diff --git a/examples/android/binder/java/io/grpc/binder/cpp/exampleclient/native.cc b/examples/android/binder/java/io/grpc/binder/cpp/exampleclient/native.cc new file mode 100644 index 00000000..effdf2e2 --- /dev/null +++ b/examples/android/binder/java/io/grpc/binder/cpp/exampleclient/native.cc @@ -0,0 +1,53 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "examples/protos/helloworld.grpc.pb.h" +#include "examples/protos/helloworld.pb.h" + +#include "src/core/ext/transport/binder/client/channel_create.h" +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" + +extern "C" JNIEXPORT jstring JNICALL +Java_io_grpc_binder_cpp_exampleclient_ButtonPressHandler_native_1entry( + JNIEnv* env, jobject /*this*/, jobject application) { + static bool first = true; + __android_log_print(ANDROID_LOG_INFO, "DemoClient", "Line number %d", + __LINE__); + if (first) { + first = false; + grpc::experimental::BindToOnDeviceServerService( + env, application, "io.grpc.binder.cpp.exampleserver", + "io.grpc.binder.cpp.exampleserver.ExportedEndpointService"); + return env->NewStringUTF("Clicked 1 time"); + } else { + // TODO(mingcl): Use same signature security after it become available + auto channel = grpc::experimental::CreateBinderChannel( + env, application, "", "", + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>()); + auto stub = helloworld::Greeter::NewStub(channel); + grpc::ClientContext context; + helloworld::HelloRequest request; + helloworld::HelloReply response; + request.set_name("BinderTransportClient"); + grpc::Status status = stub->SayHello(&context, request, &response); + if (status.ok()) { + return env->NewStringUTF(response.message().c_str()); + } + return env->NewStringUTF("Clicked more than 1 time. Status not ok"); + } +} diff --git a/examples/android/binder/java/io/grpc/binder/cpp/exampleserver/native.cc b/examples/android/binder/java/io/grpc/binder/cpp/exampleserver/native.cc new file mode 100644 index 00000000..2afe578b --- /dev/null +++ b/examples/android/binder/java/io/grpc/binder/cpp/exampleserver/native.cc @@ -0,0 +1,72 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "examples/protos/helloworld.grpc.pb.h" +#include "examples/protos/helloworld.pb.h" + +#include + +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" +#include "src/core/ext/transport/binder/server/binder_server.h" +#include "src/core/ext/transport/binder/server/binder_server_credentials.h" + +namespace { +class GreeterService : public helloworld::Greeter::Service { + public: + grpc::Status SayHello(grpc::ServerContext*, + const helloworld::HelloRequest* request, + helloworld::HelloReply* response) override { + __android_log_print(ANDROID_LOG_INFO, "DemoServer", "Line number %d", + __LINE__); + __android_log_print(ANDROID_LOG_INFO, "DemoServer", "Got hello request: %s", + request->name().c_str()); + response->set_message("Hi, " + request->name()); + return grpc::Status::OK; + } +}; + +} // namespace + +extern "C" JNIEXPORT void JNICALL +Java_io_grpc_binder_cpp_exampleserver_ExportedEndpointService_init_1grpc_1server( + JNIEnv* env, jobject /*this*/) { + __android_log_print(ANDROID_LOG_INFO, "DemoServer", "Line number %d", + __LINE__); + static std::unique_ptr server = nullptr; + + if (server != nullptr) { + // Already initiated + return; + } + + static GreeterService service; + grpc::ServerBuilder server_builder; + server_builder.RegisterService(&service); + + // TODO(mingcl): Use same signature security after it become available + server_builder.AddListeningPort( + "binder:example.service", + grpc::experimental::BinderServerCredentials( + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>())); + + server = server_builder.BuildAndStart(); +} diff --git a/examples/android/helloworld/app/src/main/cpp/grpc-helloworld.cc b/examples/android/helloworld/app/src/main/cpp/grpc-helloworld.cc new file mode 100644 index 00000000..7a31b783 --- /dev/null +++ b/examples/android/helloworld/app/src/main/cpp/grpc-helloworld.cc @@ -0,0 +1,142 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "helloworld.grpc.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +std::atomic stop_server(false); + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void StartServer(JNIEnv* env, jobject obj, jmethodID is_cancelled_mid, + int port) { + const int host_port_buf_size = 1024; + char host_port[host_port_buf_size]; + snprintf(host_port, host_port_buf_size, "0.0.0.0:%d", port); + + GreeterServiceImpl service; + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(host_port, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + while (!stop_server.load()) { + // Check with the Java code to see if the user has requested the server stop or the app is no + // longer in the foreground. + jboolean is_cancelled = env->CallBooleanMethod(obj, is_cancelled_mid); + if (is_cancelled == JNI_TRUE) { + stop_server = true; + } + } +} + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + if (status.ok()) { + return reply.message(); + } else { + return status.error_message(); + } + } + + private: + std::unique_ptr stub_; +}; + +// Send an RPC and return the response. Invoked from Java code. +extern "C" JNIEXPORT jstring JNICALL +Java_io_grpc_helloworldexample_cpp_HelloworldActivity_sayHello( + JNIEnv* env, jobject obj_unused, jstring host_raw, jint port_raw, + jstring message_raw) { + const char* host_chars = env->GetStringUTFChars(host_raw, (jboolean*)0); + std::string host(host_chars, env->GetStringUTFLength(host_raw)); + + int port = static_cast(port_raw); + + const char* message_chars = env->GetStringUTFChars(message_raw, (jboolean*)0); + std::string message(message_chars, env->GetStringUTFLength(message_raw)); + + const int host_port_buf_size = 1024; + char host_port[host_port_buf_size]; + snprintf(host_port, host_port_buf_size, "%s:%d", host.c_str(), port); + + GreeterClient greeter( + grpc::CreateChannel(host_port, grpc::InsecureChannelCredentials())); + std::string reply = greeter.SayHello(message); + + return env->NewStringUTF(reply.c_str()); +} + +// Start the server. Invoked from Java code. +extern "C" JNIEXPORT void JNICALL +Java_io_grpc_helloworldexample_cpp_HelloworldActivity_startServer( + JNIEnv* env, jobject obj_this, jint port_raw) { + int port = static_cast(port_raw); + + jclass cls = env->GetObjectClass(obj_this); + jmethodID is_cancelled_mid = + env->GetMethodID(cls, "isRunServerTaskCancelled", "()Z"); + + stop_server = false; + + StartServer(env, obj_this, is_cancelled_mid, port); +} diff --git a/examples/cpp/compression/greeter_client.cc b/examples/cpp/compression/greeter_client.cc new file mode 100644 index 00000000..b1ef4546 --- /dev/null +++ b/examples/cpp/compression/greeter_client.cc @@ -0,0 +1,93 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ChannelArguments; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // Overwrite the call's compression algorithm to DEFLATE. + context.set_compression_algorithm(GRPC_COMPRESS_DEFLATE); + + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + ChannelArguments args; + // Set the default compression algorithm for the channel. + args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP); + GreeterClient greeter(grpc::CreateCustomChannel( + "localhost:50051", grpc::InsecureChannelCredentials(), args)); + std::string user("world world world world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/compression/greeter_server.cc b/examples/cpp/compression/greeter_server.cc new file mode 100644 index 00000000..dc1b1f65 --- /dev/null +++ b/examples/cpp/compression/greeter_server.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + // Overwrite the call's compression algorithm to DEFLATE. + context->set_compression_algorithm(GRPC_COMPRESS_DEFLATE); + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + GreeterServiceImpl service; + + ServerBuilder builder; + // Set the default compression algorithm for the server. + builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_GZIP); + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_async_client.cc b/examples/cpp/helloworld/greeter_async_client.cc new file mode 100644 index 00000000..ab02bbc4 --- /dev/null +++ b/examples/cpp/helloworld/greeter_async_client.cc @@ -0,0 +1,121 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientAsyncResponseReader; +using grpc::ClientContext; +using grpc::CompletionQueue; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + explicit GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The producer-consumer queue we use to communicate asynchronously with the + // gRPC runtime. + CompletionQueue cq; + + // Storage for the status of the RPC upon completion. + Status status; + + // stub_->PrepareAsyncSayHello() creates an RPC object, returning + // an instance to store in "call" but does not actually start the RPC + // Because we are using the asynchronous API, we need to hold on to + // the "call" instance in order to get updates on the ongoing RPC. + std::unique_ptr > rpc( + stub_->PrepareAsyncSayHello(&context, request, &cq)); + + // StartCall initiates the RPC call + rpc->StartCall(); + + // Request that, upon completion of the RPC, "reply" be updated with the + // server's response; "status" with the indication of whether the operation + // was successful. Tag the request with the integer 1. + rpc->Finish(&reply, &status, (void*)1); + void* got_tag; + bool ok = false; + // Block until the next result is available in the completion queue "cq". + // The return value of Next should always be checked. This return value + // tells us whether there is any kind of event or the cq_ is shutting down. + GPR_ASSERT(cq.Next(&got_tag, &ok)); + + // Verify that the result from "cq" corresponds, by its tag, our previous + // request. + GPR_ASSERT(got_tag == (void*)1); + // ... and that the request was completed successfully. Note that "ok" + // corresponds solely to the request for updates introduced by Finish(). + GPR_ASSERT(ok); + + // Act upon the status of the actual RPC. + if (status.ok()) { + return reply.message(); + } else { + return "RPC failed"; + } + } + + private: + // Out of the passed in Channel comes the stub, stored here, our view of the + // server's exposed services. + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + GreeterClient greeter(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + std::string user("world"); + std::string reply = greeter.SayHello(user); // The actual RPC call! + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_async_client2.cc b/examples/cpp/helloworld/greeter_async_client2.cc new file mode 100644 index 00000000..54e75c9c --- /dev/null +++ b/examples/cpp/helloworld/greeter_async_client2.cc @@ -0,0 +1,143 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientAsyncResponseReader; +using grpc::ClientContext; +using grpc::CompletionQueue; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + explicit GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload and sends it to the server. + void SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Call object to store rpc data + AsyncClientCall* call = new AsyncClientCall; + + // stub_->PrepareAsyncSayHello() creates an RPC object, returning + // an instance to store in "call" but does not actually start the RPC + // Because we are using the asynchronous API, we need to hold on to + // the "call" instance in order to get updates on the ongoing RPC. + call->response_reader = + stub_->PrepareAsyncSayHello(&call->context, request, &cq_); + + // StartCall initiates the RPC call + call->response_reader->StartCall(); + + // Request that, upon completion of the RPC, "reply" be updated with the + // server's response; "status" with the indication of whether the operation + // was successful. Tag the request with the memory address of the call + // object. + call->response_reader->Finish(&call->reply, &call->status, (void*)call); + } + + // Loop while listening for completed responses. + // Prints out the response from the server. + void AsyncCompleteRpc() { + void* got_tag; + bool ok = false; + + // Block until the next result is available in the completion queue "cq". + while (cq_.Next(&got_tag, &ok)) { + // The tag in this example is the memory location of the call object + AsyncClientCall* call = static_cast(got_tag); + + // Verify that the request was completed successfully. Note that "ok" + // corresponds solely to the request for updates introduced by Finish(). + GPR_ASSERT(ok); + + if (call->status.ok()) + std::cout << "Greeter received: " << call->reply.message() << std::endl; + else + std::cout << "RPC failed" << std::endl; + + // Once we're complete, deallocate the call object. + delete call; + } + } + + private: + // struct for keeping state and data information + struct AsyncClientCall { + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // Storage for the status of the RPC upon completion. + Status status; + + std::unique_ptr> response_reader; + }; + + // Out of the passed in Channel comes the stub, stored here, our view of the + // server's exposed services. + std::unique_ptr stub_; + + // The producer-consumer queue we use to communicate asynchronously with the + // gRPC runtime. + CompletionQueue cq_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + GreeterClient greeter(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + + // Spawn reader thread that loops indefinitely + std::thread thread_ = std::thread(&GreeterClient::AsyncCompleteRpc, &greeter); + + for (int i = 0; i < 100; i++) { + std::string user("world " + std::to_string(i)); + greeter.SayHello(user); // The actual RPC call! + } + + std::cout << "Press control-c to quit" << std::endl << std::endl; + thread_.join(); // blocks forever + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_async_server.cc b/examples/cpp/helloworld/greeter_async_server.cc new file mode 100644 index 00000000..d35d3e97 --- /dev/null +++ b/examples/cpp/helloworld/greeter_async_server.cc @@ -0,0 +1,171 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerAsyncResponseWriter; +using grpc::ServerBuilder; +using grpc::ServerCompletionQueue; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class ServerImpl final { + public: + ~ServerImpl() { + server_->Shutdown(); + // Always shutdown the completion queue after the server. + cq_->Shutdown(); + } + + // There is no shutdown handling in this code. + void Run() { + std::string server_address("0.0.0.0:50051"); + + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service_" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *asynchronous* service. + builder.RegisterService(&service_); + // Get hold of the completion queue used for the asynchronous communication + // with the gRPC runtime. + cq_ = builder.AddCompletionQueue(); + // Finally assemble the server. + server_ = builder.BuildAndStart(); + std::cout << "Server listening on " << server_address << std::endl; + + // Proceed to the server's main loop. + HandleRpcs(); + } + + private: + // Class encompasing the state and logic needed to serve a request. + class CallData { + public: + // Take in the "service" instance (in this case representing an asynchronous + // server) and the completion queue "cq" used for asynchronous communication + // with the gRPC runtime. + CallData(Greeter::AsyncService* service, ServerCompletionQueue* cq) + : service_(service), cq_(cq), responder_(&ctx_), status_(CREATE) { + // Invoke the serving logic right away. + Proceed(); + } + + void Proceed() { + if (status_ == CREATE) { + // Make this instance progress to the PROCESS state. + status_ = PROCESS; + + // As part of the initial CREATE state, we *request* that the system + // start processing SayHello requests. In this request, "this" acts are + // the tag uniquely identifying the request (so that different CallData + // instances can serve different requests concurrently), in this case + // the memory address of this CallData instance. + service_->RequestSayHello(&ctx_, &request_, &responder_, cq_, cq_, + this); + } else if (status_ == PROCESS) { + // Spawn a new CallData instance to serve new clients while we process + // the one for this CallData. The instance will deallocate itself as + // part of its FINISH state. + new CallData(service_, cq_); + + // The actual processing. + std::string prefix("Hello "); + reply_.set_message(prefix + request_.name()); + + // And we are done! Let the gRPC runtime know we've finished, using the + // memory address of this instance as the uniquely identifying tag for + // the event. + status_ = FINISH; + responder_.Finish(reply_, Status::OK, this); + } else { + GPR_ASSERT(status_ == FINISH); + // Once in the FINISH state, deallocate ourselves (CallData). + delete this; + } + } + + private: + // The means of communication with the gRPC runtime for an asynchronous + // server. + Greeter::AsyncService* service_; + // The producer-consumer queue where for asynchronous server notifications. + ServerCompletionQueue* cq_; + // Context for the rpc, allowing to tweak aspects of it such as the use + // of compression, authentication, as well as to send metadata back to the + // client. + ServerContext ctx_; + + // What we get from the client. + HelloRequest request_; + // What we send back to the client. + HelloReply reply_; + + // The means to get back to the client. + ServerAsyncResponseWriter responder_; + + // Let's implement a tiny state machine with the following states. + enum CallStatus { CREATE, PROCESS, FINISH }; + CallStatus status_; // The current serving state. + }; + + // This can be run in multiple threads if needed. + void HandleRpcs() { + // Spawn a new CallData instance to serve new clients. + new CallData(&service_, cq_.get()); + void* tag; // uniquely identifies a request. + bool ok; + while (true) { + // Block waiting to read the next event from the completion queue. The + // event is uniquely identified by its tag, which in this case is the + // memory address of a CallData instance. + // The return value of Next should always be checked. This return value + // tells us whether there is any kind of event or cq_ is shutting down. + GPR_ASSERT(cq_->Next(&tag, &ok)); + GPR_ASSERT(ok); + static_cast(tag)->Proceed(); + } + } + + std::unique_ptr cq_; + Greeter::AsyncService service_; + std::unique_ptr server_; +}; + +int main(int argc, char** argv) { + ServerImpl server; + server.Run(); + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_callback_client.cc b/examples/cpp/helloworld/greeter_callback_client.cc new file mode 100644 index 00000000..0e92cb90 --- /dev/null +++ b/examples/cpp/helloworld/greeter_callback_client.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The actual RPC. + std::mutex mu; + std::condition_variable cv; + bool done = false; + Status status; + stub_->async()->SayHello(&context, &request, &reply, + [&mu, &cv, &done, &status](Status s) { + status = std::move(s); + std::lock_guard lock(mu); + done = true; + cv.notify_one(); + }); + + std::unique_lock lock(mu); + while (!done) { + cv.wait(lock); + } + + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint specified by + // the argument "--target=" which is the only expected argument. + // We indicate that the channel isn't authenticated (use of + // InsecureChannelCredentials()). + std::string target_str; + std::string arg_str("--target"); + if (argc > 1) { + std::string arg_val = argv[1]; + size_t start_pos = arg_val.find(arg_str); + if (start_pos != std::string::npos) { + start_pos += arg_str.size(); + if (arg_val[start_pos] == '=') { + target_str = arg_val.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" + << std::endl; + return 0; + } + } else { + std::cout << "The only acceptable argument is --target=" << std::endl; + return 0; + } + } else { + target_str = "localhost:50051"; + } + GreeterClient greeter( + grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_callback_server.cc b/examples/cpp/helloworld/greeter_callback_server.cc new file mode 100644 index 00000000..8f935611 --- /dev/null +++ b/examples/cpp/helloworld/greeter_callback_server.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::CallbackServerContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerUnaryReactor; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::CallbackService { + ServerUnaryReactor* SayHello(CallbackServerContext* context, + const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + + ServerUnaryReactor* reactor = context->DefaultReactor(); + reactor->Finish(Status::OK); + return reactor; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + GreeterServiceImpl service; + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_client.cc b/examples/cpp/helloworld/greeter_client.cc new file mode 100644 index 00000000..6b9c12d5 --- /dev/null +++ b/examples/cpp/helloworld/greeter_client.cc @@ -0,0 +1,108 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint specified by + // the argument "--target=" which is the only expected argument. + // We indicate that the channel isn't authenticated (use of + // InsecureChannelCredentials()). + std::string target_str; + std::string arg_str("--target"); + if (argc > 1) { + std::string arg_val = argv[1]; + size_t start_pos = arg_val.find(arg_str); + if (start_pos != std::string::npos) { + start_pos += arg_str.size(); + if (arg_val[start_pos] == '=') { + target_str = arg_val.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" + << std::endl; + return 0; + } + } else { + std::cout << "The only acceptable argument is --target=" << std::endl; + return 0; + } + } else { + target_str = "localhost:50051"; + } + GreeterClient greeter( + grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/helloworld/greeter_server.cc b/examples/cpp/helloworld/greeter_server.cc new file mode 100644 index 00000000..560b8b42 --- /dev/null +++ b/examples/cpp/helloworld/greeter_server.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + GreeterServiceImpl service; + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/helloworld/xds_greeter_client.cc b/examples/cpp/helloworld/xds_greeter_client.cc new file mode 100644 index 00000000..839625ec --- /dev/null +++ b/examples/cpp/helloworld/xds_greeter_client.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +ABSL_FLAG(std::string, target, "xds:///helloworld:50051", "Target string"); +ABSL_FLAG(bool, secure, true, "Secure mode"); + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + GreeterClient greeter(grpc::CreateChannel( + absl::GetFlag(FLAGS_target), absl::GetFlag(FLAGS_secure) + ? grpc::experimental::XdsCredentials( + grpc::InsecureChannelCredentials()) + : grpc::InsecureChannelCredentials())); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/helloworld/xds_greeter_server.cc b/examples/cpp/helloworld/xds_greeter_server.cc new file mode 100644 index 00000000..21a964de --- /dev/null +++ b/examples/cpp/helloworld/xds_greeter_server.cc @@ -0,0 +1,108 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +ABSL_FLAG(int32_t, port, 50051, "Server port for service."); +ABSL_FLAG(int32_t, maintenance_port, 50052, + "Server port for maintenance if --secure is used."); +ABSL_FLAG(bool, secure, true, "Secure mode"); + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void RunServer() { + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + int port = absl::GetFlag(FLAGS_port); + int maintenance_port = absl::GetFlag(FLAGS_maintenance_port); + grpc::experimental::XdsServerBuilder xds_builder; + ServerBuilder builder; + std::unique_ptr xds_enabled_server; + std::unique_ptr server; + GreeterServiceImpl service; + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + xds_builder.RegisterService(&service); + if (absl::GetFlag(FLAGS_secure)) { + // Listen on the given address with XdsServerCredentials and a fallback of + // InsecureServerCredentials + xds_builder.AddListeningPort(absl::StrCat("0.0.0.0:", port), + grpc::experimental::XdsServerCredentials( + grpc::InsecureServerCredentials())); + xds_enabled_server = xds_builder.BuildAndStart(); + gpr_log(GPR_INFO, "Server starting on 0.0.0.0:%d", port); + grpc::AddAdminServices(&builder); + // For the maintenance server, do not use any authentication mechanism. + builder.AddListeningPort(absl::StrCat("0.0.0.0:", maintenance_port), + grpc::InsecureServerCredentials()); + server = builder.BuildAndStart(); + gpr_log(GPR_INFO, "Maintenance server listening on 0.0.0.0:%d", + maintenance_port); + } else { + grpc::AddAdminServices(&xds_builder); + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(absl::StrCat("0.0.0.0:", port), + grpc::InsecureServerCredentials()); + server = xds_builder.BuildAndStart(); + gpr_log(GPR_INFO, "Server listening on 0.0.0.0:%d", port); + } + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + RunServer(); + return 0; +} diff --git a/examples/cpp/keyvaluestore/client.cc b/examples/cpp/keyvaluestore/client.cc new file mode 100644 index 00000000..75c09a0a --- /dev/null +++ b/examples/cpp/keyvaluestore/client.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "caching_interceptor.h" + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/keyvaluestore.grpc.pb.h" +#else +#include "keyvaluestore.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using keyvaluestore::KeyValueStore; +using keyvaluestore::Request; +using keyvaluestore::Response; + +class KeyValueStoreClient { + public: + KeyValueStoreClient(std::shared_ptr channel) + : stub_(KeyValueStore::NewStub(channel)) {} + + // Requests each key in the vector and displays the key and its corresponding + // value as a pair + void GetValues(const std::vector& keys) { + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + auto stream = stub_->GetValues(&context); + for (const auto& key : keys) { + // Key we are sending to the server. + Request request; + request.set_key(key); + stream->Write(request); + + // Get the value for the sent key + Response response; + stream->Read(&response); + std::cout << key << " : " << response.value() << "\n"; + } + stream->WritesDone(); + Status status = stream->Finish(); + if (!status.ok()) { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + std::cout << "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + // In this example, we are using a cache which has been added in as an + // interceptor. + grpc::ChannelArguments args; + std::vector< + std::unique_ptr> + interceptor_creators; + interceptor_creators.push_back(std::unique_ptr( + new CachingInterceptorFactory())); + auto channel = grpc::experimental::CreateCustomChannelWithInterceptors( + "localhost:50051", grpc::InsecureChannelCredentials(), args, + std::move(interceptor_creators)); + KeyValueStoreClient client(channel); + std::vector keys = {"key1", "key2", "key3", "key4", + "key5", "key1", "key2", "key4"}; + client.GetValues(keys); + + return 0; +} diff --git a/examples/cpp/keyvaluestore/server.cc b/examples/cpp/keyvaluestore/server.cc new file mode 100644 index 00000000..e75da9c6 --- /dev/null +++ b/examples/cpp/keyvaluestore/server.cc @@ -0,0 +1,97 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/keyvaluestore.grpc.pb.h" +#else +#include "keyvaluestore.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::ServerReaderWriter; +using grpc::Status; +using keyvaluestore::KeyValueStore; +using keyvaluestore::Request; +using keyvaluestore::Response; + +struct kv_pair { + const char* key; + const char* value; +}; + +static const kv_pair kvs_map[] = { + {"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}, + {"key4", "value4"}, {"key5", "value5"}, +}; + +const char* get_value_from_map(const char* key) { + for (size_t i = 0; i < sizeof(kvs_map) / sizeof(kv_pair); ++i) { + if (strcmp(key, kvs_map[i].key) == 0) { + return kvs_map[i].value; + } + } + return ""; +} + +// Logic and data behind the server's behavior. +class KeyValueStoreServiceImpl final : public KeyValueStore::Service { + Status GetValues(ServerContext* context, + ServerReaderWriter* stream) override { + Request request; + while (stream->Read(&request)) { + Response response; + response.set_value(get_value_from_map(request.key().c_str())); + stream->Write(response); + } + return Status::OK; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + KeyValueStoreServiceImpl service; + + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case, it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/load_balancing/greeter_client.cc b/examples/cpp/load_balancing/greeter_client.cc new file mode 100644 index 00000000..25322dbf --- /dev/null +++ b/examples/cpp/load_balancing/greeter_client.cc @@ -0,0 +1,90 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ChannelArguments; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class GreeterClient { + public: + GreeterClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + // Act upon its status. + if (status.ok()) { + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + ChannelArguments args; + // Set the load balancing policy for the channel. + args.SetLoadBalancingPolicyName("round_robin"); + GreeterClient greeter(grpc::CreateCustomChannel( + "localhost:50051", grpc::InsecureChannelCredentials(), args)); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Greeter received: " << reply << std::endl; + + return 0; +} diff --git a/examples/cpp/load_balancing/greeter_server.cc b/examples/cpp/load_balancing/greeter_server.cc new file mode 100644 index 00000000..d021b54c --- /dev/null +++ b/examples/cpp/load_balancing/greeter_server.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + GreeterServiceImpl service; + + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/metadata/greeter_client.cc b/examples/cpp/metadata/greeter_client.cc new file mode 100644 index 00000000..3784d2fa --- /dev/null +++ b/examples/cpp/metadata/greeter_client.cc @@ -0,0 +1,102 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +class CustomHeaderClient { + public: + CustomHeaderClient(std::shared_ptr channel) + : stub_(Greeter::NewStub(channel)) {} + + // Assembles the client's payload, sends it and presents the response back + // from the server. + std::string SayHello(const std::string& user) { + // Data we are sending to the server. + HelloRequest request; + request.set_name(user); + + // Container for the data we expect from the server. + HelloReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // Setting custom metadata to be sent to the server + context.AddMetadata("custom-header", "Custom Value"); + + // Setting custom binary metadata + char bytes[8] = {'\0', '\1', '\2', '\3', '\4', '\5', '\6', '\7'}; + context.AddMetadata("custom-bin", std::string(bytes, 8)); + + // The actual RPC. + Status status = stub_->SayHello(&context, request, &reply); + + // Act upon its status. + if (status.ok()) { + std::cout << "Client received initial metadata from server: " + << context.GetServerInitialMetadata() + .find("custom-server-metadata") + ->second + << std::endl; + std::cout << "Client received trailing metadata from server: " + << context.GetServerTrailingMetadata() + .find("custom-trailing-metadata") + ->second + << std::endl; + return reply.message(); + } else { + std::cout << status.error_code() << ": " << status.error_message() + << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + CustomHeaderClient greeter(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + std::string user("world"); + std::string reply = greeter.SayHello(user); + std::cout << "Client received message: " << reply << std::endl; + return 0; +} diff --git a/examples/cpp/metadata/greeter_server.cc b/examples/cpp/metadata/greeter_server.cc new file mode 100644 index 00000000..19d8bb0d --- /dev/null +++ b/examples/cpp/metadata/greeter_server.cc @@ -0,0 +1,97 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#ifdef BAZEL_BUILD +#include "examples/protos/helloworld.grpc.pb.h" +#else +#include "helloworld.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using helloworld::Greeter; +using helloworld::HelloReply; +using helloworld::HelloRequest; + +// Logic and data behind the server's behavior. +class GreeterServiceImpl final : public Greeter::Service { + Status SayHello(ServerContext* context, const HelloRequest* request, + HelloReply* reply) override { + std::string prefix("Hello "); + + // Get the client's initial metadata + std::cout << "Client metadata: " << std::endl; + const std::multimap metadata = + context->client_metadata(); + for (auto iter = metadata.begin(); iter != metadata.end(); ++iter) { + std::cout << "Header key: " << iter->first << ", value: "; + // Check for binary value + size_t isbin = iter->first.find("-bin"); + if ((isbin != std::string::npos) && (isbin + 4 == iter->first.size())) { + std::cout << std::hex; + for (auto c : iter->second) { + std::cout << static_cast(c); + } + std::cout << std::dec; + } else { + std::cout << iter->second; + } + std::cout << std::endl; + } + + context->AddInitialMetadata("custom-server-metadata", + "initial metadata value"); + context->AddTrailingMetadata("custom-trailing-metadata", + "trailing metadata value"); + reply->set_message(prefix + request->name()); + return Status::OK; + } +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + GreeterServiceImpl service; + + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + + return 0; +} diff --git a/examples/cpp/route_guide/helper.cc b/examples/cpp/route_guide/helper.cc new file mode 100644 index 00000000..cd889f6a --- /dev/null +++ b/examples/cpp/route_guide/helper.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#ifdef BAZEL_BUILD +#include "examples/protos/route_guide.grpc.pb.h" +#else +#include "route_guide.grpc.pb.h" +#endif + +namespace routeguide { + +std::string GetDbFileContent(int argc, char** argv) { + std::string db_path; + std::string arg_str("--db_path"); + if (argc > 1) { + std::string argv_1 = argv[1]; + size_t start_position = argv_1.find(arg_str); + if (start_position != std::string::npos) { + start_position += arg_str.size(); + if (argv_1[start_position] == ' ' || argv_1[start_position] == '=') { + db_path = argv_1.substr(start_position + 1); + } + } + } else { +#ifdef BAZEL_BUILD + db_path = "cpp/route_guide/route_guide_db.json"; +#else + db_path = "route_guide_db.json"; +#endif + } + std::ifstream db_file(db_path); + if (!db_file.is_open()) { + std::cout << "Failed to open " << db_path << std::endl; + return ""; + } + std::stringstream db; + db << db_file.rdbuf(); + return db.str(); +} + +// A simple parser for the json db file. It requires the db file to have the +// exact form of [{"location": { "latitude": 123, "longitude": 456}, "name": +// "the name can be empty" }, { ... } ... The spaces will be stripped. +class Parser { + public: + explicit Parser(const std::string& db) : db_(db) { + // Remove all spaces. + db_.erase(std::remove_if(db_.begin(), db_.end(), isspace), db_.end()); + if (!Match("[")) { + SetFailedAndReturnFalse(); + } + } + + bool Finished() { return current_ >= db_.size(); } + + bool TryParseOne(Feature* feature) { + if (failed_ || Finished() || !Match("{")) { + return SetFailedAndReturnFalse(); + } + if (!Match(location_) || !Match("{") || !Match(latitude_)) { + return SetFailedAndReturnFalse(); + } + long temp = 0; + ReadLong(&temp); + feature->mutable_location()->set_latitude(temp); + if (!Match(",") || !Match(longitude_)) { + return SetFailedAndReturnFalse(); + } + ReadLong(&temp); + feature->mutable_location()->set_longitude(temp); + if (!Match("},") || !Match(name_) || !Match("\"")) { + return SetFailedAndReturnFalse(); + } + size_t name_start = current_; + while (current_ != db_.size() && db_[current_++] != '"') { + } + if (current_ == db_.size()) { + return SetFailedAndReturnFalse(); + } + feature->set_name(db_.substr(name_start, current_ - name_start - 1)); + if (!Match("},")) { + if (db_[current_ - 1] == ']' && current_ == db_.size()) { + return true; + } + return SetFailedAndReturnFalse(); + } + return true; + } + + private: + bool SetFailedAndReturnFalse() { + failed_ = true; + return false; + } + + bool Match(const std::string& prefix) { + bool eq = db_.substr(current_, prefix.size()) == prefix; + current_ += prefix.size(); + return eq; + } + + void ReadLong(long* l) { + size_t start = current_; + while (current_ != db_.size() && db_[current_] != ',' && + db_[current_] != '}') { + current_++; + } + // It will throw an exception if fails. + *l = std::stol(db_.substr(start, current_ - start)); + } + + bool failed_ = false; + std::string db_; + size_t current_ = 0; + const std::string location_ = "\"location\":"; + const std::string latitude_ = "\"latitude\":"; + const std::string longitude_ = "\"longitude\":"; + const std::string name_ = "\"name\":"; +}; + +void ParseDb(const std::string& db, std::vector* feature_list) { + feature_list->clear(); + std::string db_content(db); + db_content.erase( + std::remove_if(db_content.begin(), db_content.end(), isspace), + db_content.end()); + + Parser parser(db_content); + Feature feature; + while (!parser.Finished()) { + feature_list->push_back(Feature()); + if (!parser.TryParseOne(&feature_list->back())) { + std::cout << "Error parsing the db file"; + feature_list->clear(); + break; + } + } + std::cout << "DB parsed, loaded " << feature_list->size() << " features." + << std::endl; +} + +} // namespace routeguide diff --git a/examples/cpp/route_guide/route_guide_callback_client.cc b/examples/cpp/route_guide/route_guide_callback_client.cc new file mode 100644 index 00000000..4bab4ae3 --- /dev/null +++ b/examples/cpp/route_guide/route_guide_callback_client.cc @@ -0,0 +1,360 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "helper.h" + +#include +#include +#include +#include +#include +#include +#ifdef BAZEL_BUILD +#include "examples/protos/route_guide.grpc.pb.h" +#else +#include "route_guide.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using routeguide::Feature; +using routeguide::Point; +using routeguide::Rectangle; +using routeguide::RouteGuide; +using routeguide::RouteNote; +using routeguide::RouteSummary; + +Point MakePoint(long latitude, long longitude) { + Point p; + p.set_latitude(latitude); + p.set_longitude(longitude); + return p; +} + +Feature MakeFeature(const std::string& name, long latitude, long longitude) { + Feature f; + f.set_name(name); + f.mutable_location()->CopyFrom(MakePoint(latitude, longitude)); + return f; +} + +RouteNote MakeRouteNote(const std::string& message, long latitude, + long longitude) { + RouteNote n; + n.set_message(message); + n.mutable_location()->CopyFrom(MakePoint(latitude, longitude)); + return n; +} + +class RouteGuideClient { + public: + RouteGuideClient(std::shared_ptr channel, const std::string& db) + : stub_(RouteGuide::NewStub(channel)) { + routeguide::ParseDb(db, &feature_list_); + } + + void GetFeature() { + Point point; + Feature feature; + point = MakePoint(409146138, -746188906); + GetOneFeature(point, &feature); + point = MakePoint(0, 0); + GetOneFeature(point, &feature); + } + + void ListFeatures() { + routeguide::Rectangle rect; + Feature feature; + + rect.mutable_lo()->set_latitude(400000000); + rect.mutable_lo()->set_longitude(-750000000); + rect.mutable_hi()->set_latitude(420000000); + rect.mutable_hi()->set_longitude(-730000000); + std::cout << "Looking for features between 40, -75 and 42, -73" + << std::endl; + + class Reader : public grpc::ClientReadReactor { + public: + Reader(RouteGuide::Stub* stub, float coord_factor, + const routeguide::Rectangle& rect) + : coord_factor_(coord_factor) { + stub->async()->ListFeatures(&context_, &rect, this); + StartRead(&feature_); + StartCall(); + } + void OnReadDone(bool ok) override { + if (ok) { + std::cout << "Found feature called " << feature_.name() << " at " + << feature_.location().latitude() / coord_factor_ << ", " + << feature_.location().longitude() / coord_factor_ + << std::endl; + StartRead(&feature_); + } + } + void OnDone(const Status& s) override { + std::unique_lock l(mu_); + status_ = s; + done_ = true; + cv_.notify_one(); + } + Status Await() { + std::unique_lock l(mu_); + cv_.wait(l, [this] { return done_; }); + return std::move(status_); + } + + private: + ClientContext context_; + float coord_factor_; + Feature feature_; + std::mutex mu_; + std::condition_variable cv_; + Status status_; + bool done_ = false; + }; + Reader reader(stub_.get(), kCoordFactor_, rect); + Status status = std::move(reader.Await()); + if (status.ok()) { + std::cout << "ListFeatures rpc succeeded." << std::endl; + } else { + std::cout << "ListFeatures rpc failed." << std::endl; + } + } + + void RecordRoute() { + class Recorder : public grpc::ClientWriteReactor { + public: + Recorder(RouteGuide::Stub* stub, float coord_factor, + const std::vector* feature_list) + : coord_factor_(coord_factor), + feature_list_(feature_list), + generator_( + std::chrono::system_clock::now().time_since_epoch().count()), + feature_distribution_(0, feature_list->size() - 1), + delay_distribution_(500, 1500) { + stub->async()->RecordRoute(&context_, &stats_, this); + // Use a hold since some StartWrites are invoked indirectly from a + // delayed lambda in OnWriteDone rather than directly from the reaction + // itself + AddHold(); + NextWrite(); + StartCall(); + } + void OnWriteDone(bool ok) override { + // Delay and then do the next write or WritesDone + alarm_.Set( + std::chrono::system_clock::now() + + std::chrono::milliseconds(delay_distribution_(generator_)), + [this](bool /*ok*/) { NextWrite(); }); + } + void OnDone(const Status& s) override { + std::unique_lock l(mu_); + status_ = s; + done_ = true; + cv_.notify_one(); + } + Status Await(RouteSummary* stats) { + std::unique_lock l(mu_); + cv_.wait(l, [this] { return done_; }); + *stats = stats_; + return std::move(status_); + } + + private: + void NextWrite() { + if (points_remaining_ != 0) { + const Feature& f = + (*feature_list_)[feature_distribution_(generator_)]; + std::cout << "Visiting point " + << f.location().latitude() / coord_factor_ << ", " + << f.location().longitude() / coord_factor_ << std::endl; + StartWrite(&f.location()); + points_remaining_--; + } else { + StartWritesDone(); + RemoveHold(); + } + } + ClientContext context_; + float coord_factor_; + int points_remaining_ = 10; + Point point_; + RouteSummary stats_; + const std::vector* feature_list_; + std::default_random_engine generator_; + std::uniform_int_distribution feature_distribution_; + std::uniform_int_distribution delay_distribution_; + grpc::Alarm alarm_; + std::mutex mu_; + std::condition_variable cv_; + Status status_; + bool done_ = false; + }; + Recorder recorder(stub_.get(), kCoordFactor_, &feature_list_); + RouteSummary stats; + Status status = std::move(recorder.Await(&stats)); + if (status.ok()) { + std::cout << "Finished trip with " << stats.point_count() << " points\n" + << "Passed " << stats.feature_count() << " features\n" + << "Travelled " << stats.distance() << " meters\n" + << "It took " << stats.elapsed_time() << " seconds" + << std::endl; + } else { + std::cout << "RecordRoute rpc failed." << std::endl; + } + } + + void RouteChat() { + class Chatter : public grpc::ClientBidiReactor { + public: + explicit Chatter(RouteGuide::Stub* stub) + : notes_{MakeRouteNote("First message", 0, 0), + MakeRouteNote("Second message", 0, 1), + MakeRouteNote("Third message", 1, 0), + MakeRouteNote("Fourth message", 0, 0)}, + notes_iterator_(notes_.begin()) { + stub->async()->RouteChat(&context_, this); + NextWrite(); + StartRead(&server_note_); + StartCall(); + } + void OnWriteDone(bool /*ok*/) override { NextWrite(); } + void OnReadDone(bool ok) override { + if (ok) { + std::cout << "Got message " << server_note_.message() << " at " + << server_note_.location().latitude() << ", " + << server_note_.location().longitude() << std::endl; + StartRead(&server_note_); + } + } + void OnDone(const Status& s) override { + std::unique_lock l(mu_); + status_ = s; + done_ = true; + cv_.notify_one(); + } + Status Await() { + std::unique_lock l(mu_); + cv_.wait(l, [this] { return done_; }); + return std::move(status_); + } + + private: + void NextWrite() { + if (notes_iterator_ != notes_.end()) { + const auto& note = *notes_iterator_; + std::cout << "Sending message " << note.message() << " at " + << note.location().latitude() << ", " + << note.location().longitude() << std::endl; + StartWrite(¬e); + notes_iterator_++; + } else { + StartWritesDone(); + } + } + ClientContext context_; + const std::vector notes_; + std::vector::const_iterator notes_iterator_; + RouteNote server_note_; + std::mutex mu_; + std::condition_variable cv_; + Status status_; + bool done_ = false; + }; + + Chatter chatter(stub_.get()); + Status status = std::move(chatter.Await()); + if (!status.ok()) { + std::cout << "RouteChat rpc failed." << std::endl; + } + } + + private: + bool GetOneFeature(const Point& point, Feature* feature) { + ClientContext context; + bool result; + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->GetFeature( + &context, &point, feature, + [&result, &mu, &cv, &done, feature, this](Status status) { + bool ret; + if (!status.ok()) { + std::cout << "GetFeature rpc failed." << std::endl; + ret = false; + } else if (!feature->has_location()) { + std::cout << "Server returns incomplete feature." << std::endl; + ret = false; + } else if (feature->name().empty()) { + std::cout << "Found no feature at " + << feature->location().latitude() / kCoordFactor_ << ", " + << feature->location().longitude() / kCoordFactor_ + << std::endl; + ret = true; + } else { + std::cout << "Found feature called " << feature->name() << " at " + << feature->location().latitude() / kCoordFactor_ << ", " + << feature->location().longitude() / kCoordFactor_ + << std::endl; + ret = true; + } + std::lock_guard lock(mu); + result = ret; + done = true; + cv.notify_one(); + }); + std::unique_lock lock(mu); + cv.wait(lock, [&done] { return done; }); + return result; + } + + const float kCoordFactor_ = 10000000.0; + std::unique_ptr stub_; + std::vector feature_list_; +}; + +int main(int argc, char** argv) { + // Expect only arg: --db_path=path/to/route_guide_db.json. + std::string db = routeguide::GetDbFileContent(argc, argv); + RouteGuideClient guide( + grpc::CreateChannel("localhost:50051", + grpc::InsecureChannelCredentials()), + db); + + std::cout << "-------------- GetFeature --------------" << std::endl; + guide.GetFeature(); + std::cout << "-------------- ListFeatures --------------" << std::endl; + guide.ListFeatures(); + std::cout << "-------------- RecordRoute --------------" << std::endl; + guide.RecordRoute(); + std::cout << "-------------- RouteChat --------------" << std::endl; + guide.RouteChat(); + + return 0; +} diff --git a/examples/cpp/route_guide/route_guide_callback_server.cc b/examples/cpp/route_guide/route_guide_callback_server.cc new file mode 100644 index 00000000..1c9fb5f6 --- /dev/null +++ b/examples/cpp/route_guide/route_guide_callback_server.cc @@ -0,0 +1,278 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "helper.h" + +#include +#include +#include +#include +#include +#ifdef BAZEL_BUILD +#include "examples/protos/route_guide.grpc.pb.h" +#else +#include "route_guide.grpc.pb.h" +#endif + +using grpc::CallbackServerContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::Status; +using routeguide::Feature; +using routeguide::Point; +using routeguide::Rectangle; +using routeguide::RouteGuide; +using routeguide::RouteNote; +using routeguide::RouteSummary; +using std::chrono::system_clock; + +float ConvertToRadians(float num) { return num * 3.1415926 / 180; } + +// The formula is based on http://mathforum.org/library/drmath/view/51879.html +float GetDistance(const Point& start, const Point& end) { + const float kCoordFactor = 10000000.0; + float lat_1 = start.latitude() / kCoordFactor; + float lat_2 = end.latitude() / kCoordFactor; + float lon_1 = start.longitude() / kCoordFactor; + float lon_2 = end.longitude() / kCoordFactor; + float lat_rad_1 = ConvertToRadians(lat_1); + float lat_rad_2 = ConvertToRadians(lat_2); + float delta_lat_rad = ConvertToRadians(lat_2 - lat_1); + float delta_lon_rad = ConvertToRadians(lon_2 - lon_1); + + float a = pow(sin(delta_lat_rad / 2), 2) + + cos(lat_rad_1) * cos(lat_rad_2) * pow(sin(delta_lon_rad / 2), 2); + float c = 2 * atan2(sqrt(a), sqrt(1 - a)); + int R = 6371000; // metres + + return R * c; +} + +std::string GetFeatureName(const Point& point, + const std::vector& feature_list) { + for (const Feature& f : feature_list) { + if (f.location().latitude() == point.latitude() && + f.location().longitude() == point.longitude()) { + return f.name(); + } + } + return ""; +} + +class RouteGuideImpl final : public RouteGuide::CallbackService { + public: + explicit RouteGuideImpl(const std::string& db) { + routeguide::ParseDb(db, &feature_list_); + } + + grpc::ServerUnaryReactor* GetFeature(CallbackServerContext* context, + const Point* point, + Feature* feature) override { + feature->set_name(GetFeatureName(*point, feature_list_)); + feature->mutable_location()->CopyFrom(*point); + auto* reactor = context->DefaultReactor(); + reactor->Finish(Status::OK); + return reactor; + } + + grpc::ServerWriteReactor* ListFeatures( + CallbackServerContext* context, + const routeguide::Rectangle* rectangle) override { + class Lister : public grpc::ServerWriteReactor { + public: + Lister(const routeguide::Rectangle* rectangle, + const std::vector* feature_list) + : left_((std::min)(rectangle->lo().longitude(), + rectangle->hi().longitude())), + right_((std::max)(rectangle->lo().longitude(), + rectangle->hi().longitude())), + top_((std::max)(rectangle->lo().latitude(), + rectangle->hi().latitude())), + bottom_((std::min)(rectangle->lo().latitude(), + rectangle->hi().latitude())), + feature_list_(feature_list), + next_feature_(feature_list_->begin()) { + NextWrite(); + } + void OnDone() override { delete this; } + void OnWriteDone(bool /*ok*/) override { NextWrite(); } + + private: + void NextWrite() { + while (next_feature_ != feature_list_->end()) { + const Feature& f = *next_feature_; + next_feature_++; + if (f.location().longitude() >= left_ && + f.location().longitude() <= right_ && + f.location().latitude() >= bottom_ && + f.location().latitude() <= top_) { + StartWrite(&f); + return; + } + } + // Didn't write anything, all is done. + Finish(Status::OK); + } + const long left_; + const long right_; + const long top_; + const long bottom_; + const std::vector* feature_list_; + std::vector::const_iterator next_feature_; + }; + return new Lister(rectangle, &feature_list_); + } + + grpc::ServerReadReactor* RecordRoute(CallbackServerContext* context, + RouteSummary* summary) override { + class Recorder : public grpc::ServerReadReactor { + public: + Recorder(RouteSummary* summary, const std::vector* feature_list) + : start_time_(system_clock::now()), + summary_(summary), + feature_list_(feature_list) { + StartRead(&point_); + } + void OnDone() { delete this; } + void OnReadDone(bool ok) { + if (ok) { + point_count_++; + if (!GetFeatureName(point_, *feature_list_).empty()) { + feature_count_++; + } + if (point_count_ != 1) { + distance_ += GetDistance(previous_, point_); + } + previous_ = point_; + StartRead(&point_); + } else { + summary_->set_point_count(point_count_); + summary_->set_feature_count(feature_count_); + summary_->set_distance(static_cast(distance_)); + auto secs = std::chrono::duration_cast( + system_clock::now() - start_time_); + summary_->set_elapsed_time(secs.count()); + Finish(Status::OK); + } + } + + private: + system_clock::time_point start_time_; + RouteSummary* summary_; + const std::vector* feature_list_; + Point point_; + int point_count_ = 0; + int feature_count_ = 0; + float distance_ = 0.0; + Point previous_; + }; + return new Recorder(summary, &feature_list_); + } + + grpc::ServerBidiReactor* RouteChat( + CallbackServerContext* context) override { + class Chatter : public grpc::ServerBidiReactor { + public: + Chatter(std::mutex* mu, std::vector* received_notes) + : mu_(mu), received_notes_(received_notes) { + StartRead(¬e_); + } + void OnDone() override { + // Collect the read_starter thread if needed + if (read_starter_.joinable()) { + read_starter_.join(); + } + delete this; + } + void OnReadDone(bool ok) override { + if (ok) { + // We may need to wait an arbitary amount of time on this mutex + // and we cannot delay the reaction, so start it in a thread + // Collect the previous read_starter thread if needed + if (read_starter_.joinable()) { + read_starter_.join(); + } + read_starter_ = std::thread([this] { + mu_->lock(); + notes_iterator_ = received_notes_->begin(); + NextWrite(); + }); + } else { + Finish(Status::OK); + } + } + void OnWriteDone(bool /*ok*/) override { NextWrite(); } + + private: + void NextWrite() { + while (notes_iterator_ != received_notes_->end()) { + const RouteNote& n = *notes_iterator_; + notes_iterator_++; + if (n.location().latitude() == note_.location().latitude() && + n.location().longitude() == note_.location().longitude()) { + StartWrite(&n); + return; + } + } + // Didn't write anything, so all done with this note + received_notes_->push_back(note_); + mu_->unlock(); + StartRead(¬e_); + } + RouteNote note_; + std::mutex* mu_; + std::vector* received_notes_; + std::vector::iterator notes_iterator_; + std::thread read_starter_; + }; + return new Chatter(&mu_, &received_notes_); + } + + private: + std::vector feature_list_; + std::mutex mu_; + std::vector received_notes_; +}; + +void RunServer(const std::string& db_path) { + std::string server_address("0.0.0.0:50051"); + RouteGuideImpl service(db_path); + + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + server->Wait(); +} + +int main(int argc, char** argv) { + // Expect only arg: --db_path=path/to/route_guide_db.json. + std::string db = routeguide::GetDbFileContent(argc, argv); + RunServer(db); + + return 0; +} diff --git a/examples/cpp/route_guide/route_guide_client.cc b/examples/cpp/route_guide/route_guide_client.cc new file mode 100644 index 00000000..fd29a31e --- /dev/null +++ b/examples/cpp/route_guide/route_guide_client.cc @@ -0,0 +1,236 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include "helper.h" + +#include +#include +#include +#include +#include +#ifdef BAZEL_BUILD +#include "examples/protos/route_guide.grpc.pb.h" +#else +#include "route_guide.grpc.pb.h" +#endif + +using grpc::Channel; +using grpc::ClientContext; +using grpc::ClientReader; +using grpc::ClientReaderWriter; +using grpc::ClientWriter; +using grpc::Status; +using routeguide::Feature; +using routeguide::Point; +using routeguide::Rectangle; +using routeguide::RouteGuide; +using routeguide::RouteNote; +using routeguide::RouteSummary; + +Point MakePoint(long latitude, long longitude) { + Point p; + p.set_latitude(latitude); + p.set_longitude(longitude); + return p; +} + +Feature MakeFeature(const std::string& name, long latitude, long longitude) { + Feature f; + f.set_name(name); + f.mutable_location()->CopyFrom(MakePoint(latitude, longitude)); + return f; +} + +RouteNote MakeRouteNote(const std::string& message, long latitude, + long longitude) { + RouteNote n; + n.set_message(message); + n.mutable_location()->CopyFrom(MakePoint(latitude, longitude)); + return n; +} + +class RouteGuideClient { + public: + RouteGuideClient(std::shared_ptr channel, const std::string& db) + : stub_(RouteGuide::NewStub(channel)) { + routeguide::ParseDb(db, &feature_list_); + } + + void GetFeature() { + Point point; + Feature feature; + point = MakePoint(409146138, -746188906); + GetOneFeature(point, &feature); + point = MakePoint(0, 0); + GetOneFeature(point, &feature); + } + + void ListFeatures() { + routeguide::Rectangle rect; + Feature feature; + ClientContext context; + + rect.mutable_lo()->set_latitude(400000000); + rect.mutable_lo()->set_longitude(-750000000); + rect.mutable_hi()->set_latitude(420000000); + rect.mutable_hi()->set_longitude(-730000000); + std::cout << "Looking for features between 40, -75 and 42, -73" + << std::endl; + + std::unique_ptr > reader( + stub_->ListFeatures(&context, rect)); + while (reader->Read(&feature)) { + std::cout << "Found feature called " << feature.name() << " at " + << feature.location().latitude() / kCoordFactor_ << ", " + << feature.location().longitude() / kCoordFactor_ << std::endl; + } + Status status = reader->Finish(); + if (status.ok()) { + std::cout << "ListFeatures rpc succeeded." << std::endl; + } else { + std::cout << "ListFeatures rpc failed." << std::endl; + } + } + + void RecordRoute() { + Point point; + RouteSummary stats; + ClientContext context; + const int kPoints = 10; + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + + std::default_random_engine generator(seed); + std::uniform_int_distribution feature_distribution( + 0, feature_list_.size() - 1); + std::uniform_int_distribution delay_distribution(500, 1500); + + std::unique_ptr > writer( + stub_->RecordRoute(&context, &stats)); + for (int i = 0; i < kPoints; i++) { + const Feature& f = feature_list_[feature_distribution(generator)]; + std::cout << "Visiting point " << f.location().latitude() / kCoordFactor_ + << ", " << f.location().longitude() / kCoordFactor_ + << std::endl; + if (!writer->Write(f.location())) { + // Broken stream. + break; + } + std::this_thread::sleep_for( + std::chrono::milliseconds(delay_distribution(generator))); + } + writer->WritesDone(); + Status status = writer->Finish(); + if (status.ok()) { + std::cout << "Finished trip with " << stats.point_count() << " points\n" + << "Passed " << stats.feature_count() << " features\n" + << "Travelled " << stats.distance() << " meters\n" + << "It took " << stats.elapsed_time() << " seconds" + << std::endl; + } else { + std::cout << "RecordRoute rpc failed." << std::endl; + } + } + + void RouteChat() { + ClientContext context; + + std::shared_ptr > stream( + stub_->RouteChat(&context)); + + std::thread writer([stream]() { + std::vector notes{MakeRouteNote("First message", 0, 0), + MakeRouteNote("Second message", 0, 1), + MakeRouteNote("Third message", 1, 0), + MakeRouteNote("Fourth message", 0, 0)}; + for (const RouteNote& note : notes) { + std::cout << "Sending message " << note.message() << " at " + << note.location().latitude() << ", " + << note.location().longitude() << std::endl; + stream->Write(note); + } + stream->WritesDone(); + }); + + RouteNote server_note; + while (stream->Read(&server_note)) { + std::cout << "Got message " << server_note.message() << " at " + << server_note.location().latitude() << ", " + << server_note.location().longitude() << std::endl; + } + writer.join(); + Status status = stream->Finish(); + if (!status.ok()) { + std::cout << "RouteChat rpc failed." << std::endl; + } + } + + private: + bool GetOneFeature(const Point& point, Feature* feature) { + ClientContext context; + Status status = stub_->GetFeature(&context, point, feature); + if (!status.ok()) { + std::cout << "GetFeature rpc failed." << std::endl; + return false; + } + if (!feature->has_location()) { + std::cout << "Server returns incomplete feature." << std::endl; + return false; + } + if (feature->name().empty()) { + std::cout << "Found no feature at " + << feature->location().latitude() / kCoordFactor_ << ", " + << feature->location().longitude() / kCoordFactor_ << std::endl; + } else { + std::cout << "Found feature called " << feature->name() << " at " + << feature->location().latitude() / kCoordFactor_ << ", " + << feature->location().longitude() / kCoordFactor_ << std::endl; + } + return true; + } + + const float kCoordFactor_ = 10000000.0; + std::unique_ptr stub_; + std::vector feature_list_; +}; + +int main(int argc, char** argv) { + // Expect only arg: --db_path=path/to/route_guide_db.json. + std::string db = routeguide::GetDbFileContent(argc, argv); + RouteGuideClient guide( + grpc::CreateChannel("localhost:50051", + grpc::InsecureChannelCredentials()), + db); + + std::cout << "-------------- GetFeature --------------" << std::endl; + guide.GetFeature(); + std::cout << "-------------- ListFeatures --------------" << std::endl; + guide.ListFeatures(); + std::cout << "-------------- RecordRoute --------------" << std::endl; + guide.RecordRoute(); + std::cout << "-------------- RouteChat --------------" << std::endl; + guide.RouteChat(); + + return 0; +} diff --git a/examples/cpp/route_guide/route_guide_server.cc b/examples/cpp/route_guide/route_guide_server.cc new file mode 100644 index 00000000..0e291258 --- /dev/null +++ b/examples/cpp/route_guide/route_guide_server.cc @@ -0,0 +1,190 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include "helper.h" + +#include +#include +#include +#include +#include +#ifdef BAZEL_BUILD +#include "examples/protos/route_guide.grpc.pb.h" +#else +#include "route_guide.grpc.pb.h" +#endif + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::ServerReader; +using grpc::ServerReaderWriter; +using grpc::ServerWriter; +using grpc::Status; +using routeguide::Feature; +using routeguide::Point; +using routeguide::Rectangle; +using routeguide::RouteGuide; +using routeguide::RouteNote; +using routeguide::RouteSummary; +using std::chrono::system_clock; + +float ConvertToRadians(float num) { return num * 3.1415926 / 180; } + +// The formula is based on http://mathforum.org/library/drmath/view/51879.html +float GetDistance(const Point& start, const Point& end) { + const float kCoordFactor = 10000000.0; + float lat_1 = start.latitude() / kCoordFactor; + float lat_2 = end.latitude() / kCoordFactor; + float lon_1 = start.longitude() / kCoordFactor; + float lon_2 = end.longitude() / kCoordFactor; + float lat_rad_1 = ConvertToRadians(lat_1); + float lat_rad_2 = ConvertToRadians(lat_2); + float delta_lat_rad = ConvertToRadians(lat_2 - lat_1); + float delta_lon_rad = ConvertToRadians(lon_2 - lon_1); + + float a = pow(sin(delta_lat_rad / 2), 2) + + cos(lat_rad_1) * cos(lat_rad_2) * pow(sin(delta_lon_rad / 2), 2); + float c = 2 * atan2(sqrt(a), sqrt(1 - a)); + int R = 6371000; // metres + + return R * c; +} + +std::string GetFeatureName(const Point& point, + const std::vector& feature_list) { + for (const Feature& f : feature_list) { + if (f.location().latitude() == point.latitude() && + f.location().longitude() == point.longitude()) { + return f.name(); + } + } + return ""; +} + +class RouteGuideImpl final : public RouteGuide::Service { + public: + explicit RouteGuideImpl(const std::string& db) { + routeguide::ParseDb(db, &feature_list_); + } + + Status GetFeature(ServerContext* context, const Point* point, + Feature* feature) override { + feature->set_name(GetFeatureName(*point, feature_list_)); + feature->mutable_location()->CopyFrom(*point); + return Status::OK; + } + + Status ListFeatures(ServerContext* context, + const routeguide::Rectangle* rectangle, + ServerWriter* writer) override { + auto lo = rectangle->lo(); + auto hi = rectangle->hi(); + long left = (std::min)(lo.longitude(), hi.longitude()); + long right = (std::max)(lo.longitude(), hi.longitude()); + long top = (std::max)(lo.latitude(), hi.latitude()); + long bottom = (std::min)(lo.latitude(), hi.latitude()); + for (const Feature& f : feature_list_) { + if (f.location().longitude() >= left && + f.location().longitude() <= right && + f.location().latitude() >= bottom && f.location().latitude() <= top) { + writer->Write(f); + } + } + return Status::OK; + } + + Status RecordRoute(ServerContext* context, ServerReader* reader, + RouteSummary* summary) override { + Point point; + int point_count = 0; + int feature_count = 0; + float distance = 0.0; + Point previous; + + system_clock::time_point start_time = system_clock::now(); + while (reader->Read(&point)) { + point_count++; + if (!GetFeatureName(point, feature_list_).empty()) { + feature_count++; + } + if (point_count != 1) { + distance += GetDistance(previous, point); + } + previous = point; + } + system_clock::time_point end_time = system_clock::now(); + summary->set_point_count(point_count); + summary->set_feature_count(feature_count); + summary->set_distance(static_cast(distance)); + auto secs = + std::chrono::duration_cast(end_time - start_time); + summary->set_elapsed_time(secs.count()); + + return Status::OK; + } + + Status RouteChat(ServerContext* context, + ServerReaderWriter* stream) override { + RouteNote note; + while (stream->Read(¬e)) { + std::unique_lock lock(mu_); + for (const RouteNote& n : received_notes_) { + if (n.location().latitude() == note.location().latitude() && + n.location().longitude() == note.location().longitude()) { + stream->Write(n); + } + } + received_notes_.push_back(note); + } + + return Status::OK; + } + + private: + std::vector feature_list_; + std::mutex mu_; + std::vector received_notes_; +}; + +void RunServer(const std::string& db_path) { + std::string server_address("0.0.0.0:50051"); + RouteGuideImpl service(db_path); + + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + server->Wait(); +} + +int main(int argc, char** argv) { + // Expect only arg: --db_path=path/to/route_guide_db.json. + std::string db = routeguide::GetDbFileContent(argc, argv); + RunServer(db); + + return 0; +} diff --git a/src/android/test/interop/app/src/main/cpp/grpc-interop.cc b/src/android/test/interop/app/src/main/cpp/grpc-interop.cc new file mode 100644 index 00000000..40009dd0 --- /dev/null +++ b/src/android/test/interop/app/src/main/cpp/grpc-interop.cc @@ -0,0 +1,128 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/cpp/interop/interop_client.h" + +std::shared_ptr GetClient(const char* host, + int port, + bool use_tls) { + const int host_port_buf_size = 1024; + char host_port[host_port_buf_size]; + snprintf(host_port, host_port_buf_size, "%s:%d", host, port); + + std::shared_ptr credentials; + if (use_tls) { + credentials = grpc::SslCredentials(grpc::SslCredentialsOptions()); + } else { + credentials = grpc::InsecureChannelCredentials(); + } + + grpc::testing::ChannelCreationFunc channel_creation_func = + std::bind(grpc::CreateChannel, host_port, credentials); + return std::shared_ptr( + new grpc::testing::InteropClient(channel_creation_func, true, false)); +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doEmpty(JNIEnv* env, jobject obj_this, + jstring host_raw, + jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoEmpty(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doLargeUnary(JNIEnv* env, + jobject obj_this, + jstring host_raw, + jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoLargeUnary(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doEmptyStream(JNIEnv* env, + jobject obj_this, + jstring host_raw, + jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoEmptyStream(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doRequestStreaming( + JNIEnv* env, jobject obj_this, jstring host_raw, jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoRequestStreaming(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doResponseStreaming( + JNIEnv* env, jobject obj_this, jstring host_raw, jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoResponseStreaming(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} + +extern "C" JNIEXPORT jboolean JNICALL +Java_io_grpc_interop_cpp_InteropActivity_doPingPong(JNIEnv* env, + jobject obj_this, + jstring host_raw, + jint port_raw, + jboolean use_tls_raw) { + const char* host = env->GetStringUTFChars(host_raw, (jboolean*)0); + int port = static_cast(port_raw); + bool use_tls = static_cast(use_tls_raw); + + jboolean result = GetClient(host, port, use_tls)->DoPingPong(); + env->ReleaseStringUTFChars(host_raw, host); + return result; +} diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc new file mode 100644 index 00000000..f5283280 --- /dev/null +++ b/src/compiler/cpp_generator.cc @@ -0,0 +1,2369 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/cpp_generator.h" + +#include +#include + +namespace grpc_cpp_generator { +namespace { + +template +std::string as_string(T x) { + std::ostringstream out; + out << x; + return out.str(); +} + +inline bool ClientOnlyStreaming(const grpc_generator::Method* method) { + return method->ClientStreaming() && !method->ServerStreaming(); +} + +inline bool ServerOnlyStreaming(const grpc_generator::Method* method) { + return !method->ClientStreaming() && method->ServerStreaming(); +} + +std::string FilenameIdentifier(const std::string& filename) { + std::string result; + for (unsigned i = 0; i < filename.size(); i++) { + char c = filename[i]; + if (isalnum(c)) { + result.push_back(c); + } else { + static char hex[] = "0123456789abcdef"; + result.push_back('_'); + result.push_back(hex[(c >> 4) & 0xf]); + result.push_back(hex[c & 0xf]); + } + } + return result; +} +} // namespace + +template +T* array_end(T (&array)[N]) { + return array + N; +} + +void PrintIncludes(grpc_generator::Printer* printer, + const std::vector& headers, + bool use_system_headers, const std::string& search_path) { + std::map vars; + + vars["l"] = use_system_headers ? '<' : '"'; + vars["r"] = use_system_headers ? '>' : '"'; + + if (!search_path.empty()) { + vars["l"] += search_path; + if (search_path[search_path.size() - 1] != '/') { + vars["l"] += '/'; + } + } + + for (auto i = headers.begin(); i != headers.end(); i++) { + vars["h"] = *i; + printer->Print(vars, "#include $l$$h$$r$\n"); + } +} + +std::string GetHeaderPrologue(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + vars["filename"] = file->filename(); + vars["filename_identifier"] = FilenameIdentifier(file->filename()); + vars["filename_base"] = file->filename_without_ext(); + vars["message_header_ext"] = params.message_header_extension.empty() + ? kCppGeneratorMessageHeaderExt + : params.message_header_extension; + + printer->Print(vars, "// Generated by the gRPC C++ plugin.\n"); + printer->Print(vars, + "// If you make any local change, they will be lost.\n"); + printer->Print(vars, "// source: $filename$\n"); + std::string leading_comments = file->GetLeadingComments("//"); + if (!leading_comments.empty()) { + printer->Print(vars, "// Original file comments:\n"); + printer->PrintRaw(leading_comments.c_str()); + } + printer->Print(vars, "#ifndef GRPC_$filename_identifier$__INCLUDED\n"); + printer->Print(vars, "#define GRPC_$filename_identifier$__INCLUDED\n"); + printer->Print(vars, "\n"); + printer->Print(vars, "#include \"$filename_base$$message_header_ext$\"\n"); + printer->Print(vars, file->additional_headers().c_str()); + printer->Print(vars, "\n"); + } + return output; +} + +// Convert from "a/b/c.proto" to "#include \"a/b/c$message_header_ext$\"\n" +std::string ImportInludeFromProtoName(const std::string& proto_name) { + return std::string("#include \"") + + proto_name.substr(0, proto_name.size() - 6) + + std::string("$message_header_ext$\"\n"); +} + +std::string GetHeaderIncludes(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + if (!params.additional_header_includes.empty()) { + PrintIncludes(printer.get(), params.additional_header_includes, false, + ""); + } + static const char* headers_strs[] = { + "functional", + "grpcpp/impl/codegen/async_generic_service.h", + "grpcpp/impl/codegen/async_stream.h", + "grpcpp/impl/codegen/async_unary_call.h", + "grpcpp/impl/codegen/client_callback.h", + "grpcpp/impl/codegen/client_context.h", + "grpcpp/impl/codegen/completion_queue.h", + "grpcpp/impl/codegen/message_allocator.h", + "grpcpp/impl/codegen/method_handler.h", + "grpcpp/impl/codegen/proto_utils.h", + "grpcpp/impl/codegen/rpc_method.h", + "grpcpp/impl/codegen/server_callback.h", + "grpcpp/impl/codegen/server_callback_handlers.h", + "grpcpp/impl/codegen/server_context.h", + "grpcpp/impl/codegen/service_type.h", + "grpcpp/impl/codegen/status.h", + "grpcpp/impl/codegen/stub_options.h", + "grpcpp/impl/codegen/sync_stream.h", + }; + std::vector headers(headers_strs, array_end(headers_strs)); + PrintIncludes(printer.get(), headers, params.use_system_headers, + params.grpc_search_path); + printer->Print(vars, "\n"); + + vars["message_header_ext"] = params.message_header_extension.empty() + ? kCppGeneratorMessageHeaderExt + : params.message_header_extension; + + if (params.include_import_headers) { + const std::vector import_names = file->GetImportNames(); + for (const auto& import_name : import_names) { + const std::string include_name = ImportInludeFromProtoName(import_name); + printer->Print(vars, include_name.c_str()); + } + printer->PrintRaw("\n"); + } + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.begin(); part != parts.end(); part++) { + vars["part"] = *part; + printer->Print(vars, "namespace $part$ {\n"); + } + printer->Print(vars, "\n"); + } + } + return output; +} + +void PrintHeaderClientMethodInterfaces(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars, + bool is_public) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + + struct { + std::string prefix; + std::string method_params; // extra arguments to method + std::string raw_args; // extra arguments to raw version of method + } async_prefixes[] = {{"Async", ", void* tag", ", tag"}, + {"PrepareAsync", "", ""}}; + + if (is_public) { + if (method->NoStreaming()) { + printer->Print( + *vars, + "virtual ::grpc::Status $Method$(::grpc::ClientContext* context, " + "const $Request$& request, $Response$* response) = 0;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + printer->Print( + *vars, + "std::unique_ptr< " + "::grpc::ClientAsyncResponseReaderInterface< $Response$>> " + "$AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncResponseReaderInterface< $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, request, cq));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientWriterInterface< $Request$>>" + " $Method$(" + "::grpc::ClientContext* context, $Response$* response) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< ::grpc::ClientWriterInterface< $Request$>>" + "($Method$Raw(context, response));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientAsyncWriterInterface< $Request$>>" + " $AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "$Response$* " + "response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print(*vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncWriterInterface< $Request$>>(" + "$AsyncPrefix$$Method$Raw(context, response, " + "cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientReaderInterface< $Response$>>" + " $Method$(::grpc::ClientContext* context, const $Request$& request)" + " {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< ::grpc::ClientReaderInterface< $Response$>>" + "($Method$Raw(context, request));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto& async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientAsyncReaderInterface< $Response$>> " + "$AsyncPrefix$$Method$(" + "::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncReaderInterface< $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, request, cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "std::unique_ptr< ::grpc::ClientReaderWriterInterface< " + "$Request$, $Response$>> " + "$Method$(::grpc::ClientContext* context) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< " + "::grpc::ClientReaderWriterInterface< $Request$, $Response$>>(" + "$Method$Raw(context));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "std::unique_ptr< " + "::grpc::ClientAsyncReaderWriterInterface< $Request$, $Response$>> " + "$AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncReaderWriterInterface< $Request$, $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } + } else { + if (method->NoStreaming()) { + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + printer->Print( + *vars, + "virtual ::grpc::ClientAsyncResponseReaderInterface< $Response$>* " + "$AsyncPrefix$$Method$Raw(::grpc::ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) = 0;\n"); + } + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "virtual ::grpc::ClientWriterInterface< $Request$>*" + " $Method$Raw(" + "::grpc::ClientContext* context, $Response$* response) = 0;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + printer->Print( + *vars, + "virtual ::grpc::ClientAsyncWriterInterface< $Request$>*" + " $AsyncPrefix$$Method$Raw(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) = 0;\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "virtual ::grpc::ClientReaderInterface< $Response$>* " + "$Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request) = 0;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + printer->Print( + *vars, + "virtual ::grpc::ClientAsyncReaderInterface< $Response$>* " + "$AsyncPrefix$$Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) = 0;\n"); + } + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "virtual ::grpc::ClientReaderWriterInterface< $Request$, " + "$Response$>* " + "$Method$Raw(::grpc::ClientContext* context) = 0;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + printer->Print( + *vars, + "virtual ::grpc::ClientAsyncReaderWriterInterface< " + "$Request$, $Response$>* " + "$AsyncPrefix$$Method$Raw(::grpc::ClientContext* context, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) = 0;\n"); + } + } + } +} + +void PrintHeaderClientMethod(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars, + bool is_public) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + struct { + std::string prefix; + std::string method_params; // extra arguments to method + std::string raw_args; // extra arguments to raw version of method + } async_prefixes[] = {{"Async", ", void* tag", ", tag"}, + {"PrepareAsync", "", ""}}; + + if (is_public) { + if (method->NoStreaming()) { + printer->Print( + *vars, + "::grpc::Status $Method$(::grpc::ClientContext* context, " + "const $Request$& request, $Response$* response) override;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientAsyncResponseReader< $Response$>> " + "$AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) {\n"); + printer->Indent(); + printer->Print(*vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncResponseReader< $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, request, cq));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientWriter< $Request$>>" + " $Method$(" + "::grpc::ClientContext* context, $Response$* response) {\n"); + printer->Indent(); + printer->Print(*vars, + "return std::unique_ptr< ::grpc::ClientWriter< $Request$>>" + "($Method$Raw(context, response));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print(*vars, + "std::unique_ptr< ::grpc::ClientAsyncWriter< $Request$>>" + " $AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< ::grpc::ClientAsyncWriter< $Request$>>(" + "$AsyncPrefix$$Method$Raw(context, response, " + "cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientReader< $Response$>>" + " $Method$(::grpc::ClientContext* context, const $Request$& request)" + " {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< ::grpc::ClientReader< $Response$>>" + "($Method$Raw(context, request));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientAsyncReader< $Response$>> " + "$AsyncPrefix$$Method$(" + "::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< ::grpc::ClientAsyncReader< $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, request, cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "std::unique_ptr< ::grpc::ClientReaderWriter< $Request$, $Response$>>" + " $Method$(::grpc::ClientContext* context) {\n"); + printer->Indent(); + printer->Print(*vars, + "return std::unique_ptr< " + "::grpc::ClientReaderWriter< $Request$, $Response$>>(" + "$Method$Raw(context));\n"); + printer->Outdent(); + printer->Print("}\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print(*vars, + "std::unique_ptr< ::grpc::ClientAsyncReaderWriter< " + "$Request$, $Response$>> " + "$AsyncPrefix$$Method$(::grpc::ClientContext* context, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Indent(); + printer->Print( + *vars, + "return std::unique_ptr< " + "::grpc::ClientAsyncReaderWriter< $Request$, $Response$>>(" + "$AsyncPrefix$$Method$Raw(context, cq$AsyncRawArgs$));\n"); + printer->Outdent(); + printer->Print("}\n"); + } + } + } else { + if (method->NoStreaming()) { + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + printer->Print( + *vars, + "::grpc::ClientAsyncResponseReader< $Response$>* " + "$AsyncPrefix$$Method$Raw(::grpc::ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) override;\n"); + } + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "::grpc::ClientWriter< $Request$>* $Method$Raw(" + "::grpc::ClientContext* context, $Response$* response) " + "override;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "::grpc::ClientAsyncWriter< $Request$>* $AsyncPrefix$$Method$Raw(" + "::grpc::ClientContext* context, $Response$* response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) override;\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print(*vars, + "::grpc::ClientReader< $Response$>* $Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request)" + " override;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "::grpc::ClientAsyncReader< $Response$>* $AsyncPrefix$$Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) override;\n"); + } + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "::grpc::ClientReaderWriter< $Request$, $Response$>* " + "$Method$Raw(::grpc::ClientContext* context) override;\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncRawArgs"] = async_prefix.raw_args; + printer->Print( + *vars, + "::grpc::ClientAsyncReaderWriter< $Request$, $Response$>* " + "$AsyncPrefix$$Method$Raw(::grpc::ClientContext* context, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) override;\n"); + } + } + } +} + +void PrintHeaderClientMethodCallbackInterfacesStart( + grpc_generator::Printer* printer, + std::map* /*vars*/) { + // This declares the interface for the callback-based API. The components + // are pure; even though this is new (post-1.0) API, it can be pure because + // it is an entirely new interface that happens to be scoped within + // StubInterface, not new additions to StubInterface itself + printer->Print("class async_interface {\n"); + // All methods in this new interface are public. There is no need for private + // "Raw" methods since the callback-based API returns unowned raw pointers + printer->Print(" public:\n"); + printer->Indent(); + printer->Print("virtual ~async_interface() {}\n"); +} + +void PrintHeaderClientMethodCallbackInterfaces( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + + if (method->NoStreaming()) { + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "std::function) = 0;\n"); + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "::grpc::ClientUnaryReactor* reactor) = 0;\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::ClientWriteReactor< $Request$>* " + "reactor) = 0;\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, " + "::grpc::ClientReadReactor< $Response$>* " + "reactor) = 0;\n"); + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "::grpc::ClientBidiReactor< " + "$Request$,$Response$>* reactor) = 0;\n"); + } +} + +void PrintHeaderClientMethodCallbackInterfacesEnd( + grpc_generator::Printer* printer, + std::map* /*vars*/) { + printer->Outdent(); + printer->Print("};\n"); + // TODO: Remove typedef when all uses of experimental_async are migrated off. + printer->Print( + "typedef class async_interface experimental_async_interface;\n"); + + // Declare a function to give the async stub contents. It can't be pure + // since this is a new API in StubInterface, but it is meaningless by default + // (since any stub that wants to use it must have its own implementation of + // the callback functions therein), so make the default return value nullptr. + // Intentionally include the word "class" to avoid possible shadowing. + // TODO: Remove experimental_async call when possible, replace with nullptr. + printer->Print( + "virtual class async_interface* async() { return nullptr; }\n"); + + // TODO: Remove experimental_async call when possible. + printer->Print( + "class async_interface* experimental_async() { return async(); }\n"); +} + +void PrintHeaderClientMethodCallbackStart( + grpc_generator::Printer* printer, + std::map* /*vars*/) { + // This declares the stub entry for the callback-based API. + printer->Print("class async final :\n"); + printer->Print(" public StubInterface::async_interface {\n"); + printer->Print(" public:\n"); + printer->Indent(); +} + +void PrintHeaderClientMethodCallback(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + + if (method->NoStreaming()) { + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "std::function) override;\n"); + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "::grpc::ClientUnaryReactor* reactor) override;\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::ClientWriteReactor< $Request$>* " + "reactor) override;\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "const $Request$* request, " + "::grpc::ClientReadReactor< $Response$>* " + "reactor) override;\n"); + + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "::grpc::ClientBidiReactor< " + "$Request$,$Response$>* reactor) override;\n"); + } +} + +void PrintHeaderClientMethodCallbackEnd( + grpc_generator::Printer* printer, + std::map* /*vars*/) { + printer->Outdent(); + printer->Print(" private:\n"); + printer->Indent(); + printer->Print("friend class Stub;\n"); + printer->Print("explicit async(Stub* stub): stub_(stub) { }\n"); + // include a function with a phony use of stub_ to avoid an unused + // private member warning for service with no methods + printer->Print("Stub* stub() { return stub_; }\n"); + printer->Print("Stub* stub_;\n"); + printer->Outdent(); + printer->Print("};\n"); + + printer->Print( + "class async* async() override { " + "return &async_stub_; }\n"); +} + +void PrintHeaderClientMethodData(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + printer->Print(*vars, + "const ::grpc::internal::RpcMethod rpcmethod_$Method$_;\n"); +} + +void PrintHeaderServerMethodSync(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + printer->Print(method->GetLeadingComments("//").c_str()); + if (method->NoStreaming()) { + printer->Print(*vars, + "virtual ::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "$Response$* response);\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "virtual ::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReader< $Request$>* reader, " + "$Response$* response);\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print(*vars, + "virtual ::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "::grpc::ServerWriter< $Response$>* writer);\n"); + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "virtual ::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* stream);" + "\n"); + } + printer->Print(method->GetTrailingComments("//").c_str()); +} + +// Helper generator. Disables the sync API for Request and Response, then adds +// in an async API for RealRequest and RealResponse types. This is to be used +// to generate async and raw async APIs. +void PrintHeaderServerAsyncMethodsHelper( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + if (method->NoStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "void Request$Method$(" + "::grpc::ServerContext* context, $RealRequest$* request, " + "::grpc::ServerAsyncResponseWriter< $RealResponse$>* response, " + "::grpc::CompletionQueue* new_call_cq, " + "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n"); + printer->Print(*vars, + " ::grpc::Service::RequestAsyncUnary($Idx$, context, " + "request, response, new_call_cq, notification_cq, tag);\n"); + printer->Print("}\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReader< $Request$>* /*reader*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "void Request$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerAsyncReader< $RealResponse$, $RealRequest$>* reader, " + "::grpc::CompletionQueue* new_call_cq, " + "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n"); + printer->Print(*vars, + " ::grpc::Service::RequestAsyncClientStreaming($Idx$, " + "context, reader, new_call_cq, notification_cq, tag);\n"); + printer->Print("}\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "::grpc::ServerWriter< $Response$>* /*writer*/) override " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "void Request$Method$(" + "::grpc::ServerContext* context, $RealRequest$* request, " + "::grpc::ServerAsyncWriter< $RealResponse$>* writer, " + "::grpc::CompletionQueue* new_call_cq, " + "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n"); + printer->Print( + *vars, + " ::grpc::Service::RequestAsyncServerStreaming($Idx$, " + "context, request, writer, new_call_cq, notification_cq, tag);\n"); + printer->Print("}\n"); + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* /*stream*/) " + " override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "void Request$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerAsyncReaderWriter< $RealResponse$, $RealRequest$>* " + "stream, " + "::grpc::CompletionQueue* new_call_cq, " + "::grpc::ServerCompletionQueue* notification_cq, void *tag) {\n"); + printer->Print(*vars, + " ::grpc::Service::RequestAsyncBidiStreaming($Idx$, " + "context, stream, new_call_cq, notification_cq, tag);\n"); + printer->Print("}\n"); + } +} + +void PrintHeaderServerMethodAsync(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for the async API + (*vars)["RealRequest"] = method->input_type_name(); + (*vars)["RealResponse"] = method->output_type_name(); + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithAsyncMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithAsyncMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodAsync($Idx$);\n" + "}\n"); + printer->Print(*vars, + "~WithAsyncMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerAsyncMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +// Helper generator. Disables the sync API for Request and Response, then adds +// in a callback API for RealRequest and RealResponse types. This is to be used +// to generate callback and raw callback APIs. +void PrintHeaderServerCallbackMethodsHelper( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + if (method->NoStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print(*vars, + "virtual ::grpc::ServerUnaryReactor* $Method$(\n" + " ::grpc::CallbackServerContext* /*context*/, " + "const $RealRequest$* /*request*/, " + "$RealResponse$* /*response*/)" + " { return nullptr; }\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReader< $Request$>* /*reader*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print(*vars, + "virtual ::grpc::ServerReadReactor< " + "$RealRequest$>* $Method$(\n" + " ::grpc::CallbackServerContext* " + "/*context*/, $RealResponse$* /*response*/)" + " { return nullptr; }\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "::grpc::ServerWriter< $Response$>* /*writer*/) override " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "virtual ::grpc::ServerWriteReactor< $RealResponse$>* $Method$(\n" + " ::grpc::CallbackServerContext* " + "/*context*/, const $RealRequest$* /*request*/)" + " { return nullptr; }\n"); + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* /*stream*/) " + " override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "virtual ::grpc::ServerBidiReactor< $RealRequest$, $RealResponse$>* " + "$Method$(\n" + " ::grpc::CallbackServerContext* /*context*/)\n" + " { return nullptr; }\n"); + } +} + +void PrintHeaderServerMethodCallback(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for the callback API + (*vars)["RealRequest"] = method->input_type_name(); + (*vars)["RealResponse"] = method->output_type_name(); + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithCallbackMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, "WithCallbackMethod_$Method$() {\n"); + if (method->NoStreaming()) { + printer->Print( + *vars, + " ::grpc::Service::MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackUnaryHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "const $RealRequest$* " + "request, " + "$RealResponse$* response) { " + "return this->$Method$(context, request, response); }));}\n"); + printer->Print(*vars, + "void SetMessageAllocatorFor_$Method$(\n" + " ::grpc::MessageAllocator< " + "$RealRequest$, $RealResponse$>* allocator) {\n" + " ::grpc::internal::MethodHandler* const handler = " + "::grpc::Service::GetHandler($Idx$);\n" + " static_cast<::grpc::internal::CallbackUnaryHandler< " + "$RealRequest$, $RealResponse$>*>(handler)\n" + " ->SetMessageAllocator(allocator);\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + " ::grpc::Service::MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "$RealResponse$* " + "response) { " + "return this->$Method$(context, response); }));\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + " ::grpc::Service::MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "const $RealRequest$* " + "request) { " + "return this->$Method$(context, request); }));\n"); + } else if (method->BidiStreaming()) { + printer->Print(*vars, + " ::grpc::Service::MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context) " + "{ return this->$Method$(context); }));\n"); + } + printer->Print(*vars, "}\n"); + printer->Print(*vars, + "~WithCallbackMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerCallbackMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +void PrintHeaderServerMethodRawCallback( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for raw API + (*vars)["RealRequest"] = "::grpc::ByteBuffer"; + (*vars)["RealResponse"] = "::grpc::ByteBuffer"; + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithRawCallbackMethod_$Method$ : public " + "BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, "WithRawCallbackMethod_$Method$() {\n"); + if (method->NoStreaming()) { + printer->Print(*vars, + " ::grpc::Service::MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackUnaryHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "const $RealRequest$* " + "request, " + "$RealResponse$* response) { return " + "this->$Method$(context, request, response); }));\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + " ::grpc::Service::MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "$RealResponse$* response) " + "{ return this->$Method$(context, response); }));\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + " ::grpc::Service::MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context, " + "const" + "$RealRequest$* request) { return " + "this->$Method$(context, request); }));\n"); + } else if (method->BidiStreaming()) { + printer->Print(*vars, + " ::grpc::Service::MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this](\n" + " ::grpc::CallbackServerContext* context) " + "{ return this->$Method$(context); }));\n"); + } + printer->Print(*vars, "}\n"); + printer->Print(*vars, + "~WithRawCallbackMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerCallbackMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +void PrintHeaderServerMethodStreamedUnary( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + if (method->NoStreaming()) { + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithStreamedUnaryMethod_$Method$ : " + "public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithStreamedUnaryMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodStreamed($Idx$,\n" + " new ::grpc::internal::StreamedUnaryHandler<\n" + " $Request$, $Response$>(\n" + " [this](::grpc::ServerContext* context,\n" + " ::grpc::ServerUnaryStreamer<\n" + " $Request$, $Response$>* streamer) {\n" + " return this->Streamed$Method$(context,\n" + " streamer);\n" + " }));\n" + "}\n"); + printer->Print(*vars, + "~WithStreamedUnaryMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + printer->Print( + *vars, + "// disable regular version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print(*vars, + "// replace default version of method with streamed unary\n" + "virtual ::grpc::Status Streamed$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerUnaryStreamer< " + "$Request$,$Response$>* server_unary_streamer)" + " = 0;\n"); + printer->Outdent(); + printer->Print(*vars, "};\n"); + } +} + +void PrintHeaderServerMethodSplitStreaming( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + if (ServerOnlyStreaming(method)) { + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithSplitStreamingMethod_$Method$ : " + "public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithSplitStreamingMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodStreamed($Idx$,\n" + " new ::grpc::internal::SplitServerStreamingHandler<\n" + " $Request$, $Response$>(\n" + " [this](::grpc::ServerContext* context,\n" + " ::grpc::ServerSplitStreamer<\n" + " $Request$, $Response$>* streamer) {\n" + " return this->Streamed$Method$(context,\n" + " streamer);\n" + " }));\n" + "}\n"); + printer->Print(*vars, + "~WithSplitStreamingMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + printer->Print( + *vars, + "// disable regular version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "::grpc::ServerWriter< $Response$>* /*writer*/) override " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print(*vars, + "// replace default version of method with split streamed\n" + "virtual ::grpc::Status Streamed$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerSplitStreamer< " + "$Request$,$Response$>* server_split_streamer)" + " = 0;\n"); + printer->Outdent(); + printer->Print(*vars, "};\n"); + } +} + +void PrintHeaderServerMethodGeneric(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithGenericMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithGenericMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodGeneric($Idx$);\n" + "}\n"); + printer->Print(*vars, + "~WithGenericMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + if (method->NoStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReader< $Request$>* /*reader*/, " + "$Response$* /*response*/) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, const $Request$* /*request*/, " + "::grpc::ServerWriter< $Response$>* /*writer*/) override " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* /*context*/, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* /*stream*/) " + " override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +void PrintHeaderServerMethodRaw(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for raw API + (*vars)["RealRequest"] = "::grpc::ByteBuffer"; + (*vars)["RealResponse"] = "::grpc::ByteBuffer"; + printer->Print(*vars, "template \n"); + printer->Print(*vars, "class WithRawMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service* /*service*/) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithRawMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodRaw($Idx$);\n" + "}\n"); + printer->Print(*vars, + "~WithRawMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerAsyncMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +void PrintHeaderService(grpc_generator::Printer* printer, + const grpc_generator::Service* service, + std::map* vars) { + (*vars)["Service"] = service->name(); + + printer->Print(service->GetLeadingComments("//").c_str()); + printer->Print(*vars, + "class $Service$ final {\n" + " public:\n"); + printer->Indent(); + + // Service metadata + printer->Print(*vars, + "static constexpr char const* service_full_name() {\n" + " return \"$Package$$Service$\";\n" + "}\n"); + + // Client side + printer->Print( + "class StubInterface {\n" + " public:\n"); + printer->Indent(); + printer->Print("virtual ~StubInterface() {}\n"); + for (int i = 0; i < service->method_count(); ++i) { + printer->Print(service->method(i)->GetLeadingComments("//").c_str()); + PrintHeaderClientMethodInterfaces(printer, service->method(i).get(), vars, + true); + printer->Print(service->method(i)->GetTrailingComments("//").c_str()); + } + PrintHeaderClientMethodCallbackInterfacesStart(printer, vars); + for (int i = 0; i < service->method_count(); ++i) { + printer->Print(service->method(i)->GetLeadingComments("//").c_str()); + PrintHeaderClientMethodCallbackInterfaces(printer, service->method(i).get(), + vars); + printer->Print(service->method(i)->GetTrailingComments("//").c_str()); + } + PrintHeaderClientMethodCallbackInterfacesEnd(printer, vars); + printer->Outdent(); + printer->Print(" private:\n"); + printer->Indent(); + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderClientMethodInterfaces(printer, service->method(i).get(), vars, + false); + } + printer->Outdent(); + printer->Print("};\n"); + printer->Print( + "class Stub final : public StubInterface" + " {\n public:\n"); + printer->Indent(); + printer->Print( + "Stub(const std::shared_ptr< ::grpc::ChannelInterface>& " + "channel, const ::grpc::StubOptions& options = " + "::grpc::StubOptions());\n"); + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderClientMethod(printer, service->method(i).get(), vars, true); + } + PrintHeaderClientMethodCallbackStart(printer, vars); + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderClientMethodCallback(printer, service->method(i).get(), vars); + } + PrintHeaderClientMethodCallbackEnd(printer, vars); + printer->Outdent(); + printer->Print("\n private:\n"); + printer->Indent(); + printer->Print("std::shared_ptr< ::grpc::ChannelInterface> channel_;\n"); + printer->Print("class async async_stub_{this};\n"); + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderClientMethod(printer, service->method(i).get(), vars, false); + } + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderClientMethodData(printer, service->method(i).get(), vars); + } + printer->Outdent(); + printer->Print("};\n"); + printer->Print( + "static std::unique_ptr NewStub(const std::shared_ptr< " + "::grpc::ChannelInterface>& channel, " + "const ::grpc::StubOptions& options = ::grpc::StubOptions());\n"); + + printer->Print("\n"); + + // Server side - base + printer->Print( + "class Service : public ::grpc::Service {\n" + " public:\n"); + printer->Indent(); + printer->Print("Service();\n"); + printer->Print("virtual ~Service();\n"); + for (int i = 0; i < service->method_count(); ++i) { + PrintHeaderServerMethodSync(printer, service->method(i).get(), vars); + } + printer->Outdent(); + printer->Print("};\n"); + + // Server side - Asynchronous + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodAsync(printer, service->method(i).get(), vars); + } + + printer->Print("typedef "); + + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i)->name(); + printer->Print(*vars, "WithAsyncMethod_$method_name$<"); + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + printer->Print(" >"); + } + printer->Print(" AsyncService;\n"); + + // Server side - Callback + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodCallback(printer, service->method(i).get(), vars); + } + + printer->Print("typedef "); + + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i)->name(); + printer->Print(*vars, "WithCallbackMethod_$method_name$<"); + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + printer->Print(" >"); + } + printer->Print(" CallbackService;\n"); + + // TODO: Remove following typedef once all uses of ExperimentalCallbackService + // are migrated to CallbackService + printer->Print("typedef CallbackService ExperimentalCallbackService;\n"); + + // Server side - Generic + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodGeneric(printer, service->method(i).get(), vars); + } + + // Server side - Raw + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodRaw(printer, service->method(i).get(), vars); + } + + // Server side - Raw Callback + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodRawCallback(printer, service->method(i).get(), vars); + } + + // Server side - Streamed Unary + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodStreamedUnary(printer, service->method(i).get(), + vars); + } + + printer->Print("typedef "); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i)->name(); + if (service->method(i)->NoStreaming()) { + printer->Print(*vars, "WithStreamedUnaryMethod_$method_name$<"); + } + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + if (service->method(i)->NoStreaming()) { + printer->Print(" >"); + } + } + printer->Print(" StreamedUnaryService;\n"); + + // Server side - controlled server-side streaming + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodSplitStreaming(printer, service->method(i).get(), + vars); + } + + printer->Print("typedef "); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i)->name(); + auto method = service->method(i); + if (ServerOnlyStreaming(method.get())) { + printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<"); + } + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + if (ServerOnlyStreaming(method.get())) { + printer->Print(" >"); + } + } + printer->Print(" SplitStreamedService;\n"); + + // Server side - typedef for controlled both unary and server-side streaming + printer->Print("typedef "); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i)->name(); + auto method = service->method(i); + if (ServerOnlyStreaming(method.get())) { + printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<"); + } + if (service->method(i)->NoStreaming()) { + printer->Print(*vars, "WithStreamedUnaryMethod_$method_name$<"); + } + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + if (service->method(i)->NoStreaming() || + ServerOnlyStreaming(method.get())) { + printer->Print(" >"); + } + } + printer->Print(" StreamedService;\n"); + + printer->Outdent(); + printer->Print("};\n"); + printer->Print(service->GetTrailingComments("//").c_str()); +} + +std::string GetHeaderServices(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + // Package string is empty or ends with a dot. It is used to fully qualify + // method names. + vars["Package"] = file->package(); + if (!file->package().empty()) { + vars["Package"].append("."); + } + + if (!params.services_namespace.empty()) { + vars["services_namespace"] = params.services_namespace; + printer->Print(vars, "\nnamespace $services_namespace$ {\n\n"); + } + + for (int i = 0; i < file->service_count(); ++i) { + PrintHeaderService(printer.get(), file->service(i).get(), &vars); + printer->Print("\n"); + } + + if (!params.services_namespace.empty()) { + printer->Print(vars, "} // namespace $services_namespace$\n\n"); + } + } + return output; +} + +std::string GetHeaderEpilogue(grpc_generator::File* file, + const Parameters& /*params*/) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + vars["filename"] = file->filename(); + vars["filename_identifier"] = FilenameIdentifier(file->filename()); + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.rbegin(); part != parts.rend(); part++) { + vars["part"] = *part; + printer->Print(vars, "} // namespace $part$\n"); + } + printer->Print(vars, "\n"); + } + + printer->Print(vars, "\n"); + printer->Print(vars, "#endif // GRPC_$filename_identifier$__INCLUDED\n"); + + printer->Print(file->GetTrailingComments("//").c_str()); + } + return output; +} + +std::string GetSourcePrologue(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + vars["filename"] = file->filename(); + vars["filename_base"] = file->filename_without_ext(); + vars["message_header_ext"] = params.message_header_extension.empty() + ? kCppGeneratorMessageHeaderExt + : params.message_header_extension; + vars["service_header_ext"] = kCppGeneratorServiceHeaderExt; + + printer->Print(vars, "// Generated by the gRPC C++ plugin.\n"); + printer->Print(vars, + "// If you make any local change, they will be lost.\n"); + printer->Print(vars, "// source: $filename$\n\n"); + + printer->Print(vars, "#include \"$filename_base$$message_header_ext$\"\n"); + printer->Print(vars, "#include \"$filename_base$$service_header_ext$\"\n"); + printer->Print(vars, "\n"); + } + return output; +} + +std::string GetSourceIncludes(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + static const char* headers_strs[] = { + "functional", + "grpcpp/impl/codegen/async_stream.h", + "grpcpp/impl/codegen/async_unary_call.h", + "grpcpp/impl/codegen/channel_interface.h", + "grpcpp/impl/codegen/client_unary_call.h", + "grpcpp/impl/codegen/client_callback.h", + "grpcpp/impl/codegen/message_allocator.h", + "grpcpp/impl/codegen/method_handler.h", + "grpcpp/impl/codegen/rpc_service_method.h", + "grpcpp/impl/codegen/server_callback.h", + "grpcpp/impl/codegen/server_callback_handlers.h", + "grpcpp/impl/codegen/server_context.h", + "grpcpp/impl/codegen/service_type.h", + "grpcpp/impl/codegen/sync_stream.h"}; + std::vector headers(headers_strs, array_end(headers_strs)); + PrintIncludes(printer.get(), headers, params.use_system_headers, + params.grpc_search_path); + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.begin(); part != parts.end(); part++) { + vars["part"] = *part; + printer->Print(vars, "namespace $part$ {\n"); + } + } + + printer->Print(vars, "\n"); + } + return output; +} + +void PrintSourceClientMethod(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + struct { + std::string prefix; + std::string start; // bool literal expressed as string + std::string method_params; // extra arguments to method + std::string create_args; // extra arguments to creator + } async_prefixes[] = {{"Async", "true", ", void* tag", ", tag"}, + {"PrepareAsync", "false", "", ", nullptr"}}; + if (method->NoStreaming()) { + printer->Print(*vars, + "::grpc::Status $ns$$Service$::Stub::$Method$(" + "::grpc::ClientContext* context, " + "const $Request$& request, $Response$* response) {\n"); + printer->Print(*vars, + " return ::grpc::internal::BlockingUnaryCall" + "< $Request$, $Response$, ::grpc::protobuf::MessageLite, " + "::grpc::protobuf::MessageLite>" + "(channel_.get(), rpcmethod_$Method$_, " + "context, request, response);\n}\n\n"); + + printer->Print(*vars, + "void $ns$$Service$::Stub::async::$Method$(" + "::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "std::function f) {\n"); + printer->Print(*vars, + " ::grpc::internal::CallbackUnaryCall" + "< $Request$, $Response$, ::grpc::protobuf::MessageLite, " + "::grpc::protobuf::MessageLite>" + "(stub_->channel_.get(), stub_->rpcmethod_$Method$_, " + "context, request, response, std::move(f));\n}\n\n"); + + printer->Print(*vars, + "void $ns$$Service$::Stub::async::$Method$(" + "::grpc::ClientContext* context, " + "const $Request$* request, $Response$* response, " + "::grpc::ClientUnaryReactor* reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackUnaryFactory::Create" + "< ::grpc::protobuf::MessageLite, " + "::grpc::protobuf::MessageLite>" + "(stub_->channel_.get(), stub_->rpcmethod_$Method$_, " + "context, request, response, reactor);\n}\n\n"); + + printer->Print(*vars, + "::grpc::ClientAsyncResponseReader< $Response$>* " + "$ns$$Service$::Stub::PrepareAsync$Method$Raw(::grpc::" + "ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) {\n"); + printer->Print(*vars, + " return " + "::grpc::internal::ClientAsyncResponseReaderHelper::Create" + "< $Response$, $Request$, ::grpc::protobuf::MessageLite, " + "::grpc::protobuf::MessageLite>" + "(channel_.get(), cq, rpcmethod_$Method$_, " + "context, request);\n" + "}\n\n"); + printer->Print(*vars, + "::grpc::ClientAsyncResponseReader< $Response$>* " + "$ns$$Service$::Stub::Async$Method$Raw(::grpc::" + "ClientContext* context, " + "const $Request$& request, " + "::grpc::CompletionQueue* cq) {\n"); + printer->Print(*vars, + " auto* result =\n" + " this->PrepareAsync$Method$Raw(context, request, cq);\n" + " result->StartCall();\n" + " return result;\n" + "}\n\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "::grpc::ClientWriter< $Request$>* " + "$ns$$Service$::Stub::$Method$Raw(" + "::grpc::ClientContext* context, $Response$* response) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientWriterFactory< " + "$Request$>::Create(" + "channel_.get(), " + "rpcmethod_$Method$_, " + "context, response);\n" + "}\n\n"); + + printer->Print(*vars, + "void $ns$$Service$::" + "Stub::async::$Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::ClientWriteReactor< $Request$>* reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackWriterFactory< " + "$Request$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, response, reactor);\n" + "}\n\n"); + + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncStart"] = async_prefix.start; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncCreateArgs"] = async_prefix.create_args; + printer->Print(*vars, + "::grpc::ClientAsyncWriter< $Request$>* " + "$ns$$Service$::Stub::$AsyncPrefix$$Method$Raw(" + "::grpc::ClientContext* context, $Response$* response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Print( + *vars, + " return ::grpc::internal::ClientAsyncWriterFactory< $Request$>" + "::Create(channel_.get(), cq, " + "rpcmethod_$Method$_, " + "context, response, $AsyncStart$$AsyncCreateArgs$);\n" + "}\n\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "::grpc::ClientReader< $Response$>* " + "$ns$$Service$::Stub::$Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientReaderFactory< " + "$Response$>::Create(" + "channel_.get(), " + "rpcmethod_$Method$_, " + "context, request);\n" + "}\n\n"); + + printer->Print(*vars, + "void $ns$$Service$::Stub::async::$Method$(::grpc::" + "ClientContext* context, " + "const $Request$* request, " + "::grpc::ClientReadReactor< $Response$>* reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackReaderFactory< " + "$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, request, reactor);\n" + "}\n\n"); + + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncStart"] = async_prefix.start; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncCreateArgs"] = async_prefix.create_args; + printer->Print( + *vars, + "::grpc::ClientAsyncReader< $Response$>* " + "$ns$$Service$::Stub::$AsyncPrefix$$Method$Raw(" + "::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientAsyncReaderFactory< " + "$Response$>" + "::Create(channel_.get(), cq, " + "rpcmethod_$Method$_, " + "context, request, $AsyncStart$$AsyncCreateArgs$);\n" + "}\n\n"); + } + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "::grpc::ClientReaderWriter< $Request$, $Response$>* " + "$ns$$Service$::Stub::$Method$Raw(::grpc::ClientContext* context) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientReaderWriterFactory< " + "$Request$, $Response$>::Create(" + "channel_.get(), " + "rpcmethod_$Method$_, " + "context);\n" + "}\n\n"); + + printer->Print(*vars, + "void $ns$$Service$::Stub::async::$Method$(::grpc::" + "ClientContext* context, " + "::grpc::ClientBidiReactor< $Request$,$Response$>* " + "reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackReaderWriterFactory< " + "$Request$,$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, reactor);\n" + "}\n\n"); + + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncStart"] = async_prefix.start; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["AsyncCreateArgs"] = async_prefix.create_args; + printer->Print(*vars, + "::grpc::ClientAsyncReaderWriter< $Request$, $Response$>* " + "$ns$$Service$::Stub::$AsyncPrefix$$Method$Raw(::grpc::" + "ClientContext* context, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$) {\n"); + printer->Print(*vars, + " return " + "::grpc::internal::ClientAsyncReaderWriterFactory< " + "$Request$, $Response$>::Create(" + "channel_.get(), cq, " + "rpcmethod_$Method$_, " + "context, $AsyncStart$$AsyncCreateArgs$);\n" + "}\n\n"); + } + } +} + +void PrintSourceServerMethod(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + if (method->NoStreaming()) { + printer->Print(*vars, + "::grpc::Status $ns$$Service$::Service::$Method$(" + "::grpc::ServerContext* context, " + "const $Request$* request, $Response$* response) {\n"); + printer->Print(" (void) context;\n"); + printer->Print(" (void) request;\n"); + printer->Print(" (void) response;\n"); + printer->Print( + " return ::grpc::Status(" + "::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"); + printer->Print("}\n\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print(*vars, + "::grpc::Status $ns$$Service$::Service::$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReader< $Request$>* reader, " + "$Response$* response) {\n"); + printer->Print(" (void) context;\n"); + printer->Print(" (void) reader;\n"); + printer->Print(" (void) response;\n"); + printer->Print( + " return ::grpc::Status(" + "::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"); + printer->Print("}\n\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print(*vars, + "::grpc::Status $ns$$Service$::Service::$Method$(" + "::grpc::ServerContext* context, " + "const $Request$* request, " + "::grpc::ServerWriter< $Response$>* writer) {\n"); + printer->Print(" (void) context;\n"); + printer->Print(" (void) request;\n"); + printer->Print(" (void) writer;\n"); + printer->Print( + " return ::grpc::Status(" + "::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"); + printer->Print("}\n\n"); + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "::grpc::Status $ns$$Service$::Service::$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* " + "stream) {\n"); + printer->Print(" (void) context;\n"); + printer->Print(" (void) stream;\n"); + printer->Print( + " return ::grpc::Status(" + "::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"); + printer->Print("}\n\n"); + } +} + +void PrintSourceService(grpc_generator::Printer* printer, + const grpc_generator::Service* service, + std::map* vars) { + (*vars)["Service"] = service->name(); + + if (service->method_count() > 0) { + printer->Print(*vars, + "static const char* $prefix$$Service$_method_names[] = {\n"); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Method"] = service->method(i)->name(); + printer->Print(*vars, " \"/$Package$$Service$/$Method$\",\n"); + } + printer->Print(*vars, "};\n\n"); + } + + printer->Print(*vars, + "std::unique_ptr< $ns$$Service$::Stub> $ns$$Service$::NewStub(" + "const std::shared_ptr< ::grpc::ChannelInterface>& channel, " + "const ::grpc::StubOptions& options) {\n" + " (void)options;\n" + " std::unique_ptr< $ns$$Service$::Stub> stub(new " + "$ns$$Service$::Stub(channel, options));\n" + " return stub;\n" + "}\n\n"); + printer->Print(*vars, + "$ns$$Service$::Stub::Stub(const std::shared_ptr< " + "::grpc::ChannelInterface>& channel, const " + "::grpc::StubOptions& options)\n"); + printer->Indent(); + printer->Print(": channel_(channel)"); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + (*vars)["Method"] = method->name(); + (*vars)["Idx"] = as_string(i); + if (method->NoStreaming()) { + (*vars)["StreamingType"] = "NORMAL_RPC"; + // NOTE: There is no reason to consider streamed-unary as a separate + // category here since this part is setting up the client-side stub + // and this appears as a NORMAL_RPC from the client-side. + } else if (ClientOnlyStreaming(method.get())) { + (*vars)["StreamingType"] = "CLIENT_STREAMING"; + } else if (ServerOnlyStreaming(method.get())) { + (*vars)["StreamingType"] = "SERVER_STREAMING"; + } else { + (*vars)["StreamingType"] = "BIDI_STREAMING"; + } + printer->Print( + *vars, + ", rpcmethod_$Method$_(" + "$prefix$$Service$_method_names[$Idx$], options.suffix_for_stats()," + "::grpc::internal::RpcMethod::$StreamingType$, " + "channel" + ")\n"); + } + printer->Print("{}\n\n"); + printer->Outdent(); + + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintSourceClientMethod(printer, service->method(i).get(), vars); + } + + printer->Print(*vars, "$ns$$Service$::Service::Service() {\n"); + printer->Indent(); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + (*vars)["Idx"] = as_string(i); + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + if (method->NoStreaming()) { + printer->Print( + *vars, + "AddMethod(new ::grpc::internal::RpcServiceMethod(\n" + " $prefix$$Service$_method_names[$Idx$],\n" + " ::grpc::internal::RpcMethod::NORMAL_RPC,\n" + " new ::grpc::internal::RpcMethodHandler< $ns$$Service$::Service, " + "$Request$, $Response$, ::grpc::protobuf::MessageLite, " + "::grpc::protobuf::MessageLite>(\n" + " []($ns$$Service$::Service* service,\n" + " ::grpc::ServerContext* ctx,\n" + " const $Request$* req,\n" + " $Response$* resp) {\n" + " return service->$Method$(ctx, req, resp);\n" + " }, this)));\n"); + } else if (ClientOnlyStreaming(method.get())) { + printer->Print( + *vars, + "AddMethod(new ::grpc::internal::RpcServiceMethod(\n" + " $prefix$$Service$_method_names[$Idx$],\n" + " ::grpc::internal::RpcMethod::CLIENT_STREAMING,\n" + " new ::grpc::internal::ClientStreamingHandler< " + "$ns$$Service$::Service, $Request$, $Response$>(\n" + " []($ns$$Service$::Service* service,\n" + " ::grpc::ServerContext* ctx,\n" + " ::grpc::ServerReader<$Request$>* reader,\n" + " $Response$* resp) {\n" + " return service->$Method$(ctx, reader, resp);\n" + " }, this)));\n"); + } else if (ServerOnlyStreaming(method.get())) { + printer->Print( + *vars, + "AddMethod(new ::grpc::internal::RpcServiceMethod(\n" + " $prefix$$Service$_method_names[$Idx$],\n" + " ::grpc::internal::RpcMethod::SERVER_STREAMING,\n" + " new ::grpc::internal::ServerStreamingHandler< " + "$ns$$Service$::Service, $Request$, $Response$>(\n" + " []($ns$$Service$::Service* service,\n" + " ::grpc::ServerContext* ctx,\n" + " const $Request$* req,\n" + " ::grpc::ServerWriter<$Response$>* writer) {\n" + " return service->$Method$(ctx, req, writer);\n" + " }, this)));\n"); + } else if (method->BidiStreaming()) { + printer->Print(*vars, + "AddMethod(new ::grpc::internal::RpcServiceMethod(\n" + " $prefix$$Service$_method_names[$Idx$],\n" + " ::grpc::internal::RpcMethod::BIDI_STREAMING,\n" + " new ::grpc::internal::BidiStreamingHandler< " + "$ns$$Service$::Service, $Request$, $Response$>(\n" + " []($ns$$Service$::Service* service,\n" + " ::grpc::ServerContext* ctx,\n" + " ::grpc::ServerReaderWriter<$Response$,\n" + " $Request$>* stream) {\n" + " return service->$Method$(ctx, stream);\n" + " }, this)));\n"); + } + } + printer->Outdent(); + printer->Print(*vars, "}\n\n"); + printer->Print(*vars, + "$ns$$Service$::Service::~Service() {\n" + "}\n\n"); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintSourceServerMethod(printer, service->method(i).get(), vars); + } +} + +std::string GetSourceServices(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + // Package string is empty or ends with a dot. It is used to fully qualify + // method names. + vars["Package"] = file->package(); + if (!file->package().empty()) { + vars["Package"].append("."); + } + if (!params.services_namespace.empty()) { + vars["ns"] = params.services_namespace + "::"; + vars["prefix"] = params.services_namespace; + } else { + vars["ns"] = ""; + vars["prefix"] = ""; + } + + for (int i = 0; i < file->service_count(); ++i) { + PrintSourceService(printer.get(), file->service(i).get(), &vars); + printer->Print("\n"); + } + } + return output; +} + +std::string GetSourceEpilogue(grpc_generator::File* file, + const Parameters& /*params*/) { + std::string temp; + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.begin(); part != parts.end(); part++) { + temp.append("} // namespace "); + temp.append(*part); + temp.append("\n"); + } + temp.append("\n"); + } + + return temp; +} + +// TODO(mmukhi): Make sure we need parameters or not. +std::string GetMockPrologue(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + vars["filename"] = file->filename(); + vars["filename_base"] = file->filename_without_ext(); + vars["message_header_ext"] = params.message_header_extension.empty() + ? kCppGeneratorMessageHeaderExt + : params.message_header_extension; + vars["service_header_ext"] = kCppGeneratorServiceHeaderExt; + + printer->Print(vars, "// Generated by the gRPC C++ plugin.\n"); + printer->Print(vars, + "// If you make any local change, they will be lost.\n"); + printer->Print(vars, "// source: $filename$\n\n"); + + printer->Print(vars, "#include \"$filename_base$$message_header_ext$\"\n"); + printer->Print(vars, "#include \"$filename_base$$service_header_ext$\"\n"); + if (params.include_import_headers) { + const std::vector import_names = file->GetImportNames(); + for (const auto& import_name : import_names) { + const std::string include_name = ImportInludeFromProtoName(import_name); + printer->Print(vars, include_name.c_str()); + } + printer->PrintRaw("\n"); + } + printer->Print(vars, file->additional_headers().c_str()); + printer->Print(vars, "\n"); + } + return output; +} + +// TODO(mmukhi): Add client-stream and completion-queue headers. +std::string GetMockIncludes(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + + static const char* headers_strs[] = { + "grpcpp/impl/codegen/async_stream.h", + "grpcpp/impl/codegen/sync_stream.h", + }; + std::vector headers(headers_strs, array_end(headers_strs)); + PrintIncludes(printer.get(), headers, params.use_system_headers, + params.grpc_search_path); + + std::vector gmock_header; + if (params.gmock_search_path.empty()) { + gmock_header.push_back("gmock/gmock.h"); + PrintIncludes(printer.get(), gmock_header, params.use_system_headers, + params.grpc_search_path); + } else { + gmock_header.push_back("gmock.h"); + // We use local includes when a gmock_search_path is given + PrintIncludes(printer.get(), gmock_header, false, + params.gmock_search_path); + } + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.begin(); part != parts.end(); part++) { + vars["part"] = *part; + printer->Print(vars, "namespace $part$ {\n"); + } + } + + printer->Print(vars, "\n"); + } + return output; +} + +void PrintMockClientMethods(grpc_generator::Printer* printer, + const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + + struct { + std::string prefix; + std::string method_params; // extra arguments to method + int extra_method_param_count; + } async_prefixes[] = {{"Async", ", void* tag", 1}, {"PrepareAsync", "", 0}}; + + if (method->NoStreaming()) { + printer->Print( + *vars, + "MOCK_METHOD3($Method$, ::grpc::Status(::grpc::ClientContext* context, " + "const $Request$& request, $Response$* response));\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + printer->Print( + *vars, + "MOCK_METHOD3($AsyncPrefix$$Method$Raw, " + "::grpc::ClientAsyncResponseReaderInterface< $Response$>*" + "(::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq));\n"); + } + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "MOCK_METHOD2($Method$Raw, " + "::grpc::ClientWriterInterface< $Request$>*" + "(::grpc::ClientContext* context, $Response$* response));\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["MockArgs"] = + std::to_string(3 + async_prefix.extra_method_param_count); + printer->Print(*vars, + "MOCK_METHOD$MockArgs$($AsyncPrefix$$Method$Raw, " + "::grpc::ClientAsyncWriterInterface< $Request$>*" + "(::grpc::ClientContext* context, $Response$* response, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$));\n"); + } + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "MOCK_METHOD2($Method$Raw, " + "::grpc::ClientReaderInterface< $Response$>*" + "(::grpc::ClientContext* context, const $Request$& request));\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["MockArgs"] = + std::to_string(3 + async_prefix.extra_method_param_count); + printer->Print( + *vars, + "MOCK_METHOD$MockArgs$($AsyncPrefix$$Method$Raw, " + "::grpc::ClientAsyncReaderInterface< $Response$>*" + "(::grpc::ClientContext* context, const $Request$& request, " + "::grpc::CompletionQueue* cq$AsyncMethodParams$));\n"); + } + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "MOCK_METHOD1($Method$Raw, " + "::grpc::ClientReaderWriterInterface< $Request$, $Response$>*" + "(::grpc::ClientContext* context));\n"); + for (auto async_prefix : async_prefixes) { + (*vars)["AsyncPrefix"] = async_prefix.prefix; + (*vars)["AsyncMethodParams"] = async_prefix.method_params; + (*vars)["MockArgs"] = + std::to_string(2 + async_prefix.extra_method_param_count); + printer->Print( + *vars, + "MOCK_METHOD$MockArgs$($AsyncPrefix$$Method$Raw, " + "::grpc::ClientAsyncReaderWriterInterface<$Request$, " + "$Response$>*" + "(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq" + "$AsyncMethodParams$));\n"); + } + } +} + +void PrintMockService(grpc_generator::Printer* printer, + const grpc_generator::Service* service, + std::map* vars) { + (*vars)["Service"] = service->name(); + + printer->Print(*vars, + "class Mock$Service$Stub : public $Service$::StubInterface {\n" + " public:\n"); + printer->Indent(); + for (int i = 0; i < service->method_count(); ++i) { + PrintMockClientMethods(printer, service->method(i).get(), vars); + } + printer->Outdent(); + printer->Print("};\n"); +} + +std::string GetMockServices(grpc_generator::File* file, + const Parameters& params) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto printer = file->CreatePrinter(&output); + std::map vars; + // Package string is empty or ends with a dot. It is used to fully qualify + // method names. + vars["Package"] = file->package(); + if (!file->package().empty()) { + vars["Package"].append("."); + } + + if (!params.services_namespace.empty()) { + vars["services_namespace"] = params.services_namespace; + printer->Print(vars, "\nnamespace $services_namespace$ {\n\n"); + } + + for (int i = 0; i < file->service_count(); i++) { + PrintMockService(printer.get(), file->service(i).get(), &vars); + printer->Print("\n"); + } + + if (!params.services_namespace.empty()) { + printer->Print(vars, "} // namespace $services_namespace$\n\n"); + } + } + return output; +} + +std::string GetMockEpilogue(grpc_generator::File* file, + const Parameters& /*params*/) { + std::string temp; + + if (!file->package().empty()) { + std::vector parts = file->package_parts(); + + for (auto part = parts.begin(); part != parts.end(); part++) { + temp.append("} // namespace "); + temp.append(*part); + temp.append("\n"); + } + temp.append("\n"); + } + + return temp; +} + +} // namespace grpc_cpp_generator diff --git a/src/compiler/cpp_plugin.cc b/src/compiler/cpp_plugin.cc new file mode 100644 index 00000000..2de27454 --- /dev/null +++ b/src/compiler/cpp_plugin.cc @@ -0,0 +1,26 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates cpp gRPC service interface out of Protobuf IDL. +// +#include "src/compiler/cpp_plugin.h" + +int main(int argc, char* argv[]) { + CppGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/csharp_generator.cc b/src/compiler/csharp_generator.cc new file mode 100644 index 00000000..b42c07cd --- /dev/null +++ b/src/compiler/csharp_generator.cc @@ -0,0 +1,829 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/csharp_generator.h" + +#include +#include +#include +#include + +#include "src/compiler/config.h" +#include "src/compiler/csharp_generator_helpers.h" + +using grpc::protobuf::Descriptor; +using grpc::protobuf::FileDescriptor; +using grpc::protobuf::MethodDescriptor; +using grpc::protobuf::ServiceDescriptor; +using grpc::protobuf::io::Printer; +using grpc::protobuf::io::StringOutputStream; +using grpc_generator::StringReplace; +using std::vector; + +namespace grpc_csharp_generator { +namespace { + +// This function is a massaged version of +// https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/csharp/csharp_doc_comment.cc +// Currently, we cannot easily reuse the functionality as +// google/protobuf/compiler/csharp/csharp_doc_comment.h is not a public header. +// TODO(jtattermusch): reuse the functionality from google/protobuf. +bool GenerateDocCommentBodyImpl(grpc::protobuf::io::Printer* printer, + grpc::protobuf::SourceLocation location) { + std::string comments = location.leading_comments.empty() + ? location.trailing_comments + : location.leading_comments; + if (comments.empty()) { + return false; + } + // XML escaping... no need for apostrophes etc as the whole text is going to + // be a child + // node of a summary element, not part of an attribute. + comments = grpc_generator::StringReplace(comments, "&", "&", true); + comments = grpc_generator::StringReplace(comments, "<", "<", true); + + std::vector lines; + grpc_generator::Split(comments, '\n', &lines); + // TODO: We really should work out which part to put in the summary and which + // to put in the remarks... + // but that needs to be part of a bigger effort to understand the markdown + // better anyway. + printer->Print("/// \n"); + bool last_was_empty = false; + // We squash multiple blank lines down to one, and remove any trailing blank + // lines. We need + // to preserve the blank lines themselves, as this is relevant in the + // markdown. + // Note that we can't remove leading or trailing whitespace as *that's* + // relevant in markdown too. + // (We don't skip "just whitespace" lines, either.) + for (std::vector::iterator it = lines.begin(); it != lines.end(); + ++it) { + std::string line = *it; + if (line.empty()) { + last_was_empty = true; + } else { + if (last_was_empty) { + printer->Print("///\n"); + } + last_was_empty = false; + printer->Print("///$line$\n", "line", *it); + } + } + printer->Print("/// \n"); + return true; +} + +void GenerateGeneratedCodeAttribute(grpc::protobuf::io::Printer* printer) { + // Mark the code as generated using the [GeneratedCode] attribute. + // We don't provide plugin version info in attribute the because: + // * the version information is not readily available from the plugin's code. + // * it would cause a lot of churn in the pre-generated code + // in this repository every time the version is updated. + printer->Print( + "[global::System.CodeDom.Compiler.GeneratedCode(\"grpc_csharp_plugin\", " + "null)]\n"); +} + +template +bool GenerateDocCommentBody(grpc::protobuf::io::Printer* printer, + const DescriptorType* descriptor) { + grpc::protobuf::SourceLocation location; + if (!descriptor->GetSourceLocation(&location)) { + return false; + } + return GenerateDocCommentBodyImpl(printer, location); +} + +void GenerateDocCommentServerMethod(grpc::protobuf::io::Printer* printer, + const MethodDescriptor* method) { + if (GenerateDocCommentBody(printer, method)) { + if (method->client_streaming()) { + printer->Print( + "/// Used for reading requests from " + "the client.\n"); + } else { + printer->Print( + "/// The request received from the " + "client.\n"); + } + if (method->server_streaming()) { + printer->Print( + "/// Used for sending responses back " + "to the client.\n"); + } + printer->Print( + "/// The context of the server-side call " + "handler being invoked.\n"); + if (method->server_streaming()) { + printer->Print( + "/// A task indicating completion of the " + "handler.\n"); + } else { + printer->Print( + "/// The response to send back to the client (wrapped by a " + "task).\n"); + } + } +} + +void GenerateDocCommentClientMethod(grpc::protobuf::io::Printer* printer, + const MethodDescriptor* method, + bool is_sync, bool use_call_options) { + if (GenerateDocCommentBody(printer, method)) { + if (!method->client_streaming()) { + printer->Print( + "/// The request to send to the " + "server.\n"); + } + if (!use_call_options) { + printer->Print( + "/// The initial metadata to send with the " + "call. This parameter is optional.\n"); + printer->Print( + "/// An optional deadline for the call. The " + "call will be cancelled if deadline is hit.\n"); + printer->Print( + "/// An optional token for " + "canceling the call.\n"); + } else { + printer->Print( + "/// The options for the call.\n"); + } + if (is_sync) { + printer->Print( + "/// The response received from the server.\n"); + } else { + printer->Print("/// The call object.\n"); + } + } +} + +std::string GetServiceClassName(const ServiceDescriptor* service) { + return service->name(); +} + +std::string GetClientClassName(const ServiceDescriptor* service) { + return service->name() + "Client"; +} + +std::string GetServerClassName(const ServiceDescriptor* service) { + return service->name() + "Base"; +} + +std::string GetCSharpMethodType(const MethodDescriptor* method) { + if (method->client_streaming()) { + if (method->server_streaming()) { + return "grpc::MethodType.DuplexStreaming"; + } else { + return "grpc::MethodType.ClientStreaming"; + } + } else { + if (method->server_streaming()) { + return "grpc::MethodType.ServerStreaming"; + } else { + return "grpc::MethodType.Unary"; + } + } +} + +std::string GetCSharpServerMethodType(const MethodDescriptor* method) { + if (method->client_streaming()) { + if (method->server_streaming()) { + return "grpc::DuplexStreamingServerMethod"; + } else { + return "grpc::ClientStreamingServerMethod"; + } + } else { + if (method->server_streaming()) { + return "grpc::ServerStreamingServerMethod"; + } else { + return "grpc::UnaryServerMethod"; + } + } +} + +std::string GetServiceNameFieldName() { return "__ServiceName"; } + +std::string GetMarshallerFieldName(const Descriptor* message) { + return "__Marshaller_" + + grpc_generator::StringReplace(message->full_name(), ".", "_", true); +} + +std::string GetMethodFieldName(const MethodDescriptor* method) { + return "__Method_" + method->name(); +} + +std::string GetMethodRequestParamMaybe(const MethodDescriptor* method, + bool invocation_param = false) { + if (method->client_streaming()) { + return ""; + } + if (invocation_param) { + return "request, "; + } + return GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()) + " request, "; +} + +std::string GetAccessLevel(bool internal_access) { + return internal_access ? "internal" : "public"; +} + +std::string GetMethodReturnTypeClient(const MethodDescriptor* method) { + if (method->client_streaming()) { + if (method->server_streaming()) { + return "grpc::AsyncDuplexStreamingCall<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()) + ", " + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + ">"; + } else { + return "grpc::AsyncClientStreamingCall<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()) + ", " + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + ">"; + } + } else { + if (method->server_streaming()) { + return "grpc::AsyncServerStreamingCall<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + ">"; + } else { + return "grpc::AsyncUnaryCall<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + ">"; + } + } +} + +std::string GetMethodRequestParamServer(const MethodDescriptor* method) { + if (method->client_streaming()) { + return "grpc::IAsyncStreamReader<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()) + + "> requestStream"; + } + return GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()) + " request"; +} + +std::string GetMethodReturnTypeServer(const MethodDescriptor* method) { + if (method->server_streaming()) { + return "global::System.Threading.Tasks.Task"; + } + return "global::System.Threading.Tasks.Task<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + ">"; +} + +std::string GetMethodResponseStreamMaybe(const MethodDescriptor* method) { + if (method->server_streaming()) { + return ", grpc::IServerStreamWriter<" + + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()) + + "> responseStream"; + } + return ""; +} + +// Gets vector of all messages used as input or output types. +std::vector GetUsedMessages( + const ServiceDescriptor* service) { + std::set descriptor_set; + std::vector + result; // vector is to maintain stable ordering + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + if (descriptor_set.find(method->input_type()) == descriptor_set.end()) { + descriptor_set.insert(method->input_type()); + result.push_back(method->input_type()); + } + if (descriptor_set.find(method->output_type()) == descriptor_set.end()) { + descriptor_set.insert(method->output_type()); + result.push_back(method->output_type()); + } + } + return result; +} + +void GenerateMarshallerFields(Printer* out, const ServiceDescriptor* service) { + std::vector used_messages = GetUsedMessages(service); + if (used_messages.size() != 0) { + // Generate static helper methods for serialization/deserialization + GenerateGeneratedCodeAttribute(out); + out->Print( + "static void __Helper_SerializeMessage(" + "global::Google.Protobuf.IMessage message, " + "grpc::SerializationContext context)\n" + "{\n"); + out->Indent(); + out->Print( + "#if !GRPC_DISABLE_PROTOBUF_BUFFER_SERIALIZATION\n" + "if (message is global::Google.Protobuf.IBufferMessage)\n" + "{\n"); + out->Indent(); + out->Print( + "context.SetPayloadLength(message.CalculateSize());\n" + "global::Google.Protobuf.MessageExtensions.WriteTo(message, " + "context.GetBufferWriter());\n" + "context.Complete();\n" + "return;\n"); + out->Outdent(); + out->Print( + "}\n" + "#endif\n"); + out->Print( + "context.Complete(" + "global::Google.Protobuf.MessageExtensions.ToByteArray(message));\n"); + out->Outdent(); + out->Print("}\n\n"); + + GenerateGeneratedCodeAttribute(out); + out->Print( + "static class __Helper_MessageCache\n" + "{\n"); + out->Indent(); + out->Print( + "public static readonly bool IsBufferMessage = " + "global::System.Reflection.IntrospectionExtensions.GetTypeInfo(typeof(" + "global::Google.Protobuf.IBufferMessage)).IsAssignableFrom(typeof(T));" + "\n"); + out->Outdent(); + out->Print("}\n\n"); + + GenerateGeneratedCodeAttribute(out); + out->Print( + "static T __Helper_DeserializeMessage(" + "grpc::DeserializationContext context, " + "global::Google.Protobuf.MessageParser parser) " + "where T : global::Google.Protobuf.IMessage\n" + "{\n"); + out->Indent(); + out->Print( + "#if !GRPC_DISABLE_PROTOBUF_BUFFER_SERIALIZATION\n" + "if (__Helper_MessageCache.IsBufferMessage)\n" + "{\n"); + out->Indent(); + out->Print( + "return parser.ParseFrom(context.PayloadAsReadOnlySequence());\n"); + out->Outdent(); + out->Print( + "}\n" + "#endif\n"); + out->Print("return parser.ParseFrom(context.PayloadAsNewBuffer());\n"); + out->Outdent(); + out->Print("}\n\n"); + } + + for (size_t i = 0; i < used_messages.size(); i++) { + const Descriptor* message = used_messages[i]; + GenerateGeneratedCodeAttribute(out); + out->Print( + "static readonly grpc::Marshaller<$type$> $fieldname$ = " + "grpc::Marshallers.Create(__Helper_SerializeMessage, " + "context => __Helper_DeserializeMessage(context, $type$.Parser));\n", + "fieldname", GetMarshallerFieldName(message), "type", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(message)); + } + out->Print("\n"); +} + +void GenerateStaticMethodField(Printer* out, const MethodDescriptor* method) { + GenerateGeneratedCodeAttribute(out); + out->Print( + "static readonly grpc::Method<$request$, $response$> $fieldname$ = new " + "grpc::Method<$request$, $response$>(\n", + "fieldname", GetMethodFieldName(method), "request", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()), "response", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type())); + out->Indent(); + out->Indent(); + out->Print("$methodtype$,\n", "methodtype", GetCSharpMethodType(method)); + out->Print("$servicenamefield$,\n", "servicenamefield", + GetServiceNameFieldName()); + out->Print("\"$methodname$\",\n", "methodname", method->name()); + out->Print("$requestmarshaller$,\n", "requestmarshaller", + GetMarshallerFieldName(method->input_type())); + out->Print("$responsemarshaller$);\n", "responsemarshaller", + GetMarshallerFieldName(method->output_type())); + out->Print("\n"); + out->Outdent(); + out->Outdent(); +} + +void GenerateServiceDescriptorProperty(Printer* out, + const ServiceDescriptor* service) { + std::ostringstream index; + index << service->index(); + out->Print("/// Service descriptor\n"); + out->Print( + "public static global::Google.Protobuf.Reflection.ServiceDescriptor " + "Descriptor\n"); + out->Print("{\n"); + out->Print(" get { return $umbrella$.Descriptor.Services[$index$]; }\n", + "umbrella", + GRPC_CUSTOM_CSHARP_GETREFLECTIONCLASSNAME(service->file()), + "index", index.str()); + out->Print("}\n"); + out->Print("\n"); +} + +void GenerateServerClass(Printer* out, const ServiceDescriptor* service) { + out->Print( + "/// Base class for server-side implementations of " + "$servicename$\n", + "servicename", GetServiceClassName(service)); + out->Print( + "[grpc::BindServiceMethod(typeof($classname$), " + "\"BindService\")]\n", + "classname", GetServiceClassName(service)); + out->Print("public abstract partial class $name$\n", "name", + GetServerClassName(service)); + out->Print("{\n"); + out->Indent(); + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + GenerateDocCommentServerMethod(out, method); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public virtual $returntype$ " + "$methodname$($request$$response_stream_maybe$, " + "grpc::ServerCallContext context)\n", + "methodname", method->name(), "returntype", + GetMethodReturnTypeServer(method), "request", + GetMethodRequestParamServer(method), "response_stream_maybe", + GetMethodResponseStreamMaybe(method)); + out->Print("{\n"); + out->Indent(); + out->Print( + "throw new grpc::RpcException(" + "new grpc::Status(grpc::StatusCode.Unimplemented, \"\"));\n"); + out->Outdent(); + out->Print("}\n\n"); + } + out->Outdent(); + out->Print("}\n"); + out->Print("\n"); +} + +void GenerateClientStub(Printer* out, const ServiceDescriptor* service) { + out->Print("/// Client for $servicename$\n", "servicename", + GetServiceClassName(service)); + out->Print("public partial class $name$ : grpc::ClientBase<$name$>\n", "name", + GetClientClassName(service)); + out->Print("{\n"); + out->Indent(); + + // constructors + out->Print( + "/// Creates a new client for $servicename$\n" + "/// The channel to use to make remote " + "calls.\n", + "servicename", GetServiceClassName(service)); + GenerateGeneratedCodeAttribute(out); + out->Print("public $name$(grpc::ChannelBase channel) : base(channel)\n", + "name", GetClientClassName(service)); + out->Print("{\n"); + out->Print("}\n"); + out->Print( + "/// Creates a new client for $servicename$ that uses a custom " + "CallInvoker.\n" + "/// The callInvoker to use to make remote " + "calls.\n", + "servicename", GetServiceClassName(service)); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public $name$(grpc::CallInvoker callInvoker) : base(callInvoker)\n", + "name", GetClientClassName(service)); + out->Print("{\n"); + out->Print("}\n"); + out->Print( + "/// Protected parameterless constructor to allow creation" + " of test doubles.\n"); + GenerateGeneratedCodeAttribute(out); + out->Print("protected $name$() : base()\n", "name", + GetClientClassName(service)); + out->Print("{\n"); + out->Print("}\n"); + out->Print( + "/// Protected constructor to allow creation of configured " + "clients.\n" + "/// The client configuration.\n"); + GenerateGeneratedCodeAttribute(out); + out->Print( + "protected $name$(ClientBaseConfiguration configuration)" + " : base(configuration)\n", + "name", GetClientClassName(service)); + out->Print("{\n"); + out->Print("}\n\n"); + + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + if (!method->client_streaming() && !method->server_streaming()) { + // unary calls have an extra synchronous stub method + GenerateDocCommentClientMethod(out, method, true, false); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public virtual $response$ $methodname$($request$ request, " + "grpc::Metadata " + "headers = null, global::System.DateTime? deadline = null, " + "global::System.Threading.CancellationToken " + "cancellationToken = " + "default(global::System.Threading.CancellationToken))\n", + "methodname", method->name(), "request", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()), "response", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type())); + out->Print("{\n"); + out->Indent(); + out->Print( + "return $methodname$(request, new grpc::CallOptions(headers, " + "deadline, " + "cancellationToken));\n", + "methodname", method->name()); + out->Outdent(); + out->Print("}\n"); + + // overload taking CallOptions as a param + GenerateDocCommentClientMethod(out, method, true, true); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public virtual $response$ $methodname$($request$ request, " + "grpc::CallOptions options)\n", + "methodname", method->name(), "request", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()), "response", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type())); + out->Print("{\n"); + out->Indent(); + out->Print( + "return CallInvoker.BlockingUnaryCall($methodfield$, null, options, " + "request);\n", + "methodfield", GetMethodFieldName(method)); + out->Outdent(); + out->Print("}\n"); + } + + std::string method_name = method->name(); + if (!method->client_streaming() && !method->server_streaming()) { + method_name += "Async"; // prevent name clash with synchronous method. + } + GenerateDocCommentClientMethod(out, method, false, false); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public virtual $returntype$ " + "$methodname$($request_maybe$grpc::Metadata " + "headers = null, global::System.DateTime? deadline = null, " + "global::System.Threading.CancellationToken " + "cancellationToken = " + "default(global::System.Threading.CancellationToken))\n", + "methodname", method_name, "request_maybe", + GetMethodRequestParamMaybe(method), "returntype", + GetMethodReturnTypeClient(method)); + out->Print("{\n"); + out->Indent(); + + out->Print( + "return $methodname$($request_maybe$new grpc::CallOptions(headers, " + "deadline, " + "cancellationToken));\n", + "methodname", method_name, "request_maybe", + GetMethodRequestParamMaybe(method, true)); + out->Outdent(); + out->Print("}\n"); + + // overload taking CallOptions as a param + GenerateDocCommentClientMethod(out, method, false, true); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public virtual $returntype$ " + "$methodname$($request_maybe$grpc::CallOptions " + "options)\n", + "methodname", method_name, "request_maybe", + GetMethodRequestParamMaybe(method), "returntype", + GetMethodReturnTypeClient(method)); + out->Print("{\n"); + out->Indent(); + if (!method->client_streaming() && !method->server_streaming()) { + // Non-Streaming + out->Print( + "return CallInvoker.AsyncUnaryCall($methodfield$, null, options, " + "request);\n", + "methodfield", GetMethodFieldName(method)); + } else if (method->client_streaming() && !method->server_streaming()) { + // Client Streaming Only + out->Print( + "return CallInvoker.AsyncClientStreamingCall($methodfield$, null, " + "options);\n", + "methodfield", GetMethodFieldName(method)); + } else if (!method->client_streaming() && method->server_streaming()) { + // Server Streaming Only + out->Print( + "return CallInvoker.AsyncServerStreamingCall($methodfield$, null, " + "options, request);\n", + "methodfield", GetMethodFieldName(method)); + } else { + // Bi-Directional Streaming + out->Print( + "return CallInvoker.AsyncDuplexStreamingCall($methodfield$, null, " + "options);\n", + "methodfield", GetMethodFieldName(method)); + } + out->Outdent(); + out->Print("}\n"); + } + + // override NewInstance method + out->Print( + "/// Creates a new instance of client from given " + "ClientBaseConfiguration.\n"); + GenerateGeneratedCodeAttribute(out); + out->Print( + "protected override $name$ NewInstance(ClientBaseConfiguration " + "configuration)\n", + "name", GetClientClassName(service)); + out->Print("{\n"); + out->Indent(); + out->Print("return new $name$(configuration);\n", "name", + GetClientClassName(service)); + out->Outdent(); + out->Print("}\n"); + + out->Outdent(); + out->Print("}\n"); + out->Print("\n"); +} + +void GenerateBindServiceMethod(Printer* out, const ServiceDescriptor* service) { + out->Print( + "/// Creates service definition that can be registered with a " + "server\n"); + out->Print( + "/// An object implementing the server-side" + " handling logic.\n"); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public static grpc::ServerServiceDefinition BindService($implclass$ " + "serviceImpl)\n", + "implclass", GetServerClassName(service)); + out->Print("{\n"); + out->Indent(); + + out->Print("return grpc::ServerServiceDefinition.CreateBuilder()"); + out->Indent(); + out->Indent(); + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + out->Print("\n.AddMethod($methodfield$, serviceImpl.$methodname$)", + "methodfield", GetMethodFieldName(method), "methodname", + method->name()); + } + out->Print(".Build();\n"); + out->Outdent(); + out->Outdent(); + + out->Outdent(); + out->Print("}\n"); + out->Print("\n"); +} + +void GenerateBindServiceWithBinderMethod(Printer* out, + const ServiceDescriptor* service) { + out->Print( + "/// Register service method with a service " + "binder with or without implementation. Useful when customizing the " + "service binding logic.\n" + "/// Note: this method is part of an experimental API that can change or " + "be " + "removed without any prior notice.\n"); + out->Print( + "/// Service methods will be bound by " + "calling AddMethod on this object." + "\n"); + out->Print( + "/// An object implementing the server-side" + " handling logic.\n"); + GenerateGeneratedCodeAttribute(out); + out->Print( + "public static void BindService(grpc::ServiceBinderBase serviceBinder, " + "$implclass$ " + "serviceImpl)\n", + "implclass", GetServerClassName(service)); + out->Print("{\n"); + out->Indent(); + + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + out->Print( + "serviceBinder.AddMethod($methodfield$, serviceImpl == null ? null : " + "new $servermethodtype$<$inputtype$, $outputtype$>(" + "serviceImpl.$methodname$));\n", + "methodfield", GetMethodFieldName(method), "servermethodtype", + GetCSharpServerMethodType(method), "inputtype", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->input_type()), "outputtype", + GRPC_CUSTOM_CSHARP_GETCLASSNAME(method->output_type()), "methodname", + method->name()); + } + + out->Outdent(); + out->Print("}\n"); + out->Print("\n"); +} + +void GenerateService(Printer* out, const ServiceDescriptor* service, + bool generate_client, bool generate_server, + bool internal_access) { + GenerateDocCommentBody(out, service); + + out->Print("$access_level$ static partial class $classname$\n", + "access_level", GetAccessLevel(internal_access), "classname", + GetServiceClassName(service)); + out->Print("{\n"); + out->Indent(); + out->Print("static readonly string $servicenamefield$ = \"$servicename$\";\n", + "servicenamefield", GetServiceNameFieldName(), "servicename", + service->full_name()); + out->Print("\n"); + + GenerateMarshallerFields(out, service); + for (int i = 0; i < service->method_count(); i++) { + GenerateStaticMethodField(out, service->method(i)); + } + GenerateServiceDescriptorProperty(out, service); + + if (generate_server) { + GenerateServerClass(out, service); + } + if (generate_client) { + GenerateClientStub(out, service); + } + if (generate_server) { + GenerateBindServiceMethod(out, service); + GenerateBindServiceWithBinderMethod(out, service); + } + + out->Outdent(); + out->Print("}\n"); +} + +} // anonymous namespace + +std::string GetServices(const FileDescriptor* file, bool generate_client, + bool generate_server, bool internal_access) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + + StringOutputStream output_stream(&output); + Printer out(&output_stream, '$'); + + // Don't write out any output if there no services, to avoid empty service + // files being generated for proto files that don't declare any. + if (file->service_count() == 0) { + return output; + } + + // Write out a file header. + out.Print("// \n"); + out.Print( + "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"); + out.Print("// source: $filename$\n", "filename", file->name()); + out.Print("// \n"); + + // use C++ style as there are no file-level XML comments in .NET + std::string leading_comments = GetCsharpComments(file, true); + if (!leading_comments.empty()) { + out.Print("// Original file comments:\n"); + out.PrintRaw(leading_comments.c_str()); + } + + out.Print("#pragma warning disable 0414, 1591\n"); + + out.Print("#region Designer generated code\n"); + out.Print("\n"); + out.Print("using grpc = global::Grpc.Core;\n"); + out.Print("\n"); + + std::string file_namespace = GRPC_CUSTOM_CSHARP_GETFILENAMESPACE(file); + if (file_namespace != "") { + out.Print("namespace $namespace$ {\n", "namespace", file_namespace); + out.Indent(); + } + for (int i = 0; i < file->service_count(); i++) { + GenerateService(&out, file->service(i), generate_client, generate_server, + internal_access); + } + if (file_namespace != "") { + out.Outdent(); + out.Print("}\n"); + } + out.Print("#endregion\n"); + } + return output; +} + +} // namespace grpc_csharp_generator diff --git a/src/compiler/csharp_plugin.cc b/src/compiler/csharp_plugin.cc new file mode 100644 index 00000000..f63bb4de --- /dev/null +++ b/src/compiler/csharp_plugin.cc @@ -0,0 +1,87 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates C# gRPC service interface out of Protobuf IDL. + +#include + +#include "src/compiler/config.h" +#include "src/compiler/csharp_generator.h" +#include "src/compiler/csharp_generator_helpers.h" + +class CSharpGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { + public: + CSharpGrpcGenerator() {} + ~CSharpGrpcGenerator() {} + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + bool Generate(const grpc::protobuf::FileDescriptor* file, + const std::string& parameter, + grpc::protobuf::compiler::GeneratorContext* context, + std::string* error) const override { + std::vector > options; + grpc::protobuf::compiler::ParseGeneratorParameter(parameter, &options); + + bool generate_client = true; + bool generate_server = true; + bool internal_access = false; + // the suffix that will get appended to the name generated from the name + // of the original .proto file + std::string file_suffix = "Grpc.cs"; + for (size_t i = 0; i < options.size(); i++) { + if (options[i].first == "no_client") { + generate_client = false; + } else if (options[i].first == "no_server") { + generate_server = false; + } else if (options[i].first == "internal_access") { + internal_access = true; + } else if (options[i].first == "file_suffix") { + file_suffix = options[i].second; + } else { + *error = "Unknown generator option: " + options[i].first; + return false; + } + } + + std::string code = grpc_csharp_generator::GetServices( + file, generate_client, generate_server, internal_access); + if (code.size() == 0) { + return true; // don't generate a file if there are no services + } + + // Get output file name. + std::string file_name; + if (!grpc_csharp_generator::ServicesFilename(file, file_suffix, + file_name)) { + return false; + } + std::unique_ptr output( + context->Open(file_name)); + grpc::protobuf::io::CodedOutputStream coded_out(output.get()); + coded_out.WriteRaw(code.data(), code.size()); + return true; + } +}; + +int main(int argc, char* argv[]) { + CSharpGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/node_generator.cc b/src/compiler/node_generator.cc new file mode 100644 index 00000000..8c9f0263 --- /dev/null +++ b/src/compiler/node_generator.cc @@ -0,0 +1,276 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/node_generator.h" + +#include + +#include "src/compiler/config.h" +#include "src/compiler/generator_helpers.h" +#include "src/compiler/node_generator_helpers.h" + +using grpc::protobuf::Descriptor; +using grpc::protobuf::FileDescriptor; +using grpc::protobuf::MethodDescriptor; +using grpc::protobuf::ServiceDescriptor; +using grpc::protobuf::io::Printer; +using grpc::protobuf::io::StringOutputStream; +using std::map; + +namespace grpc_node_generator { +namespace { + +// Returns the alias we assign to the module of the given .proto filename +// when importing. Copied entirely from +// github:google/protobuf/src/google/protobuf/compiler/js/js_generator.cc#L154 +std::string ModuleAlias(const std::string filename) { + // This scheme could technically cause problems if a file includes any 2 of: + // foo/bar_baz.proto + // foo_bar_baz.proto + // foo_bar/baz.proto + // + // We'll worry about this problem if/when we actually see it. This name isn't + // exposed to users so we can change it later if we need to. + std::string basename = grpc_generator::StripProto(filename); + basename = grpc_generator::StringReplace(basename, "-", "$"); + basename = grpc_generator::StringReplace(basename, "/", "_"); + basename = grpc_generator::StringReplace(basename, ".", "_"); + return basename + "_pb"; +} + +// Given a filename like foo/bar/baz.proto, returns the corresponding JavaScript +// message file foo/bar/baz.js +std::string GetJSMessageFilename(const std::string& filename) { + std::string name = filename; + return grpc_generator::StripProto(name) + "_pb.js"; +} + +// Given a filename like foo/bar/baz.proto, returns the root directory +// path ../../ +std::string GetRootPath(const std::string& from_filename, + const std::string& to_filename) { + if (to_filename.find("google/protobuf") == 0) { + // Well-known types (.proto files in the google/protobuf directory) are + // assumed to come from the 'google-protobuf' npm package. We may want to + // generalize this exception later by letting others put generated code in + // their own npm packages. + return "google-protobuf/"; + } + size_t slashes = std::count(from_filename.begin(), from_filename.end(), '/'); + if (slashes == 0) { + return "./"; + } + std::string result = ""; + for (size_t i = 0; i < slashes; i++) { + result += "../"; + } + return result; +} + +// Return the relative path to load to_file from the directory containing +// from_file, assuming that both paths are relative to the same directory +std::string GetRelativePath(const std::string& from_file, + const std::string& to_file) { + return GetRootPath(from_file, to_file) + to_file; +} + +/* Finds all message types used in all services in the file, and returns them + * as a map of fully qualified message type name to message descriptor */ +map GetAllMessages(const FileDescriptor* file) { + map message_types; + for (int service_num = 0; service_num < file->service_count(); + service_num++) { + const ServiceDescriptor* service = file->service(service_num); + for (int method_num = 0; method_num < service->method_count(); + method_num++) { + const MethodDescriptor* method = service->method(method_num); + const Descriptor* input_type = method->input_type(); + const Descriptor* output_type = method->output_type(); + message_types[input_type->full_name()] = input_type; + message_types[output_type->full_name()] = output_type; + } + } + return message_types; +} + +std::string MessageIdentifierName(const std::string& name) { + return grpc_generator::StringReplace(name, ".", "_"); +} + +std::string NodeObjectPath(const Descriptor* descriptor) { + std::string module_alias = ModuleAlias(descriptor->file()->name()); + std::string name = descriptor->full_name(); + grpc_generator::StripPrefix(&name, descriptor->file()->package() + "."); + return module_alias + "." + name; +} + +// Prints out the message serializer and deserializer functions +void PrintMessageTransformer(const Descriptor* descriptor, Printer* out, + const Parameters& params) { + map template_vars; + std::string full_name = descriptor->full_name(); + template_vars["identifier_name"] = MessageIdentifierName(full_name); + template_vars["name"] = full_name; + template_vars["node_name"] = NodeObjectPath(descriptor); + // Print the serializer + out->Print(template_vars, "function serialize_$identifier_name$(arg) {\n"); + out->Indent(); + out->Print(template_vars, "if (!(arg instanceof $node_name$)) {\n"); + out->Indent(); + out->Print(template_vars, + "throw new Error('Expected argument of type $name$');\n"); + out->Outdent(); + out->Print("}\n"); + if (params.minimum_node_version > 5) { + // Node version is > 5, we should use Buffer.from + out->Print("return Buffer.from(arg.serializeBinary());\n"); + } else { + out->Print("return new Buffer(arg.serializeBinary());\n"); + } + out->Outdent(); + out->Print("}\n\n"); + + // Print the deserializer + out->Print(template_vars, + "function deserialize_$identifier_name$(buffer_arg) {\n"); + out->Indent(); + out->Print( + template_vars, + "return $node_name$.deserializeBinary(new Uint8Array(buffer_arg));\n"); + out->Outdent(); + out->Print("}\n\n"); +} + +void PrintMethod(const MethodDescriptor* method, Printer* out) { + const Descriptor* input_type = method->input_type(); + const Descriptor* output_type = method->output_type(); + map vars; + vars["service_name"] = method->service()->full_name(); + vars["name"] = method->name(); + vars["input_type"] = NodeObjectPath(input_type); + vars["input_type_id"] = MessageIdentifierName(input_type->full_name()); + vars["output_type"] = NodeObjectPath(output_type); + vars["output_type_id"] = MessageIdentifierName(output_type->full_name()); + vars["client_stream"] = method->client_streaming() ? "true" : "false"; + vars["server_stream"] = method->server_streaming() ? "true" : "false"; + out->Print("{\n"); + out->Indent(); + out->Print(vars, "path: '/$service_name$/$name$',\n"); + out->Print(vars, "requestStream: $client_stream$,\n"); + out->Print(vars, "responseStream: $server_stream$,\n"); + out->Print(vars, "requestType: $input_type$,\n"); + out->Print(vars, "responseType: $output_type$,\n"); + out->Print(vars, "requestSerialize: serialize_$input_type_id$,\n"); + out->Print(vars, "requestDeserialize: deserialize_$input_type_id$,\n"); + out->Print(vars, "responseSerialize: serialize_$output_type_id$,\n"); + out->Print(vars, "responseDeserialize: deserialize_$output_type_id$,\n"); + out->Outdent(); + out->Print("}"); +} + +// Prints out the service descriptor object +void PrintService(const ServiceDescriptor* service, Printer* out) { + map template_vars; + out->Print(GetNodeComments(service, true).c_str()); + template_vars["name"] = service->name(); + out->Print(template_vars, "var $name$Service = exports.$name$Service = {\n"); + out->Indent(); + for (int i = 0; i < service->method_count(); i++) { + std::string method_name = + grpc_generator::LowercaseFirstLetter(service->method(i)->name()); + out->Print(GetNodeComments(service->method(i), true).c_str()); + out->Print("$method_name$: ", "method_name", method_name); + PrintMethod(service->method(i), out); + out->Print(",\n"); + out->Print(GetNodeComments(service->method(i), false).c_str()); + } + out->Outdent(); + out->Print("};\n\n"); + out->Print(template_vars, + "exports.$name$Client = " + "grpc.makeGenericClientConstructor($name$Service);\n"); + out->Print(GetNodeComments(service, false).c_str()); +} + +void PrintImports(const FileDescriptor* file, Printer* out) { + out->Print("var grpc = require('grpc');\n"); + if (file->message_type_count() > 0) { + std::string file_path = + GetRelativePath(file->name(), GetJSMessageFilename(file->name())); + out->Print("var $module_alias$ = require('$file_path$');\n", "module_alias", + ModuleAlias(file->name()), "file_path", file_path); + } + + for (int i = 0; i < file->dependency_count(); i++) { + std::string file_path = GetRelativePath( + file->name(), GetJSMessageFilename(file->dependency(i)->name())); + out->Print("var $module_alias$ = require('$file_path$');\n", "module_alias", + ModuleAlias(file->dependency(i)->name()), "file_path", + file_path); + } + out->Print("\n"); +} + +void PrintTransformers(const FileDescriptor* file, Printer* out, + const Parameters& params) { + map messages = GetAllMessages(file); + for (std::map::iterator it = messages.begin(); + it != messages.end(); it++) { + PrintMessageTransformer(it->second, out, params); + } + out->Print("\n"); +} + +void PrintServices(const FileDescriptor* file, Printer* out) { + for (int i = 0; i < file->service_count(); i++) { + PrintService(file->service(i), out); + } +} +} // namespace + +std::string GenerateFile(const FileDescriptor* file, const Parameters& params) { + std::string output; + { + StringOutputStream output_stream(&output); + Printer out(&output_stream, '$'); + + if (file->service_count() == 0) { + return output; + } + out.Print("// GENERATED CODE -- DO NOT EDIT!\n\n"); + + std::string leading_comments = GetNodeComments(file, true); + if (!leading_comments.empty()) { + out.Print("// Original file comments:\n"); + out.PrintRaw(leading_comments.c_str()); + } + + out.Print("'use strict';\n"); + + PrintImports(file, &out); + + PrintTransformers(file, &out, params); + + PrintServices(file, &out); + + out.Print(GetNodeComments(file, false).c_str()); + } + return output; +} + +} // namespace grpc_node_generator diff --git a/src/compiler/node_plugin.cc b/src/compiler/node_plugin.cc new file mode 100644 index 00000000..9334024e --- /dev/null +++ b/src/compiler/node_plugin.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates Node gRPC service interface out of Protobuf IDL. + +#include + +#include "src/compiler/config.h" +#include "src/compiler/node_generator.h" +#include "src/compiler/node_generator_helpers.h" + +using grpc_node_generator::GenerateFile; +using grpc_node_generator::GetJSServiceFilename; + +class NodeGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { + public: + NodeGrpcGenerator() {} + ~NodeGrpcGenerator() {} + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + bool Generate(const grpc::protobuf::FileDescriptor* file, + const std::string& parameter, + grpc::protobuf::compiler::GeneratorContext* context, + std::string* error) const override { + grpc_node_generator::Parameters generator_parameters; + generator_parameters.minimum_node_version = 4; + + if (!parameter.empty()) { + std::vector parameters_list = + grpc_generator::tokenize(parameter, ","); + for (auto parameter_string = parameters_list.begin(); + parameter_string != parameters_list.end(); parameter_string++) { + std::vector param = + grpc_generator::tokenize(*parameter_string, "="); + if (param[0] == "minimum_node_version") { + sscanf(param[1].c_str(), "%d", + &generator_parameters.minimum_node_version); + } else { + *error = std::string("Unknown parameter: ") + *parameter_string; + return false; + } + } + } + + std::string code = GenerateFile(file, generator_parameters); + if (code.size() == 0) { + return true; + } + + // Get output file name + std::string file_name = GetJSServiceFilename(file->name()); + + std::unique_ptr output( + context->Open(file_name)); + grpc::protobuf::io::CodedOutputStream coded_out(output.get()); + coded_out.WriteRaw(code.data(), code.size()); + return true; + } +}; + +int main(int argc, char* argv[]) { + NodeGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/objective_c_generator.cc b/src/compiler/objective_c_generator.cc new file mode 100644 index 00000000..cd05ae78 --- /dev/null +++ b/src/compiler/objective_c_generator.cc @@ -0,0 +1,451 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/objective_c_generator.h" + +#include +#include +#include + +#include + +#include "src/compiler/config.h" +#include "src/compiler/objective_c_generator_helpers.h" + +using ::google::protobuf::compiler::objectivec::ClassName; +using ::grpc::protobuf::FileDescriptor; +using ::grpc::protobuf::MethodDescriptor; +using ::grpc::protobuf::ServiceDescriptor; +using ::grpc::protobuf::io::Printer; +using ::std::map; +using ::std::set; + +namespace grpc_objective_c_generator { +namespace { + +void PrintProtoRpcDeclarationAsPragma(Printer* printer, + const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + vars["client_stream"] = method->client_streaming() ? "stream " : ""; + vars["server_stream"] = method->server_streaming() ? "stream " : ""; + + printer->Print(vars, + "#pragma mark $method_name$($client_stream$$request_type$)" + " returns ($server_stream$$response_type$)\n\n"); +} + +template +static void PrintAllComments(const DescriptorType* desc, Printer* printer, + bool deprecated = false) { + std::vector comments; + grpc_generator::GetComment(desc, grpc_generator::COMMENTTYPE_LEADING_DETACHED, + &comments); + grpc_generator::GetComment(desc, grpc_generator::COMMENTTYPE_LEADING, + &comments); + grpc_generator::GetComment(desc, grpc_generator::COMMENTTYPE_TRAILING, + &comments); + if (comments.empty()) { + return; + } + printer->Print("/**\n"); + for (auto it = comments.begin(); it != comments.end(); ++it) { + printer->Print(" * "); + size_t start_pos = it->find_first_not_of(' '); + if (start_pos != std::string::npos) { + printer->PrintRaw(it->c_str() + start_pos); + } + printer->Print("\n"); + } + if (deprecated) { + printer->Print(" *\n"); + printer->Print( + " * This method belongs to a set of APIs that have been deprecated. " + "Using" + " the v2 API is recommended.\n"); + } + printer->Print(" */\n"); +} + +void PrintMethodSignature(Printer* printer, const MethodDescriptor* method, + const map< ::std::string, ::std::string>& vars) { + // Print comment + PrintAllComments(method, printer, true); + + printer->Print(vars, "- ($return_type$)$method_name$With"); + if (method->client_streaming()) { + printer->Print("RequestsWriter:(GRXWriter *)requestWriter"); + } else { + printer->Print(vars, "Request:($request_class$ *)request"); + } + + // TODO(jcanizales): Put this on a new line and align colons. + if (method->server_streaming()) { + printer->Print(vars, + " eventHandler:(void(^)(BOOL done, " + "$response_class$ *_Nullable response, NSError *_Nullable " + "error))eventHandler"); + } else { + printer->Print(vars, + " handler:(void(^)($response_class$ *_Nullable response, " + "NSError *_Nullable error))handler"); + } +} + +void PrintSimpleSignature(Printer* printer, const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + vars["method_name"] = + grpc_generator::LowercaseFirstLetter(vars["method_name"]); + vars["return_type"] = "void"; + PrintMethodSignature(printer, method, vars); +} + +void PrintAdvancedSignature(Printer* printer, const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + vars["method_name"] = "RPCTo" + vars["method_name"]; + vars["return_type"] = "GRPCProtoCall *"; + PrintMethodSignature(printer, method, vars); +} + +void PrintV2Signature(Printer* printer, const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + if (method->client_streaming()) { + vars["return_type"] = "GRPCStreamingProtoCall *"; + } else { + vars["return_type"] = "GRPCUnaryProtoCall *"; + } + vars["method_name"] = + grpc_generator::LowercaseFirstLetter(vars["method_name"]); + + PrintAllComments(method, printer); + + printer->Print(vars, "- ($return_type$)$method_name$With"); + if (method->client_streaming()) { + printer->Print("ResponseHandler:(id)handler"); + } else { + printer->Print(vars, + "Message:($request_class$ *)message " + "responseHandler:(id)handler"); + } + printer->Print(" callOptions:(GRPCCallOptions *_Nullable)callOptions"); +} + +inline map< ::std::string, ::std::string> GetMethodVars( + const MethodDescriptor* method) { + map< ::std::string, ::std::string> res; + res["method_name"] = method->name(); + res["request_type"] = method->input_type()->name(); + res["response_type"] = method->output_type()->name(); + res["request_class"] = ClassName(method->input_type()); + res["response_class"] = ClassName(method->output_type()); + return res; +} + +void PrintMethodDeclarations(Printer* printer, const MethodDescriptor* method) { + map< ::std::string, ::std::string> vars = GetMethodVars(method); + + PrintProtoRpcDeclarationAsPragma(printer, method, vars); + + PrintSimpleSignature(printer, method, vars); + printer->Print(";\n\n"); + PrintAdvancedSignature(printer, method, vars); + printer->Print(";\n\n\n"); +} + +void PrintV2MethodDeclarations(Printer* printer, + const MethodDescriptor* method) { + map< ::std::string, ::std::string> vars = GetMethodVars(method); + + PrintProtoRpcDeclarationAsPragma(printer, method, vars); + + PrintV2Signature(printer, method, vars); + printer->Print(";\n\n"); +} + +void PrintSimpleImplementation(Printer* printer, const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + printer->Print("{\n"); + printer->Print(vars, " [[self RPCTo$method_name$With"); + if (method->client_streaming()) { + printer->Print("RequestsWriter:requestWriter"); + } else { + printer->Print("Request:request"); + } + if (method->server_streaming()) { + printer->Print(" eventHandler:eventHandler] start];\n"); + } else { + printer->Print(" handler:handler] start];\n"); + } + printer->Print("}\n"); +} + +void PrintAdvancedImplementation(Printer* printer, + const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + printer->Print("{\n"); + printer->Print(vars, " return [self RPCToMethod:@\"$method_name$\"\n"); + + printer->Print(" requestsWriter:"); + if (method->client_streaming()) { + printer->Print("requestWriter\n"); + } else { + printer->Print("[GRXWriter writerWithValue:request]\n"); + } + + printer->Print(vars, " responseClass:[$response_class$ class]\n"); + + printer->Print(" responsesWriteable:[GRXWriteable "); + if (method->server_streaming()) { + printer->Print("writeableWithEventHandler:eventHandler]];\n"); + } else { + printer->Print("writeableWithSingleHandler:handler]];\n"); + } + + printer->Print("}\n"); +} + +void PrintV2Implementation(Printer* printer, const MethodDescriptor* method, + map< ::std::string, ::std::string> vars) { + printer->Print(" {\n"); + if (method->client_streaming()) { + printer->Print(vars, " return [self RPCToMethod:@\"$method_name$\"\n"); + printer->Print(" responseHandler:handler\n"); + printer->Print(" callOptions:callOptions\n"); + printer->Print( + vars, " responseClass:[$response_class$ class]];\n}\n\n"); + } else { + printer->Print(vars, " return [self RPCToMethod:@\"$method_name$\"\n"); + printer->Print(" message:message\n"); + printer->Print(" responseHandler:handler\n"); + printer->Print(" callOptions:callOptions\n"); + printer->Print( + vars, " responseClass:[$response_class$ class]];\n}\n\n"); + } +} + +void PrintMethodImplementations(Printer* printer, + const MethodDescriptor* method, + const Parameters& generator_params) { + map< ::std::string, ::std::string> vars = GetMethodVars(method); + + PrintProtoRpcDeclarationAsPragma(printer, method, vars); + + if (!generator_params.no_v1_compatibility) { + // TODO(jcanizales): Print documentation from the method. + PrintSimpleSignature(printer, method, vars); + PrintSimpleImplementation(printer, method, vars); + + printer->Print("// Returns a not-yet-started RPC object.\n"); + PrintAdvancedSignature(printer, method, vars); + PrintAdvancedImplementation(printer, method, vars); + } + + PrintV2Signature(printer, method, vars); + PrintV2Implementation(printer, method, vars); +} + +} // namespace + +::std::string GetAllMessageClasses(const FileDescriptor* file) { + ::std::string output; + set< ::std::string> classes; + for (int i = 0; i < file->service_count(); i++) { + const auto service = file->service(i); + for (int i = 0; i < service->method_count(); i++) { + const auto method = service->method(i); + classes.insert(ClassName(method->input_type())); + classes.insert(ClassName(method->output_type())); + } + } + for (auto one_class : classes) { + output += "@class " + one_class + ";\n"; + } + + return output; +} + +::std::string GetProtocol(const ServiceDescriptor* service, + const Parameters& generator_params) { + ::std::string output; + + if (generator_params.no_v1_compatibility) return output; + + // Scope the output stream so it closes and finalizes output to the string. + grpc::protobuf::io::StringOutputStream output_stream(&output); + Printer printer(&output_stream, '$'); + + map< ::std::string, ::std::string> vars = { + {"service_class", ServiceClassName(service)}}; + + printer.Print(vars, + "/**\n" + " * The methods in this protocol belong to a set of old APIs " + "that have been deprecated. They do not\n" + " * recognize call options provided in the initializer. Using " + "the v2 protocol is recommended.\n" + " */\n"); + printer.Print(vars, "@protocol $service_class$ \n\n"); + for (int i = 0; i < service->method_count(); i++) { + PrintMethodDeclarations(&printer, service->method(i)); + } + printer.Print("@end\n\n"); + + return output; +} + +::std::string GetV2Protocol(const ServiceDescriptor* service) { + ::std::string output; + + // Scope the output stream so it closes and finalizes output to the string. + grpc::protobuf::io::StringOutputStream output_stream(&output); + Printer printer(&output_stream, '$'); + + map< ::std::string, ::std::string> vars = { + {"service_class", ServiceClassName(service) + "2"}}; + + printer.Print(vars, "@protocol $service_class$ \n\n"); + for (int i = 0; i < service->method_count(); i++) { + PrintV2MethodDeclarations(&printer, service->method(i)); + } + printer.Print("@end\n\n"); + + return output; +} + +::std::string GetInterface(const ServiceDescriptor* service, + const Parameters& generator_params) { + ::std::string output; + + // Scope the output stream so it closes and finalizes output to the string. + grpc::protobuf::io::StringOutputStream output_stream(&output); + Printer printer(&output_stream, '$'); + + map< ::std::string, ::std::string> vars = { + {"service_class", ServiceClassName(service)}}; + + printer.Print(vars, + "/**\n" + " * Basic service implementation, over gRPC, that only does\n" + " * marshalling and parsing.\n" + " */\n"); + printer.Print(vars, + "@interface $service_class$ :" + " GRPCProtoService<$service_class$2"); + if (!generator_params.no_v1_compatibility) { + printer.Print(vars, ", $service_class$"); + } + printer.Print(">\n"); + printer.Print( + "- (instancetype)initWithHost:(NSString *)host " + "callOptions:(GRPCCallOptions " + "*_Nullable)callOptions" + " NS_DESIGNATED_INITIALIZER;\n"); + printer.Print( + "+ (instancetype)serviceWithHost:(NSString *)host " + "callOptions:(GRPCCallOptions *_Nullable)callOptions;\n"); + if (!generator_params.no_v1_compatibility) { + printer.Print( + "// The following methods belong to a set of old APIs that have been " + "deprecated.\n"); + printer.Print("- (instancetype)initWithHost:(NSString *)host;\n"); + printer.Print("+ (instancetype)serviceWithHost:(NSString *)host;\n"); + } + printer.Print("@end\n"); + + return output; +} + +::std::string GetSource(const ServiceDescriptor* service, + const Parameters& generator_params) { + ::std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + grpc::protobuf::io::StringOutputStream output_stream(&output); + Printer printer(&output_stream, '$'); + + map< ::std::string, ::std::string> vars = { + {"service_name", service->name()}, + {"service_class", ServiceClassName(service)}, + {"package", service->file()->package()}}; + + printer.Print(vars, + "@implementation $service_class$\n\n" + "#pragma clang diagnostic push\n" + "#pragma clang diagnostic ignored " + "\"-Wobjc-designated-initializers\"\n\n" + "// Designated initializer\n" + "- (instancetype)initWithHost:(NSString *)host " + "callOptions:(GRPCCallOptions *_Nullable)callOptions {\n" + " return [super initWithHost:host\n" + " packageName:@\"$package$\"\n" + " serviceName:@\"$service_name$\"\n" + " callOptions:callOptions];\n" + "}\n\n"); + if (!generator_params.no_v1_compatibility) { + printer.Print(vars, + "- (instancetype)initWithHost:(NSString *)host {\n" + " return [super initWithHost:host\n" + " packageName:@\"$package$\"\n" + " serviceName:@\"$service_name$\"];\n" + "}\n\n"); + } + printer.Print("#pragma clang diagnostic pop\n\n"); + + if (!generator_params.no_v1_compatibility) { + printer.Print( + "// Override superclass initializer to disallow different" + " package and service names.\n" + "- (instancetype)initWithHost:(NSString *)host\n" + " packageName:(NSString *)packageName\n" + " serviceName:(NSString *)serviceName {\n" + " return [self initWithHost:host];\n" + "}\n\n"); + } + printer.Print( + "- (instancetype)initWithHost:(NSString *)host\n" + " packageName:(NSString *)packageName\n" + " serviceName:(NSString *)serviceName\n" + " callOptions:(GRPCCallOptions *)callOptions {\n" + " return [self initWithHost:host callOptions:callOptions];\n" + "}\n\n"); + + printer.Print("#pragma mark - Class Methods\n\n"); + if (!generator_params.no_v1_compatibility) { + printer.Print( + "+ (instancetype)serviceWithHost:(NSString *)host {\n" + " return [[self alloc] initWithHost:host];\n" + "}\n\n"); + } + printer.Print( + "+ (instancetype)serviceWithHost:(NSString *)host " + "callOptions:(GRPCCallOptions *_Nullable)callOptions {\n" + " return [[self alloc] initWithHost:host callOptions:callOptions];\n" + "}\n\n"); + + printer.Print("#pragma mark - Method Implementations\n\n"); + + for (int i = 0; i < service->method_count(); i++) { + PrintMethodImplementations(&printer, service->method(i), + generator_params); + } + + printer.Print("@end\n"); + } + return output; +} + +} // namespace grpc_objective_c_generator diff --git a/src/compiler/objective_c_plugin.cc b/src/compiler/objective_c_plugin.cc new file mode 100644 index 00000000..8cc3fdfe --- /dev/null +++ b/src/compiler/objective_c_plugin.cc @@ -0,0 +1,317 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates Objective C gRPC service interface out of Protobuf IDL. + +#include + +#include + +#include "src/compiler/config.h" +#include "src/compiler/objective_c_generator.h" +#include "src/compiler/objective_c_generator_helpers.h" + +using ::google::protobuf::compiler::objectivec:: + IsProtobufLibraryBundledProtoFile; +using ::google::protobuf::compiler::objectivec::ProtobufLibraryFrameworkName; +using ::grpc_objective_c_generator::FrameworkImport; +using ::grpc_objective_c_generator::LocalImport; +using ::grpc_objective_c_generator::PreprocIfElse; +using ::grpc_objective_c_generator::PreprocIfNot; +using ::grpc_objective_c_generator::SystemImport; + +namespace { + +inline ::std::string ImportProtoHeaders( + const grpc::protobuf::FileDescriptor* dep, const char* indent, + const ::std::string& framework, + const ::std::string& pb_runtime_import_prefix) { + ::std::string header = grpc_objective_c_generator::MessageHeaderName(dep); + + if (!IsProtobufLibraryBundledProtoFile(dep)) { + if (framework.empty()) { + return indent + LocalImport(header); + } else { + return indent + FrameworkImport(header, framework); + } + } + + ::std::string base_name = header; + grpc_generator::StripPrefix(&base_name, "google/protobuf/"); + ::std::string file_name = "GPB" + base_name; + // create the import code snippet + ::std::string framework_header = + ::std::string(ProtobufLibraryFrameworkName) + "/" + file_name; + ::std::string local_header = file_name; + if (!pb_runtime_import_prefix.empty()) { + local_header = pb_runtime_import_prefix + "/" + file_name; + } + + static const ::std::string kFrameworkImportsCondition = + "GPB_USE_PROTOBUF_FRAMEWORK_IMPORTS"; + return PreprocIfElse(kFrameworkImportsCondition, + indent + SystemImport(framework_header), + indent + LocalImport(local_header)); +} + +} // namespace + +class ObjectiveCGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { + public: + ObjectiveCGrpcGenerator() {} + virtual ~ObjectiveCGrpcGenerator() {} + + public: + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + virtual bool Generate(const grpc::protobuf::FileDescriptor* file, + const ::std::string& parameter, + grpc::protobuf::compiler::GeneratorContext* context, + ::std::string* error) const override { + if (file->service_count() == 0) { + // No services. Do nothing. + return true; + } + + bool grpc_local_import = false; + ::std::string framework; + ::std::string pb_runtime_import_prefix; + ::std::string grpc_local_import_prefix; + std::vector<::std::string> params_list = + grpc_generator::tokenize(parameter, ","); + for (auto param_str = params_list.begin(); param_str != params_list.end(); + ++param_str) { + std::vector<::std::string> param = + grpc_generator::tokenize(*param_str, "="); + if (param[0] == "generate_for_named_framework") { + if (param.size() != 2) { + *error = + std::string("Format: generate_for_named_framework="); + return false; + } else if (param[1].empty()) { + *error = + std::string("Name of framework cannot be empty for parameter: ") + + param[0]; + return false; + } + framework = param[1]; + } else if (param[0] == "runtime_import_prefix") { + if (param.size() != 2) { + *error = grpc::string("Format: runtime_import_prefix=dir/"); + return false; + } + pb_runtime_import_prefix = param[1]; + grpc_generator::StripSuffix(&pb_runtime_import_prefix, "/"); + } else if (param[0] == "grpc_local_import_prefix") { + grpc_local_import = true; + if (param.size() != 2) { + *error = grpc::string("Format: grpc_local_import_prefix=dir/"); + return false; + } + grpc_local_import_prefix = param[1]; + } + } + + static const ::std::string kNonNullBegin = "NS_ASSUME_NONNULL_BEGIN\n"; + static const ::std::string kNonNullEnd = "NS_ASSUME_NONNULL_END\n"; + static const ::std::string kProtocolOnly = "GPB_GRPC_PROTOCOL_ONLY"; + static const ::std::string kForwardDeclare = + "GPB_GRPC_FORWARD_DECLARE_MESSAGE_PROTO"; + + ::std::string file_name = + google::protobuf::compiler::objectivec::FilePath(file); + + grpc_objective_c_generator::Parameters generator_params; + generator_params.no_v1_compatibility = false; + + if (!parameter.empty()) { + std::vector parameters_list = + grpc_generator::tokenize(parameter, ","); + for (auto parameter_string = parameters_list.begin(); + parameter_string != parameters_list.end(); parameter_string++) { + std::vector param = + grpc_generator::tokenize(*parameter_string, "="); + if (param[0] == "no_v1_compatibility") { + generator_params.no_v1_compatibility = true; + } + } + } + + // Write out a file header. + ::std::string file_header = + "// Code generated by gRPC proto compiler. DO NOT EDIT!\n" + "// source: " + + file->name() + "\n\n"; + + { + // Generate .pbrpc.h + + ::std::string imports; + if (framework.empty()) { + imports = LocalImport(file_name + ".pbobjc.h"); + } else { + imports = FrameworkImport(file_name + ".pbobjc.h", framework); + } + + ::std::string system_imports; + if (grpc_local_import) { + system_imports = + LocalImport(grpc_local_import_prefix + "ProtoRPC/ProtoService.h"); + if (generator_params.no_v1_compatibility) { + system_imports += + LocalImport(grpc_local_import_prefix + "ProtoRPC/ProtoRPC.h"); + } else { + system_imports += LocalImport(grpc_local_import_prefix + + "ProtoRPC/ProtoRPCLegacy.h"); + system_imports += LocalImport(grpc_local_import_prefix + + "RxLibrary/GRXWriteable.h"); + system_imports += + LocalImport(grpc_local_import_prefix + "RxLibrary/GRXWriter.h"); + } + } else { + system_imports = SystemImport("ProtoRPC/ProtoService.h"); + if (generator_params.no_v1_compatibility) { + system_imports += SystemImport("ProtoRPC/ProtoRPC.h"); + } else { + system_imports += SystemImport("ProtoRPC/ProtoRPCLegacy.h"); + system_imports += SystemImport("RxLibrary/GRXWriteable.h"); + system_imports += SystemImport("RxLibrary/GRXWriter.h"); + } + } + + ::std::string forward_declarations = + "@class GRPCUnaryProtoCall;\n" + "@class GRPCStreamingProtoCall;\n" + "@class GRPCCallOptions;\n" + "@protocol GRPCProtoResponseHandler;\n"; + if (!generator_params.no_v1_compatibility) { + forward_declarations += "@class GRPCProtoCall;\n"; + } + forward_declarations += "\n"; + + ::std::string class_declarations = + grpc_objective_c_generator::GetAllMessageClasses(file); + + ::std::string class_imports; + for (int i = 0; i < file->dependency_count(); i++) { + class_imports += ImportProtoHeaders( + file->dependency(i), " ", framework, pb_runtime_import_prefix); + } + + ::std::string ng_protocols; + for (int i = 0; i < file->service_count(); i++) { + const grpc::protobuf::ServiceDescriptor* service = file->service(i); + ng_protocols += grpc_objective_c_generator::GetV2Protocol(service); + } + + ::std::string protocols; + for (int i = 0; i < file->service_count(); i++) { + const grpc::protobuf::ServiceDescriptor* service = file->service(i); + protocols += + grpc_objective_c_generator::GetProtocol(service, generator_params); + } + + ::std::string interfaces; + for (int i = 0; i < file->service_count(); i++) { + const grpc::protobuf::ServiceDescriptor* service = file->service(i); + interfaces += + grpc_objective_c_generator::GetInterface(service, generator_params); + } + + Write(context, file_name + ".pbrpc.h", + file_header + SystemImport("Foundation/Foundation.h") + "\n" + + PreprocIfNot(kForwardDeclare, imports) + "\n" + + PreprocIfNot(kProtocolOnly, system_imports) + "\n" + + class_declarations + "\n" + + PreprocIfNot(kForwardDeclare, class_imports) + "\n" + + forward_declarations + "\n" + kNonNullBegin + "\n" + + ng_protocols + protocols + "\n" + + PreprocIfNot(kProtocolOnly, interfaces) + "\n" + kNonNullEnd + + "\n"); + } + + { + // Generate .pbrpc.m + + ::std::string imports; + if (framework.empty()) { + imports = LocalImport(file_name + ".pbrpc.h") + + LocalImport(file_name + ".pbobjc.h"); + } else { + imports = FrameworkImport(file_name + ".pbrpc.h", framework) + + FrameworkImport(file_name + ".pbobjc.h", framework); + } + + if (grpc_local_import) { + if (generator_params.no_v1_compatibility) { + imports += + LocalImport(grpc_local_import_prefix + "ProtoRPC/ProtoRPC.h"); + } else { + imports += LocalImport(grpc_local_import_prefix + + "ProtoRPC/ProtoRPCLegacy.h"); + imports += LocalImport(grpc_local_import_prefix + + "RxLibrary/GRXWriter+Immediate.h"); + } + } else { + if (generator_params.no_v1_compatibility) { + imports += SystemImport("ProtoRPC/ProtoRPC.h"); + } else { + imports += SystemImport("ProtoRPC/ProtoRPCLegacy.h"); + imports += SystemImport("RxLibrary/GRXWriter+Immediate.h"); + } + } + + ::std::string class_imports; + for (int i = 0; i < file->dependency_count(); i++) { + class_imports += ImportProtoHeaders(file->dependency(i), "", framework, + pb_runtime_import_prefix); + } + + ::std::string definitions; + for (int i = 0; i < file->service_count(); i++) { + const grpc::protobuf::ServiceDescriptor* service = file->service(i); + definitions += + grpc_objective_c_generator::GetSource(service, generator_params); + } + + Write(context, file_name + ".pbrpc.m", + file_header + + PreprocIfNot(kProtocolOnly, imports + "\n" + class_imports + + "\n" + definitions)); + } + + return true; + } + + private: + // Write the given code into the given file. + void Write(grpc::protobuf::compiler::GeneratorContext* context, + const ::std::string& filename, const ::std::string& code) const { + std::unique_ptr output( + context->Open(filename)); + grpc::protobuf::io::CodedOutputStream coded_out(output.get()); + coded_out.WriteRaw(code.data(), code.size()); + } +}; + +int main(int argc, char* argv[]) { + ObjectiveCGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/php_generator.cc b/src/compiler/php_generator.cc new file mode 100644 index 00000000..78d95419 --- /dev/null +++ b/src/compiler/php_generator.cc @@ -0,0 +1,340 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/compiler/config.h" +#include "src/compiler/generator_helpers.h" +#include "src/compiler/php_generator_helpers.h" + +using google::protobuf::compiler::php::GeneratedClassName; +using grpc::protobuf::Descriptor; +using grpc::protobuf::FileDescriptor; +using grpc::protobuf::MethodDescriptor; +using grpc::protobuf::ServiceDescriptor; +using grpc::protobuf::io::Printer; +using grpc::protobuf::io::StringOutputStream; +using std::map; + +namespace grpc_php_generator { +namespace { + +std::string ConvertToPhpNamespace(const std::string& name) { + std::vector tokens = grpc_generator::tokenize(name, "."); + std::ostringstream oss; + for (unsigned int i = 0; i < tokens.size(); i++) { + oss << (i == 0 ? "" : "\\") + << grpc_generator::CapitalizeFirstLetter(tokens[i]); + } + return oss.str(); +} + +std::string PackageName(const FileDescriptor* file) { + if (file->options().has_php_namespace()) { + return file->options().php_namespace(); + } else { + return ConvertToPhpNamespace(file->package()); + } +} + +std::string MessageIdentifierName(const std::string& name, + const FileDescriptor* file) { + std::vector tokens = grpc_generator::tokenize(name, "."); + std::ostringstream oss; + if (PackageName(file) != "") { + oss << PackageName(file) << "\\"; + } + oss << grpc_generator::CapitalizeFirstLetter(tokens[tokens.size() - 1]); + return oss.str(); +} + +void PrintMethod(const MethodDescriptor* method, Printer* out) { + const Descriptor* input_type = method->input_type(); + const Descriptor* output_type = method->output_type(); + map vars; + vars["service_name"] = method->service()->full_name(); + vars["name"] = method->name(); + vars["input_type_id"] = + MessageIdentifierName(GeneratedClassName(input_type), input_type->file()); + vars["output_type_id"] = MessageIdentifierName( + GeneratedClassName(output_type), output_type->file()); + + out->Print("/**\n"); + out->Print(GetPHPComments(method, " *").c_str()); + if (method->client_streaming()) { + if (method->server_streaming()) { + vars["return_type_id"] = "\\Grpc\\BidiStreamingCall"; + } else { + vars["return_type_id"] = "\\Grpc\\ClientStreamingCall"; + } + out->Print(vars, + " * @param array $$metadata metadata\n" + " * @param array $$options call options\n" + " * @return $return_type_id$\n */\n" + "public function $name$($$metadata = [], " + "$$options = []) {\n"); + out->Indent(); + out->Indent(); + if (method->server_streaming()) { + out->Print("return $$this->_bidiRequest("); + } else { + out->Print("return $$this->_clientStreamRequest("); + } + out->Print(vars, + "'/$service_name$/$name$',\n" + "['\\$output_type_id$','decode'],\n" + "$$metadata, $$options);\n"); + } else { + if (method->server_streaming()) { + vars["return_type_id"] = "\\Grpc\\ServerStreamingCall"; + } else { + vars["return_type_id"] = "\\Grpc\\UnaryCall"; + } + out->Print(vars, + " * @param \\$input_type_id$ $$argument input argument\n" + " * @param array $$metadata metadata\n" + " * @param array $$options call options\n" + " * @return $return_type_id$\n */\n" + "public function $name$(\\$input_type_id$ $$argument,\n" + " $$metadata = [], $$options = []) {\n"); + out->Indent(); + out->Indent(); + if (method->server_streaming()) { + out->Print("return $$this->_serverStreamRequest("); + } else { + out->Print("return $$this->_simpleRequest("); + } + out->Print(vars, + "'/$service_name$/$name$',\n" + "$$argument,\n" + "['\\$output_type_id$', 'decode'],\n" + "$$metadata, $$options);\n"); + } + out->Outdent(); + out->Outdent(); + out->Print("}\n\n"); +} + +void PrintServerMethod(const MethodDescriptor* method, Printer* out) { + map vars; + const Descriptor* input_type = method->input_type(); + const Descriptor* output_type = method->output_type(); + vars["service_name"] = method->service()->full_name(); + vars["method_name"] = method->name(); + vars["input_type_id"] = + MessageIdentifierName(GeneratedClassName(input_type), input_type->file()); + vars["output_type_id"] = MessageIdentifierName( + GeneratedClassName(output_type), output_type->file()); + + out->Print("/**\n"); + out->Print(GetPHPComments(method, " *").c_str()); + + const char* method_template; + if (method->client_streaming() && method->server_streaming()) { + method_template = + " * @param \\Grpc\\ServerCallReader $$reader read client request data " + "of \\$input_type_id$\n" + " * @param \\Grpc\\ServerCallWriter $$writer write response data of " + "\\$output_type_id$\n" + " * @param \\Grpc\\ServerContext $$context server request context\n" + " * @return void\n" + " */\n" + "public function $method_name$(\n" + " \\Grpc\\ServerCallReader $$reader,\n" + " \\Grpc\\ServerCallWriter $$writer,\n" + " \\Grpc\\ServerContext $$context\n" + "): void {\n" + " $$context->setStatus(\\Grpc\\Status::unimplemented());\n" + " $$writer->finish();\n" + "}\n\n"; + } else if (method->client_streaming()) { + method_template = + " * @param \\Grpc\\ServerCallReader $$reader read client request data " + "of \\$input_type_id$\n" + " * @param \\Grpc\\ServerContext $$context server request context\n" + " * @return \\$output_type_id$ for response data, null if if error " + "occured\n" + " * initial metadata (if any) and status (if not ok) should be set " + "to $$context\n" + " */\n" + "public function $method_name$(\n" + " \\Grpc\\ServerCallReader $$reader,\n" + " \\Grpc\\ServerContext $$context\n" + "): ?\\$output_type_id$ {\n" + " $$context->setStatus(\\Grpc\\Status::unimplemented());\n" + " return null;\n" + "}\n\n"; + } else if (method->server_streaming()) { + method_template = + " * @param \\$input_type_id$ $$request client request\n" + " * @param \\Grpc\\ServerCallWriter $$writer write response data of " + "\\$output_type_id$\n" + " * @param \\Grpc\\ServerContext $$context server request context\n" + " * @return void\n" + " */\n" + "public function $method_name$(\n" + " \\$input_type_id$ $$request,\n" + " \\Grpc\\ServerCallWriter $$writer,\n" + " \\Grpc\\ServerContext $$context\n" + "): void {\n" + " $$context->setStatus(\\Grpc\\Status::unimplemented());\n" + " $$writer->finish();\n" + "}\n\n"; + } else { + method_template = + " * @param \\$input_type_id$ $$request client request\n" + " * @param \\Grpc\\ServerContext $$context server request context\n" + " * @return \\$output_type_id$ for response data, null if if error " + "occured\n" + " * initial metadata (if any) and status (if not ok) should be set " + "to $$context\n" + " */\n" + "public function $method_name$(\n" + " \\$input_type_id$ $$request,\n" + " \\Grpc\\ServerContext $$context\n" + "): ?\\$output_type_id$ {\n" + " $$context->setStatus(\\Grpc\\Status::unimplemented());\n" + " return null;\n" + "}\n\n"; + } + out->Print(vars, method_template); +} + +void PrintServerMethodDescriptors(const ServiceDescriptor* service, + Printer* out) { + map vars; + vars["service_name"] = service->full_name(); + + out->Print( + "/**\n" + " * Get the method descriptors of the service for server registration\n" + " *\n" + " * @return array of \\Grpc\\MethodDescriptor for the service methods\n" + " */\n" + "public final function getMethodDescriptors(): array\n{\n"); + out->Indent(); + out->Indent(); + out->Print("return [\n"); + out->Indent(); + out->Indent(); + for (int i = 0; i < service->method_count(); i++) { + auto method = service->method(i); + auto input_type = method->input_type(); + vars["method_name"] = method->name(); + vars["input_type_id"] = MessageIdentifierName( + GeneratedClassName(input_type), input_type->file()); + if (method->client_streaming() && method->server_streaming()) { + vars["call_type"] = "BIDI_STREAMING_CALL"; + } else if (method->client_streaming()) { + vars["call_type"] = "CLIENT_STREAMING_CALL"; + } else if (method->server_streaming()) { + vars["call_type"] = "SERVER_STREAMING_CALL"; + } else { + vars["call_type"] = "UNARY_CALL"; + } + out->Print( + vars, + "'/$service_name$/$method_name$' => new \\Grpc\\MethodDescriptor(\n" + " $$this,\n" + " '$method_name$',\n" + " '\\$input_type_id$',\n" + " \\Grpc\\MethodDescriptor::$call_type$\n" + "),\n"); + } + out->Outdent(); + out->Outdent(); + out->Print("];\n"); + out->Outdent(); + out->Outdent(); + out->Print("}\n\n"); +} + +// Prints out the service descriptor object +void PrintService(const ServiceDescriptor* service, + const std::string& class_suffix, bool is_server, + Printer* out) { + map vars; + out->Print("/**\n"); + out->Print(GetPHPComments(service, " *").c_str()); + out->Print(" */\n"); + vars["name"] = GetPHPServiceClassname(service, class_suffix, is_server); + vars["extends"] = is_server ? "" : "extends \\Grpc\\BaseStub "; + out->Print(vars, "class $name$ $extends${\n\n"); + out->Indent(); + out->Indent(); + if (!is_server) { + out->Print( + "/**\n * @param string $$hostname hostname\n" + " * @param array $$opts channel options\n" + " * @param \\Grpc\\Channel $$channel (optional) re-use channel object\n" + " */\n" + "public function __construct($$hostname, $$opts, " + "$$channel = null) {\n"); + out->Indent(); + out->Indent(); + out->Print("parent::__construct($$hostname, $$opts, $$channel);\n"); + out->Outdent(); + out->Outdent(); + out->Print("}\n\n"); + } + for (int i = 0; i < service->method_count(); i++) { + if (is_server) { + PrintServerMethod(service->method(i), out); + } else { + PrintMethod(service->method(i), out); + } + } + if (is_server) { + PrintServerMethodDescriptors(service, out); + } + out->Outdent(); + out->Outdent(); + out->Print("}\n"); +} +} // namespace + +std::string GenerateFile(const FileDescriptor* file, + const ServiceDescriptor* service, + const std::string& class_suffix, bool is_server) { + std::string output; + { + StringOutputStream output_stream(&output); + Printer out(&output_stream, '$'); + + out.Print(" vars; + std::string php_namespace = PackageName(file); + vars["package"] = php_namespace; + out.Print(vars, "namespace $package$;\n\n"); + + PrintService(service, class_suffix, is_server, &out); + } + return output; +} + +} // namespace grpc_php_generator diff --git a/src/compiler/php_plugin.cc b/src/compiler/php_plugin.cc new file mode 100644 index 00000000..7d4e4ce3 --- /dev/null +++ b/src/compiler/php_plugin.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates PHP gRPC service interface out of Protobuf IDL. + +#include + +#include "src/compiler/config.h" +#include "src/compiler/php_generator.h" +#include "src/compiler/php_generator_helpers.h" + +using google::protobuf::compiler::ParseGeneratorParameter; +using grpc_php_generator::GenerateFile; +using grpc_php_generator::GetPHPServiceFilename; + +class PHPGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { + public: + PHPGrpcGenerator() {} + ~PHPGrpcGenerator() {} + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + bool Generate(const grpc::protobuf::FileDescriptor* file, + const std::string& parameter, + grpc::protobuf::compiler::GeneratorContext* context, + std::string* error) const override { + if (file->service_count() == 0) { + return true; + } + + std::vector > options; + ParseGeneratorParameter(parameter, &options); + + bool generate_server = false; + std::string class_suffix; + for (size_t i = 0; i < options.size(); ++i) { + if (options[i].first == "class_suffix") { + class_suffix = options[i].second; + } else if (options[i].first == "generate_server") { + generate_server = true; + } else { + *error = "unsupported options: " + options[i].first; + return false; + } + } + + for (int i = 0; i < file->service_count(); i++) { + GenerateService(file, file->service(i), class_suffix, false, context); + if (generate_server) { + GenerateService(file, file->service(i), class_suffix, true, context); + } + } + + return true; + } + + private: + void GenerateService( + const grpc::protobuf::FileDescriptor* file, + const grpc::protobuf::ServiceDescriptor* service, + const std::string& class_suffix, bool is_server, + grpc::protobuf::compiler::GeneratorContext* context) const { + std::string code = GenerateFile(file, service, class_suffix, is_server); + + // Get output file name + std::string file_name = + GetPHPServiceFilename(file, service, class_suffix, is_server); + + std::unique_ptr output( + context->Open(file_name)); + grpc::protobuf::io::CodedOutputStream coded_out(output.get()); + coded_out.WriteRaw(code.data(), code.size()); + } +}; + +int main(int argc, char* argv[]) { + PHPGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/python_generator.cc b/src/compiler/python_generator.cc new file mode 100644 index 00000000..753fe1c8 --- /dev/null +++ b/src/compiler/python_generator.cc @@ -0,0 +1,925 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/python_generator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/compiler/config.h" +#include "src/compiler/generator_helpers.h" +#include "src/compiler/protobuf_plugin.h" +#include "src/compiler/python_generator_helpers.h" +#include "src/compiler/python_private_generator.h" + +using grpc::protobuf::FileDescriptor; +using grpc::protobuf::compiler::GeneratorContext; +using grpc::protobuf::io::CodedOutputStream; +using grpc::protobuf::io::ZeroCopyOutputStream; +using std::make_pair; +using std::map; +using std::pair; +using std::replace; +using std::set; +using std::tuple; +using std::vector; + +namespace grpc_python_generator { + +std::string generator_file_name; + +namespace { + +typedef map StringMap; +typedef vector StringVector; +typedef tuple StringPair; +typedef set StringPairSet; + +// Provides RAII indentation handling. Use as: +// { +// IndentScope raii_my_indent_var_name_here(my_py_printer); +// // constructor indented my_py_printer +// ... +// // destructor called at end of scope, un-indenting my_py_printer +// } +class IndentScope { + public: + explicit IndentScope(grpc_generator::Printer* printer) : printer_(printer) { + // NOTE(rbellevi): Two-space tabs are hard-coded in the protocol compiler. + // Doubling our indents and outdents guarantees compliance with PEP8. + printer_->Indent(); + printer_->Indent(); + } + + ~IndentScope() { + printer_->Outdent(); + printer_->Outdent(); + } + + private: + grpc_generator::Printer* printer_; +}; + +PrivateGenerator::PrivateGenerator(const GeneratorConfiguration& config, + const grpc_generator::File* file) + : config(config), file(file) {} + +void PrivateGenerator::PrintAllComments(StringVector comments, + grpc_generator::Printer* out) { + if (comments.empty()) { + // Python requires code structures like class and def to have + // a body, even if it is just "pass" or a docstring. We need + // to ensure not to generate empty bodies. We could do something + // smarter and more sophisticated, but at the moment, if there is + // no docstring to print, we simply emit "pass" to ensure validity + // of the generated code. + out->Print( + "\"\"\"Missing associated documentation comment in .proto " + "file.\"\"\"\n"); + return; + } + out->Print("\"\"\""); + for (StringVector::iterator it = comments.begin(); it != comments.end(); + ++it) { + size_t start_pos = it->find_first_not_of(' '); + if (start_pos != std::string::npos) { + out->PrintRaw(it->c_str() + start_pos); + } + out->Print("\n"); + } + out->Print("\"\"\"\n"); +} + +bool PrivateGenerator::PrintBetaServicer(const grpc_generator::Service* service, + grpc_generator::Printer* out) { + StringMap service_dict; + service_dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(service_dict, "class Beta$Service$Servicer(object):\n"); + { + IndentScope raii_class_indent(out); + out->Print( + "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n" + "\nIt is recommended to use the GA API (classes and functions in this\n" + "file not marked beta) for all further purposes. This class was " + "generated\n" + "only to ease transition from grpcio<0.15.0 to " + "grpcio>=0.15.0.\"\"\"\n"); + StringVector service_comments = service->GetAllComments(); + PrintAllComments(service_comments, out); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + std::string arg_name = + method->ClientStreaming() ? "request_iterator" : "request"; + StringMap method_dict; + method_dict["Method"] = method->name(); + method_dict["ArgName"] = arg_name; + out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n"); + { + IndentScope raii_method_indent(out); + StringVector method_comments = method->GetAllComments(); + PrintAllComments(method_comments, out); + out->Print("context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)\n"); + } + } + } + return true; +} + +bool PrivateGenerator::PrintBetaStub(const grpc_generator::Service* service, + grpc_generator::Printer* out) { + StringMap service_dict; + service_dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(service_dict, "class Beta$Service$Stub(object):\n"); + { + IndentScope raii_class_indent(out); + out->Print( + "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n" + "\nIt is recommended to use the GA API (classes and functions in this\n" + "file not marked beta) for all further purposes. This class was " + "generated\n" + "only to ease transition from grpcio<0.15.0 to " + "grpcio>=0.15.0.\"\"\"\n"); + StringVector service_comments = service->GetAllComments(); + PrintAllComments(service_comments, out); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + std::string arg_name = + method->ClientStreaming() ? "request_iterator" : "request"; + StringMap method_dict; + method_dict["Method"] = method->name(); + method_dict["ArgName"] = arg_name; + out->Print(method_dict, + "def $Method$(self, $ArgName$, timeout, metadata=None, " + "with_call=False, protocol_options=None):\n"); + { + IndentScope raii_method_indent(out); + StringVector method_comments = method->GetAllComments(); + PrintAllComments(method_comments, out); + out->Print("raise NotImplementedError()\n"); + } + if (!method->ServerStreaming()) { + out->Print(method_dict, "$Method$.future = None\n"); + } + } + } + return true; +} + +bool PrivateGenerator::PrintBetaServerFactory( + const std::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { + StringMap service_dict; + service_dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(service_dict, + "def beta_create_$Service$_server(servicer, pool=None, " + "pool_size=None, default_timeout=None, maximum_timeout=None):\n"); + { + IndentScope raii_create_server_indent(out); + out->Print( + "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n" + "\nIt is recommended to use the GA API (classes and functions in this\n" + "file not marked beta) for all further purposes. This function was\n" + "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0" + "\"\"\"\n"); + StringMap method_implementation_constructors; + StringMap input_message_modules_and_classes; + StringMap output_message_modules_and_classes; + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + const std::string method_implementation_constructor = + std::string(method->ClientStreaming() ? "stream_" : "unary_") + + std::string(method->ServerStreaming() ? "stream_" : "unary_") + + "inline"; + std::string input_message_module_and_class; + if (!method->get_module_and_message_path_input( + &input_message_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + std::string output_message_module_and_class; + if (!method->get_module_and_message_path_output( + &output_message_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + method_implementation_constructors.insert( + make_pair(method->name(), method_implementation_constructor)); + input_message_modules_and_classes.insert( + make_pair(method->name(), input_message_module_and_class)); + output_message_modules_and_classes.insert( + make_pair(method->name(), output_message_module_and_class)); + } + StringMap method_dict; + method_dict["PackageQualifiedServiceName"] = package_qualified_service_name; + out->Print("request_deserializers = {\n"); + for (StringMap::iterator name_and_input_module_class_pair = + input_message_modules_and_classes.begin(); + name_and_input_module_class_pair != + input_message_modules_and_classes.end(); + name_and_input_module_class_pair++) { + method_dict["MethodName"] = name_and_input_module_class_pair->first; + method_dict["InputTypeModuleAndClass"] = + name_and_input_module_class_pair->second; + IndentScope raii_indent(out); + out->Print(method_dict, + "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$InputTypeModuleAndClass$.FromString,\n"); + } + out->Print("}\n"); + out->Print("response_serializers = {\n"); + for (StringMap::iterator name_and_output_module_class_pair = + output_message_modules_and_classes.begin(); + name_and_output_module_class_pair != + output_message_modules_and_classes.end(); + name_and_output_module_class_pair++) { + method_dict["MethodName"] = name_and_output_module_class_pair->first; + method_dict["OutputTypeModuleAndClass"] = + name_and_output_module_class_pair->second; + IndentScope raii_indent(out); + out->Print(method_dict, + "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$OutputTypeModuleAndClass$.SerializeToString,\n"); + } + out->Print("}\n"); + out->Print("method_implementations = {\n"); + for (StringMap::iterator name_and_implementation_constructor = + method_implementation_constructors.begin(); + name_and_implementation_constructor != + method_implementation_constructors.end(); + name_and_implementation_constructor++) { + method_dict["Method"] = name_and_implementation_constructor->first; + method_dict["Constructor"] = name_and_implementation_constructor->second; + IndentScope raii_descriptions_indent(out); + const std::string method_name = + name_and_implementation_constructor->first; + out->Print(method_dict, + "(\'$PackageQualifiedServiceName$\', \'$Method$\'): " + "face_utilities.$Constructor$(servicer.$Method$),\n"); + } + out->Print("}\n"); + out->Print( + "server_options = beta_implementations.server_options(" + "request_deserializers=request_deserializers, " + "response_serializers=response_serializers, " + "thread_pool=pool, thread_pool_size=pool_size, " + "default_timeout=default_timeout, " + "maximum_timeout=maximum_timeout)\n"); + out->Print( + "return beta_implementations.server(method_implementations, " + "options=server_options)\n"); + } + return true; +} + +bool PrivateGenerator::PrintBetaStubFactory( + const std::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { + StringMap dict; + dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(dict, + "def beta_create_$Service$_stub(channel, host=None," + " metadata_transformer=None, pool=None, pool_size=None):\n"); + { + IndentScope raii_create_server_indent(out); + out->Print( + "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n" + "\nIt is recommended to use the GA API (classes and functions in this\n" + "file not marked beta) for all further purposes. This function was\n" + "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0" + "\"\"\"\n"); + StringMap method_cardinalities; + StringMap input_message_modules_and_classes; + StringMap output_message_modules_and_classes; + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + const std::string method_cardinality = + std::string(method->ClientStreaming() ? "STREAM" : "UNARY") + "_" + + std::string(method->ServerStreaming() ? "STREAM" : "UNARY"); + std::string input_message_module_and_class; + if (!method->get_module_and_message_path_input( + &input_message_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + std::string output_message_module_and_class; + if (!method->get_module_and_message_path_output( + &output_message_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + method_cardinalities.insert( + make_pair(method->name(), method_cardinality)); + input_message_modules_and_classes.insert( + make_pair(method->name(), input_message_module_and_class)); + output_message_modules_and_classes.insert( + make_pair(method->name(), output_message_module_and_class)); + } + StringMap method_dict; + method_dict["PackageQualifiedServiceName"] = package_qualified_service_name; + out->Print("request_serializers = {\n"); + for (StringMap::iterator name_and_input_module_class_pair = + input_message_modules_and_classes.begin(); + name_and_input_module_class_pair != + input_message_modules_and_classes.end(); + name_and_input_module_class_pair++) { + method_dict["MethodName"] = name_and_input_module_class_pair->first; + method_dict["InputTypeModuleAndClass"] = + name_and_input_module_class_pair->second; + IndentScope raii_indent(out); + out->Print(method_dict, + "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$InputTypeModuleAndClass$.SerializeToString,\n"); + } + out->Print("}\n"); + out->Print("response_deserializers = {\n"); + for (StringMap::iterator name_and_output_module_class_pair = + output_message_modules_and_classes.begin(); + name_and_output_module_class_pair != + output_message_modules_and_classes.end(); + name_and_output_module_class_pair++) { + method_dict["MethodName"] = name_and_output_module_class_pair->first; + method_dict["OutputTypeModuleAndClass"] = + name_and_output_module_class_pair->second; + IndentScope raii_indent(out); + out->Print(method_dict, + "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$OutputTypeModuleAndClass$.FromString,\n"); + } + out->Print("}\n"); + out->Print("cardinalities = {\n"); + for (StringMap::iterator name_and_cardinality = + method_cardinalities.begin(); + name_and_cardinality != method_cardinalities.end(); + name_and_cardinality++) { + method_dict["Method"] = name_and_cardinality->first; + method_dict["Cardinality"] = name_and_cardinality->second; + IndentScope raii_descriptions_indent(out); + out->Print(method_dict, + "\'$Method$\': cardinality.Cardinality.$Cardinality$,\n"); + } + out->Print("}\n"); + out->Print( + "stub_options = beta_implementations.stub_options(" + "host=host, metadata_transformer=metadata_transformer, " + "request_serializers=request_serializers, " + "response_deserializers=response_deserializers, " + "thread_pool=pool, thread_pool_size=pool_size)\n"); + out->Print(method_dict, + "return beta_implementations.dynamic_stub(channel, " + "\'$PackageQualifiedServiceName$\', " + "cardinalities, options=stub_options)\n"); + } + return true; +} + +bool PrivateGenerator::PrintStub( + const std::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { + StringMap dict; + dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(dict, "class $Service$Stub(object):\n"); + { + IndentScope raii_class_indent(out); + StringVector service_comments = service->GetAllComments(); + PrintAllComments(service_comments, out); + out->Print("\n"); + out->Print("def __init__(self, channel):\n"); + { + IndentScope raii_init_indent(out); + out->Print("\"\"\"Constructor.\n"); + out->Print("\n"); + out->Print("Args:\n"); + { + IndentScope raii_args_indent(out); + out->Print("channel: A grpc.Channel.\n"); + } + out->Print("\"\"\"\n"); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + std::string multi_callable_constructor = + std::string(method->ClientStreaming() ? "stream" : "unary") + "_" + + std::string(method->ServerStreaming() ? "stream" : "unary"); + std::string request_module_and_class; + if (!method->get_module_and_message_path_input( + &request_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + std::string response_module_and_class; + if (!method->get_module_and_message_path_output( + &response_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + StringMap method_dict; + method_dict["Method"] = method->name(); + method_dict["MultiCallableConstructor"] = multi_callable_constructor; + out->Print(method_dict, + "self.$Method$ = channel.$MultiCallableConstructor$(\n"); + { + method_dict["PackageQualifiedService"] = + package_qualified_service_name; + method_dict["RequestModuleAndClass"] = request_module_and_class; + method_dict["ResponseModuleAndClass"] = response_module_and_class; + IndentScope raii_first_attribute_indent(out); + IndentScope raii_second_attribute_indent(out); + out->Print(method_dict, "'/$PackageQualifiedService$/$Method$',\n"); + out->Print(method_dict, + "request_serializer=$RequestModuleAndClass$." + "SerializeToString,\n"); + out->Print( + method_dict, + "response_deserializer=$ResponseModuleAndClass$.FromString,\n"); + out->Print(")\n"); + } + } + } + } + return true; +} + +bool PrivateGenerator::PrintServicer(const grpc_generator::Service* service, + grpc_generator::Printer* out) { + StringMap service_dict; + service_dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(service_dict, "class $Service$Servicer(object):\n"); + { + IndentScope raii_class_indent(out); + StringVector service_comments = service->GetAllComments(); + PrintAllComments(service_comments, out); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + std::string arg_name = + method->ClientStreaming() ? "request_iterator" : "request"; + StringMap method_dict; + method_dict["Method"] = method->name(); + method_dict["ArgName"] = arg_name; + out->Print("\n"); + out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n"); + { + IndentScope raii_method_indent(out); + StringVector method_comments = method->GetAllComments(); + PrintAllComments(method_comments, out); + out->Print("context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n"); + out->Print("context.set_details('Method not implemented!')\n"); + out->Print("raise NotImplementedError('Method not implemented!')\n"); + } + } + } + return true; +} + +bool PrivateGenerator::PrintAddServicerToServer( + const std::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { + StringMap service_dict; + service_dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(service_dict, + "def add_$Service$Servicer_to_server(servicer, server):\n"); + { + IndentScope raii_class_indent(out); + out->Print("rpc_method_handlers = {\n"); + { + IndentScope raii_dict_first_indent(out); + IndentScope raii_dict_second_indent(out); + for (int i = 0; i < service->method_count(); ++i) { + auto method = service->method(i); + std::string method_handler_constructor = + std::string(method->ClientStreaming() ? "stream" : "unary") + "_" + + std::string(method->ServerStreaming() ? "stream" : "unary") + + "_rpc_method_handler"; + std::string request_module_and_class; + if (!method->get_module_and_message_path_input( + &request_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + std::string response_module_and_class; + if (!method->get_module_and_message_path_output( + &response_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + StringMap method_dict; + method_dict["Method"] = method->name(); + method_dict["MethodHandlerConstructor"] = method_handler_constructor; + method_dict["RequestModuleAndClass"] = request_module_and_class; + method_dict["ResponseModuleAndClass"] = response_module_and_class; + out->Print(method_dict, + "'$Method$': grpc.$MethodHandlerConstructor$(\n"); + { + IndentScope raii_call_first_indent(out); + IndentScope raii_call_second_indent(out); + out->Print(method_dict, "servicer.$Method$,\n"); + out->Print( + method_dict, + "request_deserializer=$RequestModuleAndClass$.FromString,\n"); + out->Print( + method_dict, + "response_serializer=$ResponseModuleAndClass$.SerializeToString," + "\n"); + } + out->Print("),\n"); + } + } + StringMap method_dict; + method_dict["PackageQualifiedServiceName"] = package_qualified_service_name; + out->Print("}\n"); + out->Print("generic_handler = grpc.method_handlers_generic_handler(\n"); + { + IndentScope raii_call_first_indent(out); + IndentScope raii_call_second_indent(out); + out->Print(method_dict, + "'$PackageQualifiedServiceName$', rpc_method_handlers)\n"); + } + out->Print("server.add_generic_rpc_handlers((generic_handler,))\n"); + } + return true; +} + +/* Prints out a service class used as a container for static methods pertaining + * to a class. This class has the exact name of service written in the ".proto" + * file, with no suffixes. Since this class merely acts as a namespace, it + * should never be instantiated. + */ +bool PrivateGenerator::PrintServiceClass( + const std::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { + StringMap dict; + dict["Service"] = service->name(); + out->Print("\n\n"); + out->Print(" # This class is part of an EXPERIMENTAL API.\n"); + out->Print(dict, "class $Service$(object):\n"); + { + IndentScope class_indent(out); + StringVector service_comments = service->GetAllComments(); + PrintAllComments(service_comments, out); + for (int i = 0; i < service->method_count(); ++i) { + const auto& method = service->method(i); + std::string request_module_and_class; + if (!method->get_module_and_message_path_input( + &request_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + std::string response_module_and_class; + if (!method->get_module_and_message_path_output( + &response_module_and_class, generator_file_name, + generate_in_pb2_grpc, config.import_prefix, + config.prefixes_to_filter)) { + return false; + } + out->Print("\n"); + StringMap method_dict; + method_dict["Method"] = method->name(); + out->Print("@staticmethod\n"); + out->Print(method_dict, "def $Method$("); + std::string request_parameter( + method->ClientStreaming() ? "request_iterator" : "request"); + StringMap args_dict; + args_dict["RequestParameter"] = request_parameter; + { + IndentScope args_indent(out); + IndentScope args_double_indent(out); + out->Print(args_dict, "$RequestParameter$,\n"); + out->Print("target,\n"); + out->Print("options=(),\n"); + out->Print("channel_credentials=None,\n"); + out->Print("call_credentials=None,\n"); + out->Print("insecure=False,\n"); + out->Print("compression=None,\n"); + out->Print("wait_for_ready=None,\n"); + out->Print("timeout=None,\n"); + out->Print("metadata=None):\n"); + } + { + IndentScope method_indent(out); + std::string arity_method_name = + std::string(method->ClientStreaming() ? "stream" : "unary") + "_" + + std::string(method->ServerStreaming() ? "stream" : "unary"); + args_dict["ArityMethodName"] = arity_method_name; + args_dict["PackageQualifiedService"] = package_qualified_service_name; + args_dict["Method"] = method->name(); + out->Print(args_dict, + "return " + "grpc.experimental.$ArityMethodName$($RequestParameter$, " + "target, '/$PackageQualifiedService$/$Method$',\n"); + { + IndentScope continuation_indent(out); + StringMap serializer_dict; + serializer_dict["RequestModuleAndClass"] = request_module_and_class; + serializer_dict["ResponseModuleAndClass"] = response_module_and_class; + out->Print(serializer_dict, + "$RequestModuleAndClass$.SerializeToString,\n"); + out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n"); + out->Print("options, channel_credentials,\n"); + out->Print( + "insecure, call_credentials, compression, wait_for_ready, " + "timeout, metadata)\n"); + } + } + } + } + // TODO(rbellevi): Add methods pertinent to the server side as well. + return true; +} + +bool PrivateGenerator::PrintBetaPreamble(grpc_generator::Printer* out) { + StringMap var; + var["Package"] = config.beta_package_root; + out->Print(var, + "from $Package$ import implementations as beta_implementations\n"); + out->Print(var, "from $Package$ import interfaces as beta_interfaces\n"); + out->Print("from grpc.framework.common import cardinality\n"); + out->Print( + "from grpc.framework.interfaces.face import utilities as " + "face_utilities\n"); + return true; +} + +bool PrivateGenerator::PrintPreamble(grpc_generator::Printer* out) { + StringMap var; + var["Package"] = config.grpc_package_root; + out->Print(var, "import $Package$\n"); + if (generate_in_pb2_grpc) { + out->Print("\n"); + StringPairSet imports_set; + for (int i = 0; i < file->service_count(); ++i) { + auto service = file->service(i); + for (int j = 0; j < service->method_count(); ++j) { + auto method = service.get()->method(j); + + std::string input_type_file_name = method->get_input_type_name(); + std::string input_module_name = + ModuleName(input_type_file_name, config.import_prefix, + config.prefixes_to_filter); + std::string input_module_alias = + ModuleAlias(input_type_file_name, config.import_prefix, + config.prefixes_to_filter); + imports_set.insert( + std::make_tuple(input_module_name, input_module_alias)); + + std::string output_type_file_name = method->get_output_type_name(); + std::string output_module_name = + ModuleName(output_type_file_name, config.import_prefix, + config.prefixes_to_filter); + std::string output_module_alias = + ModuleAlias(output_type_file_name, config.import_prefix, + config.prefixes_to_filter); + imports_set.insert( + std::make_tuple(output_module_name, output_module_alias)); + } + } + + for (StringPairSet::iterator it = imports_set.begin(); + it != imports_set.end(); ++it) { + auto module_name = std::get<0>(*it); + var["ModuleAlias"] = std::get<1>(*it); + const size_t last_dot_pos = module_name.rfind('.'); + if (last_dot_pos == std::string::npos) { + var["ImportStatement"] = "import " + module_name; + } else { + var["ImportStatement"] = "from " + module_name.substr(0, last_dot_pos) + + " import " + + module_name.substr(last_dot_pos + 1); + } + out->Print(var, "$ImportStatement$ as $ModuleAlias$\n"); + } + } + return true; +} + +bool PrivateGenerator::PrintGAServices(grpc_generator::Printer* out) { + std::string package = file->package(); + if (!package.empty()) { + package = package.append("."); + } + for (int i = 0; i < file->service_count(); ++i) { + auto service = file->service(i); + std::string package_qualified_service_name = package + service->name(); + if (!(PrintStub(package_qualified_service_name, service.get(), out) && + PrintServicer(service.get(), out) && + PrintAddServicerToServer(package_qualified_service_name, + service.get(), out) && + PrintServiceClass(package_qualified_service_name, service.get(), + out))) { + return false; + } + } + return true; +} + +bool PrivateGenerator::PrintBetaServices(grpc_generator::Printer* out) { + std::string package = file->package(); + if (!package.empty()) { + package = package.append("."); + } + for (int i = 0; i < file->service_count(); ++i) { + auto service = file->service(i); + std::string package_qualified_service_name = package + service->name(); + if (!(PrintBetaServicer(service.get(), out) && + PrintBetaStub(service.get(), out) && + PrintBetaServerFactory(package_qualified_service_name, service.get(), + out) && + PrintBetaStubFactory(package_qualified_service_name, service.get(), + out))) { + return false; + } + } + return true; +} + +pair PrivateGenerator::GetGrpcServices() { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + auto out = file->CreatePrinter(&output); + if (generate_in_pb2_grpc) { + out->Print( + "# Generated by the gRPC Python protocol compiler plugin. " + "DO NOT EDIT!\n\"\"\"" + "Client and server classes corresponding to protobuf-defined " + "services.\"\"\"\n"); + if (!PrintPreamble(out.get())) { + return make_pair(false, ""); + } + if (!PrintGAServices(out.get())) { + return make_pair(false, ""); + } + } else { + out->Print("try:\n"); + { + IndentScope raii_dict_try_indent(out.get()); + out->Print( + "# THESE ELEMENTS WILL BE DEPRECATED.\n" + "# Please use the generated *_pb2_grpc.py files instead.\n"); + if (!PrintPreamble(out.get())) { + return make_pair(false, ""); + } + if (!PrintBetaPreamble(out.get())) { + return make_pair(false, ""); + } + if (!PrintGAServices(out.get())) { + return make_pair(false, ""); + } + if (!PrintBetaServices(out.get())) { + return make_pair(false, ""); + } + } + out->Print("except ImportError:\n"); + { + IndentScope raii_dict_except_indent(out.get()); + out->Print("pass"); + } + } + } + return make_pair(true, std::move(output)); +} + +} // namespace + +GeneratorConfiguration::GeneratorConfiguration() + : grpc_package_root("grpc"), + beta_package_root("grpc.beta"), + import_prefix("") {} + +PythonGrpcGenerator::PythonGrpcGenerator(const GeneratorConfiguration& config) + : config_(config) {} + +PythonGrpcGenerator::~PythonGrpcGenerator() {} + +static bool GenerateGrpc(GeneratorContext* context, PrivateGenerator& generator, + std::string file_name, bool generate_in_pb2_grpc) { + bool success; + std::unique_ptr output; + std::unique_ptr coded_output; + std::string grpc_code; + + if (generate_in_pb2_grpc) { + output.reset(context->Open(file_name)); + generator.generate_in_pb2_grpc = true; + } else { + output.reset(context->OpenForInsert(file_name, "module_scope")); + generator.generate_in_pb2_grpc = false; + } + + coded_output.reset(new CodedOutputStream(output.get())); + tie(success, grpc_code) = generator.GetGrpcServices(); + + if (success) { + coded_output->WriteRaw(grpc_code.data(), grpc_code.size()); + return true; + } else { + return false; + } +} + +static bool ParseParameters(const std::string& parameter, + std::string* grpc_version, + std::vector* strip_prefixes, + std::string* error) { + std::vector comma_delimited_parameters; + grpc_python_generator::Split(parameter, ',', &comma_delimited_parameters); + if (comma_delimited_parameters.size() == 1 && + comma_delimited_parameters[0].empty()) { + *grpc_version = "grpc_2_0"; + } else if (comma_delimited_parameters.size() == 1) { + *grpc_version = comma_delimited_parameters[0]; + } else if (comma_delimited_parameters.size() == 2) { + *grpc_version = comma_delimited_parameters[0]; + std::copy(comma_delimited_parameters.begin() + 1, + comma_delimited_parameters.end(), + std::back_inserter(*strip_prefixes)); + } else { + *error = "--grpc_python_out received too many comma-delimited parameters."; + return false; + } + return true; +} + +uint64_t PythonGrpcGenerator::GetSupportedFeatures() const { + return FEATURE_PROTO3_OPTIONAL; +} + +bool PythonGrpcGenerator::Generate(const FileDescriptor* file, + const std::string& parameter, + GeneratorContext* context, + std::string* error) const { + // Get output file name. + std::string pb2_file_name; + std::string pb2_grpc_file_name; + static const int proto_suffix_length = strlen(".proto"); + if (file->name().size() > static_cast(proto_suffix_length) && + file->name().find_last_of(".proto") == file->name().size() - 1) { + std::string base = + file->name().substr(0, file->name().size() - proto_suffix_length); + std::replace(base.begin(), base.end(), '-', '_'); + pb2_file_name = base + "_pb2.py"; + pb2_grpc_file_name = base + "_pb2_grpc.py"; + } else { + *error = "Invalid proto file name. Proto file must end with .proto"; + return false; + } + generator_file_name = file->name(); + + ProtoBufFile pbfile(file); + std::string grpc_version; + GeneratorConfiguration extended_config(config_); + bool success = ParseParameters(parameter, &grpc_version, + &(extended_config.prefixes_to_filter), error); + PrivateGenerator generator(extended_config, &pbfile); + if (!success) return false; + if (grpc_version == "grpc_2_0") { + return GenerateGrpc(context, generator, pb2_grpc_file_name, true); + } else if (grpc_version == "grpc_1_0") { + return GenerateGrpc(context, generator, pb2_grpc_file_name, true) && + GenerateGrpc(context, generator, pb2_file_name, false); + } else { + *error = "Invalid grpc version '" + grpc_version + "'."; + return false; + } +} + +} // namespace grpc_python_generator diff --git a/src/compiler/python_plugin.cc b/src/compiler/python_plugin.cc new file mode 100644 index 00000000..81eb1d4f --- /dev/null +++ b/src/compiler/python_plugin.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates a Python gRPC service interface out of Protobuf IDL. + +#include "src/compiler/config.h" +#include "src/compiler/protobuf_plugin.h" +#include "src/compiler/python_generator.h" + +int main(int argc, char* argv[]) { + grpc_python_generator::GeneratorConfiguration config; + grpc_python_generator::PythonGrpcGenerator generator(config); + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/compiler/ruby_generator.cc b/src/compiler/ruby_generator.cc new file mode 100644 index 00000000..c553e1c3 --- /dev/null +++ b/src/compiler/ruby_generator.cc @@ -0,0 +1,215 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/compiler/ruby_generator.h" + +#include +#include +#include + +#include "src/compiler/config.h" +#include "src/compiler/ruby_generator_helpers-inl.h" +#include "src/compiler/ruby_generator_map-inl.h" +#include "src/compiler/ruby_generator_string-inl.h" + +using grpc::protobuf::FileDescriptor; +using grpc::protobuf::MethodDescriptor; +using grpc::protobuf::ServiceDescriptor; +using grpc::protobuf::io::Printer; +using grpc::protobuf::io::StringOutputStream; +using std::map; +using std::vector; + +namespace grpc_ruby_generator { +namespace { + +// Prints out the method using the ruby gRPC DSL. +void PrintMethod(const MethodDescriptor* method, Printer* out) { + std::string input_type = RubyTypeOf(method->input_type()); + if (method->client_streaming()) { + input_type = "stream(" + input_type + ")"; + } + std::string output_type = RubyTypeOf(method->output_type()); + if (method->server_streaming()) { + output_type = "stream(" + output_type + ")"; + } + std::map method_vars = ListToDict({ + "mth.name", + method->name(), + "input.type", + input_type, + "output.type", + output_type, + }); + out->Print(GetRubyComments(method, true).c_str()); + out->Print(method_vars, "rpc :$mth.name$, $input.type$, $output.type$\n"); + out->Print(GetRubyComments(method, false).c_str()); +} + +// Prints out the service using the ruby gRPC DSL. +void PrintService(const ServiceDescriptor* service, Printer* out) { + if (service->method_count() == 0) { + return; + } + + // Begin the service module + std::map module_vars = ListToDict({ + "module.name", + Modularize(service->name()), + }); + out->Print(module_vars, "module $module.name$\n"); + out->Indent(); + + out->Print(GetRubyComments(service, true).c_str()); + out->Print("class Service\n"); + + // Write the indented class body. + out->Indent(); + out->Print("\n"); + out->Print("include ::GRPC::GenericService\n"); + out->Print("\n"); + out->Print("self.marshal_class_method = :encode\n"); + out->Print("self.unmarshal_class_method = :decode\n"); + std::map pkg_vars = + ListToDict({"service_full_name", service->full_name()}); + out->Print(pkg_vars, "self.service_name = '$service_full_name$'\n"); + out->Print("\n"); + for (int i = 0; i < service->method_count(); ++i) { + PrintMethod(service->method(i), out); + } + out->Outdent(); + + out->Print("end\n"); + out->Print("\n"); + out->Print("Stub = Service.rpc_stub_class\n"); + + // End the service module + out->Outdent(); + out->Print("end\n"); + out->Print(GetRubyComments(service, false).c_str()); +} + +} // namespace + +// The following functions are copied directly from the source for the protoc +// ruby generator +// to ensure compatibility (with the exception of int and string type changes). +// See +// https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/ruby/ruby_generator.cc#L250 +// TODO: keep up to date with protoc code generation, though this behavior isn't +// expected to change +bool IsLower(char ch) { return ch >= 'a' && ch <= 'z'; } + +char ToUpper(char ch) { return IsLower(ch) ? (ch - 'a' + 'A') : ch; } + +// Package names in protobuf are snake_case by convention, but Ruby module +// names must be PascalCased. +// +// foo_bar_baz -> FooBarBaz +std::string PackageToModule(const std::string& name) { + bool next_upper = true; + std::string result; + result.reserve(name.size()); + + for (std::string::size_type i = 0; i < name.size(); i++) { + if (name[i] == '_') { + next_upper = true; + } else { + if (next_upper) { + result.push_back(ToUpper(name[i])); + } else { + result.push_back(name[i]); + } + next_upper = false; + } + } + + return result; +} +// end copying of protoc generator for ruby code + +std::string GetServices(const FileDescriptor* file) { + std::string output; + { + // Scope the output stream so it closes and finalizes output to the string. + + StringOutputStream output_stream(&output); + Printer out(&output_stream, '$'); + + // Don't write out any output if there no services, to avoid empty service + // files being generated for proto files that don't declare any. + if (file->service_count() == 0) { + return output; + } + + std::string package_name = RubyPackage(file); + + // Write out a file header. + std::map header_comment_vars = ListToDict({ + "file.name", + file->name(), + "file.package", + package_name, + }); + out.Print("# Generated by the protocol buffer compiler. DO NOT EDIT!\n"); + out.Print(header_comment_vars, + "# Source: $file.name$ for package '$file.package$'\n"); + + std::string leading_comments = GetRubyComments(file, true); + if (!leading_comments.empty()) { + out.Print("# Original file comments:\n"); + out.PrintRaw(leading_comments.c_str()); + } + + out.Print("\n"); + out.Print("require 'grpc'\n"); + // Write out require statemment to import the separately generated file + // that defines the messages used by the service. This is generated by the + // main ruby plugin. + std::map dep_vars = ListToDict({ + "dep.name", + MessagesRequireName(file), + }); + out.Print(dep_vars, "require '$dep.name$'\n"); + + // Write out services within the modules + out.Print("\n"); + std::vector modules = Split(package_name, '.'); + for (size_t i = 0; i < modules.size(); ++i) { + std::map module_vars = ListToDict({ + "module.name", + PackageToModule(modules[i]), + }); + out.Print(module_vars, "module $module.name$\n"); + out.Indent(); + } + for (int i = 0; i < file->service_count(); ++i) { + auto service = file->service(i); + PrintService(service, &out); + } + for (size_t i = 0; i < modules.size(); ++i) { + out.Outdent(); + out.Print("end\n"); + } + + out.Print(GetRubyComments(file, false).c_str()); + } + return output; +} + +} // namespace grpc_ruby_generator diff --git a/src/compiler/ruby_plugin.cc b/src/compiler/ruby_plugin.cc new file mode 100644 index 00000000..8821e613 --- /dev/null +++ b/src/compiler/ruby_plugin.cc @@ -0,0 +1,61 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Generates Ruby gRPC service interface out of Protobuf IDL. + +#include + +#include "src/compiler/config.h" +#include "src/compiler/ruby_generator.h" +#include "src/compiler/ruby_generator_helpers-inl.h" + +class RubyGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { + public: + RubyGrpcGenerator() {} + ~RubyGrpcGenerator() {} + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + bool Generate(const grpc::protobuf::FileDescriptor* file, + const std::string& /*parameter*/, + grpc::protobuf::compiler::GeneratorContext* context, + std::string* /*error*/) const override { + std::string code = grpc_ruby_generator::GetServices(file); + if (code.size() == 0) { + return true; // don't generate a file if there are no services + } + + // Get output file name. + std::string file_name; + if (!grpc_ruby_generator::ServicesFilename(file, &file_name)) { + return false; + } + std::unique_ptr output( + context->Open(file_name)); + grpc::protobuf::io::CodedOutputStream coded_out(output.get()); + coded_out.WriteRaw(code.data(), code.size()); + return true; + } +}; + +int main(int argc, char* argv[]) { + RubyGrpcGenerator generator; + return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/src/core/ext/filters/census/grpc_context.cc b/src/core/ext/filters/census/grpc_context.cc new file mode 100644 index 00000000..6659f701 --- /dev/null +++ b/src/core/ext/filters/census/grpc_context.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" + +void grpc_census_call_set_context(grpc_call* call, census_context* context) { + GRPC_API_TRACE("grpc_census_call_set_context(call=%p, census_context=%p)", 2, + (call, context)); + if (context != nullptr) { + grpc_call_context_set(call, GRPC_CONTEXT_TRACING, context, nullptr); + } +} + +census_context* grpc_census_call_get_context(grpc_call* call) { + GRPC_API_TRACE("grpc_census_call_get_context(call=%p)", 1, (call)); + return static_cast( + grpc_call_context_get(call, GRPC_CONTEXT_TRACING)); +} diff --git a/src/core/ext/filters/client_channel/backend_metric.cc b/src/core/ext/filters/client_channel/backend_metric.cc new file mode 100644 index 00000000..235916a9 --- /dev/null +++ b/src/core/ext/filters/client_channel/backend_metric.cc @@ -0,0 +1,80 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/backend_metric.h" + +#include "absl/strings/string_view.h" +#include "upb/upb.hpp" +#include "xds/data/orca/v3/orca_load_report.upb.h" + +namespace grpc_core { + +namespace { + +template +std::map ParseMap( + xds_data_orca_v3_OrcaLoadReport* msg, + const EntryType* (*entry_func)(const xds_data_orca_v3_OrcaLoadReport*, + size_t*), + upb_strview (*key_func)(const EntryType*), + double (*value_func)(const EntryType*), Arena* arena) { + std::map result; + size_t i = UPB_MAP_BEGIN; + while (true) { + const auto* entry = entry_func(msg, &i); + if (entry == nullptr) break; + upb_strview key_view = key_func(entry); + char* key = static_cast(arena->Alloc(key_view.size)); + memcpy(key, key_view.data, key_view.size); + result[absl::string_view(key, key_view.size)] = value_func(entry); + } + return result; +} + +} // namespace + +const LoadBalancingPolicy::BackendMetricData* ParseBackendMetricData( + const grpc_slice& serialized_load_report, Arena* arena) { + upb::Arena upb_arena; + xds_data_orca_v3_OrcaLoadReport* msg = xds_data_orca_v3_OrcaLoadReport_parse( + reinterpret_cast( + GRPC_SLICE_START_PTR(serialized_load_report)), + GRPC_SLICE_LENGTH(serialized_load_report), upb_arena.ptr()); + if (msg == nullptr) return nullptr; + LoadBalancingPolicy::BackendMetricData* backend_metric_data = + arena->New(); + backend_metric_data->cpu_utilization = + xds_data_orca_v3_OrcaLoadReport_cpu_utilization(msg); + backend_metric_data->mem_utilization = + xds_data_orca_v3_OrcaLoadReport_mem_utilization(msg); + backend_metric_data->requests_per_second = + xds_data_orca_v3_OrcaLoadReport_rps(msg); + backend_metric_data->request_cost = + ParseMap( + msg, xds_data_orca_v3_OrcaLoadReport_request_cost_next, + xds_data_orca_v3_OrcaLoadReport_RequestCostEntry_key, + xds_data_orca_v3_OrcaLoadReport_RequestCostEntry_value, arena); + backend_metric_data->utilization = + ParseMap( + msg, xds_data_orca_v3_OrcaLoadReport_utilization_next, + xds_data_orca_v3_OrcaLoadReport_UtilizationEntry_key, + xds_data_orca_v3_OrcaLoadReport_UtilizationEntry_value, arena); + return backend_metric_data; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/backup_poller.cc b/src/core/ext/filters/client_channel/backup_poller.cc new file mode 100644 index 00000000..1332e877 --- /dev/null +++ b/src/core/ext/filters/client_channel/backup_poller.cc @@ -0,0 +1,183 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" + +#define DEFAULT_POLL_INTERVAL_MS 5000 + +namespace { +struct backup_poller { + grpc_timer polling_timer; + grpc_closure run_poller_closure; + grpc_closure shutdown_closure; + gpr_mu* pollset_mu; + grpc_pollset* pollset; // guarded by pollset_mu + bool shutting_down; // guarded by pollset_mu + gpr_refcount refs; + gpr_refcount shutdown_refs; +}; +} // namespace + +static gpr_once g_once = GPR_ONCE_INIT; +static gpr_mu g_poller_mu; +static backup_poller* g_poller = nullptr; // guarded by g_poller_mu +// g_poll_interval_ms is set only once at the first time +// grpc_client_channel_start_backup_polling() is called, after that it is +// treated as const. +static int g_poll_interval_ms = DEFAULT_POLL_INTERVAL_MS; + +GPR_GLOBAL_CONFIG_DEFINE_INT32( + grpc_client_channel_backup_poll_interval_ms, DEFAULT_POLL_INTERVAL_MS, + "Declares the interval in ms between two backup polls on client channels. " + "These polls are run in the timer thread so that gRPC can process " + "connection failures while there is no active polling thread. " + "They help reconnect disconnected client channels (mostly due to " + "idleness), so that the next RPC on this channel won't fail. Set to 0 to " + "turn off the backup polls."); + +void grpc_client_channel_global_init_backup_polling() { + gpr_once_init(&g_once, [] { gpr_mu_init(&g_poller_mu); }); + int32_t poll_interval_ms = + GPR_GLOBAL_CONFIG_GET(grpc_client_channel_backup_poll_interval_ms); + if (poll_interval_ms < 0) { + gpr_log(GPR_ERROR, + "Invalid GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS: %d, " + "default value %d will be used.", + poll_interval_ms, g_poll_interval_ms); + } else { + g_poll_interval_ms = poll_interval_ms; + } +} + +static void backup_poller_shutdown_unref(backup_poller* p) { + if (gpr_unref(&p->shutdown_refs)) { + grpc_pollset_destroy(p->pollset); + gpr_free(p->pollset); + gpr_free(p); + } +} + +static void done_poller(void* arg, grpc_error_handle /*error*/) { + backup_poller_shutdown_unref(static_cast(arg)); +} + +static void g_poller_unref() { + gpr_mu_lock(&g_poller_mu); + if (gpr_unref(&g_poller->refs)) { + backup_poller* p = g_poller; + g_poller = nullptr; + gpr_mu_unlock(&g_poller_mu); + gpr_mu_lock(p->pollset_mu); + p->shutting_down = true; + grpc_pollset_shutdown( + p->pollset, GRPC_CLOSURE_INIT(&p->shutdown_closure, done_poller, p, + grpc_schedule_on_exec_ctx)); + gpr_mu_unlock(p->pollset_mu); + grpc_timer_cancel(&p->polling_timer); + backup_poller_shutdown_unref(p); + } else { + gpr_mu_unlock(&g_poller_mu); + } +} + +static void run_poller(void* arg, grpc_error_handle error) { + backup_poller* p = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + if (error != GRPC_ERROR_CANCELLED) { + GRPC_LOG_IF_ERROR("run_poller", GRPC_ERROR_REF(error)); + } + backup_poller_shutdown_unref(p); + return; + } + gpr_mu_lock(p->pollset_mu); + if (p->shutting_down) { + gpr_mu_unlock(p->pollset_mu); + backup_poller_shutdown_unref(p); + return; + } + grpc_error_handle err = + grpc_pollset_work(p->pollset, nullptr, grpc_core::ExecCtx::Get()->Now()); + gpr_mu_unlock(p->pollset_mu); + GRPC_LOG_IF_ERROR("Run client channel backup poller", err); + grpc_timer_init(&p->polling_timer, + grpc_core::ExecCtx::Get()->Now() + g_poll_interval_ms, + &p->run_poller_closure); +} + +static void g_poller_init_locked() { + if (g_poller == nullptr) { + g_poller = grpc_core::Zalloc(); + g_poller->pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + g_poller->shutting_down = false; + grpc_pollset_init(g_poller->pollset, &g_poller->pollset_mu); + gpr_ref_init(&g_poller->refs, 0); + // one for timer cancellation, one for pollset shutdown, one for g_poller + gpr_ref_init(&g_poller->shutdown_refs, 3); + GRPC_CLOSURE_INIT(&g_poller->run_poller_closure, run_poller, g_poller, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&g_poller->polling_timer, + grpc_core::ExecCtx::Get()->Now() + g_poll_interval_ms, + &g_poller->run_poller_closure); + } +} + +void grpc_client_channel_start_backup_polling( + grpc_pollset_set* interested_parties) { + if (g_poll_interval_ms == 0 || grpc_iomgr_run_in_background()) { + return; + } + gpr_mu_lock(&g_poller_mu); + g_poller_init_locked(); + gpr_ref(&g_poller->refs); + /* Get a reference to g_poller->pollset before releasing g_poller_mu to make + * TSAN happy. Otherwise, reading from g_poller (i.e g_poller->pollset) after + * releasing the lock and setting g_poller to NULL in g_poller_unref() is + * being flagged as a data-race by TSAN */ + grpc_pollset* pollset = g_poller->pollset; + gpr_mu_unlock(&g_poller_mu); + + grpc_pollset_set_add_pollset(interested_parties, pollset); +} + +void grpc_client_channel_stop_backup_polling( + grpc_pollset_set* interested_parties) { + if (g_poll_interval_ms == 0 || grpc_iomgr_run_in_background()) { + return; + } + grpc_pollset_set_del_pollset(interested_parties, g_poller->pollset); + g_poller_unref(); +} diff --git a/src/core/ext/filters/client_channel/channel_connectivity.cc b/src/core/ext/filters/client_channel/channel_connectivity.cc new file mode 100644 index 00000000..a75b9746 --- /dev/null +++ b/src/core/ext/filters/client_channel/channel_connectivity.cc @@ -0,0 +1,220 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/lame_client.h" + +namespace { + +bool IsLameChannel(grpc_channel* channel) { + grpc_channel_element* elem = + grpc_channel_stack_last_element(grpc_channel_get_channel_stack(channel)); + return elem->filter == &grpc_lame_filter; +} + +} // namespace + +grpc_connectivity_state grpc_channel_check_connectivity_state( + grpc_channel* channel, int try_to_connect) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_channel_check_connectivity_state(channel=%p, try_to_connect=%d)", 2, + (channel, try_to_connect)); + // Forward through to the underlying client channel. + grpc_core::ClientChannel* client_channel = + grpc_core::ClientChannel::GetFromChannel(channel); + if (GPR_UNLIKELY(client_channel == nullptr)) { + if (IsLameChannel(channel)) return GRPC_CHANNEL_TRANSIENT_FAILURE; + gpr_log(GPR_ERROR, + "grpc_channel_check_connectivity_state called on something that is " + "not a client channel"); + return GRPC_CHANNEL_SHUTDOWN; + } + return client_channel->CheckConnectivityState(try_to_connect); +} + +int grpc_channel_num_external_connectivity_watchers(grpc_channel* channel) { + grpc_core::ClientChannel* client_channel = + grpc_core::ClientChannel::GetFromChannel(channel); + if (client_channel == nullptr) { + if (!IsLameChannel(channel)) { + gpr_log(GPR_ERROR, + "grpc_channel_num_external_connectivity_watchers called on " + "something that is not a client channel"); + } + return 0; + } + return client_channel->NumExternalConnectivityWatchers(); +} + +int grpc_channel_support_connectivity_watcher(grpc_channel* channel) { + return grpc_core::ClientChannel::GetFromChannel(channel) != nullptr; +} + +namespace grpc_core { +namespace { + +class StateWatcher : public DualRefCounted { + public: + StateWatcher(grpc_channel* channel, grpc_completion_queue* cq, void* tag, + grpc_connectivity_state last_observed_state, + gpr_timespec deadline) + : channel_(channel), cq_(cq), tag_(tag), state_(last_observed_state) { + GPR_ASSERT(grpc_cq_begin_op(cq, tag)); + GRPC_CHANNEL_INTERNAL_REF(channel, "watch_channel_connectivity"); + GRPC_CLOSURE_INIT(&on_complete_, WatchComplete, this, nullptr); + GRPC_CLOSURE_INIT(&on_timeout_, TimeoutComplete, this, nullptr); + ClientChannel* client_channel = ClientChannel::GetFromChannel(channel); + if (client_channel == nullptr) { + // If the target URI used to create the channel was invalid, channel + // stack initialization failed, and that caused us to create a lame + // channel. In that case, connectivity state will never change (it + // will always be TRANSIENT_FAILURE), so we don't actually start a + // watch, but we are hiding that fact from the application. + if (IsLameChannel(channel)) { + // Ref from object creation is held by timer callback. + StartTimer(grpc_timespec_to_millis_round_up(deadline)); + return; + } + gpr_log(GPR_ERROR, + "grpc_channel_watch_connectivity_state called on " + "something that is not a client channel"); + GPR_ASSERT(false); + } + // Take an addition ref, so we have two (the first one is from the + // creation of this object). One will be held by the timer callback, + // the other by the watcher callback. + Ref().release(); + auto* watcher_timer_init_state = new WatcherTimerInitState( + this, grpc_timespec_to_millis_round_up(deadline)); + client_channel->AddExternalConnectivityWatcher( + grpc_polling_entity_create_from_pollset(grpc_cq_pollset(cq)), &state_, + &on_complete_, watcher_timer_init_state->closure()); + } + + ~StateWatcher() override { + GRPC_CHANNEL_INTERNAL_UNREF(channel_, "watch_channel_connectivity"); + } + + private: + // A fire-and-forget object used to delay starting the timer until the + // ClientChannel actually starts the watch. + class WatcherTimerInitState { + public: + WatcherTimerInitState(StateWatcher* state_watcher, grpc_millis deadline) + : state_watcher_(state_watcher), deadline_(deadline) { + GRPC_CLOSURE_INIT(&closure_, WatcherTimerInit, this, nullptr); + } + + grpc_closure* closure() { return &closure_; } + + private: + static void WatcherTimerInit(void* arg, grpc_error_handle /*error*/) { + auto* self = static_cast(arg); + self->state_watcher_->StartTimer(self->deadline_); + delete self; + } + + StateWatcher* state_watcher_; + grpc_millis deadline_; + grpc_closure closure_; + }; + + void StartTimer(grpc_millis deadline) { + grpc_timer_init(&timer_, deadline, &on_timeout_); + } + + static void WatchComplete(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures)) { + GRPC_LOG_IF_ERROR("watch_completion_error", GRPC_ERROR_REF(error)); + } + grpc_timer_cancel(&self->timer_); + self->Unref(); + } + + static void TimeoutComplete(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + self->timer_fired_ = error == GRPC_ERROR_NONE; + // If this is a client channel (not a lame channel), cancel the watch. + ClientChannel* client_channel = + ClientChannel::GetFromChannel(self->channel_); + if (client_channel != nullptr) { + client_channel->CancelExternalConnectivityWatcher(&self->on_complete_); + } + self->Unref(); + } + + // Invoked when both strong refs are released. + void Orphan() override { + WeakRef().release(); // Take a weak ref until completion is finished. + grpc_error_handle error = + timer_fired_ ? GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Timed out waiting for connection state change") + : GRPC_ERROR_NONE; + grpc_cq_end_op(cq_, tag_, error, FinishedCompletion, this, + &completion_storage_); + } + + // Called when the completion is returned to the CQ. + static void FinishedCompletion(void* arg, grpc_cq_completion* /*ignored*/) { + auto* self = static_cast(arg); + self->WeakUnref(); + } + + grpc_channel* channel_; + grpc_completion_queue* cq_; + void* tag_; + + grpc_connectivity_state state_; + + grpc_cq_completion completion_storage_; + + grpc_closure on_complete_; + grpc_timer timer_; + grpc_closure on_timeout_; + + bool timer_fired_ = false; +}; + +} // namespace +} // namespace grpc_core + +void grpc_channel_watch_connectivity_state( + grpc_channel* channel, grpc_connectivity_state last_observed_state, + gpr_timespec deadline, grpc_completion_queue* cq, void* tag) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_channel_watch_connectivity_state(" + "channel=%p, last_observed_state=%d, " + "deadline=gpr_timespec { tv_sec: %" PRId64 + ", tv_nsec: %d, clock_type: %d }, " + "cq=%p, tag=%p)", + 7, + (channel, (int)last_observed_state, deadline.tv_sec, deadline.tv_nsec, + (int)deadline.clock_type, cq, tag)); + new grpc_core::StateWatcher(channel, cq, tag, last_observed_state, deadline); +} diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc new file mode 100644 index 00000000..94c25142 --- /dev/null +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -0,0 +1,3119 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" + +#include +#include +#include +#include +#include + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backend_metric.h" +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/config_selector.h" +#include "src/core/ext/filters/client_channel/dynamic_filters.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/local_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" +#include "src/core/ext/filters/client_channel/retry_filter.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/ext/filters/deadline/deadline_filter.h" +#include "src/core/ext/service_config/service_config.h" +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" + +// +// Client channel filter +// + +#define GRPC_ARG_HEALTH_CHECK_SERVICE_NAME \ + "grpc.internal.health_check_service_name" + +namespace grpc_core { + +using internal::ClientChannelGlobalParsedConfig; +using internal::ClientChannelMethodParsedConfig; +using internal::ClientChannelServiceConfigParser; + +TraceFlag grpc_client_channel_call_trace(false, "client_channel_call"); +TraceFlag grpc_client_channel_routing_trace(false, "client_channel_routing"); + +// +// ClientChannel::CallData definition +// + +class ClientChannel::CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args); + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* final_info, + grpc_closure* then_schedule_closure); + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch); + static void SetPollent(grpc_call_element* elem, grpc_polling_entity* pollent); + + // Invoked by channel for queued calls when name resolution is completed. + static void CheckResolution(void* arg, grpc_error_handle error); + // Helper function for applying the service config to a call while + // holding ClientChannel::resolution_mu_. + // Returns true if the service config has been applied to the call, in which + // case the caller must invoke ResolutionDone() or AsyncResolutionDone() + // with the returned error. + bool CheckResolutionLocked(grpc_call_element* elem, grpc_error_handle* error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::resolution_mu_); + // Schedules a callback to continue processing the call once + // resolution is complete. The callback will not run until after this + // method returns. + void AsyncResolutionDone(grpc_call_element* elem, grpc_error_handle error); + + private: + class ResolverQueuedCallCanceller; + + CallData(grpc_call_element* elem, const ClientChannel& chand, + const grpc_call_element_args& args); + ~CallData(); + + // Returns the index into pending_batches_ to be used for batch. + static size_t GetBatchIndex(grpc_transport_stream_op_batch* batch); + void PendingBatchesAdd(grpc_call_element* elem, + grpc_transport_stream_op_batch* batch); + static void FailPendingBatchInCallCombiner(void* arg, + grpc_error_handle error); + // A predicate type and some useful implementations for PendingBatchesFail(). + typedef bool (*YieldCallCombinerPredicate)( + const CallCombinerClosureList& closures); + static bool YieldCallCombiner(const CallCombinerClosureList& /*closures*/) { + return true; + } + static bool NoYieldCallCombiner(const CallCombinerClosureList& /*closures*/) { + return false; + } + static bool YieldCallCombinerIfPendingBatchesFound( + const CallCombinerClosureList& closures) { + return closures.size() > 0; + } + // Fails all pending batches. + // If yield_call_combiner_predicate returns true, assumes responsibility for + // yielding the call combiner. + void PendingBatchesFail( + grpc_call_element* elem, grpc_error_handle error, + YieldCallCombinerPredicate yield_call_combiner_predicate); + static void ResumePendingBatchInCallCombiner(void* arg, + grpc_error_handle ignored); + // Resumes all pending batches on lb_call_. + void PendingBatchesResume(grpc_call_element* elem); + + // Applies service config to the call. Must be invoked once we know + // that the resolver has returned results to the channel. + // If an error is returned, the error indicates the status with which + // the call should be failed. + grpc_error_handle ApplyServiceConfigToCallLocked( + grpc_call_element* elem, grpc_metadata_batch* initial_metadata) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::resolution_mu_); + // Invoked when the resolver result is applied to the caller, on both + // success or failure. + static void ResolutionDone(void* arg, grpc_error_handle error); + // Removes the call (if present) from the channel's list of calls queued + // for name resolution. + void MaybeRemoveCallFromResolverQueuedCallsLocked(grpc_call_element* elem) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::resolution_mu_); + // Adds the call (if not already present) to the channel's list of + // calls queued for name resolution. + void MaybeAddCallToResolverQueuedCallsLocked(grpc_call_element* elem) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::resolution_mu_); + + static void RecvTrailingMetadataReadyForConfigSelectorCommitCallback( + void* arg, grpc_error_handle error); + void InjectRecvTrailingMetadataReadyForConfigSelectorCommitCallback( + grpc_transport_stream_op_batch* batch); + + void CreateDynamicCall(grpc_call_element* elem); + + // State for handling deadlines. + // The code in deadline_filter.c requires this to be the first field. + // TODO(roth): This is slightly sub-optimal in that grpc_deadline_state + // and this struct both independently store pointers to the call stack + // and call combiner. If/when we have time, find a way to avoid this + // without breaking the grpc_deadline_state abstraction. + grpc_deadline_state deadline_state_; + + grpc_slice path_; // Request path. + gpr_cycle_counter call_start_time_; + grpc_millis deadline_; + Arena* arena_; + grpc_call_stack* owning_call_; + CallCombiner* call_combiner_; + grpc_call_context_element* call_context_; + + grpc_polling_entity* pollent_ = nullptr; + + grpc_closure resolution_done_closure_; + + // Accessed while holding ClientChannel::resolution_mu_. + bool service_config_applied_ ABSL_GUARDED_BY(&ClientChannel::resolution_mu_) = + false; + bool queued_pending_resolver_result_ + ABSL_GUARDED_BY(&ClientChannel::resolution_mu_) = false; + ClientChannel::ResolverQueuedCall resolver_queued_call_ + ABSL_GUARDED_BY(&ClientChannel::resolution_mu_); + ResolverQueuedCallCanceller* resolver_call_canceller_ + ABSL_GUARDED_BY(&ClientChannel::resolution_mu_) = nullptr; + + grpc_closure* original_recv_trailing_metadata_ready_ = nullptr; + grpc_closure recv_trailing_metadata_ready_; + + RefCountedPtr dynamic_filters_; + RefCountedPtr dynamic_call_; + + // Batches are added to this list when received from above. + // They are removed when we are done handling the batch (i.e., when + // either we have invoked all of the batch's callbacks or we have + // passed the batch down to the LB call and are not intercepting any of + // its callbacks). + grpc_transport_stream_op_batch* pending_batches_[MAX_PENDING_BATCHES] = {}; + + // Set when we get a cancel_stream op. + grpc_error_handle cancel_error_ = GRPC_ERROR_NONE; +}; + +// +// Filter vtable +// + +const grpc_channel_filter ClientChannel::kFilterVtable = { + ClientChannel::CallData::StartTransportStreamOpBatch, + ClientChannel::StartTransportOp, + sizeof(ClientChannel::CallData), + ClientChannel::CallData::Init, + ClientChannel::CallData::SetPollent, + ClientChannel::CallData::Destroy, + sizeof(ClientChannel), + ClientChannel::Init, + ClientChannel::Destroy, + ClientChannel::GetChannelInfo, + "client-channel", +}; + +// +// dynamic termination filter +// + +namespace { + +// Channel arg pointer vtable for GRPC_ARG_CLIENT_CHANNEL. +void* ClientChannelArgCopy(void* p) { return p; } +void ClientChannelArgDestroy(void* /*p*/) {} +int ClientChannelArgCmp(void* p, void* q) { return QsortCompare(p, q); } +const grpc_arg_pointer_vtable kClientChannelArgPointerVtable = { + ClientChannelArgCopy, ClientChannelArgDestroy, ClientChannelArgCmp}; + +// Channel arg pointer vtable for GRPC_ARG_SERVICE_CONFIG_OBJ. +void* ServiceConfigObjArgCopy(void* p) { + auto* service_config = static_cast(p); + service_config->Ref().release(); + return p; +} +void ServiceConfigObjArgDestroy(void* p) { + auto* service_config = static_cast(p); + service_config->Unref(); +} +int ServiceConfigObjArgCmp(void* p, void* q) { return QsortCompare(p, q); } +const grpc_arg_pointer_vtable kServiceConfigObjArgPointerVtable = { + ServiceConfigObjArgCopy, ServiceConfigObjArgDestroy, + ServiceConfigObjArgCmp}; + +class DynamicTerminationFilter { + public: + class CallData; + + static const grpc_channel_filter kFilterVtable; + + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(args->is_last); + GPR_ASSERT(elem->filter == &kFilterVtable); + new (elem->channel_data) DynamicTerminationFilter(args->channel_args); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~DynamicTerminationFilter(); + } + + // Will never be called. + static void StartTransportOp(grpc_channel_element* /*elem*/, + grpc_transport_op* /*op*/) {} + static void GetChannelInfo(grpc_channel_element* /*elem*/, + const grpc_channel_info* /*info*/) {} + + private: + explicit DynamicTerminationFilter(const grpc_channel_args* args) + : chand_(grpc_channel_args_find_pointer( + args, GRPC_ARG_CLIENT_CHANNEL)) {} + + ClientChannel* chand_; +}; + +class DynamicTerminationFilter::CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(*args); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure) { + auto* calld = static_cast(elem->call_data); + RefCountedPtr subchannel_call; + if (GPR_LIKELY(calld->lb_call_ != nullptr)) { + subchannel_call = calld->lb_call_->subchannel_call(); + } + calld->~CallData(); + if (GPR_LIKELY(subchannel_call != nullptr)) { + subchannel_call->SetAfterCallStackDestroy(then_schedule_closure); + } else { + // TODO(yashkt) : This can potentially be a Closure::Run + ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, GRPC_ERROR_NONE); + } + } + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + calld->lb_call_->StartTransportStreamOpBatch(batch); + } + + static void SetPollent(grpc_call_element* elem, + grpc_polling_entity* pollent) { + auto* calld = static_cast(elem->call_data); + auto* chand = static_cast(elem->channel_data); + ClientChannel* client_channel = chand->chand_; + grpc_call_element_args args = {calld->owning_call_, nullptr, + calld->call_context_, calld->path_, + /*start_time=*/0, calld->deadline_, + calld->arena_, calld->call_combiner_}; + auto* service_config_call_data = + static_cast( + calld->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + calld->lb_call_ = client_channel->CreateLoadBalancedCall( + args, pollent, nullptr, + service_config_call_data->call_dispatch_controller(), + /*is_transparent_retry=*/false); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p dynamic_termination_calld=%p: create lb_call=%p", chand, + client_channel, calld->lb_call_.get()); + } + } + + private: + explicit CallData(const grpc_call_element_args& args) + : path_(grpc_slice_ref_internal(args.path)), + deadline_(args.deadline), + arena_(args.arena), + owning_call_(args.call_stack), + call_combiner_(args.call_combiner), + call_context_(args.context) {} + + ~CallData() { grpc_slice_unref_internal(path_); } + + grpc_slice path_; // Request path. + grpc_millis deadline_; + Arena* arena_; + grpc_call_stack* owning_call_; + CallCombiner* call_combiner_; + grpc_call_context_element* call_context_; + + OrphanablePtr lb_call_; +}; + +const grpc_channel_filter DynamicTerminationFilter::kFilterVtable = { + DynamicTerminationFilter::CallData::StartTransportStreamOpBatch, + DynamicTerminationFilter::StartTransportOp, + sizeof(DynamicTerminationFilter::CallData), + DynamicTerminationFilter::CallData::Init, + DynamicTerminationFilter::CallData::SetPollent, + DynamicTerminationFilter::CallData::Destroy, + sizeof(DynamicTerminationFilter), + DynamicTerminationFilter::Init, + DynamicTerminationFilter::Destroy, + DynamicTerminationFilter::GetChannelInfo, + "dynamic_filter_termination", +}; + +} // namespace + +// +// ClientChannel::ResolverResultHandler +// + +class ClientChannel::ResolverResultHandler : public Resolver::ResultHandler { + public: + explicit ResolverResultHandler(ClientChannel* chand) : chand_(chand) { + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ResolverResultHandler"); + } + + ~ResolverResultHandler() override { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: resolver shutdown complete", chand_); + } + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, "ResolverResultHandler"); + } + + void ReturnResult(Resolver::Result result) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + chand_->OnResolverResultChangedLocked(std::move(result)); + } + + void ReturnError(grpc_error_handle error) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + chand_->OnResolverErrorLocked(error); + } + + private: + ClientChannel* chand_; +}; + +// +// ClientChannel::SubchannelWrapper +// + +// This class is a wrapper for Subchannel that hides details of the +// channel's implementation (such as the health check service name and +// connected subchannel) from the LB policy API. +// +// Note that no synchronization is needed here, because even if the +// underlying subchannel is shared between channels, this wrapper will only +// be used within one channel, so it will always be synchronized by the +// control plane work_serializer. +class ClientChannel::SubchannelWrapper : public SubchannelInterface { + public: + SubchannelWrapper(ClientChannel* chand, RefCountedPtr subchannel, + absl::optional health_check_service_name) + : SubchannelInterface( + GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace) + ? "SubchannelWrapper" + : nullptr), + chand_(chand), + subchannel_(std::move(subchannel)), + health_check_service_name_(std::move(health_check_service_name)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: creating subchannel wrapper %p for subchannel %p", + chand, this, subchannel_.get()); + } + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "SubchannelWrapper"); + if (chand_->channelz_node_ != nullptr) { + auto* subchannel_node = subchannel_->channelz_node(); + if (subchannel_node != nullptr) { + auto it = chand_->subchannel_refcount_map_.find(subchannel_.get()); + if (it == chand_->subchannel_refcount_map_.end()) { + chand_->channelz_node_->AddChildSubchannel(subchannel_node->uuid()); + it = chand_->subchannel_refcount_map_.emplace(subchannel_.get(), 0) + .first; + } + ++it->second; + } + } + chand_->subchannel_wrappers_.insert(this); + } + + ~SubchannelWrapper() override { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: destroying subchannel wrapper %p for subchannel %p", + chand_, this, subchannel_.get()); + } + chand_->subchannel_wrappers_.erase(this); + if (chand_->channelz_node_ != nullptr) { + auto* subchannel_node = subchannel_->channelz_node(); + if (subchannel_node != nullptr) { + auto it = chand_->subchannel_refcount_map_.find(subchannel_.get()); + GPR_ASSERT(it != chand_->subchannel_refcount_map_.end()); + --it->second; + if (it->second == 0) { + chand_->channelz_node_->RemoveChildSubchannel( + subchannel_node->uuid()); + chand_->subchannel_refcount_map_.erase(it); + } + } + } + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, "SubchannelWrapper"); + } + + grpc_connectivity_state CheckConnectivityState() override { + return subchannel_->CheckConnectivityState(health_check_service_name_); + } + + void WatchConnectivityState( + grpc_connectivity_state initial_state, + std::unique_ptr watcher) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + auto& watcher_wrapper = watcher_map_[watcher.get()]; + GPR_ASSERT(watcher_wrapper == nullptr); + watcher_wrapper = new WatcherWrapper(std::move(watcher), + Ref(DEBUG_LOCATION, "WatcherWrapper"), + initial_state); + subchannel_->WatchConnectivityState( + initial_state, health_check_service_name_, + RefCountedPtr( + watcher_wrapper)); + } + + void CancelConnectivityStateWatch(ConnectivityStateWatcherInterface* watcher) + override ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + auto it = watcher_map_.find(watcher); + GPR_ASSERT(it != watcher_map_.end()); + subchannel_->CancelConnectivityStateWatch(health_check_service_name_, + it->second); + watcher_map_.erase(it); + } + + RefCountedPtr connected_subchannel() const { + return subchannel_->connected_subchannel(); + } + + void AttemptToConnect() override { subchannel_->AttemptToConnect(); } + + void ResetBackoff() override { subchannel_->ResetBackoff(); } + + const grpc_channel_args* channel_args() override { + return subchannel_->channel_args(); + } + + void ThrottleKeepaliveTime(int new_keepalive_time) { + subchannel_->ThrottleKeepaliveTime(new_keepalive_time); + } + + private: + // Subchannel and SubchannelInterface have different interfaces for + // their respective ConnectivityStateWatcherInterface classes. + // The one in Subchannel updates the ConnectedSubchannel along with + // the state, whereas the one in SubchannelInterface does not expose + // the ConnectedSubchannel. + // + // This wrapper provides a bridge between the two. It implements + // Subchannel::ConnectivityStateWatcherInterface and wraps + // the instance of SubchannelInterface::ConnectivityStateWatcherInterface + // that was passed in by the LB policy. We pass an instance of this + // class to the underlying Subchannel, and when we get updates from + // the subchannel, we pass those on to the wrapped watcher to return + // the update to the LB policy. This allows us to set the connected + // subchannel before passing the result back to the LB policy. + class WatcherWrapper : public Subchannel::ConnectivityStateWatcherInterface { + public: + WatcherWrapper( + std::unique_ptr + watcher, + RefCountedPtr parent, + grpc_connectivity_state initial_state) + : watcher_(std::move(watcher)), + parent_(std::move(parent)), + last_seen_state_(initial_state) {} + + ~WatcherWrapper() override { + auto* parent = parent_.release(); // ref owned by lambda + parent->chand_->work_serializer_->Run( + [parent]() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(parent_->chand_->work_serializer_) { + parent->Unref(DEBUG_LOCATION, "WatcherWrapper"); + }, + DEBUG_LOCATION); + } + + void OnConnectivityStateChange() override { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: connectivity change for subchannel wrapper %p " + "subchannel %p; hopping into work_serializer", + parent_->chand_, parent_.get(), parent_->subchannel_.get()); + } + Ref().release(); // ref owned by lambda + parent_->chand_->work_serializer_->Run( + [this]() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(parent_->chand_->work_serializer_) { + ApplyUpdateInControlPlaneWorkSerializer(); + Unref(); + }, + DEBUG_LOCATION); + } + + grpc_pollset_set* interested_parties() override { + SubchannelInterface::ConnectivityStateWatcherInterface* watcher = + watcher_.get(); + if (watcher_ == nullptr) watcher = replacement_->watcher_.get(); + return watcher->interested_parties(); + } + + WatcherWrapper* MakeReplacement() { + auto* replacement = + new WatcherWrapper(std::move(watcher_), parent_, last_seen_state_); + replacement_ = replacement; + return replacement; + } + + grpc_connectivity_state last_seen_state() const { return last_seen_state_; } + + private: + void ApplyUpdateInControlPlaneWorkSerializer() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(parent_->chand_->work_serializer_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: processing connectivity change in work serializer " + "for subchannel wrapper %p subchannel %p " + "watcher=%p", + parent_->chand_, parent_.get(), parent_->subchannel_.get(), + watcher_.get()); + } + ConnectivityStateChange state_change = PopConnectivityStateChange(); + absl::optional keepalive_throttling = + state_change.status.GetPayload(kKeepaliveThrottlingKey); + if (keepalive_throttling.has_value()) { + int new_keepalive_time = -1; + if (absl::SimpleAtoi(std::string(keepalive_throttling.value()), + &new_keepalive_time)) { + if (new_keepalive_time > parent_->chand_->keepalive_time_) { + parent_->chand_->keepalive_time_ = new_keepalive_time; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: throttling keepalive time to %d", + parent_->chand_, parent_->chand_->keepalive_time_); + } + // Propagate the new keepalive time to all subchannels. This is so + // that new transports created by any subchannel (and not just the + // subchannel that received the GOAWAY), use the new keepalive time. + for (auto* subchannel_wrapper : + parent_->chand_->subchannel_wrappers_) { + subchannel_wrapper->ThrottleKeepaliveTime(new_keepalive_time); + } + } + } else { + gpr_log(GPR_ERROR, "chand=%p: Illegal keepalive throttling value %s", + parent_->chand_, + std::string(keepalive_throttling.value()).c_str()); + } + } + // Ignore update if the parent WatcherWrapper has been replaced + // since this callback was scheduled. + if (watcher_ != nullptr) { + last_seen_state_ = state_change.state; + watcher_->OnConnectivityStateChange(state_change.state); + } + } + + std::unique_ptr + watcher_; + RefCountedPtr parent_; + grpc_connectivity_state last_seen_state_; + WatcherWrapper* replacement_ = nullptr; + }; + + ClientChannel* chand_; + RefCountedPtr subchannel_; + absl::optional health_check_service_name_; + // Maps from the address of the watcher passed to us by the LB policy + // to the address of the WrapperWatcher that we passed to the underlying + // subchannel. This is needed so that when the LB policy calls + // CancelConnectivityStateWatch() with its watcher, we know the + // corresponding WrapperWatcher to cancel on the underlying subchannel. + std::map watcher_map_ + ABSL_GUARDED_BY(&ClientChannel::work_serializer_); +}; + +// +// ClientChannel::ExternalConnectivityWatcher +// + +ClientChannel::ExternalConnectivityWatcher::ExternalConnectivityWatcher( + ClientChannel* chand, grpc_polling_entity pollent, + grpc_connectivity_state* state, grpc_closure* on_complete, + grpc_closure* watcher_timer_init) + : chand_(chand), + pollent_(pollent), + initial_state_(*state), + state_(state), + on_complete_(on_complete), + watcher_timer_init_(watcher_timer_init) { + grpc_polling_entity_add_to_pollset_set(&pollent_, + chand_->interested_parties_); + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ExternalConnectivityWatcher"); + { + MutexLock lock(&chand_->external_watchers_mu_); + // Will be deleted when the watch is complete. + GPR_ASSERT(chand->external_watchers_[on_complete] == nullptr); + // Store a ref to the watcher in the external_watchers_ map. + chand->external_watchers_[on_complete] = + Ref(DEBUG_LOCATION, "AddWatcherToExternalWatchersMapLocked"); + } + // Pass the ref from creating the object to Start(). + chand_->work_serializer_->Run( + [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + // The ref is passed to AddWatcherLocked(). + AddWatcherLocked(); + }, + DEBUG_LOCATION); +} + +ClientChannel::ExternalConnectivityWatcher::~ExternalConnectivityWatcher() { + grpc_polling_entity_del_from_pollset_set(&pollent_, + chand_->interested_parties_); + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, + "ExternalConnectivityWatcher"); +} + +void ClientChannel::ExternalConnectivityWatcher:: + RemoveWatcherFromExternalWatchersMap(ClientChannel* chand, + grpc_closure* on_complete, + bool cancel) { + RefCountedPtr watcher; + { + MutexLock lock(&chand->external_watchers_mu_); + auto it = chand->external_watchers_.find(on_complete); + if (it != chand->external_watchers_.end()) { + watcher = std::move(it->second); + chand->external_watchers_.erase(it); + } + } + // watcher->Cancel() will hop into the WorkSerializer, so we have to unlock + // the mutex before calling it. + if (watcher != nullptr && cancel) watcher->Cancel(); +} + +void ClientChannel::ExternalConnectivityWatcher::Notify( + grpc_connectivity_state state, const absl::Status& /* status */) { + bool done = false; + if (!done_.compare_exchange_strong(done, true, std::memory_order_relaxed, + std::memory_order_relaxed)) { + return; // Already done. + } + // Remove external watcher. + ExternalConnectivityWatcher::RemoveWatcherFromExternalWatchersMap( + chand_, on_complete_, /*cancel=*/false); + // Report new state to the user. + *state_ = state; + ExecCtx::Run(DEBUG_LOCATION, on_complete_, GRPC_ERROR_NONE); + // Hop back into the work_serializer to clean up. + // Not needed in state SHUTDOWN, because the tracker will + // automatically remove all watchers in that case. + if (state != GRPC_CHANNEL_SHUTDOWN) { + chand_->work_serializer_->Run( + [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + RemoveWatcherLocked(); + }, + DEBUG_LOCATION); + } +} + +void ClientChannel::ExternalConnectivityWatcher::Cancel() { + bool done = false; + if (!done_.compare_exchange_strong(done, true, std::memory_order_relaxed, + std::memory_order_relaxed)) { + return; // Already done. + } + ExecCtx::Run(DEBUG_LOCATION, on_complete_, GRPC_ERROR_CANCELLED); + // Hop back into the work_serializer to clean up. + chand_->work_serializer_->Run( + [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + RemoveWatcherLocked(); + }, + DEBUG_LOCATION); +} + +void ClientChannel::ExternalConnectivityWatcher::AddWatcherLocked() { + Closure::Run(DEBUG_LOCATION, watcher_timer_init_, GRPC_ERROR_NONE); + // Add new watcher. Pass the ref of the object from creation to OrphanablePtr. + chand_->state_tracker_.AddWatcher( + initial_state_, OrphanablePtr(this)); +} + +void ClientChannel::ExternalConnectivityWatcher::RemoveWatcherLocked() { + chand_->state_tracker_.RemoveWatcher(this); +} + +// +// ClientChannel::ConnectivityWatcherAdder +// + +class ClientChannel::ConnectivityWatcherAdder { + public: + ConnectivityWatcherAdder( + ClientChannel* chand, grpc_connectivity_state initial_state, + OrphanablePtr watcher) + : chand_(chand), + initial_state_(initial_state), + watcher_(std::move(watcher)) { + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ConnectivityWatcherAdder"); + chand_->work_serializer_->Run( + [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + AddWatcherLocked(); + }, + DEBUG_LOCATION); + } + + private: + void AddWatcherLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + chand_->state_tracker_.AddWatcher(initial_state_, std::move(watcher_)); + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, "ConnectivityWatcherAdder"); + delete this; + } + + ClientChannel* chand_; + grpc_connectivity_state initial_state_; + OrphanablePtr watcher_; +}; + +// +// ClientChannel::ConnectivityWatcherRemover +// + +class ClientChannel::ConnectivityWatcherRemover { + public: + ConnectivityWatcherRemover(ClientChannel* chand, + AsyncConnectivityStateWatcherInterface* watcher) + : chand_(chand), watcher_(watcher) { + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ConnectivityWatcherRemover"); + chand_->work_serializer_->Run( + [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + RemoveWatcherLocked(); + }, + DEBUG_LOCATION); + } + + private: + void RemoveWatcherLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + chand_->state_tracker_.RemoveWatcher(watcher_); + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, + "ConnectivityWatcherRemover"); + delete this; + } + + ClientChannel* chand_; + AsyncConnectivityStateWatcherInterface* watcher_; +}; + +// +// ClientChannel::ClientChannelControlHelper +// + +class ClientChannel::ClientChannelControlHelper + : public LoadBalancingPolicy::ChannelControlHelper { + public: + explicit ClientChannelControlHelper(ClientChannel* chand) : chand_(chand) { + GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ClientChannelControlHelper"); + } + + ~ClientChannelControlHelper() override { + GRPC_CHANNEL_STACK_UNREF(chand_->owning_stack_, + "ClientChannelControlHelper"); + } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + if (chand_->resolver_ == nullptr) return nullptr; // Shutting down. + // Determine health check service name. + absl::optional health_check_service_name; + const char* health_check_service_name_arg = grpc_channel_args_find_string( + &args, GRPC_ARG_HEALTH_CHECK_SERVICE_NAME); + if (health_check_service_name_arg != nullptr) { + bool inhibit_health_checking = grpc_channel_args_find_bool( + &args, GRPC_ARG_INHIBIT_HEALTH_CHECKING, false); + if (!inhibit_health_checking) { + health_check_service_name = health_check_service_name_arg; + } + } + // Construct channel args for subchannel. + // Remove channel args that should not affect subchannel uniqueness. + absl::InlinedVector args_to_remove = { + GRPC_ARG_HEALTH_CHECK_SERVICE_NAME, + GRPC_ARG_INHIBIT_HEALTH_CHECKING, + GRPC_ARG_CHANNELZ_CHANNEL_NODE, + }; + // Add channel args needed for the subchannel. + absl::InlinedVector args_to_add = { + SubchannelPoolInterface::CreateChannelArg( + chand_->subchannel_pool_.get()), + }; + // Check if default authority arg is already set. + const char* default_authority = + grpc_channel_args_find_string(&args, GRPC_ARG_DEFAULT_AUTHORITY); + // Add args from subchannel address. + if (address.args() != nullptr) { + for (size_t j = 0; j < address.args()->num_args; ++j) { + grpc_arg& arg = address.args()->args[j]; + if (strcmp(arg.key, GRPC_ARG_DEFAULT_AUTHORITY) == 0) { + // Don't add default authority arg from subchannel address if + // it's already set at the channel level -- the value from the + // application should take precedence over what is set by the + // resolver. + if (default_authority != nullptr) continue; + default_authority = arg.value.string; + } + args_to_add.emplace_back(arg); + } + } + // If we haven't already set the default authority arg, add it from + // the channel. + if (default_authority == nullptr) { + // Remove it, just in case it's actually present but is the wrong type. + args_to_remove.push_back(GRPC_ARG_DEFAULT_AUTHORITY); + args_to_add.push_back(grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast(chand_->default_authority_.c_str()))); + } + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + &args, args_to_remove.data(), args_to_remove.size(), args_to_add.data(), + args_to_add.size()); + // Create subchannel. + RefCountedPtr subchannel = + chand_->client_channel_factory_->CreateSubchannel(address.address(), + new_args); + grpc_channel_args_destroy(new_args); + if (subchannel == nullptr) return nullptr; + // Make sure the subchannel has updated keepalive time. + subchannel->ThrottleKeepaliveTime(chand_->keepalive_time_); + // Create and return wrapper for the subchannel. + return MakeRefCounted( + chand_, std::move(subchannel), std::move(health_check_service_name)); + } + + void UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + if (chand_->resolver_ == nullptr) return; // Shutting down. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + const char* extra = chand_->disconnect_error_ == GRPC_ERROR_NONE + ? "" + : " (ignoring -- channel shutting down)"; + gpr_log(GPR_INFO, "chand=%p: update: state=%s status=(%s) picker=%p%s", + chand_, ConnectivityStateName(state), status.ToString().c_str(), + picker.get(), extra); + } + // Do update only if not shutting down. + if (chand_->disconnect_error_ == GRPC_ERROR_NONE) { + chand_->UpdateStateAndPickerLocked(state, status, "helper", + std::move(picker)); + } + } + + void RequestReresolution() override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + if (chand_->resolver_ == nullptr) return; // Shutting down. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: started name re-resolving", chand_); + } + chand_->resolver_->RequestReresolutionLocked(); + } + + absl::string_view GetAuthority() override { + return chand_->default_authority_; + } + + void AddTraceEvent(TraceSeverity severity, absl::string_view message) override + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand_->work_serializer_) { + if (chand_->resolver_ == nullptr) return; // Shutting down. + if (chand_->channelz_node_ != nullptr) { + chand_->channelz_node_->AddTraceEvent( + ConvertSeverityEnum(severity), + grpc_slice_from_copied_buffer(message.data(), message.size())); + } + } + + private: + static channelz::ChannelTrace::Severity ConvertSeverityEnum( + TraceSeverity severity) { + if (severity == TRACE_INFO) return channelz::ChannelTrace::Info; + if (severity == TRACE_WARNING) return channelz::ChannelTrace::Warning; + return channelz::ChannelTrace::Error; + } + + ClientChannel* chand_; +}; + +// +// ClientChannel implementation +// + +ClientChannel* ClientChannel::GetFromChannel(grpc_channel* channel) { + grpc_channel_element* elem = + grpc_channel_stack_last_element(grpc_channel_get_channel_stack(channel)); + if (elem->filter != &kFilterVtable) return nullptr; + return static_cast(elem->channel_data); +} + +grpc_error_handle ClientChannel::Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(args->is_last); + GPR_ASSERT(elem->filter == &kFilterVtable); + grpc_error_handle error = GRPC_ERROR_NONE; + new (elem->channel_data) ClientChannel(args, &error); + return error; +} + +void ClientChannel::Destroy(grpc_channel_element* elem) { + ClientChannel* chand = static_cast(elem->channel_data); + chand->~ClientChannel(); +} + +namespace { + +RefCountedPtr GetSubchannelPool( + const grpc_channel_args* args) { + const bool use_local_subchannel_pool = grpc_channel_args_find_bool( + args, GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, false); + if (use_local_subchannel_pool) { + return MakeRefCounted(); + } + return GlobalSubchannelPool::instance(); +} + +channelz::ChannelNode* GetChannelzNode(const grpc_channel_args* args) { + return grpc_channel_args_find_pointer( + args, GRPC_ARG_CHANNELZ_CHANNEL_NODE); +} + +} // namespace + +ClientChannel::ClientChannel(grpc_channel_element_args* args, + grpc_error_handle* error) + : deadline_checking_enabled_( + grpc_deadline_checking_enabled(args->channel_args)), + owning_stack_(args->channel_stack), + client_channel_factory_( + ClientChannelFactory::GetFromChannelArgs(args->channel_args)), + channelz_node_(GetChannelzNode(args->channel_args)), + interested_parties_(grpc_pollset_set_create()), + work_serializer_(std::make_shared()), + state_tracker_("client_channel", GRPC_CHANNEL_IDLE), + subchannel_pool_(GetSubchannelPool(args->channel_args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: creating client_channel for channel stack %p", + this, owning_stack_); + } + // Start backup polling. + grpc_client_channel_start_backup_polling(interested_parties_); + // Check client channel factory. + if (client_channel_factory_ == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing client channel factory in args for client channel filter"); + return; + } + // Get default service config. If none is specified via the client API, + // we use an empty config. + const char* service_config_json = grpc_channel_args_find_string( + args->channel_args, GRPC_ARG_SERVICE_CONFIG); + if (service_config_json == nullptr) service_config_json = "{}"; + *error = GRPC_ERROR_NONE; + default_service_config_ = + ServiceConfig::Create(args->channel_args, service_config_json, error); + if (*error != GRPC_ERROR_NONE) { + default_service_config_.reset(); + return; + } + // Get URI to resolve, using proxy mapper if needed. + const char* server_uri = + grpc_channel_args_find_string(args->channel_args, GRPC_ARG_SERVER_URI); + if (server_uri == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "target URI channel arg missing or wrong type in client channel " + "filter"); + return; + } + uri_to_resolve_ = server_uri; + char* proxy_name = nullptr; + grpc_channel_args* new_args = nullptr; + ProxyMapperRegistry::MapName(server_uri, args->channel_args, &proxy_name, + &new_args); + if (proxy_name != nullptr) { + uri_to_resolve_ = proxy_name; + gpr_free(proxy_name); + } + // Make sure the URI to resolve is valid, so that we know that + // resolver creation will succeed later. + if (!ResolverRegistry::IsValidTarget(uri_to_resolve_)) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("the target uri is not valid: ", uri_to_resolve_.c_str())); + return; + } + // Strip out service config channel arg, so that it doesn't affect + // subchannel uniqueness when the args flow down to that layer. + const char* arg_to_remove = GRPC_ARG_SERVICE_CONFIG; + channel_args_ = grpc_channel_args_copy_and_remove( + new_args != nullptr ? new_args : args->channel_args, &arg_to_remove, 1); + grpc_channel_args_destroy(new_args); + // Set initial keepalive time. + keepalive_time_ = grpc_channel_args_find_integer( + channel_args_, GRPC_ARG_KEEPALIVE_TIME_MS, + {-1 /* default value, unset */, 1, INT_MAX}); + // Set default authority. + const char* default_authority = + grpc_channel_args_find_string(channel_args_, GRPC_ARG_DEFAULT_AUTHORITY); + if (default_authority == nullptr) { + default_authority_ = ResolverRegistry::GetDefaultAuthority(server_uri); + } else { + default_authority_ = default_authority; + } + // Success. + *error = GRPC_ERROR_NONE; +} + +ClientChannel::~ClientChannel() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: destroying channel", this); + } + DestroyResolverAndLbPolicyLocked(); + grpc_channel_args_destroy(channel_args_); + GRPC_ERROR_UNREF(resolver_transient_failure_error_); + // Stop backup polling. + grpc_client_channel_stop_backup_polling(interested_parties_); + grpc_pollset_set_destroy(interested_parties_); + GRPC_ERROR_UNREF(disconnect_error_); +} + +OrphanablePtr +ClientChannel::CreateLoadBalancedCall( + const grpc_call_element_args& args, grpc_polling_entity* pollent, + grpc_closure* on_call_destruction_complete, + ConfigSelector::CallDispatchController* call_dispatch_controller, + bool is_transparent_retry) { + return OrphanablePtr(args.arena->New( + this, args, pollent, on_call_destruction_complete, + call_dispatch_controller, is_transparent_retry)); +} + +namespace { + +RefCountedPtr ChooseLbPolicy( + const Resolver::Result& resolver_result, + const internal::ClientChannelGlobalParsedConfig* parsed_service_config) { + // Prefer the LB policy config found in the service config. + if (parsed_service_config->parsed_lb_config() != nullptr) { + return parsed_service_config->parsed_lb_config(); + } + // Try the deprecated LB policy name from the service config. + // If not, try the setting from channel args. + const char* policy_name = nullptr; + if (!parsed_service_config->parsed_deprecated_lb_policy().empty()) { + policy_name = parsed_service_config->parsed_deprecated_lb_policy().c_str(); + } else { + policy_name = grpc_channel_args_find_string(resolver_result.args, + GRPC_ARG_LB_POLICY_NAME); + } + // Use pick_first if nothing was specified and we didn't select grpclb + // above. + if (policy_name == nullptr) policy_name = "pick_first"; + // Now that we have the policy name, construct an empty config for it. + Json config_json = Json::Array{Json::Object{ + {policy_name, Json::Object{}}, + }}; + grpc_error_handle parse_error = GRPC_ERROR_NONE; + auto lb_policy_config = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + config_json, &parse_error); + // The policy name came from one of three places: + // - The deprecated loadBalancingPolicy field in the service config, + // in which case the code in ClientChannelServiceConfigParser + // already verified that the policy does not require a config. + // - One of the hard-coded values here, all of which are known to not + // require a config. + // - A channel arg, in which case the application did something that + // is a misuse of our API. + // In the first two cases, these assertions will always be true. In + // the last case, this is probably fine for now. + // TODO(roth): If the last case becomes a problem, add better error + // handling here. + GPR_ASSERT(lb_policy_config != nullptr); + GPR_ASSERT(parse_error == GRPC_ERROR_NONE); + return lb_policy_config; +} + +} // namespace + +void ClientChannel::OnResolverResultChangedLocked(Resolver::Result result) { + // Handle race conditions. + if (resolver_ == nullptr) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: got resolver result", this); + } + // We only want to trace the address resolution in the follow cases: + // (a) Address resolution resulted in service config change. + // (b) Address resolution that causes number of backends to go from + // zero to non-zero. + // (c) Address resolution that causes number of backends to go from + // non-zero to zero. + // (d) Address resolution that causes a new LB policy to be created. + // + // We track a list of strings to eventually be concatenated and traced. + absl::InlinedVector trace_strings; + if (result.addresses.empty() && previous_resolution_contained_addresses_) { + trace_strings.push_back("Address list became empty"); + } else if (!result.addresses.empty() && + !previous_resolution_contained_addresses_) { + trace_strings.push_back("Address list became non-empty"); + } + previous_resolution_contained_addresses_ = !result.addresses.empty(); + std::string service_config_error_string_storage; + if (result.service_config_error != GRPC_ERROR_NONE) { + service_config_error_string_storage = + grpc_error_std_string(result.service_config_error); + trace_strings.push_back(service_config_error_string_storage.c_str()); + } + // Choose the service config. + RefCountedPtr service_config; + RefCountedPtr config_selector; + if (result.service_config_error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: resolver returned service config error: %s", + this, grpc_error_std_string(result.service_config_error).c_str()); + } + // If the service config was invalid, then fallback to the + // previously returned service config. + if (saved_service_config_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: resolver returned invalid service config. " + "Continuing to use previous service config.", + this); + } + service_config = saved_service_config_; + config_selector = saved_config_selector_; + } else { + // We received an invalid service config and we don't have a + // previous service config to fall back to. Put the channel into + // TRANSIENT_FAILURE. + OnResolverErrorLocked(GRPC_ERROR_REF(result.service_config_error)); + trace_strings.push_back("no valid service config"); + } + } else if (result.service_config == nullptr) { + // Resolver did not return any service config. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: resolver returned no service config. Using default " + "service config for channel.", + this); + } + service_config = default_service_config_; + } else { + // Use ServiceConfig and ConfigSelector returned by resolver. + service_config = result.service_config; + config_selector = ConfigSelector::GetFromChannelArgs(*result.args); + } + if (service_config != nullptr) { + // Extract global config for client channel. + const internal::ClientChannelGlobalParsedConfig* parsed_service_config = + static_cast( + service_config->GetGlobalParsedConfig( + internal::ClientChannelServiceConfigParser::ParserIndex())); + // Choose LB policy config. + RefCountedPtr lb_policy_config = + ChooseLbPolicy(result, parsed_service_config); + // Check if the ServiceConfig has changed. + const bool service_config_changed = + saved_service_config_ == nullptr || + service_config->json_string() != saved_service_config_->json_string(); + // Check if the ConfigSelector has changed. + const bool config_selector_changed = !ConfigSelector::Equals( + saved_config_selector_.get(), config_selector.get()); + // If either has changed, apply the global parameters now. + if (service_config_changed || config_selector_changed) { + // Update service config in control plane. + UpdateServiceConfigInControlPlaneLocked(std::move(service_config), + std::move(config_selector), + lb_policy_config->name()); + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: service config not changed", this); + } + // Create or update LB policy, as needed. + CreateOrUpdateLbPolicyLocked( + std::move(lb_policy_config), + parsed_service_config->health_check_service_name(), std::move(result)); + if (service_config_changed || config_selector_changed) { + // Start using new service config for calls. + // This needs to happen after the LB policy has been updated, since + // the ConfigSelector may need the LB policy to know about new + // destinations before it can send RPCs to those destinations. + UpdateServiceConfigInDataPlaneLocked(); + // TODO(ncteisen): might be worth somehow including a snippet of the + // config in the trace, at the risk of bloating the trace logs. + trace_strings.push_back("Service config changed"); + } + } + // Add channel trace event. + if (!trace_strings.empty()) { + std::string message = + absl::StrCat("Resolution event: ", absl::StrJoin(trace_strings, ", ")); + if (channelz_node_ != nullptr) { + channelz_node_->AddTraceEvent(channelz::ChannelTrace::Severity::Info, + grpc_slice_from_cpp_string(message)); + } + } +} + +void ClientChannel::OnResolverErrorLocked(grpc_error_handle error) { + if (resolver_ == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: resolver transient failure: %s", this, + grpc_error_std_string(error).c_str()); + } + // If we already have an LB policy from a previous resolution + // result, then we continue to let it set the connectivity state. + // Otherwise, we go into TRANSIENT_FAILURE. + if (lb_policy_ == nullptr) { + grpc_error_handle state_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Resolver transient failure", &error, 1); + absl::Status status = grpc_error_to_absl_status(state_error); + { + MutexLock lock(&resolution_mu_); + // Update resolver transient failure. + GRPC_ERROR_UNREF(resolver_transient_failure_error_); + resolver_transient_failure_error_ = state_error; + // Process calls that were queued waiting for the resolver result. + for (ResolverQueuedCall* call = resolver_queued_calls_; call != nullptr; + call = call->next) { + grpc_call_element* elem = call->elem; + CallData* calld = static_cast(elem->call_data); + grpc_error_handle error = GRPC_ERROR_NONE; + if (calld->CheckResolutionLocked(elem, &error)) { + calld->AsyncResolutionDone(elem, error); + } + } + } + // Update connectivity state. + UpdateStateAndPickerLocked( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, "resolver failure", + absl::make_unique(status)); + } + GRPC_ERROR_UNREF(error); +} + +void ClientChannel::CreateOrUpdateLbPolicyLocked( + RefCountedPtr lb_policy_config, + const absl::optional& health_check_service_name, + Resolver::Result result) { + // Construct update. + LoadBalancingPolicy::UpdateArgs update_args; + update_args.addresses = std::move(result.addresses); + update_args.config = std::move(lb_policy_config); + // Add health check service name to channel args. + absl::InlinedVector args_to_add; + if (health_check_service_name.has_value()) { + args_to_add.push_back(grpc_channel_arg_string_create( + const_cast(GRPC_ARG_HEALTH_CHECK_SERVICE_NAME), + const_cast(health_check_service_name->c_str()))); + } + // Remove the config selector from channel args so that we're not holding + // unnecessary refs that cause it to be destroyed somewhere other than in the + // WorkSerializer. + const char* arg_to_remove = GRPC_ARG_CONFIG_SELECTOR; + update_args.args = grpc_channel_args_copy_and_add_and_remove( + result.args, &arg_to_remove, 1, args_to_add.data(), args_to_add.size()); + // Create policy if needed. + if (lb_policy_ == nullptr) { + lb_policy_ = CreateLbPolicyLocked(*update_args.args); + } + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: Updating child policy %p", this, + lb_policy_.get()); + } + lb_policy_->UpdateLocked(std::move(update_args)); +} + +// Creates a new LB policy. +OrphanablePtr ClientChannel::CreateLbPolicyLocked( + const grpc_channel_args& args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = work_serializer_; + lb_policy_args.channel_control_helper = + absl::make_unique(this); + lb_policy_args.args = &args; + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_client_channel_routing_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: created new LB policy %p", this, + lb_policy.get()); + } + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + interested_parties_); + return lb_policy; +} + +void ClientChannel::AddResolverQueuedCall(ResolverQueuedCall* call, + grpc_polling_entity* pollent) { + // Add call to queued calls list. + call->next = resolver_queued_calls_; + resolver_queued_calls_ = call; + // Add call's pollent to channel's interested_parties, so that I/O + // can be done under the call's CQ. + grpc_polling_entity_add_to_pollset_set(pollent, interested_parties_); +} + +void ClientChannel::RemoveResolverQueuedCall(ResolverQueuedCall* to_remove, + grpc_polling_entity* pollent) { + // Remove call's pollent from channel's interested_parties. + grpc_polling_entity_del_from_pollset_set(pollent, interested_parties_); + // Remove from queued calls list. + for (ResolverQueuedCall** call = &resolver_queued_calls_; *call != nullptr; + call = &(*call)->next) { + if (*call == to_remove) { + *call = to_remove->next; + return; + } + } +} + +void ClientChannel::UpdateServiceConfigInControlPlaneLocked( + RefCountedPtr service_config, + RefCountedPtr config_selector, const char* lb_policy_name) { + UniquePtr service_config_json( + gpr_strdup(service_config->json_string().c_str())); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p: resolver returned updated service config: \"%s\"", this, + service_config_json.get()); + } + // Save service config. + saved_service_config_ = std::move(service_config); + // Swap out the data used by GetChannelInfo(). + UniquePtr lb_policy_name_owned(gpr_strdup(lb_policy_name)); + { + MutexLock lock(&info_mu_); + info_lb_policy_name_ = std::move(lb_policy_name_owned); + info_service_config_json_ = std::move(service_config_json); + } + // Save config selector. + saved_config_selector_ = std::move(config_selector); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: using ConfigSelector %p", this, + saved_config_selector_.get()); + } +} + +void ClientChannel::UpdateServiceConfigInDataPlaneLocked() { + // Grab ref to service config. + RefCountedPtr service_config = saved_service_config_; + // Grab ref to config selector. Use default if resolver didn't supply one. + RefCountedPtr config_selector = saved_config_selector_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: switching to ConfigSelector %p", this, + saved_config_selector_.get()); + } + if (config_selector == nullptr) { + config_selector = + MakeRefCounted(saved_service_config_); + } + absl::InlinedVector args_to_add = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CLIENT_CHANNEL), this, + &kClientChannelArgPointerVtable), + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_SERVICE_CONFIG_OBJ), service_config.get(), + &kServiceConfigObjArgPointerVtable), + }; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add( + channel_args_, args_to_add.data(), args_to_add.size()); + new_args = config_selector->ModifyChannelArgs(new_args); + bool enable_retries = + grpc_channel_args_find_bool(new_args, GRPC_ARG_ENABLE_RETRIES, true); + // Construct dynamic filter stack. + std::vector filters = + config_selector->GetFilters(); + if (enable_retries) { + filters.push_back(&kRetryFilterVtable); + } else { + filters.push_back(&DynamicTerminationFilter::kFilterVtable); + } + RefCountedPtr dynamic_filters = + DynamicFilters::Create(new_args, std::move(filters)); + GPR_ASSERT(dynamic_filters != nullptr); + grpc_channel_args_destroy(new_args); + // Grab data plane lock to update service config. + // + // We defer unreffing the old values (and deallocating memory) until + // after releasing the lock to keep the critical section small. + { + MutexLock lock(&resolution_mu_); + GRPC_ERROR_UNREF(resolver_transient_failure_error_); + resolver_transient_failure_error_ = GRPC_ERROR_NONE; + // Update service config. + received_service_config_data_ = true; + // Old values will be unreffed after lock is released. + service_config_.swap(service_config); + config_selector_.swap(config_selector); + dynamic_filters_.swap(dynamic_filters); + // Process calls that were queued waiting for the resolver result. + for (ResolverQueuedCall* call = resolver_queued_calls_; call != nullptr; + call = call->next) { + // If there are a lot of queued calls here, resuming them all may cause us + // to stay inside C-core for a long period of time. All of that work would + // be done using the same ExecCtx instance and therefore the same cached + // value of "now". The longer it takes to finish all of this work and exit + // from C-core, the more stale the cached value of "now" may become. This + // can cause problems whereby (e.g.) we calculate a timer deadline based + // on the stale value, which results in the timer firing too early. To + // avoid this, we invalidate the cached value for each call we process. + ExecCtx::Get()->InvalidateNow(); + grpc_call_element* elem = call->elem; + CallData* calld = static_cast(elem->call_data); + grpc_error_handle error = GRPC_ERROR_NONE; + if (calld->CheckResolutionLocked(elem, &error)) { + calld->AsyncResolutionDone(elem, error); + } + } + } + // Old values will be unreffed after lock is released when they go out + // of scope. +} + +void ClientChannel::CreateResolverLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: starting name resolution", this); + } + resolver_ = ResolverRegistry::CreateResolver( + uri_to_resolve_.c_str(), channel_args_, interested_parties_, + work_serializer_, absl::make_unique(this)); + // Since the validity of the args was checked when the channel was created, + // CreateResolver() must return a non-null result. + GPR_ASSERT(resolver_ != nullptr); + UpdateStateAndPickerLocked( + GRPC_CHANNEL_CONNECTING, absl::Status(), "started resolving", + absl::make_unique(nullptr)); + resolver_->StartLocked(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: created resolver=%p", this, resolver_.get()); + } +} + +void ClientChannel::DestroyResolverAndLbPolicyLocked() { + if (resolver_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: shutting down resolver=%p", this, + resolver_.get()); + } + resolver_.reset(); + if (lb_policy_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", this, + lb_policy_.get()); + } + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + } +} + +void ClientChannel::UpdateStateAndPickerLocked( + grpc_connectivity_state state, const absl::Status& status, + const char* reason, + std::unique_ptr picker) { + // Special case for IDLE and SHUTDOWN states. + if (picker == nullptr || state == GRPC_CHANNEL_SHUTDOWN) { + saved_service_config_.reset(); + saved_config_selector_.reset(); + // Acquire resolution lock to update config selector and associated state. + // To minimize lock contention, we wait to unref these objects until + // after we release the lock. + RefCountedPtr service_config_to_unref; + RefCountedPtr config_selector_to_unref; + RefCountedPtr dynamic_filters_to_unref; + { + MutexLock lock(&resolution_mu_); + received_service_config_data_ = false; + service_config_to_unref = std::move(service_config_); + config_selector_to_unref = std::move(config_selector_); + dynamic_filters_to_unref = std::move(dynamic_filters_); + } + } + // Update connectivity state. + state_tracker_.SetState(state, status, reason); + if (channelz_node_ != nullptr) { + channelz_node_->SetConnectivityState(state); + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string( + channelz::ChannelNode::GetChannelConnectivityStateChangeString( + state))); + } + // Grab data plane lock to update the picker. + { + MutexLock lock(&data_plane_mu_); + // Swap out the picker. + // Note: Original value will be destroyed after the lock is released. + picker_.swap(picker); + // Re-process queued picks. + for (LbQueuedCall* call = lb_queued_calls_; call != nullptr; + call = call->next) { + // If there are a lot of queued calls here, resuming them all may cause us + // to stay inside C-core for a long period of time. All of that work would + // be done using the same ExecCtx instance and therefore the same cached + // value of "now". The longer it takes to finish all of this work and exit + // from C-core, the more stale the cached value of "now" may become. This + // can cause problems whereby (e.g.) we calculate a timer deadline based + // on the stale value, which results in the timer firing too early. To + // avoid this, we invalidate the cached value for each call we process. + ExecCtx::Get()->InvalidateNow(); + grpc_error_handle error = GRPC_ERROR_NONE; + if (call->lb_call->PickSubchannelLocked(&error)) { + call->lb_call->AsyncPickDone(error); + } + } + } +} + +namespace { + +// TODO(roth): Remove this in favor of the gprpp Match() function once +// we can do that without breaking lock annotations. +template +T HandlePickResult( + LoadBalancingPolicy::PickResult* result, + std::function complete_func, + std::function queue_func, + std::function fail_func, + std::function drop_func) { + auto* complete_pick = + absl::get_if(&result->result); + if (complete_pick != nullptr) { + return complete_func(complete_pick); + } + auto* queue_pick = + absl::get_if(&result->result); + if (queue_pick != nullptr) { + return queue_func(queue_pick); + } + auto* fail_pick = + absl::get_if(&result->result); + if (fail_pick != nullptr) { + return fail_func(fail_pick); + } + auto* drop_pick = + absl::get_if(&result->result); + GPR_ASSERT(drop_pick != nullptr); + return drop_func(drop_pick); +} + +} // namespace + +grpc_error_handle ClientChannel::DoPingLocked(grpc_transport_op* op) { + if (state_tracker_.state() != GRPC_CHANNEL_READY) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("channel not connected"); + } + LoadBalancingPolicy::PickResult result; + { + MutexLock lock(&data_plane_mu_); + result = picker_->Pick(LoadBalancingPolicy::PickArgs()); + } + return HandlePickResult( + &result, + // Complete pick. + [op](LoadBalancingPolicy::PickResult::Complete* complete_pick) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::work_serializer_) { + SubchannelWrapper* subchannel = static_cast( + complete_pick->subchannel.get()); + RefCountedPtr connected_subchannel = + subchannel->connected_subchannel(); + connected_subchannel->Ping(op->send_ping.on_initiate, + op->send_ping.on_ack); + return GRPC_ERROR_NONE; + }, + // Queue pick. + [](LoadBalancingPolicy::PickResult::Queue* /*queue_pick*/) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("LB picker queued call"); + }, + // Fail pick. + [](LoadBalancingPolicy::PickResult::Fail* fail_pick) { + return absl_status_to_grpc_error(fail_pick->status); + }, + // Drop pick. + [](LoadBalancingPolicy::PickResult::Drop* drop_pick) { + return absl_status_to_grpc_error(drop_pick->status); + }); +} + +void ClientChannel::StartTransportOpLocked(grpc_transport_op* op) { + // Connectivity watch. + if (op->start_connectivity_watch != nullptr) { + state_tracker_.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); + } + if (op->stop_connectivity_watch != nullptr) { + state_tracker_.RemoveWatcher(op->stop_connectivity_watch); + } + // Ping. + if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) { + grpc_error_handle error = DoPingLocked(op); + if (error != GRPC_ERROR_NONE) { + ExecCtx::Run(DEBUG_LOCATION, op->send_ping.on_initiate, + GRPC_ERROR_REF(error)); + ExecCtx::Run(DEBUG_LOCATION, op->send_ping.on_ack, error); + } + op->bind_pollset = nullptr; + op->send_ping.on_initiate = nullptr; + op->send_ping.on_ack = nullptr; + } + // Reset backoff. + if (op->reset_connect_backoff) { + if (lb_policy_ != nullptr) { + lb_policy_->ResetBackoffLocked(); + } + } + // Disconnect or enter IDLE. + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p: disconnect_with_error: %s", this, + grpc_error_std_string(op->disconnect_with_error).c_str()); + } + DestroyResolverAndLbPolicyLocked(); + intptr_t value; + if (grpc_error_get_int(op->disconnect_with_error, + GRPC_ERROR_INT_CHANNEL_CONNECTIVITY_STATE, &value) && + static_cast(value) == GRPC_CHANNEL_IDLE) { + if (disconnect_error_ == GRPC_ERROR_NONE) { + // Enter IDLE state. + UpdateStateAndPickerLocked(GRPC_CHANNEL_IDLE, absl::Status(), + "channel entering IDLE", nullptr); + } + GRPC_ERROR_UNREF(op->disconnect_with_error); + } else { + // Disconnect. + GPR_ASSERT(disconnect_error_ == GRPC_ERROR_NONE); + disconnect_error_ = op->disconnect_with_error; + UpdateStateAndPickerLocked( + GRPC_CHANNEL_SHUTDOWN, absl::Status(), "shutdown from API", + absl::make_unique( + grpc_error_to_absl_status(op->disconnect_with_error))); + } + } + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "start_transport_op"); + ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); +} + +void ClientChannel::StartTransportOp(grpc_channel_element* elem, + grpc_transport_op* op) { + ClientChannel* chand = static_cast(elem->channel_data); + GPR_ASSERT(op->set_accept_stream == false); + // Handle bind_pollset. + if (op->bind_pollset != nullptr) { + grpc_pollset_set_add_pollset(chand->interested_parties_, op->bind_pollset); + } + // Pop into control plane work_serializer for remaining ops. + GRPC_CHANNEL_STACK_REF(chand->owning_stack_, "start_transport_op"); + chand->work_serializer_->Run( + [chand, op]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand->work_serializer_) { + chand->StartTransportOpLocked(op); + }, + DEBUG_LOCATION); +} + +void ClientChannel::GetChannelInfo(grpc_channel_element* elem, + const grpc_channel_info* info) { + ClientChannel* chand = static_cast(elem->channel_data); + MutexLock lock(&chand->info_mu_); + if (info->lb_policy_name != nullptr) { + *info->lb_policy_name = gpr_strdup(chand->info_lb_policy_name_.get()); + } + if (info->service_config_json != nullptr) { + *info->service_config_json = + gpr_strdup(chand->info_service_config_json_.get()); + } +} + +void ClientChannel::AddLbQueuedCall(LbQueuedCall* call, + grpc_polling_entity* pollent) { + // Add call to queued picks list. + call->next = lb_queued_calls_; + lb_queued_calls_ = call; + // Add call's pollent to channel's interested_parties, so that I/O + // can be done under the call's CQ. + grpc_polling_entity_add_to_pollset_set(pollent, interested_parties_); +} + +void ClientChannel::RemoveLbQueuedCall(LbQueuedCall* to_remove, + grpc_polling_entity* pollent) { + // Remove call's pollent from channel's interested_parties. + grpc_polling_entity_del_from_pollset_set(pollent, interested_parties_); + // Remove from queued picks list. + for (LbQueuedCall** call = &lb_queued_calls_; *call != nullptr; + call = &(*call)->next) { + if (*call == to_remove) { + *call = to_remove->next; + return; + } + } +} + +void ClientChannel::TryToConnectLocked() { + if (lb_policy_ != nullptr) { + lb_policy_->ExitIdleLocked(); + } else if (resolver_ == nullptr) { + CreateResolverLocked(); + } + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "TryToConnect"); +} + +grpc_connectivity_state ClientChannel::CheckConnectivityState( + bool try_to_connect) { + // state_tracker_ is guarded by work_serializer_, which we're not + // holding here. But the one method of state_tracker_ that *is* + // thread-safe to call without external synchronization is the state() + // method, so we can disable thread-safety analysis for this one read. + grpc_connectivity_state out = ABSL_TS_UNCHECKED_READ(state_tracker_).state(); + if (out == GRPC_CHANNEL_IDLE && try_to_connect) { + GRPC_CHANNEL_STACK_REF(owning_stack_, "TryToConnect"); + work_serializer_->Run([this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED( + work_serializer_) { TryToConnectLocked(); }, + DEBUG_LOCATION); + } + return out; +} + +void ClientChannel::AddConnectivityWatcher( + grpc_connectivity_state initial_state, + OrphanablePtr watcher) { + new ConnectivityWatcherAdder(this, initial_state, std::move(watcher)); +} + +void ClientChannel::RemoveConnectivityWatcher( + AsyncConnectivityStateWatcherInterface* watcher) { + new ConnectivityWatcherRemover(this, watcher); +} + +// +// CallData implementation +// + +ClientChannel::CallData::CallData(grpc_call_element* elem, + const ClientChannel& chand, + const grpc_call_element_args& args) + : deadline_state_(elem, args, + GPR_LIKELY(chand.deadline_checking_enabled_) + ? args.deadline + : GRPC_MILLIS_INF_FUTURE), + path_(grpc_slice_ref_internal(args.path)), + call_start_time_(args.start_time), + deadline_(args.deadline), + arena_(args.arena), + owning_call_(args.call_stack), + call_combiner_(args.call_combiner), + call_context_(args.context) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: created call", &chand, this); + } +} + +ClientChannel::CallData::~CallData() { + grpc_slice_unref_internal(path_); + GRPC_ERROR_UNREF(cancel_error_); + // Make sure there are no remaining pending batches. + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + GPR_ASSERT(pending_batches_[i] == nullptr); + } +} + +grpc_error_handle ClientChannel::CallData::Init( + grpc_call_element* elem, const grpc_call_element_args* args) { + ClientChannel* chand = static_cast(elem->channel_data); + new (elem->call_data) CallData(elem, *chand, *args); + return GRPC_ERROR_NONE; +} + +void ClientChannel::CallData::Destroy( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure) { + CallData* calld = static_cast(elem->call_data); + RefCountedPtr dynamic_call = + std::move(calld->dynamic_call_); + calld->~CallData(); + if (GPR_LIKELY(dynamic_call != nullptr)) { + dynamic_call->SetAfterCallStackDestroy(then_schedule_closure); + } else { + // TODO(yashkt) : This can potentially be a Closure::Run + ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, GRPC_ERROR_NONE); + } +} + +void ClientChannel::CallData::StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + GPR_TIMER_SCOPE("cc_start_transport_stream_op_batch", 0); + CallData* calld = static_cast(elem->call_data); + ClientChannel* chand = static_cast(elem->channel_data); + if (GPR_LIKELY(chand->deadline_checking_enabled_)) { + grpc_deadline_state_client_start_transport_stream_op_batch(elem, batch); + } + // Intercept recv_trailing_metadata to call CallDispatchController::Commit(), + // in case we wind up failing the call before we get down to the retry + // or LB call layer. + if (batch->recv_trailing_metadata) { + calld->InjectRecvTrailingMetadataReadyForConfigSelectorCommitCallback( + batch); + } + // If we already have a dynamic call, pass the batch down to it. + // Note that once we have done so, we do not need to acquire the channel's + // resolution mutex, which is more efficient (especially for streaming calls). + if (calld->dynamic_call_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: starting batch on dynamic_call=%p", + chand, calld, calld->dynamic_call_.get()); + } + calld->dynamic_call_->StartTransportStreamOpBatch(batch); + return; + } + // We do not yet have a dynamic call. + // + // If we've previously been cancelled, immediately fail any new batches. + if (GPR_UNLIKELY(calld->cancel_error_ != GRPC_ERROR_NONE)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: failing batch with error: %s", + chand, calld, + grpc_error_std_string(calld->cancel_error_).c_str()); + } + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(calld->cancel_error_), calld->call_combiner_); + return; + } + // Handle cancellation. + if (GPR_UNLIKELY(batch->cancel_stream)) { + // Stash a copy of cancel_error in our call data, so that we can use + // it for subsequent operations. This ensures that if the call is + // cancelled before any batches are passed down (e.g., if the deadline + // is in the past when the call starts), we can return the right + // error to the caller when the first batch does get passed down. + GRPC_ERROR_UNREF(calld->cancel_error_); + calld->cancel_error_ = + GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: recording cancel_error=%s", chand, + calld, grpc_error_std_string(calld->cancel_error_).c_str()); + } + // Fail all pending batches. + calld->PendingBatchesFail(elem, GRPC_ERROR_REF(calld->cancel_error_), + NoYieldCallCombiner); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(calld->cancel_error_), calld->call_combiner_); + return; + } + // Add the batch to the pending list. + calld->PendingBatchesAdd(elem, batch); + // For batches containing a send_initial_metadata op, acquire the + // channel's resolution mutex to apply the service config to the call, + // after which we will create a dynamic call. + if (GPR_LIKELY(batch->send_initial_metadata)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: grabbing resolution mutex to apply service " + "config", + chand, calld); + } + CheckResolution(elem, GRPC_ERROR_NONE); + } else { + // For all other batches, release the call combiner. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: saved batch, yielding call combiner", chand, + calld); + } + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "batch does not include send_initial_metadata"); + } +} + +void ClientChannel::CallData::SetPollent(grpc_call_element* elem, + grpc_polling_entity* pollent) { + CallData* calld = static_cast(elem->call_data); + calld->pollent_ = pollent; +} + +// +// pending_batches management +// + +size_t ClientChannel::CallData::GetBatchIndex( + grpc_transport_stream_op_batch* batch) { + // Note: It is important the send_initial_metadata be the first entry + // here, since the code in ApplyServiceConfigToCallLocked() and + // CheckResolutionLocked() assumes it will be. + if (batch->send_initial_metadata) return 0; + if (batch->send_message) return 1; + if (batch->send_trailing_metadata) return 2; + if (batch->recv_initial_metadata) return 3; + if (batch->recv_message) return 4; + if (batch->recv_trailing_metadata) return 5; + GPR_UNREACHABLE_CODE(return (size_t)-1); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::CallData::PendingBatchesAdd( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + ClientChannel* chand = static_cast(elem->channel_data); + const size_t idx = GetBatchIndex(batch); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: adding pending batch at index %" PRIuPTR, chand, + this, idx); + } + grpc_transport_stream_op_batch*& pending = pending_batches_[idx]; + GPR_ASSERT(pending == nullptr); + pending = batch; +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::CallData::FailPendingBatchInCallCombiner( + void* arg, grpc_error_handle error) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + CallData* calld = static_cast(batch->handler_private.extra_arg); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(error), calld->call_combiner_); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::CallData::PendingBatchesFail( + grpc_call_element* elem, grpc_error_handle error, + YieldCallCombinerPredicate yield_call_combiner_predicate) { + GPR_ASSERT(error != GRPC_ERROR_NONE); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + size_t num_batches = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + if (pending_batches_[i] != nullptr) ++num_batches; + } + gpr_log(GPR_INFO, + "chand=%p calld=%p: failing %" PRIuPTR " pending batches: %s", + elem->channel_data, this, num_batches, + grpc_error_std_string(error).c_str()); + } + CallCombinerClosureList closures; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + grpc_transport_stream_op_batch*& batch = pending_batches_[i]; + if (batch != nullptr) { + batch->handler_private.extra_arg = this; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, + FailPendingBatchInCallCombiner, batch, + grpc_schedule_on_exec_ctx); + closures.Add(&batch->handler_private.closure, GRPC_ERROR_REF(error), + "PendingBatchesFail"); + batch = nullptr; + } + } + if (yield_call_combiner_predicate(closures)) { + closures.RunClosures(call_combiner_); + } else { + closures.RunClosuresWithoutYielding(call_combiner_); + } + GRPC_ERROR_UNREF(error); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::CallData::ResumePendingBatchInCallCombiner( + void* arg, grpc_error_handle /*ignored*/) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + auto* elem = + static_cast(batch->handler_private.extra_arg); + auto* calld = static_cast(elem->call_data); + // Note: This will release the call combiner. + calld->dynamic_call_->StartTransportStreamOpBatch(batch); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::CallData::PendingBatchesResume(grpc_call_element* elem) { + ClientChannel* chand = static_cast(elem->channel_data); + // Retries not enabled; send down batches as-is. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + size_t num_batches = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + if (pending_batches_[i] != nullptr) ++num_batches; + } + gpr_log(GPR_INFO, + "chand=%p calld=%p: starting %" PRIuPTR + " pending batches on dynamic_call=%p", + chand, this, num_batches, dynamic_call_.get()); + } + CallCombinerClosureList closures; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + grpc_transport_stream_op_batch*& batch = pending_batches_[i]; + if (batch != nullptr) { + batch->handler_private.extra_arg = elem; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, + ResumePendingBatchInCallCombiner, batch, nullptr); + closures.Add(&batch->handler_private.closure, GRPC_ERROR_NONE, + "resuming pending batch from client channel call"); + batch = nullptr; + } + } + // Note: This will release the call combiner. + closures.RunClosures(call_combiner_); +} + +// +// name resolution +// + +// A class to handle the call combiner cancellation callback for a +// queued pick. +class ClientChannel::CallData::ResolverQueuedCallCanceller { + public: + explicit ResolverQueuedCallCanceller(grpc_call_element* elem) : elem_(elem) { + auto* calld = static_cast(elem->call_data); + GRPC_CALL_STACK_REF(calld->owning_call_, "ResolverQueuedCallCanceller"); + GRPC_CLOSURE_INIT(&closure_, &CancelLocked, this, + grpc_schedule_on_exec_ctx); + calld->call_combiner_->SetNotifyOnCancel(&closure_); + } + + private: + static void CancelLocked(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + auto* chand = static_cast(self->elem_->channel_data); + auto* calld = static_cast(self->elem_->call_data); + { + MutexLock lock(&chand->resolution_mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: cancelling resolver queued pick: " + "error=%s self=%p calld->resolver_pick_canceller=%p", + chand, calld, grpc_error_std_string(error).c_str(), self, + calld->resolver_call_canceller_); + } + if (calld->resolver_call_canceller_ == self && error != GRPC_ERROR_NONE) { + // Remove pick from list of queued picks. + calld->MaybeRemoveCallFromResolverQueuedCallsLocked(self->elem_); + // Fail pending batches on the call. + calld->PendingBatchesFail(self->elem_, GRPC_ERROR_REF(error), + YieldCallCombinerIfPendingBatchesFound); + } + } + GRPC_CALL_STACK_UNREF(calld->owning_call_, "ResolvingQueuedCallCanceller"); + delete self; + } + + grpc_call_element* elem_; + grpc_closure closure_; +}; + +void ClientChannel::CallData::MaybeRemoveCallFromResolverQueuedCallsLocked( + grpc_call_element* elem) { + if (!queued_pending_resolver_result_) return; + auto* chand = static_cast(elem->channel_data); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: removing from resolver queued picks list", + chand, this); + } + chand->RemoveResolverQueuedCall(&resolver_queued_call_, pollent_); + queued_pending_resolver_result_ = false; + // Lame the call combiner canceller. + resolver_call_canceller_ = nullptr; +} + +void ClientChannel::CallData::MaybeAddCallToResolverQueuedCallsLocked( + grpc_call_element* elem) { + if (queued_pending_resolver_result_) return; + auto* chand = static_cast(elem->channel_data); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: adding to resolver queued picks list", + chand, this); + } + queued_pending_resolver_result_ = true; + resolver_queued_call_.elem = elem; + chand->AddResolverQueuedCall(&resolver_queued_call_, pollent_); + // Register call combiner cancellation callback. + resolver_call_canceller_ = new ResolverQueuedCallCanceller(elem); +} + +grpc_error_handle ClientChannel::CallData::ApplyServiceConfigToCallLocked( + grpc_call_element* elem, grpc_metadata_batch* initial_metadata) { + ClientChannel* chand = static_cast(elem->channel_data); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: applying service config to call", + chand, this); + } + ConfigSelector* config_selector = chand->config_selector_.get(); + if (config_selector != nullptr) { + // Use the ConfigSelector to determine the config for the call. + ConfigSelector::CallConfig call_config = + config_selector->GetCallConfig({&path_, initial_metadata, arena_}); + if (call_config.error != GRPC_ERROR_NONE) return call_config.error; + // Create a ClientChannelServiceConfigCallData for the call. This stores + // a ref to the ServiceConfig and caches the right set of parsed configs + // to use for the call. The ClientChannelServiceConfigCallData will store + // itself in the call context, so that it can be accessed by filters + // below us in the stack, and it will be cleaned up when the call ends. + auto* service_config_call_data = + arena_->New( + std::move(call_config.service_config), call_config.method_configs, + std::move(call_config.call_attributes), + call_config.call_dispatch_controller, call_context_); + // Apply our own method params to the call. + auto* method_params = static_cast( + service_config_call_data->GetMethodParsedConfig( + internal::ClientChannelServiceConfigParser::ParserIndex())); + if (method_params != nullptr) { + // If the deadline from the service config is shorter than the one + // from the client API, reset the deadline timer. + if (chand->deadline_checking_enabled_ && method_params->timeout() != 0) { + const grpc_millis per_method_deadline = + grpc_cycle_counter_to_millis_round_up(call_start_time_) + + method_params->timeout(); + if (per_method_deadline < deadline_) { + deadline_ = per_method_deadline; + grpc_deadline_state_reset(elem, deadline_); + } + } + // If the service config set wait_for_ready and the application + // did not explicitly set it, use the value from the service config. + uint32_t* send_initial_metadata_flags = + &pending_batches_[0] + ->payload->send_initial_metadata.send_initial_metadata_flags; + if (method_params->wait_for_ready().has_value() && + !(*send_initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET)) { + if (method_params->wait_for_ready().value()) { + *send_initial_metadata_flags |= GRPC_INITIAL_METADATA_WAIT_FOR_READY; + } else { + *send_initial_metadata_flags &= ~GRPC_INITIAL_METADATA_WAIT_FOR_READY; + } + } + } + // Set the dynamic filter stack. + dynamic_filters_ = chand->dynamic_filters_; + } + return GRPC_ERROR_NONE; +} + +void ClientChannel::CallData:: + RecvTrailingMetadataReadyForConfigSelectorCommitCallback( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + auto* service_config_call_data = + static_cast( + self->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + if (service_config_call_data != nullptr) { + service_config_call_data->call_dispatch_controller()->Commit(); + } + // Chain to original callback. + Closure::Run(DEBUG_LOCATION, self->original_recv_trailing_metadata_ready_, + GRPC_ERROR_REF(error)); +} + +void ClientChannel::CallData:: + InjectRecvTrailingMetadataReadyForConfigSelectorCommitCallback( + grpc_transport_stream_op_batch* batch) { + original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, + RecvTrailingMetadataReadyForConfigSelectorCommitCallback, + this, nullptr); + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &recv_trailing_metadata_ready_; +} + +void ClientChannel::CallData::AsyncResolutionDone(grpc_call_element* elem, + grpc_error_handle error) { + // TODO(roth): Does this callback need to hold a ref to the call stack? + GRPC_CLOSURE_INIT(&resolution_done_closure_, ResolutionDone, elem, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &resolution_done_closure_, error); +} + +void ClientChannel::CallData::ResolutionDone(void* arg, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + ClientChannel* chand = static_cast(elem->channel_data); + CallData* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: error applying config to call: error=%s", + chand, calld, grpc_error_std_string(error).c_str()); + } + calld->PendingBatchesFail(elem, GRPC_ERROR_REF(error), YieldCallCombiner); + return; + } + calld->CreateDynamicCall(elem); +} + +void ClientChannel::CallData::CheckResolution(void* arg, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + CallData* calld = static_cast(elem->call_data); + ClientChannel* chand = static_cast(elem->channel_data); + bool resolution_complete; + { + MutexLock lock(&chand->resolution_mu_); + resolution_complete = calld->CheckResolutionLocked(elem, &error); + } + if (resolution_complete) { + ResolutionDone(elem, error); + GRPC_ERROR_UNREF(error); + } +} + +bool ClientChannel::CallData::CheckResolutionLocked(grpc_call_element* elem, + grpc_error_handle* error) { + ClientChannel* chand = static_cast(elem->channel_data); + // If we're still in IDLE, we need to start resolving. + if (GPR_UNLIKELY(chand->CheckConnectivityState(false) == GRPC_CHANNEL_IDLE)) { + // Bounce into the control plane work serializer to start resolving, + // in case we are still in IDLE state. Since we are holding on to the + // resolution mutex here, we offload it on the ExecCtx so that we don't + // deadlock with ourselves. + GRPC_CHANNEL_STACK_REF(chand->owning_stack_, "CheckResolutionLocked"); + ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle /*error*/) { + auto* chand = static_cast(arg); + chand->work_serializer_->Run( + [chand]() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(chand->work_serializer_) { + chand->CheckConnectivityState(/*try_to_connect=*/true); + GRPC_CHANNEL_STACK_UNREF(chand->owning_stack_, + "CheckResolutionLocked"); + }, + DEBUG_LOCATION); + }, + chand, nullptr), + GRPC_ERROR_NONE); + } + // Get send_initial_metadata batch and flags. + auto& send_initial_metadata = + pending_batches_[0]->payload->send_initial_metadata; + grpc_metadata_batch* initial_metadata_batch = + send_initial_metadata.send_initial_metadata; + const uint32_t send_initial_metadata_flags = + send_initial_metadata.send_initial_metadata_flags; + // If we don't yet have a resolver result, we need to queue the call + // until we get one. + if (GPR_UNLIKELY(!chand->received_service_config_data_)) { + // If the resolver returned transient failure before returning the + // first service config, fail any non-wait_for_ready calls. + grpc_error_handle resolver_error = chand->resolver_transient_failure_error_; + if (resolver_error != GRPC_ERROR_NONE && + (send_initial_metadata_flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) == + 0) { + MaybeRemoveCallFromResolverQueuedCallsLocked(elem); + *error = GRPC_ERROR_REF(resolver_error); + return true; + } + // Either the resolver has not yet returned a result, or it has + // returned transient failure but the call is wait_for_ready. In + // either case, queue the call. + MaybeAddCallToResolverQueuedCallsLocked(elem); + return false; + } + // Apply service config to call if not yet applied. + if (GPR_LIKELY(!service_config_applied_)) { + service_config_applied_ = true; + *error = ApplyServiceConfigToCallLocked(elem, initial_metadata_batch); + } + MaybeRemoveCallFromResolverQueuedCallsLocked(elem); + return true; +} + +void ClientChannel::CallData::CreateDynamicCall(grpc_call_element* elem) { + auto* chand = static_cast(elem->channel_data); + DynamicFilters::Call::Args args = {std::move(dynamic_filters_), + pollent_, + path_, + call_start_time_, + deadline_, + arena_, + call_context_, + call_combiner_}; + grpc_error_handle error = GRPC_ERROR_NONE; + DynamicFilters* channel_stack = args.channel_stack.get(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log( + GPR_INFO, + "chand=%p calld=%p: creating dynamic call stack on channel_stack=%p", + chand, this, channel_stack); + } + dynamic_call_ = channel_stack->CreateCall(std::move(args), &error); + if (error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: failed to create dynamic call: error=%s", + chand, this, grpc_error_std_string(error).c_str()); + } + PendingBatchesFail(elem, error, YieldCallCombiner); + return; + } + PendingBatchesResume(elem); +} + +// +// ClientChannel::LoadBalancedCall::Metadata +// + +class ClientChannel::LoadBalancedCall::Metadata + : public LoadBalancingPolicy::MetadataInterface { + public: + Metadata(LoadBalancedCall* lb_call, grpc_metadata_batch* batch) + : lb_call_(lb_call), batch_(batch) {} + + void Add(absl::string_view key, absl::string_view value) override { + grpc_linked_mdelem* linked_mdelem = static_cast( + lb_call_->arena_->Alloc(sizeof(grpc_linked_mdelem))); + linked_mdelem->md = grpc_mdelem_from_slices( + ExternallyManagedSlice(key.data(), key.size()), + ExternallyManagedSlice(value.data(), value.size())); + GPR_ASSERT(batch_->LinkTail(linked_mdelem) == GRPC_ERROR_NONE); + } + + std::vector> TestOnlyCopyToVector() + override { + std::vector> result; + batch_->ForEach([&](grpc_mdelem md) { + auto key = std::string(StringViewFromSlice(GRPC_MDKEY(md))); + if (key != ":path") { + result.push_back( + std::make_pair(std::move(key), + std::string(StringViewFromSlice(GRPC_MDVALUE(md))))); + } + }); + return result; + } + + absl::optional Lookup(absl::string_view key, + std::string* buffer) const override { + return batch_->GetValue(key, buffer); + } + + private: + LoadBalancedCall* lb_call_; + grpc_metadata_batch* batch_; +}; + +// +// ClientChannel::LoadBalancedCall::LbCallState +// + +class ClientChannel::LoadBalancedCall::LbCallState + : public LoadBalancingPolicy::CallState { + public: + explicit LbCallState(LoadBalancedCall* lb_call) : lb_call_(lb_call) {} + + void* Alloc(size_t size) override { return lb_call_->arena_->Alloc(size); } + + const LoadBalancingPolicy::BackendMetricData* GetBackendMetricData() + override { + if (lb_call_->backend_metric_data_ == nullptr) { + grpc_linked_mdelem* md = lb_call_->recv_trailing_metadata_->legacy_index() + ->named.x_endpoint_load_metrics_bin; + if (md != nullptr) { + lb_call_->backend_metric_data_ = + ParseBackendMetricData(GRPC_MDVALUE(md->md), lb_call_->arena_); + } + } + return lb_call_->backend_metric_data_; + } + + absl::string_view ExperimentalGetCallAttribute(const char* key) override { + auto* service_config_call_data = static_cast( + lb_call_->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + auto& call_attributes = service_config_call_data->call_attributes(); + auto it = call_attributes.find(key); + if (it == call_attributes.end()) return absl::string_view(); + return it->second; + } + + private: + LoadBalancedCall* lb_call_; +}; + +// +// LoadBalancedCall +// + +namespace { + +CallTracer::CallAttemptTracer* GetCallAttemptTracer( + grpc_call_context_element* context, bool is_transparent_retry) { + auto* call_tracer = + static_cast(context[GRPC_CONTEXT_CALL_TRACER].value); + if (call_tracer == nullptr) return nullptr; + return call_tracer->StartNewAttempt(is_transparent_retry); +} + +} // namespace + +ClientChannel::LoadBalancedCall::LoadBalancedCall( + ClientChannel* chand, const grpc_call_element_args& args, + grpc_polling_entity* pollent, grpc_closure* on_call_destruction_complete, + ConfigSelector::CallDispatchController* call_dispatch_controller, + bool is_transparent_retry) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace) + ? "LoadBalancedCall" + : nullptr), + chand_(chand), + path_(grpc_slice_ref_internal(args.path)), + deadline_(args.deadline), + arena_(args.arena), + owning_call_(args.call_stack), + call_combiner_(args.call_combiner), + call_context_(args.context), + pollent_(pollent), + on_call_destruction_complete_(on_call_destruction_complete), + call_dispatch_controller_(call_dispatch_controller), + call_attempt_tracer_( + GetCallAttemptTracer(args.context, is_transparent_retry)) {} + +ClientChannel::LoadBalancedCall::~LoadBalancedCall() { + grpc_slice_unref_internal(path_); + GRPC_ERROR_UNREF(cancel_error_); + GRPC_ERROR_UNREF(failure_error_); + if (backend_metric_data_ != nullptr) { + backend_metric_data_ + ->LoadBalancingPolicy::BackendMetricData::~BackendMetricData(); + } + // Make sure there are no remaining pending batches. + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + GPR_ASSERT(pending_batches_[i] == nullptr); + } + if (on_call_destruction_complete_ != nullptr) { + ExecCtx::Run(DEBUG_LOCATION, on_call_destruction_complete_, + GRPC_ERROR_NONE); + } +} + +void ClientChannel::LoadBalancedCall::Orphan() { + // Compute latency and report it to the tracer. + if (call_attempt_tracer_ != nullptr) { + gpr_timespec latency = + gpr_cycle_counter_sub(gpr_get_cycle_counter(), lb_call_start_time_); + call_attempt_tracer_->RecordEnd(latency); + } + Unref(); +} + +size_t ClientChannel::LoadBalancedCall::GetBatchIndex( + grpc_transport_stream_op_batch* batch) { + // Note: It is important the send_initial_metadata be the first entry + // here, since the code in PickSubchannelLocked() assumes it will be. + if (batch->send_initial_metadata) return 0; + if (batch->send_message) return 1; + if (batch->send_trailing_metadata) return 2; + if (batch->recv_initial_metadata) return 3; + if (batch->recv_message) return 4; + if (batch->recv_trailing_metadata) return 5; + GPR_UNREACHABLE_CODE(return (size_t)-1); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::LoadBalancedCall::PendingBatchesAdd( + grpc_transport_stream_op_batch* batch) { + const size_t idx = GetBatchIndex(batch); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: adding pending batch at index %" PRIuPTR, + chand_, this, idx); + } + GPR_ASSERT(pending_batches_[idx] == nullptr); + pending_batches_[idx] = batch; +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::LoadBalancedCall::FailPendingBatchInCallCombiner( + void* arg, grpc_error_handle error) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + auto* self = static_cast(batch->handler_private.extra_arg); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(error), self->call_combiner_); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::LoadBalancedCall::PendingBatchesFail( + grpc_error_handle error, + YieldCallCombinerPredicate yield_call_combiner_predicate) { + GPR_ASSERT(error != GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(failure_error_); + failure_error_ = error; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + size_t num_batches = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + if (pending_batches_[i] != nullptr) ++num_batches; + } + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: failing %" PRIuPTR " pending batches: %s", + chand_, this, num_batches, grpc_error_std_string(error).c_str()); + } + CallCombinerClosureList closures; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + grpc_transport_stream_op_batch*& batch = pending_batches_[i]; + if (batch != nullptr) { + batch->handler_private.extra_arg = this; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, + FailPendingBatchInCallCombiner, batch, + grpc_schedule_on_exec_ctx); + closures.Add(&batch->handler_private.closure, GRPC_ERROR_REF(error), + "PendingBatchesFail"); + batch = nullptr; + } + } + if (yield_call_combiner_predicate(closures)) { + closures.RunClosures(call_combiner_); + } else { + closures.RunClosuresWithoutYielding(call_combiner_); + } +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::LoadBalancedCall::ResumePendingBatchInCallCombiner( + void* arg, grpc_error_handle /*ignored*/) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + SubchannelCall* subchannel_call = + static_cast(batch->handler_private.extra_arg); + // Note: This will release the call combiner. + subchannel_call->StartTransportStreamOpBatch(batch); +} + +// This is called via the call combiner, so access to calld is synchronized. +void ClientChannel::LoadBalancedCall::PendingBatchesResume() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + size_t num_batches = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + if (pending_batches_[i] != nullptr) ++num_batches; + } + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: starting %" PRIuPTR + " pending batches on subchannel_call=%p", + chand_, this, num_batches, subchannel_call_.get()); + } + CallCombinerClosureList closures; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + grpc_transport_stream_op_batch*& batch = pending_batches_[i]; + if (batch != nullptr) { + batch->handler_private.extra_arg = subchannel_call_.get(); + GRPC_CLOSURE_INIT(&batch->handler_private.closure, + ResumePendingBatchInCallCombiner, batch, + grpc_schedule_on_exec_ctx); + closures.Add(&batch->handler_private.closure, GRPC_ERROR_NONE, + "resuming pending batch from LB call"); + batch = nullptr; + } + } + // Note: This will release the call combiner. + closures.RunClosures(call_combiner_); +} + +void ClientChannel::LoadBalancedCall::StartTransportStreamOpBatch( + grpc_transport_stream_op_batch* batch) { + // Handle call tracing. + if (call_attempt_tracer_ != nullptr) { + // Record send ops in tracer. + if (batch->cancel_stream) { + call_attempt_tracer_->RecordCancel( + GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error)); + } + if (batch->send_initial_metadata) { + call_attempt_tracer_->RecordSendInitialMetadata( + batch->payload->send_initial_metadata.send_initial_metadata, + batch->payload->send_initial_metadata.send_initial_metadata_flags); + peer_string_ = batch->payload->send_initial_metadata.peer_string; + original_send_initial_metadata_on_complete_ = batch->on_complete; + GRPC_CLOSURE_INIT(&send_initial_metadata_on_complete_, + SendInitialMetadataOnComplete, this, nullptr); + batch->on_complete = &send_initial_metadata_on_complete_; + } + if (batch->send_message) { + call_attempt_tracer_->RecordSendMessage( + *batch->payload->send_message.send_message); + } + if (batch->send_trailing_metadata) { + call_attempt_tracer_->RecordSendTrailingMetadata( + batch->payload->send_trailing_metadata.send_trailing_metadata); + } + // Intercept recv ops. + if (batch->recv_initial_metadata) { + recv_initial_metadata_ = + batch->payload->recv_initial_metadata.recv_initial_metadata; + original_recv_initial_metadata_ready_ = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + this, nullptr); + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &recv_initial_metadata_ready_; + } + if (batch->recv_message) { + recv_message_ = batch->payload->recv_message.recv_message; + original_recv_message_ready_ = + batch->payload->recv_message.recv_message_ready; + GRPC_CLOSURE_INIT(&recv_message_ready_, RecvMessageReady, this, nullptr); + batch->payload->recv_message.recv_message_ready = &recv_message_ready_; + } + } + // Intercept recv_trailing_metadata even if there is no call tracer, + // since we may need to notify the LB policy about trailing metadata. + if (batch->recv_trailing_metadata) { + recv_trailing_metadata_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata; + transport_stream_stats_ = + batch->payload->recv_trailing_metadata.collect_stats; + original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReady, + this, nullptr); + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &recv_trailing_metadata_ready_; + } + // If we've already gotten a subchannel call, pass the batch down to it. + // Note that once we have picked a subchannel, we do not need to acquire + // the channel's data plane mutex, which is more efficient (especially for + // streaming calls). + if (subchannel_call_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: starting batch on subchannel_call=%p", + chand_, this, subchannel_call_.get()); + } + subchannel_call_->StartTransportStreamOpBatch(batch); + return; + } + // We do not yet have a subchannel call. + // + // If we've previously been cancelled, immediately fail any new batches. + if (GPR_UNLIKELY(cancel_error_ != GRPC_ERROR_NONE)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: failing batch with error: %s", + chand_, this, grpc_error_std_string(cancel_error_).c_str()); + } + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(cancel_error_), call_combiner_); + return; + } + // Handle cancellation. + if (GPR_UNLIKELY(batch->cancel_stream)) { + // Stash a copy of cancel_error in our call data, so that we can use + // it for subsequent operations. This ensures that if the call is + // cancelled before any batches are passed down (e.g., if the deadline + // is in the past when the call starts), we can return the right + // error to the caller when the first batch does get passed down. + GRPC_ERROR_UNREF(cancel_error_); + cancel_error_ = GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: recording cancel_error=%s", + chand_, this, grpc_error_std_string(cancel_error_).c_str()); + } + // Fail all pending batches. + PendingBatchesFail(GRPC_ERROR_REF(cancel_error_), NoYieldCallCombiner); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(cancel_error_), call_combiner_); + return; + } + // Add the batch to the pending list. + PendingBatchesAdd(batch); + // For batches containing a send_initial_metadata op, acquire the + // channel's data plane mutex to pick a subchannel. + if (GPR_LIKELY(batch->send_initial_metadata)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: grabbing data plane mutex to perform pick", + chand_, this); + } + PickSubchannel(this, GRPC_ERROR_NONE); + } else { + // For all other batches, release the call combiner. + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: saved batch, yielding call combiner", + chand_, this); + } + GRPC_CALL_COMBINER_STOP(call_combiner_, + "batch does not include send_initial_metadata"); + } +} + +void ClientChannel::LoadBalancedCall::SendInitialMetadataOnComplete( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + self->call_attempt_tracer_->RecordOnDoneSendInitialMetadata( + self->peer_string_); + Closure::Run(DEBUG_LOCATION, + self->original_send_initial_metadata_on_complete_, + GRPC_ERROR_REF(error)); +} + +void ClientChannel::LoadBalancedCall::RecvInitialMetadataReady( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + // recv_initial_metadata_flags is not populated for clients + self->call_attempt_tracer_->RecordReceivedInitialMetadata( + self->recv_initial_metadata_, 0 /* recv_initial_metadata_flags */); + } + Closure::Run(DEBUG_LOCATION, self->original_recv_initial_metadata_ready_, + GRPC_ERROR_REF(error)); +} + +void ClientChannel::LoadBalancedCall::RecvMessageReady( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + if (*self->recv_message_ != nullptr) { + self->call_attempt_tracer_->RecordReceivedMessage(**self->recv_message_); + } + Closure::Run(DEBUG_LOCATION, self->original_recv_message_ready_, + GRPC_ERROR_REF(error)); +} + +void ClientChannel::LoadBalancedCall::RecvTrailingMetadataReady( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + // Check if we have a tracer or an LB callback to invoke. + if (self->call_attempt_tracer_ != nullptr || + self->lb_recv_trailing_metadata_ready_ != nullptr) { + // Get the call's status. + absl::Status status; + if (error != GRPC_ERROR_NONE) { + // Get status from error. + grpc_status_code code; + std::string message; + grpc_error_get_status(error, self->deadline_, &code, &message, + /*http_error=*/nullptr, /*error_string=*/nullptr); + status = absl::Status(static_cast(code), message); + } else { + // Get status from headers. + const auto& fields = self->recv_trailing_metadata_->legacy_index()->named; + GPR_ASSERT(fields.grpc_status != nullptr); + grpc_status_code code = + grpc_get_status_code_from_metadata(fields.grpc_status->md); + if (code != GRPC_STATUS_OK) { + absl::string_view message; + if (fields.grpc_message != nullptr) { + message = StringViewFromSlice(GRPC_MDVALUE(fields.grpc_message->md)); + } + status = absl::Status(static_cast(code), message); + } + } + // If we have a tracer, notify it. + if (self->call_attempt_tracer_ != nullptr) { + self->call_attempt_tracer_->RecordReceivedTrailingMetadata( + status, self->recv_trailing_metadata_, + *self->transport_stream_stats_); + } + // If the LB policy requested a callback for trailing metadata, invoke + // the callback. + if (self->lb_recv_trailing_metadata_ready_ != nullptr) { + Metadata trailing_metadata(self, self->recv_trailing_metadata_); + LbCallState lb_call_state(self); + self->lb_recv_trailing_metadata_ready_(status, &trailing_metadata, + &lb_call_state); + } + } + // Chain to original callback. + if (self->failure_error_ != GRPC_ERROR_NONE) { + error = self->failure_error_; + self->failure_error_ = GRPC_ERROR_NONE; + } else { + error = GRPC_ERROR_REF(error); + } + Closure::Run(DEBUG_LOCATION, self->original_recv_trailing_metadata_ready_, + error); +} + +void ClientChannel::LoadBalancedCall::CreateSubchannelCall() { + SubchannelCall::Args call_args = { + std::move(connected_subchannel_), pollent_, path_, /*start_time=*/0, + deadline_, arena_, + // TODO(roth): When we implement hedging support, we will probably + // need to use a separate call context for each subchannel call. + call_context_, call_combiner_}; + grpc_error_handle error = GRPC_ERROR_NONE; + subchannel_call_ = SubchannelCall::Create(std::move(call_args), &error); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: create subchannel_call=%p: error=%s", chand_, + this, subchannel_call_.get(), grpc_error_std_string(error).c_str()); + } + if (on_call_destruction_complete_ != nullptr) { + subchannel_call_->SetAfterCallStackDestroy(on_call_destruction_complete_); + on_call_destruction_complete_ = nullptr; + } + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + PendingBatchesFail(error, YieldCallCombiner); + } else { + PendingBatchesResume(); + } +} + +// A class to handle the call combiner cancellation callback for a +// queued pick. +// TODO(roth): When we implement hedging support, we won't be able to +// register a call combiner cancellation closure for each LB pick, +// because there may be multiple LB picks happening in parallel. +// Instead, we will probably need to maintain a list in the CallData +// object of pending LB picks to be cancelled when the closure runs. +class ClientChannel::LoadBalancedCall::LbQueuedCallCanceller { + public: + explicit LbQueuedCallCanceller(RefCountedPtr lb_call) + : lb_call_(std::move(lb_call)) { + GRPC_CALL_STACK_REF(lb_call_->owning_call_, "LbQueuedCallCanceller"); + GRPC_CLOSURE_INIT(&closure_, &CancelLocked, this, nullptr); + lb_call_->call_combiner_->SetNotifyOnCancel(&closure_); + } + + private: + static void CancelLocked(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + auto* lb_call = self->lb_call_.get(); + auto* chand = lb_call->chand_; + { + MutexLock lock(&chand->data_plane_mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: cancelling queued pick: " + "error=%s self=%p calld->pick_canceller=%p", + chand, lb_call, grpc_error_std_string(error).c_str(), self, + lb_call->lb_call_canceller_); + } + if (lb_call->lb_call_canceller_ == self && error != GRPC_ERROR_NONE) { + lb_call->call_dispatch_controller_->Commit(); + // Remove pick from list of queued picks. + lb_call->MaybeRemoveCallFromLbQueuedCallsLocked(); + // Fail pending batches on the call. + lb_call->PendingBatchesFail(GRPC_ERROR_REF(error), + YieldCallCombinerIfPendingBatchesFound); + } + } + GRPC_CALL_STACK_UNREF(lb_call->owning_call_, "LbQueuedCallCanceller"); + delete self; + } + + RefCountedPtr lb_call_; + grpc_closure closure_; +}; + +void ClientChannel::LoadBalancedCall::MaybeRemoveCallFromLbQueuedCallsLocked() { + if (!queued_pending_lb_pick_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: removing from queued picks list", + chand_, this); + } + chand_->RemoveLbQueuedCall(&queued_call_, pollent_); + queued_pending_lb_pick_ = false; + // Lame the call combiner canceller. + lb_call_canceller_ = nullptr; +} + +void ClientChannel::LoadBalancedCall::MaybeAddCallToLbQueuedCallsLocked() { + if (queued_pending_lb_pick_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: adding to queued picks list", + chand_, this); + } + queued_pending_lb_pick_ = true; + queued_call_.lb_call = this; + chand_->AddLbQueuedCall(&queued_call_, pollent_); + // Register call combiner cancellation callback. + lb_call_canceller_ = new LbQueuedCallCanceller(Ref()); +} + +void ClientChannel::LoadBalancedCall::AsyncPickDone(grpc_error_handle error) { + // TODO(roth): Does this callback need to hold a ref to LoadBalancedCall? + GRPC_CLOSURE_INIT(&pick_closure_, PickDone, this, grpc_schedule_on_exec_ctx); + ExecCtx::Run(DEBUG_LOCATION, &pick_closure_, error); +} + +void ClientChannel::LoadBalancedCall::PickDone(void* arg, + grpc_error_handle error) { + auto* self = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: failed to pick subchannel: error=%s", + self->chand_, self, grpc_error_std_string(error).c_str()); + } + self->PendingBatchesFail(GRPC_ERROR_REF(error), YieldCallCombiner); + return; + } + self->call_dispatch_controller_->Commit(); + self->CreateSubchannelCall(); +} + +void ClientChannel::LoadBalancedCall::PickSubchannel(void* arg, + grpc_error_handle error) { + auto* self = static_cast(arg); + bool pick_complete; + { + MutexLock lock(&self->chand_->data_plane_mu_); + pick_complete = self->PickSubchannelLocked(&error); + } + if (pick_complete) { + PickDone(self, error); + GRPC_ERROR_UNREF(error); + } +} + +bool ClientChannel::LoadBalancedCall::PickSubchannelLocked( + grpc_error_handle* error) { + GPR_ASSERT(connected_subchannel_ == nullptr); + GPR_ASSERT(subchannel_call_ == nullptr); + // Grab initial metadata. + auto& send_initial_metadata = + pending_batches_[0]->payload->send_initial_metadata; + grpc_metadata_batch* initial_metadata_batch = + send_initial_metadata.send_initial_metadata; + const uint32_t send_initial_metadata_flags = + send_initial_metadata.send_initial_metadata_flags; + // Perform LB pick. + LoadBalancingPolicy::PickArgs pick_args; + pick_args.path = StringViewFromSlice(path_); + LbCallState lb_call_state(this); + pick_args.call_state = &lb_call_state; + Metadata initial_metadata(this, initial_metadata_batch); + pick_args.initial_metadata = &initial_metadata; + auto result = chand_->picker_->Pick(pick_args); + return HandlePickResult( + &result, + // CompletePick + [this](LoadBalancingPolicy::PickResult::Complete* complete_pick) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::data_plane_mu_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, + "chand=%p lb_call=%p: LB pick succeeded: subchannel=%p", + chand_, this, complete_pick->subchannel.get()); + } + GPR_ASSERT(complete_pick->subchannel != nullptr); + // Grab a ref to the connected subchannel while we're still + // holding the data plane mutex. + SubchannelWrapper* subchannel = static_cast( + complete_pick->subchannel.get()); + connected_subchannel_ = subchannel->connected_subchannel(); + // If the subchannel has no connected subchannel (e.g., if the + // subchannel has moved out of state READY but the LB policy hasn't + // yet seen that change and given us a new picker), then just + // queue the pick. We'll try again as soon as we get a new picker. + // TODO(roth): In this case, we need to invoke the LB + // policy's recv_trailing_metadata_ready callback to tell it + // that the pick has been abandoned. + if (connected_subchannel_ == nullptr) { + MaybeAddCallToLbQueuedCallsLocked(); + return false; + } + lb_recv_trailing_metadata_ready_ = + std::move(complete_pick->recv_trailing_metadata_ready); + MaybeRemoveCallFromLbQueuedCallsLocked(); + return true; + }, + // QueuePick + [this](LoadBalancingPolicy::PickResult::Queue* /*queue_pick*/) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::data_plane_mu_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: LB pick queued", chand_, + this); + } + MaybeAddCallToLbQueuedCallsLocked(); + return false; + }, + // FailPick + [this, send_initial_metadata_flags, + &error](LoadBalancingPolicy::PickResult::Fail* fail_pick) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::data_plane_mu_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: LB pick failed: %s", + chand_, this, fail_pick->status.ToString().c_str()); + } + // If wait_for_ready is false, then the error indicates the RPC + // attempt's final status. + if ((send_initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) { + grpc_error_handle lb_error = + absl_status_to_grpc_error(fail_pick->status); + *error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to pick subchannel", &lb_error, 1); + GRPC_ERROR_UNREF(lb_error); + MaybeRemoveCallFromLbQueuedCallsLocked(); + return true; + } + // If wait_for_ready is true, then queue to retry when we get a new + // picker. + MaybeAddCallToLbQueuedCallsLocked(); + return false; + }, + // DropPick + [this, &error](LoadBalancingPolicy::PickResult::Drop* drop_pick) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::data_plane_mu_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_routing_trace)) { + gpr_log(GPR_INFO, "chand=%p lb_call=%p: LB pick dropped: %s", + chand_, this, drop_pick->status.ToString().c_str()); + } + *error = + grpc_error_set_int(absl_status_to_grpc_error(drop_pick->status), + GRPC_ERROR_INT_LB_POLICY_DROP, 1); + MaybeRemoveCallFromLbQueuedCallsLocked(); + return true; + }); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/client_channel_channelz.cc b/src/core/ext/filters/client_channel/client_channel_channelz.cc new file mode 100644 index 00000000..d543aa41 --- /dev/null +++ b/src/core/ext/filters/client_channel/client_channel_channelz.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/client_channel_channelz.h" + +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" + +namespace grpc_core { +namespace channelz { + +SubchannelNode::SubchannelNode(std::string target_address, + size_t channel_tracer_max_nodes) + : BaseNode(EntityType::kSubchannel, target_address), + target_(std::move(target_address)), + trace_(channel_tracer_max_nodes) {} + +SubchannelNode::~SubchannelNode() {} + +void SubchannelNode::UpdateConnectivityState(grpc_connectivity_state state) { + connectivity_state_.store(state, std::memory_order_relaxed); +} + +void SubchannelNode::SetChildSocket(RefCountedPtr socket) { + MutexLock lock(&socket_mu_); + child_socket_ = std::move(socket); +} + +Json SubchannelNode::RenderJson() { + // Create and fill the data child. + grpc_connectivity_state state = + connectivity_state_.load(std::memory_order_relaxed); + Json::Object data = { + {"state", + Json::Object{ + {"state", ConnectivityStateName(state)}, + }}, + {"target", target_}, + }; + + // Fill in the channel trace if applicable + Json trace_json = trace_.RenderJson(); + if (trace_json.type() != Json::Type::JSON_NULL) { + data["trace"] = std::move(trace_json); + } + // Ask CallCountingHelper to populate call count data. + call_counter_.PopulateCallCounts(&data); + // Construct top-level object. + Json::Object object{ + {"ref", + Json::Object{ + {"subchannelId", std::to_string(uuid())}, + }}, + {"data", std::move(data)}, + }; + // Populate the child socket. + RefCountedPtr child_socket; + { + MutexLock lock(&socket_mu_); + child_socket = child_socket_; + } + if (child_socket != nullptr && child_socket->uuid() != 0) { + object["socketRef"] = Json::Array{ + Json::Object{ + {"socketId", std::to_string(child_socket->uuid())}, + {"name", child_socket->name()}, + }, + }; + } + return object; +} + +} // namespace channelz +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/client_channel_factory.cc b/src/core/ext/filters/client_channel/client_channel_factory.cc new file mode 100644 index 00000000..7e234a3e --- /dev/null +++ b/src/core/ext/filters/client_channel/client_channel_factory.cc @@ -0,0 +1,56 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/client_channel_factory.h" + +#include "src/core/lib/channel/channel_args.h" + +// Channel arg key for client channel factory. +#define GRPC_ARG_CLIENT_CHANNEL_FACTORY "grpc.client_channel_factory" + +namespace grpc_core { + +namespace { + +void* factory_arg_copy(void* f) { return f; } +void factory_arg_destroy(void* /*f*/) {} +int factory_arg_cmp(void* factory1, void* factory2) { + return QsortCompare(factory1, factory2); +} +const grpc_arg_pointer_vtable factory_arg_vtable = { + factory_arg_copy, factory_arg_destroy, factory_arg_cmp}; + +} // namespace + +grpc_arg ClientChannelFactory::CreateChannelArg(ClientChannelFactory* factory) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CLIENT_CHANNEL_FACTORY), factory, + &factory_arg_vtable); +} + +ClientChannelFactory* ClientChannelFactory::GetFromChannelArgs( + const grpc_channel_args* args) { + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_CLIENT_CHANNEL_FACTORY); + if (arg == nullptr || arg->type != GRPC_ARG_POINTER) return nullptr; + return static_cast(arg->value.pointer.p); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/client_channel_plugin.cc b/src/core/ext/filters/client_channel/client_channel_plugin.cc new file mode 100644 index 00000000..82205181 --- /dev/null +++ b/src/core/ext/filters/client_channel/client_channel_plugin.cc @@ -0,0 +1,74 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/client_channel_channelz.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/http_proxy.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" +#include "src/core/ext/filters/client_channel/retry_service_config.h" +#include "src/core/ext/filters/client_channel/retry_throttle.h" +#include "src/core/lib/config/core_configuration.h" + +void grpc_client_channel_init(void) { + grpc_core::internal::ClientChannelServiceConfigParser::Register(); + grpc_core::internal::RetryServiceConfigParser::Register(); + grpc_core::LoadBalancingPolicyRegistry::Builder::InitRegistry(); + grpc_core::ResolverRegistry::Builder::InitRegistry(); + grpc_core::internal::ServerRetryThrottleMap::Init(); + grpc_core::ProxyMapperRegistry::Init(); + grpc_core::RegisterHttpProxyMapper(); + grpc_core::GlobalSubchannelPool::Init(); + grpc_client_channel_global_init_backup_polling(); +} + +void grpc_client_channel_shutdown(void) { + grpc_core::GlobalSubchannelPool::Shutdown(); + grpc_core::ProxyMapperRegistry::Shutdown(); + grpc_core::internal::ServerRetryThrottleMap::Shutdown(); + grpc_core::ResolverRegistry::Builder::ShutdownRegistry(); + grpc_core::LoadBalancingPolicyRegistry::Builder::ShutdownRegistry(); +} + +namespace grpc_core { + +void BuildClientChannelConfiguration(CoreConfiguration::Builder* builder) { + RegisterHttpConnectHandshaker(builder); + builder->channel_init()->RegisterStage( + GRPC_CLIENT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + return grpc_channel_stack_builder_append_filter( + builder, &grpc_core::ClientChannel::kFilterVtable, nullptr, + nullptr); + }); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/config_selector.cc b/src/core/ext/filters/client_channel/config_selector.cc new file mode 100644 index 00000000..e741a3e6 --- /dev/null +++ b/src/core/ext/filters/client_channel/config_selector.cc @@ -0,0 +1,59 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/config_selector.h" + +#include "src/core/lib/channel/channel_args.h" + +namespace grpc_core { + +namespace { + +void* ConfigSelectorArgCopy(void* p) { + ConfigSelector* config_selector = static_cast(p); + config_selector->Ref().release(); + return p; +} + +void ConfigSelectorArgDestroy(void* p) { + ConfigSelector* config_selector = static_cast(p); + config_selector->Unref(); +} + +int ConfigSelectorArgCmp(void* p, void* q) { return QsortCompare(p, q); } + +const grpc_arg_pointer_vtable kChannelArgVtable = { + ConfigSelectorArgCopy, ConfigSelectorArgDestroy, ConfigSelectorArgCmp}; + +} // namespace + +grpc_arg ConfigSelector::MakeChannelArg() const { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CONFIG_SELECTOR), + const_cast(this), &kChannelArgVtable); +} + +RefCountedPtr ConfigSelector::GetFromChannelArgs( + const grpc_channel_args& args) { + ConfigSelector* config_selector = + grpc_channel_args_find_pointer(&args, + GRPC_ARG_CONFIG_SELECTOR); + return config_selector != nullptr ? config_selector->Ref() : nullptr; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/dynamic_filters.cc b/src/core/ext/filters/client_channel/dynamic_filters.cc new file mode 100644 index 00000000..8e7f5db2 --- /dev/null +++ b/src/core/ext/filters/client_channel/dynamic_filters.cc @@ -0,0 +1,190 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/dynamic_filters.h" + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/surface/lame_client.h" + +// Conversion between call and call stack. +#define CALL_TO_CALL_STACK(call) \ + (grpc_call_stack*)((char*)(call) + GPR_ROUND_UP_TO_ALIGNMENT_SIZE( \ + sizeof(DynamicFilters::Call))) +#define CALL_STACK_TO_CALL(callstack) \ + (DynamicFilters::Call*)(((char*)(call_stack)) - \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE( \ + sizeof(DynamicFilters::Call))) + +namespace grpc_core { + +// +// DynamicFilters::Call +// + +DynamicFilters::Call::Call(Args args, grpc_error_handle* error) + : channel_stack_(std::move(args.channel_stack)) { + grpc_call_stack* call_stack = CALL_TO_CALL_STACK(this); + const grpc_call_element_args call_args = { + call_stack, /* call_stack */ + nullptr, /* server_transport_data */ + args.context, /* context */ + args.path, /* path */ + args.start_time, /* start_time */ + args.deadline, /* deadline */ + args.arena, /* arena */ + args.call_combiner /* call_combiner */ + }; + *error = grpc_call_stack_init(channel_stack_->channel_stack_, 1, Destroy, + this, &call_args); + if (GPR_UNLIKELY(*error != GRPC_ERROR_NONE)) { + gpr_log(GPR_ERROR, "error: %s", grpc_error_std_string(*error).c_str()); + return; + } + grpc_call_stack_set_pollset_or_pollset_set(call_stack, args.pollent); +} + +void DynamicFilters::Call::StartTransportStreamOpBatch( + grpc_transport_stream_op_batch* batch) { + grpc_call_stack* call_stack = CALL_TO_CALL_STACK(this); + grpc_call_element* top_elem = grpc_call_stack_element(call_stack, 0); + GRPC_CALL_LOG_OP(GPR_INFO, top_elem, batch); + top_elem->filter->start_transport_stream_op_batch(top_elem, batch); +} + +void DynamicFilters::Call::SetAfterCallStackDestroy(grpc_closure* closure) { + GPR_ASSERT(after_call_stack_destroy_ == nullptr); + GPR_ASSERT(closure != nullptr); + after_call_stack_destroy_ = closure; +} + +RefCountedPtr DynamicFilters::Call::Ref() { + IncrementRefCount(); + return RefCountedPtr(this); +} + +RefCountedPtr DynamicFilters::Call::Ref( + const grpc_core::DebugLocation& location, const char* reason) { + IncrementRefCount(location, reason); + return RefCountedPtr(this); +} + +void DynamicFilters::Call::Unref() { + GRPC_CALL_STACK_UNREF(CALL_TO_CALL_STACK(this), ""); +} + +void DynamicFilters::Call::Unref(const DebugLocation& /*location*/, + const char* reason) { + GRPC_CALL_STACK_UNREF(CALL_TO_CALL_STACK(this), reason); +} + +void DynamicFilters::Call::Destroy(void* arg, grpc_error_handle /*error*/) { + DynamicFilters::Call* self = static_cast(arg); + // Keep some members before destroying the subchannel call. + grpc_closure* after_call_stack_destroy = self->after_call_stack_destroy_; + RefCountedPtr channel_stack = std::move(self->channel_stack_); + // Destroy the subchannel call. + self->~Call(); + // Destroy the call stack. This should be after destroying the call, because + // call->after_call_stack_destroy(), if not null, will free the call arena. + grpc_call_stack_destroy(CALL_TO_CALL_STACK(self), nullptr, + after_call_stack_destroy); + // Automatically reset channel_stack. This should be after destroying the call + // stack, because destroying call stack needs access to the channel stack. +} + +void DynamicFilters::Call::IncrementRefCount() { + GRPC_CALL_STACK_REF(CALL_TO_CALL_STACK(this), ""); +} + +void DynamicFilters::Call::IncrementRefCount( + const grpc_core::DebugLocation& /*location*/, const char* reason) { + GRPC_CALL_STACK_REF(CALL_TO_CALL_STACK(this), reason); +} + +// +// DynamicFilters +// + +namespace { + +void DestroyChannelStack(void* arg, grpc_error_handle /*error*/) { + grpc_channel_stack* channel_stack = static_cast(arg); + grpc_channel_stack_destroy(channel_stack); + gpr_free(channel_stack); +} + +std::pair CreateChannelStack( + const grpc_channel_args* args, + std::vector filters) { + // Allocate memory for channel stack. + const size_t channel_stack_size = + grpc_channel_stack_size(filters.data(), filters.size()); + grpc_channel_stack* channel_stack = + reinterpret_cast(gpr_zalloc(channel_stack_size)); + // Initialize stack. + grpc_error_handle error = grpc_channel_stack_init( + /*initial_refs=*/1, DestroyChannelStack, channel_stack, filters.data(), + filters.size(), args, /*optional_transport=*/nullptr, "DynamicFilters", + channel_stack); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "error initializing client internal stack: %s", + grpc_error_std_string(error).c_str()); + grpc_channel_stack_destroy(channel_stack); + gpr_free(channel_stack); + return {nullptr, error}; + } + return {channel_stack, GRPC_ERROR_NONE}; +} + +} // namespace + +RefCountedPtr DynamicFilters::Create( + const grpc_channel_args* args, + std::vector filters) { + // Attempt to create channel stack from requested filters. + auto p = CreateChannelStack(args, std::move(filters)); + if (p.second != GRPC_ERROR_NONE) { + // Channel stack creation failed with requested filters. + // Create with lame filter instead. + grpc_error_handle error = p.second; + grpc_arg error_arg = MakeLameClientErrorArg(&error); + grpc_channel_args* new_args = + grpc_channel_args_copy_and_add(args, &error_arg, 1); + GRPC_ERROR_UNREF(error); + p = CreateChannelStack(new_args, {&grpc_lame_filter}); + GPR_ASSERT(p.second == GRPC_ERROR_NONE); + grpc_channel_args_destroy(new_args); + } + return MakeRefCounted(p.first); +} + +DynamicFilters::~DynamicFilters() { + GRPC_CHANNEL_STACK_UNREF(channel_stack_, "~DynamicFilters"); +} + +RefCountedPtr DynamicFilters::CreateCall( + DynamicFilters::Call::Args args, grpc_error_handle* error) { + size_t allocation_size = GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(Call)) + + channel_stack_->call_stack_size; + Call* call = static_cast(args.arena->Alloc(allocation_size)); + new (call) Call(std::move(args), error); + return call; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/global_subchannel_pool.cc b/src/core/ext/filters/client_channel/global_subchannel_pool.cc new file mode 100644 index 00000000..72593a6f --- /dev/null +++ b/src/core/ext/filters/client_channel/global_subchannel_pool.cc @@ -0,0 +1,83 @@ +// +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" + +#include "src/core/ext/filters/client_channel/subchannel.h" + +namespace grpc_core { + +#define GRPC_REGISTER_SUBCHANNEL_CALM_DOWN_AFTER_ATTEMPTS 100 +#define GRPC_REGISTER_SUBCHANNEL_CALM_DOWN_MICROS 10 + +void GlobalSubchannelPool::Init() { + instance_ = new RefCountedPtr( + MakeRefCounted()); +} + +void GlobalSubchannelPool::Shutdown() { + // To ensure Init() was called before. + GPR_ASSERT(instance_ != nullptr); + // To ensure Shutdown() was not called before. + GPR_ASSERT(*instance_ != nullptr); + instance_->reset(); + delete instance_; +} + +RefCountedPtr GlobalSubchannelPool::instance() { + GPR_ASSERT(instance_ != nullptr); + GPR_ASSERT(*instance_ != nullptr); + return *instance_; +} + +RefCountedPtr GlobalSubchannelPool::RegisterSubchannel( + const SubchannelKey& key, RefCountedPtr constructed) { + MutexLock lock(&mu_); + auto it = subchannel_map_.find(key); + if (it != subchannel_map_.end()) { + RefCountedPtr existing = it->second->RefIfNonZero(); + if (existing != nullptr) return existing; + } + subchannel_map_[key] = constructed.get(); + return constructed; +} + +RefCountedPtr* GlobalSubchannelPool::instance_ = nullptr; + +void GlobalSubchannelPool::UnregisterSubchannel(const SubchannelKey& key, + Subchannel* subchannel) { + MutexLock lock(&mu_); + auto it = subchannel_map_.find(key); + // delete only if key hasn't been re-registered to a different subchannel + // between strong-unreffing and unregistration of subchannel. + if (it != subchannel_map_.end() && it->second == subchannel) { + subchannel_map_.erase(it); + } +} + +RefCountedPtr GlobalSubchannelPool::FindSubchannel( + const SubchannelKey& key) { + MutexLock lock(&mu_); + auto it = subchannel_map_.find(key); + if (it == subchannel_map_.end()) return nullptr; + return it->second->RefIfNonZero(); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/health/health_check_client.cc b/src/core/ext/filters/client_channel/health/health_check_client.cc new file mode 100644 index 00000000..53d13a44 --- /dev/null +++ b/src/core/ext/filters/client_channel/health/health_check_client.cc @@ -0,0 +1,619 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/health/health_check_client.h" + +#include +#include + +#include "upb/upb.hpp" + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/status_metadata.h" +#include "src/proto/grpc/health/v1/health.upb.h" + +#define HEALTH_CHECK_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define HEALTH_CHECK_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define HEALTH_CHECK_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define HEALTH_CHECK_RECONNECT_JITTER 0.2 + +namespace grpc_core { + +TraceFlag grpc_health_check_client_trace(false, "health_check_client"); + +// +// HealthCheckClient +// + +HealthCheckClient::HealthCheckClient( + std::string service_name, + RefCountedPtr connected_subchannel, + grpc_pollset_set* interested_parties, + RefCountedPtr channelz_node, + RefCountedPtr watcher) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace) + ? "HealthCheckClient" + : nullptr), + service_name_(std::move(service_name)), + connected_subchannel_(std::move(connected_subchannel)), + interested_parties_(interested_parties), + channelz_node_(std::move(channelz_node)), + watcher_(std::move(watcher)), + retry_backoff_( + BackOff::Options() + .set_initial_backoff( + HEALTH_CHECK_INITIAL_CONNECT_BACKOFF_SECONDS * 1000) + .set_multiplier(HEALTH_CHECK_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(HEALTH_CHECK_RECONNECT_JITTER) + .set_max_backoff(HEALTH_CHECK_RECONNECT_MAX_BACKOFF_SECONDS * + 1000)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "created HealthCheckClient %p", this); + } + GRPC_CLOSURE_INIT(&retry_timer_callback_, OnRetryTimer, this, + grpc_schedule_on_exec_ctx); + StartCall(); +} + +HealthCheckClient::~HealthCheckClient() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "destroying HealthCheckClient %p", this); + } +} + +void HealthCheckClient::SetHealthStatus(grpc_connectivity_state state, + const char* reason) { + MutexLock lock(&mu_); + SetHealthStatusLocked(state, reason); +} + +void HealthCheckClient::SetHealthStatusLocked(grpc_connectivity_state state, + const char* reason) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: setting state=%s reason=%s", this, + ConnectivityStateName(state), reason); + } + if (watcher_ != nullptr) { + watcher_->Notify(state, + state == GRPC_CHANNEL_TRANSIENT_FAILURE + ? absl::Status(absl::StatusCode::kUnavailable, reason) + : absl::Status()); + } +} + +void HealthCheckClient::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: shutting down", this); + } + { + MutexLock lock(&mu_); + shutting_down_ = true; + watcher_.reset(); + call_state_.reset(); + if (retry_timer_callback_pending_) { + grpc_timer_cancel(&retry_timer_); + } + } + Unref(DEBUG_LOCATION, "orphan"); +} + +void HealthCheckClient::StartCall() { + MutexLock lock(&mu_); + StartCallLocked(); +} + +void HealthCheckClient::StartCallLocked() { + if (shutting_down_) return; + GPR_ASSERT(call_state_ == nullptr); + SetHealthStatusLocked(GRPC_CHANNEL_CONNECTING, "starting health watch"); + call_state_ = MakeOrphanable(Ref(), interested_parties_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: created CallState %p", this, + call_state_.get()); + } + call_state_->StartCall(); +} + +void HealthCheckClient::StartRetryTimerLocked() { + SetHealthStatusLocked(GRPC_CHANNEL_TRANSIENT_FAILURE, + "health check call failed; will retry after backoff"); + grpc_millis next_try = retry_backoff_.NextAttemptTime(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: health check call lost...", this); + grpc_millis timeout = next_try - ExecCtx::Get()->Now(); + if (timeout > 0) { + gpr_log(GPR_INFO, + "HealthCheckClient %p: ... will retry in %" PRId64 "ms.", this, + timeout); + } else { + gpr_log(GPR_INFO, "HealthCheckClient %p: ... retrying immediately.", + this); + } + } + // Ref for callback, tracked manually. + Ref(DEBUG_LOCATION, "health_retry_timer").release(); + retry_timer_callback_pending_ = true; + grpc_timer_init(&retry_timer_, next_try, &retry_timer_callback_); +} + +void HealthCheckClient::OnRetryTimer(void* arg, grpc_error_handle error) { + HealthCheckClient* self = static_cast(arg); + { + MutexLock lock(&self->mu_); + self->retry_timer_callback_pending_ = false; + if (!self->shutting_down_ && error == GRPC_ERROR_NONE && + self->call_state_ == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: restarting health check call", + self); + } + self->StartCallLocked(); + } + } + self->Unref(DEBUG_LOCATION, "health_retry_timer"); +} + +// +// protobuf helpers +// + +namespace { + +void EncodeRequest(const std::string& service_name, + ManualConstructor* send_message) { + upb::Arena arena; + grpc_health_v1_HealthCheckRequest* request_struct = + grpc_health_v1_HealthCheckRequest_new(arena.ptr()); + grpc_health_v1_HealthCheckRequest_set_service( + request_struct, + upb_strview_make(service_name.data(), service_name.size())); + size_t buf_length; + char* buf = grpc_health_v1_HealthCheckRequest_serialize( + request_struct, arena.ptr(), &buf_length); + grpc_slice request_slice = GRPC_SLICE_MALLOC(buf_length); + memcpy(GRPC_SLICE_START_PTR(request_slice), buf, buf_length); + grpc_slice_buffer slice_buffer; + grpc_slice_buffer_init(&slice_buffer); + grpc_slice_buffer_add(&slice_buffer, request_slice); + send_message->Init(&slice_buffer, 0); + grpc_slice_buffer_destroy_internal(&slice_buffer); +} + +// Returns true if healthy. +// If there was an error parsing the response, sets *error and returns false. +bool DecodeResponse(grpc_slice_buffer* slice_buffer, grpc_error_handle* error) { + // If message is empty, assume unhealthy. + if (slice_buffer->length == 0) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("health check response was empty"); + return false; + } + // Concatenate the slices to form a single string. + std::unique_ptr recv_message_deleter; + uint8_t* recv_message; + if (slice_buffer->count == 1) { + recv_message = GRPC_SLICE_START_PTR(slice_buffer->slices[0]); + } else { + recv_message = static_cast(gpr_malloc(slice_buffer->length)); + recv_message_deleter.reset(recv_message); + size_t offset = 0; + for (size_t i = 0; i < slice_buffer->count; ++i) { + memcpy(recv_message + offset, + GRPC_SLICE_START_PTR(slice_buffer->slices[i]), + GRPC_SLICE_LENGTH(slice_buffer->slices[i])); + offset += GRPC_SLICE_LENGTH(slice_buffer->slices[i]); + } + } + // Deserialize message. + upb::Arena arena; + grpc_health_v1_HealthCheckResponse* response_struct = + grpc_health_v1_HealthCheckResponse_parse( + reinterpret_cast(recv_message), slice_buffer->length, + arena.ptr()); + if (response_struct == nullptr) { + // Can't parse message; assume unhealthy. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "cannot parse health check response"); + return false; + } + int32_t status = grpc_health_v1_HealthCheckResponse_status(response_struct); + return status == grpc_health_v1_HealthCheckResponse_SERVING; +} + +} // namespace + +// +// HealthCheckClient::CallState +// + +HealthCheckClient::CallState::CallState( + RefCountedPtr health_check_client, + grpc_pollset_set* interested_parties) + : health_check_client_(std::move(health_check_client)), + pollent_(grpc_polling_entity_create_from_pollset_set(interested_parties)), + arena_(Arena::Create(health_check_client_->connected_subchannel_ + ->GetInitialCallSizeEstimate())), + payload_(context_), + send_initial_metadata_(arena_), + send_trailing_metadata_(arena_), + recv_initial_metadata_(arena_), + recv_trailing_metadata_(arena_) {} + +HealthCheckClient::CallState::~CallState() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, "HealthCheckClient %p: destroying CallState %p", + health_check_client_.get(), this); + } + for (size_t i = 0; i < GRPC_CONTEXT_COUNT; i++) { + if (context_[i].destroy != nullptr) { + context_[i].destroy(context_[i].value); + } + } + // Unset the call combiner cancellation closure. This has the + // effect of scheduling the previously set cancellation closure, if + // any, so that it can release any internal references it may be + // holding to the call stack. + call_combiner_.SetNotifyOnCancel(nullptr); + arena_->Destroy(); +} + +void HealthCheckClient::CallState::Orphan() { + call_combiner_.Cancel(GRPC_ERROR_CANCELLED); + Cancel(); +} + +void HealthCheckClient::CallState::StartCall() { + SubchannelCall::Args args = { + health_check_client_->connected_subchannel_, + &pollent_, + GRPC_MDSTR_SLASH_GRPC_DOT_HEALTH_DOT_V1_DOT_HEALTH_SLASH_WATCH, + gpr_get_cycle_counter(), // start_time + GRPC_MILLIS_INF_FUTURE, // deadline + arena_, + context_, + &call_combiner_, + }; + grpc_error_handle error = GRPC_ERROR_NONE; + call_ = SubchannelCall::Create(std::move(args), &error).release(); + // Register after-destruction callback. + GRPC_CLOSURE_INIT(&after_call_stack_destruction_, AfterCallStackDestruction, + this, grpc_schedule_on_exec_ctx); + call_->SetAfterCallStackDestroy(&after_call_stack_destruction_); + // Check if creation failed. + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "HealthCheckClient %p CallState %p: error creating health " + "checking call on subchannel (%s); will retry", + health_check_client_.get(), this, + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + CallEndedLocked(/*retry=*/true); + return; + } + // Initialize payload and batch. + payload_.context = context_; + batch_.payload = &payload_; + // on_complete callback takes ref, handled manually. + call_->Ref(DEBUG_LOCATION, "on_complete").release(); + batch_.on_complete = GRPC_CLOSURE_INIT(&on_complete_, OnComplete, this, + grpc_schedule_on_exec_ctx); + // Add send_initial_metadata op. + error = grpc_metadata_batch_add_head( + &send_initial_metadata_, &path_metadata_storage_, + grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, + GRPC_MDSTR_SLASH_GRPC_DOT_HEALTH_DOT_V1_DOT_HEALTH_SLASH_WATCH), + GRPC_BATCH_PATH); + GPR_ASSERT(error == GRPC_ERROR_NONE); + payload_.send_initial_metadata.send_initial_metadata = + &send_initial_metadata_; + payload_.send_initial_metadata.send_initial_metadata_flags = 0; + payload_.send_initial_metadata.peer_string = nullptr; + batch_.send_initial_metadata = true; + // Add send_message op. + EncodeRequest(health_check_client_->service_name_, &send_message_); + payload_.send_message.send_message.reset(send_message_.get()); + batch_.send_message = true; + // Add send_trailing_metadata op. + payload_.send_trailing_metadata.send_trailing_metadata = + &send_trailing_metadata_; + batch_.send_trailing_metadata = true; + // Add recv_initial_metadata op. + payload_.recv_initial_metadata.recv_initial_metadata = + &recv_initial_metadata_; + payload_.recv_initial_metadata.recv_flags = nullptr; + payload_.recv_initial_metadata.trailing_metadata_available = nullptr; + payload_.recv_initial_metadata.peer_string = nullptr; + // recv_initial_metadata_ready callback takes ref, handled manually. + call_->Ref(DEBUG_LOCATION, "recv_initial_metadata_ready").release(); + payload_.recv_initial_metadata.recv_initial_metadata_ready = + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + this, grpc_schedule_on_exec_ctx); + batch_.recv_initial_metadata = true; + // Add recv_message op. + payload_.recv_message.recv_message = &recv_message_; + payload_.recv_message.call_failed_before_recv_message = nullptr; + // recv_message callback takes ref, handled manually. + call_->Ref(DEBUG_LOCATION, "recv_message_ready").release(); + payload_.recv_message.recv_message_ready = GRPC_CLOSURE_INIT( + &recv_message_ready_, RecvMessageReady, this, grpc_schedule_on_exec_ctx); + batch_.recv_message = true; + // Start batch. + StartBatch(&batch_); + // Initialize recv_trailing_metadata batch. + recv_trailing_metadata_batch_.payload = &payload_; + // Add recv_trailing_metadata op. + payload_.recv_trailing_metadata.recv_trailing_metadata = + &recv_trailing_metadata_; + payload_.recv_trailing_metadata.collect_stats = &collect_stats_; + // This callback signals the end of the call, so it relies on the + // initial ref instead of taking a new ref. When it's invoked, the + // initial ref is released. + payload_.recv_trailing_metadata.recv_trailing_metadata_ready = + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, + RecvTrailingMetadataReady, this, + grpc_schedule_on_exec_ctx); + recv_trailing_metadata_batch_.recv_trailing_metadata = true; + // Start recv_trailing_metadata batch. + StartBatch(&recv_trailing_metadata_batch_); +} + +void HealthCheckClient::CallState::StartBatchInCallCombiner( + void* arg, grpc_error_handle /*error*/) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + SubchannelCall* call = + static_cast(batch->handler_private.extra_arg); + call->StartTransportStreamOpBatch(batch); +} + +void HealthCheckClient::CallState::StartBatch( + grpc_transport_stream_op_batch* batch) { + batch->handler_private.extra_arg = call_; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, StartBatchInCallCombiner, + batch, grpc_schedule_on_exec_ctx); + GRPC_CALL_COMBINER_START(&call_combiner_, &batch->handler_private.closure, + GRPC_ERROR_NONE, "start_subchannel_batch"); +} + +void HealthCheckClient::CallState::AfterCallStackDestruction( + void* arg, grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + delete self; +} + +void HealthCheckClient::CallState::OnCancelComplete( + void* arg, grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + GRPC_CALL_COMBINER_STOP(&self->call_combiner_, "health_cancel"); + self->call_->Unref(DEBUG_LOCATION, "cancel"); +} + +void HealthCheckClient::CallState::StartCancel(void* arg, + grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + auto* batch = grpc_make_transport_stream_op( + GRPC_CLOSURE_CREATE(OnCancelComplete, self, grpc_schedule_on_exec_ctx)); + batch->cancel_stream = true; + batch->payload->cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + self->call_->StartTransportStreamOpBatch(batch); +} + +void HealthCheckClient::CallState::Cancel() { + bool expected = false; + if (cancelled_.compare_exchange_strong(expected, true, + std::memory_order_acq_rel, + std::memory_order_acquire)) { + call_->Ref(DEBUG_LOCATION, "cancel").release(); + GRPC_CALL_COMBINER_START( + &call_combiner_, + GRPC_CLOSURE_CREATE(StartCancel, this, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE, "health_cancel"); + } +} + +void HealthCheckClient::CallState::OnComplete(void* arg, + grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + GRPC_CALL_COMBINER_STOP(&self->call_combiner_, "on_complete"); + self->send_initial_metadata_.Clear(); + self->send_trailing_metadata_.Clear(); + self->call_->Unref(DEBUG_LOCATION, "on_complete"); +} + +void HealthCheckClient::CallState::RecvInitialMetadataReady( + void* arg, grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + GRPC_CALL_COMBINER_STOP(&self->call_combiner_, "recv_initial_metadata_ready"); + self->recv_initial_metadata_.Clear(); + self->call_->Unref(DEBUG_LOCATION, "recv_initial_metadata_ready"); +} + +void HealthCheckClient::CallState::DoneReadingRecvMessage( + grpc_error_handle error) { + recv_message_.reset(); + if (error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(error); + Cancel(); + grpc_slice_buffer_destroy_internal(&recv_message_buffer_); + call_->Unref(DEBUG_LOCATION, "recv_message_ready"); + return; + } + const bool healthy = DecodeResponse(&recv_message_buffer_, &error); + const grpc_connectivity_state state = + healthy ? GRPC_CHANNEL_READY : GRPC_CHANNEL_TRANSIENT_FAILURE; + health_check_client_->SetHealthStatus( + state, error == GRPC_ERROR_NONE && !healthy + ? "backend unhealthy" + : grpc_error_std_string(error).c_str()); + seen_response_.store(true, std::memory_order_release); + grpc_slice_buffer_destroy_internal(&recv_message_buffer_); + // Start another recv_message batch. + // This re-uses the ref we're holding. + // Note: Can't just reuse batch_ here, since we don't know that all + // callbacks from the original batch have completed yet. + recv_message_batch_.payload = &payload_; + payload_.recv_message.recv_message = &recv_message_; + payload_.recv_message.call_failed_before_recv_message = nullptr; + payload_.recv_message.recv_message_ready = GRPC_CLOSURE_INIT( + &recv_message_ready_, RecvMessageReady, this, grpc_schedule_on_exec_ctx); + recv_message_batch_.recv_message = true; + StartBatch(&recv_message_batch_); +} + +grpc_error_handle HealthCheckClient::CallState::PullSliceFromRecvMessage() { + grpc_slice slice; + grpc_error_handle error = recv_message_->Pull(&slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&recv_message_buffer_, slice); + } + return error; +} + +void HealthCheckClient::CallState::ContinueReadingRecvMessage() { + while (recv_message_->Next(SIZE_MAX, &recv_message_ready_)) { + grpc_error_handle error = PullSliceFromRecvMessage(); + if (error != GRPC_ERROR_NONE) { + DoneReadingRecvMessage(error); + return; + } + if (recv_message_buffer_.length == recv_message_->length()) { + DoneReadingRecvMessage(GRPC_ERROR_NONE); + break; + } + } +} + +void HealthCheckClient::CallState::OnByteStreamNext(void* arg, + grpc_error_handle error) { + HealthCheckClient::CallState* self = + static_cast(arg); + if (error != GRPC_ERROR_NONE) { + self->DoneReadingRecvMessage(GRPC_ERROR_REF(error)); + return; + } + error = self->PullSliceFromRecvMessage(); + if (error != GRPC_ERROR_NONE) { + self->DoneReadingRecvMessage(error); + return; + } + if (self->recv_message_buffer_.length == self->recv_message_->length()) { + self->DoneReadingRecvMessage(GRPC_ERROR_NONE); + } else { + self->ContinueReadingRecvMessage(); + } +} + +void HealthCheckClient::CallState::RecvMessageReady( + void* arg, grpc_error_handle /*error*/) { + HealthCheckClient::CallState* self = + static_cast(arg); + GRPC_CALL_COMBINER_STOP(&self->call_combiner_, "recv_message_ready"); + if (self->recv_message_ == nullptr) { + self->call_->Unref(DEBUG_LOCATION, "recv_message_ready"); + return; + } + grpc_slice_buffer_init(&self->recv_message_buffer_); + GRPC_CLOSURE_INIT(&self->recv_message_ready_, OnByteStreamNext, self, + grpc_schedule_on_exec_ctx); + self->ContinueReadingRecvMessage(); + // Ref will continue to be held until we finish draining the byte stream. +} + +void HealthCheckClient::CallState::RecvTrailingMetadataReady( + void* arg, grpc_error_handle error) { + HealthCheckClient::CallState* self = + static_cast(arg); + GRPC_CALL_COMBINER_STOP(&self->call_combiner_, + "recv_trailing_metadata_ready"); + // Get call status. + grpc_status_code status = GRPC_STATUS_UNKNOWN; + if (error != GRPC_ERROR_NONE) { + grpc_error_get_status(error, GRPC_MILLIS_INF_FUTURE, &status, + nullptr /* slice */, nullptr /* http_error */, + nullptr /* error_string */); + } else if (self->recv_trailing_metadata_.legacy_index()->named.grpc_status != + nullptr) { + status = grpc_get_status_code_from_metadata( + self->recv_trailing_metadata_.legacy_index()->named.grpc_status->md); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_health_check_client_trace)) { + gpr_log(GPR_INFO, + "HealthCheckClient %p CallState %p: health watch failed with " + "status %d", + self->health_check_client_.get(), self, status); + } + // Clean up. + self->recv_trailing_metadata_.Clear(); + // For status UNIMPLEMENTED, give up and assume always healthy. + bool retry = true; + if (status == GRPC_STATUS_UNIMPLEMENTED) { + static const char kErrorMessage[] = + "health checking Watch method returned UNIMPLEMENTED; " + "disabling health checks but assuming server is healthy"; + gpr_log(GPR_ERROR, kErrorMessage); + if (self->health_check_client_->channelz_node_ != nullptr) { + self->health_check_client_->channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Error, + grpc_slice_from_static_string(kErrorMessage)); + } + self->health_check_client_->SetHealthStatus(GRPC_CHANNEL_READY, + kErrorMessage); + retry = false; + } + MutexLock lock(&self->health_check_client_->mu_); + self->CallEndedLocked(retry); +} + +void HealthCheckClient::CallState::CallEndedLocked(bool retry) { + // If this CallState is still in use, this call ended because of a failure, + // so we need to stop using it and optionally create a new one. + // Otherwise, we have deliberately ended this call, and no further action + // is required. + if (this == health_check_client_->call_state_.get()) { + health_check_client_->call_state_.reset(); + if (retry) { + GPR_ASSERT(!health_check_client_->shutting_down_); + if (seen_response_.load(std::memory_order_acquire)) { + // If the call fails after we've gotten a successful response, reset + // the backoff and restart the call immediately. + health_check_client_->retry_backoff_.Reset(); + health_check_client_->StartCallLocked(); + } else { + // If the call failed without receiving any messages, retry later. + health_check_client_->StartRetryTimerLocked(); + } + } + } + // When the last ref to the call stack goes away, the CallState object + // will be automatically destroyed. + call_->Unref(DEBUG_LOCATION, "call_ended"); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/http_connect_handshaker.cc b/src/core/ext/filters/client_channel/http_connect_handshaker.cc new file mode 100644 index 00000000..4140f6ce --- /dev/null +++ b/src/core/ext/filters/client_channel/http_connect_handshaker.cc @@ -0,0 +1,392 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/http/format_request.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { + +namespace { + +class HttpConnectHandshaker : public Handshaker { + public: + HttpConnectHandshaker(); + void Shutdown(grpc_error_handle why) override; + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override; + const char* name() const override { return "http_connect"; } + + private: + ~HttpConnectHandshaker() override; + void CleanupArgsForFailureLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void HandshakeFailedLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + static void OnWriteDone(void* arg, grpc_error_handle error); + static void OnReadDone(void* arg, grpc_error_handle error); + static void OnWriteDoneScheduler(void* arg, grpc_error_handle error); + static void OnReadDoneScheduler(void* arg, grpc_error_handle error); + + Mutex mu_; + + bool is_shutdown_ ABSL_GUARDED_BY(mu_) = false; + // Endpoint and read buffer to destroy after a shutdown. + grpc_endpoint* endpoint_to_destroy_ ABSL_GUARDED_BY(mu_) = nullptr; + grpc_slice_buffer* read_buffer_to_destroy_ ABSL_GUARDED_BY(mu_) = nullptr; + + // State saved while performing the handshake. + HandshakerArgs* args_ = nullptr; + grpc_closure* on_handshake_done_ = nullptr; + + // Objects for processing the HTTP CONNECT request and response. + grpc_slice_buffer write_buffer_ ABSL_GUARDED_BY(mu_); + grpc_closure request_done_closure_ ABSL_GUARDED_BY(mu_); + grpc_closure response_read_closure_ ABSL_GUARDED_BY(mu_); + grpc_http_parser http_parser_ ABSL_GUARDED_BY(mu_); + grpc_http_response http_response_ ABSL_GUARDED_BY(mu_); +}; + +HttpConnectHandshaker::~HttpConnectHandshaker() { + if (endpoint_to_destroy_ != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy_); + } + if (read_buffer_to_destroy_ != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_); + gpr_free(read_buffer_to_destroy_); + } + grpc_slice_buffer_destroy_internal(&write_buffer_); + grpc_http_parser_destroy(&http_parser_); + grpc_http_response_destroy(&http_response_); +} + +// Set args fields to nullptr, saving the endpoint and read buffer for +// later destruction. +void HttpConnectHandshaker::CleanupArgsForFailureLocked() { + endpoint_to_destroy_ = args_->endpoint; + args_->endpoint = nullptr; + read_buffer_to_destroy_ = args_->read_buffer; + args_->read_buffer = nullptr; + grpc_channel_args_destroy(args_->args); + args_->args = nullptr; +} + +// If the handshake failed or we're shutting down, clean up and invoke the +// callback with the error. +void HttpConnectHandshaker::HandshakeFailedLocked(grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) { + // If we were shut down after an endpoint operation succeeded but + // before the endpoint callback was invoked, we need to generate our + // own error. + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); + } + if (!is_shutdown_) { + // TODO(ctiller): It is currently necessary to shutdown endpoints + // before destroying them, even if we know that there are no + // pending read/write callbacks. This should be fixed, at which + // point this can be removed. + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error)); + // Not shutting down, so the handshake failed. Clean up before + // invoking the callback. + CleanupArgsForFailureLocked(); + // Set shutdown to true so that subsequent calls to + // http_connect_handshaker_shutdown() do nothing. + is_shutdown_ = true; + } + // Invoke callback. + ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error); +} + +// This callback can be invoked inline while already holding onto the mutex. To +// avoid deadlocks, schedule OnWriteDone on ExecCtx. +void HttpConnectHandshaker::OnWriteDoneScheduler(void* arg, + grpc_error_handle error) { + auto* handshaker = static_cast(arg); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&handshaker->request_done_closure_, + &HttpConnectHandshaker::OnWriteDone, handshaker, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_REF(error)); +} + +// Callback invoked when finished writing HTTP CONNECT request. +void HttpConnectHandshaker::OnWriteDone(void* arg, grpc_error_handle error) { + auto* handshaker = static_cast(arg); + ReleasableMutexLock lock(&handshaker->mu_); + if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) { + // If the write failed or we're shutting down, clean up and invoke the + // callback with the error. + handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error)); + lock.Release(); + handshaker->Unref(); + } else { + // Otherwise, read the response. + // The read callback inherits our ref to the handshaker. + grpc_endpoint_read( + handshaker->args_->endpoint, handshaker->args_->read_buffer, + GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, + &HttpConnectHandshaker::OnReadDoneScheduler, + handshaker, grpc_schedule_on_exec_ctx), + /*urgent=*/true); + } +} + +// This callback can be invoked inline while already holding onto the mutex. To +// avoid deadlocks, schedule OnReadDone on ExecCtx. +void HttpConnectHandshaker::OnReadDoneScheduler(void* arg, + grpc_error_handle error) { + auto* handshaker = static_cast(arg); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, + &HttpConnectHandshaker::OnReadDone, handshaker, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_REF(error)); +} + +// Callback invoked for reading HTTP CONNECT response. +void HttpConnectHandshaker::OnReadDone(void* arg, grpc_error_handle error) { + auto* handshaker = static_cast(arg); + ReleasableMutexLock lock(&handshaker->mu_); + if (error != GRPC_ERROR_NONE || handshaker->is_shutdown_) { + // If the read failed or we're shutting down, clean up and invoke the + // callback with the error. + handshaker->HandshakeFailedLocked(GRPC_ERROR_REF(error)); + goto done; + } + // Add buffer to parser. + for (size_t i = 0; i < handshaker->args_->read_buffer->count; ++i) { + if (GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i]) > 0) { + size_t body_start_offset = 0; + error = grpc_http_parser_parse(&handshaker->http_parser_, + handshaker->args_->read_buffer->slices[i], + &body_start_offset); + if (error != GRPC_ERROR_NONE) { + handshaker->HandshakeFailedLocked(error); + goto done; + } + if (handshaker->http_parser_.state == GRPC_HTTP_BODY) { + // Remove the data we've already read from the read buffer, + // leaving only the leftover bytes (if any). + grpc_slice_buffer tmp_buffer; + grpc_slice_buffer_init(&tmp_buffer); + if (body_start_offset < + GRPC_SLICE_LENGTH(handshaker->args_->read_buffer->slices[i])) { + grpc_slice_buffer_add( + &tmp_buffer, + grpc_slice_split_tail(&handshaker->args_->read_buffer->slices[i], + body_start_offset)); + } + grpc_slice_buffer_addn(&tmp_buffer, + &handshaker->args_->read_buffer->slices[i + 1], + handshaker->args_->read_buffer->count - i - 1); + grpc_slice_buffer_swap(handshaker->args_->read_buffer, &tmp_buffer); + grpc_slice_buffer_destroy_internal(&tmp_buffer); + break; + } + } + } + // If we're not done reading the response, read more data. + // TODO(roth): In practice, I suspect that the response to a CONNECT + // request will never include a body, in which case this check is + // sufficient. However, the language of RFC-2817 doesn't explicitly + // forbid the response from including a body. If there is a body, + // it's possible that we might have parsed part but not all of the + // body, in which case this check will cause us to fail to parse the + // remainder of the body. If that ever becomes an issue, we may + // need to fix the HTTP parser to understand when the body is + // complete (e.g., handling chunked transfer encoding or looking + // at the Content-Length: header). + if (handshaker->http_parser_.state != GRPC_HTTP_BODY) { + grpc_slice_buffer_reset_and_unref_internal(handshaker->args_->read_buffer); + grpc_endpoint_read( + handshaker->args_->endpoint, handshaker->args_->read_buffer, + GRPC_CLOSURE_INIT(&handshaker->response_read_closure_, + &HttpConnectHandshaker::OnReadDoneScheduler, + handshaker, grpc_schedule_on_exec_ctx), + /*urgent=*/true); + return; + } + // Make sure we got a 2xx response. + if (handshaker->http_response_.status < 200 || + handshaker->http_response_.status >= 300) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("HTTP proxy returned response code ", + handshaker->http_response_.status)); + handshaker->HandshakeFailedLocked(error); + goto done; + } + // Success. Invoke handshake-done callback. + ExecCtx::Run(DEBUG_LOCATION, handshaker->on_handshake_done_, error); +done: + // Set shutdown to true so that subsequent calls to + // http_connect_handshaker_shutdown() do nothing. + handshaker->is_shutdown_ = true; + lock.Release(); + handshaker->Unref(); +} + +// +// Public handshaker methods +// + +void HttpConnectHandshaker::Shutdown(grpc_error_handle why) { + { + MutexLock lock(&mu_); + if (!is_shutdown_) { + is_shutdown_ = true; + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why)); + CleanupArgsForFailureLocked(); + } + } + GRPC_ERROR_UNREF(why); +} + +void HttpConnectHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, + grpc_closure* on_handshake_done, + HandshakerArgs* args) { + // Check for HTTP CONNECT channel arg. + // If not found, invoke on_handshake_done without doing anything. + const grpc_arg* arg = + grpc_channel_args_find(args->args, GRPC_ARG_HTTP_CONNECT_SERVER); + char* server_name = grpc_channel_arg_get_string(arg); + if (server_name == nullptr) { + // Set shutdown to true so that subsequent calls to + // http_connect_handshaker_shutdown() do nothing. + { + MutexLock lock(&mu_); + is_shutdown_ = true; + } + ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, GRPC_ERROR_NONE); + return; + } + // Get headers from channel args. + arg = grpc_channel_args_find(args->args, GRPC_ARG_HTTP_CONNECT_HEADERS); + char* arg_header_string = grpc_channel_arg_get_string(arg); + grpc_http_header* headers = nullptr; + size_t num_headers = 0; + char** header_strings = nullptr; + size_t num_header_strings = 0; + if (arg_header_string != nullptr) { + gpr_string_split(arg_header_string, "\n", &header_strings, + &num_header_strings); + headers = static_cast( + gpr_malloc(sizeof(grpc_http_header) * num_header_strings)); + for (size_t i = 0; i < num_header_strings; ++i) { + char* sep = strchr(header_strings[i], ':'); + + if (sep == nullptr) { + gpr_log(GPR_ERROR, "skipping unparseable HTTP CONNECT header: %s", + header_strings[i]); + continue; + } + *sep = '\0'; + headers[num_headers].key = header_strings[i]; + headers[num_headers].value = sep + 1; + ++num_headers; + } + } + // Save state in the handshaker object. + MutexLock lock(&mu_); + args_ = args; + on_handshake_done_ = on_handshake_done; + // Log connection via proxy. + std::string proxy_name(grpc_endpoint_get_peer(args->endpoint)); + gpr_log(GPR_INFO, "Connecting to server %s via HTTP proxy %s", server_name, + proxy_name.c_str()); + // Construct HTTP CONNECT request. + grpc_httpcli_request request; + request.host = server_name; + request.ssl_host_override = nullptr; + request.http.method = const_cast("CONNECT"); + request.http.path = server_name; + request.http.version = GRPC_HTTP_HTTP10; // Set by OnReadDone + request.http.hdrs = headers; + request.http.hdr_count = num_headers; + request.http.body_length = 0; + request.http.body = nullptr; + request.handshaker = &grpc_httpcli_plaintext; + grpc_slice request_slice = grpc_httpcli_format_connect_request(&request); + grpc_slice_buffer_add(&write_buffer_, request_slice); + // Clean up. + gpr_free(headers); + for (size_t i = 0; i < num_header_strings; ++i) { + gpr_free(header_strings[i]); + } + gpr_free(header_strings); + // Take a new ref to be held by the write callback. + Ref().release(); + grpc_endpoint_write( + args->endpoint, &write_buffer_, + GRPC_CLOSURE_INIT(&request_done_closure_, + &HttpConnectHandshaker::OnWriteDoneScheduler, this, + grpc_schedule_on_exec_ctx), + nullptr); +} + +HttpConnectHandshaker::HttpConnectHandshaker() { + grpc_slice_buffer_init(&write_buffer_); + grpc_http_parser_init(&http_parser_, GRPC_HTTP_RESPONSE, &http_response_); +} + +// +// handshaker factory +// + +class HttpConnectHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* /*args*/, + grpc_pollset_set* /*interested_parties*/, + HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(MakeRefCounted()); + } + ~HttpConnectHandshakerFactory() override = default; +}; + +} // namespace + +void RegisterHttpConnectHandshaker(CoreConfiguration::Builder* builder) { + builder->handshaker_registry()->RegisterHandshakerFactory( + true /* at_start */, grpc_core::HANDSHAKER_CLIENT, + absl::make_unique()); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/http_proxy.cc b/src/core/ext/filters/client_channel/http_proxy.cc new file mode 100644 index 00000000..4bc7b1b1 --- /dev/null +++ b/src/core/ext/filters/client_channel/http_proxy.cc @@ -0,0 +1,234 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/http_proxy.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { +namespace { + +/** + * Parses the 'https_proxy' env var (fallback on 'http_proxy') and returns the + * proxy hostname to resolve or nullptr on error. Also sets 'user_cred' to user + * credentials if present in the 'http_proxy' env var, otherwise leaves it + * unchanged. It is caller's responsibility to gpr_free user_cred. + */ +// TODO(hork): change this to return std::string +char* GetHttpProxyServer(const grpc_channel_args* args, char** user_cred) { + GPR_ASSERT(user_cred != nullptr); + absl::StatusOr uri; + char* proxy_name = nullptr; + char** authority_strs = nullptr; + size_t authority_nstrs; + /* We check the following places to determine the HTTP proxy to use, stopping + * at the first one that is set: + * 1. GRPC_ARG_HTTP_PROXY channel arg + * 2. grpc_proxy environment variable + * 3. https_proxy environment variable + * 4. http_proxy environment variable + * If none of the above are set, then no HTTP proxy will be used. + */ + char* uri_str = + gpr_strdup(grpc_channel_args_find_string(args, GRPC_ARG_HTTP_PROXY)); + if (uri_str == nullptr) uri_str = gpr_getenv("grpc_proxy"); + if (uri_str == nullptr) uri_str = gpr_getenv("https_proxy"); + if (uri_str == nullptr) uri_str = gpr_getenv("http_proxy"); + if (uri_str == nullptr) return nullptr; + // an emtpy value means "don't use proxy" + if (uri_str[0] == '\0') goto done; + uri = URI::Parse(uri_str); + if (!uri.ok() || uri->authority().empty()) { + gpr_log(GPR_ERROR, "cannot parse value of 'http_proxy' env var. Error: %s", + uri.status().ToString().c_str()); + goto done; + } + if (uri->scheme() != "http") { + gpr_log(GPR_ERROR, "'%s' scheme not supported in proxy URI", + uri->scheme().c_str()); + goto done; + } + /* Split on '@' to separate user credentials from host */ + gpr_string_split(uri->authority().c_str(), "@", &authority_strs, + &authority_nstrs); + GPR_ASSERT(authority_nstrs != 0); /* should have at least 1 string */ + if (authority_nstrs == 1) { + /* User cred not present in authority */ + proxy_name = authority_strs[0]; + } else if (authority_nstrs == 2) { + /* User cred found */ + *user_cred = authority_strs[0]; + proxy_name = authority_strs[1]; + gpr_log(GPR_DEBUG, "userinfo found in proxy URI"); + } else { + /* Bad authority */ + for (size_t i = 0; i < authority_nstrs; i++) { + gpr_free(authority_strs[i]); + } + proxy_name = nullptr; + } + gpr_free(authority_strs); +done: + gpr_free(uri_str); + return proxy_name; +} + +// Adds the default port if target does not contain a port. +std::string MaybeAddDefaultPort(absl::string_view target) { + absl::string_view host; + absl::string_view port; + SplitHostPort(target, &host, &port); + if (port.empty()) { + return JoinHostPort(host, kDefaultSecurePortInt); + } + return std::string(target); +} + +class HttpProxyMapper : public ProxyMapperInterface { + public: + bool MapName(const char* server_uri, const grpc_channel_args* args, + char** name_to_resolve, grpc_channel_args** new_args) override { + if (!grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_HTTP_PROXY, true)) { + return false; + } + char* user_cred = nullptr; + *name_to_resolve = GetHttpProxyServer(args, &user_cred); + if (*name_to_resolve == nullptr) return false; + char* no_proxy_str = nullptr; + std::string server_target; + absl::StatusOr uri = URI::Parse(server_uri); + if (!uri.ok() || uri->path().empty()) { + gpr_log(GPR_ERROR, + "'http_proxy' environment variable set, but cannot " + "parse server URI '%s' -- not using proxy. Error: %s", + server_uri, uri.status().ToString().c_str()); + goto no_use_proxy; + } + if (uri->scheme() == "unix") { + gpr_log(GPR_INFO, "not using proxy for Unix domain socket '%s'", + server_uri); + goto no_use_proxy; + } + /* Prefer using 'no_grpc_proxy'. Fallback on 'no_proxy' if it is not set. */ + no_proxy_str = gpr_getenv("no_grpc_proxy"); + if (no_proxy_str == nullptr) no_proxy_str = gpr_getenv("no_proxy"); + if (no_proxy_str != nullptr) { + static const char* NO_PROXY_SEPARATOR = ","; + bool use_proxy = true; + std::string server_host; + std::string server_port; + if (!SplitHostPort(absl::StripPrefix(uri->path(), "/"), &server_host, + &server_port)) { + gpr_log(GPR_INFO, + "unable to split host and port, not checking no_proxy list for " + "host '%s'", + server_uri); + gpr_free(no_proxy_str); + } else { + size_t uri_len = server_host.size(); + char** no_proxy_hosts; + size_t num_no_proxy_hosts; + gpr_string_split(no_proxy_str, NO_PROXY_SEPARATOR, &no_proxy_hosts, + &num_no_proxy_hosts); + for (size_t i = 0; i < num_no_proxy_hosts; i++) { + char* no_proxy_entry = no_proxy_hosts[i]; + size_t no_proxy_len = strlen(no_proxy_entry); + if (no_proxy_len <= uri_len && + gpr_stricmp(no_proxy_entry, + &(server_host.c_str()[uri_len - no_proxy_len])) == + 0) { + gpr_log(GPR_INFO, "not using proxy for host in no_proxy list '%s'", + server_uri); + use_proxy = false; + break; + } + } + for (size_t i = 0; i < num_no_proxy_hosts; i++) { + gpr_free(no_proxy_hosts[i]); + } + gpr_free(no_proxy_hosts); + gpr_free(no_proxy_str); + if (!use_proxy) goto no_use_proxy; + } + } + server_target = + MaybeAddDefaultPort(absl::StripPrefix(uri->path(), "/")).c_str(); + grpc_arg args_to_add[2]; + args_to_add[0] = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_HTTP_CONNECT_SERVER), + const_cast(server_target.c_str())); + if (user_cred != nullptr) { + /* Use base64 encoding for user credentials as stated in RFC 7617 */ + char* encoded_user_cred = + grpc_base64_encode(user_cred, strlen(user_cred), 0, 0); + std::string header = + absl::StrCat("Proxy-Authorization:Basic ", encoded_user_cred); + gpr_free(encoded_user_cred); + args_to_add[1] = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_HTTP_CONNECT_HEADERS), + const_cast(header.c_str())); + *new_args = grpc_channel_args_copy_and_add(args, args_to_add, 2); + } else { + *new_args = grpc_channel_args_copy_and_add(args, args_to_add, 1); + } + gpr_free(user_cred); + return true; + no_use_proxy: + gpr_free(*name_to_resolve); + *name_to_resolve = nullptr; + gpr_free(user_cred); + return false; + } + + bool MapAddress(const grpc_resolved_address& /*address*/, + const grpc_channel_args* /*args*/, + grpc_resolved_address** /*new_address*/, + grpc_channel_args** /*new_args*/) override { + return false; + } +}; + +} // namespace + +void RegisterHttpProxyMapper() { + ProxyMapperRegistry::Register( + true /* at_start */, + std::unique_ptr(new HttpProxyMapper())); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy.cc b/src/core/ext/filters/client_channel/lb_policy.cc new file mode 100644 index 00000000..44363b4a --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy.cc @@ -0,0 +1,131 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" + +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/iomgr/combiner.h" + +namespace grpc_core { + +DebugOnlyTraceFlag grpc_trace_lb_policy_refcount(false, "lb_policy_refcount"); + +// +// LoadBalancingPolicy +// + +LoadBalancingPolicy::LoadBalancingPolicy(Args args, intptr_t initial_refcount) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_trace_lb_policy_refcount) + ? "LoadBalancingPolicy" + : nullptr, + initial_refcount), + work_serializer_(std::move(args.work_serializer)), + interested_parties_(grpc_pollset_set_create()), + channel_control_helper_(std::move(args.channel_control_helper)) {} + +LoadBalancingPolicy::~LoadBalancingPolicy() { + grpc_pollset_set_destroy(interested_parties_); +} + +void LoadBalancingPolicy::Orphan() { + ShutdownLocked(); + Unref(DEBUG_LOCATION, "Orphan"); +} + +// +// LoadBalancingPolicy::UpdateArgs +// + +LoadBalancingPolicy::UpdateArgs::UpdateArgs(const UpdateArgs& other) { + addresses = other.addresses; + config = other.config; + args = grpc_channel_args_copy(other.args); +} + +LoadBalancingPolicy::UpdateArgs::UpdateArgs(UpdateArgs&& other) noexcept { + addresses = std::move(other.addresses); + config = std::move(other.config); + // TODO(roth): Use std::move() once channel args is converted to C++. + args = other.args; + other.args = nullptr; +} + +LoadBalancingPolicy::UpdateArgs& LoadBalancingPolicy::UpdateArgs::operator=( + const UpdateArgs& other) { + if (&other == this) { + return *this; + } + addresses = other.addresses; + config = other.config; + grpc_channel_args_destroy(args); + args = grpc_channel_args_copy(other.args); + return *this; +} + +LoadBalancingPolicy::UpdateArgs& LoadBalancingPolicy::UpdateArgs::operator=( + UpdateArgs&& other) noexcept { + addresses = std::move(other.addresses); + config = std::move(other.config); + // TODO(roth): Use std::move() once channel args is converted to C++. + grpc_channel_args_destroy(args); + args = other.args; + other.args = nullptr; + return *this; +} + +// +// LoadBalancingPolicy::QueuePicker +// + +LoadBalancingPolicy::PickResult LoadBalancingPolicy::QueuePicker::Pick( + PickArgs /*args*/) { + // We invoke the parent's ExitIdleLocked() via a closure instead + // of doing it directly here, for two reasons: + // 1. ExitIdleLocked() may cause the policy's state to change and + // a new picker to be delivered to the channel. If that new + // picker is delivered before ExitIdleLocked() returns, then by + // the time this function returns, the pick will already have + // been processed, and we'll be trying to re-process the same + // pick again, leading to a crash. + // 2. We are currently running in the data plane mutex, but we + // need to bounce into the control plane work_serializer to call + // ExitIdleLocked(). + if (!exit_idle_called_ && parent_ != nullptr) { + exit_idle_called_ = true; + auto* parent = parent_->Ref().release(); // ref held by lambda. + ExecCtx::Run(DEBUG_LOCATION, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle /*error*/) { + auto* parent = static_cast(arg); + parent->work_serializer()->Run( + [parent]() { + parent->ExitIdleLocked(); + parent->Unref(); + }, + DEBUG_LOCATION); + }, + parent, nullptr), + GRPC_ERROR_NONE); + } + return PickResult::Queue(); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/address_filtering.cc b/src/core/ext/filters/client_channel/lb_policy/address_filtering.cc new file mode 100644 index 00000000..c6078e0d --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/address_filtering.cc @@ -0,0 +1,96 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/address_filtering.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +#include "src/core/lib/channel/channel_args.h" + +#define GRPC_ARG_HIERARCHICAL_PATH "grpc.internal.address.hierarchical_path" + +namespace grpc_core { + +const char* kHierarchicalPathAttributeKey = "hierarchical_path"; + +namespace { + +class HierarchicalPathAttribute : public ServerAddress::AttributeInterface { + public: + explicit HierarchicalPathAttribute(std::vector path) + : path_(std::move(path)) {} + + std::unique_ptr Copy() const override { + return absl::make_unique(path_); + } + + int Cmp(const AttributeInterface* other) const override { + const std::vector& other_path = + static_cast(other)->path_; + for (size_t i = 0; i < path_.size(); ++i) { + if (other_path.size() == i) return 1; + int r = path_[i].compare(other_path[i]); + if (r != 0) return r; + } + if (other_path.size() > path_.size()) return -1; + return 0; + } + + std::string ToString() const override { + return absl::StrCat("[", absl::StrJoin(path_, ", "), "]"); + } + + const std::vector& path() const { return path_; } + + private: + std::vector path_; +}; + +} // namespace + +std::unique_ptr +MakeHierarchicalPathAttribute(std::vector path) { + return absl::make_unique(std::move(path)); +} + +HierarchicalAddressMap MakeHierarchicalAddressMap( + const ServerAddressList& addresses) { + HierarchicalAddressMap result; + for (const ServerAddress& address : addresses) { + const HierarchicalPathAttribute* path_attribute = + static_cast( + address.GetAttribute(kHierarchicalPathAttributeKey)); + if (path_attribute == nullptr) continue; + const std::vector& path = path_attribute->path(); + auto it = path.begin(); + ServerAddressList& target_list = result[*it]; + std::unique_ptr new_attribute; + ++it; + if (it != path.end()) { + std::vector remaining_path(it, path.end()); + new_attribute = absl::make_unique( + std::move(remaining_path)); + } + target_list.emplace_back(address.WithAttribute( + kHierarchicalPathAttributeKey, std::move(new_attribute))); + } + return result; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc b/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc new file mode 100644 index 00000000..1d0b1ef2 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc @@ -0,0 +1,304 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" + +#include + +#include "absl/strings/str_cat.h" + +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" + +namespace grpc_core { + +// +// ChildPolicyHandler::Helper +// + +class ChildPolicyHandler::Helper + : public LoadBalancingPolicy::ChannelControlHelper { + public: + explicit Helper(RefCountedPtr parent) + : parent_(std::move(parent)) {} + + ~Helper() override { parent_.reset(DEBUG_LOCATION, "Helper"); } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + if (parent_->shutting_down_) return nullptr; + if (!CalledByCurrentChild() && !CalledByPendingChild()) return nullptr; + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + if (parent_->shutting_down_) return; + // If this request is from the pending child policy, ignore it until + // it reports something other than CONNECTING, at which point we swap it + // into place. + if (CalledByPendingChild()) { + if (GRPC_TRACE_FLAG_ENABLED(*(parent_->tracer_))) { + gpr_log(GPR_INFO, + "[child_policy_handler %p] helper %p: pending child policy %p " + "reports state=%s (%s)", + parent_.get(), this, child_, ConnectivityStateName(state), + status.ToString().c_str()); + } + if (state == GRPC_CHANNEL_CONNECTING) return; + grpc_pollset_set_del_pollset_set( + parent_->child_policy_->interested_parties(), + parent_->interested_parties()); + parent_->child_policy_ = std::move(parent_->pending_child_policy_); + } else if (!CalledByCurrentChild()) { + // This request is from an outdated child, so ignore it. + return; + } + parent_->channel_control_helper()->UpdateState(state, status, + std::move(picker)); + } + + void RequestReresolution() override { + if (parent_->shutting_down_) return; + // Only forward re-resolution requests from the most recent child, + // since that's the one that will be receiving any update we receive + // from the resolver. + const LoadBalancingPolicy* latest_child_policy = + parent_->pending_child_policy_ != nullptr + ? parent_->pending_child_policy_.get() + : parent_->child_policy_.get(); + if (child_ != latest_child_policy) return; + if (GRPC_TRACE_FLAG_ENABLED(*(parent_->tracer_))) { + gpr_log(GPR_INFO, "[child_policy_handler %p] started name re-resolving", + parent_.get()); + } + parent_->channel_control_helper()->RequestReresolution(); + } + + absl::string_view GetAuthority() override { + return parent_->channel_control_helper()->GetAuthority(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + if (parent_->shutting_down_) return; + if (!CalledByPendingChild() && !CalledByCurrentChild()) return; + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + void set_child(LoadBalancingPolicy* child) { child_ = child; } + + private: + bool CalledByPendingChild() const { + GPR_ASSERT(child_ != nullptr); + return child_ == parent_->pending_child_policy_.get(); + } + + bool CalledByCurrentChild() const { + GPR_ASSERT(child_ != nullptr); + return child_ == parent_->child_policy_.get(); + }; + + RefCountedPtr parent_; + LoadBalancingPolicy* child_ = nullptr; +}; + +// +// ChildPolicyHandler +// + +void ChildPolicyHandler::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, "[child_policy_handler %p] shutting down", this); + } + shutting_down_ = true; + if (child_policy_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, "[child_policy_handler %p] shutting down lb_policy %p", + this, child_policy_.get()); + } + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + interested_parties()); + child_policy_.reset(); + } + if (pending_child_policy_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, + "[child_policy_handler %p] shutting down pending lb_policy %p", + this, pending_child_policy_.get()); + } + grpc_pollset_set_del_pollset_set( + pending_child_policy_->interested_parties(), interested_parties()); + pending_child_policy_.reset(); + } +} + +void ChildPolicyHandler::UpdateLocked(UpdateArgs args) { + // If the child policy name changes, we need to create a new child + // policy. When this happens, we leave child_policy_ as-is and store + // the new child policy in pending_child_policy_. Once the new child + // policy transitions into state READY, we swap it into child_policy_, + // replacing the original child policy. So pending_child_policy_ is + // non-null only between when we apply an update that changes the child + // policy name and when the new child reports state READY. + // + // Updates can arrive at any point during this transition. We always + // apply updates relative to the most recently created child policy, + // even if the most recent one is still in pending_child_policy_. This + // is true both when applying the updates to an existing child policy + // and when determining whether we need to create a new policy. + // + // As a result of this, there are several cases to consider here: + // + // 1. We have no existing child policy (i.e., this is the first update + // we receive after being created; in this case, both child_policy_ + // and pending_child_policy_ are null). In this case, we create a + // new child policy and store it in child_policy_. + // + // 2. We have an existing child policy and have no pending child policy + // from a previous update (i.e., either there has not been a + // previous update that changed the policy name, or we have already + // finished swapping in the new policy; in this case, child_policy_ + // is non-null but pending_child_policy_ is null). In this case: + // a. If going from the current config to the new config does not + // require a new policy, then we update the existing child policy. + // b. If going from the current config to the new config does require a + // new policy, we create a new policy. The policy will be stored in + // pending_child_policy_ and will later be swapped into + // child_policy_ by the helper when the new child transitions + // into state READY. + // + // 3. We have an existing child policy and have a pending child policy + // from a previous update (i.e., a previous update set + // pending_child_policy_ as per case 2b above and that policy has + // not yet transitioned into state READY and been swapped into + // child_policy_; in this case, both child_policy_ and + // pending_child_policy_ are non-null). In this case: + // a. If going from the current config to the new config does not + // require a new policy, then we update the existing pending + // child policy. + // b. If going from the current config to the new config does require a + // new child policy, then we create a new policy. The new + // policy is stored in pending_child_policy_ (replacing the one + // that was there before, which will be immediately shut down) + // and will later be swapped into child_policy_ by the helper + // when the new child transitions into state READY. + const bool create_policy = + // case 1 + child_policy_ == nullptr || + // cases 2b and 3b + ConfigChangeRequiresNewPolicyInstance(current_config_.get(), + args.config.get()); + current_config_ = args.config; + LoadBalancingPolicy* policy_to_update = nullptr; + if (create_policy) { + // Cases 1, 2b, and 3b: create a new child policy. + // If child_policy_ is null, we set it (case 1), else we set + // pending_child_policy_ (cases 2b and 3b). + // TODO(roth): In cases 2b and 3b, we should start a timer here, so + // that there's an upper bound on the amount of time it takes us to + // switch to the new policy, even if the new policy stays in + // CONNECTING for a very long period of time. + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, + "[child_policy_handler %p] creating new %schild policy %s", this, + child_policy_ == nullptr ? "" : "pending ", args.config->name()); + } + auto& lb_policy = + child_policy_ == nullptr ? child_policy_ : pending_child_policy_; + lb_policy = CreateChildPolicy(args.config->name(), *args.args); + policy_to_update = lb_policy.get(); + } else { + // Cases 2a and 3a: update an existing policy. + // If we have a pending child policy, send the update to the pending + // policy (case 3a), else send it to the current policy (case 2a). + policy_to_update = pending_child_policy_ != nullptr + ? pending_child_policy_.get() + : child_policy_.get(); + } + GPR_ASSERT(policy_to_update != nullptr); + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, "[child_policy_handler %p] updating %schild policy %p", + this, + policy_to_update == pending_child_policy_.get() ? "pending " : "", + policy_to_update); + } + policy_to_update->UpdateLocked(std::move(args)); +} + +void ChildPolicyHandler::ExitIdleLocked() { + if (child_policy_ != nullptr) { + child_policy_->ExitIdleLocked(); + if (pending_child_policy_ != nullptr) { + pending_child_policy_->ExitIdleLocked(); + } + } +} + +void ChildPolicyHandler::ResetBackoffLocked() { + if (child_policy_ != nullptr) { + child_policy_->ResetBackoffLocked(); + if (pending_child_policy_ != nullptr) { + pending_child_policy_->ResetBackoffLocked(); + } + } +} + +OrphanablePtr ChildPolicyHandler::CreateChildPolicy( + const char* child_policy_name, const grpc_channel_args& args) { + Helper* helper = new Helper(Ref(DEBUG_LOCATION, "Helper")); + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = work_serializer(); + lb_policy_args.channel_control_helper = + std::unique_ptr(helper); + lb_policy_args.args = &args; + OrphanablePtr lb_policy = + CreateLoadBalancingPolicy(child_policy_name, std::move(lb_policy_args)); + if (GPR_UNLIKELY(lb_policy == nullptr)) { + gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", child_policy_name); + return nullptr; + } + helper->set_child(lb_policy.get()); + if (GRPC_TRACE_FLAG_ENABLED(*tracer_)) { + gpr_log(GPR_INFO, + "[child_policy_handler %p] created new LB policy \"%s\" (%p)", this, + child_policy_name, lb_policy.get()); + } + channel_control_helper()->AddTraceEvent( + ChannelControlHelper::TRACE_INFO, + absl::StrCat("Created new LB policy \"", child_policy_name, "\"")); + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + interested_parties()); + return lb_policy; +} + +bool ChildPolicyHandler::ConfigChangeRequiresNewPolicyInstance( + LoadBalancingPolicy::Config* old_config, + LoadBalancingPolicy::Config* new_config) const { + return strcmp(old_config->name(), new_config->name()) != 0; +} + +OrphanablePtr +ChildPolicyHandler::CreateLoadBalancingPolicy( + const char* name, LoadBalancingPolicy::Args args) const { + return LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + name, std::move(args)); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.cc new file mode 100644 index 00000000..a3cf12ac --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.cc @@ -0,0 +1,148 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.h" + +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/profiling/timers.h" + +static grpc_error_handle clr_init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void clr_destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +namespace { + +struct call_data { + // Stats object to update. + grpc_core::RefCountedPtr client_stats; + // State for intercepting send_initial_metadata. + grpc_closure on_complete_for_send; + grpc_closure* original_on_complete_for_send; + bool send_initial_metadata_succeeded = false; + // State for intercepting recv_initial_metadata. + grpc_closure recv_initial_metadata_ready; + grpc_closure* original_recv_initial_metadata_ready; + bool recv_initial_metadata_succeeded = false; +}; + +} // namespace + +static void on_complete_for_send(void* arg, grpc_error_handle error) { + call_data* calld = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + calld->send_initial_metadata_succeeded = true; + } + grpc_core::Closure::Run(DEBUG_LOCATION, calld->original_on_complete_for_send, + GRPC_ERROR_REF(error)); +} + +static void recv_initial_metadata_ready(void* arg, grpc_error_handle error) { + call_data* calld = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + calld->recv_initial_metadata_succeeded = true; + } + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_initial_metadata_ready, + GRPC_ERROR_REF(error)); +} + +static grpc_error_handle clr_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + GPR_ASSERT(args->context != nullptr); + new (elem->call_data) call_data(); + return GRPC_ERROR_NONE; +} + +static void clr_destroy_call_elem(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + if (calld->client_stats != nullptr) { + // Record call finished, optionally setting client_failed_to_send and + // received. + calld->client_stats->AddCallFinished( + !calld->send_initial_metadata_succeeded /* client_failed_to_send */, + calld->recv_initial_metadata_succeeded /* known_received */); + } + calld->~call_data(); +} + +static void clr_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + GPR_TIMER_SCOPE("clr_start_transport_stream_op_batch", 0); + // Handle send_initial_metadata. + if (batch->send_initial_metadata) { + // Grab client stats object from metadata. + auto client_stats_md = + batch->payload->send_initial_metadata.send_initial_metadata->Remove( + grpc_slice_from_static_string( + grpc_core::kGrpcLbClientStatsMetadataKey)); + if (client_stats_md.has_value()) { + grpc_core::GrpcLbClientStats* client_stats = + const_cast( + reinterpret_cast( + GRPC_SLICE_START_PTR(*client_stats_md))); + if (client_stats != nullptr) { + calld->client_stats.reset(client_stats); + // Intercept completion. + calld->original_on_complete_for_send = batch->on_complete; + GRPC_CLOSURE_INIT(&calld->on_complete_for_send, on_complete_for_send, + calld, grpc_schedule_on_exec_ctx); + batch->on_complete = &calld->on_complete_for_send; + } + } + } + // Intercept completion of recv_initial_metadata. + if (batch->recv_initial_metadata) { + calld->original_recv_initial_metadata_ready = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + GRPC_CLOSURE_INIT(&calld->recv_initial_metadata_ready, + recv_initial_metadata_ready, calld, + grpc_schedule_on_exec_ctx); + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready; + } + // Chain to next filter. + grpc_call_next_op(elem, batch); +} + +const grpc_channel_filter grpc_client_load_reporting_filter = { + clr_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + clr_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + clr_destroy_call_elem, + 0, // sizeof(channel_data) + clr_init_channel_elem, + clr_destroy_channel_elem, + grpc_channel_next_get_info, + "client_load_reporting"}; diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc new file mode 100644 index 00000000..489d153c --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc @@ -0,0 +1,1866 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +/// Implementation of the gRPC LB policy. +/// +/// This policy takes as input a list of resolved addresses, which must +/// include at least one balancer address. +/// +/// An internal channel (\a lb_channel_) is created for the addresses +/// from that are balancers. This channel behaves just like a regular +/// channel that uses pick_first to select from the list of balancer +/// addresses. +/// +/// When we get our initial update, we instantiate the internal *streaming* +/// call to the LB server (whichever address pick_first chose). The call +/// will be complete when either the balancer sends status or when we cancel +/// the call (e.g., because we are shutting down). In needed, we retry the +/// call. If we received at least one valid message from the server, a new +/// call attempt will be made immediately; otherwise, we apply back-off +/// delays between attempts. +/// +/// We maintain an internal round_robin policy instance for distributing +/// requests across backends. Whenever we receive a new serverlist from +/// the balancer, we update the round_robin policy with the new list of +/// addresses. If we cannot communicate with the balancer on startup, +/// however, we may enter fallback mode, in which case we will populate +/// the child policy's addresses from the backend addresses returned by the +/// resolver. +/// +/// Once a child policy instance is in place (and getting updated as described), +/// calls for a pick, a ping, or a cancellation will be serviced right +/// away by forwarding them to the child policy instance. Any time there's no +/// child policy available (i.e., right after the creation of the gRPCLB +/// policy), pick requests are queued. +/// +/// \see https://github.com/grpc/grpc/blob/master/doc/load-balancing.md for the +/// high level design and details. + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/strip.h" +#include "upb/upb.hpp" + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/static_metadata.h" + +#define GRPC_GRPCLB_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define GRPC_GRPCLB_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define GRPC_GRPCLB_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define GRPC_GRPCLB_RECONNECT_JITTER 0.2 +#define GRPC_GRPCLB_DEFAULT_FALLBACK_TIMEOUT_MS 10000 +#define GRPC_GRPCLB_DEFAULT_SUBCHANNEL_DELETION_DELAY_MS 10000 + +namespace grpc_core { + +TraceFlag grpc_lb_glb_trace(false, "glb"); + +const char kGrpcLbClientStatsMetadataKey[] = "grpclb_client_stats"; +const char kGrpcLbLbTokenMetadataKey[] = "lb-token"; + +const char kGrpcLbAddressAttributeKey[] = "grpclb"; + +namespace { + +constexpr char kGrpclb[] = "grpclb"; + +class GrpcLbConfig : public LoadBalancingPolicy::Config { + public: + GrpcLbConfig(RefCountedPtr child_policy, + std::string service_name) + : child_policy_(std::move(child_policy)), + service_name_(std::move(service_name)) {} + const char* name() const override { return kGrpclb; } + + RefCountedPtr child_policy() const { + return child_policy_; + } + + const std::string& service_name() const { return service_name_; } + + private: + RefCountedPtr child_policy_; + std::string service_name_; +}; + +class GrpcLb : public LoadBalancingPolicy { + public: + explicit GrpcLb(Args args); + + const char* name() const override { return kGrpclb; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + + private: + /// Contains a call to the LB server and all the data related to the call. + class BalancerCallState : public InternallyRefCounted { + public: + explicit BalancerCallState( + RefCountedPtr parent_grpclb_policy); + ~BalancerCallState() override; + + // It's the caller's responsibility to ensure that Orphan() is called from + // inside the combiner. + void Orphan() override; + + void StartQuery(); + + GrpcLbClientStats* client_stats() const { return client_stats_.get(); } + + bool seen_initial_response() const { return seen_initial_response_; } + bool seen_serverlist() const { return seen_serverlist_; } + + private: + GrpcLb* grpclb_policy() const { + return static_cast(grpclb_policy_.get()); + } + + void ScheduleNextClientLoadReportLocked(); + void SendClientLoadReportLocked(); + + static void MaybeSendClientLoadReport(void* arg, grpc_error_handle error); + static void ClientLoadReportDone(void* arg, grpc_error_handle error); + static void OnInitialRequestSent(void* arg, grpc_error_handle error); + static void OnBalancerMessageReceived(void* arg, grpc_error_handle error); + static void OnBalancerStatusReceived(void* arg, grpc_error_handle error); + + void MaybeSendClientLoadReportLocked(grpc_error_handle error); + void ClientLoadReportDoneLocked(grpc_error_handle error); + void OnInitialRequestSentLocked(); + void OnBalancerMessageReceivedLocked(); + void OnBalancerStatusReceivedLocked(grpc_error_handle error); + + // The owning LB policy. + RefCountedPtr grpclb_policy_; + + // The streaming call to the LB server. Always non-NULL. + grpc_call* lb_call_ = nullptr; + + // recv_initial_metadata + grpc_metadata_array lb_initial_metadata_recv_; + + // send_message + grpc_byte_buffer* send_message_payload_ = nullptr; + grpc_closure lb_on_initial_request_sent_; + + // recv_message + grpc_byte_buffer* recv_message_payload_ = nullptr; + grpc_closure lb_on_balancer_message_received_; + bool seen_initial_response_ = false; + bool seen_serverlist_ = false; + + // recv_trailing_metadata + grpc_closure lb_on_balancer_status_received_; + grpc_metadata_array lb_trailing_metadata_recv_; + grpc_status_code lb_call_status_; + grpc_slice lb_call_status_details_; + + // The stats for client-side load reporting associated with this LB call. + // Created after the first serverlist is received. + RefCountedPtr client_stats_; + grpc_millis client_stats_report_interval_ = 0; + grpc_timer client_load_report_timer_; + bool client_load_report_timer_callback_pending_ = false; + bool last_client_load_report_counters_were_zero_ = false; + bool client_load_report_is_due_ = false; + // The closure used for either the load report timer or the callback for + // completion of sending the load report. + grpc_closure client_load_report_closure_; + }; + + class SubchannelWrapper : public DelegatingSubchannel { + public: + SubchannelWrapper(RefCountedPtr subchannel, + RefCountedPtr lb_policy, std::string lb_token, + RefCountedPtr client_stats) + : DelegatingSubchannel(std::move(subchannel)), + lb_policy_(std::move(lb_policy)), + lb_token_(std::move(lb_token)), + client_stats_(std::move(client_stats)) {} + + ~SubchannelWrapper() override { + if (!lb_policy_->shutting_down_) { + lb_policy_->CacheDeletedSubchannelLocked(wrapped_subchannel()); + } + } + + const std::string& lb_token() const { return lb_token_; } + GrpcLbClientStats* client_stats() const { return client_stats_.get(); } + + private: + RefCountedPtr lb_policy_; + std::string lb_token_; + RefCountedPtr client_stats_; + }; + + class TokenAndClientStatsAttribute + : public ServerAddress::AttributeInterface { + public: + TokenAndClientStatsAttribute(std::string lb_token, + RefCountedPtr client_stats) + : lb_token_(std::move(lb_token)), + client_stats_(std::move(client_stats)) {} + + std::unique_ptr Copy() const override { + return absl::make_unique(lb_token_, + client_stats_); + } + + int Cmp(const AttributeInterface* other_base) const override { + const TokenAndClientStatsAttribute* other = + static_cast(other_base); + int r = lb_token_.compare(other->lb_token_); + if (r != 0) return r; + return QsortCompare(client_stats_.get(), other->client_stats_.get()); + } + + std::string ToString() const override { + return absl::StrFormat("lb_token=\"%s\" client_stats=%p", lb_token_, + client_stats_.get()); + } + + const std::string& lb_token() const { return lb_token_; } + RefCountedPtr client_stats() const { + return client_stats_; + } + + private: + std::string lb_token_; + RefCountedPtr client_stats_; + }; + + class Serverlist : public RefCounted { + public: + // Takes ownership of serverlist. + explicit Serverlist(std::vector serverlist) + : serverlist_(std::move(serverlist)) {} + + bool operator==(const Serverlist& other) const; + + const std::vector& serverlist() const { return serverlist_; } + + // Returns a text representation suitable for logging. + std::string AsText() const; + + // Extracts all non-drop entries into a ServerAddressList. + ServerAddressList GetServerAddressList( + GrpcLbClientStats* client_stats) const; + + // Returns true if the serverlist contains at least one drop entry and + // no backend address entries. + bool ContainsAllDropEntries() const; + + // Returns the LB token to use for a drop, or null if the call + // should not be dropped. + // + // Note: This is called from the picker, so it will be invoked in + // the channel's data plane mutex, NOT the control plane + // work_serializer. It should not be accessed by any other part of the LB + // policy. + const char* ShouldDrop(); + + private: + std::vector serverlist_; + + // Guarded by the channel's data plane mutex, NOT the control + // plane work_serializer. It should not be accessed by anything but the + // picker via the ShouldDrop() method. + size_t drop_index_ = 0; + }; + + class Picker : public SubchannelPicker { + public: + Picker(RefCountedPtr serverlist, + std::unique_ptr child_picker, + RefCountedPtr client_stats) + : serverlist_(std::move(serverlist)), + child_picker_(std::move(child_picker)), + client_stats_(std::move(client_stats)) {} + + PickResult Pick(PickArgs args) override; + + private: + // Serverlist to be used for determining drops. + RefCountedPtr serverlist_; + + std::unique_ptr child_picker_; + RefCountedPtr client_stats_; + }; + + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr parent) + : parent_(std::move(parent)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr parent_; + }; + + class StateWatcher : public AsyncConnectivityStateWatcherInterface { + public: + explicit StateWatcher(RefCountedPtr parent) + : AsyncConnectivityStateWatcherInterface(parent->work_serializer()), + parent_(std::move(parent)) {} + + ~StateWatcher() override { parent_.reset(DEBUG_LOCATION, "StateWatcher"); } + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& status) override { + if (parent_->fallback_at_startup_checks_pending_ && + new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + // In TRANSIENT_FAILURE. Cancel the fallback timer and go into + // fallback mode immediately. + gpr_log(GPR_INFO, + "[grpclb %p] balancer channel in state:TRANSIENT_FAILURE (%s); " + "entering fallback mode", + parent_.get(), status.ToString().c_str()); + parent_->fallback_at_startup_checks_pending_ = false; + grpc_timer_cancel(&parent_->lb_fallback_timer_); + parent_->fallback_mode_ = true; + parent_->CreateOrUpdateChildPolicyLocked(); + // Cancel the watch, since we don't care about the channel state once we + // go into fallback mode. + parent_->CancelBalancerChannelConnectivityWatchLocked(); + } + } + + RefCountedPtr parent_; + }; + + ~GrpcLb() override; + + void ShutdownLocked() override; + + // Helper functions used in UpdateLocked(). + void ProcessAddressesAndChannelArgsLocked(const ServerAddressList& addresses, + const grpc_channel_args& args); + static ServerAddressList AddNullLbTokenToAddresses( + const ServerAddressList& addresses); + + void CancelBalancerChannelConnectivityWatchLocked(); + + // Methods for dealing with fallback state. + void MaybeEnterFallbackModeAfterStartup(); + static void OnFallbackTimer(void* arg, grpc_error_handle error); + void OnFallbackTimerLocked(grpc_error_handle error); + + // Methods for dealing with the balancer call. + void StartBalancerCallLocked(); + void StartBalancerCallRetryTimerLocked(); + static void OnBalancerCallRetryTimer(void* arg, grpc_error_handle error); + void OnBalancerCallRetryTimerLocked(grpc_error_handle error); + + // Methods for dealing with the child policy. + grpc_channel_args* CreateChildPolicyArgsLocked( + bool is_backend_from_grpclb_load_balancer); + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + void CreateOrUpdateChildPolicyLocked(); + + // Subchannel caching. + void CacheDeletedSubchannelLocked( + RefCountedPtr subchannel); + void StartSubchannelCacheTimerLocked(); + static void OnSubchannelCacheTimer(void* arg, grpc_error_handle error); + void OnSubchannelCacheTimerLocked(grpc_error_handle error); + + // Who the client is trying to communicate with. + std::string server_name_; + // Configurations for the policy. + RefCountedPtr config_; + + // Current channel args from the resolver. + grpc_channel_args* args_ = nullptr; + + // Internal state. + bool shutting_down_ = false; + + // The channel for communicating with the LB server. + grpc_channel* lb_channel_ = nullptr; + StateWatcher* watcher_ = nullptr; + // Response generator to inject address updates into lb_channel_. + RefCountedPtr response_generator_; + // Parent channelz node. + RefCountedPtr parent_channelz_node_; + + // The data associated with the current LB call. It holds a ref to this LB + // policy. It's initialized every time we query for backends. It's reset to + // NULL whenever the current LB call is no longer needed (e.g., the LB policy + // is shutting down, or the LB call has ended). A non-NULL lb_calld_ always + // contains a non-NULL lb_call_. + OrphanablePtr lb_calld_; + // Timeout in milliseconds for the LB call. 0 means no deadline. + const int lb_call_timeout_ms_ = 0; + // Balancer call retry state. + BackOff lb_call_backoff_; + bool retry_timer_callback_pending_ = false; + grpc_timer lb_call_retry_timer_; + grpc_closure lb_on_call_retry_; + + // The deserialized response from the balancer. May be nullptr until one + // such response has arrived. + RefCountedPtr serverlist_; + + // Whether we're in fallback mode. + bool fallback_mode_ = false; + // The backend addresses from the resolver. + ServerAddressList fallback_backend_addresses_; + // State for fallback-at-startup checks. + // Timeout after startup after which we will go into fallback mode if + // we have not received a serverlist from the balancer. + const int fallback_at_startup_timeout_ = 0; + bool fallback_at_startup_checks_pending_ = false; + grpc_timer lb_fallback_timer_; + grpc_closure lb_on_fallback_; + + // The child policy to use for the backends. + OrphanablePtr child_policy_; + // Child policy in state READY. + bool child_policy_ready_ = false; + + // Deleted subchannel caching. + const grpc_millis subchannel_cache_interval_ms_; + std::map>> + cached_subchannels_; + grpc_timer subchannel_cache_timer_; + grpc_closure on_subchannel_cache_timer_; + bool subchannel_cache_timer_pending_ = false; +}; + +// +// GrpcLb::Serverlist +// + +bool GrpcLb::Serverlist::operator==(const Serverlist& other) const { + return serverlist_ == other.serverlist_; +} + +void ParseServer(const GrpcLbServer& server, grpc_resolved_address* addr) { + memset(addr, 0, sizeof(*addr)); + if (server.drop) return; + const uint16_t netorder_port = grpc_htons(static_cast(server.port)); + /* the addresses are given in binary format (a in(6)_addr struct) in + * server->ip_address.bytes. */ + if (server.ip_size == 4) { + addr->len = static_cast(sizeof(grpc_sockaddr_in)); + grpc_sockaddr_in* addr4 = reinterpret_cast(&addr->addr); + addr4->sin_family = GRPC_AF_INET; + memcpy(&addr4->sin_addr, server.ip_addr, server.ip_size); + addr4->sin_port = netorder_port; + } else if (server.ip_size == 16) { + addr->len = static_cast(sizeof(grpc_sockaddr_in6)); + grpc_sockaddr_in6* addr6 = + reinterpret_cast(&addr->addr); + addr6->sin6_family = GRPC_AF_INET6; + memcpy(&addr6->sin6_addr, server.ip_addr, server.ip_size); + addr6->sin6_port = netorder_port; + } +} + +std::string GrpcLb::Serverlist::AsText() const { + std::vector entries; + for (size_t i = 0; i < serverlist_.size(); ++i) { + const GrpcLbServer& server = serverlist_[i]; + std::string ipport; + if (server.drop) { + ipport = "(drop)"; + } else { + grpc_resolved_address addr; + ParseServer(server, &addr); + ipport = grpc_sockaddr_to_string(&addr, false); + } + entries.push_back(absl::StrFormat(" %" PRIuPTR ": %s token=%s\n", i, + ipport, server.load_balance_token)); + } + return absl::StrJoin(entries, ""); +} + +bool IsServerValid(const GrpcLbServer& server, size_t idx, bool log) { + if (server.drop) return false; + if (GPR_UNLIKELY(server.port >> 16 != 0)) { + if (log) { + gpr_log(GPR_ERROR, + "Invalid port '%d' at index %" PRIuPTR + " of serverlist. Ignoring.", + server.port, idx); + } + return false; + } + if (GPR_UNLIKELY(server.ip_size != 4 && server.ip_size != 16)) { + if (log) { + gpr_log(GPR_ERROR, + "Expected IP to be 4 or 16 bytes, got %d at index %" PRIuPTR + " of serverlist. Ignoring", + server.ip_size, idx); + } + return false; + } + return true; +} + +// Returns addresses extracted from the serverlist. +ServerAddressList GrpcLb::Serverlist::GetServerAddressList( + GrpcLbClientStats* client_stats) const { + RefCountedPtr stats; + if (client_stats != nullptr) stats = client_stats->Ref(); + ServerAddressList addresses; + for (size_t i = 0; i < serverlist_.size(); ++i) { + const GrpcLbServer& server = serverlist_[i]; + if (!IsServerValid(server, i, false)) continue; + // Address processing. + grpc_resolved_address addr; + ParseServer(server, &addr); + // LB token processing. + const size_t lb_token_length = strnlen( + server.load_balance_token, GPR_ARRAY_SIZE(server.load_balance_token)); + std::string lb_token(server.load_balance_token, lb_token_length); + if (lb_token.empty()) { + gpr_log(GPR_INFO, + "Missing LB token for backend address '%s'. The empty token will " + "be used instead", + grpc_sockaddr_to_uri(&addr).c_str()); + } + // Attach attribute to address containing LB token and stats object. + std::map> + attributes; + attributes[kGrpcLbAddressAttributeKey] = + absl::make_unique(std::move(lb_token), + stats); + // Add address. + addresses.emplace_back(addr, /*args=*/nullptr, std::move(attributes)); + } + return addresses; +} + +bool GrpcLb::Serverlist::ContainsAllDropEntries() const { + if (serverlist_.empty()) return false; + for (const GrpcLbServer& server : serverlist_) { + if (!server.drop) return false; + } + return true; +} + +const char* GrpcLb::Serverlist::ShouldDrop() { + if (serverlist_.empty()) return nullptr; + GrpcLbServer& server = serverlist_[drop_index_]; + drop_index_ = (drop_index_ + 1) % serverlist_.size(); + return server.drop ? server.load_balance_token : nullptr; +} + +// +// GrpcLb::Picker +// + +GrpcLb::PickResult GrpcLb::Picker::Pick(PickArgs args) { + // Check if we should drop the call. + const char* drop_token = + serverlist_ == nullptr ? nullptr : serverlist_->ShouldDrop(); + if (drop_token != nullptr) { + // Update client load reporting stats to indicate the number of + // dropped calls. Note that we have to do this here instead of in + // the client_load_reporting filter, because we do not create a + // subchannel call (and therefore no client_load_reporting filter) + // for dropped calls. + if (client_stats_ != nullptr) { + client_stats_->AddCallDropped(drop_token); + } + return PickResult::Drop( + absl::UnavailableError("drop directed by grpclb balancer")); + } + // Forward pick to child policy. + PickResult result = child_picker_->Pick(args); + // If pick succeeded, add LB token to initial metadata. + auto* complete_pick = absl::get_if(&result.result); + if (complete_pick != nullptr) { + const SubchannelWrapper* subchannel_wrapper = + static_cast(complete_pick->subchannel.get()); + // Encode client stats object into metadata for use by + // client_load_reporting filter. + GrpcLbClientStats* client_stats = subchannel_wrapper->client_stats(); + if (client_stats != nullptr) { + client_stats->Ref().release(); // Ref passed via metadata. + // The metadata value is a hack: we pretend the pointer points to + // a string and rely on the client_load_reporting filter to know + // how to interpret it. + args.initial_metadata->Add( + kGrpcLbClientStatsMetadataKey, + absl::string_view(reinterpret_cast(client_stats), 0)); + // Update calls-started. + client_stats->AddCallStarted(); + } + // Encode the LB token in metadata. + // Create a new copy on the call arena, since the subchannel list + // may get refreshed between when we return this pick and when the + // initial metadata goes out on the wire. + if (!subchannel_wrapper->lb_token().empty()) { + char* lb_token = static_cast( + args.call_state->Alloc(subchannel_wrapper->lb_token().size() + 1)); + strcpy(lb_token, subchannel_wrapper->lb_token().c_str()); + args.initial_metadata->Add(kGrpcLbLbTokenMetadataKey, lb_token); + } + // Unwrap subchannel to pass up to the channel. + complete_pick->subchannel = subchannel_wrapper->wrapped_subchannel(); + } + return result; +} + +// +// GrpcLb::Helper +// + +RefCountedPtr GrpcLb::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (parent_->shutting_down_) return nullptr; + const TokenAndClientStatsAttribute* attribute = + static_cast( + address.GetAttribute(kGrpcLbAddressAttributeKey)); + if (attribute == nullptr) { + gpr_log(GPR_ERROR, + "[grpclb %p] no TokenAndClientStatsAttribute for address %p", + parent_.get(), address.ToString().c_str()); + abort(); + } + std::string lb_token = attribute->lb_token(); + RefCountedPtr client_stats = attribute->client_stats(); + return MakeRefCounted( + parent_->channel_control_helper()->CreateSubchannel(std::move(address), + args), + parent_->Ref(DEBUG_LOCATION, "SubchannelWrapper"), std::move(lb_token), + std::move(client_stats)); +} + +void GrpcLb::Helper::UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) { + if (parent_->shutting_down_) return; + // Record whether child policy reports READY. + parent_->child_policy_ready_ = state == GRPC_CHANNEL_READY; + // Enter fallback mode if needed. + parent_->MaybeEnterFallbackModeAfterStartup(); + // We pass the serverlist to the picker so that it can handle drops. + // However, we don't want to handle drops in the case where the child + // policy is reporting a state other than READY (unless we are + // dropping *all* calls), because we don't want to process drops for picks + // that yield a QUEUE result; this would result in dropping too many calls, + // since we will see the queued picks multiple times, and we'd consider each + // one a separate call for the drop calculation. So in this case, we pass + // a null serverlist to the picker, which tells it not to do drops. + RefCountedPtr serverlist; + if (state == GRPC_CHANNEL_READY || + (parent_->serverlist_ != nullptr && + parent_->serverlist_->ContainsAllDropEntries())) { + serverlist = parent_->serverlist_; + } + RefCountedPtr client_stats; + if (parent_->lb_calld_ != nullptr && + parent_->lb_calld_->client_stats() != nullptr) { + client_stats = parent_->lb_calld_->client_stats()->Ref(); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p helper %p] state=%s (%s) wrapping child " + "picker %p (serverlist=%p, client_stats=%p)", + parent_.get(), this, ConnectivityStateName(state), + status.ToString().c_str(), picker.get(), serverlist.get(), + client_stats.get()); + } + parent_->channel_control_helper()->UpdateState( + state, status, + absl::make_unique(std::move(serverlist), std::move(picker), + std::move(client_stats))); +} + +void GrpcLb::Helper::RequestReresolution() { + if (parent_->shutting_down_) return; + // If we are talking to a balancer, we expect to get updated addresses + // from the balancer, so we can ignore the re-resolution request from + // the child policy. Otherwise, pass the re-resolution request up to the + // channel. + if (parent_->lb_calld_ == nullptr || + !parent_->lb_calld_->seen_initial_response()) { + parent_->channel_control_helper()->RequestReresolution(); + } +} + +absl::string_view GrpcLb::Helper::GetAuthority() { + return parent_->channel_control_helper()->GetAuthority(); +} + +void GrpcLb::Helper::AddTraceEvent(TraceSeverity severity, + absl::string_view message) { + if (parent_->shutting_down_) return; + parent_->channel_control_helper()->AddTraceEvent(severity, message); +} + +// +// GrpcLb::BalancerCallState +// + +GrpcLb::BalancerCallState::BalancerCallState( + RefCountedPtr parent_grpclb_policy) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace) ? "BalancerCallState" + : nullptr), + grpclb_policy_(std::move(parent_grpclb_policy)) { + GPR_ASSERT(grpclb_policy_ != nullptr); + GPR_ASSERT(!grpclb_policy()->shutting_down_); + // Init the LB call. Note that the LB call will progress every time there's + // activity in grpclb_policy_->interested_parties(), which is comprised of + // the polling entities from client_channel. + GPR_ASSERT(!grpclb_policy()->server_name_.empty()); + // Closure Initialization + GRPC_CLOSURE_INIT(&lb_on_initial_request_sent_, OnInitialRequestSent, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&lb_on_balancer_message_received_, + OnBalancerMessageReceived, this, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&lb_on_balancer_status_received_, OnBalancerStatusReceived, + this, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&client_load_report_closure_, MaybeSendClientLoadReport, + this, grpc_schedule_on_exec_ctx); + const grpc_millis deadline = + grpclb_policy()->lb_call_timeout_ms_ == 0 + ? GRPC_MILLIS_INF_FUTURE + : ExecCtx::Get()->Now() + grpclb_policy()->lb_call_timeout_ms_; + lb_call_ = grpc_channel_create_pollset_set_call( + grpclb_policy()->lb_channel_, nullptr, GRPC_PROPAGATE_DEFAULTS, + grpclb_policy_->interested_parties(), + GRPC_MDSTR_SLASH_GRPC_DOT_LB_DOT_V1_DOT_LOADBALANCER_SLASH_BALANCELOAD, + nullptr, deadline, nullptr); + // Init the LB call request payload. + upb::Arena arena; + grpc_slice request_payload_slice = GrpcLbRequestCreate( + grpclb_policy()->config_->service_name().empty() + ? grpclb_policy()->server_name_.c_str() + : grpclb_policy()->config_->service_name().c_str(), + arena.ptr()); + send_message_payload_ = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_slice_unref_internal(request_payload_slice); + // Init other data associated with the LB call. + grpc_metadata_array_init(&lb_initial_metadata_recv_); + grpc_metadata_array_init(&lb_trailing_metadata_recv_); +} + +GrpcLb::BalancerCallState::~BalancerCallState() { + GPR_ASSERT(lb_call_ != nullptr); + grpc_call_unref(lb_call_); + grpc_metadata_array_destroy(&lb_initial_metadata_recv_); + grpc_metadata_array_destroy(&lb_trailing_metadata_recv_); + grpc_byte_buffer_destroy(send_message_payload_); + grpc_byte_buffer_destroy(recv_message_payload_); + grpc_slice_unref_internal(lb_call_status_details_); +} + +void GrpcLb::BalancerCallState::Orphan() { + GPR_ASSERT(lb_call_ != nullptr); + // If we are here because grpclb_policy wants to cancel the call, + // lb_on_balancer_status_received_ will complete the cancellation and clean + // up. Otherwise, we are here because grpclb_policy has to orphan a failed + // call, then the following cancellation will be a no-op. + grpc_call_cancel_internal(lb_call_); + if (client_load_report_timer_callback_pending_) { + grpc_timer_cancel(&client_load_report_timer_); + } + // Note that the initial ref is hold by lb_on_balancer_status_received_ + // instead of the caller of this function. So the corresponding unref happens + // in lb_on_balancer_status_received_ instead of here. +} + +void GrpcLb::BalancerCallState::StartQuery() { + GPR_ASSERT(lb_call_ != nullptr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, "[grpclb %p] lb_calld=%p: Starting LB call %p", + grpclb_policy_.get(), this, lb_call_); + } + // Create the ops. + grpc_call_error call_error; + grpc_op ops[3]; + memset(ops, 0, sizeof(ops)); + // Op: send initial metadata. + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY | + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET; + op->reserved = nullptr; + op++; + // Op: send request message. + GPR_ASSERT(send_message_payload_ != nullptr); + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = send_message_payload_; + op->flags = 0; + op->reserved = nullptr; + op++; + // TODO(roth): We currently track this ref manually. Once the + // ClosureRef API is ready, we should pass the RefCountedPtr<> along + // with the callback. + auto self = Ref(DEBUG_LOCATION, "on_initial_request_sent"); + self.release(); + call_error = grpc_call_start_batch_and_execute(lb_call_, ops, + static_cast(op - ops), + &lb_on_initial_request_sent_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: recv initial metadata. + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &lb_initial_metadata_recv_; + op->flags = 0; + op->reserved = nullptr; + op++; + // Op: recv response. + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_message_payload_; + op->flags = 0; + op->reserved = nullptr; + op++; + // TODO(roth): We currently track this ref manually. Once the + // ClosureRef API is ready, we should pass the RefCountedPtr<> along + // with the callback. + self = Ref(DEBUG_LOCATION, "on_message_received"); + self.release(); + call_error = grpc_call_start_batch_and_execute( + lb_call_, ops, static_cast(op - ops), + &lb_on_balancer_message_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: recv server status. + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = + &lb_trailing_metadata_recv_; + op->data.recv_status_on_client.status = &lb_call_status_; + op->data.recv_status_on_client.status_details = &lb_call_status_details_; + op->flags = 0; + op->reserved = nullptr; + op++; + // This callback signals the end of the LB call, so it relies on the initial + // ref instead of a new ref. When it's invoked, it's the initial ref that is + // unreffed. + call_error = grpc_call_start_batch_and_execute( + lb_call_, ops, static_cast(op - ops), + &lb_on_balancer_status_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); +} + +void GrpcLb::BalancerCallState::ScheduleNextClientLoadReportLocked() { + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + ExecCtx::Get()->InvalidateNow(); + const grpc_millis next_client_load_report_time = + ExecCtx::Get()->Now() + client_stats_report_interval_; + GRPC_CLOSURE_INIT(&client_load_report_closure_, MaybeSendClientLoadReport, + this, grpc_schedule_on_exec_ctx); + grpc_timer_init(&client_load_report_timer_, next_client_load_report_time, + &client_load_report_closure_); + client_load_report_timer_callback_pending_ = true; +} + +void GrpcLb::BalancerCallState::MaybeSendClientLoadReport( + void* arg, grpc_error_handle error) { + BalancerCallState* lb_calld = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + lb_calld->grpclb_policy()->work_serializer()->Run( + [lb_calld, error]() { lb_calld->MaybeSendClientLoadReportLocked(error); }, + DEBUG_LOCATION); +} + +void GrpcLb::BalancerCallState::MaybeSendClientLoadReportLocked( + grpc_error_handle error) { + client_load_report_timer_callback_pending_ = false; + if (error != GRPC_ERROR_NONE || this != grpclb_policy()->lb_calld_.get()) { + Unref(DEBUG_LOCATION, "client_load_report"); + GRPC_ERROR_UNREF(error); + return; + } + // If we've already sent the initial request, then we can go ahead and send + // the load report. Otherwise, we need to wait until the initial request has + // been sent to send this (see OnInitialRequestSentLocked()). + if (send_message_payload_ == nullptr) { + SendClientLoadReportLocked(); + } else { + client_load_report_is_due_ = true; + } +} + +void GrpcLb::BalancerCallState::SendClientLoadReportLocked() { + // Construct message payload. + GPR_ASSERT(send_message_payload_ == nullptr); + // Get snapshot of stats. + int64_t num_calls_started; + int64_t num_calls_finished; + int64_t num_calls_finished_with_client_failed_to_send; + int64_t num_calls_finished_known_received; + std::unique_ptr drop_token_counts; + client_stats_->Get(&num_calls_started, &num_calls_finished, + &num_calls_finished_with_client_failed_to_send, + &num_calls_finished_known_received, &drop_token_counts); + // Skip client load report if the counters were all zero in the last + // report and they are still zero in this one. + if (num_calls_started == 0 && num_calls_finished == 0 && + num_calls_finished_with_client_failed_to_send == 0 && + num_calls_finished_known_received == 0 && + (drop_token_counts == nullptr || drop_token_counts->empty())) { + if (last_client_load_report_counters_were_zero_) { + ScheduleNextClientLoadReportLocked(); + return; + } + last_client_load_report_counters_were_zero_ = true; + } else { + last_client_load_report_counters_were_zero_ = false; + } + // Populate load report. + upb::Arena arena; + grpc_slice request_payload_slice = GrpcLbLoadReportRequestCreate( + num_calls_started, num_calls_finished, + num_calls_finished_with_client_failed_to_send, + num_calls_finished_known_received, drop_token_counts.get(), arena.ptr()); + send_message_payload_ = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_slice_unref_internal(request_payload_slice); + // Send the report. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_SEND_MESSAGE; + op.data.send_message.send_message = send_message_payload_; + GRPC_CLOSURE_INIT(&client_load_report_closure_, ClientLoadReportDone, this, + grpc_schedule_on_exec_ctx); + grpc_call_error call_error = grpc_call_start_batch_and_execute( + lb_call_, &op, 1, &client_load_report_closure_); + if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { + gpr_log(GPR_ERROR, + "[grpclb %p] lb_calld=%p call_error=%d sending client load report", + grpclb_policy_.get(), this, call_error); + GPR_ASSERT(GRPC_CALL_OK == call_error); + } +} + +void GrpcLb::BalancerCallState::ClientLoadReportDone(void* arg, + grpc_error_handle error) { + BalancerCallState* lb_calld = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + lb_calld->grpclb_policy()->work_serializer()->Run( + [lb_calld, error]() { lb_calld->ClientLoadReportDoneLocked(error); }, + DEBUG_LOCATION); +} + +void GrpcLb::BalancerCallState::ClientLoadReportDoneLocked( + grpc_error_handle error) { + grpc_byte_buffer_destroy(send_message_payload_); + send_message_payload_ = nullptr; + if (error != GRPC_ERROR_NONE || this != grpclb_policy()->lb_calld_.get()) { + Unref(DEBUG_LOCATION, "client_load_report"); + GRPC_ERROR_UNREF(error); + return; + } + ScheduleNextClientLoadReportLocked(); +} + +void GrpcLb::BalancerCallState::OnInitialRequestSent( + void* arg, grpc_error_handle /*error*/) { + BalancerCallState* lb_calld = static_cast(arg); + lb_calld->grpclb_policy()->work_serializer()->Run( + [lb_calld]() { lb_calld->OnInitialRequestSentLocked(); }, DEBUG_LOCATION); +} + +void GrpcLb::BalancerCallState::OnInitialRequestSentLocked() { + grpc_byte_buffer_destroy(send_message_payload_); + send_message_payload_ = nullptr; + // If we attempted to send a client load report before the initial request was + // sent (and this lb_calld is still in use), send the load report now. + if (client_load_report_is_due_ && this == grpclb_policy()->lb_calld_.get()) { + SendClientLoadReportLocked(); + client_load_report_is_due_ = false; + } + Unref(DEBUG_LOCATION, "on_initial_request_sent"); +} + +void GrpcLb::BalancerCallState::OnBalancerMessageReceived( + void* arg, grpc_error_handle /*error*/) { + BalancerCallState* lb_calld = static_cast(arg); + lb_calld->grpclb_policy()->work_serializer()->Run( + [lb_calld]() { lb_calld->OnBalancerMessageReceivedLocked(); }, + DEBUG_LOCATION); +} + +void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked() { + // Null payload means the LB call was cancelled. + if (this != grpclb_policy()->lb_calld_.get() || + recv_message_payload_ == nullptr) { + Unref(DEBUG_LOCATION, "on_message_received"); + return; + } + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, recv_message_payload_); + grpc_slice response_slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_byte_buffer_reader_destroy(&bbr); + grpc_byte_buffer_destroy(recv_message_payload_); + recv_message_payload_ = nullptr; + GrpcLbResponse response; + upb::Arena arena; + if (!GrpcLbResponseParse(response_slice, arena.ptr(), &response) || + (response.type == response.INITIAL && seen_initial_response_)) { + char* response_slice_str = + grpc_dump_slice(response_slice, GPR_DUMP_ASCII | GPR_DUMP_HEX); + gpr_log(GPR_ERROR, + "[grpclb %p] lb_calld=%p: Invalid LB response received: '%s'. " + "Ignoring.", + grpclb_policy(), this, response_slice_str); + gpr_free(response_slice_str); + } else { + switch (response.type) { + case response.INITIAL: { + if (response.client_stats_report_interval != 0) { + client_stats_report_interval_ = std::max( + int64_t(GPR_MS_PER_SEC), response.client_stats_report_interval); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Received initial LB response " + "message; client load reporting interval = %" PRId64 + " milliseconds", + grpclb_policy(), this, client_stats_report_interval_); + } + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Received initial LB response " + "message; client load reporting NOT enabled", + grpclb_policy(), this); + } + seen_initial_response_ = true; + break; + } + case response.SERVERLIST: { + GPR_ASSERT(lb_call_ != nullptr); + auto serverlist_wrapper = + MakeRefCounted(std::move(response.serverlist)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Serverlist with %" PRIuPTR + " servers received:\n%s", + grpclb_policy(), this, + serverlist_wrapper->serverlist().size(), + serverlist_wrapper->AsText().c_str()); + } + seen_serverlist_ = true; + // Start sending client load report only after we start using the + // serverlist returned from the current LB call. + if (client_stats_report_interval_ > 0 && client_stats_ == nullptr) { + client_stats_ = MakeRefCounted(); + // Ref held by callback. + Ref(DEBUG_LOCATION, "client_load_report").release(); + ScheduleNextClientLoadReportLocked(); + } + // Check if the serverlist differs from the previous one. + if (grpclb_policy()->serverlist_ != nullptr && + *grpclb_policy()->serverlist_ == *serverlist_wrapper) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Incoming server list identical " + "to current, ignoring.", + grpclb_policy(), this); + } + } else { // New serverlist. + // Dispose of the fallback. + // TODO(roth): Ideally, we should stay in fallback mode until we + // know that we can reach at least one of the backends in the new + // serverlist. Unfortunately, we can't do that, since we need to + // send the new addresses to the child policy in order to determine + // if they are reachable, and if we don't exit fallback mode now, + // CreateOrUpdateChildPolicyLocked() will use the fallback + // addresses instead of the addresses from the new serverlist. + // However, if we can't reach any of the servers in the new + // serverlist, then the child policy will never switch away from + // the fallback addresses, but the grpclb policy will still think + // that we're not in fallback mode, which means that we won't send + // updates to the child policy when the fallback addresses are + // updated by the resolver. This is sub-optimal, but the only way + // to fix it is to maintain a completely separate child policy for + // fallback mode, and that's more work than we want to put into + // the grpclb implementation at this point, since we're deprecating + // it in favor of the xds policy. We will implement this the + // right way in the xds policy instead. + if (grpclb_policy()->fallback_mode_) { + gpr_log(GPR_INFO, + "[grpclb %p] Received response from balancer; exiting " + "fallback mode", + grpclb_policy()); + grpclb_policy()->fallback_mode_ = false; + } + if (grpclb_policy()->fallback_at_startup_checks_pending_) { + grpclb_policy()->fallback_at_startup_checks_pending_ = false; + grpc_timer_cancel(&grpclb_policy()->lb_fallback_timer_); + grpclb_policy()->CancelBalancerChannelConnectivityWatchLocked(); + } + // Update the serverlist in the GrpcLb instance. This serverlist + // instance will be destroyed either upon the next update or when the + // GrpcLb instance is destroyed. + grpclb_policy()->serverlist_ = std::move(serverlist_wrapper); + grpclb_policy()->CreateOrUpdateChildPolicyLocked(); + } + break; + } + case response.FALLBACK: { + if (!grpclb_policy()->fallback_mode_) { + gpr_log(GPR_INFO, + "[grpclb %p] Entering fallback mode as requested by balancer", + grpclb_policy()); + if (grpclb_policy()->fallback_at_startup_checks_pending_) { + grpclb_policy()->fallback_at_startup_checks_pending_ = false; + grpc_timer_cancel(&grpclb_policy()->lb_fallback_timer_); + grpclb_policy()->CancelBalancerChannelConnectivityWatchLocked(); + } + grpclb_policy()->fallback_mode_ = true; + grpclb_policy()->CreateOrUpdateChildPolicyLocked(); + // Reset serverlist, so that if the balancer exits fallback + // mode by sending the same serverlist we were previously + // using, we don't incorrectly ignore it as a duplicate. + grpclb_policy()->serverlist_.reset(); + } + break; + } + } + } + grpc_slice_unref_internal(response_slice); + if (!grpclb_policy()->shutting_down_) { + // Keep listening for serverlist updates. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_RECV_MESSAGE; + op.data.recv_message.recv_message = &recv_message_payload_; + op.flags = 0; + op.reserved = nullptr; + // Reuse the "OnBalancerMessageReceivedLocked" ref taken in StartQuery(). + const grpc_call_error call_error = grpc_call_start_batch_and_execute( + lb_call_, &op, 1, &lb_on_balancer_message_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + } else { + Unref(DEBUG_LOCATION, "on_message_received+grpclb_shutdown"); + } +} + +void GrpcLb::BalancerCallState::OnBalancerStatusReceived( + void* arg, grpc_error_handle error) { + BalancerCallState* lb_calld = static_cast(arg); + (void)GRPC_ERROR_REF(error); // owned by lambda + lb_calld->grpclb_policy()->work_serializer()->Run( + [lb_calld, error]() { lb_calld->OnBalancerStatusReceivedLocked(error); }, + DEBUG_LOCATION); +} + +void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked( + grpc_error_handle error) { + GPR_ASSERT(lb_call_ != nullptr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + char* status_details = grpc_slice_to_c_string(lb_call_status_details_); + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Status from LB server received. " + "Status = %d, details = '%s', (lb_call: %p), error '%s'", + grpclb_policy(), this, lb_call_status_, status_details, lb_call_, + grpc_error_std_string(error).c_str()); + gpr_free(status_details); + } + GRPC_ERROR_UNREF(error); + // If this lb_calld is still in use, this call ended because of a failure so + // we want to retry connecting. Otherwise, we have deliberately ended this + // call and no further action is required. + if (this == grpclb_policy()->lb_calld_.get()) { + // If the fallback-at-startup checks are pending, go into fallback mode + // immediately. This short-circuits the timeout for the fallback-at-startup + // case. + if (grpclb_policy()->fallback_at_startup_checks_pending_) { + GPR_ASSERT(!seen_serverlist_); + gpr_log(GPR_INFO, + "[grpclb %p] Balancer call finished without receiving " + "serverlist; entering fallback mode", + grpclb_policy()); + grpclb_policy()->fallback_at_startup_checks_pending_ = false; + grpc_timer_cancel(&grpclb_policy()->lb_fallback_timer_); + grpclb_policy()->CancelBalancerChannelConnectivityWatchLocked(); + grpclb_policy()->fallback_mode_ = true; + grpclb_policy()->CreateOrUpdateChildPolicyLocked(); + } else { + // This handles the fallback-after-startup case. + grpclb_policy()->MaybeEnterFallbackModeAfterStartup(); + } + grpclb_policy()->lb_calld_.reset(); + GPR_ASSERT(!grpclb_policy()->shutting_down_); + grpclb_policy()->channel_control_helper()->RequestReresolution(); + if (seen_initial_response_) { + // If we lose connection to the LB server, reset the backoff and restart + // the LB call immediately. + grpclb_policy()->lb_call_backoff_.Reset(); + grpclb_policy()->StartBalancerCallLocked(); + } else { + // If this LB call fails establishing any connection to the LB server, + // retry later. + grpclb_policy()->StartBalancerCallRetryTimerLocked(); + } + } + Unref(DEBUG_LOCATION, "lb_call_ended"); +} + +// +// helper code for creating balancer channel +// + +ServerAddressList ExtractBalancerAddresses(const grpc_channel_args& args) { + const ServerAddressList* addresses = + FindGrpclbBalancerAddressesInChannelArgs(args); + if (addresses != nullptr) return *addresses; + return ServerAddressList(); +} + +/* Returns the channel args for the LB channel, used to create a bidirectional + * stream for the reception of load balancing updates. + * + * Inputs: + * - \a response_generator: in order to propagate updates from the resolver + * above the grpclb policy. + * - \a args: other args inherited from the grpclb policy. */ +grpc_channel_args* BuildBalancerChannelArgs( + FakeResolverResponseGenerator* response_generator, + const grpc_channel_args* args) { + // Channel args to remove. + static const char* args_to_remove[] = { + // LB policy name, since we want to use the default (pick_first) in + // the LB channel. + GRPC_ARG_LB_POLICY_NAME, + // Strip out the service config, since we don't want the LB policy + // config specified for the parent channel to affect the LB channel. + GRPC_ARG_SERVICE_CONFIG, + // The channel arg for the server URI, since that will be different for + // the LB channel than for the parent channel. The client channel + // factory will re-add this arg with the right value. + GRPC_ARG_SERVER_URI, + // The fake resolver response generator, because we are replacing it + // with the one from the grpclb policy, used to propagate updates to + // the LB channel. + GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + // The LB channel should use the authority indicated by the target + // authority table (see \a ModifyGrpclbBalancerChannelArgs), + // as opposed to the authority from the parent channel. + GRPC_ARG_DEFAULT_AUTHORITY, + // Just as for \a GRPC_ARG_DEFAULT_AUTHORITY, the LB channel should be + // treated as a stand-alone channel and not inherit this argument from the + // args of the parent channel. + GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, + // Don't want to pass down channelz node from parent; the balancer + // channel will get its own. + GRPC_ARG_CHANNELZ_CHANNEL_NODE, + }; + // Channel args to add. + absl::InlinedVector args_to_add = { + // The fake resolver response generator, which we use to inject + // address updates into the LB channel. + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + response_generator), + // A channel arg indicating the target is a grpclb load balancer. + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER), 1), + // Tells channelz that this is an internal channel. + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_CHANNELZ_IS_INTERNAL_CHANNEL), 1), + }; + // Construct channel args. + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), args_to_add.data(), + args_to_add.size()); + // Make any necessary modifications for security. + return ModifyGrpclbBalancerChannelArgs(new_args); +} + +// +// ctor and dtor +// + +std::string GetServerNameFromChannelArgs(const grpc_channel_args* args) { + const char* server_uri = + grpc_channel_args_find_string(args, GRPC_ARG_SERVER_URI); + GPR_ASSERT(server_uri != nullptr); + absl::StatusOr uri = URI::Parse(server_uri); + GPR_ASSERT(uri.ok() && !uri->path().empty()); + return std::string(absl::StripPrefix(uri->path(), "/")); +} + +GrpcLb::GrpcLb(Args args) + : LoadBalancingPolicy(std::move(args)), + server_name_(GetServerNameFromChannelArgs(args.args)), + response_generator_(MakeRefCounted()), + lb_call_timeout_ms_(grpc_channel_args_find_integer( + args.args, GRPC_ARG_GRPCLB_CALL_TIMEOUT_MS, {0, 0, INT_MAX})), + lb_call_backoff_( + BackOff::Options() + .set_initial_backoff(GRPC_GRPCLB_INITIAL_CONNECT_BACKOFF_SECONDS * + 1000) + .set_multiplier(GRPC_GRPCLB_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(GRPC_GRPCLB_RECONNECT_JITTER) + .set_max_backoff(GRPC_GRPCLB_RECONNECT_MAX_BACKOFF_SECONDS * + 1000)), + fallback_at_startup_timeout_(grpc_channel_args_find_integer( + args.args, GRPC_ARG_GRPCLB_FALLBACK_TIMEOUT_MS, + {GRPC_GRPCLB_DEFAULT_FALLBACK_TIMEOUT_MS, 0, INT_MAX})), + subchannel_cache_interval_ms_(grpc_channel_args_find_integer( + args.args, GRPC_ARG_GRPCLB_SUBCHANNEL_CACHE_INTERVAL_MS, + {GRPC_GRPCLB_DEFAULT_SUBCHANNEL_DELETION_DELAY_MS, 0, INT_MAX})) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] Will use '%s' as the server name for LB request.", + this, server_name_.c_str()); + } + // Closure Initialization + GRPC_CLOSURE_INIT(&lb_on_fallback_, &GrpcLb::OnFallbackTimer, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&lb_on_call_retry_, &GrpcLb::OnBalancerCallRetryTimer, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_subchannel_cache_timer_, &OnSubchannelCacheTimer, this, + nullptr); +} + +GrpcLb::~GrpcLb() { grpc_channel_args_destroy(args_); } + +void GrpcLb::ShutdownLocked() { + shutting_down_ = true; + lb_calld_.reset(); + if (subchannel_cache_timer_pending_) { + subchannel_cache_timer_pending_ = false; + grpc_timer_cancel(&subchannel_cache_timer_); + } + cached_subchannels_.clear(); + if (retry_timer_callback_pending_) { + grpc_timer_cancel(&lb_call_retry_timer_); + } + if (fallback_at_startup_checks_pending_) { + fallback_at_startup_checks_pending_ = false; + grpc_timer_cancel(&lb_fallback_timer_); + CancelBalancerChannelConnectivityWatchLocked(); + } + if (child_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + interested_parties()); + child_policy_.reset(); + } + // We destroy the LB channel here instead of in our destructor because + // destroying the channel triggers a last callback to + // OnBalancerChannelConnectivityChangedLocked(), and we need to be + // alive when that callback is invoked. + if (lb_channel_ != nullptr) { + if (parent_channelz_node_ != nullptr) { + channelz::ChannelNode* child_channelz_node = + grpc_channel_get_channelz_node(lb_channel_); + GPR_ASSERT(child_channelz_node != nullptr); + parent_channelz_node_->RemoveChildChannel(child_channelz_node->uuid()); + } + grpc_channel_destroy(lb_channel_); + lb_channel_ = nullptr; + } +} + +// +// public methods +// + +void GrpcLb::ResetBackoffLocked() { + if (lb_channel_ != nullptr) { + grpc_channel_reset_connect_backoff(lb_channel_); + } + if (child_policy_ != nullptr) { + child_policy_->ResetBackoffLocked(); + } +} + +void GrpcLb::UpdateLocked(UpdateArgs args) { + const bool is_initial_update = lb_channel_ == nullptr; + config_ = args.config; + GPR_ASSERT(config_ != nullptr); + ProcessAddressesAndChannelArgsLocked(args.addresses, *args.args); + // Update the existing child policy. + if (child_policy_ != nullptr) CreateOrUpdateChildPolicyLocked(); + // If this is the initial update, start the fallback-at-startup checks + // and the balancer call. + if (is_initial_update) { + fallback_at_startup_checks_pending_ = true; + // Start timer. + grpc_millis deadline = ExecCtx::Get()->Now() + fallback_at_startup_timeout_; + Ref(DEBUG_LOCATION, "on_fallback_timer").release(); // Ref for callback + grpc_timer_init(&lb_fallback_timer_, deadline, &lb_on_fallback_); + // Start watching the channel's connectivity state. If the channel + // goes into state TRANSIENT_FAILURE before the timer fires, we go into + // fallback mode even if the fallback timeout has not elapsed. + ClientChannel* client_channel = ClientChannel::GetFromChannel(lb_channel_); + GPR_ASSERT(client_channel != nullptr); + // Ref held by callback. + watcher_ = new StateWatcher(Ref(DEBUG_LOCATION, "StateWatcher")); + client_channel->AddConnectivityWatcher( + GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher_)); + // Start balancer call. + StartBalancerCallLocked(); + } +} + +// +// helpers for UpdateLocked() +// + +ServerAddressList GrpcLb::AddNullLbTokenToAddresses( + const ServerAddressList& addresses) { + ServerAddressList addresses_out; + for (const ServerAddress& address : addresses) { + addresses_out.emplace_back(address.WithAttribute( + kGrpcLbAddressAttributeKey, + absl::make_unique("", nullptr))); + } + return addresses_out; +} + +void GrpcLb::ProcessAddressesAndChannelArgsLocked( + const ServerAddressList& addresses, const grpc_channel_args& args) { + // Update fallback address list. + fallback_backend_addresses_ = AddNullLbTokenToAddresses(addresses); + // Make sure that GRPC_ARG_LB_POLICY_NAME is set in channel args, + // since we use this to trigger the client_load_reporting filter. + static const char* args_to_remove[] = {GRPC_ARG_LB_POLICY_NAME}; + grpc_arg new_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_LB_POLICY_NAME), const_cast("grpclb")); + grpc_channel_args_destroy(args_); + args_ = grpc_channel_args_copy_and_add_and_remove( + &args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), &new_arg, 1); + // Construct args for balancer channel. + ServerAddressList balancer_addresses = ExtractBalancerAddresses(args); + grpc_channel_args* lb_channel_args = + BuildBalancerChannelArgs(response_generator_.get(), &args); + // Create balancer channel if needed. + if (lb_channel_ == nullptr) { + std::string uri_str = absl::StrCat("fake:///", server_name_); + lb_channel_ = + CreateGrpclbBalancerChannel(uri_str.c_str(), *lb_channel_args); + GPR_ASSERT(lb_channel_ != nullptr); + // Set up channelz linkage. + channelz::ChannelNode* child_channelz_node = + grpc_channel_get_channelz_node(lb_channel_); + channelz::ChannelNode* parent_channelz_node = + grpc_channel_args_find_pointer( + &args, GRPC_ARG_CHANNELZ_CHANNEL_NODE); + if (child_channelz_node != nullptr && parent_channelz_node != nullptr) { + parent_channelz_node->AddChildChannel(child_channelz_node->uuid()); + parent_channelz_node_ = parent_channelz_node->Ref(); + } + } + // Propagate updates to the LB channel (pick_first) through the fake + // resolver. + Resolver::Result result; + result.addresses = std::move(balancer_addresses); + result.args = lb_channel_args; + response_generator_->SetResponse(std::move(result)); +} + +void GrpcLb::CancelBalancerChannelConnectivityWatchLocked() { + ClientChannel* client_channel = ClientChannel::GetFromChannel(lb_channel_); + GPR_ASSERT(client_channel != nullptr); + client_channel->RemoveConnectivityWatcher(watcher_); +} + +// +// code for balancer channel and call +// + +void GrpcLb::StartBalancerCallLocked() { + GPR_ASSERT(lb_channel_ != nullptr); + if (shutting_down_) return; + // Init the LB call data. + GPR_ASSERT(lb_calld_ == nullptr); + lb_calld_ = MakeOrphanable(Ref()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] Query for backends (lb_channel: %p, lb_calld: %p)", + this, lb_channel_, lb_calld_.get()); + } + lb_calld_->StartQuery(); +} + +void GrpcLb::StartBalancerCallRetryTimerLocked() { + grpc_millis next_try = lb_call_backoff_.NextAttemptTime(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, "[grpclb %p] Connection to LB server lost...", this); + grpc_millis timeout = next_try - ExecCtx::Get()->Now(); + if (timeout > 0) { + gpr_log(GPR_INFO, "[grpclb %p] ... retry_timer_active in %" PRId64 "ms.", + this, timeout); + } else { + gpr_log(GPR_INFO, "[grpclb %p] ... retry_timer_active immediately.", + this); + } + } + // TODO(roth): We currently track this ref manually. Once the + // ClosureRef API is ready, we should pass the RefCountedPtr<> along + // with the callback. + auto self = Ref(DEBUG_LOCATION, "on_balancer_call_retry_timer"); + self.release(); + retry_timer_callback_pending_ = true; + grpc_timer_init(&lb_call_retry_timer_, next_try, &lb_on_call_retry_); +} + +void GrpcLb::OnBalancerCallRetryTimer(void* arg, grpc_error_handle error) { + GrpcLb* grpclb_policy = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + grpclb_policy->work_serializer()->Run( + [grpclb_policy, error]() { + grpclb_policy->OnBalancerCallRetryTimerLocked(error); + }, + DEBUG_LOCATION); +} + +void GrpcLb::OnBalancerCallRetryTimerLocked(grpc_error_handle error) { + retry_timer_callback_pending_ = false; + if (!shutting_down_ && error == GRPC_ERROR_NONE && lb_calld_ == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, "[grpclb %p] Restarting call to LB server", this); + } + StartBalancerCallLocked(); + } + Unref(DEBUG_LOCATION, "on_balancer_call_retry_timer"); + GRPC_ERROR_UNREF(error); +} + +// +// code for handling fallback mode +// + +void GrpcLb::MaybeEnterFallbackModeAfterStartup() { + // Enter fallback mode if all of the following are true: + // - We are not currently in fallback mode. + // - We are not currently waiting for the initial fallback timeout. + // - We are not currently in contact with the balancer. + // - The child policy is not in state READY. + if (!fallback_mode_ && !fallback_at_startup_checks_pending_ && + (lb_calld_ == nullptr || !lb_calld_->seen_serverlist()) && + !child_policy_ready_) { + gpr_log(GPR_INFO, + "[grpclb %p] lost contact with balancer and backends from " + "most recent serverlist; entering fallback mode", + this); + fallback_mode_ = true; + CreateOrUpdateChildPolicyLocked(); + } +} + +void GrpcLb::OnFallbackTimer(void* arg, grpc_error_handle error) { + GrpcLb* grpclb_policy = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + grpclb_policy->work_serializer()->Run( + [grpclb_policy, error]() { grpclb_policy->OnFallbackTimerLocked(error); }, + DEBUG_LOCATION); +} + +void GrpcLb::OnFallbackTimerLocked(grpc_error_handle error) { + // If we receive a serverlist after the timer fires but before this callback + // actually runs, don't fall back. + if (fallback_at_startup_checks_pending_ && !shutting_down_ && + error == GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, + "[grpclb %p] No response from balancer after fallback timeout; " + "entering fallback mode", + this); + fallback_at_startup_checks_pending_ = false; + CancelBalancerChannelConnectivityWatchLocked(); + fallback_mode_ = true; + CreateOrUpdateChildPolicyLocked(); + } + Unref(DEBUG_LOCATION, "on_fallback_timer"); + GRPC_ERROR_UNREF(error); +} + +// +// code for interacting with the child policy +// + +grpc_channel_args* GrpcLb::CreateChildPolicyArgsLocked( + bool is_backend_from_grpclb_load_balancer) { + absl::InlinedVector args_to_add; + args_to_add.emplace_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ADDRESS_IS_BACKEND_FROM_GRPCLB_LOAD_BALANCER), + is_backend_from_grpclb_load_balancer)); + if (is_backend_from_grpclb_load_balancer) { + args_to_add.emplace_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_INHIBIT_HEALTH_CHECKING), 1)); + } + return grpc_channel_args_copy_and_add(args_, args_to_add.data(), + args_to_add.size()); +} + +OrphanablePtr GrpcLb::CreateChildPolicyLocked( + const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = absl::make_unique(Ref()); + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_lb_glb_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, "[grpclb %p] Created new child policy handler (%p)", this, + lb_policy.get()); + } + // Add the gRPC LB's interested_parties pollset_set to that of the newly + // created child policy. This will make the child policy progress upon + // activity on gRPC LB, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + interested_parties()); + return lb_policy; +} + +void GrpcLb::CreateOrUpdateChildPolicyLocked() { + if (shutting_down_) return; + // Construct update args. + UpdateArgs update_args; + bool is_backend_from_grpclb_load_balancer = false; + if (fallback_mode_) { + // If CreateOrUpdateChildPolicyLocked() is invoked when we haven't + // received any serverlist from the balancer, we use the fallback backends + // returned by the resolver. Note that the fallback backend list may be + // empty, in which case the new round_robin policy will keep the requested + // picks pending. + update_args.addresses = fallback_backend_addresses_; + } else { + update_args.addresses = serverlist_->GetServerAddressList( + lb_calld_ == nullptr ? nullptr : lb_calld_->client_stats()); + is_backend_from_grpclb_load_balancer = true; + } + update_args.args = + CreateChildPolicyArgsLocked(is_backend_from_grpclb_load_balancer); + GPR_ASSERT(update_args.args != nullptr); + update_args.config = config_->child_policy(); + // Create child policy if needed. + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(update_args.args); + } + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, "[grpclb %p] Updating child policy handler %p", this, + child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +// +// subchannel caching +// + +void GrpcLb::CacheDeletedSubchannelLocked( + RefCountedPtr subchannel) { + grpc_millis deletion_time = + ExecCtx::Get()->Now() + subchannel_cache_interval_ms_; + cached_subchannels_[deletion_time].push_back(std::move(subchannel)); + if (!subchannel_cache_timer_pending_) { + Ref(DEBUG_LOCATION, "OnSubchannelCacheTimer").release(); + subchannel_cache_timer_pending_ = true; + StartSubchannelCacheTimerLocked(); + } +} + +void GrpcLb::StartSubchannelCacheTimerLocked() { + GPR_ASSERT(!cached_subchannels_.empty()); + grpc_timer_init(&subchannel_cache_timer_, cached_subchannels_.begin()->first, + &on_subchannel_cache_timer_); +} + +void GrpcLb::OnSubchannelCacheTimer(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); + self->work_serializer()->Run( + [self, error]() { self->GrpcLb::OnSubchannelCacheTimerLocked(error); }, + DEBUG_LOCATION); +} + +void GrpcLb::OnSubchannelCacheTimerLocked(grpc_error_handle error) { + if (subchannel_cache_timer_pending_ && error == GRPC_ERROR_NONE) { + auto it = cached_subchannels_.begin(); + if (it != cached_subchannels_.end()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_glb_trace)) { + gpr_log(GPR_INFO, + "[grpclb %p] removing %" PRIuPTR " subchannels from cache", + this, it->second.size()); + } + cached_subchannels_.erase(it); + } + if (!cached_subchannels_.empty()) { + StartSubchannelCacheTimerLocked(); + return; + } + subchannel_cache_timer_pending_ = false; + } + Unref(DEBUG_LOCATION, "OnSubchannelCacheTimer"); + GRPC_ERROR_UNREF(error); +} + +// +// factory +// + +class GrpcLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kGrpclb; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + return MakeRefCounted(nullptr, ""); + } + std::vector error_list; + Json child_policy_config_json_tmp; + const Json* child_policy_config_json; + std::string service_name; + auto it = json.object_value().find("serviceName"); + if (it != json.object_value().end()) { + const Json& service_name_json = it->second; + if (service_name_json.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:serviceName error:type should be string")); + } else { + service_name = service_name_json.string_value(); + } + } + it = json.object_value().find("childPolicy"); + if (it == json.object_value().end()) { + child_policy_config_json_tmp = Json::Array{Json::Object{ + {"round_robin", Json::Object()}, + }}; + child_policy_config_json = &child_policy_config_json_tmp; + } else { + child_policy_config_json = &it->second; + } + grpc_error_handle parse_error = GRPC_ERROR_NONE; + RefCountedPtr child_policy_config = + LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + *child_policy_config_json, &parse_error); + if (parse_error != GRPC_ERROR_NONE) { + std::vector child_errors; + child_errors.push_back(parse_error); + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:childPolicy", &child_errors)); + } + if (error_list.empty()) { + return MakeRefCounted(std::move(child_policy_config), + std::move(service_name)); + } else { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("GrpcLb Parser", &error_list); + return nullptr; + } + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_grpclb_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_grpclb_shutdown() {} + +namespace grpc_core { +void RegisterGrpcLbLoadReportingFilter(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + const grpc_arg* channel_arg = + grpc_channel_args_find(args, GRPC_ARG_LB_POLICY_NAME); + if (channel_arg != nullptr && channel_arg->type == GRPC_ARG_STRING && + strcmp(channel_arg->value.string, "grpclb") == 0) { + // TODO(roth): When we get around to re-attempting + // https://github.com/grpc/grpc/pull/16214, we should try to keep + // this filter at the very top of the subchannel stack, since that + // will minimize the number of metadata elements that the filter + // needs to iterate through to find the ClientStats object. + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_client_load_reporting_filter, nullptr, nullptr); + } + return true; + }); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.cc new file mode 100644 index 00000000..8168d156 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.cc @@ -0,0 +1,76 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" + +// Channel arg key for the list of balancer addresses. +#define GRPC_ARG_GRPCLB_BALANCER_ADDRESSES "grpc.grpclb_balancer_addresses" + +namespace grpc_core { + +namespace { + +void* BalancerAddressesArgCopy(void* p) { + ServerAddressList* address_list = static_cast(p); + return new ServerAddressList(*address_list); +} + +void BalancerAddressesArgDestroy(void* p) { + ServerAddressList* address_list = static_cast(p); + delete address_list; +} + +int BalancerAddressesArgCmp(void* p, void* q) { + ServerAddressList* address_list1 = static_cast(p); + ServerAddressList* address_list2 = static_cast(q); + if (address_list1 == nullptr || address_list2 == nullptr) { + return QsortCompare(address_list1, address_list2); + } + if (address_list1->size() > address_list2->size()) return 1; + if (address_list1->size() < address_list2->size()) return -1; + for (size_t i = 0; i < address_list1->size(); ++i) { + int retval = (*address_list1)[i].Cmp((*address_list2)[i]); + if (retval != 0) return retval; + } + return 0; +} + +const grpc_arg_pointer_vtable kBalancerAddressesArgVtable = { + BalancerAddressesArgCopy, BalancerAddressesArgDestroy, + BalancerAddressesArgCmp}; + +} // namespace + +grpc_arg CreateGrpclbBalancerAddressesArg( + const ServerAddressList* address_list) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_GRPCLB_BALANCER_ADDRESSES), + const_cast(address_list), + &kBalancerAddressesArgVtable); +} + +const ServerAddressList* FindGrpclbBalancerAddressesInChannelArgs( + const grpc_channel_args& args) { + return grpc_channel_args_find_pointer( + &args, const_cast(GRPC_ARG_GRPCLB_BALANCER_ADDRESSES)); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.cc new file mode 100644 index 00000000..b8b4889e --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.cc @@ -0,0 +1,36 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h" + +#include + +namespace grpc_core { + +grpc_channel_args* ModifyGrpclbBalancerChannelArgs(grpc_channel_args* args) { + return args; +} + +grpc_channel* CreateGrpclbBalancerChannel(const char* target_uri, + const grpc_channel_args& args) { + return grpc_insecure_channel_create(target_uri, &args, nullptr); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc new file mode 100644 index 00000000..a4a67ea1 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc @@ -0,0 +1,83 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/container/inlined_vector.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +grpc_channel_args* ModifyGrpclbBalancerChannelArgs(grpc_channel_args* args) { + absl::InlinedVector args_to_remove; + absl::InlinedVector args_to_add; + // Substitute the channel credentials with a version without call + // credentials: the load balancer is not necessarily trusted to handle + // bearer token credentials. + grpc_channel_credentials* channel_credentials = + grpc_channel_credentials_find_in_args(args); + RefCountedPtr creds_sans_call_creds; + if (channel_credentials != nullptr) { + creds_sans_call_creds = + channel_credentials->duplicate_without_call_credentials(); + GPR_ASSERT(creds_sans_call_creds != nullptr); + args_to_remove.emplace_back(GRPC_ARG_CHANNEL_CREDENTIALS); + args_to_add.emplace_back( + grpc_channel_credentials_to_arg(creds_sans_call_creds.get())); + } + grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( + args, args_to_remove.data(), args_to_remove.size(), args_to_add.data(), + args_to_add.size()); + // Clean up. + grpc_channel_args_destroy(args); + return result; +} + +grpc_channel* CreateGrpclbBalancerChannel(const char* target_uri, + const grpc_channel_args& args) { + grpc_channel_credentials* creds = + grpc_channel_credentials_find_in_args(&args); + if (creds == nullptr) { + // Build with security but parent channel is insecure. + return grpc_insecure_channel_create(target_uri, &args, nullptr); + } + const char* arg_to_remove = GRPC_ARG_CHANNEL_CREDENTIALS; + grpc_channel_args* new_args = + grpc_channel_args_copy_and_remove(&args, &arg_to_remove, 1); + grpc_channel* channel = + grpc_secure_channel_create(creds, target_uri, new_args, nullptr); + grpc_channel_args_destroy(new_args); + return channel; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.cc new file mode 100644 index 00000000..e7b5d288 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.cc @@ -0,0 +1,93 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h" + +#include + +#include "absl/memory/memory.h" + +#include +#include + +#include "src/core/lib/gprpp/sync.h" + +namespace grpc_core { + +void GrpcLbClientStats::AddCallStarted() { + gpr_atm_full_fetch_add(&num_calls_started_, (gpr_atm)1); +} + +void GrpcLbClientStats::AddCallFinished( + bool finished_with_client_failed_to_send, bool finished_known_received) { + gpr_atm_full_fetch_add(&num_calls_finished_, (gpr_atm)1); + if (finished_with_client_failed_to_send) { + gpr_atm_full_fetch_add(&num_calls_finished_with_client_failed_to_send_, + (gpr_atm)1); + } + if (finished_known_received) { + gpr_atm_full_fetch_add(&num_calls_finished_known_received_, (gpr_atm)1); + } +} + +void GrpcLbClientStats::AddCallDropped(const char* token) { + // Increment num_calls_started and num_calls_finished. + gpr_atm_full_fetch_add(&num_calls_started_, (gpr_atm)1); + gpr_atm_full_fetch_add(&num_calls_finished_, (gpr_atm)1); + // Record the drop. + MutexLock lock(&drop_count_mu_); + if (drop_token_counts_ == nullptr) { + drop_token_counts_ = absl::make_unique(); + } + for (size_t i = 0; i < drop_token_counts_->size(); ++i) { + if (strcmp((*drop_token_counts_)[i].token.get(), token) == 0) { + ++(*drop_token_counts_)[i].count; + return; + } + } + // Not found, so add a new entry. + drop_token_counts_->emplace_back( + grpc_core::UniquePtr(gpr_strdup(token)), 1); +} + +namespace { + +void AtomicGetAndResetCounter(int64_t* value, gpr_atm* counter) { + *value = static_cast(gpr_atm_full_xchg(counter, (gpr_atm)0)); +} + +} // namespace + +void GrpcLbClientStats::Get( + int64_t* num_calls_started, int64_t* num_calls_finished, + int64_t* num_calls_finished_with_client_failed_to_send, + int64_t* num_calls_finished_known_received, + std::unique_ptr* drop_token_counts) { + AtomicGetAndResetCounter(num_calls_started, &num_calls_started_); + AtomicGetAndResetCounter(num_calls_finished, &num_calls_finished_); + AtomicGetAndResetCounter(num_calls_finished_with_client_failed_to_send, + &num_calls_finished_with_client_failed_to_send_); + AtomicGetAndResetCounter(num_calls_finished_known_received, + &num_calls_finished_known_received_); + MutexLock lock(&drop_count_mu_); + *drop_token_counts = std::move(drop_token_counts_); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.cc new file mode 100644 index 00000000..1d397686 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h" + +#include "google/protobuf/duration.upb.h" +#include "google/protobuf/timestamp.upb.h" + +#include + +namespace grpc_core { + +bool GrpcLbServer::operator==(const GrpcLbServer& other) const { + if (ip_size != other.ip_size) return false; + int r = memcmp(ip_addr, other.ip_addr, ip_size); + if (r != 0) return false; + if (port != other.port) return false; + r = strncmp(load_balance_token, other.load_balance_token, + sizeof(load_balance_token)); + if (r != 0) return false; + return drop == other.drop; +} + +namespace { + +grpc_slice grpc_grpclb_request_encode( + const grpc_lb_v1_LoadBalanceRequest* request, upb_arena* arena) { + size_t buf_length; + char* buf = + grpc_lb_v1_LoadBalanceRequest_serialize(request, arena, &buf_length); + return grpc_slice_from_copied_buffer(buf, buf_length); +} + +} // namespace + +grpc_slice GrpcLbRequestCreate(const char* lb_service_name, upb_arena* arena) { + grpc_lb_v1_LoadBalanceRequest* req = grpc_lb_v1_LoadBalanceRequest_new(arena); + grpc_lb_v1_InitialLoadBalanceRequest* initial_request = + grpc_lb_v1_LoadBalanceRequest_mutable_initial_request(req, arena); + size_t name_len = std::min(strlen(lb_service_name), + size_t(GRPC_GRPCLB_SERVICE_NAME_MAX_LENGTH)); + grpc_lb_v1_InitialLoadBalanceRequest_set_name( + initial_request, upb_strview_make(lb_service_name, name_len)); + return grpc_grpclb_request_encode(req, arena); +} + +namespace { + +void google_protobuf_Timestamp_assign(google_protobuf_Timestamp* timestamp, + const gpr_timespec& value) { + google_protobuf_Timestamp_set_seconds(timestamp, value.tv_sec); + google_protobuf_Timestamp_set_nanos(timestamp, value.tv_nsec); +} + +} // namespace + +grpc_slice GrpcLbLoadReportRequestCreate( + int64_t num_calls_started, int64_t num_calls_finished, + int64_t num_calls_finished_with_client_failed_to_send, + int64_t num_calls_finished_known_received, + const GrpcLbClientStats::DroppedCallCounts* drop_token_counts, + upb_arena* arena) { + grpc_lb_v1_LoadBalanceRequest* req = grpc_lb_v1_LoadBalanceRequest_new(arena); + grpc_lb_v1_ClientStats* req_stats = + grpc_lb_v1_LoadBalanceRequest_mutable_client_stats(req, arena); + google_protobuf_Timestamp_assign( + grpc_lb_v1_ClientStats_mutable_timestamp(req_stats, arena), + gpr_now(GPR_CLOCK_REALTIME)); + grpc_lb_v1_ClientStats_set_num_calls_started(req_stats, num_calls_started); + grpc_lb_v1_ClientStats_set_num_calls_finished(req_stats, num_calls_finished); + grpc_lb_v1_ClientStats_set_num_calls_finished_with_client_failed_to_send( + req_stats, num_calls_finished_with_client_failed_to_send); + grpc_lb_v1_ClientStats_set_num_calls_finished_known_received( + req_stats, num_calls_finished_known_received); + if (drop_token_counts != nullptr) { + for (size_t i = 0; i < drop_token_counts->size(); ++i) { + const GrpcLbClientStats::DropTokenCount& cur = (*drop_token_counts)[i]; + grpc_lb_v1_ClientStatsPerToken* cur_msg = + grpc_lb_v1_ClientStats_add_calls_finished_with_drop(req_stats, arena); + const size_t token_len = strlen(cur.token.get()); + char* token = reinterpret_cast(upb_arena_malloc(arena, token_len)); + memcpy(token, cur.token.get(), token_len); + grpc_lb_v1_ClientStatsPerToken_set_load_balance_token( + cur_msg, upb_strview_make(token, token_len)); + grpc_lb_v1_ClientStatsPerToken_set_num_calls(cur_msg, cur.count); + } + } + return grpc_grpclb_request_encode(req, arena); +} + +namespace { + +bool ParseServerList(const grpc_lb_v1_LoadBalanceResponse& response, + std::vector* server_list) { + // Determine the number of servers. + const grpc_lb_v1_ServerList* server_list_msg = + grpc_lb_v1_LoadBalanceResponse_server_list(&response); + if (server_list_msg == nullptr) return false; + size_t server_count = 0; + const grpc_lb_v1_Server* const* servers = + grpc_lb_v1_ServerList_servers(server_list_msg, &server_count); + // Populate servers. + if (server_count > 0) { + server_list->reserve(server_count); + for (size_t i = 0; i < server_count; ++i) { + GrpcLbServer& cur = *server_list->emplace(server_list->end()); + upb_strview address = grpc_lb_v1_Server_ip_address(servers[i]); + if (address.size == 0) { + ; // Nothing to do because cur->ip_address is an empty string. + } else if (address.size <= GRPC_GRPCLB_SERVER_IP_ADDRESS_MAX_SIZE) { + cur.ip_size = static_cast(address.size); + memcpy(cur.ip_addr, address.data, address.size); + } + cur.port = grpc_lb_v1_Server_port(servers[i]); + upb_strview token = grpc_lb_v1_Server_load_balance_token(servers[i]); + if (token.size == 0) { + ; // Nothing to do because cur->load_balance_token is an empty string. + } else if (token.size <= GRPC_GRPCLB_SERVER_LOAD_BALANCE_TOKEN_MAX_SIZE) { + memcpy(cur.load_balance_token, token.data, token.size); + } else { + gpr_log(GPR_ERROR, + "grpc_lb_v1_LoadBalanceResponse has too long token. len=%zu", + token.size); + } + cur.drop = grpc_lb_v1_Server_drop(servers[i]); + } + } + return true; +} + +grpc_millis grpc_grpclb_duration_to_millis( + const google_protobuf_Duration* duration_pb) { + return static_cast( + (google_protobuf_Duration_seconds(duration_pb) * GPR_MS_PER_SEC) + + (google_protobuf_Duration_nanos(duration_pb) / GPR_NS_PER_MS)); +} + +} // namespace + +bool GrpcLbResponseParse(const grpc_slice& serialized_response, + upb_arena* arena, GrpcLbResponse* result) { + grpc_lb_v1_LoadBalanceResponse* response = + grpc_lb_v1_LoadBalanceResponse_parse( + reinterpret_cast( + GRPC_SLICE_START_PTR(serialized_response)), + GRPC_SLICE_LENGTH(serialized_response), arena); + // Handle serverlist responses. + if (ParseServerList(*response, &result->serverlist)) { + result->type = result->SERVERLIST; + return true; + } + // Handle initial responses. + auto* initial_response = + grpc_lb_v1_LoadBalanceResponse_initial_response(response); + if (initial_response != nullptr) { + result->type = result->INITIAL; + const google_protobuf_Duration* client_stats_report_interval = + grpc_lb_v1_InitialLoadBalanceResponse_client_stats_report_interval( + initial_response); + if (client_stats_report_interval != nullptr) { + result->client_stats_report_interval = + grpc_grpclb_duration_to_millis(client_stats_report_interval); + } + return true; + } + // Handle fallback. + if (grpc_lb_v1_LoadBalanceResponse_has_fallback_response(response)) { + result->type = result->FALLBACK; + return true; + } + // Unknown response type. + return false; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc new file mode 100644 index 00000000..ad03e6c0 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc @@ -0,0 +1,522 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/subchannel_list.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +TraceFlag grpc_lb_pick_first_trace(false, "pick_first"); + +namespace { + +// +// pick_first LB policy +// + +constexpr char kPickFirst[] = "pick_first"; + +class PickFirst : public LoadBalancingPolicy { + public: + explicit PickFirst(Args args); + + const char* name() const override { return kPickFirst; } + + void UpdateLocked(UpdateArgs args) override; + void ExitIdleLocked() override; + void ResetBackoffLocked() override; + + private: + ~PickFirst() override; + + class PickFirstSubchannelList; + + class PickFirstSubchannelData + : public SubchannelData { + public: + PickFirstSubchannelData( + SubchannelList* + subchannel_list, + const ServerAddress& address, + RefCountedPtr subchannel) + : SubchannelData(subchannel_list, address, std::move(subchannel)) {} + + void ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) override; + + // Processes the connectivity change to READY for an unselected subchannel. + void ProcessUnselectedReadyLocked(); + + void CheckConnectivityStateAndStartWatchingLocked(); + }; + + class PickFirstSubchannelList + : public SubchannelList { + public: + PickFirstSubchannelList(PickFirst* policy, TraceFlag* tracer, + ServerAddressList addresses, + const grpc_channel_args& args) + : SubchannelList(policy, tracer, std::move(addresses), + policy->channel_control_helper(), args) { + // Need to maintain a ref to the LB policy as long as we maintain + // any references to subchannels, since the subchannels' + // pollset_sets will include the LB policy's pollset_set. + policy->Ref(DEBUG_LOCATION, "subchannel_list").release(); + } + + ~PickFirstSubchannelList() override { + PickFirst* p = static_cast(policy()); + p->Unref(DEBUG_LOCATION, "subchannel_list"); + } + + bool in_transient_failure() const { return in_transient_failure_; } + void set_in_transient_failure(bool in_transient_failure) { + in_transient_failure_ = in_transient_failure; + } + + private: + bool in_transient_failure_ = false; + }; + + class Picker : public SubchannelPicker { + public: + explicit Picker(RefCountedPtr subchannel) + : subchannel_(std::move(subchannel)) {} + + PickResult Pick(PickArgs /*args*/) override { + return PickResult::Complete(subchannel_); + } + + private: + RefCountedPtr subchannel_; + }; + + void ShutdownLocked() override; + + void AttemptToConnectUsingLatestUpdateArgsLocked(); + + // Lateset update args. + UpdateArgs latest_update_args_; + // All our subchannels. + OrphanablePtr subchannel_list_; + // Latest pending subchannel list. + OrphanablePtr latest_pending_subchannel_list_; + // Selected subchannel in \a subchannel_list_. + PickFirstSubchannelData* selected_ = nullptr; + // Are we in IDLE state? + bool idle_ = false; + // Are we shut down? + bool shutdown_ = false; +}; + +PickFirst::PickFirst(Args args) : LoadBalancingPolicy(std::move(args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, "Pick First %p created.", this); + } +} + +PickFirst::~PickFirst() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, "Destroying Pick First %p", this); + } + GPR_ASSERT(subchannel_list_ == nullptr); + GPR_ASSERT(latest_pending_subchannel_list_ == nullptr); +} + +void PickFirst::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, "Pick First %p Shutting down", this); + } + shutdown_ = true; + subchannel_list_.reset(); + latest_pending_subchannel_list_.reset(); +} + +void PickFirst::ExitIdleLocked() { + if (shutdown_) return; + if (idle_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, "Pick First %p exiting idle", this); + } + idle_ = false; + AttemptToConnectUsingLatestUpdateArgsLocked(); + } +} + +void PickFirst::ResetBackoffLocked() { + if (subchannel_list_ != nullptr) subchannel_list_->ResetBackoffLocked(); + if (latest_pending_subchannel_list_ != nullptr) { + latest_pending_subchannel_list_->ResetBackoffLocked(); + } +} + +void PickFirst::AttemptToConnectUsingLatestUpdateArgsLocked() { + // Create a subchannel list from the latest_update_args_. + auto subchannel_list = MakeOrphanable( + this, &grpc_lb_pick_first_trace, latest_update_args_.addresses, + *latest_update_args_.args); + // Empty update or no valid subchannels. + if (subchannel_list->num_subchannels() == 0) { + // Unsubscribe from all current subchannels. + subchannel_list_ = std::move(subchannel_list); // Empty list. + selected_ = nullptr; + // If not idle, put the channel in TRANSIENT_FAILURE. + // (If we are idle, then this will happen in ExitIdleLocked() if we + // haven't gotten a non-empty update by the time the application tries + // to start a new call.) + absl::Status status = absl::UnavailableError("Empty update"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + return; + } + // If one of the subchannels in the new list is already in state + // READY, then select it immediately. This can happen when the + // currently selected subchannel is also present in the update. It + // can also happen if one of the subchannels in the update is already + // in the global subchannel pool because it's in use by another channel. + for (size_t i = 0; i < subchannel_list->num_subchannels(); ++i) { + PickFirstSubchannelData* sd = subchannel_list->subchannel(i); + grpc_connectivity_state state = sd->CheckConnectivityStateLocked(); + if (state == GRPC_CHANNEL_READY) { + subchannel_list_ = std::move(subchannel_list); + sd->StartConnectivityWatchLocked(); + sd->ProcessUnselectedReadyLocked(); + // If there was a previously pending update (which may or may + // not have contained the currently selected subchannel), drop + // it, so that it doesn't override what we've done here. + latest_pending_subchannel_list_.reset(); + return; + } + } + if (selected_ == nullptr) { + // We don't yet have a selected subchannel, so replace the current + // subchannel list immediately. + subchannel_list_ = std::move(subchannel_list); + // If we're not in IDLE state, start trying to connect to the first + // subchannel in the new list. + // Note: No need to use CheckConnectivityStateAndStartWatchingLocked() + // here, since we've already checked the initial connectivity + // state of all subchannels above. + subchannel_list_->subchannel(0)->StartConnectivityWatchLocked(); + subchannel_list_->subchannel(0)->subchannel()->AttemptToConnect(); + } else { + // We do have a selected subchannel (which means it's READY), so keep + // using it until one of the subchannels in the new list reports READY. + if (latest_pending_subchannel_list_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p Shutting down latest pending subchannel list " + "%p, about to be replaced by newer latest %p", + this, latest_pending_subchannel_list_.get(), + subchannel_list.get()); + } + } + latest_pending_subchannel_list_ = std::move(subchannel_list); + // If we're not in IDLE state, start trying to connect to the first + // subchannel in the new list. + // Note: No need to use CheckConnectivityStateAndStartWatchingLocked() + // here, since we've already checked the initial connectivity + // state of all subchannels above. + latest_pending_subchannel_list_->subchannel(0) + ->StartConnectivityWatchLocked(); + latest_pending_subchannel_list_->subchannel(0) + ->subchannel() + ->AttemptToConnect(); + } +} + +void PickFirst::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p received update with %" PRIuPTR " addresses", this, + args.addresses.size()); + } + // Update the latest_update_args_ + grpc_arg new_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_INHIBIT_HEALTH_CHECKING), 1); + const grpc_channel_args* new_args = + grpc_channel_args_copy_and_add(args.args, &new_arg, 1); + std::swap(new_args, args.args); + grpc_channel_args_destroy(new_args); + latest_update_args_ = std::move(args); + // If we are not in idle, start connection attempt immediately. + // Otherwise, we defer the attempt into ExitIdleLocked(). + if (!idle_) { + AttemptToConnectUsingLatestUpdateArgsLocked(); + } +} + +void PickFirst::PickFirstSubchannelData::ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) { + PickFirst* p = static_cast(subchannel_list()->policy()); + // The notification must be for a subchannel in either the current or + // latest pending subchannel lists. + GPR_ASSERT(subchannel_list() == p->subchannel_list_.get() || + subchannel_list() == p->latest_pending_subchannel_list_.get()); + GPR_ASSERT(connectivity_state != GRPC_CHANNEL_SHUTDOWN); + // Handle updates for the currently selected subchannel. + if (p->selected_ == this) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p selected subchannel connectivity changed to %s", p, + ConnectivityStateName(connectivity_state)); + } + // If the new state is anything other than READY and there is a + // pending update, switch to the pending update. + if (connectivity_state != GRPC_CHANNEL_READY && + p->latest_pending_subchannel_list_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p promoting pending subchannel list %p to " + "replace %p", + p, p->latest_pending_subchannel_list_.get(), + p->subchannel_list_.get()); + } + p->selected_ = nullptr; + CancelConnectivityWatchLocked( + "selected subchannel failed; switching to pending update"); + p->subchannel_list_ = std::move(p->latest_pending_subchannel_list_); + // Set our state to that of the pending subchannel list. + if (p->subchannel_list_->in_transient_failure()) { + absl::Status status = absl::UnavailableError( + "selected subchannel failed; switching to pending update"); + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } else { + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique( + p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } + } else { + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + // If the selected subchannel goes bad, request a re-resolution. We + // also set the channel state to IDLE. The reason is that if the new + // state is TRANSIENT_FAILURE due to a GOAWAY reception we don't want + // to connect to the re-resolved backends until we leave IDLE state. + // TODO(qianchengz): We may want to request re-resolution in + // ExitIdleLocked(). + p->idle_ = true; + p->channel_control_helper()->RequestReresolution(); + p->selected_ = nullptr; + p->subchannel_list_.reset(); + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_IDLE, absl::Status(), + absl::make_unique( + p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } else { + // This is unlikely but can happen when a subchannel has been asked + // to reconnect by a different channel and this channel has dropped + // some connectivity state notifications. + if (connectivity_state == GRPC_CHANNEL_READY) { + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), + absl::make_unique(subchannel()->Ref())); + } else { // CONNECTING + p->channel_control_helper()->UpdateState( + connectivity_state, absl::Status(), + absl::make_unique( + p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } + } + } + return; + } + // If we get here, there are two possible cases: + // 1. We do not currently have a selected subchannel, and the update is + // for a subchannel in p->subchannel_list_ that we're trying to + // connect to. The goal here is to find a subchannel that we can + // select. + // 2. We do currently have a selected subchannel, and the update is + // for a subchannel in p->latest_pending_subchannel_list_. The + // goal here is to find a subchannel from the update that we can + // select in place of the current one. + subchannel_list()->set_in_transient_failure(false); + switch (connectivity_state) { + case GRPC_CHANNEL_READY: { + ProcessUnselectedReadyLocked(); + break; + } + case GRPC_CHANNEL_TRANSIENT_FAILURE: { + CancelConnectivityWatchLocked("connection attempt failed"); + PickFirstSubchannelData* sd = this; + size_t next_index = + (sd->Index() + 1) % subchannel_list()->num_subchannels(); + sd = subchannel_list()->subchannel(next_index); + // If we're tried all subchannels, set state to TRANSIENT_FAILURE. + if (sd->Index() == 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p subchannel list %p failed to connect to " + "all subchannels", + p, subchannel_list()); + } + subchannel_list()->set_in_transient_failure(true); + // In case 2, swap to the new subchannel list. This means reporting + // TRANSIENT_FAILURE and dropping the existing (working) connection, + // but we can't ignore what the control plane has told us. + if (subchannel_list() == p->latest_pending_subchannel_list_.get()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p promoting pending subchannel list %p to " + "replace %p", + p, p->latest_pending_subchannel_list_.get(), + p->subchannel_list_.get()); + } + p->subchannel_list_ = std::move(p->latest_pending_subchannel_list_); + } + // If this is the current subchannel list (either because we were + // in case 1 or because we were in case 2 and just promoted it to + // be the current list), re-resolve and report new state. + if (subchannel_list() == p->subchannel_list_.get()) { + p->channel_control_helper()->RequestReresolution(); + absl::Status status = + absl::UnavailableError("failed to connect to all addresses"); + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } + } + sd->CheckConnectivityStateAndStartWatchingLocked(); + break; + } + case GRPC_CHANNEL_CONNECTING: + case GRPC_CHANNEL_IDLE: { + // Only update connectivity state in case 1. + if (subchannel_list() == p->subchannel_list_.get()) { + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique( + p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } + break; + } + case GRPC_CHANNEL_SHUTDOWN: + GPR_UNREACHABLE_CODE(break); + } +} + +void PickFirst::PickFirstSubchannelData::ProcessUnselectedReadyLocked() { + PickFirst* p = static_cast(subchannel_list()->policy()); + // If we get here, there are two possible cases: + // 1. We do not currently have a selected subchannel, and the update is + // for a subchannel in p->subchannel_list_ that we're trying to + // connect to. The goal here is to find a subchannel that we can + // select. + // 2. We do currently have a selected subchannel, and the update is + // for a subchannel in p->latest_pending_subchannel_list_. The + // goal here is to find a subchannel from the update that we can + // select in place of the current one. + GPR_ASSERT(subchannel_list() == p->subchannel_list_.get() || + subchannel_list() == p->latest_pending_subchannel_list_.get()); + // Case 2. Promote p->latest_pending_subchannel_list_ to p->subchannel_list_. + if (subchannel_list() == p->latest_pending_subchannel_list_.get()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, + "Pick First %p promoting pending subchannel list %p to " + "replace %p", + p, p->latest_pending_subchannel_list_.get(), + p->subchannel_list_.get()); + } + p->subchannel_list_ = std::move(p->latest_pending_subchannel_list_); + } + // Cases 1 and 2. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_pick_first_trace)) { + gpr_log(GPR_INFO, "Pick First %p selected subchannel %p", p, subchannel()); + } + p->selected_ = this; + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), + absl::make_unique(subchannel()->Ref())); + for (size_t i = 0; i < subchannel_list()->num_subchannels(); ++i) { + if (i != Index()) { + subchannel_list()->subchannel(i)->ShutdownLocked(); + } + } +} + +void PickFirst::PickFirstSubchannelData:: + CheckConnectivityStateAndStartWatchingLocked() { + PickFirst* p = static_cast(subchannel_list()->policy()); + // Check current state. + grpc_connectivity_state current_state = CheckConnectivityStateLocked(); + // Start watch. + StartConnectivityWatchLocked(); + // If current state is READY, select the subchannel now, since we started + // watching from this state and will not get a notification of it + // transitioning into this state. + // If the current state is not READY, attempt to connect. + if (current_state == GRPC_CHANNEL_READY) { + if (p->selected_ != this) ProcessUnselectedReadyLocked(); + } else { + subchannel()->AttemptToConnect(); + } +} + +class PickFirstConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kPickFirst; } +}; + +// +// factory +// + +class PickFirstFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kPickFirst; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } +}; + +} // namespace + +} // namespace grpc_core + +void grpc_lb_policy_pick_first_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_pick_first_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc new file mode 100644 index 00000000..cb9fc3a3 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc @@ -0,0 +1,918 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/address_filtering.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +TraceFlag grpc_lb_priority_trace(false, "priority_lb"); + +namespace { + +constexpr char kPriority[] = "priority_experimental"; + +// How long we keep a child around for after it is no longer being used +// (either because it has been removed from the config or because we +// have switched to a higher-priority child). +constexpr int kChildRetentionIntervalMs = 15 * 60 * 1000; + +// Default for how long we wait for a newly created child to get connected +// before starting to attempt the next priority. Overridable via channel arg. +constexpr int kDefaultChildFailoverTimeoutMs = 10000; + +// Config for priority LB policy. +class PriorityLbConfig : public LoadBalancingPolicy::Config { + public: + struct PriorityLbChild { + RefCountedPtr config; + bool ignore_reresolution_requests = false; + }; + + PriorityLbConfig(std::map children, + std::vector priorities) + : children_(std::move(children)), priorities_(std::move(priorities)) {} + + const char* name() const override { return kPriority; } + + const std::map& children() const { + return children_; + } + const std::vector& priorities() const { return priorities_; } + + private: + const std::map children_; + const std::vector priorities_; +}; + +// priority LB policy. +class PriorityLb : public LoadBalancingPolicy { + public: + explicit PriorityLb(Args args); + + const char* name() const override { return kPriority; } + + void UpdateLocked(UpdateArgs args) override; + void ExitIdleLocked() override; + void ResetBackoffLocked() override; + + private: + // Each ChildPriority holds a ref to the PriorityLb. + class ChildPriority : public InternallyRefCounted { + public: + ChildPriority(RefCountedPtr priority_policy, std::string name); + + ~ChildPriority() override { + priority_policy_.reset(DEBUG_LOCATION, "ChildPriority"); + } + + const std::string& name() const { return name_; } + + void UpdateLocked(RefCountedPtr config, + bool ignore_reresolution_requests); + void ExitIdleLocked(); + void ResetBackoffLocked(); + void DeactivateLocked(); + void MaybeReactivateLocked(); + void MaybeCancelFailoverTimerLocked(); + + void Orphan() override; + + std::unique_ptr GetPicker() { + return absl::make_unique(picker_wrapper_); + } + + grpc_connectivity_state connectivity_state() const { + return connectivity_state_; + } + + const absl::Status& connectivity_status() const { + return connectivity_status_; + } + + bool failover_timer_callback_pending() const { + return failover_timer_callback_pending_; + } + + private: + // A simple wrapper for ref-counting a picker from the child policy. + class RefCountedPicker : public RefCounted { + public: + explicit RefCountedPicker(std::unique_ptr picker) + : picker_(std::move(picker)) {} + PickResult Pick(PickArgs args) { return picker_->Pick(args); } + + private: + std::unique_ptr picker_; + }; + + // A non-ref-counted wrapper for RefCountedPicker. + class RefCountedPickerWrapper : public SubchannelPicker { + public: + explicit RefCountedPickerWrapper(RefCountedPtr picker) + : picker_(std::move(picker)) {} + PickResult Pick(PickArgs args) override { return picker_->Pick(args); } + + private: + RefCountedPtr picker_; + }; + + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr priority) + : priority_(std::move(priority)) {} + + ~Helper() override { priority_.reset(DEBUG_LOCATION, "Helper"); } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr priority_; + }; + + // Methods for dealing with the child policy. + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + + void OnConnectivityStateUpdateLocked( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker); + + void StartFailoverTimerLocked(); + + static void OnFailoverTimer(void* arg, grpc_error_handle error); + void OnFailoverTimerLocked(grpc_error_handle error); + static void OnDeactivationTimer(void* arg, grpc_error_handle error); + void OnDeactivationTimerLocked(grpc_error_handle error); + + RefCountedPtr priority_policy_; + const std::string name_; + bool ignore_reresolution_requests_ = false; + + OrphanablePtr child_policy_; + + grpc_connectivity_state connectivity_state_ = GRPC_CHANNEL_CONNECTING; + absl::Status connectivity_status_; + RefCountedPtr picker_wrapper_; + + // States for delayed removal. + grpc_timer deactivation_timer_; + grpc_closure on_deactivation_timer_; + bool deactivation_timer_callback_pending_ = false; + + // States of failover. + grpc_timer failover_timer_; + grpc_closure on_failover_timer_; + bool failover_timer_callback_pending_ = false; + }; + + ~PriorityLb() override; + + void ShutdownLocked() override; + + // Returns UINT32_MAX if child is not in current priority list. + uint32_t GetChildPriorityLocked(const std::string& child_name) const; + + void HandleChildConnectivityStateChangeLocked(ChildPriority* child); + void DeleteChild(ChildPriority* child); + + void TryNextPriorityLocked(bool report_connecting); + void SelectPriorityLocked(uint32_t priority); + + const int child_failover_timeout_ms_; + + // Current channel args and config from the resolver. + const grpc_channel_args* args_ = nullptr; + RefCountedPtr config_; + HierarchicalAddressMap addresses_; + + // Internal state. + bool shutting_down_ = false; + + std::map> children_; + // The priority that is being used. + uint32_t current_priority_ = UINT32_MAX; + // Points to the current child from before the most recent update. + // We will continue to use this child until we decide which of the new + // children to use. + ChildPriority* current_child_from_before_update_ = nullptr; +}; + +// +// PriorityLb +// + +PriorityLb::PriorityLb(Args args) + : LoadBalancingPolicy(std::move(args)), + child_failover_timeout_ms_(grpc_channel_args_find_integer( + args.args, GRPC_ARG_PRIORITY_FAILOVER_TIMEOUT_MS, + {kDefaultChildFailoverTimeoutMs, 0, INT_MAX})) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] created", this); + } +} + +PriorityLb::~PriorityLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] destroying priority LB policy", this); + } + grpc_channel_args_destroy(args_); +} + +void PriorityLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] shutting down", this); + } + shutting_down_ = true; + children_.clear(); +} + +void PriorityLb::ExitIdleLocked() { + if (current_priority_ != UINT32_MAX) { + const std::string& child_name = config_->priorities()[current_priority_]; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] exiting IDLE for current priority %d child %s", + this, current_priority_, child_name.c_str()); + } + children_[child_name]->ExitIdleLocked(); + } +} + +void PriorityLb::ResetBackoffLocked() { + for (const auto& p : children_) p.second->ResetBackoffLocked(); +} + +void PriorityLb::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] received update", this); + } + // Save current child. + if (current_priority_ != UINT32_MAX) { + const std::string& child_name = config_->priorities()[current_priority_]; + current_child_from_before_update_ = children_[child_name].get(); + // Unset current_priority_, since it was an index into the old + // config's priority list and may no longer be valid. It will be + // reset later by TryNextPriorityLocked(), but we unset it here in + // case updating any of our children triggers a state update. + current_priority_ = UINT32_MAX; + } + // Update config. + config_ = std::move(args.config); + // Update args. + grpc_channel_args_destroy(args_); + args_ = args.args; + args.args = nullptr; + // Update addresses. + addresses_ = MakeHierarchicalAddressMap(args.addresses); + // Check all existing children against the new config. + for (const auto& p : children_) { + const std::string& child_name = p.first; + auto& child = p.second; + auto config_it = config_->children().find(child_name); + if (config_it == config_->children().end()) { + // Existing child not found in new config. Deactivate it. + child->DeactivateLocked(); + } else { + // Existing child found in new config. Update it. + child->UpdateLocked(config_it->second.config, + config_it->second.ignore_reresolution_requests); + } + } + // Try to get connected. + TryNextPriorityLocked(/*report_connecting=*/children_.empty()); +} + +uint32_t PriorityLb::GetChildPriorityLocked( + const std::string& child_name) const { + for (uint32_t priority = 0; priority < config_->priorities().size(); + ++priority) { + if (config_->priorities()[priority] == child_name) return priority; + } + return UINT32_MAX; +} + +void PriorityLb::HandleChildConnectivityStateChangeLocked( + ChildPriority* child) { + // Special case for the child that was the current child before the + // most recent update. + if (child == current_child_from_before_update_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] state update for current child from before " + "config update", + this); + } + if (child->connectivity_state() == GRPC_CHANNEL_READY || + child->connectivity_state() == GRPC_CHANNEL_IDLE) { + // If it's still READY or IDLE, we stick with this child, so pass + // the new picker up to our parent. + channel_control_helper()->UpdateState(child->connectivity_state(), + child->connectivity_status(), + child->GetPicker()); + } else { + // If it's no longer READY or IDLE, we should stop using it. + // We already started trying other priorities as a result of the + // update, but calling TryNextPriorityLocked() ensures that we will + // properly select between CONNECTING and TRANSIENT_FAILURE as the + // new state to report to our parent. + current_child_from_before_update_ = nullptr; + TryNextPriorityLocked(/*report_connecting=*/true); + } + return; + } + // Otherwise, find the child's priority. + uint32_t child_priority = GetChildPriorityLocked(child->name()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] state update for priority %u, child %s, current " + "priority %u", + this, child_priority, child->name().c_str(), current_priority_); + } + // Ignore priorities not in the current config. + if (child_priority == UINT32_MAX) return; + // Ignore lower-than-current priorities. + if (child_priority > current_priority_) return; + // If a child reports TRANSIENT_FAILURE, start trying the next priority. + // Note that even if this is for a higher-than-current priority, we + // may still need to create some children between this priority and + // the current one (e.g., if we got an update that inserted new + // priorities ahead of the current one). + if (child->connectivity_state() == GRPC_CHANNEL_TRANSIENT_FAILURE) { + TryNextPriorityLocked( + /*report_connecting=*/child_priority == current_priority_); + return; + } + // The update is for a higher-than-current priority (or for any + // priority if we don't have any current priority). + if (child_priority < current_priority_) { + // If the child reports READY or IDLE, switch to that priority. + // Otherwise, ignore the update. + if (child->connectivity_state() == GRPC_CHANNEL_READY || + child->connectivity_state() == GRPC_CHANNEL_IDLE) { + SelectPriorityLocked(child_priority); + } + return; + } + // The current priority has returned a new picker, so pass it up to + // our parent. + channel_control_helper()->UpdateState(child->connectivity_state(), + child->connectivity_status(), + child->GetPicker()); +} + +void PriorityLb::DeleteChild(ChildPriority* child) { + // If this was the current child from before the most recent update, + // stop using it. We already started trying other priorities as a + // result of the update, but calling TryNextPriorityLocked() ensures that + // we will properly select between CONNECTING and TRANSIENT_FAILURE as the + // new state to report to our parent. + if (current_child_from_before_update_ == child) { + current_child_from_before_update_ = nullptr; + TryNextPriorityLocked(/*report_connecting=*/true); + } + children_.erase(child->name()); +} + +void PriorityLb::TryNextPriorityLocked(bool report_connecting) { + current_priority_ = UINT32_MAX; + for (uint32_t priority = 0; priority < config_->priorities().size(); + ++priority) { + // If the child for the priority does not exist yet, create it. + const std::string& child_name = config_->priorities()[priority]; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] trying priority %u, child %s", this, + priority, child_name.c_str()); + } + auto& child = children_[child_name]; + if (child == nullptr) { + if (report_connecting) { + channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique(Ref(DEBUG_LOCATION, "QueuePicker"))); + } + child = MakeOrphanable( + Ref(DEBUG_LOCATION, "ChildPriority"), child_name); + auto child_config = config_->children().find(child_name); + GPR_DEBUG_ASSERT(child_config != config_->children().end()); + child->UpdateLocked(child_config->second.config, + child_config->second.ignore_reresolution_requests); + return; + } + // The child already exists. + child->MaybeReactivateLocked(); + // If the child is in state READY or IDLE, switch to it. + if (child->connectivity_state() == GRPC_CHANNEL_READY || + child->connectivity_state() == GRPC_CHANNEL_IDLE) { + SelectPriorityLocked(priority); + return; + } + // Child is not READY or IDLE. + // If its failover timer is still pending, give it time to fire. + if (child->failover_timer_callback_pending()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] priority %u, child %s: child still " + "attempting to connect, will wait", + this, priority, child_name.c_str()); + } + if (report_connecting) { + channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique(Ref(DEBUG_LOCATION, "QueuePicker"))); + } + return; + } + // Child has been failing for a while. Move on to the next priority. + } + // If there are no more priorities to try, report TRANSIENT_FAILURE. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] no priority reachable, putting channel in " + "TRANSIENT_FAILURE", + this); + } + current_child_from_before_update_ = nullptr; + absl::Status status = absl::UnavailableError("no ready priority"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); +} + +void PriorityLb::SelectPriorityLocked(uint32_t priority) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] selected priority %u, child %s", this, + priority, config_->priorities()[priority].c_str()); + } + current_priority_ = priority; + current_child_from_before_update_ = nullptr; + // Deactivate lower priorities. + for (uint32_t p = priority + 1; p < config_->priorities().size(); ++p) { + const std::string& child_name = config_->priorities()[p]; + auto it = children_.find(child_name); + if (it != children_.end()) it->second->DeactivateLocked(); + } + // Update picker. + auto& child = children_[config_->priorities()[priority]]; + channel_control_helper()->UpdateState(child->connectivity_state(), + child->connectivity_status(), + child->GetPicker()); +} + +// +// PriorityLb::ChildPriority +// + +PriorityLb::ChildPriority::ChildPriority( + RefCountedPtr priority_policy, std::string name) + : priority_policy_(std::move(priority_policy)), name_(std::move(name)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] creating child %s (%p)", + priority_policy_.get(), name_.c_str(), this); + } + GRPC_CLOSURE_INIT(&on_failover_timer_, OnFailoverTimer, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_deactivation_timer_, OnDeactivationTimer, this, + grpc_schedule_on_exec_ctx); + // Start the failover timer. + StartFailoverTimerLocked(); +} + +void PriorityLb::ChildPriority::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): orphaned", + priority_policy_.get(), name_.c_str(), this); + } + MaybeCancelFailoverTimerLocked(); + if (deactivation_timer_callback_pending_) { + grpc_timer_cancel(&deactivation_timer_); + } + // Remove the child policy's interested_parties pollset_set from the + // xDS policy. + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + priority_policy_->interested_parties()); + child_policy_.reset(); + // Drop our ref to the child's picker, in case it's holding a ref to + // the child. + picker_wrapper_.reset(); + if (deactivation_timer_callback_pending_) { + grpc_timer_cancel(&deactivation_timer_); + } + Unref(DEBUG_LOCATION, "ChildPriority+Orphan"); +} + +void PriorityLb::ChildPriority::UpdateLocked( + RefCountedPtr config, + bool ignore_reresolution_requests) { + if (priority_policy_->shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): start update", + priority_policy_.get(), name_.c_str(), this); + } + ignore_reresolution_requests_ = ignore_reresolution_requests; + // Create policy if needed. + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(priority_policy_->args_); + } + // Construct update args. + UpdateArgs update_args; + update_args.config = std::move(config); + update_args.addresses = priority_policy_->addresses_[name_]; + update_args.args = grpc_channel_args_copy(priority_policy_->args_); + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): updating child policy handler %p", + priority_policy_.get(), name_.c_str(), this, child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +OrphanablePtr +PriorityLb::ChildPriority::CreateChildPolicyLocked( + const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = priority_policy_->work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = + absl::make_unique(this->Ref(DEBUG_LOCATION, "Helper")); + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_lb_priority_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): created new child policy " + "handler %p", + priority_policy_.get(), name_.c_str(), this, lb_policy.get()); + } + // Add the parent's interested_parties pollset_set to that of the newly + // created child policy. This will make the child policy progress upon + // activity on the parent LB, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + priority_policy_->interested_parties()); + return lb_policy; +} + +void PriorityLb::ChildPriority::ExitIdleLocked() { + if (connectivity_state_ == GRPC_CHANNEL_IDLE && + !failover_timer_callback_pending_) { + StartFailoverTimerLocked(); + } + child_policy_->ExitIdleLocked(); +} + +void PriorityLb::ChildPriority::ResetBackoffLocked() { + child_policy_->ResetBackoffLocked(); +} + +void PriorityLb::ChildPriority::OnConnectivityStateUpdateLocked( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): state update: %s (%s) picker %p", + priority_policy_.get(), name_.c_str(), this, + ConnectivityStateName(state), status.ToString().c_str(), + picker.get()); + } + // Store the state and picker. + connectivity_state_ = state; + connectivity_status_ = status; + picker_wrapper_ = MakeRefCounted(std::move(picker)); + // If READY or TRANSIENT_FAILURE, cancel failover timer. + if (state == GRPC_CHANNEL_READY || state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + MaybeCancelFailoverTimerLocked(); + } + // Notify the parent policy. + priority_policy_->HandleChildConnectivityStateChangeLocked(this); +} + +void PriorityLb::ChildPriority::StartFailoverTimerLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): starting failover timer for %d ms", + priority_policy_.get(), name_.c_str(), this, + priority_policy_->child_failover_timeout_ms_); + } + Ref(DEBUG_LOCATION, "ChildPriority+OnFailoverTimerLocked").release(); + grpc_timer_init( + &failover_timer_, + ExecCtx::Get()->Now() + priority_policy_->child_failover_timeout_ms_, + &on_failover_timer_); + failover_timer_callback_pending_ = true; +} + +void PriorityLb::ChildPriority::MaybeCancelFailoverTimerLocked() { + if (failover_timer_callback_pending_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): cancelling failover timer", + priority_policy_.get(), name_.c_str(), this); + } + grpc_timer_cancel(&failover_timer_); + failover_timer_callback_pending_ = false; + } +} + +void PriorityLb::ChildPriority::OnFailoverTimer(void* arg, + grpc_error_handle error) { + ChildPriority* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + self->priority_policy_->work_serializer()->Run( + [self, error]() { self->OnFailoverTimerLocked(error); }, DEBUG_LOCATION); +} + +void PriorityLb::ChildPriority::OnFailoverTimerLocked(grpc_error_handle error) { + if (error == GRPC_ERROR_NONE && failover_timer_callback_pending_ && + !priority_policy_->shutting_down_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): failover timer fired, " + "reporting TRANSIENT_FAILURE", + priority_policy_.get(), name_.c_str(), this); + } + failover_timer_callback_pending_ = false; + OnConnectivityStateUpdateLocked( + GRPC_CHANNEL_TRANSIENT_FAILURE, + absl::Status(absl::StatusCode::kUnavailable, "failover timer fired"), + nullptr); + } + Unref(DEBUG_LOCATION, "ChildPriority+OnFailoverTimerLocked"); + GRPC_ERROR_UNREF(error); +} + +void PriorityLb::ChildPriority::DeactivateLocked() { + // If already deactivated, don't do it again. + if (deactivation_timer_callback_pending_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): deactivating -- will remove in %d " + "ms.", + priority_policy_.get(), name_.c_str(), this, + kChildRetentionIntervalMs); + } + MaybeCancelFailoverTimerLocked(); + // Start a timer to delete the child. + Ref(DEBUG_LOCATION, "ChildPriority+timer").release(); + grpc_timer_init(&deactivation_timer_, + ExecCtx::Get()->Now() + kChildRetentionIntervalMs, + &on_deactivation_timer_); + deactivation_timer_callback_pending_ = true; +} + +void PriorityLb::ChildPriority::MaybeReactivateLocked() { + if (deactivation_timer_callback_pending_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): reactivating", + priority_policy_.get(), name_.c_str(), this); + } + deactivation_timer_callback_pending_ = false; + grpc_timer_cancel(&deactivation_timer_); + } +} + +void PriorityLb::ChildPriority::OnDeactivationTimer(void* arg, + grpc_error_handle error) { + ChildPriority* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + self->priority_policy_->work_serializer()->Run( + [self, error]() { self->OnDeactivationTimerLocked(error); }, + DEBUG_LOCATION); +} + +void PriorityLb::ChildPriority::OnDeactivationTimerLocked( + grpc_error_handle error) { + if (error == GRPC_ERROR_NONE && deactivation_timer_callback_pending_ && + !priority_policy_->shutting_down_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) { + gpr_log(GPR_INFO, + "[priority_lb %p] child %s (%p): deactivation timer fired, " + "deleting child", + priority_policy_.get(), name_.c_str(), this); + } + deactivation_timer_callback_pending_ = false; + priority_policy_->DeleteChild(this); + } + Unref(DEBUG_LOCATION, "ChildPriority+timer"); + GRPC_ERROR_UNREF(error); +} + +// +// PriorityLb::ChildPriority::Helper +// + +RefCountedPtr +PriorityLb::ChildPriority::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (priority_->priority_policy_->shutting_down_) return nullptr; + return priority_->priority_policy_->channel_control_helper() + ->CreateSubchannel(std::move(address), args); +} + +void PriorityLb::ChildPriority::Helper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (priority_->priority_policy_->shutting_down_) return; + // Notify the priority. + priority_->OnConnectivityStateUpdateLocked(state, status, std::move(picker)); +} + +void PriorityLb::ChildPriority::Helper::RequestReresolution() { + if (priority_->priority_policy_->shutting_down_) return; + if (priority_->ignore_reresolution_requests_) { + return; + } + priority_->priority_policy_->channel_control_helper()->RequestReresolution(); +} + +absl::string_view PriorityLb::ChildPriority::Helper::GetAuthority() { + return priority_->priority_policy_->channel_control_helper()->GetAuthority(); +} + +void PriorityLb::ChildPriority::Helper::AddTraceEvent( + TraceSeverity severity, absl::string_view message) { + if (priority_->priority_policy_->shutting_down_) return; + priority_->priority_policy_->channel_control_helper()->AddTraceEvent(severity, + message); +} + +// +// factory +// + +class PriorityLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kPriority; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // priority was mentioned as a policy in the deprecated + // loadBalancingPolicy field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:priority policy requires " + "configuration. Please use loadBalancingConfig field of service " + "config instead."); + return nullptr; + } + std::vector error_list; + // Children. + std::map children; + auto it = json.object_value().find("children"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:children error:required field missing")); + } else if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:children error:type should be object")); + } else { + const Json::Object& object = it->second.object_value(); + for (const auto& p : object) { + const std::string& child_name = p.first; + const Json& element = p.second; + if (element.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:children key:", child_name, + " error:should be type object"))); + } else { + auto it2 = element.object_value().find("config"); + if (it2 == element.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:children key:", child_name, + " error:missing 'config' field"))); + } else { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + auto config = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + it2->second, &parse_error); + bool ignore_resolution_requests = false; + // If present, ignore_reresolution_requests must be of type + // boolean. + auto it3 = + element.object_value().find("ignore_reresolution_requests"); + if (it3 != element.object_value().end()) { + if (it3->second.type() == Json::Type::JSON_TRUE) { + ignore_resolution_requests = true; + } else if (it3->second.type() != Json::Type::JSON_FALSE) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:children key:", child_name, + " field:ignore_reresolution_requests:should " + "be type boolean"))); + } + } + if (config == nullptr) { + GPR_DEBUG_ASSERT(parse_error != GRPC_ERROR_NONE); + error_list.push_back( + GRPC_ERROR_CREATE_REFERENCING_FROM_COPIED_STRING( + absl::StrCat("field:children key:", child_name).c_str(), + &parse_error, 1)); + GRPC_ERROR_UNREF(parse_error); + } + children[child_name].config = std::move(config); + children[child_name].ignore_reresolution_requests = + ignore_resolution_requests; + } + } + } + } + // Priorities. + std::vector priorities; + it = json.object_value().find("priorities"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:priorities error:required field missing")); + } else if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:priorities error:type should be array")); + } else { + const Json::Array& array = it->second.array_value(); + for (size_t i = 0; i < array.size(); ++i) { + const Json& element = array[i]; + if (element.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:priorities element:", i, " error:should be type string"))); + } else if (children.find(element.string_value()) == children.end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:priorities element:", i, " error:unknown child '", + element.string_value(), "'"))); + } else { + priorities.emplace_back(element.string_value()); + } + } + if (priorities.size() != children.size()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:priorities error:priorities size (", priorities.size(), + ") != children size (", children.size(), ")"))); + } + } + if (error_list.empty()) { + return MakeRefCounted(std::move(children), + std::move(priorities)); + } else { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "priority_experimental LB policy config", &error_list); + return nullptr; + } + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_priority_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_priority_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc new file mode 100644 index 00000000..3f954f2d --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc @@ -0,0 +1,757 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#define XXH_INLINE_ALL +#include "xxhash.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/subchannel_list.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/static_metadata.h" + +namespace grpc_core { + +const char* kRequestRingHashAttribute = "request_ring_hash"; +TraceFlag grpc_lb_ring_hash_trace(false, "ring_hash_lb"); + +// Helper Parser method +void ParseRingHashLbConfig(const Json& json, size_t* min_ring_size, + size_t* max_ring_size, + std::vector* error_list) { + *min_ring_size = 1024; + *max_ring_size = 8388608; + if (json.type() != Json::Type::OBJECT) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ring_hash_experimental should be of type object")); + return; + } + const Json::Object& ring_hash = json.object_value(); + auto ring_hash_it = ring_hash.find("min_ring_size"); + if (ring_hash_it != ring_hash.end()) { + if (ring_hash_it->second.type() != Json::Type::NUMBER) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:min_ring_size error: should be of type number")); + } else { + *min_ring_size = gpr_parse_nonnegative_int( + ring_hash_it->second.string_value().c_str()); + } + } + ring_hash_it = ring_hash.find("max_ring_size"); + if (ring_hash_it != ring_hash.end()) { + if (ring_hash_it->second.type() != Json::Type::NUMBER) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:max_ring_size error: should be of type number")); + } else { + *max_ring_size = gpr_parse_nonnegative_int( + ring_hash_it->second.string_value().c_str()); + } + } + if (*min_ring_size == 0 || *min_ring_size > 8388608 || *max_ring_size == 0 || + *max_ring_size > 8388608 || *min_ring_size > *max_ring_size) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:max_ring_size and or min_ring_size error: " + "values need to be in the range of 1 to 8388608 " + "and max_ring_size cannot be smaller than " + "min_ring_size")); + } +} + +namespace { + +constexpr char kRingHash[] = "ring_hash_experimental"; + +class RingHashLbConfig : public LoadBalancingPolicy::Config { + public: + RingHashLbConfig(size_t min_ring_size, size_t max_ring_size) + : min_ring_size_(min_ring_size), max_ring_size_(max_ring_size) {} + const char* name() const override { return kRingHash; } + size_t min_ring_size() const { return min_ring_size_; } + size_t max_ring_size() const { return max_ring_size_; } + + private: + size_t min_ring_size_; + size_t max_ring_size_; +}; + +// +// ring_hash LB policy +// +class RingHash : public LoadBalancingPolicy { + public: + explicit RingHash(Args args); + + const char* name() const override { return kRingHash; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + + private: + ~RingHash() override; + + // Forward declaration. + class RingHashSubchannelList; + + // Data for a particular subchannel in a subchannel list. + // This subclass adds the following functionality: + // - Tracks the previous connectivity state of the subchannel, so that + // we know how many subchannels are in each state. + class RingHashSubchannelData + : public SubchannelData { + public: + RingHashSubchannelData( + SubchannelList* + subchannel_list, + const ServerAddress& address, + RefCountedPtr subchannel) + : SubchannelData(subchannel_list, address, std::move(subchannel)), + address_(address) {} + + grpc_connectivity_state connectivity_state() const { + return last_connectivity_state_; + } + const ServerAddress& address() const { return address_; } + + bool seen_failure_since_ready() const { return seen_failure_since_ready_; } + + // Performs connectivity state updates that need to be done both when we + // first start watching and when a watcher notification is received. + void UpdateConnectivityStateLocked( + grpc_connectivity_state connectivity_state); + + private: + // Performs connectivity state updates that need to be done only + // after we have started watching. + void ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) override; + + ServerAddress address_; + grpc_connectivity_state last_connectivity_state_ = GRPC_CHANNEL_SHUTDOWN; + bool seen_failure_since_ready_ = false; + }; + + // A list of subchannels. + class RingHashSubchannelList + : public SubchannelList { + public: + RingHashSubchannelList(RingHash* policy, TraceFlag* tracer, + ServerAddressList addresses, + const grpc_channel_args& args) + : SubchannelList(policy, tracer, std::move(addresses), + policy->channel_control_helper(), args) { + // Need to maintain a ref to the LB policy as long as we maintain + // any references to subchannels, since the subchannels' + // pollset_sets will include the LB policy's pollset_set. + policy->Ref(DEBUG_LOCATION, "subchannel_list").release(); + } + + ~RingHashSubchannelList() override { + RingHash* p = static_cast(policy()); + p->Unref(DEBUG_LOCATION, "subchannel_list"); + } + + // Starts watching the subchannels in this list. + void StartWatchingLocked(); + + // Updates the counters of subchannels in each state when a + // subchannel transitions from old_state to new_state. + void UpdateStateCountersLocked(grpc_connectivity_state old_state, + grpc_connectivity_state new_state); + + // Updates the RH policy's connectivity state based on the + // subchannel list's state counters, creating new picker and new ring. + // Furthermore, return a bool indicating whether the aggregated state is + // Transient Failure. + bool UpdateRingHashConnectivityStateLocked(); + + private: + size_t num_idle_ = 0; + size_t num_ready_ = 0; + size_t num_connecting_ = 0; + size_t num_transient_failure_ = 0; + }; + + class Picker : public SubchannelPicker { + public: + Picker(RefCountedPtr parent, + RingHashSubchannelList* subchannel_list); + + PickResult Pick(PickArgs args) override; + + private: + struct RingEntry { + uint64_t hash; + RefCountedPtr subchannel; + grpc_connectivity_state connectivity_state; + }; + + // A fire-and-forget class that schedules subchannel connection attempts + // on the control plane WorkSerializer. + class SubchannelConnectionAttempter : public Orphanable { + public: + explicit SubchannelConnectionAttempter( + RefCountedPtr ring_hash_lb) + : ring_hash_lb_(std::move(ring_hash_lb)) { + GRPC_CLOSURE_INIT(&closure_, RunInExecCtx, this, nullptr); + } + + void AddSubchannel(RefCountedPtr subchannel) { + subchannels_.push_back(std::move(subchannel)); + } + + void Orphan() override { + // Hop into ExecCtx, so that we're not holding the data plane mutex + // while we run control-plane code. + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); + } + + private: + static void RunInExecCtx(void* arg, grpc_error_handle /*error*/) { + auto* self = static_cast(arg); + self->ring_hash_lb_->work_serializer()->Run( + [self]() { + if (!self->ring_hash_lb_->shutdown_) { + for (auto& subchannel : self->subchannels_) { + subchannel->AttemptToConnect(); + } + } + delete self; + }, + DEBUG_LOCATION); + } + + RefCountedPtr ring_hash_lb_; + grpc_closure closure_; + absl::InlinedVector, 10> subchannels_; + }; + + RefCountedPtr parent_; + + // A ring of subchannels. + std::vector ring_; + }; + + void ShutdownLocked() override; + + // Current config from resolver. + RefCountedPtr config_; + + // list of subchannels. + OrphanablePtr subchannel_list_; + // indicating if we are shutting down. + bool shutdown_ = false; +}; + +// +// RingHash::Picker +// + +RingHash::Picker::Picker(RefCountedPtr parent, + RingHashSubchannelList* subchannel_list) + : parent_(std::move(parent)) { + size_t num_subchannels = subchannel_list->num_subchannels(); + // Store the weights while finding the sum. + struct AddressWeight { + std::string address; + // Default weight is 1 for the cases where a weight is not provided, + // each occurrence of the address will be counted a weight value of 1. + uint32_t weight = 1; + double normalized_weight; + }; + std::vector address_weights; + size_t sum = 0; + address_weights.reserve(num_subchannels); + for (size_t i = 0; i < num_subchannels; ++i) { + RingHashSubchannelData* sd = subchannel_list->subchannel(i); + const ServerAddressWeightAttribute* weight_attribute = static_cast< + const ServerAddressWeightAttribute*>(sd->address().GetAttribute( + ServerAddressWeightAttribute::kServerAddressWeightAttributeKey)); + AddressWeight address_weight; + address_weight.address = + grpc_sockaddr_to_string(&sd->address().address(), false); + if (weight_attribute != nullptr) { + GPR_ASSERT(weight_attribute->weight() != 0); + address_weight.weight = weight_attribute->weight(); + } + sum += address_weight.weight; + address_weights.push_back(std::move(address_weight)); + } + // Calculating normalized weights and find min and max. + double min_normalized_weight = 1.0; + double max_normalized_weight = 0.0; + for (auto& address : address_weights) { + address.normalized_weight = static_cast(address.weight) / sum; + min_normalized_weight = + std::min(address.normalized_weight, min_normalized_weight); + max_normalized_weight = + std::max(address.normalized_weight, max_normalized_weight); + } + // Scale up the number of hashes per host such that the least-weighted host + // gets a whole number of hashes on the ring. Other hosts might not end up + // with whole numbers, and that's fine (the ring-building algorithm below can + // handle this). This preserves the original implementation's behavior: when + // weights aren't provided, all hosts should get an equal number of hashes. In + // the case where this number exceeds the max_ring_size, it's scaled back down + // to fit. + const size_t min_ring_size = parent_->config_->min_ring_size(); + const size_t max_ring_size = parent_->config_->max_ring_size(); + const double scale = std::min( + std::ceil(min_normalized_weight * min_ring_size) / min_normalized_weight, + static_cast(max_ring_size)); + // Reserve memory for the entire ring up front. + const uint64_t ring_size = std::ceil(scale); + ring_.reserve(ring_size); + // Populate the hash ring by walking through the (host, weight) pairs in + // normalized_host_weights, and generating (scale * weight) hashes for each + // host. Since these aren't necessarily whole numbers, we maintain running + // sums -- current_hashes and target_hashes -- which allows us to populate the + // ring in a mostly stable way. + absl::InlinedVector hash_key_buffer; + double current_hashes = 0.0; + double target_hashes = 0.0; + uint64_t min_hashes_per_host = ring_size; + uint64_t max_hashes_per_host = 0; + for (size_t i = 0; i < num_subchannels; ++i) { + const std::string& address_string = address_weights[i].address; + hash_key_buffer.assign(address_string.begin(), address_string.end()); + hash_key_buffer.emplace_back('_'); + auto offset_start = hash_key_buffer.end(); + target_hashes += scale * address_weights[i].normalized_weight; + size_t count = 0; + auto current_state = + subchannel_list->subchannel(i)->subchannel()->CheckConnectivityState(); + while (current_hashes < target_hashes) { + const std::string count_str = absl::StrCat(count); + hash_key_buffer.insert(offset_start, count_str.begin(), count_str.end()); + absl::string_view hash_key(hash_key_buffer.data(), + hash_key_buffer.size()); + const uint64_t hash = XXH64(hash_key.data(), hash_key.size(), 0); + ring_.push_back({hash, + subchannel_list->subchannel(i)->subchannel()->Ref(), + current_state}); + ++count; + ++current_hashes; + hash_key_buffer.erase(offset_start, hash_key_buffer.end()); + } + min_hashes_per_host = + std::min(static_cast(i), min_hashes_per_host); + max_hashes_per_host = + std::max(static_cast(i), max_hashes_per_host); + } + std::sort(ring_.begin(), ring_.end(), + [](const RingEntry& lhs, const RingEntry& rhs) -> bool { + return lhs.hash < rhs.hash; + }); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, + "[RH %p picker %p] created picker from subchannel_list=%p " + "with %" PRIuPTR " ring entries", + parent_.get(), this, subchannel_list, ring_.size()); + } +} + +RingHash::PickResult RingHash::Picker::Pick(PickArgs args) { + auto hash = + args.call_state->ExperimentalGetCallAttribute(kRequestRingHashAttribute); + uint64_t h; + if (!absl::SimpleAtoi(hash, &h)) { + return PickResult::Fail( + absl::InternalError("xds ring hash value is not a number")); + } + // Ported from https://github.com/RJ/ketama/blob/master/libketama/ketama.c + // (ketama_get_server) NOTE: The algorithm depends on using signed integers + // for lowp, highp, and first_index. Do not change them! + int64_t lowp = 0; + int64_t highp = ring_.size(); + int64_t first_index = 0; + while (true) { + first_index = (lowp + highp) / 2; + if (first_index == static_cast(ring_.size())) { + first_index = 0; + break; + } + uint64_t midval = ring_[first_index].hash; + uint64_t midval1 = first_index == 0 ? 0 : ring_[first_index - 1].hash; + if (h <= midval && h > midval1) { + break; + } + if (midval < h) { + lowp = first_index + 1; + } else { + highp = first_index - 1; + } + if (lowp > highp) { + first_index = 0; + break; + } + } + OrphanablePtr subchannel_connection_attempter; + auto ScheduleSubchannelConnectionAttempt = + [&](RefCountedPtr subchannel) { + if (subchannel_connection_attempter == nullptr) { + subchannel_connection_attempter = + MakeOrphanable(parent_); + } + subchannel_connection_attempter->AddSubchannel(std::move(subchannel)); + }; + switch (ring_[first_index].connectivity_state) { + case GRPC_CHANNEL_READY: + return PickResult::Complete(ring_[first_index].subchannel); + case GRPC_CHANNEL_IDLE: + ScheduleSubchannelConnectionAttempt(ring_[first_index].subchannel); + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHANNEL_CONNECTING: + return PickResult::Queue(); + default: // GRPC_CHANNEL_TRANSIENT_FAILURE + break; + } + ScheduleSubchannelConnectionAttempt(ring_[first_index].subchannel); + // Loop through remaining subchannels to find one in READY. + // On the way, we make sure the right set of connection attempts + // will happen. + bool found_second_subchannel = false; + bool found_first_non_failed = false; + for (size_t i = 1; i < ring_.size(); ++i) { + const RingEntry& entry = ring_[(first_index + i) % ring_.size()]; + if (entry.subchannel == ring_[first_index].subchannel) { + continue; + } + if (entry.connectivity_state == GRPC_CHANNEL_READY) { + return PickResult::Complete(entry.subchannel); + } + if (!found_second_subchannel) { + switch (entry.connectivity_state) { + case GRPC_CHANNEL_IDLE: + ScheduleSubchannelConnectionAttempt(entry.subchannel); + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHANNEL_CONNECTING: + return PickResult::Queue(); + default: + break; + } + found_second_subchannel = true; + } + if (!found_first_non_failed) { + if (entry.connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + ScheduleSubchannelConnectionAttempt(entry.subchannel); + } else { + if (entry.connectivity_state == GRPC_CHANNEL_IDLE) { + ScheduleSubchannelConnectionAttempt(entry.subchannel); + } + found_first_non_failed = true; + } + } + } + return PickResult::Fail(absl::UnavailableError( + "xds ring hash found a subchannel that is in TRANSIENT_FAILURE state")); +} + +// +// RingHash::RingHashSubchannelList +// + +void RingHash::RingHashSubchannelList::StartWatchingLocked() { + if (num_subchannels() == 0) return; + // Check current state of each subchannel synchronously. + for (size_t i = 0; i < num_subchannels(); ++i) { + grpc_connectivity_state state = + subchannel(i)->CheckConnectivityStateLocked(); + subchannel(i)->UpdateConnectivityStateLocked(state); + } + // Start connectivity watch for each subchannel. + for (size_t i = 0; i < num_subchannels(); i++) { + if (subchannel(i)->subchannel() != nullptr) { + subchannel(i)->StartConnectivityWatchLocked(); + } + } + RingHash* p = static_cast(policy()); + // Sending up the initial picker while all subchannels are in IDLE state. + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), + absl::make_unique(p->Ref(DEBUG_LOCATION, "RingHashPicker"), + this)); +} + +void RingHash::RingHashSubchannelList::UpdateStateCountersLocked( + grpc_connectivity_state old_state, grpc_connectivity_state new_state) { + GPR_ASSERT(new_state != GRPC_CHANNEL_SHUTDOWN); + if (old_state == GRPC_CHANNEL_IDLE) { + GPR_ASSERT(num_idle_ > 0); + --num_idle_; + } else if (old_state == GRPC_CHANNEL_READY) { + GPR_ASSERT(num_ready_ > 0); + --num_ready_; + } else if (old_state == GRPC_CHANNEL_CONNECTING) { + GPR_ASSERT(num_connecting_ > 0); + --num_connecting_; + } else if (old_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + GPR_ASSERT(num_transient_failure_ > 0); + --num_transient_failure_; + } + if (new_state == GRPC_CHANNEL_IDLE) { + ++num_idle_; + } else if (new_state == GRPC_CHANNEL_READY) { + ++num_ready_; + } else if (new_state == GRPC_CHANNEL_CONNECTING) { + ++num_connecting_; + } else if (new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + ++num_transient_failure_; + } +} + +// Sets the RH policy's connectivity state and generates a new picker based +// on the current subchannel list or requests an re-attempt by returning true.. +bool RingHash::RingHashSubchannelList::UpdateRingHashConnectivityStateLocked() { + RingHash* p = static_cast(policy()); + // Only set connectivity state if this is the current subchannel list. + if (p->subchannel_list_.get() != this) return false; + // The overall aggregation rules here are: + // 1. If there is at least one subchannel in READY state, report READY. + // 2. If there are 2 or more subchannels in TRANSIENT_FAILURE state, report + // TRANSIENT_FAILURE. + // 3. If there is at least one subchannel in CONNECTING state, report + // CONNECTING. + // 4. If there is at least one subchannel in IDLE state, report IDLE. + // 5. Otherwise, report TRANSIENT_FAILURE. + if (num_ready_ > 0) { + /* READY */ + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), + absl::make_unique(p->Ref(DEBUG_LOCATION, "RingHashPicker"), + this)); + return false; + } + if (num_connecting_ > 0 && num_transient_failure_ < 2) { + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique(p->Ref(DEBUG_LOCATION, "QueuePicker"))); + return false; + } + if (num_idle_ > 0 && num_transient_failure_ < 2) { + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_IDLE, absl::Status(), + absl::make_unique(p->Ref(DEBUG_LOCATION, "RingHashPicker"), + this)); + return false; + } + absl::Status status = + absl::UnavailableError("connections to backend failing or idle"); + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + return true; +} + +// +// RingHash::RingHashSubchannelData +// + +void RingHash::RingHashSubchannelData::UpdateConnectivityStateLocked( + grpc_connectivity_state connectivity_state) { + RingHash* p = static_cast(subchannel_list()->policy()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log( + GPR_INFO, + "[RR %p] connectivity changed for subchannel %p, subchannel_list %p " + "(index %" PRIuPTR " of %" PRIuPTR "): prev_state=%s new_state=%s", + p, subchannel(), subchannel_list(), Index(), + subchannel_list()->num_subchannels(), + ConnectivityStateName(last_connectivity_state_), + ConnectivityStateName(connectivity_state)); + } + // Decide what state to report for aggregation purposes. + // If we haven't seen a failure since the last time we were in state + // READY, then we report the state change as-is. However, once we do see + // a failure, we report TRANSIENT_FAILURE and do not report any subsequent + // state changes until we go back into state READY. + if (!seen_failure_since_ready_) { + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + seen_failure_since_ready_ = true; + } + subchannel_list()->UpdateStateCountersLocked(last_connectivity_state_, + connectivity_state); + } else { + if (connectivity_state == GRPC_CHANNEL_READY) { + seen_failure_since_ready_ = false; + subchannel_list()->UpdateStateCountersLocked( + GRPC_CHANNEL_TRANSIENT_FAILURE, connectivity_state); + } + } + // Record last seen connectivity state. + last_connectivity_state_ = connectivity_state; +} + +void RingHash::RingHashSubchannelData::ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) { + RingHash* p = static_cast(subchannel_list()->policy()); + GPR_ASSERT(subchannel() != nullptr); + // If the new state is TRANSIENT_FAILURE, re-resolve. + // Only do this if we've started watching, not at startup time. + // Otherwise, if the subchannel was already in state TRANSIENT_FAILURE + // when the subchannel list was created, we'd wind up in a constant + // loop of re-resolution. + // Also attempt to reconnect. + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, + "[RR %p] Subchannel %p has gone into TRANSIENT_FAILURE. " + "Requesting re-resolution", + p, subchannel()); + } + p->channel_control_helper()->RequestReresolution(); + } + // Update state counters. + UpdateConnectivityStateLocked(connectivity_state); + // Update the RH policy's connectivity state, creating new picker and new + // ring. + bool transient_failure = + subchannel_list()->UpdateRingHashConnectivityStateLocked(); + // While the ring_hash policy is reporting TRANSIENT_FAILURE, it will + // not be getting any pick requests from the priority policy. + // However, because the ring_hash policy does not attempt to + // reconnect to subchannels unless it is getting pick requests, + // it will need special handling to ensure that it will eventually + // recover from TRANSIENT_FAILURE state once the problem is resolved. + // Specifically, it will make sure that it is attempting to connect to + // at least one subchannel at any given time. After a given subchannel + // fails a connection attempt, it will move on to the next subchannel + // in the ring. It will keep doing this until one of the subchannels + // successfully connects, at which point it will report READY and stop + // proactively trying to connect. The policy will remain in + // TRANSIENT_FAILURE until at least one subchannel becomes connected, + // even if subchannels are in state CONNECTING during that time. + if (transient_failure && + connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + size_t next_index = (Index() + 1) % subchannel_list()->num_subchannels(); + RingHashSubchannelData* next_sd = subchannel_list()->subchannel(next_index); + next_sd->subchannel()->AttemptToConnect(); + } +} + +// +// RingHash +// + +RingHash::RingHash(Args args) : LoadBalancingPolicy(std::move(args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, "[RH %p] Created", this); + } +} + +RingHash::~RingHash() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, "[RH %p] Destroying Ring Hash policy", this); + } + GPR_ASSERT(subchannel_list_ == nullptr); +} + +void RingHash::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, "[RH %p] Shutting down", this); + } + shutdown_ = true; + subchannel_list_.reset(); +} + +void RingHash::ResetBackoffLocked() { subchannel_list_->ResetBackoffLocked(); } + +void RingHash::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) { + gpr_log(GPR_INFO, "[RR %p] received update with %" PRIuPTR " addresses", + this, args.addresses.size()); + } + config_ = std::move(args.config); + // Filter out any address with weight 0. + ServerAddressList addresses; + addresses.reserve(args.addresses.size()); + for (ServerAddress& address : args.addresses) { + const ServerAddressWeightAttribute* weight_attribute = + static_cast(address.GetAttribute( + ServerAddressWeightAttribute::kServerAddressWeightAttributeKey)); + if (weight_attribute == nullptr || weight_attribute->weight() > 0) { + addresses.push_back(std::move(address)); + } + } + subchannel_list_ = MakeOrphanable( + this, &grpc_lb_ring_hash_trace, std::move(addresses), *args.args); + if (subchannel_list_->num_subchannels() == 0) { + // If the new list is empty, immediately transition to TRANSIENT_FAILURE. + absl::Status status = absl::UnavailableError("Empty update"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } else { + // Start watching the new list. + subchannel_list_->StartWatchingLocked(); + } +} + +// +// factory +// + +class RingHashFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kRingHash; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + size_t min_ring_size; + size_t max_ring_size; + std::vector error_list; + ParseRingHashLbConfig(json, &min_ring_size, &max_ring_size, &error_list); + if (error_list.empty()) { + return MakeRefCounted(min_ring_size, max_ring_size); + } else { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "ring_hash_experimental LB policy config", &error_list); + return nullptr; + } + } +}; + +} // namespace + +void GrpcLbPolicyRingHashInit() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void GrpcLbPolicyRingHashShutdown() {} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/rls/rls.cc b/src/core/ext/filters/client_channel/lb_policy/rls/rls.cc new file mode 100644 index 00000000..2b0483a6 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/rls/rls.cc @@ -0,0 +1,2502 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Implementation of the Route Lookup Service (RLS) LB policy +// +// The policy queries a route lookup service for the name of the actual service +// to use. A child policy that recognizes the name as a field of its +// configuration will take further load balancing action on the request. + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "upb/upb.hpp" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/dual_ref_counted.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/json/json_util.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/uri/uri_parser.h" +#include "src/proto/grpc/lookup/v1/rls.upb.h" + +namespace grpc_core { + +TraceFlag grpc_lb_rls_trace(false, "rls_lb"); + +namespace { + +const char* kRls = "rls"; +const char kGrpc[] = "grpc"; +const char* kRlsRequestPath = "/grpc.lookup.v1.RouteLookupService/RouteLookup"; +const char* kFakeTargetFieldValue = "fake_target_field_value"; +const char* kRlsHeaderKey = "X-Google-RLS-Data"; + +const grpc_millis kDefaultLookupServiceTimeout = 10000; +const grpc_millis kMaxMaxAge = 5 * 60 * GPR_MS_PER_SEC; +const grpc_millis kMinExpirationTime = 5 * GPR_MS_PER_SEC; +const grpc_millis kCacheBackoffInitial = 1 * GPR_MS_PER_SEC; +const double kCacheBackoffMultiplier = 1.6; +const double kCacheBackoffJitter = 0.2; +const grpc_millis kCacheBackoffMax = 120 * GPR_MS_PER_SEC; +const grpc_millis kDefaultThrottleWindowSize = 30 * GPR_MS_PER_SEC; +const double kDefaultThrottleRatioForSuccesses = 2.0; +const int kDefaultThrottlePaddings = 8; +const grpc_millis kCacheCleanupTimerInterval = 60 * GPR_MS_PER_SEC; +const int64_t kMaxCacheSizeBytes = 5 * 1024 * 1024; + +// Parsed RLS LB policy configuration. +class RlsLbConfig : public LoadBalancingPolicy::Config { + public: + struct KeyBuilder { + std::map> + header_keys; + std::string host_key; + std::string service_key; + std::string method_key; + std::map constant_keys; + }; + using KeyBuilderMap = std::unordered_map; + + struct RouteLookupConfig { + KeyBuilderMap key_builder_map; + std::string lookup_service; + grpc_millis lookup_service_timeout = 0; + grpc_millis max_age = 0; + grpc_millis stale_age = 0; + int64_t cache_size_bytes = 0; + std::string default_target; + }; + + RlsLbConfig(RouteLookupConfig route_lookup_config, Json child_policy_config, + std::string child_policy_config_target_field_name, + RefCountedPtr + default_child_policy_parsed_config) + : route_lookup_config_(std::move(route_lookup_config)), + child_policy_config_(std::move(child_policy_config)), + child_policy_config_target_field_name_( + std::move(child_policy_config_target_field_name)), + default_child_policy_parsed_config_( + std::move(default_child_policy_parsed_config)) {} + + const char* name() const override { return kRls; } + + const KeyBuilderMap& key_builder_map() const { + return route_lookup_config_.key_builder_map; + } + const std::string& lookup_service() const { + return route_lookup_config_.lookup_service; + } + grpc_millis lookup_service_timeout() const { + return route_lookup_config_.lookup_service_timeout; + } + grpc_millis max_age() const { return route_lookup_config_.max_age; } + grpc_millis stale_age() const { return route_lookup_config_.stale_age; } + int64_t cache_size_bytes() const { + return route_lookup_config_.cache_size_bytes; + } + const std::string& default_target() const { + return route_lookup_config_.default_target; + } + const Json& child_policy_config() const { return child_policy_config_; } + const std::string& child_policy_config_target_field_name() const { + return child_policy_config_target_field_name_; + } + RefCountedPtr + default_child_policy_parsed_config() const { + return default_child_policy_parsed_config_; + } + + private: + RouteLookupConfig route_lookup_config_; + Json child_policy_config_; + std::string child_policy_config_target_field_name_; + RefCountedPtr + default_child_policy_parsed_config_; +}; + +// RLS LB policy. +class RlsLb : public LoadBalancingPolicy { + public: + explicit RlsLb(Args args); + + const char* name() const override { return kRls; } + void UpdateLocked(UpdateArgs args) override; + void ExitIdleLocked() override; + void ResetBackoffLocked() override; + + private: + // Key to access entries in the cache and the request map. + struct RequestKey { + std::map key_map; + + bool operator==(const RequestKey& rhs) const { + return key_map == rhs.key_map; + } + + template + friend H AbslHashValue(H h, const RequestKey& key) { + std::hash string_hasher; + for (auto& kv : key.key_map) { + h = H::combine(std::move(h), string_hasher(kv.first), + string_hasher(kv.second)); + } + return h; + } + + size_t Size() const { + size_t size = sizeof(RequestKey); + for (auto& kv : key_map) { + size += kv.first.length() + kv.second.length(); + } + return size; + } + + std::string ToString() const { + return absl::StrCat( + "{", absl::StrJoin(key_map, ",", absl::PairFormatter("=")), "}"); + } + }; + + // Data from an RLS response. + struct ResponseInfo { + absl::Status status; + std::vector targets; + std::string header_data; + + std::string ToString() const { + return absl::StrFormat("{status=%s, targets=[%s], header_data=\"%s\"}", + status.ToString(), absl::StrJoin(targets, ","), + header_data); + } + }; + + // Wraps a child policy for a given RLS target. + class ChildPolicyWrapper : public DualRefCounted { + public: + ChildPolicyWrapper(RefCountedPtr lb_policy, std::string target); + + // Note: We are forced to disable lock analysis here because + // Orphan() is called by OrphanablePtr<>, which cannot have lock + // annotations for this particular caller. + void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; + + const std::string& target() const { return target_; } + + PickResult Pick(PickArgs args) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return picker_->Pick(args); + } + + // Updates for the child policy are handled in two phases: + // 1. In StartUpdate(), we parse and validate the new child policy + // config and store the parsed config. + // 2. In MaybeFinishUpdate(), we actually pass the parsed config to the + // child policy's UpdateLocked() method. + // + // The reason we do this is to avoid deadlocks. In StartUpdate(), + // if the new config fails to validate, then we need to set + // picker_ to an instance that will fail all requests, which + // requires holding the lock. However, we cannot call the child + // policy's UpdateLocked() method from MaybeFinishUpdate() while + // holding the lock, since that would cause a deadlock: the child's + // UpdateLocked() will call the helper's UpdateState() method, which + // will try to acquire the lock to set picker_. So StartUpdate() is + // called while we are still holding the lock, but MaybeFinishUpdate() + // is called after releasing it. + // + // Both methods grab the data they need from the parent object. + void StartUpdate() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + // Does not take ownership of channel_args. + void MaybeFinishUpdate() ABSL_LOCKS_EXCLUDED(&RlsLb::mu_); + + void ExitIdleLocked() { + if (child_policy_ != nullptr) child_policy_->ExitIdleLocked(); + } + + void ResetBackoffLocked() { + if (child_policy_ != nullptr) child_policy_->ResetBackoffLocked(); + } + + // Gets the connectivity state of the child policy. Once the child policy + // reports TRANSIENT_FAILURE, the function will always return + // TRANSIENT_FAILURE state instead of the actual state of the child policy + // until the child policy reports another READY state. + grpc_connectivity_state connectivity_state() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return connectivity_state_; + } + + private: + // ChannelControlHelper object that allows the child policy to update state + // with the wrapper. + class ChildPolicyHelper : public LoadBalancingPolicy::ChannelControlHelper { + public: + explicit ChildPolicyHelper(WeakRefCountedPtr wrapper) + : wrapper_(std::move(wrapper)) {} + ~ChildPolicyHelper() override { + wrapper_.reset(DEBUG_LOCATION, "ChildPolicyHelper"); + } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + WeakRefCountedPtr wrapper_; + }; + + RefCountedPtr lb_policy_; + std::string target_; + + bool is_shutdown_ = false; + + OrphanablePtr child_policy_; + RefCountedPtr pending_config_; + + grpc_connectivity_state connectivity_state_ ABSL_GUARDED_BY(&RlsLb::mu_) = + GRPC_CHANNEL_IDLE; + std::unique_ptr picker_ + ABSL_GUARDED_BY(&RlsLb::mu_); + }; + + // A picker that uses the cache and the request map in the LB policy + // (synchronized via a mutex) to determine how to route requests. + class Picker : public LoadBalancingPolicy::SubchannelPicker { + public: + explicit Picker(RefCountedPtr lb_policy); + ~Picker() override; + + PickResult Pick(PickArgs args) override; + + private: + RefCountedPtr lb_policy_; + RefCountedPtr config_; + RefCountedPtr default_child_policy_; + }; + + // An LRU cache with adjustable size. + class Cache { + public: + using Iterator = std::list::iterator; + + class Entry : public InternallyRefCounted { + public: + Entry(RefCountedPtr lb_policy, const RequestKey& key); + + // Notify the entry when it's evicted from the cache. Performs shut down. + // Note: We are forced to disable lock analysis here because + // Orphan() is called by OrphanablePtr<>, which cannot have lock + // annotations for this particular caller. + void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; + + const absl::Status& status() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return status_; + } + grpc_millis backoff_time() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return backoff_time_; + } + grpc_millis backoff_expiration_time() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return backoff_expiration_time_; + } + grpc_millis data_expiration_time() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return data_expiration_time_; + } + const std::string& header_data() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return header_data_; + } + grpc_millis stale_time() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return stale_time_; + } + grpc_millis min_expiration_time() const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return min_expiration_time_; + } + + std::unique_ptr TakeBackoffState() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return std::move(backoff_state_); + } + + // Cache size of entry. + size_t Size() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Pick subchannel for request based on the entry's state. + PickResult Pick(PickArgs args) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // If the cache entry is in backoff state, resets the backoff and, if + // applicable, its backoff timer. The method does not update the LB + // policy's picker; the caller is responsible for that if necessary. + void ResetBackoff() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Check if the entry should be removed by the clean-up timer. + bool ShouldRemove() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Check if the entry can be evicted from the cache, i.e. the + // min_expiration_time_ has passed. + bool CanEvict() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Updates the entry upon reception of a new RLS response. + // Returns a list of child policy wrappers on which FinishUpdate() + // needs to be called after releasing the lock. + std::vector OnRlsResponseLocked( + ResponseInfo response, std::unique_ptr backoff_state) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Moves entry to the end of the LRU list. + void MarkUsed() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + private: + class BackoffTimer : public InternallyRefCounted { + public: + BackoffTimer(RefCountedPtr entry, grpc_millis backoff_time); + + // Note: We are forced to disable lock analysis here because + // Orphan() is called by OrphanablePtr<>, which cannot have lock + // annotations for this particular caller. + void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; + + private: + static void OnBackoffTimer(void* args, grpc_error_handle error); + + RefCountedPtr entry_; + bool armed_ ABSL_GUARDED_BY(&RlsLb::mu_) = true; + grpc_timer backoff_timer_; + grpc_closure backoff_timer_callback_; + }; + + RefCountedPtr lb_policy_; + + bool is_shutdown_ ABSL_GUARDED_BY(&RlsLb::mu_) = false; + + // Backoff states + absl::Status status_ ABSL_GUARDED_BY(&RlsLb::mu_); + std::unique_ptr backoff_state_ ABSL_GUARDED_BY(&RlsLb::mu_); + grpc_millis backoff_time_ ABSL_GUARDED_BY(&RlsLb::mu_) = + GRPC_MILLIS_INF_PAST; + grpc_millis backoff_expiration_time_ ABSL_GUARDED_BY(&RlsLb::mu_) = + GRPC_MILLIS_INF_PAST; + OrphanablePtr backoff_timer_; + + // RLS response states + std::vector> child_policy_wrappers_ + ABSL_GUARDED_BY(&RlsLb::mu_); + std::string header_data_ ABSL_GUARDED_BY(&RlsLb::mu_); + grpc_millis data_expiration_time_ ABSL_GUARDED_BY(&RlsLb::mu_) = + GRPC_MILLIS_INF_PAST; + grpc_millis stale_time_ ABSL_GUARDED_BY(&RlsLb::mu_) = + GRPC_MILLIS_INF_PAST; + + grpc_millis min_expiration_time_ ABSL_GUARDED_BY(&RlsLb::mu_); + Cache::Iterator lru_iterator_ ABSL_GUARDED_BY(&RlsLb::mu_); + }; + + explicit Cache(RlsLb* lb_policy); + + // Finds an entry from the cache that corresponds to a key. If an entry is + // not found, nullptr is returned. Otherwise, the entry is considered + // recently used and its order in the LRU list of the cache is updated. + Entry* Find(const RequestKey& key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Finds an entry from the cache that corresponds to a key. If an entry is + // not found, an entry is created, inserted in the cache, and returned to + // the caller. Otherwise, the entry found is returned to the caller. The + // entry returned to the user is considered recently used and its order in + // the LRU list of the cache is updated. + Entry* FindOrInsert(const RequestKey& key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Resizes the cache. If the new cache size is greater than the current size + // of the cache, do nothing. Otherwise, evict the oldest entries that + // exceed the new size limit of the cache. + void Resize(size_t bytes) ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Resets backoff of all the cache entries. + void ResetAllBackoff() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Shutdown the cache; clean-up and orphan all the stored cache entries. + void Shutdown() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + private: + static void OnCleanupTimer(void* arg, grpc_error_handle error); + + // Returns the entry size for a given key. + static size_t EntrySizeForKey(const RequestKey& key); + + // Evicts oversized cache elements when the current size is greater than + // the specified limit. + void MaybeShrinkSize(size_t bytes) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + RlsLb* lb_policy_; + + size_t size_limit_ ABSL_GUARDED_BY(&RlsLb::mu_) = 0; + size_t size_ ABSL_GUARDED_BY(&RlsLb::mu_) = 0; + + std::list lru_list_ ABSL_GUARDED_BY(&RlsLb::mu_); + std::unordered_map, absl::Hash> + map_ ABSL_GUARDED_BY(&RlsLb::mu_); + grpc_timer cleanup_timer_; + grpc_closure timer_callback_; + }; + + // Channel for communicating with the RLS server. + // Contains throttling logic for RLS requests. + class RlsChannel : public InternallyRefCounted { + public: + RlsChannel(RefCountedPtr lb_policy, const std::string& target, + const grpc_channel_args* parent_channel_args); + + // Shuts down the channel. + void Orphan() override; + + // Starts an RLS call. + // If stale_entry is non-null, it points to the entry containing + // stale data for the key. + void StartRlsCall(const RequestKey& key, Cache::Entry* stale_entry) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Reports the result of an RLS call to the throttle. + void ReportResponseLocked(bool response_succeeded) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + // Checks if a proposed RLS call should be throttled. + bool ShouldThrottle() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + return throttle_.ShouldThrottle(); + } + + // Resets the channel's backoff. + void ResetBackoff(); + + grpc_channel* channel() const { return channel_; } + + private: + // Watches the state of the RLS channel. Notifies the LB policy when + // the channel was previously in TRANSIENT_FAILURE and then becomes READY. + class StateWatcher : public AsyncConnectivityStateWatcherInterface { + public: + explicit StateWatcher(RefCountedPtr rls_channel) + : AsyncConnectivityStateWatcherInterface( + rls_channel->lb_policy_->work_serializer()), + rls_channel_(std::move(rls_channel)) {} + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& status) override; + + RefCountedPtr rls_channel_; + bool was_transient_failure_ = false; + }; + + // Throttle state for RLS requests. + class Throttle { + public: + explicit Throttle(int window_size_seconds = 0, + double ratio_for_successes = 0, int paddings = 0); + + bool ShouldThrottle() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + void RegisterResponse(bool success) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_); + + private: + grpc_millis window_size_; + double ratio_for_successes_; + int paddings_; + + // Logged timestamp of requests. + std::deque requests_ ABSL_GUARDED_BY(&RlsLb::mu_); + + // Logged timestamp of responses that were successful. + std::deque successes_ ABSL_GUARDED_BY(&RlsLb::mu_); + }; + + RefCountedPtr lb_policy_; + bool is_shutdown_ = false; + + grpc_channel* channel_ = nullptr; + RefCountedPtr parent_channelz_node_; + StateWatcher* watcher_ = nullptr; + Throttle throttle_ ABSL_GUARDED_BY(&RlsLb::mu_); + }; + + // A pending RLS request. Instances will be tracked in request_map_. + class RlsRequest : public InternallyRefCounted { + public: + // Asynchronously starts a call on rls_channel for key. + // Stores backoff_state, which will be transferred to the data cache + // if the RLS request fails. + RlsRequest(RefCountedPtr lb_policy, RlsLb::RequestKey key, + RefCountedPtr rls_channel, + std::unique_ptr backoff_state, + grpc_lookup_v1_RouteLookupRequest_Reason reason, + std::string stale_header_data); + ~RlsRequest() override; + + // Shuts down the request. If the request is still in flight, it is + // cancelled, in which case no response will be added to the cache. + void Orphan() override; + + private: + // Callback to be invoked to start the call. + static void StartCall(void* arg, grpc_error_handle error); + + // Helper for StartCall() that runs within the WorkSerializer. + void StartCallLocked(); + + // Callback to be invoked when the call is completed. + static void OnRlsCallComplete(void* arg, grpc_error_handle error); + + // Call completion callback running on LB policy WorkSerializer. + void OnRlsCallCompleteLocked(grpc_error_handle error); + + grpc_byte_buffer* MakeRequestProto(); + ResponseInfo ParseResponseProto(); + + RefCountedPtr lb_policy_; + RlsLb::RequestKey key_; + RefCountedPtr rls_channel_; + std::unique_ptr backoff_state_; + grpc_lookup_v1_RouteLookupRequest_Reason reason_; + std::string stale_header_data_; + + // RLS call state. + grpc_millis deadline_; + grpc_closure call_start_cb_; + grpc_closure call_complete_cb_; + grpc_call* call_ = nullptr; + grpc_byte_buffer* send_message_ = nullptr; + grpc_metadata_array recv_initial_metadata_; + grpc_byte_buffer* recv_message_ = nullptr; + grpc_metadata_array recv_trailing_metadata_; + grpc_status_code status_recv_; + grpc_slice status_details_recv_; + }; + + void ShutdownLocked() override; + + // Returns a new picker to the channel to trigger reprocessing of + // pending picks. Schedules the actual picker update on the ExecCtx + // to be run later, so it's safe to invoke this while holding the lock. + void UpdatePickerAsync(); + // Hops into work serializer and calls UpdatePickerLocked(). + static void UpdatePickerCallback(void* arg, grpc_error_handle error); + // Updates the picker in the work serializer. + void UpdatePickerLocked() ABSL_LOCKS_EXCLUDED(&mu_); + + // The name of the server for the channel. + std::string server_name_; + + // Mutex to guard LB policy state that is accessed by the picker. + Mutex mu_; + bool is_shutdown_ ABSL_GUARDED_BY(mu_) = false; + Cache cache_ ABSL_GUARDED_BY(mu_); + // Maps an RLS request key to an RlsRequest object that represents a pending + // RLS request. + std::unordered_map, + absl::Hash> + request_map_ ABSL_GUARDED_BY(mu_); + // The channel on which RLS requests are sent. + // Note that this channel may be swapped out when the RLS policy gets + // an update. However, when that happens, any existing entries in + // request_map_ will continue to use the previous channel. + OrphanablePtr rls_channel_ ABSL_GUARDED_BY(mu_); + + // Accessed only from within WorkSerializer. + ServerAddressList addresses_; + const grpc_channel_args* channel_args_ = nullptr; + RefCountedPtr config_; + RefCountedPtr default_child_policy_; + std::map child_policy_map_; +}; + +// +// RlsLb::ChildPolicyWrapper +// + +RlsLb::ChildPolicyWrapper::ChildPolicyWrapper(RefCountedPtr lb_policy, + std::string target) + : DualRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) ? "ChildPolicyWrapper" + : nullptr), + lb_policy_(lb_policy), + target_(std::move(target)), + picker_(absl::make_unique(std::move(lb_policy))) { + lb_policy_->child_policy_map_.emplace(target_, this); +} + +void RlsLb::ChildPolicyWrapper::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] ChildPolicyWrapper=%p [%s]: shutdown", + lb_policy_.get(), this, target_.c_str()); + } + is_shutdown_ = true; + lb_policy_->child_policy_map_.erase(target_); + if (child_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + lb_policy_->interested_parties()); + child_policy_.reset(); + } + picker_.reset(); +} + +grpc_error_handle InsertOrUpdateChildPolicyField(const std::string& field, + const std::string& value, + Json* config) { + if (config->type() != Json::Type::ARRAY) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child policy configuration is not an array"); + } + std::vector error_list; + for (Json& child_json : *config->mutable_array()) { + if (child_json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child policy item is not an object")); + } else { + Json::Object& child = *child_json.mutable_object(); + if (child.size() != 1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child policy item contains more than one field")); + } else { + Json& child_config_json = child.begin()->second; + if (child_config_json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child policy item config is not an object")); + } else { + Json::Object& child_config = *child_config_json.mutable_object(); + child_config[field] = Json(value); + } + } + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("errors when inserting field \"", field, + "\" for child policy"), + &error_list); +} + +void RlsLb::ChildPolicyWrapper::StartUpdate() { + Json child_policy_config = lb_policy_->config_->child_policy_config(); + grpc_error_handle error = InsertOrUpdateChildPolicyField( + lb_policy_->config_->child_policy_config_target_field_name(), target_, + &child_policy_config); + GPR_ASSERT(error == GRPC_ERROR_NONE); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log( + GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s]: validating update, config: %s", + lb_policy_.get(), this, target_.c_str(), + child_policy_config.Dump().c_str()); + } + pending_config_ = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + child_policy_config, &error); + // Returned RLS target fails the validation. + if (error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s]: config failed to parse: " + "%s; config: %s", + lb_policy_.get(), this, target_.c_str(), + grpc_error_std_string(error).c_str(), + child_policy_config.Dump().c_str()); + } + pending_config_.reset(); + picker_ = absl::make_unique( + grpc_error_to_absl_status(error)); + GRPC_ERROR_UNREF(error); + child_policy_.reset(); + } +} + +void RlsLb::ChildPolicyWrapper::MaybeFinishUpdate() { + // If pending_config_ is not set, that means StartUpdate() failed, so + // there's nothing to do here. + if (pending_config_ == nullptr) return; + // If child policy doesn't yet exist, create it. + if (child_policy_ == nullptr) { + Args create_args; + create_args.work_serializer = lb_policy_->work_serializer(); + create_args.channel_control_helper = absl::make_unique( + WeakRef(DEBUG_LOCATION, "ChildPolicyHelper")); + create_args.args = lb_policy_->channel_args_; + child_policy_ = MakeOrphanable(std::move(create_args), + &grpc_lb_rls_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s], created new child policy " + "handler %p", + lb_policy_.get(), this, target_.c_str(), child_policy_.get()); + } + grpc_pollset_set_add_pollset_set(child_policy_->interested_parties(), + lb_policy_->interested_parties()); + } + // Send the child the updated config. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s], updating child policy " + "handler %p", + lb_policy_.get(), this, target_.c_str(), child_policy_.get()); + } + UpdateArgs update_args; + update_args.config = std::move(pending_config_); + update_args.addresses = lb_policy_->addresses_; + update_args.args = grpc_channel_args_copy(lb_policy_->channel_args_); + child_policy_->UpdateLocked(std::move(update_args)); +} + +// +// RlsLb::ChildPolicyWrapper::ChildPolicyHelper +// + +RefCountedPtr +RlsLb::ChildPolicyWrapper::ChildPolicyHelper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s] ChildPolicyHelper=%p: " + "CreateSubchannel() for %s", + wrapper_->lb_policy_.get(), wrapper_.get(), + wrapper_->target_.c_str(), this, address.ToString().c_str()); + } + if (wrapper_->is_shutdown_) return nullptr; + return wrapper_->lb_policy_->channel_control_helper()->CreateSubchannel( + std::move(address), args); +} + +void RlsLb::ChildPolicyWrapper::ChildPolicyHelper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s] ChildPolicyHelper=%p: " + "UpdateState(state=%s, status=%s, picker=%p)", + wrapper_->lb_policy_.get(), wrapper_.get(), + wrapper_->target_.c_str(), this, ConnectivityStateName(state), + status.ToString().c_str(), picker.get()); + } + { + MutexLock lock(&wrapper_->lb_policy_->mu_); + if (wrapper_->is_shutdown_) return; + if (wrapper_->connectivity_state_ == GRPC_CHANNEL_TRANSIENT_FAILURE && + state != GRPC_CHANNEL_READY) { + return; + } + wrapper_->connectivity_state_ = state; + GPR_DEBUG_ASSERT(picker != nullptr); + if (picker != nullptr) { + wrapper_->picker_ = std::move(picker); + } + } + wrapper_->lb_policy_->UpdatePickerLocked(); +} + +void RlsLb::ChildPolicyWrapper::ChildPolicyHelper::RequestReresolution() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] ChildPolicyWrapper=%p [%s] ChildPolicyHelper=%p: " + "RequestReresolution", + wrapper_->lb_policy_.get(), wrapper_.get(), + wrapper_->target_.c_str(), this); + } + if (wrapper_->is_shutdown_) return; + wrapper_->lb_policy_->channel_control_helper()->RequestReresolution(); +} + +absl::string_view RlsLb::ChildPolicyWrapper::ChildPolicyHelper::GetAuthority() { + return wrapper_->lb_policy_->channel_control_helper()->GetAuthority(); +} + +void RlsLb::ChildPolicyWrapper::ChildPolicyHelper::AddTraceEvent( + TraceSeverity severity, absl::string_view message) { + if (wrapper_->is_shutdown_) return; + wrapper_->lb_policy_->channel_control_helper()->AddTraceEvent(severity, + message); +} + +// +// RlsLb::Picker +// + +// Builds the key to be used for a request based on path and initial_metadata. +std::map BuildKeyMap( + const RlsLbConfig::KeyBuilderMap& key_builder_map, absl::string_view path, + const std::string& host, + const LoadBalancingPolicy::MetadataInterface* initial_metadata) { + size_t last_slash_pos = path.npos; // May need this a few times, so cache it. + // Find key builder for this path. + auto it = key_builder_map.find(std::string(path)); + if (it == key_builder_map.end()) { + // Didn't find exact match, try method wildcard. + last_slash_pos = path.rfind("/"); + GPR_DEBUG_ASSERT(last_slash_pos != path.npos); + if (GPR_UNLIKELY(last_slash_pos == path.npos)) return {}; + std::string service(path.substr(0, last_slash_pos + 1)); + it = key_builder_map.find(service); + if (it == key_builder_map.end()) return {}; + } + const RlsLbConfig::KeyBuilder* key_builder = &it->second; + // Construct key map using key builder. + std::map key_map; + // Add header keys. + for (const auto& p : key_builder->header_keys) { + const std::string& key = p.first; + const std::vector& header_names = p.second; + for (const std::string& header_name : header_names) { + std::string buffer; + absl::optional value = + initial_metadata->Lookup(header_name, &buffer); + if (value.has_value()) { + key_map[key] = std::string(*value); + break; + } + } + } + // Add constant keys. + key_map.insert(key_builder->constant_keys.begin(), + key_builder->constant_keys.end()); + // Add host key. + if (!key_builder->host_key.empty()) { + key_map[key_builder->host_key] = host; + } + // Add service key. + if (!key_builder->service_key.empty()) { + if (last_slash_pos == path.npos) { + last_slash_pos = path.rfind("/"); + GPR_DEBUG_ASSERT(last_slash_pos != path.npos); + if (GPR_UNLIKELY(last_slash_pos == path.npos)) return {}; + } + key_map[key_builder->service_key] = + std::string(path.substr(1, last_slash_pos - 1)); + } + // Add method key. + if (!key_builder->method_key.empty()) { + if (last_slash_pos == path.npos) { + last_slash_pos = path.rfind("/"); + GPR_DEBUG_ASSERT(last_slash_pos != path.npos); + if (GPR_UNLIKELY(last_slash_pos == path.npos)) return {}; + } + key_map[key_builder->method_key] = + std::string(path.substr(last_slash_pos + 1)); + } + return key_map; +} + +RlsLb::Picker::Picker(RefCountedPtr lb_policy) + : lb_policy_(std::move(lb_policy)), config_(lb_policy_->config_) { + if (lb_policy_->default_child_policy_ != nullptr) { + default_child_policy_ = + lb_policy_->default_child_policy_->Ref(DEBUG_LOCATION, "Picker"); + } +} + +RlsLb::Picker::~Picker() { + // It's not safe to unref the default child policy in the picker, + // since that needs to be done in the WorkSerializer. + if (default_child_policy_ != nullptr) { + auto* default_child_policy = default_child_policy_.release(); + lb_policy_->work_serializer()->Run( + [default_child_policy]() { + default_child_policy->Unref(DEBUG_LOCATION, "Picker"); + }, + DEBUG_LOCATION); + } +} + +LoadBalancingPolicy::PickResult RlsLb::Picker::Pick(PickArgs args) { + // Construct key for request. + RequestKey key = {BuildKeyMap(config_->key_builder_map(), args.path, + lb_policy_->server_name_, + args.initial_metadata)}; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] picker=%p: request keys: %s", + lb_policy_.get(), this, key.ToString().c_str()); + } + grpc_millis now = ExecCtx::Get()->Now(); + MutexLock lock(&lb_policy_->mu_); + if (lb_policy_->is_shutdown_) { + return PickResult::Fail( + absl::UnavailableError("LB policy already shut down")); + } + // Check if there's a cache entry. + Cache::Entry* entry = lb_policy_->cache_.Find(key); + // If there is no cache entry, or if the cache entry is not in backoff + // and has a stale time in the past, and there is not already a + // pending RLS request for this key, then try to start a new RLS request. + if ((entry == nullptr || + (entry->stale_time() < now && entry->backoff_time() < now)) && + lb_policy_->request_map_.find(key) == lb_policy_->request_map_.end()) { + // Check if requests are being throttled. + if (lb_policy_->rls_channel_->ShouldThrottle()) { + // Request is throttled. + // If there is no non-expired data in the cache, then we use the + // default target if set, or else we fail the pick. + if (entry == nullptr || entry->data_expiration_time() < now) { + if (default_child_policy_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] picker=%p: RLS call throttled; " + "using default target", + lb_policy_.get(), this); + } + return default_child_policy_->Pick(args); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] picker=%p: RLS call throttled; failing pick", + lb_policy_.get(), this); + } + return PickResult::Fail( + absl::UnavailableError("RLS request throttled")); + } + } + // Start the RLS call. + lb_policy_->rls_channel_->StartRlsCall( + key, (entry == nullptr || entry->data_expiration_time() < now) ? nullptr + : entry); + } + // If the cache entry exists, see if it has usable data. + if (entry != nullptr) { + // If the entry has non-expired data, use it. + if (entry->data_expiration_time() >= now) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] picker=%p: using cache entry %p", + lb_policy_.get(), this, entry); + } + return entry->Pick(args); + } + // If the entry is in backoff, then use the default target if set, + // or else fail the pick. + if (entry->backoff_time() >= now) { + if (default_child_policy_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log( + GPR_INFO, + "[rlslb %p] picker=%p: RLS call in backoff; using default target", + lb_policy_.get(), this); + } + return default_child_policy_->Pick(args); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] picker=%p: RLS call in backoff; failing pick", + lb_policy_.get(), this); + } + return PickResult::Fail(entry->status()); + } + } + // RLS call pending. Queue the pick. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] picker=%p: RLS request pending; queuing pick", + lb_policy_.get(), this); + } + return PickResult::Queue(); +} + +// +// RlsLb::Cache::Entry::BackoffTimer +// + +RlsLb::Cache::Entry::BackoffTimer::BackoffTimer(RefCountedPtr entry, + grpc_millis backoff_time) + : entry_(std::move(entry)) { + GRPC_CLOSURE_INIT(&backoff_timer_callback_, OnBackoffTimer, this, nullptr); + Ref(DEBUG_LOCATION, "BackoffTimer").release(); + grpc_timer_init(&backoff_timer_, backoff_time, &backoff_timer_callback_); +} + +void RlsLb::Cache::Entry::BackoffTimer::Orphan() { + if (armed_) { + armed_ = false; + grpc_timer_cancel(&backoff_timer_); + } + Unref(DEBUG_LOCATION, "Orphan"); +} + +void RlsLb::Cache::Entry::BackoffTimer::OnBackoffTimer( + void* arg, grpc_error_handle /*error*/) { + auto* self = static_cast(arg); + self->entry_->lb_policy_->work_serializer()->Run( + [self]() { + RefCountedPtr backoff_timer(self); + { + MutexLock lock(&self->entry_->lb_policy_->mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] cache entry=%p %s, armed_=%d: " + "backoff timer fired", + self->entry_->lb_policy_.get(), self->entry_.get(), + self->entry_->is_shutdown_ + ? "(shut down)" + : self->entry_->lru_iterator_->ToString().c_str(), + self->armed_); + } + bool cancelled = !self->armed_; + self->armed_ = false; + if (cancelled) return; + } + // The pick was in backoff state and there could be a pick queued if + // wait_for_ready is true. We'll update the picker for that case. + self->entry_->lb_policy_->UpdatePickerLocked(); + }, + DEBUG_LOCATION); +} + +// +// RlsLb::Cache::Entry +// + +std::unique_ptr MakeCacheEntryBackoff() { + return absl::make_unique( + BackOff::Options() + .set_initial_backoff(kCacheBackoffInitial) + .set_multiplier(kCacheBackoffMultiplier) + .set_jitter(kCacheBackoffJitter) + .set_max_backoff(kCacheBackoffMax)); +} + +RlsLb::Cache::Entry::Entry(RefCountedPtr lb_policy, + const RequestKey& key) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) ? "CacheEntry" : nullptr), + lb_policy_(std::move(lb_policy)), + backoff_state_(MakeCacheEntryBackoff()), + min_expiration_time_(ExecCtx::Get()->Now() + kMinExpirationTime), + lru_iterator_(lb_policy_->cache_.lru_list_.insert( + lb_policy_->cache_.lru_list_.end(), key)) {} + +void RlsLb::Cache::Entry::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] cache entry=%p %s: cache entry evicted", + lb_policy_.get(), this, lru_iterator_->ToString().c_str()); + } + is_shutdown_ = true; + lb_policy_->cache_.lru_list_.erase(lru_iterator_); + lru_iterator_ = lb_policy_->cache_.lru_list_.end(); // Just in case. + backoff_state_.reset(); + if (backoff_timer_ != nullptr) { + backoff_timer_.reset(); + lb_policy_->UpdatePickerAsync(); + } + child_policy_wrappers_.clear(); + Unref(DEBUG_LOCATION, "Orphan"); +} + +size_t RlsLb::Cache::Entry::Size() const { + // lru_iterator_ is not valid once we're shut down. + GPR_ASSERT(!is_shutdown_); + return lb_policy_->cache_.EntrySizeForKey(*lru_iterator_); +} + +LoadBalancingPolicy::PickResult RlsLb::Cache::Entry::Pick(PickArgs args) { + for (const auto& child_policy_wrapper : child_policy_wrappers_) { + if (child_policy_wrapper->connectivity_state() == + GRPC_CHANNEL_TRANSIENT_FAILURE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] cache entry=%p %s: target %s in state " + "TRANSIENT_FAILURE; skipping", + lb_policy_.get(), this, lru_iterator_->ToString().c_str(), + child_policy_wrapper->target().c_str()); + } + continue; + } + // Child policy not in TRANSIENT_FAILURE, so delegate. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log( + GPR_INFO, + "[rlslb %p] cache entry=%p %s: target %s in state %s; " + "delegating", + lb_policy_.get(), this, lru_iterator_->ToString().c_str(), + child_policy_wrapper->target().c_str(), + ConnectivityStateName(child_policy_wrapper->connectivity_state())); + } + // Add header data. + if (!header_data_.empty()) { + char* copied_header_data = + static_cast(args.call_state->Alloc(header_data_.length() + 1)); + strcpy(copied_header_data, header_data_.c_str()); + args.initial_metadata->Add(kRlsHeaderKey, copied_header_data); + } + return child_policy_wrapper->Pick(args); + } + // No child policy found. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] cache entry=%p %s: no healthy target found; " + "failing pick", + lb_policy_.get(), this, lru_iterator_->ToString().c_str()); + } + return PickResult::Fail( + absl::UnavailableError("all RLS targets unreachable")); +} + +void RlsLb::Cache::Entry::ResetBackoff() { + backoff_time_ = GRPC_MILLIS_INF_PAST; + backoff_timer_.reset(); +} + +bool RlsLb::Cache::Entry::ShouldRemove() const { + grpc_millis now = ExecCtx::Get()->Now(); + return data_expiration_time_ < now && backoff_expiration_time_ < now; +} + +bool RlsLb::Cache::Entry::CanEvict() const { + grpc_millis now = ExecCtx::Get()->Now(); + return min_expiration_time_ < now; +} + +void RlsLb::Cache::Entry::MarkUsed() { + auto& lru_list = lb_policy_->cache_.lru_list_; + auto new_it = lru_list.insert(lru_list.end(), *lru_iterator_); + lru_list.erase(lru_iterator_); + lru_iterator_ = new_it; +} + +std::vector +RlsLb::Cache::Entry::OnRlsResponseLocked( + ResponseInfo response, std::unique_ptr backoff_state) { + // Move the entry to the end of the LRU list. + MarkUsed(); + // If the request failed, store the failed status and update the + // backoff state. + if (!response.status.ok()) { + status_ = response.status; + if (backoff_state != nullptr) { + backoff_state_ = std::move(backoff_state); + } else { + backoff_state_ = MakeCacheEntryBackoff(); + } + backoff_time_ = backoff_state_->NextAttemptTime(); + grpc_millis now = ExecCtx::Get()->Now(); + backoff_expiration_time_ = now + (backoff_time_ - now) * 2; + backoff_timer_ = MakeOrphanable( + Ref(DEBUG_LOCATION, "BackoffTimer"), backoff_time_); + lb_policy_->UpdatePickerAsync(); + return {}; + } + // Request succeeded, so store the result. + header_data_ = std::move(response.header_data); + grpc_millis now = ExecCtx::Get()->Now(); + data_expiration_time_ = now + lb_policy_->config_->max_age(); + stale_time_ = now + lb_policy_->config_->stale_age(); + status_ = absl::OkStatus(); + backoff_state_.reset(); + backoff_time_ = GRPC_MILLIS_INF_PAST; + backoff_expiration_time_ = GRPC_MILLIS_INF_PAST; + // Check if we need to update this list of targets. + bool targets_changed = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&RlsLb::mu_) { + if (child_policy_wrappers_.size() != response.targets.size()) return true; + for (size_t i = 0; i < response.targets.size(); ++i) { + if (child_policy_wrappers_[i]->target() != response.targets[i]) { + return true; + } + } + return false; + }(); + if (!targets_changed) { + // Targets didn't change, so we're not updating the list of child + // policies. Return a new picker so that any queued requests can be + // re-processed. + lb_policy_->UpdatePickerAsync(); + return {}; + } + // Target list changed, so update it. + std::set old_targets; + for (RefCountedPtr& child_policy_wrapper : + child_policy_wrappers_) { + old_targets.emplace(child_policy_wrapper->target()); + } + bool update_picker = false; + std::vector child_policies_to_finish_update; + std::vector> new_child_policy_wrappers; + new_child_policy_wrappers.reserve(response.targets.size()); + for (std::string& target : response.targets) { + auto it = lb_policy_->child_policy_map_.find(target); + if (it == lb_policy_->child_policy_map_.end()) { + auto new_child = MakeRefCounted( + lb_policy_->Ref(DEBUG_LOCATION, "ChildPolicyWrapper"), target); + new_child->StartUpdate(); + child_policies_to_finish_update.push_back(new_child.get()); + new_child_policy_wrappers.emplace_back(std::move(new_child)); + } else { + new_child_policy_wrappers.emplace_back( + it->second->Ref(DEBUG_LOCATION, "CacheEntry")); + // If the target already existed but was not previously used for + // this key, then we'll need to update the picker, since we + // didn't actually create a new child policy, which would have + // triggered an RLS picker update when it returned its first picker. + if (old_targets.find(target) == old_targets.end()) { + update_picker = true; + } + } + } + child_policy_wrappers_ = std::move(new_child_policy_wrappers); + if (update_picker) { + lb_policy_->UpdatePickerAsync(); + } + return child_policies_to_finish_update; +} + +// +// RlsLb::Cache +// + +RlsLb::Cache::Cache(RlsLb* lb_policy) : lb_policy_(lb_policy) { + grpc_millis now = ExecCtx::Get()->Now(); + lb_policy_->Ref(DEBUG_LOCATION, "CacheCleanupTimer").release(); + GRPC_CLOSURE_INIT(&timer_callback_, OnCleanupTimer, this, nullptr); + grpc_timer_init(&cleanup_timer_, now + kCacheCleanupTimerInterval, + &timer_callback_); +} + +RlsLb::Cache::Entry* RlsLb::Cache::Find(const RequestKey& key) { + auto it = map_.find(key); + if (it == map_.end()) return nullptr; + it->second->MarkUsed(); + return it->second.get(); +} + +RlsLb::Cache::Entry* RlsLb::Cache::FindOrInsert(const RequestKey& key) { + auto it = map_.find(key); + // If not found, create new entry. + if (it == map_.end()) { + size_t entry_size = EntrySizeForKey(key); + MaybeShrinkSize(size_limit_ - std::min(size_limit_, entry_size)); + Entry* entry = + new Entry(lb_policy_->Ref(DEBUG_LOCATION, "CacheEntry"), key); + map_.emplace(key, OrphanablePtr(entry)); + size_ += entry_size; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] key=%s: cache entry added, entry=%p", + lb_policy_, key.ToString().c_str(), entry); + } + return entry; + } + // Entry found, so use it. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] key=%s: found cache entry %p", lb_policy_, + key.ToString().c_str(), it->second.get()); + } + it->second->MarkUsed(); + return it->second.get(); +} + +void RlsLb::Cache::Resize(size_t bytes) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] resizing cache to %" PRIuPTR " bytes", + lb_policy_, bytes); + } + size_limit_ = bytes; + MaybeShrinkSize(size_limit_); +} + +void RlsLb::Cache::ResetAllBackoff() { + for (auto& p : map_) { + p.second->ResetBackoff(); + } + lb_policy_->UpdatePickerAsync(); +} + +void RlsLb::Cache::Shutdown() { + map_.clear(); + lru_list_.clear(); + grpc_timer_cancel(&cleanup_timer_); +} + +void RlsLb::Cache::OnCleanupTimer(void* arg, grpc_error_handle error) { + Cache* cache = static_cast(arg); + GRPC_ERROR_REF(error); + cache->lb_policy_->work_serializer()->Run( + [cache, error]() { + RefCountedPtr lb_policy(cache->lb_policy_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] cache cleanup timer fired (%s)", + cache->lb_policy_, grpc_error_std_string(error).c_str()); + } + if (error == GRPC_ERROR_CANCELLED) return; + MutexLock lock(&lb_policy->mu_); + if (lb_policy->is_shutdown_) return; + for (auto it = cache->map_.begin(); it != cache->map_.end();) { + if (GPR_UNLIKELY(it->second->ShouldRemove() && + it->second->CanEvict())) { + cache->size_ -= it->second->Size(); + it = cache->map_.erase(it); + } else { + ++it; + } + } + grpc_millis now = ExecCtx::Get()->Now(); + lb_policy.release(); + grpc_timer_init(&cache->cleanup_timer_, + now + kCacheCleanupTimerInterval, + &cache->timer_callback_); + }, + DEBUG_LOCATION); +} + +size_t RlsLb::Cache::EntrySizeForKey(const RequestKey& key) { + // Key is stored twice, once in LRU list and again in the cache map. + return (key.Size() * 2) + sizeof(Entry); +} + +void RlsLb::Cache::MaybeShrinkSize(size_t bytes) { + while (size_ > bytes) { + auto lru_it = lru_list_.begin(); + if (GPR_UNLIKELY(lru_it == lru_list_.end())) break; + auto map_it = map_.find(*lru_it); + GPR_ASSERT(map_it != map_.end()); + if (!map_it->second->CanEvict()) break; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] LRU eviction: removing entry %p %s", + lb_policy_, map_it->second.get(), lru_it->ToString().c_str()); + } + size_ -= map_it->second->Size(); + map_.erase(map_it); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] LRU pass complete: desired size=%" PRIuPTR + " size=%" PRIuPTR, + lb_policy_, bytes, size_); + } +} + +// +// RlsLb::RlsChannel::StateWatcher +// + +void RlsLb::RlsChannel::StateWatcher::OnConnectivityStateChange( + grpc_connectivity_state new_state, const absl::Status& status) { + auto* lb_policy = rls_channel_->lb_policy_.get(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] RlsChannel=%p StateWatcher=%p: " + "state changed to %s (%s)", + lb_policy, rls_channel_.get(), this, + ConnectivityStateName(new_state), status.ToString().c_str()); + } + if (rls_channel_->is_shutdown_) return; + MutexLock lock(&lb_policy->mu_); + if (new_state == GRPC_CHANNEL_READY && was_transient_failure_) { + was_transient_failure_ = false; + // Reset the backoff of all cache entries, so that we don't + // double-penalize if an RLS request fails while the channel is + // down, since the throttling for the channel being down is handled + // at the channel level instead of in the individual cache entries. + lb_policy->cache_.ResetAllBackoff(); + } else if (new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + was_transient_failure_ = true; + } +} + +// +// RlsLb::RlsChannel::Throttle +// + +RlsLb::RlsChannel::Throttle::Throttle(int window_size_seconds, + double ratio_for_successes, + int paddings) { + GPR_DEBUG_ASSERT(window_size_seconds >= 0); + GPR_DEBUG_ASSERT(ratio_for_successes >= 0); + GPR_DEBUG_ASSERT(paddings >= 0); + window_size_ = window_size_seconds == 0 ? window_size_seconds * GPR_MS_PER_SEC + : kDefaultThrottleWindowSize; + ratio_for_successes_ = ratio_for_successes == 0 + ? kDefaultThrottleRatioForSuccesses + : ratio_for_successes; + paddings_ = paddings == 0 ? kDefaultThrottlePaddings : paddings; +} + +bool RlsLb::RlsChannel::Throttle::ShouldThrottle() { + grpc_millis now = ExecCtx::Get()->Now(); + while (!requests_.empty() && now - requests_.front() > window_size_) { + requests_.pop_front(); + } + while (!successes_.empty() && now - successes_.front() > window_size_) { + successes_.pop_front(); + } + int successes = successes_.size(); + int requests = requests_.size(); + bool result = ((rand() % (requests + paddings_)) < + static_cast(requests) - + static_cast(successes) * ratio_for_successes_); + requests_.push_back(now); + return result; +} + +void RlsLb::RlsChannel::Throttle::RegisterResponse(bool success) { + if (success) { + successes_.push_back(ExecCtx::Get()->Now()); + } +} + +// +// RlsLb::RlsChannel +// + +RlsLb::RlsChannel::RlsChannel(RefCountedPtr lb_policy, + const std::string& target, + const grpc_channel_args* parent_channel_args) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) ? "RlsChannel" : nullptr), + lb_policy_(std::move(lb_policy)) { + // Get channel creds from parent channel. + // TODO(roth): Once we eliminate insecure builds, get this via a + // method on the helper instead of digging through channel args. + grpc_channel_credentials* creds = + grpc_channel_credentials_find_in_args(parent_channel_args); + // Use the parent channel's authority. + std::string authority(lb_policy_->channel_control_helper()->GetAuthority()); + absl::InlinedVector args = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast(authority.c_str())), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_CHANNELZ_IS_INTERNAL_CHANNEL), 1), + }; + // Propagate fake security connector expected targets, if any. + // (This is ugly, but it seems better than propagating all channel args + // from the parent channel by default and then having a giant + // exclude list of args to strip out, like we do in grpclb.) + const char* fake_security_expected_targets = grpc_channel_args_find_string( + parent_channel_args, GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS); + if (fake_security_expected_targets != nullptr) { + args.push_back(grpc_channel_arg_string_create( + const_cast(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS), + const_cast(fake_security_expected_targets))); + } + grpc_channel_args rls_channel_args = {args.size(), args.data()}; + channel_ = grpc_secure_channel_create(creds, target.c_str(), + &rls_channel_args, nullptr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] RlsChannel=%p: created channel %p for %s", + lb_policy_.get(), this, channel_, target.c_str()); + } + if (channel_ != nullptr) { + // Set up channelz linkage. + channelz::ChannelNode* child_channelz_node = + grpc_channel_get_channelz_node(channel_); + channelz::ChannelNode* parent_channelz_node = + grpc_channel_args_find_pointer( + parent_channel_args, GRPC_ARG_CHANNELZ_CHANNEL_NODE); + if (child_channelz_node != nullptr && parent_channelz_node != nullptr) { + parent_channelz_node->AddChildChannel(child_channelz_node->uuid()); + parent_channelz_node_ = parent_channelz_node->Ref(); + } + // Start connectivity watch. + ClientChannel* client_channel = ClientChannel::GetFromChannel(channel_); + GPR_ASSERT(client_channel != nullptr); + watcher_ = new StateWatcher(Ref(DEBUG_LOCATION, "StateWatcher")); + client_channel->AddConnectivityWatcher( + GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher_)); + } +} + +void RlsLb::RlsChannel::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] RlsChannel=%p, channel=%p: shutdown", + lb_policy_.get(), this, channel_); + } + is_shutdown_ = true; + if (channel_ != nullptr) { + // Remove channelz linkage. + if (parent_channelz_node_ != nullptr) { + channelz::ChannelNode* child_channelz_node = + grpc_channel_get_channelz_node(channel_); + GPR_ASSERT(child_channelz_node != nullptr); + parent_channelz_node_->RemoveChildChannel(child_channelz_node->uuid()); + } + // Stop connectivity watch. + if (watcher_ != nullptr) { + ClientChannel* client_channel = ClientChannel::GetFromChannel(channel_); + GPR_ASSERT(client_channel != nullptr); + client_channel->RemoveConnectivityWatcher(watcher_); + watcher_ = nullptr; + } + grpc_channel_destroy(channel_); + } + Unref(DEBUG_LOCATION, "Orphan"); +} + +void RlsLb::RlsChannel::StartRlsCall(const RequestKey& key, + Cache::Entry* stale_entry) { + std::unique_ptr backoff_state; + grpc_lookup_v1_RouteLookupRequest_Reason reason = + grpc_lookup_v1_RouteLookupRequest_REASON_MISS; + std::string stale_header_data; + if (stale_entry != nullptr) { + backoff_state = stale_entry->TakeBackoffState(); + reason = grpc_lookup_v1_RouteLookupRequest_REASON_STALE; + stale_header_data = stale_entry->header_data(); + } + lb_policy_->request_map_.emplace( + key, MakeOrphanable( + lb_policy_->Ref(DEBUG_LOCATION, "RlsRequest"), key, + lb_policy_->rls_channel_->Ref(DEBUG_LOCATION, "RlsRequest"), + std::move(backoff_state), reason, std::move(stale_header_data))); +} + +void RlsLb::RlsChannel::ReportResponseLocked(bool response_succeeded) { + throttle_.RegisterResponse(response_succeeded); +} + +void RlsLb::RlsChannel::ResetBackoff() { + GPR_DEBUG_ASSERT(channel_ != nullptr); + grpc_channel_reset_connect_backoff(channel_); +} + +// +// RlsLb::RlsRequest +// + +RlsLb::RlsRequest::RlsRequest(RefCountedPtr lb_policy, RequestKey key, + RefCountedPtr rls_channel, + std::unique_ptr backoff_state, + grpc_lookup_v1_RouteLookupRequest_Reason reason, + std::string stale_header_data) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) ? "RlsRequest" : nullptr), + lb_policy_(std::move(lb_policy)), + key_(std::move(key)), + rls_channel_(std::move(rls_channel)), + backoff_state_(std::move(backoff_state)), + reason_(reason), + stale_header_data_(std::move(stale_header_data)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] rls_request=%p: RLS request created for key %s", + lb_policy_.get(), this, key_.ToString().c_str()); + } + GRPC_CLOSURE_INIT(&call_complete_cb_, OnRlsCallComplete, this, nullptr); + ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&call_start_cb_, StartCall, + Ref(DEBUG_LOCATION, "StartCall").release(), nullptr), + GRPC_ERROR_NONE); +} + +RlsLb::RlsRequest::~RlsRequest() { GPR_ASSERT(call_ == nullptr); } + +void RlsLb::RlsRequest::Orphan() { + if (call_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] rls_request=%p %s: cancelling RLS call", + lb_policy_.get(), this, key_.ToString().c_str()); + } + grpc_call_cancel_internal(call_); + } + Unref(DEBUG_LOCATION, "Orphan"); +} + +void RlsLb::RlsRequest::StartCall(void* arg, grpc_error_handle /*error*/) { + auto* request = static_cast(arg); + request->lb_policy_->work_serializer()->Run( + [request]() { + request->StartCallLocked(); + request->Unref(DEBUG_LOCATION, "StartCall"); + }, + DEBUG_LOCATION); +} + +void RlsLb::RlsRequest::StartCallLocked() { + { + MutexLock lock(&lb_policy_->mu_); + if (lb_policy_->is_shutdown_) return; + } + grpc_millis now = ExecCtx::Get()->Now(); + deadline_ = now + lb_policy_->config_->lookup_service_timeout(); + grpc_metadata_array_init(&recv_initial_metadata_); + grpc_metadata_array_init(&recv_trailing_metadata_); + call_ = grpc_channel_create_pollset_set_call( + rls_channel_->channel(), nullptr, GRPC_PROPAGATE_DEFAULTS, + lb_policy_->interested_parties(), + grpc_slice_from_static_string(kRlsRequestPath), nullptr, deadline_, + nullptr); + grpc_op ops[6]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + ++op; + op->op = GRPC_OP_SEND_MESSAGE; + send_message_ = MakeRequestProto(); + op->data.send_message.send_message = send_message_; + ++op; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + ++op; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &recv_initial_metadata_; + ++op; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_message_; + ++op; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &recv_trailing_metadata_; + op->data.recv_status_on_client.status = &status_recv_; + op->data.recv_status_on_client.status_details = &status_details_recv_; + ++op; + Ref(DEBUG_LOCATION, "OnRlsCallComplete").release(); + auto call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &call_complete_cb_); + GPR_ASSERT(call_error == GRPC_CALL_OK); +} + +void RlsLb::RlsRequest::OnRlsCallComplete(void* arg, grpc_error_handle error) { + auto* request = static_cast(arg); + GRPC_ERROR_REF(error); + request->lb_policy_->work_serializer()->Run( + [request, error]() { + request->OnRlsCallCompleteLocked(error); + request->Unref(DEBUG_LOCATION, "OnRlsCallComplete"); + }, + DEBUG_LOCATION); +} + +void RlsLb::RlsRequest::OnRlsCallCompleteLocked(grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + std::string status_message(StringViewFromSlice(status_details_recv_)); + gpr_log(GPR_INFO, + "[rlslb %p] rls_request=%p %s, error=%s, status={%d, %s} RLS call " + "response received", + lb_policy_.get(), this, key_.ToString().c_str(), + grpc_error_std_string(error).c_str(), status_recv_, + status_message.c_str()); + } + // Parse response. + ResponseInfo response; + if (error != GRPC_ERROR_NONE) { + grpc_status_code code; + std::string message; + grpc_error_get_status(error, deadline_, &code, &message, + /*http_error=*/nullptr, /*error_string=*/nullptr); + response.status = + absl::Status(static_cast(code), message); + } else if (status_recv_ != GRPC_STATUS_OK) { + response.status = absl::Status(static_cast(status_recv_), + StringViewFromSlice(status_details_recv_)); + } else { + response = ParseResponseProto(); + } + // Clean up call state. + grpc_byte_buffer_destroy(send_message_); + grpc_byte_buffer_destroy(recv_message_); + grpc_metadata_array_destroy(&recv_initial_metadata_); + grpc_metadata_array_destroy(&recv_trailing_metadata_); + grpc_slice_unref_internal(status_details_recv_); + grpc_call_unref(call_); + call_ = nullptr; + // Return result to cache. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] rls_request=%p %s: response info: %s", + lb_policy_.get(), this, key_.ToString().c_str(), + response.ToString().c_str()); + } + std::vector child_policies_to_finish_update; + { + MutexLock lock(&lb_policy_->mu_); + if (lb_policy_->is_shutdown_) return; + rls_channel_->ReportResponseLocked(!response.status.ok()); + Cache::Entry* cache_entry = lb_policy_->cache_.FindOrInsert(key_); + child_policies_to_finish_update = cache_entry->OnRlsResponseLocked( + std::move(response), std::move(backoff_state_)); + lb_policy_->request_map_.erase(key_); + } + // Now that we've released the lock, finish the update on any newly + // created child policies. + for (ChildPolicyWrapper* child : child_policies_to_finish_update) { + child->MaybeFinishUpdate(); + } +} + +grpc_byte_buffer* RlsLb::RlsRequest::MakeRequestProto() { + upb::Arena arena; + grpc_lookup_v1_RouteLookupRequest* req = + grpc_lookup_v1_RouteLookupRequest_new(arena.ptr()); + grpc_lookup_v1_RouteLookupRequest_set_target_type( + req, upb_strview_make(kGrpc, sizeof(kGrpc) - 1)); + for (const auto& kv : key_.key_map) { + grpc_lookup_v1_RouteLookupRequest_key_map_set( + req, upb_strview_make(kv.first.data(), kv.first.size()), + upb_strview_make(kv.second.data(), kv.second.size()), arena.ptr()); + } + grpc_lookup_v1_RouteLookupRequest_set_reason(req, reason_); + if (!stale_header_data_.empty()) { + grpc_lookup_v1_RouteLookupRequest_set_stale_header_data( + req, + upb_strview_make(stale_header_data_.data(), stale_header_data_.size())); + } + size_t len; + char* buf = + grpc_lookup_v1_RouteLookupRequest_serialize(req, arena.ptr(), &len); + grpc_slice send_slice = grpc_slice_from_copied_buffer(buf, len); + grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&send_slice, 1); + grpc_slice_unref_internal(send_slice); + return byte_buffer; +} + +RlsLb::ResponseInfo RlsLb::RlsRequest::ParseResponseProto() { + ResponseInfo response_info; + upb::Arena arena; + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, recv_message_); + grpc_slice recv_slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_byte_buffer_reader_destroy(&bbr); + grpc_lookup_v1_RouteLookupResponse* response = + grpc_lookup_v1_RouteLookupResponse_parse( + reinterpret_cast(GRPC_SLICE_START_PTR(recv_slice)), + GRPC_SLICE_LENGTH(recv_slice), arena.ptr()); + grpc_slice_unref_internal(recv_slice); + if (response == nullptr) { + response_info.status = absl::InternalError("cannot parse RLS response"); + return response_info; + } + size_t num_targets; + const upb_strview* targets_strview = + grpc_lookup_v1_RouteLookupResponse_targets(response, &num_targets); + if (num_targets == 0) { + response_info.status = + absl::InvalidArgumentError("RLS response has no target entry"); + return response_info; + } + response_info.targets.reserve(num_targets); + for (size_t i = 0; i < num_targets; ++i) { + response_info.targets.emplace_back(targets_strview[i].data, + targets_strview[i].size); + } + upb_strview header_data_strview = + grpc_lookup_v1_RouteLookupResponse_header_data(response); + response_info.header_data = + std::string(header_data_strview.data, header_data_strview.size); + return response_info; +} + +// +// RlsLb +// + +std::string GetServerUri(const grpc_channel_args* args) { + const char* server_uri_str = + grpc_channel_args_find_string(args, GRPC_ARG_SERVER_URI); + GPR_ASSERT(server_uri_str != nullptr); + absl::StatusOr uri = URI::Parse(server_uri_str); + GPR_ASSERT(uri.ok()); + return std::string(absl::StripPrefix(uri->path(), "/")); +} + +RlsLb::RlsLb(Args args) + : LoadBalancingPolicy(std::move(args)), + server_name_(GetServerUri(args.args)), + cache_(this) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] policy created", this); + } +} + +void RlsLb::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] policy updated", this); + } + // Swap out config, addresses, and channel args. + RefCountedPtr old_config = std::move(config_); + config_ = std::move(args.config); + ServerAddressList old_addresses = std::move(addresses_); + addresses_ = std::move(args.addresses); + grpc_channel_args_destroy(channel_args_); + channel_args_ = grpc_channel_args_copy(args.args); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) && + (old_config == nullptr || + old_config->child_policy_config() != config_->child_policy_config())) { + gpr_log(GPR_INFO, "[rlslb %p] updated child policy config: %s", this, + config_->child_policy_config().Dump().c_str()); + } + // Determine whether we need to update all child policies. + bool update_child_policies = + old_config == nullptr || + old_config->child_policy_config() != config_->child_policy_config() || + old_addresses != addresses_ || + grpc_channel_args_compare(args.args, channel_args_) != 0; + // If default target changes, swap out child policy. + bool created_default_child = false; + if (old_config == nullptr || + config_->default_target() != old_config->default_target()) { + if (config_->default_target().empty()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] unsetting default target", this); + } + default_child_policy_.reset(); + } else { + auto it = child_policy_map_.find(config_->default_target()); + if (it == child_policy_map_.end()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] creating new default target", this); + } + default_child_policy_ = MakeRefCounted( + Ref(DEBUG_LOCATION, "ChildPolicyWrapper"), + config_->default_target()); + created_default_child = true; + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, + "[rlslb %p] using existing child for default target", this); + } + default_child_policy_ = + it->second->Ref(DEBUG_LOCATION, "DefaultChildPolicy"); + } + } + } + // Now grab the lock to swap out the state it guards. + { + MutexLock lock(&mu_); + // Swap out RLS channel if needed. + if (old_config == nullptr || + config_->lookup_service() != old_config->lookup_service()) { + rls_channel_ = + MakeOrphanable(Ref(DEBUG_LOCATION, "RlsChannel"), + config_->lookup_service(), channel_args_); + } + // Resize cache if needed. + if (old_config == nullptr || + config_->cache_size_bytes() != old_config->cache_size_bytes()) { + cache_.Resize(config_->cache_size_bytes()); + } + // Start update of child policies if needed. + if (update_child_policies) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] starting child policy updates", this); + } + for (auto& p : child_policy_map_) { + p.second->StartUpdate(); + } + } else if (created_default_child) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] starting default child policy update", + this); + } + default_child_policy_->StartUpdate(); + } + } + // Now that we've released the lock, finish update of child policies. + if (update_child_policies) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] finishing child policy updates", this); + } + for (auto& p : child_policy_map_) { + p.second->MaybeFinishUpdate(); + } + } else if (created_default_child) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] finishing default child policy update", + this); + } + default_child_policy_->MaybeFinishUpdate(); + } + // In principle, we need to update the picker here only if the config + // fields used by the picker have changed. However, it seems fragile + // to check individual fields, since the picker logic could change in + // the future to use additional config fields, and we might not + // remember to update the code here. So for now, we just unconditionally + // update the picker here, even though it's probably redundant. + UpdatePickerLocked(); +} + +void RlsLb::ExitIdleLocked() { + MutexLock lock(&mu_); + for (auto& child_entry : child_policy_map_) { + child_entry.second->ExitIdleLocked(); + } +} + +void RlsLb::ResetBackoffLocked() { + { + MutexLock lock(&mu_); + rls_channel_->ResetBackoff(); + cache_.ResetAllBackoff(); + } + for (auto& child : child_policy_map_) { + child.second->ResetBackoffLocked(); + } +} + +void RlsLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] policy shutdown", this); + } + MutexLock lock(&mu_); + is_shutdown_ = true; + config_.reset(DEBUG_LOCATION, "ShutdownLocked"); + if (channel_args_ != nullptr) { + grpc_channel_args_destroy(channel_args_); + } + cache_.Shutdown(); + request_map_.clear(); + rls_channel_.reset(); + default_child_policy_.reset(); +} + +void RlsLb::UpdatePickerAsync() { + // Run via the ExecCtx, since the caller may be holding the lock, and + // we don't want to be doing that when we hop into the WorkSerializer, + // in case the WorkSerializer callback happens to run inline. + ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(UpdatePickerCallback, + Ref(DEBUG_LOCATION, "UpdatePickerCallback").release(), + grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); +} + +void RlsLb::UpdatePickerCallback(void* arg, grpc_error_handle /*error*/) { + auto* rls_lb = static_cast(arg); + rls_lb->work_serializer()->Run( + [rls_lb]() { + RefCountedPtr lb_policy(rls_lb); + lb_policy->UpdatePickerLocked(); + lb_policy.reset(DEBUG_LOCATION, "UpdatePickerCallback"); + }, + DEBUG_LOCATION); +} + +void RlsLb::UpdatePickerLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] updating picker", this); + } + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + if (!child_policy_map_.empty()) { + state = GRPC_CHANNEL_TRANSIENT_FAILURE; + int num_idle = 0; + int num_connecting = 0; + { + MutexLock lock(&mu_); + if (is_shutdown_) return; + for (auto& p : child_policy_map_) { + grpc_connectivity_state child_state = p.second->connectivity_state(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] target %s in state %s", this, + p.second->target().c_str(), + ConnectivityStateName(child_state)); + } + if (child_state == GRPC_CHANNEL_READY) { + state = GRPC_CHANNEL_READY; + break; + } else if (child_state == GRPC_CHANNEL_CONNECTING) { + ++num_connecting; + } else if (child_state == GRPC_CHANNEL_IDLE) { + ++num_idle; + } + } + if (state != GRPC_CHANNEL_READY) { + if (num_connecting > 0) { + state = GRPC_CHANNEL_CONNECTING; + } else if (num_idle > 0) { + state = GRPC_CHANNEL_IDLE; + } + } + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) { + gpr_log(GPR_INFO, "[rlslb %p] reporting state %s", this, + ConnectivityStateName(state)); + } + absl::Status status; + if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + status = absl::UnavailableError("no children available"); + } + channel_control_helper()->UpdateState( + state, status, absl::make_unique(Ref(DEBUG_LOCATION, "Picker"))); +} + +// +// RlsLbFactory +// + +grpc_error_handle ParseJsonHeaders(size_t idx, const Json& json, + std::string* key, + std::vector* headers) { + if (json.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:headers index:", idx, " error:type should be OBJECT")); + } + std::vector error_list; + // requiredMatch must not be present. + if (json.object_value().find("requiredMatch") != json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:requiredMatch error:must not be present")); + } + // Find key. + if (ParseJsonObjectField(json.object_value(), "key", key, &error_list) && + key->empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:key error:must be non-empty")); + } + // Find headers. + const Json::Array* headers_json = nullptr; + ParseJsonObjectField(json.object_value(), "names", &headers_json, + &error_list); + if (headers_json != nullptr) { + if (headers_json->empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:names error:list is empty")); + } else { + size_t name_idx = 0; + for (const Json& name_json : *headers_json) { + if (name_json.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:names index:", name_idx, " error:type should be STRING"))); + } else if (name_json.string_value().empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:names index:", name_idx, + " error:header name must be non-empty"))); + } else { + headers->push_back(name_json.string_value()); + } + ++name_idx; + } + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("field:headers index:", idx), &error_list); +} + +std::string ParseJsonMethodName(size_t idx, const Json& json, + grpc_error_handle* error) { + if (json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:names index:", idx, " error:type should be OBJECT")); + return ""; + } + std::vector error_list; + // Find service name. + absl::string_view service_name; + ParseJsonObjectField(json.object_value(), "service", &service_name, + &error_list); + // Find method name. + absl::string_view method_name; + ParseJsonObjectField(json.object_value(), "method", &method_name, &error_list, + /*required=*/false); + // Return error, if any. + *error = GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("field:names index:", idx), &error_list); + // Construct path. + return absl::StrCat("/", service_name, "/", method_name); +} + +grpc_error_handle ParseGrpcKeybuilder( + size_t idx, const Json& json, RlsLbConfig::KeyBuilderMap* key_builder_map) { + if (json.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "field:grpc_keybuilders index:", idx, " error:type should be OBJECT")); + } + std::vector error_list; + // Parse names. + std::set names; + const Json::Array* names_array = nullptr; + if (ParseJsonObjectField(json.object_value(), "names", &names_array, + &error_list)) { + if (names_array->empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:names error:list is empty")); + } else { + size_t name_idx = 0; + for (const Json& name_json : *names_array) { + grpc_error_handle child_error = GRPC_ERROR_NONE; + std::string name = + ParseJsonMethodName(name_idx++, name_json, &child_error); + if (child_error != GRPC_ERROR_NONE) { + error_list.push_back(child_error); + } else { + bool inserted = names.insert(name).second; + if (!inserted) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:names error:duplicate entry for ", name))); + } + } + } + } + } + // Helper function to check for duplicate keys. + std::set all_keys; + auto duplicate_key_check_func = [&all_keys, + &error_list](const std::string& key) { + auto it = all_keys.find(key); + if (it != all_keys.end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("key \"", key, "\" listed multiple times"))); + } else { + all_keys.insert(key); + } + }; + // Parse headers. + RlsLbConfig::KeyBuilder key_builder; + const Json::Array* headers_array = nullptr; + ParseJsonObjectField(json.object_value(), "headers", &headers_array, + &error_list, /*required=*/false); + if (headers_array != nullptr) { + size_t header_idx = 0; + for (const Json& header_json : *headers_array) { + std::string key; + std::vector headers; + grpc_error_handle child_error = + ParseJsonHeaders(header_idx++, header_json, &key, &headers); + if (child_error != GRPC_ERROR_NONE) { + error_list.push_back(child_error); + } else { + duplicate_key_check_func(key); + key_builder.header_keys.emplace(key, std::move(headers)); + } + } + } + // Parse extraKeys. + const Json::Object* extra_keys = nullptr; + ParseJsonObjectField(json.object_value(), "extraKeys", &extra_keys, + &error_list, /*required=*/false); + if (extra_keys != nullptr) { + std::vector extra_keys_errors; + if (ParseJsonObjectField(*extra_keys, "host", &key_builder.host_key, + &extra_keys_errors, /*required=*/false) && + key_builder.host_key.empty()) { + extra_keys_errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:host error:must be non-empty")); + } + if (!key_builder.host_key.empty()) { + duplicate_key_check_func(key_builder.host_key); + } + if (ParseJsonObjectField(*extra_keys, "service", &key_builder.service_key, + &extra_keys_errors, /*required=*/false) && + key_builder.service_key.empty()) { + extra_keys_errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:service error:must be non-empty")); + } + if (!key_builder.service_key.empty()) { + duplicate_key_check_func(key_builder.service_key); + } + if (ParseJsonObjectField(*extra_keys, "method", &key_builder.method_key, + &extra_keys_errors, /*required=*/false) && + key_builder.method_key.empty()) { + extra_keys_errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:method error:must be non-empty")); + } + if (!key_builder.method_key.empty()) { + duplicate_key_check_func(key_builder.method_key); + } + if (!extra_keys_errors.empty()) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:extraKeys", &extra_keys_errors)); + } + } + // Parse constantKeys. + const Json::Object* constant_keys = nullptr; + ParseJsonObjectField(json.object_value(), "constantKeys", &constant_keys, + &error_list, /*required=*/false); + if (constant_keys != nullptr) { + std::vector constant_keys_errors; + for (const auto& p : *constant_keys) { + const std::string& key = p.first; + const Json& value = p.second; + if (key.empty()) { + constant_keys_errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "error:keys must be non-empty")); + } + duplicate_key_check_func(key); + ExtractJsonString(value, key, &key_builder.constant_keys[key], + &constant_keys_errors); + } + if (!constant_keys_errors.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:constantKeys", &constant_keys_errors)); + } + } + // Insert key_builder into key_builder_map. + for (const std::string& name : names) { + bool inserted = key_builder_map->emplace(name, key_builder).second; + if (!inserted) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:names error:duplicate entry for ", name))); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("index:", idx), &error_list); +} + +RlsLbConfig::KeyBuilderMap ParseGrpcKeybuilders( + const Json::Array& key_builder_list, grpc_error_handle* error) { + RlsLbConfig::KeyBuilderMap key_builder_map; + if (key_builder_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:grpcKeybuilders error:list is empty"); + return key_builder_map; + } + std::vector error_list; + size_t idx = 0; + for (const Json& key_builder : key_builder_list) { + grpc_error_handle child_error = + ParseGrpcKeybuilder(idx++, key_builder, &key_builder_map); + if (child_error != GRPC_ERROR_NONE) error_list.push_back(child_error); + } + *error = GRPC_ERROR_CREATE_FROM_VECTOR("field:grpcKeybuilders", &error_list); + return key_builder_map; +} + +RlsLbConfig::RouteLookupConfig ParseRouteLookupConfig( + const Json::Object& json, grpc_error_handle* error) { + std::vector error_list; + RlsLbConfig::RouteLookupConfig route_lookup_config; + // Parse grpcKeybuilders. + const Json::Array* keybuilder_list = nullptr; + ParseJsonObjectField(json, "grpcKeybuilders", &keybuilder_list, &error_list); + if (keybuilder_list != nullptr) { + grpc_error_handle child_error = GRPC_ERROR_NONE; + route_lookup_config.key_builder_map = + ParseGrpcKeybuilders(*keybuilder_list, &child_error); + if (child_error != GRPC_ERROR_NONE) error_list.push_back(child_error); + } + // Parse lookupService. + if (ParseJsonObjectField(json, "lookupService", + &route_lookup_config.lookup_service, &error_list)) { + if (!ResolverRegistry::IsValidTarget(route_lookup_config.lookup_service)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:lookupService error:must be valid gRPC target URI")); + } + } + // Parse lookupServiceTimeout. + route_lookup_config.lookup_service_timeout = kDefaultLookupServiceTimeout; + ParseJsonObjectFieldAsDuration(json, "lookupServiceTimeout", + &route_lookup_config.lookup_service_timeout, + &error_list, /*required=*/false); + // Parse maxAge. + route_lookup_config.max_age = kMaxMaxAge; + bool max_age_set = ParseJsonObjectFieldAsDuration( + json, "maxAge", &route_lookup_config.max_age, &error_list, + /*required=*/false); + // Clamp maxAge to the max allowed value. + if (route_lookup_config.max_age > kMaxMaxAge) { + route_lookup_config.max_age = kMaxMaxAge; + } + // Parse staleAge. + route_lookup_config.stale_age = kMaxMaxAge; + bool stale_age_set = ParseJsonObjectFieldAsDuration( + json, "staleAge", &route_lookup_config.stale_age, &error_list, + /*required=*/false); + // If staleAge is set, then maxAge must also be set. + if (stale_age_set && !max_age_set) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxAge error:must be set if staleAge is set")); + } + // Ignore staleAge if greater than or equal to maxAge. + if (route_lookup_config.stale_age >= route_lookup_config.max_age) { + route_lookup_config.stale_age = route_lookup_config.max_age; + } + // Parse cacheSizeBytes. + ParseJsonObjectField(json, "cacheSizeBytes", + &route_lookup_config.cache_size_bytes, &error_list); + if (route_lookup_config.cache_size_bytes <= 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:cacheSizeBytes error:must be greater than 0")); + } + // Clamp cacheSizeBytes to the max allowed value. + if (route_lookup_config.cache_size_bytes > kMaxCacheSizeBytes) { + route_lookup_config.cache_size_bytes = kMaxCacheSizeBytes; + } + // Parse defaultTarget. + if (ParseJsonObjectField(json, "defaultTarget", + &route_lookup_config.default_target, &error_list, + /*required=*/false)) { + if (route_lookup_config.default_target.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:defaultTarget error:must be non-empty if set")); + } + } + *error = + GRPC_ERROR_CREATE_FROM_VECTOR("field:routeLookupConfig", &error_list); + return route_lookup_config; +} + +grpc_error_handle ValidateChildPolicyList( + const Json& child_policy_list, + const std::string& child_policy_config_target_field_name, + const std::string& default_target, Json* child_policy_config, + RefCountedPtr* + default_child_policy_parsed_config) { + // Add target to each entry in the config proto. + *child_policy_config = child_policy_list; + std::string target = + default_target.empty() ? kFakeTargetFieldValue : default_target; + grpc_error_handle error = InsertOrUpdateChildPolicyField( + child_policy_config_target_field_name, target, child_policy_config); + if (error != GRPC_ERROR_NONE) return error; + // Parse the config. + RefCountedPtr parsed_config = + LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + *child_policy_config, &error); + if (error != GRPC_ERROR_NONE) return error; + // Find the chosen config and return it in JSON form. + // We remove all non-selected configs, and in the selected config, we leave + // the target field in place, set to the default value. This slightly + // optimizes what we need to do later when we update a child policy for a + // given target. + if (parsed_config != nullptr) { + for (Json& config : *(child_policy_config->mutable_array())) { + if (config.object_value().begin()->first == parsed_config->name()) { + Json save_config = std::move(config); + child_policy_config->mutable_array()->clear(); + child_policy_config->mutable_array()->push_back(std::move(save_config)); + break; + } + } + } + // If default target is set, return the parsed config. + if (!default_target.empty()) { + *default_child_policy_parsed_config = std::move(parsed_config); + } + return GRPC_ERROR_NONE; +} + +class RlsLbFactory : public LoadBalancingPolicyFactory { + public: + const char* name() const override { return kRls; } + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& config, grpc_error_handle* error) const override { + std::vector error_list; + // Parse routeLookupConfig. + RlsLbConfig::RouteLookupConfig route_lookup_config; + const Json::Object* route_lookup_config_json = nullptr; + if (ParseJsonObjectField(config.object_value(), "routeLookupConfig", + &route_lookup_config_json, &error_list)) { + grpc_error_handle child_error = GRPC_ERROR_NONE; + route_lookup_config = + ParseRouteLookupConfig(*route_lookup_config_json, &child_error); + if (child_error != GRPC_ERROR_NONE) error_list.push_back(child_error); + } + // Parse childPolicyConfigTargetFieldName. + std::string child_policy_config_target_field_name; + if (ParseJsonObjectField( + config.object_value(), "childPolicyConfigTargetFieldName", + &child_policy_config_target_field_name, &error_list)) { + if (child_policy_config_target_field_name.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:childPolicyConfigTargetFieldName error:must be non-empty")); + } + } + // Parse childPolicy. + Json child_policy_config; + RefCountedPtr + default_child_policy_parsed_config; + auto it = config.object_value().find("childPolicy"); + if (it == config.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:childPolicy error:does not exist.")); + } else if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:childPolicy error:type should be ARRAY")); + } else { + grpc_error_handle child_error = ValidateChildPolicyList( + it->second, child_policy_config_target_field_name, + route_lookup_config.default_target, &child_policy_config, + &default_child_policy_parsed_config); + if (child_error != GRPC_ERROR_NONE) { + error_list.push_back(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "field:childPolicy", &child_error, 1)); + GRPC_ERROR_UNREF(child_error); + } + } + // Return result. + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "errors parsing RLS LB policy config", &error_list); + return MakeRefCounted( + std::move(route_lookup_config), std::move(child_policy_config), + std::move(child_policy_config_target_field_name), + std::move(default_child_policy_parsed_config)); + } +}; + +bool RlsEnabled() { + char* value = gpr_getenv("GRPC_EXPERIMENTAL_ENABLE_RLS_LB_POLICY"); + bool parsed_value; + bool parse_succeeded = gpr_parse_bool_value(value, &parsed_value); + gpr_free(value); + return parse_succeeded && parsed_value; +} + +} // namespace + +void RlsLbPluginInit() { + if (!RlsEnabled()) return; + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void RlsLbPluginShutdown() {} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc new file mode 100644 index 00000000..675ecf47 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc @@ -0,0 +1,503 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/** Round Robin Policy. + * + * Before every pick, the \a get_next_ready_subchannel_index_locked function + * returns the p->subchannel_list->subchannels index for next subchannel, + * respecting the relative order of the addresses provided upon creation or + * updates. Note however that updates will start picking from the beginning of + * the updated list. */ + +#include + +#include +#include + +#include + +#include "src/core/ext/filters/client_channel/lb_policy/subchannel_list.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/static_metadata.h" + +namespace grpc_core { + +TraceFlag grpc_lb_round_robin_trace(false, "round_robin"); + +namespace { + +// +// round_robin LB policy +// + +constexpr char kRoundRobin[] = "round_robin"; + +class RoundRobin : public LoadBalancingPolicy { + public: + explicit RoundRobin(Args args); + + const char* name() const override { return kRoundRobin; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + + private: + ~RoundRobin() override; + + // Forward declaration. + class RoundRobinSubchannelList; + + // Data for a particular subchannel in a subchannel list. + // This subclass adds the following functionality: + // - Tracks the previous connectivity state of the subchannel, so that + // we know how many subchannels are in each state. + class RoundRobinSubchannelData + : public SubchannelData { + public: + RoundRobinSubchannelData( + SubchannelList* + subchannel_list, + const ServerAddress& address, + RefCountedPtr subchannel) + : SubchannelData(subchannel_list, address, std::move(subchannel)) {} + + grpc_connectivity_state connectivity_state() const { + return last_connectivity_state_; + } + + bool seen_failure_since_ready() const { return seen_failure_since_ready_; } + + // Performs connectivity state updates that need to be done both when we + // first start watching and when a watcher notification is received. + void UpdateConnectivityStateLocked( + grpc_connectivity_state connectivity_state); + + private: + // Performs connectivity state updates that need to be done only + // after we have started watching. + void ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) override; + + grpc_connectivity_state last_connectivity_state_ = GRPC_CHANNEL_IDLE; + bool seen_failure_since_ready_ = false; + }; + + // A list of subchannels. + class RoundRobinSubchannelList + : public SubchannelList { + public: + RoundRobinSubchannelList(RoundRobin* policy, TraceFlag* tracer, + ServerAddressList addresses, + const grpc_channel_args& args) + : SubchannelList(policy, tracer, std::move(addresses), + policy->channel_control_helper(), args) { + // Need to maintain a ref to the LB policy as long as we maintain + // any references to subchannels, since the subchannels' + // pollset_sets will include the LB policy's pollset_set. + policy->Ref(DEBUG_LOCATION, "subchannel_list").release(); + } + + ~RoundRobinSubchannelList() override { + RoundRobin* p = static_cast(policy()); + p->Unref(DEBUG_LOCATION, "subchannel_list"); + } + + // Starts watching the subchannels in this list. + void StartWatchingLocked(); + + // Updates the counters of subchannels in each state when a + // subchannel transitions from old_state to new_state. + void UpdateStateCountersLocked(grpc_connectivity_state old_state, + grpc_connectivity_state new_state); + + // If this subchannel list is the RR policy's current subchannel + // list, updates the RR policy's connectivity state based on the + // subchannel list's state counters. + void MaybeUpdateRoundRobinConnectivityStateLocked(); + + // Updates the RR policy's overall state based on the counters of + // subchannels in each state. + void UpdateRoundRobinStateFromSubchannelStateCountsLocked(); + + private: + size_t num_ready_ = 0; + size_t num_connecting_ = 0; + size_t num_transient_failure_ = 0; + }; + + class Picker : public SubchannelPicker { + public: + Picker(RoundRobin* parent, RoundRobinSubchannelList* subchannel_list); + + PickResult Pick(PickArgs args) override; + + private: + // Using pointer value only, no ref held -- do not dereference! + RoundRobin* parent_; + + size_t last_picked_index_; + absl::InlinedVector, 10> subchannels_; + }; + + void ShutdownLocked() override; + + /** list of subchannels */ + OrphanablePtr subchannel_list_; + /** Latest version of the subchannel list. + * Subchannel connectivity callbacks will only promote updated subchannel + * lists if they equal \a latest_pending_subchannel_list. In other words, + * racing callbacks that reference outdated subchannel lists won't perform any + * update. */ + OrphanablePtr latest_pending_subchannel_list_; + /** are we shutting down? */ + bool shutdown_ = false; +}; + +// +// RoundRobin::Picker +// + +RoundRobin::Picker::Picker(RoundRobin* parent, + RoundRobinSubchannelList* subchannel_list) + : parent_(parent) { + for (size_t i = 0; i < subchannel_list->num_subchannels(); ++i) { + RoundRobinSubchannelData* sd = subchannel_list->subchannel(i); + if (sd->connectivity_state() == GRPC_CHANNEL_READY) { + subchannels_.push_back(sd->subchannel()->Ref()); + } + } + // For discussion on why we generate a random starting index for + // the picker, see https://github.com/grpc/grpc-go/issues/2580. + // TODO(roth): rand(3) is not thread-safe. This should be replaced with + // something better as part of https://github.com/grpc/grpc/issues/17891. + last_picked_index_ = rand() % subchannels_.size(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, + "[RR %p picker %p] created picker from subchannel_list=%p " + "with %" PRIuPTR " READY subchannels; last_picked_index_=%" PRIuPTR, + parent_, this, subchannel_list, subchannels_.size(), + last_picked_index_); + } +} + +RoundRobin::PickResult RoundRobin::Picker::Pick(PickArgs /*args*/) { + last_picked_index_ = (last_picked_index_ + 1) % subchannels_.size(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, + "[RR %p picker %p] returning index %" PRIuPTR ", subchannel=%p", + parent_, this, last_picked_index_, + subchannels_[last_picked_index_].get()); + } + return PickResult::Complete(subchannels_[last_picked_index_]); +} + +// +// RoundRobin +// + +RoundRobin::RoundRobin(Args args) : LoadBalancingPolicy(std::move(args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, "[RR %p] Created", this); + } +} + +RoundRobin::~RoundRobin() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, "[RR %p] Destroying Round Robin policy", this); + } + GPR_ASSERT(subchannel_list_ == nullptr); + GPR_ASSERT(latest_pending_subchannel_list_ == nullptr); +} + +void RoundRobin::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, "[RR %p] Shutting down", this); + } + shutdown_ = true; + subchannel_list_.reset(); + latest_pending_subchannel_list_.reset(); +} + +void RoundRobin::ResetBackoffLocked() { + subchannel_list_->ResetBackoffLocked(); + if (latest_pending_subchannel_list_ != nullptr) { + latest_pending_subchannel_list_->ResetBackoffLocked(); + } +} + +void RoundRobin::RoundRobinSubchannelList::StartWatchingLocked() { + if (num_subchannels() == 0) return; + // Check current state of each subchannel synchronously, since any + // subchannel already used by some other channel may have a non-IDLE + // state. + for (size_t i = 0; i < num_subchannels(); ++i) { + grpc_connectivity_state state = + subchannel(i)->CheckConnectivityStateLocked(); + if (state != GRPC_CHANNEL_IDLE) { + subchannel(i)->UpdateConnectivityStateLocked(state); + } + } + // Start connectivity watch for each subchannel. + for (size_t i = 0; i < num_subchannels(); i++) { + if (subchannel(i)->subchannel() != nullptr) { + subchannel(i)->StartConnectivityWatchLocked(); + subchannel(i)->subchannel()->AttemptToConnect(); + } + } + // Now set the LB policy's state based on the subchannels' states. + UpdateRoundRobinStateFromSubchannelStateCountsLocked(); +} + +void RoundRobin::RoundRobinSubchannelList::UpdateStateCountersLocked( + grpc_connectivity_state old_state, grpc_connectivity_state new_state) { + GPR_ASSERT(old_state != GRPC_CHANNEL_SHUTDOWN); + GPR_ASSERT(new_state != GRPC_CHANNEL_SHUTDOWN); + if (old_state == GRPC_CHANNEL_READY) { + GPR_ASSERT(num_ready_ > 0); + --num_ready_; + } else if (old_state == GRPC_CHANNEL_CONNECTING) { + GPR_ASSERT(num_connecting_ > 0); + --num_connecting_; + } else if (old_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + GPR_ASSERT(num_transient_failure_ > 0); + --num_transient_failure_; + } + if (new_state == GRPC_CHANNEL_READY) { + ++num_ready_; + } else if (new_state == GRPC_CHANNEL_CONNECTING) { + ++num_connecting_; + } else if (new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + ++num_transient_failure_; + } +} + +// Sets the RR policy's connectivity state and generates a new picker based +// on the current subchannel list. +void RoundRobin::RoundRobinSubchannelList:: + MaybeUpdateRoundRobinConnectivityStateLocked() { + RoundRobin* p = static_cast(policy()); + // Only set connectivity state if this is the current subchannel list. + if (p->subchannel_list_.get() != this) return; + /* In priority order. The first rule to match terminates the search (ie, if we + * are on rule n, all previous rules were unfulfilled). + * + * 1) RULE: ANY subchannel is READY => policy is READY. + * CHECK: subchannel_list->num_ready > 0. + * + * 2) RULE: ANY subchannel is CONNECTING => policy is CONNECTING. + * CHECK: sd->curr_connectivity_state == CONNECTING. + * + * 3) RULE: ALL subchannels are TRANSIENT_FAILURE => policy is + * TRANSIENT_FAILURE. + * CHECK: subchannel_list->num_transient_failures == + * subchannel_list->num_subchannels. + */ + if (num_ready_ > 0) { + /* 1) READY */ + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_READY, absl::Status(), absl::make_unique(p, this)); + } else if (num_connecting_ > 0) { + /* 2) CONNECTING */ + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_CONNECTING, absl::Status(), + absl::make_unique(p->Ref(DEBUG_LOCATION, "QueuePicker"))); + } else if (num_transient_failure_ == num_subchannels()) { + /* 3) TRANSIENT_FAILURE */ + absl::Status status = + absl::UnavailableError("connections to all backends failing"); + p->channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } +} + +void RoundRobin::RoundRobinSubchannelList:: + UpdateRoundRobinStateFromSubchannelStateCountsLocked() { + RoundRobin* p = static_cast(policy()); + // If we have at least one READY subchannel, then swap to the new list. + // Also, if all of the subchannels are in TRANSIENT_FAILURE, then we know + // we've tried all of them and failed, so we go ahead and swap over + // anyway; this may cause the channel to go from READY to TRANSIENT_FAILURE, + // but we are doing what the control plane told us to do. + if (num_ready_ > 0 || num_transient_failure_ == num_subchannels()) { + if (p->subchannel_list_.get() != this) { + // Promote this list to p->subchannel_list_. + // This list must be p->latest_pending_subchannel_list_, because + // any previous update would have been shut down already and + // therefore we would not be receiving a notification for them. + GPR_ASSERT(p->latest_pending_subchannel_list_.get() == this); + GPR_ASSERT(!shutting_down()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + const size_t old_num_subchannels = + p->subchannel_list_ != nullptr + ? p->subchannel_list_->num_subchannels() + : 0; + gpr_log(GPR_INFO, + "[RR %p] phasing out subchannel list %p (size %" PRIuPTR + ") in favor of %p (size %" PRIuPTR ")", + p, p->subchannel_list_.get(), old_num_subchannels, this, + num_subchannels()); + } + p->subchannel_list_ = std::move(p->latest_pending_subchannel_list_); + } + } + // Update the RR policy's connectivity state if needed. + MaybeUpdateRoundRobinConnectivityStateLocked(); +} + +void RoundRobin::RoundRobinSubchannelData::UpdateConnectivityStateLocked( + grpc_connectivity_state connectivity_state) { + RoundRobin* p = static_cast(subchannel_list()->policy()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log( + GPR_INFO, + "[RR %p] connectivity changed for subchannel %p, subchannel_list %p " + "(index %" PRIuPTR " of %" PRIuPTR "): prev_state=%s new_state=%s", + p, subchannel(), subchannel_list(), Index(), + subchannel_list()->num_subchannels(), + ConnectivityStateName(last_connectivity_state_), + ConnectivityStateName(connectivity_state)); + } + // Decide what state to report for aggregation purposes. + // If we haven't seen a failure since the last time we were in state + // READY, then we report the state change as-is. However, once we do see + // a failure, we report TRANSIENT_FAILURE and do not report any subsequent + // state changes until we go back into state READY. + if (!seen_failure_since_ready_) { + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + seen_failure_since_ready_ = true; + } + subchannel_list()->UpdateStateCountersLocked(last_connectivity_state_, + connectivity_state); + } else { + if (connectivity_state == GRPC_CHANNEL_READY) { + seen_failure_since_ready_ = false; + subchannel_list()->UpdateStateCountersLocked( + GRPC_CHANNEL_TRANSIENT_FAILURE, connectivity_state); + } + } + // Record last seen connectivity state. + last_connectivity_state_ = connectivity_state; +} + +void RoundRobin::RoundRobinSubchannelData::ProcessConnectivityChangeLocked( + grpc_connectivity_state connectivity_state) { + RoundRobin* p = static_cast(subchannel_list()->policy()); + GPR_ASSERT(subchannel() != nullptr); + // If the new state is TRANSIENT_FAILURE, re-resolve. + // Only do this if we've started watching, not at startup time. + // Otherwise, if the subchannel was already in state TRANSIENT_FAILURE + // when the subchannel list was created, we'd wind up in a constant + // loop of re-resolution. + // Also attempt to reconnect. + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, + "[RR %p] Subchannel %p has gone into TRANSIENT_FAILURE. " + "Requesting re-resolution", + p, subchannel()); + } + p->channel_control_helper()->RequestReresolution(); + subchannel()->AttemptToConnect(); + } + // Update state counters. + UpdateConnectivityStateLocked(connectivity_state); + // Update overall state and renew notification. + subchannel_list()->UpdateRoundRobinStateFromSubchannelStateCountsLocked(); +} + +void RoundRobin::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, "[RR %p] received update with %" PRIuPTR " addresses", + this, args.addresses.size()); + } + // Replace latest_pending_subchannel_list_. + if (latest_pending_subchannel_list_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) { + gpr_log(GPR_INFO, + "[RR %p] Shutting down previous pending subchannel list %p", this, + latest_pending_subchannel_list_.get()); + } + } + latest_pending_subchannel_list_ = MakeOrphanable( + this, &grpc_lb_round_robin_trace, std::move(args.addresses), *args.args); + if (latest_pending_subchannel_list_->num_subchannels() == 0) { + // If the new list is empty, immediately promote the new list to the + // current list and transition to TRANSIENT_FAILURE. + absl::Status status = absl::UnavailableError("Empty update"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + subchannel_list_ = std::move(latest_pending_subchannel_list_); + } else if (subchannel_list_ == nullptr) { + // If there is no current list, immediately promote the new list to + // the current list and start watching it. + subchannel_list_ = std::move(latest_pending_subchannel_list_); + subchannel_list_->StartWatchingLocked(); + } else { + // Start watching the pending list. It will get swapped into the + // current list when it reports READY. + latest_pending_subchannel_list_->StartWatchingLocked(); + } +} + +class RoundRobinConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kRoundRobin; } +}; + +// +// factory +// + +class RoundRobinFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kRoundRobin; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } +}; + +} // namespace + +} // namespace grpc_core + +void grpc_lb_policy_round_robin_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_round_robin_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc new file mode 100644 index 00000000..785975d6 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc @@ -0,0 +1,741 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/address_filtering.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +TraceFlag grpc_lb_weighted_target_trace(false, "weighted_target_lb"); + +namespace { + +constexpr char kWeightedTarget[] = "weighted_target_experimental"; + +// How long we keep a child around for after it has been removed from +// the config. +constexpr int kChildRetentionIntervalMs = 15 * 60 * 1000; + +// Config for weighted_target LB policy. +class WeightedTargetLbConfig : public LoadBalancingPolicy::Config { + public: + struct ChildConfig { + uint32_t weight; + RefCountedPtr config; + }; + + using TargetMap = std::map; + + explicit WeightedTargetLbConfig(TargetMap target_map) + : target_map_(std::move(target_map)) {} + + const char* name() const override { return kWeightedTarget; } + + const TargetMap& target_map() const { return target_map_; } + + private: + TargetMap target_map_; +}; + +// weighted_target LB policy. +class WeightedTargetLb : public LoadBalancingPolicy { + public: + explicit WeightedTargetLb(Args args); + + const char* name() const override { return kWeightedTarget; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + + private: + // A simple wrapper for ref-counting a picker from the child policy. + class ChildPickerWrapper : public RefCounted { + public: + explicit ChildPickerWrapper(std::unique_ptr picker) + : picker_(std::move(picker)) {} + PickResult Pick(PickArgs args) { return picker_->Pick(args); } + + private: + std::unique_ptr picker_; + }; + + // Picks a child using stateless WRR and then delegates to that + // child's picker. + class WeightedPicker : public SubchannelPicker { + public: + // Maintains a weighted list of pickers from each child that is in + // ready state. The first element in the pair represents the end of a + // range proportional to the child's weight. The start of the range + // is the previous value in the vector and is 0 for the first element. + using PickerList = absl::InlinedVector< + std::pair>, 1>; + + explicit WeightedPicker(PickerList pickers) + : pickers_(std::move(pickers)) {} + + PickResult Pick(PickArgs args) override; + + private: + PickerList pickers_; + }; + + // Each WeightedChild holds a ref to its parent WeightedTargetLb. + class WeightedChild : public InternallyRefCounted { + public: + WeightedChild(RefCountedPtr weighted_target_policy, + const std::string& name); + ~WeightedChild() override; + + void Orphan() override; + + void UpdateLocked(const WeightedTargetLbConfig::ChildConfig& config, + ServerAddressList addresses, + const grpc_channel_args* args); + void ResetBackoffLocked(); + void DeactivateLocked(); + + uint32_t weight() const { return weight_; } + grpc_connectivity_state connectivity_state() const { + return connectivity_state_; + } + RefCountedPtr picker_wrapper() const { + return picker_wrapper_; + } + + private: + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr weighted_child) + : weighted_child_(std::move(weighted_child)) {} + + ~Helper() override { weighted_child_.reset(DEBUG_LOCATION, "Helper"); } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr weighted_child_; + }; + + // Methods for dealing with the child policy. + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + + void OnConnectivityStateUpdateLocked( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker); + + static void OnDelayedRemovalTimer(void* arg, grpc_error_handle error); + void OnDelayedRemovalTimerLocked(grpc_error_handle error); + + // The owning LB policy. + RefCountedPtr weighted_target_policy_; + + const std::string name_; + + uint32_t weight_; + + OrphanablePtr child_policy_; + + RefCountedPtr picker_wrapper_; + grpc_connectivity_state connectivity_state_ = GRPC_CHANNEL_CONNECTING; + bool seen_failure_since_ready_ = false; + + // States for delayed removal. + grpc_timer delayed_removal_timer_; + grpc_closure on_delayed_removal_timer_; + bool delayed_removal_timer_callback_pending_ = false; + bool shutdown_ = false; + }; + + ~WeightedTargetLb() override; + + void ShutdownLocked() override; + + void UpdateStateLocked(); + + // Current config from the resolver. + RefCountedPtr config_; + + // Internal state. + bool shutting_down_ = false; + + // Children. + std::map> targets_; +}; + +// +// WeightedTargetLb::WeightedPicker +// + +WeightedTargetLb::PickResult WeightedTargetLb::WeightedPicker::Pick( + PickArgs args) { + // Generate a random number in [0, total weight). + const uint32_t key = rand() % pickers_[pickers_.size() - 1].first; + // Find the index in pickers_ corresponding to key. + size_t mid = 0; + size_t start_index = 0; + size_t end_index = pickers_.size() - 1; + size_t index = 0; + while (end_index > start_index) { + mid = (start_index + end_index) / 2; + if (pickers_[mid].first > key) { + end_index = mid; + } else if (pickers_[mid].first < key) { + start_index = mid + 1; + } else { + index = mid + 1; + break; + } + } + if (index == 0) index = start_index; + GPR_ASSERT(pickers_[index].first > key); + // Delegate to the child picker. + return pickers_[index].second->Pick(args); +} + +// +// WeightedTargetLb +// + +WeightedTargetLb::WeightedTargetLb(Args args) + : LoadBalancingPolicy(std::move(args)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] created", this); + } +} + +WeightedTargetLb::~WeightedTargetLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] destroying weighted_target LB policy", + this); + } +} + +void WeightedTargetLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] shutting down", this); + } + shutting_down_ = true; + targets_.clear(); +} + +void WeightedTargetLb::ResetBackoffLocked() { + for (auto& p : targets_) p.second->ResetBackoffLocked(); +} + +void WeightedTargetLb::UpdateLocked(UpdateArgs args) { + if (shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] Received update", this); + } + // Update config. + config_ = std::move(args.config); + // Deactivate the targets not in the new config. + for (const auto& p : targets_) { + const std::string& name = p.first; + WeightedChild* child = p.second.get(); + if (config_->target_map().find(name) == config_->target_map().end()) { + child->DeactivateLocked(); + } + } + // Create any children that don't already exist. + // Note that we add all children before updating any of them, because + // an update may trigger a child to immediately update its + // connectivity state (e.g., reporting TRANSIENT_FAILURE immediately when + // receiving an empty address list), and we don't want to return an + // overall state with incomplete data. + for (const auto& p : config_->target_map()) { + const std::string& name = p.first; + auto it = targets_.find(name); + if (it == targets_.end()) { + targets_.emplace(name, MakeOrphanable( + Ref(DEBUG_LOCATION, "WeightedChild"), name)); + } + } + // Update all children. + HierarchicalAddressMap address_map = + MakeHierarchicalAddressMap(args.addresses); + for (const auto& p : config_->target_map()) { + const std::string& name = p.first; + const WeightedTargetLbConfig::ChildConfig& config = p.second; + targets_[name]->UpdateLocked(config, std::move(address_map[name]), + args.args); + } + UpdateStateLocked(); +} + +void WeightedTargetLb::UpdateStateLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] scanning children to determine " + "connectivity state", + this); + } + // Construct a new picker which maintains a map of all child pickers + // that are ready. Each child is represented by a portion of the range + // proportional to its weight, such that the total range is the sum of the + // weights of all children. + WeightedPicker::PickerList picker_list; + uint32_t end = 0; + // Also count the number of children in each state, to determine the + // overall state. + size_t num_connecting = 0; + size_t num_idle = 0; + size_t num_transient_failures = 0; + for (const auto& p : targets_) { + const std::string& child_name = p.first; + const WeightedChild* child = p.second.get(); + // Skip the targets that are not in the latest update. + if (config_->target_map().find(child_name) == config_->target_map().end()) { + continue; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] child=%s state=%s weight=%d picker=%p", + this, child_name.c_str(), + ConnectivityStateName(child->connectivity_state()), + child->weight(), child->picker_wrapper().get()); + } + switch (child->connectivity_state()) { + case GRPC_CHANNEL_READY: { + end += child->weight(); + picker_list.push_back(std::make_pair(end, child->picker_wrapper())); + break; + } + case GRPC_CHANNEL_CONNECTING: { + ++num_connecting; + break; + } + case GRPC_CHANNEL_IDLE: { + ++num_idle; + break; + } + case GRPC_CHANNEL_TRANSIENT_FAILURE: { + ++num_transient_failures; + break; + } + default: + GPR_UNREACHABLE_CODE(return ); + } + } + // Determine aggregated connectivity state. + grpc_connectivity_state connectivity_state; + if (!picker_list.empty()) { + connectivity_state = GRPC_CHANNEL_READY; + } else if (num_connecting > 0) { + connectivity_state = GRPC_CHANNEL_CONNECTING; + } else if (num_idle > 0) { + connectivity_state = GRPC_CHANNEL_IDLE; + } else { + connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] connectivity changed to %s", + this, ConnectivityStateName(connectivity_state)); + } + std::unique_ptr picker; + absl::Status status; + switch (connectivity_state) { + case GRPC_CHANNEL_READY: + picker = absl::make_unique(std::move(picker_list)); + break; + case GRPC_CHANNEL_CONNECTING: + case GRPC_CHANNEL_IDLE: + picker = + absl::make_unique(Ref(DEBUG_LOCATION, "QueuePicker")); + break; + default: + status = absl::UnavailableError( + "weighted_target: all children report state TRANSIENT_FAILURE"); + picker = absl::make_unique(status); + } + channel_control_helper()->UpdateState(connectivity_state, status, + std::move(picker)); +} + +// +// WeightedTargetLb::WeightedChild +// + +WeightedTargetLb::WeightedChild::WeightedChild( + RefCountedPtr weighted_target_policy, + const std::string& name) + : weighted_target_policy_(std::move(weighted_target_policy)), name_(name) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, "[weighted_target_lb %p] created WeightedChild %p for %s", + weighted_target_policy_.get(), this, name_.c_str()); + } + GRPC_CLOSURE_INIT(&on_delayed_removal_timer_, OnDelayedRemovalTimer, this, + grpc_schedule_on_exec_ctx); +} + +WeightedTargetLb::WeightedChild::~WeightedChild() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: destroying child", + weighted_target_policy_.get(), this, name_.c_str()); + } + weighted_target_policy_.reset(DEBUG_LOCATION, "WeightedChild"); +} + +void WeightedTargetLb::WeightedChild::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: shutting down child", + weighted_target_policy_.get(), this, name_.c_str()); + } + // Remove the child policy's interested_parties pollset_set from the + // xDS policy. + grpc_pollset_set_del_pollset_set( + child_policy_->interested_parties(), + weighted_target_policy_->interested_parties()); + child_policy_.reset(); + // Drop our ref to the child's picker, in case it's holding a ref to + // the child. + picker_wrapper_.reset(); + if (delayed_removal_timer_callback_pending_) { + delayed_removal_timer_callback_pending_ = false; + grpc_timer_cancel(&delayed_removal_timer_); + } + shutdown_ = true; + Unref(); +} + +OrphanablePtr +WeightedTargetLb::WeightedChild::CreateChildPolicyLocked( + const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = weighted_target_policy_->work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = + absl::make_unique(this->Ref(DEBUG_LOCATION, "Helper")); + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_lb_weighted_target_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: Created new child " + "policy handler %p", + weighted_target_policy_.get(), this, name_.c_str(), + lb_policy.get()); + } + // Add the xDS's interested_parties pollset_set to that of the newly created + // child policy. This will make the child policy progress upon activity on + // xDS LB, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set( + lb_policy->interested_parties(), + weighted_target_policy_->interested_parties()); + return lb_policy; +} + +void WeightedTargetLb::WeightedChild::UpdateLocked( + const WeightedTargetLbConfig::ChildConfig& config, + ServerAddressList addresses, const grpc_channel_args* args) { + if (weighted_target_policy_->shutting_down_) return; + // Update child weight. + weight_ = config.weight; + // Reactivate if needed. + if (delayed_removal_timer_callback_pending_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: reactivating", + weighted_target_policy_.get(), this, name_.c_str()); + } + delayed_removal_timer_callback_pending_ = false; + grpc_timer_cancel(&delayed_removal_timer_); + } + // Create child policy if needed. + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(args); + } + // Construct update args. + UpdateArgs update_args; + update_args.config = config.config; + update_args.addresses = std::move(addresses); + update_args.args = grpc_channel_args_copy(args); + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: Updating child " + "policy handler %p", + weighted_target_policy_.get(), this, name_.c_str(), + child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +void WeightedTargetLb::WeightedChild::ResetBackoffLocked() { + child_policy_->ResetBackoffLocked(); +} + +void WeightedTargetLb::WeightedChild::OnConnectivityStateUpdateLocked( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + // Cache the picker in the WeightedChild. + picker_wrapper_ = MakeRefCounted(std::move(picker)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: connectivity " + "state update: state=%s (%s) picker_wrapper=%p", + weighted_target_policy_.get(), this, name_.c_str(), + ConnectivityStateName(state), status.ToString().c_str(), + picker_wrapper_.get()); + } + // If the child reports IDLE, immediately tell it to exit idle. + if (state == GRPC_CHANNEL_IDLE) child_policy_->ExitIdleLocked(); + // Decide what state to report for aggregation purposes. + // If we haven't seen a failure since the last time we were in state + // READY, then we report the state change as-is. However, once we do see + // a failure, we report TRANSIENT_FAILURE and ignore any subsequent state + // changes until we go back into state READY. + if (!seen_failure_since_ready_) { + if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + seen_failure_since_ready_ = true; + } + } else { + if (state != GRPC_CHANNEL_READY) return; + seen_failure_since_ready_ = false; + } + connectivity_state_ = state; + // Notify the LB policy. + weighted_target_policy_->UpdateStateLocked(); +} + +void WeightedTargetLb::WeightedChild::DeactivateLocked() { + // If already deactivated, don't do that again. + if (weight_ == 0) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_weighted_target_trace)) { + gpr_log(GPR_INFO, + "[weighted_target_lb %p] WeightedChild %p %s: deactivating", + weighted_target_policy_.get(), this, name_.c_str()); + } + // Set the child weight to 0 so that future picker won't contain this child. + weight_ = 0; + // Start a timer to delete the child. + Ref(DEBUG_LOCATION, "WeightedChild+timer").release(); + delayed_removal_timer_callback_pending_ = true; + grpc_timer_init(&delayed_removal_timer_, + ExecCtx::Get()->Now() + kChildRetentionIntervalMs, + &on_delayed_removal_timer_); +} + +void WeightedTargetLb::WeightedChild::OnDelayedRemovalTimer( + void* arg, grpc_error_handle error) { + WeightedChild* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + self->weighted_target_policy_->work_serializer()->Run( + [self, error]() { self->OnDelayedRemovalTimerLocked(error); }, + DEBUG_LOCATION); +} + +void WeightedTargetLb::WeightedChild::OnDelayedRemovalTimerLocked( + grpc_error_handle error) { + if (error == GRPC_ERROR_NONE && delayed_removal_timer_callback_pending_ && + !shutdown_ && weight_ == 0) { + delayed_removal_timer_callback_pending_ = false; + weighted_target_policy_->targets_.erase(name_); + } + Unref(DEBUG_LOCATION, "WeightedChild+timer"); + GRPC_ERROR_UNREF(error); +} + +// +// WeightedTargetLb::WeightedChild::Helper +// + +RefCountedPtr +WeightedTargetLb::WeightedChild::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (weighted_child_->weighted_target_policy_->shutting_down_) return nullptr; + return weighted_child_->weighted_target_policy_->channel_control_helper() + ->CreateSubchannel(std::move(address), args); +} + +void WeightedTargetLb::WeightedChild::Helper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (weighted_child_->weighted_target_policy_->shutting_down_) return; + weighted_child_->OnConnectivityStateUpdateLocked(state, status, + std::move(picker)); +} + +void WeightedTargetLb::WeightedChild::Helper::RequestReresolution() { + if (weighted_child_->weighted_target_policy_->shutting_down_) return; + weighted_child_->weighted_target_policy_->channel_control_helper() + ->RequestReresolution(); +} + +absl::string_view WeightedTargetLb::WeightedChild::Helper::GetAuthority() { + return weighted_child_->weighted_target_policy_->channel_control_helper() + ->GetAuthority(); +} + +void WeightedTargetLb::WeightedChild::Helper::AddTraceEvent( + TraceSeverity severity, absl::string_view message) { + if (weighted_child_->weighted_target_policy_->shutting_down_) return; + weighted_child_->weighted_target_policy_->channel_control_helper() + ->AddTraceEvent(severity, message); +} + +// +// factory +// + +class WeightedTargetLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kWeightedTarget; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // weighted_target was mentioned as a policy in the deprecated + // loadBalancingPolicy field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:weighted_target policy requires " + "configuration. Please use loadBalancingConfig field of service " + "config instead."); + return nullptr; + } + std::vector error_list; + // Weight map. + WeightedTargetLbConfig::TargetMap target_map; + auto it = json.object_value().find("targets"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:targets error:required field not present")); + } else if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:targets error:type should be object")); + } else { + for (const auto& p : it->second.object_value()) { + WeightedTargetLbConfig::ChildConfig child_config; + std::vector child_errors = + ParseChildConfig(p.second, &child_config); + if (!child_errors.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("field:targets key:", p.first), &child_errors)); + } else { + target_map[p.first] = std::move(child_config); + } + } + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "weighted_target_experimental LB policy config", &error_list); + return nullptr; + } + return MakeRefCounted(std::move(target_map)); + } + + private: + static std::vector ParseChildConfig( + const Json& json, WeightedTargetLbConfig::ChildConfig* child_config) { + std::vector error_list; + if (json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "value should be of type object")); + return error_list; + } + // Weight. + auto it = json.object_value().find("weight"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "required field \"weight\" not specified")); + } else if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:weight error:must be of type number")); + } else { + int weight = gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (weight == -1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:weight error:unparseable value")); + } else if (weight == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:weight error:value must be greater than zero")); + } else { + child_config->weight = weight; + } + } + // Child policy. + it = json.object_value().find("childPolicy"); + if (it != json.object_value().end()) { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + child_config->config = + LoadBalancingPolicyRegistry::ParseLoadBalancingConfig(it->second, + &parse_error); + if (child_config->config == nullptr) { + GPR_DEBUG_ASSERT(parse_error != GRPC_ERROR_NONE); + std::vector child_errors; + child_errors.push_back(parse_error); + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:childPolicy", &child_errors)); + } + } + return error_list; + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_weighted_target_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_weighted_target_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc new file mode 100644 index 00000000..49a4496e --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc @@ -0,0 +1,745 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/xds/xds_certificate_provider.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/security/credentials/xds/xds_credentials.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +TraceFlag grpc_cds_lb_trace(false, "cds_lb"); + +namespace { + +constexpr char kCds[] = "cds_experimental"; + +// Config for this LB policy. +class CdsLbConfig : public LoadBalancingPolicy::Config { + public: + explicit CdsLbConfig(std::string cluster) : cluster_(std::move(cluster)) {} + const std::string& cluster() const { return cluster_; } + const char* name() const override { return kCds; } + + private: + std::string cluster_; +}; + +// CDS LB policy. +class CdsLb : public LoadBalancingPolicy { + public: + CdsLb(RefCountedPtr xds_client, Args args); + + const char* name() const override { return kCds; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + void ExitIdleLocked() override; + + private: + // Watcher for getting cluster data from XdsClient. + class ClusterWatcher : public XdsClient::ClusterWatcherInterface { + public: + ClusterWatcher(RefCountedPtr parent, std::string name) + : parent_(std::move(parent)), name_(std::move(name)) {} + + void OnClusterChanged(XdsApi::CdsUpdate cluster_data) override { + new Notifier(parent_, name_, std::move(cluster_data)); + } + void OnError(grpc_error_handle error) override { + new Notifier(parent_, name_, error); + } + void OnResourceDoesNotExist() override { new Notifier(parent_, name_); } + + private: + class Notifier { + public: + Notifier(RefCountedPtr parent, std::string name, + XdsApi::CdsUpdate update); + Notifier(RefCountedPtr parent, std::string name, + grpc_error_handle error); + explicit Notifier(RefCountedPtr parent, std::string name); + + private: + enum Type { kUpdate, kError, kDoesNotExist }; + + static void RunInExecCtx(void* arg, grpc_error_handle error); + void RunInWorkSerializer(grpc_error_handle error); + + RefCountedPtr parent_; + std::string name_; + grpc_closure closure_; + XdsApi::CdsUpdate update_; + Type type_; + }; + + RefCountedPtr parent_; + std::string name_; + }; + + struct WatcherState { + // Pointer to watcher, to be used when cancelling. + // Not owned, so do not dereference. + ClusterWatcher* watcher = nullptr; + // Most recent update obtained from this watcher. + absl::optional update; + }; + + // Delegating helper to be passed to child policy. + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr parent) : parent_(std::move(parent)) {} + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr parent_; + }; + + ~CdsLb() override; + + void ShutdownLocked() override; + + bool GenerateDiscoveryMechanismForCluster( + const std::string& name, Json::Array* discovery_mechanisms, + std::set* clusters_needed); + void OnClusterChanged(const std::string& name, + XdsApi::CdsUpdate cluster_data); + void OnError(const std::string& name, grpc_error_handle error); + void OnResourceDoesNotExist(const std::string& name); + + grpc_error_handle UpdateXdsCertificateProvider( + const std::string& cluster_name, const XdsApi::CdsUpdate& cluster_data); + + void CancelClusterDataWatch(absl::string_view cluster_name, + XdsClient::ClusterWatcherInterface* watcher, + bool delay_unsubscription = false); + + void MaybeDestroyChildPolicyLocked(); + + RefCountedPtr config_; + + // Current channel args from the resolver. + const grpc_channel_args* args_ = nullptr; + + // The xds client. + RefCountedPtr xds_client_; + + // Maps from cluster name to the state for that cluster. + // The root of the tree is config_->cluster(). + std::map watchers_; + + RefCountedPtr root_certificate_provider_; + RefCountedPtr identity_certificate_provider_; + RefCountedPtr xds_certificate_provider_; + + // Child LB policy. + OrphanablePtr child_policy_; + + // Internal state. + bool shutting_down_ = false; +}; + +// +// CdsLb::ClusterWatcher::Notifier +// + +CdsLb::ClusterWatcher::Notifier::Notifier(RefCountedPtr parent, + std::string name, + XdsApi::CdsUpdate update) + : parent_(std::move(parent)), + name_(std::move(name)), + update_(std::move(update)), + type_(kUpdate) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +CdsLb::ClusterWatcher::Notifier::Notifier(RefCountedPtr parent, + std::string name, + grpc_error_handle error) + : parent_(std::move(parent)), name_(std::move(name)), type_(kError) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, error); +} + +CdsLb::ClusterWatcher::Notifier::Notifier(RefCountedPtr parent, + std::string name) + : parent_(std::move(parent)), name_(std::move(name)), type_(kDoesNotExist) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +void CdsLb::ClusterWatcher::Notifier::RunInExecCtx(void* arg, + grpc_error_handle error) { + Notifier* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); + self->parent_->work_serializer()->Run( + [self, error]() { self->RunInWorkSerializer(error); }, DEBUG_LOCATION); +} + +void CdsLb::ClusterWatcher::Notifier::RunInWorkSerializer( + grpc_error_handle error) { + switch (type_) { + case kUpdate: + parent_->OnClusterChanged(name_, std::move(update_)); + break; + case kError: + parent_->OnError(name_, error); + break; + case kDoesNotExist: + parent_->OnResourceDoesNotExist(name_); + break; + }; + delete this; +} + +// +// CdsLb::Helper +// + +RefCountedPtr CdsLb::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (parent_->shutting_down_) return nullptr; + return parent_->channel_control_helper()->CreateSubchannel(std::move(address), + args); +} + +void CdsLb::Helper::UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) { + if (parent_->shutting_down_ || parent_->child_policy_ == nullptr) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, + "[cdslb %p] state updated by child: %s message_state: (%s)", this, + ConnectivityStateName(state), status.ToString().c_str()); + } + parent_->channel_control_helper()->UpdateState(state, status, + std::move(picker)); +} + +void CdsLb::Helper::RequestReresolution() { + if (parent_->shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] Re-resolution requested from child policy.", + parent_.get()); + } + parent_->channel_control_helper()->RequestReresolution(); +} + +absl::string_view CdsLb::Helper::GetAuthority() { + return parent_->channel_control_helper()->GetAuthority(); +} + +void CdsLb::Helper::AddTraceEvent(TraceSeverity severity, + absl::string_view message) { + if (parent_->shutting_down_) return; + parent_->channel_control_helper()->AddTraceEvent(severity, message); +} + +// +// CdsLb +// + +CdsLb::CdsLb(RefCountedPtr xds_client, Args args) + : LoadBalancingPolicy(std::move(args)), xds_client_(std::move(xds_client)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] created -- using xds client %p", this, + xds_client_.get()); + } +} + +CdsLb::~CdsLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] destroying cds LB policy", this); + } +} + +void CdsLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] shutting down", this); + } + shutting_down_ = true; + MaybeDestroyChildPolicyLocked(); + if (xds_client_ != nullptr) { + for (auto& watcher : watchers_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] cancelling watch for cluster %s", this, + watcher.first.c_str()); + } + CancelClusterDataWatch(watcher.first, watcher.second.watcher, + /*delay_unsubscription=*/false); + } + watchers_.clear(); + xds_client_.reset(DEBUG_LOCATION, "CdsLb"); + } + grpc_channel_args_destroy(args_); + args_ = nullptr; +} + +void CdsLb::MaybeDestroyChildPolicyLocked() { + if (child_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + interested_parties()); + child_policy_.reset(); + } +} + +void CdsLb::ResetBackoffLocked() { + if (child_policy_ != nullptr) child_policy_->ResetBackoffLocked(); +} + +void CdsLb::ExitIdleLocked() { + if (child_policy_ != nullptr) child_policy_->ExitIdleLocked(); +} + +void CdsLb::UpdateLocked(UpdateArgs args) { + // Update config. + auto old_config = std::move(config_); + config_ = std::move(args.config); + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] received update: cluster=%s", this, + config_->cluster().c_str()); + } + // Update args. + grpc_channel_args_destroy(args_); + args_ = args.args; + args.args = nullptr; + // If cluster name changed, cancel watcher and restart. + if (old_config == nullptr || old_config->cluster() != config_->cluster()) { + if (old_config != nullptr) { + for (auto& watcher : watchers_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] cancelling watch for cluster %s", this, + watcher.first.c_str()); + } + CancelClusterDataWatch(watcher.first, watcher.second.watcher, + /*delay_unsubscription=*/true); + } + watchers_.clear(); + } + auto watcher = absl::make_unique(Ref(), config_->cluster()); + watchers_[config_->cluster()].watcher = watcher.get(); + xds_client_->WatchClusterData(config_->cluster(), std::move(watcher)); + } +} + +// This method will attempt to generate one or multiple entries of discovery +// mechanism recursively: +// For cluster types EDS or LOGICAL_DNS, one discovery mechanism entry may be +// generated cluster name, type and other data from the CdsUpdate inserted into +// the entry and the entry appended to the array of entries. +// Note, discovery mechanism entry can be generated if an CdsUpdate is +// available; otherwise, just return false. For cluster type AGGREGATE, +// recursively call the method for each child cluster. +bool CdsLb::GenerateDiscoveryMechanismForCluster( + const std::string& name, Json::Array* discovery_mechanisms, + std::set* clusters_needed) { + clusters_needed->insert(name); + auto& state = watchers_[name]; + // Create a new watcher if needed. + if (state.watcher == nullptr) { + auto watcher = absl::make_unique(Ref(), name); + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] starting watch for cluster %s", this, + name.c_str()); + } + state.watcher = watcher.get(); + xds_client_->WatchClusterData(name, std::move(watcher)); + return false; + } + // Don't have the update we need yet. + if (!state.update.has_value()) return false; + // For AGGREGATE clusters, recursively expand to child clusters. + if (state.update->cluster_type == XdsApi::CdsUpdate::ClusterType::AGGREGATE) { + bool missing_cluster = false; + for (const std::string& child_name : + state.update->prioritized_cluster_names) { + if (!GenerateDiscoveryMechanismForCluster( + child_name, discovery_mechanisms, clusters_needed)) { + missing_cluster = true; + } + } + return !missing_cluster; + } + Json::Object mechanism = { + {"clusterName", name}, + {"max_concurrent_requests", state.update->max_concurrent_requests}, + }; + switch (state.update->cluster_type) { + case XdsApi::CdsUpdate::ClusterType::EDS: + mechanism["type"] = "EDS"; + if (!state.update->eds_service_name.empty()) { + mechanism["edsServiceName"] = state.update->eds_service_name; + } + break; + case XdsApi::CdsUpdate::ClusterType::LOGICAL_DNS: + mechanism["type"] = "LOGICAL_DNS"; + mechanism["dnsHostname"] = state.update->dns_hostname; + break; + default: + GPR_ASSERT(0); + break; + } + if (state.update->lrs_load_reporting_server_name.has_value()) { + mechanism["lrsLoadReportingServerName"] = + state.update->lrs_load_reporting_server_name.value(); + } + discovery_mechanisms->emplace_back(std::move(mechanism)); + return true; +} + +void CdsLb::OnClusterChanged(const std::string& name, + XdsApi::CdsUpdate cluster_data) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log( + GPR_INFO, + "[cdslb %p] received CDS update for cluster %s from xds client %p: %s", + this, name.c_str(), xds_client_.get(), cluster_data.ToString().c_str()); + } + // Store the update in the map if we are still interested in watching this + // cluster (i.e., it is not cancelled already). + // If we've already deleted this entry, then this is an update notification + // that was scheduled before the deletion, so we can just ignore it. + auto it = watchers_.find(name); + if (it == watchers_.end()) return; + it->second.update = cluster_data; + // Take care of integration with new certificate code. + grpc_error_handle error = GRPC_ERROR_NONE; + error = UpdateXdsCertificateProvider(name, it->second.update.value()); + if (error != GRPC_ERROR_NONE) { + return OnError(name, error); + } + // Scan the map starting from the root cluster to generate the list of + // discovery mechanisms. If we don't have some of the data we need (i.e., we + // just started up and not all watchers have returned data yet), then don't + // update the child policy at all. + Json::Array discovery_mechanisms; + std::set clusters_needed; + if (GenerateDiscoveryMechanismForCluster( + config_->cluster(), &discovery_mechanisms, &clusters_needed)) { + // Construct config for child policy. + Json::Object xds_lb_policy; + if (cluster_data.lb_policy == "RING_HASH") { + xds_lb_policy["RING_HASH"] = Json::Object{ + {"min_ring_size", cluster_data.min_ring_size}, + {"max_ring_size", cluster_data.max_ring_size}, + }; + } else { + xds_lb_policy["ROUND_ROBIN"] = Json::Object(); + } + Json::Object child_config = { + {"xdsLbPolicy", + Json::Array{ + xds_lb_policy, + }}, + {"discoveryMechanisms", std::move(discovery_mechanisms)}, + }; + Json json = Json::Array{ + Json::Object{ + {"xds_cluster_resolver_experimental", std::move(child_config)}, + }, + }; + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + std::string json_str = json.Dump(/*indent=*/1); + gpr_log(GPR_INFO, "[cdslb %p] generated config for child policy: %s", + this, json_str.c_str()); + } + RefCountedPtr config = + LoadBalancingPolicyRegistry::ParseLoadBalancingConfig(json, &error); + if (error != GRPC_ERROR_NONE) { + OnError(name, error); + return; + } + // Create child policy if not already present. + if (child_policy_ == nullptr) { + LoadBalancingPolicy::Args args; + args.work_serializer = work_serializer(); + args.args = args_; + args.channel_control_helper = absl::make_unique(Ref()); + child_policy_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + config->name(), std::move(args)); + if (child_policy_ == nullptr) { + OnError(name, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "failed to create child policy")); + return; + } + grpc_pollset_set_add_pollset_set(child_policy_->interested_parties(), + interested_parties()); + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] created child policy %s (%p)", this, + config->name(), child_policy_.get()); + } + } + // Update child policy. + UpdateArgs args; + args.config = std::move(config); + if (xds_certificate_provider_ != nullptr) { + grpc_arg arg_to_add = xds_certificate_provider_->MakeChannelArg(); + args.args = grpc_channel_args_copy_and_add(args_, &arg_to_add, 1); + } else { + args.args = grpc_channel_args_copy(args_); + } + child_policy_->UpdateLocked(std::move(args)); + } + // Remove entries in watchers_ for any clusters not in clusters_needed + for (auto it = watchers_.begin(); it != watchers_.end();) { + const std::string& cluster_name = it->first; + if (clusters_needed.find(cluster_name) != clusters_needed.end()) { + ++it; + continue; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { + gpr_log(GPR_INFO, "[cdslb %p] cancelling watch for cluster %s", this, + cluster_name.c_str()); + } + CancelClusterDataWatch(cluster_name, it->second.watcher, + /*delay_unsubscription=*/false); + it = watchers_.erase(it); + } +} + +void CdsLb::OnError(const std::string& name, grpc_error_handle error) { + gpr_log(GPR_ERROR, "[cdslb %p] xds error obtaining data for cluster %s: %s", + this, name.c_str(), grpc_error_std_string(error).c_str()); + // Go into TRANSIENT_FAILURE if we have not yet created the child + // policy (i.e., we have not yet received data from xds). Otherwise, + // we keep running with the data we had previously. + if (child_policy_ == nullptr) { + absl::Status status = grpc_error_to_absl_status(error); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } + GRPC_ERROR_UNREF(error); +} + +void CdsLb::OnResourceDoesNotExist(const std::string& name) { + gpr_log(GPR_ERROR, + "[cdslb %p] CDS resource for %s does not exist -- reporting " + "TRANSIENT_FAILURE", + this, name.c_str()); + absl::Status status = absl::UnavailableError( + absl::StrCat("CDS resource \"", config_->cluster(), "\" does not exist")); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + MaybeDestroyChildPolicyLocked(); +} + +grpc_error_handle CdsLb::UpdateXdsCertificateProvider( + const std::string& cluster_name, const XdsApi::CdsUpdate& cluster_data) { + // Early out if channel is not configured to use xds security. + grpc_channel_credentials* channel_credentials = + grpc_channel_credentials_find_in_args(args_); + if (channel_credentials == nullptr || + channel_credentials->type() != kCredentialsTypeXds) { + xds_certificate_provider_ = nullptr; + return GRPC_ERROR_NONE; + } + if (xds_certificate_provider_ == nullptr) { + xds_certificate_provider_ = MakeRefCounted(); + } + // Configure root cert. + absl::string_view root_provider_instance_name = + cluster_data.common_tls_context.certificate_validation_context + .ca_certificate_provider_instance.instance_name; + absl::string_view root_provider_cert_name = + cluster_data.common_tls_context.certificate_validation_context + .ca_certificate_provider_instance.certificate_name; + RefCountedPtr new_root_provider; + if (!root_provider_instance_name.empty()) { + new_root_provider = + xds_client_->certificate_provider_store() + .CreateOrGetCertificateProvider(root_provider_instance_name); + if (new_root_provider == nullptr) { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Certificate provider instance name: \"", + root_provider_instance_name, "\" not recognized.")), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + } + } + if (root_certificate_provider_ != new_root_provider) { + if (root_certificate_provider_ != nullptr && + root_certificate_provider_->interested_parties() != nullptr) { + grpc_pollset_set_del_pollset_set( + interested_parties(), + root_certificate_provider_->interested_parties()); + } + if (new_root_provider != nullptr && + new_root_provider->interested_parties() != nullptr) { + grpc_pollset_set_add_pollset_set(interested_parties(), + new_root_provider->interested_parties()); + } + root_certificate_provider_ = std::move(new_root_provider); + } + xds_certificate_provider_->UpdateRootCertNameAndDistributor( + cluster_name, root_provider_cert_name, + root_certificate_provider_ == nullptr + ? nullptr + : root_certificate_provider_->distributor()); + // Configure identity cert. + absl::string_view identity_provider_instance_name = + cluster_data.common_tls_context.tls_certificate_provider_instance + .instance_name; + absl::string_view identity_provider_cert_name = + cluster_data.common_tls_context.tls_certificate_provider_instance + .certificate_name; + RefCountedPtr new_identity_provider; + if (!identity_provider_instance_name.empty()) { + new_identity_provider = + xds_client_->certificate_provider_store() + .CreateOrGetCertificateProvider(identity_provider_instance_name); + if (new_identity_provider == nullptr) { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Certificate provider instance name: \"", + identity_provider_instance_name, "\" not recognized.")), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + } + } + if (identity_certificate_provider_ != new_identity_provider) { + if (identity_certificate_provider_ != nullptr && + identity_certificate_provider_->interested_parties() != nullptr) { + grpc_pollset_set_del_pollset_set( + interested_parties(), + identity_certificate_provider_->interested_parties()); + } + if (new_identity_provider != nullptr && + new_identity_provider->interested_parties() != nullptr) { + grpc_pollset_set_add_pollset_set( + interested_parties(), new_identity_provider->interested_parties()); + } + identity_certificate_provider_ = std::move(new_identity_provider); + } + xds_certificate_provider_->UpdateIdentityCertNameAndDistributor( + cluster_name, identity_provider_cert_name, + identity_certificate_provider_ == nullptr + ? nullptr + : identity_certificate_provider_->distributor()); + // Configure SAN matchers. + const std::vector& match_subject_alt_names = + cluster_data.common_tls_context.certificate_validation_context + .match_subject_alt_names; + xds_certificate_provider_->UpdateSubjectAlternativeNameMatchers( + cluster_name, match_subject_alt_names); + return GRPC_ERROR_NONE; +} + +void CdsLb::CancelClusterDataWatch(absl::string_view cluster_name, + XdsClient::ClusterWatcherInterface* watcher, + bool delay_unsubscription) { + if (xds_certificate_provider_ != nullptr) { + std::string name(cluster_name); + xds_certificate_provider_->UpdateRootCertNameAndDistributor(name, "", + nullptr); + xds_certificate_provider_->UpdateIdentityCertNameAndDistributor(name, "", + nullptr); + xds_certificate_provider_->UpdateSubjectAlternativeNameMatchers(name, {}); + } + xds_client_->CancelClusterDataWatch(cluster_name, watcher, + delay_unsubscription); +} +// +// factory +// + +class CdsLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + RefCountedPtr xds_client = + XdsClient::GetFromChannelArgs(*args.args); + if (xds_client == nullptr) { + gpr_log(GPR_ERROR, + "XdsClient not present in channel args -- cannot instantiate " + "cds LB policy"); + return nullptr; + } + return MakeOrphanable(std::move(xds_client), std::move(args)); + } + + const char* name() const override { return kCds; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // xds was mentioned as a policy in the deprecated loadBalancingPolicy + // field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:cds policy requires configuration. " + "Please use loadBalancingConfig field of service config instead."); + return nullptr; + } + std::vector error_list; + // cluster name. + std::string cluster; + auto it = json.object_value().find("cluster"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "required field 'cluster' not present")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:cluster error:type should be string")); + } else { + cluster = it->second.string_value(); + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Cds Parser", &error_list); + return nullptr; + } + return MakeRefCounted(std::move(cluster)); + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_cds_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_cds_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_impl.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_impl.cc new file mode 100644 index 00000000..c65fa330 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_impl.cc @@ -0,0 +1,795 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include "absl/strings/string_view.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/ext/xds/xds_client_stats.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/work_serializer.h" + +namespace grpc_core { + +TraceFlag grpc_xds_cluster_impl_lb_trace(false, "xds_cluster_impl_lb"); + +namespace { + +// +// global circuit breaker atomic map +// + +class CircuitBreakerCallCounterMap { + public: + using Key = + std::pair; + + class CallCounter : public RefCounted { + public: + explicit CallCounter(Key key) : key_(std::move(key)) {} + ~CallCounter() override; + + uint32_t Load() { + return concurrent_requests_.load(std::memory_order_seq_cst); + } + uint32_t Increment() { return concurrent_requests_.fetch_add(1); } + void Decrement() { concurrent_requests_.fetch_sub(1); } + + private: + Key key_; + std::atomic concurrent_requests_{0}; + }; + + RefCountedPtr GetOrCreate(const std::string& cluster, + const std::string& eds_service_name); + + private: + Mutex mu_; + std::map map_ ABSL_GUARDED_BY(mu_); +}; + +CircuitBreakerCallCounterMap* g_call_counter_map = nullptr; + +RefCountedPtr +CircuitBreakerCallCounterMap::GetOrCreate(const std::string& cluster, + const std::string& eds_service_name) { + Key key(cluster, eds_service_name); + RefCountedPtr result; + MutexLock lock(&mu_); + auto it = map_.find(key); + if (it == map_.end()) { + it = map_.insert({key, nullptr}).first; + } else { + result = it->second->RefIfNonZero(); + } + if (result == nullptr) { + result = MakeRefCounted(std::move(key)); + it->second = result.get(); + } + return result; +} + +CircuitBreakerCallCounterMap::CallCounter::~CallCounter() { + MutexLock lock(&g_call_counter_map->mu_); + auto it = g_call_counter_map->map_.find(key_); + if (it != g_call_counter_map->map_.end() && it->second == this) { + g_call_counter_map->map_.erase(it); + } +} + +// +// LB policy +// + +constexpr char kXdsClusterImpl[] = "xds_cluster_impl_experimental"; + +// Config for xDS Cluster Impl LB policy. +class XdsClusterImplLbConfig : public LoadBalancingPolicy::Config { + public: + XdsClusterImplLbConfig( + RefCountedPtr child_policy, + std::string cluster_name, std::string eds_service_name, + absl::optional lrs_load_reporting_server_name, + uint32_t max_concurrent_requests, + RefCountedPtr drop_config) + : child_policy_(std::move(child_policy)), + cluster_name_(std::move(cluster_name)), + eds_service_name_(std::move(eds_service_name)), + lrs_load_reporting_server_name_( + std::move(lrs_load_reporting_server_name)), + max_concurrent_requests_(max_concurrent_requests), + drop_config_(std::move(drop_config)) {} + + const char* name() const override { return kXdsClusterImpl; } + + RefCountedPtr child_policy() const { + return child_policy_; + } + const std::string& cluster_name() const { return cluster_name_; } + const std::string& eds_service_name() const { return eds_service_name_; } + const absl::optional& lrs_load_reporting_server_name() const { + return lrs_load_reporting_server_name_; + }; + uint32_t max_concurrent_requests() const { return max_concurrent_requests_; } + RefCountedPtr drop_config() const { + return drop_config_; + } + + private: + RefCountedPtr child_policy_; + std::string cluster_name_; + std::string eds_service_name_; + absl::optional lrs_load_reporting_server_name_; + uint32_t max_concurrent_requests_; + RefCountedPtr drop_config_; +}; + +// xDS Cluster Impl LB policy. +class XdsClusterImplLb : public LoadBalancingPolicy { + public: + XdsClusterImplLb(RefCountedPtr xds_client, Args args); + + const char* name() const override { return kXdsClusterImpl; } + + void UpdateLocked(UpdateArgs args) override; + void ExitIdleLocked() override; + void ResetBackoffLocked() override; + + private: + class StatsSubchannelWrapper : public DelegatingSubchannel { + public: + StatsSubchannelWrapper( + RefCountedPtr wrapped_subchannel, + RefCountedPtr locality_stats) + : DelegatingSubchannel(std::move(wrapped_subchannel)), + locality_stats_(std::move(locality_stats)) {} + + XdsClusterLocalityStats* locality_stats() const { + return locality_stats_.get(); + } + + private: + RefCountedPtr locality_stats_; + }; + + // A simple wrapper for ref-counting a picker from the child policy. + class RefCountedPicker : public RefCounted { + public: + explicit RefCountedPicker(std::unique_ptr picker) + : picker_(std::move(picker)) {} + PickResult Pick(PickArgs args) { return picker_->Pick(args); } + + private: + std::unique_ptr picker_; + }; + + // A picker that wraps the picker from the child to perform drops. + class Picker : public SubchannelPicker { + public: + Picker(XdsClusterImplLb* xds_cluster_impl_lb, + RefCountedPtr picker); + + PickResult Pick(PickArgs args) override; + + private: + RefCountedPtr call_counter_; + uint32_t max_concurrent_requests_; + RefCountedPtr drop_config_; + RefCountedPtr drop_stats_; + RefCountedPtr picker_; + }; + + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr xds_cluster_impl_policy) + : xds_cluster_impl_policy_(std::move(xds_cluster_impl_policy)) {} + + ~Helper() override { + xds_cluster_impl_policy_.reset(DEBUG_LOCATION, "Helper"); + } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr xds_cluster_impl_policy_; + }; + + ~XdsClusterImplLb() override; + + void ShutdownLocked() override; + + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + void UpdateChildPolicyLocked(ServerAddressList addresses, + const grpc_channel_args* args); + + void MaybeUpdatePickerLocked(); + + // Current config from the resolver. + RefCountedPtr config_; + + // Current concurrent number of requests. + RefCountedPtr call_counter_; + + // Internal state. + bool shutting_down_ = false; + + // The xds client. + RefCountedPtr xds_client_; + + // The stats for client-side load reporting. + RefCountedPtr drop_stats_; + + OrphanablePtr child_policy_; + + // Latest state and picker reported by the child policy. + grpc_connectivity_state state_ = GRPC_CHANNEL_IDLE; + absl::Status status_; + RefCountedPtr picker_; +}; + +// +// XdsClusterImplLb::Picker +// + +XdsClusterImplLb::Picker::Picker(XdsClusterImplLb* xds_cluster_impl_lb, + RefCountedPtr picker) + : call_counter_(xds_cluster_impl_lb->call_counter_), + max_concurrent_requests_( + xds_cluster_impl_lb->config_->max_concurrent_requests()), + drop_config_(xds_cluster_impl_lb->config_->drop_config()), + drop_stats_(xds_cluster_impl_lb->drop_stats_), + picker_(std::move(picker)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_impl_lb %p] constructed new picker %p", + xds_cluster_impl_lb, this); + } +} + +LoadBalancingPolicy::PickResult XdsClusterImplLb::Picker::Pick( + LoadBalancingPolicy::PickArgs args) { + // Handle EDS drops. + const std::string* drop_category; + if (drop_config_->ShouldDrop(&drop_category)) { + if (drop_stats_ != nullptr) drop_stats_->AddCallDropped(*drop_category); + return PickResult::Drop(absl::UnavailableError( + absl::StrCat("EDS-configured drop: ", *drop_category))); + } + // Handle circuit breaking. + uint32_t current = call_counter_->Load(); + // Check and see if we exceeded the max concurrent requests count. + if (current >= max_concurrent_requests_) { + if (drop_stats_ != nullptr) drop_stats_->AddUncategorizedDrops(); + return PickResult::Drop(absl::UnavailableError("circuit breaker drop")); + } + call_counter_->Increment(); + // If we're not dropping the call, we should always have a child picker. + if (picker_ == nullptr) { // Should never happen. + call_counter_->Decrement(); + return PickResult::Fail(absl::InternalError( + "xds_cluster_impl picker not given any child picker")); + } + // Not dropping, so delegate to child picker. + PickResult result = picker_->Pick(args); + auto* complete_pick = absl::get_if(&result.result); + if (complete_pick != nullptr) { + XdsClusterLocalityStats* locality_stats = nullptr; + if (drop_stats_ != nullptr) { // If load reporting is enabled. + auto* subchannel_wrapper = + static_cast(complete_pick->subchannel.get()); + // Handle load reporting. + locality_stats = subchannel_wrapper->locality_stats()->Ref().release(); + // Record a call started. + locality_stats->AddCallStarted(); + // Unwrap subchannel to pass back up the stack. + complete_pick->subchannel = subchannel_wrapper->wrapped_subchannel(); + } + // Intercept the recv_trailing_metadata op to record call completion. + auto* call_counter = call_counter_->Ref(DEBUG_LOCATION, "call").release(); + auto original_recv_trailing_metadata_ready = + complete_pick->recv_trailing_metadata_ready; + complete_pick->recv_trailing_metadata_ready = + // Note: This callback does not run in either the control plane + // work serializer or in the data plane mutex. + [locality_stats, original_recv_trailing_metadata_ready, call_counter]( + absl::Status status, MetadataInterface* metadata, + CallState* call_state) { + // Record call completion for load reporting. + if (locality_stats != nullptr) { + locality_stats->AddCallFinished(!status.ok()); + locality_stats->Unref(DEBUG_LOCATION, "LocalityStats+call"); + } + // Decrement number of calls in flight. + call_counter->Decrement(); + call_counter->Unref(DEBUG_LOCATION, "call"); + // Invoke the original recv_trailing_metadata_ready callback, if any. + if (original_recv_trailing_metadata_ready != nullptr) { + original_recv_trailing_metadata_ready(status, metadata, call_state); + } + }; + } else { + // TODO(roth): We should ideally also record call failures here in the case + // where a pick fails. This is challenging, because we don't know which + // picks are for wait_for_ready RPCs or how many times we'll return a + // failure for the same wait_for_ready RPC. + call_counter_->Decrement(); + } + return result; +} + +// +// XdsClusterImplLb +// + +XdsClusterImplLb::XdsClusterImplLb(RefCountedPtr xds_client, + Args args) + : LoadBalancingPolicy(std::move(args)), xds_client_(std::move(xds_client)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_impl_lb %p] created -- using xds client %p", + this, xds_client_.get()); + } +} + +XdsClusterImplLb::~XdsClusterImplLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] destroying xds_cluster_impl LB policy", + this); + } +} + +void XdsClusterImplLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_impl_lb %p] shutting down", this); + } + shutting_down_ = true; + // Remove the child policy's interested_parties pollset_set from the + // xDS policy. + if (child_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + interested_parties()); + child_policy_.reset(); + } + // Drop our ref to the child's picker, in case it's holding a ref to + // the child. + picker_.reset(); + drop_stats_.reset(); + xds_client_.reset(); +} + +void XdsClusterImplLb::ExitIdleLocked() { + if (child_policy_ != nullptr) child_policy_->ExitIdleLocked(); +} + +void XdsClusterImplLb::ResetBackoffLocked() { + // The XdsClient will have its backoff reset by the xds resolver, so we + // don't need to do it here. + if (child_policy_ != nullptr) child_policy_->ResetBackoffLocked(); +} + +void XdsClusterImplLb::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_impl_lb %p] Received update", this); + } + // Update config. + const bool is_initial_update = config_ == nullptr; + auto old_config = std::move(config_); + config_ = std::move(args.config); + // On initial update, create drop stats. + if (is_initial_update) { + if (config_->lrs_load_reporting_server_name().has_value()) { + drop_stats_ = xds_client_->AddClusterDropStats( + config_->lrs_load_reporting_server_name().value(), + config_->cluster_name(), config_->eds_service_name()); + } + call_counter_ = g_call_counter_map->GetOrCreate( + config_->cluster_name(), config_->eds_service_name()); + } else { + // Cluster name, EDS service name, and LRS server name should never + // change, because the EDS policy above us should be swapped out if + // that happens. + GPR_ASSERT(config_->cluster_name() == old_config->cluster_name()); + GPR_ASSERT(config_->eds_service_name() == old_config->eds_service_name()); + GPR_ASSERT(config_->lrs_load_reporting_server_name() == + old_config->lrs_load_reporting_server_name()); + } + // Update picker if max_concurrent_requests has changed. + if (is_initial_update || config_->max_concurrent_requests() != + old_config->max_concurrent_requests()) { + MaybeUpdatePickerLocked(); + } + // Update child policy. + UpdateChildPolicyLocked(std::move(args.addresses), args.args); +} + +void XdsClusterImplLb::MaybeUpdatePickerLocked() { + // If we're dropping all calls, report READY, regardless of what (or + // whether) the child has reported. + if (config_->drop_config() != nullptr && config_->drop_config()->drop_all()) { + auto drop_picker = absl::make_unique(this, picker_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] updating connectivity (drop all): " + "state=READY " + "picker=%p", + this, drop_picker.get()); + } + channel_control_helper()->UpdateState(GRPC_CHANNEL_READY, absl::Status(), + std::move(drop_picker)); + return; + } + // Otherwise, update only if we have a child picker. + if (picker_ != nullptr) { + auto drop_picker = absl::make_unique(this, picker_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] updating connectivity: state=%s " + "status=(%s) " + "picker=%p", + this, ConnectivityStateName(state_), status_.ToString().c_str(), + drop_picker.get()); + } + channel_control_helper()->UpdateState(state_, status_, + std::move(drop_picker)); + } +} + +OrphanablePtr XdsClusterImplLb::CreateChildPolicyLocked( + const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = + absl::make_unique(Ref(DEBUG_LOCATION, "Helper")); + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_xds_cluster_impl_lb_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] Created new child policy handler %p", + this, lb_policy.get()); + } + // Add our interested_parties pollset_set to that of the newly created + // child policy. This will make the child policy progress upon activity on + // this policy, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + interested_parties()); + return lb_policy; +} + +void XdsClusterImplLb::UpdateChildPolicyLocked(ServerAddressList addresses, + const grpc_channel_args* args) { + // Create policy if needed. + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(args); + } + // Construct update args. + UpdateArgs update_args; + update_args.addresses = std::move(addresses); + update_args.config = config_->child_policy(); + grpc_arg cluster_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_XDS_CLUSTER_NAME), + const_cast(config_->cluster_name().c_str())); + update_args.args = grpc_channel_args_copy_and_add(args, &cluster_arg, 1); + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] Updating child policy handler %p", this, + child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +// +// XdsClusterImplLb::Helper +// + +RefCountedPtr XdsClusterImplLb::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (xds_cluster_impl_policy_->shutting_down_) return nullptr; + // If load reporting is enabled, wrap the subchannel such that it + // includes the locality stats object, which will be used by the EdsPicker. + if (xds_cluster_impl_policy_->config_->lrs_load_reporting_server_name() + .has_value()) { + RefCountedPtr locality_name; + auto* attribute = address.GetAttribute(kXdsLocalityNameAttributeKey); + if (attribute != nullptr) { + const auto* locality_attr = + static_cast(attribute); + locality_name = locality_attr->locality_name(); + } + RefCountedPtr locality_stats = + xds_cluster_impl_policy_->xds_client_->AddClusterLocalityStats( + *xds_cluster_impl_policy_->config_ + ->lrs_load_reporting_server_name(), + xds_cluster_impl_policy_->config_->cluster_name(), + xds_cluster_impl_policy_->config_->eds_service_name(), + std::move(locality_name)); + return MakeRefCounted( + xds_cluster_impl_policy_->channel_control_helper()->CreateSubchannel( + std::move(address), args), + std::move(locality_stats)); + } + // Load reporting not enabled, so don't wrap the subchannel. + return xds_cluster_impl_policy_->channel_control_helper()->CreateSubchannel( + std::move(address), args); +} + +void XdsClusterImplLb::Helper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (xds_cluster_impl_policy_->shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_impl_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_impl_lb %p] child connectivity state update: " + "state=%s (%s) " + "picker=%p", + xds_cluster_impl_policy_.get(), ConnectivityStateName(state), + status.ToString().c_str(), picker.get()); + } + // Save the state and picker. + xds_cluster_impl_policy_->state_ = state; + xds_cluster_impl_policy_->status_ = status; + xds_cluster_impl_policy_->picker_ = + MakeRefCounted(std::move(picker)); + // Wrap the picker and return it to the channel. + xds_cluster_impl_policy_->MaybeUpdatePickerLocked(); +} + +void XdsClusterImplLb::Helper::RequestReresolution() { + if (xds_cluster_impl_policy_->shutting_down_) return; + xds_cluster_impl_policy_->channel_control_helper()->RequestReresolution(); +} + +absl::string_view XdsClusterImplLb::Helper::GetAuthority() { + return xds_cluster_impl_policy_->channel_control_helper()->GetAuthority(); +} + +void XdsClusterImplLb::Helper::AddTraceEvent(TraceSeverity severity, + absl::string_view message) { + if (xds_cluster_impl_policy_->shutting_down_) return; + xds_cluster_impl_policy_->channel_control_helper()->AddTraceEvent(severity, + message); +} + +// +// factory +// + +class XdsClusterImplLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + RefCountedPtr xds_client = + XdsClient::GetFromChannelArgs(*args.args); + if (xds_client == nullptr) { + gpr_log(GPR_ERROR, + "XdsClient not present in channel args -- cannot instantiate " + "xds_cluster_impl LB policy"); + return nullptr; + } + return MakeOrphanable(std::move(xds_client), + std::move(args)); + } + + const char* name() const override { return kXdsClusterImpl; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // This policy was configured in the deprecated loadBalancingPolicy + // field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:xds_cluster_impl policy requires " + "configuration. Please use loadBalancingConfig field of service " + "config instead."); + return nullptr; + } + std::vector error_list; + // Child policy. + RefCountedPtr child_policy; + auto it = json.object_value().find("childPolicy"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:childPolicy error:required field missing")); + } else { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + child_policy = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + it->second, &parse_error); + if (child_policy == nullptr) { + GPR_DEBUG_ASSERT(parse_error != GRPC_ERROR_NONE); + std::vector child_errors; + child_errors.push_back(parse_error); + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:childPolicy", &child_errors)); + } + } + // Cluster name. + std::string cluster_name; + it = json.object_value().find("clusterName"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clusterName error:required field missing")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clusterName error:type should be string")); + } else { + cluster_name = it->second.string_value(); + } + // EDS service name. + std::string eds_service_name; + it = json.object_value().find("edsServiceName"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:edsServiceName error:type should be string")); + } else { + eds_service_name = it->second.string_value(); + } + } + // LRS load reporting server name. + absl::optional lrs_load_reporting_server_name; + it = json.object_value().find("lrsLoadReportingServerName"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:lrsLoadReportingServerName error:type should be string")); + } else { + lrs_load_reporting_server_name = it->second.string_value(); + } + } + // Max concurrent requests. + uint32_t max_concurrent_requests = 1024; + it = json.object_value().find("maxConcurrentRequests"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:max_concurrent_requests error:must be of type number")); + } else { + max_concurrent_requests = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + } + } + // Drop config. + auto drop_config = MakeRefCounted(); + it = json.object_value().find("dropCategories"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:dropCategories error:required field missing")); + } else { + std::vector child_errors = + ParseDropCategories(it->second, drop_config.get()); + if (!child_errors.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:dropCategories", &child_errors)); + } + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "xds_cluster_impl_experimental LB policy config", &error_list); + return nullptr; + } + return MakeRefCounted( + std::move(child_policy), std::move(cluster_name), + std::move(eds_service_name), std::move(lrs_load_reporting_server_name), + max_concurrent_requests, std::move(drop_config)); + } + + private: + static std::vector ParseDropCategories( + const Json& json, XdsApi::EdsUpdate::DropConfig* drop_config) { + std::vector error_list; + if (json.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "dropCategories field is not an array")); + return error_list; + } + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& entry = json.array_value()[i]; + std::vector child_errors = + ParseDropCategory(entry, drop_config); + if (!child_errors.empty()) { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("errors parsing index ", i)); + for (size_t i = 0; i < child_errors.size(); ++i) { + error = grpc_error_add_child(error, child_errors[i]); + } + error_list.push_back(error); + } + } + return error_list; + } + + static std::vector ParseDropCategory( + const Json& json, XdsApi::EdsUpdate::DropConfig* drop_config) { + std::vector error_list; + if (json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "dropCategories entry is not an object")); + return error_list; + } + std::string category; + auto it = json.object_value().find("category"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"category\" field not present")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"category\" field is not a string")); + } else { + category = it->second.string_value(); + } + uint32_t requests_per_million = 0; + it = json.object_value().find("requests_per_million"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"requests_per_million\" field is not present")); + } else if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"requests_per_million\" field is not a number")); + } else { + requests_per_million = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + } + if (error_list.empty()) { + drop_config->AddCategory(std::move(category), requests_per_million); + } + return error_list; + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_xds_cluster_impl_init() { + grpc_core::g_call_counter_map = new grpc_core::CircuitBreakerCallCounterMap(); + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_xds_cluster_impl_shutdown() { + delete grpc_core::g_call_counter_map; +} diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc new file mode 100644 index 00000000..0455204b --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc @@ -0,0 +1,701 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/resolver/xds/xds_resolver.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/transport/error_utils.h" + +#define GRPC_XDS_CLUSTER_MANAGER_CHILD_RETENTION_INTERVAL_MS (15 * 60 * 1000) + +namespace grpc_core { + +TraceFlag grpc_xds_cluster_manager_lb_trace(false, "xds_cluster_manager_lb"); + +namespace { + +constexpr char kXdsClusterManager[] = "xds_cluster_manager_experimental"; + +// Config for xds_cluster_manager LB policy. +class XdsClusterManagerLbConfig : public LoadBalancingPolicy::Config { + public: + using ClusterMap = + std::map>; + + explicit XdsClusterManagerLbConfig(ClusterMap cluster_map) + : cluster_map_(std::move(cluster_map)) {} + + const char* name() const override { return kXdsClusterManager; } + + const ClusterMap& cluster_map() const { return cluster_map_; } + + private: + ClusterMap cluster_map_; +}; + +// xds_cluster_manager LB policy. +class XdsClusterManagerLb : public LoadBalancingPolicy { + public: + explicit XdsClusterManagerLb(Args args); + + const char* name() const override { return kXdsClusterManager; } + + void UpdateLocked(UpdateArgs args) override; + void ExitIdleLocked() override; + void ResetBackoffLocked() override; + + private: + // A simple wrapper for ref-counting a picker from the child policy. + class ChildPickerWrapper : public RefCounted { + public: + ChildPickerWrapper(std::string name, + std::unique_ptr picker) + : name_(std::move(name)), picker_(std::move(picker)) {} + PickResult Pick(PickArgs args) { return picker_->Pick(args); } + + const std::string& name() const { return name_; } + + private: + std::string name_; + std::unique_ptr picker_; + }; + + // Picks a child using prefix or path matching and then delegates to that + // child's picker. + class ClusterPicker : public SubchannelPicker { + public: + // Maintains a map of cluster names to pickers. + using ClusterMap = std::map>; + + // It is required that the keys of cluster_map have to live at least as long + // as the ClusterPicker instance. + explicit ClusterPicker(ClusterMap cluster_map) + : cluster_map_(std::move(cluster_map)) {} + + PickResult Pick(PickArgs args) override; + + private: + ClusterMap cluster_map_; + }; + + // Each ClusterChild holds a ref to its parent XdsClusterManagerLb. + class ClusterChild : public InternallyRefCounted { + public: + ClusterChild(RefCountedPtr xds_cluster_manager_policy, + const std::string& name); + ~ClusterChild() override; + + void Orphan() override; + + void UpdateLocked(RefCountedPtr config, + const ServerAddressList& addresses, + const grpc_channel_args* args); + void ExitIdleLocked(); + void ResetBackoffLocked(); + void DeactivateLocked(); + + grpc_connectivity_state connectivity_state() const { + return connectivity_state_; + } + RefCountedPtr picker_wrapper() const { + return picker_wrapper_; + } + + private: + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr xds_cluster_manager_child) + : xds_cluster_manager_child_(std::move(xds_cluster_manager_child)) {} + + ~Helper() override { + xds_cluster_manager_child_.reset(DEBUG_LOCATION, "Helper"); + } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, + const absl::Status& status, + std::unique_ptr picker) override; + void RequestReresolution() override; + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr xds_cluster_manager_child_; + }; + + // Methods for dealing with the child policy. + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + + static void OnDelayedRemovalTimer(void* arg, grpc_error_handle error); + void OnDelayedRemovalTimerLocked(grpc_error_handle error); + + // The owning LB policy. + RefCountedPtr xds_cluster_manager_policy_; + + // Points to the corresponding key in children map. + const std::string name_; + + OrphanablePtr child_policy_; + + RefCountedPtr picker_wrapper_; + grpc_connectivity_state connectivity_state_ = GRPC_CHANNEL_IDLE; + bool seen_failure_since_ready_ = false; + + // States for delayed removal. + grpc_timer delayed_removal_timer_; + grpc_closure on_delayed_removal_timer_; + bool delayed_removal_timer_callback_pending_ = false; + bool shutdown_ = false; + }; + + ~XdsClusterManagerLb() override; + + void ShutdownLocked() override; + + void UpdateStateLocked(); + + // Current config from the resolver. + RefCountedPtr config_; + + // Internal state. + bool shutting_down_ = false; + + // Children. + std::map> children_; +}; + +// +// XdsClusterManagerLb::ClusterPicker +// + +XdsClusterManagerLb::PickResult XdsClusterManagerLb::ClusterPicker::Pick( + PickArgs args) { + auto cluster_name = + args.call_state->ExperimentalGetCallAttribute(kXdsClusterAttribute); + auto it = cluster_map_.find(cluster_name); + if (it != cluster_map_.end()) { + return it->second->Pick(args); + } + return PickResult::Fail(absl::InternalError(absl::StrCat( + "xds cluster manager picker: unknown cluster \"", cluster_name, "\""))); +} + +// +// XdsClusterManagerLb +// + +XdsClusterManagerLb::XdsClusterManagerLb(Args args) + : LoadBalancingPolicy(std::move(args)) {} + +XdsClusterManagerLb::~XdsClusterManagerLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log( + GPR_INFO, + "[xds_cluster_manager_lb %p] destroying xds_cluster_manager LB policy", + this); + } +} + +void XdsClusterManagerLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_manager_lb %p] shutting down", this); + } + shutting_down_ = true; + children_.clear(); +} + +void XdsClusterManagerLb::ExitIdleLocked() { + for (auto& p : children_) p.second->ExitIdleLocked(); +} + +void XdsClusterManagerLb::ResetBackoffLocked() { + for (auto& p : children_) p.second->ResetBackoffLocked(); +} + +void XdsClusterManagerLb::UpdateLocked(UpdateArgs args) { + if (shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_manager_lb %p] Received update", this); + } + // Update config. + config_ = std::move(args.config); + // Deactivate the children not in the new config. + for (const auto& p : children_) { + const std::string& name = p.first; + ClusterChild* child = p.second.get(); + if (config_->cluster_map().find(name) == config_->cluster_map().end()) { + child->DeactivateLocked(); + } + } + // Add or update the children in the new config. + for (const auto& p : config_->cluster_map()) { + const std::string& name = p.first; + const RefCountedPtr& config = p.second; + auto it = children_.find(name); + if (it == children_.end()) { + it = children_ + .emplace(name, MakeOrphanable( + Ref(DEBUG_LOCATION, "ClusterChild"), name)) + .first; + } + it->second->UpdateLocked(config, args.addresses, args.args); + } + UpdateStateLocked(); +} + +void XdsClusterManagerLb::UpdateStateLocked() { + // Also count the number of children in each state, to determine the + // overall state. + size_t num_ready = 0; + size_t num_connecting = 0; + size_t num_idle = 0; + size_t num_transient_failures = 0; + for (const auto& p : children_) { + const auto& child_name = p.first; + const ClusterChild* child = p.second.get(); + // Skip the children that are not in the latest update. + if (config_->cluster_map().find(child_name) == + config_->cluster_map().end()) { + continue; + } + switch (child->connectivity_state()) { + case GRPC_CHANNEL_READY: { + ++num_ready; + break; + } + case GRPC_CHANNEL_CONNECTING: { + ++num_connecting; + break; + } + case GRPC_CHANNEL_IDLE: { + ++num_idle; + break; + } + case GRPC_CHANNEL_TRANSIENT_FAILURE: { + ++num_transient_failures; + break; + } + default: + GPR_UNREACHABLE_CODE(return ); + } + } + // Determine aggregated connectivity state. + grpc_connectivity_state connectivity_state; + if (num_ready > 0) { + connectivity_state = GRPC_CHANNEL_READY; + } else if (num_connecting > 0) { + connectivity_state = GRPC_CHANNEL_CONNECTING; + } else if (num_idle > 0) { + connectivity_state = GRPC_CHANNEL_IDLE; + } else { + connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_manager_lb %p] connectivity changed to %s", + this, ConnectivityStateName(connectivity_state)); + } + ClusterPicker::ClusterMap cluster_map; + for (const auto& p : config_->cluster_map()) { + const std::string& cluster_name = p.first; + RefCountedPtr& child_picker = cluster_map[cluster_name]; + child_picker = children_[cluster_name]->picker_wrapper(); + if (child_picker == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] child %s has not yet returned a " + "picker; creating a QueuePicker.", + this, cluster_name.c_str()); + } + child_picker = MakeRefCounted( + cluster_name, + absl::make_unique(Ref(DEBUG_LOCATION, "QueuePicker"))); + } + } + std::unique_ptr picker = + absl::make_unique(std::move(cluster_map)); + absl::Status status; + if (connectivity_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + status = absl::Status(absl::StatusCode::kUnavailable, + "TRANSIENT_FAILURE from XdsClusterManagerLb"); + } + channel_control_helper()->UpdateState(connectivity_state, status, + std::move(picker)); +} + +// +// XdsClusterManagerLb::ClusterChild +// + +XdsClusterManagerLb::ClusterChild::ClusterChild( + RefCountedPtr xds_cluster_manager_policy, + const std::string& name) + : xds_cluster_manager_policy_(std::move(xds_cluster_manager_policy)), + name_(name) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] created ClusterChild %p for %s", + xds_cluster_manager_policy_.get(), this, name_.c_str()); + } + GRPC_CLOSURE_INIT(&on_delayed_removal_timer_, OnDelayedRemovalTimer, this, + grpc_schedule_on_exec_ctx); +} + +XdsClusterManagerLb::ClusterChild::~ClusterChild() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] ClusterChild %p: destroying " + "child", + xds_cluster_manager_policy_.get(), this); + } + xds_cluster_manager_policy_.reset(DEBUG_LOCATION, "ClusterChild"); +} + +void XdsClusterManagerLb::ClusterChild::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] ClusterChild %p %s: " + "shutting down child", + xds_cluster_manager_policy_.get(), this, name_.c_str()); + } + // Remove the child policy's interested_parties pollset_set from the + // xDS policy. + grpc_pollset_set_del_pollset_set( + child_policy_->interested_parties(), + xds_cluster_manager_policy_->interested_parties()); + child_policy_.reset(); + // Drop our ref to the child's picker, in case it's holding a ref to + // the child. + picker_wrapper_.reset(); + if (delayed_removal_timer_callback_pending_) { + grpc_timer_cancel(&delayed_removal_timer_); + } + shutdown_ = true; + Unref(); +} + +OrphanablePtr +XdsClusterManagerLb::ClusterChild::CreateChildPolicyLocked( + const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = + xds_cluster_manager_policy_->work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = + absl::make_unique(this->Ref(DEBUG_LOCATION, "Helper")); + OrphanablePtr lb_policy = + MakeOrphanable(std::move(lb_policy_args), + &grpc_xds_cluster_manager_lb_trace); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] ClusterChild %p %s: Created " + "new child " + "policy handler %p", + xds_cluster_manager_policy_.get(), this, name_.c_str(), + lb_policy.get()); + } + // Add the xDS's interested_parties pollset_set to that of the newly created + // child policy. This will make the child policy progress upon activity on + // xDS LB, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set( + lb_policy->interested_parties(), + xds_cluster_manager_policy_->interested_parties()); + return lb_policy; +} + +void XdsClusterManagerLb::ClusterChild::UpdateLocked( + RefCountedPtr config, + const ServerAddressList& addresses, const grpc_channel_args* args) { + if (xds_cluster_manager_policy_->shutting_down_) return; + // Update child weight. + // Reactivate if needed. + if (delayed_removal_timer_callback_pending_) { + delayed_removal_timer_callback_pending_ = false; + grpc_timer_cancel(&delayed_removal_timer_); + } + // Create child policy if needed. + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(args); + } + // Construct update args. + UpdateArgs update_args; + update_args.config = std::move(config); + update_args.addresses = addresses; + update_args.args = grpc_channel_args_copy(args); + // Update the policy. + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_manager_lb %p] ClusterChild %p %s: " + "Updating child " + "policy handler %p", + xds_cluster_manager_policy_.get(), this, name_.c_str(), + child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +void XdsClusterManagerLb::ClusterChild::ExitIdleLocked() { + child_policy_->ExitIdleLocked(); +} + +void XdsClusterManagerLb::ClusterChild::ResetBackoffLocked() { + child_policy_->ResetBackoffLocked(); +} + +void XdsClusterManagerLb::ClusterChild::DeactivateLocked() { + // If already deactivated, don't do that again. + if (delayed_removal_timer_callback_pending_) return; + // Set the child weight to 0 so that future picker won't contain this child. + // Start a timer to delete the child. + Ref(DEBUG_LOCATION, "ClusterChild+timer").release(); + grpc_timer_init(&delayed_removal_timer_, + ExecCtx::Get()->Now() + + GRPC_XDS_CLUSTER_MANAGER_CHILD_RETENTION_INTERVAL_MS, + &on_delayed_removal_timer_); + delayed_removal_timer_callback_pending_ = true; +} + +void XdsClusterManagerLb::ClusterChild::OnDelayedRemovalTimer( + void* arg, grpc_error_handle error) { + ClusterChild* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); // Ref owned by the lambda + self->xds_cluster_manager_policy_->work_serializer()->Run( + [self, error]() { self->OnDelayedRemovalTimerLocked(error); }, + DEBUG_LOCATION); +} + +void XdsClusterManagerLb::ClusterChild::OnDelayedRemovalTimerLocked( + grpc_error_handle error) { + delayed_removal_timer_callback_pending_ = false; + if (error == GRPC_ERROR_NONE && !shutdown_) { + xds_cluster_manager_policy_->children_.erase(name_); + } + Unref(DEBUG_LOCATION, "ClusterChild+timer"); + GRPC_ERROR_UNREF(error); +} + +// +// XdsClusterManagerLb::ClusterChild::Helper +// + +RefCountedPtr +XdsClusterManagerLb::ClusterChild::Helper::CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) { + if (xds_cluster_manager_child_->xds_cluster_manager_policy_->shutting_down_) { + return nullptr; + } + return xds_cluster_manager_child_->xds_cluster_manager_policy_ + ->channel_control_helper() + ->CreateSubchannel(std::move(address), args); +} + +void XdsClusterManagerLb::ClusterChild::Helper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_cluster_manager_lb_trace)) { + gpr_log( + GPR_INFO, + "[xds_cluster_manager_lb %p] child %s: received update: state=%s (%s) " + "picker=%p", + xds_cluster_manager_child_->xds_cluster_manager_policy_.get(), + xds_cluster_manager_child_->name_.c_str(), ConnectivityStateName(state), + status.ToString().c_str(), picker.get()); + } + if (xds_cluster_manager_child_->xds_cluster_manager_policy_->shutting_down_) { + return; + } + // Cache the picker in the ClusterChild. + xds_cluster_manager_child_->picker_wrapper_ = + MakeRefCounted(xds_cluster_manager_child_->name_, + std::move(picker)); + // Decide what state to report for aggregation purposes. + // If we haven't seen a failure since the last time we were in state + // READY, then we report the state change as-is. However, once we do see + // a failure, we report TRANSIENT_FAILURE and ignore any subsequent state + // changes until we go back into state READY. + if (!xds_cluster_manager_child_->seen_failure_since_ready_) { + if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + xds_cluster_manager_child_->seen_failure_since_ready_ = true; + } + } else { + if (state != GRPC_CHANNEL_READY) return; + xds_cluster_manager_child_->seen_failure_since_ready_ = false; + } + xds_cluster_manager_child_->connectivity_state_ = state; + // Notify the LB policy. + xds_cluster_manager_child_->xds_cluster_manager_policy_->UpdateStateLocked(); +} + +void XdsClusterManagerLb::ClusterChild::Helper::RequestReresolution() { + if (xds_cluster_manager_child_->xds_cluster_manager_policy_->shutting_down_) { + return; + } + xds_cluster_manager_child_->xds_cluster_manager_policy_ + ->channel_control_helper() + ->RequestReresolution(); +} + +absl::string_view XdsClusterManagerLb::ClusterChild::Helper::GetAuthority() { + return xds_cluster_manager_child_->xds_cluster_manager_policy_ + ->channel_control_helper() + ->GetAuthority(); +} + +void XdsClusterManagerLb::ClusterChild::Helper::AddTraceEvent( + TraceSeverity severity, absl::string_view message) { + if (xds_cluster_manager_child_->xds_cluster_manager_policy_->shutting_down_) { + return; + } + xds_cluster_manager_child_->xds_cluster_manager_policy_ + ->channel_control_helper() + ->AddTraceEvent(severity, message); +} + +// +// factory +// + +class XdsClusterManagerLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kXdsClusterManager; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // xds_cluster_manager was mentioned as a policy in the deprecated + // loadBalancingPolicy field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:xds_cluster_manager policy requires " + "configuration. Please use loadBalancingConfig field of service " + "config instead."); + return nullptr; + } + std::vector error_list; + XdsClusterManagerLbConfig::ClusterMap cluster_map; + std::set clusters_to_be_used; + auto it = json.object_value().find("children"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:children error:required field not present")); + } else if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:children error:type should be object")); + } else { + for (const auto& p : it->second.object_value()) { + const std::string& child_name = p.first; + if (child_name.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:children element error: name cannot be empty")); + continue; + } + RefCountedPtr child_config; + std::vector child_errors = + ParseChildConfig(p.second, &child_config); + if (!child_errors.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("field:children name:", child_name), &child_errors)); + } else { + cluster_map[child_name] = std::move(child_config); + clusters_to_be_used.insert(child_name); + } + } + } + if (cluster_map.empty()) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("no valid children configured")); + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "xds_cluster_manager_experimental LB policy config", &error_list); + return nullptr; + } + return MakeRefCounted(std::move(cluster_map)); + } + + private: + static std::vector ParseChildConfig( + const Json& json, + RefCountedPtr* child_config) { + std::vector error_list; + if (json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "value should be of type object")); + return error_list; + } + auto it = json.object_value().find("childPolicy"); + if (it == json.object_value().end()) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("did not find childPolicy")); + } else { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + *child_config = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + it->second, &parse_error); + if (*child_config == nullptr) { + GPR_DEBUG_ASSERT(parse_error != GRPC_ERROR_NONE); + std::vector child_errors; + child_errors.push_back(parse_error); + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:childPolicy", &child_errors)); + } + } + return error_list; + } +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_xds_cluster_manager_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_xds_cluster_manager_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc new file mode 100644 index 00000000..8ca5f065 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc @@ -0,0 +1,1362 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" + +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy/address_filtering.h" +#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h" +#include "src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h" +#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/ext/xds/xds_client_stats.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/uri/uri_parser.h" + +#define GRPC_EDS_DEFAULT_FALLBACK_TIMEOUT 10000 + +namespace grpc_core { + +TraceFlag grpc_lb_xds_cluster_resolver_trace(false, "xds_cluster_resolver_lb"); + +const char* kXdsLocalityNameAttributeKey = "xds_locality_name"; + +namespace { + +constexpr char kXdsClusterResolver[] = "xds_cluster_resolver_experimental"; + +// Config for EDS LB policy. +class XdsClusterResolverLbConfig : public LoadBalancingPolicy::Config { + public: + struct DiscoveryMechanism { + std::string cluster_name; + absl::optional lrs_load_reporting_server_name; + uint32_t max_concurrent_requests; + enum DiscoveryMechanismType { + EDS, + LOGICAL_DNS, + }; + DiscoveryMechanismType type; + std::string eds_service_name; + std::string dns_hostname; + + bool operator==(const DiscoveryMechanism& other) const { + return (cluster_name == other.cluster_name && + lrs_load_reporting_server_name == + other.lrs_load_reporting_server_name && + max_concurrent_requests == other.max_concurrent_requests && + type == other.type && + eds_service_name == other.eds_service_name && + dns_hostname == other.dns_hostname); + } + }; + + XdsClusterResolverLbConfig( + std::vector discovery_mechanisms, Json xds_lb_policy) + : discovery_mechanisms_(std::move(discovery_mechanisms)), + xds_lb_policy_(std::move(xds_lb_policy)) {} + + const char* name() const override { return kXdsClusterResolver; } + const std::vector& discovery_mechanisms() const { + return discovery_mechanisms_; + } + + const Json& xds_lb_policy() const { return xds_lb_policy_; } + + private: + std::vector discovery_mechanisms_; + Json xds_lb_policy_; +}; + +// Xds Cluster Resolver LB policy. +class XdsClusterResolverLb : public LoadBalancingPolicy { + public: + XdsClusterResolverLb(RefCountedPtr xds_client, Args args, + std::string server_name, bool is_xds_uri); + + const char* name() const override { return kXdsClusterResolver; } + + void UpdateLocked(UpdateArgs args) override; + void ResetBackoffLocked() override; + void ExitIdleLocked() override; + + private: + // Discovery Mechanism Base class + // + // Implemented by EDS and LOGICAL_DNS. + // + // Implementations are responsible for calling the LB policy's + // OnEndpointChanged(), OnError(), and OnResourceDoesNotExist() + // methods when the corresponding events occur. + // + // Must implement Orphan() method to cancel the watchers. + class DiscoveryMechanism : public InternallyRefCounted { + public: + DiscoveryMechanism( + RefCountedPtr xds_cluster_resolver_lb, + size_t index) + : parent_(std::move(xds_cluster_resolver_lb)), index_(index) {} + virtual void Start() = 0; + void Orphan() override = 0; + virtual Json::Array override_child_policy() = 0; + virtual bool disable_reresolution() = 0; + + // Returns a pair containing the cluster and eds_service_name + // to use for LRS load reporting. Caller must ensure that config_ is set + // before calling. + std::pair GetLrsClusterKey() const { + if (!parent_->is_xds_uri_) return {parent_->server_name_, nullptr}; + return { + parent_->config_->discovery_mechanisms()[index_].cluster_name, + parent_->config_->discovery_mechanisms()[index_].eds_service_name}; + } + + protected: + XdsClusterResolverLb* parent() const { return parent_.get(); } + size_t index() const { return index_; } + + private: + RefCountedPtr parent_; + // Stores its own index in the vector of DiscoveryMechanism. + size_t index_; + }; + + class EdsDiscoveryMechanism : public DiscoveryMechanism { + public: + EdsDiscoveryMechanism( + RefCountedPtr xds_cluster_resolver_lb, + size_t index) + : DiscoveryMechanism(std::move(xds_cluster_resolver_lb), index) {} + void Start() override; + void Orphan() override; + Json::Array override_child_policy() override { return Json::Array{}; } + bool disable_reresolution() override { return true; } + + private: + class EndpointWatcher : public XdsClient::EndpointWatcherInterface { + public: + explicit EndpointWatcher( + RefCountedPtr discovery_mechanism) + : discovery_mechanism_(std::move(discovery_mechanism)) {} + ~EndpointWatcher() override { + discovery_mechanism_.reset(DEBUG_LOCATION, "EndpointWatcher"); + } + void OnEndpointChanged(XdsApi::EdsUpdate update) override { + new Notifier(discovery_mechanism_, std::move(update)); + } + void OnError(grpc_error_handle error) override { + new Notifier(discovery_mechanism_, error); + } + void OnResourceDoesNotExist() override { + new Notifier(discovery_mechanism_); + } + + private: + class Notifier { + public: + Notifier(RefCountedPtr discovery_mechanism, + XdsApi::EdsUpdate update); + Notifier(RefCountedPtr discovery_mechanism, + grpc_error_handle error); + explicit Notifier( + RefCountedPtr discovery_mechanism); + ~Notifier() { discovery_mechanism_.reset(DEBUG_LOCATION, "Notifier"); } + + private: + enum Type { kUpdate, kError, kDoesNotExist }; + + static void RunInExecCtx(void* arg, grpc_error_handle error); + void RunInWorkSerializer(grpc_error_handle error); + + RefCountedPtr discovery_mechanism_; + grpc_closure closure_; + XdsApi::EdsUpdate update_; + Type type_; + }; + + RefCountedPtr discovery_mechanism_; + }; + + absl::string_view GetEdsResourceName() const { + if (!parent()->is_xds_uri_) return parent()->server_name_; + if (!parent() + ->config_->discovery_mechanisms()[index()] + .eds_service_name.empty()) { + return parent() + ->config_->discovery_mechanisms()[index()] + .eds_service_name; + } + return parent()->config_->discovery_mechanisms()[index()].cluster_name; + } + + // Note that this is not owned, so this pointer must never be dereferenced. + EndpointWatcher* watcher_ = nullptr; + }; + + class LogicalDNSDiscoveryMechanism : public DiscoveryMechanism { + public: + LogicalDNSDiscoveryMechanism( + RefCountedPtr xds_cluster_resolver_lb, + size_t index) + : DiscoveryMechanism(std::move(xds_cluster_resolver_lb), index) {} + void Start() override; + void Orphan() override; + Json::Array override_child_policy() override { + return Json::Array{ + Json::Object{ + {"pick_first", Json::Object()}, + }, + }; + } + bool disable_reresolution() override { return false; }; + + private: + class ResolverResultHandler : public Resolver::ResultHandler { + public: + explicit ResolverResultHandler( + RefCountedPtr discovery_mechanism) + : discovery_mechanism_(std::move(discovery_mechanism)) {} + + ~ResolverResultHandler() override {} + + void ReturnResult(Resolver::Result result) override; + + void ReturnError(grpc_error_handle error) override; + + private: + RefCountedPtr discovery_mechanism_; + }; + + // This is necessary only because of a bug in msvc where nested class cannot + // access protected member in base class. + friend class ResolverResultHandler; + + OrphanablePtr resolver_; + }; + + struct DiscoveryMechanismEntry { + OrphanablePtr discovery_mechanism; + bool first_update_received = false; + // Number of priorities this mechanism has contributed to priority_list_. + // (The sum of this across all discovery mechanisms should always equal + // the number of priorities in priority_list_.) + uint32_t num_priorities = 0; + RefCountedPtr drop_config; + // Populated only when an update has been delivered by the mechanism + // but has not yet been applied to the LB policy's combined priority_list_. + absl::optional pending_priority_list; + }; + + class Helper : public ChannelControlHelper { + public: + explicit Helper( + RefCountedPtr xds_cluster_resolver_policy) + : xds_cluster_resolver_policy_(std::move(xds_cluster_resolver_policy)) { + } + + ~Helper() override { + xds_cluster_resolver_policy_.reset(DEBUG_LOCATION, "Helper"); + } + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override; + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override; + // This is a no-op, because we get the addresses from the xds + // client, which is a watch-based API. + void RequestReresolution() override {} + absl::string_view GetAuthority() override; + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override; + + private: + RefCountedPtr xds_cluster_resolver_policy_; + }; + + ~XdsClusterResolverLb() override; + + void ShutdownLocked() override; + + void OnEndpointChanged(size_t index, XdsApi::EdsUpdate update); + void OnError(size_t index, grpc_error_handle error); + void OnResourceDoesNotExist(size_t index); + + void MaybeDestroyChildPolicyLocked(); + + void UpdatePriorityList(XdsApi::EdsUpdate::PriorityList priority_list); + void UpdateChildPolicyLocked(); + OrphanablePtr CreateChildPolicyLocked( + const grpc_channel_args* args); + ServerAddressList CreateChildPolicyAddressesLocked(); + RefCountedPtr CreateChildPolicyConfigLocked(); + grpc_channel_args* CreateChildPolicyArgsLocked( + const grpc_channel_args* args_in); + + // The xds client and endpoint watcher. + RefCountedPtr xds_client_; + + // Server name from target URI. + std::string server_name_; + bool is_xds_uri_; + + // Current channel args and config from the resolver. + const grpc_channel_args* args_ = nullptr; + RefCountedPtr config_; + + // Internal state. + bool shutting_down_ = false; + + // Vector of discovery mechansism entries in priority order. + std::vector discovery_mechanisms_; + + // The latest data from the endpoint watcher. + XdsApi::EdsUpdate::PriorityList priority_list_; + // State used to retain child policy names for priority policy. + std::vector priority_child_numbers_; + + OrphanablePtr child_policy_; +}; + +// +// XdsClusterResolverLb::Helper +// + +RefCountedPtr +XdsClusterResolverLb::Helper::CreateSubchannel(ServerAddress address, + const grpc_channel_args& args) { + if (xds_cluster_resolver_policy_->shutting_down_) return nullptr; + return xds_cluster_resolver_policy_->channel_control_helper() + ->CreateSubchannel(std::move(address), args); +} + +void XdsClusterResolverLb::Helper::UpdateState( + grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) { + if (xds_cluster_resolver_policy_->shutting_down_ || + xds_cluster_resolver_policy_->child_policy_ == nullptr) { + return; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] child policy updated state=%s (%s) " + "picker=%p", + xds_cluster_resolver_policy_.get(), ConnectivityStateName(state), + status.ToString().c_str(), picker.get()); + } + xds_cluster_resolver_policy_->channel_control_helper()->UpdateState( + state, status, std::move(picker)); +} + +absl::string_view XdsClusterResolverLb::Helper::GetAuthority() { + return xds_cluster_resolver_policy_->channel_control_helper()->GetAuthority(); +} + +void XdsClusterResolverLb::Helper::AddTraceEvent(TraceSeverity severity, + absl::string_view message) { + if (xds_cluster_resolver_policy_->shutting_down_) return; + xds_cluster_resolver_policy_->channel_control_helper()->AddTraceEvent( + severity, message); +} + +// +// XdsClusterResolverLb::EdsDiscoveryMechanism +// + +void XdsClusterResolverLb::EdsDiscoveryMechanism::Start() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] eds discovery mechanism %" PRIuPTR + ":%p starting xds watch for %s", + parent(), index(), this, std::string(GetEdsResourceName()).c_str()); + } + auto watcher = absl::make_unique( + Ref(DEBUG_LOCATION, "EdsDiscoveryMechanism")); + watcher_ = watcher.get(); + parent()->xds_client_->WatchEndpointData(GetEdsResourceName(), + std::move(watcher)); +} + +void XdsClusterResolverLb::EdsDiscoveryMechanism::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] eds discovery mechanism %" PRIuPTR + ":%p cancelling xds watch for %s", + parent(), index(), this, std::string(GetEdsResourceName()).c_str()); + } + parent()->xds_client_->CancelEndpointDataWatch(GetEdsResourceName(), + watcher_); + Unref(); +} + +// +// XdsClusterResolverLb::EndpointWatcher::Notifier +// + +XdsClusterResolverLb::EdsDiscoveryMechanism::EndpointWatcher::Notifier:: + Notifier(RefCountedPtr + discovery_mechanism, + XdsApi::EdsUpdate update) + : discovery_mechanism_(std::move(discovery_mechanism)), + update_(std::move(update)), + type_(kUpdate) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +XdsClusterResolverLb::EdsDiscoveryMechanism::EndpointWatcher::Notifier:: + Notifier(RefCountedPtr + discovery_mechanism, + grpc_error_handle error) + : discovery_mechanism_(std::move(discovery_mechanism)), type_(kError) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, error); +} + +XdsClusterResolverLb::EdsDiscoveryMechanism::EndpointWatcher::Notifier:: + Notifier(RefCountedPtr + discovery_mechanism) + : discovery_mechanism_(std::move(discovery_mechanism)), + type_(kDoesNotExist) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +void XdsClusterResolverLb::EdsDiscoveryMechanism::EndpointWatcher::Notifier:: + RunInExecCtx(void* arg, grpc_error_handle error) { + Notifier* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); + self->discovery_mechanism_->parent()->work_serializer()->Run( + [self, error]() { self->RunInWorkSerializer(error); }, DEBUG_LOCATION); +} + +void XdsClusterResolverLb::EdsDiscoveryMechanism::EndpointWatcher::Notifier:: + RunInWorkSerializer(grpc_error_handle error) { + switch (type_) { + case kUpdate: + discovery_mechanism_->parent()->OnEndpointChanged( + discovery_mechanism_->index(), std::move(update_)); + break; + case kError: + discovery_mechanism_->parent()->OnError(discovery_mechanism_->index(), + error); + break; + case kDoesNotExist: + discovery_mechanism_->parent()->OnResourceDoesNotExist( + discovery_mechanism_->index()); + break; + }; + delete this; +} + +// +// XdsClusterResolverLb::LogicalDNSDiscoveryMechanism +// + +void XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::Start() { + std::string target = + parent()->config_->discovery_mechanisms()[index()].dns_hostname; + grpc_channel_args* args = nullptr; + FakeResolverResponseGenerator* fake_resolver_response_generator = + grpc_channel_args_find_pointer( + parent()->args_, + GRPC_ARG_XDS_LOGICAL_DNS_CLUSTER_FAKE_RESOLVER_RESPONSE_GENERATOR); + if (fake_resolver_response_generator != nullptr) { + target = absl::StrCat("fake:", target); + grpc_arg new_arg = FakeResolverResponseGenerator::MakeChannelArg( + fake_resolver_response_generator); + args = grpc_channel_args_copy_and_add(parent()->args_, &new_arg, 1); + } else { + target = absl::StrCat("dns:", target); + args = grpc_channel_args_copy(parent()->args_); + } + resolver_ = ResolverRegistry::CreateResolver( + target.c_str(), args, parent()->interested_parties(), + parent()->work_serializer(), + absl::make_unique( + Ref(DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism"))); + grpc_channel_args_destroy(args); + if (resolver_ == nullptr) { + parent()->OnResourceDoesNotExist(index()); + return; + } + resolver_->StartLocked(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] logical DNS discovery mechanism " + "%" PRIuPTR ":%p starting dns resolver %p", + parent(), index(), this, resolver_.get()); + } +} + +void XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log( + GPR_INFO, + "[xds_cluster_resolver_lb %p] logical DNS discovery mechanism %" PRIuPTR + ":%p shutting down dns resolver %p", + parent(), index(), this, resolver_.get()); + } + resolver_.reset(); + Unref(); +} + +// +// XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::ResolverResultHandler +// + +void XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::ResolverResultHandler:: + ReturnResult(Resolver::Result result) { + // convert result to eds update + XdsApi::EdsUpdate update; + XdsApi::EdsUpdate::Priority::Locality locality; + locality.name = MakeRefCounted("", "", ""); + locality.lb_weight = 1; + locality.endpoints = std::move(result.addresses); + XdsApi::EdsUpdate::Priority priority; + priority.localities.emplace(locality.name.get(), std::move(locality)); + update.priorities.emplace_back(std::move(priority)); + discovery_mechanism_->parent()->OnEndpointChanged( + discovery_mechanism_->index(), std::move(update)); +} + +void XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::ResolverResultHandler:: + ReturnError(grpc_error_handle error) { + discovery_mechanism_->parent()->OnError(discovery_mechanism_->index(), error); +} + +// +// XdsClusterResolverLb public methods +// + +XdsClusterResolverLb::XdsClusterResolverLb(RefCountedPtr xds_client, + Args args, std::string server_name, + bool is_xds_uri) + : LoadBalancingPolicy(std::move(args)), + xds_client_(std::move(xds_client)), + server_name_(std::move(server_name)), + is_xds_uri_(is_xds_uri) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] created -- xds_client=%p, " + "server_name=%s, is_xds_uri=%d", + this, xds_client_.get(), server_name_.c_str(), is_xds_uri_); + } + // EDS-only flow. + if (!is_xds_uri_) { + // Couple polling. + grpc_pollset_set_add_pollset_set(xds_client_->interested_parties(), + interested_parties()); + } +} + +XdsClusterResolverLb::~XdsClusterResolverLb() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] destroying xds_cluster_resolver LB " + "policy", + this); + } +} + +void XdsClusterResolverLb::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_resolver_lb %p] shutting down", this); + } + shutting_down_ = true; + MaybeDestroyChildPolicyLocked(); + discovery_mechanisms_.clear(); + if (!is_xds_uri_) { + // Decouple polling. + grpc_pollset_set_del_pollset_set(xds_client_->interested_parties(), + interested_parties()); + } + xds_client_.reset(DEBUG_LOCATION, "XdsClusterResolverLb"); + // Destroy channel args. + grpc_channel_args_destroy(args_); + args_ = nullptr; +} + +void XdsClusterResolverLb::MaybeDestroyChildPolicyLocked() { + if (child_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(), + interested_parties()); + child_policy_.reset(); + } +} + +void XdsClusterResolverLb::UpdateLocked(UpdateArgs args) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_resolver_lb %p] Received update", this); + } + const bool is_initial_update = args_ == nullptr; + // Update config. + auto old_config = std::move(config_); + config_ = std::move(args.config); + // Update args. + grpc_channel_args_destroy(args_); + args_ = args.args; + args.args = nullptr; + // Update child policy if needed. + if (child_policy_ != nullptr) UpdateChildPolicyLocked(); + // Create endpoint watcher if needed. + if (is_initial_update) { + for (const auto& config : config_->discovery_mechanisms()) { + DiscoveryMechanismEntry entry; + if (config.type == XdsClusterResolverLbConfig::DiscoveryMechanism:: + DiscoveryMechanismType::EDS) { + entry.discovery_mechanism = + grpc_core::MakeOrphanable( + Ref(DEBUG_LOCATION, "EdsDiscoveryMechanism"), + discovery_mechanisms_.size()); + } else if (config.type == XdsClusterResolverLbConfig::DiscoveryMechanism:: + DiscoveryMechanismType::LOGICAL_DNS) { + entry.discovery_mechanism = + grpc_core::MakeOrphanable( + Ref(DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism"), + discovery_mechanisms_.size()); + } else { + GPR_ASSERT(0); + } + discovery_mechanisms_.push_back(std::move(entry)); + } + // Call start() on all discovery mechanisms after creation. + for (const auto& discovery_mechanism : discovery_mechanisms_) { + discovery_mechanism.discovery_mechanism->Start(); + } + } +} + +void XdsClusterResolverLb::ResetBackoffLocked() { + // When the XdsClient is instantiated in the resolver instead of in this + // LB policy, this is done via the resolver, so we don't need to do it here. + if (!is_xds_uri_ && xds_client_ != nullptr) xds_client_->ResetBackoff(); + if (child_policy_ != nullptr) { + child_policy_->ResetBackoffLocked(); + } +} + +void XdsClusterResolverLb::ExitIdleLocked() { + if (child_policy_ != nullptr) child_policy_->ExitIdleLocked(); +} + +void XdsClusterResolverLb::OnEndpointChanged(size_t index, + XdsApi::EdsUpdate update) { + if (shutting_down_) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p] Received update from xds client" + " for discovery mechanism %" PRIuPTR "", + this, index); + } + // We need at least one priority for each discovery mechanism, just so that we + // have a child in which to create the xds_cluster_impl policy. This ensures + // that we properly handle the case of a discovery mechanism dropping 100% of + // calls, the OnError() case, and the OnResourceDoesNotExist() case. + if (update.priorities.empty()) update.priorities.emplace_back(); + discovery_mechanisms_[index].drop_config = std::move(update.drop_config); + discovery_mechanisms_[index].pending_priority_list = + std::move(update.priorities); + discovery_mechanisms_[index].first_update_received = true; + // If any discovery mechanism has not received its first update, + // wait until that happens before creating the child policy. + // TODO(roth): If this becomes problematic in the future (e.g., a + // secondary discovery mechanism delaying us from starting up at all), + // we can consider some sort of optimization whereby we can create the + // priority policy with only a subset of its children. But we need to + // make sure not to get into a situation where the priority policy + // will put the channel into TRANSIENT_FAILURE instead of CONNECTING + // while we're still waiting for the other discovery mechanism(s). + for (DiscoveryMechanismEntry& mechanism : discovery_mechanisms_) { + if (!mechanism.first_update_received) return; + } + // Construct new priority list. + XdsApi::EdsUpdate::PriorityList priority_list; + size_t priority_index = 0; + for (DiscoveryMechanismEntry& mechanism : discovery_mechanisms_) { + // If the mechanism has a pending update, use that. + // Otherwise, use the priorities that it previously contributed to the + // combined list. + if (mechanism.pending_priority_list.has_value()) { + priority_list.insert(priority_list.end(), + mechanism.pending_priority_list->begin(), + mechanism.pending_priority_list->end()); + priority_index += mechanism.num_priorities; + mechanism.num_priorities = mechanism.pending_priority_list->size(); + mechanism.pending_priority_list.reset(); + } else { + priority_list.insert( + priority_list.end(), priority_list_.begin() + priority_index, + priority_list_.begin() + priority_index + mechanism.num_priorities); + priority_index += mechanism.num_priorities; + } + } + // Update child policy. + UpdatePriorityList(std::move(priority_list)); +} + +void XdsClusterResolverLb::OnError(size_t index, grpc_error_handle error) { + gpr_log(GPR_ERROR, + "[xds_cluster_resolver_lb %p] discovery mechanism %" PRIuPTR + " xds watcher reported error: %s", + this, index, grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + if (shutting_down_) return; + if (!discovery_mechanisms_[index].first_update_received) { + // Call OnEndpointChanged with an empty update just like + // OnResourceDoesNotExist. + OnEndpointChanged(index, XdsApi::EdsUpdate()); + } +} + +void XdsClusterResolverLb::OnResourceDoesNotExist(size_t index) { + gpr_log(GPR_ERROR, + "[xds_cluster_resolver_lb %p] discovery mechanism %" PRIuPTR + " resource does not exist", + this, index); + if (shutting_down_) return; + // Call OnEndpointChanged with an empty update. + OnEndpointChanged(index, XdsApi::EdsUpdate()); +} + +// +// child policy-related methods +// + +void XdsClusterResolverLb::UpdatePriorityList( + XdsApi::EdsUpdate::PriorityList priority_list) { + // Build some maps from locality to child number and the reverse from + // the old data in priority_list_ and priority_child_numbers_. + std::map + locality_child_map; + std::map> child_locality_map; + for (size_t priority = 0; priority < priority_list_.size(); ++priority) { + size_t child_number = priority_child_numbers_[priority]; + const auto& localities = priority_list_[priority].localities; + for (const auto& p : localities) { + XdsLocalityName* locality_name = p.first; + locality_child_map[locality_name] = child_number; + child_locality_map[child_number].insert(locality_name); + } + } + // Construct new list of children. + std::vector priority_child_numbers; + for (size_t priority = 0; priority < priority_list.size(); ++priority) { + const auto& localities = priority_list[priority].localities; + absl::optional child_number; + // If one of the localities in this priority already existed, reuse its + // child number. + for (const auto& p : localities) { + XdsLocalityName* locality_name = p.first; + if (!child_number.has_value()) { + auto it = locality_child_map.find(locality_name); + if (it != locality_child_map.end()) { + child_number = it->second; + locality_child_map.erase(it); + // Remove localities that *used* to be in this child number, so + // that we don't incorrectly reuse this child number for a + // subsequent priority. + for (XdsLocalityName* old_locality : + child_locality_map[*child_number]) { + locality_child_map.erase(old_locality); + } + } + } else { + // Remove all localities that are now in this child number, so + // that we don't accidentally reuse this child number for a + // subsequent priority. + locality_child_map.erase(locality_name); + } + } + // If we didn't find an existing child number, assign a new one. + if (!child_number.has_value()) { + for (child_number = 0; + child_locality_map.find(*child_number) != child_locality_map.end(); + ++(*child_number)) { + } + // Add entry so we know that the child number is in use. + // (Don't need to add the list of localities, since we won't use them.) + child_locality_map[*child_number]; + } + priority_child_numbers.push_back(*child_number); + } + // Save update. + priority_list_ = std::move(priority_list); + priority_child_numbers_ = std::move(priority_child_numbers); + // Update child policy. + UpdateChildPolicyLocked(); +} + +ServerAddressList XdsClusterResolverLb::CreateChildPolicyAddressesLocked() { + ServerAddressList addresses; + for (size_t priority = 0; priority < priority_list_.size(); ++priority) { + const auto& localities = priority_list_[priority].localities; + std::string priority_child_name = + absl::StrCat("child", priority_child_numbers_[priority]); + for (const auto& p : localities) { + const auto& locality_name = p.first; + const auto& locality = p.second; + std::vector hierarchical_path = { + priority_child_name, locality_name->AsHumanReadableString()}; + for (const auto& endpoint : locality.endpoints) { + const ServerAddressWeightAttribute* weight_attribute = static_cast< + const ServerAddressWeightAttribute*>(endpoint.GetAttribute( + ServerAddressWeightAttribute::kServerAddressWeightAttributeKey)); + uint32_t weight = locality.lb_weight; + if (weight_attribute != nullptr) { + weight = locality.lb_weight * weight_attribute->weight(); + } + addresses.emplace_back( + endpoint + .WithAttribute(kHierarchicalPathAttributeKey, + MakeHierarchicalPathAttribute(hierarchical_path)) + .WithAttribute(kXdsLocalityNameAttributeKey, + absl::make_unique( + locality_name->Ref())) + .WithAttribute( + ServerAddressWeightAttribute:: + kServerAddressWeightAttributeKey, + absl::make_unique(weight))); + } + } + } + return addresses; +} + +RefCountedPtr +XdsClusterResolverLb::CreateChildPolicyConfigLocked() { + Json::Object priority_children; + Json::Array priority_priorities; + // Setting up index to iterate through the discovery mechanisms and keeping + // track the discovery_mechanism each priority belongs to. + size_t discovery_index = 0; + // Setting up num_priorities_remaining to track the priorities in each + // discovery_mechanism. + size_t num_priorities_remaining_in_discovery = + discovery_mechanisms_[discovery_index].num_priorities; + for (size_t priority = 0; priority < priority_list_.size(); ++priority) { + Json child_policy; + if (!discovery_mechanisms_[discovery_index] + .discovery_mechanism->override_child_policy() + .empty()) { + child_policy = discovery_mechanisms_[discovery_index] + .discovery_mechanism->override_child_policy(); + } else { + const auto& xds_lb_policy = config_->xds_lb_policy().object_value(); + if (xds_lb_policy.find("ROUND_ROBIN") != xds_lb_policy.end()) { + const auto& localities = priority_list_[priority].localities; + Json::Object weighted_targets; + for (const auto& p : localities) { + XdsLocalityName* locality_name = p.first; + const auto& locality = p.second; + // Construct JSON object containing locality name. + Json::Object locality_name_json; + if (!locality_name->region().empty()) { + locality_name_json["region"] = locality_name->region(); + } + if (!locality_name->zone().empty()) { + locality_name_json["zone"] = locality_name->zone(); + } + if (!locality_name->sub_zone().empty()) { + locality_name_json["sub_zone"] = locality_name->sub_zone(); + } + // Add weighted target entry. + weighted_targets[locality_name->AsHumanReadableString()] = + Json::Object{ + {"weight", locality.lb_weight}, + {"childPolicy", + Json::Array{ + Json::Object{ + {"round_robin", Json::Object()}, + }, + }}, + }; + } + // Construct locality-picking policy. + // Start with field from our config and add the "targets" field. + child_policy = Json::Array{ + Json::Object{ + {"weighted_target_experimental", + Json::Object{ + {"targets", Json::Object()}, + }}, + }, + }; + Json::Object& config = + *(*child_policy.mutable_array())[0].mutable_object(); + auto it = config.begin(); + GPR_ASSERT(it != config.end()); + (*it->second.mutable_object())["targets"] = std::move(weighted_targets); + } else { + auto it = xds_lb_policy.find("RING_HASH"); + GPR_ASSERT(it != xds_lb_policy.end()); + Json::Object ring_hash_experimental_policy = it->second.object_value(); + child_policy = Json::Array{ + Json::Object{ + {"ring_hash_experimental", ring_hash_experimental_policy}, + }, + }; + } + } + // Wrap it in the drop policy. + Json::Array drop_categories; + if (discovery_mechanisms_[discovery_index].drop_config != nullptr) { + for (const auto& category : discovery_mechanisms_[discovery_index] + .drop_config->drop_category_list()) { + drop_categories.push_back(Json::Object{ + {"category", category.name}, + {"requests_per_million", category.parts_per_million}, + }); + } + } + const auto lrs_key = discovery_mechanisms_[discovery_index] + .discovery_mechanism->GetLrsClusterKey(); + Json::Object xds_cluster_impl_config = { + {"clusterName", std::string(lrs_key.first)}, + {"childPolicy", std::move(child_policy)}, + {"dropCategories", std::move(drop_categories)}, + {"maxConcurrentRequests", + config_->discovery_mechanisms()[discovery_index] + .max_concurrent_requests}, + }; + if (!lrs_key.second.empty()) { + xds_cluster_impl_config["edsServiceName"] = std::string(lrs_key.second); + } + if (config_->discovery_mechanisms()[discovery_index] + .lrs_load_reporting_server_name.has_value()) { + xds_cluster_impl_config["lrsLoadReportingServerName"] = + config_->discovery_mechanisms()[discovery_index] + .lrs_load_reporting_server_name.value(); + } + Json locality_picking_policy = Json::Array{Json::Object{ + {"xds_cluster_impl_experimental", std::move(xds_cluster_impl_config)}, + }}; + // Add priority entry. + const size_t child_number = priority_child_numbers_[priority]; + std::string child_name = absl::StrCat("child", child_number); + priority_priorities.emplace_back(child_name); + Json::Object child_config = { + {"config", std::move(locality_picking_policy)}, + }; + if (discovery_mechanisms_[discovery_index] + .discovery_mechanism->disable_reresolution()) { + child_config["ignore_reresolution_requests"] = true; + } + priority_children[child_name] = std::move(child_config); + // Each priority in the priority_list_ should correspond to a priority in a + // discovery mechanism in discovery_mechanisms_ (both in the same order). + // Keeping track of the discovery_mechanism each priority belongs to. + --num_priorities_remaining_in_discovery; + while (num_priorities_remaining_in_discovery == 0 && + discovery_index < discovery_mechanisms_.size() - 1) { + ++discovery_index; + num_priorities_remaining_in_discovery = + discovery_mechanisms_[discovery_index].num_priorities; + } + } + // There should be matching number of priorities in discovery_mechanisms_ and + // in priority_list_; therefore at the end of looping through all the + // priorities, num_priorities_remaining should be down to 0, and index should + // be the last index in discovery_mechanisms_. + GPR_ASSERT(num_priorities_remaining_in_discovery == 0); + GPR_ASSERT(discovery_index == discovery_mechanisms_.size() - 1); + Json json = Json::Array{Json::Object{ + {"priority_experimental", + Json::Object{ + {"children", std::move(priority_children)}, + {"priorities", std::move(priority_priorities)}, + }}, + }}; + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + std::string json_str = json.Dump(/*indent=*/1); + gpr_log( + GPR_INFO, + "[xds_cluster_resolver_lb %p] generated config for child policy: %s", + this, json_str.c_str()); + } + grpc_error_handle error = GRPC_ERROR_NONE; + RefCountedPtr config = + LoadBalancingPolicyRegistry::ParseLoadBalancingConfig(json, &error); + if (error != GRPC_ERROR_NONE) { + // This should never happen, but if it does, we basically have no + // way to fix it, so we put the channel in TRANSIENT_FAILURE. + gpr_log(GPR_ERROR, + "[xds_cluster_resolver_lb %p] error parsing generated child policy " + "config -- " + "will put channel in TRANSIENT_FAILURE: %s", + this, grpc_error_std_string(error).c_str()); + absl::Status status = absl::InternalError( + "xds_cluster_resolver LB policy: error parsing generated child policy " + "config"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + return nullptr; + } + return config; +} + +void XdsClusterResolverLb::UpdateChildPolicyLocked() { + if (shutting_down_) return; + UpdateArgs update_args; + update_args.config = CreateChildPolicyConfigLocked(); + if (update_args.config == nullptr) return; + update_args.addresses = CreateChildPolicyAddressesLocked(); + update_args.args = CreateChildPolicyArgsLocked(args_); + if (child_policy_ == nullptr) { + child_policy_ = CreateChildPolicyLocked(update_args.args); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_cluster_resolver_lb %p] Updating child policy %p", + this, child_policy_.get()); + } + child_policy_->UpdateLocked(std::move(update_args)); +} + +grpc_channel_args* XdsClusterResolverLb::CreateChildPolicyArgsLocked( + const grpc_channel_args* args) { + absl::InlinedVector new_args = { + // Inhibit client-side health checking, since the balancer does this + // for us. + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_INHIBIT_HEALTH_CHECKING), 1), + }; + if (!is_xds_uri_) new_args.push_back(xds_client_->MakeChannelArg()); + return grpc_channel_args_copy_and_add(args, new_args.data(), new_args.size()); +} + +OrphanablePtr +XdsClusterResolverLb::CreateChildPolicyLocked(const grpc_channel_args* args) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.work_serializer = work_serializer(); + lb_policy_args.args = args; + lb_policy_args.channel_control_helper = + absl::make_unique(Ref(DEBUG_LOCATION, "Helper")); + OrphanablePtr lb_policy = + LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + "priority_experimental", std::move(lb_policy_args)); + if (GPR_UNLIKELY(lb_policy == nullptr)) { + gpr_log(GPR_ERROR, + "[xds_cluster_resolver_lb %p] failure creating child policy", this); + return nullptr; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_cluster_resolver_trace)) { + gpr_log(GPR_INFO, + "[xds_cluster_resolver_lb %p]: Created new child policy %p", this, + lb_policy.get()); + } + // Add our interested_parties pollset_set to that of the newly created + // child policy. This will make the child policy progress upon activity on + // this policy, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set(lb_policy->interested_parties(), + interested_parties()); + return lb_policy; +} + +// +// factory +// + +class XdsClusterResolverLbFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + // Find server name. + const char* server_uri = + grpc_channel_args_find_string(args.args, GRPC_ARG_SERVER_URI); + GPR_ASSERT(server_uri != nullptr); + absl::StatusOr uri = URI::Parse(server_uri); + GPR_ASSERT(uri.ok() && !uri->path().empty()); + absl::string_view server_name = absl::StripPrefix(uri->path(), "/"); + // Determine if it's an xds URI. + bool is_xds_uri = uri->scheme() == "xds" || uri->scheme() == "google-c2p"; + // Get XdsClient. + RefCountedPtr xds_client = + XdsClient::GetFromChannelArgs(*args.args); + if (xds_client == nullptr) { + if (!is_xds_uri) { + grpc_error_handle error = GRPC_ERROR_NONE; + xds_client = XdsClient::GetOrCreate(args.args, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "cannot get or create XdsClient to instantiate " + "xds_cluster_resolver LB policy: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + } else { + gpr_log(GPR_ERROR, + "XdsClient not present in channel args -- cannot instantiate " + "xds_cluster_resolver LB policy"); + return nullptr; + } + } + return MakeOrphanable( + std::move(xds_client), std::move(args), server_name, is_xds_uri); + } + + const char* name() const override { return kXdsClusterResolver; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (json.type() == Json::Type::JSON_NULL) { + // xds_cluster_resolver was mentioned as a policy in the deprecated + // loadBalancingPolicy field or in the client API. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:xds_cluster_resolver policy " + "requires configuration. " + "Please use loadBalancingConfig field of service config instead."); + return nullptr; + } + std::vector error_list; + std::vector + discovery_mechanisms; + auto it = json.object_value().find("discoveryMechanisms"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:discoveryMechanisms error:required field missing")); + } else if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:discoveryMechanisms error:type should be array")); + } else { + const Json::Array& array = it->second.array_value(); + for (size_t i = 0; i < array.size(); ++i) { + XdsClusterResolverLbConfig::DiscoveryMechanism discovery_mechanism; + std::vector discovery_mechanism_errors = + ParseDiscoveryMechanism(array[i], &discovery_mechanism); + if (!discovery_mechanism_errors.empty()) { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:discovery_mechanism element: ", i, " error")); + for (const grpc_error_handle& discovery_mechanism_error : + discovery_mechanism_errors) { + error = grpc_error_add_child(error, discovery_mechanism_error); + } + error_list.push_back(error); + } + discovery_mechanisms.emplace_back(std::move(discovery_mechanism)); + } + } + if (discovery_mechanisms.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:discovery_mechanism error:list is missing or empty")); + } + Json xds_lb_policy = Json::Object{ + {"ROUND_ROBIN", Json::Object()}, + }; + it = json.object_value().find("xdsLbPolicy"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:xdsLbPolicy error:type should be array")); + } else { + const Json::Array& array = it->second.array_value(); + for (size_t i = 0; i < array.size(); ++i) { + if (array[i].type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:xdsLbPolicy error:element should be of type object")); + continue; + } + const Json::Object& policy = array[i].object_value(); + auto policy_it = policy.find("ROUND_ROBIN"); + if (policy_it != policy.end()) { + if (policy_it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:ROUND_ROBIN error:type should be object")); + } + break; + } + policy_it = policy.find("RING_HASH"); + if (policy_it != policy.end()) { + xds_lb_policy = array[i]; + size_t min_ring_size; + size_t max_ring_size; + ParseRingHashLbConfig(policy_it->second, &min_ring_size, + &max_ring_size, &error_list); + } + } + } + } + // Construct config. + if (error_list.empty()) { + return MakeRefCounted( + std::move(discovery_mechanisms), std::move(xds_lb_policy)); + } else { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "xds_cluster_resolver_experimental LB policy config", &error_list); + return nullptr; + } + } + + private: + static std::vector ParseDiscoveryMechanism( + const Json& json, + XdsClusterResolverLbConfig::DiscoveryMechanism* discovery_mechanism) { + std::vector error_list; + if (json.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "value should be of type object")); + return error_list; + } + // Cluster name. + auto it = json.object_value().find("clusterName"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clusterName error:required field missing")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clusterName error:type should be string")); + } else { + discovery_mechanism->cluster_name = it->second.string_value(); + } + // LRS load reporting server name. + it = json.object_value().find("lrsLoadReportingServerName"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:lrsLoadReportingServerName error:type should be string")); + } else { + discovery_mechanism->lrs_load_reporting_server_name.emplace( + it->second.string_value()); + } + } + // Max concurrent requests. + discovery_mechanism->max_concurrent_requests = 1024; + it = json.object_value().find("max_concurrent_requests"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:max_concurrent_requests error:must be of type number")); + } else { + discovery_mechanism->max_concurrent_requests = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + } + } + // Discovery Mechanism type + it = json.object_value().find("type"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:type error:required field missing")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:type error:type should be string")); + } else { + if (it->second.string_value() == "EDS") { + discovery_mechanism->type = XdsClusterResolverLbConfig:: + DiscoveryMechanism::DiscoveryMechanismType::EDS; + it = json.object_value().find("edsServiceName"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:edsServiceName error:type should be string")); + } else { + discovery_mechanism->eds_service_name = it->second.string_value(); + } + } + } else if (it->second.string_value() == "LOGICAL_DNS") { + discovery_mechanism->type = XdsClusterResolverLbConfig:: + DiscoveryMechanism::DiscoveryMechanismType::LOGICAL_DNS; + it = json.object_value().find("dnsHostname"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:dnsHostname error:required field missing")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:dnsHostname error:type should be string")); + } else { + discovery_mechanism->dns_hostname = it->second.string_value(); + } + } else { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:type error:invalid type")); + } + } + return error_list; + } + + class XdsClusterResolverChildHandler : public ChildPolicyHandler { + public: + XdsClusterResolverChildHandler(RefCountedPtr xds_client, + Args args, absl::string_view server_name, + bool is_xds_uri) + : ChildPolicyHandler(std::move(args), + &grpc_lb_xds_cluster_resolver_trace), + xds_client_(std::move(xds_client)), + server_name_(server_name), + is_xds_uri_(is_xds_uri) {} + + bool ConfigChangeRequiresNewPolicyInstance( + LoadBalancingPolicy::Config* old_config, + LoadBalancingPolicy::Config* new_config) const override { + GPR_ASSERT(old_config->name() == kXdsClusterResolver); + GPR_ASSERT(new_config->name() == kXdsClusterResolver); + XdsClusterResolverLbConfig* old_xds_cluster_resolver_config = + static_cast(old_config); + XdsClusterResolverLbConfig* new_xds_cluster_resolver_config = + static_cast(new_config); + return old_xds_cluster_resolver_config->discovery_mechanisms() != + new_xds_cluster_resolver_config->discovery_mechanisms(); + } + + OrphanablePtr CreateLoadBalancingPolicy( + const char* /*name*/, LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(xds_client_, std::move(args), + server_name_, is_xds_uri_); + } + + private: + RefCountedPtr xds_client_; + std::string server_name_; + bool is_xds_uri_; + }; +}; + +} // namespace + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_lb_policy_xds_cluster_resolver_init() { + grpc_core::LoadBalancingPolicyRegistry::Builder:: + RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +void grpc_lb_policy_xds_cluster_resolver_shutdown() {} diff --git a/src/core/ext/filters/client_channel/lb_policy_registry.cc b/src/core/ext/filters/client_channel/lb_policy_registry.cc new file mode 100644 index 00000000..c72a6095 --- /dev/null +++ b/src/core/ext/filters/client_channel/lb_policy_registry.cc @@ -0,0 +1,185 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include "src/core/lib/gpr/string.h" + +namespace grpc_core { + +namespace { + +class RegistryState { + public: + RegistryState() {} + + void RegisterLoadBalancingPolicyFactory( + std::unique_ptr factory) { + gpr_log(GPR_DEBUG, "registering LB policy factory for \"%s\"", + factory->name()); + for (size_t i = 0; i < factories_.size(); ++i) { + GPR_ASSERT(strcmp(factories_[i]->name(), factory->name()) != 0); + } + factories_.push_back(std::move(factory)); + } + + LoadBalancingPolicyFactory* GetLoadBalancingPolicyFactory( + const char* name) const { + for (size_t i = 0; i < factories_.size(); ++i) { + if (strcmp(name, factories_[i]->name()) == 0) { + return factories_[i].get(); + } + } + return nullptr; + } + + private: + absl::InlinedVector, 10> + factories_; +}; + +RegistryState* g_state = nullptr; + +} // namespace + +// +// LoadBalancingPolicyRegistry::Builder +// + +void LoadBalancingPolicyRegistry::Builder::InitRegistry() { + if (g_state == nullptr) g_state = new RegistryState(); +} + +void LoadBalancingPolicyRegistry::Builder::ShutdownRegistry() { + delete g_state; + g_state = nullptr; +} + +void LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + std::unique_ptr factory) { + InitRegistry(); + g_state->RegisterLoadBalancingPolicyFactory(std::move(factory)); +} + +// +// LoadBalancingPolicyRegistry +// + +OrphanablePtr +LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + const char* name, LoadBalancingPolicy::Args args) { + GPR_ASSERT(g_state != nullptr); + // Find factory. + LoadBalancingPolicyFactory* factory = + g_state->GetLoadBalancingPolicyFactory(name); + if (factory == nullptr) return nullptr; // Specified name not found. + // Create policy via factory. + return factory->CreateLoadBalancingPolicy(std::move(args)); +} + +bool LoadBalancingPolicyRegistry::LoadBalancingPolicyExists( + const char* name, bool* requires_config) { + GPR_ASSERT(g_state != nullptr); + auto* factory = g_state->GetLoadBalancingPolicyFactory(name); + if (factory == nullptr) { + return false; + } + if (requires_config != nullptr) { + grpc_error_handle error = GRPC_ERROR_NONE; + // Check if the load balancing policy allows an empty config + *requires_config = + factory->ParseLoadBalancingConfig(Json(), &error) == nullptr; + GRPC_ERROR_UNREF(error); + } + return true; +} + +namespace { + +// Returns the JSON node of policy (with both policy name and config content) +// given the JSON node of a LoadBalancingConfig array. +grpc_error_handle ParseLoadBalancingConfigHelper( + const Json& lb_config_array, Json::Object::const_iterator* result) { + if (lb_config_array.type() != Json::Type::ARRAY) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("type should be array"); + } + // Find the first LB policy that this client supports. + std::vector policies_tried; + for (const Json& lb_config : lb_config_array.array_value()) { + if (lb_config.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child entry should be of type object"); + } + if (lb_config.object_value().empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "no policy found in child entry"); + } + if (lb_config.object_value().size() > 1) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("oneOf violation"); + } + auto it = lb_config.object_value().begin(); + if (it->second.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "child entry should be of type object"); + } + // If we support this policy, then select it. + if (LoadBalancingPolicyRegistry::LoadBalancingPolicyExists( + it->first.c_str(), nullptr)) { + *result = it; + return GRPC_ERROR_NONE; + } + policies_tried.push_back(it->first); + } + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "No known policies in list: ", absl::StrJoin(policies_tried, " "))); +} + +} // namespace + +RefCountedPtr +LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + GPR_ASSERT(g_state != nullptr); + Json::Object::const_iterator policy; + *error = ParseLoadBalancingConfigHelper(json, &policy); + if (*error != GRPC_ERROR_NONE) { + return nullptr; + } + // Find factory. + LoadBalancingPolicyFactory* factory = + g_state->GetLoadBalancingPolicyFactory(policy->first.c_str()); + if (factory == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Factory not found for policy \"%s\"", policy->first)); + return nullptr; + } + // Parse load balancing config via factory. + return factory->ParseLoadBalancingConfig(policy->second, error); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/local_subchannel_pool.cc b/src/core/ext/filters/client_channel/local_subchannel_pool.cc new file mode 100644 index 00000000..49338ecc --- /dev/null +++ b/src/core/ext/filters/client_channel/local_subchannel_pool.cc @@ -0,0 +1,56 @@ +// +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/filters/client_channel/local_subchannel_pool.h" + +#include "src/core/ext/filters/client_channel/subchannel.h" + +namespace grpc_core { + +RefCountedPtr LocalSubchannelPool::RegisterSubchannel( + const SubchannelKey& key, RefCountedPtr constructed) { + auto it = subchannel_map_.find(key); + // Because this pool is only accessed under the client channel's work + // serializer, and because FindSubchannel is checked before invoking + // RegisterSubchannel, no such subchannel should exist in the map. + GPR_ASSERT(it == subchannel_map_.end()); + subchannel_map_[key] = constructed.get(); + return constructed; +} + +void LocalSubchannelPool::UnregisterSubchannel(const SubchannelKey& key, + Subchannel* subchannel) { + auto it = subchannel_map_.find(key); + // Because this subchannel pool is accessed only under the client + // channel's work serializer, any subchannel created by RegisterSubchannel + // will be deleted from the map in UnregisterSubchannel. + GPR_ASSERT(it != subchannel_map_.end()); + GPR_ASSERT(it->second == subchannel); + subchannel_map_.erase(it); +} + +RefCountedPtr LocalSubchannelPool::FindSubchannel( + const SubchannelKey& key) { + auto it = subchannel_map_.find(key); + if (it == subchannel_map_.end()) return nullptr; + return it->second->Ref(); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/proxy_mapper_registry.cc b/src/core/ext/filters/client_channel/proxy_mapper_registry.cc new file mode 100644 index 00000000..2fc25725 --- /dev/null +++ b/src/core/ext/filters/client_channel/proxy_mapper_registry.cc @@ -0,0 +1,89 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" + +#include +#include + +namespace grpc_core { + +namespace { + +using ProxyMapperList = std::vector>; +ProxyMapperList* g_proxy_mapper_list; + +} // namespace + +void ProxyMapperRegistry::Init() { + if (g_proxy_mapper_list == nullptr) { + g_proxy_mapper_list = new ProxyMapperList(); + } +} + +void ProxyMapperRegistry::Shutdown() { + delete g_proxy_mapper_list; + // Clean up in case we re-initialze later. + // TODO(roth): This should ideally live in Init(). However, if we did this + // there, then we would do it AFTER we start registering proxy mappers from + // third-party plugins, so they'd never show up (and would leak memory). + // We probably need some sort of dependency system for plugins to fix + // this. + g_proxy_mapper_list = nullptr; +} + +void ProxyMapperRegistry::Register( + bool at_start, std::unique_ptr mapper) { + Init(); + if (at_start) { + g_proxy_mapper_list->insert(g_proxy_mapper_list->begin(), + std::move(mapper)); + } else { + g_proxy_mapper_list->emplace_back(std::move(mapper)); + } +} + +bool ProxyMapperRegistry::MapName(const char* server_uri, + const grpc_channel_args* args, + char** name_to_resolve, + grpc_channel_args** new_args) { + Init(); + for (const auto& mapper : *g_proxy_mapper_list) { + if (mapper->MapName(server_uri, args, name_to_resolve, new_args)) { + return true; + } + } + return false; +} + +bool ProxyMapperRegistry::MapAddress(const grpc_resolved_address& address, + const grpc_channel_args* args, + grpc_resolved_address** new_address, + grpc_channel_args** new_args) { + Init(); + for (const auto& mapper : *g_proxy_mapper_list) { + if (mapper->MapAddress(address, args, new_address, new_args)) { + return true; + } + } + return false; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/resolver.cc b/src/core/ext/filters/client_channel/resolver.cc new file mode 100644 index 00000000..a5494b16 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver.cc @@ -0,0 +1,87 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/resolver.h" + +grpc_core::DebugOnlyTraceFlag grpc_trace_resolver_refcount(false, + "resolver_refcount"); + +namespace grpc_core { + +// +// Resolver +// + +Resolver::Resolver() + : InternallyRefCounted(GRPC_TRACE_FLAG_ENABLED(grpc_trace_resolver_refcount) + ? "Resolver" + : nullptr) {} + +// +// Resolver::Result +// + +Resolver::Result::~Result() { + GRPC_ERROR_UNREF(service_config_error); + grpc_channel_args_destroy(args); +} + +Resolver::Result::Result(const Result& other) { + addresses = other.addresses; + service_config = other.service_config; + service_config_error = GRPC_ERROR_REF(other.service_config_error); + args = grpc_channel_args_copy(other.args); +} + +Resolver::Result::Result(Result&& other) noexcept { + addresses = std::move(other.addresses); + service_config = std::move(other.service_config); + service_config_error = other.service_config_error; + other.service_config_error = GRPC_ERROR_NONE; + args = other.args; + other.args = nullptr; +} + +Resolver::Result& Resolver::Result::operator=(const Result& other) { + if (&other == this) { + return *this; + } + addresses = other.addresses; + service_config = other.service_config; + GRPC_ERROR_UNREF(service_config_error); + service_config_error = GRPC_ERROR_REF(other.service_config_error); + grpc_channel_args_destroy(args); + args = grpc_channel_args_copy(other.args); + return *this; +} + +Resolver::Result& Resolver::Result::operator=(Result&& other) noexcept { + addresses = std::move(other.addresses); + service_config = std::move(other.service_config); + GRPC_ERROR_UNREF(service_config_error); + service_config_error = other.service_config_error; + other.service_config_error = GRPC_ERROR_NONE; + grpc_channel_args_destroy(args); + args = other.args; + other.args = nullptr; + return *this; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc b/src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc new file mode 100644 index 00000000..68f21c80 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/binder/binder_resolver.cc @@ -0,0 +1,139 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_HAVE_UNIX_SOCKET + +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" + +namespace grpc_core { +namespace { + +class BinderResolver : public Resolver { + public: + BinderResolver(ServerAddressList addresses, ResolverArgs args) + : result_handler_(std::move(args.result_handler)), + addresses_(std::move(addresses)), + channel_args_(grpc_channel_args_copy(args.args)) {} + + ~BinderResolver() override { grpc_channel_args_destroy(channel_args_); }; + + void StartLocked() override { + Result result; + result.addresses = std::move(addresses_); + result.args = channel_args_; + channel_args_ = nullptr; + result_handler_->ReturnResult(std::move(result)); + } + + void ShutdownLocked() override {} + + private: + std::unique_ptr result_handler_; + ServerAddressList addresses_; + const grpc_channel_args* channel_args_ = nullptr; +}; + +class BinderResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + return ParseUri(uri, nullptr); + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + ServerAddressList addresses; + if (!ParseUri(args.uri, &addresses)) return nullptr; + return MakeOrphanable(std::move(addresses), + std::move(args)); + } + + const char* scheme() const override { return "binder"; } + + private: + static grpc_error_handle BinderAddrPopulate( + absl::string_view path, grpc_resolved_address* resolved_addr) { + path = absl::StripPrefix(path, "/"); + if (path.empty()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING("path is empty"); + } + // Store parsed path in a unix socket so it can be reinterpreted as + // sockaddr. An invalid address family (AF_MAX) is set to make sure it won't + // be accidentally used. + memset(resolved_addr, 0, sizeof(*resolved_addr)); + struct sockaddr_un* un = + reinterpret_cast(resolved_addr->addr); + un->sun_family = AF_MAX; + static_assert(sizeof(un->sun_path) >= 101, + "unix socket path size is unexpectedly short"); + if (path.size() + 1 > sizeof(un->sun_path)) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat(path, " is too long to be handled")); + } + // `un` has already be set to zero, no need to append null after the string + memcpy(un->sun_path, path.data(), path.size()); + resolved_addr->len = + static_cast(sizeof(un->sun_family) + path.size() + 1); + return GRPC_ERROR_NONE; + } + + static bool ParseUri(const URI& uri, ServerAddressList* addresses) { + grpc_resolved_address addr; + { + if (!uri.authority().empty()) { + gpr_log(GPR_ERROR, "authority is not supported in binder scheme"); + return false; + } + grpc_error_handle error = BinderAddrPopulate(uri.path(), &addr); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return false; + } + } + if (addresses != nullptr) { + addresses->emplace_back(addr, nullptr /* args */); + } + return true; + } +}; + +} // namespace +} // namespace grpc_core + +void grpc_resolver_binder_init() { + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); +} + +void grpc_resolver_binder_shutdown() {} + +#else + +void grpc_resolver_binder_init() {} + +void grpc_resolver_binder_shutdown() {} + +#endif diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc new file mode 100644 index 00000000..68624cdb --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc @@ -0,0 +1,526 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if GRPC_ARES == 1 + +#include +#include +#include + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/service_config/service_config.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/gethostname.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/json/json.h" + +#define GRPC_DNS_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define GRPC_DNS_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define GRPC_DNS_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define GRPC_DNS_RECONNECT_JITTER 0.2 + +namespace grpc_core { + +namespace { + +class AresDnsResolver : public Resolver { + public: + explicit AresDnsResolver(ResolverArgs args); + + void StartLocked() override; + + void RequestReresolutionLocked() override; + + void ResetBackoffLocked() override; + + void ShutdownLocked() override; + + private: + ~AresDnsResolver() override; + + void MaybeStartResolvingLocked(); + void StartResolvingLocked(); + + static void OnNextResolution(void* arg, grpc_error_handle error); + static void OnResolved(void* arg, grpc_error_handle error); + void OnNextResolutionLocked(grpc_error_handle error); + void OnResolvedLocked(grpc_error_handle error); + + /// DNS server to use (if not system default) + std::string dns_server_; + /// name to resolve (usually the same as target_name) + std::string name_to_resolve_; + /// channel args + grpc_channel_args* channel_args_; + std::shared_ptr work_serializer_; + std::unique_ptr result_handler_; + /// pollset_set to drive the name resolution process + grpc_pollset_set* interested_parties_; + + /// whether to request the service config + bool request_service_config_; + // whether or not to enable SRV DNS queries + bool enable_srv_queries_; + // timeout in milliseconds for active DNS queries + int query_timeout_ms_; + /// min interval between DNS requests + grpc_millis min_time_between_resolutions_; + + /// closures used by the work_serializer + grpc_closure on_next_resolution_; + grpc_closure on_resolved_; + /// are we currently resolving? + bool resolving_ = false; + /// the pending resolving request + grpc_ares_request* pending_request_ = nullptr; + /// next resolution timer + bool have_next_resolution_timer_ = false; + grpc_timer next_resolution_timer_; + /// timestamp of last DNS request + grpc_millis last_resolution_timestamp_ = -1; + /// retry backoff state + BackOff backoff_; + /// currently resolving backend addresses + std::unique_ptr addresses_; + /// currently resolving balancer addresses + std::unique_ptr balancer_addresses_; + /// currently resolving service config + char* service_config_json_ = nullptr; + // has shutdown been initiated + bool shutdown_initiated_ = false; +}; + +AresDnsResolver::AresDnsResolver(ResolverArgs args) + : dns_server_(args.uri.authority()), + name_to_resolve_(absl::StripPrefix(args.uri.path(), "/")), + channel_args_(grpc_channel_args_copy(args.args)), + work_serializer_(std::move(args.work_serializer)), + result_handler_(std::move(args.result_handler)), + interested_parties_(args.pollset_set), + request_service_config_(!grpc_channel_args_find_bool( + channel_args_, GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION, true)), + enable_srv_queries_(grpc_channel_args_find_bool( + channel_args_, GRPC_ARG_DNS_ENABLE_SRV_QUERIES, false)), + query_timeout_ms_(grpc_channel_args_find_integer( + channel_args_, GRPC_ARG_DNS_ARES_QUERY_TIMEOUT_MS, + {GRPC_DNS_ARES_DEFAULT_QUERY_TIMEOUT_MS, 0, INT_MAX})), + min_time_between_resolutions_(grpc_channel_args_find_integer( + channel_args_, GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS, + {1000 * 30, 0, INT_MAX})), + backoff_( + BackOff::Options() + .set_initial_backoff(GRPC_DNS_INITIAL_CONNECT_BACKOFF_SECONDS * + 1000) + .set_multiplier(GRPC_DNS_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(GRPC_DNS_RECONNECT_JITTER) + .set_max_backoff(GRPC_DNS_RECONNECT_MAX_BACKOFF_SECONDS * 1000)) { + // Closure initialization. + GRPC_CLOSURE_INIT(&on_next_resolution_, OnNextResolution, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_resolved_, OnResolved, this, grpc_schedule_on_exec_ctx); +} + +AresDnsResolver::~AresDnsResolver() { + GRPC_CARES_TRACE_LOG("resolver:%p destroying AresDnsResolver", this); + grpc_channel_args_destroy(channel_args_); +} + +void AresDnsResolver::StartLocked() { + GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::StartLocked() is called.", + this); + MaybeStartResolvingLocked(); +} + +void AresDnsResolver::RequestReresolutionLocked() { + if (!resolving_) { + MaybeStartResolvingLocked(); + } +} + +void AresDnsResolver::ResetBackoffLocked() { + if (have_next_resolution_timer_) { + grpc_timer_cancel(&next_resolution_timer_); + } + backoff_.Reset(); +} + +void AresDnsResolver::ShutdownLocked() { + shutdown_initiated_ = true; + if (have_next_resolution_timer_) { + grpc_timer_cancel(&next_resolution_timer_); + } + if (pending_request_ != nullptr) { + grpc_cancel_ares_request_locked(pending_request_); + } +} + +void AresDnsResolver::OnNextResolution(void* arg, grpc_error_handle error) { + AresDnsResolver* r = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + r->work_serializer_->Run([r, error]() { r->OnNextResolutionLocked(error); }, + DEBUG_LOCATION); +} + +void AresDnsResolver::OnNextResolutionLocked(grpc_error_handle error) { + GRPC_CARES_TRACE_LOG( + "resolver:%p re-resolution timer fired. error: %s. shutdown_initiated_: " + "%d", + this, grpc_error_std_string(error).c_str(), shutdown_initiated_); + have_next_resolution_timer_ = false; + if (error == GRPC_ERROR_NONE && !shutdown_initiated_) { + if (!resolving_) { + GRPC_CARES_TRACE_LOG( + "resolver:%p start resolving due to re-resolution timer", this); + StartResolvingLocked(); + } + } + Unref(DEBUG_LOCATION, "next_resolution_timer"); + GRPC_ERROR_UNREF(error); +} + +bool ValueInJsonArray(const Json::Array& array, const char* value) { + for (const Json& entry : array) { + if (entry.type() == Json::Type::STRING && entry.string_value() == value) { + return true; + } + } + return false; +} + +std::string ChooseServiceConfig(char* service_config_choice_json, + grpc_error_handle* error) { + Json json = Json::Parse(service_config_choice_json, error); + if (*error != GRPC_ERROR_NONE) return ""; + if (json.type() != Json::Type::ARRAY) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Service Config Choices, error: should be of type array"); + return ""; + } + const Json* service_config = nullptr; + absl::InlinedVector error_list; + for (const Json& choice : json.array_value()) { + if (choice.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Service Config Choice, error: should be of type object")); + continue; + } + // Check client language, if specified. + auto it = choice.object_value().find("clientLanguage"); + if (it != choice.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clientLanguage error:should be of type array")); + } else if (!ValueInJsonArray(it->second.array_value(), "c++")) { + continue; + } + } + // Check client hostname, if specified. + it = choice.object_value().find("clientHostname"); + if (it != choice.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:clientHostname error:should be of type array")); + } else { + char* hostname = grpc_gethostname(); + if (hostname == nullptr || + !ValueInJsonArray(it->second.array_value(), hostname)) { + continue; + } + } + } + // Check percentage, if specified. + it = choice.object_value().find("percentage"); + if (it != choice.object_value().end()) { + if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:percentage error:should be of type number")); + } else { + int random_pct = rand() % 100; + int percentage; + if (sscanf(it->second.string_value().c_str(), "%d", &percentage) != 1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:percentage error:should be of type integer")); + } else if (random_pct > percentage || percentage == 0) { + continue; + } + } + } + // Found service config. + it = choice.object_value().find("serviceConfig"); + if (it == choice.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:serviceConfig error:required field missing")); + } else if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:serviceConfig error:should be of type object")); + } else if (service_config == nullptr) { + service_config = &it->second; + } + } + if (!error_list.empty()) { + service_config = nullptr; + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Service Config Choices Parser", + &error_list); + } + if (service_config == nullptr) return ""; + return service_config->Dump(); +} + +void AresDnsResolver::OnResolved(void* arg, grpc_error_handle error) { + AresDnsResolver* r = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + r->work_serializer_->Run([r, error]() { r->OnResolvedLocked(error); }, + DEBUG_LOCATION); +} + +void AresDnsResolver::OnResolvedLocked(grpc_error_handle error) { + GPR_ASSERT(resolving_); + resolving_ = false; + delete pending_request_; + pending_request_ = nullptr; + if (shutdown_initiated_) { + Unref(DEBUG_LOCATION, "OnResolvedLocked() shutdown"); + GRPC_ERROR_UNREF(error); + return; + } + if (addresses_ != nullptr || balancer_addresses_ != nullptr) { + Result result; + if (addresses_ != nullptr) { + result.addresses = std::move(*addresses_); + } + if (service_config_json_ != nullptr) { + std::string service_config_string = ChooseServiceConfig( + service_config_json_, &result.service_config_error); + gpr_free(service_config_json_); + if (result.service_config_error == GRPC_ERROR_NONE && + !service_config_string.empty()) { + GRPC_CARES_TRACE_LOG("resolver:%p selected service config choice: %s", + this, service_config_string.c_str()); + result.service_config = ServiceConfig::Create( + channel_args_, service_config_string, &result.service_config_error); + } + } + absl::InlinedVector new_args; + if (balancer_addresses_ != nullptr) { + new_args.push_back( + CreateGrpclbBalancerAddressesArg(balancer_addresses_.get())); + } + result.args = grpc_channel_args_copy_and_add(channel_args_, new_args.data(), + new_args.size()); + result_handler_->ReturnResult(std::move(result)); + addresses_.reset(); + balancer_addresses_.reset(); + // Reset backoff state so that we start from the beginning when the + // next request gets triggered. + backoff_.Reset(); + } else { + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed: %s", this, + grpc_error_std_string(error).c_str()); + std::string error_message = + absl::StrCat("DNS resolution failed for service: ", name_to_resolve_); + result_handler_->ReturnError(grpc_error_set_int( + GRPC_ERROR_CREATE_REFERENCING_FROM_COPIED_STRING(error_message.c_str(), + &error, 1), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + // Set retry timer. + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + ExecCtx::Get()->InvalidateNow(); + grpc_millis next_try = backoff_.NextAttemptTime(); + grpc_millis timeout = next_try - ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed (will retry): %s", + this, grpc_error_std_string(error).c_str()); + GPR_ASSERT(!have_next_resolution_timer_); + have_next_resolution_timer_ = true; + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "retry-timer").release(); + if (timeout > 0) { + GRPC_CARES_TRACE_LOG("resolver:%p retrying in %" PRId64 " milliseconds", + this, timeout); + } else { + GRPC_CARES_TRACE_LOG("resolver:%p retrying immediately", this); + } + grpc_timer_init(&next_resolution_timer_, next_try, &on_next_resolution_); + } + Unref(DEBUG_LOCATION, "dns-resolving"); + GRPC_ERROR_UNREF(error); +} + +void AresDnsResolver::MaybeStartResolvingLocked() { + // If there is an existing timer, the time it fires is the earliest time we + // can start the next resolution. + if (have_next_resolution_timer_) return; + if (last_resolution_timestamp_ >= 0) { + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + ExecCtx::Get()->InvalidateNow(); + const grpc_millis earliest_next_resolution = + last_resolution_timestamp_ + min_time_between_resolutions_; + const grpc_millis ms_until_next_resolution = + earliest_next_resolution - grpc_core::ExecCtx::Get()->Now(); + if (ms_until_next_resolution > 0) { + const grpc_millis last_resolution_ago = + grpc_core::ExecCtx::Get()->Now() - last_resolution_timestamp_; + GRPC_CARES_TRACE_LOG( + "resolver:%p In cooldown from last resolution (from %" PRId64 + " ms ago). Will resolve again in %" PRId64 " ms", + this, last_resolution_ago, ms_until_next_resolution); + have_next_resolution_timer_ = true; + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "next_resolution_timer_cooldown").release(); + grpc_timer_init(&next_resolution_timer_, + ExecCtx::Get()->Now() + ms_until_next_resolution, + &on_next_resolution_); + return; + } + } + StartResolvingLocked(); +} + +void AresDnsResolver::StartResolvingLocked() { + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "dns-resolving").release(); + GPR_ASSERT(!resolving_); + resolving_ = true; + service_config_json_ = nullptr; + pending_request_ = grpc_dns_lookup_ares_locked( + dns_server_.c_str(), name_to_resolve_.c_str(), kDefaultSecurePort, + interested_parties_, &on_resolved_, &addresses_, + enable_srv_queries_ ? &balancer_addresses_ : nullptr, + request_service_config_ ? &service_config_json_ : nullptr, + query_timeout_ms_, work_serializer_); + last_resolution_timestamp_ = grpc_core::ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG("resolver:%p Started resolving. pending_request_:%p", + this, pending_request_); +} + +// +// Factory +// + +class AresDnsResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + if (absl::StripPrefix(uri.path(), "/").empty()) { + gpr_log(GPR_ERROR, "no server name supplied in dns URI"); + return false; + } + return true; + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* scheme() const override { return "dns"; } +}; + +} // namespace + +} // namespace grpc_core + +extern grpc_address_resolver_vtable* grpc_resolve_address_impl; +static grpc_address_resolver_vtable* default_resolver; + +static grpc_error_handle blocking_resolve_address_ares( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + return default_resolver->blocking_resolve_address(name, default_port, + addresses); +} + +static grpc_address_resolver_vtable ares_resolver = { + grpc_resolve_address_ares, blocking_resolve_address_ares}; + +static bool should_use_ares(const char* resolver_env) { + // TODO(lidiz): Remove the "g_custom_iomgr_enabled" flag once c-ares support + // custom IO managers (e.g. gevent). + return !g_custom_iomgr_enabled && + (resolver_env == nullptr || strlen(resolver_env) == 0 || + gpr_stricmp(resolver_env, "ares") == 0); +} + +static bool g_use_ares_dns_resolver; + +void grpc_resolver_dns_ares_init() { + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (should_use_ares(resolver.get())) { + g_use_ares_dns_resolver = true; + gpr_log(GPR_DEBUG, "Using ares dns resolver"); + address_sorting_init(); + grpc_error_handle error = grpc_ares_init(); + if (error != GRPC_ERROR_NONE) { + GRPC_LOG_IF_ERROR("grpc_ares_init() failed", error); + return; + } + if (default_resolver == nullptr) { + default_resolver = grpc_resolve_address_impl; + } + grpc_set_resolver_impl(&ares_resolver); + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + } else { + g_use_ares_dns_resolver = false; + } +} + +void grpc_resolver_dns_ares_shutdown() { + if (g_use_ares_dns_resolver) { + address_sorting_shutdown(); + grpc_ares_cleanup(); + } +} + +#else /* GRPC_ARES == 1 */ + +void grpc_resolver_dns_ares_init(void) {} + +void grpc_resolver_dns_ares_shutdown(void) {} + +#endif /* GRPC_ARES == 1 */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_event_engine.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_event_engine.cc new file mode 100644 index 00000000..45c14648 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_event_engine.cc @@ -0,0 +1,31 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_USE_EVENT_ENGINE) + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" + +namespace grpc_core { + +std::unique_ptr NewGrpcPolledFdFactory( + std::shared_ptr /* work_serializer */) { + return nullptr; +} + +} // namespace grpc_core + +#endif /* GRPC_ARES == 1 && defined(GRPC_USE_EVENT_ENGINE) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.cc new file mode 100644 index 00000000..e25bd55b --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_POSIX_SOCKET_ARES_EV_DRIVER) + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/iomgr_internal.h" + +namespace grpc_core { + +class GrpcPolledFdPosix : public GrpcPolledFd { + public: + GrpcPolledFdPosix(ares_socket_t as, grpc_pollset_set* driver_pollset_set) + : name_(absl::StrCat("c-ares fd: ", static_cast(as))), as_(as) { + fd_ = grpc_fd_create(static_cast(as), name_.c_str(), false); + driver_pollset_set_ = driver_pollset_set; + grpc_pollset_set_add_fd(driver_pollset_set_, fd_); + } + + ~GrpcPolledFdPosix() override { + grpc_pollset_set_del_fd(driver_pollset_set_, fd_); + /* c-ares library will close the fd inside grpc_fd. This fd may be picked up + immediately by another thread, and should not be closed by the following + grpc_fd_orphan. */ + int phony_release_fd; + grpc_fd_orphan(fd_, nullptr, &phony_release_fd, "c-ares query finished"); + } + + void RegisterForOnReadableLocked(grpc_closure* read_closure) override { + grpc_fd_notify_on_read(fd_, read_closure); + } + + void RegisterForOnWriteableLocked(grpc_closure* write_closure) override { + grpc_fd_notify_on_write(fd_, write_closure); + } + + bool IsFdStillReadableLocked() override { + size_t bytes_available = 0; + return ioctl(grpc_fd_wrapped_fd(fd_), FIONREAD, &bytes_available) == 0 && + bytes_available > 0; + } + + void ShutdownLocked(grpc_error_handle error) override { + grpc_fd_shutdown(fd_, error); + } + + ares_socket_t GetWrappedAresSocketLocked() override { return as_; } + + const char* GetName() override { return name_.c_str(); } + + private: + std::string name_; + ares_socket_t as_; + grpc_fd* fd_; + grpc_pollset_set* driver_pollset_set_; +}; + +class GrpcPolledFdFactoryPosix : public GrpcPolledFdFactory { + public: + GrpcPolledFd* NewGrpcPolledFdLocked( + ares_socket_t as, grpc_pollset_set* driver_pollset_set, + std::shared_ptr /*work_serializer*/) override { + return new GrpcPolledFdPosix(as, driver_pollset_set); + } + + void ConfigureAresChannelLocked(ares_channel /*channel*/) override {} +}; + +std::unique_ptr NewGrpcPolledFdFactory( + std::shared_ptr work_serializer) { + (void)work_serializer; + return absl::make_unique(); +} + +} // namespace grpc_core + +#endif /* GRPC_ARES == 1 && defined(GRPC_POSIX_SOCKET_ARES_EV_DRIVER) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc new file mode 100644 index 00000000..33b78b62 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_windows.cc @@ -0,0 +1,902 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER) + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_windows.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/slice/slice_internal.h" + +/* TODO(apolcyn): remove this hack after fixing upstream. + * Our grpc/c-ares code on Windows uses the ares_set_socket_functions API, + * which uses "struct iovec" type, which on Windows is defined inside of + * a c-ares header that is not public. + * See https://github.com/c-ares/c-ares/issues/206. */ +struct iovec { + void* iov_base; + size_t iov_len; +}; + +namespace grpc_core { + +/* c-ares reads and takes action on the error codes of the + * "virtual socket operations" in this file, via the WSAGetLastError + * APIs. If code in this file wants to set a specific WSA error that + * c-ares should read, it must do so by calling SetWSAError() on the + * WSAErrorContext instance passed to it. A WSAErrorContext must only be + * instantiated at the top of the virtual socket function callstack. */ +class WSAErrorContext { + public: + explicit WSAErrorContext(){}; + + ~WSAErrorContext() { + if (error_ != 0) { + WSASetLastError(error_); + } + } + + /* Disallow copy and assignment operators */ + WSAErrorContext(const WSAErrorContext&) = delete; + WSAErrorContext& operator=(const WSAErrorContext&) = delete; + + void SetWSAError(int error) { error_ = error; } + + private: + int error_ = 0; +}; + +/* c-ares creates its own sockets and is meant to read them when readable and + * write them when writeable. To fit this socket usage model into the grpc + * windows poller (which gives notifications when attempted reads and writes are + * actually fulfilled rather than possible), this GrpcPolledFdWindows class + * takes advantage of the ares_set_socket_functions API and acts as a virtual + * socket. It holds its own read and write buffers which are written to and read + * from c-ares and are used with the grpc windows poller, and it, e.g., + * manufactures virtual socket error codes when it e.g. needs to tell the c-ares + * library to wait for an async read. */ +class GrpcPolledFdWindows { + public: + enum WriteState { + WRITE_IDLE, + WRITE_REQUESTED, + WRITE_PENDING, + WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY, + }; + + GrpcPolledFdWindows(ares_socket_t as, + std::shared_ptr work_serializer, + int address_family, int socket_type) + : work_serializer_(std::move(work_serializer)), + read_buf_(grpc_empty_slice()), + write_buf_(grpc_empty_slice()), + tcp_write_state_(WRITE_IDLE), + name_(absl::StrFormat("c-ares socket: %" PRIdPTR, as)), + gotten_into_driver_list_(false), + address_family_(address_family), + socket_type_(socket_type) { + // Closure Initialization + GRPC_CLOSURE_INIT(&outer_read_closure_, + &GrpcPolledFdWindows::OnIocpReadable, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&outer_write_closure_, + &GrpcPolledFdWindows::OnIocpWriteable, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_tcp_connect_locked_, + &GrpcPolledFdWindows::OnTcpConnect, this, + grpc_schedule_on_exec_ctx); + winsocket_ = grpc_winsocket_create(as, name_.c_str()); + } + + ~GrpcPolledFdWindows() { + grpc_slice_unref_internal(read_buf_); + grpc_slice_unref_internal(write_buf_); + GPR_ASSERT(read_closure_ == nullptr); + GPR_ASSERT(write_closure_ == nullptr); + grpc_winsocket_destroy(winsocket_); + } + + void ScheduleAndNullReadClosure(grpc_error_handle error) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, read_closure_, error); + read_closure_ = nullptr; + } + + void ScheduleAndNullWriteClosure(grpc_error_handle error) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, write_closure_, error); + write_closure_ = nullptr; + } + + void RegisterForOnReadableLocked(grpc_closure* read_closure) { + GPR_ASSERT(read_closure_ == nullptr); + read_closure_ = read_closure; + GPR_ASSERT(GRPC_SLICE_LENGTH(read_buf_) == 0); + grpc_slice_unref_internal(read_buf_); + GPR_ASSERT(!read_buf_has_data_); + read_buf_ = GRPC_SLICE_MALLOC(4192); + if (connect_done_) { + work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); }, + DEBUG_LOCATION); + } else { + GPR_ASSERT(pending_continue_register_for_on_readable_locked_ == false); + pending_continue_register_for_on_readable_locked_ = true; + } + } + + void ContinueRegisterForOnReadableLocked() { + GRPC_CARES_TRACE_LOG( + "fd:|%s| InnerContinueRegisterForOnReadableLocked " + "wsa_connect_error_:%d", + GetName(), wsa_connect_error_); + GPR_ASSERT(connect_done_); + if (wsa_connect_error_ != 0) { + ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect")); + return; + } + WSABUF buffer; + buffer.buf = (char*)GRPC_SLICE_START_PTR(read_buf_); + buffer.len = GRPC_SLICE_LENGTH(read_buf_); + memset(&winsocket_->read_info.overlapped, 0, sizeof(OVERLAPPED)); + recv_from_source_addr_len_ = sizeof(recv_from_source_addr_); + DWORD flags = 0; + if (WSARecvFrom(grpc_winsocket_wrapped_socket(winsocket_), &buffer, 1, + nullptr, &flags, (sockaddr*)recv_from_source_addr_, + &recv_from_source_addr_len_, + &winsocket_->read_info.overlapped, nullptr)) { + int wsa_last_error = WSAGetLastError(); + char* msg = gpr_format_message(wsa_last_error); + GRPC_CARES_TRACE_LOG( + "fd:|%s| RegisterForOnReadableLocked WSARecvFrom error code:|%d| " + "msg:|%s|", + GetName(), wsa_last_error, msg); + gpr_free(msg); + if (wsa_last_error != WSA_IO_PENDING) { + ScheduleAndNullReadClosure( + GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom")); + return; + } + } + grpc_socket_notify_on_read(winsocket_, &outer_read_closure_); + } + + void RegisterForOnWriteableLocked(grpc_closure* write_closure) { + if (socket_type_ == SOCK_DGRAM) { + GRPC_CARES_TRACE_LOG("fd:|%s| RegisterForOnWriteableLocked called", + GetName()); + } else { + GPR_ASSERT(socket_type_ == SOCK_STREAM); + GRPC_CARES_TRACE_LOG( + "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d", + GetName(), tcp_write_state_); + } + GPR_ASSERT(write_closure_ == nullptr); + write_closure_ = write_closure; + if (connect_done_) { + work_serializer_->Run( + [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION); + } else { + GPR_ASSERT(pending_continue_register_for_on_writeable_locked_ == false); + pending_continue_register_for_on_writeable_locked_ = true; + } + } + + void ContinueRegisterForOnWriteableLocked() { + GRPC_CARES_TRACE_LOG( + "fd:|%s| InnerContinueRegisterForOnWriteableLocked " + "wsa_connect_error_:%d", + GetName(), wsa_connect_error_); + GPR_ASSERT(connect_done_); + if (wsa_connect_error_ != 0) { + ScheduleAndNullWriteClosure( + GRPC_WSA_ERROR(wsa_connect_error_, "connect")); + return; + } + if (socket_type_ == SOCK_DGRAM) { + ScheduleAndNullWriteClosure(GRPC_ERROR_NONE); + } else { + GPR_ASSERT(socket_type_ == SOCK_STREAM); + int wsa_error_code = 0; + switch (tcp_write_state_) { + case WRITE_IDLE: + ScheduleAndNullWriteClosure(GRPC_ERROR_NONE); + break; + case WRITE_REQUESTED: + tcp_write_state_ = WRITE_PENDING; + if (SendWriteBuf(nullptr, &winsocket_->write_info.overlapped, + &wsa_error_code) != 0) { + ScheduleAndNullWriteClosure( + GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)")); + } else { + grpc_socket_notify_on_write(winsocket_, &outer_write_closure_); + } + break; + case WRITE_PENDING: + case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY: + abort(); + } + } + } + + bool IsFdStillReadableLocked() { return read_buf_has_data_; } + + void ShutdownLocked(grpc_error_handle error) { + grpc_winsocket_shutdown(winsocket_); + } + + ares_socket_t GetWrappedAresSocketLocked() { + return grpc_winsocket_wrapped_socket(winsocket_); + } + + const char* GetName() { return name_.c_str(); } + + ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data, + ares_socket_t data_len, int flags, + struct sockaddr* from, ares_socklen_t* from_len) { + GRPC_CARES_TRACE_LOG( + "fd:|%s| RecvFrom called read_buf_has_data:%d Current read buf " + "length:|%d|", + GetName(), read_buf_has_data_, GRPC_SLICE_LENGTH(read_buf_)); + if (!read_buf_has_data_) { + wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK); + return -1; + } + ares_ssize_t bytes_read = 0; + for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) { + ((char*)data)[i] = GRPC_SLICE_START_PTR(read_buf_)[i]; + bytes_read++; + } + read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read, + GRPC_SLICE_LENGTH(read_buf_)); + if (GRPC_SLICE_LENGTH(read_buf_) == 0) { + read_buf_has_data_ = false; + } + /* c-ares overloads this recv_from virtual socket function to receive + * data on both UDP and TCP sockets, and from is nullptr for TCP. */ + if (from != nullptr) { + GPR_ASSERT(*from_len <= recv_from_source_addr_len_); + memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_); + *from_len = recv_from_source_addr_len_; + } + return bytes_read; + } + + grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) { + int total = 0; + for (int i = 0; i < iov_count; i++) { + total += iov[i].iov_len; + } + grpc_slice out = GRPC_SLICE_MALLOC(total); + size_t cur = 0; + for (int i = 0; i < iov_count; i++) { + for (int k = 0; k < iov[i].iov_len; k++) { + GRPC_SLICE_START_PTR(out)[cur++] = ((char*)iov[i].iov_base)[k]; + } + } + return out; + } + + int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped, + int* wsa_error_code) { + WSABUF buf; + buf.len = GRPC_SLICE_LENGTH(write_buf_); + buf.buf = (char*)GRPC_SLICE_START_PTR(write_buf_); + DWORD flags = 0; + int out = WSASend(grpc_winsocket_wrapped_socket(winsocket_), &buf, 1, + bytes_sent_ptr, flags, overlapped, nullptr); + *wsa_error_code = WSAGetLastError(); + GRPC_CARES_TRACE_LOG( + "fd:|%s| SendWriteBuf WSASend buf.len:%d *bytes_sent_ptr:%d " + "overlapped:%p " + "return:%d *wsa_error_code:%d", + GetName(), buf.len, bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0, + overlapped, out, *wsa_error_code); + return out; + } + + ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov, + int iov_count) { + GRPC_CARES_TRACE_LOG( + "fd:|%s| SendV called connect_done_:%d wsa_connect_error_:%d", + GetName(), connect_done_, wsa_connect_error_); + if (!connect_done_) { + wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK); + return -1; + } + if (wsa_connect_error_ != 0) { + wsa_error_ctx->SetWSAError(wsa_connect_error_); + return -1; + } + switch (socket_type_) { + case SOCK_DGRAM: + return SendVUDP(wsa_error_ctx, iov, iov_count); + case SOCK_STREAM: + return SendVTCP(wsa_error_ctx, iov, iov_count); + default: + abort(); + } + } + + ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov, + int iov_count) { + // c-ares doesn't handle retryable errors on writes of UDP sockets. + // Therefore, the sendv handler for UDP sockets must only attempt + // to write everything inline. + GRPC_CARES_TRACE_LOG("fd:|%s| SendVUDP called", GetName()); + GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0); + grpc_slice_unref_internal(write_buf_); + write_buf_ = FlattenIovec(iov, iov_count); + DWORD bytes_sent = 0; + int wsa_error_code = 0; + if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) { + grpc_slice_unref_internal(write_buf_); + write_buf_ = grpc_empty_slice(); + wsa_error_ctx->SetWSAError(wsa_error_code); + char* msg = gpr_format_message(wsa_error_code); + GRPC_CARES_TRACE_LOG( + "fd:|%s| SendVUDP SendWriteBuf error code:%d msg:|%s|", GetName(), + wsa_error_code, msg); + gpr_free(msg); + return -1; + } + write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent, + GRPC_SLICE_LENGTH(write_buf_)); + return bytes_sent; + } + + ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov, + int iov_count) { + // The "sendv" handler on TCP sockets buffers up write + // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer + // out in the background, and making further send progress in general, will + // happen as long as c-ares continues to show interest in writeability on + // this fd. + GRPC_CARES_TRACE_LOG("fd:|%s| SendVTCP called tcp_write_state_:%d", + GetName(), tcp_write_state_); + switch (tcp_write_state_) { + case WRITE_IDLE: + tcp_write_state_ = WRITE_REQUESTED; + GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0); + grpc_slice_unref_internal(write_buf_); + write_buf_ = FlattenIovec(iov, iov_count); + wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK); + return -1; + case WRITE_REQUESTED: + case WRITE_PENDING: + wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK); + return -1; + case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY: + // c-ares is retrying a send on data that we previously returned + // WSAEWOULDBLOCK for, but then subsequently wrote out in the + // background. Right now, we assume that c-ares is retrying the same + // send again. If c-ares still needs to send even more data, we'll get + // to it eventually. + grpc_slice currently_attempted = FlattenIovec(iov, iov_count); + GPR_ASSERT(GRPC_SLICE_LENGTH(currently_attempted) >= + GRPC_SLICE_LENGTH(write_buf_)); + ares_ssize_t total_sent = 0; + for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) { + GPR_ASSERT(GRPC_SLICE_START_PTR(currently_attempted)[i] == + GRPC_SLICE_START_PTR(write_buf_)[i]); + total_sent++; + } + grpc_slice_unref_internal(currently_attempted); + tcp_write_state_ = WRITE_IDLE; + return total_sent; + } + abort(); + } + + static void OnTcpConnect(void* arg, grpc_error_handle error) { + GrpcPolledFdWindows* grpc_polled_fd = + static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + grpc_polled_fd->work_serializer_->Run( + [grpc_polled_fd, error]() { + grpc_polled_fd->OnTcpConnectLocked(error); + }, + DEBUG_LOCATION); + } + + void OnTcpConnectLocked(grpc_error_handle error) { + GRPC_CARES_TRACE_LOG( + "fd:%s InnerOnTcpConnectLocked error:|%s| " + "pending_register_for_readable:%d" + " pending_register_for_writeable:%d", + GetName(), grpc_error_std_string(error).c_str(), + pending_continue_register_for_on_readable_locked_, + pending_continue_register_for_on_writeable_locked_); + GPR_ASSERT(!connect_done_); + connect_done_ = true; + GPR_ASSERT(wsa_connect_error_ == 0); + if (error == GRPC_ERROR_NONE) { + DWORD transferred_bytes = 0; + DWORD flags; + BOOL wsa_success = + WSAGetOverlappedResult(grpc_winsocket_wrapped_socket(winsocket_), + &winsocket_->write_info.overlapped, + &transferred_bytes, FALSE, &flags); + GPR_ASSERT(transferred_bytes == 0); + if (!wsa_success) { + wsa_connect_error_ = WSAGetLastError(); + char* msg = gpr_format_message(wsa_connect_error_); + GRPC_CARES_TRACE_LOG( + "fd:%s InnerOnTcpConnectLocked WSA overlapped result code:%d " + "msg:|%s|", + GetName(), wsa_connect_error_, msg); + gpr_free(msg); + } + } else { + // Spoof up an error code that will cause any future c-ares operations on + // this fd to abort. + wsa_connect_error_ = WSA_OPERATION_ABORTED; + } + if (pending_continue_register_for_on_readable_locked_) { + work_serializer_->Run([this]() { ContinueRegisterForOnReadableLocked(); }, + DEBUG_LOCATION); + } + if (pending_continue_register_for_on_writeable_locked_) { + work_serializer_->Run( + [this]() { ContinueRegisterForOnWriteableLocked(); }, DEBUG_LOCATION); + } + GRPC_ERROR_UNREF(error); + } + + int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target, + ares_socklen_t target_len) { + switch (socket_type_) { + case SOCK_DGRAM: + return ConnectUDP(wsa_error_ctx, target, target_len); + case SOCK_STREAM: + return ConnectTCP(wsa_error_ctx, target, target_len); + default: + abort(); + } + } + + int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target, + ares_socklen_t target_len) { + GRPC_CARES_TRACE_LOG("fd:%s ConnectUDP", GetName()); + GPR_ASSERT(!connect_done_); + GPR_ASSERT(wsa_connect_error_ == 0); + SOCKET s = grpc_winsocket_wrapped_socket(winsocket_); + int out = + WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr); + wsa_connect_error_ = WSAGetLastError(); + wsa_error_ctx->SetWSAError(wsa_connect_error_); + connect_done_ = true; + char* msg = gpr_format_message(wsa_connect_error_); + GRPC_CARES_TRACE_LOG("fd:%s WSAConnect error code:|%d| msg:|%s|", GetName(), + wsa_connect_error_, msg); + gpr_free(msg); + // c-ares expects a posix-style connect API + return out == 0 ? 0 : -1; + } + + int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target, + ares_socklen_t target_len) { + GRPC_CARES_TRACE_LOG("fd:%s ConnectTCP", GetName()); + LPFN_CONNECTEX ConnectEx; + GUID guid = WSAID_CONNECTEX; + DWORD ioctl_num_bytes; + SOCKET s = grpc_winsocket_wrapped_socket(winsocket_); + if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), + &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr, + nullptr) != 0) { + int wsa_last_error = WSAGetLastError(); + wsa_error_ctx->SetWSAError(wsa_last_error); + char* msg = gpr_format_message(wsa_last_error); + GRPC_CARES_TRACE_LOG( + "fd:%s WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:%d " + "msg:|%s|", + GetName(), wsa_last_error, msg); + gpr_free(msg); + connect_done_ = true; + wsa_connect_error_ = wsa_last_error; + return -1; + } + grpc_resolved_address wildcard4_addr; + grpc_resolved_address wildcard6_addr; + grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr); + grpc_resolved_address* local_address = nullptr; + if (address_family_ == AF_INET) { + local_address = &wildcard4_addr; + } else { + local_address = &wildcard6_addr; + } + if (bind(s, (struct sockaddr*)local_address->addr, + (int)local_address->len) != 0) { + int wsa_last_error = WSAGetLastError(); + wsa_error_ctx->SetWSAError(wsa_last_error); + char* msg = gpr_format_message(wsa_last_error); + GRPC_CARES_TRACE_LOG("fd:%s bind error code:%d msg:|%s|", GetName(), + wsa_last_error, msg); + gpr_free(msg); + connect_done_ = true; + wsa_connect_error_ = wsa_last_error; + return -1; + } + int out = 0; + if (ConnectEx(s, target, target_len, nullptr, 0, nullptr, + &winsocket_->write_info.overlapped) == 0) { + out = -1; + int wsa_last_error = WSAGetLastError(); + wsa_error_ctx->SetWSAError(wsa_last_error); + char* msg = gpr_format_message(wsa_last_error); + GRPC_CARES_TRACE_LOG("fd:%s ConnectEx error code:%d msg:|%s|", GetName(), + wsa_last_error, msg); + gpr_free(msg); + if (wsa_last_error == WSA_IO_PENDING) { + // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on + // connect, but an async connect on IOCP socket will give + // WSA_IO_PENDING, so we need to convert. + wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK); + } else { + // By returning a non-retryable error to c-ares at this point, + // we're aborting the possibility of any future operations on this fd. + connect_done_ = true; + wsa_connect_error_ = wsa_last_error; + return -1; + } + } + grpc_socket_notify_on_write(winsocket_, &on_tcp_connect_locked_); + return out; + } + + static void OnIocpReadable(void* arg, grpc_error_handle error) { + GrpcPolledFdWindows* polled_fd = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + polled_fd->work_serializer_->Run( + [polled_fd, error]() { polled_fd->OnIocpReadableLocked(error); }, + DEBUG_LOCATION); + } + + // TODO(apolcyn): improve this error handling to be less conversative. + // An e.g. ECONNRESET error here should result in errors when + // c-ares reads from this socket later, but it shouldn't necessarily cancel + // the entire resolution attempt. Doing so will allow the "inject broken + // nameserver list" test to pass on Windows. + void OnIocpReadableLocked(grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) { + if (winsocket_->read_info.wsa_error != 0) { + /* WSAEMSGSIZE would be due to receiving more data + * than our read buffer's fixed capacity. Assume that + * the connection is TCP and read the leftovers + * in subsequent c-ares reads. */ + if (winsocket_->read_info.wsa_error != WSAEMSGSIZE) { + error = GRPC_WSA_ERROR(winsocket_->read_info.wsa_error, + "OnIocpReadableInner"); + GRPC_CARES_TRACE_LOG( + "fd:|%s| OnIocpReadableInner winsocket_->read_info.wsa_error " + "code:|%d| msg:|%s|", + GetName(), winsocket_->read_info.wsa_error, + grpc_error_std_string(error).c_str()); + } + } + } + if (error == GRPC_ERROR_NONE) { + read_buf_ = grpc_slice_sub_no_ref( + read_buf_, 0, winsocket_->read_info.bytes_transferred); + read_buf_has_data_ = true; + } else { + grpc_slice_unref_internal(read_buf_); + read_buf_ = grpc_empty_slice(); + } + GRPC_CARES_TRACE_LOG( + "fd:|%s| OnIocpReadable finishing. read buf length now:|%d|", GetName(), + GRPC_SLICE_LENGTH(read_buf_)); + ScheduleAndNullReadClosure(error); + } + + static void OnIocpWriteable(void* arg, grpc_error_handle error) { + GrpcPolledFdWindows* polled_fd = static_cast(arg); + (void)GRPC_ERROR_REF(error); // error owned by lambda + polled_fd->work_serializer_->Run( + [polled_fd, error]() { polled_fd->OnIocpWriteableLocked(error); }, + DEBUG_LOCATION); + } + + void OnIocpWriteableLocked(grpc_error_handle error) { + GRPC_CARES_TRACE_LOG("OnIocpWriteableInner. fd:|%s|", GetName()); + GPR_ASSERT(socket_type_ == SOCK_STREAM); + if (error == GRPC_ERROR_NONE) { + if (winsocket_->write_info.wsa_error != 0) { + error = GRPC_WSA_ERROR(winsocket_->write_info.wsa_error, + "OnIocpWriteableInner"); + GRPC_CARES_TRACE_LOG( + "fd:|%s| OnIocpWriteableInner. winsocket_->write_info.wsa_error " + "code:|%d| msg:|%s|", + GetName(), winsocket_->write_info.wsa_error, + grpc_error_std_string(error).c_str()); + } + } + GPR_ASSERT(tcp_write_state_ == WRITE_PENDING); + if (error == GRPC_ERROR_NONE) { + tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY; + write_buf_ = grpc_slice_sub_no_ref( + write_buf_, 0, winsocket_->write_info.bytes_transferred); + GRPC_CARES_TRACE_LOG("fd:|%s| OnIocpWriteableInner. bytes transferred:%d", + GetName(), winsocket_->write_info.bytes_transferred); + } else { + grpc_slice_unref_internal(write_buf_); + write_buf_ = grpc_empty_slice(); + } + ScheduleAndNullWriteClosure(error); + } + + bool gotten_into_driver_list() const { return gotten_into_driver_list_; } + void set_gotten_into_driver_list() { gotten_into_driver_list_ = true; } + + private: + std::shared_ptr work_serializer_; + char recv_from_source_addr_[200]; + ares_socklen_t recv_from_source_addr_len_; + grpc_slice read_buf_; + bool read_buf_has_data_ = false; + grpc_slice write_buf_; + grpc_closure* read_closure_ = nullptr; + grpc_closure* write_closure_ = nullptr; + grpc_closure outer_read_closure_; + grpc_closure outer_write_closure_; + grpc_winsocket* winsocket_; + // tcp_write_state_ is only used on TCP GrpcPolledFds + WriteState tcp_write_state_; + std::string name_; + bool gotten_into_driver_list_; + int address_family_; + int socket_type_; + grpc_closure on_tcp_connect_locked_; + bool connect_done_ = false; + int wsa_connect_error_ = 0; + // We don't run register_for_{readable,writeable} logic until + // a socket is connected. In the interim, we queue readable/writeable + // registrations with the following state. + bool pending_continue_register_for_on_readable_locked_ = false; + bool pending_continue_register_for_on_writeable_locked_ = false; +}; + +struct SockToPolledFdEntry { + SockToPolledFdEntry(SOCKET s, GrpcPolledFdWindows* fd) + : socket(s), polled_fd(fd) {} + SOCKET socket; + GrpcPolledFdWindows* polled_fd; + SockToPolledFdEntry* next = nullptr; +}; + +/* A SockToPolledFdMap can make ares_socket_t types (SOCKET's on windows) + * to GrpcPolledFdWindow's, and is used to find the appropriate + * GrpcPolledFdWindows to handle a virtual socket call when c-ares makes that + * socket call on the ares_socket_t type. Instances are owned by and one-to-one + * with a GrpcPolledFdWindows factory and event driver */ +class SockToPolledFdMap { + public: + explicit SockToPolledFdMap(std::shared_ptr work_serializer) + : work_serializer_(std::move(work_serializer)) {} + + ~SockToPolledFdMap() { GPR_ASSERT(head_ == nullptr); } + + void AddNewSocket(SOCKET s, GrpcPolledFdWindows* polled_fd) { + SockToPolledFdEntry* new_node = new SockToPolledFdEntry(s, polled_fd); + new_node->next = head_; + head_ = new_node; + } + + GrpcPolledFdWindows* LookupPolledFd(SOCKET s) { + for (SockToPolledFdEntry* node = head_; node != nullptr; + node = node->next) { + if (node->socket == s) { + GPR_ASSERT(node->polled_fd != nullptr); + return node->polled_fd; + } + } + abort(); + } + + void RemoveEntry(SOCKET s) { + GPR_ASSERT(head_ != nullptr); + SockToPolledFdEntry** prev = &head_; + for (SockToPolledFdEntry* node = head_; node != nullptr; + node = node->next) { + if (node->socket == s) { + *prev = node->next; + delete node; + return; + } + prev = &node->next; + } + abort(); + } + + /* These virtual socket functions are called from within the c-ares + * library. These methods generally dispatch those socket calls to the + * appropriate methods. The virtual "socket" and "close" methods are + * special and instead create/add and remove/destroy GrpcPolledFdWindows + * objects. + */ + static ares_socket_t Socket(int af, int type, int protocol, void* user_data) { + if (type != SOCK_DGRAM && type != SOCK_STREAM) { + GRPC_CARES_TRACE_LOG("Socket called with invalid socket type:%d", type); + return INVALID_SOCKET; + } + SockToPolledFdMap* map = static_cast(user_data); + SOCKET s = WSASocket(af, type, protocol, nullptr, 0, + grpc_get_default_wsa_socket_flags()); + if (s == INVALID_SOCKET) { + GRPC_CARES_TRACE_LOG( + "WSASocket failed with params af:%d type:%d protocol:%d", af, type, + protocol); + return s; + } + grpc_tcp_set_non_block(s); + GrpcPolledFdWindows* polled_fd = + new GrpcPolledFdWindows(s, map->work_serializer_, af, type); + GRPC_CARES_TRACE_LOG( + "fd:|%s| created with params af:%d type:%d protocol:%d", + polled_fd->GetName(), af, type, protocol); + map->AddNewSocket(s, polled_fd); + return s; + } + + static int Connect(ares_socket_t as, const struct sockaddr* target, + ares_socklen_t target_len, void* user_data) { + WSAErrorContext wsa_error_ctx; + SockToPolledFdMap* map = static_cast(user_data); + GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); + return polled_fd->Connect(&wsa_error_ctx, target, target_len); + } + + static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov, + int iovec_count, void* user_data) { + WSAErrorContext wsa_error_ctx; + SockToPolledFdMap* map = static_cast(user_data); + GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); + return polled_fd->SendV(&wsa_error_ctx, iov, iovec_count); + } + + static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len, + int flags, struct sockaddr* from, + ares_socklen_t* from_len, void* user_data) { + WSAErrorContext wsa_error_ctx; + SockToPolledFdMap* map = static_cast(user_data); + GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(as); + return polled_fd->RecvFrom(&wsa_error_ctx, data, data_len, flags, from, + from_len); + } + + static int CloseSocket(SOCKET s, void* user_data) { + SockToPolledFdMap* map = static_cast(user_data); + GrpcPolledFdWindows* polled_fd = map->LookupPolledFd(s); + map->RemoveEntry(s); + // See https://github.com/grpc/grpc/pull/20284, this trace log is + // intentionally placed to attempt to trigger a crash in case of a + // use after free on polled_fd. + GRPC_CARES_TRACE_LOG("CloseSocket called for socket: %s", + polled_fd->GetName()); + // If a gRPC polled fd has not made it in to the driver's list yet, then + // the driver has not and will never see this socket. + if (!polled_fd->gotten_into_driver_list()) { + polled_fd->ShutdownLocked(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Shut down c-ares fd before without it ever having made it into the " + "driver's list")); + } + delete polled_fd; + return 0; + } + + private: + SockToPolledFdEntry* head_ = nullptr; + std::shared_ptr work_serializer_; +}; + +const struct ares_socket_functions custom_ares_sock_funcs = { + &SockToPolledFdMap::Socket /* socket */, + &SockToPolledFdMap::CloseSocket /* close */, + &SockToPolledFdMap::Connect /* connect */, + &SockToPolledFdMap::RecvFrom /* recvfrom */, + &SockToPolledFdMap::SendV /* sendv */, +}; + +/* A thin wrapper over a GrpcPolledFdWindows object but with a shorter + lifetime. This object releases it's GrpcPolledFdWindows upon destruction, + so that c-ares can close it via usual socket teardown. */ +class GrpcPolledFdWindowsWrapper : public GrpcPolledFd { + public: + explicit GrpcPolledFdWindowsWrapper(GrpcPolledFdWindows* wrapped) + : wrapped_(wrapped) {} + + ~GrpcPolledFdWindowsWrapper() {} + + void RegisterForOnReadableLocked(grpc_closure* read_closure) override { + wrapped_->RegisterForOnReadableLocked(read_closure); + } + + void RegisterForOnWriteableLocked(grpc_closure* write_closure) override { + wrapped_->RegisterForOnWriteableLocked(write_closure); + } + + bool IsFdStillReadableLocked() override { + return wrapped_->IsFdStillReadableLocked(); + } + + void ShutdownLocked(grpc_error_handle error) override { + wrapped_->ShutdownLocked(error); + } + + ares_socket_t GetWrappedAresSocketLocked() override { + return wrapped_->GetWrappedAresSocketLocked(); + } + + const char* GetName() override { return wrapped_->GetName(); } + + private: + GrpcPolledFdWindows* wrapped_; +}; + +class GrpcPolledFdFactoryWindows : public GrpcPolledFdFactory { + public: + explicit GrpcPolledFdFactoryWindows( + std::shared_ptr work_serializer) + : sock_to_polled_fd_map_(std::move(work_serializer)) {} + + GrpcPolledFd* NewGrpcPolledFdLocked( + ares_socket_t as, grpc_pollset_set* driver_pollset_set, + std::shared_ptr work_serializer) override { + GrpcPolledFdWindows* polled_fd = sock_to_polled_fd_map_.LookupPolledFd(as); + // Set a flag so that the virtual socket "close" method knows it + // doesn't need to call ShutdownLocked, since now the driver will. + polled_fd->set_gotten_into_driver_list(); + return new GrpcPolledFdWindowsWrapper(polled_fd); + } + + void ConfigureAresChannelLocked(ares_channel channel) override { + ares_set_socket_functions(channel, &custom_ares_sock_funcs, + &sock_to_polled_fd_map_); + } + + private: + SockToPolledFdMap sock_to_polled_fd_map_; +}; + +std::unique_ptr NewGrpcPolledFdFactory( + std::shared_ptr work_serializer) { + return absl::make_unique( + std::move(work_serializer)); +} + +} // namespace grpc_core + +#endif /* GRPC_ARES == 1 && defined(GPR_WINDOWS) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc new file mode 100644 index 00000000..fe5a42af --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc @@ -0,0 +1,1203 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if GRPC_ARES == 1 + +#include +#include + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/nameser.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/timer.h" + +using grpc_core::ServerAddress; +using grpc_core::ServerAddressList; + +grpc_core::TraceFlag grpc_trace_cares_address_sorting(false, + "cares_address_sorting"); + +grpc_core::TraceFlag grpc_trace_cares_resolver(false, "cares_resolver"); + +typedef struct fd_node { + /** the owner of this fd node */ + grpc_ares_ev_driver* ev_driver; + /** a closure wrapping on_readable_locked, which should be + invoked when the grpc_fd in this node becomes readable. */ + grpc_closure read_closure; + /** a closure wrapping on_writable_locked, which should be + invoked when the grpc_fd in this node becomes writable. */ + grpc_closure write_closure; + /** next fd node in the list */ + struct fd_node* next; + + /** wrapped fd that's polled by grpc's poller for the current platform */ + grpc_core::GrpcPolledFd* grpc_polled_fd; + /** if the readable closure has been registered */ + bool readable_registered; + /** if the writable closure has been registered */ + bool writable_registered; + /** if the fd has been shutdown yet from grpc iomgr perspective */ + bool already_shutdown; +} fd_node; + +struct grpc_ares_ev_driver { + /** the ares_channel owned by this event driver */ + ares_channel channel; + /** pollset set for driving the IO events of the channel */ + grpc_pollset_set* pollset_set; + /** refcount of the event driver */ + gpr_refcount refs; + + /** work_serializer to synchronize c-ares and I/O callbacks on */ + std::shared_ptr work_serializer; + /** a list of grpc_fd that this event driver is currently using. */ + fd_node* fds; + /** is this event driver being shut down */ + bool shutting_down; + /** request object that's using this ev driver */ + grpc_ares_request* request; + /** Owned by the ev_driver. Creates new GrpcPolledFd's */ + std::unique_ptr polled_fd_factory; + /** query timeout in milliseconds */ + int query_timeout_ms; + /** alarm to cancel active queries */ + grpc_timer query_timeout; + /** cancels queries on a timeout */ + grpc_closure on_timeout_locked; + /** alarm to poll ares_process on in case fd events don't happen */ + grpc_timer ares_backup_poll_alarm; + /** polls ares_process on a periodic timer */ + grpc_closure on_ares_backup_poll_alarm_locked; +}; + +// TODO(apolcyn): make grpc_ares_hostbyname_request a sub-class +// of GrpcAresQuery. +typedef struct grpc_ares_hostbyname_request { + /** following members are set in create_hostbyname_request_locked + */ + /** the top-level request instance */ + grpc_ares_request* parent_request; + /** host to resolve, parsed from the name to resolve */ + char* host; + /** port to fill in sockaddr_in, parsed from the name to resolve */ + uint16_t port; + /** is it a grpclb address */ + bool is_balancer; + /** for logging and errors: the query type ("A" or "AAAA") */ + const char* qtype; +} grpc_ares_hostbyname_request; + +static void grpc_ares_request_ref_locked(grpc_ares_request* r); +static void grpc_ares_request_unref_locked(grpc_ares_request* r); + +// TODO(apolcyn): as a part of C++-ification, find a way to +// organize per-query and per-resolution information in such a way +// that doesn't involve allocating a number of different data +// structures. +class GrpcAresQuery { + public: + explicit GrpcAresQuery(grpc_ares_request* r, const std::string& name) + : r_(r), name_(name) { + grpc_ares_request_ref_locked(r_); + } + + ~GrpcAresQuery() { grpc_ares_request_unref_locked(r_); } + + grpc_ares_request* parent_request() { return r_; } + + const std::string& name() { return name_; } + + private: + /* the top level request instance */ + grpc_ares_request* r_; + /** for logging and errors */ + const std::string name_; +}; + +static grpc_ares_ev_driver* grpc_ares_ev_driver_ref( + grpc_ares_ev_driver* ev_driver) { + GRPC_CARES_TRACE_LOG("request:%p Ref ev_driver %p", ev_driver->request, + ev_driver); + gpr_ref(&ev_driver->refs); + return ev_driver; +} + +static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) { + GRPC_CARES_TRACE_LOG("request:%p Unref ev_driver %p", ev_driver->request, + ev_driver); + if (gpr_unref(&ev_driver->refs)) { + GRPC_CARES_TRACE_LOG("request:%p destroy ev_driver %p", ev_driver->request, + ev_driver); + GPR_ASSERT(ev_driver->fds == nullptr); + ares_destroy(ev_driver->channel); + grpc_ares_complete_request_locked(ev_driver->request); + delete ev_driver; + } +} + +static void fd_node_destroy_locked(fd_node* fdn) { + GRPC_CARES_TRACE_LOG("request:%p delete fd: %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); + GPR_ASSERT(!fdn->readable_registered); + GPR_ASSERT(!fdn->writable_registered); + GPR_ASSERT(fdn->already_shutdown); + delete fdn->grpc_polled_fd; + gpr_free(fdn); +} + +static void fd_node_shutdown_locked(fd_node* fdn, const char* reason) { + if (!fdn->already_shutdown) { + fdn->already_shutdown = true; + fdn->grpc_polled_fd->ShutdownLocked( + GRPC_ERROR_CREATE_FROM_STATIC_STRING(reason)); + } +} + +void grpc_ares_ev_driver_on_queries_complete_locked( + grpc_ares_ev_driver* ev_driver) { + // We mark the event driver as being shut down. + // grpc_ares_notify_on_event_locked will shut down any remaining + // fds. + ev_driver->shutting_down = true; + grpc_timer_cancel(&ev_driver->query_timeout); + grpc_timer_cancel(&ev_driver->ares_backup_poll_alarm); + grpc_ares_ev_driver_unref(ev_driver); +} + +void grpc_ares_ev_driver_shutdown_locked(grpc_ares_ev_driver* ev_driver) { + ev_driver->shutting_down = true; + fd_node* fn = ev_driver->fds; + while (fn != nullptr) { + fd_node_shutdown_locked(fn, "grpc_ares_ev_driver_shutdown"); + fn = fn->next; + } +} + +// Search fd in the fd_node list head. This is an O(n) search, the max possible +// value of n is ARES_GETSOCK_MAXNUM (16). n is typically 1 - 2 in our tests. +static fd_node* pop_fd_node_locked(fd_node** head, ares_socket_t as) { + fd_node phony_head; + phony_head.next = *head; + fd_node* node = &phony_head; + while (node->next != nullptr) { + if (node->next->grpc_polled_fd->GetWrappedAresSocketLocked() == as) { + fd_node* ret = node->next; + node->next = node->next->next; + *head = phony_head.next; + return ret; + } + node = node->next; + } + return nullptr; +} + +static grpc_millis calculate_next_ares_backup_poll_alarm_ms( + grpc_ares_ev_driver* driver) { + // An alternative here could be to use ares_timeout to try to be more + // accurate, but that would require using "struct timeval"'s, which just makes + // things a bit more complicated. So just poll every second, as suggested + // by the c-ares code comments. + grpc_millis ms_until_next_ares_backup_poll_alarm = 1000; + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p. next ares process poll time in " + "%" PRId64 " ms", + driver->request, driver, ms_until_next_ares_backup_poll_alarm); + return ms_until_next_ares_backup_poll_alarm + + grpc_core::ExecCtx::Get()->Now(); +} + +static void on_timeout_locked(grpc_ares_ev_driver* driver, + grpc_error_handle error) { + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p on_timeout_locked. driver->shutting_down=%d. " + "err=%s", + driver->request, driver, driver->shutting_down, + grpc_error_std_string(error).c_str()); + if (!driver->shutting_down && error == GRPC_ERROR_NONE) { + grpc_ares_ev_driver_shutdown_locked(driver); + } + grpc_ares_ev_driver_unref(driver); + GRPC_ERROR_UNREF(error); +} + +static void on_timeout(void* arg, grpc_error_handle error) { + grpc_ares_ev_driver* driver = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + driver->work_serializer->Run( + [driver, error]() { on_timeout_locked(driver, error); }, DEBUG_LOCATION); +} + +static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver); + +static void on_ares_backup_poll_alarm_locked(grpc_ares_ev_driver* driver, + grpc_error_handle error); + +static void on_ares_backup_poll_alarm(void* arg, grpc_error_handle error) { + grpc_ares_ev_driver* driver = static_cast(arg); + (void)GRPC_ERROR_REF(error); + driver->work_serializer->Run( + [driver, error]() { on_ares_backup_poll_alarm_locked(driver, error); }, + DEBUG_LOCATION); +} + +/* In case of non-responsive DNS servers, dropped packets, etc., c-ares has + * intelligent timeout and retry logic, which we can take advantage of by + * polling ares_process_fd on time intervals. Overall, the c-ares library is + * meant to be called into and given a chance to proceed name resolution: + * a) when fd events happen + * b) when some time has passed without fd events having happened + * For the latter, we use this backup poller. Also see + * https://github.com/grpc/grpc/pull/17688 description for more details. */ +static void on_ares_backup_poll_alarm_locked(grpc_ares_ev_driver* driver, + grpc_error_handle error) { + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p on_ares_backup_poll_alarm_locked. " + "driver->shutting_down=%d. " + "err=%s", + driver->request, driver, driver->shutting_down, + grpc_error_std_string(error).c_str()); + if (!driver->shutting_down && error == GRPC_ERROR_NONE) { + fd_node* fdn = driver->fds; + while (fdn != nullptr) { + if (!fdn->already_shutdown) { + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p on_ares_backup_poll_alarm_locked; " + "ares_process_fd. fd=%s", + driver->request, driver, fdn->grpc_polled_fd->GetName()); + ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); + ares_process_fd(driver->channel, as, as); + } + fdn = fdn->next; + } + if (!driver->shutting_down) { + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + grpc_core::ExecCtx::Get()->InvalidateNow(); + grpc_millis next_ares_backup_poll_alarm = + calculate_next_ares_backup_poll_alarm_ms(driver); + grpc_ares_ev_driver_ref(driver); + GRPC_CLOSURE_INIT(&driver->on_ares_backup_poll_alarm_locked, + on_ares_backup_poll_alarm, driver, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&driver->ares_backup_poll_alarm, + next_ares_backup_poll_alarm, + &driver->on_ares_backup_poll_alarm_locked); + } + grpc_ares_notify_on_event_locked(driver); + } + grpc_ares_ev_driver_unref(driver); + GRPC_ERROR_UNREF(error); +} + +static void on_readable_locked(fd_node* fdn, grpc_error_handle error) { + GPR_ASSERT(fdn->readable_registered); + grpc_ares_ev_driver* ev_driver = fdn->ev_driver; + const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); + fdn->readable_registered = false; + GRPC_CARES_TRACE_LOG("request:%p readable on %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); + if (error == GRPC_ERROR_NONE) { + do { + ares_process_fd(ev_driver->channel, as, ARES_SOCKET_BAD); + } while (fdn->grpc_polled_fd->IsFdStillReadableLocked()); + } else { + // If error is not GRPC_ERROR_NONE, it means the fd has been shutdown or + // timed out. The pending lookups made on this ev_driver will be cancelled + // by the following ares_cancel() and the on_done callbacks will be invoked + // with a status of ARES_ECANCELLED. The remaining file descriptors in this + // ev_driver will be cleaned up in the follwing + // grpc_ares_notify_on_event_locked(). + ares_cancel(ev_driver->channel); + } + grpc_ares_notify_on_event_locked(ev_driver); + grpc_ares_ev_driver_unref(ev_driver); + GRPC_ERROR_UNREF(error); +} + +static void on_readable(void* arg, grpc_error_handle error) { + fd_node* fdn = static_cast(arg); + (void)GRPC_ERROR_REF(error); /* ref owned by lambda */ + fdn->ev_driver->work_serializer->Run( + [fdn, error]() { on_readable_locked(fdn, error); }, DEBUG_LOCATION); +} + +static void on_writable_locked(fd_node* fdn, grpc_error_handle error) { + GPR_ASSERT(fdn->writable_registered); + grpc_ares_ev_driver* ev_driver = fdn->ev_driver; + const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); + fdn->writable_registered = false; + GRPC_CARES_TRACE_LOG("request:%p writable on %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); + if (error == GRPC_ERROR_NONE) { + ares_process_fd(ev_driver->channel, ARES_SOCKET_BAD, as); + } else { + // If error is not GRPC_ERROR_NONE, it means the fd has been shutdown or + // timed out. The pending lookups made on this ev_driver will be cancelled + // by the following ares_cancel() and the on_done callbacks will be invoked + // with a status of ARES_ECANCELLED. The remaining file descriptors in this + // ev_driver will be cleaned up in the follwing + // grpc_ares_notify_on_event_locked(). + ares_cancel(ev_driver->channel); + } + grpc_ares_notify_on_event_locked(ev_driver); + grpc_ares_ev_driver_unref(ev_driver); + GRPC_ERROR_UNREF(error); +} + +static void on_writable(void* arg, grpc_error_handle error) { + fd_node* fdn = static_cast(arg); + (void)GRPC_ERROR_REF(error); /* ref owned by lambda */ + fdn->ev_driver->work_serializer->Run( + [fdn, error]() { on_writable_locked(fdn, error); }, DEBUG_LOCATION); +} + +// Get the file descriptors used by the ev_driver's ares channel, register +// driver_closure with these filedescriptors. +static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { + fd_node* new_list = nullptr; + if (!ev_driver->shutting_down) { + ares_socket_t socks[ARES_GETSOCK_MAXNUM]; + int socks_bitmask = + ares_getsock(ev_driver->channel, socks, ARES_GETSOCK_MAXNUM); + for (size_t i = 0; i < ARES_GETSOCK_MAXNUM; i++) { + if (ARES_GETSOCK_READABLE(socks_bitmask, i) || + ARES_GETSOCK_WRITABLE(socks_bitmask, i)) { + fd_node* fdn = pop_fd_node_locked(&ev_driver->fds, socks[i]); + // Create a new fd_node if sock[i] is not in the fd_node list. + if (fdn == nullptr) { + fdn = static_cast(gpr_malloc(sizeof(fd_node))); + fdn->grpc_polled_fd = + ev_driver->polled_fd_factory->NewGrpcPolledFdLocked( + socks[i], ev_driver->pollset_set, ev_driver->work_serializer); + GRPC_CARES_TRACE_LOG("request:%p new fd: %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); + fdn->ev_driver = ev_driver; + fdn->readable_registered = false; + fdn->writable_registered = false; + fdn->already_shutdown = false; + } + fdn->next = new_list; + new_list = fdn; + // Register read_closure if the socket is readable and read_closure has + // not been registered with this socket. + if (ARES_GETSOCK_READABLE(socks_bitmask, i) && + !fdn->readable_registered) { + grpc_ares_ev_driver_ref(ev_driver); + GRPC_CARES_TRACE_LOG("request:%p notify read on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); + GRPC_CLOSURE_INIT(&fdn->read_closure, on_readable, fdn, + grpc_schedule_on_exec_ctx); + fdn->grpc_polled_fd->RegisterForOnReadableLocked(&fdn->read_closure); + fdn->readable_registered = true; + } + // Register write_closure if the socket is writable and write_closure + // has not been registered with this socket. + if (ARES_GETSOCK_WRITABLE(socks_bitmask, i) && + !fdn->writable_registered) { + GRPC_CARES_TRACE_LOG("request:%p notify write on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); + grpc_ares_ev_driver_ref(ev_driver); + GRPC_CLOSURE_INIT(&fdn->write_closure, on_writable, fdn, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&fdn->write_closure, on_writable, fdn, + grpc_schedule_on_exec_ctx); + fdn->grpc_polled_fd->RegisterForOnWriteableLocked( + &fdn->write_closure); + fdn->writable_registered = true; + } + } + } + } + // Any remaining fds in ev_driver->fds were not returned by ares_getsock() and + // are therefore no longer in use, so they can be shut down and removed from + // the list. + while (ev_driver->fds != nullptr) { + fd_node* cur = ev_driver->fds; + ev_driver->fds = ev_driver->fds->next; + fd_node_shutdown_locked(cur, "c-ares fd shutdown"); + if (!cur->readable_registered && !cur->writable_registered) { + fd_node_destroy_locked(cur); + } else { + cur->next = new_list; + new_list = cur; + } + } + ev_driver->fds = new_list; +} + +void grpc_ares_ev_driver_start_locked(grpc_ares_ev_driver* ev_driver) { + grpc_ares_notify_on_event_locked(ev_driver); + // Initialize overall DNS resolution timeout alarm + grpc_millis timeout = + ev_driver->query_timeout_ms == 0 + ? GRPC_MILLIS_INF_FUTURE + : ev_driver->query_timeout_ms + grpc_core::ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in " + "%" PRId64 " ms", + ev_driver->request, ev_driver, timeout); + grpc_ares_ev_driver_ref(ev_driver); + GRPC_CLOSURE_INIT(&ev_driver->on_timeout_locked, on_timeout, ev_driver, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&ev_driver->query_timeout, timeout, + &ev_driver->on_timeout_locked); + // Initialize the backup poll alarm + grpc_millis next_ares_backup_poll_alarm = + calculate_next_ares_backup_poll_alarm_ms(ev_driver); + grpc_ares_ev_driver_ref(ev_driver); + GRPC_CLOSURE_INIT(&ev_driver->on_ares_backup_poll_alarm_locked, + on_ares_backup_poll_alarm, ev_driver, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&ev_driver->ares_backup_poll_alarm, + next_ares_backup_poll_alarm, + &ev_driver->on_ares_backup_poll_alarm_locked); +} + +static void noop_inject_channel_config(ares_channel /*channel*/) {} + +void (*grpc_ares_test_only_inject_config)(ares_channel channel) = + noop_inject_channel_config; + +grpc_error_handle grpc_ares_ev_driver_create_locked( + grpc_ares_ev_driver** ev_driver, grpc_pollset_set* pollset_set, + int query_timeout_ms, + std::shared_ptr work_serializer, + grpc_ares_request* request) { + *ev_driver = new grpc_ares_ev_driver(); + ares_options opts; + memset(&opts, 0, sizeof(opts)); + opts.flags |= ARES_FLAG_STAYOPEN; + int status = ares_init_options(&(*ev_driver)->channel, &opts, ARES_OPT_FLAGS); + grpc_ares_test_only_inject_config((*ev_driver)->channel); + GRPC_CARES_TRACE_LOG("request:%p grpc_ares_ev_driver_create_locked", request); + if (status != ARES_SUCCESS) { + grpc_error_handle err = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Failed to init ares channel. C-ares error: ", ares_strerror(status))); + gpr_free(*ev_driver); + return err; + } + (*ev_driver)->work_serializer = std::move(work_serializer); + gpr_ref_init(&(*ev_driver)->refs, 1); + (*ev_driver)->pollset_set = pollset_set; + (*ev_driver)->fds = nullptr; + (*ev_driver)->shutting_down = false; + (*ev_driver)->request = request; + (*ev_driver)->polled_fd_factory = + grpc_core::NewGrpcPolledFdFactory((*ev_driver)->work_serializer); + (*ev_driver) + ->polled_fd_factory->ConfigureAresChannelLocked((*ev_driver)->channel); + (*ev_driver)->query_timeout_ms = query_timeout_ms; + return GRPC_ERROR_NONE; +} + +static void log_address_sorting_list(const grpc_ares_request* r, + const ServerAddressList& addresses, + const char* input_output_str) { + for (size_t i = 0; i < addresses.size(); i++) { + std::string addr_str = + grpc_sockaddr_to_string(&addresses[i].address(), true); + gpr_log(GPR_INFO, + "(c-ares resolver) request:%p c-ares address sorting: %s[%" PRIuPTR + "]=%s", + r, input_output_str, i, addr_str.c_str()); + } +} + +void grpc_cares_wrapper_address_sorting_sort(const grpc_ares_request* r, + ServerAddressList* addresses) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_cares_address_sorting)) { + log_address_sorting_list(r, *addresses, "input"); + } + address_sorting_sortable* sortables = static_cast( + gpr_zalloc(sizeof(address_sorting_sortable) * addresses->size())); + for (size_t i = 0; i < addresses->size(); ++i) { + sortables[i].user_data = &(*addresses)[i]; + memcpy(&sortables[i].dest_addr.addr, &(*addresses)[i].address().addr, + (*addresses)[i].address().len); + sortables[i].dest_addr.len = (*addresses)[i].address().len; + } + address_sorting_rfc_6724_sort(sortables, addresses->size()); + ServerAddressList sorted; + sorted.reserve(addresses->size()); + for (size_t i = 0; i < addresses->size(); ++i) { + sorted.emplace_back(*static_cast(sortables[i].user_data)); + } + gpr_free(sortables); + *addresses = std::move(sorted); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_cares_address_sorting)) { + log_address_sorting_list(r, *addresses, "output"); + } +} + +static void grpc_ares_request_ref_locked(grpc_ares_request* r) { + r->pending_queries++; +} + +static void grpc_ares_request_unref_locked(grpc_ares_request* r) { + r->pending_queries--; + if (r->pending_queries == 0u) { + grpc_ares_ev_driver_on_queries_complete_locked(r->ev_driver); + } +} + +void grpc_ares_complete_request_locked(grpc_ares_request* r) { + /* Invoke on_done callback and destroy the + request */ + r->ev_driver = nullptr; + ServerAddressList* addresses = r->addresses_out->get(); + if (addresses != nullptr) { + grpc_cares_wrapper_address_sorting_sort(r, addresses); + GRPC_ERROR_UNREF(r->error); + r->error = GRPC_ERROR_NONE; + // TODO(apolcyn): allow c-ares to return a service config + // with no addresses along side it + } + if (r->balancer_addresses_out != nullptr) { + ServerAddressList* balancer_addresses = r->balancer_addresses_out->get(); + if (balancer_addresses != nullptr) { + grpc_cares_wrapper_address_sorting_sort(r, balancer_addresses); + } + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, r->error); +} + +/* Note that the returned object takes a reference to qtype, so + * qtype must outlive it. */ +static grpc_ares_hostbyname_request* create_hostbyname_request_locked( + grpc_ares_request* parent_request, const char* host, uint16_t port, + bool is_balancer, const char* qtype) { + GRPC_CARES_TRACE_LOG( + "request:%p create_hostbyname_request_locked host:%s port:%d " + "is_balancer:%d qtype:%s", + parent_request, host, port, is_balancer, qtype); + grpc_ares_hostbyname_request* hr = new grpc_ares_hostbyname_request(); + hr->parent_request = parent_request; + hr->host = gpr_strdup(host); + hr->port = port; + hr->is_balancer = is_balancer; + hr->qtype = qtype; + grpc_ares_request_ref_locked(parent_request); + return hr; +} + +static void destroy_hostbyname_request_locked( + grpc_ares_hostbyname_request* hr) { + grpc_ares_request_unref_locked(hr->parent_request); + gpr_free(hr->host); + delete hr; +} + +static void on_hostbyname_done_locked(void* arg, int status, int /*timeouts*/, + struct hostent* hostent) { + grpc_ares_hostbyname_request* hr = + static_cast(arg); + grpc_ares_request* r = hr->parent_request; + if (status == ARES_SUCCESS) { + GRPC_CARES_TRACE_LOG( + "request:%p on_hostbyname_done_locked qtype=%s host=%s ARES_SUCCESS", r, + hr->qtype, hr->host); + std::unique_ptr* address_list_ptr = + hr->is_balancer ? r->balancer_addresses_out : r->addresses_out; + if (*address_list_ptr == nullptr) { + *address_list_ptr = absl::make_unique(); + } + ServerAddressList& addresses = **address_list_ptr; + for (size_t i = 0; hostent->h_addr_list[i] != nullptr; ++i) { + absl::InlinedVector args_to_add; + if (hr->is_balancer) { + args_to_add.emplace_back(grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), hr->host)); + } + grpc_channel_args* args = grpc_channel_args_copy_and_add( + nullptr, args_to_add.data(), args_to_add.size()); + switch (hostent->h_addrtype) { + case AF_INET6: { + size_t addr_len = sizeof(struct sockaddr_in6); + struct sockaddr_in6 addr; + memset(&addr, 0, addr_len); + memcpy(&addr.sin6_addr, hostent->h_addr_list[i], + sizeof(struct in6_addr)); + addr.sin6_family = static_cast(hostent->h_addrtype); + addr.sin6_port = hr->port; + addresses.emplace_back(&addr, addr_len, args); + char output[INET6_ADDRSTRLEN]; + ares_inet_ntop(AF_INET6, &addr.sin6_addr, output, INET6_ADDRSTRLEN); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET6 result: \n" + " addr: %s\n port: %d\n sin6_scope_id: %d\n", + r, output, ntohs(hr->port), addr.sin6_scope_id); + break; + } + case AF_INET: { + size_t addr_len = sizeof(struct sockaddr_in); + struct sockaddr_in addr; + memset(&addr, 0, addr_len); + memcpy(&addr.sin_addr, hostent->h_addr_list[i], + sizeof(struct in_addr)); + addr.sin_family = static_cast(hostent->h_addrtype); + addr.sin_port = hr->port; + addresses.emplace_back(&addr, addr_len, args); + char output[INET_ADDRSTRLEN]; + ares_inet_ntop(AF_INET, &addr.sin_addr, output, INET_ADDRSTRLEN); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET result: \n" + " addr: %s\n port: %d\n", + r, output, ntohs(hr->port)); + break; + } + } + } + } else { + std::string error_msg = absl::StrFormat( + "C-ares status is not ARES_SUCCESS qtype=%s name=%s is_balancer=%d: %s", + hr->qtype, hr->host, hr->is_balancer, ares_strerror(status)); + GRPC_CARES_TRACE_LOG("request:%p on_hostbyname_done_locked: %s", r, + error_msg.c_str()); + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_CPP_STRING(std::move(error_msg)); + r->error = grpc_error_add_child(error, r->error); + } + destroy_hostbyname_request_locked(hr); +} + +static void on_srv_query_done_locked(void* arg, int status, int /*timeouts*/, + unsigned char* abuf, int alen) { + GrpcAresQuery* q = static_cast(arg); + grpc_ares_request* r = q->parent_request(); + if (status == ARES_SUCCESS) { + GRPC_CARES_TRACE_LOG( + "request:%p on_srv_query_done_locked name=%s ARES_SUCCESS", r, + q->name().c_str()); + struct ares_srv_reply* reply; + const int parse_status = ares_parse_srv_reply(abuf, alen, &reply); + GRPC_CARES_TRACE_LOG("request:%p ares_parse_srv_reply: %d", r, + parse_status); + if (parse_status == ARES_SUCCESS) { + for (struct ares_srv_reply* srv_it = reply; srv_it != nullptr; + srv_it = srv_it->next) { + if (grpc_ares_query_ipv6()) { + grpc_ares_hostbyname_request* hr = create_hostbyname_request_locked( + r, srv_it->host, htons(srv_it->port), true /* is_balancer */, + "AAAA"); + ares_gethostbyname(r->ev_driver->channel, hr->host, AF_INET6, + on_hostbyname_done_locked, hr); + } + grpc_ares_hostbyname_request* hr = create_hostbyname_request_locked( + r, srv_it->host, htons(srv_it->port), true /* is_balancer */, "A"); + ares_gethostbyname(r->ev_driver->channel, hr->host, AF_INET, + on_hostbyname_done_locked, hr); + grpc_ares_notify_on_event_locked(r->ev_driver); + } + } + if (reply != nullptr) { + ares_free_data(reply); + } + } else { + std::string error_msg = absl::StrFormat( + "C-ares status is not ARES_SUCCESS qtype=SRV name=%s: %s", q->name(), + ares_strerror(status)); + GRPC_CARES_TRACE_LOG("request:%p on_srv_query_done_locked: %s", r, + error_msg.c_str()); + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_CPP_STRING(std::move(error_msg)); + r->error = grpc_error_add_child(error, r->error); + } + delete q; +} + +static const char g_service_config_attribute_prefix[] = "grpc_config="; + +static void on_txt_done_locked(void* arg, int status, int /*timeouts*/, + unsigned char* buf, int len) { + GrpcAresQuery* q = static_cast(arg); + std::unique_ptr query_deleter(q); + grpc_ares_request* r = q->parent_request(); + const size_t prefix_len = sizeof(g_service_config_attribute_prefix) - 1; + struct ares_txt_ext* result = nullptr; + struct ares_txt_ext* reply = nullptr; + grpc_error_handle error = GRPC_ERROR_NONE; + if (status != ARES_SUCCESS) goto fail; + GRPC_CARES_TRACE_LOG("request:%p on_txt_done_locked name=%s ARES_SUCCESS", r, + q->name().c_str()); + status = ares_parse_txt_reply_ext(buf, len, &reply); + if (status != ARES_SUCCESS) goto fail; + // Find service config in TXT record. + for (result = reply; result != nullptr; result = result->next) { + if (result->record_start && + memcmp(result->txt, g_service_config_attribute_prefix, prefix_len) == + 0) { + break; + } + } + // Found a service config record. + if (result != nullptr) { + size_t service_config_len = result->length - prefix_len; + *r->service_config_json_out = + static_cast(gpr_malloc(service_config_len + 1)); + memcpy(*r->service_config_json_out, result->txt + prefix_len, + service_config_len); + for (result = result->next; result != nullptr && !result->record_start; + result = result->next) { + *r->service_config_json_out = static_cast( + gpr_realloc(*r->service_config_json_out, + service_config_len + result->length + 1)); + memcpy(*r->service_config_json_out + service_config_len, result->txt, + result->length); + service_config_len += result->length; + } + (*r->service_config_json_out)[service_config_len] = '\0'; + GRPC_CARES_TRACE_LOG("request:%p found service config: %s", r, + *r->service_config_json_out); + } + // Clean up. + ares_free_data(reply); + return; +fail: + std::string error_msg = + absl::StrFormat("C-ares status is not ARES_SUCCESS qtype=TXT name=%s: %s", + q->name(), ares_strerror(status)); + GRPC_CARES_TRACE_LOG("request:%p on_txt_done_locked %s", r, + error_msg.c_str()); + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(std::move(error_msg)); + r->error = grpc_error_add_child(error, r->error); +} + +void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( + grpc_ares_request* r, const char* dns_server, const char* name, + const char* default_port, grpc_pollset_set* interested_parties, + int query_timeout_ms, + std::shared_ptr work_serializer) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_ares_hostbyname_request* hr = nullptr; + /* parse name, splitting it into host and port parts */ + std::string host; + std::string port; + grpc_core::SplitHostPort(name, &host, &port); + if (host.empty()) { + error = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("unparseable host:port"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto error_cleanup; + } else if (port.empty()) { + if (default_port == nullptr) { + error = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("no port in name"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto error_cleanup; + } + port = default_port; + } + error = grpc_ares_ev_driver_create_locked(&r->ev_driver, interested_parties, + query_timeout_ms, + std::move(work_serializer), r); + if (error != GRPC_ERROR_NONE) goto error_cleanup; + // If dns_server is specified, use it. + if (dns_server != nullptr && dns_server[0] != '\0') { + GRPC_CARES_TRACE_LOG("request:%p Using DNS server %s", r, dns_server); + grpc_resolved_address addr; + if (grpc_parse_ipv4_hostport(dns_server, &addr, false /* log_errors */)) { + r->dns_server_addr.family = AF_INET; + struct sockaddr_in* in = reinterpret_cast(addr.addr); + memcpy(&r->dns_server_addr.addr.addr4, &in->sin_addr, + sizeof(struct in_addr)); + r->dns_server_addr.tcp_port = grpc_sockaddr_get_port(&addr); + r->dns_server_addr.udp_port = grpc_sockaddr_get_port(&addr); + } else if (grpc_parse_ipv6_hostport(dns_server, &addr, + false /* log_errors */)) { + r->dns_server_addr.family = AF_INET6; + struct sockaddr_in6* in6 = + reinterpret_cast(addr.addr); + memcpy(&r->dns_server_addr.addr.addr6, &in6->sin6_addr, + sizeof(struct in6_addr)); + r->dns_server_addr.tcp_port = grpc_sockaddr_get_port(&addr); + r->dns_server_addr.udp_port = grpc_sockaddr_get_port(&addr); + } else { + error = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("cannot parse authority"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto error_cleanup; + } + int status = + ares_set_servers_ports(r->ev_driver->channel, &r->dns_server_addr); + if (status != ARES_SUCCESS) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "C-ares status is not ARES_SUCCESS: ", ares_strerror(status))); + goto error_cleanup; + } + } + r->pending_queries = 1; + if (grpc_ares_query_ipv6()) { + hr = create_hostbyname_request_locked(r, host.c_str(), + grpc_strhtons(port.c_str()), + /*is_balancer=*/false, "AAAA"); + ares_gethostbyname(r->ev_driver->channel, hr->host, AF_INET6, + on_hostbyname_done_locked, hr); + } + hr = create_hostbyname_request_locked(r, host.c_str(), + grpc_strhtons(port.c_str()), + /*is_balancer=*/false, "A"); + ares_gethostbyname(r->ev_driver->channel, hr->host, AF_INET, + on_hostbyname_done_locked, hr); + if (r->balancer_addresses_out != nullptr) { + /* Query the SRV record */ + std::string service_name = absl::StrCat("_grpclb._tcp.", host); + GrpcAresQuery* srv_query = new GrpcAresQuery(r, service_name); + ares_query(r->ev_driver->channel, service_name.c_str(), ns_c_in, ns_t_srv, + on_srv_query_done_locked, srv_query); + } + if (r->service_config_json_out != nullptr) { + std::string config_name = absl::StrCat("_grpc_config.", host); + GrpcAresQuery* txt_query = new GrpcAresQuery(r, config_name); + ares_search(r->ev_driver->channel, config_name.c_str(), ns_c_in, ns_t_txt, + on_txt_done_locked, txt_query); + } + grpc_ares_ev_driver_start_locked(r->ev_driver); + grpc_ares_request_unref_locked(r); + return; + +error_cleanup: + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, error); +} + +static bool inner_resolve_as_ip_literal_locked( + const char* name, const char* default_port, + std::unique_ptr* addrs, std::string* host, + std::string* port, std::string* hostport) { + if (!grpc_core::SplitHostPort(name, host, port)) { + gpr_log(GPR_ERROR, + "Failed to parse %s to host:port while attempting to resolve as ip " + "literal.", + name); + return false; + } + if (port->empty()) { + if (default_port == nullptr) { + gpr_log(GPR_ERROR, + "No port or default port for %s while attempting to resolve as " + "ip literal.", + name); + return false; + } + *port = default_port; + } + grpc_resolved_address addr; + *hostport = grpc_core::JoinHostPort(*host, atoi(port->c_str())); + if (grpc_parse_ipv4_hostport(hostport->c_str(), &addr, + false /* log errors */) || + grpc_parse_ipv6_hostport(hostport->c_str(), &addr, + false /* log errors */)) { + GPR_ASSERT(*addrs == nullptr); + *addrs = absl::make_unique(); + (*addrs)->emplace_back(addr.addr, addr.len, nullptr /* args */); + return true; + } + return false; +} + +static bool resolve_as_ip_literal_locked( + const char* name, const char* default_port, + std::unique_ptr* addrs) { + std::string host; + std::string port; + std::string hostport; + bool out = inner_resolve_as_ip_literal_locked(name, default_port, addrs, + &host, &port, &hostport); + return out; +} + +static bool target_matches_localhost_inner(const char* name, std::string* host, + std::string* port) { + if (!grpc_core::SplitHostPort(name, host, port)) { + gpr_log(GPR_ERROR, "Unable to split host and port for name: %s", name); + return false; + } + return gpr_stricmp(host->c_str(), "localhost") == 0; +} + +static bool target_matches_localhost(const char* name) { + std::string host; + std::string port; + return target_matches_localhost_inner(name, &host, &port); +} + +#ifdef GRPC_ARES_RESOLVE_LOCALHOST_MANUALLY +static bool inner_maybe_resolve_localhost_manually_locked( + const grpc_ares_request* r, const char* name, const char* default_port, + std::unique_ptr* addrs, std::string* host, + std::string* port) { + grpc_core::SplitHostPort(name, host, port); + if (host->empty()) { + gpr_log(GPR_ERROR, + "Failed to parse %s into host:port during manual localhost " + "resolution check.", + name); + return false; + } + if (port->empty()) { + if (default_port == nullptr) { + gpr_log(GPR_ERROR, + "No port or default port for %s during manual localhost " + "resolution check.", + name); + return false; + } + *port = default_port; + } + if (gpr_stricmp(host->c_str(), "localhost") == 0) { + GPR_ASSERT(*addrs == nullptr); + *addrs = absl::make_unique(); + uint16_t numeric_port = grpc_strhtons(port->c_str()); + // Append the ipv6 loopback address. + struct sockaddr_in6 ipv6_loopback_addr; + memset(&ipv6_loopback_addr, 0, sizeof(ipv6_loopback_addr)); + ((char*)&ipv6_loopback_addr.sin6_addr)[15] = 1; + ipv6_loopback_addr.sin6_family = AF_INET6; + ipv6_loopback_addr.sin6_port = numeric_port; + (*addrs)->emplace_back(&ipv6_loopback_addr, sizeof(ipv6_loopback_addr), + nullptr /* args */); + // Append the ipv4 loopback address. + struct sockaddr_in ipv4_loopback_addr; + memset(&ipv4_loopback_addr, 0, sizeof(ipv4_loopback_addr)); + ((char*)&ipv4_loopback_addr.sin_addr)[0] = 0x7f; + ((char*)&ipv4_loopback_addr.sin_addr)[3] = 0x01; + ipv4_loopback_addr.sin_family = AF_INET; + ipv4_loopback_addr.sin_port = numeric_port; + (*addrs)->emplace_back(&ipv4_loopback_addr, sizeof(ipv4_loopback_addr), + nullptr /* args */); + // Let the address sorter figure out which one should be tried first. + grpc_cares_wrapper_address_sorting_sort(r, addrs->get()); + return true; + } + return false; +} + +static bool grpc_ares_maybe_resolve_localhost_manually_locked( + const grpc_ares_request* r, const char* name, const char* default_port, + std::unique_ptr* addrs) { + std::string host; + std::string port; + return inner_maybe_resolve_localhost_manually_locked(r, name, default_port, + addrs, &host, &port); +} +#else /* GRPC_ARES_RESOLVE_LOCALHOST_MANUALLY */ +static bool grpc_ares_maybe_resolve_localhost_manually_locked( + const grpc_ares_request* /*r*/, const char* /*name*/, + const char* /*default_port*/, + std::unique_ptr* /*addrs*/) { + return false; +} +#endif /* GRPC_ARES_RESOLVE_LOCALHOST_MANUALLY */ + +static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( + const char* dns_server, const char* name, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + std::unique_ptr* addrs, + std::unique_ptr* balancer_addrs, + char** service_config_json, int query_timeout_ms, + std::shared_ptr work_serializer) { + grpc_ares_request* r = new grpc_ares_request(); + r->ev_driver = nullptr; + r->on_done = on_done; + r->addresses_out = addrs; + r->balancer_addresses_out = balancer_addrs; + r->service_config_json_out = service_config_json; + GRPC_CARES_TRACE_LOG( + "request:%p c-ares grpc_dns_lookup_ares_locked_impl name=%s, " + "default_port=%s", + r, name, default_port); + // Early out if the target is an ipv4 or ipv6 literal. + if (resolve_as_ip_literal_locked(name, default_port, addrs)) { + grpc_ares_complete_request_locked(r); + return r; + } + // Early out if the target is localhost and we're on Windows. + if (grpc_ares_maybe_resolve_localhost_manually_locked(r, name, default_port, + addrs)) { + grpc_ares_complete_request_locked(r); + return r; + } + // Don't query for SRV and TXT records if the target is "localhost", so + // as to cut down on lookups over the network, especially in tests: + // https://github.com/grpc/proposal/pull/79 + if (target_matches_localhost(name)) { + r->balancer_addresses_out = nullptr; + r->service_config_json_out = nullptr; + } + // Look up name using c-ares lib. + grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( + r, dns_server, name, default_port, interested_parties, query_timeout_ms, + std::move(work_serializer)); + return r; +} + +grpc_ares_request* (*grpc_dns_lookup_ares_locked)( + const char* dns_server, const char* name, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + std::unique_ptr* addrs, + std::unique_ptr* balancer_addrs, + char** service_config_json, int query_timeout_ms, + std::shared_ptr work_serializer) = + grpc_dns_lookup_ares_locked_impl; + +static void grpc_cancel_ares_request_locked_impl(grpc_ares_request* r) { + GPR_ASSERT(r != nullptr); + if (r->ev_driver != nullptr) { + grpc_ares_ev_driver_shutdown_locked(r->ev_driver); + } +} + +void (*grpc_cancel_ares_request_locked)(grpc_ares_request* r) = + grpc_cancel_ares_request_locked_impl; + +// ares_library_init and ares_library_cleanup are currently no-op except under +// Windows. Calling them may cause race conditions when other parts of the +// binary calls these functions concurrently. +#ifdef GPR_WINDOWS +grpc_error_handle grpc_ares_init(void) { + int status = ares_library_init(ARES_LIB_INIT_ALL); + if (status != ARES_SUCCESS) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("ares_library_init failed: ", ares_strerror(status))); + } + return GRPC_ERROR_NONE; +} + +void grpc_ares_cleanup(void) { ares_library_cleanup(); } +#else +grpc_error_handle grpc_ares_init(void) { return GRPC_ERROR_NONE; } +void grpc_ares_cleanup(void) {} +#endif // GPR_WINDOWS + +/* + * grpc_resolve_address_ares related structs and functions + */ + +typedef struct grpc_resolve_address_ares_request { + /* work_serializer that queries and related callbacks run under */ + std::shared_ptr work_serializer; + /** the pointer to receive the resolved addresses */ + grpc_resolved_addresses** addrs_out; + /** currently resolving addresses */ + std::unique_ptr addresses; + /** closure to call when the resolve_address_ares request completes */ + grpc_closure* on_resolve_address_done; + /** a closure wrapping on_resolve_address_done, which should be invoked when + the grpc_dns_lookup_ares_locked operation is done. */ + grpc_closure on_dns_lookup_done_locked; + /* target name */ + const char* name; + /* default port to use if none is specified */ + const char* default_port; + /* pollset_set to be driven by */ + grpc_pollset_set* interested_parties; + /* underlying ares_request that the query is performed on */ + grpc_ares_request* ares_request = nullptr; +} grpc_resolve_address_ares_request; + +static void on_dns_lookup_done_locked(grpc_resolve_address_ares_request* r, + grpc_error_handle error) { + delete r->ares_request; + grpc_resolved_addresses** resolved_addresses = r->addrs_out; + if (r->addresses == nullptr || r->addresses->empty()) { + *resolved_addresses = nullptr; + } else { + *resolved_addresses = static_cast( + gpr_zalloc(sizeof(grpc_resolved_addresses))); + (*resolved_addresses)->naddrs = r->addresses->size(); + (*resolved_addresses)->addrs = + static_cast(gpr_zalloc( + sizeof(grpc_resolved_address) * (*resolved_addresses)->naddrs)); + for (size_t i = 0; i < (*resolved_addresses)->naddrs; ++i) { + memcpy(&(*resolved_addresses)->addrs[i], &(*r->addresses)[i].address(), + sizeof(grpc_resolved_address)); + } + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_resolve_address_done, error); + delete r; +} + +static void on_dns_lookup_done(void* arg, grpc_error_handle error) { + grpc_resolve_address_ares_request* r = + static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + r->work_serializer->Run([r, error]() { on_dns_lookup_done_locked(r, error); }, + DEBUG_LOCATION); +} + +static void grpc_resolve_address_invoke_dns_lookup_ares_locked(void* arg) { + grpc_resolve_address_ares_request* r = + static_cast(arg); + GRPC_CLOSURE_INIT(&r->on_dns_lookup_done_locked, on_dns_lookup_done, r, + grpc_schedule_on_exec_ctx); + r->ares_request = grpc_dns_lookup_ares_locked( + nullptr /* dns_server */, r->name, r->default_port, r->interested_parties, + &r->on_dns_lookup_done_locked, &r->addresses, + nullptr /* balancer_addresses */, nullptr /* service_config_json */, + GRPC_DNS_ARES_DEFAULT_QUERY_TIMEOUT_MS, r->work_serializer); +} + +static void grpc_resolve_address_ares_impl(const char* name, + const char* default_port, + grpc_pollset_set* interested_parties, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + grpc_resolve_address_ares_request* r = + new grpc_resolve_address_ares_request(); + r->work_serializer = std::make_shared(); + r->addrs_out = addrs; + r->on_resolve_address_done = on_done; + r->name = name; + r->default_port = default_port; + r->interested_parties = interested_parties; + r->work_serializer->Run( + [r]() { grpc_resolve_address_invoke_dns_lookup_ares_locked(r); }, + DEBUG_LOCATION); +} + +void (*grpc_resolve_address_ares)( + const char* name, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + grpc_resolved_addresses** addrs) = grpc_resolve_address_ares_impl; + +#endif /* GRPC_ARES == 1 */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_event_engine.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_event_engine.cc new file mode 100644 index 00000000..fdd4e00c --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_event_engine.cc @@ -0,0 +1,28 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_USE_EVENT_ENGINE) + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" + +bool grpc_ares_query_ipv6() { + /* The libuv grpc code currently does not have the code to probe for this, + * so we assume for now that IPv6 is always available in contexts where this + * code will be used. */ + return true; +} + +#endif /* GRPC_ARES == 1 && defined(GRPC_USE_EVENT_ENGINE) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc new file mode 100644 index 00000000..23c0fec7 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_POSIX_SOCKET_ARES_EV_DRIVER) + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" + +bool grpc_ares_query_ipv6() { return grpc_ipv6_loopback_available(); } + +#endif /* GRPC_ARES == 1 && defined(GRPC_POSIX_SOCKET_ARES_EV_DRIVER) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc new file mode 100644 index 00000000..f76c6a4a --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc @@ -0,0 +1,34 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" +#if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER) + +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/socket_windows.h" + +bool grpc_ares_query_ipv6() { return grpc_ipv6_loopback_available(); } + +#endif /* GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER) */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.cc b/src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.cc new file mode 100644 index 00000000..07a617c1 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.cc @@ -0,0 +1,28 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// This is similar to the sockaddr resolver, except that it supports a +// bunch of query args that are useful for dependency injection in tests. + +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" + +GPR_GLOBAL_CONFIG_DEFINE_STRING( + grpc_dns_resolver, "", + "Declares which DNS resolver to use. The default is ares if gRPC is built " + "with c-ares support. Otherwise, the value of this environment variable is " + "ignored.") diff --git a/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc b/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc new file mode 100644 index 00000000..3a8bc71c --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc @@ -0,0 +1,332 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" + +#define GRPC_DNS_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define GRPC_DNS_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define GRPC_DNS_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define GRPC_DNS_RECONNECT_JITTER 0.2 + +namespace grpc_core { + +namespace { + +class NativeDnsResolver : public Resolver { + public: + explicit NativeDnsResolver(ResolverArgs args); + + void StartLocked() override; + + void RequestReresolutionLocked() override; + + void ResetBackoffLocked() override; + + void ShutdownLocked() override; + + private: + ~NativeDnsResolver() override; + + void MaybeStartResolvingLocked(); + void StartResolvingLocked(); + + static void OnNextResolution(void* arg, grpc_error_handle error); + void OnNextResolutionLocked(grpc_error_handle error); + static void OnResolved(void* arg, grpc_error_handle error); + void OnResolvedLocked(grpc_error_handle error); + + /// name to resolve + std::string name_to_resolve_; + /// channel args + grpc_channel_args* channel_args_ = nullptr; + std::shared_ptr work_serializer_; + std::unique_ptr result_handler_; + /// pollset_set to drive the name resolution process + grpc_pollset_set* interested_parties_ = nullptr; + /// are we shutting down? + bool shutdown_ = false; + /// are we currently resolving? + bool resolving_ = false; + grpc_closure on_resolved_; + /// next resolution timer + bool have_next_resolution_timer_ = false; + grpc_timer next_resolution_timer_; + grpc_closure on_next_resolution_; + /// min time between DNS requests + grpc_millis min_time_between_resolutions_; + /// timestamp of last DNS request + grpc_millis last_resolution_timestamp_ = -1; + /// retry backoff state + BackOff backoff_; + /// currently resolving addresses + grpc_resolved_addresses* addresses_ = nullptr; +}; + +NativeDnsResolver::NativeDnsResolver(ResolverArgs args) + : name_to_resolve_(absl::StripPrefix(args.uri.path(), "/")), + channel_args_(grpc_channel_args_copy(args.args)), + work_serializer_(std::move(args.work_serializer)), + result_handler_(std::move(args.result_handler)), + interested_parties_(grpc_pollset_set_create()), + min_time_between_resolutions_(grpc_channel_args_find_integer( + channel_args_, GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS, + {1000 * 30, 0, INT_MAX})), + backoff_( + BackOff::Options() + .set_initial_backoff(GRPC_DNS_INITIAL_CONNECT_BACKOFF_SECONDS * + 1000) + .set_multiplier(GRPC_DNS_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(GRPC_DNS_RECONNECT_JITTER) + .set_max_backoff(GRPC_DNS_RECONNECT_MAX_BACKOFF_SECONDS * 1000)) { + if (args.pollset_set != nullptr) { + grpc_pollset_set_add_pollset_set(interested_parties_, args.pollset_set); + } +} + +NativeDnsResolver::~NativeDnsResolver() { + grpc_channel_args_destroy(channel_args_); + grpc_pollset_set_destroy(interested_parties_); +} + +void NativeDnsResolver::StartLocked() { MaybeStartResolvingLocked(); } + +void NativeDnsResolver::RequestReresolutionLocked() { + if (!resolving_) { + MaybeStartResolvingLocked(); + } +} + +void NativeDnsResolver::ResetBackoffLocked() { + if (have_next_resolution_timer_) { + grpc_timer_cancel(&next_resolution_timer_); + } + backoff_.Reset(); +} + +void NativeDnsResolver::ShutdownLocked() { + shutdown_ = true; + if (have_next_resolution_timer_) { + grpc_timer_cancel(&next_resolution_timer_); + } +} + +void NativeDnsResolver::OnNextResolution(void* arg, grpc_error_handle error) { + NativeDnsResolver* r = static_cast(arg); + (void)GRPC_ERROR_REF(error); // ref owned by lambda + r->work_serializer_->Run([r, error]() { r->OnNextResolutionLocked(error); }, + DEBUG_LOCATION); +} + +void NativeDnsResolver::OnNextResolutionLocked(grpc_error_handle error) { + have_next_resolution_timer_ = false; + if (error == GRPC_ERROR_NONE && !resolving_) { + StartResolvingLocked(); + } + Unref(DEBUG_LOCATION, "retry-timer"); + GRPC_ERROR_UNREF(error); +} + +void NativeDnsResolver::OnResolved(void* arg, grpc_error_handle error) { + NativeDnsResolver* r = static_cast(arg); + (void)GRPC_ERROR_REF(error); // owned by lambda + r->work_serializer_->Run([r, error]() { r->OnResolvedLocked(error); }, + DEBUG_LOCATION); +} + +void NativeDnsResolver::OnResolvedLocked(grpc_error_handle error) { + GPR_ASSERT(resolving_); + resolving_ = false; + if (shutdown_) { + Unref(DEBUG_LOCATION, "dns-resolving"); + GRPC_ERROR_UNREF(error); + return; + } + if (addresses_ != nullptr) { + Result result; + for (size_t i = 0; i < addresses_->naddrs; ++i) { + result.addresses.emplace_back(&addresses_->addrs[i].addr, + addresses_->addrs[i].len, + nullptr /* args */); + } + grpc_resolved_addresses_destroy(addresses_); + result.args = grpc_channel_args_copy(channel_args_); + result_handler_->ReturnResult(std::move(result)); + // Reset backoff state so that we start from the beginning when the + // next request gets triggered. + backoff_.Reset(); + } else { + gpr_log(GPR_INFO, "dns resolution failed (will retry): %s", + grpc_error_std_string(error).c_str()); + // Return transient error. + std::string error_message = + absl::StrCat("DNS resolution failed for service: ", name_to_resolve_); + result_handler_->ReturnError(grpc_error_set_int( + GRPC_ERROR_CREATE_REFERENCING_FROM_COPIED_STRING(error_message.c_str(), + &error, 1), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + // Set up for retry. + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + ExecCtx::Get()->InvalidateNow(); + grpc_millis next_try = backoff_.NextAttemptTime(); + grpc_millis timeout = next_try - ExecCtx::Get()->Now(); + GPR_ASSERT(!have_next_resolution_timer_); + have_next_resolution_timer_ = true; + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "next_resolution_timer").release(); + if (timeout > 0) { + gpr_log(GPR_DEBUG, "retrying in %" PRId64 " milliseconds", timeout); + } else { + gpr_log(GPR_DEBUG, "retrying immediately"); + } + GRPC_CLOSURE_INIT(&on_next_resolution_, NativeDnsResolver::OnNextResolution, + this, grpc_schedule_on_exec_ctx); + grpc_timer_init(&next_resolution_timer_, next_try, &on_next_resolution_); + } + Unref(DEBUG_LOCATION, "dns-resolving"); + GRPC_ERROR_UNREF(error); +} + +void NativeDnsResolver::MaybeStartResolvingLocked() { + // If there is an existing timer, the time it fires is the earliest time we + // can start the next resolution. + if (have_next_resolution_timer_) return; + if (last_resolution_timestamp_ >= 0) { + // InvalidateNow to avoid getting stuck re-initializing this timer + // in a loop while draining the currently-held WorkSerializer. + // Also see https://github.com/grpc/grpc/issues/26079. + ExecCtx::Get()->InvalidateNow(); + const grpc_millis earliest_next_resolution = + last_resolution_timestamp_ + min_time_between_resolutions_; + const grpc_millis ms_until_next_resolution = + earliest_next_resolution - grpc_core::ExecCtx::Get()->Now(); + if (ms_until_next_resolution > 0) { + const grpc_millis last_resolution_ago = + grpc_core::ExecCtx::Get()->Now() - last_resolution_timestamp_; + gpr_log(GPR_DEBUG, + "In cooldown from last resolution (from %" PRId64 + " ms ago). Will resolve again in %" PRId64 " ms", + last_resolution_ago, ms_until_next_resolution); + have_next_resolution_timer_ = true; + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "next_resolution_timer_cooldown").release(); + GRPC_CLOSURE_INIT(&on_next_resolution_, + NativeDnsResolver::OnNextResolution, this, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&next_resolution_timer_, + ExecCtx::Get()->Now() + ms_until_next_resolution, + &on_next_resolution_); + return; + } + } + StartResolvingLocked(); +} + +void NativeDnsResolver::StartResolvingLocked() { + gpr_log(GPR_DEBUG, "Start resolving."); + // TODO(roth): We currently deal with this ref manually. Once the + // new closure API is done, find a way to track this ref with the timer + // callback as part of the type system. + Ref(DEBUG_LOCATION, "dns-resolving").release(); + GPR_ASSERT(!resolving_); + resolving_ = true; + addresses_ = nullptr; + GRPC_CLOSURE_INIT(&on_resolved_, NativeDnsResolver::OnResolved, this, + grpc_schedule_on_exec_ctx); + grpc_resolve_address(name_to_resolve_.c_str(), kDefaultSecurePort, + interested_parties_, &on_resolved_, &addresses_); + last_resolution_timestamp_ = grpc_core::ExecCtx::Get()->Now(); +} + +// +// Factory +// + +class NativeDnsResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + if (GPR_UNLIKELY(!uri.authority().empty())) { + gpr_log(GPR_ERROR, "authority based dns uri's not supported"); + return false; + } + if (absl::StripPrefix(uri.path(), "/").empty()) { + gpr_log(GPR_ERROR, "no server name supplied in dns URI"); + return false; + } + return true; + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + if (!IsValidUri(args.uri)) return nullptr; + return MakeOrphanable(std::move(args)); + } + + const char* scheme() const override { return "dns"; } +}; + +} // namespace + +} // namespace grpc_core + +void grpc_resolver_dns_native_init() { + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (gpr_stricmp(resolver.get(), "native") == 0) { + gpr_log(GPR_DEBUG, "Using native dns resolver"); + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + } else { + grpc_core::ResolverRegistry::Builder::InitRegistry(); + grpc_core::ResolverFactory* existing_factory = + grpc_core::ResolverRegistry::LookupResolverFactory("dns"); + if (existing_factory == nullptr) { + gpr_log(GPR_DEBUG, "Using native dns resolver"); + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + } + } +} + +void grpc_resolver_dns_native_shutdown() {} diff --git a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc new file mode 100644 index 00000000..dceef6f0 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc @@ -0,0 +1,380 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// This is similar to the sockaddr resolver, except that it supports a +// bunch of query args that are useful for dependency injection in tests. + +#include + +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +namespace grpc_core { + +// This cannot be in an anonymous namespace, because it is a friend of +// FakeResolverResponseGenerator. +class FakeResolver : public Resolver { + public: + explicit FakeResolver(ResolverArgs args); + + void StartLocked() override; + + void RequestReresolutionLocked() override; + + private: + friend class FakeResolverResponseGenerator; + friend class FakeResolverResponseSetter; + + ~FakeResolver() override; + + void ShutdownLocked() override; + + void MaybeSendResultLocked(); + + void ReturnReresolutionResult(); + + // passed-in parameters + grpc_channel_args* channel_args_ = nullptr; + std::shared_ptr work_serializer_; + std::unique_ptr result_handler_; + RefCountedPtr response_generator_; + // If has_next_result_ is true, next_result_ is the next resolution result + // to be returned. + bool has_next_result_ = false; + Result next_result_; + // Result to use for the pretended re-resolution in + // RequestReresolutionLocked(). + bool has_reresolution_result_ = false; + Result reresolution_result_; + // True after the call to StartLocked(). + bool started_ = false; + // True after the call to ShutdownLocked(). + bool shutdown_ = false; + // if true, return failure + bool return_failure_ = false; + // pending re-resolution + bool reresolution_closure_pending_ = false; +}; + +FakeResolver::FakeResolver(ResolverArgs args) + : work_serializer_(std::move(args.work_serializer)), + result_handler_(std::move(args.result_handler)), + response_generator_( + FakeResolverResponseGenerator::GetFromArgs(args.args)) { + // Channels sharing the same subchannels may have different resolver response + // generators. If we don't remove this arg, subchannel pool will create new + // subchannels for the same address instead of reusing existing ones because + // of different values of this channel arg. + const char* args_to_remove[] = {GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR}; + channel_args_ = grpc_channel_args_copy_and_remove( + args.args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove)); + if (response_generator_ != nullptr) { + response_generator_->SetFakeResolver(Ref()); + } +} + +FakeResolver::~FakeResolver() { grpc_channel_args_destroy(channel_args_); } + +void FakeResolver::StartLocked() { + started_ = true; + MaybeSendResultLocked(); +} + +void FakeResolver::RequestReresolutionLocked() { + if (has_reresolution_result_ || return_failure_) { + next_result_ = reresolution_result_; + has_next_result_ = true; + // Return the result in a different closure, so that we don't call + // back into the LB policy while it's still processing the previous + // update. + if (!reresolution_closure_pending_) { + reresolution_closure_pending_ = true; + Ref().release(); // ref held by closure + work_serializer_->Run([this]() { ReturnReresolutionResult(); }, + DEBUG_LOCATION); + } + } +} + +void FakeResolver::ShutdownLocked() { + shutdown_ = true; + if (response_generator_ != nullptr) { + response_generator_->SetFakeResolver(nullptr); + response_generator_.reset(); + } +} + +void FakeResolver::MaybeSendResultLocked() { + if (!started_ || shutdown_) return; + if (return_failure_) { + // TODO(roth): Change resolver result generator to be able to inject + // the error to be returned. + result_handler_->ReturnError(grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resolver transient failure"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + return_failure_ = false; + } else if (has_next_result_) { + Result result; + result.addresses = std::move(next_result_.addresses); + result.service_config = std::move(next_result_.service_config); + // TODO(roth): Use std::move() once grpc_error is converted to C++. + result.service_config_error = next_result_.service_config_error; + next_result_.service_config_error = GRPC_ERROR_NONE; + // When both next_results_ and channel_args_ contain an arg with the same + // name, only the one in next_results_ will be kept since next_results_ is + // before channel_args_. + result.args = grpc_channel_args_union(next_result_.args, channel_args_); + result_handler_->ReturnResult(std::move(result)); + has_next_result_ = false; + } +} + +void FakeResolver::ReturnReresolutionResult() { + reresolution_closure_pending_ = false; + MaybeSendResultLocked(); + Unref(); +} + +class FakeResolverResponseSetter { + public: + explicit FakeResolverResponseSetter(RefCountedPtr resolver, + Resolver::Result result, + bool has_result = false, + bool immediate = true) + : resolver_(std::move(resolver)), + result_(std::move(result)), + has_result_(has_result), + immediate_(immediate) {} + void SetResponseLocked(); + void SetReresolutionResponseLocked(); + void SetFailureLocked(); + + private: + RefCountedPtr resolver_; + Resolver::Result result_; + bool has_result_; + bool immediate_; +}; + +// Deletes object when done +void FakeResolverResponseSetter::SetReresolutionResponseLocked() { + if (!resolver_->shutdown_) { + resolver_->reresolution_result_ = std::move(result_); + resolver_->has_reresolution_result_ = has_result_; + } + delete this; +} + +// Deletes object when done +void FakeResolverResponseSetter::SetResponseLocked() { + if (!resolver_->shutdown_) { + resolver_->next_result_ = std::move(result_); + resolver_->has_next_result_ = true; + resolver_->MaybeSendResultLocked(); + } + delete this; +} + +// Deletes object when done +void FakeResolverResponseSetter::SetFailureLocked() { + if (!resolver_->shutdown_) { + resolver_->return_failure_ = true; + if (immediate_) resolver_->MaybeSendResultLocked(); + } + delete this; +} + +// +// FakeResolverResponseGenerator +// + +FakeResolverResponseGenerator::FakeResolverResponseGenerator() {} + +FakeResolverResponseGenerator::~FakeResolverResponseGenerator() {} + +void FakeResolverResponseGenerator::SetResponse(Resolver::Result result) { + RefCountedPtr resolver; + { + MutexLock lock(&mu_); + if (resolver_ == nullptr) { + has_result_ = true; + result_ = std::move(result); + return; + } + resolver = resolver_->Ref(); + } + FakeResolverResponseSetter* arg = + new FakeResolverResponseSetter(resolver, std::move(result)); + resolver->work_serializer_->Run([arg]() { arg->SetResponseLocked(); }, + DEBUG_LOCATION); +} + +void FakeResolverResponseGenerator::SetReresolutionResponse( + Resolver::Result result) { + RefCountedPtr resolver; + { + MutexLock lock(&mu_); + GPR_ASSERT(resolver_ != nullptr); + resolver = resolver_->Ref(); + } + FakeResolverResponseSetter* arg = new FakeResolverResponseSetter( + resolver, std::move(result), true /* has_result */); + resolver->work_serializer_->Run( + [arg]() { arg->SetReresolutionResponseLocked(); }, DEBUG_LOCATION); +} + +void FakeResolverResponseGenerator::UnsetReresolutionResponse() { + RefCountedPtr resolver; + { + MutexLock lock(&mu_); + GPR_ASSERT(resolver_ != nullptr); + resolver = resolver_->Ref(); + } + FakeResolverResponseSetter* arg = + new FakeResolverResponseSetter(resolver, Resolver::Result()); + resolver->work_serializer_->Run( + [arg]() { arg->SetReresolutionResponseLocked(); }, DEBUG_LOCATION); +} + +void FakeResolverResponseGenerator::SetFailure() { + RefCountedPtr resolver; + { + MutexLock lock(&mu_); + GPR_ASSERT(resolver_ != nullptr); + resolver = resolver_->Ref(); + } + FakeResolverResponseSetter* arg = + new FakeResolverResponseSetter(resolver, Resolver::Result()); + resolver->work_serializer_->Run([arg]() { arg->SetFailureLocked(); }, + DEBUG_LOCATION); +} + +void FakeResolverResponseGenerator::SetFailureOnReresolution() { + RefCountedPtr resolver; + { + MutexLock lock(&mu_); + GPR_ASSERT(resolver_ != nullptr); + resolver = resolver_->Ref(); + } + FakeResolverResponseSetter* arg = new FakeResolverResponseSetter( + resolver, Resolver::Result(), false /* has_result */, + false /* immediate */); + resolver->work_serializer_->Run([arg]() { arg->SetFailureLocked(); }, + DEBUG_LOCATION); +} + +void FakeResolverResponseGenerator::SetFakeResolver( + RefCountedPtr resolver) { + MutexLock lock(&mu_); + resolver_ = std::move(resolver); + if (resolver_ == nullptr) return; + if (has_result_) { + FakeResolverResponseSetter* arg = + new FakeResolverResponseSetter(resolver_, std::move(result_)); + resolver_->work_serializer_->Run([arg]() { arg->SetResponseLocked(); }, + DEBUG_LOCATION); + has_result_ = false; + } +} + +namespace { + +void* ResponseGeneratorChannelArgCopy(void* p) { + auto* generator = static_cast(p); + generator->Ref().release(); + return p; +} + +void ResponseGeneratorChannelArgDestroy(void* p) { + auto* generator = static_cast(p); + generator->Unref(); +} + +int ResponseGeneratorChannelArgCmp(void* a, void* b) { + return QsortCompare(a, b); +} + +} // namespace + +const grpc_arg_pointer_vtable + FakeResolverResponseGenerator::kChannelArgPointerVtable = { + ResponseGeneratorChannelArgCopy, ResponseGeneratorChannelArgDestroy, + ResponseGeneratorChannelArgCmp}; + +grpc_arg FakeResolverResponseGenerator::MakeChannelArg( + FakeResolverResponseGenerator* generator) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR), generator, + &kChannelArgPointerVtable); +} + +RefCountedPtr +FakeResolverResponseGenerator::GetFromArgs(const grpc_channel_args* args) { + auto* response_generator = + grpc_channel_args_find_pointer( + args, GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR); + if (response_generator == nullptr) return nullptr; + return response_generator->Ref(); +} + +// +// Factory +// + +namespace { + +class FakeResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& /*uri*/) const override { return true; } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* scheme() const override { return "fake"; } +}; + +} // namespace + +} // namespace grpc_core + +void grpc_resolver_fake_init() { + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); +} + +void grpc_resolver_fake_shutdown() {} diff --git a/src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc b/src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc new file mode 100644 index 00000000..a430e7a9 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc @@ -0,0 +1,384 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +namespace grpc_core { + +namespace { + +class GoogleCloud2ProdResolver : public Resolver { + public: + explicit GoogleCloud2ProdResolver(ResolverArgs args); + + void StartLocked() override; + void RequestReresolutionLocked() override; + void ResetBackoffLocked() override; + void ShutdownLocked() override; + + private: + // Represents an HTTP request to the metadata server. + class MetadataQuery : public InternallyRefCounted { + public: + MetadataQuery(RefCountedPtr resolver, + const char* path, grpc_polling_entity* pollent); + ~MetadataQuery() override; + + void Orphan() override; + + private: + static void OnHttpRequestDone(void* arg, grpc_error_handle error); + + // Calls OnDone() if not already called. Releases a ref. + void MaybeCallOnDone(grpc_error_handle error); + + // If error is not GRPC_ERROR_NONE, then it's not safe to look at response. + virtual void OnDone(GoogleCloud2ProdResolver* resolver, + const grpc_http_response* response, + grpc_error_handle error) = 0; + + RefCountedPtr resolver_; + grpc_httpcli_context context_; + grpc_httpcli_response response_; + grpc_closure on_done_; + std::atomic on_done_called_{false}; + }; + + // A metadata server query to get the zone. + class ZoneQuery : public MetadataQuery { + public: + ZoneQuery(RefCountedPtr resolver, + grpc_polling_entity* pollent); + + private: + void OnDone(GoogleCloud2ProdResolver* resolver, + const grpc_http_response* response, + grpc_error_handle error) override; + }; + + // A metadata server query to get the IPv6 address. + class IPv6Query : public MetadataQuery { + public: + IPv6Query(RefCountedPtr resolver, + grpc_polling_entity* pollent); + + private: + void OnDone(GoogleCloud2ProdResolver* resolver, + const grpc_http_response* response, + grpc_error_handle error) override; + }; + + void ZoneQueryDone(std::string zone); + void IPv6QueryDone(bool ipv6_supported); + void StartXdsResolver(); + + std::shared_ptr work_serializer_; + grpc_polling_entity pollent_; + bool using_dns_ = false; + OrphanablePtr child_resolver_; + + OrphanablePtr zone_query_; + absl::optional zone_; + + OrphanablePtr ipv6_query_; + absl::optional supports_ipv6_; +}; + +// +// GoogleCloud2ProdResolver::MetadataQuery +// + +GoogleCloud2ProdResolver::MetadataQuery::MetadataQuery( + RefCountedPtr resolver, const char* path, + grpc_polling_entity* pollent) + : resolver_(std::move(resolver)) { + grpc_httpcli_context_init(&context_); + // Start HTTP request. + GRPC_CLOSURE_INIT(&on_done_, OnHttpRequestDone, this, nullptr); + Ref().release(); // Ref held by callback. + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + grpc_http_header header = {const_cast("Metadata-Flavor"), + const_cast("Google")}; + request.host = const_cast("metadata.google.internal"); + request.http.path = const_cast(path); + request.http.hdr_count = 1; + request.http.hdrs = &header; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("c2p_resolver"); + grpc_httpcli_get(&context_, pollent, resource_quota, &request, + ExecCtx::Get()->Now() + 10000, // 10s timeout + &on_done_, &response_); +} + +GoogleCloud2ProdResolver::MetadataQuery::~MetadataQuery() { + grpc_httpcli_context_destroy(&context_); + grpc_http_response_destroy(&response_); +} + +void GoogleCloud2ProdResolver::MetadataQuery::Orphan() { + // TODO(roth): Once the HTTP client library supports cancellation, + // use that here. + MaybeCallOnDone(GRPC_ERROR_CANCELLED); +} + +void GoogleCloud2ProdResolver::MetadataQuery::OnHttpRequestDone( + void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + self->MaybeCallOnDone(GRPC_ERROR_REF(error)); +} + +void GoogleCloud2ProdResolver::MetadataQuery::MaybeCallOnDone( + grpc_error_handle error) { + bool expected = false; + if (!on_done_called_.compare_exchange_strong(expected, true, + std::memory_order_relaxed, + std::memory_order_relaxed)) { + // We've already called OnDone(), so just clean up. + GRPC_ERROR_UNREF(error); + Unref(); + return; + } + // Hop back into WorkSerializer to call OnDone(). + // Note: We implicitly pass our ref to the callback here. + resolver_->work_serializer_->Run( + [this, error]() { + OnDone(resolver_.get(), &response_, error); + Unref(); + }, + DEBUG_LOCATION); +} + +// +// GoogleCloud2ProdResolver::ZoneQuery +// + +GoogleCloud2ProdResolver::ZoneQuery::ZoneQuery( + RefCountedPtr resolver, + grpc_polling_entity* pollent) + : MetadataQuery(std::move(resolver), "/computeMetadata/v1/instance/zone", + pollent) {} + +void GoogleCloud2ProdResolver::ZoneQuery::OnDone( + GoogleCloud2ProdResolver* resolver, const grpc_http_response* response, + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "error fetching zone from metadata server: %s", + grpc_error_std_string(error).c_str()); + } + std::string zone; + if (error == GRPC_ERROR_NONE && response->status == 200) { + absl::string_view body(response->body, response->body_length); + size_t i = body.find_last_of('/'); + if (i == body.npos) { + gpr_log(GPR_ERROR, "could not parse zone from metadata server: %s", + std::string(body).c_str()); + } else { + zone = std::string(body.substr(i + 1)); + } + } + resolver->ZoneQueryDone(std::move(zone)); + GRPC_ERROR_UNREF(error); +} + +// +// GoogleCloud2ProdResolver::IPv6Query +// + +GoogleCloud2ProdResolver::IPv6Query::IPv6Query( + RefCountedPtr resolver, + grpc_polling_entity* pollent) + : MetadataQuery(std::move(resolver), + "/computeMetadata/v1/instance/network-interfaces/0/ipv6s", + pollent) {} + +void GoogleCloud2ProdResolver::IPv6Query::OnDone( + GoogleCloud2ProdResolver* resolver, const grpc_http_response* response, + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "error fetching IPv6 address from metadata server: %s", + grpc_error_std_string(error).c_str()); + } + resolver->IPv6QueryDone(error == GRPC_ERROR_NONE && response->status == 200); + GRPC_ERROR_UNREF(error); +} + +// +// GoogleCloud2ProdResolver +// + +GoogleCloud2ProdResolver::GoogleCloud2ProdResolver(ResolverArgs args) + : work_serializer_(std::move(args.work_serializer)), + pollent_(grpc_polling_entity_create_from_pollset_set(args.pollset_set)) { + absl::string_view name_to_resolve = absl::StripPrefix(args.uri.path(), "/"); + // If we're not running on GCP, we can't use DirectPath, so delegate + // to the DNS resolver. + if (!grpc_alts_is_running_on_gcp() || + // If the client is already using xDS, we can't use it here, because + // they may be talking to a completely different xDS server than we + // want to. + // TODO(roth): When we implement xDS federation, remove this constraint. + UniquePtr(gpr_getenv("GRPC_XDS_BOOTSTRAP")) != nullptr || + UniquePtr(gpr_getenv("GRPC_XDS_BOOTSTRAP_CONFIG")) != nullptr) { + using_dns_ = true; + child_resolver_ = ResolverRegistry::CreateResolver( + absl::StrCat("dns:", name_to_resolve).c_str(), args.args, + args.pollset_set, work_serializer_, std::move(args.result_handler)); + GPR_ASSERT(child_resolver_ != nullptr); + return; + } + // Create xds resolver. + child_resolver_ = ResolverRegistry::CreateResolver( + absl::StrCat("xds:", name_to_resolve).c_str(), args.args, + args.pollset_set, work_serializer_, std::move(args.result_handler)); + GPR_ASSERT(child_resolver_ != nullptr); +} + +void GoogleCloud2ProdResolver::StartLocked() { + if (using_dns_) { + child_resolver_->StartLocked(); + return; + } + // Using xDS. Start metadata server queries. + zone_query_ = MakeOrphanable(Ref(), &pollent_); + ipv6_query_ = MakeOrphanable(Ref(), &pollent_); +} + +void GoogleCloud2ProdResolver::RequestReresolutionLocked() { + if (child_resolver_ != nullptr) { + child_resolver_->RequestReresolutionLocked(); + } +} + +void GoogleCloud2ProdResolver::ResetBackoffLocked() { + if (child_resolver_ != nullptr) { + child_resolver_->ResetBackoffLocked(); + } +} + +void GoogleCloud2ProdResolver::ShutdownLocked() { + zone_query_.reset(); + ipv6_query_.reset(); + child_resolver_.reset(); +} + +void GoogleCloud2ProdResolver::ZoneQueryDone(std::string zone) { + zone_query_.reset(); + zone_ = std::move(zone); + if (supports_ipv6_.has_value()) StartXdsResolver(); +} + +void GoogleCloud2ProdResolver::IPv6QueryDone(bool ipv6_supported) { + ipv6_query_.reset(); + supports_ipv6_ = ipv6_supported; + if (zone_.has_value()) StartXdsResolver(); +} + +void GoogleCloud2ProdResolver::StartXdsResolver() { + // Construct bootstrap JSON. + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution dist(1, UINT64_MAX); + Json::Object node = { + {"id", absl::StrCat("C2P-", dist(mt))}, + }; + if (!zone_->empty()) { + node["locality"] = Json::Object{ + {"zone", *zone_}, + }; + }; + if (*supports_ipv6_) { + node["metadata"] = Json::Object{ + {"TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true}, + }; + } + // Allow the TD server uri to be overridden for testing purposes. + UniquePtr override_server( + gpr_getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI")); + const char* server_uri = + override_server != nullptr && strlen(override_server.get()) > 0 + ? override_server.get() + : "directpath-pa.googleapis.com"; + Json bootstrap = Json::Object{ + {"xds_servers", + Json::Array{ + Json::Object{ + {"server_uri", server_uri}, + {"channel_creds", + Json::Array{ + Json::Object{ + {"type", "google_default"}, + }, + }}, + {"server_features", Json::Array{"xds_v3"}}, + }, + }}, + {"node", std::move(node)}, + }; + // Inject bootstrap JSON as fallback config. + internal::SetXdsFallbackBootstrapConfig(bootstrap.Dump().c_str()); + // Now start xDS resolver. + child_resolver_->StartLocked(); +} + +// +// Factory +// + +class GoogleCloud2ProdResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + if (GPR_UNLIKELY(!uri.authority().empty())) { + gpr_log(GPR_ERROR, "google-c2p URI scheme does not support authorities"); + return false; + } + return true; + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + if (!IsValidUri(args.uri)) return nullptr; + return MakeOrphanable(std::move(args)); + } + + const char* scheme() const override { return "google-c2p"; } +}; + +} // namespace + +void GoogleCloud2ProdResolverInit() { + // TODO(roth): Remove env var protection once this code is proven stable. + UniquePtr value(gpr_getenv("GRPC_EXPERIMENTAL_GOOGLE_C2P_RESOLVER")); + bool parsed_value; + bool parse_succeeded = gpr_parse_bool_value(value.get(), &parsed_value); + if (parse_succeeded && parsed_value) { + ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + } +} + +void GoogleCloud2ProdResolverShutdown() {} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc b/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc new file mode 100644 index 00000000..46efbb06 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc @@ -0,0 +1,195 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "absl/strings/str_split.h" + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +namespace grpc_core { + +namespace { + +class SockaddrResolver : public Resolver { + public: + SockaddrResolver(ServerAddressList addresses, ResolverArgs args); + ~SockaddrResolver() override; + + void StartLocked() override; + + void ShutdownLocked() override {} + + private: + std::unique_ptr result_handler_; + ServerAddressList addresses_; + const grpc_channel_args* channel_args_ = nullptr; +}; + +SockaddrResolver::SockaddrResolver(ServerAddressList addresses, + ResolverArgs args) + : result_handler_(std::move(args.result_handler)), + addresses_(std::move(addresses)), + channel_args_(grpc_channel_args_copy(args.args)) {} + +SockaddrResolver::~SockaddrResolver() { + grpc_channel_args_destroy(channel_args_); +} + +void SockaddrResolver::StartLocked() { + Result result; + result.addresses = std::move(addresses_); + // TODO(roth): Use std::move() once channel args is converted to C++. + result.args = channel_args_; + channel_args_ = nullptr; + result_handler_->ReturnResult(std::move(result)); +} + +// +// Factory +// + +bool ParseUri(const URI& uri, + bool parse(const URI& uri, grpc_resolved_address* dst), + ServerAddressList* addresses) { + if (!uri.authority().empty()) { + gpr_log(GPR_ERROR, "authority-based URIs not supported by the %s scheme", + uri.scheme().c_str()); + return false; + } + // Construct addresses. + bool errors_found = false; + for (absl::string_view ith_path : absl::StrSplit(uri.path(), ',')) { + URI ith_uri(uri.scheme(), "", std::string(ith_path), {}, ""); + grpc_resolved_address addr; + if (!parse(ith_uri, &addr)) { + errors_found = true; + break; + } + if (addresses != nullptr) { + addresses->emplace_back(addr, nullptr /* args */); + } + } + return !errors_found; +} + +OrphanablePtr CreateSockaddrResolver( + ResolverArgs args, bool parse(const URI& uri, grpc_resolved_address* dst)) { + ServerAddressList addresses; + if (!ParseUri(args.uri, parse, &addresses)) return nullptr; + // Instantiate resolver. + return MakeOrphanable(std::move(addresses), + std::move(args)); +} + +class IPv4ResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + return ParseUri(uri, grpc_parse_ipv4, nullptr); + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return CreateSockaddrResolver(std::move(args), grpc_parse_ipv4); + } + + const char* scheme() const override { return "ipv4"; } +}; + +class IPv6ResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + return ParseUri(uri, grpc_parse_ipv6, nullptr); + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return CreateSockaddrResolver(std::move(args), grpc_parse_ipv6); + } + + const char* scheme() const override { return "ipv6"; } +}; + +#ifdef GRPC_HAVE_UNIX_SOCKET +class UnixResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + return ParseUri(uri, grpc_parse_unix, nullptr); + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return CreateSockaddrResolver(std::move(args), grpc_parse_unix); + } + + std::string GetDefaultAuthority(const URI& /*uri*/) const override { + return "localhost"; + } + + const char* scheme() const override { return "unix"; } +}; + +class UnixAbstractResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + return ParseUri(uri, grpc_parse_unix_abstract, nullptr); + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + return CreateSockaddrResolver(std::move(args), grpc_parse_unix_abstract); + } + + std::string GetDefaultAuthority(const URI& /*uri*/) const override { + return "localhost"; + } + + const char* scheme() const override { return "unix-abstract"; } +}; +#endif // GRPC_HAVE_UNIX_SOCKET + +} // namespace + +} // namespace grpc_core + +void grpc_resolver_sockaddr_init() { + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); +#ifdef GRPC_HAVE_UNIX_SOCKET + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); +#endif +} + +void grpc_resolver_sockaddr_shutdown() {} diff --git a/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc new file mode 100644 index 00000000..0e00e6f2 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc @@ -0,0 +1,980 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "re2/re2.h" +#define XXH_INLINE_ALL +#include "xxhash.h" + +#include "src/core/ext/filters/client_channel/config_selector.h" +#include "src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/ext/xds/xds_http_filters.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/timeout_encoding.h" + +namespace grpc_core { + +TraceFlag grpc_xds_resolver_trace(false, "xds_resolver"); + +const char* kXdsClusterAttribute = "xds_cluster_name"; + +namespace { + +// +// XdsResolver +// + +class XdsResolver : public Resolver { + public: + explicit XdsResolver(ResolverArgs args) + : work_serializer_(std::move(args.work_serializer)), + result_handler_(std::move(args.result_handler)), + server_name_(absl::StripPrefix(args.uri.path(), "/")), + args_(grpc_channel_args_copy(args.args)), + interested_parties_(args.pollset_set) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] created for server name %s", this, + server_name_.c_str()); + } + } + + ~XdsResolver() override { + grpc_channel_args_destroy(args_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] destroyed", this); + } + } + + void StartLocked() override; + + void ShutdownLocked() override; + + void ResetBackoffLocked() override { + if (xds_client_ != nullptr) xds_client_->ResetBackoff(); + } + + private: + class Notifier { + public: + Notifier(RefCountedPtr resolver, XdsApi::LdsUpdate update); + Notifier(RefCountedPtr resolver, XdsApi::RdsUpdate update); + Notifier(RefCountedPtr resolver, grpc_error_handle error); + explicit Notifier(RefCountedPtr resolver); + + private: + enum Type { kLdsUpdate, kRdsUpdate, kError, kDoesNotExist }; + + static void RunInExecCtx(void* arg, grpc_error_handle error); + void RunInWorkSerializer(grpc_error_handle error); + + RefCountedPtr resolver_; + grpc_closure closure_; + XdsApi::LdsUpdate update_; + Type type_; + }; + + class ListenerWatcher : public XdsClient::ListenerWatcherInterface { + public: + explicit ListenerWatcher(RefCountedPtr resolver) + : resolver_(std::move(resolver)) {} + void OnListenerChanged(XdsApi::LdsUpdate listener) override { + new Notifier(resolver_, std::move(listener)); + } + void OnError(grpc_error_handle error) override { + new Notifier(resolver_, error); + } + void OnResourceDoesNotExist() override { new Notifier(resolver_); } + + private: + RefCountedPtr resolver_; + }; + + class RouteConfigWatcher : public XdsClient::RouteConfigWatcherInterface { + public: + explicit RouteConfigWatcher(RefCountedPtr resolver) + : resolver_(std::move(resolver)) {} + void OnRouteConfigChanged(XdsApi::RdsUpdate route_config) override { + new Notifier(resolver_, std::move(route_config)); + } + void OnError(grpc_error_handle error) override { + new Notifier(resolver_, error); + } + void OnResourceDoesNotExist() override { new Notifier(resolver_); } + + private: + RefCountedPtr resolver_; + }; + + // An entry in the map of clusters that need to be present in the LB + // policy config. The map holds a weak ref. One strong ref is held by + // the ConfigSelector, and another is held by each call assigned to + // the cluster by the ConfigSelector. The ref for each call is held + // until the call is committed. When the strong refs go away, we hop + // back into the WorkSerializer to remove the entry from the map. + class ClusterState : public DualRefCounted { + public: + using ClusterStateMap = + std::map>; + + ClusterState(RefCountedPtr resolver, + const std::string& cluster_name) + : resolver_(std::move(resolver)), + it_(resolver_->cluster_state_map_.emplace(cluster_name, WeakRef()) + .first) {} + + void Orphan() override { + auto* resolver = resolver_.release(); + resolver->work_serializer_->Run( + [resolver]() { + resolver->MaybeRemoveUnusedClusters(); + resolver->Unref(); + }, + DEBUG_LOCATION); + } + + const std::string& cluster() const { return it_->first; } + + private: + RefCountedPtr resolver_; + ClusterStateMap::iterator it_; + }; + + // Call dispatch controller, created for each call handled by the + // ConfigSelector. Holds a ref to the ClusterState object until the + // call is committed. + class XdsCallDispatchController + : public ConfigSelector::CallDispatchController { + public: + explicit XdsCallDispatchController( + RefCountedPtr cluster_state) + : cluster_state_(std::move(cluster_state)) {} + + bool ShouldRetry() override { + // TODO(donnadionne): Implement the retry circuit breaker here. + return true; + } + + void Commit() override { + // TODO(donnadionne): If ShouldRetry() was called previously, + // decrement the retry circuit breaker counter. + cluster_state_.reset(); + } + + private: + // Note: The XdsCallDispatchController object is never actually destroyed, + // so do not add any data members that require destruction unless you have + // some other way to clean them up. + RefCountedPtr cluster_state_; + }; + + class XdsConfigSelector : public ConfigSelector { + public: + XdsConfigSelector(RefCountedPtr resolver, + grpc_error_handle* error); + ~XdsConfigSelector() override; + + const char* name() const override { return "XdsConfigSelector"; } + + bool Equals(const ConfigSelector* other) const override { + const auto* other_xds = static_cast(other); + // Don't need to compare resolver_, since that will always be the same. + return route_table_ == other_xds->route_table_ && + clusters_ == other_xds->clusters_; + } + + CallConfig GetCallConfig(GetCallConfigArgs args) override; + + std::vector GetFilters() override { + return filters_; + } + + grpc_channel_args* ModifyChannelArgs(grpc_channel_args* args) override; + + private: + struct Route { + struct ClusterWeightState { + uint32_t range_end; + absl::string_view cluster; + RefCountedPtr method_config; + + bool operator==(const ClusterWeightState& other) const; + }; + + XdsApi::Route route; + RefCountedPtr method_config; + absl::InlinedVector weighted_cluster_state; + + bool operator==(const Route& other) const; + }; + using RouteTable = std::vector; + + void MaybeAddCluster(const std::string& name); + grpc_error_handle CreateMethodConfig( + const XdsApi::Route& route, + const XdsApi::Route::ClusterWeight* cluster_weight, + RefCountedPtr* method_config); + + RefCountedPtr resolver_; + RouteTable route_table_; + std::map> clusters_; + std::vector filters_; + }; + + void OnListenerUpdate(XdsApi::LdsUpdate listener); + void OnRouteConfigUpdate(XdsApi::RdsUpdate rds_update); + void OnError(grpc_error_handle error); + void OnResourceDoesNotExist(); + + grpc_error_handle CreateServiceConfig( + RefCountedPtr* service_config); + void GenerateResult(); + void MaybeRemoveUnusedClusters(); + + std::shared_ptr work_serializer_; + std::unique_ptr result_handler_; + std::string server_name_; + const grpc_channel_args* args_; + grpc_pollset_set* interested_parties_; + + RefCountedPtr xds_client_; + + XdsClient::ListenerWatcherInterface* listener_watcher_ = nullptr; + // This will not contain the RouteConfiguration, even if it comes with the + // LDS response; instead, the relevant VirtualHost from the + // RouteConfiguration will be saved in current_virtual_host_. + XdsApi::LdsUpdate current_listener_; + + std::string route_config_name_; + XdsClient::RouteConfigWatcherInterface* route_config_watcher_ = nullptr; + XdsApi::RdsUpdate::VirtualHost current_virtual_host_; + + ClusterState::ClusterStateMap cluster_state_map_; +}; + +// +// XdsResolver::Notifier +// + +XdsResolver::Notifier::Notifier(RefCountedPtr resolver, + XdsApi::LdsUpdate update) + : resolver_(std::move(resolver)), + update_(std::move(update)), + type_(kLdsUpdate) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +XdsResolver::Notifier::Notifier(RefCountedPtr resolver, + XdsApi::RdsUpdate update) + : resolver_(std::move(resolver)), type_(kRdsUpdate) { + update_.http_connection_manager.rds_update = std::move(update); + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +XdsResolver::Notifier::Notifier(RefCountedPtr resolver, + grpc_error_handle error) + : resolver_(std::move(resolver)), type_(kError) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, error); +} + +XdsResolver::Notifier::Notifier(RefCountedPtr resolver) + : resolver_(std::move(resolver)), type_(kDoesNotExist) { + GRPC_CLOSURE_INIT(&closure_, &RunInExecCtx, this, nullptr); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); +} + +void XdsResolver::Notifier::RunInExecCtx(void* arg, grpc_error_handle error) { + Notifier* self = static_cast(arg); + (void)GRPC_ERROR_REF(error); + self->resolver_->work_serializer_->Run( + [self, error]() { self->RunInWorkSerializer(error); }, DEBUG_LOCATION); +} + +void XdsResolver::Notifier::RunInWorkSerializer(grpc_error_handle error) { + if (resolver_->xds_client_ == nullptr) { + GRPC_ERROR_UNREF(error); + delete this; + return; + } + switch (type_) { + case kLdsUpdate: + resolver_->OnListenerUpdate(std::move(update_)); + break; + case kRdsUpdate: + resolver_->OnRouteConfigUpdate( + std::move(*update_.http_connection_manager.rds_update)); + break; + case kError: + resolver_->OnError(error); + break; + case kDoesNotExist: + resolver_->OnResourceDoesNotExist(); + break; + }; + delete this; +} + +// +// XdsResolver::XdsConfigSelector::Route +// + +bool MethodConfigsEqual(const ServiceConfig* sc1, const ServiceConfig* sc2) { + if (sc1 == nullptr) return sc2 == nullptr; + if (sc2 == nullptr) return false; + return sc1->json_string() == sc2->json_string(); +} + +bool XdsResolver::XdsConfigSelector::Route::ClusterWeightState::operator==( + const ClusterWeightState& other) const { + return range_end == other.range_end && cluster == other.cluster && + MethodConfigsEqual(method_config.get(), other.method_config.get()); +} + +bool XdsResolver::XdsConfigSelector::Route::operator==( + const Route& other) const { + return route == other.route && + weighted_cluster_state == other.weighted_cluster_state && + MethodConfigsEqual(method_config.get(), other.method_config.get()); +} + +// +// XdsResolver::XdsConfigSelector +// + +XdsResolver::XdsConfigSelector::XdsConfigSelector( + RefCountedPtr resolver, grpc_error_handle* error) + : resolver_(std::move(resolver)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] creating XdsConfigSelector %p", + resolver_.get(), this); + } + // 1. Construct the route table + // 2 Update resolver's cluster state map + // 3. Construct cluster list to hold on to entries in the cluster state + // map. + // Reserve the necessary entries up-front to avoid reallocation as we add + // elements. This is necessary because the string_view in the entry's + // weighted_cluster_state field points to the memory in the route field, so + // moving the entry in a reallocation will cause the string_view to point to + // invalid data. + route_table_.reserve(resolver_->current_virtual_host_.routes.size()); + for (auto& route : resolver_->current_virtual_host_.routes) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] XdsConfigSelector %p: route: %s", + resolver_.get(), this, route.ToString().c_str()); + } + route_table_.emplace_back(); + auto& route_entry = route_table_.back(); + route_entry.route = route; + // If the route doesn't specify a timeout, set its timeout to the global + // one. + if (!route.max_stream_duration.has_value()) { + route_entry.route.max_stream_duration = + resolver_->current_listener_.http_connection_manager + .http_max_stream_duration; + } + if (route.weighted_clusters.empty()) { + *error = CreateMethodConfig(route_entry.route, nullptr, + &route_entry.method_config); + MaybeAddCluster(route.cluster_name); + } else { + uint32_t end = 0; + for (const auto& weighted_cluster : route_entry.route.weighted_clusters) { + Route::ClusterWeightState cluster_weight_state; + *error = CreateMethodConfig(route_entry.route, &weighted_cluster, + &cluster_weight_state.method_config); + if (*error != GRPC_ERROR_NONE) return; + end += weighted_cluster.weight; + cluster_weight_state.range_end = end; + cluster_weight_state.cluster = weighted_cluster.name; + route_entry.weighted_cluster_state.push_back( + std::move(cluster_weight_state)); + MaybeAddCluster(weighted_cluster.name); + } + } + } + // Populate filter list. + for (const auto& http_filter : + resolver_->current_listener_.http_connection_manager.http_filters) { + // Find filter. This is guaranteed to succeed, because it's checked + // at config validation time in the XdsApi code. + const XdsHttpFilterImpl* filter_impl = + XdsHttpFilterRegistry::GetFilterForType( + http_filter.config.config_proto_type_name); + GPR_ASSERT(filter_impl != nullptr); + // Add C-core filter to list. + if (filter_impl->channel_filter() != nullptr) { + filters_.push_back(filter_impl->channel_filter()); + } + } +} + +XdsResolver::XdsConfigSelector::~XdsConfigSelector() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] destroying XdsConfigSelector %p", + resolver_.get(), this); + } + clusters_.clear(); + resolver_->MaybeRemoveUnusedClusters(); +} + +const XdsHttpFilterImpl::FilterConfig* FindFilterConfigOverride( + const std::string& instance_name, + const XdsApi::RdsUpdate::VirtualHost& vhost, const XdsApi::Route& route, + const XdsApi::Route::ClusterWeight* cluster_weight) { + // Check ClusterWeight, if any. + if (cluster_weight != nullptr) { + auto it = cluster_weight->typed_per_filter_config.find(instance_name); + if (it != cluster_weight->typed_per_filter_config.end()) return &it->second; + } + // Check Route. + auto it = route.typed_per_filter_config.find(instance_name); + if (it != route.typed_per_filter_config.end()) return &it->second; + // Check VirtualHost. + it = vhost.typed_per_filter_config.find(instance_name); + if (it != vhost.typed_per_filter_config.end()) return &it->second; + // Not found. + return nullptr; +} + +grpc_error_handle XdsResolver::XdsConfigSelector::CreateMethodConfig( + const XdsApi::Route& route, + const XdsApi::Route::ClusterWeight* cluster_weight, + RefCountedPtr* method_config) { + std::vector fields; + // Set retry policy if any. + if (route.retry_policy.has_value() && !route.retry_policy->retry_on.Empty()) { + std::vector retry_parts; + retry_parts.push_back(absl::StrFormat( + "\"retryPolicy\": {\n" + " \"maxAttempts\": %d,\n" + " \"initialBackoff\": \"%d.%09ds\",\n" + " \"maxBackoff\": \"%d.%09ds\",\n" + " \"backoffMultiplier\": 2,\n", + route.retry_policy->num_retries + 1, + route.retry_policy->retry_back_off.base_interval.seconds, + route.retry_policy->retry_back_off.base_interval.nanos, + route.retry_policy->retry_back_off.max_interval.seconds, + route.retry_policy->retry_back_off.max_interval.nanos)); + std::vector code_parts; + if (route.retry_policy->retry_on.Contains(GRPC_STATUS_CANCELLED)) { + code_parts.push_back(" \"CANCELLED\""); + } + if (route.retry_policy->retry_on.Contains(GRPC_STATUS_DEADLINE_EXCEEDED)) { + code_parts.push_back(" \"DEADLINE_EXCEEDED\""); + } + if (route.retry_policy->retry_on.Contains(GRPC_STATUS_INTERNAL)) { + code_parts.push_back(" \"INTERNAL\""); + } + if (route.retry_policy->retry_on.Contains(GRPC_STATUS_RESOURCE_EXHAUSTED)) { + code_parts.push_back(" \"RESOURCE_EXHAUSTED\""); + } + if (route.retry_policy->retry_on.Contains(GRPC_STATUS_UNAVAILABLE)) { + code_parts.push_back(" \"UNAVAILABLE\""); + } + retry_parts.push_back( + absl::StrFormat(" \"retryableStatusCodes\": [\n %s ]\n", + absl::StrJoin(code_parts, ",\n"))); + retry_parts.push_back(absl::StrFormat(" }")); + fields.emplace_back(absl::StrJoin(retry_parts, "")); + } + // Set timeout. + if (route.max_stream_duration.has_value() && + (route.max_stream_duration->seconds != 0 || + route.max_stream_duration->nanos != 0)) { + fields.emplace_back(absl::StrFormat(" \"timeout\": \"%d.%09ds\"", + route.max_stream_duration->seconds, + route.max_stream_duration->nanos)); + } + // Handle xDS HTTP filters. + std::map> per_filter_configs; + grpc_channel_args* args = grpc_channel_args_copy(resolver_->args_); + for (const auto& http_filter : + resolver_->current_listener_.http_connection_manager.http_filters) { + // Find filter. This is guaranteed to succeed, because it's checked + // at config validation time in the XdsApi code. + const XdsHttpFilterImpl* filter_impl = + XdsHttpFilterRegistry::GetFilterForType( + http_filter.config.config_proto_type_name); + GPR_ASSERT(filter_impl != nullptr); + // If there is not actually any C-core filter associated with this + // xDS filter, then it won't need any config, so skip it. + if (filter_impl->channel_filter() == nullptr) continue; + // Allow filter to add channel args that may affect service config + // parsing. + args = filter_impl->ModifyChannelArgs(args); + // Find config override, if any. + const XdsHttpFilterImpl::FilterConfig* config_override = + FindFilterConfigOverride(http_filter.name, + resolver_->current_virtual_host_, route, + cluster_weight); + // Generate service config for filter. + auto method_config_field = + filter_impl->GenerateServiceConfig(http_filter.config, config_override); + if (!method_config_field.ok()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "failed to generate method config for HTTP filter ", http_filter.name, + ": ", method_config_field.status().ToString())); + } + per_filter_configs[method_config_field->service_config_field_name] + .push_back(method_config_field->element); + } + for (const auto& p : per_filter_configs) { + fields.emplace_back(absl::StrCat(" \"", p.first, "\": [\n", + absl::StrJoin(p.second, ",\n"), + "\n ]")); + } + // Construct service config. + grpc_error_handle error = GRPC_ERROR_NONE; + if (!fields.empty()) { + std::string json = absl::StrCat( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " {}\n" + " ],\n" + " ", + absl::StrJoin(fields, ",\n"), + "\n } ]\n" + "}"); + *method_config = ServiceConfig::Create(args, json.c_str(), &error); + } + grpc_channel_args_destroy(args); + return error; +} + +grpc_channel_args* XdsResolver::XdsConfigSelector::ModifyChannelArgs( + grpc_channel_args* args) { + return args; +} + +void XdsResolver::XdsConfigSelector::MaybeAddCluster(const std::string& name) { + if (clusters_.find(name) == clusters_.end()) { + auto it = resolver_->cluster_state_map_.find(name); + if (it == resolver_->cluster_state_map_.end()) { + auto new_cluster_state = MakeRefCounted(resolver_, name); + clusters_[new_cluster_state->cluster()] = std::move(new_cluster_state); + } else { + clusters_[it->second->cluster()] = it->second->Ref(); + } + } +} + +absl::optional GetHeaderValue( + grpc_metadata_batch* initial_metadata, absl::string_view header_name, + std::string* concatenated_value) { + // Note: If we ever allow binary headers here, we still need to + // special-case ignore "grpc-tags-bin" and "grpc-trace-bin", since + // they are not visible to the LB policy in grpc-go. + if (absl::EndsWith(header_name, "-bin")) { + return absl::nullopt; + } else if (header_name == "content-type") { + return "application/grpc"; + } + return initial_metadata->GetValue(header_name, concatenated_value); +} + +bool HeadersMatch(const std::vector& header_matchers, + grpc_metadata_batch* initial_metadata) { + for (const auto& header_matcher : header_matchers) { + std::string concatenated_value; + if (!header_matcher.Match(GetHeaderValue( + initial_metadata, header_matcher.name(), &concatenated_value))) { + return false; + } + } + return true; +} + +absl::optional HeaderHashHelper( + const XdsApi::Route::HashPolicy& policy, + grpc_metadata_batch* initial_metadata) { + GPR_ASSERT(policy.type == XdsApi::Route::HashPolicy::HEADER); + std::string value_buffer; + absl::optional header_value = + GetHeaderValue(initial_metadata, policy.header_name, &value_buffer); + if (!header_value.has_value()) { + return absl::nullopt; + } + if (policy.regex != nullptr) { + // If GetHeaderValue() did not already store the value in + // value_buffer, copy it there now, so we can modify it. + if (header_value->data() != value_buffer.data()) { + value_buffer = std::string(*header_value); + } + RE2::GlobalReplace(&value_buffer, *policy.regex, policy.regex_substitution); + header_value = value_buffer; + } + return XXH64(header_value->data(), header_value->size(), 0); +} + +bool UnderFraction(const uint32_t fraction_per_million) { + // Generate a random number in [0, 1000000). + const uint32_t random_number = rand() % 1000000; + return random_number < fraction_per_million; +} + +ConfigSelector::CallConfig XdsResolver::XdsConfigSelector::GetCallConfig( + GetCallConfigArgs args) { + for (const auto& entry : route_table_) { + // Path matching. + if (!entry.route.matchers.path_matcher.Match( + StringViewFromSlice(*args.path))) { + continue; + } + // Header Matching. + if (!HeadersMatch(entry.route.matchers.header_matchers, + args.initial_metadata)) { + continue; + } + // Match fraction check + if (entry.route.matchers.fraction_per_million.has_value() && + !UnderFraction(entry.route.matchers.fraction_per_million.value())) { + continue; + } + // Found a route match + absl::string_view cluster_name; + RefCountedPtr method_config; + if (entry.route.weighted_clusters.empty()) { + cluster_name = entry.route.cluster_name; + method_config = entry.method_config; + } else { + const uint32_t key = + rand() % + entry.weighted_cluster_state[entry.weighted_cluster_state.size() - 1] + .range_end; + // Find the index in weighted clusters corresponding to key. + size_t mid = 0; + size_t start_index = 0; + size_t end_index = entry.weighted_cluster_state.size() - 1; + size_t index = 0; + while (end_index > start_index) { + mid = (start_index + end_index) / 2; + if (entry.weighted_cluster_state[mid].range_end > key) { + end_index = mid; + } else if (entry.weighted_cluster_state[mid].range_end < key) { + start_index = mid + 1; + } else { + index = mid + 1; + break; + } + } + if (index == 0) index = start_index; + GPR_ASSERT(entry.weighted_cluster_state[index].range_end > key); + cluster_name = entry.weighted_cluster_state[index].cluster; + method_config = entry.weighted_cluster_state[index].method_config; + } + auto it = clusters_.find(cluster_name); + GPR_ASSERT(it != clusters_.end()); + // Generate a hash. + absl::optional hash; + for (const auto& hash_policy : entry.route.hash_policies) { + absl::optional new_hash; + switch (hash_policy.type) { + case XdsApi::Route::HashPolicy::HEADER: + new_hash = HeaderHashHelper(hash_policy, args.initial_metadata); + break; + case XdsApi::Route::HashPolicy::CHANNEL_ID: + new_hash = static_cast( + reinterpret_cast(resolver_.get())); + break; + default: + GPR_ASSERT(0); + } + if (new_hash.has_value()) { + // Rotating the old value prevents duplicate hash rules from cancelling + // each other out and preserves all of the entropy + const uint64_t old_value = + hash.has_value() ? ((hash.value() << 1) | (hash.value() >> 63)) : 0; + hash = old_value ^ new_hash.value(); + } + // If the policy is a terminal policy and a hash has been generated, + // ignore the rest of the hash policies. + if (hash_policy.terminal && hash.has_value()) { + break; + } + } + if (!hash.has_value()) { + // If there is no hash, we just choose a random value as a default. + // We cannot directly use the result of rand() as the hash value, + // since it is a 32-bit number and not a 64-bit number and will + // therefore not be evenly distributed. + uint32_t upper = rand(); + uint32_t lower = rand(); + hash = (static_cast(upper) << 32) | lower; + } + CallConfig call_config; + if (method_config != nullptr) { + call_config.method_configs = + method_config->GetMethodParsedConfigVector(grpc_empty_slice()); + call_config.service_config = std::move(method_config); + } + call_config.call_attributes[kXdsClusterAttribute] = it->first; + std::string hash_string = absl::StrCat(hash.value()); + char* hash_value = + static_cast(args.arena->Alloc(hash_string.size() + 1)); + memcpy(hash_value, hash_string.c_str(), hash_string.size()); + hash_value[hash_string.size()] = '\0'; + call_config.call_attributes[kRequestRingHashAttribute] = hash_value; + call_config.call_dispatch_controller = + args.arena->New(it->second->Ref()); + return call_config; + } + return CallConfig(); +} + +// +// XdsResolver +// + +void XdsResolver::StartLocked() { + grpc_error_handle error = GRPC_ERROR_NONE; + xds_client_ = XdsClient::GetOrCreate(args_, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "Failed to create xds client -- channel will remain in " + "TRANSIENT_FAILURE: %s", + grpc_error_std_string(error).c_str()); + result_handler_->ReturnError(error); + return; + } + grpc_pollset_set_add_pollset_set(xds_client_->interested_parties(), + interested_parties_); + auto watcher = absl::make_unique(Ref()); + listener_watcher_ = watcher.get(); + xds_client_->WatchListenerData(server_name_, std::move(watcher)); +} + +void XdsResolver::ShutdownLocked() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] shutting down", this); + } + if (xds_client_ != nullptr) { + if (listener_watcher_ != nullptr) { + xds_client_->CancelListenerDataWatch(server_name_, listener_watcher_, + /*delay_unsubscription=*/false); + } + if (route_config_watcher_ != nullptr) { + xds_client_->CancelRouteConfigDataWatch( + server_name_, route_config_watcher_, /*delay_unsubscription=*/false); + } + grpc_pollset_set_del_pollset_set(xds_client_->interested_parties(), + interested_parties_); + xds_client_.reset(); + } +} + +void XdsResolver::OnListenerUpdate(XdsApi::LdsUpdate listener) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] received updated listener data", this); + } + if (listener.http_connection_manager.route_config_name != + route_config_name_) { + if (route_config_watcher_ != nullptr) { + xds_client_->CancelRouteConfigDataWatch( + route_config_name_, route_config_watcher_, + /*delay_unsubscription=*/ + !listener.http_connection_manager.route_config_name.empty()); + route_config_watcher_ = nullptr; + } + route_config_name_ = + std::move(listener.http_connection_manager.route_config_name); + if (!route_config_name_.empty()) { + current_virtual_host_.routes.clear(); + auto watcher = absl::make_unique(Ref()); + route_config_watcher_ = watcher.get(); + xds_client_->WatchRouteConfigData(route_config_name_, std::move(watcher)); + } + } + current_listener_ = std::move(listener); + if (route_config_name_.empty()) { + GPR_ASSERT( + current_listener_.http_connection_manager.rds_update.has_value()); + OnRouteConfigUpdate( + std::move(*current_listener_.http_connection_manager.rds_update)); + } else { + // HCM may contain newer filter config. We need to propagate the update as + // config selector to the channel + GenerateResult(); + } +} + +void XdsResolver::OnRouteConfigUpdate(XdsApi::RdsUpdate rds_update) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] received updated route config", this); + } + // Find the relevant VirtualHost from the RouteConfiguration. + XdsApi::RdsUpdate::VirtualHost* vhost = + rds_update.FindVirtualHostForDomain(server_name_); + if (vhost == nullptr) { + OnError(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("could not find VirtualHost for ", server_name_, + " in RouteConfiguration"))); + return; + } + // Save the virtual host in the resolver. + current_virtual_host_ = std::move(*vhost); + // Send a new result to the channel. + GenerateResult(); +} + +void XdsResolver::OnError(grpc_error_handle error) { + gpr_log(GPR_ERROR, "[xds_resolver %p] received error from XdsClient: %s", + this, grpc_error_std_string(error).c_str()); + Result result; + grpc_arg new_arg = xds_client_->MakeChannelArg(); + result.args = grpc_channel_args_copy_and_add(args_, &new_arg, 1); + result.service_config_error = error; + result_handler_->ReturnResult(std::move(result)); +} + +void XdsResolver::OnResourceDoesNotExist() { + gpr_log(GPR_ERROR, + "[xds_resolver %p] LDS/RDS resource does not exist -- clearing " + "update and returning empty service config", + this); + current_virtual_host_.routes.clear(); + Result result; + result.service_config = + ServiceConfig::Create(args_, "{}", &result.service_config_error); + GPR_ASSERT(result.service_config != nullptr); + result.args = grpc_channel_args_copy(args_); + result_handler_->ReturnResult(std::move(result)); +} + +grpc_error_handle XdsResolver::CreateServiceConfig( + RefCountedPtr* service_config) { + std::vector clusters; + for (const auto& cluster : cluster_state_map_) { + clusters.push_back( + absl::StrFormat(" \"%s\":{\n" + " \"childPolicy\":[ {\n" + " \"cds_experimental\":{\n" + " \"cluster\": \"%s\"\n" + " }\n" + " } ]\n" + " }", + cluster.first, cluster.first)); + } + std::vector config_parts; + config_parts.push_back( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"xds_cluster_manager_experimental\":{\n" + " \"children\":{\n"); + config_parts.push_back(absl::StrJoin(clusters, ",\n")); + config_parts.push_back( + " }\n" + " } }\n" + " ]\n" + "}"); + std::string json = absl::StrJoin(config_parts, ""); + grpc_error_handle error = GRPC_ERROR_NONE; + *service_config = ServiceConfig::Create(args_, json.c_str(), &error); + return error; +} + +void XdsResolver::GenerateResult() { + if (current_virtual_host_.routes.empty()) return; + // First create XdsConfigSelector, which may add new entries to the cluster + // state map, and then CreateServiceConfig for LB policies. + grpc_error_handle error = GRPC_ERROR_NONE; + auto config_selector = MakeRefCounted(Ref(), &error); + if (error != GRPC_ERROR_NONE) { + OnError(error); + return; + } + Result result; + error = CreateServiceConfig(&result.service_config); + if (error != GRPC_ERROR_NONE) { + OnError(error); + return; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_resolver_trace)) { + gpr_log(GPR_INFO, "[xds_resolver %p] generated service config: %s", this, + result.service_config->json_string().c_str()); + } + grpc_arg new_args[] = { + xds_client_->MakeChannelArg(), + config_selector->MakeChannelArg(), + }; + result.args = + grpc_channel_args_copy_and_add(args_, new_args, GPR_ARRAY_SIZE(new_args)); + result_handler_->ReturnResult(std::move(result)); +} + +void XdsResolver::MaybeRemoveUnusedClusters() { + bool update_needed = false; + for (auto it = cluster_state_map_.begin(); it != cluster_state_map_.end();) { + RefCountedPtr cluster_state = it->second->RefIfNonZero(); + if (cluster_state != nullptr) { + ++it; + } else { + update_needed = true; + it = cluster_state_map_.erase(it); + } + } + if (update_needed && xds_client_ != nullptr) { + // Send a new result to the channel. + GenerateResult(); + } +} + +// +// Factory +// + +class XdsResolverFactory : public ResolverFactory { + public: + bool IsValidUri(const URI& uri) const override { + if (GPR_UNLIKELY(!uri.authority().empty())) { + gpr_log(GPR_ERROR, "URI authority not supported"); + return false; + } + return true; + } + + OrphanablePtr CreateResolver(ResolverArgs args) const override { + if (!IsValidUri(args.uri)) return nullptr; + return MakeOrphanable(std::move(args)); + } + + const char* scheme() const override { return "xds"; } +}; + +} // namespace + +} // namespace grpc_core + +void grpc_resolver_xds_init() { + grpc_core::ResolverRegistry::Builder::RegisterResolverFactory( + absl::make_unique()); +} + +void grpc_resolver_xds_shutdown() {} diff --git a/src/core/ext/filters/client_channel/resolver_registry.cc b/src/core/ext/filters/client_channel/resolver_registry.cc new file mode 100644 index 00000000..b1d0ee21 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver_registry.cc @@ -0,0 +1,195 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include + +namespace grpc_core { + +namespace { + +class RegistryState { + public: + RegistryState() : default_prefix_(gpr_strdup("dns:///")) {} + + void SetDefaultPrefix(const char* default_resolver_prefix) { + GPR_ASSERT(default_resolver_prefix != nullptr); + GPR_ASSERT(*default_resolver_prefix != '\0'); + default_prefix_.reset(gpr_strdup(default_resolver_prefix)); + } + + void RegisterResolverFactory(std::unique_ptr factory) { + for (size_t i = 0; i < factories_.size(); ++i) { + GPR_ASSERT(strcmp(factories_[i]->scheme(), factory->scheme()) != 0); + } + factories_.push_back(std::move(factory)); + } + + ResolverFactory* LookupResolverFactory(absl::string_view scheme) const { + for (size_t i = 0; i < factories_.size(); ++i) { + if (scheme == factories_[i]->scheme()) { + return factories_[i].get(); + } + } + return nullptr; + } + + // Returns the factory for the scheme of \a target. If \a target does + // not parse as a URI, prepends \a default_prefix_ and tries again. + // If URI parsing is successful (in either attempt), sets \a uri to + // point to the parsed URI. + // If \a default_prefix_ needs to be prepended, sets \a canonical_target + // to the canonical target string. + ResolverFactory* FindResolverFactory(absl::string_view target, URI* uri, + std::string* canonical_target) const { + GPR_ASSERT(uri != nullptr); + absl::StatusOr tmp_uri = URI::Parse(target); + ResolverFactory* factory = + tmp_uri.ok() ? LookupResolverFactory(tmp_uri->scheme()) : nullptr; + if (factory != nullptr) { + *uri = std::move(*tmp_uri); + return factory; + } + *canonical_target = absl::StrCat(default_prefix_.get(), target); + absl::StatusOr tmp_uri2 = URI::Parse(*canonical_target); + factory = + tmp_uri2.ok() ? LookupResolverFactory(tmp_uri2->scheme()) : nullptr; + if (factory != nullptr) { + *uri = std::move(*tmp_uri2); + return factory; + } + if (!tmp_uri.ok() || !tmp_uri2.ok()) { + gpr_log(GPR_ERROR, "%s", + absl::StrFormat("Error parsing URI(s). '%s':%s; '%s':%s", target, + tmp_uri.status().ToString(), *canonical_target, + tmp_uri2.status().ToString()) + .c_str()); + return nullptr; + } + gpr_log(GPR_ERROR, "Don't know how to resolve '%s' or '%s'.", + std::string(target).c_str(), canonical_target->c_str()); + return nullptr; + } + + private: + // We currently support 10 factories without doing additional + // allocation. This number could be raised if there is a case where + // more factories are needed and the additional allocations are + // hurting performance (which is unlikely, since these allocations + // only occur at gRPC initialization time). + absl::InlinedVector, 10> factories_; + grpc_core::UniquePtr default_prefix_; +}; + +static RegistryState* g_state = nullptr; + +} // namespace + +// +// ResolverRegistry::Builder +// + +void ResolverRegistry::Builder::InitRegistry() { + if (g_state == nullptr) g_state = new RegistryState(); +} + +void ResolverRegistry::Builder::ShutdownRegistry() { + delete g_state; + g_state = nullptr; +} + +void ResolverRegistry::Builder::SetDefaultPrefix(const char* default_prefix) { + InitRegistry(); + g_state->SetDefaultPrefix(default_prefix); +} + +void ResolverRegistry::Builder::RegisterResolverFactory( + std::unique_ptr factory) { + InitRegistry(); + g_state->RegisterResolverFactory(std::move(factory)); +} + +// +// ResolverRegistry +// + +ResolverFactory* ResolverRegistry::LookupResolverFactory(const char* scheme) { + GPR_ASSERT(g_state != nullptr); + return g_state->LookupResolverFactory(scheme); +} + +bool ResolverRegistry::IsValidTarget(absl::string_view target) { + URI uri; + std::string canonical_target; + ResolverFactory* factory = + g_state->FindResolverFactory(target, &uri, &canonical_target); + return factory == nullptr ? false : factory->IsValidUri(uri); +} + +OrphanablePtr ResolverRegistry::CreateResolver( + const char* target, const grpc_channel_args* args, + grpc_pollset_set* pollset_set, + std::shared_ptr work_serializer, + std::unique_ptr result_handler) { + GPR_ASSERT(g_state != nullptr); + ResolverArgs resolver_args; + ResolverFactory* factory = g_state->FindResolverFactory( + target, &resolver_args.uri, &resolver_args.uri_string); + if (factory == nullptr) return nullptr; + if (resolver_args.uri_string.empty()) resolver_args.uri_string = target; + resolver_args.args = args; + resolver_args.pollset_set = pollset_set; + resolver_args.work_serializer = std::move(work_serializer); + resolver_args.result_handler = std::move(result_handler); + return factory->CreateResolver(std::move(resolver_args)); +} + +std::string ResolverRegistry::GetDefaultAuthority(absl::string_view target) { + GPR_ASSERT(g_state != nullptr); + URI uri; + std::string canonical_target; + ResolverFactory* factory = + g_state->FindResolverFactory(target, &uri, &canonical_target); + std::string authority = + factory == nullptr ? "" : factory->GetDefaultAuthority(uri); + return authority; +} + +grpc_core::UniquePtr ResolverRegistry::AddDefaultPrefixIfNeeded( + const char* target) { + GPR_ASSERT(g_state != nullptr); + URI uri; + std::string canonical_target; + g_state->FindResolverFactory(target, &uri, &canonical_target); + return grpc_core::UniquePtr(canonical_target.empty() + ? gpr_strdup(target) + : gpr_strdup(canonical_target.c_str())); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.cc b/src/core/ext/filters/client_channel/resolver_result_parsing.cc new file mode 100644 index 00000000..4154ee78 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver_result_parsing.cc @@ -0,0 +1,189 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/json/json_util.h" +#include "src/core/lib/uri/uri_parser.h" + +// As per the retry design, we do not allow more than 5 retry attempts. +#define MAX_MAX_RETRY_ATTEMPTS 5 + +namespace grpc_core { +namespace internal { + +namespace { +size_t g_client_channel_service_config_parser_index; +} + +size_t ClientChannelServiceConfigParser::ParserIndex() { + return g_client_channel_service_config_parser_index; +} + +void ClientChannelServiceConfigParser::Register() { + g_client_channel_service_config_parser_index = + ServiceConfigParser::RegisterParser( + absl::make_unique()); +} + +namespace { + +absl::optional ParseHealthCheckConfig(const Json& field, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + if (field.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:healthCheckConfig error:should be of type object"); + return absl::nullopt; + } + std::vector error_list; + absl::optional service_name; + auto it = field.object_value().find("serviceName"); + if (it != field.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:serviceName error:should be of type string")); + } else { + service_name = it->second.string_value(); + } + } + *error = + GRPC_ERROR_CREATE_FROM_VECTOR("field:healthCheckConfig", &error_list); + return service_name; +} + +} // namespace + +std::unique_ptr +ClientChannelServiceConfigParser::ParseGlobalParams( + const grpc_channel_args* /*args*/, const Json& json, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + std::vector error_list; + // Parse LB config. + RefCountedPtr parsed_lb_config; + auto it = json.object_value().find("loadBalancingConfig"); + if (it != json.object_value().end()) { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + parsed_lb_config = LoadBalancingPolicyRegistry::ParseLoadBalancingConfig( + it->second, &parse_error); + if (parse_error != GRPC_ERROR_NONE) { + std::vector lb_errors; + lb_errors.push_back(parse_error); + error_list.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:loadBalancingConfig", &lb_errors)); + } + } + // Parse deprecated LB policy. + std::string lb_policy_name; + it = json.object_value().find("loadBalancingPolicy"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:type should be string")); + } else { + lb_policy_name = it->second.string_value(); + for (size_t i = 0; i < lb_policy_name.size(); ++i) { + lb_policy_name[i] = tolower(lb_policy_name[i]); + } + bool requires_config = false; + if (!LoadBalancingPolicyRegistry::LoadBalancingPolicyExists( + lb_policy_name.c_str(), &requires_config)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:loadBalancingPolicy error:Unknown lb policy")); + } else if (requires_config) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("field:loadBalancingPolicy error:", lb_policy_name, + " requires a config. Please use loadBalancingConfig " + "instead."))); + } + } + } + // Parse health check config. + absl::optional health_check_service_name; + it = json.object_value().find("healthCheckConfig"); + if (it != json.object_value().end()) { + grpc_error_handle parsing_error = GRPC_ERROR_NONE; + health_check_service_name = + ParseHealthCheckConfig(it->second, &parsing_error); + if (parsing_error != GRPC_ERROR_NONE) { + error_list.push_back(parsing_error); + } + } + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Client channel global parser", + &error_list); + if (*error == GRPC_ERROR_NONE) { + return absl::make_unique( + std::move(parsed_lb_config), std::move(lb_policy_name), + std::move(health_check_service_name)); + } + return nullptr; +} + +std::unique_ptr +ClientChannelServiceConfigParser::ParsePerMethodParams( + const grpc_channel_args* /*args*/, const Json& json, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + std::vector error_list; + // Parse waitForReady. + absl::optional wait_for_ready; + auto it = json.object_value().find("waitForReady"); + if (it != json.object_value().end()) { + if (it->second.type() == Json::Type::JSON_TRUE) { + wait_for_ready.emplace(true); + } else if (it->second.type() == Json::Type::JSON_FALSE) { + wait_for_ready.emplace(false); + } else { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:waitForReady error:Type should be true/false")); + } + } + // Parse timeout. + grpc_millis timeout = 0; + ParseJsonObjectFieldAsDuration(json.object_value(), "timeout", &timeout, + &error_list, false); + // Return result. + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Client channel parser", &error_list); + if (*error == GRPC_ERROR_NONE) { + return absl::make_unique(timeout, + wait_for_ready); + } + return nullptr; +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/retry_filter.cc b/src/core/ext/filters/client_channel/retry_filter.cc new file mode 100644 index 00000000..e8a6a467 --- /dev/null +++ b/src/core/ext/filters/client_channel/retry_filter.cc @@ -0,0 +1,2573 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/retry_filter.h" + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/strip.h" + +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/retry_service_config.h" +#include "src/core/ext/filters/client_channel/retry_throttle.h" +#include "src/core/ext/service_config/service_config.h" +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" +#include "src/core/lib/uri/uri_parser.h" + +// +// Retry filter +// + +// This filter is intended to be used in the DynamicFilter stack in the +// client channel, which is situated between the name resolver and the +// LB policy. Normally, the last filter in the DynamicFilter stack is +// the DynamicTerminationFilter (see client_channel.cc), which creates a +// LoadBalancedCall and delegates to it. However, when retries are +// enabled, this filter is used instead of the DynamicTerminationFilter. +// +// In order to support retries, we act as a proxy for stream op batches. +// When we get a batch from the surface, we add it to our list of pending +// batches, and we then use those batches to construct separate "child" +// batches to be started on an LB call. When the child batches return, we +// then decide which pending batches have been completed and schedule their +// callbacks accordingly. If a call attempt fails and we want to retry it, +// we create a new LB call and start again, constructing new "child" batches +// for the new LB call. +// +// Note that retries are committed when receiving data from the server +// (except for Trailers-Only responses). However, there may be many +// send ops started before receiving any data, so we may have already +// completed some number of send ops (and returned the completions up to +// the surface) by the time we realize that we need to retry. To deal +// with this, we cache data for send ops, so that we can replay them on a +// different LB call even after we have completed the original batches. +// +// The code is structured as follows: +// - In CallData (in the parent channel), we maintain a list of pending +// ops and cached data for send ops. +// - There is a CallData::CallAttempt object for each retry attempt. +// This object contains the LB call for that attempt and state to indicate +// which ops from the CallData object have already been sent down to that +// LB call. +// - There is a CallData::CallAttempt::BatchData object for each "child" +// batch sent on the LB call. +// +// When constructing the "child" batches, we compare the state in the +// CallAttempt object against the state in the CallData object to see +// which batches need to be sent on the LB call for a given attempt. + +// TODO(roth): In subsequent PRs: +// - add support for transparent retries (including initial metadata) +// - implement hedging + +// By default, we buffer 256 KiB per RPC for retries. +// TODO(roth): Do we have any data to suggest a better value? +#define DEFAULT_PER_RPC_RETRY_BUFFER_SIZE (256 << 10) + +// This value was picked arbitrarily. It can be changed if there is +// any even moderately compelling reason to do so. +#define RETRY_BACKOFF_JITTER 0.2 + +namespace grpc_core { + +namespace { + +using internal::RetryGlobalConfig; +using internal::RetryMethodConfig; +using internal::RetryServiceConfigParser; +using internal::ServerRetryThrottleData; + +TraceFlag grpc_retry_trace(false, "retry"); + +// +// RetryFilter +// + +class RetryFilter { + public: + class CallData; + + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(args->is_last); + GPR_ASSERT(elem->filter == &kRetryFilterVtable); + grpc_error_handle error = GRPC_ERROR_NONE; + new (elem->channel_data) RetryFilter(args->channel_args, &error); + return error; + } + + static void Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~RetryFilter(); + } + + // Will never be called. + static void StartTransportOp(grpc_channel_element* /*elem*/, + grpc_transport_op* /*op*/) {} + static void GetChannelInfo(grpc_channel_element* /*elem*/, + const grpc_channel_info* /*info*/) {} + + private: + static size_t GetMaxPerRpcRetryBufferSize(const grpc_channel_args* args) { + return static_cast(grpc_channel_args_find_integer( + args, GRPC_ARG_PER_RPC_RETRY_BUFFER_SIZE, + {DEFAULT_PER_RPC_RETRY_BUFFER_SIZE, 0, INT_MAX})); + } + + RetryFilter(const grpc_channel_args* args, grpc_error_handle* error) + : client_channel_(grpc_channel_args_find_pointer( + args, GRPC_ARG_CLIENT_CHANNEL)), + per_rpc_retry_buffer_size_(GetMaxPerRpcRetryBufferSize(args)) { + // Get retry throttling parameters from service config. + auto* service_config = grpc_channel_args_find_pointer( + args, GRPC_ARG_SERVICE_CONFIG_OBJ); + if (service_config == nullptr) return; + const auto* config = static_cast( + service_config->GetGlobalParsedConfig( + RetryServiceConfigParser::ParserIndex())); + if (config == nullptr) return; + // Get server name from target URI. + const char* server_uri = + grpc_channel_args_find_string(args, GRPC_ARG_SERVER_URI); + if (server_uri == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "server URI channel arg missing or wrong type in client channel " + "filter"); + return; + } + absl::StatusOr uri = URI::Parse(server_uri); + if (!uri.ok() || uri->path().empty()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "could not extract server name from target URI"); + return; + } + std::string server_name(absl::StripPrefix(uri->path(), "/")); + // Get throttling config for server_name. + retry_throttle_data_ = internal::ServerRetryThrottleMap::GetDataForServer( + server_name, config->max_milli_tokens(), config->milli_token_ratio()); + } + + ClientChannel* client_channel_; + size_t per_rpc_retry_buffer_size_; + RefCountedPtr retry_throttle_data_; +}; + +// +// RetryFilter::CallData +// + +class RetryFilter::CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args); + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure); + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch); + static void SetPollent(grpc_call_element* elem, grpc_polling_entity* pollent); + + private: + class CallStackDestructionBarrier; + + // Pending batches stored in call data. + struct PendingBatch { + // The pending batch. If nullptr, this slot is empty. + grpc_transport_stream_op_batch* batch = nullptr; + // Indicates whether payload for send ops has been cached in CallData. + bool send_ops_cached = false; + }; + + // State associated with each call attempt. + class CallAttempt : public RefCounted { + public: + explicit CallAttempt(CallData* calld); + ~CallAttempt() override; + + bool lb_call_committed() const { return lb_call_committed_; } + + // Constructs and starts whatever batches are needed on this call + // attempt. + void StartRetriableBatches(); + + // Frees cached send ops that have already been completed after + // committing the call. + void FreeCachedSendOpDataAfterCommit(); + + // Cancels the call attempt. + void CancelFromSurface(grpc_transport_stream_op_batch* cancel_batch); + + private: + // State used for starting a retryable batch on the call attempt's LB call. + // This provides its own grpc_transport_stream_op_batch and other data + // structures needed to populate the ops in the batch. + // We allocate one struct on the arena for each attempt at starting a + // batch on a given LB call. + class BatchData + : public RefCounted { + public: + BatchData(RefCountedPtr call_attempt, int refcount, + bool set_on_complete); + ~BatchData() override; + + grpc_transport_stream_op_batch* batch() { return &batch_; } + + // Adds retriable send_initial_metadata op. + void AddRetriableSendInitialMetadataOp(); + // Adds retriable send_message op. + void AddRetriableSendMessageOp(); + // Adds retriable send_trailing_metadata op. + void AddRetriableSendTrailingMetadataOp(); + // Adds retriable recv_initial_metadata op. + void AddRetriableRecvInitialMetadataOp(); + // Adds retriable recv_message op. + void AddRetriableRecvMessageOp(); + // Adds retriable recv_trailing_metadata op. + void AddRetriableRecvTrailingMetadataOp(); + // Adds cancel_stream op. + void AddCancelStreamOp(grpc_error_handle error); + + private: + // Frees cached send ops that were completed by the completed batch in + // batch_data. Used when batches are completed after the call is + // committed. + void FreeCachedSendOpDataForCompletedBatch(); + + // If there is a pending recv_initial_metadata op, adds a closure + // to closures for recv_initial_metadata_ready. + void MaybeAddClosureForRecvInitialMetadataCallback( + grpc_error_handle error, CallCombinerClosureList* closures); + // Intercepts recv_initial_metadata_ready callback for retries. + // Commits the call and returns the initial metadata up the stack. + static void RecvInitialMetadataReady(void* arg, grpc_error_handle error); + + // If there is a pending recv_message op, adds a closure to closures + // for recv_message_ready. + void MaybeAddClosureForRecvMessageCallback( + grpc_error_handle error, CallCombinerClosureList* closures); + // Intercepts recv_message_ready callback for retries. + // Commits the call and returns the message up the stack. + static void RecvMessageReady(void* arg, grpc_error_handle error); + + // If there is a pending recv_trailing_metadata op, adds a closure to + // closures for recv_trailing_metadata_ready. + void MaybeAddClosureForRecvTrailingMetadataReady( + grpc_error_handle error, CallCombinerClosureList* closures); + // Adds any necessary closures for deferred batch completion + // callbacks to closures. + void AddClosuresForDeferredCompletionCallbacks( + CallCombinerClosureList* closures); + // For any pending batch containing an op that has not yet been started, + // adds the pending batch's completion closures to closures. + void AddClosuresToFailUnstartedPendingBatches( + grpc_error_handle error, CallCombinerClosureList* closures); + // Runs necessary closures upon completion of a call attempt. + void RunClosuresForCompletedCall(grpc_error_handle error); + // Intercepts recv_trailing_metadata_ready callback for retries. + // Commits the call and returns the trailing metadata up the stack. + static void RecvTrailingMetadataReady(void* arg, grpc_error_handle error); + + // Adds the on_complete closure for the pending batch completed in + // batch_data to closures. + void AddClosuresForCompletedPendingBatch( + grpc_error_handle error, CallCombinerClosureList* closures); + + // If there are any cached ops to replay or pending ops to start on the + // LB call, adds them to closures. + void AddClosuresForReplayOrPendingSendOps( + CallCombinerClosureList* closures); + + // Callback used to intercept on_complete from LB calls. + static void OnComplete(void* arg, grpc_error_handle error); + + // Callback used to handle on_complete for internally generated + // cancel_stream op. + static void OnCompleteForCancelOp(void* arg, grpc_error_handle error); + + RefCountedPtr call_attempt_; + // The batch to use in the LB call. + // Its payload field points to CallAttempt::batch_payload_. + grpc_transport_stream_op_batch batch_; + // For intercepting on_complete. + grpc_closure on_complete_; + }; + + class AttemptDispatchController + : public ConfigSelector::CallDispatchController { + public: + explicit AttemptDispatchController(CallAttempt* call_attempt) + : call_attempt_(call_attempt) {} + + // Will never be called. + bool ShouldRetry() override { return false; } + + void Commit() override { + call_attempt_->lb_call_committed_ = true; + auto* calld = call_attempt_->calld_; + if (calld->retry_committed_) { + auto* service_config_call_data = + static_cast( + calld->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA] + .value); + service_config_call_data->call_dispatch_controller()->Commit(); + } + } + + private: + CallAttempt* call_attempt_; + }; + + // Creates a BatchData object on the call's arena with the + // specified refcount. If set_on_complete is true, the batch's + // on_complete callback will be set to point to on_complete(); + // otherwise, the batch's on_complete callback will be null. + BatchData* CreateBatch(int refcount, bool set_on_complete) { + return calld_->arena_->New(Ref(DEBUG_LOCATION, "CreateBatch"), + refcount, set_on_complete); + } + + // If there are any cached send ops that need to be replayed on this + // call attempt, creates and returns a new batch to replay those ops. + // Otherwise, returns nullptr. + BatchData* MaybeCreateBatchForReplay(); + + // Adds a closure to closures that will execute batch in the call combiner. + void AddClosureForBatch(grpc_transport_stream_op_batch* batch, + const char* reason, + CallCombinerClosureList* closures); + + // Helper function used to start a recv_trailing_metadata batch. This + // is used in the case where a recv_initial_metadata or recv_message + // op fails in a way that we know the call is over but when the application + // has not yet started its own recv_trailing_metadata op. + void AddBatchForInternalRecvTrailingMetadata( + CallCombinerClosureList* closures); + + // Adds a batch to closures to cancel this call attempt. + void AddBatchForCancelOp(grpc_error_handle error, + CallCombinerClosureList* closures); + + // Adds batches for pending batches to closures. + void AddBatchesForPendingBatches(CallCombinerClosureList* closures); + + // Adds whatever batches are needed on this attempt to closures. + void AddRetriableBatches(CallCombinerClosureList* closures); + + // Returns true if any send op in the batch was not yet started on this + // attempt. + bool PendingBatchContainsUnstartedSendOps(PendingBatch* pending); + + // Returns true if there are cached send ops to replay. + bool HaveSendOpsToReplay(); + + // If our retry state is no longer needed, switch to fast path by moving + // our LB call into calld_->committed_call_ and having calld_ drop + // its ref to us. + void MaybeSwitchToFastPath(); + + // Returns true if the call should be retried. + // If server_pushback_md is non-null, sets *server_pushback_ms. + bool ShouldRetry(absl::optional status, bool is_lb_drop, + grpc_mdelem* server_pushback_md, + grpc_millis* server_pushback_ms); + + // Abandons the call attempt. Unrefs any deferred batches. + void Abandon(); + + static void OnPerAttemptRecvTimer(void* arg, grpc_error_handle error); + static void OnPerAttemptRecvTimerLocked(void* arg, grpc_error_handle error); + void MaybeCancelPerAttemptRecvTimer(); + + CallData* calld_; + AttemptDispatchController attempt_dispatch_controller_; + OrphanablePtr lb_call_; + bool lb_call_committed_ = false; + + grpc_timer per_attempt_recv_timer_; + grpc_closure on_per_attempt_recv_timer_; + bool per_attempt_recv_timer_pending_ = false; + + // BatchData.batch.payload points to this. + grpc_transport_stream_op_batch_payload batch_payload_; + // For send_initial_metadata. + grpc_linked_mdelem retry_attempts_metadata_; + grpc_metadata_batch send_initial_metadata_{calld_->arena_}; + // For send_message. + // TODO(roth): Restructure this to eliminate use of ManualConstructor. + ManualConstructor send_message_; + // For send_trailing_metadata. + grpc_metadata_batch send_trailing_metadata_{calld_->arena_}; + // For intercepting recv_initial_metadata. + grpc_metadata_batch recv_initial_metadata_{calld_->arena_}; + grpc_closure recv_initial_metadata_ready_; + bool trailing_metadata_available_ = false; + // For intercepting recv_message. + grpc_closure recv_message_ready_; + OrphanablePtr recv_message_; + // For intercepting recv_trailing_metadata. + grpc_metadata_batch recv_trailing_metadata_{calld_->arena_}; + grpc_transport_stream_stats collect_stats_; + grpc_closure recv_trailing_metadata_ready_; + // These fields indicate which ops have been started and completed on + // this call attempt. + size_t started_send_message_count_ = 0; + size_t completed_send_message_count_ = 0; + size_t started_recv_message_count_ = 0; + size_t completed_recv_message_count_ = 0; + bool started_send_initial_metadata_ : 1; + bool completed_send_initial_metadata_ : 1; + bool started_send_trailing_metadata_ : 1; + bool completed_send_trailing_metadata_ : 1; + bool started_recv_initial_metadata_ : 1; + bool completed_recv_initial_metadata_ : 1; + bool started_recv_trailing_metadata_ : 1; + bool completed_recv_trailing_metadata_ : 1; + // State for callback processing. + RefCountedPtr recv_initial_metadata_ready_deferred_batch_; + grpc_error_handle recv_initial_metadata_error_ = GRPC_ERROR_NONE; + RefCountedPtr recv_message_ready_deferred_batch_; + grpc_error_handle recv_message_error_ = GRPC_ERROR_NONE; + struct OnCompleteDeferredBatch { + OnCompleteDeferredBatch(RefCountedPtr batch, + grpc_error_handle error) + : batch(std::move(batch)), error(error) {} + RefCountedPtr batch; + grpc_error_handle error; + }; + // There cannot be more than 3 pending send op batches at a time. + absl::InlinedVector + on_complete_deferred_batches_; + RefCountedPtr recv_trailing_metadata_internal_batch_; + grpc_error_handle recv_trailing_metadata_error_ = GRPC_ERROR_NONE; + bool seen_recv_trailing_metadata_from_surface_ : 1; + // NOTE: Do not move this next to the metadata bitfields above. That would + // save space but will also result in a data race because compiler + // will generate a 2 byte store which overwrites the meta-data + // fields upon setting this field. + bool abandoned_ : 1; + }; + + CallData(RetryFilter* chand, const grpc_call_element_args& args); + ~CallData(); + + void StartTransportStreamOpBatch(grpc_transport_stream_op_batch* batch); + + // Returns the index into pending_batches_ to be used for batch. + static size_t GetBatchIndex(grpc_transport_stream_op_batch* batch); + PendingBatch* PendingBatchesAdd(grpc_transport_stream_op_batch* batch); + void PendingBatchClear(PendingBatch* pending); + void MaybeClearPendingBatch(PendingBatch* pending); + static void FailPendingBatchInCallCombiner(void* arg, + grpc_error_handle error); + // Fails all pending batches. Does NOT yield call combiner. + void PendingBatchesFail(grpc_error_handle error); + // Returns a pointer to the first pending batch for which predicate(batch) + // returns true, or null if not found. + template + PendingBatch* PendingBatchFind(const char* log_message, Predicate predicate); + + // Caches data for send ops so that it can be retried later, if not + // already cached. + void MaybeCacheSendOpsForBatch(PendingBatch* pending); + void FreeCachedSendInitialMetadata(); + // Frees cached send_message at index idx. + void FreeCachedSendMessage(size_t idx); + void FreeCachedSendTrailingMetadata(); + void FreeAllCachedSendOpData(); + + // Commits the call so that no further retry attempts will be performed. + void RetryCommit(CallAttempt* call_attempt); + + // Starts a timer to retry after appropriate back-off. + // If server_pushback_ms is -1, retry_backoff_ is used. + void StartRetryTimer(grpc_millis server_pushback_ms); + + static void OnRetryTimer(void* arg, grpc_error_handle error); + static void OnRetryTimerLocked(void* arg, grpc_error_handle error); + + OrphanablePtr CreateLoadBalancedCall( + ConfigSelector::CallDispatchController* call_dispatch_controller); + + void CreateCallAttempt(); + + RetryFilter* chand_; + grpc_polling_entity* pollent_; + RefCountedPtr retry_throttle_data_; + const RetryMethodConfig* retry_policy_ = nullptr; + BackOff retry_backoff_; + + grpc_slice path_; // Request path. + grpc_millis deadline_; + Arena* arena_; + grpc_call_stack* owning_call_; + CallCombiner* call_combiner_; + grpc_call_context_element* call_context_; + + grpc_error_handle cancelled_from_surface_ = GRPC_ERROR_NONE; + + RefCountedPtr call_stack_destruction_barrier_; + + // TODO(roth): As part of implementing hedging, we will need to maintain a + // list of all pending attempts, so that we can cancel them all if the call + // gets cancelled. + RefCountedPtr call_attempt_; + + // LB call used when we've committed to a call attempt and the retry + // state for that attempt is no longer needed. This provides a fast + // path for long-running streaming calls that minimizes overhead. + OrphanablePtr committed_call_; + + // When are are not yet fully committed to a particular call (i.e., + // either we might still retry or we have committed to the call but + // there are still some cached ops to be replayed on the call), + // batches received from above will be added to this list, and they + // will not be removed until we have invoked their completion callbacks. + size_t bytes_buffered_for_retry_ = 0; + PendingBatch pending_batches_[MAX_PENDING_BATCHES]; + bool pending_send_initial_metadata_ : 1; + bool pending_send_message_ : 1; + bool pending_send_trailing_metadata_ : 1; + + // Retry state. + bool retry_committed_ : 1; + bool retry_timer_pending_ : 1; + int num_attempts_completed_ = 0; + grpc_timer retry_timer_; + grpc_closure retry_closure_; + + // Cached data for retrying send ops. + // send_initial_metadata + bool seen_send_initial_metadata_ = false; + grpc_metadata_batch send_initial_metadata_{arena_}; + uint32_t send_initial_metadata_flags_; + // TODO(roth): As part of implementing hedging, we'll probably need to + // have the LB call set a value in CallAttempt and then propagate it + // from CallAttempt to the parent call when we commit. Otherwise, we + // may leave this with a value for a peer other than the one we + // actually commit to. Alternatively, maybe see if there's a way to + // change the surface API such that the peer isn't available until + // after initial metadata is received? (Could even change the + // transport API to return this with the recv_initial_metadata op.) + gpr_atm* peer_string_; + // send_message + // When we get a send_message op, we replace the original byte stream + // with a CachingByteStream that caches the slices to a local buffer for + // use in retries. + // Note: We inline the cache for the first 3 send_message ops and use + // dynamic allocation after that. This number was essentially picked + // at random; it could be changed in the future to tune performance. + // TODO(roth): As part of implementing hedging, we may need some + // synchronization here, since ByteStreamCache does not provide any + // synchronization, so it's not safe to have multiple + // CachingByteStreams read from the same ByteStreamCache concurrently. + absl::InlinedVector send_messages_; + // send_trailing_metadata + bool seen_send_trailing_metadata_ = false; + grpc_metadata_batch send_trailing_metadata_{arena_}; +}; + +// +// RetryFilter::CallData::CallStackDestructionBarrier +// + +// A class to track the existence of LoadBalancedCall call stacks that +// we've created. We wait until all such call stacks have been +// destroyed before we return the on_call_stack_destruction closure up +// to the surface. +// +// The parent RetryFilter::CallData object holds a ref to this object. +// When it is destroyed, it will store the on_call_stack_destruction +// closure from the surface in this object and then release its ref. +// We also take a ref to this object for each LB call we create, and +// those refs are not released until the LB call stack is destroyed. +// When this object is destroyed, it will invoke the +// on_call_stack_destruction closure from the surface. +class RetryFilter::CallData::CallStackDestructionBarrier + : public RefCounted { + public: + CallStackDestructionBarrier() {} + + ~CallStackDestructionBarrier() override { + // TODO(yashkt) : This can potentially be a Closure::Run + ExecCtx::Run(DEBUG_LOCATION, on_call_stack_destruction_, GRPC_ERROR_NONE); + } + + // Set the closure from the surface. This closure will be invoked + // when this object is destroyed. + void set_on_call_stack_destruction(grpc_closure* on_call_stack_destruction) { + on_call_stack_destruction_ = on_call_stack_destruction; + } + + // Invoked to get an on_call_stack_destruction closure for a new LB call. + grpc_closure* MakeLbCallDestructionClosure(CallData* calld) { + Ref().release(); // Ref held by callback. + grpc_closure* on_lb_call_destruction_complete = + calld->arena_->New(); + GRPC_CLOSURE_INIT(on_lb_call_destruction_complete, + OnLbCallDestructionComplete, this, nullptr); + return on_lb_call_destruction_complete; + } + + private: + static void OnLbCallDestructionComplete(void* arg, + grpc_error_handle /*error*/) { + auto* self = static_cast(arg); + self->Unref(); + } + + grpc_closure* on_call_stack_destruction_ = nullptr; +}; + +// +// RetryFilter::CallData::CallAttempt +// + +RetryFilter::CallData::CallAttempt::CallAttempt(CallData* calld) + : RefCounted(GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace) ? "CallAttempt" + : nullptr), + calld_(calld), + attempt_dispatch_controller_(this), + batch_payload_(calld->call_context_), + started_send_initial_metadata_(false), + completed_send_initial_metadata_(false), + started_send_trailing_metadata_(false), + completed_send_trailing_metadata_(false), + started_recv_initial_metadata_(false), + completed_recv_initial_metadata_(false), + started_recv_trailing_metadata_(false), + completed_recv_trailing_metadata_(false), + seen_recv_trailing_metadata_from_surface_(false), + abandoned_(false) { + lb_call_ = calld->CreateLoadBalancedCall(&attempt_dispatch_controller_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: create lb_call=%p", + calld->chand_, calld, this, lb_call_.get()); + } + // If per_attempt_recv_timeout is set, start a timer. + if (calld->retry_policy_ != nullptr && + calld->retry_policy_->per_attempt_recv_timeout().has_value()) { + grpc_millis per_attempt_recv_deadline = + ExecCtx::Get()->Now() + + *calld->retry_policy_->per_attempt_recv_timeout(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: per-attempt timeout in %" PRId64 + " ms", + calld->chand_, calld, this, + *calld->retry_policy_->per_attempt_recv_timeout()); + } + // Schedule retry after computed delay. + GRPC_CLOSURE_INIT(&on_per_attempt_recv_timer_, OnPerAttemptRecvTimer, this, + nullptr); + GRPC_CALL_STACK_REF(calld->owning_call_, "OnPerAttemptRecvTimer"); + Ref(DEBUG_LOCATION, "OnPerAttemptRecvTimer").release(); + per_attempt_recv_timer_pending_ = true; + grpc_timer_init(&per_attempt_recv_timer_, per_attempt_recv_deadline, + &on_per_attempt_recv_timer_); + } +} + +RetryFilter::CallData::CallAttempt::~CallAttempt() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: destroying call attempt", + calld_->chand_, calld_, this); + } +} + +void RetryFilter::CallData::CallAttempt::FreeCachedSendOpDataAfterCommit() { + // TODO(roth): When we implement hedging, this logic will need to get + // a bit more complex, because there may be other (now abandoned) call + // attempts still using this data. We may need to do some sort of + // ref-counting instead. + if (completed_send_initial_metadata_) { + calld_->FreeCachedSendInitialMetadata(); + } + for (size_t i = 0; i < completed_send_message_count_; ++i) { + calld_->FreeCachedSendMessage(i); + } + if (completed_send_trailing_metadata_) { + calld_->FreeCachedSendTrailingMetadata(); + } +} + +bool RetryFilter::CallData::CallAttempt::PendingBatchContainsUnstartedSendOps( + PendingBatch* pending) { + if (pending->batch->on_complete == nullptr) return false; + if (pending->batch->send_initial_metadata && + !started_send_initial_metadata_) { + return true; + } + if (pending->batch->send_message && + started_send_message_count_ < calld_->send_messages_.size()) { + return true; + } + if (pending->batch->send_trailing_metadata && + !started_send_trailing_metadata_) { + return true; + } + return false; +} + +bool RetryFilter::CallData::CallAttempt::HaveSendOpsToReplay() { + // We don't check send_initial_metadata here, because that op will always + // be started as soon as it is received from the surface, so it will + // never need to be started at this point. + return started_send_message_count_ < calld_->send_messages_.size() || + (calld_->seen_send_trailing_metadata_ && + !started_send_trailing_metadata_); +} + +void RetryFilter::CallData::CallAttempt::MaybeSwitchToFastPath() { + // If we're not yet committed, we can't switch yet. + // TODO(roth): As part of implementing hedging, this logic needs to + // check that *this* call attempt is the one that we've committed to. + // Might need to replace abandoned_ with an enum indicating whether we're + // in flight, abandoned, or the winning call attempt. + if (!calld_->retry_committed_) return; + // If we've already switched to fast path, there's nothing to do here. + if (calld_->committed_call_ != nullptr) return; + // If the perAttemptRecvTimeout timer is pending, we can't switch yet. + if (per_attempt_recv_timer_pending_) return; + // If there are still send ops to replay, we can't switch yet. + if (HaveSendOpsToReplay()) return; + // If we started an internal batch for recv_trailing_metadata but have not + // yet seen that op from the surface, we can't switch yet. + if (recv_trailing_metadata_internal_batch_ != nullptr) return; + // Switch to fast path. + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: retry state no longer needed; " + "moving LB call to parent and unreffing the call attempt", + calld_->chand_, calld_, this); + } + calld_->committed_call_ = std::move(lb_call_); + calld_->call_attempt_.reset(DEBUG_LOCATION, "MaybeSwitchToFastPath"); +} + +// If there are any cached send ops that need to be replayed on the +// current call attempt, creates and returns a new batch to replay those ops. +// Otherwise, returns nullptr. +RetryFilter::CallData::CallAttempt::BatchData* +RetryFilter::CallData::CallAttempt::MaybeCreateBatchForReplay() { + BatchData* replay_batch_data = nullptr; + // send_initial_metadata. + if (calld_->seen_send_initial_metadata_ && !started_send_initial_metadata_ && + !calld_->pending_send_initial_metadata_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: replaying previously completed " + "send_initial_metadata op", + calld_->chand_, calld_, this); + } + replay_batch_data = CreateBatch(1, true /* set_on_complete */); + replay_batch_data->AddRetriableSendInitialMetadataOp(); + } + // send_message. + // Note that we can only have one send_message op in flight at a time. + if (started_send_message_count_ < calld_->send_messages_.size() && + started_send_message_count_ == completed_send_message_count_ && + !calld_->pending_send_message_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: replaying previously completed " + "send_message op", + calld_->chand_, calld_, this); + } + if (replay_batch_data == nullptr) { + replay_batch_data = CreateBatch(1, true /* set_on_complete */); + } + replay_batch_data->AddRetriableSendMessageOp(); + } + // send_trailing_metadata. + // Note that we only add this op if we have no more send_message ops + // to start, since we can't send down any more send_message ops after + // send_trailing_metadata. + if (calld_->seen_send_trailing_metadata_ && + started_send_message_count_ == calld_->send_messages_.size() && + !started_send_trailing_metadata_ && + !calld_->pending_send_trailing_metadata_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: replaying previously completed " + "send_trailing_metadata op", + calld_->chand_, calld_, this); + } + if (replay_batch_data == nullptr) { + replay_batch_data = CreateBatch(1, true /* set_on_complete */); + } + replay_batch_data->AddRetriableSendTrailingMetadataOp(); + } + return replay_batch_data; +} + +namespace { + +void StartBatchInCallCombiner(void* arg, grpc_error_handle /*ignored*/) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + auto* lb_call = static_cast( + batch->handler_private.extra_arg); + // Note: This will release the call combiner. + lb_call->StartTransportStreamOpBatch(batch); +} + +} // namespace + +void RetryFilter::CallData::CallAttempt::AddClosureForBatch( + grpc_transport_stream_op_batch* batch, const char* reason, + CallCombinerClosureList* closures) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: adding batch (%s): %s", + calld_->chand_, calld_, this, reason, + grpc_transport_stream_op_batch_string(batch).c_str()); + } + batch->handler_private.extra_arg = lb_call_.get(); + GRPC_CLOSURE_INIT(&batch->handler_private.closure, StartBatchInCallCombiner, + batch, grpc_schedule_on_exec_ctx); + closures->Add(&batch->handler_private.closure, GRPC_ERROR_NONE, reason); +} + +void RetryFilter::CallData::CallAttempt:: + AddBatchForInternalRecvTrailingMetadata(CallCombinerClosureList* closures) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: call failed but " + "recv_trailing_metadata not started; starting it internally", + calld_->chand_, calld_, this); + } + // Create batch_data with 2 refs, since this batch will be unreffed twice: + // once for the recv_trailing_metadata_ready callback when the batch + // completes, and again when we actually get a recv_trailing_metadata + // op from the surface. + BatchData* batch_data = CreateBatch(2, false /* set_on_complete */); + batch_data->AddRetriableRecvTrailingMetadataOp(); + recv_trailing_metadata_internal_batch_.reset(batch_data); + AddClosureForBatch(batch_data->batch(), + "starting internal recv_trailing_metadata", closures); +} + +void RetryFilter::CallData::CallAttempt::AddBatchForCancelOp( + grpc_error_handle error, CallCombinerClosureList* closures) { + BatchData* cancel_batch_data = CreateBatch(1, /*set_on_complete=*/true); + cancel_batch_data->AddCancelStreamOp(error); + AddClosureForBatch(cancel_batch_data->batch(), + "start cancellation batch on call attempt", closures); +} + +void RetryFilter::CallData::CallAttempt::AddBatchesForPendingBatches( + CallCombinerClosureList* closures) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(calld_->pending_batches_); ++i) { + PendingBatch* pending = &calld_->pending_batches_[i]; + grpc_transport_stream_op_batch* batch = pending->batch; + if (batch == nullptr) continue; + bool has_send_ops = false; + // Skip any batch that either (a) has already been started on this + // call attempt or (b) we can't start yet because we're still + // replaying send ops that need to be completed first. + // TODO(roth): Note that if any one op in the batch can't be sent + // yet due to ops that we're replaying, we don't start any of the ops + // in the batch. This is probably okay, but it could conceivably + // lead to increased latency in some cases -- e.g., we could delay + // starting a recv op due to it being in the same batch with a send + // op. If/when we revamp the callback protocol in + // transport_stream_op_batch, we may be able to fix this. + if (batch->send_initial_metadata) { + if (started_send_initial_metadata_) continue; + has_send_ops = true; + } + if (batch->send_message) { + if (completed_send_message_count_ < started_send_message_count_) { + continue; + } + has_send_ops = true; + } + // Note that we only start send_trailing_metadata if we have no more + // send_message ops to start, since we can't send down any more + // send_message ops after send_trailing_metadata. + if (batch->send_trailing_metadata) { + if (started_send_message_count_ + batch->send_message < + calld_->send_messages_.size() || + started_send_trailing_metadata_) { + continue; + } + has_send_ops = true; + } + int num_callbacks = has_send_ops; // All send ops share one callback. + if (batch->recv_initial_metadata) { + if (started_recv_initial_metadata_) continue; + ++num_callbacks; + } + if (batch->recv_message) { + if (completed_recv_message_count_ < started_recv_message_count_) { + continue; + } + ++num_callbacks; + } + if (batch->recv_trailing_metadata) { + if (started_recv_trailing_metadata_) { + seen_recv_trailing_metadata_from_surface_ = true; + // If we previously completed a recv_trailing_metadata op + // initiated by AddBatchForInternalRecvTrailingMetadata(), use the + // result of that instead of trying to re-start this op. + if (GPR_UNLIKELY(recv_trailing_metadata_internal_batch_ != nullptr)) { + // If the batch completed, then trigger the completion callback + // directly, so that we return the previously returned results to + // the application. Otherwise, just unref the internally started + // batch, since we'll propagate the completion when it completes. + if (completed_recv_trailing_metadata_) { + closures->Add( + &recv_trailing_metadata_ready_, recv_trailing_metadata_error_, + "re-executing recv_trailing_metadata_ready to propagate " + "internally triggered result"); + // Ref will be released by callback. + recv_trailing_metadata_internal_batch_.release(); + } else { + recv_trailing_metadata_internal_batch_.reset( + DEBUG_LOCATION, + "internally started recv_trailing_metadata batch pending and " + "recv_trailing_metadata started from surface"); + GRPC_ERROR_UNREF(recv_trailing_metadata_error_); + } + recv_trailing_metadata_error_ = GRPC_ERROR_NONE; + } + // We don't want the fact that we've already started this op internally + // to prevent us from adding a batch that may contain other ops. + // Instead, we'll just skip adding this op below. + if (num_callbacks == 0) continue; + } else { + ++num_callbacks; + } + } + // If we're already committed and the following conditions are met, + // just send the batch down as-is: + // - The batch contains no cached send ops. (If it does, we need + // the logic below to use the cached payloads.) + // - The batch does not contain recv_trailing_metadata when we have + // already started an internal recv_trailing_metadata batch. (If + // we've already started an internal recv_trailing_metadata batch, + // then we need the logic below to send all ops in the batch + // *except* the recv_trailing_metadata op.) + if (calld_->retry_committed_ && !pending->send_ops_cached && + (!batch->recv_trailing_metadata || !started_recv_trailing_metadata_)) { + AddClosureForBatch( + batch, + "start non-replayable pending batch on call attempt after commit", + closures); + calld_->PendingBatchClear(pending); + continue; + } + // Create batch with the right number of callbacks. + BatchData* batch_data = + CreateBatch(num_callbacks, has_send_ops /* set_on_complete */); + // Cache send ops if needed. + calld_->MaybeCacheSendOpsForBatch(pending); + // send_initial_metadata. + if (batch->send_initial_metadata) { + batch_data->AddRetriableSendInitialMetadataOp(); + } + // send_message. + if (batch->send_message) { + batch_data->AddRetriableSendMessageOp(); + } + // send_trailing_metadata. + if (batch->send_trailing_metadata) { + batch_data->AddRetriableSendTrailingMetadataOp(); + } + // recv_initial_metadata. + if (batch->recv_initial_metadata) { + // recv_flags is only used on the server side. + GPR_ASSERT(batch->payload->recv_initial_metadata.recv_flags == nullptr); + batch_data->AddRetriableRecvInitialMetadataOp(); + } + // recv_message. + if (batch->recv_message) { + batch_data->AddRetriableRecvMessageOp(); + } + // recv_trailing_metadata. + if (batch->recv_trailing_metadata && !started_recv_trailing_metadata_) { + batch_data->AddRetriableRecvTrailingMetadataOp(); + } + AddClosureForBatch(batch_data->batch(), + "start replayable pending batch on call attempt", + closures); + } +} + +void RetryFilter::CallData::CallAttempt::AddRetriableBatches( + CallCombinerClosureList* closures) { + // Replay previously-returned send_* ops if needed. + BatchData* replay_batch_data = MaybeCreateBatchForReplay(); + if (replay_batch_data != nullptr) { + AddClosureForBatch(replay_batch_data->batch(), + "start replay batch on call attempt", closures); + } + // Now add pending batches. + AddBatchesForPendingBatches(closures); +} + +void RetryFilter::CallData::CallAttempt::StartRetriableBatches() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: constructing retriable batches", + calld_->chand_, calld_, this); + } + // Construct list of closures to execute, one for each pending batch. + CallCombinerClosureList closures; + AddRetriableBatches(&closures); + // Note: This will yield the call combiner. + // Start batches on LB call. + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: starting %" PRIuPTR + " retriable batches on lb_call=%p", + calld_->chand_, calld_, this, closures.size(), lb_call_.get()); + } + closures.RunClosures(calld_->call_combiner_); +} + +void RetryFilter::CallData::CallAttempt::CancelFromSurface( + grpc_transport_stream_op_batch* cancel_batch) { + MaybeCancelPerAttemptRecvTimer(); + // Propagate cancellation to LB call. + lb_call_->StartTransportStreamOpBatch(cancel_batch); +} + +bool RetryFilter::CallData::CallAttempt::ShouldRetry( + absl::optional status, bool is_lb_drop, + grpc_mdelem* server_pushback_md, grpc_millis* server_pushback_ms) { + // LB drops always inhibit retries. + if (is_lb_drop) return false; + // TODO(roth): Handle transparent retries here. + // If no retry policy, don't retry. + if (calld_->retry_policy_ == nullptr) return false; + // Check status. + if (status.has_value()) { + if (GPR_LIKELY(*status == GRPC_STATUS_OK)) { + if (calld_->retry_throttle_data_ != nullptr) { + calld_->retry_throttle_data_->RecordSuccess(); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: call succeeded", + calld_->chand_, calld_, this); + } + return false; + } + // Status is not OK. Check whether the status is retryable. + if (!calld_->retry_policy_->retryable_status_codes().Contains(*status)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: status %s not configured as " + "retryable", + calld_->chand_, calld_, this, + grpc_status_code_to_string(*status)); + } + return false; + } + } + // Record the failure and check whether retries are throttled. + // Note that it's important for this check to come after the status + // code check above, since we should only record failures whose statuses + // match the configured retryable status codes, so that we don't count + // things like failures due to malformed requests (INVALID_ARGUMENT). + // Conversely, it's important for this to come before the remaining + // checks, so that we don't fail to record failures due to other factors. + if (calld_->retry_throttle_data_ != nullptr && + !calld_->retry_throttle_data_->RecordFailure()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: retries throttled", + calld_->chand_, calld_, this); + } + return false; + } + // Check whether the call is committed. + if (calld_->retry_committed_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: retries already committed", + calld_->chand_, calld_, this); + } + return false; + } + // Check whether we have retries remaining. + ++calld_->num_attempts_completed_; + if (calld_->num_attempts_completed_ >= + calld_->retry_policy_->max_attempts()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log( + GPR_INFO, "chand=%p calld=%p attempt=%p: exceeded %d retry attempts", + calld_->chand_, calld_, this, calld_->retry_policy_->max_attempts()); + } + return false; + } + // Check server push-back. + if (server_pushback_md != nullptr) { + // If the value is "-1" or any other unparseable string, we do not retry. + uint32_t ms; + if (!grpc_parse_slice_to_uint32(GRPC_MDVALUE(*server_pushback_md), &ms)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: not retrying due to server " + "push-back", + calld_->chand_, calld_, this); + } + return false; + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log( + GPR_INFO, + "chand=%p calld=%p attempt=%p: server push-back: retry in %u ms", + calld_->chand_, calld_, this, ms); + } + *server_pushback_ms = static_cast(ms); + } + } + // Check with call dispatch controller. + auto* service_config_call_data = + static_cast( + calld_->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + if (!service_config_call_data->call_dispatch_controller()->ShouldRetry()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log( + GPR_INFO, + "chand=%p calld=%p attempt=%p: call dispatch controller denied retry", + calld_->chand_, calld_, this); + } + return false; + } + // We should retry. + return true; +} + +void RetryFilter::CallData::CallAttempt::Abandon() { + abandoned_ = true; + // Unref batches for deferred completion callbacks that will now never + // be invoked. + if (started_recv_trailing_metadata_ && + !seen_recv_trailing_metadata_from_surface_) { + recv_trailing_metadata_internal_batch_.reset( + DEBUG_LOCATION, + "internal recv_trailing_metadata completed before that op was " + "started from the surface"); + } + GRPC_ERROR_UNREF(recv_trailing_metadata_error_); + recv_trailing_metadata_error_ = GRPC_ERROR_NONE; + recv_initial_metadata_ready_deferred_batch_.reset( + DEBUG_LOCATION, + "unref deferred recv_initial_metadata_ready batch due to retry"); + GRPC_ERROR_UNREF(recv_initial_metadata_error_); + recv_initial_metadata_error_ = GRPC_ERROR_NONE; + recv_message_ready_deferred_batch_.reset( + DEBUG_LOCATION, "unref deferred recv_message_ready batch due to retry"); + GRPC_ERROR_UNREF(recv_message_error_); + recv_message_error_ = GRPC_ERROR_NONE; + for (auto& on_complete_deferred_batch : on_complete_deferred_batches_) { + on_complete_deferred_batch.batch.reset( + DEBUG_LOCATION, "unref deferred on_complete batch due to retry"); + GRPC_ERROR_UNREF(on_complete_deferred_batch.error); + } + on_complete_deferred_batches_.clear(); +} + +void RetryFilter::CallData::CallAttempt::OnPerAttemptRecvTimer( + void* arg, grpc_error_handle error) { + auto* call_attempt = static_cast(arg); + GRPC_CLOSURE_INIT(&call_attempt->on_per_attempt_recv_timer_, + OnPerAttemptRecvTimerLocked, call_attempt, nullptr); + GRPC_CALL_COMBINER_START(call_attempt->calld_->call_combiner_, + &call_attempt->on_per_attempt_recv_timer_, + GRPC_ERROR_REF(error), "per-attempt timer fired"); +} + +void RetryFilter::CallData::CallAttempt::OnPerAttemptRecvTimerLocked( + void* arg, grpc_error_handle error) { + auto* call_attempt = static_cast(arg); + auto* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: perAttemptRecvTimeout timer fired: " + "error=%s, per_attempt_recv_timer_pending_=%d", + calld->chand_, calld, call_attempt, + grpc_error_std_string(error).c_str(), + call_attempt->per_attempt_recv_timer_pending_); + } + CallCombinerClosureList closures; + if (error == GRPC_ERROR_NONE && + call_attempt->per_attempt_recv_timer_pending_) { + call_attempt->per_attempt_recv_timer_pending_ = false; + // Cancel this attempt. + // TODO(roth): When implementing hedging, we should not cancel the + // current attempt. + call_attempt->AddBatchForCancelOp( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "retry perAttemptRecvTimeout exceeded"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_CANCELLED), + &closures); + // Check whether we should retry. + if (call_attempt->ShouldRetry( + /*status=*/absl::nullopt, /*is_lb_drop=*/false, + /*server_pushback_md=*/nullptr, /*server_pushback_ms=*/nullptr)) { + // Mark current attempt as abandoned. + call_attempt->Abandon(); + // We are retrying. Start backoff timer. + calld->StartRetryTimer(/*server_pushback_ms=*/-1); + } else { + // Not retrying, so commit the call. + calld->RetryCommit(call_attempt); + // If retry state is no longer needed, switch to fast path for + // subsequent batches. + call_attempt->MaybeSwitchToFastPath(); + } + } + closures.RunClosures(calld->call_combiner_); + call_attempt->Unref(DEBUG_LOCATION, "OnPerAttemptRecvTimer"); + GRPC_CALL_STACK_UNREF(calld->owning_call_, "OnPerAttemptRecvTimer"); +} + +void RetryFilter::CallData::CallAttempt::MaybeCancelPerAttemptRecvTimer() { + if (per_attempt_recv_timer_pending_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: cancelling " + "perAttemptRecvTimeout timer", + calld_->chand_, calld_, this); + } + per_attempt_recv_timer_pending_ = false; + grpc_timer_cancel(&per_attempt_recv_timer_); + } +} + +// +// RetryFilter::CallData::CallAttempt::BatchData +// + +RetryFilter::CallData::CallAttempt::BatchData::BatchData( + RefCountedPtr attempt, int refcount, bool set_on_complete) + : RefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace) ? "BatchData" : nullptr, + refcount), + call_attempt_(std::move(attempt)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: creating batch %p", + call_attempt_->calld_->chand_, call_attempt_->calld_, + call_attempt_.get(), this); + } + // We hold a ref to the call stack for every batch sent on a call attempt. + // This is because some batches on the call attempt may not complete + // until after all of the batches are completed at the surface (because + // each batch that is pending at the surface holds a ref). This + // can happen for replayed send ops, and it can happen for + // recv_initial_metadata and recv_message ops on a call attempt that has + // been abandoned. + GRPC_CALL_STACK_REF(call_attempt_->calld_->owning_call_, "Retry BatchData"); + batch_.payload = &call_attempt_->batch_payload_; + if (set_on_complete) { + GRPC_CLOSURE_INIT(&on_complete_, OnComplete, this, nullptr); + batch_.on_complete = &on_complete_; + } +} + +RetryFilter::CallData::CallAttempt::BatchData::~BatchData() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: destroying batch %p", + call_attempt_->calld_->chand_, call_attempt_->calld_, + call_attempt_.get(), this); + } + GRPC_CALL_STACK_UNREF(call_attempt_->calld_->owning_call_, "Retry BatchData"); + call_attempt_.reset(DEBUG_LOCATION, "~BatchData"); +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + FreeCachedSendOpDataForCompletedBatch() { + auto* calld = call_attempt_->calld_; + // TODO(roth): When we implement hedging, this logic will need to get + // a bit more complex, because there may be other (now abandoned) call + // attempts still using this data. We may need to do some sort of + // ref-counting instead. + if (batch_.send_initial_metadata) { + calld->FreeCachedSendInitialMetadata(); + } + if (batch_.send_message) { + calld->FreeCachedSendMessage(call_attempt_->completed_send_message_count_ - + 1); + } + if (batch_.send_trailing_metadata) { + calld->FreeCachedSendTrailingMetadata(); + } +} + +// +// recv_initial_metadata callback handling +// + +void RetryFilter::CallData::CallAttempt::BatchData:: + MaybeAddClosureForRecvInitialMetadataCallback( + grpc_error_handle error, CallCombinerClosureList* closures) { + // Find pending batch. + PendingBatch* pending = call_attempt_->calld_->PendingBatchFind( + "invoking recv_initial_metadata_ready for", + [](grpc_transport_stream_op_batch* batch) { + return batch->recv_initial_metadata && + batch->payload->recv_initial_metadata + .recv_initial_metadata_ready != nullptr; + }); + if (pending == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } + // Return metadata. + *pending->batch->payload->recv_initial_metadata.recv_initial_metadata = + std::move(call_attempt_->recv_initial_metadata_); + // Propagate trailing_metadata_available. + *pending->batch->payload->recv_initial_metadata.trailing_metadata_available = + call_attempt_->trailing_metadata_available_; + // Update bookkeeping. + // Note: Need to do this before invoking the callback, since invoking + // the callback will result in yielding the call combiner. + grpc_closure* recv_initial_metadata_ready = + pending->batch->payload->recv_initial_metadata + .recv_initial_metadata_ready; + pending->batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + nullptr; + call_attempt_->calld_->MaybeClearPendingBatch(pending); + // Add callback to closures. + closures->Add(recv_initial_metadata_ready, error, + "recv_initial_metadata_ready for pending batch"); +} + +void RetryFilter::CallData::CallAttempt::BatchData::RecvInitialMetadataReady( + void* arg, grpc_error_handle error) { + RefCountedPtr batch_data(static_cast(arg)); + CallAttempt* call_attempt = batch_data->call_attempt_.get(); + CallData* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p batch_data=%p: " + "got recv_initial_metadata_ready, error=%s", + calld->chand_, calld, call_attempt, batch_data.get(), + grpc_error_std_string(error).c_str()); + } + call_attempt->completed_recv_initial_metadata_ = true; + // If this attempt has been abandoned, then we're not going to use the + // result of this recv_initial_metadata op, so do nothing. + if (call_attempt->abandoned_) { + GRPC_CALL_COMBINER_STOP( + calld->call_combiner_, + "recv_initial_metadata_ready for abandoned attempt"); + return; + } + // Cancel per-attempt recv timer, if any. + call_attempt->MaybeCancelPerAttemptRecvTimer(); + // If we're not committed, check the response to see if we need to commit. + if (!calld->retry_committed_) { + // If we got an error or a Trailers-Only response and have not yet gotten + // the recv_trailing_metadata_ready callback, then defer propagating this + // callback back to the surface. We can evaluate whether to retry when + // recv_trailing_metadata comes back. + if (GPR_UNLIKELY((call_attempt->trailing_metadata_available_ || + error != GRPC_ERROR_NONE) && + !call_attempt->completed_recv_trailing_metadata_)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: deferring " + "recv_initial_metadata_ready (Trailers-Only)", + calld->chand_, calld, call_attempt); + } + call_attempt->recv_initial_metadata_ready_deferred_batch_ = + std::move(batch_data); + call_attempt->recv_initial_metadata_error_ = GRPC_ERROR_REF(error); + CallCombinerClosureList closures; + if (error != GRPC_ERROR_NONE) { + call_attempt->AddBatchForCancelOp(GRPC_ERROR_REF(error), &closures); + } + if (!call_attempt->started_recv_trailing_metadata_) { + // recv_trailing_metadata not yet started by application; start it + // ourselves to get status. + call_attempt->AddBatchForInternalRecvTrailingMetadata(&closures); + } + closures.RunClosures(calld->call_combiner_); + return; + } + // Received valid initial metadata, so commit the call. + calld->RetryCommit(call_attempt); + // If retry state is no longer needed, switch to fast path for + // subsequent batches. + call_attempt->MaybeSwitchToFastPath(); + } + // Invoke the callback to return the result to the surface. + CallCombinerClosureList closures; + batch_data->MaybeAddClosureForRecvInitialMetadataCallback( + GRPC_ERROR_REF(error), &closures); + closures.RunClosures(calld->call_combiner_); +} + +// +// recv_message callback handling +// + +void RetryFilter::CallData::CallAttempt::BatchData:: + MaybeAddClosureForRecvMessageCallback(grpc_error_handle error, + CallCombinerClosureList* closures) { + // Find pending op. + PendingBatch* pending = call_attempt_->calld_->PendingBatchFind( + "invoking recv_message_ready for", + [](grpc_transport_stream_op_batch* batch) { + return batch->recv_message && + batch->payload->recv_message.recv_message_ready != nullptr; + }); + if (pending == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } + // Return payload. + *pending->batch->payload->recv_message.recv_message = + std::move(call_attempt_->recv_message_); + // Update bookkeeping. + // Note: Need to do this before invoking the callback, since invoking + // the callback will result in yielding the call combiner. + grpc_closure* recv_message_ready = + pending->batch->payload->recv_message.recv_message_ready; + pending->batch->payload->recv_message.recv_message_ready = nullptr; + call_attempt_->calld_->MaybeClearPendingBatch(pending); + // Add callback to closures. + closures->Add(recv_message_ready, error, + "recv_message_ready for pending batch"); +} + +void RetryFilter::CallData::CallAttempt::BatchData::RecvMessageReady( + void* arg, grpc_error_handle error) { + RefCountedPtr batch_data(static_cast(arg)); + CallAttempt* call_attempt = batch_data->call_attempt_.get(); + CallData* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p batch_data=%p: " + "got recv_message_ready, error=%s", + calld->chand_, calld, call_attempt, batch_data.get(), + grpc_error_std_string(error).c_str()); + } + ++call_attempt->completed_recv_message_count_; + // If this attempt has been abandoned, then we're not going to use the + // result of this recv_message op, so do nothing. + if (call_attempt->abandoned_) { + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "recv_message_ready for abandoned attempt"); + return; + } + // Cancel per-attempt recv timer, if any. + call_attempt->MaybeCancelPerAttemptRecvTimer(); + // If we're not committed, check the response to see if we need to commit. + if (!calld->retry_committed_) { + // If we got an error or the payload was nullptr and we have not yet gotten + // the recv_trailing_metadata_ready callback, then defer propagating this + // callback back to the surface. We can evaluate whether to retry when + // recv_trailing_metadata comes back. + if (GPR_UNLIKELY((call_attempt->recv_message_ == nullptr || + error != GRPC_ERROR_NONE) && + !call_attempt->completed_recv_trailing_metadata_)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: deferring recv_message_ready " + "(nullptr message and recv_trailing_metadata pending)", + calld->chand_, calld, call_attempt); + } + call_attempt->recv_message_ready_deferred_batch_ = std::move(batch_data); + call_attempt->recv_message_error_ = GRPC_ERROR_REF(error); + CallCombinerClosureList closures; + if (error != GRPC_ERROR_NONE) { + call_attempt->AddBatchForCancelOp(GRPC_ERROR_REF(error), &closures); + } + if (!call_attempt->started_recv_trailing_metadata_) { + // recv_trailing_metadata not yet started by application; start it + // ourselves to get status. + call_attempt->AddBatchForInternalRecvTrailingMetadata(&closures); + } + closures.RunClosures(calld->call_combiner_); + return; + } + // Received a valid message, so commit the call. + calld->RetryCommit(call_attempt); + // If retry state is no longer needed, switch to fast path for + // subsequent batches. + call_attempt->MaybeSwitchToFastPath(); + } + // Invoke the callback to return the result to the surface. + CallCombinerClosureList closures; + batch_data->MaybeAddClosureForRecvMessageCallback(GRPC_ERROR_REF(error), + &closures); + closures.RunClosures(calld->call_combiner_); +} + +// +// recv_trailing_metadata handling +// + +namespace { + +// Sets *status, *server_pushback_md, and *is_lb_drop based on md_batch +// and error. +void GetCallStatus(grpc_millis deadline, grpc_metadata_batch* md_batch, + grpc_error_handle error, grpc_status_code* status, + grpc_mdelem** server_pushback_md, bool* is_lb_drop) { + if (error != GRPC_ERROR_NONE) { + grpc_error_get_status(error, deadline, status, nullptr, nullptr, nullptr); + intptr_t value = 0; + if (grpc_error_get_int(error, GRPC_ERROR_INT_LB_POLICY_DROP, &value) && + value != 0) { + *is_lb_drop = true; + } + } else { + GPR_ASSERT(md_batch->legacy_index()->named.grpc_status != nullptr); + *status = grpc_get_status_code_from_metadata( + md_batch->legacy_index()->named.grpc_status->md); + if (md_batch->legacy_index()->named.grpc_retry_pushback_ms != nullptr) { + *server_pushback_md = + &md_batch->legacy_index()->named.grpc_retry_pushback_ms->md; + } + } + GRPC_ERROR_UNREF(error); +} + +} // namespace + +void RetryFilter::CallData::CallAttempt::BatchData:: + MaybeAddClosureForRecvTrailingMetadataReady( + grpc_error_handle error, CallCombinerClosureList* closures) { + auto* calld = call_attempt_->calld_; + // Find pending batch. + PendingBatch* pending = calld->PendingBatchFind( + "invoking recv_trailing_metadata_ready for", + [](grpc_transport_stream_op_batch* batch) { + return batch->recv_trailing_metadata && + batch->payload->recv_trailing_metadata + .recv_trailing_metadata_ready != nullptr; + }); + // If we generated the recv_trailing_metadata op internally via + // AddBatchForInternalRecvTrailingMetadata(), then there will be no + // pending batch. + if (pending == nullptr) { + call_attempt_->recv_trailing_metadata_error_ = error; + return; + } + // Copy transport stats to be delivered up to the surface. + grpc_transport_move_stats( + &call_attempt_->collect_stats_, + pending->batch->payload->recv_trailing_metadata.collect_stats); + // Return metadata. + *pending->batch->payload->recv_trailing_metadata.recv_trailing_metadata = + std::move(call_attempt_->recv_trailing_metadata_); + // Add closure. + closures->Add(pending->batch->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + error, "recv_trailing_metadata_ready for pending batch"); + // Update bookkeeping. + pending->batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + nullptr; + calld->MaybeClearPendingBatch(pending); +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddClosuresForDeferredCompletionCallbacks( + CallCombinerClosureList* closures) { + // Add closure for deferred recv_initial_metadata_ready. + if (GPR_UNLIKELY(call_attempt_->recv_initial_metadata_ready_deferred_batch_ != + nullptr)) { + MaybeAddClosureForRecvInitialMetadataCallback( + call_attempt_->recv_initial_metadata_error_, closures); + call_attempt_->recv_initial_metadata_ready_deferred_batch_.reset( + DEBUG_LOCATION, "resuming deferred recv_initial_metadata_ready"); + call_attempt_->recv_initial_metadata_error_ = GRPC_ERROR_NONE; + } + // Add closure for deferred recv_message_ready. + if (GPR_UNLIKELY(call_attempt_->recv_message_ready_deferred_batch_ != + nullptr)) { + MaybeAddClosureForRecvMessageCallback(call_attempt_->recv_message_error_, + closures); + call_attempt_->recv_message_ready_deferred_batch_.reset( + DEBUG_LOCATION, "resuming deferred recv_message_ready"); + call_attempt_->recv_message_error_ = GRPC_ERROR_NONE; + } + // Add closures for deferred on_complete callbacks. + for (auto& on_complete_deferred_batch : + call_attempt_->on_complete_deferred_batches_) { + closures->Add(&on_complete_deferred_batch.batch->on_complete_, + on_complete_deferred_batch.error, "resuming on_complete"); + on_complete_deferred_batch.batch.release(); + } + call_attempt_->on_complete_deferred_batches_.clear(); +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddClosuresToFailUnstartedPendingBatches( + grpc_error_handle error, CallCombinerClosureList* closures) { + auto* calld = call_attempt_->calld_; + for (size_t i = 0; i < GPR_ARRAY_SIZE(calld->pending_batches_); ++i) { + PendingBatch* pending = &calld->pending_batches_[i]; + if (pending->batch == nullptr) continue; + if (call_attempt_->PendingBatchContainsUnstartedSendOps(pending)) { + closures->Add(pending->batch->on_complete, GRPC_ERROR_REF(error), + "failing on_complete for pending batch"); + pending->batch->on_complete = nullptr; + calld->MaybeClearPendingBatch(pending); + } + } + GRPC_ERROR_UNREF(error); +} + +void RetryFilter::CallData::CallAttempt::BatchData::RunClosuresForCompletedCall( + grpc_error_handle error) { + // Construct list of closures to execute. + CallCombinerClosureList closures; + // First, add closure for recv_trailing_metadata_ready. + MaybeAddClosureForRecvTrailingMetadataReady(GRPC_ERROR_REF(error), &closures); + // If there are deferred batch completion callbacks, add them to closures. + AddClosuresForDeferredCompletionCallbacks(&closures); + // Add closures to fail any pending batches that have not yet been started. + AddClosuresToFailUnstartedPendingBatches(GRPC_ERROR_REF(error), &closures); + // Schedule all of the closures identified above. + // Note: This will release the call combiner. + closures.RunClosures(call_attempt_->calld_->call_combiner_); + GRPC_ERROR_UNREF(error); +} + +void RetryFilter::CallData::CallAttempt::BatchData::RecvTrailingMetadataReady( + void* arg, grpc_error_handle error) { + RefCountedPtr batch_data(static_cast(arg)); + CallAttempt* call_attempt = batch_data->call_attempt_.get(); + CallData* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p batch_data=%p: " + "got recv_trailing_metadata_ready, error=%s", + calld->chand_, calld, call_attempt, batch_data.get(), + grpc_error_std_string(error).c_str()); + } + call_attempt->completed_recv_trailing_metadata_ = true; + // If this attempt has been abandoned, then we're not going to use the + // result of this recv_trailing_metadata op, so do nothing. + if (call_attempt->abandoned_) { + GRPC_CALL_COMBINER_STOP( + calld->call_combiner_, + "recv_trailing_metadata_ready for abandoned attempt"); + return; + } + // Cancel per-attempt recv timer, if any. + call_attempt->MaybeCancelPerAttemptRecvTimer(); + // Get the call's status and check for server pushback metadata. + grpc_status_code status = GRPC_STATUS_OK; + grpc_mdelem* server_pushback_md = nullptr; + grpc_metadata_batch* md_batch = + batch_data->batch_.payload->recv_trailing_metadata.recv_trailing_metadata; + bool is_lb_drop = false; + GetCallStatus(calld->deadline_, md_batch, GRPC_ERROR_REF(error), &status, + &server_pushback_md, &is_lb_drop); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log( + GPR_INFO, + "chand=%p calld=%p attempt=%p: call finished, status=%s is_lb_drop=%d", + calld->chand_, calld, call_attempt, grpc_status_code_to_string(status), + is_lb_drop); + } + // Check if we should retry. + grpc_millis server_pushback_ms = -1; + if (call_attempt->ShouldRetry(status, is_lb_drop, server_pushback_md, + &server_pushback_ms)) { + // Start retry timer. + calld->StartRetryTimer(server_pushback_ms); + // Cancel call attempt. + CallCombinerClosureList closures; + call_attempt->AddBatchForCancelOp( + error == GRPC_ERROR_NONE + ? grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("call attempt failed"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_CANCELLED) + : GRPC_ERROR_REF(error), + &closures); + // Record that this attempt has been abandoned. + call_attempt->Abandon(); + // Yields call combiner. + closures.RunClosures(calld->call_combiner_); + return; + } + // Not retrying, so commit the call. + calld->RetryCommit(call_attempt); + // If retry state is no longer needed, switch to fast path for + // subsequent batches. + call_attempt->MaybeSwitchToFastPath(); + // Run any necessary closures. + batch_data->RunClosuresForCompletedCall(GRPC_ERROR_REF(error)); +} + +// +// on_complete callback handling +// + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddClosuresForCompletedPendingBatch(grpc_error_handle error, + CallCombinerClosureList* closures) { + auto* calld = call_attempt_->calld_; + PendingBatch* pending = calld->PendingBatchFind( + "completed", [this](grpc_transport_stream_op_batch* batch) { + // Match the pending batch with the same set of send ops as the + // batch we've just completed. + return batch->on_complete != nullptr && + batch_.send_initial_metadata == batch->send_initial_metadata && + batch_.send_message == batch->send_message && + batch_.send_trailing_metadata == batch->send_trailing_metadata; + }); + // If batch_data is a replay batch, then there will be no pending + // batch to complete. + if (pending == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } + // Propagate payload. + if (batch_.send_message) { + pending->batch->payload->send_message.stream_write_closed = + batch_.payload->send_message.stream_write_closed; + } + // Add closure. + closures->Add(pending->batch->on_complete, error, + "on_complete for pending batch"); + pending->batch->on_complete = nullptr; + calld->MaybeClearPendingBatch(pending); +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddClosuresForReplayOrPendingSendOps(CallCombinerClosureList* closures) { + auto* calld = call_attempt_->calld_; + bool have_pending_send_ops = call_attempt_->HaveSendOpsToReplay(); + // We don't check send_initial_metadata here, because that op will always + // be started as soon as it is received from the surface, so it will + // never need to be started at this point. + if (!have_pending_send_ops) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(calld->pending_batches_); ++i) { + PendingBatch* pending = &calld->pending_batches_[i]; + grpc_transport_stream_op_batch* batch = pending->batch; + if (batch == nullptr || pending->send_ops_cached) continue; + if (batch->send_message || batch->send_trailing_metadata) { + have_pending_send_ops = true; + break; + } + } + } + if (have_pending_send_ops) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p: starting next batch for pending " + "send op(s)", + calld->chand_, calld, call_attempt_.get()); + } + call_attempt_->AddRetriableBatches(closures); + } +} + +void RetryFilter::CallData::CallAttempt::BatchData::OnComplete( + void* arg, grpc_error_handle error) { + RefCountedPtr batch_data(static_cast(arg)); + CallAttempt* call_attempt = batch_data->call_attempt_.get(); + CallData* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p batch_data=%p: " + "got on_complete, error=%s, batch=%s", + calld->chand_, calld, call_attempt, batch_data.get(), + grpc_error_std_string(error).c_str(), + grpc_transport_stream_op_batch_string(&batch_data->batch_).c_str()); + } + // If this attempt has been abandoned, then we're not going to propagate + // the completion of this batch, so do nothing. + if (call_attempt->abandoned_) { + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "on_complete for abandoned attempt"); + return; + } + // If we got an error and have not yet gotten the + // recv_trailing_metadata_ready callback, then defer propagating this + // callback back to the surface. We can evaluate whether to retry when + // recv_trailing_metadata comes back. + if (GPR_UNLIKELY(!calld->retry_committed_ && error != GRPC_ERROR_NONE && + !call_attempt->completed_recv_trailing_metadata_)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p attempt=%p: deferring on_complete", + calld->chand_, calld, call_attempt); + } + call_attempt->on_complete_deferred_batches_.emplace_back( + std::move(batch_data), GRPC_ERROR_REF(error)); + CallCombinerClosureList closures; + call_attempt->AddBatchForCancelOp(GRPC_ERROR_REF(error), &closures); + if (!call_attempt->started_recv_trailing_metadata_) { + // recv_trailing_metadata not yet started by application; start it + // ourselves to get status. + call_attempt->AddBatchForInternalRecvTrailingMetadata(&closures); + } + closures.RunClosures(calld->call_combiner_); + return; + } + // Update bookkeeping in call_attempt. + if (batch_data->batch_.send_initial_metadata) { + call_attempt->completed_send_initial_metadata_ = true; + } + if (batch_data->batch_.send_message) { + ++call_attempt->completed_send_message_count_; + } + if (batch_data->batch_.send_trailing_metadata) { + call_attempt->completed_send_trailing_metadata_ = true; + } + // If the call is committed, free cached data for send ops that we've just + // completed. + if (calld->retry_committed_) { + batch_data->FreeCachedSendOpDataForCompletedBatch(); + } + // Construct list of closures to execute. + CallCombinerClosureList closures; + // Add closure for the completed pending batch, if any. + batch_data->AddClosuresForCompletedPendingBatch(GRPC_ERROR_REF(error), + &closures); + // If needed, add a callback to start any replay or pending send ops on + // the LB call. + if (!call_attempt->completed_recv_trailing_metadata_) { + batch_data->AddClosuresForReplayOrPendingSendOps(&closures); + } + // If retry state is no longer needed (i.e., we're committed and there + // are no more send ops to replay), switch to fast path for subsequent + // batches. + call_attempt->MaybeSwitchToFastPath(); + // Schedule all of the closures identified above. + // Note: This yields the call combiner. + closures.RunClosures(calld->call_combiner_); +} + +void RetryFilter::CallData::CallAttempt::BatchData::OnCompleteForCancelOp( + void* arg, grpc_error_handle error) { + RefCountedPtr batch_data(static_cast(arg)); + CallAttempt* call_attempt = batch_data->call_attempt_.get(); + CallData* calld = call_attempt->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p attempt=%p batch_data=%p: " + "got on_complete for cancel_stream batch, error=%s, batch=%s", + calld->chand_, calld, call_attempt, batch_data.get(), + grpc_error_std_string(error).c_str(), + grpc_transport_stream_op_batch_string(&batch_data->batch_).c_str()); + } + GRPC_CALL_COMBINER_STOP( + calld->call_combiner_, + "on_complete for internally generated cancel_stream op"); +} + +// +// retriable batch construction +// + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableSendInitialMetadataOp() { + auto* calld = call_attempt_->calld_; + // Maps the number of retries to the corresponding metadata value slice. + const grpc_slice* retry_count_strings[] = {&GRPC_MDSTR_1, &GRPC_MDSTR_2, + &GRPC_MDSTR_3, &GRPC_MDSTR_4}; + // We need to make a copy of the metadata batch for each attempt, since + // the filters in the subchannel stack may modify this batch, and we don't + // want those modifications to be passed forward to subsequent attempts. + // + // If we've already completed one or more attempts, add the + // grpc-retry-attempts header. + grpc_metadata_batch_copy(&calld->send_initial_metadata_, + &call_attempt_->send_initial_metadata_); + if (GPR_UNLIKELY(call_attempt_->send_initial_metadata_.legacy_index() + ->named.grpc_previous_rpc_attempts != nullptr)) { + call_attempt_->send_initial_metadata_.Remove( + GRPC_BATCH_GRPC_PREVIOUS_RPC_ATTEMPTS); + } + if (GPR_UNLIKELY(calld->num_attempts_completed_ > 0)) { + grpc_mdelem retry_md = grpc_mdelem_create( + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS, + *retry_count_strings[calld->num_attempts_completed_ - 1], nullptr); + grpc_error_handle error = grpc_metadata_batch_add_tail( + &call_attempt_->send_initial_metadata_, + &call_attempt_->retry_attempts_metadata_, retry_md, + GRPC_BATCH_GRPC_PREVIOUS_RPC_ATTEMPTS); + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + gpr_log(GPR_ERROR, "error adding retry metadata: %s", + grpc_error_std_string(error).c_str()); + GPR_ASSERT(false); + } + } + call_attempt_->started_send_initial_metadata_ = true; + batch_.send_initial_metadata = true; + batch_.payload->send_initial_metadata.send_initial_metadata = + &call_attempt_->send_initial_metadata_; + batch_.payload->send_initial_metadata.send_initial_metadata_flags = + calld->send_initial_metadata_flags_; + batch_.payload->send_initial_metadata.peer_string = calld->peer_string_; +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableSendMessageOp() { + auto* calld = call_attempt_->calld_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log( + GPR_INFO, + "chand=%p calld=%p attempt=%p: starting calld->send_messages[%" PRIuPTR + "]", + calld->chand_, calld, call_attempt_.get(), + call_attempt_->started_send_message_count_); + } + ByteStreamCache* cache = + calld->send_messages_[call_attempt_->started_send_message_count_]; + ++call_attempt_->started_send_message_count_; + call_attempt_->send_message_.Init(cache); + batch_.send_message = true; + batch_.payload->send_message.send_message.reset( + call_attempt_->send_message_.get()); +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableSendTrailingMetadataOp() { + auto* calld = call_attempt_->calld_; + // We need to make a copy of the metadata batch for each attempt, since + // the filters in the subchannel stack may modify this batch, and we don't + // want those modifications to be passed forward to subsequent attempts. + grpc_metadata_batch_copy(&calld->send_trailing_metadata_, + &call_attempt_->send_trailing_metadata_); + call_attempt_->started_send_trailing_metadata_ = true; + batch_.send_trailing_metadata = true; + batch_.payload->send_trailing_metadata.send_trailing_metadata = + &call_attempt_->send_trailing_metadata_; +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableRecvInitialMetadataOp() { + call_attempt_->started_recv_initial_metadata_ = true; + batch_.recv_initial_metadata = true; + call_attempt_->recv_initial_metadata_.Clear(); + batch_.payload->recv_initial_metadata.recv_initial_metadata = + &call_attempt_->recv_initial_metadata_; + batch_.payload->recv_initial_metadata.trailing_metadata_available = + &call_attempt_->trailing_metadata_available_; + GRPC_CLOSURE_INIT(&call_attempt_->recv_initial_metadata_ready_, + RecvInitialMetadataReady, this, grpc_schedule_on_exec_ctx); + batch_.payload->recv_initial_metadata.recv_initial_metadata_ready = + &call_attempt_->recv_initial_metadata_ready_; +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableRecvMessageOp() { + ++call_attempt_->started_recv_message_count_; + batch_.recv_message = true; + batch_.payload->recv_message.recv_message = &call_attempt_->recv_message_; + batch_.payload->recv_message.call_failed_before_recv_message = nullptr; + GRPC_CLOSURE_INIT(&call_attempt_->recv_message_ready_, RecvMessageReady, this, + grpc_schedule_on_exec_ctx); + batch_.payload->recv_message.recv_message_ready = + &call_attempt_->recv_message_ready_; +} + +void RetryFilter::CallData::CallAttempt::BatchData:: + AddRetriableRecvTrailingMetadataOp() { + call_attempt_->started_recv_trailing_metadata_ = true; + batch_.recv_trailing_metadata = true; + call_attempt_->recv_trailing_metadata_.Clear(); + batch_.payload->recv_trailing_metadata.recv_trailing_metadata = + &call_attempt_->recv_trailing_metadata_; + batch_.payload->recv_trailing_metadata.collect_stats = + &call_attempt_->collect_stats_; + GRPC_CLOSURE_INIT(&call_attempt_->recv_trailing_metadata_ready_, + RecvTrailingMetadataReady, this, grpc_schedule_on_exec_ctx); + batch_.payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &call_attempt_->recv_trailing_metadata_ready_; +} + +void RetryFilter::CallData::CallAttempt::BatchData::AddCancelStreamOp( + grpc_error_handle error) { + batch_.cancel_stream = true; + batch_.payload->cancel_stream.cancel_error = error; + // Override on_complete callback. + GRPC_CLOSURE_INIT(&on_complete_, OnCompleteForCancelOp, this, nullptr); +} + +// +// CallData vtable functions +// + +grpc_error_handle RetryFilter::CallData::Init( + grpc_call_element* elem, const grpc_call_element_args* args) { + auto* chand = static_cast(elem->channel_data); + new (elem->call_data) CallData(chand, *args); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: created call", chand, + elem->call_data); + } + return GRPC_ERROR_NONE; +} + +void RetryFilter::CallData::Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure) { + auto* calld = static_cast(elem->call_data); + // Save our ref to the CallStackDestructionBarrier until after our + // dtor is invoked. + RefCountedPtr call_stack_destruction_barrier = + std::move(calld->call_stack_destruction_barrier_); + calld->~CallData(); + // Now set the callback in the CallStackDestructionBarrier object, + // right before we release our ref to it (implicitly upon returning). + // The callback will be invoked when the CallStackDestructionBarrier + // is destroyed. + call_stack_destruction_barrier->set_on_call_stack_destruction( + then_schedule_closure); +} + +void RetryFilter::CallData::StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + calld->StartTransportStreamOpBatch(batch); +} + +void RetryFilter::CallData::SetPollent(grpc_call_element* elem, + grpc_polling_entity* pollent) { + auto* calld = static_cast(elem->call_data); + calld->pollent_ = pollent; +} + +// +// CallData implementation +// + +const RetryMethodConfig* GetRetryPolicy( + const grpc_call_context_element* context) { + if (context == nullptr) return nullptr; + auto* svc_cfg_call_data = static_cast( + context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + if (svc_cfg_call_data == nullptr) return nullptr; + return static_cast( + svc_cfg_call_data->GetMethodParsedConfig( + RetryServiceConfigParser::ParserIndex())); +} + +RetryFilter::CallData::CallData(RetryFilter* chand, + const grpc_call_element_args& args) + : chand_(chand), + retry_throttle_data_(chand->retry_throttle_data_), + retry_policy_(GetRetryPolicy(args.context)), + retry_backoff_( + BackOff::Options() + .set_initial_backoff(retry_policy_ == nullptr + ? 0 + : retry_policy_->initial_backoff()) + .set_multiplier(retry_policy_ == nullptr + ? 0 + : retry_policy_->backoff_multiplier()) + .set_jitter(RETRY_BACKOFF_JITTER) + .set_max_backoff( + retry_policy_ == nullptr ? 0 : retry_policy_->max_backoff())), + path_(grpc_slice_ref_internal(args.path)), + deadline_(args.deadline), + arena_(args.arena), + owning_call_(args.call_stack), + call_combiner_(args.call_combiner), + call_context_(args.context), + call_stack_destruction_barrier_( + arena_->New()), + pending_send_initial_metadata_(false), + pending_send_message_(false), + pending_send_trailing_metadata_(false), + retry_committed_(false), + retry_timer_pending_(false) {} + +RetryFilter::CallData::~CallData() { + grpc_slice_unref_internal(path_); + // Make sure there are no remaining pending batches. + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + GPR_ASSERT(pending_batches_[i].batch == nullptr); + } + GRPC_ERROR_UNREF(cancelled_from_surface_); +} + +void RetryFilter::CallData::StartTransportStreamOpBatch( + grpc_transport_stream_op_batch* batch) { + // If we have an LB call, delegate to the LB call. + if (committed_call_ != nullptr) { + // Note: This will release the call combiner. + committed_call_->StartTransportStreamOpBatch(batch); + return; + } + // Handle cancellation. + if (GPR_UNLIKELY(batch->cancel_stream)) { + grpc_error_handle cancel_error = batch->payload->cancel_stream.cancel_error; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: cancelled from surface: %s", chand_, + this, grpc_error_std_string(cancel_error).c_str()); + } + // If we have a current call attempt, commit the call, then send + // the cancellation down to that attempt. When the call fails, it + // will not be retried, because we have committed it here. + if (call_attempt_ != nullptr) { + RetryCommit(call_attempt_.get()); + // TODO(roth): When implementing hedging, this will get more + // complex, because instead of just passing the batch down to a + // single call attempt, we'll need to cancel multiple call + // attempts and wait for the cancellation on_complete from each call + // attempt before we propagate the on_complete from this batch + // back to the surface. + // Note: This will release the call combiner. + call_attempt_->CancelFromSurface(batch); + return; + } + // Save cancel_error in case subsequent batches are started. + GRPC_ERROR_UNREF(cancelled_from_surface_); + cancelled_from_surface_ = GRPC_ERROR_REF(cancel_error); + // Cancel retry timer. + if (retry_timer_pending_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: cancelling retry timer", chand_, + this); + } + retry_timer_pending_ = false; // Lame timer callback. + grpc_timer_cancel(&retry_timer_); + FreeAllCachedSendOpData(); + } + // Fail pending batches. + PendingBatchesFail(GRPC_ERROR_REF(cancel_error)); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(cancel_error), call_combiner_); + return; + } + // Add the batch to the pending list. + PendingBatch* pending = PendingBatchesAdd(batch); + // If the timer is pending, yield the call combiner and wait for it to + // run, since we don't want to start another call attempt until it does. + if (retry_timer_pending_) { + GRPC_CALL_COMBINER_STOP(call_combiner_, + "added pending batch while retry timer pending"); + return; + } + // If we do not yet have a call attempt, create one. + if (call_attempt_ == nullptr) { + // If we were previously cancelled from the surface, cancel this + // batch instead of creating a call attempt. + if (cancelled_from_surface_ != GRPC_ERROR_NONE) { + PendingBatchClear(pending); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(cancelled_from_surface_), call_combiner_); + return; + } + // If there is no retry policy, then commit retries immediately. + // This ensures that the code below will always jump to the fast path. + // TODO(roth): Remove this special case when we implement + // transparent retries. + if (retry_policy_ == nullptr) retry_committed_ = true; + // If this is the first batch and retries are already committed + // (e.g., if this batch put the call above the buffer size limit), then + // immediately create an LB call and delegate the batch to it. This + // avoids the overhead of unnecessarily allocating a CallAttempt + // object or caching any of the send op data. + // Note that we would ideally like to do this also on subsequent + // attempts (e.g., if a batch puts the call above the buffer size + // limit since the last attempt was complete), but in practice that's + // not really worthwhile, because we will almost always have cached and + // completed at least the send_initial_metadata op on the previous + // attempt, which means that we'd need special logic to replay the + // batch anyway, which is exactly what the CallAttempt object provides. + // We also skip this optimization if perAttemptRecvTimeout is set in the + // retry policy, because we need the code in CallAttempt to handle + // the associated timer. + if (num_attempts_completed_ == 0 && retry_committed_ && + (retry_policy_ == nullptr || + !retry_policy_->per_attempt_recv_timeout().has_value())) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: retry committed before first attempt; " + "creating LB call", + chand_, this); + } + PendingBatchClear(pending); + auto* service_config_call_data = + static_cast( + call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + committed_call_ = CreateLoadBalancedCall( + service_config_call_data->call_dispatch_controller()); + committed_call_->StartTransportStreamOpBatch(batch); + return; + } + // Otherwise, create a call attempt. + // The attempt will automatically start any necessary replays or + // pending batches. + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: creating call attempt", chand_, + this); + } + CreateCallAttempt(); + return; + } + // Send batches to call attempt. + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: starting batch on attempt=%p", chand_, + this, call_attempt_.get()); + } + call_attempt_->StartRetriableBatches(); +} + +OrphanablePtr +RetryFilter::CallData::CreateLoadBalancedCall( + ConfigSelector::CallDispatchController* call_dispatch_controller) { + grpc_call_element_args args = {owning_call_, nullptr, call_context_, + path_, /*start_time=*/0, deadline_, + arena_, call_combiner_}; + return chand_->client_channel_->CreateLoadBalancedCall( + args, pollent_, + // This callback holds a ref to the CallStackDestructionBarrier + // object until the LB call is destroyed. + call_stack_destruction_barrier_->MakeLbCallDestructionClosure(this), + call_dispatch_controller, + // TODO(roth): Change this when we support transparent retries. + /*is_transparent_retry=*/false); +} + +void RetryFilter::CallData::CreateCallAttempt() { + call_attempt_ = MakeRefCounted(this); + call_attempt_->StartRetriableBatches(); +} + +// +// send op data caching +// + +void RetryFilter::CallData::MaybeCacheSendOpsForBatch(PendingBatch* pending) { + if (pending->send_ops_cached) return; + pending->send_ops_cached = true; + grpc_transport_stream_op_batch* batch = pending->batch; + // Save a copy of metadata for send_initial_metadata ops. + if (batch->send_initial_metadata) { + seen_send_initial_metadata_ = true; + grpc_metadata_batch* send_initial_metadata = + batch->payload->send_initial_metadata.send_initial_metadata; + grpc_metadata_batch_copy(send_initial_metadata, &send_initial_metadata_); + send_initial_metadata_flags_ = + batch->payload->send_initial_metadata.send_initial_metadata_flags; + peer_string_ = batch->payload->send_initial_metadata.peer_string; + } + // Set up cache for send_message ops. + if (batch->send_message) { + ByteStreamCache* cache = arena_->New( + std::move(batch->payload->send_message.send_message)); + send_messages_.push_back(cache); + } + // Save metadata batch for send_trailing_metadata ops. + if (batch->send_trailing_metadata) { + seen_send_trailing_metadata_ = true; + grpc_metadata_batch* send_trailing_metadata = + batch->payload->send_trailing_metadata.send_trailing_metadata; + grpc_metadata_batch_copy(send_trailing_metadata, &send_trailing_metadata_); + } +} + +void RetryFilter::CallData::FreeCachedSendInitialMetadata() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: destroying send_initial_metadata", + chand_, this); + } + send_initial_metadata_.Clear(); +} + +void RetryFilter::CallData::FreeCachedSendMessage(size_t idx) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: destroying send_messages[%" PRIuPTR "]", chand_, + this, idx); + } + send_messages_[idx]->Destroy(); +} + +void RetryFilter::CallData::FreeCachedSendTrailingMetadata() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: destroying send_trailing_metadata", + chand_, this); + } + send_trailing_metadata_.Clear(); +} + +void RetryFilter::CallData::FreeAllCachedSendOpData() { + if (seen_send_initial_metadata_) { + FreeCachedSendInitialMetadata(); + } + for (size_t i = 0; i < send_messages_.size(); ++i) { + FreeCachedSendMessage(i); + } + if (seen_send_trailing_metadata_) { + FreeCachedSendTrailingMetadata(); + } +} + +// +// pending_batches management +// + +size_t RetryFilter::CallData::GetBatchIndex( + grpc_transport_stream_op_batch* batch) { + if (batch->send_initial_metadata) return 0; + if (batch->send_message) return 1; + if (batch->send_trailing_metadata) return 2; + if (batch->recv_initial_metadata) return 3; + if (batch->recv_message) return 4; + if (batch->recv_trailing_metadata) return 5; + GPR_UNREACHABLE_CODE(return (size_t)-1); +} + +// This is called via the call combiner, so access to calld is synchronized. +RetryFilter::CallData::PendingBatch* RetryFilter::CallData::PendingBatchesAdd( + grpc_transport_stream_op_batch* batch) { + const size_t idx = GetBatchIndex(batch); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: adding pending batch at index %" PRIuPTR, + chand_, this, idx); + } + PendingBatch* pending = &pending_batches_[idx]; + GPR_ASSERT(pending->batch == nullptr); + pending->batch = batch; + pending->send_ops_cached = false; + // Update state in calld about pending batches. + // Also check if the batch takes us over the retry buffer limit. + // Note: We don't check the size of trailing metadata here, because + // gRPC clients do not send trailing metadata. + if (batch->send_initial_metadata) { + pending_send_initial_metadata_ = true; + bytes_buffered_for_retry_ += batch->payload->send_initial_metadata + .send_initial_metadata->TransportSize(); + } + if (batch->send_message) { + pending_send_message_ = true; + bytes_buffered_for_retry_ += + batch->payload->send_message.send_message->length(); + } + if (batch->send_trailing_metadata) { + pending_send_trailing_metadata_ = true; + } + // TODO(roth): When we implement hedging, if there are currently attempts + // in flight, we will need to pick the one on which the max number of send + // ops have already been sent, and we commit to that attempt. + if (GPR_UNLIKELY(bytes_buffered_for_retry_ > + chand_->per_rpc_retry_buffer_size_)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: exceeded retry buffer size, committing", + chand_, this); + } + RetryCommit(call_attempt_.get()); + } + return pending; +} + +void RetryFilter::CallData::PendingBatchClear(PendingBatch* pending) { + if (pending->batch->send_initial_metadata) { + pending_send_initial_metadata_ = false; + } + if (pending->batch->send_message) { + pending_send_message_ = false; + } + if (pending->batch->send_trailing_metadata) { + pending_send_trailing_metadata_ = false; + } + pending->batch = nullptr; +} + +void RetryFilter::CallData::MaybeClearPendingBatch(PendingBatch* pending) { + grpc_transport_stream_op_batch* batch = pending->batch; + // We clear the pending batch if all of its callbacks have been + // scheduled and reset to nullptr. + if (batch->on_complete == nullptr && + (!batch->recv_initial_metadata || + batch->payload->recv_initial_metadata.recv_initial_metadata_ready == + nullptr) && + (!batch->recv_message || + batch->payload->recv_message.recv_message_ready == nullptr) && + (!batch->recv_trailing_metadata || + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready == + nullptr)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: clearing pending batch", chand_, + this); + } + PendingBatchClear(pending); + } +} + +// This is called via the call combiner, so access to calld is synchronized. +void RetryFilter::CallData::FailPendingBatchInCallCombiner( + void* arg, grpc_error_handle error) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + CallData* call = static_cast(batch->handler_private.extra_arg); + // Note: This will release the call combiner. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(error), call->call_combiner_); +} + +// This is called via the call combiner, so access to calld is synchronized. +void RetryFilter::CallData::PendingBatchesFail(grpc_error_handle error) { + GPR_ASSERT(error != GRPC_ERROR_NONE); + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + size_t num_batches = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + if (pending_batches_[i].batch != nullptr) ++num_batches; + } + gpr_log(GPR_INFO, + "chand=%p calld=%p: failing %" PRIuPTR " pending batches: %s", + chand_, this, num_batches, grpc_error_std_string(error).c_str()); + } + CallCombinerClosureList closures; + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + PendingBatch* pending = &pending_batches_[i]; + grpc_transport_stream_op_batch* batch = pending->batch; + if (batch != nullptr) { + batch->handler_private.extra_arg = this; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, + FailPendingBatchInCallCombiner, batch, + grpc_schedule_on_exec_ctx); + closures.Add(&batch->handler_private.closure, GRPC_ERROR_REF(error), + "PendingBatchesFail"); + PendingBatchClear(pending); + } + } + closures.RunClosuresWithoutYielding(call_combiner_); + GRPC_ERROR_UNREF(error); +} + +template +RetryFilter::CallData::PendingBatch* RetryFilter::CallData::PendingBatchFind( + const char* log_message, Predicate predicate) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches_); ++i) { + PendingBatch* pending = &pending_batches_[i]; + grpc_transport_stream_op_batch* batch = pending->batch; + if (batch != nullptr && predicate(batch)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: %s pending batch at index %" PRIuPTR, + chand_, this, log_message, i); + } + return pending; + } + } + return nullptr; +} + +// +// retry code +// + +void RetryFilter::CallData::RetryCommit(CallAttempt* call_attempt) { + if (retry_committed_) return; + retry_committed_ = true; + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: committing retries", chand_, this); + } + if (call_attempt != nullptr) { + // If the call attempt's LB call has been committed, inform the call + // dispatch controller that the call has been committed. + // Note: If call_attempt is null, this is happening before the first + // retry attempt is started, in which case we'll just pass the real + // call dispatch controller down into the LB call, and it won't be + // our problem anymore. + if (call_attempt->lb_call_committed()) { + auto* service_config_call_data = + static_cast( + call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + service_config_call_data->call_dispatch_controller()->Commit(); + } + // Free cached send ops. + call_attempt->FreeCachedSendOpDataAfterCommit(); + } +} + +void RetryFilter::CallData::StartRetryTimer(grpc_millis server_pushback_ms) { + // Reset call attempt. + call_attempt_.reset(DEBUG_LOCATION, "StartRetryTimer"); + // Compute backoff delay. + grpc_millis next_attempt_time; + if (server_pushback_ms >= 0) { + next_attempt_time = ExecCtx::Get()->Now() + server_pushback_ms; + retry_backoff_.Reset(); + } else { + next_attempt_time = retry_backoff_.NextAttemptTime(); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_retry_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: retrying failed call in %" PRId64 " ms", chand_, + this, next_attempt_time - ExecCtx::Get()->Now()); + } + // Schedule retry after computed delay. + GRPC_CLOSURE_INIT(&retry_closure_, OnRetryTimer, this, nullptr); + GRPC_CALL_STACK_REF(owning_call_, "OnRetryTimer"); + retry_timer_pending_ = true; + grpc_timer_init(&retry_timer_, next_attempt_time, &retry_closure_); +} + +void RetryFilter::CallData::OnRetryTimer(void* arg, grpc_error_handle error) { + auto* calld = static_cast(arg); + GRPC_CLOSURE_INIT(&calld->retry_closure_, OnRetryTimerLocked, calld, nullptr); + GRPC_CALL_COMBINER_START(calld->call_combiner_, &calld->retry_closure_, + GRPC_ERROR_REF(error), "retry timer fired"); +} + +void RetryFilter::CallData::OnRetryTimerLocked(void* arg, + grpc_error_handle error) { + auto* calld = static_cast(arg); + if (error == GRPC_ERROR_NONE && calld->retry_timer_pending_) { + calld->retry_timer_pending_ = false; + calld->CreateCallAttempt(); + } else { + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, "retry timer cancelled"); + } + GRPC_CALL_STACK_UNREF(calld->owning_call_, "OnRetryTimer"); +} + +} // namespace + +const grpc_channel_filter kRetryFilterVtable = { + RetryFilter::CallData::StartTransportStreamOpBatch, + RetryFilter::StartTransportOp, + sizeof(RetryFilter::CallData), + RetryFilter::CallData::Init, + RetryFilter::CallData::SetPollent, + RetryFilter::CallData::Destroy, + sizeof(RetryFilter), + RetryFilter::Init, + RetryFilter::Destroy, + RetryFilter::GetChannelInfo, + "retry_filter", +}; + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/retry_service_config.cc b/src/core/ext/filters/client_channel/retry_service_config.cc new file mode 100644 index 00000000..7d83a949 --- /dev/null +++ b/src/core/ext/filters/client_channel/retry_service_config.cc @@ -0,0 +1,316 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/retry_service_config.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/json/json_util.h" +#include "src/core/lib/uri/uri_parser.h" + +// As per the retry design, we do not allow more than 5 retry attempts. +#define MAX_MAX_RETRY_ATTEMPTS 5 + +namespace grpc_core { +namespace internal { + +namespace { +size_t g_retry_service_config_parser_index; +} + +size_t RetryServiceConfigParser::ParserIndex() { + return g_retry_service_config_parser_index; +} + +void RetryServiceConfigParser::Register() { + g_retry_service_config_parser_index = ServiceConfigParser::RegisterParser( + absl::make_unique()); +} + +namespace { + +grpc_error_handle ParseRetryThrottling(const Json& json, + intptr_t* max_milli_tokens, + intptr_t* milli_token_ratio) { + if (json.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling error:Type should be object"); + } + std::vector error_list; + // Parse maxTokens. + auto it = json.object_value().find("maxTokens"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:maxTokens error:Not found")); + } else if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:maxTokens error:Type should be " + "number")); + } else { + *max_milli_tokens = + gpr_parse_nonnegative_int(it->second.string_value().c_str()) * 1000; + if (*max_milli_tokens <= 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:maxTokens error:should be " + "greater than zero")); + } + } + // Parse tokenRatio. + it = json.object_value().find("tokenRatio"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:tokenRatio error:Not found")); + } else if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:tokenRatio error:type should be " + "number")); + } else { + // We support up to 3 decimal digits. + size_t whole_len = it->second.string_value().size(); + const char* value = it->second.string_value().c_str(); + uint32_t multiplier = 1; + uint32_t decimal_value = 0; + const char* decimal_point = strchr(value, '.'); + if (decimal_point != nullptr) { + whole_len = static_cast(decimal_point - value); + multiplier = 1000; + size_t decimal_len = strlen(decimal_point + 1); + if (decimal_len > 3) decimal_len = 3; + if (!gpr_parse_bytes_to_uint32(decimal_point + 1, decimal_len, + &decimal_value)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:tokenRatio error:Failed " + "parsing")); + return GRPC_ERROR_CREATE_FROM_VECTOR("retryThrottling", &error_list); + } + uint32_t decimal_multiplier = 1; + for (size_t i = 0; i < (3 - decimal_len); ++i) { + decimal_multiplier *= 10; + } + decimal_value *= decimal_multiplier; + } + uint32_t whole_value; + if (!gpr_parse_bytes_to_uint32(value, whole_len, &whole_value)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:tokenRatio error:Failed " + "parsing")); + return GRPC_ERROR_CREATE_FROM_VECTOR("retryThrottling", &error_list); + } + *milli_token_ratio = + static_cast((whole_value * multiplier) + decimal_value); + if (*milli_token_ratio <= 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryThrottling field:tokenRatio error:value should " + "be greater than 0")); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("retryThrottling", &error_list); +} + +} // namespace + +std::unique_ptr +RetryServiceConfigParser::ParseGlobalParams(const grpc_channel_args* /*args*/, + const Json& json, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + auto it = json.object_value().find("retryThrottling"); + if (it == json.object_value().end()) return nullptr; + intptr_t max_milli_tokens = 0; + intptr_t milli_token_ratio = 0; + *error = + ParseRetryThrottling(it->second, &max_milli_tokens, &milli_token_ratio); + if (*error != GRPC_ERROR_NONE) return nullptr; + return absl::make_unique(max_milli_tokens, + milli_token_ratio); +} + +namespace { + +grpc_error_handle ParseRetryPolicy( + const grpc_channel_args* args, const Json& json, int* max_attempts, + grpc_millis* initial_backoff, grpc_millis* max_backoff, + float* backoff_multiplier, StatusCodeSet* retryable_status_codes, + absl::optional* per_attempt_recv_timeout) { + if (json.type() != Json::Type::OBJECT) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryPolicy error:should be of type object"); + } + std::vector error_list; + // Parse maxAttempts. + auto it = json.object_value().find("maxAttempts"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxAttempts error:required field missing")); + } else { + if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxAttempts error:should be of type number")); + } else { + *max_attempts = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (*max_attempts <= 1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxAttempts error:should be at least 2")); + } else if (*max_attempts > MAX_MAX_RETRY_ATTEMPTS) { + gpr_log(GPR_ERROR, + "service config: clamped retryPolicy.maxAttempts at %d", + MAX_MAX_RETRY_ATTEMPTS); + *max_attempts = MAX_MAX_RETRY_ATTEMPTS; + } + } + } + // Parse initialBackoff. + if (ParseJsonObjectFieldAsDuration(json.object_value(), "initialBackoff", + initial_backoff, &error_list) && + *initial_backoff == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:initialBackoff error:must be greater than 0")); + } + // Parse maxBackoff. + if (ParseJsonObjectFieldAsDuration(json.object_value(), "maxBackoff", + max_backoff, &error_list) && + *max_backoff == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxBackoff error:must be greater than 0")); + } + // Parse backoffMultiplier. + it = json.object_value().find("backoffMultiplier"); + if (it == json.object_value().end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:backoffMultiplier error:required field missing")); + } else { + if (it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:backoffMultiplier error:should be of type number")); + } else { + if (sscanf(it->second.string_value().c_str(), "%f", backoff_multiplier) != + 1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:backoffMultiplier error:failed to parse")); + } else if (*backoff_multiplier <= 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:backoffMultiplier error:must be greater than 0")); + } + } + } + // Parse retryableStatusCodes. + it = json.object_value().find("retryableStatusCodes"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryableStatusCodes error:must be of type array")); + } else { + for (const Json& element : it->second.array_value()) { + if (element.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryableStatusCodes error:status codes should be of type " + "string")); + continue; + } + grpc_status_code status; + if (!grpc_status_code_from_string(element.string_value().c_str(), + &status)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryableStatusCodes error:failed to parse status code")); + continue; + } + retryable_status_codes->Add(status); + } + } + } + // Parse perAttemptRecvTimeout. + if (grpc_channel_args_find_bool(args, GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING, + false)) { + it = json.object_value().find("perAttemptRecvTimeout"); + if (it != json.object_value().end()) { + grpc_millis per_attempt_recv_timeout_value; + if (!ParseDurationFromJson(it->second, &per_attempt_recv_timeout_value)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:perAttemptRecvTimeout error:type must be STRING of the " + "form given by google.proto.Duration.")); + } else { + *per_attempt_recv_timeout = per_attempt_recv_timeout_value; + // TODO(roth): As part of implementing hedging, relax this check such + // that we allow a value of 0 if a hedging policy is specified. + if (per_attempt_recv_timeout_value == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:perAttemptRecvTimeout error:must be greater than 0")); + } + } + } else if (retryable_status_codes->Empty()) { + // If perAttemptRecvTimeout not present, retryableStatusCodes must be + // non-empty. + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryableStatusCodes error:must be non-empty if " + "perAttemptRecvTimeout not present")); + } + } else { + // Hedging not enabled, so the error message for + // retryableStatusCodes unset should be different. + if (retryable_status_codes->Empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:retryableStatusCodes error:must be non-empty")); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("retryPolicy", &error_list); +} + +} // namespace + +std::unique_ptr +RetryServiceConfigParser::ParsePerMethodParams(const grpc_channel_args* args, + const Json& json, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + // Parse retry policy. + auto it = json.object_value().find("retryPolicy"); + if (it == json.object_value().end()) return nullptr; + int max_attempts = 0; + grpc_millis initial_backoff = 0; + grpc_millis max_backoff = 0; + float backoff_multiplier = 0; + StatusCodeSet retryable_status_codes; + absl::optional per_attempt_recv_timeout; + *error = ParseRetryPolicy(args, it->second, &max_attempts, &initial_backoff, + &max_backoff, &backoff_multiplier, + &retryable_status_codes, &per_attempt_recv_timeout); + if (*error != GRPC_ERROR_NONE) return nullptr; + return absl::make_unique( + max_attempts, initial_backoff, max_backoff, backoff_multiplier, + retryable_status_codes, per_attempt_recv_timeout); +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/retry_throttle.cc b/src/core/ext/filters/client_channel/retry_throttle.cc new file mode 100644 index 00000000..ddeb13f6 --- /dev/null +++ b/src/core/ext/filters/client_channel/retry_throttle.cc @@ -0,0 +1,162 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/retry_throttle.h" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/manual_constructor.h" + +namespace grpc_core { +namespace internal { + +// +// ServerRetryThrottleData +// + +ServerRetryThrottleData::ServerRetryThrottleData( + intptr_t max_milli_tokens, intptr_t milli_token_ratio, + ServerRetryThrottleData* old_throttle_data) + : max_milli_tokens_(max_milli_tokens), + milli_token_ratio_(milli_token_ratio) { + intptr_t initial_milli_tokens = max_milli_tokens; + // If there was a pre-existing entry for this server name, initialize + // the token count by scaling proportionately to the old data. This + // ensures that if we're already throttling retries on the old scale, + // we will start out doing the same thing on the new one. + if (old_throttle_data != nullptr) { + double token_fraction = + static_cast( + gpr_atm_acq_load(&old_throttle_data->milli_tokens_)) / + static_cast(old_throttle_data->max_milli_tokens_); + initial_milli_tokens = + static_cast(token_fraction * max_milli_tokens); + } + gpr_atm_rel_store(&milli_tokens_, static_cast(initial_milli_tokens)); + // If there was a pre-existing entry, mark it as stale and give it a + // pointer to the new entry, which is its replacement. + if (old_throttle_data != nullptr) { + Ref().release(); // Ref held by pre-existing entry. + gpr_atm_rel_store(&old_throttle_data->replacement_, + reinterpret_cast(this)); + } +} + +ServerRetryThrottleData::~ServerRetryThrottleData() { + ServerRetryThrottleData* replacement = + reinterpret_cast( + gpr_atm_acq_load(&replacement_)); + if (replacement != nullptr) { + replacement->Unref(); + } +} + +void ServerRetryThrottleData::GetReplacementThrottleDataIfNeeded( + ServerRetryThrottleData** throttle_data) { + while (true) { + ServerRetryThrottleData* new_throttle_data = + reinterpret_cast( + gpr_atm_acq_load(&(*throttle_data)->replacement_)); + if (new_throttle_data == nullptr) return; + *throttle_data = new_throttle_data; + } +} + +bool ServerRetryThrottleData::RecordFailure() { + // First, check if we are stale and need to be replaced. + ServerRetryThrottleData* throttle_data = this; + GetReplacementThrottleDataIfNeeded(&throttle_data); + // We decrement milli_tokens by 1000 (1 token) for each failure. + const intptr_t new_value = + static_cast(gpr_atm_no_barrier_clamped_add( + &throttle_data->milli_tokens_, static_cast(-1000), + static_cast(0), + static_cast(throttle_data->max_milli_tokens_))); + // Retries are allowed as long as the new value is above the threshold + // (max_milli_tokens / 2). + return new_value > throttle_data->max_milli_tokens_ / 2; +} + +void ServerRetryThrottleData::RecordSuccess() { + // First, check if we are stale and need to be replaced. + ServerRetryThrottleData* throttle_data = this; + GetReplacementThrottleDataIfNeeded(&throttle_data); + // We increment milli_tokens by milli_token_ratio for each success. + gpr_atm_no_barrier_clamped_add( + &throttle_data->milli_tokens_, + static_cast(throttle_data->milli_token_ratio_), + static_cast(0), + static_cast(throttle_data->max_milli_tokens_)); +} + +// +// ServerRetryThrottleMap +// + +using StringToDataMap = + std::map>; +static gpr_mu g_mu; +static StringToDataMap* g_map; + +void ServerRetryThrottleMap::Init() { + gpr_mu_init(&g_mu); + g_map = new StringToDataMap(); +} + +void ServerRetryThrottleMap::Shutdown() { + gpr_mu_destroy(&g_mu); + delete g_map; + g_map = nullptr; +} + +RefCountedPtr ServerRetryThrottleMap::GetDataForServer( + const std::string& server_name, intptr_t max_milli_tokens, + intptr_t milli_token_ratio) { + RefCountedPtr result; + gpr_mu_lock(&g_mu); + auto it = g_map->find(server_name); + ServerRetryThrottleData* throttle_data = + it == g_map->end() ? nullptr : it->second.get(); + if (throttle_data == nullptr || + throttle_data->max_milli_tokens() != max_milli_tokens || + throttle_data->milli_token_ratio() != milli_token_ratio) { + // Entry not found, or found with old parameters. Create a new one. + it = g_map + ->emplace(server_name, + MakeRefCounted( + max_milli_tokens, milli_token_ratio, throttle_data)) + .first; + throttle_data = it->second.get(); + } + gpr_mu_unlock(&g_mu); + return throttle_data->Ref(); +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/server_address.cc b/src/core/ext/filters/client_channel/server_address.cc new file mode 100644 index 00000000..e082256b --- /dev/null +++ b/src/core/ext/filters/client_channel/server_address.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/client_channel/server_address.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +#include "src/core/lib/address_utils/sockaddr_utils.h" + +namespace grpc_core { + +// +// ServerAddressWeightAttribute +// +const char* ServerAddressWeightAttribute::kServerAddressWeightAttributeKey = + "server_address_weight"; + +// +// ServerAddress +// + +ServerAddress::ServerAddress( + const grpc_resolved_address& address, grpc_channel_args* args, + std::map> attributes) + : address_(address), args_(args), attributes_(std::move(attributes)) {} + +ServerAddress::ServerAddress( + const void* address, size_t address_len, grpc_channel_args* args, + std::map> attributes) + : args_(args), attributes_(std::move(attributes)) { + memcpy(address_.addr, address, address_len); + address_.len = static_cast(address_len); +} + +ServerAddress::ServerAddress(const ServerAddress& other) + : address_(other.address_), args_(grpc_channel_args_copy(other.args_)) { + for (const auto& p : other.attributes_) { + attributes_[p.first] = p.second->Copy(); + } +} +ServerAddress& ServerAddress::operator=(const ServerAddress& other) { + if (&other == this) { + return *this; + } + address_ = other.address_; + grpc_channel_args_destroy(args_); + args_ = grpc_channel_args_copy(other.args_); + attributes_.clear(); + for (const auto& p : other.attributes_) { + attributes_[p.first] = p.second->Copy(); + } + return *this; +} + +ServerAddress::ServerAddress(ServerAddress&& other) noexcept + : address_(other.address_), + args_(other.args_), + attributes_(std::move(other.attributes_)) { + other.args_ = nullptr; +} +ServerAddress& ServerAddress::operator=(ServerAddress&& other) noexcept { + address_ = other.address_; + grpc_channel_args_destroy(args_); + args_ = other.args_; + other.args_ = nullptr; + attributes_ = std::move(other.attributes_); + return *this; +} + +namespace { + +int CompareAttributes( + const std::map>& + attributes1, + const std::map>& + attributes2) { + auto it2 = attributes2.begin(); + for (auto it1 = attributes1.begin(); it1 != attributes1.end(); ++it1) { + // attributes2 has fewer elements than attributes1 + if (it2 == attributes2.end()) return -1; + // compare keys + int retval = strcmp(it1->first, it2->first); + if (retval != 0) return retval; + // compare values + retval = it1->second->Cmp(it2->second.get()); + if (retval != 0) return retval; + ++it2; + } + // attributes1 has fewer elements than attributes2 + if (it2 != attributes2.end()) return 1; + // equal + return 0; +} + +} // namespace + +int ServerAddress::Cmp(const ServerAddress& other) const { + if (address_.len > other.address_.len) return 1; + if (address_.len < other.address_.len) return -1; + int retval = memcmp(address_.addr, other.address_.addr, address_.len); + if (retval != 0) return retval; + retval = grpc_channel_args_compare(args_, other.args_); + if (retval != 0) return retval; + return CompareAttributes(attributes_, other.attributes_); +} + +const ServerAddress::AttributeInterface* ServerAddress::GetAttribute( + const char* key) const { + auto it = attributes_.find(key); + if (it == attributes_.end()) return nullptr; + return it->second.get(); +} + +// Returns a copy of the address with a modified attribute. +// If the new value is null, the attribute is removed. +ServerAddress ServerAddress::WithAttribute( + const char* key, std::unique_ptr value) const { + ServerAddress address = *this; + if (value == nullptr) { + address.attributes_.erase(key); + } else { + address.attributes_[key] = std::move(value); + } + return address; +} + +std::string ServerAddress::ToString() const { + std::vector parts = { + grpc_sockaddr_to_string(&address_, false), + }; + if (args_ != nullptr) { + parts.emplace_back( + absl::StrCat("args={", grpc_channel_args_string(args_), "}")); + } + if (!attributes_.empty()) { + std::vector attrs; + for (const auto& p : attributes_) { + attrs.emplace_back(absl::StrCat(p.first, "=", p.second->ToString())); + } + parts.emplace_back( + absl::StrCat("attributes={", absl::StrJoin(attrs, ", "), "}")); + } + return absl::StrJoin(parts, " "); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/service_config_channel_arg_filter.cc b/src/core/ext/filters/client_channel/service_config_channel_arg_filter.cc new file mode 100644 index 00000000..19054299 --- /dev/null +++ b/src/core/ext/filters/client_channel/service_config_channel_arg_filter.cc @@ -0,0 +1,156 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// This filter reads GRPC_ARG_SERVICE_CONFIG and populates ServiceConfigCallData +// in the call context per call for direct channels. + +#include + +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" + +namespace grpc_core { + +namespace { + +class ServiceConfigChannelArgChannelData { + public: + explicit ServiceConfigChannelArgChannelData( + const grpc_channel_element_args* args) { + const char* service_config_str = grpc_channel_args_find_string( + args->channel_args, GRPC_ARG_SERVICE_CONFIG); + if (service_config_str != nullptr) { + grpc_error_handle service_config_error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + args->channel_args, service_config_str, &service_config_error); + if (service_config_error == GRPC_ERROR_NONE) { + service_config_ = std::move(service_config); + } else { + gpr_log(GPR_ERROR, "%s", + grpc_error_std_string(service_config_error).c_str()); + } + GRPC_ERROR_UNREF(service_config_error); + } + } + + RefCountedPtr service_config() const { + return service_config_; + } + + private: + RefCountedPtr service_config_; +}; + +class ServiceConfigChannelArgCallData { + public: + ServiceConfigChannelArgCallData( + RefCountedPtr service_config, + const ServiceConfigParser::ParsedConfigVector* method_config, + const grpc_call_element_args* args) + : call_context_(args->context), + service_config_call_data_(std::move(service_config), method_config, + /*call_attributes=*/{}) { + GPR_DEBUG_ASSERT(args->context != nullptr); + // No need to set the destroy function, since it will be cleaned up + // when this filter is destroyed in the filter stack. + args->context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value = + &service_config_call_data_; + } + + ~ServiceConfigChannelArgCallData() { + // Remove the entry from call context, just in case anyone above us + // tries to look at it during call stack destruction. + call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value = nullptr; + } + + private: + grpc_call_context_element* call_context_; + ServiceConfigCallData service_config_call_data_; +}; + +grpc_error_handle ServiceConfigChannelArgInitCallElem( + grpc_call_element* elem, const grpc_call_element_args* args) { + auto* chand = + static_cast(elem->channel_data); + auto* calld = static_cast(elem->call_data); + RefCountedPtr service_config = chand->service_config(); + const ServiceConfigParser::ParsedConfigVector* method_config = nullptr; + if (service_config != nullptr) { + method_config = service_config->GetMethodParsedConfigVector(args->path); + } + new (calld) ServiceConfigChannelArgCallData(std::move(service_config), + method_config, args); + return GRPC_ERROR_NONE; +} + +void ServiceConfigChannelArgDestroyCallElem( + grpc_call_element* elem, const grpc_call_final_info* /* final_info */, + grpc_closure* /* then_schedule_closure */) { + ServiceConfigChannelArgCallData* calld = + static_cast(elem->call_data); + calld->~ServiceConfigChannelArgCallData(); +} + +grpc_error_handle ServiceConfigChannelArgInitChannelElem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + ServiceConfigChannelArgChannelData* chand = + static_cast(elem->channel_data); + new (chand) ServiceConfigChannelArgChannelData(args); + return GRPC_ERROR_NONE; +} + +void ServiceConfigChannelArgDestroyChannelElem(grpc_channel_element* elem) { + ServiceConfigChannelArgChannelData* chand = + static_cast(elem->channel_data); + chand->~ServiceConfigChannelArgChannelData(); +} + +const grpc_channel_filter ServiceConfigChannelArgFilter = { + grpc_call_next_op, + grpc_channel_next_op, + sizeof(ServiceConfigChannelArgCallData), + ServiceConfigChannelArgInitCallElem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + ServiceConfigChannelArgDestroyCallElem, + sizeof(ServiceConfigChannelArgChannelData), + ServiceConfigChannelArgInitChannelElem, + ServiceConfigChannelArgDestroyChannelElem, + grpc_channel_next_get_info, + "service_config_channel_arg"}; + +} // namespace + +void RegisterServiceConfigChannelArgFilter( + CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_CLIENT_DIRECT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (grpc_channel_args_want_minimal_stack(channel_args) || + grpc_channel_args_find_string(channel_args, + GRPC_ARG_SERVICE_CONFIG) == nullptr) { + return true; + } + return grpc_channel_stack_builder_prepend_filter( + builder, &ServiceConfigChannelArgFilter, nullptr, nullptr); + }); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/subchannel.cc b/src/core/ext/filters/client_channel/subchannel.cc new file mode 100644 index 00000000..39ded508 --- /dev/null +++ b/src/core/ext/filters/client_channel/subchannel.cc @@ -0,0 +1,1021 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/client_channel/subchannel.h" + +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/health/health_check_client.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/subchannel_pool_interface.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/status_metadata.h" +#include "src/core/lib/uri/uri_parser.h" + +// Strong and weak refs. +#define INTERNAL_REF_BITS 16 +#define STRONG_REF_MASK (~(gpr_atm)((1 << INTERNAL_REF_BITS) - 1)) + +// Backoff parameters. +#define GRPC_SUBCHANNEL_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define GRPC_SUBCHANNEL_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define GRPC_SUBCHANNEL_RECONNECT_MIN_TIMEOUT_SECONDS 20 +#define GRPC_SUBCHANNEL_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define GRPC_SUBCHANNEL_RECONNECT_JITTER 0.2 + +// Conversion between subchannel call and call stack. +#define SUBCHANNEL_CALL_TO_CALL_STACK(call) \ + (grpc_call_stack*)((char*)(call) + \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(SubchannelCall))) +#define CALL_STACK_TO_SUBCHANNEL_CALL(callstack) \ + (SubchannelCall*)(((char*)(call_stack)) - \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(SubchannelCall))) + +namespace grpc_core { + +TraceFlag grpc_trace_subchannel(false, "subchannel"); +DebugOnlyTraceFlag grpc_trace_subchannel_refcount(false, "subchannel_refcount"); + +// +// ConnectedSubchannel +// + +ConnectedSubchannel::ConnectedSubchannel( + grpc_channel_stack* channel_stack, const grpc_channel_args* args, + RefCountedPtr channelz_subchannel) + : RefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_trace_subchannel_refcount) + ? "ConnectedSubchannel" + : nullptr), + channel_stack_(channel_stack), + args_(grpc_channel_args_copy(args)), + channelz_subchannel_(std::move(channelz_subchannel)) {} + +ConnectedSubchannel::~ConnectedSubchannel() { + grpc_channel_args_destroy(args_); + GRPC_CHANNEL_STACK_UNREF(channel_stack_, "connected_subchannel_dtor"); +} + +void ConnectedSubchannel::StartWatch( + grpc_pollset_set* interested_parties, + OrphanablePtr watcher) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->start_connectivity_watch = std::move(watcher); + op->start_connectivity_watch_state = GRPC_CHANNEL_READY; + op->bind_pollset_set = interested_parties; + grpc_channel_element* elem = grpc_channel_stack_element(channel_stack_, 0); + elem->filter->start_transport_op(elem, op); +} + +void ConnectedSubchannel::Ping(grpc_closure* on_initiate, + grpc_closure* on_ack) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + grpc_channel_element* elem; + op->send_ping.on_initiate = on_initiate; + op->send_ping.on_ack = on_ack; + elem = grpc_channel_stack_element(channel_stack_, 0); + elem->filter->start_transport_op(elem, op); +} + +size_t ConnectedSubchannel::GetInitialCallSizeEstimate() const { + return GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(SubchannelCall)) + + channel_stack_->call_stack_size; +} + +// +// SubchannelCall +// + +RefCountedPtr SubchannelCall::Create(Args args, + grpc_error_handle* error) { + const size_t allocation_size = + args.connected_subchannel->GetInitialCallSizeEstimate(); + Arena* arena = args.arena; + return RefCountedPtr(new ( + arena->Alloc(allocation_size)) SubchannelCall(std::move(args), error)); +} + +SubchannelCall::SubchannelCall(Args args, grpc_error_handle* error) + : connected_subchannel_(std::move(args.connected_subchannel)), + deadline_(args.deadline) { + grpc_call_stack* callstk = SUBCHANNEL_CALL_TO_CALL_STACK(this); + const grpc_call_element_args call_args = { + callstk, /* call_stack */ + nullptr, /* server_transport_data */ + args.context, /* context */ + args.path, /* path */ + args.start_time, /* start_time */ + args.deadline, /* deadline */ + args.arena, /* arena */ + args.call_combiner /* call_combiner */ + }; + *error = grpc_call_stack_init(connected_subchannel_->channel_stack(), 1, + SubchannelCall::Destroy, this, &call_args); + if (GPR_UNLIKELY(*error != GRPC_ERROR_NONE)) { + gpr_log(GPR_ERROR, "error: %s", grpc_error_std_string(*error).c_str()); + return; + } + grpc_call_stack_set_pollset_or_pollset_set(callstk, args.pollent); + auto* channelz_node = connected_subchannel_->channelz_subchannel(); + if (channelz_node != nullptr) { + channelz_node->RecordCallStarted(); + } +} + +void SubchannelCall::StartTransportStreamOpBatch( + grpc_transport_stream_op_batch* batch) { + GPR_TIMER_SCOPE("subchannel_call_process_op", 0); + MaybeInterceptRecvTrailingMetadata(batch); + grpc_call_stack* call_stack = SUBCHANNEL_CALL_TO_CALL_STACK(this); + grpc_call_element* top_elem = grpc_call_stack_element(call_stack, 0); + GRPC_CALL_LOG_OP(GPR_INFO, top_elem, batch); + top_elem->filter->start_transport_stream_op_batch(top_elem, batch); +} + +grpc_call_stack* SubchannelCall::GetCallStack() { + return SUBCHANNEL_CALL_TO_CALL_STACK(this); +} + +void SubchannelCall::SetAfterCallStackDestroy(grpc_closure* closure) { + GPR_ASSERT(after_call_stack_destroy_ == nullptr); + GPR_ASSERT(closure != nullptr); + after_call_stack_destroy_ = closure; +} + +RefCountedPtr SubchannelCall::Ref() { + IncrementRefCount(); + return RefCountedPtr(this); +} + +RefCountedPtr SubchannelCall::Ref( + const grpc_core::DebugLocation& location, const char* reason) { + IncrementRefCount(location, reason); + return RefCountedPtr(this); +} + +void SubchannelCall::Unref() { + GRPC_CALL_STACK_UNREF(SUBCHANNEL_CALL_TO_CALL_STACK(this), ""); +} + +void SubchannelCall::Unref(const DebugLocation& /*location*/, + const char* reason) { + GRPC_CALL_STACK_UNREF(SUBCHANNEL_CALL_TO_CALL_STACK(this), reason); +} + +void SubchannelCall::Destroy(void* arg, grpc_error_handle /*error*/) { + GPR_TIMER_SCOPE("subchannel_call_destroy", 0); + SubchannelCall* self = static_cast(arg); + // Keep some members before destroying the subchannel call. + grpc_closure* after_call_stack_destroy = self->after_call_stack_destroy_; + RefCountedPtr connected_subchannel = + std::move(self->connected_subchannel_); + // Destroy the subchannel call. + self->~SubchannelCall(); + // Destroy the call stack. This should be after destroying the subchannel + // call, because call->after_call_stack_destroy(), if not null, will free the + // call arena. + grpc_call_stack_destroy(SUBCHANNEL_CALL_TO_CALL_STACK(self), nullptr, + after_call_stack_destroy); + // Automatically reset connected_subchannel. This should be after destroying + // the call stack, because destroying call stack needs access to the channel + // stack. +} + +void SubchannelCall::MaybeInterceptRecvTrailingMetadata( + grpc_transport_stream_op_batch* batch) { + // only intercept payloads with recv trailing. + if (!batch->recv_trailing_metadata) { + return; + } + // only add interceptor is channelz is enabled. + if (connected_subchannel_->channelz_subchannel() == nullptr) { + return; + } + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReady, + this, grpc_schedule_on_exec_ctx); + // save some state needed for the interception callback. + GPR_ASSERT(recv_trailing_metadata_ == nullptr); + recv_trailing_metadata_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata; + original_recv_trailing_metadata_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &recv_trailing_metadata_ready_; +} + +namespace { + +// Sets *status based on the rest of the parameters. +void GetCallStatus(grpc_status_code* status, grpc_millis deadline, + grpc_metadata_batch* md_batch, grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + grpc_error_get_status(error, deadline, status, nullptr, nullptr, nullptr); + } else { + if (md_batch->legacy_index()->named.grpc_status != nullptr) { + *status = grpc_get_status_code_from_metadata( + md_batch->legacy_index()->named.grpc_status->md); + } else { + *status = GRPC_STATUS_UNKNOWN; + } + } + GRPC_ERROR_UNREF(error); +} + +} // namespace + +void SubchannelCall::RecvTrailingMetadataReady(void* arg, + grpc_error_handle error) { + SubchannelCall* call = static_cast(arg); + GPR_ASSERT(call->recv_trailing_metadata_ != nullptr); + grpc_status_code status = GRPC_STATUS_OK; + GetCallStatus(&status, call->deadline_, call->recv_trailing_metadata_, + GRPC_ERROR_REF(error)); + channelz::SubchannelNode* channelz_subchannel = + call->connected_subchannel_->channelz_subchannel(); + GPR_ASSERT(channelz_subchannel != nullptr); + if (status == GRPC_STATUS_OK) { + channelz_subchannel->RecordCallSucceeded(); + } else { + channelz_subchannel->RecordCallFailed(); + } + Closure::Run(DEBUG_LOCATION, call->original_recv_trailing_metadata_, + GRPC_ERROR_REF(error)); +} + +void SubchannelCall::IncrementRefCount() { + GRPC_CALL_STACK_REF(SUBCHANNEL_CALL_TO_CALL_STACK(this), ""); +} + +void SubchannelCall::IncrementRefCount( + const grpc_core::DebugLocation& /*location*/, const char* reason) { + GRPC_CALL_STACK_REF(SUBCHANNEL_CALL_TO_CALL_STACK(this), reason); +} + +// +// Subchannel::ConnectedSubchannelStateWatcher +// + +class Subchannel::ConnectedSubchannelStateWatcher + : public AsyncConnectivityStateWatcherInterface { + public: + // Must be instantiated while holding c->mu. + explicit ConnectedSubchannelStateWatcher(WeakRefCountedPtr c) + : subchannel_(std::move(c)) {} + + ~ConnectedSubchannelStateWatcher() override { + subchannel_.reset(DEBUG_LOCATION, "state_watcher"); + } + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& status) override { + Subchannel* c = subchannel_.get(); + MutexLock lock(&c->mu_); + switch (new_state) { + case GRPC_CHANNEL_TRANSIENT_FAILURE: + case GRPC_CHANNEL_SHUTDOWN: { + if (!c->disconnected_ && c->connected_subchannel_ != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_subchannel)) { + gpr_log(GPR_INFO, + "subchannel %p %s: Connected subchannel %p has gone into " + "%s. Attempting to reconnect.", + c, c->key_.ToString().c_str(), + c->connected_subchannel_.get(), + ConnectivityStateName(new_state)); + } + c->connected_subchannel_.reset(); + if (c->channelz_node() != nullptr) { + c->channelz_node()->SetChildSocket(nullptr); + } + // We need to construct our own status if the underlying state was + // shutdown since the accompanying status will be StatusCode::OK + // otherwise. + c->SetConnectivityStateLocked( + GRPC_CHANNEL_TRANSIENT_FAILURE, + new_state == GRPC_CHANNEL_SHUTDOWN + ? absl::Status(absl::StatusCode::kUnavailable, + "Subchannel has disconnected.") + : status); + c->backoff_begun_ = false; + c->backoff_.Reset(); + } + break; + } + default: { + // In principle, this should never happen. We should not get + // a callback for READY, because that was the state we started + // this watch from. And a connected subchannel should never go + // from READY to CONNECTING or IDLE. + c->SetConnectivityStateLocked(new_state, status); + } + } + } + + WeakRefCountedPtr subchannel_; +}; + +// Asynchronously notifies the \a watcher of a change in the connectvity state +// of \a subchannel to the current \a state. Deletes itself when done. +class Subchannel::AsyncWatcherNotifierLocked { + public: + AsyncWatcherNotifierLocked( + RefCountedPtr watcher, + grpc_connectivity_state state, const absl::Status& status) + : watcher_(std::move(watcher)) { + watcher_->PushConnectivityStateChange({state, status}); + ExecCtx::Run(DEBUG_LOCATION, + GRPC_CLOSURE_INIT( + &closure_, + [](void* arg, grpc_error_handle /*error*/) { + auto* self = + static_cast(arg); + self->watcher_->OnConnectivityStateChange(); + delete self; + }, + this, nullptr), + GRPC_ERROR_NONE); + } + + private: + RefCountedPtr watcher_; + grpc_closure closure_; +}; + +// +// Subchannel::ConnectivityStateWatcherList +// + +void Subchannel::ConnectivityStateWatcherList::AddWatcherLocked( + RefCountedPtr watcher) { + watchers_.insert(std::make_pair(watcher.get(), std::move(watcher))); +} + +void Subchannel::ConnectivityStateWatcherList::RemoveWatcherLocked( + ConnectivityStateWatcherInterface* watcher) { + watchers_.erase(watcher); +} + +void Subchannel::ConnectivityStateWatcherList::NotifyLocked( + grpc_connectivity_state state, const absl::Status& status) { + for (const auto& p : watchers_) { + new AsyncWatcherNotifierLocked(p.second, state, status); + } +} + +// +// Subchannel::HealthWatcherMap::HealthWatcher +// + +// State needed for tracking the connectivity state with a particular +// health check service name. +class Subchannel::HealthWatcherMap::HealthWatcher + : public AsyncConnectivityStateWatcherInterface { + public: + HealthWatcher(WeakRefCountedPtr c, + std::string health_check_service_name) + : subchannel_(std::move(c)), + health_check_service_name_(std::move(health_check_service_name)), + state_(subchannel_->state_ == GRPC_CHANNEL_READY + ? GRPC_CHANNEL_CONNECTING + : subchannel_->state_) { + // If the subchannel is already connected, start health checking. + if (subchannel_->state_ == GRPC_CHANNEL_READY) StartHealthCheckingLocked(); + } + + ~HealthWatcher() override { + subchannel_.reset(DEBUG_LOCATION, "health_watcher"); + } + + const std::string& health_check_service_name() const { + return health_check_service_name_; + } + + grpc_connectivity_state state() const { return state_; } + + void AddWatcherLocked( + grpc_connectivity_state initial_state, + RefCountedPtr watcher) { + if (state_ != initial_state) { + new AsyncWatcherNotifierLocked(watcher, state_, status_); + } + watcher_list_.AddWatcherLocked(std::move(watcher)); + } + + void RemoveWatcherLocked( + Subchannel::ConnectivityStateWatcherInterface* watcher) { + watcher_list_.RemoveWatcherLocked(watcher); + } + + bool HasWatchers() const { return !watcher_list_.empty(); } + + void NotifyLocked(grpc_connectivity_state state, const absl::Status& status) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(subchannel_->mu_) { + if (state == GRPC_CHANNEL_READY) { + // If we had not already notified for CONNECTING state, do so now. + // (We may have missed this earlier, because if the transition + // from IDLE to CONNECTING to READY was too quick, the connected + // subchannel may not have sent us a notification for CONNECTING.) + if (state_ != GRPC_CHANNEL_CONNECTING) { + state_ = GRPC_CHANNEL_CONNECTING; + status_ = status; + watcher_list_.NotifyLocked(state_, status); + } + // If we've become connected, start health checking. + StartHealthCheckingLocked(); + } else { + state_ = state; + status_ = status; + watcher_list_.NotifyLocked(state_, status); + // We're not connected, so stop health checking. + health_check_client_.reset(); + } + } + + void Orphan() override { + watcher_list_.Clear(); + health_check_client_.reset(); + Unref(); + } + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& status) override { + MutexLock lock(&subchannel_->mu_); + if (new_state != GRPC_CHANNEL_SHUTDOWN && health_check_client_ != nullptr) { + state_ = new_state; + status_ = status; + watcher_list_.NotifyLocked(new_state, status); + } + } + + void StartHealthCheckingLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(subchannel_->mu_) { + GPR_ASSERT(health_check_client_ == nullptr); + health_check_client_ = MakeOrphanable( + health_check_service_name_, subchannel_->connected_subchannel_, + subchannel_->pollset_set_, subchannel_->channelz_node_, Ref()); + } + + WeakRefCountedPtr subchannel_; + std::string health_check_service_name_; + OrphanablePtr health_check_client_; + grpc_connectivity_state state_; + absl::Status status_; + ConnectivityStateWatcherList watcher_list_; +}; + +// +// Subchannel::HealthWatcherMap +// + +void Subchannel::HealthWatcherMap::AddWatcherLocked( + WeakRefCountedPtr subchannel, + grpc_connectivity_state initial_state, + const std::string& health_check_service_name, + RefCountedPtr watcher) { + // If the health check service name is not already present in the map, + // add it. + auto it = map_.find(health_check_service_name); + HealthWatcher* health_watcher; + if (it == map_.end()) { + auto w = MakeOrphanable(std::move(subchannel), + health_check_service_name); + health_watcher = w.get(); + map_.emplace(health_check_service_name, std::move(w)); + } else { + health_watcher = it->second.get(); + } + // Add the watcher to the entry. + health_watcher->AddWatcherLocked(initial_state, std::move(watcher)); +} + +void Subchannel::HealthWatcherMap::RemoveWatcherLocked( + const std::string& health_check_service_name, + ConnectivityStateWatcherInterface* watcher) { + auto it = map_.find(health_check_service_name); + GPR_ASSERT(it != map_.end()); + it->second->RemoveWatcherLocked(watcher); + // If we just removed the last watcher for this service name, remove + // the map entry. + if (!it->second->HasWatchers()) map_.erase(it); +} + +void Subchannel::HealthWatcherMap::NotifyLocked(grpc_connectivity_state state, + const absl::Status& status) { + for (const auto& p : map_) { + p.second->NotifyLocked(state, status); + } +} + +grpc_connectivity_state +Subchannel::HealthWatcherMap::CheckConnectivityStateLocked( + Subchannel* subchannel, const std::string& health_check_service_name) { + auto it = map_.find(health_check_service_name); + if (it == map_.end()) { + // If the health check service name is not found in the map, we're + // not currently doing a health check for that service name. If the + // subchannel's state without health checking is READY, report + // CONNECTING, since that's what we'd be in as soon as we do start a + // watch. Otherwise, report the channel's state without health checking. + return subchannel->state_ == GRPC_CHANNEL_READY ? GRPC_CHANNEL_CONNECTING + : subchannel->state_; + } + HealthWatcher* health_watcher = it->second.get(); + return health_watcher->state(); +} + +void Subchannel::HealthWatcherMap::ShutdownLocked() { map_.clear(); } + +// +// Subchannel +// + +namespace { + +BackOff::Options ParseArgsForBackoffValues( + const grpc_channel_args* args, grpc_millis* min_connect_timeout_ms) { + grpc_millis initial_backoff_ms = + GRPC_SUBCHANNEL_INITIAL_CONNECT_BACKOFF_SECONDS * 1000; + *min_connect_timeout_ms = + GRPC_SUBCHANNEL_RECONNECT_MIN_TIMEOUT_SECONDS * 1000; + grpc_millis max_backoff_ms = + GRPC_SUBCHANNEL_RECONNECT_MAX_BACKOFF_SECONDS * 1000; + bool fixed_reconnect_backoff = false; + if (args != nullptr) { + for (size_t i = 0; i < args->num_args; i++) { + if (0 == strcmp(args->args[i].key, + "grpc.testing.fixed_reconnect_backoff_ms")) { + fixed_reconnect_backoff = true; + initial_backoff_ms = *min_connect_timeout_ms = max_backoff_ms = + grpc_channel_arg_get_integer( + &args->args[i], + {static_cast(initial_backoff_ms), 100, INT_MAX}); + } else if (0 == + strcmp(args->args[i].key, GRPC_ARG_MIN_RECONNECT_BACKOFF_MS)) { + fixed_reconnect_backoff = false; + *min_connect_timeout_ms = grpc_channel_arg_get_integer( + &args->args[i], + {static_cast(*min_connect_timeout_ms), 100, INT_MAX}); + } else if (0 == + strcmp(args->args[i].key, GRPC_ARG_MAX_RECONNECT_BACKOFF_MS)) { + fixed_reconnect_backoff = false; + max_backoff_ms = grpc_channel_arg_get_integer( + &args->args[i], {static_cast(max_backoff_ms), 100, INT_MAX}); + } else if (0 == strcmp(args->args[i].key, + GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS)) { + fixed_reconnect_backoff = false; + initial_backoff_ms = grpc_channel_arg_get_integer( + &args->args[i], + {static_cast(initial_backoff_ms), 100, INT_MAX}); + } + } + } + return BackOff::Options() + .set_initial_backoff(initial_backoff_ms) + .set_multiplier(fixed_reconnect_backoff + ? 1.0 + : GRPC_SUBCHANNEL_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(fixed_reconnect_backoff ? 0.0 + : GRPC_SUBCHANNEL_RECONNECT_JITTER) + .set_max_backoff(max_backoff_ms); +} + +} // namespace + +void Subchannel::ConnectivityStateWatcherInterface::PushConnectivityStateChange( + ConnectivityStateChange state_change) { + MutexLock lock(&mu_); + connectivity_state_queue_.push_back(std::move(state_change)); +} + +Subchannel::ConnectivityStateWatcherInterface::ConnectivityStateChange +Subchannel::ConnectivityStateWatcherInterface::PopConnectivityStateChange() { + MutexLock lock(&mu_); + GPR_ASSERT(!connectivity_state_queue_.empty()); + ConnectivityStateChange state_change = connectivity_state_queue_.front(); + connectivity_state_queue_.pop_front(); + return state_change; +} + +Subchannel::Subchannel(SubchannelKey key, + OrphanablePtr connector, + const grpc_channel_args* args) + : DualRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_trace_subchannel_refcount) ? "Subchannel" + : nullptr), + key_(std::move(key)), + pollset_set_(grpc_pollset_set_create()), + connector_(std::move(connector)), + backoff_(ParseArgsForBackoffValues(args, &min_connect_timeout_ms_)) { + GRPC_STATS_INC_CLIENT_SUBCHANNELS_CREATED(); + GRPC_CLOSURE_INIT(&on_connecting_finished_, OnConnectingFinished, this, + grpc_schedule_on_exec_ctx); + // Check proxy mapper to determine address to connect to and channel + // args to use. + address_for_connect_ = key_.address(); + grpc_resolved_address* new_address = nullptr; + grpc_channel_args* new_args = nullptr; + if (ProxyMapperRegistry::MapAddress(address_for_connect_, args, &new_address, + &new_args)) { + GPR_ASSERT(new_address != nullptr); + address_for_connect_ = *new_address; + gpr_free(new_address); + } + if (new_args != nullptr) { + args_ = new_args; + } else { + args_ = grpc_channel_args_copy(args); + } + // Initialize channelz. + const bool channelz_enabled = grpc_channel_args_find_bool( + args_, GRPC_ARG_ENABLE_CHANNELZ, GRPC_ENABLE_CHANNELZ_DEFAULT); + if (channelz_enabled) { + const size_t channel_tracer_max_memory = + static_cast(grpc_channel_args_find_integer( + args_, GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, + {GRPC_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE_DEFAULT, 0, + INT_MAX})); + channelz_node_ = MakeRefCounted( + grpc_sockaddr_to_uri(&key_.address()), channel_tracer_max_memory); + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("subchannel created")); + } +} + +Subchannel::~Subchannel() { + if (channelz_node_ != nullptr) { + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("Subchannel destroyed")); + channelz_node_->UpdateConnectivityState(GRPC_CHANNEL_SHUTDOWN); + } + grpc_channel_args_destroy(args_); + connector_.reset(); + grpc_pollset_set_destroy(pollset_set_); +} + +RefCountedPtr Subchannel::Create( + OrphanablePtr connector, + const grpc_resolved_address& address, const grpc_channel_args* args) { + SubchannelKey key(address, args); + SubchannelPoolInterface* subchannel_pool = + SubchannelPoolInterface::GetSubchannelPoolFromChannelArgs(args); + GPR_ASSERT(subchannel_pool != nullptr); + RefCountedPtr c = subchannel_pool->FindSubchannel(key); + if (c != nullptr) { + return c; + } + c = MakeRefCounted(std::move(key), std::move(connector), args); + // Try to register the subchannel before setting the subchannel pool. + // Otherwise, in case of a registration race, unreffing c in + // RegisterSubchannel() will cause c to be tried to be unregistered, while + // its key maps to a different subchannel. + RefCountedPtr registered = + subchannel_pool->RegisterSubchannel(c->key_, c); + if (registered == c) c->subchannel_pool_ = subchannel_pool->Ref(); + return registered; +} + +void Subchannel::ThrottleKeepaliveTime(int new_keepalive_time) { + MutexLock lock(&mu_); + // Only update the value if the new keepalive time is larger. + if (new_keepalive_time > keepalive_time_) { + keepalive_time_ = new_keepalive_time; + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_subchannel)) { + gpr_log(GPR_INFO, "subchannel %p %s: throttling keepalive time to %d", + this, key_.ToString().c_str(), new_keepalive_time); + } + const grpc_arg arg_to_add = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), new_keepalive_time); + const char* arg_to_remove = GRPC_ARG_KEEPALIVE_TIME_MS; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + args_, &arg_to_remove, 1, &arg_to_add, 1); + grpc_channel_args_destroy(args_); + args_ = new_args; + } +} + +channelz::SubchannelNode* Subchannel::channelz_node() { + return channelz_node_.get(); +} + +grpc_connectivity_state Subchannel::CheckConnectivityState( + const absl::optional& health_check_service_name) { + MutexLock lock(&mu_); + if (health_check_service_name.has_value()) { + return health_watcher_map_.CheckConnectivityStateLocked( + this, *health_check_service_name); + } + return state_; +} + +void Subchannel::WatchConnectivityState( + grpc_connectivity_state initial_state, + const absl::optional& health_check_service_name, + RefCountedPtr watcher) { + MutexLock lock(&mu_); + grpc_pollset_set* interested_parties = watcher->interested_parties(); + if (interested_parties != nullptr) { + grpc_pollset_set_add_pollset_set(pollset_set_, interested_parties); + } + if (!health_check_service_name.has_value()) { + if (state_ != initial_state) { + new AsyncWatcherNotifierLocked(watcher, state_, status_); + } + watcher_list_.AddWatcherLocked(std::move(watcher)); + } else { + health_watcher_map_.AddWatcherLocked( + WeakRef(DEBUG_LOCATION, "health_watcher"), initial_state, + *health_check_service_name, std::move(watcher)); + } +} + +void Subchannel::CancelConnectivityStateWatch( + const absl::optional& health_check_service_name, + ConnectivityStateWatcherInterface* watcher) { + MutexLock lock(&mu_); + grpc_pollset_set* interested_parties = watcher->interested_parties(); + if (interested_parties != nullptr) { + grpc_pollset_set_del_pollset_set(pollset_set_, interested_parties); + } + if (!health_check_service_name.has_value()) { + watcher_list_.RemoveWatcherLocked(watcher); + } else { + health_watcher_map_.RemoveWatcherLocked(*health_check_service_name, + watcher); + } +} + +void Subchannel::AttemptToConnect() { + MutexLock lock(&mu_); + MaybeStartConnectingLocked(); +} + +void Subchannel::ResetBackoff() { + MutexLock lock(&mu_); + backoff_.Reset(); + if (have_retry_alarm_) { + retry_immediately_ = true; + grpc_timer_cancel(&retry_alarm_); + } else { + backoff_begun_ = false; + MaybeStartConnectingLocked(); + } +} + +void Subchannel::Orphan() { + // The subchannel_pool is only used once here in this subchannel, so the + // access can be outside of the lock. + if (subchannel_pool_ != nullptr) { + subchannel_pool_->UnregisterSubchannel(key_, this); + subchannel_pool_.reset(); + } + MutexLock lock(&mu_); + GPR_ASSERT(!disconnected_); + disconnected_ = true; + connector_.reset(); + connected_subchannel_.reset(); + health_watcher_map_.ShutdownLocked(); +} + +namespace { + +// Returns a string indicating the subchannel's connectivity state change to +// \a state. +const char* SubchannelConnectivityStateChangeString( + grpc_connectivity_state state) { + switch (state) { + case GRPC_CHANNEL_IDLE: + return "Subchannel state change to IDLE"; + case GRPC_CHANNEL_CONNECTING: + return "Subchannel state change to CONNECTING"; + case GRPC_CHANNEL_READY: + return "Subchannel state change to READY"; + case GRPC_CHANNEL_TRANSIENT_FAILURE: + return "Subchannel state change to TRANSIENT_FAILURE"; + case GRPC_CHANNEL_SHUTDOWN: + return "Subchannel state change to SHUTDOWN"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +} // namespace + +// Note: Must be called with a state that is different from the current state. +void Subchannel::SetConnectivityStateLocked(grpc_connectivity_state state, + const absl::Status& status) { + state_ = state; + status_ = status; + if (channelz_node_ != nullptr) { + channelz_node_->UpdateConnectivityState(state); + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string( + SubchannelConnectivityStateChangeString(state))); + } + // Notify non-health watchers. + watcher_list_.NotifyLocked(state, status); + // Notify health watchers. + health_watcher_map_.NotifyLocked(state, status); +} + +void Subchannel::MaybeStartConnectingLocked() { + if (disconnected_) { + // Don't try to connect if we're already disconnected. + return; + } + if (connecting_) { + // Already connecting: don't restart. + return; + } + if (connected_subchannel_ != nullptr) { + // Already connected: don't restart. + return; + } + connecting_ = true; + WeakRef(DEBUG_LOCATION, "connecting") + .release(); // ref held by pending connect + if (!backoff_begun_) { + backoff_begun_ = true; + ContinueConnectingLocked(); + } else { + GPR_ASSERT(!have_retry_alarm_); + have_retry_alarm_ = true; + const grpc_millis time_til_next = + next_attempt_deadline_ - ExecCtx::Get()->Now(); + if (time_til_next <= 0) { + gpr_log(GPR_INFO, "subchannel %p %s: Retry immediately", this, + key_.ToString().c_str()); + } else { + gpr_log(GPR_INFO, "subchannel %p %s: Retry in %" PRId64 " milliseconds", + this, key_.ToString().c_str(), time_til_next); + } + GRPC_CLOSURE_INIT(&on_retry_alarm_, OnRetryAlarm, this, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&retry_alarm_, next_attempt_deadline_, &on_retry_alarm_); + } +} + +void Subchannel::OnRetryAlarm(void* arg, grpc_error_handle error) { + WeakRefCountedPtr c(static_cast(arg)); + MutexLock lock(&c->mu_); + c->have_retry_alarm_ = false; + if (c->disconnected_) { + error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Disconnected", + &error, 1); + } else if (c->retry_immediately_) { + c->retry_immediately_ = false; + error = GRPC_ERROR_NONE; + } else { + (void)GRPC_ERROR_REF(error); + } + if (error == GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, + "subchannel %p %s: failed to connect to channel, retrying", c.get(), + c->key_.ToString().c_str()); + c->ContinueConnectingLocked(); + // Still connecting, keep ref around. Note that this stolen ref won't + // be dropped without first acquiring c->mu_. + c.release(); + } + GRPC_ERROR_UNREF(error); +} + +void Subchannel::ContinueConnectingLocked() { + SubchannelConnector::Args args; + args.address = &address_for_connect_; + args.interested_parties = pollset_set_; + const grpc_millis min_deadline = + min_connect_timeout_ms_ + ExecCtx::Get()->Now(); + next_attempt_deadline_ = backoff_.NextAttemptTime(); + args.deadline = std::max(next_attempt_deadline_, min_deadline); + args.channel_args = args_; + SetConnectivityStateLocked(GRPC_CHANNEL_CONNECTING, absl::Status()); + connector_->Connect(args, &connecting_result_, &on_connecting_finished_); +} + +void Subchannel::OnConnectingFinished(void* arg, grpc_error_handle error) { + WeakRefCountedPtr c(static_cast(arg)); + const grpc_channel_args* delete_channel_args = + c->connecting_result_.channel_args; + { + MutexLock lock(&c->mu_); + c->connecting_ = false; + if (c->connecting_result_.transport != nullptr && + c->PublishTransportLocked()) { + // Do nothing, transport was published. + } else if (!c->disconnected_) { + gpr_log(GPR_INFO, "subchannel %p %s: connect failed: %s", c.get(), + c->key_.ToString().c_str(), grpc_error_std_string(error).c_str()); + c->SetConnectivityStateLocked(GRPC_CHANNEL_TRANSIENT_FAILURE, + grpc_error_to_absl_status(error)); + } + } + grpc_channel_args_destroy(delete_channel_args); + c.reset(DEBUG_LOCATION, "connecting"); +} + +namespace { + +void ConnectionDestroy(void* arg, grpc_error_handle /*error*/) { + grpc_channel_stack* stk = static_cast(arg); + grpc_channel_stack_destroy(stk); + gpr_free(stk); +} + +} // namespace + +bool Subchannel::PublishTransportLocked() { + // Construct channel stack. + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + grpc_channel_stack_builder_set_channel_arguments( + builder, connecting_result_.channel_args); + grpc_channel_stack_builder_set_transport(builder, + connecting_result_.transport); + if (!CoreConfiguration::Get().channel_init().CreateStack( + builder, GRPC_CLIENT_SUBCHANNEL)) { + grpc_channel_stack_builder_destroy(builder); + return false; + } + grpc_channel_stack* stk; + grpc_error_handle error = grpc_channel_stack_builder_finish( + builder, 0, 1, ConnectionDestroy, nullptr, + reinterpret_cast(&stk)); + if (error != GRPC_ERROR_NONE) { + grpc_transport_destroy(connecting_result_.transport); + gpr_log(GPR_ERROR, + "subchannel %p %s: error initializing subchannel stack: %s", this, + key_.ToString().c_str(), grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return false; + } + RefCountedPtr socket = + std::move(connecting_result_.socket_node); + connecting_result_.Reset(); + if (disconnected_) { + grpc_channel_stack_destroy(stk); + gpr_free(stk); + return false; + } + // Publish. + connected_subchannel_.reset( + new ConnectedSubchannel(stk, args_, channelz_node_)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_subchannel)) { + gpr_log(GPR_INFO, "subchannel %p %s: new connected subchannel at %p", this, + key_.ToString().c_str(), connected_subchannel_.get()); + } + if (channelz_node_ != nullptr) { + channelz_node_->SetChildSocket(std::move(socket)); + } + // Start watching connected subchannel. + connected_subchannel_->StartWatch( + pollset_set_, MakeOrphanable( + WeakRef(DEBUG_LOCATION, "state_watcher"))); + // Report initial state. + SetConnectivityStateLocked(GRPC_CHANNEL_READY, absl::Status()); + return true; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/subchannel_pool_interface.cc b/src/core/ext/filters/client_channel/subchannel_pool_interface.cc new file mode 100644 index 00000000..d6b12ecc --- /dev/null +++ b/src/core/ext/filters/client_channel/subchannel_pool_interface.cc @@ -0,0 +1,126 @@ +// +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/filters/client_channel/subchannel_pool_interface.h" + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/useful.h" + +// The subchannel pool to reuse subchannels. +#define GRPC_ARG_SUBCHANNEL_POOL "grpc.internal.subchannel_pool" +// The subchannel key ID that is only used in test to make each key unique. +#define GRPC_ARG_SUBCHANNEL_KEY_TEST_ONLY_ID "grpc.subchannel_key_test_only_id" + +namespace grpc_core { + +TraceFlag grpc_subchannel_pool_trace(false, "subchannel_pool"); + +SubchannelKey::SubchannelKey(const grpc_resolved_address& address, + const grpc_channel_args* args) { + Init(address, args, grpc_channel_args_normalize); +} + +SubchannelKey::~SubchannelKey() { + grpc_channel_args_destroy(const_cast(args_)); +} + +SubchannelKey::SubchannelKey(const SubchannelKey& other) { + Init(other.address_, other.args_, grpc_channel_args_copy); +} + +SubchannelKey& SubchannelKey::operator=(const SubchannelKey& other) { + if (&other == this) { + return *this; + } + grpc_channel_args_destroy(const_cast(args_)); + Init(other.address_, other.args_, grpc_channel_args_copy); + return *this; +} + +SubchannelKey::SubchannelKey(SubchannelKey&& other) noexcept { + address_ = other.address_; + args_ = other.args_; + other.args_ = nullptr; +} + +SubchannelKey& SubchannelKey::operator=(SubchannelKey&& other) noexcept { + address_ = other.address_; + args_ = other.args_; + other.args_ = nullptr; + return *this; +} + +bool SubchannelKey::operator<(const SubchannelKey& other) const { + if (address_.len < other.address_.len) return true; + if (address_.len > other.address_.len) return false; + int r = memcmp(address_.addr, other.address_.addr, address_.len); + if (r < 0) return true; + if (r > 0) return false; + return grpc_channel_args_compare(args_, other.args_) < 0; +} + +void SubchannelKey::Init( + const grpc_resolved_address& address, const grpc_channel_args* args, + grpc_channel_args* (*copy_channel_args)(const grpc_channel_args* args)) { + address_ = address; + args_ = copy_channel_args(args); +} + +std::string SubchannelKey::ToString() const { + return absl::StrCat("{address=", grpc_sockaddr_to_uri(&address_), + ", args=", grpc_channel_args_string(args_), "}"); +} + +namespace { + +void* arg_copy(void* p) { + auto* subchannel_pool = static_cast(p); + subchannel_pool->Ref().release(); + return p; +} + +void arg_destroy(void* p) { + auto* subchannel_pool = static_cast(p); + subchannel_pool->Unref(); +} + +int arg_cmp(void* a, void* b) { return QsortCompare(a, b); } + +const grpc_arg_pointer_vtable subchannel_pool_arg_vtable = { + arg_copy, arg_destroy, arg_cmp}; + +} // namespace + +grpc_arg SubchannelPoolInterface::CreateChannelArg( + SubchannelPoolInterface* subchannel_pool) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_SUBCHANNEL_POOL), subchannel_pool, + &subchannel_pool_arg_vtable); +} + +SubchannelPoolInterface* +SubchannelPoolInterface::GetSubchannelPoolFromChannelArgs( + const grpc_channel_args* args) { + const grpc_arg* arg = grpc_channel_args_find(args, GRPC_ARG_SUBCHANNEL_POOL); + if (arg == nullptr || arg->type != GRPC_ARG_POINTER) return nullptr; + return static_cast(arg->value.pointer.p); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_idle/client_idle_filter.cc b/src/core/ext/filters/client_idle/client_idle_filter.cc new file mode 100644 index 00000000..6cd3d12e --- /dev/null +++ b/src/core/ext/filters/client_idle/client_idle_filter.cc @@ -0,0 +1,264 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/ext/filters/client_idle/idle_filter_state.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/transport/http2_errors.h" + +// TODO(juanlishen): The idle filter is disabled in client channel by default +// due to b/143502997. Try to fix the bug and enable the filter by default. +#define DEFAULT_IDLE_TIMEOUT_MS INT_MAX +// The user input idle timeout smaller than this would be capped to it. +#define MIN_IDLE_TIMEOUT_MS (1 /*second*/ * 1000) + +namespace grpc_core { + +TraceFlag grpc_trace_client_idle_filter(false, "client_idle_filter"); + +#define GRPC_IDLE_FILTER_LOG(format, ...) \ + do { \ + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_client_idle_filter)) { \ + gpr_log(GPR_INFO, "(client idle filter) " format, ##__VA_ARGS__); \ + } \ + } while (0) + +namespace { + +grpc_millis GetClientIdleTimeout(const grpc_channel_args* args) { + return std::max( + grpc_channel_arg_get_integer( + grpc_channel_args_find(args, GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS), + {DEFAULT_IDLE_TIMEOUT_MS, 0, INT_MAX}), + MIN_IDLE_TIMEOUT_MS); +} + +class ChannelData { + public: + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* args); + static void Destroy(grpc_channel_element* elem); + + static void StartTransportOp(grpc_channel_element* elem, + grpc_transport_op* op); + + void IncreaseCallCount(); + + void DecreaseCallCount(); + + private: + ChannelData(grpc_channel_element* elem, grpc_channel_element_args* args, + grpc_error_handle* error); + ~ChannelData() = default; + + static void IdleTimerCallback(void* arg, grpc_error_handle error); + static void IdleTransportOpCompleteCallback(void* arg, + grpc_error_handle error); + + void StartIdleTimer(); + + void EnterIdle(); + + grpc_channel_element* elem_; + // The channel stack to which we take refs for pending callbacks. + grpc_channel_stack* channel_stack_; + // Timeout after the last RPC finishes on the client channel at which the + // channel goes back into IDLE state. + const grpc_millis client_idle_timeout_; + + // Member data used to track the state of channel. + IdleFilterState idle_filter_state_{false}; + + // Idle timer and its callback closure. + grpc_timer idle_timer_; + grpc_closure idle_timer_callback_; + + // The transport op telling the client channel to enter IDLE. + grpc_transport_op idle_transport_op_; + grpc_closure idle_transport_op_complete_callback_; +}; + +grpc_error_handle ChannelData::Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + grpc_error_handle error = GRPC_ERROR_NONE; + new (elem->channel_data) ChannelData(elem, args, &error); + return error; +} + +void ChannelData::Destroy(grpc_channel_element* elem) { + ChannelData* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +void ChannelData::StartTransportOp(grpc_channel_element* elem, + grpc_transport_op* op) { + ChannelData* chand = static_cast(elem->channel_data); + // Catch the disconnect_with_error transport op. + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + // IncreaseCallCount() introduces a phony call and prevent the timer from + // being reset by other threads. + chand->IncreaseCallCount(); + // If the timer has been set, cancel the timer. + // No synchronization issues here. grpc_timer_cancel() is valid as long as + // the timer has been init()ed before. + grpc_timer_cancel(&chand->idle_timer_); + } + // Pass the op to the next filter. + grpc_channel_next_op(elem, op); +} + +void ChannelData::IncreaseCallCount() { + idle_filter_state_.IncreaseCallCount(); +} + +void ChannelData::DecreaseCallCount() { + if (idle_filter_state_.DecreaseCallCount()) { + // If there are no more calls in progress, start the idle timer. + StartIdleTimer(); + } +} + +ChannelData::ChannelData(grpc_channel_element* elem, + grpc_channel_element_args* args, + grpc_error_handle* /*error*/) + : elem_(elem), + channel_stack_(args->channel_stack), + client_idle_timeout_(GetClientIdleTimeout(args->channel_args)) { + // If the idle filter is explicitly disabled in channel args, this ctor should + // not get called. + GPR_ASSERT(client_idle_timeout_ != GRPC_MILLIS_INF_FUTURE); + GRPC_IDLE_FILTER_LOG("created with max_leisure_time = %" PRId64 " ms", + client_idle_timeout_); + // Initialize the idle timer without setting it. + grpc_timer_init_unset(&idle_timer_); + // Initialize the idle timer callback closure. + GRPC_CLOSURE_INIT(&idle_timer_callback_, IdleTimerCallback, this, + grpc_schedule_on_exec_ctx); + // Initialize the idle transport op complete callback. + GRPC_CLOSURE_INIT(&idle_transport_op_complete_callback_, + IdleTransportOpCompleteCallback, this, + grpc_schedule_on_exec_ctx); +} + +void ChannelData::IdleTimerCallback(void* arg, grpc_error_handle error) { + GRPC_IDLE_FILTER_LOG("timer alarms"); + ChannelData* chand = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + GRPC_IDLE_FILTER_LOG("timer canceled"); + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack_, "max idle timer callback"); + return; + } + if (chand->idle_filter_state_.CheckTimer()) { + chand->StartIdleTimer(); + } else { + chand->EnterIdle(); + } + GRPC_IDLE_FILTER_LOG("timer finishes"); + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack_, "max idle timer callback"); +} + +void ChannelData::IdleTransportOpCompleteCallback(void* arg, + grpc_error_handle /*error*/) { + ChannelData* chand = static_cast(arg); + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack_, "idle transport op"); +} + +void ChannelData::StartIdleTimer() { + GRPC_IDLE_FILTER_LOG("timer has started"); + // Hold a ref to the channel stack for the timer callback. + GRPC_CHANNEL_STACK_REF(channel_stack_, "max idle timer callback"); + grpc_timer_init(&idle_timer_, ExecCtx::Get()->Now() + client_idle_timeout_, + &idle_timer_callback_); +} + +void ChannelData::EnterIdle() { + GRPC_IDLE_FILTER_LOG("the channel will enter IDLE"); + // Hold a ref to the channel stack for the transport op. + GRPC_CHANNEL_STACK_REF(channel_stack_, "idle transport op"); + // Initialize the transport op. + idle_transport_op_ = {}; + idle_transport_op_.disconnect_with_error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("enter idle"), + GRPC_ERROR_INT_CHANNEL_CONNECTIVITY_STATE, GRPC_CHANNEL_IDLE); + idle_transport_op_.on_consumed = &idle_transport_op_complete_callback_; + // Pass the transport op down to the channel stack. + grpc_channel_next_op(elem_, &idle_transport_op_); +} + +class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args); + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* final_info, + grpc_closure* then_schedule_closure); +}; + +grpc_error_handle CallData::Init(grpc_call_element* elem, + const grpc_call_element_args* /*args*/) { + ChannelData* chand = static_cast(elem->channel_data); + chand->IncreaseCallCount(); + return GRPC_ERROR_NONE; +} + +void CallData::Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + ChannelData* chand = static_cast(elem->channel_data); + chand->DecreaseCallCount(); +} + +const grpc_channel_filter grpc_client_idle_filter = { + grpc_call_next_op, + ChannelData::StartTransportOp, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + sizeof(ChannelData), + ChannelData::Init, + ChannelData::Destroy, + grpc_channel_next_get_info, + "client_idle"}; + +} // namespace + +void RegisterClientIdleFilter(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_CLIENT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (!grpc_channel_args_want_minimal_stack(channel_args) && + GetClientIdleTimeout(channel_args) != INT_MAX) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_client_idle_filter, nullptr, nullptr); + } else { + return true; + } + }); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/client_idle/idle_filter_state.cc b/src/core/ext/filters/client_idle/idle_filter_state.cc new file mode 100644 index 00000000..cd068972 --- /dev/null +++ b/src/core/ext/filters/client_idle/idle_filter_state.cc @@ -0,0 +1,96 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/filters/client_idle/idle_filter_state.h" + +#include + +namespace grpc_core { + +IdleFilterState::IdleFilterState(bool start_timer) + : state_(start_timer ? kTimerStarted : 0) {} + +void IdleFilterState::IncreaseCallCount() { + uintptr_t state = state_.load(std::memory_order_relaxed); + uintptr_t new_state; + do { + // Increment the counter, and flag that there's been activity. + new_state = state; + new_state |= kCallsStartedSinceLastTimerCheck; + new_state += kCallIncrement; + } while (!state_.compare_exchange_weak( + state, new_state, std::memory_order_acq_rel, std::memory_order_relaxed)); +} + +bool IdleFilterState::DecreaseCallCount() { + uintptr_t state = state_.load(std::memory_order_relaxed); + uintptr_t new_state; + bool start_timer; + do { + start_timer = false; + new_state = state; + // Decrement call count (and assert there's at least one call outstanding!) + assert(new_state >= kCallIncrement); + new_state -= kCallIncrement; + // If that decrement reaches a call count of zero and we have not started a + // timer + if ((new_state >> kCallsInProgressShift) == 0 && + (new_state & kTimerStarted) == 0) { + // Flag that we will start a timer, and mark it started so nobody else + // does. + start_timer = true; + new_state |= kTimerStarted; + new_state &= ~kCallsInProgressShift; + } + } while (!state_.compare_exchange_weak( + state, new_state, std::memory_order_acq_rel, std::memory_order_relaxed)); + return start_timer; +} + +bool IdleFilterState::CheckTimer() { + uintptr_t state = state_.load(std::memory_order_relaxed); + uintptr_t new_state; + bool start_timer; + do { + if ((state >> kCallsInProgressShift) != 0) { + // Still calls in progress: nothing needs updating, just return + // and keep the timer going! + return true; + } + new_state = state; + bool is_active = false; + if (new_state & kCallsStartedSinceLastTimerCheck) { + // If any calls started since the last time we checked, then consider the + // channel still active and try again. + is_active = true; + new_state &= ~kCallsStartedSinceLastTimerCheck; + } + if (is_active) { + // If we are still active, we should signal that the timer should start + // again. + start_timer = true; + } else { + // Otherwise, we should not start the timer again, and we should signal + // that in the updated state. + start_timer = false; + new_state &= ~kTimerStarted; + } + } while (!state_.compare_exchange_weak( + state, new_state, std::memory_order_acq_rel, std::memory_order_relaxed)); + return start_timer; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/deadline/deadline_filter.cc b/src/core/ext/filters/deadline/deadline_filter.cc new file mode 100644 index 00000000..9bdd51fc --- /dev/null +++ b/src/core/ext/filters/deadline/deadline_filter.cc @@ -0,0 +1,391 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/deadline/deadline_filter.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +// A fire-and-forget class representing a pending deadline timer. +// Allocated on the call arena. +class TimerState { + public: + TimerState(grpc_call_element* elem, grpc_millis deadline) : elem_(elem) { + grpc_deadline_state* deadline_state = + static_cast(elem_->call_data); + GRPC_CALL_STACK_REF(deadline_state->call_stack, "DeadlineTimerState"); + GRPC_CLOSURE_INIT(&closure_, TimerCallback, this, nullptr); + grpc_timer_init(&timer_, deadline, &closure_); + } + + void Cancel() { grpc_timer_cancel(&timer_); } + + private: + // The on_complete callback used when sending a cancel_error batch down the + // filter stack. Yields the call combiner when the batch returns. + static void YieldCallCombiner(void* arg, grpc_error_handle /*ignored*/) { + TimerState* self = static_cast(arg); + grpc_deadline_state* deadline_state = + static_cast(self->elem_->call_data); + GRPC_CALL_COMBINER_STOP(deadline_state->call_combiner, + "got on_complete from cancel_stream batch"); + GRPC_CALL_STACK_UNREF(deadline_state->call_stack, "DeadlineTimerState"); + } + + // This is called via the call combiner, so access to deadline_state is + // synchronized. + static void SendCancelOpInCallCombiner(void* arg, grpc_error_handle error) { + TimerState* self = static_cast(arg); + grpc_transport_stream_op_batch* batch = grpc_make_transport_stream_op( + GRPC_CLOSURE_INIT(&self->closure_, YieldCallCombiner, self, nullptr)); + batch->cancel_stream = true; + batch->payload->cancel_stream.cancel_error = GRPC_ERROR_REF(error); + self->elem_->filter->start_transport_stream_op_batch(self->elem_, batch); + } + + // Timer callback. + static void TimerCallback(void* arg, grpc_error_handle error) { + TimerState* self = static_cast(arg); + grpc_deadline_state* deadline_state = + static_cast(self->elem_->call_data); + if (error != GRPC_ERROR_CANCELLED) { + error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Deadline Exceeded"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_DEADLINE_EXCEEDED); + deadline_state->call_combiner->Cancel(GRPC_ERROR_REF(error)); + GRPC_CLOSURE_INIT(&self->closure_, SendCancelOpInCallCombiner, self, + nullptr); + GRPC_CALL_COMBINER_START(deadline_state->call_combiner, &self->closure_, + error, + "deadline exceeded -- sending cancel_stream op"); + } else { + GRPC_CALL_STACK_UNREF(deadline_state->call_stack, "DeadlineTimerState"); + } + } + + // NOTE: This object's dtor is never called, so do not add any data + // members that require destruction! + // TODO(roth): We should ideally call this object's dtor somewhere, + // but that would require adding more synchronization, because we'd + // need to call the dtor only after both (a) the timer callback + // finishes and (b) the filter sees the call completion and attempts + // to cancel the timer. + grpc_call_element* elem_; + grpc_timer timer_; + grpc_closure closure_; +}; + +} // namespace grpc_core + +// +// grpc_deadline_state +// + +// Starts the deadline timer. +// This is called via the call combiner, so access to deadline_state is +// synchronized. +static void start_timer_if_needed(grpc_call_element* elem, + grpc_millis deadline) { + if (deadline == GRPC_MILLIS_INF_FUTURE) return; + grpc_deadline_state* deadline_state = + static_cast(elem->call_data); + GPR_ASSERT(deadline_state->timer_state == nullptr); + deadline_state->timer_state = + deadline_state->arena->New(elem, deadline); +} + +// Cancels the deadline timer. +// This is called via the call combiner, so access to deadline_state is +// synchronized. +static void cancel_timer_if_needed(grpc_deadline_state* deadline_state) { + if (deadline_state->timer_state != nullptr) { + deadline_state->timer_state->Cancel(); + deadline_state->timer_state = nullptr; + } +} + +// Callback run when we receive trailing metadata. +static void recv_trailing_metadata_ready(void* arg, grpc_error_handle error) { + grpc_deadline_state* deadline_state = static_cast(arg); + cancel_timer_if_needed(deadline_state); + // Invoke the original callback. + grpc_core::Closure::Run(DEBUG_LOCATION, + deadline_state->original_recv_trailing_metadata_ready, + GRPC_ERROR_REF(error)); +} + +// Inject our own recv_trailing_metadata_ready callback into op. +static void inject_recv_trailing_metadata_ready( + grpc_deadline_state* deadline_state, grpc_transport_stream_op_batch* op) { + deadline_state->original_recv_trailing_metadata_ready = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + GRPC_CLOSURE_INIT(&deadline_state->recv_trailing_metadata_ready, + recv_trailing_metadata_ready, deadline_state, + grpc_schedule_on_exec_ctx); + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &deadline_state->recv_trailing_metadata_ready; +} + +// Callback and associated state for starting the timer after call stack +// initialization has been completed. +struct start_timer_after_init_state { + start_timer_after_init_state(grpc_call_element* elem, grpc_millis deadline) + : elem(elem), deadline(deadline) {} + ~start_timer_after_init_state() { start_timer_if_needed(elem, deadline); } + + bool in_call_combiner = false; + grpc_call_element* elem; + grpc_millis deadline; + grpc_closure closure; +}; +static void start_timer_after_init(void* arg, grpc_error_handle error) { + struct start_timer_after_init_state* state = + static_cast(arg); + grpc_deadline_state* deadline_state = + static_cast(state->elem->call_data); + if (!state->in_call_combiner) { + // We are initially called without holding the call combiner, so we + // need to bounce ourselves into it. + state->in_call_combiner = true; + GRPC_CALL_COMBINER_START(deadline_state->call_combiner, &state->closure, + GRPC_ERROR_REF(error), + "scheduling deadline timer"); + return; + } + delete state; + GRPC_CALL_COMBINER_STOP(deadline_state->call_combiner, + "done scheduling deadline timer"); +} + +grpc_deadline_state::grpc_deadline_state(grpc_call_element* elem, + const grpc_call_element_args& args, + grpc_millis deadline) + : call_stack(args.call_stack), + call_combiner(args.call_combiner), + arena(args.arena) { + // Deadline will always be infinite on servers, so the timer will only be + // set on clients with a finite deadline. + if (deadline != GRPC_MILLIS_INF_FUTURE) { + // When the deadline passes, we indicate the failure by sending down + // an op with cancel_error set. However, we can't send down any ops + // until after the call stack is fully initialized. If we start the + // timer here, we have no guarantee that the timer won't pop before + // call stack initialization is finished. To avoid that problem, we + // create a closure to start the timer, and we schedule that closure + // to be run after call stack initialization is done. + struct start_timer_after_init_state* state = + new start_timer_after_init_state(elem, deadline); + GRPC_CLOSURE_INIT(&state->closure, start_timer_after_init, state, + grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &state->closure, GRPC_ERROR_NONE); + } +} + +grpc_deadline_state::~grpc_deadline_state() { cancel_timer_if_needed(this); } + +void grpc_deadline_state_reset(grpc_call_element* elem, + grpc_millis new_deadline) { + grpc_deadline_state* deadline_state = + static_cast(elem->call_data); + cancel_timer_if_needed(deadline_state); + start_timer_if_needed(elem, new_deadline); +} + +void grpc_deadline_state_client_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + grpc_deadline_state* deadline_state = + static_cast(elem->call_data); + if (op->cancel_stream) { + cancel_timer_if_needed(deadline_state); + } else { + // Make sure we know when the call is complete, so that we can cancel + // the timer. + if (op->recv_trailing_metadata) { + inject_recv_trailing_metadata_ready(deadline_state, op); + } + } +} + +// +// filter code +// + +// Constructor for channel_data. Used for both client and server filters. +static grpc_error_handle deadline_init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + return GRPC_ERROR_NONE; +} + +// Destructor for channel_data. Used for both client and server filters. +static void deadline_destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +// Call data used for both client and server filter. +typedef struct base_call_data { + grpc_deadline_state deadline_state; +} base_call_data; + +// Additional call data used only for the server filter. +typedef struct server_call_data { + base_call_data base; // Must be first. + // The closure for receiving initial metadata. + grpc_closure recv_initial_metadata_ready; + // Received initial metadata batch. + grpc_metadata_batch* recv_initial_metadata; + // The original recv_initial_metadata_ready closure, which we chain to + // after our own closure is invoked. + grpc_closure* next_recv_initial_metadata_ready; +} server_call_data; + +// Constructor for call_data. Used for both client and server filters. +static grpc_error_handle deadline_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + new (elem->call_data) grpc_deadline_state(elem, *args, args->deadline); + return GRPC_ERROR_NONE; +} + +// Destructor for call_data. Used for both client and server filters. +static void deadline_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + grpc_deadline_state* deadline_state = + static_cast(elem->call_data); + deadline_state->~grpc_deadline_state(); +} + +// Method for starting a call op for client filter. +static void deadline_client_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + grpc_deadline_state_client_start_transport_stream_op_batch(elem, op); + // Chain to next filter. + grpc_call_next_op(elem, op); +} + +// Callback for receiving initial metadata on the server. +static void recv_initial_metadata_ready(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + server_call_data* calld = static_cast(elem->call_data); + start_timer_if_needed( + elem, calld->recv_initial_metadata->get(grpc_core::GrpcTimeoutMetadata()) + .value_or(GRPC_MILLIS_INF_FUTURE)); + // Invoke the next callback. + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->next_recv_initial_metadata_ready, + GRPC_ERROR_REF(error)); +} + +// Method for starting a call op for server filter. +static void deadline_server_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + server_call_data* calld = static_cast(elem->call_data); + if (op->cancel_stream) { + cancel_timer_if_needed(&calld->base.deadline_state); + } else { + // If we're receiving initial metadata, we need to get the deadline + // from the recv_initial_metadata_ready callback. So we inject our + // own callback into that hook. + if (op->recv_initial_metadata) { + calld->next_recv_initial_metadata_ready = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + calld->recv_initial_metadata = + op->payload->recv_initial_metadata.recv_initial_metadata; + GRPC_CLOSURE_INIT(&calld->recv_initial_metadata_ready, + recv_initial_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + op->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready; + } + // Make sure we know when the call is complete, so that we can cancel + // the timer. + // Note that we trigger this on recv_trailing_metadata, even though + // the client never sends trailing metadata, because this is the + // hook that tells us when the call is complete on the server side. + if (op->recv_trailing_metadata) { + inject_recv_trailing_metadata_ready(&calld->base.deadline_state, op); + } + } + // Chain to next filter. + grpc_call_next_op(elem, op); +} + +const grpc_channel_filter grpc_client_deadline_filter = { + deadline_client_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(base_call_data), + deadline_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + deadline_destroy_call_elem, + 0, // sizeof(channel_data) + deadline_init_channel_elem, + deadline_destroy_channel_elem, + grpc_channel_next_get_info, + "deadline", +}; + +const grpc_channel_filter grpc_server_deadline_filter = { + deadline_server_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(server_call_data), + deadline_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + deadline_destroy_call_elem, + 0, // sizeof(channel_data) + deadline_init_channel_elem, + deadline_destroy_channel_elem, + grpc_channel_next_get_info, + "deadline", +}; + +bool grpc_deadline_checking_enabled(const grpc_channel_args* channel_args) { + return grpc_channel_arg_get_bool( + grpc_channel_args_find(channel_args, GRPC_ARG_ENABLE_DEADLINE_CHECKS), + !grpc_channel_args_want_minimal_stack(channel_args)); +} + +namespace grpc_core { +void RegisterDeadlineFilter(CoreConfiguration::Builder* builder) { + auto register_filter = [builder](grpc_channel_stack_type type, + const grpc_channel_filter* filter) { + builder->channel_init()->RegisterStage( + type, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [filter](grpc_channel_stack_builder* builder) { + if (grpc_deadline_checking_enabled( + grpc_channel_stack_builder_get_channel_arguments(builder))) { + return grpc_channel_stack_builder_prepend_filter(builder, filter, + nullptr, nullptr); + } + return true; + }); + }; + register_filter(GRPC_CLIENT_DIRECT_CHANNEL, &grpc_client_deadline_filter); + register_filter(GRPC_SERVER_CHANNEL, &grpc_server_deadline_filter); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/fault_injection/fault_injection_filter.cc b/src/core/ext/filters/fault_injection/fault_injection_filter.cc new file mode 100644 index 00000000..2a9a9bed --- /dev/null +++ b/src/core/ext/filters/fault_injection/fault_injection_filter.cc @@ -0,0 +1,503 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/fault_injection/fault_injection_filter.h" + +#include + +#include "absl/strings/numbers.h" + +#include +#include + +#include "src/core/ext/filters/fault_injection/service_config_parser.h" +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/transport/status_conversion.h" + +namespace grpc_core { + +TraceFlag grpc_fault_injection_filter_trace(false, "fault_injection_filter"); + +namespace { + +std::atomic g_active_faults{0}; +static_assert( + std::is_trivially_destructible>::value, + "the active fault counter needs to have a trivially destructible type"); + +inline int GetMetadatumValueInt(grpc_mdelem md) { + int res; + if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md)), &res)) { + return res; + } else { + return -1; + } +} + +inline uint32_t GetMetadatumValueUnsignedInt(grpc_mdelem md) { + uint32_t res; + if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md)), &res)) { + return res; + } else { + return -1; + } +} + +inline int64_t GetMetadatumValueInt64(grpc_mdelem md) { + int64_t res; + if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md)), &res)) { + return res; + } else { + return -1; + } +} + +inline bool UnderFraction(const uint32_t numerator, + const uint32_t denominator) { + if (numerator <= 0) return false; + if (numerator >= denominator) return true; + // Generate a random number in [0, denominator). + const uint32_t random_number = rand() % denominator; + return random_number < numerator; +} + +class ChannelData { + public: + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* args); + static void Destroy(grpc_channel_element* elem); + + int index() const { return index_; } + + private: + ChannelData(grpc_channel_element* elem, grpc_channel_element_args* args); + ~ChannelData() = default; + + // The relative index of instances of the same filter. + int index_; +}; + +class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args); + + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*then_schedule_closure*/); + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch); + + private: + class ResumeBatchCanceller; + + CallData(grpc_call_element* elem, const grpc_call_element_args* args); + ~CallData(); + + void DecideWhetherToInjectFaults(grpc_metadata_batch* initial_metadata); + + // Checks if current active faults exceed the allowed max faults. + bool HaveActiveFaultsQuota(bool increment); + + // Returns true if this RPC needs to be delayed. If so, this call will be + // counted as an active fault. + bool MaybeDelay(); + + // Returns the aborted RPC status if this RPC needs to be aborted. If so, + // this call will be counted as an active fault. Otherwise, it returns + // GRPC_ERROR_NONE. + // If this call is already been delay injected, skip the active faults + // quota check. + grpc_error_handle MaybeAbort(); + + // Delays the stream operations batch. + void DelayBatch(grpc_call_element* elem, + grpc_transport_stream_op_batch* batch); + + // Cancels the delay timer. + void CancelDelayTimer() { grpc_timer_cancel(&delay_timer_); } + + // Finishes the fault injection, should only be called once. + void FaultInjectionFinished() { + g_active_faults.fetch_sub(1, std::memory_order_relaxed); + } + + // This is a callback that will be invoked after the delay timer is up. + static void ResumeBatch(void* arg, grpc_error_handle error); + + // This is a callback invoked upon completion of recv_trailing_metadata. + // Injects the abort_error_ to the recv_trailing_metadata batch if needed. + static void HijackedRecvTrailingMetadataReady(void* arg, grpc_error_handle); + + // Used to track the policy structs that needs to be destroyed in dtor. + bool fi_policy_owned_ = false; + const FaultInjectionMethodParsedConfig::FaultInjectionPolicy* fi_policy_; + grpc_call_stack* owning_call_; + Arena* arena_; + CallCombiner* call_combiner_; + + // Indicates whether we are doing a delay and/or an abort for this call. + bool delay_request_ = false; + bool abort_request_ = false; + + // Delay states + grpc_timer delay_timer_ ABSL_GUARDED_BY(delay_mu_); + ResumeBatchCanceller* resume_batch_canceller_ ABSL_GUARDED_BY(delay_mu_); + grpc_transport_stream_op_batch* delayed_batch_ ABSL_GUARDED_BY(delay_mu_); + // Abort states + grpc_error_handle abort_error_ = GRPC_ERROR_NONE; + grpc_closure recv_trailing_metadata_ready_; + grpc_closure* original_recv_trailing_metadata_ready_; + // Protects the asynchronous delay, resume, and cancellation. + Mutex delay_mu_; +}; + +// ChannelData + +grpc_error_handle ChannelData::Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(elem->filter == &FaultInjectionFilterVtable); + new (elem->channel_data) ChannelData(elem, args); + return GRPC_ERROR_NONE; +} + +void ChannelData::Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +ChannelData::ChannelData(grpc_channel_element* elem, + grpc_channel_element_args* args) + : index_(grpc_channel_stack_filter_instance_number(args->channel_stack, + elem)) {} + +// CallData::ResumeBatchCanceller + +class CallData::ResumeBatchCanceller { + public: + explicit ResumeBatchCanceller(grpc_call_element* elem) : elem_(elem) { + auto* calld = static_cast(elem->call_data); + GRPC_CALL_STACK_REF(calld->owning_call_, "ResumeBatchCanceller"); + GRPC_CLOSURE_INIT(&closure_, &Cancel, this, grpc_schedule_on_exec_ctx); + calld->call_combiner_->SetNotifyOnCancel(&closure_); + } + + private: + static void Cancel(void* arg, grpc_error_handle error) { + auto* self = static_cast(arg); + auto* chand = static_cast(self->elem_->channel_data); + auto* calld = static_cast(self->elem_->call_data); + { + MutexLock lock(&calld->delay_mu_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: cancelling schdueled pick: " + "error=%s self=%p calld->resume_batch_canceller_=%p", + chand, calld, grpc_error_std_string(error).c_str(), self, + calld->resume_batch_canceller_); + } + if (error != GRPC_ERROR_NONE && calld->resume_batch_canceller_ == self) { + // Cancel the delayed pick. + calld->CancelDelayTimer(); + calld->FaultInjectionFinished(); + // Fail pending batches on the call. + grpc_transport_stream_op_batch_finish_with_failure( + calld->delayed_batch_, GRPC_ERROR_REF(error), + calld->call_combiner_); + } + } + GRPC_CALL_STACK_UNREF(calld->owning_call_, "ResumeBatchCanceller"); + delete self; + } + + grpc_call_element* elem_; + grpc_closure closure_; +}; + +// CallData + +grpc_error_handle CallData::Init(grpc_call_element* elem, + const grpc_call_element_args* args) { + auto* calld = new (elem->call_data) CallData(elem, args); + if (calld->fi_policy_ == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "failed to find fault injection policy"); + } + return GRPC_ERROR_NONE; +} + +void CallData::Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*then_schedule_closure*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +void CallData::StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + // There should only be one send_initial_metdata op, and fault injection also + // only need to be enforced once. + if (batch->send_initial_metadata) { + calld->DecideWhetherToInjectFaults( + batch->payload->send_initial_metadata.send_initial_metadata); + if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: Fault injection triggered delay=%d abort=%d", + elem->channel_data, calld, calld->delay_request_, + calld->abort_request_); + } + if (calld->MaybeDelay()) { + // Delay the batch, and pass down the batch in the scheduled closure. + calld->DelayBatch(elem, batch); + return; + } + grpc_error_handle abort_error = calld->MaybeAbort(); + if (abort_error != GRPC_ERROR_NONE) { + calld->abort_error_ = abort_error; + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(calld->abort_error_), calld->call_combiner_); + return; + } + } else { + if (batch->recv_trailing_metadata) { + // Intercept recv_trailing_metadata callback so that we can inject the + // failure when aborting streaming calls, because their + // recv_trailing_metatdata op may not be on the same batch as the + // send_initial_metadata op. + calld->original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready_; + } + if (calld->abort_error_ != GRPC_ERROR_NONE) { + // If we already decided to abort, then immediately fail this batch. + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(calld->abort_error_), calld->call_combiner_); + return; + } + } + // Chain to the next filter. + grpc_call_next_op(elem, batch); +} + +CallData::CallData(grpc_call_element* elem, const grpc_call_element_args* args) + : owning_call_(args->call_stack), + arena_(args->arena), + call_combiner_(args->call_combiner) { + auto* chand = static_cast(elem->channel_data); + // Fetch the fault injection policy from the service config, based on the + // relative index for which policy should this CallData use. + auto* service_config_call_data = static_cast( + args->context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + auto* method_params = static_cast( + service_config_call_data->GetMethodParsedConfig( + FaultInjectionServiceConfigParser::ParserIndex())); + if (method_params != nullptr) { + fi_policy_ = method_params->fault_injection_policy(chand->index()); + } + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, + HijackedRecvTrailingMetadataReady, elem, + grpc_schedule_on_exec_ctx); +} + +CallData::~CallData() { + if (fi_policy_owned_) { + fi_policy_->~FaultInjectionPolicy(); + } + GRPC_ERROR_UNREF(abort_error_); +} + +void CallData::DecideWhetherToInjectFaults( + grpc_metadata_batch* initial_metadata) { + FaultInjectionMethodParsedConfig::FaultInjectionPolicy* copied_policy = + nullptr; + // Update the policy with values in initial metadata. + if (!fi_policy_->abort_code_header.empty() || + !fi_policy_->abort_percentage_header.empty() || + !fi_policy_->delay_header.empty() || + !fi_policy_->delay_percentage_header.empty()) { + // Defer the actual copy until the first matched header. + auto maybe_copy_policy_func = [this, &copied_policy]() { + if (copied_policy == nullptr) { + copied_policy = + arena_->New( + *fi_policy_); + } + }; + initial_metadata->ForEach([&](grpc_mdelem md) { + absl::string_view key = StringViewFromSlice(GRPC_MDKEY(md)); + // Only perform string comparison if: + // 1. Needs to check this header; + // 2. The value is not been filled before. + if (!fi_policy_->abort_code_header.empty() && + (copied_policy == nullptr || + copied_policy->abort_code == GRPC_STATUS_OK) && + key == fi_policy_->abort_code_header) { + maybe_copy_policy_func(); + grpc_status_code_from_int(GetMetadatumValueInt(md), + &copied_policy->abort_code); + } + if (!fi_policy_->abort_percentage_header.empty() && + key == fi_policy_->abort_percentage_header) { + maybe_copy_policy_func(); + copied_policy->abort_percentage_numerator = + std::min(GetMetadatumValueUnsignedInt(md), + fi_policy_->abort_percentage_numerator); + } + if (!fi_policy_->delay_header.empty() && + (copied_policy == nullptr || copied_policy->delay == 0) && + key == fi_policy_->delay_header) { + maybe_copy_policy_func(); + copied_policy->delay = static_cast( + std::max(GetMetadatumValueInt64(md), int64_t(0))); + } + if (!fi_policy_->delay_percentage_header.empty() && + key == fi_policy_->delay_percentage_header) { + maybe_copy_policy_func(); + copied_policy->delay_percentage_numerator = + std::min(GetMetadatumValueUnsignedInt(md), + fi_policy_->delay_percentage_numerator); + } + }); + if (copied_policy != nullptr) fi_policy_ = copied_policy; + } + // Roll the dice + delay_request_ = fi_policy_->delay != 0 && + UnderFraction(fi_policy_->delay_percentage_numerator, + fi_policy_->delay_percentage_denominator); + abort_request_ = fi_policy_->abort_code != GRPC_STATUS_OK && + UnderFraction(fi_policy_->abort_percentage_numerator, + fi_policy_->abort_percentage_denominator); + if (!delay_request_ && !abort_request_) { + if (copied_policy != nullptr) copied_policy->~FaultInjectionPolicy(); + // No fault injection for this call + } else { + fi_policy_owned_ = copied_policy != nullptr; + } +} + +bool CallData::HaveActiveFaultsQuota(bool increment) { + if (g_active_faults.load(std::memory_order_acquire) >= + fi_policy_->max_faults) { + return false; + } + if (increment) g_active_faults.fetch_add(1, std::memory_order_relaxed); + return true; +} + +bool CallData::MaybeDelay() { + if (delay_request_) { + return HaveActiveFaultsQuota(true); + } + return false; +} + +grpc_error_handle CallData::MaybeAbort() { + if (abort_request_ && (delay_request_ || HaveActiveFaultsQuota(false))) { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_COPIED_STRING(fi_policy_->abort_message.c_str()), + GRPC_ERROR_INT_GRPC_STATUS, fi_policy_->abort_code); + } + return GRPC_ERROR_NONE; +} + +void CallData::DelayBatch(grpc_call_element* elem, + grpc_transport_stream_op_batch* batch) { + MutexLock lock(&delay_mu_); + delayed_batch_ = batch; + resume_batch_canceller_ = new ResumeBatchCanceller(elem); + grpc_millis resume_time = ExecCtx::Get()->Now() + fi_policy_->delay; + GRPC_CLOSURE_INIT(&batch->handler_private.closure, ResumeBatch, elem, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&delay_timer_, resume_time, &batch->handler_private.closure); +} + +void CallData::ResumeBatch(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + auto* calld = static_cast(elem->call_data); + MutexLock lock(&calld->delay_mu_); + // Cancelled or canceller has already run + if (error == GRPC_ERROR_CANCELLED || + calld->resume_batch_canceller_ == nullptr) { + return; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: Resuming delayed stream op batch %p", + elem->channel_data, calld, calld->delayed_batch_); + } + // Lame the canceller + calld->resume_batch_canceller_ = nullptr; + // Finish fault injection. + calld->FaultInjectionFinished(); + // Abort if needed. + error = calld->MaybeAbort(); + if (error != GRPC_ERROR_NONE) { + calld->abort_error_ = error; + grpc_transport_stream_op_batch_finish_with_failure( + calld->delayed_batch_, GRPC_ERROR_REF(calld->abort_error_), + calld->call_combiner_); + return; + } + // Chain to the next filter. + grpc_call_next_op(elem, calld->delayed_batch_); +} + +void CallData::HijackedRecvTrailingMetadataReady(void* arg, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + auto* calld = static_cast(elem->call_data); + if (calld->abort_error_ != GRPC_ERROR_NONE) { + error = grpc_error_add_child(GRPC_ERROR_REF(error), + GRPC_ERROR_REF(calld->abort_error_)); + } else { + error = GRPC_ERROR_REF(error); + } + Closure::Run(DEBUG_LOCATION, calld->original_recv_trailing_metadata_ready_, + error); +} + +} // namespace + +extern const grpc_channel_filter FaultInjectionFilterVtable = { + CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + sizeof(ChannelData), + ChannelData::Init, + ChannelData::Destroy, + grpc_channel_next_get_info, + "fault_injection_filter", +}; + +void FaultInjectionFilterInit(void) { + grpc_core::FaultInjectionServiceConfigParser::Register(); +} + +void FaultInjectionFilterShutdown(void) {} + +} // namespace grpc_core diff --git a/src/core/ext/filters/fault_injection/service_config_parser.cc b/src/core/ext/filters/fault_injection/service_config_parser.cc new file mode 100644 index 00000000..a0c8b419 --- /dev/null +++ b/src/core/ext/filters/fault_injection/service_config_parser.cc @@ -0,0 +1,181 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/fault_injection/service_config_parser.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +#include "src/core/ext/filters/fault_injection/fault_injection_filter.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/json/json_util.h" + +namespace grpc_core { + +namespace { + +size_t g_fault_injection_parser_index; + +std::vector +ParseFaultInjectionPolicy(const Json::Array& policies_json_array, + std::vector* error_list) { + std::vector policies; + for (size_t i = 0; i < policies_json_array.size(); i++) { + FaultInjectionMethodParsedConfig::FaultInjectionPolicy + fault_injection_policy; + std::vector sub_error_list; + if (policies_json_array[i].type() != Json::Type::OBJECT) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "faultInjectionPolicy index ", i, " is not a JSON object"))); + continue; + } + const Json::Object& json_object = policies_json_array[i].object_value(); + // Parse abort_code + std::string abort_code_string; + if (ParseJsonObjectField(json_object, "abortCode", &abort_code_string, + &sub_error_list, false)) { + if (!grpc_status_code_from_string(abort_code_string.c_str(), + &(fault_injection_policy.abort_code))) { + sub_error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:abortCode error:failed to parse status code")); + } + } + // Parse abort_message + if (!ParseJsonObjectField(json_object, "abortMessage", + &fault_injection_policy.abort_message, + &sub_error_list, false)) { + fault_injection_policy.abort_message = "Fault injected"; + } + // Parse abort_code_header + ParseJsonObjectField(json_object, "abortCodeHeader", + &fault_injection_policy.abort_code_header, + &sub_error_list, false); + // Parse abort_percentage_header + ParseJsonObjectField(json_object, "abortPercentageHeader", + &fault_injection_policy.abort_percentage_header, + &sub_error_list, false); + // Parse abort_percentage_numerator + ParseJsonObjectField(json_object, "abortPercentageNumerator", + &fault_injection_policy.abort_percentage_numerator, + &sub_error_list, false); + // Parse abort_percentage_denominator + if (ParseJsonObjectField( + json_object, "abortPercentageDenominator", + &fault_injection_policy.abort_percentage_denominator, + &sub_error_list, false)) { + if (fault_injection_policy.abort_percentage_denominator != 100 && + fault_injection_policy.abort_percentage_denominator != 10000 && + fault_injection_policy.abort_percentage_denominator != 1000000) { + sub_error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:abortPercentageDenominator error:Denominator can only be " + "one of " + "100, 10000, 1000000")); + } + } + // Parse delay + ParseJsonObjectFieldAsDuration(json_object, "delay", + &fault_injection_policy.delay, + &sub_error_list, false); + // Parse delay_header + ParseJsonObjectField(json_object, "delayHeader", + &fault_injection_policy.delay_header, &sub_error_list, + false); + // Parse delay_percentage_header + ParseJsonObjectField(json_object, "delayPercentageHeader", + &fault_injection_policy.delay_percentage_header, + &sub_error_list, false); + // Parse delay_percentage_numerator + ParseJsonObjectField(json_object, "delayPercentageNumerator", + &fault_injection_policy.delay_percentage_numerator, + &sub_error_list, false); + // Parse delay_percentage_denominator + if (ParseJsonObjectField( + json_object, "delayPercentageDenominator", + &fault_injection_policy.delay_percentage_denominator, + &sub_error_list, false)) { + if (fault_injection_policy.delay_percentage_denominator != 100 && + fault_injection_policy.delay_percentage_denominator != 10000 && + fault_injection_policy.delay_percentage_denominator != 1000000) { + sub_error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:delayPercentageDenominator error:Denominator can only be " + "one of " + "100, 10000, 1000000")); + } + } + // Parse max_faults + if (ParseJsonObjectField(json_object, "maxFaults", + &fault_injection_policy.max_faults, + &sub_error_list, false)) { + if (fault_injection_policy.max_faults < 0) { + sub_error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxFaults error:should be zero or positive")); + } + } + if (!sub_error_list.empty()) { + error_list->push_back(GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("failed to parse faultInjectionPolicy index ", i), + &sub_error_list)); + } + policies.push_back(std::move(fault_injection_policy)); + } + return policies; +} + +} // namespace + +std::unique_ptr +FaultInjectionServiceConfigParser::ParsePerMethodParams( + const grpc_channel_args* args, const Json& json, grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + // Only parse fault injection policy if the following channel arg is present. + if (!grpc_channel_args_find_bool( + args, GRPC_ARG_PARSE_FAULT_INJECTION_METHOD_CONFIG, false)) { + return nullptr; + } + // Parse fault injection policy from given Json + std::vector + fault_injection_policies; + std::vector error_list; + const Json::Array* policies_json_array; + if (ParseJsonObjectField(json.object_value(), "faultInjectionPolicy", + &policies_json_array, &error_list)) { + fault_injection_policies = + ParseFaultInjectionPolicy(*policies_json_array, &error_list); + } + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Fault injection parser", &error_list); + if (*error != GRPC_ERROR_NONE || fault_injection_policies.empty()) { + return nullptr; + } + return absl::make_unique( + std::move(fault_injection_policies)); +} + +void FaultInjectionServiceConfigParser::Register() { + g_fault_injection_parser_index = ServiceConfigParser::RegisterParser( + absl::make_unique()); +} + +size_t FaultInjectionServiceConfigParser::ParserIndex() { + return g_fault_injection_parser_index; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc new file mode 100644 index 00000000..5310cfdd --- /dev/null +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -0,0 +1,602 @@ +/* + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/http/client/http_client_filter.h" + +#include +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/percent_encoding.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_conversion.h" +#include "src/core/lib/transport/transport_impl.h" + +#define EXPECTED_CONTENT_TYPE "application/grpc" +#define EXPECTED_CONTENT_TYPE_LENGTH (sizeof(EXPECTED_CONTENT_TYPE) - 1) + +/* default maximum size of payload eligible for GET request */ +static constexpr size_t kMaxPayloadSizeForGet = 2048; + +static void recv_initial_metadata_ready(void* user_data, + grpc_error_handle error); +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error); +static void on_send_message_next_done(void* arg, grpc_error_handle error); +static void send_message_on_complete(void* arg, grpc_error_handle error); + +namespace { +struct call_data { + call_data(grpc_call_element* elem, const grpc_call_element_args& args) + : call_combiner(args.call_combiner) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready, + ::recv_initial_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + ::recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_send_message_next_done, ::on_send_message_next_done, + elem, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&send_message_on_complete, ::send_message_on_complete, + elem, grpc_schedule_on_exec_ctx); + } + + ~call_data() { GRPC_ERROR_UNREF(recv_initial_metadata_error); } + + grpc_core::CallCombiner* call_combiner; + // State for handling send_initial_metadata ops. + grpc_linked_mdelem method; + grpc_linked_mdelem scheme; + grpc_linked_mdelem content_type; + grpc_linked_mdelem user_agent; + // State for handling recv_initial_metadata ops. + grpc_metadata_batch* recv_initial_metadata; + grpc_error_handle recv_initial_metadata_error = GRPC_ERROR_NONE; + grpc_closure* original_recv_initial_metadata_ready = nullptr; + grpc_closure recv_initial_metadata_ready; + // State for handling recv_trailing_metadata ops. + grpc_metadata_batch* recv_trailing_metadata; + grpc_closure* original_recv_trailing_metadata_ready; + grpc_closure recv_trailing_metadata_ready; + grpc_error_handle recv_trailing_metadata_error = GRPC_ERROR_NONE; + bool seen_recv_trailing_metadata_ready = false; + // State for handling send_message ops. + grpc_transport_stream_op_batch* send_message_batch; + size_t send_message_bytes_read = 0; + grpc_core::ManualConstructor send_message_cache; + grpc_core::ManualConstructor + send_message_caching_stream; + grpc_closure on_send_message_next_done; + grpc_closure* original_send_message_on_complete; + grpc_closure send_message_on_complete; +}; + +struct channel_data { + grpc_mdelem static_scheme; + grpc_mdelem user_agent; + size_t max_payload_size_for_get; +}; +} // namespace + +static grpc_error_handle client_filter_incoming_metadata( + grpc_metadata_batch* b) { + if (b->legacy_index()->named.status != nullptr) { + /* If both gRPC status and HTTP status are provided in the response, we + * should prefer the gRPC status code, as mentioned in + * https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. + */ + if (b->legacy_index()->named.grpc_status != nullptr || + grpc_mdelem_static_value_eq(b->legacy_index()->named.status->md, + GRPC_MDELEM_STATUS_200)) { + b->Remove(GRPC_BATCH_STATUS); + } else { + char* val = grpc_dump_slice( + GRPC_MDVALUE(b->legacy_index()->named.status->md), GPR_DUMP_ASCII); + std::string msg = + absl::StrCat("Received http2 header with status: ", val); + grpc_error_handle e = grpc_error_set_str( + grpc_error_set_int( + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Received http2 :status header with non-200 OK status"), + GRPC_ERROR_STR_VALUE, val), + GRPC_ERROR_INT_GRPC_STATUS, + grpc_http2_status_to_grpc_status(atoi(val))), + GRPC_ERROR_STR_GRPC_MESSAGE, msg); + gpr_free(val); + return e; + } + } + + if (b->legacy_index()->named.grpc_message != nullptr) { + grpc_slice pct_decoded_msg = grpc_core::PermissivePercentDecodeSlice( + GRPC_MDVALUE(b->legacy_index()->named.grpc_message->md)); + if (grpc_slice_is_equivalent( + pct_decoded_msg, + GRPC_MDVALUE(b->legacy_index()->named.grpc_message->md))) { + grpc_slice_unref_internal(pct_decoded_msg); + } else { + grpc_metadata_batch_set_value(b->legacy_index()->named.grpc_message, + pct_decoded_msg); + } + } + + if (b->legacy_index()->named.content_type != nullptr) { + if (!grpc_mdelem_static_value_eq( + b->legacy_index()->named.content_type->md, + GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC)) { + if (grpc_slice_buf_start_eq( + GRPC_MDVALUE(b->legacy_index()->named.content_type->md), + EXPECTED_CONTENT_TYPE, EXPECTED_CONTENT_TYPE_LENGTH) && + (GRPC_SLICE_START_PTR(GRPC_MDVALUE( + b->legacy_index() + ->named.content_type->md))[EXPECTED_CONTENT_TYPE_LENGTH] == + '+' || + GRPC_SLICE_START_PTR(GRPC_MDVALUE( + b->legacy_index() + ->named.content_type->md))[EXPECTED_CONTENT_TYPE_LENGTH] == + ';')) { + /* Although the C implementation doesn't (currently) generate them, + any custom +-suffix is explicitly valid. */ + /* TODO(klempner): We should consider preallocating common values such + as +proto or +json, or at least stashing them if we see them. */ + /* TODO(klempner): Should we be surfacing this to application code? */ + } else { + /* TODO(klempner): We're currently allowing this, but we shouldn't + see it without a proxy so log for now. */ + char* val = grpc_dump_slice( + GRPC_MDVALUE(b->legacy_index()->named.content_type->md), + GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "Unexpected content-type '%s'", val); + gpr_free(val); + } + } + b->Remove(GRPC_BATCH_CONTENT_TYPE); + } + + return GRPC_ERROR_NONE; +} + +static void recv_initial_metadata_ready(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (error == GRPC_ERROR_NONE) { + error = client_filter_incoming_metadata(calld->recv_initial_metadata); + calld->recv_initial_metadata_error = GRPC_ERROR_REF(error); + } else { + (void)GRPC_ERROR_REF(error); + } + grpc_closure* closure = calld->original_recv_initial_metadata_ready; + calld->original_recv_initial_metadata_ready = nullptr; + if (calld->seen_recv_trailing_metadata_ready) { + GRPC_CALL_COMBINER_START( + calld->call_combiner, &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_error, "continue recv_trailing_metadata"); + } + grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); +} + +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (calld->original_recv_initial_metadata_ready != nullptr) { + calld->recv_trailing_metadata_error = GRPC_ERROR_REF(error); + calld->seen_recv_trailing_metadata_ready = true; + GRPC_CALL_COMBINER_STOP(calld->call_combiner, + "deferring recv_trailing_metadata_ready until " + "after recv_initial_metadata_ready"); + return; + } + if (error == GRPC_ERROR_NONE) { + error = client_filter_incoming_metadata(calld->recv_trailing_metadata); + } else { + (void)GRPC_ERROR_REF(error); + } + error = grpc_error_add_child( + error, GRPC_ERROR_REF(calld->recv_initial_metadata_error)); + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_trailing_metadata_ready, error); +} + +static void send_message_on_complete(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + calld->send_message_cache.Destroy(); + // Set the batch's send_message bit back to true, so the retry code + // above knows what was in this batch. + calld->send_message_batch->send_message = true; + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_send_message_on_complete, + GRPC_ERROR_REF(error)); +} + +// Pulls a slice from the send_message byte stream, updating +// calld->send_message_bytes_read. +static grpc_error_handle pull_slice_from_send_message(call_data* calld) { + grpc_slice incoming_slice; + grpc_error_handle error = + calld->send_message_caching_stream->Pull(&incoming_slice); + if (error == GRPC_ERROR_NONE) { + calld->send_message_bytes_read += GRPC_SLICE_LENGTH(incoming_slice); + grpc_slice_unref_internal(incoming_slice); + } + return error; +} + +// Reads as many slices as possible from the send_message byte stream. +// Upon successful return, if calld->send_message_bytes_read == +// calld->send_message_caching_stream->length(), then we have completed +// reading from the byte stream; otherwise, an async read has been dispatched +// and on_send_message_next_done() will be invoked when it is complete. +static grpc_error_handle read_all_available_send_message_data( + call_data* calld) { + while (calld->send_message_caching_stream->Next( + SIZE_MAX, &calld->on_send_message_next_done)) { + grpc_error_handle error = pull_slice_from_send_message(calld); + if (error != GRPC_ERROR_NONE) return error; + if (calld->send_message_bytes_read == + calld->send_message_caching_stream->length()) { + break; + } + } + return GRPC_ERROR_NONE; +} + +// Async callback for ByteStream::Next(). +static void on_send_message_next_done(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure( + calld->send_message_batch, error, calld->call_combiner); + return; + } + error = pull_slice_from_send_message(calld); + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure( + calld->send_message_batch, error, calld->call_combiner); + return; + } + // There may or may not be more to read, but we don't care. If we got + // here, then we know that all of the data was not available + // synchronously, so we were not able to do a cached call. Instead, + // we just reset the byte stream and then send down the batch as-is. + calld->send_message_caching_stream->Reset(); + grpc_call_next_op(elem, calld->send_message_batch); +} + +static char* slice_buffer_to_string(grpc_slice_buffer* slice_buffer) { + char* payload_bytes = + static_cast(gpr_malloc(slice_buffer->length + 1)); + size_t offset = 0; + for (size_t i = 0; i < slice_buffer->count; ++i) { + memcpy(payload_bytes + offset, + GRPC_SLICE_START_PTR(slice_buffer->slices[i]), + GRPC_SLICE_LENGTH(slice_buffer->slices[i])); + offset += GRPC_SLICE_LENGTH(slice_buffer->slices[i]); + } + *(payload_bytes + offset) = '\0'; + return payload_bytes; +} + +// Modifies the path entry in the batch's send_initial_metadata to +// append the base64-encoded query for a GET request. +static grpc_error_handle update_path_for_get( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + grpc_slice path_slice = + GRPC_MDVALUE(batch->payload->send_initial_metadata.send_initial_metadata + ->legacy_index() + ->named.path->md); + /* sum up individual component's lengths and allocate enough memory to + * hold combined path+query */ + size_t estimated_len = GRPC_SLICE_LENGTH(path_slice); + estimated_len++; /* for the '?' */ + estimated_len += grpc_base64_estimate_encoded_size( + batch->payload->send_message.send_message->length(), + false /* multi_line */); + grpc_core::UnmanagedMemorySlice path_with_query_slice(estimated_len); + /* memcopy individual pieces into this slice */ + char* write_ptr = + reinterpret_cast GRPC_SLICE_START_PTR(path_with_query_slice); + char* original_path = + reinterpret_cast GRPC_SLICE_START_PTR(path_slice); + memcpy(write_ptr, original_path, GRPC_SLICE_LENGTH(path_slice)); + write_ptr += GRPC_SLICE_LENGTH(path_slice); + *write_ptr++ = '?'; + char* payload_bytes = + slice_buffer_to_string(calld->send_message_cache->cache_buffer()); + grpc_base64_encode_core(write_ptr, payload_bytes, + batch->payload->send_message.send_message->length(), + true /* url_safe */, false /* multi_line */); + gpr_free(payload_bytes); + /* remove trailing unused memory and add trailing 0 to terminate string */ + char* t = reinterpret_cast GRPC_SLICE_START_PTR(path_with_query_slice); + /* safe to use strlen since base64_encode will always add '\0' */ + path_with_query_slice = + grpc_slice_sub_no_ref(path_with_query_slice, 0, strlen(t)); + /* substitute previous path with the new path+query */ + grpc_mdelem mdelem_path_and_query = + grpc_mdelem_from_slices(GRPC_MDSTR_PATH, path_with_query_slice); + grpc_metadata_batch* b = + batch->payload->send_initial_metadata.send_initial_metadata; + return b->Substitute(b->legacy_index()->named.path, mdelem_path_and_query); +} + +static void remove_if_present(grpc_metadata_batch* batch, + grpc_metadata_batch_callouts_index idx) { + batch->Remove(idx); +} + +static void http_client_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + channel_data* channeld = static_cast(elem->channel_data); + GPR_TIMER_SCOPE("http_client_start_transport_stream_op_batch", 0); + + if (batch->recv_initial_metadata) { + /* substitute our callback for the higher callback */ + calld->recv_initial_metadata = + batch->payload->recv_initial_metadata.recv_initial_metadata; + calld->original_recv_initial_metadata_ready = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready; + } + + if (batch->recv_trailing_metadata) { + /* substitute our callback for the higher callback */ + calld->recv_trailing_metadata = + batch->payload->recv_trailing_metadata.recv_trailing_metadata; + calld->original_recv_trailing_metadata_ready = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready; + } + + grpc_error_handle error = GRPC_ERROR_NONE; + bool batch_will_be_handled_asynchronously = false; + if (batch->send_initial_metadata) { + // Decide which HTTP VERB to use. We use GET if the request is marked + // cacheable, and the operation contains both initial metadata and send + // message, and the payload is below the size threshold, and all the data + // for this request is immediately available. + grpc_mdelem method = GRPC_MDELEM_METHOD_POST; + if (batch->send_message && + (batch->payload->send_initial_metadata.send_initial_metadata_flags & + GRPC_INITIAL_METADATA_CACHEABLE_REQUEST) && + batch->payload->send_message.send_message->length() < + channeld->max_payload_size_for_get) { + calld->send_message_bytes_read = 0; + calld->send_message_cache.Init( + std::move(batch->payload->send_message.send_message)); + calld->send_message_caching_stream.Init(calld->send_message_cache.get()); + batch->payload->send_message.send_message.reset( + calld->send_message_caching_stream.get()); + calld->original_send_message_on_complete = batch->on_complete; + batch->on_complete = &calld->send_message_on_complete; + calld->send_message_batch = batch; + error = read_all_available_send_message_data(calld); + if (error != GRPC_ERROR_NONE) goto done; + // If all the data has been read, then we can use GET. + if (calld->send_message_bytes_read == + calld->send_message_caching_stream->length()) { + method = GRPC_MDELEM_METHOD_GET; + error = update_path_for_get(elem, batch); + if (error != GRPC_ERROR_NONE) goto done; + batch->send_message = false; + calld->send_message_caching_stream->Orphan(); + } else { + // Not all data is available. The batch will be sent down + // asynchronously in on_send_message_next_done(). + batch_will_be_handled_asynchronously = true; + // Fall back to POST. + gpr_log(GPR_DEBUG, + "Request is marked Cacheable but not all data is available. " + "Falling back to POST"); + } + } else if (batch->payload->send_initial_metadata + .send_initial_metadata_flags & + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST) { + method = GRPC_MDELEM_METHOD_PUT; + } + + remove_if_present( + batch->payload->send_initial_metadata.send_initial_metadata, + GRPC_BATCH_METHOD); + remove_if_present( + batch->payload->send_initial_metadata.send_initial_metadata, + GRPC_BATCH_SCHEME); + remove_if_present( + batch->payload->send_initial_metadata.send_initial_metadata, + GRPC_BATCH_CONTENT_TYPE); + remove_if_present( + batch->payload->send_initial_metadata.send_initial_metadata, + GRPC_BATCH_USER_AGENT); + + /* Send : prefixed headers, which have to be before any application + layer headers. */ + error = grpc_metadata_batch_add_head( + batch->payload->send_initial_metadata.send_initial_metadata, + &calld->method, method, GRPC_BATCH_METHOD); + if (error != GRPC_ERROR_NONE) goto done; + error = grpc_metadata_batch_add_head( + batch->payload->send_initial_metadata.send_initial_metadata, + &calld->scheme, channeld->static_scheme, GRPC_BATCH_SCHEME); + if (error != GRPC_ERROR_NONE) goto done; + batch->payload->send_initial_metadata.send_initial_metadata->Set( + grpc_core::TeMetadata(), grpc_core::TeMetadata::kTrailers); + error = grpc_metadata_batch_add_tail( + batch->payload->send_initial_metadata.send_initial_metadata, + &calld->content_type, GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC, + GRPC_BATCH_CONTENT_TYPE); + if (error != GRPC_ERROR_NONE) goto done; + error = grpc_metadata_batch_add_tail( + batch->payload->send_initial_metadata.send_initial_metadata, + &calld->user_agent, GRPC_MDELEM_REF(channeld->user_agent), + GRPC_BATCH_USER_AGENT); + if (error != GRPC_ERROR_NONE) goto done; + } + +done: + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure(batch, error, + calld->call_combiner); + } else if (!batch_will_be_handled_asynchronously) { + grpc_call_next_op(elem, batch); + } +} + +/* Constructor for call_data */ +static grpc_error_handle http_client_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + new (elem->call_data) call_data(elem, *args); + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data */ +static void http_client_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->~call_data(); +} + +static grpc_mdelem scheme_from_args(const grpc_channel_args* args) { + unsigned i; + size_t j; + grpc_mdelem valid_schemes[] = {GRPC_MDELEM_SCHEME_HTTP, + GRPC_MDELEM_SCHEME_HTTPS}; + if (args != nullptr) { + for (i = 0; i < args->num_args; ++i) { + if (args->args[i].type == GRPC_ARG_STRING && + strcmp(args->args[i].key, GRPC_ARG_HTTP2_SCHEME) == 0) { + for (j = 0; j < GPR_ARRAY_SIZE(valid_schemes); j++) { + if (0 == grpc_slice_str_cmp(GRPC_MDVALUE(valid_schemes[j]), + args->args[i].value.string)) { + return valid_schemes[j]; + } + } + } + } + } + return GRPC_MDELEM_SCHEME_HTTP; +} + +static size_t max_payload_size_from_args(const grpc_channel_args* args) { + if (args != nullptr) { + for (size_t i = 0; i < args->num_args; ++i) { + if (0 == strcmp(args->args[i].key, GRPC_ARG_MAX_PAYLOAD_SIZE_FOR_GET)) { + if (args->args[i].type != GRPC_ARG_INTEGER) { + gpr_log(GPR_ERROR, "%s: must be an integer", + GRPC_ARG_MAX_PAYLOAD_SIZE_FOR_GET); + } else { + return static_cast(args->args[i].value.integer); + } + } + } + } + return kMaxPayloadSizeForGet; +} + +static grpc_core::ManagedMemorySlice user_agent_from_args( + const grpc_channel_args* args, const char* transport_name) { + std::vector user_agent_fields; + + for (size_t i = 0; args && i < args->num_args; i++) { + if (0 == strcmp(args->args[i].key, GRPC_ARG_PRIMARY_USER_AGENT_STRING)) { + if (args->args[i].type != GRPC_ARG_STRING) { + gpr_log(GPR_ERROR, "Channel argument '%s' should be a string", + GRPC_ARG_PRIMARY_USER_AGENT_STRING); + } else { + user_agent_fields.push_back(args->args[i].value.string); + } + } + } + + user_agent_fields.push_back( + absl::StrFormat("grpc-c/%s (%s; %s)", grpc_version_string(), + GPR_PLATFORM_STRING, transport_name)); + + for (size_t i = 0; args && i < args->num_args; i++) { + if (0 == strcmp(args->args[i].key, GRPC_ARG_SECONDARY_USER_AGENT_STRING)) { + if (args->args[i].type != GRPC_ARG_STRING) { + gpr_log(GPR_ERROR, "Channel argument '%s' should be a string", + GRPC_ARG_SECONDARY_USER_AGENT_STRING); + } else { + user_agent_fields.push_back(args->args[i].value.string); + } + } + } + + std::string user_agent_string = absl::StrJoin(user_agent_fields, " "); + return grpc_core::ManagedMemorySlice(user_agent_string.c_str()); +} + +/* Constructor for channel_data */ +static grpc_error_handle http_client_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + GPR_ASSERT(!args->is_last); + GPR_ASSERT(args->optional_transport != nullptr); + chand->static_scheme = scheme_from_args(args->channel_args); + chand->max_payload_size_for_get = + max_payload_size_from_args(args->channel_args); + chand->user_agent = grpc_mdelem_from_slices( + GRPC_MDSTR_USER_AGENT, + user_agent_from_args(args->channel_args, + args->optional_transport->vtable->name)); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +static void http_client_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + GRPC_MDELEM_UNREF(chand->user_agent); +} + +const grpc_channel_filter grpc_http_client_filter = { + http_client_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + http_client_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + http_client_destroy_call_elem, + sizeof(channel_data), + http_client_init_channel_elem, + http_client_destroy_channel_elem, + grpc_channel_next_get_info, + "http-client"}; diff --git a/src/core/ext/filters/http/client_authority_filter.cc b/src/core/ext/filters/http/client_authority_filter.cc new file mode 100644 index 00000000..1915ebda --- /dev/null +++ b/src/core/ext/filters/http/client_authority_filter.cc @@ -0,0 +1,159 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/http/client_authority_filter.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/static_metadata.h" + +namespace { + +struct call_data { + grpc_linked_mdelem authority_storage; + grpc_core::CallCombiner* call_combiner; +}; + +struct channel_data { + grpc_core::ManagedMemorySlice default_authority; + grpc_mdelem default_authority_mdelem; +}; + +void client_authority_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + channel_data* chand = static_cast(elem->channel_data); + call_data* calld = static_cast(elem->call_data); + // Handle send_initial_metadata. + // If the initial metadata doesn't already contain :authority, add it. + if (batch->send_initial_metadata && + batch->payload->send_initial_metadata.send_initial_metadata + ->legacy_index() + ->named.authority == nullptr) { + grpc_error_handle error = grpc_metadata_batch_add_head( + batch->payload->send_initial_metadata.send_initial_metadata, + &calld->authority_storage, + GRPC_MDELEM_REF(chand->default_authority_mdelem), GRPC_BATCH_AUTHORITY); + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure(batch, error, + calld->call_combiner); + return; + } + } + // Pass control down the stack. + grpc_call_next_op(elem, batch); +} + +/* Constructor for call_data */ +grpc_error_handle client_authority_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + call_data* calld = static_cast(elem->call_data); + calld->call_combiner = args->call_combiner; + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data */ +void client_authority_destroy_call_elem( + grpc_call_element* /*elem*/, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) {} + +/* Constructor for channel_data */ +grpc_error_handle client_authority_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + const grpc_arg* default_authority_arg = + grpc_channel_args_find(args->channel_args, GRPC_ARG_DEFAULT_AUTHORITY); + if (default_authority_arg == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "GRPC_ARG_DEFAULT_AUTHORITY channel arg. not found. Note that direct " + "channels must explicitly specify a value for this argument."); + } + const char* default_authority_str = + grpc_channel_arg_get_string(default_authority_arg); + if (default_authority_str == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "GRPC_ARG_DEFAULT_AUTHORITY channel arg. must be a string"); + } + chand->default_authority = + grpc_core::ManagedMemorySlice(default_authority_str); + chand->default_authority_mdelem = grpc_mdelem_create( + GRPC_MDSTR_AUTHORITY, chand->default_authority, nullptr); + GPR_ASSERT(!args->is_last); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +void client_authority_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + grpc_slice_unref_internal(chand->default_authority); + GRPC_MDELEM_UNREF(chand->default_authority_mdelem); +} +} // namespace + +const grpc_channel_filter grpc_client_authority_filter = { + client_authority_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + client_authority_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + client_authority_destroy_call_elem, + sizeof(channel_data), + client_authority_init_channel_elem, + client_authority_destroy_channel_elem, + grpc_channel_next_get_info, + "authority"}; + +static bool add_client_authority_filter(grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + const grpc_arg* disable_client_authority_filter_arg = grpc_channel_args_find( + channel_args, GRPC_ARG_DISABLE_CLIENT_AUTHORITY_FILTER); + if (disable_client_authority_filter_arg != nullptr) { + const bool is_client_authority_filter_disabled = + grpc_channel_arg_get_bool(disable_client_authority_filter_arg, false); + if (is_client_authority_filter_disabled) { + return true; + } + } + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_client_authority_filter, nullptr, nullptr); +} + +namespace grpc_core { +void RegisterClientAuthorityFilter(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, INT_MAX, + add_client_authority_filter); + builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, INT_MAX, + add_client_authority_filter); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/http/http_filters_plugin.cc b/src/core/ext/filters/http/http_filters_plugin.cc new file mode 100644 index 00000000..dac830d4 --- /dev/null +++ b/src/core/ext/filters/http/http_filters_plugin.cc @@ -0,0 +1,90 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/ext/filters/http/client/http_client_filter.h" +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" +#include "src/core/ext/filters/http/message_compress/message_decompress_filter.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/transport/transport_impl.h" + +static bool is_building_http_like_transport( + grpc_channel_stack_builder* builder) { + grpc_transport* t = grpc_channel_stack_builder_get_transport(builder); + return t != nullptr && strstr(t->vtable->name, "http"); +} + +namespace grpc_core { +void RegisterHttpFilters(CoreConfiguration::Builder* builder) { + auto optional = [builder](grpc_channel_stack_type channel_type, + bool enable_in_minimal_stack, + const char* control_channel_arg, + const grpc_channel_filter* filter) { + builder->channel_init()->RegisterStage( + channel_type, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [enable_in_minimal_stack, control_channel_arg, + filter](grpc_channel_stack_builder* builder) { + if (!is_building_http_like_transport(builder)) return true; + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + bool enable = grpc_channel_arg_get_bool( + grpc_channel_args_find(channel_args, control_channel_arg), + enable_in_minimal_stack || + !grpc_channel_args_want_minimal_stack(channel_args)); + if (!enable) return true; + return grpc_channel_stack_builder_prepend_filter(builder, filter, + nullptr, nullptr); + }); + }; + auto required = [builder](grpc_channel_stack_type channel_type, + const grpc_channel_filter* filter) { + builder->channel_init()->RegisterStage( + channel_type, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [filter](grpc_channel_stack_builder* builder) { + if (!is_building_http_like_transport(builder)) return true; + return grpc_channel_stack_builder_prepend_filter(builder, filter, + nullptr, nullptr); + }); + }; + optional(GRPC_CLIENT_SUBCHANNEL, false, + GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION, + &grpc_message_compress_filter); + optional(GRPC_CLIENT_DIRECT_CHANNEL, false, + GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION, + &grpc_message_compress_filter); + optional(GRPC_SERVER_CHANNEL, false, GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION, + &grpc_message_compress_filter); + optional(GRPC_CLIENT_SUBCHANNEL, true, + GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, + &grpc_core::MessageDecompressFilter); + optional(GRPC_CLIENT_DIRECT_CHANNEL, true, + GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, + &grpc_core::MessageDecompressFilter); + optional(GRPC_SERVER_CHANNEL, true, GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION, + &grpc_core::MessageDecompressFilter); + required(GRPC_CLIENT_SUBCHANNEL, &grpc_http_client_filter); + required(GRPC_CLIENT_DIRECT_CHANNEL, &grpc_http_client_filter); + required(GRPC_SERVER_CHANNEL, &grpc_http_server_filter); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/http/message_compress/message_compress_filter.cc b/src/core/ext/filters/http/message_compress/message_compress_filter.cc new file mode 100644 index 00000000..478fb96f --- /dev/null +++ b/src/core/ext/filters/http/message_compress/message_compress_filter.cc @@ -0,0 +1,550 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" + +#include +#include + +#include "absl/types/optional.h" + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/compression/compression_internal.h" +#include "src/core/lib/compression/message_compress.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/transport/static_metadata.h" + +namespace { + +class ChannelData { + public: + explicit ChannelData(grpc_channel_element_args* args) { + // Get the enabled and the default algorithms from channel args. + enabled_compression_algorithms_bitset_ = + grpc_channel_args_compression_algorithm_get_states(args->channel_args); + default_compression_algorithm_ = + grpc_channel_args_get_channel_default_compression_algorithm( + args->channel_args); + // Make sure the default is enabled. + if (!grpc_core::GetBit(enabled_compression_algorithms_bitset_, + default_compression_algorithm_)) { + const char* name; + GPR_ASSERT(grpc_compression_algorithm_name(default_compression_algorithm_, + &name) == 1); + gpr_log(GPR_ERROR, + "default compression algorithm %s not enabled: switching to none", + name); + default_compression_algorithm_ = GRPC_COMPRESS_NONE; + } + enabled_message_compression_algorithms_bitset_ = + grpc_compression_bitset_to_message_bitset( + enabled_compression_algorithms_bitset_); + enabled_stream_compression_algorithms_bitset_ = + grpc_compression_bitset_to_stream_bitset( + enabled_compression_algorithms_bitset_); + GPR_ASSERT(!args->is_last); + } + + grpc_compression_algorithm default_compression_algorithm() const { + return default_compression_algorithm_; + } + + uint32_t enabled_compression_algorithms_bitset() const { + return enabled_compression_algorithms_bitset_; + } + + uint32_t enabled_message_compression_algorithms_bitset() const { + return enabled_message_compression_algorithms_bitset_; + } + + uint32_t enabled_stream_compression_algorithms_bitset() const { + return enabled_stream_compression_algorithms_bitset_; + } + + private: + /** The default, channel-level, compression algorithm */ + grpc_compression_algorithm default_compression_algorithm_; + /** Bitset of enabled compression algorithms */ + uint32_t enabled_compression_algorithms_bitset_; + /** Bitset of enabled message compression algorithms */ + uint32_t enabled_message_compression_algorithms_bitset_; + /** Bitset of enabled stream compression algorithms */ + uint32_t enabled_stream_compression_algorithms_bitset_; +}; + +class CallData { + public: + CallData(grpc_call_element* elem, const grpc_call_element_args& args) + : call_combiner_(args.call_combiner) { + ChannelData* channeld = static_cast(elem->channel_data); + // The call's message compression algorithm is set to channel's default + // setting. It can be overridden later by initial metadata. + if (GPR_LIKELY( + grpc_core::GetBit(channeld->enabled_compression_algorithms_bitset(), + channeld->default_compression_algorithm()))) { + message_compression_algorithm_ = + grpc_compression_algorithm_to_message_compression_algorithm( + channeld->default_compression_algorithm()); + } + GRPC_CLOSURE_INIT(&start_send_message_batch_in_call_combiner_, + StartSendMessageBatch, elem, grpc_schedule_on_exec_ctx); + } + + ~CallData() { + if (state_initialized_) { + grpc_slice_buffer_destroy_internal(&slices_); + } + GRPC_ERROR_UNREF(cancel_error_); + } + + void CompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch); + + private: + bool SkipMessageCompression(); + void InitializeState(grpc_call_element* elem); + + grpc_error_handle ProcessSendInitialMetadata( + grpc_call_element* elem, grpc_metadata_batch* initial_metadata); + + // Methods for processing a send_message batch + static void StartSendMessageBatch(void* elem_arg, grpc_error_handle unused); + static void OnSendMessageNextDone(void* elem_arg, grpc_error_handle error); + grpc_error_handle PullSliceFromSendMessage(); + void ContinueReadingSendMessage(grpc_call_element* elem); + void FinishSendMessage(grpc_call_element* elem); + void SendMessageBatchContinue(grpc_call_element* elem); + static void FailSendMessageBatchInCallCombiner(void* calld_arg, + grpc_error_handle error); + + static void SendMessageOnComplete(void* calld_arg, grpc_error_handle error); + + grpc_core::CallCombiner* call_combiner_; + grpc_message_compression_algorithm message_compression_algorithm_ = + GRPC_MESSAGE_COMPRESS_NONE; + grpc_error_handle cancel_error_ = GRPC_ERROR_NONE; + grpc_transport_stream_op_batch* send_message_batch_ = nullptr; + bool seen_initial_metadata_ = false; + /* Set to true, if the fields below are initialized. */ + bool state_initialized_ = false; + grpc_closure start_send_message_batch_in_call_combiner_; + /* The fields below are only initialized when we compress the payload. + * Keep them at the bottom of the struct, so they don't pollute the + * cache-lines. */ + grpc_linked_mdelem message_compression_algorithm_storage_; + grpc_linked_mdelem stream_compression_algorithm_storage_; + grpc_linked_mdelem accept_encoding_storage_; + grpc_linked_mdelem accept_stream_encoding_storage_; + grpc_slice_buffer slices_; /**< Buffers up input slices to be compressed */ + // Allocate space for the replacement stream + std::aligned_storage::type + replacement_stream_; + grpc_closure* original_send_message_on_complete_ = nullptr; + grpc_closure send_message_on_complete_; + grpc_closure on_send_message_next_done_; +}; + +// Returns true if we should skip message compression for the current message. +bool CallData::SkipMessageCompression() { + // If the flags of this message indicate that it shouldn't be compressed, we + // skip message compression. + uint32_t flags = + send_message_batch_->payload->send_message.send_message->flags(); + if (flags & (GRPC_WRITE_NO_COMPRESS | GRPC_WRITE_INTERNAL_COMPRESS)) { + return true; + } + // If this call doesn't have any message compression algorithm set, skip + // message compression. + return message_compression_algorithm_ == GRPC_MESSAGE_COMPRESS_NONE; +} + +// Determines the compression algorithm from the initial metadata and the +// channel's default setting. +grpc_compression_algorithm FindCompressionAlgorithm( + grpc_metadata_batch* initial_metadata, ChannelData* channeld) { + if (initial_metadata->legacy_index()->named.grpc_internal_encoding_request == + nullptr) { + return channeld->default_compression_algorithm(); + } + grpc_compression_algorithm compression_algorithm; + // Parse the compression algorithm from the initial metadata. + grpc_mdelem md = initial_metadata->legacy_index() + ->named.grpc_internal_encoding_request->md; + GPR_ASSERT(grpc_compression_algorithm_parse(GRPC_MDVALUE(md), + &compression_algorithm)); + // Remove this metadata since it's an internal one (i.e., it won't be + // transmitted out). + initial_metadata->Remove(GRPC_BATCH_GRPC_INTERNAL_ENCODING_REQUEST); + // Check if that algorithm is enabled. Note that GRPC_COMPRESS_NONE is always + // enabled. + // TODO(juanlishen): Maybe use channel default or abort() if the algorithm + // from the initial metadata is disabled. + if (GPR_LIKELY( + grpc_core::GetBit(channeld->enabled_compression_algorithms_bitset(), + compression_algorithm))) { + return compression_algorithm; + } + const char* algorithm_name; + GPR_ASSERT( + grpc_compression_algorithm_name(compression_algorithm, &algorithm_name)); + gpr_log(GPR_ERROR, + "Invalid compression algorithm from initial metadata: '%s' " + "(previously disabled). " + "Will not compress.", + algorithm_name); + return GRPC_COMPRESS_NONE; +} + +void CallData::InitializeState(grpc_call_element* elem) { + GPR_DEBUG_ASSERT(!state_initialized_); + state_initialized_ = true; + grpc_slice_buffer_init(&slices_); + GRPC_CLOSURE_INIT(&send_message_on_complete_, SendMessageOnComplete, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_send_message_next_done_, OnSendMessageNextDone, elem, + grpc_schedule_on_exec_ctx); +} + +grpc_error_handle CallData::ProcessSendInitialMetadata( + grpc_call_element* elem, grpc_metadata_batch* initial_metadata) { + ChannelData* channeld = static_cast(elem->channel_data); + // Find the compression algorithm. + grpc_compression_algorithm compression_algorithm = + FindCompressionAlgorithm(initial_metadata, channeld); + // Note that at most one of the following algorithms can be set. + message_compression_algorithm_ = + grpc_compression_algorithm_to_message_compression_algorithm( + compression_algorithm); + grpc_stream_compression_algorithm stream_compression_algorithm = + grpc_compression_algorithm_to_stream_compression_algorithm( + compression_algorithm); + // Hint compression algorithm. + grpc_error_handle error = GRPC_ERROR_NONE; + if (message_compression_algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) { + InitializeState(elem); + error = grpc_metadata_batch_add_tail( + initial_metadata, &message_compression_algorithm_storage_, + grpc_message_compression_encoding_mdelem( + message_compression_algorithm_), + GRPC_BATCH_GRPC_ENCODING); + } else if (stream_compression_algorithm != GRPC_STREAM_COMPRESS_NONE) { + InitializeState(elem); + error = grpc_metadata_batch_add_tail( + initial_metadata, &stream_compression_algorithm_storage_, + grpc_stream_compression_encoding_mdelem(stream_compression_algorithm), + GRPC_BATCH_CONTENT_ENCODING); + } + if (error != GRPC_ERROR_NONE) return error; + // Convey supported compression algorithms. + error = grpc_metadata_batch_add_tail( + initial_metadata, &accept_encoding_storage_, + GRPC_MDELEM_ACCEPT_ENCODING_FOR_ALGORITHMS( + channeld->enabled_message_compression_algorithms_bitset()), + GRPC_BATCH_GRPC_ACCEPT_ENCODING); + if (error != GRPC_ERROR_NONE) return error; + // Do not overwrite accept-encoding header if it already presents (e.g., added + // by some proxy). + if (!initial_metadata->legacy_index()->named.accept_encoding) { + error = grpc_metadata_batch_add_tail( + initial_metadata, &accept_stream_encoding_storage_, + GRPC_MDELEM_ACCEPT_STREAM_ENCODING_FOR_ALGORITHMS( + channeld->enabled_stream_compression_algorithms_bitset()), + GRPC_BATCH_ACCEPT_ENCODING); + } + return error; +} + +void CallData::SendMessageOnComplete(void* calld_arg, grpc_error_handle error) { + CallData* calld = static_cast(calld_arg); + grpc_slice_buffer_reset_and_unref_internal(&calld->slices_); + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_send_message_on_complete_, + GRPC_ERROR_REF(error)); +} + +void CallData::SendMessageBatchContinue(grpc_call_element* elem) { + // Note: The call to grpc_call_next_op() results in yielding the + // call combiner, so we need to clear send_message_batch_ before we do that. + grpc_transport_stream_op_batch* send_message_batch = send_message_batch_; + send_message_batch_ = nullptr; + grpc_call_next_op(elem, send_message_batch); +} + +void CallData::FinishSendMessage(grpc_call_element* elem) { + GPR_DEBUG_ASSERT(message_compression_algorithm_ != + GRPC_MESSAGE_COMPRESS_NONE); + // Compress the data if appropriate. + grpc_slice_buffer tmp; + grpc_slice_buffer_init(&tmp); + uint32_t send_flags = + send_message_batch_->payload->send_message.send_message->flags(); + bool did_compress = + grpc_msg_compress(message_compression_algorithm_, &slices_, &tmp); + if (did_compress) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + const char* algo_name; + const size_t before_size = slices_.length; + const size_t after_size = tmp.length; + const float savings_ratio = 1.0f - static_cast(after_size) / + static_cast(before_size); + GPR_ASSERT(grpc_message_compression_algorithm_name( + message_compression_algorithm_, &algo_name)); + gpr_log(GPR_INFO, + "Compressed[%s] %" PRIuPTR " bytes vs. %" PRIuPTR + " bytes (%.2f%% savings)", + algo_name, before_size, after_size, 100 * savings_ratio); + } + grpc_slice_buffer_swap(&slices_, &tmp); + send_flags |= GRPC_WRITE_INTERNAL_COMPRESS; + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + const char* algo_name; + GPR_ASSERT(grpc_message_compression_algorithm_name( + message_compression_algorithm_, &algo_name)); + gpr_log(GPR_INFO, + "Algorithm '%s' enabled but decided not to compress. Input size: " + "%" PRIuPTR, + algo_name, slices_.length); + } + } + grpc_slice_buffer_destroy_internal(&tmp); + // Swap out the original byte stream with our new one and send the + // batch down. + new (&replacement_stream_) + grpc_core::SliceBufferByteStream(&slices_, send_flags); + send_message_batch_->payload->send_message.send_message.reset( + reinterpret_cast( + &replacement_stream_)); + original_send_message_on_complete_ = send_message_batch_->on_complete; + send_message_batch_->on_complete = &send_message_on_complete_; + SendMessageBatchContinue(elem); +} + +void CallData::FailSendMessageBatchInCallCombiner(void* calld_arg, + grpc_error_handle error) { + CallData* calld = static_cast(calld_arg); + if (calld->send_message_batch_ != nullptr) { + grpc_transport_stream_op_batch_finish_with_failure( + calld->send_message_batch_, GRPC_ERROR_REF(error), + calld->call_combiner_); + calld->send_message_batch_ = nullptr; + } +} + +// Pulls a slice from the send_message byte stream and adds it to slices_. +grpc_error_handle CallData::PullSliceFromSendMessage() { + grpc_slice incoming_slice; + grpc_error_handle error = + send_message_batch_->payload->send_message.send_message->Pull( + &incoming_slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&slices_, incoming_slice); + } + return error; +} + +// Reads as many slices as possible from the send_message byte stream. +// If all data has been read, invokes FinishSendMessage(). Otherwise, +// an async call to ByteStream::Next() has been started, which will +// eventually result in calling OnSendMessageNextDone(). +void CallData::ContinueReadingSendMessage(grpc_call_element* elem) { + if (slices_.length == + send_message_batch_->payload->send_message.send_message->length()) { + FinishSendMessage(elem); + return; + } + while (send_message_batch_->payload->send_message.send_message->Next( + ~static_cast(0), &on_send_message_next_done_)) { + grpc_error_handle error = PullSliceFromSendMessage(); + if (error != GRPC_ERROR_NONE) { + // Closure callback; does not take ownership of error. + FailSendMessageBatchInCallCombiner(this, error); + GRPC_ERROR_UNREF(error); + return; + } + if (slices_.length == + send_message_batch_->payload->send_message.send_message->length()) { + FinishSendMessage(elem); + break; + } + } +} + +// Async callback for ByteStream::Next(). +void CallData::OnSendMessageNextDone(void* elem_arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(elem_arg); + CallData* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + // Closure callback; does not take ownership of error. + FailSendMessageBatchInCallCombiner(calld, error); + return; + } + error = calld->PullSliceFromSendMessage(); + if (error != GRPC_ERROR_NONE) { + // Closure callback; does not take ownership of error. + FailSendMessageBatchInCallCombiner(calld, error); + GRPC_ERROR_UNREF(error); + return; + } + if (calld->slices_.length == calld->send_message_batch_->payload->send_message + .send_message->length()) { + calld->FinishSendMessage(elem); + } else { + calld->ContinueReadingSendMessage(elem); + } +} + +void CallData::StartSendMessageBatch(void* elem_arg, + grpc_error_handle /*unused*/) { + grpc_call_element* elem = static_cast(elem_arg); + CallData* calld = static_cast(elem->call_data); + if (calld->SkipMessageCompression()) { + calld->SendMessageBatchContinue(elem); + } else { + calld->ContinueReadingSendMessage(elem); + } +} + +void CallData::CompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + GPR_TIMER_SCOPE("compress_start_transport_stream_op_batch", 0); + // Handle cancel_stream. + if (batch->cancel_stream) { + GRPC_ERROR_UNREF(cancel_error_); + cancel_error_ = GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error); + if (send_message_batch_ != nullptr) { + if (!seen_initial_metadata_) { + GRPC_CALL_COMBINER_START( + call_combiner_, + GRPC_CLOSURE_CREATE(FailSendMessageBatchInCallCombiner, this, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_REF(cancel_error_), "failing send_message op"); + } else { + send_message_batch_->payload->send_message.send_message->Shutdown( + GRPC_ERROR_REF(cancel_error_)); + } + } + } else if (cancel_error_ != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, GRPC_ERROR_REF(cancel_error_), call_combiner_); + return; + } + // Handle send_initial_metadata. + if (batch->send_initial_metadata) { + GPR_ASSERT(!seen_initial_metadata_); + grpc_error_handle error = ProcessSendInitialMetadata( + elem, batch->payload->send_initial_metadata.send_initial_metadata); + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure(batch, error, + call_combiner_); + return; + } + seen_initial_metadata_ = true; + // If we had previously received a batch containing a send_message op, + // handle it now. Note that we need to re-enter the call combiner + // for this, since we can't send two batches down while holding the + // call combiner, since the connected_channel filter (at the bottom of + // the call stack) will release the call combiner for each batch it sees. + if (send_message_batch_ != nullptr) { + GRPC_CALL_COMBINER_START( + call_combiner_, &start_send_message_batch_in_call_combiner_, + GRPC_ERROR_NONE, "starting send_message after send_initial_metadata"); + } + } + // Handle send_message. + if (batch->send_message) { + GPR_ASSERT(send_message_batch_ == nullptr); + send_message_batch_ = batch; + // If we have not yet seen send_initial_metadata, then we have to + // wait. We save the batch and then drop the call combiner, which we'll + // have to pick up again later when we get send_initial_metadata. + if (!seen_initial_metadata_) { + GRPC_CALL_COMBINER_STOP( + call_combiner_, "send_message batch pending send_initial_metadata"); + return; + } + StartSendMessageBatch(elem, GRPC_ERROR_NONE); + } else { + // Pass control down the stack. + grpc_call_next_op(elem, batch); + } +} + +void CompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + CallData* calld = static_cast(elem->call_data); + calld->CompressStartTransportStreamOpBatch(elem, batch); +} + +/* Constructor for call_data */ +grpc_error_handle CompressInitCallElem(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(elem, *args); + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data */ +void CompressDestroyCallElem(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + CallData* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +/* Constructor for ChannelData */ +grpc_error_handle CompressInitChannelElem(grpc_channel_element* elem, + grpc_channel_element_args* args) { + new (elem->channel_data) ChannelData(args); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +void CompressDestroyChannelElem(grpc_channel_element* elem) { + ChannelData* channeld = static_cast(elem->channel_data); + channeld->~ChannelData(); +} + +} // namespace + +const grpc_channel_filter grpc_message_compress_filter = { + CompressStartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CompressInitCallElem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CompressDestroyCallElem, + sizeof(ChannelData), + CompressInitChannelElem, + CompressDestroyChannelElem, + grpc_channel_next_get_info, + "message_compress"}; diff --git a/src/core/ext/filters/http/message_compress/message_decompress_filter.cc b/src/core/ext/filters/http/message_compress/message_decompress_filter.cc new file mode 100644 index 00000000..b4653506 --- /dev/null +++ b/src/core/ext/filters/http/message_compress/message_decompress_filter.cc @@ -0,0 +1,398 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/filters/http/message_compress/message_decompress_filter.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/message_size/message_size_filter.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/compression/compression_internal.h" +#include "src/core/lib/compression/message_compress.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +namespace grpc_core { +namespace { + +class ChannelData { + public: + explicit ChannelData(const grpc_channel_element_args* args) + : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args->channel_args)) {} + + int max_recv_size() const { return max_recv_size_; } + + private: + int max_recv_size_; +}; + +class CallData { + public: + CallData(const grpc_call_element_args& args, const ChannelData* chand) + : call_combiner_(args.call_combiner), + max_recv_message_length_(chand->max_recv_size()) { + // Initialize state for recv_initial_metadata_ready callback + GRPC_CLOSURE_INIT(&on_recv_initial_metadata_ready_, + OnRecvInitialMetadataReady, this, + grpc_schedule_on_exec_ctx); + // Initialize state for recv_message_ready callback + grpc_slice_buffer_init(&recv_slices_); + GRPC_CLOSURE_INIT(&on_recv_message_next_done_, OnRecvMessageNextDone, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_recv_message_ready_, OnRecvMessageReady, this, + grpc_schedule_on_exec_ctx); + // Initialize state for recv_trailing_metadata_ready callback + GRPC_CLOSURE_INIT(&on_recv_trailing_metadata_ready_, + OnRecvTrailingMetadataReady, this, + grpc_schedule_on_exec_ctx); + const MessageSizeParsedConfig* limits = + MessageSizeParsedConfig::GetFromCallContext(args.context); + if (limits != nullptr && limits->limits().max_recv_size >= 0 && + (limits->limits().max_recv_size < max_recv_message_length_ || + max_recv_message_length_ < 0)) { + max_recv_message_length_ = limits->limits().max_recv_size; + } + } + + ~CallData() { grpc_slice_buffer_destroy_internal(&recv_slices_); } + + void DecompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch); + + private: + static void OnRecvInitialMetadataReady(void* arg, grpc_error_handle error); + + // Methods for processing a receive message event + void MaybeResumeOnRecvMessageReady(); + static void OnRecvMessageReady(void* arg, grpc_error_handle error); + static void OnRecvMessageNextDone(void* arg, grpc_error_handle error); + grpc_error_handle PullSliceFromRecvMessage(); + void ContinueReadingRecvMessage(); + void FinishRecvMessage(); + void ContinueRecvMessageReadyCallback(grpc_error_handle error); + + // Methods for processing a recv_trailing_metadata event + void MaybeResumeOnRecvTrailingMetadataReady(); + static void OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error); + + CallCombiner* call_combiner_; + // Overall error for the call + grpc_error_handle error_ = GRPC_ERROR_NONE; + // Fields for handling recv_initial_metadata_ready callback + grpc_closure on_recv_initial_metadata_ready_; + grpc_closure* original_recv_initial_metadata_ready_ = nullptr; + grpc_metadata_batch* recv_initial_metadata_ = nullptr; + // Fields for handling recv_message_ready callback + bool seen_recv_message_ready_ = false; + int max_recv_message_length_; + grpc_message_compression_algorithm algorithm_ = GRPC_MESSAGE_COMPRESS_NONE; + grpc_closure on_recv_message_ready_; + grpc_closure* original_recv_message_ready_ = nullptr; + grpc_closure on_recv_message_next_done_; + OrphanablePtr* recv_message_ = nullptr; + // recv_slices_ holds the slices read from the original recv_message stream. + // It is initialized during construction and reset when a new stream is + // created using it. + grpc_slice_buffer recv_slices_; + std::aligned_storage::type + recv_replacement_stream_; + // Fields for handling recv_trailing_metadata_ready callback + bool seen_recv_trailing_metadata_ready_ = false; + grpc_closure on_recv_trailing_metadata_ready_; + grpc_closure* original_recv_trailing_metadata_ready_ = nullptr; + grpc_error_handle on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE; +}; + +grpc_message_compression_algorithm DecodeMessageCompressionAlgorithm( + grpc_mdelem md) { + grpc_message_compression_algorithm algorithm = + grpc_message_compression_algorithm_from_slice(GRPC_MDVALUE(md)); + if (algorithm == GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) { + char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md)); + gpr_log(GPR_ERROR, + "Invalid incoming message compression algorithm: '%s'. " + "Interpreting incoming data as uncompressed.", + md_c_str); + gpr_free(md_c_str); + return GRPC_MESSAGE_COMPRESS_NONE; + } + return algorithm; +} + +void CallData::OnRecvInitialMetadataReady(void* arg, grpc_error_handle error) { + CallData* calld = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + grpc_linked_mdelem* grpc_encoding = + calld->recv_initial_metadata_->legacy_index()->named.grpc_encoding; + if (grpc_encoding != nullptr) { + calld->algorithm_ = DecodeMessageCompressionAlgorithm(grpc_encoding->md); + } + } + calld->MaybeResumeOnRecvMessageReady(); + calld->MaybeResumeOnRecvTrailingMetadataReady(); + grpc_closure* closure = calld->original_recv_initial_metadata_ready_; + calld->original_recv_initial_metadata_ready_ = nullptr; + Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error)); +} + +void CallData::MaybeResumeOnRecvMessageReady() { + if (seen_recv_message_ready_) { + seen_recv_message_ready_ = false; + GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_message_ready_, + GRPC_ERROR_NONE, + "continue recv_message_ready callback"); + } +} + +void CallData::OnRecvMessageReady(void* arg, grpc_error_handle error) { + CallData* calld = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + if (calld->original_recv_initial_metadata_ready_ != nullptr) { + calld->seen_recv_message_ready_ = true; + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "Deferring OnRecvMessageReady until after " + "OnRecvInitialMetadataReady"); + return; + } + if (calld->algorithm_ != GRPC_MESSAGE_COMPRESS_NONE) { + // recv_message can be NULL if trailing metadata is received instead of + // message, or it's possible that the message was not compressed. + if (*calld->recv_message_ == nullptr || + (*calld->recv_message_)->length() == 0 || + ((*calld->recv_message_)->flags() & GRPC_WRITE_INTERNAL_COMPRESS) == + 0) { + return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_NONE); + } + if (calld->max_recv_message_length_ >= 0 && + (*calld->recv_message_)->length() > + static_cast(calld->max_recv_message_length_)) { + GPR_DEBUG_ASSERT(calld->error_ == GRPC_ERROR_NONE); + calld->error_ = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Received message larger than max (%u vs. %d)", + (*calld->recv_message_)->length(), + calld->max_recv_message_length_)), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED); + return calld->ContinueRecvMessageReadyCallback( + GRPC_ERROR_REF(calld->error_)); + } + grpc_slice_buffer_destroy_internal(&calld->recv_slices_); + grpc_slice_buffer_init(&calld->recv_slices_); + return calld->ContinueReadingRecvMessage(); + } + } + calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error)); +} + +void CallData::ContinueReadingRecvMessage() { + while ((*recv_message_) + ->Next((*recv_message_)->length() - recv_slices_.length, + &on_recv_message_next_done_)) { + grpc_error_handle error = PullSliceFromRecvMessage(); + if (error != GRPC_ERROR_NONE) { + return ContinueRecvMessageReadyCallback(error); + } + // We have read the entire message. + if (recv_slices_.length == (*recv_message_)->length()) { + return FinishRecvMessage(); + } + } +} + +grpc_error_handle CallData::PullSliceFromRecvMessage() { + grpc_slice incoming_slice; + grpc_error_handle error = (*recv_message_)->Pull(&incoming_slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&recv_slices_, incoming_slice); + } + return error; +} + +void CallData::OnRecvMessageNextDone(void* arg, grpc_error_handle error) { + CallData* calld = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + return calld->ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error)); + } + error = calld->PullSliceFromRecvMessage(); + if (error != GRPC_ERROR_NONE) { + return calld->ContinueRecvMessageReadyCallback(error); + } + if (calld->recv_slices_.length == (*calld->recv_message_)->length()) { + calld->FinishRecvMessage(); + } else { + calld->ContinueReadingRecvMessage(); + } +} + +void CallData::FinishRecvMessage() { + grpc_slice_buffer decompressed_slices; + grpc_slice_buffer_init(&decompressed_slices); + if (grpc_msg_decompress(algorithm_, &recv_slices_, &decompressed_slices) == + 0) { + GPR_DEBUG_ASSERT(error_ == GRPC_ERROR_NONE); + error_ = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unexpected error decompressing data for algorithm with " + "enum value ", + algorithm_)); + grpc_slice_buffer_destroy_internal(&decompressed_slices); + } else { + uint32_t recv_flags = + ((*recv_message_)->flags() & (~GRPC_WRITE_INTERNAL_COMPRESS)) | + GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED; + // Swap out the original receive byte stream with our new one and send the + // batch down. + // Initializing recv_replacement_stream_ with decompressed_slices removes + // all the slices from decompressed_slices leaving it empty. + new (&recv_replacement_stream_) + SliceBufferByteStream(&decompressed_slices, recv_flags); + recv_message_->reset( + reinterpret_cast(&recv_replacement_stream_)); + recv_message_ = nullptr; + } + ContinueRecvMessageReadyCallback(GRPC_ERROR_REF(error_)); +} + +void CallData::ContinueRecvMessageReadyCallback(grpc_error_handle error) { + MaybeResumeOnRecvTrailingMetadataReady(); + // The surface will clean up the receiving stream if there is an error. + grpc_closure* closure = original_recv_message_ready_; + original_recv_message_ready_ = nullptr; + Closure::Run(DEBUG_LOCATION, closure, error); +} + +void CallData::MaybeResumeOnRecvTrailingMetadataReady() { + if (seen_recv_trailing_metadata_ready_) { + seen_recv_trailing_metadata_ready_ = false; + grpc_error_handle error = on_recv_trailing_metadata_ready_error_; + on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE; + GRPC_CALL_COMBINER_START(call_combiner_, &on_recv_trailing_metadata_ready_, + error, "Continuing OnRecvTrailingMetadataReady"); + } +} + +void CallData::OnRecvTrailingMetadataReady(void* arg, grpc_error_handle error) { + CallData* calld = static_cast(arg); + if (calld->original_recv_initial_metadata_ready_ != nullptr || + calld->original_recv_message_ready_ != nullptr) { + calld->seen_recv_trailing_metadata_ready_ = true; + calld->on_recv_trailing_metadata_ready_error_ = GRPC_ERROR_REF(error); + GRPC_CALL_COMBINER_STOP( + calld->call_combiner_, + "Deferring OnRecvTrailingMetadataReady until after " + "OnRecvInitialMetadataReady and OnRecvMessageReady"); + return; + } + error = grpc_error_add_child(GRPC_ERROR_REF(error), calld->error_); + calld->error_ = GRPC_ERROR_NONE; + grpc_closure* closure = calld->original_recv_trailing_metadata_ready_; + calld->original_recv_trailing_metadata_ready_ = nullptr; + Closure::Run(DEBUG_LOCATION, closure, error); +} + +void CallData::DecompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + // Handle recv_initial_metadata. + if (batch->recv_initial_metadata) { + recv_initial_metadata_ = + batch->payload->recv_initial_metadata.recv_initial_metadata; + original_recv_initial_metadata_ready_ = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &on_recv_initial_metadata_ready_; + } + // Handle recv_message + if (batch->recv_message) { + recv_message_ = batch->payload->recv_message.recv_message; + original_recv_message_ready_ = + batch->payload->recv_message.recv_message_ready; + batch->payload->recv_message.recv_message_ready = &on_recv_message_ready_; + } + // Handle recv_trailing_metadata + if (batch->recv_trailing_metadata) { + original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &on_recv_trailing_metadata_ready_; + } + // Pass control down the stack. + grpc_call_next_op(elem, batch); +} + +void DecompressStartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + GPR_TIMER_SCOPE("decompress_start_transport_stream_op_batch", 0); + CallData* calld = static_cast(elem->call_data); + calld->DecompressStartTransportStreamOpBatch(elem, batch); +} + +grpc_error_handle DecompressInitCallElem(grpc_call_element* elem, + const grpc_call_element_args* args) { + ChannelData* chand = static_cast(elem->channel_data); + new (elem->call_data) CallData(*args, chand); + return GRPC_ERROR_NONE; +} + +void DecompressDestroyCallElem(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + CallData* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +grpc_error_handle DecompressInitChannelElem(grpc_channel_element* elem, + grpc_channel_element_args* args) { + ChannelData* chand = static_cast(elem->channel_data); + new (chand) ChannelData(args); + return GRPC_ERROR_NONE; +} + +void DecompressDestroyChannelElem(grpc_channel_element* elem) { + ChannelData* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +} // namespace + +const grpc_channel_filter MessageDecompressFilter = { + DecompressStartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + DecompressInitCallElem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + DecompressDestroyCallElem, + sizeof(ChannelData), + DecompressInitChannelElem, + DecompressDestroyChannelElem, + grpc_channel_next_get_info, + "message_decompress"}; +} // namespace grpc_core diff --git a/src/core/ext/filters/http/server/http_server_filter.cc b/src/core/ext/filters/http/server/http_server_filter.cc new file mode 100644 index 00000000..0a2108d0 --- /dev/null +++ b/src/core/ext/filters/http/server/http_server_filter.cc @@ -0,0 +1,537 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/http/server/http_server_filter.h" + +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/percent_encoding.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/static_metadata.h" + +#define EXPECTED_CONTENT_TYPE "application/grpc" +#define EXPECTED_CONTENT_TYPE_LENGTH (sizeof(EXPECTED_CONTENT_TYPE) - 1) + +static void hs_recv_initial_metadata_ready(void* user_data, + grpc_error_handle err); +static void hs_recv_trailing_metadata_ready(void* user_data, + grpc_error_handle err); +static void hs_recv_message_ready(void* user_data, grpc_error_handle err); + +namespace { + +struct call_data { + call_data(grpc_call_element* elem, const grpc_call_element_args& args) + : call_combiner(args.call_combiner) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready, + hs_recv_initial_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_message_ready, hs_recv_message_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + hs_recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + } + + ~call_data() { + GRPC_ERROR_UNREF(recv_initial_metadata_ready_error); + if (have_read_stream) { + read_stream->Orphan(); + } + } + + grpc_core::CallCombiner* call_combiner; + + // Outgoing headers to add to send_initial_metadata. + grpc_linked_mdelem status; + grpc_linked_mdelem content_type; + + // If we see the recv_message contents in the GET query string, we + // store it here. + grpc_core::ManualConstructor read_stream; + bool have_read_stream = false; + + // State for intercepting recv_initial_metadata. + grpc_closure recv_initial_metadata_ready; + grpc_error_handle recv_initial_metadata_ready_error = GRPC_ERROR_NONE; + grpc_closure* original_recv_initial_metadata_ready; + grpc_metadata_batch* recv_initial_metadata = nullptr; + uint32_t* recv_initial_metadata_flags; + bool seen_recv_initial_metadata_ready = false; + + // State for intercepting recv_message. + grpc_closure* original_recv_message_ready; + grpc_closure recv_message_ready; + grpc_core::OrphanablePtr* recv_message; + bool seen_recv_message_ready = false; + + // State for intercepting recv_trailing_metadata + grpc_closure recv_trailing_metadata_ready; + grpc_closure* original_recv_trailing_metadata_ready; + grpc_error_handle recv_trailing_metadata_ready_error; + bool seen_recv_trailing_metadata_ready = false; +}; + +struct channel_data { + bool surface_user_agent; +}; + +} // namespace + +static grpc_error_handle hs_filter_outgoing_metadata(grpc_metadata_batch* b) { + if (b->legacy_index()->named.grpc_message != nullptr) { + grpc_slice pct_encoded_msg = grpc_core::PercentEncodeSlice( + GRPC_MDVALUE(b->legacy_index()->named.grpc_message->md), + grpc_core::PercentEncodingType::Compatible); + if (grpc_slice_is_equivalent( + pct_encoded_msg, + GRPC_MDVALUE(b->legacy_index()->named.grpc_message->md))) { + grpc_slice_unref_internal(pct_encoded_msg); + } else { + grpc_metadata_batch_set_value(b->legacy_index()->named.grpc_message, + pct_encoded_msg); + } + } + return GRPC_ERROR_NONE; +} + +static void hs_add_error(const char* error_name, grpc_error_handle* cumulative, + grpc_error_handle new_err) { + if (new_err == GRPC_ERROR_NONE) return; + if (*cumulative == GRPC_ERROR_NONE) { + *cumulative = GRPC_ERROR_CREATE_FROM_COPIED_STRING(error_name); + } + *cumulative = grpc_error_add_child(*cumulative, new_err); +} + +// Metadata equality within this filter leverages the fact that the sender was +// likely using the gRPC chttp2 transport, in which case the encoder would emit +// indexed values, in which case the local hpack parser would intern the +// relevant metadata, allowing a simple pointer comparison. +// +// That said, if the header was transmitted sans indexing/encoding, we still +// need to do the right thing. +// +// Assumptions: +// 1) The keys for a and b_static must match +// 2) b_static must be a statically allocated metadata object. +// 3) It is assumed that the remote end is indexing, but not necessary. +// TODO(arjunroy): Revisit this method when grpc_mdelem is strongly typed. +static bool md_strict_equal(grpc_mdelem a, grpc_mdelem b_static) { + // Hpack encoder on the remote side should emit indexed values, in which case + // hpack parser on this end should pick up interned values, in which case the + // pointer comparison alone is enough. + // + if (GPR_LIKELY(GRPC_MDELEM_IS_INTERNED(a))) { + return a.payload == b_static.payload; + } else { + return grpc_slice_eq_static_interned(GRPC_MDVALUE(a), + GRPC_MDVALUE(b_static)); + } +} + +static grpc_error_handle hs_filter_incoming_metadata(grpc_call_element* elem, + grpc_metadata_batch* b) { + call_data* calld = static_cast(elem->call_data); + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* error_name = "Failed processing incoming headers"; + + if (b->legacy_index()->named.method != nullptr) { + if (md_strict_equal(b->legacy_index()->named.method->md, + GRPC_MDELEM_METHOD_POST)) { + *calld->recv_initial_metadata_flags &= + ~(GRPC_INITIAL_METADATA_CACHEABLE_REQUEST | + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST); + } else if (md_strict_equal(b->legacy_index()->named.method->md, + GRPC_MDELEM_METHOD_PUT)) { + *calld->recv_initial_metadata_flags &= + ~GRPC_INITIAL_METADATA_CACHEABLE_REQUEST; + *calld->recv_initial_metadata_flags |= + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST; + } else if (md_strict_equal(b->legacy_index()->named.method->md, + GRPC_MDELEM_METHOD_GET)) { + *calld->recv_initial_metadata_flags |= + GRPC_INITIAL_METADATA_CACHEABLE_REQUEST; + *calld->recv_initial_metadata_flags &= + ~GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST; + } else { + hs_add_error(error_name, &error, + grpc_attach_md_to_error( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad header"), + b->legacy_index()->named.method->md)); + } + b->Remove(GRPC_BATCH_METHOD); + } else { + hs_add_error(error_name, &error, + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"), + GRPC_ERROR_STR_KEY, ":method")); + } + + auto te = b->Take(grpc_core::TeMetadata()); + if (te == grpc_core::TeMetadata::kTrailers) { + // Do nothing, ok. + } else if (!te.has_value()) { + hs_add_error(error_name, &error, + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"), + GRPC_ERROR_STR_KEY, "te")); + } else { + hs_add_error(error_name, &error, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad te header")); + } + + if (b->legacy_index()->named.scheme != nullptr) { + if (!md_strict_equal(b->legacy_index()->named.scheme->md, + GRPC_MDELEM_SCHEME_HTTP) && + !md_strict_equal(b->legacy_index()->named.scheme->md, + GRPC_MDELEM_SCHEME_HTTPS) && + !grpc_mdelem_static_value_eq(b->legacy_index()->named.scheme->md, + GRPC_MDELEM_SCHEME_GRPC)) { + hs_add_error(error_name, &error, + grpc_attach_md_to_error( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad header"), + b->legacy_index()->named.scheme->md)); + } + b->Remove(GRPC_BATCH_SCHEME); + } else { + hs_add_error(error_name, &error, + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"), + GRPC_ERROR_STR_KEY, ":scheme")); + } + + if (b->legacy_index()->named.content_type != nullptr) { + if (!grpc_mdelem_static_value_eq( + b->legacy_index()->named.content_type->md, + GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC)) { + if (grpc_slice_buf_start_eq( + GRPC_MDVALUE(b->legacy_index()->named.content_type->md), + EXPECTED_CONTENT_TYPE, EXPECTED_CONTENT_TYPE_LENGTH) && + (GRPC_SLICE_START_PTR(GRPC_MDVALUE( + b->legacy_index() + ->named.content_type->md))[EXPECTED_CONTENT_TYPE_LENGTH] == + '+' || + GRPC_SLICE_START_PTR(GRPC_MDVALUE( + b->legacy_index() + ->named.content_type->md))[EXPECTED_CONTENT_TYPE_LENGTH] == + ';')) { + /* Although the C implementation doesn't (currently) generate them, + any custom +-suffix is explicitly valid. */ + /* TODO(klempner): We should consider preallocating common values such + as +proto or +json, or at least stashing them if we see them. */ + /* TODO(klempner): Should we be surfacing this to application code? */ + } else { + /* TODO(klempner): We're currently allowing this, but we shouldn't + see it without a proxy so log for now. */ + char* val = grpc_dump_slice( + GRPC_MDVALUE(b->legacy_index()->named.content_type->md), + GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "Unexpected content-type '%s'", val); + gpr_free(val); + } + } + b->Remove(GRPC_BATCH_CONTENT_TYPE); + } + + if (b->legacy_index()->named.path == nullptr) { + hs_add_error(error_name, &error, + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"), + GRPC_ERROR_STR_KEY, ":path")); + } else if (*calld->recv_initial_metadata_flags & + GRPC_INITIAL_METADATA_CACHEABLE_REQUEST) { + /* We have a cacheable request made with GET verb. The path contains the + * query parameter which is base64 encoded request payload. */ + const char k_query_separator = '?'; + grpc_slice path_slice = GRPC_MDVALUE(b->legacy_index()->named.path->md); + uint8_t* path_ptr = GRPC_SLICE_START_PTR(path_slice); + size_t path_length = GRPC_SLICE_LENGTH(path_slice); + /* offset of the character '?' */ + size_t offset = 0; + for (offset = 0; offset < path_length && *path_ptr != k_query_separator; + path_ptr++, offset++) { + } + if (offset < path_length) { + grpc_slice query_slice = + grpc_slice_sub(path_slice, offset + 1, path_length); + + /* substitute path metadata with just the path (not query) */ + grpc_mdelem mdelem_path_without_query = grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, grpc_slice_sub(path_slice, 0, offset)); + + (void)b->Substitute(b->legacy_index()->named.path, + mdelem_path_without_query); + + /* decode payload from query and add to the slice buffer to be returned */ + const int k_url_safe = 1; + grpc_slice_buffer read_slice_buffer; + grpc_slice_buffer_init(&read_slice_buffer); + grpc_slice_buffer_add( + &read_slice_buffer, + grpc_base64_decode_with_len( + reinterpret_cast GRPC_SLICE_START_PTR(query_slice), + GRPC_SLICE_LENGTH(query_slice), k_url_safe)); + calld->read_stream.Init(&read_slice_buffer, 0); + grpc_slice_buffer_destroy_internal(&read_slice_buffer); + calld->have_read_stream = true; + grpc_slice_unref_internal(query_slice); + } else { + gpr_log(GPR_ERROR, "GET request without QUERY"); + } + } + + if (b->legacy_index()->named.host != nullptr && + b->legacy_index()->named.authority == nullptr) { + grpc_linked_mdelem* el = b->legacy_index()->named.host; + grpc_mdelem md = GRPC_MDELEM_REF(el->md); + b->Remove(el); + hs_add_error( + error_name, &error, + grpc_metadata_batch_add_head( + b, el, + grpc_mdelem_from_slices(GRPC_MDSTR_AUTHORITY, + grpc_slice_ref_internal(GRPC_MDVALUE(md))), + GRPC_BATCH_AUTHORITY)); + GRPC_MDELEM_UNREF(md); + } + + if (b->legacy_index()->named.authority == nullptr) { + hs_add_error(error_name, &error, + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Missing header"), + GRPC_ERROR_STR_KEY, ":authority")); + } + + channel_data* chand = static_cast(elem->channel_data); + if (!chand->surface_user_agent && + b->legacy_index()->named.user_agent != nullptr) { + b->Remove(GRPC_BATCH_USER_AGENT); + } + + return error; +} + +static void hs_recv_initial_metadata_ready(void* user_data, + grpc_error_handle err) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + calld->seen_recv_initial_metadata_ready = true; + if (err == GRPC_ERROR_NONE) { + err = hs_filter_incoming_metadata(elem, calld->recv_initial_metadata); + calld->recv_initial_metadata_ready_error = GRPC_ERROR_REF(err); + if (calld->seen_recv_message_ready) { + // We've already seen the recv_message callback, but we previously + // deferred it, so we need to return it here. + // Replace the recv_message byte stream if needed. + if (calld->have_read_stream) { + calld->recv_message->reset(calld->read_stream.get()); + calld->have_read_stream = false; + } + // Re-enter call combiner for original_recv_message_ready, since the + // surface code will release the call combiner for each callback it + // receives. + GRPC_CALL_COMBINER_START( + calld->call_combiner, calld->original_recv_message_ready, + GRPC_ERROR_REF(err), + "resuming recv_message_ready from recv_initial_metadata_ready"); + } + } else { + (void)GRPC_ERROR_REF(err); + } + if (calld->seen_recv_trailing_metadata_ready) { + GRPC_CALL_COMBINER_START(calld->call_combiner, + &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_ready_error, + "resuming hs_recv_trailing_metadata_ready from " + "hs_recv_initial_metadata_ready"); + } + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_initial_metadata_ready, err); +} + +static void hs_recv_message_ready(void* user_data, grpc_error_handle err) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + calld->seen_recv_message_ready = true; + if (calld->seen_recv_initial_metadata_ready) { + // We've already seen the recv_initial_metadata callback, so + // replace the recv_message byte stream if needed and invoke the + // original recv_message callback immediately. + if (calld->have_read_stream) { + calld->recv_message->reset(calld->read_stream.get()); + calld->have_read_stream = false; + } + grpc_core::Closure::Run(DEBUG_LOCATION, calld->original_recv_message_ready, + GRPC_ERROR_REF(err)); + } else { + // We have not yet seen the recv_initial_metadata callback, so we + // need to wait to see if this is a GET request. + // Note that we release the call combiner here, so that other + // callbacks can run. + GRPC_CALL_COMBINER_STOP( + calld->call_combiner, + "pausing recv_message_ready until recv_initial_metadata_ready"); + } +} + +static void hs_recv_trailing_metadata_ready(void* user_data, + grpc_error_handle err) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (!calld->seen_recv_initial_metadata_ready) { + calld->recv_trailing_metadata_ready_error = GRPC_ERROR_REF(err); + calld->seen_recv_trailing_metadata_ready = true; + GRPC_CALL_COMBINER_STOP(calld->call_combiner, + "deferring hs_recv_trailing_metadata_ready until " + "ater hs_recv_initial_metadata_ready"); + return; + } + err = grpc_error_add_child( + GRPC_ERROR_REF(err), + GRPC_ERROR_REF(calld->recv_initial_metadata_ready_error)); + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_trailing_metadata_ready, err); +} + +static grpc_error_handle hs_mutate_op(grpc_call_element* elem, + grpc_transport_stream_op_batch* op) { + /* grab pointers to our data from the call element */ + call_data* calld = static_cast(elem->call_data); + + if (op->send_initial_metadata) { + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* error_name = "Failed sending initial metadata"; + hs_add_error( + error_name, &error, + grpc_metadata_batch_add_head( + op->payload->send_initial_metadata.send_initial_metadata, + &calld->status, GRPC_MDELEM_STATUS_200, GRPC_BATCH_STATUS)); + hs_add_error(error_name, &error, + grpc_metadata_batch_add_tail( + op->payload->send_initial_metadata.send_initial_metadata, + &calld->content_type, + GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC, + GRPC_BATCH_CONTENT_TYPE)); + hs_add_error(error_name, &error, + hs_filter_outgoing_metadata( + op->payload->send_initial_metadata.send_initial_metadata)); + if (error != GRPC_ERROR_NONE) return error; + } + + if (op->recv_initial_metadata) { + /* substitute our callback for the higher callback */ + GPR_ASSERT(op->payload->recv_initial_metadata.recv_flags != nullptr); + calld->recv_initial_metadata = + op->payload->recv_initial_metadata.recv_initial_metadata; + calld->recv_initial_metadata_flags = + op->payload->recv_initial_metadata.recv_flags; + calld->original_recv_initial_metadata_ready = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + op->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready; + } + + if (op->recv_message) { + calld->recv_message = op->payload->recv_message.recv_message; + calld->original_recv_message_ready = + op->payload->recv_message.recv_message_ready; + op->payload->recv_message.recv_message_ready = &calld->recv_message_ready; + } + + if (op->recv_trailing_metadata) { + calld->original_recv_trailing_metadata_ready = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready; + } + + if (op->send_trailing_metadata) { + grpc_error_handle error = hs_filter_outgoing_metadata( + op->payload->send_trailing_metadata.send_trailing_metadata); + if (error != GRPC_ERROR_NONE) return error; + } + + return GRPC_ERROR_NONE; +} + +static void hs_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + GPR_TIMER_SCOPE("hs_start_transport_stream_op_batch", 0); + call_data* calld = static_cast(elem->call_data); + grpc_error_handle error = hs_mutate_op(elem, op); + if (error != GRPC_ERROR_NONE) { + grpc_transport_stream_op_batch_finish_with_failure(op, error, + calld->call_combiner); + } else { + grpc_call_next_op(elem, op); + } +} + +/* Constructor for call_data */ +static grpc_error_handle hs_init_call_elem(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) call_data(elem, *args); + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data */ +static void hs_destroy_call_elem(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->~call_data(); +} + +/* Constructor for channel_data */ +static grpc_error_handle hs_init_channel_elem(grpc_channel_element* elem, + grpc_channel_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + GPR_ASSERT(!args->is_last); + chand->surface_user_agent = grpc_channel_arg_get_bool( + grpc_channel_args_find(args->channel_args, + const_cast(GRPC_ARG_SURFACE_USER_AGENT)), + true); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +static void hs_destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +const grpc_channel_filter grpc_http_server_filter = { + hs_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + hs_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + hs_destroy_call_elem, + sizeof(channel_data), + hs_init_channel_elem, + hs_destroy_channel_elem, + grpc_channel_next_get_info, + "http-server"}; diff --git a/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc new file mode 100644 index 00000000..9be2f739 --- /dev/null +++ b/src/core/ext/filters/load_reporting/server_load_reporting_filter.cc @@ -0,0 +1,364 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/load_reporting/server_load_reporting_filter.h" + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" +#include "src/core/ext/filters/load_reporting/registered_opencensus_objects.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/context.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc { + +constexpr char kEncodedIpv4AddressLengthString[] = "08"; +constexpr char kEncodedIpv6AddressLengthString[] = "32"; +constexpr char kEmptyAddressLengthString[] = "00"; +constexpr size_t kLengthPrefixSize = 2; + +grpc_error_handle ServerLoadReportingChannelData::Init( + grpc_channel_element* /* elem */, grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + // Find and record the peer_identity. + const grpc_auth_context* auth_context = + grpc_find_auth_context_in_args(args->channel_args); + if (auth_context != nullptr && + grpc_auth_context_peer_is_authenticated(auth_context)) { + grpc_auth_property_iterator auth_it = + grpc_auth_context_peer_identity(auth_context); + const grpc_auth_property* auth_property = + grpc_auth_property_iterator_next(&auth_it); + if (auth_property != nullptr) { + peer_identity_ = auth_property->value; + peer_identity_len_ = auth_property->value_length; + } + } + return GRPC_ERROR_NONE; +} + +void ServerLoadReportingCallData::Destroy( + grpc_call_element* elem, const grpc_call_final_info* final_info, + grpc_closure* /*then_call_closure*/) { + ServerLoadReportingChannelData* chand = + reinterpret_cast(elem->channel_data); + // Only record an end if we've recorded its corresponding start, which is + // indicated by a non-null client_ip_and_lr_token_. Note that it's possible + // that we attempt to record the call end before we have recorded the call + // start, because the data needed for recording the start comes from the + // initial metadata, which may not be ready before the call finishes. + if (client_ip_and_lr_token_ != nullptr) { + opencensus::stats::Record( + {{::grpc::load_reporter::MeasureEndCount(), 1}, + {::grpc::load_reporter::MeasureEndBytesSent(), + final_info->stats.transport_stream_stats.outgoing.data_bytes}, + {::grpc::load_reporter::MeasureEndBytesReceived(), + final_info->stats.transport_stream_stats.incoming.data_bytes}, + {::grpc::load_reporter::MeasureEndLatencyMs(), + gpr_time_to_millis(final_info->stats.latency)}}, + {{::grpc::load_reporter::TagKeyToken(), + {client_ip_and_lr_token_, client_ip_and_lr_token_len_}}, + {::grpc::load_reporter::TagKeyHost(), + {target_host_, target_host_len_}}, + {::grpc::load_reporter::TagKeyUserId(), + {chand->peer_identity(), chand->peer_identity_len()}}, + {::grpc::load_reporter::TagKeyStatus(), + GetStatusTagForStatus(final_info->final_status)}}); + gpr_free(client_ip_and_lr_token_); + } + gpr_free(target_host_); + grpc_slice_unref_internal(service_method_); +} + +void ServerLoadReportingCallData::StartTransportStreamOpBatch( + grpc_call_element* elem, TransportStreamOpBatch* op) { + GPR_TIMER_SCOPE("lr_start_transport_stream_op", 0); + if (op->recv_initial_metadata() != nullptr) { + // Save some fields to use when initial metadata is ready. + peer_string_ = op->get_peer_string(); + recv_initial_metadata_ = + op->op()->payload->recv_initial_metadata.recv_initial_metadata; + original_recv_initial_metadata_ready_ = op->recv_initial_metadata_ready(); + // Substitute the original closure for the wrapper closure. + op->set_recv_initial_metadata_ready(&recv_initial_metadata_ready_); + } else if (op->send_trailing_metadata() != nullptr) { + GRPC_LOG_IF_ERROR( + "server_load_reporting_filter", + grpc_metadata_batch_filter(op->send_trailing_metadata()->batch(), + SendTrailingMetadataFilter, elem, + "send_trailing_metadata filtering error")); + } + grpc_call_next_op(elem, op->op()); +} + +std::string ServerLoadReportingCallData::GetCensusSafeClientIpString() { + // Find the client URI string. + const char* client_uri_str = + reinterpret_cast(gpr_atm_acq_load(peer_string_)); + if (client_uri_str == nullptr) { + gpr_log(GPR_ERROR, + "Unable to extract client URI string (peer string) from gRPC " + "metadata."); + return ""; + } + absl::StatusOr client_uri = + grpc_core::URI::Parse(client_uri_str); + if (!client_uri.ok()) { + gpr_log(GPR_ERROR, + "Unable to parse the client URI string (peer string) to a client " + "URI. Error: %s", + client_uri.status().ToString().c_str()); + return ""; + } + // Parse the client URI into grpc_resolved_address. + grpc_resolved_address resolved_address; + bool success = grpc_parse_uri(*client_uri, &resolved_address); + if (!success) { + gpr_log(GPR_ERROR, + "Unable to parse client URI into a grpc_resolved_address."); + return ""; + } + // Convert the socket address in the grpc_resolved_address into a hex string + // according to the address family. + grpc_sockaddr* addr = reinterpret_cast(resolved_address.addr); + if (addr->sa_family == GRPC_AF_INET) { + grpc_sockaddr_in* addr4 = reinterpret_cast(addr); + return absl::StrFormat("%08x", grpc_ntohl(addr4->sin_addr.s_addr)); + } else if (addr->sa_family == GRPC_AF_INET6) { + grpc_sockaddr_in6* addr6 = reinterpret_cast(addr); + std::string client_ip; + client_ip.reserve(32); + uint32_t* addr6_next_long = reinterpret_cast(&addr6->sin6_addr); + for (size_t i = 0; i < 4; ++i) { + absl::StrAppendFormat(&client_ip, "%08x", grpc_ntohl(*addr6_next_long++)); + } + return client_ip; + } else { + GPR_UNREACHABLE_CODE(); + } +} + +void ServerLoadReportingCallData::StoreClientIpAndLrToken(const char* lr_token, + size_t lr_token_len) { + std::string client_ip = GetCensusSafeClientIpString(); + client_ip_and_lr_token_len_ = + kLengthPrefixSize + client_ip.size() + lr_token_len; + client_ip_and_lr_token_ = static_cast( + gpr_zalloc(client_ip_and_lr_token_len_ * sizeof(char))); + char* cur_pos = client_ip_and_lr_token_; + // Store the IP length prefix. + if (client_ip.empty()) { + strncpy(cur_pos, kEmptyAddressLengthString, kLengthPrefixSize); + } else if (client_ip.size() == 8) { + strncpy(cur_pos, kEncodedIpv4AddressLengthString, kLengthPrefixSize); + } else if (client_ip.size() == 32) { + strncpy(cur_pos, kEncodedIpv6AddressLengthString, kLengthPrefixSize); + } else { + GPR_UNREACHABLE_CODE(); + } + cur_pos += kLengthPrefixSize; + // Store the IP. + if (!client_ip.empty()) { + strncpy(cur_pos, client_ip.c_str(), client_ip.size()); + } + cur_pos += client_ip.size(); + // Store the LR token. + if (lr_token_len != 0) { + strncpy(cur_pos, lr_token, lr_token_len); + } + GPR_ASSERT( + static_cast(cur_pos + lr_token_len - client_ip_and_lr_token_) == + client_ip_and_lr_token_len_); +} + +grpc_filtered_mdelem ServerLoadReportingCallData::RecvInitialMetadataFilter( + void* user_data, grpc_mdelem md) { + grpc_call_element* elem = reinterpret_cast(user_data); + ServerLoadReportingCallData* calld = + reinterpret_cast(elem->call_data); + if (grpc_slice_eq(GRPC_MDKEY(md), GRPC_MDSTR_PATH)) { + calld->service_method_ = grpc_slice_ref_internal(GRPC_MDVALUE(md)); + } else if (calld->target_host_ == nullptr && + grpc_slice_eq(GRPC_MDKEY(md), GRPC_MDSTR_AUTHORITY)) { + grpc_slice target_host_slice = GRPC_MDVALUE(md); + calld->target_host_len_ = GRPC_SLICE_LENGTH(target_host_slice); + calld->target_host_ = + reinterpret_cast(gpr_zalloc(calld->target_host_len_)); + for (size_t i = 0; i < calld->target_host_len_; ++i) { + calld->target_host_[i] = static_cast( + tolower(GRPC_SLICE_START_PTR(target_host_slice)[i])); + } + } else if (grpc_slice_str_cmp(GRPC_MDKEY(md), + grpc_core::kGrpcLbLbTokenMetadataKey) == 0) { + if (calld->client_ip_and_lr_token_ == nullptr) { + calld->StoreClientIpAndLrToken( + reinterpret_cast GRPC_SLICE_START_PTR(GRPC_MDVALUE(md)), + GRPC_SLICE_LENGTH(GRPC_MDVALUE(md))); + } + return GRPC_FILTERED_REMOVE(); + } + return GRPC_FILTERED_MDELEM(md); +} + +void ServerLoadReportingCallData::RecvInitialMetadataReady( + void* arg, grpc_error_handle err) { + grpc_call_element* elem = reinterpret_cast(arg); + ServerLoadReportingCallData* calld = + reinterpret_cast(elem->call_data); + ServerLoadReportingChannelData* chand = + reinterpret_cast(elem->channel_data); + if (err == GRPC_ERROR_NONE) { + GRPC_LOG_IF_ERROR( + "server_load_reporting_filter", + grpc_metadata_batch_filter(calld->recv_initial_metadata_, + RecvInitialMetadataFilter, elem, + "recv_initial_metadata filtering error")); + // If the LB token was not found in the recv_initial_metadata, only the + // client IP part will be recorded (with an empty LB token). + if (calld->client_ip_and_lr_token_ == nullptr) { + calld->StoreClientIpAndLrToken(nullptr, 0); + } + opencensus::stats::Record( + {{::grpc::load_reporter::MeasureStartCount(), 1}}, + {{::grpc::load_reporter::TagKeyToken(), + {calld->client_ip_and_lr_token_, calld->client_ip_and_lr_token_len_}}, + {::grpc::load_reporter::TagKeyHost(), + {calld->target_host_, calld->target_host_len_}}, + {::grpc::load_reporter::TagKeyUserId(), + {chand->peer_identity(), chand->peer_identity_len()}}}); + } + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_initial_metadata_ready_, + GRPC_ERROR_REF(err)); +} + +grpc_error_handle ServerLoadReportingCallData::Init( + grpc_call_element* elem, const grpc_call_element_args* /*args*/) { + service_method_ = grpc_empty_slice(); + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + elem, grpc_schedule_on_exec_ctx); + return GRPC_ERROR_NONE; +} + +grpc_filtered_mdelem ServerLoadReportingCallData::SendTrailingMetadataFilter( + void* user_data, grpc_mdelem md) { + grpc_call_element* elem = reinterpret_cast(user_data); + ServerLoadReportingCallData* calld = + reinterpret_cast(elem->call_data); + ServerLoadReportingChannelData* chand = + reinterpret_cast(elem->channel_data); + if (grpc_slice_eq(GRPC_MDKEY(md), GRPC_MDSTR_LB_COST_BIN)) { + const grpc_slice value = GRPC_MDVALUE(md); + const size_t cost_entry_size = GRPC_SLICE_LENGTH(value); + if (cost_entry_size < sizeof(double)) { + gpr_log(GPR_ERROR, + "Cost metadata value too small (%zu bytes) to hold valid data. " + "Ignoring.", + cost_entry_size); + return GRPC_FILTERED_REMOVE(); + } + const double* cost_entry_ptr = + reinterpret_cast(GRPC_SLICE_START_PTR(value)); + double cost_value; + memcpy(&cost_value, cost_entry_ptr, sizeof(double)); + cost_entry_ptr++; + const char* cost_name = reinterpret_cast(cost_entry_ptr); + const size_t cost_name_len = cost_entry_size - sizeof(double); + opencensus::stats::Record( + {{::grpc::load_reporter::MeasureOtherCallMetric(), cost_value}}, + {{::grpc::load_reporter::TagKeyToken(), + {calld->client_ip_and_lr_token_, calld->client_ip_and_lr_token_len_}}, + {::grpc::load_reporter::TagKeyHost(), + {calld->target_host_, calld->target_host_len_}}, + {::grpc::load_reporter::TagKeyUserId(), + {chand->peer_identity(), chand->peer_identity_len()}}, + {::grpc::load_reporter::TagKeyMetricName(), + {cost_name, cost_name_len}}}); + return GRPC_FILTERED_REMOVE(); + } + return GRPC_FILTERED_MDELEM(md); +} + +const char* ServerLoadReportingCallData::GetStatusTagForStatus( + grpc_status_code status) { + switch (status) { + case GRPC_STATUS_OK: + return ::grpc::load_reporter::kCallStatusOk; + case GRPC_STATUS_UNKNOWN: + case GRPC_STATUS_DEADLINE_EXCEEDED: + case GRPC_STATUS_UNIMPLEMENTED: + case GRPC_STATUS_INTERNAL: + case GRPC_STATUS_UNAVAILABLE: + case GRPC_STATUS_DATA_LOSS: + return ::grpc::load_reporter::kCallStatusServerError; + default: + return ::grpc::load_reporter::kCallStatusClientError; + } +} + +namespace { +bool MaybeAddServerLoadReportingFilter(const grpc_channel_args& args) { + return grpc_channel_arg_get_bool( + grpc_channel_args_find(&args, GRPC_ARG_ENABLE_LOAD_REPORTING), false); +} +} // namespace + +// TODO(juanlishen): We should register the filter during grpc initialization +// time once OpenCensus is compatible with our build system. For now, we force +// registration of the server load reporting filter at static initialization +// time if we build with the filter target. +struct ServerLoadReportingFilterStaticRegistrar { + ServerLoadReportingFilterStaticRegistrar() { + static std::atomic registered{false}; + if (registered.load(std::memory_order_acquire)) return; + RegisterChannelFilter( + "server_load_reporting", GRPC_SERVER_CHANNEL, INT_MAX, + MaybeAddServerLoadReportingFilter); + // Access measures to ensure they are initialized. Otherwise, we can't + // create any valid view before the first RPC. + ::grpc::load_reporter::MeasureStartCount(); + ::grpc::load_reporter::MeasureEndCount(); + ::grpc::load_reporter::MeasureEndBytesSent(); + ::grpc::load_reporter::MeasureEndBytesReceived(); + ::grpc::load_reporter::MeasureEndLatencyMs(); + ::grpc::load_reporter::MeasureOtherCallMetric(); + registered.store(true, std::memory_order_release); + } +} server_load_reporting_filter_static_registrar; + +} // namespace grpc diff --git a/src/core/ext/filters/max_age/max_age_filter.cc b/src/core/ext/filters/max_age/max_age_filter.cc new file mode 100644 index 00000000..c0687284 --- /dev/null +++ b/src/core/ext/filters/max_age/max_age_filter.cc @@ -0,0 +1,560 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/filters/max_age/max_age_filter.h" + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/transport/http2_errors.h" + +/* If these settings change, make sure that we are not sending a GOAWAY for + * inproc transport, since a GOAWAY to inproc ends up destroying the transport. + */ +#define DEFAULT_MAX_CONNECTION_AGE_MS INT_MAX +#define DEFAULT_MAX_CONNECTION_AGE_GRACE_MS INT_MAX +#define DEFAULT_MAX_CONNECTION_IDLE_MS INT_MAX +#define MAX_CONNECTION_AGE_JITTER 0.1 + +#define MAX_CONNECTION_AGE_INTEGER_OPTIONS \ + { DEFAULT_MAX_CONNECTION_AGE_MS, 1, INT_MAX } +#define MAX_CONNECTION_IDLE_INTEGER_OPTIONS \ + { DEFAULT_MAX_CONNECTION_IDLE_MS, 1, INT_MAX } + +/* States for idle_state in channel_data */ +#define MAX_IDLE_STATE_INIT ((gpr_atm)0) +#define MAX_IDLE_STATE_SEEN_EXIT_IDLE ((gpr_atm)1) +#define MAX_IDLE_STATE_SEEN_ENTER_IDLE ((gpr_atm)2) +#define MAX_IDLE_STATE_TIMER_SET ((gpr_atm)3) + +namespace { +struct channel_data { + /* The channel stack to which we take refs for pending callbacks. */ + grpc_channel_stack* channel_stack; + /* Guards access to max_age_timer, max_age_timer_pending, max_age_grace_timer + and max_age_grace_timer_pending */ + grpc_core::Mutex max_age_timer_mu; + /* True if the max_age timer callback is currently pending */ + bool max_age_timer_pending ABSL_GUARDED_BY(max_age_timer_mu) = false; + /* True if the max_age_grace timer callback is currently pending */ + bool max_age_grace_timer_pending ABSL_GUARDED_BY(max_age_timer_mu) = false; + /* The timer for checking if the channel has reached its max age */ + grpc_timer max_age_timer ABSL_GUARDED_BY(max_age_timer_mu); + /* The timer for checking if the max-aged channel has uesed up the grace + period */ + grpc_timer max_age_grace_timer ABSL_GUARDED_BY(max_age_timer_mu); + /* The timer for checking if the channel's idle duration reaches + max_connection_idle */ + grpc_timer max_idle_timer; + /* Allowed max time a channel may have no outstanding rpcs */ + grpc_millis max_connection_idle; + /* Allowed max time a channel may exist */ + grpc_millis max_connection_age; + /* Allowed grace period after the channel reaches its max age */ + grpc_millis max_connection_age_grace; + /* Closure to run when the channel's idle duration reaches max_connection_idle + and should be closed gracefully */ + grpc_closure max_idle_timer_cb; + /* Closure to run when the channel reaches its max age and should be closed + gracefully */ + grpc_closure close_max_age_channel; + /* Closure to run the channel uses up its max age grace time and should be + closed forcibly */ + grpc_closure force_close_max_age_channel; + /* Closure to run when the init fo channel stack is done and the max_idle + timer should be started */ + grpc_closure start_max_idle_timer_after_init; + /* Closure to run when the init fo channel stack is done and the max_age timer + should be started */ + grpc_closure start_max_age_timer_after_init; + /* Closure to run when the goaway op is finished and the max_age_timer */ + grpc_closure start_max_age_grace_timer_after_goaway_op; + /* Number of active calls */ + gpr_atm call_count; + /* TODO(zyc): C++lize this state machine */ + /* 'idle_state' holds the states of max_idle_timer and channel idleness. + It can contain one of the following values: + +--------------------------------+----------------+---------+ + | idle_state | max_idle_timer | channel | + +--------------------------------+----------------+---------+ + | MAX_IDLE_STATE_INIT | unset | busy | + | MAX_IDLE_STATE_TIMER_SET | set, valid | idle | + | MAX_IDLE_STATE_SEEN_EXIT_IDLE | set, invalid | busy | + | MAX_IDLE_STATE_SEEN_ENTER_IDLE | set, invalid | idle | + +--------------------------------+----------------+---------+ + + MAX_IDLE_STATE_INIT: The initial and final state of 'idle_state'. The + channel has 1 or 1+ active calls, and the timer is not set. Note that + we may put a virtual call to hold this state at channel initialization or + shutdown, so that the channel won't enter other states. + + MAX_IDLE_STATE_TIMER_SET: The state after the timer is set and no calls + have arrived after the timer is set. The channel must have 0 active call in + this state. If the timer is fired in this state, we will close the channel + due to idleness. + + MAX_IDLE_STATE_SEEN_EXIT_IDLE: The state after the timer is set and at + least one call has arrived after the timer is set. The channel must have 1 + or 1+ active calls in this state. If the timer is fired in this state, we + won't reschudle it. + + MAX_IDLE_STATE_SEEN_ENTER_IDLE: The state after the timer is set and the at + least one call has arrived after the timer is set, BUT the channel + currently has 0 active calls. If the timer is fired in this state, we will + reschudle it. + + max_idle_timer will not be cancelled (unless the channel is shutting down). + If the timer callback is called when the max_idle_timer is valid (i.e. + idle_state is MAX_IDLE_STATE_TIMER_SET), the channel will be closed due to + idleness, otherwise the channel won't be changed. + + State transitions: + MAX_IDLE_STATE_INIT <-------3------ MAX_IDLE_STATE_SEEN_EXIT_IDLE + ^ | ^ ^ | + | | | | | + 1 2 +-----------4------------+ 6 7 + | | | | | + | v | | v + MAX_IDLE_STATE_TIMER_SET <----5------ MAX_IDLE_STATE_SEEN_ENTER_IDLE + + For 1, 3, 5 : See max_idle_timer_cb() function + For 2, 7 : See decrease_call_count() function + For 4, 6 : See increase_call_count() function */ + gpr_atm idle_state; + /* Time when the channel finished its last outstanding call, in grpc_millis */ + gpr_atm last_enter_idle_time_millis; +}; +} // namespace + +/* Increase the nubmer of active calls. Before the increasement, if there are no + calls, the max_idle_timer should be cancelled. */ +static void increase_call_count(channel_data* chand) { + /* Exit idle */ + if (gpr_atm_full_fetch_add(&chand->call_count, 1) == 0) { + while (true) { + gpr_atm idle_state = gpr_atm_acq_load(&chand->idle_state); + switch (idle_state) { + case MAX_IDLE_STATE_TIMER_SET: + /* max_idle_timer_cb may have already set idle_state to + MAX_IDLE_STATE_INIT, in this case, we don't need to set it to + MAX_IDLE_STATE_SEEN_EXIT_IDLE */ + gpr_atm_rel_cas(&chand->idle_state, MAX_IDLE_STATE_TIMER_SET, + MAX_IDLE_STATE_SEEN_EXIT_IDLE); + return; + case MAX_IDLE_STATE_SEEN_ENTER_IDLE: + gpr_atm_rel_store(&chand->idle_state, MAX_IDLE_STATE_SEEN_EXIT_IDLE); + return; + default: + /* try again */ + break; + } + } + } +} + +/* Decrease the nubmer of active calls. After the decrement, if there are no + calls, the max_idle_timer should be started. */ +static void decrease_call_count(channel_data* chand) { + /* Enter idle */ + if (gpr_atm_full_fetch_add(&chand->call_count, -1) == 1) { + gpr_atm_no_barrier_store(&chand->last_enter_idle_time_millis, + (gpr_atm)grpc_core::ExecCtx::Get()->Now()); + while (true) { + gpr_atm idle_state = gpr_atm_acq_load(&chand->idle_state); + switch (idle_state) { + case MAX_IDLE_STATE_INIT: + GRPC_CHANNEL_STACK_REF(chand->channel_stack, + "max_age max_idle_timer"); + grpc_timer_init( + &chand->max_idle_timer, + grpc_core::ExecCtx::Get()->Now() + chand->max_connection_idle, + &chand->max_idle_timer_cb); + gpr_atm_rel_store(&chand->idle_state, MAX_IDLE_STATE_TIMER_SET); + return; + case MAX_IDLE_STATE_SEEN_EXIT_IDLE: + if (gpr_atm_rel_cas(&chand->idle_state, MAX_IDLE_STATE_SEEN_EXIT_IDLE, + MAX_IDLE_STATE_SEEN_ENTER_IDLE)) { + return; + } + break; + default: + /* try again */ + break; + } + } + } +} + +static void start_max_idle_timer_after_init(void* arg, + grpc_error_handle /*error*/) { + channel_data* chand = static_cast(arg); + /* Decrease call_count. If there are no active calls at this time, + max_idle_timer will start here. If the number of active calls is not 0, + max_idle_timer will start after all the active calls end. */ + decrease_call_count(chand); + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, + "max_age start_max_idle_timer_after_init"); +} + +namespace grpc_core { + +class ConnectivityWatcher : public AsyncConnectivityStateWatcherInterface { + public: + explicit ConnectivityWatcher(channel_data* chand) : chand_(chand) { + GRPC_CHANNEL_STACK_REF(chand_->channel_stack, "max_age conn_watch"); + } + + ~ConnectivityWatcher() override { + GRPC_CHANNEL_STACK_UNREF(chand_->channel_stack, "max_age conn_watch"); + } + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& /* status */) override { + if (new_state != GRPC_CHANNEL_SHUTDOWN) return; + { + MutexLock lock(&chand_->max_age_timer_mu); + if (chand_->max_age_timer_pending) { + grpc_timer_cancel(&chand_->max_age_timer); + chand_->max_age_timer_pending = false; + } + if (chand_->max_age_grace_timer_pending) { + grpc_timer_cancel(&chand_->max_age_grace_timer); + chand_->max_age_grace_timer_pending = false; + } + } + /* If there are no active calls, this increasement will cancel + max_idle_timer, and prevent max_idle_timer from being started in the + future. */ + increase_call_count(chand_); + if (gpr_atm_acq_load(&chand_->idle_state) == + MAX_IDLE_STATE_SEEN_EXIT_IDLE) { + grpc_timer_cancel(&chand_->max_idle_timer); + } + } + + channel_data* chand_; +}; + +} // namespace grpc_core + +static void start_max_age_timer_after_init(void* arg, + grpc_error_handle /*error*/) { + channel_data* chand = static_cast(arg); + { + grpc_core::MutexLock lock(&chand->max_age_timer_mu); + chand->max_age_timer_pending = true; + GRPC_CHANNEL_STACK_REF(chand->channel_stack, "max_age max_age_timer"); + grpc_timer_init( + &chand->max_age_timer, + grpc_core::ExecCtx::Get()->Now() + chand->max_connection_age, + &chand->close_max_age_channel); + } + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->start_connectivity_watch.reset(new grpc_core::ConnectivityWatcher(chand)); + op->start_connectivity_watch_state = GRPC_CHANNEL_IDLE; + grpc_channel_next_op(grpc_channel_stack_element(chand->channel_stack, 0), op); + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, + "max_age start_max_age_timer_after_init"); +} + +static void start_max_age_grace_timer_after_goaway_op( + void* arg, grpc_error_handle /*error*/) { + channel_data* chand = static_cast(arg); + { + grpc_core::MutexLock lock(&chand->max_age_timer_mu); + chand->max_age_grace_timer_pending = true; + GRPC_CHANNEL_STACK_REF(chand->channel_stack, "max_age max_age_grace_timer"); + grpc_timer_init(&chand->max_age_grace_timer, + chand->max_connection_age_grace == GRPC_MILLIS_INF_FUTURE + ? GRPC_MILLIS_INF_FUTURE + : grpc_core::ExecCtx::Get()->Now() + + chand->max_connection_age_grace, + &chand->force_close_max_age_channel); + } + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, + "max_age start_max_age_grace_timer_after_goaway_op"); +} + +static void close_max_idle_channel(channel_data* chand) { + /* Prevent the max idle timer from being set again */ + gpr_atm_no_barrier_fetch_add(&chand->call_count, 1); + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->goaway_error = + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("max_idle"), + GRPC_ERROR_INT_HTTP2_ERROR, GRPC_HTTP2_NO_ERROR); + grpc_channel_element* elem = + grpc_channel_stack_element(chand->channel_stack, 0); + elem->filter->start_transport_op(elem, op); +} + +static void max_idle_timer_cb(void* arg, grpc_error_handle error) { + channel_data* chand = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + bool try_again = true; + while (try_again) { + gpr_atm idle_state = gpr_atm_acq_load(&chand->idle_state); + switch (idle_state) { + case MAX_IDLE_STATE_TIMER_SET: + close_max_idle_channel(chand); + /* This MAX_IDLE_STATE_INIT is a final state, we don't have to check + * if idle_state has been changed */ + gpr_atm_rel_store(&chand->idle_state, MAX_IDLE_STATE_INIT); + try_again = false; + break; + case MAX_IDLE_STATE_SEEN_EXIT_IDLE: + if (gpr_atm_rel_cas(&chand->idle_state, MAX_IDLE_STATE_SEEN_EXIT_IDLE, + MAX_IDLE_STATE_INIT)) { + try_again = false; + } + break; + case MAX_IDLE_STATE_SEEN_ENTER_IDLE: + GRPC_CHANNEL_STACK_REF(chand->channel_stack, + "max_age max_idle_timer"); + grpc_timer_init(&chand->max_idle_timer, + static_cast(gpr_atm_no_barrier_load( + &chand->last_enter_idle_time_millis)) + + chand->max_connection_idle, + &chand->max_idle_timer_cb); + /* idle_state may have already been set to + MAX_IDLE_STATE_SEEN_EXIT_IDLE by increase_call_count(), in this + case, we don't need to set it to MAX_IDLE_STATE_TIMER_SET */ + gpr_atm_rel_cas(&chand->idle_state, MAX_IDLE_STATE_SEEN_ENTER_IDLE, + MAX_IDLE_STATE_TIMER_SET); + try_again = false; + break; + default: + /* try again */ + break; + } + } + } + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, "max_age max_idle_timer"); +} + +static void close_max_age_channel(void* arg, grpc_error_handle error) { + channel_data* chand = static_cast(arg); + { + grpc_core::MutexLock lock(&chand->max_age_timer_mu); + chand->max_age_timer_pending = false; + } + if (error == GRPC_ERROR_NONE) { + GRPC_CHANNEL_STACK_REF(chand->channel_stack, + "max_age start_max_age_grace_timer_after_goaway_op"); + grpc_transport_op* op = grpc_make_transport_op( + &chand->start_max_age_grace_timer_after_goaway_op); + op->goaway_error = + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("max_age"), + GRPC_ERROR_INT_HTTP2_ERROR, GRPC_HTTP2_NO_ERROR); + grpc_channel_element* elem = + grpc_channel_stack_element(chand->channel_stack, 0); + elem->filter->start_transport_op(elem, op); + } else if (error != GRPC_ERROR_CANCELLED) { + GRPC_LOG_IF_ERROR("close_max_age_channel", error); + } + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, "max_age max_age_timer"); +} + +static void force_close_max_age_channel(void* arg, grpc_error_handle error) { + channel_data* chand = static_cast(arg); + { + grpc_core::MutexLock lock(&chand->max_age_timer_mu); + chand->max_age_grace_timer_pending = false; + } + if (error == GRPC_ERROR_NONE) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->disconnect_with_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Channel reaches max age"); + grpc_channel_element* elem = + grpc_channel_stack_element(chand->channel_stack, 0); + elem->filter->start_transport_op(elem, op); + } else if (error != GRPC_ERROR_CANCELLED) { + GRPC_LOG_IF_ERROR("force_close_max_age_channel", error); + } + GRPC_CHANNEL_STACK_UNREF(chand->channel_stack, "max_age max_age_grace_timer"); +} + +/* A random jitter of +/-10% will be added to MAX_CONNECTION_AGE to spread out + connection storms. Note that the MAX_CONNECTION_AGE option without jitter + would not create connection storms by itself, but if there happened to be a + connection storm it could cause it to repeat at a fixed period. */ +static grpc_millis +add_random_max_connection_age_jitter_and_convert_to_grpc_millis(int value) { + /* generate a random number between 1 - MAX_CONNECTION_AGE_JITTER and + 1 + MAX_CONNECTION_AGE_JITTER */ + double multiplier = rand() * MAX_CONNECTION_AGE_JITTER * 2.0 / RAND_MAX + + 1.0 - MAX_CONNECTION_AGE_JITTER; + double result = multiplier * value; + /* INT_MAX - 0.5 converts the value to float, so that result will not be + cast to int implicitly before the comparison. */ + return result > (static_cast(GRPC_MILLIS_INF_FUTURE)) - 0.5 + ? GRPC_MILLIS_INF_FUTURE + : static_cast(result); +} + +/* Constructor for call_data. */ +static grpc_error_handle max_age_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* /*args*/) { + channel_data* chand = static_cast(elem->channel_data); + increase_call_count(chand); + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data. */ +static void max_age_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + channel_data* chand = static_cast(elem->channel_data); + decrease_call_count(chand); +} + +/* Constructor for channel_data. */ +static grpc_error_handle max_age_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + new (chand) channel_data(); + chand->channel_stack = args->channel_stack; + chand->max_connection_age = + add_random_max_connection_age_jitter_and_convert_to_grpc_millis( + DEFAULT_MAX_CONNECTION_AGE_MS); + chand->max_connection_age_grace = + DEFAULT_MAX_CONNECTION_AGE_GRACE_MS == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : DEFAULT_MAX_CONNECTION_AGE_GRACE_MS; + chand->max_connection_idle = DEFAULT_MAX_CONNECTION_IDLE_MS == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : DEFAULT_MAX_CONNECTION_IDLE_MS; + chand->idle_state = MAX_IDLE_STATE_INIT; + gpr_atm_no_barrier_store(&chand->last_enter_idle_time_millis, GPR_ATM_MIN); + for (size_t i = 0; i < args->channel_args->num_args; ++i) { + if (0 == strcmp(args->channel_args->args[i].key, + GRPC_ARG_MAX_CONNECTION_AGE_MS)) { + const int value = grpc_channel_arg_get_integer( + &args->channel_args->args[i], MAX_CONNECTION_AGE_INTEGER_OPTIONS); + chand->max_connection_age = + add_random_max_connection_age_jitter_and_convert_to_grpc_millis( + value); + } else if (0 == strcmp(args->channel_args->args[i].key, + GRPC_ARG_MAX_CONNECTION_AGE_GRACE_MS)) { + const int value = grpc_channel_arg_get_integer( + &args->channel_args->args[i], + {DEFAULT_MAX_CONNECTION_AGE_GRACE_MS, 0, INT_MAX}); + chand->max_connection_age_grace = + value == INT_MAX ? GRPC_MILLIS_INF_FUTURE : value; + } else if (0 == strcmp(args->channel_args->args[i].key, + GRPC_ARG_MAX_CONNECTION_IDLE_MS)) { + const int value = grpc_channel_arg_get_integer( + &args->channel_args->args[i], MAX_CONNECTION_IDLE_INTEGER_OPTIONS); + chand->max_connection_idle = + value == INT_MAX ? GRPC_MILLIS_INF_FUTURE : value; + } + } + GRPC_CLOSURE_INIT(&chand->max_idle_timer_cb, max_idle_timer_cb, chand, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&chand->close_max_age_channel, close_max_age_channel, chand, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&chand->force_close_max_age_channel, + force_close_max_age_channel, chand, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&chand->start_max_idle_timer_after_init, + start_max_idle_timer_after_init, chand, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&chand->start_max_age_timer_after_init, + start_max_age_timer_after_init, chand, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&chand->start_max_age_grace_timer_after_goaway_op, + start_max_age_grace_timer_after_goaway_op, chand, + grpc_schedule_on_exec_ctx); + + if (chand->max_connection_age != GRPC_MILLIS_INF_FUTURE) { + /* When the channel reaches its max age, we send down an op with + goaway_error set. However, we can't send down any ops until after the + channel stack is fully initialized. If we start the timer here, we have + no guarantee that the timer won't pop before channel stack initialization + is finished. To avoid that problem, we create a closure to start the + timer, and we schedule that closure to be run after call stack + initialization is done. */ + GRPC_CHANNEL_STACK_REF(chand->channel_stack, + "max_age start_max_age_timer_after_init"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + &chand->start_max_age_timer_after_init, + GRPC_ERROR_NONE); + } + + /* Initialize the number of calls as 1, so that the max_idle_timer will not + start until start_max_idle_timer_after_init is invoked. */ + gpr_atm_rel_store(&chand->call_count, 1); + if (chand->max_connection_idle != GRPC_MILLIS_INF_FUTURE) { + GRPC_CHANNEL_STACK_REF(chand->channel_stack, + "max_age start_max_idle_timer_after_init"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + &chand->start_max_idle_timer_after_init, + GRPC_ERROR_NONE); + } + return GRPC_ERROR_NONE; +} + +/* Destructor for channel_data. */ +static void max_age_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + chand->~channel_data(); +} + +const grpc_channel_filter grpc_max_age_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, /* sizeof_call_data */ + max_age_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + max_age_destroy_call_elem, + sizeof(channel_data), + max_age_init_channel_elem, + max_age_destroy_channel_elem, + grpc_channel_next_get_info, + "max_age"}; + +namespace grpc_core { +void RegisterMaxAgeFilter(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + bool enable = grpc_channel_arg_get_integer( + grpc_channel_args_find( + channel_args, GRPC_ARG_MAX_CONNECTION_AGE_MS), + MAX_CONNECTION_AGE_INTEGER_OPTIONS) != INT_MAX || + grpc_channel_arg_get_integer( + grpc_channel_args_find( + channel_args, GRPC_ARG_MAX_CONNECTION_IDLE_MS), + MAX_CONNECTION_IDLE_INTEGER_OPTIONS) != INT_MAX; + if (enable) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_max_age_filter, nullptr, nullptr); + } else { + return true; + } + }); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc new file mode 100644 index 00000000..edb256c3 --- /dev/null +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -0,0 +1,402 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/message_size/message_size_filter.h" + +#include +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/surface/call.h" + +static void recv_message_ready(void* user_data, grpc_error_handle error); +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error); + +namespace grpc_core { + +namespace { +size_t g_message_size_parser_index; +} // namespace + +// +// MessageSizeParsedConfig +// + +const MessageSizeParsedConfig* MessageSizeParsedConfig::GetFromCallContext( + const grpc_call_context_element* context) { + if (context == nullptr) return nullptr; + auto* svc_cfg_call_data = static_cast( + context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value); + if (svc_cfg_call_data == nullptr) return nullptr; + return static_cast( + svc_cfg_call_data->GetMethodParsedConfig( + MessageSizeParser::ParserIndex())); +} + +// +// MessageSizeParser +// + +std::unique_ptr +MessageSizeParser::ParsePerMethodParams(const grpc_channel_args* /*args*/, + const Json& json, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr && *error == GRPC_ERROR_NONE); + std::vector error_list; + // Max request size. + int max_request_message_bytes = -1; + auto it = json.object_value().find("maxRequestMessageBytes"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING && + it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxRequestMessageBytes error:should be of type number")); + } else { + max_request_message_bytes = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (max_request_message_bytes == -1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxRequestMessageBytes error:should be non-negative")); + } + } + } + // Max response size. + int max_response_message_bytes = -1; + it = json.object_value().find("maxResponseMessageBytes"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::STRING && + it->second.type() != Json::Type::NUMBER) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxResponseMessageBytes error:should be of type number")); + } else { + max_response_message_bytes = + gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (max_response_message_bytes == -1) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:maxResponseMessageBytes error:should be non-negative")); + } + } + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Message size parser", &error_list); + return nullptr; + } + return absl::make_unique(max_request_message_bytes, + max_response_message_bytes); +} + +void MessageSizeParser::Register() { + g_message_size_parser_index = ServiceConfigParser::RegisterParser( + absl::make_unique()); +} + +size_t MessageSizeParser::ParserIndex() { return g_message_size_parser_index; } + +int GetMaxRecvSizeFromChannelArgs(const grpc_channel_args* args) { + if (grpc_channel_args_want_minimal_stack(args)) return -1; + return grpc_channel_args_find_integer( + args, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, + {GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH, -1, INT_MAX}); +} + +int GetMaxSendSizeFromChannelArgs(const grpc_channel_args* args) { + if (grpc_channel_args_want_minimal_stack(args)) return -1; + return grpc_channel_args_find_integer( + args, GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, + {GRPC_DEFAULT_MAX_SEND_MESSAGE_LENGTH, -1, INT_MAX}); +} + +} // namespace grpc_core + +namespace { +struct channel_data { + grpc_core::MessageSizeParsedConfig::message_size_limits limits; +}; + +struct call_data { + call_data(grpc_call_element* elem, const channel_data& chand, + const grpc_call_element_args& args) + : call_combiner(args.call_combiner), limits(chand.limits) { + GRPC_CLOSURE_INIT(&recv_message_ready, ::recv_message_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + ::recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + // Get max sizes from channel data, then merge in per-method config values. + // Note: Per-method config is only available on the client, so we + // apply the max request size to the send limit and the max response + // size to the receive limit. + const grpc_core::MessageSizeParsedConfig* limits = + grpc_core::MessageSizeParsedConfig::GetFromCallContext(args.context); + if (limits != nullptr) { + if (limits->limits().max_send_size >= 0 && + (limits->limits().max_send_size < this->limits.max_send_size || + this->limits.max_send_size < 0)) { + this->limits.max_send_size = limits->limits().max_send_size; + } + if (limits->limits().max_recv_size >= 0 && + (limits->limits().max_recv_size < this->limits.max_recv_size || + this->limits.max_recv_size < 0)) { + this->limits.max_recv_size = limits->limits().max_recv_size; + } + } + } + + ~call_data() { GRPC_ERROR_UNREF(error); } + + grpc_core::CallCombiner* call_combiner; + grpc_core::MessageSizeParsedConfig::message_size_limits limits; + // Receive closures are chained: we inject this closure as the + // recv_message_ready up-call on transport_stream_op, and remember to + // call our next_recv_message_ready member after handling it. + grpc_closure recv_message_ready; + grpc_closure recv_trailing_metadata_ready; + // The error caused by a message that is too large, or GRPC_ERROR_NONE + grpc_error_handle error = GRPC_ERROR_NONE; + // Used by recv_message_ready. + grpc_core::OrphanablePtr* recv_message = nullptr; + // Original recv_message_ready callback, invoked after our own. + grpc_closure* next_recv_message_ready = nullptr; + // Original recv_trailing_metadata callback, invoked after our own. + grpc_closure* original_recv_trailing_metadata_ready; + bool seen_recv_trailing_metadata = false; + grpc_error_handle recv_trailing_metadata_error; +}; + +} // namespace + +// Callback invoked when we receive a message. Here we check the max +// receive message size. +static void recv_message_ready(void* user_data, grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (*calld->recv_message != nullptr && calld->limits.max_recv_size >= 0 && + (*calld->recv_message)->length() > + static_cast(calld->limits.max_recv_size)) { + grpc_error_handle new_error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Received message larger than max (%u vs. %d)", + (*calld->recv_message)->length(), calld->limits.max_recv_size)), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED); + error = grpc_error_add_child(GRPC_ERROR_REF(error), new_error); + GRPC_ERROR_UNREF(calld->error); + calld->error = GRPC_ERROR_REF(error); + } else { + (void)GRPC_ERROR_REF(error); + } + // Invoke the next callback. + grpc_closure* closure = calld->next_recv_message_ready; + calld->next_recv_message_ready = nullptr; + if (calld->seen_recv_trailing_metadata) { + /* We might potentially see another RECV_MESSAGE op. In that case, we do not + * want to run the recv_trailing_metadata_ready closure again. The newer + * RECV_MESSAGE op cannot cause any errors since the transport has already + * invoked the recv_trailing_metadata_ready closure and all further + * RECV_MESSAGE ops will get null payloads. */ + calld->seen_recv_trailing_metadata = false; + GRPC_CALL_COMBINER_START(calld->call_combiner, + &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_error, + "continue recv_trailing_metadata_ready"); + } + grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); +} + +// Callback invoked on completion of recv_trailing_metadata +// Notifies the recv_trailing_metadata batch of any message size failures +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (calld->next_recv_message_ready != nullptr) { + calld->seen_recv_trailing_metadata = true; + calld->recv_trailing_metadata_error = GRPC_ERROR_REF(error); + GRPC_CALL_COMBINER_STOP(calld->call_combiner, + "deferring recv_trailing_metadata_ready until " + "after recv_message_ready"); + return; + } + error = + grpc_error_add_child(GRPC_ERROR_REF(error), GRPC_ERROR_REF(calld->error)); + // Invoke the next callback. + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_trailing_metadata_ready, error); +} + +// Start transport stream op. +static void message_size_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + call_data* calld = static_cast(elem->call_data); + // Check max send message size. + if (op->send_message && calld->limits.max_send_size >= 0 && + op->payload->send_message.send_message->length() > + static_cast(calld->limits.max_send_size)) { + grpc_transport_stream_op_batch_finish_with_failure( + op, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Sent message larger than max (%u vs. %d)", + op->payload->send_message.send_message->length(), + calld->limits.max_send_size)), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_RESOURCE_EXHAUSTED), + calld->call_combiner); + return; + } + // Inject callback for receiving a message. + if (op->recv_message) { + calld->next_recv_message_ready = + op->payload->recv_message.recv_message_ready; + calld->recv_message = op->payload->recv_message.recv_message; + op->payload->recv_message.recv_message_ready = &calld->recv_message_ready; + } + // Inject callback for receiving trailing metadata. + if (op->recv_trailing_metadata) { + calld->original_recv_trailing_metadata_ready = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready; + } + // Chain to the next filter. + grpc_call_next_op(elem, op); +} + +// Constructor for call_data. +static grpc_error_handle message_size_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + channel_data* chand = static_cast(elem->channel_data); + new (elem->call_data) call_data(elem, *chand, *args); + return GRPC_ERROR_NONE; +} + +// Destructor for call_data. +static void message_size_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->~call_data(); +} + +grpc_core::MessageSizeParsedConfig::message_size_limits get_message_size_limits( + const grpc_channel_args* channel_args) { + grpc_core::MessageSizeParsedConfig::message_size_limits lim; + lim.max_send_size = grpc_core::GetMaxSendSizeFromChannelArgs(channel_args); + lim.max_recv_size = grpc_core::GetMaxRecvSizeFromChannelArgs(channel_args); + return lim; +} + +// Constructor for channel_data. +static grpc_error_handle message_size_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + channel_data* chand = static_cast(elem->channel_data); + new (chand) channel_data(); + chand->limits = get_message_size_limits(args->channel_args); + return GRPC_ERROR_NONE; +} + +// Destructor for channel_data. +static void message_size_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + chand->~channel_data(); +} + +const grpc_channel_filter grpc_message_size_filter = { + message_size_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + message_size_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + message_size_destroy_call_elem, + sizeof(channel_data), + message_size_init_channel_elem, + message_size_destroy_channel_elem, + grpc_channel_next_get_info, + "message_size"}; + +// Used for GRPC_CLIENT_SUBCHANNEL +static bool maybe_add_message_size_filter_subchannel( + grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (grpc_channel_args_want_minimal_stack(channel_args)) { + return true; + } + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_message_size_filter, nullptr, nullptr); +} + +// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the filter +// only if message size limits or service config is specified. +static bool maybe_add_message_size_filter(grpc_channel_stack_builder* builder) { + const grpc_channel_args* channel_args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (grpc_channel_args_want_minimal_stack(channel_args)) { + return true; + } + bool enable = false; + grpc_core::MessageSizeParsedConfig::message_size_limits lim = + get_message_size_limits(channel_args); + if (lim.max_send_size != -1 || lim.max_recv_size != -1) { + enable = true; + } + const grpc_arg* a = + grpc_channel_args_find(channel_args, GRPC_ARG_SERVICE_CONFIG); + const char* svc_cfg_str = grpc_channel_arg_get_string(a); + if (svc_cfg_str != nullptr) { + enable = true; + } + if (enable) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_message_size_filter, nullptr, nullptr); + } else { + return true; + } +} + +void grpc_message_size_filter_init(void) { + grpc_core::MessageSizeParser::Register(); +} + +void grpc_message_size_filter_shutdown(void) {} + +namespace grpc_core { +void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter_subchannel); + builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter); + builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + maybe_add_message_size_filter); +} +} // namespace grpc_core diff --git a/src/core/ext/filters/server_config_selector/server_config_selector.cc b/src/core/ext/filters/server_config_selector/server_config_selector.cc new file mode 100644 index 00000000..85df64e4 --- /dev/null +++ b/src/core/ext/filters/server_config_selector/server_config_selector.cc @@ -0,0 +1,67 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/filters/server_config_selector/server_config_selector.h" + +#include "src/core/lib/channel/channel_args.h" + +namespace grpc_core { +namespace { + +void* ServerConfigSelectorProviderArgCopy(void* p) { + ServerConfigSelectorProvider* arg = + static_cast(p); + return arg->Ref().release(); +} + +void ServerConfigSelectorProviderArgDestroy(void* p) { + ServerConfigSelectorProvider* arg = + static_cast(p); + arg->Unref(); +} + +int ServerConfigSelectorProviderArgCmp(void* p, void* q) { + return QsortCompare(p, q); +} + +const grpc_arg_pointer_vtable kChannelArgVtable = { + ServerConfigSelectorProviderArgCopy, ServerConfigSelectorProviderArgDestroy, + ServerConfigSelectorProviderArgCmp}; + +const char* kServerConfigSelectorProviderChannelArgName = + "grpc.internal.server_config_selector_provider"; + +} // namespace + +grpc_arg ServerConfigSelectorProvider::MakeChannelArg() const { + return grpc_channel_arg_pointer_create( + const_cast(kServerConfigSelectorProviderChannelArgName), + const_cast(this), &kChannelArgVtable); +} + +RefCountedPtr +ServerConfigSelectorProvider::GetFromChannelArgs( + const grpc_channel_args& args) { + ServerConfigSelectorProvider* config_selector_provider = + grpc_channel_args_find_pointer( + &args, kServerConfigSelectorProviderChannelArgName); + return config_selector_provider != nullptr ? config_selector_provider->Ref() + : nullptr; +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc b/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc new file mode 100644 index 00000000..0f54ad52 --- /dev/null +++ b/src/core/ext/filters/server_config_selector/server_config_selector_filter.cc @@ -0,0 +1,265 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/filters/server_config_selector/server_config_selector_filter.h" + +#include "src/core/ext/filters/server_config_selector/server_config_selector.h" +#include "src/core/ext/service_config/service_config_call_data.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { + +namespace { + +class ChannelData { + public: + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* args); + static void Destroy(grpc_channel_element* elem); + + absl::StatusOr> config_selector() { + MutexLock lock(&mu_); + return config_selector_; + } + + private: + class ServerConfigSelectorWatcher + : public ServerConfigSelectorProvider::ServerConfigSelectorWatcher { + public: + explicit ServerConfigSelectorWatcher(ChannelData* chand) : chand_(chand) {} + void OnServerConfigSelectorUpdate( + absl::StatusOr> update) override { + MutexLock lock(&chand_->mu_); + chand_->config_selector_ = std::move(update); + } + + private: + ChannelData* chand_; + }; + + explicit ChannelData(RefCountedPtr + server_config_selector_provider); + ~ChannelData(); + + RefCountedPtr server_config_selector_provider_; + Mutex mu_; + absl::StatusOr> config_selector_ + ABSL_GUARDED_BY(mu_); +}; + +class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args); + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /* final_info */, + grpc_closure* /* then_schedule_closure */); + static void StartTransportStreamOpBatch(grpc_call_element* elem, + grpc_transport_stream_op_batch* op); + + private: + CallData(grpc_call_element* elem, const grpc_call_element_args& args); + ~CallData(); + static void RecvInitialMetadataReady(void* user_data, + grpc_error_handle error); + static void RecvTrailingMetadataReady(void* user_data, + grpc_error_handle error); + void MaybeResumeRecvTrailingMetadataReady(); + + grpc_call_context_element* call_context_; + grpc_core::CallCombiner* call_combiner_; + ServiceConfigCallData service_config_call_data_; + // Overall error for the call + grpc_error_handle error_ = GRPC_ERROR_NONE; + // State for keeping track of recv_initial_metadata + grpc_metadata_batch* recv_initial_metadata_ = nullptr; + grpc_closure* original_recv_initial_metadata_ready_ = nullptr; + grpc_closure recv_initial_metadata_ready_; + // State for keeping of track of recv_trailing_metadata + grpc_closure* original_recv_trailing_metadata_ready_; + grpc_closure recv_trailing_metadata_ready_; + grpc_error_handle recv_trailing_metadata_ready_error_; + bool seen_recv_trailing_metadata_ready_ = false; +}; + +// ChannelData + +grpc_error_handle ChannelData::Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(elem->filter = &kServerConfigSelectorFilter); + RefCountedPtr server_config_selector_provider = + ServerConfigSelectorProvider::GetFromChannelArgs(*args->channel_args); + if (server_config_selector_provider == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No ServerConfigSelectorProvider object found"); + } + new (elem->channel_data) + ChannelData(std::move(server_config_selector_provider)); + return GRPC_ERROR_NONE; +} + +void ChannelData::Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +ChannelData::ChannelData( + RefCountedPtr server_config_selector_provider) + : server_config_selector_provider_( + std::move(server_config_selector_provider)) { + GPR_ASSERT(server_config_selector_provider_ != nullptr); + auto server_config_selector_watcher = + absl::make_unique(this); + config_selector_ = server_config_selector_provider_->Watch( + std::move(server_config_selector_watcher)); +} + +ChannelData::~ChannelData() { server_config_selector_provider_->CancelWatch(); } + +// CallData + +grpc_error_handle CallData::Init(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(elem, *args); + return GRPC_ERROR_NONE; +} + +void CallData::Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*then_schedule_closure*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +void CallData::StartTransportStreamOpBatch(grpc_call_element* elem, + grpc_transport_stream_op_batch* op) { + CallData* calld = static_cast(elem->call_data); + if (op->recv_initial_metadata) { + calld->recv_initial_metadata_ = + op->payload->recv_initial_metadata.recv_initial_metadata; + calld->original_recv_initial_metadata_ready_ = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + op->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready_; + } + if (op->recv_trailing_metadata) { + // We might generate errors on receiving initial metadata which we need to + // bubble up through recv_trailing_metadata_ready + calld->original_recv_trailing_metadata_ready_ = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready_; + } + // Chain to the next filter. + grpc_call_next_op(elem, op); +} + +CallData::CallData(grpc_call_element* elem, const grpc_call_element_args& args) + : call_context_(args.context), call_combiner_(args.call_combiner) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + elem, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReady, + elem, grpc_schedule_on_exec_ctx); +} + +CallData::~CallData() { + // Remove the entry from call context, just in case anyone above us + // tries to look at it during call stack destruction. + call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value = nullptr; + GRPC_ERROR_UNREF(error_); +} + +void CallData::RecvInitialMetadataReady(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + CallData* calld = static_cast(elem->call_data); + ChannelData* chand = static_cast(elem->channel_data); + if (error == GRPC_ERROR_NONE) { + auto config_selector = chand->config_selector(); + if (config_selector.ok()) { + auto call_config = + config_selector.value()->GetCallConfig(calld->recv_initial_metadata_); + if (call_config.error != GRPC_ERROR_NONE) { + calld->error_ = call_config.error; + error = call_config.error; // Does not take a ref + } else { + calld->service_config_call_data_ = + ServiceConfigCallData(std::move(call_config.service_config), + call_config.method_configs, {}); + calld->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value = + &calld->service_config_call_data_; + } + } else { + calld->error_ = absl_status_to_grpc_error(config_selector.status()); + error = calld->error_; + } + } + calld->MaybeResumeRecvTrailingMetadataReady(); + grpc_closure* closure = calld->original_recv_initial_metadata_ready_; + calld->original_recv_initial_metadata_ready_ = nullptr; + Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error)); +} + +void CallData::RecvTrailingMetadataReady(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(user_data); + CallData* calld = static_cast(elem->call_data); + if (calld->original_recv_initial_metadata_ready_ != nullptr) { + calld->seen_recv_trailing_metadata_ready_ = true; + calld->recv_trailing_metadata_ready_error_ = GRPC_ERROR_REF(error); + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "Deferring RecvTrailingMetadataReady until after " + "RecvInitialMetadataReady"); + return; + } + error = grpc_error_add_child(GRPC_ERROR_REF(error), calld->error_); + calld->error_ = GRPC_ERROR_NONE; + grpc_closure* closure = calld->original_recv_trailing_metadata_ready_; + calld->original_recv_trailing_metadata_ready_ = nullptr; + Closure::Run(DEBUG_LOCATION, closure, error); +} + +void CallData::MaybeResumeRecvTrailingMetadataReady() { + if (seen_recv_trailing_metadata_ready_) { + seen_recv_trailing_metadata_ready_ = false; + grpc_error_handle error = recv_trailing_metadata_ready_error_; + recv_trailing_metadata_ready_error_ = GRPC_ERROR_NONE; + GRPC_CALL_COMBINER_START(call_combiner_, &recv_trailing_metadata_ready_, + error, "Continuing RecvTrailingMetadataReady"); + } +} + +} // namespace + +const grpc_channel_filter kServerConfigSelectorFilter = { + CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + sizeof(ChannelData), + ChannelData::Init, + ChannelData::Destroy, + grpc_channel_next_get_info, + "server_config_selector_filter", +}; + +} // namespace grpc_core diff --git a/src/core/ext/service_config/service_config.cc b/src/core/ext/service_config/service_config.cc new file mode 100644 index 00000000..643b175b --- /dev/null +++ b/src/core/ext/service_config/service_config.cc @@ -0,0 +1,227 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/service_config/service_config.h" + +#include + +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/ext/service_config/service_config_parser.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +RefCountedPtr ServiceConfig::Create( + const grpc_channel_args* args, absl::string_view json_string, + grpc_error_handle* error) { + GPR_DEBUG_ASSERT(error != nullptr); + Json json = Json::Parse(json_string, error); + if (*error != GRPC_ERROR_NONE) return nullptr; + return MakeRefCounted(args, std::string(json_string), + std::move(json), error); +} + +ServiceConfig::ServiceConfig(const grpc_channel_args* args, + std::string json_string, Json json, + grpc_error_handle* error) + : json_string_(std::move(json_string)), json_(std::move(json)) { + GPR_DEBUG_ASSERT(error != nullptr); + if (json_.type() != Json::Type::OBJECT) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("JSON value is not an object"); + return; + } + std::vector error_list; + grpc_error_handle global_error = GRPC_ERROR_NONE; + parsed_global_configs_ = + ServiceConfigParser::ParseGlobalParameters(args, json_, &global_error); + if (global_error != GRPC_ERROR_NONE) error_list.push_back(global_error); + grpc_error_handle local_error = ParsePerMethodParams(args); + if (local_error != GRPC_ERROR_NONE) error_list.push_back(local_error); + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Service config parsing error", + &error_list); + } +} + +ServiceConfig::~ServiceConfig() { + for (auto& p : parsed_method_configs_map_) { + grpc_slice_unref_internal(p.first); + } +} + +grpc_error_handle ServiceConfig::ParseJsonMethodConfig( + const grpc_channel_args* args, const Json& json) { + std::vector error_list; + // Parse method config with each registered parser. + auto parsed_configs = + absl::make_unique(); + grpc_error_handle parser_error = GRPC_ERROR_NONE; + *parsed_configs = + ServiceConfigParser::ParsePerMethodParameters(args, json, &parser_error); + if (parser_error != GRPC_ERROR_NONE) { + error_list.push_back(parser_error); + } + parsed_method_config_vectors_storage_.push_back(std::move(parsed_configs)); + const auto* vector_ptr = parsed_method_config_vectors_storage_.back().get(); + // Add an entry for each path. + bool found_name = false; + auto it = json.object_value().find("name"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error:not of type Array")); + return GRPC_ERROR_CREATE_FROM_VECTOR("methodConfig", &error_list); + } + const Json::Array& name_array = it->second.array_value(); + for (const Json& name : name_array) { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + std::string path = ParseJsonMethodName(name, &parse_error); + if (parse_error != GRPC_ERROR_NONE) { + error_list.push_back(parse_error); + } else { + found_name = true; + if (path.empty()) { + if (default_method_config_vector_ != nullptr) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error:multiple default method configs")); + } + default_method_config_vector_ = vector_ptr; + } else { + grpc_slice key = grpc_slice_from_copied_string(path.c_str()); + // If the key is not already present in the map, this will + // store a ref to the key in the map. + auto& value = parsed_method_configs_map_[key]; + if (value != nullptr) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error:multiple method configs with same name")); + // The map entry already existed, so we need to unref the + // key we just created. + grpc_slice_unref_internal(key); + } else { + value = vector_ptr; + } + } + } + } + } + if (!found_name) { + parsed_method_config_vectors_storage_.pop_back(); + } + return GRPC_ERROR_CREATE_FROM_VECTOR("methodConfig", &error_list); +} + +grpc_error_handle ServiceConfig::ParsePerMethodParams( + const grpc_channel_args* args) { + std::vector error_list; + auto it = json_.object_value().find("methodConfig"); + if (it != json_.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:methodConfig error:not of type Array")); + } + for (const Json& method_config : it->second.array_value()) { + if (method_config.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:methodConfig error:not of type Object")); + continue; + } + grpc_error_handle error = ParseJsonMethodConfig(args, method_config); + if (error != GRPC_ERROR_NONE) { + error_list.push_back(error); + } + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("Method Params", &error_list); +} + +std::string ServiceConfig::ParseJsonMethodName(const Json& json, + grpc_error_handle* error) { + if (json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error:type is not object"); + return ""; + } + // Find service name. + const std::string* service_name = nullptr; + auto it = json.object_value().find("service"); + if (it != json.object_value().end() && + it->second.type() != Json::Type::JSON_NULL) { + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error: field:service error:not of type string"); + return ""; + } + if (!it->second.string_value().empty()) { + service_name = &it->second.string_value(); + } + } + const std::string* method_name = nullptr; + // Find method name. + it = json.object_value().find("method"); + if (it != json.object_value().end() && + it->second.type() != Json::Type::JSON_NULL) { + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error: field:method error:not of type string"); + return ""; + } + if (!it->second.string_value().empty()) { + method_name = &it->second.string_value(); + } + } + // If neither service nor method are specified, it's the default. + // Method name may not be specified without service name. + if (service_name == nullptr) { + if (method_name != nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:name error:method name populated without service name"); + } + return ""; + } + // Construct path. + return absl::StrCat("/", *service_name, "/", + method_name == nullptr ? "" : *method_name); +} + +const ServiceConfigParser::ParsedConfigVector* +ServiceConfig::GetMethodParsedConfigVector(const grpc_slice& path) const { + if (parsed_method_configs_map_.empty()) { + return default_method_config_vector_; + } + // Try looking up the full path in the map. + auto it = parsed_method_configs_map_.find(path); + if (it != parsed_method_configs_map_.end()) return it->second; + // If we didn't find a match for the path, try looking for a wildcard + // entry (i.e., change "/service/method" to "/service/"). + UniquePtr path_str(grpc_slice_to_c_string(path)); + char* sep = strrchr(path_str.get(), '/'); + if (sep == nullptr) return nullptr; // Shouldn't ever happen. + sep[1] = '\0'; + grpc_slice wildcard_path = grpc_slice_from_static_string(path_str.get()); + it = parsed_method_configs_map_.find(wildcard_path); + if (it != parsed_method_configs_map_.end()) return it->second; + // Try default method config, if set. + return default_method_config_vector_; +} + +} // namespace grpc_core diff --git a/src/core/ext/service_config/service_config_parser.cc b/src/core/ext/service_config/service_config_parser.cc new file mode 100644 index 00000000..f649fd77 --- /dev/null +++ b/src/core/ext/service_config/service_config_parser.cc @@ -0,0 +1,89 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/service_config/service_config_parser.h" + +#include + +namespace grpc_core { + +namespace { +typedef absl::InlinedVector, + ServiceConfigParser::kNumPreallocatedParsers> + ServiceConfigParserList; +ServiceConfigParserList* g_registered_parsers; +} // namespace + +void ServiceConfigParserInit() { + GPR_ASSERT(g_registered_parsers == nullptr); + g_registered_parsers = new ServiceConfigParserList(); +} + +void ServiceConfigParserShutdown() { + delete g_registered_parsers; + g_registered_parsers = nullptr; +} + +size_t ServiceConfigParser::RegisterParser(std::unique_ptr parser) { + g_registered_parsers->push_back(std::move(parser)); + return g_registered_parsers->size() - 1; +} + +ServiceConfigParser::ParsedConfigVector +ServiceConfigParser::ParseGlobalParameters(const grpc_channel_args* args, + const Json& json, + grpc_error_handle* error) { + ParsedConfigVector parsed_global_configs; + std::vector error_list; + for (size_t i = 0; i < g_registered_parsers->size(); i++) { + grpc_error_handle parser_error = GRPC_ERROR_NONE; + auto parsed_config = (*g_registered_parsers)[i]->ParseGlobalParams( + args, json, &parser_error); + if (parser_error != GRPC_ERROR_NONE) { + error_list.push_back(parser_error); + } + parsed_global_configs.push_back(std::move(parsed_config)); + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("Global Params", &error_list); + } + return parsed_global_configs; +} + +ServiceConfigParser::ParsedConfigVector +ServiceConfigParser::ParsePerMethodParameters(const grpc_channel_args* args, + const Json& json, + grpc_error_handle* error) { + ParsedConfigVector parsed_method_configs; + std::vector error_list; + for (size_t i = 0; i < g_registered_parsers->size(); i++) { + grpc_error_handle parser_error = GRPC_ERROR_NONE; + auto parsed_config = (*g_registered_parsers)[i]->ParsePerMethodParams( + args, json, &parser_error); + if (parser_error != GRPC_ERROR_NONE) { + error_list.push_back(parser_error); + } + parsed_method_configs.push_back(std::move(parsed_config)); + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR("methodConfig", &error_list); + } + return parsed_method_configs; +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/binder/client/channel_create.cc b/src/core/ext/transport/binder/client/channel_create.cc new file mode 100644 index 00000000..c6392d93 --- /dev/null +++ b/src/core/ext/transport/binder/client/channel_create.cc @@ -0,0 +1,168 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/client/channel_create.h" + +// The interface is only defined if GPR_ANDROID is defined, because some +// arguments requires JNI. +// Furthermore, the interface is non-phony only when +// GPR_SUPPORT_BINDER_TRANSPORT is true because actual implementation of binder +// transport requires newer version of NDK API + +#ifdef GPR_ANDROID + +#include +#include + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + +#include + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" + +#include +#include + +#include "src/core/ext/transport/binder/client/channel_create_impl.h" +#include "src/core/ext/transport/binder/client/connection_id_generator.h" +#include "src/core/ext/transport/binder/client/endpoint_binder_pool.h" +#include "src/core/ext/transport/binder/client/jni_utils.h" +#include "src/core/ext/transport/binder/transport/binder_transport.h" +#include "src/core/ext/transport/binder/wire_format/binder.h" +#include "src/core/ext/transport/binder/wire_format/binder_android.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/transport.h" +#include "src/cpp/client/create_channel_internal.h" + +namespace grpc { +namespace experimental { + +namespace { +// TODO(mingcl): To support multiple binder transport connection at the same +// time, we will need to generate unique connection id for each connection. +// For now we just use a single connection id globally. This will be fixed after +// we drop the BindToOnDeviceServerService interface +std::string g_connection_id; +} // namespace + +void BindToOnDeviceServerService(void* jni_env_void, jobject application, + absl::string_view package_name, + absl::string_view class_name) { + // Init gRPC library first so gpr_log works + grpc::internal::GrpcLibrary init_lib; + init_lib.init(); + + JNIEnv* jni_env = static_cast(jni_env_void); + + { + GPR_ASSERT(g_connection_id == ""); + g_connection_id = grpc_binder::GetConnectionIdGenerator()->Generate( + std::string(package_name), std::string(class_name)); + GPR_ASSERT(g_connection_id != ""); + } + + // clang-format off + grpc_binder::CallStaticJavaMethod(jni_env, + "io/grpc/binder/cpp/NativeConnectionHelper", + "tryEstablishConnection", + "(Landroid/content/Context;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V", + application, std::string(package_name), std::string(class_name), g_connection_id); + // clang-format on +} + +// BindToOndeviceServerService need to be called before this, in a different +// task (due to Android API design). (Reference: +// https://stackoverflow.com/a/3055749) +std::shared_ptr CreateBinderChannel( + void* jni_env_void, jobject application, absl::string_view package_name, + absl::string_view class_name, + std::shared_ptr + security_policy) { + return CreateCustomBinderChannel(jni_env_void, application, package_name, + class_name, security_policy, + ChannelArguments()); +} + +// BindToOndeviceServerService need to be called before this, in a different +// task (due to Android API design). (Reference: +// https://stackoverflow.com/a/3055749) +std::shared_ptr CreateCustomBinderChannel( + void*, jobject /*application*/, absl::string_view /*package_name*/, + absl::string_view /*class_name*/, + std::shared_ptr security_policy, + const ChannelArguments& args) { + GPR_ASSERT(security_policy != nullptr); + + std::unique_ptr endpoint_binder; + grpc_binder::GetEndpointBinderPool()->GetEndpointBinder( + g_connection_id, [&](std::unique_ptr e) { + endpoint_binder = std::move(e); + }); + // This assumes the above callback will be called immediately before + // `GetEndpointBinder` returns + GPR_ASSERT(endpoint_binder != nullptr); + + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return CreateChannelInternal( + "", + ::grpc::internal::CreateChannelFromBinderImpl( + std::move(endpoint_binder), security_policy, &channel_args), + std::vector< + std::unique_ptr>()); +} + +} // namespace experimental +} // namespace grpc + +#else // !GPR_SUPPORT_BINDER_TRANSPORT + +namespace grpc { +namespace experimental { + +void BindToOnDeviceServerService(void*, jobject, absl::string_view, + absl::string_view) { + GPR_ASSERT(0); +} + +std::shared_ptr CreateBinderChannel( + void*, jobject, absl::string_view, absl::string_view, + std::shared_ptr) { + GPR_ASSERT(0); + return {}; +} + +std::shared_ptr CreateCustomBinderChannel( + void*, jobject, absl::string_view, absl::string_view, + std::shared_ptr, + const ChannelArguments&) { + GPR_ASSERT(0); + return {}; +} + +} // namespace experimental +} // namespace grpc + +#endif // GPR_SUPPORT_BINDER_TRANSPORT + +#endif // GPR_ANDROID diff --git a/src/core/ext/transport/binder/client/channel_create_impl.cc b/src/core/ext/transport/binder/client/channel_create_impl.cc new file mode 100644 index 00000000..ab88c1b1 --- /dev/null +++ b/src/core/ext/transport/binder/client/channel_create_impl.cc @@ -0,0 +1,62 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/client/channel_create_impl.h" + +#include +#include + +#include "src/core/ext/transport/binder/transport/binder_transport.h" +#include "src/core/ext/transport/binder/wire_format/binder.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" + +namespace grpc { +namespace internal { + +grpc_channel* CreateChannelFromBinderImpl( + std::unique_ptr endpoint_binder, + std::shared_ptr security_policy, + const grpc_channel_args* args) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_channel_create_from_binder(target=%p, args=%p)", 2, + ((void*)1234, args)); + + grpc_transport* transport = grpc_create_binder_transport_client( + std::move(endpoint_binder), security_policy); + GPR_ASSERT(transport); + + // TODO(b/192207753): check binder alive and ping binder + + // TODO(b/192207758): Figure out if we are required to set authority here + grpc_arg default_authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test.authority")); + grpc_channel_args* final_args = + grpc_channel_args_copy_and_add(args, &default_authority_arg, 1); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_channel* channel = grpc_channel_create( + "binder_target_placeholder", final_args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, &error); + // TODO(mingcl): Handle error properly + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_channel_args_destroy(final_args); + return channel; +} + +} // namespace internal +} // namespace grpc diff --git a/src/core/ext/transport/binder/client/connection_id_generator.cc b/src/core/ext/transport/binder/client/connection_id_generator.cc new file mode 100644 index 00000000..144296d3 --- /dev/null +++ b/src/core/ext/transport/binder/client/connection_id_generator.cc @@ -0,0 +1,67 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/client/connection_id_generator.h" + +#include "absl/strings/str_cat.h" + +namespace { +// Make sure `s` does not contain characters other than numbers, alphabets, +// period and underscore +std::string Normalize(absl::string_view str_view) { + std::string s = std::string(str_view); + for (size_t i = 0; i < s.length(); i++) { + if (!isalnum(s[i]) && s[i] != '.') { + s[i] = '_'; + } + } + return s; +} + +// Remove prefix of the string if the string is longer than len +std::string StripToLength(const std::string& s, size_t len) { + if (s.length() > len) { + return s.substr(s.length() - len, len); + } + return s; +} +} // namespace + +namespace grpc_binder { + +std::string ConnectionIdGenerator::Generate(absl::string_view package_name, + absl::string_view class_name) { + // reserve some room for serial number + const size_t kReserveForNumbers = 15; + std::string s = StripToLength( + absl::StrCat(Normalize(package_name), "-", Normalize(class_name)), + kPathLengthLimit - kReserveForNumbers); + std::string ret; + { + grpc_core::MutexLock l(&m_); + // Insert a hyphen before serial number + ret = absl::StrCat(s, "-", ++count_); + } + GPR_ASSERT(ret.length() < kPathLengthLimit); + return ret; +} + +ConnectionIdGenerator* GetConnectionIdGenerator() { + static ConnectionIdGenerator* cig = new ConnectionIdGenerator(); + return cig; +} + +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/client/endpoint_binder_pool.cc b/src/core/ext/transport/binder/client/endpoint_binder_pool.cc new file mode 100644 index 00000000..aa3794cf --- /dev/null +++ b/src/core/ext/transport/binder/client/endpoint_binder_pool.cc @@ -0,0 +1,108 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/client/endpoint_binder_pool.h" + +#include "src/core/ext/transport/binder/client/jni_utils.h" + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + +#include + +#include "src/core/ext/transport/binder/wire_format/binder_android.h" + +extern "C" { +// Adds endpoint binder to binder pool when Java notify us that the endpoint +// binder is ready. This is called from GrpcBinderConnection.java +JNIEXPORT void JNICALL +Java_io_grpc_binder_cpp_GrpcBinderConnection_notifyConnected__Ljava_lang_String_2Landroid_os_IBinder_2( + JNIEnv* jni_env, jobject, jstring conn_id_jstring, jobject ibinder) { + jboolean isCopy; + const char* conn_id = jni_env->GetStringUTFChars(conn_id_jstring, &isCopy); + gpr_log(GPR_ERROR, "%s called with conn_id = %s", __func__, conn_id); + GPR_ASSERT(ibinder != nullptr); + ndk::SpAIBinder aibinder = grpc_binder::FromJavaBinder(jni_env, ibinder); + gpr_log(GPR_ERROR, "aibinder = %p", aibinder.get()); + auto b = absl::make_unique(aibinder); + GPR_ASSERT(b != nullptr); + grpc_binder::GetEndpointBinderPool()->AddEndpointBinder(conn_id, + std::move(b)); + if (isCopy == JNI_TRUE) { + jni_env->ReleaseStringUTFChars(conn_id_jstring, conn_id); + } +} +} + +#endif // GPR_SUPPORT_BINDER_TRANSPORT + +namespace grpc_binder { + +void EndpointBinderPool ::GetEndpointBinder( + std::string conn_id, + std::function)> cb) { + gpr_log(GPR_ERROR, "GetEndpointBinder %s", conn_id.c_str()); + std::unique_ptr b; + { + grpc_core::MutexLock l(&m_); + if (binder_map_.count(conn_id)) { + b = std::move(binder_map_[conn_id]); + binder_map_.erase(conn_id); + GPR_ASSERT(b != nullptr); + } else { + if (pending_requests_.count(conn_id) != 0) { + gpr_log(GPR_ERROR, "Duplicate GetEndpointBinder request. conn_id = %s", + conn_id.c_str()); + return; + } + pending_requests_[conn_id] = std::move(cb); + return; + } + } + GPR_ASSERT(b != nullptr); + cb(std::move(b)); +} + +void EndpointBinderPool::AddEndpointBinder( + std::string conn_id, std::unique_ptr b) { + gpr_log(GPR_ERROR, "AddEndpointBinder %s", conn_id.c_str()); + GPR_ASSERT(b != nullptr); + // cb will be set in the following block if there is a pending callback + std::function)> cb = nullptr; + { + grpc_core::MutexLock l(&m_); + if (binder_map_.count(conn_id) != 0) { + gpr_log(GPR_ERROR, "EndpointBinder already in the pool. conn_id = %s", + conn_id.c_str()); + return; + } + if (pending_requests_.count(conn_id)) { + cb = std::move(pending_requests_[conn_id]); + pending_requests_.erase(conn_id); + } else { + binder_map_[conn_id] = std::move(b); + b = nullptr; + } + } + if (cb != nullptr) { + cb(std::move(b)); + } +} + +EndpointBinderPool* GetEndpointBinderPool() { + static EndpointBinderPool* p = new EndpointBinderPool(); + return p; +} +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/client/jni_utils.cc b/src/core/ext/transport/binder/client/jni_utils.cc new file mode 100644 index 00000000..e8f094e6 --- /dev/null +++ b/src/core/ext/transport/binder/client/jni_utils.cc @@ -0,0 +1,89 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/client/jni_utils.h" + +#include + +#if defined(ANDROID) || defined(__ANDROID__) + +namespace grpc_binder { + +void CallStaticJavaMethod(JNIEnv* env, const std::string& clazz, + const std::string& method, const std::string& type, + jobject application, const std::string& pkg, + const std::string& cls) { + jclass cl = env->FindClass(clazz.c_str()); + if (cl == nullptr) { + gpr_log(GPR_ERROR, "No class %s", clazz.c_str()); + } + + jmethodID mid = env->GetStaticMethodID(cl, method.c_str(), type.c_str()); + if (mid == nullptr) { + gpr_log(GPR_ERROR, "No method id %s", method.c_str()); + } + + env->CallStaticVoidMethod(cl, mid, application, + env->NewStringUTF(pkg.c_str()), + env->NewStringUTF(cls.c_str())); +} + +void CallStaticJavaMethod(JNIEnv* env, const std::string& clazz, + const std::string& method, const std::string& type, + jobject application, const std::string& pkg, + const std::string& cls, const std::string& conn_id) { + jclass cl = env->FindClass(clazz.c_str()); + if (cl == nullptr) { + gpr_log(GPR_ERROR, "No class %s", clazz.c_str()); + } + + jmethodID mid = env->GetStaticMethodID(cl, method.c_str(), type.c_str()); + if (mid == nullptr) { + gpr_log(GPR_ERROR, "No method id %s", method.c_str()); + } + + env->CallStaticVoidMethod( + cl, mid, application, env->NewStringUTF(pkg.c_str()), + env->NewStringUTF(cls.c_str()), env->NewStringUTF(conn_id.c_str())); +} + +jobject CallStaticJavaMethodForObject(JNIEnv* env, const std::string& clazz, + const std::string& method, + const std::string& type) { + jclass cl = env->FindClass(clazz.c_str()); + if (cl == nullptr) { + gpr_log(GPR_ERROR, "No class %s", clazz.c_str()); + return nullptr; + } + + jmethodID mid = env->GetStaticMethodID(cl, method.c_str(), type.c_str()); + if (mid == nullptr) { + gpr_log(GPR_ERROR, "No method id %s", method.c_str()); + return nullptr; + } + + jobject object = env->CallStaticObjectMethod(cl, mid); + if (object == nullptr) { + gpr_log(GPR_ERROR, "Got null object from Java"); + return nullptr; + } + + return object; +} + +} // namespace grpc_binder + +#endif diff --git a/src/core/ext/transport/binder/security_policy/internal_only_security_policy.cc b/src/core/ext/transport/binder/security_policy/internal_only_security_policy.cc new file mode 100644 index 00000000..91587599 --- /dev/null +++ b/src/core/ext/transport/binder/security_policy/internal_only_security_policy.cc @@ -0,0 +1,39 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/security_policy/internal_only_security_policy.h" + +#ifdef GPR_ANDROID + +#include + +namespace grpc { +namespace experimental { +namespace binder { + +InternalOnlySecurityPolicy::InternalOnlySecurityPolicy() = default; + +InternalOnlySecurityPolicy::~InternalOnlySecurityPolicy() = default; + +bool InternalOnlySecurityPolicy::IsAuthorized(int uid) { + return static_cast(uid) == getuid(); +}; + +} // namespace binder +} // namespace experimental +} // namespace grpc + +#endif diff --git a/src/core/ext/transport/binder/security_policy/untrusted_security_policy.cc b/src/core/ext/transport/binder/security_policy/untrusted_security_policy.cc new file mode 100644 index 00000000..801b19b9 --- /dev/null +++ b/src/core/ext/transport/binder/security_policy/untrusted_security_policy.cc @@ -0,0 +1,31 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" + +namespace grpc { +namespace experimental { +namespace binder { + +UntrustedSecurityPolicy::UntrustedSecurityPolicy() = default; + +UntrustedSecurityPolicy::~UntrustedSecurityPolicy() = default; + +bool UntrustedSecurityPolicy::IsAuthorized(int) { return true; }; + +} // namespace binder +} // namespace experimental +} // namespace grpc diff --git a/src/core/ext/transport/binder/server/binder_server.cc b/src/core/ext/transport/binder/server/binder_server.cc new file mode 100644 index 00000000..27120d03 --- /dev/null +++ b/src/core/ext/transport/binder/server/binder_server.cc @@ -0,0 +1,248 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/server/binder_server.h" + +#include +#include +#include + +#include "absl/memory/memory.h" + +#include + +#include "src/core/ext/transport/binder/transport/binder_transport.h" +#include "src/core/ext/transport/binder/wire_format/binder_android.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/error_utils.h" + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + +#include +#include +#include + +extern "C" { + +// This will be invoked from +// src/core/ext/transport/binder/java/io/grpc/binder/cpp/GrpcCppServerBuilder.java +JNIEXPORT jobject JNICALL +Java_io_grpc_binder_cpp_GrpcCppServerBuilder_GetEndpointBinderInternal__Ljava_lang_String_2( + JNIEnv* jni_env, jobject, jstring conn_id_jstring) { + AIBinder* ai_binder = nullptr; + + { + // This block is the scope of conn_id c-string + jboolean isCopy; + const char* conn_id = jni_env->GetStringUTFChars(conn_id_jstring, &isCopy); + ai_binder = + static_cast(grpc_get_endpoint_binder(std::string(conn_id))); + if (ai_binder == nullptr) { + gpr_log(GPR_ERROR, "Cannot find endpoint binder with connection id = %s", + conn_id); + } + if (isCopy == JNI_TRUE) { + jni_env->ReleaseStringUTFChars(conn_id_jstring, conn_id); + } + } + + if (ai_binder == nullptr) { + return nullptr; + } + + return AIBinder_toJavaBinder(jni_env, ai_binder); +} +} + +#endif + +namespace grpc { +namespace experimental { +namespace binder { + +void* GetEndpointBinder(const std::string& service) { + return grpc_get_endpoint_binder(service); +} + +void AddEndpointBinder(const std::string& service, void* endpoint_binder) { + grpc_add_endpoint_binder(service, endpoint_binder); +} + +void RemoveEndpointBinder(const std::string& service) { + grpc_remove_endpoint_binder(service); +} + +} // namespace binder +} // namespace experimental +} // namespace grpc + +static absl::flat_hash_map* g_endpoint_binder_pool = + nullptr; + +namespace { + +grpc_core::Mutex* GetBinderPoolMutex() { + static grpc_core::Mutex* mu = new grpc_core::Mutex(); + return mu; +} + +} // namespace + +void grpc_add_endpoint_binder(const std::string& service, + void* endpoint_binder) { + grpc_core::MutexLock lock(GetBinderPoolMutex()); + if (g_endpoint_binder_pool == nullptr) { + g_endpoint_binder_pool = new absl::flat_hash_map(); + } + (*g_endpoint_binder_pool)[service] = endpoint_binder; +} + +void grpc_remove_endpoint_binder(const std::string& service) { + grpc_core::MutexLock lock(GetBinderPoolMutex()); + if (g_endpoint_binder_pool == nullptr) { + return; + } + g_endpoint_binder_pool->erase(service); +} + +void* grpc_get_endpoint_binder(const std::string& service) { + grpc_core::MutexLock lock(GetBinderPoolMutex()); + if (g_endpoint_binder_pool == nullptr) { + return nullptr; + } + auto iter = g_endpoint_binder_pool->find(service); + return iter == g_endpoint_binder_pool->end() ? nullptr : iter->second; +} + +namespace grpc_core { + +class BinderServerListener : public Server::ListenerInterface { + public: + BinderServerListener( + Server* server, std::string addr, BinderTxReceiverFactory factory, + std::shared_ptr + security_policy) + : server_(server), + addr_(std::move(addr)), + factory_(std::move(factory)), + security_policy_(security_policy) {} + + void Start(Server* /*server*/, + const std::vector* /*pollsets*/) override { + tx_receiver_ = factory_( + [this](transaction_code_t code, grpc_binder::ReadableParcel* parcel, + int uid) { return OnSetupTransport(code, parcel, uid); }); + endpoint_binder_ = tx_receiver_->GetRawBinder(); + grpc_add_endpoint_binder(addr_, endpoint_binder_); + } + + channelz::ListenSocketNode* channelz_listen_socket_node() const override { + return nullptr; + } + + void SetOnDestroyDone(grpc_closure* on_destroy_done) override { + on_destroy_done_ = on_destroy_done; + } + + void Orphan() override { delete this; } + + ~BinderServerListener() override { + ExecCtx::Get()->Flush(); + if (on_destroy_done_) { + ExecCtx::Run(DEBUG_LOCATION, on_destroy_done_, GRPC_ERROR_NONE); + ExecCtx::Get()->Flush(); + } + grpc_remove_endpoint_binder(addr_); + } + + private: + absl::Status OnSetupTransport(transaction_code_t code, + grpc_binder::ReadableParcel* parcel, int uid) { + grpc_core::ExecCtx exec_ctx; + if (grpc_binder::BinderTransportTxCode(code) != + grpc_binder::BinderTransportTxCode::SETUP_TRANSPORT) { + return absl::InvalidArgumentError("Not a SETUP_TRANSPORT request"); + } + + gpr_log(GPR_ERROR, "calling uid = %d", uid); + if (!security_policy_->IsAuthorized(uid)) { + // TODO(mingcl): For now we just ignore this unauthorized + // SETUP_TRANSPORT transaction and ghost the client. Check if we should + // send back a SHUTDOWN_TRANSPORT in this case. + return absl::PermissionDeniedError( + "UID " + std::to_string(uid) + + " is not allowed to connect to this " + "server according to security policy."); + } + + int version; + absl::Status status = parcel->ReadInt32(&version); + if (!status.ok()) { + return status; + } + gpr_log(GPR_INFO, "version = %d", version); + // TODO(waynetu): Check supported version. + std::unique_ptr client_binder{}; + status = parcel->ReadBinder(&client_binder); + if (!status.ok()) { + return status; + } + if (!client_binder) { + return absl::InvalidArgumentError("NULL binder read from the parcel"); + } + client_binder->Initialize(); + // Finish the second half of SETUP_TRANSPORT in + // grpc_create_binder_transport_server(). + grpc_transport* server_transport = grpc_create_binder_transport_server( + std::move(client_binder), security_policy_); + GPR_ASSERT(server_transport); + grpc_channel_args* args = grpc_channel_args_copy(server_->channel_args()); + grpc_error_handle error = server_->SetupTransport(server_transport, nullptr, + args, nullptr, nullptr); + grpc_channel_args_destroy(args); + return grpc_error_to_absl_status(error); + } + + Server* server_; + grpc_closure* on_destroy_done_ = nullptr; + std::string addr_; + BinderTxReceiverFactory factory_; + std::shared_ptr security_policy_; + void* endpoint_binder_ = nullptr; + std::unique_ptr tx_receiver_; +}; + +bool AddBinderPort(const std::string& addr, grpc_server* server, + BinderTxReceiverFactory factory, + std::shared_ptr + security_policy) { + // TODO(mingcl): Check if the addr is valid here after binder address resolver + // related code are merged. + const std::string kBinderUriScheme = "binder:"; + if (addr.compare(0, kBinderUriScheme.size(), kBinderUriScheme) != 0) { + return false; + } + std::string conn_id = addr.substr(kBinderUriScheme.size()); + grpc_core::Server* core_server = server->core_server.get(); + core_server->AddListener( + grpc_core::OrphanablePtr( + new grpc_core::BinderServerListener( + core_server, conn_id, std::move(factory), security_policy))); + return true; +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/binder/server/binder_server_credentials.cc b/src/core/ext/transport/binder/server/binder_server_credentials.cc new file mode 100644 index 00000000..d585a555 --- /dev/null +++ b/src/core/ext/transport/binder/server/binder_server_credentials.cc @@ -0,0 +1,73 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "src/core/ext/transport/binder/security_policy/security_policy.h" +#include "src/core/ext/transport/binder/server/binder_server.h" +#include "src/core/ext/transport/binder/wire_format/binder_android.h" + +namespace grpc { +namespace experimental { + +namespace { + +class BinderServerCredentialsImpl final : public ServerCredentials { + public: + explicit BinderServerCredentialsImpl( + std::shared_ptr + security_policy) + : security_policy_(security_policy) {} +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + int AddPortToServer(const std::string& addr, grpc_server* server) override { + return grpc_core::AddBinderPort( + std::string(addr), server, + [](grpc_binder::TransactionReceiver::OnTransactCb transact_cb) { + return absl::make_unique( + nullptr, std::move(transact_cb)); + }, + security_policy_); + } +#else + int AddPortToServer(const std::string& /*addr*/, + grpc_server* /*server*/) override { + return 0; + } +#endif // GPR_SUPPORT_BINDER_TRANSPORT + + void SetAuthMetadataProcessor( + const std::shared_ptr& /*processor*/) override { + GPR_ASSERT(false); + } + + private: + bool IsInsecure() const override { return true; } + + std::shared_ptr security_policy_; +}; + +} // namespace + +std::shared_ptr BinderServerCredentials( + std::shared_ptr + security_policy) { + GPR_ASSERT(security_policy != nullptr); + return std::shared_ptr( + new BinderServerCredentialsImpl(security_policy)); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc new file mode 100644 index 00000000..6cd5933e --- /dev/null +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -0,0 +1,757 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/transport/binder_transport.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" + +#include + +#include "src/core/ext/transport/binder/transport/binder_stream.h" +#include "src/core/ext/transport/binder/utils/transport_stream_receiver.h" +#include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h" +#include "src/core/ext/transport/binder/wire_format/wire_reader.h" +#include "src/core/ext/transport/binder/wire_format/wire_reader_impl.h" +#include "src/core/ext/transport/binder/wire_format/wire_writer.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/transport/byte_stream.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" +#include "src/core/lib/transport/transport.h" + +#ifndef NDEBUG +static void grpc_binder_stream_ref(grpc_binder_stream* s, const char* reason) { + grpc_stream_ref(s->refcount, reason); +} +static void grpc_binder_stream_unref(grpc_binder_stream* s, + const char* reason) { + grpc_stream_unref(s->refcount, reason); +} +static void grpc_binder_ref_transport(grpc_binder_transport* t, + const char* reason, const char* file, + int line) { + t->refs.Ref(grpc_core::DebugLocation(file, line), reason); +} +static void grpc_binder_unref_transport(grpc_binder_transport* t, + const char* reason, const char* file, + int line) { + if (t->refs.Unref(grpc_core::DebugLocation(file, line), reason)) { + delete t; + } +} +#else +static void grpc_binder_stream_ref(grpc_binder_stream* s) { + grpc_stream_ref(s->refcount); +} +static void grpc_binder_stream_unref(grpc_binder_stream* s) { + grpc_stream_unref(s->refcount); +} +static void grpc_binder_ref_transport(grpc_binder_transport* t) { + t->refs.Ref(); +} +static void grpc_binder_unref_transport(grpc_binder_transport* t) { + if (t->refs.Unref()) { + delete t; + } +} +#endif + +#ifndef NDEBUG +#define GRPC_BINDER_STREAM_REF(stream, reason) \ + grpc_binder_stream_ref(stream, reason) +#define GRPC_BINDER_STREAM_UNREF(stream, reason) \ + grpc_binder_stream_unref(stream, reason) +#define GRPC_BINDER_REF_TRANSPORT(t, r) \ + grpc_binder_ref_transport(t, r, __FILE__, __LINE__) +#define GRPC_BINDER_UNREF_TRANSPORT(t, r) \ + grpc_binder_unref_transport(t, r, __FILE__, __LINE__) +#else +#define GRPC_BINDER_STREAM_REF(stream, reason) grpc_binder_stream_ref(stream) +#define GRPC_BINDER_STREAM_UNREF(stream, reason) \ + grpc_binder_stream_unref(stream) +#define GRPC_BINDER_REF_TRANSPORT(t, r) grpc_binder_ref_transport(t) +#define GRPC_BINDER_UNREF_TRANSPORT(t, r) grpc_binder_unref_transport(t) +#endif + +static int init_stream(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, const void* server_data, + grpc_core::Arena* arena) { + GPR_TIMER_SCOPE("init_stream", 0); + gpr_log(GPR_INFO, "%s = %p %p %p %p %p", __func__, gt, gs, refcount, + server_data, arena); + grpc_binder_transport* t = reinterpret_cast(gt); + // TODO(mingcl): Figure out if we need to worry about concurrent invocation + // here + new (gs) grpc_binder_stream(t, refcount, server_data, arena, + t->NewStreamTxCode(), t->is_client); + return 0; +} + +static void set_pollset(grpc_transport* gt, grpc_stream* gs, grpc_pollset* gp) { + gpr_log(GPR_INFO, "%s = %p %p %p", __func__, gt, gs, gp); +} + +static void set_pollset_set(grpc_transport*, grpc_stream*, grpc_pollset_set*) { + gpr_log(GPR_INFO, __func__); +} + +static void AssignMetadata(grpc_metadata_batch* mb, + const grpc_binder::Metadata& md) { + mb->Clear(); + for (auto& p : md) { + mb->Append(p.first, grpc_slice_from_cpp_string(p.second)); + } +} + +static void cancel_stream_locked(grpc_binder_transport* gbt, + grpc_binder_stream* gbs, + grpc_error_handle error) { + gpr_log(GPR_INFO, "cancel_stream_locked"); + if (!gbs->is_closed) { + GPR_ASSERT(gbs->cancel_self_error == GRPC_ERROR_NONE); + gbs->is_closed = true; + gbs->cancel_self_error = GRPC_ERROR_REF(error); + gbt->transport_stream_receiver->CancelStream(gbs->tx_code); + gbt->registered_stream.erase(gbs->tx_code); + if (gbs->recv_initial_metadata_ready != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_initial_metadata_ready, + GRPC_ERROR_REF(error)); + gbs->recv_initial_metadata_ready = nullptr; + gbs->recv_initial_metadata = nullptr; + gbs->trailing_metadata_available = nullptr; + } + if (gbs->recv_message_ready != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready, + GRPC_ERROR_REF(error)); + gbs->recv_message_ready = nullptr; + gbs->recv_message->reset(); + gbs->recv_message = nullptr; + gbs->call_failed_before_recv_message = nullptr; + } + if (gbs->recv_trailing_metadata_finished != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + gbs->recv_trailing_metadata_finished, + GRPC_ERROR_REF(error)); + gbs->recv_trailing_metadata_finished = nullptr; + gbs->recv_trailing_metadata = nullptr; + } + } + GRPC_ERROR_UNREF(error); +} + +static bool ContainsAuthorityAndPath(const grpc_binder::Metadata& metadata) { + bool has_authority = false; + bool has_path = false; + for (const auto& kv : metadata) { + if (kv.first == grpc_core::StringViewFromSlice(GRPC_MDSTR_AUTHORITY)) { + has_authority = true; + } + if (kv.first == grpc_core::StringViewFromSlice(GRPC_MDSTR_PATH)) { + has_path = true; + } + } + return has_authority && has_path; +} + +static void recv_initial_metadata_locked(void* arg, + grpc_error_handle /*error*/) { + RecvInitialMetadataArgs* args = static_cast(arg); + grpc_binder_stream* gbs = args->gbs; + + gpr_log(GPR_INFO, + "recv_initial_metadata_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_initial_metadata); + GPR_ASSERT(gbs->recv_initial_metadata_ready); + if (!args->initial_metadata.ok()) { + gpr_log(GPR_ERROR, "Failed to parse initial metadata"); + return absl_status_to_grpc_error(args->initial_metadata.status()); + } + if (!gbs->is_client) { + // For server, we expect :authority and :path in initial metadata. + if (!ContainsAuthorityAndPath(*args->initial_metadata)) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + "Missing :authority or :path in initial metadata"); + } + } + AssignMetadata(gbs->recv_initial_metadata, *args->initial_metadata); + return GRPC_ERROR_NONE; + }(); + + grpc_closure* cb = gbs->recv_initial_metadata_ready; + gbs->recv_initial_metadata_ready = nullptr; + gbs->recv_initial_metadata = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } + GRPC_BINDER_STREAM_UNREF(gbs, "recv_initial_metadata"); +} + +static void recv_message_locked(void* arg, grpc_error_handle /*error*/) { + RecvMessageArgs* args = static_cast(arg); + grpc_binder_stream* gbs = args->gbs; + + gpr_log(GPR_INFO, "recv_message_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_message); + GPR_ASSERT(gbs->recv_message_ready); + if (!args->message.ok()) { + gpr_log(GPR_ERROR, "Failed to receive message"); + if (args->message.status().message() == + grpc_binder::TransportStreamReceiver:: + kGrpcBinderTransportCancelledGracefully) { + gpr_log(GPR_ERROR, "message cancelled gracefully"); + // Cancelled because we've already received trailing metadata. + // It's not an error in this case. + return GRPC_ERROR_NONE; + } else { + return absl_status_to_grpc_error(args->message.status()); + } + } + grpc_slice_buffer buf; + grpc_slice_buffer_init(&buf); + grpc_slice_buffer_add(&buf, grpc_slice_from_cpp_string(*args->message)); + + gbs->sbs.Init(&buf, 0); + gbs->recv_message->reset(gbs->sbs.get()); + return GRPC_ERROR_NONE; + }(); + + if (error != GRPC_ERROR_NONE && + gbs->call_failed_before_recv_message != nullptr) { + *gbs->call_failed_before_recv_message = true; + } + grpc_closure* cb = gbs->recv_message_ready; + gbs->recv_message_ready = nullptr; + gbs->recv_message = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } + + GRPC_BINDER_STREAM_UNREF(gbs, "recv_message"); +} + +static void recv_trailing_metadata_locked(void* arg, + grpc_error_handle /*error*/) { + RecvTrailingMetadataArgs* args = static_cast(arg); + grpc_binder_stream* gbs = args->gbs; + + gpr_log(GPR_INFO, + "recv_trailing_metadata_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + + if (!gbs->is_closed) { + grpc_error_handle error = [&] { + GPR_ASSERT(gbs->recv_trailing_metadata); + GPR_ASSERT(gbs->recv_trailing_metadata_finished); + if (!args->trailing_metadata.ok()) { + gpr_log(GPR_ERROR, "Failed to receive trailing metadata"); + return absl_status_to_grpc_error(args->trailing_metadata.status()); + } + if (!gbs->is_client) { + // Client will not send non-empty trailing metadata. + if (!args->trailing_metadata.value().empty()) { + gpr_log(GPR_ERROR, "Server receives non-empty trailing metadata."); + return GRPC_ERROR_CANCELLED; + } + } else { + AssignMetadata(gbs->recv_trailing_metadata, *args->trailing_metadata); + // Append status to metadata + // TODO(b/192208695): See if we can avoid to manually put status + // code into the header + gpr_log(GPR_INFO, "status = %d", args->status); + grpc_linked_mdelem* glm = static_cast( + gbs->arena->Alloc(sizeof(grpc_linked_mdelem))); + glm->md = grpc_get_reffed_status_elem(args->status); + GPR_ASSERT(gbs->recv_trailing_metadata->LinkTail(glm) == + GRPC_ERROR_NONE); + gpr_log(GPR_INFO, "trailing_metadata = %p", + gbs->recv_trailing_metadata); + gpr_log(GPR_INFO, "glm = %p", glm); + } + return GRPC_ERROR_NONE; + }(); + + if (gbs->is_client || gbs->trailing_metadata_sent) { + grpc_closure* cb = gbs->recv_trailing_metadata_finished; + gbs->recv_trailing_metadata_finished = nullptr; + gbs->recv_trailing_metadata = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } else { + // According to transport explaineer - "Server extra: This op shouldn't + // actually be considered complete until the server has also sent trailing + // metadata to provide the other side with final status" + // + // We haven't sent trailing metadata yet, so we have to delay completing + // the recv_trailing_metadata callback. + gbs->need_to_call_trailing_metadata_callback = true; + } + } + GRPC_BINDER_STREAM_UNREF(gbs, "recv_trailing_metadata"); +} + +static void perform_stream_op_locked(void* stream_op, + grpc_error_handle /*error*/) { + grpc_transport_stream_op_batch* op = + static_cast(stream_op); + grpc_binder_stream* gbs = + static_cast(op->handler_private.extra_arg); + grpc_binder_transport* gbt = gbs->t; + if (op->cancel_stream) { + // TODO(waynetu): Is this true? + GPR_ASSERT(!op->send_initial_metadata && !op->send_message && + !op->send_trailing_metadata && !op->recv_initial_metadata && + !op->recv_message && !op->recv_trailing_metadata); + gpr_log(GPR_INFO, "cancel_stream is_client = %d", gbs->is_client); + if (!gbs->is_client) { + // Send trailing metadata to inform the other end about the cancellation, + // regardless if we'd already done that or not. + grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbt->is_client); + cancel_tx.SetSuffix(grpc_binder::Metadata{}); + cancel_tx.SetStatus(1); + absl::Status status = gbt->wire_writer->RpcCall(cancel_tx); + } + cancel_stream_locked(gbt, gbs, op->payload->cancel_stream.cancel_error); + if (op->on_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, GRPC_ERROR_NONE); + } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); + return; + } + + if (gbs->is_closed) { + if (op->send_message) { + // Reset the send_message payload to prevent memory leaks. + op->payload->send_message.send_message.reset(); + } + if (op->recv_initial_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->recv_message) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + op->payload->recv_message.recv_message_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->recv_trailing_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + if (op->on_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, + GRPC_ERROR_REF(gbs->cancel_self_error)); + } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); + return; + } + + int tx_code = gbs->tx_code; + grpc_binder::Transaction tx(tx_code, gbt->is_client); + + if (op->send_initial_metadata) { + gpr_log(GPR_INFO, "send_initial_metadata"); + grpc_binder::Metadata init_md; + auto batch = op->payload->send_initial_metadata.send_initial_metadata; + + batch->ForEach([&](grpc_mdelem md) { + absl::string_view key = grpc_core::StringViewFromSlice(GRPC_MDKEY(md)); + absl::string_view value = + grpc_core::StringViewFromSlice(GRPC_MDVALUE(md)); + gpr_log(GPR_INFO, "send initial metatday key-value %s", + absl::StrCat(key, " ", value).c_str()); + if (grpc_slice_eq(GRPC_MDKEY(md), GRPC_MDSTR_PATH)) { + // TODO(b/192208403): Figure out if it is correct to simply drop '/' + // prefix and treat it as rpc method name + GPR_ASSERT(value[0] == '/'); + std::string path = std::string(value).substr(1); + + // Only client send method ref. + GPR_ASSERT(gbt->is_client); + tx.SetMethodRef(path); + } else { + init_md.emplace_back(std::string(key), std::string(value)); + } + }); + tx.SetPrefix(init_md); + } + if (op->send_message) { + gpr_log(GPR_INFO, "send_message"); + size_t remaining = op->payload->send_message.send_message->length(); + std::string message_data; + while (remaining > 0) { + grpc_slice message_slice; + // TODO(waynetu): Temporarily assume that the message is ready. + GPR_ASSERT( + op->payload->send_message.send_message->Next(SIZE_MAX, nullptr)); + grpc_error_handle error = + op->payload->send_message.send_message->Pull(&message_slice); + // TODO(waynetu): Cancel the stream if error is not GRPC_ERROR_NONE. + GPR_ASSERT(error == GRPC_ERROR_NONE); + uint8_t* p = GRPC_SLICE_START_PTR(message_slice); + size_t len = GRPC_SLICE_LENGTH(message_slice); + remaining -= len; + message_data += std::string(reinterpret_cast(p), len); + grpc_slice_unref_internal(message_slice); + } + gpr_log(GPR_INFO, "message_data = %s", message_data.c_str()); + tx.SetData(message_data); + // TODO(b/192369787): Are we supposed to reset here to avoid + // use-after-free issue in call.cc? + op->payload->send_message.send_message.reset(); + } + + if (op->send_trailing_metadata) { + gpr_log(GPR_INFO, "send_trailing_metadata"); + auto batch = op->payload->send_trailing_metadata.send_trailing_metadata; + grpc_binder::Metadata trailing_metadata; + + batch->ForEach([&](grpc_mdelem md) { + // Client will not send trailing metadata. + GPR_ASSERT(!gbt->is_client); + + if (grpc_slice_eq(GRPC_MDKEY(md), GRPC_MDSTR_GRPC_STATUS)) { + int status = grpc_get_status_code_from_metadata(md); + gpr_log(GPR_INFO, "send trailing metadata status = %d", status); + tx.SetStatus(status); + } else { + absl::string_view key = grpc_core::StringViewFromSlice(GRPC_MDKEY(md)); + absl::string_view value = + grpc_core::StringViewFromSlice(GRPC_MDVALUE(md)); + gpr_log(GPR_INFO, "send trailing metatday key-value %s", + absl::StrCat(key, " ", value).c_str()); + trailing_metadata.emplace_back(std::string(key), std::string(value)); + } + }); + // TODO(mingcl): Will we ever has key-value pair here? According to + // wireformat client suffix data is always empty. + tx.SetSuffix(trailing_metadata); + } + if (op->recv_initial_metadata) { + gpr_log(GPR_INFO, "recv_initial_metadata"); + gbs->recv_initial_metadata_ready = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + gbs->recv_initial_metadata = + op->payload->recv_initial_metadata.recv_initial_metadata; + gbs->trailing_metadata_available = + op->payload->recv_initial_metadata.trailing_metadata_available; + GRPC_BINDER_STREAM_REF(gbs, "recv_initial_metadata"); + gbt->transport_stream_receiver->RegisterRecvInitialMetadata( + tx_code, [tx_code, gbs, + gbt](absl::StatusOr initial_metadata) { + grpc_core::ExecCtx exec_ctx; + gbs->recv_initial_metadata_args.tx_code = tx_code; + gbs->recv_initial_metadata_args.initial_metadata = + std::move(initial_metadata); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_initial_metadata_closure, + recv_initial_metadata_locked, + &gbs->recv_initial_metadata_args, nullptr), + GRPC_ERROR_NONE); + }); + } + if (op->recv_message) { + gpr_log(GPR_INFO, "recv_message"); + gbs->recv_message_ready = op->payload->recv_message.recv_message_ready; + gbs->recv_message = op->payload->recv_message.recv_message; + gbs->call_failed_before_recv_message = + op->payload->recv_message.call_failed_before_recv_message; + GRPC_BINDER_STREAM_REF(gbs, "recv_message"); + gbt->transport_stream_receiver->RegisterRecvMessage( + tx_code, [tx_code, gbs, gbt](absl::StatusOr message) { + grpc_core::ExecCtx exec_ctx; + gbs->recv_message_args.tx_code = tx_code; + gbs->recv_message_args.message = std::move(message); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_message_closure, recv_message_locked, + &gbs->recv_message_args, nullptr), + GRPC_ERROR_NONE); + }); + } + if (op->recv_trailing_metadata) { + gpr_log(GPR_INFO, "recv_trailing_metadata"); + gbs->recv_trailing_metadata_finished = + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + gbs->recv_trailing_metadata = + op->payload->recv_trailing_metadata.recv_trailing_metadata; + GRPC_BINDER_STREAM_REF(gbs, "recv_trailing_metadata"); + gbt->transport_stream_receiver->RegisterRecvTrailingMetadata( + tx_code, [tx_code, gbs, gbt]( + absl::StatusOr trailing_metadata, + int status) { + grpc_core::ExecCtx exec_ctx; + gbs->recv_trailing_metadata_args.tx_code = tx_code; + gbs->recv_trailing_metadata_args.trailing_metadata = + std::move(trailing_metadata); + gbs->recv_trailing_metadata_args.status = status; + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&gbs->recv_trailing_metadata_closure, + recv_trailing_metadata_locked, + &gbs->recv_trailing_metadata_args, nullptr), + GRPC_ERROR_NONE); + }); + } + // Only send transaction when there's a send op presented. + absl::Status status = absl::OkStatus(); + if (op->send_initial_metadata || op->send_message || + op->send_trailing_metadata) { + // TODO(waynetu): RpcCall() is doing a lot of work (including waiting for + // acknowledgements from the other side). Consider delaying this operation + // with combiner. + status = gbt->wire_writer->RpcCall(tx); + if (!gbs->is_client && op->send_trailing_metadata) { + gbs->trailing_metadata_sent = true; + // According to transport explaineer - "Server extra: This op shouldn't + // actually be considered complete until the server has also sent trailing + // metadata to provide the other side with final status" + // + // Because we've done sending trailing metadata here, we can safely + // complete the recv_trailing_metadata callback here. + if (gbs->need_to_call_trailing_metadata_callback) { + grpc_closure* cb = gbs->recv_trailing_metadata_finished; + gbs->recv_trailing_metadata_finished = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + gbs->need_to_call_trailing_metadata_callback = false; + } + } + } + // Note that this should only be scheduled when all non-recv ops are + // completed + if (op->on_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, + absl_status_to_grpc_error(status)); + gpr_log(GPR_INFO, "on_complete closure schuduled"); + } + GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); +} + +static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, + grpc_transport_stream_op_batch* op) { + GPR_TIMER_SCOPE("perform_stream_op", 0); + grpc_binder_transport* gbt = reinterpret_cast(gt); + grpc_binder_stream* gbs = reinterpret_cast(gs); + gpr_log(GPR_INFO, "%s = %p %p %p is_client = %d", __func__, gt, gs, op, + gbs->is_client); + GRPC_BINDER_STREAM_REF(gbs, "perform_stream_op"); + op->handler_private.extra_arg = gbs; + gbt->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_stream_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +static void close_transport_locked(grpc_binder_transport* gbt) { + gbt->state_tracker.SetState(GRPC_CHANNEL_SHUTDOWN, absl::OkStatus(), + "transport closed due to disconnection/goaway"); + while (!gbt->registered_stream.empty()) { + cancel_stream_locked( + gbt, gbt->registered_stream.begin()->second, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("transport closed"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } +} + +static void perform_transport_op_locked(void* transport_op, + grpc_error_handle /*error*/) { + grpc_transport_op* op = static_cast(transport_op); + grpc_binder_transport* gbt = + static_cast(op->handler_private.extra_arg); + // TODO(waynetu): Should we lock here to avoid data race? + if (op->start_connectivity_watch != nullptr) { + gbt->state_tracker.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); + } + if (op->stop_connectivity_watch != nullptr) { + gbt->state_tracker.RemoveWatcher(op->stop_connectivity_watch); + } + if (op->set_accept_stream) { + gbt->accept_stream_fn = op->set_accept_stream_fn; + gbt->accept_stream_user_data = op->set_accept_stream_user_data; + } + if (op->on_consumed) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); + } + bool do_close = false; + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + do_close = true; + GRPC_ERROR_UNREF(op->disconnect_with_error); + } + if (op->goaway_error != GRPC_ERROR_NONE) { + do_close = true; + GRPC_ERROR_UNREF(op->goaway_error); + } + if (do_close) { + close_transport_locked(gbt); + } + GRPC_BINDER_UNREF_TRANSPORT(gbt, "perform_transport_op"); +} + +static void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { + gpr_log(GPR_INFO, __func__); + grpc_binder_transport* gbt = reinterpret_cast(gt); + op->handler_private.extra_arg = gbt; + GRPC_BINDER_REF_TRANSPORT(gbt, "perform_transport_op"); + gbt->combiner->Run( + GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_transport_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +static void destroy_stream_locked(void* sp, grpc_error_handle /*error*/) { + grpc_binder_stream* gbs = static_cast(sp); + grpc_binder_transport* gbt = gbs->t; + cancel_stream_locked( + gbt, gbs, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("destroy stream"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + gbs->~grpc_binder_stream(); +} + +static void destroy_stream(grpc_transport* /*gt*/, grpc_stream* gs, + grpc_closure* then_schedule_closure) { + gpr_log(GPR_INFO, __func__); + grpc_binder_stream* gbs = reinterpret_cast(gs); + gbs->destroy_stream_then_closure = then_schedule_closure; + gbs->t->combiner->Run(GRPC_CLOSURE_INIT(&gbs->destroy_stream, + destroy_stream_locked, gbs, nullptr), + GRPC_ERROR_NONE); +} + +static void destroy_transport_locked(void* gt, grpc_error_handle /*error*/) { + grpc_binder_transport* gbt = static_cast(gt); + close_transport_locked(gbt); + // Release the references held by the transport. + gbt->wire_reader = nullptr; + gbt->transport_stream_receiver = nullptr; + gbt->wire_writer = nullptr; + GRPC_BINDER_UNREF_TRANSPORT(gbt, "transport destroyed"); +} + +static void destroy_transport(grpc_transport* gt) { + gpr_log(GPR_INFO, __func__); + grpc_binder_transport* gbt = reinterpret_cast(gt); + gbt->combiner->Run( + GRPC_CLOSURE_CREATE(destroy_transport_locked, gbt, nullptr), + GRPC_ERROR_NONE); +} + +static grpc_endpoint* get_endpoint(grpc_transport*) { + gpr_log(GPR_INFO, __func__); + return nullptr; +} + +// See grpc_transport_vtable declaration for meaning of each field +static const grpc_transport_vtable vtable = {sizeof(grpc_binder_stream), + "binder", + init_stream, + set_pollset, + set_pollset_set, + perform_stream_op, + perform_transport_op, + destroy_stream, + destroy_transport, + get_endpoint}; + +static const grpc_transport_vtable* get_vtable() { return &vtable; } + +static void accept_stream_locked(void* gt, grpc_error_handle /*error*/) { + grpc_binder_transport* gbt = static_cast(gt); + if (gbt->accept_stream_fn) { + // must pass in a non-null value. + (*gbt->accept_stream_fn)(gbt->accept_stream_user_data, &gbt->base, gbt); + } +} + +grpc_binder_transport::grpc_binder_transport( + std::unique_ptr binder, bool is_client, + std::shared_ptr security_policy) + : is_client(is_client), + combiner(grpc_combiner_create()), + state_tracker(is_client ? "binder_transport_client" + : "binder_transport_server"), + refs(1, nullptr) { + gpr_log(GPR_INFO, __func__); + base.vtable = get_vtable(); + GRPC_CLOSURE_INIT(&accept_stream_closure, accept_stream_locked, this, + nullptr); + transport_stream_receiver = + std::make_shared( + is_client, /*accept_stream_callback=*/[this] { + grpc_core::ExecCtx exec_ctx; + combiner->Run(&accept_stream_closure, GRPC_ERROR_NONE); + }); + // WireReader holds a ref to grpc_binder_transport. + GRPC_BINDER_REF_TRANSPORT(this, "wire reader"); + wire_reader = grpc_core::MakeOrphanable( + transport_stream_receiver, is_client, security_policy, + /*on_destruct_callback=*/ + [this] { + // Unref transport when destructed. + GRPC_BINDER_UNREF_TRANSPORT(this, "wire reader"); + }); + wire_writer = wire_reader->SetupTransport(std::move(binder)); +} + +grpc_binder_transport::~grpc_binder_transport() { + GRPC_COMBINER_UNREF(combiner, "binder_transport"); +} + +grpc_transport* grpc_create_binder_transport_client( + std::unique_ptr endpoint_binder, + std::shared_ptr + security_policy) { + gpr_log(GPR_INFO, __func__); + + GPR_ASSERT(endpoint_binder != nullptr); + GPR_ASSERT(security_policy != nullptr); + + grpc_binder_transport* t = new grpc_binder_transport( + std::move(endpoint_binder), /*is_client=*/true, security_policy); + + return &t->base; +} + +grpc_transport* grpc_create_binder_transport_server( + std::unique_ptr client_binder, + std::shared_ptr + security_policy) { + gpr_log(GPR_INFO, __func__); + + GPR_ASSERT(client_binder != nullptr); + GPR_ASSERT(security_policy != nullptr); + + grpc_binder_transport* t = new grpc_binder_transport( + std::move(client_binder), /*is_client=*/false, security_policy); + + return &t->base; +} diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc new file mode 100644 index 00000000..74c9ef1e --- /dev/null +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc @@ -0,0 +1,252 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h" + +#include +#include +#include + +#include + +namespace grpc_binder { + +const absl::string_view + TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully = + "grpc-binder-transport: cancelled gracefully"; + +void TransportStreamReceiverImpl::RegisterRecvInitialMetadata( + StreamIdentifier id, InitialMetadataCallbackType cb) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + absl::StatusOr initial_metadata{}; + { + grpc_core::MutexLock l(&m_); + GPR_ASSERT(initial_metadata_cbs_.count(id) == 0); + auto iter = pending_initial_metadata_.find(id); + if (iter == pending_initial_metadata_.end()) { + if (trailing_metadata_recvd_.count(id)) { + cb(absl::CancelledError("")); + } else { + initial_metadata_cbs_[id] = std::move(cb); + } + cb = nullptr; + } else { + initial_metadata = std::move(iter->second.front()); + iter->second.pop(); + if (iter->second.empty()) { + pending_initial_metadata_.erase(iter); + } + } + } + if (cb != nullptr) { + cb(std::move(initial_metadata)); + } +} + +void TransportStreamReceiverImpl::RegisterRecvMessage( + StreamIdentifier id, MessageDataCallbackType cb) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + absl::StatusOr message{}; + { + grpc_core::MutexLock l(&m_); + GPR_ASSERT(message_cbs_.count(id) == 0); + auto iter = pending_message_.find(id); + if (iter == pending_message_.end()) { + // If we'd already received trailing-metadata and there's no pending + // messages, cancel the callback. + if (trailing_metadata_recvd_.count(id)) { + cb(absl::CancelledError( + TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully)); + } else { + message_cbs_[id] = std::move(cb); + } + cb = nullptr; + } else { + // We'll still keep all pending messages received before the trailing + // metadata since they're issued before the end of stream, as promised by + // WireReader which keeps transactions commit in-order. + message = std::move(iter->second.front()); + iter->second.pop(); + if (iter->second.empty()) { + pending_message_.erase(iter); + } + } + } + if (cb != nullptr) { + cb(std::move(message)); + } +} + +void TransportStreamReceiverImpl::RegisterRecvTrailingMetadata( + StreamIdentifier id, TrailingMetadataCallbackType cb) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + std::pair, int> trailing_metadata{}; + { + grpc_core::MutexLock l(&m_); + GPR_ASSERT(trailing_metadata_cbs_.count(id) == 0); + auto iter = pending_trailing_metadata_.find(id); + if (iter == pending_trailing_metadata_.end()) { + trailing_metadata_cbs_[id] = std::move(cb); + cb = nullptr; + } else { + trailing_metadata = std::move(iter->second.front()); + iter->second.pop(); + if (iter->second.empty()) { + pending_trailing_metadata_.erase(iter); + } + } + } + if (cb != nullptr) { + cb(std::move(trailing_metadata.first), trailing_metadata.second); + } +} + +void TransportStreamReceiverImpl::NotifyRecvInitialMetadata( + StreamIdentifier id, absl::StatusOr initial_metadata) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + if (!is_client_ && accept_stream_callback_ && initial_metadata.ok()) { + accept_stream_callback_(); + } + InitialMetadataCallbackType cb; + { + grpc_core::MutexLock l(&m_); + auto iter = initial_metadata_cbs_.find(id); + if (iter != initial_metadata_cbs_.end()) { + cb = iter->second; + initial_metadata_cbs_.erase(iter); + } else { + pending_initial_metadata_[id].push(std::move(initial_metadata)); + return; + } + } + cb(std::move(initial_metadata)); +} + +void TransportStreamReceiverImpl::NotifyRecvMessage( + StreamIdentifier id, absl::StatusOr message) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + MessageDataCallbackType cb; + { + grpc_core::MutexLock l(&m_); + auto iter = message_cbs_.find(id); + if (iter != message_cbs_.end()) { + cb = iter->second; + message_cbs_.erase(iter); + } else { + pending_message_[id].push(std::move(message)); + return; + } + } + cb(std::move(message)); +} + +void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata( + StreamIdentifier id, absl::StatusOr trailing_metadata, + int status) { + // Trailing metadata mark the end of the stream. Since TransportStreamReceiver + // assumes in-order commitments of transactions and that trailing metadata is + // parsed after message data, we can safely cancel all upcoming callbacks of + // recv_message. + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + OnRecvTrailingMetadata(id); + TrailingMetadataCallbackType cb; + { + grpc_core::MutexLock l(&m_); + auto iter = trailing_metadata_cbs_.find(id); + if (iter != trailing_metadata_cbs_.end()) { + cb = iter->second; + trailing_metadata_cbs_.erase(iter); + } else { + pending_trailing_metadata_[id].emplace(std::move(trailing_metadata), + status); + return; + } + } + cb(std::move(trailing_metadata), status); +} + +void TransportStreamReceiverImpl::CancelInitialMetadataCallback( + StreamIdentifier id, absl::Status error) { + InitialMetadataCallbackType callback = nullptr; + { + grpc_core::MutexLock l(&m_); + auto iter = initial_metadata_cbs_.find(id); + if (iter != initial_metadata_cbs_.end()) { + callback = std::move(iter->second); + initial_metadata_cbs_.erase(iter); + } + } + if (callback != nullptr) { + std::move(callback)(error); + } +} + +void TransportStreamReceiverImpl::CancelMessageCallback(StreamIdentifier id, + absl::Status error) { + MessageDataCallbackType callback = nullptr; + { + grpc_core::MutexLock l(&m_); + auto iter = message_cbs_.find(id); + if (iter != message_cbs_.end()) { + callback = std::move(iter->second); + message_cbs_.erase(iter); + } + } + if (callback != nullptr) { + std::move(callback)(error); + } +} + +void TransportStreamReceiverImpl::CancelTrailingMetadataCallback( + StreamIdentifier id, absl::Status error) { + TrailingMetadataCallbackType callback = nullptr; + { + grpc_core::MutexLock l(&m_); + auto iter = trailing_metadata_cbs_.find(id); + if (iter != trailing_metadata_cbs_.end()) { + callback = std::move(iter->second); + trailing_metadata_cbs_.erase(iter); + } + } + if (callback != nullptr) { + std::move(callback)(error, 0); + } +} + +void TransportStreamReceiverImpl::OnRecvTrailingMetadata(StreamIdentifier id) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + m_.Lock(); + trailing_metadata_recvd_.insert(id); + m_.Unlock(); + CancelInitialMetadataCallback(id, absl::CancelledError("")); + CancelMessageCallback( + id, + absl::CancelledError( + TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully)); +} + +void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + CancelInitialMetadataCallback(id, absl::CancelledError("Stream cancelled")); + CancelMessageCallback(id, absl::CancelledError("Stream cancelled")); + CancelTrailingMetadataCallback(id, absl::CancelledError("Stream cancelled")); + grpc_core::MutexLock l(&m_); + trailing_metadata_recvd_.erase(id); + pending_initial_metadata_.erase(id); + pending_message_.erase(id); + pending_trailing_metadata_.erase(id); +} +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/wire_format/binder_android.cc b/src/core/ext/transport/binder/wire_format/binder_android.cc new file mode 100644 index 00000000..b78fe8be --- /dev/null +++ b/src/core/ext/transport/binder/wire_format/binder_android.cc @@ -0,0 +1,331 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/ext/transport/binder/wire_format/binder_android.h" +#include "src/core/lib/gprpp/sync.h" + +extern "C" { +// TODO(mingcl): This function is introduced at API level 32 and is not +// available in any NDK release yet. So we export it weakly so that we can use +// it without triggering undefined reference error. Its purpose is to disable +// header in Parcel to conform to the BinderChannel wire format. +extern void AIBinder_Class_disableInterfaceTokenHeader(AIBinder_Class* clazz) + __attribute__((weak)); +// This is released in API level 31. +extern int32_t AParcel_getDataSize(const AParcel* parcel) __attribute__((weak)); +} + +namespace grpc_binder { +namespace { + +struct BinderUserData { + explicit BinderUserData(grpc_core::RefCountedPtr wire_reader_ref, + TransactionReceiver::OnTransactCb* callback) + : wire_reader_ref(wire_reader_ref), callback(callback) {} + grpc_core::RefCountedPtr wire_reader_ref; + TransactionReceiver::OnTransactCb* callback; +}; + +struct OnCreateArgs { + grpc_core::RefCountedPtr wire_reader_ref; + TransactionReceiver::OnTransactCb* callback; +}; + +void* f_onCreate_userdata(void* data) { + auto* args = static_cast(data); + return new BinderUserData(args->wire_reader_ref, args->callback); +} + +void f_onDestroy_delete(void* data) { + auto* user_data = static_cast(data); + delete user_data; +} + +void* f_onCreate_noop(void* /*args*/) { return nullptr; } +void f_onDestroy_noop(void* /*userData*/) {} + +// TODO(mingcl): Consider if thread safety is a requirement here +binder_status_t f_onTransact(AIBinder* binder, transaction_code_t code, + const AParcel* in, AParcel* /*out*/) { + gpr_log(GPR_INFO, __func__); + gpr_log(GPR_INFO, "tx code = %u", code); + + auto* user_data = static_cast(AIBinder_getUserData(binder)); + TransactionReceiver::OnTransactCb* callback = user_data->callback; + // Wrap the parcel in a ReadableParcel. + std::unique_ptr output = + absl::make_unique(in); + // The lock should be released "after" the callback finishes. + absl::Status status = + (*callback)(code, output.get(), AIBinder_getCallingUid()); + if (status.ok()) { + return STATUS_OK; + } else { + gpr_log(GPR_ERROR, "Callback failed: %s", status.ToString().c_str()); + return STATUS_UNKNOWN_ERROR; + } +} + +// StdStringAllocator, ReadString, StdVectorAllocator, and ReadVector's +// implementations are copied from android/binder_parcel_utils.h +// We cannot include the header because it does not compile in C++11 + +bool StdStringAllocator(void* stringData, int32_t length, char** buffer) { + if (length <= 0) return false; + + std::string* str = static_cast(stringData); + str->resize(static_cast(length) - 1); + *buffer = &(*str)[0]; + return true; +} + +binder_status_t AParcelReadString(const AParcel* parcel, std::string* str) { + void* stringData = static_cast(str); + return AParcel_readString(parcel, stringData, StdStringAllocator); +} + +template +bool StdVectorAllocator(void* vectorData, int32_t length, T** outBuffer) { + if (length < 0) return false; + + std::vector* vec = static_cast*>(vectorData); + if (static_cast(length) > vec->max_size()) return false; + + vec->resize(static_cast(length)); + *outBuffer = vec->data(); + return true; +} + +binder_status_t AParcelReadVector(const AParcel* parcel, + std::vector* vec) { + void* vectorData = static_cast(vec); + return AParcel_readByteArray(parcel, vectorData, StdVectorAllocator); +} + +} // namespace + +ndk::SpAIBinder FromJavaBinder(JNIEnv* jni_env, jobject binder) { + return ndk::SpAIBinder(AIBinder_fromJavaBinder(jni_env, binder)); +} + +TransactionReceiverAndroid::TransactionReceiverAndroid( + grpc_core::RefCountedPtr wire_reader_ref, + OnTransactCb transact_cb) + : transact_cb_(transact_cb) { + // TODO(mingcl): For now interface descriptor is always empty, figure out if + // we want it to be something more meaningful (we can probably manually change + // interface descriptor by modifying Java code's reply to + // os.IBinder.INTERFACE_TRANSACTION) + AIBinder_Class* aibinder_class = AIBinder_Class_define( + /*interfaceDescriptor=*/"", f_onCreate_userdata, f_onDestroy_delete, + f_onTransact); + + if (AIBinder_Class_disableInterfaceTokenHeader) { + AIBinder_Class_disableInterfaceTokenHeader(aibinder_class); + } else { + // TODO(mingcl): Make this a fatal error + gpr_log(GPR_ERROR, + "AIBinder_Class_disableInterfaceTokenHeader remain unresolved. " + "This BinderTransport implementation contains header and is not " + "compatible with Java's implementation"); + } + + // Pass the on-transact callback to the on-create function of the binder. The + // on-create function equips the callback with a mutex and gives it to the + // user data stored in the binder which can be retrieved later. + // Also Ref() (called implicitly by the copy constructor of RefCountedPtr) the + // wire reader so that it would not be destructed during the callback + // invocation. + OnCreateArgs args; + args.wire_reader_ref = wire_reader_ref; + args.callback = &transact_cb_; + binder_ = AIBinder_new(aibinder_class, &args); + GPR_ASSERT(binder_); + gpr_log(GPR_INFO, "AIBinder_associateClass = %d", + static_cast(AIBinder_associateClass(binder_, aibinder_class))); +} + +TransactionReceiverAndroid::~TransactionReceiverAndroid() { + // Release the binder. + AIBinder_decStrong(binder_); +} + +namespace { + +binder_status_t f_onTransact_noop(AIBinder* /*binder*/, + transaction_code_t /*code*/, + const AParcel* /*in*/, AParcel* /*out*/) { + return {}; +} + +void AssociateWithNoopClass(AIBinder* binder) { + // Need to associate class before using it + AIBinder_Class* aibinder_class = AIBinder_Class_define( + "", f_onCreate_noop, f_onDestroy_noop, f_onTransact_noop); + + if (AIBinder_Class_disableInterfaceTokenHeader) { + AIBinder_Class_disableInterfaceTokenHeader(aibinder_class); + } else { + // TODO(mingcl): Make this a fatal error + gpr_log(GPR_ERROR, + "AIBinder_Class_disableInterfaceTokenHeader remain unresolved. " + "This BinderTransport implementation contains header and is not " + "compatible with Java's implementation"); + } + + gpr_log(GPR_INFO, "AIBinder_associateClass = %d", + static_cast(AIBinder_associateClass(binder, aibinder_class))); +} + +} // namespace + +void BinderAndroid::Initialize() { + AIBinder* binder = binder_.get(); + AssociateWithNoopClass(binder); +} + +absl::Status BinderAndroid::PrepareTransaction() { + AIBinder* binder = binder_.get(); + return AIBinder_prepareTransaction(binder, &input_parcel_->parcel_) == + STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AIBinder_prepareTransaction failed"); +} + +absl::Status BinderAndroid::Transact(BinderTransportTxCode tx_code) { + AIBinder* binder = binder_.get(); + // We only do one-way transaction and thus the output parcel is never used. + AParcel* unused_output_parcel; + absl::Status result = + (AIBinder_transact(binder, static_cast(tx_code), + &input_parcel_->parcel_, &unused_output_parcel, + FLAG_ONEWAY) == STATUS_OK) + ? absl::OkStatus() + : absl::InternalError("AIBinder_transact failed"); + AParcel_delete(unused_output_parcel); + return result; +} + +std::unique_ptr BinderAndroid::ConstructTxReceiver( + grpc_core::RefCountedPtr wire_reader_ref, + TransactionReceiver::OnTransactCb transact_cb) const { + return absl::make_unique(wire_reader_ref, + transact_cb); +} + +int32_t WritableParcelAndroid::GetDataSize() const { + if (AParcel_getDataSize) { + return AParcel_getDataSize(parcel_); + } else { + gpr_log(GPR_INFO, "[Warning] AParcel_getDataSize is not available"); + return 0; + } +} + +absl::Status WritableParcelAndroid::WriteInt32(int32_t data) { + return AParcel_writeInt32(parcel_, data) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_writeInt32 failed"); +} + +absl::Status WritableParcelAndroid::WriteInt64(int64_t data) { + return AParcel_writeInt64(parcel_, data) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_writeInt64 failed"); +} + +absl::Status WritableParcelAndroid::WriteBinder(HasRawBinder* binder) { + return AParcel_writeStrongBinder( + parcel_, reinterpret_cast(binder->GetRawBinder())) == + STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_writeStrongBinder failed"); +} + +absl::Status WritableParcelAndroid::WriteString(absl::string_view s) { + return AParcel_writeString(parcel_, s.data(), s.length()) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_writeString failed"); +} + +absl::Status WritableParcelAndroid::WriteByteArray(const int8_t* buffer, + int32_t length) { + return AParcel_writeByteArray(parcel_, buffer, length) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_writeByteArray failed"); +} + +int32_t ReadableParcelAndroid::GetDataSize() const { + if (AParcel_getDataSize) { + return AParcel_getDataSize(parcel_); + } else { + gpr_log(GPR_INFO, "[Warning] AParcel_getDataSize is not available"); + return 0; + } +} + +absl::Status ReadableParcelAndroid::ReadInt32(int32_t* data) { + return AParcel_readInt32(parcel_, data) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_readInt32 failed"); +} + +absl::Status ReadableParcelAndroid::ReadInt64(int64_t* data) { + return AParcel_readInt64(parcel_, data) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_readInt64 failed"); +} + +absl::Status ReadableParcelAndroid::ReadBinder(std::unique_ptr* data) { + AIBinder* binder; + if (AParcel_readStrongBinder(parcel_, &binder) != STATUS_OK) { + *data = nullptr; + return absl::InternalError("AParcel_readStrongBinder failed"); + } + *data = absl::make_unique(ndk::SpAIBinder(binder)); + return absl::OkStatus(); +} + +absl::Status ReadableParcelAndroid::ReadByteArray(std::string* data) { + std::vector vec; + if (AParcelReadVector(parcel_, &vec) == STATUS_OK) { + data->resize(vec.size()); + if (!vec.empty()) { + memcpy(&((*data)[0]), vec.data(), vec.size()); + } + return absl::OkStatus(); + } + return absl::InternalError("AParcel_readByteArray failed"); +} + +absl::Status ReadableParcelAndroid::ReadString(std::string* str) { + return AParcelReadString(parcel_, str) == STATUS_OK + ? absl::OkStatus() + : absl::InternalError("AParcel_readString failed"); +} + +} // namespace grpc_binder + +#endif // GPR_SUPPORT_BINDER_TRANSPORT diff --git a/src/core/ext/transport/binder/wire_format/binder_constants.cc b/src/core/ext/transport/binder/wire_format/binder_constants.cc new file mode 100644 index 00000000..605c4911 --- /dev/null +++ b/src/core/ext/transport/binder/wire_format/binder_constants.cc @@ -0,0 +1,30 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/wire_format/binder_constants.h" + +#ifndef GPR_SUPPORT_BINDER_TRANSPORT + +const int FIRST_CALL_TRANSACTION = 0x00000001; +const int LAST_CALL_TRANSACTION = 0x00FFFFFF; + +#endif // GPR_SUPPORT_BINDER_TRANSPORT + +namespace grpc_binder { + +const int kFirstCallId = FIRST_CALL_TRANSACTION + 1000; + +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/wire_format/transaction.cc b/src/core/ext/transport/binder/wire_format/transaction.cc new file mode 100644 index 00000000..753aed8d --- /dev/null +++ b/src/core/ext/transport/binder/wire_format/transaction.cc @@ -0,0 +1,30 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/wire_format/transaction.h" + +namespace grpc_binder { + +const int kFlagPrefix = 0x1; +const int kFlagMessageData = 0x2; +const int kFlagSuffix = 0x4; +const int kFlagOutOfBandClose = 0x8; +const int kFlagExpectSingleMessage = 0x10; +const int kFlagStatusDescription = 0x20; +const int kFlagMessageDataIsParcelable = 0x40; +const int kFlagMessageDataIsPartial = 0x80; + +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc new file mode 100644 index 00000000..dd2c74a3 --- /dev/null +++ b/src/core/ext/transport/binder/wire_format/wire_reader_impl.cc @@ -0,0 +1,381 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/wire_format/wire_reader_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" + +#include + +#include "src/core/ext/transport/binder/utils/transport_stream_receiver.h" +#include "src/core/ext/transport/binder/wire_format/binder.h" +#include "src/core/ext/transport/binder/wire_format/wire_writer.h" + +#define RETURN_IF_ERROR(expr) \ + do { \ + const absl::Status status = (expr); \ + if (!status.ok()) return status; \ + } while (0) + +namespace grpc_binder { +namespace { + +absl::StatusOr parse_metadata(ReadableParcel* reader) { + int num_header; + RETURN_IF_ERROR(reader->ReadInt32(&num_header)); + gpr_log(GPR_INFO, "num_header = %d", num_header); + if (num_header < 0) { + return absl::InvalidArgumentError("num_header cannot be negative"); + } + std::vector> ret; + for (int i = 0; i < num_header; i++) { + int count; + RETURN_IF_ERROR(reader->ReadInt32(&count)); + gpr_log(GPR_INFO, "count = %d", count); + std::string key{}; + if (count > 0) RETURN_IF_ERROR(reader->ReadByteArray(&key)); + gpr_log(GPR_INFO, "key = %s", key.c_str()); + RETURN_IF_ERROR(reader->ReadInt32(&count)); + gpr_log(GPR_INFO, "count = %d", count); + std::string value{}; + if (count > 0) RETURN_IF_ERROR(reader->ReadByteArray(&value)); + gpr_log(GPR_INFO, "value = %s", value.c_str()); + ret.emplace_back(key, value); + } + return ret; +} + +} // namespace + +WireReaderImpl::WireReaderImpl( + std::shared_ptr transport_stream_receiver, + bool is_client, + std::shared_ptr security_policy, + std::function on_destruct_callback) + : transport_stream_receiver_(std::move(transport_stream_receiver)), + is_client_(is_client), + security_policy_(security_policy), + on_destruct_callback_(on_destruct_callback) {} + +WireReaderImpl::~WireReaderImpl() { + if (on_destruct_callback_) { + on_destruct_callback_(); + } +} + +std::shared_ptr WireReaderImpl::SetupTransport( + std::unique_ptr binder) { + gpr_log(GPR_INFO, "Setting up transport"); + if (!is_client_) { + SendSetupTransport(binder.get()); + { + grpc_core::MutexLock lock(&mu_); + connected_ = true; + wire_writer_ = std::make_shared(std::move(binder)); + } + return wire_writer_; + } else { + SendSetupTransport(binder.get()); + auto other_end_binder = RecvSetupTransport(); + { + grpc_core::MutexLock lock(&mu_); + connected_ = true; + wire_writer_ = + std::make_shared(std::move(other_end_binder)); + } + return wire_writer_; + } +} + +void WireReaderImpl::SendSetupTransport(Binder* binder) { + binder->Initialize(); + gpr_log(GPR_INFO, "prepare transaction = %d", + binder->PrepareTransaction().ok()); + WritableParcel* writable_parcel = binder->GetWritableParcel(); + int32_t version = 77; + gpr_log(GPR_INFO, "write int32 = %d", + writable_parcel->WriteInt32(version).ok()); + // The lifetime of the transaction receiver is the same as the wire writer's. + // The transaction receiver is responsible for not calling the on-transact + // callback when it's dead. + // Give TransactionReceiver a Ref() since WireReader cannot be destructed + // during callback execution. TransactionReceiver should make sure that the + // callback owns a Ref() when it's being invoked. + tx_receiver_ = binder->ConstructTxReceiver( + /*wire_reader_ref=*/Ref(), + [this](transaction_code_t code, ReadableParcel* readable_parcel, + int uid) { + return this->ProcessTransaction(code, readable_parcel, uid); + }); + + gpr_log(GPR_INFO, "tx_receiver = %p", tx_receiver_->GetRawBinder()); + gpr_log(GPR_INFO, "AParcel_writeStrongBinder = %d", + writable_parcel->WriteBinder(tx_receiver_.get()).ok()); + gpr_log(GPR_INFO, "AIBinder_transact = %d", + binder->Transact(BinderTransportTxCode::SETUP_TRANSPORT).ok()); +} + +std::unique_ptr WireReaderImpl::RecvSetupTransport() { + // TODO(b/191941760): avoid blocking, handle wire_writer_noti lifetime + // better + gpr_log(GPR_INFO, "start waiting for noti"); + connection_noti_.WaitForNotification(); + gpr_log(GPR_INFO, "end waiting for noti"); + return std::move(other_end_binder_); +} + +absl::Status WireReaderImpl::ProcessTransaction(transaction_code_t code, + ReadableParcel* parcel, + int uid) { + gpr_log(GPR_INFO, __func__); + gpr_log(GPR_INFO, "tx code = %u", code); + if (code >= static_cast(kFirstCallId)) { + gpr_log(GPR_INFO, "This is probably a Streaming Tx"); + return ProcessStreamingTransaction(code, parcel); + } + + if (!(code >= static_cast( + BinderTransportTxCode::SETUP_TRANSPORT) && + code <= static_cast( + BinderTransportTxCode::PING_RESPONSE))) { + gpr_log(GPR_INFO, + "Received unknown control message. Shutdown transport gracefully."); + // TODO(waynetu): Shutdown transport gracefully. + return absl::OkStatus(); + } + + grpc_core::MutexLock lock(&mu_); + + if (BinderTransportTxCode(code) != BinderTransportTxCode::SETUP_TRANSPORT && + !connected_) { + return absl::InvalidArgumentError("Transports not connected yet"); + } + + // TODO(mingcl): See if we want to check the security policy for every RPC + // call or just during transport setup. + + switch (BinderTransportTxCode(code)) { + case BinderTransportTxCode::SETUP_TRANSPORT: { + if (recvd_setup_transport_) { + return absl::InvalidArgumentError( + "Already received a SETUP_TRANSPORT request"); + } + recvd_setup_transport_ = true; + + gpr_log(GPR_ERROR, "calling uid = %d", uid); + if (!security_policy_->IsAuthorized(uid)) { + return absl::PermissionDeniedError( + "UID " + std::to_string(uid) + + " is not allowed to connect to this " + "transport according to security policy."); + } + + int version; + RETURN_IF_ERROR(parcel->ReadInt32(&version)); + gpr_log(GPR_INFO, "version = %d", version); + std::unique_ptr binder{}; + RETURN_IF_ERROR(parcel->ReadBinder(&binder)); + if (!binder) { + return absl::InternalError("Read NULL binder from the parcel"); + } + binder->Initialize(); + other_end_binder_ = std::move(binder); + connection_noti_.Notify(); + break; + } + case BinderTransportTxCode::SHUTDOWN_TRANSPORT: { + gpr_log(GPR_ERROR, + "Received SHUTDOWN_TRANSPORT request but not implemented yet."); + return absl::UnimplementedError("SHUTDOWN_TRANSPORT"); + } + case BinderTransportTxCode::ACKNOWLEDGE_BYTES: { + int64_t num_bytes = -1; + RETURN_IF_ERROR(parcel->ReadInt64(&num_bytes)); + gpr_log(GPR_INFO, "received acknowledge bytes = %lld", + static_cast(num_bytes)); + wire_writer_->OnAckReceived(num_bytes); + break; + } + case BinderTransportTxCode::PING: { + if (is_client_) { + return absl::FailedPreconditionError("Receive PING request in client"); + } + int ping_id = -1; + RETURN_IF_ERROR(parcel->ReadInt32(&ping_id)); + gpr_log(GPR_INFO, "received ping id = %d", ping_id); + // TODO(waynetu): Ping back. + break; + } + case BinderTransportTxCode::PING_RESPONSE: { + int value = -1; + RETURN_IF_ERROR(parcel->ReadInt32(&value)); + gpr_log(GPR_INFO, "received ping response = %d", value); + break; + } + } + return absl::OkStatus(); +} + +absl::Status WireReaderImpl::ProcessStreamingTransaction( + transaction_code_t code, ReadableParcel* parcel) { + grpc_core::MutexLock lock(&mu_); + if (!connected_) { + return absl::InvalidArgumentError("Transports not connected yet"); + } + + // Indicate which callbacks should be cancelled. It will be initialized as the + // flags the in-coming transaction carries, and when a particular callback is + // completed, the corresponding bit in cancellation_flag will be set to 0 so + // that we won't cancel it afterward. + int cancellation_flags = 0; + absl::Status status = + ProcessStreamingTransactionImpl(code, parcel, &cancellation_flags); + if (!status.ok()) { + gpr_log(GPR_ERROR, "Failed to process streaming transaction: %s", + status.ToString().c_str()); + // Something went wrong when receiving transaction. Cancel failed requests. + if (cancellation_flags & kFlagPrefix) { + gpr_log(GPR_INFO, "cancelling initial metadata"); + transport_stream_receiver_->NotifyRecvInitialMetadata(code, status); + } + if (cancellation_flags & kFlagMessageData) { + gpr_log(GPR_INFO, "cancelling message data"); + transport_stream_receiver_->NotifyRecvMessage(code, status); + } + if (cancellation_flags & kFlagSuffix) { + gpr_log(GPR_INFO, "cancelling trailing metadata"); + transport_stream_receiver_->NotifyRecvTrailingMetadata(code, status, 0); + } + } + if ((num_incoming_bytes_ - num_acknowledged_bytes_) >= kFlowControlAckBytes) { + GPR_ASSERT(wire_writer_); + absl::Status ack_status = wire_writer_->SendAck(num_incoming_bytes_); + if (status.ok()) { + status = ack_status; + } + num_acknowledged_bytes_ = num_incoming_bytes_; + } + return status; +} + +absl::Status WireReaderImpl::ProcessStreamingTransactionImpl( + transaction_code_t code, ReadableParcel* parcel, int* cancellation_flags) { + GPR_ASSERT(cancellation_flags); + num_incoming_bytes_ += parcel->GetDataSize(); + + int flags; + RETURN_IF_ERROR(parcel->ReadInt32(&flags)); + gpr_log(GPR_INFO, "flags = %d", flags); + *cancellation_flags = flags; + + // Ignore in-coming transaction with flag = 0 to match with Java + // implementation. + // TODO(waynetu): Check with grpc-java team to see whether this is the + // intended behavior. + // TODO(waynetu): What should be returned here? + if (flags == 0) { + gpr_log(GPR_INFO, "[WARNING] Receive empty transaction. Ignored."); + return absl::OkStatus(); + } + + int status = flags >> 16; + gpr_log(GPR_INFO, "status = %d", status); + gpr_log(GPR_INFO, "FLAG_PREFIX = %d", (flags & kFlagPrefix)); + gpr_log(GPR_INFO, "FLAG_MESSAGE_DATA = %d", (flags & kFlagMessageData)); + gpr_log(GPR_INFO, "FLAG_SUFFIX = %d", (flags & kFlagSuffix)); + int seq_num; + RETURN_IF_ERROR(parcel->ReadInt32(&seq_num)); + // TODO(waynetu): For now we'll just assume that the transactions commit in + // the same order they're issued. The following assertion detects + // out-of-order or missing transactions. WireReaderImpl should be fixed if + // we indeed found such behavior. + int32_t& expectation = expected_seq_num_[code]; + if (seq_num < 0 || seq_num != expectation) { + // Unexpected sequence number. + return absl::InternalError("Unexpected sequence number"); + } + // TODO(waynetu): According to the protocol, "The sequence number will wrap + // around to 0 if more than 2^31 messages are sent." For now we'll just + // assert that it never reach such circumstances. + GPR_ASSERT(expectation < std::numeric_limits::max() && + "Sequence number too large"); + expectation++; + gpr_log(GPR_INFO, "sequence number = %d", seq_num); + if (flags & kFlagPrefix) { + std::string method_ref; + if (!is_client_) { + RETURN_IF_ERROR(parcel->ReadString(&method_ref)); + } + absl::StatusOr initial_metadata_or_error = parse_metadata(parcel); + if (!initial_metadata_or_error.ok()) { + return initial_metadata_or_error.status(); + } + if (!is_client_) { + initial_metadata_or_error->emplace_back(":path", + std::string("/") + method_ref); + } + transport_stream_receiver_->NotifyRecvInitialMetadata( + code, *initial_metadata_or_error); + *cancellation_flags &= ~kFlagPrefix; + } + if (flags & kFlagMessageData) { + int count; + RETURN_IF_ERROR(parcel->ReadInt32(&count)); + gpr_log(GPR_INFO, "count = %d", count); + std::string msg_data{}; + if (count > 0) { + RETURN_IF_ERROR(parcel->ReadByteArray(&msg_data)); + } + gpr_log(GPR_INFO, "msg_data = %s", msg_data.c_str()); + message_buffer_[code] += msg_data; + if ((flags & kFlagMessageDataIsPartial) == 0) { + std::string s = std::move(message_buffer_[code]); + message_buffer_.erase(code); + transport_stream_receiver_->NotifyRecvMessage(code, std::move(s)); + } + *cancellation_flags &= ~kFlagMessageData; + } + if (flags & kFlagSuffix) { + if (flags & kFlagStatusDescription) { + // FLAG_STATUS_DESCRIPTION set + std::string desc; + RETURN_IF_ERROR(parcel->ReadString(&desc)); + gpr_log(GPR_INFO, "description = %s", desc.c_str()); + } + Metadata trailing_metadata; + if (is_client_) { + absl::StatusOr trailing_metadata_or_error = + parse_metadata(parcel); + if (!trailing_metadata_or_error.ok()) { + return trailing_metadata_or_error.status(); + } + trailing_metadata = *trailing_metadata_or_error; + } + transport_stream_receiver_->NotifyRecvTrailingMetadata( + code, std::move(trailing_metadata), status); + *cancellation_flags &= ~kFlagSuffix; + } + return absl::OkStatus(); +} + +} // namespace grpc_binder diff --git a/src/core/ext/transport/binder/wire_format/wire_writer.cc b/src/core/ext/transport/binder/wire_format/wire_writer.cc new file mode 100644 index 00000000..ce47d82c --- /dev/null +++ b/src/core/ext/transport/binder/wire_format/wire_writer.cc @@ -0,0 +1,181 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/binder/wire_format/wire_writer.h" + +#include + +#include + +#define RETURN_IF_ERROR(expr) \ + do { \ + const absl::Status status = (expr); \ + if (!status.ok()) return status; \ + } while (0) + +namespace grpc_binder { +WireWriterImpl::WireWriterImpl(std::unique_ptr binder) + : binder_(std::move(binder)) {} + +absl::Status WireWriterImpl::WriteInitialMetadata(const Transaction& tx, + WritableParcel* parcel) { + if (tx.IsClient()) { + // Only client sends method ref. + RETURN_IF_ERROR(parcel->WriteString(tx.GetMethodRef())); + } + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetPrefixMetadata().size())); + for (const auto& md : tx.GetPrefixMetadata()) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); + } + return absl::OkStatus(); +} + +absl::Status WireWriterImpl::WriteTrailingMetadata(const Transaction& tx, + WritableParcel* parcel) { + if (tx.IsServer()) { + if (tx.GetFlags() & kFlagStatusDescription) { + RETURN_IF_ERROR(parcel->WriteString(tx.GetStatusDesc())); + } + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetSuffixMetadata().size())); + for (const auto& md : tx.GetSuffixMetadata()) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.first)); + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(md.second)); + } + } else { + // client suffix currently is always empty according to the wireformat + if (!tx.GetSuffixMetadata().empty()) { + gpr_log(GPR_ERROR, "Got non-empty suffix metadata from client."); + } + } + return absl::OkStatus(); +} + +const int64_t WireWriterImpl::kBlockSize = 16 * 1024; +const int64_t WireWriterImpl::kFlowControlWindowSize = 128 * 1024; + +bool WireWriterImpl::CanBeSentInOneTransaction(const Transaction& tx) const { + return (tx.GetFlags() & kFlagMessageData) == 0 || + tx.GetMessageData().size() <= kBlockSize; +} + +absl::Status WireWriterImpl::RpcCallFastPath(const Transaction& tx) { + int& seq = seq_num_[tx.GetTxCode()]; + // Fast path: send data in one transaction. + RETURN_IF_ERROR(binder_->PrepareTransaction()); + WritableParcel* parcel = binder_->GetWritableParcel(); + RETURN_IF_ERROR(parcel->WriteInt32(tx.GetFlags())); + RETURN_IF_ERROR(parcel->WriteInt32(seq++)); + if (tx.GetFlags() & kFlagPrefix) { + RETURN_IF_ERROR(WriteInitialMetadata(tx, parcel)); + } + if (tx.GetFlags() & kFlagMessageData) { + RETURN_IF_ERROR(parcel->WriteByteArrayWithLength(tx.GetMessageData())); + } + if (tx.GetFlags() & kFlagSuffix) { + RETURN_IF_ERROR(WriteTrailingMetadata(tx, parcel)); + } + // FIXME(waynetu): Construct BinderTransportTxCode from an arbitrary integer + // is an undefined behavior. + return binder_->Transact(BinderTransportTxCode(tx.GetTxCode())); +} + +bool WireWriterImpl::WaitForAcknowledgement() { + if (num_outgoing_bytes_ < num_acknowledged_bytes_ + kFlowControlWindowSize) { + return true; + } + absl::Time deadline = absl::Now() + absl::Seconds(1); + do { + if (cv_.WaitWithDeadline(&mu_, deadline)) { + return false; + } + if (absl::Now() >= deadline) { + return false; + } + } while (num_outgoing_bytes_ >= + num_acknowledged_bytes_ + kFlowControlWindowSize); + return true; +} + +absl::Status WireWriterImpl::RpcCall(const Transaction& tx) { + // TODO(mingcl): check tx_code <= last call id + grpc_core::MutexLock lock(&mu_); + GPR_ASSERT(tx.GetTxCode() >= kFirstCallId); + if (CanBeSentInOneTransaction(tx)) { + return RpcCallFastPath(tx); + } + // Slow path: the message data is too large to fit in one transaction. + int& seq = seq_num_[tx.GetTxCode()]; + int original_flags = tx.GetFlags(); + GPR_ASSERT(original_flags & kFlagMessageData); + absl::string_view data = tx.GetMessageData(); + size_t bytes_sent = 0; + while (bytes_sent < data.size()) { + if (!WaitForAcknowledgement()) { + return absl::InternalError("Timeout waiting for acknowledgement"); + } + RETURN_IF_ERROR(binder_->PrepareTransaction()); + WritableParcel* parcel = binder_->GetWritableParcel(); + size_t size = + std::min(static_cast(kBlockSize), data.size() - bytes_sent); + int flags = kFlagMessageData; + if (bytes_sent == 0) { + // This is the first transaction. Include initial metadata if there's any. + if (original_flags & kFlagPrefix) { + flags |= kFlagPrefix; + } + } + if (bytes_sent + kBlockSize >= data.size()) { + // This is the last transaction. Include trailing metadata if there's any. + if (original_flags & kFlagSuffix) { + flags |= kFlagSuffix; + } + } else { + // There are more messages to send. + flags |= kFlagMessageDataIsPartial; + } + RETURN_IF_ERROR(parcel->WriteInt32(flags)); + RETURN_IF_ERROR(parcel->WriteInt32(seq++)); + if (flags & kFlagPrefix) { + RETURN_IF_ERROR(WriteInitialMetadata(tx, parcel)); + } + RETURN_IF_ERROR( + parcel->WriteByteArrayWithLength(data.substr(bytes_sent, size))); + if (flags & kFlagSuffix) { + RETURN_IF_ERROR(WriteTrailingMetadata(tx, parcel)); + } + num_outgoing_bytes_ += parcel->GetDataSize(); + RETURN_IF_ERROR(binder_->Transact(BinderTransportTxCode(tx.GetTxCode()))); + bytes_sent += size; + } + return absl::OkStatus(); +} + +absl::Status WireWriterImpl::SendAck(int64_t num_bytes) { + grpc_core::MutexLock lock(&mu_); + RETURN_IF_ERROR(binder_->PrepareTransaction()); + WritableParcel* parcel = binder_->GetWritableParcel(); + RETURN_IF_ERROR(parcel->WriteInt64(num_bytes)); + return binder_->Transact(BinderTransportTxCode::ACKNOWLEDGE_BYTES); +} + +void WireWriterImpl::OnAckReceived(int64_t num_bytes) { + grpc_core::MutexLock lock(&mu_); + num_acknowledged_bytes_ = std::max(num_acknowledged_bytes_, num_bytes); + cv_.Signal(); +} + +} // namespace grpc_binder diff --git a/src/core/ext/transport/chttp2/alpn/alpn.cc b/src/core/ext/transport/chttp2/alpn/alpn.cc new file mode 100644 index 00000000..33f76276 --- /dev/null +++ b/src/core/ext/transport/chttp2/alpn/alpn.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" + +#include + +#include "src/core/lib/gpr/useful.h" + +/* in order of preference */ +static const char* const supported_versions[] = {"grpc-exp", "h2"}; + +int grpc_chttp2_is_alpn_version_supported(const char* version, size_t size) { + size_t i; + for (i = 0; i < GPR_ARRAY_SIZE(supported_versions); i++) { + if (!strncmp(version, supported_versions[i], size)) return 1; + } + return 0; +} + +size_t grpc_chttp2_num_alpn_versions(void) { + return GPR_ARRAY_SIZE(supported_versions); +} + +const char* grpc_chttp2_get_alpn_version_index(size_t i) { + GPR_ASSERT(i < GPR_ARRAY_SIZE(supported_versions)); + return supported_versions[i]; +} diff --git a/src/core/ext/transport/chttp2/client/chttp2_connector.cc b/src/core/ext/transport/chttp2/client/chttp2_connector.cc new file mode 100644 index 00000000..03e95a13 --- /dev/null +++ b/src/core/ext/transport/chttp2/client/chttp2_connector.cc @@ -0,0 +1,276 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/client/chttp2_connector.h" + +#include + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/connector.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +Chttp2Connector::Chttp2Connector() { + GRPC_CLOSURE_INIT(&connected_, Connected, this, grpc_schedule_on_exec_ctx); +} + +Chttp2Connector::~Chttp2Connector() { + if (resource_quota_ != nullptr) { + grpc_resource_quota_unref_internal(resource_quota_); + } + if (endpoint_ != nullptr) { + grpc_endpoint_destroy(endpoint_); + } +} + +void Chttp2Connector::Connect(const Args& args, Result* result, + grpc_closure* notify) { + grpc_endpoint** ep; + { + MutexLock lock(&mu_); + GPR_ASSERT(notify_ == nullptr); + args_ = args; + result_ = result; + notify_ = notify; + GPR_ASSERT(!connecting_); + connecting_ = true; + GPR_ASSERT(endpoint_ == nullptr); + ep = &endpoint_; + if (resource_quota_ != nullptr) { + grpc_resource_quota_unref_internal(resource_quota_); + } + resource_quota_ = + grpc_resource_quota_from_channel_args(args.channel_args, true); + } + // In some implementations, the closure can be flushed before + // grpc_tcp_client_connect() returns, and since the closure requires access + // to mu_, this can result in a deadlock (see + // https://github.com/grpc/grpc/issues/16427 for details). + // grpc_tcp_client_connect() will fill endpoint_ with proper contents, and we + // make sure that we still exist at that point by taking a ref. + Ref().release(); // Ref held by callback. + grpc_tcp_client_connect( + &connected_, ep, + grpc_slice_allocator_create(resource_quota_, + grpc_sockaddr_to_string(args.address, false), + args.channel_args), + args.interested_parties, args.channel_args, args.address, args.deadline); +} + +void Chttp2Connector::Shutdown(grpc_error_handle error) { + MutexLock lock(&mu_); + shutdown_ = true; + if (handshake_mgr_ != nullptr) { + handshake_mgr_->Shutdown(GRPC_ERROR_REF(error)); + } + // If handshaking is not yet in progress, shutdown the endpoint. + // Otherwise, the handshaker will do this for us. + if (!connecting_ && endpoint_ != nullptr) { + grpc_endpoint_shutdown(endpoint_, GRPC_ERROR_REF(error)); + } + GRPC_ERROR_UNREF(error); +} + +void Chttp2Connector::Connected(void* arg, grpc_error_handle error) { + Chttp2Connector* self = static_cast(arg); + bool unref = false; + { + MutexLock lock(&self->mu_); + GPR_ASSERT(self->connecting_); + self->connecting_ = false; + if (error != GRPC_ERROR_NONE || self->shutdown_) { + if (error == GRPC_ERROR_NONE) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("connector shutdown"); + } else { + error = GRPC_ERROR_REF(error); + } + if (self->endpoint_ != nullptr) { + grpc_endpoint_shutdown(self->endpoint_, GRPC_ERROR_REF(error)); + } + self->result_->Reset(); + grpc_closure* notify = self->notify_; + self->notify_ = nullptr; + ExecCtx::Run(DEBUG_LOCATION, notify, error); + unref = true; + } else { + GPR_ASSERT(self->endpoint_ != nullptr); + self->StartHandshakeLocked(); + } + } + if (unref) self->Unref(); +} + +void Chttp2Connector::StartHandshakeLocked() { + handshake_mgr_ = MakeRefCounted(); + CoreConfiguration::Get().handshaker_registry().AddHandshakers( + HANDSHAKER_CLIENT, args_.channel_args, args_.interested_parties, + handshake_mgr_.get()); + grpc_endpoint_add_to_pollset_set(endpoint_, args_.interested_parties); + handshake_mgr_->DoHandshake(endpoint_, args_.channel_args, args_.deadline, + nullptr /* acceptor */, OnHandshakeDone, this); + endpoint_ = nullptr; // Endpoint handed off to handshake manager. +} + +namespace { +void NullThenSchedClosure(const DebugLocation& location, grpc_closure** closure, + grpc_error_handle error) { + grpc_closure* c = *closure; + *closure = nullptr; + ExecCtx::Run(location, c, error); +} +} // namespace + +void Chttp2Connector::OnHandshakeDone(void* arg, grpc_error_handle error) { + auto* args = static_cast(arg); + Chttp2Connector* self = static_cast(args->user_data); + { + MutexLock lock(&self->mu_); + if (error != GRPC_ERROR_NONE || self->shutdown_) { + if (error == GRPC_ERROR_NONE) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("connector shutdown"); + // We were shut down after handshaking completed successfully, so + // destroy the endpoint here. + if (args->endpoint != nullptr) { + // TODO(ctiller): It is currently necessary to shutdown endpoints + // before destroying them, even if we know that there are no + // pending read/write callbacks. This should be fixed, at which + // point this can be removed. + grpc_endpoint_shutdown(args->endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_destroy(args->endpoint); + grpc_channel_args_destroy(args->args); + grpc_slice_buffer_destroy_internal(args->read_buffer); + gpr_free(args->read_buffer); + } + } else { + error = GRPC_ERROR_REF(error); + } + self->result_->Reset(); + NullThenSchedClosure(DEBUG_LOCATION, &self->notify_, error); + } else if (args->endpoint != nullptr) { + self->result_->transport = grpc_create_chttp2_transport( + args->args, args->endpoint, true, + grpc_resource_user_create( + self->resource_quota_, + absl::StrCat(grpc_endpoint_get_peer(args->endpoint), + ":connector_transport"))); + self->result_->socket_node = + grpc_chttp2_transport_get_socket_node(self->result_->transport); + self->result_->channel_args = args->args; + GPR_ASSERT(self->result_->transport != nullptr); + self->endpoint_ = args->endpoint; + self->Ref().release(); // Ref held by OnReceiveSettings() + GRPC_CLOSURE_INIT(&self->on_receive_settings_, OnReceiveSettings, self, + grpc_schedule_on_exec_ctx); + self->Ref().release(); // Ref held by OnTimeout() + grpc_chttp2_transport_start_reading(self->result_->transport, + args->read_buffer, + &self->on_receive_settings_, nullptr); + GRPC_CLOSURE_INIT(&self->on_timeout_, OnTimeout, self, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&self->timer_, self->args_.deadline, &self->on_timeout_); + } else { + // If the handshaking succeeded but there is no endpoint, then the + // handshaker may have handed off the connection to some external + // code. Just verify that exit_early flag is set. + GPR_DEBUG_ASSERT(args->exit_early); + NullThenSchedClosure(DEBUG_LOCATION, &self->notify_, error); + } + self->handshake_mgr_.reset(); + } + self->Unref(); +} + +void Chttp2Connector::OnReceiveSettings(void* arg, grpc_error_handle error) { + Chttp2Connector* self = static_cast(arg); + { + MutexLock lock(&self->mu_); + if (!self->notify_error_.has_value()) { + grpc_endpoint_delete_from_pollset_set(self->endpoint_, + self->args_.interested_parties); + if (error != GRPC_ERROR_NONE) { + // Transport got an error while waiting on SETTINGS frame. + // TODO(yashykt): The following two lines should be moved to + // SubchannelConnector::Result::Reset() + grpc_transport_destroy(self->result_->transport); + grpc_channel_args_destroy(self->result_->channel_args); + self->result_->Reset(); + } + self->MaybeNotify(GRPC_ERROR_REF(error)); + grpc_timer_cancel(&self->timer_); + } else { + // OnTimeout() was already invoked. Call Notify() again so that notify_ + // can be invoked. + self->MaybeNotify(GRPC_ERROR_NONE); + } + } + self->Unref(); +} + +void Chttp2Connector::OnTimeout(void* arg, grpc_error_handle /*error*/) { + Chttp2Connector* self = static_cast(arg); + { + MutexLock lock(&self->mu_); + if (!self->notify_error_.has_value()) { + // The transport did not receive the settings frame in time. Destroy the + // transport. + grpc_endpoint_delete_from_pollset_set(self->endpoint_, + self->args_.interested_parties); + // TODO(yashykt): The following two lines should be moved to + // SubchannelConnector::Result::Reset() + grpc_transport_destroy(self->result_->transport); + grpc_channel_args_destroy(self->result_->channel_args); + self->result_->Reset(); + self->MaybeNotify(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "connection attempt timed out before receiving SETTINGS frame")); + } else { + // OnReceiveSettings() was already invoked. Call Notify() again so that + // notify_ can be invoked. + self->MaybeNotify(GRPC_ERROR_NONE); + } + } + self->Unref(); +} + +void Chttp2Connector::MaybeNotify(grpc_error_handle error) { + if (notify_error_.has_value()) { + GRPC_ERROR_UNREF(error); + NullThenSchedClosure(DEBUG_LOCATION, ¬ify_, notify_error_.value()); + // Clear state for a new Connect(). + // Clear out the endpoint_, since it is the responsibility of + // the transport to shut it down. + endpoint_ = nullptr; + notify_error_.reset(); + } else { + notify_error_ = error; + } +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/client/insecure/channel_create.cc b/src/core/ext/transport/chttp2/client/insecure/channel_create.cc new file mode 100644 index 00000000..8d5674ad --- /dev/null +++ b/src/core/ext/transport/chttp2/client/insecure/channel_create.cc @@ -0,0 +1,119 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/transport/chttp2/client/chttp2_connector.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" + +namespace grpc_core { + +class Chttp2InsecureClientChannelFactory : public ClientChannelFactory { + public: + RefCountedPtr CreateSubchannel( + const grpc_resolved_address& address, + const grpc_channel_args* args) override { + return Subchannel::Create(MakeOrphanable(), address, args); + } +}; + +namespace { + +grpc_channel* CreateChannel(const char* target, const grpc_channel_args* args, + grpc_error_handle* error) { + if (target == nullptr) { + gpr_log(GPR_ERROR, "cannot create channel with NULL target name"); + if (error != nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("channel target is NULL"); + } + return nullptr; + } + // Add channel arg containing the server URI. + grpc_core::UniquePtr canonical_target = + ResolverRegistry::AddDefaultPrefixIfNeeded(target); + grpc_arg arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVER_URI), canonical_target.get()); + const char* to_remove[] = {GRPC_ARG_SERVER_URI}; + grpc_channel_args* new_args = + grpc_channel_args_copy_and_add_and_remove(args, to_remove, 1, &arg, 1); + grpc_channel* channel = grpc_channel_create( + target, new_args, GRPC_CLIENT_CHANNEL, nullptr, nullptr, 0, error); + grpc_channel_args_destroy(new_args); + return channel; +} + +} // namespace + +} // namespace grpc_core + +namespace { + +grpc_core::Chttp2InsecureClientChannelFactory* g_factory; +gpr_once g_factory_once = GPR_ONCE_INIT; + +void FactoryInit() { + g_factory = new grpc_core::Chttp2InsecureClientChannelFactory(); +} + +} // namespace + +/* Create a client channel: + Asynchronously: - resolve target + - connect to it (trying alternatives as presented) + - perform handshakes */ +grpc_channel* grpc_insecure_channel_create(const char* target, + const grpc_channel_args* args, + void* reserved) { + grpc_core::ExecCtx exec_ctx; + args = grpc_channel_args_remove_grpc_internal(args); + GRPC_API_TRACE( + "grpc_insecure_channel_create(target=%s, args=%p, reserved=%p)", 3, + (target, args, reserved)); + GPR_ASSERT(reserved == nullptr); + // Add channel arg containing the client channel factory. + gpr_once_init(&g_factory_once, FactoryInit); + grpc_arg arg = grpc_core::ClientChannelFactory::CreateChannelArg(g_factory); + const char* arg_to_remove = arg.key; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + args, &arg_to_remove, 1, &arg, 1); + grpc_error_handle error = GRPC_ERROR_NONE; + // Create channel. + grpc_channel* channel = grpc_core::CreateChannel(target, new_args, &error); + // Clean up. + grpc_channel_args_destroy(new_args); + grpc_channel_args_destroy(args); + if (channel == nullptr) { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + channel = grpc_lame_client_channel_create( + target, status, "Failed to create client channel"); + } + return channel; +} diff --git a/src/core/ext/transport/chttp2/client/insecure/channel_create_posix.cc b/src/core/ext/transport/chttp2/client/insecure/channel_create_posix.cc new file mode 100644 index 00000000..db9d4d53 --- /dev/null +++ b/src/core/ext/transport/chttp2/client/insecure/channel_create_posix.cc @@ -0,0 +1,95 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#ifdef GPR_SUPPORT_CHANNELS_FROM_FD + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/tcp_client_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/transport.h" + +grpc_channel* grpc_insecure_channel_create_from_fd( + const char* target, int fd, const grpc_channel_args* args) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_insecure_channel_create(target=%p, fd=%d, args=%p)", 3, + (target, fd, args)); + + grpc_arg default_authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test.authority")); + grpc_channel_args* final_args = + grpc_channel_args_copy_and_add(args, &default_authority_arg, 1); + + int flags = fcntl(fd, F_GETFL, 0); + GPR_ASSERT(fcntl(fd, F_SETFL, flags | O_NONBLOCK) == 0); + grpc_resource_quota* resource_quota = + grpc_resource_quota_from_channel_args(args, true); + grpc_slice_allocator* allocator = grpc_slice_allocator_create( + resource_quota, "fd-client:endpoint", final_args); + grpc_endpoint* client = grpc_tcp_client_create_from_fd( + grpc_fd_create(fd, "client", true), args, "fd-client", allocator); + grpc_transport* transport = grpc_create_chttp2_transport( + final_args, client, true, + grpc_resource_user_create(resource_quota, "fd-client:transport")); + grpc_resource_quota_unref_internal(resource_quota); + GPR_ASSERT(transport); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_channel* channel = + grpc_channel_create(target, final_args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, &error); + grpc_channel_args_destroy(final_args); + if (channel != nullptr) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + grpc_core::ExecCtx::Get()->Flush(); + } else { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(transport); + channel = grpc_lame_client_channel_create( + target, status, "Failed to create client channel"); + } + + return channel; +} + +#else // !GPR_SUPPORT_CHANNELS_FROM_FD + +grpc_channel* grpc_insecure_channel_create_from_fd( + const char* /* target */, int /* fd */, + const grpc_channel_args* /* args */) { + GPR_ASSERT(0); + return nullptr; +} + +#endif // GPR_SUPPORT_CHANNELS_FROM_FD diff --git a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc new file mode 100644 index 00000000..68cd3a91 --- /dev/null +++ b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc @@ -0,0 +1,189 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/transport/chttp2/client/chttp2_connector.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { + +class Chttp2SecureClientChannelFactory : public ClientChannelFactory { + public: + RefCountedPtr CreateSubchannel( + const grpc_resolved_address& address, + const grpc_channel_args* args) override { + grpc_channel_args* new_args = GetSecureNamingChannelArgs(args); + if (new_args == nullptr) { + gpr_log(GPR_ERROR, + "Failed to create channel args during subchannel creation."); + return nullptr; + } + RefCountedPtr s = Subchannel::Create( + MakeOrphanable(), address, new_args); + grpc_channel_args_destroy(new_args); + return s; + } + + private: + static grpc_channel_args* GetSecureNamingChannelArgs( + const grpc_channel_args* args) { + grpc_channel_credentials* channel_credentials = + grpc_channel_credentials_find_in_args(args); + if (channel_credentials == nullptr) { + gpr_log(GPR_ERROR, + "Can't create subchannel: channel credentials missing for secure " + "channel."); + return nullptr; + } + // Make sure security connector does not already exist in args. + if (grpc_security_connector_find_in_args(args) != nullptr) { + gpr_log(GPR_ERROR, + "Can't create subchannel: security connector already present in " + "channel args."); + return nullptr; + } + // Find the authority to use in the security connector. + const char* authority = + grpc_channel_args_find_string(args, GRPC_ARG_DEFAULT_AUTHORITY); + GPR_ASSERT(authority != nullptr); + // Create the security connector using the credentials and target name. + grpc_channel_args* new_args_from_connector = nullptr; + RefCountedPtr + subchannel_security_connector = + channel_credentials->create_security_connector( + /*call_creds=*/nullptr, authority, args, + &new_args_from_connector); + if (subchannel_security_connector == nullptr) { + gpr_log(GPR_ERROR, + "Failed to create secure subchannel for secure name '%s'", + authority); + return nullptr; + } + grpc_arg new_security_connector_arg = + grpc_security_connector_to_arg(subchannel_security_connector.get()); + grpc_channel_args* new_args = grpc_channel_args_copy_and_add( + new_args_from_connector != nullptr ? new_args_from_connector : args, + &new_security_connector_arg, 1); + subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create"); + grpc_channel_args_destroy(new_args_from_connector); + return new_args; + } +}; + +namespace { + +grpc_channel* CreateChannel(const char* target, const grpc_channel_args* args, + grpc_error_handle* error) { + if (target == nullptr) { + gpr_log(GPR_ERROR, "cannot create channel with NULL target name"); + if (error != nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("channel target is NULL"); + } + return nullptr; + } + // Add channel arg containing the server URI. + grpc_core::UniquePtr canonical_target = + ResolverRegistry::AddDefaultPrefixIfNeeded(target); + grpc_arg arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVER_URI), canonical_target.get()); + const char* to_remove[] = {GRPC_ARG_SERVER_URI}; + grpc_channel_args* new_args = + grpc_channel_args_copy_and_add_and_remove(args, to_remove, 1, &arg, 1); + grpc_channel* channel = grpc_channel_create( + target, new_args, GRPC_CLIENT_CHANNEL, nullptr, nullptr, 0, error); + grpc_channel_args_destroy(new_args); + return channel; +} + +} // namespace + +} // namespace grpc_core + +namespace { + +grpc_core::Chttp2SecureClientChannelFactory* g_factory; +gpr_once g_factory_once = GPR_ONCE_INIT; + +void FactoryInit() { + g_factory = new grpc_core::Chttp2SecureClientChannelFactory(); +} + +} // namespace + +// Create a secure client channel: +// Asynchronously: - resolve target +// - connect to it (trying alternatives as presented) +// - perform handshakes +grpc_channel* grpc_secure_channel_create(grpc_channel_credentials* creds, + const char* target, + const grpc_channel_args* args, + void* reserved) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_secure_channel_create(creds=%p, target=%s, args=%p, " + "reserved=%p)", + 4, ((void*)creds, target, (void*)args, (void*)reserved)); + GPR_ASSERT(reserved == nullptr); + args = grpc_channel_args_remove_grpc_internal(args); + grpc_channel* channel = nullptr; + grpc_error_handle error = GRPC_ERROR_NONE; + if (creds != nullptr) { + // Add channel args containing the client channel factory and channel + // credentials. + gpr_once_init(&g_factory_once, FactoryInit); + grpc_arg channel_factory_arg = + grpc_core::ClientChannelFactory::CreateChannelArg(g_factory); + grpc_arg args_to_add[] = {channel_factory_arg, + grpc_channel_credentials_to_arg(creds)}; + const char* arg_to_remove = channel_factory_arg.key; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + args, &arg_to_remove, 1, args_to_add, GPR_ARRAY_SIZE(args_to_add)); + new_args = creds->update_arguments(new_args); + // Create channel. + channel = grpc_core::CreateChannel(target, new_args, &error); + // Clean up. + grpc_channel_args_destroy(new_args); + } + grpc_channel_args_destroy(args); + if (channel == nullptr) { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + channel = grpc_lame_client_channel_create( + target, status, "Failed to create secure client channel"); + } + return channel; +} diff --git a/src/core/ext/transport/chttp2/server/chttp2_server.cc b/src/core/ext/transport/chttp2/server/chttp2_server.cc new file mode 100644 index 00000000..c2e1e14c --- /dev/null +++ b/src/core/ext/transport/chttp2/server/chttp2_server.cc @@ -0,0 +1,923 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/server/chttp2_server.h" + +#include +#include +#include + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/server.h" + +namespace grpc_core { +namespace { + +const char kUnixUriPrefix[] = "unix:"; +const char kUnixAbstractUriPrefix[] = "unix-abstract:"; + +class Chttp2ServerListener : public Server::ListenerInterface { + public: + static grpc_error_handle Create(Server* server, grpc_resolved_address* addr, + grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier, + int* port_num); + + static grpc_error_handle CreateWithAcceptor( + Server* server, const char* name, grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier); + + // Do not instantiate directly. Use one of the factory methods above. + Chttp2ServerListener(Server* server, grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier, + grpc_resource_quota* resource_quota); + ~Chttp2ServerListener() override; + + void Start(Server* server, + const std::vector* pollsets) override; + + channelz::ListenSocketNode* channelz_listen_socket_node() const override { + return channelz_listen_socket_.get(); + } + + void SetOnDestroyDone(grpc_closure* on_destroy_done) override; + + void Orphan() override; + + private: + class ConfigFetcherWatcher + : public grpc_server_config_fetcher::WatcherInterface { + public: + explicit ConfigFetcherWatcher(RefCountedPtr listener) + : listener_(std::move(listener)) {} + + void UpdateConnectionManager( + RefCountedPtr + connection_manager) override; + + void StopServing() override; + + private: + RefCountedPtr listener_; + }; + + class ActiveConnection : public InternallyRefCounted { + public: + class HandshakingState : public InternallyRefCounted { + public: + HandshakingState(RefCountedPtr connection_ref, + grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor, + grpc_channel_args* args, + grpc_resource_user* channel_resource_user); + + ~HandshakingState() override; + + void Orphan() override; + + void Start(grpc_endpoint* endpoint, grpc_channel_args* args); + + // Needed to be able to grab an external ref in ActiveConnection::Start() + using InternallyRefCounted::Ref; + + private: + static void OnTimeout(void* arg, grpc_error_handle error); + static void OnReceiveSettings(void* arg, grpc_error_handle /* error */); + static void OnHandshakeDone(void* arg, grpc_error_handle error); + RefCountedPtr const connection_; + grpc_pollset* const accepting_pollset_; + grpc_tcp_server_acceptor* acceptor_; + RefCountedPtr handshake_mgr_ + ABSL_GUARDED_BY(&connection_->mu_); + // State for enforcing handshake timeout on receiving HTTP/2 settings. + grpc_millis const deadline_; + grpc_timer timer_ ABSL_GUARDED_BY(&connection_->mu_); + grpc_closure on_timeout_ ABSL_GUARDED_BY(&connection_->mu_); + grpc_closure on_receive_settings_ ABSL_GUARDED_BY(&connection_->mu_); + grpc_pollset_set* const interested_parties_; + grpc_resource_user* channel_resource_user_; + }; + + ActiveConnection(grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor, + grpc_channel_args* args, + grpc_resource_user* channel_resource_user); + ~ActiveConnection() override; + + void Orphan() override; + + void SendGoAway(); + + void Start(RefCountedPtr listener, + grpc_endpoint* endpoint, grpc_channel_args* args); + + // Needed to be able to grab an external ref in + // Chttp2ServerListener::OnAccept() + using InternallyRefCounted::Ref; + + private: + static void OnClose(void* arg, grpc_error_handle error); + + RefCountedPtr listener_; + Mutex mu_ ABSL_ACQUIRED_AFTER(&listener_->mu_); + // Set by HandshakingState before the handshaking begins and reset when + // handshaking is done. + OrphanablePtr handshaking_state_ ABSL_GUARDED_BY(&mu_); + // Set by HandshakingState when handshaking is done and a valid transport is + // created. + grpc_chttp2_transport* transport_ ABSL_GUARDED_BY(&mu_) = nullptr; + grpc_closure on_close_; + bool shutdown_ ABSL_GUARDED_BY(&mu_) = false; + }; + + // To allow access to RefCounted<> like interface. + friend class RefCountedPtr; + + // Should only be called once so as to start the TCP server. + void StartListening(); + + static void OnAccept(void* arg, grpc_endpoint* tcp, + grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor); + + static void TcpServerShutdownComplete(void* arg, grpc_error_handle error); + + static void DestroyListener(Server* /*server*/, void* arg, + grpc_closure* destroy_done); + + // The interface required by RefCountedPtr<> has been manually implemented + // here to take a ref on tcp_server_ instead. Note that, the handshaker needs + // tcp_server_ to exist for the lifetime of the handshake since it's needed by + // acceptor. Sharing refs between the listener and tcp_server_ is just an + // optimization to avoid taking additional refs on the listener, since + // TcpServerShutdownComplete already holds a ref to the listener. + void IncrementRefCount() { grpc_tcp_server_ref(tcp_server_); } + void IncrementRefCount(const DebugLocation& /* location */, + const char* /* reason */) { + IncrementRefCount(); + } + + RefCountedPtr Ref() GRPC_MUST_USE_RESULT { + IncrementRefCount(); + return RefCountedPtr(this); + } + RefCountedPtr Ref(const DebugLocation& /* location */, + const char* /* reason */) + GRPC_MUST_USE_RESULT { + return Ref(); + } + + void Unref() { grpc_tcp_server_unref(tcp_server_); } + void Unref(const DebugLocation& /* location */, const char* /* reason */) { + Unref(); + } + + Server* const server_; + grpc_tcp_server* tcp_server_; + grpc_resolved_address resolved_address_; + Chttp2ServerArgsModifier const args_modifier_; + ConfigFetcherWatcher* config_fetcher_watcher_ = nullptr; + grpc_channel_args* args_; + Mutex connection_manager_mu_; + RefCountedPtr + connection_manager_ ABSL_GUARDED_BY(connection_manager_mu_); + Mutex mu_; + // Signals whether grpc_tcp_server_start() has been called. + bool started_ ABSL_GUARDED_BY(mu_) = false; + // Signals whether grpc_tcp_server_start() has completed. + CondVar started_cv_ ABSL_GUARDED_BY(mu_); + // Signals whether new requests/connections are to be accepted. + bool is_serving_ ABSL_GUARDED_BY(mu_) = false; + // Signals whether the application has triggered shutdown. + bool shutdown_ ABSL_GUARDED_BY(mu_) = false; + std::map> connections_ + ABSL_GUARDED_BY(mu_); + grpc_closure tcp_server_shutdown_complete_ ABSL_GUARDED_BY(mu_); + grpc_closure* on_destroy_done_ ABSL_GUARDED_BY(mu_) = nullptr; + RefCountedPtr channelz_listen_socket_; + grpc_resource_quota* resource_quota_; +}; + +// +// Chttp2ServerListener::ConfigFetcherWatcher +// + +void Chttp2ServerListener::ConfigFetcherWatcher::UpdateConnectionManager( + RefCountedPtr + connection_manager) { + RefCountedPtr + connection_manager_to_destroy; + { + MutexLock lock(&listener_->connection_manager_mu_); + connection_manager_to_destroy = listener_->connection_manager_; + listener_->connection_manager_ = std::move(connection_manager); + } + { + MutexLock lock(&listener_->mu_); + if (listener_->shutdown_) { + return; + } + listener_->is_serving_ = true; + if (listener_->started_) return; + } + int port_temp; + grpc_error_handle error = grpc_tcp_server_add_port( + listener_->tcp_server_, &listener_->resolved_address_, &port_temp); + if (error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(error); + gpr_log(GPR_ERROR, "Error adding port to server: %s", + grpc_error_std_string(error).c_str()); + // TODO(yashykt): We wouldn't need to assert here if we bound to the + // port earlier during AddPort. + GPR_ASSERT(0); + } + listener_->StartListening(); + { + MutexLock lock(&listener_->mu_); + listener_->started_ = true; + listener_->started_cv_.SignalAll(); + } +} + +void Chttp2ServerListener::ConfigFetcherWatcher::StopServing() { + std::map> connections; + { + MutexLock lock(&listener_->mu_); + listener_->is_serving_ = false; + connections = std::move(listener_->connections_); + } + // Send GOAWAYs on the transports so that they disconnected when existing RPCs + // finish. + for (auto& connection : connections) { + connection.first->SendGoAway(); + } +} + +// +// Chttp2ServerListener::ActiveConnection::HandshakingState +// + +grpc_millis GetConnectionDeadline(const grpc_channel_args* args) { + int timeout_ms = + grpc_channel_args_find_integer(args, GRPC_ARG_SERVER_HANDSHAKE_TIMEOUT_MS, + {120 * GPR_MS_PER_SEC, 1, INT_MAX}); + return ExecCtx::Get()->Now() + timeout_ms; +} + +Chttp2ServerListener::ActiveConnection::HandshakingState::HandshakingState( + RefCountedPtr connection_ref, + grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, + grpc_channel_args* args, grpc_resource_user* channel_resource_user) + : connection_(std::move(connection_ref)), + accepting_pollset_(accepting_pollset), + acceptor_(acceptor), + handshake_mgr_(MakeRefCounted()), + deadline_(GetConnectionDeadline(args)), + interested_parties_(grpc_pollset_set_create()), + channel_resource_user_(channel_resource_user) { + grpc_pollset_set_add_pollset(interested_parties_, accepting_pollset_); + CoreConfiguration::Get().handshaker_registry().AddHandshakers( + HANDSHAKER_SERVER, args, interested_parties_, handshake_mgr_.get()); +} + +Chttp2ServerListener::ActiveConnection::HandshakingState::~HandshakingState() { + grpc_pollset_set_del_pollset(interested_parties_, accepting_pollset_); + grpc_pollset_set_destroy(interested_parties_); + if (channel_resource_user_ != nullptr) { + grpc_resource_user_unref(channel_resource_user_); + } + gpr_free(acceptor_); +} + +void Chttp2ServerListener::ActiveConnection::HandshakingState::Orphan() { + { + MutexLock lock(&connection_->mu_); + if (handshake_mgr_ != nullptr) { + handshake_mgr_->Shutdown( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Listener stopped serving.")); + } + } + Unref(); +} + +void Chttp2ServerListener::ActiveConnection::HandshakingState::Start( + grpc_endpoint* endpoint, grpc_channel_args* args) { + Ref().release(); // Held by OnHandshakeDone + RefCountedPtr handshake_mgr; + { + MutexLock lock(&connection_->mu_); + if (handshake_mgr_ == nullptr) return; + handshake_mgr = handshake_mgr_; + } + handshake_mgr->DoHandshake(endpoint, args, deadline_, acceptor_, + OnHandshakeDone, this); +} + +void Chttp2ServerListener::ActiveConnection::HandshakingState::OnTimeout( + void* arg, grpc_error_handle error) { + HandshakingState* self = static_cast(arg); + // Note that we may be called with GRPC_ERROR_NONE when the timer fires + // or with an error indicating that the timer system is being shut down. + if (error != GRPC_ERROR_CANCELLED) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->disconnect_with_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Did not receive HTTP/2 settings before handshake timeout"); + grpc_chttp2_transport* transport = nullptr; + { + MutexLock lock(&self->connection_->mu_); + transport = self->connection_->transport_; + } + grpc_transport_perform_op(&transport->base, op); + } + self->Unref(); +} + +void Chttp2ServerListener::ActiveConnection::HandshakingState:: + OnReceiveSettings(void* arg, grpc_error_handle /* error */) { + HandshakingState* self = static_cast(arg); + grpc_timer_cancel(&self->timer_); + self->Unref(); +} + +void Chttp2ServerListener::ActiveConnection::HandshakingState::OnHandshakeDone( + void* arg, grpc_error_handle error) { + auto* args = static_cast(arg); + HandshakingState* self = static_cast(args->user_data); + OrphanablePtr handshaking_state_ref; + RefCountedPtr handshake_mgr; + bool cleanup_connection = false; + { + MutexLock connection_lock(&self->connection_->mu_); + if (error != GRPC_ERROR_NONE || self->connection_->shutdown_) { + std::string error_str = grpc_error_std_string(error); + gpr_log(GPR_DEBUG, "Handshaking failed: %s", error_str.c_str()); + cleanup_connection = true; + if (error == GRPC_ERROR_NONE && args->endpoint != nullptr) { + // We were shut down or stopped serving after handshaking completed + // successfully, so destroy the endpoint here. + // TODO(ctiller): It is currently necessary to shutdown endpoints + // before destroying them, even if we know that there are no + // pending read/write callbacks. This should be fixed, at which + // point this can be removed. + grpc_endpoint_shutdown(args->endpoint, GRPC_ERROR_NONE); + grpc_endpoint_destroy(args->endpoint); + grpc_channel_args_destroy(args->args); + grpc_slice_buffer_destroy_internal(args->read_buffer); + gpr_free(args->read_buffer); + } + } else { + // If the handshaking succeeded but there is no endpoint, then the + // handshaker may have handed off the connection to some external + // code, so we can just clean up here without creating a transport. + if (args->endpoint != nullptr) { + grpc_transport* transport = grpc_create_chttp2_transport( + args->args, args->endpoint, false, + grpc_resource_user_create( + self->connection_->listener_->resource_quota_, + absl::StrCat(grpc_endpoint_get_peer(args->endpoint), + ":chttp2_server_transport"))); + grpc_error_handle channel_init_err = + self->connection_->listener_->server_->SetupTransport( + transport, self->accepting_pollset_, args->args, + grpc_chttp2_transport_get_socket_node(transport), + self->channel_resource_user_, GRPC_RESOURCE_QUOTA_CHANNEL_SIZE); + self->channel_resource_user_ = nullptr; + if (channel_init_err == GRPC_ERROR_NONE) { + // Use notify_on_receive_settings callback to enforce the + // handshake deadline. + // Note: The reinterpret_cast<>s here are safe, because + // grpc_chttp2_transport is a C-style extension of + // grpc_transport, so this is morally equivalent of a + // static_cast<> to a derived class. + // TODO(roth): Change to static_cast<> when we C++-ify the + // transport API. + self->connection_->transport_ = + reinterpret_cast(transport); + GRPC_CHTTP2_REF_TRANSPORT(self->connection_->transport_, + "ActiveConnection"); // Held by connection_ + self->Ref().release(); // Held by OnReceiveSettings(). + GRPC_CLOSURE_INIT(&self->on_receive_settings_, OnReceiveSettings, + self, grpc_schedule_on_exec_ctx); + // If the listener has been configured with a config fetcher, we need + // to watch on the transport being closed so that we can an updated + // list of active connections. + grpc_closure* on_close = nullptr; + if (self->connection_->listener_->config_fetcher_watcher_ != + nullptr) { + // Refs helds by OnClose() + self->connection_->Ref().release(); + on_close = &self->connection_->on_close_; + } else { + // Remove the connection from the connections_ map since OnClose() + // will not be invoked when a config fetcher is set. + cleanup_connection = true; + } + grpc_chttp2_transport_start_reading(transport, args->read_buffer, + &self->on_receive_settings_, + on_close); + grpc_channel_args_destroy(args->args); + self->Ref().release(); // Held by OnTimeout(). + GRPC_CLOSURE_INIT(&self->on_timeout_, OnTimeout, self, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&self->timer_, self->deadline_, &self->on_timeout_); + } else { + // Failed to create channel from transport. Clean up. + gpr_log(GPR_ERROR, "Failed to create channel: %s", + grpc_error_std_string(channel_init_err).c_str()); + GRPC_ERROR_UNREF(channel_init_err); + grpc_transport_destroy(transport); + grpc_slice_buffer_destroy_internal(args->read_buffer); + gpr_free(args->read_buffer); + cleanup_connection = true; + grpc_channel_args_destroy(args->args); + } + } else { + cleanup_connection = true; + } + } + // Since the handshake manager is done, the connection no longer needs to + // shutdown the handshake when the listener needs to stop serving. + // Avoid calling the destructor of HandshakeManager and HandshakingState + // from within the critical region. + handshake_mgr = std::move(self->handshake_mgr_); + handshaking_state_ref = std::move(self->connection_->handshaking_state_); + } + gpr_free(self->acceptor_); + self->acceptor_ = nullptr; + OrphanablePtr connection; + if (self->channel_resource_user_ != nullptr) { + grpc_resource_user_free(self->channel_resource_user_, + GRPC_RESOURCE_QUOTA_CHANNEL_SIZE); + } + if (cleanup_connection) { + MutexLock listener_lock(&self->connection_->listener_->mu_); + auto it = self->connection_->listener_->connections_.find( + self->connection_.get()); + if (it != self->connection_->listener_->connections_.end()) { + connection = std::move(it->second); + self->connection_->listener_->connections_.erase(it); + } + } + self->Unref(); +} + +// +// Chttp2ServerListener::ActiveConnection +// + +Chttp2ServerListener::ActiveConnection::ActiveConnection( + grpc_pollset* accepting_pollset, grpc_tcp_server_acceptor* acceptor, + grpc_channel_args* args, grpc_resource_user* channel_resource_user) + : handshaking_state_(MakeOrphanable( + Ref(), accepting_pollset, acceptor, args, channel_resource_user)) { + GRPC_CLOSURE_INIT(&on_close_, ActiveConnection::OnClose, this, + grpc_schedule_on_exec_ctx); +} + +Chttp2ServerListener::ActiveConnection::~ActiveConnection() { + if (transport_ != nullptr) { + GRPC_CHTTP2_UNREF_TRANSPORT(transport_, "ActiveConnection"); + } +} + +void Chttp2ServerListener::ActiveConnection::Orphan() { + OrphanablePtr handshaking_state; + { + MutexLock lock(&mu_); + shutdown_ = true; + // Reset handshaking_state_ since we have been orphaned by the listener + // signaling that the listener has stopped serving. + handshaking_state = std::move(handshaking_state_); + } + Unref(); +} + +void Chttp2ServerListener::ActiveConnection::SendGoAway() { + grpc_chttp2_transport* transport = nullptr; + { + MutexLock lock(&mu_); + transport = transport_; + } + if (transport != nullptr) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->goaway_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Server is stopping to serve requests."); + grpc_transport_perform_op(&transport->base, op); + } +} + +void Chttp2ServerListener::ActiveConnection::Start( + RefCountedPtr listener, grpc_endpoint* endpoint, + grpc_channel_args* args) { + RefCountedPtr handshaking_state_ref; + listener_ = std::move(listener); + { + MutexLock lock(&mu_); + if (shutdown_) return; + // Hold a ref to HandshakingState to allow starting the handshake outside + // the critical region. + handshaking_state_ref = handshaking_state_->Ref(); + } + handshaking_state_ref->Start(endpoint, args); +} + +void Chttp2ServerListener::ActiveConnection::OnClose( + void* arg, grpc_error_handle /* error */) { + ActiveConnection* self = static_cast(arg); + OrphanablePtr connection; + { + MutexLock listener_lock(&self->listener_->mu_); + MutexLock connection_lock(&self->mu_); + // The node was already deleted from the connections_ list if the connection + // is shutdown. + if (!self->shutdown_) { + auto it = self->listener_->connections_.find(self); + if (it != self->listener_->connections_.end()) { + connection = std::move(it->second); + self->listener_->connections_.erase(it); + } + } + } + self->Unref(); +} + +// +// Chttp2ServerListener +// + +grpc_error_handle Chttp2ServerListener::Create( + Server* server, grpc_resolved_address* addr, grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier, int* port_num) { + Chttp2ServerListener* listener = nullptr; + // The bulk of this method is inside of a lambda to make cleanup + // easier without using goto. + grpc_error_handle error = [&]() { + grpc_error_handle error = GRPC_ERROR_NONE; + // Create Chttp2ServerListener. + listener = new Chttp2ServerListener( + server, args, args_modifier, + grpc_resource_quota_from_channel_args(args, true)); + grpc_resource_quota_ref_internal(listener->resource_quota_); + error = grpc_tcp_server_create( + &listener->tcp_server_shutdown_complete_, args, + grpc_slice_allocator_factory_create(listener->resource_quota_), + &listener->tcp_server_); + if (error != GRPC_ERROR_NONE) return error; + if (server->config_fetcher() != nullptr) { + listener->resolved_address_ = *addr; + // TODO(yashykt): Consider binding so as to be able to return the port + // number. + } else { + error = grpc_tcp_server_add_port(listener->tcp_server_, addr, port_num); + if (error != GRPC_ERROR_NONE) return error; + } + // Create channelz node. + if (grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_CHANNELZ, + GRPC_ENABLE_CHANNELZ_DEFAULT)) { + std::string string_address = grpc_sockaddr_to_uri(addr); + listener->channelz_listen_socket_ = + MakeRefCounted( + string_address.c_str(), + absl::StrFormat("chttp2 listener %s", string_address.c_str())); + } + // Register with the server only upon success + server->AddListener(OrphanablePtr(listener)); + return GRPC_ERROR_NONE; + }(); + if (error != GRPC_ERROR_NONE) { + if (listener != nullptr) { + if (listener->tcp_server_ != nullptr) { + // listener is deleted when tcp_server_ is shutdown. + grpc_tcp_server_unref(listener->tcp_server_); + } else { + delete listener; + } + } else { + grpc_channel_args_destroy(args); + } + } + return error; +} + +grpc_error_handle Chttp2ServerListener::CreateWithAcceptor( + Server* server, const char* name, grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier) { + Chttp2ServerListener* listener = new Chttp2ServerListener( + server, args, args_modifier, + grpc_resource_quota_from_channel_args(args, true)); + grpc_resource_quota_ref_internal(listener->resource_quota_); + grpc_error_handle error = grpc_tcp_server_create( + &listener->tcp_server_shutdown_complete_, args, + grpc_slice_allocator_factory_create(listener->resource_quota_), + &listener->tcp_server_); + if (error != GRPC_ERROR_NONE) { + delete listener; + return error; + } + // TODO(yangg) channelz + TcpServerFdHandler** arg_val = + grpc_channel_args_find_pointer(args, name); + *arg_val = grpc_tcp_server_create_fd_handler(listener->tcp_server_); + server->AddListener(OrphanablePtr(listener)); + return GRPC_ERROR_NONE; +} + +Chttp2ServerListener::Chttp2ServerListener( + Server* server, grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier, grpc_resource_quota* resource_quota) + : server_(server), + args_modifier_(args_modifier), + args_(args), + resource_quota_(resource_quota) { + GRPC_CLOSURE_INIT(&tcp_server_shutdown_complete_, TcpServerShutdownComplete, + this, grpc_schedule_on_exec_ctx); +} + +Chttp2ServerListener::~Chttp2ServerListener() { + // Flush queued work before destroying handshaker factory, since that + // may do a synchronous unref. + ExecCtx::Get()->Flush(); + if (on_destroy_done_ != nullptr) { + ExecCtx::Run(DEBUG_LOCATION, on_destroy_done_, GRPC_ERROR_NONE); + ExecCtx::Get()->Flush(); + } + grpc_resource_quota_unref_internal(resource_quota_); + grpc_channel_args_destroy(args_); +} + +/* Server callback: start listening on our ports */ +void Chttp2ServerListener::Start( + Server* /*server*/, const std::vector* /* pollsets */) { + if (server_->config_fetcher() != nullptr) { + auto watcher = absl::make_unique(Ref()); + config_fetcher_watcher_ = watcher.get(); + server_->config_fetcher()->StartWatch( + grpc_sockaddr_to_string(&resolved_address_, false), std::move(watcher)); + } else { + { + MutexLock lock(&mu_); + started_ = true; + is_serving_ = true; + } + StartListening(); + } +} + +void Chttp2ServerListener::StartListening() { + grpc_tcp_server_start(tcp_server_, &server_->pollsets(), OnAccept, this); +} + +void Chttp2ServerListener::SetOnDestroyDone(grpc_closure* on_destroy_done) { + MutexLock lock(&mu_); + on_destroy_done_ = on_destroy_done; +} + +void Chttp2ServerListener::OnAccept(void* arg, grpc_endpoint* tcp, + grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor) { + Chttp2ServerListener* self = static_cast(arg); + grpc_channel_args* args = self->args_; + grpc_channel_args* args_to_destroy = nullptr; + RefCountedPtr + connection_manager; + { + MutexLock lock(&self->connection_manager_mu_); + connection_manager = self->connection_manager_; + } + auto endpoint_cleanup = [&](grpc_error_handle error) { + grpc_endpoint_shutdown(tcp, error); + grpc_endpoint_destroy(tcp); + gpr_free(acceptor); + }; + if (self->server_->config_fetcher() != nullptr) { + if (connection_manager == nullptr) { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No ConnectionManager configured. Closing connection."); + endpoint_cleanup(error); + return; + } + // TODO(yashykt): Maybe combine the following two arg modifiers into a + // single one. + // Make a copy of the args so as to avoid destroying the original. + args = grpc_channel_args_copy(args); + absl::StatusOr args_result = + connection_manager->UpdateChannelArgsForConnection(args, tcp); + if (!args_result.ok()) { + gpr_log(GPR_DEBUG, "Closing connection: %s", + args_result.status().ToString().c_str()); + endpoint_cleanup( + GRPC_ERROR_CREATE_FROM_CPP_STRING(args_result.status().ToString())); + return; + } + grpc_error_handle error = GRPC_ERROR_NONE; + args = self->args_modifier_(*args_result, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_DEBUG, "Closing connection: %s", + grpc_error_std_string(error).c_str()); + endpoint_cleanup(error); + grpc_channel_args_destroy(args); + return; + } + args_to_destroy = args; + } + grpc_resource_user* channel_resource_user = grpc_resource_user_create( + self->resource_quota_, + absl::StrCat(grpc_endpoint_get_peer(tcp), ":server_channel")); + auto connection = MakeOrphanable( + accepting_pollset, acceptor, args, channel_resource_user); + // We no longer own acceptor + acceptor = nullptr; + // Hold a ref to connection to allow starting handshake outside the + // critical region + RefCountedPtr connection_ref = connection->Ref(); + RefCountedPtr listener_ref; + { + MutexLock lock(&self->mu_); + // Shutdown the the connection if listener's stopped serving. + if (!self->shutdown_ && self->is_serving_) { + if (!grpc_resource_user_safe_alloc(channel_resource_user, + GRPC_RESOURCE_QUOTA_CHANNEL_SIZE)) { + gpr_log( + GPR_INFO, + "Memory quota exhausted, rejecting connection, no handshaking."); + } else { + // This ref needs to be taken in the critical region after having made + // sure that the listener has not been Orphaned, so as to avoid + // heap-use-after-free issues where `Ref()` is invoked when the ref of + // tcp_server_ has already reached 0. (Ref() implementation of + // Chttp2ServerListener is grpc_tcp_server_ref().) + listener_ref = self->Ref(); + self->connections_.emplace(connection.get(), std::move(connection)); + } + } + } + if (connection != nullptr) { + endpoint_cleanup(GRPC_ERROR_NONE); + } else { + connection_ref->Start(std::move(listener_ref), tcp, args); + } + grpc_channel_args_destroy(args_to_destroy); +} + +void Chttp2ServerListener::TcpServerShutdownComplete(void* arg, + grpc_error_handle error) { + Chttp2ServerListener* self = static_cast(arg); + self->channelz_listen_socket_.reset(); + GRPC_ERROR_UNREF(error); + delete self; +} + +/* Server callback: destroy the tcp listener (so we don't generate further + callbacks) */ +void Chttp2ServerListener::Orphan() { + // Cancel the watch before shutting down so as to avoid holding a ref to the + // listener in the watcher. + if (config_fetcher_watcher_ != nullptr) { + server_->config_fetcher()->CancelWatch(config_fetcher_watcher_); + } + std::map> connections; + grpc_tcp_server* tcp_server; + { + MutexLock lock(&mu_); + shutdown_ = true; + is_serving_ = false; + // Orphan the connections so that they can start cleaning up. + connections = std::move(connections_); + // If the listener is currently set to be serving but has not been started + // yet, it means that `grpc_tcp_server_start` is in progress. Wait for the + // operation to finish to avoid causing races. + while (is_serving_ && !started_) { + started_cv_.Wait(&mu_); + } + tcp_server = tcp_server_; + } + grpc_tcp_server_shutdown_listeners(tcp_server); + grpc_tcp_server_unref(tcp_server); +} + +} // namespace + +// +// Chttp2ServerAddPort() +// + +grpc_error_handle Chttp2ServerAddPort(Server* server, const char* addr, + grpc_channel_args* args, + Chttp2ServerArgsModifier args_modifier, + int* port_num) { + if (strncmp(addr, "external:", 9) == 0) { + return grpc_core::Chttp2ServerListener::CreateWithAcceptor( + server, addr, args, args_modifier); + } + *port_num = -1; + grpc_resolved_addresses* resolved = nullptr; + std::vector error_list; + // Using lambda to avoid use of goto. + grpc_error_handle error = [&]() { + grpc_error_handle error = GRPC_ERROR_NONE; + if (absl::StartsWith(addr, kUnixUriPrefix)) { + error = grpc_resolve_unix_domain_address( + addr + sizeof(kUnixUriPrefix) - 1, &resolved); + } else if (absl::StartsWith(addr, kUnixAbstractUriPrefix)) { + error = grpc_resolve_unix_abstract_domain_address( + addr + sizeof(kUnixAbstractUriPrefix) - 1, &resolved); + } else { + error = grpc_blocking_resolve_address(addr, "https", &resolved); + } + if (error != GRPC_ERROR_NONE) return error; + // Create a listener for each resolved address. + for (size_t i = 0; i < resolved->naddrs; i++) { + // If address has a wildcard port (0), use the same port as a previous + // listener. + if (*port_num != -1 && grpc_sockaddr_get_port(&resolved->addrs[i]) == 0) { + grpc_sockaddr_set_port(&resolved->addrs[i], *port_num); + } + int port_temp = -1; + error = grpc_core::Chttp2ServerListener::Create( + server, &resolved->addrs[i], grpc_channel_args_copy(args), + args_modifier, &port_temp); + if (error != GRPC_ERROR_NONE) { + error_list.push_back(error); + } else { + if (*port_num == -1) { + *port_num = port_temp; + } else { + GPR_ASSERT(*port_num == port_temp); + } + } + } + if (error_list.size() == resolved->naddrs) { + std::string msg = + absl::StrFormat("No address added out of total %" PRIuPTR " resolved", + resolved->naddrs); + return GRPC_ERROR_CREATE_REFERENCING_FROM_COPIED_STRING( + msg.c_str(), error_list.data(), error_list.size()); + } else if (!error_list.empty()) { + std::string msg = absl::StrFormat( + "Only %" PRIuPTR " addresses added out of total %" PRIuPTR + " resolved", + resolved->naddrs - error_list.size(), resolved->naddrs); + error = GRPC_ERROR_CREATE_REFERENCING_FROM_COPIED_STRING( + msg.c_str(), error_list.data(), error_list.size()); + gpr_log(GPR_INFO, "WARNING: %s", grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + // we managed to bind some addresses: continue without error + } + return GRPC_ERROR_NONE; + }(); // lambda end + for (grpc_error_handle error : error_list) { + GRPC_ERROR_UNREF(error); + } + grpc_channel_args_destroy(args); + if (resolved != nullptr) { + grpc_resolved_addresses_destroy(resolved); + } + if (error != GRPC_ERROR_NONE) *port_num = 0; + return error; +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/server/insecure/server_chttp2.cc b/src/core/ext/transport/chttp2/server/insecure/server_chttp2.cc new file mode 100644 index 00000000..56588a37 --- /dev/null +++ b/src/core/ext/transport/chttp2/server/insecure/server_chttp2.cc @@ -0,0 +1,53 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/ext/transport/chttp2/server/chttp2_server.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/server.h" + +namespace { + +grpc_channel_args* ModifyArgsForConnection(grpc_channel_args* args, + grpc_error_handle* /*error*/) { + return args; +} + +} // namespace + +int grpc_server_add_insecure_http2_port(grpc_server* server, const char* addr) { + grpc_core::ExecCtx exec_ctx; + int port_num = 0; + GRPC_API_TRACE("grpc_server_add_insecure_http2_port(server=%p, addr=%s)", 2, + (server, addr)); + grpc_error_handle err = grpc_core::Chttp2ServerAddPort( + server->core_server.get(), addr, + grpc_channel_args_copy(server->core_server->channel_args()), + ModifyArgsForConnection, &port_num); + if (err != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(err).c_str()); + + GRPC_ERROR_UNREF(err); + } + return port_num; +} diff --git a/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc b/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc new file mode 100644 index 00000000..e47d8af4 --- /dev/null +++ b/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc @@ -0,0 +1,83 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#ifdef GPR_SUPPORT_CHANNELS_FROM_FD + +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" + +void grpc_server_add_insecure_channel_from_fd(grpc_server* server, + void* reserved, int fd) { + GPR_ASSERT(reserved == nullptr); + + grpc_core::ExecCtx exec_ctx; + grpc_core::Server* core_server = server->core_server.get(); + + const grpc_channel_args* server_args = core_server->channel_args(); + std::string name = absl::StrCat("fd:", fd); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create(name.c_str()); + grpc_endpoint* server_endpoint = grpc_tcp_create( + grpc_fd_create(fd, name.c_str(), true), server_args, name.c_str(), + grpc_slice_allocator_create(resource_quota, name, server_args)); + grpc_transport* transport = grpc_create_chttp2_transport( + server_args, server_endpoint, false /* is_client */, + grpc_resource_user_create(resource_quota, + absl::StrCat(name, ":transport"))); + grpc_error_handle error = core_server->SetupTransport( + transport, nullptr, server_args, nullptr, + grpc_resource_user_create(resource_quota, + absl::StrCat(name, ":channel"))); + grpc_resource_quota_unref_internal(resource_quota); + if (error == GRPC_ERROR_NONE) { + for (grpc_pollset* pollset : core_server->pollsets()) { + grpc_endpoint_add_to_pollset(server_endpoint, pollset); + } + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + gpr_log(GPR_ERROR, "Failed to create channel: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(transport); + } +} + +#else // !GPR_SUPPORT_CHANNELS_FROM_FD + +void grpc_server_add_insecure_channel_from_fd(grpc_server* /* server */, + void* /* reserved */, + int /* fd */) { + GPR_ASSERT(0); +} + +#endif // GPR_SUPPORT_CHANNELS_FROM_FD diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc new file mode 100644 index 00000000..675e554a --- /dev/null +++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/server/chttp2_server.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/server.h" + +namespace { + +grpc_channel_args* ModifyArgsForConnection(grpc_channel_args* args, + grpc_error_handle* error) { + grpc_server_credentials* server_credentials = + grpc_find_server_credentials_in_args(args); + if (server_credentials == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not find server credentials"); + return args; + } + auto security_connector = server_credentials->create_security_connector(args); + if (security_connector == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unable to create secure server with credentials of type ", + server_credentials->type())); + return args; + } + grpc_arg arg_to_add = + grpc_security_connector_to_arg(security_connector.get()); + grpc_channel_args* new_args = + grpc_channel_args_copy_and_add(args, &arg_to_add, 1); + grpc_channel_args_destroy(args); + return new_args; +} + +} // namespace + +int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, + grpc_server_credentials* creds) { + grpc_core::ExecCtx exec_ctx; + grpc_error_handle err = GRPC_ERROR_NONE; + grpc_core::RefCountedPtr sc; + int port_num = 0; + grpc_channel_args* args = nullptr; + GRPC_API_TRACE( + "grpc_server_add_secure_http2_port(" + "server=%p, addr=%s, creds=%p)", + 3, (server, addr, creds)); + // Create security context. + if (creds == nullptr) { + err = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No credentials specified for secure server port (creds==NULL)"); + goto done; + } + // TODO(yashykt): Ideally, we would not want to have different behavior here + // based on whether a config fetcher is configured or not. Currently, we have + // a feature for SSL credentials reloading with an application callback that + // assumes that there is a single security connector. If we delay the creation + // of the security connector to after the creation of the listener(s), we + // would have potentially multiple security connectors which breaks the + // assumption for SSL creds reloading. When the API for SSL creds reloading is + // rewritten, we would be able to make this workaround go away by removing + // that assumption. As an immediate drawback of this workaround, config + // fetchers need to be registered before adding ports to the server. + if (server->core_server->config_fetcher() != nullptr) { + // Create channel args. + grpc_arg arg_to_add = grpc_server_credentials_to_arg(creds); + args = grpc_channel_args_copy_and_add(server->core_server->channel_args(), + &arg_to_add, 1); + } else { + sc = creds->create_security_connector(nullptr); + if (sc == nullptr) { + err = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Unable to create secure server with credentials of type ", + creds->type())); + goto done; + } + grpc_arg args_to_add[2]; + args_to_add[0] = grpc_server_credentials_to_arg(creds); + args_to_add[1] = grpc_security_connector_to_arg(sc.get()); + args = grpc_channel_args_copy_and_add(server->core_server->channel_args(), + args_to_add, + GPR_ARRAY_SIZE(args_to_add)); + } + // Add server port. + err = grpc_core::Chttp2ServerAddPort(server->core_server.get(), addr, args, + ModifyArgsForConnection, &port_num); +done: + sc.reset(DEBUG_LOCATION, "server"); + if (err != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(err).c_str()); + + GRPC_ERROR_UNREF(err); + } + return port_num; +} diff --git a/src/core/ext/transport/chttp2/transport/bin_decoder.cc b/src/core/ext/transport/chttp2/transport/bin_decoder.cc new file mode 100644 index 00000000..4be4cc25 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/bin_decoder.cc @@ -0,0 +1,252 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/bin_decoder.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +static uint8_t decode_table[] = { + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 62, 0x40, 0x40, 0x40, 63, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40}; + +static const uint8_t tail_xtra[4] = {0, 0, 1, 2}; + +static bool input_is_valid(const uint8_t* input_ptr, size_t length) { + size_t i; + + for (i = 0; i < length; ++i) { + if (GPR_UNLIKELY((decode_table[input_ptr[i]] & 0xC0) != 0)) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, invalid character '%c' in base64 " + "input.\n", + static_cast(*input_ptr)); + return false; + } + } + return true; +} + +#define COMPOSE_OUTPUT_BYTE_0(input_ptr) \ + (uint8_t)((decode_table[(input_ptr)[0]] << 2) | \ + (decode_table[(input_ptr)[1]] >> 4)) + +#define COMPOSE_OUTPUT_BYTE_1(input_ptr) \ + (uint8_t)((decode_table[(input_ptr)[1]] << 4) | \ + (decode_table[(input_ptr)[2]] >> 2)) + +#define COMPOSE_OUTPUT_BYTE_2(input_ptr) \ + (uint8_t)((decode_table[(input_ptr)[2]] << 6) | decode_table[(input_ptr)[3]]) + +// By RFC 4648, if the length of the encoded string without padding is 4n+r, +// the length of decoded string is: 1) 3n if r = 0, 2) 3n + 1 if r = 2, 3, or +// 3) invalid if r = 1. +size_t grpc_chttp2_base64_infer_length_after_decode(const grpc_slice& slice) { + size_t len = GRPC_SLICE_LENGTH(slice); + const uint8_t* bytes = GRPC_SLICE_START_PTR(slice); + while (len > 0 && bytes[len - 1] == '=') { + len--; + } + if (GPR_UNLIKELY(GRPC_SLICE_LENGTH(slice) - len > 2)) { + gpr_log(GPR_ERROR, + "Base64 decoding failed. Input has more than 2 paddings."); + return 0; + } + size_t tuples = len / 4; + size_t tail_case = len % 4; + if (GPR_UNLIKELY(tail_case == 1)) { + gpr_log(GPR_ERROR, + "Base64 decoding failed. Input has a length of %zu (without" + " padding), which is invalid.\n", + len); + return 0; + } + return tuples * 3 + tail_xtra[tail_case]; +} + +bool grpc_base64_decode_partial(struct grpc_base64_decode_context* ctx) { + size_t input_tail; + + if (ctx->input_cur > ctx->input_end || ctx->output_cur > ctx->output_end) { + return false; + } + + // Process a block of 4 input characters and 3 output bytes + while (ctx->input_end >= ctx->input_cur + 4 && + ctx->output_end >= ctx->output_cur + 3) { + if (!input_is_valid(ctx->input_cur, 4)) return false; + ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + ctx->output_cur[2] = COMPOSE_OUTPUT_BYTE_2(ctx->input_cur); + ctx->output_cur += 3; + ctx->input_cur += 4; + } + + // Process the tail of input data + input_tail = static_cast(ctx->input_end - ctx->input_cur); + if (input_tail == 4) { + // Process the input data with pad chars + if (ctx->input_cur[3] == '=') { + if (ctx->input_cur[2] == '=' && ctx->output_end >= ctx->output_cur + 1) { + if (!input_is_valid(ctx->input_cur, 2)) return false; + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + ctx->input_cur += 4; + } else if (ctx->output_end >= ctx->output_cur + 2) { + if (!input_is_valid(ctx->input_cur, 3)) return false; + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + ; + ctx->input_cur += 4; + } + } + + } else if (ctx->contains_tail && input_tail > 1) { + // Process the input data without pad chars, but constains_tail is set + if (ctx->output_end >= ctx->output_cur + tail_xtra[input_tail]) { + if (!input_is_valid(ctx->input_cur, input_tail)) return false; + switch (input_tail) { + case 3: + ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + ABSL_FALLTHROUGH_INTENDED; + case 2: + ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + } + ctx->output_cur += tail_xtra[input_tail]; + ctx->input_cur += input_tail; + } + } + + return true; +} + +grpc_slice grpc_chttp2_base64_decode(const grpc_slice& input) { + size_t input_length = GRPC_SLICE_LENGTH(input); + size_t output_length = input_length / 4 * 3; + struct grpc_base64_decode_context ctx; + grpc_slice output; + + if (GPR_UNLIKELY(input_length % 4 != 0)) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, input of " + "grpc_chttp2_base64_decode has a length of %d, which is not a " + "multiple of 4.\n", + static_cast(input_length)); + return grpc_empty_slice(); + } + + if (input_length > 0) { + const uint8_t* input_end = GRPC_SLICE_END_PTR(input); + if (*(--input_end) == '=') { + output_length--; + if (*(--input_end) == '=') { + output_length--; + } + } + } + output = GRPC_SLICE_MALLOC(output_length); + + ctx.input_cur = GRPC_SLICE_START_PTR(input); + ctx.input_end = GRPC_SLICE_END_PTR(input); + ctx.output_cur = GRPC_SLICE_START_PTR(output); + ctx.output_end = GRPC_SLICE_END_PTR(output); + ctx.contains_tail = false; + + if (GPR_UNLIKELY(!grpc_base64_decode_partial(&ctx))) { + char* s = grpc_slice_to_c_string(input); + gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s); + gpr_free(s); + grpc_slice_unref_internal(output); + return grpc_empty_slice(); + } + GPR_ASSERT(ctx.output_cur == GRPC_SLICE_END_PTR(output)); + GPR_ASSERT(ctx.input_cur == GRPC_SLICE_END_PTR(input)); + return output; +} + +grpc_slice grpc_chttp2_base64_decode_with_length(const grpc_slice& input, + size_t output_length) { + size_t input_length = GRPC_SLICE_LENGTH(input); + grpc_slice output = GRPC_SLICE_MALLOC(output_length); + struct grpc_base64_decode_context ctx; + + // The length of a base64 string cannot be 4 * n + 1 + if (GPR_UNLIKELY(input_length % 4 == 1)) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, input of " + "grpc_chttp2_base64_decode_with_length has a length of %d, which " + "has a tail of 1 byte.\n", + static_cast(input_length)); + grpc_slice_unref_internal(output); + return grpc_empty_slice(); + } + + if (GPR_UNLIKELY(output_length > + input_length / 4 * 3 + tail_xtra[input_length % 4])) { + gpr_log( + GPR_ERROR, + "Base64 decoding failed, output_length %d is longer " + "than the max possible output length %d.\n", + static_cast(output_length), + static_cast(input_length / 4 * 3 + tail_xtra[input_length % 4])); + grpc_slice_unref_internal(output); + return grpc_empty_slice(); + } + + ctx.input_cur = GRPC_SLICE_START_PTR(input); + ctx.input_end = GRPC_SLICE_END_PTR(input); + ctx.output_cur = GRPC_SLICE_START_PTR(output); + ctx.output_end = GRPC_SLICE_END_PTR(output); + ctx.contains_tail = true; + + if (GPR_UNLIKELY(!grpc_base64_decode_partial(&ctx))) { + char* s = grpc_slice_to_c_string(input); + gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s); + gpr_free(s); + grpc_slice_unref_internal(output); + return grpc_empty_slice(); + } + GPR_ASSERT(ctx.output_cur == GRPC_SLICE_END_PTR(output)); + GPR_ASSERT(ctx.input_cur <= GRPC_SLICE_END_PTR(input)); + return output; +} diff --git a/src/core/ext/transport/chttp2/transport/bin_encoder.cc b/src/core/ext/transport/chttp2/transport/bin_encoder.cc new file mode 100644 index 00000000..cd92f726 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/bin_encoder.cc @@ -0,0 +1,231 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" + +#include + +#include + +#include "src/core/ext/transport/chttp2/transport/huffsyms.h" + +static const char alphabet[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +struct b64_huff_sym { + uint16_t bits; + uint8_t length; +}; +static const b64_huff_sym huff_alphabet[64] = { + {0x21, 6}, {0x5d, 7}, {0x5e, 7}, {0x5f, 7}, {0x60, 7}, {0x61, 7}, + {0x62, 7}, {0x63, 7}, {0x64, 7}, {0x65, 7}, {0x66, 7}, {0x67, 7}, + {0x68, 7}, {0x69, 7}, {0x6a, 7}, {0x6b, 7}, {0x6c, 7}, {0x6d, 7}, + {0x6e, 7}, {0x6f, 7}, {0x70, 7}, {0x71, 7}, {0x72, 7}, {0xfc, 8}, + {0x73, 7}, {0xfd, 8}, {0x3, 5}, {0x23, 6}, {0x4, 5}, {0x24, 6}, + {0x5, 5}, {0x25, 6}, {0x26, 6}, {0x27, 6}, {0x6, 5}, {0x74, 7}, + {0x75, 7}, {0x28, 6}, {0x29, 6}, {0x2a, 6}, {0x7, 5}, {0x2b, 6}, + {0x76, 7}, {0x2c, 6}, {0x8, 5}, {0x9, 5}, {0x2d, 6}, {0x77, 7}, + {0x78, 7}, {0x79, 7}, {0x7a, 7}, {0x7b, 7}, {0x0, 5}, {0x1, 5}, + {0x2, 5}, {0x19, 6}, {0x1a, 6}, {0x1b, 6}, {0x1c, 6}, {0x1d, 6}, + {0x1e, 6}, {0x1f, 6}, {0x7fb, 11}, {0x18, 6}}; + +static const uint8_t tail_xtra[3] = {0, 2, 3}; + +grpc_slice grpc_chttp2_base64_encode(const grpc_slice& input) { + size_t input_length = GRPC_SLICE_LENGTH(input); + size_t input_triplets = input_length / 3; + size_t tail_case = input_length % 3; + size_t output_length = input_triplets * 4 + tail_xtra[tail_case]; + grpc_slice output = GRPC_SLICE_MALLOC(output_length); + const uint8_t* in = GRPC_SLICE_START_PTR(input); + char* out = reinterpret_cast GRPC_SLICE_START_PTR(output); + size_t i; + + /* encode full triplets */ + for (i = 0; i < input_triplets; i++) { + out[0] = alphabet[in[0] >> 2]; + out[1] = alphabet[((in[0] & 0x3) << 4) | (in[1] >> 4)]; + out[2] = alphabet[((in[1] & 0xf) << 2) | (in[2] >> 6)]; + out[3] = alphabet[in[2] & 0x3f]; + out += 4; + in += 3; + } + + /* encode the remaining bytes */ + switch (tail_case) { + case 0: + break; + case 1: + out[0] = alphabet[in[0] >> 2]; + out[1] = alphabet[(in[0] & 0x3) << 4]; + out += 2; + in += 1; + break; + case 2: + out[0] = alphabet[in[0] >> 2]; + out[1] = alphabet[((in[0] & 0x3) << 4) | (in[1] >> 4)]; + out[2] = alphabet[(in[1] & 0xf) << 2]; + out += 3; + in += 2; + break; + } + + GPR_ASSERT(out == (char*)GRPC_SLICE_END_PTR(output)); + GPR_ASSERT(in == GRPC_SLICE_END_PTR(input)); + return output; +} + +grpc_slice grpc_chttp2_huffman_compress(const grpc_slice& input) { + size_t nbits; + const uint8_t* in; + uint8_t* out; + grpc_slice output; + uint32_t temp = 0; + uint32_t temp_length = 0; + + nbits = 0; + for (in = GRPC_SLICE_START_PTR(input); in != GRPC_SLICE_END_PTR(input); + ++in) { + nbits += grpc_chttp2_huffsyms[*in].length; + } + + output = GRPC_SLICE_MALLOC(nbits / 8 + (nbits % 8 != 0)); + out = GRPC_SLICE_START_PTR(output); + for (in = GRPC_SLICE_START_PTR(input); in != GRPC_SLICE_END_PTR(input); + ++in) { + int sym = *in; + temp <<= grpc_chttp2_huffsyms[sym].length; + temp |= grpc_chttp2_huffsyms[sym].bits; + temp_length += grpc_chttp2_huffsyms[sym].length; + + while (temp_length > 8) { + temp_length -= 8; + *out++ = static_cast(temp >> temp_length); + } + } + + if (temp_length) { + /* NB: the following integer arithmetic operation needs to be in its + * expanded form due to the "integral promotion" performed (see section + * 3.2.1.1 of the C89 draft standard). A cast to the smaller container type + * is then required to avoid the compiler warning */ + *out++ = + static_cast(static_cast(temp << (8u - temp_length)) | + static_cast(0xffu >> temp_length)); + } + + GPR_ASSERT(out == GRPC_SLICE_END_PTR(output)); + + return output; +} + +struct huff_out { + uint32_t temp; + uint32_t temp_length; + uint8_t* out; +}; +static void enc_flush_some(huff_out* out) { + while (out->temp_length > 8) { + out->temp_length -= 8; + *out->out++ = static_cast(out->temp >> out->temp_length); + } +} + +static void enc_add2(huff_out* out, uint8_t a, uint8_t b) { + b64_huff_sym sa = huff_alphabet[a]; + b64_huff_sym sb = huff_alphabet[b]; + out->temp = (out->temp << (sa.length + sb.length)) | + (static_cast(sa.bits) << sb.length) | sb.bits; + out->temp_length += + static_cast(sa.length) + static_cast(sb.length); + enc_flush_some(out); +} + +static void enc_add1(huff_out* out, uint8_t a) { + b64_huff_sym sa = huff_alphabet[a]; + out->temp = (out->temp << sa.length) | sa.bits; + out->temp_length += sa.length; + enc_flush_some(out); +} + +grpc_slice grpc_chttp2_base64_encode_and_huffman_compress( + const grpc_slice& input) { + size_t input_length = GRPC_SLICE_LENGTH(input); + size_t input_triplets = input_length / 3; + size_t tail_case = input_length % 3; + size_t output_syms = input_triplets * 4 + tail_xtra[tail_case]; + size_t max_output_bits = 11 * output_syms; + size_t max_output_length = max_output_bits / 8 + (max_output_bits % 8 != 0); + grpc_slice output = GRPC_SLICE_MALLOC(max_output_length); + const uint8_t* in = GRPC_SLICE_START_PTR(input); + uint8_t* start_out = GRPC_SLICE_START_PTR(output); + huff_out out; + size_t i; + + out.temp = 0; + out.temp_length = 0; + out.out = start_out; + + /* encode full triplets */ + for (i = 0; i < input_triplets; i++) { + const uint8_t low_to_high = static_cast((in[0] & 0x3) << 4); + const uint8_t high_to_low = in[1] >> 4; + enc_add2(&out, in[0] >> 2, low_to_high | high_to_low); + + const uint8_t a = static_cast((in[1] & 0xf) << 2); + const uint8_t b = (in[2] >> 6); + enc_add2(&out, a | b, in[2] & 0x3f); + in += 3; + } + + /* encode the remaining bytes */ + switch (tail_case) { + case 0: + break; + case 1: + enc_add2(&out, in[0] >> 2, static_cast((in[0] & 0x3) << 4)); + in += 1; + break; + case 2: { + const uint8_t low_to_high = static_cast((in[0] & 0x3) << 4); + const uint8_t high_to_low = in[1] >> 4; + enc_add2(&out, in[0] >> 2, low_to_high | high_to_low); + enc_add1(&out, static_cast((in[1] & 0xf) << 2)); + in += 2; + break; + } + } + + if (out.temp_length) { + /* NB: the following integer arithmetic operation needs to be in its + * expanded form due to the "integral promotion" performed (see section + * 3.2.1.1 of the C89 draft standard). A cast to the smaller container type + * is then required to avoid the compiler warning */ + *out.out++ = static_cast( + static_cast(out.temp << (8u - out.temp_length)) | + static_cast(0xffu >> out.temp_length)); + } + + GPR_ASSERT(out.out <= GRPC_SLICE_END_PTR(output)); + GRPC_SLICE_SET_LENGTH(output, out.out - start_out); + + GPR_ASSERT(in == GRPC_SLICE_END_PTR(input)); + return output; +} diff --git a/src/core/ext/transport/chttp2/transport/chttp2_plugin.cc b/src/core/ext/transport/chttp2/transport/chttp2_plugin.cc new file mode 100644 index 00000000..ac13d73d --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/chttp2_plugin.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/transport/metadata.h" + +GPR_GLOBAL_CONFIG_DEFINE_BOOL( + grpc_experimental_disable_flow_control, false, + "If set, flow control will be effectively disabled. Max out all values and " + "assume the remote peer does the same. Thus we can ignore any flow control " + "bookkeeping, error checking, and decision making"); + +void grpc_chttp2_plugin_init(void) { + g_flow_control_enabled = + !GPR_GLOBAL_CONFIG_GET(grpc_experimental_disable_flow_control); +} + +void grpc_chttp2_plugin_shutdown(void) {} diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc new file mode 100644 index 00000000..adada282 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -0,0 +1,3351 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/context_list.h" +#include "src/core/ext/transport/chttp2/transport/frame_data.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/ext/transport/chttp2/transport/varint.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/stream_compression.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/http2_errors.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_conversion.h" +#include "src/core/lib/transport/timeout_encoding.h" +#include "src/core/lib/transport/transport.h" +#include "src/core/lib/transport/transport_impl.h" +#include "src/core/lib/uri/uri_parser.h" + +#define DEFAULT_CONNECTION_WINDOW_TARGET (1024 * 1024) +#define MAX_WINDOW 0x7fffffffu +#define MAX_WRITE_BUFFER_SIZE (64 * 1024 * 1024) +#define DEFAULT_MAX_HEADER_LIST_SIZE (8 * 1024) + +#define DEFAULT_CLIENT_KEEPALIVE_TIME_MS INT_MAX +#define DEFAULT_CLIENT_KEEPALIVE_TIMEOUT_MS 20000 /* 20 seconds */ +#define DEFAULT_SERVER_KEEPALIVE_TIME_MS 7200000 /* 2 hours */ +#define DEFAULT_SERVER_KEEPALIVE_TIMEOUT_MS 20000 /* 20 seconds */ +#define DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS false +#define KEEPALIVE_TIME_BACKOFF_MULTIPLIER 2 + +#define DEFAULT_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS 300000 /* 5 minutes */ +#define DEFAULT_MAX_PINGS_BETWEEN_DATA 2 +#define DEFAULT_MAX_PING_STRIKES 2 + +#define DEFAULT_MAX_PENDING_INDUCED_FRAMES 10000 + +static int g_default_client_keepalive_time_ms = + DEFAULT_CLIENT_KEEPALIVE_TIME_MS; +static int g_default_client_keepalive_timeout_ms = + DEFAULT_CLIENT_KEEPALIVE_TIMEOUT_MS; +static int g_default_server_keepalive_time_ms = + DEFAULT_SERVER_KEEPALIVE_TIME_MS; +static int g_default_server_keepalive_timeout_ms = + DEFAULT_SERVER_KEEPALIVE_TIMEOUT_MS; +static bool g_default_client_keepalive_permit_without_calls = + DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS; +static bool g_default_server_keepalive_permit_without_calls = + DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS; + +static int g_default_min_recv_ping_interval_without_data_ms = + DEFAULT_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS; +static int g_default_max_pings_without_data = DEFAULT_MAX_PINGS_BETWEEN_DATA; +static int g_default_max_ping_strikes = DEFAULT_MAX_PING_STRIKES; + +#define MAX_CLIENT_STREAM_ID 0x7fffffffu +grpc_core::TraceFlag grpc_http_trace(false, "http"); +grpc_core::TraceFlag grpc_keepalive_trace(false, "http_keepalive"); +grpc_core::DebugOnlyTraceFlag grpc_trace_chttp2_refcount(false, + "chttp2_refcount"); + +// forward declarations of various callbacks that we'll build closures around +static void write_action_begin_locked(void* t, grpc_error_handle error); +static void write_action(void* t, grpc_error_handle error); +static void write_action_end(void* t, grpc_error_handle error); +static void write_action_end_locked(void* t, grpc_error_handle error); + +static void read_action(void* t, grpc_error_handle error); +static void read_action_locked(void* t, grpc_error_handle error); +static void continue_read_action_locked(grpc_chttp2_transport* t); + +static void complete_fetch(void* gs, grpc_error_handle error); +static void complete_fetch_locked(void* gs, grpc_error_handle error); +// Set a transport level setting, and push it to our peer +static void queue_setting_update(grpc_chttp2_transport* t, + grpc_chttp2_setting_id id, uint32_t value); + +static void close_from_api(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_error_handle error); + +// Start new streams that have been created if we can +static void maybe_start_some_streams(grpc_chttp2_transport* t); + +static void connectivity_state_set(grpc_chttp2_transport* t, + grpc_connectivity_state state, + const absl::Status& status, + const char* reason); + +static void benign_reclaimer(void* arg, grpc_error_handle error); +static void destructive_reclaimer(void* arg, grpc_error_handle error); +static void benign_reclaimer_locked(void* arg, grpc_error_handle error); +static void destructive_reclaimer_locked(void* arg, grpc_error_handle error); + +static void post_benign_reclaimer(grpc_chttp2_transport* t); +static void post_destructive_reclaimer(grpc_chttp2_transport* t); + +static void close_transport_locked(grpc_chttp2_transport* t, + grpc_error_handle error); +static void end_all_the_calls(grpc_chttp2_transport* t, + grpc_error_handle error); + +static void start_bdp_ping(void* tp, grpc_error_handle error); +static void finish_bdp_ping(void* tp, grpc_error_handle error); +static void start_bdp_ping_locked(void* tp, grpc_error_handle error); +static void finish_bdp_ping_locked(void* tp, grpc_error_handle error); +static void next_bdp_ping_timer_expired(void* tp, grpc_error_handle error); +static void next_bdp_ping_timer_expired_locked(void* tp, + grpc_error_handle error); + +static void cancel_pings(grpc_chttp2_transport* t, grpc_error_handle error); +static void send_ping_locked(grpc_chttp2_transport* t, + grpc_closure* on_initiate, grpc_closure* on_ack); +static void retry_initiate_ping_locked(void* tp, grpc_error_handle error); + +// keepalive-relevant functions +static void init_keepalive_ping(void* arg, grpc_error_handle error); +static void init_keepalive_ping_locked(void* arg, grpc_error_handle error); +static void start_keepalive_ping(void* arg, grpc_error_handle error); +static void finish_keepalive_ping(void* arg, grpc_error_handle error); +static void start_keepalive_ping_locked(void* arg, grpc_error_handle error); +static void finish_keepalive_ping_locked(void* arg, grpc_error_handle error); +static void keepalive_watchdog_fired(void* arg, grpc_error_handle error); +static void keepalive_watchdog_fired_locked(void* arg, grpc_error_handle error); + +static void reset_byte_stream(void* arg, grpc_error_handle error); + +// Flow control default enabled. Can be disabled by setting +// GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL +bool g_flow_control_enabled = true; + +namespace grpc_core { + +namespace { +TestOnlyGlobalHttp2TransportInitCallback test_only_init_callback = nullptr; +TestOnlyGlobalHttp2TransportDestructCallback test_only_destruct_callback = + nullptr; +} // namespace + +void TestOnlySetGlobalHttp2TransportInitCallback( + TestOnlyGlobalHttp2TransportInitCallback callback) { + test_only_init_callback = callback; +} + +void TestOnlySetGlobalHttp2TransportDestructCallback( + TestOnlyGlobalHttp2TransportDestructCallback callback) { + test_only_destruct_callback = callback; +} + +} // namespace grpc_core + +// +// CONSTRUCTION/DESTRUCTION/REFCOUNTING +// + +grpc_chttp2_transport::~grpc_chttp2_transport() { + size_t i; + + if (channelz_socket != nullptr) { + channelz_socket.reset(); + } + + grpc_endpoint_destroy(ep); + + grpc_slice_buffer_destroy_internal(&qbuf); + + grpc_slice_buffer_destroy_internal(&outbuf); + + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed"); + // ContextList::Execute follows semantics of a callback function and does not + // take a ref on error + grpc_core::ContextList::Execute(cl, nullptr, error); + GRPC_ERROR_UNREF(error); + cl = nullptr; + + grpc_slice_buffer_destroy_internal(&read_buffer); + grpc_chttp2_goaway_parser_destroy(&goaway_parser); + + for (i = 0; i < STREAM_LIST_COUNT; i++) { + GPR_ASSERT(lists[i].head == nullptr); + GPR_ASSERT(lists[i].tail == nullptr); + } + + GRPC_ERROR_UNREF(goaway_error); + + GPR_ASSERT(grpc_chttp2_stream_map_size(&stream_map) == 0); + + grpc_chttp2_stream_map_destroy(&stream_map); + + GRPC_COMBINER_UNREF(combiner, "chttp2_transport"); + + cancel_pings(this, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed")); + + while (write_cb_pool) { + grpc_chttp2_write_cb* next = write_cb_pool->next; + gpr_free(write_cb_pool); + write_cb_pool = next; + } + + flow_control.Destroy(); + + GRPC_ERROR_UNREF(closed_with_error); + gpr_free(ping_acks); + if (grpc_core::test_only_destruct_callback != nullptr) { + grpc_core::test_only_destruct_callback(); + } +} + +static const grpc_transport_vtable* get_vtable(void); + +// Returns whether bdp is enabled +static bool read_channel_args(grpc_chttp2_transport* t, + const grpc_channel_args* channel_args, + bool is_client) { + bool enable_bdp = true; + bool channelz_enabled = GRPC_ENABLE_CHANNELZ_DEFAULT; + size_t i; + int j; + + for (i = 0; i < channel_args->num_args; i++) { + if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER)) { + const grpc_integer_options options = {-1, 0, INT_MAX}; + const int value = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + if (value >= 0) { + if ((t->next_stream_id & 1) != (value & 1)) { + gpr_log(GPR_ERROR, "%s: low bit must be %d on %s", + GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER, t->next_stream_id & 1, + is_client ? "client" : "server"); + } else { + t->next_stream_id = static_cast(value); + } + } + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER)) { + const grpc_integer_options options = {-1, 0, INT_MAX}; + const int value = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + if (value >= 0) { + t->hpack_compressor.SetMaxUsableSize(value); + } + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA)) { + t->ping_policy.max_pings_without_data = grpc_channel_arg_get_integer( + &channel_args->args[i], + {g_default_max_pings_without_data, 0, INT_MAX}); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_MAX_PING_STRIKES)) { + t->ping_policy.max_ping_strikes = grpc_channel_arg_get_integer( + &channel_args->args[i], {g_default_max_ping_strikes, 0, INT_MAX}); + } else if (0 == + strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS)) { + t->ping_policy.min_recv_ping_interval_without_data = + grpc_channel_arg_get_integer( + &channel_args->args[i], + grpc_integer_options{ + g_default_min_recv_ping_interval_without_data_ms, 0, + INT_MAX}); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE)) { + t->write_buffer_size = static_cast(grpc_channel_arg_get_integer( + &channel_args->args[i], {0, 0, MAX_WRITE_BUFFER_SIZE})); + } else if (0 == + strcmp(channel_args->args[i].key, GRPC_ARG_HTTP2_BDP_PROBE)) { + enable_bdp = grpc_channel_arg_get_bool(&channel_args->args[i], true); + } else if (0 == + strcmp(channel_args->args[i].key, GRPC_ARG_KEEPALIVE_TIME_MS)) { + const int value = grpc_channel_arg_get_integer( + &channel_args->args[i], + grpc_integer_options{t->is_client + ? g_default_client_keepalive_time_ms + : g_default_server_keepalive_time_ms, + 1, INT_MAX}); + t->keepalive_time = value == INT_MAX ? GRPC_MILLIS_INF_FUTURE : value; + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_KEEPALIVE_TIMEOUT_MS)) { + const int value = grpc_channel_arg_get_integer( + &channel_args->args[i], + grpc_integer_options{t->is_client + ? g_default_client_keepalive_timeout_ms + : g_default_server_keepalive_timeout_ms, + 0, INT_MAX}); + t->keepalive_timeout = value == INT_MAX ? GRPC_MILLIS_INF_FUTURE : value; + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS)) { + t->keepalive_permit_without_calls = static_cast( + grpc_channel_arg_get_integer(&channel_args->args[i], {0, 0, 1})); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_OPTIMIZATION_TARGET)) { + gpr_log(GPR_INFO, "GRPC_ARG_OPTIMIZATION_TARGET is deprecated"); + } else if (0 == + strcmp(channel_args->args[i].key, GRPC_ARG_ENABLE_CHANNELZ)) { + channelz_enabled = grpc_channel_arg_get_bool( + &channel_args->args[i], GRPC_ENABLE_CHANNELZ_DEFAULT); + } else { + static const struct { + const char* channel_arg_name; + grpc_chttp2_setting_id setting_id; + grpc_integer_options integer_options; + bool availability[2] /* server, client */; + } settings_map[] = {{GRPC_ARG_MAX_CONCURRENT_STREAMS, + GRPC_CHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, + {-1, 0, INT32_MAX}, + {true, false}}, + {GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER, + GRPC_CHTTP2_SETTINGS_HEADER_TABLE_SIZE, + {-1, 0, INT32_MAX}, + {true, true}}, + {GRPC_ARG_MAX_METADATA_SIZE, + GRPC_CHTTP2_SETTINGS_MAX_HEADER_LIST_SIZE, + {-1, 0, INT32_MAX}, + {true, true}}, + {GRPC_ARG_HTTP2_MAX_FRAME_SIZE, + GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE, + {-1, 16384, 16777215}, + {true, true}}, + {GRPC_ARG_HTTP2_ENABLE_TRUE_BINARY, + GRPC_CHTTP2_SETTINGS_GRPC_ALLOW_TRUE_BINARY_METADATA, + {1, 0, 1}, + {true, true}}, + {GRPC_ARG_HTTP2_STREAM_LOOKAHEAD_BYTES, + GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, + {-1, 5, INT32_MAX}, + {true, true}}}; + for (j = 0; j < static_cast GPR_ARRAY_SIZE(settings_map); j++) { + if (0 == strcmp(channel_args->args[i].key, + settings_map[j].channel_arg_name)) { + if (!settings_map[j].availability[is_client]) { + gpr_log(GPR_DEBUG, "%s is not available on %s", + settings_map[j].channel_arg_name, + is_client ? "clients" : "servers"); + } else { + int value = grpc_channel_arg_get_integer( + &channel_args->args[i], settings_map[j].integer_options); + if (value >= 0) { + queue_setting_update(t, settings_map[j].setting_id, + static_cast(value)); + } + } + break; + } + } + } + } + if (channelz_enabled) { + t->channelz_socket = + grpc_core::MakeRefCounted( + std::string(grpc_endpoint_get_local_address(t->ep)), t->peer_string, + absl::StrFormat("%s %s", get_vtable()->name, t->peer_string), + grpc_core::channelz::SocketNode::Security::GetFromChannelArgs( + channel_args)); + } + return enable_bdp; +} + +static void init_transport_keepalive_settings(grpc_chttp2_transport* t) { + if (t->is_client) { + t->keepalive_time = g_default_client_keepalive_time_ms == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : g_default_client_keepalive_time_ms; + t->keepalive_timeout = g_default_client_keepalive_timeout_ms == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : g_default_client_keepalive_timeout_ms; + t->keepalive_permit_without_calls = + g_default_client_keepalive_permit_without_calls; + } else { + t->keepalive_time = g_default_server_keepalive_time_ms == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : g_default_server_keepalive_time_ms; + t->keepalive_timeout = g_default_server_keepalive_timeout_ms == INT_MAX + ? GRPC_MILLIS_INF_FUTURE + : g_default_server_keepalive_timeout_ms; + t->keepalive_permit_without_calls = + g_default_server_keepalive_permit_without_calls; + } +} + +static void configure_transport_ping_policy(grpc_chttp2_transport* t) { + t->ping_policy.max_pings_without_data = g_default_max_pings_without_data; + t->ping_policy.max_ping_strikes = g_default_max_ping_strikes; + t->ping_policy.min_recv_ping_interval_without_data = + g_default_min_recv_ping_interval_without_data_ms; +} + +static void init_keepalive_pings_if_enabled(grpc_chttp2_transport* t) { + if (t->keepalive_time != GRPC_MILLIS_INF_FUTURE) { + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_WAITING; + GRPC_CHTTP2_REF_TRANSPORT(t, "init keepalive ping"); + GRPC_CLOSURE_INIT(&t->init_keepalive_ping_locked, init_keepalive_ping, t, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->keepalive_ping_timer, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_time, + &t->init_keepalive_ping_locked); + } else { + // Use GRPC_CHTTP2_KEEPALIVE_STATE_DISABLED to indicate there are no + // inflight keeaplive timers + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_DISABLED; + } +} + +grpc_chttp2_transport::grpc_chttp2_transport( + const grpc_channel_args* channel_args, grpc_endpoint* ep, bool is_client, + grpc_resource_user* resource_user) + : refs(1, GRPC_TRACE_FLAG_ENABLED(grpc_trace_chttp2_refcount) + ? "chttp2_refcount" + : nullptr), + ep(ep), + peer_string(grpc_endpoint_get_peer(ep)), + resource_user(resource_user), + combiner(grpc_combiner_create()), + state_tracker(is_client ? "client_transport" : "server_transport", + GRPC_CHANNEL_READY), + is_client(is_client), + next_stream_id(is_client ? 1 : 2), + deframe_state(is_client ? GRPC_DTS_FH_0 : GRPC_DTS_CLIENT_PREFIX_0) { + GPR_ASSERT(strlen(GRPC_CHTTP2_CLIENT_CONNECT_STRING) == + GRPC_CHTTP2_CLIENT_CONNECT_STRLEN); + base.vtable = get_vtable(); + // 8 is a random stab in the dark as to a good initial size: it's small enough + // that it shouldn't waste memory for infrequently used connections, yet + // large enough that the exponential growth should happen nicely when it's + // needed. + // TODO(ctiller): tune this + grpc_chttp2_stream_map_init(&stream_map, 8); + + grpc_slice_buffer_init(&read_buffer); + grpc_slice_buffer_init(&outbuf); + if (is_client) { + grpc_slice_buffer_add(&outbuf, grpc_slice_from_copied_string( + GRPC_CHTTP2_CLIENT_CONNECT_STRING)); + } + grpc_slice_buffer_init(&qbuf); + // copy in initial settings to all setting sets + size_t i; + int j; + for (i = 0; i < GRPC_CHTTP2_NUM_SETTINGS; i++) { + for (j = 0; j < GRPC_NUM_SETTING_SETS; j++) { + settings[j][i] = grpc_chttp2_settings_parameters[i].default_value; + } + } + grpc_chttp2_goaway_parser_init(&goaway_parser); + + // configure http2 the way we like it + if (is_client) { + queue_setting_update(this, GRPC_CHTTP2_SETTINGS_ENABLE_PUSH, 0); + queue_setting_update(this, GRPC_CHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 0); + } + queue_setting_update(this, GRPC_CHTTP2_SETTINGS_MAX_HEADER_LIST_SIZE, + DEFAULT_MAX_HEADER_LIST_SIZE); + queue_setting_update(this, + GRPC_CHTTP2_SETTINGS_GRPC_ALLOW_TRUE_BINARY_METADATA, 1); + + configure_transport_ping_policy(this); + init_transport_keepalive_settings(this); + + bool enable_bdp = true; + if (channel_args) { + enable_bdp = read_channel_args(this, channel_args, is_client); + } + + if (g_flow_control_enabled) { + flow_control.Init(this, + enable_bdp); + } else { + flow_control.Init(this); + enable_bdp = false; + } + + // No pings allowed before receiving a header or data frame. + ping_state.pings_before_data_required = 0; + ping_state.is_delayed_ping_timer_set = false; + ping_state.last_ping_sent_time = GRPC_MILLIS_INF_PAST; + + ping_recv_state.last_ping_recv_time = GRPC_MILLIS_INF_PAST; + ping_recv_state.ping_strikes = 0; + + init_keepalive_pings_if_enabled(this); + + if (enable_bdp) { + bdp_ping_blocked = true; + grpc_chttp2_act_on_flowctl_action(flow_control->PeriodicUpdate(), this, + nullptr); + } + + grpc_chttp2_initiate_write(this, GRPC_CHTTP2_INITIATE_WRITE_INITIAL_WRITE); + post_benign_reclaimer(this); + if (grpc_core::test_only_init_callback != nullptr) { + grpc_core::test_only_init_callback(); + } +} + +static void destroy_transport_locked(void* tp, grpc_error_handle /*error*/) { + grpc_chttp2_transport* t = static_cast(tp); + t->destroying = 1; + grpc_resource_user_shutdown(t->resource_user); + grpc_resource_user_unref(t->resource_user); + close_transport_locked( + t, grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed"), + GRPC_ERROR_INT_OCCURRED_DURING_WRITE, t->write_state)); + // Must be the last line. + GRPC_CHTTP2_UNREF_TRANSPORT(t, "destroy"); +} + +static void destroy_transport(grpc_transport* gt) { + grpc_chttp2_transport* t = reinterpret_cast(gt); + t->combiner->Run(GRPC_CLOSURE_CREATE(destroy_transport_locked, t, nullptr), + GRPC_ERROR_NONE); +} + +static void close_transport_locked(grpc_chttp2_transport* t, + grpc_error_handle error) { + end_all_the_calls(t, GRPC_ERROR_REF(error)); + cancel_pings(t, GRPC_ERROR_REF(error)); + if (t->closed_with_error == GRPC_ERROR_NONE) { + if (!grpc_error_has_clear_grpc_status(error)) { + error = grpc_error_set_int(error, GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE); + } + if (t->write_state != GRPC_CHTTP2_WRITE_STATE_IDLE) { + if (t->close_transport_on_writes_finished == GRPC_ERROR_NONE) { + t->close_transport_on_writes_finished = + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Delayed close due to in-progress write"); + } + t->close_transport_on_writes_finished = + grpc_error_add_child(t->close_transport_on_writes_finished, error); + return; + } + GPR_ASSERT(error != GRPC_ERROR_NONE); + t->closed_with_error = GRPC_ERROR_REF(error); + connectivity_state_set(t, GRPC_CHANNEL_SHUTDOWN, absl::Status(), + "close_transport"); + if (t->ping_state.is_delayed_ping_timer_set) { + grpc_timer_cancel(&t->ping_state.delayed_ping_timer); + } + if (t->have_next_bdp_ping_timer) { + grpc_timer_cancel(&t->next_bdp_ping_timer); + } + switch (t->keepalive_state) { + case GRPC_CHTTP2_KEEPALIVE_STATE_WAITING: + grpc_timer_cancel(&t->keepalive_ping_timer); + break; + case GRPC_CHTTP2_KEEPALIVE_STATE_PINGING: + grpc_timer_cancel(&t->keepalive_ping_timer); + grpc_timer_cancel(&t->keepalive_watchdog_timer); + break; + case GRPC_CHTTP2_KEEPALIVE_STATE_DYING: + case GRPC_CHTTP2_KEEPALIVE_STATE_DISABLED: + // keepalive timers are not set in these two states + break; + } + + // flush writable stream list to avoid dangling references + grpc_chttp2_stream* s; + while (grpc_chttp2_list_pop_writable_stream(t, &s)) { + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2_writing:close"); + } + GPR_ASSERT(t->write_state == GRPC_CHTTP2_WRITE_STATE_IDLE); + grpc_endpoint_shutdown(t->ep, GRPC_ERROR_REF(error)); + } + if (t->notify_on_receive_settings != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, t->notify_on_receive_settings, + GRPC_ERROR_REF(error)); + t->notify_on_receive_settings = nullptr; + } + if (t->notify_on_close != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, t->notify_on_close, + GRPC_ERROR_REF(error)); + t->notify_on_close = nullptr; + } + GRPC_ERROR_UNREF(error); +} + +#ifndef NDEBUG +void grpc_chttp2_stream_ref(grpc_chttp2_stream* s, const char* reason) { + grpc_stream_ref(s->refcount, reason); +} +void grpc_chttp2_stream_unref(grpc_chttp2_stream* s, const char* reason) { + grpc_stream_unref(s->refcount, reason); +} +#else +void grpc_chttp2_stream_ref(grpc_chttp2_stream* s) { + grpc_stream_ref(s->refcount); +} +void grpc_chttp2_stream_unref(grpc_chttp2_stream* s) { + grpc_stream_unref(s->refcount); +} +#endif + +grpc_chttp2_stream::Reffer::Reffer(grpc_chttp2_stream* s) { + // We reserve one 'active stream' that's dropped when the stream is + // read-closed. The others are for Chttp2IncomingByteStreams that are + // actively reading + GRPC_CHTTP2_STREAM_REF(s, "chttp2"); + GRPC_CHTTP2_REF_TRANSPORT(s->t, "stream"); +} + +grpc_chttp2_stream::grpc_chttp2_stream(grpc_chttp2_transport* t, + grpc_stream_refcount* refcount, + const void* server_data, + grpc_core::Arena* arena) + : t(t), + refcount(refcount), + reffer(this), + initial_metadata_buffer(arena), + trailing_metadata_buffer(arena) { + if (server_data) { + id = static_cast(reinterpret_cast(server_data)); + *t->accepting_stream = this; + grpc_chttp2_stream_map_add(&t->stream_map, id, this); + post_destructive_reclaimer(t); + } + if (t->flow_control->flow_control_enabled()) { + flow_control.Init( + static_cast( + t->flow_control.get()), + this); + } else { + flow_control.Init(); + } + + grpc_slice_buffer_init(&frame_storage); + grpc_slice_buffer_init(&unprocessed_incoming_frames_buffer); + grpc_slice_buffer_init(&flow_controlled_buffer); + GRPC_CLOSURE_INIT(&reset_byte_stream, ::reset_byte_stream, this, nullptr); +} + +grpc_chttp2_stream::~grpc_chttp2_stream() { + if (t->channelz_socket != nullptr) { + if ((t->is_client && eos_received) || (!t->is_client && eos_sent)) { + t->channelz_socket->RecordStreamSucceeded(); + } else { + t->channelz_socket->RecordStreamFailed(); + } + } + + GPR_ASSERT((write_closed && read_closed) || id == 0); + if (id != 0) { + GPR_ASSERT(grpc_chttp2_stream_map_find(&t->stream_map, id) == nullptr); + } + + grpc_slice_buffer_destroy_internal(&unprocessed_incoming_frames_buffer); + grpc_slice_buffer_destroy_internal(&frame_storage); + if (stream_compression_method != GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS) { + grpc_slice_buffer_destroy_internal(&compressed_data_buffer); + } + if (stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS) { + grpc_slice_buffer_destroy_internal(&decompressed_data_buffer); + } + + for (int i = 0; i < STREAM_LIST_COUNT; i++) { + if (GPR_UNLIKELY(included[i])) { + gpr_log(GPR_ERROR, "%s stream %d still included in list %d", + t->is_client ? "client" : "server", id, i); + abort(); + } + } + + GPR_ASSERT(send_initial_metadata_finished == nullptr); + GPR_ASSERT(fetching_send_message == nullptr); + GPR_ASSERT(send_trailing_metadata_finished == nullptr); + GPR_ASSERT(recv_initial_metadata_ready == nullptr); + GPR_ASSERT(recv_message_ready == nullptr); + GPR_ASSERT(recv_trailing_metadata_finished == nullptr); + grpc_slice_buffer_destroy_internal(&flow_controlled_buffer); + GRPC_ERROR_UNREF(read_closed_error); + GRPC_ERROR_UNREF(write_closed_error); + GRPC_ERROR_UNREF(byte_stream_error); + flow_control.Destroy(); + if (!t->is_client) { + grpc_resource_user_free(t->resource_user, GRPC_RESOURCE_QUOTA_CALL_SIZE); + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "stream"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, destroy_stream_arg, GRPC_ERROR_NONE); +} + +static int init_stream(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, const void* server_data, + grpc_core::Arena* arena) { + GPR_TIMER_SCOPE("init_stream", 0); + grpc_chttp2_transport* t = reinterpret_cast(gt); + new (gs) grpc_chttp2_stream(t, refcount, server_data, arena); + return 0; +} + +static void destroy_stream_locked(void* sp, grpc_error_handle /*error*/) { + GPR_TIMER_SCOPE("destroy_stream", 0); + grpc_chttp2_stream* s = static_cast(sp); + s->~grpc_chttp2_stream(); +} + +static void destroy_stream(grpc_transport* gt, grpc_stream* gs, + grpc_closure* then_schedule_closure) { + GPR_TIMER_SCOPE("destroy_stream", 0); + grpc_chttp2_transport* t = reinterpret_cast(gt); + grpc_chttp2_stream* s = reinterpret_cast(gs); + if (s->stream_compression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS && + s->stream_compression_ctx != nullptr) { + grpc_stream_compression_context_destroy(s->stream_compression_ctx); + s->stream_compression_ctx = nullptr; + } + if (s->stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS && + s->stream_decompression_ctx != nullptr) { + grpc_stream_compression_context_destroy(s->stream_decompression_ctx); + s->stream_decompression_ctx = nullptr; + } + + s->destroy_stream_arg = then_schedule_closure; + t->combiner->Run( + GRPC_CLOSURE_INIT(&s->destroy_stream, destroy_stream_locked, s, nullptr), + GRPC_ERROR_NONE); +} + +grpc_chttp2_stream* grpc_chttp2_parsing_accept_stream(grpc_chttp2_transport* t, + uint32_t id) { + if (t->accept_stream_cb == nullptr) { + return nullptr; + } + // Don't accept the stream if memory quota doesn't allow. Note that we should + // simply refuse the stream here instead of canceling the stream after it's + // accepted since the latter will create the call which costs much memory. + GPR_ASSERT(t->resource_user != nullptr); + if (!grpc_resource_user_safe_alloc(t->resource_user, + GRPC_RESOURCE_QUOTA_CALL_SIZE)) { + gpr_log(GPR_INFO, "Memory exhausted, rejecting the stream."); + grpc_chttp2_add_rst_stream_to_next_write(t, id, GRPC_HTTP2_REFUSED_STREAM, + nullptr); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_RST_STREAM); + return nullptr; + } + grpc_chttp2_stream* accepting = nullptr; + GPR_ASSERT(t->accepting_stream == nullptr); + t->accepting_stream = &accepting; + t->accept_stream_cb(t->accept_stream_cb_user_data, &t->base, + reinterpret_cast(id)); + t->accepting_stream = nullptr; + return accepting; +} + +// +// OUTPUT PROCESSING +// + +static const char* write_state_name(grpc_chttp2_write_state st) { + switch (st) { + case GRPC_CHTTP2_WRITE_STATE_IDLE: + return "IDLE"; + case GRPC_CHTTP2_WRITE_STATE_WRITING: + return "WRITING"; + case GRPC_CHTTP2_WRITE_STATE_WRITING_WITH_MORE: + return "WRITING+MORE"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +static void set_write_state(grpc_chttp2_transport* t, + grpc_chttp2_write_state st, const char* reason) { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, "W:%p %s [%s] state %s -> %s [%s]", t, + t->is_client ? "CLIENT" : "SERVER", t->peer_string.c_str(), + write_state_name(t->write_state), write_state_name(st), reason)); + t->write_state = st; + // If the state is being reset back to idle, it means a write was just + // finished. Make sure all the run_after_write closures are scheduled. + // + // This is also our chance to close the transport if the transport was marked + // to be closed after all writes finish (for example, if we received a go-away + // from peer while we had some pending writes) + if (st == GRPC_CHTTP2_WRITE_STATE_IDLE) { + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &t->run_after_write); + if (t->close_transport_on_writes_finished != GRPC_ERROR_NONE) { + grpc_error_handle err = t->close_transport_on_writes_finished; + t->close_transport_on_writes_finished = GRPC_ERROR_NONE; + close_transport_locked(t, err); + } + } +} + +static void inc_initiate_write_reason( + grpc_chttp2_initiate_write_reason reason) { + switch (reason) { + case GRPC_CHTTP2_INITIATE_WRITE_INITIAL_WRITE: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_INITIAL_WRITE(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_START_NEW_STREAM: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_START_NEW_STREAM(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_MESSAGE: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_SEND_MESSAGE(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_INITIAL_METADATA: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_SEND_INITIAL_METADATA(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_TRAILING_METADATA: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_SEND_TRAILING_METADATA(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_RETRY_SEND_PING: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_RETRY_SEND_PING(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_CONTINUE_PINGS: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_CONTINUE_PINGS(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_GOAWAY_SENT: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_GOAWAY_SENT(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_RST_STREAM: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_RST_STREAM(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_CLOSE_FROM_API: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_CLOSE_FROM_API(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_STREAM_FLOW_CONTROL: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_STREAM_FLOW_CONTROL(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_TRANSPORT_FLOW_CONTROL(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_SETTINGS: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_SEND_SETTINGS(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_SETTING: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_FLOW_CONTROL_UNSTALLED_BY_SETTING(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_UPDATE: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_FLOW_CONTROL_UNSTALLED_BY_UPDATE(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_APPLICATION_PING: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_APPLICATION_PING(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_BDP_PING: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_BDP_ESTIMATOR_PING(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_KEEPALIVE_PING: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_KEEPALIVE_PING(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL_UNSTALLED: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_TRANSPORT_FLOW_CONTROL_UNSTALLED(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_PING_RESPONSE: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_PING_RESPONSE(); + break; + case GRPC_CHTTP2_INITIATE_WRITE_FORCE_RST_STREAM: + GRPC_STATS_INC_HTTP2_INITIATE_WRITE_DUE_TO_FORCE_RST_STREAM(); + break; + } +} + +void grpc_chttp2_initiate_write(grpc_chttp2_transport* t, + grpc_chttp2_initiate_write_reason reason) { + GPR_TIMER_SCOPE("grpc_chttp2_initiate_write", 0); + + switch (t->write_state) { + case GRPC_CHTTP2_WRITE_STATE_IDLE: + inc_initiate_write_reason(reason); + set_write_state(t, GRPC_CHTTP2_WRITE_STATE_WRITING, + grpc_chttp2_initiate_write_reason_string(reason)); + GRPC_CHTTP2_REF_TRANSPORT(t, "writing"); + // Note that the 'write_action_begin_locked' closure is being scheduled + // on the 'finally_scheduler' of t->combiner. This means that + // 'write_action_begin_locked' is called only *after* all the other + // closures (some of which are potentially initiating more writes on the + // transport) are executed on the t->combiner. + // + // The reason for scheduling on finally_scheduler is to make sure we batch + // as many writes as possible. 'write_action_begin_locked' is the function + // that gathers all the relevant bytes (which are at various places in the + // grpc_chttp2_transport structure) and append them to 'outbuf' field in + // grpc_chttp2_transport thereby batching what would have been potentially + // multiple write operations. + // + // Also, 'write_action_begin_locked' only gathers the bytes into outbuf. + // It does not call the endpoint to write the bytes. That is done by the + // 'write_action' (which is scheduled by 'write_action_begin_locked') + t->combiner->FinallyRun( + GRPC_CLOSURE_INIT(&t->write_action_begin_locked, + write_action_begin_locked, t, nullptr), + GRPC_ERROR_NONE); + break; + case GRPC_CHTTP2_WRITE_STATE_WRITING: + set_write_state(t, GRPC_CHTTP2_WRITE_STATE_WRITING_WITH_MORE, + grpc_chttp2_initiate_write_reason_string(reason)); + break; + case GRPC_CHTTP2_WRITE_STATE_WRITING_WITH_MORE: + break; + } +} + +void grpc_chttp2_mark_stream_writable(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + if (t->closed_with_error == GRPC_ERROR_NONE && + grpc_chttp2_list_add_writable_stream(t, s)) { + GRPC_CHTTP2_STREAM_REF(s, "chttp2_writing:become"); + } +} + +static const char* begin_writing_desc(bool partial) { + if (partial) { + return "begin partial write in background"; + } else { + return "begin write in current thread"; + } +} + +static void write_action_begin_locked(void* gt, + grpc_error_handle /*error_ignored*/) { + GPR_TIMER_SCOPE("write_action_begin_locked", 0); + grpc_chttp2_transport* t = static_cast(gt); + GPR_ASSERT(t->write_state != GRPC_CHTTP2_WRITE_STATE_IDLE); + grpc_chttp2_begin_write_result r; + if (t->closed_with_error != GRPC_ERROR_NONE) { + r.writing = false; + } else { + r = grpc_chttp2_begin_write(t); + } + if (r.writing) { + if (r.partial) { + GRPC_STATS_INC_HTTP2_PARTIAL_WRITES(); + } + set_write_state(t, + r.partial ? GRPC_CHTTP2_WRITE_STATE_WRITING_WITH_MORE + : GRPC_CHTTP2_WRITE_STATE_WRITING, + begin_writing_desc(r.partial)); + write_action(t, GRPC_ERROR_NONE); + if (t->reading_paused_on_pending_induced_frames) { + GPR_ASSERT(t->num_pending_induced_frames == 0); + // We had paused reading, because we had many induced frames (SETTINGS + // ACK, PINGS ACK and RST_STREAMS) pending in t->qbuf. Now that we have + // been able to flush qbuf, we can resume reading. + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_INFO, + "transport %p : Resuming reading after being paused due to too " + "many unwritten SETTINGS ACK, PINGS ACK and RST_STREAM frames", + t)); + t->reading_paused_on_pending_induced_frames = false; + continue_read_action_locked(t); + } + } else { + GRPC_STATS_INC_HTTP2_SPURIOUS_WRITES_BEGUN(); + set_write_state(t, GRPC_CHTTP2_WRITE_STATE_IDLE, "begin writing nothing"); + GRPC_CHTTP2_UNREF_TRANSPORT(t, "writing"); + } +} + +static void write_action(void* gt, grpc_error_handle /*error*/) { + GPR_TIMER_SCOPE("write_action", 0); + grpc_chttp2_transport* t = static_cast(gt); + void* cl = t->cl; + t->cl = nullptr; + grpc_endpoint_write( + t->ep, &t->outbuf, + GRPC_CLOSURE_INIT(&t->write_action_end_locked, write_action_end, t, + grpc_schedule_on_exec_ctx), + cl); +} + +static void write_action_end(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->write_action_end_locked, + write_action_end_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +// Callback from the grpc_endpoint after bytes have been written by calling +// sendmsg +static void write_action_end_locked(void* tp, grpc_error_handle error) { + GPR_TIMER_SCOPE("terminate_writing_with_lock", 0); + grpc_chttp2_transport* t = static_cast(tp); + + bool closed = false; + if (error != GRPC_ERROR_NONE) { + close_transport_locked(t, GRPC_ERROR_REF(error)); + closed = true; + } + + if (t->sent_goaway_state == GRPC_CHTTP2_GOAWAY_SEND_SCHEDULED) { + t->sent_goaway_state = GRPC_CHTTP2_GOAWAY_SENT; + closed = true; + if (grpc_chttp2_stream_map_size(&t->stream_map) == 0) { + close_transport_locked( + t, GRPC_ERROR_CREATE_FROM_STATIC_STRING("goaway sent")); + } + } + + switch (t->write_state) { + case GRPC_CHTTP2_WRITE_STATE_IDLE: + GPR_UNREACHABLE_CODE(break); + case GRPC_CHTTP2_WRITE_STATE_WRITING: + GPR_TIMER_MARK("state=writing", 0); + set_write_state(t, GRPC_CHTTP2_WRITE_STATE_IDLE, "finish writing"); + break; + case GRPC_CHTTP2_WRITE_STATE_WRITING_WITH_MORE: + GPR_TIMER_MARK("state=writing_stale_no_poller", 0); + set_write_state(t, GRPC_CHTTP2_WRITE_STATE_WRITING, "continue writing"); + GRPC_CHTTP2_REF_TRANSPORT(t, "writing"); + // If the transport is closed, we will retry writing on the endpoint + // and next write may contain part of the currently serialized frames. + // So, we should only call the run_after_write callbacks when the next + // write finishes, or the callbacks will be invoked when the stream is + // closed. + if (!closed) { + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &t->run_after_write); + } + t->combiner->FinallyRun( + GRPC_CLOSURE_INIT(&t->write_action_begin_locked, + write_action_begin_locked, t, nullptr), + GRPC_ERROR_NONE); + break; + } + + grpc_chttp2_end_write(t, GRPC_ERROR_REF(error)); + GRPC_CHTTP2_UNREF_TRANSPORT(t, "writing"); +} + +// Dirties an HTTP2 setting to be sent out next time a writing path occurs. +// If the change needs to occur immediately, manually initiate a write. +static void queue_setting_update(grpc_chttp2_transport* t, + grpc_chttp2_setting_id id, uint32_t value) { + const grpc_chttp2_setting_parameters* sp = + &grpc_chttp2_settings_parameters[id]; + uint32_t use_value = grpc_core::Clamp(value, sp->min_value, sp->max_value); + if (use_value != value) { + gpr_log(GPR_INFO, "Requested parameter %s clamped from %d to %d", sp->name, + value, use_value); + } + if (use_value != t->settings[GRPC_LOCAL_SETTINGS][id]) { + t->settings[GRPC_LOCAL_SETTINGS][id] = use_value; + t->dirtied_local_settings = true; + } +} + +void grpc_chttp2_add_incoming_goaway(grpc_chttp2_transport* t, + uint32_t goaway_error, + uint32_t last_stream_id, + absl::string_view goaway_text) { + // Discard the error from a previous goaway frame (if any) + if (t->goaway_error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(t->goaway_error); + } + t->goaway_error = grpc_error_set_str( + grpc_error_set_int( + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("GOAWAY received"), + GRPC_ERROR_INT_HTTP2_ERROR, static_cast(goaway_error)), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), + GRPC_ERROR_STR_RAW_BYTES, goaway_text); + + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, "transport %p got goaway with last stream id %d", t, + last_stream_id)); + // We want to log this irrespective of whether http tracing is enabled if we + // received a GOAWAY with a non NO_ERROR code. + if (goaway_error != GRPC_HTTP2_NO_ERROR) { + gpr_log(GPR_INFO, "%s: Got goaway [%d] err=%s", t->peer_string.c_str(), + goaway_error, grpc_error_std_string(t->goaway_error).c_str()); + } + absl::Status status = grpc_error_to_absl_status(t->goaway_error); + // When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug + // data equal to "too_many_pings", it should log the occurrence at a log level + // that is enabled by default and double the configured KEEPALIVE_TIME used + // for new connections on that channel. + if (GPR_UNLIKELY(t->is_client && + goaway_error == GRPC_HTTP2_ENHANCE_YOUR_CALM && + goaway_text == "too_many_pings")) { + gpr_log(GPR_ERROR, + "Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug " + "data equal to \"too_many_pings\""); + double current_keepalive_time_ms = static_cast(t->keepalive_time); + constexpr int max_keepalive_time_ms = + INT_MAX / KEEPALIVE_TIME_BACKOFF_MULTIPLIER; + t->keepalive_time = + current_keepalive_time_ms > static_cast(max_keepalive_time_ms) + ? GRPC_MILLIS_INF_FUTURE + : static_cast(current_keepalive_time_ms * + KEEPALIVE_TIME_BACKOFF_MULTIPLIER); + status.SetPayload(grpc_core::kKeepaliveThrottlingKey, + absl::Cord(std::to_string(t->keepalive_time))); + } + // lie: use transient failure from the transport to indicate goaway has been + // received. + connectivity_state_set(t, GRPC_CHANNEL_TRANSIENT_FAILURE, status, + "got_goaway"); +} + +static void maybe_start_some_streams(grpc_chttp2_transport* t) { + grpc_chttp2_stream* s; + // cancel out streams that haven't yet started if we have received a GOAWAY + if (t->goaway_error != GRPC_ERROR_NONE) { + while (grpc_chttp2_list_pop_waiting_for_concurrency(t, &s)) { + grpc_chttp2_cancel_stream( + t, s, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("GOAWAY received"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + return; + } + // start streams where we have free grpc_chttp2_stream ids and free + // * concurrency + while (t->next_stream_id <= MAX_CLIENT_STREAM_ID && + grpc_chttp2_stream_map_size(&t->stream_map) < + t->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS] && + grpc_chttp2_list_pop_waiting_for_concurrency(t, &s)) { + // safe since we can't (legally) be parsing this stream yet + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_INFO, + "HTTP:%s: Transport %p allocating new grpc_chttp2_stream %p to id %d", + t->is_client ? "CLI" : "SVR", t, s, t->next_stream_id)); + + GPR_ASSERT(s->id == 0); + s->id = t->next_stream_id; + t->next_stream_id += 2; + + if (t->next_stream_id >= MAX_CLIENT_STREAM_ID) { + connectivity_state_set(t, GRPC_CHANNEL_TRANSIENT_FAILURE, + absl::Status(absl::StatusCode::kUnavailable, + "Transport Stream IDs exhausted"), + "no_more_stream_ids"); + } + + grpc_chttp2_stream_map_add(&t->stream_map, s->id, s); + post_destructive_reclaimer(t); + grpc_chttp2_mark_stream_writable(t, s); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_START_NEW_STREAM); + } + // cancel out streams that will never be started + if (t->next_stream_id >= MAX_CLIENT_STREAM_ID) { + while (grpc_chttp2_list_pop_waiting_for_concurrency(t, &s)) { + grpc_chttp2_cancel_stream( + t, s, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Stream IDs exhausted"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } +} + +// Flag that this closure barrier may be covering a write in a pollset, and so +// we should not complete this closure until we can prove that the write got +// scheduled +#define CLOSURE_BARRIER_MAY_COVER_WRITE (1 << 0) +// First bit of the reference count, stored in the high order bits (with the low +// bits being used for flags defined above) +#define CLOSURE_BARRIER_FIRST_REF_BIT (1 << 16) + +static grpc_closure* add_closure_barrier(grpc_closure* closure) { + closure->next_data.scratch += CLOSURE_BARRIER_FIRST_REF_BIT; + return closure; +} + +static void null_then_sched_closure(grpc_closure** closure) { + grpc_closure* c = *closure; + *closure = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, c, GRPC_ERROR_NONE); +} + +void grpc_chttp2_complete_closure_step(grpc_chttp2_transport* t, + grpc_chttp2_stream* /*s*/, + grpc_closure** pclosure, + grpc_error_handle error, + const char* desc) { + grpc_closure* closure = *pclosure; + *pclosure = nullptr; + if (closure == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } + closure->next_data.scratch -= CLOSURE_BARRIER_FIRST_REF_BIT; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log( + GPR_INFO, + "complete_closure_step: t=%p %p refs=%d flags=0x%04x desc=%s err=%s " + "write_state=%s", + t, closure, + static_cast(closure->next_data.scratch / + CLOSURE_BARRIER_FIRST_REF_BIT), + static_cast(closure->next_data.scratch % + CLOSURE_BARRIER_FIRST_REF_BIT), + desc, grpc_error_std_string(error).c_str(), + write_state_name(t->write_state)); + } + if (error != GRPC_ERROR_NONE) { + if (closure->error_data.error == GRPC_ERROR_NONE) { + closure->error_data.error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error in HTTP transport completing operation"); + closure->error_data.error = + grpc_error_set_str(closure->error_data.error, + GRPC_ERROR_STR_TARGET_ADDRESS, t->peer_string); + } + closure->error_data.error = + grpc_error_add_child(closure->error_data.error, error); + } + if (closure->next_data.scratch < CLOSURE_BARRIER_FIRST_REF_BIT) { + if ((t->write_state == GRPC_CHTTP2_WRITE_STATE_IDLE) || + !(closure->next_data.scratch & CLOSURE_BARRIER_MAY_COVER_WRITE)) { + // Using GRPC_CLOSURE_SCHED instead of GRPC_CLOSURE_RUN to avoid running + // closures earlier than when it is safe to do so. + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, + closure->error_data.error); + } else { + grpc_closure_list_append(&t->run_after_write, closure, + closure->error_data.error); + } + } +} + +static bool contains_non_ok_status(grpc_metadata_batch* batch) { + if (batch->legacy_index()->named.grpc_status != nullptr) { + return !grpc_mdelem_static_value_eq( + batch->legacy_index()->named.grpc_status->md, + GRPC_MDELEM_GRPC_STATUS_0); + } + return false; +} + +static void maybe_become_writable_due_to_send_msg(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + if (s->id != 0 && (!s->write_buffering || + s->flow_controlled_buffer.length > t->write_buffer_size)) { + grpc_chttp2_mark_stream_writable(t, s); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_SEND_MESSAGE); + } +} + +static void add_fetched_slice_locked(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + s->fetched_send_message_length += + static_cast GRPC_SLICE_LENGTH(s->fetching_slice); + grpc_slice_buffer_add(&s->flow_controlled_buffer, s->fetching_slice); + maybe_become_writable_due_to_send_msg(t, s); +} + +static void continue_fetching_send_locked(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + for (;;) { + if (s->fetching_send_message == nullptr) { + // Stream was cancelled before message fetch completed + abort(); /* TODO(ctiller): what cleanup here? */ + } + if (s->fetched_send_message_length == s->fetching_send_message->length()) { + int64_t notify_offset = s->next_message_end_offset; + if (notify_offset <= s->flow_controlled_bytes_written) { + grpc_chttp2_complete_closure_step( + t, s, &s->fetching_send_message_finished, GRPC_ERROR_NONE, + "fetching_send_message_finished"); + } else { + grpc_chttp2_write_cb* cb = t->write_cb_pool; + if (cb == nullptr) { + cb = static_cast(gpr_malloc(sizeof(*cb))); + } else { + t->write_cb_pool = cb->next; + } + cb->call_at_byte = notify_offset; + cb->closure = s->fetching_send_message_finished; + s->fetching_send_message_finished = nullptr; + grpc_chttp2_write_cb** list = + s->fetching_send_message->flags() & GRPC_WRITE_THROUGH + ? &s->on_write_finished_cbs + : &s->on_flow_controlled_cbs; + cb->next = *list; + *list = cb; + } + s->fetching_send_message.reset(); + return; /* early out */ + } else if (s->fetching_send_message->Next( + UINT32_MAX, GRPC_CLOSURE_INIT(&s->complete_fetch_locked, + ::complete_fetch, s, + grpc_schedule_on_exec_ctx))) { + grpc_error_handle error = + s->fetching_send_message->Pull(&s->fetching_slice); + if (error != GRPC_ERROR_NONE) { + s->fetching_send_message.reset(); + grpc_chttp2_cancel_stream(t, s, error); + } else { + add_fetched_slice_locked(t, s); + } + } + } +} + +static void complete_fetch(void* gs, grpc_error_handle error) { + grpc_chttp2_stream* s = static_cast(gs); + s->t->combiner->Run(GRPC_CLOSURE_INIT(&s->complete_fetch_locked, + ::complete_fetch_locked, s, nullptr), + GRPC_ERROR_REF(error)); +} + +static void complete_fetch_locked(void* gs, grpc_error_handle error) { + grpc_chttp2_stream* s = static_cast(gs); + grpc_chttp2_transport* t = s->t; + if (error == GRPC_ERROR_NONE) { + error = s->fetching_send_message->Pull(&s->fetching_slice); + if (error == GRPC_ERROR_NONE) { + add_fetched_slice_locked(t, s); + continue_fetching_send_locked(t, s); + } + } + if (error != GRPC_ERROR_NONE) { + s->fetching_send_message.reset(); + grpc_chttp2_cancel_stream(t, s, error); + } +} + +static void log_metadata(const grpc_metadata_batch* md_batch, uint32_t id, + bool is_client, bool is_initial) { + md_batch->ForEach([=](grpc_mdelem md) { + char* key = grpc_slice_to_c_string(GRPC_MDKEY(md)); + char* value = grpc_slice_to_c_string(GRPC_MDVALUE(md)); + gpr_log(GPR_INFO, "HTTP:%d:%s:%s: %s: %s", id, is_initial ? "HDR" : "TRL", + is_client ? "CLI" : "SVR", key, value); + gpr_free(key); + gpr_free(value); + }); +} + +static void perform_stream_op_locked(void* stream_op, + grpc_error_handle /*error_ignored*/) { + GPR_TIMER_SCOPE("perform_stream_op_locked", 0); + + grpc_transport_stream_op_batch* op = + static_cast(stream_op); + grpc_chttp2_stream* s = + static_cast(op->handler_private.extra_arg); + grpc_transport_stream_op_batch_payload* op_payload = op->payload; + grpc_chttp2_transport* t = s->t; + + GRPC_STATS_INC_HTTP2_OP_BATCHES(); + + s->context = op->payload->context; + s->traced = op->is_traced; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "perform_stream_op_locked: %s; on_complete = %p", + grpc_transport_stream_op_batch_string(op).c_str(), op->on_complete); + if (op->send_initial_metadata) { + log_metadata(op_payload->send_initial_metadata.send_initial_metadata, + s->id, t->is_client, true); + } + if (op->send_trailing_metadata) { + log_metadata(op_payload->send_trailing_metadata.send_trailing_metadata, + s->id, t->is_client, false); + } + } + + grpc_closure* on_complete = op->on_complete; + // on_complete will be null if and only if there are no send ops in the batch. + if (on_complete != nullptr) { + // This batch has send ops. Use final_data as a barrier until enqueue time; + // the initial counter is dropped at the end of this function. + on_complete->next_data.scratch = CLOSURE_BARRIER_FIRST_REF_BIT; + on_complete->error_data.error = GRPC_ERROR_NONE; + } + + if (op->cancel_stream) { + GRPC_STATS_INC_HTTP2_OP_CANCEL(); + grpc_chttp2_cancel_stream(t, s, op_payload->cancel_stream.cancel_error); + } + + if (op->send_initial_metadata) { + if (t->is_client && t->channelz_socket != nullptr) { + t->channelz_socket->RecordStreamStartedFromLocal(); + } + GRPC_STATS_INC_HTTP2_OP_SEND_INITIAL_METADATA(); + GPR_ASSERT(s->send_initial_metadata_finished == nullptr); + on_complete->next_data.scratch |= CLOSURE_BARRIER_MAY_COVER_WRITE; + + // Identify stream compression + if (op_payload->send_initial_metadata.send_initial_metadata->legacy_index() + ->named.content_encoding == nullptr || + grpc_stream_compression_method_parse( + GRPC_MDVALUE( + op_payload->send_initial_metadata.send_initial_metadata + ->legacy_index() + ->named.content_encoding->md), + true, &s->stream_compression_method) == 0) { + s->stream_compression_method = GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS; + } + if (s->stream_compression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS) { + s->uncompressed_data_size = 0; + s->stream_compression_ctx = nullptr; + grpc_slice_buffer_init(&s->compressed_data_buffer); + } + s->send_initial_metadata_finished = add_closure_barrier(on_complete); + s->send_initial_metadata = + op_payload->send_initial_metadata.send_initial_metadata; + if (t->is_client) { + s->deadline = std::min( + s->deadline, + s->send_initial_metadata->get(grpc_core::GrpcTimeoutMetadata()) + .value_or(GRPC_MILLIS_INF_FUTURE)); + } + if (contains_non_ok_status(s->send_initial_metadata)) { + s->seen_error = true; + } + if (!s->write_closed) { + if (t->is_client) { + if (t->closed_with_error == GRPC_ERROR_NONE) { + GPR_ASSERT(s->id == 0); + grpc_chttp2_list_add_waiting_for_concurrency(t, s); + maybe_start_some_streams(t); + } else { + grpc_chttp2_cancel_stream( + t, s, + grpc_error_set_int( + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Transport closed", &t->closed_with_error, 1), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } else { + GPR_ASSERT(s->id != 0); + grpc_chttp2_mark_stream_writable(t, s); + if (!(op->send_message && + (op->payload->send_message.send_message->flags() & + GRPC_WRITE_BUFFER_HINT))) { + grpc_chttp2_initiate_write( + t, GRPC_CHTTP2_INITIATE_WRITE_SEND_INITIAL_METADATA); + } + } + } else { + s->send_initial_metadata = nullptr; + grpc_chttp2_complete_closure_step( + t, s, &s->send_initial_metadata_finished, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Attempt to send initial metadata after stream was closed", + &s->write_closed_error, 1), + "send_initial_metadata_finished"); + } + if (op_payload->send_initial_metadata.peer_string != nullptr) { + gpr_atm_rel_store(op_payload->send_initial_metadata.peer_string, + (gpr_atm)t->peer_string.c_str()); + } + } + + if (op->send_message) { + GRPC_STATS_INC_HTTP2_OP_SEND_MESSAGE(); + t->num_messages_in_next_write++; + GRPC_STATS_INC_HTTP2_SEND_MESSAGE_SIZE( + op->payload->send_message.send_message->length()); + on_complete->next_data.scratch |= CLOSURE_BARRIER_MAY_COVER_WRITE; + s->fetching_send_message_finished = add_closure_barrier(op->on_complete); + if (s->write_closed) { + op->payload->send_message.stream_write_closed = true; + // We should NOT return an error here, so as to avoid a cancel OP being + // started. The surface layer will notice that the stream has been closed + // for writes and fail the send message op. + op->payload->send_message.send_message.reset(); + grpc_chttp2_complete_closure_step( + t, s, &s->fetching_send_message_finished, GRPC_ERROR_NONE, + "fetching_send_message_finished"); + } else { + GPR_ASSERT(s->fetching_send_message == nullptr); + uint8_t* frame_hdr = grpc_slice_buffer_tiny_add( + &s->flow_controlled_buffer, GRPC_HEADER_SIZE_IN_BYTES); + uint32_t flags = op_payload->send_message.send_message->flags(); + frame_hdr[0] = (flags & GRPC_WRITE_INTERNAL_COMPRESS) != 0; + size_t len = op_payload->send_message.send_message->length(); + frame_hdr[1] = static_cast(len >> 24); + frame_hdr[2] = static_cast(len >> 16); + frame_hdr[3] = static_cast(len >> 8); + frame_hdr[4] = static_cast(len); + s->fetching_send_message = + std::move(op_payload->send_message.send_message); + s->fetched_send_message_length = 0; + s->next_message_end_offset = + s->flow_controlled_bytes_written + + static_cast(s->flow_controlled_buffer.length) + + static_cast(len); + if (flags & GRPC_WRITE_BUFFER_HINT) { + s->next_message_end_offset -= t->write_buffer_size; + s->write_buffering = true; + } else { + s->write_buffering = false; + } + continue_fetching_send_locked(t, s); + maybe_become_writable_due_to_send_msg(t, s); + } + } + + if (op->send_trailing_metadata) { + GRPC_STATS_INC_HTTP2_OP_SEND_TRAILING_METADATA(); + GPR_ASSERT(s->send_trailing_metadata_finished == nullptr); + on_complete->next_data.scratch |= CLOSURE_BARRIER_MAY_COVER_WRITE; + s->send_trailing_metadata_finished = add_closure_barrier(on_complete); + s->send_trailing_metadata = + op_payload->send_trailing_metadata.send_trailing_metadata; + s->sent_trailing_metadata_op = op_payload->send_trailing_metadata.sent; + s->write_buffering = false; + if (contains_non_ok_status(s->send_trailing_metadata)) { + s->seen_error = true; + } + if (s->write_closed) { + s->send_trailing_metadata = nullptr; + s->sent_trailing_metadata_op = nullptr; + grpc_chttp2_complete_closure_step( + t, s, &s->send_trailing_metadata_finished, + op->payload->send_trailing_metadata.send_trailing_metadata->empty() + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Attempt to send trailing metadata after " + "stream was closed"), + "send_trailing_metadata_finished"); + } else if (s->id != 0) { + // TODO(ctiller): check if there's flow control for any outstanding + // bytes before going writable + grpc_chttp2_mark_stream_writable(t, s); + grpc_chttp2_initiate_write( + t, GRPC_CHTTP2_INITIATE_WRITE_SEND_TRAILING_METADATA); + } + } + + if (op->recv_initial_metadata) { + GRPC_STATS_INC_HTTP2_OP_RECV_INITIAL_METADATA(); + GPR_ASSERT(s->recv_initial_metadata_ready == nullptr); + s->recv_initial_metadata_ready = + op_payload->recv_initial_metadata.recv_initial_metadata_ready; + s->recv_initial_metadata = + op_payload->recv_initial_metadata.recv_initial_metadata; + s->trailing_metadata_available = + op_payload->recv_initial_metadata.trailing_metadata_available; + if (op_payload->recv_initial_metadata.peer_string != nullptr) { + gpr_atm_rel_store(op_payload->recv_initial_metadata.peer_string, + (gpr_atm)t->peer_string.c_str()); + } + grpc_chttp2_maybe_complete_recv_initial_metadata(t, s); + } + + if (op->recv_message) { + GRPC_STATS_INC_HTTP2_OP_RECV_MESSAGE(); + size_t before = 0; + GPR_ASSERT(s->recv_message_ready == nullptr); + GPR_ASSERT(!s->pending_byte_stream); + s->recv_message_ready = op_payload->recv_message.recv_message_ready; + s->recv_message = op_payload->recv_message.recv_message; + s->call_failed_before_recv_message = + op_payload->recv_message.call_failed_before_recv_message; + if (s->id != 0) { + if (!s->read_closed) { + before = s->frame_storage.length + + s->unprocessed_incoming_frames_buffer.length; + } + } + grpc_chttp2_maybe_complete_recv_message(t, s); + if (s->id != 0) { + if (!s->read_closed && s->frame_storage.length == 0) { + size_t after = s->unprocessed_incoming_frames_buffer_cached_length; + s->flow_control->IncomingByteStreamUpdate(GRPC_HEADER_SIZE_IN_BYTES, + before - after); + grpc_chttp2_act_on_flowctl_action(s->flow_control->MakeAction(), t, s); + } + } + } + + if (op->recv_trailing_metadata) { + GRPC_STATS_INC_HTTP2_OP_RECV_TRAILING_METADATA(); + GPR_ASSERT(s->collecting_stats == nullptr); + s->collecting_stats = op_payload->recv_trailing_metadata.collect_stats; + GPR_ASSERT(s->recv_trailing_metadata_finished == nullptr); + s->recv_trailing_metadata_finished = + op_payload->recv_trailing_metadata.recv_trailing_metadata_ready; + s->recv_trailing_metadata = + op_payload->recv_trailing_metadata.recv_trailing_metadata; + s->final_metadata_requested = true; + grpc_chttp2_maybe_complete_recv_trailing_metadata(t, s); + } + + if (on_complete != nullptr) { + grpc_chttp2_complete_closure_step(t, s, &on_complete, GRPC_ERROR_NONE, + "op->on_complete"); + } + + GRPC_CHTTP2_STREAM_UNREF(s, "perform_stream_op"); +} + +static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, + grpc_transport_stream_op_batch* op) { + GPR_TIMER_SCOPE("perform_stream_op", 0); + grpc_chttp2_transport* t = reinterpret_cast(gt); + grpc_chttp2_stream* s = reinterpret_cast(gs); + + if (!t->is_client) { + if (op->send_initial_metadata) { + GPR_ASSERT(!op->payload->send_initial_metadata.send_initial_metadata + ->get(grpc_core::GrpcTimeoutMetadata()) + .has_value()); + } + if (op->send_trailing_metadata) { + GPR_ASSERT(!op->payload->send_trailing_metadata.send_trailing_metadata + ->get(grpc_core::GrpcTimeoutMetadata()) + .has_value()); + } + } + + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "perform_stream_op[s=%p]: %s", s, + grpc_transport_stream_op_batch_string(op).c_str()); + } + + GRPC_CHTTP2_STREAM_REF(s, "perform_stream_op"); + op->handler_private.extra_arg = gs; + t->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_stream_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +static void cancel_pings(grpc_chttp2_transport* t, grpc_error_handle error) { + // callback remaining pings: they're not allowed to call into the transport, + // and maybe they hold resources that need to be freed + grpc_chttp2_ping_queue* pq = &t->ping_queue; + GPR_ASSERT(error != GRPC_ERROR_NONE); + for (size_t j = 0; j < GRPC_CHTTP2_PCL_COUNT; j++) { + grpc_closure_list_fail_all(&pq->lists[j], GRPC_ERROR_REF(error)); + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &pq->lists[j]); + } + GRPC_ERROR_UNREF(error); +} + +static void send_ping_locked(grpc_chttp2_transport* t, + grpc_closure* on_initiate, grpc_closure* on_ack) { + if (t->closed_with_error != GRPC_ERROR_NONE) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_initiate, + GRPC_ERROR_REF(t->closed_with_error)); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_ack, + GRPC_ERROR_REF(t->closed_with_error)); + return; + } + grpc_chttp2_ping_queue* pq = &t->ping_queue; + grpc_closure_list_append(&pq->lists[GRPC_CHTTP2_PCL_INITIATE], on_initiate, + GRPC_ERROR_NONE); + grpc_closure_list_append(&pq->lists[GRPC_CHTTP2_PCL_NEXT], on_ack, + GRPC_ERROR_NONE); +} + +// Specialized form of send_ping_locked for keepalive ping. If there is already +// a ping in progress, the keepalive ping would piggyback onto that ping, +// instead of waiting for that ping to complete and then starting a new ping. +static void send_keepalive_ping_locked(grpc_chttp2_transport* t) { + if (t->closed_with_error != GRPC_ERROR_NONE) { + t->combiner->Run(GRPC_CLOSURE_INIT(&t->start_keepalive_ping_locked, + start_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(t->closed_with_error)); + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->finish_keepalive_ping_locked, + finish_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(t->closed_with_error)); + return; + } + grpc_chttp2_ping_queue* pq = &t->ping_queue; + if (!grpc_closure_list_empty(pq->lists[GRPC_CHTTP2_PCL_INFLIGHT])) { + // There is a ping in flight. Add yourself to the inflight closure list. + t->combiner->Run(GRPC_CLOSURE_INIT(&t->start_keepalive_ping_locked, + start_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(t->closed_with_error)); + grpc_closure_list_append( + &pq->lists[GRPC_CHTTP2_PCL_INFLIGHT], + GRPC_CLOSURE_INIT(&t->finish_keepalive_ping_locked, + finish_keepalive_ping, t, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + return; + } + grpc_closure_list_append( + &pq->lists[GRPC_CHTTP2_PCL_INITIATE], + GRPC_CLOSURE_INIT(&t->start_keepalive_ping_locked, start_keepalive_ping, + t, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + grpc_closure_list_append( + &pq->lists[GRPC_CHTTP2_PCL_NEXT], + GRPC_CLOSURE_INIT(&t->finish_keepalive_ping_locked, finish_keepalive_ping, + t, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); +} + +void grpc_chttp2_retry_initiate_ping(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->retry_initiate_ping_locked, + retry_initiate_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void retry_initiate_ping_locked(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->ping_state.is_delayed_ping_timer_set = false; + if (error == GRPC_ERROR_NONE) { + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_RETRY_SEND_PING); + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "retry_initiate_ping_locked"); +} + +void grpc_chttp2_ack_ping(grpc_chttp2_transport* t, uint64_t id) { + grpc_chttp2_ping_queue* pq = &t->ping_queue; + if (pq->inflight_id != id) { + gpr_log(GPR_DEBUG, "Unknown ping response from %s: %" PRIx64, + t->peer_string.c_str(), id); + return; + } + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, + &pq->lists[GRPC_CHTTP2_PCL_INFLIGHT]); + if (!grpc_closure_list_empty(pq->lists[GRPC_CHTTP2_PCL_NEXT])) { + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_CONTINUE_PINGS); + } +} + +static void send_goaway(grpc_chttp2_transport* t, grpc_error_handle error) { + // We want to log this irrespective of whether http tracing is enabled + gpr_log(GPR_DEBUG, "%s: Sending goaway err=%s", t->peer_string.c_str(), + grpc_error_std_string(error).c_str()); + t->sent_goaway_state = GRPC_CHTTP2_GOAWAY_SEND_SCHEDULED; + grpc_http2_error_code http_error; + std::string message; + grpc_error_get_status(error, GRPC_MILLIS_INF_FUTURE, nullptr, &message, + &http_error, nullptr); + grpc_chttp2_goaway_append( + t->last_new_stream_id, static_cast(http_error), + grpc_slice_from_cpp_string(std::move(message)), &t->qbuf); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_GOAWAY_SENT); + GRPC_ERROR_UNREF(error); +} + +void grpc_chttp2_add_ping_strike(grpc_chttp2_transport* t) { + if (++t->ping_recv_state.ping_strikes > t->ping_policy.max_ping_strikes && + t->ping_policy.max_ping_strikes != 0) { + send_goaway(t, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("too_many_pings"), + GRPC_ERROR_INT_HTTP2_ERROR, GRPC_HTTP2_ENHANCE_YOUR_CALM)); + // The transport will be closed after the write is done + close_transport_locked( + t, grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Too many pings"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } +} + +void grpc_chttp2_reset_ping_clock(grpc_chttp2_transport* t) { + if (!t->is_client) { + t->ping_recv_state.last_ping_recv_time = GRPC_MILLIS_INF_PAST; + t->ping_recv_state.ping_strikes = 0; + } + t->ping_state.pings_before_data_required = + t->ping_policy.max_pings_without_data; +} + +static void perform_transport_op_locked(void* stream_op, + grpc_error_handle /*error_ignored*/) { + grpc_transport_op* op = static_cast(stream_op); + grpc_chttp2_transport* t = + static_cast(op->handler_private.extra_arg); + + if (op->goaway_error != GRPC_ERROR_NONE) { + send_goaway(t, op->goaway_error); + } + + if (op->set_accept_stream) { + t->accept_stream_cb = op->set_accept_stream_fn; + t->accept_stream_cb_user_data = op->set_accept_stream_user_data; + } + + if (op->bind_pollset) { + grpc_endpoint_add_to_pollset(t->ep, op->bind_pollset); + } + + if (op->bind_pollset_set) { + grpc_endpoint_add_to_pollset_set(t->ep, op->bind_pollset_set); + } + + if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) { + send_ping_locked(t, op->send_ping.on_initiate, op->send_ping.on_ack); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_APPLICATION_PING); + } + + if (op->start_connectivity_watch != nullptr) { + t->state_tracker.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); + } + if (op->stop_connectivity_watch != nullptr) { + t->state_tracker.RemoveWatcher(op->stop_connectivity_watch); + } + + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + close_transport_locked(t, op->disconnect_with_error); + } + + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); + + GRPC_CHTTP2_UNREF_TRANSPORT(t, "transport_op"); +} + +static void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { + grpc_chttp2_transport* t = reinterpret_cast(gt); + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "perform_transport_op[t=%p]: %s", t, + grpc_transport_op_string(op).c_str()); + } + op->handler_private.extra_arg = gt; + GRPC_CHTTP2_REF_TRANSPORT(t, "transport_op"); + t->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure, + perform_transport_op_locked, op, nullptr), + GRPC_ERROR_NONE); +} + +// +// INPUT PROCESSING - GENERAL +// + +void grpc_chttp2_maybe_complete_recv_initial_metadata( + grpc_chttp2_transport* /*t*/, grpc_chttp2_stream* s) { + if (s->recv_initial_metadata_ready != nullptr && + s->published_metadata[0] != GRPC_METADATA_NOT_PUBLISHED) { + if (s->seen_error) { + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + if (!s->pending_byte_stream) { + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + } + } + *s->recv_initial_metadata = std::move(s->initial_metadata_buffer); + null_then_sched_closure(&s->recv_initial_metadata_ready); + } +} + +void grpc_chttp2_maybe_complete_recv_message(grpc_chttp2_transport* /*t*/, + grpc_chttp2_stream* s) { + grpc_error_handle error = GRPC_ERROR_NONE; + if (s->recv_message_ready != nullptr) { + *s->recv_message = nullptr; + if (s->final_metadata_requested && s->seen_error) { + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + if (!s->pending_byte_stream) { + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + } + } + if (!s->pending_byte_stream) { + while (s->unprocessed_incoming_frames_buffer.length > 0 || + s->frame_storage.length > 0) { + if (s->unprocessed_incoming_frames_buffer.length == 0) { + grpc_slice_buffer_swap(&s->unprocessed_incoming_frames_buffer, + &s->frame_storage); + s->unprocessed_incoming_frames_decompressed = false; + } + if (!s->unprocessed_incoming_frames_decompressed && + s->stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS) { + GPR_ASSERT(s->decompressed_data_buffer.length == 0); + bool end_of_context; + if (!s->stream_decompression_ctx) { + s->stream_decompression_ctx = + grpc_stream_compression_context_create( + s->stream_decompression_method); + } + if (!grpc_stream_decompress( + s->stream_decompression_ctx, + &s->unprocessed_incoming_frames_buffer, + &s->decompressed_data_buffer, nullptr, + GRPC_HEADER_SIZE_IN_BYTES - s->decompressed_header_bytes, + &end_of_context)) { + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Stream decompression error."); + } else { + s->decompressed_header_bytes += s->decompressed_data_buffer.length; + if (s->decompressed_header_bytes == GRPC_HEADER_SIZE_IN_BYTES) { + s->decompressed_header_bytes = 0; + } + error = grpc_deframe_unprocessed_incoming_frames( + &s->data_parser, s, &s->decompressed_data_buffer, nullptr, + s->recv_message); + if (end_of_context) { + grpc_stream_compression_context_destroy( + s->stream_decompression_ctx); + s->stream_decompression_ctx = nullptr; + } + } + } else { + error = grpc_deframe_unprocessed_incoming_frames( + &s->data_parser, s, &s->unprocessed_incoming_frames_buffer, + nullptr, s->recv_message); + } + if (error != GRPC_ERROR_NONE) { + s->seen_error = true; + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + break; + } else if (*s->recv_message != nullptr) { + break; + } + } + } + // save the length of the buffer before handing control back to application + // threads. Needed to support correct flow control bookkeeping + s->unprocessed_incoming_frames_buffer_cached_length = + s->unprocessed_incoming_frames_buffer.length; + if (error == GRPC_ERROR_NONE && *s->recv_message != nullptr) { + null_then_sched_closure(&s->recv_message_ready); + } else if (s->published_metadata[1] != GRPC_METADATA_NOT_PUBLISHED) { + *s->recv_message = nullptr; + if (s->call_failed_before_recv_message != nullptr) { + *s->call_failed_before_recv_message = + (s->published_metadata[1] != GRPC_METADATA_PUBLISHED_AT_CLOSE); + } + null_then_sched_closure(&s->recv_message_ready); + } + GRPC_ERROR_UNREF(error); + } +} + +void grpc_chttp2_maybe_complete_recv_trailing_metadata(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + grpc_chttp2_maybe_complete_recv_message(t, s); + if (s->recv_trailing_metadata_finished != nullptr && s->read_closed && + s->write_closed) { + if (s->seen_error || !t->is_client) { + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + if (!s->pending_byte_stream) { + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + } + } + bool pending_data = s->pending_byte_stream || + s->unprocessed_incoming_frames_buffer.length > 0; + if (s->read_closed && s->frame_storage.length > 0 && !pending_data && + !s->seen_error && s->recv_trailing_metadata_finished != nullptr) { + // Maybe some SYNC_FLUSH data is left in frame_storage. Consume them and + // maybe decompress the next 5 bytes in the stream. + if (s->stream_decompression_method == + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS) { + grpc_slice_buffer_move_first( + &s->frame_storage, + std::min(s->frame_storage.length, + size_t(GRPC_HEADER_SIZE_IN_BYTES)), + &s->unprocessed_incoming_frames_buffer); + if (s->unprocessed_incoming_frames_buffer.length > 0) { + s->unprocessed_incoming_frames_decompressed = true; + pending_data = true; + } + } else { + bool end_of_context; + if (!s->stream_decompression_ctx) { + s->stream_decompression_ctx = grpc_stream_compression_context_create( + s->stream_decompression_method); + } + if (!grpc_stream_decompress( + s->stream_decompression_ctx, &s->frame_storage, + &s->unprocessed_incoming_frames_buffer, nullptr, + GRPC_HEADER_SIZE_IN_BYTES, &end_of_context)) { + grpc_slice_buffer_reset_and_unref_internal(&s->frame_storage); + grpc_slice_buffer_reset_and_unref_internal( + &s->unprocessed_incoming_frames_buffer); + s->seen_error = true; + } else { + if (s->unprocessed_incoming_frames_buffer.length > 0) { + s->unprocessed_incoming_frames_decompressed = true; + pending_data = true; + } + if (end_of_context) { + grpc_stream_compression_context_destroy( + s->stream_decompression_ctx); + s->stream_decompression_ctx = nullptr; + } + } + } + } + if (s->read_closed && s->frame_storage.length == 0 && !pending_data && + s->recv_trailing_metadata_finished != nullptr) { + grpc_transport_move_stats(&s->stats, s->collecting_stats); + s->collecting_stats = nullptr; + *s->recv_trailing_metadata = std::move(s->trailing_metadata_buffer); + null_then_sched_closure(&s->recv_trailing_metadata_finished); + } + } +} + +static void remove_stream(grpc_chttp2_transport* t, uint32_t id, + grpc_error_handle error) { + grpc_chttp2_stream* s = static_cast( + grpc_chttp2_stream_map_delete(&t->stream_map, id)); + GPR_DEBUG_ASSERT(s); + if (t->incoming_stream == s) { + t->incoming_stream = nullptr; + grpc_chttp2_parsing_become_skip_parser(t); + } + if (s->pending_byte_stream) { + if (s->on_next != nullptr) { + grpc_core::Chttp2IncomingByteStream* bs = s->data_parser.parsing_frame; + if (error == GRPC_ERROR_NONE) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Truncated message"); + } + bs->PublishError(error); + bs->Unref(); + s->data_parser.parsing_frame = nullptr; + } else { + GRPC_ERROR_UNREF(s->byte_stream_error); + s->byte_stream_error = GRPC_ERROR_REF(error); + } + } + + if (grpc_chttp2_stream_map_size(&t->stream_map) == 0) { + post_benign_reclaimer(t); + if (t->sent_goaway_state == GRPC_CHTTP2_GOAWAY_SENT) { + close_transport_locked( + t, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Last stream closed after sending GOAWAY", &error, 1)); + } + } + if (grpc_chttp2_list_remove_writable_stream(t, s)) { + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2_writing:remove_stream"); + } + grpc_chttp2_list_remove_stalled_by_stream(t, s); + grpc_chttp2_list_remove_stalled_by_transport(t, s); + + GRPC_ERROR_UNREF(error); + + maybe_start_some_streams(t); +} + +void grpc_chttp2_cancel_stream(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_error_handle due_to_error) { + if (!t->is_client && !s->sent_trailing_metadata && + grpc_error_has_clear_grpc_status(due_to_error)) { + close_from_api(t, s, due_to_error); + return; + } + + if (!s->read_closed || !s->write_closed) { + if (s->id != 0) { + grpc_http2_error_code http_error; + grpc_error_get_status(due_to_error, s->deadline, nullptr, nullptr, + &http_error, nullptr); + grpc_chttp2_add_rst_stream_to_next_write( + t, s->id, static_cast(http_error), &s->stats.outgoing); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_RST_STREAM); + } + } + if (due_to_error != GRPC_ERROR_NONE && !s->seen_error) { + s->seen_error = true; + } + grpc_chttp2_mark_stream_closed(t, s, 1, 1, due_to_error); +} + +void grpc_chttp2_fake_status(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_error_handle error) { + grpc_status_code status; + std::string message; + grpc_error_get_status(error, s->deadline, &status, &message, nullptr, + nullptr); + if (status != GRPC_STATUS_OK) { + s->seen_error = true; + } + // stream_global->recv_trailing_metadata_finished gives us a + // last chance replacement: we've received trailing metadata, + // but something more important has become available to signal + // to the upper layers - drop what we've got, and then publish + // what we want - which is safe because we haven't told anyone + // about the metadata yet + if (s->published_metadata[1] == GRPC_METADATA_NOT_PUBLISHED || + s->recv_trailing_metadata_finished != nullptr) { + char status_string[GPR_LTOA_MIN_BUFSIZE]; + gpr_ltoa(status, status_string); + GRPC_LOG_IF_ERROR("add_status", + s->trailing_metadata_buffer.ReplaceOrAppend( + GRPC_MDSTR_GRPC_STATUS, + grpc_core::UnmanagedMemorySlice(status_string))); + if (!message.empty()) { + grpc_slice message_slice = grpc_slice_from_cpp_string(std::move(message)); + GRPC_LOG_IF_ERROR("add_status_message", + s->trailing_metadata_buffer.ReplaceOrAppend( + GRPC_MDSTR_GRPC_MESSAGE, message_slice)); + } + s->published_metadata[1] = GRPC_METADATA_SYNTHESIZED_FROM_FAKE; + grpc_chttp2_maybe_complete_recv_trailing_metadata(t, s); + } + + GRPC_ERROR_UNREF(error); +} + +static void add_error(grpc_error_handle error, grpc_error_handle* refs, + size_t* nrefs) { + if (error == GRPC_ERROR_NONE) return; + for (size_t i = 0; i < *nrefs; i++) { + if (error == refs[i]) { + return; + } + } + refs[*nrefs] = error; + ++*nrefs; +} + +static grpc_error_handle removal_error(grpc_error_handle extra_error, + grpc_chttp2_stream* s, + const char* main_error_msg) { + grpc_error_handle refs[3]; + size_t nrefs = 0; + add_error(s->read_closed_error, refs, &nrefs); + add_error(s->write_closed_error, refs, &nrefs); + add_error(extra_error, refs, &nrefs); + grpc_error_handle error = GRPC_ERROR_NONE; + if (nrefs > 0) { + error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(main_error_msg, + refs, nrefs); + } + GRPC_ERROR_UNREF(extra_error); + return error; +} + +static void flush_write_list(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_chttp2_write_cb** list, + grpc_error_handle error) { + while (*list) { + grpc_chttp2_write_cb* cb = *list; + *list = cb->next; + grpc_chttp2_complete_closure_step(t, s, &cb->closure, GRPC_ERROR_REF(error), + "on_write_finished_cb"); + cb->next = t->write_cb_pool; + t->write_cb_pool = cb; + } + GRPC_ERROR_UNREF(error); +} + +void grpc_chttp2_fail_pending_writes(grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + grpc_error_handle error) { + error = + removal_error(error, s, "Pending writes failed due to stream closure"); + s->send_initial_metadata = nullptr; + grpc_chttp2_complete_closure_step(t, s, &s->send_initial_metadata_finished, + GRPC_ERROR_REF(error), + "send_initial_metadata_finished"); + + s->send_trailing_metadata = nullptr; + s->sent_trailing_metadata_op = nullptr; + grpc_chttp2_complete_closure_step(t, s, &s->send_trailing_metadata_finished, + GRPC_ERROR_REF(error), + "send_trailing_metadata_finished"); + + s->fetching_send_message.reset(); + grpc_chttp2_complete_closure_step(t, s, &s->fetching_send_message_finished, + GRPC_ERROR_REF(error), + "fetching_send_message_finished"); + flush_write_list(t, s, &s->on_write_finished_cbs, GRPC_ERROR_REF(error)); + flush_write_list(t, s, &s->on_flow_controlled_cbs, error); +} + +void grpc_chttp2_mark_stream_closed(grpc_chttp2_transport* t, + grpc_chttp2_stream* s, int close_reads, + int close_writes, grpc_error_handle error) { + if (s->read_closed && s->write_closed) { + // already closed, but we should still fake the status if needed. + grpc_error_handle overall_error = removal_error(error, s, "Stream removed"); + if (overall_error != GRPC_ERROR_NONE) { + grpc_chttp2_fake_status(t, s, overall_error); + } + grpc_chttp2_maybe_complete_recv_trailing_metadata(t, s); + return; + } + bool closed_read = false; + bool became_closed = false; + if (close_reads && !s->read_closed) { + s->read_closed_error = GRPC_ERROR_REF(error); + s->read_closed = true; + closed_read = true; + } + if (close_writes && !s->write_closed) { + s->write_closed_error = GRPC_ERROR_REF(error); + s->write_closed = true; + grpc_chttp2_fail_pending_writes(t, s, GRPC_ERROR_REF(error)); + } + if (s->read_closed && s->write_closed) { + became_closed = true; + grpc_error_handle overall_error = + removal_error(GRPC_ERROR_REF(error), s, "Stream removed"); + if (s->id != 0) { + remove_stream(t, s->id, GRPC_ERROR_REF(overall_error)); + } else { + // Purge streams waiting on concurrency still waiting for id assignment + grpc_chttp2_list_remove_waiting_for_concurrency(t, s); + } + if (overall_error != GRPC_ERROR_NONE) { + grpc_chttp2_fake_status(t, s, overall_error); + } + } + if (closed_read) { + for (int i = 0; i < 2; i++) { + if (s->published_metadata[i] == GRPC_METADATA_NOT_PUBLISHED) { + s->published_metadata[i] = GRPC_METADATA_PUBLISHED_AT_CLOSE; + } + } + grpc_chttp2_maybe_complete_recv_initial_metadata(t, s); + grpc_chttp2_maybe_complete_recv_message(t, s); + } + if (became_closed) { + grpc_chttp2_maybe_complete_recv_trailing_metadata(t, s); + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2"); + } + GRPC_ERROR_UNREF(error); +} + +static void close_from_api(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_error_handle error) { + grpc_slice hdr; + grpc_slice status_hdr; + grpc_slice http_status_hdr; + grpc_slice content_type_hdr; + grpc_slice message_pfx; + uint8_t* p; + uint32_t len = 0; + grpc_status_code grpc_status; + std::string message; + grpc_error_get_status(error, s->deadline, &grpc_status, &message, nullptr, + nullptr); + + GPR_ASSERT(grpc_status >= 0 && (int)grpc_status < 100); + + // Hand roll a header block. + // This is unnecessarily ugly - at some point we should find a more + // elegant solution. + // It's complicated by the fact that our send machinery would be dead by + // the time we got around to sending this, so instead we ignore HPACK + // compression and just write the uncompressed bytes onto the wire. + if (!s->sent_initial_metadata) { + http_status_hdr = GRPC_SLICE_MALLOC(13); + p = GRPC_SLICE_START_PTR(http_status_hdr); + *p++ = 0x00; + *p++ = 7; + *p++ = ':'; + *p++ = 's'; + *p++ = 't'; + *p++ = 'a'; + *p++ = 't'; + *p++ = 'u'; + *p++ = 's'; + *p++ = 3; + *p++ = '2'; + *p++ = '0'; + *p++ = '0'; + GPR_ASSERT(p == GRPC_SLICE_END_PTR(http_status_hdr)); + len += static_cast GRPC_SLICE_LENGTH(http_status_hdr); + + content_type_hdr = GRPC_SLICE_MALLOC(31); + p = GRPC_SLICE_START_PTR(content_type_hdr); + *p++ = 0x00; + *p++ = 12; + *p++ = 'c'; + *p++ = 'o'; + *p++ = 'n'; + *p++ = 't'; + *p++ = 'e'; + *p++ = 'n'; + *p++ = 't'; + *p++ = '-'; + *p++ = 't'; + *p++ = 'y'; + *p++ = 'p'; + *p++ = 'e'; + *p++ = 16; + *p++ = 'a'; + *p++ = 'p'; + *p++ = 'p'; + *p++ = 'l'; + *p++ = 'i'; + *p++ = 'c'; + *p++ = 'a'; + *p++ = 't'; + *p++ = 'i'; + *p++ = 'o'; + *p++ = 'n'; + *p++ = '/'; + *p++ = 'g'; + *p++ = 'r'; + *p++ = 'p'; + *p++ = 'c'; + GPR_ASSERT(p == GRPC_SLICE_END_PTR(content_type_hdr)); + len += static_cast GRPC_SLICE_LENGTH(content_type_hdr); + } + + status_hdr = GRPC_SLICE_MALLOC(15 + (grpc_status >= 10)); + p = GRPC_SLICE_START_PTR(status_hdr); + *p++ = 0x00; /* literal header, not indexed */ + *p++ = 11; /* len(grpc-status) */ + *p++ = 'g'; + *p++ = 'r'; + *p++ = 'p'; + *p++ = 'c'; + *p++ = '-'; + *p++ = 's'; + *p++ = 't'; + *p++ = 'a'; + *p++ = 't'; + *p++ = 'u'; + *p++ = 's'; + if (grpc_status < 10) { + *p++ = 1; + *p++ = static_cast('0' + grpc_status); + } else { + *p++ = 2; + *p++ = static_cast('0' + (grpc_status / 10)); + *p++ = static_cast('0' + (grpc_status % 10)); + } + GPR_ASSERT(p == GRPC_SLICE_END_PTR(status_hdr)); + len += static_cast GRPC_SLICE_LENGTH(status_hdr); + + size_t msg_len = message.length(); + GPR_ASSERT(msg_len <= UINT32_MAX); + grpc_core::VarintWriter<1> msg_len_writer(msg_len); + message_pfx = GRPC_SLICE_MALLOC(14 + msg_len_writer.length()); + p = GRPC_SLICE_START_PTR(message_pfx); + *p++ = 0x00; /* literal header, not indexed */ + *p++ = 12; /* len(grpc-message) */ + *p++ = 'g'; + *p++ = 'r'; + *p++ = 'p'; + *p++ = 'c'; + *p++ = '-'; + *p++ = 'm'; + *p++ = 'e'; + *p++ = 's'; + *p++ = 's'; + *p++ = 'a'; + *p++ = 'g'; + *p++ = 'e'; + msg_len_writer.Write(0, p); + p += msg_len_writer.length(); + GPR_ASSERT(p == GRPC_SLICE_END_PTR(message_pfx)); + len += static_cast GRPC_SLICE_LENGTH(message_pfx); + len += static_cast(msg_len); + + hdr = GRPC_SLICE_MALLOC(9); + p = GRPC_SLICE_START_PTR(hdr); + *p++ = static_cast(len >> 16); + *p++ = static_cast(len >> 8); + *p++ = static_cast(len); + *p++ = GRPC_CHTTP2_FRAME_HEADER; + *p++ = GRPC_CHTTP2_DATA_FLAG_END_STREAM | GRPC_CHTTP2_DATA_FLAG_END_HEADERS; + *p++ = static_cast(s->id >> 24); + *p++ = static_cast(s->id >> 16); + *p++ = static_cast(s->id >> 8); + *p++ = static_cast(s->id); + GPR_ASSERT(p == GRPC_SLICE_END_PTR(hdr)); + + grpc_slice_buffer_add(&t->qbuf, hdr); + if (!s->sent_initial_metadata) { + grpc_slice_buffer_add(&t->qbuf, http_status_hdr); + grpc_slice_buffer_add(&t->qbuf, content_type_hdr); + } + grpc_slice_buffer_add(&t->qbuf, status_hdr); + grpc_slice_buffer_add(&t->qbuf, message_pfx); + grpc_slice_buffer_add(&t->qbuf, + grpc_slice_from_cpp_string(std::move(message))); + grpc_chttp2_reset_ping_clock(t); + grpc_chttp2_add_rst_stream_to_next_write(t, s->id, GRPC_HTTP2_NO_ERROR, + &s->stats.outgoing); + + grpc_chttp2_mark_stream_closed(t, s, 1, 1, error); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_CLOSE_FROM_API); +} + +struct cancel_stream_cb_args { + grpc_error_handle error; + grpc_chttp2_transport* t; +}; + +static void cancel_stream_cb(void* user_data, uint32_t /*key*/, void* stream) { + cancel_stream_cb_args* args = static_cast(user_data); + grpc_chttp2_stream* s = static_cast(stream); + grpc_chttp2_cancel_stream(args->t, s, GRPC_ERROR_REF(args->error)); +} + +static void end_all_the_calls(grpc_chttp2_transport* t, + grpc_error_handle error) { + intptr_t http2_error; + // If there is no explicit grpc or HTTP/2 error, set to UNAVAILABLE on server. + if (!t->is_client && !grpc_error_has_clear_grpc_status(error) && + !grpc_error_get_int(error, GRPC_ERROR_INT_HTTP2_ERROR, &http2_error)) { + error = grpc_error_set_int(error, GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE); + } + cancel_stream_cb_args args = {error, t}; + grpc_chttp2_stream_map_for_each(&t->stream_map, cancel_stream_cb, &args); + GRPC_ERROR_UNREF(error); +} + +// +// INPUT PROCESSING - PARSING +// + +template +static void WithUrgency(grpc_chttp2_transport* t, + grpc_core::chttp2::FlowControlAction::Urgency urgency, + grpc_chttp2_initiate_write_reason reason, F action) { + switch (urgency) { + case grpc_core::chttp2::FlowControlAction::Urgency::NO_ACTION_NEEDED: + break; + case grpc_core::chttp2::FlowControlAction::Urgency::UPDATE_IMMEDIATELY: + grpc_chttp2_initiate_write(t, reason); + ABSL_FALLTHROUGH_INTENDED; + case grpc_core::chttp2::FlowControlAction::Urgency::QUEUE_UPDATE: + action(); + break; + } +} + +void grpc_chttp2_act_on_flowctl_action( + const grpc_core::chttp2::FlowControlAction& action, + grpc_chttp2_transport* t, grpc_chttp2_stream* s) { + WithUrgency(t, action.send_stream_update(), + GRPC_CHTTP2_INITIATE_WRITE_STREAM_FLOW_CONTROL, + [t, s]() { grpc_chttp2_mark_stream_writable(t, s); }); + WithUrgency(t, action.send_transport_update(), + GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL, []() {}); + WithUrgency(t, action.send_initial_window_update(), + GRPC_CHTTP2_INITIATE_WRITE_SEND_SETTINGS, [t, &action]() { + queue_setting_update(t, + GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, + action.initial_window_size()); + }); + WithUrgency(t, action.send_max_frame_size_update(), + GRPC_CHTTP2_INITIATE_WRITE_SEND_SETTINGS, [t, &action]() { + queue_setting_update(t, GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE, + action.max_frame_size()); + }); +} + +static grpc_error_handle try_http_parsing(grpc_chttp2_transport* t) { + grpc_http_parser parser; + size_t i = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_http_response response; + + grpc_http_parser_init(&parser, GRPC_HTTP_RESPONSE, &response); + + grpc_error_handle parse_error = GRPC_ERROR_NONE; + for (; i < t->read_buffer.count && parse_error == GRPC_ERROR_NONE; i++) { + parse_error = + grpc_http_parser_parse(&parser, t->read_buffer.slices[i], nullptr); + } + if (parse_error == GRPC_ERROR_NONE && + (parse_error = grpc_http_parser_eof(&parser)) == GRPC_ERROR_NONE) { + error = grpc_error_set_int( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Trying to connect an http1.x server"), + GRPC_ERROR_INT_HTTP_STATUS, response.status), + GRPC_ERROR_INT_GRPC_STATUS, + grpc_http2_status_to_grpc_status(response.status)); + } + GRPC_ERROR_UNREF(parse_error); + + grpc_http_parser_destroy(&parser); + grpc_http_response_destroy(&response); + return error; +} + +static void read_action(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->read_action_locked, read_action_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void read_action_locked(void* tp, grpc_error_handle error) { + GPR_TIMER_SCOPE("reading_action_locked", 0); + + grpc_chttp2_transport* t = static_cast(tp); + + (void)GRPC_ERROR_REF(error); + + grpc_error_handle err = error; + if (err != GRPC_ERROR_NONE) { + err = grpc_error_set_int(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Endpoint read failed", &err, 1), + GRPC_ERROR_INT_OCCURRED_DURING_WRITE, + t->write_state); + } + std::swap(err, error); + GRPC_ERROR_UNREF(err); + if (t->closed_with_error == GRPC_ERROR_NONE) { + GPR_TIMER_SCOPE("reading_action.parse", 0); + size_t i = 0; + grpc_error_handle errors[3] = {GRPC_ERROR_REF(error), GRPC_ERROR_NONE, + GRPC_ERROR_NONE}; + for (; i < t->read_buffer.count && errors[1] == GRPC_ERROR_NONE; i++) { + errors[1] = grpc_chttp2_perform_read(t, t->read_buffer.slices[i]); + } + if (errors[1] != GRPC_ERROR_NONE) { + errors[2] = try_http_parsing(t); + GRPC_ERROR_UNREF(error); + error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed parsing HTTP/2", errors, GPR_ARRAY_SIZE(errors)); + } + for (i = 0; i < GPR_ARRAY_SIZE(errors); i++) { + GRPC_ERROR_UNREF(errors[i]); + } + + GPR_TIMER_SCOPE("post_parse_locked", 0); + if (t->initial_window_update != 0) { + if (t->initial_window_update > 0) { + grpc_chttp2_stream* s; + while (grpc_chttp2_list_pop_stalled_by_stream(t, &s)) { + grpc_chttp2_mark_stream_writable(t, s); + grpc_chttp2_initiate_write( + t, GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_SETTING); + } + } + t->initial_window_update = 0; + } + } + + GPR_TIMER_SCOPE("post_reading_action_locked", 0); + bool keep_reading = false; + if (error == GRPC_ERROR_NONE && t->closed_with_error != GRPC_ERROR_NONE) { + error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Transport closed", &t->closed_with_error, 1); + } + if (error != GRPC_ERROR_NONE) { + // If a goaway frame was received, this might be the reason why the read + // failed. Add this info to the error + if (t->goaway_error != GRPC_ERROR_NONE) { + error = grpc_error_add_child(error, GRPC_ERROR_REF(t->goaway_error)); + } + + close_transport_locked(t, GRPC_ERROR_REF(error)); + t->endpoint_reading = 0; + } else if (t->closed_with_error == GRPC_ERROR_NONE) { + keep_reading = true; + // Since we have read a byte, reset the keepalive timer + if (t->keepalive_state == GRPC_CHTTP2_KEEPALIVE_STATE_WAITING) { + grpc_timer_cancel(&t->keepalive_ping_timer); + } + } + grpc_slice_buffer_reset_and_unref_internal(&t->read_buffer); + + if (keep_reading) { + if (t->num_pending_induced_frames >= DEFAULT_MAX_PENDING_INDUCED_FRAMES) { + t->reading_paused_on_pending_induced_frames = true; + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, + "transport %p : Pausing reading due to too " + "many unwritten SETTINGS ACK and RST_STREAM frames", + t)); + } else { + continue_read_action_locked(t); + } + } else { + GRPC_CHTTP2_UNREF_TRANSPORT(t, "reading_action"); + } + + GRPC_ERROR_UNREF(error); +} + +static void continue_read_action_locked(grpc_chttp2_transport* t) { + const bool urgent = t->goaway_error != GRPC_ERROR_NONE; + GRPC_CLOSURE_INIT(&t->read_action_locked, read_action, t, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(t->ep, &t->read_buffer, &t->read_action_locked, urgent); + grpc_chttp2_act_on_flowctl_action(t->flow_control->MakeAction(), t, nullptr); +} + +// t is reffed prior to calling the first time, and once the callback chain +// that kicks off finishes, it's unreffed +void schedule_bdp_ping_locked(grpc_chttp2_transport* t) { + t->flow_control->bdp_estimator()->SchedulePing(); + send_ping_locked( + t, + GRPC_CLOSURE_INIT(&t->start_bdp_ping_locked, start_bdp_ping, t, + grpc_schedule_on_exec_ctx), + GRPC_CLOSURE_INIT(&t->finish_bdp_ping_locked, finish_bdp_ping, t, + grpc_schedule_on_exec_ctx)); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_BDP_PING); +} + +static void start_bdp_ping(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->start_bdp_ping_locked, + start_bdp_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void start_bdp_ping_locked(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "%s: Start BDP ping err=%s", t->peer_string.c_str(), + grpc_error_std_string(error).c_str()); + } + if (error != GRPC_ERROR_NONE || t->closed_with_error != GRPC_ERROR_NONE) { + return; + } + // Reset the keepalive ping timer + if (t->keepalive_state == GRPC_CHTTP2_KEEPALIVE_STATE_WAITING) { + grpc_timer_cancel(&t->keepalive_ping_timer); + } + t->flow_control->bdp_estimator()->StartPing(); + t->bdp_ping_started = true; +} + +static void finish_bdp_ping(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->finish_bdp_ping_locked, + finish_bdp_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void finish_bdp_ping_locked(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "%s: Complete BDP ping err=%s", t->peer_string.c_str(), + grpc_error_std_string(error).c_str()); + } + if (error != GRPC_ERROR_NONE || t->closed_with_error != GRPC_ERROR_NONE) { + GRPC_CHTTP2_UNREF_TRANSPORT(t, "bdp_ping"); + return; + } + if (!t->bdp_ping_started) { + // start_bdp_ping_locked has not been run yet. Schedule + // finish_bdp_ping_locked to be run later. + t->combiner->Run(GRPC_CLOSURE_INIT(&t->finish_bdp_ping_locked, + finish_bdp_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); + return; + } + t->bdp_ping_started = false; + grpc_millis next_ping = t->flow_control->bdp_estimator()->CompletePing(); + grpc_chttp2_act_on_flowctl_action(t->flow_control->PeriodicUpdate(), t, + nullptr); + GPR_ASSERT(!t->have_next_bdp_ping_timer); + t->have_next_bdp_ping_timer = true; + GRPC_CLOSURE_INIT(&t->next_bdp_ping_timer_expired_locked, + next_bdp_ping_timer_expired, t, grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->next_bdp_ping_timer, next_ping, + &t->next_bdp_ping_timer_expired_locked); +} + +static void next_bdp_ping_timer_expired(void* tp, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->next_bdp_ping_timer_expired_locked, + next_bdp_ping_timer_expired_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void next_bdp_ping_timer_expired_locked(void* tp, + grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(tp); + GPR_ASSERT(t->have_next_bdp_ping_timer); + t->have_next_bdp_ping_timer = false; + if (error != GRPC_ERROR_NONE) { + GRPC_CHTTP2_UNREF_TRANSPORT(t, "bdp_ping"); + return; + } + if (t->flow_control->bdp_estimator()->accumulator() == 0) { + // Block the bdp ping till we receive more data. + t->bdp_ping_blocked = true; + GRPC_CHTTP2_UNREF_TRANSPORT(t, "bdp_ping"); + } else { + schedule_bdp_ping_locked(t); + } +} + +void grpc_chttp2_config_default_keepalive_args(grpc_channel_args* args, + bool is_client) { + size_t i; + if (args) { + for (i = 0; i < args->num_args; i++) { + if (0 == strcmp(args->args[i].key, GRPC_ARG_KEEPALIVE_TIME_MS)) { + const int value = grpc_channel_arg_get_integer( + &args->args[i], {is_client ? g_default_client_keepalive_time_ms + : g_default_server_keepalive_time_ms, + 1, INT_MAX}); + if (is_client) { + g_default_client_keepalive_time_ms = value; + } else { + g_default_server_keepalive_time_ms = value; + } + } else if (0 == + strcmp(args->args[i].key, GRPC_ARG_KEEPALIVE_TIMEOUT_MS)) { + const int value = grpc_channel_arg_get_integer( + &args->args[i], {is_client ? g_default_client_keepalive_timeout_ms + : g_default_server_keepalive_timeout_ms, + 0, INT_MAX}); + if (is_client) { + g_default_client_keepalive_timeout_ms = value; + } else { + g_default_server_keepalive_timeout_ms = value; + } + } else if (0 == strcmp(args->args[i].key, + GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS)) { + const bool value = static_cast(grpc_channel_arg_get_integer( + &args->args[i], + {is_client ? g_default_client_keepalive_permit_without_calls + : g_default_server_keepalive_timeout_ms, + 0, 1})); + if (is_client) { + g_default_client_keepalive_permit_without_calls = value; + } else { + g_default_server_keepalive_permit_without_calls = value; + } + } else if (0 == + strcmp(args->args[i].key, GRPC_ARG_HTTP2_MAX_PING_STRIKES)) { + g_default_max_ping_strikes = grpc_channel_arg_get_integer( + &args->args[i], {g_default_max_ping_strikes, 0, INT_MAX}); + } else if (0 == strcmp(args->args[i].key, + GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA)) { + g_default_max_pings_without_data = grpc_channel_arg_get_integer( + &args->args[i], {g_default_max_pings_without_data, 0, INT_MAX}); + } else if (0 == + strcmp( + args->args[i].key, + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS)) { + g_default_min_recv_ping_interval_without_data_ms = + grpc_channel_arg_get_integer( + &args->args[i], + {g_default_min_recv_ping_interval_without_data_ms, 0, INT_MAX}); + } + } + } +} + +static void init_keepalive_ping(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->init_keepalive_ping_locked, + init_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void init_keepalive_ping_locked(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + GPR_ASSERT(t->keepalive_state == GRPC_CHTTP2_KEEPALIVE_STATE_WAITING); + if (t->destroying || t->closed_with_error != GRPC_ERROR_NONE) { + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_DYING; + } else if (error == GRPC_ERROR_NONE) { + if (t->keepalive_permit_without_calls || + grpc_chttp2_stream_map_size(&t->stream_map) > 0) { + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_PINGING; + GRPC_CHTTP2_REF_TRANSPORT(t, "keepalive ping end"); + grpc_timer_init_unset(&t->keepalive_watchdog_timer); + send_keepalive_ping_locked(t); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_KEEPALIVE_PING); + } else { + GRPC_CHTTP2_REF_TRANSPORT(t, "init keepalive ping"); + GRPC_CLOSURE_INIT(&t->init_keepalive_ping_locked, init_keepalive_ping, t, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->keepalive_ping_timer, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_time, + &t->init_keepalive_ping_locked); + } + } else if (error == GRPC_ERROR_CANCELLED) { + // The keepalive ping timer may be cancelled by bdp + GRPC_CHTTP2_REF_TRANSPORT(t, "init keepalive ping"); + GRPC_CLOSURE_INIT(&t->init_keepalive_ping_locked, init_keepalive_ping, t, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->keepalive_ping_timer, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_time, + &t->init_keepalive_ping_locked); + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "init keepalive ping"); +} + +static void start_keepalive_ping(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->start_keepalive_ping_locked, + start_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void start_keepalive_ping_locked(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + return; + } + if (t->channelz_socket != nullptr) { + t->channelz_socket->RecordKeepaliveSent(); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, "%s: Start keepalive ping", t->peer_string.c_str()); + } + GRPC_CHTTP2_REF_TRANSPORT(t, "keepalive watchdog"); + GRPC_CLOSURE_INIT(&t->keepalive_watchdog_fired_locked, + keepalive_watchdog_fired, t, grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->keepalive_watchdog_timer, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_timeout, + &t->keepalive_watchdog_fired_locked); + t->keepalive_ping_started = true; +} + +static void finish_keepalive_ping(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->finish_keepalive_ping_locked, + finish_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void finish_keepalive_ping_locked(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + if (t->keepalive_state == GRPC_CHTTP2_KEEPALIVE_STATE_PINGING) { + if (error == GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, "%s: Finish keepalive ping", t->peer_string.c_str()); + } + if (!t->keepalive_ping_started) { + // start_keepalive_ping_locked has not run yet. Reschedule + // finish_keepalive_ping_locked for it to be run later. + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->finish_keepalive_ping_locked, + finish_keepalive_ping_locked, t, nullptr), + GRPC_ERROR_REF(error)); + return; + } + t->keepalive_ping_started = false; + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_WAITING; + grpc_timer_cancel(&t->keepalive_watchdog_timer); + GRPC_CHTTP2_REF_TRANSPORT(t, "init keepalive ping"); + GRPC_CLOSURE_INIT(&t->init_keepalive_ping_locked, init_keepalive_ping, t, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->keepalive_ping_timer, + grpc_core::ExecCtx::Get()->Now() + t->keepalive_time, + &t->init_keepalive_ping_locked); + } + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "keepalive ping end"); +} + +static void keepalive_watchdog_fired(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->keepalive_watchdog_fired_locked, + keepalive_watchdog_fired_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void keepalive_watchdog_fired_locked(void* arg, + grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + if (t->keepalive_state == GRPC_CHTTP2_KEEPALIVE_STATE_PINGING) { + if (error == GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, "%s: Keepalive watchdog fired. Closing transport.", + t->peer_string.c_str()); + t->keepalive_state = GRPC_CHTTP2_KEEPALIVE_STATE_DYING; + close_transport_locked( + t, grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "keepalive watchdog timeout"), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE)); + } + } else { + // The watchdog timer should have been cancelled by + // finish_keepalive_ping_locked. + if (GPR_UNLIKELY(error != GRPC_ERROR_CANCELLED)) { + gpr_log(GPR_ERROR, "keepalive_ping_end state error: %d (expect: %d)", + t->keepalive_state, GRPC_CHTTP2_KEEPALIVE_STATE_PINGING); + } + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "keepalive watchdog"); +} + +// +// CALLBACK LOOP +// + +static void connectivity_state_set(grpc_chttp2_transport* t, + grpc_connectivity_state state, + const absl::Status& status, + const char* reason) { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, "transport %p set connectivity_state=%d", t, state)); + t->state_tracker.SetState(state, status, reason); +} + +// +// POLLSET STUFF +// + +static void set_pollset(grpc_transport* gt, grpc_stream* /*gs*/, + grpc_pollset* pollset) { + grpc_chttp2_transport* t = reinterpret_cast(gt); + grpc_endpoint_add_to_pollset(t->ep, pollset); +} + +static void set_pollset_set(grpc_transport* gt, grpc_stream* /*gs*/, + grpc_pollset_set* pollset_set) { + grpc_chttp2_transport* t = reinterpret_cast(gt); + grpc_endpoint_add_to_pollset_set(t->ep, pollset_set); +} + +// +// BYTE STREAM +// + +static void reset_byte_stream(void* arg, grpc_error_handle error) { + grpc_chttp2_stream* s = static_cast(arg); + s->pending_byte_stream = false; + if (error == GRPC_ERROR_NONE) { + grpc_chttp2_maybe_complete_recv_message(s->t, s); + grpc_chttp2_maybe_complete_recv_trailing_metadata(s->t, s); + } else { + GPR_ASSERT(error != GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, s->on_next, GRPC_ERROR_REF(error)); + s->on_next = nullptr; + GRPC_ERROR_UNREF(s->byte_stream_error); + s->byte_stream_error = GRPC_ERROR_NONE; + grpc_chttp2_cancel_stream(s->t, s, GRPC_ERROR_REF(error)); + s->byte_stream_error = GRPC_ERROR_REF(error); + } +} + +namespace grpc_core { + +Chttp2IncomingByteStream::Chttp2IncomingByteStream( + grpc_chttp2_transport* transport, grpc_chttp2_stream* stream, + uint32_t frame_size, uint32_t flags) + : ByteStream(frame_size, flags), + transport_(transport), + stream_(stream), + refs_(2), + remaining_bytes_(frame_size) { + GRPC_ERROR_UNREF(stream->byte_stream_error); + stream->byte_stream_error = GRPC_ERROR_NONE; +} + +void Chttp2IncomingByteStream::OrphanLocked( + void* arg, grpc_error_handle /*error_ignored*/) { + Chttp2IncomingByteStream* bs = static_cast(arg); + grpc_chttp2_stream* s = bs->stream_; + grpc_chttp2_transport* t = s->t; + bs->Unref(); + s->pending_byte_stream = false; + grpc_chttp2_maybe_complete_recv_message(t, s); + grpc_chttp2_maybe_complete_recv_trailing_metadata(t, s); +} + +void Chttp2IncomingByteStream::Orphan() { + GPR_TIMER_SCOPE("incoming_byte_stream_destroy", 0); + transport_->combiner->Run( + GRPC_CLOSURE_INIT(&destroy_action_, + &Chttp2IncomingByteStream::OrphanLocked, this, nullptr), + GRPC_ERROR_NONE); +} + +void Chttp2IncomingByteStream::NextLocked(void* arg, + grpc_error_handle /*error_ignored*/) { + Chttp2IncomingByteStream* bs = static_cast(arg); + grpc_chttp2_transport* t = bs->transport_; + grpc_chttp2_stream* s = bs->stream_; + size_t cur_length = s->frame_storage.length; + if (!s->read_closed) { + s->flow_control->IncomingByteStreamUpdate(bs->next_action_.max_size_hint, + cur_length); + grpc_chttp2_act_on_flowctl_action(s->flow_control->MakeAction(), t, s); + } + GPR_ASSERT(s->unprocessed_incoming_frames_buffer.length == 0); + if (s->frame_storage.length > 0) { + grpc_slice_buffer_swap(&s->frame_storage, + &s->unprocessed_incoming_frames_buffer); + s->unprocessed_incoming_frames_decompressed = false; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, bs->next_action_.on_complete, + GRPC_ERROR_NONE); + } else if (s->byte_stream_error != GRPC_ERROR_NONE) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, bs->next_action_.on_complete, + GRPC_ERROR_REF(s->byte_stream_error)); + if (s->data_parser.parsing_frame != nullptr) { + s->data_parser.parsing_frame->Unref(); + s->data_parser.parsing_frame = nullptr; + } + } else if (s->read_closed) { + if (bs->remaining_bytes_ != 0) { + s->byte_stream_error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Truncated message", &s->read_closed_error, 1); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, bs->next_action_.on_complete, + GRPC_ERROR_REF(s->byte_stream_error)); + if (s->data_parser.parsing_frame != nullptr) { + s->data_parser.parsing_frame->Unref(); + s->data_parser.parsing_frame = nullptr; + } + } else { + // Should never reach here. + GPR_ASSERT(false); + } + } else { + s->on_next = bs->next_action_.on_complete; + } + bs->Unref(); +} + +bool Chttp2IncomingByteStream::Next(size_t max_size_hint, + grpc_closure* on_complete) { + GPR_TIMER_SCOPE("incoming_byte_stream_next", 0); + if (stream_->unprocessed_incoming_frames_buffer.length > 0) { + return true; + } else { + Ref(); + next_action_.max_size_hint = max_size_hint; + next_action_.on_complete = on_complete; + transport_->combiner->Run( + GRPC_CLOSURE_INIT(&next_action_.closure, + &Chttp2IncomingByteStream::NextLocked, this, nullptr), + GRPC_ERROR_NONE); + return false; + } +} + +void Chttp2IncomingByteStream::MaybeCreateStreamDecompressionCtx() { + GPR_DEBUG_ASSERT(stream_->stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS); + if (!stream_->stream_decompression_ctx) { + stream_->stream_decompression_ctx = grpc_stream_compression_context_create( + stream_->stream_decompression_method); + } +} + +grpc_error_handle Chttp2IncomingByteStream::Pull(grpc_slice* slice) { + GPR_TIMER_SCOPE("incoming_byte_stream_pull", 0); + grpc_error_handle error; + if (stream_->unprocessed_incoming_frames_buffer.length > 0) { + if (!stream_->unprocessed_incoming_frames_decompressed && + stream_->stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS) { + bool end_of_context; + MaybeCreateStreamDecompressionCtx(); + if (!grpc_stream_decompress(stream_->stream_decompression_ctx, + &stream_->unprocessed_incoming_frames_buffer, + &stream_->decompressed_data_buffer, nullptr, + MAX_SIZE_T, &end_of_context)) { + error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Stream decompression error."); + return error; + } + GPR_ASSERT(stream_->unprocessed_incoming_frames_buffer.length == 0); + grpc_slice_buffer_swap(&stream_->unprocessed_incoming_frames_buffer, + &stream_->decompressed_data_buffer); + stream_->unprocessed_incoming_frames_decompressed = true; + if (end_of_context) { + grpc_stream_compression_context_destroy( + stream_->stream_decompression_ctx); + stream_->stream_decompression_ctx = nullptr; + } + if (stream_->unprocessed_incoming_frames_buffer.length == 0) { + *slice = grpc_empty_slice(); + } + } + error = grpc_deframe_unprocessed_incoming_frames( + &stream_->data_parser, stream_, + &stream_->unprocessed_incoming_frames_buffer, slice, nullptr); + if (error != GRPC_ERROR_NONE) { + return error; + } + } else { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Truncated message"); + stream_->t->combiner->Run(&stream_->reset_byte_stream, + GRPC_ERROR_REF(error)); + return error; + } + return GRPC_ERROR_NONE; +} + +void Chttp2IncomingByteStream::PublishError(grpc_error_handle error) { + GPR_ASSERT(error != GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, stream_->on_next, + GRPC_ERROR_REF(error)); + stream_->on_next = nullptr; + GRPC_ERROR_UNREF(stream_->byte_stream_error); + stream_->byte_stream_error = GRPC_ERROR_REF(error); + grpc_chttp2_cancel_stream(transport_, stream_, GRPC_ERROR_REF(error)); +} + +grpc_error_handle Chttp2IncomingByteStream::Push(const grpc_slice& slice, + grpc_slice* slice_out) { + if (remaining_bytes_ < GRPC_SLICE_LENGTH(slice)) { + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Too many bytes in stream"); + transport_->combiner->Run(&stream_->reset_byte_stream, + GRPC_ERROR_REF(error)); + grpc_slice_unref_internal(slice); + return error; + } else { + remaining_bytes_ -= static_cast GRPC_SLICE_LENGTH(slice); + if (slice_out != nullptr) { + *slice_out = slice; + } + return GRPC_ERROR_NONE; + } +} + +grpc_error_handle Chttp2IncomingByteStream::Finished(grpc_error_handle error, + bool reset_on_error) { + if (error == GRPC_ERROR_NONE) { + if (remaining_bytes_ != 0) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Truncated message"); + } + } + if (error != GRPC_ERROR_NONE && reset_on_error) { + transport_->combiner->Run(&stream_->reset_byte_stream, + GRPC_ERROR_REF(error)); + } + Unref(); + return error; +} + +void Chttp2IncomingByteStream::Shutdown(grpc_error_handle error) { + GRPC_ERROR_UNREF(Finished(error, true /* reset_on_error */)); +} + +} // namespace grpc_core + +// +// RESOURCE QUOTAS +// + +static void post_benign_reclaimer(grpc_chttp2_transport* t) { + if (!t->benign_reclaimer_registered) { + t->benign_reclaimer_registered = true; + GRPC_CHTTP2_REF_TRANSPORT(t, "benign_reclaimer"); + GRPC_CLOSURE_INIT(&t->benign_reclaimer_locked, benign_reclaimer, t, + grpc_schedule_on_exec_ctx); + grpc_resource_user_post_reclaimer(t->resource_user, false, + &t->benign_reclaimer_locked); + } +} + +static void post_destructive_reclaimer(grpc_chttp2_transport* t) { + if (!t->destructive_reclaimer_registered) { + t->destructive_reclaimer_registered = true; + GRPC_CHTTP2_REF_TRANSPORT(t, "destructive_reclaimer"); + GRPC_CLOSURE_INIT(&t->destructive_reclaimer_locked, destructive_reclaimer, + t, grpc_schedule_on_exec_ctx); + grpc_resource_user_post_reclaimer(t->resource_user, true, + &t->destructive_reclaimer_locked); + } +} + +static void benign_reclaimer(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->benign_reclaimer_locked, + benign_reclaimer_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void benign_reclaimer_locked(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + if (error == GRPC_ERROR_NONE && + grpc_chttp2_stream_map_size(&t->stream_map) == 0) { + // Channel with no active streams: send a goaway to try and make it + // disconnect cleanly + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "HTTP2: %s - send goaway to free memory", + t->peer_string.c_str()); + } + send_goaway(t, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Buffers full"), + GRPC_ERROR_INT_HTTP2_ERROR, GRPC_HTTP2_ENHANCE_YOUR_CALM)); + } else if (error == GRPC_ERROR_NONE && + GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "HTTP2: %s - skip benign reclamation, there are still %" PRIdPTR + " streams", + t->peer_string.c_str(), + grpc_chttp2_stream_map_size(&t->stream_map)); + } + t->benign_reclaimer_registered = false; + if (error != GRPC_ERROR_CANCELLED) { + grpc_resource_user_finish_reclamation(t->resource_user); + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "benign_reclaimer"); +} + +static void destructive_reclaimer(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + t->combiner->Run(GRPC_CLOSURE_INIT(&t->destructive_reclaimer_locked, + destructive_reclaimer_locked, t, nullptr), + GRPC_ERROR_REF(error)); +} + +static void destructive_reclaimer_locked(void* arg, grpc_error_handle error) { + grpc_chttp2_transport* t = static_cast(arg); + size_t n = grpc_chttp2_stream_map_size(&t->stream_map); + t->destructive_reclaimer_registered = false; + if (error == GRPC_ERROR_NONE && n > 0) { + grpc_chttp2_stream* s = static_cast( + grpc_chttp2_stream_map_rand(&t->stream_map)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "HTTP2: %s - abandon stream id %d", + t->peer_string.c_str(), s->id); + } + grpc_chttp2_cancel_stream( + t, s, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Buffers full"), + GRPC_ERROR_INT_HTTP2_ERROR, + GRPC_HTTP2_ENHANCE_YOUR_CALM)); + if (n > 1) { + // Since we cancel one stream per destructive reclamation, if + // there are more streams left, we can immediately post a new + // reclaimer in case the resource quota needs to free more + // memory + post_destructive_reclaimer(t); + } + } + if (error != GRPC_ERROR_CANCELLED) { + grpc_resource_user_finish_reclamation(t->resource_user); + } + GRPC_CHTTP2_UNREF_TRANSPORT(t, "destructive_reclaimer"); +} + +// +// MONITORING +// + +const char* grpc_chttp2_initiate_write_reason_string( + grpc_chttp2_initiate_write_reason reason) { + switch (reason) { + case GRPC_CHTTP2_INITIATE_WRITE_INITIAL_WRITE: + return "INITIAL_WRITE"; + case GRPC_CHTTP2_INITIATE_WRITE_START_NEW_STREAM: + return "START_NEW_STREAM"; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_MESSAGE: + return "SEND_MESSAGE"; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_INITIAL_METADATA: + return "SEND_INITIAL_METADATA"; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_TRAILING_METADATA: + return "SEND_TRAILING_METADATA"; + case GRPC_CHTTP2_INITIATE_WRITE_RETRY_SEND_PING: + return "RETRY_SEND_PING"; + case GRPC_CHTTP2_INITIATE_WRITE_CONTINUE_PINGS: + return "CONTINUE_PINGS"; + case GRPC_CHTTP2_INITIATE_WRITE_GOAWAY_SENT: + return "GOAWAY_SENT"; + case GRPC_CHTTP2_INITIATE_WRITE_RST_STREAM: + return "RST_STREAM"; + case GRPC_CHTTP2_INITIATE_WRITE_CLOSE_FROM_API: + return "CLOSE_FROM_API"; + case GRPC_CHTTP2_INITIATE_WRITE_STREAM_FLOW_CONTROL: + return "STREAM_FLOW_CONTROL"; + case GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL: + return "TRANSPORT_FLOW_CONTROL"; + case GRPC_CHTTP2_INITIATE_WRITE_SEND_SETTINGS: + return "SEND_SETTINGS"; + case GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_SETTING: + return "FLOW_CONTROL_UNSTALLED_BY_SETTING"; + case GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_UPDATE: + return "FLOW_CONTROL_UNSTALLED_BY_UPDATE"; + case GRPC_CHTTP2_INITIATE_WRITE_APPLICATION_PING: + return "APPLICATION_PING"; + case GRPC_CHTTP2_INITIATE_WRITE_BDP_PING: + return "BDP_PING"; + case GRPC_CHTTP2_INITIATE_WRITE_KEEPALIVE_PING: + return "KEEPALIVE_PING"; + case GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL_UNSTALLED: + return "TRANSPORT_FLOW_CONTROL_UNSTALLED"; + case GRPC_CHTTP2_INITIATE_WRITE_PING_RESPONSE: + return "PING_RESPONSE"; + case GRPC_CHTTP2_INITIATE_WRITE_FORCE_RST_STREAM: + return "FORCE_RST_STREAM"; + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +static grpc_endpoint* chttp2_get_endpoint(grpc_transport* t) { + return (reinterpret_cast(t))->ep; +} + +static const grpc_transport_vtable vtable = {sizeof(grpc_chttp2_stream), + "chttp2", + init_stream, + set_pollset, + set_pollset_set, + perform_stream_op, + perform_transport_op, + destroy_stream, + destroy_transport, + chttp2_get_endpoint}; + +static const grpc_transport_vtable* get_vtable(void) { return &vtable; } + +grpc_core::RefCountedPtr +grpc_chttp2_transport_get_socket_node(grpc_transport* transport) { + grpc_chttp2_transport* t = + reinterpret_cast(transport); + return t->channelz_socket; +} + +grpc_transport* grpc_create_chttp2_transport( + const grpc_channel_args* channel_args, grpc_endpoint* ep, bool is_client, + grpc_resource_user* resource_user) { + auto t = + new grpc_chttp2_transport(channel_args, ep, is_client, resource_user); + return &t->base; +} + +void grpc_chttp2_transport_start_reading( + grpc_transport* transport, grpc_slice_buffer* read_buffer, + grpc_closure* notify_on_receive_settings, grpc_closure* notify_on_close) { + grpc_chttp2_transport* t = + reinterpret_cast(transport); + GRPC_CHTTP2_REF_TRANSPORT( + t, "reading_action"); /* matches unref inside reading_action */ + if (read_buffer != nullptr) { + grpc_slice_buffer_move_into(read_buffer, &t->read_buffer); + gpr_free(read_buffer); + } + t->notify_on_receive_settings = notify_on_receive_settings; + t->notify_on_close = notify_on_close; + t->combiner->Run( + GRPC_CLOSURE_INIT(&t->read_action_locked, read_action_locked, t, nullptr), + GRPC_ERROR_NONE); +} diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc new file mode 100644 index 00000000..afb18abc --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/context_list.cc @@ -0,0 +1,68 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/context_list.h" + +namespace { +void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*, + grpc_error_handle error) = nullptr; +void* (*get_copied_context_fn_g)(void*) = nullptr; +} // namespace + +namespace grpc_core { +void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) { + if (get_copied_context_fn_g == nullptr || + write_timestamps_callback_g == nullptr) { + return; + } + /* Create a new element in the list and add it at the front */ + ContextList* elem = new ContextList(); + elem->trace_context_ = get_copied_context_fn_g(s->context); + elem->byte_offset_ = s->byte_counter; + elem->next_ = *head; + *head = elem; +} + +void ContextList::Execute(void* arg, grpc_core::Timestamps* ts, + grpc_error_handle error) { + ContextList* head = static_cast(arg); + ContextList* to_be_freed; + while (head != nullptr) { + if (write_timestamps_callback_g) { + if (ts) { + ts->byte_offset = static_cast(head->byte_offset_); + } + write_timestamps_callback_g(head->trace_context_, ts, error); + } + to_be_freed = head; + head = head->next_; + delete to_be_freed; + } +} + +void grpc_http2_set_write_timestamps_callback( + void (*fn)(void*, grpc_core::Timestamps*, grpc_error_handle error)) { + write_timestamps_callback_g = fn; +} + +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) { + get_copied_context_fn_g = fn; +} +} /* namespace grpc_core */ diff --git a/src/core/ext/transport/chttp2/transport/flow_control.cc b/src/core/ext/transport/chttp2/transport/flow_control.cc new file mode 100644 index 00000000..739d15df --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/flow_control.cc @@ -0,0 +1,430 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/flow_control.h" + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/gpr/string.h" + +grpc_core::TraceFlag grpc_flowctl_trace(false, "flowctl"); + +namespace grpc_core { +namespace chttp2 { + +TestOnlyTransportTargetWindowEstimatesMocker* + g_test_only_transport_target_window_estimates_mocker; + +bool g_test_only_transport_flow_control_window_check; + +namespace { + +static constexpr const int kTracePadding = 30; +static constexpr const int64_t kMaxWindowUpdateSize = (1u << 31) - 1; + +static char* fmt_int64_diff_str(int64_t old_val, int64_t new_val) { + std::string str; + if (old_val != new_val) { + str = absl::StrFormat("%" PRId64 " -> %" PRId64 "", old_val, new_val); + } else { + str = absl::StrFormat("%" PRId64 "", old_val); + } + return gpr_leftpad(str.c_str(), ' ', kTracePadding); +} + +static char* fmt_uint32_diff_str(uint32_t old_val, uint32_t new_val) { + std::string str; + if (old_val != new_val) { + str = absl::StrFormat("%" PRIu32 " -> %" PRIu32 "", old_val, new_val); + } else { + str = absl::StrFormat("%" PRIu32 "", old_val); + } + return gpr_leftpad(str.c_str(), ' ', kTracePadding); +} +} // namespace + +void FlowControlTrace::Init(const char* reason, TransportFlowControl* tfc, + StreamFlowControl* sfc) { + tfc_ = tfc; + sfc_ = sfc; + reason_ = reason; + remote_window_ = tfc->remote_window(); + target_window_ = tfc->target_window(); + announced_window_ = tfc->announced_window(); + if (sfc != nullptr) { + remote_window_delta_ = sfc->remote_window_delta(); + local_window_delta_ = sfc->local_window_delta(); + announced_window_delta_ = sfc->announced_window_delta(); + } +} + +void FlowControlTrace::Finish() { + uint32_t acked_local_window = + tfc_->transport()->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]; + uint32_t remote_window = + tfc_->transport()->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]; + char* trw_str = fmt_int64_diff_str(remote_window_, tfc_->remote_window()); + char* tlw_str = fmt_int64_diff_str(target_window_, tfc_->target_window()); + char* taw_str = + fmt_int64_diff_str(announced_window_, tfc_->announced_window()); + char* srw_str; + char* slw_str; + char* saw_str; + if (sfc_ != nullptr) { + srw_str = fmt_int64_diff_str(remote_window_delta_ + remote_window, + sfc_->remote_window_delta() + remote_window); + slw_str = + fmt_int64_diff_str(local_window_delta_ + acked_local_window, + sfc_->local_window_delta() + acked_local_window); + saw_str = + fmt_int64_diff_str(announced_window_delta_ + acked_local_window, + sfc_->announced_window_delta() + acked_local_window); + } else { + srw_str = gpr_leftpad("", ' ', kTracePadding); + slw_str = gpr_leftpad("", ' ', kTracePadding); + saw_str = gpr_leftpad("", ' ', kTracePadding); + } + gpr_log(GPR_DEBUG, + "%p[%u][%s] | %s | trw:%s, tlw:%s, taw:%s, srw:%s, slw:%s, saw:%s", + tfc_, sfc_ != nullptr ? sfc_->stream()->id : 0, + tfc_->transport()->is_client ? "cli" : "svr", reason_, trw_str, + tlw_str, taw_str, srw_str, slw_str, saw_str); + gpr_free(trw_str); + gpr_free(tlw_str); + gpr_free(taw_str); + gpr_free(srw_str); + gpr_free(slw_str); + gpr_free(saw_str); +} + +const char* FlowControlAction::UrgencyString(Urgency u) { + switch (u) { + case Urgency::NO_ACTION_NEEDED: + return "no action"; + case Urgency::UPDATE_IMMEDIATELY: + return "update immediately"; + case Urgency::QUEUE_UPDATE: + return "queue update"; + default: + GPR_UNREACHABLE_CODE(return "unknown"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +void FlowControlAction::Trace(grpc_chttp2_transport* t) const { + char* iw_str = fmt_uint32_diff_str( + t->settings[GRPC_SENT_SETTINGS][GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + initial_window_size_); + char* mf_str = fmt_uint32_diff_str( + t->settings[GRPC_SENT_SETTINGS][GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE], + max_frame_size_); + gpr_log(GPR_DEBUG, "t[%s], s[%s], iw:%s:%s mf:%s:%s", + UrgencyString(send_transport_update_), + UrgencyString(send_stream_update_), + UrgencyString(send_initial_window_update_), iw_str, + UrgencyString(send_max_frame_size_update_), mf_str); + gpr_free(iw_str); + gpr_free(mf_str); +} + +TransportFlowControlDisabled::TransportFlowControlDisabled( + grpc_chttp2_transport* t) { + remote_window_ = kMaxWindow; + target_initial_window_size_ = kMaxWindow; + announced_window_ = kMaxWindow; + t->settings[GRPC_PEER_SETTINGS][GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE] = + kFrameSize; + t->settings[GRPC_SENT_SETTINGS][GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE] = + kFrameSize; + t->settings[GRPC_ACKED_SETTINGS][GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE] = + kFrameSize; + t->settings[GRPC_PEER_SETTINGS][GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE] = + kMaxWindow; + t->settings[GRPC_SENT_SETTINGS][GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE] = + kMaxWindow; + t->settings[GRPC_ACKED_SETTINGS][GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE] = + kMaxWindow; +} + +TransportFlowControl::TransportFlowControl(const grpc_chttp2_transport* t, + bool enable_bdp_probe) + : t_(t), + enable_bdp_probe_(enable_bdp_probe), + bdp_estimator_(t->peer_string.c_str()), + pid_controller_(PidController::Args() + .set_gain_p(4) + .set_gain_i(8) + .set_gain_d(0) + .set_initial_control_value(TargetLogBdp()) + .set_min_control_value(-1) + .set_max_control_value(25) + .set_integral_range(10)), + last_pid_update_(ExecCtx::Get()->Now()) {} + +uint32_t TransportFlowControl::MaybeSendUpdate(bool writing_anyway) { + FlowControlTrace trace("t updt sent", this, nullptr); + const uint32_t target_announced_window = + static_cast(target_window()); + if ((writing_anyway || announced_window_ <= target_announced_window / 2) && + announced_window_ != target_announced_window) { + const uint32_t announce = + static_cast(Clamp(target_announced_window - announced_window_, + int64_t(0), kMaxWindowUpdateSize)); + announced_window_ += announce; + return announce; + } + return 0; +} + +grpc_error_handle TransportFlowControl::ValidateRecvData( + int64_t incoming_frame_size) { + if (incoming_frame_size > announced_window_) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "frame of size %" PRId64 " overflows local window of %" PRId64, + incoming_frame_size, announced_window_)); + } + return GRPC_ERROR_NONE; +} + +StreamFlowControl::StreamFlowControl(TransportFlowControl* tfc, + const grpc_chttp2_stream* s) + : tfc_(tfc), s_(s) {} + +grpc_error_handle StreamFlowControl::RecvData(int64_t incoming_frame_size) { + FlowControlTrace trace(" data recv", tfc_, this); + + grpc_error_handle error = GRPC_ERROR_NONE; + error = tfc_->ValidateRecvData(incoming_frame_size); + if (error != GRPC_ERROR_NONE) return error; + + uint32_t sent_init_window = + tfc_->transport()->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]; + uint32_t acked_init_window = + tfc_->transport()->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]; + + int64_t acked_stream_window = announced_window_delta_ + acked_init_window; + int64_t sent_stream_window = announced_window_delta_ + sent_init_window; + if (incoming_frame_size > acked_stream_window) { + if (incoming_frame_size <= sent_stream_window) { + gpr_log(GPR_ERROR, + "Incoming frame of size %" PRId64 + " exceeds local window size of %" PRId64 + ".\n" + "The (un-acked, future) window size would be %" PRId64 + " which is not exceeded.\n" + "This would usually cause a disconnection, but allowing it due to" + "broken HTTP2 implementations in the wild.\n" + "See (for example) https://github.com/netty/netty/issues/6520.", + incoming_frame_size, acked_stream_window, sent_stream_window); + } else { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "frame of size %" PRId64 " overflows local window of %" PRId64, + incoming_frame_size, acked_stream_window)); + } + } + + UpdateAnnouncedWindowDelta(tfc_, -incoming_frame_size); + local_window_delta_ -= incoming_frame_size; + tfc_->CommitRecvData(incoming_frame_size); + return GRPC_ERROR_NONE; +} + +uint32_t StreamFlowControl::MaybeSendUpdate() { + FlowControlTrace trace("s updt sent", tfc_, this); + // If a recently sent settings frame caused the stream's flow control window + // to go in the negative (or < GRPC_HEADER_SIZE_IN_BYTES), update the delta if + // one of the following conditions is satisfied - + // 1) There is a pending byte_stream and higher layers have expressed interest + // in reading additional data through the invokation of `Next()` where the + // bytes are to be available asynchronously. 2) There is a pending + // recv_message op. + // In these cases, we want to make sure that bytes are still flowing. + if (local_window_delta_ < GRPC_HEADER_SIZE_IN_BYTES) { + if (s_->on_next != nullptr) { + GPR_DEBUG_ASSERT(s_->pending_byte_stream); + IncomingByteStreamUpdate(GRPC_HEADER_SIZE_IN_BYTES, 0); + } else if (s_->recv_message != nullptr) { + IncomingByteStreamUpdate(GRPC_HEADER_SIZE_IN_BYTES, + s_->frame_storage.length); + } + } + if (local_window_delta_ > announced_window_delta_) { + uint32_t announce = static_cast( + Clamp(local_window_delta_ - announced_window_delta_, int64_t(0), + kMaxWindowUpdateSize)); + UpdateAnnouncedWindowDelta(tfc_, announce); + return announce; + } + return 0; +} + +void StreamFlowControl::IncomingByteStreamUpdate(size_t max_size_hint, + size_t have_already) { + FlowControlTrace trace("app st recv", tfc_, this); + uint32_t max_recv_bytes; + + /* clamp max recv hint to an allowable size */ + if (max_size_hint >= kMaxWindowDelta) { + max_recv_bytes = kMaxWindowDelta; + } else { + max_recv_bytes = static_cast(max_size_hint); + } + + /* account for bytes already received but unknown to higher layers */ + if (max_recv_bytes >= have_already) { + max_recv_bytes -= static_cast(have_already); + } else { + max_recv_bytes = 0; + } + + /* add some small lookahead to keep pipelines flowing */ + GPR_DEBUG_ASSERT( + max_recv_bytes <= + kMaxWindowUpdateSize - + tfc_->transport() + ->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]); + if (local_window_delta_ < max_recv_bytes) { + uint32_t add_max_recv_bytes = + static_cast(max_recv_bytes - local_window_delta_); + local_window_delta_ += add_max_recv_bytes; + } +} + +// Take in a target and modifies it based on the memory pressure of the system +static double AdjustForMemoryPressure(grpc_resource_quota* quota, + double target) { + // do not increase window under heavy memory pressure. + double memory_pressure = grpc_resource_quota_get_memory_pressure(quota); + static const double kLowMemPressure = 0.1; + static const double kZeroTarget = 22; + static const double kHighMemPressure = 0.8; + static const double kMaxMemPressure = 0.9; + if (memory_pressure < kLowMemPressure && target < kZeroTarget) { + target = (target - kZeroTarget) * memory_pressure / kLowMemPressure + + kZeroTarget; + } else if (memory_pressure > kHighMemPressure) { + target *= 1 - std::min(1.0, (memory_pressure - kHighMemPressure) / + (kMaxMemPressure - kHighMemPressure)); + } + return target; +} + +double TransportFlowControl::TargetLogBdp() { + return AdjustForMemoryPressure(grpc_resource_user_quota(t_->resource_user), + 1 + log2(bdp_estimator_.EstimateBdp())); +} + +double TransportFlowControl::SmoothLogBdp(double value) { + grpc_millis now = ExecCtx::Get()->Now(); + double bdp_error = value - pid_controller_.last_control_value(); + const double dt = static_cast(now - last_pid_update_) * 1e-3; + last_pid_update_ = now; + // Limit dt to 100ms + const double kMaxDt = 0.1; + return pid_controller_.Update(bdp_error, dt > kMaxDt ? kMaxDt : dt); +} + +FlowControlAction::Urgency TransportFlowControl::DeltaUrgency( + int64_t value, grpc_chttp2_setting_id setting_id) { + int64_t delta = value - static_cast( + t_->settings[GRPC_LOCAL_SETTINGS][setting_id]); + // TODO(ncteisen): tune this + if (delta != 0 && (delta <= -value / 5 || delta >= value / 5)) { + return FlowControlAction::Urgency::QUEUE_UPDATE; + } else { + return FlowControlAction::Urgency::NO_ACTION_NEEDED; + } +} + +FlowControlAction TransportFlowControl::PeriodicUpdate() { + FlowControlAction action; + if (enable_bdp_probe_) { + // get bdp estimate and update initial_window accordingly. + // target might change based on how much memory pressure we are under + // TODO(ncteisen): experiment with setting target to be huge under low + // memory pressure. + double target = pow(2, SmoothLogBdp(TargetLogBdp())); + if (g_test_only_transport_target_window_estimates_mocker != nullptr) { + // Hook for simulating unusual flow control situations in tests. + target = g_test_only_transport_target_window_estimates_mocker + ->ComputeNextTargetInitialWindowSizeFromPeriodicUpdate( + target_initial_window_size_ /* current target */); + } + // Though initial window 'could' drop to 0, we keep the floor at + // kMinInitialWindowSize + target_initial_window_size_ = static_cast(Clamp( + target, double(kMinInitialWindowSize), double(kMaxInitialWindowSize))); + action.set_send_initial_window_update( + DeltaUrgency(target_initial_window_size_, + GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE), + static_cast(target_initial_window_size_)); + + // get bandwidth estimate and update max_frame accordingly. + double bw_dbl = bdp_estimator_.EstimateBandwidth(); + // we target the max of BDP or bandwidth in microseconds. + int32_t frame_size = static_cast(Clamp( + std::max( + static_cast(Clamp(bw_dbl, 0.0, double(INT_MAX))) / 1000, + static_cast(target_initial_window_size_)), + 16384, 16777215)); + action.set_send_max_frame_size_update( + DeltaUrgency(static_cast(frame_size), + GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE), + frame_size); + } + return UpdateAction(action); +} + +FlowControlAction StreamFlowControl::UpdateAction(FlowControlAction action) { + // TODO(ncteisen): tune this + if (!s_->read_closed) { + uint32_t sent_init_window = + tfc_->transport()->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]; + if (local_window_delta_ > announced_window_delta_ && + announced_window_delta_ + sent_init_window <= sent_init_window / 2) { + action.set_send_stream_update( + FlowControlAction::Urgency::UPDATE_IMMEDIATELY); + } else if (local_window_delta_ > announced_window_delta_) { + action.set_send_stream_update(FlowControlAction::Urgency::QUEUE_UPDATE); + } + } + + return action; +} + +} // namespace chttp2 +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/frame_data.cc b/src/core/ext/transport/chttp2/transport/frame_data.cc new file mode 100644 index 00000000..91a3f2d7 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_data.cc @@ -0,0 +1,308 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_data.h" + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/transport.h" + +grpc_chttp2_data_parser::~grpc_chttp2_data_parser() { + if (parsing_frame != nullptr) { + GRPC_ERROR_UNREF(parsing_frame->Finished( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Parser destroyed"), false)); + } + GRPC_ERROR_UNREF(error); +} + +grpc_error_handle grpc_chttp2_data_parser_begin_frame( + grpc_chttp2_data_parser* /*parser*/, uint8_t flags, uint32_t stream_id, + grpc_chttp2_stream* s) { + if (flags & ~GRPC_CHTTP2_DATA_FLAG_END_STREAM) { + return grpc_error_set_int(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "unsupported data flags: 0x%02x", flags)), + GRPC_ERROR_INT_STREAM_ID, + static_cast(stream_id)); + } + + if (flags & GRPC_CHTTP2_DATA_FLAG_END_STREAM) { + s->received_last_frame = true; + s->eos_received = true; + } else { + s->received_last_frame = false; + } + + return GRPC_ERROR_NONE; +} + +void grpc_chttp2_encode_data(uint32_t id, grpc_slice_buffer* inbuf, + uint32_t write_bytes, int is_eof, + grpc_transport_one_way_stats* stats, + grpc_slice_buffer* outbuf) { + grpc_slice hdr; + uint8_t* p; + static const size_t header_size = 9; + + hdr = GRPC_SLICE_MALLOC(header_size); + p = GRPC_SLICE_START_PTR(hdr); + GPR_ASSERT(write_bytes < (1 << 24)); + *p++ = static_cast(write_bytes >> 16); + *p++ = static_cast(write_bytes >> 8); + *p++ = static_cast(write_bytes); + *p++ = GRPC_CHTTP2_FRAME_DATA; + *p++ = is_eof ? GRPC_CHTTP2_DATA_FLAG_END_STREAM : 0; + *p++ = static_cast(id >> 24); + *p++ = static_cast(id >> 16); + *p++ = static_cast(id >> 8); + *p++ = static_cast(id); + grpc_slice_buffer_add(outbuf, hdr); + + grpc_slice_buffer_move_first_no_ref(inbuf, write_bytes, outbuf); + + stats->framing_bytes += header_size; + stats->data_bytes += write_bytes; +} + +grpc_error_handle grpc_deframe_unprocessed_incoming_frames( + grpc_chttp2_data_parser* p, grpc_chttp2_stream* s, + grpc_slice_buffer* slices, grpc_slice* slice_out, + grpc_core::OrphanablePtr* stream_out) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_chttp2_transport* t = s->t; + + while (slices->count > 0) { + uint8_t* beg = nullptr; + uint8_t* end = nullptr; + uint8_t* cur = nullptr; + + grpc_slice* slice = grpc_slice_buffer_peek_first(slices); + beg = GRPC_SLICE_START_PTR(*slice); + end = GRPC_SLICE_END_PTR(*slice); + cur = beg; + uint32_t message_flags; + + if (cur == end) { + grpc_slice_buffer_remove_first(slices); + continue; + } + + switch (p->state) { + case GRPC_CHTTP2_DATA_ERROR: + p->state = GRPC_CHTTP2_DATA_ERROR; + grpc_slice_buffer_remove_first(slices); + return GRPC_ERROR_REF(p->error); + case GRPC_CHTTP2_DATA_FH_0: + s->stats.incoming.framing_bytes++; + p->frame_type = *cur; + switch (p->frame_type) { + case 0: + p->is_frame_compressed = false; /* GPR_FALSE */ + break; + case 1: + p->is_frame_compressed = true; /* GPR_TRUE */ + break; + default: + p->error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Bad GRPC frame type 0x%02x", p->frame_type)); + p->error = grpc_error_set_int(p->error, GRPC_ERROR_INT_STREAM_ID, + static_cast(s->id)); + grpc_core::UniquePtr dmp( + grpc_dump_slice(*slice, GPR_DUMP_HEX | GPR_DUMP_ASCII)); + p->error = grpc_error_set_str(p->error, GRPC_ERROR_STR_RAW_BYTES, + dmp.get()); + p->error = + grpc_error_set_int(p->error, GRPC_ERROR_INT_OFFSET, cur - beg); + p->state = GRPC_CHTTP2_DATA_ERROR; + grpc_slice_buffer_remove_first(slices); + return GRPC_ERROR_REF(p->error); + } + if (++cur == end) { + p->state = GRPC_CHTTP2_DATA_FH_1; + grpc_slice_buffer_remove_first(slices); + continue; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_DATA_FH_1: + s->stats.incoming.framing_bytes++; + p->frame_size = (static_cast(*cur)) << 24; + if (++cur == end) { + p->state = GRPC_CHTTP2_DATA_FH_2; + grpc_slice_buffer_remove_first(slices); + continue; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_DATA_FH_2: + s->stats.incoming.framing_bytes++; + p->frame_size |= (static_cast(*cur)) << 16; + if (++cur == end) { + p->state = GRPC_CHTTP2_DATA_FH_3; + grpc_slice_buffer_remove_first(slices); + continue; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_DATA_FH_3: + s->stats.incoming.framing_bytes++; + p->frame_size |= (static_cast(*cur)) << 8; + if (++cur == end) { + p->state = GRPC_CHTTP2_DATA_FH_4; + grpc_slice_buffer_remove_first(slices); + continue; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_DATA_FH_4: + s->stats.incoming.framing_bytes++; + GPR_ASSERT(stream_out != nullptr); + GPR_ASSERT(p->parsing_frame == nullptr); + p->frame_size |= (static_cast(*cur)); + if (t->channelz_socket != nullptr) { + t->channelz_socket->RecordMessageReceived(); + } + p->state = GRPC_CHTTP2_DATA_FRAME; + ++cur; + message_flags = 0; + if (p->is_frame_compressed) { + message_flags |= GRPC_WRITE_INTERNAL_COMPRESS; + } + p->parsing_frame = new grpc_core::Chttp2IncomingByteStream( + t, s, p->frame_size, message_flags); + stream_out->reset(p->parsing_frame); + if (p->parsing_frame->remaining_bytes() == 0) { + GRPC_ERROR_UNREF(p->parsing_frame->Finished(GRPC_ERROR_NONE, true)); + p->parsing_frame = nullptr; + p->state = GRPC_CHTTP2_DATA_FH_0; + } + s->pending_byte_stream = true; + if (cur != end) { + grpc_slice_buffer_sub_first(slices, static_cast(cur - beg), + static_cast(end - beg)); + } else { + grpc_slice_buffer_remove_first(slices); + } + return GRPC_ERROR_NONE; + case GRPC_CHTTP2_DATA_FRAME: { + GPR_ASSERT(p->parsing_frame != nullptr); + GPR_ASSERT(slice_out != nullptr); + if (cur == end) { + grpc_slice_buffer_remove_first(slices); + continue; + } + uint32_t remaining = static_cast(end - cur); + if (remaining == p->frame_size) { + s->stats.incoming.data_bytes += remaining; + if (GRPC_ERROR_NONE != + (error = p->parsing_frame->Push( + grpc_slice_sub(*slice, static_cast(cur - beg), + static_cast(end - beg)), + slice_out))) { + grpc_slice_buffer_remove_first(slices); + return error; + } + if (GRPC_ERROR_NONE != + (error = p->parsing_frame->Finished(GRPC_ERROR_NONE, true))) { + grpc_slice_buffer_remove_first(slices); + return error; + } + p->parsing_frame = nullptr; + p->state = GRPC_CHTTP2_DATA_FH_0; + grpc_slice_buffer_remove_first(slices); + return GRPC_ERROR_NONE; + } else if (remaining < p->frame_size) { + s->stats.incoming.data_bytes += remaining; + if (GRPC_ERROR_NONE != + (error = p->parsing_frame->Push( + grpc_slice_sub(*slice, static_cast(cur - beg), + static_cast(end - beg)), + slice_out))) { + return error; + } + p->frame_size -= remaining; + grpc_slice_buffer_remove_first(slices); + return GRPC_ERROR_NONE; + } else { + GPR_ASSERT(remaining > p->frame_size); + s->stats.incoming.data_bytes += p->frame_size; + if (GRPC_ERROR_NONE != + p->parsing_frame->Push( + grpc_slice_sub( + *slice, static_cast(cur - beg), + static_cast(cur + p->frame_size - beg)), + slice_out)) { + grpc_slice_buffer_remove_first(slices); + return error; + } + if (GRPC_ERROR_NONE != + (error = p->parsing_frame->Finished(GRPC_ERROR_NONE, true))) { + grpc_slice_buffer_remove_first(slices); + return error; + } + p->parsing_frame = nullptr; + p->state = GRPC_CHTTP2_DATA_FH_0; + cur += p->frame_size; + grpc_slice_buffer_sub_first(slices, static_cast(cur - beg), + static_cast(end - beg)); + return GRPC_ERROR_NONE; + } + } + } + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_chttp2_data_parser_parse(void* /*parser*/, + grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + const grpc_slice& slice, + int is_last) { + if (!s->pending_byte_stream) { + grpc_slice_ref_internal(slice); + grpc_slice_buffer_add(&s->frame_storage, slice); + grpc_chttp2_maybe_complete_recv_message(t, s); + } else if (s->on_next) { + GPR_ASSERT(s->frame_storage.length == 0); + grpc_slice_ref_internal(slice); + grpc_slice_buffer_add(&s->unprocessed_incoming_frames_buffer, slice); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, s->on_next, GRPC_ERROR_NONE); + s->on_next = nullptr; + s->unprocessed_incoming_frames_decompressed = false; + } else { + grpc_slice_ref_internal(slice); + grpc_slice_buffer_add(&s->frame_storage, slice); + } + + if (is_last && s->received_last_frame) { + grpc_chttp2_mark_stream_closed( + t, s, true, false, + t->is_client ? GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Data frame with END_STREAM flag received") + : GRPC_ERROR_NONE); + } + + return GRPC_ERROR_NONE; +} diff --git a/src/core/ext/transport/chttp2/transport/frame_goaway.cc b/src/core/ext/transport/chttp2/transport/frame_goaway.cc new file mode 100644 index 00000000..eb90e2e9 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_goaway.cc @@ -0,0 +1,187 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_goaway.h" + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" + +void grpc_chttp2_goaway_parser_init(grpc_chttp2_goaway_parser* p) { + p->debug_data = nullptr; +} + +void grpc_chttp2_goaway_parser_destroy(grpc_chttp2_goaway_parser* p) { + gpr_free(p->debug_data); +} + +grpc_error_handle grpc_chttp2_goaway_parser_begin_frame( + grpc_chttp2_goaway_parser* p, uint32_t length, uint8_t /*flags*/) { + if (length < 8) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("goaway frame too short (%d bytes)", length)); + } + + gpr_free(p->debug_data); + p->debug_length = length - 8; + p->debug_data = static_cast(gpr_malloc(p->debug_length)); + p->debug_pos = 0; + p->state = GRPC_CHTTP2_GOAWAY_LSI0; + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_chttp2_goaway_parser_parse(void* parser, + grpc_chttp2_transport* t, + grpc_chttp2_stream* /*s*/, + const grpc_slice& slice, + int is_last) { + const uint8_t* const beg = GRPC_SLICE_START_PTR(slice); + const uint8_t* const end = GRPC_SLICE_END_PTR(slice); + const uint8_t* cur = beg; + grpc_chttp2_goaway_parser* p = + static_cast(parser); + + switch (p->state) { + case GRPC_CHTTP2_GOAWAY_LSI0: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_LSI0; + return GRPC_ERROR_NONE; + } + p->last_stream_id = (static_cast(*cur)) << 24; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_LSI1: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_LSI1; + return GRPC_ERROR_NONE; + } + p->last_stream_id |= (static_cast(*cur)) << 16; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_LSI2: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_LSI2; + return GRPC_ERROR_NONE; + } + p->last_stream_id |= (static_cast(*cur)) << 8; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_LSI3: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_LSI3; + return GRPC_ERROR_NONE; + } + p->last_stream_id |= (static_cast(*cur)); + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_ERR0: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_ERR0; + return GRPC_ERROR_NONE; + } + p->error_code = (static_cast(*cur)) << 24; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_ERR1: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_ERR1; + return GRPC_ERROR_NONE; + } + p->error_code |= (static_cast(*cur)) << 16; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_ERR2: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_ERR2; + return GRPC_ERROR_NONE; + } + p->error_code |= (static_cast(*cur)) << 8; + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_ERR3: + if (cur == end) { + p->state = GRPC_CHTTP2_GOAWAY_ERR3; + return GRPC_ERROR_NONE; + } + p->error_code |= (static_cast(*cur)); + ++cur; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_GOAWAY_DEBUG: + if (end != cur) { + memcpy(p->debug_data + p->debug_pos, cur, + static_cast(end - cur)); + } + GPR_ASSERT((size_t)(end - cur) < UINT32_MAX - p->debug_pos); + p->debug_pos += static_cast(end - cur); + p->state = GRPC_CHTTP2_GOAWAY_DEBUG; + if (is_last) { + grpc_chttp2_add_incoming_goaway( + t, p->error_code, p->last_stream_id, + absl::string_view(p->debug_data, p->debug_length)); + gpr_free(p->debug_data); + p->debug_data = nullptr; + } + return GRPC_ERROR_NONE; + } + GPR_UNREACHABLE_CODE( + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Should never reach here")); +} + +void grpc_chttp2_goaway_append(uint32_t last_stream_id, uint32_t error_code, + const grpc_slice& debug_data, + grpc_slice_buffer* slice_buffer) { + grpc_slice header = GRPC_SLICE_MALLOC(9 + 4 + 4); + uint8_t* p = GRPC_SLICE_START_PTR(header); + uint32_t frame_length; + GPR_ASSERT(GRPC_SLICE_LENGTH(debug_data) < UINT32_MAX - 4 - 4); + frame_length = 4 + 4 + static_cast GRPC_SLICE_LENGTH(debug_data); + + /* frame header: length */ + *p++ = static_cast(frame_length >> 16); + *p++ = static_cast(frame_length >> 8); + *p++ = static_cast(frame_length); + /* frame header: type */ + *p++ = GRPC_CHTTP2_FRAME_GOAWAY; + /* frame header: flags */ + *p++ = 0; + /* frame header: stream id */ + *p++ = 0; + *p++ = 0; + *p++ = 0; + *p++ = 0; + /* payload: last stream id */ + *p++ = static_cast(last_stream_id >> 24); + *p++ = static_cast(last_stream_id >> 16); + *p++ = static_cast(last_stream_id >> 8); + *p++ = static_cast(last_stream_id); + /* payload: error code */ + *p++ = static_cast(error_code >> 24); + *p++ = static_cast(error_code >> 16); + *p++ = static_cast(error_code >> 8); + *p++ = static_cast(error_code); + GPR_ASSERT(p == GRPC_SLICE_END_PTR(header)); + grpc_slice_buffer_add(slice_buffer, header); + grpc_slice_buffer_add(slice_buffer, debug_data); +} diff --git a/src/core/ext/transport/chttp2/transport/frame_ping.cc b/src/core/ext/transport/chttp2/transport/frame_ping.cc new file mode 100644 index 00000000..3ee99eba --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_ping.cc @@ -0,0 +1,132 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_ping.h" + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" + +static bool g_disable_ping_ack = false; + +grpc_slice grpc_chttp2_ping_create(uint8_t ack, uint64_t opaque_8bytes) { + grpc_slice slice = GRPC_SLICE_MALLOC(9 + 8); + uint8_t* p = GRPC_SLICE_START_PTR(slice); + + *p++ = 0; + *p++ = 0; + *p++ = 8; + *p++ = GRPC_CHTTP2_FRAME_PING; + *p++ = ack ? 1 : 0; + *p++ = 0; + *p++ = 0; + *p++ = 0; + *p++ = 0; + *p++ = static_cast(opaque_8bytes >> 56); + *p++ = static_cast(opaque_8bytes >> 48); + *p++ = static_cast(opaque_8bytes >> 40); + *p++ = static_cast(opaque_8bytes >> 32); + *p++ = static_cast(opaque_8bytes >> 24); + *p++ = static_cast(opaque_8bytes >> 16); + *p++ = static_cast(opaque_8bytes >> 8); + *p++ = static_cast(opaque_8bytes); + + return slice; +} + +grpc_error_handle grpc_chttp2_ping_parser_begin_frame( + grpc_chttp2_ping_parser* parser, uint32_t length, uint8_t flags) { + if (flags & 0xfe || length != 8) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("invalid ping: length=%d, flags=%02x", length, flags)); + } + parser->byte = 0; + parser->is_ack = flags; + parser->opaque_8bytes = 0; + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_chttp2_ping_parser_parse(void* parser, + grpc_chttp2_transport* t, + grpc_chttp2_stream* /*s*/, + const grpc_slice& slice, + int is_last) { + const uint8_t* const beg = GRPC_SLICE_START_PTR(slice); + const uint8_t* const end = GRPC_SLICE_END_PTR(slice); + const uint8_t* cur = beg; + grpc_chttp2_ping_parser* p = static_cast(parser); + + while (p->byte != 8 && cur != end) { + p->opaque_8bytes |= ((static_cast(*cur)) << (56 - 8 * p->byte)); + cur++; + p->byte++; + } + + if (p->byte == 8) { + GPR_ASSERT(is_last); + if (p->is_ack) { + grpc_chttp2_ack_ping(t, p->opaque_8bytes); + } else { + if (!t->is_client) { + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + grpc_millis next_allowed_ping = + t->ping_recv_state.last_ping_recv_time + + t->ping_policy.min_recv_ping_interval_without_data; + + if (t->keepalive_permit_without_calls == 0 && + grpc_chttp2_stream_map_size(&t->stream_map) == 0) { + /* According to RFC1122, the interval of TCP Keep-Alive is default to + no less than two hours. When there is no outstanding streams, we + restrict the number of PINGS equivalent to TCP Keep-Alive. */ + next_allowed_ping = + t->ping_recv_state.last_ping_recv_time + 7200 * GPR_MS_PER_SEC; + } + + if (next_allowed_ping > now) { + grpc_chttp2_add_ping_strike(t); + } + + t->ping_recv_state.last_ping_recv_time = now; + } + if (!g_disable_ping_ack) { + if (t->ping_ack_count == t->ping_ack_capacity) { + t->ping_ack_capacity = + std::max(t->ping_ack_capacity * 3 / 2, size_t(3)); + t->ping_acks = static_cast(gpr_realloc( + t->ping_acks, t->ping_ack_capacity * sizeof(*t->ping_acks))); + } + t->num_pending_induced_frames++; + t->ping_acks[t->ping_ack_count++] = p->opaque_8bytes; + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_PING_RESPONSE); + } + } + } + + return GRPC_ERROR_NONE; +} + +void grpc_set_disable_ping_ack(bool disable_ping_ack) { + g_disable_ping_ack = disable_ping_ack; +} diff --git a/src/core/ext/transport/chttp2/transport/frame_rst_stream.cc b/src/core/ext/transport/chttp2/transport/frame_rst_stream.cc new file mode 100644 index 00000000..bd92557e --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_rst_stream.cc @@ -0,0 +1,118 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_rst_stream.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/frame.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/transport/http2_errors.h" + +grpc_slice grpc_chttp2_rst_stream_create(uint32_t id, uint32_t code, + grpc_transport_one_way_stats* stats) { + static const size_t frame_size = 13; + grpc_slice slice = GRPC_SLICE_MALLOC(frame_size); + if (stats != nullptr) stats->framing_bytes += frame_size; + uint8_t* p = GRPC_SLICE_START_PTR(slice); + + // Frame size. + *p++ = 0; + *p++ = 0; + *p++ = 4; + // Frame type. + *p++ = GRPC_CHTTP2_FRAME_RST_STREAM; + // Flags. + *p++ = 0; + // Stream ID. + *p++ = static_cast(id >> 24); + *p++ = static_cast(id >> 16); + *p++ = static_cast(id >> 8); + *p++ = static_cast(id); + // Error code. + *p++ = static_cast(code >> 24); + *p++ = static_cast(code >> 16); + *p++ = static_cast(code >> 8); + *p++ = static_cast(code); + + return slice; +} + +void grpc_chttp2_add_rst_stream_to_next_write( + grpc_chttp2_transport* t, uint32_t id, uint32_t code, + grpc_transport_one_way_stats* stats) { + t->num_pending_induced_frames++; + grpc_slice_buffer_add(&t->qbuf, + grpc_chttp2_rst_stream_create(id, code, stats)); +} + +grpc_error_handle grpc_chttp2_rst_stream_parser_begin_frame( + grpc_chttp2_rst_stream_parser* parser, uint32_t length, uint8_t flags) { + if (length != 4) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "invalid rst_stream: length=%d, flags=%02x", length, flags)); + } + parser->byte = 0; + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_chttp2_rst_stream_parser_parse(void* parser, + grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + const grpc_slice& slice, + int is_last) { + const uint8_t* const beg = GRPC_SLICE_START_PTR(slice); + const uint8_t* const end = GRPC_SLICE_END_PTR(slice); + const uint8_t* cur = beg; + grpc_chttp2_rst_stream_parser* p = + static_cast(parser); + + while (p->byte != 4 && cur != end) { + p->reason_bytes[p->byte] = *cur; + cur++; + p->byte++; + } + s->stats.incoming.framing_bytes += static_cast(end - cur); + + if (p->byte == 4) { + GPR_ASSERT(is_last); + uint32_t reason = ((static_cast(p->reason_bytes[0])) << 24) | + ((static_cast(p->reason_bytes[1])) << 16) | + ((static_cast(p->reason_bytes[2])) << 8) | + ((static_cast(p->reason_bytes[3]))); + grpc_error_handle error = GRPC_ERROR_NONE; + if (reason != GRPC_HTTP2_NO_ERROR || s->trailing_metadata_buffer.empty()) { + error = grpc_error_set_int( + grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("RST_STREAM"), + GRPC_ERROR_STR_GRPC_MESSAGE, + absl::StrCat("Received RST_STREAM with error code ", reason)), + GRPC_ERROR_INT_HTTP2_ERROR, static_cast(reason)); + } + grpc_chttp2_mark_stream_closed(t, s, true, true, error); + } + + return GRPC_ERROR_NONE; +} diff --git a/src/core/ext/transport/chttp2/transport/frame_settings.cc b/src/core/ext/transport/chttp2/transport/frame_settings.cc new file mode 100644 index 00000000..c488f5ab --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_settings.cc @@ -0,0 +1,273 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_settings.h" + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/frame.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/transport/http2_errors.h" + +static uint8_t* fill_header(uint8_t* out, uint32_t length, uint8_t flags) { + *out++ = static_cast(length >> 16); + *out++ = static_cast(length >> 8); + *out++ = static_cast(length); + *out++ = GRPC_CHTTP2_FRAME_SETTINGS; + *out++ = flags; + *out++ = 0; + *out++ = 0; + *out++ = 0; + *out++ = 0; + return out; +} + +grpc_slice grpc_chttp2_settings_create(uint32_t* old_settings, + const uint32_t* new_settings, + uint32_t force_mask, size_t count) { + size_t i; + uint32_t n = 0; + grpc_slice output; + uint8_t* p; + + for (i = 0; i < count; i++) { + n += (new_settings[i] != old_settings[i] || (force_mask & (1u << i)) != 0); + } + + output = GRPC_SLICE_MALLOC(9 + 6 * n); + p = fill_header(GRPC_SLICE_START_PTR(output), 6 * n, 0); + + for (i = 0; i < count; i++) { + if (new_settings[i] != old_settings[i] || (force_mask & (1u << i)) != 0) { + *p++ = static_cast(grpc_setting_id_to_wire_id[i] >> 8); + *p++ = static_cast(grpc_setting_id_to_wire_id[i]); + *p++ = static_cast(new_settings[i] >> 24); + *p++ = static_cast(new_settings[i] >> 16); + *p++ = static_cast(new_settings[i] >> 8); + *p++ = static_cast(new_settings[i]); + old_settings[i] = new_settings[i]; + } + } + + GPR_ASSERT(p == GRPC_SLICE_END_PTR(output)); + + return output; +} + +grpc_slice grpc_chttp2_settings_ack_create(void) { + grpc_slice output = GRPC_SLICE_MALLOC(9); + fill_header(GRPC_SLICE_START_PTR(output), 0, GRPC_CHTTP2_FLAG_ACK); + return output; +} + +grpc_error_handle grpc_chttp2_settings_parser_begin_frame( + grpc_chttp2_settings_parser* parser, uint32_t length, uint8_t flags, + uint32_t* settings) { + parser->target_settings = settings; + memcpy(parser->incoming_settings, settings, + GRPC_CHTTP2_NUM_SETTINGS * sizeof(uint32_t)); + parser->is_ack = 0; + parser->state = GRPC_CHTTP2_SPS_ID0; + if (flags == GRPC_CHTTP2_FLAG_ACK) { + parser->is_ack = 1; + if (length != 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "non-empty settings ack frame received"); + } + return GRPC_ERROR_NONE; + } else if (flags != 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "invalid flags on settings frame"); + } else if (length % 6 != 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "settings frames must be a multiple of six bytes"); + } else { + return GRPC_ERROR_NONE; + } +} + +namespace { + +void StreamFlowControlWindowCheck(void* user_data, uint32_t /* key */, + void* stream) { + bool* error = static_cast(user_data); + grpc_chttp2_stream* s = static_cast(stream); + if ((s->t->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE] + + s->t->initial_window_update + s->flow_control->remote_window_delta()) > + ((1u << 31) - 1)) { + *error = true; + } +} + +} // namespace + +grpc_error_handle grpc_chttp2_settings_parser_parse(void* p, + grpc_chttp2_transport* t, + grpc_chttp2_stream* /*s*/, + const grpc_slice& slice, + int is_last) { + grpc_chttp2_settings_parser* parser = + static_cast(p); + const uint8_t* cur = GRPC_SLICE_START_PTR(slice); + const uint8_t* end = GRPC_SLICE_END_PTR(slice); + grpc_chttp2_setting_id id; + + if (parser->is_ack) { + return GRPC_ERROR_NONE; + } + + for (;;) { + switch (parser->state) { + case GRPC_CHTTP2_SPS_ID0: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_ID0; + if (is_last) { + memcpy(parser->target_settings, parser->incoming_settings, + GRPC_CHTTP2_NUM_SETTINGS * sizeof(uint32_t)); + t->num_pending_induced_frames++; + grpc_slice_buffer_add(&t->qbuf, grpc_chttp2_settings_ack_create()); + if (t->notify_on_receive_settings != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + t->notify_on_receive_settings, + GRPC_ERROR_NONE); + t->notify_on_receive_settings = nullptr; + } + } + return GRPC_ERROR_NONE; + } + parser->id = static_cast((static_cast(*cur)) << 8); + cur++; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_SPS_ID1: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_ID1; + return GRPC_ERROR_NONE; + } + parser->id = static_cast(parser->id | (*cur)); + cur++; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_SPS_VAL0: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_VAL0; + return GRPC_ERROR_NONE; + } + parser->value = (static_cast(*cur)) << 24; + cur++; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_SPS_VAL1: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_VAL1; + return GRPC_ERROR_NONE; + } + parser->value |= (static_cast(*cur)) << 16; + cur++; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_SPS_VAL2: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_VAL2; + return GRPC_ERROR_NONE; + } + parser->value |= (static_cast(*cur)) << 8; + cur++; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHTTP2_SPS_VAL3: + if (cur == end) { + parser->state = GRPC_CHTTP2_SPS_VAL3; + return GRPC_ERROR_NONE; + } else { + parser->state = GRPC_CHTTP2_SPS_ID0; + } + parser->value |= *cur; + cur++; + + if (grpc_wire_id_to_setting_id(parser->id, &id)) { + const grpc_chttp2_setting_parameters* sp = + &grpc_chttp2_settings_parameters[id]; + // If flow control is disabled we skip these. + if (!t->flow_control->flow_control_enabled() && + (id == GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE || + id == GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE)) { + continue; + } + if (parser->value < sp->min_value || parser->value > sp->max_value) { + switch (sp->invalid_value_behavior) { + case GRPC_CHTTP2_CLAMP_INVALID_VALUE: + parser->value = grpc_core::Clamp(parser->value, sp->min_value, + sp->max_value); + break; + case GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE: + grpc_chttp2_goaway_append( + t->last_new_stream_id, sp->error_value, + grpc_slice_from_static_string("HTTP2 settings error"), + &t->qbuf); + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "invalid value %u passed for %s", parser->value, sp->name)); + } + } + if (id == GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE && + parser->incoming_settings[id] != parser->value) { + t->initial_window_update += static_cast(parser->value) - + parser->incoming_settings[id]; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_flowctl_trace)) { + gpr_log(GPR_INFO, "%p[%s] adding %d for initial_window change", t, + t->is_client ? "cli" : "svr", + static_cast(t->initial_window_update)); + } + if (grpc_core::chttp2:: + g_test_only_transport_flow_control_window_check) { + bool error = false; + if (parser->value > grpc_core::chttp2::kMaxInitialWindowSize || + parser->value < grpc_core::chttp2::kMinInitialWindowSize) { + error = true; + } else { + grpc_chttp2_stream_map_for_each( + &t->stream_map, StreamFlowControlWindowCheck, &error); + } + if (error) { + grpc_chttp2_goaway_append( + t->last_new_stream_id, sp->error_value, + grpc_slice_from_static_string("HTTP2 settings error"), + &t->qbuf); + } + } + } + parser->incoming_settings[id] = parser->value; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "CHTTP2:%s:%s: got setting %s = %d", + t->is_client ? "CLI" : "SVR", t->peer_string.c_str(), + sp->name, parser->value); + } + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_ERROR, "CHTTP2: Ignoring unknown setting %d (value %d)", + parser->id, parser->value); + } + break; + } + } +} diff --git a/src/core/ext/transport/chttp2/transport/frame_window_update.cc b/src/core/ext/transport/chttp2/transport/frame_window_update.cc new file mode 100644 index 00000000..9763e810 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/frame_window_update.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/frame_window_update.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" + +grpc_slice grpc_chttp2_window_update_create( + uint32_t id, uint32_t window_delta, grpc_transport_one_way_stats* stats) { + static const size_t frame_size = 13; + grpc_slice slice = GRPC_SLICE_MALLOC(frame_size); + stats->header_bytes += frame_size; + uint8_t* p = GRPC_SLICE_START_PTR(slice); + + GPR_ASSERT(window_delta); + + *p++ = 0; + *p++ = 0; + *p++ = 4; + *p++ = GRPC_CHTTP2_FRAME_WINDOW_UPDATE; + *p++ = 0; + *p++ = static_cast(id >> 24); + *p++ = static_cast(id >> 16); + *p++ = static_cast(id >> 8); + *p++ = static_cast(id); + *p++ = static_cast(window_delta >> 24); + *p++ = static_cast(window_delta >> 16); + *p++ = static_cast(window_delta >> 8); + *p++ = static_cast(window_delta); + + return slice; +} + +grpc_error_handle grpc_chttp2_window_update_parser_begin_frame( + grpc_chttp2_window_update_parser* parser, uint32_t length, uint8_t flags) { + if (flags || length != 4) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "invalid window update: length=%d, flags=%02x", length, flags)); + } + parser->byte = 0; + parser->amount = 0; + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_chttp2_window_update_parser_parse( + void* parser, grpc_chttp2_transport* t, grpc_chttp2_stream* s, + const grpc_slice& slice, int is_last) { + const uint8_t* const beg = GRPC_SLICE_START_PTR(slice); + const uint8_t* const end = GRPC_SLICE_END_PTR(slice); + const uint8_t* cur = beg; + grpc_chttp2_window_update_parser* p = + static_cast(parser); + + while (p->byte != 4 && cur != end) { + p->amount |= (static_cast(*cur)) << (8 * (3 - p->byte)); + cur++; + p->byte++; + } + + if (s != nullptr) { + s->stats.incoming.framing_bytes += static_cast(end - cur); + } + + if (p->byte == 4) { + // top bit is reserved and must be ignored. + uint32_t received_update = p->amount & 0x7fffffffu; + if (received_update == 0) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("invalid window update bytes: ", p->amount)); + } + GPR_ASSERT(is_last); + + if (t->incoming_stream_id != 0) { + if (s != nullptr) { + s->flow_control->RecvUpdate(received_update); + if (grpc_core::chttp2:: + g_test_only_transport_flow_control_window_check && + s->flow_control->remote_window_delta() > + grpc_core::chttp2::kMaxWindowDelta) { + GPR_ASSERT(false); + } + if (grpc_chttp2_list_remove_stalled_by_stream(t, s)) { + grpc_chttp2_mark_stream_writable(t, s); + grpc_chttp2_initiate_write( + t, GRPC_CHTTP2_INITIATE_WRITE_FLOW_CONTROL_UNSTALLED_BY_UPDATE); + } + } + } else { + bool was_zero = t->flow_control->remote_window() <= 0; + t->flow_control->RecvUpdate(received_update); + bool is_zero = t->flow_control->remote_window() <= 0; + if (was_zero && !is_zero) { + grpc_chttp2_initiate_write( + t, GRPC_CHTTP2_INITIATE_WRITE_TRANSPORT_FLOW_CONTROL_UNSTALLED); + } + } + } + + return GRPC_ERROR_NONE; +} diff --git a/src/core/ext/transport/chttp2/transport/hpack_encoder.cc b/src/core/ext/transport/chttp2/transport/hpack_encoder.cc new file mode 100644 index 00000000..5dfd11d9 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/hpack_encoder.cc @@ -0,0 +1,546 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" + +#include +#include + +/* This is here for grpc_is_binary_header + * TODO(murgatroid99): Remove this + */ +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" +#include "src/core/ext/transport/chttp2/transport/hpack_utils.h" +#include "src/core/ext/transport/chttp2/transport/varint.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/timeout_encoding.h" + +namespace grpc_core { + +namespace { + +/* don't consider adding anything bigger than this to the hpack table */ +constexpr size_t kMaxDecoderSpaceUsage = 512; +constexpr size_t kDataFrameHeaderSize = 9; + +} /* namespace */ + +/* fills p (which is expected to be kDataFrameHeaderSize bytes long) + * with a data frame header */ +static void FillHeader(uint8_t* p, uint8_t type, uint32_t id, size_t len, + uint8_t flags) { + /* len is the current frame size (i.e. for the frame we're finishing). + We finish a frame if: + 1) We called ensure_space(), (i.e. add_tiny_header_data()) and adding + 'need_bytes' to the frame would cause us to exceed max_frame_size. + 2) We called add_header_data, and adding the slice would cause us to exceed + max_frame_size. + 3) We're done encoding the header. + + Thus, len is always <= max_frame_size. + max_frame_size is derived from GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE, + which has a max allowable value of 16777215 (see chttp_transport.cc). + Thus, the following assert can be a debug assert. */ + GPR_DEBUG_ASSERT(len < 16777316); + *p++ = static_cast(len >> 16); + *p++ = static_cast(len >> 8); + *p++ = static_cast(len); + *p++ = type; + *p++ = flags; + *p++ = static_cast(id >> 24); + *p++ = static_cast(id >> 16); + *p++ = static_cast(id >> 8); + *p++ = static_cast(id); +} + +size_t HPackCompressor::Framer::CurrentFrameSize() const { + const size_t frame_size = + output_->length - prefix_.output_length_at_start_of_frame; + GPR_DEBUG_ASSERT(frame_size <= max_frame_size_); + return frame_size; +} + +// finish a frame - fill in the previously reserved header +void HPackCompressor::Framer::FinishFrame(bool is_header_boundary) { + const uint8_t type = is_first_frame_ ? GRPC_CHTTP2_FRAME_HEADER + : GRPC_CHTTP2_FRAME_CONTINUATION; + uint8_t flags = 0; + // per the HTTP/2 spec: + // A HEADERS frame carries the END_STREAM flag that signals the end of a + // stream. However, a HEADERS frame with the END_STREAM flag set can be + // followed by CONTINUATION frames on the same stream. Logically, the + // CONTINUATION frames are part of the HEADERS frame. + // Thus, we add the END_STREAM flag to the HEADER frame (the first frame). + if (is_first_frame_ && is_end_of_stream_) { + flags |= GRPC_CHTTP2_DATA_FLAG_END_STREAM; + } + // per the HTTP/2 spec: + // A HEADERS frame without the END_HEADERS flag set MUST be followed by + // a CONTINUATION frame for the same stream. + // Thus, we add the END_HEADER flag to the last frame. + if (is_header_boundary) { + flags |= GRPC_CHTTP2_DATA_FLAG_END_HEADERS; + } + FillHeader(GRPC_SLICE_START_PTR(output_->slices[prefix_.header_idx]), type, + stream_id_, CurrentFrameSize(), flags); + stats_->framing_bytes += kDataFrameHeaderSize; + is_first_frame_ = false; +} + +// begin a new frame: reserve off header space, remember how many bytes we'd +// output before beginning +HPackCompressor::Framer::FramePrefix HPackCompressor::Framer::BeginFrame() { + grpc_slice reserved; + reserved.refcount = nullptr; + reserved.data.inlined.length = kDataFrameHeaderSize; + return FramePrefix{grpc_slice_buffer_add_indexed(output_, reserved), + output_->length}; +} + +// make sure that the current frame is of the type desired, and has sufficient +// space to add at least about_to_add bytes -- finishes the current frame if +// needed +void HPackCompressor::Framer::EnsureSpace(size_t need_bytes) { + if (GPR_LIKELY(CurrentFrameSize() + need_bytes <= max_frame_size_)) { + return; + } + FinishFrame(false); + prefix_ = BeginFrame(); +} + +void HPackCompressor::Framer::Add(grpc_slice slice) { + const size_t len = GRPC_SLICE_LENGTH(slice); + if (len == 0) return; + const size_t remaining = max_frame_size_ - CurrentFrameSize(); + if (len <= remaining) { + stats_->header_bytes += len; + grpc_slice_buffer_add(output_, slice); + } else { + stats_->header_bytes += remaining; + grpc_slice_buffer_add(output_, grpc_slice_split_head(&slice, remaining)); + FinishFrame(false); + prefix_ = BeginFrame(); + Add(slice); + } +} + +uint8_t* HPackCompressor::Framer::AddTiny(size_t len) { + EnsureSpace(len); + stats_->header_bytes += len; + return grpc_slice_buffer_tiny_add(output_, len); +} + +// Add a key to the dynamic table. Both key and value will be added to table at +// the decoder. +void HPackCompressor::AddKeyWithIndex(grpc_slice_refcount* key_ref, + uint32_t new_index, uint32_t key_hash) { + key_index_.Insert(KeySliceRef(key_ref, key_hash), new_index); +} + +/* add an element to the decoder table */ +void HPackCompressor::AddElemWithIndex(grpc_mdelem elem, uint32_t new_index, + uint32_t elem_hash, uint32_t key_hash) { + GPR_DEBUG_ASSERT(GRPC_MDELEM_IS_INTERNED(elem)); + elem_index_.Insert(KeyElem(elem, elem_hash), new_index); + AddKeyWithIndex(GRPC_MDKEY(elem).refcount, new_index, key_hash); +} + +void HPackCompressor::AddElem(grpc_mdelem elem, size_t elem_size, + uint32_t elem_hash, uint32_t key_hash) { + uint32_t new_index = table_.AllocateIndex(elem_size); + if (new_index != 0) { + AddElemWithIndex(elem, new_index, elem_hash, key_hash); + } +} + +void HPackCompressor::AddKey(grpc_mdelem elem, size_t elem_size, + uint32_t key_hash) { + uint32_t new_index = table_.AllocateIndex(elem_size); + if (new_index != 0) { + AddKeyWithIndex(GRPC_MDKEY(elem).refcount, new_index, key_hash); + } +} + +void HPackCompressor::Framer::EmitIndexed(uint32_t elem_index) { + GRPC_STATS_INC_HPACK_SEND_INDEXED(); + VarintWriter<1> w(elem_index); + w.Write(0x80, AddTiny(w.length())); +} + +struct WireValue { + WireValue(uint8_t huffman_prefix, bool insert_null_before_wire_value, + const grpc_slice& slice) + : data(slice), + huffman_prefix(huffman_prefix), + insert_null_before_wire_value(insert_null_before_wire_value), + length(GRPC_SLICE_LENGTH(slice) + + (insert_null_before_wire_value ? 1 : 0)) {} + // While wire_value is const from the POV of hpack encoder code, actually + // adding it to a slice buffer will possibly split the slice. + const grpc_slice data; + const uint8_t huffman_prefix; + const bool insert_null_before_wire_value; + const size_t length; +}; + +static WireValue GetWireValue(const grpc_slice& value, bool true_binary_enabled, + bool is_bin_hdr) { + if (is_bin_hdr) { + if (true_binary_enabled) { + GRPC_STATS_INC_HPACK_SEND_BINARY(); + return WireValue(0x00, true, grpc_slice_ref_internal(value)); + } else { + GRPC_STATS_INC_HPACK_SEND_BINARY_BASE64(); + return WireValue(0x80, false, + grpc_chttp2_base64_encode_and_huffman_compress(value)); + } + } else { + /* TODO(ctiller): opportunistically compress non-binary headers */ + GRPC_STATS_INC_HPACK_SEND_UNCOMPRESSED(); + return WireValue(0x00, false, grpc_slice_ref_internal(value)); + } +} + +struct DefinitelyInterned { + static bool IsBinary(grpc_slice key) { + return grpc_is_refcounted_slice_binary_header(key); + } +}; +struct UnsureIfInterned { + static bool IsBinary(grpc_slice key) { + return grpc_is_binary_header_internal(key); + } +}; + +class StringValue { + public: + template + StringValue(MetadataKeyType, grpc_mdelem elem, bool use_true_binary_metadata) + : wire_value_(GetWireValue(GRPC_MDVALUE(elem), use_true_binary_metadata, + MetadataKeyType::IsBinary(GRPC_MDKEY(elem)))), + len_val_(wire_value_.length) {} + + size_t prefix_length() const { + return len_val_.length() + + (wire_value_.insert_null_before_wire_value ? 1 : 0); + } + + void WritePrefix(uint8_t* prefix_data) { + len_val_.Write(wire_value_.huffman_prefix, prefix_data); + if (wire_value_.insert_null_before_wire_value) { + prefix_data[len_val_.length()] = 0; + } + } + + const grpc_slice& data() { return wire_value_.data; } + + private: + WireValue wire_value_; + VarintWriter<1> len_val_; +}; + +class NonBinaryStringValue { + public: + explicit NonBinaryStringValue(const grpc_slice& value) + : value_(value), len_val_(GRPC_SLICE_LENGTH(value)) {} + + size_t prefix_length() const { return len_val_.length(); } + + void WritePrefix(uint8_t* prefix_data) { len_val_.Write(0x00, prefix_data); } + + const grpc_slice& data() { return value_; } + + private: + grpc_slice value_; + VarintWriter<1> len_val_; +}; + +class StringKey { + public: + explicit StringKey(grpc_slice key) + : key_(key), len_key_(GRPC_SLICE_LENGTH(key)) {} + + size_t prefix_length() const { return 1 + len_key_.length(); } + + void WritePrefix(uint8_t type, uint8_t* data) { + data[0] = type; + len_key_.Write(0x00, data + 1); + } + + grpc_slice key() const { return key_; } + + private: + grpc_slice key_; + VarintWriter<1> len_key_; +}; + +void HPackCompressor::Framer::EmitLitHdrIncIdx(uint32_t key_index, + grpc_mdelem elem) { + GRPC_STATS_INC_HPACK_SEND_LITHDR_INCIDX(); + StringValue emit(DefinitelyInterned(), elem, use_true_binary_metadata_); + VarintWriter<2> key(key_index); + uint8_t* data = AddTiny(key.length() + emit.prefix_length()); + key.Write(0x40, data); + emit.WritePrefix(data + key.length()); + Add(emit.data()); +} + +void HPackCompressor::Framer::EmitLitHdrNotIdx(uint32_t key_index, + grpc_mdelem elem) { + GRPC_STATS_INC_HPACK_SEND_LITHDR_NOTIDX(); + StringValue emit(DefinitelyInterned(), elem, use_true_binary_metadata_); + VarintWriter<4> key(key_index); + uint8_t* data = AddTiny(key.length() + emit.prefix_length()); + key.Write(0x00, data); + emit.WritePrefix(data + key.length()); + Add(emit.data()); +} + +void HPackCompressor::Framer::EmitLitHdrWithStringKeyIncIdx(grpc_mdelem elem) { + GRPC_STATS_INC_HPACK_SEND_LITHDR_INCIDX_V(); + GRPC_STATS_INC_HPACK_SEND_UNCOMPRESSED(); + StringKey key(GRPC_MDKEY(elem)); + key.WritePrefix(0x40, AddTiny(key.prefix_length())); + Add(grpc_slice_ref_internal(key.key())); + StringValue emit(DefinitelyInterned(), elem, use_true_binary_metadata_); + emit.WritePrefix(AddTiny(emit.prefix_length())); + Add(emit.data()); +} + +void HPackCompressor::Framer::EmitLitHdrWithNonBinaryStringKeyIncIdx( + const grpc_slice& key_slice, const grpc_slice& value_slice) { + GRPC_STATS_INC_HPACK_SEND_LITHDR_INCIDX_V(); + GRPC_STATS_INC_HPACK_SEND_UNCOMPRESSED(); + StringKey key(key_slice); + key.WritePrefix(0x40, AddTiny(key.prefix_length())); + Add(grpc_slice_ref_internal(key.key())); + NonBinaryStringValue emit(value_slice); + emit.WritePrefix(AddTiny(emit.prefix_length())); + Add(grpc_slice_ref_internal(emit.data())); +} + +void HPackCompressor::Framer::EmitLitHdrWithStringKeyNotIdx(grpc_mdelem elem) { + GRPC_STATS_INC_HPACK_SEND_LITHDR_NOTIDX_V(); + GRPC_STATS_INC_HPACK_SEND_UNCOMPRESSED(); + StringKey key(GRPC_MDKEY(elem)); + key.WritePrefix(0x00, AddTiny(key.prefix_length())); + Add(grpc_slice_ref_internal(key.key())); + StringValue emit(UnsureIfInterned(), elem, use_true_binary_metadata_); + emit.WritePrefix(AddTiny(emit.prefix_length())); + Add(emit.data()); +} + +void HPackCompressor::Framer::AdvertiseTableSizeChange() { + VarintWriter<3> w(compressor_->table_.max_size()); + w.Write(0x20, AddTiny(w.length())); +} + +void HPackCompressor::Framer::Log(grpc_mdelem elem) { + char* k = grpc_slice_to_c_string(GRPC_MDKEY(elem)); + char* v = nullptr; + if (grpc_is_binary_header_internal(GRPC_MDKEY(elem))) { + v = grpc_dump_slice(GRPC_MDVALUE(elem), GPR_DUMP_HEX); + } else { + v = grpc_slice_to_c_string(GRPC_MDVALUE(elem)); + } + gpr_log( + GPR_INFO, + "Encode: '%s: %s', elem_interned=%d [%d], k_interned=%d, v_interned=%d", + k, v, GRPC_MDELEM_IS_INTERNED(elem), GRPC_MDELEM_STORAGE(elem), + grpc_slice_is_interned(GRPC_MDKEY(elem)), + grpc_slice_is_interned(GRPC_MDVALUE(elem))); + gpr_free(k); + gpr_free(v); +} + +struct EmitIndexedStatus { + EmitIndexedStatus() = default; + EmitIndexedStatus(uint32_t elem_hash, bool emitted, bool can_add) + : elem_hash(elem_hash), emitted(emitted), can_add(can_add) {} + const uint32_t elem_hash = 0; + const bool emitted = false; + const bool can_add = false; +}; + +/* encode an mdelem */ +void HPackCompressor::Framer::EncodeDynamic(grpc_mdelem elem) { + const grpc_slice& elem_key = GRPC_MDKEY(elem); + // User-provided key len validated in grpc_validate_header_key_is_legal(). + GPR_DEBUG_ASSERT(GRPC_SLICE_LENGTH(elem_key) > 0); + // Header ordering: all reserved headers (prefixed with ':') must precede + // regular headers. This can be a debug assert, since: + // 1) User cannot give us ':' headers (grpc_validate_header_key_is_legal()). + // 2) grpc filters/core should be checked during debug builds. */ +#ifndef NDEBUG + if (GRPC_SLICE_START_PTR(elem_key)[0] != ':') { /* regular header */ + seen_regular_header_ = true; + } else { + GPR_DEBUG_ASSERT( + !seen_regular_header_ && + "Reserved header (colon-prefixed) happening after regular ones."); + } +#endif + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + Log(elem); + } + const bool elem_interned = GRPC_MDELEM_IS_INTERNED(elem); + const bool key_interned = elem_interned || grpc_slice_is_interned(elem_key); + // Key is not interned, emit literals. + if (!key_interned) { + EmitLitHdrWithStringKeyNotIdx(elem); + return; + } + /* Interned metadata => maybe already indexed. */ + uint32_t elem_hash = 0; + if (elem_interned) { + // Update filter to see if we can perhaps add this elem. + elem_hash = GRPC_MDELEM_STORAGE(elem) == GRPC_MDELEM_STORAGE_INTERNED + ? reinterpret_cast( + GRPC_MDELEM_DATA(elem)) + ->hash() + : reinterpret_cast( + GRPC_MDELEM_DATA(elem)) + ->hash(); + bool can_add_to_hashtable = + compressor_->filter_elems_.AddElement(elem_hash % kNumFilterValues); + /* is this elem currently in the decoders table? */ + auto indices_key = + compressor_->elem_index_.Lookup(KeyElem(elem, elem_hash)); + if (indices_key.has_value() && + compressor_->table_.ConvertableToDynamicIndex(*indices_key)) { + EmitIndexed(compressor_->table_.DynamicIndex(*indices_key)); + return; + } + /* Didn't hit either cuckoo index, so no emit. */ + if (!can_add_to_hashtable) elem_hash = 0; + } + + /* should this elem be in the table? */ + const size_t decoder_space_usage = + grpc_core::MetadataSizeInHPackTable(elem, use_true_binary_metadata_); + const bool decoder_space_available = + decoder_space_usage < kMaxDecoderSpaceUsage; + const bool should_add_elem = + elem_interned && decoder_space_available && elem_hash != 0; + /* no hits for the elem... maybe there's a key? */ + const uint32_t key_hash = elem_key.refcount->Hash(elem_key); + auto indices_key = + compressor_->key_index_.Lookup(KeySliceRef(elem_key.refcount, key_hash)); + if (indices_key.has_value() && + compressor_->table_.ConvertableToDynamicIndex(*indices_key)) { + if (should_add_elem) { + EmitLitHdrIncIdx(compressor_->table_.DynamicIndex(*indices_key), elem); + compressor_->AddElem(elem, decoder_space_usage, elem_hash, key_hash); + } else { + EmitLitHdrNotIdx(compressor_->table_.DynamicIndex(*indices_key), elem); + } + return; + } + /* no elem, key in the table... fall back to literal emission */ + const bool should_add_key = !elem_interned && decoder_space_available; + if (should_add_elem || should_add_key) { + EmitLitHdrWithStringKeyIncIdx(elem); + } else { + EmitLitHdrWithStringKeyNotIdx(elem); + } + if (should_add_elem) { + compressor_->AddElem(elem, decoder_space_usage, elem_hash, key_hash); + } else if (should_add_key) { + compressor_->AddKey(elem, decoder_space_usage, key_hash); + } +} + +void HPackCompressor::Framer::Encode(TeMetadata, TeMetadata::ValueType value) { + GPR_ASSERT(value == TeMetadata::ValueType::kTrailers); + if (compressor_->table_.ConvertableToDynamicIndex(compressor_->te_index_)) { + EmitIndexed(compressor_->table_.DynamicIndex(compressor_->te_index_)); + } else { + compressor_->te_index_ = compressor_->table_.AllocateIndex( + 2 /* te */ + 8 /* trailers */ + hpack_constants::kEntryOverhead); + EmitLitHdrWithNonBinaryStringKeyIncIdx(GRPC_MDSTR_TE, GRPC_MDSTR_TRAILERS); + } +} + +void HPackCompressor::Framer::Encode(GrpcTimeoutMetadata, + grpc_millis deadline) { + char timeout_str[GRPC_HTTP2_TIMEOUT_ENCODE_MIN_BUFSIZE]; + grpc_mdelem mdelem; + grpc_http2_encode_timeout(deadline - grpc_core::ExecCtx::Get()->Now(), + timeout_str); + mdelem = grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_TIMEOUT, grpc_core::UnmanagedMemorySlice(timeout_str)); + EncodeDynamic(mdelem); + GRPC_MDELEM_UNREF(mdelem); +} + +void HPackCompressor::SetMaxUsableSize(uint32_t max_table_size) { + max_usable_size_ = max_table_size; + SetMaxTableSize(std::min(table_.max_size(), max_table_size)); +} + +void HPackCompressor::SetMaxTableSize(uint32_t max_table_size) { + if (table_.SetMaxSize(std::min(max_usable_size_, max_table_size))) { + advertise_table_size_change_ = true; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "set max table size from encoder to %d", + max_table_size); + } + } +} + +HPackCompressor::Framer::Framer(const EncodeHeaderOptions& options, + HPackCompressor* compressor, + grpc_slice_buffer* output) + : max_frame_size_(options.max_frame_size), + use_true_binary_metadata_(options.use_true_binary_metadata), + is_end_of_stream_(options.is_end_of_stream), + stream_id_(options.stream_id), + output_(output), + stats_(options.stats), + compressor_(compressor), + prefix_(BeginFrame()) { + if (absl::exchange(compressor_->advertise_table_size_change_, false)) { + AdvertiseTableSizeChange(); + } +} + +void HPackCompressor::Framer::Encode(grpc_mdelem md) { + if (GRPC_MDELEM_STORAGE(md) == GRPC_MDELEM_STORAGE_STATIC) { + const uintptr_t static_index = + reinterpret_cast(GRPC_MDELEM_DATA(md)) + ->StaticIndex(); + if (static_index < hpack_constants::kLastStaticEntry) { + EmitIndexed(static_cast(static_index + 1)); + return; + } + } + EncodeDynamic(md); +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/hpack_encoder_table.cc b/src/core/ext/transport/chttp2/transport/hpack_encoder_table.cc new file mode 100644 index 00000000..bfe152a9 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/hpack_encoder_table.cc @@ -0,0 +1,86 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder_table.h" + +#include + +namespace grpc_core { + +uint32_t HPackEncoderTable::AllocateIndex(size_t element_size) { + uint32_t new_index = tail_remote_index_ + table_elems_ + 1; + GPR_DEBUG_ASSERT(element_size < 65536); + + if (element_size > max_table_size_) { + while (table_size_ > 0) { + EvictOne(); + } + return 0; + } + + // Reserve space for this element in the remote table: if this overflows + // the current table, drop elements until it fits, matching the decompressor + // algorithm. + while (table_size_ + element_size > max_table_size_) { + EvictOne(); + } + GPR_ASSERT(table_elems_ < elem_size_.size()); + elem_size_[new_index % elem_size_.size()] = + static_cast(element_size); + table_size_ += element_size; + table_elems_++; + + return new_index; +} + +bool HPackEncoderTable::SetMaxSize(uint32_t max_table_size) { + if (max_table_size == max_table_size_) { + return false; + } + while (table_size_ > 0 && table_size_ > max_table_size) { + EvictOne(); + } + max_table_size_ = max_table_size; + const size_t max_table_elems = + hpack_constants::EntriesForBytes(max_table_size); + // TODO(ctiller): integrate with ResourceQuota to rebuild smaller when we can. + if (max_table_elems > elem_size_.size()) { + Rebuild(std::max(max_table_elems, 2 * elem_size_.size())); + } + return true; +} + +void HPackEncoderTable::EvictOne() { + tail_remote_index_++; + GPR_ASSERT(tail_remote_index_ > 0); + GPR_ASSERT(table_elems_ > 0); + auto removing_size = elem_size_[tail_remote_index_ % elem_size_.size()]; + GPR_ASSERT(table_size_ >= removing_size); + table_size_ -= removing_size; + table_elems_--; +} + +void HPackEncoderTable::Rebuild(uint32_t capacity) { + decltype(elem_size_) new_elem_size(capacity); + GPR_ASSERT(table_elems_ <= capacity); + for (uint32_t i = 0; i < table_elems_; i++) { + uint32_t ofs = tail_remote_index_ + i + 1; + new_elem_size[ofs % capacity] = elem_size_[ofs % elem_size_.size()]; + } + elem_size_.swap(new_elem_size); +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/hpack_parser.cc b/src/core/ext/transport/chttp2/transport/hpack_parser.cc new file mode 100644 index 00000000..eb2f2f45 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/hpack_parser.cc @@ -0,0 +1,1454 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/match.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/http2_errors.h" + +#if __cplusplus > 201103L +#define GRPC_HPACK_CONSTEXPR_FN constexpr +#define GRPC_HPACK_CONSTEXPR_VALUE constexpr +#else +#define GRPC_HPACK_CONSTEXPR_FN +#define GRPC_HPACK_CONSTEXPR_VALUE const +#endif + +namespace grpc_core { + +TraceFlag grpc_trace_chttp2_hpack_parser(false, "chttp2_hpack_parser"); + +/* state table for huffman decoding: given a state, gives an index/16 into + next_sub_tbl. Taking that index and adding the value of the nibble being + considered returns the next state. + + generated by gen_hpack_tables.c */ +static const uint8_t next_tbl[256] = { + 0, 1, 2, 3, 4, 1, 2, 5, 6, 1, 7, 8, 1, 3, 3, 9, 10, 11, 1, 1, + 1, 12, 1, 2, 13, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, + 14, 1, 15, 16, 1, 17, 1, 15, 2, 7, 3, 18, 19, 1, 1, 1, 1, 20, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 15, 2, 2, 7, 21, 1, 22, 1, 1, 1, 1, 1, + 1, 1, 1, 15, 2, 2, 2, 2, 2, 2, 23, 24, 25, 1, 1, 1, 1, 2, 2, 2, + 26, 3, 3, 27, 10, 28, 1, 1, 1, 1, 1, 1, 2, 3, 29, 10, 30, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 31, 1, 1, 1, 1, 1, 1, 1, 2, + 2, 2, 2, 2, 2, 2, 2, 32, 1, 1, 15, 33, 1, 34, 35, 9, 36, 1, 1, 1, + 1, 1, 1, 1, 37, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 26, 9, + 38, 1, 1, 1, 1, 1, 1, 1, 15, 2, 2, 2, 2, 26, 3, 3, 39, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 7, 3, 3, 3, 40, 2, + 41, 1, 1, 1, 42, 43, 1, 1, 44, 1, 1, 1, 1, 15, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 45, 46, 1, 1, 2, 2, 2, 35, 3, 3, 18, 47, 2, +}; + +/* next state, based upon current state and the current nibble: see above. + generated by gen_hpack_tables.c */ +static const int16_t next_sub_tbl[48 * 16] = { + 1, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, + 218, 2, 6, 10, 13, 14, 15, 16, 17, 2, 6, 10, 13, 14, 15, + 16, 17, 3, 7, 11, 24, 3, 7, 11, 24, 3, 7, 11, 24, 3, + 7, 11, 24, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, + 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, + 199, 200, 201, 202, 203, 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 9, 133, 134, 135, 136, 137, 138, 139, 140, + 141, 142, 143, 144, 145, 146, 147, 3, 7, 11, 24, 3, 7, 11, 24, + 4, 8, 4, 8, 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 12, 132, 4, 8, 4, 8, 4, 8, + 4, 8, 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 18, 19, 20, 21, 4, 8, 4, + 8, 4, 8, 4, 8, 4, 8, 0, 0, 0, 22, 23, 91, 25, 26, + 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 3, + 7, 11, 24, 3, 7, 11, 24, 0, 0, 0, 0, 0, 41, 42, 43, + 2, 6, 10, 13, 14, 15, 16, 17, 3, 7, 11, 24, 3, 7, 11, + 24, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 0, 0, + 44, 45, 2, 6, 10, 13, 14, 15, 16, 17, 46, 47, 48, 49, 50, + 51, 52, 57, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, + 68, 69, 70, 71, 72, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 73, 75, 76, 77, 78, 79, 80, 81, 82, + 83, 84, 85, 86, 87, 88, 89, 90, 3, 7, 11, 24, 3, 7, 11, + 24, 3, 7, 11, 24, 0, 0, 0, 0, 3, 7, 11, 24, 3, 7, + 11, 24, 4, 8, 4, 8, 0, 0, 0, 92, 0, 0, 0, 93, 94, + 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 3, 7, 11, 24, + 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, + 8, 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 4, + 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 0, 0, + 0, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 2, 6, 10, 13, 14, 15, 16, 17, 4, 8, 4, 8, 4, 8, + 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 148, + 149, 150, 151, 3, 7, 11, 24, 4, 8, 4, 8, 0, 0, 0, 0, + 0, 0, 152, 153, 3, 7, 11, 24, 3, 7, 11, 24, 3, 7, 11, + 24, 154, 155, 156, 164, 3, 7, 11, 24, 3, 7, 11, 24, 3, 7, + 11, 24, 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 157, 158, 159, 160, 161, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, + 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, + 188, 189, 190, 191, 192, 193, 194, 195, 196, 4, 8, 4, 8, 4, 8, + 4, 8, 4, 8, 4, 8, 4, 8, 197, 198, 4, 8, 4, 8, 4, + 8, 4, 8, 0, 0, 0, 0, 0, 0, 219, 220, 3, 7, 11, 24, + 4, 8, 4, 8, 4, 8, 0, 0, 221, 222, 223, 224, 3, 7, 11, + 24, 3, 7, 11, 24, 4, 8, 4, 8, 4, 8, 225, 228, 4, 8, + 4, 8, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 226, 227, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, + 4, 8, 4, 8, 4, 8, 4, 8, 4, 8, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 255, +}; + +/* emission table: indexed like next_tbl, ultimately gives the byte to be + emitted, or -1 for no byte, or 256 for end of stream + + generated by gen_hpack_tables.c */ +static const uint16_t emit_tbl[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 0, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 0, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 0, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, + 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, + 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 0, + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, + 0, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, + 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, + 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, + 219, 220, 221, 0, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, + 248, +}; + +/* generated by gen_hpack_tables.c */ +static const int16_t emit_sub_tbl[249 * 16] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, 48, 48, 48, 48, 48, 48, 48, 48, 49, 49, 49, 49, 49, 49, + 49, 49, 48, 48, 48, 48, 49, 49, 49, 49, 50, 50, 50, 50, 97, + 97, 97, 97, 48, 48, 49, 49, 50, 50, 97, 97, 99, 99, 101, 101, + 105, 105, 111, 111, 48, 49, 50, 97, 99, 101, 105, 111, 115, 116, -1, + -1, -1, -1, -1, -1, 32, 32, 32, 32, 32, 32, 32, 32, 37, 37, + 37, 37, 37, 37, 37, 37, 99, 99, 99, 99, 101, 101, 101, 101, 105, + 105, 105, 105, 111, 111, 111, 111, 115, 115, 116, 116, 32, 37, 45, 46, + 47, 51, 52, 53, 54, 55, 56, 57, 61, 61, 61, 61, 61, 61, 61, + 61, 65, 65, 65, 65, 65, 65, 65, 65, 115, 115, 115, 115, 116, 116, + 116, 116, 32, 32, 37, 37, 45, 45, 46, 46, 61, 65, 95, 98, 100, + 102, 103, 104, 108, 109, 110, 112, 114, 117, -1, -1, 58, 58, 58, 58, + 58, 58, 58, 58, 66, 66, 66, 66, 66, 66, 66, 66, 47, 47, 51, + 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56, 57, 57, 61, 61, + 65, 65, 95, 95, 98, 98, 100, 100, 102, 102, 103, 103, 104, 104, 108, + 108, 109, 109, 110, 110, 112, 112, 114, 114, 117, 117, 58, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 89, 106, 107, 113, 118, 119, 120, 121, 122, -1, -1, + -1, -1, 38, 38, 38, 38, 38, 38, 38, 38, 42, 42, 42, 42, 42, + 42, 42, 42, 44, 44, 44, 44, 44, 44, 44, 44, 59, 59, 59, 59, + 59, 59, 59, 59, 88, 88, 88, 88, 88, 88, 88, 88, 90, 90, 90, + 90, 90, 90, 90, 90, 33, 33, 34, 34, 40, 40, 41, 41, 63, 63, + 39, 43, 124, -1, -1, -1, 35, 35, 35, 35, 35, 35, 35, 35, 62, + 62, 62, 62, 62, 62, 62, 62, 0, 0, 0, 0, 36, 36, 36, 36, + 64, 64, 64, 64, 91, 91, 91, 91, 69, 69, 69, 69, 69, 69, 69, + 69, 70, 70, 70, 70, 70, 70, 70, 70, 71, 71, 71, 71, 71, 71, + 71, 71, 72, 72, 72, 72, 72, 72, 72, 72, 73, 73, 73, 73, 73, + 73, 73, 73, 74, 74, 74, 74, 74, 74, 74, 74, 75, 75, 75, 75, + 75, 75, 75, 75, 76, 76, 76, 76, 76, 76, 76, 76, 77, 77, 77, + 77, 77, 77, 77, 77, 78, 78, 78, 78, 78, 78, 78, 78, 79, 79, + 79, 79, 79, 79, 79, 79, 80, 80, 80, 80, 80, 80, 80, 80, 81, + 81, 81, 81, 81, 81, 81, 81, 82, 82, 82, 82, 82, 82, 82, 82, + 83, 83, 83, 83, 83, 83, 83, 83, 84, 84, 84, 84, 84, 84, 84, + 84, 85, 85, 85, 85, 85, 85, 85, 85, 86, 86, 86, 86, 86, 86, + 86, 86, 87, 87, 87, 87, 87, 87, 87, 87, 89, 89, 89, 89, 89, + 89, 89, 89, 106, 106, 106, 106, 106, 106, 106, 106, 107, 107, 107, 107, + 107, 107, 107, 107, 113, 113, 113, 113, 113, 113, 113, 113, 118, 118, 118, + 118, 118, 118, 118, 118, 119, 119, 119, 119, 119, 119, 119, 119, 120, 120, + 120, 120, 120, 120, 120, 120, 121, 121, 121, 121, 121, 121, 121, 121, 122, + 122, 122, 122, 122, 122, 122, 122, 38, 38, 38, 38, 42, 42, 42, 42, + 44, 44, 44, 44, 59, 59, 59, 59, 88, 88, 88, 88, 90, 90, 90, + 90, 33, 34, 40, 41, 63, -1, -1, -1, 39, 39, 39, 39, 39, 39, + 39, 39, 43, 43, 43, 43, 43, 43, 43, 43, 124, 124, 124, 124, 124, + 124, 124, 124, 35, 35, 35, 35, 62, 62, 62, 62, 0, 0, 36, 36, + 64, 64, 91, 91, 93, 93, 126, 126, 94, 125, -1, -1, 60, 60, 60, + 60, 60, 60, 60, 60, 96, 96, 96, 96, 96, 96, 96, 96, 123, 123, + 123, 123, 123, 123, 123, 123, -1, -1, -1, -1, -1, -1, -1, -1, 92, + 92, 92, 92, 92, 92, 92, 92, 195, 195, 195, 195, 195, 195, 195, 195, + 208, 208, 208, 208, 208, 208, 208, 208, 128, 128, 128, 128, 130, 130, 130, + 130, 131, 131, 131, 131, 162, 162, 162, 162, 184, 184, 184, 184, 194, 194, + 194, 194, 224, 224, 224, 224, 226, 226, 226, 226, 153, 153, 161, 161, 167, + 167, 172, 172, 176, 176, 177, 177, 179, 179, 209, 209, 216, 216, 217, 217, + 227, 227, 229, 229, 230, 230, 129, 132, 133, 134, 136, 146, 154, 156, 160, + 163, 164, 169, 170, 173, 178, 181, 185, 186, 187, 189, 190, 196, 198, 228, + 232, 233, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 135, + 135, 135, 135, 135, 135, 135, 135, 137, 137, 137, 137, 137, 137, 137, 137, + 138, 138, 138, 138, 138, 138, 138, 138, 139, 139, 139, 139, 139, 139, 139, + 139, 140, 140, 140, 140, 140, 140, 140, 140, 141, 141, 141, 141, 141, 141, + 141, 141, 143, 143, 143, 143, 143, 143, 143, 143, 147, 147, 147, 147, 147, + 147, 147, 147, 149, 149, 149, 149, 149, 149, 149, 149, 150, 150, 150, 150, + 150, 150, 150, 150, 151, 151, 151, 151, 151, 151, 151, 151, 152, 152, 152, + 152, 152, 152, 152, 152, 155, 155, 155, 155, 155, 155, 155, 155, 157, 157, + 157, 157, 157, 157, 157, 157, 158, 158, 158, 158, 158, 158, 158, 158, 165, + 165, 165, 165, 165, 165, 165, 165, 166, 166, 166, 166, 166, 166, 166, 166, + 168, 168, 168, 168, 168, 168, 168, 168, 174, 174, 174, 174, 174, 174, 174, + 174, 175, 175, 175, 175, 175, 175, 175, 175, 180, 180, 180, 180, 180, 180, + 180, 180, 182, 182, 182, 182, 182, 182, 182, 182, 183, 183, 183, 183, 183, + 183, 183, 183, 188, 188, 188, 188, 188, 188, 188, 188, 191, 191, 191, 191, + 191, 191, 191, 191, 197, 197, 197, 197, 197, 197, 197, 197, 231, 231, 231, + 231, 231, 231, 231, 231, 239, 239, 239, 239, 239, 239, 239, 239, 9, 9, + 9, 9, 142, 142, 142, 142, 144, 144, 144, 144, 145, 145, 145, 145, 148, + 148, 148, 148, 159, 159, 159, 159, 171, 171, 171, 171, 206, 206, 206, 206, + 215, 215, 215, 215, 225, 225, 225, 225, 236, 236, 236, 236, 237, 237, 237, + 237, 199, 199, 207, 207, 234, 234, 235, 235, 192, 193, 200, 201, 202, 205, + 210, 213, 218, 219, 238, 240, 242, 243, 255, -1, 203, 203, 203, 203, 203, + 203, 203, 203, 204, 204, 204, 204, 204, 204, 204, 204, 211, 211, 211, 211, + 211, 211, 211, 211, 212, 212, 212, 212, 212, 212, 212, 212, 214, 214, 214, + 214, 214, 214, 214, 214, 221, 221, 221, 221, 221, 221, 221, 221, 222, 222, + 222, 222, 222, 222, 222, 222, 223, 223, 223, 223, 223, 223, 223, 223, 241, + 241, 241, 241, 241, 241, 241, 241, 244, 244, 244, 244, 244, 244, 244, 244, + 245, 245, 245, 245, 245, 245, 245, 245, 246, 246, 246, 246, 246, 246, 246, + 246, 247, 247, 247, 247, 247, 247, 247, 247, 248, 248, 248, 248, 248, 248, + 248, 248, 250, 250, 250, 250, 250, 250, 250, 250, 251, 251, 251, 251, 251, + 251, 251, 251, 252, 252, 252, 252, 252, 252, 252, 252, 253, 253, 253, 253, + 253, 253, 253, 253, 254, 254, 254, 254, 254, 254, 254, 254, 2, 2, 2, + 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, + 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 11, 11, 11, 11, 12, + 12, 12, 12, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, + 17, 17, 17, 17, 18, 18, 18, 18, 19, 19, 19, 19, 20, 20, 20, + 20, 21, 21, 21, 21, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, + 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 28, 29, + 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 127, 127, 127, 127, + 220, 220, 220, 220, 249, 249, 249, 249, 10, 13, 22, 256, 93, 93, 93, + 93, 126, 126, 126, 126, 94, 94, 125, 125, 60, 96, 123, -1, 92, 195, + 208, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 128, + 128, 128, 128, 128, 128, 128, 128, 130, 130, 130, 130, 130, 130, 130, 130, + 131, 131, 131, 131, 131, 131, 131, 131, 162, 162, 162, 162, 162, 162, 162, + 162, 184, 184, 184, 184, 184, 184, 184, 184, 194, 194, 194, 194, 194, 194, + 194, 194, 224, 224, 224, 224, 224, 224, 224, 224, 226, 226, 226, 226, 226, + 226, 226, 226, 153, 153, 153, 153, 161, 161, 161, 161, 167, 167, 167, 167, + 172, 172, 172, 172, 176, 176, 176, 176, 177, 177, 177, 177, 179, 179, 179, + 179, 209, 209, 209, 209, 216, 216, 216, 216, 217, 217, 217, 217, 227, 227, + 227, 227, 229, 229, 229, 229, 230, 230, 230, 230, 129, 129, 132, 132, 133, + 133, 134, 134, 136, 136, 146, 146, 154, 154, 156, 156, 160, 160, 163, 163, + 164, 164, 169, 169, 170, 170, 173, 173, 178, 178, 181, 181, 185, 185, 186, + 186, 187, 187, 189, 189, 190, 190, 196, 196, 198, 198, 228, 228, 232, 232, + 233, 233, 1, 135, 137, 138, 139, 140, 141, 143, 147, 149, 150, 151, 152, + 155, 157, 158, 165, 166, 168, 174, 175, 180, 182, 183, 188, 191, 197, 231, + 239, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 9, 9, 9, + 9, 9, 9, 9, 9, 142, 142, 142, 142, 142, 142, 142, 142, 144, 144, + 144, 144, 144, 144, 144, 144, 145, 145, 145, 145, 145, 145, 145, 145, 148, + 148, 148, 148, 148, 148, 148, 148, 159, 159, 159, 159, 159, 159, 159, 159, + 171, 171, 171, 171, 171, 171, 171, 171, 206, 206, 206, 206, 206, 206, 206, + 206, 215, 215, 215, 215, 215, 215, 215, 215, 225, 225, 225, 225, 225, 225, + 225, 225, 236, 236, 236, 236, 236, 236, 236, 236, 237, 237, 237, 237, 237, + 237, 237, 237, 199, 199, 199, 199, 207, 207, 207, 207, 234, 234, 234, 234, + 235, 235, 235, 235, 192, 192, 193, 193, 200, 200, 201, 201, 202, 202, 205, + 205, 210, 210, 213, 213, 218, 218, 219, 219, 238, 238, 240, 240, 242, 242, + 243, 243, 255, 255, 203, 204, 211, 212, 214, 221, 222, 223, 241, 244, 245, + 246, 247, 248, 250, 251, 252, 253, 254, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, + 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, + 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, + 8, 8, 8, 8, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, + 12, 12, 12, 12, 12, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, + 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 17, + 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, + 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, + 20, 21, 21, 21, 21, 21, 21, 21, 21, 23, 23, 23, 23, 23, 23, + 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, + 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, + 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, + 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, + 31, 31, 31, 31, 31, 31, 127, 127, 127, 127, 127, 127, 127, 127, 220, + 220, 220, 220, 220, 220, 220, 220, 249, 249, 249, 249, 249, 249, 249, 249, + 10, 10, 13, 13, 22, 22, 256, 256, 67, 67, 67, 67, 67, 67, 67, + 67, 68, 68, 68, 68, 68, 68, 68, 68, 95, 95, 95, 95, 95, 95, + 95, 95, 98, 98, 98, 98, 98, 98, 98, 98, 100, 100, 100, 100, 100, + 100, 100, 100, 102, 102, 102, 102, 102, 102, 102, 102, 103, 103, 103, 103, + 103, 103, 103, 103, 104, 104, 104, 104, 104, 104, 104, 104, 108, 108, 108, + 108, 108, 108, 108, 108, 109, 109, 109, 109, 109, 109, 109, 109, 110, 110, + 110, 110, 110, 110, 110, 110, 112, 112, 112, 112, 112, 112, 112, 112, 114, + 114, 114, 114, 114, 114, 114, 114, 117, 117, 117, 117, 117, 117, 117, 117, + 58, 58, 58, 58, 66, 66, 66, 66, 67, 67, 67, 67, 68, 68, 68, + 68, 69, 69, 69, 69, 70, 70, 70, 70, 71, 71, 71, 71, 72, 72, + 72, 72, 73, 73, 73, 73, 74, 74, 74, 74, 75, 75, 75, 75, 76, + 76, 76, 76, 77, 77, 77, 77, 78, 78, 78, 78, 79, 79, 79, 79, + 80, 80, 80, 80, 81, 81, 81, 81, 82, 82, 82, 82, 83, 83, 83, + 83, 84, 84, 84, 84, 85, 85, 85, 85, 86, 86, 86, 86, 87, 87, + 87, 87, 89, 89, 89, 89, 106, 106, 106, 106, 107, 107, 107, 107, 113, + 113, 113, 113, 118, 118, 118, 118, 119, 119, 119, 119, 120, 120, 120, 120, + 121, 121, 121, 121, 122, 122, 122, 122, 38, 38, 42, 42, 44, 44, 59, + 59, 88, 88, 90, 90, -1, -1, -1, -1, 33, 33, 33, 33, 33, 33, + 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 40, 40, 40, 40, 40, + 40, 40, 40, 41, 41, 41, 41, 41, 41, 41, 41, 63, 63, 63, 63, + 63, 63, 63, 63, 39, 39, 39, 39, 43, 43, 43, 43, 124, 124, 124, + 124, 35, 35, 62, 62, 0, 36, 64, 91, 93, 126, -1, -1, 94, 94, + 94, 94, 94, 94, 94, 94, 125, 125, 125, 125, 125, 125, 125, 125, 60, + 60, 60, 60, 96, 96, 96, 96, 123, 123, 123, 123, -1, -1, -1, -1, + 92, 92, 92, 92, 195, 195, 195, 195, 208, 208, 208, 208, 128, 128, 130, + 130, 131, 131, 162, 162, 184, 184, 194, 194, 224, 224, 226, 226, 153, 161, + 167, 172, 176, 177, 179, 209, 216, 217, 227, 229, 230, -1, -1, -1, -1, + -1, -1, -1, 129, 129, 129, 129, 129, 129, 129, 129, 132, 132, 132, 132, + 132, 132, 132, 132, 133, 133, 133, 133, 133, 133, 133, 133, 134, 134, 134, + 134, 134, 134, 134, 134, 136, 136, 136, 136, 136, 136, 136, 136, 146, 146, + 146, 146, 146, 146, 146, 146, 154, 154, 154, 154, 154, 154, 154, 154, 156, + 156, 156, 156, 156, 156, 156, 156, 160, 160, 160, 160, 160, 160, 160, 160, + 163, 163, 163, 163, 163, 163, 163, 163, 164, 164, 164, 164, 164, 164, 164, + 164, 169, 169, 169, 169, 169, 169, 169, 169, 170, 170, 170, 170, 170, 170, + 170, 170, 173, 173, 173, 173, 173, 173, 173, 173, 178, 178, 178, 178, 178, + 178, 178, 178, 181, 181, 181, 181, 181, 181, 181, 181, 185, 185, 185, 185, + 185, 185, 185, 185, 186, 186, 186, 186, 186, 186, 186, 186, 187, 187, 187, + 187, 187, 187, 187, 187, 189, 189, 189, 189, 189, 189, 189, 189, 190, 190, + 190, 190, 190, 190, 190, 190, 196, 196, 196, 196, 196, 196, 196, 196, 198, + 198, 198, 198, 198, 198, 198, 198, 228, 228, 228, 228, 228, 228, 228, 228, + 232, 232, 232, 232, 232, 232, 232, 232, 233, 233, 233, 233, 233, 233, 233, + 233, 1, 1, 1, 1, 135, 135, 135, 135, 137, 137, 137, 137, 138, 138, + 138, 138, 139, 139, 139, 139, 140, 140, 140, 140, 141, 141, 141, 141, 143, + 143, 143, 143, 147, 147, 147, 147, 149, 149, 149, 149, 150, 150, 150, 150, + 151, 151, 151, 151, 152, 152, 152, 152, 155, 155, 155, 155, 157, 157, 157, + 157, 158, 158, 158, 158, 165, 165, 165, 165, 166, 166, 166, 166, 168, 168, + 168, 168, 174, 174, 174, 174, 175, 175, 175, 175, 180, 180, 180, 180, 182, + 182, 182, 182, 183, 183, 183, 183, 188, 188, 188, 188, 191, 191, 191, 191, + 197, 197, 197, 197, 231, 231, 231, 231, 239, 239, 239, 239, 9, 9, 142, + 142, 144, 144, 145, 145, 148, 148, 159, 159, 171, 171, 206, 206, 215, 215, + 225, 225, 236, 236, 237, 237, 199, 207, 234, 235, 192, 192, 192, 192, 192, + 192, 192, 192, 193, 193, 193, 193, 193, 193, 193, 193, 200, 200, 200, 200, + 200, 200, 200, 200, 201, 201, 201, 201, 201, 201, 201, 201, 202, 202, 202, + 202, 202, 202, 202, 202, 205, 205, 205, 205, 205, 205, 205, 205, 210, 210, + 210, 210, 210, 210, 210, 210, 213, 213, 213, 213, 213, 213, 213, 213, 218, + 218, 218, 218, 218, 218, 218, 218, 219, 219, 219, 219, 219, 219, 219, 219, + 238, 238, 238, 238, 238, 238, 238, 238, 240, 240, 240, 240, 240, 240, 240, + 240, 242, 242, 242, 242, 242, 242, 242, 242, 243, 243, 243, 243, 243, 243, + 243, 243, 255, 255, 255, 255, 255, 255, 255, 255, 203, 203, 203, 203, 204, + 204, 204, 204, 211, 211, 211, 211, 212, 212, 212, 212, 214, 214, 214, 214, + 221, 221, 221, 221, 222, 222, 222, 222, 223, 223, 223, 223, 241, 241, 241, + 241, 244, 244, 244, 244, 245, 245, 245, 245, 246, 246, 246, 246, 247, 247, + 247, 247, 248, 248, 248, 248, 250, 250, 250, 250, 251, 251, 251, 251, 252, + 252, 252, 252, 253, 253, 253, 253, 254, 254, 254, 254, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 11, 11, 12, 12, 14, + 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, + 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, + 30, 31, 31, 127, 127, 220, 220, 249, 249, -1, -1, 10, 10, 10, 10, + 10, 10, 10, 10, 13, 13, 13, 13, 13, 13, 13, 13, 22, 22, 22, + 22, 22, 22, 22, 22, 256, 256, 256, 256, 256, 256, 256, 256, 45, 45, + 45, 45, 45, 45, 45, 45, 46, 46, 46, 46, 46, 46, 46, 46, 47, + 47, 47, 47, 47, 47, 47, 47, 51, 51, 51, 51, 51, 51, 51, 51, + 52, 52, 52, 52, 52, 52, 52, 52, 53, 53, 53, 53, 53, 53, 53, + 53, 54, 54, 54, 54, 54, 54, 54, 54, 55, 55, 55, 55, 55, 55, + 55, 55, 56, 56, 56, 56, 56, 56, 56, 56, 57, 57, 57, 57, 57, + 57, 57, 57, 50, 50, 50, 50, 50, 50, 50, 50, 97, 97, 97, 97, + 97, 97, 97, 97, 99, 99, 99, 99, 99, 99, 99, 99, 101, 101, 101, + 101, 101, 101, 101, 101, 105, 105, 105, 105, 105, 105, 105, 105, 111, 111, + 111, 111, 111, 111, 111, 111, 115, 115, 115, 115, 115, 115, 115, 115, 116, + 116, 116, 116, 116, 116, 116, 116, 32, 32, 32, 32, 37, 37, 37, 37, + 45, 45, 45, 45, 46, 46, 46, 46, 47, 47, 47, 47, 51, 51, 51, + 51, 52, 52, 52, 52, 53, 53, 53, 53, 54, 54, 54, 54, 55, 55, + 55, 55, 56, 56, 56, 56, 57, 57, 57, 57, 61, 61, 61, 61, 65, + 65, 65, 65, 95, 95, 95, 95, 98, 98, 98, 98, 100, 100, 100, 100, + 102, 102, 102, 102, 103, 103, 103, 103, 104, 104, 104, 104, 108, 108, 108, + 108, 109, 109, 109, 109, 110, 110, 110, 110, 112, 112, 112, 112, 114, 114, + 114, 114, 117, 117, 117, 117, 58, 58, 66, 66, 67, 67, 68, 68, 69, + 69, 70, 70, 71, 71, 72, 72, 73, 73, 74, 74, 75, 75, 76, 76, + 77, 77, 78, 78, 79, 79, 80, 80, 81, 81, 82, 82, 83, 83, 84, + 84, 85, 85, 86, 86, 87, 87, 89, 89, 106, 106, 107, 107, 113, 113, + 118, 118, 119, 119, 120, 120, 121, 121, 122, 122, 38, 42, 44, 59, 88, + 90, -1, -1, 33, 33, 33, 33, 34, 34, 34, 34, 40, 40, 40, 40, + 41, 41, 41, 41, 63, 63, 63, 63, 39, 39, 43, 43, 124, 124, 35, + 62, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 36, 36, + 36, 36, 36, 36, 36, 36, 64, 64, 64, 64, 64, 64, 64, 64, 91, + 91, 91, 91, 91, 91, 91, 91, 93, 93, 93, 93, 93, 93, 93, 93, + 126, 126, 126, 126, 126, 126, 126, 126, 94, 94, 94, 94, 125, 125, 125, + 125, 60, 60, 96, 96, 123, 123, -1, -1, 92, 92, 195, 195, 208, 208, + 128, 130, 131, 162, 184, 194, 224, 226, -1, -1, 153, 153, 153, 153, 153, + 153, 153, 153, 161, 161, 161, 161, 161, 161, 161, 161, 167, 167, 167, 167, + 167, 167, 167, 167, 172, 172, 172, 172, 172, 172, 172, 172, 176, 176, 176, + 176, 176, 176, 176, 176, 177, 177, 177, 177, 177, 177, 177, 177, 179, 179, + 179, 179, 179, 179, 179, 179, 209, 209, 209, 209, 209, 209, 209, 209, 216, + 216, 216, 216, 216, 216, 216, 216, 217, 217, 217, 217, 217, 217, 217, 217, + 227, 227, 227, 227, 227, 227, 227, 227, 229, 229, 229, 229, 229, 229, 229, + 229, 230, 230, 230, 230, 230, 230, 230, 230, 129, 129, 129, 129, 132, 132, + 132, 132, 133, 133, 133, 133, 134, 134, 134, 134, 136, 136, 136, 136, 146, + 146, 146, 146, 154, 154, 154, 154, 156, 156, 156, 156, 160, 160, 160, 160, + 163, 163, 163, 163, 164, 164, 164, 164, 169, 169, 169, 169, 170, 170, 170, + 170, 173, 173, 173, 173, 178, 178, 178, 178, 181, 181, 181, 181, 185, 185, + 185, 185, 186, 186, 186, 186, 187, 187, 187, 187, 189, 189, 189, 189, 190, + 190, 190, 190, 196, 196, 196, 196, 198, 198, 198, 198, 228, 228, 228, 228, + 232, 232, 232, 232, 233, 233, 233, 233, 1, 1, 135, 135, 137, 137, 138, + 138, 139, 139, 140, 140, 141, 141, 143, 143, 147, 147, 149, 149, 150, 150, + 151, 151, 152, 152, 155, 155, 157, 157, 158, 158, 165, 165, 166, 166, 168, + 168, 174, 174, 175, 175, 180, 180, 182, 182, 183, 183, 188, 188, 191, 191, + 197, 197, 231, 231, 239, 239, 9, 142, 144, 145, 148, 159, 171, 206, 215, + 225, 236, 237, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 199, 199, + 199, 199, 199, 199, 199, 199, 207, 207, 207, 207, 207, 207, 207, 207, 234, + 234, 234, 234, 234, 234, 234, 234, 235, 235, 235, 235, 235, 235, 235, 235, + 192, 192, 192, 192, 193, 193, 193, 193, 200, 200, 200, 200, 201, 201, 201, + 201, 202, 202, 202, 202, 205, 205, 205, 205, 210, 210, 210, 210, 213, 213, + 213, 213, 218, 218, 218, 218, 219, 219, 219, 219, 238, 238, 238, 238, 240, + 240, 240, 240, 242, 242, 242, 242, 243, 243, 243, 243, 255, 255, 255, 255, + 203, 203, 204, 204, 211, 211, 212, 212, 214, 214, 221, 221, 222, 222, 223, + 223, 241, 241, 244, 244, 245, 245, 246, 246, 247, 247, 248, 248, 250, 250, + 251, 251, 252, 252, 253, 253, 254, 254, 2, 3, 4, 5, 6, 7, 8, + 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 127, 220, 249, -1, 10, 10, 10, 10, 13, 13, 13, + 13, 22, 22, 22, 22, 256, 256, 256, 256, +}; + +namespace { +// The alphabet used for base64 encoding binary metadata. +static constexpr char kBase64Alphabet[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; + +// An inverted table: for each value in kBase64Alphabet, table contains the +// index with which it's stored, so we can quickly invert the encoding without +// any complicated runtime logic. +struct Base64InverseTable { + uint8_t table[256]{}; + GRPC_HPACK_CONSTEXPR_FN Base64InverseTable() { + for (int i = 0; i < 256; i++) { + table[i] = 255; + } + for (const char* p = kBase64Alphabet; *p; p++) { + uint8_t idx = *p; + uint8_t ofs = p - kBase64Alphabet; + table[idx] = ofs; + } + } +}; + +static GRPC_HPACK_CONSTEXPR_VALUE Base64InverseTable kBase64InverseTable; +} // namespace + +// Input tracks the current byte through the input data and provides it +// via a simple stream interface. +class HPackParser::Input { + public: + Input(grpc_slice_refcount* current_slice_refcount, const uint8_t* begin, + const uint8_t* end) + : current_slice_refcount_(current_slice_refcount), + begin_(begin), + end_(end), + frontier_(begin) {} + + // If input is backed by a slice, retrieve its refcount. If not, return + // nullptr. + grpc_slice_refcount* slice_refcount() { return current_slice_refcount_; } + + // Have we reached the end of input? + bool end_of_stream() const { return begin_ == end_; } + // How many bytes until end of input + size_t remaining() const { return end_ - begin_; } + // Current position, as a pointer + const uint8_t* cur_ptr() const { return begin_; } + // End position, as a pointer + const uint8_t* end_ptr() const { return end_; } + // Move read position forward by n, unchecked + void Advance(size_t n) { begin_ += n; } + + // Retrieve the current character, or nullopt if end of stream + // Do not advance + absl::optional peek() const { + if (end_of_stream()) { + return {}; + } + return *begin_; + } + + // Retrieve and advance past the current character, or return nullopt if end + // of stream + absl::optional Next() { + if (end_of_stream()) { + return UnexpectedEOF(absl::optional()); + } + return *begin_++; + } + + // Helper to parse a varint delta on top of value, return nullopt on failure + // (setting error) + absl::optional ParseVarint(uint32_t value) { + // TODO(ctiller): break out a variant of this when we know there are at + // least 5 bytes in input_ + auto cur = Next(); + if (!cur) return {}; + value += *cur & 0x7f; + if ((*cur & 0x80) == 0) return value; + + cur = Next(); + if (!cur) return {}; + value += (*cur & 0x7f) << 7; + if ((*cur & 0x80) == 0) return value; + + cur = Next(); + if (!cur) return {}; + value += (*cur & 0x7f) << 14; + if ((*cur & 0x80) == 0) return value; + + cur = Next(); + if (!cur) return {}; + value += (*cur & 0x7f) << 21; + if ((*cur & 0x80) == 0) return value; + + cur = Next(); + if (!cur) return {}; + uint32_t c = (*cur) & 0x7f; + // We might overflow here, so we need to be a little careful about the + // addition + if (c > 0xf) return ParseVarintOutOfRange(value, *cur); + const uint32_t add = c << 28; + if (add > 0xffffffffu - value) { + return ParseVarintOutOfRange(value, *cur); + } + value += add; + if ((*cur & 0x80) == 0) return value; + + // Spec weirdness: we can add an infinite stream of 0x80 at the end of a + // varint and still end up with a correctly encoded varint. + do { + cur = Next(); + if (!cur.has_value()) return {}; + } while (*cur == 0x80); + + // BUT... the last byte needs to be 0x00 or we'll overflow dramatically! + if (*cur == 0) return value; + return ParseVarintOutOfRange(value, *cur); + } + + // Prefix for a string + struct StringPrefix { + // Number of bytes in input for string + uint32_t length; + // Is it huffman compressed + bool huff; + }; + + // Parse a string prefix + absl::optional ParseStringPrefix() { + auto cur = Next(); + if (!cur.has_value()) return {}; + // Huffman if the top bit is 1 + const bool huff = (*cur & 0x80) != 0; + // String length + uint32_t strlen = (*cur & 0x7f); + if (strlen == 0x7f) { + // all ones ==> varint string length + auto v = ParseVarint(0x7f); + if (!v.has_value()) return {}; + strlen = *v; + } + return StringPrefix{strlen, huff}; + } + + // Check if we saw an EOF.. must be verified before looking at TakeError + bool eof_error() const { return eof_error_; } + + // Extract the parse error, leaving the current error as NONE. + grpc_error_handle TakeError() { + grpc_error_handle out = error_; + error_ = GRPC_ERROR_NONE; + return out; + } + + // Set the current error - allows the rest of the code not to need to pass + // around StatusOr<> which would be prohibitive here. + GPR_ATTRIBUTE_NOINLINE void SetError(grpc_error_handle error) { + if (error_ != GRPC_ERROR_NONE || eof_error_) { + GRPC_ERROR_UNREF(error); + return; + } + error_ = error; + begin_ = end_; + } + + // If no error is set, set it to the value produced by error_factory. + // Return return_value unchanged. + template + GPR_ATTRIBUTE_NOINLINE T MaybeSetErrorAndReturn(F error_factory, + T return_value) { + if (error_ != GRPC_ERROR_NONE || eof_error_) return return_value; + error_ = error_factory(); + begin_ = end_; + return return_value; + } + + // Set the error to an unexpected eof, and return result (code golfed as this + // is a common case) + template + T UnexpectedEOF(T return_value) { + if (error_ != GRPC_ERROR_NONE) return return_value; + eof_error_ = true; + return return_value; + } + + // Update the frontier - signifies we've successfully parsed another element + void UpdateFrontier() { frontier_ = begin_; } + + // Get the frontier - for buffering should we fail due to eof + const uint8_t* frontier() const { return frontier_; } + + private: + // Helper to set the error to out of range for ParseVarint + absl::optional ParseVarintOutOfRange(uint32_t value, + uint8_t last_byte) { + return MaybeSetErrorAndReturn( + [value, last_byte] { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "integer overflow in hpack integer decoding: have 0x%08x, " + "got byte 0x%02x on byte 5", + value, last_byte)); + }, + absl::optional()); + } + + // Refcount if we are backed by a slice + grpc_slice_refcount* current_slice_refcount_; + // Current input point + const uint8_t* begin_; + // End of stream point + const uint8_t* const end_; + // Frontier denotes the first byte past successfully processed input + const uint8_t* frontier_; + // Current error + grpc_error_handle error_ = GRPC_ERROR_NONE; + // If the error was EOF, we flag it here.. + bool eof_error_ = false; +}; + +// Helper to parse a string and turn it into a slice with appropriate memory +// management characteristics +class HPackParser::String { + public: + // Helper to specify a string should be internalized + struct Intern {}; + // Helper to specify a string should be externalized + struct Extern {}; + + private: + // Forward declare take functions... we'll need them in the public interface + UnmanagedMemorySlice Take(Extern); + ManagedMemorySlice Take(Intern); + + public: + // If a String is a Slice then unref + ~String() { + if (auto* p = absl::get_if(&value_)) { + grpc_slice_unref_internal(*p); + } + } + + // Take the value and leave this empty + // Use Intern/Extern to choose memory management + template + auto Take() -> decltype(this->Take(T())) { + return Take(T()); + } + + String(const String&) = delete; + String& operator=(const String&) = delete; + String(String&& other) noexcept : value_(std::move(other.value_)) { + other.value_ = absl::Span(); + } + String& operator=(String&& other) noexcept { + value_ = std::move(other.value_); + other.value_ = absl::Span(); + return *this; + } + + // Parse a non-binary string + static absl::optional Parse(Input* input) { + auto pfx = input->ParseStringPrefix(); + if (!pfx.has_value()) return {}; + if (pfx->huff) { + // Huffman coded + std::vector output; + auto v = ParseHuff(input, pfx->length, + [&output](uint8_t c) { output.push_back(c); }); + if (!v) return {}; + return String(std::move(output)); + } + return ParseUncompressed(input, pfx->length); + } + + // Parse a binary string + static absl::optional ParseBinary(Input* input) { + auto pfx = input->ParseStringPrefix(); + if (!pfx.has_value()) return {}; + if (!pfx->huff) { + if (pfx->length > 0 && input->peek() == 0) { + // 'true-binary' + input->Advance(1); + return ParseUncompressed(input, pfx->length - 1); + } + // Base64 encoded... pull out the string, then unbase64 it + auto base64 = ParseUncompressed(input, pfx->length); + if (!base64.has_value()) return {}; + return Unbase64(input, std::move(*base64)); + } else { + // Huffman encoded... + std::vector decompressed; + // State here says either we don't know if it's base64 or binary, or we do + // and what is it. + enum class State { kUnsure, kBinary, kBase64 }; + State state = State::kUnsure; + auto decompressed_ok = + ParseHuff(input, pfx->length, [&state, &decompressed](uint8_t c) { + if (state == State::kUnsure) { + // First byte... if it's zero it's binary + if (c == 0) { + // Save the type, and skip the zero + state = State::kBinary; + return; + } else { + // Flag base64, store this value + state = State::kBase64; + } + } + // Non-first byte, or base64 first byte + decompressed.push_back(c); + }); + if (!decompressed_ok) return {}; + switch (state) { + case State::kUnsure: + // No bytes, empty span + return String(absl::Span()); + case State::kBinary: + // Binary, we're done + return String(std::move(decompressed)); + case State::kBase64: + // Base64 - unpack it + return Unbase64(input, String(std::move(decompressed))); + } + GPR_UNREACHABLE_CODE(abort();); + } + } + + private: + void AppendBytes(const uint8_t* data, size_t length); + explicit String(std::vector v) : value_(std::move(v)) {} + explicit String(absl::Span v) : value_(v) {} + String(grpc_slice_refcount* r, const uint8_t* begin, const uint8_t* end) + : value_(MakeSlice(r, begin, end)) {} + + // Given a refcount and a byte range, make a slice + static grpc_slice MakeSlice(grpc_slice_refcount* r, const uint8_t* begin, + const uint8_t* end) { + grpc_slice out; + out.refcount = r; + r->Ref(); + out.data.refcounted.bytes = const_cast(begin); + out.data.refcounted.length = end - begin; + return out; + } + + // Parse some huffman encoded bytes, using output(uint8_t b) to emit each + // decoded byte. + template + static bool ParseHuff(Input* input, uint32_t length, Out output) { + GRPC_STATS_INC_HPACK_RECV_HUFFMAN(); + int16_t state = 0; + // Parse one half byte... we leverage some lookup tables to keep the logic + // here really simple. + auto nibble = [&output, &state](uint8_t nibble) { + int16_t emit = emit_sub_tbl[16 * emit_tbl[state] + nibble]; + int16_t next = next_sub_tbl[16 * next_tbl[state] + nibble]; + if (emit != -1) { + if (emit >= 0 && emit < 256) { + output(static_cast(emit)); + } else { + assert(emit == 256); + } + } + state = next; + }; + // If there's insufficient bytes remaining, return now. + if (input->remaining() < length) { + return input->UnexpectedEOF(false); + } + // Grab the byte range, and iterate through it. + const uint8_t* p = input->cur_ptr(); + input->Advance(length); + for (uint32_t i = 0; i < length; i++) { + nibble(p[i] >> 4); + nibble(p[i] & 0xf); + } + return true; + } + + // Parse some uncompressed string bytes. + static absl::optional ParseUncompressed(Input* input, + uint32_t length) { + GRPC_STATS_INC_HPACK_RECV_UNCOMPRESSED(); + // Check there's enough bytes + if (input->remaining() < length) { + return input->UnexpectedEOF(absl::optional()); + } + auto* refcount = input->slice_refcount(); + auto* p = input->cur_ptr(); + input->Advance(length); + if (refcount != nullptr) { + return String(refcount, p, p + length); + } else { + return String(absl::Span(p, length)); + } + } + + // Turn base64 encoded bytes into not base64 encoded bytes. + // Only takes input to set an error on failure. + static absl::optional Unbase64(Input* input, String s) { + auto v = Match( + s.value_, + [](const grpc_slice& slice) { + return Unbase64Loop(GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_END_PTR(slice)); + }, + [](absl::Span span) { + return Unbase64Loop(span.begin(), span.end()); + }, + [](const std::vector& vec) { + return Unbase64Loop(vec.data(), vec.data() + vec.size()); + }); + if (!v.has_value()) { + return input->MaybeSetErrorAndReturn( + [] { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "illegal base64 encoding"); + }, + absl::optional()); + } + return String(std::move(*v)); + } + + // Main loop for Unbase64 + static absl::optional> Unbase64Loop(const uint8_t* cur, + const uint8_t* end) { + while (cur != end && end[-1] == '=') { + --end; + } + + std::vector out; + out.reserve(3 * (end - cur) / 4 + 3); + + // Decode 4 bytes at a time while we can + while (end - cur >= 4) { + uint32_t bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + uint32_t buffer = bits << 18; + ++cur; + + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits << 12; + ++cur; + + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits << 6; + ++cur; + + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits; + ++cur; + + out.insert(out.end(), {static_cast(buffer >> 16), + static_cast(buffer >> 8), + static_cast(buffer)}); + } + // Deal with the last 0, 1, 2, or 3 bytes. + switch (end - cur) { + case 0: + return out; + case 1: + return {}; + case 2: { + uint32_t bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + uint32_t buffer = bits << 18; + + ++cur; + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits << 12; + + if (buffer & 0xffff) return {}; + out.push_back(static_cast(buffer >> 16)); + return out; + } + case 3: { + uint32_t bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + uint32_t buffer = bits << 18; + + ++cur; + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits << 12; + + ++cur; + bits = kBase64InverseTable.table[*cur]; + if (bits > 63) return {}; + buffer |= bits << 6; + + ++cur; + if (buffer & 0xff) return {}; + out.push_back(static_cast(buffer >> 16)); + out.push_back(static_cast(buffer >> 8)); + return out; + } + } + + GPR_UNREACHABLE_CODE(return out;); + } + + absl::variant, std::vector> + value_; +}; + +// Parser parses one key/value pair from a byte stream. +class HPackParser::Parser { + public: + Parser(Input* input, grpc_metadata_batch* metadata_buffer, + uint32_t metadata_size_limit, HPackTable* table, + uint8_t* dynamic_table_updates_allowed, uint32_t* frame_length, + LogInfo log_info) + : input_(input), + metadata_buffer_(metadata_buffer), + table_(table), + dynamic_table_updates_allowed_(dynamic_table_updates_allowed), + frame_length_(frame_length), + metadata_size_limit_(metadata_size_limit), + log_info_(log_info) {} + + // Skip any priority bits, or return false on failure + bool SkipPriority() { + if (input_->remaining() < 5) return input_->UnexpectedEOF(false); + input_->Advance(5); + return true; + } + + bool Parse() { + auto cur = *input_->Next(); + switch (cur >> 4) { + // Literal header not indexed - First byte format: 0000xxxx + // Literal header never indexed - First byte format: 0001xxxx + // Where xxxx: + // 0000 - literal key + // 1111 - indexed key, varint encoded index + // other - indexed key, inline encoded index + case 0: + case 1: + switch (cur & 0xf) { + case 0: // literal key + return FinishHeaderOmitFromTable(ParseLiteralKey()); + case 0xf: // varint encoded key index + return FinishHeaderOmitFromTable( + ParseVarIdxKey(0xf)); + default: // inline encoded key index + return FinishHeaderOmitFromTable( + ParseIdxKey(cur & 0xf)); + } + // Update max table size. + // First byte format: 001xxxxx + // Where xxxxx: + // 11111 - max size is varint encoded + // other - max size is stored inline + case 2: + // inline encoded max table size + return FinishMaxTableSize(cur & 0x1f); + case 3: + if (cur == 0x3f) { + // varint encoded max table size + return FinishMaxTableSize(input_->ParseVarint(0x1f)); + } else { + // inline encoded max table size + return FinishMaxTableSize(cur & 0x1f); + } + // Literal header with incremental indexing. + // First byte format: 01xxxxxx + // Where xxxxxx: + // 000000 - literal key + // 111111 - indexed key, varint encoded index + // other - indexed key, inline encoded index + case 4: + if (cur == 0x40) { + // literal key + return FinishHeaderAndAddToTable(ParseLiteralKey()); + } + ABSL_FALLTHROUGH_INTENDED; + case 5: + case 6: + // inline encoded key index + return FinishHeaderAndAddToTable( + ParseIdxKey(cur & 0x3f)); + case 7: + if (cur == 0x7f) { + // varint encoded key index + return FinishHeaderAndAddToTable( + ParseVarIdxKey(0x3f)); + } else { + // inline encoded key index + return FinishHeaderAndAddToTable( + ParseIdxKey(cur & 0x3f)); + } + // Indexed Header Field Representation + // First byte format: 1xxxxxxx + // Where xxxxxxx: + // 0000000 - illegal + // 1111111 - varint encoded field index + // other - inline encoded field index + case 8: + if (cur == 0x80) { + // illegal value. + return input_->MaybeSetErrorAndReturn( + [] { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Illegal hpack op code"); + }, + false); + } + ABSL_FALLTHROUGH_INTENDED; + case 9: + case 10: + case 11: + case 12: + case 13: + case 14: + // inline encoded field index + return FinishIndexed(cur & 0x7f); + case 15: + if (cur == 0xff) { + // varint encoded field index + return FinishIndexed(input_->ParseVarint(0x7f)); + } else { + // inline encoded field index + return FinishIndexed(cur & 0x7f); + } + } + GPR_UNREACHABLE_CODE(abort()); + } + + private: + void GPR_ATTRIBUTE_NOINLINE LogHeader(const HPackTable::Memento& memento) { + const char* type; + switch (log_info_.type) { + case LogInfo::kHeaders: + type = "HDR"; + break; + case LogInfo::kTrailers: + type = "TRL"; + break; + case LogInfo::kDontKnow: + type = "???"; + break; + } + gpr_log(GPR_DEBUG, "HTTP:%d:%s:%s: %s", log_info_.stream_id, type, + log_info_.is_client ? "CLI" : "SVR", memento.DebugString().c_str()); + } + + bool EmitHeader(const HPackTable::Memento& md) { + // Pass up to the transport + if (GPR_UNLIKELY(metadata_buffer_ == nullptr)) return true; + *frame_length_ += md.transport_size(); + if (GPR_UNLIKELY(*frame_length_ > metadata_size_limit_)) { + return HandleMetadataSizeLimitExceeded(md); + } + + grpc_error_handle err = metadata_buffer_->Set(md); + if (GPR_UNLIKELY(err != GRPC_ERROR_NONE)) { + input_->SetError(err); + return false; + } + return true; + } + + bool FinishHeaderAndAddToTable(absl::optional md) { + // Allow higher code to just pass in failures ... simplifies things a bit. + if (!md.has_value()) return false; + // Log if desired + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_chttp2_hpack_parser)) { + LogHeader(*md); + } + // Emit whilst we own the metadata. + auto r = EmitHeader(*md); + // Add to the hpack table + grpc_error_handle err = table_->Add(std::move(*md)); + if (GPR_UNLIKELY(err != GRPC_ERROR_NONE)) { + input_->SetError(err); + return false; + }; + return r; + } + + bool FinishHeaderOmitFromTable(absl::optional md) { + // Allow higher code to just pass in failures ... simplifies things a bit. + if (!md.has_value()) return false; + return FinishHeaderOmitFromTable(*md); + } + + bool FinishHeaderOmitFromTable(const HPackTable::Memento& md) { + // Log if desired + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_chttp2_hpack_parser)) { + LogHeader(md); + } + return EmitHeader(md); + } + + // Parse a string encoded key and a string encoded value + template + absl::optional ParseLiteralKey() { + auto key = String::Parse(input_); + if (!key.has_value()) return {}; + auto key_slice = key->Take(); + auto value = + ParseValueString(grpc_is_refcounted_slice_binary_header(key_slice)); + if (GPR_UNLIKELY(!value.has_value())) { + grpc_slice_unref_internal(key_slice); + return {}; + } + return grpc_metadata_batch::Parse(key_slice, value->Take()); + } + + // Parse an index encoded key and a string encoded value + template + absl::optional ParseIdxKey(uint32_t index) { + const auto* elem = table_->Lookup(index); + if (GPR_UNLIKELY(elem == nullptr)) { + return InvalidHPackIndexError(index, + absl::optional()); + } + auto value = ParseValueString(elem->is_binary_header()); + if (GPR_UNLIKELY(!value.has_value())) return {}; + return elem->WithNewValue(value->Take()); + } + + // Parse a varint index encoded key and a string encoded value + template + absl::optional ParseVarIdxKey(uint32_t offset) { + auto index = input_->ParseVarint(offset); + if (GPR_UNLIKELY(!index.has_value())) return {}; + return ParseIdxKey(*index); + } + + // Parse a string, figuring out if it's binary or not by the key name. + absl::optional ParseValueString(bool is_binary) { + if (is_binary) { + return String::ParseBinary(input_); + } else { + return String::Parse(input_); + } + } + + // Emit an indexed field + bool FinishIndexed(absl::optional index) { + *dynamic_table_updates_allowed_ = 0; + if (!index.has_value()) return false; + const auto* elem = table_->Lookup(*index); + if (GPR_UNLIKELY(elem == nullptr)) { + return InvalidHPackIndexError(*index, false); + } + GRPC_STATS_INC_HPACK_RECV_INDEXED(); + return FinishHeaderOmitFromTable(*elem); + } + + // finish parsing a max table size change + bool FinishMaxTableSize(absl::optional size) { + if (!size.has_value()) return false; + if (*dynamic_table_updates_allowed_ == 0) { + return input_->MaybeSetErrorAndReturn( + [] { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "More than two max table size changes in a single frame"); + }, + false); + } + (*dynamic_table_updates_allowed_)--; + grpc_error_handle err = table_->SetCurrentTableSize(*size); + if (err != GRPC_ERROR_NONE) { + input_->SetError(err); + return false; + } + return true; + } + + // Set an invalid hpack index error if no error has been set. Returns result + // unmodified. + template + R InvalidHPackIndexError(uint32_t index, R result) { + return input_->MaybeSetErrorAndReturn( + [this, index] { + return grpc_error_set_int( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid HPACK index received"), + GRPC_ERROR_INT_INDEX, + static_cast(index)), + GRPC_ERROR_INT_SIZE, + static_cast(this->table_->num_entries())); + }, + std::move(result)); + } + + GPR_ATTRIBUTE_NOINLINE + bool HandleMetadataSizeLimitExceeded(const HPackTable::Memento&) { + gpr_log(GPR_DEBUG, + "received initial metadata size exceeds limit (%" PRIu32 + " vs. %" PRIu32 + "). GRPC_ARG_MAX_METADATA_SIZE can be set to increase this limit.", + *frame_length_, metadata_size_limit_); + return input_->MaybeSetErrorAndReturn( + [] { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "received initial metadata size exceeds limit"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_RESOURCE_EXHAUSTED); + }, + false); + } + + Input* const input_; + grpc_metadata_batch* const metadata_buffer_; + HPackTable* const table_; + uint8_t* const dynamic_table_updates_allowed_; + uint32_t* const frame_length_; + const uint32_t metadata_size_limit_; + const LogInfo log_info_; +}; + +UnmanagedMemorySlice HPackParser::String::Take(Extern) { + auto s = Match( + value_, + [](const grpc_slice& slice) { + // TODO(ctiller): Think about this before submission. + GPR_DEBUG_ASSERT(!grpc_slice_is_interned(slice)); + auto out_slice = grpc_slice_copy(slice); + grpc_slice_unref_internal(slice); + return static_cast(out_slice); + }, + [](absl::Span span) { + return UnmanagedMemorySlice( + reinterpret_cast(const_cast(span.begin())), + span.size()); + }, + [](const std::vector& v) { + return UnmanagedMemorySlice(reinterpret_cast(v.data()), + v.size()); + }); + value_ = absl::Span(); + return s; +} + +ManagedMemorySlice HPackParser::String::Take(Intern) { + auto s = Match( + value_, + [](const grpc_slice& slice) { + ManagedMemorySlice s(&slice); + grpc_slice_unref_internal(slice); + return s; + }, + [](absl::Span span) { + return ManagedMemorySlice( + reinterpret_cast(const_cast(span.data())), + span.size()); + }, + [](const std::vector& v) { + return ManagedMemorySlice(reinterpret_cast(v.data()), + v.size()); + }); + value_ = absl::Span(); + return s; +} + +/* PUBLIC INTERFACE */ + +HPackParser::HPackParser() = default; + +HPackParser::~HPackParser() = default; + +void HPackParser::BeginFrame(grpc_metadata_batch* metadata_buffer, + uint32_t metadata_size_limit, Boundary boundary, + Priority priority, LogInfo log_info) { + metadata_buffer_ = metadata_buffer; + boundary_ = boundary; + priority_ = priority; + dynamic_table_updates_allowed_ = 2; + frame_length_ = 0; + metadata_size_limit_ = metadata_size_limit; + log_info_ = log_info; +} + +grpc_error_handle HPackParser::Parse(const grpc_slice& slice, bool is_last) { + if (GPR_UNLIKELY(!unparsed_bytes_.empty())) { + std::vector buffer = std::move(unparsed_bytes_); + buffer.insert(buffer.end(), GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_END_PTR(slice)); + return ParseInput( + Input(nullptr, buffer.data(), buffer.data() + buffer.size()), is_last); + } + return ParseInput(Input(slice.refcount, GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_END_PTR(slice)), + is_last); +} + +grpc_error_handle HPackParser::ParseInput(Input input, bool is_last) { + if (ParseInputInner(&input)) { + return GRPC_ERROR_NONE; + } + if (input.eof_error()) { + if (GPR_UNLIKELY(is_last && is_boundary())) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Incomplete header at the end of a header/continuation sequence"); + } + unparsed_bytes_ = std::vector(input.frontier(), input.end_ptr()); + return GRPC_ERROR_NONE; + } + return input.TakeError(); +} + +bool HPackParser::ParseInputInner(Input* input) { + switch (priority_) { + case Priority::None: + break; + case Priority::Included: { + if (input->remaining() < 5) return input->UnexpectedEOF(false); + input->Advance(5); + input->UpdateFrontier(); + priority_ = Priority::None; + } + } + while (!input->end_of_stream()) { + if (GPR_UNLIKELY(!Parser(input, metadata_buffer_, metadata_size_limit_, + &table_, &dynamic_table_updates_allowed_, + &frame_length_, log_info_) + .Parse())) { + return false; + } + input->UpdateFrontier(); + } + return true; +} + +void HPackParser::FinishFrame() { metadata_buffer_ = nullptr; } + +} // namespace grpc_core + +// TODO(ctiller): this serves as an eviction notice for the remainder of this +// file... it belongs elsewhere! + +typedef void (*maybe_complete_func_type)(grpc_chttp2_transport* t, + grpc_chttp2_stream* s); +static const maybe_complete_func_type maybe_complete_funcs[] = { + grpc_chttp2_maybe_complete_recv_initial_metadata, + grpc_chttp2_maybe_complete_recv_trailing_metadata}; + +static void force_client_rst_stream(void* sp, grpc_error_handle /*error*/) { + grpc_chttp2_stream* s = static_cast(sp); + grpc_chttp2_transport* t = s->t; + if (!s->write_closed) { + grpc_chttp2_add_rst_stream_to_next_write(t, s->id, GRPC_HTTP2_NO_ERROR, + &s->stats.outgoing); + grpc_chttp2_initiate_write(t, GRPC_CHTTP2_INITIATE_WRITE_FORCE_RST_STREAM); + grpc_chttp2_mark_stream_closed(t, s, true, true, GRPC_ERROR_NONE); + } + GRPC_CHTTP2_STREAM_UNREF(s, "final_rst"); +} + +static void parse_stream_compression_md(grpc_chttp2_transport* /*t*/, + grpc_chttp2_stream* s, + grpc_metadata_batch* initial_metadata) { + if (initial_metadata->legacy_index()->named.content_encoding == nullptr || + grpc_stream_compression_method_parse( + GRPC_MDVALUE( + initial_metadata->legacy_index()->named.content_encoding->md), + false, &s->stream_decompression_method) == 0) { + s->stream_decompression_method = + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS; + } + + if (s->stream_decompression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS) { + s->stream_decompression_ctx = nullptr; + grpc_slice_buffer_init(&s->decompressed_data_buffer); + } +} + +grpc_error_handle grpc_chttp2_header_parser_parse(void* hpack_parser, + grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + const grpc_slice& slice, + int is_last) { + GPR_TIMER_SCOPE("grpc_chttp2_header_parser_parse", 0); + auto* parser = static_cast(hpack_parser); + if (s != nullptr) { + s->stats.incoming.header_bytes += GRPC_SLICE_LENGTH(slice); + } + grpc_error_handle error = parser->Parse(slice, is_last != 0); + if (error != GRPC_ERROR_NONE) { + return error; + } + if (is_last) { + /* need to check for null stream: this can occur if we receive an invalid + stream id on a header */ + if (s != nullptr) { + if (parser->is_boundary()) { + if (s->header_frames_received == 2) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Too many trailer frames"); + } + /* Process stream compression md element if it exists */ + if (s->header_frames_received == + 0) { /* Only acts on initial metadata */ + parse_stream_compression_md(t, s, &s->initial_metadata_buffer); + } + s->published_metadata[s->header_frames_received] = + GRPC_METADATA_PUBLISHED_FROM_WIRE; + maybe_complete_funcs[s->header_frames_received](t, s); + s->header_frames_received++; + } + if (parser->is_eof()) { + if (t->is_client && !s->write_closed) { + /* server eof ==> complete closure; we may need to forcefully close + the stream. Wait until the combiner lock is ready to be released + however -- it might be that we receive a RST_STREAM following this + and can avoid the extra write */ + GRPC_CHTTP2_STREAM_REF(s, "final_rst"); + t->combiner->FinallyRun( + GRPC_CLOSURE_CREATE(force_client_rst_stream, s, nullptr), + GRPC_ERROR_NONE); + } + grpc_chttp2_mark_stream_closed(t, s, true, false, GRPC_ERROR_NONE); + } + } + parser->FinishFrame(); + } + return GRPC_ERROR_NONE; +} diff --git a/src/core/ext/transport/chttp2/transport/hpack_parser_table.cc b/src/core/ext/transport/chttp2/transport/hpack_parser_table.cc new file mode 100644 index 00000000..80a026e6 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/hpack_parser_table.cc @@ -0,0 +1,146 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_parser_table.h" + +#include +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/static_metadata.h" + +extern grpc_core::TraceFlag grpc_http_trace; + +namespace grpc_core { + +HPackTable::HPackTable() : static_metadata_(GetStaticMementos()) {} + +HPackTable::~HPackTable() = default; + +/* Evict one element from the table */ +void HPackTable::EvictOne() { + auto first_entry = std::move(entries_[first_entry_]); + GPR_ASSERT(first_entry.transport_size() <= mem_used_); + mem_used_ -= first_entry.transport_size(); + first_entry_ = ((first_entry_ + 1) % entries_.size()); + num_entries_--; +} + +void HPackTable::Rebuild(uint32_t new_cap) { + EntriesVec entries; + entries.resize(new_cap); + for (size_t i = 0; i < num_entries_; i++) { + entries[i] = std::move(entries_[(first_entry_ + i) % entries_.size()]); + } + first_entry_ = 0; + entries_.swap(entries); +} + +void HPackTable::SetMaxBytes(uint32_t max_bytes) { + if (max_bytes_ == max_bytes) { + return; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "Update hpack parser max size to %d", max_bytes); + } + while (mem_used_ > max_bytes) { + EvictOne(); + } + max_bytes_ = max_bytes; +} + +grpc_error_handle HPackTable::SetCurrentTableSize(uint32_t bytes) { + if (current_table_bytes_ == bytes) { + return GRPC_ERROR_NONE; + } + if (bytes > max_bytes_) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Attempt to make hpack table %d bytes when max is %d bytes", bytes, + max_bytes_)); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_INFO, "Update hpack parser table size to %d", bytes); + } + while (mem_used_ > bytes) { + EvictOne(); + } + current_table_bytes_ = bytes; + max_entries_ = hpack_constants::EntriesForBytes(bytes); + if (max_entries_ > entries_.size()) { + Rebuild(max_entries_); + } else if (max_entries_ < entries_.size() / 3) { + // TODO(ctiller): move to resource quota system, only shrink under memory + // pressure + uint32_t new_cap = + std::max(max_entries_, static_cast(kInlineEntries)); + if (new_cap != entries_.size()) { + Rebuild(new_cap); + } + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle HPackTable::Add(Memento md) { + if (current_table_bytes_ > max_bytes_) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "HPACK max table size reduced to %d but not reflected by hpack " + "stream (still at %d)", + max_bytes_, current_table_bytes_)); + } + + // we can't add elements bigger than the max table size + if (md.transport_size() > current_table_bytes_) { + // HPACK draft 10 section 4.4 states: + // If the size of the new entry is less than or equal to the maximum + // size, that entry is added to the table. It is not an error to + // attempt to add an entry that is larger than the maximum size; an + // attempt to add an entry larger than the entire table causes + // the table to be emptied of all existing entries, and results in an + // empty table. + while (num_entries_) { + EvictOne(); + } + return GRPC_ERROR_NONE; + } + + // evict entries to ensure no overflow + while (md.transport_size() > + static_cast(current_table_bytes_) - mem_used_) { + EvictOne(); + } + + // copy the finalized entry in + mem_used_ += md.transport_size(); + entries_[(first_entry_ + num_entries_) % entries_.size()] = std::move(md); + + // update accounting values + num_entries_++; + return GRPC_ERROR_NONE; +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/hpack_utils.cc b/src/core/ext/transport/chttp2/transport/hpack_utils.cc new file mode 100644 index 00000000..fb1d2cb2 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/hpack_utils.cc @@ -0,0 +1,46 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_utils.h" + +#include "src/core/lib/surface/validate_metadata.h" + +namespace grpc_core { + +namespace { +size_t Base64EncodedSize(size_t raw_length) { + static constexpr uint8_t tail_xtra[3] = {0, 2, 3}; + return raw_length / 3 * 4 + tail_xtra[raw_length % 3]; +} +} // namespace + +// Return the size occupied by some metadata in the HPACK table. +size_t MetadataSizeInHPackTable(grpc_mdelem elem, + bool use_true_binary_metadata) { + const uint8_t* key_buf = GRPC_SLICE_START_PTR(GRPC_MDKEY(elem)); + size_t key_len = GRPC_SLICE_LENGTH(GRPC_MDKEY(elem)); + size_t overhead_and_key = 32 + key_len; + size_t value_len = GRPC_SLICE_LENGTH(GRPC_MDVALUE(elem)); + if (grpc_key_is_binary_header(key_buf, key_len)) { + return overhead_and_key + (use_true_binary_metadata + ? value_len + 1 + : Base64EncodedSize(value_len)); + } else { + return overhead_and_key + value_len; + } +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/http2_settings.cc b/src/core/ext/transport/chttp2/transport/http2_settings.cc new file mode 100644 index 00000000..294ee8e4 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/http2_settings.cc @@ -0,0 +1,62 @@ +/* + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Automatically generated by tools/codegen/core/gen_settings_ids.py + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/http2_settings.h" + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/transport/http2_errors.h" + +const uint16_t grpc_setting_id_to_wire_id[] = {1, 2, 3, 4, 5, 6, 65027}; + +bool grpc_wire_id_to_setting_id(uint32_t wire_id, grpc_chttp2_setting_id* out) { + uint32_t i = wire_id - 1; + uint32_t x = i % 256; + uint32_t y = i / 256; + uint32_t h = x; + switch (y) { + case 254: + h += 4; + break; + } + *out = static_cast(h); + return h < GPR_ARRAY_SIZE(grpc_setting_id_to_wire_id) && + grpc_setting_id_to_wire_id[h] == wire_id; +} + +const grpc_chttp2_setting_parameters + grpc_chttp2_settings_parameters[GRPC_CHTTP2_NUM_SETTINGS] = { + {"HEADER_TABLE_SIZE", 4096u, 0u, 4294967295u, + GRPC_CHTTP2_CLAMP_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}, + {"ENABLE_PUSH", 1u, 0u, 1u, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, + GRPC_HTTP2_PROTOCOL_ERROR}, + {"MAX_CONCURRENT_STREAMS", 4294967295u, 0u, 4294967295u, + GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}, + {"INITIAL_WINDOW_SIZE", 65535u, 0u, 2147483647u, + GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, + GRPC_HTTP2_FLOW_CONTROL_ERROR}, + {"MAX_FRAME_SIZE", 16384u, 16384u, 16777215u, + GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}, + {"MAX_HEADER_LIST_SIZE", 16777216u, 0u, 16777216u, + GRPC_CHTTP2_CLAMP_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}, + {"GRPC_ALLOW_TRUE_BINARY_METADATA", 0u, 0u, 1u, + GRPC_CHTTP2_CLAMP_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}, +}; diff --git a/src/core/ext/transport/chttp2/transport/huffsyms.cc b/src/core/ext/transport/chttp2/transport/huffsyms.cc new file mode 100644 index 00000000..813e4c91 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/huffsyms.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/huffsyms.h" + +/* Constants pulled from the HPACK spec, and converted to C using the vim + command: + :%s/.* \([0-9a-f]\+\) \[ *\([0-9]\+\)\]/{0x\1, \2},/g */ +const grpc_chttp2_huffsym grpc_chttp2_huffsyms[GRPC_CHTTP2_NUM_HUFFSYMS] = { + {0x1ff8, 13}, {0x7fffd8, 23}, {0xfffffe2, 28}, {0xfffffe3, 28}, + {0xfffffe4, 28}, {0xfffffe5, 28}, {0xfffffe6, 28}, {0xfffffe7, 28}, + {0xfffffe8, 28}, {0xffffea, 24}, {0x3ffffffc, 30}, {0xfffffe9, 28}, + {0xfffffea, 28}, {0x3ffffffd, 30}, {0xfffffeb, 28}, {0xfffffec, 28}, + {0xfffffed, 28}, {0xfffffee, 28}, {0xfffffef, 28}, {0xffffff0, 28}, + {0xffffff1, 28}, {0xffffff2, 28}, {0x3ffffffe, 30}, {0xffffff3, 28}, + {0xffffff4, 28}, {0xffffff5, 28}, {0xffffff6, 28}, {0xffffff7, 28}, + {0xffffff8, 28}, {0xffffff9, 28}, {0xffffffa, 28}, {0xffffffb, 28}, + {0x14, 6}, {0x3f8, 10}, {0x3f9, 10}, {0xffa, 12}, + {0x1ff9, 13}, {0x15, 6}, {0xf8, 8}, {0x7fa, 11}, + {0x3fa, 10}, {0x3fb, 10}, {0xf9, 8}, {0x7fb, 11}, + {0xfa, 8}, {0x16, 6}, {0x17, 6}, {0x18, 6}, + {0x0, 5}, {0x1, 5}, {0x2, 5}, {0x19, 6}, + {0x1a, 6}, {0x1b, 6}, {0x1c, 6}, {0x1d, 6}, + {0x1e, 6}, {0x1f, 6}, {0x5c, 7}, {0xfb, 8}, + {0x7ffc, 15}, {0x20, 6}, {0xffb, 12}, {0x3fc, 10}, + {0x1ffa, 13}, {0x21, 6}, {0x5d, 7}, {0x5e, 7}, + {0x5f, 7}, {0x60, 7}, {0x61, 7}, {0x62, 7}, + {0x63, 7}, {0x64, 7}, {0x65, 7}, {0x66, 7}, + {0x67, 7}, {0x68, 7}, {0x69, 7}, {0x6a, 7}, + {0x6b, 7}, {0x6c, 7}, {0x6d, 7}, {0x6e, 7}, + {0x6f, 7}, {0x70, 7}, {0x71, 7}, {0x72, 7}, + {0xfc, 8}, {0x73, 7}, {0xfd, 8}, {0x1ffb, 13}, + {0x7fff0, 19}, {0x1ffc, 13}, {0x3ffc, 14}, {0x22, 6}, + {0x7ffd, 15}, {0x3, 5}, {0x23, 6}, {0x4, 5}, + {0x24, 6}, {0x5, 5}, {0x25, 6}, {0x26, 6}, + {0x27, 6}, {0x6, 5}, {0x74, 7}, {0x75, 7}, + {0x28, 6}, {0x29, 6}, {0x2a, 6}, {0x7, 5}, + {0x2b, 6}, {0x76, 7}, {0x2c, 6}, {0x8, 5}, + {0x9, 5}, {0x2d, 6}, {0x77, 7}, {0x78, 7}, + {0x79, 7}, {0x7a, 7}, {0x7b, 7}, {0x7ffe, 15}, + {0x7fc, 11}, {0x3ffd, 14}, {0x1ffd, 13}, {0xffffffc, 28}, + {0xfffe6, 20}, {0x3fffd2, 22}, {0xfffe7, 20}, {0xfffe8, 20}, + {0x3fffd3, 22}, {0x3fffd4, 22}, {0x3fffd5, 22}, {0x7fffd9, 23}, + {0x3fffd6, 22}, {0x7fffda, 23}, {0x7fffdb, 23}, {0x7fffdc, 23}, + {0x7fffdd, 23}, {0x7fffde, 23}, {0xffffeb, 24}, {0x7fffdf, 23}, + {0xffffec, 24}, {0xffffed, 24}, {0x3fffd7, 22}, {0x7fffe0, 23}, + {0xffffee, 24}, {0x7fffe1, 23}, {0x7fffe2, 23}, {0x7fffe3, 23}, + {0x7fffe4, 23}, {0x1fffdc, 21}, {0x3fffd8, 22}, {0x7fffe5, 23}, + {0x3fffd9, 22}, {0x7fffe6, 23}, {0x7fffe7, 23}, {0xffffef, 24}, + {0x3fffda, 22}, {0x1fffdd, 21}, {0xfffe9, 20}, {0x3fffdb, 22}, + {0x3fffdc, 22}, {0x7fffe8, 23}, {0x7fffe9, 23}, {0x1fffde, 21}, + {0x7fffea, 23}, {0x3fffdd, 22}, {0x3fffde, 22}, {0xfffff0, 24}, + {0x1fffdf, 21}, {0x3fffdf, 22}, {0x7fffeb, 23}, {0x7fffec, 23}, + {0x1fffe0, 21}, {0x1fffe1, 21}, {0x3fffe0, 22}, {0x1fffe2, 21}, + {0x7fffed, 23}, {0x3fffe1, 22}, {0x7fffee, 23}, {0x7fffef, 23}, + {0xfffea, 20}, {0x3fffe2, 22}, {0x3fffe3, 22}, {0x3fffe4, 22}, + {0x7ffff0, 23}, {0x3fffe5, 22}, {0x3fffe6, 22}, {0x7ffff1, 23}, + {0x3ffffe0, 26}, {0x3ffffe1, 26}, {0xfffeb, 20}, {0x7fff1, 19}, + {0x3fffe7, 22}, {0x7ffff2, 23}, {0x3fffe8, 22}, {0x1ffffec, 25}, + {0x3ffffe2, 26}, {0x3ffffe3, 26}, {0x3ffffe4, 26}, {0x7ffffde, 27}, + {0x7ffffdf, 27}, {0x3ffffe5, 26}, {0xfffff1, 24}, {0x1ffffed, 25}, + {0x7fff2, 19}, {0x1fffe3, 21}, {0x3ffffe6, 26}, {0x7ffffe0, 27}, + {0x7ffffe1, 27}, {0x3ffffe7, 26}, {0x7ffffe2, 27}, {0xfffff2, 24}, + {0x1fffe4, 21}, {0x1fffe5, 21}, {0x3ffffe8, 26}, {0x3ffffe9, 26}, + {0xffffffd, 28}, {0x7ffffe3, 27}, {0x7ffffe4, 27}, {0x7ffffe5, 27}, + {0xfffec, 20}, {0xfffff3, 24}, {0xfffed, 20}, {0x1fffe6, 21}, + {0x3fffe9, 22}, {0x1fffe7, 21}, {0x1fffe8, 21}, {0x7ffff3, 23}, + {0x3fffea, 22}, {0x3fffeb, 22}, {0x1ffffee, 25}, {0x1ffffef, 25}, + {0xfffff4, 24}, {0xfffff5, 24}, {0x3ffffea, 26}, {0x7ffff4, 23}, + {0x3ffffeb, 26}, {0x7ffffe6, 27}, {0x3ffffec, 26}, {0x3ffffed, 26}, + {0x7ffffe7, 27}, {0x7ffffe8, 27}, {0x7ffffe9, 27}, {0x7ffffea, 27}, + {0x7ffffeb, 27}, {0xffffffe, 28}, {0x7ffffec, 27}, {0x7ffffed, 27}, + {0x7ffffee, 27}, {0x7ffffef, 27}, {0x7fffff0, 27}, {0x3ffffee, 26}, + {0x3fffffff, 30}, +}; diff --git a/src/core/ext/transport/chttp2/transport/parsing.cc b/src/core/ext/transport/chttp2/transport/parsing.cc new file mode 100644 index 00000000..5f75fffb --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/parsing.cc @@ -0,0 +1,651 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/transport/http2_errors.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_conversion.h" +#include "src/core/lib/transport/timeout_encoding.h" + +using grpc_core::HPackParser; + +static grpc_error_handle init_frame_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_header_frame_parser(grpc_chttp2_transport* t, + int is_continuation); +static grpc_error_handle init_data_frame_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_rst_stream_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_settings_frame_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_window_update_frame_parser( + grpc_chttp2_transport* t); +static grpc_error_handle init_ping_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_goaway_parser(grpc_chttp2_transport* t); +static grpc_error_handle init_non_header_skip_frame_parser( + grpc_chttp2_transport* t); + +static grpc_error_handle parse_frame_slice(grpc_chttp2_transport* t, + const grpc_slice& slice, + int is_last); + +grpc_error_handle grpc_chttp2_perform_read(grpc_chttp2_transport* t, + const grpc_slice& slice) { + const uint8_t* beg = GRPC_SLICE_START_PTR(slice); + const uint8_t* end = GRPC_SLICE_END_PTR(slice); + const uint8_t* cur = beg; + grpc_error_handle err; + + if (cur == end) return GRPC_ERROR_NONE; + + switch (t->deframe_state) { + case GRPC_DTS_CLIENT_PREFIX_0: + case GRPC_DTS_CLIENT_PREFIX_1: + case GRPC_DTS_CLIENT_PREFIX_2: + case GRPC_DTS_CLIENT_PREFIX_3: + case GRPC_DTS_CLIENT_PREFIX_4: + case GRPC_DTS_CLIENT_PREFIX_5: + case GRPC_DTS_CLIENT_PREFIX_6: + case GRPC_DTS_CLIENT_PREFIX_7: + case GRPC_DTS_CLIENT_PREFIX_8: + case GRPC_DTS_CLIENT_PREFIX_9: + case GRPC_DTS_CLIENT_PREFIX_10: + case GRPC_DTS_CLIENT_PREFIX_11: + case GRPC_DTS_CLIENT_PREFIX_12: + case GRPC_DTS_CLIENT_PREFIX_13: + case GRPC_DTS_CLIENT_PREFIX_14: + case GRPC_DTS_CLIENT_PREFIX_15: + case GRPC_DTS_CLIENT_PREFIX_16: + case GRPC_DTS_CLIENT_PREFIX_17: + case GRPC_DTS_CLIENT_PREFIX_18: + case GRPC_DTS_CLIENT_PREFIX_19: + case GRPC_DTS_CLIENT_PREFIX_20: + case GRPC_DTS_CLIENT_PREFIX_21: + case GRPC_DTS_CLIENT_PREFIX_22: + case GRPC_DTS_CLIENT_PREFIX_23: + while (cur != end && t->deframe_state != GRPC_DTS_FH_0) { + if (*cur != GRPC_CHTTP2_CLIENT_CONNECT_STRING[t->deframe_state]) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Connect string mismatch: expected '%c' (%d) got '%c' (%d) " + "at byte %d", + GRPC_CHTTP2_CLIENT_CONNECT_STRING[t->deframe_state], + static_cast(static_cast( + GRPC_CHTTP2_CLIENT_CONNECT_STRING[t->deframe_state])), + *cur, static_cast(*cur), t->deframe_state)); + } + ++cur; + // NOLINTNEXTLINE(bugprone-misplaced-widening-cast) + t->deframe_state = static_cast( + 1 + static_cast(t->deframe_state)); + } + if (cur == end) { + return GRPC_ERROR_NONE; + } + dts_fh_0: + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_0: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_frame_size = (static_cast(*cur)) << 16; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_1; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_1: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_frame_size |= (static_cast(*cur)) << 8; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_2; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_2: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_frame_size |= *cur; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_3; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_3: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_frame_type = *cur; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_4; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_4: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_frame_flags = *cur; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_5; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_5: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_stream_id = ((static_cast(*cur)) & 0x7f) << 24; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_6; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_6: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_stream_id |= (static_cast(*cur)) << 16; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_7; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_7: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_stream_id |= (static_cast(*cur)) << 8; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_8; + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FH_8: + GPR_DEBUG_ASSERT(cur < end); + t->incoming_stream_id |= (static_cast(*cur)); + t->deframe_state = GRPC_DTS_FRAME; + err = init_frame_parser(t); + if (err != GRPC_ERROR_NONE) { + return err; + } + if (t->incoming_frame_size == 0) { + err = parse_frame_slice(t, grpc_empty_slice(), 1); + if (err != GRPC_ERROR_NONE) { + return err; + } + t->incoming_stream = nullptr; + if (++cur == end) { + t->deframe_state = GRPC_DTS_FH_0; + return GRPC_ERROR_NONE; + } + goto dts_fh_0; /* loop */ + } else if (t->flow_control->flow_control_enabled() && + t->incoming_frame_size > + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE]) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Frame size %d is larger than max frame size %d", + t->incoming_frame_size, + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE])); + } + if (++cur == end) { + return GRPC_ERROR_NONE; + } + ABSL_FALLTHROUGH_INTENDED; + case GRPC_DTS_FRAME: + GPR_DEBUG_ASSERT(cur < end); + if (static_cast(end - cur) == t->incoming_frame_size) { + err = parse_frame_slice( + t, + grpc_slice_sub_no_ref(slice, static_cast(cur - beg), + static_cast(end - beg)), + 1); + if (err != GRPC_ERROR_NONE) { + return err; + } + t->deframe_state = GRPC_DTS_FH_0; + t->incoming_stream = nullptr; + return GRPC_ERROR_NONE; + } else if (static_cast(end - cur) > t->incoming_frame_size) { + size_t cur_offset = static_cast(cur - beg); + err = parse_frame_slice( + t, + grpc_slice_sub_no_ref(slice, cur_offset, + cur_offset + t->incoming_frame_size), + 1); + if (err != GRPC_ERROR_NONE) { + return err; + } + cur += t->incoming_frame_size; + t->incoming_stream = nullptr; + goto dts_fh_0; /* loop */ + } else { + err = parse_frame_slice( + t, + grpc_slice_sub_no_ref(slice, static_cast(cur - beg), + static_cast(end - beg)), + 0); + if (err != GRPC_ERROR_NONE) { + return err; + } + t->incoming_frame_size -= static_cast(end - cur); + return GRPC_ERROR_NONE; + } + GPR_UNREACHABLE_CODE(return GRPC_ERROR_NONE); + } + + GPR_UNREACHABLE_CODE(return GRPC_ERROR_NONE); +} + +static grpc_error_handle init_frame_parser(grpc_chttp2_transport* t) { + if (t->is_first_frame && + t->incoming_frame_type != GRPC_CHTTP2_FRAME_SETTINGS) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Expected SETTINGS frame as the first frame, got frame type ", + t->incoming_frame_type)); + } + t->is_first_frame = false; + if (t->expect_continuation_stream_id != 0) { + if (t->incoming_frame_type != GRPC_CHTTP2_FRAME_CONTINUATION) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Expected CONTINUATION frame, got frame type %02x", + t->incoming_frame_type)); + } + if (t->expect_continuation_stream_id != t->incoming_stream_id) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Expected CONTINUATION frame for grpc_chttp2_stream %08x, got " + "grpc_chttp2_stream %08x", + t->expect_continuation_stream_id, t->incoming_stream_id)); + } + return init_header_frame_parser(t, 1); + } + switch (t->incoming_frame_type) { + case GRPC_CHTTP2_FRAME_DATA: + return init_data_frame_parser(t); + case GRPC_CHTTP2_FRAME_HEADER: + return init_header_frame_parser(t, 0); + case GRPC_CHTTP2_FRAME_CONTINUATION: + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unexpected CONTINUATION frame"); + case GRPC_CHTTP2_FRAME_RST_STREAM: + return init_rst_stream_parser(t); + case GRPC_CHTTP2_FRAME_SETTINGS: + return init_settings_frame_parser(t); + case GRPC_CHTTP2_FRAME_WINDOW_UPDATE: + return init_window_update_frame_parser(t); + case GRPC_CHTTP2_FRAME_PING: + return init_ping_parser(t); + case GRPC_CHTTP2_FRAME_GOAWAY: + return init_goaway_parser(t); + default: + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_ERROR, "Unknown frame type %02x", t->incoming_frame_type); + } + return init_non_header_skip_frame_parser(t); + } +} + +static grpc_error_handle skip_parser(void* /*parser*/, + grpc_chttp2_transport* /*t*/, + grpc_chttp2_stream* /*s*/, + const grpc_slice& /*slice*/, + int /*is_last*/) { + return GRPC_ERROR_NONE; +} + +grpc_error_handle skip_header(grpc_mdelem md) { + GRPC_MDELEM_UNREF(md); + return GRPC_ERROR_NONE; +} + +static HPackParser::Boundary hpack_boundary_type(grpc_chttp2_transport* t, + bool is_eoh) { + if (is_eoh) { + if (t->header_eof) { + return HPackParser::Boundary::EndOfStream; + } else { + return HPackParser::Boundary::EndOfHeaders; + } + } else { + return HPackParser::Boundary::None; + } +} + +static HPackParser::LogInfo hpack_parser_log_info( + grpc_chttp2_transport* t, HPackParser::LogInfo::Type type) { + return HPackParser::LogInfo{ + t->incoming_stream_id, + type, + t->is_client, + }; +} + +static grpc_error_handle init_header_skip_frame_parser( + grpc_chttp2_transport* t, HPackParser::Priority priority_type) { + bool is_eoh = t->expect_continuation_stream_id != 0; + t->parser = grpc_chttp2_header_parser_parse; + t->parser_data = &t->hpack_parser; + t->hpack_parser.BeginFrame( + nullptr, + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_HEADER_LIST_SIZE], + hpack_boundary_type(t, is_eoh), priority_type, + hpack_parser_log_info(t, HPackParser::LogInfo::kDontKnow)); + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_non_header_skip_frame_parser( + grpc_chttp2_transport* t) { + t->parser = skip_parser; + return GRPC_ERROR_NONE; +} + +void grpc_chttp2_parsing_become_skip_parser(grpc_chttp2_transport* t) { + if (t->parser == grpc_chttp2_header_parser_parse) { + t->hpack_parser.StopBufferingFrame(); + } else { + t->parser = skip_parser; + } +} + +static grpc_error_handle init_data_frame_parser(grpc_chttp2_transport* t) { + // Update BDP accounting since we have received a data frame. + grpc_core::BdpEstimator* bdp_est = t->flow_control->bdp_estimator(); + if (bdp_est) { + if (t->bdp_ping_blocked) { + t->bdp_ping_blocked = false; + GRPC_CHTTP2_REF_TRANSPORT(t, "bdp_ping"); + schedule_bdp_ping_locked(t); + } + bdp_est->AddIncomingBytes(t->incoming_frame_size); + } + grpc_chttp2_stream* s = + grpc_chttp2_parsing_lookup_stream(t, t->incoming_stream_id); + grpc_error_handle err = GRPC_ERROR_NONE; + grpc_core::chttp2::FlowControlAction action; + if (s == nullptr) { + err = t->flow_control->RecvData(t->incoming_frame_size); + action = t->flow_control->MakeAction(); + } else { + err = s->flow_control->RecvData(t->incoming_frame_size); + action = s->flow_control->MakeAction(); + } + grpc_chttp2_act_on_flowctl_action(action, t, s); + if (err != GRPC_ERROR_NONE) { + goto error_handler; + } + if (s == nullptr) { + return init_non_header_skip_frame_parser(t); + } + s->received_bytes += t->incoming_frame_size; + s->stats.incoming.framing_bytes += 9; + if (err == GRPC_ERROR_NONE && s->read_closed) { + return init_non_header_skip_frame_parser(t); + } + if (err == GRPC_ERROR_NONE) { + err = grpc_chttp2_data_parser_begin_frame( + &s->data_parser, t->incoming_frame_flags, s->id, s); + } +error_handler: + intptr_t unused; + if (err == GRPC_ERROR_NONE) { + t->incoming_stream = s; + /* t->parser = grpc_chttp2_data_parser_parse;*/ + t->parser = grpc_chttp2_data_parser_parse; + t->parser_data = &s->data_parser; + t->ping_state.last_ping_sent_time = GRPC_MILLIS_INF_PAST; + return GRPC_ERROR_NONE; + } else if (grpc_error_get_int(err, GRPC_ERROR_INT_STREAM_ID, &unused)) { + /* handle stream errors by closing the stream */ + if (s != nullptr) { + grpc_chttp2_mark_stream_closed(t, s, true, false, err); + } + grpc_chttp2_add_rst_stream_to_next_write(t, t->incoming_stream_id, + GRPC_HTTP2_PROTOCOL_ERROR, + &s->stats.outgoing); + return init_non_header_skip_frame_parser(t); + } else { + return err; + } +} + +static grpc_error_handle init_header_frame_parser(grpc_chttp2_transport* t, + int is_continuation) { + const bool is_eoh = + (t->incoming_frame_flags & GRPC_CHTTP2_DATA_FLAG_END_HEADERS) != 0; + grpc_chttp2_stream* s; + + /* TODO(ctiller): when to increment header_frames_received? */ + + if (is_eoh) { + t->expect_continuation_stream_id = 0; + } else { + t->expect_continuation_stream_id = t->incoming_stream_id; + } + + if (!is_continuation) { + t->header_eof = + (t->incoming_frame_flags & GRPC_CHTTP2_DATA_FLAG_END_STREAM) != 0; + } + + const auto priority_type = !is_continuation && (t->incoming_frame_flags & + GRPC_CHTTP2_FLAG_HAS_PRIORITY) + ? HPackParser::Priority::Included + : HPackParser::Priority::None; + + t->ping_state.last_ping_sent_time = GRPC_MILLIS_INF_PAST; + + /* could be a new grpc_chttp2_stream or an existing grpc_chttp2_stream */ + s = grpc_chttp2_parsing_lookup_stream(t, t->incoming_stream_id); + if (s == nullptr) { + if (GPR_UNLIKELY(is_continuation)) { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_ERROR, + "grpc_chttp2_stream disbanded before CONTINUATION received")); + return init_header_skip_frame_parser(t, priority_type); + } + if (t->is_client) { + if (GPR_LIKELY((t->incoming_stream_id & 1) && + t->incoming_stream_id < t->next_stream_id)) { + /* this is an old (probably cancelled) grpc_chttp2_stream */ + } else { + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_ERROR, "ignoring new grpc_chttp2_stream creation on client")); + } + return init_header_skip_frame_parser(t, priority_type); + } else if (GPR_UNLIKELY(t->last_new_stream_id >= t->incoming_stream_id)) { + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_ERROR, + "ignoring out of order new grpc_chttp2_stream request on server; " + "last grpc_chttp2_stream " + "id=%d, new grpc_chttp2_stream id=%d", + t->last_new_stream_id, t->incoming_stream_id)); + return init_header_skip_frame_parser(t, priority_type); + } else if (GPR_UNLIKELY((t->incoming_stream_id & 1) == 0)) { + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_ERROR, + "ignoring grpc_chttp2_stream with non-client generated index %d", + t->incoming_stream_id)); + return init_header_skip_frame_parser(t, priority_type); + } else if (GPR_UNLIKELY( + grpc_chttp2_stream_map_size(&t->stream_map) >= + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS])) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Max stream count exceeded"); + } + t->last_new_stream_id = t->incoming_stream_id; + s = t->incoming_stream = + grpc_chttp2_parsing_accept_stream(t, t->incoming_stream_id); + if (GPR_UNLIKELY(s == nullptr)) { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_ERROR, "grpc_chttp2_stream not accepted")); + return init_header_skip_frame_parser(t, priority_type); + } + if (t->channelz_socket != nullptr) { + t->channelz_socket->RecordStreamStartedFromRemote(); + } + } else { + t->incoming_stream = s; + } + GPR_DEBUG_ASSERT(s != nullptr); + s->stats.incoming.framing_bytes += 9; + if (GPR_UNLIKELY(s->read_closed)) { + GRPC_CHTTP2_IF_TRACING(gpr_log( + GPR_ERROR, "skipping already closed grpc_chttp2_stream header")); + t->incoming_stream = nullptr; + return init_header_skip_frame_parser(t, priority_type); + } + t->parser = grpc_chttp2_header_parser_parse; + t->parser_data = &t->hpack_parser; + if (t->header_eof) { + s->eos_received = true; + } + grpc_metadata_batch* incoming_metadata_buffer = nullptr; + HPackParser::LogInfo::Type frame_type = HPackParser::LogInfo::kDontKnow; + switch (s->header_frames_received) { + case 0: + if (t->is_client && t->header_eof) { + GRPC_CHTTP2_IF_TRACING(gpr_log(GPR_INFO, "parsing Trailers-Only")); + if (s->trailing_metadata_available != nullptr) { + *s->trailing_metadata_available = true; + } + incoming_metadata_buffer = &s->trailing_metadata_buffer; + frame_type = HPackParser::LogInfo::kTrailers; + } else { + GRPC_CHTTP2_IF_TRACING(gpr_log(GPR_INFO, "parsing initial_metadata")); + incoming_metadata_buffer = &s->initial_metadata_buffer; + frame_type = HPackParser::LogInfo::kHeaders; + } + break; + case 1: + GRPC_CHTTP2_IF_TRACING(gpr_log(GPR_INFO, "parsing trailing_metadata")); + incoming_metadata_buffer = &s->trailing_metadata_buffer; + frame_type = HPackParser::LogInfo::kTrailers; + break; + case 2: + gpr_log(GPR_ERROR, "too many header frames received"); + return init_header_skip_frame_parser(t, priority_type); + } + t->hpack_parser.BeginFrame( + incoming_metadata_buffer, + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_HEADER_LIST_SIZE], + hpack_boundary_type(t, is_eoh), priority_type, + hpack_parser_log_info(t, frame_type)); + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_window_update_frame_parser( + grpc_chttp2_transport* t) { + grpc_error_handle err = grpc_chttp2_window_update_parser_begin_frame( + &t->simple.window_update, t->incoming_frame_size, + t->incoming_frame_flags); + if (err != GRPC_ERROR_NONE) return err; + if (t->incoming_stream_id != 0) { + grpc_chttp2_stream* s = t->incoming_stream = + grpc_chttp2_parsing_lookup_stream(t, t->incoming_stream_id); + if (s == nullptr) { + return init_non_header_skip_frame_parser(t); + } + s->stats.incoming.framing_bytes += 9; + } + t->parser = grpc_chttp2_window_update_parser_parse; + t->parser_data = &t->simple.window_update; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_ping_parser(grpc_chttp2_transport* t) { + grpc_error_handle err = grpc_chttp2_ping_parser_begin_frame( + &t->simple.ping, t->incoming_frame_size, t->incoming_frame_flags); + if (err != GRPC_ERROR_NONE) return err; + t->parser = grpc_chttp2_ping_parser_parse; + t->parser_data = &t->simple.ping; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_rst_stream_parser(grpc_chttp2_transport* t) { + grpc_error_handle err = grpc_chttp2_rst_stream_parser_begin_frame( + &t->simple.rst_stream, t->incoming_frame_size, t->incoming_frame_flags); + if (err != GRPC_ERROR_NONE) return err; + grpc_chttp2_stream* s = t->incoming_stream = + grpc_chttp2_parsing_lookup_stream(t, t->incoming_stream_id); + if (!t->incoming_stream) { + return init_non_header_skip_frame_parser(t); + } + s->stats.incoming.framing_bytes += 9; + t->parser = grpc_chttp2_rst_stream_parser_parse; + t->parser_data = &t->simple.rst_stream; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_goaway_parser(grpc_chttp2_transport* t) { + grpc_error_handle err = grpc_chttp2_goaway_parser_begin_frame( + &t->goaway_parser, t->incoming_frame_size, t->incoming_frame_flags); + if (err != GRPC_ERROR_NONE) return err; + t->parser = grpc_chttp2_goaway_parser_parse; + t->parser_data = &t->goaway_parser; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle init_settings_frame_parser(grpc_chttp2_transport* t) { + if (t->incoming_stream_id != 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Settings frame received for grpc_chttp2_stream"); + } + + grpc_error_handle err = grpc_chttp2_settings_parser_begin_frame( + &t->simple.settings, t->incoming_frame_size, t->incoming_frame_flags, + t->settings[GRPC_PEER_SETTINGS]); + if (err != GRPC_ERROR_NONE) { + return err; + } + if (t->incoming_frame_flags & GRPC_CHTTP2_FLAG_ACK) { + memcpy(t->settings[GRPC_ACKED_SETTINGS], t->settings[GRPC_SENT_SETTINGS], + GRPC_CHTTP2_NUM_SETTINGS * sizeof(uint32_t)); + t->hpack_parser.hpack_table()->SetMaxBytes( + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_HEADER_TABLE_SIZE]); + t->sent_local_settings = false; + } + t->parser = grpc_chttp2_settings_parser_parse; + t->parser_data = &t->simple.settings; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle parse_frame_slice(grpc_chttp2_transport* t, + const grpc_slice& slice, + int is_last) { + grpc_chttp2_stream* s = t->incoming_stream; + grpc_error_handle err = t->parser(t->parser_data, t, s, slice, is_last); + intptr_t unused; + if (GPR_LIKELY(err == GRPC_ERROR_NONE)) { + return err; + } else if (grpc_error_get_int(err, GRPC_ERROR_INT_STREAM_ID, &unused)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace)) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(err).c_str()); + } + grpc_chttp2_parsing_become_skip_parser(t); + if (s) { + s->forced_close_error = err; + grpc_chttp2_add_rst_stream_to_next_write(t, t->incoming_stream_id, + GRPC_HTTP2_PROTOCOL_ERROR, + &s->stats.outgoing); + } else { + GRPC_ERROR_UNREF(err); + } + } + return err; +} diff --git a/src/core/ext/transport/chttp2/transport/stream_lists.cc b/src/core/ext/transport/chttp2/transport/stream_lists.cc new file mode 100644 index 00000000..9c3c3e69 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/stream_lists.cc @@ -0,0 +1,216 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" + +static const char* stream_list_id_string(grpc_chttp2_stream_list_id id) { + switch (id) { + case GRPC_CHTTP2_LIST_WRITABLE: + return "writable"; + case GRPC_CHTTP2_LIST_WRITING: + return "writing"; + case GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT: + return "stalled_by_transport"; + case GRPC_CHTTP2_LIST_STALLED_BY_STREAM: + return "stalled_by_stream"; + case GRPC_CHTTP2_LIST_WAITING_FOR_CONCURRENCY: + return "waiting_for_concurrency"; + case STREAM_LIST_COUNT: + GPR_UNREACHABLE_CODE(return "unknown"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +grpc_core::TraceFlag grpc_trace_http2_stream_state(false, "http2_stream_state"); + +/* core list management */ + +static bool stream_list_empty(grpc_chttp2_transport* t, + grpc_chttp2_stream_list_id id) { + return t->lists[id].head == nullptr; +} + +static bool stream_list_pop(grpc_chttp2_transport* t, + grpc_chttp2_stream** stream, + grpc_chttp2_stream_list_id id) { + grpc_chttp2_stream* s = t->lists[id].head; + if (s) { + grpc_chttp2_stream* new_head = s->links[id].next; + GPR_ASSERT(s->included[id]); + if (new_head) { + t->lists[id].head = new_head; + new_head->links[id].prev = nullptr; + } else { + t->lists[id].head = nullptr; + t->lists[id].tail = nullptr; + } + s->included[id] = 0; + } + *stream = s; + if (s && GRPC_TRACE_FLAG_ENABLED(grpc_trace_http2_stream_state)) { + gpr_log(GPR_INFO, "%p[%d][%s]: pop from %s", t, s->id, + t->is_client ? "cli" : "svr", stream_list_id_string(id)); + } + return s != nullptr; +} + +static void stream_list_remove(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_chttp2_stream_list_id id) { + GPR_ASSERT(s->included[id]); + s->included[id] = 0; + if (s->links[id].prev) { + s->links[id].prev->links[id].next = s->links[id].next; + } else { + GPR_ASSERT(t->lists[id].head == s); + t->lists[id].head = s->links[id].next; + } + if (s->links[id].next) { + s->links[id].next->links[id].prev = s->links[id].prev; + } else { + t->lists[id].tail = s->links[id].prev; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_http2_stream_state)) { + gpr_log(GPR_INFO, "%p[%d][%s]: remove from %s", t, s->id, + t->is_client ? "cli" : "svr", stream_list_id_string(id)); + } +} + +static bool stream_list_maybe_remove(grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + grpc_chttp2_stream_list_id id) { + if (s->included[id]) { + stream_list_remove(t, s, id); + return true; + } else { + return false; + } +} + +static void stream_list_add_tail(grpc_chttp2_transport* t, + grpc_chttp2_stream* s, + grpc_chttp2_stream_list_id id) { + grpc_chttp2_stream* old_tail; + GPR_ASSERT(!s->included[id]); + old_tail = t->lists[id].tail; + s->links[id].next = nullptr; + s->links[id].prev = old_tail; + if (old_tail) { + old_tail->links[id].next = s; + } else { + t->lists[id].head = s; + } + t->lists[id].tail = s; + s->included[id] = 1; + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_http2_stream_state)) { + gpr_log(GPR_INFO, "%p[%d][%s]: add to %s", t, s->id, + t->is_client ? "cli" : "svr", stream_list_id_string(id)); + } +} + +static bool stream_list_add(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_chttp2_stream_list_id id) { + if (s->included[id]) { + return false; + } + stream_list_add_tail(t, s, id); + return true; +} + +/* wrappers for specializations */ + +bool grpc_chttp2_list_add_writable_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + GPR_ASSERT(s->id != 0); + return stream_list_add(t, s, GRPC_CHTTP2_LIST_WRITABLE); +} + +bool grpc_chttp2_list_pop_writable_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream** s) { + return stream_list_pop(t, s, GRPC_CHTTP2_LIST_WRITABLE); +} + +bool grpc_chttp2_list_remove_writable_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + return stream_list_maybe_remove(t, s, GRPC_CHTTP2_LIST_WRITABLE); +} + +bool grpc_chttp2_list_add_writing_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + return stream_list_add(t, s, GRPC_CHTTP2_LIST_WRITING); +} + +bool grpc_chttp2_list_have_writing_streams(grpc_chttp2_transport* t) { + return !stream_list_empty(t, GRPC_CHTTP2_LIST_WRITING); +} + +bool grpc_chttp2_list_pop_writing_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream** s) { + return stream_list_pop(t, s, GRPC_CHTTP2_LIST_WRITING); +} + +void grpc_chttp2_list_add_waiting_for_concurrency(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + stream_list_add(t, s, GRPC_CHTTP2_LIST_WAITING_FOR_CONCURRENCY); +} + +bool grpc_chttp2_list_pop_waiting_for_concurrency(grpc_chttp2_transport* t, + grpc_chttp2_stream** s) { + return stream_list_pop(t, s, GRPC_CHTTP2_LIST_WAITING_FOR_CONCURRENCY); +} + +void grpc_chttp2_list_remove_waiting_for_concurrency(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + stream_list_maybe_remove(t, s, GRPC_CHTTP2_LIST_WAITING_FOR_CONCURRENCY); +} + +void grpc_chttp2_list_add_stalled_by_transport(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + GPR_ASSERT(t->flow_control->flow_control_enabled()); + stream_list_add(t, s, GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT); +} + +bool grpc_chttp2_list_pop_stalled_by_transport(grpc_chttp2_transport* t, + grpc_chttp2_stream** s) { + return stream_list_pop(t, s, GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT); +} + +void grpc_chttp2_list_remove_stalled_by_transport(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + stream_list_maybe_remove(t, s, GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT); +} + +void grpc_chttp2_list_add_stalled_by_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + GPR_ASSERT(t->flow_control->flow_control_enabled()); + stream_list_add(t, s, GRPC_CHTTP2_LIST_STALLED_BY_STREAM); +} + +bool grpc_chttp2_list_pop_stalled_by_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream** s) { + return stream_list_pop(t, s, GRPC_CHTTP2_LIST_STALLED_BY_STREAM); +} + +bool grpc_chttp2_list_remove_stalled_by_stream(grpc_chttp2_transport* t, + grpc_chttp2_stream* s) { + return stream_list_maybe_remove(t, s, GRPC_CHTTP2_LIST_STALLED_BY_STREAM); +} diff --git a/src/core/ext/transport/chttp2/transport/stream_map.cc b/src/core/ext/transport/chttp2/transport/stream_map.cc new file mode 100644 index 00000000..647214b9 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/stream_map.cc @@ -0,0 +1,177 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/stream_map.h" + +#include + +#include +#include + +void grpc_chttp2_stream_map_init(grpc_chttp2_stream_map* map, + size_t initial_capacity) { + GPR_DEBUG_ASSERT(initial_capacity > 1); + map->keys = + static_cast(gpr_malloc(sizeof(uint32_t) * initial_capacity)); + map->values = + static_cast(gpr_malloc(sizeof(void*) * initial_capacity)); + map->count = 0; + map->free = 0; + map->capacity = initial_capacity; +} + +void grpc_chttp2_stream_map_destroy(grpc_chttp2_stream_map* map) { + gpr_free(map->keys); + gpr_free(map->values); +} + +static size_t compact(uint32_t* keys, void** values, size_t count) { + size_t i, out; + + for (i = 0, out = 0; i < count; i++) { + if (values[i]) { + keys[out] = keys[i]; + values[out] = values[i]; + out++; + } + } + + return out; +} + +void grpc_chttp2_stream_map_add(grpc_chttp2_stream_map* map, uint32_t key, + void* value) { + size_t count = map->count; + size_t capacity = map->capacity; + uint32_t* keys = map->keys; + void** values = map->values; + + // The first assertion ensures that the table is monotonically increasing. + GPR_ASSERT(count == 0 || keys[count - 1] < key); + GPR_DEBUG_ASSERT(value); + // Asserting that the key is not already in the map can be a debug assertion. + // Why: we're already checking that the map elements are monotonically + // increasing. If we re-add a key, i.e. if the key is already present, then + // either it is the most recently added key in the map (in which case the + // first assertion fails due to key == last_key) or there is a more recently + // added (larger) key at the end of the map: in which case the first assertion + // still fails due to key < last_key. + GPR_DEBUG_ASSERT(grpc_chttp2_stream_map_find(map, key) == nullptr); + + if (count == capacity) { + if (map->free > capacity / 4) { + count = compact(keys, values, count); + map->free = 0; + } else { + /* resize when less than 25% of the table is free, because compaction + won't help much */ + map->capacity = capacity = 2 * capacity; + map->keys = keys = static_cast( + gpr_realloc(keys, capacity * sizeof(uint32_t))); + map->values = values = + static_cast(gpr_realloc(values, capacity * sizeof(void*))); + } + } + + keys[count] = key; + values[count] = value; + map->count = count + 1; +} + +template +static void** find(grpc_chttp2_stream_map* map, uint32_t key) { + size_t min_idx = 0; + size_t max_idx = map->count; + size_t mid_idx; + uint32_t* keys = map->keys; + void** values = map->values; + uint32_t mid_key; + + GPR_DEBUG_ASSERT(!strict_find || max_idx > 0); + if (!strict_find && max_idx == 0) return nullptr; + + while (min_idx < max_idx) { + /* find the midpoint, avoiding overflow */ + mid_idx = min_idx + ((max_idx - min_idx) / 2); + mid_key = keys[mid_idx]; + + if (mid_key < key) { + min_idx = mid_idx + 1; + } else if (mid_key > key) { + max_idx = mid_idx; + } else /* mid_key == key */ + { + return &values[mid_idx]; + } + } + + GPR_DEBUG_ASSERT(!strict_find); + return nullptr; +} + +void* grpc_chttp2_stream_map_delete(grpc_chttp2_stream_map* map, uint32_t key) { + void** pvalue = find(map, key); + GPR_DEBUG_ASSERT(pvalue != nullptr); + void* out = *pvalue; + GPR_DEBUG_ASSERT(out != nullptr); + *pvalue = nullptr; + map->free++; + /* recognize complete emptyness and ensure we can skip + defragmentation later */ + if (map->free == map->count) { + map->free = map->count = 0; + } + GPR_DEBUG_ASSERT(grpc_chttp2_stream_map_find(map, key) == nullptr); + return out; +} + +void* grpc_chttp2_stream_map_find(grpc_chttp2_stream_map* map, uint32_t key) { + void** pvalue = find(map, key); + return pvalue != nullptr ? *pvalue : nullptr; +} + +size_t grpc_chttp2_stream_map_size(grpc_chttp2_stream_map* map) { + return map->count - map->free; +} + +void* grpc_chttp2_stream_map_rand(grpc_chttp2_stream_map* map) { + if (map->count == map->free) { + return nullptr; + } + if (map->free != 0) { + map->count = compact(map->keys, map->values, map->count); + map->free = 0; + GPR_ASSERT(map->count > 0); + } + return map->values[(static_cast(rand())) % map->count]; +} + +void grpc_chttp2_stream_map_for_each(grpc_chttp2_stream_map* map, + void (*f)(void* user_data, uint32_t key, + void* value), + void* user_data) { + size_t i; + + for (i = 0; i < map->count; i++) { + if (map->values[i]) { + f(user_data, map->keys[i], map->values[i]); + } + } +} diff --git a/src/core/ext/transport/chttp2/transport/varint.cc b/src/core/ext/transport/chttp2/transport/varint.cc new file mode 100644 index 00000000..64782354 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/varint.cc @@ -0,0 +1,62 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/varint.h" + +#include "absl/base/attributes.h" + +namespace grpc_core { + +uint32_t VarintLength(uint32_t tail_value) { + if (tail_value < (1 << 7)) { + return 2; + } else if (tail_value < (1 << 14)) { + return 3; + } else if (tail_value < (1 << 21)) { + return 4; + } else if (tail_value < (1 << 28)) { + return 5; + } else { + return 6; + } +} + +void VarintWriteTail(uint32_t tail_value, uint8_t* target, + uint32_t tail_length) { + switch (tail_length) { + case 5: + target[4] = static_cast((tail_value >> 28) | 0x80); + ABSL_FALLTHROUGH_INTENDED; + case 4: + target[3] = static_cast((tail_value >> 21) | 0x80); + ABSL_FALLTHROUGH_INTENDED; + case 3: + target[2] = static_cast((tail_value >> 14) | 0x80); + ABSL_FALLTHROUGH_INTENDED; + case 2: + target[1] = static_cast((tail_value >> 7) | 0x80); + ABSL_FALLTHROUGH_INTENDED; + case 1: + target[0] = static_cast((tail_value) | 0x80); + } + target[tail_length - 1] &= 0x7f; +} + +} // namespace grpc_core diff --git a/src/core/ext/transport/chttp2/transport/writing.cc b/src/core/ext/transport/chttp2/transport/writing.cc new file mode 100644 index 00000000..bd8f555a --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/writing.cc @@ -0,0 +1,716 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/context_list.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/compression/stream_compression.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/http2_errors.h" + +static void add_to_write_list(grpc_chttp2_write_cb** list, + grpc_chttp2_write_cb* cb) { + cb->next = *list; + *list = cb; +} + +static void finish_write_cb(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + grpc_chttp2_write_cb* cb, grpc_error_handle error) { + grpc_chttp2_complete_closure_step(t, s, &cb->closure, error, + "finish_write_cb"); + cb->next = t->write_cb_pool; + t->write_cb_pool = cb; +} + +static void maybe_initiate_ping(grpc_chttp2_transport* t) { + grpc_chttp2_ping_queue* pq = &t->ping_queue; + if (grpc_closure_list_empty(pq->lists[GRPC_CHTTP2_PCL_NEXT])) { + /* no ping needed: wait */ + return; + } + if (!grpc_closure_list_empty(pq->lists[GRPC_CHTTP2_PCL_INFLIGHT])) { + /* ping already in-flight: wait */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, "%s: Ping delayed [%s]: already pinging", + t->is_client ? "CLIENT" : "SERVER", t->peer_string.c_str()); + } + return; + } + if (t->ping_state.pings_before_data_required == 0 && + t->ping_policy.max_pings_without_data != 0) { + /* need to receive something of substance before sending a ping again */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, "%s: Ping delayed [%s]: too many recent pings: %d/%d", + t->is_client ? "CLIENT" : "SERVER", t->peer_string.c_str(), + t->ping_state.pings_before_data_required, + t->ping_policy.max_pings_without_data); + } + return; + } + // InvalidateNow to avoid getting stuck re-initializing the ping timer + // in a loop while draining the currently-held combiner. Also see + // https://github.com/grpc/grpc/issues/26079. + grpc_core::ExecCtx::Get()->InvalidateNow(); + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + + grpc_millis next_allowed_ping_interval = + (t->keepalive_permit_without_calls == 0 && + grpc_chttp2_stream_map_size(&t->stream_map) == 0) + ? 7200 * GPR_MS_PER_SEC + : (GPR_MS_PER_SEC); /* A second is added to deal with network delays + and timing imprecision */ + grpc_millis next_allowed_ping = + t->ping_state.last_ping_sent_time + next_allowed_ping_interval; + + if (next_allowed_ping > now) { + /* not enough elapsed time between successive pings */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, + "%s: Ping delayed [%s]: not enough time elapsed since last ping. " + " Last ping %f: Next ping %f: Now %f", + t->is_client ? "CLIENT" : "SERVER", t->peer_string.c_str(), + static_cast(t->ping_state.last_ping_sent_time), + static_cast(next_allowed_ping), static_cast(now)); + } + if (!t->ping_state.is_delayed_ping_timer_set) { + t->ping_state.is_delayed_ping_timer_set = true; + GRPC_CHTTP2_REF_TRANSPORT(t, "retry_initiate_ping_locked"); + GRPC_CLOSURE_INIT(&t->retry_initiate_ping_locked, + grpc_chttp2_retry_initiate_ping, t, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&t->ping_state.delayed_ping_timer, next_allowed_ping, + &t->retry_initiate_ping_locked); + } + return; + } + + pq->inflight_id = t->ping_ctr; + t->ping_ctr++; + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, + &pq->lists[GRPC_CHTTP2_PCL_INITIATE]); + grpc_closure_list_move(&pq->lists[GRPC_CHTTP2_PCL_NEXT], + &pq->lists[GRPC_CHTTP2_PCL_INFLIGHT]); + grpc_slice_buffer_add(&t->outbuf, + grpc_chttp2_ping_create(false, pq->inflight_id)); + GRPC_STATS_INC_HTTP2_PINGS_SENT(); + t->ping_state.last_ping_sent_time = now; + if (GRPC_TRACE_FLAG_ENABLED(grpc_http_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace) || + GRPC_TRACE_FLAG_ENABLED(grpc_keepalive_trace)) { + gpr_log(GPR_INFO, "%s: Ping sent [%s]: %d/%d", + t->is_client ? "CLIENT" : "SERVER", t->peer_string.c_str(), + t->ping_state.pings_before_data_required, + t->ping_policy.max_pings_without_data); + } + t->ping_state.pings_before_data_required -= + (t->ping_state.pings_before_data_required != 0); +} + +static bool update_list(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + int64_t send_bytes, grpc_chttp2_write_cb** list, + int64_t* ctr, grpc_error_handle error) { + bool sched_any = false; + grpc_chttp2_write_cb* cb = *list; + *list = nullptr; + *ctr += send_bytes; + while (cb) { + grpc_chttp2_write_cb* next = cb->next; + if (cb->call_at_byte <= *ctr) { + sched_any = true; + finish_write_cb(t, s, cb, GRPC_ERROR_REF(error)); + } else { + add_to_write_list(list, cb); + } + cb = next; + } + GRPC_ERROR_UNREF(error); + return sched_any; +} + +static void report_stall(grpc_chttp2_transport* t, grpc_chttp2_stream* s, + const char* staller) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_flowctl_trace)) { + gpr_log( + GPR_DEBUG, + "%s:%p stream %d moved to stalled list by %s. This is FULLY expected " + "to happen in a healthy program that is not seeing flow control stalls." + " However, if you know that there are unwanted stalls, here is some " + "helpful data: [fc:pending=%" PRIdPTR ":pending-compressed=%" PRIdPTR + ":flowed=%" PRId64 ":peer_initwin=%d:t_win=%" PRId64 + ":s_win=%d:s_delta=%" PRId64 "]", + t->peer_string.c_str(), t, s->id, staller, + s->flow_controlled_buffer.length, + s->stream_compression_method == + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS + ? 0 + : s->compressed_data_buffer.length, + s->flow_controlled_bytes_flowed, + t->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + t->flow_control->remote_window(), + static_cast(std::max( + int64_t(0), + s->flow_control->remote_window_delta() + + static_cast( + t->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]))), + s->flow_control->remote_window_delta()); + } +} + +/* How many bytes would we like to put on the wire during a single syscall */ +static uint32_t target_write_size(grpc_chttp2_transport* /*t*/) { + return 1024 * 1024; +} + +// Returns true if initial_metadata contains only default headers. +static bool is_default_initial_metadata(grpc_metadata_batch* initial_metadata) { + return initial_metadata->default_count() == + initial_metadata->non_deadline_count(); +} + +namespace { +class StreamWriteContext; + +class WriteContext { + public: + explicit WriteContext(grpc_chttp2_transport* t) : t_(t) { + GRPC_STATS_INC_HTTP2_WRITES_BEGUN(); + GPR_TIMER_SCOPE("grpc_chttp2_begin_write", 0); + } + + // TODO(ctiller): make this the destructor + void FlushStats() { + GRPC_STATS_INC_HTTP2_SEND_INITIAL_METADATA_PER_WRITE( + initial_metadata_writes_); + GRPC_STATS_INC_HTTP2_SEND_MESSAGE_PER_WRITE(message_writes_); + GRPC_STATS_INC_HTTP2_SEND_TRAILING_METADATA_PER_WRITE( + trailing_metadata_writes_); + GRPC_STATS_INC_HTTP2_SEND_FLOWCTL_PER_WRITE(flow_control_writes_); + } + + void FlushSettings() { + if (t_->dirtied_local_settings && !t_->sent_local_settings) { + grpc_slice_buffer_add( + &t_->outbuf, grpc_chttp2_settings_create( + t_->settings[GRPC_SENT_SETTINGS], + t_->settings[GRPC_LOCAL_SETTINGS], + t_->force_send_settings, GRPC_CHTTP2_NUM_SETTINGS)); + t_->force_send_settings = false; + t_->dirtied_local_settings = false; + t_->sent_local_settings = true; + GRPC_STATS_INC_HTTP2_SETTINGS_WRITES(); + } + } + + void FlushQueuedBuffers() { + /* simple writes are queued to qbuf, and flushed here */ + grpc_slice_buffer_move_into(&t_->qbuf, &t_->outbuf); + t_->num_pending_induced_frames = 0; + GPR_ASSERT(t_->qbuf.count == 0); + } + + void FlushWindowUpdates() { + uint32_t transport_announce = + t_->flow_control->MaybeSendUpdate(t_->outbuf.count > 0); + if (transport_announce) { + grpc_transport_one_way_stats throwaway_stats; + grpc_slice_buffer_add( + &t_->outbuf, grpc_chttp2_window_update_create(0, transport_announce, + &throwaway_stats)); + grpc_chttp2_reset_ping_clock(t_); + } + } + + void FlushPingAcks() { + for (size_t i = 0; i < t_->ping_ack_count; i++) { + grpc_slice_buffer_add(&t_->outbuf, + grpc_chttp2_ping_create(true, t_->ping_acks[i])); + } + t_->ping_ack_count = 0; + } + + void EnactHpackSettings() { + t_->hpack_compressor.SetMaxTableSize( + t_->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_HEADER_TABLE_SIZE]); + } + + void UpdateStreamsNoLongerStalled() { + grpc_chttp2_stream* s; + while (grpc_chttp2_list_pop_stalled_by_transport(t_, &s)) { + if (t_->closed_with_error == GRPC_ERROR_NONE && + grpc_chttp2_list_add_writable_stream(t_, s)) { + if (!s->refcount->refs.RefIfNonZero()) { + grpc_chttp2_list_remove_writable_stream(t_, s); + } + } + } + } + + grpc_chttp2_stream* NextStream() { + if (t_->outbuf.length > target_write_size(t_)) { + result_.partial = true; + return nullptr; + } + + grpc_chttp2_stream* s; + if (!grpc_chttp2_list_pop_writable_stream(t_, &s)) { + return nullptr; + } + + return s; + } + + void IncInitialMetadataWrites() { ++initial_metadata_writes_; } + void IncWindowUpdateWrites() { ++flow_control_writes_; } + void IncMessageWrites() { ++message_writes_; } + void IncTrailingMetadataWrites() { ++trailing_metadata_writes_; } + + void NoteScheduledResults() { result_.early_results_scheduled = true; } + + grpc_chttp2_transport* transport() const { return t_; } + + grpc_chttp2_begin_write_result Result() { + result_.writing = t_->outbuf.count > 0; + return result_; + } + + private: + grpc_chttp2_transport* const t_; + + /* stats histogram counters: we increment these throughout this function, + and at the end publish to the central stats histograms */ + int flow_control_writes_ = 0; + int initial_metadata_writes_ = 0; + int trailing_metadata_writes_ = 0; + int message_writes_ = 0; + grpc_chttp2_begin_write_result result_ = {false, false, false}; +}; + +class DataSendContext { + public: + DataSendContext(WriteContext* write_context, grpc_chttp2_transport* t, + grpc_chttp2_stream* s) + : write_context_(write_context), + t_(t), + s_(s), + sending_bytes_before_(s_->sending_bytes) {} + + uint32_t stream_remote_window() const { + return static_cast(std::max( + int64_t(0), + s_->flow_control->remote_window_delta() + + static_cast( + t_->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE]))); + } + + uint32_t max_outgoing() const { + return static_cast(std::min( + t_->settings[GRPC_PEER_SETTINGS][GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE], + static_cast(std::min(int64_t(stream_remote_window()), + t_->flow_control->remote_window())))); + } + + bool AnyOutgoing() const { return max_outgoing() > 0; } + + void FlushUncompressedBytes() { + uint32_t send_bytes = static_cast( + std::min(size_t(max_outgoing()), s_->flow_controlled_buffer.length)); + is_last_frame_ = send_bytes == s_->flow_controlled_buffer.length && + s_->fetching_send_message == nullptr && + s_->send_trailing_metadata != nullptr && + s_->send_trailing_metadata->empty(); + grpc_chttp2_encode_data(s_->id, &s_->flow_controlled_buffer, send_bytes, + is_last_frame_, &s_->stats.outgoing, &t_->outbuf); + s_->flow_control->SentData(send_bytes); + s_->sending_bytes += send_bytes; + } + + void FlushCompressedBytes() { + GPR_DEBUG_ASSERT(s_->stream_compression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS); + + uint32_t send_bytes = static_cast( + std::min(size_t(max_outgoing()), s_->compressed_data_buffer.length)); + bool is_last_data_frame = + (send_bytes == s_->compressed_data_buffer.length && + s_->flow_controlled_buffer.length == 0 && + s_->fetching_send_message == nullptr); + if (is_last_data_frame && s_->send_trailing_metadata != nullptr && + s_->stream_compression_ctx != nullptr) { + if (GPR_UNLIKELY(!grpc_stream_compress( + s_->stream_compression_ctx, &s_->flow_controlled_buffer, + &s_->compressed_data_buffer, nullptr, MAX_SIZE_T, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH))) { + gpr_log(GPR_ERROR, "Stream compression failed."); + } + grpc_stream_compression_context_destroy(s_->stream_compression_ctx); + s_->stream_compression_ctx = nullptr; + /* After finish, bytes in s->compressed_data_buffer may be + * more than max_outgoing. Start another round of the current + * while loop so that send_bytes and is_last_data_frame are + * recalculated. */ + return; + } + is_last_frame_ = is_last_data_frame && + s_->send_trailing_metadata != nullptr && + s_->send_trailing_metadata->empty(); + grpc_chttp2_encode_data(s_->id, &s_->compressed_data_buffer, send_bytes, + is_last_frame_, &s_->stats.outgoing, &t_->outbuf); + s_->flow_control->SentData(send_bytes); + if (s_->compressed_data_buffer.length == 0) { + s_->sending_bytes += s_->uncompressed_data_size; + } + } + + void CompressMoreBytes() { + GPR_DEBUG_ASSERT(s_->stream_compression_method != + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS); + + if (s_->stream_compression_ctx == nullptr) { + s_->stream_compression_ctx = + grpc_stream_compression_context_create(s_->stream_compression_method); + } + s_->uncompressed_data_size = s_->flow_controlled_buffer.length; + if (GPR_UNLIKELY(!grpc_stream_compress( + s_->stream_compression_ctx, &s_->flow_controlled_buffer, + &s_->compressed_data_buffer, nullptr, MAX_SIZE_T, + GRPC_STREAM_COMPRESSION_FLUSH_SYNC))) { + gpr_log(GPR_ERROR, "Stream compression failed."); + } + } + + bool is_last_frame() const { return is_last_frame_; } + + void CallCallbacks() { + if (update_list( + t_, s_, + static_cast(s_->sending_bytes - sending_bytes_before_), + &s_->on_flow_controlled_cbs, &s_->flow_controlled_bytes_flowed, + GRPC_ERROR_NONE)) { + write_context_->NoteScheduledResults(); + } + } + + private: + WriteContext* write_context_; + grpc_chttp2_transport* t_; + grpc_chttp2_stream* s_; + const size_t sending_bytes_before_; + bool is_last_frame_ = false; +}; + +class StreamWriteContext { + public: + StreamWriteContext(WriteContext* write_context, grpc_chttp2_stream* s) + : write_context_(write_context), t_(write_context->transport()), s_(s) { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, "W:%p %s[%d] im-(sent,send)=(%d,%d) announce=%d", t_, + t_->is_client ? "CLIENT" : "SERVER", s->id, + s->sent_initial_metadata, s->send_initial_metadata != nullptr, + (int)(s->flow_control->local_window_delta() - + s->flow_control->announced_window_delta()))); + } + + void FlushInitialMetadata() { + /* send initial metadata if it's available */ + if (s_->sent_initial_metadata) return; + if (s_->send_initial_metadata == nullptr) return; + + // We skip this on the server side if there is no custom initial + // metadata, there are no messages to send, and we are also sending + // trailing metadata. This results in a Trailers-Only response, + // which is required for retries, as per: + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#when-retries-are-valid + if (!t_->is_client && s_->fetching_send_message == nullptr && + s_->flow_controlled_buffer.length == 0 && + compressed_data_buffer_len() == 0 && + s_->send_trailing_metadata != nullptr && + is_default_initial_metadata(s_->send_initial_metadata)) { + ConvertInitialMetadataToTrailingMetadata(); + } else { + t_->hpack_compressor.EncodeHeaders( + grpc_core::HPackCompressor::EncodeHeaderOptions{ + s_->id, // stream_id + false, // is_eof + t_->settings + [GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_GRPC_ALLOW_TRUE_BINARY_METADATA] != + 0, // use_true_binary_metadata + t_->settings + [GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE], // max_frame_size + &s_->stats.outgoing // stats + }, + *s_->send_initial_metadata, &t_->outbuf); + grpc_chttp2_reset_ping_clock(t_); + write_context_->IncInitialMetadataWrites(); + } + + s_->send_initial_metadata = nullptr; + s_->sent_initial_metadata = true; + write_context_->NoteScheduledResults(); + grpc_chttp2_complete_closure_step( + t_, s_, &s_->send_initial_metadata_finished, GRPC_ERROR_NONE, + "send_initial_metadata_finished"); + } + + size_t compressed_data_buffer_len() { + return s_->stream_compression_method == + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS + ? 0 + : s_->compressed_data_buffer.length; + } + + void FlushWindowUpdates() { + /* send any window updates */ + const uint32_t stream_announce = s_->flow_control->MaybeSendUpdate(); + if (stream_announce == 0) return; + + grpc_slice_buffer_add( + &t_->outbuf, grpc_chttp2_window_update_create(s_->id, stream_announce, + &s_->stats.outgoing)); + grpc_chttp2_reset_ping_clock(t_); + write_context_->IncWindowUpdateWrites(); + } + + void FlushData() { + if (!s_->sent_initial_metadata) return; + + if (s_->flow_controlled_buffer.length == 0 && + compressed_data_buffer_len() == 0) { + return; // early out: nothing to do + } + + DataSendContext data_send_context(write_context_, t_, s_); + + if (!data_send_context.AnyOutgoing()) { + if (t_->flow_control->remote_window() <= 0) { + report_stall(t_, s_, "transport"); + grpc_chttp2_list_add_stalled_by_transport(t_, s_); + } else if (data_send_context.stream_remote_window() <= 0) { + report_stall(t_, s_, "stream"); + grpc_chttp2_list_add_stalled_by_stream(t_, s_); + } + return; // early out: nothing to do + } + + if (s_->stream_compression_method == + GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS) { + while (s_->flow_controlled_buffer.length > 0 && + data_send_context.max_outgoing() > 0) { + data_send_context.FlushUncompressedBytes(); + } + } else { + while ((s_->flow_controlled_buffer.length > 0 || + s_->compressed_data_buffer.length > 0) && + data_send_context.max_outgoing() > 0) { + if (s_->compressed_data_buffer.length > 0) { + data_send_context.FlushCompressedBytes(); + } else { + data_send_context.CompressMoreBytes(); + } + } + } + grpc_chttp2_reset_ping_clock(t_); + if (data_send_context.is_last_frame()) { + SentLastFrame(); + } + data_send_context.CallCallbacks(); + stream_became_writable_ = true; + if (s_->flow_controlled_buffer.length > 0 || + compressed_data_buffer_len() > 0) { + GRPC_CHTTP2_STREAM_REF(s_, "chttp2_writing:fork"); + grpc_chttp2_list_add_writable_stream(t_, s_); + } + write_context_->IncMessageWrites(); + } + + void FlushTrailingMetadata() { + if (!s_->sent_initial_metadata) return; + + if (s_->send_trailing_metadata == nullptr) return; + if (s_->fetching_send_message != nullptr) return; + if (s_->flow_controlled_buffer.length != 0) return; + if (compressed_data_buffer_len() != 0) return; + + GRPC_CHTTP2_IF_TRACING(gpr_log(GPR_INFO, "sending trailing_metadata")); + if (s_->send_trailing_metadata->empty()) { + grpc_chttp2_encode_data(s_->id, &s_->flow_controlled_buffer, 0, true, + &s_->stats.outgoing, &t_->outbuf); + } else { + t_->hpack_compressor.EncodeHeaders( + grpc_core::HPackCompressor::EncodeHeaderOptions{ + s_->id, true, + t_->settings + [GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_GRPC_ALLOW_TRUE_BINARY_METADATA] != + 0, + t_->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE], + &s_->stats.outgoing}, + grpc_core::ConcatMetadata( + grpc_core::MetadataArray( + extra_headers_for_trailing_metadata_, + num_extra_headers_for_trailing_metadata_), + *s_->send_trailing_metadata), + &t_->outbuf); + } + write_context_->IncTrailingMetadataWrites(); + grpc_chttp2_reset_ping_clock(t_); + SentLastFrame(); + + write_context_->NoteScheduledResults(); + grpc_chttp2_complete_closure_step( + t_, s_, &s_->send_trailing_metadata_finished, GRPC_ERROR_NONE, + "send_trailing_metadata_finished"); + } + + bool stream_became_writable() { return stream_became_writable_; } + + private: + void ConvertInitialMetadataToTrailingMetadata() { + GRPC_CHTTP2_IF_TRACING( + gpr_log(GPR_INFO, "not sending initial_metadata (Trailers-Only)")); + // When sending Trailers-Only, we need to move the :status and + // content-type headers to the trailers. + if (s_->send_initial_metadata->legacy_index()->named.status != nullptr) { + extra_headers_for_trailing_metadata_ + [num_extra_headers_for_trailing_metadata_++] = + &s_->send_initial_metadata->legacy_index()->named.status->md; + } + if (s_->send_initial_metadata->legacy_index()->named.content_type != + nullptr) { + extra_headers_for_trailing_metadata_ + [num_extra_headers_for_trailing_metadata_++] = + &s_->send_initial_metadata->legacy_index() + ->named.content_type->md; + } + } + + void SentLastFrame() { + s_->send_trailing_metadata = nullptr; + if (s_->sent_trailing_metadata_op) { + *s_->sent_trailing_metadata_op = true; + s_->sent_trailing_metadata_op = nullptr; + } + s_->sent_trailing_metadata = true; + s_->eos_sent = true; + + if (!t_->is_client && !s_->read_closed) { + grpc_slice_buffer_add( + &t_->outbuf, grpc_chttp2_rst_stream_create( + s_->id, GRPC_HTTP2_NO_ERROR, &s_->stats.outgoing)); + } + grpc_chttp2_mark_stream_closed(t_, s_, !t_->is_client, true, + GRPC_ERROR_NONE); + } + + WriteContext* const write_context_; + grpc_chttp2_transport* const t_; + grpc_chttp2_stream* const s_; + bool stream_became_writable_ = false; + grpc_mdelem* extra_headers_for_trailing_metadata_[2]; + size_t num_extra_headers_for_trailing_metadata_ = 0; +}; +} // namespace + +grpc_chttp2_begin_write_result grpc_chttp2_begin_write( + grpc_chttp2_transport* t) { + WriteContext ctx(t); + ctx.FlushSettings(); + ctx.FlushPingAcks(); + ctx.FlushQueuedBuffers(); + ctx.EnactHpackSettings(); + + if (t->flow_control->remote_window() > 0) { + ctx.UpdateStreamsNoLongerStalled(); + } + + /* for each grpc_chttp2_stream that's become writable, frame it's data + (according to available window sizes) and add to the output buffer */ + while (grpc_chttp2_stream* s = ctx.NextStream()) { + StreamWriteContext stream_ctx(&ctx, s); + size_t orig_len = t->outbuf.length; + stream_ctx.FlushInitialMetadata(); + stream_ctx.FlushWindowUpdates(); + stream_ctx.FlushData(); + stream_ctx.FlushTrailingMetadata(); + if (t->outbuf.length > orig_len) { + /* Add this stream to the list of the contexts to be traced at TCP */ + s->byte_counter += t->outbuf.length - orig_len; + if (s->traced && grpc_endpoint_can_track_err(t->ep)) { + grpc_core::ContextList::Append(&t->cl, s); + } + } + if (stream_ctx.stream_became_writable()) { + if (!grpc_chttp2_list_add_writing_stream(t, s)) { + /* already in writing list: drop ref */ + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2_writing:already_writing"); + } else { + /* ref will be dropped at end of write */ + } + } else { + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2_writing:no_write"); + } + } + + ctx.FlushWindowUpdates(); + + maybe_initiate_ping(t); + + return ctx.Result(); +} + +void grpc_chttp2_end_write(grpc_chttp2_transport* t, grpc_error_handle error) { + GPR_TIMER_SCOPE("grpc_chttp2_end_write", 0); + grpc_chttp2_stream* s; + + if (t->channelz_socket != nullptr) { + t->channelz_socket->RecordMessagesSent(t->num_messages_in_next_write); + } + t->num_messages_in_next_write = 0; + + while (grpc_chttp2_list_pop_writing_stream(t, &s)) { + if (s->sending_bytes != 0) { + update_list(t, s, static_cast(s->sending_bytes), + &s->on_write_finished_cbs, &s->flow_controlled_bytes_written, + GRPC_ERROR_REF(error)); + s->sending_bytes = 0; + } + GRPC_CHTTP2_STREAM_UNREF(s, "chttp2_writing:end"); + } + grpc_slice_buffer_reset_and_unref_internal(&t->outbuf); + GRPC_ERROR_UNREF(error); +} diff --git a/src/core/ext/transport/cronet/client/secure/cronet_channel_create.cc b/src/core/ext/transport/cronet/client/secure/cronet_channel_create.cc new file mode 100644 index 00000000..ef1a9b57 --- /dev/null +++ b/src/core/ext/transport/cronet/client/secure/cronet_channel_create.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/cronet/client/secure/cronet_channel_create.h" + +#include +#include + +#include +#include + +#include "src/core/ext/transport/cronet/transport/cronet_transport.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/transport_impl.h" + +// Cronet transport object +typedef struct cronet_transport { + grpc_transport base; // must be first element in this structure + void* engine; + char* host; +} cronet_transport; + +extern grpc_transport_vtable grpc_cronet_vtable; + +GRPCAPI grpc_channel* grpc_cronet_secure_channel_create( + void* engine, const char* target, const grpc_channel_args* args, + void* reserved) { + gpr_log(GPR_DEBUG, + "grpc_create_cronet_transport: stream_engine = %p, target=%s", engine, + target); + + // Disable client authority filter when using Cronet + grpc_arg disable_client_authority_filter_arg; + disable_client_authority_filter_arg.key = + const_cast(GRPC_ARG_DISABLE_CLIENT_AUTHORITY_FILTER); + disable_client_authority_filter_arg.type = GRPC_ARG_INTEGER; + disable_client_authority_filter_arg.value.integer = 1; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add( + args, &disable_client_authority_filter_arg, 1); + + grpc_transport* ct = + grpc_create_cronet_transport(engine, target, new_args, reserved); + + grpc_core::ExecCtx exec_ctx; + grpc_channel* channel = grpc_channel_create( + target, new_args, GRPC_CLIENT_DIRECT_CHANNEL, ct, nullptr, 0, nullptr); + grpc_channel_args_destroy(new_args); + return channel; +} diff --git a/src/core/ext/transport/cronet/plugin_registry/grpc_cronet_plugin_registry.cc b/src/core/ext/transport/cronet/plugin_registry/grpc_cronet_plugin_registry.cc new file mode 100644 index 00000000..9557b652 --- /dev/null +++ b/src/core/ext/transport/cronet/plugin_registry/grpc_cronet_plugin_registry.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +void grpc_http_filters_init(void); +void grpc_http_filters_shutdown(void); +void grpc_chttp2_plugin_init(void); +void grpc_chttp2_plugin_shutdown(void); +void grpc_deadline_filter_init(void); +void grpc_deadline_filter_shutdown(void); +void grpc_client_channel_init(void); +void grpc_client_channel_shutdown(void); + +namespace grpc_core { +void ServiceConfigParserInit(void); +void ServiceConfigParserShutdown(void); +} // namespace grpc_core + +void grpc_register_built_in_plugins(void) { + grpc_register_plugin(grpc_http_filters_init, grpc_http_filters_shutdown); + grpc_register_plugin(grpc_core::ServiceConfigParserInit, + grpc_core::ServiceConfigParserShutdown); + grpc_register_plugin(grpc_chttp2_plugin_init, grpc_chttp2_plugin_shutdown); + grpc_register_plugin(grpc_deadline_filter_init, + grpc_deadline_filter_shutdown); + grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); +} diff --git a/src/core/ext/transport/cronet/transport/cronet_api_phony.cc b/src/core/ext/transport/cronet/transport/cronet_api_phony.cc new file mode 100644 index 00000000..0499112b --- /dev/null +++ b/src/core/ext/transport/cronet/transport/cronet_api_phony.cc @@ -0,0 +1,86 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This file has empty implementation of all the functions exposed by the cronet +library, so we can build it in all environments */ + +#include + +#include + +#include "third_party/objective_c/Cronet/bidirectional_stream_c.h" + +#include + +#ifdef GRPC_COMPILE_WITH_CRONET +/* link with the real CRONET library in the build system */ +#else +/* Phony implementation of cronet API just to test for build-ability */ +bidirectional_stream* bidirectional_stream_create( + stream_engine* /*engine*/, void* /*annotation*/, + bidirectional_stream_callback* /*callback*/) { + GPR_ASSERT(0); + return nullptr; +} + +int bidirectional_stream_destroy(bidirectional_stream* /*stream*/) { + GPR_ASSERT(0); + return 0; +} + +int bidirectional_stream_start( + bidirectional_stream* /*stream*/, const char* /*url*/, int /*priority*/, + const char* /*method*/, + const bidirectional_stream_header_array* /*headers*/, + bool /*end_of_stream*/) { + GPR_ASSERT(0); + return 0; +} + +int bidirectional_stream_read(bidirectional_stream* /*stream*/, + char* /*buffer*/, int /*capacity*/) { + GPR_ASSERT(0); + return 0; +} + +int bidirectional_stream_write(bidirectional_stream* /*stream*/, + const char* /*buffer*/, int /*count*/, + bool /*end_of_stream*/) { + GPR_ASSERT(0); + return 0; +} + +void bidirectional_stream_cancel(bidirectional_stream* /*stream*/) { + GPR_ASSERT(0); +} + +void bidirectional_stream_disable_auto_flush(bidirectional_stream* /*stream*/, + bool /*disable_auto_flush*/) { + GPR_ASSERT(0); +} + +void bidirectional_stream_delay_request_headers_until_flush( + bidirectional_stream* /*stream*/, bool /*delay_headers_until_flush*/) { + GPR_ASSERT(0); +} + +void bidirectional_stream_flush(bidirectional_stream* /*stream*/) { + GPR_ASSERT(0); +} + +#endif /* GRPC_COMPILE_WITH_CRONET */ diff --git a/src/core/ext/transport/cronet/transport/cronet_status.cc b/src/core/ext/transport/cronet/transport/cronet_status.cc new file mode 100644 index 00000000..8b831f2d --- /dev/null +++ b/src/core/ext/transport/cronet/transport/cronet_status.cc @@ -0,0 +1,521 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/cronet/transport/cronet_status.h" + +const char* cronet_net_error_as_string(cronet_net_error_code net_error) { + switch (net_error) { + case OK: + return "OK"; + case CRONET_NET_ERROR_IO_PENDING: + return "CRONET_NET_ERROR_IO_PENDING"; + case CRONET_NET_ERROR_FAILED: + return "CRONET_NET_ERROR_FAILED"; + case CRONET_NET_ERROR_ABORTED: + return "CRONET_NET_ERROR_ABORTED"; + case CRONET_NET_ERROR_INVALID_ARGUMENT: + return "CRONET_NET_ERROR_INVALID_ARGUMENT"; + case CRONET_NET_ERROR_INVALID_HANDLE: + return "CRONET_NET_ERROR_INVALID_HANDLE"; + case CRONET_NET_ERROR_FILE_NOT_FOUND: + return "CRONET_NET_ERROR_FILE_NOT_FOUND"; + case CRONET_NET_ERROR_TIMED_OUT: + return "CRONET_NET_ERROR_TIMED_OUT"; + case CRONET_NET_ERROR_FILE_TOO_BIG: + return "CRONET_NET_ERROR_FILE_TOO_BIG"; + case CRONET_NET_ERROR_UNEXPECTED: + return "CRONET_NET_ERROR_UNEXPECTED"; + case CRONET_NET_ERROR_ACCESS_DENIED: + return "CRONET_NET_ERROR_ACCESS_DENIED"; + case CRONET_NET_ERROR_NOT_IMPLEMENTED: + return "CRONET_NET_ERROR_NOT_IMPLEMENTED"; + case CRONET_NET_ERROR_INSUFFICIENT_RESOURCES: + return "CRONET_NET_ERROR_INSUFFICIENT_RESOURCES"; + case CRONET_NET_ERROR_OUT_OF_MEMORY: + return "CRONET_NET_ERROR_OUT_OF_MEMORY"; + case CRONET_NET_ERROR_UPLOAD_FILE_CHANGED: + return "CRONET_NET_ERROR_UPLOAD_FILE_CHANGED"; + case CRONET_NET_ERROR_SOCKET_NOT_CONNECTED: + return "CRONET_NET_ERROR_SOCKET_NOT_CONNECTED"; + case CRONET_NET_ERROR_FILE_EXISTS: + return "CRONET_NET_ERROR_FILE_EXISTS"; + case CRONET_NET_ERROR_FILE_PATH_TOO_LONG: + return "CRONET_NET_ERROR_FILE_PATH_TOO_LONG"; + case CRONET_NET_ERROR_FILE_NO_SPACE: + return "CRONET_NET_ERROR_FILE_NO_SPACE"; + case CRONET_NET_ERROR_FILE_VIRUS_INFECTED: + return "CRONET_NET_ERROR_FILE_VIRUS_INFECTED"; + case CRONET_NET_ERROR_BLOCKED_BY_CLIENT: + return "CRONET_NET_ERROR_BLOCKED_BY_CLIENT"; + case CRONET_NET_ERROR_NETWORK_CHANGED: + return "CRONET_NET_ERROR_NETWORK_CHANGED"; + case CRONET_NET_ERROR_BLOCKED_BY_ADMINISTRATOR: + return "CRONET_NET_ERROR_BLOCKED_BY_ADMINISTRATOR"; + case CRONET_NET_ERROR_SOCKET_IS_CONNECTED: + return "CRONET_NET_ERROR_SOCKET_IS_CONNECTED"; + case CRONET_NET_ERROR_BLOCKED_ENROLLMENT_CHECK_PENDING: + return "CRONET_NET_ERROR_BLOCKED_ENROLLMENT_CHECK_PENDING"; + case CRONET_NET_ERROR_UPLOAD_STREAM_REWIND_NOT_SUPPORTED: + return "CRONET_NET_ERROR_UPLOAD_STREAM_REWIND_NOT_SUPPORTED"; + case CRONET_NET_ERROR_CONTEXT_SHUT_DOWN: + return "CRONET_NET_ERROR_CONTEXT_SHUT_DOWN"; + case CRONET_NET_ERROR_BLOCKED_BY_RESPONSE: + return "CRONET_NET_ERROR_BLOCKED_BY_RESPONSE"; + case CRONET_NET_ERROR_CLEARTEXT_NOT_PERMITTED: + return "CRONET_NET_ERROR_CLEARTEXT_NOT_PERMITTED"; + case CRONET_NET_ERROR_BLOCKED_BY_CSP: + return "CRONET_NET_ERROR_BLOCKED_BY_CSP"; + case CRONET_NET_ERROR_H2_OR_QUIC_REQUIRED: + return "CRONET_NET_ERROR_H2_OR_QUIC_REQUIRED"; + case CRONET_NET_ERROR_INSECURE_PRIVATE_NETWORK_REQUEST: + return "CRONET_NET_ERROR_INSECURE_PRIVATE_NETWORK_REQUEST"; + case CRONET_NET_ERROR_CONNECTION_CLOSED: + return "CRONET_NET_ERROR_CONNECTION_CLOSED"; + case CRONET_NET_ERROR_CONNECTION_RESET: + return "CRONET_NET_ERROR_CONNECTION_RESET"; + case CRONET_NET_ERROR_CONNECTION_REFUSED: + return "CRONET_NET_ERROR_CONNECTION_REFUSED"; + case CRONET_NET_ERROR_CONNECTION_ABORTED: + return "CRONET_NET_ERROR_CONNECTION_ABORTED"; + case CRONET_NET_ERROR_CONNECTION_FAILED: + return "CRONET_NET_ERROR_CONNECTION_FAILED"; + case CRONET_NET_ERROR_NAME_NOT_RESOLVED: + return "CRONET_NET_ERROR_NAME_NOT_RESOLVED"; + case CRONET_NET_ERROR_INTERNET_DISCONNECTED: + return "CRONET_NET_ERROR_INTERNET_DISCONNECTED"; + case CRONET_NET_ERROR_SSL_PROTOCOL_ERROR: + return "CRONET_NET_ERROR_SSL_PROTOCOL_ERROR"; + case CRONET_NET_ERROR_ADDRESS_INVALID: + return "CRONET_NET_ERROR_ADDRESS_INVALID"; + case CRONET_NET_ERROR_ADDRESS_UNREACHABLE: + return "CRONET_NET_ERROR_ADDRESS_UNREACHABLE"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NEEDED: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NEEDED"; + case CRONET_NET_ERROR_TUNNEL_CONNECTION_FAILED: + return "CRONET_NET_ERROR_TUNNEL_CONNECTION_FAILED"; + case CRONET_NET_ERROR_NO_SSL_VERSIONS_ENABLED: + return "CRONET_NET_ERROR_NO_SSL_VERSIONS_ENABLED"; + case CRONET_NET_ERROR_SSL_VERSION_OR_CIPHER_MISMATCH: + return "CRONET_NET_ERROR_SSL_VERSION_OR_CIPHER_MISMATCH"; + case CRONET_NET_ERROR_SSL_RENEGOTIATION_REQUESTED: + return "CRONET_NET_ERROR_SSL_RENEGOTIATION_REQUESTED"; + case CRONET_NET_ERROR_PROXY_AUTH_UNSUPPORTED: + return "CRONET_NET_ERROR_PROXY_AUTH_UNSUPPORTED"; + case CRONET_NET_ERROR_CERT_ERROR_IN_SSL_RENEGOTIATION: + return "CRONET_NET_ERROR_CERT_ERROR_IN_SSL_RENEGOTIATION"; + case CRONET_NET_ERROR_BAD_SSL_CLIENT_AUTH_CERT: + return "CRONET_NET_ERROR_BAD_SSL_CLIENT_AUTH_CERT"; + case CRONET_NET_ERROR_CONNECTION_TIMED_OUT: + return "CRONET_NET_ERROR_CONNECTION_TIMED_OUT"; + case CRONET_NET_ERROR_HOST_RESOLVER_QUEUE_TOO_LARGE: + return "CRONET_NET_ERROR_HOST_RESOLVER_QUEUE_TOO_LARGE"; + case CRONET_NET_ERROR_SOCKS_CONNECTION_FAILED: + return "CRONET_NET_ERROR_SOCKS_CONNECTION_FAILED"; + case CRONET_NET_ERROR_SOCKS_CONNECTION_HOST_UNREACHABLE: + return "CRONET_NET_ERROR_SOCKS_CONNECTION_HOST_UNREACHABLE"; + case CRONET_NET_ERROR_ALPN_NEGOTIATION_FAILED: + return "CRONET_NET_ERROR_ALPN_NEGOTIATION_FAILED"; + case CRONET_NET_ERROR_SSL_NO_RENEGOTIATION: + return "CRONET_NET_ERROR_SSL_NO_RENEGOTIATION"; + case CRONET_NET_ERROR_WINSOCK_UNEXPECTED_WRITTEN_BYTES: + return "CRONET_NET_ERROR_WINSOCK_UNEXPECTED_WRITTEN_BYTES"; + case CRONET_NET_ERROR_SSL_DECOMPRESSION_FAILURE_ALERT: + return "CRONET_NET_ERROR_SSL_DECOMPRESSION_FAILURE_ALERT"; + case CRONET_NET_ERROR_SSL_BAD_RECORD_MAC_ALERT: + return "CRONET_NET_ERROR_SSL_BAD_RECORD_MAC_ALERT"; + case CRONET_NET_ERROR_PROXY_AUTH_REQUESTED: + return "CRONET_NET_ERROR_PROXY_AUTH_REQUESTED"; + case CRONET_NET_ERROR_PROXY_CONNECTION_FAILED: + return "CRONET_NET_ERROR_PROXY_CONNECTION_FAILED"; + case CRONET_NET_ERROR_MANDATORY_PROXY_CONFIGURATION_FAILED: + return "CRONET_NET_ERROR_MANDATORY_PROXY_CONFIGURATION_FAILED"; + case CRONET_NET_ERROR_PRECONNECT_MAX_SOCKET_LIMIT: + return "CRONET_NET_ERROR_PRECONNECT_MAX_SOCKET_LIMIT"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_PRIVATE_KEY_ACCESS_DENIED: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_PRIVATE_KEY_ACCESS_DENIED"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY"; + case CRONET_NET_ERROR_PROXY_CERTIFICATE_INVALID: + return "CRONET_NET_ERROR_PROXY_CERTIFICATE_INVALID"; + case CRONET_NET_ERROR_NAME_RESOLUTION_FAILED: + return "CRONET_NET_ERROR_NAME_RESOLUTION_FAILED"; + case CRONET_NET_ERROR_NETWORK_ACCESS_DENIED: + return "CRONET_NET_ERROR_NETWORK_ACCESS_DENIED"; + case CRONET_NET_ERROR_TEMPORARILY_THROTTLED: + return "CRONET_NET_ERROR_TEMPORARILY_THROTTLED"; + case CRONET_NET_ERROR_HTTPS_PROXY_TUNNEL_RESPONSE_REDIRECT: + return "CRONET_NET_ERROR_HTTPS_PROXY_TUNNEL_RESPONSE_REDIRECT"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_SIGNATURE_FAILED: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_SIGNATURE_FAILED"; + case CRONET_NET_ERROR_MSG_TOO_BIG: + return "CRONET_NET_ERROR_MSG_TOO_BIG"; + case CRONET_NET_ERROR_WS_PROTOCOL_ERROR: + return "CRONET_NET_ERROR_WS_PROTOCOL_ERROR"; + case CRONET_NET_ERROR_ADDRESS_IN_USE: + return "CRONET_NET_ERROR_ADDRESS_IN_USE"; + case CRONET_NET_ERROR_SSL_HANDSHAKE_NOT_COMPLETED: + return "CRONET_NET_ERROR_SSL_HANDSHAKE_NOT_COMPLETED"; + case CRONET_NET_ERROR_SSL_BAD_PEER_PUBLIC_KEY: + return "CRONET_NET_ERROR_SSL_BAD_PEER_PUBLIC_KEY"; + case CRONET_NET_ERROR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN: + return "CRONET_NET_ERROR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN"; + case CRONET_NET_ERROR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED: + return "CRONET_NET_ERROR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED"; + case CRONET_NET_ERROR_SSL_DECRYPT_ERROR_ALERT: + return "CRONET_NET_ERROR_SSL_DECRYPT_ERROR_ALERT"; + case CRONET_NET_ERROR_WS_THROTTLE_QUEUE_TOO_LARGE: + return "CRONET_NET_ERROR_WS_THROTTLE_QUEUE_TOO_LARGE"; + case CRONET_NET_ERROR_SSL_SERVER_CERT_CHANGED: + return "CRONET_NET_ERROR_SSL_SERVER_CERT_CHANGED"; + case CRONET_NET_ERROR_SSL_UNRECOGNIZED_NAME_ALERT: + return "CRONET_NET_ERROR_SSL_UNRECOGNIZED_NAME_ALERT"; + case CRONET_NET_ERROR_SOCKET_SET_RECEIVE_BUFFER_SIZE_ERROR: + return "CRONET_NET_ERROR_SOCKET_SET_RECEIVE_BUFFER_SIZE_ERROR"; + case CRONET_NET_ERROR_SOCKET_SET_SEND_BUFFER_SIZE_ERROR: + return "CRONET_NET_ERROR_SOCKET_SET_SEND_BUFFER_SIZE_ERROR"; + case CRONET_NET_ERROR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE: + return "CRONET_NET_ERROR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE"; + case CRONET_NET_ERROR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE: + return "CRONET_NET_ERROR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT"; + case CRONET_NET_ERROR_ICANN_NAME_COLLISION: + return "CRONET_NET_ERROR_ICANN_NAME_COLLISION"; + case CRONET_NET_ERROR_SSL_SERVER_CERT_BAD_FORMAT: + return "CRONET_NET_ERROR_SSL_SERVER_CERT_BAD_FORMAT"; + case CRONET_NET_ERROR_CT_STH_PARSING_FAILED: + return "CRONET_NET_ERROR_CT_STH_PARSING_FAILED"; + case CRONET_NET_ERROR_CT_STH_INCOMPLETE: + return "CRONET_NET_ERROR_CT_STH_INCOMPLETE"; + case CRONET_NET_ERROR_UNABLE_TO_REUSE_CONNECTION_FOR_PROXY_AUTH: + return "CRONET_NET_ERROR_UNABLE_TO_REUSE_CONNECTION_FOR_PROXY_AUTH"; + case CRONET_NET_ERROR_CT_CONSISTENCY_PROOF_PARSING_FAILED: + return "CRONET_NET_ERROR_CT_CONSISTENCY_PROOF_PARSING_FAILED"; + case CRONET_NET_ERROR_SSL_OBSOLETE_CIPHER: + return "CRONET_NET_ERROR_SSL_OBSOLETE_CIPHER"; + case CRONET_NET_ERROR_WS_UPGRADE: + return "CRONET_NET_ERROR_WS_UPGRADE"; + case CRONET_NET_ERROR_READ_IF_READY_NOT_IMPLEMENTED: + return "CRONET_NET_ERROR_READ_IF_READY_NOT_IMPLEMENTED"; + case CRONET_NET_ERROR_NO_BUFFER_SPACE: + return "CRONET_NET_ERROR_NO_BUFFER_SPACE"; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_NO_COMMON_ALGORITHMS: + return "CRONET_NET_ERROR_SSL_CLIENT_AUTH_NO_COMMON_ALGORITHMS"; + case CRONET_NET_ERROR_EARLY_DATA_REJECTED: + return "CRONET_NET_ERROR_EARLY_DATA_REJECTED"; + case CRONET_NET_ERROR_WRONG_VERSION_ON_EARLY_DATA: + return "CRONET_NET_ERROR_WRONG_VERSION_ON_EARLY_DATA"; + case CRONET_NET_ERROR_TLS13_DOWNGRADE_DETECTED: + return "CRONET_NET_ERROR_TLS13_DOWNGRADE_DETECTED"; + case CRONET_NET_ERROR_SSL_KEY_USAGE_INCOMPATIBLE: + return "CRONET_NET_ERROR_SSL_KEY_USAGE_INCOMPATIBLE"; + case CRONET_NET_ERROR_CERT_COMMON_NAME_INVALID: + return "CRONET_NET_ERROR_CERT_COMMON_NAME_INVALID"; + case CRONET_NET_ERROR_CERT_DATE_INVALID: + return "CRONET_NET_ERROR_CERT_DATE_INVALID"; + case CRONET_NET_ERROR_CERT_AUTHORITY_INVALID: + return "CRONET_NET_ERROR_CERT_AUTHORITY_INVALID"; + case CRONET_NET_ERROR_CERT_CONTAINS_ERRORS: + return "CRONET_NET_ERROR_CERT_CONTAINS_ERRORS"; + case CRONET_NET_ERROR_CERT_NO_REVOCATION_MECHANISM: + return "CRONET_NET_ERROR_CERT_NO_REVOCATION_MECHANISM"; + case CRONET_NET_ERROR_CERT_UNABLE_TO_CHECK_REVOCATION: + return "CRONET_NET_ERROR_CERT_UNABLE_TO_CHECK_REVOCATION"; + case CRONET_NET_ERROR_CERT_REVOKED: + return "CRONET_NET_ERROR_CERT_REVOKED"; + case CRONET_NET_ERROR_CERT_INVALID: + return "CRONET_NET_ERROR_CERT_INVALID"; + case CRONET_NET_ERROR_CERT_WEAK_SIGNATURE_ALGORITHM: + return "CRONET_NET_ERROR_CERT_WEAK_SIGNATURE_ALGORITHM"; + case CRONET_NET_ERROR_CERT_NON_UNIQUE_NAME: + return "CRONET_NET_ERROR_CERT_NON_UNIQUE_NAME"; + case CRONET_NET_ERROR_CERT_WEAK_KEY: + return "CRONET_NET_ERROR_CERT_WEAK_KEY"; + case CRONET_NET_ERROR_CERT_NAME_CONSTRAINT_VIOLATION: + return "CRONET_NET_ERROR_CERT_NAME_CONSTRAINT_VIOLATION"; + case CRONET_NET_ERROR_CERT_VALIDITY_TOO_LONG: + return "CRONET_NET_ERROR_CERT_VALIDITY_TOO_LONG"; + case CRONET_NET_ERROR_CERTIFICATE_TRANSPARENCY_REQUIRED: + return "CRONET_NET_ERROR_CERTIFICATE_TRANSPARENCY_REQUIRED"; + case CRONET_NET_ERROR_CERT_SYMANTEC_LEGACY: + return "CRONET_NET_ERROR_CERT_SYMANTEC_LEGACY"; + case CRONET_NET_ERROR_CERT_KNOWN_INTERCEPTION_BLOCKED: + return "CRONET_NET_ERROR_CERT_KNOWN_INTERCEPTION_BLOCKED"; + case CRONET_NET_ERROR_SSL_OBSOLETE_VERSION: + return "CRONET_NET_ERROR_SSL_OBSOLETE_VERSION"; + case CRONET_NET_ERROR_CERT_END: + return "CRONET_NET_ERROR_CERT_END"; + case CRONET_NET_ERROR_INVALID_URL: + return "CRONET_NET_ERROR_INVALID_URL"; + case CRONET_NET_ERROR_DISALLOWED_URL_SCHEME: + return "CRONET_NET_ERROR_DISALLOWED_URL_SCHEME"; + case CRONET_NET_ERROR_UNKNOWN_URL_SCHEME: + return "CRONET_NET_ERROR_UNKNOWN_URL_SCHEME"; + case CRONET_NET_ERROR_INVALID_REDIRECT: + return "CRONET_NET_ERROR_INVALID_REDIRECT"; + case CRONET_NET_ERROR_TOO_MANY_REDIRECTS: + return "CRONET_NET_ERROR_TOO_MANY_REDIRECTS"; + case CRONET_NET_ERROR_UNSAFE_REDIRECT: + return "CRONET_NET_ERROR_UNSAFE_REDIRECT"; + case CRONET_NET_ERROR_UNSAFE_PORT: + return "CRONET_NET_ERROR_UNSAFE_PORT"; + case CRONET_NET_ERROR_INVALID_RESPONSE: + return "CRONET_NET_ERROR_INVALID_RESPONSE"; + case CRONET_NET_ERROR_INVALID_CHUNKED_ENCODING: + return "CRONET_NET_ERROR_INVALID_CHUNKED_ENCODING"; + case CRONET_NET_ERROR_METHOD_NOT_SUPPORTED: + return "CRONET_NET_ERROR_METHOD_NOT_SUPPORTED"; + case CRONET_NET_ERROR_UNEXPECTED_PROXY_AUTH: + return "CRONET_NET_ERROR_UNEXPECTED_PROXY_AUTH"; + case CRONET_NET_ERROR_EMPTY_RESPONSE: + return "CRONET_NET_ERROR_EMPTY_RESPONSE"; + case CRONET_NET_ERROR_RESPONSE_HEADERS_TOO_BIG: + return "CRONET_NET_ERROR_RESPONSE_HEADERS_TOO_BIG"; + case CRONET_NET_ERROR_PAC_SCRIPT_FAILED: + return "CRONET_NET_ERROR_PAC_SCRIPT_FAILED"; + case CRONET_NET_ERROR_REQUEST_RANGE_NOT_SATISFIABLE: + return "CRONET_NET_ERROR_REQUEST_RANGE_NOT_SATISFIABLE"; + case CRONET_NET_ERROR_MALFORMED_IDENTITY: + return "CRONET_NET_ERROR_MALFORMED_IDENTITY"; + case CRONET_NET_ERROR_CONTENT_DECODING_FAILED: + return "CRONET_NET_ERROR_CONTENT_DECODING_FAILED"; + case CRONET_NET_ERROR_NETWORK_IO_SUSPENDED: + return "CRONET_NET_ERROR_NETWORK_IO_SUSPENDED"; + case CRONET_NET_ERROR_SYN_REPLY_NOT_RECEIVED: + return "CRONET_NET_ERROR_SYN_REPLY_NOT_RECEIVED"; + case CRONET_NET_ERROR_ENCODING_CONVERSION_FAILED: + return "CRONET_NET_ERROR_ENCODING_CONVERSION_FAILED"; + case CRONET_NET_ERROR_UNRECOGNIZED_FTP_DIRECTORY_LISTING_FORMAT: + return "CRONET_NET_ERROR_UNRECOGNIZED_FTP_DIRECTORY_LISTING_FORMAT"; + case CRONET_NET_ERROR_NO_SUPPORTED_PROXIES: + return "CRONET_NET_ERROR_NO_SUPPORTED_PROXIES"; + case CRONET_NET_ERROR_HTTP2_PROTOCOL_ERROR: + return "CRONET_NET_ERROR_HTTP2_PROTOCOL_ERROR"; + case CRONET_NET_ERROR_INVALID_AUTH_CREDENTIALS: + return "CRONET_NET_ERROR_INVALID_AUTH_CREDENTIALS"; + case CRONET_NET_ERROR_UNSUPPORTED_AUTH_SCHEME: + return "CRONET_NET_ERROR_UNSUPPORTED_AUTH_SCHEME"; + case CRONET_NET_ERROR_ENCODING_DETECTION_FAILED: + return "CRONET_NET_ERROR_ENCODING_DETECTION_FAILED"; + case CRONET_NET_ERROR_MISSING_AUTH_CREDENTIALS: + return "CRONET_NET_ERROR_MISSING_AUTH_CREDENTIALS"; + case CRONET_NET_ERROR_UNEXPECTED_SECURITY_LIBRARY_STATUS: + return "CRONET_NET_ERROR_UNEXPECTED_SECURITY_LIBRARY_STATUS"; + case CRONET_NET_ERROR_MISCONFIGURED_AUTH_ENVIRONMENT: + return "CRONET_NET_ERROR_MISCONFIGURED_AUTH_ENVIRONMENT"; + case CRONET_NET_ERROR_UNDOCUMENTED_SECURITY_LIBRARY_STATUS: + return "CRONET_NET_ERROR_UNDOCUMENTED_SECURITY_LIBRARY_STATUS"; + case CRONET_NET_ERROR_RESPONSE_BODY_TOO_BIG_TO_DRAIN: + return "CRONET_NET_ERROR_RESPONSE_BODY_TOO_BIG_TO_DRAIN"; + case CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_CONTENT_LENGTH: + return "CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_CONTENT_LENGTH"; + case CRONET_NET_ERROR_INCOMPLETE_HTTP2_HEADERS: + return "CRONET_NET_ERROR_INCOMPLETE_HTTP2_HEADERS"; + case CRONET_NET_ERROR_PAC_NOT_IN_DHCP: + return "CRONET_NET_ERROR_PAC_NOT_IN_DHCP"; + case CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_CONTENT_DISPOSITION: + return "CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_CONTENT_DISPOSITION"; + case CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_LOCATION: + return "CRONET_NET_ERROR_RESPONSE_HEADERS_MULTIPLE_LOCATION"; + case CRONET_NET_ERROR_HTTP2_SERVER_REFUSED_STREAM: + return "CRONET_NET_ERROR_HTTP2_SERVER_REFUSED_STREAM"; + case CRONET_NET_ERROR_HTTP2_PING_FAILED: + return "CRONET_NET_ERROR_HTTP2_PING_FAILED"; + case CRONET_NET_ERROR_CONTENT_LENGTH_MISMATCH: + return "CRONET_NET_ERROR_CONTENT_LENGTH_MISMATCH"; + case CRONET_NET_ERROR_INCOMPLETE_CHUNKED_ENCODING: + return "CRONET_NET_ERROR_INCOMPLETE_CHUNKED_ENCODING"; + case CRONET_NET_ERROR_QUIC_PROTOCOL_ERROR: + return "CRONET_NET_ERROR_QUIC_PROTOCOL_ERROR"; + case CRONET_NET_ERROR_RESPONSE_HEADERS_TRUNCATED: + return "CRONET_NET_ERROR_RESPONSE_HEADERS_TRUNCATED"; + case CRONET_NET_ERROR_QUIC_HANDSHAKE_FAILED: + return "CRONET_NET_ERROR_QUIC_HANDSHAKE_FAILED"; + case CRONET_NET_ERROR_HTTP2_INADEQUATE_TRANSPORT_SECURITY: + return "CRONET_NET_ERROR_HTTP2_INADEQUATE_TRANSPORT_SECURITY"; + case CRONET_NET_ERROR_HTTP2_FLOW_CONTROL_ERROR: + return "CRONET_NET_ERROR_HTTP2_FLOW_CONTROL_ERROR"; + case CRONET_NET_ERROR_HTTP2_FRAME_SIZE_ERROR: + return "CRONET_NET_ERROR_HTTP2_FRAME_SIZE_ERROR"; + case CRONET_NET_ERROR_HTTP2_COMPRESSION_ERROR: + return "CRONET_NET_ERROR_HTTP2_COMPRESSION_ERROR"; + case CRONET_NET_ERROR_PROXY_AUTH_REQUESTED_WITH_NO_CONNECTION: + return "CRONET_NET_ERROR_PROXY_AUTH_REQUESTED_WITH_NO_CONNECTION"; + case CRONET_NET_ERROR_HTTP_1_1_REQUIRED: + return "CRONET_NET_ERROR_HTTP_1_1_REQUIRED"; + case CRONET_NET_ERROR_PROXY_HTTP_1_1_REQUIRED: + return "CRONET_NET_ERROR_PROXY_HTTP_1_1_REQUIRED"; + case CRONET_NET_ERROR_PAC_SCRIPT_TERMINATED: + return "CRONET_NET_ERROR_PAC_SCRIPT_TERMINATED"; + case CRONET_NET_ERROR_INVALID_HTTP_RESPONSE: + return "CRONET_NET_ERROR_INVALID_HTTP_RESPONSE"; + case CRONET_NET_ERROR_CONTENT_DECODING_INIT_FAILED: + return "CRONET_NET_ERROR_CONTENT_DECODING_INIT_FAILED"; + case CRONET_NET_ERROR_HTTP2_RST_STREAM_NO_ERROR_RECEIVED: + return "CRONET_NET_ERROR_HTTP2_RST_STREAM_NO_ERROR_RECEIVED"; + case CRONET_NET_ERROR_HTTP2_PUSHED_STREAM_NOT_AVAILABLE: + return "CRONET_NET_ERROR_HTTP2_PUSHED_STREAM_NOT_AVAILABLE"; + case CRONET_NET_ERROR_HTTP2_CLAIMED_PUSHED_STREAM_RESET_BY_SERVER: + return "CRONET_NET_ERROR_HTTP2_CLAIMED_PUSHED_STREAM_RESET_BY_SERVER"; + case CRONET_NET_ERROR_TOO_MANY_RETRIES: + return "CRONET_NET_ERROR_TOO_MANY_RETRIES"; + case CRONET_NET_ERROR_HTTP2_STREAM_CLOSED: + return "CRONET_NET_ERROR_HTTP2_STREAM_CLOSED"; + case CRONET_NET_ERROR_HTTP2_CLIENT_REFUSED_STREAM: + return "CRONET_NET_ERROR_HTTP2_CLIENT_REFUSED_STREAM"; + case CRONET_NET_ERROR_HTTP2_PUSHED_RESPONSE_DOES_NOT_MATCH: + return "CRONET_NET_ERROR_HTTP2_PUSHED_RESPONSE_DOES_NOT_MATCH"; + case CRONET_NET_ERROR_HTTP_RESPONSE_CODE_FAILURE: + return "CRONET_NET_ERROR_HTTP_RESPONSE_CODE_FAILURE"; + case CRONET_NET_ERROR_QUIC_CERT_ROOT_NOT_KNOWN: + return "CRONET_NET_ERROR_QUIC_CERT_ROOT_NOT_KNOWN"; + case CRONET_NET_ERROR_CACHE_MISS: + return "CRONET_NET_ERROR_CACHE_MISS"; + case CRONET_NET_ERROR_CACHE_READ_FAILURE: + return "CRONET_NET_ERROR_CACHE_READ_FAILURE"; + case CRONET_NET_ERROR_CACHE_WRITE_FAILURE: + return "CRONET_NET_ERROR_CACHE_WRITE_FAILURE"; + case CRONET_NET_ERROR_CACHE_OPERATION_NOT_SUPPORTED: + return "CRONET_NET_ERROR_CACHE_OPERATION_NOT_SUPPORTED"; + case CRONET_NET_ERROR_CACHE_OPEN_FAILURE: + return "CRONET_NET_ERROR_CACHE_OPEN_FAILURE"; + case CRONET_NET_ERROR_CACHE_CREATE_FAILURE: + return "CRONET_NET_ERROR_CACHE_CREATE_FAILURE"; + case CRONET_NET_ERROR_CACHE_RACE: + return "CRONET_NET_ERROR_CACHE_RACE"; + case CRONET_NET_ERROR_CACHE_CHECKSUM_READ_FAILURE: + return "CRONET_NET_ERROR_CACHE_CHECKSUM_READ_FAILURE"; + case CRONET_NET_ERROR_CACHE_CHECKSUM_MISMATCH: + return "CRONET_NET_ERROR_CACHE_CHECKSUM_MISMATCH"; + case CRONET_NET_ERROR_CACHE_LOCK_TIMEOUT: + return "CRONET_NET_ERROR_CACHE_LOCK_TIMEOUT"; + case CRONET_NET_ERROR_CACHE_AUTH_FAILURE_AFTER_READ: + return "CRONET_NET_ERROR_CACHE_AUTH_FAILURE_AFTER_READ"; + case CRONET_NET_ERROR_CACHE_ENTRY_NOT_SUITABLE: + return "CRONET_NET_ERROR_CACHE_ENTRY_NOT_SUITABLE"; + case CRONET_NET_ERROR_CACHE_DOOM_FAILURE: + return "CRONET_NET_ERROR_CACHE_DOOM_FAILURE"; + case CRONET_NET_ERROR_CACHE_OPEN_OR_CREATE_FAILURE: + return "CRONET_NET_ERROR_CACHE_OPEN_OR_CREATE_FAILURE"; + case CRONET_NET_ERROR_INSECURE_RESPONSE: + return "CRONET_NET_ERROR_INSECURE_RESPONSE"; + case CRONET_NET_ERROR_NO_PRIVATE_KEY_FOR_CERT: + return "CRONET_NET_ERROR_NO_PRIVATE_KEY_FOR_CERT"; + case CRONET_NET_ERROR_ADD_USER_CERT_FAILED: + return "CRONET_NET_ERROR_ADD_USER_CERT_FAILED"; + case CRONET_NET_ERROR_INVALID_SIGNED_EXCHANGE: + return "CRONET_NET_ERROR_INVALID_SIGNED_EXCHANGE"; + case CRONET_NET_ERROR_INVALID_WEB_BUNDLE: + return "CRONET_NET_ERROR_INVALID_WEB_BUNDLE"; + case CRONET_NET_ERROR_TRUST_TOKEN_OPERATION_FAILED: + return "CRONET_NET_ERROR_TRUST_TOKEN_OPERATION_FAILED"; + case CRONET_NET_ERROR_TRUST_TOKEN_OPERATION_CACHE_HIT: + return "CRONET_NET_ERROR_TRUST_TOKEN_OPERATION_CACHE_HIT"; + case CRONET_NET_ERROR_FTP_FAILED: + return "CRONET_NET_ERROR_FTP_FAILED"; + case CRONET_NET_ERROR_FTP_SERVICE_UNAVAILABLE: + return "CRONET_NET_ERROR_FTP_SERVICE_UNAVAILABLE"; + case CRONET_NET_ERROR_FTP_TRANSFER_ABORTED: + return "CRONET_NET_ERROR_FTP_TRANSFER_ABORTED"; + case CRONET_NET_ERROR_FTP_FILE_BUSY: + return "CRONET_NET_ERROR_FTP_FILE_BUSY"; + case CRONET_NET_ERROR_FTP_SYNTAX_ERROR: + return "CRONET_NET_ERROR_FTP_SYNTAX_ERROR"; + case CRONET_NET_ERROR_FTP_COMMAND_NOT_SUPPORTED: + return "CRONET_NET_ERROR_FTP_COMMAND_NOT_SUPPORTED"; + case CRONET_NET_ERROR_FTP_BAD_COMMAND_SEQUENCE: + return "CRONET_NET_ERROR_FTP_BAD_COMMAND_SEQUENCE"; + case CRONET_NET_ERROR_PKCS12_IMPORT_BAD_PASSWORD: + return "CRONET_NET_ERROR_PKCS12_IMPORT_BAD_PASSWORD"; + case CRONET_NET_ERROR_PKCS12_IMPORT_FAILED: + return "CRONET_NET_ERROR_PKCS12_IMPORT_FAILED"; + case CRONET_NET_ERROR_IMPORT_CA_CERT_NOT_CA: + return "CRONET_NET_ERROR_IMPORT_CA_CERT_NOT_CA"; + case CRONET_NET_ERROR_IMPORT_CERT_ALREADY_EXISTS: + return "CRONET_NET_ERROR_IMPORT_CERT_ALREADY_EXISTS"; + case CRONET_NET_ERROR_IMPORT_CA_CERT_FAILED: + return "CRONET_NET_ERROR_IMPORT_CA_CERT_FAILED"; + case CRONET_NET_ERROR_IMPORT_SERVER_CERT_FAILED: + return "CRONET_NET_ERROR_IMPORT_SERVER_CERT_FAILED"; + case CRONET_NET_ERROR_PKCS12_IMPORT_INVALID_MAC: + return "CRONET_NET_ERROR_PKCS12_IMPORT_INVALID_MAC"; + case CRONET_NET_ERROR_PKCS12_IMPORT_INVALID_FILE: + return "CRONET_NET_ERROR_PKCS12_IMPORT_INVALID_FILE"; + case CRONET_NET_ERROR_PKCS12_IMPORT_UNSUPPORTED: + return "CRONET_NET_ERROR_PKCS12_IMPORT_UNSUPPORTED"; + case CRONET_NET_ERROR_KEY_GENERATION_FAILED: + return "CRONET_NET_ERROR_KEY_GENERATION_FAILED"; + case CRONET_NET_ERROR_PRIVATE_KEY_EXPORT_FAILED: + return "CRONET_NET_ERROR_PRIVATE_KEY_EXPORT_FAILED"; + case CRONET_NET_ERROR_SELF_SIGNED_CERT_GENERATION_FAILED: + return "CRONET_NET_ERROR_SELF_SIGNED_CERT_GENERATION_FAILED"; + case CRONET_NET_ERROR_CERT_DATABASE_CHANGED: + return "CRONET_NET_ERROR_CERT_DATABASE_CHANGED"; + case CRONET_NET_ERROR_DNS_MALFORMED_RESPONSE: + return "CRONET_NET_ERROR_DNS_MALFORMED_RESPONSE"; + case CRONET_NET_ERROR_DNS_SERVER_REQUIRES_TCP: + return "CRONET_NET_ERROR_DNS_SERVER_REQUIRES_TCP"; + case CRONET_NET_ERROR_DNS_SERVER_FAILED: + return "CRONET_NET_ERROR_DNS_SERVER_FAILED"; + case CRONET_NET_ERROR_DNS_TIMED_OUT: + return "CRONET_NET_ERROR_DNS_TIMED_OUT"; + case CRONET_NET_ERROR_DNS_CACHE_MISS: + return "CRONET_NET_ERROR_DNS_CACHE_MISS"; + case CRONET_NET_ERROR_DNS_SEARCH_EMPTY: + return "CRONET_NET_ERROR_DNS_SEARCH_EMPTY"; + case CRONET_NET_ERROR_DNS_SORT_ERROR: + return "CRONET_NET_ERROR_DNS_SORT_ERROR"; + case CRONET_NET_ERROR_DNS_SECURE_RESOLVER_HOSTNAME_RESOLUTION_FAILED: + return "CRONET_NET_ERROR_DNS_SECURE_RESOLVER_HOSTNAME_RESOLUTION_FAILED"; + } + return "UNAVAILABLE."; +} + +grpc_status_code cronet_net_error_to_grpc_error( + cronet_net_error_code net_error) { + switch (net_error) { + case OK: + return GRPC_STATUS_OK; + case CRONET_NET_ERROR_ABORTED: + return GRPC_STATUS_ABORTED; + case CRONET_NET_ERROR_ACCESS_DENIED: + case CRONET_NET_ERROR_NETWORK_ACCESS_DENIED: + return GRPC_STATUS_PERMISSION_DENIED; + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NEEDED: + case CRONET_NET_ERROR_PROXY_AUTH_UNSUPPORTED: + case CRONET_NET_ERROR_BAD_SSL_CLIENT_AUTH_CERT: + case CRONET_NET_ERROR_PROXY_AUTH_REQUESTED: + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_PRIVATE_KEY_ACCESS_DENIED: + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY: + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_SIGNATURE_FAILED: + case CRONET_NET_ERROR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED: + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT: + case CRONET_NET_ERROR_SSL_CLIENT_AUTH_NO_COMMON_ALGORITHMS: + case CRONET_NET_ERROR_CERT_AUTHORITY_INVALID: + case CRONET_NET_ERROR_UNEXPECTED_PROXY_AUTH: + case CRONET_NET_ERROR_MALFORMED_IDENTITY: + case CRONET_NET_ERROR_INVALID_AUTH_CREDENTIALS: + case CRONET_NET_ERROR_UNSUPPORTED_AUTH_SCHEME: + case CRONET_NET_ERROR_MISSING_AUTH_CREDENTIALS: + return GRPC_STATUS_UNAUTHENTICATED; + default: + return GRPC_STATUS_UNAVAILABLE; + } +} diff --git a/src/core/ext/transport/cronet/transport/cronet_transport.cc b/src/core/ext/transport/cronet/transport/cronet_transport.cc new file mode 100644 index 00000000..4c719e5b --- /dev/null +++ b/src/core/ext/transport/cronet/transport/cronet_transport.cc @@ -0,0 +1,1533 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/cronet/transport/cronet_transport.h" + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "third_party/objective_c/Cronet/bidirectional_stream_c.h" + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/bin_decoder.h" +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" +#include "src/core/ext/transport/cronet/transport/cronet_status.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/timeout_encoding.h" +#include "src/core/lib/transport/transport_impl.h" + +#define GRPC_HEADER_SIZE_IN_BYTES 5 +#define GRPC_FLUSH_READ_SIZE 4096 + +grpc_core::TraceFlag grpc_cronet_trace(false, "cronet"); +#define CRONET_LOG(...) \ + do { \ + if (grpc_cronet_trace.enabled()) gpr_log(__VA_ARGS__); \ + } while (0) + +enum e_op_result { + ACTION_TAKEN_WITH_CALLBACK, + ACTION_TAKEN_NO_CALLBACK, + NO_ACTION_POSSIBLE +}; + +enum e_op_id { + OP_SEND_INITIAL_METADATA = 0, + OP_SEND_MESSAGE, + OP_SEND_TRAILING_METADATA, + OP_RECV_MESSAGE, + OP_RECV_INITIAL_METADATA, + OP_RECV_TRAILING_METADATA, + OP_CANCEL_ERROR, + OP_ON_COMPLETE, + OP_FAILED, + OP_SUCCEEDED, + OP_CANCELED, + OP_RECV_MESSAGE_AND_ON_COMPLETE, + OP_READ_REQ_MADE, + OP_NUM_OPS +}; + +/* Cronet callbacks. See cronet_c_for_grpc.h for documentation for each. */ + +static void on_stream_ready(bidirectional_stream*); +static void on_response_headers_received( + bidirectional_stream*, const bidirectional_stream_header_array*, + const char*); +static void on_write_completed(bidirectional_stream*, const char*); +static void on_read_completed(bidirectional_stream*, char*, int); +static void on_response_trailers_received( + bidirectional_stream*, const bidirectional_stream_header_array*); +static void on_succeeded(bidirectional_stream*); +static void on_failed(bidirectional_stream*, int); +static void on_canceled(bidirectional_stream*); +static bidirectional_stream_callback cronet_callbacks = { + on_stream_ready, + on_response_headers_received, + on_read_completed, + on_write_completed, + on_response_trailers_received, + on_succeeded, + on_failed, + on_canceled}; + +/* Cronet transport object */ +struct grpc_cronet_transport { + grpc_transport base; /* must be first element in this structure */ + stream_engine* engine; + char* host; + bool use_packet_coalescing; +}; +typedef struct grpc_cronet_transport grpc_cronet_transport; + +/* TODO (makdharma): reorder structure for memory efficiency per + http://www.catb.org/esr/structure-packing/#_structure_reordering: */ +struct read_state { + explicit read_state(grpc_core::Arena* arena) + : trailing_metadata(arena), initial_metadata(arena) { + grpc_slice_buffer_init(&read_slice_buffer); + } + + /* vars to store data coming from server */ + char* read_buffer = nullptr; + bool length_field_received = false; + int received_bytes = 0; + int remaining_bytes = 0; + int length_field = 0; + bool compressed = false; + char grpc_header_bytes[GRPC_HEADER_SIZE_IN_BYTES] = {}; + char* payload_field = nullptr; + bool read_stream_closed = false; + + /* vars for holding data destined for the application */ + grpc_core::ManualConstructor sbs; + grpc_slice_buffer read_slice_buffer; + + /* vars for trailing metadata */ + grpc_metadata_batch trailing_metadata; + bool trailing_metadata_valid = false; + + /* vars for initial metadata */ + grpc_metadata_batch initial_metadata; +}; + +struct write_state { + char* write_buffer = nullptr; +}; + +/* track state of one stream op */ +struct op_state { + explicit op_state(grpc_core::Arena* arena) : rs(arena) {} + + bool state_op_done[OP_NUM_OPS] = {}; + bool state_callback_received[OP_NUM_OPS] = {}; + /* A non-zero gRPC status code has been seen */ + bool fail_state = false; + /* Transport is discarding all buffered messages */ + bool flush_read = false; + bool flush_cronet_when_ready = false; + bool pending_write_for_trailer = false; + bool pending_send_message = false; + /* User requested RECV_TRAILING_METADATA */ + bool pending_recv_trailing_metadata = false; + cronet_net_error_code net_error = OK; + grpc_error_handle cancel_error = GRPC_ERROR_NONE; + /* data structure for storing data coming from server */ + struct read_state rs; + /* data structure for storing data going to the server */ + struct write_state ws; +}; + +struct stream_obj; + +struct op_and_state { + op_and_state(stream_obj* s, const grpc_transport_stream_op_batch& op); + + grpc_transport_stream_op_batch op; + struct op_state state; + bool done = false; + struct stream_obj* s; /* Pointer back to the stream object */ + /* next op_and_state in the linked list */ + struct op_and_state* next = nullptr; +}; + +struct op_storage { + int num_pending_ops = 0; + struct op_and_state* head = nullptr; +}; + +struct stream_obj { + stream_obj(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, grpc_core::Arena* arena); + ~stream_obj(); + + grpc_core::Arena* arena; + struct op_and_state* oas = nullptr; + grpc_transport_stream_op_batch* curr_op = nullptr; + grpc_cronet_transport* curr_ct; + grpc_stream* curr_gs; + bidirectional_stream* cbs = nullptr; + bidirectional_stream_header_array header_array = + bidirectional_stream_header_array(); // Zero-initialize the structure. + + /* Stream level state. Some state will be tracked both at stream and stream_op + * level */ + struct op_state state; + + /* OP storage */ + struct op_storage storage; + + /* Mutex to protect storage */ + gpr_mu mu; + + /* Refcount object of the stream */ + grpc_stream_refcount* refcount; +}; + +#ifndef NDEBUG +#define GRPC_CRONET_STREAM_REF(stream, reason) \ + grpc_cronet_stream_ref((stream), (reason)) +#define GRPC_CRONET_STREAM_UNREF(stream, reason) \ + grpc_cronet_stream_unref((stream), (reason)) +void grpc_cronet_stream_ref(stream_obj* s, const char* reason) { + grpc_stream_ref(s->refcount, reason); +} +void grpc_cronet_stream_unref(stream_obj* s, const char* reason) { + grpc_stream_unref(s->refcount, reason); +} +#else +#define GRPC_CRONET_STREAM_REF(stream, reason) grpc_cronet_stream_ref((stream)) +#define GRPC_CRONET_STREAM_UNREF(stream, reason) \ + grpc_cronet_stream_unref((stream)) +void grpc_cronet_stream_ref(stream_obj* s) { grpc_stream_ref(s->refcount); } +void grpc_cronet_stream_unref(stream_obj* s) { grpc_stream_unref(s->refcount); } +#endif + +static enum e_op_result execute_stream_op(struct op_and_state* oas); + +/* + Utility function to translate enum into string for printing +*/ +static const char* op_result_string(enum e_op_result i) { + switch (i) { + case ACTION_TAKEN_WITH_CALLBACK: + return "ACTION_TAKEN_WITH_CALLBACK"; + case ACTION_TAKEN_NO_CALLBACK: + return "ACTION_TAKEN_NO_CALLBACK"; + case NO_ACTION_POSSIBLE: + return "NO_ACTION_POSSIBLE"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +static const char* op_id_string(enum e_op_id i) { + switch (i) { + case OP_SEND_INITIAL_METADATA: + return "OP_SEND_INITIAL_METADATA"; + case OP_SEND_MESSAGE: + return "OP_SEND_MESSAGE"; + case OP_SEND_TRAILING_METADATA: + return "OP_SEND_TRAILING_METADATA"; + case OP_RECV_MESSAGE: + return "OP_RECV_MESSAGE"; + case OP_RECV_INITIAL_METADATA: + return "OP_RECV_INITIAL_METADATA"; + case OP_RECV_TRAILING_METADATA: + return "OP_RECV_TRAILING_METADATA"; + case OP_CANCEL_ERROR: + return "OP_CANCEL_ERROR"; + case OP_ON_COMPLETE: + return "OP_ON_COMPLETE"; + case OP_FAILED: + return "OP_FAILED"; + case OP_SUCCEEDED: + return "OP_SUCCEEDED"; + case OP_CANCELED: + return "OP_CANCELED"; + case OP_RECV_MESSAGE_AND_ON_COMPLETE: + return "OP_RECV_MESSAGE_AND_ON_COMPLETE"; + case OP_READ_REQ_MADE: + return "OP_READ_REQ_MADE"; + case OP_NUM_OPS: + return "OP_NUM_OPS"; + } + return "UNKNOWN"; +} + +static void null_and_maybe_free_read_buffer(stream_obj* s) { + if (s->state.rs.read_buffer && + s->state.rs.read_buffer != s->state.rs.grpc_header_bytes) { + gpr_free(s->state.rs.read_buffer); + } + s->state.rs.read_buffer = nullptr; +} + +static void read_grpc_header(stream_obj* s) { + s->state.rs.read_buffer = s->state.rs.grpc_header_bytes; + s->state.rs.remaining_bytes = GRPC_HEADER_SIZE_IN_BYTES; + s->state.rs.received_bytes = 0; + s->state.rs.compressed = false; + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_read(%p)", s->cbs); + bidirectional_stream_read(s->cbs, s->state.rs.read_buffer, + s->state.rs.remaining_bytes); +} + +static grpc_error_handle make_error_with_desc(int error_code, + int cronet_internal_error_code, + const char* desc) { + return grpc_error_set_int(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Cronet error code:%d, Cronet error detail:%s", + cronet_internal_error_code, desc)), + GRPC_ERROR_INT_GRPC_STATUS, error_code); +} + +inline op_and_state::op_and_state(stream_obj* s, + const grpc_transport_stream_op_batch& op) + : op(op), state(s->arena), s(s) {} + +/* + Add a new stream op to op storage. +*/ +static void add_to_storage(struct stream_obj* s, + grpc_transport_stream_op_batch* op) { + struct op_storage* storage = &s->storage; + /* add new op at the beginning of the linked list. The memory is freed + in remove_from_storage */ + op_and_state* new_op = new op_and_state(s, *op); + gpr_mu_lock(&s->mu); + new_op->next = storage->head; + storage->head = new_op; + storage->num_pending_ops++; + if (op->send_message) { + s->state.pending_send_message = true; + } + if (op->recv_trailing_metadata) { + s->state.pending_recv_trailing_metadata = true; + } + CRONET_LOG(GPR_DEBUG, "adding new op %p. %d in the queue.", new_op, + storage->num_pending_ops); + gpr_mu_unlock(&s->mu); +} + +/* + Traverse the linked list and delete op and free memory +*/ +static void remove_from_storage(struct stream_obj* s, + struct op_and_state* oas) { + struct op_and_state* curr; + if (s->storage.head == nullptr || oas == nullptr) { + return; + } + if (s->storage.head == oas) { + s->storage.head = oas->next; + delete oas; + s->storage.num_pending_ops--; + CRONET_LOG(GPR_DEBUG, "Freed %p. Now %d in the queue", oas, + s->storage.num_pending_ops); + } else { + for (curr = s->storage.head; curr != nullptr; curr = curr->next) { + if (curr->next == oas) { + curr->next = oas->next; + s->storage.num_pending_ops--; + CRONET_LOG(GPR_DEBUG, "Freed %p. Now %d in the queue", oas, + s->storage.num_pending_ops); + delete oas; + break; + } else if (GPR_UNLIKELY(curr->next == nullptr)) { + CRONET_LOG(GPR_ERROR, "Reached end of LL and did not find op to free"); + } + } + } +} + +/* + Cycle through ops and try to take next action. Break when either + an action with callback is taken, or no action is possible. + This can get executed from the Cronet network thread via cronet callback + or on the application supplied thread via the perform_stream_op function. +*/ +static void execute_from_storage(stream_obj* s) { + gpr_mu_lock(&s->mu); + for (struct op_and_state* curr = s->storage.head; curr != nullptr;) { + CRONET_LOG(GPR_DEBUG, "calling op at %p. done = %d", curr, curr->done); + GPR_ASSERT(!curr->done); + enum e_op_result result = execute_stream_op(curr); + CRONET_LOG(GPR_DEBUG, "execute_stream_op[%p] returns %s", curr, + op_result_string(result)); + /* if this op is done, then remove it and free memory */ + if (curr->done) { + struct op_and_state* next = curr->next; + remove_from_storage(s, curr); + curr = next; + } else if (result == NO_ACTION_POSSIBLE) { + curr = curr->next; + } else if (result == ACTION_TAKEN_WITH_CALLBACK) { + /* wait for the callback */ + break; + } /* continue processing the same op if ACTION_TAKEN_WITHOUT_CALLBACK */ + } + gpr_mu_unlock(&s->mu); +} + +static void convert_cronet_array_to_metadata( + const bidirectional_stream_header_array* header_array, + grpc_metadata_batch* mds) { + for (size_t i = 0; i < header_array->count; i++) { + CRONET_LOG(GPR_DEBUG, "header key=%s, value=%s", + header_array->headers[i].key, header_array->headers[i].value); + grpc_slice value; + if (absl::EndsWith(header_array->headers[i].key, "-bin")) { + value = grpc_slice_from_static_string(header_array->headers[i].value); + value = grpc_slice_intern(grpc_chttp2_base64_decode_with_length( + value, grpc_chttp2_base64_infer_length_after_decode(value))); + } else { + value = grpc_slice_intern( + grpc_slice_from_static_string(header_array->headers[i].value)); + } + mds->Append(header_array->headers[i].key, value); + } +} + +/* + Cronet callback +*/ +static void on_failed(bidirectional_stream* stream, int net_error) { + gpr_log(GPR_ERROR, "on_failed(%p, %d)", stream, net_error); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + stream_obj* s = static_cast(stream->annotation); + gpr_mu_lock(&s->mu); + bidirectional_stream_destroy(s->cbs); + s->state.state_callback_received[OP_FAILED] = true; + s->state.net_error = static_cast(net_error); + s->cbs = nullptr; + if (s->header_array.headers) { + gpr_free(s->header_array.headers); + s->header_array.headers = nullptr; + } + if (s->state.ws.write_buffer) { + gpr_free(s->state.ws.write_buffer); + s->state.ws.write_buffer = nullptr; + } + null_and_maybe_free_read_buffer(s); + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + GRPC_CRONET_STREAM_UNREF(s, "cronet transport"); +} + +/* + Cronet callback +*/ +static void on_canceled(bidirectional_stream* stream) { + CRONET_LOG(GPR_DEBUG, "on_canceled(%p)", stream); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + stream_obj* s = static_cast(stream->annotation); + gpr_mu_lock(&s->mu); + bidirectional_stream_destroy(s->cbs); + s->state.state_callback_received[OP_CANCELED] = true; + s->cbs = nullptr; + if (s->header_array.headers) { + gpr_free(s->header_array.headers); + s->header_array.headers = nullptr; + } + if (s->state.ws.write_buffer) { + gpr_free(s->state.ws.write_buffer); + s->state.ws.write_buffer = nullptr; + } + null_and_maybe_free_read_buffer(s); + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + GRPC_CRONET_STREAM_UNREF(s, "cronet transport"); +} + +/* + Cronet callback +*/ +static void on_succeeded(bidirectional_stream* stream) { + CRONET_LOG(GPR_DEBUG, "on_succeeded(%p)", stream); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + stream_obj* s = static_cast(stream->annotation); + gpr_mu_lock(&s->mu); + bidirectional_stream_destroy(s->cbs); + s->state.state_callback_received[OP_SUCCEEDED] = true; + s->cbs = nullptr; + null_and_maybe_free_read_buffer(s); + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + GRPC_CRONET_STREAM_UNREF(s, "cronet transport"); +} + +/* + Cronet callback +*/ +static void on_stream_ready(bidirectional_stream* stream) { + CRONET_LOG(GPR_DEBUG, "W: on_stream_ready(%p)", stream); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + stream_obj* s = static_cast(stream->annotation); + grpc_cronet_transport* t = s->curr_ct; + gpr_mu_lock(&s->mu); + s->state.state_op_done[OP_SEND_INITIAL_METADATA] = true; + s->state.state_callback_received[OP_SEND_INITIAL_METADATA] = true; + /* Free the memory allocated for headers */ + if (s->header_array.headers) { + gpr_free(s->header_array.headers); + s->header_array.headers = nullptr; + } + /* Send the initial metadata on wire if there is no SEND_MESSAGE or + * SEND_TRAILING_METADATA ops pending */ + if (t->use_packet_coalescing) { + if (s->state.flush_cronet_when_ready) { + CRONET_LOG(GPR_DEBUG, "cronet_bidirectional_stream_flush (%p)", s->cbs); + bidirectional_stream_flush(stream); + } + } + gpr_mu_unlock(&s->mu); + execute_from_storage(s); +} + +/* + Cronet callback +*/ +static void on_response_headers_received( + bidirectional_stream* stream, + const bidirectional_stream_header_array* headers, + const char* negotiated_protocol) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + CRONET_LOG(GPR_DEBUG, "R: on_response_headers_received(%p, %p, %s)", stream, + headers, negotiated_protocol); + stream_obj* s = static_cast(stream->annotation); + + /* Identify if this is a header or a trailer (in a trailer-only response case) + */ + for (size_t i = 0; i < headers->count; i++) { + if (0 == strcmp("grpc-status", headers->headers[i].key)) { + on_response_trailers_received(stream, headers); + + /* Do an extra read for a trailer-only stream to trigger on_succeeded() + * callback */ + read_grpc_header(s); + return; + } + } + + gpr_mu_lock(&s->mu); + convert_cronet_array_to_metadata(headers, &s->state.rs.initial_metadata); + s->state.state_callback_received[OP_RECV_INITIAL_METADATA] = true; + if (!(s->state.state_op_done[OP_CANCEL_ERROR] || + s->state.state_callback_received[OP_FAILED])) { + /* Do an extra read to trigger on_succeeded() callback in case connection + is closed */ + GPR_ASSERT(s->state.rs.length_field_received == false); + read_grpc_header(s); + } + gpr_mu_unlock(&s->mu); + execute_from_storage(s); +} + +/* + Cronet callback +*/ +static void on_write_completed(bidirectional_stream* stream, const char* data) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + stream_obj* s = static_cast(stream->annotation); + CRONET_LOG(GPR_DEBUG, "W: on_write_completed(%p, %s)", stream, data); + gpr_mu_lock(&s->mu); + if (s->state.ws.write_buffer) { + gpr_free(s->state.ws.write_buffer); + s->state.ws.write_buffer = nullptr; + } + s->state.state_callback_received[OP_SEND_MESSAGE] = true; + gpr_mu_unlock(&s->mu); + execute_from_storage(s); +} + +/* + Cronet callback +*/ +static void on_read_completed(bidirectional_stream* stream, char* data, + int count) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + stream_obj* s = static_cast(stream->annotation); + CRONET_LOG(GPR_DEBUG, "R: on_read_completed(%p, %p, %d)", stream, data, + count); + gpr_mu_lock(&s->mu); + s->state.state_callback_received[OP_RECV_MESSAGE] = true; + if (count > 0 && s->state.flush_read) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_read(%p)", s->cbs); + bidirectional_stream_read(s->cbs, s->state.rs.read_buffer, + GRPC_FLUSH_READ_SIZE); + gpr_mu_unlock(&s->mu); + } else if (count > 0) { + s->state.rs.received_bytes += count; + s->state.rs.remaining_bytes -= count; + if (s->state.rs.remaining_bytes > 0) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_read(%p)", s->cbs); + s->state.state_op_done[OP_READ_REQ_MADE] = true; + bidirectional_stream_read( + s->cbs, s->state.rs.read_buffer + s->state.rs.received_bytes, + s->state.rs.remaining_bytes); + gpr_mu_unlock(&s->mu); + } else { + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + } + } else { + null_and_maybe_free_read_buffer(s); + s->state.rs.read_stream_closed = true; + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + } +} + +/* + Cronet callback +*/ +static void on_response_trailers_received( + bidirectional_stream* stream, + const bidirectional_stream_header_array* trailers) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + CRONET_LOG(GPR_DEBUG, "R: on_response_trailers_received(%p,%p)", stream, + trailers); + stream_obj* s = static_cast(stream->annotation); + grpc_cronet_transport* t = s->curr_ct; + gpr_mu_lock(&s->mu); + s->state.rs.trailing_metadata_valid = false; + convert_cronet_array_to_metadata(trailers, &s->state.rs.trailing_metadata); + if (trailers->count > 0) { + s->state.rs.trailing_metadata_valid = true; + } + s->state.state_callback_received[OP_RECV_TRAILING_METADATA] = true; + /* Send a EOS when server terminates the stream (testServerFinishesRequest) to + * trigger on_succeeded */ + if (!s->state.state_op_done[OP_SEND_TRAILING_METADATA] && + !(s->state.state_op_done[OP_CANCEL_ERROR] || + s->state.state_callback_received[OP_FAILED])) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_write (%p, 0)", s->cbs); + s->state.state_callback_received[OP_SEND_MESSAGE] = false; + bidirectional_stream_write(s->cbs, "", 0, true); + if (t->use_packet_coalescing) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_flush (%p)", s->cbs); + bidirectional_stream_flush(s->cbs); + } + s->state.state_op_done[OP_SEND_TRAILING_METADATA] = true; + + gpr_mu_unlock(&s->mu); + } else { + gpr_mu_unlock(&s->mu); + execute_from_storage(s); + } +} + +/* + Utility function that takes the data from s->write_slice_buffer and assembles + into a contiguous byte stream with 5 byte gRPC header prepended. +*/ +static void create_grpc_frame(grpc_slice_buffer* write_slice_buffer, + char** pp_write_buffer, + size_t* p_write_buffer_size, uint32_t flags) { + size_t length = write_slice_buffer->length; + *p_write_buffer_size = length + GRPC_HEADER_SIZE_IN_BYTES; + /* This is freed in the on_write_completed callback */ + char* write_buffer = + static_cast(gpr_malloc(length + GRPC_HEADER_SIZE_IN_BYTES)); + *pp_write_buffer = write_buffer; + uint8_t* p = reinterpret_cast(write_buffer); + /* Append 5 byte header */ + /* Compressed flag */ + *p++ = static_cast((flags & GRPC_WRITE_INTERNAL_COMPRESS) ? 1 : 0); + /* Message length */ + *p++ = static_cast(length >> 24); + *p++ = static_cast(length >> 16); + *p++ = static_cast(length >> 8); + *p++ = static_cast(length); + /* append actual data */ + size_t offset = 0; + for (size_t i = 0; i < write_slice_buffer->count; ++i) { + memcpy(p + offset, GRPC_SLICE_START_PTR(write_slice_buffer->slices[i]), + GRPC_SLICE_LENGTH(write_slice_buffer->slices[i])); + offset += GRPC_SLICE_LENGTH(write_slice_buffer->slices[i]); + } +} + +namespace { +class CronetMetadataEncoder { + public: + explicit CronetMetadataEncoder(bidirectional_stream_header** pp_headers, + size_t* p_count, const char* host, + size_t capacity, const char** method, + std::string* url) + : host_(host), + capacity_(capacity), + count_(*p_count), + headers_(*pp_headers), + method_(method), + url_(url) { + count_ = 0; + headers_ = static_cast( + gpr_malloc(sizeof(bidirectional_stream_header) * capacity_)); + } + + CronetMetadataEncoder(const CronetMetadataEncoder&) = delete; + CronetMetadataEncoder& operator=(const CronetMetadataEncoder&) = delete; + + template + void Encode(T, V value) { + auto value_slice = T::Encode(value); + auto key_slice = grpc_slice_from_static_string(T::key()); + auto mdelem = grpc_mdelem_from_slices(key_slice, value_slice); + Encode(mdelem); + GRPC_MDELEM_UNREF(mdelem); + } + + void Encode(grpc_mdelem mdelem) { + char* key = grpc_slice_to_c_string(GRPC_MDKEY(mdelem)); + char* value; + if (grpc_is_binary_header_internal(GRPC_MDKEY(mdelem))) { + grpc_slice wire_value = grpc_chttp2_base64_encode(GRPC_MDVALUE(mdelem)); + value = grpc_slice_to_c_string(wire_value); + grpc_slice_unref_internal(wire_value); + } else { + value = grpc_slice_to_c_string(GRPC_MDVALUE(mdelem)); + } + if (grpc_slice_eq_static_interned(GRPC_MDKEY(mdelem), GRPC_MDSTR_SCHEME) || + grpc_slice_eq_static_interned(GRPC_MDKEY(mdelem), + GRPC_MDSTR_AUTHORITY)) { + /* Cronet populates these fields on its own */ + gpr_free(key); + gpr_free(value); + return; + } + if (grpc_slice_eq_static_interned(GRPC_MDKEY(mdelem), GRPC_MDSTR_METHOD)) { + if (grpc_slice_eq_static_interned(GRPC_MDVALUE(mdelem), GRPC_MDSTR_PUT)) { + *method_ = "PUT"; + } else { + /* POST method in default*/ + *method_ = "POST"; + } + gpr_free(key); + gpr_free(value); + return; + } + if (grpc_slice_eq_static_interned(GRPC_MDKEY(mdelem), GRPC_MDSTR_PATH)) { + /* Create URL by appending :path value to the hostname */ + *url_ = absl::StrCat("https://", host_, value); + gpr_free(key); + gpr_free(value); + return; + } + CRONET_LOG(GPR_DEBUG, "header %s = %s", key, value); + GPR_ASSERT(count_ < capacity_); + headers_[count_].key = key; + headers_[count_].value = value; + ++count_; + } + + private: + const char* host_; + size_t capacity_; + size_t& count_; + bidirectional_stream_header*& headers_; + const char** method_; + std::string* url_; +}; +} // namespace + +/* + Convert metadata in a format that Cronet can consume +*/ +static void convert_metadata_to_cronet_headers( + grpc_metadata_batch* metadata, const char* host, std::string* pp_url, + bidirectional_stream_header** pp_headers, size_t* p_num_headers, + const char** method) { + CronetMetadataEncoder encoder(pp_headers, p_num_headers, host, + metadata->count(), method, pp_url); + metadata->Encode(&encoder); +} + +static void parse_grpc_header(const uint8_t* data, int* length, + bool* compressed) { + const uint8_t c = *data; + const uint8_t* p = data + 1; + *compressed = ((c & 0x01) == 0x01); + *length = 0; + *length |= (*p++) << 24; + *length |= (*p++) << 16; + *length |= (*p++) << 8; + *length |= (*p++); +} + +static bool header_has_authority(const grpc_metadata_batch* b) { + bool found = false; + b->ForEach([&](grpc_mdelem elem) { + if (grpc_slice_eq_static_interned(GRPC_MDKEY(elem), GRPC_MDSTR_AUTHORITY)) { + found = true; + } + }); + return found; +} + +/* + Op Execution: Decide if one of the actions contained in the stream op can be + executed. This is the heart of the state machine. +*/ +static bool op_can_be_run(grpc_transport_stream_op_batch* curr_op, + struct stream_obj* s, struct op_state* op_state, + enum e_op_id op_id) { + struct op_state* stream_state = &s->state; + grpc_cronet_transport* t = s->curr_ct; + bool result = true; + /* When call is canceled, every op can be run, except under following + conditions + */ + bool is_canceled_or_failed = stream_state->state_op_done[OP_CANCEL_ERROR] || + stream_state->state_callback_received[OP_FAILED]; + if (is_canceled_or_failed) { + if (op_id == OP_SEND_INITIAL_METADATA) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + if (op_id == OP_SEND_MESSAGE) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + if (op_id == OP_SEND_TRAILING_METADATA) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + if (op_id == OP_CANCEL_ERROR) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + /* already executed */ + if (op_id == OP_RECV_INITIAL_METADATA && + stream_state->state_op_done[OP_RECV_INITIAL_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + if (op_id == OP_RECV_MESSAGE && op_state->state_op_done[OP_RECV_MESSAGE]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + if (op_id == OP_RECV_TRAILING_METADATA && + stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + /* ON_COMPLETE can be processed if one of the following conditions is met: + * 1. the stream failed + * 2. the stream is cancelled, and the callback is received + * 3. the stream succeeded before cancel is effective + * 4. the stream is cancelled, and the stream is never started */ + if (op_id == OP_ON_COMPLETE && + !(stream_state->state_callback_received[OP_FAILED] || + stream_state->state_callback_received[OP_CANCELED] || + stream_state->state_callback_received[OP_SUCCEEDED] || + !stream_state->state_op_done[OP_SEND_INITIAL_METADATA])) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + } else if (op_id == OP_SEND_INITIAL_METADATA) { + /* already executed */ + if (stream_state->state_op_done[OP_SEND_INITIAL_METADATA]) result = false; + } else if (op_id == OP_RECV_INITIAL_METADATA) { + if (stream_state->state_op_done[OP_RECV_INITIAL_METADATA]) { + /* already executed */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_SEND_INITIAL_METADATA]) { + /* we haven't sent headers yet. */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_RECV_INITIAL_METADATA] && + !stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + /* we haven't received headers yet. */ + result = false; + } + } else if (op_id == OP_SEND_MESSAGE) { + if (op_state->state_op_done[OP_SEND_MESSAGE]) { + /* already executed (note we're checking op specific state, not stream + state) */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_SEND_INITIAL_METADATA]) { + /* we haven't sent headers yet. */ + result = false; + } + } else if (op_id == OP_RECV_MESSAGE) { + if (op_state->state_op_done[OP_RECV_MESSAGE]) { + /* already executed */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_RECV_INITIAL_METADATA] && + !stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + /* we haven't received headers yet. */ + result = false; + } + } else if (op_id == OP_RECV_TRAILING_METADATA) { + if (stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + /* already executed */ + result = false; + } else if (stream_state->state_op_done[OP_READ_REQ_MADE] && + !stream_state->state_op_done[OP_RECV_MESSAGE]) { + /* we have asked for but haven't received message yet. */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_RECV_TRAILING_METADATA]) { + /* we haven't received trailers yet. */ + result = false; + } else if (!stream_state->state_callback_received[OP_SUCCEEDED]) { + /* we haven't received on_succeeded yet. */ + result = false; + } + } else if (op_id == OP_SEND_TRAILING_METADATA) { + if (stream_state->state_op_done[OP_SEND_TRAILING_METADATA]) { + /* already executed */ + result = false; + } else if (!stream_state + ->state_callback_received[OP_SEND_INITIAL_METADATA]) { + /* we haven't sent initial metadata yet */ + result = false; + } else if (stream_state->pending_send_message && + !stream_state->state_op_done[OP_SEND_MESSAGE]) { + /* we haven't sent message yet */ + result = false; + } else if (stream_state->state_op_done[OP_SEND_MESSAGE] && + !stream_state->state_callback_received[OP_SEND_MESSAGE] && + !(t->use_packet_coalescing && + stream_state->pending_write_for_trailer)) { + /* we haven't got on_write_completed for the send yet */ + result = false; + } + } else if (op_id == OP_CANCEL_ERROR) { + /* already executed */ + if (stream_state->state_op_done[OP_CANCEL_ERROR]) result = false; + } else if (op_id == OP_ON_COMPLETE) { + if (op_state->state_op_done[OP_ON_COMPLETE]) { + /* already executed (note we're checking op specific state, not stream + state) */ + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + /* Check if every op that was asked for is done. */ + /* TODO(muxi): We should not consider the recv ops here, since they + * have their own callbacks. We should invoke a batch's on_complete + * as soon as all of the batch's send ops are complete, even if + * there are still recv ops pending. */ + else if (curr_op->send_initial_metadata && + !stream_state->state_callback_received[OP_SEND_INITIAL_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->send_message && + !op_state->state_op_done[OP_SEND_MESSAGE]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->send_message && + !stream_state->state_callback_received[OP_SEND_MESSAGE]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->send_trailing_metadata && + !stream_state->state_op_done[OP_SEND_TRAILING_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->recv_initial_metadata && + !stream_state->state_op_done[OP_RECV_INITIAL_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->recv_message && + !op_state->state_op_done[OP_RECV_MESSAGE]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->cancel_stream && + !stream_state->state_callback_received[OP_CANCELED]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } else if (curr_op->recv_trailing_metadata) { + /* We aren't done with trailing metadata yet */ + if (!stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + /* We've asked for actual message in an earlier op, and it hasn't been + delivered yet. */ + else if (stream_state->state_op_done[OP_READ_REQ_MADE]) { + /* If this op is not the one asking for read, (which means some earlier + op has asked), and the read hasn't been delivered. */ + if (!curr_op->recv_message && + !stream_state->state_callback_received[OP_SUCCEEDED]) { + CRONET_LOG(GPR_DEBUG, "Because"); + result = false; + } + } + } + /* We should see at least one on_write_completed for the trailers that we + sent */ + else if (curr_op->send_trailing_metadata && + !stream_state->state_callback_received[OP_SEND_MESSAGE]) { + result = false; + } + } + CRONET_LOG(GPR_DEBUG, "op_can_be_run %s : %s", op_id_string(op_id), + result ? "YES" : "NO"); + return result; +} + +/* + TODO (makdharma): Break down this function in smaller chunks for readability. +*/ +static enum e_op_result execute_stream_op(struct op_and_state* oas) { + grpc_transport_stream_op_batch* stream_op = &oas->op; + struct stream_obj* s = oas->s; + grpc_cronet_transport* t = s->curr_ct; + struct op_state* stream_state = &s->state; + enum e_op_result result = NO_ACTION_POSSIBLE; + if (stream_op->send_initial_metadata && + op_can_be_run(stream_op, s, &oas->state, OP_SEND_INITIAL_METADATA)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_SEND_INITIAL_METADATA", oas); + /* Start new cronet stream. It is destroyed in on_succeeded, on_canceled, + * on_failed */ + GPR_ASSERT(s->cbs == nullptr); + GPR_ASSERT(!stream_state->state_op_done[OP_SEND_INITIAL_METADATA]); + s->cbs = + bidirectional_stream_create(t->engine, s->curr_gs, &cronet_callbacks); + CRONET_LOG(GPR_DEBUG, "%p = bidirectional_stream_create()", s->cbs); + if (t->use_packet_coalescing) { + bidirectional_stream_disable_auto_flush(s->cbs, true); + bidirectional_stream_delay_request_headers_until_flush(s->cbs, true); + } + std::string url; + const char* method = "POST"; + s->header_array.headers = nullptr; + convert_metadata_to_cronet_headers( + stream_op->payload->send_initial_metadata.send_initial_metadata, + t->host, &url, &s->header_array.headers, &s->header_array.count, + &method); + s->header_array.capacity = s->header_array.count; + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_start(%p, %s)", s->cbs, + url.c_str()); + bidirectional_stream_start(s->cbs, url.c_str(), 0, method, &s->header_array, + false); + unsigned int header_index; + for (header_index = 0; header_index < s->header_array.count; + header_index++) { + gpr_free(const_cast(s->header_array.headers[header_index].key)); + gpr_free(const_cast(s->header_array.headers[header_index].value)); + } + stream_state->state_op_done[OP_SEND_INITIAL_METADATA] = true; + if (t->use_packet_coalescing) { + if (!stream_op->send_message && !stream_op->send_trailing_metadata) { + s->state.flush_cronet_when_ready = true; + } + } + result = ACTION_TAKEN_WITH_CALLBACK; + } else if (stream_op->send_message && + op_can_be_run(stream_op, s, &oas->state, OP_SEND_MESSAGE)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_SEND_MESSAGE", oas); + stream_state->pending_send_message = false; + if (stream_state->state_op_done[OP_CANCEL_ERROR] || + stream_state->state_callback_received[OP_FAILED] || + stream_state->state_callback_received[OP_SUCCEEDED]) { + result = NO_ACTION_POSSIBLE; + CRONET_LOG(GPR_DEBUG, "Stream is either cancelled, failed or finished"); + } else { + grpc_slice_buffer write_slice_buffer; + grpc_slice slice; + grpc_slice_buffer_init(&write_slice_buffer); + while (write_slice_buffer.length < + stream_op->payload->send_message.send_message->length()) { + /* TODO(roth): When we add support for incremental sending,this code + * will need to be changed to support asynchronous delivery of the + * send_message payload. */ + if (!stream_op->payload->send_message.send_message->Next( + stream_op->payload->send_message.send_message->length(), + nullptr)) { + /* Should never reach here */ + GPR_ASSERT(false); + } + if (GRPC_ERROR_NONE != + stream_op->payload->send_message.send_message->Pull(&slice)) { + /* Should never reach here */ + GPR_ASSERT(false); + } + grpc_slice_buffer_add(&write_slice_buffer, slice); + } + size_t write_buffer_size; + create_grpc_frame(&write_slice_buffer, &stream_state->ws.write_buffer, + &write_buffer_size, + stream_op->payload->send_message.send_message->flags()); + if (write_buffer_size > 0) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_write (%p, %p)", s->cbs, + stream_state->ws.write_buffer); + stream_state->state_callback_received[OP_SEND_MESSAGE] = false; + bidirectional_stream_write(s->cbs, stream_state->ws.write_buffer, + static_cast(write_buffer_size), false); + grpc_slice_buffer_destroy_internal(&write_slice_buffer); + if (t->use_packet_coalescing) { + if (!stream_op->send_trailing_metadata) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_flush (%p)", s->cbs); + bidirectional_stream_flush(s->cbs); + result = ACTION_TAKEN_WITH_CALLBACK; + } else { + stream_state->pending_write_for_trailer = true; + result = ACTION_TAKEN_NO_CALLBACK; + } + } else { + result = ACTION_TAKEN_WITH_CALLBACK; + } + } else { + /* Should never reach here */ + GPR_ASSERT(false); + } + } + stream_state->state_op_done[OP_SEND_MESSAGE] = true; + oas->state.state_op_done[OP_SEND_MESSAGE] = true; + stream_op->payload->send_message.send_message.reset(); + } else if (stream_op->send_trailing_metadata && + op_can_be_run(stream_op, s, &oas->state, + OP_SEND_TRAILING_METADATA)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_SEND_TRAILING_METADATA", oas); + if (stream_state->state_op_done[OP_CANCEL_ERROR] || + stream_state->state_callback_received[OP_FAILED] || + stream_state->state_callback_received[OP_SUCCEEDED]) { + result = NO_ACTION_POSSIBLE; + CRONET_LOG(GPR_DEBUG, "Stream is either cancelled, failed or finished"); + } else { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_write (%p, 0)", s->cbs); + stream_state->state_callback_received[OP_SEND_MESSAGE] = false; + bidirectional_stream_write(s->cbs, "", 0, true); + if (t->use_packet_coalescing) { + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_flush (%p)", s->cbs); + bidirectional_stream_flush(s->cbs); + } + result = ACTION_TAKEN_WITH_CALLBACK; + } + stream_state->state_op_done[OP_SEND_TRAILING_METADATA] = true; + } else if (stream_op->recv_initial_metadata && + op_can_be_run(stream_op, s, &oas->state, + OP_RECV_INITIAL_METADATA)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_RECV_INITIAL_METADATA", oas); + if (stream_state->state_op_done[OP_CANCEL_ERROR]) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_NONE); + } else if (stream_state->state_callback_received[OP_FAILED]) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_NONE); + } else if (stream_state->state_op_done[OP_RECV_TRAILING_METADATA]) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_NONE); + } else { + *stream_op->payload->recv_initial_metadata.recv_initial_metadata = + std::move(oas->s->state.rs.initial_metadata); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_NONE); + } + stream_state->state_op_done[OP_RECV_INITIAL_METADATA] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (stream_op->recv_message && + op_can_be_run(stream_op, s, &oas->state, OP_RECV_MESSAGE)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_RECV_MESSAGE", oas); + if (stream_state->state_op_done[OP_CANCEL_ERROR]) { + CRONET_LOG(GPR_DEBUG, "Stream is cancelled."); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (stream_state->state_callback_received[OP_FAILED]) { + CRONET_LOG(GPR_DEBUG, "Stream failed."); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (stream_state->rs.read_stream_closed) { + /* No more data will be received */ + CRONET_LOG(GPR_DEBUG, "read stream closed"); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (stream_state->flush_read) { + CRONET_LOG(GPR_DEBUG, "flush read"); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (!stream_state->rs.length_field_received) { + if (stream_state->rs.received_bytes == GRPC_HEADER_SIZE_IN_BYTES && + stream_state->rs.remaining_bytes == 0) { + /* Start a read operation for data */ + stream_state->rs.length_field_received = true; + parse_grpc_header( + reinterpret_cast(stream_state->rs.read_buffer), + &stream_state->rs.length_field, &stream_state->rs.compressed); + CRONET_LOG(GPR_DEBUG, "length field = %d", + stream_state->rs.length_field); + if (stream_state->rs.length_field > 0) { + stream_state->rs.read_buffer = static_cast( + gpr_malloc(static_cast(stream_state->rs.length_field))); + GPR_ASSERT(stream_state->rs.read_buffer); + stream_state->rs.remaining_bytes = stream_state->rs.length_field; + stream_state->rs.received_bytes = 0; + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_read(%p)", s->cbs); + stream_state->state_op_done[OP_READ_REQ_MADE] = + true; /* Indicates that at least one read request has been made */ + bidirectional_stream_read(s->cbs, stream_state->rs.read_buffer, + stream_state->rs.remaining_bytes); + result = ACTION_TAKEN_WITH_CALLBACK; + } else { + stream_state->rs.remaining_bytes = 0; + CRONET_LOG(GPR_DEBUG, "read operation complete. Empty response."); + /* Clean up read_slice_buffer in case there is unread data. */ + grpc_slice_buffer_destroy_internal( + &stream_state->rs.read_slice_buffer); + grpc_slice_buffer_init(&stream_state->rs.read_slice_buffer); + uint32_t flags = 0; + if (stream_state->rs.compressed) { + flags |= GRPC_WRITE_INTERNAL_COMPRESS; + } + stream_state->rs.sbs.Init(&stream_state->rs.read_slice_buffer, flags); + stream_op->payload->recv_message.recv_message->reset( + stream_state->rs.sbs.get()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + + /* Extra read to trigger on_succeed */ + stream_state->rs.length_field_received = false; + stream_state->state_op_done[OP_READ_REQ_MADE] = + true; /* Indicates that at least one read request has been made */ + read_grpc_header(s); + result = ACTION_TAKEN_NO_CALLBACK; + } + } else if (stream_state->rs.remaining_bytes == 0) { + /* Start a read operation for first 5 bytes (GRPC header) */ + stream_state->rs.read_buffer = stream_state->rs.grpc_header_bytes; + stream_state->rs.remaining_bytes = GRPC_HEADER_SIZE_IN_BYTES; + stream_state->rs.received_bytes = 0; + stream_state->rs.compressed = false; + CRONET_LOG(GPR_DEBUG, "bidirectional_stream_read(%p)", s->cbs); + stream_state->state_op_done[OP_READ_REQ_MADE] = + true; /* Indicates that at least one read request has been made */ + bidirectional_stream_read(s->cbs, stream_state->rs.read_buffer, + stream_state->rs.remaining_bytes); + result = ACTION_TAKEN_WITH_CALLBACK; + } else { + result = NO_ACTION_POSSIBLE; + } + } else if (stream_state->rs.remaining_bytes == 0) { + CRONET_LOG(GPR_DEBUG, "read operation complete"); + grpc_slice read_data_slice = + GRPC_SLICE_MALLOC((uint32_t)stream_state->rs.length_field); + uint8_t* dst_p = GRPC_SLICE_START_PTR(read_data_slice); + memcpy(dst_p, stream_state->rs.read_buffer, + static_cast(stream_state->rs.length_field)); + null_and_maybe_free_read_buffer(s); + /* Clean up read_slice_buffer in case there is unread data. */ + grpc_slice_buffer_destroy_internal(&stream_state->rs.read_slice_buffer); + grpc_slice_buffer_init(&stream_state->rs.read_slice_buffer); + grpc_slice_buffer_add(&stream_state->rs.read_slice_buffer, + read_data_slice); + uint32_t flags = 0; + if (stream_state->rs.compressed) { + flags = GRPC_WRITE_INTERNAL_COMPRESS; + } + stream_state->rs.sbs.Init(&stream_state->rs.read_slice_buffer, flags); + stream_op->payload->recv_message.recv_message->reset( + stream_state->rs.sbs.get()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + stream_state->state_op_done[OP_RECV_MESSAGE] = true; + oas->state.state_op_done[OP_RECV_MESSAGE] = true; + /* Do an extra read to trigger on_succeeded() callback in case connection + is closed */ + stream_state->rs.length_field_received = false; + read_grpc_header(s); + result = ACTION_TAKEN_NO_CALLBACK; + } + } else if (stream_op->recv_trailing_metadata && + op_can_be_run(stream_op, s, &oas->state, + OP_RECV_TRAILING_METADATA)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_RECV_TRAILING_METADATA", oas); + grpc_error_handle error = GRPC_ERROR_NONE; + if (stream_state->state_op_done[OP_CANCEL_ERROR]) { + error = GRPC_ERROR_REF(stream_state->cancel_error); + } else if (stream_state->state_callback_received[OP_FAILED]) { + grpc_status_code grpc_error_code = + cronet_net_error_to_grpc_error(stream_state->net_error); + const char* desc = cronet_net_error_as_string(stream_state->net_error); + error = + make_error_with_desc(grpc_error_code, stream_state->net_error, desc); + } else if (oas->s->state.rs.trailing_metadata_valid) { + *stream_op->payload->recv_trailing_metadata.recv_trailing_metadata = + std::move(oas->s->state.rs.trailing_metadata); + stream_state->rs.trailing_metadata_valid = false; + } + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + stream_op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + error); + stream_state->state_op_done[OP_RECV_TRAILING_METADATA] = true; + result = ACTION_TAKEN_NO_CALLBACK; + } else if (stream_op->cancel_stream && + op_can_be_run(stream_op, s, &oas->state, OP_CANCEL_ERROR)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_CANCEL_ERROR", oas); + if (s->cbs) { + CRONET_LOG(GPR_DEBUG, "W: bidirectional_stream_cancel(%p)", s->cbs); + bidirectional_stream_cancel(s->cbs); + result = ACTION_TAKEN_WITH_CALLBACK; + } else { + result = ACTION_TAKEN_NO_CALLBACK; + } + stream_state->state_op_done[OP_CANCEL_ERROR] = true; + if (stream_state->cancel_error == GRPC_ERROR_NONE) { + stream_state->cancel_error = + GRPC_ERROR_REF(stream_op->payload->cancel_stream.cancel_error); + } + } else if (op_can_be_run(stream_op, s, &oas->state, OP_ON_COMPLETE)) { + CRONET_LOG(GPR_DEBUG, "running: %p OP_ON_COMPLETE", oas); + if (stream_state->state_op_done[OP_CANCEL_ERROR]) { + if (stream_op->on_complete) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, stream_op->on_complete, + GRPC_ERROR_REF(stream_state->cancel_error)); + } + } else if (stream_state->state_callback_received[OP_FAILED]) { + if (stream_op->on_complete) { + const char* error_message = + cronet_net_error_as_string(stream_state->net_error); + grpc_status_code grpc_error_code = + cronet_net_error_to_grpc_error(stream_state->net_error); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, stream_op->on_complete, + make_error_with_desc(grpc_error_code, stream_state->net_error, + error_message)); + } + } else { + /* All actions in this stream_op are complete. Call the on_complete + * callback + */ + if (stream_op->on_complete) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, stream_op->on_complete, + GRPC_ERROR_NONE); + } + } + oas->state.state_op_done[OP_ON_COMPLETE] = true; + oas->done = true; + /* reset any send message state, only if this ON_COMPLETE is about a send. + */ + if (stream_op->send_message) { + stream_state->state_callback_received[OP_SEND_MESSAGE] = false; + stream_state->state_op_done[OP_SEND_MESSAGE] = false; + } + result = ACTION_TAKEN_NO_CALLBACK; + /* If this is the on_complete callback being called for a received message - + make a note */ + if (stream_op->recv_message) { + stream_state->state_op_done[OP_RECV_MESSAGE_AND_ON_COMPLETE] = true; + } + } else { + result = NO_ACTION_POSSIBLE; + } + return result; +} + +/* + Functions used by upper layers to access transport functionality. +*/ + +inline stream_obj::stream_obj(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, + grpc_core::Arena* arena) + : arena(arena), + curr_ct(reinterpret_cast(gt)), + curr_gs(gs), + state(arena), + refcount(refcount) { + GRPC_CRONET_STREAM_REF(this, "cronet transport"); + gpr_mu_init(&mu); +} + +inline stream_obj::~stream_obj() { + null_and_maybe_free_read_buffer(this); + /* Clean up read_slice_buffer in case there is unread data. */ + grpc_slice_buffer_destroy_internal(&state.rs.read_slice_buffer); + GRPC_ERROR_UNREF(state.cancel_error); +} + +static int init_stream(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, + const void* /*server_data*/, grpc_core::Arena* arena) { + new (gs) stream_obj(gt, gs, refcount, arena); + return 0; +} + +static void set_pollset_do_nothing(grpc_transport* /*gt*/, grpc_stream* /*gs*/, + grpc_pollset* /*pollset*/) {} + +static void set_pollset_set_do_nothing(grpc_transport* /*gt*/, + grpc_stream* /*gs*/, + grpc_pollset_set* /*pollset_set*/) {} + +static void perform_stream_op(grpc_transport* /*gt*/, grpc_stream* gs, + grpc_transport_stream_op_batch* op) { + CRONET_LOG(GPR_DEBUG, "perform_stream_op"); + if (op->send_initial_metadata && + header_has_authority( + op->payload->send_initial_metadata.send_initial_metadata)) { + /* Cronet does not support :authority header field. We cancel the call when + this field is present in metadata */ + if (op->recv_initial_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_CANCELLED); + } + if (op->recv_message) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + op->payload->recv_message.recv_message_ready, + GRPC_ERROR_CANCELLED); + } + if (op->recv_trailing_metadata) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_CANCELLED); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, + GRPC_ERROR_CANCELLED); + return; + } + stream_obj* s = reinterpret_cast(gs); + add_to_storage(s, op); + execute_from_storage(s); +} + +static void destroy_stream(grpc_transport* /*gt*/, grpc_stream* gs, + grpc_closure* then_schedule_closure) { + stream_obj* s = reinterpret_cast(gs); + s->~stream_obj(); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, + GRPC_ERROR_NONE); +} + +static void destroy_transport(grpc_transport* /*gt*/) {} + +static grpc_endpoint* get_endpoint(grpc_transport* /*gt*/) { return nullptr; } + +static void perform_op(grpc_transport* /*gt*/, grpc_transport_op* /*op*/) {} + +static const grpc_transport_vtable grpc_cronet_vtable = { + sizeof(stream_obj), + "cronet_http", + init_stream, + set_pollset_do_nothing, + set_pollset_set_do_nothing, + perform_stream_op, + perform_op, + destroy_stream, + destroy_transport, + get_endpoint}; + +grpc_transport* grpc_create_cronet_transport(void* engine, const char* target, + const grpc_channel_args* args, + void* /*reserved*/) { + grpc_cronet_transport* ct = static_cast( + gpr_malloc(sizeof(grpc_cronet_transport))); + if (!ct) { + goto error; + } + ct->base.vtable = &grpc_cronet_vtable; + ct->engine = static_cast(engine); + ct->host = static_cast(gpr_malloc(strlen(target) + 1)); + if (!ct->host) { + goto error; + } + strcpy(ct->host, target); + + ct->use_packet_coalescing = true; + if (args) { + for (size_t i = 0; i < args->num_args; i++) { + if (0 == + strcmp(args->args[i].key, GRPC_ARG_USE_CRONET_PACKET_COALESCING)) { + if (GPR_UNLIKELY(args->args[i].type != GRPC_ARG_INTEGER)) { + gpr_log(GPR_ERROR, "%s ignored: it must be an integer", + GRPC_ARG_USE_CRONET_PACKET_COALESCING); + } else { + ct->use_packet_coalescing = (args->args[i].value.integer != 0); + } + } + } + } + + return &ct->base; + +error: + if (ct) { + if (ct->host) { + gpr_free(ct->host); + } + gpr_free(ct); + } + + return nullptr; +} diff --git a/src/core/ext/transport/inproc/inproc_plugin.cc b/src/core/ext/transport/inproc/inproc_plugin.cc new file mode 100644 index 00000000..8e251fa2 --- /dev/null +++ b/src/core/ext/transport/inproc/inproc_plugin.cc @@ -0,0 +1,28 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/debug/trace.h" + +grpc_core::TraceFlag grpc_inproc_trace(false, "inproc"); + +void grpc_inproc_plugin_init(void) { grpc_inproc_transport_init(); } + +void grpc_inproc_plugin_shutdown(void) { grpc_inproc_transport_shutdown(); } diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc new file mode 100644 index 00000000..31d5e0b7 --- /dev/null +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -0,0 +1,1360 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/inproc/inproc_transport.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/transport_impl.h" + +#define INPROC_LOG(...) \ + do { \ + if (GRPC_TRACE_FLAG_ENABLED(grpc_inproc_trace)) { \ + gpr_log(__VA_ARGS__); \ + } \ + } while (0) + +namespace { +grpc_slice g_empty_slice; +grpc_slice g_fake_path_key; +grpc_slice g_fake_path_value; +grpc_slice g_fake_auth_key; +grpc_slice g_fake_auth_value; + +struct inproc_stream; +bool cancel_stream_locked(inproc_stream* s, grpc_error_handle error); +void maybe_process_ops_locked(inproc_stream* s, grpc_error_handle error); +void op_state_machine_locked(inproc_stream* s, grpc_error_handle error); +void log_metadata(const grpc_metadata_batch* md_batch, bool is_client, + bool is_initial); +void fill_in_metadata(inproc_stream* s, const grpc_metadata_batch* metadata, + uint32_t flags, grpc_metadata_batch* out_md, + uint32_t* outflags, bool* markfilled); + +struct shared_mu { + shared_mu() { + // Share one lock between both sides since both sides get affected + gpr_mu_init(&mu); + gpr_ref_init(&refs, 2); + } + + ~shared_mu() { gpr_mu_destroy(&mu); } + + gpr_mu mu; + gpr_refcount refs; +}; + +struct inproc_transport { + inproc_transport(const grpc_transport_vtable* vtable, shared_mu* mu, + bool is_client) + : mu(mu), + is_client(is_client), + state_tracker(is_client ? "inproc_client" : "inproc_server", + GRPC_CHANNEL_READY) { + base.vtable = vtable; + // Start each side of transport with 2 refs since they each have a ref + // to the other + gpr_ref_init(&refs, 2); + } + + ~inproc_transport() { + if (gpr_unref(&mu->refs)) { + mu->~shared_mu(); + gpr_free(mu); + } + } + + void ref() { + INPROC_LOG(GPR_INFO, "ref_transport %p", this); + gpr_ref(&refs); + } + + void unref() { + INPROC_LOG(GPR_INFO, "unref_transport %p", this); + if (!gpr_unref(&refs)) { + return; + } + INPROC_LOG(GPR_INFO, "really_destroy_transport %p", this); + this->~inproc_transport(); + gpr_free(this); + } + + grpc_transport base; + shared_mu* mu; + gpr_refcount refs; + bool is_client; + grpc_core::ConnectivityStateTracker state_tracker; + void (*accept_stream_cb)(void* user_data, grpc_transport* transport, + const void* server_data); + void* accept_stream_data; + bool is_closed = false; + struct inproc_transport* other_side; + struct inproc_stream* stream_list = nullptr; +}; + +struct inproc_stream { + inproc_stream(inproc_transport* t, grpc_stream_refcount* refcount, + const void* server_data, grpc_core::Arena* arena) + : t(t), refs(refcount), arena(arena) { + // Ref this stream right now for ctor and list. + ref("inproc_init_stream:init"); + ref("inproc_init_stream:list"); + + stream_list_prev = nullptr; + gpr_mu_lock(&t->mu->mu); + stream_list_next = t->stream_list; + if (t->stream_list) { + t->stream_list->stream_list_prev = this; + } + t->stream_list = this; + gpr_mu_unlock(&t->mu->mu); + + if (!server_data) { + t->ref(); + inproc_transport* st = t->other_side; + st->ref(); + other_side = nullptr; // will get filled in soon + // Pass the client-side stream address to the server-side for a ref + ref("inproc_init_stream:clt"); // ref it now on behalf of server + // side to avoid destruction + INPROC_LOG(GPR_INFO, "calling accept stream cb %p %p", + st->accept_stream_cb, st->accept_stream_data); + (*st->accept_stream_cb)(st->accept_stream_data, &st->base, this); + } else { + // This is the server-side and is being called through accept_stream_cb + inproc_stream* cs = const_cast( + static_cast(server_data)); + other_side = cs; + // Ref the server-side stream on behalf of the client now + ref("inproc_init_stream:srv"); + + // Now we are about to affect the other side, so lock the transport + // to make sure that it doesn't get destroyed + gpr_mu_lock(&t->mu->mu); + cs->other_side = this; + // Now transfer from the other side's write_buffer if any to the to_read + // buffer + if (cs->write_buffer_initial_md_filled) { + (void)fill_in_metadata(this, &cs->write_buffer_initial_md, + cs->write_buffer_initial_md_flags, + &to_read_initial_md, &to_read_initial_md_flags, + &to_read_initial_md_filled); + deadline = std::min(deadline, cs->write_buffer_deadline); + cs->write_buffer_initial_md.Clear(); + cs->write_buffer_initial_md_filled = false; + } + if (cs->write_buffer_trailing_md_filled) { + (void)fill_in_metadata(this, &cs->write_buffer_trailing_md, 0, + &to_read_trailing_md, nullptr, + &to_read_trailing_md_filled); + cs->write_buffer_trailing_md.Clear(); + cs->write_buffer_trailing_md_filled = false; + } + if (cs->write_buffer_cancel_error != GRPC_ERROR_NONE) { + cancel_other_error = cs->write_buffer_cancel_error; + cs->write_buffer_cancel_error = GRPC_ERROR_NONE; + maybe_process_ops_locked(this, cancel_other_error); + } + + gpr_mu_unlock(&t->mu->mu); + } + } + + ~inproc_stream() { + GRPC_ERROR_UNREF(write_buffer_cancel_error); + GRPC_ERROR_UNREF(cancel_self_error); + GRPC_ERROR_UNREF(cancel_other_error); + + if (recv_inited) { + grpc_slice_buffer_destroy_internal(&recv_message); + } + + t->unref(); + } + +#ifndef NDEBUG +#define STREAM_REF(refs, reason) grpc_stream_ref(refs, reason) +#define STREAM_UNREF(refs, reason) grpc_stream_unref(refs, reason) +#else +#define STREAM_REF(refs, reason) grpc_stream_ref(refs) +#define STREAM_UNREF(refs, reason) grpc_stream_unref(refs) +#endif + void ref(const char* reason) { + INPROC_LOG(GPR_INFO, "ref_stream %p %s", this, reason); + STREAM_REF(refs, reason); + } + + void unref(const char* reason) { + INPROC_LOG(GPR_INFO, "unref_stream %p %s", this, reason); + STREAM_UNREF(refs, reason); + } +#undef STREAM_REF +#undef STREAM_UNREF + + inproc_transport* t; + grpc_stream_refcount* refs; + grpc_core::Arena* arena; + + grpc_metadata_batch to_read_initial_md{arena}; + uint32_t to_read_initial_md_flags = 0; + bool to_read_initial_md_filled = false; + grpc_metadata_batch to_read_trailing_md{arena}; + bool to_read_trailing_md_filled = false; + bool ops_needed = false; + // Write buffer used only during gap at init time when client-side + // stream is set up but server side stream is not yet set up + grpc_metadata_batch write_buffer_initial_md{arena}; + bool write_buffer_initial_md_filled = false; + uint32_t write_buffer_initial_md_flags = 0; + grpc_millis write_buffer_deadline = GRPC_MILLIS_INF_FUTURE; + grpc_metadata_batch write_buffer_trailing_md{arena}; + bool write_buffer_trailing_md_filled = false; + grpc_error_handle write_buffer_cancel_error = GRPC_ERROR_NONE; + + struct inproc_stream* other_side; + bool other_side_closed = false; // won't talk anymore + bool write_buffer_other_side_closed = false; // on hold + + grpc_transport_stream_op_batch* send_message_op = nullptr; + grpc_transport_stream_op_batch* send_trailing_md_op = nullptr; + grpc_transport_stream_op_batch* recv_initial_md_op = nullptr; + grpc_transport_stream_op_batch* recv_message_op = nullptr; + grpc_transport_stream_op_batch* recv_trailing_md_op = nullptr; + + grpc_slice_buffer recv_message; + grpc_core::ManualConstructor recv_stream; + bool recv_inited = false; + + bool initial_md_sent = false; + bool trailing_md_sent = false; + bool initial_md_recvd = false; + bool trailing_md_recvd = false; + // The following tracks if the server-side only pretends to have received + // trailing metadata since it no longer cares about the RPC. If that is the + // case, it is still ok for the client to send trailing metadata (in which + // case it will be ignored). + bool trailing_md_recvd_implicit_only = false; + + bool closed = false; + + grpc_error_handle cancel_self_error = GRPC_ERROR_NONE; + grpc_error_handle cancel_other_error = GRPC_ERROR_NONE; + + grpc_millis deadline = GRPC_MILLIS_INF_FUTURE; + + bool listed = true; + struct inproc_stream* stream_list_prev; + struct inproc_stream* stream_list_next; +}; + +void log_metadata(const grpc_metadata_batch* md_batch, bool is_client, + bool is_initial) { + md_batch->ForEach([=](grpc_mdelem md) { + char* key = grpc_slice_to_c_string(GRPC_MDKEY(md)); + char* value = grpc_slice_to_c_string(GRPC_MDVALUE(md)); + gpr_log(GPR_INFO, "INPROC:%s:%s: %s: %s", is_initial ? "HDR" : "TRL", + is_client ? "CLI" : "SVR", key, value); + gpr_free(key); + gpr_free(value); + }); +} + +namespace { + +class CopySink { + public: + explicit CopySink(grpc_metadata_batch* dst) : dst_(dst) {} + + void Encode(grpc_mdelem md) { + // Differently to grpc_metadata_batch_copy, we always copy slices here so + // that we don't need to deal with the plethora of edge cases in that world. + // TODO(ctiller): revisit this when deleting mdelem. + md = grpc_mdelem_from_slices(grpc_slice_intern(GRPC_MDKEY(md)), + grpc_slice_copy(GRPC_MDVALUE(md))); + // Error unused in non-debug builds. + grpc_error_handle GRPC_UNUSED error = dst_->Append(md); + // The only way that Append() can fail is if + // there's a duplicate entry for a callout. However, that can't be + // the case here, because we would not have been allowed to create + // a source batch that had that kind of conflict. + GPR_DEBUG_ASSERT(error == GRPC_ERROR_NONE); + } + + template + void Encode(T trait, V value) { + dst_->Set(trait, value); + } + + private: + grpc_metadata_batch* dst_; +}; + +} // namespace + +void fill_in_metadata(inproc_stream* s, const grpc_metadata_batch* metadata, + uint32_t flags, grpc_metadata_batch* out_md, + uint32_t* outflags, bool* markfilled) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_inproc_trace)) { + log_metadata(metadata, s->t->is_client, outflags != nullptr); + } + + if (outflags != nullptr) { + *outflags = flags; + } + if (markfilled != nullptr) { + *markfilled = true; + } + + // TODO(ctiller): copy the metadata batch, don't rely on a bespoke copy + // function. Can only do this once mdelems are out of the way though, too many + // edge cases otherwise. + out_md->Clear(); + CopySink sink(out_md); + metadata->Encode(&sink); +} + +int init_stream(grpc_transport* gt, grpc_stream* gs, + grpc_stream_refcount* refcount, const void* server_data, + grpc_core::Arena* arena) { + INPROC_LOG(GPR_INFO, "init_stream %p %p %p", gt, gs, server_data); + inproc_transport* t = reinterpret_cast(gt); + new (gs) inproc_stream(t, refcount, server_data, arena); + return 0; // return value is not important +} + +void close_stream_locked(inproc_stream* s) { + if (!s->closed) { + // Release the metadata that we would have written out + s->write_buffer_initial_md.Clear(); + s->write_buffer_trailing_md.Clear(); + + if (s->listed) { + inproc_stream* p = s->stream_list_prev; + inproc_stream* n = s->stream_list_next; + if (p != nullptr) { + p->stream_list_next = n; + } else { + s->t->stream_list = n; + } + if (n != nullptr) { + n->stream_list_prev = p; + } + s->listed = false; + s->unref("close_stream:list"); + } + s->closed = true; + s->unref("close_stream:closing"); + } +} + +// This function means that we are done talking/listening to the other side +void close_other_side_locked(inproc_stream* s, const char* reason) { + if (s->other_side != nullptr) { + // First release the metadata that came from the other side's arena + s->to_read_initial_md.Clear(); + s->to_read_trailing_md.Clear(); + + s->other_side->unref(reason); + s->other_side_closed = true; + s->other_side = nullptr; + } else if (!s->other_side_closed) { + s->write_buffer_other_side_closed = true; + } +} + +// Call the on_complete closure associated with this stream_op_batch if +// this stream_op_batch is only one of the pending operations for this +// stream. This is called when one of the pending operations for the stream +// is done and about to be NULLed out +void complete_if_batch_end_locked(inproc_stream* s, grpc_error_handle error, + grpc_transport_stream_op_batch* op, + const char* msg) { + int is_sm = static_cast(op == s->send_message_op); + int is_stm = static_cast(op == s->send_trailing_md_op); + // TODO(vjpai): We should not consider the recv ops here, since they + // have their own callbacks. We should invoke a batch's on_complete + // as soon as all of the batch's send ops are complete, even if there + // are still recv ops pending. + int is_rim = static_cast(op == s->recv_initial_md_op); + int is_rm = static_cast(op == s->recv_message_op); + int is_rtm = static_cast(op == s->recv_trailing_md_op); + + if ((is_sm + is_stm + is_rim + is_rm + is_rtm) == 1) { + INPROC_LOG(GPR_INFO, "%s %p %p %s", msg, s, op, + grpc_error_std_string(error).c_str()); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, + GRPC_ERROR_REF(error)); + } +} + +void maybe_process_ops_locked(inproc_stream* s, grpc_error_handle error) { + if (s && (error != GRPC_ERROR_NONE || s->ops_needed)) { + s->ops_needed = false; + op_state_machine_locked(s, error); + } +} + +void fail_helper_locked(inproc_stream* s, grpc_error_handle error) { + INPROC_LOG(GPR_INFO, "op_state_machine %p fail_helper", s); + // If we're failing this side, we need to make sure that + // we also send or have already sent trailing metadata + if (!s->trailing_md_sent) { + // Send trailing md to the other side indicating cancellation + s->trailing_md_sent = true; + + grpc_metadata_batch fake_md(s->arena); + inproc_stream* other = s->other_side; + grpc_metadata_batch* dest = (other == nullptr) + ? &s->write_buffer_trailing_md + : &other->to_read_trailing_md; + bool* destfilled = (other == nullptr) ? &s->write_buffer_trailing_md_filled + : &other->to_read_trailing_md_filled; + (void)fill_in_metadata(s, &fake_md, 0, dest, nullptr, destfilled); + + if (other != nullptr) { + if (other->cancel_other_error == GRPC_ERROR_NONE) { + other->cancel_other_error = GRPC_ERROR_REF(error); + } + maybe_process_ops_locked(other, error); + } else if (s->write_buffer_cancel_error == GRPC_ERROR_NONE) { + s->write_buffer_cancel_error = GRPC_ERROR_REF(error); + } + } + if (s->recv_initial_md_op) { + grpc_error_handle err; + if (!s->t->is_client) { + // If this is a server, provide initial metadata with a path and authority + // since it expects that as well as no error yet + grpc_metadata_batch fake_md(s->arena); + grpc_linked_mdelem* path_md = + static_cast(s->arena->Alloc(sizeof(*path_md))); + path_md->md = grpc_mdelem_from_slices(g_fake_path_key, g_fake_path_value); + GPR_ASSERT(fake_md.LinkTail(path_md) == GRPC_ERROR_NONE); + grpc_linked_mdelem* auth_md = + static_cast(s->arena->Alloc(sizeof(*auth_md))); + auth_md->md = grpc_mdelem_from_slices(g_fake_auth_key, g_fake_auth_value); + GPR_ASSERT(fake_md.LinkTail(auth_md) == GRPC_ERROR_NONE); + + (void)fill_in_metadata( + s, &fake_md, 0, + s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata, + s->recv_initial_md_op->payload->recv_initial_metadata.recv_flags, + nullptr); + err = GRPC_ERROR_NONE; + } else { + err = GRPC_ERROR_REF(error); + } + if (s->recv_initial_md_op->payload->recv_initial_metadata + .trailing_metadata_available != nullptr) { + // Set to true unconditionally, because we're failing the call, so even + // if we haven't actually seen the send_trailing_metadata op from the + // other side, we're going to return trailing metadata anyway. + *s->recv_initial_md_op->payload->recv_initial_metadata + .trailing_metadata_available = true; + } + INPROC_LOG(GPR_INFO, + "fail_helper %p scheduling initial-metadata-ready %s %s", s, + grpc_error_std_string(error).c_str(), + grpc_error_std_string(err).c_str()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata_ready, + err); + // Last use of err so no need to REF and then UNREF it + + complete_if_batch_end_locked( + s, error, s->recv_initial_md_op, + "fail_helper scheduling recv-initial-metadata-on-complete"); + s->recv_initial_md_op = nullptr; + } + if (s->recv_message_op) { + INPROC_LOG(GPR_INFO, "fail_helper %p scheduling message-ready %s", s, + grpc_error_std_string(error).c_str()); + if (s->recv_message_op->payload->recv_message + .call_failed_before_recv_message != nullptr) { + *s->recv_message_op->payload->recv_message + .call_failed_before_recv_message = true; + } + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_message_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_REF(error)); + complete_if_batch_end_locked( + s, error, s->recv_message_op, + "fail_helper scheduling recv-message-on-complete"); + s->recv_message_op = nullptr; + } + if (s->send_message_op) { + s->send_message_op->payload->send_message.send_message.reset(); + complete_if_batch_end_locked( + s, error, s->send_message_op, + "fail_helper scheduling send-message-on-complete"); + s->send_message_op = nullptr; + } + if (s->send_trailing_md_op) { + complete_if_batch_end_locked( + s, error, s->send_trailing_md_op, + "fail_helper scheduling send-trailng-md-on-complete"); + s->send_trailing_md_op = nullptr; + } + if (s->recv_trailing_md_op) { + INPROC_LOG(GPR_INFO, "fail_helper %p scheduling trailing-metadata-ready %s", + s, grpc_error_std_string(error).c_str()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + GRPC_ERROR_REF(error)); + INPROC_LOG(GPR_INFO, "fail_helper %p scheduling trailing-md-on-complete %s", + s, grpc_error_std_string(error).c_str()); + complete_if_batch_end_locked( + s, error, s->recv_trailing_md_op, + "fail_helper scheduling recv-trailing-metadata-on-complete"); + s->recv_trailing_md_op = nullptr; + } + close_other_side_locked(s, "fail_helper:other_side"); + close_stream_locked(s); + + GRPC_ERROR_UNREF(error); +} + +// TODO(vjpai): It should not be necessary to drain the incoming byte +// stream and create a new one; instead, we should simply pass the byte +// stream from the sender directly to the receiver as-is. +// +// Note that fixing this will also avoid the assumption in this code +// that the incoming byte stream's next() call will always return +// synchronously. That assumption is true today but may not always be +// true in the future. +void message_transfer_locked(inproc_stream* sender, inproc_stream* receiver) { + size_t remaining = + sender->send_message_op->payload->send_message.send_message->length(); + if (receiver->recv_inited) { + grpc_slice_buffer_destroy_internal(&receiver->recv_message); + } + grpc_slice_buffer_init(&receiver->recv_message); + receiver->recv_inited = true; + do { + grpc_slice message_slice; + grpc_closure unused; + GPR_ASSERT( + sender->send_message_op->payload->send_message.send_message->Next( + SIZE_MAX, &unused)); + grpc_error_handle error = + sender->send_message_op->payload->send_message.send_message->Pull( + &message_slice); + if (error != GRPC_ERROR_NONE) { + cancel_stream_locked(sender, GRPC_ERROR_REF(error)); + break; + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + remaining -= GRPC_SLICE_LENGTH(message_slice); + grpc_slice_buffer_add(&receiver->recv_message, message_slice); + } while (remaining > 0); + sender->send_message_op->payload->send_message.send_message.reset(); + + receiver->recv_stream.Init(&receiver->recv_message, 0); + receiver->recv_message_op->payload->recv_message.recv_message->reset( + receiver->recv_stream.get()); + INPROC_LOG(GPR_INFO, "message_transfer_locked %p scheduling message-ready", + receiver); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + receiver->recv_message_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + complete_if_batch_end_locked( + sender, GRPC_ERROR_NONE, sender->send_message_op, + "message_transfer scheduling sender on_complete"); + complete_if_batch_end_locked( + receiver, GRPC_ERROR_NONE, receiver->recv_message_op, + "message_transfer scheduling receiver on_complete"); + + receiver->recv_message_op = nullptr; + sender->send_message_op = nullptr; +} + +void op_state_machine_locked(inproc_stream* s, grpc_error_handle error) { + // This function gets called when we have contents in the unprocessed reads + // Get what we want based on our ops wanted + // Schedule our appropriate closures + // and then return to ops_needed state if still needed + + grpc_error_handle new_err = GRPC_ERROR_NONE; + + bool needs_close = false; + + INPROC_LOG(GPR_INFO, "op_state_machine %p", s); + // cancellation takes precedence + inproc_stream* other = s->other_side; + + if (s->cancel_self_error != GRPC_ERROR_NONE) { + fail_helper_locked(s, GRPC_ERROR_REF(s->cancel_self_error)); + goto done; + } else if (s->cancel_other_error != GRPC_ERROR_NONE) { + fail_helper_locked(s, GRPC_ERROR_REF(s->cancel_other_error)); + goto done; + } else if (error != GRPC_ERROR_NONE) { + fail_helper_locked(s, GRPC_ERROR_REF(error)); + goto done; + } + + if (s->send_message_op && other) { + if (other->recv_message_op) { + message_transfer_locked(s, other); + maybe_process_ops_locked(other, GRPC_ERROR_NONE); + } else if (!s->t->is_client && s->trailing_md_sent) { + // A server send will never be matched if the server already sent status + s->send_message_op->payload->send_message.send_message.reset(); + complete_if_batch_end_locked( + s, GRPC_ERROR_NONE, s->send_message_op, + "op_state_machine scheduling send-message-on-complete case 1"); + s->send_message_op = nullptr; + } + } + // Pause a send trailing metadata if there is still an outstanding + // send message unless we know that the send message will never get + // matched to a receive. This happens on the client if the server has + // already sent status or on the server if the client has requested + // status + if (s->send_trailing_md_op && + (!s->send_message_op || + (s->t->is_client && + (s->trailing_md_recvd || s->to_read_trailing_md_filled)) || + (!s->t->is_client && other && + (other->trailing_md_recvd || other->to_read_trailing_md_filled || + other->recv_trailing_md_op)))) { + grpc_metadata_batch* dest = (other == nullptr) + ? &s->write_buffer_trailing_md + : &other->to_read_trailing_md; + bool* destfilled = (other == nullptr) ? &s->write_buffer_trailing_md_filled + : &other->to_read_trailing_md_filled; + if (*destfilled || s->trailing_md_sent) { + // The buffer is already in use; that's an error! + INPROC_LOG(GPR_INFO, "Extra trailing metadata %p", s); + new_err = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Extra trailing metadata"); + fail_helper_locked(s, GRPC_ERROR_REF(new_err)); + goto done; + } else { + if (!other || !other->closed) { + (void)fill_in_metadata( + s, + s->send_trailing_md_op->payload->send_trailing_metadata + .send_trailing_metadata, + 0, dest, nullptr, destfilled); + } + s->trailing_md_sent = true; + if (s->send_trailing_md_op->payload->send_trailing_metadata.sent) { + *s->send_trailing_md_op->payload->send_trailing_metadata.sent = true; + } + if (!s->t->is_client && s->trailing_md_recvd && s->recv_trailing_md_op) { + INPROC_LOG(GPR_INFO, + "op_state_machine %p scheduling trailing-metadata-ready", s); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + GRPC_ERROR_NONE); + INPROC_LOG(GPR_INFO, + "op_state_machine %p scheduling trailing-md-on-complete", s); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + s->recv_trailing_md_op->on_complete, + GRPC_ERROR_NONE); + s->recv_trailing_md_op = nullptr; + needs_close = true; + } + } + maybe_process_ops_locked(other, GRPC_ERROR_NONE); + complete_if_batch_end_locked( + s, GRPC_ERROR_NONE, s->send_trailing_md_op, + "op_state_machine scheduling send-trailing-metadata-on-complete"); + s->send_trailing_md_op = nullptr; + } + if (s->recv_initial_md_op) { + if (s->initial_md_recvd) { + new_err = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Already recvd initial md"); + INPROC_LOG( + GPR_INFO, + "op_state_machine %p scheduling on_complete errors for already " + "recvd initial md %s", + s, grpc_error_std_string(new_err).c_str()); + fail_helper_locked(s, GRPC_ERROR_REF(new_err)); + goto done; + } + + if (s->to_read_initial_md_filled) { + s->initial_md_recvd = true; + fill_in_metadata( + s, &s->to_read_initial_md, s->to_read_initial_md_flags, + s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata, + s->recv_initial_md_op->payload->recv_initial_metadata.recv_flags, + nullptr); + if (s->deadline != GRPC_MILLIS_INF_FUTURE) { + s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata->Set(grpc_core::GrpcTimeoutMetadata(), + s->deadline); + } + if (s->recv_initial_md_op->payload->recv_initial_metadata + .trailing_metadata_available != nullptr) { + *s->recv_initial_md_op->payload->recv_initial_metadata + .trailing_metadata_available = + (other != nullptr && other->send_trailing_md_op != nullptr); + } + s->to_read_initial_md.Clear(); + s->to_read_initial_md_filled = false; + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata_ready, + GRPC_ERROR_NONE); + complete_if_batch_end_locked( + s, GRPC_ERROR_NONE, s->recv_initial_md_op, + "op_state_machine scheduling recv-initial-metadata-on-complete"); + s->recv_initial_md_op = nullptr; + } + } + if (s->recv_message_op) { + if (other && other->send_message_op) { + message_transfer_locked(other, s); + maybe_process_ops_locked(other, GRPC_ERROR_NONE); + } + } + if (s->to_read_trailing_md_filled) { + if (s->trailing_md_recvd) { + if (s->trailing_md_recvd_implicit_only) { + INPROC_LOG(GPR_INFO, + "op_state_machine %p already implicitly received trailing " + "metadata, so ignoring new trailing metadata from client", + s); + s->to_read_trailing_md.Clear(); + s->to_read_trailing_md_filled = false; + s->trailing_md_recvd_implicit_only = false; + } else { + new_err = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Already recvd trailing md"); + INPROC_LOG( + GPR_INFO, + "op_state_machine %p scheduling on_complete errors for already " + "recvd trailing md %s", + s, grpc_error_std_string(new_err).c_str()); + fail_helper_locked(s, GRPC_ERROR_REF(new_err)); + goto done; + } + } + if (s->recv_message_op != nullptr) { + // This message needs to be wrapped up because it will never be + // satisfied + *s->recv_message_op->payload->recv_message.recv_message = nullptr; + INPROC_LOG(GPR_INFO, "op_state_machine %p scheduling message-ready", s); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_message_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + complete_if_batch_end_locked( + s, new_err, s->recv_message_op, + "op_state_machine scheduling recv-message-on-complete"); + s->recv_message_op = nullptr; + } + if ((s->trailing_md_sent || s->t->is_client) && s->send_message_op) { + // Nothing further will try to receive from this stream, so finish off + // any outstanding send_message op + s->send_message_op->payload->send_message.send_message.reset(); + s->send_message_op->payload->send_message.stream_write_closed = true; + complete_if_batch_end_locked( + s, new_err, s->send_message_op, + "op_state_machine scheduling send-message-on-complete case 2"); + s->send_message_op = nullptr; + } + if (s->recv_trailing_md_op != nullptr) { + // We wanted trailing metadata and we got it + s->trailing_md_recvd = true; + fill_in_metadata(s, &s->to_read_trailing_md, 0, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata, + nullptr, nullptr); + s->to_read_trailing_md.Clear(); + s->to_read_trailing_md_filled = false; + + // We should schedule the recv_trailing_md_op completion if + // 1. this stream is the client-side + // 2. this stream is the server-side AND has already sent its trailing md + // (If the server hasn't already sent its trailing md, it doesn't have + // a final status, so don't mark this op complete) + if (s->t->is_client || s->trailing_md_sent) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + s->recv_trailing_md_op->on_complete, + GRPC_ERROR_NONE); + s->recv_trailing_md_op = nullptr; + needs_close = s->trailing_md_sent; + } + } else if (!s->trailing_md_recvd) { + INPROC_LOG( + GPR_INFO, + "op_state_machine %p has trailing md but not yet waiting for it", s); + } + } + if (!s->t->is_client && s->trailing_md_sent && + (s->recv_trailing_md_op != nullptr)) { + // In this case, we don't care to receive the write-close from the client + // because we have already sent status and the RPC is over as far as we + // are concerned. + INPROC_LOG(GPR_INFO, "op_state_machine %p scheduling trailing-md-ready %s", + s, grpc_error_std_string(new_err).c_str()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + GRPC_ERROR_REF(new_err)); + complete_if_batch_end_locked( + s, new_err, s->recv_trailing_md_op, + "op_state_machine scheduling recv-trailing-md-on-complete"); + s->trailing_md_recvd = true; + s->recv_trailing_md_op = nullptr; + // Since we are only pretending to have received the trailing MD, it would + // be ok (not an error) if the client actually sends it later. + s->trailing_md_recvd_implicit_only = true; + } + if (s->trailing_md_recvd && s->recv_message_op) { + // No further message will come on this stream, so finish off the + // recv_message_op + INPROC_LOG(GPR_INFO, "op_state_machine %p scheduling message-ready", s); + *s->recv_message_op->payload->recv_message.recv_message = nullptr; + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_message_op->payload->recv_message.recv_message_ready, + GRPC_ERROR_NONE); + complete_if_batch_end_locked( + s, new_err, s->recv_message_op, + "op_state_machine scheduling recv-message-on-complete"); + s->recv_message_op = nullptr; + } + if (s->trailing_md_recvd && s->send_message_op && s->t->is_client) { + // Nothing further will try to receive from this stream, so finish off + // any outstanding send_message op + s->send_message_op->payload->send_message.send_message.reset(); + complete_if_batch_end_locked( + s, new_err, s->send_message_op, + "op_state_machine scheduling send-message-on-complete case 3"); + s->send_message_op = nullptr; + } + if (s->send_message_op || s->send_trailing_md_op || s->recv_initial_md_op || + s->recv_message_op || s->recv_trailing_md_op) { + // Didn't get the item we wanted so we still need to get + // rescheduled + INPROC_LOG( + GPR_INFO, "op_state_machine %p still needs closure %p %p %p %p %p", s, + s->send_message_op, s->send_trailing_md_op, s->recv_initial_md_op, + s->recv_message_op, s->recv_trailing_md_op); + s->ops_needed = true; + } +done: + if (needs_close) { + close_other_side_locked(s, "op_state_machine"); + close_stream_locked(s); + } + GRPC_ERROR_UNREF(new_err); +} + +bool cancel_stream_locked(inproc_stream* s, grpc_error_handle error) { + bool ret = false; // was the cancel accepted + INPROC_LOG(GPR_INFO, "cancel_stream %p with %s", s, + grpc_error_std_string(error).c_str()); + if (s->cancel_self_error == GRPC_ERROR_NONE) { + ret = true; + s->cancel_self_error = GRPC_ERROR_REF(error); + // Catch current value of other before it gets closed off + inproc_stream* other = s->other_side; + maybe_process_ops_locked(s, s->cancel_self_error); + // Send trailing md to the other side indicating cancellation, even if we + // already have + s->trailing_md_sent = true; + + grpc_metadata_batch cancel_md(s->arena); + + grpc_metadata_batch* dest = (other == nullptr) + ? &s->write_buffer_trailing_md + : &other->to_read_trailing_md; + bool* destfilled = (other == nullptr) ? &s->write_buffer_trailing_md_filled + : &other->to_read_trailing_md_filled; + (void)fill_in_metadata(s, &cancel_md, 0, dest, nullptr, destfilled); + + if (other != nullptr) { + if (other->cancel_other_error == GRPC_ERROR_NONE) { + other->cancel_other_error = GRPC_ERROR_REF(s->cancel_self_error); + } + maybe_process_ops_locked(other, other->cancel_other_error); + } else if (s->write_buffer_cancel_error == GRPC_ERROR_NONE) { + s->write_buffer_cancel_error = GRPC_ERROR_REF(s->cancel_self_error); + } + + // if we are a server and already received trailing md but + // couldn't complete that because we hadn't yet sent out trailing + // md, now's the chance + if (!s->t->is_client && s->trailing_md_recvd && s->recv_trailing_md_op) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata_ready, + GRPC_ERROR_REF(s->cancel_self_error)); + complete_if_batch_end_locked( + s, s->cancel_self_error, s->recv_trailing_md_op, + "cancel_stream scheduling trailing-md-on-complete"); + s->recv_trailing_md_op = nullptr; + } + } + + close_other_side_locked(s, "cancel_stream:other_side"); + close_stream_locked(s); + + GRPC_ERROR_UNREF(error); + return ret; +} + +void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void perform_stream_op(grpc_transport* gt, grpc_stream* gs, + grpc_transport_stream_op_batch* op) { + INPROC_LOG(GPR_INFO, "perform_stream_op %p %p %p", gt, gs, op); + inproc_stream* s = reinterpret_cast(gs); + gpr_mu* mu = &s->t->mu->mu; // save aside in case s gets closed + gpr_mu_lock(mu); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_inproc_trace)) { + if (op->send_initial_metadata) { + log_metadata(op->payload->send_initial_metadata.send_initial_metadata, + s->t->is_client, true); + } + if (op->send_trailing_metadata) { + log_metadata(op->payload->send_trailing_metadata.send_trailing_metadata, + s->t->is_client, false); + } + } + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_closure* on_complete = op->on_complete; + // TODO(roth): This is a hack needed because we use data inside of the + // closure itself to do the barrier calculation (i.e., to ensure that + // we don't schedule the closure until all ops in the batch have been + // completed). This can go away once we move to a new C++ closure API + // that provides the ability to create a barrier closure. + if (on_complete == nullptr) { + on_complete = GRPC_CLOSURE_INIT(&op->handler_private.closure, do_nothing, + nullptr, grpc_schedule_on_exec_ctx); + } + + if (op->cancel_stream) { + // Call cancel_stream_locked without ref'ing the cancel_error because + // this function is responsible to make sure that that field gets unref'ed + cancel_stream_locked(s, op->payload->cancel_stream.cancel_error); + // this op can complete without an error + } else if (s->cancel_self_error != GRPC_ERROR_NONE) { + // already self-canceled so still give it an error + error = GRPC_ERROR_REF(s->cancel_self_error); + } else { + INPROC_LOG(GPR_INFO, "perform_stream_op %p %s%s%s%s%s%s%s", s, + s->t->is_client ? "client" : "server", + op->send_initial_metadata ? " send_initial_metadata" : "", + op->send_message ? " send_message" : "", + op->send_trailing_metadata ? " send_trailing_metadata" : "", + op->recv_initial_metadata ? " recv_initial_metadata" : "", + op->recv_message ? " recv_message" : "", + op->recv_trailing_metadata ? " recv_trailing_metadata" : ""); + } + + inproc_stream* other = s->other_side; + if (error == GRPC_ERROR_NONE && + (op->send_initial_metadata || op->send_trailing_metadata)) { + if (s->t->is_closed) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Endpoint already shutdown"); + } + if (error == GRPC_ERROR_NONE && op->send_initial_metadata) { + grpc_metadata_batch* dest = (other == nullptr) + ? &s->write_buffer_initial_md + : &other->to_read_initial_md; + uint32_t* destflags = (other == nullptr) + ? &s->write_buffer_initial_md_flags + : &other->to_read_initial_md_flags; + bool* destfilled = (other == nullptr) ? &s->write_buffer_initial_md_filled + : &other->to_read_initial_md_filled; + if (*destfilled || s->initial_md_sent) { + // The buffer is already in use; that's an error! + INPROC_LOG(GPR_INFO, "Extra initial metadata %p", s); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Extra initial metadata"); + } else { + if (!s->other_side_closed) { + (void)fill_in_metadata( + s, op->payload->send_initial_metadata.send_initial_metadata, + op->payload->send_initial_metadata.send_initial_metadata_flags, + dest, destflags, destfilled); + } + if (s->t->is_client) { + grpc_millis* dl = + (other == nullptr) ? &s->write_buffer_deadline : &other->deadline; + *dl = std::min( + *dl, op->payload->send_initial_metadata.send_initial_metadata + ->get(grpc_core::GrpcTimeoutMetadata()) + .value_or(GRPC_MILLIS_INF_FUTURE)); + s->initial_md_sent = true; + } + } + maybe_process_ops_locked(other, error); + } + } + + if (error == GRPC_ERROR_NONE && + (op->send_message || op->send_trailing_metadata || + op->recv_initial_metadata || op->recv_message || + op->recv_trailing_metadata)) { + // Mark ops that need to be processed by the state machine + if (op->send_message) { + s->send_message_op = op; + } + if (op->send_trailing_metadata) { + s->send_trailing_md_op = op; + } + if (op->recv_initial_metadata) { + s->recv_initial_md_op = op; + } + if (op->recv_message) { + s->recv_message_op = op; + } + if (op->recv_trailing_metadata) { + s->recv_trailing_md_op = op; + } + + // We want to initiate the state machine if: + // 1. We want to send a message and the other side wants to receive + // 2. We want to send trailing metadata and there isn't an unmatched send + // or the other side wants trailing metadata + // 3. We want initial metadata and the other side has sent it + // 4. We want to receive a message and there is a message ready + // 5. There is trailing metadata, even if nothing specifically wants + // that because that can shut down the receive message as well + if ((op->send_message && other && other->recv_message_op != nullptr) || + (op->send_trailing_metadata && + (!s->send_message_op || (other && other->recv_trailing_md_op))) || + (op->recv_initial_metadata && s->to_read_initial_md_filled) || + (op->recv_message && other && other->send_message_op != nullptr) || + (s->to_read_trailing_md_filled || s->trailing_md_recvd)) { + op_state_machine_locked(s, error); + } else { + s->ops_needed = true; + } + } else { + if (error != GRPC_ERROR_NONE) { + // Consume any send message that was sent here but that we are not pushing + // to the other side + if (op->send_message) { + op->payload->send_message.send_message.reset(); + } + // Schedule op's closures that we didn't push to op state machine + if (op->recv_initial_metadata) { + if (op->payload->recv_initial_metadata.trailing_metadata_available != + nullptr) { + // Set to true unconditionally, because we're failing the call, so + // even if we haven't actually seen the send_trailing_metadata op + // from the other side, we're going to return trailing metadata + // anyway. + *op->payload->recv_initial_metadata.trailing_metadata_available = + true; + } + INPROC_LOG( + GPR_INFO, + "perform_stream_op error %p scheduling initial-metadata-ready %s", + s, grpc_error_std_string(error).c_str()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_REF(error)); + } + if (op->recv_message) { + INPROC_LOG( + GPR_INFO, + "perform_stream_op error %p scheduling recv message-ready %s", s, + grpc_error_std_string(error).c_str()); + if (op->payload->recv_message.call_failed_before_recv_message != + nullptr) { + *op->payload->recv_message.call_failed_before_recv_message = true; + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + op->payload->recv_message.recv_message_ready, + GRPC_ERROR_REF(error)); + } + if (op->recv_trailing_metadata) { + INPROC_LOG( + GPR_INFO, + "perform_stream_op error %p scheduling trailing-metadata-ready %s", + s, grpc_error_std_string(error).c_str()); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_REF(error)); + } + } + INPROC_LOG(GPR_INFO, "perform_stream_op %p scheduling on_complete %s", s, + grpc_error_std_string(error).c_str()); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_complete, GRPC_ERROR_REF(error)); + } + gpr_mu_unlock(mu); + GRPC_ERROR_UNREF(error); +} + +void close_transport_locked(inproc_transport* t) { + INPROC_LOG(GPR_INFO, "close_transport %p %d", t, t->is_closed); + t->state_tracker.SetState(GRPC_CHANNEL_SHUTDOWN, absl::Status(), + "close transport"); + if (!t->is_closed) { + t->is_closed = true; + /* Also end all streams on this transport */ + while (t->stream_list != nullptr) { + // cancel_stream_locked also adjusts stream list + cancel_stream_locked( + t->stream_list, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport closed"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } +} + +void perform_transport_op(grpc_transport* gt, grpc_transport_op* op) { + inproc_transport* t = reinterpret_cast(gt); + INPROC_LOG(GPR_INFO, "perform_transport_op %p %p", t, op); + gpr_mu_lock(&t->mu->mu); + if (op->start_connectivity_watch != nullptr) { + t->state_tracker.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); + } + if (op->stop_connectivity_watch != nullptr) { + t->state_tracker.RemoveWatcher(op->stop_connectivity_watch); + } + if (op->set_accept_stream) { + t->accept_stream_cb = op->set_accept_stream_fn; + t->accept_stream_data = op->set_accept_stream_user_data; + } + if (op->on_consumed) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); + } + + bool do_close = false; + if (op->goaway_error != GRPC_ERROR_NONE) { + do_close = true; + GRPC_ERROR_UNREF(op->goaway_error); + } + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + do_close = true; + GRPC_ERROR_UNREF(op->disconnect_with_error); + } + + if (do_close) { + close_transport_locked(t); + } + gpr_mu_unlock(&t->mu->mu); +} + +void destroy_stream(grpc_transport* gt, grpc_stream* gs, + grpc_closure* then_schedule_closure) { + INPROC_LOG(GPR_INFO, "destroy_stream %p %p", gs, then_schedule_closure); + inproc_transport* t = reinterpret_cast(gt); + inproc_stream* s = reinterpret_cast(gs); + gpr_mu_lock(&t->mu->mu); + close_stream_locked(s); + gpr_mu_unlock(&t->mu->mu); + s->~inproc_stream(); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, + GRPC_ERROR_NONE); +} + +void destroy_transport(grpc_transport* gt) { + inproc_transport* t = reinterpret_cast(gt); + INPROC_LOG(GPR_INFO, "destroy_transport %p", t); + gpr_mu_lock(&t->mu->mu); + close_transport_locked(t); + gpr_mu_unlock(&t->mu->mu); + t->other_side->unref(); + t->unref(); +} + +/******************************************************************************* + * INTEGRATION GLUE + */ + +void set_pollset(grpc_transport* /*gt*/, grpc_stream* /*gs*/, + grpc_pollset* /*pollset*/) { + // Nothing to do here +} + +void set_pollset_set(grpc_transport* /*gt*/, grpc_stream* /*gs*/, + grpc_pollset_set* /*pollset_set*/) { + // Nothing to do here +} + +grpc_endpoint* get_endpoint(grpc_transport* /*t*/) { return nullptr; } + +const grpc_transport_vtable inproc_vtable = { + sizeof(inproc_stream), "inproc", init_stream, + set_pollset, set_pollset_set, perform_stream_op, + perform_transport_op, destroy_stream, destroy_transport, + get_endpoint}; + +/******************************************************************************* + * Main inproc transport functions + */ +void inproc_transports_create(grpc_transport** server_transport, + const grpc_channel_args* /*server_args*/, + grpc_transport** client_transport, + const grpc_channel_args* /*client_args*/) { + INPROC_LOG(GPR_INFO, "inproc_transports_create"); + shared_mu* mu = new (gpr_malloc(sizeof(*mu))) shared_mu(); + inproc_transport* st = new (gpr_malloc(sizeof(*st))) + inproc_transport(&inproc_vtable, mu, /*is_client=*/false); + inproc_transport* ct = new (gpr_malloc(sizeof(*ct))) + inproc_transport(&inproc_vtable, mu, /*is_client=*/true); + st->other_side = ct; + ct->other_side = st; + *server_transport = reinterpret_cast(st); + *client_transport = reinterpret_cast(ct); +} +} // namespace + +/******************************************************************************* + * GLOBAL INIT AND DESTROY + */ +void grpc_inproc_transport_init(void) { + grpc_core::ExecCtx exec_ctx; + g_empty_slice = grpc_core::ExternallyManagedSlice(); + + grpc_slice key_tmp = grpc_slice_from_static_string(":path"); + g_fake_path_key = grpc_slice_intern(key_tmp); + grpc_slice_unref_internal(key_tmp); + + g_fake_path_value = grpc_slice_from_static_string("/"); + + grpc_slice auth_tmp = grpc_slice_from_static_string(":authority"); + g_fake_auth_key = grpc_slice_intern(auth_tmp); + grpc_slice_unref_internal(auth_tmp); + + g_fake_auth_value = grpc_slice_from_static_string("inproc-fail"); +} + +grpc_channel* grpc_inproc_channel_create(grpc_server* server, + grpc_channel_args* args, + void* /*reserved*/) { + GRPC_API_TRACE("grpc_inproc_channel_create(server=%p, args=%p)", 2, + (server, args)); + + grpc_core::ExecCtx exec_ctx; + + // Remove max_connection_idle and max_connection_age channel arguments since + // those do not apply to inproc transports. + const char* args_to_remove[] = {GRPC_ARG_MAX_CONNECTION_IDLE_MS, + GRPC_ARG_MAX_CONNECTION_AGE_MS}; + const grpc_channel_args* server_args = grpc_channel_args_copy_and_remove( + server->core_server->channel_args(), args_to_remove, + GPR_ARRAY_SIZE(args_to_remove)); + // Add a default authority channel argument for the client + grpc_arg default_authority_arg; + default_authority_arg.type = GRPC_ARG_STRING; + default_authority_arg.key = const_cast(GRPC_ARG_DEFAULT_AUTHORITY); + default_authority_arg.value.string = const_cast("inproc.authority"); + grpc_channel_args* client_args = + grpc_channel_args_copy_and_add(args, &default_authority_arg, 1); + + grpc_transport* server_transport; + grpc_transport* client_transport; + inproc_transports_create(&server_transport, server_args, &client_transport, + client_args); + + // TODO(ncteisen): design and support channelz GetSocket for inproc. + grpc_error_handle error = server->core_server->SetupTransport( + server_transport, nullptr, server_args, nullptr); + grpc_channel* channel = nullptr; + if (error == GRPC_ERROR_NONE) { + channel = + grpc_channel_create("inproc", client_args, GRPC_CLIENT_DIRECT_CHANNEL, + client_transport, nullptr, 0, &error); + if (error != GRPC_ERROR_NONE) { + GPR_ASSERT(!channel); + gpr_log(GPR_ERROR, "Failed to create client channel: %s", + grpc_error_std_string(error).c_str()); + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + // client_transport was destroyed when grpc_channel_create saw an error. + grpc_transport_destroy(server_transport); + channel = grpc_lame_client_channel_create( + nullptr, status, "Failed to create client channel"); + } + } else { + GPR_ASSERT(!channel); + gpr_log(GPR_ERROR, "Failed to create server channel: %s", + grpc_error_std_string(error).c_str()); + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(client_transport); + grpc_transport_destroy(server_transport); + channel = grpc_lame_client_channel_create( + nullptr, status, "Failed to create server channel"); + } + + // Free up created channel args + grpc_channel_args_destroy(server_args); + grpc_channel_args_destroy(client_args); + + // Now finish scheduled operations + + return channel; +} + +void grpc_inproc_transport_shutdown(void) { + grpc_core::ExecCtx exec_ctx; + grpc_slice_unref_internal(g_empty_slice); + grpc_slice_unref_internal(g_fake_path_key); + grpc_slice_unref_internal(g_fake_path_value); + grpc_slice_unref_internal(g_fake_auth_key); + grpc_slice_unref_internal(g_fake_auth_value); +} diff --git a/src/core/ext/xds/certificate_provider_registry.cc b/src/core/ext/xds/certificate_provider_registry.cc new file mode 100644 index 00000000..2b1780c1 --- /dev/null +++ b/src/core/ext/xds/certificate_provider_registry.cc @@ -0,0 +1,103 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/certificate_provider_registry.h" + +#include "absl/container/inlined_vector.h" + +namespace grpc_core { + +namespace { + +class RegistryState { + public: + void RegisterCertificateProviderFactory( + std::unique_ptr factory) { + gpr_log(GPR_DEBUG, "registering certificate provider factory for \"%s\"", + factory->name()); + for (size_t i = 0; i < factories_.size(); ++i) { + GPR_ASSERT(strcmp(factories_[i]->name(), factory->name()) != 0); + } + factories_.push_back(std::move(factory)); + } + + CertificateProviderFactory* LookupCertificateProviderFactory( + absl::string_view name) const { + for (size_t i = 0; i < factories_.size(); ++i) { + if (name == factories_[i]->name()) { + return factories_[i].get(); + } + } + return nullptr; + } + + private: + // We currently support 3 factories without doing additional + // allocation. This number could be raised if there is a case where + // more factories are needed and the additional allocations are + // hurting performance (which is unlikely, since these allocations + // only occur at gRPC initialization time). + absl::InlinedVector, 3> + factories_; +}; + +static RegistryState* g_state = nullptr; + +} // namespace + +// +// CertificateProviderRegistry +// + +CertificateProviderFactory* +CertificateProviderRegistry::LookupCertificateProviderFactory( + absl::string_view name) { + GPR_ASSERT(g_state != nullptr); + return g_state->LookupCertificateProviderFactory(name); +} + +void CertificateProviderRegistry::InitRegistry() { + if (g_state == nullptr) g_state = new RegistryState(); +} + +void CertificateProviderRegistry::ShutdownRegistry() { + delete g_state; + g_state = nullptr; +} + +void CertificateProviderRegistry::RegisterCertificateProviderFactory( + std::unique_ptr factory) { + InitRegistry(); + g_state->RegisterCertificateProviderFactory(std::move(factory)); +} + +} // namespace grpc_core + +// +// Plugin registration +// + +void grpc_certificate_provider_registry_init() { + grpc_core::CertificateProviderRegistry::InitRegistry(); +} + +void grpc_certificate_provider_registry_shutdown() { + grpc_core::CertificateProviderRegistry::ShutdownRegistry(); +} diff --git a/src/core/ext/xds/certificate_provider_store.cc b/src/core/ext/xds/certificate_provider_store.cc new file mode 100644 index 00000000..dd66b97a --- /dev/null +++ b/src/core/ext/xds/certificate_provider_store.cc @@ -0,0 +1,87 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/certificate_provider_store.h" + +#include "src/core/ext/xds/certificate_provider_registry.h" + +namespace grpc_core { + +// If a certificate provider is created, the CertificateProviderStore +// maintains a raw pointer to the created CertificateProviderWrapper so that +// future calls to `CreateOrGetCertificateProvider()` with the same key result +// in returning a ref to this created certificate provider. This entry is +// deleted when the refcount to this provider reaches zero. +RefCountedPtr +CertificateProviderStore::CreateOrGetCertificateProvider( + absl::string_view key) { + RefCountedPtr result; + MutexLock lock(&mu_); + auto it = certificate_providers_map_.find(key); + if (it == certificate_providers_map_.end()) { + result = CreateCertificateProviderLocked(key); + if (result != nullptr) { + certificate_providers_map_.insert({result->key(), result.get()}); + } + } else { + result = it->second->RefIfNonZero(); + if (result == nullptr) { + result = CreateCertificateProviderLocked(key); + it->second = result.get(); + } + } + return result; +} + +RefCountedPtr +CertificateProviderStore::CreateCertificateProviderLocked( + absl::string_view key) { + auto plugin_config_it = plugin_config_map_.find(std::string(key)); + if (plugin_config_it == plugin_config_map_.end()) { + return nullptr; + } + CertificateProviderFactory* factory = + CertificateProviderRegistry::LookupCertificateProviderFactory( + plugin_config_it->second.plugin_name); + if (factory == nullptr) { + // This should never happen since an entry is only inserted in the + // plugin_config_map_ if the corresponding factory was found when parsing + // the xDS bootstrap file. + gpr_log(GPR_ERROR, "Certificate provider factory %s not found", + plugin_config_it->second.plugin_name.c_str()); + return nullptr; + } + return MakeRefCounted( + factory->CreateCertificateProvider(plugin_config_it->second.config), + Ref(), plugin_config_it->first); +} + +void CertificateProviderStore::ReleaseCertificateProvider( + absl::string_view key, CertificateProviderWrapper* wrapper) { + MutexLock lock(&mu_); + auto it = certificate_providers_map_.find(key); + if (it != certificate_providers_map_.end()) { + if (it->second == wrapper) { + certificate_providers_map_.erase(it); + } + } +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/file_watcher_certificate_provider_factory.cc b/src/core/ext/xds/file_watcher_certificate_provider_factory.cc new file mode 100644 index 00000000..7a793b06 --- /dev/null +++ b/src/core/ext/xds/file_watcher_certificate_provider_factory.cc @@ -0,0 +1,144 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/file_watcher_certificate_provider_factory.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include "src/core/ext/xds/certificate_provider_registry.h" +#include "src/core/lib/json/json_util.h" + +namespace grpc_core { + +namespace { + +const char* kFileWatcherPlugin = "file_watcher"; + +} // namespace + +// +// FileWatcherCertificateProviderFactory::Config +// + +const char* FileWatcherCertificateProviderFactory::Config::name() const { + return kFileWatcherPlugin; +} + +std::string FileWatcherCertificateProviderFactory::Config::ToString() const { + std::vector parts; + parts.push_back("{"); + if (!identity_cert_file_.empty()) { + parts.push_back( + absl::StrFormat("certificate_file=\"%s\", ", identity_cert_file_)); + } + if (!identity_cert_file_.empty()) { + parts.push_back( + absl::StrFormat("private_key_file=\"%s\", ", private_key_file_)); + } + if (!identity_cert_file_.empty()) { + parts.push_back( + absl::StrFormat("ca_certificate_file=\"%s\", ", root_cert_file_)); + } + parts.push_back( + absl::StrFormat("refresh_interval=%ldms}", refresh_interval_ms_)); + return absl::StrJoin(parts, ""); +} + +RefCountedPtr +FileWatcherCertificateProviderFactory::Config::Parse(const Json& config_json, + grpc_error_handle* error) { + auto config = MakeRefCounted(); + if (config_json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "error:config type should be OBJECT."); + return nullptr; + } + std::vector error_list; + ParseJsonObjectField(config_json.object_value(), "certificate_file", + &config->identity_cert_file_, &error_list, false); + ParseJsonObjectField(config_json.object_value(), "private_key_file", + &config->private_key_file_, &error_list, false); + if (config->identity_cert_file_.empty() != + config->private_key_file_.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "fields \"certificate_file\" and \"private_key_file\" must be both set " + "or both unset.")); + } + ParseJsonObjectField(config_json.object_value(), "ca_certificate_file", + &config->root_cert_file_, &error_list, false); + if (config->identity_cert_file_.empty() && config->root_cert_file_.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "At least one of \"certificate_file\" and \"ca_certificate_file\" must " + "be specified.")); + } + if (!ParseJsonObjectFieldAsDuration( + config_json.object_value(), "refresh_interval", + &config->refresh_interval_ms_, &error_list, false)) { + config->refresh_interval_ms_ = 10 * 60 * 1000; // 10 minutes default + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "Error parsing file watcher certificate provider config", &error_list); + return nullptr; + } + return config; +} + +// +// FileWatcherCertificateProviderFactory +// + +const char* FileWatcherCertificateProviderFactory::name() const { + return kFileWatcherPlugin; +} + +RefCountedPtr +FileWatcherCertificateProviderFactory::CreateCertificateProviderConfig( + const Json& config_json, grpc_error_handle* error) { + return FileWatcherCertificateProviderFactory::Config::Parse(config_json, + error); +} + +RefCountedPtr +FileWatcherCertificateProviderFactory::CreateCertificateProvider( + RefCountedPtr config) { + if (config->name() != name()) { + gpr_log(GPR_ERROR, "Wrong config type Actual:%s vs Expected:%s", + config->name(), name()); + return nullptr; + } + auto* file_watcher_config = + static_cast(config.get()); + return MakeRefCounted( + file_watcher_config->private_key_file(), + file_watcher_config->identity_cert_file(), + file_watcher_config->root_cert_file(), + file_watcher_config->refresh_interval_ms() / GPR_MS_PER_SEC); +} + +void FileWatcherCertificateProviderInit() { + CertificateProviderRegistry::RegisterCertificateProviderFactory( + absl::make_unique()); +} + +void FileWatcherCertificateProviderShutdown() {} + +} // namespace grpc_core diff --git a/src/core/ext/xds/google_mesh_ca_certificate_provider_factory.cc b/src/core/ext/xds/google_mesh_ca_certificate_provider_factory.cc new file mode 100644 index 00000000..6e63ae4e --- /dev/null +++ b/src/core/ext/xds/google_mesh_ca_certificate_provider_factory.cc @@ -0,0 +1,265 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/google_mesh_ca_certificate_provider_factory.h" + +#include +#include + +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/json/json_util.h" + +namespace grpc_core { + +namespace { + +const char* kMeshCaPlugin = "meshCA"; + +} // namespace + +// +// GoogleMeshCaCertificateProviderFactory::Config +// + +const char* GoogleMeshCaCertificateProviderFactory::Config::name() const { + return kMeshCaPlugin; +} + +std::string GoogleMeshCaCertificateProviderFactory::Config::ToString() const { + // TODO(yashykt): To be filled + return "{}"; +} + +std::vector +GoogleMeshCaCertificateProviderFactory::Config::ParseJsonObjectStsService( + const Json::Object& sts_service) { + std::vector error_list_sts_service; + if (!ParseJsonObjectField(sts_service, "token_exchange_service_uri", + &sts_config_.token_exchange_service_uri, + &error_list_sts_service, false)) { + sts_config_.token_exchange_service_uri = + "securetoken.googleapis.com"; // default + } + ParseJsonObjectField(sts_service, "resource", &sts_config_.resource, + &error_list_sts_service, false); + ParseJsonObjectField(sts_service, "audience", &sts_config_.audience, + &error_list_sts_service, false); + if (!ParseJsonObjectField(sts_service, "scope", &sts_config_.scope, + &error_list_sts_service, false)) { + sts_config_.scope = + "https://www.googleapis.com/auth/cloud-platform"; // default + } + ParseJsonObjectField(sts_service, "requested_token_type", + &sts_config_.requested_token_type, + &error_list_sts_service, false); + ParseJsonObjectField(sts_service, "subject_token_path", + &sts_config_.subject_token_path, + &error_list_sts_service); + ParseJsonObjectField(sts_service, "subject_token_type", + &sts_config_.subject_token_type, + &error_list_sts_service); + ParseJsonObjectField(sts_service, "actor_token_path", + &sts_config_.actor_token_path, &error_list_sts_service, + false); + ParseJsonObjectField(sts_service, "actor_token_type", + &sts_config_.actor_token_type, &error_list_sts_service, + false); + return error_list_sts_service; +} + +std::vector +GoogleMeshCaCertificateProviderFactory::Config::ParseJsonObjectCallCredentials( + const Json::Object& call_credentials) { + std::vector error_list_call_credentials; + const Json::Object* sts_service = nullptr; + if (ParseJsonObjectField(call_credentials, "sts_service", &sts_service, + &error_list_call_credentials)) { + std::vector error_list_sts_service = + ParseJsonObjectStsService(*sts_service); + if (!error_list_sts_service.empty()) { + error_list_call_credentials.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:sts_service", &error_list_sts_service)); + } + } + return error_list_call_credentials; +} + +std::vector +GoogleMeshCaCertificateProviderFactory::Config::ParseJsonObjectGoogleGrpc( + const Json::Object& google_grpc) { + std::vector error_list_google_grpc; + if (!ParseJsonObjectField(google_grpc, "target_uri", &endpoint_, + &error_list_google_grpc, false)) { + endpoint_ = "meshca.googleapis.com"; // Default target + } + const Json::Array* call_credentials_array = nullptr; + if (ParseJsonObjectField(google_grpc, "call_credentials", + &call_credentials_array, &error_list_google_grpc)) { + if (call_credentials_array->size() != 1) { + error_list_google_grpc.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:call_credentials error:Need exactly one entry.")); + } else { + const Json::Object* call_credentials = nullptr; + if (ExtractJsonType((*call_credentials_array)[0], "call_credentials[0]", + &call_credentials, &error_list_google_grpc)) { + std::vector error_list_call_credentials = + ParseJsonObjectCallCredentials(*call_credentials); + if (!error_list_call_credentials.empty()) { + error_list_google_grpc.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:call_credentials", &error_list_call_credentials)); + } + } + } + } + + return error_list_google_grpc; +} + +std::vector +GoogleMeshCaCertificateProviderFactory::Config::ParseJsonObjectGrpcServices( + const Json::Object& grpc_service) { + std::vector error_list_grpc_services; + const Json::Object* google_grpc = nullptr; + if (ParseJsonObjectField(grpc_service, "google_grpc", &google_grpc, + &error_list_grpc_services)) { + std::vector error_list_google_grpc = + ParseJsonObjectGoogleGrpc(*google_grpc); + if (!error_list_google_grpc.empty()) { + error_list_grpc_services.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:google_grpc", &error_list_google_grpc)); + } + } + if (!ParseJsonObjectFieldAsDuration(grpc_service, "timeout", &timeout_, + &error_list_grpc_services, false)) { + timeout_ = 10 * 1000; // 10sec default + } + return error_list_grpc_services; +} + +std::vector +GoogleMeshCaCertificateProviderFactory::Config::ParseJsonObjectServer( + const Json::Object& server) { + std::vector error_list_server; + std::string api_type; + if (ParseJsonObjectField(server, "api_type", &api_type, &error_list_server, + false)) { + if (api_type != "GRPC") { + error_list_server.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:api_type error:Only GRPC is supported")); + } + } + const Json::Array* grpc_services = nullptr; + if (ParseJsonObjectField(server, "grpc_services", &grpc_services, + &error_list_server)) { + if (grpc_services->size() != 1) { + error_list_server.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:grpc_services error:Need exactly one entry")); + } else { + const Json::Object* grpc_service = nullptr; + if (ExtractJsonType((*grpc_services)[0], "grpc_services[0]", + &grpc_service, &error_list_server)) { + std::vector error_list_grpc_services = + ParseJsonObjectGrpcServices(*grpc_service); + if (!error_list_grpc_services.empty()) { + error_list_server.push_back(GRPC_ERROR_CREATE_FROM_VECTOR( + "field:grpc_services", &error_list_grpc_services)); + } + } + } + } + return error_list_server; +} + +RefCountedPtr +GoogleMeshCaCertificateProviderFactory::Config::Parse( + const Json& config_json, grpc_error_handle* error) { + auto config = + MakeRefCounted(); + if (config_json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "error:config type should be OBJECT."); + return nullptr; + } + std::vector error_list; + const Json::Object* server = nullptr; + if (ParseJsonObjectField(config_json.object_value(), "server", &server, + &error_list)) { + std::vector error_list_server = + config->ParseJsonObjectServer(*server); + if (!error_list_server.empty()) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_VECTOR("field:server", &error_list_server)); + } + } + if (!ParseJsonObjectFieldAsDuration( + config_json.object_value(), "certificate_lifetime", + &config->certificate_lifetime_, &error_list, false)) { + config->certificate_lifetime_ = 24 * 60 * 60 * 1000; // 24hrs default + } + if (!ParseJsonObjectFieldAsDuration( + config_json.object_value(), "renewal_grace_period", + &config->renewal_grace_period_, &error_list, false)) { + config->renewal_grace_period_ = 12 * 60 * 60 * 1000; // 12hrs default + } + std::string key_type; + if (ParseJsonObjectField(config_json.object_value(), "key_type", &key_type, + &error_list, false)) { + if (key_type != "RSA") { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:key_type error:Only RSA is supported.")); + } + } + if (!ParseJsonObjectField(config_json.object_value(), "key_size", + &config->key_size_, &error_list, false)) { + config->key_size_ = 2048; // default 2048 bit key size + } + if (!ParseJsonObjectField(config_json.object_value(), "location", + &config->location_, &error_list, false)) { + // GCE/GKE Metadata server needs to be contacted to get the value. + } + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "Error parsing google Mesh CA config", &error_list); + return nullptr; + } + return config; +} + +// +// GoogleMeshCaCertificateProviderFactory +// + +const char* GoogleMeshCaCertificateProviderFactory::name() const { + return kMeshCaPlugin; +} + +RefCountedPtr +GoogleMeshCaCertificateProviderFactory::CreateCertificateProviderConfig( + const Json& config_json, grpc_error_handle* error) { + return GoogleMeshCaCertificateProviderFactory::Config::Parse(config_json, + error); +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_api.cc b/src/core/ext/xds/xds_api.cc new file mode 100644 index 00000000..bccaa74c --- /dev/null +++ b/src/core/ext/xds/xds_api.cc @@ -0,0 +1,3998 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/xds/xds_api.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "envoy/admin/v3/config_dump.upb.h" +#include "envoy/config/cluster/v3/circuit_breaker.upb.h" +#include "envoy/config/cluster/v3/cluster.upb.h" +#include "envoy/config/cluster/v3/cluster.upbdefs.h" +#include "envoy/config/core/v3/address.upb.h" +#include "envoy/config/core/v3/base.upb.h" +#include "envoy/config/core/v3/base.upbdefs.h" +#include "envoy/config/core/v3/config_source.upb.h" +#include "envoy/config/core/v3/health_check.upb.h" +#include "envoy/config/core/v3/protocol.upb.h" +#include "envoy/config/endpoint/v3/endpoint.upb.h" +#include "envoy/config/endpoint/v3/endpoint.upbdefs.h" +#include "envoy/config/endpoint/v3/endpoint_components.upb.h" +#include "envoy/config/endpoint/v3/load_report.upb.h" +#include "envoy/config/listener/v3/api_listener.upb.h" +#include "envoy/config/listener/v3/listener.upb.h" +#include "envoy/config/listener/v3/listener.upbdefs.h" +#include "envoy/config/listener/v3/listener_components.upb.h" +#include "envoy/config/route/v3/route.upb.h" +#include "envoy/config/route/v3/route.upbdefs.h" +#include "envoy/config/route/v3/route_components.upb.h" +#include "envoy/config/route/v3/route_components.upbdefs.h" +#include "envoy/extensions/clusters/aggregate/v3/cluster.upb.h" +#include "envoy/extensions/clusters/aggregate/v3/cluster.upbdefs.h" +#include "envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.upb.h" +#include "envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.upbdefs.h" +#include "envoy/extensions/transport_sockets/tls/v3/common.upb.h" +#include "envoy/extensions/transport_sockets/tls/v3/tls.upb.h" +#include "envoy/extensions/transport_sockets/tls/v3/tls.upbdefs.h" +#include "envoy/service/cluster/v3/cds.upb.h" +#include "envoy/service/cluster/v3/cds.upbdefs.h" +#include "envoy/service/discovery/v3/discovery.upb.h" +#include "envoy/service/discovery/v3/discovery.upbdefs.h" +#include "envoy/service/endpoint/v3/eds.upb.h" +#include "envoy/service/endpoint/v3/eds.upbdefs.h" +#include "envoy/service/listener/v3/lds.upb.h" +#include "envoy/service/load_stats/v3/lrs.upb.h" +#include "envoy/service/load_stats/v3/lrs.upbdefs.h" +#include "envoy/service/route/v3/rds.upb.h" +#include "envoy/service/route/v3/rds.upbdefs.h" +#include "envoy/service/status/v3/csds.upb.h" +#include "envoy/service/status/v3/csds.upbdefs.h" +#include "envoy/type/matcher/v3/regex.upb.h" +#include "envoy/type/matcher/v3/string.upb.h" +#include "envoy/type/v3/percent.upb.h" +#include "envoy/type/v3/range.upb.h" +#include "google/protobuf/any.upb.h" +#include "google/protobuf/duration.upb.h" +#include "google/protobuf/struct.upb.h" +#include "google/protobuf/timestamp.upb.h" +#include "google/protobuf/wrappers.upb.h" +#include "google/rpc/status.upb.h" +#include "upb/text_encode.h" +#include "upb/upb.h" +#include "upb/upb.hpp" +#include "xds/type/v3/typed_struct.upb.h" + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/slice/slice_utils.h" + +namespace grpc_core { + +// TODO(donnadionne): Check to see if cluster types aggregate_cluster and +// logical_dns are enabled, this will be +// removed once the cluster types are fully integration-tested and enabled by +// default. +bool XdsAggregateAndLogicalDnsClusterEnabled() { + char* value = gpr_getenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); + bool parsed_value; + bool parse_succeeded = gpr_parse_bool_value(value, &parsed_value); + gpr_free(value); + return parse_succeeded && parsed_value; +} + +// +// XdsApi::Route::HashPolicy +// + +XdsApi::Route::HashPolicy::HashPolicy(const HashPolicy& other) + : type(other.type), + header_name(other.header_name), + regex_substitution(other.regex_substitution) { + if (other.regex != nullptr) { + regex = + absl::make_unique(other.regex->pattern(), other.regex->options()); + } +} + +XdsApi::Route::HashPolicy& XdsApi::Route::HashPolicy::operator=( + const HashPolicy& other) { + type = other.type; + header_name = other.header_name; + if (other.regex != nullptr) { + regex = + absl::make_unique(other.regex->pattern(), other.regex->options()); + } + regex_substitution = other.regex_substitution; + return *this; +} + +XdsApi::Route::HashPolicy::HashPolicy(HashPolicy&& other) noexcept + : type(other.type), + header_name(std::move(other.header_name)), + regex(std::move(other.regex)), + regex_substitution(std::move(other.regex_substitution)) {} + +XdsApi::Route::HashPolicy& XdsApi::Route::HashPolicy::operator=( + HashPolicy&& other) noexcept { + type = other.type; + header_name = std::move(other.header_name); + regex = std::move(other.regex); + regex_substitution = std::move(other.regex_substitution); + return *this; +} + +bool XdsApi::Route::HashPolicy::HashPolicy::operator==( + const HashPolicy& other) const { + if (type != other.type) return false; + if (type == Type::HEADER) { + if (regex == nullptr) { + if (other.regex != nullptr) return false; + } else { + if (other.regex == nullptr) return false; + return header_name == other.header_name && + regex->pattern() == other.regex->pattern() && + regex_substitution == other.regex_substitution; + } + } + return true; +} + +std::string XdsApi::Route::HashPolicy::ToString() const { + std::vector contents; + switch (type) { + case Type::HEADER: + contents.push_back("type=HEADER"); + break; + case Type::CHANNEL_ID: + contents.push_back("type=CHANNEL_ID"); + break; + } + contents.push_back( + absl::StrFormat("terminal=%s", terminal ? "true" : "false")); + if (type == Type::HEADER) { + contents.push_back(absl::StrFormat( + "Header %s:/%s/%s", header_name, + (regex == nullptr) ? "" : regex->pattern(), regex_substitution)); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::Route::RetryPolicy +// +std::string XdsApi::Route::RetryPolicy::RetryBackOff::ToString() const { + std::vector contents; + contents.push_back( + absl::StrCat("RetryBackOff Base: ", base_interval.ToString())); + contents.push_back( + absl::StrCat("RetryBackOff max: ", max_interval.ToString())); + return absl::StrJoin(contents, ","); +} + +std::string XdsApi::Route::RetryPolicy::ToString() const { + std::vector contents; + contents.push_back(absl::StrFormat("num_retries=%d", num_retries)); + contents.push_back(retry_back_off.ToString()); + return absl::StrJoin(contents, ","); +} + +// +// XdsApi::Route +// + +std::string XdsApi::Route::Matchers::ToString() const { + std::vector contents; + contents.push_back( + absl::StrFormat("PathMatcher{%s}", path_matcher.ToString())); + for (const HeaderMatcher& header_matcher : header_matchers) { + contents.push_back(header_matcher.ToString()); + } + if (fraction_per_million.has_value()) { + contents.push_back(absl::StrFormat("Fraction Per Million %d", + fraction_per_million.value())); + } + return absl::StrJoin(contents, "\n"); +} + +std::string XdsApi::Route::ClusterWeight::ToString() const { + std::vector contents; + contents.push_back(absl::StrCat("cluster=", name)); + contents.push_back(absl::StrCat("weight=", weight)); + if (!typed_per_filter_config.empty()) { + std::vector parts; + for (const auto& p : typed_per_filter_config) { + const std::string& key = p.first; + const auto& config = p.second; + parts.push_back(absl::StrCat(key, "=", config.ToString())); + } + contents.push_back(absl::StrCat("typed_per_filter_config={", + absl::StrJoin(parts, ", "), "}")); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +std::string XdsApi::Route::ToString() const { + std::vector contents; + contents.push_back(matchers.ToString()); + for (const HashPolicy& hash_policy : hash_policies) { + contents.push_back(absl::StrCat("hash_policy=", hash_policy.ToString())); + } + if (retry_policy.has_value()) { + contents.push_back( + absl::StrCat("retry_policy={", retry_policy->ToString(), "}")); + } + if (!cluster_name.empty()) { + contents.push_back(absl::StrFormat("Cluster name: %s", cluster_name)); + } + for (const ClusterWeight& cluster_weight : weighted_clusters) { + contents.push_back(cluster_weight.ToString()); + } + if (max_stream_duration.has_value()) { + contents.push_back(max_stream_duration->ToString()); + } + if (!typed_per_filter_config.empty()) { + contents.push_back("typed_per_filter_config={"); + for (const auto& p : typed_per_filter_config) { + const std::string& name = p.first; + const auto& config = p.second; + contents.push_back(absl::StrCat(" ", name, "=", config.ToString())); + } + contents.push_back("}"); + } + return absl::StrJoin(contents, "\n"); +} + +// +// XdsApi::RdsUpdate +// + +std::string XdsApi::RdsUpdate::ToString() const { + std::vector vhosts; + for (const VirtualHost& vhost : virtual_hosts) { + vhosts.push_back( + absl::StrCat("vhost={\n" + " domains=[", + absl::StrJoin(vhost.domains, ", "), + "]\n" + " routes=[\n")); + for (const XdsApi::Route& route : vhost.routes) { + vhosts.push_back(" {\n"); + vhosts.push_back(route.ToString()); + vhosts.push_back("\n }\n"); + } + vhosts.push_back(" ]\n"); + vhosts.push_back(" typed_per_filter_config={\n"); + for (const auto& p : vhost.typed_per_filter_config) { + const std::string& name = p.first; + const auto& config = p.second; + vhosts.push_back( + absl::StrCat(" ", name, "=", config.ToString(), "\n")); + } + vhosts.push_back(" }\n"); + vhosts.push_back("]\n"); + } + return absl::StrJoin(vhosts, ""); +} + +namespace { + +// Better match type has smaller value. +enum MatchType { + EXACT_MATCH, + SUFFIX_MATCH, + PREFIX_MATCH, + UNIVERSE_MATCH, + INVALID_MATCH, +}; + +// Returns true if match succeeds. +bool DomainMatch(MatchType match_type, const std::string& domain_pattern_in, + const std::string& expected_host_name_in) { + // Normalize the args to lower-case. Domain matching is case-insensitive. + std::string domain_pattern = domain_pattern_in; + std::string expected_host_name = expected_host_name_in; + std::transform(domain_pattern.begin(), domain_pattern.end(), + domain_pattern.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::transform(expected_host_name.begin(), expected_host_name.end(), + expected_host_name.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (match_type == EXACT_MATCH) { + return domain_pattern == expected_host_name; + } else if (match_type == SUFFIX_MATCH) { + // Asterisk must match at least one char. + if (expected_host_name.size() < domain_pattern.size()) return false; + absl::string_view pattern_suffix(domain_pattern.c_str() + 1); + absl::string_view host_suffix(expected_host_name.c_str() + + expected_host_name.size() - + pattern_suffix.size()); + return pattern_suffix == host_suffix; + } else if (match_type == PREFIX_MATCH) { + // Asterisk must match at least one char. + if (expected_host_name.size() < domain_pattern.size()) return false; + absl::string_view pattern_prefix(domain_pattern.c_str(), + domain_pattern.size() - 1); + absl::string_view host_prefix(expected_host_name.c_str(), + pattern_prefix.size()); + return pattern_prefix == host_prefix; + } else { + return match_type == UNIVERSE_MATCH; + } +} + +MatchType DomainPatternMatchType(const std::string& domain_pattern) { + if (domain_pattern.empty()) return INVALID_MATCH; + if (domain_pattern.find('*') == std::string::npos) return EXACT_MATCH; + if (domain_pattern == "*") return UNIVERSE_MATCH; + if (domain_pattern[0] == '*') return SUFFIX_MATCH; + if (domain_pattern[domain_pattern.size() - 1] == '*') return PREFIX_MATCH; + return INVALID_MATCH; +} + +} // namespace + +XdsApi::RdsUpdate::VirtualHost* XdsApi::RdsUpdate::FindVirtualHostForDomain( + const std::string& domain) { + // Find the best matched virtual host. + // The search order for 4 groups of domain patterns: + // 1. Exact match. + // 2. Suffix match (e.g., "*ABC"). + // 3. Prefix match (e.g., "ABC*"). + // 4. Universe match (i.e., "*"). + // Within each group, longest match wins. + // If the same best matched domain pattern appears in multiple virtual hosts, + // the first matched virtual host wins. + VirtualHost* target_vhost = nullptr; + MatchType best_match_type = INVALID_MATCH; + size_t longest_match = 0; + // Check each domain pattern in each virtual host to determine the best + // matched virtual host. + for (VirtualHost& vhost : virtual_hosts) { + for (const std::string& domain_pattern : vhost.domains) { + // Check the match type first. Skip the pattern if it's not better than + // current match. + const MatchType match_type = DomainPatternMatchType(domain_pattern); + // This should be caught by RouteConfigParse(). + GPR_ASSERT(match_type != INVALID_MATCH); + if (match_type > best_match_type) continue; + if (match_type == best_match_type && + domain_pattern.size() <= longest_match) { + continue; + } + // Skip if match fails. + if (!DomainMatch(match_type, domain_pattern, domain)) continue; + // Choose this match. + target_vhost = &vhost; + best_match_type = match_type; + longest_match = domain_pattern.size(); + if (best_match_type == EXACT_MATCH) break; + } + if (best_match_type == EXACT_MATCH) break; + } + return target_vhost; +} + +// +// XdsApi::CommonTlsContext::CertificateValidationContext +// + +std::string XdsApi::CommonTlsContext::CertificateValidationContext::ToString() + const { + std::vector contents; + for (const auto& match : match_subject_alt_names) { + contents.push_back(match.ToString()); + } + return absl::StrFormat("{match_subject_alt_names=[%s]}", + absl::StrJoin(contents, ", ")); +} + +bool XdsApi::CommonTlsContext::CertificateValidationContext::Empty() const { + return match_subject_alt_names.empty(); +} + +// +// XdsApi::CommonTlsContext::CertificateProviderPluginInstance +// + +std::string +XdsApi::CommonTlsContext::CertificateProviderPluginInstance::ToString() const { + absl::InlinedVector contents; + if (!instance_name.empty()) { + contents.push_back(absl::StrFormat("instance_name=%s", instance_name)); + } + if (!certificate_name.empty()) { + contents.push_back( + absl::StrFormat("certificate_name=%s", certificate_name)); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +bool XdsApi::CommonTlsContext::CertificateProviderPluginInstance::Empty() + const { + return instance_name.empty() && certificate_name.empty(); +} + +// +// XdsApi::CommonTlsContext +// + +std::string XdsApi::CommonTlsContext::ToString() const { + absl::InlinedVector contents; + if (!tls_certificate_provider_instance.Empty()) { + contents.push_back( + absl::StrFormat("tls_certificate_provider_instance=%s", + tls_certificate_provider_instance.ToString())); + } + if (!certificate_validation_context.Empty()) { + contents.push_back( + absl::StrFormat("certificate_validation_context=%s", + certificate_validation_context.ToString())); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +bool XdsApi::CommonTlsContext::Empty() const { + return tls_certificate_provider_instance.Empty() && + certificate_validation_context.Empty(); +} + +// +// XdsApi::DownstreamTlsContext +// + +std::string XdsApi::DownstreamTlsContext::ToString() const { + return absl::StrFormat("common_tls_context=%s, require_client_certificate=%s", + common_tls_context.ToString(), + require_client_certificate ? "true" : "false"); +} + +bool XdsApi::DownstreamTlsContext::Empty() const { + return common_tls_context.Empty(); +} + +// +// XdsApi::LdsUpdate::HttpConnectionManager +// + +std::string XdsApi::LdsUpdate::HttpConnectionManager::ToString() const { + absl::InlinedVector contents; + contents.push_back(absl::StrFormat( + "route_config_name=%s", + !route_config_name.empty() ? route_config_name.c_str() : "")); + contents.push_back(absl::StrFormat("http_max_stream_duration=%s", + http_max_stream_duration.ToString())); + if (rds_update.has_value()) { + contents.push_back( + absl::StrFormat("rds_update=%s", rds_update->ToString())); + } + if (!http_filters.empty()) { + std::vector filter_strings; + for (const auto& http_filter : http_filters) { + filter_strings.push_back(http_filter.ToString()); + } + contents.push_back(absl::StrCat("http_filters=[", + absl::StrJoin(filter_strings, ", "), "]")); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::LdsUpdate::HttpFilter +// + +std::string XdsApi::LdsUpdate::HttpConnectionManager::HttpFilter::ToString() + const { + return absl::StrCat("{name=", name, ", config=", config.ToString(), "}"); +} + +// +// XdsApi::LdsUpdate::FilterChainData +// + +std::string XdsApi::LdsUpdate::FilterChainData::ToString() const { + return absl::StrCat( + "{downstream_tls_context=", downstream_tls_context.ToString(), + " http_connection_manager=", http_connection_manager.ToString(), "}"); +} + +// +// XdsApi::LdsUpdate::FilterChainMap::CidrRange +// + +std::string XdsApi::LdsUpdate::FilterChainMap::CidrRange::ToString() const { + return absl::StrCat( + "{address_prefix=", grpc_sockaddr_to_string(&address, false), + ", prefix_len=", prefix_len, "}"); +} + +// +// FilterChain +// + +struct FilterChain { + struct FilterChainMatch { + uint32_t destination_port = 0; + std::vector prefix_ranges; + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType source_type = + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType::kAny; + std::vector + source_prefix_ranges; + std::vector source_ports; + std::vector server_names; + std::string transport_protocol; + std::vector application_protocols; + + std::string ToString() const; + } filter_chain_match; + + std::shared_ptr filter_chain_data; +}; + +std::string FilterChain::FilterChainMatch::ToString() const { + absl::InlinedVector contents; + if (destination_port != 0) { + contents.push_back(absl::StrCat("destination_port=", destination_port)); + } + if (!prefix_ranges.empty()) { + std::vector prefix_ranges_content; + for (const auto& range : prefix_ranges) { + prefix_ranges_content.push_back(range.ToString()); + } + contents.push_back(absl::StrCat( + "prefix_ranges={", absl::StrJoin(prefix_ranges_content, ", "), "}")); + } + if (source_type == XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType:: + kSameIpOrLoopback) { + contents.push_back("source_type=SAME_IP_OR_LOOPBACK"); + } else if (source_type == XdsApi::LdsUpdate::FilterChainMap:: + ConnectionSourceType::kExternal) { + contents.push_back("source_type=EXTERNAL"); + } + if (!source_prefix_ranges.empty()) { + std::vector source_prefix_ranges_content; + for (const auto& range : source_prefix_ranges) { + source_prefix_ranges_content.push_back(range.ToString()); + } + contents.push_back( + absl::StrCat("source_prefix_ranges={", + absl::StrJoin(source_prefix_ranges_content, ", "), "}")); + } + if (!source_ports.empty()) { + contents.push_back( + absl::StrCat("source_ports={", absl::StrJoin(source_ports, ", "), "}")); + } + if (!server_names.empty()) { + contents.push_back( + absl::StrCat("server_names={", absl::StrJoin(server_names, ", "), "}")); + } + if (!transport_protocol.empty()) { + contents.push_back(absl::StrCat("transport_protocol=", transport_protocol)); + } + if (!application_protocols.empty()) { + contents.push_back(absl::StrCat("application_protocols={", + absl::StrJoin(application_protocols, ", "), + "}")); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::LdsUpdate::FilterChainMap +// + +std::string XdsApi::LdsUpdate::FilterChainMap::ToString() const { + std::vector contents; + for (const auto& destination_ip : destination_ip_vector) { + for (int source_type = 0; source_type < 3; ++source_type) { + for (const auto& source_ip : + destination_ip.source_types_array[source_type]) { + for (const auto& source_port_pair : source_ip.ports_map) { + FilterChain::FilterChainMatch filter_chain_match; + if (destination_ip.prefix_range.has_value()) { + filter_chain_match.prefix_ranges.push_back( + *destination_ip.prefix_range); + } + filter_chain_match.source_type = static_cast< + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType>( + source_type); + if (source_ip.prefix_range.has_value()) { + filter_chain_match.source_prefix_ranges.push_back( + *source_ip.prefix_range); + } + if (source_port_pair.first != 0) { + filter_chain_match.source_ports.push_back(source_port_pair.first); + } + contents.push_back(absl::StrCat( + "{filter_chain_match=", filter_chain_match.ToString(), + ", filter_chain=", source_port_pair.second.data->ToString(), + "}")); + } + } + } + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::LdsUpdate +// + +std::string XdsApi::LdsUpdate::ToString() const { + absl::InlinedVector contents; + if (type == ListenerType::kTcpListener) { + contents.push_back(absl::StrCat("address=", address)); + contents.push_back( + absl::StrCat("filter_chain_map=", filter_chain_map.ToString())); + if (default_filter_chain.has_value()) { + contents.push_back(absl::StrCat("default_filter_chain=", + default_filter_chain->ToString())); + } + } else if (type == ListenerType::kHttpApiListener) { + contents.push_back(absl::StrFormat("http_connection_manager=%s", + http_connection_manager.ToString())); + } + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::CdsUpdate +// + +std::string XdsApi::CdsUpdate::ToString() const { + absl::InlinedVector contents; + switch (cluster_type) { + case EDS: + contents.push_back("cluster_type=EDS"); + if (!eds_service_name.empty()) { + contents.push_back( + absl::StrFormat("eds_service_name=%s", eds_service_name)); + } + break; + case LOGICAL_DNS: + contents.push_back("cluster_type=LOGICAL_DNS"); + contents.push_back(absl::StrFormat("dns_hostname=%s", dns_hostname)); + break; + case AGGREGATE: + contents.push_back("cluster_type=AGGREGATE"); + contents.push_back( + absl::StrFormat("prioritized_cluster_names=[%s]", + absl::StrJoin(prioritized_cluster_names, ", "))); + } + if (!common_tls_context.Empty()) { + contents.push_back(absl::StrFormat("common_tls_context=%s", + common_tls_context.ToString())); + } + if (lrs_load_reporting_server_name.has_value()) { + contents.push_back(absl::StrFormat("lrs_load_reporting_server_name=%s", + lrs_load_reporting_server_name.value())); + } + contents.push_back(absl::StrCat("lb_policy=", lb_policy)); + if (lb_policy == "RING_HASH") { + contents.push_back(absl::StrCat("min_ring_size=", min_ring_size)); + contents.push_back(absl::StrCat("max_ring_size=", max_ring_size)); + } + contents.push_back( + absl::StrFormat("max_concurrent_requests=%d", max_concurrent_requests)); + return absl::StrCat("{", absl::StrJoin(contents, ", "), "}"); +} + +// +// XdsApi::EdsUpdate +// + +std::string XdsApi::EdsUpdate::Priority::Locality::ToString() const { + std::vector endpoint_strings; + for (const ServerAddress& endpoint : endpoints) { + endpoint_strings.emplace_back(endpoint.ToString()); + } + return absl::StrCat("{name=", name->AsHumanReadableString(), + ", lb_weight=", lb_weight, ", endpoints=[", + absl::StrJoin(endpoint_strings, ", "), "]}"); +} + +bool XdsApi::EdsUpdate::Priority::operator==(const Priority& other) const { + if (localities.size() != other.localities.size()) return false; + auto it1 = localities.begin(); + auto it2 = other.localities.begin(); + while (it1 != localities.end()) { + if (*it1->first != *it2->first) return false; + if (it1->second != it2->second) return false; + ++it1; + ++it2; + } + return true; +} + +std::string XdsApi::EdsUpdate::Priority::ToString() const { + std::vector locality_strings; + for (const auto& p : localities) { + locality_strings.emplace_back(p.second.ToString()); + } + return absl::StrCat("[", absl::StrJoin(locality_strings, ", "), "]"); +} + +bool XdsApi::EdsUpdate::DropConfig::ShouldDrop( + const std::string** category_name) const { + for (size_t i = 0; i < drop_category_list_.size(); ++i) { + const auto& drop_category = drop_category_list_[i]; + // Generate a random number in [0, 1000000). + const uint32_t random = static_cast(rand()) % 1000000; + if (random < drop_category.parts_per_million) { + *category_name = &drop_category.name; + return true; + } + } + return false; +} + +std::string XdsApi::EdsUpdate::DropConfig::ToString() const { + std::vector category_strings; + for (const DropCategory& category : drop_category_list_) { + category_strings.emplace_back( + absl::StrCat(category.name, "=", category.parts_per_million)); + } + return absl::StrCat("{[", absl::StrJoin(category_strings, ", "), + "], drop_all=", drop_all_, "}"); +} + +std::string XdsApi::EdsUpdate::ToString() const { + std::vector priority_strings; + for (size_t i = 0; i < priorities.size(); ++i) { + const Priority& priority = priorities[i]; + priority_strings.emplace_back( + absl::StrCat("priority ", i, ": ", priority.ToString())); + } + return absl::StrCat("priorities=[", absl::StrJoin(priority_strings, ", "), + "], drop_config=", drop_config->ToString()); +} + +// +// XdsApi +// + +const char* XdsApi::kLdsTypeUrl = + "type.googleapis.com/envoy.config.listener.v3.Listener"; +const char* XdsApi::kRdsTypeUrl = + "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; +const char* XdsApi::kCdsTypeUrl = + "type.googleapis.com/envoy.config.cluster.v3.Cluster"; +const char* XdsApi::kEdsTypeUrl = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + +namespace { + +const char* kLdsV2TypeUrl = "type.googleapis.com/envoy.api.v2.Listener"; +const char* kRdsV2TypeUrl = + "type.googleapis.com/envoy.api.v2.RouteConfiguration"; +const char* kCdsV2TypeUrl = "type.googleapis.com/envoy.api.v2.Cluster"; +const char* kEdsV2TypeUrl = + "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; + +bool IsLds(absl::string_view type_url, bool* is_v2 = nullptr) { + if (type_url == XdsApi::kLdsTypeUrl) return true; + if (type_url == kLdsV2TypeUrl) { + if (is_v2 != nullptr) *is_v2 = true; + return true; + } + return false; +} + +bool IsRds(absl::string_view type_url, bool* /*is_v2*/ = nullptr) { + return type_url == XdsApi::kRdsTypeUrl || type_url == kRdsV2TypeUrl; +} + +bool IsCds(absl::string_view type_url, bool* /*is_v2*/ = nullptr) { + return type_url == XdsApi::kCdsTypeUrl || type_url == kCdsV2TypeUrl; +} + +bool IsEds(absl::string_view type_url, bool* /*is_v2*/ = nullptr) { + return type_url == XdsApi::kEdsTypeUrl || type_url == kEdsV2TypeUrl; +} + +} // namespace + +// If gRPC is built with -DGRPC_XDS_USER_AGENT_NAME_SUFFIX="...", that string +// will be appended to the user agent name reported to the xDS server. +#ifdef GRPC_XDS_USER_AGENT_NAME_SUFFIX +#define GRPC_XDS_USER_AGENT_NAME_SUFFIX_STRING \ + " " GRPC_XDS_USER_AGENT_NAME_SUFFIX +#else +#define GRPC_XDS_USER_AGENT_NAME_SUFFIX_STRING "" +#endif + +// If gRPC is built with -DGRPC_XDS_USER_AGENT_VERSION_SUFFIX="...", that string +// will be appended to the user agent version reported to the xDS server. +#ifdef GRPC_XDS_USER_AGENT_VERSION_SUFFIX +#define GRPC_XDS_USER_AGENT_VERSION_SUFFIX_STRING \ + " " GRPC_XDS_USER_AGENT_VERSION_SUFFIX +#else +#define GRPC_XDS_USER_AGENT_VERSION_SUFFIX_STRING "" +#endif + +XdsApi::XdsApi(XdsClient* client, TraceFlag* tracer, + const XdsBootstrap::Node* node, + const CertificateProviderStore::PluginDefinitionMap* + certificate_provider_definition_map) + : client_(client), + tracer_(tracer), + node_(node), + certificate_provider_definition_map_(certificate_provider_definition_map), + build_version_(absl::StrCat("gRPC C-core ", GPR_PLATFORM_STRING, " ", + grpc_version_string(), + GRPC_XDS_USER_AGENT_NAME_SUFFIX_STRING, + GRPC_XDS_USER_AGENT_VERSION_SUFFIX_STRING)), + user_agent_name_(absl::StrCat("gRPC C-core ", GPR_PLATFORM_STRING, + GRPC_XDS_USER_AGENT_NAME_SUFFIX_STRING)), + user_agent_version_( + absl::StrCat("C-core ", grpc_version_string(), + GRPC_XDS_USER_AGENT_NAME_SUFFIX_STRING, + GRPC_XDS_USER_AGENT_VERSION_SUFFIX_STRING)) { + // Populate upb symtab with xDS proto messages that we want to print + // properly in logs. + // Note: This won't actually work properly until upb adds support for + // Any fields in textproto printing (internal b/178821188). + envoy_config_listener_v3_Listener_getmsgdef(symtab_.ptr()); + envoy_config_route_v3_RouteConfiguration_getmsgdef(symtab_.ptr()); + envoy_config_cluster_v3_Cluster_getmsgdef(symtab_.ptr()); + envoy_extensions_clusters_aggregate_v3_ClusterConfig_getmsgdef(symtab_.ptr()); + envoy_config_cluster_v3_Cluster_getmsgdef(symtab_.ptr()); + envoy_config_endpoint_v3_ClusterLoadAssignment_getmsgdef(symtab_.ptr()); + envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_getmsgdef( + symtab_.ptr()); + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_getmsgdef( + symtab_.ptr()); + // Load HTTP filter proto messages into the upb symtab. + XdsHttpFilterRegistry::PopulateSymtab(symtab_.ptr()); +} + +namespace { + +struct EncodingContext { + XdsClient* client; // Used only for logging. Unsafe for dereferencing. + TraceFlag* tracer; + upb_symtab* symtab; + upb_arena* arena; + bool use_v3; + const CertificateProviderStore::PluginDefinitionMap* + certificate_provider_definition_map; +}; + +// Works for both std::string and absl::string_view. +template +inline upb_strview StdStringToUpbString(const T& str) { + return upb_strview_make(str.data(), str.size()); +} + +void PopulateMetadataValue(const EncodingContext& context, + google_protobuf_Value* value_pb, const Json& value); + +void PopulateListValue(const EncodingContext& context, + google_protobuf_ListValue* list_value, + const Json::Array& values) { + for (const auto& value : values) { + auto* value_pb = + google_protobuf_ListValue_add_values(list_value, context.arena); + PopulateMetadataValue(context, value_pb, value); + } +} + +void PopulateMetadata(const EncodingContext& context, + google_protobuf_Struct* metadata_pb, + const Json::Object& metadata) { + for (const auto& p : metadata) { + google_protobuf_Value* value = google_protobuf_Value_new(context.arena); + PopulateMetadataValue(context, value, p.second); + google_protobuf_Struct_fields_set( + metadata_pb, StdStringToUpbString(p.first), value, context.arena); + } +} + +void PopulateMetadataValue(const EncodingContext& context, + google_protobuf_Value* value_pb, const Json& value) { + switch (value.type()) { + case Json::Type::JSON_NULL: + google_protobuf_Value_set_null_value(value_pb, 0); + break; + case Json::Type::NUMBER: + google_protobuf_Value_set_number_value( + value_pb, strtod(value.string_value().c_str(), nullptr)); + break; + case Json::Type::STRING: + google_protobuf_Value_set_string_value( + value_pb, StdStringToUpbString(value.string_value())); + break; + case Json::Type::JSON_TRUE: + google_protobuf_Value_set_bool_value(value_pb, true); + break; + case Json::Type::JSON_FALSE: + google_protobuf_Value_set_bool_value(value_pb, false); + break; + case Json::Type::OBJECT: { + google_protobuf_Struct* struct_value = + google_protobuf_Value_mutable_struct_value(value_pb, context.arena); + PopulateMetadata(context, struct_value, value.object_value()); + break; + } + case Json::Type::ARRAY: { + google_protobuf_ListValue* list_value = + google_protobuf_Value_mutable_list_value(value_pb, context.arena); + PopulateListValue(context, list_value, value.array_value()); + break; + } + } +} + +// Helper functions to manually do protobuf string encoding, so that we +// can populate the node build_version field that was removed in v3. +std::string EncodeVarint(uint64_t val) { + std::string data; + do { + uint8_t byte = val & 0x7fU; + val >>= 7; + if (val) byte |= 0x80U; + data += byte; + } while (val); + return data; +} +std::string EncodeTag(uint32_t field_number, uint8_t wire_type) { + return EncodeVarint((field_number << 3) | wire_type); +} +std::string EncodeStringField(uint32_t field_number, const std::string& str) { + static const uint8_t kDelimitedWireType = 2; + return EncodeTag(field_number, kDelimitedWireType) + + EncodeVarint(str.size()) + str; +} + +void PopulateBuildVersion(const EncodingContext& context, + envoy_config_core_v3_Node* node_msg, + const std::string& build_version) { + std::string encoded_build_version = EncodeStringField(5, build_version); + // TODO(roth): This should use upb_msg_addunknown(), but that API is + // broken in the current version of upb, so we're using the internal + // API for now. Change this once we upgrade to a version of upb that + // fixes this bug. + _upb_msg_addunknown(node_msg, encoded_build_version.data(), + encoded_build_version.size(), context.arena); +} + +void PopulateNode(const EncodingContext& context, + const XdsBootstrap::Node* node, + const std::string& build_version, + const std::string& user_agent_name, + const std::string& user_agent_version, + envoy_config_core_v3_Node* node_msg) { + if (node != nullptr) { + if (!node->id.empty()) { + envoy_config_core_v3_Node_set_id(node_msg, + StdStringToUpbString(node->id)); + } + if (!node->cluster.empty()) { + envoy_config_core_v3_Node_set_cluster( + node_msg, StdStringToUpbString(node->cluster)); + } + if (!node->metadata.object_value().empty()) { + google_protobuf_Struct* metadata = + envoy_config_core_v3_Node_mutable_metadata(node_msg, context.arena); + PopulateMetadata(context, metadata, node->metadata.object_value()); + } + if (!node->locality_region.empty() || !node->locality_zone.empty() || + !node->locality_sub_zone.empty()) { + envoy_config_core_v3_Locality* locality = + envoy_config_core_v3_Node_mutable_locality(node_msg, context.arena); + if (!node->locality_region.empty()) { + envoy_config_core_v3_Locality_set_region( + locality, StdStringToUpbString(node->locality_region)); + } + if (!node->locality_zone.empty()) { + envoy_config_core_v3_Locality_set_zone( + locality, StdStringToUpbString(node->locality_zone)); + } + if (!node->locality_sub_zone.empty()) { + envoy_config_core_v3_Locality_set_sub_zone( + locality, StdStringToUpbString(node->locality_sub_zone)); + } + } + } + if (!context.use_v3) { + PopulateBuildVersion(context, node_msg, build_version); + } + envoy_config_core_v3_Node_set_user_agent_name( + node_msg, StdStringToUpbString(user_agent_name)); + envoy_config_core_v3_Node_set_user_agent_version( + node_msg, StdStringToUpbString(user_agent_version)); + envoy_config_core_v3_Node_add_client_features( + node_msg, upb_strview_makez("envoy.lb.does_not_support_overprovisioning"), + context.arena); +} + +inline absl::string_view UpbStringToAbsl(const upb_strview& str) { + return absl::string_view(str.data, str.size); +} + +inline std::string UpbStringToStdString(const upb_strview& str) { + return std::string(str.data, str.size); +} + +void MaybeLogDiscoveryRequest( + const EncodingContext& context, + const envoy_service_discovery_v3_DiscoveryRequest* request) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_service_discovery_v3_DiscoveryRequest_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(request, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] constructed ADS request: %s", + context.client, buf); + } +} + +grpc_slice SerializeDiscoveryRequest( + const EncodingContext& context, + envoy_service_discovery_v3_DiscoveryRequest* request) { + size_t output_length; + char* output = envoy_service_discovery_v3_DiscoveryRequest_serialize( + request, context.arena, &output_length); + return grpc_slice_from_copied_buffer(output, output_length); +} + +absl::string_view TypeUrlExternalToInternal(bool use_v3, + const std::string& type_url) { + if (!use_v3) { + if (type_url == XdsApi::kLdsTypeUrl) { + return kLdsV2TypeUrl; + } + if (type_url == XdsApi::kRdsTypeUrl) { + return kRdsV2TypeUrl; + } + if (type_url == XdsApi::kCdsTypeUrl) { + return kCdsV2TypeUrl; + } + if (type_url == XdsApi::kEdsTypeUrl) { + return kEdsV2TypeUrl; + } + } + return type_url; +} + +} // namespace + +grpc_slice XdsApi::CreateAdsRequest( + const XdsBootstrap::XdsServer& server, const std::string& type_url, + const std::set& resource_names, + const std::string& version, const std::string& nonce, + grpc_error_handle error, bool populate_node) { + upb::Arena arena; + const EncodingContext context = {client_, + tracer_, + symtab_.ptr(), + arena.ptr(), + server.ShouldUseV3(), + certificate_provider_definition_map_}; + // Create a request. + envoy_service_discovery_v3_DiscoveryRequest* request = + envoy_service_discovery_v3_DiscoveryRequest_new(arena.ptr()); + // Set type_url. + absl::string_view real_type_url = + TypeUrlExternalToInternal(server.ShouldUseV3(), type_url); + envoy_service_discovery_v3_DiscoveryRequest_set_type_url( + request, StdStringToUpbString(real_type_url)); + // Set version_info. + if (!version.empty()) { + envoy_service_discovery_v3_DiscoveryRequest_set_version_info( + request, StdStringToUpbString(version)); + } + // Set nonce. + if (!nonce.empty()) { + envoy_service_discovery_v3_DiscoveryRequest_set_response_nonce( + request, StdStringToUpbString(nonce)); + } + // Set error_detail if it's a NACK. + std::string error_string_storage; + if (error != GRPC_ERROR_NONE) { + google_rpc_Status* error_detail = + envoy_service_discovery_v3_DiscoveryRequest_mutable_error_detail( + request, arena.ptr()); + // Hard-code INVALID_ARGUMENT as the status code. + // TODO(roth): If at some point we decide we care about this value, + // we could attach a status code to the individual errors where we + // generate them in the parsing code, and then use that here. + google_rpc_Status_set_code(error_detail, GRPC_STATUS_INVALID_ARGUMENT); + // Error description comes from the error that was passed in. + error_string_storage = grpc_error_std_string(error); + upb_strview error_description = StdStringToUpbString(error_string_storage); + google_rpc_Status_set_message(error_detail, error_description); + GRPC_ERROR_UNREF(error); + } + // Populate node. + if (populate_node) { + envoy_config_core_v3_Node* node_msg = + envoy_service_discovery_v3_DiscoveryRequest_mutable_node(request, + arena.ptr()); + PopulateNode(context, node_, build_version_, user_agent_name_, + user_agent_version_, node_msg); + } + // Add resource_names. + for (const auto& resource_name : resource_names) { + envoy_service_discovery_v3_DiscoveryRequest_add_resource_names( + request, StdStringToUpbString(resource_name), arena.ptr()); + } + MaybeLogDiscoveryRequest(context, request); + return SerializeDiscoveryRequest(context, request); +} + +namespace { + +void MaybeLogDiscoveryResponse( + const EncodingContext& context, + const envoy_service_discovery_v3_DiscoveryResponse* response) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_service_discovery_v3_DiscoveryResponse_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(response, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] received response: %s", context.client, + buf); + } +} + +void MaybeLogListener(const EncodingContext& context, + const envoy_config_listener_v3_Listener* listener) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_config_listener_v3_Listener_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(listener, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] Listener: %s", context.client, buf); + } +} + +void MaybeLogHttpConnectionManager( + const EncodingContext& context, + const envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager* + http_connection_manager_config) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_getmsgdef( + context.symtab); + char buf[10240]; + upb_text_encode(http_connection_manager_config, msg_type, nullptr, 0, buf, + sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] HttpConnectionManager: %s", + context.client, buf); + } +} + +void MaybeLogRouteConfiguration( + const EncodingContext& context, + const envoy_config_route_v3_RouteConfiguration* route_config) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_config_route_v3_RouteConfiguration_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(route_config, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] RouteConfiguration: %s", context.client, + buf); + } +} + +void MaybeLogCluster(const EncodingContext& context, + const envoy_config_cluster_v3_Cluster* cluster) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_config_cluster_v3_Cluster_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(cluster, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] Cluster: %s", context.client, buf); + } +} + +void MaybeLogClusterLoadAssignment( + const EncodingContext& context, + const envoy_config_endpoint_v3_ClusterLoadAssignment* cla) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_config_endpoint_v3_ClusterLoadAssignment_getmsgdef( + context.symtab); + char buf[10240]; + upb_text_encode(cla, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] ClusterLoadAssignment: %s", + context.client, buf); + } +} + +grpc_error_handle RoutePathMatchParse( + const envoy_config_route_v3_RouteMatch* match, XdsApi::Route* route, + bool* ignore_route) { + auto* case_sensitive_ptr = + envoy_config_route_v3_RouteMatch_case_sensitive(match); + bool case_sensitive = true; + if (case_sensitive_ptr != nullptr) { + case_sensitive = google_protobuf_BoolValue_value(case_sensitive_ptr); + } + StringMatcher::Type type; + std::string match_string; + if (envoy_config_route_v3_RouteMatch_has_prefix(match)) { + absl::string_view prefix = + UpbStringToAbsl(envoy_config_route_v3_RouteMatch_prefix(match)); + // Empty prefix "" is accepted. + if (!prefix.empty()) { + // Prefix "/" is accepted. + if (prefix[0] != '/') { + // Prefix which does not start with a / will never match anything, so + // ignore this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } + std::vector prefix_elements = + absl::StrSplit(prefix.substr(1), absl::MaxSplits('/', 2)); + if (prefix_elements.size() > 2) { + // Prefix cannot have more than 2 slashes. + *ignore_route = true; + return GRPC_ERROR_NONE; + } else if (prefix_elements.size() == 2 && prefix_elements[0].empty()) { + // Prefix contains empty string between the 2 slashes + *ignore_route = true; + return GRPC_ERROR_NONE; + } + } + type = StringMatcher::Type::kPrefix; + match_string = std::string(prefix); + } else if (envoy_config_route_v3_RouteMatch_has_path(match)) { + absl::string_view path = + UpbStringToAbsl(envoy_config_route_v3_RouteMatch_path(match)); + if (path.empty()) { + // Path that is empty will never match anything, so ignore this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } + if (path[0] != '/') { + // Path which does not start with a / will never match anything, so + // ignore this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } + std::vector path_elements = + absl::StrSplit(path.substr(1), absl::MaxSplits('/', 2)); + if (path_elements.size() != 2) { + // Path not in the required format of /service/method will never match + // anything, so ignore this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } else if (path_elements[0].empty()) { + // Path contains empty service name will never match anything, so ignore + // this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } else if (path_elements[1].empty()) { + // Path contains empty method name will never match anything, so ignore + // this route. + *ignore_route = true; + return GRPC_ERROR_NONE; + } + type = StringMatcher::Type::kExact; + match_string = std::string(path); + } else if (envoy_config_route_v3_RouteMatch_has_safe_regex(match)) { + const envoy_type_matcher_v3_RegexMatcher* regex_matcher = + envoy_config_route_v3_RouteMatch_safe_regex(match); + GPR_ASSERT(regex_matcher != nullptr); + type = StringMatcher::Type::kSafeRegex; + match_string = UpbStringToStdString( + envoy_type_matcher_v3_RegexMatcher_regex(regex_matcher)); + } else { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid route path specifier specified."); + } + absl::StatusOr string_matcher = + StringMatcher::Create(type, match_string, case_sensitive); + if (!string_matcher.ok()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("path matcher: ", string_matcher.status().message())); + } + route->matchers.path_matcher = std::move(string_matcher.value()); + return GRPC_ERROR_NONE; +} + +grpc_error_handle RouteHeaderMatchersParse( + const envoy_config_route_v3_RouteMatch* match, XdsApi::Route* route) { + size_t size; + const envoy_config_route_v3_HeaderMatcher* const* headers = + envoy_config_route_v3_RouteMatch_headers(match, &size); + for (size_t i = 0; i < size; ++i) { + const envoy_config_route_v3_HeaderMatcher* header = headers[i]; + const std::string name = + UpbStringToStdString(envoy_config_route_v3_HeaderMatcher_name(header)); + HeaderMatcher::Type type; + std::string match_string; + int64_t range_start = 0; + int64_t range_end = 0; + bool present_match = false; + if (envoy_config_route_v3_HeaderMatcher_has_exact_match(header)) { + type = HeaderMatcher::Type::kExact; + match_string = UpbStringToStdString( + envoy_config_route_v3_HeaderMatcher_exact_match(header)); + } else if (envoy_config_route_v3_HeaderMatcher_has_safe_regex_match( + header)) { + const envoy_type_matcher_v3_RegexMatcher* regex_matcher = + envoy_config_route_v3_HeaderMatcher_safe_regex_match(header); + GPR_ASSERT(regex_matcher != nullptr); + type = HeaderMatcher::Type::kSafeRegex; + match_string = UpbStringToStdString( + envoy_type_matcher_v3_RegexMatcher_regex(regex_matcher)); + } else if (envoy_config_route_v3_HeaderMatcher_has_range_match(header)) { + type = HeaderMatcher::Type::kRange; + const envoy_type_v3_Int64Range* range_matcher = + envoy_config_route_v3_HeaderMatcher_range_match(header); + range_start = envoy_type_v3_Int64Range_start(range_matcher); + range_end = envoy_type_v3_Int64Range_end(range_matcher); + } else if (envoy_config_route_v3_HeaderMatcher_has_present_match(header)) { + type = HeaderMatcher::Type::kPresent; + present_match = envoy_config_route_v3_HeaderMatcher_present_match(header); + } else if (envoy_config_route_v3_HeaderMatcher_has_prefix_match(header)) { + type = HeaderMatcher::Type::kPrefix; + match_string = UpbStringToStdString( + envoy_config_route_v3_HeaderMatcher_prefix_match(header)); + } else if (envoy_config_route_v3_HeaderMatcher_has_suffix_match(header)) { + type = HeaderMatcher::Type::kSuffix; + match_string = UpbStringToStdString( + envoy_config_route_v3_HeaderMatcher_suffix_match(header)); + } else if (envoy_config_route_v3_HeaderMatcher_has_contains_match(header)) { + type = HeaderMatcher::Type::kContains; + match_string = UpbStringToStdString( + envoy_config_route_v3_HeaderMatcher_contains_match(header)); + } else { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid route header matcher specified."); + } + bool invert_match = + envoy_config_route_v3_HeaderMatcher_invert_match(header); + absl::StatusOr header_matcher = + HeaderMatcher::Create(name, type, match_string, range_start, range_end, + present_match, invert_match); + if (!header_matcher.ok()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("header matcher: ", header_matcher.status().message())); + } + route->matchers.header_matchers.emplace_back( + std::move(header_matcher.value())); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle RouteRuntimeFractionParse( + const envoy_config_route_v3_RouteMatch* match, XdsApi::Route* route) { + const envoy_config_core_v3_RuntimeFractionalPercent* runtime_fraction = + envoy_config_route_v3_RouteMatch_runtime_fraction(match); + if (runtime_fraction != nullptr) { + const envoy_type_v3_FractionalPercent* fraction = + envoy_config_core_v3_RuntimeFractionalPercent_default_value( + runtime_fraction); + if (fraction != nullptr) { + uint32_t numerator = envoy_type_v3_FractionalPercent_numerator(fraction); + const auto denominator = + static_cast( + envoy_type_v3_FractionalPercent_denominator(fraction)); + // Normalize to million. + switch (denominator) { + case envoy_type_v3_FractionalPercent_HUNDRED: + numerator *= 10000; + break; + case envoy_type_v3_FractionalPercent_TEN_THOUSAND: + numerator *= 100; + break; + case envoy_type_v3_FractionalPercent_MILLION: + break; + default: + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unknown denominator type"); + } + route->matchers.fraction_per_million = numerator; + } + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle ExtractHttpFilterTypeName(const EncodingContext& context, + const google_protobuf_Any* any, + absl::string_view* filter_type) { + *filter_type = UpbStringToAbsl(google_protobuf_Any_type_url(any)); + if (*filter_type == "type.googleapis.com/xds.type.v3.TypedStruct" || + *filter_type == "type.googleapis.com/udpa.type.v1.TypedStruct") { + upb_strview any_value = google_protobuf_Any_value(any); + const auto* typed_struct = xds_type_v3_TypedStruct_parse( + any_value.data, any_value.size, context.arena); + if (typed_struct == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "could not parse TypedStruct from filter config"); + } + *filter_type = + UpbStringToAbsl(xds_type_v3_TypedStruct_type_url(typed_struct)); + } + *filter_type = absl::StripPrefix(*filter_type, "type.googleapis.com/"); + return GRPC_ERROR_NONE; +} + +template +grpc_error_handle ParseTypedPerFilterConfig( + const EncodingContext& context, const ParentType* parent, + const EntryType* (*entry_func)(const ParentType*, size_t*), + upb_strview (*key_func)(const EntryType*), + const google_protobuf_Any* (*value_func)(const EntryType*), + XdsApi::TypedPerFilterConfig* typed_per_filter_config) { + size_t filter_it = UPB_MAP_BEGIN; + while (true) { + const auto* filter_entry = entry_func(parent, &filter_it); + if (filter_entry == nullptr) break; + absl::string_view key = UpbStringToAbsl(key_func(filter_entry)); + if (key.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("empty filter name in map"); + } + const google_protobuf_Any* any = value_func(filter_entry); + GPR_ASSERT(any != nullptr); + absl::string_view filter_type = + UpbStringToAbsl(google_protobuf_Any_type_url(any)); + if (filter_type.empty()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("no filter config specified for filter name ", key)); + } + bool is_optional = false; + if (filter_type == + "type.googleapis.com/envoy.config.route.v3.FilterConfig") { + upb_strview any_value = google_protobuf_Any_value(any); + const auto* filter_config = envoy_config_route_v3_FilterConfig_parse( + any_value.data, any_value.size, context.arena); + if (filter_config == nullptr) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("could not parse FilterConfig wrapper for ", key)); + } + is_optional = + envoy_config_route_v3_FilterConfig_is_optional(filter_config); + any = envoy_config_route_v3_FilterConfig_config(filter_config); + if (any == nullptr) { + if (is_optional) continue; + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("no filter config specified for filter name ", key)); + } + } + grpc_error_handle error = + ExtractHttpFilterTypeName(context, any, &filter_type); + if (error != GRPC_ERROR_NONE) return error; + const XdsHttpFilterImpl* filter_impl = + XdsHttpFilterRegistry::GetFilterForType(filter_type); + if (filter_impl == nullptr) { + if (is_optional) continue; + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("no filter registered for config type ", filter_type)); + } + absl::StatusOr filter_config = + filter_impl->GenerateFilterConfigOverride( + google_protobuf_Any_value(any), context.arena); + if (!filter_config.ok()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "filter config for type ", filter_type, + " failed to parse: ", filter_config.status().ToString())); + } + (*typed_per_filter_config)[std::string(key)] = std::move(*filter_config); + } + return GRPC_ERROR_NONE; +} + +XdsApi::Duration DurationParse(const google_protobuf_Duration* proto_duration) { + XdsApi::Duration duration; + duration.seconds = google_protobuf_Duration_seconds(proto_duration); + duration.nanos = google_protobuf_Duration_nanos(proto_duration); + return duration; +} + +grpc_error_handle RetryPolicyParse( + const EncodingContext& context, + const envoy_config_route_v3_RetryPolicy* retry_policy, + absl::optional* retry) { + std::vector errors; + XdsApi::Route::RetryPolicy retry_to_return; + auto retry_on = UpbStringToStdString( + envoy_config_route_v3_RetryPolicy_retry_on(retry_policy)); + std::vector codes = absl::StrSplit(retry_on, ','); + for (const auto& code : codes) { + if (code == "cancelled") { + retry_to_return.retry_on.Add(GRPC_STATUS_CANCELLED); + } else if (code == "deadline-exceeded") { + retry_to_return.retry_on.Add(GRPC_STATUS_DEADLINE_EXCEEDED); + } else if (code == "internal") { + retry_to_return.retry_on.Add(GRPC_STATUS_INTERNAL); + } else if (code == "resource-exhausted") { + retry_to_return.retry_on.Add(GRPC_STATUS_RESOURCE_EXHAUSTED); + } else if (code == "unavailable") { + retry_to_return.retry_on.Add(GRPC_STATUS_UNAVAILABLE); + } else { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer)) { + gpr_log(GPR_INFO, "Unsupported retry_on policy %s.", + std::string(code).c_str()); + } + } + } + const google_protobuf_UInt32Value* num_retries = + envoy_config_route_v3_RetryPolicy_num_retries(retry_policy); + if (num_retries != nullptr) { + uint32_t num_retries_value = google_protobuf_UInt32Value_value(num_retries); + if (num_retries_value == 0) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction RetryPolicy num_retries set to invalid value 0.")); + } else { + retry_to_return.num_retries = num_retries_value; + } + } else { + retry_to_return.num_retries = 1; + } + const envoy_config_route_v3_RetryPolicy_RetryBackOff* backoff = + envoy_config_route_v3_RetryPolicy_retry_back_off(retry_policy); + if (backoff != nullptr) { + const google_protobuf_Duration* base_interval = + envoy_config_route_v3_RetryPolicy_RetryBackOff_base_interval(backoff); + if (base_interval == nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction RetryPolicy RetryBackoff missing base interval.")); + } else { + retry_to_return.retry_back_off.base_interval = + DurationParse(base_interval); + } + const google_protobuf_Duration* max_interval = + envoy_config_route_v3_RetryPolicy_RetryBackOff_max_interval(backoff); + XdsApi::Duration max; + if (max_interval != nullptr) { + max = DurationParse(max_interval); + } else { + // if max interval is not set, it is 10x the base, if the value in nanos + // can yield another second, adjust the value in seconds accordingly. + max.seconds = retry_to_return.retry_back_off.base_interval.seconds * 10; + max.nanos = retry_to_return.retry_back_off.base_interval.nanos * 10; + if (max.nanos > 1000000000) { + max.seconds += max.nanos / 1000000000; + max.nanos = max.nanos % 1000000000; + } + } + retry_to_return.retry_back_off.max_interval = max; + } else { + retry_to_return.retry_back_off.base_interval.seconds = 0; + retry_to_return.retry_back_off.base_interval.nanos = 25000000; + retry_to_return.retry_back_off.max_interval.seconds = 0; + retry_to_return.retry_back_off.max_interval.nanos = 250000000; + } + if (errors.empty()) { + *retry = retry_to_return; + return GRPC_ERROR_NONE; + } else { + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing retry policy", + &errors); + } +} + +grpc_error_handle RouteActionParse(const EncodingContext& context, + const envoy_config_route_v3_Route* route_msg, + XdsApi::Route* route, bool* ignore_route) { + if (!envoy_config_route_v3_Route_has_route(route_msg)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No RouteAction found in route."); + } + const envoy_config_route_v3_RouteAction* route_action = + envoy_config_route_v3_Route_route(route_msg); + // Get the cluster or weighted_clusters in the RouteAction. + if (envoy_config_route_v3_RouteAction_has_cluster(route_action)) { + route->cluster_name = UpbStringToStdString( + envoy_config_route_v3_RouteAction_cluster(route_action)); + if (route->cluster_name.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction cluster contains empty cluster name."); + } + } else if (envoy_config_route_v3_RouteAction_has_weighted_clusters( + route_action)) { + const envoy_config_route_v3_WeightedCluster* weighted_cluster = + envoy_config_route_v3_RouteAction_weighted_clusters(route_action); + uint32_t total_weight = 100; + const google_protobuf_UInt32Value* weight = + envoy_config_route_v3_WeightedCluster_total_weight(weighted_cluster); + if (weight != nullptr) { + total_weight = google_protobuf_UInt32Value_value(weight); + } + size_t clusters_size; + const envoy_config_route_v3_WeightedCluster_ClusterWeight* const* clusters = + envoy_config_route_v3_WeightedCluster_clusters(weighted_cluster, + &clusters_size); + uint32_t sum_of_weights = 0; + for (size_t j = 0; j < clusters_size; ++j) { + const envoy_config_route_v3_WeightedCluster_ClusterWeight* + cluster_weight = clusters[j]; + XdsApi::Route::ClusterWeight cluster; + cluster.name = UpbStringToStdString( + envoy_config_route_v3_WeightedCluster_ClusterWeight_name( + cluster_weight)); + if (cluster.name.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction weighted_cluster cluster contains empty cluster " + "name."); + } + const google_protobuf_UInt32Value* weight = + envoy_config_route_v3_WeightedCluster_ClusterWeight_weight( + cluster_weight); + if (weight == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction weighted_cluster cluster missing weight"); + } + cluster.weight = google_protobuf_UInt32Value_value(weight); + if (cluster.weight == 0) continue; + sum_of_weights += cluster.weight; + if (context.use_v3) { + grpc_error_handle error = ParseTypedPerFilterConfig< + envoy_config_route_v3_WeightedCluster_ClusterWeight, + envoy_config_route_v3_WeightedCluster_ClusterWeight_TypedPerFilterConfigEntry>( + context, cluster_weight, + envoy_config_route_v3_WeightedCluster_ClusterWeight_typed_per_filter_config_next, + envoy_config_route_v3_WeightedCluster_ClusterWeight_TypedPerFilterConfigEntry_key, + envoy_config_route_v3_WeightedCluster_ClusterWeight_TypedPerFilterConfigEntry_value, + &cluster.typed_per_filter_config); + if (error != GRPC_ERROR_NONE) return error; + } + route->weighted_clusters.emplace_back(std::move(cluster)); + } + if (total_weight != sum_of_weights) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction weighted_cluster has incorrect total weight"); + } + if (route->weighted_clusters.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "RouteAction weighted_cluster has no valid clusters specified."); + } + } else { + // No cluster or weighted_clusters found in RouteAction, ignore this route. + *ignore_route = true; + } + if (!*ignore_route) { + const envoy_config_route_v3_RouteAction_MaxStreamDuration* + max_stream_duration = + envoy_config_route_v3_RouteAction_max_stream_duration(route_action); + if (max_stream_duration != nullptr) { + const google_protobuf_Duration* duration = + envoy_config_route_v3_RouteAction_MaxStreamDuration_grpc_timeout_header_max( + max_stream_duration); + if (duration == nullptr) { + duration = + envoy_config_route_v3_RouteAction_MaxStreamDuration_max_stream_duration( + max_stream_duration); + } + if (duration != nullptr) { + route->max_stream_duration = DurationParse(duration); + } + } + } + // Get HashPolicy from RouteAction + size_t size = 0; + const envoy_config_route_v3_RouteAction_HashPolicy* const* hash_policies = + envoy_config_route_v3_RouteAction_hash_policy(route_action, &size); + for (size_t i = 0; i < size; ++i) { + const envoy_config_route_v3_RouteAction_HashPolicy* hash_policy = + hash_policies[i]; + XdsApi::Route::HashPolicy policy; + policy.terminal = + envoy_config_route_v3_RouteAction_HashPolicy_terminal(hash_policy); + const envoy_config_route_v3_RouteAction_HashPolicy_Header* header; + const envoy_config_route_v3_RouteAction_HashPolicy_FilterState* + filter_state; + if ((header = envoy_config_route_v3_RouteAction_HashPolicy_header( + hash_policy)) != nullptr) { + policy.type = XdsApi::Route::HashPolicy::Type::HEADER; + policy.header_name = UpbStringToStdString( + envoy_config_route_v3_RouteAction_HashPolicy_Header_header_name( + header)); + const struct envoy_type_matcher_v3_RegexMatchAndSubstitute* + regex_rewrite = + envoy_config_route_v3_RouteAction_HashPolicy_Header_regex_rewrite( + header); + if (regex_rewrite != nullptr) { + const envoy_type_matcher_v3_RegexMatcher* regex_matcher = + envoy_type_matcher_v3_RegexMatchAndSubstitute_pattern( + regex_rewrite); + if (regex_matcher == nullptr) { + gpr_log( + GPR_DEBUG, + "RouteAction HashPolicy contains policy specifier Header with " + "RegexMatchAndSubstitution but RegexMatcher pattern is " + "missing"); + continue; + } + RE2::Options options; + policy.regex = absl::make_unique( + UpbStringToStdString( + envoy_type_matcher_v3_RegexMatcher_regex(regex_matcher)), + options); + if (!policy.regex->ok()) { + gpr_log( + GPR_DEBUG, + "RouteAction HashPolicy contains policy specifier Header with " + "RegexMatchAndSubstitution but RegexMatcher pattern does not " + "compile"); + continue; + } + policy.regex_substitution = UpbStringToStdString( + envoy_type_matcher_v3_RegexMatchAndSubstitute_substitution( + regex_rewrite)); + } + } else if ((filter_state = + envoy_config_route_v3_RouteAction_HashPolicy_filter_state( + hash_policy)) != nullptr) { + std::string key = UpbStringToStdString( + envoy_config_route_v3_RouteAction_HashPolicy_FilterState_key( + filter_state)); + if (key == "io.grpc.channel_id") { + policy.type = XdsApi::Route::HashPolicy::Type::CHANNEL_ID; + } else { + gpr_log(GPR_DEBUG, + "RouteAction HashPolicy contains policy specifier " + "FilterState but " + "key is not io.grpc.channel_id."); + continue; + } + } else { + gpr_log(GPR_DEBUG, + "RouteAction HashPolicy contains unsupported policy specifier."); + continue; + } + route->hash_policies.emplace_back(std::move(policy)); + } + // Get retry policy + const envoy_config_route_v3_RetryPolicy* retry_policy = + envoy_config_route_v3_RouteAction_retry_policy(route_action); + if (retry_policy != nullptr) { + absl::optional retry; + grpc_error_handle error = RetryPolicyParse(context, retry_policy, &retry); + if (error != GRPC_ERROR_NONE) return error; + route->retry_policy = retry; + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle RouteConfigParse( + const EncodingContext& context, + const envoy_config_route_v3_RouteConfiguration* route_config, + bool /*is_v2*/, XdsApi::RdsUpdate* rds_update) { + MaybeLogRouteConfiguration(context, route_config); + // Get the virtual hosts. + size_t num_virtual_hosts; + const envoy_config_route_v3_VirtualHost* const* virtual_hosts = + envoy_config_route_v3_RouteConfiguration_virtual_hosts( + route_config, &num_virtual_hosts); + for (size_t i = 0; i < num_virtual_hosts; ++i) { + rds_update->virtual_hosts.emplace_back(); + XdsApi::RdsUpdate::VirtualHost& vhost = rds_update->virtual_hosts.back(); + // Parse domains. + size_t domain_size; + upb_strview const* domains = envoy_config_route_v3_VirtualHost_domains( + virtual_hosts[i], &domain_size); + for (size_t j = 0; j < domain_size; ++j) { + std::string domain_pattern = UpbStringToStdString(domains[j]); + const MatchType match_type = DomainPatternMatchType(domain_pattern); + if (match_type == INVALID_MATCH) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Invalid domain pattern \"", domain_pattern, "\".")); + } + vhost.domains.emplace_back(std::move(domain_pattern)); + } + if (vhost.domains.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("VirtualHost has no domains"); + } + // Parse typed_per_filter_config. + if (context.use_v3) { + grpc_error_handle error = ParseTypedPerFilterConfig< + envoy_config_route_v3_VirtualHost, + envoy_config_route_v3_VirtualHost_TypedPerFilterConfigEntry>( + context, virtual_hosts[i], + envoy_config_route_v3_VirtualHost_typed_per_filter_config_next, + envoy_config_route_v3_VirtualHost_TypedPerFilterConfigEntry_key, + envoy_config_route_v3_VirtualHost_TypedPerFilterConfigEntry_value, + &vhost.typed_per_filter_config); + if (error != GRPC_ERROR_NONE) return error; + } + // Parse retry policy. + absl::optional virtual_host_retry_policy; + const envoy_config_route_v3_RetryPolicy* retry_policy = + envoy_config_route_v3_VirtualHost_retry_policy(virtual_hosts[i]); + if (retry_policy != nullptr) { + grpc_error_handle error = + RetryPolicyParse(context, retry_policy, &virtual_host_retry_policy); + if (error != GRPC_ERROR_NONE) return error; + } + // Parse routes. + size_t num_routes; + const envoy_config_route_v3_Route* const* routes = + envoy_config_route_v3_VirtualHost_routes(virtual_hosts[i], &num_routes); + if (num_routes < 1) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No route found in the virtual host."); + } + // Loop over the whole list of routes + for (size_t j = 0; j < num_routes; ++j) { + const envoy_config_route_v3_RouteMatch* match = + envoy_config_route_v3_Route_match(routes[j]); + if (match == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Match can't be null."); + } + size_t query_parameters_size; + static_cast(envoy_config_route_v3_RouteMatch_query_parameters( + match, &query_parameters_size)); + if (query_parameters_size > 0) { + continue; + } + XdsApi::Route route; + bool ignore_route = false; + grpc_error_handle error = + RoutePathMatchParse(match, &route, &ignore_route); + if (error != GRPC_ERROR_NONE) return error; + if (ignore_route) continue; + error = RouteHeaderMatchersParse(match, &route); + if (error != GRPC_ERROR_NONE) return error; + error = RouteRuntimeFractionParse(match, &route); + if (error != GRPC_ERROR_NONE) return error; + error = RouteActionParse(context, routes[j], &route, &ignore_route); + if (error != GRPC_ERROR_NONE) return error; + if (ignore_route) continue; + if (route.retry_policy == absl::nullopt && retry_policy != nullptr) { + route.retry_policy = virtual_host_retry_policy; + } + if (context.use_v3) { + grpc_error_handle error = ParseTypedPerFilterConfig< + envoy_config_route_v3_Route, + envoy_config_route_v3_Route_TypedPerFilterConfigEntry>( + context, routes[j], + envoy_config_route_v3_Route_typed_per_filter_config_next, + envoy_config_route_v3_Route_TypedPerFilterConfigEntry_key, + envoy_config_route_v3_Route_TypedPerFilterConfigEntry_value, + &route.typed_per_filter_config); + if (error != GRPC_ERROR_NONE) return error; + } + vhost.routes.emplace_back(std::move(route)); + } + if (vhost.routes.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("No valid routes specified."); + } + } + return GRPC_ERROR_NONE; +} + +// CertificateProviderInstance is deprecated but we are still supporting it for +// backward compatibility reasons. Note that we still parse the data into the +// same CertificateProviderPluginInstance struct since the fields are the same. +// TODO(yashykt): Remove this once we stop supporting the old way of fetching +// certificate provider instances. +grpc_error_handle CertificateProviderInstanceParse( + const EncodingContext& context, + const envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_CertificateProviderInstance* + certificate_provider_instance_proto, + XdsApi::CommonTlsContext::CertificateProviderPluginInstance* + certificate_provider_plugin_instance) { + *certificate_provider_plugin_instance = { + UpbStringToStdString( + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_CertificateProviderInstance_instance_name( + certificate_provider_instance_proto)), + UpbStringToStdString( + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_CertificateProviderInstance_certificate_name( + certificate_provider_instance_proto))}; + if (context.certificate_provider_definition_map->find( + certificate_provider_plugin_instance->instance_name) == + context.certificate_provider_definition_map->end()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unrecognized certificate provider instance name: ", + certificate_provider_plugin_instance->instance_name)); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle CertificateProviderPluginInstanceParse( + const EncodingContext& context, + const envoy_extensions_transport_sockets_tls_v3_CertificateProviderPluginInstance* + certificate_provider_plugin_instance_proto, + XdsApi::CommonTlsContext::CertificateProviderPluginInstance* + certificate_provider_plugin_instance) { + *certificate_provider_plugin_instance = { + UpbStringToStdString( + envoy_extensions_transport_sockets_tls_v3_CertificateProviderPluginInstance_instance_name( + certificate_provider_plugin_instance_proto)), + UpbStringToStdString( + envoy_extensions_transport_sockets_tls_v3_CertificateProviderPluginInstance_certificate_name( + certificate_provider_plugin_instance_proto))}; + if (context.certificate_provider_definition_map->find( + certificate_provider_plugin_instance->instance_name) == + context.certificate_provider_definition_map->end()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unrecognized certificate provider instance name: ", + certificate_provider_plugin_instance->instance_name)); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle CertificateValidationContextParse( + const EncodingContext& context, + const envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext* + certificate_validation_context_proto, + XdsApi::CommonTlsContext::CertificateValidationContext* + certificate_validation_context) { + std::vector errors; + size_t len = 0; + auto* subject_alt_names_matchers = + envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_match_subject_alt_names( + certificate_validation_context_proto, &len); + for (size_t i = 0; i < len; ++i) { + StringMatcher::Type type; + std::string matcher; + if (envoy_type_matcher_v3_StringMatcher_has_exact( + subject_alt_names_matchers[i])) { + type = StringMatcher::Type::kExact; + matcher = UpbStringToStdString(envoy_type_matcher_v3_StringMatcher_exact( + subject_alt_names_matchers[i])); + } else if (envoy_type_matcher_v3_StringMatcher_has_prefix( + subject_alt_names_matchers[i])) { + type = StringMatcher::Type::kPrefix; + matcher = UpbStringToStdString(envoy_type_matcher_v3_StringMatcher_prefix( + subject_alt_names_matchers[i])); + } else if (envoy_type_matcher_v3_StringMatcher_has_suffix( + subject_alt_names_matchers[i])) { + type = StringMatcher::Type::kSuffix; + matcher = UpbStringToStdString(envoy_type_matcher_v3_StringMatcher_suffix( + subject_alt_names_matchers[i])); + } else if (envoy_type_matcher_v3_StringMatcher_has_contains( + subject_alt_names_matchers[i])) { + type = StringMatcher::Type::kContains; + matcher = + UpbStringToStdString(envoy_type_matcher_v3_StringMatcher_contains( + subject_alt_names_matchers[i])); + } else if (envoy_type_matcher_v3_StringMatcher_has_safe_regex( + subject_alt_names_matchers[i])) { + type = StringMatcher::Type::kSafeRegex; + auto* regex_matcher = envoy_type_matcher_v3_StringMatcher_safe_regex( + subject_alt_names_matchers[i]); + matcher = UpbStringToStdString( + envoy_type_matcher_v3_RegexMatcher_regex(regex_matcher)); + } else { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid StringMatcher specified")); + continue; + } + bool ignore_case = envoy_type_matcher_v3_StringMatcher_ignore_case( + subject_alt_names_matchers[i]); + absl::StatusOr string_matcher = + StringMatcher::Create(type, matcher, + /*case_sensitive=*/!ignore_case); + if (!string_matcher.ok()) { + errors.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("string matcher: ", string_matcher.status().message()))); + continue; + } + if (type == StringMatcher::Type::kSafeRegex && ignore_case) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "StringMatcher: ignore_case has no effect for SAFE_REGEX.")); + continue; + } + certificate_validation_context->match_subject_alt_names.push_back( + std::move(string_matcher.value())); + } + auto* ca_certificate_provider_instance = + envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_ca_certificate_provider_instance( + certificate_validation_context_proto); + if (ca_certificate_provider_instance != nullptr) { + grpc_error_handle error = CertificateProviderPluginInstanceParse( + context, ca_certificate_provider_instance, + &certificate_validation_context->ca_certificate_provider_instance); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + if (envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_verify_certificate_spki( + certificate_validation_context_proto, nullptr) != nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "CertificateValidationContext: verify_certificate_spki " + "unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_verify_certificate_hash( + certificate_validation_context_proto, nullptr) != nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "CertificateValidationContext: verify_certificate_hash " + "unsupported")); + } + auto* require_signed_certificate_timestamp = + envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_require_signed_certificate_timestamp( + certificate_validation_context_proto); + if (require_signed_certificate_timestamp != nullptr && + google_protobuf_BoolValue_value(require_signed_certificate_timestamp)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "CertificateValidationContext: " + "require_signed_certificate_timestamp unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_has_crl( + certificate_validation_context_proto)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "CertificateValidationContext: crl unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_CertificateValidationContext_has_custom_validator_config( + certificate_validation_context_proto)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "CertificateValidationContext: custom_validator_config " + "unsupported")); + } + return GRPC_ERROR_CREATE_FROM_VECTOR( + "Error parsing CertificateValidationContext", &errors); +} + +grpc_error_handle CommonTlsContextParse( + const EncodingContext& context, + const envoy_extensions_transport_sockets_tls_v3_CommonTlsContext* + common_tls_context_proto, + XdsApi::CommonTlsContext* common_tls_context) { + std::vector errors; + // The validation context is derived from the oneof in + // 'validation_context_type'. 'validation_context_sds_secret_config' is not + // supported. + auto* combined_validation_context = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_combined_validation_context( + common_tls_context_proto); + if (combined_validation_context != nullptr) { + auto* default_validation_context = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_CombinedCertificateValidationContext_default_validation_context( + combined_validation_context); + if (default_validation_context != nullptr) { + grpc_error_handle error = CertificateValidationContextParse( + context, default_validation_context, + &common_tls_context->certificate_validation_context); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + // If after parsing default_validation_context, + // common_tls_context->certificate_validation_context.ca_certificate_provider_instance + // is empty, fall back onto + // 'validation_context_certificate_provider_instance' inside + // 'combined_validation_context'. Note that this way of fetching root + // certificates is deprecated and will be removed in the future. + // TODO(yashykt): Remove this once it's no longer needed. + auto* validation_context_certificate_provider_instance = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_CombinedCertificateValidationContext_validation_context_certificate_provider_instance( + combined_validation_context); + if (common_tls_context->certificate_validation_context + .ca_certificate_provider_instance.Empty() && + validation_context_certificate_provider_instance != nullptr) { + grpc_error_handle error = CertificateProviderInstanceParse( + context, validation_context_certificate_provider_instance, + &common_tls_context->certificate_validation_context + .ca_certificate_provider_instance); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + } else { + auto* validation_context = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_validation_context( + common_tls_context_proto); + if (validation_context != nullptr) { + grpc_error_handle error = CertificateValidationContextParse( + context, validation_context, + &common_tls_context->certificate_validation_context); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } else if ( + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_has_validation_context_sds_secret_config( + common_tls_context_proto)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "validation_context_sds_secret_config unsupported")); + } + } + auto* tls_certificate_provider_instance = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_tls_certificate_provider_instance( + common_tls_context_proto); + if (tls_certificate_provider_instance != nullptr) { + grpc_error_handle error = CertificateProviderPluginInstanceParse( + context, tls_certificate_provider_instance, + &common_tls_context->tls_certificate_provider_instance); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } else { + // Fall back onto 'tls_certificate_certificate_provider_instance'. Note that + // this way of fetching identity certificates is deprecated and will be + // removed in the future. + // TODO(yashykt): Remove this once it's no longer needed. + auto* tls_certificate_certificate_provider_instance = + envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_tls_certificate_certificate_provider_instance( + common_tls_context_proto); + if (tls_certificate_certificate_provider_instance != nullptr) { + grpc_error_handle error = CertificateProviderInstanceParse( + context, tls_certificate_certificate_provider_instance, + &common_tls_context->tls_certificate_provider_instance); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } else { + if (envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_has_tls_certificates( + common_tls_context_proto)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "tls_certificates unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_has_tls_certificate_sds_secret_configs( + common_tls_context_proto)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "tls_certificate_sds_secret_configs unsupported")); + } + } + } + if (envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_has_tls_params( + common_tls_context_proto)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("tls_params unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_CommonTlsContext_has_custom_handshaker( + common_tls_context_proto)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("custom_handshaker unsupported")); + } + return GRPC_ERROR_CREATE_FROM_VECTOR("Error parsing CommonTlsContext", + &errors); +} + +grpc_error_handle HttpConnectionManagerParse( + bool is_client, const EncodingContext& context, + const envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager* + http_connection_manager_proto, + bool is_v2, + XdsApi::LdsUpdate::HttpConnectionManager* http_connection_manager) { + MaybeLogHttpConnectionManager(context, http_connection_manager_proto); + // Obtain max_stream_duration from Http Protocol Options. + const envoy_config_core_v3_HttpProtocolOptions* options = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_common_http_protocol_options( + http_connection_manager_proto); + if (options != nullptr) { + const google_protobuf_Duration* duration = + envoy_config_core_v3_HttpProtocolOptions_max_stream_duration(options); + if (duration != nullptr) { + http_connection_manager->http_max_stream_duration = + DurationParse(duration); + } + } + // Parse filters. + if (!is_v2) { + size_t num_filters = 0; + const auto* http_filters = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_http_filters( + http_connection_manager_proto, &num_filters); + std::set names_seen; + for (size_t i = 0; i < num_filters; ++i) { + const auto* http_filter = http_filters[i]; + absl::string_view name = UpbStringToAbsl( + envoy_extensions_filters_network_http_connection_manager_v3_HttpFilter_name( + http_filter)); + if (name.empty()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("empty filter name at index ", i)); + } + if (names_seen.find(name) != names_seen.end()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("duplicate HTTP filter name: ", name)); + } + names_seen.insert(name); + const bool is_optional = + envoy_extensions_filters_network_http_connection_manager_v3_HttpFilter_is_optional( + http_filter); + const google_protobuf_Any* any = + envoy_extensions_filters_network_http_connection_manager_v3_HttpFilter_typed_config( + http_filter); + if (any == nullptr) { + if (is_optional) continue; + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("no filter config specified for filter name ", name)); + } + absl::string_view filter_type; + grpc_error_handle error = + ExtractHttpFilterTypeName(context, any, &filter_type); + if (error != GRPC_ERROR_NONE) return error; + const XdsHttpFilterImpl* filter_impl = + XdsHttpFilterRegistry::GetFilterForType(filter_type); + if (filter_impl == nullptr) { + if (is_optional) continue; + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("no filter registered for config type ", filter_type)); + } + if ((is_client && !filter_impl->IsSupportedOnClients()) || + (!is_client && !filter_impl->IsSupportedOnServers())) { + if (is_optional) continue; + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Filter %s is not supported on %s", filter_type, + is_client ? "clients" : "servers")); + } + if (i < num_filters - 1) { + // Filters before the last filter must not be terminal. + if (filter_impl->IsTerminalFilter()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("terminal filter for config type ", filter_type, + " must be the last filter in the chain")); + } + } else { + // The last filter must be terminal. + if (!filter_impl->IsTerminalFilter()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("non-terminal filter for config type ", filter_type, + " is the last filter in the chain")); + } + } + absl::StatusOr filter_config = + filter_impl->GenerateFilterConfig(google_protobuf_Any_value(any), + context.arena); + if (!filter_config.ok()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "filter config for type ", filter_type, + " failed to parse: ", filter_config.status().ToString())); + } + http_connection_manager->http_filters.emplace_back( + XdsApi::LdsUpdate::HttpConnectionManager::HttpFilter{ + std::string(name), std::move(*filter_config)}); + } + } else { + // If using a v2 config, we just hard-code a list containing only the + // router filter without actually looking at the config. This ensures + // that the right thing happens in the xds resolver without having + // to expose whether the resource we received was v2 or v3. + http_connection_manager->http_filters.emplace_back( + XdsApi::LdsUpdate::HttpConnectionManager::HttpFilter{ + "router", {kXdsHttpRouterFilterConfigName, Json()}}); + } + if (is_client) { + // Found inlined route_config. Parse it to find the cluster_name. + if (envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_has_route_config( + http_connection_manager_proto)) { + const envoy_config_route_v3_RouteConfiguration* route_config = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_route_config( + http_connection_manager_proto); + XdsApi::RdsUpdate rds_update; + grpc_error_handle error = + RouteConfigParse(context, route_config, is_v2, &rds_update); + if (error != GRPC_ERROR_NONE) return error; + http_connection_manager->rds_update = std::move(rds_update); + return GRPC_ERROR_NONE; + } + // Validate that RDS must be used to get the route_config dynamically. + const envoy_extensions_filters_network_http_connection_manager_v3_Rds* rds = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_rds( + http_connection_manager_proto); + if (rds == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "HttpConnectionManager neither has inlined route_config nor RDS."); + } + // Check that the ConfigSource specifies ADS. + const envoy_config_core_v3_ConfigSource* config_source = + envoy_extensions_filters_network_http_connection_manager_v3_Rds_config_source( + rds); + if (config_source == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "HttpConnectionManager missing config_source for RDS."); + } + if (!envoy_config_core_v3_ConfigSource_has_ads(config_source)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "HttpConnectionManager ConfigSource for RDS does not specify ADS."); + } + // Get the route_config_name. + http_connection_manager->route_config_name = UpbStringToStdString( + envoy_extensions_filters_network_http_connection_manager_v3_Rds_route_config_name( + rds)); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle LdsResourceParseClient( + const EncodingContext& context, + const envoy_config_listener_v3_ApiListener* api_listener, bool is_v2, + XdsApi::LdsUpdate* lds_update) { + lds_update->type = XdsApi::LdsUpdate::ListenerType::kHttpApiListener; + const upb_strview encoded_api_listener = google_protobuf_Any_value( + envoy_config_listener_v3_ApiListener_api_listener(api_listener)); + const auto* http_connection_manager = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_parse( + encoded_api_listener.data, encoded_api_listener.size, context.arena); + if (http_connection_manager == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not parse HttpConnectionManager config from ApiListener"); + } + return HttpConnectionManagerParse(true /* is_client */, context, + http_connection_manager, is_v2, + &lds_update->http_connection_manager); +} + +grpc_error_handle DownstreamTlsContextParse( + const EncodingContext& context, + const envoy_config_core_v3_TransportSocket* transport_socket, + XdsApi::DownstreamTlsContext* downstream_tls_context) { + absl::string_view name = UpbStringToAbsl( + envoy_config_core_v3_TransportSocket_name(transport_socket)); + if (name != "envoy.transport_sockets.tls") { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unrecognized transport socket: ", name)); + } + auto* typed_config = + envoy_config_core_v3_TransportSocket_typed_config(transport_socket); + std::vector errors; + if (typed_config != nullptr) { + const upb_strview encoded_downstream_tls_context = + google_protobuf_Any_value(typed_config); + auto* downstream_tls_context_proto = + envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_parse( + encoded_downstream_tls_context.data, + encoded_downstream_tls_context.size, context.arena); + if (downstream_tls_context_proto == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Can't decode downstream tls context."); + } + auto* common_tls_context = + envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_common_tls_context( + downstream_tls_context_proto); + if (common_tls_context != nullptr) { + grpc_error_handle error = + CommonTlsContextParse(context, common_tls_context, + &downstream_tls_context->common_tls_context); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + auto* require_client_certificate = + envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_require_client_certificate( + downstream_tls_context_proto); + if (require_client_certificate != nullptr) { + downstream_tls_context->require_client_certificate = + google_protobuf_BoolValue_value(require_client_certificate); + } + auto* require_sni = + envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_require_sni( + downstream_tls_context_proto); + if (require_sni != nullptr && + google_protobuf_BoolValue_value(require_sni)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("require_sni: unsupported")); + } + if (envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_ocsp_staple_policy( + downstream_tls_context_proto) != + envoy_extensions_transport_sockets_tls_v3_DownstreamTlsContext_LENIENT_STAPLING) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ocsp_staple_policy: Only LENIENT_STAPLING supported")); + } + } + if (downstream_tls_context->common_tls_context + .tls_certificate_provider_instance.instance_name.empty()) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TLS configuration provided but no " + "tls_certificate_provider_instance found.")); + } + if (downstream_tls_context->require_client_certificate && + downstream_tls_context->common_tls_context.certificate_validation_context + .ca_certificate_provider_instance.instance_name.empty()) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TLS configuration requires client certificates but no certificate " + "provider instance specified for validation.")); + } + if (!downstream_tls_context->common_tls_context.certificate_validation_context + .match_subject_alt_names.empty()) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "match_subject_alt_names not supported on servers")); + } + return GRPC_ERROR_CREATE_FROM_VECTOR("Error parsing DownstreamTlsContext", + &errors); +} + +grpc_error_handle CidrRangeParse( + const envoy_config_core_v3_CidrRange* cidr_range_proto, + XdsApi::LdsUpdate::FilterChainMap::CidrRange* cidr_range) { + std::string address_prefix = UpbStringToStdString( + envoy_config_core_v3_CidrRange_address_prefix(cidr_range_proto)); + grpc_error_handle error = + grpc_string_to_sockaddr(&cidr_range->address, address_prefix.c_str(), 0); + if (error != GRPC_ERROR_NONE) return error; + cidr_range->prefix_len = 0; + auto* prefix_len_proto = + envoy_config_core_v3_CidrRange_prefix_len(cidr_range_proto); + if (prefix_len_proto != nullptr) { + cidr_range->prefix_len = std::min( + google_protobuf_UInt32Value_value(prefix_len_proto), + (reinterpret_cast(cidr_range->address.addr)) + ->sa_family == GRPC_AF_INET + ? uint32_t(32) + : uint32_t(128)); + } + // Normalize the network address by masking it with prefix_len + grpc_sockaddr_mask_bits(&cidr_range->address, cidr_range->prefix_len); + return GRPC_ERROR_NONE; +} + +grpc_error_handle FilterChainMatchParse( + const envoy_config_listener_v3_FilterChainMatch* filter_chain_match_proto, + FilterChain::FilterChainMatch* filter_chain_match) { + auto* destination_port = + envoy_config_listener_v3_FilterChainMatch_destination_port( + filter_chain_match_proto); + if (destination_port != nullptr) { + filter_chain_match->destination_port = + google_protobuf_UInt32Value_value(destination_port); + } + size_t size = 0; + auto* prefix_ranges = envoy_config_listener_v3_FilterChainMatch_prefix_ranges( + filter_chain_match_proto, &size); + filter_chain_match->prefix_ranges.reserve(size); + for (size_t i = 0; i < size; i++) { + XdsApi::LdsUpdate::FilterChainMap::CidrRange cidr_range; + grpc_error_handle error = CidrRangeParse(prefix_ranges[i], &cidr_range); + if (error != GRPC_ERROR_NONE) return error; + filter_chain_match->prefix_ranges.push_back(cidr_range); + } + filter_chain_match->source_type = + static_cast( + envoy_config_listener_v3_FilterChainMatch_source_type( + filter_chain_match_proto)); + auto* source_prefix_ranges = + envoy_config_listener_v3_FilterChainMatch_source_prefix_ranges( + filter_chain_match_proto, &size); + filter_chain_match->source_prefix_ranges.reserve(size); + for (size_t i = 0; i < size; i++) { + XdsApi::LdsUpdate::FilterChainMap::CidrRange cidr_range; + grpc_error_handle error = + CidrRangeParse(source_prefix_ranges[i], &cidr_range); + if (error != GRPC_ERROR_NONE) return error; + filter_chain_match->source_prefix_ranges.push_back(cidr_range); + } + auto* source_ports = envoy_config_listener_v3_FilterChainMatch_source_ports( + filter_chain_match_proto, &size); + filter_chain_match->source_ports.reserve(size); + for (size_t i = 0; i < size; i++) { + filter_chain_match->source_ports.push_back(source_ports[i]); + } + auto* server_names = envoy_config_listener_v3_FilterChainMatch_server_names( + filter_chain_match_proto, &size); + for (size_t i = 0; i < size; i++) { + filter_chain_match->server_names.push_back( + UpbStringToStdString(server_names[i])); + } + filter_chain_match->transport_protocol = UpbStringToStdString( + envoy_config_listener_v3_FilterChainMatch_transport_protocol( + filter_chain_match_proto)); + auto* application_protocols = + envoy_config_listener_v3_FilterChainMatch_application_protocols( + filter_chain_match_proto, &size); + for (size_t i = 0; i < size; i++) { + filter_chain_match->application_protocols.push_back( + UpbStringToStdString(application_protocols[i])); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle FilterChainParse( + const EncodingContext& context, + const envoy_config_listener_v3_FilterChain* filter_chain_proto, bool is_v2, + FilterChain* filter_chain) { + std::vector errors; + auto* filter_chain_match = + envoy_config_listener_v3_FilterChain_filter_chain_match( + filter_chain_proto); + if (filter_chain_match != nullptr) { + grpc_error_handle error = FilterChainMatchParse( + filter_chain_match, &filter_chain->filter_chain_match); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + // Parse the filters list. Currently we only support HttpConnectionManager. + size_t size = 0; + auto* filters = + envoy_config_listener_v3_FilterChain_filters(filter_chain_proto, &size); + if (size != 1) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "FilterChain should have exactly one filter: HttpConnectionManager; no " + "other filter is supported at the moment")); + } else { + auto* typed_config = + envoy_config_listener_v3_Filter_typed_config(filters[0]); + if (typed_config == nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No typed_config found in filter.")); + } else { + absl::string_view type_url = + UpbStringToAbsl(google_protobuf_Any_type_url(typed_config)); + if (type_url != + "type.googleapis.com/" + "envoy.extensions.filters.network.http_connection_manager.v3." + "HttpConnectionManager") { + errors.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unsupported filter type ", type_url))); + } else { + const upb_strview encoded_http_connection_manager = + google_protobuf_Any_value(typed_config); + const auto* http_connection_manager = + envoy_extensions_filters_network_http_connection_manager_v3_HttpConnectionManager_parse( + encoded_http_connection_manager.data, + encoded_http_connection_manager.size, context.arena); + if (http_connection_manager == nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not parse HttpConnectionManager config from filter " + "typed_config")); + } else { + filter_chain->filter_chain_data = + std::make_shared(); + grpc_error_handle error = HttpConnectionManagerParse( + false /* is_client */, context, http_connection_manager, is_v2, + &filter_chain->filter_chain_data->http_connection_manager); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + } + } + } + auto* transport_socket = + envoy_config_listener_v3_FilterChain_transport_socket(filter_chain_proto); + if (transport_socket != nullptr) { + grpc_error_handle error = DownstreamTlsContextParse( + context, transport_socket, + &filter_chain->filter_chain_data->downstream_tls_context); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } + return GRPC_ERROR_CREATE_FROM_VECTOR("Error parsing FilterChain", &errors); +} + +grpc_error_handle AddressParse( + const envoy_config_core_v3_Address* address_proto, std::string* address) { + const auto* socket_address = + envoy_config_core_v3_Address_socket_address(address_proto); + if (socket_address == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Address does not have socket_address"); + } + if (envoy_config_core_v3_SocketAddress_protocol(socket_address) != + envoy_config_core_v3_SocketAddress_TCP) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "SocketAddress protocol is not TCP"); + } + uint32_t port = envoy_config_core_v3_SocketAddress_port_value(socket_address); + if (port > 65535) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Invalid port"); + } + *address = JoinHostPort( + UpbStringToAbsl( + envoy_config_core_v3_SocketAddress_address(socket_address)), + port); + return GRPC_ERROR_NONE; +} + +// An intermediate map for filter chains that we create to validate the list of +// filter chains received from the control plane and to finally create +// XdsApi::LdsUpdate::FilterChainMap +struct InternalFilterChainMap { + using SourceIpMap = + std::map; + using ConnectionSourceTypesArray = std::array; + struct DestinationIp { + absl::optional prefix_range; + bool transport_protocol_raw_buffer_provided = false; + ConnectionSourceTypesArray source_types_array; + }; + using DestinationIpMap = std::map; + DestinationIpMap destination_ip_map; +}; + +grpc_error_handle AddFilterChainDataForSourcePort( + const FilterChain& filter_chain, + XdsApi::LdsUpdate::FilterChainMap::SourcePortsMap* ports_map, + uint32_t port) { + auto insert_result = ports_map->emplace( + port, XdsApi::LdsUpdate::FilterChainMap::FilterChainDataSharedPtr{ + filter_chain.filter_chain_data}); + if (!insert_result.second) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Duplicate matching rules detected when adding filter chain: ", + filter_chain.filter_chain_match.ToString())); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle AddFilterChainDataForSourcePorts( + const FilterChain& filter_chain, + XdsApi::LdsUpdate::FilterChainMap::SourcePortsMap* ports_map) { + if (filter_chain.filter_chain_match.source_ports.empty()) { + return AddFilterChainDataForSourcePort(filter_chain, ports_map, 0); + } else { + for (uint32_t port : filter_chain.filter_chain_match.source_ports) { + grpc_error_handle error = + AddFilterChainDataForSourcePort(filter_chain, ports_map, port); + if (error != GRPC_ERROR_NONE) return error; + } + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle AddFilterChainDataForSourceIpRange( + const FilterChain& filter_chain, + InternalFilterChainMap::SourceIpMap* source_ip_map) { + if (filter_chain.filter_chain_match.source_prefix_ranges.empty()) { + auto insert_result = source_ip_map->emplace( + "", XdsApi::LdsUpdate::FilterChainMap::SourceIp()); + return AddFilterChainDataForSourcePorts( + filter_chain, &insert_result.first->second.ports_map); + } else { + for (const auto& prefix_range : + filter_chain.filter_chain_match.source_prefix_ranges) { + auto insert_result = source_ip_map->emplace( + absl::StrCat(grpc_sockaddr_to_string(&prefix_range.address, false), + "/", prefix_range.prefix_len), + XdsApi::LdsUpdate::FilterChainMap::SourceIp()); + if (insert_result.second) { + insert_result.first->second.prefix_range.emplace(prefix_range); + } + grpc_error_handle error = AddFilterChainDataForSourcePorts( + filter_chain, &insert_result.first->second.ports_map); + if (error != GRPC_ERROR_NONE) return error; + } + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle AddFilterChainDataForSourceType( + const FilterChain& filter_chain, + InternalFilterChainMap::DestinationIp* destination_ip) { + GPR_ASSERT(static_cast( + filter_chain.filter_chain_match.source_type) < 3); + return AddFilterChainDataForSourceIpRange( + filter_chain, &destination_ip->source_types_array[static_cast( + filter_chain.filter_chain_match.source_type)]); +} + +grpc_error_handle AddFilterChainDataForApplicationProtocols( + const FilterChain& filter_chain, + InternalFilterChainMap::DestinationIp* destination_ip) { + // Only allow filter chains that do not mention application protocols + if (!filter_chain.filter_chain_match.application_protocols.empty()) { + return GRPC_ERROR_NONE; + } + return AddFilterChainDataForSourceType(filter_chain, destination_ip); +} + +grpc_error_handle AddFilterChainDataForTransportProtocol( + const FilterChain& filter_chain, + InternalFilterChainMap::DestinationIp* destination_ip) { + const std::string& transport_protocol = + filter_chain.filter_chain_match.transport_protocol; + // Only allow filter chains with no transport protocol or "raw_buffer" + if (!transport_protocol.empty() && transport_protocol != "raw_buffer") { + return GRPC_ERROR_NONE; + } + // If for this configuration, we've already seen filter chains that mention + // the transport protocol as "raw_buffer", we will never match filter chains + // that do not mention it. + if (destination_ip->transport_protocol_raw_buffer_provided && + transport_protocol.empty()) { + return GRPC_ERROR_NONE; + } + if (!transport_protocol.empty() && + !destination_ip->transport_protocol_raw_buffer_provided) { + destination_ip->transport_protocol_raw_buffer_provided = true; + // Clear out the previous entries if any since those entries did not mention + // "raw_buffer" + destination_ip->source_types_array = + InternalFilterChainMap::ConnectionSourceTypesArray(); + } + return AddFilterChainDataForApplicationProtocols(filter_chain, + destination_ip); +} + +grpc_error_handle AddFilterChainDataForServerNames( + const FilterChain& filter_chain, + InternalFilterChainMap::DestinationIp* destination_ip) { + // Don't continue adding filter chains with server names mentioned + if (!filter_chain.filter_chain_match.server_names.empty()) { + return GRPC_ERROR_NONE; + } + return AddFilterChainDataForTransportProtocol(filter_chain, destination_ip); +} + +grpc_error_handle AddFilterChainDataForDestinationIpRange( + const FilterChain& filter_chain, + InternalFilterChainMap::DestinationIpMap* destination_ip_map) { + if (filter_chain.filter_chain_match.prefix_ranges.empty()) { + auto insert_result = destination_ip_map->emplace( + "", InternalFilterChainMap::DestinationIp()); + return AddFilterChainDataForServerNames(filter_chain, + &insert_result.first->second); + } else { + for (const auto& prefix_range : + filter_chain.filter_chain_match.prefix_ranges) { + auto insert_result = destination_ip_map->emplace( + absl::StrCat(grpc_sockaddr_to_string(&prefix_range.address, false), + "/", prefix_range.prefix_len), + InternalFilterChainMap::DestinationIp()); + if (insert_result.second) { + insert_result.first->second.prefix_range.emplace(prefix_range); + } + grpc_error_handle error = AddFilterChainDataForServerNames( + filter_chain, &insert_result.first->second); + if (error != GRPC_ERROR_NONE) return error; + } + } + return GRPC_ERROR_NONE; +} + +XdsApi::LdsUpdate::FilterChainMap BuildFromInternalFilterChainMap( + InternalFilterChainMap* internal_filter_chain_map) { + XdsApi::LdsUpdate::FilterChainMap filter_chain_map; + for (auto& destination_ip_pair : + internal_filter_chain_map->destination_ip_map) { + XdsApi::LdsUpdate::FilterChainMap::DestinationIp destination_ip; + destination_ip.prefix_range = destination_ip_pair.second.prefix_range; + for (int i = 0; i < 3; i++) { + auto& source_ip_map = destination_ip_pair.second.source_types_array[i]; + for (auto& source_ip_pair : source_ip_map) { + destination_ip.source_types_array[i].push_back( + std::move(source_ip_pair.second)); + } + } + filter_chain_map.destination_ip_vector.push_back(std::move(destination_ip)); + } + return filter_chain_map; +} + +grpc_error_handle BuildFilterChainMap( + const std::vector& filter_chains, + XdsApi::LdsUpdate::FilterChainMap* filter_chain_map) { + InternalFilterChainMap internal_filter_chain_map; + for (const auto& filter_chain : filter_chains) { + // Discard filter chain entries that specify destination port + if (filter_chain.filter_chain_match.destination_port != 0) continue; + grpc_error_handle error = AddFilterChainDataForDestinationIpRange( + filter_chain, &internal_filter_chain_map.destination_ip_map); + if (error != GRPC_ERROR_NONE) return error; + } + *filter_chain_map = + BuildFromInternalFilterChainMap(&internal_filter_chain_map); + return GRPC_ERROR_NONE; +} + +grpc_error_handle LdsResourceParseServer( + const EncodingContext& context, + const envoy_config_listener_v3_Listener* listener, bool is_v2, + XdsApi::LdsUpdate* lds_update) { + lds_update->type = XdsApi::LdsUpdate::ListenerType::kTcpListener; + grpc_error_handle error = + AddressParse(envoy_config_listener_v3_Listener_address(listener), + &lds_update->address); + if (error != GRPC_ERROR_NONE) return error; + const auto* use_original_dst = + envoy_config_listener_v3_Listener_use_original_dst(listener); + if (use_original_dst != nullptr) { + if (google_protobuf_BoolValue_value(use_original_dst)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Field \'use_original_dst\' is not supported."); + } + } + size_t size = 0; + auto* filter_chains = + envoy_config_listener_v3_Listener_filter_chains(listener, &size); + std::vector parsed_filter_chains; + parsed_filter_chains.reserve(size); + for (size_t i = 0; i < size; i++) { + FilterChain filter_chain; + error = FilterChainParse(context, filter_chains[i], is_v2, &filter_chain); + if (error != GRPC_ERROR_NONE) return error; + parsed_filter_chains.push_back(std::move(filter_chain)); + } + error = + BuildFilterChainMap(parsed_filter_chains, &lds_update->filter_chain_map); + if (error != GRPC_ERROR_NONE) return error; + auto* default_filter_chain = + envoy_config_listener_v3_Listener_default_filter_chain(listener); + if (default_filter_chain != nullptr) { + FilterChain filter_chain; + error = + FilterChainParse(context, default_filter_chain, is_v2, &filter_chain); + if (error != GRPC_ERROR_NONE) return error; + if (filter_chain.filter_chain_data != nullptr) { + lds_update->default_filter_chain = + std::move(*filter_chain.filter_chain_data); + } + } + if (size == 0 && default_filter_chain == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("No filter chain provided."); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle LdsResourceParse( + const EncodingContext& context, + const envoy_config_listener_v3_Listener* listener, bool is_v2, + XdsApi::LdsUpdate* lds_update) { + // Check whether it's a client or server listener. + const envoy_config_listener_v3_ApiListener* api_listener = + envoy_config_listener_v3_Listener_api_listener(listener); + const envoy_config_core_v3_Address* address = + envoy_config_listener_v3_Listener_address(listener); + if (api_listener != nullptr && address != nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Listener has both address and ApiListener"); + } + if (api_listener == nullptr && address == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Listener has neither address nor ApiListener"); + } + // Validate Listener fields. + grpc_error_handle error = GRPC_ERROR_NONE; + if (api_listener != nullptr) { + error = LdsResourceParseClient(context, api_listener, is_v2, lds_update); + } else { + error = LdsResourceParseServer(context, listener, is_v2, lds_update); + } + return error; +} + +grpc_error_handle UpstreamTlsContextParse( + const EncodingContext& context, + const envoy_config_core_v3_TransportSocket* transport_socket, + XdsApi::CommonTlsContext* common_tls_context) { + // Record Upstream tls context + absl::string_view name = UpbStringToAbsl( + envoy_config_core_v3_TransportSocket_name(transport_socket)); + if (name != "envoy.transport_sockets.tls") { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unrecognized transport socket: ", name)); + } + auto* typed_config = + envoy_config_core_v3_TransportSocket_typed_config(transport_socket); + if (typed_config != nullptr) { + const upb_strview encoded_upstream_tls_context = + google_protobuf_Any_value(typed_config); + auto* upstream_tls_context = + envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_parse( + encoded_upstream_tls_context.data, + encoded_upstream_tls_context.size, context.arena); + if (upstream_tls_context == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Can't decode upstream tls context."); + } + auto* common_tls_context_proto = + envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_common_tls_context( + upstream_tls_context); + if (common_tls_context_proto != nullptr) { + grpc_error_handle error = CommonTlsContextParse( + context, common_tls_context_proto, common_tls_context); + if (error != GRPC_ERROR_NONE) { + return grpc_error_add_child(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error parsing UpstreamTlsContext"), + error); + } + } + } + if (common_tls_context->certificate_validation_context + .ca_certificate_provider_instance.instance_name.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "UpstreamTlsContext: TLS configuration provided but no " + "ca_certificate_provider_instance found."); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle CdsLogicalDnsParse( + const envoy_config_cluster_v3_Cluster* cluster, + XdsApi::CdsUpdate* cds_update) { + const auto* load_assignment = + envoy_config_cluster_v3_Cluster_load_assignment(cluster); + if (load_assignment == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "load_assignment not present for LOGICAL_DNS cluster"); + } + size_t num_localities; + const auto* const* localities = + envoy_config_endpoint_v3_ClusterLoadAssignment_endpoints(load_assignment, + &num_localities); + if (num_localities != 1) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("load_assignment for LOGICAL_DNS cluster must have " + "exactly one locality, found ", + num_localities)); + } + size_t num_endpoints; + const auto* const* endpoints = + envoy_config_endpoint_v3_LocalityLbEndpoints_lb_endpoints(localities[0], + &num_endpoints); + if (num_endpoints != 1) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("locality for LOGICAL_DNS cluster must have " + "exactly one endpoint, found ", + num_endpoints)); + } + const auto* endpoint = + envoy_config_endpoint_v3_LbEndpoint_endpoint(endpoints[0]); + if (endpoint == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "LbEndpoint endpoint field not set"); + } + const auto* address = envoy_config_endpoint_v3_Endpoint_address(endpoint); + if (address == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Endpoint address field not set"); + } + const auto* socket_address = + envoy_config_core_v3_Address_socket_address(address); + if (socket_address == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Address socket_address field not set"); + } + if (envoy_config_core_v3_SocketAddress_resolver_name(socket_address).size != + 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "LOGICAL_DNS clusters must NOT have a custom resolver name set"); + } + absl::string_view address_str = UpbStringToAbsl( + envoy_config_core_v3_SocketAddress_address(socket_address)); + if (address_str.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "SocketAddress address field not set"); + } + if (!envoy_config_core_v3_SocketAddress_has_port_value(socket_address)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "SocketAddress port_value field not set"); + } + cds_update->dns_hostname = JoinHostPort( + address_str, + envoy_config_core_v3_SocketAddress_port_value(socket_address)); + return GRPC_ERROR_NONE; +} + +grpc_error_handle CdsResourceParse( + const EncodingContext& context, + const envoy_config_cluster_v3_Cluster* cluster, bool /*is_v2*/, + XdsApi::CdsUpdate* cds_update) { + std::vector errors; + // Check the cluster_discovery_type. + if (!envoy_config_cluster_v3_Cluster_has_type(cluster) && + !envoy_config_cluster_v3_Cluster_has_cluster_type(cluster)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("DiscoveryType not found.")); + } else if (envoy_config_cluster_v3_Cluster_type(cluster) == + envoy_config_cluster_v3_Cluster_EDS) { + cds_update->cluster_type = XdsApi::CdsUpdate::ClusterType::EDS; + // Check the EDS config source. + const envoy_config_cluster_v3_Cluster_EdsClusterConfig* eds_cluster_config = + envoy_config_cluster_v3_Cluster_eds_cluster_config(cluster); + const envoy_config_core_v3_ConfigSource* eds_config = + envoy_config_cluster_v3_Cluster_EdsClusterConfig_eds_config( + eds_cluster_config); + if (!envoy_config_core_v3_ConfigSource_has_ads(eds_config)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("EDS ConfigSource is not ADS.")); + } + // Record EDS service_name (if any). + upb_strview service_name = + envoy_config_cluster_v3_Cluster_EdsClusterConfig_service_name( + eds_cluster_config); + if (service_name.size != 0) { + cds_update->eds_service_name = UpbStringToStdString(service_name); + } + } else if (!XdsAggregateAndLogicalDnsClusterEnabled()) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("DiscoveryType is not valid.")); + } else if (envoy_config_cluster_v3_Cluster_type(cluster) == + envoy_config_cluster_v3_Cluster_LOGICAL_DNS) { + cds_update->cluster_type = XdsApi::CdsUpdate::ClusterType::LOGICAL_DNS; + grpc_error_handle error = CdsLogicalDnsParse(cluster, cds_update); + if (error != GRPC_ERROR_NONE) errors.push_back(error); + } else { + if (!envoy_config_cluster_v3_Cluster_has_cluster_type(cluster)) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("DiscoveryType is not valid.")); + } else { + const envoy_config_cluster_v3_Cluster_CustomClusterType* + custom_cluster_type = + envoy_config_cluster_v3_Cluster_cluster_type(cluster); + upb_strview type_name = + envoy_config_cluster_v3_Cluster_CustomClusterType_name( + custom_cluster_type); + if (UpbStringToAbsl(type_name) != "envoy.clusters.aggregate") { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "DiscoveryType is not valid.")); + } else { + cds_update->cluster_type = XdsApi::CdsUpdate::ClusterType::AGGREGATE; + // Retrieve aggregate clusters. + const google_protobuf_Any* typed_config = + envoy_config_cluster_v3_Cluster_CustomClusterType_typed_config( + custom_cluster_type); + const upb_strview aggregate_cluster_config_upb_strview = + google_protobuf_Any_value(typed_config); + const envoy_extensions_clusters_aggregate_v3_ClusterConfig* + aggregate_cluster_config = + envoy_extensions_clusters_aggregate_v3_ClusterConfig_parse( + aggregate_cluster_config_upb_strview.data, + aggregate_cluster_config_upb_strview.size, context.arena); + if (aggregate_cluster_config == nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Can't parse aggregate cluster.")); + } else { + size_t size; + const upb_strview* clusters = + envoy_extensions_clusters_aggregate_v3_ClusterConfig_clusters( + aggregate_cluster_config, &size); + for (size_t i = 0; i < size; ++i) { + const upb_strview cluster = clusters[i]; + cds_update->prioritized_cluster_names.emplace_back( + UpbStringToStdString(cluster)); + } + } + } + } + } + // Check the LB policy. + if (envoy_config_cluster_v3_Cluster_lb_policy(cluster) == + envoy_config_cluster_v3_Cluster_ROUND_ROBIN) { + cds_update->lb_policy = "ROUND_ROBIN"; + } else if (envoy_config_cluster_v3_Cluster_lb_policy(cluster) == + envoy_config_cluster_v3_Cluster_RING_HASH) { + cds_update->lb_policy = "RING_HASH"; + // Record ring hash lb config + auto* ring_hash_config = + envoy_config_cluster_v3_Cluster_ring_hash_lb_config(cluster); + if (ring_hash_config != nullptr) { + const google_protobuf_UInt64Value* max_ring_size = + envoy_config_cluster_v3_Cluster_RingHashLbConfig_maximum_ring_size( + ring_hash_config); + if (max_ring_size != nullptr) { + cds_update->max_ring_size = + google_protobuf_UInt64Value_value(max_ring_size); + if (cds_update->max_ring_size > 8388608 || + cds_update->max_ring_size == 0) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "max_ring_size is not in the range of 1 to 8388608.")); + } + } + const google_protobuf_UInt64Value* min_ring_size = + envoy_config_cluster_v3_Cluster_RingHashLbConfig_minimum_ring_size( + ring_hash_config); + if (min_ring_size != nullptr) { + cds_update->min_ring_size = + google_protobuf_UInt64Value_value(min_ring_size); + if (cds_update->min_ring_size > 8388608 || + cds_update->min_ring_size == 0) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "min_ring_size is not in the range of 1 to 8388608.")); + } + if (cds_update->min_ring_size > cds_update->max_ring_size) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "min_ring_size cannot be greater than max_ring_size.")); + } + } + if (envoy_config_cluster_v3_Cluster_RingHashLbConfig_hash_function( + ring_hash_config) != + envoy_config_cluster_v3_Cluster_RingHashLbConfig_XX_HASH) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ring hash lb config has invalid hash function.")); + } + } + } else { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("LB policy is not supported.")); + } + auto* transport_socket = + envoy_config_cluster_v3_Cluster_transport_socket(cluster); + if (transport_socket != nullptr) { + grpc_error_handle error = UpstreamTlsContextParse( + context, transport_socket, &cds_update->common_tls_context); + if (error != GRPC_ERROR_NONE) { + errors.push_back( + grpc_error_add_child(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error parsing security configuration"), + error)); + } + } + // Record LRS server name (if any). + const envoy_config_core_v3_ConfigSource* lrs_server = + envoy_config_cluster_v3_Cluster_lrs_server(cluster); + if (lrs_server != nullptr) { + if (!envoy_config_core_v3_ConfigSource_has_self(lrs_server)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + ": LRS ConfigSource is not self.")); + } + cds_update->lrs_load_reporting_server_name.emplace(""); + } + // The Cluster resource encodes the circuit breaking parameters in a list of + // Thresholds messages, where each message specifies the parameters for a + // particular RoutingPriority. we will look only at the first entry in the + // list for priority DEFAULT and default to 1024 if not found. + if (envoy_config_cluster_v3_Cluster_has_circuit_breakers(cluster)) { + const envoy_config_cluster_v3_CircuitBreakers* circuit_breakers = + envoy_config_cluster_v3_Cluster_circuit_breakers(cluster); + size_t num_thresholds; + const envoy_config_cluster_v3_CircuitBreakers_Thresholds* const* + thresholds = envoy_config_cluster_v3_CircuitBreakers_thresholds( + circuit_breakers, &num_thresholds); + for (size_t i = 0; i < num_thresholds; ++i) { + const auto* threshold = thresholds[i]; + if (envoy_config_cluster_v3_CircuitBreakers_Thresholds_priority( + threshold) == envoy_config_core_v3_DEFAULT) { + const google_protobuf_UInt32Value* max_requests = + envoy_config_cluster_v3_CircuitBreakers_Thresholds_max_requests( + threshold); + if (max_requests != nullptr) { + cds_update->max_concurrent_requests = + google_protobuf_UInt32Value_value(max_requests); + } + break; + } + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing CDS resource", &errors); +} + +grpc_error_handle ServerAddressParseAndAppend( + const envoy_config_endpoint_v3_LbEndpoint* lb_endpoint, + ServerAddressList* list) { + // If health_status is not HEALTHY or UNKNOWN, skip this endpoint. + const int32_t health_status = + envoy_config_endpoint_v3_LbEndpoint_health_status(lb_endpoint); + if (health_status != envoy_config_core_v3_UNKNOWN && + health_status != envoy_config_core_v3_HEALTHY) { + return GRPC_ERROR_NONE; + } + // Find the ip:port. + const envoy_config_endpoint_v3_Endpoint* endpoint = + envoy_config_endpoint_v3_LbEndpoint_endpoint(lb_endpoint); + const envoy_config_core_v3_Address* address = + envoy_config_endpoint_v3_Endpoint_address(endpoint); + const envoy_config_core_v3_SocketAddress* socket_address = + envoy_config_core_v3_Address_socket_address(address); + std::string address_str = UpbStringToStdString( + envoy_config_core_v3_SocketAddress_address(socket_address)); + uint32_t port = envoy_config_core_v3_SocketAddress_port_value(socket_address); + if (GPR_UNLIKELY(port >> 16) != 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Invalid port."); + } + // Find load_balancing_weight for the endpoint. + const google_protobuf_UInt32Value* load_balancing_weight = + envoy_config_endpoint_v3_LbEndpoint_load_balancing_weight(lb_endpoint); + const int32_t weight = + load_balancing_weight != nullptr + ? google_protobuf_UInt32Value_value(load_balancing_weight) + : 500; + if (weight == 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid endpoint weight of 0."); + } + // Populate grpc_resolved_address. + grpc_resolved_address addr; + grpc_error_handle error = + grpc_string_to_sockaddr(&addr, address_str.c_str(), port); + if (error != GRPC_ERROR_NONE) return error; + // Append the address to the list. + std::map> + attributes; + attributes[ServerAddressWeightAttribute::kServerAddressWeightAttributeKey] = + absl::make_unique(weight); + list->emplace_back(addr, nullptr, std::move(attributes)); + return GRPC_ERROR_NONE; +} + +grpc_error_handle LocalityParse( + const envoy_config_endpoint_v3_LocalityLbEndpoints* locality_lb_endpoints, + XdsApi::EdsUpdate::Priority::Locality* output_locality, size_t* priority) { + // Parse LB weight. + const google_protobuf_UInt32Value* lb_weight = + envoy_config_endpoint_v3_LocalityLbEndpoints_load_balancing_weight( + locality_lb_endpoints); + // If LB weight is not specified, it means this locality is assigned no load. + // TODO(juanlishen): When we support CDS to configure the inter-locality + // policy, we should change the LB weight handling. + output_locality->lb_weight = + lb_weight != nullptr ? google_protobuf_UInt32Value_value(lb_weight) : 0; + if (output_locality->lb_weight == 0) return GRPC_ERROR_NONE; + // Parse locality name. + const envoy_config_core_v3_Locality* locality = + envoy_config_endpoint_v3_LocalityLbEndpoints_locality( + locality_lb_endpoints); + if (locality == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Empty locality."); + } + std::string region = + UpbStringToStdString(envoy_config_core_v3_Locality_region(locality)); + std::string zone = + UpbStringToStdString(envoy_config_core_v3_Locality_region(locality)); + std::string sub_zone = + UpbStringToStdString(envoy_config_core_v3_Locality_sub_zone(locality)); + output_locality->name = MakeRefCounted( + std::move(region), std::move(zone), std::move(sub_zone)); + // Parse the addresses. + size_t size; + const envoy_config_endpoint_v3_LbEndpoint* const* lb_endpoints = + envoy_config_endpoint_v3_LocalityLbEndpoints_lb_endpoints( + locality_lb_endpoints, &size); + for (size_t i = 0; i < size; ++i) { + grpc_error_handle error = ServerAddressParseAndAppend( + lb_endpoints[i], &output_locality->endpoints); + if (error != GRPC_ERROR_NONE) return error; + } + // Parse the priority. + *priority = envoy_config_endpoint_v3_LocalityLbEndpoints_priority( + locality_lb_endpoints); + return GRPC_ERROR_NONE; +} + +grpc_error_handle DropParseAndAppend( + const envoy_config_endpoint_v3_ClusterLoadAssignment_Policy_DropOverload* + drop_overload, + XdsApi::EdsUpdate::DropConfig* drop_config) { + // Get the category. + std::string category = UpbStringToStdString( + envoy_config_endpoint_v3_ClusterLoadAssignment_Policy_DropOverload_category( + drop_overload)); + if (category.empty()) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Empty drop category name"); + } + // Get the drop rate (per million). + const envoy_type_v3_FractionalPercent* drop_percentage = + envoy_config_endpoint_v3_ClusterLoadAssignment_Policy_DropOverload_drop_percentage( + drop_overload); + uint32_t numerator = + envoy_type_v3_FractionalPercent_numerator(drop_percentage); + const auto denominator = + static_cast( + envoy_type_v3_FractionalPercent_denominator(drop_percentage)); + // Normalize to million. + switch (denominator) { + case envoy_type_v3_FractionalPercent_HUNDRED: + numerator *= 10000; + break; + case envoy_type_v3_FractionalPercent_TEN_THOUSAND: + numerator *= 100; + break; + case envoy_type_v3_FractionalPercent_MILLION: + break; + default: + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Unknown denominator type"); + } + // Cap numerator to 1000000. + numerator = std::min(numerator, 1000000u); + drop_config->AddCategory(std::move(category), numerator); + return GRPC_ERROR_NONE; +} + +grpc_error_handle EdsResourceParse( + const EncodingContext& /*context*/, + const envoy_config_endpoint_v3_ClusterLoadAssignment* + cluster_load_assignment, + bool /*is_v2*/, XdsApi::EdsUpdate* eds_update) { + std::vector errors; + // Get the endpoints. + size_t locality_size; + const envoy_config_endpoint_v3_LocalityLbEndpoints* const* endpoints = + envoy_config_endpoint_v3_ClusterLoadAssignment_endpoints( + cluster_load_assignment, &locality_size); + for (size_t j = 0; j < locality_size; ++j) { + size_t priority; + XdsApi::EdsUpdate::Priority::Locality locality; + grpc_error_handle error = LocalityParse(endpoints[j], &locality, &priority); + if (error != GRPC_ERROR_NONE) { + errors.push_back(error); + continue; + } + // Filter out locality with weight 0. + if (locality.lb_weight == 0) continue; + // Make sure prorities is big enough. Note that they might not + // arrive in priority order. + while (eds_update->priorities.size() < priority + 1) { + eds_update->priorities.emplace_back(); + } + eds_update->priorities[priority].localities.emplace(locality.name.get(), + std::move(locality)); + } + for (const auto& priority : eds_update->priorities) { + if (priority.localities.empty()) { + errors.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("sparse priority list")); + } + } + // Get the drop config. + eds_update->drop_config = MakeRefCounted(); + const envoy_config_endpoint_v3_ClusterLoadAssignment_Policy* policy = + envoy_config_endpoint_v3_ClusterLoadAssignment_policy( + cluster_load_assignment); + if (policy != nullptr) { + size_t drop_size; + const envoy_config_endpoint_v3_ClusterLoadAssignment_Policy_DropOverload* const* + drop_overload = + envoy_config_endpoint_v3_ClusterLoadAssignment_Policy_drop_overloads( + policy, &drop_size); + for (size_t j = 0; j < drop_size; ++j) { + grpc_error_handle error = + DropParseAndAppend(drop_overload[j], eds_update->drop_config.get()); + if (error != GRPC_ERROR_NONE) { + errors.push_back( + grpc_error_add_child(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "drop config validation error"), + error)); + } + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing EDS resource", &errors); +} + +template +grpc_error_handle AdsResponseParse( + const EncodingContext& context, ProtoParseFunction proto_parse_function, + ProtoResourceNameFunction proto_resource_name_function, + ResourceTypeSelectorFunction resource_type_selector_function, + ProtoLogFunction proto_log_function, + ResourceParseFunction resource_parse_function, + const envoy_service_discovery_v3_DiscoveryResponse* response, + const char* resource_type_string, + const std::set& expected_resource_names, + UpdateMap* update_map, std::set* resource_names_failed) { + std::vector errors; + // Get the resources from the response. + size_t size; + const google_protobuf_Any* const* resources = + envoy_service_discovery_v3_DiscoveryResponse_resources(response, &size); + for (size_t i = 0; i < size; ++i) { + // Check the type_url of the resource. + absl::string_view type_url = + UpbStringToAbsl(google_protobuf_Any_type_url(resources[i])); + bool is_v2 = false; + if (!resource_type_selector_function(type_url, &is_v2)) { + errors.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("resource index ", i, ": Resource is not ", + resource_type_string, "."))); + continue; + } + // Parse the resource. + upb_strview serialized_resource = google_protobuf_Any_value(resources[i]); + auto* resource = proto_parse_function( + serialized_resource.data, serialized_resource.size, context.arena); + if (resource == nullptr) { + errors.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("resource index ", i, ": Can't parse ", + resource_type_string, " resource."))); + continue; + } + proto_log_function(context, resource); + // Check the resource name. Ignore unexpected names. + std::string resource_name = + UpbStringToStdString(proto_resource_name_function(resource)); + if (expected_resource_names.find(resource_name) == + expected_resource_names.end()) { + continue; + } + // Fail on duplicate resources. + if (update_map->find(resource_name) != update_map->end()) { + errors.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("duplicate resource name \"", resource_name, "\""))); + resource_names_failed->insert(resource_name); + continue; + } + // Validate resource. + decltype(UpdateMap::mapped_type::resource) update; + grpc_error_handle error = + resource_parse_function(context, resource, is_v2, &update); + if (error != GRPC_ERROR_NONE) { + errors.push_back( + grpc_error_add_child(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + resource_name, ": validation error")), + error)); + resource_names_failed->insert(resource_name); + } else { + // Store result in update map, in both validated and serialized form. + auto& resource_data = (*update_map)[resource_name]; + resource_data.resource = std::move(update); + resource_data.serialized_proto = + UpbStringToStdString(serialized_resource); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing ADS response", &errors); +} + +std::string TypeUrlInternalToExternal(absl::string_view type_url) { + if (type_url == kLdsV2TypeUrl) { + return XdsApi::kLdsTypeUrl; + } else if (type_url == kRdsV2TypeUrl) { + return XdsApi::kRdsTypeUrl; + } else if (type_url == kCdsV2TypeUrl) { + return XdsApi::kCdsTypeUrl; + } else if (type_url == kEdsV2TypeUrl) { + return XdsApi::kEdsTypeUrl; + } + return std::string(type_url); +} + +upb_strview LdsResourceName( + const envoy_config_listener_v3_Listener* lds_resource) { + return envoy_config_listener_v3_Listener_name(lds_resource); +} + +upb_strview RdsResourceName( + const envoy_config_route_v3_RouteConfiguration* rds_resource) { + return envoy_config_route_v3_RouteConfiguration_name(rds_resource); +} + +upb_strview CdsResourceName( + const envoy_config_cluster_v3_Cluster* cds_resource) { + return envoy_config_cluster_v3_Cluster_name(cds_resource); +} + +upb_strview EdsResourceName( + const envoy_config_endpoint_v3_ClusterLoadAssignment* eds_resource) { + return envoy_config_endpoint_v3_ClusterLoadAssignment_cluster_name( + eds_resource); +} + +} // namespace + +XdsApi::AdsParseResult XdsApi::ParseAdsResponse( + const XdsBootstrap::XdsServer& server, const grpc_slice& encoded_response, + const std::set& expected_listener_names, + const std::set& expected_route_configuration_names, + const std::set& expected_cluster_names, + const std::set& expected_eds_service_names) { + AdsParseResult result; + upb::Arena arena; + const EncodingContext context = {client_, + tracer_, + symtab_.ptr(), + arena.ptr(), + server.ShouldUseV3(), + certificate_provider_definition_map_}; + // Decode the response. + const envoy_service_discovery_v3_DiscoveryResponse* response = + envoy_service_discovery_v3_DiscoveryResponse_parse( + reinterpret_cast(GRPC_SLICE_START_PTR(encoded_response)), + GRPC_SLICE_LENGTH(encoded_response), arena.ptr()); + // If decoding fails, output an empty type_url and return. + if (response == nullptr) { + result.parse_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Can't decode DiscoveryResponse."); + return result; + } + MaybeLogDiscoveryResponse(context, response); + // Record the type_url, the version_info, and the nonce of the response. + result.type_url = TypeUrlInternalToExternal(UpbStringToAbsl( + envoy_service_discovery_v3_DiscoveryResponse_type_url(response))); + result.version = UpbStringToStdString( + envoy_service_discovery_v3_DiscoveryResponse_version_info(response)); + result.nonce = UpbStringToStdString( + envoy_service_discovery_v3_DiscoveryResponse_nonce(response)); + // Parse the response according to the resource type. + // TODO(roth): When we have time, consider defining an interface for the + // methods of each resource type, so that we don't have to pass + // individual functions into each call to AdsResponseParse(). + if (IsLds(result.type_url)) { + result.parse_error = AdsResponseParse( + context, envoy_config_listener_v3_Listener_parse, LdsResourceName, + IsLds, MaybeLogListener, LdsResourceParse, response, "LDS", + expected_listener_names, &result.lds_update_map, + &result.resource_names_failed); + } else if (IsRds(result.type_url)) { + result.parse_error = AdsResponseParse( + context, envoy_config_route_v3_RouteConfiguration_parse, + RdsResourceName, IsRds, MaybeLogRouteConfiguration, RouteConfigParse, + response, "RDS", expected_route_configuration_names, + &result.rds_update_map, &result.resource_names_failed); + } else if (IsCds(result.type_url)) { + result.parse_error = AdsResponseParse( + context, envoy_config_cluster_v3_Cluster_parse, CdsResourceName, IsCds, + MaybeLogCluster, CdsResourceParse, response, "CDS", + expected_cluster_names, &result.cds_update_map, + &result.resource_names_failed); + } else if (IsEds(result.type_url)) { + result.parse_error = AdsResponseParse( + context, envoy_config_endpoint_v3_ClusterLoadAssignment_parse, + EdsResourceName, IsEds, MaybeLogClusterLoadAssignment, EdsResourceParse, + response, "EDS", expected_eds_service_names, &result.eds_update_map, + &result.resource_names_failed); + } + return result; +} + +namespace { + +void MaybeLogLrsRequest( + const EncodingContext& context, + const envoy_service_load_stats_v3_LoadStatsRequest* request) { + if (GRPC_TRACE_FLAG_ENABLED(*context.tracer) && + gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + const upb_msgdef* msg_type = + envoy_service_load_stats_v3_LoadStatsRequest_getmsgdef(context.symtab); + char buf[10240]; + upb_text_encode(request, msg_type, nullptr, 0, buf, sizeof(buf)); + gpr_log(GPR_DEBUG, "[xds_client %p] constructed LRS request: %s", + context.client, buf); + } +} + +grpc_slice SerializeLrsRequest( + const EncodingContext& context, + const envoy_service_load_stats_v3_LoadStatsRequest* request) { + size_t output_length; + char* output = envoy_service_load_stats_v3_LoadStatsRequest_serialize( + request, context.arena, &output_length); + return grpc_slice_from_copied_buffer(output, output_length); +} + +} // namespace + +grpc_slice XdsApi::CreateLrsInitialRequest( + const XdsBootstrap::XdsServer& server) { + upb::Arena arena; + const EncodingContext context = {client_, + tracer_, + symtab_.ptr(), + arena.ptr(), + server.ShouldUseV3(), + certificate_provider_definition_map_}; + // Create a request. + envoy_service_load_stats_v3_LoadStatsRequest* request = + envoy_service_load_stats_v3_LoadStatsRequest_new(arena.ptr()); + // Populate node. + envoy_config_core_v3_Node* node_msg = + envoy_service_load_stats_v3_LoadStatsRequest_mutable_node(request, + arena.ptr()); + PopulateNode(context, node_, build_version_, user_agent_name_, + user_agent_version_, node_msg); + envoy_config_core_v3_Node_add_client_features( + node_msg, upb_strview_makez("envoy.lrs.supports_send_all_clusters"), + arena.ptr()); + MaybeLogLrsRequest(context, request); + return SerializeLrsRequest(context, request); +} + +namespace { + +void LocalityStatsPopulate( + const EncodingContext& context, + envoy_config_endpoint_v3_UpstreamLocalityStats* output, + const XdsLocalityName& locality_name, + const XdsClusterLocalityStats::Snapshot& snapshot) { + // Set locality. + envoy_config_core_v3_Locality* locality = + envoy_config_endpoint_v3_UpstreamLocalityStats_mutable_locality( + output, context.arena); + if (!locality_name.region().empty()) { + envoy_config_core_v3_Locality_set_region( + locality, StdStringToUpbString(locality_name.region())); + } + if (!locality_name.zone().empty()) { + envoy_config_core_v3_Locality_set_zone( + locality, StdStringToUpbString(locality_name.zone())); + } + if (!locality_name.sub_zone().empty()) { + envoy_config_core_v3_Locality_set_sub_zone( + locality, StdStringToUpbString(locality_name.sub_zone())); + } + // Set total counts. + envoy_config_endpoint_v3_UpstreamLocalityStats_set_total_successful_requests( + output, snapshot.total_successful_requests); + envoy_config_endpoint_v3_UpstreamLocalityStats_set_total_requests_in_progress( + output, snapshot.total_requests_in_progress); + envoy_config_endpoint_v3_UpstreamLocalityStats_set_total_error_requests( + output, snapshot.total_error_requests); + envoy_config_endpoint_v3_UpstreamLocalityStats_set_total_issued_requests( + output, snapshot.total_issued_requests); + // Add backend metrics. + for (const auto& p : snapshot.backend_metrics) { + const std::string& metric_name = p.first; + const XdsClusterLocalityStats::BackendMetric& metric_value = p.second; + envoy_config_endpoint_v3_EndpointLoadMetricStats* load_metric = + envoy_config_endpoint_v3_UpstreamLocalityStats_add_load_metric_stats( + output, context.arena); + envoy_config_endpoint_v3_EndpointLoadMetricStats_set_metric_name( + load_metric, StdStringToUpbString(metric_name)); + envoy_config_endpoint_v3_EndpointLoadMetricStats_set_num_requests_finished_with_metric( + load_metric, metric_value.num_requests_finished_with_metric); + envoy_config_endpoint_v3_EndpointLoadMetricStats_set_total_metric_value( + load_metric, metric_value.total_metric_value); + } +} + +} // namespace + +grpc_slice XdsApi::CreateLrsRequest( + ClusterLoadReportMap cluster_load_report_map) { + upb::Arena arena; + const EncodingContext context = { + client_, tracer_, symtab_.ptr(), + arena.ptr(), false, certificate_provider_definition_map_}; + // Create a request. + envoy_service_load_stats_v3_LoadStatsRequest* request = + envoy_service_load_stats_v3_LoadStatsRequest_new(arena.ptr()); + for (auto& p : cluster_load_report_map) { + const std::string& cluster_name = p.first.first; + const std::string& eds_service_name = p.first.second; + const ClusterLoadReport& load_report = p.second; + // Add cluster stats. + envoy_config_endpoint_v3_ClusterStats* cluster_stats = + envoy_service_load_stats_v3_LoadStatsRequest_add_cluster_stats( + request, arena.ptr()); + // Set the cluster name. + envoy_config_endpoint_v3_ClusterStats_set_cluster_name( + cluster_stats, StdStringToUpbString(cluster_name)); + // Set EDS service name, if non-empty. + if (!eds_service_name.empty()) { + envoy_config_endpoint_v3_ClusterStats_set_cluster_service_name( + cluster_stats, StdStringToUpbString(eds_service_name)); + } + // Add locality stats. + for (const auto& p : load_report.locality_stats) { + const XdsLocalityName& locality_name = *p.first; + const auto& snapshot = p.second; + envoy_config_endpoint_v3_UpstreamLocalityStats* locality_stats = + envoy_config_endpoint_v3_ClusterStats_add_upstream_locality_stats( + cluster_stats, arena.ptr()); + LocalityStatsPopulate(context, locality_stats, locality_name, snapshot); + } + // Add dropped requests. + uint64_t total_dropped_requests = 0; + for (const auto& p : load_report.dropped_requests.categorized_drops) { + const std::string& category = p.first; + const uint64_t count = p.second; + envoy_config_endpoint_v3_ClusterStats_DroppedRequests* dropped_requests = + envoy_config_endpoint_v3_ClusterStats_add_dropped_requests( + cluster_stats, arena.ptr()); + envoy_config_endpoint_v3_ClusterStats_DroppedRequests_set_category( + dropped_requests, StdStringToUpbString(category)); + envoy_config_endpoint_v3_ClusterStats_DroppedRequests_set_dropped_count( + dropped_requests, count); + total_dropped_requests += count; + } + total_dropped_requests += load_report.dropped_requests.uncategorized_drops; + // Set total dropped requests. + envoy_config_endpoint_v3_ClusterStats_set_total_dropped_requests( + cluster_stats, total_dropped_requests); + // Set real load report interval. + gpr_timespec timespec = + grpc_millis_to_timespec(load_report.load_report_interval, GPR_TIMESPAN); + google_protobuf_Duration* load_report_interval = + envoy_config_endpoint_v3_ClusterStats_mutable_load_report_interval( + cluster_stats, arena.ptr()); + google_protobuf_Duration_set_seconds(load_report_interval, timespec.tv_sec); + google_protobuf_Duration_set_nanos(load_report_interval, timespec.tv_nsec); + } + MaybeLogLrsRequest(context, request); + return SerializeLrsRequest(context, request); +} + +grpc_error_handle XdsApi::ParseLrsResponse( + const grpc_slice& encoded_response, bool* send_all_clusters, + std::set* cluster_names, + grpc_millis* load_reporting_interval) { + upb::Arena arena; + // Decode the response. + const envoy_service_load_stats_v3_LoadStatsResponse* decoded_response = + envoy_service_load_stats_v3_LoadStatsResponse_parse( + reinterpret_cast(GRPC_SLICE_START_PTR(encoded_response)), + GRPC_SLICE_LENGTH(encoded_response), arena.ptr()); + // Parse the response. + if (decoded_response == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Can't decode response."); + } + // Check send_all_clusters. + if (envoy_service_load_stats_v3_LoadStatsResponse_send_all_clusters( + decoded_response)) { + *send_all_clusters = true; + } else { + // Store the cluster names. + size_t size; + const upb_strview* clusters = + envoy_service_load_stats_v3_LoadStatsResponse_clusters(decoded_response, + &size); + for (size_t i = 0; i < size; ++i) { + cluster_names->emplace(UpbStringToStdString(clusters[i])); + } + } + // Get the load report interval. + const google_protobuf_Duration* load_reporting_interval_duration = + envoy_service_load_stats_v3_LoadStatsResponse_load_reporting_interval( + decoded_response); + gpr_timespec timespec{ + google_protobuf_Duration_seconds(load_reporting_interval_duration), + google_protobuf_Duration_nanos(load_reporting_interval_duration), + GPR_TIMESPAN}; + *load_reporting_interval = gpr_time_to_millis(timespec); + return GRPC_ERROR_NONE; +} + +namespace { +google_protobuf_Timestamp* GrpcMillisToTimestamp(const EncodingContext& context, + grpc_millis value) { + google_protobuf_Timestamp* timestamp = + google_protobuf_Timestamp_new(context.arena); + gpr_timespec timespec = grpc_millis_to_timespec(value, GPR_CLOCK_REALTIME); + google_protobuf_Timestamp_set_seconds(timestamp, timespec.tv_sec); + google_protobuf_Timestamp_set_nanos(timestamp, timespec.tv_nsec); + return timestamp; +} + +envoy_admin_v3_UpdateFailureState* CreateUpdateFailureStateUpb( + const EncodingContext& context, + const XdsApi::ResourceMetadata* resource_metadata) { + auto* update_failure_state = + envoy_admin_v3_UpdateFailureState_new(context.arena); + envoy_admin_v3_UpdateFailureState_set_details( + update_failure_state, + StdStringToUpbString(resource_metadata->failed_details)); + envoy_admin_v3_UpdateFailureState_set_version_info( + update_failure_state, + StdStringToUpbString(resource_metadata->failed_version)); + envoy_admin_v3_UpdateFailureState_set_last_update_attempt( + update_failure_state, + GrpcMillisToTimestamp(context, resource_metadata->failed_update_time)); + return update_failure_state; +} + +void DumpLdsConfig(const EncodingContext& context, + const XdsApi::ResourceTypeMetadata& resource_type_metadata, + envoy_service_status_v3_PerXdsConfig* per_xds_config) { + upb_strview kLdsTypeUrlUpb = upb_strview_makez(XdsApi::kLdsTypeUrl); + auto* listener_config_dump = + envoy_service_status_v3_PerXdsConfig_mutable_listener_config( + per_xds_config, context.arena); + envoy_admin_v3_ListenersConfigDump_set_version_info( + listener_config_dump, + StdStringToUpbString(resource_type_metadata.version)); + for (auto& p : resource_type_metadata.resource_metadata_map) { + absl::string_view name = p.first; + const XdsApi::ResourceMetadata* meta = p.second; + const upb_strview name_upb = StdStringToUpbString(name); + auto* dynamic_listener = + envoy_admin_v3_ListenersConfigDump_add_dynamic_listeners( + listener_config_dump, context.arena); + envoy_admin_v3_ListenersConfigDump_DynamicListener_set_name( + dynamic_listener, name_upb); + envoy_admin_v3_ListenersConfigDump_DynamicListener_set_client_status( + dynamic_listener, meta->client_status); + if (!meta->serialized_proto.empty()) { + // Set in-effective listeners + auto* dynamic_listener_state = + envoy_admin_v3_ListenersConfigDump_DynamicListener_mutable_active_state( + dynamic_listener, context.arena); + envoy_admin_v3_ListenersConfigDump_DynamicListenerState_set_version_info( + dynamic_listener_state, StdStringToUpbString(meta->version)); + envoy_admin_v3_ListenersConfigDump_DynamicListenerState_set_last_updated( + dynamic_listener_state, + GrpcMillisToTimestamp(context, meta->update_time)); + auto* listener_any = + envoy_admin_v3_ListenersConfigDump_DynamicListenerState_mutable_listener( + dynamic_listener_state, context.arena); + google_protobuf_Any_set_type_url(listener_any, kLdsTypeUrlUpb); + google_protobuf_Any_set_value( + listener_any, StdStringToUpbString(meta->serialized_proto)); + } + if (meta->client_status == XdsApi::ResourceMetadata::NACKED) { + // Set error_state if NACKED + envoy_admin_v3_ListenersConfigDump_DynamicListener_set_error_state( + dynamic_listener, CreateUpdateFailureStateUpb(context, meta)); + } + } +} + +void DumpRdsConfig(const EncodingContext& context, + const XdsApi::ResourceTypeMetadata& resource_type_metadata, + envoy_service_status_v3_PerXdsConfig* per_xds_config) { + upb_strview kRdsTypeUrlUpb = upb_strview_makez(XdsApi::kRdsTypeUrl); + auto* route_config_dump = + envoy_service_status_v3_PerXdsConfig_mutable_route_config(per_xds_config, + context.arena); + for (auto& p : resource_type_metadata.resource_metadata_map) { + absl::string_view name = p.first; + const XdsApi::ResourceMetadata* meta = p.second; + const upb_strview name_upb = StdStringToUpbString(name); + auto* dynamic_route_config = + envoy_admin_v3_RoutesConfigDump_add_dynamic_route_configs( + route_config_dump, context.arena); + envoy_admin_v3_RoutesConfigDump_DynamicRouteConfig_set_client_status( + dynamic_route_config, meta->client_status); + auto* route_config_any = + envoy_admin_v3_RoutesConfigDump_DynamicRouteConfig_mutable_route_config( + dynamic_route_config, context.arena); + if (!meta->serialized_proto.empty()) { + // Set in-effective route configs + envoy_admin_v3_RoutesConfigDump_DynamicRouteConfig_set_version_info( + dynamic_route_config, StdStringToUpbString(meta->version)); + envoy_admin_v3_RoutesConfigDump_DynamicRouteConfig_set_last_updated( + dynamic_route_config, + GrpcMillisToTimestamp(context, meta->update_time)); + google_protobuf_Any_set_type_url(route_config_any, kRdsTypeUrlUpb); + google_protobuf_Any_set_value( + route_config_any, StdStringToUpbString(meta->serialized_proto)); + } else { + // If there isn't a working route config, we still need to print the + // name. + auto* route_config = + envoy_config_route_v3_RouteConfiguration_new(context.arena); + envoy_config_route_v3_RouteConfiguration_set_name(route_config, name_upb); + size_t length; + char* bytes = envoy_config_route_v3_RouteConfiguration_serialize( + route_config, context.arena, &length); + google_protobuf_Any_set_type_url(route_config_any, kRdsTypeUrlUpb); + google_protobuf_Any_set_value(route_config_any, + upb_strview_make(bytes, length)); + } + if (meta->client_status == XdsApi::ResourceMetadata::NACKED) { + // Set error_state if NACKED + envoy_admin_v3_RoutesConfigDump_DynamicRouteConfig_set_error_state( + dynamic_route_config, CreateUpdateFailureStateUpb(context, meta)); + } + } +} + +void DumpCdsConfig(const EncodingContext& context, + const XdsApi::ResourceTypeMetadata& resource_type_metadata, + envoy_service_status_v3_PerXdsConfig* per_xds_config) { + upb_strview kCdsTypeUrlUpb = upb_strview_makez(XdsApi::kCdsTypeUrl); + auto* cluster_config_dump = + envoy_service_status_v3_PerXdsConfig_mutable_cluster_config( + per_xds_config, context.arena); + envoy_admin_v3_ClustersConfigDump_set_version_info( + cluster_config_dump, + StdStringToUpbString(resource_type_metadata.version)); + for (auto& p : resource_type_metadata.resource_metadata_map) { + absl::string_view name = p.first; + const XdsApi::ResourceMetadata* meta = p.second; + const upb_strview name_upb = StdStringToUpbString(name); + auto* dynamic_cluster = + envoy_admin_v3_ClustersConfigDump_add_dynamic_active_clusters( + cluster_config_dump, context.arena); + envoy_admin_v3_ClustersConfigDump_DynamicCluster_set_client_status( + dynamic_cluster, meta->client_status); + auto* cluster_any = + envoy_admin_v3_ClustersConfigDump_DynamicCluster_mutable_cluster( + dynamic_cluster, context.arena); + if (!meta->serialized_proto.empty()) { + // Set in-effective clusters + envoy_admin_v3_ClustersConfigDump_DynamicCluster_set_version_info( + dynamic_cluster, StdStringToUpbString(meta->version)); + envoy_admin_v3_ClustersConfigDump_DynamicCluster_set_last_updated( + dynamic_cluster, GrpcMillisToTimestamp(context, meta->update_time)); + google_protobuf_Any_set_type_url(cluster_any, kCdsTypeUrlUpb); + google_protobuf_Any_set_value( + cluster_any, StdStringToUpbString(meta->serialized_proto)); + } else { + // If there isn't a working cluster, we still need to print the name. + auto* cluster = envoy_config_cluster_v3_Cluster_new(context.arena); + envoy_config_cluster_v3_Cluster_set_name(cluster, name_upb); + size_t length; + char* bytes = envoy_config_cluster_v3_Cluster_serialize( + cluster, context.arena, &length); + google_protobuf_Any_set_type_url(cluster_any, kCdsTypeUrlUpb); + google_protobuf_Any_set_value(cluster_any, + upb_strview_make(bytes, length)); + } + if (meta->client_status == XdsApi::ResourceMetadata::NACKED) { + // Set error_state if NACKED + envoy_admin_v3_ClustersConfigDump_DynamicCluster_set_error_state( + dynamic_cluster, CreateUpdateFailureStateUpb(context, meta)); + } + } +} + +void DumpEdsConfig(const EncodingContext& context, + const XdsApi::ResourceTypeMetadata& resource_type_metadata, + envoy_service_status_v3_PerXdsConfig* per_xds_config) { + upb_strview kEdsTypeUrlUpb = upb_strview_makez(XdsApi::kEdsTypeUrl); + auto* endpoint_config_dump = + envoy_service_status_v3_PerXdsConfig_mutable_endpoint_config( + per_xds_config, context.arena); + for (auto& p : resource_type_metadata.resource_metadata_map) { + absl::string_view name = p.first; + const XdsApi::ResourceMetadata* meta = p.second; + const upb_strview name_upb = StdStringToUpbString(name); + auto* dynamic_endpoint = + envoy_admin_v3_EndpointsConfigDump_add_dynamic_endpoint_configs( + endpoint_config_dump, context.arena); + envoy_admin_v3_EndpointsConfigDump_DynamicEndpointConfig_set_client_status( + dynamic_endpoint, meta->client_status); + auto* endpoint_any = + envoy_admin_v3_EndpointsConfigDump_DynamicEndpointConfig_mutable_endpoint_config( + dynamic_endpoint, context.arena); + if (!meta->serialized_proto.empty()) { + // Set in-effective endpoints + envoy_admin_v3_EndpointsConfigDump_DynamicEndpointConfig_set_version_info( + dynamic_endpoint, StdStringToUpbString(meta->version)); + envoy_admin_v3_EndpointsConfigDump_DynamicEndpointConfig_set_last_updated( + dynamic_endpoint, GrpcMillisToTimestamp(context, meta->update_time)); + google_protobuf_Any_set_type_url(endpoint_any, kEdsTypeUrlUpb); + google_protobuf_Any_set_value( + endpoint_any, StdStringToUpbString(meta->serialized_proto)); + } else { + // If there isn't a working endpoint, we still need to print the name. + auto* cluster_load_assignment = + envoy_config_endpoint_v3_ClusterLoadAssignment_new(context.arena); + envoy_config_endpoint_v3_ClusterLoadAssignment_set_cluster_name( + cluster_load_assignment, name_upb); + size_t length; + char* bytes = envoy_config_endpoint_v3_ClusterLoadAssignment_serialize( + cluster_load_assignment, context.arena, &length); + google_protobuf_Any_set_type_url(endpoint_any, kEdsTypeUrlUpb); + google_protobuf_Any_set_value(endpoint_any, + upb_strview_make(bytes, length)); + } + if (meta->client_status == XdsApi::ResourceMetadata::NACKED) { + // Set error_state if NACKED + envoy_admin_v3_EndpointsConfigDump_DynamicEndpointConfig_set_error_state( + dynamic_endpoint, CreateUpdateFailureStateUpb(context, meta)); + } + } +} + +} // namespace + +std::string XdsApi::AssembleClientConfig( + const ResourceTypeMetadataMap& resource_type_metadata_map) { + upb::Arena arena; + // Create the ClientConfig for resource metadata from XdsClient + auto* client_config = envoy_service_status_v3_ClientConfig_new(arena.ptr()); + // Fill-in the node information + auto* node = envoy_service_status_v3_ClientConfig_mutable_node(client_config, + arena.ptr()); + const EncodingContext context = { + client_, tracer_, symtab_.ptr(), + arena.ptr(), true, certificate_provider_definition_map_}; + PopulateNode(context, node_, build_version_, user_agent_name_, + user_agent_version_, node); + // Dump each xDS-type config into PerXdsConfig + for (auto& p : resource_type_metadata_map) { + absl::string_view type_url = p.first; + const ResourceTypeMetadata& resource_type_metadata = p.second; + if (type_url == kLdsTypeUrl) { + auto* per_xds_config = + envoy_service_status_v3_ClientConfig_add_xds_config(client_config, + context.arena); + DumpLdsConfig(context, resource_type_metadata, per_xds_config); + } else if (type_url == kRdsTypeUrl) { + auto* per_xds_config = + envoy_service_status_v3_ClientConfig_add_xds_config(client_config, + context.arena); + DumpRdsConfig(context, resource_type_metadata, per_xds_config); + } else if (type_url == kCdsTypeUrl) { + auto* per_xds_config = + envoy_service_status_v3_ClientConfig_add_xds_config(client_config, + context.arena); + DumpCdsConfig(context, resource_type_metadata, per_xds_config); + } else if (type_url == kEdsTypeUrl) { + auto* per_xds_config = + envoy_service_status_v3_ClientConfig_add_xds_config(client_config, + context.arena); + DumpEdsConfig(context, resource_type_metadata, per_xds_config); + } else { + gpr_log(GPR_ERROR, "invalid type_url %s", std::string(type_url).c_str()); + return ""; + } + } + // Serialize the upb message to bytes + size_t output_length; + char* output = envoy_service_status_v3_ClientConfig_serialize( + client_config, arena.ptr(), &output_length); + return std::string(output, output_length); +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_bootstrap.cc b/src/core/ext/xds/xds_bootstrap.cc new file mode 100644 index 00000000..1a7a82cd --- /dev/null +++ b/src/core/ext/xds/xds_bootstrap.cc @@ -0,0 +1,471 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/xds/xds_bootstrap.h" + +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" + +#include "src/core/ext/xds/certificate_provider_registry.h" +#include "src/core/ext/xds/xds_api.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +// +// XdsChannelCredsRegistry +// + +bool XdsChannelCredsRegistry::IsSupported(const std::string& creds_type) { + return creds_type == "google_default" || creds_type == "insecure" || + creds_type == "fake"; +} + +bool XdsChannelCredsRegistry::IsValidConfig(const std::string& /*creds_type*/, + const Json& /*config*/) { + // Currently, none of the creds types actually take a config, but we + // ignore whatever might be specified in the bootstrap file for + // forward compatibility reasons. + return true; +} + +RefCountedPtr +XdsChannelCredsRegistry::MakeChannelCreds(const std::string& creds_type, + const Json& /*config*/) { + if (creds_type == "google_default") { + return grpc_google_default_credentials_create(nullptr); + } else if (creds_type == "insecure") { + return grpc_insecure_credentials_create(); + } else if (creds_type == "fake") { + return grpc_fake_transport_security_credentials_create(); + } + return nullptr; +} + +// +// XdsBootstrap::XdsServer +// + +bool XdsBootstrap::XdsServer::ShouldUseV3() const { + return server_features.find("xds_v3") != server_features.end(); +} + +// +// XdsBootstrap +// + +std::unique_ptr XdsBootstrap::Create( + absl::string_view json_string, grpc_error_handle* error) { + Json json = Json::Parse(json_string, error); + if (*error != GRPC_ERROR_NONE) { + grpc_error_handle error_out = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to parse bootstrap JSON string", error, 1); + GRPC_ERROR_UNREF(*error); + *error = error_out; + return nullptr; + } + return absl::make_unique(std::move(json), error); +} + +XdsBootstrap::XdsBootstrap(Json json, grpc_error_handle* error) { + if (json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "malformed JSON in bootstrap file"); + return; + } + std::vector error_list; + auto it = json.mutable_object()->find("xds_servers"); + if (it == json.mutable_object()->end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"xds_servers\" field not present")); + } else if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"xds_servers\" field is not an array")); + } else { + grpc_error_handle parse_error = ParseXdsServerList(&it->second); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + it = json.mutable_object()->find("node"); + if (it != json.mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"node\" field is not an object")); + } else { + grpc_error_handle parse_error = ParseNode(&it->second); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + it = json.mutable_object()->find("server_listener_resource_name_template"); + if (it != json.mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"server_listener_resource_name_template\" field is not a string")); + } else { + server_listener_resource_name_template_ = + std::move(*it->second.mutable_string_value()); + } + } + it = json.mutable_object()->find("certificate_providers"); + if (it != json.mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"certificate_providers\" field is not an object")); + } else { + grpc_error_handle parse_error = ParseCertificateProviders(&it->second); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + *error = GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing xds bootstrap file", + &error_list); +} + +grpc_error_handle XdsBootstrap::ParseXdsServerList(Json* json) { + std::vector error_list; + for (size_t i = 0; i < json->mutable_array()->size(); ++i) { + Json& child = json->mutable_array()->at(i); + if (child.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("array element ", i, " is not an object"))); + } else { + grpc_error_handle parse_error = ParseXdsServer(&child, i); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing \"xds_servers\" array", + &error_list); +} + +grpc_error_handle XdsBootstrap::ParseXdsServer(Json* json, size_t idx) { + std::vector error_list; + servers_.emplace_back(); + XdsServer& server = servers_[servers_.size() - 1]; + auto it = json->mutable_object()->find("server_uri"); + if (it == json->mutable_object()->end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"server_uri\" field not present")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"server_uri\" field is not a string")); + } else { + server.server_uri = std::move(*it->second.mutable_string_value()); + } + it = json->mutable_object()->find("channel_creds"); + if (it == json->mutable_object()->end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"channel_creds\" field not present")); + } else if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"channel_creds\" field is not an array")); + } else { + grpc_error_handle parse_error = + ParseChannelCredsArray(&it->second, &server); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + it = json->mutable_object()->find("server_features"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::ARRAY) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"server_features\" field is not an array")); + } else { + grpc_error_handle parse_error = + ParseServerFeaturesArray(&it->second, &server); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("errors parsing index ", idx), &error_list); +} + +grpc_error_handle XdsBootstrap::ParseChannelCredsArray(Json* json, + XdsServer* server) { + std::vector error_list; + for (size_t i = 0; i < json->mutable_array()->size(); ++i) { + Json& child = json->mutable_array()->at(i); + if (child.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("array element ", i, " is not an object"))); + } else { + grpc_error_handle parse_error = ParseChannelCreds(&child, i, server); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + if (server->channel_creds_type.empty()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "no known creds type found in \"channel_creds\"")); + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing \"channel_creds\" array", + &error_list); +} + +grpc_error_handle XdsBootstrap::ParseChannelCreds(Json* json, size_t idx, + XdsServer* server) { + std::vector error_list; + std::string type; + auto it = json->mutable_object()->find("type"); + if (it == json->mutable_object()->end()) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("\"type\" field not present")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("\"type\" field is not a string")); + } else { + type = std::move(*it->second.mutable_string_value()); + } + Json config; + it = json->mutable_object()->find("config"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"config\" field is not an object")); + } else { + config = std::move(it->second); + } + } + // Select the first channel creds type that we support. + if (server->channel_creds_type.empty() && + XdsChannelCredsRegistry::IsSupported(type)) { + if (!XdsChannelCredsRegistry::IsValidConfig(type, config)) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "invalid config for channel creds type \"", type, "\""))); + } + server->channel_creds_type = std::move(type); + server->channel_creds_config = std::move(config); + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("errors parsing index ", idx), &error_list); +} + +grpc_error_handle XdsBootstrap::ParseServerFeaturesArray(Json* json, + XdsServer* server) { + std::vector error_list; + for (size_t i = 0; i < json->mutable_array()->size(); ++i) { + Json& child = json->mutable_array()->at(i); + if (child.type() == Json::Type::STRING && + child.string_value() == "xds_v3") { + server->server_features.insert(std::move(*child.mutable_string_value())); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR( + "errors parsing \"server_features\" array", &error_list); +} + +grpc_error_handle XdsBootstrap::ParseNode(Json* json) { + std::vector error_list; + node_ = absl::make_unique(); + auto it = json->mutable_object()->find("id"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("\"id\" field is not a string")); + } else { + node_->id = std::move(*it->second.mutable_string_value()); + } + } + it = json->mutable_object()->find("cluster"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"cluster\" field is not a string")); + } else { + node_->cluster = std::move(*it->second.mutable_string_value()); + } + } + it = json->mutable_object()->find("locality"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"locality\" field is not an object")); + } else { + grpc_error_handle parse_error = ParseLocality(&it->second); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + it = json->mutable_object()->find("metadata"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"metadata\" field is not an object")); + } else { + node_->metadata = std::move(it->second); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing \"node\" object", + &error_list); +} + +grpc_error_handle XdsBootstrap::ParseLocality(Json* json) { + std::vector error_list; + auto it = json->mutable_object()->find("region"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"region\" field is not a string")); + } else { + node_->locality_region = std::move(*it->second.mutable_string_value()); + } + } + it = json->mutable_object()->find("zone"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"zone\" field is not a string")); + } else { + node_->locality_zone = std::move(*it->second.mutable_string_value()); + } + } + it = json->mutable_object()->find("sub_zone"); + if (it != json->mutable_object()->end()) { + if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"sub_zone\" field is not a string")); + } else { + node_->locality_sub_zone = std::move(*it->second.mutable_string_value()); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing \"locality\" object", + &error_list); +} + +grpc_error_handle XdsBootstrap::ParseCertificateProviders(Json* json) { + std::vector error_list; + for (auto& certificate_provider : *(json->mutable_object())) { + if (certificate_provider.second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "element \"", certificate_provider.first, "\" is not an object"))); + } else { + grpc_error_handle parse_error = ParseCertificateProvider( + certificate_provider.first, &certificate_provider.second); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR( + "errors parsing \"certificate_providers\" object", &error_list); +} + +grpc_error_handle XdsBootstrap::ParseCertificateProvider( + const std::string& instance_name, Json* certificate_provider_json) { + std::vector error_list; + auto it = certificate_provider_json->mutable_object()->find("plugin_name"); + if (it == certificate_provider_json->mutable_object()->end()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"plugin_name\" field not present")); + } else if (it->second.type() != Json::Type::STRING) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"plugin_name\" field is not a string")); + } else { + std::string plugin_name = std::move(*(it->second.mutable_string_value())); + CertificateProviderFactory* factory = + CertificateProviderRegistry::LookupCertificateProviderFactory( + plugin_name); + if (factory == nullptr) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unrecognized plugin name: ", plugin_name))); + } else { + RefCountedPtr config; + it = certificate_provider_json->mutable_object()->find("config"); + if (it != certificate_provider_json->mutable_object()->end()) { + if (it->second.type() != Json::Type::OBJECT) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "\"config\" field is not an object")); + } else { + grpc_error_handle parse_error = GRPC_ERROR_NONE; + config = factory->CreateCertificateProviderConfig(it->second, + &parse_error); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + } else { + // "config" is an optional field, so create an empty JSON object. + grpc_error_handle parse_error = GRPC_ERROR_NONE; + config = factory->CreateCertificateProviderConfig(Json::Object(), + &parse_error); + if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error); + } + certificate_providers_.insert( + {instance_name, {std::move(plugin_name), std::move(config)}}); + } + } + return GRPC_ERROR_CREATE_FROM_VECTOR_AND_CPP_STRING( + absl::StrCat("errors parsing element \"", instance_name, "\""), + &error_list); +} + +std::string XdsBootstrap::ToString() const { + std::vector parts; + if (node_ != nullptr) { + parts.push_back(absl::StrFormat( + "node={\n" + " id=\"%s\",\n" + " cluster=\"%s\",\n" + " locality={\n" + " region=\"%s\",\n" + " zone=\"%s\",\n" + " sub_zone=\"%s\"\n" + " },\n" + " metadata=%s,\n" + "},\n", + node_->id, node_->cluster, node_->locality_region, node_->locality_zone, + node_->locality_sub_zone, node_->metadata.Dump())); + } + parts.push_back( + absl::StrFormat("servers=[\n" + " {\n" + " uri=\"%s\",\n" + " creds_type=%s,\n", + server().server_uri, server().channel_creds_type)); + if (server().channel_creds_config.type() != Json::Type::JSON_NULL) { + parts.push_back(absl::StrFormat(" creds_config=%s,", + server().channel_creds_config.Dump())); + } + if (!server().server_features.empty()) { + parts.push_back(absl::StrCat(" server_features=[", + absl::StrJoin(server().server_features, ", "), + "],\n")); + } + parts.push_back(" }\n],\n"); + if (!server_listener_resource_name_template_.empty()) { + parts.push_back( + absl::StrFormat("server_listener_resource_name_template=\"%s\",\n", + server_listener_resource_name_template_)); + } + parts.push_back("certificate_providers={\n"); + for (const auto& entry : certificate_providers_) { + parts.push_back( + absl::StrFormat(" %s={\n" + " plugin_name=%s\n" + " config=%s\n" + " },\n", + entry.first, entry.second.plugin_name, + entry.second.config->ToString())); + } + parts.push_back("}"); + return absl::StrJoin(parts, ""); +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_certificate_provider.cc b/src/core/ext/xds/xds_certificate_provider.cc new file mode 100644 index 00000000..1b9ca75a --- /dev/null +++ b/src/core/ext/xds/xds_certificate_provider.cc @@ -0,0 +1,405 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/xds_certificate_provider.h" + +#include "absl/functional/bind_front.h" +#include "absl/strings/str_cat.h" + +namespace grpc_core { + +namespace { + +class RootCertificatesWatcher + : public grpc_tls_certificate_distributor::TlsCertificatesWatcherInterface { + public: + // Takes a ref to \a parent instead of a raw pointer since the watcher is + // owned by the root certificate distributor and not by \a parent. Note that + // presently, the watcher is immediately deleted when + // CancelTlsCertificatesWatch() is called, but that can potentially change in + // the future. + RootCertificatesWatcher( + RefCountedPtr parent, + std::string cert_name) + : parent_(std::move(parent)), cert_name_(std::move(cert_name)) {} + + void OnCertificatesChanged(absl::optional root_certs, + absl::optional + /* key_cert_pairs */) override { + if (root_certs.has_value()) { + parent_->SetKeyMaterials(cert_name_, std::string(root_certs.value()), + absl::nullopt); + } + } + + void OnError(grpc_error_handle root_cert_error, + grpc_error_handle identity_cert_error) override { + if (root_cert_error != GRPC_ERROR_NONE) { + parent_->SetErrorForCert(cert_name_, root_cert_error /* pass the ref */, + absl::nullopt); + } + GRPC_ERROR_UNREF(identity_cert_error); + } + + private: + RefCountedPtr parent_; + std::string cert_name_; +}; + +class IdentityCertificatesWatcher + : public grpc_tls_certificate_distributor::TlsCertificatesWatcherInterface { + public: + // Takes a ref to \a parent instead of a raw pointer since the watcher is + // owned by the root certificate distributor and not by \a parent. Note that + // presently, the watcher is immediately deleted when + // CancelTlsCertificatesWatch() is called, but that can potentially change in + // the future. + IdentityCertificatesWatcher( + RefCountedPtr parent, + std::string cert_name) + : parent_(std::move(parent)), cert_name_(std::move(cert_name)) {} + + void OnCertificatesChanged( + absl::optional /* root_certs */, + absl::optional key_cert_pairs) override { + if (key_cert_pairs.has_value()) { + parent_->SetKeyMaterials(cert_name_, absl::nullopt, key_cert_pairs); + } + } + + void OnError(grpc_error_handle root_cert_error, + grpc_error_handle identity_cert_error) override { + if (identity_cert_error != GRPC_ERROR_NONE) { + parent_->SetErrorForCert(cert_name_, absl::nullopt, + identity_cert_error /* pass the ref */); + } + GRPC_ERROR_UNREF(root_cert_error); + } + + private: + RefCountedPtr parent_; + std::string cert_name_; +}; + +} // namespace + +// +// XdsCertificateProvider::ClusterCertificateState +// + +XdsCertificateProvider::ClusterCertificateState::~ClusterCertificateState() { + if (root_cert_watcher_ != nullptr) { + root_cert_distributor_->CancelTlsCertificatesWatch(root_cert_watcher_); + } + if (identity_cert_watcher_ != nullptr) { + identity_cert_distributor_->CancelTlsCertificatesWatch( + identity_cert_watcher_); + } +} + +bool XdsCertificateProvider::ClusterCertificateState::IsSafeToRemove() const { + return !watching_root_certs_ && !watching_identity_certs_ && + root_cert_distributor_ == nullptr && + identity_cert_distributor_ == nullptr; +} + +void XdsCertificateProvider::ClusterCertificateState:: + UpdateRootCertNameAndDistributor( + const std::string& cert_name, absl::string_view root_cert_name, + RefCountedPtr root_cert_distributor) { + if (root_cert_name_ == root_cert_name && + root_cert_distributor_ == root_cert_distributor) { + return; + } + root_cert_name_ = std::string(root_cert_name); + if (watching_root_certs_) { + // The root certificates are being watched. Swap out the watcher. + if (root_cert_distributor_ != nullptr) { + root_cert_distributor_->CancelTlsCertificatesWatch(root_cert_watcher_); + } + if (root_cert_distributor != nullptr) { + UpdateRootCertWatcher(cert_name, root_cert_distributor.get()); + } else { + root_cert_watcher_ = nullptr; + xds_certificate_provider_->distributor_->SetErrorForCert( + "", + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No certificate provider available for root certificates"), + absl::nullopt); + } + } + // Swap out the root certificate distributor + root_cert_distributor_ = std::move(root_cert_distributor); +} + +void XdsCertificateProvider::ClusterCertificateState:: + UpdateIdentityCertNameAndDistributor( + const std::string& cert_name, absl::string_view identity_cert_name, + RefCountedPtr + identity_cert_distributor) { + if (identity_cert_name_ == identity_cert_name && + identity_cert_distributor_ == identity_cert_distributor) { + return; + } + identity_cert_name_ = std::string(identity_cert_name); + if (watching_identity_certs_) { + // The identity certificates are being watched. Swap out the watcher. + if (identity_cert_distributor_ != nullptr) { + identity_cert_distributor_->CancelTlsCertificatesWatch( + identity_cert_watcher_); + } + if (identity_cert_distributor != nullptr) { + UpdateIdentityCertWatcher(cert_name, identity_cert_distributor.get()); + } else { + identity_cert_watcher_ = nullptr; + xds_certificate_provider_->distributor_->SetErrorForCert( + "", absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No certificate provider available for identity certificates")); + } + } + // Swap out the identity certificate distributor + identity_cert_distributor_ = std::move(identity_cert_distributor); +} + +void XdsCertificateProvider::ClusterCertificateState::UpdateRootCertWatcher( + const std::string& cert_name, + grpc_tls_certificate_distributor* root_cert_distributor) { + auto watcher = absl::make_unique( + xds_certificate_provider_->distributor_, cert_name); + root_cert_watcher_ = watcher.get(); + root_cert_distributor->WatchTlsCertificates(std::move(watcher), + root_cert_name_, absl::nullopt); +} + +void XdsCertificateProvider::ClusterCertificateState::UpdateIdentityCertWatcher( + const std::string& cert_name, + grpc_tls_certificate_distributor* identity_cert_distributor) { + auto watcher = absl::make_unique( + xds_certificate_provider_->distributor_, cert_name); + identity_cert_watcher_ = watcher.get(); + identity_cert_distributor->WatchTlsCertificates( + std::move(watcher), absl::nullopt, identity_cert_name_); +} + +void XdsCertificateProvider::ClusterCertificateState::WatchStatusCallback( + const std::string& cert_name, bool root_being_watched, + bool identity_being_watched) { + // We aren't specially handling the case where root_cert_distributor is same + // as identity_cert_distributor. Always using two separate watchers + // irrespective of the fact results in a straightforward design, and using a + // single watcher does not seem to provide any benefit other than cutting down + // on the number of callbacks. + if (root_being_watched && !watching_root_certs_) { + // We need to start watching root certs. + watching_root_certs_ = true; + if (root_cert_distributor_ == nullptr) { + xds_certificate_provider_->distributor_->SetErrorForCert( + cert_name, + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No certificate provider available for root certificates"), + absl::nullopt); + } else { + UpdateRootCertWatcher(cert_name, root_cert_distributor_.get()); + } + } else if (!root_being_watched && watching_root_certs_) { + // We need to cancel root certs watch. + watching_root_certs_ = false; + if (root_cert_distributor_ != nullptr) { + root_cert_distributor_->CancelTlsCertificatesWatch(root_cert_watcher_); + root_cert_watcher_ = nullptr; + } + GPR_ASSERT(root_cert_watcher_ == nullptr); + } + if (identity_being_watched && !watching_identity_certs_) { + watching_identity_certs_ = true; + if (identity_cert_distributor_ == nullptr) { + xds_certificate_provider_->distributor_->SetErrorForCert( + cert_name, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No certificate provider available for identity certificates")); + } else { + UpdateIdentityCertWatcher(cert_name, identity_cert_distributor_.get()); + } + } else if (!identity_being_watched && watching_identity_certs_) { + watching_identity_certs_ = false; + if (identity_cert_distributor_ != nullptr) { + identity_cert_distributor_->CancelTlsCertificatesWatch( + identity_cert_watcher_); + identity_cert_watcher_ = nullptr; + } + GPR_ASSERT(identity_cert_watcher_ == nullptr); + } +} + +// +// XdsCertificateProvider +// + +XdsCertificateProvider::XdsCertificateProvider() + : distributor_(MakeRefCounted()) { + distributor_->SetWatchStatusCallback( + absl::bind_front(&XdsCertificateProvider::WatchStatusCallback, this)); +} + +XdsCertificateProvider::~XdsCertificateProvider() { + distributor_->SetWatchStatusCallback(nullptr); +} + +bool XdsCertificateProvider::ProvidesRootCerts(const std::string& cert_name) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) return false; + return it->second->ProvidesRootCerts(); +} + +void XdsCertificateProvider::UpdateRootCertNameAndDistributor( + const std::string& cert_name, absl::string_view root_cert_name, + RefCountedPtr root_cert_distributor) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) { + it = certificate_state_map_ + .emplace(cert_name, + absl::make_unique(this)) + .first; + } + it->second->UpdateRootCertNameAndDistributor(cert_name, root_cert_name, + root_cert_distributor); + // Delete unused entries. + if (it->second->IsSafeToRemove()) certificate_state_map_.erase(it); +} + +bool XdsCertificateProvider::ProvidesIdentityCerts( + const std::string& cert_name) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) return false; + return it->second->ProvidesIdentityCerts(); +} + +void XdsCertificateProvider::UpdateIdentityCertNameAndDistributor( + const std::string& cert_name, absl::string_view identity_cert_name, + RefCountedPtr identity_cert_distributor) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) { + it = certificate_state_map_ + .emplace(cert_name, + absl::make_unique(this)) + .first; + } + it->second->UpdateIdentityCertNameAndDistributor( + cert_name, identity_cert_name, identity_cert_distributor); + // Delete unused entries. + if (it->second->IsSafeToRemove()) certificate_state_map_.erase(it); +} + +bool XdsCertificateProvider::GetRequireClientCertificate( + const std::string& cert_name) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) return false; + return it->second->require_client_certificate(); +} + +void XdsCertificateProvider::UpdateRequireClientCertificate( + const std::string& cert_name, bool require_client_certificate) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) return; + it->second->set_require_client_certificate(require_client_certificate); +} + +std::vector XdsCertificateProvider::GetSanMatchers( + const std::string& cluster) { + MutexLock lock(&san_matchers_mu_); + auto it = san_matcher_map_.find(cluster); + if (it == san_matcher_map_.end()) return {}; + return it->second; +} + +void XdsCertificateProvider::UpdateSubjectAlternativeNameMatchers( + const std::string& cluster, std::vector matchers) { + MutexLock lock(&san_matchers_mu_); + if (matchers.empty()) { + san_matcher_map_.erase(cluster); + } else { + san_matcher_map_[cluster] = std::move(matchers); + } +} + +void XdsCertificateProvider::WatchStatusCallback(std::string cert_name, + bool root_being_watched, + bool identity_being_watched) { + MutexLock lock(&mu_); + auto it = certificate_state_map_.find(cert_name); + if (it == certificate_state_map_.end()) { + it = certificate_state_map_ + .emplace(cert_name, + absl::make_unique(this)) + .first; + } + it->second->WatchStatusCallback(cert_name, root_being_watched, + identity_being_watched); + // Delete unused entries. + if (it->second->IsSafeToRemove()) certificate_state_map_.erase(it); +} + +namespace { + +void* XdsCertificateProviderArgCopy(void* p) { + XdsCertificateProvider* xds_certificate_provider = + static_cast(p); + return xds_certificate_provider->Ref().release(); +} + +void XdsCertificateProviderArgDestroy(void* p) { + XdsCertificateProvider* xds_certificate_provider = + static_cast(p); + xds_certificate_provider->Unref(); +} + +int XdsCertificateProviderArgCmp(void* p, void* q) { + return QsortCompare(p, q); +} + +const grpc_arg_pointer_vtable kChannelArgVtable = { + XdsCertificateProviderArgCopy, XdsCertificateProviderArgDestroy, + XdsCertificateProviderArgCmp}; + +} // namespace + +grpc_arg XdsCertificateProvider::MakeChannelArg() const { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_XDS_CERTIFICATE_PROVIDER), + const_cast(this), &kChannelArgVtable); +} + +RefCountedPtr +XdsCertificateProvider::GetFromChannelArgs(const grpc_channel_args* args) { + XdsCertificateProvider* xds_certificate_provider = + grpc_channel_args_find_pointer( + args, GRPC_ARG_XDS_CERTIFICATE_PROVIDER); + return xds_certificate_provider != nullptr ? xds_certificate_provider->Ref() + : nullptr; +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_channel_stack_modifier.cc b/src/core/ext/xds/xds_channel_stack_modifier.cc new file mode 100644 index 00000000..3fd5c586 --- /dev/null +++ b/src/core/ext/xds/xds_channel_stack_modifier.cc @@ -0,0 +1,113 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/xds_channel_stack_modifier.h" + +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" + +namespace grpc_core { +namespace { + +void* XdsChannelStackModifierArgCopy(void* p) { + XdsChannelStackModifier* arg = static_cast(p); + return arg->Ref().release(); +} + +void XdsChannelStackModifierArgDestroy(void* p) { + XdsChannelStackModifier* arg = static_cast(p); + arg->Unref(); +} + +int XdsChannelStackModifierArgCmp(void* p, void* q) { + return QsortCompare(p, q); +} + +const grpc_arg_pointer_vtable kChannelArgVtable = { + XdsChannelStackModifierArgCopy, XdsChannelStackModifierArgDestroy, + XdsChannelStackModifierArgCmp}; + +const char* kXdsChannelStackModifierChannelArgName = + "grpc.internal.xds_channel_stack_modifier"; + +} // namespace + +bool XdsChannelStackModifier::ModifyChannelStack( + grpc_channel_stack_builder* builder) { + // Insert the filters after the census filter if present. + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + while (grpc_channel_stack_builder_move_next(it)) { + if (grpc_channel_stack_builder_iterator_is_end(it)) break; + const char* filter_name_at_it = + grpc_channel_stack_builder_iterator_filter_name(it); + if (strcmp("census_server", filter_name_at_it) == 0 || + strcmp("opencensus_server", filter_name_at_it) == 0) { + break; + } + } + if (grpc_channel_stack_builder_iterator_is_end(it)) { + // No census filter found. Reset iterator to the beginning. This will result + // in prepending the list of xDS HTTP filters to the current stack. Note + // that this stage is run before the stage that adds the top server filter, + // resulting in these filters being finally placed after the `server` + // filter. + grpc_channel_stack_builder_iterator_destroy(it); + it = grpc_channel_stack_builder_create_iterator_at_first(builder); + } + GPR_ASSERT(grpc_channel_stack_builder_move_next(it)); + for (const grpc_channel_filter* filter : filters_) { + GPR_ASSERT(grpc_channel_stack_builder_add_filter_before(it, filter, nullptr, + nullptr)); + } + grpc_channel_stack_builder_iterator_destroy(it); + return true; +} + +grpc_arg XdsChannelStackModifier::MakeChannelArg() const { + return grpc_channel_arg_pointer_create( + const_cast(kXdsChannelStackModifierChannelArgName), + const_cast(this), &kChannelArgVtable); +} + +RefCountedPtr +XdsChannelStackModifier::GetFromChannelArgs(const grpc_channel_args& args) { + XdsChannelStackModifier* config_selector_provider = + grpc_channel_args_find_pointer( + &args, kXdsChannelStackModifierChannelArgName); + return config_selector_provider != nullptr ? config_selector_provider->Ref() + : nullptr; +} + +void RegisterXdsChannelStackModifier(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, INT_MAX, [](grpc_channel_stack_builder* builder) { + grpc_core::RefCountedPtr + channel_stack_modifier = + XdsChannelStackModifier::GetFromChannelArgs( + *grpc_channel_stack_builder_get_channel_arguments(builder)); + if (channel_stack_modifier != nullptr) { + return channel_stack_modifier->ModifyChannelStack(builder); + } + return true; + }); +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_client.cc b/src/core/ext/xds/xds_client.cc new file mode 100644 index 00000000..7b167c8a --- /dev/null +++ b/src/core/ext/xds/xds_client.cc @@ -0,0 +1,2555 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/xds/xds_client.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/xds/xds_api.h" +#include "src/core/ext/xds/xds_bootstrap.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_client_stats.h" +#include "src/core/ext/xds/xds_http_filters.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/static_metadata.h" + +#define GRPC_XDS_INITIAL_CONNECT_BACKOFF_SECONDS 1 +#define GRPC_XDS_RECONNECT_BACKOFF_MULTIPLIER 1.6 +#define GRPC_XDS_RECONNECT_MAX_BACKOFF_SECONDS 120 +#define GRPC_XDS_RECONNECT_JITTER 0.2 +#define GRPC_XDS_MIN_CLIENT_LOAD_REPORTING_INTERVAL_MS 1000 + +namespace grpc_core { + +TraceFlag grpc_xds_client_trace(false, "xds_client"); +TraceFlag grpc_xds_client_refcount_trace(false, "xds_client_refcount"); + +namespace { + +Mutex* g_mu = nullptr; +const grpc_channel_args* g_channel_args ABSL_GUARDED_BY(*g_mu) = nullptr; +XdsClient* g_xds_client ABSL_GUARDED_BY(*g_mu) = nullptr; +char* g_fallback_bootstrap_config ABSL_GUARDED_BY(*g_mu) = nullptr; + +} // namespace + +// +// Internal class declarations +// + +// An xds call wrapper that can restart a call upon failure. Holds a ref to +// the xds channel. The template parameter is the kind of wrapped xds call. +template +class XdsClient::ChannelState::RetryableCall + : public InternallyRefCounted> { + public: + explicit RetryableCall(RefCountedPtr chand); + + void Orphan() override; + + void OnCallFinishedLocked(); + + T* calld() const { return calld_.get(); } + ChannelState* chand() const { return chand_.get(); } + + bool IsCurrentCallOnChannel() const; + + private: + void StartNewCallLocked(); + void StartRetryTimerLocked(); + static void OnRetryTimer(void* arg, grpc_error_handle error); + void OnRetryTimerLocked(grpc_error_handle error); + + // The wrapped xds call that talks to the xds server. It's instantiated + // every time we start a new call. It's null during call retry backoff. + OrphanablePtr calld_; + // The owning xds channel. + RefCountedPtr chand_; + + // Retry state. + BackOff backoff_; + grpc_timer retry_timer_; + grpc_closure on_retry_timer_; + bool retry_timer_callback_pending_ = false; + + bool shutting_down_ = false; +}; + +// Contains an ADS call to the xds server. +class XdsClient::ChannelState::AdsCallState + : public InternallyRefCounted { + public: + // The ctor and dtor should not be used directly. + explicit AdsCallState(RefCountedPtr> parent); + ~AdsCallState() override; + + void Orphan() override; + + RetryableCall* parent() const { return parent_.get(); } + ChannelState* chand() const { return parent_->chand(); } + XdsClient* xds_client() const { return chand()->xds_client(); } + bool seen_response() const { return seen_response_; } + + void SubscribeLocked(const std::string& type_url, const std::string& name) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + void UnsubscribeLocked(const std::string& type_url, const std::string& name, + bool delay_unsubscription) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + bool HasSubscribedResources() const; + + private: + class ResourceState : public InternallyRefCounted { + public: + ResourceState(const std::string& type_url, const std::string& name, + bool sent_initial_request) + : type_url_(type_url), + name_(name), + sent_initial_request_(sent_initial_request) { + GRPC_CLOSURE_INIT(&timer_callback_, OnTimer, this, + grpc_schedule_on_exec_ctx); + } + + void Orphan() override { + Finish(); + Unref(DEBUG_LOCATION, "Orphan"); + } + + void Start(RefCountedPtr ads_calld) { + if (sent_initial_request_) return; + sent_initial_request_ = true; + ads_calld_ = std::move(ads_calld); + Ref(DEBUG_LOCATION, "timer").release(); + timer_pending_ = true; + grpc_timer_init( + &timer_, + ExecCtx::Get()->Now() + ads_calld_->xds_client()->request_timeout_, + &timer_callback_); + } + + void Finish() { + if (timer_pending_) { + grpc_timer_cancel(&timer_); + timer_pending_ = false; + } + } + + private: + static void OnTimer(void* arg, grpc_error_handle error) { + ResourceState* self = static_cast(arg); + { + MutexLock lock(&self->ads_calld_->xds_client()->mu_); + self->OnTimerLocked(GRPC_ERROR_REF(error)); + } + self->ads_calld_.reset(); + self->Unref(DEBUG_LOCATION, "timer"); + } + + void OnTimerLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_) { + if (error == GRPC_ERROR_NONE && timer_pending_) { + timer_pending_ = false; + grpc_error_handle watcher_error = + GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "timeout obtaining resource {type=%s name=%s} from xds server", + type_url_, name_)); + watcher_error = grpc_error_set_int( + watcher_error, GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] %s", ads_calld_->xds_client(), + grpc_error_std_string(watcher_error).c_str()); + } + if (type_url_ == XdsApi::kLdsTypeUrl) { + ListenerState& state = ads_calld_->xds_client()->listener_map_[name_]; + state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; + for (const auto& p : state.watchers) { + p.first->OnError(GRPC_ERROR_REF(watcher_error)); + } + } else if (type_url_ == XdsApi::kRdsTypeUrl) { + RouteConfigState& state = + ads_calld_->xds_client()->route_config_map_[name_]; + state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; + for (const auto& p : state.watchers) { + p.first->OnError(GRPC_ERROR_REF(watcher_error)); + } + } else if (type_url_ == XdsApi::kCdsTypeUrl) { + ClusterState& state = ads_calld_->xds_client()->cluster_map_[name_]; + state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; + for (const auto& p : state.watchers) { + p.first->OnError(GRPC_ERROR_REF(watcher_error)); + } + } else if (type_url_ == XdsApi::kEdsTypeUrl) { + EndpointState& state = ads_calld_->xds_client()->endpoint_map_[name_]; + state.meta.client_status = XdsApi::ResourceMetadata::DOES_NOT_EXIST; + for (const auto& p : state.watchers) { + p.first->OnError(GRPC_ERROR_REF(watcher_error)); + } + } else { + GPR_UNREACHABLE_CODE(return ); + } + GRPC_ERROR_UNREF(watcher_error); + } + GRPC_ERROR_UNREF(error); + } + + const std::string type_url_; + const std::string name_; + + RefCountedPtr ads_calld_; + bool sent_initial_request_; + bool timer_pending_ = false; + grpc_timer timer_; + grpc_closure timer_callback_; + }; + + struct ResourceTypeState { + ~ResourceTypeState() { GRPC_ERROR_UNREF(error); } + + // Nonce and error for this resource type. + std::string nonce; + grpc_error_handle error = GRPC_ERROR_NONE; + + // Subscribed resources of this type. + std::map> + subscribed_resources; + }; + + void SendMessageLocked(const std::string& type_url) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + void AcceptLdsUpdateLocked(std::string version, grpc_millis update_time, + XdsApi::LdsUpdateMap lds_update_map, + const std::set& resource_names_failed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + void AcceptRdsUpdateLocked(std::string version, grpc_millis update_time, + XdsApi::RdsUpdateMap rds_update_map) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + void AcceptCdsUpdateLocked(std::string version, grpc_millis update_time, + XdsApi::CdsUpdateMap cds_update_map, + const std::set& resource_names_failed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + void AcceptEdsUpdateLocked(std::string version, grpc_millis update_time, + XdsApi::EdsUpdateMap eds_update_map) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + template + void RejectAdsUpdateLocked(grpc_millis update_time, + const XdsApi::AdsParseResult& result, + StateMap* state_map) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + static void OnRequestSent(void* arg, grpc_error_handle error); + void OnRequestSentLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnResponseReceived(void* arg, grpc_error_handle error); + bool OnResponseReceivedLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnStatusReceived(void* arg, grpc_error_handle error); + void OnStatusReceivedLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + bool IsCurrentCallOnChannel() const; + + std::set ResourceNamesForRequest( + const std::string& type_url); + + // The owning RetryableCall<>. + RefCountedPtr> parent_; + + bool sent_initial_message_ = false; + bool seen_response_ = false; + + // Always non-NULL. + grpc_call* call_; + + // recv_initial_metadata + grpc_metadata_array initial_metadata_recv_; + + // send_message + grpc_byte_buffer* send_message_payload_ = nullptr; + grpc_closure on_request_sent_; + + // recv_message + grpc_byte_buffer* recv_message_payload_ = nullptr; + grpc_closure on_response_received_; + + // recv_trailing_metadata + grpc_metadata_array trailing_metadata_recv_; + grpc_status_code status_code_; + grpc_slice status_details_; + grpc_closure on_status_received_; + + // Resource types for which requests need to be sent. + std::set buffered_requests_; + + // State for each resource type. + std::map state_map_; +}; + +// Contains an LRS call to the xds server. +class XdsClient::ChannelState::LrsCallState + : public InternallyRefCounted { + public: + // The ctor and dtor should not be used directly. + explicit LrsCallState(RefCountedPtr> parent); + ~LrsCallState() override; + + void Orphan() override; + + void MaybeStartReportingLocked(); + + RetryableCall* parent() { return parent_.get(); } + ChannelState* chand() const { return parent_->chand(); } + XdsClient* xds_client() const { return chand()->xds_client(); } + bool seen_response() const { return seen_response_; } + + private: + // Reports client-side load stats according to a fixed interval. + class Reporter : public InternallyRefCounted { + public: + Reporter(RefCountedPtr parent, grpc_millis report_interval) + : parent_(std::move(parent)), report_interval_(report_interval) { + GRPC_CLOSURE_INIT(&on_next_report_timer_, OnNextReportTimer, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_report_done_, OnReportDone, this, + grpc_schedule_on_exec_ctx); + ScheduleNextReportLocked(); + } + + void Orphan() override; + + private: + void ScheduleNextReportLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnNextReportTimer(void* arg, grpc_error_handle error); + bool OnNextReportTimerLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + bool SendReportLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnReportDone(void* arg, grpc_error_handle error); + bool OnReportDoneLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + bool IsCurrentReporterOnCall() const { + return this == parent_->reporter_.get(); + } + XdsClient* xds_client() const { return parent_->xds_client(); } + + // The owning LRS call. + RefCountedPtr parent_; + + // The load reporting state. + const grpc_millis report_interval_; + bool last_report_counters_were_zero_ = false; + bool next_report_timer_callback_pending_ = false; + grpc_timer next_report_timer_; + grpc_closure on_next_report_timer_; + grpc_closure on_report_done_; + }; + + static void OnInitialRequestSent(void* arg, grpc_error_handle error); + void OnInitialRequestSentLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnResponseReceived(void* arg, grpc_error_handle error); + bool OnResponseReceivedLocked() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + static void OnStatusReceived(void* arg, grpc_error_handle error); + void OnStatusReceivedLocked(grpc_error_handle error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_); + + bool IsCurrentCallOnChannel() const; + + // The owning RetryableCall<>. + RefCountedPtr> parent_; + bool seen_response_ = false; + + // Always non-NULL. + grpc_call* call_; + + // recv_initial_metadata + grpc_metadata_array initial_metadata_recv_; + + // send_message + grpc_byte_buffer* send_message_payload_ = nullptr; + grpc_closure on_initial_request_sent_; + + // recv_message + grpc_byte_buffer* recv_message_payload_ = nullptr; + grpc_closure on_response_received_; + + // recv_trailing_metadata + grpc_metadata_array trailing_metadata_recv_; + grpc_status_code status_code_; + grpc_slice status_details_; + grpc_closure on_status_received_; + + // Load reporting state. + bool send_all_clusters_ = false; + std::set cluster_names_; // Asked for by the LRS server. + grpc_millis load_reporting_interval_ = 0; + OrphanablePtr reporter_; +}; + +// +// XdsClient::ChannelState::StateWatcher +// + +class XdsClient::ChannelState::StateWatcher + : public AsyncConnectivityStateWatcherInterface { + public: + explicit StateWatcher(RefCountedPtr parent) + : parent_(std::move(parent)) {} + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& status) override { + MutexLock lock(&parent_->xds_client_->mu_); + if (!parent_->shutting_down_ && + new_state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + // In TRANSIENT_FAILURE. Notify all watchers of error. + gpr_log(GPR_INFO, + "[xds_client %p] xds channel in state:TRANSIENT_FAILURE " + "status_message:(%s)", + parent_->xds_client(), status.ToString().c_str()); + parent_->xds_client_->NotifyOnErrorLocked( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "xds channel in TRANSIENT_FAILURE")); + } + } + + RefCountedPtr parent_; +}; + +// +// XdsClient::ChannelState +// + +namespace { + +grpc_channel* CreateXdsChannel(grpc_channel_args* args, + const XdsBootstrap::XdsServer& server) { + RefCountedPtr channel_creds = + XdsChannelCredsRegistry::MakeChannelCreds(server.channel_creds_type, + server.channel_creds_config); + return grpc_secure_channel_create(channel_creds.get(), + server.server_uri.c_str(), args, nullptr); +} + +} // namespace + +XdsClient::ChannelState::ChannelState(WeakRefCountedPtr xds_client, + const XdsBootstrap::XdsServer& server) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) + ? "ChannelState" + : nullptr), + xds_client_(std::move(xds_client)), + server_(server) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] creating channel to %s", + xds_client_.get(), server.server_uri.c_str()); + } + channel_ = CreateXdsChannel(xds_client_->args_, server); + GPR_ASSERT(channel_ != nullptr); + StartConnectivityWatchLocked(); +} + +XdsClient::ChannelState::~ChannelState() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] Destroying xds channel %p", xds_client(), + this); + } + grpc_channel_destroy(channel_); + xds_client_.reset(DEBUG_LOCATION, "ChannelState"); +} + +void XdsClient::ChannelState::Orphan() { + shutting_down_ = true; + CancelConnectivityWatchLocked(); + ads_calld_.reset(); + lrs_calld_.reset(); + Unref(DEBUG_LOCATION, "ChannelState+orphaned"); +} + +XdsClient::ChannelState::AdsCallState* XdsClient::ChannelState::ads_calld() + const { + return ads_calld_->calld(); +} + +XdsClient::ChannelState::LrsCallState* XdsClient::ChannelState::lrs_calld() + const { + return lrs_calld_->calld(); +} + +bool XdsClient::ChannelState::HasActiveAdsCall() const { + return ads_calld_ != nullptr && ads_calld_->calld() != nullptr; +} + +void XdsClient::ChannelState::MaybeStartLrsCall() { + if (lrs_calld_ != nullptr) return; + lrs_calld_.reset( + new RetryableCall(Ref(DEBUG_LOCATION, "ChannelState+lrs"))); +} + +void XdsClient::ChannelState::StopLrsCall() { lrs_calld_.reset(); } + +void XdsClient::ChannelState::StartConnectivityWatchLocked() { + ClientChannel* client_channel = ClientChannel::GetFromChannel(channel_); + GPR_ASSERT(client_channel != nullptr); + watcher_ = new StateWatcher(Ref(DEBUG_LOCATION, "ChannelState+watch")); + client_channel->AddConnectivityWatcher( + GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher_)); +} + +void XdsClient::ChannelState::CancelConnectivityWatchLocked() { + ClientChannel* client_channel = ClientChannel::GetFromChannel(channel_); + GPR_ASSERT(client_channel != nullptr); + client_channel->RemoveConnectivityWatcher(watcher_); +} + +void XdsClient::ChannelState::SubscribeLocked(const std::string& type_url, + const std::string& name) { + if (ads_calld_ == nullptr) { + // Start the ADS call if this is the first request. + ads_calld_.reset(new RetryableCall( + Ref(DEBUG_LOCATION, "ChannelState+ads"))); + // Note: AdsCallState's ctor will automatically subscribe to all + // resources that the XdsClient already has watchers for, so we can + // return here. + return; + } + // If the ADS call is in backoff state, we don't need to do anything now + // because when the call is restarted it will resend all necessary requests. + if (ads_calld() == nullptr) return; + // Subscribe to this resource if the ADS call is active. + ads_calld()->SubscribeLocked(type_url, name); +} + +void XdsClient::ChannelState::UnsubscribeLocked(const std::string& type_url, + const std::string& name, + bool delay_unsubscription) { + if (ads_calld_ != nullptr) { + auto* calld = ads_calld_->calld(); + if (calld != nullptr) { + calld->UnsubscribeLocked(type_url, name, delay_unsubscription); + if (!calld->HasSubscribedResources()) ads_calld_.reset(); + } + } +} + +// +// XdsClient::ChannelState::RetryableCall<> +// + +template +XdsClient::ChannelState::RetryableCall::RetryableCall( + RefCountedPtr chand) + : chand_(std::move(chand)), + backoff_( + BackOff::Options() + .set_initial_backoff(GRPC_XDS_INITIAL_CONNECT_BACKOFF_SECONDS * + 1000) + .set_multiplier(GRPC_XDS_RECONNECT_BACKOFF_MULTIPLIER) + .set_jitter(GRPC_XDS_RECONNECT_JITTER) + .set_max_backoff(GRPC_XDS_RECONNECT_MAX_BACKOFF_SECONDS * 1000)) { + // Closure Initialization + GRPC_CLOSURE_INIT(&on_retry_timer_, OnRetryTimer, this, + grpc_schedule_on_exec_ctx); + StartNewCallLocked(); +} + +template +void XdsClient::ChannelState::RetryableCall::Orphan() { + shutting_down_ = true; + calld_.reset(); + if (retry_timer_callback_pending_) grpc_timer_cancel(&retry_timer_); + this->Unref(DEBUG_LOCATION, "RetryableCall+orphaned"); +} + +template +void XdsClient::ChannelState::RetryableCall::OnCallFinishedLocked() { + const bool seen_response = calld_->seen_response(); + calld_.reset(); + if (seen_response) { + // If we lost connection to the xds server, reset backoff and restart the + // call immediately. + backoff_.Reset(); + StartNewCallLocked(); + } else { + // If we failed to connect to the xds server, retry later. + StartRetryTimerLocked(); + } +} + +template +void XdsClient::ChannelState::RetryableCall::StartNewCallLocked() { + if (shutting_down_) return; + GPR_ASSERT(chand_->channel_ != nullptr); + GPR_ASSERT(calld_ == nullptr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] Start new call from retryable call (chand: %p, " + "retryable call: %p)", + chand()->xds_client(), chand(), this); + } + calld_ = MakeOrphanable( + this->Ref(DEBUG_LOCATION, "RetryableCall+start_new_call")); +} + +template +void XdsClient::ChannelState::RetryableCall::StartRetryTimerLocked() { + if (shutting_down_) return; + const grpc_millis next_attempt_time = backoff_.NextAttemptTime(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + grpc_millis timeout = + std::max(next_attempt_time - ExecCtx::Get()->Now(), grpc_millis(0)); + gpr_log(GPR_INFO, + "[xds_client %p] Failed to connect to xds server (chand: %p) " + "retry timer will fire in %" PRId64 "ms.", + chand()->xds_client(), chand(), timeout); + } + this->Ref(DEBUG_LOCATION, "RetryableCall+retry_timer_start").release(); + grpc_timer_init(&retry_timer_, next_attempt_time, &on_retry_timer_); + retry_timer_callback_pending_ = true; +} + +template +void XdsClient::ChannelState::RetryableCall::OnRetryTimer( + void* arg, grpc_error_handle error) { + RetryableCall* calld = static_cast(arg); + { + MutexLock lock(&calld->chand_->xds_client()->mu_); + calld->OnRetryTimerLocked(GRPC_ERROR_REF(error)); + } + calld->Unref(DEBUG_LOCATION, "RetryableCall+retry_timer_done"); +} + +template +void XdsClient::ChannelState::RetryableCall::OnRetryTimerLocked( + grpc_error_handle error) { + retry_timer_callback_pending_ = false; + if (!shutting_down_ && error == GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log( + GPR_INFO, + "[xds_client %p] Retry timer fires (chand: %p, retryable call: %p)", + chand()->xds_client(), chand(), this); + } + StartNewCallLocked(); + } + GRPC_ERROR_UNREF(error); +} + +// +// XdsClient::ChannelState::AdsCallState +// + +XdsClient::ChannelState::AdsCallState::AdsCallState( + RefCountedPtr> parent) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) + ? "AdsCallState" + : nullptr), + parent_(std::move(parent)) { + // Init the ADS call. Note that the call will progress every time there's + // activity in xds_client()->interested_parties_, which is comprised of + // the polling entities from client_channel. + GPR_ASSERT(xds_client() != nullptr); + // Create a call with the specified method name. + const auto& method = + chand()->server_.ShouldUseV3() + ? GRPC_MDSTR_SLASH_ENVOY_DOT_SERVICE_DOT_DISCOVERY_DOT_V3_DOT_AGGREGATEDDISCOVERYSERVICE_SLASH_STREAMAGGREGATEDRESOURCES + : GRPC_MDSTR_SLASH_ENVOY_DOT_SERVICE_DOT_DISCOVERY_DOT_V2_DOT_AGGREGATEDDISCOVERYSERVICE_SLASH_STREAMAGGREGATEDRESOURCES; + call_ = grpc_channel_create_pollset_set_call( + chand()->channel_, nullptr, GRPC_PROPAGATE_DEFAULTS, + xds_client()->interested_parties_, method, nullptr, + GRPC_MILLIS_INF_FUTURE, nullptr); + GPR_ASSERT(call_ != nullptr); + // Init data associated with the call. + grpc_metadata_array_init(&initial_metadata_recv_); + grpc_metadata_array_init(&trailing_metadata_recv_); + // Start the call. + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] Starting ADS call (chand: %p, calld: %p, " + "call: %p)", + xds_client(), chand(), this, call_); + } + // Create the ops. + grpc_call_error call_error; + grpc_op ops[3]; + memset(ops, 0, sizeof(ops)); + // Op: send initial metadata. + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY | + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET; + op->reserved = nullptr; + op++; + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), nullptr); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: send request message. + GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this, + grpc_schedule_on_exec_ctx); + for (const auto& p : xds_client()->listener_map_) { + SubscribeLocked(XdsApi::kLdsTypeUrl, std::string(p.first)); + } + for (const auto& p : xds_client()->route_config_map_) { + SubscribeLocked(XdsApi::kRdsTypeUrl, std::string(p.first)); + } + for (const auto& p : xds_client()->cluster_map_) { + SubscribeLocked(XdsApi::kCdsTypeUrl, std::string(p.first)); + } + for (const auto& p : xds_client()->endpoint_map_) { + SubscribeLocked(XdsApi::kEdsTypeUrl, std::string(p.first)); + } + // Op: recv initial metadata. + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv_; + op->flags = 0; + op->reserved = nullptr; + op++; + // Op: recv response. + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_message_payload_; + op->flags = 0; + op->reserved = nullptr; + op++; + Ref(DEBUG_LOCATION, "ADS+OnResponseReceivedLocked").release(); + GRPC_CLOSURE_INIT(&on_response_received_, OnResponseReceived, this, + grpc_schedule_on_exec_ctx); + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &on_response_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: recv server status. + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv_; + op->data.recv_status_on_client.status = &status_code_; + op->data.recv_status_on_client.status_details = &status_details_; + op->flags = 0; + op->reserved = nullptr; + op++; + // This callback signals the end of the call, so it relies on the initial + // ref instead of a new ref. When it's invoked, it's the initial ref that is + // unreffed. + GRPC_CLOSURE_INIT(&on_status_received_, OnStatusReceived, this, + grpc_schedule_on_exec_ctx); + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &on_status_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); +} + +XdsClient::ChannelState::AdsCallState::~AdsCallState() { + grpc_metadata_array_destroy(&initial_metadata_recv_); + grpc_metadata_array_destroy(&trailing_metadata_recv_); + grpc_byte_buffer_destroy(send_message_payload_); + grpc_byte_buffer_destroy(recv_message_payload_); + grpc_slice_unref_internal(status_details_); + GPR_ASSERT(call_ != nullptr); + grpc_call_unref(call_); +} + +void XdsClient::ChannelState::AdsCallState::Orphan() { + GPR_ASSERT(call_ != nullptr); + // If we are here because xds_client wants to cancel the call, + // on_status_received_ will complete the cancellation and clean up. Otherwise, + // we are here because xds_client has to orphan a failed call, then the + // following cancellation will be a no-op. + grpc_call_cancel_internal(call_); + state_map_.clear(); + // Note that the initial ref is hold by on_status_received_. So the + // corresponding unref happens in on_status_received_ instead of here. +} + +void XdsClient::ChannelState::AdsCallState::SendMessageLocked( + const std::string& type_url) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&XdsClient::mu_) { + // Buffer message sending if an existing message is in flight. + if (send_message_payload_ != nullptr) { + buffered_requests_.insert(type_url); + return; + } + auto& state = state_map_[type_url]; + grpc_slice request_payload_slice; + std::set resource_names = + ResourceNamesForRequest(type_url); + request_payload_slice = xds_client()->api_.CreateAdsRequest( + chand()->server_, type_url, resource_names, + xds_client()->resource_version_map_[type_url], state.nonce, + GRPC_ERROR_REF(state.error), !sent_initial_message_); + if (type_url != XdsApi::kLdsTypeUrl && type_url != XdsApi::kRdsTypeUrl && + type_url != XdsApi::kCdsTypeUrl && type_url != XdsApi::kEdsTypeUrl) { + state_map_.erase(type_url); + } + sent_initial_message_ = true; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] sending ADS request: type=%s version=%s nonce=%s " + "error=%s resources=%s", + xds_client(), type_url.c_str(), + xds_client()->resource_version_map_[type_url].c_str(), + state.nonce.c_str(), grpc_error_std_string(state.error).c_str(), + absl::StrJoin(resource_names, " ").c_str()); + } + GRPC_ERROR_UNREF(state.error); + state.error = GRPC_ERROR_NONE; + // Create message payload. + send_message_payload_ = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_slice_unref_internal(request_payload_slice); + // Send the message. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_SEND_MESSAGE; + op.data.send_message.send_message = send_message_payload_; + Ref(DEBUG_LOCATION, "ADS+OnRequestSentLocked").release(); + GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this, + grpc_schedule_on_exec_ctx); + grpc_call_error call_error = + grpc_call_start_batch_and_execute(call_, &op, 1, &on_request_sent_); + if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { + gpr_log(GPR_ERROR, + "[xds_client %p] calld=%p call_error=%d sending ADS message", + xds_client(), this, call_error); + GPR_ASSERT(GRPC_CALL_OK == call_error); + } +} + +void XdsClient::ChannelState::AdsCallState::SubscribeLocked( + const std::string& type_url, const std::string& name) { + auto& state = state_map_[type_url].subscribed_resources[name]; + if (state == nullptr) { + state = MakeOrphanable( + type_url, name, !xds_client()->resource_version_map_[type_url].empty()); + SendMessageLocked(type_url); + } +} + +void XdsClient::ChannelState::AdsCallState::UnsubscribeLocked( + const std::string& type_url, const std::string& name, + bool delay_unsubscription) { + state_map_[type_url].subscribed_resources.erase(name); + if (!delay_unsubscription) SendMessageLocked(type_url); +} + +bool XdsClient::ChannelState::AdsCallState::HasSubscribedResources() const { + for (const auto& p : state_map_) { + if (!p.second.subscribed_resources.empty()) return true; + } + return false; +} + +namespace { + +// Build a resource metadata struct for ADS result accepting methods and CSDS. +XdsApi::ResourceMetadata CreateResourceMetadataAcked( + std::string serialized_proto, std::string version, + grpc_millis update_time) { + XdsApi::ResourceMetadata resource_metadata; + resource_metadata.serialized_proto = std::move(serialized_proto); + resource_metadata.update_time = update_time; + resource_metadata.version = std::move(version); + resource_metadata.client_status = XdsApi::ResourceMetadata::ACKED; + return resource_metadata; +} + +} // namespace + +void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdateLocked( + std::string version, grpc_millis update_time, + XdsApi::LdsUpdateMap lds_update_map, + const std::set& resource_names_failed) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] LDS update received containing %" PRIuPTR + " resources", + xds_client(), lds_update_map.size()); + } + auto& lds_state = state_map_[XdsApi::kLdsTypeUrl]; + std::set rds_resource_names_seen; + for (auto& p : lds_update_map) { + const std::string& listener_name = p.first; + XdsApi::LdsUpdate& lds_update = p.second.resource; + auto& state = lds_state.subscribed_resources[listener_name]; + if (state != nullptr) state->Finish(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] LDS resource %s: %s", xds_client(), + listener_name.c_str(), lds_update.ToString().c_str()); + } + // Record the RDS resource names seen. + if (!lds_update.http_connection_manager.route_config_name.empty()) { + rds_resource_names_seen.insert( + lds_update.http_connection_manager.route_config_name); + } + // Ignore identical update. + ListenerState& listener_state = xds_client()->listener_map_[listener_name]; + if (listener_state.update.has_value() && + *listener_state.update == lds_update) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] LDS update for %s identical to current, " + "ignoring.", + xds_client(), listener_name.c_str()); + } + continue; + } + // Update the listener state. + listener_state.update = std::move(lds_update); + listener_state.meta = CreateResourceMetadataAcked( + std::move(p.second.serialized_proto), version, update_time); + // Notify watchers. + for (const auto& p : listener_state.watchers) { + p.first->OnListenerChanged(*listener_state.update); + } + } + // For invalid resources in the update, if they are already in the + // cache, pretend that they are present in the update, so that we + // don't incorrectly consider them deleted below. + for (const std::string& listener_name : resource_names_failed) { + auto it = xds_client()->listener_map_.find(listener_name); + if (it != xds_client()->listener_map_.end()) { + auto& resource = it->second.update; + if (!resource.has_value()) continue; + lds_update_map[listener_name]; + if (!resource->http_connection_manager.route_config_name.empty()) { + rds_resource_names_seen.insert( + resource->http_connection_manager.route_config_name); + } + } + } + // For any subscribed resource that is not present in the update, + // remove it from the cache and notify watchers that it does not exist. + for (const auto& p : lds_state.subscribed_resources) { + const std::string& listener_name = p.first; + if (lds_update_map.find(listener_name) == lds_update_map.end()) { + ListenerState& listener_state = + xds_client()->listener_map_[listener_name]; + // If the resource was newly requested but has not yet been received, + // we don't want to generate an error for the watchers, because this LDS + // response may be in reaction to an earlier request that did not yet + // request the new resource, so its absence from the response does not + // necessarily indicate that the resource does not exist. + // For that case, we rely on the request timeout instead. + if (!listener_state.update.has_value()) continue; + listener_state.update.reset(); + for (const auto& p : listener_state.watchers) { + p.first->OnResourceDoesNotExist(); + } + } + } + // For any RDS resource that is no longer referred to by any LDS + // resources, remove it from the cache and notify watchers that it + // does not exist. + auto& rds_state = state_map_[XdsApi::kRdsTypeUrl]; + for (const auto& p : rds_state.subscribed_resources) { + const std::string& rds_resource_name = p.first; + if (rds_resource_names_seen.find(rds_resource_name) == + rds_resource_names_seen.end()) { + RouteConfigState& route_config_state = + xds_client()->route_config_map_[rds_resource_name]; + route_config_state.update.reset(); + for (const auto& p : route_config_state.watchers) { + p.first->OnResourceDoesNotExist(); + } + } + } +} + +void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdateLocked( + std::string version, grpc_millis update_time, + XdsApi::RdsUpdateMap rds_update_map) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] RDS update received containing %" PRIuPTR + " resources", + xds_client(), rds_update_map.size()); + } + auto& rds_state = state_map_[XdsApi::kRdsTypeUrl]; + for (auto& p : rds_update_map) { + const std::string& route_config_name = p.first; + XdsApi::RdsUpdate& rds_update = p.second.resource; + auto& state = rds_state.subscribed_resources[route_config_name]; + if (state != nullptr) state->Finish(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] RDS resource:\n%s", xds_client(), + rds_update.ToString().c_str()); + } + RouteConfigState& route_config_state = + xds_client()->route_config_map_[route_config_name]; + // Ignore identical update. + if (route_config_state.update.has_value() && + *route_config_state.update == rds_update) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] RDS resource identical to current, ignoring", + xds_client()); + } + continue; + } + // Update the cache. + route_config_state.update = std::move(rds_update); + route_config_state.meta = CreateResourceMetadataAcked( + std::move(p.second.serialized_proto), version, update_time); + // Notify all watchers. + for (const auto& p : route_config_state.watchers) { + p.first->OnRouteConfigChanged(*route_config_state.update); + } + } +} + +void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdateLocked( + std::string version, grpc_millis update_time, + XdsApi::CdsUpdateMap cds_update_map, + const std::set& resource_names_failed) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] CDS update received containing %" PRIuPTR + " resources", + xds_client(), cds_update_map.size()); + } + auto& cds_state = state_map_[XdsApi::kCdsTypeUrl]; + std::set eds_resource_names_seen; + for (auto& p : cds_update_map) { + const char* cluster_name = p.first.c_str(); + XdsApi::CdsUpdate& cds_update = p.second.resource; + auto& state = cds_state.subscribed_resources[cluster_name]; + if (state != nullptr) state->Finish(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] cluster=%s: %s", xds_client(), + cluster_name, cds_update.ToString().c_str()); + } + // Record the EDS resource names seen. + eds_resource_names_seen.insert(cds_update.eds_service_name.empty() + ? cluster_name + : cds_update.eds_service_name); + // Ignore identical update. + ClusterState& cluster_state = xds_client()->cluster_map_[cluster_name]; + if (cluster_state.update.has_value() && + *cluster_state.update == cds_update) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] CDS update identical to current, ignoring.", + xds_client()); + } + continue; + } + // Update the cluster state. + cluster_state.update = std::move(cds_update); + cluster_state.meta = CreateResourceMetadataAcked( + std::move(p.second.serialized_proto), version, update_time); + // Notify all watchers. + for (const auto& p : cluster_state.watchers) { + p.first->OnClusterChanged(cluster_state.update.value()); + } + } + // For invalid resources in the update, if they are already in the + // cache, pretend that they are present in the update, so that we + // don't incorrectly consider them deleted below. + for (const std::string& cluster_name : resource_names_failed) { + auto it = xds_client()->cluster_map_.find(cluster_name); + if (it != xds_client()->cluster_map_.end()) { + auto& resource = it->second.update; + if (!resource.has_value()) continue; + cds_update_map[cluster_name]; + eds_resource_names_seen.insert(resource->eds_service_name.empty() + ? cluster_name + : resource->eds_service_name); + } + } + // For any subscribed resource that is not present in the update, + // remove it from the cache and notify watchers that it does not exist. + for (const auto& p : cds_state.subscribed_resources) { + const std::string& cluster_name = p.first; + if (cds_update_map.find(cluster_name) == cds_update_map.end()) { + ClusterState& cluster_state = xds_client()->cluster_map_[cluster_name]; + // If the resource was newly requested but has not yet been received, + // we don't want to generate an error for the watchers, because this CDS + // response may be in reaction to an earlier request that did not yet + // request the new resource, so its absence from the response does not + // necessarily indicate that the resource does not exist. + // For that case, we rely on the request timeout instead. + if (!cluster_state.update.has_value()) continue; + cluster_state.update.reset(); + for (const auto& p : cluster_state.watchers) { + p.first->OnResourceDoesNotExist(); + } + } + } + // For any EDS resource that is no longer referred to by any CDS + // resources, remove it from the cache and notify watchers that it + // does not exist. + auto& eds_state = state_map_[XdsApi::kEdsTypeUrl]; + for (const auto& p : eds_state.subscribed_resources) { + const std::string& eds_resource_name = p.first; + if (eds_resource_names_seen.find(eds_resource_name) == + eds_resource_names_seen.end()) { + EndpointState& endpoint_state = + xds_client()->endpoint_map_[eds_resource_name]; + endpoint_state.update.reset(); + for (const auto& p : endpoint_state.watchers) { + p.first->OnResourceDoesNotExist(); + } + } + } +} + +void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdateLocked( + std::string version, grpc_millis update_time, + XdsApi::EdsUpdateMap eds_update_map) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] EDS update received containing %" PRIuPTR + " resources", + xds_client(), eds_update_map.size()); + } + auto& eds_state = state_map_[XdsApi::kEdsTypeUrl]; + for (auto& p : eds_update_map) { + const char* eds_service_name = p.first.c_str(); + XdsApi::EdsUpdate& eds_update = p.second.resource; + auto& state = eds_state.subscribed_resources[eds_service_name]; + if (state != nullptr) state->Finish(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] EDS resource %s: %s", xds_client(), + eds_service_name, eds_update.ToString().c_str()); + } + EndpointState& endpoint_state = + xds_client()->endpoint_map_[eds_service_name]; + // Ignore identical update. + if (endpoint_state.update.has_value() && + *endpoint_state.update == eds_update) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] EDS update identical to current, ignoring.", + xds_client()); + } + continue; + } + // Update the cluster state. + endpoint_state.update = std::move(eds_update); + endpoint_state.meta = CreateResourceMetadataAcked( + std::move(p.second.serialized_proto), version, update_time); + // Notify all watchers. + for (const auto& p : endpoint_state.watchers) { + p.first->OnEndpointChanged(endpoint_state.update.value()); + } + } +} + +namespace { + +// Update resource_metadata for NACK. +void UpdateResourceMetadataNacked(const std::string& version, + const std::string& details, + grpc_millis update_time, + XdsApi::ResourceMetadata* resource_metadata) { + resource_metadata->client_status = XdsApi::ResourceMetadata::NACKED; + resource_metadata->failed_version = version; + resource_metadata->failed_details = details; + resource_metadata->failed_update_time = update_time; +} + +} // namespace + +template +void XdsClient::ChannelState::AdsCallState::RejectAdsUpdateLocked( + grpc_millis update_time, const XdsApi::AdsParseResult& result, + StateMap* state_map) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] %s update NACKed containing %" PRIuPTR + " invalid resources", + xds_client(), result.type_url.c_str(), + result.resource_names_failed.size()); + } + std::string details = grpc_error_std_string(result.parse_error); + for (auto& name : result.resource_names_failed) { + auto it = state_map->find(name); + if (it == state_map->end()) continue; + auto& state = it->second; + // Notify watchers of error. + for (const auto& p : state.watchers) { + p.first->OnError(GRPC_ERROR_REF(result.parse_error)); + } + // Update resource metadata for CSDS. + UpdateResourceMetadataNacked(result.version, details, update_time, + &state.meta); + } +} + +void XdsClient::ChannelState::AdsCallState::OnRequestSent( + void* arg, grpc_error_handle error) { + AdsCallState* ads_calld = static_cast(arg); + { + MutexLock lock(&ads_calld->xds_client()->mu_); + ads_calld->OnRequestSentLocked(GRPC_ERROR_REF(error)); + } + ads_calld->Unref(DEBUG_LOCATION, "ADS+OnRequestSentLocked"); +} + +void XdsClient::ChannelState::AdsCallState::OnRequestSentLocked( + grpc_error_handle error) { + if (IsCurrentCallOnChannel() && error == GRPC_ERROR_NONE) { + // Clean up the sent message. + grpc_byte_buffer_destroy(send_message_payload_); + send_message_payload_ = nullptr; + // Continue to send another pending message if any. + // TODO(roth): The current code to handle buffered messages has the + // advantage of sending only the most recent list of resource names for + // each resource type (no matter how many times that resource type has + // been requested to send while the current message sending is still + // pending). But its disadvantage is that we send the requests in fixed + // order of resource types. We need to fix this if we are seeing some + // resource type(s) starved due to frequent requests of other resource + // type(s). + auto it = buffered_requests_.begin(); + if (it != buffered_requests_.end()) { + SendMessageLocked(*it); + buffered_requests_.erase(it); + } + } + GRPC_ERROR_UNREF(error); +} + +void XdsClient::ChannelState::AdsCallState::OnResponseReceived( + void* arg, grpc_error_handle /* error */) { + AdsCallState* ads_calld = static_cast(arg); + bool done; + { + MutexLock lock(&ads_calld->xds_client()->mu_); + done = ads_calld->OnResponseReceivedLocked(); + } + if (done) ads_calld->Unref(DEBUG_LOCATION, "ADS+OnResponseReceivedLocked"); +} + +bool XdsClient::ChannelState::AdsCallState::OnResponseReceivedLocked() { + // Empty payload means the call was cancelled. + if (!IsCurrentCallOnChannel() || recv_message_payload_ == nullptr) { + return true; + } + // Read the response. + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, recv_message_payload_); + grpc_slice response_slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_byte_buffer_reader_destroy(&bbr); + grpc_byte_buffer_destroy(recv_message_payload_); + recv_message_payload_ = nullptr; + // Parse and validate the response. + XdsApi::AdsParseResult result = xds_client()->api_.ParseAdsResponse( + chand()->server_, response_slice, + ResourceNamesForRequest(XdsApi::kLdsTypeUrl), + ResourceNamesForRequest(XdsApi::kRdsTypeUrl), + ResourceNamesForRequest(XdsApi::kCdsTypeUrl), + ResourceNamesForRequest(XdsApi::kEdsTypeUrl)); + grpc_slice_unref_internal(response_slice); + if (result.type_url.empty()) { + // Ignore unparsable response. + gpr_log(GPR_ERROR, + "[xds_client %p] Error parsing ADS response (%s) -- ignoring", + xds_client(), grpc_error_std_string(result.parse_error).c_str()); + GRPC_ERROR_UNREF(result.parse_error); + } else { + grpc_millis update_time = grpc_core::ExecCtx::Get()->Now(); + // Update nonce. + auto& state = state_map_[result.type_url]; + state.nonce = std::move(result.nonce); + // If we got an error, we'll NACK the update. + if (result.parse_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "[xds_client %p] ADS response invalid for resource type %s " + "version %s, will NACK: nonce=%s error=%s", + xds_client(), result.type_url.c_str(), result.version.c_str(), + state.nonce.c_str(), + grpc_error_std_string(result.parse_error).c_str()); + result.parse_error = + grpc_error_set_int(result.parse_error, GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE); + GRPC_ERROR_UNREF(state.error); + state.error = result.parse_error; + if (result.type_url == XdsApi::kLdsTypeUrl) { + RejectAdsUpdateLocked(update_time, result, + &xds_client()->listener_map_); + } else if (result.type_url == XdsApi::kRdsTypeUrl) { + RejectAdsUpdateLocked(update_time, result, + &xds_client()->route_config_map_); + } else if (result.type_url == XdsApi::kCdsTypeUrl) { + RejectAdsUpdateLocked(update_time, result, &xds_client()->cluster_map_); + } else if (result.type_url == XdsApi::kEdsTypeUrl) { + RejectAdsUpdateLocked(update_time, result, + &xds_client()->endpoint_map_); + } + } + // Process any valid resources. + bool have_valid_resources = false; + if (result.type_url == XdsApi::kLdsTypeUrl) { + have_valid_resources = !result.lds_update_map.empty(); + AcceptLdsUpdateLocked(result.version, update_time, + std::move(result.lds_update_map), + result.resource_names_failed); + } else if (result.type_url == XdsApi::kRdsTypeUrl) { + have_valid_resources = !result.rds_update_map.empty(); + AcceptRdsUpdateLocked(result.version, update_time, + std::move(result.rds_update_map)); + } else if (result.type_url == XdsApi::kCdsTypeUrl) { + have_valid_resources = !result.cds_update_map.empty(); + AcceptCdsUpdateLocked(result.version, update_time, + std::move(result.cds_update_map), + result.resource_names_failed); + } else if (result.type_url == XdsApi::kEdsTypeUrl) { + have_valid_resources = !result.eds_update_map.empty(); + AcceptEdsUpdateLocked(result.version, update_time, + std::move(result.eds_update_map)); + } + if (have_valid_resources) { + seen_response_ = true; + xds_client()->resource_version_map_[result.type_url] = + std::move(result.version); + // Start load reporting if needed. + auto& lrs_call = chand()->lrs_calld_; + if (lrs_call != nullptr) { + LrsCallState* lrs_calld = lrs_call->calld(); + if (lrs_calld != nullptr) lrs_calld->MaybeStartReportingLocked(); + } + } + // Send ACK or NACK. + SendMessageLocked(result.type_url); + } + if (xds_client()->shutting_down_) return true; + // Keep listening for updates. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_RECV_MESSAGE; + op.data.recv_message.recv_message = &recv_message_payload_; + op.flags = 0; + op.reserved = nullptr; + GPR_ASSERT(call_ != nullptr); + // Reuse the "ADS+OnResponseReceivedLocked" ref taken in ctor. + const grpc_call_error call_error = + grpc_call_start_batch_and_execute(call_, &op, 1, &on_response_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + return false; +} + +void XdsClient::ChannelState::AdsCallState::OnStatusReceived( + void* arg, grpc_error_handle error) { + AdsCallState* ads_calld = static_cast(arg); + { + MutexLock lock(&ads_calld->xds_client()->mu_); + ads_calld->OnStatusReceivedLocked(GRPC_ERROR_REF(error)); + } + ads_calld->Unref(DEBUG_LOCATION, "ADS+OnStatusReceivedLocked"); +} + +void XdsClient::ChannelState::AdsCallState::OnStatusReceivedLocked( + grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + char* status_details = grpc_slice_to_c_string(status_details_); + gpr_log(GPR_INFO, + "[xds_client %p] ADS call status received. Status = %d, details " + "= '%s', (chand: %p, ads_calld: %p, call: %p), error '%s'", + xds_client(), status_code_, status_details, chand(), this, call_, + grpc_error_std_string(error).c_str()); + gpr_free(status_details); + } + // Ignore status from a stale call. + if (IsCurrentCallOnChannel()) { + // Try to restart the call. + parent_->OnCallFinishedLocked(); + // Send error to all watchers. + xds_client()->NotifyOnErrorLocked( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("xds call failed")); + } + GRPC_ERROR_UNREF(error); +} + +bool XdsClient::ChannelState::AdsCallState::IsCurrentCallOnChannel() const { + // If the retryable ADS call is null (which only happens when the xds channel + // is shutting down), all the ADS calls are stale. + if (chand()->ads_calld_ == nullptr) return false; + return this == chand()->ads_calld_->calld(); +} + +std::set +XdsClient::ChannelState::AdsCallState::ResourceNamesForRequest( + const std::string& type_url) { + std::set resource_names; + auto it = state_map_.find(type_url); + if (it != state_map_.end()) { + for (auto& p : it->second.subscribed_resources) { + resource_names.insert(p.first); + OrphanablePtr& state = p.second; + state->Start(Ref(DEBUG_LOCATION, "ResourceState")); + } + } + return resource_names; +} + +// +// XdsClient::ChannelState::LrsCallState::Reporter +// + +void XdsClient::ChannelState::LrsCallState::Reporter::Orphan() { + if (next_report_timer_callback_pending_) { + grpc_timer_cancel(&next_report_timer_); + } +} + +void XdsClient::ChannelState::LrsCallState::Reporter:: + ScheduleNextReportLocked() { + const grpc_millis next_report_time = ExecCtx::Get()->Now() + report_interval_; + grpc_timer_init(&next_report_timer_, next_report_time, + &on_next_report_timer_); + next_report_timer_callback_pending_ = true; +} + +void XdsClient::ChannelState::LrsCallState::Reporter::OnNextReportTimer( + void* arg, grpc_error_handle error) { + Reporter* self = static_cast(arg); + bool done; + { + MutexLock lock(&self->xds_client()->mu_); + done = self->OnNextReportTimerLocked(GRPC_ERROR_REF(error)); + } + if (done) self->Unref(DEBUG_LOCATION, "Reporter+timer"); +} + +bool XdsClient::ChannelState::LrsCallState::Reporter::OnNextReportTimerLocked( + grpc_error_handle error) { + next_report_timer_callback_pending_ = false; + if (error != GRPC_ERROR_NONE || !IsCurrentReporterOnCall()) { + GRPC_ERROR_UNREF(error); + return true; + } + return SendReportLocked(); +} + +namespace { + +bool LoadReportCountersAreZero(const XdsApi::ClusterLoadReportMap& snapshot) { + for (const auto& p : snapshot) { + const XdsApi::ClusterLoadReport& cluster_snapshot = p.second; + if (!cluster_snapshot.dropped_requests.IsZero()) return false; + for (const auto& q : cluster_snapshot.locality_stats) { + const XdsClusterLocalityStats::Snapshot& locality_snapshot = q.second; + if (!locality_snapshot.IsZero()) return false; + } + } + return true; +} + +} // namespace + +bool XdsClient::ChannelState::LrsCallState::Reporter::SendReportLocked() { + // Construct snapshot from all reported stats. + XdsApi::ClusterLoadReportMap snapshot = + xds_client()->BuildLoadReportSnapshotLocked(parent_->send_all_clusters_, + parent_->cluster_names_); + // Skip client load report if the counters were all zero in the last + // report and they are still zero in this one. + const bool old_val = last_report_counters_were_zero_; + last_report_counters_were_zero_ = LoadReportCountersAreZero(snapshot); + if (old_val && last_report_counters_were_zero_) { + if (xds_client()->load_report_map_.empty()) { + parent_->chand()->StopLrsCall(); + return true; + } + ScheduleNextReportLocked(); + return false; + } + // Create a request that contains the snapshot. + grpc_slice request_payload_slice = + xds_client()->api_.CreateLrsRequest(std::move(snapshot)); + parent_->send_message_payload_ = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_slice_unref_internal(request_payload_slice); + // Send the report. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_SEND_MESSAGE; + op.data.send_message.send_message = parent_->send_message_payload_; + grpc_call_error call_error = grpc_call_start_batch_and_execute( + parent_->call_, &op, 1, &on_report_done_); + if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { + gpr_log(GPR_ERROR, + "[xds_client %p] calld=%p call_error=%d sending client load report", + xds_client(), this, call_error); + GPR_ASSERT(GRPC_CALL_OK == call_error); + } + return false; +} + +void XdsClient::ChannelState::LrsCallState::Reporter::OnReportDone( + void* arg, grpc_error_handle error) { + Reporter* self = static_cast(arg); + bool done; + { + MutexLock lock(&self->xds_client()->mu_); + done = self->OnReportDoneLocked(GRPC_ERROR_REF(error)); + } + if (done) self->Unref(DEBUG_LOCATION, "Reporter+report_done"); +} + +bool XdsClient::ChannelState::LrsCallState::Reporter::OnReportDoneLocked( + grpc_error_handle error) { + grpc_byte_buffer_destroy(parent_->send_message_payload_); + parent_->send_message_payload_ = nullptr; + // If there are no more registered stats to report, cancel the call. + if (xds_client()->load_report_map_.empty()) { + parent_->chand()->StopLrsCall(); + GRPC_ERROR_UNREF(error); + return true; + } + if (error != GRPC_ERROR_NONE || !IsCurrentReporterOnCall()) { + GRPC_ERROR_UNREF(error); + // If this reporter is no longer the current one on the call, the reason + // might be that it was orphaned for a new one due to config update. + if (!IsCurrentReporterOnCall()) { + parent_->MaybeStartReportingLocked(); + } + return true; + } + ScheduleNextReportLocked(); + return false; +} + +// +// XdsClient::ChannelState::LrsCallState +// + +XdsClient::ChannelState::LrsCallState::LrsCallState( + RefCountedPtr> parent) + : InternallyRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) + ? "LrsCallState" + : nullptr), + parent_(std::move(parent)) { + // Init the LRS call. Note that the call will progress every time there's + // activity in xds_client()->interested_parties_, which is comprised of + // the polling entities from client_channel. + GPR_ASSERT(xds_client() != nullptr); + const auto& method = + chand()->server_.ShouldUseV3() + ? GRPC_MDSTR_SLASH_ENVOY_DOT_SERVICE_DOT_LOAD_STATS_DOT_V3_DOT_LOADREPORTINGSERVICE_SLASH_STREAMLOADSTATS + : GRPC_MDSTR_SLASH_ENVOY_DOT_SERVICE_DOT_LOAD_STATS_DOT_V2_DOT_LOADREPORTINGSERVICE_SLASH_STREAMLOADSTATS; + call_ = grpc_channel_create_pollset_set_call( + chand()->channel_, nullptr, GRPC_PROPAGATE_DEFAULTS, + xds_client()->interested_parties_, method, nullptr, + GRPC_MILLIS_INF_FUTURE, nullptr); + GPR_ASSERT(call_ != nullptr); + // Init the request payload. + grpc_slice request_payload_slice = + xds_client()->api_.CreateLrsInitialRequest(chand()->server_); + send_message_payload_ = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_slice_unref_internal(request_payload_slice); + // Init other data associated with the LRS call. + grpc_metadata_array_init(&initial_metadata_recv_); + grpc_metadata_array_init(&trailing_metadata_recv_); + // Start the call. + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] Starting LRS call (chand: %p, calld: %p, " + "call: %p)", + xds_client(), chand(), this, call_); + } + // Create the ops. + grpc_call_error call_error; + grpc_op ops[3]; + memset(ops, 0, sizeof(ops)); + // Op: send initial metadata. + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY | + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET; + op->reserved = nullptr; + op++; + // Op: send request message. + GPR_ASSERT(send_message_payload_ != nullptr); + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = send_message_payload_; + op->flags = 0; + op->reserved = nullptr; + op++; + Ref(DEBUG_LOCATION, "LRS+OnInitialRequestSentLocked").release(); + GRPC_CLOSURE_INIT(&on_initial_request_sent_, OnInitialRequestSent, this, + grpc_schedule_on_exec_ctx); + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &on_initial_request_sent_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: recv initial metadata. + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv_; + op->flags = 0; + op->reserved = nullptr; + op++; + // Op: recv response. + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_message_payload_; + op->flags = 0; + op->reserved = nullptr; + op++; + Ref(DEBUG_LOCATION, "LRS+OnResponseReceivedLocked").release(); + GRPC_CLOSURE_INIT(&on_response_received_, OnResponseReceived, this, + grpc_schedule_on_exec_ctx); + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &on_response_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + // Op: recv server status. + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv_; + op->data.recv_status_on_client.status = &status_code_; + op->data.recv_status_on_client.status_details = &status_details_; + op->flags = 0; + op->reserved = nullptr; + op++; + // This callback signals the end of the call, so it relies on the initial + // ref instead of a new ref. When it's invoked, it's the initial ref that is + // unreffed. + GRPC_CLOSURE_INIT(&on_status_received_, OnStatusReceived, this, + grpc_schedule_on_exec_ctx); + call_error = grpc_call_start_batch_and_execute( + call_, ops, static_cast(op - ops), &on_status_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); +} + +XdsClient::ChannelState::LrsCallState::~LrsCallState() { + grpc_metadata_array_destroy(&initial_metadata_recv_); + grpc_metadata_array_destroy(&trailing_metadata_recv_); + grpc_byte_buffer_destroy(send_message_payload_); + grpc_byte_buffer_destroy(recv_message_payload_); + grpc_slice_unref_internal(status_details_); + GPR_ASSERT(call_ != nullptr); + grpc_call_unref(call_); +} + +void XdsClient::ChannelState::LrsCallState::Orphan() { + reporter_.reset(); + GPR_ASSERT(call_ != nullptr); + // If we are here because xds_client wants to cancel the call, + // on_status_received_ will complete the cancellation and clean up. Otherwise, + // we are here because xds_client has to orphan a failed call, then the + // following cancellation will be a no-op. + grpc_call_cancel_internal(call_); + // Note that the initial ref is hold by on_status_received_. So the + // corresponding unref happens in on_status_received_ instead of here. +} + +void XdsClient::ChannelState::LrsCallState::MaybeStartReportingLocked() { + // Don't start again if already started. + if (reporter_ != nullptr) return; + // Don't start if the previous send_message op (of the initial request or the + // last report of the previous reporter) hasn't completed. + if (send_message_payload_ != nullptr) return; + // Don't start if no LRS response has arrived. + if (!seen_response()) return; + // Don't start if the ADS call hasn't received any valid response. Note that + // this must be the first channel because it is the current channel but its + // ADS call hasn't seen any response. + if (chand()->ads_calld_ == nullptr || + chand()->ads_calld_->calld() == nullptr || + !chand()->ads_calld_->calld()->seen_response()) { + return; + } + // Start reporting. + reporter_ = MakeOrphanable( + Ref(DEBUG_LOCATION, "LRS+load_report+start"), load_reporting_interval_); +} + +void XdsClient::ChannelState::LrsCallState::OnInitialRequestSent( + void* arg, grpc_error_handle /*error*/) { + LrsCallState* lrs_calld = static_cast(arg); + { + MutexLock lock(&lrs_calld->xds_client()->mu_); + lrs_calld->OnInitialRequestSentLocked(); + } + lrs_calld->Unref(DEBUG_LOCATION, "LRS+OnInitialRequestSentLocked"); +} + +void XdsClient::ChannelState::LrsCallState::OnInitialRequestSentLocked() { + // Clear the send_message_payload_. + grpc_byte_buffer_destroy(send_message_payload_); + send_message_payload_ = nullptr; + MaybeStartReportingLocked(); +} + +void XdsClient::ChannelState::LrsCallState::OnResponseReceived( + void* arg, grpc_error_handle /*error*/) { + LrsCallState* lrs_calld = static_cast(arg); + bool done; + { + MutexLock lock(&lrs_calld->xds_client()->mu_); + done = lrs_calld->OnResponseReceivedLocked(); + } + if (done) lrs_calld->Unref(DEBUG_LOCATION, "LRS+OnResponseReceivedLocked"); +} + +bool XdsClient::ChannelState::LrsCallState::OnResponseReceivedLocked() { + // Empty payload means the call was cancelled. + if (!IsCurrentCallOnChannel() || recv_message_payload_ == nullptr) { + return true; + } + // Read the response. + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, recv_message_payload_); + grpc_slice response_slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_byte_buffer_reader_destroy(&bbr); + grpc_byte_buffer_destroy(recv_message_payload_); + recv_message_payload_ = nullptr; + // This anonymous lambda is a hack to avoid the usage of goto. + [&]() { + // Parse the response. + bool send_all_clusters = false; + std::set new_cluster_names; + grpc_millis new_load_reporting_interval; + grpc_error_handle parse_error = xds_client()->api_.ParseLrsResponse( + response_slice, &send_all_clusters, &new_cluster_names, + &new_load_reporting_interval); + if (parse_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "[xds_client %p] LRS response parsing failed. error=%s", + xds_client(), grpc_error_std_string(parse_error).c_str()); + GRPC_ERROR_UNREF(parse_error); + return; + } + seen_response_ = true; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log( + GPR_INFO, + "[xds_client %p] LRS response received, %" PRIuPTR + " cluster names, send_all_clusters=%d, load_report_interval=%" PRId64 + "ms", + xds_client(), new_cluster_names.size(), send_all_clusters, + new_load_reporting_interval); + size_t i = 0; + for (const auto& name : new_cluster_names) { + gpr_log(GPR_INFO, "[xds_client %p] cluster_name %" PRIuPTR ": %s", + xds_client(), i++, name.c_str()); + } + } + if (new_load_reporting_interval < + GRPC_XDS_MIN_CLIENT_LOAD_REPORTING_INTERVAL_MS) { + new_load_reporting_interval = + GRPC_XDS_MIN_CLIENT_LOAD_REPORTING_INTERVAL_MS; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] Increased load_report_interval to minimum " + "value %dms", + xds_client(), GRPC_XDS_MIN_CLIENT_LOAD_REPORTING_INTERVAL_MS); + } + } + // Ignore identical update. + if (send_all_clusters == send_all_clusters_ && + cluster_names_ == new_cluster_names && + load_reporting_interval_ == new_load_reporting_interval) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] Incoming LRS response identical to current, " + "ignoring.", + xds_client()); + } + return; + } + // Stop current load reporting (if any) to adopt the new config. + reporter_.reset(); + // Record the new config. + send_all_clusters_ = send_all_clusters; + cluster_names_ = std::move(new_cluster_names); + load_reporting_interval_ = new_load_reporting_interval; + // Try starting sending load report. + MaybeStartReportingLocked(); + }(); + grpc_slice_unref_internal(response_slice); + if (xds_client()->shutting_down_) return true; + // Keep listening for LRS config updates. + grpc_op op; + memset(&op, 0, sizeof(op)); + op.op = GRPC_OP_RECV_MESSAGE; + op.data.recv_message.recv_message = &recv_message_payload_; + op.flags = 0; + op.reserved = nullptr; + GPR_ASSERT(call_ != nullptr); + // Reuse the "OnResponseReceivedLocked" ref taken in ctor. + const grpc_call_error call_error = + grpc_call_start_batch_and_execute(call_, &op, 1, &on_response_received_); + GPR_ASSERT(GRPC_CALL_OK == call_error); + return false; +} + +void XdsClient::ChannelState::LrsCallState::OnStatusReceived( + void* arg, grpc_error_handle error) { + LrsCallState* lrs_calld = static_cast(arg); + { + MutexLock lock(&lrs_calld->xds_client()->mu_); + lrs_calld->OnStatusReceivedLocked(GRPC_ERROR_REF(error)); + } + lrs_calld->Unref(DEBUG_LOCATION, "LRS+OnStatusReceivedLocked"); +} + +void XdsClient::ChannelState::LrsCallState::OnStatusReceivedLocked( + grpc_error_handle error) { + GPR_ASSERT(call_ != nullptr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + char* status_details = grpc_slice_to_c_string(status_details_); + gpr_log(GPR_INFO, + "[xds_client %p] LRS call status received. Status = %d, details " + "= '%s', (chand: %p, calld: %p, call: %p), error '%s'", + xds_client(), status_code_, status_details, chand(), this, call_, + grpc_error_std_string(error).c_str()); + gpr_free(status_details); + } + // Ignore status from a stale call. + if (IsCurrentCallOnChannel()) { + GPR_ASSERT(!xds_client()->shutting_down_); + // Try to restart the call. + parent_->OnCallFinishedLocked(); + } + GRPC_ERROR_UNREF(error); +} + +bool XdsClient::ChannelState::LrsCallState::IsCurrentCallOnChannel() const { + // If the retryable LRS call is null (which only happens when the xds channel + // is shutting down), all the LRS calls are stale. + if (chand()->lrs_calld_ == nullptr) return false; + return this == chand()->lrs_calld_->calld(); +} + +// +// XdsClient +// + +namespace { + +grpc_millis GetRequestTimeout(const grpc_channel_args* args) { + return grpc_channel_args_find_integer( + args, GRPC_ARG_XDS_RESOURCE_DOES_NOT_EXIST_TIMEOUT_MS, + {15000, 0, INT_MAX}); +} + +grpc_channel_args* ModifyChannelArgs(const grpc_channel_args* args) { + absl::InlinedVector args_to_add = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), + 5 * 60 * GPR_MS_PER_SEC), + }; + return grpc_channel_args_copy_and_add(args, args_to_add.data(), + args_to_add.size()); +} + +} // namespace + +XdsClient::XdsClient(std::unique_ptr bootstrap, + const grpc_channel_args* args) + : DualRefCounted( + GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) ? "XdsClient" + : nullptr), + bootstrap_(std::move(bootstrap)), + args_(ModifyChannelArgs(args)), + request_timeout_(GetRequestTimeout(args)), + interested_parties_(grpc_pollset_set_create()), + certificate_provider_store_(MakeOrphanable( + bootstrap_->certificate_providers())), + api_(this, &grpc_xds_client_trace, bootstrap_->node(), + &bootstrap_->certificate_providers()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] creating xds client", this); + } + // Create ChannelState object. + chand_ = MakeOrphanable( + WeakRef(DEBUG_LOCATION, "XdsClient+ChannelState"), bootstrap_->server()); +} + +XdsClient::~XdsClient() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] destroying xds client", this); + } + grpc_channel_args_destroy(args_); + grpc_pollset_set_destroy(interested_parties_); +} + +void XdsClient::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] shutting down xds client", this); + } + { + MutexLock lock(g_mu); + if (g_xds_client == this) g_xds_client = nullptr; + } + { + MutexLock lock(&mu_); + shutting_down_ = true; + // Orphan ChannelState object. + chand_.reset(); + // We do not clear cluster_map_ and endpoint_map_ if the xds client was + // created by the XdsResolver because the maps contain refs for watchers + // which in turn hold refs to the loadbalancing policies. At this point, it + // is possible for ADS calls to be in progress. Unreffing the loadbalancing + // policies before those calls are done would lead to issues such as + // https://github.com/grpc/grpc/issues/20928. + if (!listener_map_.empty()) { + cluster_map_.clear(); + endpoint_map_.clear(); + } + } +} + +void XdsClient::WatchListenerData( + absl::string_view listener_name, + std::unique_ptr watcher) { + std::string listener_name_str = std::string(listener_name); + MutexLock lock(&mu_); + ListenerState& listener_state = listener_map_[listener_name_str]; + ListenerWatcherInterface* w = watcher.get(); + listener_state.watchers[w] = std::move(watcher); + // If we've already received an LDS update, notify the new watcher + // immediately. + if (listener_state.update.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] returning cached listener data for %s", + this, listener_name_str.c_str()); + } + w->OnListenerChanged(*listener_state.update); + } + chand_->SubscribeLocked(XdsApi::kLdsTypeUrl, listener_name_str); +} + +void XdsClient::CancelListenerDataWatch(absl::string_view listener_name, + ListenerWatcherInterface* watcher, + bool delay_unsubscription) { + MutexLock lock(&mu_); + if (shutting_down_) return; + std::string listener_name_str = std::string(listener_name); + ListenerState& listener_state = listener_map_[listener_name_str]; + auto it = listener_state.watchers.find(watcher); + if (it != listener_state.watchers.end()) { + listener_state.watchers.erase(it); + if (listener_state.watchers.empty()) { + listener_map_.erase(listener_name_str); + chand_->UnsubscribeLocked(XdsApi::kLdsTypeUrl, listener_name_str, + delay_unsubscription); + } + } +} + +void XdsClient::WatchRouteConfigData( + absl::string_view route_config_name, + std::unique_ptr watcher) { + std::string route_config_name_str = std::string(route_config_name); + MutexLock lock(&mu_); + RouteConfigState& route_config_state = + route_config_map_[route_config_name_str]; + RouteConfigWatcherInterface* w = watcher.get(); + route_config_state.watchers[w] = std::move(watcher); + // If we've already received an RDS update, notify the new watcher + // immediately. + if (route_config_state.update.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] returning cached route config data for %s", this, + route_config_name_str.c_str()); + } + w->OnRouteConfigChanged(*route_config_state.update); + } + chand_->SubscribeLocked(XdsApi::kRdsTypeUrl, route_config_name_str); +} + +void XdsClient::CancelRouteConfigDataWatch(absl::string_view route_config_name, + RouteConfigWatcherInterface* watcher, + bool delay_unsubscription) { + MutexLock lock(&mu_); + if (shutting_down_) return; + std::string route_config_name_str = std::string(route_config_name); + RouteConfigState& route_config_state = + route_config_map_[route_config_name_str]; + auto it = route_config_state.watchers.find(watcher); + if (it != route_config_state.watchers.end()) { + route_config_state.watchers.erase(it); + if (route_config_state.watchers.empty()) { + route_config_map_.erase(route_config_name_str); + chand_->UnsubscribeLocked(XdsApi::kRdsTypeUrl, route_config_name_str, + delay_unsubscription); + } + } +} + +void XdsClient::WatchClusterData( + absl::string_view cluster_name, + std::unique_ptr watcher) { + std::string cluster_name_str = std::string(cluster_name); + MutexLock lock(&mu_); + ClusterState& cluster_state = cluster_map_[cluster_name_str]; + ClusterWatcherInterface* w = watcher.get(); + cluster_state.watchers[w] = std::move(watcher); + // If we've already received a CDS update, notify the new watcher + // immediately. + if (cluster_state.update.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] returning cached cluster data for %s", + this, cluster_name_str.c_str()); + } + w->OnClusterChanged(cluster_state.update.value()); + } + chand_->SubscribeLocked(XdsApi::kCdsTypeUrl, cluster_name_str); +} + +void XdsClient::CancelClusterDataWatch(absl::string_view cluster_name, + ClusterWatcherInterface* watcher, + bool delay_unsubscription) { + MutexLock lock(&mu_); + if (shutting_down_) return; + std::string cluster_name_str = std::string(cluster_name); + ClusterState& cluster_state = cluster_map_[cluster_name_str]; + auto it = cluster_state.watchers.find(watcher); + if (it != cluster_state.watchers.end()) { + cluster_state.watchers.erase(it); + if (cluster_state.watchers.empty()) { + cluster_map_.erase(cluster_name_str); + chand_->UnsubscribeLocked(XdsApi::kCdsTypeUrl, cluster_name_str, + delay_unsubscription); + } + } +} + +void XdsClient::WatchEndpointData( + absl::string_view eds_service_name, + std::unique_ptr watcher) { + std::string eds_service_name_str = std::string(eds_service_name); + MutexLock lock(&mu_); + EndpointState& endpoint_state = endpoint_map_[eds_service_name_str]; + EndpointWatcherInterface* w = watcher.get(); + endpoint_state.watchers[w] = std::move(watcher); + // If we've already received an EDS update, notify the new watcher + // immediately. + if (endpoint_state.update.has_value()) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] returning cached endpoint data for %s", + this, eds_service_name_str.c_str()); + } + w->OnEndpointChanged(endpoint_state.update.value()); + } + chand_->SubscribeLocked(XdsApi::kEdsTypeUrl, eds_service_name_str); +} + +void XdsClient::CancelEndpointDataWatch(absl::string_view eds_service_name, + EndpointWatcherInterface* watcher, + bool delay_unsubscription) { + MutexLock lock(&mu_); + if (shutting_down_) return; + std::string eds_service_name_str = std::string(eds_service_name); + EndpointState& endpoint_state = endpoint_map_[eds_service_name_str]; + auto it = endpoint_state.watchers.find(watcher); + if (it != endpoint_state.watchers.end()) { + endpoint_state.watchers.erase(it); + if (endpoint_state.watchers.empty()) { + endpoint_map_.erase(eds_service_name_str); + chand_->UnsubscribeLocked(XdsApi::kEdsTypeUrl, eds_service_name_str, + delay_unsubscription); + } + } +} + +RefCountedPtr XdsClient::AddClusterDropStats( + absl::string_view lrs_server, absl::string_view cluster_name, + absl::string_view eds_service_name) { + // TODO(roth): When we add support for direct federation, use the + // server name specified in lrs_server. + auto key = + std::make_pair(std::string(cluster_name), std::string(eds_service_name)); + MutexLock lock(&mu_); + // We jump through some hoops here to make sure that the absl::string_views + // stored in the XdsClusterDropStats object point to the strings + // in the load_report_map_ key, so that they have the same lifetime. + auto it = load_report_map_ + .emplace(std::make_pair(std::move(key), LoadReportState())) + .first; + LoadReportState& load_report_state = it->second; + RefCountedPtr cluster_drop_stats; + if (load_report_state.drop_stats != nullptr) { + cluster_drop_stats = load_report_state.drop_stats->RefIfNonZero(); + } + if (cluster_drop_stats == nullptr) { + if (load_report_state.drop_stats != nullptr) { + load_report_state.deleted_drop_stats += + load_report_state.drop_stats->GetSnapshotAndReset(); + } + cluster_drop_stats = MakeRefCounted( + Ref(DEBUG_LOCATION, "DropStats"), lrs_server, + it->first.first /*cluster_name*/, + it->first.second /*eds_service_name*/); + load_report_state.drop_stats = cluster_drop_stats.get(); + } + chand_->MaybeStartLrsCall(); + return cluster_drop_stats; +} + +void XdsClient::RemoveClusterDropStats( + absl::string_view /*lrs_server*/, absl::string_view cluster_name, + absl::string_view eds_service_name, + XdsClusterDropStats* cluster_drop_stats) { + MutexLock lock(&mu_); + // TODO(roth): When we add support for direct federation, use the + // server name specified in lrs_server. + auto it = load_report_map_.find( + std::make_pair(std::string(cluster_name), std::string(eds_service_name))); + if (it == load_report_map_.end()) return; + LoadReportState& load_report_state = it->second; + if (load_report_state.drop_stats == cluster_drop_stats) { + // Record final snapshot in deleted_drop_stats, which will be + // added to the next load report. + load_report_state.deleted_drop_stats += + load_report_state.drop_stats->GetSnapshotAndReset(); + load_report_state.drop_stats = nullptr; + } +} + +RefCountedPtr XdsClient::AddClusterLocalityStats( + absl::string_view lrs_server, absl::string_view cluster_name, + absl::string_view eds_service_name, + RefCountedPtr locality) { + // TODO(roth): When we add support for direct federation, use the + // server name specified in lrs_server. + auto key = + std::make_pair(std::string(cluster_name), std::string(eds_service_name)); + MutexLock lock(&mu_); + // We jump through some hoops here to make sure that the absl::string_views + // stored in the XdsClusterLocalityStats object point to the strings + // in the load_report_map_ key, so that they have the same lifetime. + auto it = load_report_map_ + .emplace(std::make_pair(std::move(key), LoadReportState())) + .first; + LoadReportState& load_report_state = it->second; + LoadReportState::LocalityState& locality_state = + load_report_state.locality_stats[locality]; + RefCountedPtr cluster_locality_stats; + if (locality_state.locality_stats != nullptr) { + cluster_locality_stats = locality_state.locality_stats->RefIfNonZero(); + } + if (cluster_locality_stats == nullptr) { + if (locality_state.locality_stats != nullptr) { + locality_state.deleted_locality_stats += + locality_state.locality_stats->GetSnapshotAndReset(); + } + cluster_locality_stats = MakeRefCounted( + Ref(DEBUG_LOCATION, "LocalityStats"), lrs_server, + it->first.first /*cluster_name*/, it->first.second /*eds_service_name*/, + std::move(locality)); + locality_state.locality_stats = cluster_locality_stats.get(); + } + chand_->MaybeStartLrsCall(); + return cluster_locality_stats; +} + +void XdsClient::RemoveClusterLocalityStats( + absl::string_view /*lrs_server*/, absl::string_view cluster_name, + absl::string_view eds_service_name, + const RefCountedPtr& locality, + XdsClusterLocalityStats* cluster_locality_stats) { + MutexLock lock(&mu_); + // TODO(roth): When we add support for direct federation, use the + // server name specified in lrs_server. + auto it = load_report_map_.find( + std::make_pair(std::string(cluster_name), std::string(eds_service_name))); + if (it == load_report_map_.end()) return; + LoadReportState& load_report_state = it->second; + auto locality_it = load_report_state.locality_stats.find(locality); + if (locality_it == load_report_state.locality_stats.end()) return; + LoadReportState::LocalityState& locality_state = locality_it->second; + if (locality_state.locality_stats == cluster_locality_stats) { + // Record final snapshot in deleted_locality_stats, which will be + // added to the next load report. + locality_state.deleted_locality_stats += + locality_state.locality_stats->GetSnapshotAndReset(); + locality_state.locality_stats = nullptr; + } +} + +void XdsClient::ResetBackoff() { + MutexLock lock(&mu_); + if (chand_ != nullptr) { + grpc_channel_reset_connect_backoff(chand_->channel()); + } +} + +void XdsClient::NotifyOnErrorLocked(grpc_error_handle error) { + for (const auto& p : listener_map_) { + const ListenerState& listener_state = p.second; + for (const auto& p : listener_state.watchers) { + p.first->OnError(GRPC_ERROR_REF(error)); + } + } + for (const auto& p : route_config_map_) { + const RouteConfigState& route_config_state = p.second; + for (const auto& p : route_config_state.watchers) { + p.first->OnError(GRPC_ERROR_REF(error)); + } + } + for (const auto& p : cluster_map_) { + const ClusterState& cluster_state = p.second; + for (const auto& p : cluster_state.watchers) { + p.first->OnError(GRPC_ERROR_REF(error)); + } + } + for (const auto& p : endpoint_map_) { + const EndpointState& endpoint_state = p.second; + for (const auto& p : endpoint_state.watchers) { + p.first->OnError(GRPC_ERROR_REF(error)); + } + } + GRPC_ERROR_UNREF(error); +} + +XdsApi::ClusterLoadReportMap XdsClient::BuildLoadReportSnapshotLocked( + bool send_all_clusters, const std::set& clusters) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] start building load report", this); + } + XdsApi::ClusterLoadReportMap snapshot_map; + for (auto load_report_it = load_report_map_.begin(); + load_report_it != load_report_map_.end();) { + // Cluster key is cluster and EDS service name. + const auto& cluster_key = load_report_it->first; + LoadReportState& load_report = load_report_it->second; + // If the CDS response for a cluster indicates to use LRS but the + // LRS server does not say that it wants reports for this cluster, + // then we'll have stats objects here whose data we're not going to + // include in the load report. However, we still need to clear out + // the data from the stats objects, so that if the LRS server starts + // asking for the data in the future, we don't incorrectly include + // data from previous reporting intervals in that future report. + const bool record_stats = + send_all_clusters || clusters.find(cluster_key.first) != clusters.end(); + XdsApi::ClusterLoadReport snapshot; + // Aggregate drop stats. + snapshot.dropped_requests = std::move(load_report.deleted_drop_stats); + if (load_report.drop_stats != nullptr) { + snapshot.dropped_requests += + load_report.drop_stats->GetSnapshotAndReset(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] cluster=%s eds_service_name=%s drop_stats=%p", + this, cluster_key.first.c_str(), cluster_key.second.c_str(), + load_report.drop_stats); + } + } + // Aggregate locality stats. + for (auto it = load_report.locality_stats.begin(); + it != load_report.locality_stats.end();) { + const RefCountedPtr& locality_name = it->first; + auto& locality_state = it->second; + XdsClusterLocalityStats::Snapshot& locality_snapshot = + snapshot.locality_stats[locality_name]; + locality_snapshot = std::move(locality_state.deleted_locality_stats); + if (locality_state.locality_stats != nullptr) { + locality_snapshot += + locality_state.locality_stats->GetSnapshotAndReset(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] cluster=%s eds_service_name=%s " + "locality=%s locality_stats=%p", + this, cluster_key.first.c_str(), cluster_key.second.c_str(), + locality_name->AsHumanReadableString().c_str(), + locality_state.locality_stats); + } + } + // If the only thing left in this entry was final snapshots from + // deleted locality stats objects, remove the entry. + if (locality_state.locality_stats == nullptr) { + it = load_report.locality_stats.erase(it); + } else { + ++it; + } + } + // Compute load report interval. + const grpc_millis now = ExecCtx::Get()->Now(); + snapshot.load_report_interval = now - load_report.last_report_time; + load_report.last_report_time = now; + // Record snapshot. + if (record_stats) { + snapshot_map[cluster_key] = std::move(snapshot); + } + // If the only thing left in this entry was final snapshots from + // deleted stats objects, remove the entry. + if (load_report.locality_stats.empty() && + load_report.drop_stats == nullptr) { + load_report_it = load_report_map_.erase(load_report_it); + } else { + ++load_report_it; + } + } + return snapshot_map; +} + +std::string XdsClient::DumpClientConfigBinary() { + MutexLock lock(&mu_); + XdsApi::ResourceTypeMetadataMap resource_type_metadata_map; + // Update per-xds-type version if available, this version corresponding to the + // last successful ADS update version. + for (auto& p : resource_version_map_) { + resource_type_metadata_map[p.first].version = p.second; + } + // Collect resource metadata from listeners + auto& lds_map = + resource_type_metadata_map[XdsApi::kLdsTypeUrl].resource_metadata_map; + for (auto& p : listener_map_) { + lds_map[p.first] = &p.second.meta; + } + // Collect resource metadata from route configs + auto& rds_map = + resource_type_metadata_map[XdsApi::kRdsTypeUrl].resource_metadata_map; + for (auto& p : route_config_map_) { + rds_map[p.first] = &p.second.meta; + } + // Collect resource metadata from clusters + auto& cds_map = + resource_type_metadata_map[XdsApi::kCdsTypeUrl].resource_metadata_map; + for (auto& p : cluster_map_) { + cds_map[p.first] = &p.second.meta; + } + // Collect resource metadata from endpoints + auto& eds_map = + resource_type_metadata_map[XdsApi::kEdsTypeUrl].resource_metadata_map; + for (auto& p : endpoint_map_) { + eds_map[p.first] = &p.second.meta; + } + // Assemble config dump messages + return api_.AssembleClientConfig(resource_type_metadata_map); +} + +// +// accessors for global state +// + +void XdsClientGlobalInit() { + g_mu = new Mutex; + XdsHttpFilterRegistry::Init(); +} + +// TODO(roth): Find a better way to clear the fallback config that does +// not require using ABSL_NO_THREAD_SAFETY_ANALYSIS. +void XdsClientGlobalShutdown() ABSL_NO_THREAD_SAFETY_ANALYSIS { + gpr_free(g_fallback_bootstrap_config); + g_fallback_bootstrap_config = nullptr; + delete g_mu; + g_mu = nullptr; + XdsHttpFilterRegistry::Shutdown(); +} + +namespace { + +std::string GetBootstrapContents(const char* fallback_config, + grpc_error_handle* error) { + // First, try GRPC_XDS_BOOTSTRAP env var. + grpc_core::UniquePtr path(gpr_getenv("GRPC_XDS_BOOTSTRAP")); + if (path != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "Got bootstrap file location from GRPC_XDS_BOOTSTRAP " + "environment variable: %s", + path.get()); + } + grpc_slice contents; + *error = + grpc_load_file(path.get(), /*add_null_terminator=*/true, &contents); + if (*error != GRPC_ERROR_NONE) return ""; + std::string contents_str(StringViewFromSlice(contents)); + grpc_slice_unref_internal(contents); + return contents_str; + } + // Next, try GRPC_XDS_BOOTSTRAP_CONFIG env var. + grpc_core::UniquePtr env_config( + gpr_getenv("GRPC_XDS_BOOTSTRAP_CONFIG")); + if (env_config != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "Got bootstrap contents from GRPC_XDS_BOOTSTRAP_CONFIG " + "environment variable"); + } + return env_config.get(); + } + // Finally, try fallback config. + if (fallback_config != nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "Got bootstrap contents from fallback config"); + } + return fallback_config; + } + // No bootstrap config found. + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Environment variables GRPC_XDS_BOOTSTRAP or GRPC_XDS_BOOTSTRAP_CONFIG " + "not defined"); + return ""; +} + +} // namespace + +RefCountedPtr XdsClient::GetOrCreate(const grpc_channel_args* args, + grpc_error_handle* error) { + RefCountedPtr xds_client; + // If getting bootstrap from channel args, create a local XdsClient + // instance for the channel or server instead of using the global instance. + const char* bootstrap_config = grpc_channel_args_find_string( + args, GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_BOOTSTRAP_CONFIG); + if (bootstrap_config != nullptr) { + std::unique_ptr bootstrap = + XdsBootstrap::Create(bootstrap_config, error); + if (*error == GRPC_ERROR_NONE) { + grpc_channel_args* xds_channel_args = + grpc_channel_args_find_pointer( + args, + GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_CLIENT_CHANNEL_ARGS); + return MakeRefCounted(std::move(bootstrap), xds_channel_args); + } + return nullptr; + } + // Otherwise, use the global instance. + { + MutexLock lock(g_mu); + if (g_xds_client != nullptr) { + auto xds_client = g_xds_client->RefIfNonZero(); + if (xds_client != nullptr) return xds_client; + } + // Find bootstrap contents. + std::string bootstrap_contents = + GetBootstrapContents(g_fallback_bootstrap_config, error); + if (*error != GRPC_ERROR_NONE) return nullptr; + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "xDS bootstrap contents: %s", + bootstrap_contents.c_str()); + } + // Parse bootstrap. + std::unique_ptr bootstrap = + XdsBootstrap::Create(bootstrap_contents, error); + if (*error != GRPC_ERROR_NONE) return nullptr; + // Instantiate XdsClient. + xds_client = + MakeRefCounted(std::move(bootstrap), g_channel_args); + g_xds_client = xds_client.get(); + } + return xds_client; +} + +namespace internal { + +void SetXdsChannelArgsForTest(grpc_channel_args* args) { + MutexLock lock(g_mu); + g_channel_args = args; +} + +void UnsetGlobalXdsClientForTest() { + MutexLock lock(g_mu); + g_xds_client = nullptr; +} + +void SetXdsFallbackBootstrapConfig(const char* config) { + MutexLock lock(g_mu); + gpr_free(g_fallback_bootstrap_config); + g_fallback_bootstrap_config = gpr_strdup(config); +} + +} // namespace internal + +// +// embedding XdsClient in channel args +// + +#define GRPC_ARG_XDS_CLIENT "grpc.internal.xds_client" + +namespace { + +void* XdsClientArgCopy(void* p) { + XdsClient* xds_client = static_cast(p); + xds_client->Ref(DEBUG_LOCATION, "channel arg").release(); + return p; +} + +void XdsClientArgDestroy(void* p) { + XdsClient* xds_client = static_cast(p); + xds_client->Unref(DEBUG_LOCATION, "channel arg"); +} + +int XdsClientArgCmp(void* p, void* q) { return QsortCompare(p, q); } + +const grpc_arg_pointer_vtable kXdsClientArgVtable = { + XdsClientArgCopy, XdsClientArgDestroy, XdsClientArgCmp}; + +} // namespace + +grpc_arg XdsClient::MakeChannelArg() const { + return grpc_channel_arg_pointer_create(const_cast(GRPC_ARG_XDS_CLIENT), + const_cast(this), + &kXdsClientArgVtable); +} + +RefCountedPtr XdsClient::GetFromChannelArgs( + const grpc_channel_args& args) { + XdsClient* xds_client = + grpc_channel_args_find_pointer(&args, GRPC_ARG_XDS_CLIENT); + if (xds_client == nullptr) return nullptr; + return xds_client->Ref(DEBUG_LOCATION, "GetFromChannelArgs"); +} + +} // namespace grpc_core + +// The returned bytes may contain NULL(0), so we can't use c-string. +grpc_slice grpc_dump_xds_configs() { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_error_handle error = GRPC_ERROR_NONE; + auto xds_client = grpc_core::XdsClient::GetOrCreate(nullptr, &error); + if (error != GRPC_ERROR_NONE) { + // If we isn't using xDS, just return an empty string. + GRPC_ERROR_UNREF(error); + return grpc_empty_slice(); + } + return grpc_slice_from_cpp_string(xds_client->DumpClientConfigBinary()); +} diff --git a/src/core/ext/xds/xds_client_stats.cc b/src/core/ext/xds/xds_client_stats.cc new file mode 100644 index 00000000..437541dc --- /dev/null +++ b/src/core/ext/xds/xds_client_stats.cc @@ -0,0 +1,160 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/xds/xds_client_stats.h" + +#include + +#include +#include + +#include "src/core/ext/xds/xds_client.h" + +namespace grpc_core { + +namespace { + +uint64_t GetAndResetCounter(std::atomic* from) { + return from->exchange(0, std::memory_order_relaxed); +} + +} // namespace + +// +// XdsClusterDropStats +// + +XdsClusterDropStats::XdsClusterDropStats(RefCountedPtr xds_client, + absl::string_view lrs_server_name, + absl::string_view cluster_name, + absl::string_view eds_service_name) + : RefCounted(GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) + ? "XdsClusterDropStats" + : nullptr), + xds_client_(std::move(xds_client)), + lrs_server_name_(lrs_server_name), + cluster_name_(cluster_name), + eds_service_name_(eds_service_name) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, "[xds_client %p] created drop stats %p for {%s, %s, %s}", + xds_client_.get(), this, std::string(lrs_server_name_).c_str(), + std::string(cluster_name_).c_str(), + std::string(eds_service_name_).c_str()); + } +} + +XdsClusterDropStats::~XdsClusterDropStats() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] destroying drop stats %p for {%s, %s, %s}", + xds_client_.get(), this, std::string(lrs_server_name_).c_str(), + std::string(cluster_name_).c_str(), + std::string(eds_service_name_).c_str()); + } + xds_client_->RemoveClusterDropStats(lrs_server_name_, cluster_name_, + eds_service_name_, this); + xds_client_.reset(DEBUG_LOCATION, "DropStats"); +} + +XdsClusterDropStats::Snapshot XdsClusterDropStats::GetSnapshotAndReset() { + Snapshot snapshot; + snapshot.uncategorized_drops = GetAndResetCounter(&uncategorized_drops_); + MutexLock lock(&mu_); + snapshot.categorized_drops = std::move(categorized_drops_); + return snapshot; +} + +void XdsClusterDropStats::AddUncategorizedDrops() { + uncategorized_drops_.fetch_add(1); +} + +void XdsClusterDropStats::AddCallDropped(const std::string& category) { + MutexLock lock(&mu_); + ++categorized_drops_[category]; +} + +// +// XdsClusterLocalityStats +// + +XdsClusterLocalityStats::XdsClusterLocalityStats( + RefCountedPtr xds_client, absl::string_view lrs_server_name, + absl::string_view cluster_name, absl::string_view eds_service_name, + RefCountedPtr name) + : RefCounted(GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_refcount_trace) + ? "XdsClusterLocalityStats" + : nullptr), + xds_client_(std::move(xds_client)), + lrs_server_name_(lrs_server_name), + cluster_name_(cluster_name), + eds_service_name_(eds_service_name), + name_(std::move(name)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] created locality stats %p for {%s, %s, %s, %s}", + xds_client_.get(), this, std::string(lrs_server_name_).c_str(), + std::string(cluster_name_).c_str(), + std::string(eds_service_name_).c_str(), + name_->AsHumanReadableString().c_str()); + } +} + +XdsClusterLocalityStats::~XdsClusterLocalityStats() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { + gpr_log(GPR_INFO, + "[xds_client %p] destroying locality stats %p for {%s, %s, %s, %s}", + xds_client_.get(), this, std::string(lrs_server_name_).c_str(), + std::string(cluster_name_).c_str(), + std::string(eds_service_name_).c_str(), + name_->AsHumanReadableString().c_str()); + } + xds_client_->RemoveClusterLocalityStats(lrs_server_name_, cluster_name_, + eds_service_name_, name_, this); + xds_client_.reset(DEBUG_LOCATION, "LocalityStats"); +} + +XdsClusterLocalityStats::Snapshot +XdsClusterLocalityStats::GetSnapshotAndReset() { + Snapshot snapshot = { + GetAndResetCounter(&total_successful_requests_), + // Don't reset total_requests_in_progress because it's + // not related to a single reporting interval. + total_requests_in_progress_.load(std::memory_order_relaxed), + GetAndResetCounter(&total_error_requests_), + GetAndResetCounter(&total_issued_requests_), + {}}; + MutexLock lock(&backend_metrics_mu_); + snapshot.backend_metrics = std::move(backend_metrics_); + return snapshot; +} + +void XdsClusterLocalityStats::AddCallStarted() { + total_issued_requests_.fetch_add(1, std::memory_order_relaxed); + total_requests_in_progress_.fetch_add(1, std::memory_order_relaxed); +} + +void XdsClusterLocalityStats::AddCallFinished(bool fail) { + std::atomic& to_increment = + fail ? total_error_requests_ : total_successful_requests_; + to_increment.fetch_add(1, std::memory_order_relaxed); + total_requests_in_progress_.fetch_add(-1, std::memory_order_acq_rel); +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_http_fault_filter.cc b/src/core/ext/xds/xds_http_fault_filter.cc new file mode 100644 index 00000000..b545b99b --- /dev/null +++ b/src/core/ext/xds/xds_http_fault_filter.cc @@ -0,0 +1,227 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/xds/xds_http_fault_filter.h" + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "envoy/extensions/filters/common/fault/v3/fault.upb.h" +#include "envoy/extensions/filters/http/fault/v3/fault.upb.h" +#include "envoy/extensions/filters/http/fault/v3/fault.upbdefs.h" +#include "envoy/type/v3/percent.upb.h" +#include "google/protobuf/any.upb.h" +#include "google/protobuf/duration.upb.h" +#include "google/protobuf/wrappers.upb.h" +#include "upb/def.h" + +#include + +#include "src/core/ext/filters/fault_injection/fault_injection_filter.h" +#include "src/core/ext/xds/xds_http_filters.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/transport/status_conversion.h" + +namespace grpc_core { + +const char* kXdsHttpFaultFilterConfigName = + "envoy.extensions.filters.http.fault.v3.HTTPFault"; + +namespace { + +uint32_t GetDenominator(const envoy_type_v3_FractionalPercent* fraction) { + if (fraction != nullptr) { + const auto denominator = + static_cast( + envoy_type_v3_FractionalPercent_denominator(fraction)); + switch (denominator) { + case envoy_type_v3_FractionalPercent_MILLION: + return 1000000; + case envoy_type_v3_FractionalPercent_TEN_THOUSAND: + return 10000; + case envoy_type_v3_FractionalPercent_HUNDRED: + default: + return 100; + } + } + // Use 100 as the default denominator + return 100; +} + +absl::StatusOr ParseHttpFaultIntoJson(upb_strview serialized_http_fault, + upb_arena* arena) { + auto* http_fault = envoy_extensions_filters_http_fault_v3_HTTPFault_parse( + serialized_http_fault.data, serialized_http_fault.size, arena); + if (http_fault == nullptr) { + return absl::InvalidArgumentError( + "could not parse fault injection filter config"); + } + // NOTE(lidiz): Here, we are manually translating the upb messages into the + // JSON form of the filter config as part of method config, which will be + // directly used later by service config. In this way, we can validate the + // filter configs, and NACK if needed. It also allows the service config to + // function independently without xDS, but not the other way around. + // NOTE(lidiz): please refer to FaultInjectionPolicy for ground truth + // definitions, located at: + // src/core/ext/filters/fault_injection/service_config_parser.h + Json::Object fault_injection_policy_json; + // Section 1: Parse the abort injection config + const auto* fault_abort = + envoy_extensions_filters_http_fault_v3_HTTPFault_abort(http_fault); + if (fault_abort != nullptr) { + grpc_status_code abort_grpc_status_code = GRPC_STATUS_OK; + // Try if gRPC status code is set first + int abort_grpc_status_code_raw = + envoy_extensions_filters_http_fault_v3_FaultAbort_grpc_status( + fault_abort); + if (abort_grpc_status_code_raw != 0) { + if (!grpc_status_code_from_int(abort_grpc_status_code_raw, + &abort_grpc_status_code)) { + return absl::InvalidArgumentError(absl::StrCat( + "invalid gRPC status code: ", abort_grpc_status_code_raw)); + } + } else { + // if gRPC status code is empty, check http status + int abort_http_status_code = + envoy_extensions_filters_http_fault_v3_FaultAbort_http_status( + fault_abort); + if (abort_http_status_code != 0 and abort_http_status_code != 200) { + abort_grpc_status_code = + grpc_http2_status_to_grpc_status(abort_http_status_code); + } + } + // Set the abort_code, even if it's OK + fault_injection_policy_json["abortCode"] = + grpc_status_code_to_string(abort_grpc_status_code); + // Set the headers if we enabled header abort injection control + if (envoy_extensions_filters_http_fault_v3_FaultAbort_has_header_abort( + fault_abort)) { + fault_injection_policy_json["abortCodeHeader"] = + "x-envoy-fault-abort-grpc-request"; + fault_injection_policy_json["abortPercentageHeader"] = + "x-envoy-fault-abort-percentage"; + } + // Set the fraction percent + auto* percent = + envoy_extensions_filters_http_fault_v3_FaultAbort_percentage( + fault_abort); + fault_injection_policy_json["abortPercentageNumerator"] = + Json(envoy_type_v3_FractionalPercent_numerator(percent)); + fault_injection_policy_json["abortPercentageDenominator"] = + Json(GetDenominator(percent)); + } + // Section 2: Parse the delay injection config + const auto* fault_delay = + envoy_extensions_filters_http_fault_v3_HTTPFault_delay(http_fault); + if (fault_delay != nullptr) { + // Parse the delay duration + const auto* delay_duration = + envoy_extensions_filters_common_fault_v3_FaultDelay_fixed_delay( + fault_delay); + if (delay_duration != nullptr) { + fault_injection_policy_json["delay"] = absl::StrFormat( + "%d.%09ds", google_protobuf_Duration_seconds(delay_duration), + google_protobuf_Duration_nanos(delay_duration)); + } + // Set the headers if we enabled header delay injection control + if (envoy_extensions_filters_common_fault_v3_FaultDelay_has_header_delay( + fault_delay)) { + fault_injection_policy_json["delayHeader"] = + "x-envoy-fault-delay-request"; + fault_injection_policy_json["delayPercentageHeader"] = + "x-envoy-fault-delay-request-percentage"; + } + // Set the fraction percent + auto* percent = + envoy_extensions_filters_common_fault_v3_FaultDelay_percentage( + fault_delay); + fault_injection_policy_json["delayPercentageNumerator"] = + Json(envoy_type_v3_FractionalPercent_numerator(percent)); + fault_injection_policy_json["delayPercentageDenominator"] = + Json(GetDenominator(percent)); + } + // Section 3: Parse the maximum active faults + const auto* max_fault_wrapper = + envoy_extensions_filters_http_fault_v3_HTTPFault_max_active_faults( + http_fault); + if (max_fault_wrapper != nullptr) { + fault_injection_policy_json["maxFaults"] = + google_protobuf_UInt32Value_value(max_fault_wrapper); + } + return fault_injection_policy_json; +} + +} // namespace + +void XdsHttpFaultFilter::PopulateSymtab(upb_symtab* symtab) const { + envoy_extensions_filters_http_fault_v3_HTTPFault_getmsgdef(symtab); +} + +absl::StatusOr +XdsHttpFaultFilter::GenerateFilterConfig(upb_strview serialized_filter_config, + upb_arena* arena) const { + absl::StatusOr parse_result = + ParseHttpFaultIntoJson(serialized_filter_config, arena); + if (!parse_result.ok()) { + return parse_result.status(); + } + return FilterConfig{kXdsHttpFaultFilterConfigName, std::move(*parse_result)}; +} + +absl::StatusOr +XdsHttpFaultFilter::GenerateFilterConfigOverride( + upb_strview serialized_filter_config, upb_arena* arena) const { + // HTTPFault filter has the same message type in HTTP connection manager's + // filter config and in overriding filter config field. + return GenerateFilterConfig(serialized_filter_config, arena); +} + +const grpc_channel_filter* XdsHttpFaultFilter::channel_filter() const { + return &FaultInjectionFilterVtable; +} + +grpc_channel_args* XdsHttpFaultFilter::ModifyChannelArgs( + grpc_channel_args* args) const { + grpc_arg args_to_add = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_PARSE_FAULT_INJECTION_METHOD_CONFIG), 1); + grpc_channel_args* new_args = + grpc_channel_args_copy_and_add(args, &args_to_add, 1); + // Since this function takes the ownership of the channel args, it needs to + // deallocate the old ones to prevent leak. + grpc_channel_args_destroy(args); + return new_args; +} + +absl::StatusOr +XdsHttpFaultFilter::GenerateServiceConfig( + const FilterConfig& hcm_filter_config, + const FilterConfig* filter_config_override) const { + Json policy_json = filter_config_override != nullptr + ? filter_config_override->config + : hcm_filter_config.config; + // The policy JSON may be empty, that's allowed. + return ServiceConfigJsonEntry{"faultInjectionPolicy", policy_json.Dump()}; +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_http_filters.cc b/src/core/ext/xds/xds_http_filters.cc new file mode 100644 index 00000000..4f79336c --- /dev/null +++ b/src/core/ext/xds/xds_http_filters.cc @@ -0,0 +1,116 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/ext/xds/xds_http_filters.h" + +#include "envoy/extensions/filters/http/router/v3/router.upb.h" +#include "envoy/extensions/filters/http/router/v3/router.upbdefs.h" + +#include "src/core/ext/xds/xds_http_fault_filter.h" + +namespace grpc_core { + +const char* kXdsHttpRouterFilterConfigName = + "envoy.extensions.filters.http.router.v3.Router"; + +namespace { + +class XdsHttpRouterFilter : public XdsHttpFilterImpl { + public: + void PopulateSymtab(upb_symtab* symtab) const override { + envoy_extensions_filters_http_router_v3_Router_getmsgdef(symtab); + } + + absl::StatusOr GenerateFilterConfig( + upb_strview serialized_filter_config, upb_arena* arena) const override { + if (envoy_extensions_filters_http_router_v3_Router_parse( + serialized_filter_config.data, serialized_filter_config.size, + arena) == nullptr) { + return absl::InvalidArgumentError("could not parse router filter config"); + } + return FilterConfig{kXdsHttpRouterFilterConfigName, Json()}; + } + + absl::StatusOr GenerateFilterConfigOverride( + upb_strview /*serialized_filter_config*/, + upb_arena* /*arena*/) const override { + return absl::InvalidArgumentError( + "router filter does not support config override"); + } + + const grpc_channel_filter* channel_filter() const override { return nullptr; } + + // No-op. This will never be called, since channel_filter() returns null. + absl::StatusOr GenerateServiceConfig( + const FilterConfig& /*hcm_filter_config*/, + const FilterConfig* /*filter_config_override*/) const override { + return absl::UnimplementedError("router filter should never be called"); + } + + bool IsSupportedOnClients() const override { return true; } + + bool IsSupportedOnServers() const override { return true; } + + bool IsTerminalFilter() const override { return true; } +}; + +using FilterOwnerList = std::vector>; +using FilterRegistryMap = std::map; + +FilterOwnerList* g_filters = nullptr; +FilterRegistryMap* g_filter_registry = nullptr; + +} // namespace + +void XdsHttpFilterRegistry::RegisterFilter( + std::unique_ptr filter, + const std::set& config_proto_type_names) { + for (auto config_proto_type_name : config_proto_type_names) { + (*g_filter_registry)[config_proto_type_name] = filter.get(); + } + g_filters->push_back(std::move(filter)); +} + +const XdsHttpFilterImpl* XdsHttpFilterRegistry::GetFilterForType( + absl::string_view proto_type_name) { + auto it = g_filter_registry->find(proto_type_name); + if (it == g_filter_registry->end()) return nullptr; + return it->second; +} + +void XdsHttpFilterRegistry::PopulateSymtab(upb_symtab* symtab) { + for (const auto& filter : *g_filters) { + filter->PopulateSymtab(symtab); + } +} + +void XdsHttpFilterRegistry::Init() { + g_filters = new FilterOwnerList; + g_filter_registry = new FilterRegistryMap; + RegisterFilter(absl::make_unique(), + {kXdsHttpRouterFilterConfigName}); + RegisterFilter(absl::make_unique(), + {kXdsHttpFaultFilterConfigName}); +} + +void XdsHttpFilterRegistry::Shutdown() { + delete g_filter_registry; + delete g_filters; +} + +} // namespace grpc_core diff --git a/src/core/ext/xds/xds_server_config_fetcher.cc b/src/core/ext/xds/xds_server_config_fetcher.cc new file mode 100644 index 00000000..aed2f2dc --- /dev/null +++ b/src/core/ext/xds/xds_server_config_fetcher.cc @@ -0,0 +1,544 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "absl/strings/str_replace.h" + +#include "src/core/ext/xds/xds_certificate_provider.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/security/credentials/xds/xds_credentials.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { + +TraceFlag grpc_xds_server_config_fetcher_trace(false, + "xds_server_config_fetcher"); + +namespace { + +class FilterChainMatchManager + : public grpc_server_config_fetcher::ConnectionManager { + public: + FilterChainMatchManager( + RefCountedPtr xds_client, + XdsApi::LdsUpdate::FilterChainMap filter_chain_map, + absl::optional default_filter_chain) + : xds_client_(xds_client), + filter_chain_map_(std::move(filter_chain_map)), + default_filter_chain_(std::move(default_filter_chain)) {} + + absl::StatusOr UpdateChannelArgsForConnection( + grpc_channel_args* args, grpc_endpoint* tcp) override; + + const XdsApi::LdsUpdate::FilterChainMap& filter_chain_map() const { + return filter_chain_map_; + } + + const absl::optional& + default_filter_chain() const { + return default_filter_chain_; + } + + private: + struct CertificateProviders { + // We need to save our own refs to the root and instance certificate + // providers since the xds certificate provider just stores a ref to their + // distributors. + RefCountedPtr root; + RefCountedPtr instance; + RefCountedPtr xds; + }; + + absl::StatusOr> + CreateOrGetXdsCertificateProviderFromFilterChainData( + const XdsApi::LdsUpdate::FilterChainData* filter_chain); + + const RefCountedPtr xds_client_; + const XdsApi::LdsUpdate::FilterChainMap filter_chain_map_; + const absl::optional + default_filter_chain_; + Mutex mu_; + std::map + certificate_providers_map_ ABSL_GUARDED_BY(mu_); +}; + +bool IsLoopbackIp(const grpc_resolved_address* address) { + const grpc_sockaddr* sock_addr = + reinterpret_cast(&address->addr); + if (sock_addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast(sock_addr); + if (addr4->sin_addr.s_addr == grpc_htonl(INADDR_LOOPBACK)) { + return true; + } + } else if (sock_addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(sock_addr); + if (memcmp(&addr6->sin6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)) == 0) { + return true; + } + } + return false; +} + +const XdsApi::LdsUpdate::FilterChainData* FindFilterChainDataForSourcePort( + const XdsApi::LdsUpdate::FilterChainMap::SourcePortsMap& source_ports_map, + absl::string_view port_str) { + int port = 0; + if (!absl::SimpleAtoi(port_str, &port)) return nullptr; + auto it = source_ports_map.find(port); + if (it != source_ports_map.end()) { + return it->second.data.get(); + } + // Search for the catch-all port 0 since we didn't get a direct match + it = source_ports_map.find(0); + if (it != source_ports_map.end()) { + return it->second.data.get(); + } + return nullptr; +} + +const XdsApi::LdsUpdate::FilterChainData* FindFilterChainDataForSourceIp( + const XdsApi::LdsUpdate::FilterChainMap::SourceIpVector& source_ip_vector, + const grpc_resolved_address* source_ip, absl::string_view port) { + const XdsApi::LdsUpdate::FilterChainMap::SourceIp* best_match = nullptr; + for (const auto& entry : source_ip_vector) { + // Special case for catch-all + if (!entry.prefix_range.has_value()) { + if (best_match == nullptr) { + best_match = &entry; + } + continue; + } + if (best_match != nullptr && best_match->prefix_range.has_value() && + best_match->prefix_range->prefix_len >= + entry.prefix_range->prefix_len) { + continue; + } + if (grpc_sockaddr_match_subnet(source_ip, &entry.prefix_range->address, + entry.prefix_range->prefix_len)) { + best_match = &entry; + } + } + if (best_match == nullptr) return nullptr; + return FindFilterChainDataForSourcePort(best_match->ports_map, port); +} + +const XdsApi::LdsUpdate::FilterChainData* FindFilterChainDataForSourceType( + const XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceTypesArray& + source_types_array, + grpc_endpoint* tcp, absl::string_view destination_ip) { + auto source_uri = URI::Parse(grpc_endpoint_get_peer(tcp)); + if (!source_uri.ok() || + (source_uri->scheme() != "ipv4" && source_uri->scheme() != "ipv6")) { + return nullptr; + } + std::string host; + std::string port; + if (!SplitHostPort(source_uri->path(), &host, &port)) { + return nullptr; + } + grpc_resolved_address source_addr; + grpc_error_handle error = grpc_string_to_sockaddr( + &source_addr, host.c_str(), 0 /* port doesn't matter here */); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_DEBUG, "Could not parse string to socket address: %s", + host.c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + // Use kAny only if kSameIporLoopback and kExternal are empty + if (source_types_array[static_cast( + XdsApi::LdsUpdate::FilterChainMap:: + ConnectionSourceType::kSameIpOrLoopback)] + .empty() && + source_types_array[static_cast(XdsApi::LdsUpdate::FilterChainMap:: + ConnectionSourceType::kExternal)] + .empty()) { + return FindFilterChainDataForSourceIp( + source_types_array[static_cast( + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType::kAny)], + &source_addr, port); + } + if (IsLoopbackIp(&source_addr) || host == destination_ip) { + return FindFilterChainDataForSourceIp( + source_types_array[static_cast( + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType:: + kSameIpOrLoopback)], + &source_addr, port); + } else { + return FindFilterChainDataForSourceIp( + source_types_array[static_cast( + XdsApi::LdsUpdate::FilterChainMap::ConnectionSourceType:: + kExternal)], + &source_addr, port); + } +} + +const XdsApi::LdsUpdate::FilterChainData* FindFilterChainDataForDestinationIp( + const XdsApi::LdsUpdate::FilterChainMap::DestinationIpVector + destination_ip_vector, + grpc_endpoint* tcp) { + auto destination_uri = URI::Parse(grpc_endpoint_get_local_address(tcp)); + if (!destination_uri.ok() || (destination_uri->scheme() != "ipv4" && + destination_uri->scheme() != "ipv6")) { + return nullptr; + } + std::string host; + std::string port; + if (!SplitHostPort(destination_uri->path(), &host, &port)) { + return nullptr; + } + grpc_resolved_address destination_addr; + grpc_error_handle error = grpc_string_to_sockaddr( + &destination_addr, host.c_str(), 0 /* port doesn't matter here */); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_DEBUG, "Could not parse string to socket address: %s", + host.c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + const XdsApi::LdsUpdate::FilterChainMap::DestinationIp* best_match = nullptr; + for (const auto& entry : destination_ip_vector) { + // Special case for catch-all + if (!entry.prefix_range.has_value()) { + if (best_match == nullptr) { + best_match = &entry; + } + continue; + } + if (best_match != nullptr && best_match->prefix_range.has_value() && + best_match->prefix_range->prefix_len >= + entry.prefix_range->prefix_len) { + continue; + } + if (grpc_sockaddr_match_subnet(&destination_addr, + &entry.prefix_range->address, + entry.prefix_range->prefix_len)) { + best_match = &entry; + } + } + if (best_match == nullptr) return nullptr; + return FindFilterChainDataForSourceType(best_match->source_types_array, tcp, + host); +} + +absl::StatusOr> +FilterChainMatchManager::CreateOrGetXdsCertificateProviderFromFilterChainData( + const XdsApi::LdsUpdate::FilterChainData* filter_chain) { + MutexLock lock(&mu_); + auto it = certificate_providers_map_.find(filter_chain); + if (it != certificate_providers_map_.end()) { + return it->second.xds; + } + CertificateProviders certificate_providers; + // Configure root cert. + absl::string_view root_provider_instance_name = + filter_chain->downstream_tls_context.common_tls_context + .certificate_validation_context.ca_certificate_provider_instance + .instance_name; + absl::string_view root_provider_cert_name = + filter_chain->downstream_tls_context.common_tls_context + .certificate_validation_context.ca_certificate_provider_instance + .certificate_name; + if (!root_provider_instance_name.empty()) { + certificate_providers.root = + xds_client_->certificate_provider_store() + .CreateOrGetCertificateProvider(root_provider_instance_name); + if (certificate_providers.root == nullptr) { + return absl::NotFoundError( + absl::StrCat("Certificate provider instance name: \"", + root_provider_instance_name, "\" not recognized.")); + } + } + // Configure identity cert. + absl::string_view identity_provider_instance_name = + filter_chain->downstream_tls_context.common_tls_context + .tls_certificate_provider_instance.instance_name; + absl::string_view identity_provider_cert_name = + filter_chain->downstream_tls_context.common_tls_context + .tls_certificate_provider_instance.certificate_name; + if (!identity_provider_instance_name.empty()) { + certificate_providers.instance = + xds_client_->certificate_provider_store() + .CreateOrGetCertificateProvider(identity_provider_instance_name); + if (certificate_providers.instance == nullptr) { + return absl::NotFoundError( + absl::StrCat("Certificate provider instance name: \"", + identity_provider_instance_name, "\" not recognized.")); + } + } + certificate_providers.xds = MakeRefCounted(); + certificate_providers.xds->UpdateRootCertNameAndDistributor( + "", root_provider_cert_name, + certificate_providers.root == nullptr + ? nullptr + : certificate_providers.root->distributor()); + certificate_providers.xds->UpdateIdentityCertNameAndDistributor( + "", identity_provider_cert_name, + certificate_providers.instance == nullptr + ? nullptr + : certificate_providers.instance->distributor()); + certificate_providers.xds->UpdateRequireClientCertificate( + "", filter_chain->downstream_tls_context.require_client_certificate); + auto xds_certificate_provider = certificate_providers.xds; + certificate_providers_map_.emplace(filter_chain, + std::move(certificate_providers)); + return xds_certificate_provider; +} + +absl::StatusOr +FilterChainMatchManager::UpdateChannelArgsForConnection(grpc_channel_args* args, + grpc_endpoint* tcp) { + const auto* filter_chain = FindFilterChainDataForDestinationIp( + filter_chain_map_.destination_ip_vector, tcp); + if (filter_chain == nullptr && default_filter_chain_.has_value()) { + filter_chain = &default_filter_chain_.value(); + } + if (filter_chain == nullptr) { + grpc_channel_args_destroy(args); + return absl::UnavailableError("No matching filter chain found"); + } + // Nothing to update if credentials are not xDS. + grpc_server_credentials* server_creds = + grpc_find_server_credentials_in_args(args); + if (server_creds == nullptr || server_creds->type() != kCredentialsTypeXds) { + return args; + } + absl::StatusOr> result = + CreateOrGetXdsCertificateProviderFromFilterChainData(filter_chain); + if (!result.ok()) { + grpc_channel_args_destroy(args); + return result.status(); + } + RefCountedPtr xds_certificate_provider = + std::move(*result); + GPR_ASSERT(xds_certificate_provider != nullptr); + grpc_arg arg_to_add = xds_certificate_provider->MakeChannelArg(); + grpc_channel_args* updated_args = + grpc_channel_args_copy_and_add(args, &arg_to_add, 1); + grpc_channel_args_destroy(args); + return updated_args; +} + +class XdsServerConfigFetcher : public grpc_server_config_fetcher { + public: + explicit XdsServerConfigFetcher(RefCountedPtr xds_client, + grpc_server_xds_status_notifier notifier) + : xds_client_(std::move(xds_client)), serving_status_notifier_(notifier) { + GPR_ASSERT(xds_client_ != nullptr); + } + + void StartWatch(std::string listening_address, + std::unique_ptr + watcher) override { + grpc_server_config_fetcher::WatcherInterface* watcher_ptr = watcher.get(); + auto listener_watcher = absl::make_unique( + std::move(watcher), xds_client_, serving_status_notifier_, + listening_address); + auto* listener_watcher_ptr = listener_watcher.get(); + listening_address = absl::StrReplaceAll( + xds_client_->bootstrap().server_listener_resource_name_template(), + {{"%s", listening_address}}); + xds_client_->WatchListenerData(listening_address, + std::move(listener_watcher)); + MutexLock lock(&mu_); + auto& watcher_state = watchers_[watcher_ptr]; + watcher_state.listening_address = listening_address; + watcher_state.listener_watcher = listener_watcher_ptr; + } + + void CancelWatch( + grpc_server_config_fetcher::WatcherInterface* watcher) override { + MutexLock lock(&mu_); + auto it = watchers_.find(watcher); + if (it != watchers_.end()) { + // Cancel the watch on the listener before erasing + xds_client_->CancelListenerDataWatch(it->second.listening_address, + it->second.listener_watcher, + false /* delay_unsubscription */); + watchers_.erase(it); + } + } + + // Return the interested parties from the xds client so that it can be polled. + grpc_pollset_set* interested_parties() override { + return xds_client_->interested_parties(); + } + + private: + class ListenerWatcher : public XdsClient::ListenerWatcherInterface { + public: + explicit ListenerWatcher( + std::unique_ptr + server_config_watcher, + RefCountedPtr xds_client, + grpc_server_xds_status_notifier serving_status_notifier, + std::string listening_address) + : server_config_watcher_(std::move(server_config_watcher)), + xds_client_(std::move(xds_client)), + serving_status_notifier_(serving_status_notifier), + listening_address_(std::move(listening_address)) {} + + // Deleted due to special handling required for args_. Copy the channel args + // if we ever need these. + ListenerWatcher(const ListenerWatcher&) = delete; + ListenerWatcher& operator=(const ListenerWatcher&) = delete; + + void OnListenerChanged(XdsApi::LdsUpdate listener) override { + if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_server_config_fetcher_trace)) { + gpr_log( + GPR_INFO, + "[ListenerWatcher %p] Received LDS update from xds client %p: %s", + this, xds_client_.get(), listener.ToString().c_str()); + } + if (listener.address != listening_address_) { + OnFatalError(absl::FailedPreconditionError( + "Address in LDS update does not match listening address")); + return; + } + if (filter_chain_match_manager_ == nullptr) { + if (serving_status_notifier_.on_serving_status_update != nullptr) { + serving_status_notifier_.on_serving_status_update( + serving_status_notifier_.user_data, listening_address_.c_str(), + {GRPC_STATUS_OK, ""}); + } else { + gpr_log(GPR_INFO, + "xDS Listener resource obtained; will start serving on %s", + listening_address_.c_str()); + } + } + if (filter_chain_match_manager_ == nullptr || + !(listener.filter_chain_map == + filter_chain_match_manager_->filter_chain_map() && + listener.default_filter_chain == + filter_chain_match_manager_->default_filter_chain())) { + filter_chain_match_manager_ = MakeRefCounted( + xds_client_, std::move(listener.filter_chain_map), + std::move(listener.default_filter_chain)); + server_config_watcher_->UpdateConnectionManager( + filter_chain_match_manager_); + } + } + + void OnError(grpc_error_handle error) override { + if (filter_chain_match_manager_ != nullptr) { + gpr_log(GPR_ERROR, + "ListenerWatcher:%p XdsClient reports error: %s for %s; " + "ignoring in favor of existing resource", + this, grpc_error_std_string(error).c_str(), + listening_address_.c_str()); + } else { + if (serving_status_notifier_.on_serving_status_update != nullptr) { + serving_status_notifier_.on_serving_status_update( + serving_status_notifier_.user_data, listening_address_.c_str(), + {GRPC_STATUS_UNAVAILABLE, grpc_error_std_string(error).c_str()}); + } else { + gpr_log( + GPR_ERROR, + "ListenerWatcher:%p error obtaining xDS Listener resource: %s; " + "not serving on %s", + this, grpc_error_std_string(error).c_str(), + listening_address_.c_str()); + } + } + GRPC_ERROR_UNREF(error); + } + + void OnFatalError(absl::Status status) { + gpr_log( + GPR_ERROR, + "ListenerWatcher:%p Encountered fatal error %s; not serving on %s", + this, status.ToString().c_str(), listening_address_.c_str()); + if (filter_chain_match_manager_ != nullptr) { + // The server has started listening already, so we need to gracefully + // stop serving. + server_config_watcher_->StopServing(); + filter_chain_match_manager_.reset(); + } + if (serving_status_notifier_.on_serving_status_update != nullptr) { + serving_status_notifier_.on_serving_status_update( + serving_status_notifier_.user_data, listening_address_.c_str(), + {static_cast(status.raw_code()), + std::string(status.message()).c_str()}); + } + } + + void OnResourceDoesNotExist() override { + OnFatalError(absl::NotFoundError("Requested listener does not exist")); + } + + private: + std::unique_ptr + server_config_watcher_; + RefCountedPtr xds_client_; + grpc_server_xds_status_notifier serving_status_notifier_; + std::string listening_address_; + RefCountedPtr filter_chain_match_manager_; + }; + + struct WatcherState { + std::string listening_address; + ListenerWatcher* listener_watcher = nullptr; + }; + + RefCountedPtr xds_client_; + grpc_server_xds_status_notifier serving_status_notifier_; + Mutex mu_; + std::map + watchers_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace +} // namespace grpc_core + +grpc_server_config_fetcher* grpc_server_config_fetcher_xds_create( + grpc_server_xds_status_notifier notifier, const grpc_channel_args* args) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + args = grpc_channel_args_remove_grpc_internal(args); + GRPC_API_TRACE("grpc_server_config_fetcher_xds_create()", 0, ()); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::RefCountedPtr xds_client = + grpc_core::XdsClient::GetOrCreate(args, &error); + grpc_channel_args_destroy(args); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Failed to create xds client: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + if (xds_client->bootstrap() + .server_listener_resource_name_template() + .empty()) { + gpr_log(GPR_ERROR, + "server_listener_resource_name_template not provided in bootstrap " + "file."); + return nullptr; + } + return new grpc_core::XdsServerConfigFetcher(std::move(xds_client), notifier); +} diff --git a/src/core/lib/address_utils/parse_address.cc b/src/core/lib/address_utils/parse_address.cc new file mode 100644 index 00000000..c254d2eb --- /dev/null +++ b/src/core/lib/address_utils/parse_address.cc @@ -0,0 +1,320 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/address_utils/parse_address.h" + +#include +#include +#ifdef GRPC_HAVE_UNIX_SOCKET +#include +#endif +#ifdef GRPC_POSIX_SOCKET +#include +#include +#endif + +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/grpc_if_nametoindex.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" + +#ifdef GRPC_HAVE_UNIX_SOCKET + +bool grpc_parse_unix(const grpc_core::URI& uri, + grpc_resolved_address* resolved_addr) { + if (uri.scheme() != "unix") { + gpr_log(GPR_ERROR, "Expected 'unix' scheme, got '%s'", + uri.scheme().c_str()); + return false; + } + grpc_error_handle error = + grpc_core::UnixSockaddrPopulate(uri.path(), resolved_addr); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return false; + } + return true; +} + +bool grpc_parse_unix_abstract(const grpc_core::URI& uri, + grpc_resolved_address* resolved_addr) { + if (uri.scheme() != "unix-abstract") { + gpr_log(GPR_ERROR, "Expected 'unix-abstract' scheme, got '%s'", + uri.scheme().c_str()); + return false; + } + grpc_error_handle error = + grpc_core::UnixAbstractSockaddrPopulate(uri.path(), resolved_addr); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "%s", grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return false; + } + return true; +} + +namespace grpc_core { + +grpc_error_handle UnixSockaddrPopulate(absl::string_view path, + grpc_resolved_address* resolved_addr) { + memset(resolved_addr, 0, sizeof(*resolved_addr)); + struct sockaddr_un* un = + reinterpret_cast(resolved_addr->addr); + const size_t maxlen = sizeof(un->sun_path) - 1; + if (path.size() > maxlen) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Path name should not have more than ", maxlen, " characters")); + } + un->sun_family = AF_UNIX; + path.copy(un->sun_path, path.size()); + un->sun_path[path.size()] = '\0'; + resolved_addr->len = static_cast(sizeof(*un)); + return GRPC_ERROR_NONE; +} + +grpc_error_handle UnixAbstractSockaddrPopulate( + absl::string_view path, grpc_resolved_address* resolved_addr) { + memset(resolved_addr, 0, sizeof(*resolved_addr)); + struct sockaddr_un* un = + reinterpret_cast(resolved_addr->addr); + const size_t maxlen = sizeof(un->sun_path) - 1; + if (path.size() > maxlen) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Path name should not have more than ", maxlen, " characters")); + } + un->sun_family = AF_UNIX; + un->sun_path[0] = '\0'; + path.copy(un->sun_path + 1, path.size()); + resolved_addr->len = + static_cast(sizeof(un->sun_family) + path.size() + 1); + return GRPC_ERROR_NONE; +} + +} // namespace grpc_core + +#else /* GRPC_HAVE_UNIX_SOCKET */ + +bool grpc_parse_unix(const grpc_core::URI& /* uri */, + grpc_resolved_address* /* resolved_addr */) { + abort(); +} + +bool grpc_parse_unix_abstract(const grpc_core::URI& /* uri */, + grpc_resolved_address* /* resolved_addr */) { + abort(); +} + +namespace grpc_core { + +grpc_error_handle UnixSockaddrPopulate( + absl::string_view /* path */, grpc_resolved_address* /* resolved_addr */) { + abort(); +} + +grpc_error_handle UnixAbstractSockaddrPopulate( + absl::string_view /* path */, grpc_resolved_address* /* resolved_addr */) { + abort(); +} + +} // namespace grpc_core +#endif /* GRPC_HAVE_UNIX_SOCKET */ + +bool grpc_parse_ipv4_hostport(absl::string_view hostport, + grpc_resolved_address* addr, bool log_errors) { + bool success = false; + // Split host and port. + std::string host; + std::string port; + if (!grpc_core::SplitHostPort(hostport, &host, &port)) { + if (log_errors) { + gpr_log(GPR_ERROR, "Failed gpr_split_host_port(%s, ...)", + std::string(hostport).c_str()); + } + return false; + } + // Parse IP address. + memset(addr, 0, sizeof(*addr)); + addr->len = static_cast(sizeof(grpc_sockaddr_in)); + grpc_sockaddr_in* in = reinterpret_cast(addr->addr); + in->sin_family = GRPC_AF_INET; + if (grpc_inet_pton(GRPC_AF_INET, host.c_str(), &in->sin_addr) == 0) { + if (log_errors) { + gpr_log(GPR_ERROR, "invalid ipv4 address: '%s'", host.c_str()); + } + goto done; + } + // Parse port. + if (port.empty()) { + if (log_errors) gpr_log(GPR_ERROR, "no port given for ipv4 scheme"); + goto done; + } + int port_num; + if (sscanf(port.c_str(), "%d", &port_num) != 1 || port_num < 0 || + port_num > 65535) { + if (log_errors) gpr_log(GPR_ERROR, "invalid ipv4 port: '%s'", port.c_str()); + goto done; + } + in->sin_port = grpc_htons(static_cast(port_num)); + success = true; +done: + return success; +} + +bool grpc_parse_ipv4(const grpc_core::URI& uri, + grpc_resolved_address* resolved_addr) { + if (uri.scheme() != "ipv4") { + gpr_log(GPR_ERROR, "Expected 'ipv4' scheme, got '%s'", + uri.scheme().c_str()); + return false; + } + return grpc_parse_ipv4_hostport(absl::StripPrefix(uri.path(), "/"), + resolved_addr, true /* log_errors */); +} + +bool grpc_parse_ipv6_hostport(absl::string_view hostport, + grpc_resolved_address* addr, bool log_errors) { + bool success = false; + // Split host and port. + std::string host; + std::string port; + if (!grpc_core::SplitHostPort(hostport, &host, &port)) { + if (log_errors) { + gpr_log(GPR_ERROR, "Failed gpr_split_host_port(%s, ...)", + std::string(hostport).c_str()); + } + return false; + } + // Parse IP address. + memset(addr, 0, sizeof(*addr)); + addr->len = static_cast(sizeof(grpc_sockaddr_in6)); + grpc_sockaddr_in6* in6 = reinterpret_cast(addr->addr); + in6->sin6_family = GRPC_AF_INET6; + // Handle the RFC6874 syntax for IPv6 zone identifiers. + char* host_end = + static_cast(gpr_memrchr(host.c_str(), '%', host.size())); + if (host_end != nullptr) { + GPR_ASSERT(host_end >= host.c_str()); + char host_without_scope[GRPC_INET6_ADDRSTRLEN + 1]; + size_t host_without_scope_len = + static_cast(host_end - host.c_str()); + uint32_t sin6_scope_id = 0; + if (host_without_scope_len > GRPC_INET6_ADDRSTRLEN) { + if (log_errors) { + gpr_log( + GPR_ERROR, + "invalid ipv6 address length %zu. Length cannot be greater than " + "GRPC_INET6_ADDRSTRLEN i.e %d)", + host_without_scope_len, GRPC_INET6_ADDRSTRLEN); + } + goto done; + } + strncpy(host_without_scope, host.c_str(), host_without_scope_len); + host_without_scope[host_without_scope_len] = '\0'; + if (grpc_inet_pton(GRPC_AF_INET6, host_without_scope, &in6->sin6_addr) == + 0) { + if (log_errors) { + gpr_log(GPR_ERROR, "invalid ipv6 address: '%s'", host_without_scope); + } + goto done; + } + if (gpr_parse_bytes_to_uint32(host_end + 1, + host.size() - host_without_scope_len - 1, + &sin6_scope_id) == 0) { + if ((sin6_scope_id = grpc_if_nametoindex(host_end + 1)) == 0) { + gpr_log(GPR_ERROR, + "Invalid interface name: '%s'. " + "Non-numeric and failed if_nametoindex.", + host_end + 1); + goto done; + } + } + // Handle "sin6_scope_id" being type "u_long". See grpc issue #10027. + in6->sin6_scope_id = sin6_scope_id; + } else { + if (grpc_inet_pton(GRPC_AF_INET6, host.c_str(), &in6->sin6_addr) == 0) { + if (log_errors) { + gpr_log(GPR_ERROR, "invalid ipv6 address: '%s'", host.c_str()); + } + goto done; + } + } + // Parse port. + if (port.empty()) { + if (log_errors) gpr_log(GPR_ERROR, "no port given for ipv6 scheme"); + goto done; + } + int port_num; + if (sscanf(port.c_str(), "%d", &port_num) != 1 || port_num < 0 || + port_num > 65535) { + if (log_errors) gpr_log(GPR_ERROR, "invalid ipv6 port: '%s'", port.c_str()); + goto done; + } + in6->sin6_port = grpc_htons(static_cast(port_num)); + success = true; +done: + return success; +} + +bool grpc_parse_ipv6(const grpc_core::URI& uri, + grpc_resolved_address* resolved_addr) { + if (uri.scheme() != "ipv6") { + gpr_log(GPR_ERROR, "Expected 'ipv6' scheme, got '%s'", + uri.scheme().c_str()); + return false; + } + return grpc_parse_ipv6_hostport(absl::StripPrefix(uri.path(), "/"), + resolved_addr, true /* log_errors */); +} + +bool grpc_parse_uri(const grpc_core::URI& uri, + grpc_resolved_address* resolved_addr) { + if (uri.scheme() == "unix") { + return grpc_parse_unix(uri, resolved_addr); + } + if (uri.scheme() == "unix-abstract") { + return grpc_parse_unix_abstract(uri, resolved_addr); + } + if (uri.scheme() == "ipv4") { + return grpc_parse_ipv4(uri, resolved_addr); + } + if (uri.scheme() == "ipv6") { + return grpc_parse_ipv6(uri, resolved_addr); + } + gpr_log(GPR_ERROR, "Can't parse scheme '%s'", uri.scheme().c_str()); + return false; +} + +uint16_t grpc_strhtons(const char* port) { + if (strcmp(port, "http") == 0) { + return htons(80); + } else if (strcmp(port, "https") == 0) { + return htons(443); + } + return htons(static_cast(atoi(port))); +} diff --git a/src/core/lib/address_utils/sockaddr_utils.cc b/src/core/lib/address_utils/sockaddr_utils.cc new file mode 100644 index 00000000..33afba1d --- /dev/null +++ b/src/core/lib/address_utils/sockaddr_utils.cc @@ -0,0 +1,412 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" + +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/event_engine/resolved_address_internal.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +static const uint8_t kV4MappedPrefix[] = {0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0xff, 0xff}; + +int grpc_sockaddr_is_v4mapped(const grpc_resolved_address* resolved_addr, + grpc_resolved_address* resolved_addr4_out) { + GPR_ASSERT(resolved_addr != resolved_addr4_out); + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + grpc_sockaddr_in* addr4_out = + resolved_addr4_out == nullptr + ? nullptr + : reinterpret_cast(resolved_addr4_out->addr); + if (addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(addr); + if (memcmp(addr6->sin6_addr.s6_addr, kV4MappedPrefix, + sizeof(kV4MappedPrefix)) == 0) { + if (resolved_addr4_out != nullptr) { + /* Normalize ::ffff:0.0.0.0/96 to IPv4. */ + memset(resolved_addr4_out, 0, sizeof(*resolved_addr4_out)); + addr4_out->sin_family = GRPC_AF_INET; + /* s6_addr32 would be nice, but it's non-standard. */ + memcpy(&addr4_out->sin_addr, &addr6->sin6_addr.s6_addr[12], 4); + addr4_out->sin_port = addr6->sin6_port; + resolved_addr4_out->len = + static_cast(sizeof(grpc_sockaddr_in)); + } + return 1; + } + } + return 0; +} + +int grpc_sockaddr_to_v4mapped(const grpc_resolved_address* resolved_addr, + grpc_resolved_address* resolved_addr6_out) { + GPR_ASSERT(resolved_addr != resolved_addr6_out); + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + grpc_sockaddr_in6* addr6_out = + reinterpret_cast(resolved_addr6_out->addr); + if (addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast(addr); + memset(resolved_addr6_out, 0, sizeof(*resolved_addr6_out)); + addr6_out->sin6_family = GRPC_AF_INET6; + memcpy(&addr6_out->sin6_addr.s6_addr[0], kV4MappedPrefix, 12); + memcpy(&addr6_out->sin6_addr.s6_addr[12], &addr4->sin_addr, 4); + addr6_out->sin6_port = addr4->sin_port; + resolved_addr6_out->len = static_cast(sizeof(grpc_sockaddr_in6)); + return 1; + } + return 0; +} + +int grpc_sockaddr_is_wildcard(const grpc_resolved_address* resolved_addr, + int* port_out) { + const grpc_sockaddr* addr; + grpc_resolved_address addr4_normalized; + if (grpc_sockaddr_is_v4mapped(resolved_addr, &addr4_normalized)) { + resolved_addr = &addr4_normalized; + } + addr = reinterpret_cast(resolved_addr->addr); + if (addr->sa_family == GRPC_AF_INET) { + /* Check for 0.0.0.0 */ + const grpc_sockaddr_in* addr4 = + reinterpret_cast(addr); + if (addr4->sin_addr.s_addr != 0) { + return 0; + } + *port_out = grpc_ntohs(addr4->sin_port); + return 1; + } else if (addr->sa_family == GRPC_AF_INET6) { + /* Check for :: */ + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(addr); + int i; + for (i = 0; i < 16; i++) { + if (addr6->sin6_addr.s6_addr[i] != 0) { + return 0; + } + } + *port_out = grpc_ntohs(addr6->sin6_port); + return 1; + } else { + return 0; + } +} + +void grpc_sockaddr_make_wildcards(int port, grpc_resolved_address* wild4_out, + grpc_resolved_address* wild6_out) { + grpc_sockaddr_make_wildcard4(port, wild4_out); + grpc_sockaddr_make_wildcard6(port, wild6_out); +} + +void grpc_sockaddr_make_wildcard4(int port, + grpc_resolved_address* resolved_wild_out) { + grpc_sockaddr_in* wild_out = + reinterpret_cast(resolved_wild_out->addr); + GPR_ASSERT(port >= 0 && port < 65536); + memset(resolved_wild_out, 0, sizeof(*resolved_wild_out)); + wild_out->sin_family = GRPC_AF_INET; + wild_out->sin_port = grpc_htons(static_cast(port)); + resolved_wild_out->len = static_cast(sizeof(grpc_sockaddr_in)); +} + +void grpc_sockaddr_make_wildcard6(int port, + grpc_resolved_address* resolved_wild_out) { + grpc_sockaddr_in6* wild_out = + reinterpret_cast(resolved_wild_out->addr); + GPR_ASSERT(port >= 0 && port < 65536); + memset(resolved_wild_out, 0, sizeof(*resolved_wild_out)); + wild_out->sin6_family = GRPC_AF_INET6; + wild_out->sin6_port = grpc_htons(static_cast(port)); + resolved_wild_out->len = static_cast(sizeof(grpc_sockaddr_in6)); +} + +std::string grpc_sockaddr_to_string(const grpc_resolved_address* resolved_addr, + bool normalize) { + const int save_errno = errno; + grpc_resolved_address addr_normalized; + if (normalize && grpc_sockaddr_is_v4mapped(resolved_addr, &addr_normalized)) { + resolved_addr = &addr_normalized; + } + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + const void* ip = nullptr; + int port = 0; + uint32_t sin6_scope_id = 0; + if (addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast(addr); + ip = &addr4->sin_addr; + port = grpc_ntohs(addr4->sin_port); + } else if (addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(addr); + ip = &addr6->sin6_addr; + port = grpc_ntohs(addr6->sin6_port); + sin6_scope_id = addr6->sin6_scope_id; + } + char ntop_buf[GRPC_INET6_ADDRSTRLEN]; + std::string out; + if (ip != nullptr && grpc_inet_ntop(addr->sa_family, ip, ntop_buf, + sizeof(ntop_buf)) != nullptr) { + if (sin6_scope_id != 0) { + // Enclose sin6_scope_id with the format defined in RFC 6874 section 2. + std::string host_with_scope = + absl::StrFormat("%s%%25%" PRIu32, ntop_buf, sin6_scope_id); + out = grpc_core::JoinHostPort(host_with_scope, port); + } else { + out = grpc_core::JoinHostPort(ntop_buf, port); + } + } else { + out = absl::StrFormat("(sockaddr family=%d)", addr->sa_family); + } + /* This is probably redundant, but we wouldn't want to log the wrong error. */ + errno = save_errno; + return out; +} + +grpc_error_handle grpc_string_to_sockaddr(grpc_resolved_address* out, + const char* addr, int port) { + memset(out, 0, sizeof(grpc_resolved_address)); + grpc_sockaddr_in6* addr6 = reinterpret_cast(out->addr); + grpc_sockaddr_in* addr4 = reinterpret_cast(out->addr); + if (grpc_inet_pton(GRPC_AF_INET6, addr, &addr6->sin6_addr) == 1) { + addr6->sin6_family = GRPC_AF_INET6; + out->len = sizeof(grpc_sockaddr_in6); + } else if (grpc_inet_pton(GRPC_AF_INET, addr, &addr4->sin_addr) == 1) { + addr4->sin_family = GRPC_AF_INET; + out->len = sizeof(grpc_sockaddr_in); + } else { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Failed to parse address:", addr)); + } + grpc_sockaddr_set_port(out, port); + return GRPC_ERROR_NONE; +} + +std::string grpc_sockaddr_to_uri(const grpc_resolved_address* resolved_addr) { + if (resolved_addr->len == 0) return ""; + grpc_resolved_address addr_normalized; + if (grpc_sockaddr_is_v4mapped(resolved_addr, &addr_normalized)) { + resolved_addr = &addr_normalized; + } + const char* scheme = grpc_sockaddr_get_uri_scheme(resolved_addr); + if (scheme == nullptr || strcmp("unix", scheme) == 0) { + return grpc_sockaddr_to_uri_unix_if_possible(resolved_addr); + } + std::string path = + grpc_sockaddr_to_string(resolved_addr, false /* normalize */); + std::string uri_str; + if (scheme != nullptr) { + uri_str = absl::StrCat(scheme, ":", path); + } + return uri_str; +} + +const char* grpc_sockaddr_get_uri_scheme( + const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + switch (addr->sa_family) { + case GRPC_AF_INET: + return "ipv4"; + case GRPC_AF_INET6: + return "ipv6"; + case GRPC_AF_UNIX: + return "unix"; + } + return nullptr; +} + +int grpc_sockaddr_get_family(const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + return addr->sa_family; +} + +int grpc_sockaddr_get_port(const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + switch (addr->sa_family) { + case GRPC_AF_INET: + return grpc_ntohs( + (reinterpret_cast(addr))->sin_port); + case GRPC_AF_INET6: + return grpc_ntohs( + (reinterpret_cast(addr))->sin6_port); + default: + if (grpc_is_unix_socket(resolved_addr)) { + return 1; + } + gpr_log(GPR_ERROR, "Unknown socket family %d in grpc_sockaddr_get_port", + addr->sa_family); + return 0; + } +} + +int grpc_sockaddr_set_port(grpc_resolved_address* resolved_addr, int port) { + grpc_sockaddr* addr = reinterpret_cast(resolved_addr->addr); + switch (addr->sa_family) { + case GRPC_AF_INET: + GPR_ASSERT(port >= 0 && port < 65536); + (reinterpret_cast(addr))->sin_port = + grpc_htons(static_cast(port)); + return 1; + case GRPC_AF_INET6: + GPR_ASSERT(port >= 0 && port < 65536); + (reinterpret_cast(addr))->sin6_port = + grpc_htons(static_cast(port)); + return 1; + default: + gpr_log(GPR_ERROR, "Unknown socket family %d in grpc_sockaddr_set_port", + addr->sa_family); + return 0; + } +} + +std::string grpc_sockaddr_get_packed_host( + const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + if (addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast(addr); + const char* addr_bytes = reinterpret_cast(&addr4->sin_addr); + return std::string(addr_bytes, 4); + } else if (addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(addr); + const char* addr_bytes = reinterpret_cast(&addr6->sin6_addr); + return std::string(addr_bytes, 16); + } else { + GPR_ASSERT(false); + } +} + +void grpc_sockaddr_mask_bits(grpc_resolved_address* address, + uint32_t mask_bits) { + grpc_sockaddr* addr = reinterpret_cast(address->addr); + if (addr->sa_family == GRPC_AF_INET) { + grpc_sockaddr_in* addr4 = reinterpret_cast(addr); + if (mask_bits == 0) { + memset(&addr4->sin_addr, 0, sizeof(addr4->sin_addr)); + return; + } else if (mask_bits >= 32) { + return; + } + uint32_t mask_ip_addr = (~(uint32_t(0))) << (32 - mask_bits); + addr4->sin_addr.s_addr &= grpc_htonl(mask_ip_addr); + } else if (addr->sa_family == GRPC_AF_INET6) { + grpc_sockaddr_in6* addr6 = reinterpret_cast(addr); + if (mask_bits == 0) { + memset(&addr6->sin6_addr, 0, sizeof(addr6->sin6_addr)); + return; + } else if (mask_bits >= 128) { + return; + } + // We cannot use s6_addr32 since it is not defined on all platforms that we + // need it on. + uint32_t address_parts[4]; + GPR_ASSERT(sizeof(addr6->sin6_addr) == sizeof(address_parts)); + memcpy(address_parts, &addr6->sin6_addr, sizeof(grpc_in6_addr)); + if (mask_bits <= 32) { + uint32_t mask_ip_addr = (~(uint32_t(0))) << (32 - mask_bits); + address_parts[0] &= grpc_htonl(mask_ip_addr); + memset(&address_parts[1], 0, sizeof(uint32_t)); + memset(&address_parts[2], 0, sizeof(uint32_t)); + memset(&address_parts[3], 0, sizeof(uint32_t)); + } else if (mask_bits <= 64) { + mask_bits -= 32; + uint32_t mask_ip_addr = (~(uint32_t(0))) << (32 - mask_bits); + address_parts[1] &= grpc_htonl(mask_ip_addr); + memset(&address_parts[2], 0, sizeof(uint32_t)); + memset(&address_parts[3], 0, sizeof(uint32_t)); + } else if (mask_bits <= 96) { + mask_bits -= 64; + uint32_t mask_ip_addr = (~(uint32_t(0))) << (32 - mask_bits); + address_parts[2] &= grpc_htonl(mask_ip_addr); + memset(&address_parts[3], 0, sizeof(uint32_t)); + } else { + mask_bits -= 96; + uint32_t mask_ip_addr = (~(uint32_t(0))) << (32 - mask_bits); + address_parts[3] &= grpc_htonl(mask_ip_addr); + } + memcpy(&addr6->sin6_addr, address_parts, sizeof(grpc_in6_addr)); + } +} + +bool grpc_sockaddr_match_subnet(const grpc_resolved_address* address, + const grpc_resolved_address* subnet_address, + uint32_t mask_bits) { + auto* addr = reinterpret_cast(address->addr); + auto* subnet_addr = + reinterpret_cast(subnet_address->addr); + if (addr->sa_family != subnet_addr->sa_family) return false; + grpc_resolved_address masked_address; + memcpy(&masked_address, address, sizeof(grpc_resolved_address)); + addr = reinterpret_cast((&masked_address)->addr); + grpc_sockaddr_mask_bits(&masked_address, mask_bits); + if (addr->sa_family == GRPC_AF_INET) { + auto* addr4 = reinterpret_cast(addr); + auto* subnet_addr4 = reinterpret_cast(subnet_addr); + if (memcmp(&addr4->sin_addr, &subnet_addr4->sin_addr, + sizeof(addr4->sin_addr)) == 0) { + return true; + } + } else if (addr->sa_family == GRPC_AF_INET6) { + auto* addr6 = reinterpret_cast(addr); + auto* subnet_addr6 = + reinterpret_cast(subnet_addr); + if (memcmp(&addr6->sin6_addr, &subnet_addr6->sin6_addr, + sizeof(addr6->sin6_addr)) == 0) { + return true; + } + } + return false; +} + +namespace grpc_event_engine { +namespace experimental { + +std::string ResolvedAddressToURI(const EventEngine::ResolvedAddress& addr) { + auto gra = CreateGRPCResolvedAddress(addr); + return grpc_sockaddr_to_uri(&gra); +} + +} // namespace experimental +} // namespace grpc_event_engine diff --git a/src/core/lib/avl/avl.cc b/src/core/lib/avl/avl.cc new file mode 100644 index 00000000..cd1b940e --- /dev/null +++ b/src/core/lib/avl/avl.cc @@ -0,0 +1,306 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/avl/avl.h" + +#include +#include + +#include + +#include +#include + +grpc_avl grpc_avl_create(const grpc_avl_vtable* vtable) { + grpc_avl out; + out.vtable = vtable; + out.root = nullptr; + return out; +} + +static grpc_avl_node* ref_node(grpc_avl_node* node) { + if (node) { + gpr_ref(&node->refs); + } + return node; +} + +static void unref_node(const grpc_avl_vtable* vtable, grpc_avl_node* node, + void* user_data) { + if (node == nullptr) { + return; + } + if (gpr_unref(&node->refs)) { + vtable->destroy_key(node->key, user_data); + vtable->destroy_value(node->value, user_data); + unref_node(vtable, node->left, user_data); + unref_node(vtable, node->right, user_data); + gpr_free(node); + } +} + +static long node_height(grpc_avl_node* node) { + return node == nullptr ? 0 : node->height; +} + +#ifndef NDEBUG +static long calculate_height(grpc_avl_node* node) { + return node == nullptr ? 0 + : 1 + std::max(calculate_height(node->left), + calculate_height(node->right)); +} + +static grpc_avl_node* assert_invariants(grpc_avl_node* n) { + if (n == nullptr) return nullptr; + assert_invariants(n->left); + assert_invariants(n->right); + assert(calculate_height(n) == n->height); + assert(labs(node_height(n->left) - node_height(n->right)) <= 1); + return n; +} +#else +static grpc_avl_node* assert_invariants(grpc_avl_node* n) { return n; } +#endif + +grpc_avl_node* new_node(void* key, void* value, grpc_avl_node* left, + grpc_avl_node* right) { + grpc_avl_node* node = static_cast(gpr_malloc(sizeof(*node))); + gpr_ref_init(&node->refs, 1); + node->key = key; + node->value = value; + node->left = assert_invariants(left); + node->right = assert_invariants(right); + node->height = 1 + std::max(node_height(left), node_height(right)); + return node; +} + +static grpc_avl_node* get(const grpc_avl_vtable* vtable, grpc_avl_node* node, + void* key, void* user_data) { + long cmp; + + if (node == nullptr) { + return nullptr; + } + + cmp = vtable->compare_keys(node->key, key, user_data); + if (cmp == 0) { + return node; + } else if (cmp > 0) { + return get(vtable, node->left, key, user_data); + } else { + return get(vtable, node->right, key, user_data); + } +} + +void* grpc_avl_get(grpc_avl avl, void* key, void* user_data) { + grpc_avl_node* node = get(avl.vtable, avl.root, key, user_data); + return node ? node->value : nullptr; +} + +int grpc_avl_maybe_get(grpc_avl avl, void* key, void** value, void* user_data) { + grpc_avl_node* node = get(avl.vtable, avl.root, key, user_data); + if (node != nullptr) { + *value = node->value; + return 1; + } + return 0; +} + +static grpc_avl_node* rotate_left(const grpc_avl_vtable* vtable, void* key, + void* value, grpc_avl_node* left, + grpc_avl_node* right, void* user_data) { + grpc_avl_node* n = new_node(vtable->copy_key(right->key, user_data), + vtable->copy_value(right->value, user_data), + new_node(key, value, left, ref_node(right->left)), + ref_node(right->right)); + unref_node(vtable, right, user_data); + return n; +} + +static grpc_avl_node* rotate_right(const grpc_avl_vtable* vtable, void* key, + void* value, grpc_avl_node* left, + grpc_avl_node* right, void* user_data) { + grpc_avl_node* n = + new_node(vtable->copy_key(left->key, user_data), + vtable->copy_value(left->value, user_data), ref_node(left->left), + new_node(key, value, ref_node(left->right), right)); + unref_node(vtable, left, user_data); + return n; +} + +static grpc_avl_node* rotate_left_right(const grpc_avl_vtable* vtable, + void* key, void* value, + grpc_avl_node* left, + grpc_avl_node* right, void* user_data) { + /* rotate_right(..., rotate_left(left), right) */ + grpc_avl_node* n = + new_node(vtable->copy_key(left->right->key, user_data), + vtable->copy_value(left->right->value, user_data), + new_node(vtable->copy_key(left->key, user_data), + vtable->copy_value(left->value, user_data), + ref_node(left->left), ref_node(left->right->left)), + new_node(key, value, ref_node(left->right->right), right)); + unref_node(vtable, left, user_data); + return n; +} + +static grpc_avl_node* rotate_right_left(const grpc_avl_vtable* vtable, + void* key, void* value, + grpc_avl_node* left, + grpc_avl_node* right, void* user_data) { + /* rotate_left(..., left, rotate_right(right)) */ + grpc_avl_node* n = + new_node(vtable->copy_key(right->left->key, user_data), + vtable->copy_value(right->left->value, user_data), + new_node(key, value, left, ref_node(right->left->left)), + new_node(vtable->copy_key(right->key, user_data), + vtable->copy_value(right->value, user_data), + ref_node(right->left->right), ref_node(right->right))); + unref_node(vtable, right, user_data); + return n; +} + +static grpc_avl_node* rebalance(const grpc_avl_vtable* vtable, void* key, + void* value, grpc_avl_node* left, + grpc_avl_node* right, void* user_data) { + switch (node_height(left) - node_height(right)) { + case 2: + if (node_height(left->left) - node_height(left->right) == -1) { + return assert_invariants( + rotate_left_right(vtable, key, value, left, right, user_data)); + } else { + return assert_invariants( + rotate_right(vtable, key, value, left, right, user_data)); + } + case -2: + if (node_height(right->left) - node_height(right->right) == 1) { + return assert_invariants( + rotate_right_left(vtable, key, value, left, right, user_data)); + } else { + return assert_invariants( + rotate_left(vtable, key, value, left, right, user_data)); + } + default: + return assert_invariants(new_node(key, value, left, right)); + } +} + +static grpc_avl_node* add_key(const grpc_avl_vtable* vtable, + grpc_avl_node* node, void* key, void* value, + void* user_data) { + long cmp; + if (node == nullptr) { + return new_node(key, value, nullptr, nullptr); + } + cmp = vtable->compare_keys(node->key, key, user_data); + if (cmp == 0) { + return new_node(key, value, ref_node(node->left), ref_node(node->right)); + } else if (cmp > 0) { + return rebalance(vtable, vtable->copy_key(node->key, user_data), + vtable->copy_value(node->value, user_data), + add_key(vtable, node->left, key, value, user_data), + ref_node(node->right), user_data); + } else { + return rebalance( + vtable, vtable->copy_key(node->key, user_data), + vtable->copy_value(node->value, user_data), ref_node(node->left), + add_key(vtable, node->right, key, value, user_data), user_data); + } +} + +grpc_avl grpc_avl_add(grpc_avl avl, void* key, void* value, void* user_data) { + grpc_avl_node* old_root = avl.root; + avl.root = add_key(avl.vtable, avl.root, key, value, user_data); + assert_invariants(avl.root); + unref_node(avl.vtable, old_root, user_data); + return avl; +} + +static grpc_avl_node* in_order_head(grpc_avl_node* node) { + while (node->left != nullptr) { + node = node->left; + } + return node; +} + +static grpc_avl_node* in_order_tail(grpc_avl_node* node) { + while (node->right != nullptr) { + node = node->right; + } + return node; +} + +static grpc_avl_node* remove_key(const grpc_avl_vtable* vtable, + grpc_avl_node* node, void* key, + void* user_data) { + long cmp; + if (node == nullptr) { + return nullptr; + } + cmp = vtable->compare_keys(node->key, key, user_data); + if (cmp == 0) { + if (node->left == nullptr) { + return ref_node(node->right); + } else if (node->right == nullptr) { + return ref_node(node->left); + } else if (node->left->height < node->right->height) { + grpc_avl_node* h = in_order_head(node->right); + return rebalance( + vtable, vtable->copy_key(h->key, user_data), + vtable->copy_value(h->value, user_data), ref_node(node->left), + remove_key(vtable, node->right, h->key, user_data), user_data); + } else { + grpc_avl_node* h = in_order_tail(node->left); + return rebalance(vtable, vtable->copy_key(h->key, user_data), + vtable->copy_value(h->value, user_data), + remove_key(vtable, node->left, h->key, user_data), + ref_node(node->right), user_data); + } + } else if (cmp > 0) { + return rebalance(vtable, vtable->copy_key(node->key, user_data), + vtable->copy_value(node->value, user_data), + remove_key(vtable, node->left, key, user_data), + ref_node(node->right), user_data); + } else { + return rebalance( + vtable, vtable->copy_key(node->key, user_data), + vtable->copy_value(node->value, user_data), ref_node(node->left), + remove_key(vtable, node->right, key, user_data), user_data); + } +} + +grpc_avl grpc_avl_remove(grpc_avl avl, void* key, void* user_data) { + grpc_avl_node* old_root = avl.root; + avl.root = remove_key(avl.vtable, avl.root, key, user_data); + assert_invariants(avl.root); + unref_node(avl.vtable, old_root, user_data); + return avl; +} + +grpc_avl grpc_avl_ref(grpc_avl avl, void* /*user_data*/) { + ref_node(avl.root); + return avl; +} + +void grpc_avl_unref(grpc_avl avl, void* user_data) { + unref_node(avl.vtable, avl.root, user_data); +} + +int grpc_avl_is_empty(grpc_avl avl) { return avl.root == nullptr; } diff --git a/src/core/lib/backoff/backoff.cc b/src/core/lib/backoff/backoff.cc new file mode 100644 index 00000000..68bef770 --- /dev/null +++ b/src/core/lib/backoff/backoff.cc @@ -0,0 +1,78 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/backoff/backoff.h" + +#include + +#include "src/core/lib/gpr/useful.h" + +namespace grpc_core { + +namespace { + +/* Generate a random number between 0 and 1. We roll our own RNG because seeding + * rand() modifies a global variable we have no control over. */ +double generate_uniform_random_number(uint32_t* rng_state) { + constexpr uint32_t two_raise_31 = uint32_t(1) << 31; + *rng_state = (1103515245 * *rng_state + 12345) % two_raise_31; + return *rng_state / static_cast(two_raise_31); +} + +double generate_uniform_random_number_between(uint32_t* rng_state, double a, + double b) { + if (a == b) return a; + if (a > b) std::swap(a, b); // make sure a < b + const double range = b - a; + return a + generate_uniform_random_number(rng_state) * range; +} + +} // namespace + +BackOff::BackOff(const Options& options) + : options_(options), + rng_state_(static_cast(gpr_now(GPR_CLOCK_REALTIME).tv_nsec)) { + Reset(); +} + +grpc_millis BackOff::NextAttemptTime() { + if (initial_) { + initial_ = false; + return current_backoff_ + grpc_core::ExecCtx::Get()->Now(); + } + current_backoff_ = static_cast( + std::min(current_backoff_ * options_.multiplier(), + static_cast(options_.max_backoff()))); + const double jitter = generate_uniform_random_number_between( + &rng_state_, -options_.jitter() * current_backoff_, + options_.jitter() * current_backoff_); + const grpc_millis next_timeout = + static_cast(current_backoff_ + jitter); + return next_timeout + grpc_core::ExecCtx::Get()->Now(); +} + +void BackOff::Reset() { + current_backoff_ = options_.initial_backoff(); + initial_ = true; +} + +void BackOff::SetRandomSeed(uint32_t seed) { rng_state_ = seed; } + +} // namespace grpc_core diff --git a/src/core/lib/channel/channel_args.cc b/src/core/lib/channel/channel_args.cc new file mode 100644 index 00000000..4cd015bb --- /dev/null +++ b/src/core/lib/channel/channel_args.cc @@ -0,0 +1,400 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channel_args.h" + +#include +#include + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" + +static grpc_arg copy_arg(const grpc_arg* src) { + grpc_arg dst; + dst.type = src->type; + dst.key = gpr_strdup(src->key); + switch (dst.type) { + case GRPC_ARG_STRING: + dst.value.string = gpr_strdup(src->value.string); + break; + case GRPC_ARG_INTEGER: + dst.value.integer = src->value.integer; + break; + case GRPC_ARG_POINTER: + dst.value.pointer = src->value.pointer; + dst.value.pointer.p = + src->value.pointer.vtable->copy(src->value.pointer.p); + break; + } + return dst; +} + +grpc_channel_args* grpc_channel_args_copy_and_add(const grpc_channel_args* src, + const grpc_arg* to_add, + size_t num_to_add) { + return grpc_channel_args_copy_and_add_and_remove(src, nullptr, 0, to_add, + num_to_add); +} + +grpc_channel_args* grpc_channel_args_copy_and_remove( + const grpc_channel_args* src, const char** to_remove, + size_t num_to_remove) { + return grpc_channel_args_copy_and_add_and_remove(src, to_remove, + num_to_remove, nullptr, 0); +} + +grpc_channel_args* grpc_channel_args_remove_grpc_internal( + const grpc_channel_args* src) { + if (src == nullptr) return nullptr; + // Create result. + grpc_channel_args* dst = + static_cast(gpr_malloc(sizeof(grpc_channel_args))); + dst->args = + static_cast(gpr_malloc(sizeof(grpc_arg) * src->num_args)); + dst->num_args = 0; + for (size_t i = 0; i < src->num_args; ++i) { + if (absl::StartsWith(src->args[i].key, "grpc.internal.")) continue; + dst->args[dst->num_args++] = copy_arg(&src->args[i]); + } + return dst; +} + +static bool should_remove_arg(const grpc_arg* arg, const char** to_remove, + size_t num_to_remove) { + for (size_t i = 0; i < num_to_remove; ++i) { + if (strcmp(arg->key, to_remove[i]) == 0) return true; + } + return false; +} + +grpc_channel_args* grpc_channel_args_copy_and_add_and_remove( + const grpc_channel_args* src, const char** to_remove, size_t num_to_remove, + const grpc_arg* to_add, size_t num_to_add) { + // Figure out how many args we'll be copying. + size_t num_args_to_copy = 0; + if (src != nullptr) { + for (size_t i = 0; i < src->num_args; ++i) { + if (!should_remove_arg(&src->args[i], to_remove, num_to_remove)) { + ++num_args_to_copy; + } + } + } + // Create result. + grpc_channel_args* dst = + static_cast(gpr_malloc(sizeof(grpc_channel_args))); + dst->num_args = num_args_to_copy + num_to_add; + if (dst->num_args == 0) { + dst->args = nullptr; + return dst; + } + dst->args = + static_cast(gpr_malloc(sizeof(grpc_arg) * dst->num_args)); + // Copy args from src that are not being removed. + size_t dst_idx = 0; + if (src != nullptr) { + for (size_t i = 0; i < src->num_args; ++i) { + if (!should_remove_arg(&src->args[i], to_remove, num_to_remove)) { + dst->args[dst_idx++] = copy_arg(&src->args[i]); + } + } + } + // Add args from to_add. + for (size_t i = 0; i < num_to_add; ++i) { + dst->args[dst_idx++] = copy_arg(&to_add[i]); + } + GPR_ASSERT(dst_idx == dst->num_args); + return dst; +} + +grpc_channel_args* grpc_channel_args_copy(const grpc_channel_args* src) { + return grpc_channel_args_copy_and_add(src, nullptr, 0); +} + +grpc_channel_args* grpc_channel_args_union(const grpc_channel_args* a, + const grpc_channel_args* b) { + if (a == nullptr) return grpc_channel_args_copy(b); + if (b == nullptr) return grpc_channel_args_copy(a); + const size_t max_out = (a->num_args + b->num_args); + grpc_arg* uniques = + static_cast(gpr_malloc(sizeof(*uniques) * max_out)); + for (size_t i = 0; i < a->num_args; ++i) uniques[i] = a->args[i]; + + size_t uniques_idx = a->num_args; + for (size_t i = 0; i < b->num_args; ++i) { + const char* b_key = b->args[i].key; + if (grpc_channel_args_find(a, b_key) == nullptr) { // not found + uniques[uniques_idx++] = b->args[i]; + } + } + grpc_channel_args* result = + grpc_channel_args_copy_and_add(nullptr, uniques, uniques_idx); + gpr_free(uniques); + return result; +} + +static int cmp_arg(const grpc_arg* a, const grpc_arg* b) { + int c = grpc_core::QsortCompare(a->type, b->type); + if (c != 0) return c; + c = strcmp(a->key, b->key); + if (c != 0) return c; + switch (a->type) { + case GRPC_ARG_STRING: + return strcmp(a->value.string, b->value.string); + case GRPC_ARG_INTEGER: + return grpc_core::QsortCompare(a->value.integer, b->value.integer); + case GRPC_ARG_POINTER: + c = grpc_core::QsortCompare(a->value.pointer.p, b->value.pointer.p); + if (c != 0) { + c = grpc_core::QsortCompare(a->value.pointer.vtable, + b->value.pointer.vtable); + if (c == 0) { + c = a->value.pointer.vtable->cmp(a->value.pointer.p, + b->value.pointer.p); + } + } + return c; + } + GPR_UNREACHABLE_CODE(return 0); +} + +/* stabilizing comparison function: since channel_args ordering matters for + * keys with the same name, we need to preserve that ordering */ +static int cmp_key_stable(const void* ap, const void* bp) { + const grpc_arg* const* a = static_cast(ap); + const grpc_arg* const* b = static_cast(bp); + int c = strcmp((*a)->key, (*b)->key); + if (c == 0) c = grpc_core::QsortCompare(*a, *b); + return c; +} + +grpc_channel_args* grpc_channel_args_normalize(const grpc_channel_args* src) { + grpc_arg** args = + static_cast(gpr_malloc(sizeof(grpc_arg*) * src->num_args)); + for (size_t i = 0; i < src->num_args; i++) { + args[i] = &src->args[i]; + } + if (src->num_args > 1) { + qsort(args, src->num_args, sizeof(grpc_arg*), cmp_key_stable); + } + + grpc_channel_args* b = + static_cast(gpr_malloc(sizeof(grpc_channel_args))); + b->num_args = src->num_args; + b->args = static_cast(gpr_malloc(sizeof(grpc_arg) * b->num_args)); + for (size_t i = 0; i < src->num_args; i++) { + b->args[i] = copy_arg(args[i]); + } + + gpr_free(args); + return b; +} + +void grpc_channel_args_destroy(grpc_channel_args* a) { + size_t i; + if (!a) return; + for (i = 0; i < a->num_args; i++) { + switch (a->args[i].type) { + case GRPC_ARG_STRING: + gpr_free(a->args[i].value.string); + break; + case GRPC_ARG_INTEGER: + break; + case GRPC_ARG_POINTER: + a->args[i].value.pointer.vtable->destroy(a->args[i].value.pointer.p); + break; + } + gpr_free(a->args[i].key); + } + gpr_free(a->args); + gpr_free(a); +} + +int grpc_channel_args_compare(const grpc_channel_args* a, + const grpc_channel_args* b) { + if (a == nullptr && b == nullptr) return 0; + if (a == nullptr || b == nullptr) return a == nullptr ? -1 : 1; + int c = grpc_core::QsortCompare(a->num_args, b->num_args); + if (c != 0) return c; + for (size_t i = 0; i < a->num_args; i++) { + c = cmp_arg(&a->args[i], &b->args[i]); + if (c != 0) return c; + } + return 0; +} + +const grpc_arg* grpc_channel_args_find(const grpc_channel_args* args, + const char* name) { + if (args != nullptr) { + for (size_t i = 0; i < args->num_args; ++i) { + if (strcmp(args->args[i].key, name) == 0) { + return &args->args[i]; + } + } + } + return nullptr; +} + +int grpc_channel_arg_get_integer(const grpc_arg* arg, + const grpc_integer_options options) { + if (arg == nullptr) return options.default_value; + if (arg->type != GRPC_ARG_INTEGER) { + gpr_log(GPR_ERROR, "%s ignored: it must be an integer", arg->key); + return options.default_value; + } + if (arg->value.integer < options.min_value) { + gpr_log(GPR_ERROR, "%s ignored: it must be >= %d", arg->key, + options.min_value); + return options.default_value; + } + if (arg->value.integer > options.max_value) { + gpr_log(GPR_ERROR, "%s ignored: it must be <= %d", arg->key, + options.max_value); + return options.default_value; + } + return arg->value.integer; +} + +int grpc_channel_args_find_integer(const grpc_channel_args* args, + const char* name, + const grpc_integer_options options) { + const grpc_arg* arg = grpc_channel_args_find(args, name); + return grpc_channel_arg_get_integer(arg, options); +} + +char* grpc_channel_arg_get_string(const grpc_arg* arg) { + if (arg == nullptr) return nullptr; + if (arg->type != GRPC_ARG_STRING) { + gpr_log(GPR_ERROR, "%s ignored: it must be an string", arg->key); + return nullptr; + } + return arg->value.string; +} + +char* grpc_channel_args_find_string(const grpc_channel_args* args, + const char* name) { + const grpc_arg* arg = grpc_channel_args_find(args, name); + return grpc_channel_arg_get_string(arg); +} + +bool grpc_channel_arg_get_bool(const grpc_arg* arg, bool default_value) { + if (arg == nullptr) return default_value; + if (arg->type != GRPC_ARG_INTEGER) { + gpr_log(GPR_ERROR, "%s ignored: it must be an integer", arg->key); + return default_value; + } + switch (arg->value.integer) { + case 0: + return false; + case 1: + return true; + default: + gpr_log(GPR_ERROR, "%s treated as bool but set to %d (assuming true)", + arg->key, arg->value.integer); + return true; + } +} + +bool grpc_channel_args_find_bool(const grpc_channel_args* args, + const char* name, bool default_value) { + const grpc_arg* arg = grpc_channel_args_find(args, name); + return grpc_channel_arg_get_bool(arg, default_value); +} + +bool grpc_channel_args_want_minimal_stack(const grpc_channel_args* args) { + return grpc_channel_arg_get_bool( + grpc_channel_args_find(args, GRPC_ARG_MINIMAL_STACK), false); +} + +grpc_arg grpc_channel_arg_string_create(char* name, char* value) { + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = name; + arg.value.string = value; + return arg; +} + +grpc_arg grpc_channel_arg_integer_create(char* name, int value) { + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = name; + arg.value.integer = value; + return arg; +} + +grpc_arg grpc_channel_arg_pointer_create( + char* name, void* value, const grpc_arg_pointer_vtable* vtable) { + grpc_arg arg; + arg.type = GRPC_ARG_POINTER; + arg.key = name; + arg.value.pointer.p = value; + arg.value.pointer.vtable = vtable; + return arg; +} + +std::string grpc_channel_args_string(const grpc_channel_args* args) { + if (args == nullptr) return ""; + std::vector arg_strings; + for (size_t i = 0; i < args->num_args; ++i) { + const grpc_arg& arg = args->args[i]; + std::string arg_string; + switch (arg.type) { + case GRPC_ARG_INTEGER: + arg_string = absl::StrFormat("%s=%d", arg.key, arg.value.integer); + break; + case GRPC_ARG_STRING: + arg_string = absl::StrFormat("%s=%s", arg.key, arg.value.string); + break; + case GRPC_ARG_POINTER: + arg_string = absl::StrFormat("%s=%p", arg.key, arg.value.pointer.p); + break; + default: + arg_string = "arg with unknown type"; + } + arg_strings.push_back(arg_string); + } + return absl::StrJoin(arg_strings, ", "); +} + +namespace { +grpc_channel_args_client_channel_creation_mutator g_mutator = nullptr; +} // namespace + +void grpc_channel_args_set_client_channel_creation_mutator( + grpc_channel_args_client_channel_creation_mutator cb) { + GPR_DEBUG_ASSERT(g_mutator == nullptr); + g_mutator = cb; +} +grpc_channel_args_client_channel_creation_mutator +grpc_channel_args_get_client_channel_creation_mutator() { + return g_mutator; +} diff --git a/src/core/lib/channel/channel_stack.cc b/src/core/lib/channel/channel_stack.cc new file mode 100644 index 00000000..a7f890e2 --- /dev/null +++ b/src/core/lib/channel/channel_stack.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channel_stack.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/alloc.h" + +grpc_core::TraceFlag grpc_trace_channel(false, "channel"); + +/* Memory layouts. + + Channel stack is laid out as: { + grpc_channel_stack stk; + padding to GPR_MAX_ALIGNMENT + grpc_channel_element[stk.count]; + per-filter memory, aligned to GPR_MAX_ALIGNMENT + } + + Call stack is laid out as: { + grpc_call_stack stk; + padding to GPR_MAX_ALIGNMENT + grpc_call_element[stk.count]; + per-filter memory, aligned to GPR_MAX_ALIGNMENT + } */ + +size_t grpc_channel_stack_size(const grpc_channel_filter** filters, + size_t filter_count) { + /* always need the header, and size for the channel elements */ + size_t size = GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_channel_stack)) + + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filter_count * + sizeof(grpc_channel_element)); + size_t i; + + GPR_ASSERT((GPR_MAX_ALIGNMENT & (GPR_MAX_ALIGNMENT - 1)) == 0 && + "GPR_MAX_ALIGNMENT must be a power of two"); + + /* add the size for each filter */ + for (i = 0; i < filter_count; i++) { + size += GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filters[i]->sizeof_channel_data); + } + + return size; +} + +#define CHANNEL_ELEMS_FROM_STACK(stk) \ + ((grpc_channel_element*)((char*)(stk) + GPR_ROUND_UP_TO_ALIGNMENT_SIZE( \ + sizeof(grpc_channel_stack)))) + +#define CALL_ELEMS_FROM_STACK(stk) \ + ((grpc_call_element*)((char*)(stk) + GPR_ROUND_UP_TO_ALIGNMENT_SIZE( \ + sizeof(grpc_call_stack)))) + +grpc_channel_element* grpc_channel_stack_element( + grpc_channel_stack* channel_stack, size_t index) { + return CHANNEL_ELEMS_FROM_STACK(channel_stack) + index; +} + +grpc_channel_element* grpc_channel_stack_last_element( + grpc_channel_stack* channel_stack) { + return grpc_channel_stack_element(channel_stack, channel_stack->count - 1); +} + +size_t grpc_channel_stack_filter_instance_number( + grpc_channel_stack* channel_stack, grpc_channel_element* elem) { + size_t num_found = 0; + for (size_t i = 0; i < channel_stack->count; ++i) { + grpc_channel_element* element = + grpc_channel_stack_element(channel_stack, i); + if (element == elem) break; + if (element->filter == elem->filter) ++num_found; + } + return num_found; +} + +grpc_call_element* grpc_call_stack_element(grpc_call_stack* call_stack, + size_t index) { + return CALL_ELEMS_FROM_STACK(call_stack) + index; +} + +grpc_error_handle grpc_channel_stack_init( + int initial_refs, grpc_iomgr_cb_func destroy, void* destroy_arg, + const grpc_channel_filter** filters, size_t filter_count, + const grpc_channel_args* channel_args, grpc_transport* optional_transport, + const char* name, grpc_channel_stack* stack) { + size_t call_size = + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_call_stack)) + + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filter_count * sizeof(grpc_call_element)); + grpc_channel_element* elems; + grpc_channel_element_args args; + char* user_data; + size_t i; + + stack->count = filter_count; + GRPC_STREAM_REF_INIT(&stack->refcount, initial_refs, destroy, destroy_arg, + name); + elems = CHANNEL_ELEMS_FROM_STACK(stack); + user_data = (reinterpret_cast(elems)) + + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filter_count * + sizeof(grpc_channel_element)); + + /* init per-filter data */ + grpc_error_handle first_error = GRPC_ERROR_NONE; + for (i = 0; i < filter_count; i++) { + args.channel_stack = stack; + args.channel_args = channel_args; + args.optional_transport = optional_transport; + args.is_first = i == 0; + args.is_last = i == (filter_count - 1); + elems[i].filter = filters[i]; + elems[i].channel_data = user_data; + grpc_error_handle error = + elems[i].filter->init_channel_elem(&elems[i], &args); + if (error != GRPC_ERROR_NONE) { + if (first_error == GRPC_ERROR_NONE) { + first_error = error; + } else { + GRPC_ERROR_UNREF(error); + } + } + user_data += + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filters[i]->sizeof_channel_data); + call_size += GPR_ROUND_UP_TO_ALIGNMENT_SIZE(filters[i]->sizeof_call_data); + } + + GPR_ASSERT(user_data > (char*)stack); + GPR_ASSERT((uintptr_t)(user_data - (char*)stack) == + grpc_channel_stack_size(filters, filter_count)); + + stack->call_stack_size = call_size; + return first_error; +} + +void grpc_channel_stack_destroy(grpc_channel_stack* stack) { + grpc_channel_element* channel_elems = CHANNEL_ELEMS_FROM_STACK(stack); + size_t count = stack->count; + size_t i; + + /* destroy per-filter data */ + for (i = 0; i < count; i++) { + channel_elems[i].filter->destroy_channel_elem(&channel_elems[i]); + } +} + +grpc_error_handle grpc_call_stack_init( + grpc_channel_stack* channel_stack, int initial_refs, + grpc_iomgr_cb_func destroy, void* destroy_arg, + const grpc_call_element_args* elem_args) { + grpc_channel_element* channel_elems = CHANNEL_ELEMS_FROM_STACK(channel_stack); + size_t count = channel_stack->count; + grpc_call_element* call_elems; + char* user_data; + + elem_args->call_stack->count = count; + GRPC_STREAM_REF_INIT(&elem_args->call_stack->refcount, initial_refs, destroy, + destroy_arg, "CALL_STACK"); + call_elems = CALL_ELEMS_FROM_STACK(elem_args->call_stack); + user_data = (reinterpret_cast(call_elems)) + + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(count * sizeof(grpc_call_element)); + + /* init per-filter data */ + grpc_error_handle first_error = GRPC_ERROR_NONE; + for (size_t i = 0; i < count; i++) { + call_elems[i].filter = channel_elems[i].filter; + call_elems[i].channel_data = channel_elems[i].channel_data; + call_elems[i].call_data = user_data; + user_data += + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(call_elems[i].filter->sizeof_call_data); + } + for (size_t i = 0; i < count; i++) { + grpc_error_handle error = + call_elems[i].filter->init_call_elem(&call_elems[i], elem_args); + if (error != GRPC_ERROR_NONE) { + if (first_error == GRPC_ERROR_NONE) { + first_error = error; + } else { + GRPC_ERROR_UNREF(error); + } + } + } + return first_error; +} + +void grpc_call_stack_set_pollset_or_pollset_set(grpc_call_stack* call_stack, + grpc_polling_entity* pollent) { + size_t count = call_stack->count; + grpc_call_element* call_elems; + size_t i; + + call_elems = CALL_ELEMS_FROM_STACK(call_stack); + + /* init per-filter data */ + for (i = 0; i < count; i++) { + call_elems[i].filter->set_pollset_or_pollset_set(&call_elems[i], pollent); + } +} + +void grpc_call_stack_ignore_set_pollset_or_pollset_set( + grpc_call_element* /*elem*/, grpc_polling_entity* /*pollent*/) {} + +void grpc_call_stack_destroy(grpc_call_stack* stack, + const grpc_call_final_info* final_info, + grpc_closure* then_schedule_closure) { + grpc_call_element* elems = CALL_ELEMS_FROM_STACK(stack); + size_t count = stack->count; + size_t i; + + /* destroy per-filter data */ + for (i = 0; i < count; i++) { + elems[i].filter->destroy_call_elem( + &elems[i], final_info, + i == count - 1 ? then_schedule_closure : nullptr); + } +} + +void grpc_call_next_op(grpc_call_element* elem, + grpc_transport_stream_op_batch* op) { + grpc_call_element* next_elem = elem + 1; + GRPC_CALL_LOG_OP(GPR_INFO, next_elem, op); + next_elem->filter->start_transport_stream_op_batch(next_elem, op); +} + +void grpc_channel_next_get_info(grpc_channel_element* elem, + const grpc_channel_info* channel_info) { + grpc_channel_element* next_elem = elem + 1; + next_elem->filter->get_channel_info(next_elem, channel_info); +} + +void grpc_channel_next_op(grpc_channel_element* elem, grpc_transport_op* op) { + grpc_channel_element* next_elem = elem + 1; + next_elem->filter->start_transport_op(next_elem, op); +} + +grpc_channel_stack* grpc_channel_stack_from_top_element( + grpc_channel_element* elem) { + return reinterpret_cast( + reinterpret_cast(elem) - + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_channel_stack))); +} + +grpc_call_stack* grpc_call_stack_from_top_element(grpc_call_element* elem) { + return reinterpret_cast( + reinterpret_cast(elem) - + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_call_stack))); +} diff --git a/src/core/lib/channel/channel_stack_builder.cc b/src/core/lib/channel/channel_stack_builder.cc new file mode 100644 index 00000000..186cad3f --- /dev/null +++ b/src/core/lib/channel/channel_stack_builder.cc @@ -0,0 +1,313 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channel_stack_builder.h" + +#include + +#include +#include + +#include "src/core/lib/gprpp/memory.h" + +typedef struct filter_node { + struct filter_node* next; + struct filter_node* prev; + const grpc_channel_filter* filter; + grpc_post_filter_create_init_func init; + void* init_arg; +} filter_node; + +struct grpc_channel_stack_builder { + // sentinel nodes for filters that have been added + filter_node begin; + filter_node end; + // various set/get-able parameters + grpc_channel_args* args; + grpc_transport* transport; + grpc_resource_user* resource_user; + size_t preallocated_bytes; + char* target; + const char* name; +}; + +struct grpc_channel_stack_builder_iterator { + grpc_channel_stack_builder* builder; + filter_node* node; +}; + +grpc_channel_stack_builder* grpc_channel_stack_builder_create(void) { + grpc_channel_stack_builder* b = + grpc_core::Zalloc(); + b->begin.filter = nullptr; + b->end.filter = nullptr; + b->begin.next = &b->end; + b->begin.prev = &b->end; + b->end.next = &b->begin; + b->end.prev = &b->begin; + return b; +} + +void grpc_channel_stack_builder_set_target(grpc_channel_stack_builder* b, + const char* target) { + gpr_free(b->target); + b->target = gpr_strdup(target); +} + +const char* grpc_channel_stack_builder_get_target( + grpc_channel_stack_builder* b) { + return b->target; +} + +static grpc_channel_stack_builder_iterator* create_iterator_at_filter_node( + grpc_channel_stack_builder* builder, filter_node* node) { + grpc_channel_stack_builder_iterator* it = + static_cast( + gpr_malloc(sizeof(*it))); + it->builder = builder; + it->node = node; + return it; +} + +void grpc_channel_stack_builder_iterator_destroy( + grpc_channel_stack_builder_iterator* it) { + gpr_free(it); +} + +grpc_channel_stack_builder_iterator* +grpc_channel_stack_builder_create_iterator_at_first( + grpc_channel_stack_builder* builder) { + return create_iterator_at_filter_node(builder, &builder->begin); +} + +grpc_channel_stack_builder_iterator* +grpc_channel_stack_builder_create_iterator_at_last( + grpc_channel_stack_builder* builder) { + return create_iterator_at_filter_node(builder, &builder->end); +} + +bool grpc_channel_stack_builder_iterator_is_end( + grpc_channel_stack_builder_iterator* iterator) { + return iterator->node == &iterator->builder->end; +} + +const char* grpc_channel_stack_builder_iterator_filter_name( + grpc_channel_stack_builder_iterator* iterator) { + if (iterator->node->filter == nullptr) return nullptr; + return iterator->node->filter->name; +} + +bool grpc_channel_stack_builder_move_next( + grpc_channel_stack_builder_iterator* iterator) { + if (iterator->node == &iterator->builder->end) return false; + iterator->node = iterator->node->next; + return true; +} + +bool grpc_channel_stack_builder_move_prev( + grpc_channel_stack_builder_iterator* iterator) { + if (iterator->node == &iterator->builder->begin) return false; + iterator->node = iterator->node->prev; + return true; +} + +grpc_channel_stack_builder_iterator* grpc_channel_stack_builder_iterator_find( + grpc_channel_stack_builder* builder, const char* filter_name) { + GPR_ASSERT(filter_name != nullptr); + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + while (grpc_channel_stack_builder_move_next(it)) { + if (grpc_channel_stack_builder_iterator_is_end(it)) break; + const char* filter_name_at_it = + grpc_channel_stack_builder_iterator_filter_name(it); + if (strcmp(filter_name, filter_name_at_it) == 0) break; + } + return it; +} + +bool grpc_channel_stack_builder_move_prev( + grpc_channel_stack_builder_iterator* iterator); + +void grpc_channel_stack_builder_set_name(grpc_channel_stack_builder* builder, + const char* name) { + GPR_ASSERT(builder->name == nullptr); + builder->name = name; +} + +void grpc_channel_stack_builder_set_channel_arguments( + grpc_channel_stack_builder* builder, const grpc_channel_args* args) { + if (builder->args != nullptr) { + grpc_channel_args_destroy(builder->args); + } + builder->args = grpc_channel_args_copy(args); +} + +const grpc_channel_args* grpc_channel_stack_builder_get_channel_arguments( + grpc_channel_stack_builder* builder) { + return builder->args; +} + +void grpc_channel_stack_builder_set_transport( + grpc_channel_stack_builder* builder, grpc_transport* transport) { + GPR_ASSERT(builder->transport == nullptr); + builder->transport = transport; +} + +grpc_transport* grpc_channel_stack_builder_get_transport( + grpc_channel_stack_builder* builder) { + return builder->transport; +} + +bool grpc_channel_stack_builder_append_filter( + grpc_channel_stack_builder* builder, const grpc_channel_filter* filter, + grpc_post_filter_create_init_func post_init_func, void* user_data) { + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_last(builder); + bool ok = grpc_channel_stack_builder_add_filter_before( + it, filter, post_init_func, user_data); + grpc_channel_stack_builder_iterator_destroy(it); + return ok; +} + +bool grpc_channel_stack_builder_remove_filter( + grpc_channel_stack_builder* builder, const char* filter_name) { + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_iterator_find(builder, filter_name); + if (grpc_channel_stack_builder_iterator_is_end(it)) { + grpc_channel_stack_builder_iterator_destroy(it); + return false; + } + it->node->prev->next = it->node->next; + it->node->next->prev = it->node->prev; + gpr_free(it->node); + grpc_channel_stack_builder_iterator_destroy(it); + return true; +} + +bool grpc_channel_stack_builder_prepend_filter( + grpc_channel_stack_builder* builder, const grpc_channel_filter* filter, + grpc_post_filter_create_init_func post_init_func, void* user_data) { + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + bool ok = grpc_channel_stack_builder_add_filter_after( + it, filter, post_init_func, user_data); + grpc_channel_stack_builder_iterator_destroy(it); + return ok; +} + +static void add_after(filter_node* before, const grpc_channel_filter* filter, + grpc_post_filter_create_init_func post_init_func, + void* user_data) { + filter_node* new_node = + static_cast(gpr_malloc(sizeof(*new_node))); + new_node->next = before->next; + new_node->prev = before; + new_node->next->prev = new_node->prev->next = new_node; + new_node->filter = filter; + new_node->init = post_init_func; + new_node->init_arg = user_data; +} + +bool grpc_channel_stack_builder_add_filter_before( + grpc_channel_stack_builder_iterator* iterator, + const grpc_channel_filter* filter, + grpc_post_filter_create_init_func post_init_func, void* user_data) { + if (iterator->node == &iterator->builder->begin) return false; + add_after(iterator->node->prev, filter, post_init_func, user_data); + return true; +} + +bool grpc_channel_stack_builder_add_filter_after( + grpc_channel_stack_builder_iterator* iterator, + const grpc_channel_filter* filter, + grpc_post_filter_create_init_func post_init_func, void* user_data) { + if (iterator->node == &iterator->builder->end) return false; + add_after(iterator->node, filter, post_init_func, user_data); + return true; +} + +void grpc_channel_stack_builder_destroy(grpc_channel_stack_builder* builder) { + filter_node* p = builder->begin.next; + while (p != &builder->end) { + filter_node* next = p->next; + gpr_free(p); + p = next; + } + if (builder->args != nullptr) { + grpc_channel_args_destroy(builder->args); + } + gpr_free(builder->target); + gpr_free(builder); +} + +grpc_error_handle grpc_channel_stack_builder_finish( + grpc_channel_stack_builder* builder, size_t prefix_bytes, int initial_refs, + grpc_iomgr_cb_func destroy, void* destroy_arg, void** result) { + // count the number of filters + size_t num_filters = 0; + for (filter_node* p = builder->begin.next; p != &builder->end; p = p->next) { + num_filters++; + } + + // create an array of filters + const grpc_channel_filter** filters = + static_cast( + gpr_malloc(sizeof(*filters) * num_filters)); + size_t i = 0; + for (filter_node* p = builder->begin.next; p != &builder->end; p = p->next) { + filters[i++] = p->filter; + } + + // calculate the size of the channel stack + size_t channel_stack_size = grpc_channel_stack_size(filters, num_filters); + + // allocate memory, with prefix_bytes followed by channel_stack_size + *result = gpr_zalloc(prefix_bytes + channel_stack_size); + // fetch a pointer to the channel stack + grpc_channel_stack* channel_stack = reinterpret_cast( + static_cast(*result) + prefix_bytes); + // and initialize it + grpc_error_handle error = grpc_channel_stack_init( + initial_refs, destroy, destroy_arg == nullptr ? *result : destroy_arg, + filters, num_filters, builder->args, builder->transport, builder->name, + channel_stack); + + if (error != GRPC_ERROR_NONE) { + grpc_channel_stack_destroy(channel_stack); + gpr_free(*result); + *result = nullptr; + } else { + // run post-initialization functions + i = 0; + for (filter_node* p = builder->begin.next; p != &builder->end; + p = p->next) { + if (p->init != nullptr) { + p->init(channel_stack, grpc_channel_stack_element(channel_stack, i), + p->init_arg); + } + i++; + } + } + + grpc_channel_stack_builder_destroy(builder); + gpr_free(const_cast(filters)); + + return error; +} diff --git a/src/core/lib/channel/channel_trace.cc b/src/core/lib/channel/channel_trace.cc new file mode 100644 index 00000000..f5a825c0 --- /dev/null +++ b/src/core/lib/channel/channel_trace.cc @@ -0,0 +1,193 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channel_trace.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_core { +namespace channelz { + +ChannelTrace::TraceEvent::TraceEvent(Severity severity, const grpc_slice& data, + RefCountedPtr referenced_entity) + : severity_(severity), + data_(data), + timestamp_(grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME)), + next_(nullptr), + referenced_entity_(std::move(referenced_entity)), + memory_usage_(sizeof(TraceEvent) + grpc_slice_memory_usage(data)) {} + +ChannelTrace::TraceEvent::TraceEvent(Severity severity, const grpc_slice& data) + : severity_(severity), + data_(data), + timestamp_(grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME)), + next_(nullptr), + memory_usage_(sizeof(TraceEvent) + grpc_slice_memory_usage(data)) {} + +ChannelTrace::TraceEvent::~TraceEvent() { grpc_slice_unref_internal(data_); } + +ChannelTrace::ChannelTrace(size_t max_event_memory) + : num_events_logged_(0), + event_list_memory_usage_(0), + max_event_memory_(max_event_memory), + head_trace_(nullptr), + tail_trace_(nullptr) { + if (max_event_memory_ == 0) { + return; // tracing is disabled if max_event_memory_ == 0 + } + gpr_mu_init(&tracer_mu_); + time_created_ = grpc_millis_to_timespec(grpc_core::ExecCtx::Get()->Now(), + GPR_CLOCK_REALTIME); +} + +ChannelTrace::~ChannelTrace() { + if (max_event_memory_ == 0) { + return; // tracing is disabled if max_event_memory_ == 0 + } + TraceEvent* it = head_trace_; + while (it != nullptr) { + TraceEvent* to_free = it; + it = it->next(); + delete to_free; + } + gpr_mu_destroy(&tracer_mu_); +} + +void ChannelTrace::AddTraceEventHelper(TraceEvent* new_trace_event) { + ++num_events_logged_; + // first event case + if (head_trace_ == nullptr) { + head_trace_ = tail_trace_ = new_trace_event; + } + // regular event add case + else { + tail_trace_->set_next(new_trace_event); + tail_trace_ = tail_trace_->next(); + } + event_list_memory_usage_ += new_trace_event->memory_usage(); + // maybe garbage collect the tail until we are under the memory limit. + while (event_list_memory_usage_ > max_event_memory_) { + TraceEvent* to_free = head_trace_; + event_list_memory_usage_ -= to_free->memory_usage(); + head_trace_ = head_trace_->next(); + delete to_free; + } +} + +void ChannelTrace::AddTraceEvent(Severity severity, const grpc_slice& data) { + if (max_event_memory_ == 0) { + grpc_slice_unref_internal(data); + return; // tracing is disabled if max_event_memory_ == 0 + } + AddTraceEventHelper(new TraceEvent(severity, data)); +} + +void ChannelTrace::AddTraceEventWithReference( + Severity severity, const grpc_slice& data, + RefCountedPtr referenced_entity) { + if (max_event_memory_ == 0) { + grpc_slice_unref_internal(data); + return; // tracing is disabled if max_event_memory_ == 0 + } + // create and fill up the new event + AddTraceEventHelper( + new TraceEvent(severity, data, std::move(referenced_entity))); +} + +namespace { + +const char* severity_string(ChannelTrace::Severity severity) { + switch (severity) { + case ChannelTrace::Severity::Info: + return "CT_INFO"; + case ChannelTrace::Severity::Warning: + return "CT_WARNING"; + case ChannelTrace::Severity::Error: + return "CT_ERROR"; + default: + GPR_UNREACHABLE_CODE(return "CT_UNKNOWN"); + } +} + +} // anonymous namespace + +Json ChannelTrace::TraceEvent::RenderTraceEvent() const { + char* description = grpc_slice_to_c_string(data_); + Json::Object object = { + {"description", description}, + {"severity", severity_string(severity_)}, + {"timestamp", gpr_format_timespec(timestamp_)}, + }; + gpr_free(description); + if (referenced_entity_ != nullptr) { + const bool is_channel = + (referenced_entity_->type() == BaseNode::EntityType::kTopLevelChannel || + referenced_entity_->type() == BaseNode::EntityType::kInternalChannel); + object[is_channel ? "channelRef" : "subchannelRef"] = Json::Object{ + {(is_channel ? "channelId" : "subchannelId"), + std::to_string(referenced_entity_->uuid())}, + }; + } + return object; +} + +Json ChannelTrace::RenderJson() const { + // Tracing is disabled if max_event_memory_ == 0. + if (max_event_memory_ == 0) { + return Json(); // JSON null + } + Json::Object object = { + {"creationTimestamp", gpr_format_timespec(time_created_)}, + }; + if (num_events_logged_ > 0) { + object["numEventsLogged"] = std::to_string(num_events_logged_); + } + // Only add in the event list if it is non-empty. + if (head_trace_ != nullptr) { + Json::Array array; + for (TraceEvent* it = head_trace_; it != nullptr; it = it->next()) { + array.emplace_back(it->RenderTraceEvent()); + } + object["events"] = std::move(array); + } + return object; +} + +} // namespace channelz +} // namespace grpc_core diff --git a/src/core/lib/channel/channelz.cc b/src/core/lib/channel/channelz.cc new file mode 100644 index 00000000..37dca3e2 --- /dev/null +++ b/src/core/lib/channel/channelz.cc @@ -0,0 +1,596 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channelz.h" + +#include +#include +#include + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/strip.h" + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { +namespace channelz { + +// +// BaseNode +// + +BaseNode::BaseNode(EntityType type, std::string name) + : type_(type), uuid_(-1), name_(std::move(name)) { + // The registry will set uuid_ under its lock. + ChannelzRegistry::Register(this); +} + +BaseNode::~BaseNode() { ChannelzRegistry::Unregister(uuid_); } + +std::string BaseNode::RenderJsonString() { + Json json = RenderJson(); + return json.Dump(); +} + +// +// CallCountingHelper +// + +CallCountingHelper::CallCountingHelper() { + num_cores_ = std::max(1u, gpr_cpu_num_cores()); + per_cpu_counter_data_storage_.reserve(num_cores_); + for (size_t i = 0; i < num_cores_; ++i) { + per_cpu_counter_data_storage_.emplace_back(); + } +} + +void CallCountingHelper::RecordCallStarted() { + AtomicCounterData& data = + per_cpu_counter_data_storage_[ExecCtx::Get()->starting_cpu()]; + data.calls_started.fetch_add(1, std::memory_order_relaxed); + data.last_call_started_cycle.store(gpr_get_cycle_counter(), + std::memory_order_relaxed); +} + +void CallCountingHelper::RecordCallFailed() { + per_cpu_counter_data_storage_[ExecCtx::Get()->starting_cpu()] + .calls_failed.fetch_add(1, std::memory_order_relaxed); +} + +void CallCountingHelper::RecordCallSucceeded() { + per_cpu_counter_data_storage_[ExecCtx::Get()->starting_cpu()] + .calls_succeeded.fetch_add(1, std::memory_order_relaxed); +} + +void CallCountingHelper::CollectData(CounterData* out) { + for (size_t core = 0; core < num_cores_; ++core) { + AtomicCounterData& data = per_cpu_counter_data_storage_[core]; + + out->calls_started += data.calls_started.load(std::memory_order_relaxed); + out->calls_succeeded += + per_cpu_counter_data_storage_[core].calls_succeeded.load( + std::memory_order_relaxed); + out->calls_failed += per_cpu_counter_data_storage_[core].calls_failed.load( + std::memory_order_relaxed); + const gpr_cycle_counter last_call = + per_cpu_counter_data_storage_[core].last_call_started_cycle.load( + std::memory_order_relaxed); + if (last_call > out->last_call_started_cycle) { + out->last_call_started_cycle = last_call; + } + } +} + +void CallCountingHelper::PopulateCallCounts(Json::Object* json) { + CounterData data; + CollectData(&data); + if (data.calls_started != 0) { + (*json)["callsStarted"] = std::to_string(data.calls_started); + gpr_timespec ts = gpr_convert_clock_type( + gpr_cycle_counter_to_time(data.last_call_started_cycle), + GPR_CLOCK_REALTIME); + (*json)["lastCallStartedTimestamp"] = gpr_format_timespec(ts); + } + if (data.calls_succeeded != 0) { + (*json)["callsSucceeded"] = std::to_string(data.calls_succeeded); + } + if (data.calls_failed) { + (*json)["callsFailed"] = std::to_string(data.calls_failed); + } +} + +// +// ChannelNode +// + +ChannelNode::ChannelNode(std::string target, size_t channel_tracer_max_nodes, + bool is_internal_channel) + : BaseNode(is_internal_channel ? EntityType::kInternalChannel + : EntityType::kTopLevelChannel, + target), + target_(std::move(target)), + trace_(channel_tracer_max_nodes) {} + +const char* ChannelNode::GetChannelConnectivityStateChangeString( + grpc_connectivity_state state) { + switch (state) { + case GRPC_CHANNEL_IDLE: + return "Channel state change to IDLE"; + case GRPC_CHANNEL_CONNECTING: + return "Channel state change to CONNECTING"; + case GRPC_CHANNEL_READY: + return "Channel state change to READY"; + case GRPC_CHANNEL_TRANSIENT_FAILURE: + return "Channel state change to TRANSIENT_FAILURE"; + case GRPC_CHANNEL_SHUTDOWN: + return "Channel state change to SHUTDOWN"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +Json ChannelNode::RenderJson() { + Json::Object data = { + {"target", target_}, + }; + // Connectivity state. + // If low-order bit is on, then the field is set. + int state_field = connectivity_state_.load(std::memory_order_relaxed); + if ((state_field & 1) != 0) { + grpc_connectivity_state state = + static_cast(state_field >> 1); + data["state"] = Json::Object{ + {"state", ConnectivityStateName(state)}, + }; + } + // Fill in the channel trace if applicable. + Json trace_json = trace_.RenderJson(); + if (trace_json.type() != Json::Type::JSON_NULL) { + data["trace"] = std::move(trace_json); + } + // Ask CallCountingHelper to populate call count data. + call_counter_.PopulateCallCounts(&data); + // Construct outer object. + Json::Object json = { + {"ref", + Json::Object{ + {"channelId", std::to_string(uuid())}, + }}, + {"data", std::move(data)}, + }; + // Template method. Child classes may override this to add their specific + // functionality. + PopulateChildRefs(&json); + return json; +} + +void ChannelNode::PopulateChildRefs(Json::Object* json) { + MutexLock lock(&child_mu_); + if (!child_subchannels_.empty()) { + Json::Array array; + for (intptr_t subchannel_uuid : child_subchannels_) { + array.emplace_back(Json::Object{ + {"subchannelId", std::to_string(subchannel_uuid)}, + }); + } + (*json)["subchannelRef"] = std::move(array); + } + if (!child_channels_.empty()) { + Json::Array array; + for (intptr_t channel_uuid : child_channels_) { + array.emplace_back(Json::Object{ + {"channelId", std::to_string(channel_uuid)}, + }); + } + (*json)["channelRef"] = std::move(array); + } +} + +void ChannelNode::SetConnectivityState(grpc_connectivity_state state) { + // Store with low-order bit set to indicate that the field is set. + int state_field = (state << 1) + 1; + connectivity_state_.store(state_field, std::memory_order_relaxed); +} + +void ChannelNode::AddChildChannel(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_channels_.insert(child_uuid); +} + +void ChannelNode::RemoveChildChannel(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_channels_.erase(child_uuid); +} + +void ChannelNode::AddChildSubchannel(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_subchannels_.insert(child_uuid); +} + +void ChannelNode::RemoveChildSubchannel(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_subchannels_.erase(child_uuid); +} + +// +// ServerNode +// + +ServerNode::ServerNode(size_t channel_tracer_max_nodes) + : BaseNode(EntityType::kServer, ""), trace_(channel_tracer_max_nodes) {} + +ServerNode::~ServerNode() {} + +void ServerNode::AddChildSocket(RefCountedPtr node) { + MutexLock lock(&child_mu_); + child_sockets_.insert(std::make_pair(node->uuid(), std::move(node))); +} + +void ServerNode::RemoveChildSocket(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_sockets_.erase(child_uuid); +} + +void ServerNode::AddChildListenSocket(RefCountedPtr node) { + MutexLock lock(&child_mu_); + child_listen_sockets_.insert(std::make_pair(node->uuid(), std::move(node))); +} + +void ServerNode::RemoveChildListenSocket(intptr_t child_uuid) { + MutexLock lock(&child_mu_); + child_listen_sockets_.erase(child_uuid); +} + +std::string ServerNode::RenderServerSockets(intptr_t start_socket_id, + intptr_t max_results) { + GPR_ASSERT(start_socket_id >= 0); + GPR_ASSERT(max_results >= 0); + // If user does not set max_results, we choose 500. + size_t pagination_limit = max_results == 0 ? 500 : max_results; + Json::Object object; + { + MutexLock lock(&child_mu_); + size_t sockets_rendered = 0; + // Create list of socket refs. + Json::Array array; + auto it = child_sockets_.lower_bound(start_socket_id); + for (; it != child_sockets_.end() && sockets_rendered < pagination_limit; + ++it, ++sockets_rendered) { + array.emplace_back(Json::Object{ + {"socketId", std::to_string(it->first)}, + {"name", it->second->name()}, + }); + } + object["socketRef"] = std::move(array); + if (it == child_sockets_.end()) object["end"] = true; + } + Json json = std::move(object); + return json.Dump(); +} + +Json ServerNode::RenderJson() { + Json::Object data; + // Fill in the channel trace if applicable. + Json trace_json = trace_.RenderJson(); + if (trace_json.type() != Json::Type::JSON_NULL) { + data["trace"] = std::move(trace_json); + } + // Ask CallCountingHelper to populate call count data. + call_counter_.PopulateCallCounts(&data); + // Construct top-level object. + Json::Object object = { + {"ref", + Json::Object{ + {"serverId", std::to_string(uuid())}, + }}, + {"data", std::move(data)}, + }; + // Render listen sockets. + { + MutexLock lock(&child_mu_); + if (!child_listen_sockets_.empty()) { + Json::Array array; + for (const auto& it : child_listen_sockets_) { + array.emplace_back(Json::Object{ + {"socketId", std::to_string(it.first)}, + {"name", it.second->name()}, + }); + } + object["listenSocket"] = std::move(array); + } + } + return object; +} + +// +// SocketNode::Security::Tls +// + +Json SocketNode::Security::Tls::RenderJson() { + Json::Object data; + if (type == NameType::kStandardName) { + data["standard_name"] = name; + } else if (type == NameType::kOtherName) { + data["other_name"] = name; + } + if (!local_certificate.empty()) { + data["local_certificate"] = absl::Base64Escape(local_certificate); + } + if (!remote_certificate.empty()) { + data["remote_certificate"] = absl::Base64Escape(remote_certificate); + } + return data; +} + +// +// SocketNode::Security +// + +Json SocketNode::Security::RenderJson() { + Json::Object data; + switch (type) { + case ModelType::kUnset: + break; + case ModelType::kTls: + if (tls) { + data["tls"] = tls->RenderJson(); + } + break; + case ModelType::kOther: + if (other) { + data["other"] = *other; + } + break; + } + return data; +} + +namespace { + +void* SecurityArgCopy(void* p) { + SocketNode::Security* xds_certificate_provider = + static_cast(p); + return xds_certificate_provider->Ref().release(); +} + +void SecurityArgDestroy(void* p) { + SocketNode::Security* xds_certificate_provider = + static_cast(p); + xds_certificate_provider->Unref(); +} + +int SecurityArgCmp(void* p, void* q) { return grpc_core::QsortCompare(p, q); } + +const grpc_arg_pointer_vtable kChannelArgVtable = { + SecurityArgCopy, SecurityArgDestroy, SecurityArgCmp}; + +} // namespace + +grpc_arg SocketNode::Security::MakeChannelArg() const { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CHANNELZ_SECURITY), + const_cast(this), &kChannelArgVtable); +} + +RefCountedPtr SocketNode::Security::GetFromChannelArgs( + const grpc_channel_args* args) { + Security* security = grpc_channel_args_find_pointer( + args, GRPC_ARG_CHANNELZ_SECURITY); + return security != nullptr ? security->Ref() : nullptr; +} + +// +// SocketNode +// + +namespace { + +void PopulateSocketAddressJson(Json::Object* json, const char* name, + const char* addr_str) { + if (addr_str == nullptr) return; + Json::Object data; + absl::StatusOr uri = URI::Parse(addr_str); + if (uri.ok() && (uri->scheme() == "ipv4" || uri->scheme() == "ipv6")) { + std::string host; + std::string port; + GPR_ASSERT( + SplitHostPort(absl::StripPrefix(uri->path(), "/"), &host, &port)); + int port_num = -1; + if (!port.empty()) { + port_num = atoi(port.data()); + } + grpc_resolved_address resolved_host; + grpc_error_handle error = + grpc_string_to_sockaddr(&resolved_host, host.c_str(), port_num); + if (error == GRPC_ERROR_NONE) { + std::string packed_host = grpc_sockaddr_get_packed_host(&resolved_host); + std::string b64_host = absl::Base64Escape(packed_host); + data["tcpip_address"] = Json::Object{ + {"port", port_num}, + {"ip_address", b64_host}, + }; + (*json)[name] = std::move(data); + return; + } + GRPC_ERROR_UNREF(error); + } + if (uri.ok() && uri->scheme() == "unix") { + data["uds_address"] = Json::Object{ + {"filename", uri->path()}, + }; + } else { + data["other_address"] = Json::Object{ + {"name", addr_str}, + }; + } + (*json)[name] = std::move(data); +} + +} // namespace + +SocketNode::SocketNode(std::string local, std::string remote, std::string name, + RefCountedPtr security) + : BaseNode(EntityType::kSocket, std::move(name)), + local_(std::move(local)), + remote_(std::move(remote)), + security_(std::move(security)) {} + +void SocketNode::RecordStreamStartedFromLocal() { + streams_started_.fetch_add(1, std::memory_order_relaxed); + last_local_stream_created_cycle_.store(gpr_get_cycle_counter(), + std::memory_order_relaxed); +} + +void SocketNode::RecordStreamStartedFromRemote() { + streams_started_.fetch_add(1, std::memory_order_relaxed); + last_remote_stream_created_cycle_.store(gpr_get_cycle_counter(), + std::memory_order_relaxed); +} + +void SocketNode::RecordMessagesSent(uint32_t num_sent) { + messages_sent_.fetch_add(num_sent, std::memory_order_relaxed); + last_message_sent_cycle_.store(gpr_get_cycle_counter(), + std::memory_order_relaxed); +} + +void SocketNode::RecordMessageReceived() { + messages_received_.fetch_add(1, std::memory_order_relaxed); + last_message_received_cycle_.store(gpr_get_cycle_counter(), + std::memory_order_relaxed); +} + +Json SocketNode::RenderJson() { + // Create and fill the data child. + Json::Object data; + gpr_timespec ts; + int64_t streams_started = streams_started_.load(std::memory_order_relaxed); + if (streams_started != 0) { + data["streamsStarted"] = std::to_string(streams_started); + gpr_cycle_counter last_local_stream_created_cycle = + last_local_stream_created_cycle_.load(std::memory_order_relaxed); + if (last_local_stream_created_cycle != 0) { + ts = gpr_convert_clock_type( + gpr_cycle_counter_to_time(last_local_stream_created_cycle), + GPR_CLOCK_REALTIME); + data["lastLocalStreamCreatedTimestamp"] = gpr_format_timespec(ts); + } + gpr_cycle_counter last_remote_stream_created_cycle = + last_remote_stream_created_cycle_.load(std::memory_order_relaxed); + if (last_remote_stream_created_cycle != 0) { + ts = gpr_convert_clock_type( + gpr_cycle_counter_to_time(last_remote_stream_created_cycle), + GPR_CLOCK_REALTIME); + data["lastRemoteStreamCreatedTimestamp"] = gpr_format_timespec(ts); + } + } + int64_t streams_succeeded = + streams_succeeded_.load(std::memory_order_relaxed); + if (streams_succeeded != 0) { + data["streamsSucceeded"] = std::to_string(streams_succeeded); + } + int64_t streams_failed = streams_failed_.load(std::memory_order_relaxed); + if (streams_failed != 0) { + data["streamsFailed"] = std::to_string(streams_failed); + } + int64_t messages_sent = messages_sent_.load(std::memory_order_relaxed); + if (messages_sent != 0) { + data["messagesSent"] = std::to_string(messages_sent); + ts = gpr_convert_clock_type( + gpr_cycle_counter_to_time( + last_message_sent_cycle_.load(std::memory_order_relaxed)), + GPR_CLOCK_REALTIME); + data["lastMessageSentTimestamp"] = gpr_format_timespec(ts); + } + int64_t messages_received = + messages_received_.load(std::memory_order_relaxed); + if (messages_received != 0) { + data["messagesReceived"] = std::to_string(messages_received); + ts = gpr_convert_clock_type( + gpr_cycle_counter_to_time( + last_message_received_cycle_.load(std::memory_order_relaxed)), + GPR_CLOCK_REALTIME); + data["lastMessageReceivedTimestamp"] = gpr_format_timespec(ts); + } + int64_t keepalives_sent = keepalives_sent_.load(std::memory_order_relaxed); + if (keepalives_sent != 0) { + data["keepAlivesSent"] = std::to_string(keepalives_sent); + } + // Create and fill the parent object. + Json::Object object = { + {"ref", + Json::Object{ + {"socketId", std::to_string(uuid())}, + {"name", name()}, + }}, + {"data", std::move(data)}, + }; + if (security_ != nullptr && + security_->type != SocketNode::Security::ModelType::kUnset) { + object["security"] = security_->RenderJson(); + } + PopulateSocketAddressJson(&object, "remote", remote_.c_str()); + PopulateSocketAddressJson(&object, "local", local_.c_str()); + return object; +} + +// +// ListenSocketNode +// + +ListenSocketNode::ListenSocketNode(std::string local_addr, std::string name) + : BaseNode(EntityType::kSocket, std::move(name)), + local_addr_(std::move(local_addr)) {} + +Json ListenSocketNode::RenderJson() { + Json::Object object = { + {"ref", + Json::Object{ + {"socketId", std::to_string(uuid())}, + {"name", name()}, + }}, + }; + PopulateSocketAddressJson(&object, "local", local_addr_.c_str()); + return object; +} + +} // namespace channelz +} // namespace grpc_core diff --git a/src/core/lib/channel/channelz_registry.cc b/src/core/lib/channel/channelz_registry.cc new file mode 100644 index 00000000..3ef24792 --- /dev/null +++ b/src/core/lib/channel/channelz_registry.cc @@ -0,0 +1,285 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/channelz_registry.h" + +#include +#include + +#include "absl/container/inlined_vector.h" + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/sync.h" + +namespace grpc_core { +namespace channelz { +namespace { + +// singleton instance of the registry. +ChannelzRegistry* g_channelz_registry = nullptr; + +const int kPaginationLimit = 100; + +} // anonymous namespace + +void ChannelzRegistry::Init() { g_channelz_registry = new ChannelzRegistry(); } + +void ChannelzRegistry::Shutdown() { delete g_channelz_registry; } + +ChannelzRegistry* ChannelzRegistry::Default() { + GPR_DEBUG_ASSERT(g_channelz_registry != nullptr); + return g_channelz_registry; +} + +void ChannelzRegistry::InternalRegister(BaseNode* node) { + MutexLock lock(&mu_); + node->uuid_ = ++uuid_generator_; + node_map_[node->uuid_] = node; +} + +void ChannelzRegistry::InternalUnregister(intptr_t uuid) { + GPR_ASSERT(uuid >= 1); + MutexLock lock(&mu_); + GPR_ASSERT(uuid <= uuid_generator_); + node_map_.erase(uuid); +} + +RefCountedPtr ChannelzRegistry::InternalGet(intptr_t uuid) { + MutexLock lock(&mu_); + if (uuid < 1 || uuid > uuid_generator_) { + return nullptr; + } + auto it = node_map_.find(uuid); + if (it == node_map_.end()) return nullptr; + // Found node. Return only if its refcount is not zero (i.e., when we + // know that there is no other thread about to destroy it). + BaseNode* node = it->second; + return node->RefIfNonZero(); +} + +std::string ChannelzRegistry::InternalGetTopChannels( + intptr_t start_channel_id) { + absl::InlinedVector, 10> top_level_channels; + RefCountedPtr node_after_pagination_limit; + { + MutexLock lock(&mu_); + for (auto it = node_map_.lower_bound(start_channel_id); + it != node_map_.end(); ++it) { + BaseNode* node = it->second; + RefCountedPtr node_ref; + if (node->type() == BaseNode::EntityType::kTopLevelChannel && + (node_ref = node->RefIfNonZero()) != nullptr) { + // Check if we are over pagination limit to determine if we need to set + // the "end" element. If we don't go through this block, we know that + // when the loop terminates, we have <= to kPaginationLimit. + // Note that because we have already increased this node's + // refcount, we need to decrease it, but we can't unref while + // holding the lock, because this may lead to a deadlock. + if (top_level_channels.size() == kPaginationLimit) { + node_after_pagination_limit = std::move(node_ref); + break; + } + top_level_channels.emplace_back(std::move(node_ref)); + } + } + } + Json::Object object; + if (!top_level_channels.empty()) { + // Create list of channels. + Json::Array array; + for (size_t i = 0; i < top_level_channels.size(); ++i) { + array.emplace_back(top_level_channels[i]->RenderJson()); + } + object["channel"] = std::move(array); + } + if (node_after_pagination_limit == nullptr) object["end"] = true; + Json json(std::move(object)); + return json.Dump(); +} + +std::string ChannelzRegistry::InternalGetServers(intptr_t start_server_id) { + absl::InlinedVector, 10> servers; + RefCountedPtr node_after_pagination_limit; + { + MutexLock lock(&mu_); + for (auto it = node_map_.lower_bound(start_server_id); + it != node_map_.end(); ++it) { + BaseNode* node = it->second; + RefCountedPtr node_ref; + if (node->type() == BaseNode::EntityType::kServer && + (node_ref = node->RefIfNonZero()) != nullptr) { + // Check if we are over pagination limit to determine if we need to set + // the "end" element. If we don't go through this block, we know that + // when the loop terminates, we have <= to kPaginationLimit. + // Note that because we have already increased this node's + // refcount, we need to decrease it, but we can't unref while + // holding the lock, because this may lead to a deadlock. + if (servers.size() == kPaginationLimit) { + node_after_pagination_limit = std::move(node_ref); + break; + } + servers.emplace_back(std::move(node_ref)); + } + } + } + Json::Object object; + if (!servers.empty()) { + // Create list of servers. + Json::Array array; + for (size_t i = 0; i < servers.size(); ++i) { + array.emplace_back(servers[i]->RenderJson()); + } + object["server"] = std::move(array); + } + if (node_after_pagination_limit == nullptr) object["end"] = true; + Json json(std::move(object)); + return json.Dump(); +} + +void ChannelzRegistry::InternalLogAllEntities() { + absl::InlinedVector, 10> nodes; + { + MutexLock lock(&mu_); + for (auto& p : node_map_) { + RefCountedPtr node = p.second->RefIfNonZero(); + if (node != nullptr) { + nodes.emplace_back(std::move(node)); + } + } + } + for (size_t i = 0; i < nodes.size(); ++i) { + std::string json = nodes[i]->RenderJsonString(); + gpr_log(GPR_INFO, "%s", json.c_str()); + } +} + +} // namespace channelz +} // namespace grpc_core + +char* grpc_channelz_get_top_channels(intptr_t start_channel_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + return gpr_strdup( + grpc_core::channelz::ChannelzRegistry::GetTopChannels(start_channel_id) + .c_str()); +} + +char* grpc_channelz_get_servers(intptr_t start_server_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + return gpr_strdup( + grpc_core::channelz::ChannelzRegistry::GetServers(start_server_id) + .c_str()); +} + +char* grpc_channelz_get_server(intptr_t server_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_core::RefCountedPtr server_node = + grpc_core::channelz::ChannelzRegistry::Get(server_id); + if (server_node == nullptr || + server_node->type() != + grpc_core::channelz::BaseNode::EntityType::kServer) { + return nullptr; + } + grpc_core::Json json = grpc_core::Json::Object{ + {"server", server_node->RenderJson()}, + }; + return gpr_strdup(json.Dump().c_str()); +} + +char* grpc_channelz_get_server_sockets(intptr_t server_id, + intptr_t start_socket_id, + intptr_t max_results) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + // Validate inputs before handing them of to the renderer. + grpc_core::RefCountedPtr base_node = + grpc_core::channelz::ChannelzRegistry::Get(server_id); + if (base_node == nullptr || + base_node->type() != grpc_core::channelz::BaseNode::EntityType::kServer || + start_socket_id < 0 || max_results < 0) { + return nullptr; + } + // This cast is ok since we have just checked to make sure base_node is + // actually a server node. + grpc_core::channelz::ServerNode* server_node = + static_cast(base_node.get()); + return gpr_strdup( + server_node->RenderServerSockets(start_socket_id, max_results).c_str()); +} + +char* grpc_channelz_get_channel(intptr_t channel_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_core::RefCountedPtr channel_node = + grpc_core::channelz::ChannelzRegistry::Get(channel_id); + if (channel_node == nullptr || + (channel_node->type() != + grpc_core::channelz::BaseNode::EntityType::kTopLevelChannel && + channel_node->type() != + grpc_core::channelz::BaseNode::EntityType::kInternalChannel)) { + return nullptr; + } + grpc_core::Json json = grpc_core::Json::Object{ + {"channel", channel_node->RenderJson()}, + }; + return gpr_strdup(json.Dump().c_str()); +} + +char* grpc_channelz_get_subchannel(intptr_t subchannel_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_core::RefCountedPtr subchannel_node = + grpc_core::channelz::ChannelzRegistry::Get(subchannel_id); + if (subchannel_node == nullptr || + subchannel_node->type() != + grpc_core::channelz::BaseNode::EntityType::kSubchannel) { + return nullptr; + } + grpc_core::Json json = grpc_core::Json::Object{ + {"subchannel", subchannel_node->RenderJson()}, + }; + return gpr_strdup(json.Dump().c_str()); +} + +char* grpc_channelz_get_socket(intptr_t socket_id) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_core::RefCountedPtr socket_node = + grpc_core::channelz::ChannelzRegistry::Get(socket_id); + if (socket_node == nullptr || + socket_node->type() != + grpc_core::channelz::BaseNode::EntityType::kSocket) { + return nullptr; + } + grpc_core::Json json = grpc_core::Json::Object{ + {"socket", socket_node->RenderJson()}, + }; + return gpr_strdup(json.Dump().c_str()); +} diff --git a/src/core/lib/channel/connected_channel.cc b/src/core/lib/channel/connected_channel.cc new file mode 100644 index 00000000..2a03c9cf --- /dev/null +++ b/src/core/lib/channel/connected_channel.cc @@ -0,0 +1,245 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/connected_channel.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/transport/transport.h" + +#define MAX_BUFFER_LENGTH 8192 + +typedef struct connected_channel_channel_data { + grpc_transport* transport; +} channel_data; + +struct callback_state { + grpc_closure closure; + grpc_closure* original_closure; + grpc_core::CallCombiner* call_combiner; + const char* reason; +}; +typedef struct connected_channel_call_data { + grpc_core::CallCombiner* call_combiner; + // Closures used for returning results on the call combiner. + callback_state on_complete[6]; // Max number of pending batches. + callback_state recv_initial_metadata_ready; + callback_state recv_message_ready; + callback_state recv_trailing_metadata_ready; +} call_data; + +static void run_in_call_combiner(void* arg, grpc_error_handle error) { + callback_state* state = static_cast(arg); + GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, + GRPC_ERROR_REF(error), state->reason); +} + +static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) { + run_in_call_combiner(arg, error); + gpr_free(arg); +} + +static void intercept_callback(call_data* calld, callback_state* state, + bool free_when_done, const char* reason, + grpc_closure** original_closure) { + state->original_closure = *original_closure; + state->call_combiner = calld->call_combiner; + state->reason = reason; + *original_closure = GRPC_CLOSURE_INIT( + &state->closure, + free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner, + state, grpc_schedule_on_exec_ctx); +} + +static callback_state* get_state_for_batch( + call_data* calld, grpc_transport_stream_op_batch* batch) { + if (batch->send_initial_metadata) return &calld->on_complete[0]; + if (batch->send_message) return &calld->on_complete[1]; + if (batch->send_trailing_metadata) return &calld->on_complete[2]; + if (batch->recv_initial_metadata) return &calld->on_complete[3]; + if (batch->recv_message) return &calld->on_complete[4]; + if (batch->recv_trailing_metadata) return &calld->on_complete[5]; + GPR_UNREACHABLE_CODE(return nullptr); +} + +/* We perform a small hack to locate transport data alongside the connected + channel data in call allocations, to allow everything to be pulled in minimal + cache line requests */ +#define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \ + ((grpc_stream*)(((char*)(calld)) + \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data)))) +#define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \ + ((call_data*)(((char*)(transport_stream)) - \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data)))) + +/* Intercept a call operation and either push it directly up or translate it + into transport stream operations */ +static void connected_channel_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + if (batch->recv_initial_metadata) { + callback_state* state = &calld->recv_initial_metadata_ready; + intercept_callback( + calld, state, false, "recv_initial_metadata_ready", + &batch->payload->recv_initial_metadata.recv_initial_metadata_ready); + } + if (batch->recv_message) { + callback_state* state = &calld->recv_message_ready; + intercept_callback(calld, state, false, "recv_message_ready", + &batch->payload->recv_message.recv_message_ready); + } + if (batch->recv_trailing_metadata) { + callback_state* state = &calld->recv_trailing_metadata_ready; + intercept_callback( + calld, state, false, "recv_trailing_metadata_ready", + &batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready); + } + if (batch->cancel_stream) { + // There can be more than one cancellation batch in flight at any + // given time, so we can't just pick out a fixed index into + // calld->on_complete like we can for the other ops. However, + // cancellation isn't in the fast path, so we just allocate a new + // closure for each one. + callback_state* state = + static_cast(gpr_malloc(sizeof(*state))); + intercept_callback(calld, state, true, "on_complete (cancel_stream)", + &batch->on_complete); + } else if (batch->on_complete != nullptr) { + callback_state* state = get_state_for_batch(calld, batch); + intercept_callback(calld, state, false, "on_complete", &batch->on_complete); + } + grpc_transport_perform_stream_op( + chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch); + GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport"); +} + +static void connected_channel_start_transport_op(grpc_channel_element* elem, + grpc_transport_op* op) { + channel_data* chand = static_cast(elem->channel_data); + grpc_transport_perform_op(chand->transport, op); +} + +/* Constructor for call_data */ +static grpc_error_handle connected_channel_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + calld->call_combiner = args->call_combiner; + int r = grpc_transport_init_stream( + chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld), + &args->call_stack->refcount, args->server_transport_data, args->arena); + return r == 0 ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "transport stream initialization failed"); +} + +static void set_pollset_or_pollset_set(grpc_call_element* elem, + grpc_polling_entity* pollent) { + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + grpc_transport_set_pops(chand->transport, + TRANSPORT_STREAM_FROM_CALL_DATA(calld), pollent); +} + +/* Destructor for call_data */ +static void connected_channel_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure) { + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + grpc_transport_destroy_stream(chand->transport, + TRANSPORT_STREAM_FROM_CALL_DATA(calld), + then_schedule_closure); +} + +/* Constructor for channel_data */ +static grpc_error_handle connected_channel_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + channel_data* cd = static_cast(elem->channel_data); + GPR_ASSERT(args->is_last); + cd->transport = nullptr; + return GRPC_ERROR_NONE; +} + +/* Destructor for channel_data */ +static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* cd = static_cast(elem->channel_data); + if (cd->transport) { + grpc_transport_destroy(cd->transport); + } +} + +/* No-op. */ +static void connected_channel_get_channel_info( + grpc_channel_element* /*elem*/, const grpc_channel_info* /*channel_info*/) { +} + +const grpc_channel_filter grpc_connected_filter = { + connected_channel_start_transport_stream_op_batch, + connected_channel_start_transport_op, + sizeof(call_data), + connected_channel_init_call_elem, + set_pollset_or_pollset_set, + connected_channel_destroy_call_elem, + sizeof(channel_data), + connected_channel_init_channel_elem, + connected_channel_destroy_channel_elem, + connected_channel_get_channel_info, + "connected", +}; + +static void bind_transport(grpc_channel_stack* channel_stack, + grpc_channel_element* elem, void* t) { + channel_data* cd = static_cast(elem->channel_data); + GPR_ASSERT(elem->filter == &grpc_connected_filter); + GPR_ASSERT(cd->transport == nullptr); + cd->transport = static_cast(t); + + /* HACK(ctiller): increase call stack size for the channel to make space + for channel data. We need a cleaner (but performant) way to do this, + and I'm not sure what that is yet. + This is only "safe" because call stacks place no additional data after + the last call element, and the last call element MUST be the connected + channel. */ + channel_stack->call_stack_size += + grpc_transport_stream_size(static_cast(t)); +} + +bool grpc_add_connected_filter(grpc_channel_stack_builder* builder) { + grpc_transport* t = grpc_channel_stack_builder_get_transport(builder); + GPR_ASSERT(t != nullptr); + return grpc_channel_stack_builder_append_filter( + builder, &grpc_connected_filter, bind_transport, t); +} + +grpc_stream* grpc_connected_channel_get_stream(grpc_call_element* elem) { + call_data* calld = static_cast(elem->call_data); + return TRANSPORT_STREAM_FROM_CALL_DATA(calld); +} diff --git a/src/core/lib/channel/handshaker.cc b/src/core/lib/channel/handshaker.cc new file mode 100644 index 00000000..7ebed4b9 --- /dev/null +++ b/src/core/lib/channel/handshaker.cc @@ -0,0 +1,222 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/handshaker.h" + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +TraceFlag grpc_handshaker_trace(false, "handshaker"); + +namespace { + +std::string HandshakerArgsString(HandshakerArgs* args) { + size_t num_args = args->args != nullptr ? args->args->num_args : 0; + size_t read_buffer_length = + args->read_buffer != nullptr ? args->read_buffer->length : 0; + return absl::StrFormat( + "{endpoint=%p, args=%p {size=%" PRIuPTR + ": %s}, read_buffer=%p (length=%" PRIuPTR "), exit_early=%d}", + args->endpoint, args->args, num_args, + grpc_channel_args_string(args->args), args->read_buffer, + read_buffer_length, args->exit_early); +} + +} // namespace + +HandshakeManager::HandshakeManager() {} + +void HandshakeManager::Add(RefCountedPtr handshaker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_handshaker_trace)) { + gpr_log( + GPR_INFO, + "handshake_manager %p: adding handshaker %s [%p] at index %" PRIuPTR, + this, handshaker->name(), handshaker.get(), handshakers_.size()); + } + MutexLock lock(&mu_); + handshakers_.push_back(std::move(handshaker)); +} + +HandshakeManager::~HandshakeManager() { handshakers_.clear(); } + +void HandshakeManager::Shutdown(grpc_error_handle why) { + { + MutexLock lock(&mu_); + // Shutdown the handshaker that's currently in progress, if any. + if (!is_shutdown_ && index_ > 0) { + is_shutdown_ = true; + handshakers_[index_ - 1]->Shutdown(GRPC_ERROR_REF(why)); + } + } + GRPC_ERROR_UNREF(why); +} + +// Helper function to call either the next handshaker or the +// on_handshake_done callback. +// Returns true if we've scheduled the on_handshake_done callback. +bool HandshakeManager::CallNextHandshakerLocked(grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_handshaker_trace)) { + gpr_log(GPR_INFO, + "handshake_manager %p: error=%s shutdown=%d index=%" PRIuPTR + ", args=%s", + this, grpc_error_std_string(error).c_str(), is_shutdown_, index_, + HandshakerArgsString(&args_).c_str()); + } + GPR_ASSERT(index_ <= handshakers_.size()); + // If we got an error or we've been shut down or we're exiting early or + // we've finished the last handshaker, invoke the on_handshake_done + // callback. Otherwise, call the next handshaker. + if (error != GRPC_ERROR_NONE || is_shutdown_ || args_.exit_early || + index_ == handshakers_.size()) { + if (error == GRPC_ERROR_NONE && is_shutdown_) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("handshaker shutdown"); + // It is possible that the endpoint has already been destroyed by + // a shutdown call while this callback was sitting on the ExecCtx + // with no error. + if (args_.endpoint != nullptr) { + // TODO(roth): It is currently necessary to shutdown endpoints + // before destroying then, even when we know that there are no + // pending read/write callbacks. This should be fixed, at which + // point this can be removed. + grpc_endpoint_shutdown(args_.endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_destroy(args_.endpoint); + args_.endpoint = nullptr; + grpc_channel_args_destroy(args_.args); + args_.args = nullptr; + grpc_slice_buffer_destroy_internal(args_.read_buffer); + gpr_free(args_.read_buffer); + args_.read_buffer = nullptr; + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_handshaker_trace)) { + gpr_log(GPR_INFO, + "handshake_manager %p: handshaking complete -- scheduling " + "on_handshake_done with error=%s", + this, grpc_error_std_string(error).c_str()); + } + // Cancel deadline timer, since we're invoking the on_handshake_done + // callback now. + grpc_timer_cancel(&deadline_timer_); + ExecCtx::Run(DEBUG_LOCATION, &on_handshake_done_, error); + is_shutdown_ = true; + } else { + auto handshaker = handshakers_[index_]; + if (GRPC_TRACE_FLAG_ENABLED(grpc_handshaker_trace)) { + gpr_log( + GPR_INFO, + "handshake_manager %p: calling handshaker %s [%p] at index %" PRIuPTR, + this, handshaker->name(), handshaker.get(), index_); + } + handshaker->DoHandshake(acceptor_, &call_next_handshaker_, &args_); + } + ++index_; + return is_shutdown_; +} + +void HandshakeManager::CallNextHandshakerFn(void* arg, + grpc_error_handle error) { + auto* mgr = static_cast(arg); + bool done; + { + MutexLock lock(&mgr->mu_); + done = mgr->CallNextHandshakerLocked(GRPC_ERROR_REF(error)); + } + // If we're invoked the final callback, we won't be coming back + // to this function, so we can release our reference to the + // handshake manager. + if (done) { + mgr->Unref(); + } +} + +void HandshakeManager::OnTimeoutFn(void* arg, grpc_error_handle error) { + auto* mgr = static_cast(arg); + if (error == GRPC_ERROR_NONE) { // Timer fired, rather than being cancelled + mgr->Shutdown(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake timed out")); + } + mgr->Unref(); +} + +void HandshakeManager::DoHandshake(grpc_endpoint* endpoint, + const grpc_channel_args* channel_args, + grpc_millis deadline, + grpc_tcp_server_acceptor* acceptor, + grpc_iomgr_cb_func on_handshake_done, + void* user_data) { + bool done; + { + MutexLock lock(&mu_); + GPR_ASSERT(index_ == 0); + // Construct handshaker args. These will be passed through all + // handshakers and eventually be freed by the on_handshake_done callback. + args_.endpoint = endpoint; + args_.args = grpc_channel_args_copy(channel_args); + args_.user_data = user_data; + args_.read_buffer = + static_cast(gpr_malloc(sizeof(*args_.read_buffer))); + grpc_slice_buffer_init(args_.read_buffer); + if (acceptor != nullptr && acceptor->external_connection && + acceptor->pending_data != nullptr) { + grpc_slice_buffer_swap(args_.read_buffer, + &(acceptor->pending_data->data.raw.slice_buffer)); + } + // Initialize state needed for calling handshakers. + acceptor_ = acceptor; + GRPC_CLOSURE_INIT(&call_next_handshaker_, + &HandshakeManager::CallNextHandshakerFn, this, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_handshake_done_, on_handshake_done, &args_, + grpc_schedule_on_exec_ctx); + // Start deadline timer, which owns a ref. + Ref().release(); + GRPC_CLOSURE_INIT(&on_timeout_, &HandshakeManager::OnTimeoutFn, this, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&deadline_timer_, deadline, &on_timeout_); + // Start first handshaker, which also owns a ref. + Ref().release(); + done = CallNextHandshakerLocked(GRPC_ERROR_NONE); + } + if (done) { + Unref(); + } +} + +} // namespace grpc_core + +void grpc_handshake_manager_add(grpc_handshake_manager* mgr, + grpc_handshaker* handshaker) { + // This is a transition method to aid the API change for handshakers. + grpc_core::RefCountedPtr refd_hs( + static_cast(handshaker)); + mgr->Add(refd_hs); +} diff --git a/src/core/lib/channel/handshaker_registry.cc b/src/core/lib/channel/handshaker_registry.cc new file mode 100644 index 00000000..d9d7bc9d --- /dev/null +++ b/src/core/lib/channel/handshaker_registry.cc @@ -0,0 +1,50 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/handshaker_registry.h" + +namespace grpc_core { + +void HandshakerRegistry::Builder::RegisterHandshakerFactory( + bool at_start, HandshakerType handshaker_type, + std::unique_ptr factory) { + auto& vec = factories_[handshaker_type]; + auto where = at_start ? vec.begin() : vec.end(); + vec.insert(where, std::move(factory)); +} + +HandshakerRegistry HandshakerRegistry::Builder::Build() { + HandshakerRegistry out; + for (size_t i = 0; i < NUM_HANDSHAKER_TYPES; i++) { + out.factories_[i] = std::move(factories_[i]); + } + return out; +} + +void HandshakerRegistry::AddHandshakers(HandshakerType handshaker_type, + const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) const { + for (const auto& factory : factories_[handshaker_type]) { + factory->AddHandshakers(args, interested_parties, handshake_mgr); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/channel/status_util.cc b/src/core/lib/channel/status_util.cc new file mode 100644 index 00000000..0c60030d --- /dev/null +++ b/src/core/lib/channel/status_util.cc @@ -0,0 +1,109 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/channel/status_util.h" + +#include "src/core/lib/gpr/useful.h" + +struct status_string_entry { + const char* str; + grpc_status_code status; +}; +static const status_string_entry g_status_string_entries[] = { + {"OK", GRPC_STATUS_OK}, + {"CANCELLED", GRPC_STATUS_CANCELLED}, + {"UNKNOWN", GRPC_STATUS_UNKNOWN}, + {"INVALID_ARGUMENT", GRPC_STATUS_INVALID_ARGUMENT}, + {"DEADLINE_EXCEEDED", GRPC_STATUS_DEADLINE_EXCEEDED}, + {"NOT_FOUND", GRPC_STATUS_NOT_FOUND}, + {"ALREADY_EXISTS", GRPC_STATUS_ALREADY_EXISTS}, + {"PERMISSION_DENIED", GRPC_STATUS_PERMISSION_DENIED}, + {"UNAUTHENTICATED", GRPC_STATUS_UNAUTHENTICATED}, + {"RESOURCE_EXHAUSTED", GRPC_STATUS_RESOURCE_EXHAUSTED}, + {"FAILED_PRECONDITION", GRPC_STATUS_FAILED_PRECONDITION}, + {"ABORTED", GRPC_STATUS_ABORTED}, + {"OUT_OF_RANGE", GRPC_STATUS_OUT_OF_RANGE}, + {"UNIMPLEMENTED", GRPC_STATUS_UNIMPLEMENTED}, + {"INTERNAL", GRPC_STATUS_INTERNAL}, + {"UNAVAILABLE", GRPC_STATUS_UNAVAILABLE}, + {"DATA_LOSS", GRPC_STATUS_DATA_LOSS}, +}; + +bool grpc_status_code_from_string(const char* status_str, + grpc_status_code* status) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(g_status_string_entries); ++i) { + if (strcmp(status_str, g_status_string_entries[i].str) == 0) { + *status = g_status_string_entries[i].status; + return true; + } + } + return false; +} + +const char* grpc_status_code_to_string(grpc_status_code status) { + switch (status) { + case GRPC_STATUS_OK: + return "OK"; + case GRPC_STATUS_CANCELLED: + return "CANCELLED"; + case GRPC_STATUS_UNKNOWN: + return "UNKNOWN"; + case GRPC_STATUS_INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case GRPC_STATUS_DEADLINE_EXCEEDED: + return "DEADLINE_EXCEEDED"; + case GRPC_STATUS_NOT_FOUND: + return "NOT_FOUND"; + case GRPC_STATUS_ALREADY_EXISTS: + return "ALREADY_EXISTS"; + case GRPC_STATUS_PERMISSION_DENIED: + return "PERMISSION_DENIED"; + case GRPC_STATUS_RESOURCE_EXHAUSTED: + return "RESOURCE_EXHAUSTED"; + case GRPC_STATUS_FAILED_PRECONDITION: + return "FAILED_PRECONDITION"; + case GRPC_STATUS_ABORTED: + return "ABORTED"; + case GRPC_STATUS_OUT_OF_RANGE: + return "OUT_OF_RANGE"; + case GRPC_STATUS_UNIMPLEMENTED: + return "UNIMPLEMENTED"; + case GRPC_STATUS_INTERNAL: + return "INTERNAL"; + case GRPC_STATUS_UNAVAILABLE: + return "UNAVAILABLE"; + case GRPC_STATUS_DATA_LOSS: + return "DATA_LOSS"; + case GRPC_STATUS_UNAUTHENTICATED: + return "UNAUTHENTICATED"; + default: + return "UNKNOWN"; + } +} + +bool grpc_status_code_from_int(int status_int, grpc_status_code* status) { + // The range of status code enum is [0, 16], 0 is OK, 16 is UNAUTHENTICATED. + if (status_int < GRPC_STATUS_OK || status_int > GRPC_STATUS_UNAUTHENTICATED) { + *status = GRPC_STATUS_UNKNOWN; + return false; + } + *status = static_cast(status_int); + return true; +} diff --git a/src/core/lib/compression/compression.cc b/src/core/lib/compression/compression.cc new file mode 100644 index 00000000..3d5d11f6 --- /dev/null +++ b/src/core/lib/compression/compression.cc @@ -0,0 +1,183 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include + +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/compression/compression_internal.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/transport/static_metadata.h" + +int grpc_compression_algorithm_is_message( + grpc_compression_algorithm algorithm) { + return (algorithm >= GRPC_COMPRESS_DEFLATE && algorithm <= GRPC_COMPRESS_GZIP) + ? 1 + : 0; +} + +int grpc_compression_algorithm_is_stream(grpc_compression_algorithm algorithm) { + return (algorithm == GRPC_COMPRESS_STREAM_GZIP) ? 1 : 0; +} + +int grpc_compression_algorithm_parse(grpc_slice name, + grpc_compression_algorithm* algorithm) { + if (grpc_slice_eq_static_interned(name, GRPC_MDSTR_IDENTITY)) { + *algorithm = GRPC_COMPRESS_NONE; + return 1; + } else if (grpc_slice_eq_static_interned(name, GRPC_MDSTR_DEFLATE)) { + *algorithm = GRPC_COMPRESS_DEFLATE; + return 1; + } else if (grpc_slice_eq_static_interned(name, GRPC_MDSTR_GZIP)) { + *algorithm = GRPC_COMPRESS_GZIP; + return 1; + } else if (grpc_slice_eq_static_interned(name, + GRPC_MDSTR_STREAM_SLASH_GZIP)) { + *algorithm = GRPC_COMPRESS_STREAM_GZIP; + return 1; + } else { + return 0; + } +} + +int grpc_compression_algorithm_name(grpc_compression_algorithm algorithm, + const char** name) { + GRPC_API_TRACE("grpc_compression_algorithm_name(algorithm=%d, name=%p)", 2, + ((int)algorithm, name)); + switch (algorithm) { + case GRPC_COMPRESS_NONE: + *name = "identity"; + return 1; + case GRPC_COMPRESS_DEFLATE: + *name = "deflate"; + return 1; + case GRPC_COMPRESS_GZIP: + *name = "gzip"; + return 1; + case GRPC_COMPRESS_STREAM_GZIP: + *name = "stream/gzip"; + return 1; + case GRPC_COMPRESS_ALGORITHMS_COUNT: + return 0; + } + return 0; +} + +grpc_compression_algorithm grpc_compression_algorithm_for_level( + grpc_compression_level level, uint32_t accepted_encodings) { + grpc_compression_algorithm algo; + if (level == GRPC_COMPRESS_LEVEL_NONE) { + return GRPC_COMPRESS_NONE; + } else if (level <= GRPC_COMPRESS_LEVEL_HIGH) { + // TODO(mxyan): Design algorithm to select from all algorithms, including + // stream compression algorithm + if (!grpc_compression_algorithm_from_message_stream_compression_algorithm( + &algo, + grpc_message_compression_algorithm_for_level( + level, + grpc_compression_bitset_to_message_bitset(accepted_encodings)), + static_cast(0))) { + gpr_log(GPR_ERROR, "Parse compression level error"); + return GRPC_COMPRESS_NONE; + } + return algo; + } else { + gpr_log(GPR_ERROR, "Unknown compression level: %d", level); + return GRPC_COMPRESS_NONE; + } +} + +void grpc_compression_options_init(grpc_compression_options* opts) { + memset(opts, 0, sizeof(*opts)); + /* all enabled by default */ + opts->enabled_algorithms_bitset = (1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1; +} + +void grpc_compression_options_enable_algorithm( + grpc_compression_options* opts, grpc_compression_algorithm algorithm) { + grpc_core::SetBit(&opts->enabled_algorithms_bitset, algorithm); +} + +void grpc_compression_options_disable_algorithm( + grpc_compression_options* opts, grpc_compression_algorithm algorithm) { + grpc_core::ClearBit(&opts->enabled_algorithms_bitset, algorithm); +} + +int grpc_compression_options_is_algorithm_enabled( + const grpc_compression_options* opts, + grpc_compression_algorithm algorithm) { + return grpc_compression_options_is_algorithm_enabled_internal(opts, + algorithm); +} + +grpc_slice grpc_compression_algorithm_slice( + grpc_compression_algorithm algorithm) { + switch (algorithm) { + case GRPC_COMPRESS_NONE: + return GRPC_MDSTR_IDENTITY; + case GRPC_COMPRESS_DEFLATE: + return GRPC_MDSTR_DEFLATE; + case GRPC_COMPRESS_GZIP: + return GRPC_MDSTR_GZIP; + case GRPC_COMPRESS_STREAM_GZIP: + return GRPC_MDSTR_STREAM_SLASH_GZIP; + case GRPC_COMPRESS_ALGORITHMS_COUNT: + return grpc_empty_slice(); + } + return grpc_empty_slice(); +} + +grpc_compression_algorithm grpc_compression_algorithm_from_slice( + const grpc_slice& str) { + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_IDENTITY)) { + return GRPC_COMPRESS_NONE; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_DEFLATE)) { + return GRPC_COMPRESS_DEFLATE; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_GZIP)) { + return GRPC_COMPRESS_GZIP; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_STREAM_SLASH_GZIP)) { + return GRPC_COMPRESS_STREAM_GZIP; + } + return GRPC_COMPRESS_ALGORITHMS_COUNT; +} + +grpc_mdelem grpc_compression_encoding_mdelem( + grpc_compression_algorithm algorithm) { + switch (algorithm) { + case GRPC_COMPRESS_NONE: + return GRPC_MDELEM_GRPC_ENCODING_IDENTITY; + case GRPC_COMPRESS_DEFLATE: + return GRPC_MDELEM_GRPC_ENCODING_DEFLATE; + case GRPC_COMPRESS_GZIP: + return GRPC_MDELEM_GRPC_ENCODING_GZIP; + case GRPC_COMPRESS_STREAM_GZIP: + return GRPC_MDELEM_GRPC_ENCODING_GZIP; + default: + break; + } + return GRPC_MDNULL; +} diff --git a/src/core/lib/compression/compression_args.cc b/src/core/lib/compression/compression_args.cc new file mode 100644 index 00000000..ee55c2f3 --- /dev/null +++ b/src/core/lib/compression/compression_args.cc @@ -0,0 +1,138 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/compression_args.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" + +grpc_compression_algorithm +grpc_channel_args_get_channel_default_compression_algorithm( + const grpc_channel_args* a) { + size_t i; + if (a == nullptr) return GRPC_COMPRESS_NONE; + for (i = 0; i < a->num_args; ++i) { + if (a->args[i].type == GRPC_ARG_INTEGER && + !strcmp(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, a->args[i].key)) { + grpc_compression_algorithm default_algorithm = + static_cast(a->args[i].value.integer); + return default_algorithm < GRPC_COMPRESS_ALGORITHMS_COUNT + ? default_algorithm + : GRPC_COMPRESS_NONE; + } + } + return GRPC_COMPRESS_NONE; +} + +grpc_channel_args* grpc_channel_args_set_channel_default_compression_algorithm( + grpc_channel_args* a, grpc_compression_algorithm algorithm) { + GPR_ASSERT(algorithm < GRPC_COMPRESS_ALGORITHMS_COUNT); + grpc_arg tmp; + tmp.type = GRPC_ARG_INTEGER; + tmp.key = const_cast(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM); + tmp.value.integer = algorithm; + return grpc_channel_args_copy_and_add(a, &tmp, 1); +} + +/** Returns 1 if the argument for compression algorithm's enabled states bitset + * was found in \a a, returning the arg's value in \a states. Otherwise, returns + * 0. */ +static int find_compression_algorithm_states_bitset(const grpc_channel_args* a, + int** states_arg) { + if (a != nullptr) { + size_t i; + for (i = 0; i < a->num_args; ++i) { + if (a->args[i].type == GRPC_ARG_INTEGER && + !strcmp(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET, + a->args[i].key)) { + *states_arg = &a->args[i].value.integer; + **states_arg = + (**states_arg & ((1 << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1)) | + 0x1; /* forcefully enable support for no compression */ + return 1; + } + } + } + return 0; /* GPR_FALSE */ +} + +grpc_channel_args* grpc_channel_args_compression_algorithm_set_state( + grpc_channel_args** a, grpc_compression_algorithm algorithm, int state) { + int* states_arg = nullptr; + grpc_channel_args* result = *a; + const int states_arg_found = + find_compression_algorithm_states_bitset(*a, &states_arg); + + if (grpc_channel_args_get_channel_default_compression_algorithm(*a) == + algorithm && + state == 0) { + const char* algo_name = nullptr; + GPR_ASSERT(grpc_compression_algorithm_name(algorithm, &algo_name) != 0); + gpr_log(GPR_ERROR, + "Tried to disable default compression algorithm '%s'. The " + "operation has been ignored.", + algo_name); + } else if (states_arg_found) { + if (state != 0) { + grpc_core::SetBit(reinterpret_cast(states_arg), algorithm); + } else if (algorithm != GRPC_COMPRESS_NONE) { + grpc_core::ClearBit(reinterpret_cast(states_arg), algorithm); + } + } else { + /* create a new arg */ + grpc_arg tmp; + tmp.type = GRPC_ARG_INTEGER; + tmp.key = + const_cast(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET); + /* all enabled by default */ + tmp.value.integer = (1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1; + if (state != 0) { + grpc_core::SetBit(reinterpret_cast(&tmp.value.integer), + algorithm); + } else if (algorithm != GRPC_COMPRESS_NONE) { + grpc_core::ClearBit(reinterpret_cast(&tmp.value.integer), + algorithm); + } + result = grpc_channel_args_copy_and_add(*a, &tmp, 1); + grpc_channel_args_destroy(*a); + *a = result; + } + return result; +} + +uint32_t grpc_channel_args_compression_algorithm_get_states( + const grpc_channel_args* a) { + int* states_arg; + if (find_compression_algorithm_states_bitset(a, &states_arg)) { + return static_cast(*states_arg); + } else { + return (1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1; /* All algs. enabled */ + } +} diff --git a/src/core/lib/compression/compression_internal.cc b/src/core/lib/compression/compression_internal.cc new file mode 100644 index 00000000..ae2d7458 --- /dev/null +++ b/src/core/lib/compression/compression_internal.cc @@ -0,0 +1,283 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/compression_internal.h" + +#include +#include + +#include + +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/transport/static_metadata.h" + +/* Interfaces related to MD */ + +grpc_message_compression_algorithm +grpc_message_compression_algorithm_from_slice(const grpc_slice& str) { + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_IDENTITY)) { + return GRPC_MESSAGE_COMPRESS_NONE; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_DEFLATE)) { + return GRPC_MESSAGE_COMPRESS_DEFLATE; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_GZIP)) { + return GRPC_MESSAGE_COMPRESS_GZIP; + } + return GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT; +} + +grpc_stream_compression_algorithm grpc_stream_compression_algorithm_from_slice( + const grpc_slice& str) { + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_IDENTITY)) { + return GRPC_STREAM_COMPRESS_NONE; + } + if (grpc_slice_eq_static_interned(str, GRPC_MDSTR_GZIP)) { + return GRPC_STREAM_COMPRESS_GZIP; + } + return GRPC_STREAM_COMPRESS_ALGORITHMS_COUNT; +} + +grpc_mdelem grpc_message_compression_encoding_mdelem( + grpc_message_compression_algorithm algorithm) { + switch (algorithm) { + case GRPC_MESSAGE_COMPRESS_NONE: + return GRPC_MDELEM_GRPC_ENCODING_IDENTITY; + case GRPC_MESSAGE_COMPRESS_DEFLATE: + return GRPC_MDELEM_GRPC_ENCODING_DEFLATE; + case GRPC_MESSAGE_COMPRESS_GZIP: + return GRPC_MDELEM_GRPC_ENCODING_GZIP; + default: + break; + } + return GRPC_MDNULL; +} + +grpc_mdelem grpc_stream_compression_encoding_mdelem( + grpc_stream_compression_algorithm algorithm) { + switch (algorithm) { + case GRPC_STREAM_COMPRESS_NONE: + return GRPC_MDELEM_CONTENT_ENCODING_IDENTITY; + case GRPC_STREAM_COMPRESS_GZIP: + return GRPC_MDELEM_CONTENT_ENCODING_GZIP; + default: + break; + } + return GRPC_MDNULL; +} + +/* Interfaces performing transformation between compression algorithms and + * levels. */ +grpc_message_compression_algorithm +grpc_compression_algorithm_to_message_compression_algorithm( + grpc_compression_algorithm algo) { + switch (algo) { + case GRPC_COMPRESS_DEFLATE: + return GRPC_MESSAGE_COMPRESS_DEFLATE; + case GRPC_COMPRESS_GZIP: + return GRPC_MESSAGE_COMPRESS_GZIP; + default: + return GRPC_MESSAGE_COMPRESS_NONE; + } +} + +grpc_stream_compression_algorithm +grpc_compression_algorithm_to_stream_compression_algorithm( + grpc_compression_algorithm algo) { + switch (algo) { + case GRPC_COMPRESS_STREAM_GZIP: + return GRPC_STREAM_COMPRESS_GZIP; + default: + return GRPC_STREAM_COMPRESS_NONE; + } +} + +uint32_t grpc_compression_bitset_to_message_bitset(uint32_t bitset) { + return bitset & ((1u << GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) - 1); +} + +uint32_t grpc_compression_bitset_to_stream_bitset(uint32_t bitset) { + uint32_t identity = (bitset & 1u); + uint32_t other_bits = + (bitset >> (GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT - 1)) & + ((1u << GRPC_STREAM_COMPRESS_ALGORITHMS_COUNT) - 2); + return identity | other_bits; +} + +uint32_t grpc_compression_bitset_from_message_stream_compression_bitset( + uint32_t message_bitset, uint32_t stream_bitset) { + uint32_t offset_stream_bitset = + (stream_bitset & 1u) | + ((stream_bitset & (~1u)) << (GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT - 1)); + return message_bitset | offset_stream_bitset; +} + +int grpc_compression_algorithm_from_message_stream_compression_algorithm( + grpc_compression_algorithm* algorithm, + grpc_message_compression_algorithm message_algorithm, + grpc_stream_compression_algorithm stream_algorithm) { + if (message_algorithm != GRPC_MESSAGE_COMPRESS_NONE && + stream_algorithm != GRPC_STREAM_COMPRESS_NONE) { + *algorithm = GRPC_COMPRESS_NONE; + return 0; + } + if (message_algorithm == GRPC_MESSAGE_COMPRESS_NONE) { + switch (stream_algorithm) { + case GRPC_STREAM_COMPRESS_NONE: + *algorithm = GRPC_COMPRESS_NONE; + return 1; + case GRPC_STREAM_COMPRESS_GZIP: + *algorithm = GRPC_COMPRESS_STREAM_GZIP; + return 1; + default: + *algorithm = GRPC_COMPRESS_NONE; + return 0; + } + } else { + switch (message_algorithm) { + case GRPC_MESSAGE_COMPRESS_NONE: + *algorithm = GRPC_COMPRESS_NONE; + return 1; + case GRPC_MESSAGE_COMPRESS_DEFLATE: + *algorithm = GRPC_COMPRESS_DEFLATE; + return 1; + case GRPC_MESSAGE_COMPRESS_GZIP: + *algorithm = GRPC_COMPRESS_GZIP; + return 1; + default: + *algorithm = GRPC_COMPRESS_NONE; + return 0; + } + } +} + +/* Interfaces for message compression. */ + +int grpc_message_compression_algorithm_name( + grpc_message_compression_algorithm algorithm, const char** name) { + GRPC_API_TRACE( + "grpc_message_compression_algorithm_name(algorithm=%d, name=%p)", 2, + ((int)algorithm, name)); + switch (algorithm) { + case GRPC_MESSAGE_COMPRESS_NONE: + *name = "identity"; + return 1; + case GRPC_MESSAGE_COMPRESS_DEFLATE: + *name = "deflate"; + return 1; + case GRPC_MESSAGE_COMPRESS_GZIP: + *name = "gzip"; + return 1; + case GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT: + return 0; + } + return 0; +} + +/* TODO(dgq): Add the ability to specify parameters to the individual + * compression algorithms */ +grpc_message_compression_algorithm grpc_message_compression_algorithm_for_level( + grpc_compression_level level, uint32_t accepted_encodings) { + GRPC_API_TRACE("grpc_message_compression_algorithm_for_level(level=%d)", 1, + ((int)level)); + if (level > GRPC_COMPRESS_LEVEL_HIGH) { + gpr_log(GPR_ERROR, "Unknown message compression level %d.", + static_cast(level)); + abort(); + } + + const size_t num_supported = + grpc_core::BitCount(accepted_encodings) - 1; /* discard NONE */ + if (level == GRPC_COMPRESS_LEVEL_NONE || num_supported == 0) { + return GRPC_MESSAGE_COMPRESS_NONE; + } + + GPR_ASSERT(level > 0); + + /* Establish a "ranking" or compression algorithms in increasing order of + * compression. + * This is simplistic and we will probably want to introduce other dimensions + * in the future (cpu/memory cost, etc). */ + const grpc_message_compression_algorithm algos_ranking[] = { + GRPC_MESSAGE_COMPRESS_GZIP, GRPC_MESSAGE_COMPRESS_DEFLATE}; + + /* intersect algos_ranking with the supported ones keeping the ranked order */ + grpc_message_compression_algorithm + sorted_supported_algos[GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT]; + size_t algos_supported_idx = 0; + for (size_t i = 0; i < GPR_ARRAY_SIZE(algos_ranking); i++) { + const grpc_message_compression_algorithm alg = algos_ranking[i]; + for (size_t j = 0; j < num_supported; j++) { + if (grpc_core::GetBit(accepted_encodings, alg) == 1) { + /* if \a alg in supported */ + sorted_supported_algos[algos_supported_idx++] = alg; + break; + } + } + if (algos_supported_idx == num_supported) break; + } + + switch (level) { + case GRPC_COMPRESS_LEVEL_NONE: + abort(); /* should have been handled already */ + case GRPC_COMPRESS_LEVEL_LOW: + return sorted_supported_algos[0]; + case GRPC_COMPRESS_LEVEL_MED: + return sorted_supported_algos[num_supported / 2]; + case GRPC_COMPRESS_LEVEL_HIGH: + return sorted_supported_algos[num_supported - 1]; + default: + abort(); + }; +} + +int grpc_message_compression_algorithm_parse( + grpc_slice value, grpc_message_compression_algorithm* algorithm) { + if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_IDENTITY)) { + *algorithm = GRPC_MESSAGE_COMPRESS_NONE; + return 1; + } else if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_DEFLATE)) { + *algorithm = GRPC_MESSAGE_COMPRESS_DEFLATE; + return 1; + } else if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_GZIP)) { + *algorithm = GRPC_MESSAGE_COMPRESS_GZIP; + return 1; + } else { + return 0; + } +} + +/* Interfaces for stream compression. */ + +int grpc_stream_compression_algorithm_parse( + grpc_slice value, grpc_stream_compression_algorithm* algorithm) { + if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_IDENTITY)) { + *algorithm = GRPC_STREAM_COMPRESS_NONE; + return 1; + } else if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_GZIP)) { + *algorithm = GRPC_STREAM_COMPRESS_GZIP; + return 1; + } else { + return 0; + } +} diff --git a/src/core/lib/compression/message_compress.cc b/src/core/lib/compression/message_compress.cc new file mode 100644 index 00000000..797d5edf --- /dev/null +++ b/src/core/lib/compression/message_compress.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/message_compress.h" + +#include + +#include + +#include +#include + +#include "src/core/lib/slice/slice_internal.h" + +#define OUTPUT_BLOCK_SIZE 1024 + +static int zlib_body(z_stream* zs, grpc_slice_buffer* input, + grpc_slice_buffer* output, + int (*flate)(z_stream* zs, int flush)) { + int r = Z_STREAM_END; /* Do not fail on an empty input. */ + int flush; + size_t i; + grpc_slice outbuf = GRPC_SLICE_MALLOC(OUTPUT_BLOCK_SIZE); + const uInt uint_max = ~static_cast(0); + + GPR_ASSERT(GRPC_SLICE_LENGTH(outbuf) <= uint_max); + zs->avail_out = static_cast GRPC_SLICE_LENGTH(outbuf); + zs->next_out = GRPC_SLICE_START_PTR(outbuf); + flush = Z_NO_FLUSH; + for (i = 0; i < input->count; i++) { + if (i == input->count - 1) flush = Z_FINISH; + GPR_ASSERT(GRPC_SLICE_LENGTH(input->slices[i]) <= uint_max); + zs->avail_in = static_cast GRPC_SLICE_LENGTH(input->slices[i]); + zs->next_in = GRPC_SLICE_START_PTR(input->slices[i]); + do { + if (zs->avail_out == 0) { + grpc_slice_buffer_add_indexed(output, outbuf); + outbuf = GRPC_SLICE_MALLOC(OUTPUT_BLOCK_SIZE); + GPR_ASSERT(GRPC_SLICE_LENGTH(outbuf) <= uint_max); + zs->avail_out = static_cast GRPC_SLICE_LENGTH(outbuf); + zs->next_out = GRPC_SLICE_START_PTR(outbuf); + } + r = flate(zs, flush); + if (r < 0 && r != Z_BUF_ERROR /* not fatal */) { + gpr_log(GPR_INFO, "zlib error (%d)", r); + goto error; + } + } while (zs->avail_out == 0); + if (zs->avail_in) { + gpr_log(GPR_INFO, "zlib: not all input consumed"); + goto error; + } + } + if (r != Z_STREAM_END) { + gpr_log(GPR_INFO, "zlib: Data error"); + goto error; + } + + GPR_ASSERT(outbuf.refcount); + outbuf.data.refcounted.length -= zs->avail_out; + grpc_slice_buffer_add_indexed(output, outbuf); + + return 1; + +error: + grpc_slice_unref_internal(outbuf); + return 0; +} + +static void* zalloc_gpr(void* /*opaque*/, unsigned int items, + unsigned int size) { + return gpr_malloc(items * size); +} + +static void zfree_gpr(void* /*opaque*/, void* address) { gpr_free(address); } + +static int zlib_compress(grpc_slice_buffer* input, grpc_slice_buffer* output, + int gzip) { + z_stream zs; + int r; + size_t i; + size_t count_before = output->count; + size_t length_before = output->length; + memset(&zs, 0, sizeof(zs)); + zs.zalloc = zalloc_gpr; + zs.zfree = zfree_gpr; + r = deflateInit2(&zs, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 15 | (gzip ? 16 : 0), + 8, Z_DEFAULT_STRATEGY); + GPR_ASSERT(r == Z_OK); + r = zlib_body(&zs, input, output, deflate) && output->length < input->length; + if (!r) { + for (i = count_before; i < output->count; i++) { + grpc_slice_unref_internal(output->slices[i]); + } + output->count = count_before; + output->length = length_before; + } + deflateEnd(&zs); + return r; +} + +static int zlib_decompress(grpc_slice_buffer* input, grpc_slice_buffer* output, + int gzip) { + z_stream zs; + int r; + size_t i; + size_t count_before = output->count; + size_t length_before = output->length; + memset(&zs, 0, sizeof(zs)); + zs.zalloc = zalloc_gpr; + zs.zfree = zfree_gpr; + r = inflateInit2(&zs, 15 | (gzip ? 16 : 0)); + GPR_ASSERT(r == Z_OK); + r = zlib_body(&zs, input, output, inflate); + if (!r) { + for (i = count_before; i < output->count; i++) { + grpc_slice_unref_internal(output->slices[i]); + } + output->count = count_before; + output->length = length_before; + } + inflateEnd(&zs); + return r; +} + +static int copy(grpc_slice_buffer* input, grpc_slice_buffer* output) { + size_t i; + for (i = 0; i < input->count; i++) { + grpc_slice_buffer_add(output, grpc_slice_ref_internal(input->slices[i])); + } + return 1; +} + +static int compress_inner(grpc_message_compression_algorithm algorithm, + grpc_slice_buffer* input, grpc_slice_buffer* output) { + switch (algorithm) { + case GRPC_MESSAGE_COMPRESS_NONE: + /* the fallback path always needs to be send uncompressed: we simply + rely on that here */ + return 0; + case GRPC_MESSAGE_COMPRESS_DEFLATE: + return zlib_compress(input, output, 0); + case GRPC_MESSAGE_COMPRESS_GZIP: + return zlib_compress(input, output, 1); + case GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT: + break; + } + gpr_log(GPR_ERROR, "invalid compression algorithm %d", algorithm); + return 0; +} + +int grpc_msg_compress(grpc_message_compression_algorithm algorithm, + grpc_slice_buffer* input, grpc_slice_buffer* output) { + if (!compress_inner(algorithm, input, output)) { + copy(input, output); + return 0; + } + return 1; +} + +int grpc_msg_decompress(grpc_message_compression_algorithm algorithm, + grpc_slice_buffer* input, grpc_slice_buffer* output) { + switch (algorithm) { + case GRPC_MESSAGE_COMPRESS_NONE: + return copy(input, output); + case GRPC_MESSAGE_COMPRESS_DEFLATE: + return zlib_decompress(input, output, 0); + case GRPC_MESSAGE_COMPRESS_GZIP: + return zlib_decompress(input, output, 1); + case GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT: + break; + } + gpr_log(GPR_ERROR, "invalid compression algorithm %d", algorithm); + return 0; +} diff --git a/src/core/lib/compression/stream_compression.cc b/src/core/lib/compression/stream_compression.cc new file mode 100644 index 00000000..d827aeb9 --- /dev/null +++ b/src/core/lib/compression/stream_compression.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/stream_compression.h" + +#include + +#include "src/core/lib/compression/stream_compression_gzip.h" +#include "src/core/lib/slice/slice_utils.h" + +extern const grpc_stream_compression_vtable + grpc_stream_compression_identity_vtable; + +bool grpc_stream_compress(grpc_stream_compression_context* ctx, + grpc_slice_buffer* in, grpc_slice_buffer* out, + size_t* output_size, size_t max_output_size, + grpc_stream_compression_flush flush) { + return ctx->vtable->compress(ctx, in, out, output_size, max_output_size, + flush); +} + +bool grpc_stream_decompress(grpc_stream_compression_context* ctx, + grpc_slice_buffer* in, grpc_slice_buffer* out, + size_t* output_size, size_t max_output_size, + bool* end_of_context) { + return ctx->vtable->decompress(ctx, in, out, output_size, max_output_size, + end_of_context); +} + +grpc_stream_compression_context* grpc_stream_compression_context_create( + grpc_stream_compression_method method) { + switch (method) { + case GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS: + case GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS: + return grpc_stream_compression_identity_vtable.context_create(method); + case GRPC_STREAM_COMPRESSION_GZIP_COMPRESS: + case GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS: + return grpc_stream_compression_gzip_vtable.context_create(method); + default: + gpr_log(GPR_ERROR, "Unknown stream compression method: %d", method); + return nullptr; + } +} + +void grpc_stream_compression_context_destroy( + grpc_stream_compression_context* ctx) { + ctx->vtable->context_destroy(ctx); +} + +int grpc_stream_compression_method_parse( + grpc_slice value, bool is_compress, + grpc_stream_compression_method* method) { + if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_IDENTITY)) { + *method = is_compress ? GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS + : GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS; + return 1; + } else if (grpc_slice_eq_static_interned(value, GRPC_MDSTR_GZIP)) { + *method = is_compress ? GRPC_STREAM_COMPRESSION_GZIP_COMPRESS + : GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS; + return 1; + } else { + return 0; + } +} diff --git a/src/core/lib/compression/stream_compression_gzip.cc b/src/core/lib/compression/stream_compression_gzip.cc new file mode 100644 index 00000000..92a44eab --- /dev/null +++ b/src/core/lib/compression/stream_compression_gzip.cc @@ -0,0 +1,231 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/stream_compression_gzip.h" + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +#define OUTPUT_BLOCK_SIZE (1024) + +typedef struct grpc_stream_compression_context_gzip { + grpc_stream_compression_context base; + + z_stream zs; + int (*flate)(z_stream* zs, int flush); +} grpc_stream_compression_context_gzip; + +static bool gzip_flate(grpc_stream_compression_context_gzip* ctx, + grpc_slice_buffer* in, grpc_slice_buffer* out, + size_t* output_size, size_t max_output_size, int flush, + bool* end_of_context) { + GPR_ASSERT(flush == 0 || flush == Z_SYNC_FLUSH || flush == Z_FINISH); + /* Full flush is not allowed when inflating. */ + GPR_ASSERT(!(ctx->flate == inflate && (flush == Z_FINISH))); + + grpc_core::ExecCtx exec_ctx; + int r; + bool eoc = false; + size_t original_max_output_size = max_output_size; + while (max_output_size > 0 && (in->length > 0 || flush) && !eoc) { + size_t slice_size = max_output_size < OUTPUT_BLOCK_SIZE ? max_output_size + : OUTPUT_BLOCK_SIZE; + grpc_slice slice_out = GRPC_SLICE_MALLOC(slice_size); + ctx->zs.avail_out = static_cast(slice_size); + ctx->zs.next_out = GRPC_SLICE_START_PTR(slice_out); + while (ctx->zs.avail_out > 0 && in->length > 0 && !eoc) { + grpc_slice* slice = grpc_slice_buffer_peek_first(in); + ctx->zs.avail_in = static_cast GRPC_SLICE_LENGTH(*slice); + ctx->zs.next_in = GRPC_SLICE_START_PTR(*slice); + r = ctx->flate(&ctx->zs, Z_NO_FLUSH); + if (r < 0 && r != Z_BUF_ERROR) { + gpr_log(GPR_ERROR, "zlib error (%d)", r); + grpc_slice_unref_internal(slice_out); + grpc_slice_buffer_remove_first(in); + return false; + } else if (r == Z_STREAM_END && ctx->flate == inflate) { + eoc = true; + } + if (ctx->zs.avail_in > 0) { + grpc_slice_buffer_sub_first( + in, GRPC_SLICE_LENGTH(*slice) - ctx->zs.avail_in, + GRPC_SLICE_LENGTH(*slice)); + } else { + grpc_slice_buffer_remove_first(in); + } + } + if (flush != 0 && ctx->zs.avail_out > 0 && !eoc) { + GPR_ASSERT(in->length == 0); + r = ctx->flate(&ctx->zs, flush); + if (flush == Z_SYNC_FLUSH) { + switch (r) { + case Z_OK: + /* Maybe flush is not complete; just made some partial progress. */ + if (ctx->zs.avail_out > 0) { + flush = 0; + } + break; + case Z_BUF_ERROR: + case Z_STREAM_END: + flush = 0; + break; + default: + gpr_log(GPR_ERROR, "zlib error (%d)", r); + grpc_slice_unref_internal(slice_out); + + return false; + } + } else if (flush == Z_FINISH) { + switch (r) { + case Z_OK: + case Z_BUF_ERROR: + /* Wait for the next loop to assign additional output space. */ + GPR_ASSERT(ctx->zs.avail_out == 0); + break; + case Z_STREAM_END: + flush = 0; + break; + default: + gpr_log(GPR_ERROR, "zlib error (%d)", r); + grpc_slice_unref_internal(slice_out); + + return false; + } + } + } + + if (ctx->zs.avail_out == 0) { + grpc_slice_buffer_add(out, slice_out); + } else if (ctx->zs.avail_out < slice_size) { + size_t len = GRPC_SLICE_LENGTH(slice_out); + GRPC_SLICE_SET_LENGTH(slice_out, len - ctx->zs.avail_out); + grpc_slice_buffer_add(out, slice_out); + } else { + grpc_slice_unref_internal(slice_out); + } + max_output_size -= (slice_size - ctx->zs.avail_out); + } + + if (end_of_context) { + *end_of_context = eoc; + } + if (output_size) { + *output_size = original_max_output_size - max_output_size; + } + return true; +} + +static bool grpc_stream_compress_gzip(grpc_stream_compression_context* ctx, + grpc_slice_buffer* in, + grpc_slice_buffer* out, + size_t* output_size, + size_t max_output_size, + grpc_stream_compression_flush flush) { + if (ctx == nullptr) { + return false; + } + grpc_stream_compression_context_gzip* gzip_ctx = + reinterpret_cast(ctx); + GPR_ASSERT(gzip_ctx->flate == deflate); + int gzip_flush; + switch (flush) { + case GRPC_STREAM_COMPRESSION_FLUSH_NONE: + gzip_flush = 0; + break; + case GRPC_STREAM_COMPRESSION_FLUSH_SYNC: + gzip_flush = Z_SYNC_FLUSH; + break; + case GRPC_STREAM_COMPRESSION_FLUSH_FINISH: + gzip_flush = Z_FINISH; + break; + default: + gzip_flush = 0; + } + return gzip_flate(gzip_ctx, in, out, output_size, max_output_size, gzip_flush, + nullptr); +} + +static bool grpc_stream_decompress_gzip(grpc_stream_compression_context* ctx, + grpc_slice_buffer* in, + grpc_slice_buffer* out, + size_t* output_size, + size_t max_output_size, + bool* end_of_context) { + if (ctx == nullptr) { + return false; + } + grpc_stream_compression_context_gzip* gzip_ctx = + reinterpret_cast(ctx); + GPR_ASSERT(gzip_ctx->flate == inflate); + return gzip_flate(gzip_ctx, in, out, output_size, max_output_size, + Z_SYNC_FLUSH, end_of_context); +} + +static grpc_stream_compression_context* +grpc_stream_compression_context_create_gzip( + grpc_stream_compression_method method) { + GPR_ASSERT(method == GRPC_STREAM_COMPRESSION_GZIP_COMPRESS || + method == GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + grpc_stream_compression_context_gzip* gzip_ctx = + static_cast( + gpr_zalloc(sizeof(grpc_stream_compression_context_gzip))); + int r; + if (gzip_ctx == nullptr) { + return nullptr; + } + if (method == GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS) { + r = inflateInit2(&gzip_ctx->zs, 0x1F); + gzip_ctx->flate = inflate; + } else { + r = deflateInit2(&gzip_ctx->zs, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 0x1F, 8, + Z_DEFAULT_STRATEGY); + gzip_ctx->flate = deflate; + } + if (r != Z_OK) { + gpr_free(gzip_ctx); + return nullptr; + } + + gzip_ctx->base.vtable = &grpc_stream_compression_gzip_vtable; + return reinterpret_cast(gzip_ctx); +} + +static void grpc_stream_compression_context_destroy_gzip( + grpc_stream_compression_context* ctx) { + if (ctx == nullptr) { + return; + } + grpc_stream_compression_context_gzip* gzip_ctx = + reinterpret_cast(ctx); + if (gzip_ctx->flate == inflate) { + inflateEnd(&gzip_ctx->zs); + } else { + deflateEnd(&gzip_ctx->zs); + } + gpr_free(ctx); +} + +const grpc_stream_compression_vtable grpc_stream_compression_gzip_vtable = { + grpc_stream_compress_gzip, grpc_stream_decompress_gzip, + grpc_stream_compression_context_create_gzip, + grpc_stream_compression_context_destroy_gzip}; diff --git a/src/core/lib/compression/stream_compression_identity.cc b/src/core/lib/compression/stream_compression_identity.cc new file mode 100644 index 00000000..adef4c44 --- /dev/null +++ b/src/core/lib/compression/stream_compression_identity.cc @@ -0,0 +1,91 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/compression/stream_compression_identity.h" + +#include +#include + +#include "src/core/lib/slice/slice_internal.h" + +#define OUTPUT_BLOCK_SIZE (1024) + +/* Singleton context used for all identity streams. */ +static grpc_stream_compression_context identity_ctx = { + &grpc_stream_compression_identity_vtable}; + +static void grpc_stream_compression_pass_through(grpc_slice_buffer* in, + grpc_slice_buffer* out, + size_t* output_size, + size_t max_output_size) { + if (max_output_size >= in->length) { + if (output_size) { + *output_size = in->length; + } + grpc_slice_buffer_move_into(in, out); + } else { + if (output_size) { + *output_size = max_output_size; + } + grpc_slice_buffer_move_first(in, max_output_size, out); + } +} + +static bool grpc_stream_compress_identity( + grpc_stream_compression_context* ctx, grpc_slice_buffer* in, + grpc_slice_buffer* out, size_t* output_size, size_t max_output_size, + grpc_stream_compression_flush /*flush*/) { + if (ctx == nullptr) { + return false; + } + grpc_stream_compression_pass_through(in, out, output_size, max_output_size); + return true; +} + +static bool grpc_stream_decompress_identity( + grpc_stream_compression_context* ctx, grpc_slice_buffer* in, + grpc_slice_buffer* out, size_t* output_size, size_t max_output_size, + bool* end_of_context) { + if (ctx == nullptr) { + return false; + } + grpc_stream_compression_pass_through(in, out, output_size, max_output_size); + if (end_of_context) { + *end_of_context = false; + } + return true; +} + +static grpc_stream_compression_context* +grpc_stream_compression_context_create_identity( + grpc_stream_compression_method method) { + GPR_ASSERT(method == GRPC_STREAM_COMPRESSION_IDENTITY_COMPRESS || + method == GRPC_STREAM_COMPRESSION_IDENTITY_DECOMPRESS); + /* No context needed in this case. Use fake context instead. */ + return &identity_ctx; +} + +static void grpc_stream_compression_context_destroy_identity( + grpc_stream_compression_context* /*ctx*/) {} + +const grpc_stream_compression_vtable grpc_stream_compression_identity_vtable = { + grpc_stream_compress_identity, grpc_stream_decompress_identity, + grpc_stream_compression_context_create_identity, + grpc_stream_compression_context_destroy_identity}; diff --git a/src/core/lib/config/core_configuration.cc b/src/core/lib/config/core_configuration.cc new file mode 100644 index 00000000..f37672d0 --- /dev/null +++ b/src/core/lib/config/core_configuration.cc @@ -0,0 +1,96 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/config/core_configuration.h" + +#include + +namespace grpc_core { + +std::atomic CoreConfiguration::config_{nullptr}; +std::atomic CoreConfiguration::builders_{ + nullptr}; + +CoreConfiguration::Builder::Builder() = default; + +CoreConfiguration* CoreConfiguration::Builder::Build() { + return new CoreConfiguration(this); +} + +CoreConfiguration::CoreConfiguration(Builder* builder) + : channel_init_(builder->channel_init_.Build()), + handshaker_registry_(builder->handshaker_registry_.Build()) {} + +void CoreConfiguration::RegisterBuilder(std::function builder) { + GPR_ASSERT(config_.load(std::memory_order_relaxed) == nullptr && + "CoreConfiguration was already instantiated before builder " + "registration was completed"); + RegisteredBuilder* n = new RegisteredBuilder(); + n->builder = std::move(builder); + n->next = builders_.load(std::memory_order_relaxed); + while (!builders_.compare_exchange_weak(n->next, n, std::memory_order_acq_rel, + std::memory_order_relaxed)) { + } + GPR_ASSERT(config_.load(std::memory_order_relaxed) == nullptr && + "CoreConfiguration was already instantiated before builder " + "registration was completed"); +} + +const CoreConfiguration& CoreConfiguration::BuildNewAndMaybeSet() { + // Construct builder, pass it up to code that knows about build configuration + Builder builder; + // The linked list of builders stores things in reverse registration order. + // To get things registered as systems relying on this expect however, we + // actually need to run things in forward registration order, so we iterate + // once over the linked list to build a vector of builders, and then iterate + // over said vector in reverse to actually run the builders. + std::vector registered_builders; + for (RegisteredBuilder* b = builders_.load(std::memory_order_acquire); + b != nullptr; b = b->next) { + registered_builders.push_back(b); + } + for (auto it = registered_builders.rbegin(); it != registered_builders.rend(); + ++it) { + (*it)->builder(&builder); + } + // Finally, call the built in configuration builder. + BuildCoreConfiguration(&builder); + // Use builder to construct a confguration + CoreConfiguration* p = builder.Build(); + // Try to set configuration global - it's possible another thread raced us + // here, in which case we drop the work we did and use the one that got set + // first + CoreConfiguration* expected = nullptr; + if (!config_.compare_exchange_strong(expected, p, std::memory_order_acq_rel, + std::memory_order_acquire)) { + delete p; + return *expected; + } + return *p; +} + +void CoreConfiguration::Reset() { + delete config_.exchange(nullptr, std::memory_order_acquire); + RegisteredBuilder* builder = + builders_.exchange(nullptr, std::memory_order_acquire); + while (builder != nullptr) { + RegisteredBuilder* next = builder->next; + delete builder; + builder = next; + } +} + +} // namespace grpc_core diff --git a/src/core/lib/debug/stats.cc b/src/core/lib/debug/stats.cc new file mode 100644 index 00000000..a0292e3b --- /dev/null +++ b/src/core/lib/debug/stats.cc @@ -0,0 +1,172 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/debug/stats.h" + +#include +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" + +grpc_stats_data* grpc_stats_per_cpu_storage = nullptr; +static size_t g_num_cores; + +void grpc_stats_init(void) { + g_num_cores = std::max(1u, gpr_cpu_num_cores()); + grpc_stats_per_cpu_storage = static_cast( + gpr_zalloc(sizeof(grpc_stats_data) * g_num_cores)); +} + +void grpc_stats_shutdown(void) { gpr_free(grpc_stats_per_cpu_storage); } + +void grpc_stats_collect(grpc_stats_data* output) { + memset(output, 0, sizeof(*output)); + for (size_t core = 0; core < g_num_cores; core++) { + for (size_t i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + output->counters[i] += gpr_atm_no_barrier_load( + &grpc_stats_per_cpu_storage[core].counters[i]); + } + for (size_t i = 0; i < GRPC_STATS_HISTOGRAM_BUCKETS; i++) { + output->histograms[i] += gpr_atm_no_barrier_load( + &grpc_stats_per_cpu_storage[core].histograms[i]); + } + } +} + +void grpc_stats_diff(const grpc_stats_data* b, const grpc_stats_data* a, + grpc_stats_data* c) { + for (size_t i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + c->counters[i] = b->counters[i] - a->counters[i]; + } + for (size_t i = 0; i < GRPC_STATS_HISTOGRAM_BUCKETS; i++) { + c->histograms[i] = b->histograms[i] - a->histograms[i]; + } +} + +int grpc_stats_histo_find_bucket_slow(int value, const int* table, + int table_size) { + GRPC_STATS_INC_HISTOGRAM_SLOW_LOOKUPS(); + const int* const start = table; + while (table_size > 0) { + int step = table_size / 2; + const int* it = table + step; + if (value >= *it) { + table = it + 1; + table_size -= step + 1; + } else { + table_size = step; + } + } + return static_cast(table - start) - 1; +} + +size_t grpc_stats_histo_count(const grpc_stats_data* stats, + grpc_stats_histograms histogram) { + size_t sum = 0; + for (int i = 0; i < grpc_stats_histo_buckets[histogram]; i++) { + sum += static_cast( + stats->histograms[grpc_stats_histo_start[histogram] + i]); + } + return sum; +} + +static double threshold_for_count_below(const gpr_atm* bucket_counts, + const int* bucket_boundaries, + int num_buckets, double count_below) { + double count_so_far; + double lower_bound; + double upper_bound; + int lower_idx; + int upper_idx; + + /* find the lowest bucket that gets us above count_below */ + count_so_far = 0.0; + for (lower_idx = 0; lower_idx < num_buckets; lower_idx++) { + count_so_far += static_cast(bucket_counts[lower_idx]); + if (count_so_far >= count_below) { + break; + } + } + if (count_so_far == count_below) { + /* this bucket hits the threshold exactly... we should be midway through + any run of zero values following the bucket */ + for (upper_idx = lower_idx + 1; upper_idx < num_buckets; upper_idx++) { + if (bucket_counts[upper_idx]) { + break; + } + } + return (bucket_boundaries[lower_idx] + bucket_boundaries[upper_idx]) / 2.0; + } else { + /* treat values as uniform throughout the bucket, and find where this value + should lie */ + lower_bound = bucket_boundaries[lower_idx]; + upper_bound = bucket_boundaries[lower_idx + 1]; + return upper_bound - (upper_bound - lower_bound) * + (count_so_far - count_below) / + static_cast(bucket_counts[lower_idx]); + } +} + +double grpc_stats_histo_percentile(const grpc_stats_data* stats, + grpc_stats_histograms histogram, + double percentile) { + size_t count = grpc_stats_histo_count(stats, histogram); + if (count == 0) return 0.0; + return threshold_for_count_below( + stats->histograms + grpc_stats_histo_start[histogram], + grpc_stats_histo_bucket_boundaries[histogram], + grpc_stats_histo_buckets[histogram], + static_cast(count) * percentile / 100.0); +} + +std::string grpc_stats_data_as_json(const grpc_stats_data* data) { + std::vector parts; + parts.push_back("{"); + for (size_t i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + parts.push_back(absl::StrFormat( + "\"%s\": %" PRIdPTR, grpc_stats_counter_name[i], data->counters[i])); + } + for (size_t i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + parts.push_back(absl::StrFormat("\"%s\": [", grpc_stats_histogram_name[i])); + for (int j = 0; j < grpc_stats_histo_buckets[i]; j++) { + parts.push_back( + absl::StrFormat("%s%" PRIdPTR, j == 0 ? "" : ",", + data->histograms[grpc_stats_histo_start[i] + j])); + } + parts.push_back( + absl::StrFormat("], \"%s_bkt\": [", grpc_stats_histogram_name[i])); + for (int j = 0; j < grpc_stats_histo_buckets[i]; j++) { + parts.push_back(absl::StrFormat( + "%s%d", j == 0 ? "" : ",", grpc_stats_histo_bucket_boundaries[i][j])); + } + parts.push_back("]"); + } + parts.push_back("}"); + return absl::StrJoin(parts, ""); +} diff --git a/src/core/lib/debug/stats_data.cc b/src/core/lib/debug/stats_data.cc new file mode 100644 index 00000000..21e27f2b --- /dev/null +++ b/src/core/lib/debug/stats_data.cc @@ -0,0 +1,689 @@ +/* + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Automatically generated by tools/codegen/core/gen_stats_data.py + */ + +#include + +#include "src/core/lib/debug/stats_data.h" + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" + +const char* grpc_stats_counter_name[GRPC_STATS_COUNTER_COUNT] = { + "client_calls_created", + "server_calls_created", + "cqs_created", + "client_channels_created", + "client_subchannels_created", + "server_channels_created", + "syscall_poll", + "syscall_wait", + "pollset_kick", + "pollset_kicked_without_poller", + "pollset_kicked_again", + "pollset_kick_wakeup_fd", + "pollset_kick_wakeup_cv", + "pollset_kick_own_thread", + "syscall_epoll_ctl", + "pollset_fd_cache_hits", + "histogram_slow_lookups", + "syscall_write", + "syscall_read", + "tcp_backup_pollers_created", + "tcp_backup_poller_polls", + "http2_op_batches", + "http2_op_cancel", + "http2_op_send_initial_metadata", + "http2_op_send_message", + "http2_op_send_trailing_metadata", + "http2_op_recv_initial_metadata", + "http2_op_recv_message", + "http2_op_recv_trailing_metadata", + "http2_settings_writes", + "http2_pings_sent", + "http2_writes_begun", + "http2_writes_offloaded", + "http2_writes_continued", + "http2_partial_writes", + "http2_initiate_write_due_to_initial_write", + "http2_initiate_write_due_to_start_new_stream", + "http2_initiate_write_due_to_send_message", + "http2_initiate_write_due_to_send_initial_metadata", + "http2_initiate_write_due_to_send_trailing_metadata", + "http2_initiate_write_due_to_retry_send_ping", + "http2_initiate_write_due_to_continue_pings", + "http2_initiate_write_due_to_goaway_sent", + "http2_initiate_write_due_to_rst_stream", + "http2_initiate_write_due_to_close_from_api", + "http2_initiate_write_due_to_stream_flow_control", + "http2_initiate_write_due_to_transport_flow_control", + "http2_initiate_write_due_to_send_settings", + "http2_initiate_write_due_to_bdp_estimator_ping", + "http2_initiate_write_due_to_flow_control_unstalled_by_setting", + "http2_initiate_write_due_to_flow_control_unstalled_by_update", + "http2_initiate_write_due_to_application_ping", + "http2_initiate_write_due_to_keepalive_ping", + "http2_initiate_write_due_to_transport_flow_control_unstalled", + "http2_initiate_write_due_to_ping_response", + "http2_initiate_write_due_to_force_rst_stream", + "http2_spurious_writes_begun", + "hpack_recv_indexed", + "hpack_recv_lithdr_incidx", + "hpack_recv_lithdr_incidx_v", + "hpack_recv_lithdr_notidx", + "hpack_recv_lithdr_notidx_v", + "hpack_recv_lithdr_nvridx", + "hpack_recv_lithdr_nvridx_v", + "hpack_recv_uncompressed", + "hpack_recv_huffman", + "hpack_recv_binary", + "hpack_recv_binary_base64", + "hpack_send_indexed", + "hpack_send_lithdr_incidx", + "hpack_send_lithdr_incidx_v", + "hpack_send_lithdr_notidx", + "hpack_send_lithdr_notidx_v", + "hpack_send_lithdr_nvridx", + "hpack_send_lithdr_nvridx_v", + "hpack_send_uncompressed", + "hpack_send_huffman", + "hpack_send_binary", + "hpack_send_binary_base64", + "combiner_locks_initiated", + "combiner_locks_scheduled_items", + "combiner_locks_scheduled_final_items", + "combiner_locks_offloaded", + "call_combiner_locks_initiated", + "call_combiner_locks_scheduled_items", + "call_combiner_set_notify_on_cancel", + "call_combiner_cancelled", + "executor_scheduled_short_items", + "executor_scheduled_long_items", + "executor_scheduled_to_self", + "executor_wakeup_initiated", + "executor_queue_drained", + "executor_push_retries", + "server_requested_calls", + "server_slowpath_requests_queued", + "cq_ev_queue_trylock_failures", + "cq_ev_queue_trylock_successes", + "cq_ev_queue_transient_pop_failures", +}; +const char* grpc_stats_counter_doc[GRPC_STATS_COUNTER_COUNT] = { + "Number of client side calls created by this process", + "Number of server side calls created by this process", + "Number of completion queues created", + "Number of client channels created", + "Number of client subchannels created", + "Number of server channels created", + "Number of polling syscalls (epoll_wait, poll, etc) made by this process", + "Number of sleeping syscalls made by this process", + "How many polling wakeups were performed by the process (only valid for " + "epoll1 right now)", + "How many times was a polling wakeup requested without an active poller " + "(only valid for epoll1 right now)", + "How many times was the same polling worker awoken repeatedly before " + "waking up (only valid for epoll1 right now)", + "How many times was an eventfd used as the wakeup vector for a polling " + "wakeup (only valid for epoll1 right now)", + "How many times was a condition variable used as the wakeup vector for a " + "polling wakeup (only valid for epoll1 right now)", + "How many times could a polling wakeup be satisfied by keeping the waking " + "thread awake? (only valid for epoll1 right now)", + "Number of epoll_ctl calls made (only valid for epollex right now)", + "Number of epoll_ctl calls skipped because the fd was cached as already " + "being added. (only valid for epollex right now)", + "Number of times histogram increments went through the slow (binary " + "search) path", + "Number of write syscalls (or equivalent - eg sendmsg) made by this " + "process", + "Number of read syscalls (or equivalent - eg recvmsg) made by this process", + "Number of times a backup poller has been created (this can be expensive)", + "Number of polls performed on the backup poller", + "Number of batches received by HTTP2 transport", + "Number of cancelations received by HTTP2 transport", + "Number of batches containing send initial metadata", + "Number of batches containing send message", + "Number of batches containing send trailing metadata", + "Number of batches containing receive initial metadata", + "Number of batches containing receive message", + "Number of batches containing receive trailing metadata", + "Number of settings frames sent", + "Number of HTTP2 pings sent by process", + "Number of HTTP2 writes initiated", + "Number of HTTP2 writes offloaded to the executor from application threads", + "Number of HTTP2 writes that finished seeing more data needed to be " + "written", + "Number of HTTP2 writes that were made knowing there was still more data " + "to be written (we cap maximum write size to syscall_write)", + "Number of HTTP2 writes initiated due to 'initial_write'", + "Number of HTTP2 writes initiated due to 'start_new_stream'", + "Number of HTTP2 writes initiated due to 'send_message'", + "Number of HTTP2 writes initiated due to 'send_initial_metadata'", + "Number of HTTP2 writes initiated due to 'send_trailing_metadata'", + "Number of HTTP2 writes initiated due to 'retry_send_ping'", + "Number of HTTP2 writes initiated due to 'continue_pings'", + "Number of HTTP2 writes initiated due to 'goaway_sent'", + "Number of HTTP2 writes initiated due to 'rst_stream'", + "Number of HTTP2 writes initiated due to 'close_from_api'", + "Number of HTTP2 writes initiated due to 'stream_flow_control'", + "Number of HTTP2 writes initiated due to 'transport_flow_control'", + "Number of HTTP2 writes initiated due to 'send_settings'", + "Number of HTTP2 writes initiated due to 'bdp_estimator_ping'", + "Number of HTTP2 writes initiated due to " + "'flow_control_unstalled_by_setting'", + "Number of HTTP2 writes initiated due to " + "'flow_control_unstalled_by_update'", + "Number of HTTP2 writes initiated due to 'application_ping'", + "Number of HTTP2 writes initiated due to 'keepalive_ping'", + "Number of HTTP2 writes initiated due to " + "'transport_flow_control_unstalled'", + "Number of HTTP2 writes initiated due to 'ping_response'", + "Number of HTTP2 writes initiated due to 'force_rst_stream'", + "Number of HTTP2 writes initiated with nothing to write", + "Number of HPACK indexed fields received", + "Number of HPACK literal headers received with incremental indexing", + "Number of HPACK literal headers received with incremental indexing and " + "literal keys", + "Number of HPACK literal headers received with no indexing", + "Number of HPACK literal headers received with no indexing and literal " + "keys", + "Number of HPACK literal headers received with never-indexing", + "Number of HPACK literal headers received with never-indexing and literal " + "keys", + "Number of uncompressed strings received in metadata", + "Number of huffman encoded strings received in metadata", + "Number of binary strings received in metadata", + "Number of binary strings received encoded in base64 in metadata", + "Number of HPACK indexed fields sent", + "Number of HPACK literal headers sent with incremental indexing", + "Number of HPACK literal headers sent with incremental indexing and " + "literal keys", + "Number of HPACK literal headers sent with no indexing", + "Number of HPACK literal headers sent with no indexing and literal keys", + "Number of HPACK literal headers sent with never-indexing", + "Number of HPACK literal headers sent with never-indexing and literal keys", + "Number of uncompressed strings sent in metadata", + "Number of huffman encoded strings sent in metadata", + "Number of binary strings received in metadata", + "Number of binary strings received encoded in base64 in metadata", + "Number of combiner lock entries by process (first items queued to a " + "combiner)", + "Number of items scheduled against combiner locks", + "Number of final items scheduled against combiner locks", + "Number of combiner locks offloaded to different threads", + "Number of call combiner lock entries by process (first items queued to a " + "call combiner)", + "Number of items scheduled against call combiner locks", + "Number of times a cancellation callback was set on a call combiner", + "Number of times a call combiner was cancelled", + "Number of finite runtime closures scheduled against the executor (gRPC " + "thread pool)", + "Number of potentially infinite runtime closures scheduled against the " + "executor (gRPC thread pool)", + "Number of closures scheduled by the executor to the executor", + "Number of thread wakeups initiated within the executor", + "Number of times an executor queue was drained", + "Number of times we raced and were forced to retry pushing a closure to " + "the executor", + "How many calls were requested (not necessarily received) by the server", + "How many times was the server slow path taken (indicates too few " + "outstanding requests)", + "Number of lock (trylock) acquisition failures on completion queue event " + "queue. High value here indicates high contention on completion queues", + "Number of lock (trylock) acquisition successes on completion queue event " + "queue.", + "Number of times NULL was popped out of completion queue's event queue " + "even though the event queue was not empty", +}; +const char* grpc_stats_histogram_name[GRPC_STATS_HISTOGRAM_COUNT] = { + "call_initial_size", + "poll_events_returned", + "tcp_write_size", + "tcp_write_iov_size", + "tcp_read_size", + "tcp_read_offer", + "tcp_read_offer_iov_size", + "http2_send_message_size", + "http2_send_initial_metadata_per_write", + "http2_send_message_per_write", + "http2_send_trailing_metadata_per_write", + "http2_send_flowctl_per_write", + "server_cqs_checked", +}; +const char* grpc_stats_histogram_doc[GRPC_STATS_HISTOGRAM_COUNT] = { + "Initial size of the grpc_call arena created at call start", + "How many events are called for each syscall_poll", + "Number of bytes offered to each syscall_write", + "Number of byte segments offered to each syscall_write", + "Number of bytes received by each syscall_read", + "Number of bytes offered to each syscall_read", + "Number of byte segments offered to each syscall_read", + "Size of messages received by HTTP2 transport", + "Number of streams initiated written per TCP write", + "Number of streams whose payload was written per TCP write", + "Number of streams terminated per TCP write", + "Number of flow control updates written per TCP write", + // NOLINTNEXTLINE(bugprone-suspicious-missing-comma) + "How many completion queues were checked looking for a CQ that had " + "requested the incoming call", +}; +const int grpc_stats_table_0[65] = { + 0, 1, 2, 3, 4, 5, 7, 9, 11, 14, + 17, 21, 26, 32, 39, 47, 57, 68, 82, 98, + 117, 140, 167, 199, 238, 284, 339, 404, 482, 575, + 685, 816, 972, 1158, 1380, 1644, 1959, 2334, 2780, 3312, + 3945, 4699, 5597, 6667, 7941, 9459, 11267, 13420, 15984, 19038, + 22676, 27009, 32169, 38315, 45635, 54353, 64737, 77104, 91834, 109378, + 130273, 155159, 184799, 220100, 262144}; +const uint8_t grpc_stats_table_1[124] = { + 0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, + 7, 7, 7, 8, 9, 9, 10, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, + 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 22, 23, 24, + 24, 25, 25, 26, 26, 26, 27, 27, 28, 29, 29, 30, 30, 30, 31, 31, 32, 33, + 33, 34, 34, 34, 35, 35, 36, 37, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, + 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, + 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56, 57, 57, 58, 58}; +const int grpc_stats_table_2[129] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, + 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 94, 98, 102, 106, 110, + 114, 118, 122, 126, 131, 136, 141, 146, 151, 156, 162, 168, 174, 180, 186, + 192, 199, 206, 213, 220, 228, 236, 244, 252, 260, 269, 278, 287, 297, 307, + 317, 327, 338, 349, 360, 372, 384, 396, 409, 422, 436, 450, 464, 479, 494, + 510, 526, 543, 560, 578, 596, 615, 634, 654, 674, 695, 717, 739, 762, 785, + 809, 834, 859, 885, 912, 939, 967, 996, 1024}; +const uint8_t grpc_stats_table_3[166] = { + 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, + 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 16, + 17, 17, 18, 19, 19, 20, 21, 21, 22, 23, 23, 24, 25, 25, 26, 26, 27, 27, 28, + 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 36, 36, 37, 38, 39, + 40, 40, 41, 42, 42, 43, 44, 44, 45, 46, 46, 47, 48, 48, 49, 49, 50, 50, 51, + 51, 52, 52, 53, 53, 54, 54, 55, 56, 57, 58, 59, 59, 60, 61, 62, 63, 63, 64, + 65, 65, 66, 67, 67, 68, 69, 69, 70, 71, 71, 72, 72, 73, 73, 74, 75, 75, 76, + 76, 77, 78, 79, 79, 80, 81, 82, 83, 84, 85, 85, 86, 87, 88, 88, 89, 90, 90, + 91, 92, 92, 93, 94, 94, 95, 95, 96, 97, 97, 98, 98, 99}; +const int grpc_stats_table_4[65] = { + 0, 1, 2, 3, 4, 6, 8, 11, + 15, 20, 26, 34, 44, 57, 73, 94, + 121, 155, 199, 255, 327, 419, 537, 688, + 881, 1128, 1444, 1848, 2365, 3026, 3872, 4954, + 6338, 8108, 10373, 13270, 16976, 21717, 27782, 35541, + 45467, 58165, 74409, 95189, 121772, 155778, 199281, 254933, + 326126, 417200, 533707, 682750, 873414, 1117323, 1429345, 1828502, + 2339127, 2992348, 3827987, 4896985, 6264509, 8013925, 10251880, 13114801, + 16777216}; +const uint8_t grpc_stats_table_5[87] = { + 0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 6, 6, 7, 8, 8, 9, 10, 11, + 11, 12, 13, 13, 14, 15, 15, 16, 17, 17, 18, 19, 20, 20, 21, 22, 22, 23, + 24, 25, 25, 26, 27, 27, 28, 29, 29, 30, 31, 31, 32, 33, 34, 34, 35, 36, + 36, 37, 38, 39, 39, 40, 41, 41, 42, 43, 44, 44, 45, 45, 46, 47, 48, 48, + 49, 50, 51, 51, 52, 53, 53, 54, 55, 56, 56, 57, 58, 58, 59}; +const int grpc_stats_table_6[65] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 14, 16, 18, 20, 22, 24, 27, 30, 33, 36, 39, 43, 47, + 51, 56, 61, 66, 72, 78, 85, 92, 100, 109, 118, 128, 139, + 151, 164, 178, 193, 209, 226, 244, 264, 285, 308, 333, 359, 387, + 418, 451, 486, 524, 565, 609, 656, 707, 762, 821, 884, 952, 1024}; +const uint8_t grpc_stats_table_7[102] = { + 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, + 6, 7, 7, 7, 8, 8, 9, 9, 10, 11, 11, 12, 12, 13, 13, 14, 14, + 14, 15, 15, 16, 16, 17, 17, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, + 23, 24, 24, 24, 25, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, + 32, 33, 33, 34, 35, 35, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, + 42, 42, 43, 44, 44, 45, 46, 46, 47, 48, 48, 49, 49, 50, 50, 51, 51}; +const int grpc_stats_table_8[9] = {0, 1, 2, 4, 7, 13, 23, 39, 64}; +const uint8_t grpc_stats_table_9[9] = {0, 0, 1, 2, 2, 3, 4, 4, 5}; +void grpc_stats_inc_call_initial_size(int value) { + value = grpc_core::Clamp(value, 0, 262144); + if (value < 6) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_CALL_INITIAL_SIZE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4651092515166879744ull) { + int bucket = + grpc_stats_table_1[((_val.uint - 4618441417868443648ull) >> 49)] + 6; + _bkt.dbl = grpc_stats_table_0[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_CALL_INITIAL_SIZE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_CALL_INITIAL_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_0, 64)); +} +void grpc_stats_inc_poll_events_returned(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 29) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_POLL_EVENTS_RETURNED, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4642789003353915392ull) { + int bucket = + grpc_stats_table_3[((_val.uint - 4628855992006737920ull) >> 47)] + 29; + _bkt.dbl = grpc_stats_table_2[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_POLL_EVENTS_RETURNED, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_POLL_EVENTS_RETURNED, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_2, 128)); +} +void grpc_stats_inc_tcp_write_size(int value) { + value = grpc_core::Clamp(value, 0, 16777216); + if (value < 5) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_WRITE_SIZE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4683743612465315840ull) { + int bucket = + grpc_stats_table_5[((_val.uint - 4617315517961601024ull) >> 50)] + 5; + _bkt.dbl = grpc_stats_table_4[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_WRITE_SIZE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_TCP_WRITE_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_4, 64)); +} +void grpc_stats_inc_tcp_write_iov_size(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_WRITE_IOV_SIZE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_WRITE_IOV_SIZE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_TCP_WRITE_IOV_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_tcp_read_size(int value) { + value = grpc_core::Clamp(value, 0, 16777216); + if (value < 5) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_SIZE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4683743612465315840ull) { + int bucket = + grpc_stats_table_5[((_val.uint - 4617315517961601024ull) >> 50)] + 5; + _bkt.dbl = grpc_stats_table_4[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_SIZE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_TCP_READ_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_4, 64)); +} +void grpc_stats_inc_tcp_read_offer(int value) { + value = grpc_core::Clamp(value, 0, 16777216); + if (value < 5) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_OFFER, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4683743612465315840ull) { + int bucket = + grpc_stats_table_5[((_val.uint - 4617315517961601024ull) >> 50)] + 5; + _bkt.dbl = grpc_stats_table_4[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_OFFER, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_TCP_READ_OFFER, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_4, 64)); +} +void grpc_stats_inc_tcp_read_offer_iov_size(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_OFFER_IOV_SIZE, + value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_TCP_READ_OFFER_IOV_SIZE, + bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_TCP_READ_OFFER_IOV_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_http2_send_message_size(int value) { + value = grpc_core::Clamp(value, 0, 16777216); + if (value < 5) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_SIZE, + value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4683743612465315840ull) { + int bucket = + grpc_stats_table_5[((_val.uint - 4617315517961601024ull) >> 50)] + 5; + _bkt.dbl = grpc_stats_table_4[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_SIZE, + bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_SIZE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_4, 64)); +} +void grpc_stats_inc_http2_send_initial_metadata_per_write(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_INITIAL_METADATA_PER_WRITE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_INITIAL_METADATA_PER_WRITE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_INITIAL_METADATA_PER_WRITE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_http2_send_message_per_write(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_PER_WRITE, + value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_PER_WRITE, + bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_MESSAGE_PER_WRITE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_http2_send_trailing_metadata_per_write(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_TRAILING_METADATA_PER_WRITE, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_TRAILING_METADATA_PER_WRITE, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_TRAILING_METADATA_PER_WRITE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_http2_send_flowctl_per_write(int value) { + value = grpc_core::Clamp(value, 0, 1024); + if (value < 13) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_FLOWCTL_PER_WRITE, + value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4637863191261478912ull) { + int bucket = + grpc_stats_table_7[((_val.uint - 4623507967449235456ull) >> 48)] + 13; + _bkt.dbl = grpc_stats_table_6[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_HTTP2_SEND_FLOWCTL_PER_WRITE, + bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_HTTP2_SEND_FLOWCTL_PER_WRITE, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_6, 64)); +} +void grpc_stats_inc_server_cqs_checked(int value) { + value = grpc_core::Clamp(value, 0, 64); + if (value < 3) { + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_SERVER_CQS_CHECKED, value); + return; + } + union { + double dbl; + uint64_t uint; + } _val, _bkt; + _val.dbl = value; + if (_val.uint < 4625196817309499392ull) { + int bucket = + grpc_stats_table_9[((_val.uint - 4613937818241073152ull) >> 51)] + 3; + _bkt.dbl = grpc_stats_table_8[bucket]; + bucket -= (_val.uint < _bkt.uint); + GRPC_STATS_INC_HISTOGRAM(GRPC_STATS_HISTOGRAM_SERVER_CQS_CHECKED, bucket); + return; + } + GRPC_STATS_INC_HISTOGRAM( + GRPC_STATS_HISTOGRAM_SERVER_CQS_CHECKED, + grpc_stats_histo_find_bucket_slow(value, grpc_stats_table_8, 8)); +} +const int grpc_stats_histo_buckets[13] = {64, 128, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 8}; +const int grpc_stats_histo_start[13] = {0, 64, 192, 256, 320, 384, 448, + 512, 576, 640, 704, 768, 832}; +const int* const grpc_stats_histo_bucket_boundaries[13] = { + grpc_stats_table_0, grpc_stats_table_2, grpc_stats_table_4, + grpc_stats_table_6, grpc_stats_table_4, grpc_stats_table_4, + grpc_stats_table_6, grpc_stats_table_4, grpc_stats_table_6, + grpc_stats_table_6, grpc_stats_table_6, grpc_stats_table_6, + grpc_stats_table_8}; +void (*const grpc_stats_inc_histogram[13])(int x) = { + grpc_stats_inc_call_initial_size, + grpc_stats_inc_poll_events_returned, + grpc_stats_inc_tcp_write_size, + grpc_stats_inc_tcp_write_iov_size, + grpc_stats_inc_tcp_read_size, + grpc_stats_inc_tcp_read_offer, + grpc_stats_inc_tcp_read_offer_iov_size, + grpc_stats_inc_http2_send_message_size, + grpc_stats_inc_http2_send_initial_metadata_per_write, + grpc_stats_inc_http2_send_message_per_write, + grpc_stats_inc_http2_send_trailing_metadata_per_write, + grpc_stats_inc_http2_send_flowctl_per_write, + grpc_stats_inc_server_cqs_checked}; diff --git a/src/core/lib/debug/trace.cc b/src/core/lib/debug/trace.cc new file mode 100644 index 00000000..87b61446 --- /dev/null +++ b/src/core/lib/debug/trace.cc @@ -0,0 +1,155 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/debug/trace.h" + +#include + +#include + +#include +#include +#include + +GPR_GLOBAL_CONFIG_DEFINE_STRING( + grpc_trace, "", + "A comma separated list of tracers that provide additional insight into " + "how gRPC C core is processing requests via debug logs."); + +int grpc_tracer_set_enabled(const char* name, int enabled); + +namespace grpc_core { + +TraceFlag* TraceFlagList::root_tracer_ = nullptr; + +bool TraceFlagList::Set(const char* name, bool enabled) { + TraceFlag* t; + if (0 == strcmp(name, "all")) { + for (t = root_tracer_; t; t = t->next_tracer_) { + t->set_enabled(enabled); + } + } else if (0 == strcmp(name, "list_tracers")) { + LogAllTracers(); + } else if (0 == strcmp(name, "refcount")) { + for (t = root_tracer_; t; t = t->next_tracer_) { + if (strstr(t->name_, "refcount") != nullptr) { + t->set_enabled(enabled); + } + } + } else { + bool found = false; + for (t = root_tracer_; t; t = t->next_tracer_) { + if (0 == strcmp(name, t->name_)) { + t->set_enabled(enabled); + found = true; + } + } + // check for unknowns, but ignore "", to allow to GRPC_TRACE= + if (!found && 0 != strcmp(name, "")) { + gpr_log(GPR_ERROR, "Unknown trace var: '%s'", name); + return false; /* early return */ + } + } + return true; +} + +void TraceFlagList::Add(TraceFlag* flag) { + flag->next_tracer_ = root_tracer_; + root_tracer_ = flag; +} + +void TraceFlagList::LogAllTracers() { + gpr_log(GPR_DEBUG, "available tracers:"); + TraceFlag* t; + for (t = root_tracer_; t != nullptr; t = t->next_tracer_) { + gpr_log(GPR_DEBUG, "\t%s", t->name_); + } +} + +// Flags register themselves on the list during construction +TraceFlag::TraceFlag(bool default_enabled, const char* name) : name_(name) { + static_assert(std::is_trivially_destructible::value, + "TraceFlag needs to be trivially destructible."); + set_enabled(default_enabled); + TraceFlagList::Add(this); +} + +} // namespace grpc_core + +static void add(const char* beg, const char* end, char*** ss, size_t* ns) { + size_t n = *ns; + size_t np = n + 1; + char* s; + size_t len; + GPR_ASSERT(end >= beg); + len = static_cast(end - beg); + s = static_cast(gpr_malloc(len + 1)); + memcpy(s, beg, len); + s[len] = 0; + *ss = static_cast(gpr_realloc(*ss, sizeof(char**) * np)); + (*ss)[n] = s; + *ns = np; +} + +static void split(const char* s, char*** ss, size_t* ns) { + const char* c = strchr(s, ','); + if (c == nullptr) { + add(s, s + strlen(s), ss, ns); + } else { + add(s, c, ss, ns); + split(c + 1, ss, ns); + } +} + +static void parse(const char* s) { + char** strings = nullptr; + size_t nstrings = 0; + size_t i; + split(s, &strings, &nstrings); + + for (i = 0; i < nstrings; i++) { + if (strings[i][0] == '-') { + grpc_core::TraceFlagList::Set(strings[i] + 1, false); + } else { + grpc_core::TraceFlagList::Set(strings[i], true); + } + } + + for (i = 0; i < nstrings; i++) { + gpr_free(strings[i]); + } + gpr_free(strings); +} + +void grpc_tracer_init(const char* env_var_name) { + (void)env_var_name; // suppress unused variable error + grpc_tracer_init(); +} + +void grpc_tracer_init() { + grpc_core::UniquePtr value = GPR_GLOBAL_CONFIG_GET(grpc_trace); + parse(value.get()); +} + +void grpc_tracer_shutdown(void) {} + +int grpc_tracer_set_enabled(const char* name, int enabled) { + return grpc_core::TraceFlagList::Set(name, enabled != 0); +} diff --git a/src/core/lib/event_engine/endpoint_config.cc b/src/core/lib/event_engine/endpoint_config.cc new file mode 100644 index 00000000..ca6438d5 --- /dev/null +++ b/src/core/lib/event_engine/endpoint_config.cc @@ -0,0 +1,45 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/event_engine/endpoint_config_internal.h" +#include "src/core/lib/gpr/useful.h" + +namespace grpc_event_engine { +namespace experimental { + +EndpointConfig::Setting ChannelArgsEndpointConfig::Get( + absl::string_view key) const { + const grpc_arg* arg = grpc_channel_args_find(args_, std::string(key).c_str()); + if (arg == nullptr) { + return absl::monostate(); + } + switch (arg->type) { + case GRPC_ARG_STRING: + return absl::string_view(arg->value.string); + case GRPC_ARG_INTEGER: + return arg->value.integer; + case GRPC_ARG_POINTER: + return arg->value.pointer.p; + } + GPR_UNREACHABLE_CODE(return absl::monostate()); +} + +} // namespace experimental +} // namespace grpc_event_engine diff --git a/src/core/lib/event_engine/event_engine.cc b/src/core/lib/event_engine/event_engine.cc new file mode 100644 index 00000000..cce18ab4 --- /dev/null +++ b/src/core/lib/event_engine/event_engine.cc @@ -0,0 +1,50 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/event_engine/sockaddr.h" + +namespace grpc_event_engine { +namespace experimental { + +EventEngine::ResolvedAddress::ResolvedAddress(const sockaddr* address, + socklen_t size) + : size_(size) { + GPR_ASSERT(size <= sizeof(address_)); + memcpy(&address_, address, size); +} + +const struct sockaddr* EventEngine::ResolvedAddress::address() const { + return reinterpret_cast(address_); +} + +socklen_t EventEngine::ResolvedAddress::size() const { return size_; } + +std::shared_ptr +DefaultEventEngineFactory() { + // TODO(nnoble): delete when uv-ee is merged + abort(); +} + +} // namespace experimental +} // namespace grpc_event_engine diff --git a/src/core/lib/event_engine/memory_allocator.cc b/src/core/lib/event_engine/memory_allocator.cc new file mode 100644 index 00000000..ea2a3c26 --- /dev/null +++ b/src/core/lib/event_engine/memory_allocator.cc @@ -0,0 +1,70 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/slice/slice_refcount.h" + +namespace grpc_event_engine { +namespace experimental { + +namespace { + +// Reference count for a slice allocated by MemoryAllocator::MakeSlice. +// Takes care of releasing memory back when the slice is destroyed. +class SliceRefCount { + public: + static void Destroy(void* p) { + auto* rc = static_cast(p); + rc->~SliceRefCount(); + gpr_free(rc); + } + SliceRefCount(std::shared_ptr allocator, + size_t size) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + allocator_(std::move(allocator)), + size_(size) { + // Nothing to do here. + } + ~SliceRefCount() { allocator_->Release(size_); } + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + grpc_slice_refcount base_; + grpc_core::RefCount refs_; + std::shared_ptr allocator_; + size_t size_; +}; + +} // namespace + +grpc_slice MemoryAllocator::MakeSlice(MemoryRequest request) { + auto size = Reserve(request.Increase(sizeof(SliceRefCount))); + void* p = gpr_malloc(size); + new (p) SliceRefCount(allocator_, size); + grpc_slice slice; + slice.refcount = static_cast(p)->base_refcount(); + slice.data.refcounted.bytes = + static_cast(p) + sizeof(SliceRefCount); + slice.data.refcounted.length = size - sizeof(SliceRefCount); + return slice; +} + +} // namespace experimental +} // namespace grpc_event_engine diff --git a/src/core/lib/event_engine/sockaddr.cc b/src/core/lib/event_engine/sockaddr.cc new file mode 100644 index 00000000..dab44f93 --- /dev/null +++ b/src/core/lib/event_engine/sockaddr.cc @@ -0,0 +1,40 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include +#include +#include + +uint16_t grpc_htons(uint16_t hostshort) { return htons(hostshort); } + +uint16_t grpc_ntohs(uint16_t netshort) { return ntohs(netshort); } + +uint32_t grpc_htonl(uint32_t hostlong) { return htonl(hostlong); } + +uint32_t grpc_ntohl(uint32_t netlong) { return ntohl(netlong); } + +int grpc_inet_pton(int af, const char* src, void* dst) { + return inet_pton(af, src, dst); +} + +const char* grpc_inet_ntop(int af, const void* src, char* dst, size_t size) { + inet_ntop(af, src, dst, size); + return dst; +} + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/gpr/alloc.cc b/src/core/lib/gpr/alloc.cc new file mode 100644 index 00000000..9a46ff85 --- /dev/null +++ b/src/core/lib/gpr/alloc.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include + +#include "src/core/lib/profiling/timers.h" + +void* gpr_malloc(size_t size) { + GPR_TIMER_SCOPE("gpr_malloc", 0); + void* p; + if (size == 0) return nullptr; + p = malloc(size); + if (!p) { + abort(); + } + return p; +} + +void* gpr_zalloc(size_t size) { + GPR_TIMER_SCOPE("gpr_zalloc", 0); + void* p; + if (size == 0) return nullptr; + p = calloc(size, 1); + if (!p) { + abort(); + } + return p; +} + +void gpr_free(void* p) { + GPR_TIMER_SCOPE("gpr_free", 0); + free(p); +} + +void* gpr_realloc(void* p, size_t size) { + GPR_TIMER_SCOPE("gpr_realloc", 0); + if ((size == 0) && (p == nullptr)) return nullptr; + p = realloc(p, size); + if (!p) { + abort(); + } + return p; +} + +void* gpr_malloc_aligned(size_t size, size_t alignment) { + GPR_ASSERT(((alignment - 1) & alignment) == 0); // Must be power of 2. + size_t extra = alignment - 1 + sizeof(void*); + void* p = gpr_malloc(size + extra); + void** ret = reinterpret_cast( + (reinterpret_cast(p) + extra) & ~(alignment - 1)); + ret[-1] = p; + return ret; +} + +void gpr_free_aligned(void* ptr) { gpr_free((static_cast(ptr))[-1]); } diff --git a/src/core/lib/gpr/atm.cc b/src/core/lib/gpr/atm.cc new file mode 100644 index 00000000..c40e97ea --- /dev/null +++ b/src/core/lib/gpr/atm.cc @@ -0,0 +1,35 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/gpr/useful.h" + +gpr_atm gpr_atm_no_barrier_clamped_add(gpr_atm* value, gpr_atm delta, + gpr_atm min, gpr_atm max) { + gpr_atm current_value; + gpr_atm new_value; + do { + current_value = gpr_atm_no_barrier_load(value); + new_value = grpc_core::Clamp(current_value + delta, min, max); + if (new_value == current_value) break; + } while (!gpr_atm_no_barrier_cas(value, current_value, new_value)); + return new_value; +} diff --git a/src/core/lib/gpr/cpu_iphone.cc b/src/core/lib/gpr/cpu_iphone.cc new file mode 100644 index 00000000..94a724de --- /dev/null +++ b/src/core/lib/gpr/cpu_iphone.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#ifdef GPR_CPU_IPHONE + +#include + +unsigned gpr_cpu_num_cores(void) { + size_t len; + unsigned int ncpu; + len = sizeof(ncpu); + sysctlbyname("hw.ncpu", &ncpu, &len, NULL, 0); + + return ncpu; +} + +/* Most code that's using this is using it to shard across work queues. So + unless profiling shows it's a problem or there appears a way to detect the + currently running CPU core, let's have it shard the default way. + Note that the interface in cpu.h lets gpr_cpu_num_cores return 0, but doing + it makes it impossible for gpr_cpu_current_cpu to satisfy its stated range, + and some code might be relying on it. */ +unsigned gpr_cpu_current_cpu(void) { return 0; } + +#endif /* GPR_CPU_IPHONE */ diff --git a/src/core/lib/gpr/cpu_linux.cc b/src/core/lib/gpr/cpu_linux.cc new file mode 100644 index 00000000..2e16e3d9 --- /dev/null +++ b/src/core/lib/gpr/cpu_linux.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif /* _GNU_SOURCE */ + +#include + +#ifdef GPR_CPU_LINUX + +#include +#include +#include +#include + +#include +#include +#include + +static int ncpus = 0; + +static void init_num_cpus() { +#ifndef GPR_MUSL_LIBC_COMPAT + if (sched_getcpu() < 0) { + gpr_log(GPR_ERROR, "Error determining current CPU: %s\n", strerror(errno)); + ncpus = 1; + return; + } +#endif + /* This must be signed. sysconf returns -1 when the number cannot be + determined */ + ncpus = static_cast(sysconf(_SC_NPROCESSORS_CONF)); + if (ncpus < 1) { + gpr_log(GPR_ERROR, "Cannot determine number of CPUs: assuming 1"); + ncpus = 1; + } +} + +unsigned gpr_cpu_num_cores(void) { + static gpr_once once = GPR_ONCE_INIT; + gpr_once_init(&once, init_num_cpus); + return static_cast(ncpus); +} + +unsigned gpr_cpu_current_cpu(void) { +#ifdef GPR_MUSL_LIBC_COMPAT + // sched_getcpu() is undefined on musl + return 0; +#else + if (gpr_cpu_num_cores() == 1) { + return 0; + } + int cpu = sched_getcpu(); + if (cpu < 0) { + gpr_log(GPR_ERROR, "Error determining current CPU: %s\n", strerror(errno)); + return 0; + } + if (static_cast(cpu) >= gpr_cpu_num_cores()) { + gpr_log(GPR_DEBUG, "Cannot handle hot-plugged CPUs"); + return 0; + } + return static_cast(cpu); +#endif +} + +#endif /* GPR_CPU_LINUX */ diff --git a/src/core/lib/gpr/cpu_posix.cc b/src/core/lib/gpr/cpu_posix.cc new file mode 100644 index 00000000..7a1d5460 --- /dev/null +++ b/src/core/lib/gpr/cpu_posix.cc @@ -0,0 +1,83 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if defined(GPR_CPU_POSIX) + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" + +static long ncpus = 0; + +static pthread_key_t thread_id_key; + +static void init_ncpus() { + ncpus = sysconf(_SC_NPROCESSORS_CONF); + if (ncpus < 1 || ncpus > INT32_MAX) { + gpr_log(GPR_ERROR, "Cannot determine number of CPUs: assuming 1"); + ncpus = 1; + } +} + +unsigned gpr_cpu_num_cores(void) { + static gpr_once once = GPR_ONCE_INIT; + gpr_once_init(&once, init_ncpus); + return (unsigned)ncpus; +} + +static void delete_thread_id(void* value) { + if (value) { + free(value); + } +} + +static void init_thread_id_key(void) { + pthread_key_create(&thread_id_key, delete_thread_id); +} + +unsigned gpr_cpu_current_cpu(void) { + /* NOTE: there's no way I know to return the actual cpu index portably... + most code that's using this is using it to shard across work queues though, + so here we use thread identity instead to achieve a similar though not + identical effect */ + static gpr_once once = GPR_ONCE_INIT; + gpr_once_init(&once, init_thread_id_key); + + unsigned int* thread_id = + static_cast(pthread_getspecific(thread_id_key)); + if (thread_id == nullptr) { + // Note we cannot use gpr_malloc here because this allocation can happen in + // a main thread and will only be free'd when the main thread exits, which + // will cause our internal memory counters to believe it is a leak. + thread_id = static_cast(malloc(sizeof(unsigned int))); + pthread_setspecific(thread_id_key, thread_id); + } + + return (unsigned)grpc_core::HashPointer(thread_id, gpr_cpu_num_cores()); +} + +#endif /* GPR_CPU_POSIX */ diff --git a/src/core/lib/gpr/cpu_windows.cc b/src/core/lib/gpr/cpu_windows.cc new file mode 100644 index 00000000..8d894534 --- /dev/null +++ b/src/core/lib/gpr/cpu_windows.cc @@ -0,0 +1,33 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS +#include +#include + +unsigned gpr_cpu_num_cores(void) { + SYSTEM_INFO si; + GetSystemInfo(&si); + return si.dwNumberOfProcessors; +} + +unsigned gpr_cpu_current_cpu(void) { return GetCurrentProcessorNumber(); } + +#endif /* GPR_WINDOWS */ diff --git a/src/core/lib/gpr/env_linux.cc b/src/core/lib/gpr/env_linux.cc new file mode 100644 index 00000000..4b332468 --- /dev/null +++ b/src/core/lib/gpr/env_linux.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* for secure_getenv. */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include + +#ifdef GPR_LINUX_ENV + +#include +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" + +char* gpr_getenv(const char* name) { + char* result = nullptr; +#if defined(GPR_BACKWARDS_COMPATIBILITY_MODE) + typedef char* (*getenv_type)(const char*); + static getenv_type getenv_func = nullptr; + /* Check to see which getenv variant is supported (go from most + * to least secure) */ + if (getenv_func == nullptr) { + const char* names[] = {"secure_getenv", "__secure_getenv", "getenv"}; + for (size_t i = 0; i < GPR_ARRAY_SIZE(names); i++) { + getenv_func = (getenv_type)dlsym(RTLD_DEFAULT, names[i]); + if (getenv_func != nullptr) { + break; + } + } + } + result = getenv_func(name); +#elif __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 17) + result = secure_getenv(name); +#else + result = getenv(name); +#endif + return result == nullptr ? result : gpr_strdup(result); +} + +void gpr_setenv(const char* name, const char* value) { + int res = setenv(name, value, 1); + GPR_ASSERT(res == 0); +} + +void gpr_unsetenv(const char* name) { + int res = unsetenv(name); + GPR_ASSERT(res == 0); +} + +#endif /* GPR_LINUX_ENV */ diff --git a/src/core/lib/gpr/env_posix.cc b/src/core/lib/gpr/env_posix.cc new file mode 100644 index 00000000..fb2a21c6 --- /dev/null +++ b/src/core/lib/gpr/env_posix.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_POSIX_ENV + +#include + +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" + +char* gpr_getenv(const char* name) { + char* result = getenv(name); + return result == nullptr ? result : gpr_strdup(result); +} + +void gpr_setenv(const char* name, const char* value) { + int res = setenv(name, value, 1); + GPR_ASSERT(res == 0); +} + +void gpr_unsetenv(const char* name) { + int res = unsetenv(name); + GPR_ASSERT(res == 0); +} + +#endif /* GPR_POSIX_ENV */ diff --git a/src/core/lib/gpr/env_windows.cc b/src/core/lib/gpr/env_windows.cc new file mode 100644 index 00000000..76c45fb8 --- /dev/null +++ b/src/core/lib/gpr/env_windows.cc @@ -0,0 +1,74 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS_ENV + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/string_windows.h" + +char* gpr_getenv(const char* name) { + char* result = NULL; + DWORD size; + LPTSTR tresult = NULL; + LPTSTR tname = gpr_char_to_tchar(name); + DWORD ret; + + ret = GetEnvironmentVariable(tname, NULL, 0); + if (ret == 0) { + gpr_free(tname); + return NULL; + } + size = ret * (DWORD)sizeof(TCHAR); + tresult = (LPTSTR)gpr_malloc(size); + ret = GetEnvironmentVariable(tname, tresult, size); + gpr_free(tname); + if (ret == 0) { + gpr_free(tresult); + return NULL; + } + result = gpr_tchar_to_char(tresult); + gpr_free(tresult); + return result; +} + +void gpr_setenv(const char* name, const char* value) { + LPTSTR tname = gpr_char_to_tchar(name); + LPTSTR tvalue = gpr_char_to_tchar(value); + BOOL res = SetEnvironmentVariable(tname, tvalue); + gpr_free(tname); + gpr_free(tvalue); + GPR_ASSERT(res); +} + +void gpr_unsetenv(const char* name) { + LPTSTR tname = gpr_char_to_tchar(name); + BOOL res = SetEnvironmentVariable(tname, NULL); + gpr_free(tname); + GPR_ASSERT(res); +} + +#endif /* GPR_WINDOWS_ENV */ diff --git a/src/core/lib/gpr/log.cc b/src/core/lib/gpr/log.cc new file mode 100644 index 00000000..d18f52b4 --- /dev/null +++ b/src/core/lib/gpr/log.cc @@ -0,0 +1,140 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/global_config.h" + +#ifndef GPR_DEFAULT_LOG_VERBOSITY_STRING +#define GPR_DEFAULT_LOG_VERBOSITY_STRING "ERROR" +#endif // !GPR_DEFAULT_LOG_VERBOSITY_STRING + +GPR_GLOBAL_CONFIG_DEFINE_STRING(grpc_verbosity, + GPR_DEFAULT_LOG_VERBOSITY_STRING, + "Default gRPC logging verbosity") +GPR_GLOBAL_CONFIG_DEFINE_STRING(grpc_stacktrace_minloglevel, "", + "Messages logged at the same or higher level " + "than this will print stacktrace") + +static constexpr gpr_atm GPR_LOG_SEVERITY_UNSET = GPR_LOG_SEVERITY_ERROR + 10; +static constexpr gpr_atm GPR_LOG_SEVERITY_NONE = GPR_LOG_SEVERITY_ERROR + 11; + +void gpr_default_log(gpr_log_func_args* args); +static gpr_atm g_log_func = reinterpret_cast(gpr_default_log); +static gpr_atm g_min_severity_to_print = GPR_LOG_SEVERITY_UNSET; +static gpr_atm g_min_severity_to_print_stacktrace = GPR_LOG_SEVERITY_UNSET; + +const char* gpr_log_severity_string(gpr_log_severity severity) { + switch (severity) { + case GPR_LOG_SEVERITY_DEBUG: + return "D"; + case GPR_LOG_SEVERITY_INFO: + return "I"; + case GPR_LOG_SEVERITY_ERROR: + return "E"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +int gpr_should_log(gpr_log_severity severity) { + return static_cast(severity) >= + gpr_atm_no_barrier_load(&g_min_severity_to_print) + ? 1 + : 0; +} + +int gpr_should_log_stacktrace(gpr_log_severity severity) { + return static_cast(severity) >= + gpr_atm_no_barrier_load(&g_min_severity_to_print_stacktrace) + ? 1 + : 0; +} + +void gpr_log_message(const char* file, int line, gpr_log_severity severity, + const char* message) { + if (gpr_should_log(severity) == 0) { + return; + } + + gpr_log_func_args lfargs; + memset(&lfargs, 0, sizeof(lfargs)); + lfargs.file = file; + lfargs.line = line; + lfargs.severity = severity; + lfargs.message = message; + reinterpret_cast(gpr_atm_no_barrier_load(&g_log_func))(&lfargs); +} + +void gpr_set_log_verbosity(gpr_log_severity min_severity_to_print) { + gpr_atm_no_barrier_store(&g_min_severity_to_print, + (gpr_atm)min_severity_to_print); +} + +static gpr_atm parse_log_severity(const char* str, gpr_atm error_value) { + if (gpr_stricmp(str, "DEBUG") == 0) { + return GPR_LOG_SEVERITY_DEBUG; + } else if (gpr_stricmp(str, "INFO") == 0) { + return GPR_LOG_SEVERITY_INFO; + } else if (gpr_stricmp(str, "ERROR") == 0) { + return GPR_LOG_SEVERITY_ERROR; + } else if (gpr_stricmp(str, "NONE") == 0) { + return GPR_LOG_SEVERITY_NONE; + } else { + return error_value; + } +} + +void gpr_log_verbosity_init() { + // init verbosity when it hasn't been set + if ((gpr_atm_no_barrier_load(&g_min_severity_to_print)) == + GPR_LOG_SEVERITY_UNSET) { + grpc_core::UniquePtr verbosity = + GPR_GLOBAL_CONFIG_GET(grpc_verbosity); + gpr_atm min_severity_to_print = GPR_LOG_SEVERITY_ERROR; + if (strlen(verbosity.get()) > 0) { + min_severity_to_print = + parse_log_severity(verbosity.get(), min_severity_to_print); + } + gpr_atm_no_barrier_store(&g_min_severity_to_print, min_severity_to_print); + } + // init stacktrace_minloglevel when it hasn't been set + if ((gpr_atm_no_barrier_load(&g_min_severity_to_print_stacktrace)) == + GPR_LOG_SEVERITY_UNSET) { + grpc_core::UniquePtr stacktrace_minloglevel = + GPR_GLOBAL_CONFIG_GET(grpc_stacktrace_minloglevel); + gpr_atm min_severity_to_print_stacktrace = GPR_LOG_SEVERITY_NONE; + if (strlen(stacktrace_minloglevel.get()) > 0) { + min_severity_to_print_stacktrace = parse_log_severity( + stacktrace_minloglevel.get(), min_severity_to_print_stacktrace); + } + gpr_atm_no_barrier_store(&g_min_severity_to_print_stacktrace, + min_severity_to_print_stacktrace); + } +} + +void gpr_set_log_function(gpr_log_func f) { + gpr_atm_no_barrier_store(&g_log_func, (gpr_atm)(f ? f : gpr_default_log)); +} diff --git a/src/core/lib/gpr/log_android.cc b/src/core/lib/gpr/log_android.cc new file mode 100644 index 00000000..11ffd64d --- /dev/null +++ b/src/core/lib/gpr/log_android.cc @@ -0,0 +1,77 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_ANDROID + +#include +#include +#include +#include + +#include +#include + +static android_LogPriority severity_to_log_priority(gpr_log_severity severity) { + switch (severity) { + case GPR_LOG_SEVERITY_DEBUG: + return ANDROID_LOG_DEBUG; + case GPR_LOG_SEVERITY_INFO: + return ANDROID_LOG_INFO; + case GPR_LOG_SEVERITY_ERROR: + return ANDROID_LOG_ERROR; + } + return ANDROID_LOG_DEFAULT; +} + +void gpr_log(const char* file, int line, gpr_log_severity severity, + const char* format, ...) { + /* Avoid message construction if gpr_log_message won't log */ + if (gpr_should_log(severity) == 0) { + return; + } + char* message = NULL; + va_list args; + va_start(args, format); + vasprintf(&message, format, args); + va_end(args); + gpr_log_message(file, line, severity, message); + free(message); +} + +void gpr_default_log(gpr_log_func_args* args) { + const char* final_slash; + const char* display_file; + char* output = NULL; + + final_slash = strrchr(args->file, '/'); + if (final_slash == NULL) + display_file = args->file; + else + display_file = final_slash + 1; + + asprintf(&output, "%s:%d] %s", display_file, args->line, args->message); + + __android_log_write(severity_to_log_priority(args->severity), "GRPC", output); + + /* allocated by asprintf => use free, not gpr_free */ + free(output); +} + +#endif /* GPR_ANDROID */ diff --git a/src/core/lib/gpr/log_linux.cc b/src/core/lib/gpr/log_linux.cc new file mode 100644 index 00000000..850ee13a --- /dev/null +++ b/src/core/lib/gpr/log_linux.cc @@ -0,0 +1,114 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef _POSIX_SOURCE +#define _POSIX_SOURCE +#endif + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include + +#ifdef GPR_LINUX_LOG + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gprpp/examine_stack.h" + +int gpr_should_log_stacktrace(gpr_log_severity severity); + +static long sys_gettid(void) { return syscall(__NR_gettid); } + +void gpr_log(const char* file, int line, gpr_log_severity severity, + const char* format, ...) { + /* Avoid message construction if gpr_log_message won't log */ + if (gpr_should_log(severity) == 0) { + return; + } + char* message = nullptr; + va_list args; + va_start(args, format); + if (vasprintf(&message, format, args) == -1) { + va_end(args); + return; + } + va_end(args); + gpr_log_message(file, line, severity, message); + /* message has been allocated by vasprintf above, and needs free */ + free(message); +} + +void gpr_default_log(gpr_log_func_args* args) { + const char* final_slash; + const char* display_file; + char time_buffer[64]; + time_t timer; + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + struct tm tm; + static GPR_THREAD_LOCAL(long) tid(0); + if (tid == 0) tid = sys_gettid(); + + timer = static_cast(now.tv_sec); + final_slash = strrchr(args->file, '/'); + if (final_slash == nullptr) { + display_file = args->file; + } else { + display_file = final_slash + 1; + } + + if (!localtime_r(&timer, &tm)) { + strcpy(time_buffer, "error:localtime"); + } else if (0 == + strftime(time_buffer, sizeof(time_buffer), "%m%d %H:%M:%S", &tm)) { + strcpy(time_buffer, "error:strftime"); + } + + std::string prefix = absl::StrFormat( + "%s%s.%09" PRId32 " %7ld %s:%d]", gpr_log_severity_string(args->severity), + time_buffer, now.tv_nsec, tid, display_file, args->line); + + absl::optional stack_trace = + gpr_should_log_stacktrace(args->severity) + ? grpc_core::GetCurrentStackTrace() + : absl::nullopt; + if (stack_trace) { + fprintf(stderr, "%-60s %s\n%s\n", prefix.c_str(), args->message, + stack_trace->c_str()); + } else { + fprintf(stderr, "%-60s %s\n", prefix.c_str(), args->message); + } +} + +#endif /* GPR_LINUX_LOG */ diff --git a/src/core/lib/gpr/log_posix.cc b/src/core/lib/gpr/log_posix.cc new file mode 100644 index 00000000..2c28bddd --- /dev/null +++ b/src/core/lib/gpr/log_posix.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_POSIX_LOG + +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gprpp/examine_stack.h" + +int gpr_should_log_stacktrace(gpr_log_severity severity); + +static intptr_t sys_gettid(void) { return (intptr_t)pthread_self(); } + +void gpr_log(const char* file, int line, gpr_log_severity severity, + const char* format, ...) { + /* Avoid message construction if gpr_log_message won't log */ + if (gpr_should_log(severity) == 0) { + return; + } + char buf[64]; + char* allocated = nullptr; + char* message = nullptr; + int ret; + va_list args; + va_start(args, format); + ret = vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + if (ret < 0) { + message = nullptr; + } else if ((size_t)ret <= sizeof(buf) - 1) { + message = buf; + } else { + message = allocated = (char*)gpr_malloc((size_t)ret + 1); + va_start(args, format); + vsnprintf(message, (size_t)(ret + 1), format, args); + va_end(args); + } + gpr_log_message(file, line, severity, message); + gpr_free(allocated); +} + +void gpr_default_log(gpr_log_func_args* args) { + const char* final_slash; + const char* display_file; + char time_buffer[64]; + time_t timer; + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + struct tm tm; + + timer = (time_t)now.tv_sec; + final_slash = strrchr(args->file, '/'); + if (final_slash == nullptr) + display_file = args->file; + else + display_file = final_slash + 1; + + if (!localtime_r(&timer, &tm)) { + strcpy(time_buffer, "error:localtime"); + } else if (0 == + strftime(time_buffer, sizeof(time_buffer), "%m%d %H:%M:%S", &tm)) { + strcpy(time_buffer, "error:strftime"); + } + + std::string prefix = absl::StrFormat( + "%s%s.%09d %7" PRIdPTR " %s:%d]", gpr_log_severity_string(args->severity), + time_buffer, (int)(now.tv_nsec), sys_gettid(), display_file, args->line); + + absl::optional stack_trace = + gpr_should_log_stacktrace(args->severity) + ? grpc_core::GetCurrentStackTrace() + : absl::nullopt; + if (stack_trace) { + fprintf(stderr, "%-70s %s\n%s\n", prefix.c_str(), args->message, + stack_trace->c_str()); + } else { + fprintf(stderr, "%-70s %s\n", prefix.c_str(), args->message); + } +} + +#endif /* defined(GPR_POSIX_LOG) */ diff --git a/src/core/lib/gpr/log_windows.cc b/src/core/lib/gpr/log_windows.cc new file mode 100644 index 00000000..472121b3 --- /dev/null +++ b/src/core/lib/gpr/log_windows.cc @@ -0,0 +1,116 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS_LOG + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/string_windows.h" +#include "src/core/lib/gprpp/examine_stack.h" + +int gpr_should_log_stacktrace(gpr_log_severity severity); + +void gpr_log(const char* file, int line, gpr_log_severity severity, + const char* format, ...) { + /* Avoid message construction if gpr_log_message won't log */ + if (gpr_should_log(severity) == 0) { + return; + } + + char* message = NULL; + va_list args; + int ret; + + /* Determine the length. */ + va_start(args, format); + ret = _vscprintf(format, args); + va_end(args); + if (ret < 0) { + message = NULL; + } else { + /* Allocate a new buffer, with space for the NUL terminator. */ + size_t strp_buflen = (size_t)ret + 1; + message = (char*)gpr_malloc(strp_buflen); + + /* Print to the buffer. */ + va_start(args, format); + ret = vsnprintf_s(message, strp_buflen, _TRUNCATE, format, args); + va_end(args); + if ((size_t)ret != strp_buflen - 1) { + /* This should never happen. */ + gpr_free(message); + message = NULL; + } + } + + gpr_log_message(file, line, severity, message); + gpr_free(message); +} + +/* Simple starter implementation */ +void gpr_default_log(gpr_log_func_args* args) { + const char* final_slash; + const char* display_file; + char time_buffer[64]; + time_t timer; + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + struct tm tm; + + timer = (time_t)now.tv_sec; + final_slash = strrchr(args->file, '\\'); + if (final_slash == NULL) + display_file = args->file; + else + display_file = final_slash + 1; + + if (localtime_s(&tm, &timer)) { + strcpy(time_buffer, "error:localtime"); + } else if (0 == + strftime(time_buffer, sizeof(time_buffer), "%m%d %H:%M:%S", &tm)) { + strcpy(time_buffer, "error:strftime"); + } + + absl::optional stack_trace = + gpr_should_log_stacktrace(args->severity) + ? grpc_core::GetCurrentStackTrace() + : absl::nullopt; + if (stack_trace) { + fprintf(stderr, "%s%s.%09u %5lu %s:%d] %s\n%s\n", + gpr_log_severity_string(args->severity), time_buffer, + (int)(now.tv_nsec), GetCurrentThreadId(), display_file, args->line, + args->message, stack_trace->c_str()); + } else { + fprintf(stderr, "%s%s.%09u %5lu %s:%d] %s\n", + gpr_log_severity_string(args->severity), time_buffer, + (int)(now.tv_nsec), GetCurrentThreadId(), display_file, args->line, + args->message); + } + fflush(stderr); +} + +#endif /* GPR_WINDOWS_LOG */ diff --git a/src/core/lib/gpr/murmur_hash.cc b/src/core/lib/gpr/murmur_hash.cc new file mode 100644 index 00000000..063b104b --- /dev/null +++ b/src/core/lib/gpr/murmur_hash.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gpr/murmur_hash.h" + +#include + +#include "absl/base/attributes.h" + +#define ROTL32(x, r) (((x) << (r)) | ((x) >> (32 - (r)))) + +#define FMIX32(h) \ + (h) ^= (h) >> 16; \ + (h) *= 0x85ebca6b; \ + (h) ^= (h) >> 13; \ + (h) *= 0xc2b2ae35; \ + (h) ^= (h) >> 16; + +uint32_t gpr_murmur_hash3(const void* key, size_t len, uint32_t seed) { + uint32_t h1 = seed; + uint32_t k1; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + const uint8_t* keyptr = static_cast(key); + const size_t bsize = sizeof(k1); + const size_t nblocks = len / bsize; + + /* body */ + for (size_t i = 0; i < nblocks; i++, keyptr += bsize) { + memcpy(&k1, keyptr, bsize); + + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + k1 = 0; + + /* tail */ + switch (len & 3) { + case 3: + k1 ^= (static_cast(keyptr[2])) << 16; + ABSL_FALLTHROUGH_INTENDED; + case 2: + k1 ^= (static_cast(keyptr[1])) << 8; + ABSL_FALLTHROUGH_INTENDED; + case 1: + k1 ^= keyptr[0]; + k1 *= c1; + k1 = ROTL32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + + /* finalization */ + h1 ^= static_cast(len); + FMIX32(h1); + return h1; +} diff --git a/src/core/lib/gpr/string.cc b/src/core/lib/gpr/string.cc new file mode 100644 index 00000000..9f12ed47 --- /dev/null +++ b/src/core/lib/gpr/string.cc @@ -0,0 +1,343 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gpr/string.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" + +char* gpr_strdup(const char* src) { + char* dst; + size_t len; + + if (!src) { + return nullptr; + } + + len = strlen(src) + 1; + dst = static_cast(gpr_malloc(len)); + + memcpy(dst, src, len); + + return dst; +} + +std::string gpr_format_timespec(gpr_timespec tm) { + char time_buffer[35]; + char ns_buffer[11]; // '.' + 9 digits of precision + struct tm* tm_info = localtime(reinterpret_cast(&tm.tv_sec)); + strftime(time_buffer, sizeof(time_buffer), "%Y-%m-%dT%H:%M:%S", tm_info); + snprintf(ns_buffer, 11, ".%09d", tm.tv_nsec); + // This loop trims off trailing zeros by inserting a null character that the + // right point. We iterate in chunks of three because we want 0, 3, 6, or 9 + // fractional digits. + for (int i = 7; i >= 1; i -= 3) { + if (ns_buffer[i] == '0' && ns_buffer[i + 1] == '0' && + ns_buffer[i + 2] == '0') { + ns_buffer[i] = '\0'; + // Edge case in which all fractional digits were 0. + if (i == 1) { + ns_buffer[0] = '\0'; + } + } else { + break; + } + } + return absl::StrCat(time_buffer, ns_buffer, "Z"); +} + +struct dump_out { + size_t capacity; + size_t length; + char* data; +}; + +static dump_out dump_out_create(void) { + dump_out r = {0, 0, nullptr}; + return r; +} + +static void dump_out_append(dump_out* out, char c) { + if (out->length == out->capacity) { + out->capacity = std::max(size_t(8), 2 * out->capacity); + out->data = static_cast(gpr_realloc(out->data, out->capacity)); + } + out->data[out->length++] = c; +} + +static void hexdump(dump_out* out, const char* buf, size_t len) { + static const char* hex = "0123456789abcdef"; + + const uint8_t* const beg = reinterpret_cast(buf); + const uint8_t* const end = beg + len; + const uint8_t* cur; + + for (cur = beg; cur != end; ++cur) { + if (cur != beg) dump_out_append(out, ' '); + dump_out_append(out, hex[*cur >> 4]); + dump_out_append(out, hex[*cur & 0xf]); + } +} + +static void asciidump(dump_out* out, const char* buf, size_t len) { + const uint8_t* const beg = reinterpret_cast(buf); + const uint8_t* const end = beg + len; + const uint8_t* cur; + int out_was_empty = (out->length == 0); + if (!out_was_empty) { + dump_out_append(out, ' '); + dump_out_append(out, '\''); + } + for (cur = beg; cur != end; ++cur) { + dump_out_append( + out, (isprint(*cur) ? *reinterpret_cast(cur) : '.')); + } + if (!out_was_empty) { + dump_out_append(out, '\''); + } +} + +char* gpr_dump_return_len(const char* buf, size_t len, uint32_t flags, + size_t* out_len) { + dump_out out = dump_out_create(); + if (flags & GPR_DUMP_HEX) { + hexdump(&out, buf, len); + } + if (flags & GPR_DUMP_ASCII) { + asciidump(&out, buf, len); + } + dump_out_append(&out, 0); + *out_len = out.length; + return out.data; +} + +char* gpr_dump(const char* buf, size_t len, uint32_t flags) { + size_t unused; + return gpr_dump_return_len(buf, len, flags, &unused); +} + +int gpr_parse_bytes_to_uint32(const char* buf, size_t len, uint32_t* result) { + uint32_t out = 0; + uint32_t new_val; + size_t i; + + if (len == 0) return 0; /* must have some bytes */ + + for (i = 0; i < len; i++) { + if (buf[i] < '0' || buf[i] > '9') return 0; /* bad char */ + new_val = 10 * out + static_cast(buf[i] - '0'); + if (new_val < out) return 0; /* overflow */ + out = new_val; + } + + *result = out; + return 1; +} + +void gpr_reverse_bytes(char* str, int len) { + char *p1, *p2; + for (p1 = str, p2 = str + len - 1; p2 > p1; ++p1, --p2) { + char temp = *p1; + *p1 = *p2; + *p2 = temp; + } +} + +int gpr_ltoa(long value, char* output) { + long sign; + int i = 0; + + if (value == 0) { + output[0] = '0'; + output[1] = 0; + return 1; + } + + sign = value < 0 ? -1 : 1; + while (value) { + output[i++] = static_cast('0' + sign * (value % 10)); + value /= 10; + } + if (sign < 0) output[i++] = '-'; + gpr_reverse_bytes(output, i); + output[i] = 0; + return i; +} + +int int64_ttoa(int64_t value, char* output) { + int64_t sign; + int i = 0; + + if (value == 0) { + output[0] = '0'; + output[1] = 0; + return 1; + } + + sign = value < 0 ? -1 : 1; + while (value) { + output[i++] = static_cast('0' + sign * (value % 10)); + value /= 10; + } + if (sign < 0) output[i++] = '-'; + gpr_reverse_bytes(output, i); + output[i] = 0; + return i; +} + +int gpr_parse_nonnegative_int(const char* value) { + char* end; + long result = strtol(value, &end, 10); + if (*end != '\0' || result < 0 || result > INT_MAX) return -1; + return static_cast(result); +} + +char* gpr_leftpad(const char* str, char flag, size_t length) { + const size_t str_length = strlen(str); + const size_t out_length = str_length > length ? str_length : length; + char* out = static_cast(gpr_malloc(out_length + 1)); + memset(out, flag, out_length - str_length); + memcpy(out + out_length - str_length, str, str_length); + out[out_length] = 0; + return out; +} + +char* gpr_strjoin(const char** strs, size_t nstrs, size_t* final_length) { + return gpr_strjoin_sep(strs, nstrs, "", final_length); +} + +char* gpr_strjoin_sep(const char** strs, size_t nstrs, const char* sep, + size_t* final_length) { + const size_t sep_len = strlen(sep); + size_t out_length = 0; + size_t i; + char* out; + for (i = 0; i < nstrs; i++) { + out_length += strlen(strs[i]); + } + out_length += 1; /* null terminator */ + if (nstrs > 0) { + out_length += sep_len * (nstrs - 1); /* separators */ + } + out = static_cast(gpr_malloc(out_length)); + out_length = 0; + for (i = 0; i < nstrs; i++) { + const size_t slen = strlen(strs[i]); + if (i != 0) { + memcpy(out + out_length, sep, sep_len); + out_length += sep_len; + } + memcpy(out + out_length, strs[i], slen); + out_length += slen; + } + out[out_length] = 0; + if (final_length != nullptr) { + *final_length = out_length; + } + return out; +} + +int gpr_strincmp(const char* a, const char* b, size_t n) { + int ca, cb; + do { + ca = tolower(*a); + cb = tolower(*b); + ++a; + ++b; + --n; + } while (ca == cb && ca != 0 && cb != 0 && n != 0); + return ca - cb; +} + +int gpr_stricmp(const char* a, const char* b) { + return gpr_strincmp(a, b, SIZE_MAX); +} + +static void add_string_to_split(const char* beg, const char* end, char*** strs, + size_t* nstrs, size_t* capstrs) { + char* out = + static_cast(gpr_malloc(static_cast(end - beg) + 1)); + memcpy(out, beg, static_cast(end - beg)); + out[end - beg] = 0; + if (*nstrs == *capstrs) { + *capstrs = std::max(size_t(8), 2 * *capstrs); + *strs = static_cast(gpr_realloc(*strs, sizeof(*strs) * *capstrs)); + } + (*strs)[*nstrs] = out; + ++*nstrs; +} + +void gpr_string_split(const char* input, const char* sep, char*** strs, + size_t* nstrs) { + const char* next; + *strs = nullptr; + *nstrs = 0; + size_t capstrs = 0; + while ((next = strstr(input, sep))) { + add_string_to_split(input, next, strs, nstrs, &capstrs); + input = next + strlen(sep); + } + add_string_to_split(input, input + strlen(input), strs, nstrs, &capstrs); +} + +void* gpr_memrchr(const void* s, int c, size_t n) { + if (s == nullptr) return nullptr; + char* b = const_cast(reinterpret_cast(s)); + size_t i; + for (i = 0; i < n; i++) { + if (b[n - i - 1] == c) { + return &b[n - i - 1]; + } + } + return nullptr; +} + +bool gpr_parse_bool_value(const char* value, bool* dst) { + const char* kTrue[] = {"1", "t", "true", "y", "yes"}; + const char* kFalse[] = {"0", "f", "false", "n", "no"}; + static_assert(sizeof(kTrue) == sizeof(kFalse), "true_false_equal"); + + if (value == nullptr) { + return false; + } + for (size_t i = 0; i < GPR_ARRAY_SIZE(kTrue); ++i) { + if (gpr_stricmp(value, kTrue[i]) == 0) { + *dst = true; + return true; + } else if (gpr_stricmp(value, kFalse[i]) == 0) { + *dst = false; + return true; + } + } + return false; // didn't match a legal input +} diff --git a/src/core/lib/gpr/string_posix.cc b/src/core/lib/gpr/string_posix.cc new file mode 100644 index 00000000..d32775fb --- /dev/null +++ b/src/core/lib/gpr/string_posix.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_POSIX_STRING + +#include +#include +#include + +#include +#include + +int gpr_asprintf(char** strp, const char* format, ...) { + va_list args; + int ret; + char buf[64]; + size_t strp_buflen; + + /* Use a constant-sized buffer to determine the length. */ + va_start(args, format); + ret = vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + if (ret < 0) { + *strp = nullptr; + return -1; + } + + /* Allocate a new buffer, with space for the NUL terminator. */ + strp_buflen = static_cast(ret) + 1; + if ((*strp = static_cast(gpr_malloc(strp_buflen))) == nullptr) { + /* This shouldn't happen, because gpr_malloc() calls abort(). */ + return -1; + } + + /* Return early if we have all the bytes. */ + if (strp_buflen <= sizeof(buf)) { + memcpy(*strp, buf, strp_buflen); + return ret; + } + + /* Try again using the larger buffer. */ + va_start(args, format); + ret = vsnprintf(*strp, strp_buflen, format, args); + va_end(args); + if (static_cast(ret) == strp_buflen - 1) { + return ret; + } + + /* This should never happen. */ + gpr_free(*strp); + *strp = nullptr; + return -1; +} + +#endif /* GPR_POSIX_STRING */ diff --git a/src/core/lib/gpr/string_util_windows.cc b/src/core/lib/gpr/string_util_windows.cc new file mode 100644 index 00000000..8c8c99cd --- /dev/null +++ b/src/core/lib/gpr/string_util_windows.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Posix code for gpr snprintf support. */ + +#include + +#ifdef GPR_WINDOWS + +/* Some platforms (namely msys) need wchar to be included BEFORE + anything else, especially strsafe.h. */ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/string_windows.h" + +#if defined UNICODE || defined _UNICODE +LPTSTR +gpr_char_to_tchar(LPCSTR input) { + LPTSTR ret; + int needed = MultiByteToWideChar(CP_UTF8, 0, input, -1, NULL, 0); + if (needed <= 0) return NULL; + ret = (LPTSTR)gpr_malloc((unsigned)needed * sizeof(TCHAR)); + MultiByteToWideChar(CP_UTF8, 0, input, -1, ret, needed); + return ret; +} + +LPSTR +gpr_tchar_to_char(LPCTSTR input) { + LPSTR ret; + int needed = WideCharToMultiByte(CP_UTF8, 0, input, -1, NULL, 0, NULL, NULL); + if (needed <= 0) return NULL; + ret = (LPSTR)gpr_malloc((unsigned)needed); + WideCharToMultiByte(CP_UTF8, 0, input, -1, ret, needed, NULL, NULL); + return ret; +} +#else +LPSTR gpr_tchar_to_char(LPCTSTR input) { return (LPSTR)gpr_strdup(input); } + +LPTSTR gpr_char_to_tchar(LPCTSTR input) { return (LPTSTR)gpr_strdup(input); } +#endif + +char* gpr_format_message(int messageid) { + LPTSTR tmessage; + char* message; + DWORD status = FormatMessage( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, (DWORD)messageid, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), + (LPTSTR)(&tmessage), 0, NULL); + if (status == 0) return gpr_strdup("Unable to retrieve error string"); + message = gpr_tchar_to_char(tmessage); + LocalFree(tmessage); + return message; +} + +#endif /* GPR_WINDOWS */ diff --git a/src/core/lib/gpr/string_windows.cc b/src/core/lib/gpr/string_windows.cc new file mode 100644 index 00000000..25bfd412 --- /dev/null +++ b/src/core/lib/gpr/string_windows.cc @@ -0,0 +1,69 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Windows code for gpr snprintf support. */ + +#include + +#ifdef GPR_WINDOWS_STRING + +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/string.h" + +int gpr_asprintf(char** strp, const char* format, ...) { + va_list args; + int ret; + size_t strp_buflen; + + /* Determine the length. */ + va_start(args, format); + ret = _vscprintf(format, args); + va_end(args); + if (ret < 0) { + *strp = NULL; + return -1; + } + + /* Allocate a new buffer, with space for the NUL terminator. */ + strp_buflen = (size_t)ret + 1; + if ((*strp = (char*)gpr_malloc(strp_buflen)) == NULL) { + /* This shouldn't happen, because gpr_malloc() calls abort(). */ + return -1; + } + + /* Print to the buffer. */ + va_start(args, format); + ret = vsnprintf_s(*strp, strp_buflen, _TRUNCATE, format, args); + va_end(args); + if ((size_t)ret == strp_buflen - 1) { + return ret; + } + + /* This should never happen. */ + gpr_free(*strp); + *strp = NULL; + return -1; +} + +#endif /* GPR_WINDOWS_STRING */ diff --git a/src/core/lib/gpr/sync.cc b/src/core/lib/gpr/sync.cc new file mode 100644 index 00000000..28d506f7 --- /dev/null +++ b/src/core/lib/gpr/sync.cc @@ -0,0 +1,124 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Generic implementation of synchronization primitives. */ + +#include + +#include + +#include +#include +#include + +/* Number of mutexes to allocate for events, to avoid lock contention. + Should be a prime. */ +enum { event_sync_partitions = 31 }; + +/* Events are partitioned by address to avoid lock contention. */ +static struct sync_array_s { + gpr_mu mu; + gpr_cv cv; +} sync_array[event_sync_partitions]; + +/* This routine is executed once on first use, via event_once */ +static gpr_once event_once = GPR_ONCE_INIT; +static void event_initialize(void) { + int i; + for (i = 0; i != event_sync_partitions; i++) { + gpr_mu_init(&sync_array[i].mu); + gpr_cv_init(&sync_array[i].cv); + } +} + +/* Hash ev into an element of sync_array[]. */ +static struct sync_array_s* hash(gpr_event* ev) { + return &sync_array[reinterpret_cast(ev) % event_sync_partitions]; +} + +void gpr_event_init(gpr_event* ev) { + gpr_once_init(&event_once, &event_initialize); + ev->state = 0; +} + +void gpr_event_set(gpr_event* ev, void* value) { + struct sync_array_s* s = hash(ev); + gpr_mu_lock(&s->mu); + GPR_ASSERT(gpr_atm_acq_load(&ev->state) == 0); + gpr_atm_rel_store(&ev->state, (gpr_atm)value); + gpr_cv_broadcast(&s->cv); + gpr_mu_unlock(&s->mu); + GPR_ASSERT(value != nullptr); +} + +void* gpr_event_get(gpr_event* ev) { + return reinterpret_cast(gpr_atm_acq_load(&ev->state)); +} + +void* gpr_event_wait(gpr_event* ev, gpr_timespec abs_deadline) { + void* result = reinterpret_cast(gpr_atm_acq_load(&ev->state)); + if (result == nullptr) { + struct sync_array_s* s = hash(ev); + gpr_mu_lock(&s->mu); + do { + result = reinterpret_cast(gpr_atm_acq_load(&ev->state)); + } while (result == nullptr && !gpr_cv_wait(&s->cv, &s->mu, abs_deadline)); + gpr_mu_unlock(&s->mu); + } + return result; +} + +void gpr_ref_init(gpr_refcount* r, int n) { gpr_atm_rel_store(&r->count, n); } + +void gpr_ref(gpr_refcount* r) { gpr_atm_no_barrier_fetch_add(&r->count, 1); } + +void gpr_ref_non_zero(gpr_refcount* r) { +#ifndef NDEBUG + gpr_atm prior = gpr_atm_no_barrier_fetch_add(&r->count, 1); + assert(prior > 0); +#else + gpr_ref(r); +#endif +} + +void gpr_refn(gpr_refcount* r, int n) { + gpr_atm_no_barrier_fetch_add(&r->count, n); +} + +int gpr_unref(gpr_refcount* r) { + gpr_atm prior = gpr_atm_full_fetch_add(&r->count, -1); + GPR_ASSERT(prior > 0); + return prior == 1; +} + +int gpr_ref_is_unique(gpr_refcount* r) { + return gpr_atm_acq_load(&r->count) == 1; +} + +void gpr_stats_init(gpr_stats_counter* c, intptr_t n) { + gpr_atm_rel_store(&c->value, n); +} + +void gpr_stats_inc(gpr_stats_counter* c, intptr_t inc) { + gpr_atm_no_barrier_fetch_add(&c->value, inc); +} + +intptr_t gpr_stats_read(const gpr_stats_counter* c) { + /* don't need acquire-load, but we have no no-barrier load yet */ + return gpr_atm_acq_load(&c->value); +} diff --git a/src/core/lib/gpr/sync_abseil.cc b/src/core/lib/gpr/sync_abseil.cc new file mode 100644 index 00000000..16333366 --- /dev/null +++ b/src/core/lib/gpr/sync_abseil.cc @@ -0,0 +1,114 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if defined(GPR_ABSEIL_SYNC) && !defined(GPR_CUSTOM_SYNC) + +#include +#include + +#include "absl/base/call_once.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" + +#include +#include +#include +#include + +#include "src/core/lib/profiling/timers.h" + +#ifdef GPR_LOW_LEVEL_COUNTERS +gpr_atm gpr_mu_locks = 0; +gpr_atm gpr_counter_atm_cas = 0; +gpr_atm gpr_counter_atm_add = 0; +#endif + +void gpr_mu_init(gpr_mu* mu) { + static_assert(sizeof(gpr_mu) == sizeof(absl::Mutex), + "gpr_mu and Mutex must be the same size"); + new (mu) absl::Mutex; +} + +void gpr_mu_destroy(gpr_mu* mu) { + reinterpret_cast(mu)->~Mutex(); +} + +void gpr_mu_lock(gpr_mu* mu) ABSL_NO_THREAD_SAFETY_ANALYSIS { + GPR_TIMER_SCOPE("gpr_mu_lock", 0); + reinterpret_cast(mu)->Lock(); +} + +void gpr_mu_unlock(gpr_mu* mu) ABSL_NO_THREAD_SAFETY_ANALYSIS { + GPR_TIMER_SCOPE("gpr_mu_unlock", 0); + reinterpret_cast(mu)->Unlock(); +} + +int gpr_mu_trylock(gpr_mu* mu) { + GPR_TIMER_SCOPE("gpr_mu_trylock", 0); + return reinterpret_cast(mu)->TryLock(); +} + +/*----------------------------------------*/ + +void gpr_cv_init(gpr_cv* cv) { + static_assert(sizeof(gpr_cv) == sizeof(absl::CondVar), + "gpr_cv and CondVar must be the same size"); + new (cv) absl::CondVar; +} + +void gpr_cv_destroy(gpr_cv* cv) { + reinterpret_cast(cv)->~CondVar(); +} + +int gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, gpr_timespec abs_deadline) { + GPR_TIMER_SCOPE("gpr_cv_wait", 0); + if (gpr_time_cmp(abs_deadline, gpr_inf_future(abs_deadline.clock_type)) == + 0) { + reinterpret_cast(cv)->Wait( + reinterpret_cast(mu)); + return 0; + } + abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_REALTIME); + timespec ts = {static_cast(abs_deadline.tv_sec), + static_cast(abs_deadline.tv_nsec)}; + return reinterpret_cast(cv)->WaitWithDeadline( + reinterpret_cast(mu), absl::TimeFromTimespec(ts)); +} + +void gpr_cv_signal(gpr_cv* cv) { + GPR_TIMER_MARK("gpr_cv_signal", 0); + reinterpret_cast(cv)->Signal(); +} + +void gpr_cv_broadcast(gpr_cv* cv) { + GPR_TIMER_MARK("gpr_cv_broadcast", 0); + reinterpret_cast(cv)->SignalAll(); +} + +/*----------------------------------------*/ + +void gpr_once_init(gpr_once* once, void (*init_function)(void)) { + static_assert(sizeof(gpr_once) == sizeof(absl::once_flag), + "gpr_once and absl::once_flag must be the same size"); + absl::call_once(*reinterpret_cast(once), init_function); +} + +#endif /* defined(GPR_ABSEIL_SYNC) && !defined(GPR_CUSTOM_SYNC) */ diff --git a/src/core/lib/gpr/sync_posix.cc b/src/core/lib/gpr/sync_posix.cc new file mode 100644 index 00000000..7ed5403a --- /dev/null +++ b/src/core/lib/gpr/sync_posix.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if defined(GPR_POSIX_SYNC) && !defined(GPR_ABSEIL_SYNC) && \ + !defined(GPR_CUSTOM_SYNC) + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/profiling/timers.h" + +#ifdef GPR_LOW_LEVEL_COUNTERS +gpr_atm gpr_mu_locks = 0; +gpr_atm gpr_counter_atm_cas = 0; +gpr_atm gpr_counter_atm_add = 0; +#endif + +void gpr_mu_init(gpr_mu* mu) { +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_mutex_init(&mu->mutex, nullptr) == 0); + mu->leak_checker = static_cast(malloc(sizeof(*mu->leak_checker))); + GPR_ASSERT(mu->leak_checker != nullptr); +#else + GPR_ASSERT(pthread_mutex_init(mu, nullptr) == 0); +#endif +} + +void gpr_mu_destroy(gpr_mu* mu) { +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_mutex_destroy(&mu->mutex) == 0); + free(mu->leak_checker); +#else + GPR_ASSERT(pthread_mutex_destroy(mu) == 0); +#endif +} + +void gpr_mu_lock(gpr_mu* mu) { +#ifdef GPR_LOW_LEVEL_COUNTERS + GPR_ATM_INC_COUNTER(gpr_mu_locks); +#endif + GPR_TIMER_SCOPE("gpr_mu_lock", 0); +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_mutex_lock(&mu->mutex) == 0); +#else + GPR_ASSERT(pthread_mutex_lock(mu) == 0); +#endif +} + +void gpr_mu_unlock(gpr_mu* mu) { + GPR_TIMER_SCOPE("gpr_mu_unlock", 0); +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_mutex_unlock(&mu->mutex) == 0); +#else + GPR_ASSERT(pthread_mutex_unlock(mu) == 0); +#endif +} + +int gpr_mu_trylock(gpr_mu* mu) { + GPR_TIMER_SCOPE("gpr_mu_trylock", 0); + int err = 0; +#ifdef GRPC_ASAN_ENABLED + err = pthread_mutex_trylock(&mu->mutex); +#else + err = pthread_mutex_trylock(mu); +#endif + GPR_ASSERT(err == 0 || err == EBUSY); + return err == 0; +} + +/*----------------------------------------*/ + +void gpr_cv_init(gpr_cv* cv) { + pthread_condattr_t attr; + GPR_ASSERT(pthread_condattr_init(&attr) == 0); +#if GPR_LINUX + GPR_ASSERT(pthread_condattr_setclock(&attr, CLOCK_MONOTONIC) == 0); +#endif // GPR_LINUX + +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_cond_init(&cv->cond_var, &attr) == 0); + cv->leak_checker = static_cast(malloc(sizeof(*cv->leak_checker))); + GPR_ASSERT(cv->leak_checker != nullptr); +#else + GPR_ASSERT(pthread_cond_init(cv, &attr) == 0); +#endif +} + +void gpr_cv_destroy(gpr_cv* cv) { +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_cond_destroy(&cv->cond_var) == 0); + free(cv->leak_checker); +#else + GPR_ASSERT(pthread_cond_destroy(cv) == 0); +#endif +} + +int gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, gpr_timespec abs_deadline) { + int err = 0; + if (gpr_time_cmp(abs_deadline, gpr_inf_future(abs_deadline.clock_type)) == + 0) { +#ifdef GRPC_ASAN_ENABLED + err = pthread_cond_wait(&cv->cond_var, &mu->mutex); +#else + err = pthread_cond_wait(cv, mu); +#endif + } else { + struct timespec abs_deadline_ts; +#if GPR_LINUX + abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_MONOTONIC); +#else + abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_REALTIME); +#endif // GPR_LINUX + abs_deadline_ts.tv_sec = static_cast(abs_deadline.tv_sec); + abs_deadline_ts.tv_nsec = abs_deadline.tv_nsec; +#ifdef GRPC_ASAN_ENABLED + err = pthread_cond_timedwait(&cv->cond_var, &mu->mutex, &abs_deadline_ts); +#else + err = pthread_cond_timedwait(cv, mu, &abs_deadline_ts); +#endif + } + GPR_ASSERT(err == 0 || err == ETIMEDOUT || err == EAGAIN); + return err == ETIMEDOUT; +} + +void gpr_cv_signal(gpr_cv* cv) { +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_cond_signal(&cv->cond_var) == 0); +#else + GPR_ASSERT(pthread_cond_signal(cv) == 0); +#endif +} + +void gpr_cv_broadcast(gpr_cv* cv) { +#ifdef GRPC_ASAN_ENABLED + GPR_ASSERT(pthread_cond_broadcast(&cv->cond_var) == 0); +#else + GPR_ASSERT(pthread_cond_broadcast(cv) == 0); +#endif +} + +/*----------------------------------------*/ + +void gpr_once_init(gpr_once* once, void (*init_function)(void)) { + GPR_ASSERT(pthread_once(once, init_function) == 0); +} + +#endif /* defined(GPR_POSIX_SYNC) && !defined(GPR_ABSEIL_SYNC) && \ + !defined(GPR_CUSTOM_SYNC) */ diff --git a/src/core/lib/gpr/sync_windows.cc b/src/core/lib/gpr/sync_windows.cc new file mode 100644 index 00000000..a6173c72 --- /dev/null +++ b/src/core/lib/gpr/sync_windows.cc @@ -0,0 +1,120 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Win32 code for gpr synchronization support. */ + +#include + +#if defined(GPR_WINDOWS) && !defined(GPR_ABSEIL_SYNC) && \ + !defined(GPR_CUSTOM_SYNC) + +#include +#include +#include + +void gpr_mu_init(gpr_mu* mu) { + InitializeCriticalSection(&mu->cs); + mu->locked = 0; +} + +void gpr_mu_destroy(gpr_mu* mu) { DeleteCriticalSection(&mu->cs); } + +void gpr_mu_lock(gpr_mu* mu) { + EnterCriticalSection(&mu->cs); + GPR_ASSERT(!mu->locked); + mu->locked = 1; +} + +void gpr_mu_unlock(gpr_mu* mu) { + mu->locked = 0; + LeaveCriticalSection(&mu->cs); +} + +int gpr_mu_trylock(gpr_mu* mu) { + int result = TryEnterCriticalSection(&mu->cs); + if (result) { + if (mu->locked) { /* This thread already holds the lock. */ + LeaveCriticalSection(&mu->cs); /* Decrement lock count. */ + result = 0; /* Indicate failure */ + } + mu->locked = 1; + } + return result; +} + +/*----------------------------------------*/ + +void gpr_cv_init(gpr_cv* cv) { InitializeConditionVariable(cv); } + +void gpr_cv_destroy(gpr_cv* cv) { + /* Condition variables don't need destruction in Win32. */ +} + +int gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, gpr_timespec abs_deadline) { + int timeout = 0; + DWORD timeout_max_ms; + mu->locked = 0; + if (gpr_time_cmp(abs_deadline, gpr_inf_future(abs_deadline.clock_type)) == + 0) { + SleepConditionVariableCS(cv, &mu->cs, INFINITE); + } else { + abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_REALTIME); + gpr_timespec now = gpr_now(abs_deadline.clock_type); + int64_t now_ms = (int64_t)now.tv_sec * 1000 + now.tv_nsec / 1000000; + int64_t deadline_ms = + (int64_t)abs_deadline.tv_sec * 1000 + abs_deadline.tv_nsec / 1000000; + if (now_ms >= deadline_ms) { + timeout = 1; + } else { + if ((deadline_ms - now_ms) >= INFINITE) { + timeout_max_ms = INFINITE - 1; + } else { + timeout_max_ms = (DWORD)(deadline_ms - now_ms); + } + timeout = (SleepConditionVariableCS(cv, &mu->cs, timeout_max_ms) == 0 && + GetLastError() == ERROR_TIMEOUT); + } + } + mu->locked = 1; + return timeout; +} + +void gpr_cv_signal(gpr_cv* cv) { WakeConditionVariable(cv); } + +void gpr_cv_broadcast(gpr_cv* cv) { WakeAllConditionVariable(cv); } + +/*----------------------------------------*/ + +static void* phony; +struct run_once_func_arg { + void (*init_function)(void); +}; +static BOOL CALLBACK run_once_func(gpr_once* once, void* v, void** pv) { + struct run_once_func_arg* arg = (struct run_once_func_arg*)v; + (*arg->init_function)(); + return 1; +} + +void gpr_once_init(gpr_once* once, void (*init_function)(void)) { + struct run_once_func_arg arg; + arg.init_function = init_function; + InitOnceExecuteOnce(once, run_once_func, &arg, &phony); +} + +#endif /* defined(GPR_WINDOWS) && !defined(GPR_ABSEIL_SYNC) && \ + !defined(GPR_CUSTOM_SYNC) */ diff --git a/src/core/lib/gpr/time.cc b/src/core/lib/gpr/time.cc new file mode 100644 index 00000000..d796f414 --- /dev/null +++ b/src/core/lib/gpr/time.cc @@ -0,0 +1,264 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Generic implementation of time calls. */ + +#include + +#include +#include +#include + +#include +#include + +int gpr_time_cmp(gpr_timespec a, gpr_timespec b) { + int cmp = (a.tv_sec > b.tv_sec) - (a.tv_sec < b.tv_sec); + GPR_ASSERT(a.clock_type == b.clock_type); + if (cmp == 0 && a.tv_sec != INT64_MAX && a.tv_sec != INT64_MIN) { + cmp = (a.tv_nsec > b.tv_nsec) - (a.tv_nsec < b.tv_nsec); + } + return cmp; +} + +gpr_timespec gpr_time_min(gpr_timespec a, gpr_timespec b) { + return gpr_time_cmp(a, b) < 0 ? a : b; +} + +gpr_timespec gpr_time_max(gpr_timespec a, gpr_timespec b) { + return gpr_time_cmp(a, b) > 0 ? a : b; +} + +gpr_timespec gpr_time_0(gpr_clock_type type) { + gpr_timespec out; + out.tv_sec = 0; + out.tv_nsec = 0; + out.clock_type = type; + return out; +} + +gpr_timespec gpr_inf_future(gpr_clock_type type) { + gpr_timespec out; + out.tv_sec = INT64_MAX; + out.tv_nsec = 0; + out.clock_type = type; + return out; +} + +gpr_timespec gpr_inf_past(gpr_clock_type type) { + gpr_timespec out; + out.tv_sec = INT64_MIN; + out.tv_nsec = 0; + out.clock_type = type; + return out; +} + +static gpr_timespec to_seconds_from_sub_second_time(int64_t time_in_units, + int64_t units_per_sec, + gpr_clock_type type) { + gpr_timespec out; + if (time_in_units == INT64_MAX) { + out = gpr_inf_future(type); + } else if (time_in_units == INT64_MIN) { + out = gpr_inf_past(type); + } else { + if (time_in_units >= 0) { + out.tv_sec = time_in_units / units_per_sec; + } else { + out.tv_sec = (-((units_per_sec - 1) - (time_in_units + units_per_sec)) / + units_per_sec) - + 1; + } + out.tv_nsec = + static_cast((time_in_units - out.tv_sec * units_per_sec) * + GPR_NS_PER_SEC / units_per_sec); + out.clock_type = type; + } + return out; +} + +static gpr_timespec to_seconds_from_above_second_time(int64_t time_in_units, + int64_t secs_per_unit, + gpr_clock_type type) { + gpr_timespec out; + if (time_in_units >= INT64_MAX / secs_per_unit) { + out = gpr_inf_future(type); + } else if (time_in_units <= INT64_MIN / secs_per_unit) { + out = gpr_inf_past(type); + } else { + out.tv_sec = time_in_units * secs_per_unit; + out.tv_nsec = 0; + out.clock_type = type; + } + return out; +} + +gpr_timespec gpr_time_from_nanos(int64_t ns, gpr_clock_type clock_type) { + return to_seconds_from_sub_second_time(ns, GPR_NS_PER_SEC, clock_type); +} + +gpr_timespec gpr_time_from_micros(int64_t us, gpr_clock_type clock_type) { + return to_seconds_from_sub_second_time(us, GPR_US_PER_SEC, clock_type); +} + +gpr_timespec gpr_time_from_millis(int64_t ms, gpr_clock_type clock_type) { + return to_seconds_from_sub_second_time(ms, GPR_MS_PER_SEC, clock_type); +} + +gpr_timespec gpr_time_from_seconds(int64_t s, gpr_clock_type clock_type) { + return to_seconds_from_sub_second_time(s, 1, clock_type); +} + +gpr_timespec gpr_time_from_minutes(int64_t m, gpr_clock_type clock_type) { + return to_seconds_from_above_second_time(m, 60, clock_type); +} + +gpr_timespec gpr_time_from_hours(int64_t h, gpr_clock_type clock_type) { + return to_seconds_from_above_second_time(h, 3600, clock_type); +} + +gpr_timespec gpr_time_add(gpr_timespec a, gpr_timespec b) { + gpr_timespec sum; + int64_t inc = 0; + GPR_ASSERT(b.clock_type == GPR_TIMESPAN); + // tv_nsec in a timespan is always +ve. -ve timespan is represented as (-ve + // tv_sec, +ve tv_nsec). For example, timespan = -2.5 seconds is represented + // as {-3, 5e8, GPR_TIMESPAN} + GPR_ASSERT(b.tv_nsec >= 0); + sum.clock_type = a.clock_type; + sum.tv_nsec = a.tv_nsec + b.tv_nsec; + if (sum.tv_nsec >= GPR_NS_PER_SEC) { + sum.tv_nsec -= GPR_NS_PER_SEC; + inc++; + } + if (a.tv_sec == INT64_MAX || a.tv_sec == INT64_MIN) { + sum = a; + } else if (b.tv_sec == INT64_MAX || + (b.tv_sec >= 0 && a.tv_sec >= INT64_MAX - b.tv_sec)) { + sum = gpr_inf_future(sum.clock_type); + } else if (b.tv_sec == INT64_MIN || + (b.tv_sec <= 0 && a.tv_sec <= INT64_MIN - b.tv_sec)) { + sum = gpr_inf_past(sum.clock_type); + } else { + sum.tv_sec = a.tv_sec + b.tv_sec; + if (inc != 0 && sum.tv_sec == INT64_MAX - 1) { + sum = gpr_inf_future(sum.clock_type); + } else { + sum.tv_sec += inc; + } + } + return sum; +} + +gpr_timespec gpr_time_sub(gpr_timespec a, gpr_timespec b) { + gpr_timespec diff; + int64_t dec = 0; + if (b.clock_type == GPR_TIMESPAN) { + diff.clock_type = a.clock_type; + // tv_nsec in a timespan is always +ve. -ve timespan is represented as (-ve + // tv_sec, +ve tv_nsec). For example, timespan = -2.5 seconds is represented + // as {-3, 5e8, GPR_TIMESPAN} + GPR_ASSERT(b.tv_nsec >= 0); + } else { + GPR_ASSERT(a.clock_type == b.clock_type); + diff.clock_type = GPR_TIMESPAN; + } + diff.tv_nsec = a.tv_nsec - b.tv_nsec; + if (diff.tv_nsec < 0) { + diff.tv_nsec += GPR_NS_PER_SEC; + dec++; + } + if (a.tv_sec == INT64_MAX || a.tv_sec == INT64_MIN) { + diff = a; + } else if (b.tv_sec == INT64_MIN || + (b.tv_sec <= 0 && a.tv_sec >= INT64_MAX + b.tv_sec)) { + diff = gpr_inf_future(GPR_CLOCK_REALTIME); + } else if (b.tv_sec == INT64_MAX || + (b.tv_sec >= 0 && a.tv_sec <= INT64_MIN + b.tv_sec)) { + diff = gpr_inf_past(GPR_CLOCK_REALTIME); + } else { + diff.tv_sec = a.tv_sec - b.tv_sec; + if (dec != 0 && diff.tv_sec == INT64_MIN + 1) { + diff = gpr_inf_past(GPR_CLOCK_REALTIME); + } else { + diff.tv_sec -= dec; + } + } + return diff; +} + +int gpr_time_similar(gpr_timespec a, gpr_timespec b, gpr_timespec threshold) { + int cmp_ab; + + GPR_ASSERT(a.clock_type == b.clock_type); + GPR_ASSERT(threshold.clock_type == GPR_TIMESPAN); + + cmp_ab = gpr_time_cmp(a, b); + if (cmp_ab == 0) return 1; + if (cmp_ab < 0) { + return gpr_time_cmp(gpr_time_sub(b, a), threshold) <= 0; + } else { + return gpr_time_cmp(gpr_time_sub(a, b), threshold) <= 0; + } +} + +int32_t gpr_time_to_millis(gpr_timespec t) { + if (t.tv_sec >= 2147483) { + if (t.tv_sec == 2147483 && t.tv_nsec < 648 * GPR_NS_PER_MS) { + return 2147483 * GPR_MS_PER_SEC + t.tv_nsec / GPR_NS_PER_MS; + } + return 2147483647; + } else if (t.tv_sec <= -2147483) { + /* TODO(ctiller): correct handling here (it's so far in the past do we + care?) */ + return -2147483647; + } else { + return static_cast(t.tv_sec * GPR_MS_PER_SEC + + t.tv_nsec / GPR_NS_PER_MS); + } +} + +double gpr_timespec_to_micros(gpr_timespec t) { + return static_cast(t.tv_sec) * GPR_US_PER_SEC + t.tv_nsec * 1e-3; +} + +gpr_timespec gpr_convert_clock_type(gpr_timespec t, gpr_clock_type clock_type) { + if (t.clock_type == clock_type) { + return t; + } + + if (t.tv_sec == INT64_MAX || t.tv_sec == INT64_MIN) { + t.clock_type = clock_type; + return t; + } + + if (clock_type == GPR_TIMESPAN) { + return gpr_time_sub(t, gpr_now(t.clock_type)); + } + + if (t.clock_type == GPR_TIMESPAN) { + return gpr_time_add(gpr_now(clock_type), t); + } + + // If the given input hits this code, the same result is not guaranteed for + // the same input because it relies on `gpr_now` to calculate the difference + // between two different clocks. Please be careful when you want to use this + // function in unit tests. (e.g. https://github.com/grpc/grpc/pull/22655) + return gpr_time_add(gpr_now(clock_type), + gpr_time_sub(t, gpr_now(t.clock_type))); +} diff --git a/src/core/lib/gpr/time_posix.cc b/src/core/lib/gpr/time_posix.cc new file mode 100644 index 00000000..0a9e1158 --- /dev/null +++ b/src/core/lib/gpr/time_posix.cc @@ -0,0 +1,186 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gpr/time_precise.h" + +#ifdef GPR_POSIX_TIME + +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#include +#include + +static struct timespec timespec_from_gpr(gpr_timespec gts) { + struct timespec rv; + if (sizeof(time_t) < sizeof(int64_t)) { + /* fine to assert, as this is only used in gpr_sleep_until */ + GPR_ASSERT(gts.tv_sec <= INT32_MAX && gts.tv_sec >= INT32_MIN); + } + rv.tv_sec = static_cast(gts.tv_sec); + rv.tv_nsec = gts.tv_nsec; + return rv; +} + +#if _POSIX_TIMERS > 0 || defined(__OpenBSD__) +static gpr_timespec gpr_from_timespec(struct timespec ts, + gpr_clock_type clock_type) { + /* + * timespec.tv_sec can have smaller size than gpr_timespec.tv_sec, + * but we are only using this function to implement gpr_now + * so there's no need to handle "infinity" values. + */ + gpr_timespec rv; + rv.tv_sec = ts.tv_sec; + rv.tv_nsec = static_cast(ts.tv_nsec); + rv.clock_type = clock_type; + return rv; +} + +/** maps gpr_clock_type --> clockid_t for clock_gettime */ +static const clockid_t clockid_for_gpr_clock[] = {CLOCK_MONOTONIC, + CLOCK_REALTIME}; + +void gpr_time_init(void) { gpr_precise_clock_init(); } + +static gpr_timespec now_impl(gpr_clock_type clock_type) { + struct timespec now; + GPR_ASSERT(clock_type != GPR_TIMESPAN); + if (clock_type == GPR_CLOCK_PRECISE) { + gpr_timespec ret; + gpr_precise_clock_now(&ret); + return ret; + } else { +#if defined(GPR_BACKWARDS_COMPATIBILITY_MODE) && defined(__linux__) + /* avoid ABI problems by invoking syscalls directly */ + syscall(SYS_clock_gettime, clockid_for_gpr_clock[clock_type], &now); +#else + clock_gettime(clockid_for_gpr_clock[clock_type], &now); +#endif + return gpr_from_timespec(now, clock_type); + } +} +#else +/* For some reason Apple's OSes haven't implemented clock_gettime. */ + +#include +#include +#include + +static double g_time_scale; +static uint64_t g_time_start; + +void gpr_time_init(void) { + mach_timebase_info_data_t tb = {0, 1}; + gpr_precise_clock_init(); + mach_timebase_info(&tb); + g_time_scale = tb.numer; + g_time_scale /= tb.denom; + g_time_start = mach_absolute_time(); +} + +static gpr_timespec now_impl(gpr_clock_type clock) { + gpr_timespec now; + struct timeval now_tv; + double now_dbl; + + now.clock_type = clock; + switch (clock) { + case GPR_CLOCK_REALTIME: + // gettimeofday(...) function may return with a value whose tv_usec is + // greater than 1e6 on iOS The case is resolved with the guard at end of + // this function. + gettimeofday(&now_tv, nullptr); + now.tv_sec = now_tv.tv_sec; + now.tv_nsec = now_tv.tv_usec * 1000; + break; + case GPR_CLOCK_MONOTONIC: + now_dbl = ((double)(mach_absolute_time() - g_time_start)) * g_time_scale; + now.tv_sec = (int64_t)(now_dbl * 1e-9); + now.tv_nsec = (int32_t)(now_dbl - ((double)now.tv_sec) * 1e9); + break; + case GPR_CLOCK_PRECISE: + gpr_precise_clock_now(&now); + break; + case GPR_TIMESPAN: + abort(); + } + + // Guard the tv_nsec field in valid range for all clock types + while (GPR_UNLIKELY(now.tv_nsec >= 1e9)) { + now.tv_sec++; + now.tv_nsec -= 1e9; + } + while (GPR_UNLIKELY(now.tv_nsec < 0)) { + now.tv_sec--; + now.tv_nsec += 1e9; + } + + return now; +} +#endif + +gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type) = now_impl; + +#ifdef GPR_LOW_LEVEL_COUNTERS +gpr_atm gpr_now_call_count; +#endif +gpr_timespec gpr_now(gpr_clock_type clock_type) { +#ifdef GPR_LOW_LEVEL_COUNTERS + __atomic_fetch_add(&gpr_now_call_count, 1, __ATOMIC_RELAXED); +#endif + // validate clock type + GPR_ASSERT(clock_type == GPR_CLOCK_MONOTONIC || + clock_type == GPR_CLOCK_REALTIME || + clock_type == GPR_CLOCK_PRECISE); + gpr_timespec ts = gpr_now_impl(clock_type); + // tv_nsecs must be in the range [0, 1e9). + GPR_ASSERT(ts.tv_nsec >= 0 && ts.tv_nsec < 1e9); + return ts; +} + +void gpr_sleep_until(gpr_timespec until) { + gpr_timespec now; + gpr_timespec delta; + struct timespec delta_ts; + int ns_result; + + for (;;) { + /* We could simplify by using clock_nanosleep instead, but it might be + * slightly less portable. */ + now = gpr_now(until.clock_type); + if (gpr_time_cmp(until, now) <= 0) { + return; + } + + delta = gpr_time_sub(until, now); + delta_ts = timespec_from_gpr(delta); + ns_result = nanosleep(&delta_ts, nullptr); + if (ns_result == 0) { + break; + } + } +} + +#endif /* GPR_POSIX_TIME */ diff --git a/src/core/lib/gpr/time_precise.cc b/src/core/lib/gpr/time_precise.cc new file mode 100644 index 00000000..c66b981d --- /dev/null +++ b/src/core/lib/gpr/time_precise.cc @@ -0,0 +1,168 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if GPR_LINUX +#include +#include +#endif + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/time_precise.h" + +#ifndef GPR_CYCLE_COUNTER_CUSTOM +#if GPR_CYCLE_COUNTER_RDTSC_32 || GPR_CYCLE_COUNTER_RDTSC_64 +#if GPR_LINUX +static bool read_freq_from_kernel(double* freq) { + // Google production kernel export the frequency for us in kHz. + int fd = open("/sys/devices/system/cpu/cpu0/tsc_freq_khz", O_RDONLY); + if (fd == -1) { + return false; + } + char line[1024] = {}; + char* err; + bool ret = false; + int len = read(fd, line, sizeof(line) - 1); + if (len > 0) { + const long val = strtol(line, &err, 10); + if (line[0] != '\0' && (*err == '\n' || *err == '\0')) { + *freq = val * 1e3; // Value is kHz. + ret = true; + } + } + close(fd); + return ret; +} +#endif /* GPR_LINUX */ + +static double cycles_per_second = 0; +static gpr_cycle_counter start_cycle; + +static bool is_fake_clock() { + gpr_timespec start = gpr_now(GPR_CLOCK_MONOTONIC); + int64_t sum = 0; + for (int i = 0; i < 8; ++i) { + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec delta = gpr_time_sub(now, start); + sum += delta.tv_sec * GPR_NS_PER_SEC + delta.tv_nsec; + } + // If the clock doesn't move even a nano after 8 tries, it's a fake one. + return sum == 0; +} + +void gpr_precise_clock_init(void) { + gpr_log(GPR_DEBUG, "Calibrating timers"); + +#if GPR_LINUX + if (read_freq_from_kernel(&cycles_per_second)) { + start_cycle = gpr_get_cycle_counter(); + return; + } +#endif /* GPR_LINUX */ + + if (is_fake_clock()) { + cycles_per_second = 1; + start_cycle = 0; + return; + } + // Start from a loop of 1ms, and gradually increase the loop duration + // until we either converge or we have passed 255ms (1ms+2ms+...+128ms). + int64_t measurement_ns = GPR_NS_PER_MS; + double last_freq = -1; + bool converged = false; + for (int i = 0; i < 8 && !converged; ++i, measurement_ns *= 2) { + start_cycle = gpr_get_cycle_counter(); + int64_t loop_ns; + gpr_timespec start = gpr_now(GPR_CLOCK_MONOTONIC); + do { + // TODO(soheil): Maybe sleep instead of busy polling. + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec delta = gpr_time_sub(now, start); + loop_ns = delta.tv_sec * GPR_NS_PER_SEC + delta.tv_nsec; + } while (loop_ns < measurement_ns); + gpr_cycle_counter end_cycle = gpr_get_cycle_counter(); + // Frequency should be in Hz. + const double freq = + static_cast(end_cycle - start_cycle) / loop_ns * GPR_NS_PER_SEC; + converged = + last_freq != -1 && (freq * 0.99 < last_freq && last_freq < freq * 1.01); + last_freq = freq; + } + cycles_per_second = last_freq; + gpr_log(GPR_DEBUG, "... cycles_per_second = %f\n", cycles_per_second); +} + +gpr_timespec gpr_cycle_counter_to_time(gpr_cycle_counter cycles) { + const double secs = + static_cast(cycles - start_cycle) / cycles_per_second; + gpr_timespec ts; + ts.tv_sec = static_cast(secs); + ts.tv_nsec = static_cast(GPR_NS_PER_SEC * + (secs - static_cast(ts.tv_sec))); + ts.clock_type = GPR_CLOCK_PRECISE; + return ts; +} + +gpr_timespec gpr_cycle_counter_sub(gpr_cycle_counter a, gpr_cycle_counter b) { + const double secs = static_cast(a - b) / cycles_per_second; + gpr_timespec ts; + ts.tv_sec = static_cast(secs); + ts.tv_nsec = static_cast(GPR_NS_PER_SEC * + (secs - static_cast(ts.tv_sec))); + ts.clock_type = GPR_TIMESPAN; + return ts; +} + +void gpr_precise_clock_now(gpr_timespec* clk) { + int64_t counter = gpr_get_cycle_counter(); + *clk = gpr_cycle_counter_to_time(counter); +} +#elif GPR_CYCLE_COUNTER_FALLBACK +void gpr_precise_clock_init(void) {} + +gpr_cycle_counter gpr_get_cycle_counter() { + gpr_timespec ts = gpr_now(GPR_CLOCK_REALTIME); + return gpr_timespec_to_micros(ts); +} + +gpr_timespec gpr_cycle_counter_to_time(gpr_cycle_counter cycles) { + gpr_timespec ts; + ts.tv_sec = static_cast(cycles / GPR_US_PER_SEC); + ts.tv_nsec = static_cast((cycles - ts.tv_sec * GPR_US_PER_SEC) * + GPR_NS_PER_US); + ts.clock_type = GPR_CLOCK_PRECISE; + return ts; +} + +void gpr_precise_clock_now(gpr_timespec* clk) { + *clk = gpr_now(GPR_CLOCK_REALTIME); + clk->clock_type = GPR_CLOCK_PRECISE; +} + +gpr_timespec gpr_cycle_counter_sub(gpr_cycle_counter a, gpr_cycle_counter b) { + return gpr_time_sub(gpr_cycle_counter_to_time(a), + gpr_cycle_counter_to_time(b)); +} +#endif /* GPR_CYCLE_COUNTER_FALLBACK */ +#endif /* !GPR_CYCLE_COUNTER_CUSTOM */ diff --git a/src/core/lib/gpr/time_windows.cc b/src/core/lib/gpr/time_windows.cc new file mode 100644 index 00000000..39bca1b0 --- /dev/null +++ b/src/core/lib/gpr/time_windows.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Win32 code for gpr time support. */ + +#include + +#ifdef GPR_WINDOWS_TIME + +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/time_precise.h" + +static LARGE_INTEGER g_start_time; +static double g_time_scale; + +void gpr_time_init(void) { + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + QueryPerformanceCounter(&g_start_time); + g_time_scale = 1.0 / (double)frequency.QuadPart; +} + +static gpr_timespec now_impl(gpr_clock_type clock) { + gpr_timespec now_tv; + LONGLONG diff; + struct _timeb now_tb; + LARGE_INTEGER timestamp; + double now_dbl; + now_tv.clock_type = clock; + switch (clock) { + case GPR_CLOCK_REALTIME: + _ftime_s(&now_tb); + now_tv.tv_sec = (int64_t)now_tb.time; + now_tv.tv_nsec = now_tb.millitm * 1000000; + break; + case GPR_CLOCK_MONOTONIC: + case GPR_CLOCK_PRECISE: + QueryPerformanceCounter(×tamp); + diff = timestamp.QuadPart - g_start_time.QuadPart; + now_dbl = (double)diff * g_time_scale; + now_tv.tv_sec = (int64_t)now_dbl; + now_tv.tv_nsec = (int32_t)((now_dbl - (double)now_tv.tv_sec) * 1e9); + break; + case GPR_TIMESPAN: + abort(); + break; + } + return now_tv; +} + +gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type) = now_impl; + +gpr_timespec gpr_now(gpr_clock_type clock_type) { + return gpr_now_impl(clock_type); +} + +void gpr_sleep_until(gpr_timespec until) { + gpr_timespec now; + gpr_timespec delta; + int64_t sleep_millis; + + for (;;) { + /* We could simplify by using clock_nanosleep instead, but it might be + * slightly less portable. */ + now = gpr_now(until.clock_type); + if (gpr_time_cmp(until, now) <= 0) { + return; + } + + delta = gpr_time_sub(until, now); + sleep_millis = + delta.tv_sec * GPR_MS_PER_SEC + delta.tv_nsec / GPR_NS_PER_MS; + GPR_ASSERT((sleep_millis >= 0) && (sleep_millis <= INT_MAX)); + Sleep((DWORD)sleep_millis); + } +} + +#endif /* GPR_WINDOWS_TIME */ diff --git a/src/core/lib/gpr/tmpfile_msys.cc b/src/core/lib/gpr/tmpfile_msys.cc new file mode 100644 index 00000000..76cd886f --- /dev/null +++ b/src/core/lib/gpr/tmpfile_msys.cc @@ -0,0 +1,58 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_MSYS_TMPFILE + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string_windows.h" +#include "src/core/lib/gpr/tmpfile.h" + +FILE* gpr_tmpfile(const char* prefix, char** tmp_filename_out) { + FILE* result = NULL; + char tmp_filename[MAX_PATH]; + UINT success; + + if (tmp_filename_out != NULL) *tmp_filename_out = NULL; + + /* Generate a unique filename with our template + temporary path. */ + success = GetTempFileNameA(".", prefix, 0, tmp_filename); + fprintf(stderr, "success = %d\n", success); + + if (success) { + /* Open a file there. */ + result = fopen(tmp_filename, "wb+"); + fprintf(stderr, "result = %p\n", result); + } + if (result != NULL && tmp_filename_out) { + *tmp_filename_out = gpr_strdup(tmp_filename); + } + + return result; +} + +#endif /* GPR_MSYS_TMPFILE */ diff --git a/src/core/lib/gpr/tmpfile_posix.cc b/src/core/lib/gpr/tmpfile_posix.cc new file mode 100644 index 00000000..166cdf68 --- /dev/null +++ b/src/core/lib/gpr/tmpfile_posix.cc @@ -0,0 +1,69 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_POSIX_TMPFILE + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" + +FILE* gpr_tmpfile(const char* prefix, char** tmp_filename) { + FILE* result = nullptr; + char* filename_template; + int fd; + + if (tmp_filename != nullptr) *tmp_filename = nullptr; + + gpr_asprintf(&filename_template, "/tmp/%s_XXXXXX", prefix); + GPR_ASSERT(filename_template != nullptr); + + fd = mkstemp(filename_template); + if (fd == -1) { + gpr_log(GPR_ERROR, "mkstemp failed for filename_template %s with error %s.", + filename_template, strerror(errno)); + goto end; + } + result = fdopen(fd, "w+"); + if (result == nullptr) { + gpr_log(GPR_ERROR, "Could not open file %s from fd %d (error = %s).", + filename_template, fd, strerror(errno)); + unlink(filename_template); + close(fd); + goto end; + } + +end: + if (result != nullptr && tmp_filename != nullptr) { + *tmp_filename = filename_template; + } else { + gpr_free(filename_template); + } + return result; +} + +#endif /* GPR_POSIX_TMPFILE */ diff --git a/src/core/lib/gpr/tmpfile_windows.cc b/src/core/lib/gpr/tmpfile_windows.cc new file mode 100644 index 00000000..d4868084 --- /dev/null +++ b/src/core/lib/gpr/tmpfile_windows.cc @@ -0,0 +1,69 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS_TMPFILE + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string_windows.h" +#include "src/core/lib/gpr/tmpfile.h" + +FILE* gpr_tmpfile(const char* prefix, char** tmp_filename_out) { + FILE* result = NULL; + LPTSTR template_string = NULL; + TCHAR tmp_path[MAX_PATH]; + TCHAR tmp_filename[MAX_PATH]; + DWORD status; + UINT success; + + if (tmp_filename_out != NULL) *tmp_filename_out = NULL; + + /* Convert our prefix to TCHAR. */ + template_string = gpr_char_to_tchar(prefix); + GPR_ASSERT(template_string); + + /* Get the path to the best temporary folder available. */ + status = GetTempPath(MAX_PATH, tmp_path); + if (status == 0 || status > MAX_PATH) goto end; + + /* Generate a unique filename with our template + temporary path. */ + success = GetTempFileName(tmp_path, template_string, 0, tmp_filename); + if (!success) goto end; + + /* Open a file there. */ + if (_tfopen_s(&result, tmp_filename, TEXT("wb+")) != 0) goto end; + +end: + if (result && tmp_filename_out) { + *tmp_filename_out = gpr_tchar_to_char(tmp_filename); + } + + gpr_free(template_string); + return result; +} + +#endif /* GPR_WINDOWS_TMPFILE */ diff --git a/src/core/lib/gpr/wrap_memcpy.cc b/src/core/lib/gpr/wrap_memcpy.cc new file mode 100644 index 00000000..51efc93a --- /dev/null +++ b/src/core/lib/gpr/wrap_memcpy.cc @@ -0,0 +1,43 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +/* Provide a wrapped memcpy for targets that need to be backwards + * compatible with older libc's. + * + * Enable by setting LDFLAGS=-Wl,-wrap,memcpy when linking. + */ + +extern "C" { +#ifdef __linux__ +#if defined(__x86_64__) && !defined(GPR_MUSL_LIBC_COMPAT) && \ + !defined(__ANDROID__) +__asm__(".symver memcpy,memcpy@GLIBC_2.2.5"); +void* __wrap_memcpy(void* destination, const void* source, size_t num) { + return memcpy(destination, source, num); +} +#else /* !__x86_64__ */ +void* __wrap_memcpy(void* destination, const void* source, size_t num) { + return memmove(destination, source, num); +} +#endif +#endif +} diff --git a/src/core/lib/gprpp/arena.cc b/src/core/lib/gprpp/arena.cc new file mode 100644 index 00000000..9e979de3 --- /dev/null +++ b/src/core/lib/gprpp/arena.cc @@ -0,0 +1,104 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/arena.h" + +#include + +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/gprpp/memory.h" + +namespace { + +void* ArenaStorage(size_t initial_size) { + static constexpr size_t base_size = + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_core::Arena)); + initial_size = GPR_ROUND_UP_TO_ALIGNMENT_SIZE(initial_size); + size_t alloc_size = base_size + initial_size; + static constexpr size_t alignment = + (GPR_CACHELINE_SIZE > GPR_MAX_ALIGNMENT && + GPR_CACHELINE_SIZE % GPR_MAX_ALIGNMENT == 0) + ? GPR_CACHELINE_SIZE + : GPR_MAX_ALIGNMENT; + return gpr_malloc_aligned(alloc_size, alignment); +} + +} // namespace + +namespace grpc_core { + +Arena::~Arena() { + Zone* z = last_zone_; + while (z) { + Zone* prev_z = z->prev; + z->~Zone(); + gpr_free_aligned(z); + z = prev_z; + } +} + +Arena* Arena::Create(size_t initial_size) { + return new (ArenaStorage(initial_size)) Arena(initial_size); +} + +std::pair Arena::CreateWithAlloc(size_t initial_size, + size_t alloc_size) { + static constexpr size_t base_size = + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(Arena)); + auto* new_arena = + new (ArenaStorage(initial_size)) Arena(initial_size, alloc_size); + void* first_alloc = reinterpret_cast(new_arena) + base_size; + return std::make_pair(new_arena, first_alloc); +} + +size_t Arena::Destroy() { + size_t size = total_used_.load(std::memory_order_relaxed); + this->~Arena(); + gpr_free_aligned(this); + return size; +} + +void* Arena::AllocZone(size_t size) { + // If the allocation isn't able to end in the initial zone, create a new + // zone for this allocation, and any unused space in the initial zone is + // wasted. This overflowing and wasting is uncommon because of our arena + // sizing hysteresis (that is, most calls should have a large enough initial + // zone and will not need to grow the arena). + static constexpr size_t zone_base_size = + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(Zone)); + size_t alloc_size = zone_base_size + size; + Zone* z = new (gpr_malloc_aligned(alloc_size, GPR_MAX_ALIGNMENT)) Zone(); + { + gpr_spinlock_lock(&arena_growth_spinlock_); + z->prev = last_zone_; + last_zone_ = z; + gpr_spinlock_unlock(&arena_growth_spinlock_); + } + return reinterpret_cast(z) + zone_base_size; +} + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/examine_stack.cc b/src/core/lib/gprpp/examine_stack.cc new file mode 100644 index 00000000..1c5d93ae --- /dev/null +++ b/src/core/lib/gprpp/examine_stack.cc @@ -0,0 +1,43 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/examine_stack.h" + +namespace grpc_core { + +gpr_current_stack_trace_func g_current_stack_trace_provider = nullptr; + +gpr_current_stack_trace_func GetCurrentStackTraceProvider() { + return g_current_stack_trace_provider; +} + +void SetCurrentStackTraceProvider( + gpr_current_stack_trace_func current_stack_trace_provider) { + g_current_stack_trace_provider = current_stack_trace_provider; +} + +absl::optional GetCurrentStackTrace() { + if (g_current_stack_trace_provider != nullptr) { + return g_current_stack_trace_provider(); + } + return absl::nullopt; +} + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/fork.cc b/src/core/lib/gprpp/fork.cc new file mode 100644 index 00000000..675f8e4d --- /dev/null +++ b/src/core/lib/gprpp/fork.cc @@ -0,0 +1,244 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/fork.h" + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/gprpp/memory.h" + +/* + * NOTE: FORKING IS NOT GENERALLY SUPPORTED, THIS IS ONLY INTENDED TO WORK + * AROUND VERY SPECIFIC USE CASES. + */ + +#ifdef GRPC_ENABLE_FORK_SUPPORT +#define GRPC_ENABLE_FORK_SUPPORT_DEFAULT true +#else +#define GRPC_ENABLE_FORK_SUPPORT_DEFAULT false +#endif // GRPC_ENABLE_FORK_SUPPORT + +GPR_GLOBAL_CONFIG_DEFINE_BOOL(grpc_enable_fork_support, + GRPC_ENABLE_FORK_SUPPORT_DEFAULT, + "Enable fork support"); + +namespace grpc_core { +namespace internal { +// The exec_ctx_count has 2 modes, blocked and unblocked. +// When unblocked, the count is 2-indexed; exec_ctx_count=2 indicates +// 0 active ExecCtxs, exex_ctx_count=3 indicates 1 active ExecCtxs... + +// When blocked, the exec_ctx_count is 0-indexed. Note that ExecCtx +// creation can only be blocked if there is exactly 1 outstanding ExecCtx, +// meaning that BLOCKED and UNBLOCKED counts partition the integers +#define UNBLOCKED(n) ((n) + 2) +#define BLOCKED(n) (n) + +class ExecCtxState { + public: + ExecCtxState() : fork_complete_(true) { + gpr_mu_init(&mu_); + gpr_cv_init(&cv_); + gpr_atm_no_barrier_store(&count_, UNBLOCKED(0)); + } + + void IncExecCtxCount() { + gpr_atm count = gpr_atm_no_barrier_load(&count_); + while (true) { + if (count <= BLOCKED(1)) { + // This only occurs if we are trying to fork. Wait until the fork() + // operation completes before allowing new ExecCtxs. + gpr_mu_lock(&mu_); + if (gpr_atm_no_barrier_load(&count_) <= BLOCKED(1)) { + while (!fork_complete_) { + gpr_cv_wait(&cv_, &mu_, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + } + gpr_mu_unlock(&mu_); + } else if (gpr_atm_no_barrier_cas(&count_, count, count + 1)) { + break; + } + count = gpr_atm_no_barrier_load(&count_); + } + } + + void DecExecCtxCount() { gpr_atm_no_barrier_fetch_add(&count_, -1); } + + bool BlockExecCtx() { + // Assumes there is an active ExecCtx when this function is called + if (gpr_atm_no_barrier_cas(&count_, UNBLOCKED(1), BLOCKED(1))) { + gpr_mu_lock(&mu_); + fork_complete_ = false; + gpr_mu_unlock(&mu_); + return true; + } + return false; + } + + void AllowExecCtx() { + gpr_mu_lock(&mu_); + gpr_atm_no_barrier_store(&count_, UNBLOCKED(0)); + fork_complete_ = true; + gpr_cv_broadcast(&cv_); + gpr_mu_unlock(&mu_); + } + + ~ExecCtxState() { + gpr_mu_destroy(&mu_); + gpr_cv_destroy(&cv_); + } + + private: + bool fork_complete_; + gpr_mu mu_; + gpr_cv cv_; + gpr_atm count_; +}; + +class ThreadState { + public: + ThreadState() : awaiting_threads_(false), threads_done_(false), count_(0) { + gpr_mu_init(&mu_); + gpr_cv_init(&cv_); + } + + void IncThreadCount() { + gpr_mu_lock(&mu_); + count_++; + gpr_mu_unlock(&mu_); + } + + void DecThreadCount() { + gpr_mu_lock(&mu_); + count_--; + if (awaiting_threads_ && count_ == 0) { + threads_done_ = true; + gpr_cv_signal(&cv_); + } + gpr_mu_unlock(&mu_); + } + void AwaitThreads() { + gpr_mu_lock(&mu_); + awaiting_threads_ = true; + threads_done_ = (count_ == 0); + while (!threads_done_) { + gpr_cv_wait(&cv_, &mu_, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + awaiting_threads_ = true; + gpr_mu_unlock(&mu_); + } + + ~ThreadState() { + gpr_mu_destroy(&mu_); + gpr_cv_destroy(&cv_); + } + + private: + bool awaiting_threads_; + bool threads_done_; + gpr_mu mu_; + gpr_cv cv_; + int count_; +}; + +} // namespace internal + +void Fork::GlobalInit() { + if (!override_enabled_) { + support_enabled_.store(GPR_GLOBAL_CONFIG_GET(grpc_enable_fork_support), + std::memory_order_relaxed); + } + if (support_enabled_.load(std::memory_order_relaxed)) { + exec_ctx_state_ = new internal::ExecCtxState(); + thread_state_ = new internal::ThreadState(); + } +} + +void Fork::GlobalShutdown() { + if (support_enabled_.load(std::memory_order_relaxed)) { + delete exec_ctx_state_; + delete thread_state_; + } +} + +bool Fork::Enabled() { + return support_enabled_.load(std::memory_order_relaxed); +} + +// Testing Only +void Fork::Enable(bool enable) { + override_enabled_ = true; + support_enabled_.store(enable, std::memory_order_relaxed); +} + +void Fork::DoIncExecCtxCount() { exec_ctx_state_->IncExecCtxCount(); } + +void Fork::DoDecExecCtxCount() { exec_ctx_state_->DecExecCtxCount(); } + +void Fork::SetResetChildPollingEngineFunc( + Fork::child_postfork_func reset_child_polling_engine) { + reset_child_polling_engine_ = reset_child_polling_engine; +} +Fork::child_postfork_func Fork::GetResetChildPollingEngineFunc() { + return reset_child_polling_engine_; +} + +bool Fork::BlockExecCtx() { + if (support_enabled_.load(std::memory_order_relaxed)) { + return exec_ctx_state_->BlockExecCtx(); + } + return false; +} + +void Fork::AllowExecCtx() { + if (support_enabled_.load(std::memory_order_relaxed)) { + exec_ctx_state_->AllowExecCtx(); + } +} + +void Fork::IncThreadCount() { + if (support_enabled_.load(std::memory_order_relaxed)) { + thread_state_->IncThreadCount(); + } +} + +void Fork::DecThreadCount() { + if (support_enabled_.load(std::memory_order_relaxed)) { + thread_state_->DecThreadCount(); + } +} +void Fork::AwaitThreads() { + if (support_enabled_.load(std::memory_order_relaxed)) { + thread_state_->AwaitThreads(); + } +} + +internal::ExecCtxState* Fork::exec_ctx_state_ = nullptr; +internal::ThreadState* Fork::thread_state_ = nullptr; +std::atomic Fork::support_enabled_(false); +bool Fork::override_enabled_ = false; +Fork::child_postfork_func Fork::reset_child_polling_engine_ = nullptr; +} // namespace grpc_core diff --git a/src/core/lib/gprpp/global_config_env.cc b/src/core/lib/gprpp/global_config_env.cc new file mode 100644 index 00000000..d70c8e83 --- /dev/null +++ b/src/core/lib/gprpp/global_config_env.cc @@ -0,0 +1,137 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/global_config_env.h" + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" + +namespace grpc_core { + +namespace { + +void DefaultGlobalConfigEnvErrorFunction(const char* error_message) { + gpr_log(GPR_ERROR, "%s", error_message); +} + +GlobalConfigEnvErrorFunctionType g_global_config_env_error_func = + DefaultGlobalConfigEnvErrorFunction; + +void LogParsingError(const char* name, const char* value) { + std::string error_message = absl::StrFormat( + "Illegal value '%s' specified for environment variable '%s'", value, + name); + (*g_global_config_env_error_func)(error_message.c_str()); +} + +} // namespace + +void SetGlobalConfigEnvErrorFunction(GlobalConfigEnvErrorFunctionType func) { + g_global_config_env_error_func = func; +} + +grpc_core::UniquePtr GlobalConfigEnv::GetValue() { + return grpc_core::UniquePtr(gpr_getenv(GetName())); +} + +void GlobalConfigEnv::SetValue(const char* value) { + gpr_setenv(GetName(), value); +} + +void GlobalConfigEnv::Unset() { gpr_unsetenv(GetName()); } + +char* GlobalConfigEnv::GetName() { + // This makes sure that name_ is in a canonical form having uppercase + // letters. This is okay to be called serveral times. + for (char* c = name_; *c != 0; ++c) { + *c = toupper(*c); + } + return name_; +} +static_assert(std::is_trivially_destructible::value, + "GlobalConfigEnvBool needs to be trivially destructible."); + +bool GlobalConfigEnvBool::Get() { + grpc_core::UniquePtr str = GetValue(); + if (str == nullptr) { + return default_value_; + } + // parsing given value string. + bool result = false; + if (!gpr_parse_bool_value(str.get(), &result)) { + LogParsingError(GetName(), str.get()); + result = default_value_; + } + return result; +} + +void GlobalConfigEnvBool::Set(bool value) { + SetValue(value ? "true" : "false"); +} + +static_assert(std::is_trivially_destructible::value, + "GlobalConfigEnvInt32 needs to be trivially destructible."); + +int32_t GlobalConfigEnvInt32::Get() { + grpc_core::UniquePtr str = GetValue(); + if (str == nullptr) { + return default_value_; + } + // parsing given value string. + char* end = str.get(); + long result = strtol(str.get(), &end, 10); + if (*end != 0) { + LogParsingError(GetName(), str.get()); + result = default_value_; + } + return static_cast(result); +} + +void GlobalConfigEnvInt32::Set(int32_t value) { + char buffer[GPR_LTOA_MIN_BUFSIZE]; + gpr_ltoa(value, buffer); + SetValue(buffer); +} + +static_assert(std::is_trivially_destructible::value, + "GlobalConfigEnvString needs to be trivially destructible."); + +grpc_core::UniquePtr GlobalConfigEnvString::Get() { + grpc_core::UniquePtr str = GetValue(); + if (str == nullptr) { + return grpc_core::UniquePtr(gpr_strdup(default_value_)); + } + return str; +} + +void GlobalConfigEnvString::Set(const char* value) { SetValue(value); } + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/host_port.cc b/src/core/lib/gprpp/host_port.cc new file mode 100644 index 00000000..77eb1747 --- /dev/null +++ b/src/core/lib/gprpp/host_port.cc @@ -0,0 +1,112 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/host_port.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +#include + +namespace grpc_core { + +std::string JoinHostPort(absl::string_view host, int port) { + if (!host.empty() && host[0] != '[' && host.rfind(':') != host.npos) { + // IPv6 literals must be enclosed in brackets. + return absl::StrFormat("[%s]:%d", host, port); + } + // Ordinary non-bracketed host:port. + return absl::StrFormat("%s:%d", host, port); +} + +namespace { +bool DoSplitHostPort(absl::string_view name, absl::string_view* host, + absl::string_view* port, bool* has_port) { + *has_port = false; + if (!name.empty() && name[0] == '[') { + /* Parse a bracketed host, typically an IPv6 literal. */ + const size_t rbracket = name.find(']', 1); + if (rbracket == absl::string_view::npos) { + /* Unmatched [ */ + return false; + } + if (rbracket == name.size() - 1) { + /* ] */ + *port = absl::string_view(); + } else if (name[rbracket + 1] == ':') { + /* ]: */ + *port = name.substr(rbracket + 2, name.size() - rbracket - 2); + *has_port = true; + } else { + /* ] */ + return false; + } + *host = name.substr(1, rbracket - 1); + if (host->find(':') == absl::string_view::npos) { + /* Require all bracketed hosts to contain a colon, because a hostname or + IPv4 address should never use brackets. */ + *host = absl::string_view(); + return false; + } + } else { + size_t colon = name.find(':'); + if (colon != absl::string_view::npos && + name.find(':', colon + 1) == absl::string_view::npos) { + /* Exactly 1 colon. Split into host:port. */ + *host = name.substr(0, colon); + *port = name.substr(colon + 1, name.size() - colon - 1); + *has_port = true; + } else { + /* 0 or 2+ colons. Bare hostname or IPv6 litearal. */ + *host = name; + *port = absl::string_view(); + } + } + return true; +} +} // namespace + +bool SplitHostPort(absl::string_view name, absl::string_view* host, + absl::string_view* port) { + bool unused; + return DoSplitHostPort(name, host, port, &unused); +} + +bool SplitHostPort(absl::string_view name, std::string* host, + std::string* port) { + GPR_DEBUG_ASSERT(host != nullptr && host->empty()); + GPR_DEBUG_ASSERT(port != nullptr && port->empty()); + absl::string_view host_view; + absl::string_view port_view; + bool has_port; + const bool ret = DoSplitHostPort(name, &host_view, &port_view, &has_port); + if (ret) { + // We always set the host, but port is set only when DoSplitHostPort find a + // port in the string, to remain backward compatible with the old + // gpr_split_host_port API. + *host = std::string(host_view); + if (has_port) { + *port = std::string(port_view); + } + } + return ret; +} + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/mpscq.cc b/src/core/lib/gprpp/mpscq.cc new file mode 100644 index 00000000..b5a17173 --- /dev/null +++ b/src/core/lib/gprpp/mpscq.cc @@ -0,0 +1,108 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/mpscq.h" + +namespace grpc_core { + +// +// MultiProducerSingleConsumerQueue +// + +bool MultiProducerSingleConsumerQueue::Push(Node* node) { + node->next.store(nullptr, std::memory_order_relaxed); + Node* prev = head_.exchange(node, std::memory_order_acq_rel); + prev->next.store(node, std::memory_order_release); + return prev == &stub_; +} + +MultiProducerSingleConsumerQueue::Node* +MultiProducerSingleConsumerQueue::Pop() { + bool empty; + return PopAndCheckEnd(&empty); +} + +MultiProducerSingleConsumerQueue::Node* +MultiProducerSingleConsumerQueue::PopAndCheckEnd(bool* empty) { + Node* tail = tail_; + Node* next = tail_->next.load(std::memory_order_acquire); + if (tail == &stub_) { + // indicates the list is actually (ephemerally) empty + if (next == nullptr) { + *empty = true; + return nullptr; + } + tail_ = next; + tail = next; + next = tail->next.load(std::memory_order_acquire); + } + if (next != nullptr) { + *empty = false; + tail_ = next; + return tail; + } + Node* head = head_.load(std::memory_order_acquire); + if (tail != head) { + *empty = false; + // indicates a retry is in order: we're still adding + return nullptr; + } + Push(&stub_); + next = tail->next.load(std::memory_order_acquire); + if (next != nullptr) { + *empty = false; + tail_ = next; + return tail; + } + // indicates a retry is in order: we're still adding + *empty = false; + return nullptr; +} + +// +// LockedMultiProducerSingleConsumerQueue +// + +bool LockedMultiProducerSingleConsumerQueue::Push(Node* node) { + return queue_.Push(node); +} + +LockedMultiProducerSingleConsumerQueue::Node* +LockedMultiProducerSingleConsumerQueue::TryPop() { + if (mu_.TryLock()) { + Node* node = queue_.Pop(); + mu_.Unlock(); + return node; + } + return nullptr; +} + +LockedMultiProducerSingleConsumerQueue::Node* +LockedMultiProducerSingleConsumerQueue::Pop() { + MutexLock lock(&mu_); + bool empty = false; + Node* node; + do { + node = queue_.PopAndCheckEnd(&empty); + } while (node == nullptr && !empty); + return node; +} + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/stat_posix.cc b/src/core/lib/gprpp/stat_posix.cc new file mode 100644 index 00000000..888361bf --- /dev/null +++ b/src/core/lib/gprpp/stat_posix.cc @@ -0,0 +1,49 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#ifdef GPR_POSIX_STAT + +#include +#include +#include +#include + +#include + +#include "src/core/lib/gprpp/stat.h" + +namespace grpc_core { + +absl::Status GetFileModificationTime(const char* filename, time_t* timestamp) { + GPR_ASSERT(filename != nullptr); + GPR_ASSERT(timestamp != nullptr); + struct stat buf; + if (stat(filename, &buf) != 0) { + const char* error_msg = strerror(errno); + gpr_log(GPR_ERROR, "stat failed for filename %s with error %s.", filename, + error_msg); + return absl::Status(absl::StatusCode::kInternal, error_msg); + } + // Last file/directory modification time. + *timestamp = buf.st_mtime; + return absl::OkStatus(); +} + +} // namespace grpc_core + +#endif // GPR_POSIX_STAT diff --git a/src/core/lib/gprpp/stat_windows.cc b/src/core/lib/gprpp/stat_windows.cc new file mode 100644 index 00000000..f1435e03 --- /dev/null +++ b/src/core/lib/gprpp/stat_windows.cc @@ -0,0 +1,48 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#ifdef GPR_WINDOWS_STAT + +#include +#include +#include + +#include + +#include "src/core/lib/gprpp/stat.h" + +namespace grpc_core { + +absl::Status GetFileModificationTime(const char* filename, time_t* timestamp) { + GPR_ASSERT(filename != nullptr); + GPR_ASSERT(timestamp != nullptr); + struct _stat buf; + if (_stat(filename, &buf) != 0) { + const char* error_msg = strerror(errno); + gpr_log(GPR_ERROR, "_stat failed for filename %s with error %s.", filename, + error_msg); + return absl::Status(absl::StatusCode::kInternal, error_msg); + } + // Last file/directory modification time. + *timestamp = buf.st_mtime; + return absl::OkStatus(); +} + +} // namespace grpc_core + +#endif // GPR_WINDOWS_STAT diff --git a/src/core/lib/gprpp/status_helper.cc b/src/core/lib/gprpp/status_helper.cc new file mode 100644 index 00000000..fa538dc5 --- /dev/null +++ b/src/core/lib/gprpp/status_helper.cc @@ -0,0 +1,427 @@ +// +// +// Copyright 2021 the gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/lib/gprpp/status_helper.h" + +#include "absl/strings/cord.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/time/clock.h" +#include "google/protobuf/any.upb.h" +#include "google/rpc/status.upb.h" +#include "upb/upb.hpp" + +#include + +#include "src/core/lib/gprpp/time_util.h" + +namespace grpc_core { + +namespace { + +#define TYPE_URL_PREFIX "type.googleapis.com/grpc.status." +#define TYPE_INT_TAG "int." +#define TYPE_STR_TAG "str." +#define TYPE_TIME_TAG "time." +#define TYPE_CHILDREN_TAG "children" +#define TYPE_URL(name) (TYPE_URL_PREFIX name) +const absl::string_view kTypeUrlPrefix = TYPE_URL_PREFIX; +const absl::string_view kTypeIntTag = TYPE_INT_TAG; +const absl::string_view kTypeStrTag = TYPE_STR_TAG; +const absl::string_view kTypeTimeTag = TYPE_TIME_TAG; +const absl::string_view kTypeChildrenTag = TYPE_CHILDREN_TAG; +const absl::string_view kChildrenPropertyUrl = TYPE_URL(TYPE_CHILDREN_TAG); + +const char* GetStatusIntPropertyUrl(StatusIntProperty key) { + switch (key) { + case StatusIntProperty::kErrorNo: + return TYPE_URL(TYPE_INT_TAG "errno"); + case StatusIntProperty::kFileLine: + return TYPE_URL(TYPE_INT_TAG "file_line"); + case StatusIntProperty::kStreamId: + return TYPE_URL(TYPE_INT_TAG "stream_id"); + case StatusIntProperty::kRpcStatus: + return TYPE_URL(TYPE_INT_TAG "grpc_status"); + case StatusIntProperty::kOffset: + return TYPE_URL(TYPE_INT_TAG "offset"); + case StatusIntProperty::kIndex: + return TYPE_URL(TYPE_INT_TAG "index"); + case StatusIntProperty::kSize: + return TYPE_URL(TYPE_INT_TAG "size"); + case StatusIntProperty::kHttp2Error: + return TYPE_URL(TYPE_INT_TAG "http2_error"); + case StatusIntProperty::kTsiCode: + return TYPE_URL(TYPE_INT_TAG "tsi_code"); + case StatusIntProperty::kWsaError: + return TYPE_URL(TYPE_INT_TAG "wsa_error"); + case StatusIntProperty::kFd: + return TYPE_URL(TYPE_INT_TAG "fd"); + case StatusIntProperty::kHttpStatus: + return TYPE_URL(TYPE_INT_TAG "http_status"); + case StatusIntProperty::kOccurredDuringWrite: + return TYPE_URL(TYPE_INT_TAG "occurred_during_write"); + case StatusIntProperty::ChannelConnectivityState: + return TYPE_URL(TYPE_INT_TAG "channel_connectivity_state"); + case StatusIntProperty::kLbPolicyDrop: + return TYPE_URL(TYPE_INT_TAG "lb_policy_drop"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +const char* GetStatusStrPropertyUrl(StatusStrProperty key) { + switch (key) { + case StatusStrProperty::kDescription: + return TYPE_URL(TYPE_STR_TAG "description"); + case StatusStrProperty::kFile: + return TYPE_URL(TYPE_STR_TAG "file"); + case StatusStrProperty::kOsError: + return TYPE_URL(TYPE_STR_TAG "os_error"); + case StatusStrProperty::kSyscall: + return TYPE_URL(TYPE_STR_TAG "syscall"); + case StatusStrProperty::kTargetAddress: + return TYPE_URL(TYPE_STR_TAG "target_address"); + case StatusStrProperty::kGrpcMessage: + return TYPE_URL(TYPE_STR_TAG "grpc_message"); + case StatusStrProperty::kRawBytes: + return TYPE_URL(TYPE_STR_TAG "raw_bytes"); + case StatusStrProperty::kTsiError: + return TYPE_URL(TYPE_STR_TAG "tsi_error"); + case StatusStrProperty::kFilename: + return TYPE_URL(TYPE_STR_TAG "filename"); + case StatusStrProperty::kKey: + return TYPE_URL(TYPE_STR_TAG "key"); + case StatusStrProperty::kValue: + return TYPE_URL(TYPE_STR_TAG "value"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +const char* GetStatusTimePropertyUrl(StatusTimeProperty key) { + switch (key) { + case StatusTimeProperty::kCreated: + return TYPE_URL(TYPE_TIME_TAG "created_time"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +void EncodeUInt32ToBytes(uint32_t v, char* buf) { + buf[0] = v & 0xFF; + buf[1] = (v >> 8) & 0xFF; + buf[2] = (v >> 16) & 0xFF; + buf[3] = (v >> 24) & 0xFF; +} + +uint32_t DecodeUInt32FromBytes(const char* buf) { + const unsigned char* ubuf = reinterpret_cast(buf); + return ubuf[0] | (uint32_t(ubuf[1]) << 8) | (uint32_t(ubuf[2]) << 16) | + (uint32_t(ubuf[3]) << 24); +} + +std::vector ParseChildren(absl::Cord children) { + std::vector result; + upb::Arena arena; + // Cord is flattened to iterate the buffer easily at the cost of memory copy. + // TODO(veblush): Optimize this once CordReader is introduced. + absl::string_view buf = children.Flatten(); + size_t cur = 0; + while (buf.size() - cur >= sizeof(uint32_t)) { + size_t msg_size = DecodeUInt32FromBytes(buf.data() + cur); + cur += sizeof(uint32_t); + GPR_ASSERT(buf.size() - cur >= msg_size); + google_rpc_Status* msg = + google_rpc_Status_parse(buf.data() + cur, msg_size, arena.ptr()); + cur += msg_size; + result.push_back(internal::StatusFromProto(msg)); + } + return result; +} + +} // namespace + +absl::Status StatusCreate(absl::StatusCode code, absl::string_view msg, + const DebugLocation& location, + std::initializer_list children) { + absl::Status s(code, msg); + if (location.file() != nullptr) { + StatusSetStr(&s, StatusStrProperty::kFile, location.file()); + } + if (location.line() != -1) { + StatusSetInt(&s, StatusIntProperty::kFileLine, location.line()); + } + StatusSetTime(&s, StatusTimeProperty::kCreated, absl::Now()); + for (const absl::Status& child : children) { + if (!child.ok()) { + StatusAddChild(&s, child); + } + } + return s; +} + +void StatusSetInt(absl::Status* status, StatusIntProperty key, intptr_t value) { + status->SetPayload(GetStatusIntPropertyUrl(key), + absl::Cord(std::to_string(value))); +} + +absl::optional StatusGetInt(const absl::Status& status, + StatusIntProperty key) { + absl::optional p = + status.GetPayload(GetStatusIntPropertyUrl(key)); + if (p.has_value()) { + absl::optional sv = p->TryFlat(); + intptr_t value; + if (sv.has_value()) { + if (absl::SimpleAtoi(*sv, &value)) { + return value; + } + } else { + if (absl::SimpleAtoi(std::string(*p), &value)) { + return value; + } + } + } + return {}; +} + +void StatusSetStr(absl::Status* status, StatusStrProperty key, + absl::string_view value) { + status->SetPayload(GetStatusStrPropertyUrl(key), absl::Cord(value)); +} + +absl::optional StatusGetStr(const absl::Status& status, + StatusStrProperty key) { + absl::optional p = + status.GetPayload(GetStatusStrPropertyUrl(key)); + if (p.has_value()) { + return std::string(*p); + } + return {}; +} + +void StatusSetTime(absl::Status* status, StatusTimeProperty key, + absl::Time time) { + status->SetPayload(GetStatusTimePropertyUrl(key), + absl::Cord(absl::string_view( + reinterpret_cast(&time), sizeof(time)))); +} + +absl::optional StatusGetTime(const absl::Status& status, + StatusTimeProperty key) { + absl::optional p = + status.GetPayload(GetStatusTimePropertyUrl(key)); + if (p.has_value()) { + absl::optional sv = p->TryFlat(); + if (sv.has_value()) { + return *reinterpret_cast(sv->data()); + } else { + std::string s = std::string(*p); + return *reinterpret_cast(s.c_str()); + } + } + return {}; +} + +void StatusAddChild(absl::Status* status, absl::Status child) { + upb::Arena arena; + // Serialize msg to buf + google_rpc_Status* msg = internal::StatusToProto(child, arena.ptr()); + size_t buf_len = 0; + char* buf = google_rpc_Status_serialize(msg, arena.ptr(), &buf_len); + // Append (msg-length and msg) to children payload + absl::optional old_children = + status->GetPayload(kChildrenPropertyUrl); + absl::Cord children; + if (old_children.has_value()) { + children = *old_children; + } + char head_buf[sizeof(uint32_t)]; + EncodeUInt32ToBytes(buf_len, head_buf); + children.Append(absl::string_view(head_buf, sizeof(uint32_t))); + children.Append(absl::string_view(buf, buf_len)); + status->SetPayload(kChildrenPropertyUrl, std::move(children)); +} + +std::vector StatusGetChildren(absl::Status status) { + absl::optional children = status.GetPayload(kChildrenPropertyUrl); + return children.has_value() ? ParseChildren(*children) + : std::vector(); +} + +std::string StatusToString(const absl::Status& status) { + if (status.ok()) { + return "OK"; + } + std::string head; + absl::StrAppend(&head, absl::StatusCodeToString(status.code())); + if (!status.message().empty()) { + absl::StrAppend(&head, ":", status.message()); + } + std::vector kvs; + absl::optional children; + status.ForEachPayload([&](absl::string_view type_url, + const absl::Cord& payload) { + if (absl::StartsWith(type_url, kTypeUrlPrefix)) { + type_url.remove_prefix(kTypeUrlPrefix.size()); + if (type_url == kTypeChildrenTag) { + children = payload; + return; + } + absl::string_view payload_view; + std::string payload_storage; + if (payload.TryFlat().has_value()) { + payload_view = payload.TryFlat().value(); + } else { + payload_storage = std::string(payload); + payload_view = payload_storage; + } + if (absl::StartsWith(type_url, kTypeIntTag)) { + type_url.remove_prefix(kTypeIntTag.size()); + kvs.push_back(absl::StrCat(type_url, ":", payload_view)); + } else if (absl::StartsWith(type_url, kTypeStrTag)) { + type_url.remove_prefix(kTypeStrTag.size()); + kvs.push_back(absl::StrCat(type_url, ":\"", + absl::CHexEscape(payload_view), "\"")); + } else if (absl::StartsWith(type_url, kTypeTimeTag)) { + type_url.remove_prefix(kTypeTimeTag.size()); + absl::Time t = + *reinterpret_cast(payload_view.data()); + kvs.push_back(absl::StrCat(type_url, ":\"", absl::FormatTime(t), "\"")); + } else { + kvs.push_back(absl::StrCat(type_url, ":\"", + absl::CHexEscape(payload_view), "\"")); + } + } else { + absl::optional payload_view = payload.TryFlat(); + std::string payload_str = absl::CHexEscape( + payload_view.has_value() ? *payload_view : std::string(payload)); + kvs.push_back(absl::StrCat(type_url, ":\"", payload_str, "\"")); + } + }); + if (children.has_value()) { + std::vector children_status = ParseChildren(*children); + std::vector children_text; + children_text.reserve(children_status.size()); + for (const absl::Status& child_status : children_status) { + children_text.push_back(StatusToString(child_status)); + } + kvs.push_back( + absl::StrCat("children:[", absl::StrJoin(children_text, ", "), "]")); + } + return kvs.empty() ? head + : absl::StrCat(head, " {", absl::StrJoin(kvs, ", "), "}"); +} + +namespace internal { + +google_rpc_Status* StatusToProto(const absl::Status& status, upb_arena* arena) { + google_rpc_Status* msg = google_rpc_Status_new(arena); + google_rpc_Status_set_code(msg, int32_t(status.code())); + google_rpc_Status_set_message( + msg, upb_strview_make(status.message().data(), status.message().size())); + status.ForEachPayload([&](absl::string_view type_url, + const absl::Cord& payload) { + google_protobuf_Any* any = google_rpc_Status_add_details(msg, arena); + char* type_url_buf = + reinterpret_cast(upb_arena_malloc(arena, type_url.size())); + memcpy(type_url_buf, type_url.data(), type_url.size()); + google_protobuf_Any_set_type_url( + any, upb_strview_make(type_url_buf, type_url.size())); + absl::optional v_view = payload.TryFlat(); + if (v_view.has_value()) { + google_protobuf_Any_set_value( + any, upb_strview_make(v_view->data(), v_view->size())); + } else { + char* buf = + reinterpret_cast(upb_arena_malloc(arena, payload.size())); + char* cur = buf; + for (absl::string_view chunk : payload.Chunks()) { + memcpy(cur, chunk.data(), chunk.size()); + cur += chunk.size(); + } + google_protobuf_Any_set_value(any, upb_strview_make(buf, payload.size())); + } + }); + return msg; +} + +absl::Status StatusFromProto(google_rpc_Status* msg) { + int32_t code = google_rpc_Status_code(msg); + upb_strview message = google_rpc_Status_message(msg); + absl::Status status(static_cast(code), + absl::string_view(message.data, message.size)); + size_t detail_len; + const google_protobuf_Any* const* details = + google_rpc_Status_details(msg, &detail_len); + for (size_t i = 0; i < detail_len; i++) { + upb_strview type_url = google_protobuf_Any_type_url(details[i]); + upb_strview value = google_protobuf_Any_value(details[i]); + status.SetPayload(absl::string_view(type_url.data, type_url.size), + absl::Cord(absl::string_view(value.data, value.size))); + } + return status; +} + +uintptr_t StatusAllocPtr(absl::Status s) { + // This relies the fact that absl::Status has only one member, StatusRep* + // so the sizeof(absl::Status) has the same size of intptr_t and StatusRep* + // can be stolen using placement allocation. + static_assert(sizeof(intptr_t) == sizeof(absl::Status), + "absl::Status should be as big as intptr_t"); + // This does two things; + // 1. Copies StatusRep* of absl::Status to ptr + // 2. Increases the counter of StatusRep if it's not inlined + uintptr_t ptr; + new (&ptr) absl::Status(s); + return ptr; +} + +void StatusFreePtr(uintptr_t ptr) { + // Decreases the counter of StatusRep if it's not inlined. + reinterpret_cast(&ptr)->~Status(); +} + +absl::Status StatusGetFromPtr(uintptr_t ptr) { + // Constructs Status from ptr having the address of StatusRep. + return *reinterpret_cast(&ptr); +} + +uintptr_t StatusAllocHeapPtr(absl::Status s) { + if (s.ok()) return kOkStatusPtr; + absl::Status* ptr = new absl::Status(s); + return reinterpret_cast(ptr); +} + +void StatusFreeHeapPtr(uintptr_t ptr) { + absl::Status* s = reinterpret_cast(ptr); + delete s; +} + +absl::Status StatusGetFromHeapPtr(uintptr_t ptr) { + if (ptr == kOkStatusPtr) { + return absl::OkStatus(); + } else { + return *reinterpret_cast(ptr); + } +} + +} // namespace internal + +} // namespace grpc_core diff --git a/src/core/lib/gprpp/thd_posix.cc b/src/core/lib/gprpp/thd_posix.cc new file mode 100644 index 00000000..2d522c9b --- /dev/null +++ b/src/core/lib/gprpp/thd_posix.cc @@ -0,0 +1,209 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Posix implementation for gpr threads. */ + +#include + +#ifdef GPR_POSIX_SYNC + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/fork.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" + +namespace grpc_core { +namespace { +class ThreadInternalsPosix; +struct thd_arg { + ThreadInternalsPosix* thread; + void (*body)(void* arg); /* body of a thread */ + void* arg; /* argument to a thread */ + const char* name; /* name of thread. Can be nullptr. */ + bool joinable; + bool tracked; +}; + +size_t RoundUpToPageSize(size_t size) { + // TODO(yunjiaw): Change this variable (page_size) to a function-level static + // when possible + size_t page_size = static_cast(sysconf(_SC_PAGESIZE)); + return (size + page_size - 1) & ~(page_size - 1); +} + +// Returns the minimum valid stack size that can be passed to +// pthread_attr_setstacksize. +size_t MinValidStackSize(size_t request_size) { + size_t min_stacksize = sysconf(_SC_THREAD_STACK_MIN); + if (request_size < min_stacksize) { + request_size = min_stacksize; + } + + // On some systems, pthread_attr_setstacksize() can fail if stacksize is + // not a multiple of the system page size. + return RoundUpToPageSize(request_size); +} + +class ThreadInternalsPosix : public internal::ThreadInternalsInterface { + public: + ThreadInternalsPosix(const char* thd_name, void (*thd_body)(void* arg), + void* arg, bool* success, const Thread::Options& options) + : started_(false) { + gpr_mu_init(&mu_); + gpr_cv_init(&ready_); + pthread_attr_t attr; + /* don't use gpr_malloc as we may cause an infinite recursion with + * the profiling code */ + thd_arg* info = static_cast(malloc(sizeof(*info))); + GPR_ASSERT(info != nullptr); + info->thread = this; + info->body = thd_body; + info->arg = arg; + info->name = thd_name; + info->joinable = options.joinable(); + info->tracked = options.tracked(); + if (options.tracked()) { + Fork::IncThreadCount(); + } + + GPR_ASSERT(pthread_attr_init(&attr) == 0); + if (options.joinable()) { + GPR_ASSERT(pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE) == + 0); + } else { + GPR_ASSERT(pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED) == + 0); + } + + if (options.stack_size() != 0) { + size_t stack_size = MinValidStackSize(options.stack_size()); + GPR_ASSERT(pthread_attr_setstacksize(&attr, stack_size) == 0); + } + + *success = (pthread_create( + &pthread_id_, &attr, + [](void* v) -> void* { + thd_arg arg = *static_cast(v); + free(v); + if (arg.name != nullptr) { +#if GPR_APPLE_PTHREAD_NAME + /* Apple supports 64 characters, and will + * truncate if it's longer. */ + pthread_setname_np(arg.name); +#elif GPR_LINUX_PTHREAD_NAME + /* Linux supports 16 characters max, and will + * error if it's longer. */ + char buf[16]; + size_t buf_len = GPR_ARRAY_SIZE(buf) - 1; + strncpy(buf, arg.name, buf_len); + buf[buf_len] = '\0'; + pthread_setname_np(pthread_self(), buf); +#endif // GPR_APPLE_PTHREAD_NAME + } + + gpr_mu_lock(&arg.thread->mu_); + while (!arg.thread->started_) { + gpr_cv_wait(&arg.thread->ready_, &arg.thread->mu_, + gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&arg.thread->mu_); + + if (!arg.joinable) { + delete arg.thread; + } + + (*arg.body)(arg.arg); + if (arg.tracked) { + Fork::DecThreadCount(); + } + return nullptr; + }, + info) == 0); + + GPR_ASSERT(pthread_attr_destroy(&attr) == 0); + + if (!(*success)) { + /* don't use gpr_free, as this was allocated using malloc (see above) */ + free(info); + if (options.tracked()) { + Fork::DecThreadCount(); + } + } + } + + ~ThreadInternalsPosix() override { + gpr_mu_destroy(&mu_); + gpr_cv_destroy(&ready_); + } + + void Start() override { + gpr_mu_lock(&mu_); + started_ = true; + gpr_cv_signal(&ready_); + gpr_mu_unlock(&mu_); + } + + void Join() override { pthread_join(pthread_id_, nullptr); } + + private: + gpr_mu mu_; + gpr_cv ready_; + bool started_; + pthread_t pthread_id_; +}; + +} // namespace + +Thread::Thread(const char* thd_name, void (*thd_body)(void* arg), void* arg, + bool* success, const Options& options) + : options_(options) { + bool outcome = false; + impl_ = new ThreadInternalsPosix(thd_name, thd_body, arg, &outcome, options); + if (outcome) { + state_ = ALIVE; + } else { + state_ = FAILED; + delete impl_; + impl_ = nullptr; + } + + if (success != nullptr) { + *success = outcome; + } +} +} // namespace grpc_core + +// The following is in the external namespace as it is exposed as C89 API +gpr_thd_id gpr_thd_currentid(void) { + // Use C-style casting because Linux and OSX have different definitions + // of pthread_t so that a single C++ cast doesn't handle it. + // NOLINTNEXTLINE(google-readability-casting) + return (gpr_thd_id)pthread_self(); +} + +#endif /* GPR_POSIX_SYNC */ diff --git a/src/core/lib/gprpp/thd_windows.cc b/src/core/lib/gprpp/thd_windows.cc new file mode 100644 index 00000000..ce1a7539 --- /dev/null +++ b/src/core/lib/gprpp/thd_windows.cc @@ -0,0 +1,171 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Windows implementation for gpr threads. */ + +#include + +#ifdef GPR_WINDOWS + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" + +namespace { +class ThreadInternalsWindows; +struct thd_info { + ThreadInternalsWindows* thread; + void (*body)(void* arg); /* body of a thread */ + void* arg; /* argument to a thread */ + HANDLE join_event; /* the join event */ + bool joinable; /* whether it is joinable */ +}; + +GPR_THREAD_LOCAL(struct thd_info*) g_thd_info; + +class ThreadInternalsWindows + : public grpc_core::internal::ThreadInternalsInterface { + public: + ThreadInternalsWindows(void (*thd_body)(void* arg), void* arg, bool* success, + const grpc_core::Thread::Options& options) + : started_(false) { + gpr_mu_init(&mu_); + gpr_cv_init(&ready_); + + HANDLE handle; + info_ = (struct thd_info*)gpr_malloc(sizeof(*info_)); + info_->thread = this; + info_->body = thd_body; + info_->arg = arg; + info_->join_event = nullptr; + info_->joinable = options.joinable(); + if (info_->joinable) { + info_->join_event = CreateEvent(nullptr, FALSE, FALSE, nullptr); + if (info_->join_event == nullptr) { + gpr_free(info_); + *success = false; + return; + } + } + + if (options.stack_size() != 0) { + // Windows will round up the given stack_size value to nearest page. + handle = CreateThread(nullptr, options.stack_size(), thread_body, info_, + 0, nullptr); + } else { + handle = CreateThread(nullptr, 64 * 1024, thread_body, info_, 0, nullptr); + } + + if (handle == nullptr) { + destroy_thread(); + *success = false; + } else { + CloseHandle(handle); + *success = true; + } + } + + ~ThreadInternalsWindows() override { + gpr_mu_destroy(&mu_); + gpr_cv_destroy(&ready_); + } + + void Start() override { + gpr_mu_lock(&mu_); + started_ = true; + gpr_cv_signal(&ready_); + gpr_mu_unlock(&mu_); + } + + void Join() override { + DWORD ret = WaitForSingleObject(info_->join_event, INFINITE); + GPR_ASSERT(ret == WAIT_OBJECT_0); + destroy_thread(); + } + + private: + static DWORD WINAPI thread_body(void* v) { + g_thd_info = static_cast(v); + gpr_mu_lock(&g_thd_info->thread->mu_); + while (!g_thd_info->thread->started_) { + gpr_cv_wait(&g_thd_info->thread->ready_, &g_thd_info->thread->mu_, + gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&g_thd_info->thread->mu_); + if (!g_thd_info->joinable) { + delete g_thd_info->thread; + g_thd_info->thread = nullptr; + } + g_thd_info->body(g_thd_info->arg); + if (g_thd_info->joinable) { + BOOL ret = SetEvent(g_thd_info->join_event); + GPR_ASSERT(ret); + } else { + gpr_free(g_thd_info); + } + return 0; + } + + void destroy_thread() { + if (info_ != nullptr && info_->joinable) { + CloseHandle(info_->join_event); + } + gpr_free(info_); + } + + gpr_mu mu_; + gpr_cv ready_; + bool started_; + thd_info* info_; +}; + +} // namespace + +namespace grpc_core { + +Thread::Thread(const char* thd_name, void (*thd_body)(void* arg), void* arg, + bool* success, const Options& options) + : options_(options) { + bool outcome = false; + impl_ = new ThreadInternalsWindows(thd_body, arg, &outcome, options); + if (outcome) { + state_ = ALIVE; + } else { + state_ = FAILED; + delete impl_; + impl_ = nullptr; + } + + if (success != nullptr) { + *success = outcome; + } +} + +} // namespace grpc_core + +gpr_thd_id gpr_thd_currentid(void) { + return reinterpret_cast(g_thd_info); +} + +#endif /* GPR_WINDOWS */ diff --git a/src/core/lib/gprpp/time_util.cc b/src/core/lib/gprpp/time_util.cc new file mode 100644 index 00000000..84795032 --- /dev/null +++ b/src/core/lib/gprpp/time_util.cc @@ -0,0 +1,77 @@ +// +// Copyright 2021 the gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/gprpp/time_util.h" + +#include + +namespace grpc_core { + +gpr_timespec ToGprTimeSpec(absl::Duration duration) { + if (duration == absl::InfiniteDuration()) { + return gpr_inf_future(GPR_TIMESPAN); + } else if (duration == -absl::InfiniteDuration()) { + return gpr_inf_past(GPR_TIMESPAN); + } else { + int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); + int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); + return gpr_time_add(gpr_time_from_seconds(s, GPR_TIMESPAN), + gpr_time_from_nanos(n, GPR_TIMESPAN)); + } +} + +gpr_timespec ToGprTimeSpec(absl::Time time) { + if (time == absl::InfiniteFuture()) { + return gpr_inf_future(GPR_CLOCK_REALTIME); + } else if (time == absl::InfinitePast()) { + return gpr_inf_past(GPR_CLOCK_REALTIME); + } else { + timespec ts = absl::ToTimespec(time); + gpr_timespec out; + out.tv_sec = static_cast(ts.tv_sec); + out.tv_nsec = static_cast(ts.tv_nsec); + out.clock_type = GPR_CLOCK_REALTIME; + return out; + } +} + +absl::Duration ToAbslDuration(gpr_timespec ts) { + GPR_ASSERT(ts.clock_type == GPR_TIMESPAN); + if (gpr_time_cmp(ts, gpr_inf_future(GPR_TIMESPAN)) == 0) { + return absl::InfiniteDuration(); + } else if (gpr_time_cmp(ts, gpr_inf_past(GPR_TIMESPAN)) == 0) { + return -absl::InfiniteDuration(); + } else { + return absl::Seconds(ts.tv_sec) + absl::Nanoseconds(ts.tv_nsec); + } +} + +absl::Time ToAbslTime(gpr_timespec ts) { + GPR_ASSERT(ts.clock_type != GPR_TIMESPAN); + gpr_timespec rts = gpr_convert_clock_type(ts, GPR_CLOCK_REALTIME); + if (gpr_time_cmp(rts, gpr_inf_future(GPR_CLOCK_REALTIME)) == 0) { + return absl::InfiniteFuture(); + } else if (gpr_time_cmp(rts, gpr_inf_past(GPR_CLOCK_REALTIME)) == 0) { + return absl::InfinitePast(); + } else { + return absl::UnixEpoch() + absl::Seconds(rts.tv_sec) + + absl::Nanoseconds(rts.tv_nsec); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/http/format_request.cc b/src/core/lib/http/format_request.cc new file mode 100644 index 00000000..1f0a5cf5 --- /dev/null +++ b/src/core/lib/http/format_request.cc @@ -0,0 +1,104 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/http/format_request.h" + +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" + +static void fill_common_header(const grpc_httpcli_request* request, + bool connection_close, + std::vector* buf) { + buf->push_back(request->http.path); + buf->push_back(" HTTP/1.0\r\n"); + /* just in case some crazy server really expects HTTP/1.1 */ + buf->push_back("Host: "); + buf->push_back(request->host); + buf->push_back("\r\n"); + if (connection_close) buf->push_back("Connection: close\r\n"); + buf->push_back("User-Agent: " GRPC_HTTPCLI_USER_AGENT "\r\n"); + /* user supplied headers */ + for (size_t i = 0; i < request->http.hdr_count; i++) { + buf->push_back(request->http.hdrs[i].key); + buf->push_back(": "); + buf->push_back(request->http.hdrs[i].value); + buf->push_back("\r\n"); + } +} + +grpc_slice grpc_httpcli_format_get_request( + const grpc_httpcli_request* request) { + std::vector out; + out.push_back("GET "); + fill_common_header(request, true, &out); + out.push_back("\r\n"); + std::string req = absl::StrJoin(out, ""); + return grpc_slice_from_copied_buffer(req.data(), req.size()); +} + +grpc_slice grpc_httpcli_format_post_request(const grpc_httpcli_request* request, + const char* body_bytes, + size_t body_size) { + std::vector out; + out.push_back("POST "); + fill_common_header(request, true, &out); + if (body_bytes != nullptr) { + bool has_content_type = false; + for (size_t i = 0; i < request->http.hdr_count; i++) { + if (strcmp(request->http.hdrs[i].key, "Content-Type") == 0) { + has_content_type = true; + break; + } + } + if (!has_content_type) { + out.push_back("Content-Type: text/plain\r\n"); + } + out.push_back(absl::StrFormat("Content-Length: %lu\r\n", + static_cast(body_size))); + } + out.push_back("\r\n"); + std::string req = absl::StrJoin(out, ""); + if (body_bytes != nullptr) { + absl::StrAppend(&req, absl::string_view(body_bytes, body_size)); + } + return grpc_slice_from_copied_buffer(req.data(), req.size()); +} + +grpc_slice grpc_httpcli_format_connect_request( + const grpc_httpcli_request* request) { + std::vector out; + out.push_back("CONNECT "); + fill_common_header(request, false, &out); + out.push_back("\r\n"); + std::string req = absl::StrJoin(out, ""); + return grpc_slice_from_copied_buffer(req.data(), req.size()); +} diff --git a/src/core/lib/http/httpcli.cc b/src/core/lib/http/httpcli.cc new file mode 100644 index 00000000..ab4431e3 --- /dev/null +++ b/src/core/lib/http/httpcli.cc @@ -0,0 +1,324 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/http/httpcli.h" + +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/http/format_request.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { +namespace { + +class InternalRequest { + public: + InternalRequest(const grpc_slice& request_text, + grpc_httpcli_response* response, + grpc_resource_quota* resource_quota, absl::string_view host, + absl::string_view ssl_host_override, grpc_millis deadline, + const grpc_httpcli_handshaker* handshaker, + grpc_closure* on_done, grpc_httpcli_context* context, + grpc_polling_entity* pollent, const char* name) + : request_text_(request_text), + resource_quota_(resource_quota), + host_(host), + ssl_host_override_(ssl_host_override), + deadline_(deadline), + handshaker_(handshaker), + on_done_(on_done), + context_(context), + pollent_(pollent) { + grpc_http_parser_init(&parser_, GRPC_HTTP_RESPONSE, response); + grpc_slice_buffer_init(&incoming_); + grpc_slice_buffer_init(&outgoing_); + grpc_iomgr_register_object(&iomgr_obj_, name); + + GRPC_CLOSURE_INIT(&on_read_, OnRead, this, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&done_write_, DoneWrite, this, grpc_schedule_on_exec_ctx); + GPR_ASSERT(pollent); + grpc_polling_entity_add_to_pollset_set(pollent_, context->pollset_set); + grpc_resolve_address( + host_.c_str(), handshaker_->default_port, context_->pollset_set, + GRPC_CLOSURE_CREATE(OnResolved, this, grpc_schedule_on_exec_ctx), + &addresses_); + } + + ~InternalRequest() { + grpc_http_parser_destroy(&parser_); + if (addresses_ != nullptr) { + grpc_resolved_addresses_destroy(addresses_); + } + if (ep_ != nullptr) { + grpc_endpoint_destroy(ep_); + } + grpc_slice_unref_internal(request_text_); + grpc_iomgr_unregister_object(&iomgr_obj_); + grpc_slice_buffer_destroy_internal(&incoming_); + grpc_slice_buffer_destroy_internal(&outgoing_); + GRPC_ERROR_UNREF(overall_error_); + grpc_resource_quota_unref_internal(resource_quota_); + } + + private: + void Finish(grpc_error_handle error) { + grpc_polling_entity_del_from_pollset_set(pollent_, context_->pollset_set); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done_, error); + delete this; + } + + void AppendError(grpc_error_handle error) { + if (overall_error_ == GRPC_ERROR_NONE) { + overall_error_ = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed HTTP/1 client request"); + } + grpc_resolved_address* addr = &addresses_->addrs[next_address_ - 1]; + std::string addr_text = grpc_sockaddr_to_uri(addr); + overall_error_ = grpc_error_add_child( + overall_error_, + grpc_error_set_str(error, GRPC_ERROR_STR_TARGET_ADDRESS, addr_text)); + } + + void DoRead() { + grpc_endpoint_read(ep_, &incoming_, &on_read_, /*urgent=*/true); + } + + static void OnRead(void* user_data, grpc_error_handle error) { + InternalRequest* req = static_cast(user_data); + req->OnReadInternal(error); + } + + void OnReadInternal(grpc_error_handle error) { + size_t i; + + for (i = 0; i < incoming_.count; i++) { + if (GRPC_SLICE_LENGTH(incoming_.slices[i])) { + have_read_byte_ = 1; + grpc_error_handle err = + grpc_http_parser_parse(&parser_, incoming_.slices[i], nullptr); + if (err != GRPC_ERROR_NONE) { + Finish(err); + return; + } + } + } + + if (error == GRPC_ERROR_NONE) { + DoRead(); + } else if (!have_read_byte_) { + NextAddress(GRPC_ERROR_REF(error)); + } else { + Finish(grpc_http_parser_eof(&parser_)); + } + } + + void OnWritten() { DoRead(); } + + static void DoneWrite(void* arg, grpc_error_handle error) { + InternalRequest* req = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + req->OnWritten(); + } else { + req->NextAddress(GRPC_ERROR_REF(error)); + } + } + + void StartWrite() { + grpc_slice_ref_internal(request_text_); + grpc_slice_buffer_add(&outgoing_, request_text_); + grpc_endpoint_write(ep_, &outgoing_, &done_write_, nullptr); + } + + static void OnHandshakeDone(void* arg, grpc_endpoint* ep) { + InternalRequest* req = static_cast(arg); + + if (!ep) { + req->NextAddress(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unexplained handshake failure")); + return; + } + + req->ep_ = ep; + req->StartWrite(); + } + + static void OnConnected(void* arg, grpc_error_handle error) { + InternalRequest* req = static_cast(arg); + + if (!req->ep_) { + req->NextAddress(GRPC_ERROR_REF(error)); + return; + } + req->handshaker_->handshake(req, req->ep_, + req->ssl_host_override_.empty() + ? req->host_.c_str() + : req->ssl_host_override_.c_str(), + req->deadline_, OnHandshakeDone); + } + + void NextAddress(grpc_error_handle error) { + grpc_resolved_address* addr; + if (error != GRPC_ERROR_NONE) { + AppendError(error); + } + if (next_address_ == addresses_->naddrs) { + Finish(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed HTTP requests to all targets", &overall_error_, 1)); + return; + } + addr = &addresses_->addrs[next_address_++]; + GRPC_CLOSURE_INIT(&connected_, OnConnected, this, + grpc_schedule_on_exec_ctx); + grpc_tcp_client_connect(&connected_, &ep_, + grpc_slice_allocator_create( + resource_quota_, grpc_sockaddr_to_uri(addr)), + context_->pollset_set, nullptr, addr, deadline_); + } + + static void OnResolved(void* arg, grpc_error_handle error) { + InternalRequest* req = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + req->Finish(GRPC_ERROR_REF(error)); + return; + } + req->next_address_ = 0; + req->NextAddress(GRPC_ERROR_NONE); + } + + grpc_slice request_text_; + grpc_http_parser parser_; + grpc_resolved_addresses* addresses_ = nullptr; + size_t next_address_ = 0; + grpc_endpoint* ep_ = nullptr; + grpc_resource_quota* resource_quota_; + std::string host_; + std::string ssl_host_override_; + grpc_millis deadline_; + int have_read_byte_ = 0; + const grpc_httpcli_handshaker* handshaker_; + grpc_closure* on_done_; + grpc_httpcli_context* context_; + grpc_polling_entity* pollent_; + grpc_iomgr_object iomgr_obj_; + grpc_slice_buffer incoming_; + grpc_slice_buffer outgoing_; + grpc_closure on_read_; + grpc_closure done_write_; + grpc_closure connected_; + grpc_error_handle overall_error_ = GRPC_ERROR_NONE; +}; + +} // namespace +} // namespace grpc_core + +static grpc_httpcli_get_override g_get_override = nullptr; +static grpc_httpcli_post_override g_post_override = nullptr; + +static void plaintext_handshake(void* arg, grpc_endpoint* endpoint, + const char* /*host*/, grpc_millis /*deadline*/, + void (*on_done)(void* arg, + grpc_endpoint* endpoint)) { + on_done(arg, endpoint); +} + +const grpc_httpcli_handshaker grpc_httpcli_plaintext = {"http", + plaintext_handshake}; + +void grpc_httpcli_context_init(grpc_httpcli_context* context) { + context->pollset_set = grpc_pollset_set_create(); +} + +void grpc_httpcli_context_destroy(grpc_httpcli_context* context) { + grpc_pollset_set_destroy(context->pollset_set); +} + +static void internal_request_begin(grpc_httpcli_context* context, + grpc_polling_entity* pollent, + grpc_resource_quota* resource_quota, + const grpc_httpcli_request* request, + grpc_millis deadline, grpc_closure* on_done, + grpc_httpcli_response* response, + const char* name, + const grpc_slice& request_text) { + new grpc_core::InternalRequest( + request_text, response, resource_quota, request->host, + request->ssl_host_override, deadline, + request->handshaker ? request->handshaker : &grpc_httpcli_plaintext, + on_done, context, pollent, name); +} + +void grpc_httpcli_get(grpc_httpcli_context* context, + grpc_polling_entity* pollent, + grpc_resource_quota* resource_quota, + const grpc_httpcli_request* request, grpc_millis deadline, + grpc_closure* on_done, grpc_httpcli_response* response) { + if (g_get_override && g_get_override(request, deadline, on_done, response)) { + grpc_resource_quota_unref_internal(resource_quota); + return; + } + std::string name = + absl::StrFormat("HTTP:GET:%s:%s", request->host, request->http.path); + internal_request_begin(context, pollent, resource_quota, request, deadline, + on_done, response, name.c_str(), + grpc_httpcli_format_get_request(request)); +} + +void grpc_httpcli_post(grpc_httpcli_context* context, + grpc_polling_entity* pollent, + grpc_resource_quota* resource_quota, + const grpc_httpcli_request* request, + const char* body_bytes, size_t body_size, + grpc_millis deadline, grpc_closure* on_done, + grpc_httpcli_response* response) { + if (g_post_override && g_post_override(request, body_bytes, body_size, + deadline, on_done, response)) { + grpc_resource_quota_unref_internal(resource_quota); + return; + } + std::string name = + absl::StrFormat("HTTP:POST:%s:%s", request->host, request->http.path); + internal_request_begin( + context, pollent, resource_quota, request, deadline, on_done, response, + name.c_str(), + grpc_httpcli_format_post_request(request, body_bytes, body_size)); +} + +void grpc_httpcli_set_override(grpc_httpcli_get_override get, + grpc_httpcli_post_override post) { + g_get_override = get; + g_post_override = post; +} diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc new file mode 100644 index 00000000..1efcf20d --- /dev/null +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -0,0 +1,215 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/ssl_transport_security.h" + +class grpc_httpcli_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name) + : grpc_channel_security_connector( + /*url_scheme=*/nullptr, + /*channel_creds=*/nullptr, + /*request_metadata_creds=*/nullptr), + secure_peer_name_(secure_peer_name) {} + + ~grpc_httpcli_ssl_channel_security_connector() override { + if (handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(handshaker_factory_); + } + if (secure_peer_name_ != nullptr) { + gpr_free(secure_peer_name_); + } + } + + tsi_result InitHandshakerFactory(const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store) { + tsi_ssl_client_handshaker_options options; + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + return tsi_create_ssl_client_handshaker_factory_with_options( + &options, &handshaker_factory_); + } + + void add_handshakers(const grpc_channel_args* args, + grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_mgr) override { + tsi_handshaker* handshaker = nullptr; + if (handshaker_factory_ != nullptr) { + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + handshaker_factory_, secure_peer_name_, &handshaker); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } + } + handshake_mgr->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this, args)); + } + + tsi_ssl_client_handshaker_factory* handshaker_factory() const { + return handshaker_factory_; + } + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* /*auth_context*/, + grpc_closure* on_peer_checked) override { + grpc_error_handle error = GRPC_ERROR_NONE; + + /* Check the peer name. */ + if (secure_peer_name_ != nullptr && + !tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Peer name ", secure_peer_name_, " is not in peer certificate")); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast( + other_sc); + return strcmp(secure_peer_name_, other->secure_peer_name_); + } + + bool check_call_host(absl::string_view /*host*/, + grpc_auth_context* /*auth_context*/, + grpc_closure* /*on_call_host_checked*/, + grpc_error_handle* error) override { + *error = GRPC_ERROR_NONE; + return true; + } + + void cancel_check_call_host(grpc_closure* /*on_call_host_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + const char* secure_peer_name() const { return secure_peer_name_; } + + private: + tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr; + char* secure_peer_name_; +}; + +static grpc_core::RefCountedPtr +httpcli_ssl_channel_security_connector_create( + const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, + const char* secure_peer_name, grpc_channel_args* /*channel_args*/) { + if (secure_peer_name != nullptr && pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, + "Cannot assert a secure peer name without a trust root."); + return nullptr; + } + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name)); + tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return nullptr; + } + return c; +} + +/* handshaker */ + +struct on_done_closure { + void (*func)(void* arg, grpc_endpoint* endpoint); + void* arg; + grpc_core::RefCountedPtr handshake_mgr; +}; +static void on_handshake_done(void* arg, grpc_error_handle error) { + auto* args = static_cast(arg); + on_done_closure* c = static_cast(args->user_data); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Secure transport setup failed: %s", + grpc_error_std_string(error).c_str()); + c->func(c->arg, nullptr); + } else { + grpc_channel_args_destroy(args->args); + grpc_slice_buffer_destroy_internal(args->read_buffer); + gpr_free(args->read_buffer); + c->func(c->arg, args->endpoint); + } + delete c; +} + +static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, + grpc_millis deadline, + void (*on_done)(void* arg, grpc_endpoint* endpoint)) { + auto* c = new on_done_closure(); + const char* pem_root_certs = + grpc_core::DefaultSslRootStore::GetPemRootCerts(); + const tsi_ssl_root_certs_store* root_store = + grpc_core::DefaultSslRootStore::GetRootStore(); + if (root_store == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + on_done(arg, nullptr); + gpr_free(c); + return; + } + c->func = on_done; + c->arg = arg; + grpc_core::RefCountedPtr sc = + httpcli_ssl_channel_security_connector_create( + pem_root_certs, root_store, host, + static_cast(arg)->args); + + GPR_ASSERT(sc != nullptr); + grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); + grpc_channel_args args = {1, &channel_arg}; + c->handshake_mgr = grpc_core::MakeRefCounted(); + grpc_core::CoreConfiguration::Get().handshaker_registry().AddHandshakers( + grpc_core::HANDSHAKER_CLIENT, &args, + /*interested_parties=*/nullptr, c->handshake_mgr.get()); + c->handshake_mgr->DoHandshake(tcp, /*channel_args=*/nullptr, deadline, + /*acceptor=*/nullptr, on_handshake_done, + /*user_data=*/c); + sc.reset(DEBUG_LOCATION, "httpcli"); +} + +const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake}; diff --git a/src/core/lib/http/parser.cc b/src/core/lib/http/parser.cc new file mode 100644 index 00000000..d88ebdb4 --- /dev/null +++ b/src/core/lib/http/parser.cc @@ -0,0 +1,392 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/http/parser.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" + +grpc_core::TraceFlag grpc_http1_trace(false, "http1"); + +static char* buf2str(void* buffer, size_t length) { + char* out = static_cast(gpr_malloc(length + 1)); + memcpy(out, buffer, length); + out[length] = 0; + return out; +} + +static grpc_error_handle handle_response_line(grpc_http_parser* parser) { + uint8_t* beg = parser->cur_line; + uint8_t* cur = beg; + uint8_t* end = beg + parser->cur_line_length; + + if (cur == end || *cur++ != 'H') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'H'"); + } + if (cur == end || *cur++ != 'T') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'T'"); + } + if (cur == end || *cur++ != 'T') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'T'"); + } + if (cur == end || *cur++ != 'P') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'P'"); + } + if (cur == end || *cur++ != '/') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected '/'"); + } + if (cur == end || *cur++ != '1') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected '1'"); + } + if (cur == end || *cur++ != '.') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected '.'"); + } + if (cur == end || *cur < '0' || *cur++ > '1') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Expected HTTP/1.0 or HTTP/1.1"); + } + if (cur == end || *cur++ != ' ') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected ' '"); + } + if (cur == end || *cur < '1' || *cur++ > '9') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected status code"); + } + if (cur == end || *cur < '0' || *cur++ > '9') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected status code"); + } + if (cur == end || *cur < '0' || *cur++ > '9') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected status code"); + } + parser->http.response->status = + (cur[-3] - '0') * 100 + (cur[-2] - '0') * 10 + (cur[-1] - '0'); + if (cur == end || *cur++ != ' ') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected ' '"); + } + + /* we don't really care about the status code message */ + + return GRPC_ERROR_NONE; +} + +static grpc_error_handle handle_request_line(grpc_http_parser* parser) { + uint8_t* beg = parser->cur_line; + uint8_t* cur = beg; + uint8_t* end = beg + parser->cur_line_length; + uint8_t vers_major = 0; + uint8_t vers_minor = 0; + + while (cur != end && *cur++ != ' ') { + } + if (cur == end) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "No method on HTTP request line"); + } + parser->http.request->method = + buf2str(beg, static_cast(cur - beg - 1)); + + beg = cur; + while (cur != end && *cur++ != ' ') { + } + if (cur == end) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("No path on HTTP request line"); + } + parser->http.request->path = buf2str(beg, static_cast(cur - beg - 1)); + + if (cur == end || *cur++ != 'H') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'H'"); + } + if (cur == end || *cur++ != 'T') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'T'"); + } + if (cur == end || *cur++ != 'T') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'T'"); + } + if (cur == end || *cur++ != 'P') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected 'P'"); + } + if (cur == end || *cur++ != '/') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Expected '/'"); + } + vers_major = static_cast(*cur++ - '1' + 1); + ++cur; + if (cur == end) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "End of line in HTTP version string"); + } + vers_minor = static_cast(*cur++ - '1' + 1); + + if (vers_major == 1) { + if (vers_minor == 0) { + parser->http.request->version = GRPC_HTTP_HTTP10; + } else if (vers_minor == 1) { + parser->http.request->version = GRPC_HTTP_HTTP11; + } else { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Expected one of HTTP/1.0, HTTP/1.1, or HTTP/2.0"); + } + } else if (vers_major == 2) { + if (vers_minor == 0) { + parser->http.request->version = GRPC_HTTP_HTTP20; + } else { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Expected one of HTTP/1.0, HTTP/1.1, or HTTP/2.0"); + } + } else { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Expected one of HTTP/1.0, HTTP/1.1, or HTTP/2.0"); + } + + return GRPC_ERROR_NONE; +} + +static grpc_error_handle handle_first_line(grpc_http_parser* parser) { + switch (parser->type) { + case GRPC_HTTP_REQUEST: + return handle_request_line(parser); + case GRPC_HTTP_RESPONSE: + return handle_response_line(parser); + } + GPR_UNREACHABLE_CODE( + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Should never reach here")); +} + +static grpc_error_handle add_header(grpc_http_parser* parser) { + uint8_t* beg = parser->cur_line; + uint8_t* cur = beg; + uint8_t* end = beg + parser->cur_line_length; + size_t* hdr_count = nullptr; + grpc_http_header** hdrs = nullptr; + grpc_http_header hdr = {nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + + GPR_ASSERT(cur != end); + + if (*cur == ' ' || *cur == '\t') { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Continued header lines not supported yet"); + goto done; + } + + while (cur != end && *cur != ':') { + cur++; + } + if (cur == end) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Didn't find ':' in header string"); + goto done; + } + GPR_ASSERT(cur >= beg); + hdr.key = buf2str(beg, static_cast(cur - beg)); + cur++; /* skip : */ + + while (cur != end && (*cur == ' ' || *cur == '\t')) { + cur++; + } + GPR_ASSERT((size_t)(end - cur) >= parser->cur_line_end_length); + hdr.value = buf2str( + cur, static_cast(end - cur) - parser->cur_line_end_length); + + switch (parser->type) { + case GRPC_HTTP_RESPONSE: + hdr_count = &parser->http.response->hdr_count; + hdrs = &parser->http.response->hdrs; + break; + case GRPC_HTTP_REQUEST: + hdr_count = &parser->http.request->hdr_count; + hdrs = &parser->http.request->hdrs; + break; + } + + if (*hdr_count == parser->hdr_capacity) { + parser->hdr_capacity = + std::max(parser->hdr_capacity + 1, parser->hdr_capacity * 3 / 2); + *hdrs = static_cast( + gpr_realloc(*hdrs, parser->hdr_capacity * sizeof(**hdrs))); + } + (*hdrs)[(*hdr_count)++] = hdr; + +done: + if (error != GRPC_ERROR_NONE) { + gpr_free(hdr.key); + gpr_free(hdr.value); + } + return error; +} + +static grpc_error_handle finish_line(grpc_http_parser* parser, + bool* found_body_start) { + grpc_error_handle err; + switch (parser->state) { + case GRPC_HTTP_FIRST_LINE: + err = handle_first_line(parser); + if (err != GRPC_ERROR_NONE) return err; + parser->state = GRPC_HTTP_HEADERS; + break; + case GRPC_HTTP_HEADERS: + if (parser->cur_line_length == parser->cur_line_end_length) { + parser->state = GRPC_HTTP_BODY; + *found_body_start = true; + break; + } + err = add_header(parser); + if (err != GRPC_ERROR_NONE) { + return err; + } + break; + case GRPC_HTTP_BODY: + GPR_UNREACHABLE_CODE(return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Should never reach here")); + } + + parser->cur_line_length = 0; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle addbyte_body(grpc_http_parser* parser, uint8_t byte) { + size_t* body_length = nullptr; + char** body = nullptr; + + if (parser->type == GRPC_HTTP_RESPONSE) { + body_length = &parser->http.response->body_length; + body = &parser->http.response->body; + } else if (parser->type == GRPC_HTTP_REQUEST) { + body_length = &parser->http.request->body_length; + body = &parser->http.request->body; + } else { + GPR_UNREACHABLE_CODE( + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Should never reach here")); + } + + if (*body_length == parser->body_capacity) { + parser->body_capacity = std::max(size_t(8), parser->body_capacity * 3 / 2); + *body = static_cast(gpr_realloc(*body, parser->body_capacity)); + } + (*body)[*body_length] = static_cast(byte); + (*body_length)++; + + return GRPC_ERROR_NONE; +} + +static bool check_line(grpc_http_parser* parser) { + if (parser->cur_line_length >= 2 && + parser->cur_line[parser->cur_line_length - 2] == '\r' && + parser->cur_line[parser->cur_line_length - 1] == '\n') { + return true; + } + + // HTTP request with \n\r line termiantors. + else if (parser->cur_line_length >= 2 && + parser->cur_line[parser->cur_line_length - 2] == '\n' && + parser->cur_line[parser->cur_line_length - 1] == '\r') { + return true; + } + + // HTTP request with only \n line terminators. + else if (parser->cur_line_length >= 1 && + parser->cur_line[parser->cur_line_length - 1] == '\n') { + parser->cur_line_end_length = 1; + return true; + } + + return false; +} + +static grpc_error_handle addbyte(grpc_http_parser* parser, uint8_t byte, + bool* found_body_start) { + switch (parser->state) { + case GRPC_HTTP_FIRST_LINE: + case GRPC_HTTP_HEADERS: + if (parser->cur_line_length >= GRPC_HTTP_PARSER_MAX_HEADER_LENGTH) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_http1_trace)) { + gpr_log(GPR_ERROR, "HTTP header max line length (%d) exceeded", + GRPC_HTTP_PARSER_MAX_HEADER_LENGTH); + } + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "HTTP header max line length exceeded"); + } + parser->cur_line[parser->cur_line_length] = byte; + parser->cur_line_length++; + if (check_line(parser)) { + return finish_line(parser, found_body_start); + } + return GRPC_ERROR_NONE; + case GRPC_HTTP_BODY: + return addbyte_body(parser, byte); + } + GPR_UNREACHABLE_CODE(return GRPC_ERROR_NONE); +} + +void grpc_http_parser_init(grpc_http_parser* parser, grpc_http_type type, + void* request_or_response) { + memset(parser, 0, sizeof(*parser)); + parser->state = GRPC_HTTP_FIRST_LINE; + parser->type = type; + parser->http.request_or_response = request_or_response; + parser->cur_line_end_length = 2; +} + +void grpc_http_parser_destroy(grpc_http_parser* /*parser*/) {} + +void grpc_http_request_destroy(grpc_http_request* request) { + size_t i; + gpr_free(request->body); + for (i = 0; i < request->hdr_count; i++) { + gpr_free(request->hdrs[i].key); + gpr_free(request->hdrs[i].value); + } + gpr_free(request->hdrs); + gpr_free(request->method); + gpr_free(request->path); +} + +void grpc_http_response_destroy(grpc_http_response* response) { + size_t i; + gpr_free(response->body); + for (i = 0; i < response->hdr_count; i++) { + gpr_free(response->hdrs[i].key); + gpr_free(response->hdrs[i].value); + } + gpr_free(response->hdrs); +} + +grpc_error_handle grpc_http_parser_parse(grpc_http_parser* parser, + const grpc_slice& slice, + size_t* start_of_body) { + for (size_t i = 0; i < GRPC_SLICE_LENGTH(slice); i++) { + bool found_body_start = false; + grpc_error_handle err = + addbyte(parser, GRPC_SLICE_START_PTR(slice)[i], &found_body_start); + if (err != GRPC_ERROR_NONE) return err; + if (found_body_start && start_of_body != nullptr) *start_of_body = i + 1; + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_http_parser_eof(grpc_http_parser* parser) { + if (parser->state != GRPC_HTTP_BODY) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Did not finish headers"); + } + return GRPC_ERROR_NONE; +} diff --git a/src/core/lib/iomgr/buffer_list.cc b/src/core/lib/iomgr/buffer_list.cc new file mode 100644 index 00000000..0ea6af25 --- /dev/null +++ b/src/core/lib/iomgr/buffer_list.cc @@ -0,0 +1,307 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/buffer_list.h" + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_LINUX_ERRQUEUE +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" + +namespace grpc_core { +namespace { +/** Fills gpr_timespec gts based on values from timespec ts */ +void fill_gpr_from_timestamp(gpr_timespec* gts, const struct timespec* ts) { + gts->tv_sec = ts->tv_sec; + gts->tv_nsec = static_cast(ts->tv_nsec); + gts->clock_type = GPR_CLOCK_REALTIME; +} + +void default_timestamps_callback(void* /*arg*/, grpc_core::Timestamps* /*ts*/, + grpc_error_handle /*shudown_err*/) { + gpr_log(GPR_DEBUG, "Timestamps callback has not been registered"); +} + +/** The saved callback function that will be invoked when we get all the + * timestamps that we are going to get for a TracedBuffer. */ +void (*timestamps_callback)(void*, grpc_core::Timestamps*, + grpc_error_handle shutdown_err) = + default_timestamps_callback; + +/* Used to extract individual opt stats from cmsg, so as to avoid troubles with + * unaligned reads */ +template +T read_unaligned(const void* ptr) { + T val; + memcpy(&val, ptr, sizeof(val)); + return val; +} + +/* Extracts opt stats from the tcp_info struct \a info to \a metrics */ +void extract_opt_stats_from_tcp_info(ConnectionMetrics* metrics, + const grpc_core::tcp_info* info) { + if (info == nullptr) { + return; + } + if (info->length > offsetof(grpc_core::tcp_info, tcpi_sndbuf_limited)) { + metrics->recurring_retrans.emplace(info->tcpi_retransmits); + metrics->is_delivery_rate_app_limited.emplace( + info->tcpi_delivery_rate_app_limited); + metrics->congestion_window.emplace(info->tcpi_snd_cwnd); + metrics->reordering.emplace(info->tcpi_reordering); + metrics->packet_retx.emplace(info->tcpi_total_retrans); + metrics->pacing_rate.emplace(info->tcpi_pacing_rate); + metrics->data_notsent.emplace(info->tcpi_notsent_bytes); + if (info->tcpi_min_rtt != UINT32_MAX) { + metrics->min_rtt.emplace(info->tcpi_min_rtt); + } + metrics->packet_sent.emplace(info->tcpi_data_segs_out); + metrics->delivery_rate.emplace(info->tcpi_delivery_rate); + metrics->busy_usec.emplace(info->tcpi_busy_time); + metrics->rwnd_limited_usec.emplace(info->tcpi_rwnd_limited); + metrics->sndbuf_limited_usec.emplace(info->tcpi_sndbuf_limited); + } + if (info->length > offsetof(grpc_core::tcp_info, tcpi_dsack_dups)) { + metrics->data_sent.emplace(info->tcpi_bytes_sent); + metrics->data_retx.emplace(info->tcpi_bytes_retrans); + metrics->packet_spurious_retx.emplace(info->tcpi_dsack_dups); + } +} + +/** Extracts opt stats from the given control message \a opt_stats to the + * connection metrics \a metrics */ +void extract_opt_stats_from_cmsg(ConnectionMetrics* metrics, + const cmsghdr* opt_stats) { + if (opt_stats == nullptr) { + return; + } + const auto* data = CMSG_DATA(opt_stats); + constexpr int64_t cmsg_hdr_len = CMSG_ALIGN(sizeof(struct cmsghdr)); + const int64_t len = opt_stats->cmsg_len - cmsg_hdr_len; + int64_t offset = 0; + + while (offset < len) { + const auto* attr = reinterpret_cast(data + offset); + const void* val = data + offset + NLA_HDRLEN; + switch (attr->nla_type) { + case TCP_NLA_BUSY: { + metrics->busy_usec.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_RWND_LIMITED: { + metrics->rwnd_limited_usec.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_SNDBUF_LIMITED: { + metrics->sndbuf_limited_usec.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_PACING_RATE: { + metrics->pacing_rate.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DELIVERY_RATE: { + metrics->delivery_rate.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DELIVERY_RATE_APP_LMT: { + metrics->is_delivery_rate_app_limited.emplace( + read_unaligned(val)); + break; + } + case TCP_NLA_SND_CWND: { + metrics->congestion_window.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_MIN_RTT: { + metrics->min_rtt.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_SRTT: { + metrics->srtt.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_RECUR_RETRANS: { + metrics->recurring_retrans.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_BYTES_SENT: { + metrics->data_sent.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DATA_SEGS_OUT: { + metrics->packet_sent.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_TOTAL_RETRANS: { + metrics->packet_retx.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DELIVERED: { + metrics->packet_delivered.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DELIVERED_CE: { + metrics->packet_delivered_ce.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_BYTES_RETRANS: { + metrics->data_retx.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_DSACK_DUPS: { + metrics->packet_spurious_retx.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_REORDERING: { + metrics->reordering.emplace(read_unaligned(val)); + break; + } + case TCP_NLA_SND_SSTHRESH: { + metrics->snd_ssthresh.emplace(read_unaligned(val)); + break; + } + } + offset += NLA_ALIGN(attr->nla_len); + } +} + +static int get_socket_tcp_info(grpc_core::tcp_info* info, int fd) { + memset(info, 0, sizeof(*info)); + info->length = offsetof(grpc_core::tcp_info, length); + return getsockopt(fd, IPPROTO_TCP, TCP_INFO, info, &(info->length)); +} +} /* namespace */ + +void TracedBuffer::AddNewEntry(TracedBuffer** head, uint32_t seq_no, int fd, + void* arg) { + GPR_DEBUG_ASSERT(head != nullptr); + TracedBuffer* new_elem = new TracedBuffer(seq_no, arg); + /* Store the current time as the sendmsg time. */ + new_elem->ts_.sendmsg_time.time = gpr_now(GPR_CLOCK_REALTIME); + new_elem->ts_.scheduled_time.time = gpr_inf_past(GPR_CLOCK_REALTIME); + new_elem->ts_.sent_time.time = gpr_inf_past(GPR_CLOCK_REALTIME); + new_elem->ts_.acked_time.time = gpr_inf_past(GPR_CLOCK_REALTIME); + + if (get_socket_tcp_info(&new_elem->ts_.info, fd) == 0) { + extract_opt_stats_from_tcp_info(&new_elem->ts_.sendmsg_time.metrics, + &new_elem->ts_.info); + } + if (*head == nullptr) { + *head = new_elem; + return; + } + /* Append at the end. */ + TracedBuffer* ptr = *head; + while (ptr->next_ != nullptr) { + ptr = ptr->next_; + } + ptr->next_ = new_elem; +} + +void TracedBuffer::ProcessTimestamp(TracedBuffer** head, + struct sock_extended_err* serr, + struct cmsghdr* opt_stats, + struct scm_timestamping* tss) { + GPR_DEBUG_ASSERT(head != nullptr); + TracedBuffer* elem = *head; + TracedBuffer* next = nullptr; + while (elem != nullptr) { + /* The byte number refers to the sequence number of the last byte which this + * timestamp relates to. */ + if (serr->ee_data >= elem->seq_no_) { + switch (serr->ee_info) { + case SCM_TSTAMP_SCHED: + fill_gpr_from_timestamp(&(elem->ts_.scheduled_time.time), + &(tss->ts[0])); + extract_opt_stats_from_cmsg(&(elem->ts_.scheduled_time.metrics), + opt_stats); + elem = elem->next_; + break; + case SCM_TSTAMP_SND: + fill_gpr_from_timestamp(&(elem->ts_.sent_time.time), &(tss->ts[0])); + extract_opt_stats_from_cmsg(&(elem->ts_.sent_time.metrics), + opt_stats); + elem = elem->next_; + break; + case SCM_TSTAMP_ACK: + fill_gpr_from_timestamp(&(elem->ts_.acked_time.time), &(tss->ts[0])); + extract_opt_stats_from_cmsg(&(elem->ts_.acked_time.metrics), + opt_stats); + /* Got all timestamps. Do the callback and free this TracedBuffer. + * The thing below can be passed by value if we don't want the + * restriction on the lifetime. */ + timestamps_callback(elem->arg_, &(elem->ts_), GRPC_ERROR_NONE); + next = elem->next_; + delete static_cast(elem); + *head = elem = next; + break; + default: + abort(); + } + } else { + break; + } + } +} + +void TracedBuffer::Shutdown(TracedBuffer** head, void* remaining, + grpc_error_handle shutdown_err) { + GPR_DEBUG_ASSERT(head != nullptr); + TracedBuffer* elem = *head; + while (elem != nullptr) { + timestamps_callback(elem->arg_, &(elem->ts_), shutdown_err); + auto* next = elem->next_; + delete elem; + elem = next; + } + *head = nullptr; + if (remaining != nullptr) { + timestamps_callback(remaining, nullptr, shutdown_err); + } + GRPC_ERROR_UNREF(shutdown_err); +} + +void grpc_tcp_set_write_timestamps_callback( + void (*fn)(void*, grpc_core::Timestamps*, grpc_error_handle error)) { + timestamps_callback = fn; +} +} /* namespace grpc_core */ + +#else /* GRPC_LINUX_ERRQUEUE */ + +namespace grpc_core { +void grpc_tcp_set_write_timestamps_callback( + void (*fn)(void*, grpc_core::Timestamps*, grpc_error_handle error)) { + // Cast value of fn to void to avoid unused parameter warning. + // Can't comment out the name because some compilers and formatters don't + // like the sequence */* , which would arise from */*fn*/. + (void)fn; + gpr_log(GPR_DEBUG, "Timestamps callback is not enabled for this platform"); +} +} /* namespace grpc_core */ + +#endif /* GRPC_LINUX_ERRQUEUE */ diff --git a/src/core/lib/iomgr/call_combiner.cc b/src/core/lib/iomgr/call_combiner.cc new file mode 100644 index 00000000..7137bcaa --- /dev/null +++ b/src/core/lib/iomgr/call_combiner.cc @@ -0,0 +1,281 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/call_combiner.h" + +#include + +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/profiling/timers.h" + +namespace grpc_core { + +DebugOnlyTraceFlag grpc_call_combiner_trace(false, "call_combiner"); + +namespace { + +// grpc_error LSB can be used +constexpr static intptr_t kErrorBit = 1; + +grpc_error_handle DecodeCancelStateError(gpr_atm cancel_state) { + if (cancel_state & kErrorBit) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + return internal::StatusGetFromHeapPtr(cancel_state & ~kErrorBit); +#else + return reinterpret_cast(cancel_state & ~kErrorBit); +#endif + } + return GRPC_ERROR_NONE; +} + +} // namespace + +CallCombiner::CallCombiner() { + gpr_atm_no_barrier_store(&cancel_state_, 0); + gpr_atm_no_barrier_store(&size_, 0); +#ifdef GRPC_TSAN_ENABLED + GRPC_CLOSURE_INIT(&tsan_closure_, TsanClosure, this, + grpc_schedule_on_exec_ctx); +#endif +} + +CallCombiner::~CallCombiner() { + if (cancel_state_ & kErrorBit) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + internal::StatusFreeHeapPtr(cancel_state_ & ~kErrorBit); +#else + GRPC_ERROR_UNREF(reinterpret_cast( + cancel_state_ & ~static_cast(kErrorBit))); +#endif + } +} + +#ifdef GRPC_TSAN_ENABLED +void CallCombiner::TsanClosure(void* arg, grpc_error_handle error) { + CallCombiner* self = static_cast(arg); + // We ref-count the lock, and check if it's already taken. + // If it was taken, we should do nothing. Otherwise, we will mark it as + // locked. Note that if two different threads try to do this, only one of + // them will be able to mark the lock as acquired, while they both run their + // callbacks. In such cases (which should never happen for call_combiner), + // TSAN will correctly produce an error. + // + // TODO(soheil): This only covers the callbacks scheduled by + // CallCombiner::Start() and CallCombiner::Stop(). + // If in the future, a callback gets scheduled using other + // mechanisms, we will need to add APIs to externally lock + // call combiners. + RefCountedPtr lock = self->tsan_lock_; + bool prev = false; + if (lock->taken.compare_exchange_strong(prev, true)) { + TSAN_ANNOTATE_RWLOCK_ACQUIRED(&lock->taken, true); + } else { + lock.reset(); + } + grpc_core::Closure::Run(DEBUG_LOCATION, self->original_closure_, + GRPC_ERROR_REF(error)); + if (lock != nullptr) { + TSAN_ANNOTATE_RWLOCK_RELEASED(&lock->taken, true); + bool prev = true; + GPR_ASSERT(lock->taken.compare_exchange_strong(prev, false)); + } +} +#endif + +void CallCombiner::ScheduleClosure(grpc_closure* closure, + grpc_error_handle error) { +#ifdef GRPC_TSAN_ENABLED + original_closure_ = closure; + ExecCtx::Run(DEBUG_LOCATION, &tsan_closure_, error); +#else + ExecCtx::Run(DEBUG_LOCATION, closure, error); +#endif +} + +#ifndef NDEBUG +#define DEBUG_ARGS const char *file, int line, +#define DEBUG_FMT_STR "%s:%d: " +#define DEBUG_FMT_ARGS , file, line +#else +#define DEBUG_ARGS +#define DEBUG_FMT_STR +#define DEBUG_FMT_ARGS +#endif + +void CallCombiner::Start(grpc_closure* closure, grpc_error_handle error, + DEBUG_ARGS const char* reason) { + GPR_TIMER_SCOPE("CallCombiner::Start", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, + "==> CallCombiner::Start() [%p] closure=%p [" DEBUG_FMT_STR + "%s] error=%s", + this, closure DEBUG_FMT_ARGS, reason, + grpc_error_std_string(error).c_str()); + } + size_t prev_size = + static_cast(gpr_atm_full_fetch_add(&size_, (gpr_atm)1)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " size: %" PRIdPTR " -> %" PRIdPTR, prev_size, + prev_size + 1); + } + GRPC_STATS_INC_CALL_COMBINER_LOCKS_SCHEDULED_ITEMS(); + if (prev_size == 0) { + GRPC_STATS_INC_CALL_COMBINER_LOCKS_INITIATED(); + GPR_TIMER_MARK("call_combiner_initiate", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " EXECUTING IMMEDIATELY"); + } + // Queue was empty, so execute this closure immediately. + ScheduleClosure(closure, error); + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " QUEUING"); + } + // Queue was not empty, so add closure to queue. + closure->error_data.error = error; + queue_.Push( + reinterpret_cast(closure)); + } +} + +void CallCombiner::Stop(DEBUG_ARGS const char* reason) { + GPR_TIMER_SCOPE("CallCombiner::Stop", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, "==> CallCombiner::Stop() [%p] [" DEBUG_FMT_STR "%s]", + this DEBUG_FMT_ARGS, reason); + } + size_t prev_size = + static_cast(gpr_atm_full_fetch_add(&size_, (gpr_atm)-1)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " size: %" PRIdPTR " -> %" PRIdPTR, prev_size, + prev_size - 1); + } + GPR_ASSERT(prev_size >= 1); + if (prev_size > 1) { + while (true) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " checking queue"); + } + bool empty; + grpc_closure* closure = + reinterpret_cast(queue_.PopAndCheckEnd(&empty)); + if (closure == nullptr) { + // This can happen either due to a race condition within the mpscq + // code or because of a race with Start(). + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " queue returned no result; checking again"); + } + continue; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " EXECUTING FROM QUEUE: closure=%p error=%s", + closure, + grpc_error_std_string(closure->error_data.error).c_str()); + } + ScheduleClosure(closure, closure->error_data.error); + break; + } + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, " queue empty"); + } +} + +void CallCombiner::SetNotifyOnCancel(grpc_closure* closure) { + GRPC_STATS_INC_CALL_COMBINER_SET_NOTIFY_ON_CANCEL(); + while (true) { + // Decode original state. + gpr_atm original_state = gpr_atm_acq_load(&cancel_state_); + grpc_error_handle original_error = DecodeCancelStateError(original_state); + // If error is set, invoke the cancellation closure immediately. + // Otherwise, store the new closure. + if (original_error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, + "call_combiner=%p: scheduling notify_on_cancel callback=%p " + "for pre-existing cancellation", + this, closure); + } + ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(original_error)); + break; + } else { + if (gpr_atm_full_cas(&cancel_state_, original_state, + reinterpret_cast(closure))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, "call_combiner=%p: setting notify_on_cancel=%p", + this, closure); + } + // If we replaced an earlier closure, invoke the original + // closure with GRPC_ERROR_NONE. This allows callers to clean + // up any resources they may be holding for the callback. + if (original_state != 0) { + closure = reinterpret_cast(original_state); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, + "call_combiner=%p: scheduling old cancel callback=%p", this, + closure); + } + ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + } + break; + } + } + // cas failed, try again. + } +} + +void CallCombiner::Cancel(grpc_error_handle error) { + GRPC_STATS_INC_CALL_COMBINER_CANCELLED(); +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + intptr_t status_ptr = internal::StatusAllocHeapPtr(error); + gpr_atm new_state = kErrorBit | status_ptr; +#else + gpr_atm new_state = kErrorBit | reinterpret_cast(error); +#endif + while (true) { + gpr_atm original_state = gpr_atm_acq_load(&cancel_state_); + grpc_error_handle original_error = DecodeCancelStateError(original_state); + if (original_error != GRPC_ERROR_NONE) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + internal::StatusFreeHeapPtr(status_ptr); +#else + GRPC_ERROR_UNREF(error); +#endif + break; + } + if (gpr_atm_full_cas(&cancel_state_, original_state, new_state)) { + if (original_state != 0) { + grpc_closure* notify_on_cancel = + reinterpret_cast(original_state); + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { + gpr_log(GPR_INFO, + "call_combiner=%p: scheduling notify_on_cancel callback=%p", + this, notify_on_cancel); + } + ExecCtx::Run(DEBUG_LOCATION, notify_on_cancel, GRPC_ERROR_REF(error)); + } + break; + } + // cas failed, try again. + } +} + +} // namespace grpc_core diff --git a/src/core/lib/iomgr/cfstream_handle.cc b/src/core/lib/iomgr/cfstream_handle.cc new file mode 100644 index 00000000..b94711c6 --- /dev/null +++ b/src/core/lib/iomgr/cfstream_handle.cc @@ -0,0 +1,210 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_CFSTREAM +#import + +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#import "src/core/lib/iomgr/cfstream_handle.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error_cfstream.h" +#include "src/core/lib/iomgr/ev_apple.h" +#include "src/core/lib/iomgr/exec_ctx.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +GrpcLibraryInitHolder::GrpcLibraryInitHolder() { grpc_init(); } + +GrpcLibraryInitHolder::~GrpcLibraryInitHolder() { grpc_shutdown(); } + +void* CFStreamHandle::Retain(void* info) { + CFStreamHandle* handle = static_cast(info); + CFSTREAM_HANDLE_REF(handle, "retain"); + return info; +} + +void CFStreamHandle::Release(void* info) { + CFStreamHandle* handle = static_cast(info); + CFSTREAM_HANDLE_UNREF(handle, "release"); +} + +CFStreamHandle* CFStreamHandle::CreateStreamHandle( + CFReadStreamRef read_stream, CFWriteStreamRef write_stream) { + return new CFStreamHandle(read_stream, write_stream); +} + +void CFStreamHandle::ReadCallback(CFReadStreamRef stream, + CFStreamEventType type, + void* client_callback_info) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_error_handle error; + CFErrorRef stream_error; + CFStreamHandle* handle = static_cast(client_callback_info); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream ReadCallback (%p, %p, %lu, %p)", handle, + stream, type, client_callback_info); + } + switch (type) { + case kCFStreamEventOpenCompleted: + handle->open_event_.SetReady(); + break; + case kCFStreamEventHasBytesAvailable: + case kCFStreamEventEndEncountered: + handle->read_event_.SetReady(); + break; + case kCFStreamEventErrorOccurred: + stream_error = CFReadStreamCopyError(stream); + error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CFERROR(stream_error, "read error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + CFRelease(stream_error); + handle->open_event_.SetShutdown(GRPC_ERROR_REF(error)); + handle->write_event_.SetShutdown(GRPC_ERROR_REF(error)); + handle->read_event_.SetShutdown(GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + break; + default: + GPR_UNREACHABLE_CODE(return ); + } +} +void CFStreamHandle::WriteCallback(CFWriteStreamRef stream, + CFStreamEventType type, + void* clientCallBackInfo) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_error_handle error; + CFErrorRef stream_error; + CFStreamHandle* handle = static_cast(clientCallBackInfo); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream WriteCallback (%p, %p, %lu, %p)", handle, + stream, type, clientCallBackInfo); + } + switch (type) { + case kCFStreamEventOpenCompleted: + handle->open_event_.SetReady(); + break; + case kCFStreamEventCanAcceptBytes: + case kCFStreamEventEndEncountered: + handle->write_event_.SetReady(); + break; + case kCFStreamEventErrorOccurred: + stream_error = CFWriteStreamCopyError(stream); + error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CFERROR(stream_error, "write error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + CFRelease(stream_error); + handle->open_event_.SetShutdown(GRPC_ERROR_REF(error)); + handle->write_event_.SetShutdown(GRPC_ERROR_REF(error)); + handle->read_event_.SetShutdown(GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + break; + default: + GPR_UNREACHABLE_CODE(return ); + } +} + +CFStreamHandle::CFStreamHandle(CFReadStreamRef read_stream, + CFWriteStreamRef write_stream) { + gpr_ref_init(&refcount_, 1); + open_event_.InitEvent(); + read_event_.InitEvent(); + write_event_.InitEvent(); + dispatch_queue_ = dispatch_queue_create(nullptr, DISPATCH_QUEUE_SERIAL); + CFStreamClientContext ctx = {0, static_cast(this), + CFStreamHandle::Retain, CFStreamHandle::Release, + nil}; + CFReadStreamSetClient( + read_stream, + kCFStreamEventOpenCompleted | kCFStreamEventHasBytesAvailable | + kCFStreamEventErrorOccurred | kCFStreamEventEndEncountered, + CFStreamHandle::ReadCallback, &ctx); + CFWriteStreamSetClient( + write_stream, + kCFStreamEventOpenCompleted | kCFStreamEventCanAcceptBytes | + kCFStreamEventErrorOccurred | kCFStreamEventEndEncountered, + CFStreamHandle::WriteCallback, &ctx); + grpc_apple_register_read_stream(read_stream, dispatch_queue_); + grpc_apple_register_write_stream(write_stream, dispatch_queue_); +} + +CFStreamHandle::~CFStreamHandle() { + open_event_.DestroyEvent(); + read_event_.DestroyEvent(); + write_event_.DestroyEvent(); + dispatch_release(dispatch_queue_); +} + +void CFStreamHandle::NotifyOnOpen(grpc_closure* closure) { + open_event_.NotifyOn(closure); +} + +void CFStreamHandle::NotifyOnRead(grpc_closure* closure) { + read_event_.NotifyOn(closure); +} + +void CFStreamHandle::NotifyOnWrite(grpc_closure* closure) { + write_event_.NotifyOn(closure); +} + +void CFStreamHandle::Shutdown(grpc_error_handle error) { + open_event_.SetShutdown(GRPC_ERROR_REF(error)); + read_event_.SetShutdown(GRPC_ERROR_REF(error)); + write_event_.SetShutdown(GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); +} + +void CFStreamHandle::Ref(const char* file, int line, const char* reason) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&refcount_.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "CFStream Handle ref %p : %s %" PRIdPTR " -> %" PRIdPTR, this, + reason, val, val + 1); + } + gpr_ref(&refcount_); +} + +void CFStreamHandle::Unref(const char* file, int line, const char* reason) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&refcount_.count); + gpr_log(GPR_DEBUG, + "CFStream Handle unref %p : %s %" PRIdPTR " -> %" PRIdPTR, this, + reason, val, val - 1); + } + if (gpr_unref(&refcount_)) { + delete this; + } +} + +#else + +/* Creating a phony function so that the grpc_cfstream library will be + * non-empty. + */ +void CFStreamPhony() {} + +#endif diff --git a/src/core/lib/iomgr/combiner.cc b/src/core/lib/iomgr/combiner.cc new file mode 100644 index 00000000..1b5add88 --- /dev/null +++ b/src/core/lib/iomgr/combiner.cc @@ -0,0 +1,328 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/combiner.h" + +#include +#include +#include + +#include +#include + +#include "src/core/lib/gprpp/mpscq.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" + +grpc_core::DebugOnlyTraceFlag grpc_combiner_trace(false, "combiner"); + +#define GRPC_COMBINER_TRACE(fn) \ + do { \ + if (grpc_combiner_trace.enabled()) { \ + fn; \ + } \ + } while (0) + +#define STATE_UNORPHANED 1 +#define STATE_ELEM_COUNT_LOW_BIT 2 + +static void combiner_exec(grpc_core::Combiner* lock, grpc_closure* closure, + grpc_error_handle error); +static void combiner_finally_exec(grpc_core::Combiner* lock, + grpc_closure* closure, + grpc_error_handle error); + +static void offload(void* arg, grpc_error_handle error); + +grpc_core::Combiner* grpc_combiner_create(void) { + grpc_core::Combiner* lock = new grpc_core::Combiner(); + gpr_ref_init(&lock->refs, 1); + gpr_atm_no_barrier_store(&lock->state, STATE_UNORPHANED); + grpc_closure_list_init(&lock->final_list); + GRPC_CLOSURE_INIT(&lock->offload, offload, lock, nullptr); + GRPC_COMBINER_TRACE(gpr_log(GPR_INFO, "C:%p create", lock)); + return lock; +} + +static void really_destroy(grpc_core::Combiner* lock) { + GRPC_COMBINER_TRACE(gpr_log(GPR_INFO, "C:%p really_destroy", lock)); + GPR_ASSERT(gpr_atm_no_barrier_load(&lock->state) == 0); + delete lock; +} + +static void start_destroy(grpc_core::Combiner* lock) { + gpr_atm old_state = gpr_atm_full_fetch_add(&lock->state, -STATE_UNORPHANED); + GRPC_COMBINER_TRACE(gpr_log( + GPR_INFO, "C:%p really_destroy old_state=%" PRIdPTR, lock, old_state)); + if (old_state == 1) { + really_destroy(lock); + } +} + +#ifndef NDEBUG +#define GRPC_COMBINER_DEBUG_SPAM(op, delta) \ + if (grpc_combiner_trace.enabled()) { \ + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, \ + "C:%p %s %" PRIdPTR " --> %" PRIdPTR " %s", lock, (op), \ + gpr_atm_no_barrier_load(&lock->refs.count), \ + gpr_atm_no_barrier_load(&lock->refs.count) + (delta), reason); \ + } +#else +#define GRPC_COMBINER_DEBUG_SPAM(op, delta) +#endif + +void grpc_combiner_unref(grpc_core::Combiner* lock GRPC_COMBINER_DEBUG_ARGS) { + GRPC_COMBINER_DEBUG_SPAM("UNREF", -1); + if (gpr_unref(&lock->refs)) { + start_destroy(lock); + } +} + +grpc_core::Combiner* grpc_combiner_ref( + grpc_core::Combiner* lock GRPC_COMBINER_DEBUG_ARGS) { + GRPC_COMBINER_DEBUG_SPAM(" REF", 1); + gpr_ref(&lock->refs); + return lock; +} + +static void push_last_on_exec_ctx(grpc_core::Combiner* lock) { + lock->next_combiner_on_this_exec_ctx = nullptr; + if (grpc_core::ExecCtx::Get()->combiner_data()->active_combiner == nullptr) { + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner = + grpc_core::ExecCtx::Get()->combiner_data()->last_combiner = lock; + } else { + grpc_core::ExecCtx::Get() + ->combiner_data() + ->last_combiner->next_combiner_on_this_exec_ctx = lock; + grpc_core::ExecCtx::Get()->combiner_data()->last_combiner = lock; + } +} + +static void push_first_on_exec_ctx(grpc_core::Combiner* lock) { + lock->next_combiner_on_this_exec_ctx = + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner; + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner = lock; + if (lock->next_combiner_on_this_exec_ctx == nullptr) { + grpc_core::ExecCtx::Get()->combiner_data()->last_combiner = lock; + } +} + +static void combiner_exec(grpc_core::Combiner* lock, grpc_closure* cl, + grpc_error_handle error) { + gpr_atm last = gpr_atm_full_fetch_add(&lock->state, STATE_ELEM_COUNT_LOW_BIT); + GRPC_COMBINER_TRACE(gpr_log(GPR_INFO, + "C:%p grpc_combiner_execute c=%p last=%" PRIdPTR, + lock, cl, last)); + if (last == 1) { + gpr_atm_no_barrier_store( + &lock->initiating_exec_ctx_or_null, + reinterpret_cast(grpc_core::ExecCtx::Get())); + // first element on this list: add it to the list of combiner locks + // executing within this exec_ctx + push_last_on_exec_ctx(lock); + } else { + // there may be a race with setting here: if that happens, we may delay + // offload for one or two actions, and that's fine + gpr_atm initiator = + gpr_atm_no_barrier_load(&lock->initiating_exec_ctx_or_null); + if (initiator != 0 && + initiator != reinterpret_cast(grpc_core::ExecCtx::Get())) { + gpr_atm_no_barrier_store(&lock->initiating_exec_ctx_or_null, 0); + } + } + GPR_ASSERT(last & STATE_UNORPHANED); // ensure lock has not been destroyed + assert(cl->cb); + cl->error_data.error = error; + lock->queue.Push(cl->next_data.mpscq_node.get()); +} + +static void move_next() { + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner = + grpc_core::ExecCtx::Get() + ->combiner_data() + ->active_combiner->next_combiner_on_this_exec_ctx; + if (grpc_core::ExecCtx::Get()->combiner_data()->active_combiner == nullptr) { + grpc_core::ExecCtx::Get()->combiner_data()->last_combiner = nullptr; + } +} + +static void offload(void* arg, grpc_error_handle /*error*/) { + grpc_core::Combiner* lock = static_cast(arg); + push_last_on_exec_ctx(lock); +} + +static void queue_offload(grpc_core::Combiner* lock) { + move_next(); + GRPC_COMBINER_TRACE(gpr_log(GPR_INFO, "C:%p queue_offload", lock)); + grpc_core::Executor::Run(&lock->offload, GRPC_ERROR_NONE); +} + +bool grpc_combiner_continue_exec_ctx() { + grpc_core::Combiner* lock = + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner; + if (lock == nullptr) { + return false; + } + + bool contended = + gpr_atm_no_barrier_load(&lock->initiating_exec_ctx_or_null) == 0; + + GRPC_COMBINER_TRACE(gpr_log(GPR_INFO, + "C:%p grpc_combiner_continue_exec_ctx " + "contended=%d " + "exec_ctx_ready_to_finish=%d " + "time_to_execute_final_list=%d", + lock, contended, + grpc_core::ExecCtx::Get()->IsReadyToFinish(), + lock->time_to_execute_final_list)); + + // offload only if all the following conditions are true: + // 1. the combiner is contended and has more than one closure to execute + // 2. the current execution context needs to finish as soon as possible + // 3. the current thread is not a worker for any background poller + // 4. the DEFAULT executor is threaded + if (contended && grpc_core::ExecCtx::Get()->IsReadyToFinish() && + !grpc_iomgr_platform_is_any_background_poller_thread() && + grpc_core::Executor::IsThreadedDefault()) { + // this execution context wants to move on: schedule remaining work to be + // picked up on the executor + queue_offload(lock); + return true; + } + + if (!lock->time_to_execute_final_list || + // peek to see if something new has shown up, and execute that with + // priority + (gpr_atm_acq_load(&lock->state) >> 1) > 1) { + grpc_core::MultiProducerSingleConsumerQueue::Node* n = lock->queue.Pop(); + GRPC_COMBINER_TRACE( + gpr_log(GPR_INFO, "C:%p maybe_finish_one n=%p", lock, n)); + if (n == nullptr) { + // queue is in an inconsistent state: use this as a cue that we should + // go off and do something else for a while (and come back later) + queue_offload(lock); + return true; + } + grpc_closure* cl = reinterpret_cast(n); + grpc_error_handle cl_err = cl->error_data.error; +#ifndef NDEBUG + cl->scheduled = false; +#endif + cl->cb(cl->cb_arg, cl_err); + GRPC_ERROR_UNREF(cl_err); + } else { + grpc_closure* c = lock->final_list.head; + GPR_ASSERT(c != nullptr); + grpc_closure_list_init(&lock->final_list); + int loops = 0; + while (c != nullptr) { + GRPC_COMBINER_TRACE( + gpr_log(GPR_INFO, "C:%p execute_final[%d] c=%p", lock, loops, c)); + grpc_closure* next = c->next_data.next; + grpc_error_handle error = c->error_data.error; +#ifndef NDEBUG + c->scheduled = false; +#endif + c->cb(c->cb_arg, error); + GRPC_ERROR_UNREF(error); + c = next; + } + } + + move_next(); + lock->time_to_execute_final_list = false; + gpr_atm old_state = + gpr_atm_full_fetch_add(&lock->state, -STATE_ELEM_COUNT_LOW_BIT); + GRPC_COMBINER_TRACE( + gpr_log(GPR_INFO, "C:%p finish old_state=%" PRIdPTR, lock, old_state)); +// Define a macro to ease readability of the following switch statement. +#define OLD_STATE_WAS(orphaned, elem_count) \ + (((orphaned) ? 0 : STATE_UNORPHANED) | \ + ((elem_count)*STATE_ELEM_COUNT_LOW_BIT)) + // Depending on what the previous state was, we need to perform different + // actions. + switch (old_state) { + default: + // we have multiple queued work items: just continue executing them + break; + case OLD_STATE_WAS(false, 2): + case OLD_STATE_WAS(true, 2): + // we're down to one queued item: if it's the final list we should do that + if (!grpc_closure_list_empty(lock->final_list)) { + lock->time_to_execute_final_list = true; + } + break; + case OLD_STATE_WAS(false, 1): + // had one count, one unorphaned --> unlocked unorphaned + return true; + case OLD_STATE_WAS(true, 1): + // and one count, one orphaned --> unlocked and orphaned + really_destroy(lock); + return true; + case OLD_STATE_WAS(false, 0): + case OLD_STATE_WAS(true, 0): + // these values are illegal - representing an already unlocked or + // deleted lock + GPR_UNREACHABLE_CODE(return true); + } + push_first_on_exec_ctx(lock); + return true; +} + +static void enqueue_finally(void* closure, grpc_error_handle error); + +static void combiner_finally_exec(grpc_core::Combiner* lock, + grpc_closure* closure, + grpc_error_handle error) { + GPR_ASSERT(lock != nullptr); + GRPC_COMBINER_TRACE(gpr_log( + GPR_INFO, "C:%p grpc_combiner_execute_finally c=%p; ac=%p", lock, closure, + grpc_core::ExecCtx::Get()->combiner_data()->active_combiner)); + if (grpc_core::ExecCtx::Get()->combiner_data()->active_combiner != lock) { + // Using error_data.scratch to store the combiner so that it can be accessed + // in enqueue_finally. + closure->error_data.scratch = reinterpret_cast(lock); + lock->Run(GRPC_CLOSURE_CREATE(enqueue_finally, closure, nullptr), error); + return; + } + + if (grpc_closure_list_empty(lock->final_list)) { + gpr_atm_full_fetch_add(&lock->state, STATE_ELEM_COUNT_LOW_BIT); + } + grpc_closure_list_append(&lock->final_list, closure, error); +} + +static void enqueue_finally(void* closure, grpc_error_handle error) { + grpc_closure* cl = static_cast(closure); + grpc_core::Combiner* lock = + reinterpret_cast(cl->error_data.scratch); + cl->error_data.scratch = 0; + combiner_finally_exec(lock, cl, GRPC_ERROR_REF(error)); +} + +namespace grpc_core { +void Combiner::Run(grpc_closure* closure, grpc_error_handle error) { + combiner_exec(this, closure, error); +} + +void Combiner::FinallyRun(grpc_closure* closure, grpc_error_handle error) { + combiner_finally_exec(this, closure, error); +} +} // namespace grpc_core diff --git a/src/core/lib/iomgr/dualstack_socket_posix.cc b/src/core/lib/iomgr/dualstack_socket_posix.cc new file mode 100644 index 00000000..2d771133 --- /dev/null +++ b/src/core/lib/iomgr/dualstack_socket_posix.cc @@ -0,0 +1,48 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_UTILS_COMMON + +#include + +#include "src/core/lib/iomgr/socket_utils_posix.h" + +#ifndef GRPC_SET_SOCKET_DUALSTACK_CUSTOM + +/* This should be 0 in production, but it may be enabled for testing or + debugging purposes, to simulate an environment where IPv6 sockets can't + also speak IPv4. */ +int grpc_forbid_dualstack_sockets_for_testing = 0; + +int grpc_set_socket_dualstack(int fd) { + if (!grpc_forbid_dualstack_sockets_for_testing) { + const int off = 0; + return 0 == setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &off, sizeof(off)); + } else { + /* Force an IPv6-only socket, for testing purposes. */ + const int on = 1; + setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)); + return 0; + } +} +#endif // GRPC_SET_SOCKET_DUALSTACK_CUSTOM +#endif // GRPC_POSIX_SOCKET_UTILS_COMMON diff --git a/src/core/lib/iomgr/endpoint.cc b/src/core/lib/iomgr/endpoint.cc new file mode 100644 index 00000000..a2f864c7 --- /dev/null +++ b/src/core/lib/iomgr/endpoint.cc @@ -0,0 +1,67 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/endpoint.h" + +grpc_core::TraceFlag grpc_tcp_trace(false, "tcp"); + +void grpc_endpoint_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool urgent) { + ep->vtable->read(ep, slices, cb, urgent); +} + +void grpc_endpoint_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg) { + ep->vtable->write(ep, slices, cb, arg); +} + +void grpc_endpoint_add_to_pollset(grpc_endpoint* ep, grpc_pollset* pollset) { + ep->vtable->add_to_pollset(ep, pollset); +} + +void grpc_endpoint_add_to_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + ep->vtable->add_to_pollset_set(ep, pollset_set); +} + +void grpc_endpoint_delete_from_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + ep->vtable->delete_from_pollset_set(ep, pollset_set); +} + +void grpc_endpoint_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + ep->vtable->shutdown(ep, why); +} + +void grpc_endpoint_destroy(grpc_endpoint* ep) { ep->vtable->destroy(ep); } + +absl::string_view grpc_endpoint_get_peer(grpc_endpoint* ep) { + return ep->vtable->get_peer(ep); +} + +absl::string_view grpc_endpoint_get_local_address(grpc_endpoint* ep) { + return ep->vtable->get_local_address(ep); +} + +int grpc_endpoint_get_fd(grpc_endpoint* ep) { return ep->vtable->get_fd(ep); } + +bool grpc_endpoint_can_track_err(grpc_endpoint* ep) { + return ep->vtable->can_track_err(ep); +} diff --git a/src/core/lib/iomgr/endpoint_cfstream.cc b/src/core/lib/iomgr/endpoint_cfstream.cc new file mode 100644 index 00000000..224e2766 --- /dev/null +++ b/src/core/lib/iomgr/endpoint_cfstream.cc @@ -0,0 +1,389 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_CFSTREAM_ENDPOINT + +#import + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/cfstream_handle.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/endpoint.h" +#import "src/core/lib/iomgr/endpoint_cfstream.h" +#include "src/core/lib/iomgr/error_cfstream.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +struct CFStreamEndpoint { + grpc_endpoint base; + gpr_refcount refcount; + + CFReadStreamRef read_stream; + CFWriteStreamRef write_stream; + CFStreamHandle* stream_sync; + + grpc_closure* read_cb; + grpc_closure* write_cb; + grpc_slice_buffer* read_slices; + grpc_slice_buffer* write_slices; + + grpc_closure read_action; + grpc_closure write_action; + + std::string peer_string; + std::string local_address; + grpc_slice_allocator* slice_allocator; +}; +static void CFStreamFree(CFStreamEndpoint* ep) { + grpc_slice_allocator_destroy(ep->slice_allocator); + CFRelease(ep->read_stream); + CFRelease(ep->write_stream); + CFSTREAM_HANDLE_UNREF(ep->stream_sync, "free"); + delete ep; +} + +#ifndef NDEBUG +#define EP_REF(ep, reason) CFStreamRef((ep), (reason), __FILE__, __LINE__) +#define EP_UNREF(ep, reason) CFStreamUnref((ep), (reason), __FILE__, __LINE__) +static void CFStreamUnref(CFStreamEndpoint* ep, const char* reason, + const char* file, int line) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&ep->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "CFStream endpoint unref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, + reason, val, val - 1); + } + if (gpr_unref(&ep->refcount)) { + CFStreamFree(ep); + } +} +static void CFStreamRef(CFStreamEndpoint* ep, const char* reason, + const char* file, int line) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&ep->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "CFStream endpoint ref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, + reason, val, val + 1); + } + gpr_ref(&ep->refcount); +} +#else +#define EP_REF(ep, reason) CFStreamRef((ep)) +#define EP_UNREF(ep, reason) CFStreamUnref((ep)) +static void CFStreamUnref(CFStreamEndpoint* ep) { + if (gpr_unref(&ep->refcount)) { + CFStreamFree(ep); + } +} +static void CFStreamRef(CFStreamEndpoint* ep) { gpr_ref(&ep->refcount); } +#endif + +static grpc_error_handle CFStreamAnnotateError(grpc_error_handle src_error, + CFStreamEndpoint* ep) { + return grpc_error_set_str( + grpc_error_set_int(src_error, GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE), + GRPC_ERROR_STR_TARGET_ADDRESS, ep->peer_string); +} + +static void CallReadCb(CFStreamEndpoint* ep, grpc_error_handle error) { + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p call_read_cb %p %p:%p", ep, + ep->read_cb, ep->read_cb->cb, ep->read_cb->cb_arg); + size_t i; + gpr_log(GPR_DEBUG, "read: error=%s", grpc_error_std_string(error).c_str()); + + for (i = 0; i < ep->read_slices->count; i++) { + char* dump = grpc_dump_slice(ep->read_slices->slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "READ %p (peer=%s): %s", ep, ep->peer_string.c_str(), + dump); + gpr_free(dump); + } + } + grpc_closure* cb = ep->read_cb; + ep->read_cb = nullptr; + ep->read_slices = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +static void CallWriteCb(CFStreamEndpoint* ep, grpc_error_handle error) { + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p call_write_cb %p %p:%p", ep, + ep->write_cb, ep->write_cb->cb, ep->write_cb->cb_arg); + gpr_log(GPR_DEBUG, "write: error=%s", grpc_error_std_string(error).c_str()); + } + grpc_closure* cb = ep->write_cb; + ep->write_cb = nullptr; + ep->write_slices = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +static void ReadAction(void* arg, grpc_error_handle error) { + CFStreamEndpoint* ep = static_cast(arg); + GPR_ASSERT(ep->read_cb != nullptr); + if (error != GRPC_ERROR_NONE) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_slices); + CallReadCb(ep, GRPC_ERROR_REF(error)); + EP_UNREF(ep, "read"); + return; + } + + GPR_ASSERT(ep->read_slices->count == 1); + grpc_slice slice = ep->read_slices->slices[0]; + size_t len = GRPC_SLICE_LENGTH(slice); + CFIndex read_size = + CFReadStreamRead(ep->read_stream, GRPC_SLICE_START_PTR(slice), len); + if (read_size == -1) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_slices); + CFErrorRef stream_error = CFReadStreamCopyError(ep->read_stream); + if (stream_error != nullptr) { + error = CFStreamAnnotateError( + GRPC_ERROR_CREATE_FROM_CFERROR(stream_error, "Read error"), ep); + CFRelease(stream_error); + } else { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Read error"); + } + CallReadCb(ep, error); + EP_UNREF(ep, "read"); + } else if (read_size == 0) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_slices); + CallReadCb(ep, + CFStreamAnnotateError( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Socket closed"), ep)); + EP_UNREF(ep, "read"); + } else { + if (read_size < static_cast(len)) { + grpc_slice_buffer_trim_end(ep->read_slices, len - read_size, nullptr); + } + CallReadCb(ep, GRPC_ERROR_NONE); + EP_UNREF(ep, "read"); + } +} + +static void WriteAction(void* arg, grpc_error_handle error) { + CFStreamEndpoint* ep = static_cast(arg); + GPR_ASSERT(ep->write_cb != nullptr); + if (error != GRPC_ERROR_NONE) { + grpc_slice_buffer_reset_and_unref_internal(ep->write_slices); + CallWriteCb(ep, GRPC_ERROR_REF(error)); + EP_UNREF(ep, "write"); + return; + } + + grpc_slice slice = grpc_slice_buffer_take_first(ep->write_slices); + size_t slice_len = GRPC_SLICE_LENGTH(slice); + CFIndex write_size = CFWriteStreamWrite( + ep->write_stream, GRPC_SLICE_START_PTR(slice), slice_len); + if (write_size == -1) { + grpc_slice_buffer_reset_and_unref_internal(ep->write_slices); + CFErrorRef stream_error = CFWriteStreamCopyError(ep->write_stream); + if (stream_error != nullptr) { + error = CFStreamAnnotateError( + GRPC_ERROR_CREATE_FROM_CFERROR(stream_error, "write failed."), ep); + CFRelease(stream_error); + } else { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("write failed."); + } + CallWriteCb(ep, error); + EP_UNREF(ep, "write"); + } else { + if (write_size < static_cast(GRPC_SLICE_LENGTH(slice))) { + grpc_slice_buffer_undo_take_first( + ep->write_slices, grpc_slice_sub(slice, write_size, slice_len)); + } + if (ep->write_slices->length > 0) { + ep->stream_sync->NotifyOnWrite(&ep->write_action); + } else { + CallWriteCb(ep, GRPC_ERROR_NONE); + EP_UNREF(ep, "write"); + } + + if (grpc_tcp_trace.enabled()) { + grpc_slice trace_slice = grpc_slice_sub(slice, 0, write_size); + char* dump = grpc_dump_slice(trace_slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "WRITE %p (peer=%s): %s", ep, ep->peer_string.c_str(), + dump); + gpr_free(dump); + grpc_slice_unref_internal(trace_slice); + } + } + grpc_slice_unref_internal(slice); +} + +static void CFStreamReadAllocationDone(void* arg, grpc_error_handle error) { + CFStreamEndpoint* ep = static_cast(arg); + if (error == GRPC_ERROR_NONE) { + ep->stream_sync->NotifyOnRead(&ep->read_action); + } else { + grpc_slice_buffer_reset_and_unref_internal(ep->read_slices); + CallReadCb(ep, error); + EP_UNREF(ep, "read"); + } +} + +static void CFStreamRead(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool urgent) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p read (%p, %p) length:%zu", ep_impl, + slices, cb, slices->length); + } + GPR_ASSERT(ep_impl->read_cb == nullptr); + ep_impl->read_cb = cb; + ep_impl->read_slices = slices; + grpc_slice_buffer_reset_and_unref_internal(slices); + EP_REF(ep_impl, "read"); + if (grpc_slice_allocator_allocate( + ep_impl->slice_allocator, GRPC_TCP_DEFAULT_READ_SLICE_SIZE, 1, + grpc_slice_allocator_intent::kReadBuffer, ep_impl->read_slices, + CFStreamReadAllocationDone, ep_impl)) { + ep_impl->stream_sync->NotifyOnRead(&ep_impl->read_action); + } +} + +static void CFStreamWrite(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p write (%p, %p) length:%zu", + ep_impl, slices, cb, slices->length); + } + GPR_ASSERT(ep_impl->write_cb == nullptr); + ep_impl->write_cb = cb; + ep_impl->write_slices = slices; + EP_REF(ep_impl, "write"); + ep_impl->stream_sync->NotifyOnWrite(&ep_impl->write_action); +} + +void CFStreamShutdown(grpc_endpoint* ep, grpc_error_handle why) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p shutdown (%s)", ep_impl, + grpc_error_std_string(why).c_str()); + } + CFReadStreamClose(ep_impl->read_stream); + CFWriteStreamClose(ep_impl->write_stream); + ep_impl->stream_sync->Shutdown(why); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p shutdown DONE (%s)", ep_impl, + grpc_error_std_string(why).c_str()); + } +} + +void CFStreamDestroy(grpc_endpoint* ep) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CFStream endpoint:%p destroy", ep_impl); + } + EP_UNREF(ep_impl, "destroy"); +} + +absl::string_view CFStreamGetPeer(grpc_endpoint* ep) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + return ep_impl->peer_string; +} + +absl::string_view CFStreamGetLocalAddress(grpc_endpoint* ep) { + CFStreamEndpoint* ep_impl = reinterpret_cast(ep); + return ep_impl->local_address; +} + +int CFStreamGetFD(grpc_endpoint* ep) { return 0; } + +bool CFStreamCanTrackErr(grpc_endpoint* ep) { return false; } + +void CFStreamAddToPollset(grpc_endpoint* ep, grpc_pollset* pollset) {} +void CFStreamAddToPollsetSet(grpc_endpoint* ep, grpc_pollset_set* pollset) {} +void CFStreamDeleteFromPollsetSet(grpc_endpoint* ep, + grpc_pollset_set* pollset) {} + +static const grpc_endpoint_vtable vtable = {CFStreamRead, + CFStreamWrite, + CFStreamAddToPollset, + CFStreamAddToPollsetSet, + CFStreamDeleteFromPollsetSet, + CFStreamShutdown, + CFStreamDestroy, + CFStreamGetPeer, + CFStreamGetLocalAddress, + CFStreamGetFD, + CFStreamCanTrackErr}; + +grpc_endpoint* grpc_cfstream_endpoint_create( + CFReadStreamRef read_stream, CFWriteStreamRef write_stream, + const char* peer_string, grpc_slice_allocator* slice_allocator, + CFStreamHandle* stream_sync) { + CFStreamEndpoint* ep_impl = new CFStreamEndpoint; + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, + "CFStream endpoint:%p create readStream:%p writeStream: %p", + ep_impl, read_stream, write_stream); + } + ep_impl->base.vtable = &vtable; + gpr_ref_init(&ep_impl->refcount, 1); + ep_impl->read_stream = read_stream; + ep_impl->write_stream = write_stream; + CFRetain(read_stream); + CFRetain(write_stream); + ep_impl->stream_sync = stream_sync; + CFSTREAM_HANDLE_REF(ep_impl->stream_sync, "endpoint create"); + + ep_impl->peer_string = peer_string; + grpc_resolved_address resolved_local_addr; + resolved_local_addr.len = sizeof(resolved_local_addr.addr); + CFDataRef native_handle = static_cast(CFReadStreamCopyProperty( + ep_impl->read_stream, kCFStreamPropertySocketNativeHandle)); + CFSocketNativeHandle sockfd; + CFDataGetBytes(native_handle, CFRangeMake(0, sizeof(CFSocketNativeHandle)), + (UInt8*)&sockfd); + if (native_handle) { + CFRelease(native_handle); + } + if (getsockname(sockfd, reinterpret_cast(resolved_local_addr.addr), + &resolved_local_addr.len) < 0) { + ep_impl->local_address = ""; + } else { + ep_impl->local_address = grpc_sockaddr_to_uri(&resolved_local_addr); + } + ep_impl->read_cb = nil; + ep_impl->write_cb = nil; + ep_impl->read_slices = nil; + ep_impl->write_slices = nil; + GRPC_CLOSURE_INIT(&ep_impl->read_action, ReadAction, + static_cast(ep_impl), grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&ep_impl->write_action, WriteAction, + static_cast(ep_impl), grpc_schedule_on_exec_ctx); + ep_impl->slice_allocator = slice_allocator; + + return &ep_impl->base; +} + +#endif /* GRPC_CFSTREAM_ENDPOINT */ diff --git a/src/core/lib/iomgr/endpoint_pair_event_engine.cc b/src/core/lib/iomgr/endpoint_pair_event_engine.cc new file mode 100644 index 00000000..7c0c1008 --- /dev/null +++ b/src/core/lib/iomgr/endpoint_pair_event_engine.cc @@ -0,0 +1,32 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include + +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/port.h" + +grpc_endpoint_pair grpc_iomgr_create_endpoint_pair( + const char* /* name */, grpc_channel_args* /* args */) { + // TODO(hork): determine what's needed here in the long run + GPR_ASSERT( + false && + "grpc_iomgr_create_endpoint_pair is not suppoted with event_engine"); +} + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/endpoint_pair_posix.cc b/src/core/lib/iomgr/endpoint_pair_posix.cc new file mode 100644 index 00000000..217ea377 --- /dev/null +++ b/src/core/lib/iomgr/endpoint_pair_posix.cc @@ -0,0 +1,77 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +static void create_sockets(int sv[2]) { + int flags; + grpc_create_socketpair_if_unix(sv); + flags = fcntl(sv[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[1], F_SETFL, flags | O_NONBLOCK) == 0); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv[0]) == GRPC_ERROR_NONE); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv[1]) == GRPC_ERROR_NONE); +} + +grpc_endpoint_pair grpc_iomgr_create_endpoint_pair(const char* name, + grpc_channel_args* args) { + int sv[2]; + grpc_endpoint_pair p; + create_sockets(sv); + grpc_core::ExecCtx exec_ctx; + std::string final_name = absl::StrCat(name, ":client"); + grpc_resource_quota* resource_quota = + grpc_resource_quota_from_channel_args(args, true); + p.client = grpc_tcp_create( + grpc_fd_create(sv[1], final_name.c_str(), false), args, + "socketpair-server", + grpc_slice_allocator_create(resource_quota, "server_endpoint", args)); + final_name = absl::StrCat(name, ":server"); + p.server = grpc_tcp_create( + grpc_fd_create(sv[0], final_name.c_str(), false), args, + "socketpair-client", + grpc_slice_allocator_create(resource_quota, "client_endpoint", args)); + grpc_resource_quota_unref_internal(resource_quota); + return p; +} + +#endif diff --git a/src/core/lib/iomgr/endpoint_pair_windows.cc b/src/core/lib/iomgr/endpoint_pair_windows.cc new file mode 100644 index 00000000..c1d34fae --- /dev/null +++ b/src/core/lib/iomgr/endpoint_pair_windows.cc @@ -0,0 +1,95 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET +#include +#include +#include + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_windows.h" + +static void create_sockets(SOCKET sv[2]) { + SOCKET svr_sock = INVALID_SOCKET; + SOCKET lst_sock = INVALID_SOCKET; + SOCKET cli_sock = INVALID_SOCKET; + SOCKADDR_IN addr; + int addr_len = sizeof(addr); + + lst_sock = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + grpc_get_default_wsa_socket_flags()); + GPR_ASSERT(lst_sock != INVALID_SOCKET); + + memset(&addr, 0, sizeof(addr)); + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr.sin_family = AF_INET; + GPR_ASSERT(bind(lst_sock, (grpc_sockaddr*)&addr, sizeof(addr)) != + SOCKET_ERROR); + GPR_ASSERT(listen(lst_sock, SOMAXCONN) != SOCKET_ERROR); + GPR_ASSERT(getsockname(lst_sock, (grpc_sockaddr*)&addr, &addr_len) != + SOCKET_ERROR); + + cli_sock = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + grpc_get_default_wsa_socket_flags()); + GPR_ASSERT(cli_sock != INVALID_SOCKET); + + GPR_ASSERT(WSAConnect(cli_sock, (grpc_sockaddr*)&addr, addr_len, NULL, NULL, + NULL, NULL) == 0); + svr_sock = accept(lst_sock, (grpc_sockaddr*)&addr, &addr_len); + GPR_ASSERT(svr_sock != INVALID_SOCKET); + + closesocket(lst_sock); + grpc_tcp_prepare_socket(cli_sock); + grpc_tcp_prepare_socket(svr_sock); + + sv[1] = cli_sock; + sv[0] = svr_sock; +} + +grpc_endpoint_pair grpc_iomgr_create_endpoint_pair( + const char* name, grpc_channel_args* channel_args) { + SOCKET sv[2]; + grpc_endpoint_pair p; + create_sockets(sv); + grpc_core::ExecCtx exec_ctx; + grpc_resource_quota* resource_quota = + grpc_resource_quota_from_channel_args(channel_args, true); + p.client = + grpc_tcp_create(grpc_winsocket_create(sv[1], "endpoint:client"), + channel_args, "endpoint:server", + grpc_slice_allocator_create( + resource_quota, "endpoint:server", channel_args)); + p.server = + grpc_tcp_create(grpc_winsocket_create(sv[0], "endpoint:server"), + channel_args, "endpoint:client", + grpc_slice_allocator_create( + resource_quota, "endpoint:client", channel_args)); + grpc_resource_quota_unref_internal(resource_quota); + return p; +} + +#endif diff --git a/src/core/lib/iomgr/error.cc b/src/core/lib/iomgr/error.cc new file mode 100644 index 00000000..2bf0d72f --- /dev/null +++ b/src/core/lib/iomgr/error.cc @@ -0,0 +1,985 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/iomgr/error.h" + +#include +#include + +#include +#include +#include +#include + +#ifdef GPR_WINDOWS +#include +#endif + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/error_internal.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_utils.h" + +grpc_core::DebugOnlyTraceFlag grpc_trace_error_refcount(false, + "error_refcount"); +grpc_core::DebugOnlyTraceFlag grpc_trace_closure(false, "closure"); + +static gpr_atm g_error_creation_allowed = true; + +void grpc_disable_error_creation() { + gpr_atm_no_barrier_store(&g_error_creation_allowed, false); +} + +void grpc_enable_error_creation() { + gpr_atm_no_barrier_store(&g_error_creation_allowed, true); +} + +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + +absl::Status grpc_status_create(absl::StatusCode code, absl::string_view msg, + const grpc_core::DebugLocation& location, + size_t children_count, absl::Status* children) { + absl::Status s = StatusCreate(code, msg, location, {}); + for (size_t i = 0; i < children_count; ++i) { + if (!children[i].ok()) { + grpc_core::StatusAddChild(&s, children[i]); + } + } + return s; +} + +std::string grpc_error_std_string(absl::Status error) { + return grpc_core::StatusToString(error); +} + +absl::Status grpc_os_error(const grpc_core::DebugLocation& location, int err, + const char* call_name) { + absl::Status s = + StatusCreate(absl::StatusCode::kUnknown, "OS Error", location, {}); + grpc_core::StatusSetInt(&s, grpc_core::StatusIntProperty::kErrorNo, err); + grpc_core::StatusSetStr(&s, grpc_core::StatusStrProperty::kOsError, + strerror(err)); + grpc_core::StatusSetStr(&s, grpc_core::StatusStrProperty::kSyscall, + call_name); + return s; +} + +#ifdef GPR_WINDOWS +absl::Status grpc_wsa_error(const grpc_core::DebugLocation& location, int err, + const char* call_name) { + char* utf8_message = gpr_format_message(err); + absl::Status s = + StatusCreate(absl::StatusCode::kUnknown, "WSA Error", location, {}); + StatusSetInt(&s, grpc_core::StatusIntProperty::kWsaError, err); + StatusSetStr(&s, grpc_core::StatusStrProperty::kOsError, utf8_message); + StatusSetStr(&s, grpc_core::StatusStrProperty::kSyscall, call_name); + return s; +} +#endif + +grpc_error_handle grpc_error_set_int(grpc_error_handle src, + grpc_error_ints which, intptr_t value) { + if (src == GRPC_ERROR_NONE) { + src = absl::UnknownError(""); + StatusSetInt(&src, grpc_core::StatusIntProperty::kRpcStatus, + GRPC_STATUS_OK); + } + grpc_core::StatusSetInt( + &src, static_cast(which), value); + return src; +} + +bool grpc_error_get_int(grpc_error_handle error, grpc_error_ints which, + intptr_t* p) { + absl::optional value = grpc_core::StatusGetInt( + error, static_cast(which)); + if (value.has_value()) { + *p = *value; + return true; + } else { + // TODO(veblush): Remove this once absl::Status migration is done + if (which == GRPC_ERROR_INT_GRPC_STATUS) { + switch (error.code()) { + case absl::StatusCode::kOk: + *p = GRPC_STATUS_OK; + return true; + case absl::StatusCode::kResourceExhausted: + *p = GRPC_STATUS_RESOURCE_EXHAUSTED; + return true; + case absl::StatusCode::kCancelled: + *p = GRPC_STATUS_CANCELLED; + return true; + default: + break; + } + } + return false; + } +} + +grpc_error_handle grpc_error_set_str(grpc_error_handle src, + grpc_error_strs which, + absl::string_view str) { + if (src == GRPC_ERROR_NONE) { + src = absl::UnknownError(""); + StatusSetInt(&src, grpc_core::StatusIntProperty::kRpcStatus, + GRPC_STATUS_OK); + } + if (which == GRPC_ERROR_STR_DESCRIPTION) { + // To change the message of absl::Status, a new instance should be created + // with a code and payload because it doesn't have a setter for it. + absl::Status s = absl::Status(src.code(), str); + src.ForEachPayload( + [&](absl::string_view type_url, const absl::Cord& payload) { + s.SetPayload(type_url, payload); + }); + return s; + } else { + grpc_core::StatusSetStr( + &src, static_cast(which), str); + } + return src; +} + +bool grpc_error_get_str(grpc_error_handle error, grpc_error_strs which, + std::string* s) { + if (which == GRPC_ERROR_STR_DESCRIPTION) { + // absl::Status uses the message field for GRPC_ERROR_STR_DESCRIPTION + // instead of using payload. + absl::string_view msg = error.message(); + if (msg.empty()) { + return false; + } else { + *s = std::string(msg); + return true; + } + } else { + absl::optional value = grpc_core::StatusGetStr( + error, static_cast(which)); + if (value.has_value()) { + *s = std::move(*value); + return true; + } else { + // TODO(veblush): Remove this once absl::Status migration is done + if (which == GRPC_ERROR_STR_GRPC_MESSAGE) { + switch (error.code()) { + case absl::StatusCode::kOk: + *s = ""; + return true; + case absl::StatusCode::kResourceExhausted: + *s = "RESOURCE_EXHAUSTED"; + return true; + case absl::StatusCode::kCancelled: + *s = "CANCELLED"; + return true; + default: + break; + } + } + return false; + } + } +} + +grpc_error_handle grpc_error_add_child(grpc_error_handle src, + grpc_error_handle child) { + if (src.ok()) { + return child; + } else { + if (!child.ok()) { + grpc_core::StatusAddChild(&src, child); + } + return src; + } +} + +bool grpc_log_error(const char* what, grpc_error_handle error, const char* file, + int line) { + GPR_DEBUG_ASSERT(error != GRPC_ERROR_NONE); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, "%s: %s", what, + grpc_core::StatusToString(error).c_str()); + return false; +} + +#else // GRPC_ERROR_IS_ABSEIL_STATUS + +static const char* error_int_name(grpc_error_ints key) { + switch (key) { + case GRPC_ERROR_INT_ERRNO: + return "errno"; + case GRPC_ERROR_INT_FILE_LINE: + return "file_line"; + case GRPC_ERROR_INT_STREAM_ID: + return "stream_id"; + case GRPC_ERROR_INT_GRPC_STATUS: + return "grpc_status"; + case GRPC_ERROR_INT_OFFSET: + return "offset"; + case GRPC_ERROR_INT_INDEX: + return "index"; + case GRPC_ERROR_INT_SIZE: + return "size"; + case GRPC_ERROR_INT_HTTP2_ERROR: + return "http2_error"; + case GRPC_ERROR_INT_TSI_CODE: + return "tsi_code"; + case GRPC_ERROR_INT_FD: + return "fd"; + case GRPC_ERROR_INT_WSA_ERROR: + return "wsa_error"; + case GRPC_ERROR_INT_HTTP_STATUS: + return "http_status"; + case GRPC_ERROR_INT_OCCURRED_DURING_WRITE: + return "occurred_during_write"; + case GRPC_ERROR_INT_CHANNEL_CONNECTIVITY_STATE: + return "channel_connectivity_state"; + case GRPC_ERROR_INT_LB_POLICY_DROP: + return "lb_policy_drop"; + case GRPC_ERROR_INT_MAX: + GPR_UNREACHABLE_CODE(return "unknown"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +static const char* error_str_name(grpc_error_strs key) { + switch (key) { + case GRPC_ERROR_STR_KEY: + return "key"; + case GRPC_ERROR_STR_VALUE: + return "value"; + case GRPC_ERROR_STR_DESCRIPTION: + return "description"; + case GRPC_ERROR_STR_OS_ERROR: + return "os_error"; + case GRPC_ERROR_STR_TARGET_ADDRESS: + return "target_address"; + case GRPC_ERROR_STR_SYSCALL: + return "syscall"; + case GRPC_ERROR_STR_FILE: + return "file"; + case GRPC_ERROR_STR_GRPC_MESSAGE: + return "grpc_message"; + case GRPC_ERROR_STR_RAW_BYTES: + return "raw_bytes"; + case GRPC_ERROR_STR_TSI_ERROR: + return "tsi_error"; + case GRPC_ERROR_STR_FILENAME: + return "filename"; + case GRPC_ERROR_STR_MAX: + GPR_UNREACHABLE_CODE(return "unknown"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +static const char* error_time_name(grpc_error_times key) { + switch (key) { + case GRPC_ERROR_TIME_CREATED: + return "created"; + case GRPC_ERROR_TIME_MAX: + GPR_UNREACHABLE_CODE(return "unknown"); + } + GPR_UNREACHABLE_CODE(return "unknown"); +} + +#ifndef NDEBUG +grpc_error_handle grpc_error_do_ref(grpc_error_handle err, const char* file, + int line) { + if (grpc_trace_error_refcount.enabled()) { + gpr_log(GPR_DEBUG, "%p: %" PRIdPTR " -> %" PRIdPTR " [%s:%d]", err, + gpr_atm_no_barrier_load(&err->atomics.refs.count), + gpr_atm_no_barrier_load(&err->atomics.refs.count) + 1, file, line); + } + gpr_ref(&err->atomics.refs); + return err; +} +#else +grpc_error_handle grpc_error_do_ref(grpc_error_handle err) { + gpr_ref(&err->atomics.refs); + return err; +} +#endif + +static void unref_errs(grpc_error_handle err) { + uint8_t slot = err->first_err; + while (slot != UINT8_MAX) { + grpc_linked_error* lerr = + reinterpret_cast(err->arena + slot); + GRPC_ERROR_UNREF(lerr->err); + GPR_ASSERT(err->last_err == slot ? lerr->next == UINT8_MAX + : lerr->next != UINT8_MAX); + slot = lerr->next; + } +} + +static void unref_strs(grpc_error_handle err) { + for (size_t which = 0; which < GRPC_ERROR_STR_MAX; ++which) { + uint8_t slot = err->strs[which]; + if (slot != UINT8_MAX) { + grpc_slice_unref_internal( + *reinterpret_cast(err->arena + slot)); + } + } +} + +static void error_destroy(grpc_error_handle err) { + GPR_ASSERT(!grpc_error_is_special(err)); + unref_errs(err); + unref_strs(err); + gpr_free( + reinterpret_cast(gpr_atm_acq_load(&err->atomics.error_string))); + gpr_free(err); +} + +#ifndef NDEBUG +void grpc_error_do_unref(grpc_error_handle err, const char* file, int line) { + if (grpc_trace_error_refcount.enabled()) { + gpr_log(GPR_DEBUG, "%p: %" PRIdPTR " -> %" PRIdPTR " [%s:%d]", err, + gpr_atm_no_barrier_load(&err->atomics.refs.count), + gpr_atm_no_barrier_load(&err->atomics.refs.count) - 1, file, line); + } + if (gpr_unref(&err->atomics.refs)) { + error_destroy(err); + } +} +#else +void grpc_error_do_unref(grpc_error_handle err) { + if (gpr_unref(&err->atomics.refs)) { + error_destroy(err); + } +} +#endif + +static uint8_t get_placement(grpc_error_handle* err, size_t size) { + GPR_ASSERT(*err); + uint8_t slots = static_cast(size / sizeof(intptr_t)); + if ((*err)->arena_size + slots > (*err)->arena_capacity) { + (*err)->arena_capacity = static_cast(std::min( + size_t(UINT8_MAX - 1), size_t(3 * (*err)->arena_capacity / 2))); + if ((*err)->arena_size + slots > (*err)->arena_capacity) { + return UINT8_MAX; + } +#ifndef NDEBUG + grpc_error_handle orig = *err; +#endif + *err = static_cast(gpr_realloc( + *err, sizeof(grpc_error) + (*err)->arena_capacity * sizeof(intptr_t))); +#ifndef NDEBUG + if (grpc_trace_error_refcount.enabled()) { + if (*err != orig) { + gpr_log(GPR_DEBUG, "realloc %p -> %p", orig, *err); + } + } +#endif + } + uint8_t placement = (*err)->arena_size; + (*err)->arena_size = static_cast((*err)->arena_size + slots); + return placement; +} + +static void internal_set_int(grpc_error_handle* err, grpc_error_ints which, + intptr_t value) { + uint8_t slot = (*err)->ints[which]; + if (slot == UINT8_MAX) { + slot = get_placement(err, sizeof(value)); + if (slot == UINT8_MAX) { + gpr_log(GPR_ERROR, "Error %p is full, dropping int {\"%s\":%" PRIiPTR "}", + *err, error_int_name(which), value); + return; + } + } + (*err)->ints[which] = slot; + (*err)->arena[slot] = value; +} + +static void internal_set_str(grpc_error_handle* err, grpc_error_strs which, + const grpc_slice& value) { + uint8_t slot = (*err)->strs[which]; + if (slot == UINT8_MAX) { + slot = get_placement(err, sizeof(value)); + if (slot == UINT8_MAX) { + char* str = grpc_slice_to_c_string(value); + gpr_log(GPR_ERROR, "Error %p is full, dropping string {\"%s\":\"%s\"}", + *err, error_str_name(which), str); + gpr_free(str); + return; + } + } else { + grpc_slice_unref_internal( + *reinterpret_cast((*err)->arena + slot)); + } + (*err)->strs[which] = slot; + memcpy((*err)->arena + slot, &value, sizeof(value)); +} + +static char* fmt_time(gpr_timespec tm); +static void internal_set_time(grpc_error_handle* err, grpc_error_times which, + gpr_timespec value) { + uint8_t slot = (*err)->times[which]; + if (slot == UINT8_MAX) { + slot = get_placement(err, sizeof(value)); + if (slot == UINT8_MAX) { + char* time_str = fmt_time(value); + gpr_log(GPR_ERROR, "Error %p is full, dropping \"%s\":\"%s\"}", *err, + error_time_name(which), time_str); + gpr_free(time_str); + return; + } + } + (*err)->times[which] = slot; + memcpy((*err)->arena + slot, &value, sizeof(value)); +} + +static void internal_add_error(grpc_error_handle* err, + grpc_error_handle new_err) { + grpc_linked_error new_last = {new_err, UINT8_MAX}; + uint8_t slot = get_placement(err, sizeof(grpc_linked_error)); + if (slot == UINT8_MAX) { + gpr_log(GPR_ERROR, "Error %p is full, dropping error %p = %s", *err, + new_err, grpc_error_string(new_err)); + GRPC_ERROR_UNREF(new_err); + return; + } + if ((*err)->first_err == UINT8_MAX) { + GPR_ASSERT((*err)->last_err == UINT8_MAX); + (*err)->last_err = slot; + (*err)->first_err = slot; + } else { + GPR_ASSERT((*err)->last_err != UINT8_MAX); + grpc_linked_error* old_last = + reinterpret_cast((*err)->arena + (*err)->last_err); + old_last->next = slot; + (*err)->last_err = slot; + } + memcpy((*err)->arena + slot, &new_last, sizeof(grpc_linked_error)); +} + +#define SLOTS_PER_INT (1) // == (sizeof(intptr_t) / sizeof(intptr_t)) +#define SLOTS_PER_STR (sizeof(grpc_slice) / sizeof(intptr_t)) +#define SLOTS_PER_TIME (sizeof(gpr_timespec) / sizeof(intptr_t)) +#define SLOTS_PER_LINKED_ERROR (sizeof(grpc_linked_error) / sizeof(intptr_t)) + +// size of storing one int and two slices and a timespec. For line, desc, file, +// and time created +#define DEFAULT_ERROR_CAPACITY \ + (SLOTS_PER_INT + (SLOTS_PER_STR * 2) + SLOTS_PER_TIME) + +// It is very common to include and extra int and string in an error +#define SURPLUS_CAPACITY (2 * SLOTS_PER_INT + SLOTS_PER_TIME) + +grpc_error_handle grpc_error_create(const char* file, int line, + const grpc_slice& desc, + grpc_error_handle* referencing, + size_t num_referencing) { + uint8_t initial_arena_capacity = static_cast( + DEFAULT_ERROR_CAPACITY + + static_cast(num_referencing * SLOTS_PER_LINKED_ERROR) + + SURPLUS_CAPACITY); + grpc_error_handle err = static_cast( + gpr_malloc(sizeof(*err) + initial_arena_capacity * sizeof(intptr_t))); + if (err == nullptr) { // TODO(ctiller): make gpr_malloc return NULL + return GRPC_ERROR_OOM; + } +#ifndef NDEBUG + if (!gpr_atm_no_barrier_load(&g_error_creation_allowed)) { + gpr_log(GPR_ERROR, + "Error creation occurred when error creation was disabled [%s:%d]", + file, line); + abort(); + } + if (grpc_trace_error_refcount.enabled()) { + gpr_log(GPR_DEBUG, "%p create [%s:%d]", err, file, line); + } +#endif + + err->arena_size = 0; + err->arena_capacity = initial_arena_capacity; + err->first_err = UINT8_MAX; + err->last_err = UINT8_MAX; + + memset(err->ints, UINT8_MAX, GRPC_ERROR_INT_MAX); + memset(err->strs, UINT8_MAX, GRPC_ERROR_STR_MAX); + memset(err->times, UINT8_MAX, GRPC_ERROR_TIME_MAX); + + internal_set_int(&err, GRPC_ERROR_INT_FILE_LINE, line); + internal_set_str(&err, GRPC_ERROR_STR_FILE, + grpc_slice_from_static_string(file)); + internal_set_str(&err, GRPC_ERROR_STR_DESCRIPTION, desc); + + for (size_t i = 0; i < num_referencing; ++i) { + if (referencing[i] == GRPC_ERROR_NONE) continue; + internal_add_error( + &err, + GRPC_ERROR_REF( + referencing[i])); // TODO(ncteisen), change ownership semantics + } + + internal_set_time(&err, GRPC_ERROR_TIME_CREATED, gpr_now(GPR_CLOCK_REALTIME)); + + gpr_atm_no_barrier_store(&err->atomics.error_string, 0); + gpr_ref_init(&err->atomics.refs, 1); + return err; +} + +static void ref_strs(grpc_error_handle err) { + for (size_t i = 0; i < GRPC_ERROR_STR_MAX; ++i) { + uint8_t slot = err->strs[i]; + if (slot != UINT8_MAX) { + grpc_slice_ref_internal( + *reinterpret_cast(err->arena + slot)); + } + } +} + +static void ref_errs(grpc_error_handle err) { + uint8_t slot = err->first_err; + while (slot != UINT8_MAX) { + grpc_linked_error* lerr = + reinterpret_cast(err->arena + slot); + (void)GRPC_ERROR_REF(lerr->err); + slot = lerr->next; + } +} + +static grpc_error_handle copy_error_and_unref(grpc_error_handle in) { + grpc_error_handle out; + if (grpc_error_is_special(in)) { + out = GRPC_ERROR_CREATE_FROM_STATIC_STRING("unknown"); + if (in == GRPC_ERROR_NONE) { + internal_set_str(&out, GRPC_ERROR_STR_DESCRIPTION, + grpc_slice_from_static_string("no error")); + internal_set_int(&out, GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_OK); + } else if (in == GRPC_ERROR_OOM) { + internal_set_str(&out, GRPC_ERROR_STR_DESCRIPTION, + grpc_slice_from_static_string("oom")); + } else if (in == GRPC_ERROR_CANCELLED) { + internal_set_str(&out, GRPC_ERROR_STR_DESCRIPTION, + grpc_slice_from_static_string("cancelled")); + internal_set_int(&out, GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_CANCELLED); + } + } else if (gpr_ref_is_unique(&in->atomics.refs)) { + out = in; + } else { + uint8_t new_arena_capacity = in->arena_capacity; + // the returned err will be added to, so we ensure this is room to avoid + // unneeded allocations. + if (in->arena_capacity - in->arena_size < + static_cast SLOTS_PER_STR) { + new_arena_capacity = static_cast(3 * new_arena_capacity / 2); + } + out = static_cast( + gpr_malloc(sizeof(*in) + new_arena_capacity * sizeof(intptr_t))); +#ifndef NDEBUG + if (grpc_trace_error_refcount.enabled()) { + gpr_log(GPR_DEBUG, "%p create copying %p", out, in); + } +#endif + // bulk memcpy of the rest of the struct. + // NOLINTNEXTLINE(bugprone-sizeof-expression) + size_t skip = sizeof(&out->atomics); + memcpy(reinterpret_cast(reinterpret_cast(out) + skip), + reinterpret_cast(reinterpret_cast(in) + skip), + sizeof(*in) + (in->arena_size * sizeof(intptr_t)) - skip); + // manually set the atomics and the new capacity + gpr_atm_no_barrier_store(&out->atomics.error_string, 0); + gpr_ref_init(&out->atomics.refs, 1); + out->arena_capacity = new_arena_capacity; + ref_strs(out); + ref_errs(out); + GRPC_ERROR_UNREF(in); + } + return out; +} + +grpc_error_handle grpc_error_set_int(grpc_error_handle src, + grpc_error_ints which, intptr_t value) { + grpc_error_handle new_err = copy_error_and_unref(src); + internal_set_int(&new_err, which, value); + return new_err; +} + +struct special_error_status_map { + grpc_status_code code; + const char* msg; + size_t len; +}; +const special_error_status_map error_status_map[] = { + {GRPC_STATUS_OK, "", 0}, // GRPC_ERROR_NONE + {GRPC_STATUS_INVALID_ARGUMENT, "", 0}, // GRPC_ERROR_RESERVED_1 + {GRPC_STATUS_RESOURCE_EXHAUSTED, "RESOURCE_EXHAUSTED", + strlen("RESOURCE_EXHAUSTED")}, // GRPC_ERROR_OOM + {GRPC_STATUS_INVALID_ARGUMENT, "", 0}, // GRPC_ERROR_RESERVED_2 + {GRPC_STATUS_CANCELLED, "CANCELLED", + strlen("CANCELLED")}, // GRPC_ERROR_CANCELLED +}; + +bool grpc_error_get_int(grpc_error_handle err, grpc_error_ints which, + intptr_t* p) { + if (grpc_error_is_special(err)) { + if (which != GRPC_ERROR_INT_GRPC_STATUS) return false; + *p = error_status_map[reinterpret_cast(err)].code; + return true; + } + uint8_t slot = err->ints[which]; + if (slot != UINT8_MAX) { + if (p != nullptr) *p = err->arena[slot]; + return true; + } + return false; +} + +grpc_error_handle grpc_error_set_str(grpc_error_handle src, + grpc_error_strs which, + absl::string_view str) { + grpc_error_handle new_err = copy_error_and_unref(src); + internal_set_str(&new_err, which, + grpc_slice_from_copied_buffer(str.data(), str.length())); + return new_err; +} + +bool grpc_error_get_str(grpc_error_handle err, grpc_error_strs which, + std::string* s) { + if (grpc_error_is_special(err)) { + if (which != GRPC_ERROR_STR_GRPC_MESSAGE) return false; + const special_error_status_map& msg = + error_status_map[reinterpret_cast(err)]; + *s = std::string(msg.msg, msg.len); + return true; + } + uint8_t slot = err->strs[which]; + if (slot != UINT8_MAX) { + grpc_slice* slice = reinterpret_cast(err->arena + slot); + *s = std::string(grpc_core::StringViewFromSlice(*slice)); + return true; + } else { + return false; + } +} + +grpc_error_handle grpc_error_add_child(grpc_error_handle src, + grpc_error_handle child) { + if (src != GRPC_ERROR_NONE) { + if (child == GRPC_ERROR_NONE) { + /* \a child is empty. Simply return the ref to \a src */ + return src; + } else if (child != src) { + grpc_error_handle new_err = copy_error_and_unref(src); + internal_add_error(&new_err, child); + return new_err; + } else { + /* \a src and \a child are the same. Drop one of the references and return + * the other */ + GRPC_ERROR_UNREF(child); + return src; + } + } else { + /* \a src is empty. Simply return the ref to \a child */ + return child; + } +} + +static const char* no_error_string = "\"OK\""; +static const char* oom_error_string = "\"RESOURCE_EXHAUSTED\""; +static const char* cancelled_error_string = "\"CANCELLED\""; + +struct kv_pair { + char* key; + char* value; +}; +struct kv_pairs { + kv_pair* kvs; + size_t num_kvs; + size_t cap_kvs; +}; +static void append_chr(char c, char** s, size_t* sz, size_t* cap) { + if (*sz == *cap) { + *cap = std::max(size_t(8), 3 * *cap / 2); + *s = static_cast(gpr_realloc(*s, *cap)); + } + (*s)[(*sz)++] = c; +} + +static void append_str(const char* str, char** s, size_t* sz, size_t* cap) { + for (const char* c = str; *c; c++) { + append_chr(*c, s, sz, cap); + } +} + +static void append_esc_str(const uint8_t* str, size_t len, char** s, size_t* sz, + size_t* cap) { + static const char* hex = "0123456789abcdef"; + append_chr('"', s, sz, cap); + for (size_t i = 0; i < len; i++, str++) { + if (*str < 32 || *str >= 127) { + append_chr('\\', s, sz, cap); + switch (*str) { + case '\b': + append_chr('b', s, sz, cap); + break; + case '\f': + append_chr('f', s, sz, cap); + break; + case '\n': + append_chr('n', s, sz, cap); + break; + case '\r': + append_chr('r', s, sz, cap); + break; + case '\t': + append_chr('t', s, sz, cap); + break; + default: + append_chr('u', s, sz, cap); + append_chr('0', s, sz, cap); + append_chr('0', s, sz, cap); + append_chr(hex[*str >> 4], s, sz, cap); + append_chr(hex[*str & 0x0f], s, sz, cap); + break; + } + } else { + append_chr(static_cast(*str), s, sz, cap); + } + } + append_chr('"', s, sz, cap); +} + +static void append_kv(kv_pairs* kvs, char* key, char* value) { + if (kvs->num_kvs == kvs->cap_kvs) { + kvs->cap_kvs = std::max(3 * kvs->cap_kvs / 2, size_t(4)); + kvs->kvs = static_cast( + gpr_realloc(kvs->kvs, sizeof(*kvs->kvs) * kvs->cap_kvs)); + } + kvs->kvs[kvs->num_kvs].key = key; + kvs->kvs[kvs->num_kvs].value = value; + kvs->num_kvs++; +} + +static char* key_int(grpc_error_ints which) { + return gpr_strdup(error_int_name(which)); +} + +static char* fmt_int(intptr_t p) { + char* s; + gpr_asprintf(&s, "%" PRIdPTR, p); + return s; +} + +static void collect_ints_kvs(grpc_error_handle err, kv_pairs* kvs) { + for (size_t which = 0; which < GRPC_ERROR_INT_MAX; ++which) { + uint8_t slot = err->ints[which]; + if (slot != UINT8_MAX) { + append_kv(kvs, key_int(static_cast(which)), + fmt_int(err->arena[slot])); + } + } +} + +static char* key_str(grpc_error_strs which) { + return gpr_strdup(error_str_name(which)); +} + +static char* fmt_str(const grpc_slice& slice) { + char* s = nullptr; + size_t sz = 0; + size_t cap = 0; + append_esc_str(GRPC_SLICE_START_PTR(slice), GRPC_SLICE_LENGTH(slice), &s, &sz, + &cap); + append_chr(0, &s, &sz, &cap); + return s; +} + +static void collect_strs_kvs(grpc_error_handle err, kv_pairs* kvs) { + for (size_t which = 0; which < GRPC_ERROR_STR_MAX; ++which) { + uint8_t slot = err->strs[which]; + if (slot != UINT8_MAX) { + append_kv(kvs, key_str(static_cast(which)), + fmt_str(*reinterpret_cast(err->arena + slot))); + } + } +} + +static char* key_time(grpc_error_times which) { + return gpr_strdup(error_time_name(which)); +} + +static char* fmt_time(gpr_timespec tm) { + char* out; + const char* pfx = "!!"; + switch (tm.clock_type) { + case GPR_CLOCK_MONOTONIC: + pfx = "@monotonic:"; + break; + case GPR_CLOCK_REALTIME: + pfx = "@"; + break; + case GPR_CLOCK_PRECISE: + pfx = "@precise:"; + break; + case GPR_TIMESPAN: + pfx = ""; + break; + } + gpr_asprintf(&out, "\"%s%" PRId64 ".%09d\"", pfx, tm.tv_sec, tm.tv_nsec); + return out; +} + +static void collect_times_kvs(grpc_error_handle err, kv_pairs* kvs) { + for (size_t which = 0; which < GRPC_ERROR_TIME_MAX; ++which) { + uint8_t slot = err->times[which]; + if (slot != UINT8_MAX) { + append_kv(kvs, key_time(static_cast(which)), + fmt_time(*reinterpret_cast(err->arena + slot))); + } + } +} + +static void add_errs(grpc_error_handle err, char** s, size_t* sz, size_t* cap) { + uint8_t slot = err->first_err; + bool first = true; + while (slot != UINT8_MAX) { + grpc_linked_error* lerr = + reinterpret_cast(err->arena + slot); + if (!first) append_chr(',', s, sz, cap); + first = false; + const char* e = grpc_error_string(lerr->err); + append_str(e, s, sz, cap); + GPR_ASSERT(err->last_err == slot ? lerr->next == UINT8_MAX + : lerr->next != UINT8_MAX); + slot = lerr->next; + } +} + +static char* errs_string(grpc_error_handle err) { + char* s = nullptr; + size_t sz = 0; + size_t cap = 0; + append_chr('[', &s, &sz, &cap); + add_errs(err, &s, &sz, &cap); + append_chr(']', &s, &sz, &cap); + append_chr(0, &s, &sz, &cap); + return s; +} + +static int cmp_kvs(const void* a, const void* b) { + const kv_pair* ka = static_cast(a); + const kv_pair* kb = static_cast(b); + return strcmp(ka->key, kb->key); +} + +static char* finish_kvs(kv_pairs* kvs) { + char* s = nullptr; + size_t sz = 0; + size_t cap = 0; + + append_chr('{', &s, &sz, &cap); + for (size_t i = 0; i < kvs->num_kvs; i++) { + if (i != 0) append_chr(',', &s, &sz, &cap); + append_esc_str(reinterpret_cast(kvs->kvs[i].key), + strlen(kvs->kvs[i].key), &s, &sz, &cap); + gpr_free(kvs->kvs[i].key); + append_chr(':', &s, &sz, &cap); + append_str(kvs->kvs[i].value, &s, &sz, &cap); + gpr_free(kvs->kvs[i].value); + } + append_chr('}', &s, &sz, &cap); + append_chr(0, &s, &sz, &cap); + + gpr_free(kvs->kvs); + return s; +} + +const char* grpc_error_string(grpc_error_handle err) { + if (err == GRPC_ERROR_NONE) return no_error_string; + if (err == GRPC_ERROR_OOM) return oom_error_string; + if (err == GRPC_ERROR_CANCELLED) return cancelled_error_string; + + void* p = + reinterpret_cast(gpr_atm_acq_load(&err->atomics.error_string)); + if (p != nullptr) { + return static_cast(p); + } + + kv_pairs kvs; + memset(&kvs, 0, sizeof(kvs)); + + collect_ints_kvs(err, &kvs); + collect_strs_kvs(err, &kvs); + collect_times_kvs(err, &kvs); + if (err->first_err != UINT8_MAX) { + append_kv(&kvs, gpr_strdup("referenced_errors"), errs_string(err)); + } + + qsort(kvs.kvs, kvs.num_kvs, sizeof(kv_pair), cmp_kvs); + + char* out = finish_kvs(&kvs); + + if (!gpr_atm_rel_cas(&err->atomics.error_string, 0, + reinterpret_cast(out))) { + gpr_free(out); + out = reinterpret_cast(gpr_atm_acq_load(&err->atomics.error_string)); + } + + return out; +} + +std::string grpc_error_std_string(grpc_error_handle error) { + return std::string(grpc_error_string(error)); +} + +grpc_error_handle grpc_os_error(const char* file, int line, int err, + const char* call_name) { + return grpc_error_set_str( + grpc_error_set_str( + grpc_error_set_int( + grpc_error_create(file, line, + grpc_slice_from_static_string(strerror(err)), + nullptr, 0), + GRPC_ERROR_INT_ERRNO, err), + GRPC_ERROR_STR_OS_ERROR, strerror(err)), + GRPC_ERROR_STR_SYSCALL, call_name); +} + +#ifdef GPR_WINDOWS +grpc_error_handle grpc_wsa_error(const char* file, int line, int err, + const char* call_name) { + char* utf8_message = gpr_format_message(err); + grpc_error_handle error = grpc_error_set_str( + grpc_error_set_str( + grpc_error_set_int( + grpc_error_create(file, line, + grpc_slice_from_static_string("OS Error"), NULL, + 0), + GRPC_ERROR_INT_WSA_ERROR, err), + GRPC_ERROR_STR_OS_ERROR, utf8_message), + GRPC_ERROR_STR_SYSCALL, call_name); + gpr_free(utf8_message); + return error; +} +#endif + +bool grpc_log_error(const char* what, grpc_error_handle error, const char* file, + int line) { + GPR_DEBUG_ASSERT(error != GRPC_ERROR_NONE); + const char* msg = grpc_error_string(error); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, "%s: %s", what, msg); + GRPC_ERROR_UNREF(error); + return false; +} + +#endif // GRPC_ERROR_IS_ABSEIL_STATUS diff --git a/src/core/lib/iomgr/error_cfstream.cc b/src/core/lib/iomgr/error_cfstream.cc new file mode 100644 index 00000000..84052aae --- /dev/null +++ b/src/core/lib/iomgr/error_cfstream.cc @@ -0,0 +1,59 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GRPC_CFSTREAM +#include + +#include + +#include "absl/strings/str_format.h" + +#include + +#include "src/core/lib/iomgr/error.h" + +#define MAX_ERROR_DESCRIPTION 256 + +grpc_error_handle grpc_error_create_from_cferror(const char* file, int line, + void* arg, + const char* custom_desc) { + CFErrorRef error = static_cast(arg); + char buf_domain[MAX_ERROR_DESCRIPTION]; + char buf_desc[MAX_ERROR_DESCRIPTION]; + CFErrorDomain domain = CFErrorGetDomain((error)); + CFIndex code = CFErrorGetCode((error)); + CFStringRef desc = CFErrorCopyDescription((error)); + CFStringGetCString(domain, buf_domain, MAX_ERROR_DESCRIPTION, + kCFStringEncodingUTF8); + CFStringGetCString(desc, buf_desc, MAX_ERROR_DESCRIPTION, + kCFStringEncodingUTF8); + std::string error_msg = + absl::StrFormat("%s (error domain:%s, code:%ld, description:%s)", + custom_desc, buf_domain, code, buf_desc); + CFRelease(desc); +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + return StatusCreate(absl::StatusCode::kUnknown, error_msg, + grpc_core::DebugLocation(file, line), {}); +#else + return grpc_error_create( + file, line, grpc_slice_from_copied_string(error_msg.c_str()), NULL, 0); +#endif +} +#endif /* GRPC_CFSTREAM */ diff --git a/src/core/lib/iomgr/ev_apple.cc b/src/core/lib/iomgr/ev_apple.cc new file mode 100644 index 00000000..25805826 --- /dev/null +++ b/src/core/lib/iomgr/ev_apple.cc @@ -0,0 +1,359 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/// Event engine based on Apple's CFRunLoop API family. If the CFRunLoop engine +/// is enabled (see iomgr_posix_cfstream.cc), a global thread is started to +/// handle and trigger all the CFStream events. The CFStream streams register +/// themselves with the run loop with functions grpc_apple_register_read_stream +/// and grpc_apple_register_read_stream. Pollsets are phony and block on a +/// condition variable in pollset_work(). + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_APPLE_EV + +#include + +#include + +#include "absl/time/time.h" + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/gprpp/time_util.h" +#include "src/core/lib/iomgr/ev_apple.h" + +grpc_core::DebugOnlyTraceFlag grpc_apple_polling_trace(false, "apple_polling"); + +#ifndef NDEBUG +#define GRPC_POLLING_TRACE(format, ...) \ + if (GRPC_TRACE_FLAG_ENABLED(grpc_apple_polling_trace)) { \ + gpr_log(GPR_DEBUG, "(polling) " format, __VA_ARGS__); \ + } +#else +#define GRPC_POLLING_TRACE(...) +#endif // NDEBUG + +#define GRPC_POLLSET_KICK_BROADCAST ((grpc_pollset_worker*)1) + +struct GlobalRunLoopContext { + grpc_core::CondVar init_cv; + grpc_core::CondVar input_source_cv; + + grpc_core::Mutex mu; + + // Whether an input source registration is pending. Protected by mu. + bool input_source_registered = false; + + // The reference to the global run loop object. Protected by mu. + CFRunLoopRef run_loop; + + // Whether the pollset has been globally shut down. Protected by mu. + bool is_shutdown = false; +}; + +struct GrpcAppleWorker { + // The condition varible to kick the worker. Works with the pollset's lock + // (GrpcApplePollset.mu). + grpc_core::CondVar cv; + + // Whether the worker is kicked. Protected by the pollset's lock + // (GrpcApplePollset.mu). + bool kicked = false; +}; + +struct GrpcApplePollset { + grpc_core::Mutex mu; + + // Tracks the current workers in the pollset. Protected by mu. + std::list workers; + + // Whether the pollset is shut down. Protected by mu. + bool is_shutdown = false; + + // Closure to call when shutdown is done. Protected by mu. + grpc_closure* shutdown_closure; + + // Whether there's an outstanding kick that was not processed. Protected by + // mu. + bool kicked_without_poller = false; +}; + +static GlobalRunLoopContext* gGlobalRunLoopContext = nullptr; +static grpc_core::Thread* gGlobalRunLoopThread = nullptr; + +/// Register the stream with the dispatch queue. Callbacks of the stream will be +/// issued to the dispatch queue when a network event happens and will be +/// managed by Grand Central Dispatch. +static void grpc_apple_register_read_stream_queue( + CFReadStreamRef read_stream, dispatch_queue_t dispatch_queue) { + CFReadStreamSetDispatchQueue(read_stream, dispatch_queue); +} + +/// Register the stream with the dispatch queue. Callbacks of the stream will be +/// issued to the dispatch queue when a network event happens and will be +/// managed by Grand Central Dispatch. +static void grpc_apple_register_write_stream_queue( + CFWriteStreamRef write_stream, dispatch_queue_t dispatch_queue) { + CFWriteStreamSetDispatchQueue(write_stream, dispatch_queue); +} + +/// Register the stream with the global run loop. Callbacks of the stream will +/// be issued to the run loop when a network event happens and will be driven by +/// the global run loop thread gGlobalRunLoopThread. +static void grpc_apple_register_read_stream_run_loop( + CFReadStreamRef read_stream, dispatch_queue_t dispatch_queue) { + GRPC_POLLING_TRACE("Register read stream: %p", read_stream); + grpc_core::MutexLock lock(&gGlobalRunLoopContext->mu); + CFReadStreamScheduleWithRunLoop(read_stream, gGlobalRunLoopContext->run_loop, + kCFRunLoopDefaultMode); + gGlobalRunLoopContext->input_source_registered = true; + gGlobalRunLoopContext->input_source_cv.Signal(); +} + +/// Register the stream with the global run loop. Callbacks of the stream will +/// be issued to the run loop when a network event happens, and will be driven +/// by the global run loop thread gGlobalRunLoopThread. +static void grpc_apple_register_write_stream_run_loop( + CFWriteStreamRef write_stream, dispatch_queue_t dispatch_queue) { + GRPC_POLLING_TRACE("Register write stream: %p", write_stream); + grpc_core::MutexLock lock(&gGlobalRunLoopContext->mu); + CFWriteStreamScheduleWithRunLoop( + write_stream, gGlobalRunLoopContext->run_loop, kCFRunLoopDefaultMode); + gGlobalRunLoopContext->input_source_registered = true; + gGlobalRunLoopContext->input_source_cv.Signal(); +} + +/// The default implementation of stream registration is to register the stream +/// to a dispatch queue. However, if the CFRunLoop based pollset is enabled (by +/// macro and environment variable, see docs in iomgr_posix_cfstream.cc), the +/// CFStream streams are registered with the global run loop instead (see +/// pollset_global_init below). +static void (*grpc_apple_register_read_stream_impl)( + CFReadStreamRef, dispatch_queue_t) = grpc_apple_register_read_stream_queue; +static void (*grpc_apple_register_write_stream_impl)(CFWriteStreamRef, + dispatch_queue_t) = + grpc_apple_register_write_stream_queue; + +void grpc_apple_register_read_stream(CFReadStreamRef read_stream, + dispatch_queue_t dispatch_queue) { + grpc_apple_register_read_stream_impl(read_stream, dispatch_queue); +} + +void grpc_apple_register_write_stream(CFWriteStreamRef write_stream, + dispatch_queue_t dispatch_queue) { + grpc_apple_register_write_stream_impl(write_stream, dispatch_queue); +} + +/// Drive the run loop in a global singleton thread until the global run loop is +/// shutdown. +static void GlobalRunLoopFunc(void* arg) { + grpc_core::LockableAndReleasableMutexLock lock(&gGlobalRunLoopContext->mu); + gGlobalRunLoopContext->run_loop = CFRunLoopGetCurrent(); + gGlobalRunLoopContext->init_cv.Signal(); + + while (!gGlobalRunLoopContext->is_shutdown) { + // CFRunLoopRun() will return immediately if no stream is registered on it. + // So we wait on a conditional variable until a stream is registered; + // otherwise we'll be running a spinning loop. + while (!gGlobalRunLoopContext->input_source_registered) { + gGlobalRunLoopContext->input_source_cv.Wait(&gGlobalRunLoopContext->mu); + } + gGlobalRunLoopContext->input_source_registered = false; + lock.Release(); + CFRunLoopRun(); + lock.Lock(); + } + lock.Release(); +} + +// pollset implementation + +static void pollset_global_init(void) { + gGlobalRunLoopContext = new GlobalRunLoopContext; + + grpc_apple_register_read_stream_impl = + grpc_apple_register_read_stream_run_loop; + grpc_apple_register_write_stream_impl = + grpc_apple_register_write_stream_run_loop; + + grpc_core::MutexLock lock(&gGlobalRunLoopContext->mu); + gGlobalRunLoopThread = + new grpc_core::Thread("apple_ev", GlobalRunLoopFunc, nullptr); + gGlobalRunLoopThread->Start(); + while (gGlobalRunLoopContext->run_loop == NULL) + gGlobalRunLoopContext->init_cv.Wait(&gGlobalRunLoopContext->mu); +} + +static void pollset_global_shutdown(void) { + { + grpc_core::MutexLock lock(&gGlobalRunLoopContext->mu); + gGlobalRunLoopContext->is_shutdown = true; + CFRunLoopStop(gGlobalRunLoopContext->run_loop); + } + gGlobalRunLoopThread->Join(); + delete gGlobalRunLoopThread; + delete gGlobalRunLoopContext; +} + +/// The caller must acquire the lock GrpcApplePollset.mu before calling this +/// function. The lock may be temporarily released when waiting on the condition +/// variable but will be re-acquired before the function returns. +/// +/// The Apple pollset simply waits on a condition variable until it is kicked. +/// The network events are handled in the global run loop thread. Processing of +/// these events will eventually trigger the kick. +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker, + grpc_millis deadline) { + GRPC_POLLING_TRACE("pollset work: %p, worker: %p, deadline: %" PRIu64, + pollset, worker, deadline); + GrpcApplePollset* apple_pollset = + reinterpret_cast(pollset); + GrpcAppleWorker actual_worker; + if (worker) { + *worker = reinterpret_cast(&actual_worker); + } + + if (apple_pollset->kicked_without_poller) { + // Process the outstanding kick and reset the flag. Do not block. + apple_pollset->kicked_without_poller = false; + } else { + // Block until kicked, timed out, or the pollset shuts down. + apple_pollset->workers.push_front(&actual_worker); + auto it = apple_pollset->workers.begin(); + + while (!actual_worker.kicked && !apple_pollset->is_shutdown) { + if (actual_worker.cv.WaitWithDeadline( + &apple_pollset->mu, grpc_core::ToAbslTime(grpc_millis_to_timespec( + deadline, GPR_CLOCK_REALTIME)))) { + // timed out + break; + } + } + + apple_pollset->workers.erase(it); + + // If the pollset is shut down asynchronously and this is the last pending + // worker, the shutdown process is complete at this moment and the shutdown + // callback will be called. + if (apple_pollset->is_shutdown && apple_pollset->workers.empty()) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, apple_pollset->shutdown_closure, + GRPC_ERROR_NONE); + } + } + + return GRPC_ERROR_NONE; +} + +/// Kick a specific worker. The caller must acquire the lock GrpcApplePollset.mu +/// before calling this function. +static void kick_worker(GrpcAppleWorker* worker) { + worker->kicked = true; + worker->cv.Signal(); +} + +/// The caller must acquire the lock GrpcApplePollset.mu before calling this +/// function. The kick action simply signals the condition variable of the +/// worker. +static grpc_error_handle pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + GrpcApplePollset* apple_pollset = + reinterpret_cast(pollset); + + GRPC_POLLING_TRACE("pollset kick: %p, worker:%p", pollset, specific_worker); + + if (specific_worker == nullptr) { + if (apple_pollset->workers.empty()) { + apple_pollset->kicked_without_poller = true; + } else { + GrpcAppleWorker* actual_worker = apple_pollset->workers.front(); + kick_worker(actual_worker); + } + } else if (specific_worker == GRPC_POLLSET_KICK_BROADCAST) { + for (auto& actual_worker : apple_pollset->workers) { + kick_worker(actual_worker); + } + } else { + GrpcAppleWorker* actual_worker = + reinterpret_cast(specific_worker); + kick_worker(actual_worker); + } + + return GRPC_ERROR_NONE; +} + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + GRPC_POLLING_TRACE("pollset init: %p", pollset); + GrpcApplePollset* apple_pollset = new (pollset) GrpcApplePollset(); + *mu = grpc_core::GetUnderlyingGprMu(&apple_pollset->mu); +} + +/// The caller must acquire the lock GrpcApplePollset.mu before calling this +/// function. +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + GRPC_POLLING_TRACE("pollset shutdown: %p", pollset); + + GrpcApplePollset* apple_pollset = + reinterpret_cast(pollset); + apple_pollset->is_shutdown = true; + pollset_kick(pollset, GRPC_POLLSET_KICK_BROADCAST); + + // If there is any worker blocked, shutdown will be done asynchronously. + if (apple_pollset->workers.empty()) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + } else { + apple_pollset->shutdown_closure = closure; + } +} + +static void pollset_destroy(grpc_pollset* pollset) { + GRPC_POLLING_TRACE("pollset destroy: %p", pollset); + GrpcApplePollset* apple_pollset = + reinterpret_cast(pollset); + apple_pollset->~GrpcApplePollset(); +} + +size_t pollset_size(void) { return sizeof(GrpcApplePollset); } + +grpc_pollset_vtable grpc_apple_pollset_vtable = { + pollset_global_init, pollset_global_shutdown, + pollset_init, pollset_shutdown, + pollset_destroy, pollset_work, + pollset_kick, pollset_size}; + +// pollset_set implementation + +grpc_pollset_set* pollset_set_create(void) { return nullptr; } +void pollset_set_destroy(grpc_pollset_set* pollset_set) {} +void pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} +void pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} +void pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} +void pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} + +grpc_pollset_set_vtable grpc_apple_pollset_set_vtable = { + pollset_set_create, pollset_set_destroy, + pollset_set_add_pollset, pollset_set_del_pollset, + pollset_set_add_pollset_set, pollset_set_del_pollset_set}; + +#endif diff --git a/src/core/lib/iomgr/ev_epoll1_linux.cc b/src/core/lib/iomgr/ev_epoll1_linux.cc new file mode 100644 index 00000000..cecf15c6 --- /dev/null +++ b/src/core/lib/iomgr/ev_epoll1_linux.cc @@ -0,0 +1,1364 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/port.h" + +/* This polling engine is only relevant on linux kernels supporting epoll + epoll_create() or epoll_create1() */ +#ifdef GRPC_LINUX_EPOLL +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/block_annotate.h" +#include "src/core/lib/iomgr/ev_epoll1_linux.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/lockfree_event.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "src/core/lib/profiling/timers.h" + +static grpc_wakeup_fd global_wakeup_fd; + +/******************************************************************************* + * Singleton epoll set related fields + */ + +#define MAX_EPOLL_EVENTS 100 +#define MAX_EPOLL_EVENTS_HANDLED_PER_ITERATION 1 + +/* NOTE ON SYNCHRONIZATION: + * - Fields in this struct are only modified by the designated poller. Hence + * there is no need for any locks to protect the struct. + * - num_events and cursor fields have to be of atomic type to provide memory + * visibility guarantees only. i.e In case of multiple pollers, the designated + * polling thread keeps changing; the thread that wrote these values may be + * different from the thread reading the values + */ +typedef struct epoll_set { + int epfd; + + /* The epoll_events after the last call to epoll_wait() */ + struct epoll_event events[MAX_EPOLL_EVENTS]; + + /* The number of epoll_events after the last call to epoll_wait() */ + gpr_atm num_events; + + /* Index of the first event in epoll_events that has to be processed. This + * field is only valid if num_events > 0 */ + gpr_atm cursor; +} epoll_set; + +/* The global singleton epoll set */ +static epoll_set g_epoll_set; + +static int epoll_create_and_cloexec() { +#ifdef GRPC_LINUX_EPOLL_CREATE1 + int fd = epoll_create1(EPOLL_CLOEXEC); + if (fd < 0) { + gpr_log(GPR_ERROR, "epoll_create1 unavailable"); + } +#else + int fd = epoll_create(MAX_EPOLL_EVENTS); + if (fd < 0) { + gpr_log(GPR_ERROR, "epoll_create unavailable"); + } else if (fcntl(fd, F_SETFD, FD_CLOEXEC) != 0) { + gpr_log(GPR_ERROR, "fcntl following epoll_create failed"); + return -1; + } +#endif + return fd; +} + +/* Must be called *only* once */ +static bool epoll_set_init() { + g_epoll_set.epfd = epoll_create_and_cloexec(); + if (g_epoll_set.epfd < 0) { + return false; + } + + gpr_log(GPR_INFO, "grpc epoll fd: %d", g_epoll_set.epfd); + gpr_atm_no_barrier_store(&g_epoll_set.num_events, 0); + gpr_atm_no_barrier_store(&g_epoll_set.cursor, 0); + return true; +} + +/* epoll_set_init() MUST be called before calling this. */ +static void epoll_set_shutdown() { + if (g_epoll_set.epfd >= 0) { + close(g_epoll_set.epfd); + g_epoll_set.epfd = -1; + } +} + +/******************************************************************************* + * Fd Declarations + */ + +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +struct grpc_fork_fd_list { + grpc_fd* fd; + grpc_fd* next; + grpc_fd* prev; +}; + +struct grpc_fd { + int fd; + + grpc_core::ManualConstructor read_closure; + grpc_core::ManualConstructor write_closure; + grpc_core::ManualConstructor error_closure; + + struct grpc_fd* freelist_next; + + grpc_iomgr_object iomgr_object; + + /* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ + grpc_fork_fd_list* fork_fd_list; +}; + +static void fd_global_init(void); +static void fd_global_shutdown(void); + +/******************************************************************************* + * Pollset Declarations + */ + +typedef enum { UNKICKED, KICKED, DESIGNATED_POLLER } kick_state; + +static const char* kick_state_string(kick_state st) { + switch (st) { + case UNKICKED: + return "UNKICKED"; + case KICKED: + return "KICKED"; + case DESIGNATED_POLLER: + return "DESIGNATED_POLLER"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +struct grpc_pollset_worker { + kick_state state; + int kick_state_mutator; // which line of code last changed kick state + bool initialized_cv; + grpc_pollset_worker* next; + grpc_pollset_worker* prev; + gpr_cv cv; + grpc_closure_list schedule_on_end_work; +}; + +#define SET_KICK_STATE(worker, kick_state) \ + do { \ + (worker)->state = (kick_state); \ + (worker)->kick_state_mutator = __LINE__; \ + } while (false) + +#define MAX_NEIGHBORHOODS 1024u + +typedef struct pollset_neighborhood { + union { + char pad[GPR_CACHELINE_SIZE]; + struct { + gpr_mu mu; + grpc_pollset* active_root; + }; + }; +} pollset_neighborhood; + +struct grpc_pollset { + gpr_mu mu; + pollset_neighborhood* neighborhood; + bool reassigning_neighborhood; + grpc_pollset_worker* root_worker; + bool kicked_without_poller; + + /* Set to true if the pollset is observed to have no workers available to + poll */ + bool seen_inactive; + bool shutting_down; /* Is the pollset shutting down ? */ + grpc_closure* shutdown_closure; /* Called after shutdown is complete */ + + /* Number of workers who are *about-to* attach themselves to the pollset + * worker list */ + int begin_refs; + + grpc_pollset* next; + grpc_pollset* prev; +}; + +/******************************************************************************* + * Pollset-set Declarations + */ + +struct grpc_pollset_set { + char unused; +}; + +/******************************************************************************* + * Common helpers + */ + +static bool append_error(grpc_error_handle* composite, grpc_error_handle error, + const char* desc) { + if (error == GRPC_ERROR_NONE) return true; + if (*composite == GRPC_ERROR_NONE) { + *composite = GRPC_ERROR_CREATE_FROM_COPIED_STRING(desc); + } + *composite = grpc_error_add_child(*composite, error); + return false; +} + +/******************************************************************************* + * Fd Definitions + */ + +/* We need to keep a freelist not because of any concerns of malloc performance + * but instead so that implementations with multiple threads in (for example) + * epoll_wait deal with the race between pollset removal and incoming poll + * notifications. + * + * The problem is that the poller ultimately holds a reference to this + * object, so it is very difficult to know when is safe to free it, at least + * without some expensive synchronization. + * + * If we keep the object freelisted, in the worst case losing this race just + * becomes a spurious read notification on a reused fd. + */ + +/* The alarm system needs to be able to wakeup 'some poller' sometimes + * (specifically when a new alarm needs to be triggered earlier than the next + * alarm 'epoch'). This wakeup_fd gives us something to alert on when such a + * case occurs. */ + +static grpc_fd* fd_freelist = nullptr; +static gpr_mu fd_freelist_mu; + +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +static grpc_fd* fork_fd_list_head = nullptr; +static gpr_mu fork_fd_list_mu; + +static void fd_global_init(void) { gpr_mu_init(&fd_freelist_mu); } + +static void fd_global_shutdown(void) { + // TODO(guantaol): We don't have a reasonable explanation about this + // lock()/unlock() pattern. It can be a valid barrier if there is at most one + // pending lock() at this point. Otherwise, there is still a possibility of + // use-after-free race. Need to reason about the code and/or clean it up. + gpr_mu_lock(&fd_freelist_mu); + gpr_mu_unlock(&fd_freelist_mu); + while (fd_freelist != nullptr) { + grpc_fd* fd = fd_freelist; + fd_freelist = fd_freelist->freelist_next; + gpr_free(fd); + } + gpr_mu_destroy(&fd_freelist_mu); +} + +static void fork_fd_list_add_grpc_fd(grpc_fd* fd) { + if (grpc_core::Fork::Enabled()) { + gpr_mu_lock(&fork_fd_list_mu); + fd->fork_fd_list = + static_cast(gpr_malloc(sizeof(grpc_fork_fd_list))); + fd->fork_fd_list->next = fork_fd_list_head; + fd->fork_fd_list->prev = nullptr; + if (fork_fd_list_head != nullptr) { + fork_fd_list_head->fork_fd_list->prev = fd; + } + fork_fd_list_head = fd; + gpr_mu_unlock(&fork_fd_list_mu); + } +} + +static void fork_fd_list_remove_grpc_fd(grpc_fd* fd) { + if (grpc_core::Fork::Enabled()) { + gpr_mu_lock(&fork_fd_list_mu); + if (fork_fd_list_head == fd) { + fork_fd_list_head = fd->fork_fd_list->next; + } + if (fd->fork_fd_list->prev != nullptr) { + fd->fork_fd_list->prev->fork_fd_list->next = fd->fork_fd_list->next; + } + if (fd->fork_fd_list->next != nullptr) { + fd->fork_fd_list->next->fork_fd_list->prev = fd->fork_fd_list->prev; + } + gpr_free(fd->fork_fd_list); + gpr_mu_unlock(&fork_fd_list_mu); + } +} + +static grpc_fd* fd_create(int fd, const char* name, bool track_err) { + grpc_fd* new_fd = nullptr; + + gpr_mu_lock(&fd_freelist_mu); + if (fd_freelist != nullptr) { + new_fd = fd_freelist; + fd_freelist = fd_freelist->freelist_next; + } + gpr_mu_unlock(&fd_freelist_mu); + + if (new_fd == nullptr) { + new_fd = static_cast(gpr_malloc(sizeof(grpc_fd))); + new_fd->read_closure.Init(); + new_fd->write_closure.Init(); + new_fd->error_closure.Init(); + } + new_fd->fd = fd; + new_fd->read_closure->InitEvent(); + new_fd->write_closure->InitEvent(); + new_fd->error_closure->InitEvent(); + + new_fd->freelist_next = nullptr; + + std::string fd_name = absl::StrCat(name, " fd=", fd); + grpc_iomgr_register_object(&new_fd->iomgr_object, fd_name.c_str()); + fork_fd_list_add_grpc_fd(new_fd); +#ifndef NDEBUG + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, "FD %d %p create %s", fd, new_fd, fd_name.c_str()); + } +#endif + + struct epoll_event ev; + ev.events = static_cast(EPOLLIN | EPOLLOUT | EPOLLET); + /* Use the least significant bit of ev.data.ptr to store track_err. We expect + * the addresses to be word aligned. We need to store track_err to avoid + * synchronization issues when accessing it after receiving an event. + * Accessing fd would be a data race there because the fd might have been + * returned to the free list at that point. */ + ev.data.ptr = reinterpret_cast(reinterpret_cast(new_fd) | + (track_err ? 1 : 0)); + if (epoll_ctl(g_epoll_set.epfd, EPOLL_CTL_ADD, fd, &ev) != 0) { + gpr_log(GPR_ERROR, "epoll_ctl failed: %s", strerror(errno)); + } + + return new_fd; +} + +static int fd_wrapped_fd(grpc_fd* fd) { return fd->fd; } + +/* if 'releasing_fd' is true, it means that we are going to detach the internal + * fd from grpc_fd structure (i.e which means we should not be calling + * shutdown() syscall on that fd) */ +static void fd_shutdown_internal(grpc_fd* fd, grpc_error_handle why, + bool releasing_fd) { + if (fd->read_closure->SetShutdown(GRPC_ERROR_REF(why))) { + if (!releasing_fd) { + shutdown(fd->fd, SHUT_RDWR); + } else { + /* we need a phony event for earlier linux versions. */ + epoll_event phony_event; + if (epoll_ctl(g_epoll_set.epfd, EPOLL_CTL_DEL, fd->fd, &phony_event) != + 0) { + gpr_log(GPR_ERROR, "epoll_ctl failed: %s", strerror(errno)); + } + } + fd->write_closure->SetShutdown(GRPC_ERROR_REF(why)); + fd->error_closure->SetShutdown(GRPC_ERROR_REF(why)); + } + GRPC_ERROR_UNREF(why); +} + +/* Might be called multiple times */ +static void fd_shutdown(grpc_fd* fd, grpc_error_handle why) { + fd_shutdown_internal(fd, why, false); +} + +static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd, + const char* reason) { + grpc_error_handle error = GRPC_ERROR_NONE; + bool is_release_fd = (release_fd != nullptr); + + if (!fd->read_closure->IsShutdown()) { + fd_shutdown_internal(fd, GRPC_ERROR_CREATE_FROM_COPIED_STRING(reason), + is_release_fd); + } + + /* If release_fd is not NULL, we should be relinquishing control of the file + descriptor fd->fd (but we still own the grpc_fd structure). */ + if (is_release_fd) { + *release_fd = fd->fd; + } else { + close(fd->fd); + } + + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_REF(error)); + + grpc_iomgr_unregister_object(&fd->iomgr_object); + fork_fd_list_remove_grpc_fd(fd); + fd->read_closure->DestroyEvent(); + fd->write_closure->DestroyEvent(); + fd->error_closure->DestroyEvent(); + + gpr_mu_lock(&fd_freelist_mu); + fd->freelist_next = fd_freelist; + fd_freelist = fd; + gpr_mu_unlock(&fd_freelist_mu); +} + +static bool fd_is_shutdown(grpc_fd* fd) { + return fd->read_closure->IsShutdown(); +} + +static void fd_notify_on_read(grpc_fd* fd, grpc_closure* closure) { + fd->read_closure->NotifyOn(closure); +} + +static void fd_notify_on_write(grpc_fd* fd, grpc_closure* closure) { + fd->write_closure->NotifyOn(closure); +} + +static void fd_notify_on_error(grpc_fd* fd, grpc_closure* closure) { + fd->error_closure->NotifyOn(closure); +} + +static void fd_become_readable(grpc_fd* fd) { fd->read_closure->SetReady(); } + +static void fd_become_writable(grpc_fd* fd) { fd->write_closure->SetReady(); } + +static void fd_has_errors(grpc_fd* fd) { fd->error_closure->SetReady(); } + +/******************************************************************************* + * Pollset Definitions + */ + +static GPR_THREAD_LOCAL(grpc_pollset*) g_current_thread_pollset; +static GPR_THREAD_LOCAL(grpc_pollset_worker*) g_current_thread_worker; + +/* The designated poller */ +static gpr_atm g_active_poller; + +static pollset_neighborhood* g_neighborhoods; +static size_t g_num_neighborhoods; + +/* Return true if first in list */ +static bool worker_insert(grpc_pollset* pollset, grpc_pollset_worker* worker) { + if (pollset->root_worker == nullptr) { + pollset->root_worker = worker; + worker->next = worker->prev = worker; + return true; + } else { + worker->next = pollset->root_worker; + worker->prev = worker->next->prev; + worker->next->prev = worker; + worker->prev->next = worker; + return false; + } +} + +/* Return true if last in list */ +typedef enum { EMPTIED, NEW_ROOT, REMOVED } worker_remove_result; + +static worker_remove_result worker_remove(grpc_pollset* pollset, + grpc_pollset_worker* worker) { + if (worker == pollset->root_worker) { + if (worker == worker->next) { + pollset->root_worker = nullptr; + return EMPTIED; + } else { + pollset->root_worker = worker->next; + worker->prev->next = worker->next; + worker->next->prev = worker->prev; + return NEW_ROOT; + } + } else { + worker->prev->next = worker->next; + worker->next->prev = worker->prev; + return REMOVED; + } +} + +static size_t choose_neighborhood(void) { + return static_cast(gpr_cpu_current_cpu()) % g_num_neighborhoods; +} + +static grpc_error_handle pollset_global_init(void) { + gpr_atm_no_barrier_store(&g_active_poller, 0); + global_wakeup_fd.read_fd = -1; + grpc_error_handle err = grpc_wakeup_fd_init(&global_wakeup_fd); + if (err != GRPC_ERROR_NONE) return err; + struct epoll_event ev; + ev.events = static_cast(EPOLLIN | EPOLLET); + ev.data.ptr = &global_wakeup_fd; + if (epoll_ctl(g_epoll_set.epfd, EPOLL_CTL_ADD, global_wakeup_fd.read_fd, + &ev) != 0) { + return GRPC_OS_ERROR(errno, "epoll_ctl"); + } + g_num_neighborhoods = + grpc_core::Clamp(gpr_cpu_num_cores(), 1u, MAX_NEIGHBORHOODS); + g_neighborhoods = static_cast( + gpr_zalloc(sizeof(*g_neighborhoods) * g_num_neighborhoods)); + for (size_t i = 0; i < g_num_neighborhoods; i++) { + gpr_mu_init(&g_neighborhoods[i].mu); + } + return GRPC_ERROR_NONE; +} + +static void pollset_global_shutdown(void) { + if (global_wakeup_fd.read_fd != -1) grpc_wakeup_fd_destroy(&global_wakeup_fd); + for (size_t i = 0; i < g_num_neighborhoods; i++) { + gpr_mu_destroy(&g_neighborhoods[i].mu); + } + gpr_free(g_neighborhoods); +} + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + gpr_mu_init(&pollset->mu); + *mu = &pollset->mu; + pollset->neighborhood = &g_neighborhoods[choose_neighborhood()]; + pollset->reassigning_neighborhood = false; + pollset->root_worker = nullptr; + pollset->kicked_without_poller = false; + pollset->seen_inactive = true; + pollset->shutting_down = false; + pollset->shutdown_closure = nullptr; + pollset->begin_refs = 0; + pollset->next = pollset->prev = nullptr; +} + +static void pollset_destroy(grpc_pollset* pollset) { + gpr_mu_lock(&pollset->mu); + if (!pollset->seen_inactive) { + pollset_neighborhood* neighborhood = pollset->neighborhood; + gpr_mu_unlock(&pollset->mu); + retry_lock_neighborhood: + gpr_mu_lock(&neighborhood->mu); + gpr_mu_lock(&pollset->mu); + if (!pollset->seen_inactive) { + if (pollset->neighborhood != neighborhood) { + gpr_mu_unlock(&neighborhood->mu); + neighborhood = pollset->neighborhood; + gpr_mu_unlock(&pollset->mu); + goto retry_lock_neighborhood; + } + pollset->prev->next = pollset->next; + pollset->next->prev = pollset->prev; + if (pollset == pollset->neighborhood->active_root) { + pollset->neighborhood->active_root = + pollset->next == pollset ? nullptr : pollset->next; + } + } + gpr_mu_unlock(&pollset->neighborhood->mu); + } + gpr_mu_unlock(&pollset->mu); + gpr_mu_destroy(&pollset->mu); +} + +static grpc_error_handle pollset_kick_all(grpc_pollset* pollset) { + GPR_TIMER_SCOPE("pollset_kick_all", 0); + grpc_error_handle error = GRPC_ERROR_NONE; + if (pollset->root_worker != nullptr) { + grpc_pollset_worker* worker = pollset->root_worker; + do { + GRPC_STATS_INC_POLLSET_KICK(); + switch (worker->state) { + case KICKED: + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + break; + case UNKICKED: + SET_KICK_STATE(worker, KICKED); + if (worker->initialized_cv) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + gpr_cv_signal(&worker->cv); + } + break; + case DESIGNATED_POLLER: + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_FD(); + SET_KICK_STATE(worker, KICKED); + append_error(&error, grpc_wakeup_fd_wakeup(&global_wakeup_fd), + "pollset_kick_all"); + break; + } + + worker = worker->next; + } while (worker != pollset->root_worker); + } + // TODO(sreek): Check if we need to set 'kicked_without_poller' to true here + // in the else case + return error; +} + +static void pollset_maybe_finish_shutdown(grpc_pollset* pollset) { + if (pollset->shutdown_closure != nullptr && pollset->root_worker == nullptr && + pollset->begin_refs == 0) { + GPR_TIMER_MARK("pollset_finish_shutdown", 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, pollset->shutdown_closure, + GRPC_ERROR_NONE); + pollset->shutdown_closure = nullptr; + } +} + +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + GPR_TIMER_SCOPE("pollset_shutdown", 0); + GPR_ASSERT(pollset->shutdown_closure == nullptr); + GPR_ASSERT(!pollset->shutting_down); + pollset->shutdown_closure = closure; + pollset->shutting_down = true; + GRPC_LOG_IF_ERROR("pollset_shutdown", pollset_kick_all(pollset)); + pollset_maybe_finish_shutdown(pollset); +} + +static int poll_deadline_to_millis_timeout(grpc_millis millis) { + if (millis == GRPC_MILLIS_INF_FUTURE) return -1; + grpc_millis delta = millis - grpc_core::ExecCtx::Get()->Now(); + if (delta > INT_MAX) { + return INT_MAX; + } else if (delta < 0) { + return 0; + } else { + return static_cast(delta); + } +} + +/* Process the epoll events found by do_epoll_wait() function. + - g_epoll_set.cursor points to the index of the first event to be processed + - This function then processes up-to MAX_EPOLL_EVENTS_PER_ITERATION and + updates the g_epoll_set.cursor + + NOTE ON SYNCRHONIZATION: Similar to do_epoll_wait(), this function is only + called by g_active_poller thread. So there is no need for synchronization + when accessing fields in g_epoll_set */ +static grpc_error_handle process_epoll_events(grpc_pollset* /*pollset*/) { + GPR_TIMER_SCOPE("process_epoll_events", 0); + + static const char* err_desc = "process_events"; + grpc_error_handle error = GRPC_ERROR_NONE; + long num_events = gpr_atm_acq_load(&g_epoll_set.num_events); + long cursor = gpr_atm_acq_load(&g_epoll_set.cursor); + for (int idx = 0; + (idx < MAX_EPOLL_EVENTS_HANDLED_PER_ITERATION) && cursor != num_events; + idx++) { + long c = cursor++; + struct epoll_event* ev = &g_epoll_set.events[c]; + void* data_ptr = ev->data.ptr; + + if (data_ptr == &global_wakeup_fd) { + append_error(&error, grpc_wakeup_fd_consume_wakeup(&global_wakeup_fd), + err_desc); + } else { + grpc_fd* fd = reinterpret_cast( + reinterpret_cast(data_ptr) & ~static_cast(1)); + bool track_err = + reinterpret_cast(data_ptr) & static_cast(1); + bool cancel = (ev->events & EPOLLHUP) != 0; + bool error = (ev->events & EPOLLERR) != 0; + bool read_ev = (ev->events & (EPOLLIN | EPOLLPRI)) != 0; + bool write_ev = (ev->events & EPOLLOUT) != 0; + bool err_fallback = error && !track_err; + + if (error && !err_fallback) { + fd_has_errors(fd); + } + + if (read_ev || cancel || err_fallback) { + fd_become_readable(fd); + } + + if (write_ev || cancel || err_fallback) { + fd_become_writable(fd); + } + } + } + gpr_atm_rel_store(&g_epoll_set.cursor, cursor); + return error; +} + +/* Do epoll_wait and store the events in g_epoll_set.events field. This does not + "process" any of the events yet; that is done in process_epoll_events(). + *See process_epoll_events() function for more details. + + NOTE ON SYNCHRONIZATION: At any point of time, only the g_active_poller + (i.e the designated poller thread) will be calling this function. So there is + no need for any synchronization when accesing fields in g_epoll_set */ +static grpc_error_handle do_epoll_wait(grpc_pollset* ps, grpc_millis deadline) { + GPR_TIMER_SCOPE("do_epoll_wait", 0); + + int r; + int timeout = poll_deadline_to_millis_timeout(deadline); + if (timeout != 0) { + GRPC_SCHEDULING_START_BLOCKING_REGION; + } + do { + GRPC_STATS_INC_SYSCALL_POLL(); + r = epoll_wait(g_epoll_set.epfd, g_epoll_set.events, MAX_EPOLL_EVENTS, + timeout); + } while (r < 0 && errno == EINTR); + if (timeout != 0) { + GRPC_SCHEDULING_END_BLOCKING_REGION; + } + + if (r < 0) return GRPC_OS_ERROR(errno, "epoll_wait"); + + GRPC_STATS_INC_POLL_EVENTS_RETURNED(r); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "ps: %p poll got %d events", ps, r); + } + + gpr_atm_rel_store(&g_epoll_set.num_events, r); + gpr_atm_rel_store(&g_epoll_set.cursor, 0); + + return GRPC_ERROR_NONE; +} + +static bool begin_worker(grpc_pollset* pollset, grpc_pollset_worker* worker, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + GPR_TIMER_SCOPE("begin_worker", 0); + if (worker_hdl != nullptr) *worker_hdl = worker; + worker->initialized_cv = false; + SET_KICK_STATE(worker, UNKICKED); + worker->schedule_on_end_work = (grpc_closure_list)GRPC_CLOSURE_LIST_INIT; + pollset->begin_refs++; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p BEGIN_STARTS:%p", pollset, worker); + } + + if (pollset->seen_inactive) { + // pollset has been observed to be inactive, we need to move back to the + // active list + bool is_reassigning = false; + if (!pollset->reassigning_neighborhood) { + is_reassigning = true; + pollset->reassigning_neighborhood = true; + pollset->neighborhood = &g_neighborhoods[choose_neighborhood()]; + } + pollset_neighborhood* neighborhood = pollset->neighborhood; + gpr_mu_unlock(&pollset->mu); + // pollset unlocked: state may change (even worker->kick_state) + retry_lock_neighborhood: + gpr_mu_lock(&neighborhood->mu); + gpr_mu_lock(&pollset->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p BEGIN_REORG:%p kick_state=%s is_reassigning=%d", + pollset, worker, kick_state_string(worker->state), + is_reassigning); + } + if (pollset->seen_inactive) { + if (neighborhood != pollset->neighborhood) { + gpr_mu_unlock(&neighborhood->mu); + neighborhood = pollset->neighborhood; + gpr_mu_unlock(&pollset->mu); + goto retry_lock_neighborhood; + } + + /* In the brief time we released the pollset locks above, the worker MAY + have been kicked. In this case, the worker should get out of this + pollset ASAP and hence this should neither add the pollset to + neighborhood nor mark the pollset as active. + + On a side note, the only way a worker's kick state could have changed + at this point is if it were "kicked specifically". Since the worker has + not added itself to the pollset yet (by calling worker_insert()), it is + not visible in the "kick any" path yet */ + if (worker->state == UNKICKED) { + pollset->seen_inactive = false; + if (neighborhood->active_root == nullptr) { + neighborhood->active_root = pollset->next = pollset->prev = pollset; + /* Make this the designated poller if there isn't one already */ + if (worker->state == UNKICKED && + gpr_atm_no_barrier_cas(&g_active_poller, 0, + reinterpret_cast(worker))) { + SET_KICK_STATE(worker, DESIGNATED_POLLER); + } + } else { + pollset->next = neighborhood->active_root; + pollset->prev = pollset->next->prev; + pollset->next->prev = pollset->prev->next = pollset; + } + } + } + if (is_reassigning) { + GPR_ASSERT(pollset->reassigning_neighborhood); + pollset->reassigning_neighborhood = false; + } + gpr_mu_unlock(&neighborhood->mu); + } + + worker_insert(pollset, worker); + pollset->begin_refs--; + if (worker->state == UNKICKED && !pollset->kicked_without_poller) { + GPR_ASSERT(gpr_atm_no_barrier_load(&g_active_poller) != (gpr_atm)worker); + worker->initialized_cv = true; + gpr_cv_init(&worker->cv); + while (worker->state == UNKICKED && !pollset->shutting_down) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p BEGIN_WAIT:%p kick_state=%s shutdown=%d", + pollset, worker, kick_state_string(worker->state), + pollset->shutting_down); + } + + if (gpr_cv_wait(&worker->cv, &pollset->mu, + grpc_millis_to_timespec(deadline, GPR_CLOCK_MONOTONIC)) && + worker->state == UNKICKED) { + /* If gpr_cv_wait returns true (i.e a timeout), pretend that the worker + received a kick */ + SET_KICK_STATE(worker, KICKED); + } + } + grpc_core::ExecCtx::Get()->InvalidateNow(); + } + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p BEGIN_DONE:%p kick_state=%s shutdown=%d " + "kicked_without_poller: %d", + pollset, worker, kick_state_string(worker->state), + pollset->shutting_down, pollset->kicked_without_poller); + } + + /* We release pollset lock in this function at a couple of places: + * 1. Briefly when assigning pollset to a neighborhood + * 2. When doing gpr_cv_wait() + * It is possible that 'kicked_without_poller' was set to true during (1) and + * 'shutting_down' is set to true during (1) or (2). If either of them is + * true, this worker cannot do polling */ + /* TODO(sreek): Perhaps there is a better way to handle kicked_without_poller + * case; especially when the worker is the DESIGNATED_POLLER */ + + if (pollset->kicked_without_poller) { + pollset->kicked_without_poller = false; + return false; + } + + return worker->state == DESIGNATED_POLLER && !pollset->shutting_down; +} + +static bool check_neighborhood_for_available_poller( + pollset_neighborhood* neighborhood) { + GPR_TIMER_SCOPE("check_neighborhood_for_available_poller", 0); + bool found_worker = false; + do { + grpc_pollset* inspect = neighborhood->active_root; + if (inspect == nullptr) { + break; + } + gpr_mu_lock(&inspect->mu); + GPR_ASSERT(!inspect->seen_inactive); + grpc_pollset_worker* inspect_worker = inspect->root_worker; + if (inspect_worker != nullptr) { + do { + switch (inspect_worker->state) { + case UNKICKED: + if (gpr_atm_no_barrier_cas( + &g_active_poller, 0, + reinterpret_cast(inspect_worker))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. choose next poller to be %p", + inspect_worker); + } + SET_KICK_STATE(inspect_worker, DESIGNATED_POLLER); + if (inspect_worker->initialized_cv) { + GPR_TIMER_MARK("signal worker", 0); + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + gpr_cv_signal(&inspect_worker->cv); + } + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. beaten to choose next poller"); + } + } + // even if we didn't win the cas, there's a worker, we can stop + found_worker = true; + break; + case KICKED: + break; + case DESIGNATED_POLLER: + found_worker = true; // ok, so someone else found the worker, but + // we'll accept that + break; + } + inspect_worker = inspect_worker->next; + } while (!found_worker && inspect_worker != inspect->root_worker); + } + if (!found_worker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. mark pollset %p inactive", inspect); + } + inspect->seen_inactive = true; + if (inspect == neighborhood->active_root) { + neighborhood->active_root = + inspect->next == inspect ? nullptr : inspect->next; + } + inspect->next->prev = inspect->prev; + inspect->prev->next = inspect->next; + inspect->next = inspect->prev = nullptr; + } + gpr_mu_unlock(&inspect->mu); + } while (!found_worker); + return found_worker; +} + +static void end_worker(grpc_pollset* pollset, grpc_pollset_worker* worker, + grpc_pollset_worker** worker_hdl) { + GPR_TIMER_SCOPE("end_worker", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p END_WORKER:%p", pollset, worker); + } + if (worker_hdl != nullptr) *worker_hdl = nullptr; + /* Make sure we appear kicked */ + SET_KICK_STATE(worker, KICKED); + grpc_closure_list_move(&worker->schedule_on_end_work, + grpc_core::ExecCtx::Get()->closure_list()); + if (gpr_atm_no_barrier_load(&g_active_poller) == + reinterpret_cast(worker)) { + if (worker->next != worker && worker->next->state == UNKICKED) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. choose next poller to be peer %p", worker); + } + GPR_ASSERT(worker->next->initialized_cv); + gpr_atm_no_barrier_store(&g_active_poller, (gpr_atm)worker->next); + SET_KICK_STATE(worker->next, DESIGNATED_POLLER); + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + gpr_cv_signal(&worker->next->cv); + if (grpc_core::ExecCtx::Get()->HasWork()) { + gpr_mu_unlock(&pollset->mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&pollset->mu); + } + } else { + gpr_atm_no_barrier_store(&g_active_poller, 0); + size_t poller_neighborhood_idx = + static_cast(pollset->neighborhood - g_neighborhoods); + gpr_mu_unlock(&pollset->mu); + bool found_worker = false; + bool scan_state[MAX_NEIGHBORHOODS]; + for (size_t i = 0; !found_worker && i < g_num_neighborhoods; i++) { + pollset_neighborhood* neighborhood = + &g_neighborhoods[(poller_neighborhood_idx + i) % + g_num_neighborhoods]; + if (gpr_mu_trylock(&neighborhood->mu)) { + found_worker = check_neighborhood_for_available_poller(neighborhood); + gpr_mu_unlock(&neighborhood->mu); + scan_state[i] = true; + } else { + scan_state[i] = false; + } + } + for (size_t i = 0; !found_worker && i < g_num_neighborhoods; i++) { + if (scan_state[i]) continue; + pollset_neighborhood* neighborhood = + &g_neighborhoods[(poller_neighborhood_idx + i) % + g_num_neighborhoods]; + gpr_mu_lock(&neighborhood->mu); + found_worker = check_neighborhood_for_available_poller(neighborhood); + gpr_mu_unlock(&neighborhood->mu); + } + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&pollset->mu); + } + } else if (grpc_core::ExecCtx::Get()->HasWork()) { + gpr_mu_unlock(&pollset->mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&pollset->mu); + } + if (worker->initialized_cv) { + gpr_cv_destroy(&worker->cv); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. remove worker"); + } + if (EMPTIED == worker_remove(pollset, worker)) { + pollset_maybe_finish_shutdown(pollset); + } + GPR_ASSERT(gpr_atm_no_barrier_load(&g_active_poller) != (gpr_atm)worker); +} + +/* pollset->po.mu lock must be held by the caller before calling this. + The function pollset_work() may temporarily release the lock (pollset->po.mu) + during the course of its execution but it will always re-acquire the lock and + ensure that it is held by the time the function returns */ +static grpc_error_handle pollset_work(grpc_pollset* ps, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + GPR_TIMER_SCOPE("pollset_work", 0); + grpc_pollset_worker worker; + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "pollset_work"; + if (ps->kicked_without_poller) { + ps->kicked_without_poller = false; + return GRPC_ERROR_NONE; + } + + if (begin_worker(ps, &worker, worker_hdl, deadline)) { + g_current_thread_pollset = ps; + g_current_thread_worker = &worker; + GPR_ASSERT(!ps->shutting_down); + GPR_ASSERT(!ps->seen_inactive); + + gpr_mu_unlock(&ps->mu); /* unlock */ + /* This is the designated polling thread at this point and should ideally do + polling. However, if there are unprocessed events left from a previous + call to do_epoll_wait(), skip calling epoll_wait() in this iteration and + process the pending epoll events. + + The reason for decoupling do_epoll_wait and process_epoll_events is to + better distribute the work (i.e handling epoll events) across multiple + threads + + process_epoll_events() returns very quickly: It just queues the work on + exec_ctx but does not execute it (the actual exectution or more + accurately grpc_core::ExecCtx::Get()->Flush() happens in end_worker() + AFTER selecting a designated poller). So we are not waiting long periods + without a designated poller */ + if (gpr_atm_acq_load(&g_epoll_set.cursor) == + gpr_atm_acq_load(&g_epoll_set.num_events)) { + append_error(&error, do_epoll_wait(ps, deadline), err_desc); + } + append_error(&error, process_epoll_events(ps), err_desc); + + gpr_mu_lock(&ps->mu); /* lock */ + + g_current_thread_worker = nullptr; + } else { + g_current_thread_pollset = ps; + } + end_worker(ps, &worker, worker_hdl); + + g_current_thread_pollset = nullptr; + return error; +} + +static grpc_error_handle pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + GPR_TIMER_SCOPE("pollset_kick", 0); + GRPC_STATS_INC_POLLSET_KICK(); + grpc_error_handle ret_err = GRPC_ERROR_NONE; + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + std::vector log; + log.push_back(absl::StrFormat( + "PS:%p KICK:%p curps=%p curworker=%p root=%p", pollset, specific_worker, + static_cast(g_current_thread_pollset), + static_cast(g_current_thread_worker), pollset->root_worker)); + if (pollset->root_worker != nullptr) { + log.push_back(absl::StrFormat( + " {kick_state=%s next=%p {kick_state=%s}}", + kick_state_string(pollset->root_worker->state), + pollset->root_worker->next, + kick_state_string(pollset->root_worker->next->state))); + } + if (specific_worker != nullptr) { + log.push_back(absl::StrFormat(" worker_kick_state=%s", + kick_state_string(specific_worker->state))); + } + gpr_log(GPR_DEBUG, "%s", absl::StrJoin(log, "").c_str()); + } + + if (specific_worker == nullptr) { + if (g_current_thread_pollset != pollset) { + grpc_pollset_worker* root_worker = pollset->root_worker; + if (root_worker == nullptr) { + GRPC_STATS_INC_POLLSET_KICKED_WITHOUT_POLLER(); + pollset->kicked_without_poller = true; + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kicked_without_poller"); + } + goto done; + } + grpc_pollset_worker* next_worker = root_worker->next; + if (root_worker->state == KICKED) { + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. already kicked %p", root_worker); + } + SET_KICK_STATE(root_worker, KICKED); + goto done; + } else if (next_worker->state == KICKED) { + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. already kicked %p", next_worker); + } + SET_KICK_STATE(next_worker, KICKED); + goto done; + } else if (root_worker == next_worker && // only try and wake up a poller + // if there is no next worker + root_worker == + reinterpret_cast( + gpr_atm_no_barrier_load(&g_active_poller))) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_FD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kicked %p", root_worker); + } + SET_KICK_STATE(root_worker, KICKED); + ret_err = grpc_wakeup_fd_wakeup(&global_wakeup_fd); + goto done; + } else if (next_worker->state == UNKICKED) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kicked %p", next_worker); + } + GPR_ASSERT(next_worker->initialized_cv); + SET_KICK_STATE(next_worker, KICKED); + gpr_cv_signal(&next_worker->cv); + goto done; + } else if (next_worker->state == DESIGNATED_POLLER) { + if (root_worker->state != DESIGNATED_POLLER) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log( + GPR_INFO, + " .. kicked root non-poller %p (initialized_cv=%d) (poller=%p)", + root_worker, root_worker->initialized_cv, next_worker); + } + SET_KICK_STATE(root_worker, KICKED); + if (root_worker->initialized_cv) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + gpr_cv_signal(&root_worker->cv); + } + goto done; + } else { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_FD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. non-root poller %p (root=%p)", next_worker, + root_worker); + } + SET_KICK_STATE(next_worker, KICKED); + ret_err = grpc_wakeup_fd_wakeup(&global_wakeup_fd); + goto done; + } + } else { + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + GPR_ASSERT(next_worker->state == KICKED); + SET_KICK_STATE(next_worker, KICKED); + goto done; + } + } else { + GRPC_STATS_INC_POLLSET_KICK_OWN_THREAD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kicked while waking up"); + } + goto done; + } + + GPR_UNREACHABLE_CODE(goto done); + } + + if (specific_worker->state == KICKED) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. specific worker already kicked"); + } + goto done; + } else if (g_current_thread_worker == specific_worker) { + GRPC_STATS_INC_POLLSET_KICK_OWN_THREAD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. mark %p kicked", specific_worker); + } + SET_KICK_STATE(specific_worker, KICKED); + goto done; + } else if (specific_worker == + reinterpret_cast( + gpr_atm_no_barrier_load(&g_active_poller))) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_FD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kick active poller"); + } + SET_KICK_STATE(specific_worker, KICKED); + ret_err = grpc_wakeup_fd_wakeup(&global_wakeup_fd); + goto done; + } else if (specific_worker->initialized_cv) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kick waiting worker"); + } + SET_KICK_STATE(specific_worker, KICKED); + gpr_cv_signal(&specific_worker->cv); + goto done; + } else { + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, " .. kick non-waiting worker"); + } + SET_KICK_STATE(specific_worker, KICKED); + goto done; + } +done: + return ret_err; +} + +static void pollset_add_fd(grpc_pollset* /*pollset*/, grpc_fd* /*fd*/) {} + +/******************************************************************************* + * Pollset-set Definitions + */ + +static grpc_pollset_set* pollset_set_create(void) { + return reinterpret_cast(static_cast(0xdeafbeef)); +} + +static void pollset_set_destroy(grpc_pollset_set* /*pss*/) {} + +static void pollset_set_add_fd(grpc_pollset_set* /*pss*/, grpc_fd* /*fd*/) {} + +static void pollset_set_del_fd(grpc_pollset_set* /*pss*/, grpc_fd* /*fd*/) {} + +static void pollset_set_add_pollset(grpc_pollset_set* /*pss*/, + grpc_pollset* /*ps*/) {} + +static void pollset_set_del_pollset(grpc_pollset_set* /*pss*/, + grpc_pollset* /*ps*/) {} + +static void pollset_set_add_pollset_set(grpc_pollset_set* /*bag*/, + grpc_pollset_set* /*item*/) {} + +static void pollset_set_del_pollset_set(grpc_pollset_set* /*bag*/, + grpc_pollset_set* /*item*/) {} + +/******************************************************************************* + * Event engine binding + */ + +static bool is_any_background_poller_thread(void) { return false; } + +static void shutdown_background_closure(void) {} + +static bool add_closure_to_background_poller(grpc_closure* /*closure*/, + grpc_error_handle /*error*/) { + return false; +} + +static void shutdown_engine(void) { + fd_global_shutdown(); + pollset_global_shutdown(); + epoll_set_shutdown(); + if (grpc_core::Fork::Enabled()) { + gpr_mu_destroy(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc(nullptr); + } +} + +static const grpc_event_engine_vtable vtable = { + sizeof(grpc_pollset), + true, + false, + + fd_create, + fd_wrapped_fd, + fd_orphan, + fd_shutdown, + fd_notify_on_read, + fd_notify_on_write, + fd_notify_on_error, + fd_become_readable, + fd_become_writable, + fd_has_errors, + fd_is_shutdown, + + pollset_init, + pollset_shutdown, + pollset_destroy, + pollset_work, + pollset_kick, + pollset_add_fd, + + pollset_set_create, + pollset_set_destroy, + pollset_set_add_pollset, + pollset_set_del_pollset, + pollset_set_add_pollset_set, + pollset_set_del_pollset_set, + pollset_set_add_fd, + pollset_set_del_fd, + + is_any_background_poller_thread, + shutdown_background_closure, + shutdown_engine, + add_closure_to_background_poller, +}; + +/* Called by the child process's post-fork handler to close open fds, including + * the global epoll fd. This allows gRPC to shutdown in the child process + * without interfering with connections or RPCs ongoing in the parent. */ +static void reset_event_manager_on_fork() { + gpr_mu_lock(&fork_fd_list_mu); + while (fork_fd_list_head != nullptr) { + close(fork_fd_list_head->fd); + fork_fd_list_head->fd = -1; + fork_fd_list_head = fork_fd_list_head->fork_fd_list->next; + } + gpr_mu_unlock(&fork_fd_list_mu); + shutdown_engine(); + grpc_init_epoll1_linux(true); +} + +/* It is possible that GLIBC has epoll but the underlying kernel doesn't. + * Create epoll_fd (epoll_set_init() takes care of that) to make sure epoll + * support is available */ +const grpc_event_engine_vtable* grpc_init_epoll1_linux( + bool /*explicit_request*/) { + if (!grpc_has_wakeup_fd()) { + gpr_log(GPR_ERROR, "Skipping epoll1 because of no wakeup fd."); + return nullptr; + } + + if (!epoll_set_init()) { + return nullptr; + } + + fd_global_init(); + + if (!GRPC_LOG_IF_ERROR("pollset_global_init", pollset_global_init())) { + fd_global_shutdown(); + epoll_set_shutdown(); + return nullptr; + } + + if (grpc_core::Fork::Enabled()) { + gpr_mu_init(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc( + reset_event_manager_on_fork); + } + return &vtable; +} + +#else /* defined(GRPC_LINUX_EPOLL) */ +#if defined(GRPC_POSIX_SOCKET_EV_EPOLL1) +#include "src/core/lib/iomgr/ev_epoll1_linux.h" +/* If GRPC_LINUX_EPOLL is not defined, it means epoll is not available. Return + * NULL */ +const grpc_event_engine_vtable* grpc_init_epoll1_linux( + bool /*explicit_request*/) { + return nullptr; +} +#endif /* defined(GRPC_POSIX_SOCKET_EV_EPOLL1) */ +#endif /* !defined(GRPC_LINUX_EPOLL) */ diff --git a/src/core/lib/iomgr/ev_epollex_linux.cc b/src/core/lib/iomgr/ev_epollex_linux.cc new file mode 100644 index 00000000..38e79bc6 --- /dev/null +++ b/src/core/lib/iomgr/ev_epollex_linux.cc @@ -0,0 +1,1654 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/port.h" + +/* This polling engine is only relevant on linux kernels supporting epoll() */ +#ifdef GRPC_LINUX_EPOLL_CREATE1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/spinlock.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/block_annotate.h" +#include "src/core/lib/iomgr/ev_epollex_linux.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/is_epollexclusive_available.h" +#include "src/core/lib/iomgr/lockfree_event.h" +#include "src/core/lib/iomgr/sys_epoll_wrapper.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "src/core/lib/profiling/timers.h" + +// debug aid: create workers on the heap (allows asan to spot +// use-after-destruction) +//#define GRPC_EPOLLEX_CREATE_WORKERS_ON_HEAP 1 + +#define MAX_EPOLL_EVENTS 100 +#define MAX_FDS_IN_CACHE 32 + +grpc_core::DebugOnlyTraceFlag grpc_trace_pollable_refcount(false, + "pollable_refcount"); + +/******************************************************************************* + * pollable Declarations + */ + +typedef enum { PO_MULTI, PO_FD, PO_EMPTY } pollable_type; + +typedef struct pollable pollable; + +/// A pollable is something that can be polled: it has an epoll set to poll on, +/// and a wakeup fd for kicks +/// There are three broad types: +/// - PO_EMPTY - the empty pollable, used before file descriptors are added to +/// a pollset +/// - PO_FD - a pollable containing only one FD - used to optimize single-fd +/// pollsets (which are common with synchronous api usage) +/// - PO_MULTI - a pollable containing many fds +struct pollable { + pollable_type type; // immutable + grpc_core::RefCount refs; + + int epfd; + grpc_wakeup_fd wakeup; + + // The following are relevant only for type PO_FD + grpc_fd* owner_fd; // Set to the owner_fd if the type is PO_FD + gpr_mu owner_orphan_mu; // Synchronizes access to owner_orphaned field + bool owner_orphaned; // Is the owner fd orphaned + + grpc_pollset_set* pollset_set; + pollable* next; + pollable* prev; + + gpr_mu mu; + grpc_pollset_worker* root_worker; + + int event_cursor; + int event_count; + struct epoll_event events[MAX_EPOLL_EVENTS]; +}; + +static const char* pollable_type_string(pollable_type t) { + switch (t) { + case PO_MULTI: + return "pollset"; + case PO_FD: + return "fd"; + case PO_EMPTY: + return "empty"; + } + return ""; +} + +static std::string pollable_desc(pollable* p) { + return absl::StrFormat("type=%s epfd=%d wakeup=%d", + pollable_type_string(p->type), p->epfd, + p->wakeup.read_fd); +} + +/// Shared empty pollable - used by pollset to poll on until the first fd is +/// added +static pollable* g_empty_pollable; + +static grpc_error_handle pollable_create(pollable_type type, pollable** p); +static pollable* pollable_ref(pollable* p, + const grpc_core::DebugLocation& dbg_loc, + const char* reason) { + p->refs.Ref(dbg_loc, reason); + return p; +} +static void pollable_unref(pollable* p, const grpc_core::DebugLocation& dbg_loc, + const char* reason) { + if (p == nullptr) return; + if (GPR_UNLIKELY(p != nullptr && p->refs.Unref(dbg_loc, reason))) { + GRPC_FD_TRACE("pollable_unref: Closing epfd: %d", p->epfd); + close(p->epfd); + grpc_wakeup_fd_destroy(&p->wakeup); + gpr_mu_destroy(&p->owner_orphan_mu); + gpr_mu_destroy(&p->mu); + gpr_free(p); + } +} +#define POLLABLE_REF(p, r) pollable_ref((p), DEBUG_LOCATION, (r)) +#define POLLABLE_UNREF(p, r) pollable_unref((p), DEBUG_LOCATION, (r)) + +/******************************************************************************* + * Fd Declarations + */ + +struct grpc_fd { + grpc_fd(int fd, const char* name, bool track_err) + : fd(fd), track_err(track_err) { + gpr_mu_init(&orphan_mu); + gpr_mu_init(&pollable_mu); + read_closure.InitEvent(); + write_closure.InitEvent(); + error_closure.InitEvent(); + + std::string fd_name = absl::StrCat(name, " fd=", fd); + grpc_iomgr_register_object(&iomgr_object, fd_name.c_str()); +#ifndef NDEBUG + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, "FD %d %p create %s", fd, this, fd_name.c_str()); + } +#endif + } + + // This is really the dtor, but the poller threads waking up from + // epoll_wait() may access the (read|write|error)_closure after destruction. + // Since the object will be added to the free pool, this behavior is + // not going to cause issues, except spurious events if the FD is reused + // while the race happens. + void destroy() { + grpc_iomgr_unregister_object(&iomgr_object); + + POLLABLE_UNREF(pollable_obj, "fd_pollable"); + + // To clear out the allocations of pollset_fds, we need to swap its + // contents with a newly-constructed (and soon to be destructed) local + // variable of its same type. This is because InlinedVector::clear is _not_ + // guaranteed to actually free up allocations and this is important since + // this object doesn't have a conventional destructor. + absl::InlinedVector pollset_fds_tmp; + pollset_fds_tmp.swap(pollset_fds); + + gpr_mu_destroy(&pollable_mu); + gpr_mu_destroy(&orphan_mu); + + read_closure.DestroyEvent(); + write_closure.DestroyEvent(); + error_closure.DestroyEvent(); + + invalidate(); + } + +#ifndef NDEBUG + /* Since an fd is never really destroyed (i.e gpr_free() is not called), it is + * hard-to-debug cases where fd fields are accessed even after calling + * fd_destroy(). The following invalidates fd fields to make catching such + * errors easier */ + void invalidate() { + fd = -1; + gpr_atm_no_barrier_store(&refst, -1); + memset(&orphan_mu, -1, sizeof(orphan_mu)); + memset(&pollable_mu, -1, sizeof(pollable_mu)); + pollable_obj = nullptr; + on_done_closure = nullptr; + memset(&iomgr_object, -1, sizeof(iomgr_object)); + track_err = false; + } +#else + void invalidate() {} +#endif + + int fd; + + // refst format: + // bit 0 : 1=Active / 0=Orphaned + // bits 1-n : refcount + // Ref/Unref by two to avoid altering the orphaned bit + gpr_atm refst = 1; + + gpr_mu orphan_mu; + + // Protects pollable_obj and pollset_fds. + gpr_mu pollable_mu; + absl::InlinedVector pollset_fds; // Used in PO_MULTI. + pollable* pollable_obj = nullptr; // Used in PO_FD. + + grpc_core::LockfreeEvent read_closure; + grpc_core::LockfreeEvent write_closure; + grpc_core::LockfreeEvent error_closure; + + struct grpc_fd* freelist_next = nullptr; + grpc_closure* on_done_closure = nullptr; + + grpc_iomgr_object iomgr_object; + + // Do we need to track EPOLLERR events separately? + bool track_err; +}; + +static void fd_global_init(void); +static void fd_global_shutdown(void); + +/******************************************************************************* + * Pollset Declarations + */ + +struct pwlink { + grpc_pollset_worker* next; + grpc_pollset_worker* prev; +}; +typedef enum { PWLINK_POLLABLE = 0, PWLINK_POLLSET, PWLINK_COUNT } pwlinks; + +struct grpc_pollset_worker { + bool kicked; + bool initialized_cv; +#ifndef NDEBUG + // debug aid: which thread started this worker + pid_t originator; +#endif + gpr_cv cv; + grpc_pollset* pollset; + pollable* pollable_obj; + + pwlink links[PWLINK_COUNT]; +}; + +struct grpc_pollset { + gpr_mu mu; + gpr_atm worker_count; + gpr_atm active_pollable_type; + pollable* active_pollable; + bool kicked_without_poller; + grpc_closure* shutdown_closure; + bool already_shutdown; + grpc_pollset_worker* root_worker; + int containing_pollset_set_count; +}; + +/******************************************************************************* + * Pollset-set Declarations + */ + +struct grpc_pollset_set { + grpc_core::RefCount refs; + gpr_mu mu; + grpc_pollset_set* parent; + + size_t pollset_count; + size_t pollset_capacity; + grpc_pollset** pollsets; + + size_t fd_count; + size_t fd_capacity; + grpc_fd** fds; +}; + +/******************************************************************************* + * Common helpers + */ + +static bool append_error(grpc_error_handle* composite, grpc_error_handle error, + const char* desc) { + if (error == GRPC_ERROR_NONE) return true; + if (*composite == GRPC_ERROR_NONE) { + *composite = GRPC_ERROR_CREATE_FROM_COPIED_STRING(desc); + } + *composite = grpc_error_add_child(*composite, error); + return false; +} + +/******************************************************************************* + * Fd Definitions + */ + +/* We need to keep a freelist not because of any concerns of malloc performance + * but instead so that implementations with multiple threads in (for example) + * epoll_wait deal with the race between pollset removal and incoming poll + * notifications. + * + * The problem is that the poller ultimately holds a reference to this + * object, so it is very difficult to know when is safe to free it, at least + * without some expensive synchronization. + * + * If we keep the object freelisted, in the worst case losing this race just + * becomes a spurious read notification on a reused fd. + */ + +static grpc_fd* fd_freelist = nullptr; +static gpr_mu fd_freelist_mu; + +#ifndef NDEBUG +#define REF_BY(fd, n, reason) ref_by(fd, n, reason, __FILE__, __LINE__) +#define UNREF_BY(fd, n, reason) unref_by(fd, n, reason, __FILE__, __LINE__) +static void ref_by(grpc_fd* fd, int n, const char* reason, const char* file, + int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, + "FD %d %p ref %d %" PRIdPTR " -> %" PRIdPTR " [%s; %s:%d]", + fd->fd, fd, n, gpr_atm_no_barrier_load(&fd->refst), + gpr_atm_no_barrier_load(&fd->refst) + n, reason, file, line); + } +#else +#define REF_BY(fd, n, reason) \ + do { \ + ref_by(fd, n); \ + (void)(reason); \ + } while (0) +#define UNREF_BY(fd, n, reason) \ + do { \ + unref_by(fd, n); \ + (void)(reason); \ + } while (0) +static void ref_by(grpc_fd* fd, int n) { +#endif + GPR_ASSERT(gpr_atm_no_barrier_fetch_add(&fd->refst, n) > 0); +} + +/* Uninitialize and add to the freelist */ +static void fd_destroy(void* arg, grpc_error_handle /*error*/) { + grpc_fd* fd = static_cast(arg); + fd->destroy(); + + /* Add the fd to the freelist */ + gpr_mu_lock(&fd_freelist_mu); + fd->freelist_next = fd_freelist; + fd_freelist = fd; + gpr_mu_unlock(&fd_freelist_mu); +} + +#ifndef NDEBUG +static void unref_by(grpc_fd* fd, int n, const char* reason, const char* file, + int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, + "FD %d %p unref %d %" PRIdPTR " -> %" PRIdPTR " [%s; %s:%d]", + fd->fd, fd, n, gpr_atm_no_barrier_load(&fd->refst), + gpr_atm_no_barrier_load(&fd->refst) - n, reason, file, line); + } +#else +static void unref_by(grpc_fd* fd, int n) { +#endif + gpr_atm old = gpr_atm_full_fetch_add(&fd->refst, -n); + if (old == n) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(fd_destroy, fd, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + } else { + GPR_ASSERT(old > n); + } +} + +static void fd_global_init(void) { gpr_mu_init(&fd_freelist_mu); } + +static void fd_global_shutdown(void) { + // TODO(guantaol): We don't have a reasonable explanation about this + // lock()/unlock() pattern. It can be a valid barrier if there is at most one + // pending lock() at this point. Otherwise, there is still a possibility of + // use-after-free race. Need to reason about the code and/or clean it up. + gpr_mu_lock(&fd_freelist_mu); + gpr_mu_unlock(&fd_freelist_mu); + while (fd_freelist != nullptr) { + grpc_fd* fd = fd_freelist; + fd_freelist = fd_freelist->freelist_next; + gpr_free(fd); + } + gpr_mu_destroy(&fd_freelist_mu); +} + +static grpc_fd* fd_create(int fd, const char* name, bool track_err) { + grpc_fd* new_fd = nullptr; + + gpr_mu_lock(&fd_freelist_mu); + if (fd_freelist != nullptr) { + new_fd = fd_freelist; + fd_freelist = fd_freelist->freelist_next; + } + gpr_mu_unlock(&fd_freelist_mu); + + if (new_fd == nullptr) { + new_fd = static_cast(gpr_malloc(sizeof(grpc_fd))); + } + + return new (new_fd) grpc_fd(fd, name, track_err); +} + +static int fd_wrapped_fd(grpc_fd* fd) { + int ret_fd = fd->fd; + return (gpr_atm_acq_load(&fd->refst) & 1) ? ret_fd : -1; +} + +static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd, + const char* reason) { + bool is_fd_closed = false; + + gpr_mu_lock(&fd->orphan_mu); + + // Get the fd->pollable_obj and set the owner_orphaned on that pollable to + // true so that the pollable will no longer access its owner_fd field. + gpr_mu_lock(&fd->pollable_mu); + pollable* pollable_obj = fd->pollable_obj; + + if (pollable_obj) { + gpr_mu_lock(&pollable_obj->owner_orphan_mu); + pollable_obj->owner_orphaned = true; + } + + fd->on_done_closure = on_done; + + /* If release_fd is not NULL, we should be relinquishing control of the file + descriptor fd->fd (but we still own the grpc_fd structure). */ + if (release_fd != nullptr) { + // Remove the FD from all epolls sets, before releasing it. + // Otherwise, we will receive epoll events after we release the FD. + epoll_event ev_fd; + memset(&ev_fd, 0, sizeof(ev_fd)); + if (pollable_obj != nullptr) { // For PO_FD. + epoll_ctl(pollable_obj->epfd, EPOLL_CTL_DEL, fd->fd, &ev_fd); + } + for (size_t i = 0; i < fd->pollset_fds.size(); ++i) { // For PO_MULTI. + const int epfd = fd->pollset_fds[i]; + epoll_ctl(epfd, EPOLL_CTL_DEL, fd->fd, &ev_fd); + } + *release_fd = fd->fd; + } else { + close(fd->fd); + is_fd_closed = true; + } + + // TODO(sreek): handle fd removal (where is_fd_closed=false) + if (!is_fd_closed) { + GRPC_FD_TRACE("epoll_fd %p (%d) was orphaned but not closed.", fd, fd->fd); + } + + /* Remove the active status but keep referenced. We want this grpc_fd struct + to be alive (and not added to freelist) until the end of this function */ + REF_BY(fd, 1, reason); + + grpc_core::ExecCtx::Run(DEBUG_LOCATION, fd->on_done_closure, GRPC_ERROR_NONE); + + if (pollable_obj) { + gpr_mu_unlock(&pollable_obj->owner_orphan_mu); + } + + gpr_mu_unlock(&fd->pollable_mu); + gpr_mu_unlock(&fd->orphan_mu); + + UNREF_BY(fd, 2, reason); /* Drop the reference */ +} + +static bool fd_is_shutdown(grpc_fd* fd) { + return fd->read_closure.IsShutdown(); +} + +/* Might be called multiple times */ +static void fd_shutdown(grpc_fd* fd, grpc_error_handle why) { + if (fd->read_closure.SetShutdown(GRPC_ERROR_REF(why))) { + if (shutdown(fd->fd, SHUT_RDWR)) { + if (errno != ENOTCONN) { + gpr_log(GPR_ERROR, "Error shutting down fd %d. errno: %d", + grpc_fd_wrapped_fd(fd), errno); + } + } + fd->write_closure.SetShutdown(GRPC_ERROR_REF(why)); + fd->error_closure.SetShutdown(GRPC_ERROR_REF(why)); + } + GRPC_ERROR_UNREF(why); +} + +static void fd_notify_on_read(grpc_fd* fd, grpc_closure* closure) { + fd->read_closure.NotifyOn(closure); +} + +static void fd_notify_on_write(grpc_fd* fd, grpc_closure* closure) { + fd->write_closure.NotifyOn(closure); +} + +static void fd_notify_on_error(grpc_fd* fd, grpc_closure* closure) { + fd->error_closure.NotifyOn(closure); +} + +static bool fd_has_pollset(grpc_fd* fd, grpc_pollset* pollset) { + const int epfd = pollset->active_pollable->epfd; + grpc_core::MutexLockForGprMu lock(&fd->pollable_mu); + for (size_t i = 0; i < fd->pollset_fds.size(); ++i) { + if (fd->pollset_fds[i] == epfd) { + return true; + } + } + return false; +} + +static void fd_add_pollset(grpc_fd* fd, grpc_pollset* pollset) { + const int epfd = pollset->active_pollable->epfd; + grpc_core::MutexLockForGprMu lock(&fd->pollable_mu); + fd->pollset_fds.push_back(epfd); +} + +/******************************************************************************* + * Pollable Definitions + */ + +static grpc_error_handle pollable_create(pollable_type type, pollable** p) { + *p = nullptr; + + int epfd = epoll_create1(EPOLL_CLOEXEC); + if (epfd == -1) { + return GRPC_OS_ERROR(errno, "epoll_create1"); + } + GRPC_FD_TRACE("Pollable_create: created epfd: %d (type: %d)", epfd, type); + *p = static_cast(gpr_malloc(sizeof(**p))); + grpc_error_handle err = grpc_wakeup_fd_init(&(*p)->wakeup); + if (err != GRPC_ERROR_NONE) { + GRPC_FD_TRACE( + "Pollable_create: closed epfd: %d (type: %d). wakeupfd_init error", + epfd, type); + close(epfd); + gpr_free(*p); + *p = nullptr; + return err; + } + struct epoll_event ev; + ev.events = static_cast(EPOLLIN | EPOLLET); + ev.data.ptr = + reinterpret_cast(1 | reinterpret_cast(&(*p)->wakeup)); + if (epoll_ctl(epfd, EPOLL_CTL_ADD, (*p)->wakeup.read_fd, &ev) != 0) { + err = GRPC_OS_ERROR(errno, "epoll_ctl"); + GRPC_FD_TRACE( + "Pollable_create: closed epfd: %d (type: %d). epoll_ctl error", epfd, + type); + close(epfd); + grpc_wakeup_fd_destroy(&(*p)->wakeup); + gpr_free(*p); + *p = nullptr; + return err; + } + + (*p)->type = type; + new (&(*p)->refs) grpc_core::RefCount( + 1, GRPC_TRACE_FLAG_ENABLED(grpc_trace_pollable_refcount) + ? "pollable_refcount" + : nullptr); + gpr_mu_init(&(*p)->mu); + (*p)->epfd = epfd; + (*p)->owner_fd = nullptr; + gpr_mu_init(&(*p)->owner_orphan_mu); + (*p)->owner_orphaned = false; + (*p)->pollset_set = nullptr; + (*p)->next = (*p)->prev = *p; + (*p)->root_worker = nullptr; + (*p)->event_cursor = 0; + (*p)->event_count = 0; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle pollable_add_fd(pollable* p, grpc_fd* fd) { + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "pollable_add_fd"; + const int epfd = p->epfd; + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "add fd %p (%d) to pollable %p", fd, fd->fd, p); + } + + struct epoll_event ev_fd; + ev_fd.events = + static_cast(EPOLLET | EPOLLIN | EPOLLOUT | EPOLLEXCLUSIVE); + /* Use the second least significant bit of ev_fd.data.ptr to store track_err + * to avoid synchronization issues when accessing it after receiving an event. + * Accessing fd would be a data race there because the fd might have been + * returned to the free list at that point. */ + ev_fd.data.ptr = reinterpret_cast(reinterpret_cast(fd) | + (fd->track_err ? 2 : 0)); + GRPC_STATS_INC_SYSCALL_EPOLL_CTL(); + if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd->fd, &ev_fd) != 0) { + switch (errno) { + case EEXIST: + break; + default: + append_error(&error, GRPC_OS_ERROR(errno, "epoll_ctl"), err_desc); + } + } + + return error; +} + +/******************************************************************************* + * Pollset Definitions + */ + +static GPR_THREAD_LOCAL(grpc_pollset*) g_current_thread_pollset; +static GPR_THREAD_LOCAL(grpc_pollset_worker*) g_current_thread_worker; + +/* Global state management */ +static grpc_error_handle pollset_global_init(void) { + return pollable_create(PO_EMPTY, &g_empty_pollable); +} + +static void pollset_global_shutdown(void) { + POLLABLE_UNREF(g_empty_pollable, "g_empty_pollable"); +} + +/* pollset->mu must be held while calling this function */ +static void pollset_maybe_finish_shutdown(grpc_pollset* pollset) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p (pollable:%p) maybe_finish_shutdown sc=%p (target:!NULL) " + "rw=%p (target:NULL) cpsc=%d (target:0)", + pollset, pollset->active_pollable, pollset->shutdown_closure, + pollset->root_worker, pollset->containing_pollset_set_count); + } + if (pollset->shutdown_closure != nullptr && pollset->root_worker == nullptr && + pollset->containing_pollset_set_count == 0) { + GPR_TIMER_MARK("pollset_finish_shutdown", 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, pollset->shutdown_closure, + GRPC_ERROR_NONE); + pollset->shutdown_closure = nullptr; + pollset->already_shutdown = true; + } +} + +/* pollset->mu must be held before calling this function, + * pollset->active_pollable->mu & specific_worker->pollable_obj->mu must not be + * held */ +static grpc_error_handle kick_one_worker(grpc_pollset_worker* specific_worker) { + GPR_TIMER_SCOPE("kick_one_worker", 0); + pollable* p = specific_worker->pollable_obj; + grpc_core::MutexLockForGprMu lock(&p->mu); + GPR_ASSERT(specific_worker != nullptr); + if (specific_worker->kicked) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_specific_but_already_kicked", p); + } + GRPC_STATS_INC_POLLSET_KICKED_AGAIN(); + return GRPC_ERROR_NONE; + } + if (g_current_thread_worker == specific_worker) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_specific_but_awake", p); + } + GRPC_STATS_INC_POLLSET_KICK_OWN_THREAD(); + specific_worker->kicked = true; + return GRPC_ERROR_NONE; + } + if (specific_worker == p->root_worker) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_FD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_specific_via_wakeup_fd", p); + } + specific_worker->kicked = true; + grpc_error_handle error = grpc_wakeup_fd_wakeup(&p->wakeup); + return error; + } + if (specific_worker->initialized_cv) { + GRPC_STATS_INC_POLLSET_KICK_WAKEUP_CV(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_specific_via_cv", p); + } + specific_worker->kicked = true; + gpr_cv_signal(&specific_worker->cv); + return GRPC_ERROR_NONE; + } + // we can get here during end_worker after removing specific_worker from the + // pollable list but before removing it from the pollset list + return GRPC_ERROR_NONE; +} + +static grpc_error_handle pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + GPR_TIMER_SCOPE("pollset_kick", 0); + GRPC_STATS_INC_POLLSET_KICK(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p kick %p tls_pollset=%p tls_worker=%p pollset.root_worker=%p", + pollset, specific_worker, + static_cast(g_current_thread_pollset), + static_cast(g_current_thread_worker), pollset->root_worker); + } + if (specific_worker == nullptr) { + if (g_current_thread_pollset != pollset) { + if (pollset->root_worker == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_any_without_poller", pollset); + } + GRPC_STATS_INC_POLLSET_KICKED_WITHOUT_POLLER(); + pollset->kicked_without_poller = true; + return GRPC_ERROR_NONE; + } else { + // We've been asked to kick a poller, but we haven't been told which one + // ... any will do + // We look at the pollset worker list because: + // 1. the pollable list may include workers from other pollers, so we'd + // need to do an O(N) search + // 2. we'd additionally need to take the pollable lock, which we've so + // far avoided + // Now, we would prefer to wake a poller in cv_wait, and not in + // epoll_wait (since the latter would imply the need to do an additional + // wakeup) + // We know that if a worker is at the root of a pollable, it's (likely) + // also the root of a pollset, and we know that if a worker is NOT at + // the root of a pollset, it's (likely) not at the root of a pollable, + // so we take our chances and choose the SECOND worker enqueued against + // the pollset as a worker that's likely to be in cv_wait + return kick_one_worker( + pollset->root_worker->links[PWLINK_POLLSET].next); + } + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p kicked_any_but_awake", pollset); + } + GRPC_STATS_INC_POLLSET_KICK_OWN_THREAD(); + return GRPC_ERROR_NONE; + } + } else { + return kick_one_worker(specific_worker); + } +} + +static grpc_error_handle pollset_kick_all(grpc_pollset* pollset) { + GPR_TIMER_SCOPE("pollset_kick_all", 0); + grpc_error_handle error = GRPC_ERROR_NONE; + const char* err_desc = "pollset_kick_all"; + grpc_pollset_worker* w = pollset->root_worker; + if (w != nullptr) { + do { + GRPC_STATS_INC_POLLSET_KICK(); + append_error(&error, kick_one_worker(w), err_desc); + w = w->links[PWLINK_POLLSET].next; + } while (w != pollset->root_worker); + } + return error; +} + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + gpr_mu_init(&pollset->mu); + gpr_atm_no_barrier_store(&pollset->worker_count, 0); + gpr_atm_no_barrier_store(&pollset->active_pollable_type, PO_EMPTY); + pollset->active_pollable = POLLABLE_REF(g_empty_pollable, "pollset"); + pollset->kicked_without_poller = false; + pollset->shutdown_closure = nullptr; + pollset->already_shutdown = false; + pollset->root_worker = nullptr; + pollset->containing_pollset_set_count = 0; + *mu = &pollset->mu; +} + +static int poll_deadline_to_millis_timeout(grpc_millis millis) { + if (millis == GRPC_MILLIS_INF_FUTURE) return -1; + grpc_millis delta = millis - grpc_core::ExecCtx::Get()->Now(); + if (delta > INT_MAX) { + return INT_MAX; + } else if (delta < 0) { + return 0; + } else { + return static_cast(delta); + } +} + +static void fd_become_readable(grpc_fd* fd) { fd->read_closure.SetReady(); } + +static void fd_become_writable(grpc_fd* fd) { fd->write_closure.SetReady(); } + +static void fd_has_errors(grpc_fd* fd) { fd->error_closure.SetReady(); } + +/* Get the pollable_obj attached to this fd. If none is attached, create a new + * pollable object (of type PO_FD), attach it to the fd and return it + * + * Note that if a pollable object is already attached to the fd, it may be of + * either PO_FD or PO_MULTI type */ +static grpc_error_handle get_fd_pollable(grpc_fd* fd, pollable** p) { + gpr_mu_lock(&fd->pollable_mu); + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "get_fd_pollable"; + if (fd->pollable_obj == nullptr) { + if (append_error(&error, pollable_create(PO_FD, &fd->pollable_obj), + err_desc)) { + fd->pollable_obj->owner_fd = fd; + if (!append_error(&error, pollable_add_fd(fd->pollable_obj, fd), + err_desc)) { + POLLABLE_UNREF(fd->pollable_obj, "fd_pollable"); + fd->pollable_obj = nullptr; + } + } + } + if (error == GRPC_ERROR_NONE) { + GPR_ASSERT(fd->pollable_obj != nullptr); + *p = POLLABLE_REF(fd->pollable_obj, "pollset"); + } else { + GPR_ASSERT(fd->pollable_obj == nullptr); + *p = nullptr; + } + gpr_mu_unlock(&fd->pollable_mu); + return error; +} + +/* pollset->po.mu lock must be held by the caller before calling this */ +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + GPR_TIMER_SCOPE("pollset_shutdown", 0); + GPR_ASSERT(pollset->shutdown_closure == nullptr); + pollset->shutdown_closure = closure; + GRPC_LOG_IF_ERROR("pollset_shutdown", pollset_kick_all(pollset)); + pollset_maybe_finish_shutdown(pollset); +} + +static grpc_error_handle pollable_process_events(grpc_pollset* pollset, + pollable* pollable_obj, + bool drain) { + GPR_TIMER_SCOPE("pollable_process_events", 0); + static const char* err_desc = "pollset_process_events"; + // Use a simple heuristic to determine how many fd events to process + // per loop iteration. (events/workers) + int handle_count = 1; + int worker_count = gpr_atm_no_barrier_load(&pollset->worker_count); + GPR_ASSERT(worker_count > 0); + handle_count = + (pollable_obj->event_count - pollable_obj->event_cursor) / worker_count; + if (handle_count == 0) { + handle_count = 1; + } + grpc_error_handle error = GRPC_ERROR_NONE; + for (int i = 0; (drain || i < handle_count) && + pollable_obj->event_cursor != pollable_obj->event_count; + i++) { + int n = pollable_obj->event_cursor++; + struct epoll_event* ev = &pollable_obj->events[n]; + void* data_ptr = ev->data.ptr; + if (1 & reinterpret_cast(data_ptr)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p got pollset_wakeup %p", pollset, data_ptr); + } + append_error( + &error, + grpc_wakeup_fd_consume_wakeup(reinterpret_cast( + ~static_cast(1) & + reinterpret_cast(data_ptr))), + err_desc); + } else { + grpc_fd* fd = + reinterpret_cast(reinterpret_cast(data_ptr) & ~2); + bool track_err = reinterpret_cast(data_ptr) & 2; + bool cancel = (ev->events & EPOLLHUP) != 0; + bool error = (ev->events & EPOLLERR) != 0; + bool read_ev = (ev->events & (EPOLLIN | EPOLLPRI)) != 0; + bool write_ev = (ev->events & EPOLLOUT) != 0; + bool err_fallback = error && !track_err; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p got fd %p: cancel=%d read=%d " + "write=%d", + pollset, fd, cancel, read_ev, write_ev); + } + if (error && !err_fallback) { + fd_has_errors(fd); + } + if (read_ev || cancel || err_fallback) { + fd_become_readable(fd); + } + if (write_ev || cancel || err_fallback) { + fd_become_writable(fd); + } + } + } + + return error; +} + +/* pollset_shutdown is guaranteed to be called before pollset_destroy. */ +static void pollset_destroy(grpc_pollset* pollset) { + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + pollset->active_pollable = nullptr; + gpr_mu_destroy(&pollset->mu); +} + +static grpc_error_handle pollable_epoll(pollable* p, grpc_millis deadline) { + GPR_TIMER_SCOPE("pollable_epoll", 0); + int timeout = poll_deadline_to_millis_timeout(deadline); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "POLLABLE:%p[%s] poll for %dms", p, + pollable_desc(p).c_str(), timeout); + } + + if (timeout != 0) { + GRPC_SCHEDULING_START_BLOCKING_REGION; + } + int r; + do { + GRPC_STATS_INC_SYSCALL_POLL(); + r = epoll_wait(p->epfd, p->events, MAX_EPOLL_EVENTS, timeout); + } while (r < 0 && errno == EINTR); + if (timeout != 0) { + GRPC_SCHEDULING_END_BLOCKING_REGION; + } + + if (r < 0) return GRPC_OS_ERROR(errno, "epoll_wait"); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "POLLABLE:%p got %d events", p, r); + } + + p->event_cursor = 0; + p->event_count = r; + + return GRPC_ERROR_NONE; +} + +/* Return true if first in list */ +static bool worker_insert(grpc_pollset_worker** root_worker, + grpc_pollset_worker* worker, pwlinks link) { + if (*root_worker == nullptr) { + *root_worker = worker; + worker->links[link].next = worker->links[link].prev = worker; + return true; + } else { + worker->links[link].next = *root_worker; + worker->links[link].prev = worker->links[link].next->links[link].prev; + worker->links[link].next->links[link].prev = worker; + worker->links[link].prev->links[link].next = worker; + return false; + } +} + +/* returns the new root IFF the root changed */ +typedef enum { WRR_NEW_ROOT, WRR_EMPTIED, WRR_REMOVED } worker_remove_result; + +static worker_remove_result worker_remove(grpc_pollset_worker** root_worker, + grpc_pollset_worker* worker, + pwlinks link) { + if (worker == *root_worker) { + if (worker == worker->links[link].next) { + *root_worker = nullptr; + return WRR_EMPTIED; + } else { + *root_worker = worker->links[link].next; + worker->links[link].prev->links[link].next = worker->links[link].next; + worker->links[link].next->links[link].prev = worker->links[link].prev; + return WRR_NEW_ROOT; + } + } else { + worker->links[link].prev->links[link].next = worker->links[link].next; + worker->links[link].next->links[link].prev = worker->links[link].prev; + return WRR_REMOVED; + } +} + +/* Return true if this thread should poll */ +static bool begin_worker(grpc_pollset* pollset, grpc_pollset_worker* worker, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + GPR_TIMER_SCOPE("begin_worker", 0); + bool do_poll = + (pollset->shutdown_closure == nullptr && !pollset->already_shutdown); + gpr_atm_no_barrier_fetch_add(&pollset->worker_count, 1); + if (worker_hdl != nullptr) *worker_hdl = worker; + worker->initialized_cv = false; + worker->kicked = false; + worker->pollset = pollset; + worker->pollable_obj = + POLLABLE_REF(pollset->active_pollable, "pollset_worker"); + worker_insert(&pollset->root_worker, worker, PWLINK_POLLSET); + gpr_mu_lock(&worker->pollable_obj->mu); + if (!worker_insert(&worker->pollable_obj->root_worker, worker, + PWLINK_POLLABLE)) { + worker->initialized_cv = true; + gpr_cv_init(&worker->cv); + gpr_mu_unlock(&pollset->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace) && + worker->pollable_obj->root_worker != worker) { + gpr_log(GPR_INFO, "PS:%p wait %p w=%p for %dms", pollset, + worker->pollable_obj, worker, + poll_deadline_to_millis_timeout(deadline)); + } + while (do_poll && worker->pollable_obj->root_worker != worker) { + if (gpr_cv_wait(&worker->cv, &worker->pollable_obj->mu, + grpc_millis_to_timespec(deadline, GPR_CLOCK_REALTIME))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p timeout_wait %p w=%p", pollset, + worker->pollable_obj, worker); + } + do_poll = false; + } else if (worker->kicked) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PS:%p wakeup %p w=%p", pollset, + worker->pollable_obj, worker); + } + do_poll = false; + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace) && + worker->pollable_obj->root_worker != worker) { + gpr_log(GPR_INFO, "PS:%p spurious_wakeup %p w=%p", pollset, + worker->pollable_obj, worker); + } + } + grpc_core::ExecCtx::Get()->InvalidateNow(); + } else { + gpr_mu_unlock(&pollset->mu); + } + gpr_mu_unlock(&worker->pollable_obj->mu); + + return do_poll; +} + +static void end_worker(grpc_pollset* pollset, grpc_pollset_worker* worker, + grpc_pollset_worker** /*worker_hdl*/) { + GPR_TIMER_SCOPE("end_worker", 0); + gpr_mu_lock(&pollset->mu); + gpr_mu_lock(&worker->pollable_obj->mu); + switch (worker_remove(&worker->pollable_obj->root_worker, worker, + PWLINK_POLLABLE)) { + case WRR_NEW_ROOT: { + // wakeup new poller + grpc_pollset_worker* new_root = worker->pollable_obj->root_worker; + GPR_ASSERT(new_root->initialized_cv); + gpr_cv_signal(&new_root->cv); + break; + } + case WRR_EMPTIED: + if (pollset->active_pollable != worker->pollable_obj) { + // pollable no longer being polled: flush events + (void)pollable_process_events(pollset, worker->pollable_obj, true); + } + break; + case WRR_REMOVED: + break; + } + gpr_mu_unlock(&worker->pollable_obj->mu); + POLLABLE_UNREF(worker->pollable_obj, "pollset_worker"); + if (worker_remove(&pollset->root_worker, worker, PWLINK_POLLSET) == + WRR_EMPTIED) { + pollset_maybe_finish_shutdown(pollset); + } + if (worker->initialized_cv) { + gpr_cv_destroy(&worker->cv); + } + gpr_atm_no_barrier_fetch_add(&pollset->worker_count, -1); +} + +#ifndef NDEBUG +static long sys_gettid(void) { return syscall(__NR_gettid); } +#endif + +/* pollset->mu lock must be held by the caller before calling this. + The function pollset_work() may temporarily release the lock (pollset->po.mu) + during the course of its execution but it will always re-acquire the lock and + ensure that it is held by the time the function returns */ +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + GPR_TIMER_SCOPE("pollset_work", 0); +#ifdef GRPC_EPOLLEX_CREATE_WORKERS_ON_HEAP + grpc_pollset_worker* worker = + (grpc_pollset_worker*)gpr_malloc(sizeof(*worker)); +#define WORKER_PTR (worker) +#else + grpc_pollset_worker worker; +#define WORKER_PTR (&worker) +#endif +#ifndef NDEBUG + WORKER_PTR->originator = sys_gettid(); +#endif + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p work hdl=%p worker=%p now=%" PRId64 " deadline=%" PRId64 + " kwp=%d pollable=%p", + pollset, worker_hdl, WORKER_PTR, grpc_core::ExecCtx::Get()->Now(), + deadline, pollset->kicked_without_poller, pollset->active_pollable); + } + static const char* err_desc = "pollset_work"; + grpc_error_handle error = GRPC_ERROR_NONE; + if (pollset->kicked_without_poller) { + pollset->kicked_without_poller = false; + } else { + if (begin_worker(pollset, WORKER_PTR, worker_hdl, deadline)) { + g_current_thread_pollset = pollset; + g_current_thread_worker = WORKER_PTR; + if (WORKER_PTR->pollable_obj->event_cursor == + WORKER_PTR->pollable_obj->event_count) { + append_error(&error, pollable_epoll(WORKER_PTR->pollable_obj, deadline), + err_desc); + } + append_error( + &error, + pollable_process_events(pollset, WORKER_PTR->pollable_obj, false), + err_desc); + grpc_core::ExecCtx::Get()->Flush(); + g_current_thread_pollset = nullptr; + g_current_thread_worker = nullptr; + } + end_worker(pollset, WORKER_PTR, worker_hdl); + } +#ifdef GRPC_EPOLLEX_CREATE_WORKERS_ON_HEAP + gpr_free(worker); +#endif +#undef WORKER_PTR + return error; +} + +static grpc_error_handle pollset_transition_pollable_from_empty_to_fd_locked( + grpc_pollset* pollset, grpc_fd* fd) { + static const char* err_desc = "pollset_transition_pollable_from_empty_to_fd"; + grpc_error_handle error = GRPC_ERROR_NONE; + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p add fd %p (%d); transition pollable from empty to fd", + pollset, fd, fd->fd); + } + append_error(&error, pollset_kick_all(pollset), err_desc); + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + append_error(&error, get_fd_pollable(fd, &pollset->active_pollable), + err_desc); + return error; +} + +static grpc_error_handle pollset_transition_pollable_from_fd_to_multi_locked( + grpc_pollset* pollset, grpc_fd* and_add_fd) { + static const char* err_desc = "pollset_transition_pollable_from_fd_to_multi"; + grpc_error_handle error = GRPC_ERROR_NONE; + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log( + GPR_INFO, + "PS:%p add fd %p (%d); transition pollable from fd %p to multipoller", + pollset, and_add_fd, and_add_fd ? and_add_fd->fd : -1, + pollset->active_pollable->owner_fd); + } + append_error(&error, pollset_kick_all(pollset), err_desc); + grpc_fd* initial_fd = pollset->active_pollable->owner_fd; + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + pollset->active_pollable = nullptr; + if (append_error(&error, pollable_create(PO_MULTI, &pollset->active_pollable), + err_desc)) { + append_error(&error, pollable_add_fd(pollset->active_pollable, initial_fd), + err_desc); + if (and_add_fd != nullptr) { + append_error(&error, + pollable_add_fd(pollset->active_pollable, and_add_fd), + err_desc); + } + } + return error; +} + +/* expects pollsets locked, flag whether fd is locked or not */ +static grpc_error_handle pollset_add_fd_locked(grpc_pollset* pollset, + grpc_fd* fd) { + grpc_error_handle error = GRPC_ERROR_NONE; + pollable* po_at_start = + POLLABLE_REF(pollset->active_pollable, "pollset_add_fd"); + switch (pollset->active_pollable->type) { + case PO_EMPTY: + /* empty pollable --> single fd pollable */ + error = pollset_transition_pollable_from_empty_to_fd_locked(pollset, fd); + break; + case PO_FD: + gpr_mu_lock(&po_at_start->owner_orphan_mu); + if (po_at_start->owner_orphaned) { + error = + pollset_transition_pollable_from_empty_to_fd_locked(pollset, fd); + } else { + /* fd --> multipoller */ + error = + pollset_transition_pollable_from_fd_to_multi_locked(pollset, fd); + } + gpr_mu_unlock(&po_at_start->owner_orphan_mu); + break; + case PO_MULTI: + error = pollable_add_fd(pollset->active_pollable, fd); + break; + } + if (error != GRPC_ERROR_NONE) { + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + pollset->active_pollable = po_at_start; + } else { + gpr_atm_rel_store(&pollset->active_pollable_type, + pollset->active_pollable->type); + POLLABLE_UNREF(po_at_start, "pollset_add_fd"); + } + return error; +} + +static grpc_error_handle pollset_as_multipollable_locked( + grpc_pollset* pollset, pollable** pollable_obj) { + grpc_error_handle error = GRPC_ERROR_NONE; + pollable* po_at_start = + POLLABLE_REF(pollset->active_pollable, "pollset_as_multipollable"); + switch (pollset->active_pollable->type) { + case PO_EMPTY: + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + error = pollable_create(PO_MULTI, &pollset->active_pollable); + /* Any workers currently polling on this pollset must now be woked up so + * that they can pick up the new active_pollable */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, + "PS:%p active pollable transition from empty to multi", + pollset); + } + static const char* err_desc = + "pollset_as_multipollable_locked: empty -> multi"; + append_error(&error, pollset_kick_all(pollset), err_desc); + break; + case PO_FD: + gpr_mu_lock(&po_at_start->owner_orphan_mu); + if (po_at_start->owner_orphaned) { + // Unlock before Unref'ing the pollable + gpr_mu_unlock(&po_at_start->owner_orphan_mu); + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + error = pollable_create(PO_MULTI, &pollset->active_pollable); + } else { + error = pollset_transition_pollable_from_fd_to_multi_locked(pollset, + nullptr); + gpr_mu_unlock(&po_at_start->owner_orphan_mu); + } + break; + case PO_MULTI: + break; + } + if (error != GRPC_ERROR_NONE) { + POLLABLE_UNREF(pollset->active_pollable, "pollset"); + pollset->active_pollable = po_at_start; + *pollable_obj = nullptr; + } else { + gpr_atm_rel_store(&pollset->active_pollable_type, + pollset->active_pollable->type); + *pollable_obj = POLLABLE_REF(pollset->active_pollable, "pollset_set"); + POLLABLE_UNREF(po_at_start, "pollset_as_multipollable"); + } + return error; +} + +static void pollset_add_fd(grpc_pollset* pollset, grpc_fd* fd) { + GPR_TIMER_SCOPE("pollset_add_fd", 0); + + // We never transition from PO_MULTI to other modes (i.e., PO_FD or PO_EMPTY) + // and, thus, it is safe to simply store and check whether the FD has already + // been added to the active pollable previously. + if (gpr_atm_acq_load(&pollset->active_pollable_type) == PO_MULTI && + fd_has_pollset(fd, pollset)) { + return; + } + + grpc_core::MutexLockForGprMu lock(&pollset->mu); + grpc_error_handle error = pollset_add_fd_locked(pollset, fd); + + // If we are in PO_MULTI mode, we should update the pollsets of the FD. + if (gpr_atm_no_barrier_load(&pollset->active_pollable_type) == PO_MULTI) { + fd_add_pollset(fd, pollset); + } + + GRPC_LOG_IF_ERROR("pollset_add_fd", error); +} + +/******************************************************************************* + * Pollset-set Definitions + */ + +static grpc_pollset_set* pss_lock_adam(grpc_pollset_set* pss) { + gpr_mu_lock(&pss->mu); + while (pss->parent != nullptr) { + gpr_mu_unlock(&pss->mu); + pss = pss->parent; + gpr_mu_lock(&pss->mu); + } + return pss; +} + +static grpc_pollset_set* pollset_set_create(void) { + grpc_pollset_set* pss = + static_cast(gpr_zalloc(sizeof(*pss))); + gpr_mu_init(&pss->mu); + new (&pss->refs) grpc_core::RefCount(); + return pss; +} + +static void pollset_set_unref(grpc_pollset_set* pss) { + if (pss == nullptr) return; + if (GPR_LIKELY(!pss->refs.Unref())) return; + pollset_set_unref(pss->parent); + gpr_mu_destroy(&pss->mu); + for (size_t i = 0; i < pss->pollset_count; i++) { + gpr_mu_lock(&pss->pollsets[i]->mu); + if (0 == --pss->pollsets[i]->containing_pollset_set_count) { + pollset_maybe_finish_shutdown(pss->pollsets[i]); + } + gpr_mu_unlock(&pss->pollsets[i]->mu); + } + for (size_t i = 0; i < pss->fd_count; i++) { + UNREF_BY(pss->fds[i], 2, "pollset_set"); + } + gpr_free(pss->pollsets); + gpr_free(pss->fds); + gpr_free(pss); +} + +static void pollset_set_add_fd(grpc_pollset_set* pss, grpc_fd* fd) { + GPR_TIMER_SCOPE("pollset_set_add_fd", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS:%p: add fd %p (%d)", pss, fd, fd->fd); + } + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "pollset_set_add_fd"; + pss = pss_lock_adam(pss); + for (size_t i = 0; i < pss->pollset_count; i++) { + append_error(&error, pollable_add_fd(pss->pollsets[i]->active_pollable, fd), + err_desc); + } + if (pss->fd_count == pss->fd_capacity) { + pss->fd_capacity = std::max(pss->fd_capacity * 2, size_t(8)); + pss->fds = static_cast( + gpr_realloc(pss->fds, pss->fd_capacity * sizeof(*pss->fds))); + } + REF_BY(fd, 2, "pollset_set"); + pss->fds[pss->fd_count++] = fd; + gpr_mu_unlock(&pss->mu); + + GRPC_LOG_IF_ERROR(err_desc, error); +} + +static void pollset_set_del_fd(grpc_pollset_set* pss, grpc_fd* fd) { + GPR_TIMER_SCOPE("pollset_set_del_fd", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS:%p: del fd %p", pss, fd); + } + pss = pss_lock_adam(pss); + size_t i; + for (i = 0; i < pss->fd_count; i++) { + if (pss->fds[i] == fd) { + UNREF_BY(fd, 2, "pollset_set"); + break; + } + } + GPR_ASSERT(i != pss->fd_count); + for (; i < pss->fd_count - 1; i++) { + pss->fds[i] = pss->fds[i + 1]; + } + pss->fd_count--; + gpr_mu_unlock(&pss->mu); +} + +static void pollset_set_del_pollset(grpc_pollset_set* pss, grpc_pollset* ps) { + GPR_TIMER_SCOPE("pollset_set_del_pollset", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS:%p: del pollset %p", pss, ps); + } + pss = pss_lock_adam(pss); + size_t i; + for (i = 0; i < pss->pollset_count; i++) { + if (pss->pollsets[i] == ps) { + break; + } + } + GPR_ASSERT(i != pss->pollset_count); + for (; i < pss->pollset_count - 1; i++) { + pss->pollsets[i] = pss->pollsets[i + 1]; + } + pss->pollset_count--; + gpr_mu_unlock(&pss->mu); + gpr_mu_lock(&ps->mu); + if (0 == --ps->containing_pollset_set_count) { + pollset_maybe_finish_shutdown(ps); + } + gpr_mu_unlock(&ps->mu); +} + +// add all fds to pollables, and output a new array of unorphaned out_fds +// assumes pollsets are multipollable +static grpc_error_handle add_fds_to_pollsets(grpc_fd** fds, size_t fd_count, + grpc_pollset** pollsets, + size_t pollset_count, + const char* err_desc, + grpc_fd** out_fds, + size_t* out_fd_count) { + GPR_TIMER_SCOPE("add_fds_to_pollsets", 0); + grpc_error_handle error = GRPC_ERROR_NONE; + for (size_t i = 0; i < fd_count; i++) { + gpr_mu_lock(&fds[i]->orphan_mu); + if ((gpr_atm_no_barrier_load(&fds[i]->refst) & 1) == 0) { + gpr_mu_unlock(&fds[i]->orphan_mu); + UNREF_BY(fds[i], 2, "pollset_set"); + } else { + for (size_t j = 0; j < pollset_count; j++) { + append_error(&error, + pollable_add_fd(pollsets[j]->active_pollable, fds[i]), + err_desc); + } + gpr_mu_unlock(&fds[i]->orphan_mu); + out_fds[(*out_fd_count)++] = fds[i]; + } + } + return error; +} + +static void pollset_set_add_pollset(grpc_pollset_set* pss, grpc_pollset* ps) { + GPR_TIMER_SCOPE("pollset_set_add_pollset", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS:%p: add pollset %p", pss, ps); + } + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "pollset_set_add_pollset"; + pollable* pollable_obj = nullptr; + gpr_mu_lock(&ps->mu); + if (!GRPC_LOG_IF_ERROR(err_desc, + pollset_as_multipollable_locked(ps, &pollable_obj))) { + GPR_ASSERT(pollable_obj == nullptr); + gpr_mu_unlock(&ps->mu); + return; + } + ps->containing_pollset_set_count++; + gpr_mu_unlock(&ps->mu); + pss = pss_lock_adam(pss); + size_t initial_fd_count = pss->fd_count; + pss->fd_count = 0; + append_error(&error, + add_fds_to_pollsets(pss->fds, initial_fd_count, &ps, 1, err_desc, + pss->fds, &pss->fd_count), + err_desc); + if (pss->pollset_count == pss->pollset_capacity) { + pss->pollset_capacity = std::max(pss->pollset_capacity * 2, size_t(8)); + pss->pollsets = static_cast(gpr_realloc( + pss->pollsets, pss->pollset_capacity * sizeof(*pss->pollsets))); + } + pss->pollsets[pss->pollset_count++] = ps; + gpr_mu_unlock(&pss->mu); + POLLABLE_UNREF(pollable_obj, "pollset_set"); + + GRPC_LOG_IF_ERROR(err_desc, error); +} + +static void pollset_set_add_pollset_set(grpc_pollset_set* a, + grpc_pollset_set* b) { + GPR_TIMER_SCOPE("pollset_set_add_pollset_set", 0); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS: merge (%p, %p)", a, b); + } + grpc_error_handle error = GRPC_ERROR_NONE; + static const char* err_desc = "pollset_set_add_fd"; + for (;;) { + if (a == b) { + // pollset ancestors are the same: nothing to do + return; + } + if (a > b) { + std::swap(a, b); + } + gpr_mu* a_mu = &a->mu; + gpr_mu* b_mu = &b->mu; + gpr_mu_lock(a_mu); + gpr_mu_lock(b_mu); + if (a->parent != nullptr) { + a = a->parent; + } else if (b->parent != nullptr) { + b = b->parent; + } else { + break; // exit loop, both pollsets locked + } + gpr_mu_unlock(a_mu); + gpr_mu_unlock(b_mu); + } + // try to do the least copying possible + // TODO(sreek): there's probably a better heuristic here + const size_t a_size = a->fd_count + a->pollset_count; + const size_t b_size = b->fd_count + b->pollset_count; + if (b_size > a_size) { + std::swap(a, b); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "PSS: parent %p to %p", b, a); + } + a->refs.Ref(); + b->parent = a; + if (a->fd_capacity < a->fd_count + b->fd_count) { + a->fd_capacity = std::max(2 * a->fd_capacity, a->fd_count + b->fd_count); + a->fds = static_cast( + gpr_realloc(a->fds, a->fd_capacity * sizeof(*a->fds))); + } + size_t initial_a_fd_count = a->fd_count; + a->fd_count = 0; + append_error( + &error, + add_fds_to_pollsets(a->fds, initial_a_fd_count, b->pollsets, + b->pollset_count, "merge_a2b", a->fds, &a->fd_count), + err_desc); + append_error( + &error, + add_fds_to_pollsets(b->fds, b->fd_count, a->pollsets, a->pollset_count, + "merge_b2a", a->fds, &a->fd_count), + err_desc); + if (a->pollset_capacity < a->pollset_count + b->pollset_count) { + a->pollset_capacity = + std::max(2 * a->pollset_capacity, a->pollset_count + b->pollset_count); + a->pollsets = static_cast( + gpr_realloc(a->pollsets, a->pollset_capacity * sizeof(*a->pollsets))); + } + if (b->pollset_count > 0) { + memcpy(a->pollsets + a->pollset_count, b->pollsets, + b->pollset_count * sizeof(*b->pollsets)); + } + a->pollset_count += b->pollset_count; + gpr_free(b->fds); + gpr_free(b->pollsets); + b->fds = nullptr; + b->pollsets = nullptr; + b->fd_count = b->fd_capacity = b->pollset_count = b->pollset_capacity = 0; + gpr_mu_unlock(&a->mu); + gpr_mu_unlock(&b->mu); +} + +static void pollset_set_del_pollset_set(grpc_pollset_set* /*bag*/, + grpc_pollset_set* /*item*/) {} + +/******************************************************************************* + * Event engine binding + */ + +static bool is_any_background_poller_thread(void) { return false; } + +static void shutdown_background_closure(void) {} + +static bool add_closure_to_background_poller(grpc_closure* /*closure*/, + grpc_error_handle /*error*/) { + return false; +} + +static void shutdown_engine(void) { + fd_global_shutdown(); + pollset_global_shutdown(); +} + +static const grpc_event_engine_vtable vtable = { + sizeof(grpc_pollset), + true, + false, + + fd_create, + fd_wrapped_fd, + fd_orphan, + fd_shutdown, + fd_notify_on_read, + fd_notify_on_write, + fd_notify_on_error, + fd_become_readable, + fd_become_writable, + fd_has_errors, + fd_is_shutdown, + + pollset_init, + pollset_shutdown, + pollset_destroy, + pollset_work, + pollset_kick, + pollset_add_fd, + + pollset_set_create, + pollset_set_unref, // destroy ==> unref 1 public ref + pollset_set_add_pollset, + pollset_set_del_pollset, + pollset_set_add_pollset_set, + pollset_set_del_pollset_set, + pollset_set_add_fd, + pollset_set_del_fd, + + is_any_background_poller_thread, + shutdown_background_closure, + shutdown_engine, + add_closure_to_background_poller, +}; + +const grpc_event_engine_vtable* grpc_init_epollex_linux( + bool /*explicitly_requested*/) { + if (!grpc_has_wakeup_fd()) { + gpr_log(GPR_ERROR, "Skipping epollex because of no wakeup fd."); + return nullptr; + } + + if (!grpc_is_epollexclusive_available()) { + gpr_log(GPR_INFO, "Skipping epollex because it is not supported."); + return nullptr; + } + + fd_global_init(); + + if (!GRPC_LOG_IF_ERROR("pollset_global_init", pollset_global_init())) { + pollset_global_shutdown(); + fd_global_shutdown(); + return nullptr; + } + + return &vtable; +} + +#else /* defined(GRPC_LINUX_EPOLL_CREATE1) */ +#if defined(GRPC_POSIX_SOCKET_EV_EPOLLEX) +#include "src/core/lib/iomgr/ev_epollex_linux.h" +/* If GRPC_LINUX_EPOLL_CREATE1 is not defined, it means + epoll_create1 is not available. Return NULL */ +const grpc_event_engine_vtable* grpc_init_epollex_linux( + bool /*explicitly_requested*/) { + return nullptr; +} +#endif /* defined(GRPC_POSIX_SOCKET_EV_EPOLLEX) */ + +#endif /* !defined(GRPC_LINUX_EPOLL_CREATE1) */ diff --git a/src/core/lib/iomgr/ev_poll_posix.cc b/src/core/lib/iomgr/ev_poll_posix.cc new file mode 100644 index 00000000..03dec3f5 --- /dev/null +++ b/src/core/lib/iomgr/ev_poll_posix.cc @@ -0,0 +1,1430 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_EV_POLL + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/block_annotate.h" +#include "src/core/lib/iomgr/ev_poll_posix.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "src/core/lib/profiling/timers.h" + +#define GRPC_POLLSET_KICK_BROADCAST ((grpc_pollset_worker*)1) + +/******************************************************************************* + * FD declarations + */ +typedef struct grpc_fd_watcher { + struct grpc_fd_watcher* next; + struct grpc_fd_watcher* prev; + grpc_pollset* pollset; + grpc_pollset_worker* worker; + grpc_fd* fd; +} grpc_fd_watcher; + +typedef struct grpc_cached_wakeup_fd grpc_cached_wakeup_fd; + +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +struct grpc_fork_fd_list { + /* Only one of fd or cached_wakeup_fd will be set. The unused field will be + set to nullptr. */ + grpc_fd* fd; + grpc_cached_wakeup_fd* cached_wakeup_fd; + + grpc_fork_fd_list* next; + grpc_fork_fd_list* prev; +}; + +struct grpc_fd { + int fd; + /* refst format: + bit0: 1=active/0=orphaned + bit1-n: refcount + meaning that mostly we ref by two to avoid altering the orphaned bit, + and just unref by 1 when we're ready to flag the object as orphaned */ + gpr_atm refst; + + gpr_mu mu; + int shutdown; + int closed; + int released; + gpr_atm pollhup; + grpc_error_handle shutdown_error; + + /* The watcher list. + + The following watcher related fields are protected by watcher_mu. + + An fd_watcher is an ephemeral object created when an fd wants to + begin polling, and destroyed after the poll. + + It denotes the fd's interest in whether to read poll or write poll + or both or neither on this fd. + + If a watcher is asked to poll for reads or writes, the read_watcher + or write_watcher fields are set respectively. A watcher may be asked + to poll for both, in which case both fields will be set. + + read_watcher and write_watcher may be NULL if no watcher has been + asked to poll for reads or writes. + + If an fd_watcher is not asked to poll for reads or writes, it's added + to a linked list of inactive watchers, rooted at inactive_watcher_root. + If at a later time there becomes need of a poller to poll, one of + the inactive pollers may be kicked out of their poll loops to take + that responsibility. */ + grpc_fd_watcher inactive_watcher_root; + grpc_fd_watcher* read_watcher; + grpc_fd_watcher* write_watcher; + + grpc_closure* read_closure; + grpc_closure* write_closure; + + grpc_closure* on_done_closure; + + grpc_iomgr_object iomgr_object; + + /* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ + grpc_fork_fd_list* fork_fd_list; +}; + +/* True when GRPC_ENABLE_FORK_SUPPORT=1. */ +static bool track_fds_for_fork = false; + +/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ +static grpc_fork_fd_list* fork_fd_list_head = nullptr; +static gpr_mu fork_fd_list_mu; + +/* Begin polling on an fd. + Registers that the given pollset is interested in this fd - so that if read + or writability interest changes, the pollset can be kicked to pick up that + new interest. + Return value is: + (fd_needs_read? read_mask : 0) | (fd_needs_write? write_mask : 0) + i.e. a combination of read_mask and write_mask determined by the fd's current + interest in said events. + Polling strategies that do not need to alter their behavior depending on the + fd's current interest (such as epoll) do not need to call this function. + MUST NOT be called with a pollset lock taken */ +static uint32_t fd_begin_poll(grpc_fd* fd, grpc_pollset* pollset, + grpc_pollset_worker* worker, uint32_t read_mask, + uint32_t write_mask, grpc_fd_watcher* watcher); +/* Complete polling previously started with fd_begin_poll + MUST NOT be called with a pollset lock taken + if got_read or got_write are 1, also does the become_{readable,writable} as + appropriate. */ +static void fd_end_poll(grpc_fd_watcher* watcher, int got_read, int got_write); + +/* Return 1 if this fd is orphaned, 0 otherwise */ +static bool fd_is_orphaned(grpc_fd* fd); + +#ifndef NDEBUG +static void fd_ref(grpc_fd* fd, const char* reason, const char* file, int line); +static void fd_unref(grpc_fd* fd, const char* reason, const char* file, + int line); +#define GRPC_FD_REF(fd, reason) fd_ref(fd, reason, __FILE__, __LINE__) +#define GRPC_FD_UNREF(fd, reason) fd_unref(fd, reason, __FILE__, __LINE__) +#else +static void fd_ref(grpc_fd* fd); +static void fd_unref(grpc_fd* fd); +#define GRPC_FD_REF(fd, reason) fd_ref(fd) +#define GRPC_FD_UNREF(fd, reason) fd_unref(fd) +#endif + +#define CLOSURE_NOT_READY ((grpc_closure*)0) +#define CLOSURE_READY ((grpc_closure*)1) + +/******************************************************************************* + * pollset declarations + */ + +typedef struct grpc_cached_wakeup_fd { + grpc_wakeup_fd fd; + struct grpc_cached_wakeup_fd* next; + + /* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */ + grpc_fork_fd_list* fork_fd_list; +} grpc_cached_wakeup_fd; + +struct grpc_pollset_worker { + grpc_cached_wakeup_fd* wakeup_fd; + int reevaluate_polling_on_wakeup; + int kicked_specifically; + struct grpc_pollset_worker* next; + struct grpc_pollset_worker* prev; +}; + +struct grpc_pollset { + gpr_mu mu; + grpc_pollset_worker root_worker; + int shutting_down; + int called_shutdown; + int kicked_without_pollers; + grpc_closure* shutdown_done; + int pollset_set_count; + /* all polled fds */ + size_t fd_count; + size_t fd_capacity; + grpc_fd** fds; + /* Local cache of eventfds for workers */ + grpc_cached_wakeup_fd* local_wakeup_cache; +}; + +/* Add an fd to a pollset */ +static void pollset_add_fd(grpc_pollset* pollset, struct grpc_fd* fd); + +static void pollset_set_add_fd(grpc_pollset_set* pollset_set, grpc_fd* fd); + +/* Convert a timespec to milliseconds: + - very small or negative poll times are clamped to zero to do a + non-blocking poll (which becomes spin polling) + - other small values are rounded up to one millisecond + - longer than a millisecond polls are rounded up to the next nearest + millisecond to avoid spinning + - infinite timeouts are converted to -1 */ +static int poll_deadline_to_millis_timeout(grpc_millis deadline); + +/* Allow kick to wakeup the currently polling worker */ +#define GRPC_POLLSET_CAN_KICK_SELF 1 +/* Force the wakee to repoll when awoken */ +#define GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP 2 +/* As per pollset_kick, with an extended set of flags (defined above) + -- mostly for fd_posix's use. */ +static grpc_error_handle pollset_kick_ext(grpc_pollset* p, + grpc_pollset_worker* specific_worker, + uint32_t flags) GRPC_MUST_USE_RESULT; + +/* Return 1 if the pollset has active threads in pollset_work (pollset must + * be locked) */ +static bool pollset_has_workers(grpc_pollset* pollset); + +/******************************************************************************* + * pollset_set definitions + */ + +struct grpc_pollset_set { + gpr_mu mu; + + size_t pollset_count; + size_t pollset_capacity; + grpc_pollset** pollsets; + + size_t pollset_set_count; + size_t pollset_set_capacity; + struct grpc_pollset_set** pollset_sets; + + size_t fd_count; + size_t fd_capacity; + grpc_fd** fds; +}; + +/******************************************************************************* + * functions to track opened fds. No-ops unless track_fds_for_fork is true. + */ + +static void fork_fd_list_remove_node(grpc_fork_fd_list* node) { + if (track_fds_for_fork) { + gpr_mu_lock(&fork_fd_list_mu); + if (fork_fd_list_head == node) { + fork_fd_list_head = node->next; + } + if (node->prev != nullptr) { + node->prev->next = node->next; + } + if (node->next != nullptr) { + node->next->prev = node->prev; + } + gpr_free(node); + gpr_mu_unlock(&fork_fd_list_mu); + } +} + +static void fork_fd_list_add_node(grpc_fork_fd_list* node) { + gpr_mu_lock(&fork_fd_list_mu); + node->next = fork_fd_list_head; + node->prev = nullptr; + if (fork_fd_list_head != nullptr) { + fork_fd_list_head->prev = node; + } + fork_fd_list_head = node; + gpr_mu_unlock(&fork_fd_list_mu); +} + +static void fork_fd_list_add_grpc_fd(grpc_fd* fd) { + if (track_fds_for_fork) { + fd->fork_fd_list = + static_cast(gpr_malloc(sizeof(grpc_fork_fd_list))); + fd->fork_fd_list->fd = fd; + fd->fork_fd_list->cached_wakeup_fd = nullptr; + fork_fd_list_add_node(fd->fork_fd_list); + } +} + +static void fork_fd_list_add_wakeup_fd(grpc_cached_wakeup_fd* fd) { + if (track_fds_for_fork) { + fd->fork_fd_list = + static_cast(gpr_malloc(sizeof(grpc_fork_fd_list))); + fd->fork_fd_list->cached_wakeup_fd = fd; + fd->fork_fd_list->fd = nullptr; + fork_fd_list_add_node(fd->fork_fd_list); + } +} + +/******************************************************************************* + * fd_posix.c + */ + +#ifndef NDEBUG +#define REF_BY(fd, n, reason) ref_by(fd, n, reason, __FILE__, __LINE__) +#define UNREF_BY(fd, n, reason) unref_by(fd, n, reason, __FILE__, __LINE__) +static void ref_by(grpc_fd* fd, int n, const char* reason, const char* file, + int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, + "FD %d %p ref %d %" PRIdPTR " -> %" PRIdPTR " [%s; %s:%d]", + fd->fd, fd, n, gpr_atm_no_barrier_load(&fd->refst), + gpr_atm_no_barrier_load(&fd->refst) + n, reason, file, line); + } +#else +#define REF_BY(fd, n, reason) \ + do { \ + ref_by(fd, n); \ + (void)(reason); \ + } while (0) +#define UNREF_BY(fd, n, reason) \ + do { \ + unref_by(fd, n); \ + (void)(reason); \ + } while (0) +static void ref_by(grpc_fd* fd, int n) { +#endif + GPR_ASSERT(gpr_atm_no_barrier_fetch_add(&fd->refst, n) > 0); +} + +#ifndef NDEBUG +static void unref_by(grpc_fd* fd, int n, const char* reason, const char* file, + int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_fd_refcount)) { + gpr_log(GPR_DEBUG, + "FD %d %p unref %d %" PRIdPTR " -> %" PRIdPTR " [%s; %s:%d]", + fd->fd, fd, n, gpr_atm_no_barrier_load(&fd->refst), + gpr_atm_no_barrier_load(&fd->refst) - n, reason, file, line); + } +#else +static void unref_by(grpc_fd* fd, int n) { +#endif + gpr_atm old = gpr_atm_full_fetch_add(&fd->refst, -n); + if (old == n) { + gpr_mu_destroy(&fd->mu); + grpc_iomgr_unregister_object(&fd->iomgr_object); + fork_fd_list_remove_node(fd->fork_fd_list); + if (fd->shutdown) { + GRPC_ERROR_UNREF(fd->shutdown_error); + } +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + fd->shutdown_error.~Status(); +#endif + gpr_free(fd); + } else { + GPR_ASSERT(old > n); + } +} + +static grpc_fd* fd_create(int fd, const char* name, bool track_err) { + // Avoid unused-parameter warning for debug-only parameter + (void)track_err; + GPR_DEBUG_ASSERT(track_err == false); + grpc_fd* r = static_cast(gpr_malloc(sizeof(*r))); + gpr_mu_init(&r->mu); + gpr_atm_rel_store(&r->refst, 1); + r->shutdown = 0; +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + new (&r->shutdown_error) absl::Status(); +#endif + r->read_closure = CLOSURE_NOT_READY; + r->write_closure = CLOSURE_NOT_READY; + r->fd = fd; + r->inactive_watcher_root.next = r->inactive_watcher_root.prev = + &r->inactive_watcher_root; + r->read_watcher = r->write_watcher = nullptr; + r->on_done_closure = nullptr; + r->closed = 0; + r->released = 0; + gpr_atm_no_barrier_store(&r->pollhup, 0); + + std::string name2 = absl::StrCat(name, " fd=", fd); + grpc_iomgr_register_object(&r->iomgr_object, name2.c_str()); + fork_fd_list_add_grpc_fd(r); + return r; +} + +static bool fd_is_orphaned(grpc_fd* fd) { + return (gpr_atm_acq_load(&fd->refst) & 1) == 0; +} + +static grpc_error_handle pollset_kick_locked(grpc_fd_watcher* watcher) { + gpr_mu_lock(&watcher->pollset->mu); + GPR_ASSERT(watcher->worker); + grpc_error_handle err = + pollset_kick_ext(watcher->pollset, watcher->worker, + GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP); + gpr_mu_unlock(&watcher->pollset->mu); + return err; +} + +static void maybe_wake_one_watcher_locked(grpc_fd* fd) { + if (fd->inactive_watcher_root.next != &fd->inactive_watcher_root) { + (void)pollset_kick_locked(fd->inactive_watcher_root.next); + } else if (fd->read_watcher) { + (void)pollset_kick_locked(fd->read_watcher); + } else if (fd->write_watcher) { + (void)pollset_kick_locked(fd->write_watcher); + } +} + +static void wake_all_watchers_locked(grpc_fd* fd) { + grpc_fd_watcher* watcher; + for (watcher = fd->inactive_watcher_root.next; + watcher != &fd->inactive_watcher_root; watcher = watcher->next) { + (void)pollset_kick_locked(watcher); + } + if (fd->read_watcher) { + (void)pollset_kick_locked(fd->read_watcher); + } + if (fd->write_watcher && fd->write_watcher != fd->read_watcher) { + (void)pollset_kick_locked(fd->write_watcher); + } +} + +static int has_watchers(grpc_fd* fd) { + return fd->read_watcher != nullptr || fd->write_watcher != nullptr || + fd->inactive_watcher_root.next != &fd->inactive_watcher_root; +} + +static void close_fd_locked(grpc_fd* fd) { + fd->closed = 1; + if (!fd->released) { + close(fd->fd); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, fd->on_done_closure, GRPC_ERROR_NONE); +} + +static int fd_wrapped_fd(grpc_fd* fd) { + if (fd->released || fd->closed) { + return -1; + } else { + return fd->fd; + } +} + +static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd, + const char* reason) { + fd->on_done_closure = on_done; + fd->released = release_fd != nullptr; + if (release_fd != nullptr) { + *release_fd = fd->fd; + fd->released = true; + } + gpr_mu_lock(&fd->mu); + REF_BY(fd, 1, reason); /* remove active status, but keep referenced */ + if (!has_watchers(fd)) { + close_fd_locked(fd); + } else { + wake_all_watchers_locked(fd); + } + gpr_mu_unlock(&fd->mu); + UNREF_BY(fd, 2, reason); /* drop the reference */ +} + +/* increment refcount by two to avoid changing the orphan bit */ +#ifndef NDEBUG +static void fd_ref(grpc_fd* fd, const char* reason, const char* file, + int line) { + ref_by(fd, 2, reason, file, line); +} + +static void fd_unref(grpc_fd* fd, const char* reason, const char* file, + int line) { + unref_by(fd, 2, reason, file, line); +} +#else +static void fd_ref(grpc_fd* fd) { ref_by(fd, 2); } + +static void fd_unref(grpc_fd* fd) { unref_by(fd, 2); } +#endif + +static grpc_error_handle fd_shutdown_error(grpc_fd* fd) { + if (!fd->shutdown) { + return GRPC_ERROR_NONE; + } else { + return grpc_error_set_int(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "FD shutdown", &fd->shutdown_error, 1), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE); + } +} + +static void notify_on_locked(grpc_fd* fd, grpc_closure** st, + grpc_closure* closure) { + if (fd->shutdown || gpr_atm_no_barrier_load(&fd->pollhup)) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, closure, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("FD shutdown"), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE)); + } else if (*st == CLOSURE_NOT_READY) { + /* not ready ==> switch to a waiting state by setting the closure */ + *st = closure; + } else if (*st == CLOSURE_READY) { + /* already ready ==> queue the closure to run immediately */ + *st = CLOSURE_NOT_READY; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, fd_shutdown_error(fd)); + maybe_wake_one_watcher_locked(fd); + } else { + /* upcallptr was set to a different closure. This is an error! */ + gpr_log(GPR_ERROR, + "User called a notify_on function with a previous callback still " + "pending"); + abort(); + } +} + +/* returns 1 if state becomes not ready */ +static int set_ready_locked(grpc_fd* fd, grpc_closure** st) { + if (*st == CLOSURE_READY) { + /* duplicate ready ==> ignore */ + return 0; + } else if (*st == CLOSURE_NOT_READY) { + /* not ready, and not waiting ==> flag ready */ + *st = CLOSURE_READY; + return 0; + } else { + /* waiting ==> queue closure */ + grpc_core::ExecCtx::Run(DEBUG_LOCATION, *st, fd_shutdown_error(fd)); + *st = CLOSURE_NOT_READY; + return 1; + } +} + +static void fd_shutdown(grpc_fd* fd, grpc_error_handle why) { + gpr_mu_lock(&fd->mu); + /* only shutdown once */ + if (!fd->shutdown) { + fd->shutdown = 1; + fd->shutdown_error = why; + /* signal read/write closed to OS so that future operations fail */ + shutdown(fd->fd, SHUT_RDWR); + set_ready_locked(fd, &fd->read_closure); + set_ready_locked(fd, &fd->write_closure); + } else { + GRPC_ERROR_UNREF(why); + } + gpr_mu_unlock(&fd->mu); +} + +static bool fd_is_shutdown(grpc_fd* fd) { + gpr_mu_lock(&fd->mu); + bool r = fd->shutdown; + gpr_mu_unlock(&fd->mu); + return r; +} + +static void fd_notify_on_read(grpc_fd* fd, grpc_closure* closure) { + gpr_mu_lock(&fd->mu); + notify_on_locked(fd, &fd->read_closure, closure); + gpr_mu_unlock(&fd->mu); +} + +static void fd_notify_on_write(grpc_fd* fd, grpc_closure* closure) { + gpr_mu_lock(&fd->mu); + notify_on_locked(fd, &fd->write_closure, closure); + gpr_mu_unlock(&fd->mu); +} + +static void fd_notify_on_error(grpc_fd* /*fd*/, grpc_closure* closure) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_ERROR, "Polling engine does not support tracking errors."); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_CANCELLED); +} + +static void fd_set_readable(grpc_fd* fd) { + gpr_mu_lock(&fd->mu); + set_ready_locked(fd, &fd->read_closure); + gpr_mu_unlock(&fd->mu); +} + +static void fd_set_writable(grpc_fd* fd) { + gpr_mu_lock(&fd->mu); + set_ready_locked(fd, &fd->write_closure); + gpr_mu_unlock(&fd->mu); +} + +static void fd_set_error(grpc_fd* /*fd*/) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_ERROR, "Polling engine does not support tracking errors."); + } +} + +static uint32_t fd_begin_poll(grpc_fd* fd, grpc_pollset* pollset, + grpc_pollset_worker* worker, uint32_t read_mask, + uint32_t write_mask, grpc_fd_watcher* watcher) { + uint32_t mask = 0; + grpc_closure* cur; + int requested; + /* keep track of pollers that have requested our events, in case they change + */ + GRPC_FD_REF(fd, "poll"); + + gpr_mu_lock(&fd->mu); + + /* if we are shutdown, then don't add to the watcher set */ + if (fd->shutdown) { + watcher->fd = nullptr; + watcher->pollset = nullptr; + watcher->worker = nullptr; + gpr_mu_unlock(&fd->mu); + GRPC_FD_UNREF(fd, "poll"); + return 0; + } + + /* if there is nobody polling for read, but we need to, then start doing so */ + cur = fd->read_closure; + requested = cur != CLOSURE_READY; + if (read_mask && fd->read_watcher == nullptr && requested) { + fd->read_watcher = watcher; + mask |= read_mask; + } + /* if there is nobody polling for write, but we need to, then start doing so + */ + cur = fd->write_closure; + requested = cur != CLOSURE_READY; + if (write_mask && fd->write_watcher == nullptr && requested) { + fd->write_watcher = watcher; + mask |= write_mask; + } + /* if not polling, remember this watcher in case we need someone to later */ + if (mask == 0 && worker != nullptr) { + watcher->next = &fd->inactive_watcher_root; + watcher->prev = watcher->next->prev; + watcher->next->prev = watcher->prev->next = watcher; + } + watcher->pollset = pollset; + watcher->worker = worker; + watcher->fd = fd; + gpr_mu_unlock(&fd->mu); + + return mask; +} + +static void fd_end_poll(grpc_fd_watcher* watcher, int got_read, int got_write) { + int was_polling = 0; + int kick = 0; + grpc_fd* fd = watcher->fd; + + if (fd == nullptr) { + return; + } + + gpr_mu_lock(&fd->mu); + + if (watcher == fd->read_watcher) { + /* remove read watcher, kick if we still need a read */ + was_polling = 1; + if (!got_read) { + kick = 1; + } + fd->read_watcher = nullptr; + } + if (watcher == fd->write_watcher) { + /* remove write watcher, kick if we still need a write */ + was_polling = 1; + if (!got_write) { + kick = 1; + } + fd->write_watcher = nullptr; + } + if (!was_polling && watcher->worker != nullptr) { + /* remove from inactive list */ + watcher->next->prev = watcher->prev; + watcher->prev->next = watcher->next; + } + if (got_read) { + if (set_ready_locked(fd, &fd->read_closure)) { + kick = 1; + } + } + if (got_write) { + if (set_ready_locked(fd, &fd->write_closure)) { + kick = 1; + } + } + if (kick) { + maybe_wake_one_watcher_locked(fd); + } + if (fd_is_orphaned(fd) && !has_watchers(fd) && !fd->closed) { + close_fd_locked(fd); + } + gpr_mu_unlock(&fd->mu); + + GRPC_FD_UNREF(fd, "poll"); +} + +/******************************************************************************* + * pollset_posix.c + */ + +static GPR_THREAD_LOCAL(grpc_pollset*) g_current_thread_poller; +static GPR_THREAD_LOCAL(grpc_pollset_worker*) g_current_thread_worker; + +static void remove_worker(grpc_pollset* /*p*/, grpc_pollset_worker* worker) { + worker->prev->next = worker->next; + worker->next->prev = worker->prev; +} + +static bool pollset_has_workers(grpc_pollset* p) { + return p->root_worker.next != &p->root_worker; +} + +static bool pollset_in_pollset_sets(grpc_pollset* p) { + return p->pollset_set_count; +} + +static bool pollset_has_observers(grpc_pollset* p) { + return pollset_has_workers(p) || pollset_in_pollset_sets(p); +} + +static grpc_pollset_worker* pop_front_worker(grpc_pollset* p) { + if (pollset_has_workers(p)) { + grpc_pollset_worker* w = p->root_worker.next; + remove_worker(p, w); + return w; + } else { + return nullptr; + } +} + +static void push_back_worker(grpc_pollset* p, grpc_pollset_worker* worker) { + worker->next = &p->root_worker; + worker->prev = worker->next->prev; + worker->prev->next = worker->next->prev = worker; +} + +static void push_front_worker(grpc_pollset* p, grpc_pollset_worker* worker) { + worker->prev = &p->root_worker; + worker->next = worker->prev->next; + worker->prev->next = worker->next->prev = worker; +} + +static void kick_append_error(grpc_error_handle* composite, + grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) return; + if (*composite == GRPC_ERROR_NONE) { + *composite = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Kick Failure"); + } + *composite = grpc_error_add_child(*composite, error); +} + +static grpc_error_handle pollset_kick_ext(grpc_pollset* p, + grpc_pollset_worker* specific_worker, + uint32_t flags) { + GPR_TIMER_SCOPE("pollset_kick_ext", 0); + grpc_error_handle error = GRPC_ERROR_NONE; + GRPC_STATS_INC_POLLSET_KICK(); + + /* pollset->mu already held */ + if (specific_worker != nullptr) { + if (specific_worker == GRPC_POLLSET_KICK_BROADCAST) { + GPR_TIMER_SCOPE("pollset_kick_ext.broadcast", 0); + GPR_ASSERT((flags & GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP) == 0); + for (specific_worker = p->root_worker.next; + specific_worker != &p->root_worker; + specific_worker = specific_worker->next) { + kick_append_error( + &error, grpc_wakeup_fd_wakeup(&specific_worker->wakeup_fd->fd)); + } + p->kicked_without_pollers = true; + } else if (g_current_thread_worker != specific_worker) { + GPR_TIMER_MARK("different_thread_worker", 0); + if ((flags & GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP) != 0) { + specific_worker->reevaluate_polling_on_wakeup = true; + } + specific_worker->kicked_specifically = true; + kick_append_error(&error, + grpc_wakeup_fd_wakeup(&specific_worker->wakeup_fd->fd)); + } else if ((flags & GRPC_POLLSET_CAN_KICK_SELF) != 0) { + GPR_TIMER_MARK("kick_yoself", 0); + if ((flags & GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP) != 0) { + specific_worker->reevaluate_polling_on_wakeup = true; + } + specific_worker->kicked_specifically = true; + kick_append_error(&error, + grpc_wakeup_fd_wakeup(&specific_worker->wakeup_fd->fd)); + } + } else if (g_current_thread_poller != p) { + GPR_ASSERT((flags & GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP) == 0); + GPR_TIMER_MARK("kick_anonymous", 0); + specific_worker = pop_front_worker(p); + if (specific_worker != nullptr) { + if (g_current_thread_worker == specific_worker) { + GPR_TIMER_MARK("kick_anonymous_not_self", 0); + push_back_worker(p, specific_worker); + specific_worker = pop_front_worker(p); + if ((flags & GRPC_POLLSET_CAN_KICK_SELF) == 0 && + g_current_thread_worker == specific_worker) { + push_back_worker(p, specific_worker); + specific_worker = nullptr; + } + } + if (specific_worker != nullptr) { + GPR_TIMER_MARK("finally_kick", 0); + push_back_worker(p, specific_worker); + kick_append_error( + &error, grpc_wakeup_fd_wakeup(&specific_worker->wakeup_fd->fd)); + } + } else { + GPR_TIMER_MARK("kicked_no_pollers", 0); + p->kicked_without_pollers = true; + } + } + + GRPC_LOG_IF_ERROR("pollset_kick_ext", GRPC_ERROR_REF(error)); + return error; +} + +static grpc_error_handle pollset_kick(grpc_pollset* p, + grpc_pollset_worker* specific_worker) { + return pollset_kick_ext(p, specific_worker, 0); +} + +/* global state management */ + +static grpc_error_handle pollset_global_init(void) { return GRPC_ERROR_NONE; } + +static void pollset_global_shutdown(void) {} + +/* main interface */ + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + gpr_mu_init(&pollset->mu); + *mu = &pollset->mu; + pollset->root_worker.next = pollset->root_worker.prev = &pollset->root_worker; + pollset->shutting_down = 0; + pollset->called_shutdown = 0; + pollset->kicked_without_pollers = 0; + pollset->local_wakeup_cache = nullptr; + pollset->kicked_without_pollers = 0; + pollset->fd_count = 0; + pollset->fd_capacity = 0; + pollset->fds = nullptr; + pollset->pollset_set_count = 0; +} + +static void pollset_destroy(grpc_pollset* pollset) { + GPR_ASSERT(!pollset_has_workers(pollset)); + while (pollset->local_wakeup_cache) { + grpc_cached_wakeup_fd* next = pollset->local_wakeup_cache->next; + fork_fd_list_remove_node(pollset->local_wakeup_cache->fork_fd_list); + grpc_wakeup_fd_destroy(&pollset->local_wakeup_cache->fd); + gpr_free(pollset->local_wakeup_cache); + pollset->local_wakeup_cache = next; + } + gpr_free(pollset->fds); + gpr_mu_destroy(&pollset->mu); +} + +static void pollset_add_fd(grpc_pollset* pollset, grpc_fd* fd) { + gpr_mu_lock(&pollset->mu); + size_t i; + /* TODO(ctiller): this is O(num_fds^2); maybe switch to a hash set here */ + for (i = 0; i < pollset->fd_count; i++) { + if (pollset->fds[i] == fd) goto exit; + } + if (pollset->fd_count == pollset->fd_capacity) { + pollset->fd_capacity = + std::max(pollset->fd_capacity + 8, pollset->fd_count * 3 / 2); + pollset->fds = static_cast( + gpr_realloc(pollset->fds, sizeof(grpc_fd*) * pollset->fd_capacity)); + } + pollset->fds[pollset->fd_count++] = fd; + GRPC_FD_REF(fd, "multipoller"); + (void)pollset_kick(pollset, nullptr); +exit: + gpr_mu_unlock(&pollset->mu); +} + +static void finish_shutdown(grpc_pollset* pollset) { + size_t i; + for (i = 0; i < pollset->fd_count; i++) { + GRPC_FD_UNREF(pollset->fds[i], "multipoller"); + } + pollset->fd_count = 0; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, pollset->shutdown_done, + GRPC_ERROR_NONE); +} + +static void work_combine_error(grpc_error_handle* composite, + grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) return; + if (*composite == GRPC_ERROR_NONE) { + *composite = GRPC_ERROR_CREATE_FROM_STATIC_STRING("pollset_work"); + } + *composite = grpc_error_add_child(*composite, error); +} + +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + GPR_TIMER_SCOPE("pollset_work", 0); + grpc_pollset_worker worker; + if (worker_hdl) *worker_hdl = &worker; + grpc_error_handle error = GRPC_ERROR_NONE; + + /* Avoid malloc for small number of elements. */ + enum { inline_elements = 96 }; + struct pollfd pollfd_space[inline_elements]; + struct grpc_fd_watcher watcher_space[inline_elements]; + + /* pollset->mu already held */ + int added_worker = 0; + int locked = 1; + int queued_work = 0; + int keep_polling = 0; + /* this must happen before we (potentially) drop pollset->mu */ + worker.next = worker.prev = nullptr; + worker.reevaluate_polling_on_wakeup = 0; + if (pollset->local_wakeup_cache != nullptr) { + worker.wakeup_fd = pollset->local_wakeup_cache; + pollset->local_wakeup_cache = worker.wakeup_fd->next; + } else { + worker.wakeup_fd = static_cast( + gpr_malloc(sizeof(*worker.wakeup_fd))); + error = grpc_wakeup_fd_init(&worker.wakeup_fd->fd); + fork_fd_list_add_wakeup_fd(worker.wakeup_fd); + if (error != GRPC_ERROR_NONE) { + GRPC_LOG_IF_ERROR("pollset_work", GRPC_ERROR_REF(error)); + return error; + } + } + worker.kicked_specifically = 0; + /* If we're shutting down then we don't execute any extended work */ + if (pollset->shutting_down) { + GPR_TIMER_MARK("pollset_work.shutting_down", 0); + goto done; + } + /* Start polling, and keep doing so while we're being asked to + re-evaluate our pollers (this allows poll() based pollers to + ensure they don't miss wakeups) */ + keep_polling = 1; + g_current_thread_poller = pollset; + while (keep_polling) { + keep_polling = 0; + if (!pollset->kicked_without_pollers || + deadline <= grpc_core::ExecCtx::Get()->Now()) { + if (!added_worker) { + push_front_worker(pollset, &worker); + added_worker = 1; + g_current_thread_worker = &worker; + } + GPR_TIMER_SCOPE("maybe_work_and_unlock", 0); +#define POLLOUT_CHECK (POLLOUT | POLLHUP | POLLERR) +#define POLLIN_CHECK (POLLIN | POLLHUP | POLLERR) + + int timeout; + int r; + size_t i, fd_count; + nfds_t pfd_count; + grpc_fd_watcher* watchers; + struct pollfd* pfds; + + timeout = poll_deadline_to_millis_timeout(deadline); + + if (pollset->fd_count + 2 <= inline_elements) { + pfds = pollfd_space; + watchers = watcher_space; + } else { + /* Allocate one buffer to hold both pfds and watchers arrays */ + const size_t pfd_size = sizeof(*pfds) * (pollset->fd_count + 2); + const size_t watch_size = sizeof(*watchers) * (pollset->fd_count + 2); + void* buf = gpr_malloc(pfd_size + watch_size); + pfds = static_cast(buf); + watchers = static_cast( + static_cast((static_cast(buf) + pfd_size))); + } + + fd_count = 0; + pfd_count = 1; + pfds[0].fd = GRPC_WAKEUP_FD_GET_READ_FD(&worker.wakeup_fd->fd); + pfds[0].events = POLLIN; + pfds[0].revents = 0; + for (i = 0; i < pollset->fd_count; i++) { + if (fd_is_orphaned(pollset->fds[i]) || + gpr_atm_no_barrier_load(&pollset->fds[i]->pollhup) == 1) { + GRPC_FD_UNREF(pollset->fds[i], "multipoller"); + } else { + pollset->fds[fd_count++] = pollset->fds[i]; + watchers[pfd_count].fd = pollset->fds[i]; + GRPC_FD_REF(watchers[pfd_count].fd, "multipoller_start"); + pfds[pfd_count].fd = pollset->fds[i]->fd; + pfds[pfd_count].revents = 0; + pfd_count++; + } + } + pollset->fd_count = fd_count; + gpr_mu_unlock(&pollset->mu); + + for (i = 1; i < pfd_count; i++) { + grpc_fd* fd = watchers[i].fd; + pfds[i].events = static_cast( + fd_begin_poll(fd, pollset, &worker, POLLIN, POLLOUT, &watchers[i])); + GRPC_FD_UNREF(fd, "multipoller_start"); + } + + /* TODO(vpai): Consider first doing a 0 timeout poll here to avoid + even going into the blocking annotation if possible */ + GRPC_SCHEDULING_START_BLOCKING_REGION; + GRPC_STATS_INC_SYSCALL_POLL(); + r = grpc_poll_function(pfds, pfd_count, timeout); + GRPC_SCHEDULING_END_BLOCKING_REGION; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "%p poll=%d", pollset, r); + } + + if (r < 0) { + if (errno != EINTR) { + work_combine_error(&error, GRPC_OS_ERROR(errno, "poll")); + } + + for (i = 1; i < pfd_count; i++) { + if (watchers[i].fd == nullptr) { + fd_end_poll(&watchers[i], 0, 0); + } else { + // Wake up all the file descriptors, if we have an invalid one + // we can identify it on the next pollset_work() + fd_end_poll(&watchers[i], 1, 1); + } + } + } else if (r == 0) { + for (i = 1; i < pfd_count; i++) { + fd_end_poll(&watchers[i], 0, 0); + } + } else { + if (pfds[0].revents & POLLIN_CHECK) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "%p: got_wakeup", pollset); + } + work_combine_error( + &error, grpc_wakeup_fd_consume_wakeup(&worker.wakeup_fd->fd)); + } + for (i = 1; i < pfd_count; i++) { + if (watchers[i].fd == nullptr) { + fd_end_poll(&watchers[i], 0, 0); + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_INFO, "%p got_event: %d r:%d w:%d [%d]", pollset, + pfds[i].fd, (pfds[i].revents & POLLIN_CHECK) != 0, + (pfds[i].revents & POLLOUT_CHECK) != 0, pfds[i].revents); + } + /* This is a mitigation to prevent poll() from spinning on a + ** POLLHUP https://github.com/grpc/grpc/pull/13665 + */ + if (pfds[i].revents & POLLHUP) { + gpr_atm_no_barrier_store(&watchers[i].fd->pollhup, 1); + } + fd_end_poll(&watchers[i], pfds[i].revents & POLLIN_CHECK, + pfds[i].revents & POLLOUT_CHECK); + } + } + } + + if (pfds != pollfd_space) { + /* pfds and watchers are in the same memory block pointed to by pfds */ + gpr_free(pfds); + } + + locked = 0; + } else { + GPR_TIMER_MARK("pollset_work.kicked_without_pollers", 0); + pollset->kicked_without_pollers = 0; + } + /* Finished execution - start cleaning up. + Note that we may arrive here from outside the enclosing while() loop. + In that case we won't loop though as we haven't added worker to the + worker list, which means nobody could ask us to re-evaluate polling). */ + done: + if (!locked) { + queued_work |= grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&pollset->mu); + locked = 1; + } + /* If we're forced to re-evaluate polling (via pollset_kick with + GRPC_POLLSET_REEVALUATE_POLLING_ON_WAKEUP) then we land here and force + a loop */ + if (worker.reevaluate_polling_on_wakeup && error == GRPC_ERROR_NONE) { + worker.reevaluate_polling_on_wakeup = 0; + pollset->kicked_without_pollers = 0; + if (queued_work || worker.kicked_specifically) { + /* If there's queued work on the list, then set the deadline to be + immediate so we get back out of the polling loop quickly */ + deadline = 0; + } + keep_polling = 1; + } + } + g_current_thread_poller = nullptr; + if (added_worker) { + remove_worker(pollset, &worker); + g_current_thread_worker = nullptr; + } + /* release wakeup fd to the local pool */ + worker.wakeup_fd->next = pollset->local_wakeup_cache; + pollset->local_wakeup_cache = worker.wakeup_fd; + /* check shutdown conditions */ + if (pollset->shutting_down) { + if (pollset_has_workers(pollset)) { + (void)pollset_kick(pollset, nullptr); + } else if (!pollset->called_shutdown && !pollset_has_observers(pollset)) { + pollset->called_shutdown = 1; + gpr_mu_unlock(&pollset->mu); + finish_shutdown(pollset); + grpc_core::ExecCtx::Get()->Flush(); + /* Continuing to access pollset here is safe -- it is the caller's + * responsibility to not destroy when it has outstanding calls to + * pollset_work. + * TODO(dklempner): Can we refactor the shutdown logic to avoid this? */ + gpr_mu_lock(&pollset->mu); + } + } + if (worker_hdl) *worker_hdl = nullptr; + GRPC_LOG_IF_ERROR("pollset_work", GRPC_ERROR_REF(error)); + return error; +} + +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + GPR_ASSERT(!pollset->shutting_down); + pollset->shutting_down = 1; + pollset->shutdown_done = closure; + (void)pollset_kick(pollset, GRPC_POLLSET_KICK_BROADCAST); + if (!pollset->called_shutdown && !pollset_has_observers(pollset)) { + pollset->called_shutdown = 1; + finish_shutdown(pollset); + } +} + +static int poll_deadline_to_millis_timeout(grpc_millis deadline) { + if (deadline == GRPC_MILLIS_INF_FUTURE) return -1; + if (deadline == 0) return 0; + grpc_millis n = deadline - grpc_core::ExecCtx::Get()->Now(); + if (n < 0) return 0; + if (n > INT_MAX) return -1; + return static_cast(n); +} + +/******************************************************************************* + * pollset_set_posix.c + */ + +static grpc_pollset_set* pollset_set_create(void) { + grpc_pollset_set* pollset_set = + static_cast(gpr_zalloc(sizeof(*pollset_set))); + gpr_mu_init(&pollset_set->mu); + return pollset_set; +} + +static void pollset_set_destroy(grpc_pollset_set* pollset_set) { + size_t i; + gpr_mu_destroy(&pollset_set->mu); + for (i = 0; i < pollset_set->fd_count; i++) { + GRPC_FD_UNREF(pollset_set->fds[i], "pollset_set"); + } + for (i = 0; i < pollset_set->pollset_count; i++) { + grpc_pollset* pollset = pollset_set->pollsets[i]; + gpr_mu_lock(&pollset->mu); + pollset->pollset_set_count--; + /* check shutdown */ + if (pollset->shutting_down && !pollset->called_shutdown && + !pollset_has_observers(pollset)) { + pollset->called_shutdown = 1; + gpr_mu_unlock(&pollset->mu); + finish_shutdown(pollset); + } else { + gpr_mu_unlock(&pollset->mu); + } + } + gpr_free(pollset_set->pollsets); + gpr_free(pollset_set->pollset_sets); + gpr_free(pollset_set->fds); + gpr_free(pollset_set); +} + +static void pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + size_t i, j; + gpr_mu_lock(&pollset->mu); + pollset->pollset_set_count++; + gpr_mu_unlock(&pollset->mu); + gpr_mu_lock(&pollset_set->mu); + if (pollset_set->pollset_count == pollset_set->pollset_capacity) { + pollset_set->pollset_capacity = + std::max(size_t(8), 2 * pollset_set->pollset_capacity); + pollset_set->pollsets = static_cast(gpr_realloc( + pollset_set->pollsets, + pollset_set->pollset_capacity * sizeof(*pollset_set->pollsets))); + } + pollset_set->pollsets[pollset_set->pollset_count++] = pollset; + for (i = 0, j = 0; i < pollset_set->fd_count; i++) { + if (fd_is_orphaned(pollset_set->fds[i])) { + GRPC_FD_UNREF(pollset_set->fds[i], "pollset_set"); + } else { + pollset_add_fd(pollset, pollset_set->fds[i]); + pollset_set->fds[j++] = pollset_set->fds[i]; + } + } + pollset_set->fd_count = j; + gpr_mu_unlock(&pollset_set->mu); +} + +static void pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + size_t i; + gpr_mu_lock(&pollset_set->mu); + for (i = 0; i < pollset_set->pollset_count; i++) { + if (pollset_set->pollsets[i] == pollset) { + pollset_set->pollset_count--; + std::swap(pollset_set->pollsets[i], + pollset_set->pollsets[pollset_set->pollset_count]); + break; + } + } + gpr_mu_unlock(&pollset_set->mu); + gpr_mu_lock(&pollset->mu); + pollset->pollset_set_count--; + /* check shutdown */ + if (pollset->shutting_down && !pollset->called_shutdown && + !pollset_has_observers(pollset)) { + pollset->called_shutdown = 1; + gpr_mu_unlock(&pollset->mu); + finish_shutdown(pollset); + } else { + gpr_mu_unlock(&pollset->mu); + } +} + +static void pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + size_t i, j; + gpr_mu_lock(&bag->mu); + if (bag->pollset_set_count == bag->pollset_set_capacity) { + bag->pollset_set_capacity = + std::max(size_t(8), 2 * bag->pollset_set_capacity); + bag->pollset_sets = static_cast( + gpr_realloc(bag->pollset_sets, + bag->pollset_set_capacity * sizeof(*bag->pollset_sets))); + } + bag->pollset_sets[bag->pollset_set_count++] = item; + for (i = 0, j = 0; i < bag->fd_count; i++) { + if (fd_is_orphaned(bag->fds[i])) { + GRPC_FD_UNREF(bag->fds[i], "pollset_set"); + } else { + pollset_set_add_fd(item, bag->fds[i]); + bag->fds[j++] = bag->fds[i]; + } + } + bag->fd_count = j; + gpr_mu_unlock(&bag->mu); +} + +static void pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + size_t i; + gpr_mu_lock(&bag->mu); + for (i = 0; i < bag->pollset_set_count; i++) { + if (bag->pollset_sets[i] == item) { + bag->pollset_set_count--; + std::swap(bag->pollset_sets[i], + bag->pollset_sets[bag->pollset_set_count]); + break; + } + } + gpr_mu_unlock(&bag->mu); +} + +static void pollset_set_add_fd(grpc_pollset_set* pollset_set, grpc_fd* fd) { + size_t i; + gpr_mu_lock(&pollset_set->mu); + if (pollset_set->fd_count == pollset_set->fd_capacity) { + pollset_set->fd_capacity = + std::max(size_t(8), 2 * pollset_set->fd_capacity); + pollset_set->fds = static_cast( + gpr_realloc(pollset_set->fds, + pollset_set->fd_capacity * sizeof(*pollset_set->fds))); + } + GRPC_FD_REF(fd, "pollset_set"); + pollset_set->fds[pollset_set->fd_count++] = fd; + for (i = 0; i < pollset_set->pollset_count; i++) { + pollset_add_fd(pollset_set->pollsets[i], fd); + } + for (i = 0; i < pollset_set->pollset_set_count; i++) { + pollset_set_add_fd(pollset_set->pollset_sets[i], fd); + } + gpr_mu_unlock(&pollset_set->mu); +} + +static void pollset_set_del_fd(grpc_pollset_set* pollset_set, grpc_fd* fd) { + size_t i; + gpr_mu_lock(&pollset_set->mu); + for (i = 0; i < pollset_set->fd_count; i++) { + if (pollset_set->fds[i] == fd) { + pollset_set->fd_count--; + std::swap(pollset_set->fds[i], pollset_set->fds[pollset_set->fd_count]); + GRPC_FD_UNREF(fd, "pollset_set"); + break; + } + } + for (i = 0; i < pollset_set->pollset_set_count; i++) { + pollset_set_del_fd(pollset_set->pollset_sets[i], fd); + } + gpr_mu_unlock(&pollset_set->mu); +} + +/******************************************************************************* + * event engine binding + */ + +static bool is_any_background_poller_thread(void) { return false; } + +static void shutdown_background_closure(void) {} + +static bool add_closure_to_background_poller(grpc_closure* /*closure*/, + grpc_error_handle /*error*/) { + return false; +} + +static void shutdown_engine(void) { + pollset_global_shutdown(); + if (track_fds_for_fork) { + gpr_mu_destroy(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc(nullptr); + } +} + +static const grpc_event_engine_vtable vtable = { + sizeof(grpc_pollset), + false, + false, + + fd_create, + fd_wrapped_fd, + fd_orphan, + fd_shutdown, + fd_notify_on_read, + fd_notify_on_write, + fd_notify_on_error, + fd_set_readable, + fd_set_writable, + fd_set_error, + fd_is_shutdown, + + pollset_init, + pollset_shutdown, + pollset_destroy, + pollset_work, + pollset_kick, + pollset_add_fd, + + pollset_set_create, + pollset_set_destroy, + pollset_set_add_pollset, + pollset_set_del_pollset, + pollset_set_add_pollset_set, + pollset_set_del_pollset_set, + pollset_set_add_fd, + pollset_set_del_fd, + + is_any_background_poller_thread, + shutdown_background_closure, + shutdown_engine, + add_closure_to_background_poller, +}; + +/* Called by the child process's post-fork handler to close open fds, including + * worker wakeup fds. This allows gRPC to shutdown in the child process without + * interfering with connections or RPCs ongoing in the parent. */ +static void reset_event_manager_on_fork() { + gpr_mu_lock(&fork_fd_list_mu); + while (fork_fd_list_head != nullptr) { + if (fork_fd_list_head->fd != nullptr) { + if (!fork_fd_list_head->fd->closed) { + close(fork_fd_list_head->fd->fd); + } + fork_fd_list_head->fd->fd = -1; + } else { + close(fork_fd_list_head->cached_wakeup_fd->fd.read_fd); + fork_fd_list_head->cached_wakeup_fd->fd.read_fd = -1; + close(fork_fd_list_head->cached_wakeup_fd->fd.write_fd); + fork_fd_list_head->cached_wakeup_fd->fd.write_fd = -1; + } + fork_fd_list_head = fork_fd_list_head->next; + } + gpr_mu_unlock(&fork_fd_list_mu); +} + +const grpc_event_engine_vtable* grpc_init_poll_posix( + bool /*explicit_request*/) { + if (!grpc_has_wakeup_fd()) { + gpr_log(GPR_ERROR, "Skipping poll because of no wakeup fd."); + return nullptr; + } + if (!GRPC_LOG_IF_ERROR("pollset_global_init", pollset_global_init())) { + return nullptr; + } + if (grpc_core::Fork::Enabled()) { + track_fds_for_fork = true; + gpr_mu_init(&fork_fd_list_mu); + grpc_core::Fork::SetResetChildPollingEngineFunc( + reset_event_manager_on_fork); + } + return &vtable; +} + +#endif /* GRPC_POSIX_SOCKET_EV_POLL */ diff --git a/src/core/lib/iomgr/ev_posix.cc b/src/core/lib/iomgr/ev_posix.cc new file mode 100644 index 00000000..bae0f15c --- /dev/null +++ b/src/core/lib/iomgr/ev_posix.cc @@ -0,0 +1,417 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_EV + +#include + +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/iomgr/ev_epoll1_linux.h" +#include "src/core/lib/iomgr/ev_epollex_linux.h" +#include "src/core/lib/iomgr/ev_poll_posix.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/internal_errqueue.h" + +GPR_GLOBAL_CONFIG_DEFINE_STRING( + grpc_poll_strategy, "all", + "Declares which polling engines to try when starting gRPC. " + "This is a comma-separated list of engines, which are tried in priority " + "order first -> last.") + +grpc_core::DebugOnlyTraceFlag grpc_polling_trace( + false, "polling"); /* Disabled by default */ + +/* Traces fd create/close operations */ +grpc_core::DebugOnlyTraceFlag grpc_fd_trace(false, "fd_trace"); +grpc_core::DebugOnlyTraceFlag grpc_trace_fd_refcount(false, "fd_refcount"); +grpc_core::DebugOnlyTraceFlag grpc_polling_api_trace(false, "polling_api"); + +// Polling API trace only enabled in debug builds +#ifndef NDEBUG +#define GRPC_POLLING_API_TRACE(format, ...) \ + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_api_trace)) { \ + gpr_log(GPR_INFO, "(polling-api) " format, __VA_ARGS__); \ + } +#else +#define GRPC_POLLING_API_TRACE(...) +#endif // NDEBUG + +/** Default poll() function - a pointer so that it can be overridden by some + * tests */ +#ifndef GPR_AIX +grpc_poll_function_type grpc_poll_function = poll; +#else +int aix_poll(struct pollfd fds[], nfds_t nfds, int timeout) { + return poll(fds, nfds, timeout); +} +grpc_poll_function_type grpc_poll_function = aix_poll; +#endif // GPR_AIX + +grpc_wakeup_fd grpc_global_wakeup_fd; + +static const grpc_event_engine_vtable* g_event_engine = nullptr; +static const char* g_poll_strategy_name = nullptr; + +typedef const grpc_event_engine_vtable* (*event_engine_factory_fn)( + bool explicit_request); + +struct event_engine_factory { + const char* name; + event_engine_factory_fn factory; +}; +namespace { + +grpc_poll_function_type real_poll_function; + +int phony_poll(struct pollfd fds[], nfds_t nfds, int timeout) { + if (timeout == 0) { + return real_poll_function(fds, nfds, 0); + } else { + gpr_log(GPR_ERROR, "Attempted a blocking poll when declared non-polling."); + GPR_ASSERT(false); + return -1; + } +} + +const grpc_event_engine_vtable* init_non_polling(bool explicit_request) { + if (!explicit_request) { + return nullptr; + } + // return the simplest engine as a phony but also override the poller + auto ret = grpc_init_poll_posix(explicit_request); + real_poll_function = grpc_poll_function; + grpc_poll_function = phony_poll; + + return ret; +} +} // namespace + +#define ENGINE_HEAD_CUSTOM "head_custom" +#define ENGINE_TAIL_CUSTOM "tail_custom" + +// The global array of event-engine factories. Each entry is a pair with a name +// and an event-engine generator function (nullptr if there is no generator +// registered for this name). The middle entries are the engines predefined by +// open-source gRPC. The head entries represent an opportunity for specific +// high-priority custom pollers to be added by the initializer plugins of +// custom-built gRPC libraries. The tail entries represent the same, but for +// low-priority custom pollers. The actual poller selected is either the first +// available one in the list if no specific poller is requested, or the first +// specific poller that is requested by name in the GRPC_POLL_STRATEGY +// environment variable if that variable is set (which should be a +// comma-separated list of one or more event engine names) +static event_engine_factory g_factories[] = { + {ENGINE_HEAD_CUSTOM, nullptr}, {ENGINE_HEAD_CUSTOM, nullptr}, + {ENGINE_HEAD_CUSTOM, nullptr}, {ENGINE_HEAD_CUSTOM, nullptr}, + {"epollex", grpc_init_epollex_linux}, {"epoll1", grpc_init_epoll1_linux}, + {"poll", grpc_init_poll_posix}, {"none", init_non_polling}, + {ENGINE_TAIL_CUSTOM, nullptr}, {ENGINE_TAIL_CUSTOM, nullptr}, + {ENGINE_TAIL_CUSTOM, nullptr}, {ENGINE_TAIL_CUSTOM, nullptr}, +}; + +static void add(const char* beg, const char* end, char*** ss, size_t* ns) { + size_t n = *ns; + size_t np = n + 1; + char* s; + size_t len; + GPR_ASSERT(end >= beg); + len = static_cast(end - beg); + s = static_cast(gpr_malloc(len + 1)); + memcpy(s, beg, len); + s[len] = 0; + *ss = static_cast(gpr_realloc(*ss, sizeof(char**) * np)); + (*ss)[n] = s; + *ns = np; +} + +static void split(const char* s, char*** ss, size_t* ns) { + const char* c = strchr(s, ','); + if (c == nullptr) { + add(s, s + strlen(s), ss, ns); + } else { + add(s, c, ss, ns); + split(c + 1, ss, ns); + } +} + +static bool is(const char* want, const char* have) { + return 0 == strcmp(want, "all") || 0 == strcmp(want, have); +} + +static void try_engine(const char* engine) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(g_factories); i++) { + if (g_factories[i].factory != nullptr && is(engine, g_factories[i].name)) { + if ((g_event_engine = g_factories[i].factory( + 0 == strcmp(engine, g_factories[i].name)))) { + g_poll_strategy_name = g_factories[i].name; + gpr_log(GPR_DEBUG, "Using polling engine: %s", g_factories[i].name); + return; + } + } + } +} + +/* Call this before calling grpc_event_engine_init() */ +void grpc_register_event_engine_factory(const char* name, + event_engine_factory_fn factory, + bool add_at_head) { + const char* custom_match = + add_at_head ? ENGINE_HEAD_CUSTOM : ENGINE_TAIL_CUSTOM; + + // Overwrite an existing registration if already registered + for (size_t i = 0; i < GPR_ARRAY_SIZE(g_factories); i++) { + if (0 == strcmp(name, g_factories[i].name)) { + g_factories[i].factory = factory; + return; + } + } + + // Otherwise fill in an available custom slot + for (size_t i = 0; i < GPR_ARRAY_SIZE(g_factories); i++) { + if (0 == strcmp(g_factories[i].name, custom_match)) { + g_factories[i].name = name; + g_factories[i].factory = factory; + return; + } + } + + // Otherwise fail + GPR_ASSERT(false); +} + +/*If grpc_event_engine_init() has been called, returns the poll_strategy_name. + * Otherwise, returns nullptr. */ +const char* grpc_get_poll_strategy_name() { return g_poll_strategy_name; } + +void grpc_event_engine_init(void) { + grpc_core::UniquePtr value = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + + char** strings = nullptr; + size_t nstrings = 0; + split(value.get(), &strings, &nstrings); + + for (size_t i = 0; g_event_engine == nullptr && i < nstrings; i++) { + try_engine(strings[i]); + } + + for (size_t i = 0; i < nstrings; i++) { + gpr_free(strings[i]); + } + gpr_free(strings); + + if (g_event_engine == nullptr) { + gpr_log(GPR_ERROR, "No event engine could be initialized from %s", + value.get()); + abort(); + } +} + +void grpc_event_engine_shutdown(void) { + g_event_engine->shutdown_engine(); + g_event_engine = nullptr; +} + +bool grpc_event_engine_can_track_errors(void) { + /* Only track errors if platform supports errqueue. */ + if (grpc_core::kernel_supports_errqueue()) { + return g_event_engine->can_track_err; + } + return false; +} + +bool grpc_event_engine_run_in_background(void) { + // g_event_engine is nullptr when using a custom iomgr. + return g_event_engine != nullptr && g_event_engine->run_in_background; +} + +grpc_fd* grpc_fd_create(int fd, const char* name, bool track_err) { + GRPC_POLLING_API_TRACE("fd_create(%d, %s, %d)", fd, name, track_err); + GRPC_FD_TRACE("fd_create(%d, %s, %d)", fd, name, track_err); + return g_event_engine->fd_create( + fd, name, track_err && grpc_event_engine_can_track_errors()); +} + +int grpc_fd_wrapped_fd(grpc_fd* fd) { + return g_event_engine->fd_wrapped_fd(fd); +} + +void grpc_fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd, + const char* reason) { + GRPC_POLLING_API_TRACE("fd_orphan(%d, %p, %p, %s)", grpc_fd_wrapped_fd(fd), + on_done, release_fd, reason); + GRPC_FD_TRACE("grpc_fd_orphan, fd:%d closed", grpc_fd_wrapped_fd(fd)); + + g_event_engine->fd_orphan(fd, on_done, release_fd, reason); +} + +void grpc_fd_shutdown(grpc_fd* fd, grpc_error_handle why) { + GRPC_POLLING_API_TRACE("fd_shutdown(%d)", grpc_fd_wrapped_fd(fd)); + GRPC_FD_TRACE("fd_shutdown(%d)", grpc_fd_wrapped_fd(fd)); + g_event_engine->fd_shutdown(fd, why); +} + +bool grpc_fd_is_shutdown(grpc_fd* fd) { + return g_event_engine->fd_is_shutdown(fd); +} + +void grpc_fd_notify_on_read(grpc_fd* fd, grpc_closure* closure) { + g_event_engine->fd_notify_on_read(fd, closure); +} + +void grpc_fd_notify_on_write(grpc_fd* fd, grpc_closure* closure) { + g_event_engine->fd_notify_on_write(fd, closure); +} + +void grpc_fd_notify_on_error(grpc_fd* fd, grpc_closure* closure) { + g_event_engine->fd_notify_on_error(fd, closure); +} + +void grpc_fd_set_readable(grpc_fd* fd) { g_event_engine->fd_set_readable(fd); } + +void grpc_fd_set_writable(grpc_fd* fd) { g_event_engine->fd_set_writable(fd); } + +void grpc_fd_set_error(grpc_fd* fd) { g_event_engine->fd_set_error(fd); } + +static size_t pollset_size(void) { return g_event_engine->pollset_size; } + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + GRPC_POLLING_API_TRACE("pollset_init(%p)", pollset); + g_event_engine->pollset_init(pollset, mu); +} + +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + GRPC_POLLING_API_TRACE("pollset_shutdown(%p)", pollset); + g_event_engine->pollset_shutdown(pollset, closure); +} + +static void pollset_destroy(grpc_pollset* pollset) { + GRPC_POLLING_API_TRACE("pollset_destroy(%p)", pollset); + g_event_engine->pollset_destroy(pollset); +} + +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker, + grpc_millis deadline) { + GRPC_POLLING_API_TRACE("pollset_work(%p, %" PRId64 ") begin", pollset, + deadline); + grpc_error_handle err = + g_event_engine->pollset_work(pollset, worker, deadline); + GRPC_POLLING_API_TRACE("pollset_work(%p, %" PRId64 ") end", pollset, + deadline); + return err; +} + +static grpc_error_handle pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + GRPC_POLLING_API_TRACE("pollset_kick(%p, %p)", pollset, specific_worker); + return g_event_engine->pollset_kick(pollset, specific_worker); +} + +void grpc_pollset_add_fd(grpc_pollset* pollset, struct grpc_fd* fd) { + GRPC_POLLING_API_TRACE("pollset_add_fd(%p, %d)", pollset, + grpc_fd_wrapped_fd(fd)); + g_event_engine->pollset_add_fd(pollset, fd); +} + +void pollset_global_init() {} +void pollset_global_shutdown() {} + +grpc_pollset_vtable grpc_posix_pollset_vtable = { + pollset_global_init, pollset_global_shutdown, + pollset_init, pollset_shutdown, + pollset_destroy, pollset_work, + pollset_kick, pollset_size}; + +static grpc_pollset_set* pollset_set_create(void) { + grpc_pollset_set* pss = g_event_engine->pollset_set_create(); + GRPC_POLLING_API_TRACE("pollset_set_create(%p)", pss); + return pss; +} + +static void pollset_set_destroy(grpc_pollset_set* pollset_set) { + GRPC_POLLING_API_TRACE("pollset_set_destroy(%p)", pollset_set); + g_event_engine->pollset_set_destroy(pollset_set); +} + +static void pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + GRPC_POLLING_API_TRACE("pollset_set_add_pollset(%p, %p)", pollset_set, + pollset); + g_event_engine->pollset_set_add_pollset(pollset_set, pollset); +} + +static void pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + GRPC_POLLING_API_TRACE("pollset_set_del_pollset(%p, %p)", pollset_set, + pollset); + g_event_engine->pollset_set_del_pollset(pollset_set, pollset); +} + +static void pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + GRPC_POLLING_API_TRACE("pollset_set_add_pollset_set(%p, %p)", bag, item); + g_event_engine->pollset_set_add_pollset_set(bag, item); +} + +static void pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + GRPC_POLLING_API_TRACE("pollset_set_del_pollset_set(%p, %p)", bag, item); + g_event_engine->pollset_set_del_pollset_set(bag, item); +} + +grpc_pollset_set_vtable grpc_posix_pollset_set_vtable = { + pollset_set_create, pollset_set_destroy, + pollset_set_add_pollset, pollset_set_del_pollset, + pollset_set_add_pollset_set, pollset_set_del_pollset_set}; + +void grpc_pollset_set_add_fd(grpc_pollset_set* pollset_set, grpc_fd* fd) { + GRPC_POLLING_API_TRACE("pollset_set_add_fd(%p, %d)", pollset_set, + grpc_fd_wrapped_fd(fd)); + g_event_engine->pollset_set_add_fd(pollset_set, fd); +} + +void grpc_pollset_set_del_fd(grpc_pollset_set* pollset_set, grpc_fd* fd) { + GRPC_POLLING_API_TRACE("pollset_set_del_fd(%p, %d)", pollset_set, + grpc_fd_wrapped_fd(fd)); + g_event_engine->pollset_set_del_fd(pollset_set, fd); +} + +bool grpc_is_any_background_poller_thread(void) { + return g_event_engine->is_any_background_poller_thread(); +} + +bool grpc_add_closure_to_background_poller(grpc_closure* closure, + grpc_error_handle error) { + return g_event_engine->add_closure_to_background_poller(closure, error); +} + +void grpc_shutdown_background_closure(void) { + g_event_engine->shutdown_background_closure(); +} + +#endif // GRPC_POSIX_SOCKET_EV diff --git a/src/core/lib/iomgr/ev_windows.cc b/src/core/lib/iomgr/ev_windows.cc new file mode 100644 index 00000000..e3f5715a --- /dev/null +++ b/src/core/lib/iomgr/ev_windows.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include "src/core/lib/debug/trace.h" + +grpc_core::DebugOnlyTraceFlag grpc_polling_trace( + false, "polling"); /* Disabled by default */ + +#endif // GRPC_WINSOCK_SOCKET diff --git a/src/core/lib/iomgr/event_engine/closure.cc b/src/core/lib/iomgr/event_engine/closure.cc new file mode 100644 index 00000000..e2afe8ab --- /dev/null +++ b/src/core/lib/iomgr/event_engine/closure.cc @@ -0,0 +1,77 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/event_engine/closure.h" +#include "src/core/lib/iomgr/event_engine/pollset.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_event_engine { +namespace experimental { + +namespace { + +void RunClosure(grpc_closure* closure, grpc_error_handle error) { + GPR_ASSERT(closure != nullptr); +#ifndef NDEBUG + closure->scheduled = false; + if (grpc_trace_closure.enabled()) { + gpr_log(GPR_DEBUG, + "EventEngine: running closure %p: created [%s:%d]: %s [%s:%d]", + closure, closure->file_created, closure->line_created, + closure->run ? "run" : "scheduled", closure->file_initiated, + closure->line_initiated); + } +#endif + closure->cb(closure->cb_arg, error); +#ifndef NDEBUG + if (grpc_trace_closure.enabled()) { + gpr_log(GPR_DEBUG, "EventEngine: closure %p finished", closure); + } +#endif +} + +} // namespace + +std::function GrpcClosureToStatusCallback( + grpc_closure* closure) { + return [closure](absl::Status status) { + RunClosure(closure, absl_status_to_grpc_error(status)); + grpc_pollset_ee_broadcast_event(); + }; +} + +std::function GrpcClosureToCallback(grpc_closure* closure) { + return [closure]() { + RunClosure(closure, GRPC_ERROR_NONE); + grpc_pollset_ee_broadcast_event(); + }; +} + +std::function GrpcClosureToCallback(grpc_closure* closure, + grpc_error_handle error) { + return [closure, error]() { + RunClosure(closure, error); + grpc_pollset_ee_broadcast_event(); + }; +} + +} // namespace experimental +} // namespace grpc_event_engine + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/endpoint.cc b/src/core/lib/iomgr/event_engine/endpoint.cc new file mode 100644 index 00000000..6ddbbb38 --- /dev/null +++ b/src/core/lib/iomgr/event_engine/endpoint.cc @@ -0,0 +1,173 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include "absl/strings/string_view.h" + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/event_engine/closure.h" +#include "src/core/lib/iomgr/event_engine/endpoint.h" +#include "src/core/lib/iomgr/event_engine/pollset.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/transport/error_utils.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +namespace { + +using ::grpc_event_engine::experimental::EventEngine; +using ::grpc_event_engine::experimental::ResolvedAddressToURI; +using ::grpc_event_engine::experimental::SliceAllocator; +using ::grpc_event_engine::experimental::SliceBuffer; + +void endpoint_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool /* urgent */) { + auto* eeep = reinterpret_cast(ep); + if (eeep->endpoint == nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_CANCELLED); + return; + } + SliceBuffer* read_buffer = new (&eeep->read_buffer) SliceBuffer(slices); + eeep->endpoint->Read( + [eeep, cb](absl::Status status) { + auto* read_buffer = reinterpret_cast(&eeep->read_buffer); + read_buffer->~SliceBuffer(); + grpc_core::ExecCtx exec_ctx; + grpc_core::Closure::Run(DEBUG_LOCATION, cb, + absl_status_to_grpc_error(status)); + exec_ctx.Flush(); + grpc_pollset_ee_broadcast_event(); + }, + read_buffer); +} + +void endpoint_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg) { + // TODO(hork): adapt arg to some metrics collection mechanism. + (void)arg; + auto* eeep = reinterpret_cast(ep); + if (eeep->endpoint == nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_CANCELLED); + return; + } + SliceBuffer* write_buffer = new (&eeep->write_buffer) SliceBuffer(slices); + eeep->endpoint->Write( + [eeep, cb](absl::Status status) { + auto* write_buffer = + reinterpret_cast(&eeep->write_buffer); + write_buffer->~SliceBuffer(); + grpc_core::ExecCtx exec_ctx; + grpc_core::Closure::Run(DEBUG_LOCATION, cb, + absl_status_to_grpc_error(status)); + exec_ctx.Flush(); + grpc_pollset_ee_broadcast_event(); + }, + write_buffer); +} +void endpoint_add_to_pollset(grpc_endpoint* /* ep */, + grpc_pollset* /* pollset */) {} +void endpoint_add_to_pollset_set(grpc_endpoint* /* ep */, + grpc_pollset_set* /* pollset */) {} +void endpoint_delete_from_pollset_set(grpc_endpoint* /* ep */, + grpc_pollset_set* /* pollset */) {} +/// After shutdown, all endpoint operations except destroy are no-op, +/// and will return some kind of sane default (empty strings, nullptrs, etc). It +/// is the caller's responsibility to ensure that calls to endpoint_shutdown are +/// synchronized. +void endpoint_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + auto* eeep = reinterpret_cast(ep); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + std::string str = grpc_error_std_string(why); + gpr_log(GPR_INFO, "TCP Endpoint %p shutdown why=%s", eeep->endpoint.get(), + str.c_str()); + } + eeep->endpoint.reset(); +} + +void endpoint_destroy(grpc_endpoint* ep) { + auto* eeep = reinterpret_cast(ep); + delete eeep; +} + +absl::string_view endpoint_get_peer(grpc_endpoint* ep) { + auto* eeep = reinterpret_cast(ep); + if (eeep->endpoint == nullptr) { + return ""; + } + if (eeep->peer_address.empty()) { + const EventEngine::ResolvedAddress& addr = eeep->endpoint->GetPeerAddress(); + eeep->peer_address = ResolvedAddressToURI(addr); + } + return eeep->peer_address; +} + +absl::string_view endpoint_get_local_address(grpc_endpoint* ep) { + auto* eeep = reinterpret_cast(ep); + if (eeep->endpoint == nullptr) { + return ""; + } + if (eeep->local_address.empty()) { + const EventEngine::ResolvedAddress& addr = + eeep->endpoint->GetLocalAddress(); + eeep->local_address = ResolvedAddressToURI(addr); + } + return eeep->local_address; +} + +int endpoint_get_fd(grpc_endpoint* /* ep */) { return -1; } + +bool endpoint_can_track_err(grpc_endpoint* /* ep */) { return false; } + +grpc_endpoint_vtable grpc_event_engine_endpoint_vtable = { + endpoint_read, + endpoint_write, + endpoint_add_to_pollset, + endpoint_add_to_pollset_set, + endpoint_delete_from_pollset_set, + endpoint_shutdown, + endpoint_destroy, + endpoint_get_peer, + endpoint_get_local_address, + endpoint_get_fd, + endpoint_can_track_err}; + +} // namespace + +grpc_event_engine_endpoint* grpc_tcp_server_endpoint_create( + std::unique_ptr ee_endpoint) { + auto endpoint = new grpc_event_engine_endpoint; + endpoint->base.vtable = &grpc_event_engine_endpoint_vtable; + endpoint->endpoint = std::move(ee_endpoint); + return endpoint; +} + +grpc_endpoint* grpc_tcp_create(const grpc_channel_args* channel_args, + absl::string_view peer_address) { + auto endpoint = new grpc_event_engine_endpoint; + endpoint->base.vtable = &grpc_event_engine_endpoint_vtable; + return &endpoint->base; +} + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/iomgr.cc b/src/core/lib/iomgr/event_engine/iomgr.cc new file mode 100644 index 00000000..27204328 --- /dev/null +++ b/src/core/lib/iomgr/event_engine/iomgr.cc @@ -0,0 +1,104 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/event_engine/iomgr.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/surface/init.h" + +extern grpc_tcp_client_vtable grpc_event_engine_tcp_client_vtable; +extern grpc_tcp_server_vtable grpc_event_engine_tcp_server_vtable; +extern grpc_timer_vtable grpc_event_engine_timer_vtable; +extern grpc_pollset_vtable grpc_event_engine_pollset_vtable; +extern grpc_pollset_set_vtable grpc_event_engine_pollset_set_vtable; +extern grpc_address_resolver_vtable grpc_event_engine_resolver_vtable; + +// Disabled by default. grpc_polling_trace must be defined in all iomgr +// implementations due to its usage in lockfree_event. +grpc_core::DebugOnlyTraceFlag grpc_polling_trace(false, "polling"); + +namespace { + +using ::grpc_event_engine::experimental::DefaultEventEngineFactory; +using ::grpc_event_engine::experimental::EventEngine; + +EventEngine* g_event_engine = nullptr; + +// TODO(nnoble): Instantiate the default EventEngine if none have been provided. +void iomgr_platform_init(void) { GPR_ASSERT(g_event_engine != nullptr); } + +void iomgr_platform_flush(void) {} + +void iomgr_platform_shutdown(void) { + delete g_event_engine; + g_event_engine = nullptr; +} + +void iomgr_platform_shutdown_background_closure(void) {} + +bool iomgr_platform_is_any_background_poller_thread(void) { + return g_event_engine->IsWorkerThread(); +} + +bool iomgr_platform_add_closure_to_background_poller( + grpc_closure* /* closure */, grpc_error_handle /* error */) { + return false; +} + +grpc_iomgr_platform_vtable vtable = { + iomgr_platform_init, + iomgr_platform_flush, + iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure, + iomgr_platform_is_any_background_poller_thread, + iomgr_platform_add_closure_to_background_poller}; + +} // namespace + +void grpc_set_default_iomgr_platform() { + grpc_set_tcp_client_impl(&grpc_event_engine_tcp_client_vtable); + grpc_set_tcp_server_impl(&grpc_event_engine_tcp_server_vtable); + grpc_set_timer_impl(&grpc_event_engine_timer_vtable); + grpc_set_pollset_vtable(&grpc_event_engine_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_event_engine_pollset_set_vtable); + grpc_set_resolver_impl(&grpc_event_engine_resolver_vtable); + grpc_set_iomgr_platform_vtable(&vtable); +} + +bool grpc_iomgr_run_in_background() { return false; } + +grpc_event_engine::experimental::EventEngine* grpc_iomgr_event_engine() { + return g_event_engine; +} + +namespace grpc_core { + +void SetDefaultEventEngine( + std::unique_ptr + event_engine) { + GPR_ASSERT(g_event_engine == nullptr); + g_event_engine = event_engine.release(); +} + +} // namespace grpc_core + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/pollset.cc b/src/core/lib/iomgr/event_engine/pollset.cc new file mode 100644 index 00000000..b8e08d06 --- /dev/null +++ b/src/core/lib/iomgr/event_engine/pollset.cc @@ -0,0 +1,88 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include "src/core/lib/iomgr/event_engine/pollset.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_set.h" + +namespace { + +static gpr_mu g_mu; +static gpr_cv g_cv; + +// --- pollset vtable API --- +void pollset_global_init(void) { + gpr_mu_init(&g_mu); + gpr_cv_init(&g_cv); +} +void pollset_global_shutdown(void) { + gpr_cv_destroy(&g_cv); + gpr_mu_destroy(&g_mu); +} +void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { *mu = &g_mu; } +void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); +} +void pollset_destroy(grpc_pollset* pollset) {} +grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker, + grpc_millis deadline) { + (void)worker; + gpr_cv_wait(&g_cv, &g_mu, + grpc_millis_to_timespec(deadline, GPR_CLOCK_REALTIME)); + return GRPC_ERROR_NONE; +} +grpc_error_handle pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + (void)pollset; + (void)specific_worker; + return GRPC_ERROR_NONE; +} +size_t pollset_size(void) { return 1; } + +// --- pollset_set vtable API --- +grpc_pollset_set* pollset_set_create(void) { return nullptr; } +void pollset_set_destroy(grpc_pollset_set* pollset_set) {} +void pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} + +void pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} +void pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} +void pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} + +} // namespace + +void grpc_pollset_ee_broadcast_event() { gpr_cv_signal(&g_cv); } + +// --- vtables --- +grpc_pollset_vtable grpc_event_engine_pollset_vtable = { + pollset_global_init, pollset_global_shutdown, + pollset_init, pollset_shutdown, + pollset_destroy, pollset_work, + pollset_kick, pollset_size}; + +grpc_pollset_set_vtable grpc_event_engine_pollset_set_vtable = { + pollset_set_create, pollset_set_destroy, + pollset_set_add_pollset, pollset_set_del_pollset, + pollset_set_add_pollset_set, pollset_set_del_pollset_set}; + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/resolved_address_internal.cc b/src/core/lib/iomgr/event_engine/resolved_address_internal.cc new file mode 100644 index 00000000..561d91ed --- /dev/null +++ b/src/core/lib/iomgr/event_engine/resolved_address_internal.cc @@ -0,0 +1,41 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "src/core/lib/iomgr/event_engine/resolved_address_internal.h" + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/resolve_address.h" + +namespace grpc_event_engine { +namespace experimental { + +EventEngine::ResolvedAddress CreateResolvedAddress( + const grpc_resolved_address& addr) { + return EventEngine::ResolvedAddress( + reinterpret_cast(addr.addr), addr.len); +} + +grpc_resolved_address CreateGRPCResolvedAddress( + const EventEngine::ResolvedAddress& ra) { + grpc_resolved_address grpc_addr; + memcpy(grpc_addr.addr, ra.address(), ra.size()); + grpc_addr.len = ra.size(); + return grpc_addr; +} + +} // namespace experimental +} // namespace grpc_event_engine diff --git a/src/core/lib/iomgr/event_engine/resolver.cc b/src/core/lib/iomgr/event_engine/resolver.cc new file mode 100644 index 00000000..628c305e --- /dev/null +++ b/src/core/lib/iomgr/event_engine/resolver.cc @@ -0,0 +1,114 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include "absl/functional/bind_front.h" + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/event_engine/iomgr.h" +#include "src/core/lib/iomgr/event_engine/promise.h" +#include "src/core/lib/iomgr/event_engine/resolved_address_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/surface/init.h" +#include "src/core/lib/transport/error_utils.h" + +namespace { +using ::grpc_event_engine::experimental::CreateGRPCResolvedAddress; +using ::grpc_event_engine::experimental::EventEngine; +using ::grpc_event_engine::experimental::Promise; + +/// A fire-and-forget class representing an individual DNS request. +/// +/// This provides a place to store the ownership of the DNSResolver object until +/// the request is complete. +class DnsRequest { + public: + DnsRequest(std::unique_ptr dns_resolver, + absl::string_view address, absl::string_view default_port, + grpc_closure* on_done, grpc_resolved_addresses** addresses) + : dns_resolver_(std::move(dns_resolver)), + cb_(on_done), + addresses_(addresses) { + dns_resolver_->LookupHostname( + absl::bind_front(&DnsRequest::OnLookupComplete, this), address, + default_port, absl::InfiniteFuture()); + } + + private: + void OnLookupComplete( + absl::StatusOr> addresses) { + grpc_core::ExecCtx exec_ctx; + // Convert addresses to iomgr form. + *addresses_ = static_cast( + gpr_malloc(sizeof(grpc_resolved_addresses))); + (*addresses_)->naddrs = addresses->size(); + (*addresses_)->addrs = static_cast( + gpr_malloc(sizeof(grpc_resolved_address) * addresses->size())); + for (size_t i = 0; i < addresses->size(); ++i) { + (*addresses_)->addrs[i] = CreateGRPCResolvedAddress((*addresses)[i]); + } + grpc_closure* cb = cb_; + delete this; + grpc_core::Closure::Run(DEBUG_LOCATION, cb, + absl_status_to_grpc_error(addresses.status())); + } + + std::unique_ptr dns_resolver_; + grpc_closure* cb_; + grpc_resolved_addresses** addresses_; +}; + +void resolve_address(const char* addr, const char* default_port, + grpc_pollset_set* /* interested_parties */, + grpc_closure* on_done, + grpc_resolved_addresses** addresses) { + std::unique_ptr dns_resolver = + grpc_iomgr_event_engine()->GetDNSResolver(); + if (dns_resolver == nullptr) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, on_done, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to get DNS Resolver.")); + return; + } + new DnsRequest(std::move(dns_resolver), addr, default_port, on_done, + addresses); +} + +void blocking_handle_async_resolve_done(void* arg, grpc_error_handle error) { + static_cast*>(arg)->Set(std::move(error)); +} + +grpc_error_handle blocking_resolve_address( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + grpc_closure on_done; + Promise evt; + GRPC_CLOSURE_INIT(&on_done, blocking_handle_async_resolve_done, &evt, + grpc_schedule_on_exec_ctx); + resolve_address(name, default_port, nullptr, &on_done, addresses); + return evt.Get(); +} + +} // namespace + +grpc_address_resolver_vtable grpc_event_engine_resolver_vtable{ + resolve_address, blocking_resolve_address}; + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/tcp.cc b/src/core/lib/iomgr/event_engine/tcp.cc new file mode 100644 index 00000000..04f6216a --- /dev/null +++ b/src/core/lib/iomgr/event_engine/tcp.cc @@ -0,0 +1,293 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/event_engine/endpoint_config_internal.h" +#include "src/core/lib/event_engine/sockaddr.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/iomgr/event_engine/closure.h" +#include "src/core/lib/iomgr/event_engine/endpoint.h" +#include "src/core/lib/iomgr/event_engine/iomgr.h" +#include "src/core/lib/iomgr/event_engine/pollset.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/surface/init.h" +#include "src/core/lib/transport/error_utils.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +namespace { +using ::grpc_event_engine::experimental::ChannelArgsEndpointConfig; +using ::grpc_event_engine::experimental::EventEngine; +using ::grpc_event_engine::experimental::GrpcClosureToStatusCallback; +using ::grpc_event_engine::experimental::SliceAllocator; +using ::grpc_event_engine::experimental::SliceAllocatorFactory; +using ::grpc_event_engine::experimental::SliceBuffer; +} // namespace + +class WrappedInternalSliceAllocator : public SliceAllocator { + public: + explicit WrappedInternalSliceAllocator(grpc_slice_allocator* slice_allocator) + : slice_allocator_(slice_allocator) {} + + ~WrappedInternalSliceAllocator() { + grpc_slice_allocator_destroy(slice_allocator_); + } + + absl::Status Allocate(size_t size, SliceBuffer* dest, + SliceAllocator::AllocateCallback cb) override { + // TODO(nnoble): requires the SliceBuffer definition. + grpc_slice_allocator_allocate( + slice_allocator_, size, 1, grpc_slice_allocator_intent::kReadBuffer, + dest->RawSliceBuffer(), + [](void* arg, grpc_error_handle error) { + auto cb = static_cast(arg); + (*cb)(grpc_error_to_absl_status(error)); + delete cb; + }, + new SliceAllocator::AllocateCallback(cb)); + return absl::OkStatus(); + } + + private: + grpc_slice_allocator* slice_allocator_; +}; + +class WrappedInternalSliceAllocatorFactory : public SliceAllocatorFactory { + public: + explicit WrappedInternalSliceAllocatorFactory( + grpc_slice_allocator_factory* slice_allocator_factory) + : slice_allocator_factory_(slice_allocator_factory) {} + + ~WrappedInternalSliceAllocatorFactory() { + grpc_slice_allocator_factory_destroy(slice_allocator_factory_); + } + + std::unique_ptr CreateSliceAllocator( + absl::string_view peer_name) override { + return absl::make_unique( + grpc_slice_allocator_factory_create_slice_allocator( + slice_allocator_factory_, peer_name)); + }; + + private: + grpc_slice_allocator_factory* slice_allocator_factory_; +}; + +struct grpc_tcp_server { + explicit grpc_tcp_server(std::unique_ptr listener) + : refcount(1, GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace) ? "tcp" : nullptr), + listener(std::move(listener)) { + shutdown_starting.head = nullptr; + shutdown_starting.tail = nullptr; + }; + ~grpc_tcp_server() { + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &shutdown_starting); + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_core::RefCount refcount; + grpc_core::Mutex mu; + std::unique_ptr listener; + grpc_closure_list shutdown_starting ABSL_GUARDED_BY(mu); + grpc_tcp_server_cb on_accept_internal; + void* on_accept_internal_arg; +}; + +namespace { + +/// Converts a grpc_closure to an EventEngine Callback. The closure is expected +/// to already be initialized. +EventEngine::OnConnectCallback GrpcClosureToOnConnectCallback( + grpc_closure* closure, grpc_endpoint** endpoint_ptr) { + return [closure, endpoint_ptr]( + absl::StatusOr> endpoint) { + grpc_core::ExecCtx exec_ctx; + if (endpoint.ok()) { + auto* grpc_endpoint_out = + reinterpret_cast(*endpoint_ptr); + grpc_endpoint_out->endpoint = std::move(*endpoint); + } else { + grpc_endpoint_destroy(*endpoint_ptr); + *endpoint_ptr = nullptr; + } + grpc_core::Closure::Run(DEBUG_LOCATION, closure, + absl_status_to_grpc_error(endpoint.status())); + exec_ctx.Flush(); + grpc_pollset_ee_broadcast_event(); + }; +} + +/// Usage note: this method does not take ownership of any pointer arguments. +void tcp_connect(grpc_closure* on_connect, grpc_endpoint** endpoint, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* /* interested_parties */, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, grpc_millis deadline) { + grpc_event_engine_endpoint* ee_endpoint = + reinterpret_cast( + grpc_tcp_create(channel_args, grpc_sockaddr_to_uri(addr))); + *endpoint = &ee_endpoint->base; + EventEngine::OnConnectCallback ee_on_connect = + GrpcClosureToOnConnectCallback(on_connect, endpoint); + auto ee_slice_allocator = + absl::make_unique(slice_allocator); + EventEngine::ResolvedAddress ra(reinterpret_cast(addr->addr), + addr->len); + absl::Time ee_deadline = grpc_core::ToAbslTime( + grpc_millis_to_timespec(deadline, GPR_CLOCK_MONOTONIC)); + ChannelArgsEndpointConfig endpoint_config(channel_args); + absl::Status connected = grpc_iomgr_event_engine()->Connect( + ee_on_connect, ra, endpoint_config, std::move(ee_slice_allocator), + ee_deadline); + if (!connected.ok()) { + // EventEngine failed to start an asynchronous connect. + grpc_endpoint_destroy(*endpoint); + *endpoint = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_connect, + absl_status_to_grpc_error(connected)); + } +} + +grpc_error_handle tcp_server_create( + grpc_closure* shutdown_complete, const grpc_channel_args* args, + grpc_slice_allocator_factory* slice_allocator_factory, + grpc_tcp_server** server) { + ChannelArgsEndpointConfig endpoint_config(args); + auto ee_slice_allocator_factory = + absl::make_unique( + slice_allocator_factory); + EventEngine* event_engine = grpc_iomgr_event_engine(); + absl::StatusOr> listener = + event_engine->CreateListener( + [server](std::unique_ptr ee_endpoint, + const SliceAllocator& /*slice_allocator*/) { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT((*server)->on_accept_internal != nullptr); + grpc_event_engine_endpoint* iomgr_endpoint = + grpc_tcp_server_endpoint_create(std::move(ee_endpoint)); + grpc_tcp_server_acceptor* acceptor = + static_cast( + gpr_zalloc(sizeof(*acceptor))); + acceptor->from_server = *server; + acceptor->external_connection = false; + (*server)->on_accept_internal((*server)->on_accept_internal_arg, + &iomgr_endpoint->base, nullptr, + acceptor); + exec_ctx.Flush(); + grpc_pollset_ee_broadcast_event(); + }, + GrpcClosureToStatusCallback(shutdown_complete), endpoint_config, + std::move(ee_slice_allocator_factory)); + if (!listener.ok()) { + return absl_status_to_grpc_error(listener.status()); + } + *server = new grpc_tcp_server(std::move(*listener)); + return GRPC_ERROR_NONE; +} + +void tcp_server_start(grpc_tcp_server* server, + const std::vector* /* pollsets */, + grpc_tcp_server_cb on_accept_cb, void* cb_arg) { + server->on_accept_internal = on_accept_cb; + server->on_accept_internal_arg = cb_arg; + // The iomgr API does not handle situations where the server cannot start, so + // a crash may be preferable for now. + GPR_ASSERT(server->listener->Start().ok()); +} + +grpc_error_handle tcp_server_add_port(grpc_tcp_server* s, + const grpc_resolved_address* addr, + int* out_port) { + EventEngine::ResolvedAddress ra(reinterpret_cast(addr->addr), + addr->len); + auto port = s->listener->Bind(ra); + if (!port.ok()) { + return absl_status_to_grpc_error(port.status()); + } + *out_port = *port; + return GRPC_ERROR_NONE; +} + +grpc_core::TcpServerFdHandler* tcp_server_create_fd_handler( + grpc_tcp_server* /* s */) { + // EventEngine-iomgr does not support fds. + return nullptr; +} + +unsigned tcp_server_port_fd_count(grpc_tcp_server* /* s */, + unsigned /* port_index */) { + return 0; +} + +int tcp_server_port_fd(grpc_tcp_server* /* s */, unsigned /* port_index */, + unsigned /* fd_index */) { + // Note: only used internally + return -1; +} + +grpc_tcp_server* tcp_server_ref(grpc_tcp_server* s) { + s->refcount.Ref(DEBUG_LOCATION, "server ref"); + return s; +} + +void tcp_server_shutdown_starting_add(grpc_tcp_server* s, + grpc_closure* shutdown_starting) { + grpc_core::MutexLock lock(&s->mu); + grpc_closure_list_append(&s->shutdown_starting, shutdown_starting, + GRPC_ERROR_NONE); +} + +void tcp_server_unref(grpc_tcp_server* s) { + if (GPR_UNLIKELY(s->refcount.Unref(DEBUG_LOCATION, "server unref"))) { + delete s; + } +} + +// No-op, all are handled on listener unref +void tcp_server_shutdown_listeners(grpc_tcp_server* /* s */) {} + +} // namespace + +grpc_tcp_client_vtable grpc_event_engine_tcp_client_vtable = {tcp_connect}; +grpc_tcp_server_vtable grpc_event_engine_tcp_server_vtable = { + tcp_server_create, tcp_server_start, + tcp_server_add_port, tcp_server_create_fd_handler, + tcp_server_port_fd_count, tcp_server_port_fd, + tcp_server_ref, tcp_server_shutdown_starting_add, + tcp_server_unref, tcp_server_shutdown_listeners}; + +// Methods that are expected to exist elsewhere in the codebase. + +struct grpc_fd { + int fd; +}; + +grpc_fd* grpc_fd_create(int /* fd */, const char* /* name */, + bool /* track_err */) { + return nullptr; +} + +grpc_endpoint* grpc_tcp_client_create_from_fd( + grpc_fd* /* fd */, const grpc_channel_args* /* channel_args */, + const char* /* addr_str */, grpc_slice_allocator* slice_allocator) { + grpc_slice_allocator_destroy(slice_allocator); + return nullptr; +} + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/event_engine/timer.cc b/src/core/lib/iomgr/event_engine/timer.cc new file mode 100644 index 00000000..bb7dbfb9 --- /dev/null +++ b/src/core/lib/iomgr/event_engine/timer.cc @@ -0,0 +1,62 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#ifdef GRPC_USE_EVENT_ENGINE +#include + +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/event_engine/closure.h" +#include "src/core/lib/iomgr/event_engine/iomgr.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/surface/init.h" +#include "src/core/lib/transport/error_utils.h" + +namespace { +using ::grpc_event_engine::experimental::EventEngine; +using ::grpc_event_engine::experimental::GrpcClosureToCallback; + +void timer_init(grpc_timer* timer, grpc_millis deadline, + grpc_closure* closure) { + timer->ee_task_handle = grpc_iomgr_event_engine()->RunAt( + grpc_core::ToAbslTime( + grpc_millis_to_timespec(deadline, GPR_CLOCK_REALTIME)), + GrpcClosureToCallback(closure)); + timer->closure = closure; +} + +void timer_cancel(grpc_timer* timer) { + auto handle = timer->ee_task_handle; + if (!grpc_iomgr_event_engine()->Cancel(handle)) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, + GRPC_ERROR_CANCELLED); + } +} + +/* Internal API */ +grpc_timer_check_result timer_check(grpc_millis* /* next */) { + return GRPC_TIMERS_NOT_CHECKED; +} +void timer_list_init() {} +void timer_list_shutdown(void) {} +void timer_consume_kick(void) {} + +} // namespace + +grpc_timer_vtable grpc_event_engine_timer_vtable = { + timer_init, timer_cancel, timer_check, + timer_list_init, timer_list_shutdown, timer_consume_kick}; + +#endif // GRPC_USE_EVENT_ENGINE diff --git a/src/core/lib/iomgr/exec_ctx.cc b/src/core/lib/iomgr/exec_ctx.cc new file mode 100644 index 00000000..ed480d8f --- /dev/null +++ b/src/core/lib/iomgr/exec_ctx.cc @@ -0,0 +1,227 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/exec_ctx.h" + +#include +#include + +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/profiling/timers.h" + +static void exec_ctx_run(grpc_closure* closure, grpc_error_handle error) { +#ifndef NDEBUG + closure->scheduled = false; + if (grpc_trace_closure.enabled()) { + gpr_log(GPR_DEBUG, "running closure %p: created [%s:%d]: %s [%s:%d]", + closure, closure->file_created, closure->line_created, + closure->run ? "run" : "scheduled", closure->file_initiated, + closure->line_initiated); + } +#endif + closure->cb(closure->cb_arg, error); +#ifndef NDEBUG + if (grpc_trace_closure.enabled()) { + gpr_log(GPR_DEBUG, "closure %p finished", closure); + } +#endif + GRPC_ERROR_UNREF(error); +} + +static void exec_ctx_sched(grpc_closure* closure, grpc_error_handle error) { + grpc_closure_list_append(grpc_core::ExecCtx::Get()->closure_list(), closure, + error); +} + +static gpr_timespec g_start_time; +static gpr_cycle_counter g_start_cycle; + +static grpc_millis timespan_to_millis_round_down(gpr_timespec ts) { + double x = GPR_MS_PER_SEC * static_cast(ts.tv_sec) + + static_cast(ts.tv_nsec) / GPR_NS_PER_MS; + if (x < 0) return 0; + if (x > static_cast(GRPC_MILLIS_INF_FUTURE)) { + return GRPC_MILLIS_INF_FUTURE; + } + return static_cast(x); +} + +static grpc_millis timespec_to_millis_round_down(gpr_timespec ts) { + return timespan_to_millis_round_down(gpr_time_sub(ts, g_start_time)); +} + +static grpc_millis timespan_to_millis_round_up(gpr_timespec ts) { + double x = GPR_MS_PER_SEC * static_cast(ts.tv_sec) + + static_cast(ts.tv_nsec) / GPR_NS_PER_MS + + static_cast(GPR_NS_PER_SEC - 1) / + static_cast(GPR_NS_PER_SEC); + if (x < 0) return 0; + if (x > static_cast(GRPC_MILLIS_INF_FUTURE)) { + return GRPC_MILLIS_INF_FUTURE; + } + return static_cast(x); +} + +static grpc_millis timespec_to_millis_round_up(gpr_timespec ts) { + return timespan_to_millis_round_up(gpr_time_sub(ts, g_start_time)); +} + +gpr_timespec grpc_millis_to_timespec(grpc_millis millis, + gpr_clock_type clock_type) { + // special-case infinities as grpc_millis can be 32bit on some platforms + // while gpr_time_from_millis always takes an int64_t. + if (millis == GRPC_MILLIS_INF_FUTURE) { + return gpr_inf_future(clock_type); + } + if (millis == GRPC_MILLIS_INF_PAST) { + return gpr_inf_past(clock_type); + } + + if (clock_type == GPR_TIMESPAN) { + return gpr_time_from_millis(millis, GPR_TIMESPAN); + } + return gpr_time_add(gpr_convert_clock_type(g_start_time, clock_type), + gpr_time_from_millis(millis, GPR_TIMESPAN)); +} + +grpc_millis grpc_timespec_to_millis_round_down(gpr_timespec ts) { + return timespec_to_millis_round_down( + gpr_convert_clock_type(ts, g_start_time.clock_type)); +} + +grpc_millis grpc_timespec_to_millis_round_up(gpr_timespec ts) { + return timespec_to_millis_round_up( + gpr_convert_clock_type(ts, g_start_time.clock_type)); +} + +grpc_millis grpc_cycle_counter_to_millis_round_down(gpr_cycle_counter cycles) { + return timespan_to_millis_round_down( + gpr_cycle_counter_sub(cycles, g_start_cycle)); +} + +grpc_millis grpc_cycle_counter_to_millis_round_up(gpr_cycle_counter cycles) { + return timespan_to_millis_round_up( + gpr_cycle_counter_sub(cycles, g_start_cycle)); +} + +namespace grpc_core { +GPR_THREAD_LOCAL(ExecCtx*) ExecCtx::exec_ctx_; +GPR_THREAD_LOCAL(ApplicationCallbackExecCtx*) +ApplicationCallbackExecCtx::callback_exec_ctx_; + +// WARNING: for testing purposes only! +void ExecCtx::TestOnlyGlobalInit(gpr_timespec new_val) { + g_start_time = new_val; +} + +void ExecCtx::GlobalInit(void) { + // gpr_now(GPR_CLOCK_MONOTONIC) incurs a syscall. We don't actually know the + // exact cycle the time was captured, so we use the average of cycles before + // and after the syscall as the starting cycle. + const gpr_cycle_counter cycle_before = gpr_get_cycle_counter(); + g_start_time = gpr_now(GPR_CLOCK_MONOTONIC); + const gpr_cycle_counter cycle_after = gpr_get_cycle_counter(); + g_start_cycle = (cycle_before + cycle_after) / 2; +} + +bool ExecCtx::Flush() { + bool did_something = false; + GPR_TIMER_SCOPE("grpc_exec_ctx_flush", 0); + for (;;) { + if (!grpc_closure_list_empty(closure_list_)) { + grpc_closure* c = closure_list_.head; + closure_list_.head = closure_list_.tail = nullptr; + while (c != nullptr) { + grpc_closure* next = c->next_data.next; + grpc_error_handle error = c->error_data.error; + did_something = true; + exec_ctx_run(c, error); + c = next; + } + } else if (!grpc_combiner_continue_exec_ctx()) { + break; + } + } + GPR_ASSERT(combiner_data_.active_combiner == nullptr); + return did_something; +} + +grpc_millis ExecCtx::Now() { + if (!now_is_valid_) { + now_ = timespec_to_millis_round_down(gpr_now(GPR_CLOCK_MONOTONIC)); + now_is_valid_ = true; + } + return now_; +} + +void ExecCtx::Run(const DebugLocation& location, grpc_closure* closure, + grpc_error_handle error) { + (void)location; + if (closure == nullptr) { + GRPC_ERROR_UNREF(error); + return; + } +#ifndef NDEBUG + if (closure->scheduled) { + gpr_log(GPR_ERROR, + "Closure already scheduled. (closure: %p, created: [%s:%d], " + "previously scheduled at: [%s: %d], newly scheduled at [%s: %d]", + closure, closure->file_created, closure->line_created, + closure->file_initiated, closure->line_initiated, location.file(), + location.line()); + abort(); + } + closure->scheduled = true; + closure->file_initiated = location.file(); + closure->line_initiated = location.line(); + closure->run = false; + GPR_ASSERT(closure->cb != nullptr); +#endif + exec_ctx_sched(closure, error); +} + +void ExecCtx::RunList(const DebugLocation& location, grpc_closure_list* list) { + (void)location; + grpc_closure* c = list->head; + while (c != nullptr) { + grpc_closure* next = c->next_data.next; +#ifndef NDEBUG + if (c->scheduled) { + gpr_log(GPR_ERROR, + "Closure already scheduled. (closure: %p, created: [%s:%d], " + "previously scheduled at: [%s: %d], newly scheduled at [%s:%d]", + c, c->file_created, c->line_created, c->file_initiated, + c->line_initiated, location.file(), location.line()); + abort(); + } + c->scheduled = true; + c->file_initiated = location.file(); + c->line_initiated = location.line(); + c->run = false; + GPR_ASSERT(c->cb != nullptr); +#endif + exec_ctx_sched(c, c->error_data.error); + c = next; + } + list->head = list->tail = nullptr; +} + +} // namespace grpc_core diff --git a/src/core/lib/iomgr/executor.cc b/src/core/lib/iomgr/executor.cc new file mode 100644 index 00000000..4ef0c34a --- /dev/null +++ b/src/core/lib/iomgr/executor.cc @@ -0,0 +1,455 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/executor.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iomgr_internal.h" + +#define MAX_DEPTH 2 + +#define EXECUTOR_TRACE(format, ...) \ + do { \ + if (GRPC_TRACE_FLAG_ENABLED(executor_trace)) { \ + gpr_log(GPR_INFO, "EXECUTOR " format, __VA_ARGS__); \ + } \ + } while (0) + +#define EXECUTOR_TRACE0(str) \ + do { \ + if (GRPC_TRACE_FLAG_ENABLED(executor_trace)) { \ + gpr_log(GPR_INFO, "EXECUTOR " str); \ + } \ + } while (0) + +namespace grpc_core { +namespace { + +static GPR_THREAD_LOCAL(ThreadState*) g_this_thread_state; + +Executor* executors[static_cast(ExecutorType::NUM_EXECUTORS)]; + +void default_enqueue_short(grpc_closure* closure, grpc_error_handle error) { + executors[static_cast(ExecutorType::DEFAULT)]->Enqueue( + closure, error, true /* is_short */); +} + +void default_enqueue_long(grpc_closure* closure, grpc_error_handle error) { + executors[static_cast(ExecutorType::DEFAULT)]->Enqueue( + closure, error, false /* is_short */); +} + +void resolver_enqueue_short(grpc_closure* closure, grpc_error_handle error) { + executors[static_cast(ExecutorType::RESOLVER)]->Enqueue( + closure, error, true /* is_short */); +} + +void resolver_enqueue_long(grpc_closure* closure, grpc_error_handle error) { + executors[static_cast(ExecutorType::RESOLVER)]->Enqueue( + closure, error, false /* is_short */); +} + +using EnqueueFunc = void (*)(grpc_closure* closure, grpc_error_handle error); + +const EnqueueFunc + executor_enqueue_fns_[static_cast(ExecutorType::NUM_EXECUTORS)] + [static_cast(ExecutorJobType::NUM_JOB_TYPES)] = + {{default_enqueue_short, default_enqueue_long}, + {resolver_enqueue_short, resolver_enqueue_long}}; + +} // namespace + +TraceFlag executor_trace(false, "executor"); + +Executor::Executor(const char* name) : name_(name) { + adding_thread_lock_ = GPR_SPINLOCK_STATIC_INITIALIZER; + gpr_atm_rel_store(&num_threads_, 0); + max_threads_ = std::max(1u, 2 * gpr_cpu_num_cores()); +} + +void Executor::Init() { SetThreading(true); } + +size_t Executor::RunClosures(const char* executor_name, + grpc_closure_list list) { + size_t n = 0; + + // In the executor, the ExecCtx for the thread is declared in the executor + // thread itself, but this is the point where we could start seeing + // application-level callbacks. No need to create a new ExecCtx, though, + // since there already is one and it is flushed (but not destructed) in this + // function itself. The ApplicationCallbackExecCtx will have its callbacks + // invoked on its destruction, which will be after completing any closures in + // the executor's closure list (which were explicitly scheduled onto the + // executor). + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + + grpc_closure* c = list.head; + while (c != nullptr) { + grpc_closure* next = c->next_data.next; + grpc_error_handle error = c->error_data.error; +#ifndef NDEBUG + EXECUTOR_TRACE("(%s) run %p [created by %s:%d]", executor_name, c, + c->file_created, c->line_created); + c->scheduled = false; +#else + EXECUTOR_TRACE("(%s) run %p", executor_name, c); +#endif + c->cb(c->cb_arg, error); + GRPC_ERROR_UNREF(error); + c = next; + n++; + grpc_core::ExecCtx::Get()->Flush(); + } + + return n; +} + +bool Executor::IsThreaded() const { + return gpr_atm_acq_load(&num_threads_) > 0; +} + +void Executor::SetThreading(bool threading) { + gpr_atm curr_num_threads = gpr_atm_acq_load(&num_threads_); + EXECUTOR_TRACE("(%s) SetThreading(%d) begin", name_, threading); + + if (threading) { + if (curr_num_threads > 0) { + EXECUTOR_TRACE("(%s) SetThreading(true). curr_num_threads > 0", name_); + return; + } + + GPR_ASSERT(num_threads_ == 0); + gpr_atm_rel_store(&num_threads_, 1); + thd_state_ = static_cast( + gpr_zalloc(sizeof(ThreadState) * max_threads_)); + + for (size_t i = 0; i < max_threads_; i++) { + gpr_mu_init(&thd_state_[i].mu); + gpr_cv_init(&thd_state_[i].cv); + thd_state_[i].id = i; + thd_state_[i].name = name_; + thd_state_[i].thd = grpc_core::Thread(); + thd_state_[i].elems = GRPC_CLOSURE_LIST_INIT; + } + + thd_state_[0].thd = + grpc_core::Thread(name_, &Executor::ThreadMain, &thd_state_[0]); + thd_state_[0].thd.Start(); + } else { // !threading + if (curr_num_threads == 0) { + EXECUTOR_TRACE("(%s) SetThreading(false). curr_num_threads == 0", name_); + return; + } + + for (size_t i = 0; i < max_threads_; i++) { + gpr_mu_lock(&thd_state_[i].mu); + thd_state_[i].shutdown = true; + gpr_cv_signal(&thd_state_[i].cv); + gpr_mu_unlock(&thd_state_[i].mu); + } + + /* Ensure no thread is adding a new thread. Once this is past, then no + * thread will try to add a new one either (since shutdown is true) */ + gpr_spinlock_lock(&adding_thread_lock_); + gpr_spinlock_unlock(&adding_thread_lock_); + + curr_num_threads = gpr_atm_no_barrier_load(&num_threads_); + for (gpr_atm i = 0; i < curr_num_threads; i++) { + thd_state_[i].thd.Join(); + EXECUTOR_TRACE("(%s) Thread %" PRIdPTR " of %" PRIdPTR " joined", name_, + i + 1, curr_num_threads); + } + + gpr_atm_rel_store(&num_threads_, 0); + for (size_t i = 0; i < max_threads_; i++) { + gpr_mu_destroy(&thd_state_[i].mu); + gpr_cv_destroy(&thd_state_[i].cv); + RunClosures(thd_state_[i].name, thd_state_[i].elems); + } + + gpr_free(thd_state_); + + // grpc_iomgr_shutdown_background_closure() will close all the registered + // fds in the background poller, and wait for all pending closures to + // finish. Thus, never call Executor::SetThreading(false) in the middle of + // an application. + // TODO(guantaol): create another method to finish all the pending closures + // registered in the background poller by grpc_core::Executor. + grpc_iomgr_platform_shutdown_background_closure(); + } + + EXECUTOR_TRACE("(%s) SetThreading(%d) done", name_, threading); +} + +void Executor::Shutdown() { SetThreading(false); } + +void Executor::ThreadMain(void* arg) { + ThreadState* ts = static_cast(arg); + g_this_thread_state = ts; + + grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + + size_t subtract_depth = 0; + for (;;) { + EXECUTOR_TRACE("(%s) [%" PRIdPTR "]: step (sub_depth=%" PRIdPTR ")", + ts->name, ts->id, subtract_depth); + + gpr_mu_lock(&ts->mu); + ts->depth -= subtract_depth; + // Wait for closures to be enqueued or for the executor to be shutdown + while (grpc_closure_list_empty(ts->elems) && !ts->shutdown) { + ts->queued_long_job = false; + gpr_cv_wait(&ts->cv, &ts->mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + + if (ts->shutdown) { + EXECUTOR_TRACE("(%s) [%" PRIdPTR "]: shutdown", ts->name, ts->id); + gpr_mu_unlock(&ts->mu); + break; + } + + grpc_closure_list closures = ts->elems; + ts->elems = GRPC_CLOSURE_LIST_INIT; + gpr_mu_unlock(&ts->mu); + + EXECUTOR_TRACE("(%s) [%" PRIdPTR "]: execute", ts->name, ts->id); + + grpc_core::ExecCtx::Get()->InvalidateNow(); + subtract_depth = RunClosures(ts->name, closures); + } + + g_this_thread_state = nullptr; +} + +void Executor::Enqueue(grpc_closure* closure, grpc_error_handle error, + bool is_short) { + bool retry_push; + + do { + retry_push = false; + size_t cur_thread_count = + static_cast(gpr_atm_acq_load(&num_threads_)); + + // If the number of threads is zero(i.e either the executor is not threaded + // or already shutdown), then queue the closure on the exec context itself + if (cur_thread_count == 0) { +#ifndef NDEBUG + EXECUTOR_TRACE("(%s) schedule %p (created %s:%d) inline", name_, closure, + closure->file_created, closure->line_created); +#else + EXECUTOR_TRACE("(%s) schedule %p inline", name_, closure); +#endif + grpc_closure_list_append(grpc_core::ExecCtx::Get()->closure_list(), + closure, error); + return; + } + + if (grpc_iomgr_platform_add_closure_to_background_poller(closure, error)) { + return; + } + + ThreadState* ts = g_this_thread_state; + if (ts == nullptr) { + ts = &thd_state_[grpc_core::HashPointer(grpc_core::ExecCtx::Get(), + cur_thread_count)]; + } + + ThreadState* orig_ts = ts; + bool try_new_thread = false; + + for (;;) { +#ifndef NDEBUG + EXECUTOR_TRACE( + "(%s) try to schedule %p (%s) (created %s:%d) to thread " + "%" PRIdPTR, + name_, closure, is_short ? "short" : "long", closure->file_created, + closure->line_created, ts->id); +#else + EXECUTOR_TRACE("(%s) try to schedule %p (%s) to thread %" PRIdPTR, name_, + closure, is_short ? "short" : "long", ts->id); +#endif + + gpr_mu_lock(&ts->mu); + if (ts->queued_long_job) { + // if there's a long job queued, we never queue anything else to this + // queue (since long jobs can take 'infinite' time and we need to + // guarantee no starvation). Spin through queues and try again + gpr_mu_unlock(&ts->mu); + size_t idx = ts->id; + ts = &thd_state_[(idx + 1) % cur_thread_count]; + if (ts == orig_ts) { + // We cycled through all the threads. Retry enqueue again by creating + // a new thread + // + // TODO (sreek): There is a potential issue here. We are + // unconditionally setting try_new_thread to true here. What if the + // executor is shutdown OR if cur_thread_count is already equal to + // max_threads ? + // (Fortunately, this is not an issue yet (as of july 2018) because + // there is only one instance of long job in gRPC and hence we will + // not hit this code path) + retry_push = true; + try_new_thread = true; + break; + } + + continue; // Try the next thread-state + } + + // == Found the thread state (i.e thread) to enqueue this closure! == + + // Also, if this thread has been waiting for closures, wake it up. + // - If grpc_closure_list_empty() is true and the Executor is not + // shutdown, it means that the thread must be waiting in ThreadMain() + // - Note that gpr_cv_signal() won't immediately wakeup the thread. That + // happens after we release the mutex &ts->mu a few lines below + if (grpc_closure_list_empty(ts->elems) && !ts->shutdown) { + gpr_cv_signal(&ts->cv); + } + + grpc_closure_list_append(&ts->elems, closure, error); + + // If we already queued more than MAX_DEPTH number of closures on this + // thread, use this as a hint to create more threads + ts->depth++; + try_new_thread = ts->depth > MAX_DEPTH && + cur_thread_count < max_threads_ && !ts->shutdown; + + ts->queued_long_job = !is_short; + + gpr_mu_unlock(&ts->mu); + break; + } + + if (try_new_thread && gpr_spinlock_trylock(&adding_thread_lock_)) { + cur_thread_count = static_cast(gpr_atm_acq_load(&num_threads_)); + if (cur_thread_count < max_threads_) { + // Increment num_threads (safe to do a store instead of a cas because we + // always increment num_threads under the 'adding_thread_lock') + gpr_atm_rel_store(&num_threads_, cur_thread_count + 1); + + thd_state_[cur_thread_count].thd = grpc_core::Thread( + name_, &Executor::ThreadMain, &thd_state_[cur_thread_count]); + thd_state_[cur_thread_count].thd.Start(); + } + gpr_spinlock_unlock(&adding_thread_lock_); + } + } while (retry_push); +} + +// Executor::InitAll() and Executor::ShutdownAll() functions are called in the +// the grpc_init() and grpc_shutdown() code paths which are protected by a +// global mutex. So it is okay to assume that these functions are thread-safe +void Executor::InitAll() { + EXECUTOR_TRACE0("Executor::InitAll() enter"); + + // Return if Executor::InitAll() is already called earlier + if (executors[static_cast(ExecutorType::DEFAULT)] != nullptr) { + GPR_ASSERT(executors[static_cast(ExecutorType::RESOLVER)] != + nullptr); + return; + } + + executors[static_cast(ExecutorType::DEFAULT)] = + new Executor("default-executor"); + executors[static_cast(ExecutorType::RESOLVER)] = + new Executor("resolver-executor"); + + executors[static_cast(ExecutorType::DEFAULT)]->Init(); + executors[static_cast(ExecutorType::RESOLVER)]->Init(); + + EXECUTOR_TRACE0("Executor::InitAll() done"); +} + +void Executor::Run(grpc_closure* closure, grpc_error_handle error, + ExecutorType executor_type, ExecutorJobType job_type) { + executor_enqueue_fns_[static_cast(executor_type)] + [static_cast(job_type)](closure, error); +} + +void Executor::ShutdownAll() { + EXECUTOR_TRACE0("Executor::ShutdownAll() enter"); + + // Return if Executor:SshutdownAll() is already called earlier + if (executors[static_cast(ExecutorType::DEFAULT)] == nullptr) { + GPR_ASSERT(executors[static_cast(ExecutorType::RESOLVER)] == + nullptr); + return; + } + + executors[static_cast(ExecutorType::DEFAULT)]->Shutdown(); + executors[static_cast(ExecutorType::RESOLVER)]->Shutdown(); + + // Delete the executor objects. + // + // NOTE: It is important to call Shutdown() on all executors first before + // calling delete because it is possible for one executor (that is not + // shutdown yet) to call Enqueue() on a different executor which is already + // shutdown. This is legal and in such cases, the Enqueue() operation + // effectively "fails" and enqueues that closure on the calling thread's + // exec_ctx. + // + // By ensuring that all executors are shutdown first, we are also ensuring + // that no thread is active across all executors. + + delete executors[static_cast(ExecutorType::DEFAULT)]; + delete executors[static_cast(ExecutorType::RESOLVER)]; + executors[static_cast(ExecutorType::DEFAULT)] = nullptr; + executors[static_cast(ExecutorType::RESOLVER)] = nullptr; + + EXECUTOR_TRACE0("Executor::ShutdownAll() done"); +} + +bool Executor::IsThreaded(ExecutorType executor_type) { + GPR_ASSERT(executor_type < ExecutorType::NUM_EXECUTORS); + return executors[static_cast(executor_type)]->IsThreaded(); +} + +bool Executor::IsThreadedDefault() { + return Executor::IsThreaded(ExecutorType::DEFAULT); +} + +void Executor::SetThreadingAll(bool enable) { + EXECUTOR_TRACE("Executor::SetThreadingAll(%d) called", enable); + for (size_t i = 0; i < static_cast(ExecutorType::NUM_EXECUTORS); + i++) { + executors[i]->SetThreading(enable); + } +} + +void Executor::SetThreadingDefault(bool enable) { + EXECUTOR_TRACE("Executor::SetThreadingDefault(%d) called", enable); + executors[static_cast(ExecutorType::DEFAULT)]->SetThreading(enable); +} + +void grpc_executor_global_init() {} + +} // namespace grpc_core diff --git a/src/core/lib/iomgr/executor/mpmcqueue.cc b/src/core/lib/iomgr/executor/mpmcqueue.cc new file mode 100644 index 00000000..e4d9d21d --- /dev/null +++ b/src/core/lib/iomgr/executor/mpmcqueue.cc @@ -0,0 +1,182 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/executor/mpmcqueue.h" + +namespace grpc_core { + +DebugOnlyTraceFlag grpc_thread_pool_trace(false, "thread_pool"); + +inline void* InfLenFIFOQueue::PopFront() { + // Caller should already check queue is not empty and has already held the + // mutex. This function will assume that there is at least one element in the + // queue (i.e. queue_head_->content is valid). + void* result = queue_head_->content; + count_.store(count_.load(std::memory_order_relaxed) - 1, + std::memory_order_relaxed); + + // Updates Stats when trace flag turned on. + if (GRPC_TRACE_FLAG_ENABLED(grpc_thread_pool_trace)) { + gpr_timespec wait_time = + gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), queue_head_->insert_time); + stats_.num_completed++; + stats_.total_queue_time = gpr_time_add(stats_.total_queue_time, wait_time); + stats_.max_queue_time = gpr_time_max( + gpr_convert_clock_type(stats_.max_queue_time, GPR_TIMESPAN), wait_time); + + if (count_.load(std::memory_order_relaxed) == 0) { + stats_.busy_queue_time = + gpr_time_add(stats_.busy_queue_time, + gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), busy_time)); + } + + gpr_log(GPR_INFO, + "[InfLenFIFOQueue PopFront] num_completed: %" PRIu64 + " total_queue_time: %f max_queue_time: %f busy_queue_time: %f", + stats_.num_completed, + gpr_timespec_to_micros(stats_.total_queue_time), + gpr_timespec_to_micros(stats_.max_queue_time), + gpr_timespec_to_micros(stats_.busy_queue_time)); + } + + queue_head_ = queue_head_->next; + // Signal waiting thread + if (count_.load(std::memory_order_relaxed) > 0) { + TopWaiter()->cv.Signal(); + } + + return result; +} + +InfLenFIFOQueue::Node* InfLenFIFOQueue::AllocateNodes(int num) { + num_nodes_ = num_nodes_ + num; + Node* new_chunk = new Node[num]; + new_chunk[0].next = &new_chunk[1]; + new_chunk[num - 1].prev = &new_chunk[num - 2]; + for (int i = 1; i < num - 1; ++i) { + new_chunk[i].prev = &new_chunk[i - 1]; + new_chunk[i].next = &new_chunk[i + 1]; + } + return new_chunk; +} + +InfLenFIFOQueue::InfLenFIFOQueue() { + delete_list_size_ = kDeleteListInitSize; + delete_list_ = new Node*[delete_list_size_]; + + Node* new_chunk = AllocateNodes(kQueueInitNumNodes); + delete_list_[delete_list_count_++] = new_chunk; + queue_head_ = queue_tail_ = new_chunk; + new_chunk[0].prev = &new_chunk[kQueueInitNumNodes - 1]; + new_chunk[kQueueInitNumNodes - 1].next = &new_chunk[0]; + + waiters_.next = &waiters_; + waiters_.prev = &waiters_; +} + +InfLenFIFOQueue::~InfLenFIFOQueue() { + GPR_ASSERT(count_.load(std::memory_order_relaxed) == 0); + for (size_t i = 0; i < delete_list_count_; ++i) { + delete[] delete_list_[i]; + } + delete[] delete_list_; +} + +void InfLenFIFOQueue::Put(void* elem) { + MutexLock l(&mu_); + + int curr_count = count_.load(std::memory_order_relaxed); + + if (queue_tail_ == queue_head_ && curr_count != 0) { + // List is full. Expands list to double size by inserting new chunk of nodes + Node* new_chunk = AllocateNodes(curr_count); + delete_list_[delete_list_count_++] = new_chunk; + // Expands delete list on full. + if (delete_list_count_ == delete_list_size_) { + delete_list_size_ = delete_list_size_ * 2; + delete_list_ = new Node*[delete_list_size_]; + } + new_chunk[0].prev = queue_tail_->prev; + new_chunk[curr_count - 1].next = queue_head_; + queue_tail_->prev->next = new_chunk; + queue_head_->prev = &new_chunk[curr_count - 1]; + queue_tail_ = new_chunk; + } + queue_tail_->content = static_cast(elem); + + // Updates Stats info + if (GRPC_TRACE_FLAG_ENABLED(grpc_thread_pool_trace)) { + stats_.num_started++; + gpr_log(GPR_INFO, "[InfLenFIFOQueue Put] num_started: %" PRIu64, + stats_.num_started); + auto current_time = gpr_now(GPR_CLOCK_MONOTONIC); + if (curr_count == 0) { + busy_time = current_time; + } + queue_tail_->insert_time = current_time; + } + + count_.store(curr_count + 1, std::memory_order_relaxed); + queue_tail_ = queue_tail_->next; + + TopWaiter()->cv.Signal(); +} + +void* InfLenFIFOQueue::Get(gpr_timespec* wait_time) { + MutexLock l(&mu_); + + if (count_.load(std::memory_order_relaxed) == 0) { + gpr_timespec start_time; + if (GRPC_TRACE_FLAG_ENABLED(grpc_thread_pool_trace) && + wait_time != nullptr) { + start_time = gpr_now(GPR_CLOCK_MONOTONIC); + } + + Waiter self; + PushWaiter(&self); + do { + self.cv.Wait(&mu_); + } while (count_.load(std::memory_order_relaxed) == 0); + RemoveWaiter(&self); + if (GRPC_TRACE_FLAG_ENABLED(grpc_thread_pool_trace) && + wait_time != nullptr) { + *wait_time = gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), start_time); + } + } + GPR_DEBUG_ASSERT(count_.load(std::memory_order_relaxed) > 0); + return PopFront(); +} + +void InfLenFIFOQueue::PushWaiter(Waiter* waiter) { + waiter->next = waiters_.next; + waiter->prev = &waiters_; + waiter->next->prev = waiter; + waiter->prev->next = waiter; +} + +void InfLenFIFOQueue::RemoveWaiter(Waiter* waiter) { + GPR_DEBUG_ASSERT(waiter != &waiters_); + waiter->next->prev = waiter->prev; + waiter->prev->next = waiter->next; +} + +InfLenFIFOQueue::Waiter* InfLenFIFOQueue::TopWaiter() { return waiters_.next; } + +} // namespace grpc_core diff --git a/src/core/lib/iomgr/executor/threadpool.cc b/src/core/lib/iomgr/executor/threadpool.cc new file mode 100644 index 00000000..8bd954cc --- /dev/null +++ b/src/core/lib/iomgr/executor/threadpool.cc @@ -0,0 +1,136 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/executor/threadpool.h" + +namespace grpc_core { + +void ThreadPoolWorker::Run() { + while (true) { + void* elem; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_thread_pool_trace)) { + // Updates stats and print + gpr_timespec wait_time = gpr_time_0(GPR_TIMESPAN); + elem = queue_->Get(&wait_time); + stats_.sleep_time = gpr_time_add(stats_.sleep_time, wait_time); + gpr_log(GPR_INFO, + "ThreadPool Worker [%s %d] Stats: sleep_time %f", + thd_name_, index_, gpr_timespec_to_micros(stats_.sleep_time)); + } else { + elem = queue_->Get(nullptr); + } + if (elem == nullptr) { + break; + } + // Runs closure + auto* closure = static_cast(elem); + closure->functor_run(closure, closure->internal_success); + } +} + +void ThreadPool::SharedThreadPoolConstructor() { + // All worker threads in thread pool must be joinable. + thread_options_.set_joinable(true); + + // Create at least 1 worker thread. + if (num_threads_ <= 0) num_threads_ = 1; + + queue_ = new InfLenFIFOQueue(); + threads_ = static_cast( + gpr_zalloc(num_threads_ * sizeof(ThreadPoolWorker*))); + for (int i = 0; i < num_threads_; ++i) { + threads_[i] = new ThreadPoolWorker(thd_name_, queue_, thread_options_, i); + threads_[i]->Start(); + } +} + +size_t ThreadPool::DefaultStackSize() { +#if defined(__ANDROID__) || defined(__APPLE__) + return 1952 * 1024; +#else + return 64 * 1024; +#endif +} + +void ThreadPool::AssertHasNotBeenShutDown() { + // For debug checking purpose, using RELAXED order is sufficient. + GPR_DEBUG_ASSERT(!shut_down_.load(std::memory_order_relaxed)); +} + +ThreadPool::ThreadPool(int num_threads) : num_threads_(num_threads) { + thd_name_ = "ThreadPoolWorker"; + thread_options_ = Thread::Options(); + thread_options_.set_stack_size(DefaultStackSize()); + SharedThreadPoolConstructor(); +} + +ThreadPool::ThreadPool(int num_threads, const char* thd_name) + : num_threads_(num_threads), thd_name_(thd_name) { + thread_options_ = Thread::Options(); + thread_options_.set_stack_size(DefaultStackSize()); + SharedThreadPoolConstructor(); +} + +ThreadPool::ThreadPool(int num_threads, const char* thd_name, + const Thread::Options& thread_options) + : num_threads_(num_threads), + thd_name_(thd_name), + thread_options_(thread_options) { + if (thread_options_.stack_size() == 0) { + thread_options_.set_stack_size(DefaultStackSize()); + } + SharedThreadPoolConstructor(); +} + +ThreadPool::~ThreadPool() { + // For debug checking purpose, using RELAXED order is sufficient. + shut_down_.store(true, std::memory_order_relaxed); + + for (int i = 0; i < num_threads_; ++i) { + queue_->Put(nullptr); + } + + for (int i = 0; i < num_threads_; ++i) { + threads_[i]->Join(); + } + + for (int i = 0; i < num_threads_; ++i) { + delete threads_[i]; + } + gpr_free(threads_); + delete queue_; +} + +void ThreadPool::Add(grpc_completion_queue_functor* closure) { + AssertHasNotBeenShutDown(); + queue_->Put(static_cast(closure)); +} + +int ThreadPool::num_pending_closures() const { return queue_->count(); } + +int ThreadPool::pool_capacity() const { return num_threads_; } + +const Thread::Options& ThreadPool::thread_options() const { + return thread_options_; +} + +const char* ThreadPool::thread_name() const { return thd_name_; } +} // namespace grpc_core diff --git a/src/core/lib/iomgr/fork_posix.cc b/src/core/lib/iomgr/fork_posix.cc new file mode 100644 index 00000000..82654b5a --- /dev/null +++ b/src/core/lib/iomgr/fork_posix.cc @@ -0,0 +1,119 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_FORK + +#ifdef GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK +#include +#endif + +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/fork.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" + +/* + * NOTE: FORKING IS NOT GENERALLY SUPPORTED, THIS IS ONLY INTENDED TO WORK + * AROUND VERY SPECIFIC USE CASES. + */ + +namespace { +bool skipped_handler = true; +bool registered_handlers = false; +} // namespace + +void grpc_prefork() { + skipped_handler = true; + // This may be called after core shuts down, so verify initialized before + // instantiating an ExecCtx. + if (!grpc_is_initialized()) { + return; + } + grpc_core::ExecCtx exec_ctx; + if (!grpc_core::Fork::Enabled()) { + gpr_log(GPR_ERROR, + "Fork support not enabled; try running with the " + "environment variable GRPC_ENABLE_FORK_SUPPORT=1"); + return; + } + const char* poll_strategy_name = grpc_get_poll_strategy_name(); + if (poll_strategy_name == nullptr || + (strcmp(poll_strategy_name, "epoll1") != 0 && + strcmp(poll_strategy_name, "poll") != 0)) { + gpr_log(GPR_INFO, + "Fork support is only compatible with the epoll1 and poll polling " + "strategies"); + } + if (!grpc_core::Fork::BlockExecCtx()) { + gpr_log(GPR_INFO, + "Other threads are currently calling into gRPC, skipping fork() " + "handlers"); + return; + } + grpc_timer_manager_set_threading(false); + grpc_core::Executor::SetThreadingAll(false); + grpc_core::ExecCtx::Get()->Flush(); + grpc_core::Fork::AwaitThreads(); + skipped_handler = false; +} + +void grpc_postfork_parent() { + if (!skipped_handler) { + grpc_core::Fork::AllowExecCtx(); + grpc_core::ExecCtx exec_ctx; + grpc_timer_manager_set_threading(true); + grpc_core::Executor::SetThreadingAll(true); + } +} + +void grpc_postfork_child() { + if (!skipped_handler) { + grpc_core::Fork::AllowExecCtx(); + grpc_core::ExecCtx exec_ctx; + grpc_core::Fork::child_postfork_func reset_polling_engine = + grpc_core::Fork::GetResetChildPollingEngineFunc(); + if (reset_polling_engine != nullptr) { + reset_polling_engine(); + } + grpc_timer_manager_set_threading(true); + grpc_core::Executor::SetThreadingAll(true); + } +} + +void grpc_fork_handlers_auto_register() { + if (grpc_core::Fork::Enabled() & !registered_handlers) { +#ifdef GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK + pthread_atfork(grpc_prefork, grpc_postfork_parent, grpc_postfork_child); + registered_handlers = true; +#endif // GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK + } +} + +#endif // GRPC_POSIX_FORK diff --git a/src/core/lib/iomgr/fork_windows.cc b/src/core/lib/iomgr/fork_windows.cc new file mode 100644 index 00000000..798f671b --- /dev/null +++ b/src/core/lib/iomgr/fork_windows.cc @@ -0,0 +1,41 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifndef GRPC_POSIX_FORK + +#include +#include + +/* + * NOTE: FORKING IS NOT GENERALLY SUPPORTED, THIS IS ONLY INTENDED TO WORK + * AROUND VERY SPECIFIC USE CASES. + */ + +void grpc_prefork() { gpr_log(GPR_ERROR, "Forking not supported on Windows"); } + +void grpc_postfork_parent() {} + +void grpc_postfork_child() {} + +void grpc_fork_handlers_auto_register() {} + +#endif // GRPC_POSIX_FORK diff --git a/src/core/lib/iomgr/gethostname_fallback.cc b/src/core/lib/iomgr/gethostname_fallback.cc new file mode 100644 index 00000000..65ae8187 --- /dev/null +++ b/src/core/lib/iomgr/gethostname_fallback.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/gethostname.h" +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_GETHOSTNAME_FALLBACK + +#include + +char* grpc_gethostname() { return NULL; } + +#endif // GRPC_GETHOSTNAME_FALLBACK diff --git a/src/core/lib/iomgr/gethostname_host_name_max.cc b/src/core/lib/iomgr/gethostname_host_name_max.cc new file mode 100644 index 00000000..79f5daa8 --- /dev/null +++ b/src/core/lib/iomgr/gethostname_host_name_max.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/gethostname.h" +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_HOST_NAME_MAX + +#include +#include + +#include + +char* grpc_gethostname() { + char* hostname = static_cast(gpr_malloc(HOST_NAME_MAX)); + if (gethostname(hostname, HOST_NAME_MAX) != 0) { + gpr_free(hostname); + return nullptr; + } + return hostname; +} + +#endif // GRPC_POSIX_HOST_NAME_MAX diff --git a/src/core/lib/iomgr/gethostname_sysconf.cc b/src/core/lib/iomgr/gethostname_sysconf.cc new file mode 100644 index 00000000..92c5de33 --- /dev/null +++ b/src/core/lib/iomgr/gethostname_sysconf.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/gethostname.h" +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SYSCONF + +#include + +#include + +char* grpc_gethostname() { + size_t host_name_max = (size_t)sysconf(_SC_HOST_NAME_MAX); + char* hostname = (char*)gpr_malloc(host_name_max); + if (gethostname(hostname, host_name_max) != 0) { + gpr_free(hostname); + return nullptr; + } + return hostname; +} + +#endif // GRPC_POSIX_SYSCONF diff --git a/src/core/lib/iomgr/grpc_if_nametoindex_posix.cc b/src/core/lib/iomgr/grpc_if_nametoindex_posix.cc new file mode 100644 index 00000000..8916eac9 --- /dev/null +++ b/src/core/lib/iomgr/grpc_if_nametoindex_posix.cc @@ -0,0 +1,42 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#if GRPC_IF_NAMETOINDEX == 1 && defined(GRPC_POSIX_SOCKET_IF_NAMETOINDEX) + +#include +#include + +#include + +#include "src/core/lib/iomgr/grpc_if_nametoindex.h" + +uint32_t grpc_if_nametoindex(char* name) { + uint32_t out = if_nametoindex(name); + if (out == 0) { + gpr_log(GPR_DEBUG, "if_nametoindex failed for name %s. errno %d", name, + errno); + } + return out; +} + +#endif /* GRPC_IF_NAMETOINDEX == 1 && \ + defined(GRPC_POSIX_SOCKET_IF_NAMETOINDEX) */ diff --git a/src/core/lib/iomgr/grpc_if_nametoindex_unsupported.cc b/src/core/lib/iomgr/grpc_if_nametoindex_unsupported.cc new file mode 100644 index 00000000..63062433 --- /dev/null +++ b/src/core/lib/iomgr/grpc_if_nametoindex_unsupported.cc @@ -0,0 +1,38 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#if GRPC_IF_NAMETOINDEX == 0 || !defined(GRPC_POSIX_SOCKET_IF_NAMETOINDEX) + +#include + +#include "src/core/lib/iomgr/grpc_if_nametoindex.h" + +uint32_t grpc_if_nametoindex(char* name) { + gpr_log(GPR_DEBUG, + "Not attempting to convert interface name %s to index for current " + "platform.", + name); + return 0; +} + +#endif /* GRPC_IF_NAMETOINDEX == 0 || \ + !defined(GRPC_POSIX_SOCKET_IF_NAMETOINDEX) */ diff --git a/src/core/lib/iomgr/internal_errqueue.cc b/src/core/lib/iomgr/internal_errqueue.cc new file mode 100644 index 00000000..ac644b2f --- /dev/null +++ b/src/core/lib/iomgr/internal_errqueue.cc @@ -0,0 +1,68 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/internal_errqueue.h" + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include + +namespace grpc_core { +static bool errqueue_supported = false; + +bool kernel_supports_errqueue() { return errqueue_supported; } + +void grpc_errqueue_init() { +/* Both-compile time and run-time linux kernel versions should be at least 4.0.0 + */ +#ifdef GRPC_LINUX_ERRQUEUE + struct utsname buffer; + if (uname(&buffer) != 0) { + gpr_log(GPR_ERROR, "uname: %s", strerror(errno)); + return; + } + char* release = buffer.release; + if (release == nullptr) { + return; + } + + if (strtol(release, nullptr, 10) >= 4) { + errqueue_supported = true; + } else { + gpr_log(GPR_DEBUG, "ERRQUEUE support not enabled"); + } +#endif /* GRPC_LINUX_ERRQUEUE */ +} +} /* namespace grpc_core */ + +#else + +namespace grpc_core { +void grpc_errqueue_init() {} +} /* namespace grpc_core */ + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/src/core/lib/iomgr/iocp_windows.cc b/src/core/lib/iomgr/iocp_windows.cc new file mode 100644 index 00000000..d4c76eca --- /dev/null +++ b/src/core/lib/iomgr/iocp_windows.cc @@ -0,0 +1,158 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/timer.h" + +static ULONG g_iocp_kick_token; +static OVERLAPPED g_iocp_custom_overlap; + +static gpr_atm g_custom_events = 0; + +static HANDLE g_iocp; + +static DWORD deadline_to_millis_timeout(grpc_millis deadline) { + if (deadline == GRPC_MILLIS_INF_FUTURE) { + return INFINITE; + } + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + if (deadline < now) return 0; + grpc_millis timeout = deadline - now; + if (timeout > std::numeric_limits::max()) return INFINITE; + return static_cast(deadline - now); +} + +grpc_iocp_work_status grpc_iocp_work(grpc_millis deadline) { + BOOL success; + DWORD bytes = 0; + DWORD flags = 0; + ULONG_PTR completion_key; + LPOVERLAPPED overlapped; + grpc_winsocket* socket; + grpc_winsocket_callback_info* info; + GRPC_STATS_INC_SYSCALL_POLL(); + success = + GetQueuedCompletionStatus(g_iocp, &bytes, &completion_key, &overlapped, + deadline_to_millis_timeout(deadline)); + grpc_core::ExecCtx::Get()->InvalidateNow(); + if (success == 0 && overlapped == NULL) { + return GRPC_IOCP_WORK_TIMEOUT; + } + GPR_ASSERT(completion_key && overlapped); + if (overlapped == &g_iocp_custom_overlap) { + gpr_atm_full_fetch_add(&g_custom_events, -1); + if (completion_key == (ULONG_PTR)&g_iocp_kick_token) { + /* We were awoken from a kick. */ + return GRPC_IOCP_WORK_KICK; + } + gpr_log(GPR_ERROR, "Unknown custom completion key."); + abort(); + } + + socket = (grpc_winsocket*)completion_key; + if (overlapped == &socket->write_info.overlapped) { + info = &socket->write_info; + } else if (overlapped == &socket->read_info.overlapped) { + info = &socket->read_info; + } else { + abort(); + } + if (socket->shutdown_called) { + info->bytes_transferred = 0; + info->wsa_error = WSA_OPERATION_ABORTED; + } else { + success = WSAGetOverlappedResult(socket->socket, &info->overlapped, &bytes, + FALSE, &flags); + info->bytes_transferred = bytes; + info->wsa_error = success ? 0 : WSAGetLastError(); + } + GPR_ASSERT(overlapped == &info->overlapped); + grpc_socket_become_ready(socket, info); + return GRPC_IOCP_WORK_WORK; +} + +void grpc_iocp_init(void) { + g_iocp = + CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, (ULONG_PTR)NULL, 0); + GPR_ASSERT(g_iocp); +} + +void grpc_iocp_kick(void) { + BOOL success; + + gpr_atm_full_fetch_add(&g_custom_events, 1); + success = PostQueuedCompletionStatus(g_iocp, 0, (ULONG_PTR)&g_iocp_kick_token, + &g_iocp_custom_overlap); + GPR_ASSERT(success); +} + +void grpc_iocp_flush(void) { + grpc_core::ExecCtx exec_ctx; + grpc_iocp_work_status work_status; + + do { + work_status = grpc_iocp_work(GRPC_MILLIS_INF_PAST); + } while (work_status == GRPC_IOCP_WORK_KICK || + grpc_core::ExecCtx::Get()->Flush()); +} + +void grpc_iocp_shutdown(void) { + grpc_core::ExecCtx exec_ctx; + while (gpr_atm_acq_load(&g_custom_events)) { + grpc_iocp_work(GRPC_MILLIS_INF_FUTURE); + grpc_core::ExecCtx::Get()->Flush(); + } + + GPR_ASSERT(CloseHandle(g_iocp)); +} + +void grpc_iocp_add_socket(grpc_winsocket* socket) { + HANDLE ret; + if (socket->added_to_iocp) return; + ret = CreateIoCompletionPort((HANDLE)socket->socket, g_iocp, + (uintptr_t)socket, 0); + if (!ret) { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_ERROR, "Unable to add socket to iocp: %s", utf8_message); + gpr_free(utf8_message); + __debugbreak(); + abort(); + } + socket->added_to_iocp = 1; + GPR_ASSERT(ret == g_iocp); +} + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/iomgr.cc b/src/core/lib/iomgr/iomgr.cc new file mode 100644 index 00000000..8bfcebf9 --- /dev/null +++ b/src/core/lib/iomgr/iomgr.cc @@ -0,0 +1,196 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/iomgr.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/buffer_list.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/internal_errqueue.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/timer_manager.h" + +GPR_GLOBAL_CONFIG_DEFINE_BOOL(grpc_abort_on_leaks, false, + "A debugging aid to cause a call to abort() when " + "gRPC objects are leaked past grpc_shutdown()"); + +static gpr_mu g_mu; +static gpr_cv g_rcv; +static int g_shutdown; +static grpc_iomgr_object g_root_object; +static bool g_grpc_abort_on_leaks; + +void grpc_iomgr_init() { + grpc_core::ExecCtx exec_ctx; + if (!grpc_have_determined_iomgr_platform()) { + grpc_set_default_iomgr_platform(); + } + g_shutdown = 0; + gpr_mu_init(&g_mu); + gpr_cv_init(&g_rcv); + grpc_core::Executor::InitAll(); + g_root_object.next = g_root_object.prev = &g_root_object; + g_root_object.name = const_cast("root"); + grpc_iomgr_platform_init(); + grpc_timer_list_init(); + grpc_core::grpc_errqueue_init(); + g_grpc_abort_on_leaks = GPR_GLOBAL_CONFIG_GET(grpc_abort_on_leaks); +} + +void grpc_iomgr_start() { grpc_timer_manager_init(); } + +static size_t count_objects(void) { + grpc_iomgr_object* obj; + size_t n = 0; + for (obj = g_root_object.next; obj != &g_root_object; obj = obj->next) { + n++; + } + return n; +} + +size_t grpc_iomgr_count_objects_for_testing(void) { return count_objects(); } + +static void dump_objects(const char* kind) { + grpc_iomgr_object* obj; + for (obj = g_root_object.next; obj != &g_root_object; obj = obj->next) { + gpr_log(GPR_DEBUG, "%s OBJECT: %s %p", kind, obj->name, obj); + } +} + +void grpc_iomgr_shutdown() { + gpr_timespec shutdown_deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN)); + gpr_timespec last_warning_time = gpr_now(GPR_CLOCK_REALTIME); + + { + grpc_timer_manager_shutdown(); + grpc_iomgr_platform_flush(); + + gpr_mu_lock(&g_mu); + g_shutdown = 1; + while (g_root_object.next != &g_root_object) { + if (gpr_time_cmp( + gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), last_warning_time), + gpr_time_from_seconds(1, GPR_TIMESPAN)) >= 0) { + if (g_root_object.next != &g_root_object) { + gpr_log(GPR_DEBUG, + "Waiting for %" PRIuPTR " iomgr objects to be destroyed", + count_objects()); + } + last_warning_time = gpr_now(GPR_CLOCK_REALTIME); + } + grpc_core::ExecCtx::Get()->SetNowIomgrShutdown(); + if (grpc_timer_check(nullptr) == GRPC_TIMERS_FIRED) { + gpr_mu_unlock(&g_mu); + grpc_core::ExecCtx::Get()->Flush(); + grpc_iomgr_platform_flush(); + gpr_mu_lock(&g_mu); + continue; + } + if (g_root_object.next != &g_root_object) { + if (grpc_iomgr_abort_on_leaks()) { + gpr_log(GPR_DEBUG, + "Failed to free %" PRIuPTR + " iomgr objects before shutdown deadline: " + "memory leaks are likely", + count_objects()); + dump_objects("LEAKED"); + abort(); + } + gpr_timespec short_deadline = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(100, GPR_TIMESPAN)); + if (gpr_cv_wait(&g_rcv, &g_mu, short_deadline)) { + if (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), shutdown_deadline) > + 0) { + if (g_root_object.next != &g_root_object) { + gpr_log(GPR_DEBUG, + "Failed to free %" PRIuPTR + " iomgr objects before shutdown deadline: " + "memory leaks are likely", + count_objects()); + dump_objects("LEAKED"); + } + break; + } + } + } + } + gpr_mu_unlock(&g_mu); + grpc_timer_list_shutdown(); + grpc_core::ExecCtx::Get()->Flush(); + grpc_core::Executor::ShutdownAll(); + } + + /* ensure all threads have left g_mu */ + gpr_mu_lock(&g_mu); + gpr_mu_unlock(&g_mu); + + grpc_iomgr_platform_shutdown(); + gpr_mu_destroy(&g_mu); + gpr_cv_destroy(&g_rcv); +} + +void grpc_iomgr_shutdown_background_closure() { + grpc_iomgr_platform_shutdown_background_closure(); +} + +bool grpc_iomgr_is_any_background_poller_thread() { + return grpc_iomgr_platform_is_any_background_poller_thread(); +} + +bool grpc_iomgr_add_closure_to_background_poller(grpc_closure* closure, + grpc_error_handle error) { + return grpc_iomgr_platform_add_closure_to_background_poller(closure, error); +} + +void grpc_iomgr_register_object(grpc_iomgr_object* obj, const char* name) { + obj->name = gpr_strdup(name); + gpr_mu_lock(&g_mu); + obj->next = &g_root_object; + obj->prev = g_root_object.prev; + obj->next->prev = obj->prev->next = obj; + gpr_mu_unlock(&g_mu); +} + +void grpc_iomgr_unregister_object(grpc_iomgr_object* obj) { + gpr_mu_lock(&g_mu); + obj->next->prev = obj->prev; + obj->prev->next = obj->next; + gpr_cv_signal(&g_rcv); + gpr_mu_unlock(&g_mu); + gpr_free(obj->name); +} + +bool grpc_iomgr_abort_on_leaks(void) { return g_grpc_abort_on_leaks; } diff --git a/src/core/lib/iomgr/iomgr_custom.cc b/src/core/lib/iomgr/iomgr_custom.cc new file mode 100644 index 00000000..70f4e1b7 --- /dev/null +++ b/src/core/lib/iomgr/iomgr_custom.cc @@ -0,0 +1,79 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/iomgr_custom.h" + +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/pollset_custom.h" +#include "src/core/lib/iomgr/pollset_set_custom.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/resolve_address_custom.h" + +gpr_thd_id g_init_thread; + +static void iomgr_platform_init(void) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Executor::SetThreadingAll(false); + g_init_thread = gpr_thd_currentid(); + grpc_pollset_global_init(); +} +static void iomgr_platform_flush(void) {} +static void iomgr_platform_shutdown(void) { grpc_pollset_global_shutdown(); } +static void iomgr_platform_shutdown_background_closure(void) {} +static bool iomgr_platform_is_any_background_poller_thread(void) { + return false; +} +static bool iomgr_platform_add_closure_to_background_poller( + grpc_closure* /*closure*/, grpc_error_handle /*error*/) { + return false; +} + +bool g_custom_iomgr_enabled = false; + +static grpc_iomgr_platform_vtable vtable = { + iomgr_platform_init, + iomgr_platform_flush, + iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure, + iomgr_platform_is_any_background_poller_thread, + iomgr_platform_add_closure_to_background_poller}; + +void grpc_custom_iomgr_init(grpc_socket_vtable* socket, + grpc_custom_resolver_vtable* resolver, + grpc_custom_timer_vtable* timer, + grpc_custom_poller_vtable* poller) { + g_custom_iomgr_enabled = true; + grpc_custom_endpoint_init(socket); + grpc_custom_timer_init(timer); + grpc_custom_pollset_init(poller); + grpc_custom_pollset_set_init(); + grpc_custom_resolver_init(resolver); + grpc_set_iomgr_platform_vtable(&vtable); +} + +#ifdef GRPC_CUSTOM_SOCKET +grpc_iomgr_platform_vtable* grpc_default_iomgr_platform_vtable() { + return &vtable; +} +#endif diff --git a/src/core/lib/iomgr/iomgr_internal.cc b/src/core/lib/iomgr/iomgr_internal.cc new file mode 100644 index 00000000..87ec71ad --- /dev/null +++ b/src/core/lib/iomgr/iomgr_internal.cc @@ -0,0 +1,53 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/iomgr_internal.h" + +#include + +static grpc_iomgr_platform_vtable* iomgr_platform_vtable = nullptr; + +bool grpc_have_determined_iomgr_platform() { + return iomgr_platform_vtable != nullptr; +} + +void grpc_set_iomgr_platform_vtable(grpc_iomgr_platform_vtable* vtable) { + iomgr_platform_vtable = vtable; +} + +void grpc_iomgr_platform_init() { iomgr_platform_vtable->init(); } + +void grpc_iomgr_platform_flush() { iomgr_platform_vtable->flush(); } + +void grpc_iomgr_platform_shutdown() { iomgr_platform_vtable->shutdown(); } + +void grpc_iomgr_platform_shutdown_background_closure() { + iomgr_platform_vtable->shutdown_background_closure(); +} + +bool grpc_iomgr_platform_is_any_background_poller_thread() { + return iomgr_platform_vtable->is_any_background_poller_thread(); +} + +bool grpc_iomgr_platform_add_closure_to_background_poller( + grpc_closure* closure, grpc_error_handle error) { + return iomgr_platform_vtable->add_closure_to_background_poller(closure, + error); +} diff --git a/src/core/lib/iomgr/iomgr_posix.cc b/src/core/lib/iomgr/iomgr_posix.cc new file mode 100644 index 00000000..2450ef0b --- /dev/null +++ b/src/core/lib/iomgr/iomgr_posix.cc @@ -0,0 +1,90 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_IOMGR + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" + +extern grpc_tcp_server_vtable grpc_posix_tcp_server_vtable; +extern grpc_tcp_client_vtable grpc_posix_tcp_client_vtable; +extern grpc_timer_vtable grpc_generic_timer_vtable; +extern grpc_pollset_vtable grpc_posix_pollset_vtable; +extern grpc_pollset_set_vtable grpc_posix_pollset_set_vtable; +extern grpc_address_resolver_vtable grpc_posix_resolver_vtable; + +static void iomgr_platform_init(void) { + grpc_wakeup_fd_global_init(); + grpc_event_engine_init(); + grpc_tcp_posix_init(); +} + +static void iomgr_platform_flush(void) {} + +static void iomgr_platform_shutdown(void) { + grpc_tcp_posix_shutdown(); + grpc_event_engine_shutdown(); + grpc_wakeup_fd_global_destroy(); +} + +static void iomgr_platform_shutdown_background_closure(void) { + grpc_shutdown_background_closure(); +} + +static bool iomgr_platform_is_any_background_poller_thread(void) { + return grpc_is_any_background_poller_thread(); +} + +static bool iomgr_platform_add_closure_to_background_poller( + grpc_closure* closure, grpc_error_handle error) { + return grpc_add_closure_to_background_poller(closure, error); +} + +static grpc_iomgr_platform_vtable vtable = { + iomgr_platform_init, + iomgr_platform_flush, + iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure, + iomgr_platform_is_any_background_poller_thread, + iomgr_platform_add_closure_to_background_poller}; + +void grpc_set_default_iomgr_platform() { + grpc_set_tcp_client_impl(&grpc_posix_tcp_client_vtable); + grpc_set_tcp_server_impl(&grpc_posix_tcp_server_vtable); + grpc_set_timer_impl(&grpc_generic_timer_vtable); + grpc_set_pollset_vtable(&grpc_posix_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_posix_pollset_set_vtable); + grpc_set_resolver_impl(&grpc_posix_resolver_vtable); + grpc_set_iomgr_platform_vtable(&vtable); +} + +bool grpc_iomgr_run_in_background() { + return grpc_event_engine_run_in_background(); +} + +#endif /* GRPC_POSIX_SOCKET_IOMGR */ diff --git a/src/core/lib/iomgr/iomgr_posix_cfstream.cc b/src/core/lib/iomgr/iomgr_posix_cfstream.cc new file mode 100644 index 00000000..27af38e4 --- /dev/null +++ b/src/core/lib/iomgr/iomgr_posix_cfstream.cc @@ -0,0 +1,200 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/// CFStream is build-enabled on iOS by default and disabled by default on other +/// platforms (see port_platform.h). To enable CFStream build on another +/// platform, the users need to define macro "GRPC_CFSTREAM=1" when building +/// gRPC. +/// +/// When CFStream is to be built (either by default on iOS or by macro on other +/// platforms), the users can disable CFStream with environment variable +/// "grpc_cfstream=0". This will let gRPC to fallback to use POSIX sockets. In +/// addition, the users may choose to use an alternative CFRunLoop based pollset +/// "ev_apple" by setting environment variable "GRPC_CFSTREAM_RUN_LOOP=1". This +/// pollset resolves a bug from Apple when CFStream streams dispatch events to +/// dispatch queues. The caveat of this pollset is that users may not be able to +/// run a gRPC server in the same process. + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_CFSTREAM_IOMGR + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/ev_apple.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" + +static const char* grpc_cfstream_env_var = "grpc_cfstream"; +static const char* grpc_cfstream_run_loop_env_var = "GRPC_CFSTREAM_RUN_LOOP"; + +extern grpc_tcp_server_vtable grpc_posix_tcp_server_vtable; +extern grpc_tcp_client_vtable grpc_posix_tcp_client_vtable; +extern grpc_tcp_client_vtable grpc_cfstream_client_vtable; +extern grpc_timer_vtable grpc_generic_timer_vtable; +extern grpc_pollset_vtable grpc_posix_pollset_vtable; +extern grpc_pollset_set_vtable grpc_posix_pollset_set_vtable; +extern grpc_address_resolver_vtable grpc_posix_resolver_vtable; + +static void apple_iomgr_platform_init(void) { grpc_pollset_global_init(); } + +static void apple_iomgr_platform_flush(void) {} + +static void apple_iomgr_platform_shutdown(void) { + grpc_pollset_global_shutdown(); +} + +static void apple_iomgr_platform_shutdown_background_closure(void) {} + +static bool apple_iomgr_platform_is_any_background_poller_thread(void) { + return false; +} + +static bool apple_iomgr_platform_add_closure_to_background_poller( + grpc_closure* closure, grpc_error_handle error) { + return false; +} + +static grpc_iomgr_platform_vtable apple_vtable = { + apple_iomgr_platform_init, + apple_iomgr_platform_flush, + apple_iomgr_platform_shutdown, + apple_iomgr_platform_shutdown_background_closure, + apple_iomgr_platform_is_any_background_poller_thread, + apple_iomgr_platform_add_closure_to_background_poller}; + +namespace { +struct CFStreamEnv { + bool enable_cfstream; + bool enable_cfstream_run_loop; +}; + +// Parses environment variables for CFStream specific settings +CFStreamEnv ParseEnvForCFStream() { + CFStreamEnv env; + char* enable_cfstream_str = getenv(grpc_cfstream_env_var); + env.enable_cfstream = + enable_cfstream_str == nullptr || enable_cfstream_str[0] != '0'; + char* enable_cfstream_run_loop_str = getenv(grpc_cfstream_run_loop_env_var); + // CFStream run-loop is disabled by default. The user has to enable it + // explicitly with environment variable. + env.enable_cfstream_run_loop = enable_cfstream_run_loop_str != nullptr && + enable_cfstream_run_loop_str[0] == '1'; + return env; +} + +void MaybeInitializeTcpPosix(void) { + CFStreamEnv env = ParseEnvForCFStream(); + if (!env.enable_cfstream || !env.enable_cfstream_run_loop) { + grpc_tcp_posix_init(); + } +} + +void MaybeShutdownTcpPosix(void) { + CFStreamEnv env = ParseEnvForCFStream(); + if (!env.enable_cfstream || !env.enable_cfstream_run_loop) { + grpc_tcp_posix_shutdown(); + } +} +} // namespace + +static void iomgr_platform_init(void) { + MaybeInitializeTcpPosix(); + grpc_wakeup_fd_global_init(); + grpc_event_engine_init(); +} + +static void iomgr_platform_flush(void) {} + +static void iomgr_platform_shutdown(void) { + grpc_event_engine_shutdown(); + grpc_wakeup_fd_global_destroy(); + MaybeShutdownTcpPosix(); +} + +static void iomgr_platform_shutdown_background_closure(void) { + grpc_shutdown_background_closure(); +} + +static bool iomgr_platform_is_any_background_poller_thread(void) { + return grpc_is_any_background_poller_thread(); +} + +static bool iomgr_platform_add_closure_to_background_poller( + grpc_closure* closure, grpc_error_handle error) { + return grpc_add_closure_to_background_poller(closure, error); +} + +static grpc_iomgr_platform_vtable vtable = { + iomgr_platform_init, + iomgr_platform_flush, + iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure, + iomgr_platform_is_any_background_poller_thread, + iomgr_platform_add_closure_to_background_poller}; + +void grpc_set_default_iomgr_platform() { + CFStreamEnv env = ParseEnvForCFStream(); + if (!env.enable_cfstream) { + // Use POSIX sockets for both client and server + grpc_set_tcp_client_impl(&grpc_posix_tcp_client_vtable); + grpc_set_tcp_server_impl(&grpc_posix_tcp_server_vtable); + grpc_set_pollset_vtable(&grpc_posix_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_posix_pollset_set_vtable); + grpc_set_iomgr_platform_vtable(&vtable); + } else if (env.enable_cfstream && !env.enable_cfstream_run_loop) { + // Use CFStream with dispatch queue for client; use POSIX sockets for server + grpc_set_tcp_client_impl(&grpc_cfstream_client_vtable); + grpc_set_tcp_server_impl(&grpc_posix_tcp_server_vtable); + grpc_set_pollset_vtable(&grpc_posix_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_posix_pollset_set_vtable); + grpc_set_iomgr_platform_vtable(&vtable); + } else { + // Use CFStream with CFRunLoop for client; server not supported + grpc_set_tcp_client_impl(&grpc_cfstream_client_vtable); + grpc_set_pollset_vtable(&grpc_apple_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_apple_pollset_set_vtable); + grpc_set_iomgr_platform_vtable(&apple_vtable); + } + grpc_set_timer_impl(&grpc_generic_timer_vtable); + grpc_set_resolver_impl(&grpc_posix_resolver_vtable); +} + +bool grpc_iomgr_run_in_background() { + char* enable_cfstream_str = getenv(grpc_cfstream_env_var); + bool enable_cfstream = + enable_cfstream_str == nullptr || enable_cfstream_str[0] != '0'; + char* enable_cfstream_run_loop_str = getenv(grpc_cfstream_run_loop_env_var); + // CFStream run-loop is disabled by default. The user has to enable it + // explicitly with environment variable. + bool enable_cfstream_run_loop = enable_cfstream_run_loop_str != nullptr && + enable_cfstream_run_loop_str[0] == '1'; + if (enable_cfstream && enable_cfstream_run_loop) { + return false; + } else { + return grpc_event_engine_run_in_background(); + } +} + +#endif /* GRPC_CFSTREAM_IOMGR */ diff --git a/src/core/lib/iomgr/iomgr_windows.cc b/src/core/lib/iomgr/iomgr_windows.cc new file mode 100644 index 00000000..93fdaf85 --- /dev/null +++ b/src/core/lib/iomgr/iomgr_windows.cc @@ -0,0 +1,105 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include + +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/pollset_windows.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" + +extern grpc_tcp_server_vtable grpc_windows_tcp_server_vtable; +extern grpc_tcp_client_vtable grpc_windows_tcp_client_vtable; +extern grpc_timer_vtable grpc_generic_timer_vtable; +extern grpc_pollset_vtable grpc_windows_pollset_vtable; +extern grpc_pollset_set_vtable grpc_windows_pollset_set_vtable; +extern grpc_address_resolver_vtable grpc_windows_resolver_vtable; + +/* Windows' io manager is going to be fully designed using IO completion + ports. All of what we're doing here is basically make sure that + Windows sockets are initialized in and out. */ + +static void winsock_init(void) { + WSADATA wsaData; + int status = WSAStartup(MAKEWORD(2, 0), &wsaData); + GPR_ASSERT(status == 0); +} + +static void winsock_shutdown(void) { + int status = WSACleanup(); + GPR_ASSERT(status == 0); +} + +static void iomgr_platform_init(void) { + winsock_init(); + grpc_iocp_init(); + grpc_pollset_global_init(); + grpc_wsa_socket_flags_init(); +} + +static void iomgr_platform_flush(void) { grpc_iocp_flush(); } + +static void iomgr_platform_shutdown(void) { + grpc_pollset_global_shutdown(); + grpc_iocp_shutdown(); + winsock_shutdown(); +} + +static void iomgr_platform_shutdown_background_closure(void) {} + +static bool iomgr_platform_is_any_background_poller_thread(void) { + return false; +} + +static bool iomgr_platform_add_closure_to_background_poller( + grpc_closure* closure, grpc_error_handle error) { + return false; +} + +static grpc_iomgr_platform_vtable vtable = { + iomgr_platform_init, + iomgr_platform_flush, + iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure, + iomgr_platform_is_any_background_poller_thread, + iomgr_platform_add_closure_to_background_poller}; + +void grpc_set_default_iomgr_platform() { + grpc_set_tcp_client_impl(&grpc_windows_tcp_client_vtable); + grpc_set_tcp_server_impl(&grpc_windows_tcp_server_vtable); + grpc_set_timer_impl(&grpc_generic_timer_vtable); + grpc_set_pollset_vtable(&grpc_windows_pollset_vtable); + grpc_set_pollset_set_vtable(&grpc_windows_pollset_set_vtable); + grpc_set_resolver_impl(&grpc_windows_resolver_vtable); + grpc_set_iomgr_platform_vtable(&vtable); +} + +bool grpc_iomgr_run_in_background() { return false; } + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/is_epollexclusive_available.cc b/src/core/lib/iomgr/is_epollexclusive_available.cc new file mode 100644 index 00000000..80cac4a5 --- /dev/null +++ b/src/core/lib/iomgr/is_epollexclusive_available.cc @@ -0,0 +1,119 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/is_epollexclusive_available.h" + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_LINUX_EPOLL_CREATE1 + +#include +#include +#include +#include + +#include + +#include "src/core/lib/iomgr/sys_epoll_wrapper.h" + +/* This polling engine is only relevant on linux kernels supporting epoll() */ +bool grpc_is_epollexclusive_available(void) { + static bool logged_why_not = false; + + int fd = epoll_create1(EPOLL_CLOEXEC); + if (fd < 0) { + if (!logged_why_not) { + gpr_log(GPR_DEBUG, + "epoll_create1 failed with error: %d. Not using epollex polling " + "engine.", + fd); + logged_why_not = true; + } + return false; + } + int evfd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (evfd < 0) { + if (!logged_why_not) { + gpr_log(GPR_DEBUG, + "eventfd failed with error: %d. Not using epollex polling " + "engine.", + fd); + logged_why_not = true; + } + close(fd); + return false; + } + struct epoll_event ev; + /* choose events that should cause an error on + EPOLLEXCLUSIVE enabled kernels - specifically the combination of + EPOLLONESHOT and EPOLLEXCLUSIVE */ + ev.events = + static_cast(EPOLLET | EPOLLIN | EPOLLEXCLUSIVE | EPOLLONESHOT); + ev.data.ptr = nullptr; + if (epoll_ctl(fd, EPOLL_CTL_ADD, evfd, &ev) != 0) { + if (errno != EINVAL) { + if (!logged_why_not) { + gpr_log( + GPR_ERROR, + "epoll_ctl with EPOLLEXCLUSIVE | EPOLLONESHOT failed with error: " + "%d. Not using epollex polling engine.", + errno); + logged_why_not = true; + } + close(fd); + close(evfd); + return false; + } + } else { + if (!logged_why_not) { + gpr_log(GPR_DEBUG, + "epoll_ctl with EPOLLEXCLUSIVE | EPOLLONESHOT succeeded. This is " + "evidence of no EPOLLEXCLUSIVE support. Not using " + "epollex polling engine."); + logged_why_not = true; + } + close(fd); + close(evfd); + return false; + } + // Check that EPOLLEXCLUSIVE is supported at all. + ev.events = static_cast(EPOLLET | EPOLLIN | EPOLLEXCLUSIVE); + if (epoll_ctl(fd, EPOLL_CTL_ADD, evfd, &ev) != 0) { + if (!logged_why_not) { + gpr_log(GPR_DEBUG, + "epoll_ctl with EPOLLEXCLUSIVE failed with error: " + "%d. Not using epollex polling engine.", + errno); + logged_why_not = true; + } + close(fd); + close(evfd); + return false; + } + close(evfd); + close(fd); + return true; +} + +#else + +bool grpc_is_epollexclusive_available(void) { return false; } + +#endif diff --git a/src/core/lib/iomgr/load_file.cc b/src/core/lib/iomgr/load_file.cc new file mode 100644 index 00000000..90686701 --- /dev/null +++ b/src/core/lib/iomgr/load_file.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/load_file.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/block_annotate.h" + +grpc_error_handle grpc_load_file(const char* filename, int add_null_terminator, + grpc_slice* output) { + unsigned char* contents = nullptr; + size_t contents_size = 0; + grpc_slice result = grpc_empty_slice(); + FILE* file; + size_t bytes_read = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + + GRPC_SCHEDULING_START_BLOCKING_REGION; + file = fopen(filename, "rb"); + if (file == nullptr) { + error = GRPC_OS_ERROR(errno, "fopen"); + goto end; + } + fseek(file, 0, SEEK_END); + /* Converting to size_t on the assumption that it will not fail */ + contents_size = static_cast(ftell(file)); + fseek(file, 0, SEEK_SET); + contents = static_cast( + gpr_malloc(contents_size + (add_null_terminator ? 1 : 0))); + bytes_read = fread(contents, 1, contents_size, file); + if (bytes_read < contents_size) { + gpr_free(contents); + error = GRPC_OS_ERROR(errno, "fread"); + GPR_ASSERT(ferror(file)); + goto end; + } + if (add_null_terminator) { + contents[contents_size++] = 0; + } + result = grpc_slice_new(contents, contents_size, gpr_free); + +end: + *output = result; + if (file != nullptr) fclose(file); + if (error != GRPC_ERROR_NONE) { + grpc_error_handle error_out = + grpc_error_set_str(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to load file", &error, 1), + GRPC_ERROR_STR_FILENAME, + + filename); + GRPC_ERROR_UNREF(error); + error = error_out; + } + GRPC_SCHEDULING_END_BLOCKING_REGION_NO_EXEC_CTX; + return error; +} diff --git a/src/core/lib/iomgr/lockfree_event.cc b/src/core/lib/iomgr/lockfree_event.cc new file mode 100644 index 00000000..d41e5029 --- /dev/null +++ b/src/core/lib/iomgr/lockfree_event.cc @@ -0,0 +1,278 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/lockfree_event.h" + +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/exec_ctx.h" + +extern grpc_core::DebugOnlyTraceFlag grpc_polling_trace; + +/* 'state' holds the to call when the fd is readable or writable respectively. + It can contain one of the following values: + kClosureReady : The fd has an I/O event of interest but there is no + closure yet to execute + + kClosureNotReady : The fd has no I/O event of interest + + closure ptr : The closure to be executed when the fd has an I/O + event of interest + + shutdown_error | kShutdownBit : + 'shutdown_error' field ORed with kShutdownBit. + This indicates that the fd is shutdown. Since all + memory allocations are word-aligned, the lower two + bits of the shutdown_error pointer are always 0. So + it is safe to OR these with kShutdownBit + + Valid state transitions: + + <-----3------ kClosureNotReady -----1-------> kClosureReady + | | ^ | ^ | | + | | | | | | | + | +--------------4----------+ 6 +---------2---------------+ | + | | | + | v | + +-----5-------> [shutdown_error | kShutdownBit] <-------7---------+ + + For 1, 4 : See SetReady() function + For 2, 3 : See NotifyOn() function + For 5,6,7: See SetShutdown() function */ + +namespace grpc_core { + +LockfreeEvent::LockfreeEvent() { InitEvent(); } + +void LockfreeEvent::InitEvent() { + /* Perform an atomic store to start the state machine. + + Note carefully that LockfreeEvent *MAY* be used whilst in a destroyed + state, while a file descriptor is on a freelist. In such a state it may + be SetReady'd, and so we need to perform an atomic operation here to + ensure no races */ + gpr_atm_no_barrier_store(&state_, kClosureNotReady); +} + +void LockfreeEvent::DestroyEvent() { + gpr_atm curr; + do { + curr = gpr_atm_no_barrier_load(&state_); + if (curr & kShutdownBit) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + internal::StatusFreeHeapPtr(curr & ~kShutdownBit); +#else + GRPC_ERROR_UNREF((grpc_error_handle)(curr & ~kShutdownBit)); +#endif + } else { + GPR_ASSERT(curr == kClosureNotReady || curr == kClosureReady); + } + /* we CAS in a shutdown, no error value here. If this event is interacted + with post-deletion (see the note in the constructor) we want the bit + pattern to prevent error retention in a deleted object */ + } while (!gpr_atm_no_barrier_cas(&state_, curr, + kShutdownBit /* shutdown, no error */)); +} + +void LockfreeEvent::NotifyOn(grpc_closure* closure) { + while (true) { + /* This load needs to be an acquire load because this can be a shutdown + * error that we might need to reference. Adding acquire semantics makes + * sure that the shutdown error has been initialized properly before us + * referencing it. */ + gpr_atm curr = gpr_atm_acq_load(&state_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_DEBUG, + "LockfreeEvent::NotifyOn: %p curr=%" PRIxPTR " closure=%p", this, + curr, closure); + } + switch (curr) { + case kClosureNotReady: { + /* kClosureNotReady -> . + + We're guaranteed by API that there's an acquire barrier before here, + so there's no need to double-dip and this can be a release-only. + + The release itself pairs with the acquire half of a set_ready full + barrier. */ + if (gpr_atm_rel_cas(&state_, kClosureNotReady, + reinterpret_cast(closure))) { + return; /* Successful. Return */ + } + + break; /* retry */ + } + + case kClosureReady: { + /* Change the state to kClosureNotReady. Schedule the closure if + successful. If not, the state most likely transitioned to shutdown. + We should retry. + + This can be a no-barrier cas since the state is being transitioned to + kClosureNotReady; set_ready and set_shutdown do not schedule any + closure when transitioning out of CLOSURE_NO_READY state (i.e there + is no other code that needs to 'happen-after' this) */ + if (gpr_atm_no_barrier_cas(&state_, kClosureReady, kClosureNotReady)) { + ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + return; /* Successful. Return */ + } + + break; /* retry */ + } + + default: { + /* 'curr' is either a closure or the fd is shutdown(in which case 'curr' + contains a pointer to the shutdown-error). If the fd is shutdown, + schedule the closure with the shutdown error */ + if ((curr & kShutdownBit) > 0) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + grpc_error_handle shutdown_err = + internal::StatusGetFromHeapPtr(curr & ~kShutdownBit); +#else + grpc_error_handle shutdown_err = + reinterpret_cast(curr & ~kShutdownBit); +#endif + ExecCtx::Run(DEBUG_LOCATION, closure, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "FD Shutdown", &shutdown_err, 1)); + return; + } + + /* There is already a closure!. This indicates a bug in the code */ + gpr_log(GPR_ERROR, + "LockfreeEvent::NotifyOn: notify_on called with a previous " + "callback still pending"); + abort(); + } + } + } + + GPR_UNREACHABLE_CODE(return ); +} + +bool LockfreeEvent::SetShutdown(grpc_error_handle shutdown_error) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + intptr_t status_ptr = internal::StatusAllocHeapPtr(shutdown_error); + gpr_atm new_state = status_ptr | kShutdownBit; +#else + gpr_atm new_state = reinterpret_cast(shutdown_error) | kShutdownBit; +#endif + + while (true) { + gpr_atm curr = gpr_atm_no_barrier_load(&state_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_DEBUG, + "LockfreeEvent::SetShutdown: %p curr=%" PRIxPTR " err=%s", + &state_, curr, grpc_error_std_string(shutdown_error).c_str()); + } + switch (curr) { + case kClosureReady: + case kClosureNotReady: + /* Need a full barrier here so that the initial load in notify_on + doesn't need a barrier */ + if (gpr_atm_full_cas(&state_, curr, new_state)) { + return true; /* early out */ + } + break; /* retry */ + + default: { + /* 'curr' is either a closure or the fd is already shutdown */ + + /* If fd is already shutdown, we are done */ + if ((curr & kShutdownBit) > 0) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + internal::StatusFreeHeapPtr(status_ptr); +#else + GRPC_ERROR_UNREF(shutdown_error); +#endif + return false; + } + + /* Fd is not shutdown. Schedule the closure and move the state to + shutdown state. + Needs an acquire to pair with setting the closure (and get a + happens-after on that edge), and a release to pair with anything + loading the shutdown state. */ + if (gpr_atm_full_cas(&state_, curr, new_state)) { + ExecCtx::Run(DEBUG_LOCATION, reinterpret_cast(curr), + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "FD Shutdown", &shutdown_error, 1)); + return true; + } + + /* 'curr' was a closure but now changed to a different state. We will + have to retry */ + break; + } + } + } + + GPR_UNREACHABLE_CODE(return false); +} + +void LockfreeEvent::SetReady() { + while (true) { + gpr_atm curr = gpr_atm_no_barrier_load(&state_); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_polling_trace)) { + gpr_log(GPR_DEBUG, "LockfreeEvent::SetReady: %p curr=%" PRIxPTR, &state_, + curr); + } + + switch (curr) { + case kClosureReady: { + /* Already ready. We are done here */ + return; + } + + case kClosureNotReady: { + /* No barrier required as we're transitioning to a state that does not + involve a closure */ + if (gpr_atm_no_barrier_cas(&state_, kClosureNotReady, kClosureReady)) { + return; /* early out */ + } + break; /* retry */ + } + + default: { + /* 'curr' is either a closure or the fd is shutdown */ + if ((curr & kShutdownBit) > 0) { + /* The fd is shutdown. Do nothing */ + return; + } + /* Full cas: acquire pairs with this cas' release in the event of a + spurious set_ready; release pairs with this or the acquire in + notify_on (or set_shutdown) */ + else if (gpr_atm_full_cas(&state_, curr, kClosureNotReady)) { + ExecCtx::Run(DEBUG_LOCATION, reinterpret_cast(curr), + GRPC_ERROR_NONE); + return; + } + /* else the state changed again (only possible by either a racing + set_ready or set_shutdown functions. In both these cases, the closure + would have been scheduled for execution. So we are done here */ + return; + } + } + } +} + +} // namespace grpc_core diff --git a/src/core/lib/iomgr/polling_entity.cc b/src/core/lib/iomgr/polling_entity.cc new file mode 100644 index 00000000..0c1788a6 --- /dev/null +++ b/src/core/lib/iomgr/polling_entity.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/polling_entity.h" + +#include +#include + +grpc_polling_entity grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set* pollset_set) { + grpc_polling_entity pollent; + pollent.pollent.pollset_set = pollset_set; + pollent.tag = GRPC_POLLS_POLLSET_SET; + return pollent; +} + +grpc_polling_entity grpc_polling_entity_create_from_pollset( + grpc_pollset* pollset) { + grpc_polling_entity pollent; + pollent.pollent.pollset = pollset; + pollent.tag = GRPC_POLLS_POLLSET; + return pollent; +} + +grpc_pollset* grpc_polling_entity_pollset(grpc_polling_entity* pollent) { + if (pollent->tag == GRPC_POLLS_POLLSET) { + return pollent->pollent.pollset; + } + return nullptr; +} + +grpc_pollset_set* grpc_polling_entity_pollset_set( + grpc_polling_entity* pollent) { + if (pollent->tag == GRPC_POLLS_POLLSET_SET) { + return pollent->pollent.pollset_set; + } + return nullptr; +} + +bool grpc_polling_entity_is_empty(const grpc_polling_entity* pollent) { + return pollent->tag == GRPC_POLLS_NONE; +} + +void grpc_polling_entity_add_to_pollset_set(grpc_polling_entity* pollent, + grpc_pollset_set* pss_dst) { + if (pollent->tag == GRPC_POLLS_POLLSET) { + // CFStream does not use file destriptors. When CFStream is used, the fd + // pollset is possible to be null. + if (pollent->pollent.pollset != nullptr) { + grpc_pollset_set_add_pollset(pss_dst, pollent->pollent.pollset); + } + } else if (pollent->tag == GRPC_POLLS_POLLSET_SET) { + GPR_ASSERT(pollent->pollent.pollset_set != nullptr); + grpc_pollset_set_add_pollset_set(pss_dst, pollent->pollent.pollset_set); + } else { + gpr_log(GPR_ERROR, "Invalid grpc_polling_entity tag '%d'", pollent->tag); + abort(); + } +} + +void grpc_polling_entity_del_from_pollset_set(grpc_polling_entity* pollent, + grpc_pollset_set* pss_dst) { + if (pollent->tag == GRPC_POLLS_POLLSET) { +#ifdef GRPC_CFSTREAM + if (pollent->pollent.pollset != nullptr) { + grpc_pollset_set_del_pollset(pss_dst, pollent->pollent.pollset); + } +#else + GPR_ASSERT(pollent->pollent.pollset != nullptr); + grpc_pollset_set_del_pollset(pss_dst, pollent->pollent.pollset); +#endif + } else if (pollent->tag == GRPC_POLLS_POLLSET_SET) { + GPR_ASSERT(pollent->pollent.pollset_set != nullptr); + grpc_pollset_set_del_pollset_set(pss_dst, pollent->pollent.pollset_set); + } else { + gpr_log(GPR_ERROR, "Invalid grpc_polling_entity tag '%d'", pollent->tag); + abort(); + } +} diff --git a/src/core/lib/iomgr/pollset.cc b/src/core/lib/iomgr/pollset.cc new file mode 100644 index 00000000..ba2a58d8 --- /dev/null +++ b/src/core/lib/iomgr/pollset.cc @@ -0,0 +1,56 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/pollset.h" + +grpc_pollset_vtable* grpc_pollset_impl; + +void grpc_set_pollset_vtable(grpc_pollset_vtable* vtable) { + grpc_pollset_impl = vtable; +} + +void grpc_pollset_global_init() { grpc_pollset_impl->global_init(); } + +void grpc_pollset_global_shutdown() { grpc_pollset_impl->global_shutdown(); } + +void grpc_pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + grpc_pollset_impl->init(pollset, mu); +} + +void grpc_pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + grpc_pollset_impl->shutdown(pollset, closure); +} + +void grpc_pollset_destroy(grpc_pollset* pollset) { + grpc_pollset_impl->destroy(pollset); +} + +grpc_error_handle grpc_pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker, + grpc_millis deadline) { + return grpc_pollset_impl->work(pollset, worker, deadline); +} + +grpc_error_handle grpc_pollset_kick(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker) { + return grpc_pollset_impl->kick(pollset, specific_worker); +} + +size_t grpc_pollset_size(void) { return grpc_pollset_impl->pollset_size(); } diff --git a/src/core/lib/iomgr/pollset_custom.cc b/src/core/lib/iomgr/pollset_custom.cc new file mode 100644 index 00000000..f70c2bcf --- /dev/null +++ b/src/core/lib/iomgr/pollset_custom.cc @@ -0,0 +1,105 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/pollset_custom.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/timer.h" + +static grpc_custom_poller_vtable* poller_vtable; + +struct grpc_pollset { + gpr_mu mu; +}; + +static size_t pollset_size() { return sizeof(grpc_pollset); } + +static void pollset_global_init() { poller_vtable->init(); } + +static void pollset_global_shutdown() { poller_vtable->shutdown(); } + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + gpr_mu_init(&pollset->mu); + *mu = &pollset->mu; +} + +static void pollset_shutdown(grpc_pollset* /*pollset*/, grpc_closure* closure) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); +} + +static void pollset_destroy(grpc_pollset* pollset) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + gpr_mu_destroy(&pollset->mu); +} + +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** /*worker_hdl*/, + grpc_millis deadline) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + gpr_mu_unlock(&pollset->mu); + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + grpc_millis timeout = 0; + if (deadline > now) { + timeout = deadline - now; + } + // We yield here because the poll() call might yield + // control back to the application + grpc_core::ExecCtx* curr = grpc_core::ExecCtx::Get(); + grpc_core::ExecCtx::Set(nullptr); + grpc_error_handle err = poller_vtable->poll(static_cast(timeout)); + grpc_core::ExecCtx::Set(curr); + grpc_core::ExecCtx::Get()->InvalidateNow(); + if (grpc_core::ExecCtx::Get()->HasWork()) { + grpc_core::ExecCtx::Get()->Flush(); + } + gpr_mu_lock(&pollset->mu); + return err; +} + +static grpc_error_handle pollset_kick( + grpc_pollset* /*pollset*/, grpc_pollset_worker* /*specific_worker*/) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + poller_vtable->kick(); + return GRPC_ERROR_NONE; +} + +grpc_pollset_vtable custom_pollset_vtable = { + pollset_global_init, pollset_global_shutdown, + pollset_init, pollset_shutdown, + pollset_destroy, pollset_work, + pollset_kick, pollset_size}; + +void grpc_custom_pollset_init(grpc_custom_poller_vtable* vtable) { + poller_vtable = vtable; + grpc_set_pollset_vtable(&custom_pollset_vtable); +} diff --git a/src/core/lib/iomgr/pollset_set.cc b/src/core/lib/iomgr/pollset_set.cc new file mode 100644 index 00000000..42a647a7 --- /dev/null +++ b/src/core/lib/iomgr/pollset_set.cc @@ -0,0 +1,55 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/pollset_set.h" + +grpc_pollset_set_vtable* grpc_pollset_set_impl; + +void grpc_set_pollset_set_vtable(grpc_pollset_set_vtable* vtable) { + grpc_pollset_set_impl = vtable; +} + +grpc_pollset_set* grpc_pollset_set_create() { + return grpc_pollset_set_impl->create(); +} + +void grpc_pollset_set_destroy(grpc_pollset_set* pollset_set) { + grpc_pollset_set_impl->destroy(pollset_set); +} + +void grpc_pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + grpc_pollset_set_impl->add_pollset(pollset_set, pollset); +} + +void grpc_pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) { + grpc_pollset_set_impl->del_pollset(pollset_set, pollset); +} + +void grpc_pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + grpc_pollset_set_impl->add_pollset_set(bag, item); +} + +void grpc_pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) { + grpc_pollset_set_impl->del_pollset_set(bag, item); +} diff --git a/src/core/lib/iomgr/pollset_set_custom.cc b/src/core/lib/iomgr/pollset_set_custom.cc new file mode 100644 index 00000000..db105bfd --- /dev/null +++ b/src/core/lib/iomgr/pollset_set_custom.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/port.h" + +static grpc_pollset_set* pollset_set_create(void) { + return reinterpret_cast(static_cast(0xdeafbeef)); +} + +static void pollset_set_destroy(grpc_pollset_set* /*pollset_set*/) {} + +static void pollset_set_add_pollset(grpc_pollset_set* /*pollset_set*/, + grpc_pollset* /*pollset*/) {} + +static void pollset_set_del_pollset(grpc_pollset_set* /*pollset_set*/, + grpc_pollset* /*pollset*/) {} + +static void pollset_set_add_pollset_set(grpc_pollset_set* /*bag*/, + grpc_pollset_set* /*item*/) {} + +static void pollset_set_del_pollset_set(grpc_pollset_set* /*bag*/, + grpc_pollset_set* /*item*/) {} + +static grpc_pollset_set_vtable vtable = { + pollset_set_create, pollset_set_destroy, + pollset_set_add_pollset, pollset_set_del_pollset, + pollset_set_add_pollset_set, pollset_set_del_pollset_set}; + +void grpc_custom_pollset_set_init() { grpc_set_pollset_set_vtable(&vtable); } diff --git a/src/core/lib/iomgr/pollset_set_windows.cc b/src/core/lib/iomgr/pollset_set_windows.cc new file mode 100644 index 00000000..1b105a24 --- /dev/null +++ b/src/core/lib/iomgr/pollset_set_windows.cc @@ -0,0 +1,52 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include "src/core/lib/iomgr/pollset_set_windows.h" + +static grpc_pollset_set* pollset_set_create(void) { + return (grpc_pollset_set*)((intptr_t)0xdeafbeef); +} + +static void pollset_set_destroy(grpc_pollset_set* pollset_set) {} + +static void pollset_set_add_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} + +static void pollset_set_del_pollset(grpc_pollset_set* pollset_set, + grpc_pollset* pollset) {} + +static void pollset_set_add_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} + +static void pollset_set_del_pollset_set(grpc_pollset_set* bag, + grpc_pollset_set* item) {} + +grpc_pollset_set_vtable grpc_windows_pollset_set_vtable = { + pollset_set_create, pollset_set_destroy, + pollset_set_add_pollset, pollset_set_del_pollset, + pollset_set_add_pollset_set, pollset_set_del_pollset_set}; + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/pollset_windows.cc b/src/core/lib/iomgr/pollset_windows.cc new file mode 100644 index 00000000..f8758a3f --- /dev/null +++ b/src/core/lib/iomgr/pollset_windows.cc @@ -0,0 +1,243 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_windows.h" + +#define GRPC_POLLSET_KICK_BROADCAST ((grpc_pollset_worker*)1) + +grpc_core::DebugOnlyTraceFlag grpc_trace_fd_refcount(false, "fd_refcount"); + +gpr_mu grpc_polling_mu; +static grpc_pollset_worker* g_active_poller; +static grpc_pollset_worker g_global_root_worker; + +static void pollset_global_init(void) { + gpr_mu_init(&grpc_polling_mu); + g_active_poller = NULL; + g_global_root_worker.links[GRPC_POLLSET_WORKER_LINK_GLOBAL].next = + g_global_root_worker.links[GRPC_POLLSET_WORKER_LINK_GLOBAL].prev = + &g_global_root_worker; +} + +static void pollset_global_shutdown(void) { gpr_mu_destroy(&grpc_polling_mu); } + +static void remove_worker(grpc_pollset_worker* worker, + grpc_pollset_worker_link_type type) { + worker->links[type].prev->links[type].next = worker->links[type].next; + worker->links[type].next->links[type].prev = worker->links[type].prev; + worker->links[type].next = worker->links[type].prev = worker; +} + +static int has_workers(grpc_pollset_worker* root, + grpc_pollset_worker_link_type type) { + return root->links[type].next != root; +} + +static grpc_pollset_worker* pop_front_worker( + grpc_pollset_worker* root, grpc_pollset_worker_link_type type) { + if (has_workers(root, type)) { + grpc_pollset_worker* w = root->links[type].next; + remove_worker(w, type); + return w; + } else { + return NULL; + } +} + +static void push_front_worker(grpc_pollset_worker* root, + grpc_pollset_worker_link_type type, + grpc_pollset_worker* worker) { + worker->links[type].prev = root; + worker->links[type].next = worker->links[type].prev->links[type].next; + worker->links[type].prev->links[type].next = + worker->links[type].next->links[type].prev = worker; +} + +static size_t pollset_size(void) { return sizeof(grpc_pollset); } + +/* There isn't really any such thing as a pollset under Windows, due to the + nature of the IO completion ports. We're still going to provide a minimal + set of features for the sake of the rest of grpc. But grpc_pollset_work + won't actually do any polling, and return as quickly as possible. */ + +static void pollset_init(grpc_pollset* pollset, gpr_mu** mu) { + *mu = &grpc_polling_mu; + pollset->root_worker.links[GRPC_POLLSET_WORKER_LINK_POLLSET].next = + pollset->root_worker.links[GRPC_POLLSET_WORKER_LINK_POLLSET].prev = + &pollset->root_worker; +} + +static void pollset_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + pollset->shutting_down = 1; + grpc_pollset_kick(pollset, GRPC_POLLSET_KICK_BROADCAST); + if (!pollset->is_iocp_worker) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + } else { + pollset->on_shutdown = closure; + } +} + +static void pollset_destroy(grpc_pollset* pollset) {} + +static grpc_error_handle pollset_work(grpc_pollset* pollset, + grpc_pollset_worker** worker_hdl, + grpc_millis deadline) { + grpc_pollset_worker worker; + if (worker_hdl) *worker_hdl = &worker; + + int added_worker = 0; + worker.links[GRPC_POLLSET_WORKER_LINK_POLLSET].next = + worker.links[GRPC_POLLSET_WORKER_LINK_POLLSET].prev = + worker.links[GRPC_POLLSET_WORKER_LINK_GLOBAL].next = + worker.links[GRPC_POLLSET_WORKER_LINK_GLOBAL].prev = NULL; + worker.kicked = 0; + worker.pollset = pollset; + gpr_cv_init(&worker.cv); + if (!pollset->kicked_without_pollers && !pollset->shutting_down) { + if (g_active_poller == NULL) { + grpc_pollset_worker* next_worker; + /* become poller */ + pollset->is_iocp_worker = 1; + g_active_poller = &worker; + gpr_mu_unlock(&grpc_polling_mu); + grpc_iocp_work(deadline); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&grpc_polling_mu); + pollset->is_iocp_worker = 0; + g_active_poller = NULL; + /* try to get a worker from this pollsets worker list */ + next_worker = pop_front_worker(&pollset->root_worker, + GRPC_POLLSET_WORKER_LINK_POLLSET); + if (next_worker == NULL) { + /* try to get a worker from the global list */ + next_worker = pop_front_worker(&g_global_root_worker, + GRPC_POLLSET_WORKER_LINK_GLOBAL); + } + if (next_worker != NULL) { + next_worker->kicked = 1; + gpr_cv_signal(&next_worker->cv); + } + + if (pollset->shutting_down && pollset->on_shutdown != NULL) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, pollset->on_shutdown, + GRPC_ERROR_NONE); + pollset->on_shutdown = NULL; + } + goto done; + } + push_front_worker(&g_global_root_worker, GRPC_POLLSET_WORKER_LINK_GLOBAL, + &worker); + push_front_worker(&pollset->root_worker, GRPC_POLLSET_WORKER_LINK_POLLSET, + &worker); + added_worker = 1; + while (!worker.kicked) { + if (gpr_cv_wait(&worker.cv, &grpc_polling_mu, + grpc_millis_to_timespec(deadline, GPR_CLOCK_REALTIME))) { + grpc_core::ExecCtx::Get()->InvalidateNow(); + break; + } + grpc_core::ExecCtx::Get()->InvalidateNow(); + } + } else { + pollset->kicked_without_pollers = 0; + } +done: + if (!grpc_closure_list_empty(*grpc_core::ExecCtx::Get()->closure_list())) { + gpr_mu_unlock(&grpc_polling_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&grpc_polling_mu); + } + if (added_worker) { + remove_worker(&worker, GRPC_POLLSET_WORKER_LINK_GLOBAL); + remove_worker(&worker, GRPC_POLLSET_WORKER_LINK_POLLSET); + } + gpr_cv_destroy(&worker.cv); + if (worker_hdl) *worker_hdl = NULL; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle pollset_kick(grpc_pollset* p, + grpc_pollset_worker* specific_worker) { + bool should_kick_global = false; + if (specific_worker != NULL) { + if (specific_worker == GRPC_POLLSET_KICK_BROADCAST) { + should_kick_global = true; + for (specific_worker = + p->root_worker.links[GRPC_POLLSET_WORKER_LINK_POLLSET].next; + specific_worker != &p->root_worker; + specific_worker = + specific_worker->links[GRPC_POLLSET_WORKER_LINK_POLLSET].next) { + specific_worker->kicked = 1; + should_kick_global = false; + gpr_cv_signal(&specific_worker->cv); + } + p->kicked_without_pollers = 1; + if (p->is_iocp_worker) { + grpc_iocp_kick(); + should_kick_global = false; + } + } else { + if (p->is_iocp_worker && g_active_poller == specific_worker) { + grpc_iocp_kick(); + } else { + specific_worker->kicked = 1; + gpr_cv_signal(&specific_worker->cv); + } + } + } else { + specific_worker = + pop_front_worker(&p->root_worker, GRPC_POLLSET_WORKER_LINK_POLLSET); + if (specific_worker != NULL) { + grpc_pollset_kick(p, specific_worker); + } else if (p->is_iocp_worker) { + grpc_iocp_kick(); + } else { + p->kicked_without_pollers = 1; + should_kick_global = true; + } + } + if (should_kick_global && g_active_poller == NULL) { + grpc_pollset_worker* next_global_worker = pop_front_worker( + &g_global_root_worker, GRPC_POLLSET_WORKER_LINK_GLOBAL); + if (next_global_worker != NULL) { + next_global_worker->kicked = 1; + gpr_cv_signal(&next_global_worker->cv); + } + } + return GRPC_ERROR_NONE; +} + +grpc_pollset_vtable grpc_windows_pollset_vtable = { + pollset_global_init, pollset_global_shutdown, + pollset_init, pollset_shutdown, + pollset_destroy, pollset_work, + pollset_kick, pollset_size}; + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/resolve_address.cc b/src/core/lib/iomgr/resolve_address.cc new file mode 100644 index 00000000..a2e159a6 --- /dev/null +++ b/src/core/lib/iomgr/resolve_address.cc @@ -0,0 +1,55 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/iomgr/resolve_address.h" + +#include +#include + +namespace grpc_core { +const char* kDefaultSecurePort = "https"; +} // namespace grpc_core + +grpc_address_resolver_vtable* grpc_resolve_address_impl; + +void grpc_set_resolver_impl(grpc_address_resolver_vtable* vtable) { + grpc_resolve_address_impl = vtable; +} + +void grpc_resolve_address(const char* addr, const char* default_port, + grpc_pollset_set* interested_parties, + grpc_closure* on_done, + grpc_resolved_addresses** addresses) { + grpc_resolve_address_impl->resolve_address( + addr, default_port, interested_parties, on_done, addresses); +} + +void grpc_resolved_addresses_destroy(grpc_resolved_addresses* addresses) { + if (addresses != nullptr) { + gpr_free(addresses->addrs); + } + gpr_free(addresses); +} + +grpc_error_handle grpc_blocking_resolve_address( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + return grpc_resolve_address_impl->blocking_resolve_address(name, default_port, + addresses); +} diff --git a/src/core/lib/iomgr/resolve_address_custom.cc b/src/core/lib/iomgr/resolve_address_custom.cc new file mode 100644 index 00000000..6d8c599c --- /dev/null +++ b/src/core/lib/iomgr/resolve_address_custom.cc @@ -0,0 +1,169 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/resolve_address_custom.h" + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/port.h" + +struct grpc_custom_resolver { + grpc_closure* on_done = nullptr; + grpc_resolved_addresses** addresses = nullptr; + std::string host; + std::string port; +}; + +static grpc_custom_resolver_vtable* resolve_address_vtable = nullptr; + +static int retry_named_port_failure(grpc_custom_resolver* r, + grpc_resolved_addresses** res) { + // This loop is copied from resolve_address_posix.c + const char* svc[][2] = {{"http", "80"}, {"https", "443"}}; + for (size_t i = 0; i < GPR_ARRAY_SIZE(svc); i++) { + if (r->port == svc[i][0]) { + r->port = svc[i][1]; + if (res) { + grpc_error_handle error = resolve_address_vtable->resolve( + r->host.c_str(), r->port.c_str(), res); + if (error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(error); + return 0; + } + } else { + resolve_address_vtable->resolve_async(r, r->host.c_str(), + r->port.c_str()); + } + return 1; + } + } + return 0; +} + +void grpc_custom_resolve_callback(grpc_custom_resolver* r, + grpc_resolved_addresses* result, + grpc_error_handle error) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + if (error == GRPC_ERROR_NONE) { + *r->addresses = result; + } else if (retry_named_port_failure(r, nullptr)) { + return; + } + if (r->on_done) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, error); + } + delete r; +} + +static grpc_error_handle try_split_host_port(const char* name, + const char* default_port, + std::string* host, + std::string* port) { + /* parse name, splitting it into host and port parts */ + grpc_core::SplitHostPort(name, host, port); + if (host->empty()) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("unparseable host:port: '%s'", name)); + } + if (port->empty()) { + // TODO(murgatroid99): add tests for this case + if (default_port == nullptr) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("no port in name '%s'", name)); + } + *port = default_port; + } + return GRPC_ERROR_NONE; +} + +static grpc_error_handle blocking_resolve_address_impl( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + + grpc_custom_resolver resolver; + grpc_error_handle err = + try_split_host_port(name, default_port, &resolver.host, &resolver.port); + if (err != GRPC_ERROR_NONE) { + return err; + } + + /* Call getaddrinfo */ + grpc_resolved_addresses* addrs; + grpc_core::ExecCtx* curr = grpc_core::ExecCtx::Get(); + grpc_core::ExecCtx::Set(nullptr); + err = resolve_address_vtable->resolve(resolver.host.c_str(), + resolver.port.c_str(), &addrs); + if (err != GRPC_ERROR_NONE) { + if (retry_named_port_failure(&resolver, &addrs)) { + GRPC_ERROR_UNREF(err); + err = GRPC_ERROR_NONE; + } + } + grpc_core::ExecCtx::Set(curr); + if (err == GRPC_ERROR_NONE) { + *addresses = addrs; + } + return err; +} + +static void resolve_address_impl(const char* name, const char* default_port, + grpc_pollset_set* /*interested_parties*/, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + std::string host; + std::string port; + grpc_error_handle err = try_split_host_port(name, default_port, &host, &port); + if (err != GRPC_ERROR_NONE) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, err); + return; + } + grpc_custom_resolver* r = new grpc_custom_resolver(); + r->on_done = on_done; + r->addresses = addrs; + r->host = std::move(host); + r->port = std::move(port); + + /* Call getaddrinfo */ + resolve_address_vtable->resolve_async(r, r->host.c_str(), r->port.c_str()); +} + +static grpc_address_resolver_vtable custom_resolver_vtable = { + resolve_address_impl, blocking_resolve_address_impl}; + +void grpc_custom_resolver_init(grpc_custom_resolver_vtable* impl) { + resolve_address_vtable = impl; + grpc_set_resolver_impl(&custom_resolver_vtable); +} diff --git a/src/core/lib/iomgr/resolve_address_posix.cc b/src/core/lib/iomgr/resolve_address_posix.cc new file mode 100644 index 00000000..1427cbda --- /dev/null +++ b/src/core/lib/iomgr/resolve_address_posix.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" +#ifdef GRPC_POSIX_SOCKET_RESOLVE_ADDRESS + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/block_annotate.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +static grpc_error_handle posix_blocking_resolve_address( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + grpc_core::ExecCtx exec_ctx; + struct addrinfo hints; + struct addrinfo *result = nullptr, *resp; + int s; + size_t i; + grpc_error_handle err; + + std::string host; + std::string port; + /* parse name, splitting it into host and port parts */ + grpc_core::SplitHostPort(name, &host, &port); + if (host.empty()) { + err = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("unparseable host:port"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto done; + } + + if (port.empty()) { + if (default_port == nullptr) { + err = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("no port in name"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto done; + } + port = default_port; + } + + /* Call getaddrinfo */ + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; /* ipv4 or ipv6 */ + hints.ai_socktype = SOCK_STREAM; /* stream socket */ + hints.ai_flags = AI_PASSIVE; /* for wildcard IP address */ + + GRPC_SCHEDULING_START_BLOCKING_REGION; + s = getaddrinfo(host.c_str(), port.c_str(), &hints, &result); + GRPC_SCHEDULING_END_BLOCKING_REGION; + + if (s != 0) { + /* Retry if well-known service name is recognized */ + const char* svc[][2] = {{"http", "80"}, {"https", "443"}}; + for (i = 0; i < GPR_ARRAY_SIZE(svc); i++) { + if (port == svc[i][0]) { + GRPC_SCHEDULING_START_BLOCKING_REGION; + s = getaddrinfo(host.c_str(), svc[i][1], &hints, &result); + GRPC_SCHEDULING_END_BLOCKING_REGION; + break; + } + } + } + + if (s != 0) { + err = grpc_error_set_str( + grpc_error_set_str( + grpc_error_set_str( + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING(gai_strerror(s)), + GRPC_ERROR_INT_ERRNO, s), + GRPC_ERROR_STR_OS_ERROR, gai_strerror(s)), + GRPC_ERROR_STR_SYSCALL, "getaddrinfo"), + GRPC_ERROR_STR_TARGET_ADDRESS, name); + goto done; + } + + /* Success path: set addrs non-NULL, fill it in */ + *addresses = static_cast( + gpr_malloc(sizeof(grpc_resolved_addresses))); + (*addresses)->naddrs = 0; + for (resp = result; resp != nullptr; resp = resp->ai_next) { + (*addresses)->naddrs++; + } + (*addresses)->addrs = static_cast( + gpr_malloc(sizeof(grpc_resolved_address) * (*addresses)->naddrs)); + i = 0; + for (resp = result; resp != nullptr; resp = resp->ai_next) { + memcpy(&(*addresses)->addrs[i].addr, resp->ai_addr, resp->ai_addrlen); + (*addresses)->addrs[i].len = resp->ai_addrlen; + i++; + } + err = GRPC_ERROR_NONE; + +done: + if (result) { + freeaddrinfo(result); + } + return err; +} + +struct request { + char* name; + char* default_port; + grpc_closure* on_done; + grpc_resolved_addresses** addrs_out; + grpc_closure request_closure; + void* arg; +}; +/* Callback to be passed to grpc Executor to asynch-ify + * grpc_blocking_resolve_address */ +static void do_request_thread(void* rp, grpc_error_handle /*error*/) { + request* r = static_cast(rp); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, r->on_done, + grpc_blocking_resolve_address(r->name, r->default_port, r->addrs_out)); + gpr_free(r->name); + gpr_free(r->default_port); + gpr_free(r); +} + +static void posix_resolve_address(const char* name, const char* default_port, + grpc_pollset_set* /*interested_parties*/, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + request* r = static_cast(gpr_malloc(sizeof(request))); + GRPC_CLOSURE_INIT(&r->request_closure, do_request_thread, r, nullptr); + r->name = gpr_strdup(name); + r->default_port = gpr_strdup(default_port); + r->on_done = on_done; + r->addrs_out = addrs; + grpc_core::Executor::Run(&r->request_closure, GRPC_ERROR_NONE, + grpc_core::ExecutorType::RESOLVER); +} + +grpc_address_resolver_vtable grpc_posix_resolver_vtable = { + posix_resolve_address, posix_blocking_resolve_address}; +#endif diff --git a/src/core/lib/iomgr/resolve_address_windows.cc b/src/core/lib/iomgr/resolve_address_windows.cc new file mode 100644 index 00000000..8cea3c1e --- /dev/null +++ b/src/core/lib/iomgr/resolve_address_windows.cc @@ -0,0 +1,152 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" +#ifdef GRPC_WINSOCK_SOCKET + +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/block_annotate.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" + +struct request { + char* name; + char* default_port; + grpc_closure request_closure; + grpc_closure* on_done; + grpc_resolved_addresses** addresses; +}; +static grpc_error_handle windows_blocking_resolve_address( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + grpc_core::ExecCtx exec_ctx; + struct addrinfo hints; + struct addrinfo *result = NULL, *resp; + int s; + size_t i; + grpc_error_handle error = GRPC_ERROR_NONE; + + /* parse name, splitting it into host and port parts */ + std::string host; + std::string port; + grpc_core::SplitHostPort(name, &host, &port); + if (host.empty()) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("unparseable host:port: '%s'", name)); + goto done; + } + if (port.empty()) { + if (default_port == NULL) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("no port in name '%s'", name)); + goto done; + } + port = default_port; + } + + /* Call getaddrinfo */ + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; /* ipv4 or ipv6 */ + hints.ai_socktype = SOCK_STREAM; /* stream socket */ + hints.ai_flags = AI_PASSIVE; /* for wildcard IP address */ + + GRPC_SCHEDULING_START_BLOCKING_REGION; + s = getaddrinfo(host.c_str(), port.c_str(), &hints, &result); + GRPC_SCHEDULING_END_BLOCKING_REGION; + if (s != 0) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "getaddrinfo"); + goto done; + } + + /* Success path: set addrs non-NULL, fill it in */ + (*addresses) = + (grpc_resolved_addresses*)gpr_malloc(sizeof(grpc_resolved_addresses)); + (*addresses)->naddrs = 0; + for (resp = result; resp != NULL; resp = resp->ai_next) { + (*addresses)->naddrs++; + } + (*addresses)->addrs = (grpc_resolved_address*)gpr_malloc( + sizeof(grpc_resolved_address) * (*addresses)->naddrs); + i = 0; + for (resp = result; resp != NULL; resp = resp->ai_next) { + memcpy(&(*addresses)->addrs[i].addr, resp->ai_addr, resp->ai_addrlen); + (*addresses)->addrs[i].len = resp->ai_addrlen; + i++; + } + +done: + if (result) { + freeaddrinfo(result); + } + return error; +} + +/* Callback to be passed to grpc_executor to asynch-ify + * grpc_blocking_resolve_address */ +static void do_request_thread(void* rp, grpc_error_handle error) { + request* r = (request*)rp; + if (error == GRPC_ERROR_NONE) { + error = + grpc_blocking_resolve_address(r->name, r->default_port, r->addresses); + } else { + GRPC_ERROR_REF(error); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, error); + gpr_free(r->name); + gpr_free(r->default_port); + gpr_free(r); +} + +static void windows_resolve_address(const char* name, const char* default_port, + grpc_pollset_set* interested_parties, + grpc_closure* on_done, + grpc_resolved_addresses** addresses) { + request* r = (request*)gpr_malloc(sizeof(request)); + GRPC_CLOSURE_INIT(&r->request_closure, do_request_thread, r, nullptr); + r->name = gpr_strdup(name); + r->default_port = gpr_strdup(default_port); + r->on_done = on_done; + r->addresses = addresses; + grpc_core::Executor::Run(&r->request_closure, GRPC_ERROR_NONE, + grpc_core::ExecutorType::RESOLVER); +} + +grpc_address_resolver_vtable grpc_windows_resolver_vtable = { + windows_resolve_address, windows_blocking_resolve_address}; +#endif diff --git a/src/core/lib/iomgr/resource_quota.cc b/src/core/lib/iomgr/resource_quota.cc new file mode 100644 index 00000000..9f1ed409 --- /dev/null +++ b/src/core/lib/iomgr/resource_quota.cc @@ -0,0 +1,1106 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/resource_quota.h" + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/slice/slice_internal.h" + +grpc_core::TraceFlag grpc_resource_quota_trace(false, "resource_quota"); + +#define MEMORY_USAGE_ESTIMATION_MAX 65536 + +/* Internal linked list pointers for a resource user */ +struct grpc_resource_user_link { + grpc_resource_user* next; + grpc_resource_user* prev; +}; +/* Resource users are kept in (potentially) several intrusive linked lists + at once. These are the list names. */ +typedef enum { + /* Resource users that are waiting for an allocation */ + GRPC_RULIST_AWAITING_ALLOCATION, + /* Resource users that have free memory available for internal reclamation */ + GRPC_RULIST_NON_EMPTY_FREE_POOL, + /* Resource users that have published a benign reclamation is available */ + GRPC_RULIST_RECLAIMER_BENIGN, + /* Resource users that have published a destructive reclamation is + available */ + GRPC_RULIST_RECLAIMER_DESTRUCTIVE, + /* Number of lists: must be last */ + GRPC_RULIST_COUNT +} grpc_rulist; + +struct grpc_resource_user { + /* The quota this resource user consumes from */ + grpc_resource_quota* resource_quota; + + /* Closure to schedule an allocation under the resource quota combiner lock */ + grpc_closure allocate_closure; + /* Closure to publish a non empty free pool under the resource quota combiner + lock */ + grpc_closure add_to_free_pool_closure; + + /* one ref for each ref call (released by grpc_resource_user_unref), and one + ref for each byte allocated (released by grpc_resource_user_free) */ + gpr_atm refs; + /* is this resource user unlocked? starts at 0, increases for each shutdown + call */ + gpr_atm shutdown; + + gpr_mu mu; + /* The amount of memory (in bytes) this user has cached for its own use: to + avoid quota contention, each resource user can keep some memory in + addition to what it is immediately using (e.g., for caching), and the quota + can pull it back under memory pressure. + This value can become negative if more memory has been requested than + existed in the free pool, at which point the quota is consulted to bring + this value non-negative (asynchronously). */ + int64_t free_pool; + /* A list of closures to call once free_pool becomes non-negative - ie when + all outstanding allocations have been granted. */ + grpc_closure_list on_allocated; + /* True if we are currently trying to allocate from the quota, false if not */ + bool allocating; + /* The amount of memory (in bytes) that has been requested from this user + * asynchronously but hasn't been granted yet. */ + int64_t outstanding_allocations; + /* True if we are currently trying to add ourselves to the non-free quota + list, false otherwise */ + bool added_to_free_pool; + + /* The number of threads currently allocated to this resource user */ + gpr_atm num_threads_allocated; + + /* Reclaimers: index 0 is the benign reclaimer, 1 is the destructive reclaimer + */ + grpc_closure* reclaimers[2]; + /* Reclaimers just posted: once we're in the combiner lock, we'll move them + to the array above */ + grpc_closure* new_reclaimers[2]; + /* Trampoline closures to finish reclamation and re-enter the quota combiner + lock */ + grpc_closure post_reclaimer_closure[2]; + + /* Closure to execute under the quota combiner to de-register and shutdown the + resource user */ + grpc_closure destroy_closure; + + /* Links in the various grpc_rulist lists */ + grpc_resource_user_link links[GRPC_RULIST_COUNT]; + + /* The name of this resource user, for debugging/tracing */ + std::string name; +}; + +struct grpc_resource_quota { + /* refcount */ + gpr_refcount refs; + + /* estimate of current memory usage + scaled to the range [0..RESOURCE_USAGE_ESTIMATION_MAX] */ + gpr_atm memory_usage_estimation; + + /* Main combiner lock: all activity on a quota executes under this combiner + * (so no mutex is needed for this data structure) */ + grpc_core::Combiner* combiner; + /* Size of the resource quota */ + int64_t size; + /* Amount of free memory in the resource quota */ + int64_t free_pool; + /* Used size of memory in the resource quota. Updated as soon as the resource + * users start to allocate or free the memory. */ + gpr_atm used; + + gpr_atm last_size; + + /* Mutex to protect max_threads and num_threads_allocated */ + /* Note: We could have used gpr_atm for max_threads and num_threads_allocated + * and avoid having this mutex; but in that case, each invocation of the + * function grpc_resource_user_allocate_threads() would have had to do at + * least two atomic loads (for max_threads and num_threads_allocated) followed + * by a CAS (on num_threads_allocated). + * Moreover, we expect grpc_resource_user_allocate_threads() to be often + * called concurrently thereby increasing the chances of failing the CAS + * operation. This additional complexity is not worth the tiny perf gain we + * may (or may not) have by using atomics */ + gpr_mu thread_count_mu; + + /* Max number of threads allowed */ + int max_threads; + + /* Number of threads currently allocated via this resource_quota object */ + int num_threads_allocated; + + /* Has rq_step been scheduled to occur? */ + bool step_scheduled; + + /* Are we currently reclaiming memory */ + bool reclaiming; + + /* Closure around rq_step */ + grpc_closure rq_step_closure; + + /* Closure around rq_reclamation_done */ + grpc_closure rq_reclamation_done_closure; + + /* This is only really usable for debugging: it's always a stale pointer, but + a stale pointer that might just be fresh enough to guide us to where the + reclamation system is stuck */ + grpc_closure* debug_only_last_initiated_reclaimer; + grpc_resource_user* debug_only_last_reclaimer_resource_user; + + /* Roots of all resource user lists */ + grpc_resource_user* roots[GRPC_RULIST_COUNT]; + + std::string name; +}; + +static void ru_unref_by(grpc_resource_user* resource_user, gpr_atm amount); + +/******************************************************************************* + * list management + */ + +static void rulist_add_head(grpc_resource_user* resource_user, + grpc_rulist list) { + grpc_resource_quota* resource_quota = resource_user->resource_quota; + grpc_resource_user** root = &resource_quota->roots[list]; + if (*root == nullptr) { + *root = resource_user; + resource_user->links[list].next = resource_user->links[list].prev = + resource_user; + } else { + resource_user->links[list].next = *root; + resource_user->links[list].prev = (*root)->links[list].prev; + resource_user->links[list].next->links[list].prev = + resource_user->links[list].prev->links[list].next = resource_user; + *root = resource_user; + } +} + +static void rulist_add_tail(grpc_resource_user* resource_user, + grpc_rulist list) { + grpc_resource_quota* resource_quota = resource_user->resource_quota; + grpc_resource_user** root = &resource_quota->roots[list]; + if (*root == nullptr) { + *root = resource_user; + resource_user->links[list].next = resource_user->links[list].prev = + resource_user; + } else { + resource_user->links[list].next = (*root)->links[list].next; + resource_user->links[list].prev = *root; + resource_user->links[list].next->links[list].prev = + resource_user->links[list].prev->links[list].next = resource_user; + } +} + +static bool rulist_empty(grpc_resource_quota* resource_quota, + grpc_rulist list) { + return resource_quota->roots[list] == nullptr; +} + +static grpc_resource_user* rulist_pop_head(grpc_resource_quota* resource_quota, + grpc_rulist list) { + grpc_resource_user** root = &resource_quota->roots[list]; + grpc_resource_user* resource_user = *root; + if (resource_user == nullptr) { + return nullptr; + } + if (resource_user->links[list].next == resource_user) { + *root = nullptr; + } else { + resource_user->links[list].next->links[list].prev = + resource_user->links[list].prev; + resource_user->links[list].prev->links[list].next = + resource_user->links[list].next; + *root = resource_user->links[list].next; + } + resource_user->links[list].next = resource_user->links[list].prev = nullptr; + return resource_user; +} + +static void rulist_remove(grpc_resource_user* resource_user, grpc_rulist list) { + if (resource_user->links[list].next == nullptr) return; + grpc_resource_quota* resource_quota = resource_user->resource_quota; + if (resource_quota->roots[list] == resource_user) { + resource_quota->roots[list] = resource_user->links[list].next; + if (resource_quota->roots[list] == resource_user) { + resource_quota->roots[list] = nullptr; + } + } + resource_user->links[list].next->links[list].prev = + resource_user->links[list].prev; + resource_user->links[list].prev->links[list].next = + resource_user->links[list].next; + resource_user->links[list].next = resource_user->links[list].prev = nullptr; +} + +/******************************************************************************* + * resource quota state machine + */ + +static bool rq_alloc(grpc_resource_quota* resource_quota); +static bool rq_reclaim_from_per_user_free_pool( + grpc_resource_quota* resource_quota); +static bool rq_reclaim(grpc_resource_quota* resource_quota, bool destructive); + +static void rq_step(void* rq, grpc_error_handle /*error*/) { + grpc_resource_quota* resource_quota = static_cast(rq); + resource_quota->step_scheduled = false; + do { + if (rq_alloc(resource_quota)) goto done; + } while (rq_reclaim_from_per_user_free_pool(resource_quota)); + + if (!rq_reclaim(resource_quota, false)) { + rq_reclaim(resource_quota, true); + } + +done: + grpc_resource_quota_unref_internal(resource_quota); +} + +static void rq_step_sched(grpc_resource_quota* resource_quota) { + if (resource_quota->step_scheduled) return; + resource_quota->step_scheduled = true; + grpc_resource_quota_ref_internal(resource_quota); + resource_quota->combiner->FinallyRun(&resource_quota->rq_step_closure, + GRPC_ERROR_NONE); +} + +/* update the atomically available resource estimate - use no barriers since + timeliness of delivery really doesn't matter much */ +static void rq_update_estimate(grpc_resource_quota* resource_quota) { + gpr_atm memory_usage_estimation = MEMORY_USAGE_ESTIMATION_MAX; + if (resource_quota->size != 0) { + memory_usage_estimation = grpc_core::Clamp( + static_cast( + (1.0 - (static_cast(resource_quota->free_pool)) / + (static_cast(resource_quota->size))) * + MEMORY_USAGE_ESTIMATION_MAX), + gpr_atm(0), gpr_atm(MEMORY_USAGE_ESTIMATION_MAX)); + } + gpr_atm_no_barrier_store(&resource_quota->memory_usage_estimation, + memory_usage_estimation); +} + +/* returns true if all allocations are completed */ +static bool rq_alloc(grpc_resource_quota* resource_quota) { + grpc_resource_user* resource_user; + while ((resource_user = rulist_pop_head(resource_quota, + GRPC_RULIST_AWAITING_ALLOCATION))) { + gpr_mu_lock(&resource_user->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "RQ: check allocation for user %p shutdown=%" PRIdPTR + " free_pool=%" PRId64 " outstanding_allocations=%" PRId64, + resource_user, gpr_atm_no_barrier_load(&resource_user->shutdown), + resource_user->free_pool, resource_user->outstanding_allocations); + } + if (gpr_atm_no_barrier_load(&resource_user->shutdown)) { + resource_user->allocating = false; + grpc_closure_list_fail_all( + &resource_user->on_allocated, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource user shutdown")); + int64_t aborted_allocations = resource_user->outstanding_allocations; + resource_user->outstanding_allocations = 0; + resource_user->free_pool += aborted_allocations; + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &resource_user->on_allocated); + gpr_mu_unlock(&resource_user->mu); + if (aborted_allocations > 0) { + ru_unref_by(resource_user, static_cast(aborted_allocations)); + } + continue; + } + if (resource_user->free_pool < 0 && + -resource_user->free_pool <= resource_quota->free_pool) { + int64_t amt = -resource_user->free_pool; + resource_user->free_pool = 0; + resource_quota->free_pool -= amt; + rq_update_estimate(resource_quota); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "RQ %s %s: grant alloc %" PRId64 + " bytes; rq_free_pool -> %" PRId64, + resource_quota->name.c_str(), resource_user->name.c_str(), amt, + resource_quota->free_pool); + } + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace) && + resource_user->free_pool >= 0) { + gpr_log(GPR_INFO, "RQ %s %s: discard already satisfied alloc request", + resource_quota->name.c_str(), resource_user->name.c_str()); + } + if (resource_user->free_pool >= 0) { + resource_user->allocating = false; + resource_user->outstanding_allocations = 0; + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &resource_user->on_allocated); + gpr_mu_unlock(&resource_user->mu); + } else { + rulist_add_head(resource_user, GRPC_RULIST_AWAITING_ALLOCATION); + gpr_mu_unlock(&resource_user->mu); + return false; + } + } + return true; +} + +/* returns true if any memory could be reclaimed from buffers */ +static bool rq_reclaim_from_per_user_free_pool( + grpc_resource_quota* resource_quota) { + grpc_resource_user* resource_user; + while ((resource_user = rulist_pop_head(resource_quota, + GRPC_RULIST_NON_EMPTY_FREE_POOL))) { + gpr_mu_lock(&resource_user->mu); + resource_user->added_to_free_pool = false; + if (resource_user->free_pool > 0) { + int64_t amt = resource_user->free_pool; + resource_user->free_pool = 0; + resource_quota->free_pool += amt; + rq_update_estimate(resource_quota); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "RQ %s %s: reclaim_from_per_user_free_pool %" PRId64 + " bytes; rq_free_pool -> %" PRId64, + resource_quota->name.c_str(), resource_user->name.c_str(), amt, + resource_quota->free_pool); + } + gpr_mu_unlock(&resource_user->mu); + return true; + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, + "RQ %s %s: failed to reclaim_from_per_user_free_pool; " + "free_pool = %" PRId64 "; rq_free_pool = %" PRId64, + resource_quota->name.c_str(), resource_user->name.c_str(), + resource_user->free_pool, resource_quota->free_pool); + } + gpr_mu_unlock(&resource_user->mu); + } + } + return false; +} + +/* returns true if reclamation is proceeding */ +static bool rq_reclaim(grpc_resource_quota* resource_quota, bool destructive) { + if (resource_quota->reclaiming) return true; + grpc_rulist list = destructive ? GRPC_RULIST_RECLAIMER_DESTRUCTIVE + : GRPC_RULIST_RECLAIMER_BENIGN; + grpc_resource_user* resource_user = rulist_pop_head(resource_quota, list); + if (resource_user == nullptr) return false; + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RQ %s %s: initiate %s reclamation", + resource_quota->name.c_str(), resource_user->name.c_str(), + destructive ? "destructive" : "benign"); + } + resource_quota->reclaiming = true; + grpc_resource_quota_ref_internal(resource_quota); + grpc_closure* c = resource_user->reclaimers[destructive]; + GPR_ASSERT(c); + resource_quota->debug_only_last_reclaimer_resource_user = resource_user; + resource_quota->debug_only_last_initiated_reclaimer = c; + resource_user->reclaimers[destructive] = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, c, GRPC_ERROR_NONE); + return true; +} + +/******************************************************************************* + * ru_slice: a slice implementation that is backed by a grpc_resource_user + */ + +namespace grpc_core { + +class RuSliceRefcount { + public: + static void Destroy(void* p) { + auto* rc = static_cast(p); + rc->~RuSliceRefcount(); + gpr_free(rc); + } + RuSliceRefcount(grpc_resource_user* resource_user, size_t size) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + resource_user_(resource_user), + size_(size) { + // Nothing to do here. + } + ~RuSliceRefcount() { grpc_resource_user_free(resource_user_, size_); } + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + grpc_slice_refcount base_; + RefCount refs_; + grpc_resource_user* resource_user_; + size_t size_; +}; + +} // namespace grpc_core + +static grpc_slice ru_slice_create(grpc_resource_user* resource_user, + size_t size) { + auto* rc = static_cast( + gpr_malloc(sizeof(grpc_core::RuSliceRefcount) + size)); + new (rc) grpc_core::RuSliceRefcount(resource_user, size); + grpc_slice slice; + + slice.refcount = rc->base_refcount(); + slice.data.refcounted.bytes = reinterpret_cast(rc + 1); + slice.data.refcounted.length = size; + return slice; +} + +/******************************************************************************* + * grpc_resource_quota internal implementation: resource user manipulation under + * the combiner + */ + +// TODO(hork): rename all ru variables to resource_user +static void ru_allocate(void* ru, grpc_error_handle /*error*/) { + grpc_resource_user* resource_user = static_cast(ru); + if (rulist_empty(resource_user->resource_quota, + GRPC_RULIST_AWAITING_ALLOCATION)) { + rq_step_sched(resource_user->resource_quota); + } + rulist_add_tail(resource_user, GRPC_RULIST_AWAITING_ALLOCATION); +} + +static void ru_add_to_free_pool(void* ru, grpc_error_handle /*error*/) { + grpc_resource_user* resource_user = static_cast(ru); + if (!rulist_empty(resource_user->resource_quota, + GRPC_RULIST_AWAITING_ALLOCATION) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_NON_EMPTY_FREE_POOL)) { + rq_step_sched(resource_user->resource_quota); + } + rulist_add_tail(resource_user, GRPC_RULIST_NON_EMPTY_FREE_POOL); +} + +static bool ru_post_reclaimer(grpc_resource_user* resource_user, + bool destructive) { + grpc_closure* closure = resource_user->new_reclaimers[destructive]; + GPR_ASSERT(closure != nullptr); + resource_user->new_reclaimers[destructive] = nullptr; + GPR_ASSERT(resource_user->reclaimers[destructive] == nullptr); + if (gpr_atm_acq_load(&resource_user->shutdown) > 0) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_CANCELLED); + return false; + } + resource_user->reclaimers[destructive] = closure; + return true; +} + +static void ru_post_benign_reclaimer(void* ru, grpc_error_handle /*error*/) { + grpc_resource_user* resource_user = static_cast(ru); + if (!ru_post_reclaimer(resource_user, false)) return; + if (!rulist_empty(resource_user->resource_quota, + GRPC_RULIST_AWAITING_ALLOCATION) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_NON_EMPTY_FREE_POOL) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_RECLAIMER_BENIGN)) { + rq_step_sched(resource_user->resource_quota); + } + rulist_add_tail(resource_user, GRPC_RULIST_RECLAIMER_BENIGN); +} + +static void ru_post_destructive_reclaimer(void* ru, + grpc_error_handle /*error*/) { + grpc_resource_user* resource_user = static_cast(ru); + if (!ru_post_reclaimer(resource_user, true)) return; + if (!rulist_empty(resource_user->resource_quota, + GRPC_RULIST_AWAITING_ALLOCATION) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_NON_EMPTY_FREE_POOL) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_RECLAIMER_BENIGN) && + rulist_empty(resource_user->resource_quota, + GRPC_RULIST_RECLAIMER_DESTRUCTIVE)) { + rq_step_sched(resource_user->resource_quota); + } + rulist_add_tail(resource_user, GRPC_RULIST_RECLAIMER_DESTRUCTIVE); +} + +static void ru_shutdown(void* ru, grpc_error_handle /*error*/) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RU shutdown %p", ru); + } + grpc_resource_user* resource_user = static_cast(ru); + gpr_mu_lock(&resource_user->mu); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, resource_user->reclaimers[0], + GRPC_ERROR_CANCELLED); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, resource_user->reclaimers[1], + GRPC_ERROR_CANCELLED); + resource_user->reclaimers[0] = nullptr; + resource_user->reclaimers[1] = nullptr; + rulist_remove(resource_user, GRPC_RULIST_RECLAIMER_BENIGN); + rulist_remove(resource_user, GRPC_RULIST_RECLAIMER_DESTRUCTIVE); + if (resource_user->allocating) { + rq_step_sched(resource_user->resource_quota); + } + gpr_mu_unlock(&resource_user->mu); +} + +static void ru_destroy(void* ru, grpc_error_handle /*error*/) { + grpc_resource_user* resource_user = static_cast(ru); + GPR_ASSERT(gpr_atm_no_barrier_load(&resource_user->refs) == 0); + // Free all the remaining thread quota + grpc_resource_user_free_threads(resource_user, + static_cast(gpr_atm_no_barrier_load( + &resource_user->num_threads_allocated))); + + for (int i = 0; i < GRPC_RULIST_COUNT; i++) { + rulist_remove(resource_user, static_cast(i)); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, resource_user->reclaimers[0], + GRPC_ERROR_CANCELLED); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, resource_user->reclaimers[1], + GRPC_ERROR_CANCELLED); + if (resource_user->free_pool != 0) { + resource_user->resource_quota->free_pool += resource_user->free_pool; + rq_step_sched(resource_user->resource_quota); + } + grpc_resource_quota_unref_internal(resource_user->resource_quota); + gpr_mu_destroy(&resource_user->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RU '%s' (%p) destroyed", resource_user->name.c_str(), + resource_user); + } + delete resource_user; +} + +static void ru_alloc_slices(grpc_slice_allocator* slice_allocator) { + for (size_t i = 0; i < slice_allocator->count; i++) { + grpc_slice_buffer_add_indexed( + slice_allocator->dest, ru_slice_create(slice_allocator->resource_user, + slice_allocator->length)); + } +} + +static void ru_allocated_slices(void* arg, grpc_error_handle error) { + grpc_slice_allocator* slice_allocator = + static_cast(arg); + if (error == GRPC_ERROR_NONE) ru_alloc_slices(slice_allocator); + grpc_core::Closure::Run(DEBUG_LOCATION, &slice_allocator->on_done, + GRPC_ERROR_REF(error)); +} + +/******************************************************************************* + * grpc_resource_quota internal implementation: quota manipulation under the + * combiner + */ + +struct rq_resize_args { + int64_t size; + grpc_resource_quota* resource_quota; + grpc_closure closure; +}; +static void rq_resize(void* args, grpc_error_handle /*error*/) { + rq_resize_args* a = static_cast(args); + int64_t delta = a->size - a->resource_quota->size; + a->resource_quota->size += delta; + a->resource_quota->free_pool += delta; + rq_update_estimate(a->resource_quota); + rq_step_sched(a->resource_quota); + grpc_resource_quota_unref_internal(a->resource_quota); + gpr_free(a); +} + +static void rq_reclamation_done(void* rq, grpc_error_handle /*error*/) { + grpc_resource_quota* resource_quota = static_cast(rq); + resource_quota->reclaiming = false; + rq_step_sched(resource_quota); + grpc_resource_quota_unref_internal(resource_quota); +} + +/******************************************************************************* + * grpc_resource_quota api + */ + +/* Public API */ +grpc_resource_quota* grpc_resource_quota_create(const char* name) { + grpc_resource_quota* resource_quota = new grpc_resource_quota; + gpr_ref_init(&resource_quota->refs, 1); + resource_quota->combiner = grpc_combiner_create(); + resource_quota->free_pool = INT64_MAX; + resource_quota->size = INT64_MAX; + resource_quota->used = 0; + gpr_atm_no_barrier_store(&resource_quota->last_size, GPR_ATM_MAX); + gpr_mu_init(&resource_quota->thread_count_mu); + resource_quota->max_threads = INT_MAX; + resource_quota->num_threads_allocated = 0; + resource_quota->step_scheduled = false; + resource_quota->reclaiming = false; + gpr_atm_no_barrier_store(&resource_quota->memory_usage_estimation, 0); + if (name != nullptr) { + resource_quota->name = name; + } else { + resource_quota->name = absl::StrCat( + "anonymous_pool_", reinterpret_cast(resource_quota)); + } + GRPC_CLOSURE_INIT(&resource_quota->rq_step_closure, rq_step, resource_quota, + nullptr); + GRPC_CLOSURE_INIT(&resource_quota->rq_reclamation_done_closure, + rq_reclamation_done, resource_quota, nullptr); + for (int i = 0; i < GRPC_RULIST_COUNT; i++) { + resource_quota->roots[i] = nullptr; + } + return resource_quota; +} + +void grpc_resource_quota_unref_internal(grpc_resource_quota* resource_quota) { + if (gpr_unref(&resource_quota->refs)) { + // No outstanding thread quota + GPR_ASSERT(resource_quota->num_threads_allocated == 0); + GRPC_COMBINER_UNREF(resource_quota->combiner, "resource_quota"); + gpr_mu_destroy(&resource_quota->thread_count_mu); + delete resource_quota; + } +} + +/* Public API */ +void grpc_resource_quota_unref(grpc_resource_quota* resource_quota) { + grpc_core::ExecCtx exec_ctx; + grpc_resource_quota_unref_internal(resource_quota); +} + +grpc_resource_quota* grpc_resource_quota_ref_internal( + grpc_resource_quota* resource_quota) { + gpr_ref(&resource_quota->refs); + return resource_quota; +} + +/* Public API */ +void grpc_resource_quota_ref(grpc_resource_quota* resource_quota) { + grpc_resource_quota_ref_internal(resource_quota); +} + +double grpc_resource_quota_get_memory_pressure( + grpc_resource_quota* resource_quota) { + return (static_cast(gpr_atm_no_barrier_load( + &resource_quota->memory_usage_estimation))) / + (static_cast(MEMORY_USAGE_ESTIMATION_MAX)); +} + +/* Public API */ +void grpc_resource_quota_set_max_threads(grpc_resource_quota* resource_quota, + int new_max_threads) { + GPR_ASSERT(new_max_threads >= 0); + gpr_mu_lock(&resource_quota->thread_count_mu); + resource_quota->max_threads = new_max_threads; + gpr_mu_unlock(&resource_quota->thread_count_mu); +} + +/* Public API */ +void grpc_resource_quota_resize(grpc_resource_quota* resource_quota, + size_t size) { + grpc_core::ExecCtx exec_ctx; + rq_resize_args* a = static_cast(gpr_malloc(sizeof(*a))); + a->resource_quota = grpc_resource_quota_ref_internal(resource_quota); + a->size = static_cast(size); + gpr_atm_no_barrier_store(&resource_quota->last_size, + (gpr_atm)std::min((size_t)GPR_ATM_MAX, size)); + GRPC_CLOSURE_INIT(&a->closure, rq_resize, a, grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &a->closure, GRPC_ERROR_NONE); +} + +size_t grpc_resource_quota_peek_size(grpc_resource_quota* resource_quota) { + return static_cast( + gpr_atm_no_barrier_load(&resource_quota->last_size)); +} + +/******************************************************************************* + * grpc_resource_user channel args api + */ + +grpc_resource_quota* grpc_resource_quota_from_channel_args( + const grpc_channel_args* channel_args, bool create) { + auto* resource_quota = grpc_channel_args_find_pointer( + channel_args, GRPC_ARG_RESOURCE_QUOTA); + if (resource_quota != nullptr) { + return grpc_resource_quota_ref_internal(resource_quota); + } + return create ? grpc_resource_quota_create(nullptr) : nullptr; +} + +static void* rq_copy(void* rq) { + grpc_resource_quota_ref(static_cast(rq)); + return rq; +} + +static void rq_destroy(void* rq) { + grpc_resource_quota_unref_internal(static_cast(rq)); +} + +static int rq_cmp(void* a, void* b) { return grpc_core::QsortCompare(a, b); } + +const grpc_arg_pointer_vtable* grpc_resource_quota_arg_vtable(void) { + static const grpc_arg_pointer_vtable vtable = {rq_copy, rq_destroy, rq_cmp}; + return &vtable; +} + +/******************************************************************************* + * grpc_resource_user api + */ + +grpc_resource_user* grpc_resource_user_create( + grpc_resource_quota* resource_quota, absl::string_view name) { + grpc_resource_user* resource_user = new grpc_resource_user; + resource_user->resource_quota = + grpc_resource_quota_ref_internal(resource_quota); + GRPC_CLOSURE_INIT(&resource_user->allocate_closure, &ru_allocate, + resource_user, nullptr); + GRPC_CLOSURE_INIT(&resource_user->add_to_free_pool_closure, + &ru_add_to_free_pool, resource_user, nullptr); + GRPC_CLOSURE_INIT(&resource_user->post_reclaimer_closure[0], + &ru_post_benign_reclaimer, resource_user, nullptr); + GRPC_CLOSURE_INIT(&resource_user->post_reclaimer_closure[1], + &ru_post_destructive_reclaimer, resource_user, nullptr); + GRPC_CLOSURE_INIT(&resource_user->destroy_closure, &ru_destroy, resource_user, + nullptr); + gpr_mu_init(&resource_user->mu); + gpr_atm_rel_store(&resource_user->refs, 1); + gpr_atm_rel_store(&resource_user->shutdown, 0); + resource_user->free_pool = 0; + grpc_closure_list_init(&resource_user->on_allocated); + resource_user->allocating = false; + resource_user->added_to_free_pool = false; + gpr_atm_no_barrier_store(&resource_user->num_threads_allocated, 0); + resource_user->reclaimers[0] = nullptr; + resource_user->reclaimers[1] = nullptr; + resource_user->new_reclaimers[0] = nullptr; + resource_user->new_reclaimers[1] = nullptr; + resource_user->outstanding_allocations = 0; + for (int i = 0; i < GRPC_RULIST_COUNT; i++) { + resource_user->links[i].next = resource_user->links[i].prev = nullptr; + } + if (!name.empty()) { + resource_user->name = std::string(name); + } else { + resource_user->name = absl::StrCat( + "anonymous_resource_user_", reinterpret_cast(resource_user)); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RU '%s' (%p) created", resource_user->name.c_str(), + resource_user); + } + return resource_user; +} + +grpc_resource_quota* grpc_resource_user_quota( + grpc_resource_user* resource_user) { + return resource_user->resource_quota; +} + +static void ru_ref_by(grpc_resource_user* resource_user, gpr_atm amount) { + GPR_ASSERT(amount > 0); + gpr_atm prior = gpr_atm_no_barrier_fetch_add(&resource_user->refs, amount); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RU '%s' (%p) reffing: %" PRIdPTR " -> %" PRIdPTR, + resource_user->name.c_str(), resource_user, prior, prior + amount); + } + GPR_ASSERT(prior != 0); +} + +static void ru_unref_by(grpc_resource_user* resource_user, gpr_atm amount) { + GPR_ASSERT(amount > 0); + gpr_atm old = gpr_atm_full_fetch_add(&resource_user->refs, -amount); + GPR_ASSERT(old >= amount); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RU '%s' (%p) unreffing: %" PRIdPTR " -> %" PRIdPTR, + resource_user->name.c_str(), resource_user, old, old - amount); + } + if (old == amount) { + resource_user->resource_quota->combiner->Run( + &resource_user->destroy_closure, GRPC_ERROR_NONE); + } +} + +void grpc_resource_user_ref(grpc_resource_user* resource_user) { + ru_ref_by(resource_user, 1); +} + +void grpc_resource_user_unref(grpc_resource_user* resource_user) { + ru_unref_by(resource_user, 1); +} + +void grpc_resource_user_shutdown(grpc_resource_user* resource_user) { + if (gpr_atm_full_fetch_add(&resource_user->shutdown, 1) == 0) { + resource_user->resource_quota->combiner->Run( + GRPC_CLOSURE_CREATE(ru_shutdown, resource_user, nullptr), + GRPC_ERROR_NONE); + } +} + +bool grpc_resource_user_allocate_threads(grpc_resource_user* resource_user, + int thread_count) { + GPR_ASSERT(thread_count >= 0); + bool is_success = false; + gpr_mu_lock(&resource_user->resource_quota->thread_count_mu); + grpc_resource_quota* resource_quota = resource_user->resource_quota; + if (resource_quota->num_threads_allocated + thread_count <= + resource_quota->max_threads) { + resource_quota->num_threads_allocated += thread_count; + gpr_atm_no_barrier_fetch_add(&resource_user->num_threads_allocated, + thread_count); + is_success = true; + } + gpr_mu_unlock(&resource_user->resource_quota->thread_count_mu); + return is_success; +} + +void grpc_resource_user_free_threads(grpc_resource_user* resource_user, + int thread_count) { + GPR_ASSERT(thread_count >= 0); + gpr_mu_lock(&resource_user->resource_quota->thread_count_mu); + grpc_resource_quota* resource_quota = resource_user->resource_quota; + resource_quota->num_threads_allocated -= thread_count; + int old_count = static_cast(gpr_atm_no_barrier_fetch_add( + &resource_user->num_threads_allocated, -thread_count)); + if (old_count < thread_count || resource_quota->num_threads_allocated < 0) { + gpr_log(GPR_ERROR, + "Releasing more threads (%d) than currently allocated " + "(resource_quota threads: %d, ru threads: %d)", + thread_count, resource_quota->num_threads_allocated + thread_count, + old_count); + abort(); + } + gpr_mu_unlock(&resource_user->resource_quota->thread_count_mu); +} + +static bool resource_user_alloc_locked(grpc_resource_user* resource_user, + size_t size, + grpc_closure* optional_on_done) { + ru_ref_by(resource_user, static_cast(size)); + resource_user->free_pool -= static_cast(size); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RQ %s %s: alloc %" PRIdPTR "; free_pool -> %" PRId64, + resource_user->resource_quota->name.c_str(), + resource_user->name.c_str(), size, resource_user->free_pool); + } + if (GPR_LIKELY(resource_user->free_pool >= 0)) return true; + // Slow path: We need to wait for the free pool to refill. + if (optional_on_done != nullptr) { + resource_user->outstanding_allocations += static_cast(size); + grpc_closure_list_append(&resource_user->on_allocated, optional_on_done, + GRPC_ERROR_NONE); + } + if (!resource_user->allocating) { + resource_user->allocating = true; + resource_user->resource_quota->combiner->Run( + &resource_user->allocate_closure, GRPC_ERROR_NONE); + } + return false; +} + +bool grpc_resource_user_safe_alloc(grpc_resource_user* resource_user, + size_t size) { + if (gpr_atm_no_barrier_load(&resource_user->shutdown)) return false; + gpr_mu_lock(&resource_user->mu); + grpc_resource_quota* resource_quota = resource_user->resource_quota; + bool cas_success; + do { + gpr_atm used = gpr_atm_no_barrier_load(&resource_quota->used); + gpr_atm new_used = used + size; + if (static_cast(new_used) > + grpc_resource_quota_peek_size(resource_quota)) { + gpr_mu_unlock(&resource_user->mu); + return false; + } + cas_success = gpr_atm_full_cas(&resource_quota->used, used, new_used); + } while (!cas_success); + resource_user_alloc_locked(resource_user, size, nullptr); + gpr_mu_unlock(&resource_user->mu); + return true; +} + +bool grpc_resource_user_alloc(grpc_resource_user* resource_user, size_t size, + grpc_closure* optional_on_done) { + // TODO(juanlishen): Maybe return immediately if shutting down. Deferring this + // because some tests become flaky after the change. + gpr_mu_lock(&resource_user->mu); + grpc_resource_quota* resource_quota = resource_user->resource_quota; + gpr_atm_no_barrier_fetch_add(&resource_quota->used, size); + const bool ret = + resource_user_alloc_locked(resource_user, size, optional_on_done); + gpr_mu_unlock(&resource_user->mu); + return ret; +} + +void grpc_resource_user_free(grpc_resource_user* resource_user, size_t size) { + gpr_mu_lock(&resource_user->mu); + grpc_resource_quota* resource_quota = resource_user->resource_quota; + gpr_atm prior = gpr_atm_no_barrier_fetch_add(&resource_quota->used, -size); + GPR_ASSERT(prior >= static_cast(size)); + bool was_zero_or_negative = resource_user->free_pool <= 0; + resource_user->free_pool += static_cast(size); + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RQ %s %s: free %" PRIdPTR "; free_pool -> %" PRId64, + resource_user->resource_quota->name.c_str(), + resource_user->name.c_str(), size, resource_user->free_pool); + } + bool is_bigger_than_zero = resource_user->free_pool > 0; + if (is_bigger_than_zero && was_zero_or_negative && + !resource_user->added_to_free_pool) { + resource_user->added_to_free_pool = true; + resource_quota->combiner->Run(&resource_user->add_to_free_pool_closure, + GRPC_ERROR_NONE); + } + gpr_mu_unlock(&resource_user->mu); + ru_unref_by(resource_user, static_cast(size)); +} + +void grpc_resource_user_post_reclaimer(grpc_resource_user* resource_user, + bool destructive, + grpc_closure* closure) { + GPR_ASSERT(resource_user->new_reclaimers[destructive] == nullptr); + resource_user->new_reclaimers[destructive] = closure; + resource_user->resource_quota->combiner->Run( + &resource_user->post_reclaimer_closure[destructive], GRPC_ERROR_NONE); +} + +void grpc_resource_user_finish_reclamation(grpc_resource_user* resource_user) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log(GPR_INFO, "RQ %s %s: reclamation complete", + resource_user->resource_quota->name.c_str(), + resource_user->name.c_str()); + } + resource_user->resource_quota->combiner->Run( + &resource_user->resource_quota->rq_reclamation_done_closure, + GRPC_ERROR_NONE); +} + +grpc_slice_allocator* grpc_slice_allocator_create( + grpc_resource_quota* resource_quota, absl::string_view name, + const grpc_channel_args* args) { + grpc_slice_allocator* slice_allocator = new grpc_slice_allocator; + slice_allocator->min_length = grpc_channel_args_find_integer( + args, GRPC_ARG_TCP_MIN_READ_CHUNK_SIZE, + {GRPC_SLICE_ALLOCATOR_MIN_ALLOCATE_SIZE, -1, INT_MAX}); + slice_allocator->max_length = grpc_channel_args_find_integer( + args, GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE, + {GRPC_SLICE_ALLOCATOR_MAX_ALLOCATE_SIZE, -1, INT_MAX}); + slice_allocator->resource_user = + grpc_resource_user_create(resource_quota, name); + GRPC_CLOSURE_INIT(&slice_allocator->on_allocated, ru_allocated_slices, + slice_allocator, grpc_schedule_on_exec_ctx); + return slice_allocator; +} + +void grpc_slice_allocator_destroy(grpc_slice_allocator* slice_allocator) { + ru_unref_by(slice_allocator->resource_user, 1); + delete slice_allocator; +} + +static size_t grpc_slice_allocator_adjust_allocation_length( + grpc_slice_allocator* slice_allocator, size_t requested_length, + grpc_slice_allocator_intent intent) { + if (intent == grpc_slice_allocator_intent::kDefault) { + return requested_length; + } + GPR_ASSERT(intent == grpc_slice_allocator_intent::kReadBuffer); + double pressure = grpc_resource_quota_get_memory_pressure( + slice_allocator->resource_user->resource_quota); + // Reduce allocation size proportional to the pressure > 80% usage. + size_t target = + requested_length * (pressure > 0.8 ? (1.0 - pressure) / 0.2 : 1.0); + // Target will be some multiple of 8 bytes, rounded up + target = + (static_cast(grpc_core::Clamp(target, slice_allocator->min_length, + slice_allocator->max_length)) + + 255) & + ~static_cast(255); + // Don't use more than 1/16th of the overall resource quota for a single + // read alloc + size_t rqmax = grpc_resource_quota_peek_size( + slice_allocator->resource_user->resource_quota); + if (target > rqmax / 16 && rqmax > 1024) { + target = rqmax / 16; + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_resource_quota_trace)) { + gpr_log( + GPR_INFO, + "SliceAllocator(%p) requested %zu bytes for (%s) intent, adjusted " + "allocation size to %zu", + slice_allocator, requested_length, + intent == grpc_slice_allocator_intent::kDefault ? "default" : "read", + target); + } + return target; +} + +bool grpc_slice_allocator_allocate(grpc_slice_allocator* slice_allocator, + size_t length, size_t count, + grpc_slice_allocator_intent intent, + grpc_slice_buffer* dest, + grpc_iomgr_cb_func cb, void* p) { + if (GPR_UNLIKELY( + gpr_atm_no_barrier_load(&slice_allocator->resource_user->shutdown))) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, &slice_allocator->on_allocated, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource user shutdown")); + return false; + } + GRPC_CLOSURE_INIT(&slice_allocator->on_done, cb, p, + grpc_schedule_on_exec_ctx); + slice_allocator->length = grpc_slice_allocator_adjust_allocation_length( + slice_allocator, length, intent); + slice_allocator->count = count; + slice_allocator->dest = dest; + const bool ret = grpc_resource_user_alloc(slice_allocator->resource_user, + count * slice_allocator->length, + &slice_allocator->on_allocated); + if (ret) ru_alloc_slices(slice_allocator); + return ret; +} + +grpc_slice_allocator_factory* grpc_slice_allocator_factory_create( + grpc_resource_quota* resource_quota) { + grpc_slice_allocator_factory* factory = new grpc_slice_allocator_factory; + factory->resource_quota = resource_quota; + return factory; +} + +grpc_slice_allocator* grpc_slice_allocator_factory_create_slice_allocator( + grpc_slice_allocator_factory* slice_allocator_factory, + absl::string_view name, grpc_channel_args* args) { + return grpc_slice_allocator_create(slice_allocator_factory->resource_quota, + name, args); +} + +void grpc_slice_allocator_factory_destroy( + grpc_slice_allocator_factory* slice_allocator_factory) { + grpc_resource_quota_unref_internal(slice_allocator_factory->resource_quota); + delete slice_allocator_factory; +} diff --git a/src/core/lib/iomgr/socket_factory_posix.cc b/src/core/lib/iomgr/socket_factory_posix.cc new file mode 100644 index 00000000..11f86569 --- /dev/null +++ b/src/core/lib/iomgr/socket_factory_posix.cc @@ -0,0 +1,95 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_SOCKET_FACTORY + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/socket_factory_posix.h" + +void grpc_socket_factory_init(grpc_socket_factory* factory, + const grpc_socket_factory_vtable* vtable) { + factory->vtable = vtable; + gpr_ref_init(&factory->refcount, 1); +} + +int grpc_socket_factory_socket(grpc_socket_factory* factory, int domain, + int type, int protocol) { + return factory->vtable->socket(factory, domain, type, protocol); +} + +int grpc_socket_factory_bind(grpc_socket_factory* factory, int sockfd, + const grpc_resolved_address* addr) { + return factory->vtable->bind(factory, sockfd, addr); +} + +int grpc_socket_factory_compare(grpc_socket_factory* a, + grpc_socket_factory* b) { + int c = grpc_core::QsortCompare(a, b); + if (c != 0) { + grpc_socket_factory* sma = a; + grpc_socket_factory* smb = b; + c = grpc_core::QsortCompare(sma->vtable, smb->vtable); + if (c == 0) { + c = sma->vtable->compare(sma, smb); + } + } + return c; +} + +grpc_socket_factory* grpc_socket_factory_ref(grpc_socket_factory* factory) { + gpr_ref(&factory->refcount); + return factory; +} + +void grpc_socket_factory_unref(grpc_socket_factory* factory) { + if (gpr_unref(&factory->refcount)) { + factory->vtable->destroy(factory); + } +} + +static void* socket_factory_arg_copy(void* p) { + return grpc_socket_factory_ref(static_cast(p)); +} + +static void socket_factory_arg_destroy(void* p) { + grpc_socket_factory_unref(static_cast(p)); +} + +static int socket_factory_cmp(void* a, void* b) { + return grpc_socket_factory_compare(static_cast(a), + static_cast(b)); +} + +static const grpc_arg_pointer_vtable socket_factory_arg_vtable = { + socket_factory_arg_copy, socket_factory_arg_destroy, socket_factory_cmp}; + +grpc_arg grpc_socket_factory_to_arg(grpc_socket_factory* factory) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_SOCKET_FACTORY), factory, + &socket_factory_arg_vtable); +} + +#endif diff --git a/src/core/lib/iomgr/socket_mutator.cc b/src/core/lib/iomgr/socket_mutator.cc new file mode 100644 index 00000000..b908e22f --- /dev/null +++ b/src/core/lib/iomgr/socket_mutator.cc @@ -0,0 +1,97 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/socket_mutator.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" + +void grpc_socket_mutator_init(grpc_socket_mutator* mutator, + const grpc_socket_mutator_vtable* vtable) { + mutator->vtable = vtable; + gpr_ref_init(&mutator->refcount, 1); +} + +grpc_socket_mutator* grpc_socket_mutator_ref(grpc_socket_mutator* mutator) { + gpr_ref(&mutator->refcount); + return mutator; +} + +bool grpc_socket_mutator_mutate_fd(grpc_socket_mutator* mutator, int fd, + grpc_fd_usage usage) { + if (mutator->vtable->mutate_fd_2 != nullptr) { + grpc_mutate_socket_info info{fd, usage}; + return mutator->vtable->mutate_fd_2(&info, mutator); + } + switch (usage) { + case GRPC_FD_SERVER_CONNECTION_USAGE: + return true; + case GRPC_FD_CLIENT_CONNECTION_USAGE: + case GRPC_FD_SERVER_LISTENER_USAGE: + return mutator->vtable->mutate_fd(fd, mutator); + } + GPR_UNREACHABLE_CODE(return false); +} + +int grpc_socket_mutator_compare(grpc_socket_mutator* a, + grpc_socket_mutator* b) { + int c = grpc_core::QsortCompare(a, b); + if (c != 0) { + grpc_socket_mutator* sma = a; + grpc_socket_mutator* smb = b; + c = grpc_core::QsortCompare(sma->vtable, smb->vtable); + if (c == 0) { + c = sma->vtable->compare(sma, smb); + } + } + return c; +} + +void grpc_socket_mutator_unref(grpc_socket_mutator* mutator) { + if (gpr_unref(&mutator->refcount)) { + mutator->vtable->destroy(mutator); + } +} + +static void* socket_mutator_arg_copy(void* p) { + return grpc_socket_mutator_ref(static_cast(p)); +} + +static void socket_mutator_arg_destroy(void* p) { + grpc_socket_mutator_unref(static_cast(p)); +} + +static int socket_mutator_cmp(void* a, void* b) { + return grpc_socket_mutator_compare(static_cast(a), + static_cast(b)); +} + +static const grpc_arg_pointer_vtable socket_mutator_arg_vtable = { + socket_mutator_arg_copy, socket_mutator_arg_destroy, socket_mutator_cmp}; + +grpc_arg grpc_socket_mutator_to_arg(grpc_socket_mutator* mutator) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_SOCKET_MUTATOR), mutator, + &socket_mutator_arg_vtable); +} diff --git a/src/core/lib/iomgr/socket_utils_common_posix.cc b/src/core/lib/iomgr/socket_utils_common_posix.cc new file mode 100644 index 00000000..7c6706be --- /dev/null +++ b/src/core/lib/iomgr/socket_utils_common_posix.cc @@ -0,0 +1,515 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_UTILS_COMMON + +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#ifdef GRPC_LINUX_TCP_H +#include +#else +#include +#endif +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/sockaddr.h" + +/* set a socket to use zerocopy */ +grpc_error_handle grpc_set_socket_zerocopy(int fd) { +#ifdef GRPC_LINUX_ERRQUEUE + const int enable = 1; + auto err = setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable)); + if (err != 0) { + return GRPC_OS_ERROR(errno, "setsockopt(SO_ZEROCOPY)"); + } + return GRPC_ERROR_NONE; +#else + (void)fd; + return GRPC_OS_ERROR(ENOSYS, "setsockopt(SO_ZEROCOPY)"); +#endif +} + +/* set a socket to non blocking mode */ +grpc_error_handle grpc_set_socket_nonblocking(int fd, int non_blocking) { + int oldflags = fcntl(fd, F_GETFL, 0); + if (oldflags < 0) { + return GRPC_OS_ERROR(errno, "fcntl"); + } + + if (non_blocking) { + oldflags |= O_NONBLOCK; + } else { + oldflags &= ~O_NONBLOCK; + } + + if (fcntl(fd, F_SETFL, oldflags) != 0) { + return GRPC_OS_ERROR(errno, "fcntl"); + } + + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_set_socket_no_sigpipe_if_possible(int fd) { +#ifdef GRPC_HAVE_SO_NOSIGPIPE + int val = 1; + int newval; + socklen_t intlen = sizeof(newval); + if (0 != setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val))) { + return GRPC_OS_ERROR(errno, "setsockopt(SO_NOSIGPIPE)"); + } + if (0 != getsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &newval, &intlen)) { + return GRPC_OS_ERROR(errno, "getsockopt(SO_NOSIGPIPE)"); + } + if ((newval != 0) != (val != 0)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to set SO_NOSIGPIPE"); + } +#else + // Avoid unused parameter warning for conditional parameter + (void)fd; +#endif + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_set_socket_ip_pktinfo_if_possible(int fd) { + // Use conditionally-important parameter to avoid warning + (void)fd; +#ifdef GRPC_HAVE_IP_PKTINFO + int get_local_ip = 1; + if (0 != setsockopt(fd, IPPROTO_IP, IP_PKTINFO, &get_local_ip, + sizeof(get_local_ip))) { + return GRPC_OS_ERROR(errno, "setsockopt(IP_PKTINFO)"); + } +#endif + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_set_socket_ipv6_recvpktinfo_if_possible(int fd) { + // Use conditionally-important parameter to avoid warning + (void)fd; +#ifdef GRPC_HAVE_IPV6_RECVPKTINFO + int get_local_ip = 1; + if (0 != setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &get_local_ip, + sizeof(get_local_ip))) { + return GRPC_OS_ERROR(errno, "setsockopt(IPV6_RECVPKTINFO)"); + } +#endif + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_set_socket_sndbuf(int fd, int buffer_size_bytes) { + return 0 == setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buffer_size_bytes, + sizeof(buffer_size_bytes)) + ? GRPC_ERROR_NONE + : GRPC_OS_ERROR(errno, "setsockopt(SO_SNDBUF)"); +} + +grpc_error_handle grpc_set_socket_rcvbuf(int fd, int buffer_size_bytes) { + return 0 == setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &buffer_size_bytes, + sizeof(buffer_size_bytes)) + ? GRPC_ERROR_NONE + : GRPC_OS_ERROR(errno, "setsockopt(SO_RCVBUF)"); +} + +/* set a socket to close on exec */ +grpc_error_handle grpc_set_socket_cloexec(int fd, int close_on_exec) { + int oldflags = fcntl(fd, F_GETFD, 0); + if (oldflags < 0) { + return GRPC_OS_ERROR(errno, "fcntl"); + } + + if (close_on_exec) { + oldflags |= FD_CLOEXEC; + } else { + oldflags &= ~FD_CLOEXEC; + } + + if (fcntl(fd, F_SETFD, oldflags) != 0) { + return GRPC_OS_ERROR(errno, "fcntl"); + } + + return GRPC_ERROR_NONE; +} + +/* set a socket to reuse old addresses */ +grpc_error_handle grpc_set_socket_reuse_addr(int fd, int reuse) { + int val = (reuse != 0); + int newval; + socklen_t intlen = sizeof(newval); + if (0 != setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val))) { + return GRPC_OS_ERROR(errno, "setsockopt(SO_REUSEADDR)"); + } + if (0 != getsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &newval, &intlen)) { + return GRPC_OS_ERROR(errno, "getsockopt(SO_REUSEADDR)"); + } + if ((newval != 0) != val) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to set SO_REUSEADDR"); + } + + return GRPC_ERROR_NONE; +} + +/* set a socket to reuse old addresses */ +grpc_error_handle grpc_set_socket_reuse_port(int fd, int reuse) { +#ifndef SO_REUSEPORT + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "SO_REUSEPORT unavailable on compiling system"); +#else + int val = (reuse != 0); + int newval; + socklen_t intlen = sizeof(newval); + if (0 != setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val))) { + return GRPC_OS_ERROR(errno, "setsockopt(SO_REUSEPORT)"); + } + if (0 != getsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &newval, &intlen)) { + return GRPC_OS_ERROR(errno, "getsockopt(SO_REUSEPORT)"); + } + if ((newval != 0) != val) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to set SO_REUSEPORT"); + } + + return GRPC_ERROR_NONE; +#endif +} + +static gpr_once g_probe_so_reuesport_once = GPR_ONCE_INIT; +static int g_support_so_reuseport = false; + +void probe_so_reuseport_once(void) { + int s = socket(AF_INET, SOCK_STREAM, 0); + if (s < 0) { + /* This might be an ipv6-only environment in which case 'socket(AF_INET,..)' + call would fail. Try creating IPv6 socket in that case */ + s = socket(AF_INET6, SOCK_STREAM, 0); + } + if (s >= 0) { + g_support_so_reuseport = GRPC_LOG_IF_ERROR( + "check for SO_REUSEPORT", grpc_set_socket_reuse_port(s, 1)); + close(s); + } +} + +bool grpc_is_socket_reuse_port_supported() { + gpr_once_init(&g_probe_so_reuesport_once, probe_so_reuseport_once); + return g_support_so_reuseport; +} + +/* disable nagle */ +grpc_error_handle grpc_set_socket_low_latency(int fd, int low_latency) { + int val = (low_latency != 0); + int newval; + socklen_t intlen = sizeof(newval); + if (0 != setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val))) { + return GRPC_OS_ERROR(errno, "setsockopt(TCP_NODELAY)"); + } + if (0 != getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &newval, &intlen)) { + return GRPC_OS_ERROR(errno, "getsockopt(TCP_NODELAY)"); + } + if ((newval != 0) != val) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to set TCP_NODELAY"); + } + return GRPC_ERROR_NONE; +} + +/* The default values for TCP_USER_TIMEOUT are currently configured to be in + * line with the default values of KEEPALIVE_TIMEOUT as proposed in + * https://github.com/grpc/proposal/blob/master/A18-tcp-user-timeout.md */ +#define DEFAULT_CLIENT_TCP_USER_TIMEOUT_MS 20000 /* 20 seconds */ +#define DEFAULT_SERVER_TCP_USER_TIMEOUT_MS 20000 /* 20 seconds */ + +static int g_default_client_tcp_user_timeout_ms = + DEFAULT_CLIENT_TCP_USER_TIMEOUT_MS; +static int g_default_server_tcp_user_timeout_ms = + DEFAULT_SERVER_TCP_USER_TIMEOUT_MS; +static bool g_default_client_tcp_user_timeout_enabled = false; +static bool g_default_server_tcp_user_timeout_enabled = true; + +#if GPR_LINUX == 1 +// For Linux, it will be detected to support TCP_USER_TIMEOUT +#ifndef TCP_USER_TIMEOUT +#define TCP_USER_TIMEOUT 18 +#endif +#define SOCKET_SUPPORTS_TCP_USER_TIMEOUT_DEFAULT 0 +#else +// For non-Linux, TCP_USER_TIMEOUT will be used if TCP_USER_TIMEOUT is defined. +#ifdef TCP_USER_TIMEOUT +#define SOCKET_SUPPORTS_TCP_USER_TIMEOUT_DEFAULT 0 +#else +#define TCP_USER_TIMEOUT 0 +#define SOCKET_SUPPORTS_TCP_USER_TIMEOUT_DEFAULT -1 +#endif // TCP_USER_TIMEOUT +#endif // GPR_LINUX == 1 + +// Whether the socket supports TCP_USER_TIMEOUT option. +// (0: don't know, 1: support, -1: not support) +static std::atomic g_socket_supports_tcp_user_timeout( + SOCKET_SUPPORTS_TCP_USER_TIMEOUT_DEFAULT); + +void config_default_tcp_user_timeout(bool enable, int timeout, bool is_client) { + if (is_client) { + g_default_client_tcp_user_timeout_enabled = enable; + if (timeout > 0) { + g_default_client_tcp_user_timeout_ms = timeout; + } + } else { + g_default_server_tcp_user_timeout_enabled = enable; + if (timeout > 0) { + g_default_server_tcp_user_timeout_ms = timeout; + } + } +} + +/* Set TCP_USER_TIMEOUT */ +grpc_error_handle grpc_set_socket_tcp_user_timeout( + int fd, const grpc_channel_args* channel_args, bool is_client) { + // Use conditionally-important parameter to avoid warning + (void)fd; + (void)channel_args; + (void)is_client; + extern grpc_core::TraceFlag grpc_tcp_trace; + if (g_socket_supports_tcp_user_timeout.load() >= 0) { + bool enable; + int timeout; + if (is_client) { + enable = g_default_client_tcp_user_timeout_enabled; + timeout = g_default_client_tcp_user_timeout_ms; + } else { + enable = g_default_server_tcp_user_timeout_enabled; + timeout = g_default_server_tcp_user_timeout_ms; + } + if (channel_args) { + for (unsigned int i = 0; i < channel_args->num_args; i++) { + if (0 == + strcmp(channel_args->args[i].key, GRPC_ARG_KEEPALIVE_TIME_MS)) { + const int value = grpc_channel_arg_get_integer( + &channel_args->args[i], grpc_integer_options{0, 1, INT_MAX}); + /* Continue using default if value is 0 */ + if (value == 0) { + continue; + } + /* Disable if value is INT_MAX */ + enable = value != INT_MAX; + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_KEEPALIVE_TIMEOUT_MS)) { + const int value = grpc_channel_arg_get_integer( + &channel_args->args[i], grpc_integer_options{0, 1, INT_MAX}); + /* Continue using default if value is 0 */ + if (value == 0) { + continue; + } + timeout = value; + } + } + } + if (enable) { + int newval; + socklen_t len = sizeof(newval); + // If this is the first time to use TCP_USER_TIMEOUT, try to check + // if it is available. + if (g_socket_supports_tcp_user_timeout.load() == 0) { + if (0 != getsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &newval, &len)) { + gpr_log(GPR_INFO, + "TCP_USER_TIMEOUT is not available. TCP_USER_TIMEOUT won't " + "be used thereafter"); + g_socket_supports_tcp_user_timeout.store(-1); + } else { + gpr_log(GPR_INFO, + "TCP_USER_TIMEOUT is available. TCP_USER_TIMEOUT will be " + "used thereafter"); + g_socket_supports_tcp_user_timeout.store(1); + } + } + if (g_socket_supports_tcp_user_timeout.load() > 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "Enabling TCP_USER_TIMEOUT with a timeout of %d ms", + timeout); + } + if (0 != setsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &timeout, + sizeof(timeout))) { + gpr_log(GPR_ERROR, "setsockopt(TCP_USER_TIMEOUT) %s", + strerror(errno)); + return GRPC_ERROR_NONE; + } + if (0 != getsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &newval, &len)) { + gpr_log(GPR_ERROR, "getsockopt(TCP_USER_TIMEOUT) %s", + strerror(errno)); + return GRPC_ERROR_NONE; + } + if (newval != timeout) { + /* Do not fail on failing to set TCP_USER_TIMEOUT for now. */ + gpr_log(GPR_ERROR, "Failed to set TCP_USER_TIMEOUT"); + return GRPC_ERROR_NONE; + } + } + } + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP_USER_TIMEOUT not supported for this platform"); + } + } + return GRPC_ERROR_NONE; +} + +/* set a socket using a grpc_socket_mutator */ +grpc_error_handle grpc_set_socket_with_mutator(int fd, grpc_fd_usage usage, + grpc_socket_mutator* mutator) { + GPR_ASSERT(mutator); + if (!grpc_socket_mutator_mutate_fd(mutator, fd, usage)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("grpc_socket_mutator failed."); + } + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_apply_socket_mutator_in_args( + int fd, grpc_fd_usage usage, const grpc_channel_args* args) { + const grpc_arg* socket_mutator_arg = + grpc_channel_args_find(args, GRPC_ARG_SOCKET_MUTATOR); + if (socket_mutator_arg == nullptr) { + return GRPC_ERROR_NONE; + } + GPR_DEBUG_ASSERT(socket_mutator_arg->type == GRPC_ARG_POINTER); + grpc_socket_mutator* mutator = + static_cast(socket_mutator_arg->value.pointer.p); + return grpc_set_socket_with_mutator(fd, usage, mutator); +} + +static gpr_once g_probe_ipv6_once = GPR_ONCE_INIT; +static int g_ipv6_loopback_available; + +static void probe_ipv6_once(void) { + int fd = socket(AF_INET6, SOCK_STREAM, 0); + g_ipv6_loopback_available = 0; + if (fd < 0) { + gpr_log(GPR_INFO, "Disabling AF_INET6 sockets because socket() failed."); + } else { + grpc_sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_addr.s6_addr[15] = 1; /* [::1]:0 */ + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) == 0) { + g_ipv6_loopback_available = 1; + } else { + gpr_log(GPR_INFO, + "Disabling AF_INET6 sockets because ::1 is not available."); + } + close(fd); + } +} + +int grpc_ipv6_loopback_available(void) { + gpr_once_init(&g_probe_ipv6_once, probe_ipv6_once); + return g_ipv6_loopback_available; +} + +static grpc_error_handle error_for_fd(int fd, + const grpc_resolved_address* addr) { + if (fd >= 0) return GRPC_ERROR_NONE; + std::string addr_str = grpc_sockaddr_to_string(addr, false); + grpc_error_handle err = grpc_error_set_str( + GRPC_OS_ERROR(errno, "socket"), GRPC_ERROR_STR_TARGET_ADDRESS, addr_str); + return err; +} + +grpc_error_handle grpc_create_dualstack_socket( + const grpc_resolved_address* resolved_addr, int type, int protocol, + grpc_dualstack_mode* dsmode, int* newfd) { + return grpc_create_dualstack_socket_using_factory( + nullptr, resolved_addr, type, protocol, dsmode, newfd); +} + +static int create_socket(grpc_socket_factory* factory, int domain, int type, + int protocol) { + return (factory != nullptr) + ? grpc_socket_factory_socket(factory, domain, type, protocol) + : socket(domain, type, protocol); +} + +grpc_error_handle grpc_create_dualstack_socket_using_factory( + grpc_socket_factory* factory, const grpc_resolved_address* resolved_addr, + int type, int protocol, grpc_dualstack_mode* dsmode, int* newfd) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + int family = addr->sa_family; + if (family == AF_INET6) { + if (grpc_ipv6_loopback_available()) { + *newfd = create_socket(factory, family, type, protocol); + } else { + *newfd = -1; + errno = EAFNOSUPPORT; + } + /* Check if we've got a valid dualstack socket. */ + if (*newfd >= 0 && grpc_set_socket_dualstack(*newfd)) { + *dsmode = GRPC_DSMODE_DUALSTACK; + return GRPC_ERROR_NONE; + } + /* If this isn't an IPv4 address, then return whatever we've got. */ + if (!grpc_sockaddr_is_v4mapped(resolved_addr, nullptr)) { + *dsmode = GRPC_DSMODE_IPV6; + return error_for_fd(*newfd, resolved_addr); + } + /* Fall back to AF_INET. */ + if (*newfd >= 0) { + close(*newfd); + } + family = AF_INET; + } + *dsmode = family == AF_INET ? GRPC_DSMODE_IPV4 : GRPC_DSMODE_NONE; + *newfd = create_socket(factory, family, type, protocol); + return error_for_fd(*newfd, resolved_addr); +} + +uint16_t grpc_htons(uint16_t hostshort) { return htons(hostshort); } + +uint16_t grpc_ntohs(uint16_t netshort) { return ntohs(netshort); } + +uint32_t grpc_htonl(uint32_t hostlong) { return htonl(hostlong); } + +uint32_t grpc_ntohl(uint32_t netlong) { return ntohl(netlong); } + +int grpc_inet_pton(int af, const char* src, void* dst) { + return inet_pton(af, src, dst); +} + +const char* grpc_inet_ntop(int af, const void* src, char* dst, size_t size) { + GPR_ASSERT(size <= (socklen_t)-1); + return inet_ntop(af, src, dst, static_cast(size)); +} + +#endif diff --git a/src/core/lib/iomgr/socket_utils_linux.cc b/src/core/lib/iomgr/socket_utils_linux.cc new file mode 100644 index 00000000..ef5d975b --- /dev/null +++ b/src/core/lib/iomgr/socket_utils_linux.cc @@ -0,0 +1,42 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_LINUX_SOCKETUTILS + +#include +#include + +#include + +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" + +int grpc_accept4(int sockfd, grpc_resolved_address* resolved_addr, int nonblock, + int cloexec) { + int flags = 0; + flags |= nonblock ? SOCK_NONBLOCK : 0; + flags |= cloexec ? SOCK_CLOEXEC : 0; + return accept4(sockfd, reinterpret_cast(resolved_addr->addr), + &resolved_addr->len, flags); +} + +#endif diff --git a/src/core/lib/iomgr/socket_utils_posix.cc b/src/core/lib/iomgr/socket_utils_posix.cc new file mode 100644 index 00000000..333e60db --- /dev/null +++ b/src/core/lib/iomgr/socket_utils_posix.cc @@ -0,0 +1,58 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKETUTILS + +#include +#include +#include + +#include + +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" + +int grpc_accept4(int sockfd, grpc_resolved_address* resolved_addr, int nonblock, + int cloexec) { + int fd, flags; + fd = accept(sockfd, reinterpret_cast(resolved_addr->addr), + &resolved_addr->len); + if (fd >= 0) { + if (nonblock) { + flags = fcntl(fd, F_GETFL, 0); + if (flags < 0) goto close_and_error; + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) goto close_and_error; + } + if (cloexec) { + flags = fcntl(fd, F_GETFD, 0); + if (flags < 0) goto close_and_error; + if (fcntl(fd, F_SETFD, flags | FD_CLOEXEC) != 0) goto close_and_error; + } + } + return fd; + +close_and_error: + close(fd); + return -1; +} + +#endif /* GRPC_POSIX_SOCKETUTILS */ diff --git a/src/core/lib/iomgr/socket_utils_windows.cc b/src/core/lib/iomgr/socket_utils_windows.cc new file mode 100644 index 00000000..4f483f07 --- /dev/null +++ b/src/core/lib/iomgr/socket_utils_windows.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINDOWS_SOCKETUTILS + +#include + +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" + +uint16_t grpc_htons(uint16_t hostshort) { return htons(hostshort); } + +uint16_t grpc_ntohs(uint16_t netshort) { return ntohs(netshort); } + +uint32_t grpc_htonl(uint32_t hostlong) { return htonl(hostlong); } + +uint32_t grpc_ntohl(uint32_t netlong) { return ntohl(netlong); } + +int grpc_inet_pton(int af, const char* src, void* dst) { + return inet_pton(af, src, dst); +} + +const char* grpc_inet_ntop(int af, const void* src, char* dst, size_t size) { + /* Windows InetNtopA wants a mutable ip pointer */ + return InetNtopA(af, (void*)src, dst, size); +} + +#endif /* GRPC_WINDOWS_SOCKETUTILS */ diff --git a/src/core/lib/iomgr/socket_windows.cc b/src/core/lib/iomgr/socket_windows.cc new file mode 100644 index 00000000..9d804dda --- /dev/null +++ b/src/core/lib/iomgr/socket_windows.cc @@ -0,0 +1,202 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include + +// must be included after winsock2.h +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_windows.h" +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" + +static DWORD s_wsa_socket_flags; + +grpc_winsocket* grpc_winsocket_create(SOCKET socket, const char* name) { + grpc_winsocket* r = (grpc_winsocket*)gpr_malloc(sizeof(grpc_winsocket)); + memset(r, 0, sizeof(grpc_winsocket)); + r->socket = socket; + gpr_mu_init(&r->state_mu); + grpc_iomgr_register_object( + &r->iomgr_object, absl::StrFormat("%s:socket=0x%p", name, r).c_str()); + grpc_iocp_add_socket(r); + return r; +} + +SOCKET grpc_winsocket_wrapped_socket(grpc_winsocket* socket) { + return socket->socket; +} + +/* Schedule a shutdown of the socket operations. Will call the pending + operations to abort them. We need to do that this way because of the + various callsites of that function, which happens to be in various + mutex hold states, and that'd be unsafe to call them directly. */ +void grpc_winsocket_shutdown(grpc_winsocket* winsocket) { + /* Grab the function pointer for DisconnectEx for that specific socket. + It may change depending on the interface. */ + int status; + GUID guid = WSAID_DISCONNECTEX; + LPFN_DISCONNECTEX DisconnectEx; + DWORD ioctl_num_bytes; + + gpr_mu_lock(&winsocket->state_mu); + if (winsocket->shutdown_called) { + gpr_mu_unlock(&winsocket->state_mu); + return; + } + winsocket->shutdown_called = true; + gpr_mu_unlock(&winsocket->state_mu); + + status = WSAIoctl(winsocket->socket, SIO_GET_EXTENSION_FUNCTION_POINTER, + &guid, sizeof(guid), &DisconnectEx, sizeof(DisconnectEx), + &ioctl_num_bytes, NULL, NULL); + + if (status == 0) { + DisconnectEx(winsocket->socket, NULL, 0, 0); + } else { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_INFO, "Unable to retrieve DisconnectEx pointer : %s", + utf8_message); + gpr_free(utf8_message); + } + closesocket(winsocket->socket); +} + +static void destroy(grpc_winsocket* winsocket) { + grpc_iomgr_unregister_object(&winsocket->iomgr_object); + gpr_mu_destroy(&winsocket->state_mu); + gpr_free(winsocket); +} + +static bool check_destroyable(grpc_winsocket* winsocket) { + return winsocket->destroy_called == true && + winsocket->write_info.closure == NULL && + winsocket->read_info.closure == NULL; +} + +void grpc_winsocket_destroy(grpc_winsocket* winsocket) { + gpr_mu_lock(&winsocket->state_mu); + GPR_ASSERT(!winsocket->destroy_called); + winsocket->destroy_called = true; + bool should_destroy = check_destroyable(winsocket); + gpr_mu_unlock(&winsocket->state_mu); + if (should_destroy) destroy(winsocket); +} + +/* Calling notify_on_read or write means either of two things: +-) The IOCP already completed in the background, and we need to call +the callback now. +-) The IOCP hasn't completed yet, and we're queuing it for later. */ +static void socket_notify_on_iocp(grpc_winsocket* socket, grpc_closure* closure, + grpc_winsocket_callback_info* info) { + GPR_ASSERT(info->closure == NULL); + gpr_mu_lock(&socket->state_mu); + if (info->has_pending_iocp) { + info->has_pending_iocp = 0; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + } else { + info->closure = closure; + } + gpr_mu_unlock(&socket->state_mu); +} + +void grpc_socket_notify_on_write(grpc_winsocket* socket, + grpc_closure* closure) { + socket_notify_on_iocp(socket, closure, &socket->write_info); +} + +void grpc_socket_notify_on_read(grpc_winsocket* socket, grpc_closure* closure) { + socket_notify_on_iocp(socket, closure, &socket->read_info); +} + +void grpc_socket_become_ready(grpc_winsocket* socket, + grpc_winsocket_callback_info* info) { + GPR_ASSERT(!info->has_pending_iocp); + gpr_mu_lock(&socket->state_mu); + if (info->closure) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, info->closure, GRPC_ERROR_NONE); + info->closure = NULL; + } else { + info->has_pending_iocp = 1; + } + bool should_destroy = check_destroyable(socket); + gpr_mu_unlock(&socket->state_mu); + if (should_destroy) destroy(socket); +} + +static gpr_once g_probe_ipv6_once = GPR_ONCE_INIT; +static bool g_ipv6_loopback_available = false; + +static void probe_ipv6_once(void) { + SOCKET s = socket(AF_INET6, SOCK_STREAM, 0); + g_ipv6_loopback_available = 0; + if (s == INVALID_SOCKET) { + gpr_log(GPR_INFO, "Disabling AF_INET6 sockets because socket() failed."); + } else { + grpc_sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_addr.s6_addr[15] = 1; /* [::1]:0 */ + if (bind(s, reinterpret_cast(&addr), sizeof(addr)) == 0) { + g_ipv6_loopback_available = 1; + } else { + gpr_log(GPR_INFO, + "Disabling AF_INET6 sockets because ::1 is not available."); + } + closesocket(s); + } +} + +int grpc_ipv6_loopback_available(void) { + gpr_once_init(&g_probe_ipv6_once, probe_ipv6_once); + return g_ipv6_loopback_available; +} + +DWORD grpc_get_default_wsa_socket_flags() { return s_wsa_socket_flags; } + +void grpc_wsa_socket_flags_init() { + s_wsa_socket_flags = WSA_FLAG_OVERLAPPED; + /* WSA_FLAG_NO_HANDLE_INHERIT may be not supported on the older Windows + versions, see + https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx + for details. */ + SOCKET sock = WSASocket(AF_INET6, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + s_wsa_socket_flags | WSA_FLAG_NO_HANDLE_INHERIT); + if (sock != INVALID_SOCKET) { + /* Windows 7, Windows 2008 R2 with SP1 or later */ + s_wsa_socket_flags |= WSA_FLAG_NO_HANDLE_INHERIT; + closesocket(sock); + } +} + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/tcp_client.cc b/src/core/lib/iomgr/tcp_client.cc new file mode 100644 index 00000000..02b7e4b4 --- /dev/null +++ b/src/core/lib/iomgr/tcp_client.cc @@ -0,0 +1,38 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/tcp_client.h" + +grpc_tcp_client_vtable* grpc_tcp_client_impl; + +void grpc_tcp_client_connect(grpc_closure* on_connect, grpc_endpoint** endpoint, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, + grpc_millis deadline) { + grpc_tcp_client_impl->connect(on_connect, endpoint, slice_allocator, + interested_parties, channel_args, addr, + deadline); +} + +void grpc_set_tcp_client_impl(grpc_tcp_client_vtable* impl) { + grpc_tcp_client_impl = impl; +} diff --git a/src/core/lib/iomgr/tcp_client_cfstream.cc b/src/core/lib/iomgr/tcp_client_cfstream.cc new file mode 100644 index 00000000..5292e231 --- /dev/null +++ b/src/core/lib/iomgr/tcp_client_cfstream.cc @@ -0,0 +1,205 @@ + +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_CFSTREAM_CLIENT + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/cfstream_handle.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/endpoint_cfstream.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/error_cfstream.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/timer.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +struct CFStreamConnect { + gpr_mu mu; + gpr_refcount refcount; + + CFReadStreamRef read_stream; + CFWriteStreamRef write_stream; + CFStreamHandle* stream_handle; + + grpc_timer alarm; + grpc_closure on_alarm; + grpc_closure on_open; + + bool read_stream_open; + bool write_stream_open; + bool failed; + + grpc_closure* closure; + grpc_endpoint** endpoint; + int refs; + std::string addr_name; + grpc_slice_allocator* slice_allocator; +}; + +static void CFStreamConnectCleanup(CFStreamConnect* connect) { + CFSTREAM_HANDLE_UNREF(connect->stream_handle, "async connect clean up"); + CFRelease(connect->read_stream); + CFRelease(connect->write_stream); + if (connect->slice_allocator != nullptr) { + grpc_slice_allocator_destroy(connect->slice_allocator); + } + gpr_mu_destroy(&connect->mu); + delete connect; +} + +static void OnAlarm(void* arg, grpc_error_handle error) { + CFStreamConnect* connect = static_cast(arg); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CLIENT_CONNECT :%p OnAlarm, error:%s", connect, + grpc_error_std_string(error).c_str()); + } + gpr_mu_lock(&connect->mu); + grpc_closure* closure = connect->closure; + connect->closure = nil; + const bool done = (--connect->refs == 0); + gpr_mu_unlock(&connect->mu); + // Only schedule a callback once, by either OnAlarm or OnOpen. The + // first one issues callback while the second one does cleanup. + if (done) { + CFStreamConnectCleanup(connect); + } else { + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("connect() timed out"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); + } +} + +static void OnOpen(void* arg, grpc_error_handle error) { + CFStreamConnect* connect = static_cast(arg); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CLIENT_CONNECT :%p OnOpen, error:%s", connect, + grpc_error_std_string(error).c_str()); + } + gpr_mu_lock(&connect->mu); + grpc_timer_cancel(&connect->alarm); + grpc_closure* closure = connect->closure; + connect->closure = nil; + + bool done = (--connect->refs == 0); + grpc_endpoint** endpoint = connect->endpoint; + + // Only schedule a callback once, by either OnAlarm or OnOpen. The + // first one issues callback while the second one does cleanup. + if (done) { + gpr_mu_unlock(&connect->mu); + CFStreamConnectCleanup(connect); + } else { + if (error == GRPC_ERROR_NONE) { + CFErrorRef stream_error = CFReadStreamCopyError(connect->read_stream); + if (stream_error == NULL) { + stream_error = CFWriteStreamCopyError(connect->write_stream); + } + if (stream_error) { + error = GRPC_ERROR_CREATE_FROM_CFERROR(stream_error, "connect() error"); + CFRelease(stream_error); + } + if (error == GRPC_ERROR_NONE) { + *endpoint = grpc_cfstream_endpoint_create( + connect->read_stream, connect->write_stream, + connect->addr_name.c_str(), connect->slice_allocator, + connect->stream_handle); + connect->slice_allocator = nullptr; + } + } else { + (void)GRPC_ERROR_REF(error); + } + gpr_mu_unlock(&connect->mu); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); + } +} + +static void ParseResolvedAddress(const grpc_resolved_address* addr, + CFStringRef* host, int* port) { + std::string host_port = grpc_sockaddr_to_string(addr, true); + std::string host_string; + std::string port_string; + grpc_core::SplitHostPort(host_port, &host_string, &port_string); + *host = CFStringCreateWithCString(NULL, host_string.c_str(), + kCFStringEncodingUTF8); + *port = grpc_sockaddr_get_port(addr); +} + +static void CFStreamClientConnect(grpc_closure* closure, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* resolved_addr, + grpc_millis deadline) { + CFStreamConnect* connect = new CFStreamConnect(); + connect->closure = closure; + connect->endpoint = ep; + connect->addr_name = grpc_sockaddr_to_uri(resolved_addr); + connect->refs = 2; // One for the connect operation, one for the timer. + gpr_ref_init(&connect->refcount, 1); + gpr_mu_init(&connect->mu); + + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_DEBUG, "CLIENT_CONNECT: %p, %s: asynchronously connecting", + connect, connect->addr_name.c_str()); + } + connect->slice_allocator = slice_allocator; + + CFReadStreamRef read_stream; + CFWriteStreamRef write_stream; + + CFStringRef host; + int port; + ParseResolvedAddress(resolved_addr, &host, &port); + CFStreamCreatePairWithSocketToHost(NULL, host, port, &read_stream, + &write_stream); + CFRelease(host); + connect->read_stream = read_stream; + connect->write_stream = write_stream; + connect->stream_handle = + CFStreamHandle::CreateStreamHandle(read_stream, write_stream); + GRPC_CLOSURE_INIT(&connect->on_open, OnOpen, static_cast(connect), + grpc_schedule_on_exec_ctx); + connect->stream_handle->NotifyOnOpen(&connect->on_open); + GRPC_CLOSURE_INIT(&connect->on_alarm, OnAlarm, connect, + grpc_schedule_on_exec_ctx); + gpr_mu_lock(&connect->mu); + CFReadStreamOpen(read_stream); + CFWriteStreamOpen(write_stream); + grpc_timer_init(&connect->alarm, deadline, &connect->on_alarm); + gpr_mu_unlock(&connect->mu); +} + +grpc_tcp_client_vtable grpc_cfstream_client_vtable = {CFStreamClientConnect}; + +#endif /* GRPC_CFSTREAM_CLIENT */ diff --git a/src/core/lib/iomgr/tcp_client_custom.cc b/src/core/lib/iomgr/tcp_client_custom.cc new file mode 100644 index 00000000..363a4922 --- /dev/null +++ b/src/core/lib/iomgr/tcp_client_custom.cc @@ -0,0 +1,152 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_custom.h" +#include "src/core/lib/iomgr/timer.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; +extern grpc_socket_vtable* grpc_custom_socket_vtable; + +struct grpc_custom_tcp_connect { + grpc_custom_socket* socket; + grpc_timer alarm; + grpc_closure on_alarm; + grpc_closure* closure; + grpc_endpoint** endpoint; + int refs; + std::string addr_name; + grpc_slice_allocator* slice_allocator; +}; + +static void custom_tcp_connect_cleanup(grpc_custom_tcp_connect* connect) { + if (connect->slice_allocator != nullptr) { + grpc_slice_allocator_destroy(connect->slice_allocator); + } + grpc_custom_socket* socket = connect->socket; + delete connect; + socket->refs--; + if (socket->refs == 0) { + grpc_custom_socket_vtable->destroy(socket); + gpr_free(socket); + } +} + +static void custom_close_callback(grpc_custom_socket* /*socket*/) {} + +static void on_alarm(void* acp, grpc_error_handle error) { + int done; + grpc_custom_socket* socket = static_cast(acp); + grpc_custom_tcp_connect* connect = socket->connector; + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "CLIENT_CONNECT: %s: on_alarm: error=%s", + connect->addr_name.c_str(), grpc_error_std_string(error).c_str()); + } + if (error == GRPC_ERROR_NONE) { + /* error == NONE implies that the timer ran out, and wasn't cancelled. If + it was cancelled, then the handler that cancelled it also should close + the handle, if applicable */ + grpc_custom_socket_vtable->close(socket, custom_close_callback); + } + done = (--connect->refs == 0); + if (done) { + custom_tcp_connect_cleanup(connect); + } +} + +static void custom_connect_callback_internal(grpc_custom_socket* socket, + grpc_error_handle error) { + grpc_custom_tcp_connect* connect = socket->connector; + int done; + grpc_closure* closure = connect->closure; + grpc_timer_cancel(&connect->alarm); + if (error == GRPC_ERROR_NONE) { + *connect->endpoint = custom_tcp_endpoint_create( + socket, connect->slice_allocator, connect->addr_name.c_str()); + connect->slice_allocator = nullptr; + } + done = (--connect->refs == 0); + if (done) { + grpc_core::ExecCtx::Get()->Flush(); + custom_tcp_connect_cleanup(connect); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); +} + +static void custom_connect_callback(grpc_custom_socket* socket, + grpc_error_handle error) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + if (grpc_core::ExecCtx::Get() == nullptr) { + /* If we are being run on a thread which does not have an exec_ctx created + * yet, we should create one. */ + grpc_core::ExecCtx exec_ctx; + custom_connect_callback_internal(socket, error); + } else { + custom_connect_callback_internal(socket, error); + } +} + +static void tcp_connect(grpc_closure* closure, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* resolved_addr, + grpc_millis deadline) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + (void)channel_args; + (void)interested_parties; + grpc_custom_socket* socket = + static_cast(gpr_malloc(sizeof(grpc_custom_socket))); + socket->refs = 2; + (void)grpc_custom_socket_vtable->init(socket, GRPC_AF_UNSPEC); + grpc_custom_tcp_connect* connect = new grpc_custom_tcp_connect(); + connect->closure = closure; + connect->endpoint = ep; + connect->addr_name = grpc_sockaddr_to_uri(resolved_addr); + connect->slice_allocator = slice_allocator; + connect->socket = socket; + socket->connector = connect; + socket->endpoint = nullptr; + socket->listener = nullptr; + connect->refs = 2; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "CLIENT_CONNECT: %p %s: asynchronously connecting", + socket, connect->addr_name.c_str()); + } + + GRPC_CLOSURE_INIT(&connect->on_alarm, on_alarm, socket, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&connect->alarm, deadline, &connect->on_alarm); + grpc_custom_socket_vtable->connect( + socket, reinterpret_cast(resolved_addr->addr), + resolved_addr->len, custom_connect_callback); +} + +grpc_tcp_client_vtable custom_tcp_client_vtable = {tcp_connect}; diff --git a/src/core/lib/iomgr/tcp_client_posix.cc b/src/core/lib/iomgr/tcp_client_posix.cc new file mode 100644 index 00000000..f763cfea --- /dev/null +++ b/src/core/lib/iomgr/tcp_client_posix.cc @@ -0,0 +1,360 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP_CLIENT + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_mutator.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/tcp_client_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "src/core/lib/slice/slice_internal.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +struct async_connect { + gpr_mu mu; + grpc_fd* fd; + grpc_timer alarm; + grpc_closure on_alarm; + int refs; + grpc_closure write_closure; + grpc_pollset_set* interested_parties; + std::string addr_str; + grpc_endpoint** ep; + grpc_closure* closure; + grpc_channel_args* channel_args; + grpc_slice_allocator* slice_allocator; +}; + +static grpc_error_handle prepare_socket(const grpc_resolved_address* addr, + int fd, + const grpc_channel_args* channel_args) { + grpc_error_handle err = GRPC_ERROR_NONE; + + GPR_ASSERT(fd >= 0); + + err = grpc_set_socket_nonblocking(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_cloexec(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + if (!grpc_is_unix_socket(addr)) { + err = grpc_set_socket_low_latency(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_reuse_addr(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_tcp_user_timeout(fd, channel_args, + true /* is_client */); + if (err != GRPC_ERROR_NONE) goto error; + } + err = grpc_set_socket_no_sigpipe_if_possible(fd); + if (err != GRPC_ERROR_NONE) goto error; + + err = grpc_apply_socket_mutator_in_args(fd, GRPC_FD_CLIENT_CONNECTION_USAGE, + channel_args); + if (err != GRPC_ERROR_NONE) goto error; + + goto done; + +error: + if (fd >= 0) { + close(fd); + } +done: + return err; +} + +static void tc_on_alarm(void* acp, grpc_error_handle error) { + int done; + async_connect* ac = static_cast(acp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "CLIENT_CONNECT: %s: on_alarm: error=%s", + ac->addr_str.c_str(), grpc_error_std_string(error).c_str()); + } + gpr_mu_lock(&ac->mu); + if (ac->fd != nullptr) { + grpc_fd_shutdown( + ac->fd, GRPC_ERROR_CREATE_FROM_STATIC_STRING("connect() timed out")); + } + done = (--ac->refs == 0); + gpr_mu_unlock(&ac->mu); + if (done) { + gpr_mu_destroy(&ac->mu); + if (ac->slice_allocator != nullptr) { + grpc_slice_allocator_destroy(ac->slice_allocator); + } + grpc_channel_args_destroy(ac->channel_args); + delete ac; + } +} + +grpc_endpoint* grpc_tcp_client_create_from_fd( + grpc_fd* fd, const grpc_channel_args* channel_args, const char* addr_str, + grpc_slice_allocator* slice_allocator) { + return grpc_tcp_create(fd, channel_args, addr_str, slice_allocator); +} + +static void on_writable(void* acp, grpc_error_handle error) { + async_connect* ac = static_cast(acp); + int so_error = 0; + socklen_t so_error_size; + int err; + int done; + grpc_endpoint** ep = ac->ep; + grpc_closure* closure = ac->closure; + grpc_fd* fd; + + (void)GRPC_ERROR_REF(error); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "CLIENT_CONNECT: %s: on_writable: error=%s", + ac->addr_str.c_str(), grpc_error_std_string(error).c_str()); + } + + gpr_mu_lock(&ac->mu); + GPR_ASSERT(ac->fd); + fd = ac->fd; + ac->fd = nullptr; + gpr_mu_unlock(&ac->mu); + + grpc_timer_cancel(&ac->alarm); + + gpr_mu_lock(&ac->mu); + if (error != GRPC_ERROR_NONE) { + error = + grpc_error_set_str(error, GRPC_ERROR_STR_OS_ERROR, "Timeout occurred"); + goto finish; + } + + do { + so_error_size = sizeof(so_error); + err = getsockopt(grpc_fd_wrapped_fd(fd), SOL_SOCKET, SO_ERROR, &so_error, + &so_error_size); + } while (err < 0 && errno == EINTR); + if (err < 0) { + error = GRPC_OS_ERROR(errno, "getsockopt"); + goto finish; + } + + switch (so_error) { + case 0: + grpc_pollset_set_del_fd(ac->interested_parties, fd); + *ep = grpc_tcp_client_create_from_fd( + fd, ac->channel_args, ac->addr_str.c_str(), ac->slice_allocator); + ac->slice_allocator = nullptr; + fd = nullptr; + break; + case ENOBUFS: + /* We will get one of these errors if we have run out of + memory in the kernel for the data structures allocated + when you connect a socket. If this happens it is very + likely that if we wait a little bit then try again the + connection will work (since other programs or this + program will close their network connections and free up + memory). This does _not_ indicate that there is anything + wrong with the server we are connecting to, this is a + local problem. + + If you are looking at this code, then chances are that + your program or another program on the same computer + opened too many network connections. The "easy" fix: + don't do that! */ + gpr_log(GPR_ERROR, "kernel out of buffers"); + gpr_mu_unlock(&ac->mu); + grpc_fd_notify_on_write(fd, &ac->write_closure); + return; + case ECONNREFUSED: + /* This error shouldn't happen for anything other than connect(). */ + error = GRPC_OS_ERROR(so_error, "connect"); + break; + default: + /* We don't really know which syscall triggered the problem here, + so punt by reporting getsockopt(). */ + error = GRPC_OS_ERROR(so_error, "getsockopt(SO_ERROR)"); + break; + } + +finish: + if (fd != nullptr) { + grpc_pollset_set_del_fd(ac->interested_parties, fd); + grpc_fd_orphan(fd, nullptr, nullptr, "tcp_client_orphan"); + fd = nullptr; + } + done = (--ac->refs == 0); + gpr_mu_unlock(&ac->mu); + if (error != GRPC_ERROR_NONE) { + std::string str; + bool ret = grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &str); + GPR_ASSERT(ret); + std::string description = + absl::StrCat("Failed to connect to remote host: ", str); + error = grpc_error_set_str(error, GRPC_ERROR_STR_DESCRIPTION, description); + error = + grpc_error_set_str(error, GRPC_ERROR_STR_TARGET_ADDRESS, ac->addr_str); + } + if (done) { + // This is safe even outside the lock, because "done", the sentinel, is + // populated *inside* the lock. + gpr_mu_destroy(&ac->mu); + if (ac->slice_allocator != nullptr) { + grpc_slice_allocator_destroy(ac->slice_allocator); + ac->slice_allocator = nullptr; + } + grpc_channel_args_destroy(ac->channel_args); + delete ac; + } + // Push async connect closure to the executor since this may actually be + // called during the shutdown process, in which case a deadlock could form + // between the core shutdown mu and the connector mu (b/188239051) + grpc_core::Executor::Run(closure, error); +} + +grpc_error_handle grpc_tcp_client_prepare_fd( + const grpc_channel_args* channel_args, const grpc_resolved_address* addr, + grpc_resolved_address* mapped_addr, int* fd) { + grpc_dualstack_mode dsmode; + grpc_error_handle error; + *fd = -1; + /* Use dualstack sockets where available. Set mapped to v6 or v4 mapped to + v6. */ + if (!grpc_sockaddr_to_v4mapped(addr, mapped_addr)) { + /* addr is v4 mapped to v6 or v6. */ + memcpy(mapped_addr, addr, sizeof(*mapped_addr)); + } + error = + grpc_create_dualstack_socket(mapped_addr, SOCK_STREAM, 0, &dsmode, fd); + if (error != GRPC_ERROR_NONE) { + return error; + } + if (dsmode == GRPC_DSMODE_IPV4) { + /* Original addr is either v4 or v4 mapped to v6. Set mapped_addr to v4. */ + if (!grpc_sockaddr_is_v4mapped(addr, mapped_addr)) { + memcpy(mapped_addr, addr, sizeof(*mapped_addr)); + } + } + if ((error = prepare_socket(mapped_addr, *fd, channel_args)) != + GRPC_ERROR_NONE) { + return error; + } + return GRPC_ERROR_NONE; +} + +void grpc_tcp_client_create_from_prepared_fd( + grpc_pollset_set* interested_parties, grpc_closure* closure, const int fd, + const grpc_channel_args* channel_args, const grpc_resolved_address* addr, + grpc_millis deadline, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator) { + int err; + do { + err = connect(fd, reinterpret_cast(addr->addr), + addr->len); + } while (err < 0 && errno == EINTR); + + std::string name = absl::StrCat("tcp-client:", grpc_sockaddr_to_uri(addr)); + grpc_fd* fdobj = grpc_fd_create(fd, name.c_str(), true); + + if (err >= 0) { + *ep = grpc_tcp_client_create_from_fd(fdobj, channel_args, + grpc_sockaddr_to_uri(addr).c_str(), + slice_allocator); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + return; + } + if (errno != EWOULDBLOCK && errno != EINPROGRESS) { + grpc_slice_allocator_destroy(slice_allocator); + grpc_error_handle error = GRPC_OS_ERROR(errno, "connect"); + error = grpc_error_set_str(error, GRPC_ERROR_STR_TARGET_ADDRESS, + grpc_sockaddr_to_uri(addr)); + grpc_fd_orphan(fdobj, nullptr, nullptr, "tcp_client_connect_error"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); + return; + } + + grpc_pollset_set_add_fd(interested_parties, fdobj); + + async_connect* ac = new async_connect(); + ac->closure = closure; + ac->ep = ep; + ac->fd = fdobj; + ac->interested_parties = interested_parties; + ac->addr_str = grpc_sockaddr_to_uri(addr); + gpr_mu_init(&ac->mu); + ac->refs = 2; + ac->slice_allocator = slice_allocator; + GRPC_CLOSURE_INIT(&ac->write_closure, on_writable, ac, + grpc_schedule_on_exec_ctx); + ac->channel_args = grpc_channel_args_copy(channel_args); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "CLIENT_CONNECT: %s: asynchronously connecting fd %p", + ac->addr_str.c_str(), fdobj); + } + + gpr_mu_lock(&ac->mu); + GRPC_CLOSURE_INIT(&ac->on_alarm, tc_on_alarm, ac, grpc_schedule_on_exec_ctx); + grpc_timer_init(&ac->alarm, deadline, &ac->on_alarm); + grpc_fd_notify_on_write(ac->fd, &ac->write_closure); + gpr_mu_unlock(&ac->mu); +} + +static void tcp_connect(grpc_closure* closure, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, + grpc_millis deadline) { + grpc_resolved_address mapped_addr; + int fd = -1; + grpc_error_handle error; + *ep = nullptr; + if ((error = grpc_tcp_client_prepare_fd(channel_args, addr, &mapped_addr, + &fd)) != GRPC_ERROR_NONE) { + grpc_slice_allocator_destroy(slice_allocator); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); + return; + } + grpc_tcp_client_create_from_prepared_fd(interested_parties, closure, fd, + channel_args, &mapped_addr, deadline, + ep, slice_allocator); +} + +grpc_tcp_client_vtable grpc_posix_tcp_client_vtable = {tcp_connect}; +#endif diff --git a/src/core/lib/iomgr/tcp_client_windows.cc b/src/core/lib/iomgr/tcp_client_windows.cc new file mode 100644 index 00000000..7d201048 --- /dev/null +++ b/src/core/lib/iomgr/tcp_client_windows.cc @@ -0,0 +1,241 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_windows.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" + +struct async_connect { + grpc_closure* on_done; + gpr_mu mu; + grpc_winsocket* socket; + grpc_timer alarm; + grpc_closure on_alarm; + std::string addr_name; + int refs; + grpc_closure on_connect; + grpc_endpoint** endpoint; + grpc_channel_args* channel_args; + grpc_slice_allocator* slice_allocator; +}; + +static void async_connect_unlock_and_cleanup(async_connect* ac, + grpc_winsocket* socket) { + int done = (--ac->refs == 0); + gpr_mu_unlock(&ac->mu); + if (done) { + grpc_channel_args_destroy(ac->channel_args); + gpr_mu_destroy(&ac->mu); + if (ac->slice_allocator != nullptr) { + grpc_slice_allocator_destroy(ac->slice_allocator); + } + delete ac; + } + if (socket != NULL) grpc_winsocket_destroy(socket); +} + +static void on_alarm(void* acp, grpc_error_handle error) { + async_connect* ac = (async_connect*)acp; + gpr_mu_lock(&ac->mu); + grpc_winsocket* socket = ac->socket; + ac->socket = NULL; + if (socket != NULL) { + grpc_winsocket_shutdown(socket); + } + async_connect_unlock_and_cleanup(ac, socket); +} + +static void on_connect(void* acp, grpc_error_handle error) { + async_connect* ac = (async_connect*)acp; + grpc_endpoint** ep = ac->endpoint; + GPR_ASSERT(*ep == NULL); + grpc_closure* on_done = ac->on_done; + + (void)GRPC_ERROR_REF(error); + + gpr_mu_lock(&ac->mu); + grpc_winsocket* socket = ac->socket; + ac->socket = NULL; + gpr_mu_unlock(&ac->mu); + + grpc_timer_cancel(&ac->alarm); + + gpr_mu_lock(&ac->mu); + + if (error == GRPC_ERROR_NONE) { + if (socket != NULL) { + DWORD transfered_bytes = 0; + DWORD flags; + BOOL wsa_success = + WSAGetOverlappedResult(socket->socket, &socket->write_info.overlapped, + &transfered_bytes, FALSE, &flags); + GPR_ASSERT(transfered_bytes == 0); + if (!wsa_success) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "ConnectEx"); + closesocket(socket->socket); + } else { + *ep = grpc_tcp_create(socket, ac->channel_args, ac->addr_name.c_str(), + ac->slice_allocator); + ac->slice_allocator = nullptr; + socket = nullptr; + } + } else { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("socket is null"); + } + } + + async_connect_unlock_and_cleanup(ac, socket); + /* If the connection was aborted, the callback was already called when + the deadline was met. */ + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, error); +} + +/* Tries to issue one async connection, then schedules both an IOCP + notification request for the connection, and one timeout alert. */ +static void tcp_connect(grpc_closure* on_done, grpc_endpoint** endpoint, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, + grpc_millis deadline) { + SOCKET sock = INVALID_SOCKET; + BOOL success; + int status; + grpc_resolved_address addr6_v4mapped; + grpc_resolved_address local_address; + grpc_winsocket* socket = NULL; + LPFN_CONNECTEX ConnectEx; + GUID guid = WSAID_CONNECTEX; + DWORD ioctl_num_bytes; + grpc_winsocket_callback_info* info; + grpc_error_handle error = GRPC_ERROR_NONE; + async_connect* ac = NULL; + + *endpoint = NULL; + + /* Use dualstack sockets where available. */ + if (grpc_sockaddr_to_v4mapped(addr, &addr6_v4mapped)) { + addr = &addr6_v4mapped; + } + + sock = WSASocket(AF_INET6, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + grpc_get_default_wsa_socket_flags()); + if (sock == INVALID_SOCKET) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "WSASocket"); + goto failure; + } + + error = grpc_tcp_prepare_socket(sock); + if (error != GRPC_ERROR_NONE) { + goto failure; + } + + /* Grab the function pointer for ConnectEx for that specific socket. + It may change depending on the interface. */ + status = + WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), + &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, NULL, NULL); + + if (status != 0) { + error = GRPC_WSA_ERROR(WSAGetLastError(), + "WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER)"); + goto failure; + } + + grpc_sockaddr_make_wildcard6(0, &local_address); + + status = + bind(sock, (grpc_sockaddr*)&local_address.addr, (int)local_address.len); + if (status != 0) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "bind"); + goto failure; + } + + socket = grpc_winsocket_create(sock, "client"); + info = &socket->write_info; + success = ConnectEx(sock, (grpc_sockaddr*)&addr->addr, (int)addr->len, NULL, + 0, NULL, &info->overlapped); + + /* It wouldn't be unusual to get a success immediately. But we'll still get + an IOCP notification, so let's ignore it. */ + if (!success) { + int last_error = WSAGetLastError(); + if (last_error != ERROR_IO_PENDING) { + error = GRPC_WSA_ERROR(last_error, "ConnectEx"); + goto failure; + } + } + + ac = new async_connect(); + ac->on_done = on_done; + ac->socket = socket; + gpr_mu_init(&ac->mu); + ac->refs = 2; + ac->addr_name = grpc_sockaddr_to_uri(addr); + ac->endpoint = endpoint; + ac->slice_allocator = slice_allocator; + ac->channel_args = grpc_channel_args_copy(channel_args); + GRPC_CLOSURE_INIT(&ac->on_connect, on_connect, ac, grpc_schedule_on_exec_ctx); + + GRPC_CLOSURE_INIT(&ac->on_alarm, on_alarm, ac, grpc_schedule_on_exec_ctx); + gpr_mu_lock(&ac->mu); + grpc_timer_init(&ac->alarm, deadline, &ac->on_alarm); + grpc_socket_notify_on_write(socket, &ac->on_connect); + gpr_mu_unlock(&ac->mu); + return; + +failure: + GPR_ASSERT(error != GRPC_ERROR_NONE); + std::string target_uri = grpc_sockaddr_to_uri(addr); + grpc_error_handle final_error = + grpc_error_set_str(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to connect", &error, 1), + GRPC_ERROR_STR_TARGET_ADDRESS, target_uri); + GRPC_ERROR_UNREF(error); + grpc_slice_allocator_destroy(slice_allocator); + if (socket != NULL) { + grpc_winsocket_destroy(socket); + } else if (sock != INVALID_SOCKET) { + closesocket(sock); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, final_error); +} + +grpc_tcp_client_vtable grpc_windows_tcp_client_vtable = {tcp_connect}; + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/tcp_custom.cc b/src/core/lib/iomgr/tcp_custom.cc new file mode 100644 index 00000000..a74bc288 --- /dev/null +++ b/src/core/lib/iomgr/tcp_custom.cc @@ -0,0 +1,377 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/tcp_custom.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +#define GRPC_TCP_DEFAULT_READ_SLICE_SIZE 8192 + +extern grpc_core::TraceFlag grpc_tcp_trace; + +grpc_socket_vtable* grpc_custom_socket_vtable = nullptr; +extern grpc_tcp_server_vtable custom_tcp_server_vtable; +extern grpc_tcp_client_vtable custom_tcp_client_vtable; + +void grpc_custom_endpoint_init(grpc_socket_vtable* impl) { + grpc_custom_socket_vtable = impl; + grpc_set_tcp_client_impl(&custom_tcp_client_vtable); + grpc_set_tcp_server_impl(&custom_tcp_server_vtable); +} + +struct custom_tcp_endpoint { + grpc_endpoint base; + gpr_refcount refcount; + grpc_custom_socket* socket; + + grpc_closure* read_cb = nullptr; + grpc_closure* write_cb = nullptr; + + grpc_slice_buffer* read_slices = nullptr; + grpc_slice_buffer* write_slices = nullptr; + + grpc_slice_allocator* slice_allocator; + + bool shutting_down; + + std::string peer_string; + std::string local_address; +}; +static void tcp_free(grpc_custom_socket* s) { + custom_tcp_endpoint* tcp = + reinterpret_cast(s->endpoint); + grpc_slice_allocator_destroy(tcp->slice_allocator); + delete tcp; + s->refs--; + if (s->refs == 0) { + grpc_custom_socket_vtable->destroy(s); + gpr_free(s); + } +} + +#ifndef NDEBUG +#define TCP_UNREF(tcp, reason) tcp_unref((tcp), (reason), __FILE__, __LINE__) +#define TCP_REF(tcp, reason) tcp_ref((tcp), (reason), __FILE__, __LINE__) +static void tcp_unref(custom_tcp_endpoint* tcp, const char* reason, + const char* file, int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_atm val = gpr_atm_no_barrier_load(&tcp->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, + "TCP unref %p : %s %" PRIdPTR " -> %" PRIdPTR, tcp->socket, reason, + val, val - 1); + } + if (gpr_unref(&tcp->refcount)) { + tcp_free(tcp->socket); + } +} + +static void tcp_ref(custom_tcp_endpoint* tcp, const char* reason, + const char* file, int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_atm val = gpr_atm_no_barrier_load(&tcp->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, + "TCP ref %p : %s %" PRIdPTR " -> %" PRIdPTR, tcp->socket, reason, + val, val + 1); + } + gpr_ref(&tcp->refcount); +} +#else +#define TCP_UNREF(tcp, reason) tcp_unref((tcp)) +#define TCP_REF(tcp, reason) tcp_ref((tcp)) +static void tcp_unref(custom_tcp_endpoint* tcp) { + if (gpr_unref(&tcp->refcount)) { + tcp_free(tcp->socket); + } +} + +static void tcp_ref(custom_tcp_endpoint* tcp) { gpr_ref(&tcp->refcount); } +#endif + +static void call_read_cb(custom_tcp_endpoint* tcp, grpc_error_handle error) { + grpc_closure* cb = tcp->read_cb; + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p call_cb %p %p:%p", tcp->socket, cb, cb->cb, + cb->cb_arg); + size_t i; + gpr_log(GPR_INFO, "read: error=%s", grpc_error_std_string(error).c_str()); + for (i = 0; i < tcp->read_slices->count; i++) { + char* dump = grpc_dump_slice(tcp->read_slices->slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "READ %p (peer=%s): %s", tcp, tcp->peer_string.c_str(), + dump); + gpr_free(dump); + } + } + TCP_UNREF(tcp, "read"); + tcp->read_slices = nullptr; + tcp->read_cb = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +static void custom_read_callback(grpc_custom_socket* socket, size_t nread, + grpc_error_handle error) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer garbage; + custom_tcp_endpoint* tcp = + reinterpret_cast(socket->endpoint); + if (error == GRPC_ERROR_NONE && nread == 0) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("EOF"); + } + if (error == GRPC_ERROR_NONE) { + // Successful read + if (nread < tcp->read_slices->length) { + /* TODO(murgatroid99): Instead of discarding the unused part of the read + * buffer, reuse it as the next read buffer. */ + grpc_slice_buffer_init(&garbage); + grpc_slice_buffer_trim_end(tcp->read_slices, + tcp->read_slices->length - nread, &garbage); + grpc_slice_buffer_reset_and_unref_internal(&garbage); + } + } else { + grpc_slice_buffer_reset_and_unref_internal(tcp->read_slices); + } + call_read_cb(tcp, error); +} + +static void tcp_read_allocation_done(void* tcpp, grpc_error_handle error) { + custom_tcp_endpoint* tcp = static_cast(tcpp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p read_allocation_done: %s", tcp->socket, + grpc_error_std_string(error).c_str()); + } + if (error == GRPC_ERROR_NONE) { + /* Before calling read, we allocate a buffer with exactly one slice + * to tcp->read_slices and wait for the callback indicating that the + * allocation was successful. So slices[0] should always exist here */ + char* buffer = reinterpret_cast( + GRPC_SLICE_START_PTR(tcp->read_slices->slices[0])); + size_t len = GRPC_SLICE_LENGTH(tcp->read_slices->slices[0]); + grpc_custom_socket_vtable->read(tcp->socket, buffer, len, + custom_read_callback); + } else { + grpc_slice_buffer_reset_and_unref_internal(tcp->read_slices); + call_read_cb(tcp, GRPC_ERROR_REF(error)); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "Initiating read on %p: error=%s", tcp->socket, + grpc_error_std_string(error).c_str()); + } +} + +static void endpoint_read(grpc_endpoint* ep, grpc_slice_buffer* read_slices, + grpc_closure* cb, bool /*urgent*/) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + GPR_ASSERT(tcp->read_cb == nullptr); + tcp->read_cb = cb; + tcp->read_slices = read_slices; + grpc_slice_buffer_reset_and_unref_internal(read_slices); + TCP_REF(tcp, "read"); + if (grpc_slice_allocator_allocate( + tcp->slice_allocator, GRPC_TCP_DEFAULT_READ_SLICE_SIZE, 1, + grpc_slice_allocator_intent::kReadBuffer, tcp->read_slices, + tcp_read_allocation_done, tcp)) { + tcp_read_allocation_done(tcp, GRPC_ERROR_NONE); + } +} + +static void custom_write_callback(grpc_custom_socket* socket, + grpc_error_handle error) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + custom_tcp_endpoint* tcp = + reinterpret_cast(socket->endpoint); + grpc_closure* cb = tcp->write_cb; + tcp->write_cb = nullptr; + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "write complete on %p: error=%s", tcp->socket, + grpc_error_std_string(error).c_str()); + } + TCP_UNREF(tcp, "write"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +static void endpoint_write(grpc_endpoint* ep, grpc_slice_buffer* write_slices, + grpc_closure* cb, void* /*arg*/) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + size_t j; + + for (j = 0; j < write_slices->count; j++) { + char* data = grpc_dump_slice(write_slices->slices[j], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "WRITE %p (peer=%s): %s", tcp->socket, + tcp->peer_string.c_str(), data); + gpr_free(data); + } + } + + if (tcp->shutting_down) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, cb, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("TCP socket is shutting down")); + return; + } + + GPR_ASSERT(tcp->write_cb == nullptr); + tcp->write_slices = write_slices; + GPR_ASSERT(tcp->write_slices->count <= UINT_MAX); + if (tcp->write_slices->count == 0) { + // No slices means we don't have to do anything + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + return; + } + tcp->write_cb = cb; + TCP_REF(tcp, "write"); + grpc_custom_socket_vtable->write(tcp->socket, tcp->write_slices, + custom_write_callback); +} + +static void endpoint_add_to_pollset(grpc_endpoint* ep, grpc_pollset* pollset) { + // No-op. We're ignoring pollsets currently + (void)ep; + (void)pollset; +} + +static void endpoint_add_to_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset) { + // No-op. We're ignoring pollsets currently + (void)ep; + (void)pollset; +} + +static void endpoint_delete_from_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset) { + // No-op. We're ignoring pollsets currently + (void)ep; + (void)pollset; +} + +static void endpoint_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + if (!tcp->shutting_down) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP %p shutdown why=%s", tcp->socket, + grpc_error_std_string(why).c_str()); + } + tcp->shutting_down = true; + // grpc_core::ExecCtx::Run(DEBUG_LOCATION,tcp->read_cb, + // GRPC_ERROR_REF(why)); + // grpc_core::ExecCtx::Run(DEBUG_LOCATION,tcp->write_cb, + // GRPC_ERROR_REF(why)); tcp->read_cb = nullptr; tcp->write_cb = nullptr; + grpc_custom_socket_vtable->shutdown(tcp->socket); + } + GRPC_ERROR_UNREF(why); +} + +static void custom_close_callback(grpc_custom_socket* socket) { + socket->refs--; + if (socket->refs == 0) { + grpc_custom_socket_vtable->destroy(socket); + gpr_free(socket); + } else if (socket->endpoint) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + custom_tcp_endpoint* tcp = + reinterpret_cast(socket->endpoint); + TCP_UNREF(tcp, "destroy"); + } +} + +static void endpoint_destroy(grpc_endpoint* ep) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + grpc_custom_socket_vtable->close(tcp->socket, custom_close_callback); +} + +static absl::string_view endpoint_get_peer(grpc_endpoint* ep) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + return tcp->peer_string; +} + +static absl::string_view endpoint_get_local_address(grpc_endpoint* ep) { + custom_tcp_endpoint* tcp = reinterpret_cast(ep); + return tcp->local_address; +} + +static int endpoint_get_fd(grpc_endpoint* /*ep*/) { return -1; } + +static bool endpoint_can_track_err(grpc_endpoint* /*ep*/) { return false; } + +static grpc_endpoint_vtable vtable = {endpoint_read, + endpoint_write, + endpoint_add_to_pollset, + endpoint_add_to_pollset_set, + endpoint_delete_from_pollset_set, + endpoint_shutdown, + endpoint_destroy, + endpoint_get_peer, + endpoint_get_local_address, + endpoint_get_fd, + endpoint_can_track_err}; + +grpc_endpoint* custom_tcp_endpoint_create(grpc_custom_socket* socket, + grpc_slice_allocator* slice_allocator, + const char* peer_string) { + custom_tcp_endpoint* tcp = new custom_tcp_endpoint; + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "Creating TCP endpoint %p", socket); + } + socket->refs++; + socket->endpoint = reinterpret_cast(tcp); + tcp->socket = socket; + tcp->base.vtable = &vtable; + gpr_ref_init(&tcp->refcount, 1); + tcp->peer_string = peer_string; + grpc_resolved_address resolved_local_addr; + resolved_local_addr.len = sizeof(resolved_local_addr.addr); + if (grpc_custom_socket_vtable->getsockname( + socket, reinterpret_cast(resolved_local_addr.addr), + reinterpret_cast(&resolved_local_addr.len)) != + GRPC_ERROR_NONE) { + tcp->local_address = ""; + } else { + tcp->local_address = grpc_sockaddr_to_uri(&resolved_local_addr); + } + tcp->shutting_down = false; + tcp->slice_allocator = slice_allocator; + return &tcp->base; +} diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc new file mode 100644 index 00000000..a3e312b4 --- /dev/null +++ b/src/core/lib/iomgr/tcp_posix.cc @@ -0,0 +1,1848 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/buffer_list.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +#ifndef SOL_TCP +#define SOL_TCP IPPROTO_TCP +#endif + +#ifndef TCP_INQ +#define TCP_INQ 36 +#define TCP_CM_INQ TCP_INQ +#endif + +#ifdef GRPC_HAVE_MSG_NOSIGNAL +#define SENDMSG_FLAGS MSG_NOSIGNAL +#else +#define SENDMSG_FLAGS 0 +#endif + +// TCP zero copy sendmsg flag. +// NB: We define this here as a fallback in case we're using an older set of +// library headers that has not defined MSG_ZEROCOPY. Since this constant is +// part of the kernel, we are guaranteed it will never change/disagree so +// defining it here is safe. +#ifndef MSG_ZEROCOPY +#define MSG_ZEROCOPY 0x4000000 +#endif + +#ifdef GRPC_MSG_IOVLEN_TYPE +typedef GRPC_MSG_IOVLEN_TYPE msg_iovlen_type; +#else +typedef size_t msg_iovlen_type; +#endif + +extern grpc_core::TraceFlag grpc_tcp_trace; + +namespace grpc_core { + +class TcpZerocopySendRecord { + public: + TcpZerocopySendRecord() { grpc_slice_buffer_init(&buf_); } + + ~TcpZerocopySendRecord() { + AssertEmpty(); + grpc_slice_buffer_destroy_internal(&buf_); + } + + // Given the slices that we wish to send, and the current offset into the + // slice buffer (indicating which have already been sent), populate an iovec + // array that will be used for a zerocopy enabled sendmsg(). + msg_iovlen_type PopulateIovs(size_t* unwind_slice_idx, + size_t* unwind_byte_idx, size_t* sending_length, + iovec* iov); + + // A sendmsg() may not be able to send the bytes that we requested at this + // time, returning EAGAIN (possibly due to backpressure). In this case, + // unwind the offset into the slice buffer so we retry sending these bytes. + void UnwindIfThrottled(size_t unwind_slice_idx, size_t unwind_byte_idx) { + out_offset_.byte_idx = unwind_byte_idx; + out_offset_.slice_idx = unwind_slice_idx; + } + + // Update the offset into the slice buffer based on how much we wanted to sent + // vs. what sendmsg() actually sent (which may be lower, possibly due to + // backpressure). + void UpdateOffsetForBytesSent(size_t sending_length, size_t actually_sent); + + // Indicates whether all underlying data has been sent or not. + bool AllSlicesSent() { return out_offset_.slice_idx == buf_.count; } + + // Reset this structure for a new tcp_write() with zerocopy. + void PrepareForSends(grpc_slice_buffer* slices_to_send) { + AssertEmpty(); + out_offset_.slice_idx = 0; + out_offset_.byte_idx = 0; + grpc_slice_buffer_swap(slices_to_send, &buf_); + Ref(); + } + + // References: 1 reference per sendmsg(), and 1 for the tcp_write(). + void Ref() { ref_.fetch_add(1, std::memory_order_relaxed); } + + // Unref: called when we get an error queue notification for a sendmsg(), if a + // sendmsg() failed or when tcp_write() is done. + bool Unref() { + const intptr_t prior = ref_.fetch_sub(1, std::memory_order_acq_rel); + GPR_DEBUG_ASSERT(prior > 0); + if (prior == 1) { + AllSendsComplete(); + return true; + } + return false; + } + + private: + struct OutgoingOffset { + size_t slice_idx = 0; + size_t byte_idx = 0; + }; + + void AssertEmpty() { + GPR_DEBUG_ASSERT(buf_.count == 0); + GPR_DEBUG_ASSERT(buf_.length == 0); + GPR_DEBUG_ASSERT(ref_.load(std::memory_order_relaxed) == 0); + } + + // When all sendmsg() calls associated with this tcp_write() have been + // completed (ie. we have received the notifications for each sequence number + // for each sendmsg()) and all reference counts have been dropped, drop our + // reference to the underlying data since we no longer need it. + void AllSendsComplete() { + GPR_DEBUG_ASSERT(ref_.load(std::memory_order_relaxed) == 0); + grpc_slice_buffer_reset_and_unref_internal(&buf_); + } + + grpc_slice_buffer buf_; + std::atomic ref_{0}; + OutgoingOffset out_offset_; +}; + +class TcpZerocopySendCtx { + public: + static constexpr int kDefaultMaxSends = 4; + static constexpr size_t kDefaultSendBytesThreshold = 16 * 1024; // 16KB + + explicit TcpZerocopySendCtx( + int max_sends = kDefaultMaxSends, + size_t send_bytes_threshold = kDefaultSendBytesThreshold) + : max_sends_(max_sends), + free_send_records_size_(max_sends), + threshold_bytes_(send_bytes_threshold) { + send_records_ = static_cast( + gpr_malloc(max_sends * sizeof(*send_records_))); + free_send_records_ = static_cast( + gpr_malloc(max_sends * sizeof(*free_send_records_))); + if (send_records_ == nullptr || free_send_records_ == nullptr) { + gpr_free(send_records_); + gpr_free(free_send_records_); + gpr_log(GPR_INFO, "Disabling TCP TX zerocopy due to memory pressure.\n"); + memory_limited_ = true; + } else { + for (int idx = 0; idx < max_sends_; ++idx) { + new (send_records_ + idx) TcpZerocopySendRecord(); + free_send_records_[idx] = send_records_ + idx; + } + } + } + + ~TcpZerocopySendCtx() { + if (send_records_ != nullptr) { + for (int idx = 0; idx < max_sends_; ++idx) { + send_records_[idx].~TcpZerocopySendRecord(); + } + } + gpr_free(send_records_); + gpr_free(free_send_records_); + } + + // True if we were unable to allocate the various bookkeeping structures at + // transport initialization time. If memory limited, we do not zerocopy. + bool memory_limited() const { return memory_limited_; } + + // TCP send zerocopy maintains an implicit sequence number for every + // successful sendmsg() with zerocopy enabled; the kernel later gives us an + // error queue notification with this sequence number indicating that the + // underlying data buffers that we sent can now be released. Once that + // notification is received, we can release the buffers associated with this + // zerocopy send record. Here, we associate the sequence number with the data + // buffers that were sent with the corresponding call to sendmsg(). + void NoteSend(TcpZerocopySendRecord* record) { + record->Ref(); + AssociateSeqWithSendRecord(last_send_, record); + ++last_send_; + } + + // If sendmsg() actually failed, though, we need to revert the sequence number + // that we speculatively bumped before calling sendmsg(). Note that we bump + // this sequence number and perform relevant bookkeeping (see: NoteSend()) + // *before* calling sendmsg() since, if we called it *after* sendmsg(), then + // there is a possible race with the release notification which could occur on + // another thread before we do the necessary bookkeeping. Hence, calling + // NoteSend() *before* sendmsg() and implementing an undo function is needed. + void UndoSend() { + --last_send_; + if (ReleaseSendRecord(last_send_)->Unref()) { + // We should still be holding the ref taken by tcp_write(). + GPR_DEBUG_ASSERT(0); + } + } + + // Simply associate this send record (and the underlying sent data buffers) + // with the implicit sequence number for this zerocopy sendmsg(). + void AssociateSeqWithSendRecord(uint32_t seq, TcpZerocopySendRecord* record) { + MutexLock guard(&lock_); + ctx_lookup_.emplace(seq, record); + } + + // Get a send record for a send that we wish to do with zerocopy. + TcpZerocopySendRecord* GetSendRecord() { + MutexLock guard(&lock_); + return TryGetSendRecordLocked(); + } + + // A given send record corresponds to a single tcp_write() with zerocopy + // enabled. This can result in several sendmsg() calls to flush all of the + // data to wire. Each sendmsg() takes a reference on the + // TcpZerocopySendRecord, and corresponds to a single sequence number. + // ReleaseSendRecord releases a reference on TcpZerocopySendRecord for a + // single sequence number. This is called either when we receive the relevant + // error queue notification (saying that we can discard the underlying + // buffers for this sendmsg()) is received from the kernel - or, in case + // sendmsg() was unsuccessful to begin with. + TcpZerocopySendRecord* ReleaseSendRecord(uint32_t seq) { + MutexLock guard(&lock_); + return ReleaseSendRecordLocked(seq); + } + + // After all the references to a TcpZerocopySendRecord are released, we can + // add it back to the pool (of size max_sends_). Note that we can only have + // max_sends_ tcp_write() instances with zerocopy enabled in flight at the + // same time. + void PutSendRecord(TcpZerocopySendRecord* record) { + GPR_DEBUG_ASSERT(record >= send_records_ && + record < send_records_ + max_sends_); + MutexLock guard(&lock_); + PutSendRecordLocked(record); + } + + // Indicate that we are disposing of this zerocopy context. This indicator + // will prevent new zerocopy writes from being issued. + void Shutdown() { shutdown_.store(true, std::memory_order_release); } + + // Indicates that there are no inflight tcp_write() instances with zerocopy + // enabled. + bool AllSendRecordsEmpty() { + MutexLock guard(&lock_); + return free_send_records_size_ == max_sends_; + } + + bool enabled() const { return enabled_; } + + void set_enabled(bool enabled) { + GPR_DEBUG_ASSERT(!enabled || !memory_limited()); + enabled_ = enabled; + } + + // Only use zerocopy if we are sending at least this many bytes. The + // additional overhead of reading the error queue for notifications means that + // zerocopy is not useful for small transfers. + size_t threshold_bytes() const { return threshold_bytes_; } + + private: + TcpZerocopySendRecord* ReleaseSendRecordLocked(uint32_t seq) { + auto iter = ctx_lookup_.find(seq); + GPR_DEBUG_ASSERT(iter != ctx_lookup_.end()); + TcpZerocopySendRecord* record = iter->second; + ctx_lookup_.erase(iter); + return record; + } + + TcpZerocopySendRecord* TryGetSendRecordLocked() { + if (shutdown_.load(std::memory_order_acquire)) { + return nullptr; + } + if (free_send_records_size_ == 0) { + return nullptr; + } + free_send_records_size_--; + return free_send_records_[free_send_records_size_]; + } + + void PutSendRecordLocked(TcpZerocopySendRecord* record) { + GPR_DEBUG_ASSERT(free_send_records_size_ < max_sends_); + free_send_records_[free_send_records_size_] = record; + free_send_records_size_++; + } + + TcpZerocopySendRecord* send_records_; + TcpZerocopySendRecord** free_send_records_; + int max_sends_; + int free_send_records_size_; + Mutex lock_; + uint32_t last_send_ = 0; + std::atomic shutdown_{false}; + bool enabled_ = false; + size_t threshold_bytes_ = kDefaultSendBytesThreshold; + std::unordered_map ctx_lookup_; + bool memory_limited_ = false; +}; + +} // namespace grpc_core + +using grpc_core::TcpZerocopySendCtx; +using grpc_core::TcpZerocopySendRecord; + +namespace { +struct grpc_tcp { + grpc_tcp(int max_sends, size_t send_bytes_threshold) + : tcp_zerocopy_send_ctx(max_sends, send_bytes_threshold) {} + grpc_endpoint base; + grpc_fd* em_fd; + int fd; + /* Used by the endpoint read function to distinguish the very first read call + * from the rest */ + bool is_first_read; + double target_length; + double bytes_read_this_round; + grpc_core::RefCount refcount; + gpr_atm shutdown_count; + + int min_read_chunk_size; + int max_read_chunk_size; + + /* garbage after the last read */ + grpc_slice_buffer last_read_buffer; + + grpc_slice_buffer* incoming_buffer; + int inq; /* bytes pending on the socket from the last read. */ + bool inq_capable; /* cache whether kernel supports inq */ + + grpc_slice_buffer* outgoing_buffer; + /* byte within outgoing_buffer->slices[0] to write next */ + size_t outgoing_byte_idx; + + grpc_closure* read_cb; + grpc_closure* write_cb; + grpc_closure* release_fd_cb; + int* release_fd; + + grpc_closure read_done_closure; + grpc_closure write_done_closure; + grpc_closure error_closure; + + std::string peer_string; + std::string local_address; + + grpc_slice_allocator* slice_allocator; + + grpc_core::TracedBuffer* tb_head; /* List of traced buffers */ + gpr_mu tb_mu; /* Lock for access to list of traced buffers */ + + /* grpc_endpoint_write takes an argument which if non-null means that the + * transport layer wants the TCP layer to collect timestamps for this write. + * This arg is forwarded to the timestamps callback function when the ACK + * timestamp is received from the kernel. This arg is a (void *) which allows + * users of this API to pass in a pointer to any kind of structure. This + * structure could actually be a tag or any book-keeping object that the user + * can use to distinguish between different traced writes. The only + * requirement from the TCP endpoint layer is that this arg should be non-null + * if the user wants timestamps for the write. */ + void* outgoing_buffer_arg; + /* A counter which starts at 0. It is initialized the first time the socket + * options for collecting timestamps are set, and is incremented with each + * byte sent. */ + int bytes_counter; + bool socket_ts_enabled; /* True if timestamping options are set on the socket + */ + bool ts_capable; /* Cache whether we can set timestamping options */ + gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified + on errors anymore */ + TcpZerocopySendCtx tcp_zerocopy_send_ctx; + TcpZerocopySendRecord* current_zerocopy_send = nullptr; +}; + +struct backup_poller { + gpr_mu* pollset_mu; + grpc_closure run_poller; +}; + +} // namespace + +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp); + +#define BACKUP_POLLER_POLLSET(b) ((grpc_pollset*)((b) + 1)) + +static grpc_core::Mutex* g_backup_poller_mu = nullptr; +static int g_uncovered_notifications_pending + ABSL_GUARDED_BY(g_backup_poller_mu); +static backup_poller* g_backup_poller ABSL_GUARDED_BY(g_backup_poller_mu); + +static void tcp_handle_read(void* arg /* grpc_tcp */, grpc_error_handle error); +static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error_handle error); +static void tcp_drop_uncovered_then_handle_write(void* arg /* grpc_tcp */, + grpc_error_handle error); + +static void done_poller(void* bp, grpc_error_handle /*error_ignored*/) { + backup_poller* p = static_cast(bp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p destroy", p); + } + grpc_pollset_destroy(BACKUP_POLLER_POLLSET(p)); + gpr_free(p); +} + +static void run_poller(void* bp, grpc_error_handle /*error_ignored*/) { + backup_poller* p = static_cast(bp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p run", p); + } + gpr_mu_lock(p->pollset_mu); + grpc_millis deadline = grpc_core::ExecCtx::Get()->Now() + 10 * GPR_MS_PER_SEC; + GRPC_STATS_INC_TCP_BACKUP_POLLER_POLLS(); + GRPC_LOG_IF_ERROR( + "backup_poller:pollset_work", + grpc_pollset_work(BACKUP_POLLER_POLLSET(p), nullptr, deadline)); + gpr_mu_unlock(p->pollset_mu); + g_backup_poller_mu->Lock(); + /* last "uncovered" notification is the ref that keeps us polling */ + if (g_uncovered_notifications_pending == 1) { + GPR_ASSERT(g_backup_poller == p); + g_backup_poller = nullptr; + g_uncovered_notifications_pending = 0; + g_backup_poller_mu->Unlock(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p shutdown", p); + } + grpc_pollset_shutdown(BACKUP_POLLER_POLLSET(p), + GRPC_CLOSURE_INIT(&p->run_poller, done_poller, p, + grpc_schedule_on_exec_ctx)); + } else { + g_backup_poller_mu->Unlock(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p reschedule", p); + } + grpc_core::Executor::Run(&p->run_poller, GRPC_ERROR_NONE, + grpc_core::ExecutorType::DEFAULT, + grpc_core::ExecutorJobType::LONG); + } +} + +static void drop_uncovered(grpc_tcp* /*tcp*/) { + int old_count; + backup_poller* p; + g_backup_poller_mu->Lock(); + p = g_backup_poller; + old_count = g_uncovered_notifications_pending--; + g_backup_poller_mu->Unlock(); + GPR_ASSERT(old_count > 1); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p uncover cnt %d->%d", p, old_count, + old_count - 1); + } +} + +// gRPC API considers a Write operation to be done the moment it clears ‘flow +// control’ i.e., not necessarily sent on the wire. This means that the +// application MIGHT not call `grpc_completion_queue_next/pluck` in a timely +// manner when its `Write()` API is acked. +// +// We need to ensure that the fd is 'covered' (i.e being monitored by some +// polling thread and progress is made) and hence add it to a backup poller here +static void cover_self(grpc_tcp* tcp) { + backup_poller* p; + g_backup_poller_mu->Lock(); + int old_count = 0; + if (g_uncovered_notifications_pending == 0) { + g_uncovered_notifications_pending = 2; + p = static_cast( + gpr_zalloc(sizeof(*p) + grpc_pollset_size())); + g_backup_poller = p; + grpc_pollset_init(BACKUP_POLLER_POLLSET(p), &p->pollset_mu); + g_backup_poller_mu->Unlock(); + GRPC_STATS_INC_TCP_BACKUP_POLLERS_CREATED(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p create", p); + } + grpc_core::Executor::Run( + GRPC_CLOSURE_INIT(&p->run_poller, run_poller, p, nullptr), + GRPC_ERROR_NONE, grpc_core::ExecutorType::DEFAULT, + grpc_core::ExecutorJobType::LONG); + } else { + old_count = g_uncovered_notifications_pending++; + p = g_backup_poller; + g_backup_poller_mu->Unlock(); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "BACKUP_POLLER:%p add %p cnt %d->%d", p, tcp, + old_count - 1, old_count); + } + grpc_pollset_add_fd(BACKUP_POLLER_POLLSET(p), tcp->em_fd); +} + +static void notify_on_read(grpc_tcp* tcp) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p notify_on_read", tcp); + } + grpc_fd_notify_on_read(tcp->em_fd, &tcp->read_done_closure); +} + +static void notify_on_write(grpc_tcp* tcp) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p notify_on_write", tcp); + } + if (!grpc_event_engine_run_in_background()) { + cover_self(tcp); + } + grpc_fd_notify_on_write(tcp->em_fd, &tcp->write_done_closure); +} + +static void tcp_drop_uncovered_then_handle_write(void* arg, + grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p got_write: %s", arg, + grpc_error_std_string(error).c_str()); + } + drop_uncovered(static_cast(arg)); + tcp_handle_write(arg, error); +} + +static void add_to_estimate(grpc_tcp* tcp, size_t bytes) { + tcp->bytes_read_this_round += static_cast(bytes); +} + +static void finish_estimate(grpc_tcp* tcp) { + /* If we read >80% of the target buffer in one read loop, increase the size + of the target buffer to either the amount read, or twice its previous + value */ + if (tcp->bytes_read_this_round > tcp->target_length * 0.8) { + tcp->target_length = + std::max(2 * tcp->target_length, tcp->bytes_read_this_round); + } else { + tcp->target_length = + 0.99 * tcp->target_length + 0.01 * tcp->bytes_read_this_round; + } + tcp->bytes_read_this_round = 0; +} + +static grpc_error_handle tcp_annotate_error(grpc_error_handle src_error, + grpc_tcp* tcp) { + return grpc_error_set_str( + grpc_error_set_int( + grpc_error_set_int(src_error, GRPC_ERROR_INT_FD, tcp->fd), + /* All tcp errors are marked with UNAVAILABLE so that application may + * choose to retry. */ + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), + GRPC_ERROR_STR_TARGET_ADDRESS, tcp->peer_string); +} + +static void tcp_handle_read(void* arg /* grpc_tcp */, grpc_error_handle error); +static void tcp_handle_write(void* arg /* grpc_tcp */, grpc_error_handle error); + +static void tcp_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + grpc_tcp* tcp = reinterpret_cast(ep); + ZerocopyDisableAndWaitForRemaining(tcp); + grpc_fd_shutdown(tcp->em_fd, why); +} + +static void tcp_free(grpc_tcp* tcp) { + grpc_fd_orphan(tcp->em_fd, tcp->release_fd_cb, tcp->release_fd, + "tcp_unref_orphan"); + grpc_slice_buffer_destroy_internal(&tcp->last_read_buffer); + grpc_slice_allocator_destroy(tcp->slice_allocator); + /* The lock is not really necessary here, since all refs have been released */ + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::Shutdown( + &tcp->tb_head, tcp->outgoing_buffer_arg, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed")); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; + gpr_mu_destroy(&tcp->tb_mu); + delete tcp; +} + +#ifndef NDEBUG +#define TCP_UNREF(tcp, reason) tcp_unref((tcp), (reason), DEBUG_LOCATION) +#define TCP_REF(tcp, reason) tcp_ref((tcp), (reason), DEBUG_LOCATION) +static void tcp_unref(grpc_tcp* tcp, const char* reason, + const grpc_core::DebugLocation& debug_location) { + if (GPR_UNLIKELY(tcp->refcount.Unref(debug_location, reason))) { + tcp_free(tcp); + } +} + +static void tcp_ref(grpc_tcp* tcp, const char* reason, + const grpc_core::DebugLocation& debug_location) { + tcp->refcount.Ref(debug_location, reason); +} +#else +#define TCP_UNREF(tcp, reason) tcp_unref((tcp)) +#define TCP_REF(tcp, reason) tcp_ref((tcp)) +static void tcp_unref(grpc_tcp* tcp) { + if (GPR_UNLIKELY(tcp->refcount.Unref())) { + tcp_free(tcp); + } +} + +static void tcp_ref(grpc_tcp* tcp) { tcp->refcount.Ref(); } +#endif + +static void tcp_destroy(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); + if (grpc_event_engine_can_track_errors()) { + ZerocopyDisableAndWaitForRemaining(tcp); + gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); + grpc_fd_set_error(tcp->em_fd); + } + TCP_UNREF(tcp, "destroy"); +} + +static void call_read_cb(grpc_tcp* tcp, grpc_error_handle error) { + grpc_closure* cb = tcp->read_cb; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p call_cb %p %p:%p", tcp, cb, cb->cb, cb->cb_arg); + size_t i; + gpr_log(GPR_INFO, "READ %p (peer=%s) error=%s", tcp, + tcp->peer_string.c_str(), grpc_error_std_string(error).c_str()); + if (gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + for (i = 0; i < tcp->incoming_buffer->count; i++) { + char* dump = grpc_dump_slice(tcp->incoming_buffer->slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "DATA: %s", dump); + gpr_free(dump); + } + } + } + + tcp->read_cb = nullptr; + tcp->incoming_buffer = nullptr; + grpc_core::Closure::Run(DEBUG_LOCATION, cb, error); +} + +#define MAX_READ_IOVEC 4 +static void tcp_do_read(grpc_tcp* tcp) { + GPR_TIMER_SCOPE("tcp_do_read", 0); + struct msghdr msg; + struct iovec iov[MAX_READ_IOVEC]; + ssize_t read_bytes; + size_t total_read_bytes = 0; + size_t iov_len = + std::min(MAX_READ_IOVEC, tcp->incoming_buffer->count); +#ifdef GRPC_LINUX_ERRQUEUE + constexpr size_t cmsg_alloc_space = + CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) + CMSG_SPACE(sizeof(int)); +#else + constexpr size_t cmsg_alloc_space = 24 /* CMSG_SPACE(sizeof(int)) */; +#endif /* GRPC_LINUX_ERRQUEUE */ + char cmsgbuf[cmsg_alloc_space]; + for (size_t i = 0; i < iov_len; i++) { + iov[i].iov_base = GRPC_SLICE_START_PTR(tcp->incoming_buffer->slices[i]); + iov[i].iov_len = GRPC_SLICE_LENGTH(tcp->incoming_buffer->slices[i]); + } + + do { + /* Assume there is something on the queue. If we receive TCP_INQ from + * kernel, we will update this value, otherwise, we have to assume there is + * always something to read until we get EAGAIN. */ + tcp->inq = 1; + + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = iov; + msg.msg_iovlen = static_cast(iov_len); + if (tcp->inq_capable) { + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + } else { + msg.msg_control = nullptr; + msg.msg_controllen = 0; + } + msg.msg_flags = 0; + + GRPC_STATS_INC_TCP_READ_OFFER(tcp->incoming_buffer->length); + GRPC_STATS_INC_TCP_READ_OFFER_IOV_SIZE(tcp->incoming_buffer->count); + + do { + GPR_TIMER_SCOPE("recvmsg", 0); + GRPC_STATS_INC_SYSCALL_READ(); + read_bytes = recvmsg(tcp->fd, &msg, 0); + } while (read_bytes < 0 && errno == EINTR); + + /* We have read something in previous reads. We need to deliver those + * bytes to the upper layer. */ + if (read_bytes <= 0 && total_read_bytes > 0) { + tcp->inq = 1; + break; + } + + if (read_bytes < 0) { + /* NB: After calling call_read_cb a parallel call of the read handler may + * be running. */ + if (errno == EAGAIN) { + finish_estimate(tcp); + tcp->inq = 0; + /* We've consumed the edge, request a new one */ + notify_on_read(tcp); + } else { + grpc_slice_buffer_reset_and_unref_internal(tcp->incoming_buffer); + call_read_cb(tcp, + tcp_annotate_error(GRPC_OS_ERROR(errno, "recvmsg"), tcp)); + TCP_UNREF(tcp, "read"); + } + return; + } + if (read_bytes == 0) { + /* 0 read size ==> end of stream + * + * We may have read something, i.e., total_read_bytes > 0, but + * since the connection is closed we will drop the data here, because we + * can't call the callback multiple times. */ + grpc_slice_buffer_reset_and_unref_internal(tcp->incoming_buffer); + call_read_cb( + tcp, tcp_annotate_error( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Socket closed"), tcp)); + TCP_UNREF(tcp, "read"); + return; + } + + GRPC_STATS_INC_TCP_READ_SIZE(read_bytes); + add_to_estimate(tcp, static_cast(read_bytes)); + GPR_DEBUG_ASSERT((size_t)read_bytes <= + tcp->incoming_buffer->length - total_read_bytes); + +#ifdef GRPC_HAVE_TCP_INQ + if (tcp->inq_capable) { + GPR_DEBUG_ASSERT(!(msg.msg_flags & MSG_CTRUNC)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + for (; cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == SOL_TCP && cmsg->cmsg_type == TCP_CM_INQ && + cmsg->cmsg_len == CMSG_LEN(sizeof(int))) { + tcp->inq = *reinterpret_cast(CMSG_DATA(cmsg)); + break; + } + } + } +#endif /* GRPC_HAVE_TCP_INQ */ + + total_read_bytes += read_bytes; + if (tcp->inq == 0 || total_read_bytes == tcp->incoming_buffer->length) { + /* We have filled incoming_buffer, and we cannot read any more. */ + break; + } + + /* We had a partial read, and still have space to read more data. + * So, adjust IOVs and try to read more. */ + size_t remaining = read_bytes; + size_t j = 0; + for (size_t i = 0; i < iov_len; i++) { + if (remaining >= iov[i].iov_len) { + remaining -= iov[i].iov_len; + continue; + } + if (remaining > 0) { + iov[j].iov_base = static_cast(iov[i].iov_base) + remaining; + iov[j].iov_len = iov[i].iov_len - remaining; + remaining = 0; + } else { + iov[j].iov_base = iov[i].iov_base; + iov[j].iov_len = iov[i].iov_len; + } + ++j; + } + iov_len = j; + } while (true); + + if (tcp->inq == 0) { + finish_estimate(tcp); + } + + GPR_DEBUG_ASSERT(total_read_bytes > 0); + if (total_read_bytes < tcp->incoming_buffer->length) { + grpc_slice_buffer_trim_end(tcp->incoming_buffer, + tcp->incoming_buffer->length - total_read_bytes, + &tcp->last_read_buffer); + } + call_read_cb(tcp, GRPC_ERROR_NONE); + TCP_UNREF(tcp, "read"); +} + +static void tcp_read_allocation_done(void* tcpp, grpc_error_handle error) { + grpc_tcp* tcp = static_cast(tcpp); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p read_allocation_done: %s", tcp, + grpc_error_std_string(error).c_str()); + } + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + grpc_slice_buffer_reset_and_unref_internal(tcp->incoming_buffer); + grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); + call_read_cb(tcp, GRPC_ERROR_REF(error)); + TCP_UNREF(tcp, "read"); + } else { + tcp_do_read(tcp); + } +} + +static void tcp_continue_read(grpc_tcp* tcp) { + /* Wait for allocation only when there is no buffer left. */ + if (tcp->incoming_buffer->length == 0 && + tcp->incoming_buffer->count < MAX_READ_IOVEC) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p alloc_slices", tcp); + } + if (GPR_UNLIKELY(!grpc_slice_allocator_allocate( + tcp->slice_allocator, tcp->target_length, 1, + grpc_slice_allocator_intent::kReadBuffer, tcp->incoming_buffer, + tcp_read_allocation_done, tcp))) { + // Wait for allocation. + return; + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p do_read", tcp); + } + tcp_do_read(tcp); +} + +static void tcp_handle_read(void* arg /* grpc_tcp */, grpc_error_handle error) { + grpc_tcp* tcp = static_cast(arg); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p got_read: %s", tcp, + grpc_error_std_string(error).c_str()); + } + + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + grpc_slice_buffer_reset_and_unref_internal(tcp->incoming_buffer); + grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); + call_read_cb(tcp, GRPC_ERROR_REF(error)); + TCP_UNREF(tcp, "read"); + } else { + tcp_continue_read(tcp); + } +} + +static void tcp_read(grpc_endpoint* ep, grpc_slice_buffer* incoming_buffer, + grpc_closure* cb, bool urgent) { + grpc_tcp* tcp = reinterpret_cast(ep); + GPR_ASSERT(tcp->read_cb == nullptr); + tcp->read_cb = cb; + tcp->incoming_buffer = incoming_buffer; + grpc_slice_buffer_reset_and_unref_internal(incoming_buffer); + grpc_slice_buffer_swap(incoming_buffer, &tcp->last_read_buffer); + TCP_REF(tcp, "read"); + if (tcp->is_first_read) { + /* Endpoint read called for the very first time. Register read callback with + * the polling engine */ + tcp->is_first_read = false; + notify_on_read(tcp); + } else if (!urgent && tcp->inq == 0) { + /* Upper layer asked to read more but we know there is no pending data + * to read from previous reads. So, wait for POLLIN. + */ + notify_on_read(tcp); + } else { + /* Not the first time. We may or may not have more bytes available. In any + * case call tcp->read_done_closure (i.e tcp_handle_read()) which does the + * right thing (i.e calls tcp_do_read() which either reads the available + * bytes or calls notify_on_read() to be notified when new bytes become + * available */ + grpc_core::Closure::Run(DEBUG_LOCATION, &tcp->read_done_closure, + GRPC_ERROR_NONE); + } +} + +/* A wrapper around sendmsg. It sends \a msg over \a fd and returns the number + * of bytes sent. */ +ssize_t tcp_send(int fd, const struct msghdr* msg, int additional_flags = 0) { + GPR_TIMER_SCOPE("sendmsg", 1); + ssize_t sent_length; + do { + /* TODO(klempner): Cork if this is a partial write */ + GRPC_STATS_INC_SYSCALL_WRITE(); + sent_length = sendmsg(fd, msg, SENDMSG_FLAGS | additional_flags); + } while (sent_length < 0 && errno == EINTR); + return sent_length; +} + +/** This is to be called if outgoing_buffer_arg is not null. On linux platforms, + * this will call sendmsg with socket options set to collect timestamps inside + * the kernel. On return, sent_length is set to the return value of the sendmsg + * call. Returns false if setting the socket options failed. This is not + * implemented for non-linux platforms currently, and crashes out. + */ +static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, + size_t sending_length, + ssize_t* sent_length, + int additional_flags = 0); + +/** The callback function to be invoked when we get an error on the socket. */ +static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error_handle error); + +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* tcp, grpc_slice_buffer* buf); + +#ifdef GRPC_LINUX_ERRQUEUE +static bool process_errors(grpc_tcp* tcp); + +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* tcp, grpc_slice_buffer* buf) { + TcpZerocopySendRecord* zerocopy_send_record = nullptr; + const bool use_zerocopy = + tcp->tcp_zerocopy_send_ctx.enabled() && + tcp->tcp_zerocopy_send_ctx.threshold_bytes() < buf->length; + if (use_zerocopy) { + zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord(); + if (zerocopy_send_record == nullptr) { + process_errors(tcp); + zerocopy_send_record = tcp->tcp_zerocopy_send_ctx.GetSendRecord(); + } + if (zerocopy_send_record != nullptr) { + zerocopy_send_record->PrepareForSends(buf); + GPR_DEBUG_ASSERT(buf->count == 0); + GPR_DEBUG_ASSERT(buf->length == 0); + tcp->outgoing_byte_idx = 0; + tcp->outgoing_buffer = nullptr; + } + } + return zerocopy_send_record; +} + +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* tcp) { + tcp->tcp_zerocopy_send_ctx.Shutdown(); + while (!tcp->tcp_zerocopy_send_ctx.AllSendRecordsEmpty()) { + process_errors(tcp); + } +} + +static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, + size_t sending_length, + ssize_t* sent_length, + int additional_flags) { + if (!tcp->socket_ts_enabled) { + uint32_t opt = grpc_core::kTimestampingSocketOptions; + if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, + static_cast(&opt), sizeof(opt)) != 0) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_ERROR, "Failed to set timestamping options on the socket."); + } + return false; + } + tcp->bytes_counter = -1; + tcp->socket_ts_enabled = true; + } + /* Set control message to indicate that you want timestamps. */ + union { + char cmsg_buf[CMSG_SPACE(sizeof(uint32_t))]; + struct cmsghdr align; + } u; + cmsghdr* cmsg = reinterpret_cast(u.cmsg_buf); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SO_TIMESTAMPING; + cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t)); + *reinterpret_cast(CMSG_DATA(cmsg)) = + grpc_core::kTimestampingRecordingOptions; + msg->msg_control = u.cmsg_buf; + msg->msg_controllen = CMSG_SPACE(sizeof(uint32_t)); + + /* If there was an error on sendmsg the logic in tcp_flush will handle it. */ + ssize_t length = tcp_send(tcp->fd, msg, additional_flags); + *sent_length = length; + /* Only save timestamps if all the bytes were taken by sendmsg. */ + if (sending_length == static_cast(length)) { + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::AddNewEntry( + &tcp->tb_head, static_cast(tcp->bytes_counter + length), + tcp->fd, tcp->outgoing_buffer_arg); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; + } + return true; +} + +static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp, + TcpZerocopySendRecord* record, + uint32_t seq, const char* tag); +// Reads \a cmsg to process zerocopy control messages. +static void process_zerocopy(grpc_tcp* tcp, struct cmsghdr* cmsg) { + GPR_DEBUG_ASSERT(cmsg); + auto serr = reinterpret_cast(CMSG_DATA(cmsg)); + GPR_DEBUG_ASSERT(serr->ee_errno == 0); + GPR_DEBUG_ASSERT(serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY); + const uint32_t lo = serr->ee_info; + const uint32_t hi = serr->ee_data; + for (uint32_t seq = lo; seq <= hi; ++seq) { + // TODO(arjunroy): It's likely that lo and hi refer to zerocopy sequence + // numbers that are generated by a single call to grpc_endpoint_write; ie. + // we can batch the unref operation. So, check if record is the same for + // both; if so, batch the unref/put. + TcpZerocopySendRecord* record = + tcp->tcp_zerocopy_send_ctx.ReleaseSendRecord(seq); + GPR_DEBUG_ASSERT(record); + UnrefMaybePutZerocopySendRecord(tcp, record, seq, "CALLBACK RCVD"); + } +} + +// Whether the cmsg received from error queue is of the IPv4 or IPv6 levels. +static bool CmsgIsIpLevel(const cmsghdr& cmsg) { + return (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR) || + (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR); +} + +static bool CmsgIsZeroCopy(const cmsghdr& cmsg) { + if (!CmsgIsIpLevel(cmsg)) { + return false; + } + auto serr = reinterpret_cast CMSG_DATA(&cmsg); + return serr->ee_errno == 0 && serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY; +} + +/** Reads \a cmsg to derive timestamps from the control messages. If a valid + * timestamp is found, the traced buffer list is updated with this timestamp. + * The caller of this function should be looping on the control messages found + * in \a msg. \a cmsg should point to the control message that the caller wants + * processed. + * On return, a pointer to a control message is returned. On the next iteration, + * CMSG_NXTHDR(msg, ret_val) should be passed as \a cmsg. */ +struct cmsghdr* process_timestamp(grpc_tcp* tcp, msghdr* msg, + struct cmsghdr* cmsg) { + auto next_cmsg = CMSG_NXTHDR(msg, cmsg); + cmsghdr* opt_stats = nullptr; + if (next_cmsg == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_ERROR, "Received timestamp without extended error"); + } + return cmsg; + } + + /* Check if next_cmsg is an OPT_STATS msg */ + if (next_cmsg->cmsg_level == SOL_SOCKET && + next_cmsg->cmsg_type == SCM_TIMESTAMPING_OPT_STATS) { + opt_stats = next_cmsg; + next_cmsg = CMSG_NXTHDR(msg, opt_stats); + if (next_cmsg == nullptr) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_ERROR, "Received timestamp without extended error"); + } + return opt_stats; + } + } + + if (!(next_cmsg->cmsg_level == SOL_IP || next_cmsg->cmsg_level == SOL_IPV6) || + !(next_cmsg->cmsg_type == IP_RECVERR || + next_cmsg->cmsg_type == IPV6_RECVERR)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_ERROR, "Unexpected control message"); + } + return cmsg; + } + + auto tss = + reinterpret_cast(CMSG_DATA(cmsg)); + auto serr = reinterpret_cast(CMSG_DATA(next_cmsg)); + if (serr->ee_errno != ENOMSG || + serr->ee_origin != SO_EE_ORIGIN_TIMESTAMPING) { + gpr_log(GPR_ERROR, "Unexpected control message"); + return cmsg; + } + /* The error handling can potentially be done on another thread so we need + * to protect the traced buffer list. A lock free list might be better. Using + * a simple mutex for now. */ + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::ProcessTimestamp(&tcp->tb_head, serr, opt_stats, + tss); + gpr_mu_unlock(&tcp->tb_mu); + return next_cmsg; +} + +/** For linux platforms, reads the socket's error queue and processes error + * messages from the queue. + */ +static bool process_errors(grpc_tcp* tcp) { + bool processed_err = false; + struct iovec iov; + iov.iov_base = nullptr; + iov.iov_len = 0; + struct msghdr msg; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = &iov; + msg.msg_iovlen = 0; + msg.msg_flags = 0; + /* Allocate enough space so we don't need to keep increasing this as size + * of OPT_STATS increase */ + constexpr size_t cmsg_alloc_space = + CMSG_SPACE(sizeof(grpc_core::scm_timestamping)) + + CMSG_SPACE(sizeof(sock_extended_err) + sizeof(sockaddr_in)) + + CMSG_SPACE(32 * NLA_ALIGN(NLA_HDRLEN + sizeof(uint64_t))); + /* Allocate aligned space for cmsgs received along with timestamps */ + union { + char rbuf[cmsg_alloc_space]; + struct cmsghdr align; + } aligned_buf; + msg.msg_control = aligned_buf.rbuf; + int r, saved_errno; + while (true) { + msg.msg_controllen = sizeof(aligned_buf.rbuf); + do { + r = recvmsg(tcp->fd, &msg, MSG_ERRQUEUE); + saved_errno = errno; + } while (r < 0 && saved_errno == EINTR); + + if (r == -1 && saved_errno == EAGAIN) { + return processed_err; /* No more errors to process */ + } + if (r == -1) { + return processed_err; + } + if (GPR_UNLIKELY((msg.msg_flags & MSG_CTRUNC) != 0)) { + gpr_log(GPR_ERROR, "Error message was truncated."); + } + + if (msg.msg_controllen == 0) { + /* There was no control message found. It was probably spurious. */ + return processed_err; + } + bool seen = false; + for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg && cmsg->cmsg_len; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (CmsgIsZeroCopy(*cmsg)) { + process_zerocopy(tcp, cmsg); + seen = true; + processed_err = true; + } else if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_TIMESTAMPING) { + cmsg = process_timestamp(tcp, &msg, cmsg); + seen = true; + processed_err = true; + } else { + /* Got a control message that is not a timestamp or zerocopy. Don't know + * how to handle this. */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, + "unknown control message cmsg_level:%d cmsg_type:%d", + cmsg->cmsg_level, cmsg->cmsg_type); + } + return processed_err; + } + } + if (!seen) { + return processed_err; + } + } +} + +static void tcp_handle_error(void* arg /* grpc_tcp */, + grpc_error_handle error) { + grpc_tcp* tcp = static_cast(arg); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "TCP:%p got_error: %s", tcp, + grpc_error_std_string(error).c_str()); + } + + if (error != GRPC_ERROR_NONE || + static_cast(gpr_atm_acq_load(&tcp->stop_error_notification))) { + /* We aren't going to register to hear on error anymore, so it is safe to + * unref. */ + TCP_UNREF(tcp, "error-tracking"); + return; + } + + /* We are still interested in collecting timestamps, so let's try reading + * them. */ + bool processed = process_errors(tcp); + /* This might not a timestamps error. Set the read and write closures to be + * ready. */ + if (!processed) { + grpc_fd_set_readable(tcp->em_fd); + grpc_fd_set_writable(tcp->em_fd); + } + grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure); +} + +#else /* GRPC_LINUX_ERRQUEUE */ +static TcpZerocopySendRecord* tcp_get_send_zerocopy_record( + grpc_tcp* /*tcp*/, grpc_slice_buffer* /*buf*/) { + return nullptr; +} + +static void ZerocopyDisableAndWaitForRemaining(grpc_tcp* /*tcp*/) {} + +static bool tcp_write_with_timestamps(grpc_tcp* /*tcp*/, struct msghdr* /*msg*/, + size_t /*sending_length*/, + ssize_t* /*sent_length*/, + int /*additional_flags*/) { + gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); + GPR_ASSERT(0); + return false; +} + +static void tcp_handle_error(void* /*arg*/ /* grpc_tcp */, + grpc_error_handle /*error*/) { + gpr_log(GPR_ERROR, "Error handling is not supported for this platform"); + GPR_ASSERT(0); +} +#endif /* GRPC_LINUX_ERRQUEUE */ + +/* If outgoing_buffer_arg is filled, shuts down the list early, so that any + * release operations needed can be performed on the arg */ +void tcp_shutdown_buffer_list(grpc_tcp* tcp) { + if (tcp->outgoing_buffer_arg) { + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::Shutdown( + &tcp->tb_head, tcp->outgoing_buffer_arg, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("TracedBuffer list shutdown")); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; + } +} + +#if defined(IOV_MAX) && IOV_MAX < 260 +#define MAX_WRITE_IOVEC IOV_MAX +#else +#define MAX_WRITE_IOVEC 260 +#endif +msg_iovlen_type TcpZerocopySendRecord::PopulateIovs(size_t* unwind_slice_idx, + size_t* unwind_byte_idx, + size_t* sending_length, + iovec* iov) { + msg_iovlen_type iov_size; + *unwind_slice_idx = out_offset_.slice_idx; + *unwind_byte_idx = out_offset_.byte_idx; + for (iov_size = 0; + out_offset_.slice_idx != buf_.count && iov_size != MAX_WRITE_IOVEC; + iov_size++) { + iov[iov_size].iov_base = + GRPC_SLICE_START_PTR(buf_.slices[out_offset_.slice_idx]) + + out_offset_.byte_idx; + iov[iov_size].iov_len = + GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]) - + out_offset_.byte_idx; + *sending_length += iov[iov_size].iov_len; + ++(out_offset_.slice_idx); + out_offset_.byte_idx = 0; + } + GPR_DEBUG_ASSERT(iov_size > 0); + return iov_size; +} + +void TcpZerocopySendRecord::UpdateOffsetForBytesSent(size_t sending_length, + size_t actually_sent) { + size_t trailing = sending_length - actually_sent; + while (trailing > 0) { + size_t slice_length; + out_offset_.slice_idx--; + slice_length = GRPC_SLICE_LENGTH(buf_.slices[out_offset_.slice_idx]); + if (slice_length > trailing) { + out_offset_.byte_idx = slice_length - trailing; + break; + } else { + trailing -= slice_length; + } + } +} + +// returns true if done, false if pending; if returning true, *error is set +static bool do_tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record, + grpc_error_handle* error) { + msg_iovlen_type iov_size; + ssize_t sent_length = 0; + size_t sending_length; + size_t unwind_slice_idx; + size_t unwind_byte_idx; + bool tried_sending_message; + msghdr msg; + // iov consumes a large space. Keep it as the last item on the stack to + // improve locality. After all, we expect only the first elements of it being + // populated in most cases. + iovec iov[MAX_WRITE_IOVEC]; + while (true) { + sending_length = 0; + iov_size = record->PopulateIovs(&unwind_slice_idx, &unwind_byte_idx, + &sending_length, iov); + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = iov; + msg.msg_iovlen = iov_size; + msg.msg_flags = 0; + tried_sending_message = false; + // Before calling sendmsg (with or without timestamps): we + // take a single ref on the zerocopy send record. + tcp->tcp_zerocopy_send_ctx.NoteSend(record); + if (tcp->outgoing_buffer_arg != nullptr) { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length, + MSG_ZEROCOPY)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; + tcp_shutdown_buffer_list(tcp); + } else { + tried_sending_message = true; + } + } + if (!tried_sending_message) { + msg.msg_control = nullptr; + msg.msg_controllen = 0; + GRPC_STATS_INC_TCP_WRITE_SIZE(sending_length); + GRPC_STATS_INC_TCP_WRITE_IOV_SIZE(iov_size); + sent_length = tcp_send(tcp->fd, &msg, MSG_ZEROCOPY); + } + if (sent_length < 0) { + // If this particular send failed, drop ref taken earlier in this method. + tcp->tcp_zerocopy_send_ctx.UndoSend(); + if (errno == EAGAIN) { + record->UnwindIfThrottled(unwind_slice_idx, unwind_byte_idx); + return false; + } else if (errno == EPIPE) { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + tcp_shutdown_buffer_list(tcp); + return true; + } else { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + tcp_shutdown_buffer_list(tcp); + return true; + } + } + tcp->bytes_counter += sent_length; + record->UpdateOffsetForBytesSent(sending_length, + static_cast(sent_length)); + if (record->AllSlicesSent()) { + *error = GRPC_ERROR_NONE; + return true; + } + } +} + +static void UnrefMaybePutZerocopySendRecord(grpc_tcp* tcp, + TcpZerocopySendRecord* record, + uint32_t /*seq*/, + const char* /*tag*/) { + if (record->Unref()) { + tcp->tcp_zerocopy_send_ctx.PutSendRecord(record); + } +} + +static bool tcp_flush_zerocopy(grpc_tcp* tcp, TcpZerocopySendRecord* record, + grpc_error_handle* error) { + bool done = do_tcp_flush_zerocopy(tcp, record, error); + if (done) { + // Either we encountered an error, or we successfully sent all the bytes. + // In either case, we're done with this record. + UnrefMaybePutZerocopySendRecord(tcp, record, 0, "flush_done"); + } + return done; +} + +static bool tcp_flush(grpc_tcp* tcp, grpc_error_handle* error) { + struct msghdr msg; + struct iovec iov[MAX_WRITE_IOVEC]; + msg_iovlen_type iov_size; + ssize_t sent_length = 0; + size_t sending_length; + size_t trailing; + size_t unwind_slice_idx; + size_t unwind_byte_idx; + + // We always start at zero, because we eagerly unref and trim the slice + // buffer as we write + size_t outgoing_slice_idx = 0; + + while (true) { + sending_length = 0; + unwind_slice_idx = outgoing_slice_idx; + unwind_byte_idx = tcp->outgoing_byte_idx; + for (iov_size = 0; outgoing_slice_idx != tcp->outgoing_buffer->count && + iov_size != MAX_WRITE_IOVEC; + iov_size++) { + iov[iov_size].iov_base = + GRPC_SLICE_START_PTR( + tcp->outgoing_buffer->slices[outgoing_slice_idx]) + + tcp->outgoing_byte_idx; + iov[iov_size].iov_len = + GRPC_SLICE_LENGTH(tcp->outgoing_buffer->slices[outgoing_slice_idx]) - + tcp->outgoing_byte_idx; + sending_length += iov[iov_size].iov_len; + outgoing_slice_idx++; + tcp->outgoing_byte_idx = 0; + } + GPR_ASSERT(iov_size > 0); + + msg.msg_name = nullptr; + msg.msg_namelen = 0; + msg.msg_iov = iov; + msg.msg_iovlen = iov_size; + msg.msg_flags = 0; + bool tried_sending_message = false; + if (tcp->outgoing_buffer_arg != nullptr) { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; + tcp_shutdown_buffer_list(tcp); + } else { + tried_sending_message = true; + } + } + if (!tried_sending_message) { + msg.msg_control = nullptr; + msg.msg_controllen = 0; + + GRPC_STATS_INC_TCP_WRITE_SIZE(sending_length); + GRPC_STATS_INC_TCP_WRITE_IOV_SIZE(iov_size); + + sent_length = tcp_send(tcp->fd, &msg); + } + + if (sent_length < 0) { + if (errno == EAGAIN) { + tcp->outgoing_byte_idx = unwind_byte_idx; + // unref all and forget about all slices that have been written to this + // point + for (size_t idx = 0; idx < unwind_slice_idx; ++idx) { + grpc_slice_buffer_remove_first(tcp->outgoing_buffer); + } + return false; + } else if (errno == EPIPE) { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); + tcp_shutdown_buffer_list(tcp); + return true; + } else { + *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); + grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); + tcp_shutdown_buffer_list(tcp); + return true; + } + } + + GPR_ASSERT(tcp->outgoing_byte_idx == 0); + tcp->bytes_counter += sent_length; + trailing = sending_length - static_cast(sent_length); + while (trailing > 0) { + size_t slice_length; + + outgoing_slice_idx--; + slice_length = + GRPC_SLICE_LENGTH(tcp->outgoing_buffer->slices[outgoing_slice_idx]); + if (slice_length > trailing) { + tcp->outgoing_byte_idx = slice_length - trailing; + break; + } else { + trailing -= slice_length; + } + } + if (outgoing_slice_idx == tcp->outgoing_buffer->count) { + *error = GRPC_ERROR_NONE; + grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); + return true; + } + } +} + +static void tcp_handle_write(void* arg /* grpc_tcp */, + grpc_error_handle error) { + grpc_tcp* tcp = static_cast(arg); + grpc_closure* cb; + + if (error != GRPC_ERROR_NONE) { + cb = tcp->write_cb; + tcp->write_cb = nullptr; + if (tcp->current_zerocopy_send != nullptr) { + UnrefMaybePutZerocopySendRecord(tcp, tcp->current_zerocopy_send, 0, + "handle_write_err"); + tcp->current_zerocopy_send = nullptr; + } + grpc_core::Closure::Run(DEBUG_LOCATION, cb, GRPC_ERROR_REF(error)); + TCP_UNREF(tcp, "write"); + return; + } + + bool flush_result = + tcp->current_zerocopy_send != nullptr + ? tcp_flush_zerocopy(tcp, tcp->current_zerocopy_send, &error) + : tcp_flush(tcp, &error); + if (!flush_result) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "write: delayed"); + } + notify_on_write(tcp); + // tcp_flush does not populate error if it has returned false. + GPR_DEBUG_ASSERT(error == GRPC_ERROR_NONE); + } else { + cb = tcp->write_cb; + tcp->write_cb = nullptr; + tcp->current_zerocopy_send = nullptr; + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "write: %s", grpc_error_std_string(error).c_str()); + } + // No need to take a ref on error since tcp_flush provides a ref. + grpc_core::Closure::Run(DEBUG_LOCATION, cb, error); + TCP_UNREF(tcp, "write"); + } +} + +static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf, + grpc_closure* cb, void* arg) { + GPR_TIMER_SCOPE("tcp_write", 0); + grpc_tcp* tcp = reinterpret_cast(ep); + grpc_error_handle error = GRPC_ERROR_NONE; + TcpZerocopySendRecord* zerocopy_send_record = nullptr; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + size_t i; + + for (i = 0; i < buf->count; i++) { + gpr_log(GPR_INFO, "WRITE %p (peer=%s)", tcp, tcp->peer_string.c_str()); + if (gpr_should_log(GPR_LOG_SEVERITY_DEBUG)) { + char* data = + grpc_dump_slice(buf->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "DATA: %s", data); + gpr_free(data); + } + } + } + + GPR_ASSERT(tcp->write_cb == nullptr); + GPR_DEBUG_ASSERT(tcp->current_zerocopy_send == nullptr); + + if (buf->length == 0) { + grpc_core::Closure::Run( + DEBUG_LOCATION, cb, + grpc_fd_is_shutdown(tcp->em_fd) + ? tcp_annotate_error(GRPC_ERROR_CREATE_FROM_STATIC_STRING("EOF"), + tcp) + : GRPC_ERROR_NONE); + tcp_shutdown_buffer_list(tcp); + return; + } + + zerocopy_send_record = tcp_get_send_zerocopy_record(tcp, buf); + if (zerocopy_send_record == nullptr) { + // Either not enough bytes, or couldn't allocate a zerocopy context. + tcp->outgoing_buffer = buf; + tcp->outgoing_byte_idx = 0; + } + tcp->outgoing_buffer_arg = arg; + if (arg) { + GPR_ASSERT(grpc_event_engine_can_track_errors()); + } + + bool flush_result = + zerocopy_send_record != nullptr + ? tcp_flush_zerocopy(tcp, zerocopy_send_record, &error) + : tcp_flush(tcp, &error); + if (!flush_result) { + TCP_REF(tcp, "write"); + tcp->write_cb = cb; + tcp->current_zerocopy_send = zerocopy_send_record; + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "write: delayed"); + } + notify_on_write(tcp); + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "write: %s", grpc_error_std_string(error).c_str()); + } + grpc_core::Closure::Run(DEBUG_LOCATION, cb, error); + } +} + +static void tcp_add_to_pollset(grpc_endpoint* ep, grpc_pollset* pollset) { + grpc_tcp* tcp = reinterpret_cast(ep); + grpc_pollset_add_fd(pollset, tcp->em_fd); +} + +static void tcp_add_to_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + grpc_tcp* tcp = reinterpret_cast(ep); + grpc_pollset_set_add_fd(pollset_set, tcp->em_fd); +} + +static void tcp_delete_from_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + grpc_tcp* tcp = reinterpret_cast(ep); + grpc_pollset_set_del_fd(pollset_set, tcp->em_fd); +} + +static absl::string_view tcp_get_peer(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + return tcp->peer_string; +} + +static absl::string_view tcp_get_local_address(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + return tcp->local_address; +} + +static int tcp_get_fd(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + return tcp->fd; +} + +static bool tcp_can_track_err(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + if (!grpc_event_engine_can_track_errors()) { + return false; + } + struct sockaddr addr; + socklen_t len = sizeof(addr); + if (getsockname(tcp->fd, &addr, &len) < 0) { + return false; + } + return addr.sa_family == AF_INET || addr.sa_family == AF_INET6; +} + +static const grpc_endpoint_vtable vtable = {tcp_read, + tcp_write, + tcp_add_to_pollset, + tcp_add_to_pollset_set, + tcp_delete_from_pollset_set, + tcp_shutdown, + tcp_destroy, + tcp_get_peer, + tcp_get_local_address, + tcp_get_fd, + tcp_can_track_err}; + +#define MAX_CHUNK_SIZE (32 * 1024 * 1024) + +grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, + const grpc_channel_args* channel_args, + const char* peer_string, + grpc_slice_allocator* slice_allocator) { + static constexpr bool kZerocpTxEnabledDefault = false; + int tcp_read_chunk_size = GRPC_TCP_DEFAULT_READ_SLICE_SIZE; + int tcp_max_read_chunk_size = 4 * 1024 * 1024; + int tcp_min_read_chunk_size = 256; + bool tcp_tx_zerocopy_enabled = kZerocpTxEnabledDefault; + int tcp_tx_zerocopy_send_bytes_thresh = + grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold; + int tcp_tx_zerocopy_max_simult_sends = + grpc_core::TcpZerocopySendCtx::kDefaultMaxSends; + if (channel_args != nullptr) { + for (size_t i = 0; i < channel_args->num_args; i++) { + if (0 == + strcmp(channel_args->args[i].key, GRPC_ARG_TCP_READ_CHUNK_SIZE)) { + grpc_integer_options options = {tcp_read_chunk_size, 1, MAX_CHUNK_SIZE}; + tcp_read_chunk_size = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_MIN_READ_CHUNK_SIZE)) { + grpc_integer_options options = {tcp_read_chunk_size, 1, MAX_CHUNK_SIZE}; + tcp_min_read_chunk_size = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE)) { + grpc_integer_options options = {tcp_read_chunk_size, 1, MAX_CHUNK_SIZE}; + tcp_max_read_chunk_size = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_ENABLED)) { + tcp_tx_zerocopy_enabled = grpc_channel_arg_get_bool( + &channel_args->args[i], kZerocpTxEnabledDefault); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_SEND_BYTES_THRESHOLD)) { + grpc_integer_options options = { + grpc_core::TcpZerocopySendCtx::kDefaultSendBytesThreshold, 0, + INT_MAX}; + tcp_tx_zerocopy_send_bytes_thresh = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } else if (0 == strcmp(channel_args->args[i].key, + GRPC_ARG_TCP_TX_ZEROCOPY_MAX_SIMULT_SENDS)) { + grpc_integer_options options = { + grpc_core::TcpZerocopySendCtx::kDefaultMaxSends, 0, INT_MAX}; + tcp_tx_zerocopy_max_simult_sends = + grpc_channel_arg_get_integer(&channel_args->args[i], options); + } + } + } + + if (tcp_min_read_chunk_size > tcp_max_read_chunk_size) { + tcp_min_read_chunk_size = tcp_max_read_chunk_size; + } + tcp_read_chunk_size = grpc_core::Clamp( + tcp_read_chunk_size, tcp_min_read_chunk_size, tcp_max_read_chunk_size); + + grpc_tcp* tcp = new grpc_tcp(tcp_tx_zerocopy_max_simult_sends, + tcp_tx_zerocopy_send_bytes_thresh); + tcp->base.vtable = &vtable; + tcp->peer_string = peer_string; + tcp->fd = grpc_fd_wrapped_fd(em_fd); + tcp->slice_allocator = slice_allocator; + grpc_resolved_address resolved_local_addr; + memset(&resolved_local_addr, 0, sizeof(resolved_local_addr)); + resolved_local_addr.len = sizeof(resolved_local_addr.addr); + if (getsockname(tcp->fd, + reinterpret_cast(resolved_local_addr.addr), + &resolved_local_addr.len) < 0) { + tcp->local_address = ""; + } else { + tcp->local_address = grpc_sockaddr_to_uri(&resolved_local_addr); + } + tcp->read_cb = nullptr; + tcp->write_cb = nullptr; + tcp->current_zerocopy_send = nullptr; + tcp->release_fd_cb = nullptr; + tcp->release_fd = nullptr; + tcp->incoming_buffer = nullptr; + tcp->target_length = static_cast(tcp_read_chunk_size); + tcp->min_read_chunk_size = tcp_min_read_chunk_size; + tcp->max_read_chunk_size = tcp_max_read_chunk_size; + tcp->bytes_read_this_round = 0; + /* Will be set to false by the very first endpoint read function */ + tcp->is_first_read = true; + tcp->bytes_counter = -1; + tcp->socket_ts_enabled = false; + tcp->ts_capable = true; + tcp->outgoing_buffer_arg = nullptr; + if (tcp_tx_zerocopy_enabled && !tcp->tcp_zerocopy_send_ctx.memory_limited()) { +#ifdef GRPC_LINUX_ERRQUEUE + const int enable = 1; + auto err = + setsockopt(tcp->fd, SOL_SOCKET, SO_ZEROCOPY, &enable, sizeof(enable)); + if (err == 0) { + tcp->tcp_zerocopy_send_ctx.set_enabled(true); + } else { + gpr_log(GPR_ERROR, "Failed to set zerocopy options on the socket."); + } +#endif + } + /* paired with unref in grpc_tcp_destroy */ + new (&tcp->refcount) grpc_core::RefCount( + 1, GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace) ? "tcp" : nullptr); + gpr_atm_no_barrier_store(&tcp->shutdown_count, 0); + tcp->em_fd = em_fd; + grpc_slice_buffer_init(&tcp->last_read_buffer); + gpr_mu_init(&tcp->tb_mu); + tcp->tb_head = nullptr; + GRPC_CLOSURE_INIT(&tcp->read_done_closure, tcp_handle_read, tcp, + grpc_schedule_on_exec_ctx); + if (grpc_event_engine_run_in_background()) { + // If there is a polling engine always running in the background, there is + // no need to run the backup poller. + GRPC_CLOSURE_INIT(&tcp->write_done_closure, tcp_handle_write, tcp, + grpc_schedule_on_exec_ctx); + } else { + GRPC_CLOSURE_INIT(&tcp->write_done_closure, + tcp_drop_uncovered_then_handle_write, tcp, + grpc_schedule_on_exec_ctx); + } + /* Always assume there is something on the queue to read. */ + tcp->inq = 1; +#ifdef GRPC_HAVE_TCP_INQ + int one = 1; + if (setsockopt(tcp->fd, SOL_TCP, TCP_INQ, &one, sizeof(one)) == 0) { + tcp->inq_capable = true; + } else { + gpr_log(GPR_DEBUG, "cannot set inq fd=%d errno=%d", tcp->fd, errno); + tcp->inq_capable = false; + } +#else + tcp->inq_capable = false; +#endif /* GRPC_HAVE_TCP_INQ */ + /* Start being notified on errors if event engine can track errors. */ + if (grpc_event_engine_can_track_errors()) { + /* Grab a ref to tcp so that we can safely access the tcp struct when + * processing errors. We unref when we no longer want to track errors + * separately. */ + TCP_REF(tcp, "error-tracking"); + gpr_atm_rel_store(&tcp->stop_error_notification, 0); + GRPC_CLOSURE_INIT(&tcp->error_closure, tcp_handle_error, tcp, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure); + } + + return &tcp->base; +} + +int grpc_tcp_fd(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast(ep); + GPR_ASSERT(ep->vtable == &vtable); + return grpc_fd_wrapped_fd(tcp->em_fd); +} + +void grpc_tcp_destroy_and_release_fd(grpc_endpoint* ep, int* fd, + grpc_closure* done) { + grpc_tcp* tcp = reinterpret_cast(ep); + GPR_ASSERT(ep->vtable == &vtable); + tcp->release_fd = fd; + tcp->release_fd_cb = done; + grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); + if (grpc_event_engine_can_track_errors()) { + /* Stop errors notification. */ + ZerocopyDisableAndWaitForRemaining(tcp); + gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); + grpc_fd_set_error(tcp->em_fd); + } + TCP_UNREF(tcp, "destroy"); +} + +void grpc_tcp_posix_init() { g_backup_poller_mu = new grpc_core::Mutex; } + +void grpc_tcp_posix_shutdown() { + delete g_backup_poller_mu; + g_backup_poller_mu = nullptr; +} + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/src/core/lib/iomgr/tcp_server.cc b/src/core/lib/iomgr/tcp_server.cc new file mode 100644 index 00000000..70be5876 --- /dev/null +++ b/src/core/lib/iomgr/tcp_server.cc @@ -0,0 +1,79 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/tcp_server.h" + +grpc_tcp_server_vtable* grpc_tcp_server_impl; + +grpc_error_handle grpc_tcp_server_create( + grpc_closure* shutdown_complete, const grpc_channel_args* args, + grpc_slice_allocator_factory* slice_allocator_factory, + grpc_tcp_server** server) { + return grpc_tcp_server_impl->create(shutdown_complete, args, + slice_allocator_factory, server); +} + +void grpc_tcp_server_start(grpc_tcp_server* server, + const std::vector* pollsets, + grpc_tcp_server_cb on_accept_cb, void* cb_arg) { + grpc_tcp_server_impl->start(server, pollsets, on_accept_cb, cb_arg); +} + +grpc_error_handle grpc_tcp_server_add_port(grpc_tcp_server* s, + const grpc_resolved_address* addr, + int* out_port) { + return grpc_tcp_server_impl->add_port(s, addr, out_port); +} + +grpc_core::TcpServerFdHandler* grpc_tcp_server_create_fd_handler( + grpc_tcp_server* s) { + return grpc_tcp_server_impl->create_fd_handler(s); +} + +unsigned grpc_tcp_server_port_fd_count(grpc_tcp_server* s, + unsigned port_index) { + return grpc_tcp_server_impl->port_fd_count(s, port_index); +} + +int grpc_tcp_server_port_fd(grpc_tcp_server* s, unsigned port_index, + unsigned fd_index) { + return grpc_tcp_server_impl->port_fd(s, port_index, fd_index); +} + +grpc_tcp_server* grpc_tcp_server_ref(grpc_tcp_server* s) { + return grpc_tcp_server_impl->ref(s); +} + +void grpc_tcp_server_shutdown_starting_add(grpc_tcp_server* s, + grpc_closure* shutdown_starting) { + grpc_tcp_server_impl->shutdown_starting_add(s, shutdown_starting); +} + +void grpc_tcp_server_unref(grpc_tcp_server* s) { + grpc_tcp_server_impl->unref(s); +} + +void grpc_tcp_server_shutdown_listeners(grpc_tcp_server* s) { + grpc_tcp_server_impl->shutdown_listeners(s); +} + +void grpc_set_tcp_server_impl(grpc_tcp_server_vtable* impl) { + grpc_tcp_server_impl = impl; +} diff --git a/src/core/lib/iomgr/tcp_server_custom.cc b/src/core/lib/iomgr/tcp_server_custom.cc new file mode 100644 index 00000000..a75b2cd1 --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_custom.cc @@ -0,0 +1,467 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_custom.h" +#include "src/core/lib/iomgr/tcp_server.h" + +extern grpc_core::TraceFlag grpc_tcp_trace; + +extern grpc_socket_vtable* grpc_custom_socket_vtable; + +/* one listening port */ +struct grpc_tcp_listener { + grpc_tcp_server* server; + unsigned port_index; + int port; + + grpc_custom_socket* socket; + + /* linked list */ + struct grpc_tcp_listener* next; + + bool closed; +}; + +struct grpc_tcp_server { + gpr_refcount refs; + + /* Called whenever accept() succeeds on a server port. */ + grpc_tcp_server_cb on_accept_cb; + void* on_accept_cb_arg; + + int open_ports; + + /* linked list of server ports */ + grpc_tcp_listener* head; + grpc_tcp_listener* tail; + + /* List of closures passed to shutdown_starting_add(). */ + grpc_closure_list shutdown_starting; + + /* shutdown callback */ + grpc_closure* shutdown_complete; + + bool shutdown; + bool so_reuseport; + + grpc_slice_allocator_factory* slice_allocator_factory; +}; + +static grpc_error_handle tcp_server_create( + grpc_closure* shutdown_complete, const grpc_channel_args* args, + grpc_slice_allocator_factory* slice_allocator_factory, + grpc_tcp_server** server) { + grpc_tcp_server* s = + static_cast(gpr_malloc(sizeof(grpc_tcp_server))); + s->so_reuseport = + grpc_channel_args_find_bool(args, GRPC_ARG_ALLOW_REUSEPORT, true); + gpr_ref_init(&s->refs, 1); + s->on_accept_cb = nullptr; + s->on_accept_cb_arg = nullptr; + s->open_ports = 0; + s->head = nullptr; + s->tail = nullptr; + s->shutdown_starting.head = nullptr; + s->shutdown_starting.tail = nullptr; + s->shutdown_complete = shutdown_complete; + s->shutdown = false; + s->slice_allocator_factory = slice_allocator_factory; + *server = s; + return GRPC_ERROR_NONE; +} + +static grpc_tcp_server* tcp_server_ref(grpc_tcp_server* s) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + gpr_ref(&s->refs); + return s; +} + +static void tcp_server_shutdown_starting_add(grpc_tcp_server* s, + grpc_closure* shutdown_starting) { + grpc_closure_list_append(&s->shutdown_starting, shutdown_starting, + GRPC_ERROR_NONE); +} + +static void finish_shutdown(grpc_tcp_server* s) { + GPR_ASSERT(s->shutdown); + if (s->shutdown_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, s->shutdown_complete, + GRPC_ERROR_NONE); + } + + while (s->head) { + grpc_tcp_listener* sp = s->head; + s->head = sp->next; + sp->next = nullptr; + gpr_free(sp); + } + grpc_slice_allocator_factory_destroy(s->slice_allocator_factory); + gpr_free(s); +} + +static void custom_close_callback(grpc_custom_socket* socket) { + grpc_tcp_listener* sp = socket->listener; + if (sp) { + grpc_core::ExecCtx exec_ctx; + sp->server->open_ports--; + if (sp->server->open_ports == 0 && sp->server->shutdown) { + finish_shutdown(sp->server); + } + } + socket->refs--; + if (socket->refs == 0) { + grpc_custom_socket_vtable->destroy(socket); + gpr_free(socket); + } +} + +void grpc_custom_close_server_callback(grpc_tcp_listener* listener) { + if (listener) { + grpc_core::ExecCtx exec_ctx; + listener->server->open_ports--; + if (listener->server->open_ports == 0 && listener->server->shutdown) { + finish_shutdown(listener->server); + } + } +} + +static void close_listener(grpc_tcp_listener* sp) { + grpc_custom_socket* socket = sp->socket; + if (!sp->closed) { + sp->closed = true; + grpc_custom_socket_vtable->close(socket, custom_close_callback); + } +} + +static void tcp_server_destroy(grpc_tcp_server* s) { + int immediately_done = 0; + grpc_tcp_listener* sp; + + GPR_ASSERT(!s->shutdown); + s->shutdown = true; + + if (s->open_ports == 0) { + immediately_done = 1; + } + for (sp = s->head; sp; sp = sp->next) { + close_listener(sp); + } + + if (immediately_done) { + finish_shutdown(s); + } +} + +static void tcp_server_unref(grpc_tcp_server* s) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + if (gpr_unref(&s->refs)) { + /* Complete shutdown_starting work before destroying. */ + grpc_core::ExecCtx exec_ctx; + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &s->shutdown_starting); + grpc_core::ExecCtx::Get()->Flush(); + tcp_server_destroy(s); + } +} + +static void finish_accept(grpc_tcp_listener* sp, grpc_custom_socket* socket) { + grpc_tcp_server_acceptor* acceptor = + static_cast(gpr_malloc(sizeof(*acceptor))); + grpc_endpoint* ep = nullptr; + grpc_resolved_address peer_name; + std::string peer_name_string; + grpc_error_handle err; + + memset(&peer_name, 0, sizeof(grpc_resolved_address)); + peer_name.len = GRPC_MAX_SOCKADDR_SIZE; + err = grpc_custom_socket_vtable->getpeername( + socket, reinterpret_cast(&peer_name.addr), + reinterpret_cast(&peer_name.len)); + if (err == GRPC_ERROR_NONE) { + peer_name_string = grpc_sockaddr_to_uri(&peer_name); + } else { + GRPC_LOG_IF_ERROR("getpeername error", err); + GRPC_ERROR_UNREF(err); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "SERVER_CONNECT: %p accepted connection: %s", sp->server, + peer_name_string.c_str()); + } + ep = custom_tcp_endpoint_create( + socket, + grpc_slice_allocator_factory_create_slice_allocator( + sp->server->slice_allocator_factory, peer_name_string), + peer_name_string.c_str()); + acceptor->from_server = sp->server; + acceptor->port_index = sp->port_index; + acceptor->fd_index = 0; + acceptor->external_connection = false; + sp->server->on_accept_cb(sp->server->on_accept_cb_arg, ep, nullptr, acceptor); +} + +static void custom_accept_callback(grpc_custom_socket* socket, + grpc_custom_socket* client, + grpc_error_handle error); + +static void custom_accept_callback(grpc_custom_socket* socket, + grpc_custom_socket* client, + grpc_error_handle error) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_tcp_listener* sp = socket->listener; + if (error != GRPC_ERROR_NONE) { + if (!sp->closed) { + gpr_log(GPR_ERROR, "Accept failed: %s", + grpc_error_std_string(error).c_str()); + } + gpr_free(client); + GRPC_ERROR_UNREF(error); + return; + } + finish_accept(sp, client); + if (!sp->closed) { + grpc_custom_socket* new_socket = static_cast( + gpr_malloc(sizeof(grpc_custom_socket))); + new_socket->endpoint = nullptr; + new_socket->listener = nullptr; + new_socket->connector = nullptr; + new_socket->refs = 1; + grpc_custom_socket_vtable->accept(sp->socket, new_socket, + custom_accept_callback); + } +} + +static grpc_error_handle add_socket_to_server(grpc_tcp_server* s, + grpc_custom_socket* socket, + const grpc_resolved_address* addr, + unsigned port_index, + grpc_tcp_listener** listener) { + grpc_tcp_listener* sp = nullptr; + int port = -1; + grpc_error_handle error; + grpc_resolved_address sockname_temp; + + // NOTE(lidiz) The last argument is "flags" which is unused by other + // implementations. Python IO managers uses it to specify SO_REUSEPORT. + int flags = 0; + if (s->so_reuseport) { + flags |= GRPC_CUSTOM_SOCKET_OPT_SO_REUSEPORT; + } + + error = grpc_custom_socket_vtable->bind( + socket, reinterpret_cast(const_cast(addr->addr)), + addr->len, flags); + if (error != GRPC_ERROR_NONE) { + return error; + } + + error = grpc_custom_socket_vtable->listen(socket); + if (error != GRPC_ERROR_NONE) { + return error; + } + + sockname_temp.len = GRPC_MAX_SOCKADDR_SIZE; + error = grpc_custom_socket_vtable->getsockname( + socket, reinterpret_cast(&sockname_temp.addr), + reinterpret_cast(&sockname_temp.len)); + if (error != GRPC_ERROR_NONE) { + return error; + } + + port = grpc_sockaddr_get_port(&sockname_temp); + + GPR_ASSERT(port >= 0); + GPR_ASSERT(!s->on_accept_cb && "must add ports before starting server"); + sp = grpc_core::Zalloc(); + sp->next = nullptr; + if (s->head == nullptr) { + s->head = sp; + } else { + s->tail->next = sp; + } + s->tail = sp; + sp->server = s; + sp->socket = socket; + sp->port = port; + sp->port_index = port_index; + sp->closed = false; + s->open_ports++; + *listener = sp; + + return GRPC_ERROR_NONE; +} + +static grpc_error_handle tcp_server_add_port(grpc_tcp_server* s, + const grpc_resolved_address* addr, + int* port) { + // This function is mostly copied from tcp_server_windows.c + grpc_tcp_listener* sp = nullptr; + grpc_custom_socket* socket; + grpc_resolved_address addr6_v4mapped; + grpc_resolved_address wildcard; + grpc_resolved_address* allocated_addr = nullptr; + grpc_resolved_address sockname_temp; + unsigned port_index = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + int family; + + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + + if (s->tail != nullptr) { + port_index = s->tail->port_index + 1; + } + + /* Check if this is a wildcard port, and if so, try to keep the port the same + as some previously created listener. */ + if (grpc_sockaddr_get_port(addr) == 0) { + for (sp = s->head; sp; sp = sp->next) { + socket = sp->socket; + sockname_temp.len = GRPC_MAX_SOCKADDR_SIZE; + if (grpc_custom_socket_vtable->getsockname( + socket, reinterpret_cast(&sockname_temp.addr), + reinterpret_cast(&sockname_temp.len)) == GRPC_ERROR_NONE) { + *port = grpc_sockaddr_get_port(&sockname_temp); + if (*port > 0) { + allocated_addr = static_cast( + gpr_malloc(sizeof(grpc_resolved_address))); + memcpy(allocated_addr, addr, sizeof(grpc_resolved_address)); + grpc_sockaddr_set_port(allocated_addr, *port); + addr = allocated_addr; + break; + } + } + } + } + + if (grpc_sockaddr_to_v4mapped(addr, &addr6_v4mapped)) { + addr = &addr6_v4mapped; + } + + /* Treat :: or 0.0.0.0 as a family-agnostic wildcard. */ + if (grpc_sockaddr_is_wildcard(addr, port)) { + grpc_sockaddr_make_wildcard6(*port, &wildcard); + + addr = &wildcard; + } + + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "SERVER %p add_port %s error=%s", s, + grpc_sockaddr_to_string(addr, false).c_str(), + grpc_error_std_string(error).c_str()); + } + + family = grpc_sockaddr_get_family(addr); + socket = + static_cast(gpr_malloc(sizeof(grpc_custom_socket))); + socket->refs = 1; + socket->endpoint = nullptr; + socket->listener = nullptr; + socket->connector = nullptr; + error = grpc_custom_socket_vtable->init(socket, family); + + if (error == GRPC_ERROR_NONE) { + error = add_socket_to_server(s, socket, addr, port_index, &sp); + } + gpr_free(allocated_addr); + + if (error != GRPC_ERROR_NONE) { + grpc_error_handle error_out = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to add port to server", &error, 1); + GRPC_ERROR_UNREF(error); + error = error_out; + *port = -1; + } else { + GPR_ASSERT(sp != nullptr); + *port = sp->port; + } + socket->listener = sp; + return error; +} + +static void tcp_server_start(grpc_tcp_server* server, + const std::vector* /*pollsets*/, + grpc_tcp_server_cb on_accept_cb, void* cb_arg) { + grpc_tcp_listener* sp; + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "SERVER_START %p", server); + } + GPR_ASSERT(on_accept_cb); + GPR_ASSERT(!server->on_accept_cb); + server->on_accept_cb = on_accept_cb; + server->on_accept_cb_arg = cb_arg; + for (sp = server->head; sp; sp = sp->next) { + grpc_custom_socket* new_socket = static_cast( + gpr_malloc(sizeof(grpc_custom_socket))); + new_socket->endpoint = nullptr; + new_socket->listener = nullptr; + new_socket->connector = nullptr; + new_socket->refs = 1; + grpc_custom_socket_vtable->accept(sp->socket, new_socket, + custom_accept_callback); + } +} + +static unsigned tcp_server_port_fd_count(grpc_tcp_server* /*s*/, + unsigned /*port_index*/) { + return 0; +} + +static int tcp_server_port_fd(grpc_tcp_server* /*s*/, unsigned /*port_index*/, + unsigned /*fd_index*/) { + return -1; +} + +static void tcp_server_shutdown_listeners(grpc_tcp_server* s) { + for (grpc_tcp_listener* sp = s->head; sp; sp = sp->next) { + if (!sp->closed) { + sp->closed = true; + grpc_custom_socket_vtable->close(sp->socket, custom_close_callback); + } + } +} + +static grpc_core::TcpServerFdHandler* tcp_server_create_fd_handler( + grpc_tcp_server* /*s*/) { + return nullptr; +} + +grpc_tcp_server_vtable custom_tcp_server_vtable = { + tcp_server_create, tcp_server_start, + tcp_server_add_port, tcp_server_create_fd_handler, + tcp_server_port_fd_count, tcp_server_port_fd, + tcp_server_ref, tcp_server_shutdown_starting_add, + tcp_server_unref, tcp_server_shutdown_listeners}; diff --git a/src/core/lib/iomgr/tcp_server_posix.cc b/src/core/lib/iomgr/tcp_server_posix.cc new file mode 100644 index 00000000..fc8ff01d --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_posix.cc @@ -0,0 +1,645 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* FIXME: "posix" files shouldn't be depending on _GNU_SOURCE */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP_SERVER + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/tcp_server_utils_posix.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +static grpc_error_handle tcp_server_create( + grpc_closure* shutdown_complete, const grpc_channel_args* args, + grpc_slice_allocator_factory* slice_allocator_factory, + grpc_tcp_server** server) { + grpc_tcp_server* s = grpc_core::Zalloc(); + s->so_reuseport = grpc_is_socket_reuse_port_supported(); + s->expand_wildcard_addrs = false; + for (size_t i = 0; i < (args == nullptr ? 0 : args->num_args); i++) { + if (0 == strcmp(GRPC_ARG_ALLOW_REUSEPORT, args->args[i].key)) { + if (args->args[i].type == GRPC_ARG_INTEGER) { + s->so_reuseport = grpc_is_socket_reuse_port_supported() && + (args->args[i].value.integer != 0); + } else { + gpr_free(s); + grpc_slice_allocator_factory_destroy(slice_allocator_factory); + return GRPC_ERROR_CREATE_FROM_STATIC_STRING(GRPC_ARG_ALLOW_REUSEPORT + " must be an integer"); + } + } else if (0 == strcmp(GRPC_ARG_EXPAND_WILDCARD_ADDRS, args->args[i].key)) { + if (args->args[i].type == GRPC_ARG_INTEGER) { + s->expand_wildcard_addrs = (args->args[i].value.integer != 0); + } else { + gpr_free(s); + grpc_slice_allocator_factory_destroy(slice_allocator_factory); + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + GRPC_ARG_EXPAND_WILDCARD_ADDRS " must be an integer"); + } + } + } + gpr_ref_init(&s->refs, 1); + gpr_mu_init(&s->mu); + s->active_ports = 0; + s->destroyed_ports = 0; + s->shutdown = false; + s->shutdown_starting.head = nullptr; + s->shutdown_starting.tail = nullptr; + s->shutdown_complete = shutdown_complete; + s->on_accept_cb = nullptr; + s->on_accept_cb_arg = nullptr; + s->head = nullptr; + s->tail = nullptr; + s->nports = 0; + s->channel_args = grpc_channel_args_copy(args); + s->fd_handler = nullptr; + s->slice_allocator_factory = slice_allocator_factory; + gpr_atm_no_barrier_store(&s->next_pollset_to_assign, 0); + *server = s; + return GRPC_ERROR_NONE; +} + +static void finish_shutdown(grpc_tcp_server* s) { + gpr_mu_lock(&s->mu); + GPR_ASSERT(s->shutdown); + gpr_mu_unlock(&s->mu); + if (s->shutdown_complete != nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, s->shutdown_complete, + GRPC_ERROR_NONE); + } + gpr_mu_destroy(&s->mu); + while (s->head) { + grpc_tcp_listener* sp = s->head; + s->head = sp->next; + gpr_free(sp); + } + grpc_slice_allocator_factory_destroy(s->slice_allocator_factory); + grpc_channel_args_destroy(s->channel_args); + delete s->fd_handler; + gpr_free(s); +} + +static void destroyed_port(void* server, grpc_error_handle /*error*/) { + grpc_tcp_server* s = static_cast(server); + gpr_mu_lock(&s->mu); + s->destroyed_ports++; + if (s->destroyed_ports == s->nports) { + gpr_mu_unlock(&s->mu); + finish_shutdown(s); + } else { + GPR_ASSERT(s->destroyed_ports < s->nports); + gpr_mu_unlock(&s->mu); + } +} + +/* called when all listening endpoints have been shutdown, so no further + events will be received on them - at this point it's safe to destroy + things */ +static void deactivated_all_ports(grpc_tcp_server* s) { + /* delete ALL the things */ + gpr_mu_lock(&s->mu); + + GPR_ASSERT(s->shutdown); + + if (s->head) { + grpc_tcp_listener* sp; + for (sp = s->head; sp; sp = sp->next) { + grpc_unlink_if_unix_domain_socket(&sp->addr); + GRPC_CLOSURE_INIT(&sp->destroyed_closure, destroyed_port, s, + grpc_schedule_on_exec_ctx); + grpc_fd_orphan(sp->emfd, &sp->destroyed_closure, nullptr, + "tcp_listener_shutdown"); + } + gpr_mu_unlock(&s->mu); + } else { + gpr_mu_unlock(&s->mu); + finish_shutdown(s); + } +} + +static void tcp_server_destroy(grpc_tcp_server* s) { + gpr_mu_lock(&s->mu); + GPR_ASSERT(!s->shutdown); + s->shutdown = true; + /* shutdown all fd's */ + if (s->active_ports) { + grpc_tcp_listener* sp; + for (sp = s->head; sp; sp = sp->next) { + grpc_fd_shutdown( + sp->emfd, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server destroyed")); + } + gpr_mu_unlock(&s->mu); + } else { + gpr_mu_unlock(&s->mu); + deactivated_all_ports(s); + } +} + +/* event manager callback when reads are ready */ +static void on_read(void* arg, grpc_error_handle err) { + grpc_tcp_listener* sp = static_cast(arg); + grpc_pollset* read_notifier_pollset; + if (err != GRPC_ERROR_NONE) { + goto error; + } + + /* loop until accept4 returns EAGAIN, and then re-arm notification */ + for (;;) { + grpc_resolved_address addr; + memset(&addr, 0, sizeof(addr)); + addr.len = static_cast(sizeof(struct sockaddr_storage)); + /* Note: If we ever decide to return this address to the user, remember to + strip off the ::ffff:0.0.0.0/96 prefix first. */ + int fd = grpc_accept4(sp->fd, &addr, 1, 1); + if (fd < 0) { + switch (errno) { + case EINTR: + continue; + case EAGAIN: + grpc_fd_notify_on_read(sp->emfd, &sp->read_closure); + return; + default: + gpr_mu_lock(&sp->server->mu); + if (!sp->server->shutdown_listeners) { + gpr_log(GPR_ERROR, "Failed accept4: %s", strerror(errno)); + } else { + /* if we have shutdown listeners, accept4 could fail, and we + needn't notify users */ + } + gpr_mu_unlock(&sp->server->mu); + goto error; + } + } + + /* For UNIX sockets, the accept call might not fill up the member sun_path + * of sockaddr_un, so explicitly call getsockname to get it. */ + if (grpc_is_unix_socket(&addr)) { + memset(&addr, 0, sizeof(addr)); + addr.len = static_cast(sizeof(struct sockaddr_storage)); + if (getsockname(fd, reinterpret_cast(addr.addr), + &(addr.len)) < 0) { + gpr_log(GPR_ERROR, "Failed getsockname: %s", strerror(errno)); + close(fd); + goto error; + } + } + + (void)grpc_set_socket_no_sigpipe_if_possible(fd); + + err = grpc_apply_socket_mutator_in_args(fd, GRPC_FD_SERVER_CONNECTION_USAGE, + sp->server->channel_args); + if (err != GRPC_ERROR_NONE) { + goto error; + } + + std::string addr_str = grpc_sockaddr_to_uri(&addr); + if (GRPC_TRACE_FLAG_ENABLED(grpc_tcp_trace)) { + gpr_log(GPR_INFO, "SERVER_CONNECT: incoming connection: %s", + addr_str.c_str()); + } + + std::string name = absl::StrCat("tcp-server-connection:", addr_str); + grpc_fd* fdobj = grpc_fd_create(fd, name.c_str(), true); + + read_notifier_pollset = (*(sp->server->pollsets)) + [static_cast(gpr_atm_no_barrier_fetch_add( + &sp->server->next_pollset_to_assign, 1)) % + sp->server->pollsets->size()]; + + grpc_pollset_add_fd(read_notifier_pollset, fdobj); + + // Create acceptor. + grpc_tcp_server_acceptor* acceptor = + static_cast(gpr_malloc(sizeof(*acceptor))); + acceptor->from_server = sp->server; + acceptor->port_index = sp->port_index; + acceptor->fd_index = sp->fd_index; + acceptor->external_connection = false; + sp->server->on_accept_cb( + sp->server->on_accept_cb_arg, + grpc_tcp_create(fdobj, sp->server->channel_args, addr_str.c_str(), + grpc_slice_allocator_factory_create_slice_allocator( + sp->server->slice_allocator_factory, + absl::StrCat("tcp_server_posix:", addr_str), + sp->server->channel_args)), + read_notifier_pollset, acceptor); + } + + GPR_UNREACHABLE_CODE(return ); + +error: + gpr_mu_lock(&sp->server->mu); + if (0 == --sp->server->active_ports && sp->server->shutdown) { + gpr_mu_unlock(&sp->server->mu); + deactivated_all_ports(sp->server); + } else { + gpr_mu_unlock(&sp->server->mu); + } +} + +/* Treat :: or 0.0.0.0 as a family-agnostic wildcard. */ +static grpc_error_handle add_wildcard_addrs_to_server(grpc_tcp_server* s, + unsigned port_index, + int requested_port, + int* out_port) { + grpc_resolved_address wild4; + grpc_resolved_address wild6; + unsigned fd_index = 0; + grpc_dualstack_mode dsmode; + grpc_tcp_listener* sp = nullptr; + grpc_tcp_listener* sp2 = nullptr; + grpc_error_handle v6_err = GRPC_ERROR_NONE; + grpc_error_handle v4_err = GRPC_ERROR_NONE; + *out_port = -1; + + if (grpc_tcp_server_have_ifaddrs() && s->expand_wildcard_addrs) { + return grpc_tcp_server_add_all_local_addrs(s, port_index, requested_port, + out_port); + } + + grpc_sockaddr_make_wildcards(requested_port, &wild4, &wild6); + /* Try listening on IPv6 first. */ + if ((v6_err = grpc_tcp_server_add_addr(s, &wild6, port_index, fd_index, + &dsmode, &sp)) == GRPC_ERROR_NONE) { + ++fd_index; + requested_port = *out_port = sp->port; + if (dsmode == GRPC_DSMODE_DUALSTACK || dsmode == GRPC_DSMODE_IPV4) { + return GRPC_ERROR_NONE; + } + } + /* If we got a v6-only socket or nothing, try adding 0.0.0.0. */ + grpc_sockaddr_set_port(&wild4, requested_port); + if ((v4_err = grpc_tcp_server_add_addr(s, &wild4, port_index, fd_index, + &dsmode, &sp2)) == GRPC_ERROR_NONE) { + *out_port = sp2->port; + if (sp != nullptr) { + sp2->is_sibling = 1; + sp->sibling = sp2; + } + } + if (*out_port > 0) { + if (v6_err != GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, + "Failed to add :: listener, " + "the environment may not support IPv6: %s", + grpc_error_std_string(v6_err).c_str()); + GRPC_ERROR_UNREF(v6_err); + } + if (v4_err != GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, + "Failed to add 0.0.0.0 listener, " + "the environment may not support IPv4: %s", + grpc_error_std_string(v4_err).c_str()); + GRPC_ERROR_UNREF(v4_err); + } + return GRPC_ERROR_NONE; + } else { + grpc_error_handle root_err = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Failed to add any wildcard listeners"); + GPR_ASSERT(v6_err != GRPC_ERROR_NONE && v4_err != GRPC_ERROR_NONE); + root_err = grpc_error_add_child(root_err, v6_err); + root_err = grpc_error_add_child(root_err, v4_err); + return root_err; + } +} + +static grpc_error_handle clone_port(grpc_tcp_listener* listener, + unsigned count) { + grpc_tcp_listener* sp = nullptr; + std::string addr_str; + grpc_error_handle err; + + for (grpc_tcp_listener* l = listener->next; l && l->is_sibling; l = l->next) { + l->fd_index += count; + } + + for (unsigned i = 0; i < count; i++) { + int fd = -1; + int port = -1; + grpc_dualstack_mode dsmode; + err = grpc_create_dualstack_socket(&listener->addr, SOCK_STREAM, 0, &dsmode, + &fd); + if (err != GRPC_ERROR_NONE) return err; + err = grpc_tcp_server_prepare_socket(listener->server, fd, &listener->addr, + true, &port); + if (err != GRPC_ERROR_NONE) return err; + listener->server->nports++; + addr_str = grpc_sockaddr_to_string(&listener->addr, true); + sp = static_cast(gpr_malloc(sizeof(grpc_tcp_listener))); + sp->next = listener->next; + listener->next = sp; + /* sp (the new listener) is a sibling of 'listener' (the original + listener). */ + sp->is_sibling = 1; + sp->sibling = listener->sibling; + listener->sibling = sp; + sp->server = listener->server; + sp->fd = fd; + sp->emfd = grpc_fd_create( + fd, + absl::StrFormat("tcp-server-listener:%s/clone-%d", addr_str.c_str(), i) + .c_str(), + true); + memcpy(&sp->addr, &listener->addr, sizeof(grpc_resolved_address)); + sp->port = port; + sp->port_index = listener->port_index; + sp->fd_index = listener->fd_index + count - i; + GPR_ASSERT(sp->emfd); + while (listener->server->tail->next != nullptr) { + listener->server->tail = listener->server->tail->next; + } + } + + return GRPC_ERROR_NONE; +} + +static grpc_error_handle tcp_server_add_port(grpc_tcp_server* s, + const grpc_resolved_address* addr, + int* out_port) { + GPR_ASSERT(addr->len <= GRPC_MAX_SOCKADDR_SIZE); + grpc_tcp_listener* sp; + grpc_resolved_address sockname_temp; + grpc_resolved_address addr6_v4mapped; + int requested_port = grpc_sockaddr_get_port(addr); + unsigned port_index = 0; + grpc_dualstack_mode dsmode; + grpc_error_handle err; + *out_port = -1; + if (s->tail != nullptr) { + port_index = s->tail->port_index + 1; + } + grpc_unlink_if_unix_domain_socket(addr); + + /* Check if this is a wildcard port, and if so, try to keep the port the same + as some previously created listener. */ + if (requested_port == 0) { + for (sp = s->head; sp; sp = sp->next) { + sockname_temp.len = + static_cast(sizeof(struct sockaddr_storage)); + if (0 == + getsockname(sp->fd, + reinterpret_cast(&sockname_temp.addr), + &sockname_temp.len)) { + int used_port = grpc_sockaddr_get_port(&sockname_temp); + if (used_port > 0) { + memcpy(&sockname_temp, addr, sizeof(grpc_resolved_address)); + grpc_sockaddr_set_port(&sockname_temp, used_port); + requested_port = used_port; + addr = &sockname_temp; + break; + } + } + } + } + if (grpc_sockaddr_is_wildcard(addr, &requested_port)) { + return add_wildcard_addrs_to_server(s, port_index, requested_port, + out_port); + } + if (grpc_sockaddr_to_v4mapped(addr, &addr6_v4mapped)) { + addr = &addr6_v4mapped; + } + if ((err = grpc_tcp_server_add_addr(s, addr, port_index, 0, &dsmode, &sp)) == + GRPC_ERROR_NONE) { + *out_port = sp->port; + } + return err; +} + +/* Return listener at port_index or NULL. Should only be called with s->mu + locked. */ +static grpc_tcp_listener* get_port_index(grpc_tcp_server* s, + unsigned port_index) { + unsigned num_ports = 0; + grpc_tcp_listener* sp; + for (sp = s->head; sp; sp = sp->next) { + if (!sp->is_sibling) { + if (++num_ports > port_index) { + return sp; + } + } + } + return nullptr; +} + +unsigned tcp_server_port_fd_count(grpc_tcp_server* s, unsigned port_index) { + unsigned num_fds = 0; + gpr_mu_lock(&s->mu); + grpc_tcp_listener* sp = get_port_index(s, port_index); + for (; sp; sp = sp->sibling) { + ++num_fds; + } + gpr_mu_unlock(&s->mu); + return num_fds; +} + +static int tcp_server_port_fd(grpc_tcp_server* s, unsigned port_index, + unsigned fd_index) { + gpr_mu_lock(&s->mu); + grpc_tcp_listener* sp = get_port_index(s, port_index); + for (; sp; sp = sp->sibling, --fd_index) { + if (fd_index == 0) { + gpr_mu_unlock(&s->mu); + return sp->fd; + } + } + gpr_mu_unlock(&s->mu); + return -1; +} + +static void tcp_server_start(grpc_tcp_server* s, + const std::vector* pollsets, + grpc_tcp_server_cb on_accept_cb, + void* on_accept_cb_arg) { + size_t i; + grpc_tcp_listener* sp; + GPR_ASSERT(on_accept_cb); + gpr_mu_lock(&s->mu); + GPR_ASSERT(!s->on_accept_cb); + GPR_ASSERT(s->active_ports == 0); + s->on_accept_cb = on_accept_cb; + s->on_accept_cb_arg = on_accept_cb_arg; + s->pollsets = pollsets; + sp = s->head; + while (sp != nullptr) { + if (s->so_reuseport && !grpc_is_unix_socket(&sp->addr) && + pollsets->size() > 1) { + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "clone_port", clone_port(sp, (unsigned)(pollsets->size() - 1)))); + for (i = 0; i < pollsets->size(); i++) { + grpc_pollset_add_fd((*pollsets)[i], sp->emfd); + GRPC_CLOSURE_INIT(&sp->read_closure, on_read, sp, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_read(sp->emfd, &sp->read_closure); + s->active_ports++; + sp = sp->next; + } + } else { + for (i = 0; i < pollsets->size(); i++) { + grpc_pollset_add_fd((*pollsets)[i], sp->emfd); + } + GRPC_CLOSURE_INIT(&sp->read_closure, on_read, sp, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_read(sp->emfd, &sp->read_closure); + s->active_ports++; + sp = sp->next; + } + } + gpr_mu_unlock(&s->mu); +} + +grpc_tcp_server* tcp_server_ref(grpc_tcp_server* s) { + gpr_ref_non_zero(&s->refs); + return s; +} + +static void tcp_server_shutdown_starting_add(grpc_tcp_server* s, + grpc_closure* shutdown_starting) { + gpr_mu_lock(&s->mu); + grpc_closure_list_append(&s->shutdown_starting, shutdown_starting, + GRPC_ERROR_NONE); + gpr_mu_unlock(&s->mu); +} + +static void tcp_server_unref(grpc_tcp_server* s) { + if (gpr_unref(&s->refs)) { + grpc_tcp_server_shutdown_listeners(s); + gpr_mu_lock(&s->mu); + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &s->shutdown_starting); + gpr_mu_unlock(&s->mu); + tcp_server_destroy(s); + } +} + +static void tcp_server_shutdown_listeners(grpc_tcp_server* s) { + gpr_mu_lock(&s->mu); + s->shutdown_listeners = true; + /* shutdown all fd's */ + if (s->active_ports) { + grpc_tcp_listener* sp; + for (sp = s->head; sp; sp = sp->next) { + grpc_fd_shutdown(sp->emfd, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server shutdown")); + } + } + gpr_mu_unlock(&s->mu); +} + +namespace { +class ExternalConnectionHandler : public grpc_core::TcpServerFdHandler { + public: + explicit ExternalConnectionHandler(grpc_tcp_server* s) : s_(s) {} + + // TODO(yangg) resolve duplicate code with on_read + void Handle(int listener_fd, int fd, grpc_byte_buffer* buf) override { + grpc_pollset* read_notifier_pollset; + grpc_resolved_address addr; + memset(&addr, 0, sizeof(addr)); + addr.len = static_cast(sizeof(struct sockaddr_storage)); + grpc_core::ExecCtx exec_ctx; + + if (getpeername(fd, reinterpret_cast(addr.addr), + &(addr.len)) < 0) { + gpr_log(GPR_ERROR, "Failed getpeername: %s", strerror(errno)); + close(fd); + return; + } + (void)grpc_set_socket_no_sigpipe_if_possible(fd); + std::string addr_str = grpc_sockaddr_to_uri(&addr); + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_INFO, "SERVER_CONNECT: incoming external connection: %s", + addr_str.c_str()); + } + std::string name = absl::StrCat("tcp-server-connection:", addr_str); + grpc_fd* fdobj = grpc_fd_create(fd, name.c_str(), true); + read_notifier_pollset = + (*(s_->pollsets))[static_cast(gpr_atm_no_barrier_fetch_add( + &s_->next_pollset_to_assign, 1)) % + s_->pollsets->size()]; + grpc_pollset_add_fd(read_notifier_pollset, fdobj); + grpc_tcp_server_acceptor* acceptor = + static_cast(gpr_malloc(sizeof(*acceptor))); + acceptor->from_server = s_; + acceptor->port_index = -1; + acceptor->fd_index = -1; + acceptor->external_connection = true; + acceptor->listener_fd = listener_fd; + acceptor->pending_data = buf; + s_->on_accept_cb( + s_->on_accept_cb_arg, + grpc_tcp_create( + fdobj, s_->channel_args, addr_str.c_str(), + grpc_slice_allocator_factory_create_slice_allocator( + s_->slice_allocator_factory, addr_str, s_->channel_args)), + read_notifier_pollset, acceptor); + } + + private: + grpc_tcp_server* s_; +}; +} // namespace + +static grpc_core::TcpServerFdHandler* tcp_server_create_fd_handler( + grpc_tcp_server* s) { + s->fd_handler = new ExternalConnectionHandler(s); + return s->fd_handler; +} + +grpc_tcp_server_vtable grpc_posix_tcp_server_vtable = { + tcp_server_create, tcp_server_start, + tcp_server_add_port, tcp_server_create_fd_handler, + tcp_server_port_fd_count, tcp_server_port_fd, + tcp_server_ref, tcp_server_shutdown_starting_add, + tcp_server_unref, tcp_server_shutdown_listeners}; + +#endif /* GRPC_POSIX_SOCKET_TCP_SERVER */ diff --git a/src/core/lib/iomgr/tcp_server_utils_posix_common.cc b/src/core/lib/iomgr/tcp_server_utils_posix_common.cc new file mode 100644 index 00000000..7501596e --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_utils_posix_common.cc @@ -0,0 +1,223 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_SOCKET_TCP_SERVER_UTILS_COMMON + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_server_utils_posix.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +#define MIN_SAFE_ACCEPT_QUEUE_SIZE 100 + +static gpr_once s_init_max_accept_queue_size = GPR_ONCE_INIT; +static int s_max_accept_queue_size; + +/* get max listen queue size on linux */ +static void init_max_accept_queue_size(void) { + int n = SOMAXCONN; + char buf[64]; + FILE* fp = fopen("/proc/sys/net/core/somaxconn", "r"); + if (fp == nullptr) { + /* 2.4 kernel. */ + s_max_accept_queue_size = SOMAXCONN; + return; + } + if (fgets(buf, sizeof buf, fp)) { + char* end; + long i = strtol(buf, &end, 10); + if (i > 0 && i <= INT_MAX && end && *end == '\n') { + n = static_cast(i); + } + } + fclose(fp); + s_max_accept_queue_size = n; + + if (s_max_accept_queue_size < MIN_SAFE_ACCEPT_QUEUE_SIZE) { + gpr_log(GPR_INFO, + "Suspiciously small accept queue (%d) will probably lead to " + "connection drops", + s_max_accept_queue_size); + } +} + +static int get_max_accept_queue_size(void) { + gpr_once_init(&s_init_max_accept_queue_size, init_max_accept_queue_size); + return s_max_accept_queue_size; +} + +static grpc_error_handle add_socket_to_server(grpc_tcp_server* s, int fd, + const grpc_resolved_address* addr, + unsigned port_index, + unsigned fd_index, + grpc_tcp_listener** listener) { + grpc_tcp_listener* sp = nullptr; + int port = -1; + + grpc_error_handle err = + grpc_tcp_server_prepare_socket(s, fd, addr, s->so_reuseport, &port); + if (err == GRPC_ERROR_NONE) { + GPR_ASSERT(port > 0); + std::string addr_str = grpc_sockaddr_to_string(addr, true); + std::string name = absl::StrCat("tcp-server-listener:", addr_str); + gpr_mu_lock(&s->mu); + s->nports++; + GPR_ASSERT(!s->on_accept_cb && "must add ports before starting server"); + sp = static_cast(gpr_malloc(sizeof(grpc_tcp_listener))); + sp->next = nullptr; + if (s->head == nullptr) { + s->head = sp; + } else { + s->tail->next = sp; + } + s->tail = sp; + sp->server = s; + sp->fd = fd; + sp->emfd = grpc_fd_create(fd, name.c_str(), true); + memcpy(&sp->addr, addr, sizeof(grpc_resolved_address)); + sp->port = port; + sp->port_index = port_index; + sp->fd_index = fd_index; + sp->is_sibling = 0; + sp->sibling = nullptr; + GPR_ASSERT(sp->emfd); + gpr_mu_unlock(&s->mu); + } + + *listener = sp; + return err; +} + +/* If successful, add a listener to s for addr, set *dsmode for the socket, and + return the *listener. */ +grpc_error_handle grpc_tcp_server_add_addr(grpc_tcp_server* s, + const grpc_resolved_address* addr, + unsigned port_index, + unsigned fd_index, + grpc_dualstack_mode* dsmode, + grpc_tcp_listener** listener) { + grpc_resolved_address addr4_copy; + int fd; + grpc_error_handle err = + grpc_create_dualstack_socket(addr, SOCK_STREAM, 0, dsmode, &fd); + if (err != GRPC_ERROR_NONE) { + return err; + } + if (*dsmode == GRPC_DSMODE_IPV4 && + grpc_sockaddr_is_v4mapped(addr, &addr4_copy)) { + addr = &addr4_copy; + } + return add_socket_to_server(s, fd, addr, port_index, fd_index, listener); +} + +/* Prepare a recently-created socket for listening. */ +grpc_error_handle grpc_tcp_server_prepare_socket( + grpc_tcp_server* s, int fd, const grpc_resolved_address* addr, + bool so_reuseport, int* port) { + grpc_resolved_address sockname_temp; + grpc_error_handle err = GRPC_ERROR_NONE; + + GPR_ASSERT(fd >= 0); + + if (so_reuseport && !grpc_is_unix_socket(addr)) { + err = grpc_set_socket_reuse_port(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + } + +#ifdef GRPC_LINUX_ERRQUEUE + err = grpc_set_socket_zerocopy(fd); + if (err != GRPC_ERROR_NONE) { + /* it's not fatal, so just log it. */ + gpr_log(GPR_DEBUG, "Node does not support SO_ZEROCOPY, continuing."); + GRPC_ERROR_UNREF(err); + } +#endif + err = grpc_set_socket_nonblocking(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_cloexec(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + if (!grpc_is_unix_socket(addr)) { + err = grpc_set_socket_low_latency(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_reuse_addr(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_tcp_user_timeout(fd, s->channel_args, + false /* is_client */); + if (err != GRPC_ERROR_NONE) goto error; + } + err = grpc_set_socket_no_sigpipe_if_possible(fd); + if (err != GRPC_ERROR_NONE) goto error; + + err = grpc_apply_socket_mutator_in_args(fd, GRPC_FD_SERVER_LISTENER_USAGE, + s->channel_args); + if (err != GRPC_ERROR_NONE) goto error; + + if (bind(fd, reinterpret_cast(const_cast(addr->addr)), + addr->len) < 0) { + err = GRPC_OS_ERROR(errno, "bind"); + goto error; + } + + if (listen(fd, get_max_accept_queue_size()) < 0) { + err = GRPC_OS_ERROR(errno, "listen"); + goto error; + } + + sockname_temp.len = static_cast(sizeof(struct sockaddr_storage)); + + if (getsockname(fd, reinterpret_cast(sockname_temp.addr), + &sockname_temp.len) < 0) { + err = GRPC_OS_ERROR(errno, "getsockname"); + goto error; + } + + *port = grpc_sockaddr_get_port(&sockname_temp); + return GRPC_ERROR_NONE; + +error: + GPR_ASSERT(err != GRPC_ERROR_NONE); + if (fd >= 0) { + close(fd); + } + grpc_error_handle ret = + grpc_error_set_int(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Unable to configure socket", &err, 1), + GRPC_ERROR_INT_FD, fd); + GRPC_ERROR_UNREF(err); + return ret; +} + +#endif /* GRPC_POSIX_SOCKET_TCP_SERVER_UTILS_COMMON */ diff --git a/src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.cc b/src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.cc new file mode 100644 index 00000000..0a8cb7ae --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.cc @@ -0,0 +1,175 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_HAVE_IFADDRS + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_server_utils_posix.h" + +/* Return the listener in s with address addr or NULL. */ +static grpc_tcp_listener* find_listener_with_addr(grpc_tcp_server* s, + grpc_resolved_address* addr) { + grpc_tcp_listener* l; + gpr_mu_lock(&s->mu); + for (l = s->head; l != nullptr; l = l->next) { + if (l->addr.len != addr->len) { + continue; + } + if (memcmp(l->addr.addr, addr->addr, addr->len) == 0) { + break; + } + } + gpr_mu_unlock(&s->mu); + return l; +} + +/* Bind to "::" to get a port number not used by any address. */ +static grpc_error_handle get_unused_port(int* port) { + grpc_resolved_address wild; + grpc_sockaddr_make_wildcard6(0, &wild); + grpc_dualstack_mode dsmode; + int fd; + grpc_error_handle err = + grpc_create_dualstack_socket(&wild, SOCK_STREAM, 0, &dsmode, &fd); + if (err != GRPC_ERROR_NONE) { + return err; + } + if (dsmode == GRPC_DSMODE_IPV4) { + grpc_sockaddr_make_wildcard4(0, &wild); + } + if (bind(fd, reinterpret_cast(wild.addr), wild.len) != + 0) { + err = GRPC_OS_ERROR(errno, "bind"); + close(fd); + return err; + } + if (getsockname(fd, reinterpret_cast(wild.addr), &wild.len) != + 0) { + err = GRPC_OS_ERROR(errno, "getsockname"); + close(fd); + return err; + } + close(fd); + *port = grpc_sockaddr_get_port(&wild); + return *port <= 0 ? GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad port") + : GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_tcp_server_add_all_local_addrs(grpc_tcp_server* s, + unsigned port_index, + int requested_port, + int* out_port) { + struct ifaddrs* ifa = nullptr; + struct ifaddrs* ifa_it; + unsigned fd_index = 0; + grpc_tcp_listener* sp = nullptr; + grpc_error_handle err = GRPC_ERROR_NONE; + if (requested_port == 0) { + /* Note: There could be a race where some local addrs can listen on the + selected port and some can't. The sane way to handle this would be to + retry by recreating the whole grpc_tcp_server. Backing out individual + listeners and orphaning the FDs looks like too much trouble. */ + if ((err = get_unused_port(&requested_port)) != GRPC_ERROR_NONE) { + return err; + } else if (requested_port <= 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Bad get_unused_port()"); + } + gpr_log(GPR_DEBUG, "Picked unused port %d", requested_port); + } + if (getifaddrs(&ifa) != 0 || ifa == nullptr) { + return GRPC_OS_ERROR(errno, "getifaddrs"); + } + for (ifa_it = ifa; ifa_it != nullptr; ifa_it = ifa_it->ifa_next) { + grpc_resolved_address addr; + grpc_dualstack_mode dsmode; + grpc_tcp_listener* new_sp = nullptr; + const char* ifa_name = (ifa_it->ifa_name ? ifa_it->ifa_name : ""); + if (ifa_it->ifa_addr == nullptr) { + continue; + } else if (ifa_it->ifa_addr->sa_family == AF_INET) { + addr.len = static_cast(sizeof(grpc_sockaddr_in)); + } else if (ifa_it->ifa_addr->sa_family == AF_INET6) { + addr.len = static_cast(sizeof(grpc_sockaddr_in6)); + } else { + continue; + } + memcpy(addr.addr, ifa_it->ifa_addr, addr.len); + if (!grpc_sockaddr_set_port(&addr, requested_port)) { + /* Should never happen, because we check sa_family above. */ + err = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to set port"); + break; + } + std::string addr_str = grpc_sockaddr_to_string(&addr, false); + gpr_log(GPR_DEBUG, + "Adding local addr from interface %s flags 0x%x to server: %s", + ifa_name, ifa_it->ifa_flags, addr_str.c_str()); + /* We could have multiple interfaces with the same address (e.g., bonding), + so look for duplicates. */ + if (find_listener_with_addr(s, &addr) != nullptr) { + gpr_log(GPR_DEBUG, "Skipping duplicate addr %s on interface %s", + addr_str.c_str(), ifa_name); + continue; + } + if ((err = grpc_tcp_server_add_addr(s, &addr, port_index, fd_index, &dsmode, + &new_sp)) != GRPC_ERROR_NONE) { + grpc_error_handle root_err = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Failed to add listener: ", addr_str)); + err = grpc_error_add_child(root_err, err); + break; + } else { + GPR_ASSERT(requested_port == new_sp->port); + ++fd_index; + if (sp != nullptr) { + new_sp->is_sibling = 1; + sp->sibling = new_sp; + } + sp = new_sp; + } + } + freeifaddrs(ifa); + if (err != GRPC_ERROR_NONE) { + return err; + } else if (sp == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("No local addresses"); + } else { + *out_port = sp->port; + return GRPC_ERROR_NONE; + } +} + +bool grpc_tcp_server_have_ifaddrs(void) { return true; } + +#endif /* GRPC_HAVE_IFADDRS */ diff --git a/src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.cc b/src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.cc new file mode 100644 index 00000000..beaf489e --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.cc @@ -0,0 +1,36 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#if defined(GRPC_POSIX_SOCKET) && !defined(GRPC_HAVE_IFADDRS) + +#include "src/core/lib/iomgr/tcp_server_utils_posix.h" + +grpc_error_handle grpc_tcp_server_add_all_local_addrs(grpc_tcp_server* /*s*/, + unsigned /*port_index*/, + int /*requested_port*/, + int* /*out_port*/) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("no ifaddrs available"); +} + +bool grpc_tcp_server_have_ifaddrs(void) { return false; } + +#endif /* defined(GRPC_POSIX_SOCKET) && !defined(GRPC_HAVE_IFADDRS) */ diff --git a/src/core/lib/iomgr/tcp_server_windows.cc b/src/core/lib/iomgr/tcp_server_windows.cc new file mode 100644 index 00000000..e853c4f6 --- /dev/null +++ b/src/core/lib/iomgr/tcp_server_windows.cc @@ -0,0 +1,568 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/pollset_windows.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/tcp_windows.h" +#include "src/core/lib/slice/slice_internal.h" + +#define MIN_SAFE_ACCEPT_QUEUE_SIZE 100 + +/* one listening port */ +typedef struct grpc_tcp_listener grpc_tcp_listener; +struct grpc_tcp_listener { + /* This seemingly magic number comes from AcceptEx's documentation. each + address buffer needs to have at least 16 more bytes at their end. */ + uint8_t addresses[(sizeof(grpc_sockaddr_in6) + 16) * 2]; + /* This will hold the socket for the next accept. */ + SOCKET new_socket; + /* The listener winsocket. */ + grpc_winsocket* socket; + /* The actual TCP port number. */ + int port; + unsigned port_index; + grpc_tcp_server* server; + /* The cached AcceptEx for that port. */ + LPFN_ACCEPTEX AcceptEx; + int shutting_down; + int outstanding_calls; + /* closure for socket notification of accept being ready */ + grpc_closure on_accept; + /* linked list */ + struct grpc_tcp_listener* next; +}; + +/* the overall server */ +struct grpc_tcp_server { + gpr_refcount refs; + /* Called whenever accept() succeeds on a server port. */ + grpc_tcp_server_cb on_accept_cb; + void* on_accept_cb_arg; + + gpr_mu mu; + + /* active port count: how many ports are actually still listening */ + int active_ports; + + /* linked list of server ports */ + grpc_tcp_listener* head; + grpc_tcp_listener* tail; + + /* List of closures passed to shutdown_starting_add(). */ + grpc_closure_list shutdown_starting; + + /* shutdown callback */ + grpc_closure* shutdown_complete; + + grpc_channel_args* channel_args; + grpc_slice_allocator_factory* slice_allocator_factory; +}; + +/* Public function. Allocates the proper data structures to hold a + grpc_tcp_server. */ +static grpc_error_handle tcp_server_create( + grpc_closure* shutdown_complete, const grpc_channel_args* args, + grpc_slice_allocator_factory* slice_allocator_factory, + grpc_tcp_server** server) { + grpc_tcp_server* s = (grpc_tcp_server*)gpr_malloc(sizeof(grpc_tcp_server)); + s->channel_args = grpc_channel_args_copy(args); + gpr_ref_init(&s->refs, 1); + gpr_mu_init(&s->mu); + s->active_ports = 0; + s->on_accept_cb = NULL; + s->on_accept_cb_arg = NULL; + s->head = NULL; + s->tail = NULL; + s->shutdown_starting.head = NULL; + s->shutdown_starting.tail = NULL; + s->shutdown_complete = shutdown_complete; + s->slice_allocator_factory = slice_allocator_factory; + *server = s; + return GRPC_ERROR_NONE; +} + +static void destroy_server(void* arg, grpc_error_handle error) { + grpc_tcp_server* s = (grpc_tcp_server*)arg; + + /* Now that the accepts have been aborted, we can destroy the sockets. + The IOCP won't get notified on these, so we can flag them as already + closed by the system. */ + while (s->head) { + grpc_tcp_listener* sp = s->head; + s->head = sp->next; + sp->next = NULL; + grpc_winsocket_destroy(sp->socket); + gpr_free(sp); + } + grpc_channel_args_destroy(s->channel_args); + gpr_mu_destroy(&s->mu); + gpr_free(s); +} + +static void finish_shutdown_locked(grpc_tcp_server* s) { + if (s->shutdown_complete != NULL) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, s->shutdown_complete, + GRPC_ERROR_NONE); + } + + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(destroy_server, s, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); +} + +static grpc_tcp_server* tcp_server_ref(grpc_tcp_server* s) { + gpr_ref_non_zero(&s->refs); + return s; +} + +static void tcp_server_shutdown_starting_add(grpc_tcp_server* s, + grpc_closure* shutdown_starting) { + gpr_mu_lock(&s->mu); + grpc_closure_list_append(&s->shutdown_starting, shutdown_starting, + GRPC_ERROR_NONE); + gpr_mu_unlock(&s->mu); +} + +static void tcp_server_destroy(grpc_tcp_server* s) { + grpc_tcp_listener* sp; + gpr_mu_lock(&s->mu); + grpc_slice_allocator_factory_destroy(s->slice_allocator_factory); + /* First, shutdown all fd's. This will queue abortion calls for all + of the pending accepts due to the normal operation mechanism. */ + if (s->active_ports == 0) { + finish_shutdown_locked(s); + } else { + for (sp = s->head; sp; sp = sp->next) { + sp->shutting_down = 1; + grpc_winsocket_shutdown(sp->socket); + } + } + gpr_mu_unlock(&s->mu); +} + +static void tcp_server_unref(grpc_tcp_server* s) { + if (gpr_unref(&s->refs)) { + grpc_tcp_server_shutdown_listeners(s); + gpr_mu_lock(&s->mu); + grpc_core::ExecCtx::RunList(DEBUG_LOCATION, &s->shutdown_starting); + gpr_mu_unlock(&s->mu); + tcp_server_destroy(s); + } +} + +/* Prepare (bind) a recently-created socket for listening. */ +static grpc_error_handle prepare_socket(SOCKET sock, + const grpc_resolved_address* addr, + int* port) { + grpc_resolved_address sockname_temp; + grpc_error_handle error = GRPC_ERROR_NONE; + int sockname_temp_len; + + error = grpc_tcp_prepare_socket(sock); + if (error != GRPC_ERROR_NONE) { + goto failure; + } + + if (bind(sock, (const grpc_sockaddr*)addr->addr, (int)addr->len) == + SOCKET_ERROR) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "bind"); + goto failure; + } + + if (listen(sock, SOMAXCONN) == SOCKET_ERROR) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "listen"); + goto failure; + } + + sockname_temp_len = sizeof(struct sockaddr_storage); + if (getsockname(sock, (grpc_sockaddr*)sockname_temp.addr, + &sockname_temp_len) == SOCKET_ERROR) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "getsockname"); + goto failure; + } + sockname_temp.len = (size_t)sockname_temp_len; + + *port = grpc_sockaddr_get_port(&sockname_temp); + return GRPC_ERROR_NONE; + +failure: + GPR_ASSERT(error != GRPC_ERROR_NONE); + grpc_error_set_int( + grpc_error_set_str(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to prepare server socket", &error, 1), + GRPC_ERROR_STR_TARGET_ADDRESS, + grpc_sockaddr_to_uri(addr)), + GRPC_ERROR_INT_FD, (intptr_t)sock); + GRPC_ERROR_UNREF(error); + if (sock != INVALID_SOCKET) closesocket(sock); + return error; +} + +static void decrement_active_ports_and_notify_locked(grpc_tcp_listener* sp) { + sp->shutting_down = 0; + GPR_ASSERT(sp->server->active_ports > 0); + if (0 == --sp->server->active_ports) { + finish_shutdown_locked(sp->server); + } +} + +/* In order to do an async accept, we need to create a socket first which + will be the one assigned to the new incoming connection. */ +static grpc_error_handle start_accept_locked(grpc_tcp_listener* port) { + SOCKET sock = INVALID_SOCKET; + BOOL success; + DWORD addrlen = sizeof(grpc_sockaddr_in6) + 16; + DWORD bytes_received = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + + if (port->shutting_down) { + return GRPC_ERROR_NONE; + } + + sock = WSASocket(AF_INET6, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + grpc_get_default_wsa_socket_flags()); + if (sock == INVALID_SOCKET) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "WSASocket"); + goto failure; + } + + error = grpc_tcp_prepare_socket(sock); + if (error != GRPC_ERROR_NONE) goto failure; + + /* Start the "accept" asynchronously. */ + success = port->AcceptEx(port->socket->socket, sock, port->addresses, 0, + addrlen, addrlen, &bytes_received, + &port->socket->read_info.overlapped); + + /* It is possible to get an accept immediately without delay. However, we + will still get an IOCP notification for it. So let's just ignore it. */ + if (!success) { + int last_error = WSAGetLastError(); + if (last_error != ERROR_IO_PENDING) { + error = GRPC_WSA_ERROR(last_error, "AcceptEx"); + goto failure; + } + } + + /* We're ready to do the accept. Calling grpc_socket_notify_on_read may + immediately process an accept that happened in the meantime. */ + port->new_socket = sock; + grpc_socket_notify_on_read(port->socket, &port->on_accept); + port->outstanding_calls++; + return error; + +failure: + GPR_ASSERT(error != GRPC_ERROR_NONE); + if (sock != INVALID_SOCKET) closesocket(sock); + return error; +} + +/* Event manager callback when reads are ready. */ +static void on_accept(void* arg, grpc_error_handle error) { + grpc_tcp_listener* sp = (grpc_tcp_listener*)arg; + SOCKET sock = sp->new_socket; + grpc_winsocket_callback_info* info = &sp->socket->read_info; + grpc_endpoint* ep = NULL; + grpc_resolved_address peer_name; + DWORD transfered_bytes; + DWORD flags; + BOOL wsa_success; + int err; + + gpr_mu_lock(&sp->server->mu); + + peer_name.len = sizeof(struct sockaddr_storage); + + /* The general mechanism for shutting down is to queue abortion calls. While + this is necessary in the read/write case, it's useless for the accept + case. We only need to adjust the pending callback count */ + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, "Skipping on_accept due to error: %s", + grpc_error_std_string(error).c_str()); + + gpr_mu_unlock(&sp->server->mu); + return; + } + /* The IOCP notified us of a completed operation. Let's grab the results, + and act accordingly. */ + transfered_bytes = 0; + wsa_success = WSAGetOverlappedResult(sock, &info->overlapped, + &transfered_bytes, FALSE, &flags); + if (!wsa_success) { + if (!sp->shutting_down) { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_ERROR, "on_accept error: %s", utf8_message); + gpr_free(utf8_message); + } + closesocket(sock); + } else { + if (!sp->shutting_down) { + err = setsockopt(sock, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, + (char*)&sp->socket->socket, sizeof(sp->socket->socket)); + if (err) { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_ERROR, "setsockopt error: %s", utf8_message); + gpr_free(utf8_message); + } + int peer_name_len = (int)peer_name.len; + err = getpeername(sock, (grpc_sockaddr*)peer_name.addr, &peer_name_len); + peer_name.len = (size_t)peer_name_len; + std::string peer_name_string; + if (!err) { + peer_name_string = grpc_sockaddr_to_uri(&peer_name); + } else { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_ERROR, "getpeername error: %s", utf8_message); + gpr_free(utf8_message); + } + std::string fd_name = absl::StrCat("tcp_server:", peer_name_string); + ep = grpc_tcp_create( + grpc_winsocket_create(sock, fd_name.c_str()), + sp->server->channel_args, peer_name_string.c_str(), + grpc_slice_allocator_factory_create_slice_allocator( + sp->server->slice_allocator_factory, peer_name_string)); + } else { + closesocket(sock); + } + } + + /* The only time we should call our callback, is where we successfully + managed to accept a connection, and created an endpoint. */ + if (ep) { + // Create acceptor. + grpc_tcp_server_acceptor* acceptor = + (grpc_tcp_server_acceptor*)gpr_malloc(sizeof(*acceptor)); + acceptor->from_server = sp->server; + acceptor->port_index = sp->port_index; + acceptor->fd_index = 0; + acceptor->external_connection = false; + sp->server->on_accept_cb(sp->server->on_accept_cb_arg, ep, NULL, acceptor); + } + /* As we were notified from the IOCP of one and exactly one accept, + the former socked we created has now either been destroy or assigned + to the new connection. We need to create a new one for the next + connection. */ + GPR_ASSERT(GRPC_LOG_IF_ERROR("start_accept", start_accept_locked(sp))); + if (0 == --sp->outstanding_calls) { + decrement_active_ports_and_notify_locked(sp); + } + gpr_mu_unlock(&sp->server->mu); +} + +static grpc_error_handle add_socket_to_server(grpc_tcp_server* s, SOCKET sock, + const grpc_resolved_address* addr, + unsigned port_index, + grpc_tcp_listener** listener) { + grpc_tcp_listener* sp = NULL; + int port = -1; + int status; + GUID guid = WSAID_ACCEPTEX; + DWORD ioctl_num_bytes; + LPFN_ACCEPTEX AcceptEx; + grpc_error_handle error = GRPC_ERROR_NONE; + + /* We need to grab the AcceptEx pointer for that port, as it may be + interface-dependent. We'll cache it to avoid doing that again. */ + status = + WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), + &AcceptEx, sizeof(AcceptEx), &ioctl_num_bytes, NULL, NULL); + + if (status != 0) { + char* utf8_message = gpr_format_message(WSAGetLastError()); + gpr_log(GPR_ERROR, "on_connect error: %s", utf8_message); + gpr_free(utf8_message); + closesocket(sock); + return GRPC_ERROR_NONE; + } + + error = prepare_socket(sock, addr, &port); + if (error != GRPC_ERROR_NONE) { + return error; + } + + GPR_ASSERT(port >= 0); + gpr_mu_lock(&s->mu); + GPR_ASSERT(!s->on_accept_cb && "must add ports before starting server"); + sp = (grpc_tcp_listener*)gpr_malloc(sizeof(grpc_tcp_listener)); + sp->next = NULL; + if (s->head == NULL) { + s->head = sp; + } else { + s->tail->next = sp; + } + s->tail = sp; + sp->server = s; + sp->socket = grpc_winsocket_create(sock, "listener"); + sp->shutting_down = 0; + sp->outstanding_calls = 0; + sp->AcceptEx = AcceptEx; + sp->new_socket = INVALID_SOCKET; + sp->port = port; + sp->port_index = port_index; + GRPC_CLOSURE_INIT(&sp->on_accept, on_accept, sp, grpc_schedule_on_exec_ctx); + GPR_ASSERT(sp->socket); + gpr_mu_unlock(&s->mu); + *listener = sp; + + return GRPC_ERROR_NONE; +} + +static grpc_error_handle tcp_server_add_port(grpc_tcp_server* s, + const grpc_resolved_address* addr, + int* port) { + grpc_tcp_listener* sp = NULL; + SOCKET sock; + grpc_resolved_address addr6_v4mapped; + grpc_resolved_address wildcard; + grpc_resolved_address* allocated_addr = NULL; + grpc_resolved_address sockname_temp; + unsigned port_index = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + + if (s->tail != NULL) { + port_index = s->tail->port_index + 1; + } + + /* Check if this is a wildcard port, and if so, try to keep the port the same + as some previously created listener. */ + if (grpc_sockaddr_get_port(addr) == 0) { + for (sp = s->head; sp; sp = sp->next) { + int sockname_temp_len = sizeof(struct sockaddr_storage); + if (0 == getsockname(sp->socket->socket, + (grpc_sockaddr*)sockname_temp.addr, + &sockname_temp_len)) { + sockname_temp.len = (size_t)sockname_temp_len; + *port = grpc_sockaddr_get_port(&sockname_temp); + if (*port > 0) { + allocated_addr = + (grpc_resolved_address*)gpr_malloc(sizeof(grpc_resolved_address)); + memcpy(allocated_addr, addr, sizeof(grpc_resolved_address)); + grpc_sockaddr_set_port(allocated_addr, *port); + addr = allocated_addr; + break; + } + } + } + } + + if (grpc_sockaddr_to_v4mapped(addr, &addr6_v4mapped)) { + addr = &addr6_v4mapped; + } + + /* Treat :: or 0.0.0.0 as a family-agnostic wildcard. */ + if (grpc_sockaddr_is_wildcard(addr, port)) { + grpc_sockaddr_make_wildcard6(*port, &wildcard); + + addr = &wildcard; + } + + sock = WSASocket(AF_INET6, SOCK_STREAM, IPPROTO_TCP, NULL, 0, + grpc_get_default_wsa_socket_flags()); + if (sock == INVALID_SOCKET) { + error = GRPC_WSA_ERROR(WSAGetLastError(), "WSASocket"); + goto done; + } + + error = add_socket_to_server(s, sock, addr, port_index, &sp); + +done: + gpr_free(allocated_addr); + + if (error != GRPC_ERROR_NONE) { + grpc_error_handle error_out = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failed to add port to server", &error, 1); + GRPC_ERROR_UNREF(error); + error = error_out; + *port = -1; + } else { + GPR_ASSERT(sp != NULL); + *port = sp->port; + } + return error; +} + +static void tcp_server_start(grpc_tcp_server* s, + const std::vector* /*pollsets*/, + grpc_tcp_server_cb on_accept_cb, + void* on_accept_cb_arg) { + grpc_tcp_listener* sp; + GPR_ASSERT(on_accept_cb); + gpr_mu_lock(&s->mu); + GPR_ASSERT(!s->on_accept_cb); + GPR_ASSERT(s->active_ports == 0); + s->on_accept_cb = on_accept_cb; + s->on_accept_cb_arg = on_accept_cb_arg; + for (sp = s->head; sp; sp = sp->next) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("start_accept", start_accept_locked(sp))); + s->active_ports++; + } + gpr_mu_unlock(&s->mu); +} + +static unsigned tcp_server_port_fd_count(grpc_tcp_server* s, + unsigned port_index) { + return 0; +} + +static int tcp_server_port_fd(grpc_tcp_server* s, unsigned port_index, + unsigned fd_index) { + return -1; +} + +static grpc_core::TcpServerFdHandler* tcp_server_create_fd_handler( + grpc_tcp_server* s) { + return nullptr; +} + +static void tcp_server_shutdown_listeners(grpc_tcp_server* s) {} + +grpc_tcp_server_vtable grpc_windows_tcp_server_vtable = { + tcp_server_create, tcp_server_start, + tcp_server_add_port, tcp_server_create_fd_handler, + tcp_server_port_fd_count, tcp_server_port_fd, + tcp_server_ref, tcp_server_shutdown_starting_add, + tcp_server_unref, tcp_server_shutdown_listeners}; +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/tcp_windows.cc b/src/core/lib/iomgr/tcp_windows.cc new file mode 100644 index 00000000..4f2c151d --- /dev/null +++ b/src/core/lib/iomgr/tcp_windows.cc @@ -0,0 +1,530 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_WINSOCK_SOCKET + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_windows.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +#if defined(__MSYS__) && defined(GPR_ARCH_64) +/* Nasty workaround for nasty bug when using the 64 bits msys compiler + in conjunction with Microsoft Windows headers. */ +#define GRPC_FIONBIO _IOW('f', 126, uint32_t) +#else +#define GRPC_FIONBIO FIONBIO +#endif + +extern grpc_core::TraceFlag grpc_tcp_trace; + +grpc_error_handle grpc_tcp_set_non_block(SOCKET sock) { + int status; + uint32_t param = 1; + DWORD ret; + status = WSAIoctl(sock, GRPC_FIONBIO, ¶m, sizeof(param), NULL, 0, &ret, + NULL, NULL); + return status == 0 + ? GRPC_ERROR_NONE + : GRPC_WSA_ERROR(WSAGetLastError(), "WSAIoctl(GRPC_FIONBIO)"); +} + +static grpc_error_handle set_dualstack(SOCKET sock) { + int status; + unsigned long param = 0; + status = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)¶m, + sizeof(param)); + return status == 0 + ? GRPC_ERROR_NONE + : GRPC_WSA_ERROR(WSAGetLastError(), "setsockopt(IPV6_V6ONLY)"); +} + +static grpc_error_handle enable_socket_low_latency(SOCKET sock) { + int status; + BOOL param = TRUE; + status = ::setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(¶m), sizeof(param)); + if (status == SOCKET_ERROR) { + status = WSAGetLastError(); + } + return status == 0 ? GRPC_ERROR_NONE + : GRPC_WSA_ERROR(status, "setsockopt(TCP_NODELAY)"); +} + +grpc_error_handle grpc_tcp_prepare_socket(SOCKET sock) { + grpc_error_handle err; + err = grpc_tcp_set_non_block(sock); + if (err != GRPC_ERROR_NONE) return err; + err = set_dualstack(sock); + if (err != GRPC_ERROR_NONE) return err; + err = enable_socket_low_latency(sock); + if (err != GRPC_ERROR_NONE) return err; + return GRPC_ERROR_NONE; +} + +typedef struct grpc_tcp { + /* This is our C++ class derivation emulation. */ + grpc_endpoint base; + /* The one socket this endpoint is using. */ + grpc_winsocket* socket; + /* Refcounting how many operations are in progress. */ + gpr_refcount refcount; + + grpc_closure on_read; + grpc_closure on_write; + + grpc_closure* read_cb; + grpc_closure* write_cb; + + /* garbage after the last read */ + grpc_slice_buffer last_read_buffer; + + grpc_slice_buffer* write_slices; + grpc_slice_buffer* read_slices; + + grpc_slice_allocator* slice_allocator; + + /* The IO Completion Port runs from another thread. We need some mechanism + to protect ourselves when requesting a shutdown. */ + gpr_mu mu; + int shutting_down; + grpc_error_handle shutdown_error; + + std::string peer_string; + std::string local_address; +} grpc_tcp; + +static void tcp_free(grpc_tcp* tcp) { + grpc_winsocket_destroy(tcp->socket); + gpr_mu_destroy(&tcp->mu); + grpc_slice_buffer_destroy_internal(&tcp->last_read_buffer); + grpc_slice_allocator_destroy(tcp->slice_allocator); + if (tcp->shutting_down) GRPC_ERROR_UNREF(tcp->shutdown_error); + delete tcp; +} + +#ifndef NDEBUG +#define TCP_UNREF(tcp, reason) tcp_unref((tcp), (reason), __FILE__, __LINE__) +#define TCP_REF(tcp, reason) tcp_ref((tcp), (reason), __FILE__, __LINE__) +static void tcp_unref(grpc_tcp* tcp, const char* reason, const char* file, + int line) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&tcp->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "TCP unref %p : %s %" PRIdPTR " -> %" PRIdPTR, tcp, reason, val, + val - 1); + } + if (gpr_unref(&tcp->refcount)) { + tcp_free(tcp); + } +} + +static void tcp_ref(grpc_tcp* tcp, const char* reason, const char* file, + int line) { + if (grpc_tcp_trace.enabled()) { + gpr_atm val = gpr_atm_no_barrier_load(&tcp->refcount.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "TCP ref %p : %s %" PRIdPTR " -> %" PRIdPTR, tcp, reason, val, + val + 1); + } + gpr_ref(&tcp->refcount); +} +#else +#define TCP_UNREF(tcp, reason) tcp_unref((tcp)) +#define TCP_REF(tcp, reason) tcp_ref((tcp)) +static void tcp_unref(grpc_tcp* tcp) { + if (gpr_unref(&tcp->refcount)) { + tcp_free(tcp); + } +} + +static void tcp_ref(grpc_tcp* tcp) { gpr_ref(&tcp->refcount); } +#endif + +/* Asynchronous callback from the IOCP, or the background thread. */ +static void on_read(void* tcpp, grpc_error_handle error) { + grpc_tcp* tcp = (grpc_tcp*)tcpp; + grpc_closure* cb = tcp->read_cb; + grpc_winsocket* socket = tcp->socket; + grpc_winsocket_callback_info* info = &socket->read_info; + + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_INFO, "TCP:%p on_read", tcp); + } + + (void)GRPC_ERROR_REF(error); + + if (error == GRPC_ERROR_NONE) { + if (info->wsa_error != 0 && !tcp->shutting_down) { + char* utf8_message = gpr_format_message(info->wsa_error); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(utf8_message); + gpr_free(utf8_message); + grpc_slice_buffer_reset_and_unref_internal(tcp->read_slices); + } else { + if (info->bytes_transferred != 0 && !tcp->shutting_down) { + GPR_ASSERT((size_t)info->bytes_transferred <= tcp->read_slices->length); + if (static_cast(info->bytes_transferred) != + tcp->read_slices->length) { + grpc_slice_buffer_trim_end( + tcp->read_slices, + tcp->read_slices->length - + static_cast(info->bytes_transferred), + &tcp->last_read_buffer); + } + GPR_ASSERT((size_t)info->bytes_transferred == tcp->read_slices->length); + + if (grpc_tcp_trace.enabled()) { + size_t i; + for (i = 0; i < tcp->read_slices->count; i++) { + char* dump = grpc_dump_slice(tcp->read_slices->slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "READ %p (peer=%s): %s", tcp, + tcp->peer_string.c_str(), dump); + gpr_free(dump); + } + } + } else { + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_INFO, "TCP:%p unref read_slice", tcp); + } + grpc_slice_buffer_reset_and_unref_internal(tcp->read_slices); + error = tcp->shutting_down + ? GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "TCP stream shutting down", &tcp->shutdown_error, 1) + : GRPC_ERROR_CREATE_FROM_STATIC_STRING("End of TCP stream"); + } + } + } + + tcp->read_cb = NULL; + TCP_UNREF(tcp, "read"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +#define DEFAULT_TARGET_READ_SIZE 8192 +#define MAX_WSABUF_COUNT 16 +static void win_read(grpc_endpoint* ep, grpc_slice_buffer* read_slices, + grpc_closure* cb, bool urgent) { + grpc_tcp* tcp = (grpc_tcp*)ep; + grpc_winsocket* handle = tcp->socket; + grpc_winsocket_callback_info* info = &handle->read_info; + int status; + DWORD bytes_read = 0; + DWORD flags = 0; + WSABUF buffers[MAX_WSABUF_COUNT]; + size_t i; + + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_INFO, "TCP:%p win_read", tcp); + } + + if (tcp->shutting_down) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, cb, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "TCP socket is shutting down", &tcp->shutdown_error, 1)); + return; + } + + tcp->read_cb = cb; + tcp->read_slices = read_slices; + grpc_slice_buffer_reset_and_unref_internal(read_slices); + grpc_slice_buffer_swap(read_slices, &tcp->last_read_buffer); + + if (tcp->read_slices->length < DEFAULT_TARGET_READ_SIZE / 2 && + tcp->read_slices->count < MAX_WSABUF_COUNT) { + // TODO(jtattermusch): slice should be allocated using resource quota + grpc_slice_buffer_add(tcp->read_slices, + GRPC_SLICE_MALLOC(DEFAULT_TARGET_READ_SIZE)); + } + + GPR_ASSERT(tcp->read_slices->count <= MAX_WSABUF_COUNT); + for (i = 0; i < tcp->read_slices->count; i++) { + buffers[i].len = (ULONG)GRPC_SLICE_LENGTH( + tcp->read_slices->slices[i]); // we know slice size fits in 32bit. + buffers[i].buf = (char*)GRPC_SLICE_START_PTR(tcp->read_slices->slices[i]); + } + + TCP_REF(tcp, "read"); + + /* First let's try a synchronous, non-blocking read. */ + status = WSARecv(tcp->socket->socket, buffers, (DWORD)tcp->read_slices->count, + &bytes_read, &flags, NULL, NULL); + info->wsa_error = status == 0 ? 0 : WSAGetLastError(); + + /* Did we get data immediately ? Yay. */ + if (info->wsa_error != WSAEWOULDBLOCK) { + info->bytes_transferred = bytes_read; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &tcp->on_read, GRPC_ERROR_NONE); + return; + } + + /* Otherwise, let's retry, by queuing a read. */ + memset(&tcp->socket->read_info.overlapped, 0, sizeof(OVERLAPPED)); + status = WSARecv(tcp->socket->socket, buffers, (DWORD)tcp->read_slices->count, + &bytes_read, &flags, &info->overlapped, NULL); + + if (status != 0) { + int wsa_error = WSAGetLastError(); + if (wsa_error != WSA_IO_PENDING) { + info->wsa_error = wsa_error; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &tcp->on_read, + GRPC_WSA_ERROR(info->wsa_error, "WSARecv")); + return; + } + } + + grpc_socket_notify_on_read(tcp->socket, &tcp->on_read); +} + +/* Asynchronous callback from the IOCP, or the background thread. */ +static void on_write(void* tcpp, grpc_error_handle error) { + grpc_tcp* tcp = (grpc_tcp*)tcpp; + grpc_winsocket* handle = tcp->socket; + grpc_winsocket_callback_info* info = &handle->write_info; + grpc_closure* cb; + + if (grpc_tcp_trace.enabled()) { + gpr_log(GPR_INFO, "TCP:%p on_write", tcp); + } + + (void)GRPC_ERROR_REF(error); + + gpr_mu_lock(&tcp->mu); + cb = tcp->write_cb; + tcp->write_cb = NULL; + gpr_mu_unlock(&tcp->mu); + + if (error == GRPC_ERROR_NONE) { + if (info->wsa_error != 0) { + error = GRPC_WSA_ERROR(info->wsa_error, "WSASend"); + } else { + GPR_ASSERT(info->bytes_transferred == tcp->write_slices->length); + } + } + + TCP_UNREF(tcp, "write"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +/* Initiates a write. */ +static void win_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg) { + grpc_tcp* tcp = (grpc_tcp*)ep; + grpc_winsocket* socket = tcp->socket; + grpc_winsocket_callback_info* info = &socket->write_info; + unsigned i; + DWORD bytes_sent; + int status; + WSABUF local_buffers[MAX_WSABUF_COUNT]; + WSABUF* allocated = NULL; + WSABUF* buffers = local_buffers; + size_t len; + + if (grpc_tcp_trace.enabled()) { + size_t i; + for (i = 0; i < slices->count; i++) { + char* data = + grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "WRITE %p (peer=%s): %s", tcp, tcp->peer_string.c_str(), + data); + gpr_free(data); + } + } + + if (tcp->shutting_down) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, cb, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "TCP socket is shutting down", &tcp->shutdown_error, 1)); + return; + } + + tcp->write_cb = cb; + tcp->write_slices = slices; + GPR_ASSERT(tcp->write_slices->count <= UINT_MAX); + if (tcp->write_slices->count > GPR_ARRAY_SIZE(local_buffers)) { + buffers = (WSABUF*)gpr_malloc(sizeof(WSABUF) * tcp->write_slices->count); + allocated = buffers; + } + + for (i = 0; i < tcp->write_slices->count; i++) { + len = GRPC_SLICE_LENGTH(tcp->write_slices->slices[i]); + GPR_ASSERT(len <= ULONG_MAX); + buffers[i].len = (ULONG)len; + buffers[i].buf = (char*)GRPC_SLICE_START_PTR(tcp->write_slices->slices[i]); + } + + /* First, let's try a synchronous, non-blocking write. */ + status = WSASend(socket->socket, buffers, (DWORD)tcp->write_slices->count, + &bytes_sent, 0, NULL, NULL); + info->wsa_error = status == 0 ? 0 : WSAGetLastError(); + + /* We would kind of expect to get a WSAEWOULDBLOCK here, especially on a busy + connection that has its send queue filled up. But if we don't, then we can + avoid doing an async write operation at all. */ + if (info->wsa_error != WSAEWOULDBLOCK) { + grpc_error_handle error = status == 0 + ? GRPC_ERROR_NONE + : GRPC_WSA_ERROR(info->wsa_error, "WSASend"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + if (allocated) gpr_free(allocated); + return; + } + + TCP_REF(tcp, "write"); + + /* If we got a WSAEWOULDBLOCK earlier, then we need to re-do the same + operation, this time asynchronously. */ + memset(&socket->write_info.overlapped, 0, sizeof(OVERLAPPED)); + status = WSASend(socket->socket, buffers, (DWORD)tcp->write_slices->count, + &bytes_sent, 0, &socket->write_info.overlapped, NULL); + if (allocated) gpr_free(allocated); + + if (status != 0) { + int wsa_error = WSAGetLastError(); + if (wsa_error != WSA_IO_PENDING) { + TCP_UNREF(tcp, "write"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, + GRPC_WSA_ERROR(wsa_error, "WSASend")); + return; + } + } + + /* As all is now setup, we can now ask for the IOCP notification. It may + trigger the callback immediately however, but no matter. */ + grpc_socket_notify_on_write(socket, &tcp->on_write); +} + +static void win_add_to_pollset(grpc_endpoint* ep, grpc_pollset* ps) { + grpc_tcp* tcp; + (void)ps; + tcp = (grpc_tcp*)ep; + grpc_iocp_add_socket(tcp->socket); +} + +static void win_add_to_pollset_set(grpc_endpoint* ep, grpc_pollset_set* pss) { + grpc_tcp* tcp; + (void)pss; + tcp = (grpc_tcp*)ep; + grpc_iocp_add_socket(tcp->socket); +} + +static void win_delete_from_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pss) {} + +/* Initiates a shutdown of the TCP endpoint. This will queue abort callbacks + for the potential read and write operations. It is up to the caller to + guarantee this isn't called in parallel to a read or write request, so + we're not going to protect against these. However the IO Completion Port + callback will happen from another thread, so we need to protect against + concurrent access of the data structure in that regard. */ +static void win_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + grpc_tcp* tcp = (grpc_tcp*)ep; + gpr_mu_lock(&tcp->mu); + /* At that point, what may happen is that we're already inside the IOCP + callback. See the comments in on_read and on_write. */ + if (!tcp->shutting_down) { + tcp->shutting_down = 1; + tcp->shutdown_error = why; + } else { + GRPC_ERROR_UNREF(why); + } + grpc_winsocket_shutdown(tcp->socket); + gpr_mu_unlock(&tcp->mu); +} + +static void win_destroy(grpc_endpoint* ep) { + grpc_tcp* tcp = (grpc_tcp*)ep; + grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); + TCP_UNREF(tcp, "destroy"); +} + +static absl::string_view win_get_peer(grpc_endpoint* ep) { + grpc_tcp* tcp = (grpc_tcp*)ep; + return tcp->peer_string; +} + +static absl::string_view win_get_local_address(grpc_endpoint* ep) { + grpc_tcp* tcp = (grpc_tcp*)ep; + return tcp->local_address; +} + +static int win_get_fd(grpc_endpoint* ep) { return -1; } + +static bool win_can_track_err(grpc_endpoint* ep) { return false; } + +static grpc_endpoint_vtable vtable = {win_read, + win_write, + win_add_to_pollset, + win_add_to_pollset_set, + win_delete_from_pollset_set, + win_shutdown, + win_destroy, + win_get_peer, + win_get_local_address, + win_get_fd, + win_can_track_err}; + +grpc_endpoint* grpc_tcp_create(grpc_winsocket* socket, + grpc_channel_args* channel_args, + const char* peer_string, + grpc_slice_allocator* slice_allocator) { + grpc_tcp* tcp = new grpc_tcp; + memset(tcp, 0, sizeof(grpc_tcp)); + tcp->base.vtable = &vtable; + tcp->socket = socket; + gpr_mu_init(&tcp->mu); + gpr_ref_init(&tcp->refcount, 1); + GRPC_CLOSURE_INIT(&tcp->on_read, on_read, tcp, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&tcp->on_write, on_write, tcp, grpc_schedule_on_exec_ctx); + grpc_resolved_address resolved_local_addr; + resolved_local_addr.len = sizeof(resolved_local_addr.addr); + if (getsockname(tcp->socket->socket, + reinterpret_cast(resolved_local_addr.addr), + &resolved_local_addr.len) < 0) { + tcp->local_address = ""; + } else { + tcp->local_address = grpc_sockaddr_to_uri(&resolved_local_addr); + } + tcp->peer_string = peer_string; + grpc_slice_buffer_init(&tcp->last_read_buffer); + tcp->slice_allocator = slice_allocator; + return &tcp->base; +} + +#endif /* GRPC_WINSOCK_SOCKET */ diff --git a/src/core/lib/iomgr/time_averaged_stats.cc b/src/core/lib/iomgr/time_averaged_stats.cc new file mode 100644 index 00000000..6369e48d --- /dev/null +++ b/src/core/lib/iomgr/time_averaged_stats.cc @@ -0,0 +1,64 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/time_averaged_stats.h" + +void grpc_time_averaged_stats_init(grpc_time_averaged_stats* stats, + double init_avg, double regress_weight, + double persistence_factor) { + stats->init_avg = init_avg; + stats->regress_weight = regress_weight; + stats->persistence_factor = persistence_factor; + stats->batch_total_value = 0; + stats->batch_num_samples = 0; + stats->aggregate_total_weight = 0; + stats->aggregate_weighted_avg = init_avg; +} + +void grpc_time_averaged_stats_add_sample(grpc_time_averaged_stats* stats, + double value) { + stats->batch_total_value += value; + ++stats->batch_num_samples; +} + +double grpc_time_averaged_stats_update_average( + grpc_time_averaged_stats* stats) { + /* Start with the current batch: */ + double weighted_sum = stats->batch_total_value; + double total_weight = stats->batch_num_samples; + if (stats->regress_weight > 0) { + /* Add in the regression towards init_avg_: */ + weighted_sum += stats->regress_weight * stats->init_avg; + total_weight += stats->regress_weight; + } + if (stats->persistence_factor > 0) { + /* Add in the persistence: */ + const double prev_sample_weight = + stats->persistence_factor * stats->aggregate_total_weight; + weighted_sum += prev_sample_weight * stats->aggregate_weighted_avg; + total_weight += prev_sample_weight; + } + stats->aggregate_weighted_avg = + (total_weight > 0) ? (weighted_sum / total_weight) : stats->init_avg; + stats->aggregate_total_weight = total_weight; + stats->batch_num_samples = 0; + stats->batch_total_value = 0; + return stats->aggregate_weighted_avg; +} diff --git a/src/core/lib/iomgr/timer.cc b/src/core/lib/iomgr/timer.cc new file mode 100644 index 00000000..6506d302 --- /dev/null +++ b/src/core/lib/iomgr/timer.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/timer.h" + +#include "src/core/lib/iomgr/timer_manager.h" + +grpc_timer_vtable* grpc_timer_impl; + +void grpc_set_timer_impl(grpc_timer_vtable* vtable) { + grpc_timer_impl = vtable; +} + +void grpc_timer_init(grpc_timer* timer, grpc_millis deadline, + grpc_closure* closure) { + grpc_timer_impl->init(timer, deadline, closure); +} + +void grpc_timer_cancel(grpc_timer* timer) { grpc_timer_impl->cancel(timer); } + +grpc_timer_check_result grpc_timer_check(grpc_millis* next) { + return grpc_timer_impl->check(next); +} + +void grpc_timer_list_init() { grpc_timer_impl->list_init(); } + +void grpc_timer_list_shutdown() { grpc_timer_impl->list_shutdown(); } + +void grpc_timer_consume_kick() { grpc_timer_impl->consume_kick(); } diff --git a/src/core/lib/iomgr/timer_custom.cc b/src/core/lib/iomgr/timer_custom.cc new file mode 100644 index 00000000..90629665 --- /dev/null +++ b/src/core/lib/iomgr/timer_custom.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/timer_custom.h" + +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/iomgr_custom.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/timer.h" + +static grpc_custom_timer_vtable* custom_timer_impl; + +void grpc_custom_timer_callback(grpc_custom_timer* t, + grpc_error_handle /*error*/) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_timer* timer = t->original; + GPR_ASSERT(timer->pending); + timer->pending = false; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, GRPC_ERROR_NONE); + custom_timer_impl->stop(t); + gpr_free(t); +} + +static void timer_init(grpc_timer* timer, grpc_millis deadline, + grpc_closure* closure) { + uint64_t timeout; + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + if (deadline <= grpc_core::ExecCtx::Get()->Now()) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + timer->pending = false; + return; + } else { + timeout = deadline - now; + } + timer->pending = true; + timer->closure = closure; + grpc_custom_timer* timer_wrapper = + static_cast(gpr_malloc(sizeof(grpc_custom_timer))); + timer_wrapper->timeout_ms = timeout; + timer->custom_timer = timer_wrapper; + timer_wrapper->original = timer; + custom_timer_impl->start(timer_wrapper); +} + +static void timer_cancel(grpc_timer* timer) { + GRPC_CUSTOM_IOMGR_ASSERT_SAME_THREAD(); + grpc_custom_timer* tw = static_cast(timer->custom_timer); + if (timer->pending) { + timer->pending = false; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, + GRPC_ERROR_CANCELLED); + custom_timer_impl->stop(tw); + gpr_free(tw); + } +} + +static grpc_timer_check_result timer_check(grpc_millis* /*next*/) { + return GRPC_TIMERS_NOT_CHECKED; +} + +static void timer_list_init() {} +static void timer_list_shutdown() {} + +static void timer_consume_kick(void) {} + +static grpc_timer_vtable custom_timer_vtable = { + timer_init, timer_cancel, timer_check, + timer_list_init, timer_list_shutdown, timer_consume_kick}; + +void grpc_custom_timer_init(grpc_custom_timer_vtable* impl) { + custom_timer_impl = impl; + grpc_set_timer_impl(&custom_timer_vtable); +} diff --git a/src/core/lib/iomgr/timer_generic.cc b/src/core/lib/iomgr/timer_generic.cc new file mode 100644 index 00000000..74761e24 --- /dev/null +++ b/src/core/lib/iomgr/timer_generic.cc @@ -0,0 +1,718 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/spinlock.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/time_averaged_stats.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/timer_heap.h" + +#define INVALID_HEAP_INDEX 0xffffffffu + +#define ADD_DEADLINE_SCALE 0.33 +#define MIN_QUEUE_WINDOW_DURATION 0.01 +#define MAX_QUEUE_WINDOW_DURATION 1.0 + +grpc_core::TraceFlag grpc_timer_trace(false, "timer"); +grpc_core::TraceFlag grpc_timer_check_trace(false, "timer_check"); + +/* A "timer shard". Contains a 'heap' and a 'list' of timers. All timers with + * deadlines earlier than 'queue_deadline_cap' are maintained in the heap and + * others are maintained in the list (unordered). This helps to keep the number + * of elements in the heap low. + * + * The 'queue_deadline_cap' gets recomputed periodically based on the timer + * stats maintained in 'stats' and the relevant timers are then moved from the + * 'list' to 'heap'. + */ +struct timer_shard { + gpr_mu mu; + grpc_time_averaged_stats stats; + /* All and only timers with deadlines < this will be in the heap. */ + grpc_millis queue_deadline_cap; + /* The deadline of the next timer due in this shard. */ + grpc_millis min_deadline; + /* Index of this timer_shard in the g_shard_queue. */ + uint32_t shard_queue_index; + /* This holds all timers with deadlines < queue_deadline_cap. Timers in this + list have the top bit of their deadline set to 0. */ + grpc_timer_heap heap; + /* This holds timers whose deadline is >= queue_deadline_cap. */ + grpc_timer list; +}; +static size_t g_num_shards; + +/* Array of timer shards. Whenever a timer (grpc_timer *) is added, its address + * is hashed to select the timer shard to add the timer to */ +static timer_shard* g_shards; + +/* Maintains a sorted list of timer shards (sorted by their min_deadline, i.e + * the deadline of the next timer in each shard). + * Access to this is protected by g_shared_mutables.mu */ +static timer_shard** g_shard_queue; + +#ifndef NDEBUG + +/* == DEBUG ONLY: hash table for duplicate timer detection == */ + +#define NUM_HASH_BUCKETS 1009 /* Prime number close to 1000 */ + +static gpr_mu g_hash_mu[NUM_HASH_BUCKETS]; /* One mutex per bucket */ +static grpc_timer* g_timer_ht[NUM_HASH_BUCKETS] = {nullptr}; + +static void init_timer_ht() { + for (int i = 0; i < NUM_HASH_BUCKETS; i++) { + gpr_mu_init(&g_hash_mu[i]); + } +} + +static void destroy_timer_ht() { + for (int i = 0; i < NUM_HASH_BUCKETS; i++) { + gpr_mu_destroy(&g_hash_mu[i]); + } +} + +static bool is_in_ht(grpc_timer* t) { + size_t i = grpc_core::HashPointer(t, NUM_HASH_BUCKETS); + + gpr_mu_lock(&g_hash_mu[i]); + grpc_timer* p = g_timer_ht[i]; + while (p != nullptr && p != t) { + p = p->hash_table_next; + } + gpr_mu_unlock(&g_hash_mu[i]); + + return (p == t); +} + +static void add_to_ht(grpc_timer* t) { + GPR_ASSERT(!t->hash_table_next); + size_t i = grpc_core::HashPointer(t, NUM_HASH_BUCKETS); + + gpr_mu_lock(&g_hash_mu[i]); + grpc_timer* p = g_timer_ht[i]; + while (p != nullptr && p != t) { + p = p->hash_table_next; + } + + if (p == t) { + grpc_closure* c = t->closure; + gpr_log(GPR_ERROR, + "** Duplicate timer (%p) being added. Closure: (%p), created at: " + "(%s:%d), scheduled at: (%s:%d) **", + t, c, c->file_created, c->line_created, c->file_initiated, + c->line_initiated); + abort(); + } + + /* Timer not present in the bucket. Insert at head of the list */ + t->hash_table_next = g_timer_ht[i]; + g_timer_ht[i] = t; + gpr_mu_unlock(&g_hash_mu[i]); +} + +static void remove_from_ht(grpc_timer* t) { + size_t i = grpc_core::HashPointer(t, NUM_HASH_BUCKETS); + bool removed = false; + + gpr_mu_lock(&g_hash_mu[i]); + if (g_timer_ht[i] == t) { + g_timer_ht[i] = g_timer_ht[i]->hash_table_next; + removed = true; + } else if (g_timer_ht[i] != nullptr) { + grpc_timer* p = g_timer_ht[i]; + while (p->hash_table_next != nullptr && p->hash_table_next != t) { + p = p->hash_table_next; + } + + if (p->hash_table_next == t) { + p->hash_table_next = t->hash_table_next; + removed = true; + } + } + gpr_mu_unlock(&g_hash_mu[i]); + + if (!removed) { + grpc_closure* c = t->closure; + gpr_log(GPR_ERROR, + "** Removing timer (%p) that is not added to hash table. Closure " + "(%p), created at: (%s:%d), scheduled at: (%s:%d) **", + t, c, c->file_created, c->line_created, c->file_initiated, + c->line_initiated); + abort(); + } + + t->hash_table_next = nullptr; +} + +/* If a timer is added to a timer shard (either heap or a list), it must + * be pending. A timer is added to hash table only-if it is added to the + * timer shard. + * Therefore, if timer->pending is false, it cannot be in hash table */ +static void validate_non_pending_timer(grpc_timer* t) { + if (!t->pending && is_in_ht(t)) { + grpc_closure* c = t->closure; + gpr_log(GPR_ERROR, + "** gpr_timer_cancel() called on a non-pending timer (%p) which " + "is in the hash table. Closure: (%p), created at: (%s:%d), " + "scheduled at: (%s:%d) **", + t, c, c->file_created, c->line_created, c->file_initiated, + c->line_initiated); + abort(); + } +} + +#define INIT_TIMER_HASH_TABLE() init_timer_ht() +#define DESTROY_TIMER_HASH_TABLE() destroy_timer_ht() +#define ADD_TO_HASH_TABLE(t) add_to_ht((t)) +#define REMOVE_FROM_HASH_TABLE(t) remove_from_ht((t)) +#define VALIDATE_NON_PENDING_TIMER(t) validate_non_pending_timer((t)) + +#else + +#define INIT_TIMER_HASH_TABLE() +#define DESTROY_TIMER_HASH_TABLE() +#define ADD_TO_HASH_TABLE(t) +#define REMOVE_FROM_HASH_TABLE(t) +#define VALIDATE_NON_PENDING_TIMER(t) + +#endif + +/* Thread local variable that stores the deadline of the next timer the thread + * has last-seen. This is an optimization to prevent the thread from checking + * shared_mutables.min_timer (which requires acquiring shared_mutables.mu lock, + * an expensive operation) */ +static GPR_THREAD_LOCAL(grpc_millis) g_last_seen_min_timer; + +struct shared_mutables { + /* The deadline of the next timer due across all timer shards */ + grpc_millis min_timer; + /* Allow only one run_some_expired_timers at once */ + gpr_spinlock checker_mu; + bool initialized; + /* Protects g_shard_queue (and the shared_mutables struct itself) */ + gpr_mu mu; +} GPR_ALIGN_STRUCT(GPR_CACHELINE_SIZE); + +static struct shared_mutables g_shared_mutables; + +static grpc_millis saturating_add(grpc_millis a, grpc_millis b) { + if (a > GRPC_MILLIS_INF_FUTURE - b) { + return GRPC_MILLIS_INF_FUTURE; + } + return a + b; +} + +static grpc_timer_check_result run_some_expired_timers(grpc_millis now, + grpc_millis* next, + grpc_error_handle error); + +static grpc_millis compute_min_deadline(timer_shard* shard) { + return grpc_timer_heap_is_empty(&shard->heap) + ? saturating_add(shard->queue_deadline_cap, 1) + : grpc_timer_heap_top(&shard->heap)->deadline; +} + +static void timer_list_init() { + uint32_t i; + + g_num_shards = grpc_core::Clamp(2 * gpr_cpu_num_cores(), 1u, 32u); + g_shards = + static_cast(gpr_zalloc(g_num_shards * sizeof(*g_shards))); + g_shard_queue = static_cast( + gpr_zalloc(g_num_shards * sizeof(*g_shard_queue))); + + g_shared_mutables.initialized = true; + g_shared_mutables.checker_mu = GPR_SPINLOCK_INITIALIZER; + gpr_mu_init(&g_shared_mutables.mu); + g_shared_mutables.min_timer = grpc_core::ExecCtx::Get()->Now(); + + g_last_seen_min_timer = 0; + + for (i = 0; i < g_num_shards; i++) { + timer_shard* shard = &g_shards[i]; + gpr_mu_init(&shard->mu); + grpc_time_averaged_stats_init(&shard->stats, 1.0 / ADD_DEADLINE_SCALE, 0.1, + 0.5); + shard->queue_deadline_cap = g_shared_mutables.min_timer; + shard->shard_queue_index = i; + grpc_timer_heap_init(&shard->heap); + shard->list.next = shard->list.prev = &shard->list; + shard->min_deadline = compute_min_deadline(shard); + g_shard_queue[i] = shard; + } + + INIT_TIMER_HASH_TABLE(); +} + +static void timer_list_shutdown() { + size_t i; + run_some_expired_timers( + GRPC_MILLIS_INF_FUTURE, nullptr, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Timer list shutdown")); + for (i = 0; i < g_num_shards; i++) { + timer_shard* shard = &g_shards[i]; + gpr_mu_destroy(&shard->mu); + grpc_timer_heap_destroy(&shard->heap); + } + gpr_mu_destroy(&g_shared_mutables.mu); + gpr_free(g_shards); + gpr_free(g_shard_queue); + g_shared_mutables.initialized = false; + + DESTROY_TIMER_HASH_TABLE(); +} + +/* returns true if the first element in the list */ +static void list_join(grpc_timer* head, grpc_timer* timer) { + timer->next = head; + timer->prev = head->prev; + timer->next->prev = timer->prev->next = timer; +} + +static void list_remove(grpc_timer* timer) { + timer->next->prev = timer->prev; + timer->prev->next = timer->next; +} + +static void swap_adjacent_shards_in_queue(uint32_t first_shard_queue_index) { + timer_shard* temp; + temp = g_shard_queue[first_shard_queue_index]; + g_shard_queue[first_shard_queue_index] = + g_shard_queue[first_shard_queue_index + 1]; + g_shard_queue[first_shard_queue_index + 1] = temp; + g_shard_queue[first_shard_queue_index]->shard_queue_index = + first_shard_queue_index; + g_shard_queue[first_shard_queue_index + 1]->shard_queue_index = + first_shard_queue_index + 1; +} + +static void note_deadline_change(timer_shard* shard) { + while (shard->shard_queue_index > 0 && + shard->min_deadline < + g_shard_queue[shard->shard_queue_index - 1]->min_deadline) { + swap_adjacent_shards_in_queue(shard->shard_queue_index - 1); + } + while (shard->shard_queue_index < g_num_shards - 1 && + shard->min_deadline > + g_shard_queue[shard->shard_queue_index + 1]->min_deadline) { + swap_adjacent_shards_in_queue(shard->shard_queue_index); + } +} + +void grpc_timer_init_unset(grpc_timer* timer) { timer->pending = false; } + +static void timer_init(grpc_timer* timer, grpc_millis deadline, + grpc_closure* closure) { + int is_first_timer = 0; + timer_shard* shard = &g_shards[grpc_core::HashPointer(timer, g_num_shards)]; + timer->closure = closure; + timer->deadline = deadline; + +#ifndef NDEBUG + timer->hash_table_next = nullptr; +#endif + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_trace)) { + gpr_log(GPR_INFO, "TIMER %p: SET %" PRId64 " now %" PRId64 " call %p[%p]", + timer, deadline, grpc_core::ExecCtx::Get()->Now(), closure, + closure->cb); + } + + if (!g_shared_mutables.initialized) { + timer->pending = false; + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, timer->closure, + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Attempt to create timer before initialization")); + return; + } + + gpr_mu_lock(&shard->mu); + timer->pending = true; + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + if (deadline <= now) { + timer->pending = false; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, GRPC_ERROR_NONE); + gpr_mu_unlock(&shard->mu); + /* early out */ + return; + } + + grpc_time_averaged_stats_add_sample( + &shard->stats, static_cast(deadline - now) / 1000.0); + + ADD_TO_HASH_TABLE(timer); + + if (deadline < shard->queue_deadline_cap) { + is_first_timer = grpc_timer_heap_add(&shard->heap, timer); + } else { + timer->heap_index = INVALID_HEAP_INDEX; + list_join(&shard->list, timer); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_trace)) { + gpr_log(GPR_INFO, + " .. add to shard %d with queue_deadline_cap=%" PRId64 + " => is_first_timer=%s", + static_cast(shard - g_shards), shard->queue_deadline_cap, + is_first_timer ? "true" : "false"); + } + gpr_mu_unlock(&shard->mu); + + /* Deadline may have decreased, we need to adjust the main queue. Note + that there is a potential racy unlocked region here. There could be a + reordering of multiple grpc_timer_init calls, at this point, but the < test + below should ensure that we err on the side of caution. There could + also be a race with grpc_timer_check, which might beat us to the lock. In + that case, it is possible that the timer that we added will have already + run by the time we hold the lock, but that too is a safe error. + Finally, it's possible that the grpc_timer_check that intervened failed to + trigger the new timer because the min_deadline hadn't yet been reduced. + In that case, the timer will simply have to wait for the next + grpc_timer_check. */ + if (is_first_timer) { + gpr_mu_lock(&g_shared_mutables.mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_trace)) { + gpr_log(GPR_INFO, " .. old shard min_deadline=%" PRId64, + shard->min_deadline); + } + if (deadline < shard->min_deadline) { + grpc_millis old_min_deadline = g_shard_queue[0]->min_deadline; + shard->min_deadline = deadline; + note_deadline_change(shard); + if (shard->shard_queue_index == 0 && deadline < old_min_deadline) { +#if GPR_ARCH_64 + // TODO(sreek): Using c-style cast here. static_cast<> gives an error + // (on mac platforms complaining that gpr_atm* is (long *) while + // (&g_shared_mutables.min_timer) is a (long long *). The cast should be + // safe since we know that both are pointer types and 64-bit wide. + gpr_atm_no_barrier_store((gpr_atm*)(&g_shared_mutables.min_timer), + deadline); +#else + // On 32-bit systems, gpr_atm_no_barrier_store does not work on 64-bit + // types (like grpc_millis). So all reads and writes to + // g_shared_mutables.min_timer varialbe under g_shared_mutables.mu + g_shared_mutables.min_timer = deadline; +#endif + grpc_kick_poller(); + } + } + gpr_mu_unlock(&g_shared_mutables.mu); + } +} + +static void timer_consume_kick(void) { + /* Force re-evaluation of last seen min */ + g_last_seen_min_timer = 0; +} + +static void timer_cancel(grpc_timer* timer) { + if (!g_shared_mutables.initialized) { + /* must have already been cancelled, also the shard mutex is invalid */ + return; + } + + timer_shard* shard = &g_shards[grpc_core::HashPointer(timer, g_num_shards)]; + gpr_mu_lock(&shard->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_trace)) { + gpr_log(GPR_INFO, "TIMER %p: CANCEL pending=%s", timer, + timer->pending ? "true" : "false"); + } + + if (timer->pending) { + REMOVE_FROM_HASH_TABLE(timer); + + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, + GRPC_ERROR_CANCELLED); + timer->pending = false; + if (timer->heap_index == INVALID_HEAP_INDEX) { + list_remove(timer); + } else { + grpc_timer_heap_remove(&shard->heap, timer); + } + } else { + VALIDATE_NON_PENDING_TIMER(timer); + } + gpr_mu_unlock(&shard->mu); +} + +/* Rebalances the timer shard by computing a new 'queue_deadline_cap' and moving + all relevant timers in shard->list (i.e timers with deadlines earlier than + 'queue_deadline_cap') into into shard->heap. + Returns 'true' if shard->heap has at least ONE element + REQUIRES: shard->mu locked */ +static bool refill_heap(timer_shard* shard, grpc_millis now) { + /* Compute the new queue window width and bound by the limits: */ + double computed_deadline_delta = + grpc_time_averaged_stats_update_average(&shard->stats) * + ADD_DEADLINE_SCALE; + double deadline_delta = + grpc_core::Clamp(computed_deadline_delta, MIN_QUEUE_WINDOW_DURATION, + MAX_QUEUE_WINDOW_DURATION); + grpc_timer *timer, *next; + + /* Compute the new cap and put all timers under it into the queue: */ + shard->queue_deadline_cap = + saturating_add(std::max(now, shard->queue_deadline_cap), + static_cast(deadline_delta * 1000.0)); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, " .. shard[%d]->queue_deadline_cap --> %" PRId64, + static_cast(shard - g_shards), shard->queue_deadline_cap); + } + for (timer = shard->list.next; timer != &shard->list; timer = next) { + next = timer->next; + + if (timer->deadline < shard->queue_deadline_cap) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, " .. add timer with deadline %" PRId64 " to heap", + timer->deadline); + } + list_remove(timer); + grpc_timer_heap_add(&shard->heap, timer); + } + } + return !grpc_timer_heap_is_empty(&shard->heap); +} + +/* This pops the next non-cancelled timer with deadline <= now from the + queue, or returns NULL if there isn't one. + REQUIRES: shard->mu locked */ +static grpc_timer* pop_one(timer_shard* shard, grpc_millis now) { + grpc_timer* timer; + for (;;) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, " .. shard[%d]: heap_empty=%s", + static_cast(shard - g_shards), + grpc_timer_heap_is_empty(&shard->heap) ? "true" : "false"); + } + if (grpc_timer_heap_is_empty(&shard->heap)) { + if (now < shard->queue_deadline_cap) return nullptr; + if (!refill_heap(shard, now)) return nullptr; + } + timer = grpc_timer_heap_top(&shard->heap); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, + " .. check top timer deadline=%" PRId64 " now=%" PRId64, + timer->deadline, now); + } + if (timer->deadline > now) return nullptr; + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_trace)) { + gpr_log(GPR_INFO, "TIMER %p: FIRE %" PRId64 "ms late", timer, + now - timer->deadline); + } + timer->pending = false; + grpc_timer_heap_pop(&shard->heap); + return timer; + } +} + +/* REQUIRES: shard->mu unlocked */ +static size_t pop_timers(timer_shard* shard, grpc_millis now, + grpc_millis* new_min_deadline, + grpc_error_handle error) { + size_t n = 0; + grpc_timer* timer; + gpr_mu_lock(&shard->mu); + while ((timer = pop_one(shard, now))) { + REMOVE_FROM_HASH_TABLE(timer); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, timer->closure, + GRPC_ERROR_REF(error)); + n++; + } + *new_min_deadline = compute_min_deadline(shard); + gpr_mu_unlock(&shard->mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, " .. shard[%d] popped %" PRIdPTR, + static_cast(shard - g_shards), n); + } + return n; +} + +static grpc_timer_check_result run_some_expired_timers( + grpc_millis now, grpc_millis* next, grpc_error_handle error) { + grpc_timer_check_result result = GRPC_TIMERS_NOT_CHECKED; + +#if GPR_ARCH_64 + // TODO(sreek): Using c-style cast here. static_cast<> gives an error (on + // mac platforms complaining that gpr_atm* is (long *) while + // (&g_shared_mutables.min_timer) is a (long long *). The cast should be + // safe since we know that both are pointer types and 64-bit wide + grpc_millis min_timer = static_cast( + gpr_atm_no_barrier_load((gpr_atm*)(&g_shared_mutables.min_timer))); +#else + // On 32-bit systems, gpr_atm_no_barrier_load does not work on 64-bit types + // (like grpc_millis). So all reads and writes to g_shared_mutables.min_timer + // are done under g_shared_mutables.mu + gpr_mu_lock(&g_shared_mutables.mu); + grpc_millis min_timer = g_shared_mutables.min_timer; + gpr_mu_unlock(&g_shared_mutables.mu); +#endif + g_last_seen_min_timer = min_timer; + + if (now < min_timer) { + if (next != nullptr) *next = std::min(*next, min_timer); + return GRPC_TIMERS_CHECKED_AND_EMPTY; + } + + if (gpr_spinlock_trylock(&g_shared_mutables.checker_mu)) { + gpr_mu_lock(&g_shared_mutables.mu); + result = GRPC_TIMERS_CHECKED_AND_EMPTY; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, " .. shard[%d]->min_deadline = %" PRId64, + static_cast(g_shard_queue[0] - g_shards), + g_shard_queue[0]->min_deadline); + } + + while (g_shard_queue[0]->min_deadline < now || + (now != GRPC_MILLIS_INF_FUTURE && + g_shard_queue[0]->min_deadline == now)) { + grpc_millis new_min_deadline; + + /* For efficiency, we pop as many available timers as we can from the + shard. This may violate perfect timer deadline ordering, but that + shouldn't be a big deal because we don't make ordering guarantees. */ + if (pop_timers(g_shard_queue[0], now, &new_min_deadline, error) > 0) { + result = GRPC_TIMERS_FIRED; + } + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, + " .. result --> %d" + ", shard[%d]->min_deadline %" PRId64 " --> %" PRId64 + ", now=%" PRId64, + result, static_cast(g_shard_queue[0] - g_shards), + g_shard_queue[0]->min_deadline, new_min_deadline, now); + } + + /* An grpc_timer_init() on the shard could intervene here, adding a new + timer that is earlier than new_min_deadline. However, + grpc_timer_init() will block on the mutex before it can call + set_min_deadline, so this one will complete first and then the Addtimer + will reduce the min_deadline (perhaps unnecessarily). */ + g_shard_queue[0]->min_deadline = new_min_deadline; + note_deadline_change(g_shard_queue[0]); + } + + if (next) { + *next = std::min(*next, g_shard_queue[0]->min_deadline); + } + +#if GPR_ARCH_64 + // TODO(sreek): Using c-style cast here. static_cast<> gives an error (on + // mac platforms complaining that gpr_atm* is (long *) while + // (&g_shared_mutables.min_timer) is a (long long *). The cast should be + // safe since we know that both are pointer types and 64-bit wide + gpr_atm_no_barrier_store((gpr_atm*)(&g_shared_mutables.min_timer), + g_shard_queue[0]->min_deadline); +#else + // On 32-bit systems, gpr_atm_no_barrier_store does not work on 64-bit + // types (like grpc_millis). So all reads and writes to + // g_shared_mutables.min_timer are done under g_shared_mutables.mu + g_shared_mutables.min_timer = g_shard_queue[0]->min_deadline; +#endif + gpr_mu_unlock(&g_shared_mutables.mu); + gpr_spinlock_unlock(&g_shared_mutables.checker_mu); + } + + GRPC_ERROR_UNREF(error); + + return result; +} + +static grpc_timer_check_result timer_check(grpc_millis* next) { + // prelude + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + + /* fetch from a thread-local first: this avoids contention on a globally + mutable cacheline in the common case */ + grpc_millis min_timer = g_last_seen_min_timer; + + if (now < min_timer) { + if (next != nullptr) { + *next = std::min(*next, min_timer); + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "TIMER CHECK SKIP: now=%" PRId64 " min_timer=%" PRId64, + now, min_timer); + } + return GRPC_TIMERS_CHECKED_AND_EMPTY; + } + + grpc_error_handle shutdown_error = + now != GRPC_MILLIS_INF_FUTURE + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING("Shutting down timer system"); + + // tracing + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + std::string next_str; + if (next == nullptr) { + next_str = "NULL"; + } else { + next_str = absl::StrCat(*next); + } +#if GPR_ARCH_64 + gpr_log(GPR_INFO, + "TIMER CHECK BEGIN: now=%" PRId64 " next=%s tls_min=%" PRId64 + " glob_min=%" PRId64, + now, next_str.c_str(), min_timer, + static_cast(gpr_atm_no_barrier_load( + (gpr_atm*)(&g_shared_mutables.min_timer)))); +#else + gpr_log(GPR_INFO, "TIMER CHECK BEGIN: now=%" PRId64 " next=%s min=%" PRId64, + now, next_str.c_str(), min_timer); +#endif + } + // actual code + grpc_timer_check_result r = + run_some_expired_timers(now, next, shutdown_error); + // tracing + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + std::string next_str; + if (next == nullptr) { + next_str = "NULL"; + } else { + next_str = absl::StrCat(*next); + } + gpr_log(GPR_INFO, "TIMER CHECK END: r=%d; next=%s", r, next_str.c_str()); + } + return r; +} + +grpc_timer_vtable grpc_generic_timer_vtable = { + timer_init, timer_cancel, timer_check, + timer_list_init, timer_list_shutdown, timer_consume_kick}; diff --git a/src/core/lib/iomgr/timer_heap.cc b/src/core/lib/iomgr/timer_heap.cc new file mode 100644 index 00000000..604bfe1b --- /dev/null +++ b/src/core/lib/iomgr/timer_heap.cc @@ -0,0 +1,134 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/timer_heap.h" + +#include + +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/port.h" + +/* Adjusts a heap so as to move a hole at position i closer to the root, + until a suitable position is found for element t. Then, copies t into that + position. This functor is called each time immediately after modifying a + value in the underlying container, with the offset of the modified element as + its argument. */ +static void adjust_upwards(grpc_timer** first, uint32_t i, grpc_timer* t) { + while (i > 0) { + uint32_t parent = static_cast((static_cast(i) - 1) / 2); + if (first[parent]->deadline <= t->deadline) break; + first[i] = first[parent]; + first[i]->heap_index = i; + i = parent; + } + first[i] = t; + t->heap_index = i; +} + +/* Adjusts a heap so as to move a hole at position i farther away from the root, + until a suitable position is found for element t. Then, copies t into that + position. */ +static void adjust_downwards(grpc_timer** first, uint32_t i, uint32_t length, + grpc_timer* t) { + for (;;) { + uint32_t left_child = 1u + 2u * i; + if (left_child >= length) break; + uint32_t right_child = left_child + 1; + uint32_t next_i = right_child < length && first[left_child]->deadline > + first[right_child]->deadline + ? right_child + : left_child; + if (t->deadline <= first[next_i]->deadline) break; + first[i] = first[next_i]; + first[i]->heap_index = i; + i = next_i; + } + first[i] = t; + t->heap_index = i; +} + +#define SHRINK_MIN_ELEMS 8 +#define SHRINK_FULLNESS_FACTOR 2 + +static void maybe_shrink(grpc_timer_heap* heap) { + if (heap->timer_count >= 8 && + heap->timer_count <= heap->timer_capacity / SHRINK_FULLNESS_FACTOR / 2) { + heap->timer_capacity = heap->timer_count * SHRINK_FULLNESS_FACTOR; + heap->timers = static_cast( + gpr_realloc(heap->timers, heap->timer_capacity * sizeof(grpc_timer*))); + } +} + +static void note_changed_priority(grpc_timer_heap* heap, grpc_timer* timer) { + uint32_t i = timer->heap_index; + uint32_t parent = static_cast((static_cast(i) - 1) / 2); + if (heap->timers[parent]->deadline > timer->deadline) { + adjust_upwards(heap->timers, i, timer); + } else { + adjust_downwards(heap->timers, i, heap->timer_count, timer); + } +} + +void grpc_timer_heap_init(grpc_timer_heap* heap) { + memset(heap, 0, sizeof(*heap)); +} + +void grpc_timer_heap_destroy(grpc_timer_heap* heap) { gpr_free(heap->timers); } + +bool grpc_timer_heap_add(grpc_timer_heap* heap, grpc_timer* timer) { + if (heap->timer_count == heap->timer_capacity) { + heap->timer_capacity = + std::max(heap->timer_capacity + 1, heap->timer_capacity * 3 / 2); + heap->timers = static_cast( + gpr_realloc(heap->timers, heap->timer_capacity * sizeof(grpc_timer*))); + } + timer->heap_index = heap->timer_count; + adjust_upwards(heap->timers, heap->timer_count, timer); + heap->timer_count++; + return timer->heap_index == 0; +} + +void grpc_timer_heap_remove(grpc_timer_heap* heap, grpc_timer* timer) { + uint32_t i = timer->heap_index; + if (i == heap->timer_count - 1) { + heap->timer_count--; + maybe_shrink(heap); + return; + } + heap->timers[i] = heap->timers[heap->timer_count - 1]; + heap->timers[i]->heap_index = i; + heap->timer_count--; + maybe_shrink(heap); + note_changed_priority(heap, heap->timers[i]); +} + +bool grpc_timer_heap_is_empty(grpc_timer_heap* heap) { + return heap->timer_count == 0; +} + +grpc_timer* grpc_timer_heap_top(grpc_timer_heap* heap) { + return heap->timers[0]; +} + +void grpc_timer_heap_pop(grpc_timer_heap* heap) { + grpc_timer_heap_remove(heap, grpc_timer_heap_top(heap)); +} diff --git a/src/core/lib/iomgr/timer_manager.cc b/src/core/lib/iomgr/timer_manager.cc new file mode 100644 index 00000000..317e4634 --- /dev/null +++ b/src/core/lib/iomgr/timer_manager.cc @@ -0,0 +1,363 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/timer_manager.h" + +#include + +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/timer.h" + +struct completed_thread { + grpc_core::Thread thd; + completed_thread* next; +}; + +extern grpc_core::TraceFlag grpc_timer_check_trace; + +// global mutex +static gpr_mu g_mu; +// are we multi-threaded +static bool g_threaded; +// cv to wait until a thread is needed +static gpr_cv g_cv_wait; +// cv for notification when threading ends +static gpr_cv g_cv_shutdown; +// number of threads in the system +static int g_thread_count; +// number of threads sitting around waiting +static int g_waiter_count; +// linked list of threads that have completed (and need joining) +static completed_thread* g_completed_threads; +// was the manager kicked by the timer system +static bool g_kicked; +// is there a thread waiting until the next timer should fire? +static bool g_has_timed_waiter; +// the deadline of the current timed waiter thread (only relevant if +// g_has_timed_waiter is true) +static grpc_millis g_timed_waiter_deadline; +// generation counter to track which thread is waiting for the next timer +static uint64_t g_timed_waiter_generation; +// number of timer wakeups +static uint64_t g_wakeups; + +static void timer_thread(void* completed_thread_ptr); + +static void gc_completed_threads(void) { + if (g_completed_threads != nullptr) { + completed_thread* to_gc = g_completed_threads; + g_completed_threads = nullptr; + gpr_mu_unlock(&g_mu); + while (to_gc != nullptr) { + to_gc->thd.Join(); + completed_thread* next = to_gc->next; + gpr_free(to_gc); + to_gc = next; + } + gpr_mu_lock(&g_mu); + } +} + +static void start_timer_thread_and_unlock(void) { + GPR_ASSERT(g_threaded); + ++g_waiter_count; + ++g_thread_count; + gpr_mu_unlock(&g_mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "Spawn timer thread"); + } + completed_thread* ct = + static_cast(gpr_malloc(sizeof(*ct))); + ct->thd = grpc_core::Thread("grpc_global_timer", timer_thread, ct); + ct->thd.Start(); +} + +void grpc_timer_manager_tick() { + grpc_core::ExecCtx exec_ctx; + grpc_timer_check(nullptr); +} + +static void run_some_timers() { + // In the case of timers, the ExecCtx for the thread is declared + // in the timer thread itself, but this is the point where we + // could start seeing application-level callbacks. No need to + // create a new ExecCtx, though, since there already is one and it is + // flushed (but not destructed) in this function itself + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + + // if there's something to execute... + gpr_mu_lock(&g_mu); + // remove a waiter from the pool, and start another thread if necessary + --g_waiter_count; + if (g_waiter_count == 0 && g_threaded) { + // The number of timer threads is always increasing until all the threads + // are stopped. In rare cases, if a large number of timers fire + // simultaneously, we may end up using a large number of threads. + start_timer_thread_and_unlock(); + } else { + // if there's no thread waiting with a timeout, kick an existing untimed + // waiter so that the next deadline is not missed + if (!g_has_timed_waiter) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "kick untimed waiter"); + } + gpr_cv_signal(&g_cv_wait); + } + gpr_mu_unlock(&g_mu); + } + // without our lock, flush the exec_ctx + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "flush exec_ctx"); + } + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&g_mu); + // garbage collect any threads that are dead + gc_completed_threads(); + // get ready to wait again + ++g_waiter_count; + gpr_mu_unlock(&g_mu); +} + +// wait until 'next' (or forever if there is already a timed waiter in the pool) +// returns true if the thread should continue executing (false if it should +// shutdown) +static bool wait_until(grpc_millis next) { + gpr_mu_lock(&g_mu); + // if we're not threaded anymore, leave + if (!g_threaded) { + gpr_mu_unlock(&g_mu); + return false; + } + + // If g_kicked is true at this point, it means there was a kick from the timer + // system that the timer-manager threads here missed. We cannot trust 'next' + // here any longer (since there might be an earlier deadline). So if g_kicked + // is true at this point, we should quickly exit this and get the next + // deadline from the timer system + + if (!g_kicked) { + // if there's no timed waiter, we should become one: that waiter waits + // only until the next timer should expire. All other timers wait forever + // + // 'g_timed_waiter_generation' is a global generation counter. The idea here + // is that the thread becoming a timed-waiter increments and stores this + // global counter locally in 'my_timed_waiter_generation' before going to + // sleep. After waking up, if my_timed_waiter_generation == + // g_timed_waiter_generation, it can be sure that it was the timed_waiter + // thread (and that no other thread took over while this was asleep) + // + // Initialize my_timed_waiter_generation to some value that is NOT equal to + // g_timed_waiter_generation + uint64_t my_timed_waiter_generation = g_timed_waiter_generation - 1; + + /* If there's no timed waiter, we should become one: that waiter waits only + until the next timer should expire. All other timer threads wait forever + unless their 'next' is earlier than the current timed-waiter's deadline + (in which case the thread with earlier 'next' takes over as the new timed + waiter) */ + if (next != GRPC_MILLIS_INF_FUTURE) { + if (!g_has_timed_waiter || (next < g_timed_waiter_deadline)) { + my_timed_waiter_generation = ++g_timed_waiter_generation; + g_has_timed_waiter = true; + g_timed_waiter_deadline = next; + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + grpc_millis wait_time = next - grpc_core::ExecCtx::Get()->Now(); + gpr_log(GPR_INFO, "sleep for a %" PRId64 " milliseconds", wait_time); + } + } else { // g_timed_waiter == true && next >= g_timed_waiter_deadline + next = GRPC_MILLIS_INF_FUTURE; + } + } + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace) && + next == GRPC_MILLIS_INF_FUTURE) { + gpr_log(GPR_INFO, "sleep until kicked"); + } + + gpr_cv_wait(&g_cv_wait, &g_mu, + grpc_millis_to_timespec(next, GPR_CLOCK_MONOTONIC)); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "wait ended: was_timed:%d kicked:%d", + my_timed_waiter_generation == g_timed_waiter_generation, + g_kicked); + } + // if this was the timed waiter, then we need to check timers, and flag + // that there's now no timed waiter... we'll look for a replacement if + // there's work to do after checking timers (code above) + if (my_timed_waiter_generation == g_timed_waiter_generation) { + ++g_wakeups; + g_has_timed_waiter = false; + g_timed_waiter_deadline = GRPC_MILLIS_INF_FUTURE; + } + } + + // if this was a kick from the timer system, consume it (and don't stop + // this thread yet) + if (g_kicked) { + grpc_timer_consume_kick(); + g_kicked = false; + } + + gpr_mu_unlock(&g_mu); + return true; +} + +static void timer_main_loop() { + for (;;) { + grpc_millis next = GRPC_MILLIS_INF_FUTURE; + grpc_core::ExecCtx::Get()->InvalidateNow(); + + // check timer state, updates next to the next time to run a check + switch (grpc_timer_check(&next)) { + case GRPC_TIMERS_FIRED: + run_some_timers(); + break; + case GRPC_TIMERS_NOT_CHECKED: + /* This case only happens under contention, meaning more than one timer + manager thread checked timers concurrently. + + If that happens, we're guaranteed that some other thread has just + checked timers, and this will avalanche into some other thread seeing + empty timers and doing a timed sleep. + + Consequently, we can just sleep forever here and be happy at some + saved wakeup cycles. */ + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "timers not checked: expect another thread to"); + } + next = GRPC_MILLIS_INF_FUTURE; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_TIMERS_CHECKED_AND_EMPTY: + if (!wait_until(next)) { + return; + } + break; + } + } +} + +static void timer_thread_cleanup(completed_thread* ct) { + gpr_mu_lock(&g_mu); + // terminate the thread: drop the waiter count, thread count, and let whomever + // stopped the threading stuff know that we're done + --g_waiter_count; + --g_thread_count; + if (0 == g_thread_count) { + gpr_cv_signal(&g_cv_shutdown); + } + ct->next = g_completed_threads; + g_completed_threads = ct; + gpr_mu_unlock(&g_mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "End timer thread"); + } +} + +static void timer_thread(void* completed_thread_ptr) { + // this threads exec_ctx: we try to run things through to completion here + // since it's easy to spin up new threads + grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + timer_main_loop(); + + timer_thread_cleanup(static_cast(completed_thread_ptr)); +} + +static void start_threads(void) { + gpr_mu_lock(&g_mu); + if (!g_threaded) { + g_threaded = true; + start_timer_thread_and_unlock(); + } else { + gpr_mu_unlock(&g_mu); + } +} + +void grpc_timer_manager_init(void) { + gpr_mu_init(&g_mu); + gpr_cv_init(&g_cv_wait); + gpr_cv_init(&g_cv_shutdown); + g_threaded = false; + g_thread_count = 0; + g_waiter_count = 0; + g_completed_threads = nullptr; + + g_has_timed_waiter = false; + g_timed_waiter_deadline = GRPC_MILLIS_INF_FUTURE; + + start_threads(); +} + +static void stop_threads(void) { + gpr_mu_lock(&g_mu); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "stop timer threads: threaded=%d", g_threaded); + } + if (g_threaded) { + g_threaded = false; + gpr_cv_broadcast(&g_cv_wait); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "num timer threads: %d", g_thread_count); + } + while (g_thread_count > 0) { + gpr_cv_wait(&g_cv_shutdown, &g_mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_timer_check_trace)) { + gpr_log(GPR_INFO, "num timer threads: %d", g_thread_count); + } + gc_completed_threads(); + } + } + g_wakeups = 0; + gpr_mu_unlock(&g_mu); +} + +void grpc_timer_manager_shutdown(void) { + stop_threads(); + + gpr_mu_destroy(&g_mu); + gpr_cv_destroy(&g_cv_wait); + gpr_cv_destroy(&g_cv_shutdown); +} + +void grpc_timer_manager_set_threading(bool enabled) { + if (enabled) { + start_threads(); + } else { + stop_threads(); + } +} + +void grpc_kick_poller(void) { + gpr_mu_lock(&g_mu); + g_kicked = true; + g_has_timed_waiter = false; + g_timed_waiter_deadline = GRPC_MILLIS_INF_FUTURE; + ++g_timed_waiter_generation; + gpr_cv_signal(&g_cv_wait); + gpr_mu_unlock(&g_mu); +} + +uint64_t grpc_timer_manager_get_wakeups_testonly(void) { return g_wakeups; } diff --git a/src/core/lib/iomgr/unix_sockets_posix.cc b/src/core/lib/iomgr/unix_sockets_posix.cc new file mode 100644 index 00000000..0154831c --- /dev/null +++ b/src/core/lib/iomgr/unix_sockets_posix.cc @@ -0,0 +1,108 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_HAVE_UNIX_SOCKET + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +void grpc_create_socketpair_if_unix(int sv[2]) { + GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == 0); +} + +grpc_error_handle grpc_resolve_unix_domain_address( + const char* name, grpc_resolved_addresses** addresses) { + *addresses = static_cast( + gpr_malloc(sizeof(grpc_resolved_addresses))); + (*addresses)->naddrs = 1; + (*addresses)->addrs = static_cast( + gpr_malloc(sizeof(grpc_resolved_address))); + return grpc_core::UnixSockaddrPopulate(name, (*addresses)->addrs); +} + +grpc_error_handle grpc_resolve_unix_abstract_domain_address( + const absl::string_view name, grpc_resolved_addresses** addresses) { + *addresses = static_cast( + gpr_malloc(sizeof(grpc_resolved_addresses))); + (*addresses)->naddrs = 1; + (*addresses)->addrs = static_cast( + gpr_malloc(sizeof(grpc_resolved_address))); + return grpc_core::UnixAbstractSockaddrPopulate(name, (*addresses)->addrs); +} + +int grpc_is_unix_socket(const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + return addr->sa_family == AF_UNIX; +} + +void grpc_unlink_if_unix_domain_socket( + const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + if (addr->sa_family != AF_UNIX) { + return; + } + struct sockaddr_un* un = reinterpret_cast( + const_cast(resolved_addr->addr)); + + // There is nothing to unlink for an abstract unix socket + if (un->sun_path[0] == '\0' && un->sun_path[1] != '\0') { + return; + } + + struct stat st; + if (stat(un->sun_path, &st) == 0 && (st.st_mode & S_IFMT) == S_IFSOCK) { + unlink(un->sun_path); + } +} + +std::string grpc_sockaddr_to_uri_unix_if_possible( + const grpc_resolved_address* resolved_addr) { + const grpc_sockaddr* addr = + reinterpret_cast(resolved_addr->addr); + if (addr->sa_family != AF_UNIX) { + return ""; + } + const auto* unix_addr = reinterpret_cast(addr); + if (unix_addr->sun_path[0] == '\0' && unix_addr->sun_path[1] != '\0') { + return absl::StrCat( + "unix-abstract:", + absl::string_view( + unix_addr->sun_path + 1, + resolved_addr->len - sizeof(unix_addr->sun_family) - 1)); + } + return absl::StrCat("unix:", unix_addr->sun_path); +} + +#endif diff --git a/src/core/lib/iomgr/unix_sockets_posix_noop.cc b/src/core/lib/iomgr/unix_sockets_posix_noop.cc new file mode 100644 index 00000000..c2659221 --- /dev/null +++ b/src/core/lib/iomgr/unix_sockets_posix_noop.cc @@ -0,0 +1,62 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/unix_sockets_posix.h" + +#ifndef GRPC_HAVE_UNIX_SOCKET + +#include + +#include + +void grpc_create_socketpair_if_unix(int /* sv */[2]) { + // TODO: Either implement this for the non-Unix socket case or make + // sure that it is never called in any such case. Until then, leave an + // assertion to notify if this gets called inadvertently + GPR_ASSERT(0); +} + +grpc_error_handle grpc_resolve_unix_domain_address( + const char* /* name */, grpc_resolved_addresses** addresses) { + *addresses = NULL; + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unix domain sockets are not supported on Windows"); +} + +grpc_error_handle grpc_resolve_unix_abstract_domain_address( + absl::string_view, grpc_resolved_addresses** addresses) { + *addresses = NULL; + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unix domain sockets are not supported on Windows"); +} + +int grpc_is_unix_socket(const grpc_resolved_address* /* addr */) { + return false; +} + +void grpc_unlink_if_unix_domain_socket( + const grpc_resolved_address* /* addr */) {} + +std::string grpc_sockaddr_to_uri_unix_if_possible( + const grpc_resolved_address* /* addr */) { + return ""; +} + +#endif diff --git a/src/core/lib/iomgr/wakeup_fd_eventfd.cc b/src/core/lib/iomgr/wakeup_fd_eventfd.cc new file mode 100644 index 00000000..3951fe6f --- /dev/null +++ b/src/core/lib/iomgr/wakeup_fd_eventfd.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_LINUX_EVENTFD + +#include +#include +#include + +#include + +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "src/core/lib/profiling/timers.h" + +static grpc_error_handle eventfd_create(grpc_wakeup_fd* fd_info) { + fd_info->read_fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + fd_info->write_fd = -1; + if (fd_info->read_fd < 0) { + return GRPC_OS_ERROR(errno, "eventfd"); + } + return GRPC_ERROR_NONE; +} + +static grpc_error_handle eventfd_consume(grpc_wakeup_fd* fd_info) { + eventfd_t value; + int err; + do { + err = eventfd_read(fd_info->read_fd, &value); + } while (err < 0 && errno == EINTR); + if (err < 0 && errno != EAGAIN) { + return GRPC_OS_ERROR(errno, "eventfd_read"); + } + return GRPC_ERROR_NONE; +} + +static grpc_error_handle eventfd_wakeup(grpc_wakeup_fd* fd_info) { + GPR_TIMER_SCOPE("eventfd_wakeup", 0); + int err; + do { + err = eventfd_write(fd_info->read_fd, 1); + } while (err < 0 && errno == EINTR); + if (err < 0) { + return GRPC_OS_ERROR(errno, "eventfd_write"); + } + return GRPC_ERROR_NONE; +} + +static void eventfd_destroy(grpc_wakeup_fd* fd_info) { + if (fd_info->read_fd != 0) close(fd_info->read_fd); +} + +static int eventfd_check_availability(void) { + const int efd = eventfd(0, 0); + const int is_available = efd >= 0; + if (is_available) close(efd); + return is_available; +} + +const grpc_wakeup_fd_vtable grpc_specialized_wakeup_fd_vtable = { + eventfd_create, eventfd_consume, eventfd_wakeup, eventfd_destroy, + eventfd_check_availability}; + +#endif /* GRPC_LINUX_EVENTFD */ diff --git a/src/core/lib/iomgr/wakeup_fd_nospecial.cc b/src/core/lib/iomgr/wakeup_fd_nospecial.cc new file mode 100644 index 00000000..aff61eaa --- /dev/null +++ b/src/core/lib/iomgr/wakeup_fd_nospecial.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* + * This is a phony file to provide an invalid specialized_wakeup_fd_vtable on + * systems without anything better than pipe. + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_NO_SPECIAL_WAKEUP_FD + +#include + +#include "src/core/lib/iomgr/wakeup_fd_posix.h" + +static int check_availability_invalid(void) { return 0; } + +const grpc_wakeup_fd_vtable grpc_specialized_wakeup_fd_vtable = { + nullptr, nullptr, nullptr, nullptr, check_availability_invalid}; + +#endif /* GRPC_POSIX_NO_SPECIAL_WAKEUP_FD */ diff --git a/src/core/lib/iomgr/wakeup_fd_pipe.cc b/src/core/lib/iomgr/wakeup_fd_pipe.cc new file mode 100644 index 00000000..e25a9448 --- /dev/null +++ b/src/core/lib/iomgr/wakeup_fd_pipe.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_WAKEUP_FD + +#include +#include +#include + +#include + +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/wakeup_fd_pipe.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" + +static grpc_error_handle pipe_init(grpc_wakeup_fd* fd_info) { + int pipefd[2]; + int r = pipe(pipefd); + if (0 != r) { + gpr_log(GPR_ERROR, "pipe creation failed (%d): %s", errno, strerror(errno)); + return GRPC_OS_ERROR(errno, "pipe"); + } + grpc_error_handle err; + err = grpc_set_socket_nonblocking(pipefd[0], 1); + if (err != GRPC_ERROR_NONE) return err; + err = grpc_set_socket_nonblocking(pipefd[1], 1); + if (err != GRPC_ERROR_NONE) return err; + fd_info->read_fd = pipefd[0]; + fd_info->write_fd = pipefd[1]; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle pipe_consume(grpc_wakeup_fd* fd_info) { + char buf[128]; + ssize_t r; + + for (;;) { + r = read(fd_info->read_fd, buf, sizeof(buf)); + if (r > 0) continue; + if (r == 0) return GRPC_ERROR_NONE; + switch (errno) { + case EAGAIN: + return GRPC_ERROR_NONE; + case EINTR: + continue; + default: + return GRPC_OS_ERROR(errno, "read"); + } + } +} + +static grpc_error_handle pipe_wakeup(grpc_wakeup_fd* fd_info) { + char c = 0; + while (write(fd_info->write_fd, &c, 1) != 1 && errno == EINTR) { + } + return GRPC_ERROR_NONE; +} + +static void pipe_destroy(grpc_wakeup_fd* fd_info) { + if (fd_info->read_fd != 0) close(fd_info->read_fd); + if (fd_info->write_fd != 0) close(fd_info->write_fd); +} + +static int pipe_check_availability(void) { + grpc_wakeup_fd fd; + fd.read_fd = fd.write_fd = -1; + + if (pipe_init(&fd) == GRPC_ERROR_NONE) { + pipe_destroy(&fd); + return 1; + } else { + return 0; + } +} + +const grpc_wakeup_fd_vtable grpc_pipe_wakeup_fd_vtable = { + pipe_init, pipe_consume, pipe_wakeup, pipe_destroy, + pipe_check_availability}; + +#endif /* GPR_POSIX_WAKUP_FD */ diff --git a/src/core/lib/iomgr/wakeup_fd_posix.cc b/src/core/lib/iomgr/wakeup_fd_posix.cc new file mode 100644 index 00000000..89bd51ba --- /dev/null +++ b/src/core/lib/iomgr/wakeup_fd_posix.cc @@ -0,0 +1,70 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/port.h" + +#ifdef GRPC_POSIX_WAKEUP_FD + +#include + +#include "src/core/lib/iomgr/wakeup_fd_pipe.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" + +static const grpc_wakeup_fd_vtable* wakeup_fd_vtable = nullptr; + +int grpc_allow_specialized_wakeup_fd = 1; +int grpc_allow_pipe_wakeup_fd = 1; + +int has_real_wakeup_fd = 1; +int cv_wakeup_fds_enabled = 0; + +void grpc_wakeup_fd_global_init(void) { + if (grpc_allow_specialized_wakeup_fd && + grpc_specialized_wakeup_fd_vtable.check_availability()) { + wakeup_fd_vtable = &grpc_specialized_wakeup_fd_vtable; + } else if (grpc_allow_pipe_wakeup_fd && + grpc_pipe_wakeup_fd_vtable.check_availability()) { + wakeup_fd_vtable = &grpc_pipe_wakeup_fd_vtable; + } else { + has_real_wakeup_fd = 0; + } +} + +void grpc_wakeup_fd_global_destroy(void) { wakeup_fd_vtable = nullptr; } + +int grpc_has_wakeup_fd(void) { return has_real_wakeup_fd; } + +grpc_error_handle grpc_wakeup_fd_init(grpc_wakeup_fd* fd_info) { + return wakeup_fd_vtable->init(fd_info); +} + +grpc_error_handle grpc_wakeup_fd_consume_wakeup(grpc_wakeup_fd* fd_info) { + return wakeup_fd_vtable->consume(fd_info); +} + +grpc_error_handle grpc_wakeup_fd_wakeup(grpc_wakeup_fd* fd_info) { + return wakeup_fd_vtable->wakeup(fd_info); +} + +void grpc_wakeup_fd_destroy(grpc_wakeup_fd* fd_info) { + wakeup_fd_vtable->destroy(fd_info); +} + +#endif /* GRPC_POSIX_WAKEUP_FD */ diff --git a/src/core/lib/iomgr/work_serializer.cc b/src/core/lib/iomgr/work_serializer.cc new file mode 100644 index 00000000..92a711de --- /dev/null +++ b/src/core/lib/iomgr/work_serializer.cc @@ -0,0 +1,155 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/work_serializer.h" + +namespace grpc_core { + +DebugOnlyTraceFlag grpc_work_serializer_trace(false, "work_serializer"); + +struct CallbackWrapper { + CallbackWrapper(std::function cb, const grpc_core::DebugLocation& loc) + : callback(std::move(cb)), location(loc) {} + + MultiProducerSingleConsumerQueue::Node mpscq_node; + const std::function callback; + const DebugLocation location; +}; + +class WorkSerializer::WorkSerializerImpl : public Orphanable { + public: + void Run(std::function callback, + const grpc_core::DebugLocation& location); + + void Orphan() override; + + private: + void DrainQueue(); + + // An initial size of 1 keeps track of whether the work serializer has been + // orphaned. + std::atomic size_{1}; + MultiProducerSingleConsumerQueue queue_; +}; + +void WorkSerializer::WorkSerializerImpl::Run( + std::function callback, const grpc_core::DebugLocation& location) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, "WorkSerializer::Run() %p Scheduling callback [%s:%d]", + this, location.file(), location.line()); + } + const size_t prev_size = size_.fetch_add(1); + // The work serializer should not have been orphaned. + GPR_DEBUG_ASSERT(prev_size > 0); + if (prev_size == 1) { + // There is no other closure executing right now on this work serializer. + // Execute this closure immediately. + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Executing immediately"); + } + callback(); + // Loan this thread to the work serializer thread and drain the queue. + DrainQueue(); + } else { + CallbackWrapper* cb_wrapper = + new CallbackWrapper(std::move(callback), location); + // There already are closures executing on this work serializer. Simply add + // this closure to the queue. + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Scheduling on queue : item %p", cb_wrapper); + } + queue_.Push(&cb_wrapper->mpscq_node); + } +} + +void WorkSerializer::WorkSerializerImpl::Orphan() { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, "WorkSerializer::Orphan() %p", this); + } + size_t prev_size = size_.fetch_sub(1); + if (prev_size == 1) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Destroying"); + } + delete this; + } +} + +// The thread that calls this loans itself to the work serializer so as to +// execute all the scheduled callback. This is called from within +// WorkSerializer::Run() after executing a callback immediately, and hence size_ +// is at least 1. +void WorkSerializer::WorkSerializerImpl::DrainQueue() { + while (true) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, "WorkSerializer::DrainQueue() %p", this); + } + size_t prev_size = size_.fetch_sub(1); + GPR_DEBUG_ASSERT(prev_size >= 1); + // It is possible that while draining the queue, one of the callbacks ended + // up orphaning the work serializer. In that case, delete the object. + if (prev_size == 1) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Queue Drained. Destroying"); + } + delete this; + return; + } + if (prev_size == 2) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Queue Drained"); + } + return; + } + // There is at least one callback on the queue. Pop the callback from the + // queue and execute it. + CallbackWrapper* cb_wrapper = nullptr; + bool empty_unused; + while ((cb_wrapper = reinterpret_cast( + queue_.PopAndCheckEnd(&empty_unused))) == nullptr) { + // This can happen either due to a race condition within the mpscq + // implementation or because of a race with Run() + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Queue returned nullptr, trying again"); + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_work_serializer_trace)) { + gpr_log(GPR_INFO, " Running item %p : callback scheduled at [%s:%d]", + cb_wrapper, cb_wrapper->location.file(), + cb_wrapper->location.line()); + } + cb_wrapper->callback(); + delete cb_wrapper; + } +} + +// WorkSerializer + +WorkSerializer::WorkSerializer() + : impl_(MakeOrphanable()) {} + +WorkSerializer::~WorkSerializer() {} + +void WorkSerializer::Run(std::function callback, + const grpc_core::DebugLocation& location) { + impl_->Run(std::move(callback), location); +} + +} // namespace grpc_core diff --git a/src/core/lib/json/json_reader.cc b/src/core/lib/json/json_reader.cc new file mode 100644 index 00000000..b8425f28 --- /dev/null +++ b/src/core/lib/json/json_reader.cc @@ -0,0 +1,849 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include + +#include "src/core/lib/json/json.h" + +#define GRPC_JSON_MAX_DEPTH 255 +#define GRPC_JSON_MAX_ERRORS 16 + +namespace grpc_core { + +namespace { + +class JsonReader { + public: + static grpc_error_handle Parse(absl::string_view input, Json* output); + + private: + enum class Status { + GRPC_JSON_DONE, /* The parser finished successfully. */ + GRPC_JSON_PARSE_ERROR, /* The parser found an error in the json stream. */ + GRPC_JSON_INTERNAL_ERROR /* The parser got an internal error. */ + }; + + enum class State { + GRPC_JSON_STATE_OBJECT_KEY_BEGIN, + GRPC_JSON_STATE_OBJECT_KEY_STRING, + GRPC_JSON_STATE_OBJECT_KEY_END, + GRPC_JSON_STATE_VALUE_BEGIN, + GRPC_JSON_STATE_VALUE_STRING, + GRPC_JSON_STATE_STRING_ESCAPE, + GRPC_JSON_STATE_STRING_ESCAPE_U1, + GRPC_JSON_STATE_STRING_ESCAPE_U2, + GRPC_JSON_STATE_STRING_ESCAPE_U3, + GRPC_JSON_STATE_STRING_ESCAPE_U4, + GRPC_JSON_STATE_VALUE_NUMBER, + GRPC_JSON_STATE_VALUE_NUMBER_WITH_DECIMAL, + GRPC_JSON_STATE_VALUE_NUMBER_ZERO, + GRPC_JSON_STATE_VALUE_NUMBER_DOT, + GRPC_JSON_STATE_VALUE_NUMBER_E, + GRPC_JSON_STATE_VALUE_NUMBER_EPM, + GRPC_JSON_STATE_VALUE_TRUE_R, + GRPC_JSON_STATE_VALUE_TRUE_U, + GRPC_JSON_STATE_VALUE_TRUE_E, + GRPC_JSON_STATE_VALUE_FALSE_A, + GRPC_JSON_STATE_VALUE_FALSE_L, + GRPC_JSON_STATE_VALUE_FALSE_S, + GRPC_JSON_STATE_VALUE_FALSE_E, + GRPC_JSON_STATE_VALUE_NULL_U, + GRPC_JSON_STATE_VALUE_NULL_L1, + GRPC_JSON_STATE_VALUE_NULL_L2, + GRPC_JSON_STATE_VALUE_END, + GRPC_JSON_STATE_END + }; + + /* The first non-unicode value is 0x110000. But let's pick + * a value high enough to start our error codes from. These + * values are safe to return from the read_char function. + */ + static constexpr uint32_t GRPC_JSON_READ_CHAR_EOF = 0x7ffffff0; + + explicit JsonReader(absl::string_view input) + : original_input_(reinterpret_cast(input.data())), + input_(original_input_), + remaining_input_(input.size()) {} + + Status Run(); + uint32_t ReadChar(); + bool IsComplete(); + + size_t CurrentIndex() const { return input_ - original_input_ - 1; } + + void StringAddChar(uint32_t c); + void StringAddUtf32(uint32_t c); + + Json* CreateAndLinkValue(); + bool StartContainer(Json::Type type); + void EndContainer(); + void SetKey(); + void SetString(); + bool SetNumber(); + void SetTrue(); + void SetFalse(); + void SetNull(); + + const uint8_t* original_input_; + const uint8_t* input_; + size_t remaining_input_; + + State state_ = State::GRPC_JSON_STATE_VALUE_BEGIN; + bool escaped_string_was_key_ = false; + bool container_just_begun_ = false; + uint16_t unicode_char_ = 0; + uint16_t unicode_high_surrogate_ = 0; + std::vector errors_; + bool truncated_errors_ = false; + + Json root_value_; + std::vector stack_; + + std::string key_; + std::string string_; +}; + +void JsonReader::StringAddChar(uint32_t c) { + string_.push_back(static_cast(c)); +} + +void JsonReader::StringAddUtf32(uint32_t c) { + if (c <= 0x7f) { + StringAddChar(c); + } else if (c <= 0x7ff) { + uint32_t b1 = 0xc0 | ((c >> 6) & 0x1f); + uint32_t b2 = 0x80 | (c & 0x3f); + StringAddChar(b1); + StringAddChar(b2); + } else if (c <= 0xffff) { + uint32_t b1 = 0xe0 | ((c >> 12) & 0x0f); + uint32_t b2 = 0x80 | ((c >> 6) & 0x3f); + uint32_t b3 = 0x80 | (c & 0x3f); + StringAddChar(b1); + StringAddChar(b2); + StringAddChar(b3); + } else if (c <= 0x1fffff) { + uint32_t b1 = 0xf0 | ((c >> 18) & 0x07); + uint32_t b2 = 0x80 | ((c >> 12) & 0x3f); + uint32_t b3 = 0x80 | ((c >> 6) & 0x3f); + uint32_t b4 = 0x80 | (c & 0x3f); + StringAddChar(b1); + StringAddChar(b2); + StringAddChar(b3); + StringAddChar(b4); + } +} + +uint32_t JsonReader::ReadChar() { + if (remaining_input_ == 0) return GRPC_JSON_READ_CHAR_EOF; + const uint32_t r = *input_++; + --remaining_input_; + if (r == 0) { + remaining_input_ = 0; + return GRPC_JSON_READ_CHAR_EOF; + } + return r; +} + +Json* JsonReader::CreateAndLinkValue() { + Json* value; + if (stack_.empty()) { + value = &root_value_; + } else { + Json* parent = stack_.back(); + if (parent->type() == Json::Type::OBJECT) { + if (parent->object_value().find(key_) != parent->object_value().end()) { + if (errors_.size() == GRPC_JSON_MAX_ERRORS) { + truncated_errors_ = true; + } else { + errors_.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("duplicate key \"%s\" at index %" PRIuPTR, key_, + CurrentIndex()))); + } + } + value = &(*parent->mutable_object())[std::move(key_)]; + } else { + GPR_ASSERT(parent->type() == Json::Type::ARRAY); + parent->mutable_array()->emplace_back(); + value = &parent->mutable_array()->back(); + } + } + return value; +} + +bool JsonReader::StartContainer(Json::Type type) { + if (stack_.size() == GRPC_JSON_MAX_DEPTH) { + if (errors_.size() == GRPC_JSON_MAX_ERRORS) { + truncated_errors_ = true; + } else { + errors_.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("exceeded max stack depth (%d) at index %" PRIuPTR, + GRPC_JSON_MAX_DEPTH, CurrentIndex()))); + } + return false; + } + Json* value = CreateAndLinkValue(); + if (type == Json::Type::OBJECT) { + *value = Json::Object(); + } else { + GPR_ASSERT(type == Json::Type::ARRAY); + *value = Json::Array(); + } + stack_.push_back(value); + return true; +} + +void JsonReader::EndContainer() { + GPR_ASSERT(!stack_.empty()); + stack_.pop_back(); +} + +void JsonReader::SetKey() { + key_ = std::move(string_); + string_.clear(); +} + +void JsonReader::SetString() { + Json* value = CreateAndLinkValue(); + *value = std::move(string_); + string_.clear(); +} + +bool JsonReader::SetNumber() { + Json* value = CreateAndLinkValue(); + *value = Json(string_, /*is_number=*/true); + string_.clear(); + return true; +} + +void JsonReader::SetTrue() { + Json* value = CreateAndLinkValue(); + *value = true; + string_.clear(); +} + +void JsonReader::SetFalse() { + Json* value = CreateAndLinkValue(); + *value = false; + string_.clear(); +} + +void JsonReader::SetNull() { CreateAndLinkValue(); } + +bool JsonReader::IsComplete() { + return (stack_.empty() && (state_ == State::GRPC_JSON_STATE_END || + state_ == State::GRPC_JSON_STATE_VALUE_END)); +} + +/* Call this function to start parsing the input. It will return the following: + * . GRPC_JSON_DONE if the input got eof, and the parsing finished + * successfully. + * . GRPC_JSON_PARSE_ERROR if the input was somehow invalid. + * . GRPC_JSON_INTERNAL_ERROR if the parser somehow ended into an invalid + * internal state. + */ +JsonReader::Status JsonReader::Run() { + uint32_t c; + + /* This state-machine is a strict implementation of ECMA-404 */ + while (true) { + c = ReadChar(); + switch (c) { + /* Let's process the error case first. */ + case GRPC_JSON_READ_CHAR_EOF: + if (IsComplete()) { + return Status::GRPC_JSON_DONE; + } + return Status::GRPC_JSON_PARSE_ERROR; + + /* Processing whitespaces. */ + case ' ': + case '\t': + case '\n': + case '\r': + switch (state_) { + case State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN: + case State::GRPC_JSON_STATE_OBJECT_KEY_END: + case State::GRPC_JSON_STATE_VALUE_BEGIN: + case State::GRPC_JSON_STATE_VALUE_END: + case State::GRPC_JSON_STATE_END: + break; + + case State::GRPC_JSON_STATE_OBJECT_KEY_STRING: + case State::GRPC_JSON_STATE_VALUE_STRING: + if (c != ' ') return Status::GRPC_JSON_PARSE_ERROR; + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + StringAddChar(c); + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER: + case State::GRPC_JSON_STATE_VALUE_NUMBER_WITH_DECIMAL: + case State::GRPC_JSON_STATE_VALUE_NUMBER_ZERO: + case State::GRPC_JSON_STATE_VALUE_NUMBER_EPM: + if (!SetNumber()) return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_END; + break; + + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + /* Value, object or array terminations. */ + case ',': + case '}': + case ']': + switch (state_) { + case State::GRPC_JSON_STATE_OBJECT_KEY_STRING: + case State::GRPC_JSON_STATE_VALUE_STRING: + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + StringAddChar(c); + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER: + case State::GRPC_JSON_STATE_VALUE_NUMBER_WITH_DECIMAL: + case State::GRPC_JSON_STATE_VALUE_NUMBER_ZERO: + case State::GRPC_JSON_STATE_VALUE_NUMBER_EPM: + if (stack_.empty()) { + return Status::GRPC_JSON_PARSE_ERROR; + } else if (c == '}' && + stack_.back()->type() != Json::Type::OBJECT) { + return Status::GRPC_JSON_PARSE_ERROR; + } else if (c == ']' && stack_.back()->type() != Json::Type::ARRAY) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (!SetNumber()) return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_END; + ABSL_FALLTHROUGH_INTENDED; + + case State::GRPC_JSON_STATE_VALUE_END: + case State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN: + case State::GRPC_JSON_STATE_VALUE_BEGIN: + if (c == ',') { + if (state_ != State::GRPC_JSON_STATE_VALUE_END) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (!stack_.empty() && + stack_.back()->type() == Json::Type::OBJECT) { + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN; + } else if (!stack_.empty() && + stack_.back()->type() == Json::Type::ARRAY) { + state_ = State::GRPC_JSON_STATE_VALUE_BEGIN; + } else { + return Status::GRPC_JSON_PARSE_ERROR; + } + } else { + if (stack_.empty()) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == '}' && stack_.back()->type() != Json::Type::OBJECT) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == '}' && + state_ == State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN && + !container_just_begun_) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == ']' && stack_.back()->type() != Json::Type::ARRAY) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == ']' && state_ == State::GRPC_JSON_STATE_VALUE_BEGIN && + !container_just_begun_) { + return Status::GRPC_JSON_PARSE_ERROR; + } + state_ = State::GRPC_JSON_STATE_VALUE_END; + EndContainer(); + if (stack_.empty()) { + state_ = State::GRPC_JSON_STATE_END; + } + } + break; + + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + /* In-string escaping. */ + case '\\': + switch (state_) { + case State::GRPC_JSON_STATE_OBJECT_KEY_STRING: + escaped_string_was_key_ = true; + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE; + break; + + case State::GRPC_JSON_STATE_VALUE_STRING: + escaped_string_was_key_ = false; + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE; + break; + + /* This is the \\ case. */ + case State::GRPC_JSON_STATE_STRING_ESCAPE: + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + StringAddChar('\\'); + if (escaped_string_was_key_) { + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_STRING; + } else { + state_ = State::GRPC_JSON_STATE_VALUE_STRING; + } + break; + + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + default: + container_just_begun_ = false; + switch (state_) { + case State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN: + if (c != '"') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_STRING; + break; + + case State::GRPC_JSON_STATE_OBJECT_KEY_STRING: + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == '"') { + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_END; + SetKey(); + } else { + if (c < 32) return Status::GRPC_JSON_PARSE_ERROR; + StringAddChar(c); + } + break; + + case State::GRPC_JSON_STATE_VALUE_STRING: + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + if (c == '"') { + state_ = State::GRPC_JSON_STATE_VALUE_END; + SetString(); + } else { + if (c < 32) return Status::GRPC_JSON_PARSE_ERROR; + StringAddChar(c); + } + break; + + case State::GRPC_JSON_STATE_OBJECT_KEY_END: + if (c != ':') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_BEGIN; + break; + + case State::GRPC_JSON_STATE_VALUE_BEGIN: + switch (c) { + case 't': + state_ = State::GRPC_JSON_STATE_VALUE_TRUE_R; + break; + + case 'f': + state_ = State::GRPC_JSON_STATE_VALUE_FALSE_A; + break; + + case 'n': + state_ = State::GRPC_JSON_STATE_VALUE_NULL_U; + break; + + case '"': + state_ = State::GRPC_JSON_STATE_VALUE_STRING; + break; + + case '0': + StringAddChar(c); + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_ZERO; + break; + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + StringAddChar(c); + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER; + break; + + case '{': + container_just_begun_ = true; + if (!StartContainer(Json::Type::OBJECT)) { + return Status::GRPC_JSON_PARSE_ERROR; + } + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_BEGIN; + break; + + case '[': + container_just_begun_ = true; + if (!StartContainer(Json::Type::ARRAY)) { + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_STRING_ESCAPE: + if (escaped_string_was_key_) { + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_STRING; + } else { + state_ = State::GRPC_JSON_STATE_VALUE_STRING; + } + if (unicode_high_surrogate_ && c != 'u') { + return Status::GRPC_JSON_PARSE_ERROR; + } + switch (c) { + case '"': + case '/': + StringAddChar(c); + break; + case 'b': + StringAddChar('\b'); + break; + case 'f': + StringAddChar('\f'); + break; + case 'n': + StringAddChar('\n'); + break; + case 'r': + StringAddChar('\r'); + break; + case 't': + StringAddChar('\t'); + break; + case 'u': + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE_U1; + unicode_char_ = 0; + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_STRING_ESCAPE_U1: + case State::GRPC_JSON_STATE_STRING_ESCAPE_U2: + case State::GRPC_JSON_STATE_STRING_ESCAPE_U3: + case State::GRPC_JSON_STATE_STRING_ESCAPE_U4: + if ((c >= '0') && (c <= '9')) { + c -= '0'; + } else if ((c >= 'A') && (c <= 'F')) { + c -= 'A' - 10; + } else if ((c >= 'a') && (c <= 'f')) { + c -= 'a' - 10; + } else { + return Status::GRPC_JSON_PARSE_ERROR; + } + unicode_char_ = static_cast(unicode_char_ << 4); + unicode_char_ = static_cast(unicode_char_ | c); + + switch (state_) { + case State::GRPC_JSON_STATE_STRING_ESCAPE_U1: + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE_U2; + break; + case State::GRPC_JSON_STATE_STRING_ESCAPE_U2: + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE_U3; + break; + case State::GRPC_JSON_STATE_STRING_ESCAPE_U3: + state_ = State::GRPC_JSON_STATE_STRING_ESCAPE_U4; + break; + case State::GRPC_JSON_STATE_STRING_ESCAPE_U4: + /* See grpc_json_writer_escape_string to have a description + * of what's going on here. + */ + if ((unicode_char_ & 0xfc00) == 0xd800) { + /* high surrogate utf-16 */ + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + unicode_high_surrogate_ = unicode_char_; + } else if ((unicode_char_ & 0xfc00) == 0xdc00) { + /* low surrogate utf-16 */ + uint32_t utf32; + if (unicode_high_surrogate_ == 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + utf32 = 0x10000; + utf32 += static_cast( + (unicode_high_surrogate_ - 0xd800) * 0x400); + utf32 += static_cast(unicode_char_ - 0xdc00); + StringAddUtf32(utf32); + unicode_high_surrogate_ = 0; + } else { + /* anything else */ + if (unicode_high_surrogate_ != 0) { + return Status::GRPC_JSON_PARSE_ERROR; + } + StringAddUtf32(unicode_char_); + } + if (escaped_string_was_key_) { + state_ = State::GRPC_JSON_STATE_OBJECT_KEY_STRING; + } else { + state_ = State::GRPC_JSON_STATE_VALUE_STRING; + } + break; + default: + GPR_UNREACHABLE_CODE(return Status::GRPC_JSON_INTERNAL_ERROR); + } + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER: + StringAddChar(c); + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + break; + case 'e': + case 'E': + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_E; + break; + case '.': + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_DOT; + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER_WITH_DECIMAL: + StringAddChar(c); + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + break; + case 'e': + case 'E': + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_E; + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER_ZERO: + if (c != '.') return Status::GRPC_JSON_PARSE_ERROR; + StringAddChar(c); + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_DOT; + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER_DOT: + StringAddChar(c); + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_WITH_DECIMAL; + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER_E: + StringAddChar(c); + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '+': + case '-': + state_ = State::GRPC_JSON_STATE_VALUE_NUMBER_EPM; + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_VALUE_NUMBER_EPM: + StringAddChar(c); + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + break; + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_VALUE_TRUE_R: + if (c != 'r') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_TRUE_U; + break; + + case State::GRPC_JSON_STATE_VALUE_TRUE_U: + if (c != 'u') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_TRUE_E; + break; + + case State::GRPC_JSON_STATE_VALUE_TRUE_E: + if (c != 'e') return Status::GRPC_JSON_PARSE_ERROR; + SetTrue(); + state_ = State::GRPC_JSON_STATE_VALUE_END; + break; + + case State::GRPC_JSON_STATE_VALUE_FALSE_A: + if (c != 'a') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_FALSE_L; + break; + + case State::GRPC_JSON_STATE_VALUE_FALSE_L: + if (c != 'l') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_FALSE_S; + break; + + case State::GRPC_JSON_STATE_VALUE_FALSE_S: + if (c != 's') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_FALSE_E; + break; + + case State::GRPC_JSON_STATE_VALUE_FALSE_E: + if (c != 'e') return Status::GRPC_JSON_PARSE_ERROR; + SetFalse(); + state_ = State::GRPC_JSON_STATE_VALUE_END; + break; + + case State::GRPC_JSON_STATE_VALUE_NULL_U: + if (c != 'u') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_NULL_L1; + break; + + case State::GRPC_JSON_STATE_VALUE_NULL_L1: + if (c != 'l') return Status::GRPC_JSON_PARSE_ERROR; + state_ = State::GRPC_JSON_STATE_VALUE_NULL_L2; + break; + + case State::GRPC_JSON_STATE_VALUE_NULL_L2: + if (c != 'l') return Status::GRPC_JSON_PARSE_ERROR; + SetNull(); + state_ = State::GRPC_JSON_STATE_VALUE_END; + break; + + /* All of the VALUE_END cases are handled in the specialized case + * above. */ + case State::GRPC_JSON_STATE_VALUE_END: + switch (c) { + case ',': + case '}': + case ']': + GPR_UNREACHABLE_CODE(return Status::GRPC_JSON_INTERNAL_ERROR); + break; + + default: + return Status::GRPC_JSON_PARSE_ERROR; + } + break; + + case State::GRPC_JSON_STATE_END: + return Status::GRPC_JSON_PARSE_ERROR; + } + } + } + + GPR_UNREACHABLE_CODE(return Status::GRPC_JSON_INTERNAL_ERROR); +} + +grpc_error_handle JsonReader::Parse(absl::string_view input, Json* output) { + JsonReader reader(input); + Status status = reader.Run(); + if (reader.truncated_errors_) { + reader.errors_.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "too many errors encountered during JSON parsing -- fix reported " + "errors and try again to see additional errors")); + } + if (status == Status::GRPC_JSON_INTERNAL_ERROR) { + reader.errors_.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "internal error in JSON parser at index ", reader.CurrentIndex()))); + } else if (status == Status::GRPC_JSON_PARSE_ERROR) { + reader.errors_.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("JSON parse error at index ", reader.CurrentIndex()))); + } + if (!reader.errors_.empty()) { + return GRPC_ERROR_CREATE_FROM_VECTOR("JSON parsing failed", + &reader.errors_); + } + *output = std::move(reader.root_value_); + return GRPC_ERROR_NONE; +} + +} // namespace + +Json Json::Parse(absl::string_view json_str, grpc_error_handle* error) { + Json value; + *error = JsonReader::Parse(json_str, &value); + return value; +} + +} // namespace grpc_core diff --git a/src/core/lib/json/json_util.cc b/src/core/lib/json/json_util.cc new file mode 100644 index 00000000..1c90aeb5 --- /dev/null +++ b/src/core/lib/json/json_util.cc @@ -0,0 +1,58 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/lib/json/json_util.h" + +#include + +#include "src/core/lib/gpr/string.h" + +namespace grpc_core { + +bool ParseDurationFromJson(const Json& field, grpc_millis* duration) { + if (field.type() != Json::Type::STRING) return false; + size_t len = field.string_value().size(); + if (field.string_value()[len - 1] != 's') return false; + grpc_core::UniquePtr buf(gpr_strdup(field.string_value().c_str())); + *(buf.get() + len - 1) = '\0'; // Remove trailing 's'. + char* decimal_point = strchr(buf.get(), '.'); + int nanos = 0; + if (decimal_point != nullptr) { + *decimal_point = '\0'; + nanos = gpr_parse_nonnegative_int(decimal_point + 1); + if (nanos == -1) { + return false; + } + int num_digits = static_cast(strlen(decimal_point + 1)); + if (num_digits > 9) { // We don't accept greater precision than nanos. + return false; + } + for (int i = 0; i < (9 - num_digits); ++i) { + nanos *= 10; + } + } + int seconds = + decimal_point == buf.get() ? 0 : gpr_parse_nonnegative_int(buf.get()); + if (seconds == -1) return false; + *duration = seconds * GPR_MS_PER_SEC + nanos / GPR_NS_PER_MS; + return true; +} + +} // namespace grpc_core diff --git a/src/core/lib/json/json_writer.cc b/src/core/lib/json/json_writer.cc new file mode 100644 index 00000000..7522802a --- /dev/null +++ b/src/core/lib/json/json_writer.cc @@ -0,0 +1,335 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "absl/strings/string_view.h" + +#include "src/core/lib/json/json.h" + +namespace grpc_core { + +namespace { + +/* The idea of the writer is basically symmetrical of the reader. While the + * reader emits various calls to your code, the writer takes basically the + * same calls and emit json out of it. It doesn't try to make any check on + * the order of the calls you do on it. Meaning you can theorically force + * it to generate invalid json. + * + * Also, unlike the reader, the writer expects UTF-8 encoded input strings. + * These strings will be UTF-8 validated, and any invalid character will + * cut the conversion short, before any invalid UTF-8 sequence, thus forming + * a valid UTF-8 string overall. + */ +class JsonWriter { + public: + static std::string Dump(const Json& value, int indent); + + private: + explicit JsonWriter(int indent) : indent_(indent) {} + + void OutputCheck(size_t needed); + void OutputChar(char c); + void OutputString(const absl::string_view str); + void OutputIndent(); + void ValueEnd(); + void EscapeUtf16(uint16_t utf16); + void EscapeString(const std::string& string); + void ContainerBegins(Json::Type type); + void ContainerEnds(Json::Type type); + void ObjectKey(const std::string& string); + void ValueRaw(const std::string& string); + void ValueString(const std::string& string); + + void DumpObject(const Json::Object& object); + void DumpArray(const Json::Array& array); + void DumpValue(const Json& value); + + int indent_; + int depth_ = 0; + bool container_empty_ = true; + bool got_key_ = false; + std::string output_; +}; + +/* This function checks if there's enough space left in the output buffer, + * and will enlarge it if necessary. We're only allocating chunks of 256 + * bytes at a time (or multiples thereof). + */ +void JsonWriter::OutputCheck(size_t needed) { + size_t free_space = output_.capacity() - output_.size(); + if (free_space >= needed) return; + needed -= free_space; + /* Round up by 256 bytes. */ + needed = (needed + 0xff) & ~0xffU; + output_.reserve(output_.capacity() + needed); +} + +void JsonWriter::OutputChar(char c) { + OutputCheck(1); + output_.push_back(c); +} + +void JsonWriter::OutputString(const absl::string_view str) { + OutputCheck(str.size()); + output_.append(str.data(), str.size()); +} + +void JsonWriter::OutputIndent() { + static const char spacesstr[] = + " " + " " + " " + " "; + unsigned spaces = static_cast(depth_ * indent_); + if (indent_ == 0) return; + if (got_key_) { + OutputChar(' '); + return; + } + while (spaces >= (sizeof(spacesstr) - 1)) { + OutputString(absl::string_view(spacesstr, sizeof(spacesstr) - 1)); + spaces -= static_cast(sizeof(spacesstr) - 1); + } + if (spaces == 0) return; + OutputString( + absl::string_view(spacesstr + sizeof(spacesstr) - 1 - spaces, spaces)); +} + +void JsonWriter::ValueEnd() { + if (container_empty_) { + container_empty_ = false; + if (indent_ == 0 || depth_ == 0) return; + OutputChar('\n'); + } else { + OutputChar(','); + if (indent_ == 0) return; + OutputChar('\n'); + } +} + +void JsonWriter::EscapeUtf16(uint16_t utf16) { + static const char hex[] = "0123456789abcdef"; + OutputString(absl::string_view("\\u", 2)); + OutputChar(hex[(utf16 >> 12) & 0x0f]); + OutputChar(hex[(utf16 >> 8) & 0x0f]); + OutputChar(hex[(utf16 >> 4) & 0x0f]); + OutputChar(hex[(utf16)&0x0f]); +} + +void JsonWriter::EscapeString(const std::string& string) { + OutputChar('"'); + for (size_t idx = 0; idx < string.size(); ++idx) { + uint8_t c = static_cast(string[idx]); + if (c == 0) { + break; + } else if (c >= 32 && c <= 126) { + if (c == '\\' || c == '"') OutputChar('\\'); + OutputChar(static_cast(c)); + } else if (c < 32 || c == 127) { + switch (c) { + case '\b': + OutputString(absl::string_view("\\b", 2)); + break; + case '\f': + OutputString(absl::string_view("\\f", 2)); + break; + case '\n': + OutputString(absl::string_view("\\n", 2)); + break; + case '\r': + OutputString(absl::string_view("\\r", 2)); + break; + case '\t': + OutputString(absl::string_view("\\t", 2)); + break; + default: + EscapeUtf16(c); + break; + } + } else { + uint32_t utf32 = 0; + int extra = 0; + int i; + int valid = 1; + if ((c & 0xe0) == 0xc0) { + utf32 = c & 0x1f; + extra = 1; + } else if ((c & 0xf0) == 0xe0) { + utf32 = c & 0x0f; + extra = 2; + } else if ((c & 0xf8) == 0xf0) { + utf32 = c & 0x07; + extra = 3; + } else { + break; + } + for (i = 0; i < extra; i++) { + utf32 <<= 6; + ++idx; + /* Breaks out and bail if we hit the end of the string. */ + if (idx == string.size()) { + valid = 0; + break; + } + c = static_cast(string[idx]); + /* Breaks out and bail on any invalid UTF-8 sequence, including \0. */ + if ((c & 0xc0) != 0x80) { + valid = 0; + break; + } + utf32 |= c & 0x3f; + } + if (!valid) break; + /* The range 0xd800 - 0xdfff is reserved by the surrogates ad vitam. + * Any other range is technically reserved for future usage, so if we + * don't want the software to break in the future, we have to allow + * anything else. The first non-unicode character is 0x110000. */ + if (((utf32 >= 0xd800) && (utf32 <= 0xdfff)) || (utf32 >= 0x110000)) { + break; + } + if (utf32 >= 0x10000) { + /* If utf32 contains a character that is above 0xffff, it needs to be + * broken down into a utf-16 surrogate pair. A surrogate pair is first + * a high surrogate, followed by a low surrogate. Each surrogate holds + * 10 bits of usable data, thus allowing a total of 20 bits of data. + * The high surrogate marker is 0xd800, while the low surrogate marker + * is 0xdc00. The low 10 bits of each will be the usable data. + * + * After re-combining the 20 bits of data, one has to add 0x10000 to + * the resulting value, in order to obtain the original character. + * This is obviously because the range 0x0000 - 0xffff can be written + * without any special trick. + * + * Since 0x10ffff is the highest allowed character, we're working in + * the range 0x00000 - 0xfffff after we decrement it by 0x10000. + * That range is exactly 20 bits. + */ + utf32 -= 0x10000; + EscapeUtf16(static_cast(0xd800 | (utf32 >> 10))); + EscapeUtf16(static_cast(0xdc00 | (utf32 & 0x3ff))); + } else { + EscapeUtf16(static_cast(utf32)); + } + } + } + OutputChar('"'); +} + +void JsonWriter::ContainerBegins(Json::Type type) { + if (!got_key_) ValueEnd(); + OutputIndent(); + OutputChar(type == Json::Type::OBJECT ? '{' : '['); + container_empty_ = true; + got_key_ = false; + depth_++; +} + +void JsonWriter::ContainerEnds(Json::Type type) { + if (indent_ && !container_empty_) OutputChar('\n'); + depth_--; + if (!container_empty_) OutputIndent(); + OutputChar(type == Json::Type::OBJECT ? '}' : ']'); + container_empty_ = false; + got_key_ = false; +} + +void JsonWriter::ObjectKey(const std::string& string) { + ValueEnd(); + OutputIndent(); + EscapeString(string); + OutputChar(':'); + got_key_ = true; +} + +void JsonWriter::ValueRaw(const std::string& string) { + if (!got_key_) ValueEnd(); + OutputIndent(); + OutputString(string); + got_key_ = false; +} + +void JsonWriter::ValueString(const std::string& string) { + if (!got_key_) ValueEnd(); + OutputIndent(); + EscapeString(string); + got_key_ = false; +} + +void JsonWriter::DumpObject(const Json::Object& object) { + ContainerBegins(Json::Type::OBJECT); + for (const auto& p : object) { + ObjectKey(p.first.data()); + DumpValue(p.second); + } + ContainerEnds(Json::Type::OBJECT); +} + +void JsonWriter::DumpArray(const Json::Array& array) { + ContainerBegins(Json::Type::ARRAY); + for (const auto& v : array) { + DumpValue(v); + } + ContainerEnds(Json::Type::ARRAY); +} + +void JsonWriter::DumpValue(const Json& value) { + switch (value.type()) { + case Json::Type::OBJECT: + DumpObject(value.object_value()); + break; + case Json::Type::ARRAY: + DumpArray(value.array_value()); + break; + case Json::Type::STRING: + ValueString(value.string_value()); + break; + case Json::Type::NUMBER: + ValueRaw(value.string_value()); + break; + case Json::Type::JSON_TRUE: + ValueRaw(std::string("true", 4)); + break; + case Json::Type::JSON_FALSE: + ValueRaw(std::string("false", 5)); + break; + case Json::Type::JSON_NULL: + ValueRaw(std::string("null", 4)); + break; + default: + GPR_UNREACHABLE_CODE(abort()); + } +} + +std::string JsonWriter::Dump(const Json& value, int indent) { + JsonWriter writer(indent); + writer.DumpValue(value); + return std::move(writer.output_); +} + +} // namespace + +std::string Json::Dump(int indent) const { + return JsonWriter::Dump(*this, indent); +} + +} // namespace grpc_core diff --git a/src/core/lib/matchers/matchers.cc b/src/core/lib/matchers/matchers.cc new file mode 100644 index 00000000..b266176f --- /dev/null +++ b/src/core/lib/matchers/matchers.cc @@ -0,0 +1,327 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/matchers/matchers.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" + +namespace grpc_core { + +// +// StringMatcher +// + +absl::StatusOr StringMatcher::Create(Type type, + absl::string_view matcher, + bool case_sensitive) { + if (type == Type::kSafeRegex) { + auto regex_matcher = absl::make_unique(std::string(matcher)); + if (!regex_matcher->ok()) { + return absl::InvalidArgumentError( + "Invalid regex string specified in matcher."); + } + return StringMatcher(std::move(regex_matcher)); + } else { + return StringMatcher(type, matcher, case_sensitive); + } +} + +StringMatcher::StringMatcher(Type type, absl::string_view matcher, + bool case_sensitive) + : type_(type), string_matcher_(matcher), case_sensitive_(case_sensitive) {} + +StringMatcher::StringMatcher(std::unique_ptr regex_matcher) + : type_(Type::kSafeRegex), regex_matcher_(std::move(regex_matcher)) {} + +StringMatcher::StringMatcher(const StringMatcher& other) + : type_(other.type_), case_sensitive_(other.case_sensitive_) { + if (type_ == Type::kSafeRegex) { + regex_matcher_ = absl::make_unique(other.regex_matcher_->pattern()); + } else { + string_matcher_ = other.string_matcher_; + } +} + +StringMatcher& StringMatcher::operator=(const StringMatcher& other) { + type_ = other.type_; + if (type_ == Type::kSafeRegex) { + regex_matcher_ = absl::make_unique(other.regex_matcher_->pattern()); + } else { + string_matcher_ = other.string_matcher_; + } + case_sensitive_ = other.case_sensitive_; + return *this; +} + +StringMatcher::StringMatcher(StringMatcher&& other) noexcept + : type_(other.type_), case_sensitive_(other.case_sensitive_) { + if (type_ == Type::kSafeRegex) { + regex_matcher_ = std::move(other.regex_matcher_); + } else { + string_matcher_ = std::move(other.string_matcher_); + } +} + +StringMatcher& StringMatcher::operator=(StringMatcher&& other) noexcept { + type_ = other.type_; + if (type_ == Type::kSafeRegex) { + regex_matcher_ = std::move(other.regex_matcher_); + } else { + string_matcher_ = std::move(other.string_matcher_); + } + case_sensitive_ = other.case_sensitive_; + return *this; +} + +bool StringMatcher::operator==(const StringMatcher& other) const { + if (type_ != other.type_ || case_sensitive_ != other.case_sensitive_) { + return false; + } + if (type_ == Type::kSafeRegex) { + return regex_matcher_->pattern() == other.regex_matcher_->pattern(); + } else { + return string_matcher_ == other.string_matcher_; + } +} + +bool StringMatcher::Match(absl::string_view value) const { + switch (type_) { + case Type::kExact: + return case_sensitive_ ? value == string_matcher_ + : absl::EqualsIgnoreCase(value, string_matcher_); + case StringMatcher::Type::kPrefix: + return case_sensitive_ + ? absl::StartsWith(value, string_matcher_) + : absl::StartsWithIgnoreCase(value, string_matcher_); + case StringMatcher::Type::kSuffix: + return case_sensitive_ ? absl::EndsWith(value, string_matcher_) + : absl::EndsWithIgnoreCase(value, string_matcher_); + case StringMatcher::Type::kContains: + return case_sensitive_ + ? absl::StrContains(value, string_matcher_) + : absl::StrContains(absl::AsciiStrToLower(value), + absl::AsciiStrToLower(string_matcher_)); + case StringMatcher::Type::kSafeRegex: + return RE2::FullMatch(std::string(value), *regex_matcher_); + default: + return false; + } +} + +std::string StringMatcher::ToString() const { + switch (type_) { + case Type::kExact: + return absl::StrFormat("StringMatcher{exact=%s%s}", string_matcher_, + case_sensitive_ ? "" : ", case_sensitive=false"); + case Type::kPrefix: + return absl::StrFormat("StringMatcher{prefix=%s%s}", string_matcher_, + case_sensitive_ ? "" : ", case_sensitive=false"); + case Type::kSuffix: + return absl::StrFormat("StringMatcher{suffix=%s%s}", string_matcher_, + case_sensitive_ ? "" : ", case_sensitive=false"); + case Type::kContains: + return absl::StrFormat("StringMatcher{contains=%s%s}", string_matcher_, + case_sensitive_ ? "" : ", case_sensitive=false"); + case Type::kSafeRegex: + return absl::StrFormat("StringMatcher{safe_regex=%s}", + regex_matcher_->pattern()); + default: + return ""; + } +} + +// +// HeaderMatcher +// + +absl::StatusOr HeaderMatcher::Create( + absl::string_view name, Type type, absl::string_view matcher, + int64_t range_start, int64_t range_end, bool present_match, + bool invert_match) { + if (static_cast(type) < 5) { + // Only for EXACT, PREFIX, SUFFIX, SAFE_REGEX and CONTAINS. + absl::StatusOr string_matcher = + StringMatcher::Create(static_cast(type), matcher, + /*case_sensitive=*/true); + if (!string_matcher.ok()) { + return string_matcher.status(); + } + return HeaderMatcher(name, type, std::move(string_matcher.value()), + invert_match); + } else if (type == Type::kRange) { + if (range_start > range_end) { + return absl::InvalidArgumentError( + "Invalid range specifier specified: end cannot be smaller than " + "start."); + } + return HeaderMatcher(name, range_start, range_end, invert_match); + } else { + return HeaderMatcher(name, present_match, invert_match); + } +} + +HeaderMatcher::HeaderMatcher(absl::string_view name, Type type, + StringMatcher string_matcher, bool invert_match) + : name_(name), + type_(type), + matcher_(std::move(string_matcher)), + invert_match_(invert_match) {} + +HeaderMatcher::HeaderMatcher(absl::string_view name, int64_t range_start, + int64_t range_end, bool invert_match) + : name_(name), + type_(Type::kRange), + range_start_(range_start), + range_end_(range_end), + invert_match_(invert_match) {} + +HeaderMatcher::HeaderMatcher(absl::string_view name, bool present_match, + bool invert_match) + : name_(name), + type_(Type::kPresent), + present_match_(present_match), + invert_match_(invert_match) {} + +HeaderMatcher::HeaderMatcher(const HeaderMatcher& other) + : name_(other.name_), + type_(other.type_), + invert_match_(other.invert_match_) { + switch (type_) { + case Type::kRange: + range_start_ = other.range_start_; + range_end_ = other.range_end_; + break; + case Type::kPresent: + present_match_ = other.present_match_; + break; + default: + matcher_ = other.matcher_; + } +} + +HeaderMatcher& HeaderMatcher::operator=(const HeaderMatcher& other) { + name_ = other.name_; + type_ = other.type_; + invert_match_ = other.invert_match_; + switch (type_) { + case Type::kRange: + range_start_ = other.range_start_; + range_end_ = other.range_end_; + break; + case Type::kPresent: + present_match_ = other.present_match_; + break; + default: + matcher_ = other.matcher_; + } + return *this; +} + +HeaderMatcher::HeaderMatcher(HeaderMatcher&& other) noexcept + : name_(std::move(other.name_)), + type_(other.type_), + invert_match_(other.invert_match_) { + switch (type_) { + case Type::kRange: + range_start_ = other.range_start_; + range_end_ = other.range_end_; + break; + case Type::kPresent: + present_match_ = other.present_match_; + break; + default: + matcher_ = std::move(other.matcher_); + } +} + +HeaderMatcher& HeaderMatcher::operator=(HeaderMatcher&& other) noexcept { + name_ = std::move(other.name_); + type_ = other.type_; + invert_match_ = other.invert_match_; + switch (type_) { + case Type::kRange: + range_start_ = other.range_start_; + range_end_ = other.range_end_; + break; + case Type::kPresent: + present_match_ = other.present_match_; + break; + default: + matcher_ = std::move(other.matcher_); + } + return *this; +} + +bool HeaderMatcher::operator==(const HeaderMatcher& other) const { + if (name_ != other.name_) return false; + if (type_ != other.type_) return false; + if (invert_match_ != other.invert_match_) return false; + switch (type_) { + case Type::kRange: + return range_start_ == other.range_start_ && + range_end_ == other.range_end_; + case Type::kPresent: + return present_match_ == other.present_match_; + default: + return matcher_ == other.matcher_; + } +} + +bool HeaderMatcher::Match( + const absl::optional& value) const { + bool match; + if (type_ == Type::kPresent) { + match = value.has_value() == present_match_; + } else if (!value.has_value()) { + // All other types fail to match if field is not present. + match = false; + } else if (type_ == Type::kRange) { + int64_t int_value; + match = absl::SimpleAtoi(value.value(), &int_value) && + int_value >= range_start_ && int_value < range_end_; + } else { + match = matcher_.Match(value.value()); + } + return match != invert_match_; +} + +std::string HeaderMatcher::ToString() const { + switch (type_) { + case Type::kRange: + return absl::StrFormat("HeaderMatcher{%s %srange=[%d, %d]}", name_, + invert_match_ ? "not " : "", range_start_, + range_end_); + case Type::kPresent: + return absl::StrFormat("HeaderMatcher{%s %spresent=%s}", name_, + invert_match_ ? "not " : "", + present_match_ ? "true" : "false"); + case Type::kExact: + case Type::kPrefix: + case Type::kSuffix: + case Type::kSafeRegex: + case Type::kContains: + return absl::StrFormat("HeaderMatcher{%s %s%s}", name_, + invert_match_ ? "not " : "", matcher_.ToString()); + default: + return ""; + } +} + +} // namespace grpc_core diff --git a/src/core/lib/profiling/basic_timers.cc b/src/core/lib/profiling/basic_timers.cc new file mode 100644 index 00000000..f98624bf --- /dev/null +++ b/src/core/lib/profiling/basic_timers.cc @@ -0,0 +1,295 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/profiling/timers.h" + +#ifdef GRPC_BASIC_PROFILER + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/profiling/timers.h" + +typedef enum { BEGIN = '{', END = '}', MARK = '.' } marker_type; + +typedef struct gpr_timer_entry { + gpr_timespec tm; + const char* tagstr; + const char* file; + short line; + char type; + uint8_t important; + int thd; +} gpr_timer_entry; + +#define MAX_COUNT 1000000 + +typedef struct gpr_timer_log { + size_t num_entries; + struct gpr_timer_log* next; + struct gpr_timer_log* prev; + gpr_timer_entry log[MAX_COUNT]; +} gpr_timer_log; + +typedef struct gpr_timer_log_list { + gpr_timer_log* head; + /* valid iff head!=NULL */ + gpr_timer_log* tail; +} gpr_timer_log_list; + +static GPR_THREAD_LOCAL(gpr_timer_log*) g_thread_log; +static gpr_once g_once_init = GPR_ONCE_INIT; +static FILE* output_file; +static const char* output_filename_or_null = NULL; +static pthread_mutex_t g_mu; +static pthread_cond_t g_cv; +static gpr_timer_log_list g_in_progress_logs; +static gpr_timer_log_list g_done_logs; +static int g_shutdown; +static pthread_t g_writing_thread; +static GPR_THREAD_LOCAL(int) g_thread_id; +static int g_next_thread_id; +static int g_writing_enabled = 1; + +GPR_GLOBAL_CONFIG_DEFINE_STRING(grpc_latency_trace, "latency_trace.txt", + "Output file name for latency trace") + +static const char* output_filename() { + if (output_filename_or_null == NULL) { + grpc_core::UniquePtr value = + GPR_GLOBAL_CONFIG_GET(grpc_latency_trace); + if (strlen(value.get()) > 0) { + output_filename_or_null = value.release(); + } else { + output_filename_or_null = "latency_trace.txt"; + } + } + return output_filename_or_null; +} + +static int timer_log_push_back(gpr_timer_log_list* list, gpr_timer_log* log) { + if (list->head == NULL) { + list->head = list->tail = log; + log->next = log->prev = NULL; + return 1; + } else { + log->prev = list->tail; + log->next = NULL; + list->tail->next = log; + list->tail = log; + return 0; + } +} + +static gpr_timer_log* timer_log_pop_front(gpr_timer_log_list* list) { + gpr_timer_log* out = list->head; + if (out != NULL) { + list->head = out->next; + if (list->head != NULL) { + list->head->prev = NULL; + } else { + list->tail = NULL; + } + } + return out; +} + +static void timer_log_remove(gpr_timer_log_list* list, gpr_timer_log* log) { + if (log->prev == NULL) { + list->head = log->next; + if (list->head != NULL) { + list->head->prev = NULL; + } + } else { + log->prev->next = log->next; + } + if (log->next == NULL) { + list->tail = log->prev; + if (list->tail != NULL) { + list->tail->next = NULL; + } + } else { + log->next->prev = log->prev; + } +} + +static void write_log(gpr_timer_log* log) { + size_t i; + if (output_file == NULL) { + output_file = fopen(output_filename(), "w"); + } + for (i = 0; i < log->num_entries; i++) { + gpr_timer_entry* entry = &(log->log[i]); + if (gpr_time_cmp(entry->tm, gpr_time_0(entry->tm.clock_type)) < 0) { + entry->tm = gpr_time_0(entry->tm.clock_type); + } + fprintf(output_file, + "{\"t\": %" PRId64 + ".%09d, \"thd\": \"%d\", \"type\": \"%c\", \"tag\": " + "\"%s\", \"file\": \"%s\", \"line\": %d, \"imp\": %d}\n", + entry->tm.tv_sec, entry->tm.tv_nsec, entry->thd, entry->type, + entry->tagstr, entry->file, entry->line, entry->important); + } +} + +static void* writing_thread(void* unused) { + gpr_timer_log* log; + pthread_mutex_lock(&g_mu); + for (;;) { + while ((log = timer_log_pop_front(&g_done_logs)) == NULL && !g_shutdown) { + pthread_cond_wait(&g_cv, &g_mu); + } + if (log != NULL) { + pthread_mutex_unlock(&g_mu); + write_log(log); + free(log); + pthread_mutex_lock(&g_mu); + } + if (g_shutdown) { + pthread_mutex_unlock(&g_mu); + return NULL; + } + } +} + +static void flush_logs(gpr_timer_log_list* list) { + gpr_timer_log* log; + while ((log = timer_log_pop_front(list)) != NULL) { + write_log(log); + free(log); + } +} + +static void finish_writing(void) { + pthread_mutex_lock(&g_mu); + g_shutdown = 1; + pthread_cond_signal(&g_cv); + pthread_mutex_unlock(&g_mu); + pthread_join(g_writing_thread, NULL); + + gpr_log(GPR_INFO, "flushing logs"); + + pthread_mutex_lock(&g_mu); + flush_logs(&g_done_logs); + flush_logs(&g_in_progress_logs); + pthread_mutex_unlock(&g_mu); + + if (output_file) { + fclose(output_file); + } +} + +void gpr_timers_set_log_filename(const char* filename) { + output_filename_or_null = filename; +} + +static void init_output() { + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE); + pthread_create(&g_writing_thread, &attr, &writing_thread, NULL); + pthread_attr_destroy(&attr); + + atexit(finish_writing); +} + +static void rotate_log() { + /* Using malloc here, as this code could end up being called by gpr_malloc */ + gpr_timer_log* log = static_cast(malloc(sizeof(*log))); + gpr_once_init(&g_once_init, init_output); + log->num_entries = 0; + pthread_mutex_lock(&g_mu); + if (g_thread_log != NULL) { + timer_log_remove(&g_in_progress_logs, g_thread_log); + if (timer_log_push_back(&g_done_logs, g_thread_log)) { + pthread_cond_signal(&g_cv); + } + } else { + g_thread_id = g_next_thread_id++; + } + timer_log_push_back(&g_in_progress_logs, log); + pthread_mutex_unlock(&g_mu); + g_thread_log = log; +} + +static void gpr_timers_log_add(const char* tagstr, marker_type type, + int important, const char* file, int line) { + gpr_timer_entry* entry; + + if (!g_writing_enabled) { + return; + } + + if (g_thread_log == NULL || g_thread_log->num_entries == MAX_COUNT) { + rotate_log(); + } + + entry = &g_thread_log->log[g_thread_log->num_entries++]; + + entry->tm = gpr_now(GPR_CLOCK_PRECISE); + entry->tagstr = tagstr; + entry->type = type; + entry->file = file; + entry->line = (short)line; + entry->important = important != 0; + entry->thd = g_thread_id; +} + +/* Latency profiler API implementation. */ +void gpr_timer_add_mark(const char* tagstr, int important, const char* file, + int line) { + gpr_timers_log_add(tagstr, MARK, important, file, line); +} + +void gpr_timer_begin(const char* tagstr, int important, const char* file, + int line) { + gpr_timers_log_add(tagstr, BEGIN, important, file, line); +} + +void gpr_timer_end(const char* tagstr, int important, const char* file, + int line) { + gpr_timers_log_add(tagstr, END, important, file, line); +} + +void gpr_timer_set_enabled(int enabled) { g_writing_enabled = enabled; } + +/* Basic profiler specific API functions. */ +void gpr_timers_global_init(void) {} + +void gpr_timers_global_destroy(void) {} + +#else /* !GRPC_BASIC_PROFILER */ +void gpr_timers_global_init(void) {} + +void gpr_timers_global_destroy(void) {} + +void gpr_timers_set_log_filename(const char* /*filename*/) {} + +void gpr_timer_set_enabled(int /*enabled*/) {} +#endif /* GRPC_BASIC_PROFILER */ diff --git a/src/core/lib/profiling/stap_timers.cc b/src/core/lib/profiling/stap_timers.cc new file mode 100644 index 00000000..a00dbaf2 --- /dev/null +++ b/src/core/lib/profiling/stap_timers.cc @@ -0,0 +1,50 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GRPC_STAP_PROFILER + +#include + +#include "src/core/lib/profiling/timers.h" +/* Generated from src/core/profiling/stap_probes.d */ +#include "src/core/lib/profiling/stap_probes.h" + +/* Latency profiler API implementation. */ +void gpr_timer_add_mark(int tag, const char* tagstr, void* id, const char* file, + int line) { + _STAP_ADD_MARK(tag); +} + +void gpr_timer_add_important_mark(int tag, const char* tagstr, void* id, + const char* file, int line) { + _STAP_ADD_IMPORTANT_MARK(tag); +} + +void gpr_timer_begin(int tag, const char* tagstr, void* id, const char* file, + int line) { + _STAP_TIMING_NS_BEGIN(tag); +} + +void gpr_timer_end(int tag, const char* tagstr, void* id, const char* file, + int line) { + _STAP_TIMING_NS_END(tag); +} + +#endif /* GRPC_STAP_PROFILER */ diff --git a/src/core/lib/promise/activity.cc b/src/core/lib/promise/activity.cc new file mode 100644 index 00000000..979d0859 --- /dev/null +++ b/src/core/lib/promise/activity.cc @@ -0,0 +1,114 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/promise/activity.h" + +#include "src/core/lib/gprpp/atomic_utils.h" + +namespace grpc_core { + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +ABSL_CONST_INIT GPR_THREAD_LOCAL(Activity*) Activity::g_current_activity_ = + nullptr; +Waker::Unwakeable Waker::unwakeable_; + +/////////////////////////////////////////////////////////////////////////////// +// HELPER TYPES + +// Weak handle to an Activity. +// Handle can persist while Activity goes away. +class Activity::Handle final : public Wakeable { + public: + explicit Handle(Activity* activity) : activity_(activity) {} + + // Ref the Handle (not the activity). + void Ref() { refs_.fetch_add(1, std::memory_order_relaxed); } + + // Activity is going away... drop its reference and sever the connection back. + void DropActivity() ABSL_LOCKS_EXCLUDED(mu_) { + mu_.Lock(); + GPR_ASSERT(activity_ != nullptr); + activity_ = nullptr; + mu_.Unlock(); + Unref(); + } + + // Activity needs to wake up (if it still exists!) - wake it up, and drop the + // ref that was kept for this handle. + void Wakeup() override ABSL_LOCKS_EXCLUDED(mu_) { + mu_.Lock(); + // Note that activity refcount can drop to zero, but we could win the lock + // against DropActivity, so we need to only increase activities refcount if + // it is non-zero. + if (activity_ && activity_->RefIfNonzero()) { + Activity* activity = activity_; + mu_.Unlock(); + // Activity still exists and we have a reference: wake it up, which will + // drop the ref. + activity->Wakeup(); + } else { + // Could not get the activity - it's either gone or going. No need to wake + // it up! + mu_.Unlock(); + } + // Drop the ref to the handle (we have one ref = one wakeup semantics). + Unref(); + } + + void Drop() override { Unref(); } + + private: + // Unref the Handle (not the activity). + void Unref() { + if (1 == refs_.fetch_sub(1, std::memory_order_acq_rel)) { + delete this; + } + } + + // Two initial refs: one for the waiter that caused instantiation, one for the + // activity. + std::atomic refs_{2}; + Mutex mu_ ABSL_ACQUIRED_AFTER(activity_->mu_); + Activity* activity_ ABSL_GUARDED_BY(mu_); +}; + +/////////////////////////////////////////////////////////////////////////////// +// ACTIVITY IMPLEMENTATION + +bool Activity::RefIfNonzero() { return IncrementIfNonzero(&refs_); } + +Activity::Handle* Activity::RefHandle() { + if (handle_ == nullptr) { + // No handle created yet - construct it and return it. + handle_ = new Handle(this); + return handle_; + } else { + // Already had to create a handle, ref & return it. + handle_->Ref(); + return handle_; + } +} + +void Activity::DropHandle() { + handle_->DropActivity(); + handle_ = nullptr; +} + +Waker Activity::MakeNonOwningWaker() { return Waker(RefHandle()); } + +} // namespace grpc_core diff --git a/src/core/lib/resource_quota/memory_quota.cc b/src/core/lib/resource_quota/memory_quota.cc new file mode 100644 index 00000000..c72c09e9 --- /dev/null +++ b/src/core/lib/resource_quota/memory_quota.cc @@ -0,0 +1,408 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/resource_quota/memory_quota.h" + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/promise/exec_ctx_wakeup_scheduler.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/promise/seq.h" + +namespace grpc_core { + +// Maximum number of bytes an allocator will request from a quota in one step. +// Larger allocations than this will require multiple allocation requests. +static constexpr size_t kMaxReplenishBytes = 1024 * 1024; + +// Minimum number of bytes an allocator will request from a quota in one step. +static constexpr size_t kMinReplenishBytes = 4096; + +// +// Reclaimer +// + +ReclamationSweep::~ReclamationSweep() { + if (memory_quota_ != nullptr) { + memory_quota_->FinishReclamation(sweep_token_); + } +} + +// +// ReclaimerQueue +// + +const ReclaimerQueue::Index ReclaimerQueue::kInvalidIndex; + +void ReclaimerQueue::Insert( + std::shared_ptr allocator, + ReclamationFunction reclaimer, Index* index) { + ReleasableMutexLock lock(&mu_); + if (*index < entries_.size() && entries_[*index].allocator == allocator) { + entries_[*index].reclaimer.swap(reclaimer); + lock.Release(); + reclaimer({}); + return; + } + if (free_entries_.empty()) { + *index = entries_.size(); + entries_.emplace_back(std::move(allocator), std::move(reclaimer)); + } else { + *index = free_entries_.back(); + free_entries_.pop_back(); + Entry& entry = entries_[*index]; + entry.allocator = std::move(allocator); + entry.reclaimer = std::move(reclaimer); + } + if (queue_.empty()) waker_.Wakeup(); + queue_.push(*index); +} + +ReclamationFunction ReclaimerQueue::Cancel( + Index index, EventEngineMemoryAllocatorImpl* allocator) { + MutexLock lock(&mu_); + if (index >= entries_.size()) return nullptr; + Entry& entry = entries_[index]; + if (entry.allocator.get() != allocator) return {}; + entry.allocator.reset(); + return std::move(entry.reclaimer); +} + +Poll ReclaimerQueue::PollNext() { + MutexLock lock(&mu_); + while (true) { + if (queue_.empty()) { + waker_ = Activity::current()->MakeNonOwningWaker(); + return Pending{}; + } + Index index = queue_.front(); + queue_.pop(); + free_entries_.push_back(index); + Entry& entry = entries_[index]; + if (entry.allocator != nullptr) { + entry.allocator.reset(); + return std::move(entry.reclaimer); + } + } +} + +// +// GrpcMemoryAllocatorImpl +// + +GrpcMemoryAllocatorImpl::GrpcMemoryAllocatorImpl( + std::shared_ptr memory_quota) + : memory_quota_(memory_quota) { + memory_quota_->Take(taken_bytes_); +} + +GrpcMemoryAllocatorImpl::~GrpcMemoryAllocatorImpl() { + GPR_ASSERT(free_bytes_.load(std::memory_order_acquire) + + sizeof(GrpcMemoryAllocatorImpl) == + taken_bytes_); + memory_quota_->Return(taken_bytes_); +} + +void GrpcMemoryAllocatorImpl::Shutdown() { + std::shared_ptr memory_quota; + ReclaimerQueue::Index reclamation_indices[kNumReclamationPasses]; + { + MutexLock lock(&memory_quota_mu_); + GPR_ASSERT(!shutdown_); + shutdown_ = true; + memory_quota = memory_quota_; + for (size_t i = 0; i < kNumReclamationPasses; i++) { + reclamation_indices[i] = absl::exchange(reclamation_indices_[i], + ReclaimerQueue::kInvalidIndex); + } + } + for (size_t i = 0; i < kNumReclamationPasses; i++) { + auto fn = memory_quota->CancelReclaimer(i, reclamation_indices[i], this); + if (fn != nullptr) fn({}); + } +} + +size_t GrpcMemoryAllocatorImpl::Reserve(MemoryRequest request) { + // Validate request - performed here so we don't bloat the generated code with + // inlined asserts. + GPR_ASSERT(request.min() <= request.max()); + GPR_ASSERT(request.max() <= MemoryRequest::max_allowed_size()); + while (true) { + // Attempt to reserve memory from our pool. + auto reservation = TryReserve(request); + if (reservation.has_value()) return *reservation; + // If that failed, grab more from the quota and retry. + Replenish(); + } +} + +absl::optional GrpcMemoryAllocatorImpl::TryReserve( + MemoryRequest request) { + // How much memory should we request? (see the scaling below) + size_t scaled_size_over_min = request.max() - request.min(); + // Scale the request down according to memory pressure if we have that + // flexibility. + if (scaled_size_over_min != 0) { + double pressure; + { + MutexLock lock(&memory_quota_mu_); + pressure = memory_quota_->InstantaneousPressure(); + } + // Reduce allocation size proportional to the pressure > 80% usage. + if (pressure > 0.8) { + scaled_size_over_min = + std::min(scaled_size_over_min, + static_cast((request.max() - request.min()) * + (1.0 - pressure) / 0.2)); + } + } + + // How much do we want to reserve? + const size_t reserve = request.min() + scaled_size_over_min; + // See how many bytes are available. + size_t available = free_bytes_.load(std::memory_order_acquire); + while (true) { + // Does the current free pool satisfy the request? + if (available < reserve) { + return {}; + } + // Try to reserve the requested amount. + // If the amount of free memory changed through this loop, then available + // will be set to the new value and we'll repeat. + if (free_bytes_.compare_exchange_weak(available, available - reserve, + std::memory_order_acq_rel, + std::memory_order_acquire)) { + return reserve; + } + } +} + +void GrpcMemoryAllocatorImpl::Replenish() { + MutexLock lock(&memory_quota_mu_); + GPR_ASSERT(!shutdown_); + // Attempt a fairly low rate exponential growth request size, bounded between + // some reasonable limits declared at top of file. + auto amount = Clamp(taken_bytes_ / 3, kMinReplenishBytes, kMaxReplenishBytes); + // Take the requested amount from the quota. + memory_quota_->Take(amount); + // Record that we've taken it. + taken_bytes_ += amount; + // Add the taken amount to the free pool. + free_bytes_.fetch_add(amount, std::memory_order_acq_rel); + // See if we can add ourselves as a reclaimer. + MaybeRegisterReclaimerLocked(); +} + +void GrpcMemoryAllocatorImpl::MaybeRegisterReclaimer() { + MutexLock lock(&memory_quota_mu_); + MaybeRegisterReclaimerLocked(); +} + +void GrpcMemoryAllocatorImpl::MaybeRegisterReclaimerLocked() { + // If the reclaimer is already registered, then there's nothing to do. + if (reclamation_indices_[0] != ReclaimerQueue::kInvalidIndex) return; + if (shutdown_) return; + // Grab references to the things we'll need + auto self = shared_from_this(); + memory_quota_->InsertReclaimer( + 0, self, + [self](absl::optional sweep) { + if (!sweep.has_value()) return; + auto* p = static_cast(self.get()); + MutexLock lock(&p->memory_quota_mu_); + // Figure out how many bytes we can return to the quota. + size_t return_bytes = + p->free_bytes_.exchange(0, std::memory_order_acq_rel); + if (return_bytes == 0) return; + // Subtract that from our outstanding balance. + p->taken_bytes_ -= return_bytes; + // And return them to the quota. + p->memory_quota_->Return(return_bytes); + }, + &reclamation_indices_[0]); +} + +void GrpcMemoryAllocatorImpl::Rebind( + std::shared_ptr memory_quota) { + MutexLock lock(&memory_quota_mu_); + GPR_ASSERT(!shutdown_); + if (memory_quota_ == memory_quota) return; + // Return memory to the original memory quota. + memory_quota_->Return(taken_bytes_); + // Fetch back any reclaimers that are queued. + ReclamationFunction reclaimers[kNumReclamationPasses]; + for (size_t i = 0; i < kNumReclamationPasses; i++) { + reclaimers[i] = + memory_quota_->CancelReclaimer(i, reclamation_indices_[i], this); + } + // Switch to the new memory quota, leaving the old one in memory_quota so that + // when we unref it, we are outside of lock. + memory_quota_.swap(memory_quota); + // Drop our freed memory down to zero, to avoid needing to ask the new + // quota for memory we're not currently using. + taken_bytes_ -= free_bytes_.exchange(0, std::memory_order_acq_rel); + // And let the new quota know how much we're already using. + memory_quota_->Take(taken_bytes_); + // Reinsert active reclaimers. + for (size_t i = 0; i < kNumReclamationPasses; i++) { + if (reclaimers[i] == nullptr) continue; + memory_quota_->InsertReclaimer(i, shared_from_this(), + std::move(reclaimers[i]), + &reclamation_indices_[i]); + } +} + +void GrpcMemoryAllocatorImpl::PostReclaimer(ReclamationPass pass, + ReclamationFunction fn) { + MutexLock lock(&memory_quota_mu_); + GPR_ASSERT(!shutdown_); + auto pass_num = static_cast(pass); + memory_quota_->InsertReclaimer(pass_num, shared_from_this(), std::move(fn), + &reclamation_indices_[pass_num]); +} + +// +// MemoryOwner +// + +void MemoryOwner::Rebind(MemoryQuota* quota) { + static_cast(allocator_.get_internal_impl_ptr()) + ->Rebind(quota->memory_quota_); +} + +// +// BasicMemoryQuota +// + +class BasicMemoryQuota::WaitForSweepPromise { + public: + WaitForSweepPromise(std::shared_ptr memory_quota, + uint64_t token) + : memory_quota_(std::move(memory_quota)), token_(token) {} + + struct Empty {}; + Poll operator()() { + if (memory_quota_->reclamation_counter_.load(std::memory_order_relaxed) != + token_) { + return Empty{}; + } else { + return Pending{}; + } + } + + private: + std::shared_ptr memory_quota_; + uint64_t token_; +}; + +void BasicMemoryQuota::Start() { + auto self = shared_from_this(); + + // Reclamation loop: + // basically, wait until we are in overcommit (free_bytes_ < 0), and then: + // while (free_bytes_ < 0) reclaim_memory() + // ... and repeat + auto reclamation_loop = Loop(Seq( + [self]() -> Poll { + // If there's free memory we no longer need to reclaim memory! + if (self->free_bytes_.load(std::memory_order_acquire) > 0) { + return Pending{}; + } + return 0; + }, + [self]() { + // Race biases to the first thing that completes... so this will + // choose the highest priority/least destructive thing to do that's + // available. + return Race(self->reclaimers_[0].Next(), self->reclaimers_[1].Next(), + self->reclaimers_[2].Next(), self->reclaimers_[3].Next()); + }, + [self](ReclamationFunction reclaimer) { + // One of the reclaimer queues gave us a way to get back memory. + // Call the reclaimer with a token that contains enough to wake us + // up again. + const uint64_t token = + self->reclamation_counter_.fetch_add(1, std::memory_order_relaxed) + + 1; + reclaimer(ReclamationSweep(self, token)); + // Return a promise that will wait for our barrier. This will be + // awoken by the token above being destroyed. So, once that token is + // destroyed, we'll be able to proceed. + return WaitForSweepPromise(self, token); + }, + []() -> LoopCtl { + // Continue the loop! + return Continue{}; + })); + + reclaimer_activity_ = + MakeActivity(std::move(reclamation_loop), ExecCtxWakeupScheduler(), + [](absl::Status status) { + GPR_ASSERT(status.code() == absl::StatusCode::kCancelled); + }); +} + +void BasicMemoryQuota::Stop() { reclaimer_activity_.reset(); } + +void BasicMemoryQuota::SetSize(size_t new_size) { + size_t old_size = quota_size_.exchange(new_size, std::memory_order_relaxed); + if (old_size < new_size) { + // We're growing the quota. + Return(new_size - old_size); + } else { + // We're shrinking the quota. + Take(old_size - new_size); + } +} + +void BasicMemoryQuota::Take(size_t amount) { + // If there's a request for nothing, then do nothing! + if (amount == 0) return; + GPR_DEBUG_ASSERT(amount <= std::numeric_limits::max()); + // Grab memory from the quota. + auto prior = free_bytes_.fetch_sub(amount, std::memory_order_acq_rel); + // If we push into overcommit, awake the reclaimer. + if (prior >= 0 && prior < static_cast(amount)) { + if (reclaimer_activity_ != nullptr) reclaimer_activity_->ForceWakeup(); + } +} + +void BasicMemoryQuota::FinishReclamation(uint64_t token) { + uint64_t current = reclamation_counter_.load(std::memory_order_relaxed); + if (current != token) return; + if (reclamation_counter_.compare_exchange_strong(current, current + 1, + std::memory_order_relaxed, + std::memory_order_relaxed)) { + if (reclaimer_activity_ != nullptr) reclaimer_activity_->ForceWakeup(); + } +} + +void BasicMemoryQuota::Return(size_t amount) { + free_bytes_.fetch_add(amount, std::memory_order_relaxed); +} + +size_t BasicMemoryQuota::InstantaneousPressure() const { + double free = free_bytes_.load(); + if (free < 0) free = 0; + double size = quota_size_.load(); + if (size < 1) return 1.0; + double pressure = (size - free) / size; + if (pressure < 0.0) pressure = 0.0; + if (pressure > 1.0) pressure = 1.0; + return pressure; +} + +} // namespace grpc_core diff --git a/src/core/lib/resource_quota/resource_quota.cc b/src/core/lib/resource_quota/resource_quota.cc new file mode 100644 index 00000000..cf4229bb --- /dev/null +++ b/src/core/lib/resource_quota/resource_quota.cc @@ -0,0 +1,27 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/resource_quota/resource_quota.h" + +namespace grpc_core { + +ResourceQuota::ResourceQuota() + : memory_quota_(std::make_shared()), + thread_quota_(MakeRefCounted()) {} + +ResourceQuota::~ResourceQuota() = default; + +} // namespace grpc_core diff --git a/src/core/lib/resource_quota/thread_quota.cc b/src/core/lib/resource_quota/thread_quota.cc new file mode 100644 index 00000000..c935be0d --- /dev/null +++ b/src/core/lib/resource_quota/thread_quota.cc @@ -0,0 +1,43 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/resource_quota/thread_quota.h" + +namespace grpc_core { + +ThreadQuota::ThreadQuota() = default; + +ThreadQuota::~ThreadQuota() = default; + +void ThreadQuota::SetMax(size_t new_max) { + MutexLock lock(&mu_); + max_ = new_max; +} + +bool ThreadQuota::Reserve(size_t num_threads) { + MutexLock lock(&mu_); + if (allocated_ + num_threads > max_) return false; + allocated_ += num_threads; + return true; +} + +void ThreadQuota::Release(size_t num_threads) { + MutexLock lock(&mu_); + GPR_ASSERT(num_threads <= allocated_); + allocated_ -= num_threads; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/authorization_policy_provider_null_vtable.cc b/src/core/lib/security/authorization/authorization_policy_provider_null_vtable.cc new file mode 100644 index 00000000..1f1ccc98 --- /dev/null +++ b/src/core/lib/security/authorization/authorization_policy_provider_null_vtable.cc @@ -0,0 +1,24 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +// Wrapper API declared in grpc.h + +// Required only for insecure build targets. +const grpc_arg_pointer_vtable* grpc_authorization_policy_provider_arg_vtable() { + return nullptr; +} diff --git a/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc new file mode 100644 index 00000000..f8a4426e --- /dev/null +++ b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc @@ -0,0 +1,46 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "src/core/lib/security/authorization/authorization_policy_provider.h" + +namespace { + +void* ProviderArgCopy(void* p) { + grpc_authorization_policy_provider* provider = + static_cast(p); + provider->Ref().release(); + return provider; +} + +void ProviderArgDestroy(void* p) { + grpc_authorization_policy_provider* provider = + static_cast(p); + provider->Unref(); +} + +int ProviderArgCmp(void* p, void* q) { return grpc_core::QsortCompare(p, q); } + +} // namespace + +// Wrapper API declared in grpc.h + +const grpc_arg_pointer_vtable* grpc_authorization_policy_provider_arg_vtable() { + static const grpc_arg_pointer_vtable vtable = { + ProviderArgCopy, ProviderArgDestroy, ProviderArgCmp}; + return &vtable; +} diff --git a/src/core/lib/security/authorization/cel_authorization_engine.cc b/src/core/lib/security/authorization/cel_authorization_engine.cc new file mode 100644 index 00000000..a5d0a92e --- /dev/null +++ b/src/core/lib/security/authorization/cel_authorization_engine.cc @@ -0,0 +1,179 @@ +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/cel_authorization_engine.h" + +#include "absl/memory/memory.h" + +#include "src/core/lib/address_utils/sockaddr_utils.h" + +namespace grpc_core { + +namespace { + +// Symbols for traversing Envoy Attributes +constexpr char kUrlPath[] = "url_path"; +constexpr char kHost[] = "host"; +constexpr char kMethod[] = "method"; +constexpr char kHeaders[] = "headers"; +constexpr char kSourceAddress[] = "source_address"; +constexpr char kSourcePort[] = "source_port"; +constexpr char kDestinationAddress[] = "destination_address"; +constexpr char kDestinationPort[] = "destination_port"; +constexpr char kSpiffeId[] = "spiffe_id"; +constexpr char kCertServerName[] = "cert_server_name"; + +} // namespace + +std::unique_ptr +CelAuthorizationEngine::CreateCelAuthorizationEngine( + const std::vector& rbac_policies) { + if (rbac_policies.empty() || rbac_policies.size() > 2) { + gpr_log(GPR_ERROR, + "Invalid rbac policies vector. Must contain either one or two rbac " + "policies."); + return nullptr; + } else if (rbac_policies.size() == 2 && + (envoy_config_rbac_v3_RBAC_action(rbac_policies[0]) != kDeny || + envoy_config_rbac_v3_RBAC_action(rbac_policies[1]) != kAllow)) { + gpr_log(GPR_ERROR, + "Invalid rbac policies vector. Must contain one deny \ + policy and one allow policy, in that order."); + return nullptr; + } else { + return absl::make_unique(rbac_policies); + } +} + +CelAuthorizationEngine::CelAuthorizationEngine( + const std::vector& rbac_policies) { + for (const auto& rbac_policy : rbac_policies) { + // Extract array of policies and store their condition fields in either + // allow_if_matched_ or deny_if_matched_, depending on the policy action. + upb::Arena temp_arena; + size_t policy_num = UPB_MAP_BEGIN; + const envoy_config_rbac_v3_RBAC_PoliciesEntry* policy_entry; + while ((policy_entry = envoy_config_rbac_v3_RBAC_policies_next( + rbac_policy, &policy_num)) != nullptr) { + const upb_strview policy_name_strview = + envoy_config_rbac_v3_RBAC_PoliciesEntry_key(policy_entry); + const std::string policy_name(policy_name_strview.data, + policy_name_strview.size); + const envoy_config_rbac_v3_Policy* policy = + envoy_config_rbac_v3_RBAC_PoliciesEntry_value(policy_entry); + const google_api_expr_v1alpha1_Expr* condition = + envoy_config_rbac_v3_Policy_condition(policy); + // Parse condition to make a pointer tied to the lifetime of arena_. + size_t serial_len; + const char* serialized = google_api_expr_v1alpha1_Expr_serialize( + condition, temp_arena.ptr(), &serial_len); + const google_api_expr_v1alpha1_Expr* parsed_condition = + google_api_expr_v1alpha1_Expr_parse(serialized, serial_len, + arena_.ptr()); + if (envoy_config_rbac_v3_RBAC_action(rbac_policy) == kAllow) { + allow_if_matched_.insert(std::make_pair(policy_name, parsed_condition)); + } else { + deny_if_matched_.insert(std::make_pair(policy_name, parsed_condition)); + } + } + } +} + +std::unique_ptr CelAuthorizationEngine::CreateActivation( + const EvaluateArgs& args) { + std::unique_ptr activation; + for (const auto& elem : envoy_attributes_) { + if (elem == kUrlPath) { + absl::string_view url_path(args.GetPath()); + if (!url_path.empty()) { + activation->InsertValue(kUrlPath, + mock_cel::CelValue::CreateStringView(url_path)); + } + } else if (elem == kHost) { + absl::string_view host(args.GetHost()); + if (!host.empty()) { + activation->InsertValue(kHost, + mock_cel::CelValue::CreateStringView(host)); + } + } else if (elem == kMethod) { + absl::string_view method(args.GetMethod()); + if (!method.empty()) { + activation->InsertValue(kMethod, + mock_cel::CelValue::CreateStringView(method)); + } + } else if (elem == kHeaders) { + std::multimap headers = + args.GetHeaders(); + std::vector> + header_items; + for (const auto& header_key : header_keys_) { + auto header_item = headers.find(header_key); + if (header_item != headers.end()) { + header_items.push_back( + std::pair( + mock_cel::CelValue::CreateStringView(header_key), + mock_cel::CelValue::CreateStringView(header_item->second))); + } + } + headers_ = mock_cel::ContainerBackedMapImpl::Create( + absl::Span>( + header_items)); + activation->InsertValue(kHeaders, + mock_cel::CelValue::CreateMap(headers_.get())); + } else if (elem == kSourceAddress) { + absl::string_view source_address(args.GetPeerAddressString()); + if (!source_address.empty()) { + activation->InsertValue( + kSourceAddress, + mock_cel::CelValue::CreateStringView(source_address)); + } + } else if (elem == kSourcePort) { + activation->InsertValue( + kSourcePort, mock_cel::CelValue::CreateInt64(args.GetPeerPort())); + } else if (elem == kDestinationAddress) { + absl::string_view destination_address(args.GetLocalAddressString()); + if (!destination_address.empty()) { + activation->InsertValue( + kDestinationAddress, + mock_cel::CelValue::CreateStringView(destination_address)); + } + } else if (elem == kDestinationPort) { + activation->InsertValue(kDestinationPort, mock_cel::CelValue::CreateInt64( + args.GetLocalPort())); + } else if (elem == kSpiffeId) { + absl::string_view spiffe_id(args.GetSpiffeId()); + if (!spiffe_id.empty()) { + activation->InsertValue( + kSpiffeId, mock_cel::CelValue::CreateStringView(spiffe_id)); + } + } else if (elem == kCertServerName) { + absl::string_view cert_server_name(args.GetCommonName()); + if (!cert_server_name.empty()) { + activation->InsertValue( + kCertServerName, + mock_cel::CelValue::CreateStringView(cert_server_name)); + } + } else { + gpr_log(GPR_ERROR, + "Error: Authorization engine does not support evaluating " + "attribute %s.", + elem.c_str()); + } + } + return activation; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/evaluate_args.cc b/src/core/lib/security/authorization/evaluate_args.cc new file mode 100644 index 00000000..80b713c0 --- /dev/null +++ b/src/core/lib/security/authorization/evaluate_args.cc @@ -0,0 +1,213 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/evaluate_args.h" + +#include "absl/strings/numbers.h" + +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/security/credentials/tls/tls_utils.h" +#include "src/core/lib/slice/slice_utils.h" + +namespace grpc_core { + +namespace { + +EvaluateArgs::PerChannelArgs::Address ParseEndpointUri( + absl::string_view uri_text) { + EvaluateArgs::PerChannelArgs::Address address; + absl::StatusOr uri = URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_DEBUG, "Failed to parse uri."); + return address; + } + absl::string_view host_view; + absl::string_view port_view; + if (!SplitHostPort(uri->path(), &host_view, &port_view)) { + gpr_log(GPR_DEBUG, "Failed to split %s into host and port.", + uri->path().c_str()); + return address; + } + if (!absl::SimpleAtoi(port_view, &address.port)) { + gpr_log(GPR_DEBUG, "Port %s is out of range or null.", + std::string(port_view).c_str()); + } + address.address_str = std::string(host_view); + grpc_error_handle error = grpc_string_to_sockaddr( + &address.address, address.address_str.c_str(), address.port); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_DEBUG, "Address %s is not IPv4/IPv6. Error: %s", + address.address_str.c_str(), grpc_error_std_string(error).c_str()); + } + GRPC_ERROR_UNREF(error); + return address; +} + +} // namespace + +EvaluateArgs::PerChannelArgs::PerChannelArgs(grpc_auth_context* auth_context, + grpc_endpoint* endpoint) { + if (auth_context != nullptr) { + transport_security_type = GetAuthPropertyValue( + auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME); + spiffe_id = + GetAuthPropertyValue(auth_context, GRPC_PEER_SPIFFE_ID_PROPERTY_NAME); + uri_sans = GetAuthPropertyArray(auth_context, GRPC_PEER_URI_PROPERTY_NAME); + dns_sans = GetAuthPropertyArray(auth_context, GRPC_PEER_DNS_PROPERTY_NAME); + common_name = + GetAuthPropertyValue(auth_context, GRPC_X509_CN_PROPERTY_NAME); + } + if (endpoint != nullptr) { + local_address = ParseEndpointUri(grpc_endpoint_get_local_address(endpoint)); + peer_address = ParseEndpointUri(grpc_endpoint_get_peer(endpoint)); + } +} + +absl::string_view EvaluateArgs::GetPath() const { + absl::string_view path; + if (metadata_ != nullptr && + metadata_->legacy_index()->named.path != nullptr) { + grpc_linked_mdelem* elem = metadata_->legacy_index()->named.path; + const grpc_slice& val = GRPC_MDVALUE(elem->md); + path = StringViewFromSlice(val); + } + return path; +} + +absl::string_view EvaluateArgs::GetHost() const { + absl::string_view host; + if (metadata_ != nullptr && + metadata_->legacy_index()->named.host != nullptr) { + grpc_linked_mdelem* elem = metadata_->legacy_index()->named.host; + const grpc_slice& val = GRPC_MDVALUE(elem->md); + host = StringViewFromSlice(val); + } + return host; +} + +absl::string_view EvaluateArgs::GetMethod() const { + absl::string_view method; + if (metadata_ != nullptr && + metadata_->legacy_index()->named.method != nullptr) { + grpc_linked_mdelem* elem = metadata_->legacy_index()->named.method; + const grpc_slice& val = GRPC_MDVALUE(elem->md); + method = StringViewFromSlice(val); + } + return method; +} + +std::multimap EvaluateArgs::GetHeaders() + const { + std::multimap headers; + if (metadata_ == nullptr) { + return headers; + } + metadata_->ForEach([&](grpc_mdelem md) { + const grpc_slice& key = GRPC_MDKEY(md); + const grpc_slice& val = GRPC_MDVALUE(md); + headers.emplace(StringViewFromSlice(key), StringViewFromSlice(val)); + }); + return headers; +} + +absl::optional EvaluateArgs::GetHeaderValue( + absl::string_view key, std::string* concatenated_value) const { + if (metadata_ == nullptr) { + return absl::nullopt; + } + return metadata_->GetValue(key, concatenated_value); +} + +grpc_resolved_address EvaluateArgs::GetLocalAddress() const { + if (channel_args_ == nullptr) { + return {}; + } + return channel_args_->local_address.address; +} + +absl::string_view EvaluateArgs::GetLocalAddressString() const { + if (channel_args_ == nullptr) { + return ""; + } + return channel_args_->local_address.address_str; +} + +int EvaluateArgs::GetLocalPort() const { + if (channel_args_ == nullptr) { + return 0; + } + return channel_args_->local_address.port; +} + +grpc_resolved_address EvaluateArgs::GetPeerAddress() const { + if (channel_args_ == nullptr) { + return {}; + } + return channel_args_->peer_address.address; +} + +absl::string_view EvaluateArgs::GetPeerAddressString() const { + if (channel_args_ == nullptr) { + return ""; + } + return channel_args_->peer_address.address_str; +} + +int EvaluateArgs::GetPeerPort() const { + if (channel_args_ == nullptr) { + return 0; + } + return channel_args_->peer_address.port; +} + +absl::string_view EvaluateArgs::GetTransportSecurityType() const { + if (channel_args_ == nullptr) { + return ""; + } + return channel_args_->transport_security_type; +} + +absl::string_view EvaluateArgs::GetSpiffeId() const { + if (channel_args_ == nullptr) { + return ""; + } + return channel_args_->spiffe_id; +} + +std::vector EvaluateArgs::GetUriSans() const { + if (channel_args_ == nullptr) { + return {}; + } + return channel_args_->uri_sans; +} + +std::vector EvaluateArgs::GetDnsSans() const { + if (channel_args_ == nullptr) { + return {}; + } + return channel_args_->dns_sans; +} + +absl::string_view EvaluateArgs::GetCommonName() const { + if (channel_args_ == nullptr) { + return ""; + } + return channel_args_->common_name; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/grpc_authorization_engine.cc b/src/core/lib/security/authorization/grpc_authorization_engine.cc new file mode 100644 index 00000000..34fc9767 --- /dev/null +++ b/src/core/lib/security/authorization/grpc_authorization_engine.cc @@ -0,0 +1,49 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/grpc_authorization_engine.h" + +namespace grpc_core { + +GrpcAuthorizationEngine::GrpcAuthorizationEngine(Rbac policy) + : action_(policy.action) { + for (auto& sub_policy : policy.policies) { + Policy policy; + policy.name = sub_policy.first; + policy.matcher = absl::make_unique( + std::move(sub_policy.second)); + policies_.push_back(std::move(policy)); + } +} + +AuthorizationEngine::Decision GrpcAuthorizationEngine::Evaluate( + const EvaluateArgs& args) const { + Decision decision; + bool matches = false; + for (const auto& policy : policies_) { + if (policy.matcher->Matches(args)) { + matches = true; + decision.matching_policy_name = policy.name; + break; + } + } + decision.type = (matches == (action_ == Rbac::Action::kAllow)) + ? Decision::Type::kAllow + : Decision::Type::kDeny; + return decision; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/grpc_authorization_policy_provider.cc b/src/core/lib/security/authorization/grpc_authorization_policy_provider.cc new file mode 100644 index 00000000..c7236aac --- /dev/null +++ b/src/core/lib/security/authorization/grpc_authorization_policy_provider.cc @@ -0,0 +1,193 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/grpc_authorization_policy_provider.h" + +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/authorization/grpc_authorization_engine.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +extern TraceFlag grpc_sdk_authz_trace; + +absl::StatusOr> +StaticDataAuthorizationPolicyProvider::Create(absl::string_view authz_policy) { + auto policies_or = GenerateRbacPolicies(authz_policy); + if (!policies_or.ok()) { + return policies_or.status(); + } + return MakeRefCounted( + std::move(*policies_or)); +} + +StaticDataAuthorizationPolicyProvider::StaticDataAuthorizationPolicyProvider( + RbacPolicies policies) + : allow_engine_(MakeRefCounted( + std::move(policies.allow_policy))), + deny_engine_(MakeRefCounted( + std::move(policies.deny_policy))) {} + +namespace { + +absl::StatusOr ReadPolicyFromFile(absl::string_view policy_path) { + grpc_slice policy_slice = grpc_empty_slice(); + grpc_error_handle error = + grpc_load_file(std::string(policy_path).c_str(), 0, &policy_slice); + if (error != GRPC_ERROR_NONE) { + absl::Status status = + absl::InvalidArgumentError(grpc_error_std_string(error)); + GRPC_ERROR_UNREF(error); + return status; + } + std::string policy_contents(StringViewFromSlice(policy_slice)); + grpc_slice_unref_internal(policy_slice); + return policy_contents; +} + +gpr_timespec TimeoutSecondsToDeadline(int64_t seconds) { + return gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(seconds, GPR_TIMESPAN)); +} + +} // namespace + +absl::StatusOr> +FileWatcherAuthorizationPolicyProvider::Create( + absl::string_view authz_policy_path, unsigned int refresh_interval_sec) { + GPR_ASSERT(!authz_policy_path.empty()); + GPR_ASSERT(refresh_interval_sec > 0); + absl::Status status; + auto provider = MakeRefCounted( + authz_policy_path, refresh_interval_sec, &status); + if (!status.ok()) return status; + return provider; +} + +FileWatcherAuthorizationPolicyProvider::FileWatcherAuthorizationPolicyProvider( + absl::string_view authz_policy_path, unsigned int refresh_interval_sec, + absl::Status* status) + : authz_policy_path_(std::string(authz_policy_path)), + refresh_interval_sec_(refresh_interval_sec) { + gpr_event_init(&shutdown_event_); + // Initial read is done synchronously. + *status = ForceUpdate(); + if (!status->ok()) { + return; + } + auto thread_lambda = [](void* arg) { + WeakRefCountedPtr provider( + static_cast(arg)); + GPR_ASSERT(provider != nullptr); + while (true) { + void* value = gpr_event_wait( + &provider->shutdown_event_, + TimeoutSecondsToDeadline(provider->refresh_interval_sec_)); + if (value != nullptr) { + return; + } + absl::Status status = provider->ForceUpdate(); + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace) && !status.ok()) { + gpr_log(GPR_ERROR, + "authorization policy reload status. code=%d error_details=%s", + status.code(), std::string(status.message()).c_str()); + } + } + }; + refresh_thread_ = absl::make_unique( + "FileWatcherAuthorizationPolicyProvider_refreshing_thread", thread_lambda, + WeakRef().release()); + refresh_thread_->Start(); +} + +absl::Status FileWatcherAuthorizationPolicyProvider::ForceUpdate() { + absl::StatusOr file_contents = + ReadPolicyFromFile(authz_policy_path_); + if (!file_contents.ok()) { + return file_contents.status(); + } + if (file_contents_ == *file_contents) { + return absl::OkStatus(); + } + file_contents_ = std::move(*file_contents); + auto rbac_policies_or = GenerateRbacPolicies(file_contents_); + if (!rbac_policies_or.ok()) { + return rbac_policies_or.status(); + } + grpc_core::MutexLock lock(&mu_); + allow_engine_ = MakeRefCounted( + std::move(rbac_policies_or->allow_policy)); + deny_engine_ = MakeRefCounted( + std::move(rbac_policies_or->deny_policy)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace)) { + gpr_log(GPR_INFO, + "authorization policy reload status: successfully loaded new " + "policy\n%s", + file_contents_.c_str()); + } + return absl::OkStatus(); +} + +void FileWatcherAuthorizationPolicyProvider::Orphan() { + gpr_event_set(&shutdown_event_, reinterpret_cast(1)); + if (refresh_thread_ != nullptr) { + refresh_thread_->Join(); + } +} + +} // namespace grpc_core + +// Wrapper APIs declared in grpc_security.h + +grpc_authorization_policy_provider* +grpc_authorization_policy_provider_static_data_create( + const char* authz_policy, grpc_status_code* code, + const char** error_details) { + GPR_ASSERT(authz_policy != nullptr); + auto provider_or = + grpc_core::StaticDataAuthorizationPolicyProvider::Create(authz_policy); + if (!provider_or.ok()) { + *code = static_cast(provider_or.status().code()); + *error_details = + gpr_strdup(std::string(provider_or.status().message()).c_str()); + return nullptr; + } + return provider_or->release(); +} + +grpc_authorization_policy_provider* +grpc_authorization_policy_provider_file_watcher_create( + const char* authz_policy_path, unsigned int refresh_interval_sec, + grpc_status_code* code, const char** error_details) { + GPR_ASSERT(authz_policy_path != nullptr); + auto provider_or = grpc_core::FileWatcherAuthorizationPolicyProvider::Create( + authz_policy_path, refresh_interval_sec); + if (!provider_or.ok()) { + *code = static_cast(provider_or.status().code()); + *error_details = + gpr_strdup(std::string(provider_or.status().message()).c_str()); + return nullptr; + } + return provider_or->release(); +} + +void grpc_authorization_policy_provider_release( + grpc_authorization_policy_provider* provider) { + if (provider != nullptr) provider->Unref(); +} diff --git a/src/core/lib/security/authorization/matchers.cc b/src/core/lib/security/authorization/matchers.cc new file mode 100644 index 00000000..202e5918 --- /dev/null +++ b/src/core/lib/security/authorization/matchers.cc @@ -0,0 +1,225 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/matchers.h" + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" + +namespace grpc_core { + +std::unique_ptr AuthorizationMatcher::Create( + Rbac::Permission permission) { + switch (permission.type) { + case Rbac::Permission::RuleType::kAnd: { + std::vector> matchers; + for (const auto& rule : permission.permissions) { + matchers.push_back(AuthorizationMatcher::Create(std::move(*rule))); + } + return absl::make_unique(std::move(matchers)); + } + case Rbac::Permission::RuleType::kOr: { + std::vector> matchers; + for (const auto& rule : permission.permissions) { + matchers.push_back(AuthorizationMatcher::Create(std::move(*rule))); + } + return absl::make_unique(std::move(matchers)); + } + case Rbac::Permission::RuleType::kNot: + return absl::make_unique( + AuthorizationMatcher::Create(std::move(*permission.permissions[0]))); + case Rbac::Permission::RuleType::kAny: + return absl::make_unique(); + case Rbac::Permission::RuleType::kHeader: + return absl::make_unique( + std::move(permission.header_matcher)); + case Rbac::Permission::RuleType::kPath: + return absl::make_unique( + std::move(permission.string_matcher)); + case Rbac::Permission::RuleType::kDestIp: + return absl::make_unique( + IpAuthorizationMatcher::Type::kDestIp, std::move(permission.ip)); + case Rbac::Permission::RuleType::kDestPort: + return absl::make_unique(permission.port); + case Rbac::Permission::RuleType::kReqServerName: + return absl::make_unique( + std::move(permission.string_matcher)); + } + return nullptr; +} + +std::unique_ptr AuthorizationMatcher::Create( + Rbac::Principal principal) { + switch (principal.type) { + case Rbac::Principal::RuleType::kAnd: { + std::vector> matchers; + for (const auto& id : principal.principals) { + matchers.push_back(AuthorizationMatcher::Create(std::move(*id))); + } + return absl::make_unique(std::move(matchers)); + } + case Rbac::Principal::RuleType::kOr: { + std::vector> matchers; + for (const auto& id : principal.principals) { + matchers.push_back(AuthorizationMatcher::Create(std::move(*id))); + } + return absl::make_unique(std::move(matchers)); + } + case Rbac::Principal::RuleType::kNot: + return absl::make_unique( + AuthorizationMatcher::Create(std::move(*principal.principals[0]))); + case Rbac::Principal::RuleType::kAny: + return absl::make_unique(); + case Rbac::Principal::RuleType::kPrincipalName: + return absl::make_unique( + std::move(principal.string_matcher)); + case Rbac::Principal::RuleType::kSourceIp: + return absl::make_unique( + IpAuthorizationMatcher::Type::kSourceIp, std::move(principal.ip)); + case Rbac::Principal::RuleType::kDirectRemoteIp: + return absl::make_unique( + IpAuthorizationMatcher::Type::kDirectRemoteIp, + std::move(principal.ip)); + case Rbac::Principal::RuleType::kRemoteIp: + return absl::make_unique( + IpAuthorizationMatcher::Type::kRemoteIp, std::move(principal.ip)); + case Rbac::Principal::RuleType::kHeader: + return absl::make_unique( + std::move(principal.header_matcher)); + case Rbac::Principal::RuleType::kPath: + return absl::make_unique( + std::move(principal.string_matcher)); + } + return nullptr; +} + +bool AndAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + for (const auto& matcher : matchers_) { + if (!matcher->Matches(args)) { + return false; + } + } + return true; +} + +bool OrAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + for (const auto& matcher : matchers_) { + if (matcher->Matches(args)) { + return true; + } + } + return false; +} + +bool NotAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + return !matcher_->Matches(args); +} + +bool HeaderAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + std::string concatenated_value; + return matcher_.Match( + args.GetHeaderValue(matcher_.name(), &concatenated_value)); +} + +IpAuthorizationMatcher::IpAuthorizationMatcher(Type type, Rbac::CidrRange range) + : type_(type), prefix_len_(range.prefix_len) { + grpc_error_handle error = + grpc_string_to_sockaddr(&subnet_address_, range.address_prefix.c_str(), + /*port does not matter here*/ 0); + if (error == GRPC_ERROR_NONE) { + grpc_sockaddr_mask_bits(&subnet_address_, prefix_len_); + } else { + gpr_log(GPR_DEBUG, "CidrRange address %s is not IPv4/IPv6. Error: %s", + range.address_prefix.c_str(), grpc_error_std_string(error).c_str()); + } + GRPC_ERROR_UNREF(error); +} + +bool IpAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + grpc_resolved_address address; + switch (type_) { + case Type::kDestIp: { + address = args.GetLocalAddress(); + break; + } + case Type::kSourceIp: + case Type::kDirectRemoteIp: { + address = args.GetPeerAddress(); + break; + } + default: { + // Currently we do not support matching rules containing "remote_ip". + return false; + } + } + return grpc_sockaddr_match_subnet(&address, &subnet_address_, prefix_len_); +} + +bool PortAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + return port_ == args.GetLocalPort(); +} + +bool AuthenticatedAuthorizationMatcher::Matches( + const EvaluateArgs& args) const { + if (args.GetTransportSecurityType() != GRPC_SSL_TRANSPORT_SECURITY_TYPE && + args.GetTransportSecurityType() != GRPC_TLS_TRANSPORT_SECURITY_TYPE) { + // Connection is not authenticated. + return false; + } + if (matcher_.string_matcher().empty()) { + // Allows any authenticated user. + return true; + } + std::vector uri_sans = args.GetUriSans(); + if (!uri_sans.empty()) { + for (const auto& uri : uri_sans) { + if (matcher_.Match(uri)) { + return true; + } + } + } + std::vector dns_sans = args.GetDnsSans(); + if (!dns_sans.empty()) { + for (const auto& dns : dns_sans) { + if (matcher_.Match(dns)) { + return true; + } + } + } + // TODO(ashithasantosh): Check Subject field from certificate. + return false; +} + +bool ReqServerNameAuthorizationMatcher::Matches(const EvaluateArgs&) const { + // Currently we do not support matching rules containing + // "requested_server_name". + return false; +} + +bool PathAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + absl::string_view path = args.GetPath(); + if (!path.empty()) { + return matcher_.Match(path); + } + return false; +} + +bool PolicyAuthorizationMatcher::Matches(const EvaluateArgs& args) const { + return permissions_->Matches(args) && principals_->Matches(args); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/rbac_policy.cc b/src/core/lib/security/authorization/rbac_policy.cc new file mode 100644 index 00000000..a6f7ff2c --- /dev/null +++ b/src/core/lib/security/authorization/rbac_policy.cc @@ -0,0 +1,318 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/rbac_policy.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +namespace grpc_core { + +// +// Rbac +// + +Rbac::Rbac(Rbac::Action action, std::map policies) + : action(action), policies(std::move(policies)) {} + +Rbac::Rbac(Rbac&& other) noexcept + : action(other.action), policies(std::move(other.policies)) {} + +Rbac& Rbac::operator=(Rbac&& other) noexcept { + action = other.action; + policies = std::move(other.policies); + return *this; +} + +std::string Rbac::ToString() const { + std::vector contents; + contents.push_back(absl::StrFormat( + "Rbac action=%s{", action == Rbac::Action::kAllow ? "Allow" : "Deny")); + for (const auto& p : policies) { + contents.push_back(absl::StrFormat("{\n policy_name=%s\n%s\n}", p.first, + p.second.ToString())); + } + contents.push_back("}"); + return absl::StrJoin(contents, "\n"); +} + +// +// CidrRange +// + +Rbac::CidrRange::CidrRange(std::string address_prefix, uint32_t prefix_len) + : address_prefix(std::move(address_prefix)), prefix_len(prefix_len) {} + +Rbac::CidrRange::CidrRange(Rbac::CidrRange&& other) noexcept + : address_prefix(std::move(other.address_prefix)), + prefix_len(other.prefix_len) {} + +Rbac::CidrRange& Rbac::CidrRange::operator=(Rbac::CidrRange&& other) noexcept { + address_prefix = std::move(other.address_prefix); + prefix_len = other.prefix_len; + return *this; +} + +std::string Rbac::CidrRange::ToString() const { + return absl::StrFormat("CidrRange{address_prefix=%s,prefix_len=%d}", + address_prefix, prefix_len); +} + +// +// Permission +// + +Rbac::Permission::Permission( + Permission::RuleType type, + std::vector> permissions) + : type(type), permissions(std::move(permissions)) {} +Rbac::Permission::Permission(Permission::RuleType type, Permission permission) + : type(type) { + permissions.push_back( + absl::make_unique(std::move(permission))); +} +Rbac::Permission::Permission(Permission::RuleType type) : type(type) {} +Rbac::Permission::Permission(Permission::RuleType type, + HeaderMatcher header_matcher) + : type(type), header_matcher(std::move(header_matcher)) {} +Rbac::Permission::Permission(Permission::RuleType type, + StringMatcher string_matcher) + : type(type), string_matcher(std::move(string_matcher)) {} +Rbac::Permission::Permission(Permission::RuleType type, CidrRange ip) + : type(type), ip(std::move(ip)) {} +Rbac::Permission::Permission(Permission::RuleType type, int port) + : type(type), port(port) {} + +Rbac::Permission::Permission(Rbac::Permission&& other) noexcept + : type(other.type) { + switch (type) { + case RuleType::kAnd: + case RuleType::kOr: + case RuleType::kNot: + permissions = std::move(other.permissions); + break; + case RuleType::kAny: + break; + case RuleType::kHeader: + header_matcher = std::move(other.header_matcher); + break; + case RuleType::kPath: + case RuleType::kReqServerName: + string_matcher = std::move(other.string_matcher); + break; + case RuleType::kDestIp: + ip = std::move(other.ip); + break; + default: + port = other.port; + } +} + +Rbac::Permission& Rbac::Permission::operator=( + Rbac::Permission&& other) noexcept { + type = other.type; + switch (type) { + case RuleType::kAnd: + case RuleType::kOr: + case RuleType::kNot: + permissions = std::move(other.permissions); + break; + case RuleType::kAny: + break; + case RuleType::kHeader: + header_matcher = std::move(other.header_matcher); + break; + case RuleType::kPath: + case RuleType::kReqServerName: + string_matcher = std::move(other.string_matcher); + break; + case RuleType::kDestIp: + ip = std::move(other.ip); + break; + default: + port = other.port; + } + return *this; +} + +std::string Rbac::Permission::ToString() const { + switch (type) { + case RuleType::kAnd: { + std::vector contents; + contents.reserve(permissions.size()); + for (const auto& permission : permissions) { + contents.push_back(permission->ToString()); + } + return absl::StrFormat("and=[%s]", absl::StrJoin(contents, ",")); + } + case RuleType::kOr: { + std::vector contents; + contents.reserve(permissions.size()); + for (const auto& permission : permissions) { + contents.push_back(permission->ToString()); + } + return absl::StrFormat("or=[%s]", absl::StrJoin(contents, ",")); + } + case RuleType::kNot: + return absl::StrFormat("not %s", permissions[0]->ToString()); + case RuleType::kAny: + return "any"; + case RuleType::kHeader: + return absl::StrFormat("header=%s", header_matcher.ToString()); + case RuleType::kPath: + return absl::StrFormat("path=%s", string_matcher.ToString()); + case RuleType::kDestIp: + return absl::StrFormat("dest_ip=%s", ip.ToString()); + case RuleType::kDestPort: + return absl::StrFormat("dest_port=%d", port); + case RuleType::kReqServerName: + return absl::StrFormat("requested_server_name=%s", + string_matcher.ToString()); + default: + return ""; + } +} + +// +// Principal +// + +Rbac::Principal::Principal(Principal::RuleType type, + std::vector> principals) + : type(type), principals(std::move(principals)) {} +Rbac::Principal::Principal(Principal::RuleType type, Principal principal) + : type(type) { + principals.push_back( + absl::make_unique(std::move(principal))); +} +Rbac::Principal::Principal(Principal::RuleType type) : type(type) {} +Rbac::Principal::Principal(Principal::RuleType type, + StringMatcher string_matcher) + : type(type), string_matcher(std::move(string_matcher)) {} +Rbac::Principal::Principal(Principal::RuleType type, CidrRange ip) + : type(type), ip(std::move(ip)) {} +Rbac::Principal::Principal(Principal::RuleType type, + HeaderMatcher header_matcher) + : type(type), header_matcher(std::move(header_matcher)) {} + +Rbac::Principal::Principal(Rbac::Principal&& other) noexcept + : type(other.type) { + switch (type) { + case RuleType::kAnd: + case RuleType::kOr: + case RuleType::kNot: + principals = std::move(other.principals); + break; + case RuleType::kAny: + break; + case RuleType::kHeader: + header_matcher = std::move(other.header_matcher); + break; + case RuleType::kPrincipalName: + case RuleType::kPath: + string_matcher = std::move(other.string_matcher); + break; + default: + ip = std::move(other.ip); + } +} + +Rbac::Principal& Rbac::Principal::operator=(Rbac::Principal&& other) noexcept { + type = other.type; + switch (type) { + case RuleType::kAnd: + case RuleType::kOr: + case RuleType::kNot: + principals = std::move(other.principals); + break; + case RuleType::kAny: + break; + case RuleType::kHeader: + header_matcher = std::move(other.header_matcher); + break; + case RuleType::kPrincipalName: + case RuleType::kPath: + string_matcher = std::move(other.string_matcher); + break; + default: + ip = std::move(other.ip); + } + return *this; +} + +std::string Rbac::Principal::ToString() const { + switch (type) { + case RuleType::kAnd: { + std::vector contents; + contents.reserve(principals.size()); + for (const auto& principal : principals) { + contents.push_back(principal->ToString()); + } + return absl::StrFormat("and=[%s]", absl::StrJoin(contents, ",")); + } + case RuleType::kOr: { + std::vector contents; + contents.reserve(principals.size()); + for (const auto& principal : principals) { + contents.push_back(principal->ToString()); + } + return absl::StrFormat("or=[%s]", absl::StrJoin(contents, ",")); + } + case RuleType::kNot: + return absl::StrFormat("not %s", principals[0]->ToString()); + case RuleType::kAny: + return "any"; + case RuleType::kPrincipalName: + return absl::StrFormat("principal_name=%s", string_matcher.ToString()); + case RuleType::kSourceIp: + return absl::StrFormat("source_ip=%s", ip.ToString()); + case RuleType::kDirectRemoteIp: + return absl::StrFormat("direct_remote_ip=%s", ip.ToString()); + case RuleType::kRemoteIp: + return absl::StrFormat("remote_ip=%s", ip.ToString()); + case RuleType::kHeader: + return absl::StrFormat("header=%s", header_matcher.ToString()); + case RuleType::kPath: + return absl::StrFormat("path=%s", string_matcher.ToString()); + default: + return ""; + } +} + +// +// Policy +// + +Rbac::Policy::Policy(Permission permissions, Principal principals) + : permissions(std::move(permissions)), principals(std::move(principals)) {} + +Rbac::Policy::Policy(Rbac::Policy&& other) noexcept + : permissions(std::move(other.permissions)), + principals(std::move(other.principals)) {} + +Rbac::Policy& Rbac::Policy::operator=(Rbac::Policy&& other) noexcept { + permissions = std::move(other.permissions); + principals = std::move(other.principals); + return *this; +} + +std::string Rbac::Policy::ToString() const { + return absl::StrFormat( + " Policy {\n Permissions{%s}\n Principals{%s}\n }", + permissions.ToString(), principals.ToString()); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/rbac_translator.cc b/src/core/lib/security/authorization/rbac_translator.cc new file mode 100644 index 00000000..ef82af75 --- /dev/null +++ b/src/core/lib/security/authorization/rbac_translator.cc @@ -0,0 +1,372 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/rbac_translator.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/strip.h" + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/matchers/matchers.h" + +namespace grpc_core { + +namespace { + +absl::string_view GetMatcherType(absl::string_view value, + StringMatcher::Type* type) { + if (value == "*") { + *type = StringMatcher::Type::kPrefix; + return ""; + } else if (absl::StartsWith(value, "*")) { + *type = StringMatcher::Type::kSuffix; + return absl::StripPrefix(value, "*"); + } else if (absl::EndsWith(value, "*")) { + *type = StringMatcher::Type::kPrefix; + return absl::StripSuffix(value, "*"); + } + *type = StringMatcher::Type::kExact; + return value; +} + +absl::StatusOr GetStringMatcher(absl::string_view value) { + StringMatcher::Type type; + absl::string_view matcher = GetMatcherType(value, &type); + return StringMatcher::Create(type, matcher); +} + +absl::StatusOr GetHeaderMatcher(absl::string_view name, + absl::string_view value) { + StringMatcher::Type type; + absl::string_view matcher = GetMatcherType(value, &type); + return HeaderMatcher::Create(name, static_cast(type), + matcher); +} + +bool IsUnsupportedHeader(absl::string_view header_name) { + static const char* const kUnsupportedHeaders[] = {"host", + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade"}; + for (size_t i = 0; i < GPR_ARRAY_SIZE(kUnsupportedHeaders); ++i) { + if (absl::EqualsIgnoreCase(header_name, kUnsupportedHeaders[i])) { + return true; + } + } + return false; +} + +absl::StatusOr ParsePrincipalsArray(const Json& json) { + std::vector> principal_names; + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& child = json.array_value().at(i); + if (child.type() != Json::Type::STRING) { + return absl::InvalidArgumentError( + absl::StrCat("\"principals\" ", i, ": is not a string.")); + } + auto matcher_or = GetStringMatcher(child.string_value()); + if (!matcher_or.ok()) { + return absl::Status(matcher_or.status().code(), + absl::StrCat("\"principals\" ", i, ": ", + matcher_or.status().message())); + } + principal_names.push_back(absl::make_unique( + Rbac::Principal::RuleType::kPrincipalName, + std::move(matcher_or.value()))); + } + return Rbac::Principal(Rbac::Principal::RuleType::kOr, + std::move(principal_names)); +} + +absl::StatusOr ParsePeer(const Json& json) { + std::vector> peer; + auto it = json.object_value().find("principals"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"principals\" is not an array."); + } + auto principal_names_or = ParsePrincipalsArray(it->second); + if (!principal_names_or.ok()) return principal_names_or.status(); + if (!principal_names_or.value().principals.empty()) { + peer.push_back(absl::make_unique( + std::move(principal_names_or.value()))); + } + } + if (peer.empty()) { + return Rbac::Principal(Rbac::Principal::RuleType::kAny); + } + return Rbac::Principal(Rbac::Principal::RuleType::kAnd, std::move(peer)); +} + +absl::StatusOr ParseHeaderValues( + const Json& json, absl::string_view header_name) { + if (json.array_value().empty()) { + return absl::InvalidArgumentError("\"values\" list is empty."); + } + std::vector> values; + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& child = json.array_value().at(i); + if (child.type() != Json::Type::STRING) { + return absl::InvalidArgumentError( + absl::StrCat("\"values\" ", i, ": is not a string.")); + } + auto matcher_or = GetHeaderMatcher(header_name, child.string_value()); + if (!matcher_or.ok()) { + return absl::Status( + matcher_or.status().code(), + absl::StrCat("\"values\" ", i, ": ", matcher_or.status().message())); + } + values.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, std::move(matcher_or.value()))); + } + return Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(values)); +} + +absl::StatusOr ParseHeaders(const Json& json) { + auto it = json.object_value().find("key"); + if (it == json.object_value().end()) { + return absl::InvalidArgumentError("\"key\" is not present."); + } + if (it->second.type() != Json::Type::STRING) { + return absl::InvalidArgumentError("\"key\" is not a string."); + } + absl::string_view header_name = it->second.string_value(); + if (absl::StartsWith(header_name, ":") || + absl::StartsWith(header_name, "grpc-") || + IsUnsupportedHeader(header_name)) { + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported \"key\" %s.", header_name)); + } + it = json.object_value().find("values"); + if (it == json.object_value().end()) { + return absl::InvalidArgumentError("\"values\" is not present."); + } + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"values\" is not an array."); + } + return ParseHeaderValues(it->second, header_name); +} + +absl::StatusOr ParseHeadersArray(const Json& json) { + std::vector> headers; + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& child = json.array_value().at(i); + if (child.type() != Json::Type::OBJECT) { + return absl::InvalidArgumentError( + absl::StrCat("\"headers\" ", i, ": is not an object.")); + } + auto headers_or = ParseHeaders(child); + if (!headers_or.ok()) { + return absl::Status( + headers_or.status().code(), + absl::StrCat("\"headers\" ", i, ": ", headers_or.status().message())); + } + headers.push_back( + absl::make_unique(std::move(headers_or.value()))); + } + return Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(headers)); +} + +absl::StatusOr ParsePathsArray(const Json& json) { + std::vector> paths; + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& child = json.array_value().at(i); + if (child.type() != Json::Type::STRING) { + return absl::InvalidArgumentError( + absl::StrCat("\"paths\" ", i, ": is not a string.")); + } + auto matcher_or = GetStringMatcher(child.string_value()); + if (!matcher_or.ok()) { + return absl::Status( + matcher_or.status().code(), + absl::StrCat("\"paths\" ", i, ": ", matcher_or.status().message())); + } + paths.push_back(absl::make_unique( + Rbac::Permission::RuleType::kPath, std::move(matcher_or.value()))); + } + return Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(paths)); +} + +absl::StatusOr ParseRequest(const Json& json) { + std::vector> request; + auto it = json.object_value().find("paths"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"paths\" is not an array."); + } + auto paths_or = ParsePathsArray(it->second); + if (!paths_or.ok()) return paths_or.status(); + if (!paths_or.value().permissions.empty()) { + request.push_back( + absl::make_unique(std::move(paths_or.value()))); + } + } + it = json.object_value().find("headers"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"headers\" is not an array."); + } + auto headers_or = ParseHeadersArray(it->second); + if (!headers_or.ok()) return headers_or.status(); + if (!headers_or.value().permissions.empty()) { + request.push_back( + absl::make_unique(std::move(headers_or.value()))); + } + } + if (request.empty()) { + return Rbac::Permission(Rbac::Permission::RuleType::kAny); + } + return Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(request)); +} + +absl::StatusOr ParseRules(const Json& json) { + Rbac::Principal principals; + auto it = json.object_value().find("source"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::OBJECT) { + return absl::InvalidArgumentError("\"source\" is not an object."); + } + auto peer_or = ParsePeer(it->second); + if (!peer_or.ok()) return peer_or.status(); + principals = std::move(peer_or.value()); + } else { + principals = Rbac::Principal(Rbac::Principal::RuleType::kAny); + } + Rbac::Permission permissions; + it = json.object_value().find("request"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::OBJECT) { + return absl::InvalidArgumentError("\"request\" is not an object."); + } + auto request_or = ParseRequest(it->second); + if (!request_or.ok()) return request_or.status(); + permissions = std::move(request_or.value()); + } else { + permissions = Rbac::Permission(Rbac::Permission::RuleType::kAny); + } + return Rbac::Policy(std::move(permissions), std::move(principals)); +} + +absl::StatusOr> ParseRulesArray( + const Json& json, absl::string_view name) { + std::map policies; + for (size_t i = 0; i < json.array_value().size(); ++i) { + const Json& child = json.array_value().at(i); + if (child.type() != Json::Type::OBJECT) { + return absl::InvalidArgumentError( + absl::StrCat("rules ", i, ": is not an object.")); + } + auto it = child.object_value().find("name"); + if (it == child.object_value().end()) { + return absl::InvalidArgumentError( + absl::StrCat("rules ", i, ": \"name\" is not present.")); + } + if (it->second.type() != Json::Type::STRING) { + return absl::InvalidArgumentError( + absl::StrCat("rules ", i, ": \"name\" is not a string.")); + } + std::string policy_name = + std::string(name) + "_" + it->second.string_value(); + auto policy_or = ParseRules(child); + if (!policy_or.ok()) { + return absl::Status( + policy_or.status().code(), + absl::StrCat("rules ", i, ": ", policy_or.status().message())); + } + policies[policy_name] = std::move(policy_or.value()); + } + return std::move(policies); +} + +absl::StatusOr ParseDenyRulesArray(const Json& json, + absl::string_view name) { + auto policies_or = ParseRulesArray(json, name); + if (!policies_or.ok()) return policies_or.status(); + return Rbac(Rbac::Action::kDeny, std::move(policies_or.value())); +} + +absl::StatusOr ParseAllowRulesArray(const Json& json, + absl::string_view name) { + auto policies_or = ParseRulesArray(json, name); + if (!policies_or.ok()) return policies_or.status(); + return Rbac(Rbac::Action::kAllow, std::move(policies_or.value())); +} + +} // namespace + +absl::StatusOr GenerateRbacPolicies( + absl::string_view authz_policy) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(authz_policy, &error); + if (error != GRPC_ERROR_NONE) { + absl::Status status = absl::InvalidArgumentError( + absl::StrCat("Failed to parse SDK authorization policy. Error: ", + grpc_error_std_string(error))); + GRPC_ERROR_UNREF(error); + return status; + } + if (json.type() != Json::Type::OBJECT) { + return absl::InvalidArgumentError( + "SDK authorization policy is not an object."); + } + auto it = json.mutable_object()->find("name"); + if (it == json.mutable_object()->end()) { + return absl::InvalidArgumentError("\"name\" field is not present."); + } + if (it->second.type() != Json::Type::STRING) { + return absl::InvalidArgumentError("\"name\" is not a string."); + } + absl::string_view name = it->second.string_value(); + RbacPolicies rbac_policies; + it = json.mutable_object()->find("deny_rules"); + if (it != json.mutable_object()->end()) { + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"deny_rules\" is not an array."); + } + auto deny_policy_or = ParseDenyRulesArray(it->second, name); + if (!deny_policy_or.ok()) { + return absl::Status( + deny_policy_or.status().code(), + absl::StrCat("deny_", deny_policy_or.status().message())); + } + rbac_policies.deny_policy = std::move(deny_policy_or.value()); + } else { + rbac_policies.deny_policy.action = Rbac::Action::kDeny; + } + it = json.mutable_object()->find("allow_rules"); + if (it == json.mutable_object()->end()) { + return absl::InvalidArgumentError("\"allow_rules\" is not present."); + } + if (it->second.type() != Json::Type::ARRAY) { + return absl::InvalidArgumentError("\"allow_rules\" is not an array."); + } + auto allow_policy_or = ParseAllowRulesArray(it->second, name); + if (!allow_policy_or.ok()) { + return absl::Status( + allow_policy_or.status().code(), + absl::StrCat("allow_", allow_policy_or.status().message())); + } + rbac_policies.allow_policy = std::move(allow_policy_or.value()); + return std::move(rbac_policies); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/authorization/sdk_server_authz_filter.cc b/src/core/lib/security/authorization/sdk_server_authz_filter.cc new file mode 100644 index 00000000..efd17550 --- /dev/null +++ b/src/core/lib/security/authorization/sdk_server_authz_filter.cc @@ -0,0 +1,171 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/sdk_server_authz_filter.h" + +#include "src/core/lib/security/authorization/evaluate_args.h" +#include "src/core/lib/transport/transport.h" + +namespace grpc_core { + +TraceFlag grpc_sdk_authz_trace(false, "sdk_authz"); + +SdkServerAuthzFilter::SdkServerAuthzFilter( + RefCountedPtr auth_context, grpc_endpoint* endpoint, + RefCountedPtr provider) + : auth_context_(std::move(auth_context)), + per_channel_evaluate_args_(auth_context_.get(), endpoint), + provider_(std::move(provider)) {} + +grpc_error_handle SdkServerAuthzFilter::Init(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + grpc_auth_context* auth_context = + grpc_find_auth_context_in_args(args->channel_args); + grpc_authorization_policy_provider* provider = + grpc_channel_args_find_pointer( + args->channel_args, GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER); + if (provider == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Failed to get authorization provider."); + } + // grpc_endpoint isn't needed because the current SDK authorization policy + // does not support any rules that requires looking for source or destination + // addresses. + new (elem->channel_data) SdkServerAuthzFilter( + auth_context != nullptr ? auth_context->Ref() : nullptr, + /*endpoint=*/nullptr, provider->Ref()); + return GRPC_ERROR_NONE; +} + +void SdkServerAuthzFilter::Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~SdkServerAuthzFilter(); +} + +SdkServerAuthzFilter::CallData::CallData(grpc_call_element* elem) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + elem, grpc_schedule_on_exec_ctx); +} + +void SdkServerAuthzFilter::CallData::StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + if (batch->recv_initial_metadata) { + // Inject our callback. + calld->recv_initial_metadata_batch_ = + batch->payload->recv_initial_metadata.recv_initial_metadata; + calld->original_recv_initial_metadata_ready_ = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready_; + } + grpc_call_next_op(elem, batch); +} + +grpc_error_handle SdkServerAuthzFilter::CallData::Init( + grpc_call_element* elem, const grpc_call_element_args*) { + new (elem->call_data) CallData(elem); + return GRPC_ERROR_NONE; +} + +void SdkServerAuthzFilter::CallData::Destroy( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + CallData* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +bool SdkServerAuthzFilter::CallData::IsAuthorized(SdkServerAuthzFilter* chand) { + EvaluateArgs args(recv_initial_metadata_batch_, + &chand->per_channel_evaluate_args_); + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace)) { + gpr_log( + GPR_DEBUG, + "checking request: url_path=%s, transport_security_type=%s, " + "uri_sans=[%s], dns_sans=[%s], local_address=%s:%d, peer_address=%s:%d", + std::string(args.GetPath()).c_str(), + std::string(args.GetTransportSecurityType()).c_str(), + absl::StrJoin(args.GetUriSans(), ",").c_str(), + absl::StrJoin(args.GetDnsSans(), ",").c_str(), + std::string(args.GetLocalAddressString()).c_str(), args.GetLocalPort(), + std::string(args.GetPeerAddressString()).c_str(), args.GetPeerPort()); + } + grpc_authorization_policy_provider::AuthorizationEngines engines = + chand->provider_->engines(); + if (engines.deny_engine != nullptr) { + AuthorizationEngine::Decision decision = + engines.deny_engine->Evaluate(args); + if (decision.type == AuthorizationEngine::Decision::Type::kDeny) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: request denied by policy %s.", + chand, this, decision.matching_policy_name.c_str()); + } + return false; + } + } + if (engines.allow_engine != nullptr) { + AuthorizationEngine::Decision decision = + engines.allow_engine->Evaluate(args); + if (decision.type == AuthorizationEngine::Decision::Type::kAllow) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace)) { + gpr_log(GPR_INFO, "chand=%p calld=%p: request allowed by policy %s.", + chand, this, decision.matching_policy_name.c_str()); + } + return true; + } + } + if (GRPC_TRACE_FLAG_ENABLED(grpc_sdk_authz_trace)) { + gpr_log(GPR_INFO, + "chand=%p calld=%p: request denied, no matching policy found.", + chand, this); + } + return false; +} + +void SdkServerAuthzFilter::CallData::RecvInitialMetadataReady( + void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + auto* chand = static_cast(elem->channel_data); + auto* calld = static_cast(elem->call_data); + if (error == GRPC_ERROR_NONE) { + if (!calld->IsAuthorized(chand)) { + error = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unauthorized RPC request rejected."), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_PERMISSION_DENIED); + } + } else { + (void)GRPC_ERROR_REF(error); + } + Closure::Run(DEBUG_LOCATION, calld->original_recv_initial_metadata_ready_, + error); +} + +const grpc_channel_filter SdkServerAuthzFilter::kFilterVtable = { + SdkServerAuthzFilter::CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(SdkServerAuthzFilter::CallData), + SdkServerAuthzFilter::CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + SdkServerAuthzFilter::CallData::Destroy, + sizeof(SdkServerAuthzFilter), + SdkServerAuthzFilter::Init, + SdkServerAuthzFilter::Destroy, + grpc_channel_next_get_info, + "sdk-server-authz"}; + +} // namespace grpc_core diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc new file mode 100644 index 00000000..257f5c25 --- /dev/null +++ b/src/core/lib/security/context/security_context.cc @@ -0,0 +1,325 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/context/security_context.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/arena.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" + +grpc_core::DebugOnlyTraceFlag grpc_trace_auth_context_refcount( + false, "auth_context_refcount"); + +/* --- grpc_call --- */ + +grpc_call_error grpc_call_set_credentials(grpc_call* call, + grpc_call_credentials* creds) { + grpc_core::ExecCtx exec_ctx; + grpc_client_security_context* ctx = nullptr; + GRPC_API_TRACE("grpc_call_set_credentials(call=%p, creds=%p)", 2, + (call, creds)); + if (!grpc_call_is_client(call)) { + gpr_log(GPR_ERROR, "Method is client-side only."); + return GRPC_CALL_ERROR_NOT_ON_SERVER; + } + ctx = static_cast( + grpc_call_context_get(call, GRPC_CONTEXT_SECURITY)); + if (ctx == nullptr) { + ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds); + grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx, + grpc_client_security_context_destroy); + } else { + ctx->creds = creds != nullptr ? creds->Ref() : nullptr; + } + + return GRPC_CALL_OK; +} + +grpc_auth_context* grpc_call_auth_context(grpc_call* call) { + void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY); + GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call)); + if (sec_ctx == nullptr) return nullptr; + if (grpc_call_is_client(call)) { + auto* sc = static_cast(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context client") + .release(); + } + } else { + auto* sc = static_cast(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context server") + .release(); + } + } +} + +void grpc_auth_context_release(grpc_auth_context* context) { + GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context)); + if (context == nullptr) return; + context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref"); +} + +/* --- grpc_client_security_context --- */ +grpc_client_security_context::~grpc_client_security_context() { + auth_context.reset(DEBUG_LOCATION, "client_security_context"); + if (extension.instance != nullptr && extension.destroy != nullptr) { + extension.destroy(extension.instance); + } +} + +grpc_client_security_context* grpc_client_security_context_create( + grpc_core::Arena* arena, grpc_call_credentials* creds) { + return arena->New( + creds != nullptr ? creds->Ref() : nullptr); +} + +void grpc_client_security_context_destroy(void* ctx) { + grpc_core::ExecCtx exec_ctx; + grpc_client_security_context* c = + static_cast(ctx); + c->~grpc_client_security_context(); +} + +/* --- grpc_server_security_context --- */ +grpc_server_security_context::~grpc_server_security_context() { + auth_context.reset(DEBUG_LOCATION, "server_security_context"); + if (extension.instance != nullptr && extension.destroy != nullptr) { + extension.destroy(extension.instance); + } +} + +grpc_server_security_context* grpc_server_security_context_create( + grpc_core::Arena* arena) { + return arena->New(); +} + +void grpc_server_security_context_destroy(void* ctx) { + grpc_server_security_context* c = + static_cast(ctx); + c->~grpc_server_security_context(); +} + +/* --- grpc_auth_context --- */ + +static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr}; + +const char* grpc_auth_context_peer_identity_property_name( + const grpc_auth_context* ctx) { + GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1, + (ctx)); + return ctx->peer_identity_property_name(); +} + +int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, + const char* name) { + grpc_auth_property_iterator it = + grpc_auth_context_find_properties_by_name(ctx, name); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + GRPC_API_TRACE( + "grpc_auth_context_set_peer_identity_property_name(ctx=%p, name=%s)", 2, + (ctx, name)); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "Property name %s not found in auth context.", + name != nullptr ? name : "NULL"); + return 0; + } + ctx->set_peer_identity_property_name(prop->name); + return 1; +} + +int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) { + GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx)); + return ctx->is_authenticated(); +} + +grpc_auth_property_iterator grpc_auth_context_property_iterator( + const grpc_auth_context* ctx) { + grpc_auth_property_iterator it = empty_iterator; + GRPC_API_TRACE("grpc_auth_context_property_iterator(ctx=%p)", 1, (ctx)); + if (ctx == nullptr) return it; + it.ctx = ctx; + return it; +} + +const grpc_auth_property* grpc_auth_property_iterator_next( + grpc_auth_property_iterator* it) { + GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it)); + if (it == nullptr || it->ctx == nullptr) return nullptr; + while (it->index == it->ctx->properties().count) { + if (it->ctx->chained() == nullptr) return nullptr; + it->ctx = it->ctx->chained(); + it->index = 0; + } + if (it->name == nullptr) { + return &it->ctx->properties().array[it->index++]; + } else { + while (it->index < it->ctx->properties().count) { + const grpc_auth_property* prop = + &it->ctx->properties().array[it->index++]; + GPR_ASSERT(prop->name != nullptr); + if (strcmp(it->name, prop->name) == 0) { + return prop; + } + } + /* We could not find the name, try another round. */ + return grpc_auth_property_iterator_next(it); + } +} + +grpc_auth_property_iterator grpc_auth_context_find_properties_by_name( + const grpc_auth_context* ctx, const char* name) { + grpc_auth_property_iterator it = empty_iterator; + GRPC_API_TRACE("grpc_auth_context_find_properties_by_name(ctx=%p, name=%s)", + 2, (ctx, name)); + if (ctx == nullptr || name == nullptr) return empty_iterator; + it.ctx = ctx; + it.name = name; + return it; +} + +grpc_auth_property_iterator grpc_auth_context_peer_identity( + const grpc_auth_context* ctx) { + GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx)); + if (ctx == nullptr) return empty_iterator; + return grpc_auth_context_find_properties_by_name( + ctx, ctx->peer_identity_property_name()); +} + +void grpc_auth_context::ensure_capacity() { + if (properties_.count == properties_.capacity) { + properties_.capacity = + std::max(properties_.capacity + 8, properties_.capacity * 2); + properties_.array = static_cast(gpr_realloc( + properties_.array, properties_.capacity * sizeof(grpc_auth_property))); + } +} + +void grpc_auth_context::add_property(const char* name, const char* value, + size_t value_length) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; + prop->name = gpr_strdup(name); + prop->value = static_cast(gpr_malloc(value_length + 1)); + memcpy(prop->value, value, value_length); + prop->value[value_length] = '\0'; + prop->value_length = value_length; +} + +void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name, + const char* value, size_t value_length) { + GRPC_API_TRACE( + "grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, " + "value_length=%lu)", + 6, + (ctx, name, (int)value_length, (int)value_length, value, + (unsigned long)value_length)); + ctx->add_property(name, value, value_length); +} + +void grpc_auth_context::add_cstring_property(const char* name, + const char* value) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; + prop->name = gpr_strdup(name); + prop->value = gpr_strdup(value); + prop->value_length = strlen(value); +} + +void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx, + const char* name, + const char* value) { + GRPC_API_TRACE( + "grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3, + (ctx, name, value)); + ctx->add_cstring_property(name, value); +} + +void grpc_auth_property_reset(grpc_auth_property* property) { + gpr_free(property->name); + gpr_free(property->value); + memset(property, 0, sizeof(grpc_auth_property)); +} + +static void auth_context_pointer_arg_destroy(void* p) { + if (p != nullptr) { + static_cast(p)->Unref(DEBUG_LOCATION, + "auth_context_pointer_arg"); + } +} + +static void* auth_context_pointer_arg_copy(void* p) { + auto* ctx = static_cast(p); + return ctx == nullptr + ? nullptr + : ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release(); +} + +static int auth_context_pointer_cmp(void* a, void* b) { + return grpc_core::QsortCompare(a, b); +} + +static const grpc_arg_pointer_vtable auth_context_pointer_vtable = { + auth_context_pointer_arg_copy, auth_context_pointer_arg_destroy, + auth_context_pointer_cmp}; + +grpc_arg grpc_auth_context_to_arg(grpc_auth_context* c) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_AUTH_CONTEXT_ARG), c, + &auth_context_pointer_vtable); +} + +grpc_auth_context* grpc_auth_context_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_AUTH_CONTEXT_ARG) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_AUTH_CONTEXT_ARG); + return nullptr; + } + return static_cast(arg->value.pointer.p); +} + +grpc_auth_context* grpc_find_auth_context_in_args( + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_auth_context* p = grpc_auth_context_from_arg(&args->args[i]); + if (p != nullptr) return p; + } + return nullptr; +} diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc new file mode 100644 index 00000000..30acd749 --- /dev/null +++ b/src/core/lib/security/credentials/alts/alts_credentials.cc @@ -0,0 +1,111 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/alts/alts_credentials.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" + +#define GRPC_CREDENTIALS_TYPE_ALTS "Alts" +#define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal.:8080" + +grpc_alts_credentials::grpc_alts_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) { + grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions); +} + +grpc_alts_credentials::~grpc_alts_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); +} + +grpc_core::RefCountedPtr +grpc_alts_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target_name, const grpc_channel_args* /*args*/, + grpc_channel_args** /*new_args*/) { + return grpc_alts_channel_security_connector_create( + this->Ref(), std::move(call_creds), target_name); +} + +grpc_alts_server_credentials::grpc_alts_server_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) { + grpc_alts_set_rpc_protocol_versions(&options_->rpc_versions); +} + +grpc_core::RefCountedPtr +grpc_alts_server_credentials::create_security_connector( + const grpc_channel_args* /* args */) { + return grpc_alts_server_security_connector_create(this->Ref()); +} + +grpc_alts_server_credentials::~grpc_alts_server_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); +} + +grpc_channel_credentials* grpc_alts_credentials_create_customized( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url, bool enable_untrusted_alts) { + if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { + return nullptr; + } + return new grpc_alts_credentials(options, handshaker_service_url); +} + +grpc_server_credentials* grpc_alts_server_credentials_create_customized( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url, bool enable_untrusted_alts) { + if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { + return nullptr; + } + return new grpc_alts_server_credentials(options, handshaker_service_url); +} + +grpc_channel_credentials* grpc_alts_credentials_create( + const grpc_alts_credentials_options* options) { + return grpc_alts_credentials_create_customized( + options, GRPC_ALTS_HANDSHAKER_SERVICE_URL, false); +} + +grpc_server_credentials* grpc_alts_server_credentials_create( + const grpc_alts_credentials_options* options) { + return grpc_alts_server_credentials_create_customized( + options, GRPC_ALTS_HANDSHAKER_SERVICE_URL, false); +} diff --git a/src/core/lib/security/credentials/alts/check_gcp_environment.cc b/src/core/lib/security/credentials/alts/check_gcp_environment.cc new file mode 100644 index 00000000..f3370a3c --- /dev/null +++ b/src/core/lib/security/credentials/alts/check_gcp_environment.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +#include +#include +#include + +#include +#include + +const size_t kBiosDataBufferSize = 256; + +static char* trim(const char* src) { + if (src == nullptr || *src == '\0') { + return nullptr; + } + char* des = nullptr; + size_t start = 0, end = strlen(src) - 1; + /* find the last character that is not a whitespace. */ + while (end != 0 && isspace(src[end])) { + end--; + } + /* find the first character that is not a whitespace. */ + while (start < strlen(src) && isspace(src[start])) { + start++; + } + if (start <= end) { + des = static_cast( + gpr_zalloc(sizeof(char) * (end - start + 2 /* '\0' */))); + memcpy(des, src + start, end - start + 1); + } + return des; +} + +namespace grpc_core { +namespace internal { + +char* read_bios_file(const char* bios_file) { + FILE* fp = fopen(bios_file, "r"); + if (!fp) { + gpr_log(GPR_INFO, "BIOS data file does not exist or cannot be opened."); + return nullptr; + } + char buf[kBiosDataBufferSize + 1]; + size_t ret = fread(buf, sizeof(char), kBiosDataBufferSize, fp); + buf[ret] = '\0'; + char* trimmed_buf = trim(buf); + fclose(fp); + return trimmed_buf; +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/alts/check_gcp_environment_linux.cc b/src/core/lib/security/credentials/alts/check_gcp_environment_linux.cc new file mode 100644 index 00000000..0c6f5561 --- /dev/null +++ b/src/core/lib/security/credentials/alts/check_gcp_environment_linux.cc @@ -0,0 +1,68 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_LINUX + +#include + +#include +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +#define GRPC_ALTS_EXPECT_NAME_GOOGLE "Google" +#define GRPC_ALTS_EXPECT_NAME_GCE "Google Compute Engine" +#define GRPC_ALTS_PRODUCT_NAME_FILE "/sys/class/dmi/id/product_name" + +static bool g_compute_engine_detection_done = false; +static bool g_is_on_compute_engine = false; +static gpr_mu g_mu; +static gpr_once g_once = GPR_ONCE_INIT; + +namespace grpc_core { +namespace internal { + +bool check_bios_data(const char* bios_data_file) { + char* bios_data = read_bios_file(bios_data_file); + bool result = + bios_data && ((!strcmp(bios_data, GRPC_ALTS_EXPECT_NAME_GOOGLE)) || + (!strcmp(bios_data, GRPC_ALTS_EXPECT_NAME_GCE))); + gpr_free(bios_data); + return result; +} + +} // namespace internal +} // namespace grpc_core + +static void init_mu(void) { gpr_mu_init(&g_mu); } + +bool grpc_alts_is_running_on_gcp() { + gpr_once_init(&g_once, init_mu); + gpr_mu_lock(&g_mu); + if (!g_compute_engine_detection_done) { + g_is_on_compute_engine = + grpc_core::internal::check_bios_data(GRPC_ALTS_PRODUCT_NAME_FILE); + g_compute_engine_detection_done = true; + } + gpr_mu_unlock(&g_mu); + return g_is_on_compute_engine; +} + +#endif // GPR_LINUX diff --git a/src/core/lib/security/credentials/alts/check_gcp_environment_no_op.cc b/src/core/lib/security/credentials/alts/check_gcp_environment_no_op.cc new file mode 100644 index 00000000..cb401025 --- /dev/null +++ b/src/core/lib/security/credentials/alts/check_gcp_environment_no_op.cc @@ -0,0 +1,33 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if !defined(GPR_LINUX) && !defined(GPR_WINDOWS) + +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +bool grpc_alts_is_running_on_gcp() { + gpr_log(GPR_INFO, + "ALTS: Platforms other than Linux and Windows are not supported"); + return false; +} + +#endif // !defined(LINUX) && !defined(GPR_WINDOWS) diff --git a/src/core/lib/security/credentials/alts/check_gcp_environment_windows.cc b/src/core/lib/security/credentials/alts/check_gcp_environment_windows.cc new file mode 100644 index 00000000..5d2bdc14 --- /dev/null +++ b/src/core/lib/security/credentials/alts/check_gcp_environment_windows.cc @@ -0,0 +1,102 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +namespace grpc_core { +namespace internal { + +bool check_bios_data(const char*) { return false; } + +bool check_windows_registry_product_name(HKEY root_key, + const char* reg_key_path, + const char* reg_key_name) { + const size_t kProductNameBufferSize = 256; + char const expected_substr[] = "Google"; + + // Get the size of the string first to allocate our buffer. This includes + // enough space for the trailing NUL character that will be included. + DWORD buffer_size{}; + auto rc = ::RegGetValueA( + root_key, reg_key_path, reg_key_name, RRF_RT_REG_SZ, + nullptr, // We know the type will be REG_SZ. + nullptr, // We're only fetching the size; no buffer given yet. + &buffer_size); // Fetch the size in bytes of the value, if it exists. + if (rc != 0) { + return false; + } + + if (buffer_size > kProductNameBufferSize) { + return false; + } + + // Retrieve the product name string. + char buffer[kProductNameBufferSize]; + buffer_size = kProductNameBufferSize; + rc = ::RegGetValueA( + root_key, reg_key_path, reg_key_name, RRF_RT_REG_SZ, + nullptr, // We know the type will be REG_SZ. + static_cast(buffer), // Fetch the string value this time. + &buffer_size); // The string size in bytes, not including trailing NUL. + if (rc != 0) { + return false; + } + + return strstr(buffer, expected_substr) != nullptr; +} + +} // namespace internal +} // namespace grpc_core + +static bool g_compute_engine_detection_done = false; +static bool g_is_on_compute_engine = false; +static gpr_mu g_mu; +static gpr_once g_once = GPR_ONCE_INIT; + +static void init_mu(void) { gpr_mu_init(&g_mu); } + +bool grpc_alts_is_running_on_gcp() { + char const reg_key_path[] = "SYSTEM\\HardwareConfig\\Current\\"; + char const reg_key_name[] = "SystemProductName"; + + gpr_once_init(&g_once, init_mu); + gpr_mu_lock(&g_mu); + if (!g_compute_engine_detection_done) { + g_is_on_compute_engine = + grpc_core::internal::check_windows_registry_product_name( + HKEY_LOCAL_MACHINE, reg_key_path, reg_key_name); + g_compute_engine_detection_done = true; + } + gpr_mu_unlock(&g_mu); + return g_is_on_compute_engine; +} + +#endif // GPR_WINDOWS diff --git a/src/core/lib/security/credentials/alts/grpc_alts_credentials_client_options.cc b/src/core/lib/security/credentials/alts/grpc_alts_credentials_client_options.cc new file mode 100644 index 00000000..118d18d1 --- /dev/null +++ b/src/core/lib/security/credentials/alts/grpc_alts_credentials_client_options.cc @@ -0,0 +1,127 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +static grpc_alts_credentials_options* alts_client_options_copy( + const grpc_alts_credentials_options* options); + +static void alts_client_options_destroy(grpc_alts_credentials_options* options); + +static target_service_account* target_service_account_create( + const char* service_account) { + if (service_account == nullptr) { + return nullptr; + } + auto* sa = static_cast( + gpr_zalloc(sizeof(target_service_account))); + sa->data = gpr_strdup(service_account); + return sa; +} + +void grpc_alts_credentials_client_options_add_target_service_account( + grpc_alts_credentials_options* options, const char* service_account) { + if (options == nullptr || service_account == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_alts_credentials_client_options_add_target_service_account()"); + return; + } + auto client_options = + reinterpret_cast(options); + target_service_account* node = target_service_account_create(service_account); + node->next = client_options->target_account_list_head; + client_options->target_account_list_head = node; +} + +static void target_service_account_destroy( + target_service_account* service_account) { + if (service_account == nullptr) { + return; + } + gpr_free(service_account->data); + gpr_free(service_account); +} + +static const grpc_alts_credentials_options_vtable vtable = { + alts_client_options_copy, alts_client_options_destroy}; + +grpc_alts_credentials_options* grpc_alts_credentials_client_options_create( + void) { + auto client_options = static_cast( + gpr_zalloc(sizeof(grpc_alts_credentials_client_options))); + client_options->base.vtable = &vtable; + return &client_options->base; +} + +static grpc_alts_credentials_options* alts_client_options_copy( + const grpc_alts_credentials_options* options) { + if (options == nullptr) { + return nullptr; + } + grpc_alts_credentials_options* new_options = + grpc_alts_credentials_client_options_create(); + auto new_client_options = + reinterpret_cast(new_options); + /* Copy target service accounts. */ + target_service_account* prev = nullptr; + auto node = + (reinterpret_cast(options)) + ->target_account_list_head; + while (node != nullptr) { + target_service_account* new_node = + target_service_account_create(node->data); + if (prev == nullptr) { + new_client_options->target_account_list_head = new_node; + } else { + prev->next = new_node; + } + prev = new_node; + node = node->next; + } + /* Copy rpc protocol versions. */ + grpc_gcp_rpc_protocol_versions_copy(&options->rpc_versions, + &new_options->rpc_versions); + return new_options; +} + +static void alts_client_options_destroy( + grpc_alts_credentials_options* options) { + if (options == nullptr) { + return; + } + auto* client_options = + reinterpret_cast(options); + target_service_account* node = client_options->target_account_list_head; + while (node != nullptr) { + target_service_account* next_node = node->next; + target_service_account_destroy(node); + node = next_node; + } +} diff --git a/src/core/lib/security/credentials/alts/grpc_alts_credentials_options.cc b/src/core/lib/security/credentials/alts/grpc_alts_credentials_options.cc new file mode 100644 index 00000000..d4281715 --- /dev/null +++ b/src/core/lib/security/credentials/alts/grpc_alts_credentials_options.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" + +#include +#include + +grpc_alts_credentials_options* grpc_alts_credentials_options_copy( + const grpc_alts_credentials_options* options) { + if (options != nullptr && options->vtable != nullptr && + options->vtable->copy != nullptr) { + return options->vtable->copy(options); + } + /* An error occurred. */ + gpr_log(GPR_ERROR, + "Invalid arguments to grpc_alts_credentials_options_copy()"); + return nullptr; +} + +void grpc_alts_credentials_options_destroy( + grpc_alts_credentials_options* options) { + if (options != nullptr) { + if (options->vtable != nullptr && options->vtable->destruct != nullptr) { + options->vtable->destruct(options); + } + gpr_free(options); + } +} diff --git a/src/core/lib/security/credentials/alts/grpc_alts_credentials_server_options.cc b/src/core/lib/security/credentials/alts/grpc_alts_credentials_server_options.cc new file mode 100644 index 00000000..72c65421 --- /dev/null +++ b/src/core/lib/security/credentials/alts/grpc_alts_credentials_server_options.cc @@ -0,0 +1,59 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include + +#include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +static grpc_alts_credentials_options* alts_server_options_copy( + const grpc_alts_credentials_options* options); + +static void alts_server_options_destroy( + grpc_alts_credentials_options* /*options*/) {} + +static const grpc_alts_credentials_options_vtable vtable = { + alts_server_options_copy, alts_server_options_destroy}; + +grpc_alts_credentials_options* grpc_alts_credentials_server_options_create( + void) { + grpc_alts_credentials_server_options* server_options = + static_cast( + gpr_zalloc(sizeof(*server_options))); + server_options->base.vtable = &vtable; + return &server_options->base; +} + +static grpc_alts_credentials_options* alts_server_options_copy( + const grpc_alts_credentials_options* options) { + if (options == nullptr) { + return nullptr; + } + grpc_alts_credentials_options* new_options = + grpc_alts_credentials_server_options_create(); + /* Copy rpc protocol versions. */ + grpc_gcp_rpc_protocol_versions_copy(&options->rpc_versions, + &new_options->rpc_versions); + return new_options; +} diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc new file mode 100644 index 00000000..83bc728c --- /dev/null +++ b/src/core/lib/security/credentials/composite/composite_credentials.cc @@ -0,0 +1,230 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/composite/composite_credentials.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/surface/api_trace.h" + +/* -- Composite call credentials. -- */ + +static void composite_call_metadata_cb(void* arg, grpc_error_handle error); + +namespace { +struct grpc_composite_call_credentials_metadata_context { + grpc_composite_call_credentials_metadata_context( + grpc_composite_call_credentials* composite_creds, + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata) + : composite_creds(composite_creds), + pollent(pollent), + auth_md_context(auth_md_context), + md_array(md_array), + on_request_metadata(on_request_metadata) { + GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb, + this, grpc_schedule_on_exec_ctx); + } + + grpc_composite_call_credentials* composite_creds; + size_t creds_index = 0; + grpc_polling_entity* pollent; + grpc_auth_metadata_context auth_md_context; + grpc_credentials_mdelem_array* md_array; + grpc_closure* on_request_metadata; + grpc_closure internal_on_request_metadata; +}; +} // namespace + +static void composite_call_metadata_cb(void* arg, grpc_error_handle error) { + grpc_composite_call_credentials_metadata_context* ctx = + static_cast(arg); + if (error == GRPC_ERROR_NONE) { + const grpc_composite_call_credentials::CallCredentialsList& inner = + ctx->composite_creds->inner(); + /* See if we need to get some more metadata. */ + if (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, + &ctx->internal_on_request_metadata, &error)) { + // Synchronous response, so call ourselves recursively. + composite_call_metadata_cb(arg, error); + GRPC_ERROR_UNREF(error); + } + return; + } + // We're done! + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, ctx->on_request_metadata, + GRPC_ERROR_REF(error)); + delete ctx; +} + +bool grpc_composite_call_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error_handle* error) { + grpc_composite_call_credentials_metadata_context* ctx; + ctx = new grpc_composite_call_credentials_metadata_context( + this, pollent, auth_md_context, md_array, on_request_metadata); + bool synchronous = true; + const CallCredentialsList& inner = ctx->composite_creds->inner(); + while (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, + &ctx->internal_on_request_metadata, error)) { + if (*error != GRPC_ERROR_NONE) break; + } else { + synchronous = false; // Async return. + break; + } + } + if (synchronous) delete ctx; + return synchronous; +} + +void grpc_composite_call_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error_handle error) { + for (size_t i = 0; i < inner_.size(); ++i) { + inner_[i]->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error)); + } + GRPC_ERROR_UNREF(error); +} + +std::string grpc_composite_call_credentials::debug_string() { + std::vector outputs; + for (auto& inner_cred : inner_) { + outputs.emplace_back(inner_cred->debug_string()); + } + return absl::StrCat("CompositeCallCredentials{", absl::StrJoin(outputs, ","), + "}"); +} + +static size_t get_creds_array_size(const grpc_call_credentials* creds, + bool is_composite) { + return is_composite + ? static_cast(creds) + ->inner() + .size() + : 1; +} + +void grpc_composite_call_credentials::push_to_inner( + grpc_core::RefCountedPtr creds, bool is_composite) { + if (!is_composite) { + inner_.push_back(std::move(creds)); + return; + } + auto composite_creds = + static_cast(creds.get()); + for (size_t i = 0; i < composite_creds->inner().size(); ++i) { + inner_.push_back(composite_creds->inner_[i]); + } +} + +grpc_composite_call_credentials::grpc_composite_call_credentials( + grpc_core::RefCountedPtr creds1, + grpc_core::RefCountedPtr creds2) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) { + const bool creds1_is_composite = + strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const bool creds2_is_composite = + strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) + + get_creds_array_size(creds2.get(), creds2_is_composite); + inner_.reserve(size); + push_to_inner(std::move(creds1), creds1_is_composite); + push_to_inner(std::move(creds2), creds2_is_composite); + min_security_level_ = GRPC_SECURITY_NONE; + for (size_t i = 0; i < inner_.size(); ++i) { + if (static_cast(min_security_level_) < + static_cast(inner_[i]->min_security_level())) { + min_security_level_ = inner_[i]->min_security_level(); + } + } +} + +static grpc_core::RefCountedPtr +composite_call_credentials_create( + grpc_core::RefCountedPtr creds1, + grpc_core::RefCountedPtr creds2) { + return grpc_core::MakeRefCounted( + std::move(creds1), std::move(creds2)); +} + +grpc_call_credentials* grpc_composite_call_credentials_create( + grpc_call_credentials* creds1, grpc_call_credentials* creds2, + void* reserved) { + GRPC_API_TRACE( + "grpc_composite_call_credentials_create(creds1=%p, creds2=%p, " + "reserved=%p)", + 3, (creds1, creds2, reserved)); + GPR_ASSERT(reserved == nullptr); + GPR_ASSERT(creds1 != nullptr); + GPR_ASSERT(creds2 != nullptr); + + return composite_call_credentials_create(creds1->Ref(), creds2->Ref()) + .release(); +} + +/* -- Composite channel credentials. -- */ + +grpc_core::RefCountedPtr +grpc_composite_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) { + GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr); + /* If we are passed a call_creds, create a call composite to pass it + downstream. */ + if (call_creds != nullptr) { + return inner_creds_->create_security_connector( + composite_call_credentials_create(call_creds_, std::move(call_creds)), + target, args, new_args); + } else { + return inner_creds_->create_security_connector(call_creds_, target, args, + new_args); + } +} + +grpc_channel_credentials* grpc_composite_channel_credentials_create( + grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds, + void* reserved) { + GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr && + reserved == nullptr); + GRPC_API_TRACE( + "grpc_composite_channel_credentials_create(channel_creds=%p, " + "call_creds=%p, reserved=%p)", + 3, (channel_creds, call_creds, reserved)); + return new grpc_composite_channel_credentials(channel_creds->Ref(), + call_creds->Ref()); +} diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc new file mode 100644 index 00000000..01e409cf --- /dev/null +++ b/src/core/lib/security/credentials/credentials.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/credentials.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/surface/api_trace.h" + +/* -- Common. -- */ + +void grpc_channel_credentials_release(grpc_channel_credentials* creds) { + GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds)); + grpc_core::ExecCtx exec_ctx; + if (creds) creds->Unref(); +} + +void grpc_call_credentials_release(grpc_call_credentials* creds) { + GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds)); + grpc_core::ExecCtx exec_ctx; + if (creds) creds->Unref(); +} + +static void credentials_pointer_arg_destroy(void* p) { + static_cast(p)->Unref(); +} + +static void* credentials_pointer_arg_copy(void* p) { + return static_cast(p)->Ref().release(); +} + +static int credentials_pointer_cmp(void* a, void* b) { + return grpc_core::QsortCompare(a, b); +} + +static const grpc_arg_pointer_vtable credentials_pointer_vtable = { + credentials_pointer_arg_copy, credentials_pointer_arg_destroy, + credentials_pointer_cmp}; + +grpc_arg grpc_channel_credentials_to_arg( + grpc_channel_credentials* credentials) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CHANNEL_CREDENTIALS), credentials, + &credentials_pointer_vtable); +} + +grpc_channel_credentials* grpc_channel_credentials_from_arg( + const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_ARG_CHANNEL_CREDENTIALS) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_ARG_CHANNEL_CREDENTIALS); + return nullptr; + } + return static_cast(arg->value.pointer.p); +} + +grpc_channel_credentials* grpc_channel_credentials_find_in_args( + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_channel_credentials* credentials = + grpc_channel_credentials_from_arg(&args->args[i]); + if (credentials != nullptr) return credentials; + } + return nullptr; +} + +void grpc_server_credentials_release(grpc_server_credentials* creds) { + GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds)); + grpc_core::ExecCtx exec_ctx; + if (creds) creds->Unref(); +} + +void grpc_server_credentials::set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor) { + GRPC_API_TRACE( + "grpc_server_credentials_set_auth_metadata_processor(" + "creds=%p, " + "processor=grpc_auth_metadata_processor { process: %p, state: %p })", + 3, (this, (void*)(intptr_t)processor.process, processor.state)); + DestroyProcessor(); + processor_ = processor; +} + +void grpc_server_credentials_set_auth_metadata_processor( + grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { + GPR_DEBUG_ASSERT(creds != nullptr); + creds->set_auth_metadata_processor(processor); +} + +static void server_credentials_pointer_arg_destroy(void* p) { + static_cast(p)->Unref(); +} + +static void* server_credentials_pointer_arg_copy(void* p) { + return static_cast(p)->Ref().release(); +} + +static int server_credentials_pointer_cmp(void* a, void* b) { + return grpc_core::QsortCompare(a, b); +} + +static const grpc_arg_pointer_vtable cred_ptr_vtable = { + server_credentials_pointer_arg_copy, server_credentials_pointer_arg_destroy, + server_credentials_pointer_cmp}; + +grpc_arg grpc_server_credentials_to_arg(grpc_server_credentials* c) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_SERVER_CREDENTIALS_ARG), c, &cred_ptr_vtable); +} + +grpc_server_credentials* grpc_server_credentials_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_SERVER_CREDENTIALS_ARG) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_SERVER_CREDENTIALS_ARG); + return nullptr; + } + return static_cast(arg->value.pointer.p); +} + +grpc_server_credentials* grpc_find_server_credentials_in_args( + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_server_credentials* p = + grpc_server_credentials_from_arg(&args->args[i]); + if (p != nullptr) return p; + } + return nullptr; +} diff --git a/src/core/lib/security/credentials/credentials_metadata.cc b/src/core/lib/security/credentials/credentials_metadata.cc new file mode 100644 index 00000000..9d0284d3 --- /dev/null +++ b/src/core/lib/security/credentials/credentials_metadata.cc @@ -0,0 +1,61 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/slice/slice_internal.h" + +static void mdelem_list_ensure_capacity(grpc_credentials_mdelem_array* list, + size_t additional_space_needed) { + size_t target_size = list->size + additional_space_needed; + // Find the next power of two greater than the target size (i.e., + // whenever we add more space, we double what we already have). + size_t new_size = 2; + while (new_size < target_size) { + new_size *= 2; + } + list->md = static_cast( + gpr_realloc(list->md, sizeof(grpc_mdelem) * new_size)); +} + +void grpc_credentials_mdelem_array_add(grpc_credentials_mdelem_array* list, + grpc_mdelem md) { + mdelem_list_ensure_capacity(list, 1); + list->md[list->size++] = GRPC_MDELEM_REF(md); +} + +void grpc_credentials_mdelem_array_append(grpc_credentials_mdelem_array* dst, + grpc_credentials_mdelem_array* src) { + mdelem_list_ensure_capacity(dst, src->size); + for (size_t i = 0; i < src->size; ++i) { + dst->md[dst->size++] = GRPC_MDELEM_REF(src->md[i]); + } +} + +void grpc_credentials_mdelem_array_destroy( + grpc_credentials_mdelem_array* list) { + for (size_t i = 0; i < list->size; ++i) { + GRPC_MDELEM_UNREF(list->md[i]); + } + gpr_free(list->md); +} diff --git a/src/core/lib/security/credentials/external/aws_external_account_credentials.cc b/src/core/lib/security/credentials/external/aws_external_account_credentials.cc new file mode 100644 index 00000000..03dd4d59 --- /dev/null +++ b/src/core/lib/security/credentials/external/aws_external_account_credentials.cc @@ -0,0 +1,404 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include + +#include "src/core/lib/security/credentials/external/aws_external_account_credentials.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" + +#include "src/core/lib/gpr/env.h" + +namespace grpc_core { + +namespace { + +const char* kExpectedEnvironmentId = "aws1"; + +const char* kRegionEnvVar = "AWS_REGION"; +const char* kDefaultRegionEnvVar = "AWS_DEFAULT_REGION"; +const char* kAccessKeyIdEnvVar = "AWS_ACCESS_KEY_ID"; +const char* kSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"; +const char* kSessionTokenEnvVar = "AWS_SESSION_TOKEN"; + +std::string UrlEncode(const absl::string_view& s) { + const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(s.length()); + for (auto c : s) { + if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || c == '-' || c == '_' || c == '!' || + c == '\'' || c == '(' || c == ')' || c == '*' || c == '~' || c == '.') { + result.push_back(c); + } else { + result.push_back('%'); + result.push_back(hex[static_cast(c) >> 4]); + result.push_back(hex[static_cast(c) & 15]); + } + } + return result; +} + +} // namespace + +RefCountedPtr +AwsExternalAccountCredentials::Create(Options options, + std::vector scopes, + grpc_error_handle* error) { + auto creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + if (*error == GRPC_ERROR_NONE) { + return creds; + } else { + return nullptr; + } +} + +AwsExternalAccountCredentials::AwsExternalAccountCredentials( + Options options, std::vector scopes, grpc_error_handle* error) + : ExternalAccountCredentials(options, std::move(scopes)) { + audience_ = options.audience; + auto it = options.credential_source.object_value().find("environment_id"); + if (it == options.credential_source.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "environment_id field not present."); + return; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "environment_id field must be a string."); + return; + } + if (it->second.string_value() != kExpectedEnvironmentId) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("environment_id does not match."); + return; + } + it = options.credential_source.object_value().find("region_url"); + if (it == options.credential_source.object_value().end()) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("region_url field not present."); + return; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "region_url field must be a string."); + return; + } + region_url_ = it->second.string_value(); + it = options.credential_source.object_value().find("url"); + if (it != options.credential_source.object_value().end() && + it->second.type() == Json::Type::STRING) { + url_ = it->second.string_value(); + } + it = options.credential_source.object_value().find( + "regional_cred_verification_url"); + if (it == options.credential_source.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "regional_cred_verification_url field not present."); + return; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "regional_cred_verification_url field must be a string."); + return; + } + regional_cred_verification_url_ = it->second.string_value(); +} + +void AwsExternalAccountCredentials::RetrieveSubjectToken( + HTTPRequestContext* ctx, const Options& /*options*/, + std::function cb) { + if (ctx == nullptr) { + FinishRetrieveSubjectToken( + "", + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing HTTPRequestContext to start subject token retrieval.")); + return; + } + ctx_ = ctx; + cb_ = cb; + if (signer_ != nullptr) { + BuildSubjectToken(); + } else { + RetrieveRegion(); + } +} + +void AwsExternalAccountCredentials::RetrieveRegion() { + UniquePtr region_from_env(gpr_getenv(kRegionEnvVar)); + if (region_from_env == nullptr) { + region_from_env = UniquePtr(gpr_getenv(kDefaultRegionEnvVar)); + } + if (region_from_env != nullptr) { + region_ = std::string(region_from_env.get()); + if (url_.empty()) { + RetrieveSigningKeys(); + } else { + RetrieveRoleName(); + } + return; + } + absl::StatusOr uri = URI::Parse(region_url_); + if (!uri.ok()) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Invalid region url. %s", uri.status().ToString()))); + return; + } + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(uri->authority().c_str()); + request.http.path = gpr_strdup(uri->path().c_str()); + request.handshaker = + uri->scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnRetrieveRegion, this, nullptr); + grpc_httpcli_get(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, ctx_->deadline, &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void AwsExternalAccountCredentials::OnRetrieveRegion(void* arg, + grpc_error_handle error) { + AwsExternalAccountCredentials* self = + static_cast(arg); + self->OnRetrieveRegionInternal(GRPC_ERROR_REF(error)); +} + +void AwsExternalAccountCredentials::OnRetrieveRegionInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken("", error); + return; + } + // Remove the last letter of availability zone to get pure region + absl::string_view response_body(ctx_->response.body, + ctx_->response.body_length); + region_ = std::string(response_body.substr(0, response_body.size() - 1)); + if (url_.empty()) { + RetrieveSigningKeys(); + } else { + RetrieveRoleName(); + } +} + +void AwsExternalAccountCredentials::RetrieveRoleName() { + absl::StatusOr uri = URI::Parse(url_); + if (!uri.ok()) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Invalid url: %s.", uri.status().ToString()))); + return; + } + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(uri->authority().c_str()); + request.http.path = gpr_strdup(uri->path().c_str()); + request.handshaker = + uri->scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnRetrieveRoleName, this, nullptr); + grpc_httpcli_get(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, ctx_->deadline, &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void AwsExternalAccountCredentials::OnRetrieveRoleName( + void* arg, grpc_error_handle error) { + AwsExternalAccountCredentials* self = + static_cast(arg); + self->OnRetrieveRoleNameInternal(GRPC_ERROR_REF(error)); +} + +void AwsExternalAccountCredentials::OnRetrieveRoleNameInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken("", error); + return; + } + role_name_ = std::string(ctx_->response.body, ctx_->response.body_length); + RetrieveSigningKeys(); +} + +void AwsExternalAccountCredentials::RetrieveSigningKeys() { + UniquePtr access_key_id_from_env(gpr_getenv(kAccessKeyIdEnvVar)); + UniquePtr secret_access_key_from_env( + gpr_getenv(kSecretAccessKeyEnvVar)); + UniquePtr token_from_env(gpr_getenv(kSessionTokenEnvVar)); + if (access_key_id_from_env != nullptr && + secret_access_key_from_env != nullptr && token_from_env != nullptr) { + access_key_id_ = std::string(access_key_id_from_env.get()); + secret_access_key_ = std::string(secret_access_key_from_env.get()); + token_ = std::string(token_from_env.get()); + BuildSubjectToken(); + return; + } + if (role_name_.empty()) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing role name when retrieving signing keys.")); + return; + } + std::string url_with_role_name = absl::StrCat(url_, "/", role_name_); + absl::StatusOr uri = URI::Parse(url_with_role_name); + if (!uri.ok()) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Invalid url with role name: %s.", uri.status().ToString()))); + return; + } + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(uri->authority().c_str()); + request.http.path = gpr_strdup(uri->path().c_str()); + request.handshaker = + uri->scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnRetrieveSigningKeys, this, nullptr); + grpc_httpcli_get(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, ctx_->deadline, &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void AwsExternalAccountCredentials::OnRetrieveSigningKeys( + void* arg, grpc_error_handle error) { + AwsExternalAccountCredentials* self = + static_cast(arg); + self->OnRetrieveSigningKeysInternal(GRPC_ERROR_REF(error)); +} + +void AwsExternalAccountCredentials::OnRetrieveSigningKeysInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken("", error); + return; + } + absl::string_view response_body(ctx_->response.body, + ctx_->response.body_length); + Json json = Json::Parse(response_body, &error); + if (error != GRPC_ERROR_NONE || json.type() != Json::Type::OBJECT) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Invalid retrieve signing keys response.", &error, 1)); + GRPC_ERROR_UNREF(error); + return; + } + auto it = json.object_value().find("AccessKeyId"); + if (it != json.object_value().end() && + it->second.type() == Json::Type::STRING) { + access_key_id_ = it->second.string_value(); + } else { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid AccessKeyId in %s.", response_body))); + return; + } + it = json.object_value().find("SecretAccessKey"); + if (it != json.object_value().end() && + it->second.type() == Json::Type::STRING) { + secret_access_key_ = it->second.string_value(); + } else { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid SecretAccessKey in %s.", response_body))); + return; + } + it = json.object_value().find("Token"); + if (it != json.object_value().end() && + it->second.type() == Json::Type::STRING) { + token_ = it->second.string_value(); + } else { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid Token in %s.", response_body))); + return; + } + BuildSubjectToken(); +} + +void AwsExternalAccountCredentials::BuildSubjectToken() { + grpc_error_handle error = GRPC_ERROR_NONE; + if (signer_ == nullptr) { + cred_verification_url_ = absl::StrReplaceAll( + regional_cred_verification_url_, {{"{region}", region_}}); + signer_ = absl::make_unique( + access_key_id_, secret_access_key_, token_, "POST", + cred_verification_url_, region_, "", + std::map(), &error); + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Creating aws request signer failed.", &error, 1)); + GRPC_ERROR_UNREF(error); + return; + } + } + auto signed_headers = signer_->GetSignedRequestHeaders(); + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken("", + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Invalid getting signed request" + "headers.", + &error, 1)); + GRPC_ERROR_UNREF(error); + return; + } + // Construct subject token + Json::Array headers; + headers.push_back(Json( + {{"key", "Authorization"}, {"value", signed_headers["Authorization"]}})); + headers.push_back(Json({{"key", "host"}, {"value", signed_headers["host"]}})); + headers.push_back( + Json({{"key", "x-amz-date"}, {"value", signed_headers["x-amz-date"]}})); + headers.push_back(Json({{"key", "x-amz-security-token"}, + {"value", signed_headers["x-amz-security-token"]}})); + headers.push_back( + Json({{"key", "x-goog-cloud-target-resource"}, {"value", audience_}})); + Json::Object object{{"url", Json(cred_verification_url_)}, + {"method", Json("POST")}, + {"headers", Json(headers)}}; + Json subject_token_json(object); + std::string subject_token = UrlEncode(subject_token_json.Dump()); + FinishRetrieveSubjectToken(subject_token, GRPC_ERROR_NONE); +} + +void AwsExternalAccountCredentials::FinishRetrieveSubjectToken( + std::string subject_token, grpc_error_handle error) { + // Reset context + ctx_ = nullptr; + // Move object state into local variables. + auto cb = cb_; + cb_ = nullptr; + // Invoke the callback. + if (error != GRPC_ERROR_NONE) { + cb("", error); + } else { + cb(subject_token, GRPC_ERROR_NONE); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/external/aws_request_signer.cc b/src/core/lib/security/credentials/external/aws_request_signer.cc new file mode 100644 index 00000000..bc4920f3 --- /dev/null +++ b/src/core/lib/security/credentials/external/aws_request_signer.cc @@ -0,0 +1,214 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include + +#include "src/core/lib/security/credentials/external/aws_request_signer.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" + +namespace grpc_core { + +namespace { + +const char kAlgorithm[] = "AWS4-HMAC-SHA256"; +const char kDateFormat[] = "%a, %d %b %E4Y %H:%M:%S %Z"; +const char kXAmzDateFormat[] = "%Y%m%dT%H%M%SZ"; + +void SHA256(const std::string& str, unsigned char out[SHA256_DIGEST_LENGTH]) { + SHA256_CTX sha256; + SHA256_Init(&sha256); + SHA256_Update(&sha256, str.c_str(), str.size()); + SHA256_Final(out, &sha256); +} + +std::string SHA256Hex(const std::string& str) { + unsigned char hash[SHA256_DIGEST_LENGTH]; + SHA256(str, hash); + std::string hash_str(reinterpret_cast(hash), + SHA256_DIGEST_LENGTH); + return absl::BytesToHexString(hash_str); +} + +std::string HMAC(const std::string& key, const std::string& msg) { + unsigned int len; + unsigned char digest[EVP_MAX_MD_SIZE]; + HMAC(EVP_sha256(), key.c_str(), key.length(), + reinterpret_cast(msg.c_str()), msg.length(), + digest, &len); + return std::string(digest, digest + len); +} + +} // namespace + +AwsRequestSigner::AwsRequestSigner( + std::string access_key_id, std::string secret_access_key, std::string token, + std::string method, std::string url, std::string region, + std::string request_payload, + std::map additional_headers, + grpc_error_handle* error) + : access_key_id_(std::move(access_key_id)), + secret_access_key_(std::move(secret_access_key)), + token_(std::move(token)), + method_(std::move(method)), + region_(std::move(region)), + request_payload_(std::move(request_payload)), + additional_headers_(std::move(additional_headers)) { + auto amz_date_it = additional_headers_.find("x-amz-date"); + auto date_it = additional_headers_.find("date"); + if (amz_date_it != additional_headers_.end() && + date_it != additional_headers_.end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Only one of {date, x-amz-date} can be specified, not both."); + return; + } + if (amz_date_it != additional_headers_.end()) { + static_request_date_ = amz_date_it->second; + } else if (date_it != additional_headers_.end()) { + absl::Time request_date; + std::string err_str; + if (!absl::ParseTime(kDateFormat, date_it->second, &request_date, + &err_str)) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(err_str.c_str()); + return; + } + static_request_date_ = + absl::FormatTime(kXAmzDateFormat, request_date, absl::UTCTimeZone()); + } + absl::StatusOr tmp_url = URI::Parse(url); + if (!tmp_url.ok()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Invalid Aws request url."); + return; + } + url_ = tmp_url.value(); +} + +std::map AwsRequestSigner::GetSignedRequestHeaders() { + std::string request_date_full; + if (!static_request_date_.empty()) { + if (!request_headers_.empty()) { + return request_headers_; + } + request_date_full = static_request_date_; + } else { + absl::Time request_date = absl::Now(); + request_date_full = + absl::FormatTime(kXAmzDateFormat, request_date, absl::UTCTimeZone()); + } + std::string request_date_short = request_date_full.substr(0, 8); + // TASK 1: Create a canonical request for Signature Version 4 + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + std::vector canonical_request_vector; + // 1. HTTPRequestMethod + canonical_request_vector.emplace_back(method_); + canonical_request_vector.emplace_back("\n"); + // 2. CanonicalURI + canonical_request_vector.emplace_back( + url_.path().empty() ? "/" : absl::string_view(url_.path())); + canonical_request_vector.emplace_back("\n"); + // 3. CanonicalQueryString + std::vector query_vector; + for (const URI::QueryParam& query_kv : url_.query_parameter_pairs()) { + query_vector.emplace_back(absl::StrCat(query_kv.key, "=", query_kv.value)); + } + std::string query = absl::StrJoin(query_vector, "&"); + canonical_request_vector.emplace_back(query); + canonical_request_vector.emplace_back("\n"); + // 4. CanonicalHeaders + if (request_headers_.empty()) { + request_headers_.insert({"host", url_.authority()}); + if (!token_.empty()) { + request_headers_.insert({"x-amz-security-token", token_}); + } + for (const auto& header : additional_headers_) { + request_headers_.insert( + {absl::AsciiStrToLower(header.first), header.second}); + } + } + if (additional_headers_.find("date") == additional_headers_.end()) { + request_headers_["x-amz-date"] = request_date_full; + } + std::vector canonical_headers_vector; + for (const auto& header : request_headers_) { + canonical_headers_vector.emplace_back(header.first); + canonical_headers_vector.emplace_back(":"); + canonical_headers_vector.emplace_back(header.second); + canonical_headers_vector.emplace_back("\n"); + } + std::string canonical_headers = absl::StrJoin(canonical_headers_vector, ""); + canonical_request_vector.emplace_back(canonical_headers); + canonical_request_vector.emplace_back("\n"); + // 5. SignedHeaders + std::vector signed_headers_vector; + for (const auto& header : request_headers_) { + signed_headers_vector.emplace_back(header.first); + } + std::string signed_headers = absl::StrJoin(signed_headers_vector, ";"); + canonical_request_vector.emplace_back(signed_headers); + canonical_request_vector.emplace_back("\n"); + // 6. RequestPayload + std::string hashed_request_payload = SHA256Hex(request_payload_); + canonical_request_vector.emplace_back(hashed_request_payload); + std::string canonical_request = absl::StrJoin(canonical_request_vector, ""); + // TASK 2: Create a string to sign for Signature Version 4 + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html + std::vector string_to_sign_vector; + // 1. Algorithm + string_to_sign_vector.emplace_back("AWS4-HMAC-SHA256"); + string_to_sign_vector.emplace_back("\n"); + // 2. RequestDateTime + string_to_sign_vector.emplace_back(request_date_full); + string_to_sign_vector.emplace_back("\n"); + // 3. CredentialScope + std::pair host_parts = + absl::StrSplit(url_.authority(), absl::MaxSplits('.', 1)); + std::string service_name(host_parts.first); + std::string credential_scope = absl::StrFormat( + "%s/%s/%s/aws4_request", request_date_short, region_, service_name); + string_to_sign_vector.emplace_back(credential_scope); + string_to_sign_vector.emplace_back("\n"); + // 4. HashedCanonicalRequest + std::string hashed_canonical_request = SHA256Hex(canonical_request); + string_to_sign_vector.emplace_back(hashed_canonical_request); + std::string string_to_sign = absl::StrJoin(string_to_sign_vector, ""); + // TASK 3: Task 3: Calculate the signature for AWS Signature Version 4 + // https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + // 1. Derive your signing key. + std::string date = HMAC("AWS4" + secret_access_key_, request_date_short); + std::string region = HMAC(date, region_); + std::string service = HMAC(region, service_name); + std::string signing = HMAC(service, "aws4_request"); + // 2. Calculate the signature. + std::string signature_str = HMAC(signing, string_to_sign); + std::string signature = absl::BytesToHexString(signature_str); + // TASK 4: Add the signature to the HTTP request + // https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html + std::string authorization_header = absl::StrFormat( + "%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", kAlgorithm, + access_key_id_, credential_scope, signed_headers, signature); + request_headers_["Authorization"] = authorization_header; + return request_headers_; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/external/external_account_credentials.cc b/src/core/lib/security/credentials/external/external_account_credentials.cc new file mode 100644 index 00000000..dbcb1238 --- /dev/null +++ b/src/core/lib/security/credentials/external/external_account_credentials.cc @@ -0,0 +1,527 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include + +#include "src/core/lib/security/credentials/external/external_account_credentials.h" + +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" + +#include "src/core/lib/http/parser.h" +#include "src/core/lib/security/credentials/external/aws_external_account_credentials.h" +#include "src/core/lib/security/credentials/external/file_external_account_credentials.h" +#include "src/core/lib/security/credentials/external/url_external_account_credentials.h" +#include "src/core/lib/security/util/json_util.h" +#include "src/core/lib/slice/b64.h" + +#define EXTERNAL_ACCOUNT_CREDENTIALS_GRANT_TYPE \ + "urn:ietf:params:oauth:grant-type:token-exchange" +#define EXTERNAL_ACCOUNT_CREDENTIALS_REQUESTED_TOKEN_TYPE \ + "urn:ietf:params:oauth:token-type:access_token" +#define GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE \ + "https://www.googleapis.com/auth/cloud-platform" + +namespace grpc_core { + +namespace { + +std::string UrlEncode(const absl::string_view& s) { + const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(s.length()); + for (auto c : s) { + if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || c == '-' || c == '_' || c == '!' || + c == '\'' || c == '(' || c == ')' || c == '*' || c == '~' || c == '.') { + result.push_back(c); + } else { + result.push_back('%'); + result.push_back(hex[static_cast(c) >> 4]); + result.push_back(hex[static_cast(c) & 15]); + } + } + return result; +} + +// Expression to match: +// //iam.googleapis.com/locations/[^/]+/workforcePools/[^/]+/providers/.+ +bool MatchWorkforcePoolAudience(absl::string_view audience) { + // Match "//iam.googleapis.com/locations/" + if (!absl::ConsumePrefix(&audience, "//iam.googleapis.com")) return false; + if (!absl::ConsumePrefix(&audience, "/locations/")) return false; + // Match "[^/]+/workforcePools/" + std::pair workforce_pools_split_result = + absl::StrSplit(audience, absl::MaxSplits("/workforcePools/", 1)); + if (absl::StrContains(workforce_pools_split_result.first, '/')) return false; + // Match "[^/]+/providers/.+" + std::pair providers_split_result = + absl::StrSplit(workforce_pools_split_result.second, + absl::MaxSplits("/providers/", 1)); + return !absl::StrContains(providers_split_result.first, '/'); +} + +} // namespace + +RefCountedPtr ExternalAccountCredentials::Create( + const Json& json, std::vector scopes, + grpc_error_handle* error) { + GPR_ASSERT(*error == GRPC_ERROR_NONE); + Options options; + options.type = GRPC_AUTH_JSON_TYPE_INVALID; + if (json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid json to construct credentials options."); + return nullptr; + } + auto it = json.object_value().find("type"); + if (it == json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("type field not present."); + return nullptr; + } + if (it->second.type() != Json::Type::STRING) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("type field must be a string."); + return nullptr; + } + if (it->second.string_value() != GRPC_AUTH_JSON_TYPE_EXTERNAL_ACCOUNT) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Invalid credentials json type."); + return nullptr; + } + options.type = GRPC_AUTH_JSON_TYPE_EXTERNAL_ACCOUNT; + it = json.object_value().find("audience"); + if (it == json.object_value().end()) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("audience field not present."); + return nullptr; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "audience field must be a string."); + return nullptr; + } + options.audience = it->second.string_value(); + it = json.object_value().find("subject_token_type"); + if (it == json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "subject_token_type field not present."); + return nullptr; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "subject_token_type field must be a string."); + return nullptr; + } + options.subject_token_type = it->second.string_value(); + it = json.object_value().find("service_account_impersonation_url"); + if (it != json.object_value().end()) { + options.service_account_impersonation_url = it->second.string_value(); + } + it = json.object_value().find("token_url"); + if (it == json.object_value().end()) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("token_url field not present."); + return nullptr; + } + if (it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "token_url field must be a string."); + return nullptr; + } + options.token_url = it->second.string_value(); + it = json.object_value().find("token_info_url"); + if (it != json.object_value().end()) { + options.token_info_url = it->second.string_value(); + } + it = json.object_value().find("credential_source"); + if (it == json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "credential_source field not present."); + return nullptr; + } + options.credential_source = it->second; + it = json.object_value().find("quota_project_id"); + if (it != json.object_value().end()) { + options.quota_project_id = it->second.string_value(); + } + it = json.object_value().find("client_id"); + if (it != json.object_value().end()) { + options.client_id = it->second.string_value(); + } + it = json.object_value().find("client_secret"); + if (it != json.object_value().end()) { + options.client_secret = it->second.string_value(); + } + it = json.object_value().find("workforce_pool_user_project"); + if (it != json.object_value().end()) { + if (MatchWorkforcePoolAudience(options.audience)) { + options.workforce_pool_user_project = it->second.string_value(); + } else { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "workforce_pool_user_project should not be set for non-workforce " + "pool credentials"); + return nullptr; + } + } + RefCountedPtr creds; + if (options.credential_source.object_value().find("environment_id") != + options.credential_source.object_value().end()) { + creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + } else if (options.credential_source.object_value().find("file") != + options.credential_source.object_value().end()) { + creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + } else if (options.credential_source.object_value().find("url") != + options.credential_source.object_value().end()) { + creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + } else { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid options credential source to create " + "ExternalAccountCredentials."); + } + if (*error == GRPC_ERROR_NONE) { + return creds; + } else { + return nullptr; + } +} + +ExternalAccountCredentials::ExternalAccountCredentials( + Options options, std::vector scopes) + : options_(std::move(options)) { + if (scopes.empty()) { + scopes.push_back(GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE); + } + scopes_ = std::move(scopes); +} + +ExternalAccountCredentials::~ExternalAccountCredentials() {} + +std::string ExternalAccountCredentials::debug_string() { + return absl::StrFormat("ExternalAccountCredentials{Audience:%s,%s}", + options_.audience, + grpc_oauth2_token_fetcher_credentials::debug_string()); +} + +// The token fetching flow: +// 1. Retrieve subject token - Subclass's RetrieveSubjectToken() gets called +// and the subject token is received in OnRetrieveSubjectTokenInternal(). +// 2. Exchange token - ExchangeToken() gets called with the +// subject token from #1. Receive the response in OnExchangeTokenInternal(). +// 3. (Optional) Impersonate service account - ImpersenateServiceAccount() gets +// called with the access token of the response from #2. Get an impersonated +// access token in OnImpersenateServiceAccountInternal(). +// 4. Finish token fetch - Return back the response that contains an access +// token in FinishTokenFetch(). +// TODO(chuanr): Avoid starting the remaining requests if the channel gets shut +// down. +void ExternalAccountCredentials::fetch_oauth2( + grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, grpc_millis deadline) { + GPR_ASSERT(ctx_ == nullptr); + ctx_ = new HTTPRequestContext(httpcli_context, pollent, deadline); + metadata_req_ = metadata_req; + response_cb_ = response_cb; + auto cb = [this](std::string token, grpc_error_handle error) { + OnRetrieveSubjectTokenInternal(token, error); + }; + RetrieveSubjectToken(ctx_, options_, cb); +} + +void ExternalAccountCredentials::OnRetrieveSubjectTokenInternal( + absl::string_view subject_token, grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishTokenFetch(error); + } else { + ExchangeToken(subject_token); + } +} + +void ExternalAccountCredentials::ExchangeToken( + absl::string_view subject_token) { + absl::StatusOr uri = URI::Parse(options_.token_url); + if (!uri.ok()) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Invalid token url: %s. Error: %s", options_.token_url, + uri.status().ToString()))); + return; + } + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(uri->authority().c_str()); + request.http.path = gpr_strdup(uri->path().c_str()); + grpc_http_header* headers = nullptr; + if (!options_.client_id.empty() && !options_.client_secret.empty()) { + request.http.hdr_count = 2; + headers = static_cast( + gpr_malloc(sizeof(grpc_http_header) * request.http.hdr_count)); + headers[0].key = gpr_strdup("Content-Type"); + headers[0].value = gpr_strdup("application/x-www-form-urlencoded"); + std::string raw_cred = + absl::StrFormat("%s:%s", options_.client_id, options_.client_secret); + char* encoded_cred = + grpc_base64_encode(raw_cred.c_str(), raw_cred.length(), 0, 0); + std::string str = absl::StrFormat("Basic %s", std::string(encoded_cred)); + headers[1].key = gpr_strdup("Authorization"); + headers[1].value = gpr_strdup(str.c_str()); + gpr_free(encoded_cred); + } else { + request.http.hdr_count = 1; + headers = static_cast( + gpr_malloc(sizeof(grpc_http_header) * request.http.hdr_count)); + headers[0].key = gpr_strdup("Content-Type"); + headers[0].value = gpr_strdup("application/x-www-form-urlencoded"); + } + request.http.hdrs = headers; + request.handshaker = + uri->scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + std::vector body_parts; + body_parts.push_back( + absl::StrFormat("audience=%s", UrlEncode(options_.audience).c_str())); + body_parts.push_back(absl::StrFormat( + "grant_type=%s", + UrlEncode(EXTERNAL_ACCOUNT_CREDENTIALS_GRANT_TYPE).c_str())); + body_parts.push_back(absl::StrFormat( + "requested_token_type=%s", + UrlEncode(EXTERNAL_ACCOUNT_CREDENTIALS_REQUESTED_TOKEN_TYPE).c_str())); + body_parts.push_back(absl::StrFormat( + "subject_token_type=%s", UrlEncode(options_.subject_token_type).c_str())); + body_parts.push_back( + absl::StrFormat("subject_token=%s", UrlEncode(subject_token).c_str())); + std::string scope = GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE; + if (options_.service_account_impersonation_url.empty()) { + scope = absl::StrJoin(scopes_, " "); + } + body_parts.push_back(absl::StrFormat("scope=%s", UrlEncode(scope).c_str())); + Json::Object addtional_options_json_object; + if (options_.client_id.empty() && options_.client_secret.empty()) { + addtional_options_json_object["userProject"] = + options_.workforce_pool_user_project; + } + Json addtional_options_json(std::move(addtional_options_json_object)); + body_parts.push_back(absl::StrFormat( + "options=%s", UrlEncode(addtional_options_json.Dump()).c_str())); + std::string body = absl::StrJoin(body_parts, "&"); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnExchangeToken, this, nullptr); + grpc_httpcli_post(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, body.c_str(), body.size(), ctx_->deadline, + &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void ExternalAccountCredentials::OnExchangeToken(void* arg, + grpc_error_handle error) { + ExternalAccountCredentials* self = + static_cast(arg); + self->OnExchangeTokenInternal(GRPC_ERROR_REF(error)); +} + +void ExternalAccountCredentials::OnExchangeTokenInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishTokenFetch(error); + } else { + if (options_.service_account_impersonation_url.empty()) { + metadata_req_->response = ctx_->response; + metadata_req_->response.body = gpr_strdup( + std::string(ctx_->response.body, ctx_->response.body_length).c_str()); + metadata_req_->response.hdrs = static_cast( + gpr_malloc(sizeof(grpc_http_header) * ctx_->response.hdr_count)); + for (size_t i = 0; i < ctx_->response.hdr_count; i++) { + metadata_req_->response.hdrs[i].key = + gpr_strdup(ctx_->response.hdrs[i].key); + metadata_req_->response.hdrs[i].value = + gpr_strdup(ctx_->response.hdrs[i].value); + } + FinishTokenFetch(GRPC_ERROR_NONE); + } else { + ImpersenateServiceAccount(); + } + } +} + +void ExternalAccountCredentials::ImpersenateServiceAccount() { + grpc_error_handle error = GRPC_ERROR_NONE; + absl::string_view response_body(ctx_->response.body, + ctx_->response.body_length); + Json json = Json::Parse(response_body, &error); + if (error != GRPC_ERROR_NONE || json.type() != Json::Type::OBJECT) { + FinishTokenFetch(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Invalid token exchange response.", &error, 1)); + GRPC_ERROR_UNREF(error); + return; + } + auto it = json.object_value().find("access_token"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::STRING) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid access_token in %s.", response_body))); + return; + } + std::string access_token = it->second.string_value(); + absl::StatusOr uri = + URI::Parse(options_.service_account_impersonation_url); + if (!uri.ok()) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Invalid service account impersonation url: %s. Error: %s", + options_.service_account_impersonation_url, uri.status().ToString()))); + return; + } + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(uri->authority().c_str()); + request.http.path = gpr_strdup(uri->path().c_str()); + request.http.hdr_count = 2; + grpc_http_header* headers = static_cast( + gpr_malloc(sizeof(grpc_http_header) * request.http.hdr_count)); + headers[0].key = gpr_strdup("Content-Type"); + headers[0].value = gpr_strdup("application/x-www-form-urlencoded"); + std::string str = absl::StrFormat("Bearer %s", access_token); + headers[1].key = gpr_strdup("Authorization"); + headers[1].value = gpr_strdup(str.c_str()); + request.http.hdrs = headers; + request.handshaker = + uri->scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + std::string scope = absl::StrJoin(scopes_, " "); + std::string body = absl::StrFormat("scope=%s", scope); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnImpersenateServiceAccount, this, nullptr); + grpc_httpcli_post(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, body.c_str(), body.size(), ctx_->deadline, + &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void ExternalAccountCredentials::OnImpersenateServiceAccount( + void* arg, grpc_error_handle error) { + ExternalAccountCredentials* self = + static_cast(arg); + self->OnImpersenateServiceAccountInternal(GRPC_ERROR_REF(error)); +} + +void ExternalAccountCredentials::OnImpersenateServiceAccountInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishTokenFetch(error); + return; + } + absl::string_view response_body(ctx_->response.body, + ctx_->response.body_length); + Json json = Json::Parse(response_body, &error); + if (error != GRPC_ERROR_NONE || json.type() != Json::Type::OBJECT) { + FinishTokenFetch(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Invalid service account impersonation response.", &error, 1)); + GRPC_ERROR_UNREF(error); + return; + } + auto it = json.object_value().find("accessToken"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::STRING) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid accessToken in %s.", response_body))); + return; + } + std::string access_token = it->second.string_value(); + it = json.object_value().find("expireTime"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::STRING) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Missing or invalid expireTime in %s.", response_body))); + return; + } + std::string expire_time = it->second.string_value(); + absl::Time t; + if (!absl::ParseTime(absl::RFC3339_full, expire_time, &t, nullptr)) { + FinishTokenFetch(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid expire time of service account impersonation response.")); + return; + } + int expire_in = (t - absl::Now()) / absl::Seconds(1); + std::string body = absl::StrFormat( + "{\"access_token\":\"%s\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", + access_token, expire_in); + metadata_req_->response = ctx_->response; + metadata_req_->response.body = gpr_strdup(body.c_str()); + metadata_req_->response.body_length = body.length(); + metadata_req_->response.hdrs = static_cast( + gpr_malloc(sizeof(grpc_http_header) * ctx_->response.hdr_count)); + for (size_t i = 0; i < ctx_->response.hdr_count; i++) { + metadata_req_->response.hdrs[i].key = + gpr_strdup(ctx_->response.hdrs[i].key); + metadata_req_->response.hdrs[i].value = + gpr_strdup(ctx_->response.hdrs[i].value); + } + FinishTokenFetch(GRPC_ERROR_NONE); +} + +void ExternalAccountCredentials::FinishTokenFetch(grpc_error_handle error) { + GRPC_LOG_IF_ERROR("Fetch external account credentials access token", + GRPC_ERROR_REF(error)); + // Move object state into local variables. + auto* cb = response_cb_; + response_cb_ = nullptr; + auto* metadata_req = metadata_req_; + metadata_req_ = nullptr; + auto* ctx = ctx_; + ctx_ = nullptr; + // Invoke the callback. + cb(metadata_req, error); + // Delete context. + delete ctx; + GRPC_ERROR_UNREF(error); +} + +} // namespace grpc_core + +grpc_call_credentials* grpc_external_account_credentials_create( + const char* json_string, const char* scopes_string) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json json = grpc_core::Json::Parse(json_string, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "External account credentials creation failed. Error: %s.", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + std::vector scopes = absl::StrSplit(scopes_string, ','); + auto creds = grpc_core::ExternalAccountCredentials::Create( + json, std::move(scopes), &error) + .release(); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "External account credentials creation failed. Error: %s.", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + return nullptr; + } + return creds; +} diff --git a/src/core/lib/security/credentials/external/file_external_account_credentials.cc b/src/core/lib/security/credentials/external/file_external_account_credentials.cc new file mode 100644 index 00000000..d596d294 --- /dev/null +++ b/src/core/lib/security/credentials/external/file_external_account_credentials.cc @@ -0,0 +1,136 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include + +#include "src/core/lib/security/credentials/external/file_external_account_credentials.h" + +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_utils.h" + +namespace grpc_core { + +RefCountedPtr +FileExternalAccountCredentials::Create(Options options, + std::vector scopes, + grpc_error_handle* error) { + auto creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + if (*error == GRPC_ERROR_NONE) { + return creds; + } else { + return nullptr; + } +} + +FileExternalAccountCredentials::FileExternalAccountCredentials( + Options options, std::vector scopes, grpc_error_handle* error) + : ExternalAccountCredentials(options, std::move(scopes)) { + auto it = options.credential_source.object_value().find("file"); + if (it == options.credential_source.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("file field not present."); + return; + } + if (it->second.type() != Json::Type::STRING) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("file field must be a string."); + return; + } + file_ = it->second.string_value(); + it = options.credential_source.object_value().find("format"); + if (it != options.credential_source.object_value().end()) { + const Json& format_json = it->second; + if (format_json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The JSON value of credential source format is not an object."); + return; + } + auto format_it = format_json.object_value().find("type"); + if (format_it == format_json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.type field not present."); + return; + } + if (format_it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.type field must be a string."); + return; + } + format_type_ = format_it->second.string_value(); + if (format_type_ == "json") { + format_it = format_json.object_value().find("subject_token_field_name"); + if (format_it == format_json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.subject_token_field_name field must be present if the " + "format is in Json."); + return; + } + if (format_it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.subject_token_field_name field must be a string."); + return; + } + format_subject_token_field_name_ = format_it->second.string_value(); + } + } +} + +void FileExternalAccountCredentials::RetrieveSubjectToken( + HTTPRequestContext* /*ctx*/, const Options& /*options*/, + std::function cb) { + struct SliceWrapper { + ~SliceWrapper() { grpc_slice_unref_internal(slice); } + grpc_slice slice = grpc_empty_slice(); + }; + SliceWrapper content_slice; + // To retrieve the subject token, we read the file every time we make a + // request because it may have changed since the last request. + grpc_error_handle error = + grpc_load_file(file_.c_str(), 0, &content_slice.slice); + if (error != GRPC_ERROR_NONE) { + cb("", error); + return; + } + absl::string_view content = StringViewFromSlice(content_slice.slice); + if (format_type_ == "json") { + Json content_json = Json::Parse(content, &error); + if (error != GRPC_ERROR_NONE || content_json.type() != Json::Type::OBJECT) { + cb("", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The content of the file is not a valid json object.")); + GRPC_ERROR_UNREF(error); + return; + } + auto content_it = + content_json.object_value().find(format_subject_token_field_name_); + if (content_it == content_json.object_value().end()) { + cb("", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Subject token field not present.")); + return; + } + if (content_it->second.type() != Json::Type::STRING) { + cb("", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Subject token field must be a string.")); + return; + } + cb(content_it->second.string_value(), GRPC_ERROR_NONE); + return; + } + cb(std::string(content), GRPC_ERROR_NONE); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/external/url_external_account_credentials.cc b/src/core/lib/security/credentials/external/url_external_account_credentials.cc new file mode 100644 index 00000000..b48bf71f --- /dev/null +++ b/src/core/lib/security/credentials/external/url_external_account_credentials.cc @@ -0,0 +1,211 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include + +#include "src/core/lib/security/credentials/external/url_external_account_credentials.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" + +namespace grpc_core { + +RefCountedPtr +UrlExternalAccountCredentials::Create(Options options, + std::vector scopes, + grpc_error_handle* error) { + auto creds = MakeRefCounted( + std::move(options), std::move(scopes), error); + if (*error == GRPC_ERROR_NONE) { + return creds; + } else { + return nullptr; + } +} + +UrlExternalAccountCredentials::UrlExternalAccountCredentials( + Options options, std::vector scopes, grpc_error_handle* error) + : ExternalAccountCredentials(options, std::move(scopes)) { + auto it = options.credential_source.object_value().find("url"); + if (it == options.credential_source.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("url field not present."); + return; + } + if (it->second.type() != Json::Type::STRING) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("url field must be a string."); + return; + } + absl::StatusOr tmp_url = URI::Parse(it->second.string_value()); + if (!tmp_url.ok()) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Invalid credential source url. Error: %s", + tmp_url.status().ToString())); + return; + } + url_ = *tmp_url; + // The url must follow the format of :/// + std::vector v = + absl::StrSplit(it->second.string_value(), absl::MaxSplits('/', 3)); + url_full_path_ = absl::StrCat("/", v[3]); + it = options.credential_source.object_value().find("headers"); + if (it != options.credential_source.object_value().end()) { + if (it->second.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The JSON value of credential source headers is not an object."); + return; + } + for (auto const& header : it->second.object_value()) { + headers_[header.first] = header.second.string_value(); + } + } + it = options.credential_source.object_value().find("format"); + if (it != options.credential_source.object_value().end()) { + const Json& format_json = it->second; + if (format_json.type() != Json::Type::OBJECT) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The JSON value of credential source format is not an object."); + return; + } + auto format_it = format_json.object_value().find("type"); + if (format_it == format_json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.type field not present."); + return; + } + if (format_it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.type field must be a string."); + return; + } + format_type_ = format_it->second.string_value(); + if (format_type_ == "json") { + format_it = format_json.object_value().find("subject_token_field_name"); + if (format_it == format_json.object_value().end()) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.subject_token_field_name field must be present if the " + "format is in Json."); + return; + } + if (format_it->second.type() != Json::Type::STRING) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "format.subject_token_field_name field must be a string."); + return; + } + format_subject_token_field_name_ = format_it->second.string_value(); + } + } +} + +void UrlExternalAccountCredentials::RetrieveSubjectToken( + HTTPRequestContext* ctx, const Options& /*options*/, + std::function cb) { + if (ctx == nullptr) { + FinishRetrieveSubjectToken( + "", + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing HTTPRequestContext to start subject token retrieval.")); + return; + } + ctx_ = ctx; + cb_ = cb; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(url_.authority().c_str()); + request.http.path = gpr_strdup(url_full_path_.c_str()); + grpc_http_header* headers = nullptr; + request.http.hdr_count = headers_.size(); + headers = static_cast( + gpr_malloc(sizeof(grpc_http_header) * request.http.hdr_count)); + int i = 0; + for (auto const& header : headers_) { + headers[i].key = gpr_strdup(header.first.c_str()); + headers[i].value = gpr_strdup(header.second.c_str()); + ++i; + } + request.http.hdrs = headers; + request.handshaker = + url_.scheme() == "https" ? &grpc_httpcli_ssl : &grpc_httpcli_plaintext; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("external_account_credentials"); + grpc_http_response_destroy(&ctx_->response); + ctx_->response = {}; + GRPC_CLOSURE_INIT(&ctx_->closure, OnRetrieveSubjectToken, this, nullptr); + grpc_httpcli_get(ctx_->httpcli_context, ctx_->pollent, resource_quota, + &request, ctx_->deadline, &ctx_->closure, &ctx_->response); + grpc_http_request_destroy(&request.http); +} + +void UrlExternalAccountCredentials::OnRetrieveSubjectToken( + void* arg, grpc_error_handle error) { + UrlExternalAccountCredentials* self = + static_cast(arg); + self->OnRetrieveSubjectTokenInternal(GRPC_ERROR_REF(error)); +} + +void UrlExternalAccountCredentials::OnRetrieveSubjectTokenInternal( + grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + FinishRetrieveSubjectToken("", error); + return; + } + absl::string_view response_body(ctx_->response.body, + ctx_->response.body_length); + if (format_type_ == "json") { + grpc_error_handle error = GRPC_ERROR_NONE; + Json response_json = Json::Parse(response_body, &error); + if (error != GRPC_ERROR_NONE || + response_json.type() != Json::Type::OBJECT) { + FinishRetrieveSubjectToken( + "", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The format of response is not a valid json object.")); + return; + } + auto response_it = + response_json.object_value().find(format_subject_token_field_name_); + if (response_it == response_json.object_value().end()) { + FinishRetrieveSubjectToken("", GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Subject token field not present.")); + return; + } + if (response_it->second.type() != Json::Type::STRING) { + FinishRetrieveSubjectToken("", + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Subject token field must be a string.")); + return; + } + FinishRetrieveSubjectToken(response_it->second.string_value(), error); + return; + } + FinishRetrieveSubjectToken(std::string(response_body), GRPC_ERROR_NONE); +} + +void UrlExternalAccountCredentials::FinishRetrieveSubjectToken( + std::string subject_token, grpc_error_handle error) { + // Reset context + ctx_ = nullptr; + // Move object state into local variables. + auto cb = cb_; + cb_ = nullptr; + // Invoke the callback. + if (error != GRPC_ERROR_NONE) { + cb("", error); + } else { + cb(subject_token, GRPC_ERROR_NONE); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc new file mode 100644 index 00000000..57a897f0 --- /dev/null +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -0,0 +1,113 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/fake/fake_credentials.h" + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/security/security_connector/fake/fake_security_connector.h" + +/* -- Fake transport security credentials. -- */ + +namespace { +class grpc_fake_channel_credentials final : public grpc_channel_credentials { + public: + grpc_fake_channel_credentials() + : grpc_channel_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_channel_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** /*new_args*/) override { + return grpc_fake_channel_security_connector_create( + this->Ref(), std::move(call_creds), target, args); + } +}; + +class grpc_fake_server_credentials final : public grpc_server_credentials { + public: + grpc_fake_server_credentials() + : grpc_server_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_server_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector(const grpc_channel_args* /*args*/) override { + return grpc_fake_server_security_connector_create(this->Ref()); + } +}; +} // namespace + +grpc_channel_credentials* grpc_fake_transport_security_credentials_create() { + return new grpc_fake_channel_credentials(); +} + +grpc_server_credentials* +grpc_fake_transport_security_server_credentials_create() { + return new grpc_fake_server_credentials(); +} + +grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) { + return grpc_channel_arg_string_create( + const_cast(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS), + expected_targets); +} + +const char* grpc_fake_transport_get_expected_targets( + const grpc_channel_args* args) { + const grpc_arg* expected_target_arg = + grpc_channel_args_find(args, GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS); + return grpc_channel_arg_get_string(expected_target_arg); +} + +/* -- Metadata-only test credentials. -- */ + +bool grpc_md_only_test_credentials::get_request_metadata( + grpc_polling_entity* /*pollent*/, grpc_auth_metadata_context /*context*/, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error_handle* /*error*/) { + grpc_credentials_mdelem_array_add(md_array, md_); + if (is_async_) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_request_metadata, + GRPC_ERROR_NONE); + return false; + } + return true; +} + +void grpc_md_only_test_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* /*md_array*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +grpc_call_credentials* grpc_md_only_test_credentials_create( + const char* md_key, const char* md_value, bool is_async) { + return new grpc_md_only_test_credentials(md_key, md_value, is_async); +} diff --git a/src/core/lib/security/credentials/google_default/credentials_generic.cc b/src/core/lib/security/credentials/google_default/credentials_generic.cc new file mode 100644 index 00000000..1b56824d --- /dev/null +++ b/src/core/lib/security/credentials/google_default/credentials_generic.cc @@ -0,0 +1,42 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/security/credentials/google_default/google_default_credentials.h" + +std::string grpc_get_well_known_google_credentials_file_path_impl(void) { + char* base = gpr_getenv(GRPC_GOOGLE_CREDENTIALS_PATH_ENV_VAR); + if (base == nullptr) { + gpr_log(GPR_ERROR, "Could not get " GRPC_GOOGLE_CREDENTIALS_PATH_ENV_VAR + " environment variable."); + return ""; + } + std::string result = + absl::StrCat(base, "/", GRPC_GOOGLE_CREDENTIALS_PATH_SUFFIX); + gpr_free(base); + return result; +} diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc new file mode 100644 index 00000000..18710c5e --- /dev/null +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc @@ -0,0 +1,465 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/google_default/google_default_credentials.h" + +#include + +#include "absl/strings/match.h" +#include "absl/strings/strip.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/external/external_account_credentials.h" +#include "src/core/lib/security/credentials/jwt/jwt_credentials.h" +#include "src/core/lib/security/credentials/oauth2/oauth2_credentials.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/api_trace.h" + +using grpc_core::Json; + +/* -- Constants. -- */ + +#define GRPC_COMPUTE_ENGINE_DETECTION_HOST "metadata.google.internal." +#define GRPC_GOOGLE_CREDENTIAL_CREATION_ERROR \ + "Failed to create Google credentials" + +/* -- Default credentials. -- */ + +/* A sticky bit that will be set only if the result of metadata server detection + * is positive. We do not set the bit if the result is negative. Because it + * means the detection is done via network test that is unreliable and the + * unreliable result should not be referred by successive calls. */ +static int g_metadata_server_available = 0; +static grpc_core::Mutex* g_state_mu; +/* Protect a metadata_server_detector instance that can be modified by more than + * one gRPC threads */ +static gpr_mu* g_polling_mu; +static gpr_once g_once = GPR_ONCE_INIT; +static grpc_core::internal::grpc_gce_tenancy_checker g_gce_tenancy_checker = + grpc_alts_is_running_on_gcp; + +static void init_default_credentials(void) { + g_state_mu = new grpc_core::Mutex(); +} + +struct metadata_server_detector { + grpc_polling_entity pollent; + int is_done; + int success; + grpc_http_response response; +}; +grpc_core::RefCountedPtr +grpc_google_default_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) { + const bool is_grpclb_load_balancer = grpc_channel_args_find_bool( + args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER, false); + const bool is_backend_from_grpclb_load_balancer = grpc_channel_args_find_bool( + args, GRPC_ARG_ADDRESS_IS_BACKEND_FROM_GRPCLB_LOAD_BALANCER, false); + const char* xds_cluster = + grpc_channel_args_find_string(args, GRPC_ARG_XDS_CLUSTER_NAME); + const bool is_xds_non_cfe_cluster = + xds_cluster != nullptr && !absl::StartsWith(xds_cluster, "google_cfe_"); + const bool use_alts = is_grpclb_load_balancer || + is_backend_from_grpclb_load_balancer || + is_xds_non_cfe_cluster; + /* Return failure if ALTS is selected but not running on GCE. */ + if (use_alts && alts_creds_ == nullptr) { + gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE."); + return nullptr; + } + grpc_core::RefCountedPtr sc = + use_alts ? alts_creds_->create_security_connector(call_creds, target, + args, new_args) + : ssl_creds_->create_security_connector(call_creds, target, args, + new_args); + /* grpclb-specific channel args are removed from the channel args set + * to ensure backends and fallback adresses will have the same set of channel + * args. By doing that, it guarantees the connections to backends will not be + * torn down and re-connected when switching in and out of fallback mode. + */ + if (use_alts) { + static const char* args_to_remove[] = { + GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER, + GRPC_ARG_ADDRESS_IS_BACKEND_FROM_GRPCLB_LOAD_BALANCER, + }; + *new_args = grpc_channel_args_copy_and_add_and_remove( + args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0); + } + return sc; +} + +grpc_channel_args* grpc_google_default_channel_credentials::update_arguments( + grpc_channel_args* args) { + grpc_channel_args* updated = args; + if (grpc_channel_args_find(args, GRPC_ARG_DNS_ENABLE_SRV_QUERIES) == + nullptr) { + grpc_arg new_srv_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_DNS_ENABLE_SRV_QUERIES), true); + updated = grpc_channel_args_copy_and_add(args, &new_srv_arg, 1); + grpc_channel_args_destroy(args); + } + return updated; +} + +static void on_metadata_server_detection_http_response( + void* user_data, grpc_error_handle error) { + metadata_server_detector* detector = + static_cast(user_data); + if (error == GRPC_ERROR_NONE && detector->response.status == 200 && + detector->response.hdr_count > 0) { + /* Internet providers can return a generic response to all requests, so + it is necessary to check that metadata header is present also. */ + size_t i; + for (i = 0; i < detector->response.hdr_count; i++) { + grpc_http_header* header = &detector->response.hdrs[i]; + if (strcmp(header->key, "Metadata-Flavor") == 0 && + strcmp(header->value, "Google") == 0) { + detector->success = 1; + break; + } + } + } + gpr_mu_lock(g_polling_mu); + detector->is_done = 1; + GRPC_LOG_IF_ERROR( + "Pollset kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&detector->pollent), + nullptr)); + gpr_mu_unlock(g_polling_mu); +} + +static void destroy_pollset(void* p, grpc_error_handle /*e*/) { + grpc_pollset_destroy(static_cast(p)); +} + +static int is_metadata_server_reachable() { + metadata_server_detector detector; + grpc_httpcli_request request; + grpc_httpcli_context context; + grpc_closure destroy_closure; + /* The http call is local. If it takes more than one sec, it is for sure not + on compute engine. */ + grpc_millis max_detection_delay = GPR_MS_PER_SEC; + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &g_polling_mu); + detector.pollent = grpc_polling_entity_create_from_pollset(pollset); + detector.is_done = 0; + detector.success = 0; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(GRPC_COMPUTE_ENGINE_DETECTION_HOST); + request.http.path = const_cast("/"); + grpc_httpcli_context_init(&context); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("google_default_credentials"); + grpc_httpcli_get( + &context, &detector.pollent, resource_quota, &request, + grpc_core::ExecCtx::Get()->Now() + max_detection_delay, + GRPC_CLOSURE_CREATE(on_metadata_server_detection_http_response, &detector, + grpc_schedule_on_exec_ctx), + &detector.response); + grpc_core::ExecCtx::Get()->Flush(); + /* Block until we get the response. This is not ideal but this should only be + called once for the lifetime of the process by the default credentials. */ + gpr_mu_lock(g_polling_mu); + while (!detector.is_done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(grpc_polling_entity_pollset(&detector.pollent), + &worker, GRPC_MILLIS_INF_FUTURE))) { + detector.is_done = 1; + detector.success = 0; + } + } + gpr_mu_unlock(g_polling_mu); + grpc_httpcli_context_destroy(&context); + GRPC_CLOSURE_INIT(&destroy_closure, destroy_pollset, + grpc_polling_entity_pollset(&detector.pollent), + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&detector.pollent), + &destroy_closure); + g_polling_mu = nullptr; + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(grpc_polling_entity_pollset(&detector.pollent)); + grpc_http_response_destroy(&detector.response); + return detector.success; +} + +namespace { + +bool ValidateUrlField(const Json& json, const std::string& field) { + auto it = json.object_value().find(field); + if (it == json.object_value().end()) { + return true; + } + if (it->second.type() != Json::Type::STRING || + it->second.string_value().empty()) { + return false; + } + absl::StatusOr url = + grpc_core::URI::Parse(it->second.string_value()); + if (!url.ok()) return false; + if (!absl::EqualsIgnoreCase(url->scheme(), "https")) { + return false; + } + absl::string_view host; + absl::string_view port; + grpc_core::SplitHostPort(url->authority(), &host, &port); + if (absl::ConsumeSuffix(&host, ".googleapis.com")) { + if (host == "sts" || host == "iamcredentials") { + return true; + } else if (absl::StartsWith(host, "sts.") || + absl::StartsWith(host, "iamcredentials.")) { + return true; + } else if (absl::EndsWith(host, ".sts") || + absl::EndsWith(host, ".iamcredentials")) { + return true; + } else if (absl::EndsWith(host, "-sts") || + absl::EndsWith(host, "-iamcredentials")) { + return true; + } + } + return false; +} + +bool ValidateExteralAccountCredentials(const Json& json) { + return json.type() == Json::Type::OBJECT && + ValidateUrlField(json, "token_url") && + ValidateUrlField(json, "service_account_impersonation_url") && + ValidateUrlField(json, "token_info_url"); +} + +} // namespace + +/* Takes ownership of creds_path if not NULL. */ +static grpc_error_handle create_default_creds_from_path( + const std::string& creds_path, + grpc_core::RefCountedPtr* creds) { + grpc_auth_json_key key; + grpc_auth_refresh_token token; + grpc_core::RefCountedPtr result; + grpc_slice creds_data = grpc_empty_slice(); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json; + if (creds_path.empty()) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("creds_path unset"); + goto end; + } + error = grpc_load_file(creds_path.c_str(), 0, &creds_data); + if (error != GRPC_ERROR_NONE) goto end; + json = Json::Parse(grpc_core::StringViewFromSlice(creds_data), &error); + if (error != GRPC_ERROR_NONE) goto end; + if (json.type() != Json::Type::OBJECT) { + error = grpc_error_set_str( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to parse JSON"), + GRPC_ERROR_STR_RAW_BYTES, grpc_core::StringViewFromSlice(creds_data)); + goto end; + } + + /* First, try an auth json key. */ + key = grpc_auth_json_key_create_from_json(json); + if (grpc_auth_json_key_is_valid(&key)) { + result = + grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + key, grpc_max_auth_token_lifetime()); + if (result == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "grpc_service_account_jwt_access_credentials_create_from_auth_json_" + "key failed"); + } + goto end; + } + + /* Then try a refresh token if the auth json key was invalid. */ + token = grpc_auth_refresh_token_create_from_json(json); + if (grpc_auth_refresh_token_is_valid(&token)) { + result = + grpc_refresh_token_credentials_create_from_auth_refresh_token(token); + if (result == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "grpc_refresh_token_credentials_create_from_auth_refresh_token " + "failed"); + } + goto end; + } + + /* Finally try an external account credentials.*/ + if (!ValidateExteralAccountCredentials(json)) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid external account credentials format."); + goto end; + } + result = grpc_core::ExternalAccountCredentials::Create(json, {}, &error); + +end: + GPR_ASSERT((result == nullptr) + (error == GRPC_ERROR_NONE) == 1); + grpc_slice_unref_internal(creds_data); + *creds = result; + return error; +} + +static void update_tenancy() { + gpr_once_init(&g_once, init_default_credentials); + grpc_core::MutexLock lock(g_state_mu); + + /* Try a platform-provided hint for GCE. */ + if (!g_metadata_server_available) { + g_metadata_server_available = g_gce_tenancy_checker(); + } + /* TODO: Add a platform-provided hint for GAE. */ + + /* Do a network test for metadata server. */ + if (!g_metadata_server_available) { + g_metadata_server_available = is_metadata_server_reachable(); + } +} + +static bool metadata_server_available() { + grpc_core::MutexLock lock(g_state_mu); + return static_cast(g_metadata_server_available); +} + +static grpc_core::RefCountedPtr make_default_call_creds( + grpc_error_handle* error) { + grpc_core::RefCountedPtr call_creds; + grpc_error_handle err; + + /* First, try the environment variable. */ + char* path_from_env = gpr_getenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR); + if (path_from_env != nullptr) { + err = create_default_creds_from_path(path_from_env, &call_creds); + gpr_free(path_from_env); + if (err == GRPC_ERROR_NONE) return call_creds; + *error = grpc_error_add_child(*error, err); + } + + /* Then the well-known file. */ + err = create_default_creds_from_path( + grpc_get_well_known_google_credentials_file_path(), &call_creds); + if (err == GRPC_ERROR_NONE) return call_creds; + *error = grpc_error_add_child(*error, err); + + update_tenancy(); + + if (metadata_server_available()) { + call_creds = grpc_core::RefCountedPtr( + grpc_google_compute_engine_credentials_create(nullptr)); + if (call_creds == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + GRPC_GOOGLE_CREDENTIAL_CREATION_ERROR); + *error = grpc_error_add_child( + *error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Failed to get credentials from network")); + } + } + + return call_creds; +} + +grpc_channel_credentials* grpc_google_default_credentials_create( + grpc_call_credentials* call_credentials) { + grpc_channel_credentials* result = nullptr; + grpc_core::RefCountedPtr call_creds(call_credentials); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::ExecCtx exec_ctx; + + GRPC_API_TRACE("grpc_google_default_credentials_create(%p)", 1, + (call_credentials)); + + if (call_creds == nullptr) { + call_creds = make_default_call_creds(&error); + } + + if (call_creds != nullptr) { + /* Create google default credentials. */ + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + GPR_ASSERT(ssl_creds != nullptr); + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + grpc_channel_credentials* alts_creds = + grpc_alts_credentials_create(options); + grpc_alts_credentials_options_destroy(options); + auto creds = + grpc_core::MakeRefCounted( + grpc_core::RefCountedPtr(alts_creds), + grpc_core::RefCountedPtr(ssl_creds)); + result = grpc_composite_channel_credentials_create( + creds.get(), call_creds.get(), nullptr); + GPR_ASSERT(result != nullptr); + } else { + gpr_log(GPR_ERROR, "Could not create google default credentials: %s", + grpc_error_std_string(error).c_str()); + } + GRPC_ERROR_UNREF(error); + return result; +} + +namespace grpc_core { +namespace internal { + +void set_gce_tenancy_checker_for_testing(grpc_gce_tenancy_checker checker) { + g_gce_tenancy_checker = checker; +} + +void grpc_flush_cached_google_default_credentials(void) { + grpc_core::ExecCtx exec_ctx; + gpr_once_init(&g_once, init_default_credentials); + grpc_core::MutexLock lock(g_state_mu); + g_metadata_server_available = 0; +} + +} // namespace internal +} // namespace grpc_core + +/* -- Well known credentials path. -- */ + +static grpc_well_known_credentials_path_getter creds_path_getter = nullptr; + +std::string grpc_get_well_known_google_credentials_file_path(void) { + if (creds_path_getter != nullptr) return creds_path_getter(); + return grpc_get_well_known_google_credentials_file_path_impl(); +} + +void grpc_override_well_known_credentials_path_getter( + grpc_well_known_credentials_path_getter getter) { + creds_path_getter = getter; +} diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc new file mode 100644 index 00000000..36e7f11f --- /dev/null +++ b/src/core/lib/security/credentials/iam/iam_credentials.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/iam/iam_credentials.h" + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/surface/api_trace.h" + +grpc_google_iam_credentials::~grpc_google_iam_credentials() { + grpc_credentials_mdelem_array_destroy(&md_array_); +} + +bool grpc_google_iam_credentials::get_request_metadata( + grpc_polling_entity* /*pollent*/, grpc_auth_metadata_context /*context*/, + grpc_credentials_mdelem_array* md_array, + grpc_closure* /*on_request_metadata*/, grpc_error_handle* /*error*/) { + grpc_credentials_mdelem_array_append(md_array, &md_array_); + return true; +} + +void grpc_google_iam_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* /*md_array*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +grpc_google_iam_credentials::grpc_google_iam_credentials( + const char* token, const char* authority_selector) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM), + debug_string_(absl::StrFormat( + "GoogleIAMCredentials{Token:%s,AuthoritySelector:%s}", + token != nullptr ? "present" : "absent", authority_selector)) { + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), + grpc_slice_from_copied_string(token)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); + md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), + grpc_slice_from_copied_string(authority_selector)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); +} + +grpc_call_credentials* grpc_google_iam_credentials_create( + const char* token, const char* authority_selector, void* reserved) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_iam_credentials_create(token=%s, authority_selector=%s, " + "reserved=%p)", + 3, (token, authority_selector, reserved)); + GPR_ASSERT(reserved == nullptr); + GPR_ASSERT(token != nullptr); + GPR_ASSERT(authority_selector != nullptr); + return grpc_core::MakeRefCounted( + token, authority_selector) + .release(); +} diff --git a/src/core/lib/security/credentials/insecure/insecure_credentials.cc b/src/core/lib/security/credentials/insecure/insecure_credentials.cc new file mode 100644 index 00000000..4cf500e6 --- /dev/null +++ b/src/core/lib/security/credentials/insecure/insecure_credentials.cc @@ -0,0 +1,64 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include + +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/insecure/insecure_security_connector.h" + +namespace grpc_core { +namespace { + +constexpr char kCredentialsTypeInsecure[] = "insecure"; + +class InsecureCredentials final : public grpc_channel_credentials { + public: + InsecureCredentials() : grpc_channel_credentials(kCredentialsTypeInsecure) {} + + RefCountedPtr create_security_connector( + RefCountedPtr call_creds, + const char* /* target_name */, const grpc_channel_args* /* args */, + grpc_channel_args** /* new_args */) override { + return MakeRefCounted( + Ref(), std::move(call_creds)); + } +}; + +class InsecureServerCredentials final : public grpc_server_credentials { + public: + InsecureServerCredentials() + : grpc_server_credentials(kCredentialsTypeInsecure) {} + + RefCountedPtr create_security_connector( + const grpc_channel_args* /* args */) override { + return MakeRefCounted(Ref()); + } +}; + +} // namespace +} // namespace grpc_core + +grpc_channel_credentials* grpc_insecure_credentials_create() { + return new grpc_core::InsecureCredentials(); +} + +grpc_server_credentials* grpc_insecure_server_credentials_create() { + return new grpc_core::InsecureServerCredentials(); +} diff --git a/src/core/lib/security/credentials/jwt/json_token.cc b/src/core/lib/security/credentials/jwt/json_token.cc new file mode 100644 index 00000000..f44ef0f6 --- /dev/null +++ b/src/core/lib/security/credentials/jwt/json_token.cc @@ -0,0 +1,288 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/jwt/json_token.h" + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/util/json_util.h" +#include "src/core/lib/slice/b64.h" + +extern "C" { +#include +#include +#include +} + +using grpc_core::Json; + +/* --- Constants. --- */ + +/* 1 hour max. */ +gpr_timespec grpc_max_auth_token_lifetime() { + gpr_timespec out; + out.tv_sec = 3600; + out.tv_nsec = 0; + out.clock_type = GPR_TIMESPAN; + return out; +} + +#define GRPC_JWT_RSA_SHA256_ALGORITHM "RS256" +#define GRPC_JWT_TYPE "JWT" + +/* --- Override for testing. --- */ + +static grpc_jwt_encode_and_sign_override g_jwt_encode_and_sign_override = + nullptr; + +/* --- grpc_auth_json_key. --- */ + +int grpc_auth_json_key_is_valid(const grpc_auth_json_key* json_key) { + return (json_key != nullptr) && + strcmp(json_key->type, GRPC_AUTH_JSON_TYPE_INVALID) != 0; +} + +grpc_auth_json_key grpc_auth_json_key_create_from_json(const Json& json) { + grpc_auth_json_key result; + BIO* bio = nullptr; + const char* prop_value; + int success = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + + memset(&result, 0, sizeof(grpc_auth_json_key)); + result.type = GRPC_AUTH_JSON_TYPE_INVALID; + if (json.type() == Json::Type::JSON_NULL) { + gpr_log(GPR_ERROR, "Invalid json."); + goto end; + } + + prop_value = grpc_json_get_string_property(json, "type", &error); + GRPC_LOG_IF_ERROR("JSON key parsing", error); + if (prop_value == nullptr || + strcmp(prop_value, GRPC_AUTH_JSON_TYPE_SERVICE_ACCOUNT) != 0) { + goto end; + } + result.type = GRPC_AUTH_JSON_TYPE_SERVICE_ACCOUNT; + + if (!grpc_copy_json_string_property(json, "private_key_id", + &result.private_key_id) || + !grpc_copy_json_string_property(json, "client_id", &result.client_id) || + !grpc_copy_json_string_property(json, "client_email", + &result.client_email)) { + goto end; + } + + prop_value = grpc_json_get_string_property(json, "private_key", &error); + GRPC_LOG_IF_ERROR("JSON key parsing", error); + if (prop_value == nullptr) { + goto end; + } + bio = BIO_new(BIO_s_mem()); + success = BIO_puts(bio, prop_value); + if ((success < 0) || (static_cast(success) != strlen(prop_value))) { + gpr_log(GPR_ERROR, "Could not write into openssl BIO."); + goto end; + } + result.private_key = + PEM_read_bio_RSAPrivateKey(bio, nullptr, nullptr, const_cast("")); + if (result.private_key == nullptr) { + gpr_log(GPR_ERROR, "Could not deserialize private key."); + goto end; + } + success = 1; + +end: + if (bio != nullptr) BIO_free(bio); + if (!success) grpc_auth_json_key_destruct(&result); + return result; +} + +grpc_auth_json_key grpc_auth_json_key_create_from_string( + const char* json_string) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_string, &error); + GRPC_LOG_IF_ERROR("JSON key parsing", error); + return grpc_auth_json_key_create_from_json(json); +} + +void grpc_auth_json_key_destruct(grpc_auth_json_key* json_key) { + if (json_key == nullptr) return; + json_key->type = GRPC_AUTH_JSON_TYPE_INVALID; + if (json_key->client_id != nullptr) { + gpr_free(json_key->client_id); + json_key->client_id = nullptr; + } + if (json_key->private_key_id != nullptr) { + gpr_free(json_key->private_key_id); + json_key->private_key_id = nullptr; + } + if (json_key->client_email != nullptr) { + gpr_free(json_key->client_email); + json_key->client_email = nullptr; + } + if (json_key->private_key != nullptr) { + RSA_free(json_key->private_key); + json_key->private_key = nullptr; + } +} + +/* --- jwt encoding and signature. --- */ + +static char* encoded_jwt_header(const char* key_id, const char* algorithm) { + Json json = Json::Object{ + {"alg", algorithm}, + {"typ", GRPC_JWT_TYPE}, + {"kid", key_id}, + }; + std::string json_str = json.Dump(); + return grpc_base64_encode(json_str.c_str(), json_str.size(), 1, 0); +} + +static char* encoded_jwt_claim(const grpc_auth_json_key* json_key, + const char* audience, + gpr_timespec token_lifetime, const char* scope) { + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + gpr_timespec expiration = gpr_time_add(now, token_lifetime); + if (gpr_time_cmp(token_lifetime, grpc_max_auth_token_lifetime()) > 0) { + gpr_log(GPR_INFO, "Cropping token lifetime to maximum allowed value."); + expiration = gpr_time_add(now, grpc_max_auth_token_lifetime()); + } + + Json::Object object = { + {"iss", json_key->client_email}, + {"aud", audience}, + {"iat", now.tv_sec}, + {"exp", expiration.tv_sec}, + }; + if (scope != nullptr) { + object["scope"] = scope; + } else { + /* Unscoped JWTs need a sub field. */ + object["sub"] = json_key->client_email; + } + + Json json(object); + std::string json_str = json.Dump(); + return grpc_base64_encode(json_str.c_str(), json_str.size(), 1, 0); +} + +static char* dot_concat_and_free_strings(char* str1, char* str2) { + size_t str1_len = strlen(str1); + size_t str2_len = strlen(str2); + size_t result_len = str1_len + 1 /* dot */ + str2_len; + char* result = + static_cast(gpr_malloc(result_len + 1 /* NULL terminated */)); + char* current = result; + memcpy(current, str1, str1_len); + current += str1_len; + *(current++) = '.'; + memcpy(current, str2, str2_len); + current += str2_len; + GPR_ASSERT(current >= result); + GPR_ASSERT((uintptr_t)(current - result) == result_len); + *current = '\0'; + gpr_free(str1); + gpr_free(str2); + return result; +} + +const EVP_MD* openssl_digest_from_algorithm(const char* algorithm) { + if (strcmp(algorithm, GRPC_JWT_RSA_SHA256_ALGORITHM) == 0) { + return EVP_sha256(); + } else { + gpr_log(GPR_ERROR, "Unknown algorithm %s.", algorithm); + return nullptr; + } +} + +char* compute_and_encode_signature(const grpc_auth_json_key* json_key, + const char* signature_algorithm, + const char* to_sign) { + const EVP_MD* md = openssl_digest_from_algorithm(signature_algorithm); + EVP_MD_CTX* md_ctx = nullptr; + EVP_PKEY* key = EVP_PKEY_new(); + size_t sig_len = 0; + unsigned char* sig = nullptr; + char* result = nullptr; + if (md == nullptr) return nullptr; + md_ctx = EVP_MD_CTX_create(); + if (md_ctx == nullptr) { + gpr_log(GPR_ERROR, "Could not create MD_CTX"); + goto end; + } + EVP_PKEY_set1_RSA(key, json_key->private_key); + if (EVP_DigestSignInit(md_ctx, nullptr, md, nullptr, key) != 1) { + gpr_log(GPR_ERROR, "DigestInit failed."); + goto end; + } + if (EVP_DigestSignUpdate(md_ctx, to_sign, strlen(to_sign)) != 1) { + gpr_log(GPR_ERROR, "DigestUpdate failed."); + goto end; + } + if (EVP_DigestSignFinal(md_ctx, nullptr, &sig_len) != 1) { + gpr_log(GPR_ERROR, "DigestFinal (get signature length) failed."); + goto end; + } + sig = static_cast(gpr_malloc(sig_len)); + if (EVP_DigestSignFinal(md_ctx, sig, &sig_len) != 1) { + gpr_log(GPR_ERROR, "DigestFinal (signature compute) failed."); + goto end; + } + result = grpc_base64_encode(sig, sig_len, 1, 0); + +end: + if (key != nullptr) EVP_PKEY_free(key); + if (md_ctx != nullptr) EVP_MD_CTX_destroy(md_ctx); + if (sig != nullptr) gpr_free(sig); + return result; +} + +char* grpc_jwt_encode_and_sign(const grpc_auth_json_key* json_key, + const char* audience, + gpr_timespec token_lifetime, const char* scope) { + if (g_jwt_encode_and_sign_override != nullptr) { + return g_jwt_encode_and_sign_override(json_key, audience, token_lifetime, + scope); + } else { + const char* sig_algo = GRPC_JWT_RSA_SHA256_ALGORITHM; + char* to_sign = dot_concat_and_free_strings( + encoded_jwt_header(json_key->private_key_id, sig_algo), + encoded_jwt_claim(json_key, audience, token_lifetime, scope)); + char* sig = compute_and_encode_signature(json_key, sig_algo, to_sign); + if (sig == nullptr) { + gpr_free(to_sign); + return nullptr; + } + return dot_concat_and_free_strings(to_sign, sig); + } +} + +void grpc_jwt_encode_and_sign_set_override( + grpc_jwt_encode_and_sign_override func) { + g_jwt_encode_and_sign_override = func; +} diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc new file mode 100644 index 00000000..f9132578 --- /dev/null +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/jwt/jwt_credentials.h" + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/uri/uri_parser.h" + +using grpc_core::Json; + +void grpc_service_account_jwt_access_credentials::reset_cache() { + GRPC_MDELEM_UNREF(cached_.jwt_md); + cached_.jwt_md = GRPC_MDNULL; + cached_.service_url.clear(); + cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); +} + +grpc_service_account_jwt_access_credentials:: + ~grpc_service_account_jwt_access_credentials() { + grpc_auth_json_key_destruct(&key_); + reset_cache(); + gpr_mu_destroy(&cache_mu_); +} + +bool grpc_service_account_jwt_access_credentials::get_request_metadata( + grpc_polling_entity* /*pollent*/, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* /*on_request_metadata*/, grpc_error_handle* error) { + gpr_timespec refresh_threshold = gpr_time_from_seconds( + GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN); + + // Remove service name from service_url to follow the audience format + // dictated in https://google.aip.dev/auth/4111. + absl::StatusOr uri = + grpc_core::RemoveServiceNameFromJwtUri(context.service_url); + if (!uri.ok()) { + *error = absl_status_to_grpc_error(uri.status()); + return true; + } + /* See if we can return a cached jwt. */ + grpc_mdelem jwt_md = GRPC_MDNULL; + { + gpr_mu_lock(&cache_mu_); + if (!cached_.service_url.empty() && cached_.service_url == *uri && + !GRPC_MDISNULL(cached_.jwt_md) && + (gpr_time_cmp( + gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)), + refresh_threshold) > 0)) { + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); + } + gpr_mu_unlock(&cache_mu_); + } + + if (GRPC_MDISNULL(jwt_md)) { + char* jwt = nullptr; + /* Generate a new jwt. */ + gpr_mu_lock(&cache_mu_); + reset_cache(); + jwt = grpc_jwt_encode_and_sign(&key_, uri->c_str(), jwt_lifetime_, nullptr); + if (jwt != nullptr) { + std::string md_value = absl::StrCat("Bearer ", jwt); + gpr_free(jwt); + cached_.jwt_expiration = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_); + cached_.service_url = std::move(*uri); + cached_.jwt_md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_cpp_string(std::move(md_value))); + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); + } + gpr_mu_unlock(&cache_mu_); + } + + if (!GRPC_MDISNULL(jwt_md)) { + grpc_credentials_mdelem_array_add(md_array, jwt_md); + GRPC_MDELEM_UNREF(jwt_md); + } else { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Could not generate JWT."); + } + return true; +} + +void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* /*md_array*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +grpc_service_account_jwt_access_credentials:: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) { + gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime(); + if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) { + gpr_log(GPR_INFO, + "Cropping token lifetime to maximum allowed value (%d secs).", + static_cast(max_token_lifetime.tv_sec)); + token_lifetime = grpc_max_auth_token_lifetime(); + } + jwt_lifetime_ = token_lifetime; + gpr_mu_init(&cache_mu_); + reset_cache(); +} + +grpc_core::RefCountedPtr +grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key key, gpr_timespec token_lifetime) { + if (!grpc_auth_json_key_is_valid(&key)) { + gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); + return nullptr; + } + return grpc_core::MakeRefCounted( + key, token_lifetime); +} + +static char* redact_private_key(const char* json_key) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_key, &error); + if (error != GRPC_ERROR_NONE || json.type() != Json::Type::OBJECT) { + GRPC_ERROR_UNREF(error); + return gpr_strdup(""); + } + (*json.mutable_object())["private_key"] = ""; + return gpr_strdup(json.Dump(/*indent=*/2).c_str()); +} + +grpc_call_credentials* grpc_service_account_jwt_access_credentials_create( + const char* json_key, gpr_timespec token_lifetime, void* reserved) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace)) { + char* clean_json = redact_private_key(json_key); + gpr_log(GPR_INFO, + "grpc_service_account_jwt_access_credentials_create(" + "json_key=%s, " + "token_lifetime=" + "gpr_timespec { tv_sec: %" PRId64 + ", tv_nsec: %d, clock_type: %d }, " + "reserved=%p)", + clean_json, token_lifetime.tv_sec, token_lifetime.tv_nsec, + static_cast(token_lifetime.clock_type), reserved); + gpr_free(clean_json); + } + GPR_ASSERT(reserved == nullptr); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + return grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key_create_from_string(json_key), token_lifetime) + .release(); +} + +namespace grpc_core { + +absl::StatusOr RemoveServiceNameFromJwtUri(absl::string_view uri) { + auto parsed = grpc_core::URI::Parse(uri); + if (!parsed.ok()) { + return parsed.status(); + } + return absl::StrFormat("%s://%s/", parsed->scheme(), parsed->authority()); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/jwt/jwt_verifier.cc b/src/core/lib/security/credentials/jwt/jwt_verifier.cc new file mode 100644 index 00000000..010ef784 --- /dev/null +++ b/src/core/lib/security/credentials/jwt/jwt_verifier.cc @@ -0,0 +1,922 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/jwt/jwt_verifier.h" + +#include +#include + +#include +#include +#include +#include + +extern "C" { +#include +#include +#include +} + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/ssl_types.h" + +using grpc_core::Json; + +/* --- Utils. --- */ + +const char* grpc_jwt_verifier_status_to_string( + grpc_jwt_verifier_status status) { + switch (status) { + case GRPC_JWT_VERIFIER_OK: + return "OK"; + case GRPC_JWT_VERIFIER_BAD_SIGNATURE: + return "BAD_SIGNATURE"; + case GRPC_JWT_VERIFIER_BAD_FORMAT: + return "BAD_FORMAT"; + case GRPC_JWT_VERIFIER_BAD_AUDIENCE: + return "BAD_AUDIENCE"; + case GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR: + return "KEY_RETRIEVAL_ERROR"; + case GRPC_JWT_VERIFIER_TIME_CONSTRAINT_FAILURE: + return "TIME_CONSTRAINT_FAILURE"; + case GRPC_JWT_VERIFIER_GENERIC_ERROR: + return "GENERIC_ERROR"; + default: + return "UNKNOWN"; + } +} + +static const EVP_MD* evp_md_from_alg(const char* alg) { + if (strcmp(alg, "RS256") == 0) { + return EVP_sha256(); + } else if (strcmp(alg, "RS384") == 0) { + return EVP_sha384(); + } else if (strcmp(alg, "RS512") == 0) { + return EVP_sha512(); + } else { + return nullptr; + } +} + +static Json parse_json_part_from_jwt(const char* str, size_t len) { + grpc_slice slice = grpc_base64_decode_with_len(str, len, 1); + if (GRPC_SLICE_IS_EMPTY(slice)) { + gpr_log(GPR_ERROR, "Invalid base64."); + return Json(); // JSON null + } + absl::string_view string = grpc_core::StringViewFromSlice(slice); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(string, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + json = Json(); // JSON null + } + grpc_slice_unref_internal(slice); + return json; +} + +static const char* validate_string_field(const Json& json, const char* key) { + if (json.type() != Json::Type::STRING) { + gpr_log(GPR_ERROR, "Invalid %s field", key); + return nullptr; + } + return json.string_value().c_str(); +} + +static gpr_timespec validate_time_field(const Json& json, const char* key) { + gpr_timespec result = gpr_time_0(GPR_CLOCK_REALTIME); + if (json.type() != Json::Type::NUMBER) { + gpr_log(GPR_ERROR, "Invalid %s field", key); + return result; + } + result.tv_sec = strtol(json.string_value().c_str(), nullptr, 10); + return result; +} + +/* --- JOSE header. see http://tools.ietf.org/html/rfc7515#section-4 --- */ + +struct jose_header { + const char* alg; + const char* kid; + const char* typ; + /* TODO(jboeuf): Add others as needed (jku, jwk, x5u, x5c and so on...). */ + grpc_core::ManualConstructor json; +}; +static void jose_header_destroy(jose_header* h) { + h->json.Destroy(); + gpr_free(h); +} + +static jose_header* jose_header_from_json(Json json) { + const char* alg_value; + Json::Object::const_iterator it; + jose_header* h = grpc_core::Zalloc(); + if (json.type() != Json::Type::OBJECT) { + gpr_log(GPR_ERROR, "JSON value is not an object"); + goto error; + } + // Check alg field. + it = json.object_value().find("alg"); + if (it == json.object_value().end()) { + gpr_log(GPR_ERROR, "Missing alg field."); + goto error; + } + /* We only support RSA-1.5 signatures for now. + Beware of this if we add HMAC support: + https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/ + */ + alg_value = it->second.string_value().c_str(); + if (it->second.type() != Json::Type::STRING || + strncmp(alg_value, "RS", 2) != 0 || + evp_md_from_alg(alg_value) == nullptr) { + gpr_log(GPR_ERROR, "Invalid alg field"); + goto error; + } + h->alg = alg_value; + // Check typ field. + it = json.object_value().find("typ"); + if (it != json.object_value().end()) { + h->typ = validate_string_field(it->second, "typ"); + if (h->typ == nullptr) goto error; + } + // Check kid field. + it = json.object_value().find("kid"); + if (it != json.object_value().end()) { + h->kid = validate_string_field(it->second, "kid"); + if (h->kid == nullptr) goto error; + } + h->json.Init(std::move(json)); + return h; + +error: + jose_header_destroy(h); + return nullptr; +} + +/* --- JWT claims. see http://tools.ietf.org/html/rfc7519#section-4.1 */ + +struct grpc_jwt_claims { + /* Well known properties already parsed. */ + const char* sub; + const char* iss; + const char* aud; + const char* jti; + gpr_timespec iat; + gpr_timespec exp; + gpr_timespec nbf; + + grpc_core::ManualConstructor json; +}; + +void grpc_jwt_claims_destroy(grpc_jwt_claims* claims) { + claims->json.Destroy(); + gpr_free(claims); +} + +const Json* grpc_jwt_claims_json(const grpc_jwt_claims* claims) { + if (claims == nullptr) return nullptr; + return claims->json.get(); +} + +const char* grpc_jwt_claims_subject(const grpc_jwt_claims* claims) { + if (claims == nullptr) return nullptr; + return claims->sub; +} + +const char* grpc_jwt_claims_issuer(const grpc_jwt_claims* claims) { + if (claims == nullptr) return nullptr; + return claims->iss; +} + +const char* grpc_jwt_claims_id(const grpc_jwt_claims* claims) { + if (claims == nullptr) return nullptr; + return claims->jti; +} + +const char* grpc_jwt_claims_audience(const grpc_jwt_claims* claims) { + if (claims == nullptr) return nullptr; + return claims->aud; +} + +gpr_timespec grpc_jwt_claims_issued_at(const grpc_jwt_claims* claims) { + if (claims == nullptr) return gpr_inf_past(GPR_CLOCK_REALTIME); + return claims->iat; +} + +gpr_timespec grpc_jwt_claims_expires_at(const grpc_jwt_claims* claims) { + if (claims == nullptr) return gpr_inf_future(GPR_CLOCK_REALTIME); + return claims->exp; +} + +gpr_timespec grpc_jwt_claims_not_before(const grpc_jwt_claims* claims) { + if (claims == nullptr) return gpr_inf_past(GPR_CLOCK_REALTIME); + return claims->nbf; +} + +grpc_jwt_claims* grpc_jwt_claims_from_json(Json json) { + grpc_jwt_claims* claims = grpc_core::Zalloc(); + claims->json.Init(std::move(json)); + claims->iat = gpr_inf_past(GPR_CLOCK_REALTIME); + claims->nbf = gpr_inf_past(GPR_CLOCK_REALTIME); + claims->exp = gpr_inf_future(GPR_CLOCK_REALTIME); + + /* Per the spec, all fields are optional. */ + for (const auto& p : claims->json->object_value()) { + if (p.first == "sub") { + claims->sub = validate_string_field(p.second, "sub"); + if (claims->sub == nullptr) goto error; + } else if (p.first == "iss") { + claims->iss = validate_string_field(p.second, "iss"); + if (claims->iss == nullptr) goto error; + } else if (p.first == "aud") { + claims->aud = validate_string_field(p.second, "aud"); + if (claims->aud == nullptr) goto error; + } else if (p.first == "jti") { + claims->jti = validate_string_field(p.second, "jti"); + if (claims->jti == nullptr) goto error; + } else if (p.first == "iat") { + claims->iat = validate_time_field(p.second, "iat"); + if (gpr_time_cmp(claims->iat, gpr_time_0(GPR_CLOCK_REALTIME)) == 0) { + goto error; + } + } else if (p.first == "exp") { + claims->exp = validate_time_field(p.second, "exp"); + if (gpr_time_cmp(claims->exp, gpr_time_0(GPR_CLOCK_REALTIME)) == 0) { + goto error; + } + } else if (p.first == "nbf") { + claims->nbf = validate_time_field(p.second, "nbf"); + if (gpr_time_cmp(claims->nbf, gpr_time_0(GPR_CLOCK_REALTIME)) == 0) { + goto error; + } + } + } + return claims; + +error: + grpc_jwt_claims_destroy(claims); + return nullptr; +} + +grpc_jwt_verifier_status grpc_jwt_claims_check(const grpc_jwt_claims* claims, + const char* audience) { + gpr_timespec skewed_now; + int audience_ok; + + GPR_ASSERT(claims != nullptr); + + skewed_now = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), grpc_jwt_verifier_clock_skew); + if (gpr_time_cmp(skewed_now, claims->nbf) < 0) { + gpr_log(GPR_ERROR, "JWT is not valid yet."); + return GRPC_JWT_VERIFIER_TIME_CONSTRAINT_FAILURE; + } + skewed_now = + gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), grpc_jwt_verifier_clock_skew); + if (gpr_time_cmp(skewed_now, claims->exp) > 0) { + gpr_log(GPR_ERROR, "JWT is expired."); + return GRPC_JWT_VERIFIER_TIME_CONSTRAINT_FAILURE; + } + + /* This should be probably up to the upper layer to decide but let's harcode + the 99% use case here for email issuers, where the JWT must be self + issued. */ + if (grpc_jwt_issuer_email_domain(claims->iss) != nullptr && + claims->sub != nullptr && strcmp(claims->iss, claims->sub) != 0) { + gpr_log(GPR_ERROR, + "Email issuer (%s) cannot assert another subject (%s) than itself.", + claims->iss, claims->sub); + return GRPC_JWT_VERIFIER_BAD_SUBJECT; + } + + if (audience == nullptr) { + audience_ok = claims->aud == nullptr; + } else { + audience_ok = claims->aud != nullptr && strcmp(audience, claims->aud) == 0; + } + if (!audience_ok) { + gpr_log(GPR_ERROR, "Audience mismatch: expected %s and found %s.", + audience == nullptr ? "NULL" : audience, + claims->aud == nullptr ? "NULL" : claims->aud); + return GRPC_JWT_VERIFIER_BAD_AUDIENCE; + } + return GRPC_JWT_VERIFIER_OK; +} + +/* --- verifier_cb_ctx object. --- */ + +typedef enum { + HTTP_RESPONSE_OPENID = 0, + HTTP_RESPONSE_KEYS, + HTTP_RESPONSE_COUNT /* must be last */ +} http_response_index; + +struct verifier_cb_ctx { + grpc_jwt_verifier* verifier; + grpc_polling_entity pollent; + jose_header* header; + grpc_jwt_claims* claims; + char* audience; + grpc_slice signature; + grpc_slice signed_data; + void* user_data; + grpc_jwt_verification_done_cb user_cb; + grpc_http_response responses[HTTP_RESPONSE_COUNT]; +}; +/* Takes ownership of the header, claims and signature. */ +static verifier_cb_ctx* verifier_cb_ctx_create( + grpc_jwt_verifier* verifier, grpc_pollset* pollset, jose_header* header, + grpc_jwt_claims* claims, const char* audience, const grpc_slice& signature, + const char* signed_jwt, size_t signed_jwt_len, void* user_data, + grpc_jwt_verification_done_cb cb) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + verifier_cb_ctx* ctx = new verifier_cb_ctx(); + ctx->verifier = verifier; + ctx->pollent = grpc_polling_entity_create_from_pollset(pollset); + ctx->header = header; + ctx->audience = gpr_strdup(audience); + ctx->claims = claims; + ctx->signature = signature; + ctx->signed_data = grpc_slice_from_copied_buffer(signed_jwt, signed_jwt_len); + ctx->user_data = user_data; + ctx->user_cb = cb; + return ctx; +} + +void verifier_cb_ctx_destroy(verifier_cb_ctx* ctx) { + if (ctx->audience != nullptr) gpr_free(ctx->audience); + if (ctx->claims != nullptr) grpc_jwt_claims_destroy(ctx->claims); + grpc_slice_unref_internal(ctx->signature); + grpc_slice_unref_internal(ctx->signed_data); + jose_header_destroy(ctx->header); + for (size_t i = 0; i < HTTP_RESPONSE_COUNT; i++) { + grpc_http_response_destroy(&ctx->responses[i]); + } + /* TODO: see what to do with claims... */ + delete ctx; +} + +/* --- grpc_jwt_verifier object. --- */ + +/* Clock skew defaults to one minute. */ +gpr_timespec grpc_jwt_verifier_clock_skew = {60, 0, GPR_TIMESPAN}; + +/* Max delay defaults to one minute. */ +grpc_millis grpc_jwt_verifier_max_delay = 60 * GPR_MS_PER_SEC; + +struct email_key_mapping { + char* email_domain; + char* key_url_prefix; +}; +struct grpc_jwt_verifier { + email_key_mapping* mappings; + size_t num_mappings; /* Should be very few, linear search ok. */ + size_t allocated_mappings; + grpc_httpcli_context http_ctx; +}; + +static Json json_from_http(const grpc_httpcli_response* response) { + if (response == nullptr) { + gpr_log(GPR_ERROR, "HTTP response is NULL."); + return Json(); // JSON null + } + if (response->status != 200) { + gpr_log(GPR_ERROR, "Call to http server failed with error %d.", + response->status); + return Json(); // JSON null + } + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse( + absl::string_view(response->body, response->body_length), &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Invalid JSON found in response."); + return Json(); // JSON null + } + return json; +} + +static const Json* find_property_by_name(const Json& json, const char* name) { + auto it = json.object_value().find(name); + if (it == json.object_value().end()) { + return nullptr; + } + return &it->second; +} + +static EVP_PKEY* extract_pkey_from_x509(const char* x509_str) { + X509* x509 = nullptr; + EVP_PKEY* result = nullptr; + BIO* bio = BIO_new(BIO_s_mem()); + size_t len = strlen(x509_str); + GPR_ASSERT(len < INT_MAX); + BIO_write(bio, x509_str, static_cast(len)); + x509 = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + if (x509 == nullptr) { + gpr_log(GPR_ERROR, "Unable to parse x509 cert."); + goto end; + } + result = X509_get_pubkey(x509); + if (result == nullptr) { + gpr_log(GPR_ERROR, "Cannot find public key in X509 cert."); + } + +end: + BIO_free(bio); + X509_free(x509); + return result; +} + +static BIGNUM* bignum_from_base64(const char* b64) { + BIGNUM* result = nullptr; + grpc_slice bin; + + if (b64 == nullptr) return nullptr; + bin = grpc_base64_decode(b64, 1); + if (GRPC_SLICE_IS_EMPTY(bin)) { + gpr_log(GPR_ERROR, "Invalid base64 for big num."); + return nullptr; + } + result = BN_bin2bn(GRPC_SLICE_START_PTR(bin), + TSI_SIZE_AS_SIZE(GRPC_SLICE_LENGTH(bin)), nullptr); + grpc_slice_unref_internal(bin); + return result; +} + +#if OPENSSL_VERSION_NUMBER < 0x10100000L + +// Provide compatibility across OpenSSL 1.02 and 1.1. +static int RSA_set0_key(RSA* r, BIGNUM* n, BIGNUM* e, BIGNUM* d) { + /* If the fields n and e in r are NULL, the corresponding input + * parameters MUST be non-NULL for n and e. d may be + * left NULL (in case only the public key is used). + */ + if ((r->n == nullptr && n == nullptr) || (r->e == nullptr && e == nullptr)) { + return 0; + } + + if (n != nullptr) { + BN_free(r->n); + r->n = n; + } + if (e != nullptr) { + BN_free(r->e); + r->e = e; + } + if (d != nullptr) { + BN_free(r->d); + r->d = d; + } + + return 1; +} +#endif // OPENSSL_VERSION_NUMBER < 0x10100000L + +static EVP_PKEY* pkey_from_jwk(const Json& json, const char* kty) { + RSA* rsa = nullptr; + EVP_PKEY* result = nullptr; + BIGNUM* tmp_n = nullptr; + BIGNUM* tmp_e = nullptr; + Json::Object::const_iterator it; + + GPR_ASSERT(json.type() == Json::Type::OBJECT); + GPR_ASSERT(kty != nullptr); + if (strcmp(kty, "RSA") != 0) { + gpr_log(GPR_ERROR, "Unsupported key type %s.", kty); + goto end; + } + rsa = RSA_new(); + if (rsa == nullptr) { + gpr_log(GPR_ERROR, "Could not create rsa key."); + goto end; + } + it = json.object_value().find("n"); + if (it == json.object_value().end()) { + gpr_log(GPR_ERROR, "Missing RSA public key field."); + goto end; + } + tmp_n = bignum_from_base64(validate_string_field(it->second, "n")); + if (tmp_n == nullptr) goto end; + it = json.object_value().find("e"); + if (it == json.object_value().end()) { + gpr_log(GPR_ERROR, "Missing RSA public key field."); + goto end; + } + tmp_e = bignum_from_base64(validate_string_field(it->second, "e")); + if (tmp_e == nullptr) goto end; + if (!RSA_set0_key(rsa, tmp_n, tmp_e, nullptr)) { + gpr_log(GPR_ERROR, "Cannot set RSA key from inputs."); + goto end; + } + /* RSA_set0_key takes ownership on success. */ + tmp_n = nullptr; + tmp_e = nullptr; + result = EVP_PKEY_new(); + EVP_PKEY_set1_RSA(result, rsa); /* uprefs rsa. */ + +end: + RSA_free(rsa); + BN_free(tmp_n); + BN_free(tmp_e); + return result; +} + +static EVP_PKEY* find_verification_key(const Json& json, const char* header_alg, + const char* header_kid) { + /* Try to parse the json as a JWK set: + https://tools.ietf.org/html/rfc7517#section-5. */ + const Json* jwt_keys = find_property_by_name(json, "keys"); + if (jwt_keys == nullptr) { + /* Use the google proprietary format which is: + { : , : , ... } */ + const Json* cur = find_property_by_name(json, header_kid); + if (cur == nullptr) return nullptr; + return extract_pkey_from_x509(cur->string_value().c_str()); + } + if (jwt_keys->type() != Json::Type::ARRAY) { + gpr_log(GPR_ERROR, + "Unexpected value type of keys property in jwks key set."); + return nullptr; + } + /* Key format is specified in: + https://tools.ietf.org/html/rfc7518#section-6. */ + for (const Json& jkey : jwt_keys->array_value()) { + if (jkey.type() != Json::Type::OBJECT) continue; + const char* alg = nullptr; + auto it = jkey.object_value().find("alg"); + if (it != jkey.object_value().end()) { + alg = validate_string_field(it->second, "alg"); + } + const char* kid = nullptr; + it = jkey.object_value().find("kid"); + if (it != jkey.object_value().end()) { + kid = validate_string_field(it->second, "kid"); + } + const char* kty = nullptr; + it = jkey.object_value().find("kty"); + if (it != jkey.object_value().end()) { + kty = validate_string_field(it->second, "kty"); + } + if (alg != nullptr && kid != nullptr && kty != nullptr && + strcmp(kid, header_kid) == 0 && strcmp(alg, header_alg) == 0) { + return pkey_from_jwk(jkey, kty); + } + } + gpr_log(GPR_ERROR, + "Could not find matching key in key set for kid=%s and alg=%s", + header_kid, header_alg); + return nullptr; +} + +static int verify_jwt_signature(EVP_PKEY* key, const char* alg, + const grpc_slice& signature, + const grpc_slice& signed_data) { + EVP_MD_CTX* md_ctx = EVP_MD_CTX_create(); + const EVP_MD* md = evp_md_from_alg(alg); + int result = 0; + + GPR_ASSERT(md != nullptr); /* Checked before. */ + if (md_ctx == nullptr) { + gpr_log(GPR_ERROR, "Could not create EVP_MD_CTX."); + goto end; + } + if (EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key) != 1) { + gpr_log(GPR_ERROR, "EVP_DigestVerifyInit failed."); + goto end; + } + if (EVP_DigestVerifyUpdate(md_ctx, GRPC_SLICE_START_PTR(signed_data), + GRPC_SLICE_LENGTH(signed_data)) != 1) { + gpr_log(GPR_ERROR, "EVP_DigestVerifyUpdate failed."); + goto end; + } + if (EVP_DigestVerifyFinal(md_ctx, GRPC_SLICE_START_PTR(signature), + GRPC_SLICE_LENGTH(signature)) != 1) { + gpr_log(GPR_ERROR, "JWT signature verification failed."); + goto end; + } + result = 1; + +end: + EVP_MD_CTX_destroy(md_ctx); + return result; +} + +static void on_keys_retrieved(void* user_data, grpc_error_handle /*error*/) { + verifier_cb_ctx* ctx = static_cast(user_data); + Json json = json_from_http(&ctx->responses[HTTP_RESPONSE_KEYS]); + EVP_PKEY* verification_key = nullptr; + grpc_jwt_verifier_status status = GRPC_JWT_VERIFIER_GENERIC_ERROR; + grpc_jwt_claims* claims = nullptr; + + if (json.type() == Json::Type::JSON_NULL) { + status = GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR; + goto end; + } + verification_key = + find_verification_key(json, ctx->header->alg, ctx->header->kid); + if (verification_key == nullptr) { + gpr_log(GPR_ERROR, "Could not find verification key with kid %s.", + ctx->header->kid); + status = GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR; + goto end; + } + + if (!verify_jwt_signature(verification_key, ctx->header->alg, ctx->signature, + ctx->signed_data)) { + status = GRPC_JWT_VERIFIER_BAD_SIGNATURE; + goto end; + } + + status = grpc_jwt_claims_check(ctx->claims, ctx->audience); + if (status == GRPC_JWT_VERIFIER_OK) { + /* Pass ownership. */ + claims = ctx->claims; + ctx->claims = nullptr; + } + +end: + EVP_PKEY_free(verification_key); + ctx->user_cb(ctx->user_data, status, claims); + verifier_cb_ctx_destroy(ctx); +} + +static void on_openid_config_retrieved(void* user_data, + grpc_error_handle /*error*/) { + verifier_cb_ctx* ctx = static_cast(user_data); + const grpc_http_response* response = &ctx->responses[HTTP_RESPONSE_OPENID]; + Json json = json_from_http(response); + grpc_httpcli_request req; + const char* jwks_uri; + grpc_resource_quota* resource_quota = nullptr; + const Json* cur; + + /* TODO(jboeuf): Cache the jwks_uri in order to avoid this hop next time. */ + if (json.type() == Json::Type::JSON_NULL) goto error; + cur = find_property_by_name(json, "jwks_uri"); + if (cur == nullptr) { + gpr_log(GPR_ERROR, "Could not find jwks_uri in openid config."); + goto error; + } + jwks_uri = validate_string_field(*cur, "jwks_uri"); + if (jwks_uri == nullptr) goto error; + if (strstr(jwks_uri, "https://") != jwks_uri) { + gpr_log(GPR_ERROR, "Invalid non https jwks_uri: %s.", jwks_uri); + goto error; + } + jwks_uri += 8; + req.handshaker = &grpc_httpcli_ssl; + req.host = gpr_strdup(jwks_uri); + req.http.path = const_cast(strchr(jwks_uri, '/')); + if (req.http.path == nullptr) { + req.http.path = const_cast(""); + } else { + *(req.host + (req.http.path - jwks_uri)) = '\0'; + } + + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + resource_quota = grpc_resource_quota_create("jwt_verifier"); + grpc_httpcli_get( + &ctx->verifier->http_ctx, &ctx->pollent, resource_quota, &req, + grpc_core::ExecCtx::Get()->Now() + grpc_jwt_verifier_max_delay, + GRPC_CLOSURE_CREATE(on_keys_retrieved, ctx, grpc_schedule_on_exec_ctx), + &ctx->responses[HTTP_RESPONSE_KEYS]); + gpr_free(req.host); + return; + +error: + ctx->user_cb(ctx->user_data, GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR, nullptr); + verifier_cb_ctx_destroy(ctx); +} + +static email_key_mapping* verifier_get_mapping(grpc_jwt_verifier* v, + const char* email_domain) { + size_t i; + if (v->mappings == nullptr) return nullptr; + for (i = 0; i < v->num_mappings; i++) { + if (strcmp(email_domain, v->mappings[i].email_domain) == 0) { + return &v->mappings[i]; + } + } + return nullptr; +} + +static void verifier_put_mapping(grpc_jwt_verifier* v, const char* email_domain, + const char* key_url_prefix) { + email_key_mapping* mapping = verifier_get_mapping(v, email_domain); + GPR_ASSERT(v->num_mappings < v->allocated_mappings); + if (mapping != nullptr) { + gpr_free(mapping->key_url_prefix); + mapping->key_url_prefix = gpr_strdup(key_url_prefix); + return; + } + v->mappings[v->num_mappings].email_domain = gpr_strdup(email_domain); + v->mappings[v->num_mappings].key_url_prefix = gpr_strdup(key_url_prefix); + v->num_mappings++; + GPR_ASSERT(v->num_mappings <= v->allocated_mappings); +} + +/* Very non-sophisticated way to detect an email address. Should be good + enough for now... */ +const char* grpc_jwt_issuer_email_domain(const char* issuer) { + const char* at_sign = strchr(issuer, '@'); + if (at_sign == nullptr) return nullptr; + const char* email_domain = at_sign + 1; + if (*email_domain == '\0') return nullptr; + const char* dot = strrchr(email_domain, '.'); + if (dot == nullptr || dot == email_domain) return email_domain; + GPR_ASSERT(dot > email_domain); + /* There may be a subdomain, we just want the domain. */ + dot = static_cast( + gpr_memrchr(email_domain, '.', static_cast(dot - email_domain))); + if (dot == nullptr) return email_domain; + return dot + 1; +} + +/* Takes ownership of ctx. */ +static void retrieve_key_and_verify(verifier_cb_ctx* ctx) { + const char* email_domain; + grpc_closure* http_cb; + char* path_prefix = nullptr; + const char* iss; + grpc_httpcli_request req; + grpc_resource_quota* resource_quota = nullptr; + memset(&req, 0, sizeof(grpc_httpcli_request)); + req.handshaker = &grpc_httpcli_ssl; + http_response_index rsp_idx; + + GPR_ASSERT(ctx != nullptr && ctx->header != nullptr && + ctx->claims != nullptr); + iss = ctx->claims->iss; + if (ctx->header->kid == nullptr) { + gpr_log(GPR_ERROR, "Missing kid in jose header."); + goto error; + } + if (iss == nullptr) { + gpr_log(GPR_ERROR, "Missing iss in claims."); + goto error; + } + + /* This code relies on: + https://openid.net/specs/openid-connect-discovery-1_0.html + Nobody seems to implement the account/email/webfinger part 2. of the spec + so we will rely instead on email/url mappings if we detect such an issuer. + Part 4, on the other hand is implemented by both google and salesforce. */ + email_domain = grpc_jwt_issuer_email_domain(iss); + if (email_domain != nullptr) { + email_key_mapping* mapping; + GPR_ASSERT(ctx->verifier != nullptr); + mapping = verifier_get_mapping(ctx->verifier, email_domain); + if (mapping == nullptr) { + gpr_log(GPR_ERROR, "Missing mapping for issuer email."); + goto error; + } + req.host = gpr_strdup(mapping->key_url_prefix); + path_prefix = strchr(req.host, '/'); + if (path_prefix == nullptr) { + gpr_asprintf(&req.http.path, "/%s", iss); + } else { + *(path_prefix++) = '\0'; + gpr_asprintf(&req.http.path, "/%s/%s", path_prefix, iss); + } + http_cb = + GRPC_CLOSURE_CREATE(on_keys_retrieved, ctx, grpc_schedule_on_exec_ctx); + rsp_idx = HTTP_RESPONSE_KEYS; + } else { + req.host = gpr_strdup(strstr(iss, "https://") == iss ? iss + 8 : iss); + path_prefix = strchr(req.host, '/'); + if (path_prefix == nullptr) { + req.http.path = gpr_strdup(GRPC_OPENID_CONFIG_URL_SUFFIX); + } else { + *(path_prefix++) = 0; + gpr_asprintf(&req.http.path, "/%s%s", path_prefix, + GRPC_OPENID_CONFIG_URL_SUFFIX); + } + http_cb = GRPC_CLOSURE_CREATE(on_openid_config_retrieved, ctx, + grpc_schedule_on_exec_ctx); + rsp_idx = HTTP_RESPONSE_OPENID; + } + + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + resource_quota = grpc_resource_quota_create("jwt_verifier"); + grpc_httpcli_get( + &ctx->verifier->http_ctx, &ctx->pollent, resource_quota, &req, + grpc_core::ExecCtx::Get()->Now() + grpc_jwt_verifier_max_delay, http_cb, + &ctx->responses[rsp_idx]); + gpr_free(req.host); + gpr_free(req.http.path); + return; + +error: + ctx->user_cb(ctx->user_data, GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR, nullptr); + verifier_cb_ctx_destroy(ctx); +} + +void grpc_jwt_verifier_verify(grpc_jwt_verifier* verifier, + grpc_pollset* pollset, const char* jwt, + const char* audience, + grpc_jwt_verification_done_cb cb, + void* user_data) { + const char* dot = nullptr; + jose_header* header = nullptr; + grpc_jwt_claims* claims = nullptr; + grpc_slice signature; + size_t signed_jwt_len; + const char* cur = jwt; + Json json; + + GPR_ASSERT(verifier != nullptr && jwt != nullptr && audience != nullptr && + cb != nullptr); + dot = strchr(cur, '.'); + if (dot == nullptr) goto error; + json = parse_json_part_from_jwt(cur, static_cast(dot - cur)); + if (json.type() == Json::Type::JSON_NULL) goto error; + header = jose_header_from_json(std::move(json)); + if (header == nullptr) goto error; + + cur = dot + 1; + dot = strchr(cur, '.'); + if (dot == nullptr) goto error; + json = parse_json_part_from_jwt(cur, static_cast(dot - cur)); + if (json.type() == Json::Type::JSON_NULL) goto error; + claims = grpc_jwt_claims_from_json(std::move(json)); + if (claims == nullptr) goto error; + + signed_jwt_len = static_cast(dot - jwt); + cur = dot + 1; + signature = grpc_base64_decode(cur, 1); + if (GRPC_SLICE_IS_EMPTY(signature)) goto error; + retrieve_key_and_verify( + verifier_cb_ctx_create(verifier, pollset, header, claims, audience, + signature, jwt, signed_jwt_len, user_data, cb)); + return; + +error: + if (header != nullptr) jose_header_destroy(header); + if (claims != nullptr) grpc_jwt_claims_destroy(claims); + cb(user_data, GRPC_JWT_VERIFIER_BAD_FORMAT, nullptr); +} + +grpc_jwt_verifier* grpc_jwt_verifier_create( + const grpc_jwt_verifier_email_domain_key_url_mapping* mappings, + size_t num_mappings) { + grpc_jwt_verifier* v = grpc_core::Zalloc(); + grpc_httpcli_context_init(&v->http_ctx); + + /* We know at least of one mapping. */ + v->allocated_mappings = 1 + num_mappings; + v->mappings = static_cast( + gpr_malloc(v->allocated_mappings * sizeof(email_key_mapping))); + verifier_put_mapping(v, GRPC_GOOGLE_SERVICE_ACCOUNTS_EMAIL_DOMAIN, + GRPC_GOOGLE_SERVICE_ACCOUNTS_KEY_URL_PREFIX); + /* User-Provided mappings. */ + if (mappings != nullptr) { + size_t i; + for (i = 0; i < num_mappings; i++) { + verifier_put_mapping(v, mappings[i].email_domain, + mappings[i].key_url_prefix); + } + } + return v; +} + +void grpc_jwt_verifier_destroy(grpc_jwt_verifier* v) { + size_t i; + if (v == nullptr) return; + grpc_httpcli_context_destroy(&v->http_ctx); + if (v->mappings != nullptr) { + for (i = 0; i < v->num_mappings; i++) { + gpr_free(v->mappings[i].email_domain); + gpr_free(v->mappings[i].key_url_prefix); + } + gpr_free(v->mappings); + } + gpr_free(v); +} diff --git a/src/core/lib/security/credentials/local/local_credentials.cc b/src/core/lib/security/credentials/local/local_credentials.cc new file mode 100644 index 00000000..84caf1c3 --- /dev/null +++ b/src/core/lib/security/credentials/local/local_credentials.cc @@ -0,0 +1,65 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/local/local_credentials.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/security/security_connector/local/local_security_connector.h" + +#define GRPC_CREDENTIALS_TYPE_LOCAL "Local" + +grpc_core::RefCountedPtr +grpc_local_credentials::create_security_connector( + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** /*new_args*/) { + return grpc_local_channel_security_connector_create( + this->Ref(), std::move(request_metadata_creds), args, target_name); +} + +grpc_core::RefCountedPtr +grpc_local_server_credentials::create_security_connector( + const grpc_channel_args* /* args */) { + return grpc_local_server_security_connector_create(this->Ref()); +} + +grpc_local_credentials::grpc_local_credentials( + grpc_local_connect_type connect_type) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} + +grpc_channel_credentials* grpc_local_credentials_create( + grpc_local_connect_type connect_type) { + return new grpc_local_credentials(connect_type); +} + +grpc_local_server_credentials::grpc_local_server_credentials( + grpc_local_connect_type connect_type) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} + +grpc_server_credentials* grpc_local_server_credentials_create( + grpc_local_connect_type connect_type) { + return new grpc_local_server_credentials(connect_type); +} diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc new file mode 100644 index 00000000..df467fe4 --- /dev/null +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc @@ -0,0 +1,753 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/oauth2/oauth2_credentials.h" + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/security/util/json_util.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/uri/uri_parser.h" + +using grpc_core::Json; + +// +// Auth Refresh Token. +// + +int grpc_auth_refresh_token_is_valid( + const grpc_auth_refresh_token* refresh_token) { + return (refresh_token != nullptr) && + strcmp(refresh_token->type, GRPC_AUTH_JSON_TYPE_INVALID) != 0; +} + +grpc_auth_refresh_token grpc_auth_refresh_token_create_from_json( + const Json& json) { + grpc_auth_refresh_token result; + const char* prop_value; + int success = 0; + grpc_error_handle error = GRPC_ERROR_NONE; + + memset(&result, 0, sizeof(grpc_auth_refresh_token)); + result.type = GRPC_AUTH_JSON_TYPE_INVALID; + if (json.type() != Json::Type::OBJECT) { + gpr_log(GPR_ERROR, "Invalid json."); + goto end; + } + + prop_value = grpc_json_get_string_property(json, "type", &error); + GRPC_LOG_IF_ERROR("Parsing refresh token", error); + if (prop_value == nullptr || + strcmp(prop_value, GRPC_AUTH_JSON_TYPE_AUTHORIZED_USER) != 0) { + goto end; + } + result.type = GRPC_AUTH_JSON_TYPE_AUTHORIZED_USER; + + if (!grpc_copy_json_string_property(json, "client_secret", + &result.client_secret) || + !grpc_copy_json_string_property(json, "client_id", &result.client_id) || + !grpc_copy_json_string_property(json, "refresh_token", + &result.refresh_token)) { + goto end; + } + success = 1; + +end: + if (!success) grpc_auth_refresh_token_destruct(&result); + return result; +} + +grpc_auth_refresh_token grpc_auth_refresh_token_create_from_string( + const char* json_string) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_string, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parsing failed: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + } + return grpc_auth_refresh_token_create_from_json(json); +} + +void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) { + if (refresh_token == nullptr) return; + refresh_token->type = GRPC_AUTH_JSON_TYPE_INVALID; + if (refresh_token->client_id != nullptr) { + gpr_free(refresh_token->client_id); + refresh_token->client_id = nullptr; + } + if (refresh_token->client_secret != nullptr) { + gpr_free(refresh_token->client_secret); + refresh_token->client_secret = nullptr; + } + if (refresh_token->refresh_token != nullptr) { + gpr_free(refresh_token->refresh_token); + refresh_token->refresh_token = nullptr; + } +} + +// +// Oauth2 Token Fetcher credentials. +// + +grpc_oauth2_token_fetcher_credentials:: + ~grpc_oauth2_token_fetcher_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); + gpr_mu_destroy(&mu_); + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); + grpc_httpcli_context_destroy(&httpcli_context_); +} + +grpc_credentials_status +grpc_oauth2_token_fetcher_credentials_parse_server_response( + const grpc_http_response* response, grpc_mdelem* token_md, + grpc_millis* token_lifetime) { + char* null_terminated_body = nullptr; + grpc_credentials_status status = GRPC_CREDENTIALS_OK; + Json json; + + if (response == nullptr) { + gpr_log(GPR_ERROR, "Received NULL response."); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + + if (response->body_length > 0) { + null_terminated_body = + static_cast(gpr_malloc(response->body_length + 1)); + null_terminated_body[response->body_length] = '\0'; + memcpy(null_terminated_body, response->body, response->body_length); + } + + if (response->status != 200) { + gpr_log(GPR_ERROR, "Call to http server ended with error %d [%s].", + response->status, + null_terminated_body != nullptr ? null_terminated_body : ""); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } else { + const char* access_token = nullptr; + const char* token_type = nullptr; + const char* expires_in = nullptr; + Json::Object::const_iterator it; + grpc_error_handle error = GRPC_ERROR_NONE; + json = Json::Parse(null_terminated_body, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Could not parse JSON from %s: %s", + null_terminated_body, grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + if (json.type() != Json::Type::OBJECT) { + gpr_log(GPR_ERROR, "Response should be a JSON object"); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + it = json.object_value().find("access_token"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::STRING) { + gpr_log(GPR_ERROR, "Missing or invalid access_token in JSON."); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + access_token = it->second.string_value().c_str(); + it = json.object_value().find("token_type"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::STRING) { + gpr_log(GPR_ERROR, "Missing or invalid token_type in JSON."); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + token_type = it->second.string_value().c_str(); + it = json.object_value().find("expires_in"); + if (it == json.object_value().end() || + it->second.type() != Json::Type::NUMBER) { + gpr_log(GPR_ERROR, "Missing or invalid expires_in in JSON."); + status = GRPC_CREDENTIALS_ERROR; + goto end; + } + expires_in = it->second.string_value().c_str(); + *token_lifetime = strtol(expires_in, nullptr, 10) * GPR_MS_PER_SEC; + if (!GRPC_MDISNULL(*token_md)) GRPC_MDELEM_UNREF(*token_md); + *token_md = grpc_mdelem_from_slices( + grpc_core::ExternallyManagedSlice(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_cpp_string( + absl::StrCat(token_type, " ", access_token))); + status = GRPC_CREDENTIALS_OK; + } + +end: + if (status != GRPC_CREDENTIALS_OK && !GRPC_MDISNULL(*token_md)) { + GRPC_MDELEM_UNREF(*token_md); + *token_md = GRPC_MDNULL; + } + gpr_free(null_terminated_body); + return status; +} + +static void on_oauth2_token_fetcher_http_response(void* user_data, + grpc_error_handle error) { + GRPC_LOG_IF_ERROR("oauth_fetch", GRPC_ERROR_REF(error)); + grpc_credentials_metadata_request* r = + static_cast(user_data); + grpc_oauth2_token_fetcher_credentials* c = + reinterpret_cast(r->creds.get()); + c->on_http_response(r, error); +} + +void grpc_oauth2_token_fetcher_credentials::on_http_response( + grpc_credentials_metadata_request* r, grpc_error_handle error) { + grpc_mdelem access_token_md = GRPC_MDNULL; + grpc_millis token_lifetime = 0; + grpc_credentials_status status = + error == GRPC_ERROR_NONE + ? grpc_oauth2_token_fetcher_credentials_parse_server_response( + &r->response, &access_token_md, &token_lifetime) + : GRPC_CREDENTIALS_ERROR; + // Update cache and grab list of pending requests. + gpr_mu_lock(&mu_); + token_fetch_pending_ = false; + access_token_md_ = GRPC_MDELEM_REF(access_token_md); + token_expiration_ = + status == GRPC_CREDENTIALS_OK + ? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(token_lifetime, GPR_TIMESPAN)) + : gpr_inf_past(GPR_CLOCK_MONOTONIC); + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; + pending_requests_ = nullptr; + gpr_mu_unlock(&mu_); + // Invoke callbacks for all pending requests. + while (pending_request != nullptr) { + grpc_error_handle new_error = GRPC_ERROR_NONE; + if (status == GRPC_CREDENTIALS_OK) { + grpc_credentials_mdelem_array_add(pending_request->md_array, + access_token_md); + } else { + new_error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + pending_request->on_request_metadata, new_error); + grpc_polling_entity_del_from_pollset_set( + pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_)); + grpc_oauth2_pending_get_request_metadata* prev = pending_request; + pending_request = pending_request->next; + gpr_free(prev); + } + GRPC_MDELEM_UNREF(access_token_md); + Unref(); + grpc_credentials_metadata_request_destroy(r); +} + +bool grpc_oauth2_token_fetcher_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context /*context*/, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error_handle* /*error*/) { + // Check if we can use the cached token. + grpc_millis refresh_threshold = + GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC; + grpc_mdelem cached_access_token_md = GRPC_MDNULL; + gpr_mu_lock(&mu_); + if (!GRPC_MDISNULL(access_token_md_) && + gpr_time_cmp( + gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)), + gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, + GPR_TIMESPAN)) > 0) { + cached_access_token_md = GRPC_MDELEM_REF(access_token_md_); + } + if (!GRPC_MDISNULL(cached_access_token_md)) { + gpr_mu_unlock(&mu_); + grpc_credentials_mdelem_array_add(md_array, cached_access_token_md); + GRPC_MDELEM_UNREF(cached_access_token_md); + return true; + } + // Couldn't get the token from the cache. + // Add request to pending_requests_ and start a new fetch if needed. + grpc_oauth2_pending_get_request_metadata* pending_request = + static_cast( + gpr_malloc(sizeof(*pending_request))); + pending_request->md_array = md_array; + pending_request->on_request_metadata = on_request_metadata; + pending_request->pollent = pollent; + grpc_polling_entity_add_to_pollset_set( + pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_request->next = pending_requests_; + pending_requests_ = pending_request; + bool start_fetch = false; + if (!token_fetch_pending_) { + token_fetch_pending_ = true; + start_fetch = true; + } + gpr_mu_unlock(&mu_); + if (start_fetch) { + Ref().release(); + fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()), + &httpcli_context_, &pollent_, + on_oauth2_token_fetcher_http_response, + grpc_core::ExecCtx::Get()->Now() + refresh_threshold); + } + return false; +} + +void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error_handle error) { + gpr_mu_lock(&mu_); + grpc_oauth2_pending_get_request_metadata* prev = nullptr; + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; + while (pending_request != nullptr) { + if (pending_request->md_array == md_array) { + // Remove matching pending request from the list. + if (prev != nullptr) { + prev->next = pending_request->next; + } else { + pending_requests_ = pending_request->next; + } + // Invoke the callback immediately with an error. + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + pending_request->on_request_metadata, + GRPC_ERROR_REF(error)); + gpr_free(pending_request); + break; + } + prev = pending_request; + pending_request = pending_request->next; + } + gpr_mu_unlock(&mu_); + GRPC_ERROR_UNREF(error); +} + +grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials() + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)), + pollent_(grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set_create())) { + gpr_mu_init(&mu_); + grpc_httpcli_context_init(&httpcli_context_); +} + +std::string grpc_oauth2_token_fetcher_credentials::debug_string() { + return "OAuth2TokenFetcherCredentials"; +} + +// +// Google Compute Engine credentials. +// + +namespace { + +class grpc_compute_engine_token_fetcher_credentials + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_compute_engine_token_fetcher_credentials() = default; + ~grpc_compute_engine_token_fetcher_credentials() override = default; + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* http_context, + grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, + grpc_millis deadline) override { + grpc_http_header header = {const_cast("Metadata-Flavor"), + const_cast("Google")}; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(GRPC_COMPUTE_ENGINE_METADATA_HOST); + request.http.path = + const_cast(GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH); + request.http.hdr_count = 1; + request.http.hdrs = &header; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials"); + grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline, + GRPC_CLOSURE_INIT(&http_get_cb_closure_, response_cb, + metadata_req, grpc_schedule_on_exec_ctx), + &metadata_req->response); + } + + std::string debug_string() override { + return absl::StrFormat( + "GoogleComputeEngineTokenFetcherCredentials{%s}", + grpc_oauth2_token_fetcher_credentials::debug_string()); + } + + private: + grpc_closure http_get_cb_closure_; +}; + +} // namespace + +grpc_call_credentials* grpc_google_compute_engine_credentials_create( + void* reserved) { + GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1, + (reserved)); + GPR_ASSERT(reserved == nullptr); + return grpc_core::MakeRefCounted< + grpc_compute_engine_token_fetcher_credentials>() + .release(); +} + +// +// Google Refresh Token credentials. +// + +grpc_google_refresh_token_credentials:: + ~grpc_google_refresh_token_credentials() { + grpc_auth_refresh_token_destruct(&refresh_token_); +} + +void grpc_google_refresh_token_credentials::fetch_oauth2( + grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, grpc_millis deadline) { + grpc_http_header header = { + const_cast("Content-Type"), + const_cast("application/x-www-form-urlencoded")}; + grpc_httpcli_request request; + std::string body = absl::StrFormat( + GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING, refresh_token_.client_id, + refresh_token_.client_secret, refresh_token_.refresh_token); + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(GRPC_GOOGLE_OAUTH2_SERVICE_HOST); + request.http.path = const_cast(GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH); + request.http.hdr_count = 1; + request.http.hdrs = &header; + request.handshaker = &grpc_httpcli_ssl; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials_refresh"); + grpc_httpcli_post(httpcli_context, pollent, resource_quota, &request, + body.c_str(), body.size(), deadline, + GRPC_CLOSURE_INIT(&http_post_cb_closure_, response_cb, + metadata_req, grpc_schedule_on_exec_ctx), + &metadata_req->response); +} + +grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials( + grpc_auth_refresh_token refresh_token) + : refresh_token_(refresh_token) {} + +grpc_core::RefCountedPtr +grpc_refresh_token_credentials_create_from_auth_refresh_token( + grpc_auth_refresh_token refresh_token) { + if (!grpc_auth_refresh_token_is_valid(&refresh_token)) { + gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation"); + return nullptr; + } + return grpc_core::MakeRefCounted( + refresh_token); +} + +std::string grpc_google_refresh_token_credentials::debug_string() { + return absl::StrFormat("GoogleRefreshToken{ClientID:%s,%s}", + refresh_token_.client_id, + grpc_oauth2_token_fetcher_credentials::debug_string()); +} + +static std::string create_loggable_refresh_token( + grpc_auth_refresh_token* token) { + if (strcmp(token->type, GRPC_AUTH_JSON_TYPE_INVALID) == 0) { + return ""; + } + return absl::StrFormat( + "{\n type: %s\n client_id: %s\n client_secret: " + "\n refresh_token: \n}", + token->type, token->client_id); +} + +grpc_call_credentials* grpc_google_refresh_token_credentials_create( + const char* json_refresh_token, void* reserved) { + grpc_auth_refresh_token token = + grpc_auth_refresh_token_create_from_string(json_refresh_token); + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace)) { + gpr_log(GPR_INFO, + "grpc_refresh_token_credentials_create(json_refresh_token=%s, " + "reserved=%p)", + create_loggable_refresh_token(&token).c_str(), reserved); + } + GPR_ASSERT(reserved == nullptr); + return grpc_refresh_token_credentials_create_from_auth_refresh_token(token) + .release(); +} + +// +// STS credentials. +// + +namespace grpc_core { + +namespace { + +void MaybeAddToBody(const char* field_name, const char* field, + std::vector* body) { + if (field == nullptr || strlen(field) == 0) return; + body->push_back(absl::StrFormat("&%s=%s", field_name, field)); +} + +grpc_error_handle LoadTokenFile(const char* path, gpr_slice* token) { + grpc_error_handle err = grpc_load_file(path, 1, token); + if (err != GRPC_ERROR_NONE) return err; + if (GRPC_SLICE_LENGTH(*token) == 0) { + gpr_log(GPR_ERROR, "Token file %s is empty", path); + err = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Token file is empty."); + } + return err; +} + +class StsTokenFetcherCredentials + : public grpc_oauth2_token_fetcher_credentials { + public: + StsTokenFetcherCredentials(URI sts_url, + const grpc_sts_credentials_options* options) + : sts_url_(std::move(sts_url)), + resource_(gpr_strdup(options->resource)), + audience_(gpr_strdup(options->audience)), + scope_(gpr_strdup(options->scope)), + requested_token_type_(gpr_strdup(options->requested_token_type)), + subject_token_path_(gpr_strdup(options->subject_token_path)), + subject_token_type_(gpr_strdup(options->subject_token_type)), + actor_token_path_(gpr_strdup(options->actor_token_path)), + actor_token_type_(gpr_strdup(options->actor_token_type)) {} + + std::string debug_string() override { + return absl::StrFormat( + "StsTokenFetcherCredentials{Path:%s,Authority:%s,%s}", sts_url_.path(), + sts_url_.authority(), + grpc_oauth2_token_fetcher_credentials::debug_string()); + } + + private: + void fetch_oauth2(grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* http_context, + grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, + grpc_millis deadline) override { + char* body = nullptr; + size_t body_length = 0; + grpc_error_handle err = FillBody(&body, &body_length); + if (err != GRPC_ERROR_NONE) { + response_cb(metadata_req, err); + GRPC_ERROR_UNREF(err); + return; + } + grpc_http_header header = { + const_cast("Content-Type"), + const_cast("application/x-www-form-urlencoded")}; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = const_cast(sts_url_.authority().c_str()); + request.http.path = const_cast(sts_url_.path().c_str()); + request.http.hdr_count = 1; + request.http.hdrs = &header; + request.handshaker = (sts_url_.scheme() == "https") + ? &grpc_httpcli_ssl + : &grpc_httpcli_plaintext; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials_refresh"); + grpc_httpcli_post( + http_context, pollent, resource_quota, &request, body, body_length, + deadline, + GRPC_CLOSURE_INIT(&http_post_cb_closure_, response_cb, metadata_req, + grpc_schedule_on_exec_ctx), + &metadata_req->response); + gpr_free(body); + } + + grpc_error_handle FillBody(char** body, size_t* body_length) { + *body = nullptr; + std::vector body_parts; + grpc_slice subject_token = grpc_empty_slice(); + grpc_slice actor_token = grpc_empty_slice(); + grpc_error_handle err = GRPC_ERROR_NONE; + + auto cleanup = [&body, &body_length, &body_parts, &subject_token, + &actor_token, &err]() { + if (err == GRPC_ERROR_NONE) { + std::string body_str = absl::StrJoin(body_parts, ""); + *body = gpr_strdup(body_str.c_str()); + *body_length = body_str.size(); + } + grpc_slice_unref_internal(subject_token); + grpc_slice_unref_internal(actor_token); + return err; + }; + + err = LoadTokenFile(subject_token_path_.get(), &subject_token); + if (err != GRPC_ERROR_NONE) return cleanup(); + body_parts.push_back(absl::StrFormat( + GRPC_STS_POST_MINIMAL_BODY_FORMAT_STRING, + reinterpret_cast(GRPC_SLICE_START_PTR(subject_token)), + subject_token_type_.get())); + MaybeAddToBody("resource", resource_.get(), &body_parts); + MaybeAddToBody("audience", audience_.get(), &body_parts); + MaybeAddToBody("scope", scope_.get(), &body_parts); + MaybeAddToBody("requested_token_type", requested_token_type_.get(), + &body_parts); + if ((actor_token_path_ != nullptr) && *actor_token_path_ != '\0') { + err = LoadTokenFile(actor_token_path_.get(), &actor_token); + if (err != GRPC_ERROR_NONE) return cleanup(); + MaybeAddToBody( + "actor_token", + reinterpret_cast(GRPC_SLICE_START_PTR(actor_token)), + &body_parts); + MaybeAddToBody("actor_token_type", actor_token_type_.get(), &body_parts); + } + return cleanup(); + } + + URI sts_url_; + grpc_closure http_post_cb_closure_; + grpc_core::UniquePtr resource_; + grpc_core::UniquePtr audience_; + grpc_core::UniquePtr scope_; + grpc_core::UniquePtr requested_token_type_; + grpc_core::UniquePtr subject_token_path_; + grpc_core::UniquePtr subject_token_type_; + grpc_core::UniquePtr actor_token_path_; + grpc_core::UniquePtr actor_token_type_; +}; + +} // namespace + +absl::StatusOr ValidateStsCredentialsOptions( + const grpc_sts_credentials_options* options) { + absl::InlinedVector error_list; + absl::StatusOr sts_url = + URI::Parse(options->token_exchange_service_uri == nullptr + ? "" + : options->token_exchange_service_uri); + if (!sts_url.ok()) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrFormat("Invalid or missing STS endpoint URL. Error: %s", + sts_url.status().ToString()))); + } else if (sts_url->scheme() != "https" && sts_url->scheme() != "http") { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid URI scheme, must be https to http.")); + } + if (options->subject_token_path == nullptr || + strlen(options->subject_token_path) == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "subject_token needs to be specified")); + } + if (options->subject_token_type == nullptr || + strlen(options->subject_token_type) == 0) { + error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "subject_token_type needs to be specified")); + } + if (error_list.empty()) { + return sts_url; + } + auto grpc_error_vec = GRPC_ERROR_CREATE_FROM_VECTOR( + "Invalid STS Credentials Options", &error_list); + auto retval = + absl::InvalidArgumentError(grpc_error_std_string(grpc_error_vec)); + GRPC_ERROR_UNREF(grpc_error_vec); + return retval; +} + +} // namespace grpc_core + +grpc_call_credentials* grpc_sts_credentials_create( + const grpc_sts_credentials_options* options, void* reserved) { + GPR_ASSERT(reserved == nullptr); + absl::StatusOr sts_url = + grpc_core::ValidateStsCredentialsOptions(options); + if (!sts_url.ok()) { + gpr_log(GPR_ERROR, "STS Credentials creation failed. Error: %s.", + sts_url.status().ToString().c_str()); + return nullptr; + } + return grpc_core::MakeRefCounted( + std::move(*sts_url), options) + .release(); +} + +// +// Oauth2 Access Token credentials. +// + +grpc_access_token_credentials::~grpc_access_token_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); +} + +bool grpc_access_token_credentials::get_request_metadata( + grpc_polling_entity* /*pollent*/, grpc_auth_metadata_context /*context*/, + grpc_credentials_mdelem_array* md_array, + grpc_closure* /*on_request_metadata*/, grpc_error_handle* /*error*/) { + grpc_credentials_mdelem_array_add(md_array, access_token_md_); + return true; +} + +void grpc_access_token_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* /*md_array*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +grpc_access_token_credentials::grpc_access_token_credentials( + const char* access_token) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) { + grpc_core::ExecCtx exec_ctx; + access_token_md_ = grpc_mdelem_from_slices( + grpc_core::ExternallyManagedSlice(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_cpp_string(absl::StrCat("Bearer ", access_token))); +} + +std::string grpc_access_token_credentials::debug_string() { + bool access_token_present = !GRPC_MDISNULL(access_token_md_); + return absl::StrFormat("AccessTokenCredentials{Token:%s}", + access_token_present ? "present" : "absent"); +} + +grpc_call_credentials* grpc_access_token_credentials_create( + const char* access_token, void* reserved) { + GRPC_API_TRACE( + "grpc_access_token_credentials_create(access_token=, " + "reserved=%p)", + 1, (reserved)); + GPR_ASSERT(reserved == nullptr); + return grpc_core::MakeRefCounted(access_token) + .release(); +} diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc new file mode 100644 index 00000000..370ccb4f --- /dev/null +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/plugin/plugin_credentials.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/validate_metadata.h" + +grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials"); + +grpc_plugin_credentials::~grpc_plugin_credentials() { + gpr_mu_destroy(&mu_); + if (plugin_.state != nullptr && plugin_.destroy != nullptr) { + plugin_.destroy(plugin_.state); + } +} + +std::string grpc_plugin_credentials::debug_string() { + char* debug_c_str = nullptr; + if (plugin_.debug_string != nullptr) { + debug_c_str = plugin_.debug_string(plugin_.state); + } + std::string debug_str( + debug_c_str != nullptr + ? debug_c_str + : "grpc_plugin_credentials did not provide a debug string"); + gpr_free(debug_c_str); + return debug_str; +} + +void grpc_plugin_credentials::pending_request_remove_locked( + pending_request* pending_request) { + if (pending_request->prev == nullptr) { + pending_requests_ = pending_request->next; + } else { + pending_request->prev->next = pending_request->next; + } + if (pending_request->next != nullptr) { + pending_request->next->prev = pending_request->prev; + } +} + +// Checks if the request has been cancelled. +// If not, removes it from the pending list, so that it cannot be +// cancelled out from under us. +// When this returns, r->cancelled indicates whether the request was +// cancelled before completion. +void grpc_plugin_credentials::pending_request_complete(pending_request* r) { + GPR_DEBUG_ASSERT(r->creds == this); + gpr_mu_lock(&mu_); + if (!r->cancelled) pending_request_remove_locked(r); + gpr_mu_unlock(&mu_); + // Ref to credentials not needed anymore. + Unref(); +} + +static grpc_error_handle process_plugin_result( + grpc_plugin_credentials::pending_request* r, const grpc_metadata* md, + size_t num_md, grpc_status_code status, const char* error_details) { + grpc_error_handle error = GRPC_ERROR_NONE; + if (status != GRPC_STATUS_OK) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Getting metadata from plugin failed with error: ", error_details)); + } else { + bool seen_illegal_header = false; + for (size_t i = 0; i < num_md; ++i) { + if (!GRPC_LOG_IF_ERROR("validate_metadata_from_plugin", + grpc_validate_header_key_is_legal(md[i].key))) { + seen_illegal_header = true; + break; + } else if (!grpc_is_binary_header_internal(md[i].key) && + !GRPC_LOG_IF_ERROR( + "validate_metadata_from_plugin", + grpc_validate_header_nonbin_value_is_legal(md[i].value))) { + gpr_log(GPR_ERROR, "Plugin added invalid metadata value."); + seen_illegal_header = true; + break; + } + } + if (seen_illegal_header) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Illegal metadata"); + } else { + for (size_t i = 0; i < num_md; ++i) { + grpc_mdelem mdelem = + grpc_mdelem_create(md[i].key, md[i].value, nullptr); + grpc_credentials_mdelem_array_add(r->md_array, mdelem); + GRPC_MDELEM_UNREF(mdelem); + } + } + } + return error; +} + +static void plugin_md_request_metadata_ready(void* request, + const grpc_metadata* md, + size_t num_md, + grpc_status_code status, + const char* error_details) { + /* called from application code */ + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED | + GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP); + grpc_plugin_credentials::pending_request* r = + static_cast(request); + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, + "plugin_credentials[%p]: request %p: plugin returned " + "asynchronously", + r->creds, r); + } + // Remove request from pending list if not previously cancelled. + r->creds->pending_request_complete(r); + // If it has not been cancelled, process it. + if (!r->cancelled) { + grpc_error_handle error = + process_plugin_result(r, md, num_md, status, error_details); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_request_metadata, error); + } else if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, + "plugin_credentials[%p]: request %p: plugin was previously " + "cancelled", + r->creds, r); + } + gpr_free(r); +} + +bool grpc_plugin_credentials::get_request_metadata( + grpc_polling_entity* /*pollent*/, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error_handle* error) { + bool retval = true; // Synchronous return. + if (plugin_.get_metadata != nullptr) { + // Create pending_request object. + pending_request* request = grpc_core::Zalloc(); + request->creds = this; + request->md_array = md_array; + request->on_request_metadata = on_request_metadata; + // Add it to the pending list. + gpr_mu_lock(&mu_); + if (pending_requests_ != nullptr) { + pending_requests_->prev = request; + } + request->next = pending_requests_; + pending_requests_ = request; + gpr_mu_unlock(&mu_); + // Invoke the plugin. The callback holds a ref to us. + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin", + this, request); + } + Ref().release(); + grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX]; + size_t num_creds_md = 0; + grpc_status_code status = GRPC_STATUS_OK; + const char* error_details = nullptr; + if (!plugin_.get_metadata( + plugin_.state, context, plugin_md_request_metadata_ready, request, + creds_md, &num_creds_md, &status, &error_details)) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, + "plugin_credentials[%p]: request %p: plugin will return " + "asynchronously", + this, request); + } + return false; // Asynchronous return. + } + // Returned synchronously. + // Remove request from pending list if not previously cancelled. + request->creds->pending_request_complete(request); + // If the request was cancelled, the error will have been returned + // asynchronously by plugin_cancel_get_request_metadata(), so return + // false. Otherwise, process the result. + if (request->cancelled) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, + "plugin_credentials[%p]: request %p was cancelled, error " + "will be returned asynchronously", + this, request); + } + retval = false; + } else { + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, + "plugin_credentials[%p]: request %p: plugin returned " + "synchronously", + this, request); + } + *error = process_plugin_result(request, creds_md, num_creds_md, status, + error_details); + } + // Clean up. + for (size_t i = 0; i < num_creds_md; ++i) { + grpc_slice_unref_internal(creds_md[i].key); + grpc_slice_unref_internal(creds_md[i].value); + } + gpr_free(const_cast(error_details)); + gpr_free(request); + } + return retval; +} + +void grpc_plugin_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error_handle error) { + gpr_mu_lock(&mu_); + for (pending_request* pending_request = pending_requests_; + pending_request != nullptr; pending_request = pending_request->next) { + if (pending_request->md_array == md_array) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) { + gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this, + pending_request); + } + pending_request->cancelled = true; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + pending_request->on_request_metadata, + GRPC_ERROR_REF(error)); + pending_request_remove_locked(pending_request); + break; + } + } + gpr_mu_unlock(&mu_); + GRPC_ERROR_UNREF(error); +} + +grpc_plugin_credentials::grpc_plugin_credentials( + grpc_metadata_credentials_plugin plugin, + grpc_security_level min_security_level) + : grpc_call_credentials(plugin.type, min_security_level), plugin_(plugin) { + gpr_mu_init(&mu_); +} + +grpc_call_credentials* grpc_metadata_credentials_create_from_plugin( + grpc_metadata_credentials_plugin plugin, + grpc_security_level min_security_level, void* reserved) { + GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1, + (reserved)); + GPR_ASSERT(reserved == nullptr); + return new grpc_plugin_credentials(plugin, min_security_level); +} diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc new file mode 100644 index 00000000..093f1ad0 --- /dev/null +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -0,0 +1,385 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/tsi/ssl_transport_security.h" + +// +// SSL Channel Credentials. +// + +void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp, + size_t num_key_cert_pairs) { + if (kp == nullptr) return; + for (size_t i = 0; i < num_key_cert_pairs; i++) { + gpr_free(const_cast(kp[i].private_key)); + gpr_free(const_cast(kp[i].cert_chain)); + } + gpr_free(kp); +} + +grpc_ssl_credentials::grpc_ssl_credentials( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const grpc_ssl_verify_peer_options* verify_options) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + build_config(pem_root_certs, pem_key_cert_pair, verify_options); +} + +grpc_ssl_credentials::~grpc_ssl_credentials() { + gpr_free(config_.pem_root_certs); + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1); + if (config_.verify_options.verify_peer_destruct != nullptr) { + config_.verify_options.verify_peer_destruct( + config_.verify_options.verify_peer_callback_userdata); + } +} + +grpc_core::RefCountedPtr +grpc_ssl_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) { + const char* overridden_target_name = nullptr; + tsi_ssl_session_cache* ssl_session_cache = nullptr; + for (size_t i = 0; args && i < args->num_args; i++) { + grpc_arg* arg = &args->args[i]; + if (strcmp(arg->key, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == 0 && + arg->type == GRPC_ARG_STRING) { + overridden_target_name = arg->value.string; + } + if (strcmp(arg->key, GRPC_SSL_SESSION_CACHE_ARG) == 0 && + arg->type == GRPC_ARG_POINTER) { + ssl_session_cache = + static_cast(arg->value.pointer.p); + } + } + grpc_core::RefCountedPtr sc = + grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name, ssl_session_cache); + if (sc == nullptr) { + return sc; + } + grpc_arg new_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_HTTP2_SCHEME), const_cast("https")); + *new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1); + return sc; +} + +void grpc_ssl_credentials::build_config( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const grpc_ssl_verify_peer_options* verify_options) { + config_.pem_root_certs = gpr_strdup(pem_root_certs); + if (pem_key_cert_pair != nullptr) { + GPR_ASSERT(pem_key_cert_pair->private_key != nullptr); + GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr); + config_.pem_key_cert_pair = static_cast( + gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair))); + config_.pem_key_cert_pair->cert_chain = + gpr_strdup(pem_key_cert_pair->cert_chain); + config_.pem_key_cert_pair->private_key = + gpr_strdup(pem_key_cert_pair->private_key); + } else { + config_.pem_key_cert_pair = nullptr; + } + if (verify_options != nullptr) { + memcpy(&config_.verify_options, verify_options, + sizeof(verify_peer_options)); + } else { + // Otherwise set all options to default values + memset(&config_.verify_options, 0, sizeof(verify_peer_options)); + } +} + +void grpc_ssl_credentials::set_min_tls_version( + grpc_tls_version min_tls_version) { + config_.min_tls_version = min_tls_version; +} + +void grpc_ssl_credentials::set_max_tls_version( + grpc_tls_version max_tls_version) { + config_.max_tls_version = max_tls_version; +} + +/* Deprecated in favor of grpc_ssl_credentials_create_ex. Will be removed + * once all of its call sites are migrated to grpc_ssl_credentials_create_ex. */ +grpc_channel_credentials* grpc_ssl_credentials_create( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options, void* reserved) { + GRPC_API_TRACE( + "grpc_ssl_credentials_create(pem_root_certs=%s, " + "pem_key_cert_pair=%p, " + "verify_options=%p, " + "reserved=%p)", + 4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved)); + GPR_ASSERT(reserved == nullptr); + + return new grpc_ssl_credentials( + pem_root_certs, pem_key_cert_pair, + reinterpret_cast(verify_options)); +} + +grpc_channel_credentials* grpc_ssl_credentials_create_ex( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const grpc_ssl_verify_peer_options* verify_options, void* reserved) { + GRPC_API_TRACE( + "grpc_ssl_credentials_create(pem_root_certs=%s, " + "pem_key_cert_pair=%p, " + "verify_options=%p, " + "reserved=%p)", + 4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved)); + GPR_ASSERT(reserved == nullptr); + + return new grpc_ssl_credentials(pem_root_certs, pem_key_cert_pair, + verify_options); +} + +// +// SSL Server Credentials. +// + +struct grpc_ssl_server_credentials_options { + grpc_ssl_client_certificate_request_type client_certificate_request; + grpc_ssl_server_certificate_config* certificate_config; + grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher; +}; + +grpc_ssl_server_credentials::grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options) + : grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + if (options.certificate_config_fetcher != nullptr) { + config_.client_certificate_request = options.client_certificate_request; + certificate_config_fetcher_ = *options.certificate_config_fetcher; + } else { + build_config(options.certificate_config->pem_root_certs, + options.certificate_config->pem_key_cert_pairs, + options.certificate_config->num_key_cert_pairs, + options.client_certificate_request); + } +} + +grpc_ssl_server_credentials::~grpc_ssl_server_credentials() { + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs, + config_.num_key_cert_pairs); + gpr_free(config_.pem_root_certs); +} +grpc_core::RefCountedPtr +grpc_ssl_server_credentials::create_security_connector( + const grpc_channel_args* /* args */) { + return grpc_ssl_server_security_connector_create(this->Ref()); +} + +tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( + const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs) { + tsi_ssl_pem_key_cert_pair* tsi_pairs = nullptr; + if (num_key_cert_pairs > 0) { + GPR_ASSERT(pem_key_cert_pairs != nullptr); + tsi_pairs = static_cast( + gpr_zalloc(num_key_cert_pairs * sizeof(tsi_ssl_pem_key_cert_pair))); + } + for (size_t i = 0; i < num_key_cert_pairs; i++) { + GPR_ASSERT(pem_key_cert_pairs[i].private_key != nullptr); + GPR_ASSERT(pem_key_cert_pairs[i].cert_chain != nullptr); + tsi_pairs[i].cert_chain = gpr_strdup(pem_key_cert_pairs[i].cert_chain); + tsi_pairs[i].private_key = gpr_strdup(pem_key_cert_pairs[i].private_key); + } + return tsi_pairs; +} + +void grpc_ssl_server_credentials::build_config( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs, + grpc_ssl_client_certificate_request_type client_certificate_request) { + config_.client_certificate_request = client_certificate_request; + config_.pem_root_certs = gpr_strdup(pem_root_certs); + config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + pem_key_cert_pairs, num_key_cert_pairs); + config_.num_key_cert_pairs = num_key_cert_pairs; +} + +void grpc_ssl_server_credentials::set_min_tls_version( + grpc_tls_version min_tls_version) { + config_.min_tls_version = min_tls_version; +} + +void grpc_ssl_server_credentials::set_max_tls_version( + grpc_tls_version max_tls_version) { + config_.max_tls_version = max_tls_version; +} + +grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( + const char* pem_root_certs, + const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs) { + grpc_ssl_server_certificate_config* config = + static_cast( + gpr_zalloc(sizeof(grpc_ssl_server_certificate_config))); + config->pem_root_certs = gpr_strdup(pem_root_certs); + if (num_key_cert_pairs > 0) { + GPR_ASSERT(pem_key_cert_pairs != nullptr); + config->pem_key_cert_pairs = static_cast( + gpr_zalloc(num_key_cert_pairs * sizeof(grpc_ssl_pem_key_cert_pair))); + } + config->num_key_cert_pairs = num_key_cert_pairs; + for (size_t i = 0; i < num_key_cert_pairs; i++) { + GPR_ASSERT(pem_key_cert_pairs[i].private_key != nullptr); + GPR_ASSERT(pem_key_cert_pairs[i].cert_chain != nullptr); + config->pem_key_cert_pairs[i].cert_chain = + gpr_strdup(pem_key_cert_pairs[i].cert_chain); + config->pem_key_cert_pairs[i].private_key = + gpr_strdup(pem_key_cert_pairs[i].private_key); + } + return config; +} + +void grpc_ssl_server_certificate_config_destroy( + grpc_ssl_server_certificate_config* config) { + if (config == nullptr) return; + for (size_t i = 0; i < config->num_key_cert_pairs; i++) { + gpr_free(const_cast(config->pem_key_cert_pairs[i].private_key)); + gpr_free(const_cast(config->pem_key_cert_pairs[i].cert_chain)); + } + gpr_free(config->pem_key_cert_pairs); + gpr_free(config->pem_root_certs); + gpr_free(config); +} + +grpc_ssl_server_credentials_options* +grpc_ssl_server_credentials_create_options_using_config( + grpc_ssl_client_certificate_request_type client_certificate_request, + grpc_ssl_server_certificate_config* config) { + grpc_ssl_server_credentials_options* options = nullptr; + if (config == nullptr) { + gpr_log(GPR_ERROR, "Certificate config must not be NULL."); + goto done; + } + options = static_cast( + gpr_zalloc(sizeof(grpc_ssl_server_credentials_options))); + options->client_certificate_request = client_certificate_request; + options->certificate_config = config; +done: + return options; +} + +grpc_ssl_server_credentials_options* +grpc_ssl_server_credentials_create_options_using_config_fetcher( + grpc_ssl_client_certificate_request_type client_certificate_request, + grpc_ssl_server_certificate_config_callback cb, void* user_data) { + if (cb == nullptr) { + gpr_log(GPR_ERROR, "Invalid certificate config callback parameter."); + return nullptr; + } + + grpc_ssl_server_certificate_config_fetcher* fetcher = + static_cast( + gpr_zalloc(sizeof(grpc_ssl_server_certificate_config_fetcher))); + fetcher->cb = cb; + fetcher->user_data = user_data; + + grpc_ssl_server_credentials_options* options = + static_cast( + gpr_zalloc(sizeof(grpc_ssl_server_credentials_options))); + options->client_certificate_request = client_certificate_request; + options->certificate_config_fetcher = fetcher; + + return options; +} + +grpc_server_credentials* grpc_ssl_server_credentials_create( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs, int force_client_auth, void* reserved) { + return grpc_ssl_server_credentials_create_ex( + pem_root_certs, pem_key_cert_pairs, num_key_cert_pairs, + force_client_auth + ? GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY + : GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, + reserved); +} + +grpc_server_credentials* grpc_ssl_server_credentials_create_ex( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs, + grpc_ssl_client_certificate_request_type client_certificate_request, + void* reserved) { + GRPC_API_TRACE( + "grpc_ssl_server_credentials_create_ex(" + "pem_root_certs=%s, pem_key_cert_pairs=%p, num_key_cert_pairs=%lu, " + "client_certificate_request=%d, reserved=%p)", + 5, + (pem_root_certs, pem_key_cert_pairs, (unsigned long)num_key_cert_pairs, + client_certificate_request, reserved)); + GPR_ASSERT(reserved == nullptr); + + grpc_ssl_server_certificate_config* cert_config = + grpc_ssl_server_certificate_config_create( + pem_root_certs, pem_key_cert_pairs, num_key_cert_pairs); + grpc_ssl_server_credentials_options* options = + grpc_ssl_server_credentials_create_options_using_config( + client_certificate_request, cert_config); + + return grpc_ssl_server_credentials_create_with_options(options); +} + +grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( + grpc_ssl_server_credentials_options* options) { + grpc_server_credentials* retval = nullptr; + + if (options == nullptr) { + gpr_log(GPR_ERROR, + "Invalid options trying to create SSL server credentials."); + goto done; + } + + if (options->certificate_config == nullptr && + options->certificate_config_fetcher == nullptr) { + gpr_log(GPR_ERROR, + "SSL server credentials options must specify either " + "certificate config or fetcher."); + goto done; + } else if (options->certificate_config_fetcher != nullptr && + options->certificate_config_fetcher->cb == nullptr) { + gpr_log(GPR_ERROR, "Certificate config fetcher callback must not be NULL."); + goto done; + } + + retval = new grpc_ssl_server_credentials(*options); + +done: + grpc_ssl_server_credentials_options_destroy(options); + return retval; +} + +void grpc_ssl_server_credentials_options_destroy( + grpc_ssl_server_credentials_options* o) { + if (o == nullptr) return; + gpr_free(o->certificate_config_fetcher); + grpc_ssl_server_certificate_config_destroy(o->certificate_config); + gpr_free(o); +} diff --git a/src/core/lib/security/credentials/tls/grpc_tls_certificate_distributor.cc b/src/core/lib/security/credentials/tls/grpc_tls_certificate_distributor.cc new file mode 100644 index 00000000..5fa17199 --- /dev/null +++ b/src/core/lib/security/credentials/tls/grpc_tls_certificate_distributor.cc @@ -0,0 +1,348 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_distributor.h" + +#include +#include + +#include +#include +#include + +void grpc_tls_certificate_distributor::SetKeyMaterials( + const std::string& cert_name, absl::optional pem_root_certs, + absl::optional pem_key_cert_pairs) { + GPR_ASSERT(pem_root_certs.has_value() || pem_key_cert_pairs.has_value()); + grpc_core::MutexLock lock(&mu_); + auto& cert_info = certificate_info_map_[cert_name]; + if (pem_root_certs.has_value()) { + // Successful credential updates will clear any pre-existing error. + cert_info.SetRootError(GRPC_ERROR_NONE); + for (auto* watcher_ptr : cert_info.root_cert_watchers) { + GPR_ASSERT(watcher_ptr != nullptr); + const auto watcher_it = watchers_.find(watcher_ptr); + GPR_ASSERT(watcher_it != watchers_.end()); + GPR_ASSERT(watcher_it->second.root_cert_name.has_value()); + absl::optional + pem_key_cert_pairs_to_report; + if (pem_key_cert_pairs.has_value() && + watcher_it->second.identity_cert_name == cert_name) { + pem_key_cert_pairs_to_report = pem_key_cert_pairs; + } else if (watcher_it->second.identity_cert_name.has_value()) { + auto& identity_cert_info = + certificate_info_map_[*watcher_it->second.identity_cert_name]; + if (!identity_cert_info.pem_key_cert_pairs.empty()) { + pem_key_cert_pairs_to_report = identity_cert_info.pem_key_cert_pairs; + } + } + watcher_ptr->OnCertificatesChanged( + pem_root_certs, std::move(pem_key_cert_pairs_to_report)); + } + cert_info.pem_root_certs = std::move(*pem_root_certs); + } + if (pem_key_cert_pairs.has_value()) { + // Successful credential updates will clear any pre-existing error. + cert_info.SetIdentityError(GRPC_ERROR_NONE); + for (const auto watcher_ptr : cert_info.identity_cert_watchers) { + GPR_ASSERT(watcher_ptr != nullptr); + const auto watcher_it = watchers_.find(watcher_ptr); + GPR_ASSERT(watcher_it != watchers_.end()); + GPR_ASSERT(watcher_it->second.identity_cert_name.has_value()); + absl::optional pem_root_certs_to_report; + if (pem_root_certs.has_value() && + watcher_it->second.root_cert_name == cert_name) { + // In this case, We've already sent the credential updates at the time + // when checking pem_root_certs, so we will skip here. + continue; + } else if (watcher_it->second.root_cert_name.has_value()) { + auto& root_cert_info = + certificate_info_map_[*watcher_it->second.root_cert_name]; + if (!root_cert_info.pem_root_certs.empty()) { + pem_root_certs_to_report = root_cert_info.pem_root_certs; + } + } + watcher_ptr->OnCertificatesChanged(pem_root_certs_to_report, + pem_key_cert_pairs); + } + cert_info.pem_key_cert_pairs = std::move(*pem_key_cert_pairs); + } +} + +bool grpc_tls_certificate_distributor::HasRootCerts( + const std::string& root_cert_name) { + grpc_core::MutexLock lock(&mu_); + const auto it = certificate_info_map_.find(root_cert_name); + return it != certificate_info_map_.end() && + !it->second.pem_root_certs.empty(); +}; + +bool grpc_tls_certificate_distributor::HasKeyCertPairs( + const std::string& identity_cert_name) { + grpc_core::MutexLock lock(&mu_); + const auto it = certificate_info_map_.find(identity_cert_name); + return it != certificate_info_map_.end() && + !it->second.pem_key_cert_pairs.empty(); +}; + +void grpc_tls_certificate_distributor::SetErrorForCert( + const std::string& cert_name, + absl::optional root_cert_error, + absl::optional identity_cert_error) { + GPR_ASSERT(root_cert_error.has_value() || identity_cert_error.has_value()); + grpc_core::MutexLock lock(&mu_); + CertificateInfo& cert_info = certificate_info_map_[cert_name]; + if (root_cert_error.has_value()) { + for (auto* watcher_ptr : cert_info.root_cert_watchers) { + GPR_ASSERT(watcher_ptr != nullptr); + const auto watcher_it = watchers_.find(watcher_ptr); + GPR_ASSERT(watcher_it != watchers_.end()); + // identity_cert_error_to_report is the error of the identity cert this + // watcher is watching, if there is any. + grpc_error_handle identity_cert_error_to_report = GRPC_ERROR_NONE; + if (identity_cert_error.has_value() && + watcher_it->second.identity_cert_name == cert_name) { + identity_cert_error_to_report = *identity_cert_error; + } else if (watcher_it->second.identity_cert_name.has_value()) { + auto& identity_cert_info = + certificate_info_map_[*watcher_it->second.identity_cert_name]; + identity_cert_error_to_report = identity_cert_info.identity_cert_error; + } + watcher_ptr->OnError(GRPC_ERROR_REF(*root_cert_error), + GRPC_ERROR_REF(identity_cert_error_to_report)); + } + cert_info.SetRootError(*root_cert_error); + } + if (identity_cert_error.has_value()) { + for (auto* watcher_ptr : cert_info.identity_cert_watchers) { + GPR_ASSERT(watcher_ptr != nullptr); + const auto watcher_it = watchers_.find(watcher_ptr); + GPR_ASSERT(watcher_it != watchers_.end()); + // root_cert_error_to_report is the error of the root cert this watcher is + // watching, if there is any. + grpc_error_handle root_cert_error_to_report = GRPC_ERROR_NONE; + if (root_cert_error.has_value() && + watcher_it->second.root_cert_name == cert_name) { + // In this case, We've already sent the error updates at the time when + // checking root_cert_error, so we will skip here. + continue; + } else if (watcher_it->second.root_cert_name.has_value()) { + auto& root_cert_info = + certificate_info_map_[*watcher_it->second.root_cert_name]; + root_cert_error_to_report = root_cert_info.root_cert_error; + } + watcher_ptr->OnError(GRPC_ERROR_REF(root_cert_error_to_report), + GRPC_ERROR_REF(*identity_cert_error)); + } + cert_info.SetIdentityError(*identity_cert_error); + } +}; + +void grpc_tls_certificate_distributor::SetError(grpc_error_handle error) { + GPR_ASSERT(error != GRPC_ERROR_NONE); + grpc_core::MutexLock lock(&mu_); + for (const auto& watcher : watchers_) { + const auto watcher_ptr = watcher.first; + GPR_ASSERT(watcher_ptr != nullptr); + const auto& watcher_info = watcher.second; + watcher_ptr->OnError( + watcher_info.root_cert_name.has_value() ? GRPC_ERROR_REF(error) + : GRPC_ERROR_NONE, + watcher_info.identity_cert_name.has_value() ? GRPC_ERROR_REF(error) + : GRPC_ERROR_NONE); + } + for (auto& cert_info_entry : certificate_info_map_) { + auto& cert_info = cert_info_entry.second; + cert_info.SetRootError(GRPC_ERROR_REF(error)); + cert_info.SetIdentityError(GRPC_ERROR_REF(error)); + } + GRPC_ERROR_UNREF(error); +}; + +void grpc_tls_certificate_distributor::WatchTlsCertificates( + std::unique_ptr watcher, + absl::optional root_cert_name, + absl::optional identity_cert_name) { + bool start_watching_root_cert = false; + bool already_watching_identity_for_root_cert = false; + bool start_watching_identity_cert = false; + bool already_watching_root_for_identity_cert = false; + GPR_ASSERT(root_cert_name.has_value() || identity_cert_name.has_value()); + TlsCertificatesWatcherInterface* watcher_ptr = watcher.get(); + GPR_ASSERT(watcher_ptr != nullptr); + // Update watchers_ and certificate_info_map_. + { + grpc_core::MutexLock lock(&mu_); + const auto watcher_it = watchers_.find(watcher_ptr); + // The caller needs to cancel the watcher first if it wants to re-register + // the watcher. + GPR_ASSERT(watcher_it == watchers_.end()); + watchers_[watcher_ptr] = {std::move(watcher), root_cert_name, + identity_cert_name}; + absl::optional updated_root_certs; + absl::optional updated_identity_pairs; + grpc_error_handle root_error = GRPC_ERROR_NONE; + grpc_error_handle identity_error = GRPC_ERROR_NONE; + if (root_cert_name.has_value()) { + CertificateInfo& cert_info = certificate_info_map_[*root_cert_name]; + start_watching_root_cert = cert_info.root_cert_watchers.empty(); + already_watching_identity_for_root_cert = + !cert_info.identity_cert_watchers.empty(); + cert_info.root_cert_watchers.insert(watcher_ptr); + root_error = GRPC_ERROR_REF(cert_info.root_cert_error); + // Empty credentials will be treated as no updates. + if (!cert_info.pem_root_certs.empty()) { + updated_root_certs = cert_info.pem_root_certs; + } + } + if (identity_cert_name.has_value()) { + CertificateInfo& cert_info = certificate_info_map_[*identity_cert_name]; + start_watching_identity_cert = cert_info.identity_cert_watchers.empty(); + already_watching_root_for_identity_cert = + !cert_info.root_cert_watchers.empty(); + cert_info.identity_cert_watchers.insert(watcher_ptr); + identity_error = GRPC_ERROR_REF(cert_info.identity_cert_error); + // Empty credentials will be treated as no updates. + if (!cert_info.pem_key_cert_pairs.empty()) { + updated_identity_pairs = cert_info.pem_key_cert_pairs; + } + } + // Notify this watcher if the certs it is watching already had some + // contents. Note that an *_cert_error in cert_info only indicates error + // occurred while trying to fetch the latest cert, but the updated_*_certs + // should always be valid. So we will send the updates regardless of + // *_cert_error. + if (updated_root_certs.has_value() || updated_identity_pairs.has_value()) { + watcher_ptr->OnCertificatesChanged(updated_root_certs, + std::move(updated_identity_pairs)); + } + // Notify this watcher if the certs it is watching already had some errors. + if (root_error != GRPC_ERROR_NONE || identity_error != GRPC_ERROR_NONE) { + watcher_ptr->OnError(GRPC_ERROR_REF(root_error), + GRPC_ERROR_REF(identity_error)); + } + GRPC_ERROR_UNREF(root_error); + GRPC_ERROR_UNREF(identity_error); + } + // Invoke watch status callback if needed. + { + grpc_core::MutexLock lock(&callback_mu_); + if (watch_status_callback_ != nullptr) { + if (root_cert_name == identity_cert_name && + (start_watching_root_cert || start_watching_identity_cert)) { + watch_status_callback_(*root_cert_name, start_watching_root_cert, + start_watching_identity_cert); + } else { + if (start_watching_root_cert) { + watch_status_callback_(*root_cert_name, true, + already_watching_identity_for_root_cert); + } + if (start_watching_identity_cert) { + watch_status_callback_(*identity_cert_name, + already_watching_root_for_identity_cert, true); + } + } + } + } +}; + +void grpc_tls_certificate_distributor::CancelTlsCertificatesWatch( + TlsCertificatesWatcherInterface* watcher) { + absl::optional root_cert_name; + absl::optional identity_cert_name; + bool stop_watching_root_cert = false; + bool already_watching_identity_for_root_cert = false; + bool stop_watching_identity_cert = false; + bool already_watching_root_for_identity_cert = false; + // Update watchers_ and certificate_info_map_. + { + grpc_core::MutexLock lock(&mu_); + auto it = watchers_.find(watcher); + if (it == watchers_.end()) return; + WatcherInfo& watcher_info = it->second; + root_cert_name = std::move(watcher_info.root_cert_name); + identity_cert_name = std::move(watcher_info.identity_cert_name); + watchers_.erase(it); + if (root_cert_name.has_value()) { + auto it = certificate_info_map_.find(*root_cert_name); + GPR_ASSERT(it != certificate_info_map_.end()); + CertificateInfo& cert_info = it->second; + cert_info.root_cert_watchers.erase(watcher); + stop_watching_root_cert = cert_info.root_cert_watchers.empty(); + already_watching_identity_for_root_cert = + !cert_info.identity_cert_watchers.empty(); + if (stop_watching_root_cert && !already_watching_identity_for_root_cert) { + certificate_info_map_.erase(it); + } + } + if (identity_cert_name.has_value()) { + auto it = certificate_info_map_.find(*identity_cert_name); + GPR_ASSERT(it != certificate_info_map_.end()); + CertificateInfo& cert_info = it->second; + cert_info.identity_cert_watchers.erase(watcher); + stop_watching_identity_cert = cert_info.identity_cert_watchers.empty(); + already_watching_root_for_identity_cert = + !cert_info.root_cert_watchers.empty(); + if (stop_watching_identity_cert && + !already_watching_root_for_identity_cert) { + certificate_info_map_.erase(it); + } + } + } + // Invoke watch status callback if needed. + { + grpc_core::MutexLock lock(&callback_mu_); + if (watch_status_callback_ != nullptr) { + if (root_cert_name == identity_cert_name && + (stop_watching_root_cert || stop_watching_identity_cert)) { + watch_status_callback_(*root_cert_name, !stop_watching_root_cert, + !stop_watching_identity_cert); + } else { + if (stop_watching_root_cert) { + watch_status_callback_(*root_cert_name, false, + already_watching_identity_for_root_cert); + } + if (stop_watching_identity_cert) { + watch_status_callback_(*identity_cert_name, + already_watching_root_for_identity_cert, + false); + } + } + } + } +}; + +/** -- Wrapper APIs declared in grpc_security.h -- **/ + +grpc_tls_identity_pairs* grpc_tls_identity_pairs_create() { + return new grpc_tls_identity_pairs(); +} + +void grpc_tls_identity_pairs_add_pair(grpc_tls_identity_pairs* pairs, + const char* private_key, + const char* cert_chain) { + GPR_ASSERT(pairs != nullptr); + GPR_ASSERT(private_key != nullptr); + GPR_ASSERT(cert_chain != nullptr); + pairs->pem_key_cert_pairs.emplace_back(private_key, cert_chain); +} + +void grpc_tls_identity_pairs_destroy(grpc_tls_identity_pairs* pairs) { + GPR_ASSERT(pairs != nullptr); + delete pairs; +} diff --git a/src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.cc b/src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.cc new file mode 100644 index 00000000..65357eaa --- /dev/null +++ b/src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.cc @@ -0,0 +1,455 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h" + +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/stat.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" + +namespace grpc_core { + +StaticDataCertificateProvider::StaticDataCertificateProvider( + std::string root_certificate, + grpc_core::PemKeyCertPairList pem_key_cert_pairs) + : distributor_(MakeRefCounted()), + root_certificate_(std::move(root_certificate)), + pem_key_cert_pairs_(std::move(pem_key_cert_pairs)) { + distributor_->SetWatchStatusCallback([this](std::string cert_name, + bool root_being_watched, + bool identity_being_watched) { + grpc_core::MutexLock lock(&mu_); + absl::optional root_certificate; + absl::optional pem_key_cert_pairs; + StaticDataCertificateProvider::WatcherInfo& info = watcher_info_[cert_name]; + if (!info.root_being_watched && root_being_watched && + !root_certificate_.empty()) { + root_certificate = root_certificate_; + } + info.root_being_watched = root_being_watched; + if (!info.identity_being_watched && identity_being_watched && + !pem_key_cert_pairs_.empty()) { + pem_key_cert_pairs = pem_key_cert_pairs_; + } + info.identity_being_watched = identity_being_watched; + if (!info.root_being_watched && !info.identity_being_watched) { + watcher_info_.erase(cert_name); + } + const bool root_has_update = root_certificate.has_value(); + const bool identity_has_update = pem_key_cert_pairs.has_value(); + if (root_has_update || identity_has_update) { + distributor_->SetKeyMaterials(cert_name, std::move(root_certificate), + std::move(pem_key_cert_pairs)); + } + grpc_error_handle root_cert_error = GRPC_ERROR_NONE; + grpc_error_handle identity_cert_error = GRPC_ERROR_NONE; + if (root_being_watched && !root_has_update) { + root_cert_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest root certificates."); + } + if (identity_being_watched && !identity_has_update) { + identity_cert_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest identity certificates."); + } + if (root_cert_error != GRPC_ERROR_NONE || + identity_cert_error != GRPC_ERROR_NONE) { + distributor_->SetErrorForCert(cert_name, root_cert_error, + identity_cert_error); + } + }); +} + +StaticDataCertificateProvider::~StaticDataCertificateProvider() { + // Reset distributor's callback to make sure the callback won't be invoked + // again after this object(provider) is destroyed. + distributor_->SetWatchStatusCallback(nullptr); +} + +namespace { + +gpr_timespec TimeoutSecondsToDeadline(int64_t seconds) { + return gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(seconds, GPR_TIMESPAN)); +} + +} // namespace + +FileWatcherCertificateProvider::FileWatcherCertificateProvider( + std::string private_key_path, std::string identity_certificate_path, + std::string root_cert_path, unsigned int refresh_interval_sec) + : private_key_path_(std::move(private_key_path)), + identity_certificate_path_(std::move(identity_certificate_path)), + root_cert_path_(std::move(root_cert_path)), + refresh_interval_sec_(refresh_interval_sec), + distributor_(MakeRefCounted()) { + // Private key and identity cert files must be both set or both unset. + GPR_ASSERT(private_key_path_.empty() == identity_certificate_path_.empty()); + // Must be watching either root or identity certs. + GPR_ASSERT(!private_key_path_.empty() || !root_cert_path_.empty()); + gpr_event_init(&shutdown_event_); + ForceUpdate(); + auto thread_lambda = [](void* arg) { + FileWatcherCertificateProvider* provider = + static_cast(arg); + GPR_ASSERT(provider != nullptr); + while (true) { + void* value = gpr_event_wait( + &provider->shutdown_event_, + TimeoutSecondsToDeadline(provider->refresh_interval_sec_)); + if (value != nullptr) { + return; + }; + provider->ForceUpdate(); + } + }; + refresh_thread_ = grpc_core::Thread( + "FileWatcherCertificateProvider_refreshing_thread", thread_lambda, this); + refresh_thread_.Start(); + distributor_->SetWatchStatusCallback([this](std::string cert_name, + bool root_being_watched, + bool identity_being_watched) { + grpc_core::MutexLock lock(&mu_); + absl::optional root_certificate; + absl::optional pem_key_cert_pairs; + FileWatcherCertificateProvider::WatcherInfo& info = + watcher_info_[cert_name]; + if (!info.root_being_watched && root_being_watched && + !root_certificate_.empty()) { + root_certificate = root_certificate_; + } + info.root_being_watched = root_being_watched; + if (!info.identity_being_watched && identity_being_watched && + !pem_key_cert_pairs_.empty()) { + pem_key_cert_pairs = pem_key_cert_pairs_; + } + info.identity_being_watched = identity_being_watched; + if (!info.root_being_watched && !info.identity_being_watched) { + watcher_info_.erase(cert_name); + } + ExecCtx exec_ctx; + if (root_certificate.has_value() || pem_key_cert_pairs.has_value()) { + distributor_->SetKeyMaterials(cert_name, root_certificate, + pem_key_cert_pairs); + } + grpc_error_handle root_cert_error = GRPC_ERROR_NONE; + grpc_error_handle identity_cert_error = GRPC_ERROR_NONE; + if (root_being_watched && !root_certificate.has_value()) { + root_cert_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest root certificates."); + } + if (identity_being_watched && !pem_key_cert_pairs.has_value()) { + identity_cert_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest identity certificates."); + } + if (root_cert_error != GRPC_ERROR_NONE || + identity_cert_error != GRPC_ERROR_NONE) { + distributor_->SetErrorForCert(cert_name, root_cert_error, + identity_cert_error); + } + }); +} + +FileWatcherCertificateProvider::~FileWatcherCertificateProvider() { + // Reset distributor's callback to make sure the callback won't be invoked + // again after this object(provider) is destroyed. + distributor_->SetWatchStatusCallback(nullptr); + gpr_event_set(&shutdown_event_, reinterpret_cast(1)); + refresh_thread_.Join(); +} + +void FileWatcherCertificateProvider::ForceUpdate() { + absl::optional root_certificate; + absl::optional pem_key_cert_pairs; + if (!root_cert_path_.empty()) { + root_certificate = ReadRootCertificatesFromFile(root_cert_path_); + } + if (!private_key_path_.empty()) { + pem_key_cert_pairs = ReadIdentityKeyCertPairFromFiles( + private_key_path_, identity_certificate_path_); + } + grpc_core::MutexLock lock(&mu_); + const bool root_cert_changed = + (!root_certificate.has_value() && !root_certificate_.empty()) || + (root_certificate.has_value() && root_certificate_ != *root_certificate); + if (root_cert_changed) { + if (root_certificate.has_value()) { + root_certificate_ = std::move(*root_certificate); + } else { + root_certificate_ = ""; + } + } + const bool identity_cert_changed = + (!pem_key_cert_pairs.has_value() && !pem_key_cert_pairs_.empty()) || + (pem_key_cert_pairs.has_value() && + pem_key_cert_pairs_ != *pem_key_cert_pairs); + if (identity_cert_changed) { + if (pem_key_cert_pairs.has_value()) { + pem_key_cert_pairs_ = std::move(*pem_key_cert_pairs); + } else { + pem_key_cert_pairs_ = {}; + } + } + if (root_cert_changed || identity_cert_changed) { + ExecCtx exec_ctx; + grpc_error_handle root_cert_error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest root certificates."); + grpc_error_handle identity_cert_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Unable to get latest identity certificates."); + for (const auto& p : watcher_info_) { + const std::string& cert_name = p.first; + const WatcherInfo& info = p.second; + absl::optional root_to_report; + absl::optional identity_to_report; + // Set key materials to the distributor if their contents changed. + if (info.root_being_watched && !root_certificate_.empty() && + root_cert_changed) { + root_to_report = root_certificate_; + } + if (info.identity_being_watched && !pem_key_cert_pairs_.empty() && + identity_cert_changed) { + identity_to_report = pem_key_cert_pairs_; + } + if (root_to_report.has_value() || identity_to_report.has_value()) { + distributor_->SetKeyMaterials(cert_name, std::move(root_to_report), + std::move(identity_to_report)); + } + // Report errors to the distributor if the contents are empty. + const bool report_root_error = + info.root_being_watched && root_certificate_.empty(); + const bool report_identity_error = + info.identity_being_watched && pem_key_cert_pairs_.empty(); + if (report_root_error || report_identity_error) { + distributor_->SetErrorForCert( + cert_name, + report_root_error ? GRPC_ERROR_REF(root_cert_error) + : GRPC_ERROR_NONE, + report_identity_error ? GRPC_ERROR_REF(identity_cert_error) + : GRPC_ERROR_NONE); + } + } + GRPC_ERROR_UNREF(root_cert_error); + GRPC_ERROR_UNREF(identity_cert_error); + } +} + +absl::optional +FileWatcherCertificateProvider::ReadRootCertificatesFromFile( + const std::string& root_cert_full_path) { + // Read the root file. + grpc_slice root_slice = grpc_empty_slice(); + grpc_error_handle root_error = + grpc_load_file(root_cert_full_path.c_str(), 0, &root_slice); + if (root_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Reading file %s failed: %s", + root_cert_full_path.c_str(), + grpc_error_std_string(root_error).c_str()); + GRPC_ERROR_UNREF(root_error); + return absl::nullopt; + } + std::string root_cert(StringViewFromSlice(root_slice)); + grpc_slice_unref_internal(root_slice); + return root_cert; +} + +namespace { + +// This helper function gets the last-modified time of |filename|. When failed, +// it logs the error and returns 0. +time_t GetModificationTime(const char* filename) { + time_t ts = 0; + absl::Status status = grpc_core::GetFileModificationTime(filename, &ts); + return ts; +} + +} // namespace + +absl::optional +FileWatcherCertificateProvider::ReadIdentityKeyCertPairFromFiles( + const std::string& private_key_path, + const std::string& identity_certificate_path) { + struct SliceWrapper { + grpc_slice slice = grpc_empty_slice(); + ~SliceWrapper() { grpc_slice_unref_internal(slice); } + }; + const int kNumRetryAttempts = 3; + for (int i = 0; i < kNumRetryAttempts; ++i) { + // TODO(ZhenLian): replace the timestamp approach with key-match approach + // once the latter is implemented. + // Checking the last modification of identity files before reading. + time_t identity_key_ts_before = + GetModificationTime(private_key_path.c_str()); + if (identity_key_ts_before == 0) { + gpr_log( + GPR_ERROR, + "Failed to get the file's modification time of %s. Start retrying...", + private_key_path.c_str()); + continue; + } + time_t identity_cert_ts_before = + GetModificationTime(identity_certificate_path.c_str()); + if (identity_cert_ts_before == 0) { + gpr_log( + GPR_ERROR, + "Failed to get the file's modification time of %s. Start retrying...", + identity_certificate_path.c_str()); + continue; + } + // Read the identity files. + SliceWrapper key_slice, cert_slice; + grpc_error_handle key_error = + grpc_load_file(private_key_path.c_str(), 0, &key_slice.slice); + if (key_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Reading file %s failed: %s. Start retrying...", + private_key_path.c_str(), + grpc_error_std_string(key_error).c_str()); + GRPC_ERROR_UNREF(key_error); + continue; + } + grpc_error_handle cert_error = + grpc_load_file(identity_certificate_path.c_str(), 0, &cert_slice.slice); + if (cert_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Reading file %s failed: %s. Start retrying...", + identity_certificate_path.c_str(), + grpc_error_std_string(cert_error).c_str()); + GRPC_ERROR_UNREF(cert_error); + continue; + } + std::string private_key(StringViewFromSlice(key_slice.slice)); + std::string cert_chain(StringViewFromSlice(cert_slice.slice)); + PemKeyCertPairList identity_pairs; + identity_pairs.emplace_back(private_key, cert_chain); + // Checking the last modification of identity files before reading. + time_t identity_key_ts_after = + GetModificationTime(private_key_path.c_str()); + if (identity_key_ts_before != identity_key_ts_after) { + gpr_log(GPR_ERROR, + "Last modified time before and after reading %s is not the same. " + "Start retrying...", + private_key_path.c_str()); + continue; + } + time_t identity_cert_ts_after = + GetModificationTime(identity_certificate_path.c_str()); + if (identity_cert_ts_before != identity_cert_ts_after) { + gpr_log(GPR_ERROR, + "Last modified time before and after reading %s is not the same. " + "Start retrying...", + identity_certificate_path.c_str()); + continue; + } + return identity_pairs; + } + gpr_log(GPR_ERROR, + "All retry attempts failed. Will try again after the next interval."); + return absl::nullopt; +} + +absl::StatusOr PrivateKeyAndCertificateMatch( + absl::string_view private_key, absl::string_view cert_chain) { + if (private_key.empty()) { + return absl::InvalidArgumentError("Private key string is empty."); + } + if (cert_chain.empty()) { + return absl::InvalidArgumentError("Certificate string is empty."); + } + BIO* cert_bio = BIO_new_mem_buf(cert_chain.data(), cert_chain.size()); + if (cert_bio == nullptr) { + return absl::InvalidArgumentError( + "Conversion from certificate string to BIO failed."); + } + // Reads the first cert from the cert_chain which is expected to be the leaf + // cert + X509* x509 = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (x509 == nullptr) { + return absl::InvalidArgumentError( + "Conversion from PEM string to X509 failed."); + } + EVP_PKEY* public_evp_pkey = X509_get_pubkey(x509); + X509_free(x509); + if (public_evp_pkey == nullptr) { + return absl::InvalidArgumentError( + "Extraction of public key from x.509 certificate failed."); + } + BIO* private_key_bio = + BIO_new_mem_buf(private_key.data(), private_key.size()); + if (private_key_bio == nullptr) { + EVP_PKEY_free(public_evp_pkey); + return absl::InvalidArgumentError( + "Conversion from private key string to BIO failed."); + } + EVP_PKEY* private_evp_pkey = + PEM_read_bio_PrivateKey(private_key_bio, nullptr, nullptr, nullptr); + BIO_free(private_key_bio); + if (private_evp_pkey == nullptr) { + EVP_PKEY_free(public_evp_pkey); + return absl::InvalidArgumentError( + "Conversion from PEM string to EVP_PKEY failed."); + } + bool result = EVP_PKEY_cmp(private_evp_pkey, public_evp_pkey) == 1; + EVP_PKEY_free(private_evp_pkey); + EVP_PKEY_free(public_evp_pkey); + return result; +} + +} // namespace grpc_core + +/** -- Wrapper APIs declared in grpc_security.h -- **/ + +grpc_tls_certificate_provider* grpc_tls_certificate_provider_static_data_create( + const char* root_certificate, grpc_tls_identity_pairs* pem_key_cert_pairs) { + GPR_ASSERT(root_certificate != nullptr || pem_key_cert_pairs != nullptr); + grpc_core::ExecCtx exec_ctx; + grpc_core::PemKeyCertPairList identity_pairs_core; + if (pem_key_cert_pairs != nullptr) { + identity_pairs_core = std::move(pem_key_cert_pairs->pem_key_cert_pairs); + delete pem_key_cert_pairs; + } + std::string root_cert_core; + if (root_certificate != nullptr) { + root_cert_core = root_certificate; + } + return new grpc_core::StaticDataCertificateProvider( + std::move(root_cert_core), std::move(identity_pairs_core)); +} + +grpc_tls_certificate_provider* +grpc_tls_certificate_provider_file_watcher_create( + const char* private_key_path, const char* identity_certificate_path, + const char* root_cert_path, unsigned int refresh_interval_sec) { + grpc_core::ExecCtx exec_ctx; + return new grpc_core::FileWatcherCertificateProvider( + private_key_path == nullptr ? "" : private_key_path, + identity_certificate_path == nullptr ? "" : identity_certificate_path, + root_cert_path == nullptr ? "" : root_cert_path, refresh_interval_sec); +} + +void grpc_tls_certificate_provider_release( + grpc_tls_certificate_provider* provider) { + GRPC_API_TRACE("grpc_tls_certificate_provider_release(provider=%p)", 1, + (provider)); + grpc_core::ExecCtx exec_ctx; + if (provider != nullptr) provider->Unref(); +} diff --git a/src/core/lib/security/credentials/tls/grpc_tls_credentials_options.cc b/src/core/lib/security/credentials/tls/grpc_tls_credentials_options.cc new file mode 100644 index 00000000..bf26e707 --- /dev/null +++ b/src/core/lib/security/credentials/tls/grpc_tls_credentials_options.cc @@ -0,0 +1,177 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/surface/api_trace.h" + +/** -- gRPC TLS server authorization check API implementation. -- **/ +grpc_tls_server_authorization_check_config:: + grpc_tls_server_authorization_check_config( + const void* config_user_data, + int (*schedule)(void* config_user_data, + grpc_tls_server_authorization_check_arg* arg), + void (*cancel)(void* config_user_data, + grpc_tls_server_authorization_check_arg* arg), + void (*destruct)(void* config_user_data)) + : config_user_data_(const_cast(config_user_data)), + schedule_(schedule), + cancel_(cancel), + destruct_(destruct) {} + +grpc_tls_server_authorization_check_config:: + ~grpc_tls_server_authorization_check_config() { + if (destruct_ != nullptr) { + destruct_(config_user_data_); + } +} + +int grpc_tls_server_authorization_check_config::Schedule( + grpc_tls_server_authorization_check_arg* arg) const { + if (schedule_ == nullptr) { + gpr_log(GPR_ERROR, "schedule API is nullptr"); + if (arg != nullptr) { + arg->status = GRPC_STATUS_NOT_FOUND; + arg->error_details->set_error_details( + "schedule API in server authorization check config is nullptr"); + } + return 1; + } + if (arg != nullptr && context_ != nullptr) { + arg->config = const_cast(this); + } + return schedule_(config_user_data_, arg); +} + +void grpc_tls_server_authorization_check_config::Cancel( + grpc_tls_server_authorization_check_arg* arg) const { + if (cancel_ == nullptr) { + gpr_log(GPR_ERROR, "cancel API is nullptr."); + if (arg != nullptr) { + arg->status = GRPC_STATUS_NOT_FOUND; + arg->error_details->set_error_details( + "schedule API in server authorization check config is nullptr"); + } + return; + } + if (arg != nullptr) { + arg->config = const_cast(this); + } + cancel_(config_user_data_, arg); +} + +/** -- Wrapper APIs declared in grpc_security.h -- **/ + +grpc_tls_credentials_options* grpc_tls_credentials_options_create() { + grpc_core::ExecCtx exec_ctx; + return new grpc_tls_credentials_options(); +} + +void grpc_tls_credentials_options_set_cert_request_type( + grpc_tls_credentials_options* options, + grpc_ssl_client_certificate_request_type type) { + GPR_ASSERT(options != nullptr); + options->set_cert_request_type(type); +} + +void grpc_tls_credentials_options_set_server_verification_option( + grpc_tls_credentials_options* options, + grpc_tls_server_verification_option server_verification_option) { + GPR_ASSERT(options != nullptr); + options->set_server_verification_option(server_verification_option); +} + +void grpc_tls_credentials_options_set_certificate_provider( + grpc_tls_credentials_options* options, + grpc_tls_certificate_provider* provider) { + GPR_ASSERT(options != nullptr); + GPR_ASSERT(provider != nullptr); + grpc_core::ExecCtx exec_ctx; + options->set_certificate_provider( + provider->Ref(DEBUG_LOCATION, "set_certificate_provider")); +} + +void grpc_tls_credentials_options_watch_root_certs( + grpc_tls_credentials_options* options) { + GPR_ASSERT(options != nullptr); + options->set_watch_root_cert(true); +} + +void grpc_tls_credentials_options_set_root_cert_name( + grpc_tls_credentials_options* options, const char* root_cert_name) { + GPR_ASSERT(options != nullptr); + options->set_root_cert_name(root_cert_name); +} + +void grpc_tls_credentials_options_watch_identity_key_cert_pairs( + grpc_tls_credentials_options* options) { + GPR_ASSERT(options != nullptr); + options->set_watch_identity_pair(true); +} + +void grpc_tls_credentials_options_set_identity_cert_name( + grpc_tls_credentials_options* options, const char* identity_cert_name) { + GPR_ASSERT(options != nullptr); + options->set_identity_cert_name(identity_cert_name); +} + +void grpc_tls_credentials_options_set_server_authorization_check_config( + grpc_tls_credentials_options* options, + grpc_tls_server_authorization_check_config* config) { + GPR_ASSERT(options != nullptr); + GPR_ASSERT(config != nullptr); + grpc_core::ExecCtx exec_ctx; + options->set_server_authorization_check_config(config->Ref()); +} + +grpc_tls_server_authorization_check_config* +grpc_tls_server_authorization_check_config_create( + const void* config_user_data, + int (*schedule)(void* config_user_data, + grpc_tls_server_authorization_check_arg* arg), + void (*cancel)(void* config_user_data, + grpc_tls_server_authorization_check_arg* arg), + void (*destruct)(void* config_user_data)) { + if (schedule == nullptr) { + gpr_log(GPR_ERROR, + "Schedule API is nullptr in creating TLS server authorization " + "check config."); + return nullptr; + } + grpc_core::ExecCtx exec_ctx; + return new grpc_tls_server_authorization_check_config( + config_user_data, schedule, cancel, destruct); +} + +void grpc_tls_server_authorization_check_config_release( + grpc_tls_server_authorization_check_config* config) { + GRPC_API_TRACE( + "grpc_tls_server_authorization_check_config_release(config=%p)", 1, + (config)); + grpc_core::ExecCtx exec_ctx; + if (config != nullptr) config->Unref(); +} diff --git a/src/core/lib/security/credentials/tls/tls_credentials.cc b/src/core/lib/security/credentials/tls/tls_credentials.cc new file mode 100644 index 00000000..f5b05d8a --- /dev/null +++ b/src/core/lib/security/credentials/tls/tls_credentials.cc @@ -0,0 +1,133 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/tls/tls_credentials.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/security/security_connector/tls/tls_security_connector.h" + +#define GRPC_CREDENTIALS_TYPE_TLS "Tls" + +namespace { + +bool CredentialOptionSanityCheck(const grpc_tls_credentials_options* options, + bool is_client) { + if (options == nullptr) { + gpr_log(GPR_ERROR, "TLS credentials options is nullptr."); + return false; + } + // TODO(ZhenLian): remove this when it is also supported on server side. + if (!is_client && options->server_authorization_check_config() != nullptr) { + gpr_log(GPR_INFO, + "Server's credentials options should not contain server " + "authorization check config."); + } + if (options->server_verification_option() != GRPC_TLS_SERVER_VERIFICATION && + options->server_authorization_check_config() == nullptr) { + gpr_log(GPR_ERROR, + "Should provider custom verifications if bypassing default ones."); + return false; + } + return true; +} + +} // namespace + +TlsCredentials::TlsCredentials( + grpc_core::RefCountedPtr options) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_TLS), + options_(std::move(options)) {} + +TlsCredentials::~TlsCredentials() {} + +grpc_core::RefCountedPtr +TlsCredentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) { + const char* overridden_target_name = nullptr; + tsi_ssl_session_cache* ssl_session_cache = nullptr; + for (size_t i = 0; args != nullptr && i < args->num_args; i++) { + grpc_arg* arg = &args->args[i]; + if (strcmp(arg->key, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == 0 && + arg->type == GRPC_ARG_STRING) { + overridden_target_name = arg->value.string; + } + if (strcmp(arg->key, GRPC_SSL_SESSION_CACHE_ARG) == 0 && + arg->type == GRPC_ARG_POINTER) { + ssl_session_cache = + static_cast(arg->value.pointer.p); + } + } + grpc_core::RefCountedPtr sc = + grpc_core::TlsChannelSecurityConnector::CreateTlsChannelSecurityConnector( + this->Ref(), options_, std::move(call_creds), target_name, + overridden_target_name, ssl_session_cache); + if (sc == nullptr) { + return nullptr; + } + if (args != nullptr) { + grpc_arg new_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_HTTP2_SCHEME), const_cast("https")); + *new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1); + } + return sc; +} + +TlsServerCredentials::TlsServerCredentials( + grpc_core::RefCountedPtr options) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_TLS), + options_(std::move(options)) {} + +TlsServerCredentials::~TlsServerCredentials() {} + +grpc_core::RefCountedPtr +TlsServerCredentials::create_security_connector( + const grpc_channel_args* /* args */) { + return grpc_core::TlsServerSecurityConnector:: + CreateTlsServerSecurityConnector(this->Ref(), options_); +} + +/** -- Wrapper APIs declared in grpc_security.h -- **/ + +grpc_channel_credentials* grpc_tls_credentials_create( + grpc_tls_credentials_options* options) { + if (!CredentialOptionSanityCheck(options, true /* is_client */)) { + return nullptr; + } + return new TlsCredentials( + grpc_core::RefCountedPtr(options)); +} + +grpc_server_credentials* grpc_tls_server_credentials_create( + grpc_tls_credentials_options* options) { + if (!CredentialOptionSanityCheck(options, false /* is_client */)) { + return nullptr; + } + return new TlsServerCredentials( + grpc_core::RefCountedPtr(options)); +} diff --git a/src/core/lib/security/credentials/tls/tls_utils.cc b/src/core/lib/security/credentials/tls/tls_utils.cc new file mode 100644 index 00000000..2dcc1974 --- /dev/null +++ b/src/core/lib/security/credentials/tls/tls_utils.cc @@ -0,0 +1,123 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/lib/security/credentials/tls/tls_utils.h" + +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" + +namespace grpc_core { + +// Based on +// https://github.com/grpc/grpc-java/blob/ca12e7a339add0ef48202fb72434b9dc0df41756/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java#L62 +bool VerifySubjectAlternativeName(absl::string_view subject_alternative_name, + const std::string& matcher) { + if (subject_alternative_name.empty() || + absl::StartsWith(subject_alternative_name, ".")) { + // Illegal pattern/domain name + return false; + } + if (matcher.empty() || absl::StartsWith(matcher, ".")) { + // Illegal domain name + return false; + } + // Normalize \a subject_alternative_name and \a matcher by turning them into + // absolute domain names if they are not yet absolute. This is needed because + // server certificates do not normally contain absolute names or patterns, but + // they should be treated as absolute. At the same time, any + // subject_alternative_name presented to this method should also be treated as + // absolute for the purposes of matching to the server certificate. + std::string normalized_san = + absl::EndsWith(subject_alternative_name, ".") + ? std::string(subject_alternative_name) + : absl::StrCat(subject_alternative_name, "."); + std::string normalized_matcher = + absl::EndsWith(matcher, ".") ? matcher : absl::StrCat(matcher, "."); + absl::AsciiStrToLower(&normalized_san); + absl::AsciiStrToLower(&normalized_matcher); + if (!absl::StrContains(normalized_san, "*")) { + return normalized_san == normalized_matcher; + } + // WILDCARD PATTERN RULES: + // 1. Asterisk (*) is only permitted in the left-most domain name label and + // must be the only character in that label (i.e., must match the whole + // left-most label). For example, *.example.com is permitted, while + // *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are + // not permitted. + // 2. Asterisk (*) cannot match across domain name labels. + // For example, *.example.com matches test.example.com but does not match + // sub.test.example.com. + // 3. Wildcard patterns for single-label domain names are not permitted. + if (!absl::StartsWith(normalized_san, "*.")) { + // Asterisk (*) is only permitted in the left-most domain name label and + // must be the only character in that label + return false; + } + if (normalized_san == "*.") { + // Wildcard pattern for single-label domain name -- not permitted. + return false; + } + absl::string_view suffix = absl::string_view(normalized_san).substr(1); + if (absl::StrContains(suffix, "*")) { + // Asterisk (*) is not permitted in the suffix + return false; + } + if (!absl::EndsWith(normalized_matcher, suffix)) return false; + int suffix_start_index = normalized_matcher.length() - suffix.length(); + // Asterisk matching across domain labels is not permitted. + return suffix_start_index <= 0 /* should not happen */ || + normalized_matcher.find_last_of('.', suffix_start_index - 1) == + std::string::npos; +} + +absl::string_view GetAuthPropertyValue(grpc_auth_context* context, + const char* property_name) { + grpc_auth_property_iterator it = + grpc_auth_context_find_properties_by_name(context, property_name); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_DEBUG, "No value found for %s property.", property_name); + return ""; + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_DEBUG, "Multiple values found for %s property.", property_name); + return ""; + } + return absl::string_view(prop->value, prop->value_length); +} + +std::vector GetAuthPropertyArray(grpc_auth_context* context, + const char* property_name) { + std::vector values; + grpc_auth_property_iterator it = + grpc_auth_context_find_properties_by_name(context, property_name); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + while (prop != nullptr) { + values.emplace_back(prop->value, prop->value_length); + prop = grpc_auth_property_iterator_next(&it); + } + if (values.empty()) { + gpr_log(GPR_DEBUG, "No value found for %s property.", property_name); + } + return values; +} + +} // namespace grpc_core diff --git a/src/core/lib/security/credentials/xds/xds_credentials.cc b/src/core/lib/security/credentials/xds/xds_credentials.cc new file mode 100644 index 00000000..15b8e900 --- /dev/null +++ b/src/core/lib/security/credentials/xds/xds_credentials.cc @@ -0,0 +1,244 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/lib/security/credentials/xds/xds_credentials.h" + +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_certificate_provider.h" +#include "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h" +#include "src/core/lib/security/credentials/tls/tls_credentials.h" +#include "src/core/lib/security/credentials/tls/tls_utils.h" +#include "src/core/lib/uri/uri_parser.h" + +namespace grpc_core { + +const char kCredentialsTypeXds[] = "Xds"; + +namespace { + +bool XdsVerifySubjectAlternativeNames( + const char* const* subject_alternative_names, + size_t subject_alternative_names_size, + const std::vector& matchers) { + if (matchers.empty()) return true; + for (size_t i = 0; i < subject_alternative_names_size; ++i) { + for (const auto& matcher : matchers) { + if (matcher.type() == StringMatcher::Type::kExact) { + // For Exact match, use DNS rules for verifying SANs + // TODO(zhenlian): Right now, the SSL layer does not save the type of + // the SAN, so we are doing a DNS style verification for all SANs when + // the type is EXACT. When we expose the SAN type, change this to only + // do this verification when the SAN type is DNS and match type is + // kExact. For all other cases, we should use matcher.Match(). + if (VerifySubjectAlternativeName(subject_alternative_names[i], + matcher.string_matcher())) { + return true; + } + } else { + if (matcher.Match(subject_alternative_names[i])) { + return true; + } + } + } + } + return false; +} + +class ServerAuthCheck { + public: + ServerAuthCheck( + RefCountedPtr xds_certificate_provider, + std::string cluster_name) + : xds_certificate_provider_(std::move(xds_certificate_provider)), + cluster_name_(std::move(cluster_name)) {} + + static int Schedule(void* config_user_data, + grpc_tls_server_authorization_check_arg* arg) { + return static_cast(config_user_data)->ScheduleImpl(arg); + } + + static void Destroy(void* config_user_data) { + delete static_cast(config_user_data); + } + + private: + int ScheduleImpl(grpc_tls_server_authorization_check_arg* arg) { + if (XdsVerifySubjectAlternativeNames( + arg->subject_alternative_names, arg->subject_alternative_names_size, + xds_certificate_provider_->GetSanMatchers(cluster_name_))) { + arg->success = 1; + arg->status = GRPC_STATUS_OK; + } else { + arg->success = 0; + arg->status = GRPC_STATUS_UNAUTHENTICATED; + if (arg->error_details) { + arg->error_details->set_error_details( + "SANs from certificate did not match SANs from xDS control plane"); + } + } + return 0; /* synchronous check */ + } + + RefCountedPtr xds_certificate_provider_; + std::string cluster_name_; +}; + +} // namespace + +bool TestOnlyXdsVerifySubjectAlternativeNames( + const char* const* subject_alternative_names, + size_t subject_alternative_names_size, + const std::vector& matchers) { + return XdsVerifySubjectAlternativeNames( + subject_alternative_names, subject_alternative_names_size, matchers); +} + +// +// XdsCredentials +// + +RefCountedPtr +XdsCredentials::create_security_connector( + RefCountedPtr call_creds, const char* target_name, + const grpc_channel_args* args, grpc_channel_args** new_args) { + struct ChannelArgsDeleter { + const grpc_channel_args* args; + bool owned; + ~ChannelArgsDeleter() { + if (owned) grpc_channel_args_destroy(args); + } + }; + ChannelArgsDeleter temp_args{args, false}; + // TODO(yashykt): This arg will no longer need to be added after b/173119596 + // is fixed. + grpc_arg override_arg = grpc_channel_arg_string_create( + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + const_cast(target_name)); + const char* override_arg_name = GRPC_SSL_TARGET_NAME_OVERRIDE_ARG; + if (grpc_channel_args_find(args, override_arg_name) == nullptr) { + temp_args.args = grpc_channel_args_copy_and_add_and_remove( + args, &override_arg_name, 1, &override_arg, 1); + temp_args.owned = true; + } + RefCountedPtr security_connector; + auto xds_certificate_provider = + XdsCertificateProvider::GetFromChannelArgs(args); + if (xds_certificate_provider != nullptr) { + std::string cluster_name = + grpc_channel_args_find_string(args, GRPC_ARG_XDS_CLUSTER_NAME); + GPR_ASSERT(cluster_name.data() != nullptr); + const bool watch_root = + xds_certificate_provider->ProvidesRootCerts(cluster_name); + const bool watch_identity = + xds_certificate_provider->ProvidesIdentityCerts(cluster_name); + if (watch_root || watch_identity) { + auto tls_credentials_options = + MakeRefCounted(); + tls_credentials_options->set_certificate_provider( + xds_certificate_provider); + if (watch_root) { + tls_credentials_options->set_watch_root_cert(true); + tls_credentials_options->set_root_cert_name(cluster_name); + } + if (watch_identity) { + tls_credentials_options->set_watch_identity_pair(true); + tls_credentials_options->set_identity_cert_name(cluster_name); + } + tls_credentials_options->set_server_verification_option( + GRPC_TLS_SKIP_HOSTNAME_VERIFICATION); + auto* server_auth_check = new ServerAuthCheck(xds_certificate_provider, + std::move(cluster_name)); + tls_credentials_options->set_server_authorization_check_config( + MakeRefCounted( + server_auth_check, ServerAuthCheck::Schedule, nullptr, + ServerAuthCheck::Destroy)); + // TODO(yashkt): Creating a new TlsCreds object each time we create a + // security connector means that the security connector's cmp() method + // returns unequal for each instance, which means that every time an LB + // policy updates, all the subchannels will be recreated. This is + // going to lead to a lot of connection churn. Instead, we should + // either (a) change the TLS security connector's cmp() method to be + // smarter somehow, so that it compares unequal only when the + // tls_credentials_options have changed, or (b) cache the TlsCreds + // objects in the XdsCredentials object so that we can reuse the + // same one when creating new security connectors, swapping out the + // TlsCreds object only when the tls_credentials_options change. + // Option (a) would probably be better, although it may require some + // structural changes to the security connector API. + auto tls_credentials = + MakeRefCounted(std::move(tls_credentials_options)); + return tls_credentials->create_security_connector( + std::move(call_creds), target_name, temp_args.args, new_args); + } + } + GPR_ASSERT(fallback_credentials_ != nullptr); + return fallback_credentials_->create_security_connector( + std::move(call_creds), target_name, temp_args.args, new_args); +} + +// +// XdsServerCredentials +// + +RefCountedPtr +XdsServerCredentials::create_security_connector(const grpc_channel_args* args) { + auto xds_certificate_provider = + XdsCertificateProvider::GetFromChannelArgs(args); + // Identity certs are a must for TLS. + if (xds_certificate_provider != nullptr && + xds_certificate_provider->ProvidesIdentityCerts("")) { + auto tls_credentials_options = + MakeRefCounted(); + tls_credentials_options->set_watch_identity_pair(true); + tls_credentials_options->set_certificate_provider(xds_certificate_provider); + if (xds_certificate_provider->ProvidesRootCerts("")) { + tls_credentials_options->set_watch_root_cert(true); + if (xds_certificate_provider->GetRequireClientCertificate("")) { + tls_credentials_options->set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + } else { + tls_credentials_options->set_cert_request_type( + GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY); + } + } else { + // Do not request client certificate if there is no way to verify. + tls_credentials_options->set_cert_request_type( + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + } + auto tls_credentials = MakeRefCounted( + std::move(tls_credentials_options)); + return tls_credentials->create_security_connector(args); + } + return fallback_credentials_->create_security_connector(args); +} + +} // namespace grpc_core + +grpc_channel_credentials* grpc_xds_credentials_create( + grpc_channel_credentials* fallback_credentials) { + GPR_ASSERT(fallback_credentials != nullptr); + return new grpc_core::XdsCredentials(fallback_credentials->Ref()); +} + +grpc_server_credentials* grpc_xds_server_credentials_create( + grpc_server_credentials* fallback_credentials) { + GPR_ASSERT(fallback_credentials != nullptr); + return new grpc_core::XdsServerCredentials(fallback_credentials->Ref()); +} diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc new file mode 100644 index 00000000..131436df --- /dev/null +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -0,0 +1,311 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/transport.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/core/tsi/transport_security.h" + +void grpc_alts_set_rpc_protocol_versions( + grpc_gcp_rpc_protocol_versions* rpc_versions) { + grpc_gcp_rpc_protocol_versions_set_max(rpc_versions, + GRPC_PROTOCOL_VERSION_MAX_MAJOR, + GRPC_PROTOCOL_VERSION_MAX_MINOR); + grpc_gcp_rpc_protocol_versions_set_min(rpc_versions, + GRPC_PROTOCOL_VERSION_MIN_MAJOR, + GRPC_PROTOCOL_VERSION_MIN_MINOR); +} + +namespace { + +void alts_check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = + grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer); + tsi_peer_destruct(&peer); + grpc_error_handle error = + *auth_context != nullptr + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not get ALTS auth context from TSI peer"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); +} + +class grpc_alts_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_alts_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(GRPC_ALTS_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) {} + + ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_credentials* creds = + static_cast(channel_creds()); + size_t user_specified_max_frame_size = 0; + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE); + if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) { + user_specified_max_frame_size = grpc_channel_arg_get_integer( + arg, {0, 0, std::numeric_limits::max()}); + } + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), target_name_, + creds->handshaker_service_url(), true, interested_parties, + &handshaker, user_specified_max_frame_size) == TSI_OK); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this, args)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + bool check_call_host(absl::string_view host, + grpc_auth_context* /*auth_context*/, + grpc_closure* /*on_call_host_checked*/, + grpc_error_handle* error) override { + if (host.empty() || host != target_name_) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ALTS call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* /*on_call_host_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + private: + char* target_name_; +}; + +class grpc_alts_server_security_connector final + : public grpc_server_security_connector { + public: + explicit grpc_alts_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_ALTS_URL_SCHEME, + std::move(server_creds)) {} + + ~grpc_alts_server_security_connector() override = default; + + void add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* interested_parties, + grpc_core::HandshakeManager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_server_credentials* creds = + static_cast(server_creds()); + size_t user_specified_max_frame_size = 0; + const grpc_arg* arg = + grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE); + if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) { + user_specified_max_frame_size = grpc_channel_arg_get_integer( + arg, {0, 0, std::numeric_limits::max()}); + } + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), nullptr, creds->handshaker_service_url(), + false, interested_parties, &handshaker, + user_specified_max_frame_size) == TSI_OK); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this, args)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } +}; +} // namespace + +namespace grpc_core { +namespace internal { +grpc_core::RefCountedPtr +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { + if (peer == nullptr) { + gpr_log(GPR_ERROR, + "Invalid arguments to grpc_alts_auth_context_from_tsi_peer()"); + return nullptr; + } + /* Validate certificate type. */ + const tsi_peer_property* cert_type_prop = + tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); + if (cert_type_prop == nullptr || + strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE, + cert_type_prop->value.length) != 0) { + gpr_log(GPR_ERROR, "Invalid or missing certificate type property."); + return nullptr; + } + /* Check if security level exists. */ + const tsi_peer_property* security_level_prop = + tsi_peer_get_property_by_name(peer, TSI_SECURITY_LEVEL_PEER_PROPERTY); + if (security_level_prop == nullptr) { + gpr_log(GPR_ERROR, "Missing security level property."); + return nullptr; + } + /* Validate RPC protocol versions. */ + const tsi_peer_property* rpc_versions_prop = + tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS); + if (rpc_versions_prop == nullptr) { + gpr_log(GPR_ERROR, "Missing rpc protocol versions property."); + return nullptr; + } + grpc_gcp_rpc_protocol_versions local_versions, peer_versions; + grpc_alts_set_rpc_protocol_versions(&local_versions); + grpc_slice slice = grpc_slice_from_copied_buffer( + rpc_versions_prop->value.data, rpc_versions_prop->value.length); + bool decode_result = + grpc_gcp_rpc_protocol_versions_decode(slice, &peer_versions); + grpc_slice_unref_internal(slice); + if (!decode_result) { + gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions."); + return nullptr; + } + /* TODO: Pass highest common rpc protocol version to grpc caller. */ + bool check_result = grpc_gcp_rpc_protocol_versions_check( + &local_versions, &peer_versions, nullptr); + if (!check_result) { + gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions."); + return nullptr; + } + /* Validate ALTS Context. */ + const tsi_peer_property* alts_context_prop = + tsi_peer_get_property_by_name(peer, TSI_ALTS_CONTEXT); + if (alts_context_prop == nullptr) { + gpr_log(GPR_ERROR, "Missing alts context property."); + return nullptr; + } + /* Create auth context. */ + auto ctx = grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_ALTS_TRANSPORT_SECURITY_TYPE); + size_t i = 0; + for (i = 0; i < peer->property_count; i++) { + const tsi_peer_property* tsi_prop = &peer->properties[i]; + /* Add service account to auth context. */ + if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property( + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + tsi_prop->value.data, tsi_prop->value.length); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); + } + /* Add alts context to auth context. */ + if (strcmp(tsi_prop->name, TSI_ALTS_CONTEXT) == 0) { + grpc_auth_context_add_property(ctx.get(), TSI_ALTS_CONTEXT, + tsi_prop->value.data, + tsi_prop->value.length); + } + /* Add security level to auth context. */ + if (strcmp(tsi_prop->name, TSI_SECURITY_LEVEL_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + tsi_prop->value.data, tsi_prop->value.length); + } + } + if (!grpc_auth_context_peer_is_authenticated(ctx.get())) { + gpr_log(GPR_ERROR, "Invalid unauthenticated peer."); + ctx.reset(DEBUG_LOCATION, "test"); + return nullptr; + } + return ctx; +} + +} // namespace internal +} // namespace grpc_core + +grpc_core::RefCountedPtr +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid arguments to grpc_alts_channel_security_connector_create()"); + return nullptr; + } + return grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), target_name); +} + +grpc_core::RefCountedPtr +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + if (server_creds == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid arguments to grpc_alts_server_security_connector_create()"); + return nullptr; + } + return grpc_core::MakeRefCounted( + std::move(server_creds)); +} diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc new file mode 100644 index 00000000..12e4d045 --- /dev/null +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -0,0 +1,327 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/fake/fake_security_connector.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" +#include "src/core/ext/transport/chttp2/alpn/alpn.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/tsi/fake_transport_security.h" + +namespace { +class grpc_fake_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_fake_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target, const grpc_channel_args* args) + : grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_(gpr_strdup(target)), + expected_targets_( + gpr_strdup(grpc_fake_transport_get_expected_targets(args))), + is_lb_channel_(grpc_channel_args_find( + args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER) != + nullptr) { + const grpc_arg* target_name_override_arg = + grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); + if (target_name_override_arg != nullptr) { + target_name_override_ = + gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + } else { + target_name_override_ = nullptr; + } + } + + ~grpc_fake_channel_security_connector() override { + gpr_free(target_); + gpr_free(expected_targets_); + if (target_name_override_ != nullptr) gpr_free(target_name_override_); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override; + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_, other->target_); + if (c != 0) return c; + if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) { + c = grpc_core::QsortCompare(expected_targets_, other->expected_targets_); + } else { + c = strcmp(expected_targets_, other->expected_targets_); + } + if (c != 0) return c; + return grpc_core::QsortCompare(is_lb_channel_, other->is_lb_channel_); + } + + void add_handshakers(const grpc_channel_args* args, + grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( + tsi_create_fake_handshaker(/*is_client=*/true), this, args)); + } + + bool check_call_host(absl::string_view host, + grpc_auth_context* /*auth_context*/, + grpc_closure* /*on_call_host_checked*/, + grpc_error_handle* /*error*/) override { + absl::string_view authority_hostname; + absl::string_view authority_ignored_port; + absl::string_view target_hostname; + absl::string_view target_ignored_port; + grpc_core::SplitHostPort(host, &authority_hostname, + &authority_ignored_port); + grpc_core::SplitHostPort(target_, &target_hostname, &target_ignored_port); + if (target_name_override_ != nullptr) { + absl::string_view fake_security_target_name_override_hostname; + absl::string_view fake_security_target_name_override_ignored_port; + grpc_core::SplitHostPort( + target_name_override_, &fake_security_target_name_override_hostname, + &fake_security_target_name_override_ignored_port); + if (authority_hostname != fake_security_target_name_override_hostname) { + gpr_log(GPR_ERROR, + "Authority (host) '%s' != Fake Security Target override '%s'", + host.data(), + fake_security_target_name_override_hostname.data()); + abort(); + } + } else if (authority_hostname != target_hostname) { + gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", host.data(), + target_); + abort(); + } + return true; + } + + void cancel_check_call_host(grpc_closure* /*on_call_host_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + char* target() const { return target_; } + char* expected_targets() const { return expected_targets_; } + bool is_lb_channel() const { return is_lb_channel_; } + char* target_name_override() const { return target_name_override_; } + + private: + bool fake_check_target(const char* target, const char* set_str) const { + GPR_ASSERT(target != nullptr); + char** set = nullptr; + size_t set_size = 0; + gpr_string_split(set_str, ",", &set, &set_size); + bool found = false; + for (size_t i = 0; i < set_size; ++i) { + if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; + } + for (size_t i = 0; i < set_size; ++i) { + gpr_free(set[i]); + } + gpr_free(set); + return found; + } + + void fake_secure_name_check() const { + if (expected_targets_ == nullptr) return; + char** lbs_and_backends = nullptr; + size_t lbs_and_backends_size = 0; + bool success = false; + gpr_string_split(expected_targets_, ";", &lbs_and_backends, + &lbs_and_backends_size); + if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { + gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", + expected_targets_); + goto done; + } + if (is_lb_channel_) { + if (lbs_and_backends_size != 2) { + gpr_log(GPR_ERROR, + "Invalid expected targets arg value: '%s'. Expectations for LB " + "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", + expected_targets_); + goto done; + } + if (!fake_check_target(target_, lbs_and_backends[1])) { + gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", + target_, lbs_and_backends[1]); + goto done; + } + success = true; + } else { + if (!fake_check_target(target_, lbs_and_backends[0])) { + gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", + target_, lbs_and_backends[0]); + goto done; + } + success = true; + } + done: + for (size_t i = 0; i < lbs_and_backends_size; ++i) { + gpr_free(lbs_and_backends[i]); + } + gpr_free(lbs_and_backends); + if (!success) abort(); + } + + char* target_; + char* expected_targets_; + bool is_lb_channel_; + char* target_name_override_; +}; + +static void fake_check_peer( + grpc_security_connector* /*sc*/, tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + const char* prop_name; + grpc_error_handle error = GRPC_ERROR_NONE; + *auth_context = nullptr; + if (peer.property_count != 2) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Fake peers should only have 2 properties."); + goto end; + } + prop_name = peer.properties[0].name; + if (prop_name == nullptr || + strcmp(prop_name, TSI_CERTIFICATE_TYPE_PEER_PROPERTY) != 0) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unexpected property in fake peer: ", + prop_name == nullptr ? "" : prop_name)); + goto end; + } + if (strncmp(peer.properties[0].value.data, TSI_FAKE_CERTIFICATE_TYPE, + peer.properties[0].value.length) != 0) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid value for cert type property."); + goto end; + } + prop_name = peer.properties[1].name; + if (prop_name == nullptr || + strcmp(prop_name, TSI_SECURITY_LEVEL_PEER_PROPERTY) != 0) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Unexpected property in fake peer: ", + prop_name == nullptr ? "" : prop_name)); + goto end; + } + if (strncmp(peer.properties[1].value.data, TSI_FAKE_SECURITY_LEVEL, + peer.properties[1].value.length) != 0) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid value for security level property."); + goto end; + } + + *auth_context = grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property( + auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_FAKE_TRANSPORT_SECURITY_TYPE); + grpc_auth_context_add_cstring_property( + auth_context->get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + TSI_FAKE_SECURITY_LEVEL); +end: + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); +} + +void grpc_fake_channel_security_connector::check_peer( + tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + fake_check_peer(this, peer, auth_context, on_peer_checked); + fake_secure_name_check(); +} + +class grpc_fake_server_security_connector + : public grpc_server_security_connector { + public: + explicit grpc_fake_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(server_creds)) {} + ~grpc_fake_server_security_connector() override = default; + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + fake_check_peer(this, peer, auth_context, on_peer_checked); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + void add_handshakers(const grpc_channel_args* args, + grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( + tsi_create_fake_handshaker(/*=is_client*/ false), this, args)); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target, const grpc_channel_args* args) { + return grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), target, + args); +} + +grpc_core::RefCountedPtr +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + return grpc_core::MakeRefCounted( + std::move(server_creds)); +} diff --git a/src/core/lib/security/security_connector/insecure/insecure_security_connector.cc b/src/core/lib/security/security_connector/insecure/insecure_security_connector.cc new file mode 100644 index 00000000..360764e1 --- /dev/null +++ b/src/core/lib/security/security_connector/insecure/insecure_security_connector.cc @@ -0,0 +1,121 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/lib/security/security_connector/insecure/insecure_security_connector.h" + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/tsi/local_transport_security.h" + +namespace grpc_core { + +const char kInsecureTransportSecurityType[] = "insecure"; + +namespace { + +RefCountedPtr MakeAuthContext() { + auto ctx = MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + kInsecureTransportSecurityType); + const char* security_level = tsi_security_level_to_string(TSI_SECURITY_NONE); + grpc_auth_context_add_property(ctx.get(), + GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + security_level, strlen(security_level)); + return ctx; +} + +} // namespace + +RefCountedPtr TestOnlyMakeInsecureAuthContext() { + return MakeAuthContext(); +} + +// check_call_host and cancel_check_call_host are no-ops since we want to +// provide an insecure channel. +bool InsecureChannelSecurityConnector::check_call_host( + absl::string_view /*host*/, grpc_auth_context* /*auth_context*/, + grpc_closure* /*on_call_host_checked*/, grpc_error_handle* error) { + *error = GRPC_ERROR_NONE; + return true; +} + +void InsecureChannelSecurityConnector::cancel_check_call_host( + grpc_closure* /*on_call_host_checked*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +// add_handshakers should have been a no-op but we need to add a minimalist +// security handshaker so that check_peer is invoked and an auth_context is +// created with the security level of TSI_SECURITY_NONE. +void InsecureChannelSecurityConnector::add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /* interested_parties */, + HandshakeManager* handshake_manager) { + tsi_handshaker* handshaker = nullptr; + // Re-use local_tsi_handshaker_create as a minimalist handshaker. + GPR_ASSERT(tsi_local_handshaker_create(true /* is_client */, &handshaker) == + TSI_OK); + handshake_manager->Add(SecurityHandshakerCreate(handshaker, this, args)); +} + +void InsecureChannelSecurityConnector::check_peer( + tsi_peer peer, grpc_endpoint* /*ep*/, + RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = MakeAuthContext(); + tsi_peer_destruct(&peer); + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, GRPC_ERROR_NONE); +} + +int InsecureChannelSecurityConnector::cmp( + const grpc_security_connector* other_sc) const { + return channel_security_connector_cmp( + static_cast(other_sc)); +} + +// add_handshakers should have been a no-op but we need to add a minimalist +// security handshaker so that check_peer is invoked and an auth_context is +// created with the security level of TSI_SECURITY_NONE. +void InsecureServerSecurityConnector::add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /* interested_parties */, + grpc_core::HandshakeManager* handshake_manager) { + tsi_handshaker* handshaker = nullptr; + // Re-use local_tsi_handshaker_create as a minimalist handshaker. + GPR_ASSERT(tsi_local_handshaker_create(false /* is_client */, &handshaker) == + TSI_OK); + handshake_manager->Add(SecurityHandshakerCreate(handshaker, this, args)); +} + +void InsecureServerSecurityConnector::check_peer( + tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = MakeAuthContext(); + tsi_peer_destruct(&peer); + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, GRPC_ERROR_NONE); +} + +int InsecureServerSecurityConnector::cmp( + const grpc_security_connector* other) const { + return server_security_connector_cmp( + static_cast(other)); +} + +} // namespace grpc_core diff --git a/src/core/lib/security/security_connector/load_system_roots_fallback.cc b/src/core/lib/security/security_connector/load_system_roots_fallback.cc new file mode 100644 index 00000000..f448d3fc --- /dev/null +++ b/src/core/lib/security/security_connector/load_system_roots_fallback.cc @@ -0,0 +1,33 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/security/security_connector/load_system_roots.h" + +#if !defined(GPR_LINUX) && !defined(GPR_ANDROID) + +namespace grpc_core { + +grpc_slice LoadSystemRootCerts() { return grpc_empty_slice(); } + +} // namespace grpc_core + +#endif /* !(GPR_LINUX || GPR_ANDROID) */ diff --git a/src/core/lib/security/security_connector/load_system_roots_linux.cc b/src/core/lib/security/security_connector/load_system_roots_linux.cc new file mode 100644 index 00000000..da311da7 --- /dev/null +++ b/src/core/lib/security/security_connector/load_system_roots_linux.cc @@ -0,0 +1,171 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/load_system_roots_linux.h" + +#include + +#if defined(GPR_LINUX) || defined(GPR_ANDROID) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/global_config.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/security_connector/load_system_roots.h" + +GPR_GLOBAL_CONFIG_DEFINE_STRING(grpc_system_ssl_roots_dir, "", + "Custom directory to SSL Roots"); + +namespace grpc_core { +namespace { + +const char* kLinuxCertFiles[] = { + "/etc/ssl/certs/ca-certificates.crt", "/etc/pki/tls/certs/ca-bundle.crt", + "/etc/ssl/ca-bundle.pem", "/etc/pki/tls/cacert.pem", + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem"}; +const char* kLinuxCertDirectories[] = { + "/etc/ssl/certs", "/system/etc/security/cacerts", "/usr/local/share/certs", + "/etc/pki/tls/certs", "/etc/openssl/certs"}; + +grpc_slice GetSystemRootCerts() { + grpc_slice valid_bundle_slice = grpc_empty_slice(); + size_t num_cert_files_ = GPR_ARRAY_SIZE(kLinuxCertFiles); + for (size_t i = 0; i < num_cert_files_; i++) { + grpc_error_handle error = + grpc_load_file(kLinuxCertFiles[i], 1, &valid_bundle_slice); + if (error == GRPC_ERROR_NONE) { + return valid_bundle_slice; + } else { + GRPC_ERROR_UNREF(error); + } + } + return grpc_empty_slice(); +} + +} // namespace + +void GetAbsoluteFilePath(const char* valid_file_dir, + const char* file_entry_name, char* path_buffer) { + if (valid_file_dir != nullptr && file_entry_name != nullptr) { + int path_len = snprintf(path_buffer, MAXPATHLEN, "%s/%s", valid_file_dir, + file_entry_name); + if (path_len == 0) { + gpr_log(GPR_ERROR, "failed to get absolute path for file: %s", + file_entry_name); + } + } +} + +grpc_slice CreateRootCertsBundle(const char* certs_directory) { + grpc_slice bundle_slice = grpc_empty_slice(); + if (certs_directory == nullptr) { + return bundle_slice; + } + DIR* ca_directory = opendir(certs_directory); + if (ca_directory == nullptr) { + return bundle_slice; + } + struct FileData { + char path[MAXPATHLEN]; + off_t size; + }; + absl::InlinedVector roots_filenames; + size_t total_bundle_size = 0; + struct dirent* directory_entry; + while ((directory_entry = readdir(ca_directory)) != nullptr) { + struct stat dir_entry_stat; + const char* file_entry_name = directory_entry->d_name; + FileData file_data; + GetAbsoluteFilePath(certs_directory, file_entry_name, file_data.path); + int stat_return = stat(file_data.path, &dir_entry_stat); + if (stat_return == -1 || !S_ISREG(dir_entry_stat.st_mode)) { + // no subdirectories. + if (stat_return == -1) { + gpr_log(GPR_ERROR, "failed to get status for file: %s", file_data.path); + } + continue; + } + file_data.size = dir_entry_stat.st_size; + total_bundle_size += file_data.size; + roots_filenames.push_back(file_data); + } + closedir(ca_directory); + char* bundle_string = static_cast(gpr_zalloc(total_bundle_size + 1)); + size_t bytes_read = 0; + for (size_t i = 0; i < roots_filenames.size(); i++) { + int file_descriptor = open(roots_filenames[i].path, O_RDONLY); + if (file_descriptor != -1) { + // Read file into bundle. + size_t cert_file_size = roots_filenames[i].size; + int read_ret = + read(file_descriptor, bundle_string + bytes_read, cert_file_size); + if (read_ret != -1) { + bytes_read += read_ret; + } else { + gpr_log(GPR_ERROR, "failed to read file: %s", roots_filenames[i].path); + } + } + } + bundle_slice = grpc_slice_new(bundle_string, bytes_read, gpr_free); + return bundle_slice; +} + +grpc_slice LoadSystemRootCerts() { + grpc_slice result = grpc_empty_slice(); + // Prioritize user-specified custom directory if flag is set. + grpc_core::UniquePtr custom_dir = + GPR_GLOBAL_CONFIG_GET(grpc_system_ssl_roots_dir); + if (strlen(custom_dir.get()) > 0) { + result = CreateRootCertsBundle(custom_dir.get()); + } + // If the custom directory is empty/invalid/not specified, fallback to + // distribution-specific directory. + if (GRPC_SLICE_IS_EMPTY(result)) { + result = GetSystemRootCerts(); + } + if (GRPC_SLICE_IS_EMPTY(result)) { + for (size_t i = 0; i < GPR_ARRAY_SIZE(kLinuxCertDirectories); i++) { + result = CreateRootCertsBundle(kLinuxCertDirectories[i]); + if (!GRPC_SLICE_IS_EMPTY(result)) { + break; + } + } + } + return result; +} + +} // namespace grpc_core + +#endif /* GPR_LINUX || GPR_ANDROID */ diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc new file mode 100644 index 00000000..aec493c8 --- /dev/null +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -0,0 +1,294 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/local/local_security_connector.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "src/core/lib/security/credentials/local/local_credentials.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/lib/uri/uri_parser.h" +#include "src/core/tsi/local_transport_security.h" + +#define GRPC_UDS_URI_PATTERN "unix:" +#define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local" + +namespace { + +grpc_core::RefCountedPtr local_auth_context_create( + const tsi_peer* peer) { + /* Create auth context. */ + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_LOCAL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); + GPR_ASSERT(peer->property_count == 1); + const tsi_peer_property* prop = &peer->properties[0]; + GPR_ASSERT(prop != nullptr); + GPR_ASSERT(strcmp(prop->name, TSI_SECURITY_LEVEL_PEER_PROPERTY) == 0); + grpc_auth_context_add_property(ctx.get(), + GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + prop->value.data, prop->value.length); + return ctx; +} + +void local_check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked, + grpc_local_connect_type type) { + grpc_resolved_address resolved_addr; + bool is_endpoint_local = false; + absl::string_view local_addr = grpc_endpoint_get_local_address(ep); + absl::StatusOr uri = grpc_core::URI::Parse(local_addr); + if (!uri.ok() || !grpc_parse_uri(*uri, &resolved_addr)) { + gpr_log(GPR_ERROR, "Could not parse endpoint address: %s", + std::string(local_addr.data(), local_addr.size()).c_str()); + } else { + grpc_resolved_address addr_normalized; + grpc_resolved_address* addr = + grpc_sockaddr_is_v4mapped(&resolved_addr, &addr_normalized) + ? &addr_normalized + : &resolved_addr; + grpc_sockaddr* sock_addr = reinterpret_cast(&addr->addr); + // UDS + if (type == UDS && grpc_is_unix_socket(addr)) { + is_endpoint_local = true; + // IPV4 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast(sock_addr); + if (grpc_htonl(addr4->sin_addr.s_addr) == INADDR_LOOPBACK) { + is_endpoint_local = true; + } + // IPv6 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast(addr); + if (memcmp(&addr6->sin6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)) == 0) { + is_endpoint_local = true; + } + } + } + grpc_error_handle error; + if (!is_endpoint_local) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Endpoint is neither UDS or TCP loopback address."); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + return; + } + // Add TSI_SECURITY_LEVEL_PEER_PROPERTY type peer property. + size_t new_property_count = peer.property_count + 1; + tsi_peer_property* new_properties = static_cast( + gpr_zalloc(sizeof(*new_properties) * new_property_count)); + for (size_t i = 0; i < peer.property_count; i++) { + new_properties[i] = peer.properties[i]; + } + if (peer.properties != nullptr) gpr_free(peer.properties); + peer.properties = new_properties; + // TODO(yihuazhang): Set security level of local TCP to TSI_SECURITY_NONE. + const char* security_level = + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY); + tsi_result result = tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, security_level, + &peer.properties[peer.property_count]); + if (result != TSI_OK) return; + peer.property_count++; + /* Create an auth context which is necessary to pass the santiy check in + * {client, server}_auth_filter that verifies if the peer's auth context is + * obtained during handshakes. The auth context is only checked for its + * existence and not actually used. + */ + *auth_context = local_auth_context_create(&peer); + tsi_peer_destruct(&peer); + error = *auth_context != nullptr ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not create local auth context"); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); +} + +class grpc_local_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_local_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(nullptr, std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) {} + + ~grpc_local_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(tsi_local_handshaker_create(true /* is_client */, &handshaker) == + TSI_OK); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this, args)); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast( + other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_credentials* creds = + reinterpret_cast(mutable_channel_creds()); + local_check_peer(peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + bool check_call_host(absl::string_view host, + grpc_auth_context* /*auth_context*/, + grpc_closure* /*on_call_host_checked*/, + grpc_error_handle* error) override { + if (host.empty() || host != target_name_) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "local call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* /*on_call_host_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + const char* target_name() const { return target_name_; } + + private: + char* target_name_; +}; + +class grpc_local_server_security_connector final + : public grpc_server_security_connector { + public: + explicit grpc_local_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(nullptr, std::move(server_creds)) {} + ~grpc_local_server_security_connector() override = default; + + void add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(tsi_local_handshaker_create(false /* is_client */, + &handshaker) == TSI_OK); + handshake_manager->Add( + grpc_core::SecurityHandshakerCreate(handshaker, this, args)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_server_credentials* creds = + static_cast(mutable_server_creds()); + local_check_peer(peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_channel_args* args, const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid arguments to grpc_local_channel_security_connector_create()"); + return nullptr; + } + // Perform sanity check on UDS address. For TCP local connection, the check + // will be done during check_peer procedure. + grpc_local_credentials* creds = + static_cast(channel_creds.get()); + const grpc_arg* server_uri_arg = + grpc_channel_args_find(args, GRPC_ARG_SERVER_URI); + const char* server_uri_str = grpc_channel_arg_get_string(server_uri_arg); + if (creds->connect_type() == UDS && + strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, + strlen(GRPC_UDS_URI_PATTERN)) != 0) { + gpr_log(GPR_ERROR, + "Invalid UDS target name to " + "grpc_local_channel_security_connector_create()"); + return nullptr; + } + return grpc_core::MakeRefCounted( + channel_creds, request_metadata_creds, target_name); +} + +grpc_core::RefCountedPtr +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + if (server_creds == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid arguments to grpc_local_server_security_connector_create()"); + return nullptr; + } + return grpc_core::MakeRefCounted( + std::move(server_creds)); +} diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc new file mode 100644 index 00000000..67ce84ac --- /dev/null +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -0,0 +1,137 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/security_connector.h" + +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/security/transport/security_handshaker.h" + +grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount( + false, "security_connector_refcount"); + +grpc_server_security_connector::grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr server_creds) + : grpc_security_connector(url_scheme), + server_creds_(std::move(server_creds)) {} + +grpc_server_security_connector::~grpc_server_security_connector() = default; + +grpc_channel_security_connector::grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds) + : grpc_security_connector(url_scheme), + channel_creds_(std::move(channel_creds)), + request_metadata_creds_(std::move(request_metadata_creds)) {} + +grpc_channel_security_connector::~grpc_channel_security_connector() {} + +int grpc_security_connector_cmp(const grpc_security_connector* sc, + const grpc_security_connector* other) { + if (sc == nullptr || other == nullptr) { + return grpc_core::QsortCompare(sc, other); + } + return sc->cmp(other); +} + +int grpc_channel_security_connector::channel_security_connector_cmp( + const grpc_channel_security_connector* other) const { + const grpc_channel_security_connector* other_sc = + static_cast(other); + GPR_ASSERT(channel_creds() != nullptr); + GPR_ASSERT(other_sc->channel_creds() != nullptr); + int c = grpc_core::QsortCompare(channel_creds(), other_sc->channel_creds()); + if (c != 0) return c; + return grpc_core::QsortCompare(request_metadata_creds(), + other_sc->request_metadata_creds()); +} + +int grpc_server_security_connector::server_security_connector_cmp( + const grpc_server_security_connector* other) const { + const grpc_server_security_connector* other_sc = + static_cast(other); + GPR_ASSERT(server_creds() != nullptr); + GPR_ASSERT(other_sc->server_creds() != nullptr); + return grpc_core::QsortCompare(server_creds(), other_sc->server_creds()); +} + +static void connector_arg_destroy(void* p) { + if (p == nullptr) return; + static_cast(p)->Unref(DEBUG_LOCATION, + "connector_arg_destroy"); +} + +static void* connector_arg_copy(void* p) { + if (p == nullptr) return nullptr; + return static_cast(p) + ->Ref(DEBUG_LOCATION, "connector_arg_copy") + .release(); +} + +static int connector_cmp(void* a, void* b) { + return static_cast(a)->cmp( + static_cast(b)); +} + +static const grpc_arg_pointer_vtable connector_arg_vtable = { + connector_arg_copy, connector_arg_destroy, connector_cmp}; + +grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_SECURITY_CONNECTOR), sc, + &connector_arg_vtable); +} + +grpc_security_connector* grpc_security_connector_from_arg(const grpc_arg* arg) { + if (strcmp(arg->key, GRPC_ARG_SECURITY_CONNECTOR) != 0) return nullptr; + if (arg->type != GRPC_ARG_POINTER) { + gpr_log(GPR_ERROR, "Invalid type %d for arg %s", arg->type, + GRPC_ARG_SECURITY_CONNECTOR); + return nullptr; + } + return static_cast(arg->value.pointer.p); +} + +grpc_security_connector* grpc_security_connector_find_in_args( + const grpc_channel_args* args) { + size_t i; + if (args == nullptr) return nullptr; + for (i = 0; i < args->num_args; i++) { + grpc_security_connector* sc = + grpc_security_connector_from_arg(&args->args[i]); + if (sc != nullptr) return sc; + } + return nullptr; +} diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc new file mode 100644 index 00000000..c8581490 --- /dev/null +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -0,0 +1,453 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" +#include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security.h" + +namespace { +grpc_error_handle ssl_check_peer( + const char* peer_name, const tsi_peer* peer, + grpc_core::RefCountedPtr* auth_context) { + grpc_error_handle error = grpc_ssl_check_alpn(peer); + if (error != GRPC_ERROR_NONE) { + return error; + } + /* Check the peer name if specified. */ + if (peer_name != nullptr && !grpc_ssl_host_matches_name(peer, peer_name)) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Peer name ", peer_name, " is not in peer certificate")); + } + *auth_context = + grpc_ssl_peer_to_auth_context(peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + return GRPC_ERROR_NONE; +} + +class grpc_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_ssl_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name) + : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + overridden_target_name_( + overridden_target_name == nullptr ? "" : overridden_target_name), + verify_options_(&config->verify_options) { + absl::string_view host; + absl::string_view port; + grpc_core::SplitHostPort(target_name, &host, &port); + target_name_ = std::string(host); + } + + ~grpc_ssl_channel_security_connector() override { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + } + + grpc_security_status InitializeHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache) { + bool has_key_cert_pair = + config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + GPR_DEBUG_ASSERT(pem_root_certs != nullptr); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + options.min_tls_version = grpc_get_tsi_tls_version(config->min_tls_version); + options.max_tls_version = grpc_get_tsi_tls_version(config->max_tls_version); + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory_); + gpr_free(options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; + } + + void add_handshakers(const grpc_channel_args* args, + grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_mgr) override { + // Instantiate TSI handshaker. + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory_, + overridden_target_name_.empty() ? target_name_.c_str() + : overridden_target_name_.c_str(), + &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + const char* target_name = overridden_target_name_.empty() + ? target_name_.c_str() + : overridden_target_name_.c_str(); + grpc_error_handle error = ssl_check_peer(target_name, &peer, auth_context); + if (error == GRPC_ERROR_NONE && + verify_options_->verify_peer_callback != nullptr) { + const tsi_peer_property* p = + tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + if (p == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing pem cert property."); + } else { + char* peer_pem = static_cast(gpr_malloc(p->value.length + 1)); + memcpy(peer_pem, p->value.data, p->value.length); + peer_pem[p->value.length] = '\0'; + int callback_status = verify_options_->verify_peer_callback( + target_name, peer_pem, + verify_options_->verify_peer_callback_userdata); + gpr_free(peer_pem); + if (callback_status) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrFormat( + "Verify peer callback returned a failure (%d)", callback_status)); + } + } + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = target_name_.compare(other->target_name_); + if (c != 0) return c; + return overridden_target_name_.compare(other->overridden_target_name_); + } + + bool check_call_host(absl::string_view host, grpc_auth_context* auth_context, + grpc_closure* /*on_call_host_checked*/, + grpc_error_handle* error) override { + return grpc_ssl_check_call_host(host, target_name_.c_str(), + overridden_target_name_.c_str(), + auth_context, error); + } + + void cancel_check_call_host(grpc_closure* /*on_call_host_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + private: + tsi_ssl_client_handshaker_factory* client_handshaker_factory_; + std::string target_name_; + std::string overridden_target_name_; + const verify_peer_options* verify_options_; +}; + +class grpc_ssl_server_security_connector + : public grpc_server_security_connector { + public: + explicit grpc_ssl_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_SSL_URL_SCHEME, + std::move(server_creds)) {} + + ~grpc_ssl_server_security_connector() override { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + + bool has_cert_config_fetcher() const { + return static_cast(server_creds()) + ->has_cert_config_fetcher(); + } + + const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const { + return server_handshaker_factory_; + } + + grpc_security_status InitializeHandshakerFactory() { + if (has_cert_config_fetcher()) { + // Load initial credentials from certificate_config_fetcher: + if (!try_fetch_ssl_server_credentials()) { + gpr_log(GPR_ERROR, + "Failed loading SSL server credentials from fetcher."); + return GRPC_SECURITY_ERROR; + } + } else { + auto* server_credentials = + static_cast(server_creds()); + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + tsi_ssl_server_handshaker_options options; + options.pem_key_cert_pairs = + server_credentials->config().pem_key_cert_pairs; + options.num_key_cert_pairs = + server_credentials->config().num_key_cert_pairs; + options.pem_client_root_certs = + server_credentials->config().pem_root_certs; + options.client_certificate_request = + grpc_get_tsi_client_certificate_request_type( + server_credentials->config().client_certificate_request); + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.alpn_protocols = alpn_protocol_strings; + options.num_alpn_protocols = static_cast(num_alpn_protocols); + options.min_tls_version = grpc_get_tsi_tls_version( + server_credentials->config().min_tls_version); + options.max_tls_version = grpc_get_tsi_tls_version( + server_credentials->config().max_tls_version); + const tsi_result result = + tsi_create_ssl_server_handshaker_factory_with_options( + &options, &server_handshaker_factory_); + gpr_free(alpn_protocol_strings); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + } + return GRPC_SECURITY_OK; + } + + void add_handshakers(const grpc_channel_args* args, + grpc_pollset_set* /*interested_parties*/, + grpc_core::HandshakeManager* handshake_mgr) override { + // Instantiate TSI handshaker. + try_fetch_ssl_server_credentials(); + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory_, &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* /*ep*/, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + grpc_error_handle error = ssl_check_peer(nullptr, &peer, auth_context); + tsi_peer_destruct(&peer); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + } + + void cancel_check_peer(grpc_closure* /*on_peer_checked*/, + grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } + + private: + /* Attempts to fetch the server certificate config if a callback is available. + * Current certificate config will continue to be used if the callback returns + * an error. Returns true if new credentials were successfully loaded. */ + bool try_fetch_ssl_server_credentials() { + grpc_ssl_server_certificate_config* certificate_config = nullptr; + bool status; + if (!has_cert_config_fetcher()) return false; + + grpc_core::MutexLock lock(&mu_); + grpc_ssl_server_credentials* server_creds = + static_cast(this->mutable_server_creds()); + grpc_ssl_certificate_config_reload_status cb_result = + server_creds->FetchCertConfig(&certificate_config); + if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { + gpr_log(GPR_DEBUG, "No change in SSL server credentials."); + status = false; + } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { + status = try_replace_server_handshaker_factory(certificate_config); + } else { + // Log error, continue using previously-loaded credentials. + gpr_log(GPR_ERROR, + "Failed fetching new server credentials, continuing to " + "use previously-loaded credentials."); + status = false; + } + + if (certificate_config != nullptr) { + grpc_ssl_server_certificate_config_destroy(certificate_config); + } + return status; + } + + /* Attempts to replace the server_handshaker_factory with a new factory using + * the provided grpc_ssl_server_certificate_config. Should new factory + * creation fail, the existing factory will not be replaced. Returns true on + * success (new factory created). */ + bool try_replace_server_handshaker_factory( + const grpc_ssl_server_certificate_config* config) { + if (config == nullptr) { + gpr_log(GPR_ERROR, + "Server certificate config callback returned invalid (NULL) " + "config."); + return false; + } + gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); + + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; + const grpc_ssl_server_credentials* server_creds = + static_cast(this->server_creds()); + GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr); + tsi_ssl_server_handshaker_options options; + options.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + config->pem_key_cert_pairs, config->num_key_cert_pairs); + options.num_key_cert_pairs = config->num_key_cert_pairs; + options.pem_client_root_certs = config->pem_root_certs; + options.client_certificate_request = + grpc_get_tsi_client_certificate_request_type( + server_creds->config().client_certificate_request); + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.alpn_protocols = alpn_protocol_strings; + options.num_alpn_protocols = static_cast(num_alpn_protocols); + tsi_result result = tsi_create_ssl_server_handshaker_factory_with_options( + &options, &new_handshaker_factory); + grpc_tsi_ssl_pem_key_cert_pairs_destroy( + const_cast(options.pem_key_cert_pairs), + options.num_key_cert_pairs); + gpr_free(alpn_protocol_strings); + + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return false; + } + set_server_handshaker_factory(new_handshaker_factory); + return true; + } + + void set_server_handshaker_factory( + tsi_ssl_server_handshaker_factory* new_factory) { + if (server_handshaker_factory_) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + server_handshaker_factory_ = new_factory; + } + + grpc_core::Mutex mu_; + tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr; +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) { + if (config == nullptr || target_name == nullptr) { + gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); + return nullptr; + } + + const char* pem_root_certs; + const tsi_ssl_root_certs_store* root_store; + if (config->pem_root_certs == nullptr) { + // Use default root certificates. + pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return nullptr; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); + } else { + pem_root_certs = config->pem_root_certs; + root_store = nullptr; + } + + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), config, + target_name, overridden_target_name); + const grpc_security_status result = c->InitializeHandshakerFactory( + config, pem_root_certs, root_store, ssl_session_cache); + if (result != GRPC_SECURITY_OK) { + return nullptr; + } + return c; +} + +grpc_core::RefCountedPtr +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr server_credentials) { + GPR_ASSERT(server_credentials != nullptr); + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + std::move(server_credentials)); + const grpc_security_status retval = c->InitializeHandshakerFactory(); + if (retval != GRPC_SECURITY_OK) { + return nullptr; + } + return c; +} diff --git a/src/core/lib/security/security_connector/ssl_utils.cc b/src/core/lib/security/security_connector/ssl_utils.cc new file mode 100644 index 00000000..4fe689de --- /dev/null +++ b/src/core/lib/security/security_connector/ssl_utils.cc @@ -0,0 +1,611 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/ssl_utils.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "src/core/tsi/ssl_transport_security.h" + +/* -- Constants. -- */ + +#if defined(GRPC_ROOT_PEM_PATH) +static const char* installed_roots_path = GRPC_ROOT_PEM_PATH; +#elif defined(INSTALL_PREFIX) +static const char* installed_roots_path = + INSTALL_PREFIX "/usr/share/grpc/roots.pem"; +#else +static const char* installed_roots_path = "/usr/share/grpc/roots.pem"; +#endif + +#ifndef TSI_OPENSSL_ALPN_SUPPORT +#define TSI_OPENSSL_ALPN_SUPPORT 1 +#endif + +/* -- Overridden default roots. -- */ + +static grpc_ssl_roots_override_callback ssl_roots_override_cb = nullptr; + +void grpc_set_ssl_roots_override_callback(grpc_ssl_roots_override_callback cb) { + ssl_roots_override_cb = cb; +} + +/* -- Cipher suites. -- */ + +static gpr_once cipher_suites_once = GPR_ONCE_INIT; +static const char* cipher_suites = nullptr; + +// All cipher suites for default are compliant with HTTP2. +GPR_GLOBAL_CONFIG_DEFINE_STRING( + grpc_ssl_cipher_suites, + "TLS_AES_128_GCM_SHA256:" + "TLS_AES_256_GCM_SHA384:" + "TLS_CHACHA20_POLY1305_SHA256:" + "ECDHE-ECDSA-AES128-GCM-SHA256:" + "ECDHE-ECDSA-AES256-GCM-SHA384:" + "ECDHE-RSA-AES128-GCM-SHA256:" + "ECDHE-RSA-AES256-GCM-SHA384", + "A colon separated list of cipher suites to use with OpenSSL") + +static void init_cipher_suites(void) { + grpc_core::UniquePtr value = + GPR_GLOBAL_CONFIG_GET(grpc_ssl_cipher_suites); + cipher_suites = value.release(); +} + +/* --- Util --- */ + +const char* grpc_get_ssl_cipher_suites(void) { + gpr_once_init(&cipher_suites_once, init_cipher_suites); + return cipher_suites; +} + +grpc_security_level grpc_tsi_security_level_string_to_enum( + const char* security_level) { + if (strcmp(security_level, "TSI_INTEGRITY_ONLY") == 0) { + return GRPC_INTEGRITY_ONLY; + } else if (strcmp(security_level, "TSI_PRIVACY_AND_INTEGRITY") == 0) { + return GRPC_PRIVACY_AND_INTEGRITY; + } + return GRPC_SECURITY_NONE; +} + +const char* grpc_security_level_to_string(grpc_security_level security_level) { + if (security_level == GRPC_PRIVACY_AND_INTEGRITY) { + return "GRPC_PRIVACY_AND_INTEGRITY"; + } else if (security_level == GRPC_INTEGRITY_ONLY) { + return "GRPC_INTEGRITY_ONLY"; + } + return "GRPC_SECURITY_NONE"; +} + +bool grpc_check_security_level(grpc_security_level channel_level, + grpc_security_level call_cred_level) { + return static_cast(channel_level) >= static_cast(call_cred_level); +} + +tsi_client_certificate_request_type +grpc_get_tsi_client_certificate_request_type( + grpc_ssl_client_certificate_request_type grpc_request_type) { + switch (grpc_request_type) { + case GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE: + return TSI_DONT_REQUEST_CLIENT_CERTIFICATE; + + case GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + return TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY; + + case GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: + return TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY; + + case GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + return TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY; + + case GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: + return TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + + default: + return TSI_DONT_REQUEST_CLIENT_CERTIFICATE; + } +} + +tsi_tls_version grpc_get_tsi_tls_version(grpc_tls_version tls_version) { + switch (tls_version) { + case grpc_tls_version::TLS1_2: + return tsi_tls_version::TSI_TLS1_2; + case grpc_tls_version::TLS1_3: + return tsi_tls_version::TSI_TLS1_3; + default: + gpr_log(GPR_INFO, "Falling back to TLS 1.2."); + return tsi_tls_version::TSI_TLS1_2; + } +} + +grpc_error_handle grpc_ssl_check_alpn(const tsi_peer* peer) { +#if TSI_OPENSSL_ALPN_SUPPORT + /* Check the ALPN if ALPN is supported. */ + const tsi_peer_property* p = + tsi_peer_get_property_by_name(peer, TSI_SSL_ALPN_SELECTED_PROTOCOL); + if (p == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing selected ALPN property."); + } + if (!grpc_chttp2_is_alpn_version_supported(p->value.data, p->value.length)) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: invalid ALPN value."); + } +#endif /* TSI_OPENSSL_ALPN_SUPPORT */ + return GRPC_ERROR_NONE; +} + +grpc_error_handle grpc_ssl_check_peer_name(absl::string_view peer_name, + const tsi_peer* peer) { + /* Check the peer name if specified. */ + if (!peer_name.empty() && !grpc_ssl_host_matches_name(peer, peer_name)) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Peer name ", peer_name, " is not in peer certificate")); + } + return GRPC_ERROR_NONE; +} + +bool grpc_ssl_check_call_host(absl::string_view host, + absl::string_view target_name, + absl::string_view overridden_target_name, + grpc_auth_context* auth_context, + grpc_error_handle* error) { + grpc_security_status status = GRPC_SECURITY_ERROR; + tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); + if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; + /* If the target name was overridden, then the original target_name was + 'checked' transitively during the previous peer check at the end of the + handshake. */ + if (!overridden_target_name.empty() && host == target_name) { + status = GRPC_SECURITY_OK; + } + if (status != GRPC_SECURITY_OK) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "call host does not match SSL server name"); + } + grpc_shallow_peer_destruct(&peer); + return true; +} + +const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols) { + GPR_ASSERT(num_alpn_protocols != nullptr); + *num_alpn_protocols = grpc_chttp2_num_alpn_versions(); + const char** alpn_protocol_strings = static_cast( + gpr_malloc(sizeof(const char*) * (*num_alpn_protocols))); + for (size_t i = 0; i < *num_alpn_protocols; i++) { + alpn_protocol_strings[i] = grpc_chttp2_get_alpn_version_index(i); + } + return alpn_protocol_strings; +} + +int grpc_ssl_host_matches_name(const tsi_peer* peer, + absl::string_view peer_name) { + absl::string_view allocated_name; + absl::string_view ignored_port; + grpc_core::SplitHostPort(peer_name, &allocated_name, &ignored_port); + if (allocated_name.empty()) return 0; + + // IPv6 zone-id should not be included in comparisons. + const size_t zone_id = allocated_name.find('%'); + if (zone_id != absl::string_view::npos) { + allocated_name.remove_suffix(allocated_name.size() - zone_id); + } + return tsi_ssl_peer_matches_name(peer, allocated_name); +} + +int grpc_ssl_cmp_target_name(absl::string_view target_name, + absl::string_view other_target_name, + absl::string_view overridden_target_name, + absl::string_view other_overridden_target_name) { + int c = target_name.compare(other_target_name); + if (c != 0) return c; + return overridden_target_name.compare(other_overridden_target_name); +} + +static bool IsSpiffeId(absl::string_view uri) { + // Return false without logging for a non-spiffe uri scheme. + if (!absl::StartsWith(uri, "spiffe://")) { + return false; + }; + if (uri.size() > 2048) { + gpr_log(GPR_INFO, "Invalid SPIFFE ID: ID longer than 2048 bytes."); + return false; + } + std::vector splits = absl::StrSplit(uri, '/'); + if (splits.size() < 4 || splits[3] == "") { + gpr_log(GPR_INFO, "Invalid SPIFFE ID: workload id is empty."); + return false; + } + if (splits[2].size() > 255) { + gpr_log(GPR_INFO, "Invalid SPIFFE ID: domain longer than 255 characters."); + return false; + } + return true; +} + +grpc_core::RefCountedPtr grpc_ssl_peer_to_auth_context( + const tsi_peer* peer, const char* transport_security_type) { + size_t i; + const char* peer_identity_property_name = nullptr; + + /* The caller has checked the certificate type property. */ + GPR_ASSERT(peer->property_count >= 1); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + transport_security_type); + const char* spiffe_data = nullptr; + size_t spiffe_length = 0; + int uri_count = 0; + bool has_spiffe_id = false; + for (i = 0; i < peer->property_count; i++) { + const tsi_peer_property* prop = &peer->properties[i]; + if (prop->name == nullptr) continue; + if (strcmp(prop->name, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) { + /* If there is no subject alt name, have the CN as the identity. */ + if (peer_identity_property_name == nullptr) { + peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME; + } + grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { + peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME; + grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), + GRPC_X509_PEM_CERT_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_X509_PEM_CERT_CHAIN_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), + GRPC_X509_PEM_CERT_CHAIN_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), + GRPC_SSL_SESSION_REUSED_PROPERTY, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_SECURITY_LEVEL_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property( + ctx.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_X509_DNS_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), GRPC_PEER_DNS_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_X509_URI_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), GRPC_PEER_URI_PROPERTY_NAME, + prop->value.data, prop->value.length); + uri_count++; + absl::string_view spiffe_id(prop->value.data, prop->value.length); + if (IsSpiffeId(spiffe_id)) { + spiffe_data = prop->value.data; + spiffe_length = prop->value.length; + has_spiffe_id = true; + } + } else if (strcmp(prop->name, TSI_X509_EMAIL_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), GRPC_PEER_EMAIL_PROPERTY_NAME, + prop->value.data, prop->value.length); + } else if (strcmp(prop->name, TSI_X509_IP_PEER_PROPERTY) == 0) { + grpc_auth_context_add_property(ctx.get(), GRPC_PEER_IP_PROPERTY_NAME, + prop->value.data, prop->value.length); + } + } + if (peer_identity_property_name != nullptr) { + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx.get(), peer_identity_property_name) == 1); + } + // A valid SPIFFE certificate can only have exact one URI SAN field. + if (has_spiffe_id) { + if (uri_count == 1) { + GPR_ASSERT(spiffe_length > 0); + GPR_ASSERT(spiffe_data != nullptr); + grpc_auth_context_add_property(ctx.get(), + GRPC_PEER_SPIFFE_ID_PROPERTY_NAME, + spiffe_data, spiffe_length); + } else { + gpr_log(GPR_INFO, "Invalid SPIFFE ID: multiple URI SANs."); + } + } + return ctx; +} + +static void add_shallow_auth_property_to_peer(tsi_peer* peer, + const grpc_auth_property* prop, + const char* tsi_prop_name) { + tsi_peer_property* tsi_prop = &peer->properties[peer->property_count++]; + tsi_prop->name = const_cast(tsi_prop_name); + tsi_prop->value.data = prop->value; + tsi_prop->value.length = prop->value_length; +} + +tsi_peer grpc_shallow_peer_from_ssl_auth_context( + const grpc_auth_context* auth_context) { + size_t max_num_props = 0; + grpc_auth_property_iterator it; + const grpc_auth_property* prop; + tsi_peer peer; + memset(&peer, 0, sizeof(peer)); + + it = grpc_auth_context_property_iterator(auth_context); + while (grpc_auth_property_iterator_next(&it) != nullptr) max_num_props++; + + if (max_num_props > 0) { + peer.properties = static_cast( + gpr_malloc(max_num_props * sizeof(tsi_peer_property))); + it = grpc_auth_context_property_iterator(auth_context); + while ((prop = grpc_auth_property_iterator_next(&it)) != nullptr) { + if (strcmp(prop->name, GRPC_X509_SAN_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer( + &peer, prop, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_X509_CN_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer( + &peer, prop, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_X509_PEM_CERT_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_PEM_CERT_PROPERTY); + } else if (strcmp(prop->name, + GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_SECURITY_LEVEL_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_X509_PEM_CERT_CHAIN_PROPERTY_NAME) == + 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_PEM_CERT_CHAIN_PROPERTY); + } else if (strcmp(prop->name, GRPC_PEER_DNS_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_DNS_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_PEER_URI_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_URI_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_PEER_SPIFFE_ID_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_URI_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_PEER_EMAIL_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_EMAIL_PEER_PROPERTY); + } else if (strcmp(prop->name, GRPC_PEER_IP_PROPERTY_NAME) == 0) { + add_shallow_auth_property_to_peer(&peer, prop, + TSI_X509_IP_PEER_PROPERTY); + } + } + } + return peer; +} + +void grpc_shallow_peer_destruct(tsi_peer* peer) { + if (peer->properties != nullptr) gpr_free(peer->properties); +} + +grpc_security_status grpc_ssl_tsi_client_handshaker_factory_init( + tsi_ssl_pem_key_cert_pair* pem_key_cert_pair, const char* pem_root_certs, + bool skip_server_certificate_verification, tsi_tls_version min_tls_version, + tsi_tls_version max_tls_version, tsi_ssl_session_cache* ssl_session_cache, + tsi_ssl_client_handshaker_factory** handshaker_factory) { + const char* root_certs; + const tsi_ssl_root_certs_store* root_store; + if (pem_root_certs == nullptr) { + gpr_log(GPR_INFO, + "No root certificates specified; use ones stored in system default " + "locations instead"); + // Use default root certificates. + root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return GRPC_SECURITY_ERROR; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); + } else { + root_certs = pem_root_certs; + root_store = nullptr; + } + bool has_key_cert_pair = pem_key_cert_pair != nullptr && + pem_key_cert_pair->private_key != nullptr && + pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + GPR_DEBUG_ASSERT(root_certs != nullptr); + options.pem_root_certs = root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + options.skip_server_certificate_verification = + skip_server_certificate_verification; + options.min_tls_version = min_tls_version; + options.max_tls_version = max_tls_version; + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options(&options, + handshaker_factory); + gpr_free(options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; +} + +grpc_security_status grpc_ssl_tsi_server_handshaker_factory_init( + tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, + const char* pem_root_certs, + grpc_ssl_client_certificate_request_type client_certificate_request, + tsi_tls_version min_tls_version, tsi_tls_version max_tls_version, + tsi_ssl_server_handshaker_factory** handshaker_factory) { + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + tsi_ssl_server_handshaker_options options; + options.pem_key_cert_pairs = pem_key_cert_pairs; + options.num_key_cert_pairs = num_key_cert_pairs; + options.pem_client_root_certs = pem_root_certs; + options.client_certificate_request = + grpc_get_tsi_client_certificate_request_type(client_certificate_request); + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.alpn_protocols = alpn_protocol_strings; + options.num_alpn_protocols = static_cast(num_alpn_protocols); + options.min_tls_version = min_tls_version; + options.max_tls_version = max_tls_version; + const tsi_result result = + tsi_create_ssl_server_handshaker_factory_with_options(&options, + handshaker_factory); + gpr_free(alpn_protocol_strings); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; +} + +/* --- Ssl cache implementation. --- */ + +grpc_ssl_session_cache* grpc_ssl_session_cache_create_lru(size_t capacity) { + tsi_ssl_session_cache* cache = tsi_ssl_session_cache_create_lru(capacity); + return reinterpret_cast(cache); +} + +void grpc_ssl_session_cache_destroy(grpc_ssl_session_cache* cache) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast(cache); + tsi_ssl_session_cache_unref(tsi_cache); +} + +static void* grpc_ssl_session_cache_arg_copy(void* p) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast(p); + // destroy call below will unref the pointer. + tsi_ssl_session_cache_ref(tsi_cache); + return p; +} + +static void grpc_ssl_session_cache_arg_destroy(void* p) { + tsi_ssl_session_cache* tsi_cache = + reinterpret_cast(p); + tsi_ssl_session_cache_unref(tsi_cache); +} + +static int grpc_ssl_session_cache_arg_cmp(void* p, void* q) { + return grpc_core::QsortCompare(p, q); +} + +grpc_arg grpc_ssl_session_cache_create_channel_arg( + grpc_ssl_session_cache* cache) { + static const grpc_arg_pointer_vtable vtable = { + grpc_ssl_session_cache_arg_copy, + grpc_ssl_session_cache_arg_destroy, + grpc_ssl_session_cache_arg_cmp, + }; + return grpc_channel_arg_pointer_create( + const_cast(GRPC_SSL_SESSION_CACHE_ARG), cache, &vtable); +} + +/* --- Default SSL root store implementation. --- */ + +namespace grpc_core { + +tsi_ssl_root_certs_store* DefaultSslRootStore::default_root_store_; +grpc_slice DefaultSslRootStore::default_pem_root_certs_; + +const tsi_ssl_root_certs_store* DefaultSslRootStore::GetRootStore() { + InitRootStore(); + return default_root_store_; +} + +const char* DefaultSslRootStore::GetPemRootCerts() { + InitRootStore(); + return GRPC_SLICE_IS_EMPTY(default_pem_root_certs_) + ? nullptr + : reinterpret_cast + GRPC_SLICE_START_PTR(default_pem_root_certs_); +} + +grpc_slice DefaultSslRootStore::ComputePemRootCerts() { + grpc_slice result = grpc_empty_slice(); + const bool not_use_system_roots = + GPR_GLOBAL_CONFIG_GET(grpc_not_use_system_ssl_roots); + // First try to load the roots from the configuration. + grpc_core::UniquePtr default_root_certs_path = + GPR_GLOBAL_CONFIG_GET(grpc_default_ssl_roots_file_path); + if (strlen(default_root_certs_path.get()) > 0) { + GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(default_root_certs_path.get(), 1, &result)); + } + // Try overridden roots if needed. + grpc_ssl_roots_override_result ovrd_res = GRPC_SSL_ROOTS_OVERRIDE_FAIL; + if (GRPC_SLICE_IS_EMPTY(result) && ssl_roots_override_cb != nullptr) { + char* pem_root_certs = nullptr; + ovrd_res = ssl_roots_override_cb(&pem_root_certs); + if (ovrd_res == GRPC_SSL_ROOTS_OVERRIDE_OK) { + GPR_ASSERT(pem_root_certs != nullptr); + result = grpc_slice_from_copied_buffer( + pem_root_certs, + strlen(pem_root_certs) + 1); // nullptr terminator. + } + gpr_free(pem_root_certs); + } + // Try loading roots from OS trust store if flag is enabled. + if (GRPC_SLICE_IS_EMPTY(result) && !not_use_system_roots) { + result = LoadSystemRootCerts(); + } + // Fallback to roots manually shipped with gRPC. + if (GRPC_SLICE_IS_EMPTY(result) && + ovrd_res != GRPC_SSL_ROOTS_OVERRIDE_FAIL_PERMANENTLY) { + GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(installed_roots_path, 1, &result)); + } + return result; +} + +void DefaultSslRootStore::InitRootStore() { + static gpr_once once = GPR_ONCE_INIT; + gpr_once_init(&once, DefaultSslRootStore::InitRootStoreOnce); +} + +void DefaultSslRootStore::InitRootStoreOnce() { + default_pem_root_certs_ = ComputePemRootCerts(); + if (!GRPC_SLICE_IS_EMPTY(default_pem_root_certs_)) { + default_root_store_ = + tsi_ssl_root_certs_store_create(reinterpret_cast( + GRPC_SLICE_START_PTR(default_pem_root_certs_))); + } +} + +} // namespace grpc_core diff --git a/src/core/lib/security/security_connector/ssl_utils_config.cc b/src/core/lib/security/security_connector/ssl_utils_config.cc new file mode 100644 index 00000000..2d056a78 --- /dev/null +++ b/src/core/lib/security/security_connector/ssl_utils_config.cc @@ -0,0 +1,32 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/ssl_utils_config.h" + +/** Config variable that points to the default SSL roots file. This file + must be a PEM encoded file with all the roots such as the one that can be + downloaded from https://pki.google.com/roots.pem. */ +GPR_GLOBAL_CONFIG_DEFINE_STRING(grpc_default_ssl_roots_file_path, "", + "Path to the default SSL roots file."); + +/** Config variable used as a flag to enable/disable loading system root + certificates from the OS trust store. */ +GPR_GLOBAL_CONFIG_DEFINE_BOOL(grpc_not_use_system_ssl_roots, false, + "Disable loading system root certificates."); diff --git a/src/core/lib/security/security_connector/tls/tls_security_connector.cc b/src/core/lib/security/security_connector/tls/tls_security_connector.cc new file mode 100644 index 00000000..760c610a --- /dev/null +++ b/src/core/lib/security/security_connector/tls/tls_security_connector.cc @@ -0,0 +1,667 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/security_connector/tls/tls_security_connector.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" +#include "src/core/lib/security/credentials/tls/tls_credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/transport.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security.h" + +namespace grpc_core { + +namespace { + +tsi_ssl_pem_key_cert_pair* ConvertToTsiPemKeyCertPair( + const PemKeyCertPairList& cert_pair_list) { + tsi_ssl_pem_key_cert_pair* tsi_pairs = nullptr; + size_t num_key_cert_pairs = cert_pair_list.size(); + if (num_key_cert_pairs > 0) { + GPR_ASSERT(cert_pair_list.data() != nullptr); + tsi_pairs = static_cast( + gpr_zalloc(num_key_cert_pairs * sizeof(tsi_ssl_pem_key_cert_pair))); + } + for (size_t i = 0; i < num_key_cert_pairs; i++) { + GPR_ASSERT(!cert_pair_list[i].private_key().empty()); + GPR_ASSERT(!cert_pair_list[i].cert_chain().empty()); + tsi_pairs[i].cert_chain = + gpr_strdup(cert_pair_list[i].cert_chain().c_str()); + tsi_pairs[i].private_key = + gpr_strdup(cert_pair_list[i].private_key().c_str()); + } + return tsi_pairs; +} + +} // namespace + +// -------------------channel security connector------------------- +RefCountedPtr +TlsChannelSecurityConnector::CreateTlsChannelSecurityConnector( + RefCountedPtr channel_creds, + RefCountedPtr options, + RefCountedPtr request_metadata_creds, + const char* target_name, const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) { + if (channel_creds == nullptr) { + gpr_log(GPR_ERROR, + "channel_creds is nullptr in " + "TlsChannelSecurityConnectorCreate()"); + return nullptr; + } + if (options == nullptr) { + gpr_log(GPR_ERROR, + "options is nullptr in " + "TlsChannelSecurityConnectorCreate()"); + return nullptr; + } + if (target_name == nullptr) { + gpr_log(GPR_ERROR, + "target_name is nullptr in " + "TlsChannelSecurityConnectorCreate()"); + return nullptr; + } + return MakeRefCounted( + std::move(channel_creds), std::move(options), + std::move(request_metadata_creds), target_name, overridden_target_name, + ssl_session_cache); +} + +TlsChannelSecurityConnector::TlsChannelSecurityConnector( + RefCountedPtr channel_creds, + RefCountedPtr options, + RefCountedPtr request_metadata_creds, + const char* target_name, const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) + : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + options_(std::move(options)), + overridden_target_name_( + overridden_target_name == nullptr ? "" : overridden_target_name), + ssl_session_cache_(ssl_session_cache) { + if (ssl_session_cache_ != nullptr) { + tsi_ssl_session_cache_ref(ssl_session_cache_); + } + check_arg_ = ServerAuthorizationCheckArgCreate(this); + absl::string_view host; + absl::string_view port; + SplitHostPort(target_name, &host, &port); + target_name_ = std::string(host); + // Create a watcher. + auto watcher_ptr = absl::make_unique(this); + certificate_watcher_ = watcher_ptr.get(); + // Register the watcher with the distributor. + grpc_tls_certificate_distributor* distributor = + options_->certificate_distributor(); + absl::optional watched_root_cert_name; + if (options_->watch_root_cert()) { + watched_root_cert_name = options_->root_cert_name(); + } + absl::optional watched_identity_cert_name; + if (options_->watch_identity_pair()) { + watched_identity_cert_name = options_->identity_cert_name(); + } + // We will use the root certs stored in system default locations if not + // watching root certs on the client side. We will handle this case + // differently here, because "watching a default roots without the identity + // certs" is a valid case(and hence we will need to call + // OnCertificatesChanged), but it requires nothing from the provider, and + // hence no need to register the watcher. + bool use_default_roots = !options_->watch_root_cert(); + if (use_default_roots && !options_->watch_identity_pair()) { + watcher_ptr->OnCertificatesChanged(absl::nullopt, absl::nullopt); + } else { + distributor->WatchTlsCertificates(std::move(watcher_ptr), + watched_root_cert_name, + watched_identity_cert_name); + } +} + +TlsChannelSecurityConnector::~TlsChannelSecurityConnector() { + if (ssl_session_cache_ != nullptr) { + tsi_ssl_session_cache_unref(ssl_session_cache_); + } + // Cancel all the watchers. + grpc_tls_certificate_distributor* distributor = + options_->certificate_distributor(); + if (distributor != nullptr) { + distributor->CancelTlsCertificatesWatch(certificate_watcher_); + } + if (client_handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + } + if (check_arg_ != nullptr) { + ServerAuthorizationCheckArgDestroy(check_arg_); + } +} + +void TlsChannelSecurityConnector::add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /*interested_parties*/, + HandshakeManager* handshake_mgr) { + MutexLock lock(&mu_); + tsi_handshaker* tsi_hs = nullptr; + if (client_handshaker_factory_ != nullptr) { + // Instantiate TSI handshaker. + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory_, + overridden_target_name_.empty() ? target_name_.c_str() + : overridden_target_name_.c_str(), + &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } + } + // If tsi_hs is null, this will add a failing handshaker. + handshake_mgr->Add(SecurityHandshakerCreate(tsi_hs, this, args)); +} + +void TlsChannelSecurityConnector::check_peer( + tsi_peer peer, grpc_endpoint* /*ep*/, + RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + const char* target_name = overridden_target_name_.empty() + ? target_name_.c_str() + : overridden_target_name_.c_str(); + grpc_error_handle error = grpc_ssl_check_alpn(&peer); + if (error != GRPC_ERROR_NONE) { + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); + return; + } + *auth_context = + grpc_ssl_peer_to_auth_context(&peer, GRPC_TLS_TRANSPORT_SECURITY_TYPE); + if (options_->server_verification_option() == GRPC_TLS_SERVER_VERIFICATION) { + /* Do the default host name check if specifying the target name. */ + error = internal::TlsCheckHostName(target_name, &peer); + if (error != GRPC_ERROR_NONE) { + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); + return; + } + } + /* Do the custom server authorization check, if specified by the user. */ + const grpc_tls_server_authorization_check_config* config = + options_->server_authorization_check_config(); + /* If server authorization config is not null, use it to perform + * server authorization check. */ + if (config != nullptr) { + const tsi_peer_property* p = + tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + if (p == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing pem cert property."); + } else { + char* peer_pem = static_cast(gpr_zalloc(p->value.length + 1)); + memcpy(peer_pem, p->value.data, p->value.length); + GPR_ASSERT(check_arg_ != nullptr); + check_arg_->peer_cert = check_arg_->peer_cert == nullptr + ? gpr_strdup(peer_pem) + : check_arg_->peer_cert; + check_arg_->target_name = check_arg_->target_name == nullptr + ? gpr_strdup(target_name) + : check_arg_->target_name; + on_peer_checked_ = on_peer_checked; + gpr_free(peer_pem); + const tsi_peer_property* chain = tsi_peer_get_property_by_name( + &peer, TSI_X509_PEM_CERT_CHAIN_PROPERTY); + if (chain != nullptr) { + char* peer_pem_chain = + static_cast(gpr_zalloc(chain->value.length + 1)); + memcpy(peer_pem_chain, chain->value.data, chain->value.length); + check_arg_->peer_cert_full_chain = + check_arg_->peer_cert_full_chain == nullptr + ? gpr_strdup(peer_pem_chain) + : check_arg_->peer_cert_full_chain; + gpr_free(peer_pem_chain); + } + // TODO(zhenlian) - This should be cleaned up as part of the custom + // verification changes. Fill in the subject alternative names + std::vector subject_alternative_names; + for (size_t i = 0; i < peer.property_count; i++) { + const tsi_peer_property* prop = &peer.properties[i]; + if (strcmp(prop->name, + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { + char* san = new char[prop->value.length + 1]; + memcpy(san, prop->value.data, prop->value.length); + san[prop->value.length] = '\0'; + subject_alternative_names.emplace_back(san); + } + } + if (check_arg_->subject_alternative_names != nullptr) { + for (size_t i = 0; i < check_arg_->subject_alternative_names_size; + ++i) { + delete[] check_arg_->subject_alternative_names[i]; + } + delete[] check_arg_->subject_alternative_names; + } + check_arg_->subject_alternative_names_size = + subject_alternative_names.size(); + if (subject_alternative_names.empty()) { + check_arg_->subject_alternative_names = nullptr; + } else { + check_arg_->subject_alternative_names = + new char*[check_arg_->subject_alternative_names_size]; + for (size_t i = 0; i < check_arg_->subject_alternative_names_size; + ++i) { + check_arg_->subject_alternative_names[i] = + subject_alternative_names[i]; + } + } + int callback_status = config->Schedule(check_arg_); + /* Server authorization check is handled asynchronously. */ + if (callback_status) { + tsi_peer_destruct(&peer); + return; + } + /* Server authorization check is handled synchronously. */ + error = ProcessServerAuthorizationCheckResult(check_arg_); + } + } + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); + tsi_peer_destruct(&peer); +} + +int TlsChannelSecurityConnector::cmp( + const grpc_security_connector* other_sc) const { + auto* other = reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) { + return c; + } + return grpc_ssl_cmp_target_name( + target_name_.c_str(), other->target_name_.c_str(), + overridden_target_name_.c_str(), other->overridden_target_name_.c_str()); +} + +bool TlsChannelSecurityConnector::check_call_host( + absl::string_view host, grpc_auth_context* auth_context, + grpc_closure* /*on_call_host_checked*/, grpc_error_handle* error) { + if (options_->server_verification_option() == + GRPC_TLS_SKIP_HOSTNAME_VERIFICATION || + options_->server_verification_option() == + GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION) { + return true; + } + return grpc_ssl_check_call_host(host, target_name_.c_str(), + overridden_target_name_.c_str(), auth_context, + error); +} + +void TlsChannelSecurityConnector::cancel_check_call_host( + grpc_closure* /*on_call_host_checked*/, grpc_error_handle error) { + GRPC_ERROR_UNREF(error); +} + +void TlsChannelSecurityConnector::TlsChannelCertificateWatcher:: + OnCertificatesChanged(absl::optional root_certs, + absl::optional key_cert_pairs) { + GPR_ASSERT(security_connector_ != nullptr); + MutexLock lock(&security_connector_->mu_); + if (root_certs.has_value()) { + security_connector_->pem_root_certs_ = root_certs; + } + if (key_cert_pairs.has_value()) { + security_connector_->pem_key_cert_pair_list_ = std::move(key_cert_pairs); + } + const bool root_ready = !security_connector_->options_->watch_root_cert() || + security_connector_->pem_root_certs_.has_value(); + const bool identity_ready = + !security_connector_->options_->watch_identity_pair() || + security_connector_->pem_key_cert_pair_list_.has_value(); + if (root_ready && identity_ready) { + if (security_connector_->UpdateHandshakerFactoryLocked() != + GRPC_SECURITY_OK) { + gpr_log(GPR_ERROR, "Update handshaker factory failed."); + } + } +} + +// TODO(ZhenLian): implement the logic to signal waiting handshakers once +// BlockOnInitialCredentialHandshaker is implemented. +void TlsChannelSecurityConnector::TlsChannelCertificateWatcher::OnError( + grpc_error_handle root_cert_error, grpc_error_handle identity_cert_error) { + if (root_cert_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "TlsChannelCertificateWatcher getting root_cert_error: %s", + grpc_error_std_string(root_cert_error).c_str()); + } + if (identity_cert_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "TlsChannelCertificateWatcher getting identity_cert_error: %s", + grpc_error_std_string(identity_cert_error).c_str()); + } + GRPC_ERROR_UNREF(root_cert_error); + GRPC_ERROR_UNREF(identity_cert_error); +} + +// TODO(ZhenLian): implement the logic to signal waiting handshakers once +// BlockOnInitialCredentialHandshaker is implemented. +grpc_security_status +TlsChannelSecurityConnector::UpdateHandshakerFactoryLocked() { + bool skip_server_certificate_verification = + options_->server_verification_option() == + GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION; + /* Free the client handshaker factory if exists. */ + if (client_handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + } + std::string pem_root_certs; + if (pem_root_certs_.has_value()) { + // TODO(ZhenLian): update the underlying TSI layer to use C++ types like + // std::string and absl::string_view to avoid making another copy here. + pem_root_certs = std::string(*pem_root_certs_); + } + tsi_ssl_pem_key_cert_pair* pem_key_cert_pair = nullptr; + if (pem_key_cert_pair_list_.has_value()) { + pem_key_cert_pair = ConvertToTsiPemKeyCertPair(*pem_key_cert_pair_list_); + } + bool use_default_roots = !options_->watch_root_cert(); + grpc_security_status status = grpc_ssl_tsi_client_handshaker_factory_init( + pem_key_cert_pair, + pem_root_certs.empty() || use_default_roots ? nullptr + : pem_root_certs.c_str(), + skip_server_certificate_verification, + grpc_get_tsi_tls_version(options_->min_tls_version()), + grpc_get_tsi_tls_version(options_->max_tls_version()), ssl_session_cache_, + &client_handshaker_factory_); + /* Free memory. */ + if (pem_key_cert_pair != nullptr) { + grpc_tsi_ssl_pem_key_cert_pairs_destroy(pem_key_cert_pair, 1); + } + return status; +} + +void TlsChannelSecurityConnector::ServerAuthorizationCheckDone( + grpc_tls_server_authorization_check_arg* arg) { + GPR_ASSERT(arg != nullptr); + ExecCtx exec_ctx; + grpc_error_handle error = ProcessServerAuthorizationCheckResult(arg); + TlsChannelSecurityConnector* connector = + static_cast(arg->cb_user_data); + ExecCtx::Run(DEBUG_LOCATION, connector->on_peer_checked_, error); +} + +grpc_error_handle +TlsChannelSecurityConnector::ProcessServerAuthorizationCheckResult( + grpc_tls_server_authorization_check_arg* arg) { + grpc_error_handle error = GRPC_ERROR_NONE; + /* Server authorization check is cancelled by caller. */ + if (arg->status == GRPC_STATUS_CANCELLED) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Server authorization check is cancelled by the caller " + "with error: ", + arg->error_details->error_details())); + } else if (arg->status == GRPC_STATUS_OK) { + /* Server authorization check completed successfully but returned check + * failure. */ + if (!arg->success) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Server authorization check failed with error: ", + arg->error_details->error_details())); + } + /* Server authorization check did not complete correctly. */ + } else { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Server authorization check did not finish correctly with error: ", + arg->error_details->error_details())); + } + return error; +} + +grpc_tls_server_authorization_check_arg* +TlsChannelSecurityConnector::ServerAuthorizationCheckArgCreate( + void* user_data) { + grpc_tls_server_authorization_check_arg* arg = + new grpc_tls_server_authorization_check_arg(); + arg->target_name = nullptr; + arg->peer_cert = nullptr; + arg->peer_cert_full_chain = nullptr; + arg->subject_alternative_names = nullptr; + arg->subject_alternative_names_size = 0; + arg->error_details = new grpc_tls_error_details(); + arg->cb = ServerAuthorizationCheckDone; + arg->cb_user_data = user_data; + arg->status = GRPC_STATUS_OK; + return arg; +} + +void TlsChannelSecurityConnector::ServerAuthorizationCheckArgDestroy( + grpc_tls_server_authorization_check_arg* arg) { + if (arg == nullptr) { + return; + } + gpr_free(const_cast(arg->target_name)); + gpr_free(const_cast(arg->peer_cert)); + gpr_free(const_cast(arg->peer_cert_full_chain)); + for (size_t i = 0; i < arg->subject_alternative_names_size; ++i) { + delete[] arg->subject_alternative_names[i]; + } + delete[] arg->subject_alternative_names; + delete arg->error_details; + if (arg->destroy_context != nullptr) { + arg->destroy_context(arg->context); + } + delete arg; +} + +// -------------------server security connector------------------- +RefCountedPtr +TlsServerSecurityConnector::CreateTlsServerSecurityConnector( + RefCountedPtr server_creds, + RefCountedPtr options) { + if (server_creds == nullptr) { + gpr_log(GPR_ERROR, + "server_creds is nullptr in " + "TlsServerSecurityConnectorCreate()"); + return nullptr; + } + if (options == nullptr) { + gpr_log(GPR_ERROR, + "options is nullptr in " + "TlsServerSecurityConnectorCreate()"); + return nullptr; + } + return MakeRefCounted(std::move(server_creds), + std::move(options)); +} + +TlsServerSecurityConnector::TlsServerSecurityConnector( + RefCountedPtr server_creds, + RefCountedPtr options) + : grpc_server_security_connector(GRPC_SSL_URL_SCHEME, + std::move(server_creds)), + options_(std::move(options)) { + // Create a watcher. + auto watcher_ptr = absl::make_unique(this); + certificate_watcher_ = watcher_ptr.get(); + // Register the watcher with the distributor. + grpc_tls_certificate_distributor* distributor = + options_->certificate_distributor(); + absl::optional watched_root_cert_name; + if (options_->watch_root_cert()) { + watched_root_cert_name = options_->root_cert_name(); + } + absl::optional watched_identity_cert_name; + if (options_->watch_identity_pair()) { + watched_identity_cert_name = options_->identity_cert_name(); + } + // Server side won't use default system roots at any time. + distributor->WatchTlsCertificates(std::move(watcher_ptr), + watched_root_cert_name, + watched_identity_cert_name); +} + +TlsServerSecurityConnector::~TlsServerSecurityConnector() { + // Cancel all the watchers. + grpc_tls_certificate_distributor* distributor = + options_->certificate_distributor(); + distributor->CancelTlsCertificatesWatch(certificate_watcher_); + if (server_handshaker_factory_ != nullptr) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } +} + +void TlsServerSecurityConnector::add_handshakers( + const grpc_channel_args* args, grpc_pollset_set* /*interested_parties*/, + HandshakeManager* handshake_mgr) { + MutexLock lock(&mu_); + tsi_handshaker* tsi_hs = nullptr; + if (server_handshaker_factory_ != nullptr) { + // Instantiate TSI handshaker. + tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory_, &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } + } + // If tsi_hs is null, this will add a failing handshaker. + handshake_mgr->Add(SecurityHandshakerCreate(tsi_hs, this, args)); +} + +void TlsServerSecurityConnector::check_peer( + tsi_peer peer, grpc_endpoint* /*ep*/, + RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + grpc_error_handle error = grpc_ssl_check_alpn(&peer); + *auth_context = + grpc_ssl_peer_to_auth_context(&peer, GRPC_TLS_TRANSPORT_SECURITY_TYPE); + tsi_peer_destruct(&peer); + ExecCtx::Run(DEBUG_LOCATION, on_peer_checked, error); +} + +int TlsServerSecurityConnector::cmp( + const grpc_security_connector* other) const { + return server_security_connector_cmp( + static_cast(other)); +} + +void TlsServerSecurityConnector::TlsServerCertificateWatcher:: + OnCertificatesChanged(absl::optional root_certs, + absl::optional key_cert_pairs) { + GPR_ASSERT(security_connector_ != nullptr); + MutexLock lock(&security_connector_->mu_); + if (root_certs.has_value()) { + security_connector_->pem_root_certs_ = root_certs; + } + if (key_cert_pairs.has_value()) { + security_connector_->pem_key_cert_pair_list_ = std::move(key_cert_pairs); + } + bool root_being_watched = security_connector_->options_->watch_root_cert(); + bool root_has_value = security_connector_->pem_root_certs_.has_value(); + bool identity_being_watched = + security_connector_->options_->watch_identity_pair(); + bool identity_has_value = + security_connector_->pem_key_cert_pair_list_.has_value(); + if ((root_being_watched && root_has_value && identity_being_watched && + identity_has_value) || + (root_being_watched && root_has_value && !identity_being_watched) || + (!root_being_watched && identity_being_watched && identity_has_value)) { + if (security_connector_->UpdateHandshakerFactoryLocked() != + GRPC_SECURITY_OK) { + gpr_log(GPR_ERROR, "Update handshaker factory failed."); + } + } +} + +// TODO(ZhenLian): implement the logic to signal waiting handshakers once +// BlockOnInitialCredentialHandshaker is implemented. +void TlsServerSecurityConnector::TlsServerCertificateWatcher::OnError( + grpc_error_handle root_cert_error, grpc_error_handle identity_cert_error) { + if (root_cert_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "TlsServerCertificateWatcher getting root_cert_error: %s", + grpc_error_std_string(root_cert_error).c_str()); + } + if (identity_cert_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "TlsServerCertificateWatcher getting identity_cert_error: %s", + grpc_error_std_string(identity_cert_error).c_str()); + } + GRPC_ERROR_UNREF(root_cert_error); + GRPC_ERROR_UNREF(identity_cert_error); +} + +// TODO(ZhenLian): implement the logic to signal waiting handshakers once +// BlockOnInitialCredentialHandshaker is implemented. +grpc_security_status +TlsServerSecurityConnector::UpdateHandshakerFactoryLocked() { + /* Free the server handshaker factory if exists. */ + if (server_handshaker_factory_ != nullptr) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + // The identity certs on the server side shouldn't be empty. + GPR_ASSERT(pem_key_cert_pair_list_.has_value()); + GPR_ASSERT(!(*pem_key_cert_pair_list_).empty()); + std::string pem_root_certs; + if (pem_root_certs_.has_value()) { + // TODO(ZhenLian): update the underlying TSI layer to use C++ types like + // std::string and absl::string_view to avoid making another copy here. + pem_root_certs = std::string(*pem_root_certs_); + } + tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + pem_key_cert_pairs = ConvertToTsiPemKeyCertPair(*pem_key_cert_pair_list_); + size_t num_key_cert_pairs = (*pem_key_cert_pair_list_).size(); + grpc_security_status status = grpc_ssl_tsi_server_handshaker_factory_init( + pem_key_cert_pairs, num_key_cert_pairs, + pem_root_certs.empty() ? nullptr : pem_root_certs.c_str(), + options_->cert_request_type(), + grpc_get_tsi_tls_version(options_->min_tls_version()), + grpc_get_tsi_tls_version(options_->max_tls_version()), + &server_handshaker_factory_); + /* Free memory. */ + grpc_tsi_ssl_pem_key_cert_pairs_destroy(pem_key_cert_pairs, + num_key_cert_pairs); + return status; +} + +namespace internal { + +grpc_error_handle TlsCheckHostName(const char* peer_name, + const tsi_peer* peer) { + /* Check the peer name if specified. */ + if (peer_name != nullptr && !grpc_ssl_host_matches_name(peer, peer_name)) { + return GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Peer name ", peer_name, " is not in peer certificate")); + } + return GRPC_ERROR_NONE; +} + +} // namespace internal + +} // namespace grpc_core diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc new file mode 100644 index 00000000..fdf10e8f --- /dev/null +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -0,0 +1,473 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/transport/static_metadata.h" + +#define MAX_CREDENTIALS_METADATA_COUNT 4 + +namespace { + +/* We can have a per-channel credentials. */ +struct channel_data { + channel_data(grpc_channel_security_connector* security_connector, + grpc_auth_context* auth_context) + : security_connector( + security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")), + auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {} + ~channel_data() { + security_connector.reset(DEBUG_LOCATION, "client_auth_filter"); + auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + } + + grpc_core::RefCountedPtr security_connector; + grpc_core::RefCountedPtr auth_context; +}; + +/* We can have a per-call credentials. */ +struct call_data { + call_data(grpc_call_element* elem, const grpc_call_element_args& args) + : owning_call(args.call_stack), call_combiner(args.call_combiner) { + channel_data* chand = static_cast(elem->channel_data); + GPR_ASSERT(args.context != nullptr); + if (args.context[GRPC_CONTEXT_SECURITY].value == nullptr) { + args.context[GRPC_CONTEXT_SECURITY].value = + grpc_client_security_context_create(args.arena, /*creds=*/nullptr); + args.context[GRPC_CONTEXT_SECURITY].destroy = + grpc_client_security_context_destroy; + } + grpc_client_security_context* sec_ctx = + static_cast( + args.context[GRPC_CONTEXT_SECURITY].value); + sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + sec_ctx->auth_context = + chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter"); + } + + // This method is technically the dtor of this class. However, since + // `get_request_metadata_cancel_closure` can run in parallel to + // `destroy_call_elem`, we cannot call the dtor in them. Otherwise, + // fields will be accessed after calling dtor, and msan correctly complains + // that the memory is not initialized. + void destroy() { + grpc_credentials_mdelem_array_destroy(&md_array); + creds.reset(); + grpc_slice_unref_internal(host); + grpc_slice_unref_internal(method); + grpc_auth_metadata_context_reset(&auth_md_context); + } + + grpc_call_stack* owning_call; + grpc_core::CallCombiner* call_combiner; + grpc_core::RefCountedPtr creds; + grpc_slice host = grpc_empty_slice(); + grpc_slice method = grpc_empty_slice(); + /* pollset{_set} bound to this call; if we need to make external + network requests, they should be done under a pollset added to this + pollset_set so that work can progress when this call wants work to progress + */ + grpc_polling_entity* pollent = nullptr; + grpc_credentials_mdelem_array md_array; + grpc_linked_mdelem md_links[MAX_CREDENTIALS_METADATA_COUNT] = {}; + grpc_auth_metadata_context auth_md_context = + grpc_auth_metadata_context(); // Zero-initialize the C struct. + grpc_closure async_result_closure; + grpc_closure check_call_host_cancel_closure; + grpc_closure get_request_metadata_cancel_closure; +}; + +} // namespace + +void grpc_auth_metadata_context_copy(grpc_auth_metadata_context* from, + grpc_auth_metadata_context* to) { + grpc_auth_metadata_context_reset(to); + to->channel_auth_context = from->channel_auth_context; + if (to->channel_auth_context != nullptr) { + const_cast(to->channel_auth_context) + ->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context_copy") + .release(); + } + to->service_url = gpr_strdup(from->service_url); + to->method_name = gpr_strdup(from->method_name); +} + +void grpc_auth_metadata_context_reset( + grpc_auth_metadata_context* auth_md_context) { + if (auth_md_context->service_url != nullptr) { + gpr_free(const_cast(auth_md_context->service_url)); + auth_md_context->service_url = nullptr; + } + if (auth_md_context->method_name != nullptr) { + gpr_free(const_cast(auth_md_context->method_name)); + auth_md_context->method_name = nullptr; + } + if (auth_md_context->channel_auth_context != nullptr) { + const_cast(auth_md_context->channel_auth_context) + ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context"); + auth_md_context->channel_auth_context = nullptr; + } +} + +static void add_error(grpc_error_handle* combined, grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) return; + if (*combined == GRPC_ERROR_NONE) { + *combined = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Client auth metadata plugin error"); + } + *combined = grpc_error_add_child(*combined, error); +} + +static void on_credentials_metadata(void* arg, grpc_error_handle input_error) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + grpc_call_element* elem = + static_cast(batch->handler_private.extra_arg); + call_data* calld = static_cast(elem->call_data); + grpc_auth_metadata_context_reset(&calld->auth_md_context); + grpc_error_handle error = GRPC_ERROR_REF(input_error); + if (error == GRPC_ERROR_NONE) { + GPR_ASSERT(calld->md_array.size <= MAX_CREDENTIALS_METADATA_COUNT); + GPR_ASSERT(batch->send_initial_metadata); + grpc_metadata_batch* mdb = + batch->payload->send_initial_metadata.send_initial_metadata; + for (size_t i = 0; i < calld->md_array.size; ++i) { + add_error(&error, grpc_metadata_batch_add_tail( + mdb, &calld->md_links[i], + GRPC_MDELEM_REF(calld->md_array.md[i]))); + } + } + if (error == GRPC_ERROR_NONE) { + grpc_call_next_op(elem, batch); + } else { + error = grpc_error_set_int(error, GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAVAILABLE); + grpc_transport_stream_op_batch_finish_with_failure(batch, error, + calld->call_combiner); + } + GRPC_CALL_STACK_UNREF(calld->owning_call, "get_request_metadata"); +} + +void grpc_auth_metadata_context_build( + const char* url_scheme, const grpc_slice& call_host, + const grpc_slice& call_method, grpc_auth_context* auth_context, + grpc_auth_metadata_context* auth_md_context) { + char* service = grpc_slice_to_c_string(call_method); + char* last_slash = strrchr(service, '/'); + char* method_name = nullptr; + char* service_url = nullptr; + grpc_auth_metadata_context_reset(auth_md_context); + if (last_slash == nullptr) { + gpr_log(GPR_ERROR, "No '/' found in fully qualified method name"); + service[0] = '\0'; + method_name = gpr_strdup(""); + } else if (last_slash == service) { + method_name = gpr_strdup(""); + } else { + *last_slash = '\0'; + method_name = gpr_strdup(last_slash + 1); + } + char* host_and_port = grpc_slice_to_c_string(call_host); + if (url_scheme != nullptr && strcmp(url_scheme, GRPC_SSL_URL_SCHEME) == 0) { + /* Remove the port if it is 443. */ + char* port_delimiter = strrchr(host_and_port, ':'); + if (port_delimiter != nullptr && strcmp(port_delimiter + 1, "443") == 0) { + *port_delimiter = '\0'; + } + } + gpr_asprintf(&service_url, "%s://%s%s", + url_scheme == nullptr ? "" : url_scheme, host_and_port, service); + auth_md_context->service_url = service_url; + auth_md_context->method_name = method_name; + auth_md_context->channel_auth_context = + auth_context == nullptr + ? nullptr + : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context") + .release(); + gpr_free(service); + gpr_free(host_and_port); +} + +static void cancel_get_request_metadata(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + calld->creds->cancel_get_request_metadata(&calld->md_array, + GRPC_ERROR_REF(error)); + } + GRPC_CALL_STACK_UNREF(calld->owning_call, "cancel_get_request_metadata"); +} + +static void send_security_metadata(grpc_call_element* elem, + grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + grpc_client_security_context* ctx = + static_cast( + batch->payload->context[GRPC_CONTEXT_SECURITY].value); + grpc_call_credentials* channel_call_creds = + chand->security_connector->mutable_request_metadata_creds(); + int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); + + if (channel_call_creds == nullptr && !call_creds_has_md) { + /* Skip sending metadata altogether. */ + grpc_call_next_op(elem, batch); + return; + } + + if (channel_call_creds != nullptr && call_creds_has_md) { + calld->creds = grpc_core::RefCountedPtr( + grpc_composite_call_credentials_create(channel_call_creds, + ctx->creds.get(), nullptr)); + if (calld->creds == nullptr) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Incompatible credentials set on channel and call."), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAUTHENTICATED), + calld->call_combiner); + return; + } + } else { + calld->creds = + call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref(); + } + + /* Check security level of call credential and channel, and do not send + * metadata if the check fails. */ + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + chand->auth_context.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Established channel does not have an auth property " + "representing a security level."), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAUTHENTICATED), + calld->call_combiner); + return; + } + grpc_security_level call_cred_security_level = + calld->creds->min_security_level(); + int is_security_level_ok = grpc_check_security_level( + grpc_tsi_security_level_string_to_enum(prop->value), + call_cred_security_level); + if (!is_security_level_ok) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Established channel does not have a sufficient " + "security level to transfer call credential."), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNAUTHENTICATED), + calld->call_combiner); + return; + } + + grpc_auth_metadata_context_build( + chand->security_connector->url_scheme(), calld->host, calld->method, + chand->auth_context.get(), &calld->auth_md_context); + + GPR_ASSERT(calld->pollent != nullptr); + GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata"); + GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata, + batch, grpc_schedule_on_exec_ctx); + grpc_error_handle error = GRPC_ERROR_NONE; + if (calld->creds->get_request_metadata( + calld->pollent, calld->auth_md_context, &calld->md_array, + &calld->async_result_closure, &error)) { + // Synchronous return; invoke on_credentials_metadata() directly. + on_credentials_metadata(batch, error); + GRPC_ERROR_UNREF(error); + } else { + // Async return; register cancellation closure with call combiner. + // TODO(yashykt): We would not need this ref if call combiners used + // Closure::Run() instead of ExecCtx::Run() + GRPC_CALL_STACK_REF(calld->owning_call, "cancel_get_request_metadata"); + calld->call_combiner->SetNotifyOnCancel(GRPC_CLOSURE_INIT( + &calld->get_request_metadata_cancel_closure, + cancel_get_request_metadata, elem, grpc_schedule_on_exec_ctx)); + } +} + +static void on_host_checked(void* arg, grpc_error_handle error) { + grpc_transport_stream_op_batch* batch = + static_cast(arg); + grpc_call_element* elem = + static_cast(batch->handler_private.extra_arg); + call_data* calld = static_cast(elem->call_data); + if (error == GRPC_ERROR_NONE) { + send_security_metadata(elem, batch); + } else { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Invalid host ", grpc_core::StringViewFromSlice(calld->host), + " set in :authority metadata.")), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAUTHENTICATED), + calld->call_combiner); + } + GRPC_CALL_STACK_UNREF(calld->owning_call, "check_call_host"); +} + +static void cancel_check_call_host(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + if (error != GRPC_ERROR_NONE) { + chand->security_connector->cancel_check_call_host( + &calld->async_result_closure, GRPC_ERROR_REF(error)); + } + GRPC_CALL_STACK_UNREF(calld->owning_call, "cancel_check_call_host"); +} + +static void client_auth_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + GPR_TIMER_SCOPE("auth_start_transport_stream_op_batch", 0); + + /* grab pointers to our data from the call element */ + call_data* calld = static_cast(elem->call_data); + channel_data* chand = static_cast(elem->channel_data); + + if (batch->send_initial_metadata) { + grpc_metadata_batch* metadata = + batch->payload->send_initial_metadata.send_initial_metadata; + if (metadata->legacy_index()->named.path != nullptr) { + calld->method = grpc_slice_ref_internal( + GRPC_MDVALUE(metadata->legacy_index()->named.path->md)); + } + if (metadata->legacy_index()->named.authority != nullptr) { + calld->host = grpc_slice_ref_internal( + GRPC_MDVALUE(metadata->legacy_index()->named.authority->md)); + batch->handler_private.extra_arg = elem; + GRPC_CALL_STACK_REF(calld->owning_call, "check_call_host"); + GRPC_CLOSURE_INIT(&calld->async_result_closure, on_host_checked, batch, + grpc_schedule_on_exec_ctx); + absl::string_view call_host(grpc_core::StringViewFromSlice(calld->host)); + grpc_error_handle error = GRPC_ERROR_NONE; + if (chand->security_connector->check_call_host( + call_host, chand->auth_context.get(), + &calld->async_result_closure, &error)) { + // Synchronous return; invoke on_host_checked() directly. + on_host_checked(batch, error); + GRPC_ERROR_UNREF(error); + } else { + // Async return; register cancellation closure with call combiner. + // TODO(yashykt): We would not need this ref if call combiners used + // Closure::Run() instead of ExecCtx::Run() + GRPC_CALL_STACK_REF(calld->owning_call, "cancel_check_call_host"); + calld->call_combiner->SetNotifyOnCancel(GRPC_CLOSURE_INIT( + &calld->check_call_host_cancel_closure, cancel_check_call_host, + elem, grpc_schedule_on_exec_ctx)); + } + return; /* early exit */ + } + } + + /* pass control down the stack */ + grpc_call_next_op(elem, batch); +} + +/* Constructor for call_data */ +static grpc_error_handle client_auth_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + new (elem->call_data) call_data(elem, *args); + return GRPC_ERROR_NONE; +} + +static void client_auth_set_pollset_or_pollset_set( + grpc_call_element* elem, grpc_polling_entity* pollent) { + call_data* calld = static_cast(elem->call_data); + calld->pollent = pollent; +} + +/* Destructor for call_data */ +static void client_auth_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->destroy(); +} + +/* Constructor for channel_data */ +static grpc_error_handle client_auth_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + /* The first and the last filters tend to be implemented differently to + handle the case that there's no 'next' filter to call on the up or down + path */ + GPR_ASSERT(!args->is_last); + grpc_security_connector* sc = + grpc_security_connector_find_in_args(args->channel_args); + if (sc == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Security connector missing from client auth filter args"); + } + grpc_auth_context* auth_context = + grpc_find_auth_context_in_args(args->channel_args); + if (auth_context == nullptr) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Auth context missing from client auth filter args"); + } + new (elem->channel_data) channel_data( + static_cast(sc), auth_context); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +static void client_auth_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + chand->~channel_data(); +} + +const grpc_channel_filter grpc_client_auth_filter = { + client_auth_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + client_auth_init_call_elem, + client_auth_set_pollset_or_pollset_set, + client_auth_destroy_call_elem, + sizeof(channel_data), + client_auth_init_channel_elem, + client_auth_destroy_channel_elem, + grpc_channel_next_get_info, + "client-auth"}; diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc new file mode 100644 index 00000000..c60ecfba --- /dev/null +++ b/src/core/lib/security/transport/secure_endpoint.cc @@ -0,0 +1,442 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/transport/secure_endpoint.h" + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/security/transport/tsi_error.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/tsi/transport_security_grpc.h" + +#define STAGING_BUFFER_SIZE 8192 + +static void on_read(void* user_data, grpc_error_handle error); + +namespace { +struct secure_endpoint { + secure_endpoint(const grpc_endpoint_vtable* vtable, + tsi_frame_protector* protector, + tsi_zero_copy_grpc_protector* zero_copy_protector, + grpc_endpoint* transport, grpc_slice* leftover_slices, + size_t leftover_nslices) + : wrapped_ep(transport), + protector(protector), + zero_copy_protector(zero_copy_protector) { + base.vtable = vtable; + gpr_mu_init(&protector_mu); + GRPC_CLOSURE_INIT(&on_read, ::on_read, this, grpc_schedule_on_exec_ctx); + grpc_slice_buffer_init(&source_buffer); + grpc_slice_buffer_init(&leftover_bytes); + for (size_t i = 0; i < leftover_nslices; i++) { + grpc_slice_buffer_add(&leftover_bytes, + grpc_slice_ref_internal(leftover_slices[i])); + } + grpc_slice_buffer_init(&output_buffer); + gpr_ref_init(&ref, 1); + } + + ~secure_endpoint() { + grpc_endpoint_destroy(wrapped_ep); + tsi_frame_protector_destroy(protector); + tsi_zero_copy_grpc_protector_destroy(zero_copy_protector); + grpc_slice_buffer_destroy_internal(&source_buffer); + grpc_slice_buffer_destroy_internal(&leftover_bytes); + grpc_slice_unref_internal(read_staging_buffer); + grpc_slice_unref_internal(write_staging_buffer); + grpc_slice_buffer_destroy_internal(&output_buffer); + gpr_mu_destroy(&protector_mu); + } + + grpc_endpoint base; + grpc_endpoint* wrapped_ep; + struct tsi_frame_protector* protector; + struct tsi_zero_copy_grpc_protector* zero_copy_protector; + gpr_mu protector_mu; + /* saved upper level callbacks and user_data. */ + grpc_closure* read_cb = nullptr; + grpc_closure* write_cb = nullptr; + grpc_closure on_read; + grpc_slice_buffer* read_buffer = nullptr; + grpc_slice_buffer source_buffer; + /* saved handshaker leftover data to unprotect. */ + grpc_slice_buffer leftover_bytes; + /* buffers for read and write */ + grpc_slice read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + grpc_slice write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + grpc_slice_buffer output_buffer; + + gpr_refcount ref; +}; +} // namespace + +grpc_core::TraceFlag grpc_trace_secure_endpoint(false, "secure_endpoint"); + +static void destroy(secure_endpoint* ep) { delete ep; } + +#ifndef NDEBUG +#define SECURE_ENDPOINT_UNREF(ep, reason) \ + secure_endpoint_unref((ep), (reason), __FILE__, __LINE__) +#define SECURE_ENDPOINT_REF(ep, reason) \ + secure_endpoint_ref((ep), (reason), __FILE__, __LINE__) +static void secure_endpoint_unref(secure_endpoint* ep, const char* reason, + const char* file, int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { + gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "SECENDP unref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val, + val - 1); + } + if (gpr_unref(&ep->ref)) { + destroy(ep); + } +} + +static void secure_endpoint_ref(secure_endpoint* ep, const char* reason, + const char* file, int line) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { + gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "SECENDP ref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val, + val + 1); + } + gpr_ref(&ep->ref); +} +#else +#define SECURE_ENDPOINT_UNREF(ep, reason) secure_endpoint_unref((ep)) +#define SECURE_ENDPOINT_REF(ep, reason) secure_endpoint_ref((ep)) +static void secure_endpoint_unref(secure_endpoint* ep) { + if (gpr_unref(&ep->ref)) { + destroy(ep); + } +} + +static void secure_endpoint_ref(secure_endpoint* ep) { gpr_ref(&ep->ref); } +#endif + +static void flush_read_staging_buffer(secure_endpoint* ep, uint8_t** cur, + uint8_t** end) { + grpc_slice_buffer_add(ep->read_buffer, ep->read_staging_buffer); + ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + *cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer); + *end = GRPC_SLICE_END_PTR(ep->read_staging_buffer); +} + +static void call_read_cb(secure_endpoint* ep, grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { + size_t i; + for (i = 0; i < ep->read_buffer->count; i++) { + char* data = grpc_dump_slice(ep->read_buffer->slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "READ %p: %s", ep, data); + gpr_free(data); + } + } + ep->read_buffer = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, ep->read_cb, error); + SECURE_ENDPOINT_UNREF(ep, "read"); +} + +static void on_read(void* user_data, grpc_error_handle error) { + unsigned i; + uint8_t keep_looping = 0; + tsi_result result = TSI_OK; + secure_endpoint* ep = static_cast(user_data); + uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer); + uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer); + + if (error != GRPC_ERROR_NONE) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer); + call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Secure read failed", &error, 1)); + return; + } + + if (ep->zero_copy_protector != nullptr) { + // Use zero-copy grpc protector to unprotect. + result = tsi_zero_copy_grpc_protector_unprotect( + ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer); + } else { + // Use frame protector to unprotect. + /* TODO(yangg) check error, maybe bail out early */ + for (i = 0; i < ep->source_buffer.count; i++) { + grpc_slice encrypted = ep->source_buffer.slices[i]; + uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted); + size_t message_size = GRPC_SLICE_LENGTH(encrypted); + + while (message_size > 0 || keep_looping) { + size_t unprotected_buffer_size_written = static_cast(end - cur); + size_t processed_message_size = message_size; + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_unprotect( + ep->protector, message_bytes, &processed_message_size, cur, + &unprotected_buffer_size_written); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Decryption error: %s", + tsi_result_to_string(result)); + break; + } + message_bytes += processed_message_size; + message_size -= processed_message_size; + cur += unprotected_buffer_size_written; + + if (cur == end) { + flush_read_staging_buffer(ep, &cur, &end); + /* Force to enter the loop again to extract buffered bytes in + protector. The bytes could be buffered because of running out of + staging_buffer. If this happens at the end of all slices, doing + another unprotect avoids leaving data in the protector. */ + keep_looping = 1; + } else if (unprotected_buffer_size_written > 0) { + keep_looping = 1; + } else { + keep_looping = 0; + } + } + if (result != TSI_OK) break; + } + + if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) { + grpc_slice_buffer_add( + ep->read_buffer, + grpc_slice_split_head( + &ep->read_staging_buffer, + static_cast( + cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer)))); + } + } + + /* TODO(yangg) experiment with moving this block after read_cb to see if it + helps latency */ + grpc_slice_buffer_reset_and_unref_internal(&ep->source_buffer); + + if (result != TSI_OK) { + grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer); + call_read_cb( + ep, grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Unwrap failed"), result)); + return; + } + + call_read_cb(ep, GRPC_ERROR_NONE); +} + +static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool urgent) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + ep->read_cb = cb; + ep->read_buffer = slices; + grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer); + + SECURE_ENDPOINT_REF(ep, "read"); + if (ep->leftover_bytes.count) { + grpc_slice_buffer_swap(&ep->leftover_bytes, &ep->source_buffer); + GPR_ASSERT(ep->leftover_bytes.count == 0); + on_read(ep, GRPC_ERROR_NONE); + return; + } + + grpc_endpoint_read(ep->wrapped_ep, &ep->source_buffer, &ep->on_read, urgent); +} + +static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur, + uint8_t** end) { + grpc_slice_buffer_add(&ep->output_buffer, ep->write_staging_buffer); + ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE); + *cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer); + *end = GRPC_SLICE_END_PTR(ep->write_staging_buffer); +} + +static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* arg) { + GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0); + + unsigned i; + tsi_result result = TSI_OK; + secure_endpoint* ep = reinterpret_cast(secure_ep); + uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer); + uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer); + + grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_secure_endpoint)) { + for (i = 0; i < slices->count; i++) { + char* data = + grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_INFO, "WRITE %p: %s", ep, data); + gpr_free(data); + } + } + + if (ep->zero_copy_protector != nullptr) { + // Use zero-copy grpc protector to protect. + result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector, + slices, &ep->output_buffer); + } else { + // Use frame protector to protect. + for (i = 0; i < slices->count; i++) { + grpc_slice plain = slices->slices[i]; + uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain); + size_t message_size = GRPC_SLICE_LENGTH(plain); + while (message_size > 0) { + size_t protected_buffer_size_to_send = static_cast(end - cur); + size_t processed_message_size = message_size; + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_protect(ep->protector, message_bytes, + &processed_message_size, cur, + &protected_buffer_size_to_send); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Encryption error: %s", + tsi_result_to_string(result)); + break; + } + message_bytes += processed_message_size; + message_size -= processed_message_size; + cur += protected_buffer_size_to_send; + + if (cur == end) { + flush_write_staging_buffer(ep, &cur, &end); + } + } + if (result != TSI_OK) break; + } + if (result == TSI_OK) { + size_t still_pending_size; + do { + size_t protected_buffer_size_to_send = static_cast(end - cur); + gpr_mu_lock(&ep->protector_mu); + result = tsi_frame_protector_protect_flush( + ep->protector, cur, &protected_buffer_size_to_send, + &still_pending_size); + gpr_mu_unlock(&ep->protector_mu); + if (result != TSI_OK) break; + cur += protected_buffer_size_to_send; + if (cur == end) { + flush_write_staging_buffer(ep, &cur, &end); + } + } while (still_pending_size > 0); + if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) { + grpc_slice_buffer_add( + &ep->output_buffer, + grpc_slice_split_head( + &ep->write_staging_buffer, + static_cast( + cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer)))); + } + } + } + + if (result != TSI_OK) { + /* TODO(yangg) do different things according to the error type? */ + grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, cb, + grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Wrap failed"), result)); + return; + } + + grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg); +} + +static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + grpc_endpoint_shutdown(ep->wrapped_ep, why); +} + +static void endpoint_destroy(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + SECURE_ENDPOINT_UNREF(ep, "destroy"); +} + +static void endpoint_add_to_pollset(grpc_endpoint* secure_ep, + grpc_pollset* pollset) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + grpc_endpoint_add_to_pollset(ep->wrapped_ep, pollset); +} + +static void endpoint_add_to_pollset_set(grpc_endpoint* secure_ep, + grpc_pollset_set* pollset_set) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + grpc_endpoint_add_to_pollset_set(ep->wrapped_ep, pollset_set); +} + +static void endpoint_delete_from_pollset_set(grpc_endpoint* secure_ep, + grpc_pollset_set* pollset_set) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + grpc_endpoint_delete_from_pollset_set(ep->wrapped_ep, pollset_set); +} + +static absl::string_view endpoint_get_peer(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + return grpc_endpoint_get_peer(ep->wrapped_ep); +} + +static absl::string_view endpoint_get_local_address(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + return grpc_endpoint_get_local_address(ep->wrapped_ep); +} + +static int endpoint_get_fd(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + return grpc_endpoint_get_fd(ep->wrapped_ep); +} + +static bool endpoint_can_track_err(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast(secure_ep); + return grpc_endpoint_can_track_err(ep->wrapped_ep); +} + +static const grpc_endpoint_vtable vtable = {endpoint_read, + endpoint_write, + endpoint_add_to_pollset, + endpoint_add_to_pollset_set, + endpoint_delete_from_pollset_set, + endpoint_shutdown, + endpoint_destroy, + endpoint_get_peer, + endpoint_get_local_address, + endpoint_get_fd, + endpoint_can_track_err}; + +grpc_endpoint* grpc_secure_endpoint_create( + struct tsi_frame_protector* protector, + struct tsi_zero_copy_grpc_protector* zero_copy_protector, + grpc_endpoint* to_wrap, grpc_slice* leftover_slices, + size_t leftover_nslices) { + secure_endpoint* ep = + new secure_endpoint(&vtable, protector, zero_copy_protector, to_wrap, + leftover_slices, leftover_nslices); + return &ep->base; +} diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc new file mode 100644 index 00000000..0f9114d8 --- /dev/null +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -0,0 +1,642 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/transport/security_handshaker.h" + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/transport/secure_endpoint.h" +#include "src/core/lib/security/transport/tsi_error.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/transport_security_grpc.h" + +#define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 + +namespace grpc_core { + +namespace { + +class SecurityHandshaker : public Handshaker { + public: + SecurityHandshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector, + const grpc_channel_args* args); + ~SecurityHandshaker() override; + void Shutdown(grpc_error_handle why) override; + void DoHandshake(grpc_tcp_server_acceptor* acceptor, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override; + const char* name() const override { return "security"; } + + private: + grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received, + size_t bytes_received_size); + + grpc_error_handle OnHandshakeNextDoneLocked( + tsi_result result, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result); + void HandshakeFailedLocked(grpc_error_handle error); + void CleanupArgsForFailureLocked(); + + static void OnHandshakeDataReceivedFromPeerFn(void* arg, + grpc_error_handle error); + static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error_handle error); + static void OnHandshakeDataReceivedFromPeerFnScheduler( + void* arg, grpc_error_handle error); + static void OnHandshakeDataSentToPeerFnScheduler(void* arg, + grpc_error_handle error); + static void OnHandshakeNextDoneGrpcWrapper( + tsi_result result, void* user_data, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result); + static void OnPeerCheckedFn(void* arg, grpc_error_handle error); + void OnPeerCheckedInner(grpc_error_handle error); + size_t MoveReadBufferIntoHandshakeBuffer(); + grpc_error_handle CheckPeerLocked(); + + // State set at creation time. + tsi_handshaker* handshaker_; + RefCountedPtr connector_; + + Mutex mu_; + + bool is_shutdown_ = false; + // Endpoint and read buffer to destroy after a shutdown. + grpc_endpoint* endpoint_to_destroy_ = nullptr; + grpc_slice_buffer* read_buffer_to_destroy_ = nullptr; + + // State saved while performing the handshake. + HandshakerArgs* args_ = nullptr; + grpc_closure* on_handshake_done_ = nullptr; + + size_t handshake_buffer_size_; + unsigned char* handshake_buffer_; + grpc_slice_buffer outgoing_; + grpc_closure on_handshake_data_sent_to_peer_; + grpc_closure on_handshake_data_received_from_peer_; + grpc_closure on_peer_checked_; + RefCountedPtr auth_context_; + tsi_handshaker_result* handshaker_result_ = nullptr; + size_t max_frame_size_ = 0; +}; + +SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector, + const grpc_channel_args* args) + : handshaker_(handshaker), + connector_(connector->Ref(DEBUG_LOCATION, "handshake")), + handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), + handshake_buffer_( + static_cast(gpr_malloc(handshake_buffer_size_))), + max_frame_size_(grpc_channel_args_find_integer( + args, GRPC_ARG_TSI_MAX_FRAME_SIZE, + {0, 0, std::numeric_limits::max()})) { + grpc_slice_buffer_init(&outgoing_); + GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn, + this, grpc_schedule_on_exec_ctx); +} + +SecurityHandshaker::~SecurityHandshaker() { + tsi_handshaker_destroy(handshaker_); + tsi_handshaker_result_destroy(handshaker_result_); + if (endpoint_to_destroy_ != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy_); + } + if (read_buffer_to_destroy_ != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_); + gpr_free(read_buffer_to_destroy_); + } + gpr_free(handshake_buffer_); + grpc_slice_buffer_destroy_internal(&outgoing_); + auth_context_.reset(DEBUG_LOCATION, "handshake"); + connector_.reset(DEBUG_LOCATION, "handshake"); +} + +size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() { + size_t bytes_in_read_buffer = args_->read_buffer->length; + if (handshake_buffer_size_ < bytes_in_read_buffer) { + handshake_buffer_ = static_cast( + gpr_realloc(handshake_buffer_, bytes_in_read_buffer)); + handshake_buffer_size_ = bytes_in_read_buffer; + } + size_t offset = 0; + while (args_->read_buffer->count > 0) { + grpc_slice* next_slice = grpc_slice_buffer_peek_first(args_->read_buffer); + memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(*next_slice), + GRPC_SLICE_LENGTH(*next_slice)); + offset += GRPC_SLICE_LENGTH(*next_slice); + grpc_slice_buffer_remove_first(args_->read_buffer); + } + return bytes_in_read_buffer; +} + +// Set args_ fields to NULL, saving the endpoint and read buffer for +// later destruction. +void SecurityHandshaker::CleanupArgsForFailureLocked() { + endpoint_to_destroy_ = args_->endpoint; + args_->endpoint = nullptr; + read_buffer_to_destroy_ = args_->read_buffer; + args_->read_buffer = nullptr; + grpc_channel_args_destroy(args_->args); + args_->args = nullptr; +} + +// If the handshake failed or we're shutting down, clean up and invoke the +// callback with the error. +void SecurityHandshaker::HandshakeFailedLocked(grpc_error_handle error) { + if (error == GRPC_ERROR_NONE) { + // If we were shut down after the handshake succeeded but before an + // endpoint callback was invoked, we need to generate our own error. + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); + } + gpr_log(GPR_DEBUG, "Security handshake failed: %s", + grpc_error_std_string(error).c_str()); + if (!is_shutdown_) { + tsi_handshaker_shutdown(handshaker_); + // TODO(ctiller): It is currently necessary to shutdown endpoints + // before destroying them, even if we know that there are no + // pending read/write callbacks. This should be fixed, at which + // point this can be removed. + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error)); + // Not shutting down, so the write failed. Clean up before + // invoking the callback. + CleanupArgsForFailureLocked(); + // Set shutdown to true so that subsequent calls to + // security_handshaker_shutdown() do nothing. + is_shutdown_ = true; + } + // Invoke callback. + ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error); +} + +namespace { + +RefCountedPtr +MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) { + RefCountedPtr security = + MakeRefCounted(); + // TODO(yashykt): Currently, we are assuming TLS by default and are only able + // to fill in the remote certificate but we should ideally be able to fill in + // other fields in + // https://github.com/grpc/grpc/blob/fcd43e90304862a823316b224ee733d17a8cfd90/src/proto/grpc/channelz/channelz.proto#L326 + // from grpc_auth_context. + security->type = channelz::SocketNode::Security::ModelType::kTls; + security->tls = absl::make_optional(); + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + auth_context, GRPC_X509_PEM_CERT_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop != nullptr) { + security->tls->remote_certificate = + std::string(prop->value, prop->value_length); + } + return security; +} + +} // namespace + +void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) { + MutexLock lock(&mu_); + if (error != GRPC_ERROR_NONE || is_shutdown_) { + HandshakeFailedLocked(error); + return; + } + // Get unused bytes. + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + tsi_result result = tsi_handshaker_result_get_unused_bytes( + handshaker_result_, &unused_bytes, &unused_bytes_size); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TSI handshaker result does not provide unused bytes"), + result)); + return; + } + // Check whether we need to wrap the endpoint. + tsi_frame_protector_type frame_protector_type; + result = tsi_handshaker_result_get_frame_protector_type( + handshaker_result_, &frame_protector_type); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "TSI handshaker result does not implement " + "get_frame_protector_type"), + result)); + return; + } + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + tsi_frame_protector* protector = nullptr; + switch (frame_protector_type) { + case TSI_FRAME_PROTECTOR_ZERO_COPY: + ABSL_FALLTHROUGH_INTENDED; + case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY: + // Create zero-copy frame protector. + result = tsi_handshaker_result_create_zero_copy_grpc_protector( + handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, + &zero_copy_protector); + if (result != TSI_OK) { + HandshakeFailedLocked(grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Zero-copy frame protector creation failed"), + result)); + return; + } + break; + case TSI_FRAME_PROTECTOR_NORMAL: + // Create normal frame protector. + result = tsi_handshaker_result_create_frame_protector( + handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_, + &protector); + if (result != TSI_OK) { + HandshakeFailedLocked( + grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Frame protector creation failed"), + result)); + return; + } + break; + case TSI_FRAME_PROTECTOR_NONE: + break; + } + // If we have a frame protector, create a secure endpoint. + if (zero_copy_protector != nullptr || protector != nullptr) { + if (unused_bytes_size > 0) { + grpc_slice slice = grpc_slice_from_copied_buffer( + reinterpret_cast(unused_bytes), unused_bytes_size); + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, &slice, 1); + grpc_slice_unref_internal(slice); + } else { + args_->endpoint = grpc_secure_endpoint_create( + protector, zero_copy_protector, args_->endpoint, nullptr, 0); + } + } else if (unused_bytes_size > 0) { + // Not wrapping the endpoint, so just pass along unused bytes. + grpc_slice slice = grpc_slice_from_copied_buffer( + reinterpret_cast(unused_bytes), unused_bytes_size); + grpc_slice_buffer_add(args_->read_buffer, slice); + } + // Done with handshaker result. + tsi_handshaker_result_destroy(handshaker_result_); + handshaker_result_ = nullptr; + // Add auth context to channel args. + absl::InlinedVector args_to_add; + args_to_add.push_back(grpc_auth_context_to_arg(auth_context_.get())); + auto security = MakeChannelzSecurityFromAuthContext(auth_context_.get()); + args_to_add.push_back(security->MakeChannelArg()); + grpc_channel_args* tmp_args = args_->args; + args_->args = grpc_channel_args_copy_and_add(tmp_args, args_to_add.data(), + args_to_add.size()); + grpc_channel_args_destroy(tmp_args); + // Invoke callback. + ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, GRPC_ERROR_NONE); + // Set shutdown to true so that subsequent calls to + // security_handshaker_shutdown() do nothing. + is_shutdown_ = true; +} + +void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error_handle error) { + RefCountedPtr(static_cast(arg)) + ->OnPeerCheckedInner(GRPC_ERROR_REF(error)); +} + +grpc_error_handle SecurityHandshaker::CheckPeerLocked() { + tsi_peer peer; + tsi_result result = + tsi_handshaker_result_extract_peer(handshaker_result_, &peer); + if (result != TSI_OK) { + return grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); + } + connector_->check_peer(peer, args_->endpoint, &auth_context_, + &on_peer_checked_); + return GRPC_ERROR_NONE; +} + +grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked( + tsi_result result, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { + grpc_error_handle error = GRPC_ERROR_NONE; + // Handshaker was shutdown. + if (is_shutdown_) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown"); + } + // Read more if we need to. + if (result == TSI_INCOMPLETE_DATA) { + GPR_ASSERT(bytes_to_send_size == 0); + grpc_endpoint_read( + args_->endpoint, args_->read_buffer, + GRPC_CLOSURE_INIT( + &on_handshake_data_received_from_peer_, + &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, + this, grpc_schedule_on_exec_ctx), + /*urgent=*/true); + return error; + } + if (result != TSI_OK) { + return grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake failed"), result); + } + // Update handshaker result. + if (handshaker_result != nullptr) { + GPR_ASSERT(handshaker_result_ == nullptr); + handshaker_result_ = handshaker_result; + } + if (bytes_to_send_size > 0) { + // Send data to peer, if needed. + grpc_slice to_send = grpc_slice_from_copied_buffer( + reinterpret_cast(bytes_to_send), bytes_to_send_size); + grpc_slice_buffer_reset_and_unref_internal(&outgoing_); + grpc_slice_buffer_add(&outgoing_, to_send); + grpc_endpoint_write( + args_->endpoint, &outgoing_, + GRPC_CLOSURE_INIT( + &on_handshake_data_sent_to_peer_, + &SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this, + grpc_schedule_on_exec_ctx), + nullptr); + } else if (handshaker_result == nullptr) { + // There is nothing to send, but need to read from peer. + grpc_endpoint_read( + args_->endpoint, args_->read_buffer, + GRPC_CLOSURE_INIT( + &on_handshake_data_received_from_peer_, + &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, + this, grpc_schedule_on_exec_ctx), + /*urgent=*/true); + } else { + // Handshake has finished, check peer and so on. + error = CheckPeerLocked(); + } + return error; +} + +void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper( + tsi_result result, void* user_data, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { + RefCountedPtr h( + static_cast(user_data)); + MutexLock lock(&h->mu_); + grpc_error_handle error = h->OnHandshakeNextDoneLocked( + result, bytes_to_send, bytes_to_send_size, handshaker_result); + if (error != GRPC_ERROR_NONE) { + h->HandshakeFailedLocked(error); + } else { + h.release(); // Avoid unref + } +} + +grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked( + const unsigned char* bytes_received, size_t bytes_received_size) { + // Invoke TSI handshaker. + const unsigned char* bytes_to_send = nullptr; + size_t bytes_to_send_size = 0; + tsi_handshaker_result* hs_result = nullptr; + tsi_result result = tsi_handshaker_next( + handshaker_, bytes_received, bytes_received_size, &bytes_to_send, + &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this); + if (result == TSI_ASYNC) { + // Handshaker operating asynchronously. Nothing else to do here; + // callback will be invoked in a TSI thread. + return GRPC_ERROR_NONE; + } + // Handshaker returned synchronously. Invoke callback directly in + // this thread with our existing exec_ctx. + return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size, + hs_result); +} + +// This callback might be run inline while we are still holding on to the mutex, +// so schedule OnHandshakeDataReceivedFromPeerFn on ExecCtx to avoid a deadlock. +void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler( + void* arg, grpc_error_handle error) { + SecurityHandshaker* h = static_cast(arg); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer_, + &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn, + h, grpc_schedule_on_exec_ctx), + GRPC_ERROR_REF(error)); +} + +void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn( + void* arg, grpc_error_handle error) { + RefCountedPtr h(static_cast(arg)); + MutexLock lock(&h->mu_); + if (error != GRPC_ERROR_NONE || h->is_shutdown_) { + h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Handshake read failed", &error, 1)); + return; + } + // Copy all slices received. + size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer(); + // Call TSI handshaker. + error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size); + if (error != GRPC_ERROR_NONE) { + h->HandshakeFailedLocked(error); + } else { + h.release(); // Avoid unref + } +} + +// This callback might be run inline while we are still holding on to the mutex, +// so schedule OnHandshakeDataSentToPeerFn on ExecCtx to avoid a deadlock. +void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler( + void* arg, grpc_error_handle error) { + SecurityHandshaker* h = static_cast(arg); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer_, + &SecurityHandshaker::OnHandshakeDataSentToPeerFn, h, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_REF(error)); +} + +void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg, + grpc_error_handle error) { + RefCountedPtr h(static_cast(arg)); + MutexLock lock(&h->mu_); + if (error != GRPC_ERROR_NONE || h->is_shutdown_) { + h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Handshake write failed", &error, 1)); + return; + } + // We may be done. + if (h->handshaker_result_ == nullptr) { + grpc_endpoint_read( + h->args_->endpoint, h->args_->read_buffer, + GRPC_CLOSURE_INIT( + &h->on_handshake_data_received_from_peer_, + &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler, + h.get(), grpc_schedule_on_exec_ctx), + /*urgent=*/true); + } else { + error = h->CheckPeerLocked(); + if (error != GRPC_ERROR_NONE) { + h->HandshakeFailedLocked(error); + return; + } + } + h.release(); // Avoid unref +} + +// +// public handshaker API +// + +void SecurityHandshaker::Shutdown(grpc_error_handle why) { + MutexLock lock(&mu_); + if (!is_shutdown_) { + is_shutdown_ = true; + connector_->cancel_check_peer(&on_peer_checked_, GRPC_ERROR_REF(why)); + tsi_handshaker_shutdown(handshaker_); + grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why)); + CleanupArgsForFailureLocked(); + } + GRPC_ERROR_UNREF(why); +} + +void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, + grpc_closure* on_handshake_done, + HandshakerArgs* args) { + auto ref = Ref(); + MutexLock lock(&mu_); + args_ = args; + on_handshake_done_ = on_handshake_done; + size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer(); + grpc_error_handle error = + DoHandshakerNextLocked(handshake_buffer_, bytes_received_size); + if (error != GRPC_ERROR_NONE) { + HandshakeFailedLocked(error); + } else { + ref.release(); // Avoid unref + } +} + +// +// FailHandshaker +// + +class FailHandshaker : public Handshaker { + public: + const char* name() const override { return "security_fail"; } + void Shutdown(grpc_error_handle why) override { GRPC_ERROR_UNREF(why); } + void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Failed to create security handshaker"); + grpc_endpoint_shutdown(args->endpoint, GRPC_ERROR_REF(error)); + grpc_endpoint_destroy(args->endpoint); + args->endpoint = nullptr; + grpc_channel_args_destroy(args->args); + args->args = nullptr; + grpc_slice_buffer_destroy_internal(args->read_buffer); + gpr_free(args->read_buffer); + args->read_buffer = nullptr; + ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, error); + } + + private: + ~FailHandshaker() override = default; +}; + +// +// handshaker factories +// + +class ClientSecurityHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + auto* security_connector = + reinterpret_cast( + grpc_security_connector_find_in_args(args)); + if (security_connector) { + security_connector->add_handshakers(args, interested_parties, + handshake_mgr); + } + } + ~ClientSecurityHandshakerFactory() override = default; +}; + +class ServerSecurityHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* args, + grpc_pollset_set* interested_parties, + HandshakeManager* handshake_mgr) override { + auto* security_connector = + reinterpret_cast( + grpc_security_connector_find_in_args(args)); + if (security_connector) { + security_connector->add_handshakers(args, interested_parties, + handshake_mgr); + } + } + ~ServerSecurityHandshakerFactory() override = default; +}; + +} // namespace + +// +// exported functions +// + +RefCountedPtr SecurityHandshakerCreate( + tsi_handshaker* handshaker, grpc_security_connector* connector, + const grpc_channel_args* args) { + // If no TSI handshaker was created, return a handshaker that always fails. + // Otherwise, return a real security handshaker. + if (handshaker == nullptr) { + return MakeRefCounted(); + } else { + return MakeRefCounted(handshaker, connector, args); + } +} + +void SecurityRegisterHandshakerFactories(CoreConfiguration::Builder* builder) { + builder->handshaker_registry()->RegisterHandshakerFactory( + false /* at_start */, HANDSHAKER_CLIENT, + absl::make_unique()); + builder->handshaker_registry()->RegisterHandshakerFactory( + false /* at_start */, HANDSHAKER_SERVER, + absl::make_unique()); +} + +} // namespace grpc_core + +grpc_handshaker* grpc_security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector, + const grpc_channel_args* args) { + return SecurityHandshakerCreate(handshaker, connector, args).release(); +} diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc new file mode 100644 index 00000000..feae3f7d --- /dev/null +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -0,0 +1,331 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "src/core/lib/slice/slice_internal.h" + +static void recv_initial_metadata_ready(void* arg, grpc_error_handle error); +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle error); + +namespace { +enum async_state { + STATE_INIT = 0, + STATE_DONE, + STATE_CANCELLED, +}; + +struct channel_data { + channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds) + : auth_context(auth_context->Ref()), creds(creds->Ref()) {} + ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); } + + grpc_core::RefCountedPtr auth_context; + grpc_core::RefCountedPtr creds; +}; + +struct call_data { + call_data(grpc_call_element* elem, const grpc_call_element_args& args) + : call_combiner(args.call_combiner), owning_call(args.call_stack) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready, + ::recv_initial_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + ::recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + // Create server security context. Set its auth context from channel + // data and save it in the call context. + grpc_server_security_context* server_ctx = + grpc_server_security_context_create(args.arena); + channel_data* chand = static_cast(elem->channel_data); + server_ctx->auth_context = + chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter"); + if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) { + args.context[GRPC_CONTEXT_SECURITY].destroy( + args.context[GRPC_CONTEXT_SECURITY].value); + } + args.context[GRPC_CONTEXT_SECURITY].value = server_ctx; + args.context[GRPC_CONTEXT_SECURITY].destroy = + grpc_server_security_context_destroy; + } + + ~call_data() { GRPC_ERROR_UNREF(recv_initial_metadata_error); } + + grpc_core::CallCombiner* call_combiner; + grpc_call_stack* owning_call; + grpc_transport_stream_op_batch* recv_initial_metadata_batch; + grpc_closure* original_recv_initial_metadata_ready; + grpc_closure recv_initial_metadata_ready; + grpc_error_handle recv_initial_metadata_error = GRPC_ERROR_NONE; + grpc_closure recv_trailing_metadata_ready; + grpc_closure* original_recv_trailing_metadata_ready; + grpc_error_handle recv_trailing_metadata_error; + bool seen_recv_trailing_metadata_ready = false; + grpc_metadata_array md; + const grpc_metadata* consumed_md; + size_t num_consumed_md; + grpc_closure cancel_closure; + gpr_atm state = STATE_INIT; // async_state +}; + +} // namespace + +static grpc_metadata_array metadata_batch_to_md_array( + const grpc_metadata_batch* batch) { + grpc_metadata_array result; + grpc_metadata_array_init(&result); + batch->ForEach([&](grpc_mdelem md) { + grpc_metadata* usr_md = nullptr; + grpc_slice key = GRPC_MDKEY(md); + grpc_slice value = GRPC_MDVALUE(md); + if (result.count == result.capacity) { + result.capacity = std::max(result.capacity + 8, result.capacity * 2); + result.metadata = static_cast(gpr_realloc( + result.metadata, result.capacity * sizeof(grpc_metadata))); + } + usr_md = &result.metadata[result.count++]; + usr_md->key = grpc_slice_ref_internal(key); + usr_md->value = grpc_slice_ref_internal(value); + }); + return result; +} + +static grpc_filtered_mdelem remove_consumed_md(void* user_data, + grpc_mdelem md) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + size_t i; + for (i = 0; i < calld->num_consumed_md; i++) { + const grpc_metadata* consumed_md = &calld->consumed_md[i]; + if (grpc_slice_eq(GRPC_MDKEY(md), consumed_md->key) && + grpc_slice_eq(GRPC_MDVALUE(md), consumed_md->value)) { + return GRPC_FILTERED_REMOVE(); + } + } + return GRPC_FILTERED_MDELEM(md); +} + +static void on_md_processing_done_inner(grpc_call_element* elem, + const grpc_metadata* consumed_md, + size_t num_consumed_md, + const grpc_metadata* response_md, + size_t num_response_md, + grpc_error_handle error) { + call_data* calld = static_cast(elem->call_data); + grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; + /* TODO(jboeuf): Implement support for response_md. */ + if (response_md != nullptr && num_response_md > 0) { + gpr_log(GPR_INFO, + "response_md in auth metadata processing not supported for now. " + "Ignoring..."); + } + if (error == GRPC_ERROR_NONE) { + calld->consumed_md = consumed_md; + calld->num_consumed_md = num_consumed_md; + error = grpc_metadata_batch_filter( + batch->payload->recv_initial_metadata.recv_initial_metadata, + remove_consumed_md, elem, "Response metadata filtering error"); + } + calld->recv_initial_metadata_error = GRPC_ERROR_REF(error); + grpc_closure* closure = calld->original_recv_initial_metadata_ready; + calld->original_recv_initial_metadata_ready = nullptr; + if (calld->seen_recv_trailing_metadata_ready) { + GRPC_CALL_COMBINER_START(calld->call_combiner, + &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_error, + "continue recv_trailing_metadata_ready"); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, error); +} + +// Called from application code. +static void on_md_processing_done( + void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md, + const grpc_metadata* response_md, size_t num_response_md, + grpc_status_code status, const char* error_details) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + // If the call was not cancelled while we were in flight, process the result. + if (gpr_atm_full_cas(&calld->state, static_cast(STATE_INIT), + static_cast(STATE_DONE))) { + grpc_error_handle error = GRPC_ERROR_NONE; + if (status != GRPC_STATUS_OK) { + if (error_details == nullptr) { + error_details = "Authentication metadata processing failed."; + } + error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_COPIED_STRING(error_details), + GRPC_ERROR_INT_GRPC_STATUS, status); + } + on_md_processing_done_inner(elem, consumed_md, num_consumed_md, response_md, + num_response_md, error); + } + // Clean up. + for (size_t i = 0; i < calld->md.count; i++) { + grpc_slice_unref_internal(calld->md.metadata[i].key); + grpc_slice_unref_internal(calld->md.metadata[i].value); + } + grpc_metadata_array_destroy(&calld->md); + GRPC_CALL_STACK_UNREF(calld->owning_call, "server_auth_metadata"); +} + +static void cancel_call(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + // If the result was not already processed, invoke the callback now. + if (error != GRPC_ERROR_NONE && + gpr_atm_full_cas(&calld->state, static_cast(STATE_INIT), + static_cast(STATE_CANCELLED))) { + on_md_processing_done_inner(elem, nullptr, 0, nullptr, 0, + GRPC_ERROR_REF(error)); + } + GRPC_CALL_STACK_UNREF(calld->owning_call, "cancel_call"); +} + +static void recv_initial_metadata_ready(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + channel_data* chand = static_cast(elem->channel_data); + call_data* calld = static_cast(elem->call_data); + grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; + if (error == GRPC_ERROR_NONE) { + if (chand->creds != nullptr && + chand->creds->auth_metadata_processor().process != nullptr) { + // We're calling out to the application, so we need to make sure + // to drop the call combiner early if we get cancelled. + // TODO(yashykt): We would not need this ref if call combiners used + // Closure::Run() instead of ExecCtx::Run() + GRPC_CALL_STACK_REF(calld->owning_call, "cancel_call"); + GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem, + grpc_schedule_on_exec_ctx); + calld->call_combiner->SetNotifyOnCancel(&calld->cancel_closure); + GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata"); + calld->md = metadata_batch_to_md_array( + batch->payload->recv_initial_metadata.recv_initial_metadata); + chand->creds->auth_metadata_processor().process( + chand->creds->auth_metadata_processor().state, + chand->auth_context.get(), calld->md.metadata, calld->md.count, + on_md_processing_done, elem); + return; + } + } + grpc_closure* closure = calld->original_recv_initial_metadata_ready; + calld->original_recv_initial_metadata_ready = nullptr; + if (calld->seen_recv_trailing_metadata_ready) { + GRPC_CALL_COMBINER_START(calld->call_combiner, + &calld->recv_trailing_metadata_ready, + calld->recv_trailing_metadata_error, + "continue recv_trailing_metadata_ready"); + } + grpc_core::Closure::Run(DEBUG_LOCATION, closure, GRPC_ERROR_REF(error)); +} + +static void recv_trailing_metadata_ready(void* user_data, + grpc_error_handle err) { + grpc_call_element* elem = static_cast(user_data); + call_data* calld = static_cast(elem->call_data); + if (calld->original_recv_initial_metadata_ready != nullptr) { + calld->recv_trailing_metadata_error = GRPC_ERROR_REF(err); + calld->seen_recv_trailing_metadata_ready = true; + GRPC_CALL_COMBINER_STOP(calld->call_combiner, + "deferring recv_trailing_metadata_ready until " + "after recv_initial_metadata_ready"); + return; + } + err = grpc_error_add_child( + GRPC_ERROR_REF(err), GRPC_ERROR_REF(calld->recv_initial_metadata_error)); + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->original_recv_trailing_metadata_ready, err); +} + +static void server_auth_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + if (batch->recv_initial_metadata) { + // Inject our callback. + calld->recv_initial_metadata_batch = batch; + calld->original_recv_initial_metadata_ready = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &calld->recv_initial_metadata_ready; + } + if (batch->recv_trailing_metadata) { + calld->original_recv_trailing_metadata_ready = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready; + } + grpc_call_next_op(elem, batch); +} + +/* Constructor for call_data */ +static grpc_error_handle server_auth_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + new (elem->call_data) call_data(elem, *args); + return GRPC_ERROR_NONE; +} + +/* Destructor for call_data */ +static void server_auth_destroy_call_elem( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + call_data* calld = static_cast(elem->call_data); + calld->~call_data(); +} + +/* Constructor for channel_data */ +static grpc_error_handle server_auth_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + GPR_ASSERT(!args->is_last); + grpc_auth_context* auth_context = + grpc_find_auth_context_in_args(args->channel_args); + GPR_ASSERT(auth_context != nullptr); + grpc_server_credentials* creds = + grpc_find_server_credentials_in_args(args->channel_args); + new (elem->channel_data) channel_data(auth_context, creds); + return GRPC_ERROR_NONE; +} + +/* Destructor for channel data */ +static void server_auth_destroy_channel_elem(grpc_channel_element* elem) { + channel_data* chand = static_cast(elem->channel_data); + chand->~channel_data(); +} + +const grpc_channel_filter grpc_server_auth_filter = { + server_auth_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + server_auth_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + server_auth_destroy_call_elem, + sizeof(channel_data), + server_auth_init_channel_elem, + server_auth_destroy_channel_elem, + grpc_channel_next_get_info, + "server-auth"}; diff --git a/src/core/lib/security/transport/tsi_error.cc b/src/core/lib/security/transport/tsi_error.cc new file mode 100644 index 00000000..23b10811 --- /dev/null +++ b/src/core/lib/security/transport/tsi_error.cc @@ -0,0 +1,28 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/transport/tsi_error.h" + +grpc_error_handle grpc_set_tsi_error_result(grpc_error_handle error, + tsi_result result) { + return grpc_error_set_int(grpc_error_set_str(error, GRPC_ERROR_STR_TSI_ERROR, + tsi_result_to_string(result)), + GRPC_ERROR_INT_TSI_CODE, result); +} diff --git a/src/core/lib/security/util/json_util.cc b/src/core/lib/security/util/json_util.cc new file mode 100644 index 00000000..25223eef --- /dev/null +++ b/src/core/lib/security/util/json_util.cc @@ -0,0 +1,70 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/util/json_util.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/iomgr/error.h" + +const char* grpc_json_get_string_property(const grpc_core::Json& json, + const char* prop_name, + grpc_error_handle* error) { + if (json.type() != grpc_core::Json::Type::OBJECT) { + if (error != nullptr) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("JSON value is not an object"); + } + return nullptr; + } + auto it = json.object_value().find(prop_name); + if (it == json.object_value().end()) { + if (error != nullptr) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Property ", prop_name, " not found in JSON object.")); + } + return nullptr; + } + if (it->second.type() != grpc_core::Json::Type::STRING) { + if (error != nullptr) { + *error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Property ", prop_name, " n JSON object is not a string.")); + } + return nullptr; + } + return it->second.string_value().c_str(); +} + +bool grpc_copy_json_string_property(const grpc_core::Json& json, + const char* prop_name, + char** copied_value) { + grpc_error_handle error = GRPC_ERROR_NONE; + const char* prop_value = + grpc_json_get_string_property(json, prop_name, &error); + GRPC_LOG_IF_ERROR("Could not copy JSON property", error); + if (prop_value == nullptr) return false; + *copied_value = gpr_strdup(prop_value); + return true; +} diff --git a/src/core/lib/slice/b64.cc b/src/core/lib/slice/b64.cc new file mode 100644 index 00000000..0a06a05c --- /dev/null +++ b/src/core/lib/slice/b64.cc @@ -0,0 +1,239 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/slice/b64.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/slice/slice_internal.h" + +/* --- Constants. --- */ + +static const int8_t base64_bytes[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, -1, 0x3F, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1, + -1, 0x7F, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, + 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, -1, + -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, + 0x31, 0x32, 0x33, -1, -1, -1, -1, -1}; + +static const char base64_url_unsafe_chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static const char base64_url_safe_chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +#define GRPC_BASE64_PAD_CHAR '=' +#define GRPC_BASE64_PAD_BYTE 0x7F +#define GRPC_BASE64_MULTILINE_LINE_LEN 76 +#define GRPC_BASE64_MULTILINE_NUM_BLOCKS (GRPC_BASE64_MULTILINE_LINE_LEN / 4) + +/* --- base64 functions. --- */ + +char* grpc_base64_encode(const void* vdata, size_t data_size, int url_safe, + int multiline) { + size_t result_projected_size = + grpc_base64_estimate_encoded_size(data_size, multiline); + char* result = static_cast(gpr_malloc(result_projected_size)); + grpc_base64_encode_core(result, vdata, data_size, url_safe, multiline); + return result; +} + +size_t grpc_base64_estimate_encoded_size(size_t data_size, int multiline) { + size_t result_projected_size = + 4 * ((data_size + 3) / 3) + + 2 * (multiline ? (data_size / (3 * GRPC_BASE64_MULTILINE_NUM_BLOCKS)) + : 0) + + 1; + return result_projected_size; +} + +void grpc_base64_encode_core(char* result, const void* vdata, size_t data_size, + int url_safe, int multiline) { + const unsigned char* data = static_cast(vdata); + const char* base64_chars = + url_safe ? base64_url_safe_chars : base64_url_unsafe_chars; + const size_t result_projected_size = + grpc_base64_estimate_encoded_size(data_size, multiline); + + char* current = result; + size_t num_blocks = 0; + size_t i = 0; + + /* Encode each block. */ + while (data_size >= 3) { + *current++ = base64_chars[(data[i] >> 2) & 0x3F]; + *current++ = + base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)]; + *current++ = + base64_chars[((data[i + 1] & 0x0F) << 2) | ((data[i + 2] >> 6) & 0x03)]; + *current++ = base64_chars[data[i + 2] & 0x3F]; + + data_size -= 3; + i += 3; + if (multiline && (++num_blocks == GRPC_BASE64_MULTILINE_NUM_BLOCKS)) { + *current++ = '\r'; + *current++ = '\n'; + num_blocks = 0; + } + } + + /* Take care of the tail. */ + if (data_size == 2) { + *current++ = base64_chars[(data[i] >> 2) & 0x3F]; + *current++ = + base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)]; + *current++ = base64_chars[(data[i + 1] & 0x0F) << 2]; + *current++ = GRPC_BASE64_PAD_CHAR; + } else if (data_size == 1) { + *current++ = base64_chars[(data[i] >> 2) & 0x3F]; + *current++ = base64_chars[(data[i] & 0x03) << 4]; + *current++ = GRPC_BASE64_PAD_CHAR; + *current++ = GRPC_BASE64_PAD_CHAR; + } + + GPR_ASSERT(current >= result); + GPR_ASSERT((uintptr_t)(current - result) < result_projected_size); + result[current - result] = '\0'; +} + +grpc_slice grpc_base64_decode(const char* b64, int url_safe) { + return grpc_base64_decode_with_len(b64, strlen(b64), url_safe); +} + +static void decode_one_char(const unsigned char* codes, unsigned char* result, + size_t* result_offset) { + uint32_t packed = (static_cast(codes[0]) << 2) | + (static_cast(codes[1]) >> 4); + result[(*result_offset)++] = static_cast(packed); +} + +static void decode_two_chars(const unsigned char* codes, unsigned char* result, + size_t* result_offset) { + uint32_t packed = (static_cast(codes[0]) << 10) | + (static_cast(codes[1]) << 4) | + (static_cast(codes[2]) >> 2); + result[(*result_offset)++] = static_cast(packed >> 8); + result[(*result_offset)++] = static_cast(packed); +} + +static int decode_group(const unsigned char* codes, size_t num_codes, + unsigned char* result, size_t* result_offset) { + GPR_ASSERT(num_codes <= 4); + + /* Short end groups that may not have padding. */ + if (num_codes == 1) { + gpr_log(GPR_ERROR, "Invalid group. Must be at least 2 bytes."); + return 0; + } + if (num_codes == 2) { + decode_one_char(codes, result, result_offset); + return 1; + } + if (num_codes == 3) { + decode_two_chars(codes, result, result_offset); + return 1; + } + + /* Regular 4 byte groups with padding or not. */ + GPR_ASSERT(num_codes == 4); + if (codes[0] == GRPC_BASE64_PAD_BYTE || codes[1] == GRPC_BASE64_PAD_BYTE) { + gpr_log(GPR_ERROR, "Invalid padding detected."); + return 0; + } + if (codes[2] == GRPC_BASE64_PAD_BYTE) { + if (codes[3] == GRPC_BASE64_PAD_BYTE) { + decode_one_char(codes, result, result_offset); + } else { + gpr_log(GPR_ERROR, "Invalid padding detected."); + return 0; + } + } else if (codes[3] == GRPC_BASE64_PAD_BYTE) { + decode_two_chars(codes, result, result_offset); + } else { + /* No padding. */ + uint32_t packed = (static_cast(codes[0]) << 18) | + (static_cast(codes[1]) << 12) | + (static_cast(codes[2]) << 6) | codes[3]; + result[(*result_offset)++] = static_cast(packed >> 16); + result[(*result_offset)++] = static_cast(packed >> 8); + result[(*result_offset)++] = static_cast(packed); + } + return 1; +} + +grpc_slice grpc_base64_decode_with_len(const char* b64, size_t b64_len, + int url_safe) { + grpc_slice result = GRPC_SLICE_MALLOC(b64_len); + unsigned char* current = GRPC_SLICE_START_PTR(result); + size_t result_size = 0; + unsigned char codes[4]; + size_t num_codes = 0; + + while (b64_len--) { + unsigned char c = static_cast(*b64++); + signed char code; + if (c >= GPR_ARRAY_SIZE(base64_bytes)) continue; + if (url_safe) { + if (c == '+' || c == '/') { + gpr_log(GPR_ERROR, "Invalid character for url safe base64 %c", c); + goto fail; + } + if (c == '-') { + c = '+'; + } else if (c == '_') { + c = '/'; + } + } + code = base64_bytes[c]; + if (code == -1) { + if (c != '\r' && c != '\n') { + gpr_log(GPR_ERROR, "Invalid character %c", c); + goto fail; + } + } else { + codes[num_codes++] = static_cast(code); + if (num_codes == 4) { + if (!decode_group(codes, num_codes, current, &result_size)) goto fail; + num_codes = 0; + } + } + } + + if (num_codes != 0 && + !decode_group(codes, num_codes, current, &result_size)) { + goto fail; + } + GRPC_SLICE_SET_LENGTH(result, result_size); + return result; + +fail: + grpc_slice_unref_internal(result); + return grpc_empty_slice(); +} diff --git a/src/core/lib/slice/percent_encoding.cc b/src/core/lib/slice/percent_encoding.cc new file mode 100644 index 00000000..5ec29350 --- /dev/null +++ b/src/core/lib/slice/percent_encoding.cc @@ -0,0 +1,212 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/slice/percent_encoding.h" + +#include + +#include + +#include "src/core/lib/gprpp/bitset.h" +#include "src/core/lib/slice/slice_internal.h" + +#if __cplusplus > 201103l +#define GRPC_PCTENCODE_CONSTEXPR_FN constexpr +#define GRPC_PCTENCODE_CONSTEXPR_VALUE constexpr +#else +#define GRPC_PCTENCODE_CONSTEXPR_FN +#define GRPC_PCTENCODE_CONSTEXPR_VALUE const +#endif + +namespace grpc_core { + +namespace { +class UrlTable : public BitSet<256> { + public: + GRPC_PCTENCODE_CONSTEXPR_FN UrlTable() { + for (int i = 'a'; i <= 'z'; i++) set(i); + for (int i = 'A'; i <= 'Z'; i++) set(i); + for (int i = '0'; i <= '9'; i++) set(i); + set('-'); + set('_'); + set('.'); + set('~'); + } +}; + +static GRPC_PCTENCODE_CONSTEXPR_VALUE UrlTable g_url_table; + +class CompatibleTable : public BitSet<256> { + public: + GRPC_PCTENCODE_CONSTEXPR_FN CompatibleTable() { + for (int i = 32; i <= 126; i++) { + if (i == '%') continue; + set(i); + } + } +}; + +static GRPC_PCTENCODE_CONSTEXPR_VALUE CompatibleTable g_compatible_table; + +// Map PercentEncodingType to a lookup table of legal symbols for that encoding. +const BitSet<256>& LookupTableForPercentEncodingType(PercentEncodingType type) { + switch (type) { + case PercentEncodingType::URL: + return g_url_table; + case PercentEncodingType::Compatible: + return g_compatible_table; + } + // Crash if a bad PercentEncodingType was passed in. + GPR_UNREACHABLE_CODE(abort()); +} +} // namespace + +grpc_slice PercentEncodeSlice(const grpc_slice& slice, + PercentEncodingType type) { + static const uint8_t hex[] = "0123456789ABCDEF"; + + const BitSet<256>& lut = LookupTableForPercentEncodingType(type); + + // first pass: count the number of bytes needed to output this string + size_t output_length = 0; + const uint8_t* slice_start = GRPC_SLICE_START_PTR(slice); + const uint8_t* slice_end = GRPC_SLICE_END_PTR(slice); + const uint8_t* p; + bool any_reserved_bytes = false; + for (p = slice_start; p < slice_end; p++) { + bool unres = lut.is_set(*p); + output_length += unres ? 1 : 3; + any_reserved_bytes |= !unres; + } + // no unreserved bytes: return the string unmodified + if (!any_reserved_bytes) { + return grpc_slice_ref_internal(slice); + } + // second pass: actually encode + grpc_slice out = GRPC_SLICE_MALLOC(output_length); + uint8_t* q = GRPC_SLICE_START_PTR(out); + for (p = slice_start; p < slice_end; p++) { + if (lut.is_set(*p)) { + *q++ = *p; + } else { + *q++ = '%'; + *q++ = hex[*p >> 4]; + *q++ = hex[*p & 15]; + } + } + GPR_ASSERT(q == GRPC_SLICE_END_PTR(out)); + return out; +} + +static bool valid_hex(const uint8_t* p, const uint8_t* end) { + if (p >= end) return false; + return (*p >= '0' && *p <= '9') || (*p >= 'a' && *p <= 'f') || + (*p >= 'A' && *p <= 'F'); +} + +static uint8_t dehex(uint8_t c) { + if (c >= '0' && c <= '9') return static_cast(c - '0'); + if (c >= 'A' && c <= 'F') return static_cast(c - 'A' + 10); + if (c >= 'a' && c <= 'f') return static_cast(c - 'a' + 10); + GPR_UNREACHABLE_CODE(return 255); +} + +absl::optional PercentDecodeSlice(const grpc_slice& slice_in, + PercentEncodingType type) { + const uint8_t* p = GRPC_SLICE_START_PTR(slice_in); + const uint8_t* in_end = GRPC_SLICE_END_PTR(slice_in); + size_t out_length = 0; + bool any_percent_encoded_stuff = false; + const BitSet<256>& lut = LookupTableForPercentEncodingType(type); + while (p != in_end) { + if (*p == '%') { + if (!valid_hex(++p, in_end)) return {}; + if (!valid_hex(++p, in_end)) return {}; + p++; + out_length++; + any_percent_encoded_stuff = true; + } else if (lut.is_set(*p)) { + p++; + out_length++; + } else { + return {}; + } + } + if (!any_percent_encoded_stuff) { + return grpc_slice_ref_internal(slice_in); + } + p = GRPC_SLICE_START_PTR(slice_in); + grpc_slice slice_out = GRPC_SLICE_MALLOC(out_length); + uint8_t* q = GRPC_SLICE_START_PTR(slice_out); + while (p != in_end) { + if (*p == '%') { + *q++ = static_cast(dehex(p[1]) << 4) | (dehex(p[2])); + p += 3; + } else { + *q++ = *p++; + } + } + GPR_ASSERT(q == GRPC_SLICE_END_PTR(slice_out)); + return slice_out; +} + +grpc_slice PermissivePercentDecodeSlice(const grpc_slice& slice_in) { + const uint8_t* p = GRPC_SLICE_START_PTR(slice_in); + const uint8_t* in_end = GRPC_SLICE_END_PTR(slice_in); + size_t out_length = 0; + bool any_percent_encoded_stuff = false; + while (p != in_end) { + if (*p == '%') { + if (!valid_hex(p + 1, in_end) || !valid_hex(p + 2, in_end)) { + p++; + out_length++; + } else { + p += 3; + out_length++; + any_percent_encoded_stuff = true; + } + } else { + p++; + out_length++; + } + } + if (!any_percent_encoded_stuff) { + return grpc_slice_ref_internal(slice_in); + } + p = GRPC_SLICE_START_PTR(slice_in); + grpc_slice out = GRPC_SLICE_MALLOC(out_length); + uint8_t* q = GRPC_SLICE_START_PTR(out); + while (p != in_end) { + if (*p == '%') { + if (!valid_hex(p + 1, in_end) || !valid_hex(p + 2, in_end)) { + *q++ = *p++; + } else { + *q++ = static_cast(dehex(p[1]) << 4) | (dehex(p[2])); + p += 3; + } + } else { + *q++ = *p++; + } + } + GPR_ASSERT(q == GRPC_SLICE_END_PTR(out)); + return out; +} + +} // namespace grpc_core diff --git a/src/core/lib/slice/slice.cc b/src/core/lib/slice/slice.cc new file mode 100644 index 00000000..8f15a066 --- /dev/null +++ b/src/core/lib/slice/slice.cc @@ -0,0 +1,592 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/slice/slice_internal.h" + +char* grpc_slice_to_c_string(grpc_slice slice) { + char* out = static_cast(gpr_malloc(GRPC_SLICE_LENGTH(slice) + 1)); + memcpy(out, GRPC_SLICE_START_PTR(slice), GRPC_SLICE_LENGTH(slice)); + out[GRPC_SLICE_LENGTH(slice)] = 0; + return out; +} + +grpc_slice grpc_empty_slice(void) { return grpc_core::UnmanagedMemorySlice(); } + +grpc_slice grpc_slice_copy(grpc_slice s) { + grpc_slice out = GRPC_SLICE_MALLOC(GRPC_SLICE_LENGTH(s)); + memcpy(GRPC_SLICE_START_PTR(out), GRPC_SLICE_START_PTR(s), + GRPC_SLICE_LENGTH(s)); + return out; +} + +namespace grpc_core { + +/* grpc_slice_from_static_string support structure - a refcount that does + nothing */ +grpc_slice_refcount kNoopRefcount(grpc_slice_refcount::Type::NOP); +static_assert(std::is_trivially_destructible::value, + "kNoopRefcount must be trivially destructible."); + +/* grpc_slice_new support structures - we create a refcount object extended + with the user provided data pointer & destroy function */ +class NewSliceRefcount { + public: + static void Destroy(void* arg) { delete static_cast(arg); } + + NewSliceRefcount(void (*destroy)(void*), void* user_data) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + user_destroy_(destroy), + user_data_(user_data) {} + ~NewSliceRefcount() { user_destroy_(user_data_); } + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + grpc_slice_refcount base_; + RefCount refs_; + void (*user_destroy_)(void*); + void* user_data_; +}; + +} // namespace grpc_core + +size_t grpc_slice_memory_usage(grpc_slice s) { + if (s.refcount == nullptr || s.refcount == &grpc_core::kNoopRefcount) { + return 0; + } else { + return s.data.refcounted.length; + } +} + +grpc_slice grpc_slice_from_static_buffer(const void* s, size_t len) { + return grpc_core::ExternallyManagedSlice(s, len); +} + +grpc_slice grpc_slice_from_static_string(const char* s) { + return grpc_core::ExternallyManagedSlice(s, strlen(s)); +} + +grpc_slice grpc_slice_new_with_user_data(void* p, size_t len, + void (*destroy)(void*), + void* user_data) { + grpc_slice slice; + slice.refcount = + (new grpc_core::NewSliceRefcount(destroy, user_data))->base_refcount(); + slice.data.refcounted.bytes = static_cast(p); + slice.data.refcounted.length = len; + return slice; +} + +grpc_slice grpc_slice_new(void* p, size_t len, void (*destroy)(void*)) { + /* Pass "p" to *destroy when the slice is no longer needed. */ + return grpc_slice_new_with_user_data(p, len, destroy, p); +} + +namespace grpc_core { +/* grpc_slice_new_with_len support structures - we create a refcount object + extended with the user provided data pointer & destroy function */ +class NewWithLenSliceRefcount { + public: + static void Destroy(void* arg) { + delete static_cast(arg); + } + + NewWithLenSliceRefcount(void (*destroy)(void*, size_t), void* user_data, + size_t user_length) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + user_data_(user_data), + user_length_(user_length), + user_destroy_(destroy) {} + ~NewWithLenSliceRefcount() { user_destroy_(user_data_, user_length_); } + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + grpc_slice_refcount base_; + RefCount refs_; + void* user_data_; + size_t user_length_; + void (*user_destroy_)(void*, size_t); +}; + +/** grpc_slice_from_moved_(string|buffer) ref count .*/ +class MovedStringSliceRefCount { + public: + explicit MovedStringSliceRefCount(grpc_core::UniquePtr&& str) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + str_(std::move(str)) {} + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + static void Destroy(void* arg) { + delete static_cast(arg); + } + + grpc_slice_refcount base_; + grpc_core::RefCount refs_; + grpc_core::UniquePtr str_; +}; + +// grpc_slice_from_cpp_string() ref count. +class MovedCppStringSliceRefCount { + public: + explicit MovedCppStringSliceRefCount(std::string&& str) + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_), + str_(std::move(str)) {} + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + static void Destroy(void* arg) { + delete static_cast(arg); + } + + grpc_slice_refcount base_; + grpc_core::RefCount refs_; + std::string str_; +}; + +} // namespace grpc_core + +grpc_slice grpc_slice_new_with_len(void* p, size_t len, + void (*destroy)(void*, size_t)) { + grpc_slice slice; + slice.refcount = (new grpc_core::NewWithLenSliceRefcount(destroy, p, len)) + ->base_refcount(); + slice.data.refcounted.bytes = static_cast(p); + slice.data.refcounted.length = len; + return slice; +} + +grpc_core::UnmanagedMemorySlice::UnmanagedMemorySlice(const char* source, + size_t length) { + if (length <= sizeof(data.inlined.bytes)) { + refcount = nullptr; + data.inlined.length = static_cast(length); + } else { + HeapInit(length); + } + if (length > 0) { + memcpy(GRPC_SLICE_START_PTR(*this), source, length); + } +} + +grpc_core::UnmanagedMemorySlice::UnmanagedMemorySlice(const char* source) + : grpc_core::UnmanagedMemorySlice::UnmanagedMemorySlice(source, + strlen(source)) {} + +grpc_slice grpc_slice_from_copied_buffer(const char* source, size_t length) { + grpc_slice slice; + if (length <= sizeof(slice.data.inlined.bytes)) { + slice.refcount = nullptr; + slice.data.inlined.length = length; + } else { + // Create a ref-counted slice. + slice = grpc_core::UnmanagedMemorySlice( + length, grpc_core::UnmanagedMemorySlice::ForceHeapAllocation()); + } + memcpy(GRPC_SLICE_START_PTR(slice), source, length); + return slice; +} + +grpc_slice grpc_slice_from_copied_string(const char* source) { + return grpc_slice_from_copied_buffer(source, strlen(source)); +} + +grpc_slice grpc_slice_from_moved_buffer(grpc_core::UniquePtr p, + size_t len) { + uint8_t* ptr = reinterpret_cast(p.get()); + grpc_slice slice; + if (len <= sizeof(slice.data.inlined.bytes)) { + slice.refcount = nullptr; + slice.data.inlined.length = len; + memcpy(GRPC_SLICE_START_PTR(slice), ptr, len); + } else { + slice.refcount = (new grpc_core::MovedStringSliceRefCount(std::move(p))) + ->base_refcount(); + slice.data.refcounted.bytes = ptr; + slice.data.refcounted.length = len; + } + return slice; +} + +grpc_slice grpc_slice_from_moved_string(grpc_core::UniquePtr p) { + const size_t len = strlen(p.get()); + return grpc_slice_from_moved_buffer(std::move(p), len); +} + +grpc_slice grpc_slice_from_cpp_string(std::string str) { + grpc_slice slice; + if (str.size() <= sizeof(slice.data.inlined.bytes)) { + slice.refcount = nullptr; + slice.data.inlined.length = str.size(); + memcpy(GRPC_SLICE_START_PTR(slice), str.data(), str.size()); + } else { + slice.data.refcounted.bytes = + reinterpret_cast(const_cast(str.data())); + slice.data.refcounted.length = str.size(); + slice.refcount = + (new grpc_core::MovedCppStringSliceRefCount(std::move(str))) + ->base_refcount(); + } + return slice; +} + +namespace { + +class MallocRefCount { + public: + static void Destroy(void* arg) { + MallocRefCount* r = static_cast(arg); + r->~MallocRefCount(); + gpr_free(r); + } + + MallocRefCount() + : base_(grpc_slice_refcount::Type::REGULAR, &refs_, Destroy, this, + &base_) {} + ~MallocRefCount() = default; + + grpc_slice_refcount* base_refcount() { return &base_; } + + private: + grpc_slice_refcount base_; + grpc_core::RefCount refs_; +}; + +} // namespace + +grpc_slice grpc_slice_malloc_large(size_t length) { + return grpc_core::UnmanagedMemorySlice( + length, grpc_core::UnmanagedMemorySlice::ForceHeapAllocation()); +} + +void grpc_core::UnmanagedMemorySlice::HeapInit(size_t length) { + /* Memory layout used by the slice created here: + + +-----------+----------------------------------------------------------+ + | refcount | bytes | + +-----------+----------------------------------------------------------+ + + refcount is a malloc_refcount + bytes is an array of bytes of the requested length + Both parts are placed in the same allocation returned from gpr_malloc */ + auto* rc = + static_cast(gpr_malloc(sizeof(MallocRefCount) + length)); + + /* Initial refcount on rc is 1 - and it's up to the caller to release + this reference. */ + new (rc) MallocRefCount(); + + /* Build up the slice to be returned. */ + /* The slices refcount points back to the allocated block. */ + refcount = rc->base_refcount(); + /* The data bytes are placed immediately after the refcount struct */ + data.refcounted.bytes = reinterpret_cast(rc + 1); + /* And the length of the block is set to the requested length */ + data.refcounted.length = length; +} + +grpc_slice grpc_slice_malloc(size_t length) { + return grpc_core::UnmanagedMemorySlice(length); +} + +grpc_core::UnmanagedMemorySlice::UnmanagedMemorySlice(size_t length) { + if (length > sizeof(data.inlined.bytes)) { + HeapInit(length); + } else { + /* small slice: just inline the data */ + refcount = nullptr; + data.inlined.length = static_cast(length); + } +} + +template +static Slice sub_no_ref(const Slice& source, size_t begin, size_t end) { + Slice subset; + + GPR_ASSERT(end >= begin); + + if (source.refcount) { + /* Enforce preconditions */ + GPR_ASSERT(source.data.refcounted.length >= end); + + /* Build the result */ + subset.refcount = source.refcount->sub_refcount(); + /* Point into the source array */ + subset.data.refcounted.bytes = source.data.refcounted.bytes + begin; + subset.data.refcounted.length = end - begin; + } else { + /* Enforce preconditions */ + GPR_ASSERT(source.data.inlined.length >= end); + subset.refcount = nullptr; + subset.data.inlined.length = static_cast(end - begin); + memcpy(subset.data.inlined.bytes, source.data.inlined.bytes + begin, + end - begin); + } + return subset; +} + +grpc_slice grpc_slice_sub_no_ref(grpc_slice source, size_t begin, size_t end) { + return sub_no_ref(source, begin, end); +} + +grpc_core::UnmanagedMemorySlice grpc_slice_sub_no_ref( + const grpc_core::UnmanagedMemorySlice& source, size_t begin, size_t end) { + return sub_no_ref(source, begin, end); +} + +grpc_slice grpc_slice_sub(grpc_slice source, size_t begin, size_t end) { + grpc_slice subset; + + if (end - begin <= sizeof(subset.data.inlined.bytes)) { + subset.refcount = nullptr; + subset.data.inlined.length = static_cast(end - begin); + memcpy(subset.data.inlined.bytes, GRPC_SLICE_START_PTR(source) + begin, + end - begin); + } else { + subset = grpc_slice_sub_no_ref(source, begin, end); + /* Bump the refcount */ + subset.refcount->Ref(); + } + return subset; +} + +grpc_slice grpc_slice_split_tail_maybe_ref(grpc_slice* source, size_t split, + grpc_slice_ref_whom ref_whom) { + grpc_slice tail; + + if (source->refcount == nullptr) { + /* inlined data, copy it out */ + GPR_ASSERT(source->data.inlined.length >= split); + tail.refcount = nullptr; + tail.data.inlined.length = + static_cast(source->data.inlined.length - split); + memcpy(tail.data.inlined.bytes, source->data.inlined.bytes + split, + tail.data.inlined.length); + source->data.inlined.length = static_cast(split); + } else { + size_t tail_length = source->data.refcounted.length - split; + GPR_ASSERT(source->data.refcounted.length >= split); + if (tail_length < sizeof(tail.data.inlined.bytes) && + ref_whom != GRPC_SLICE_REF_TAIL) { + /* Copy out the bytes - it'll be cheaper than refcounting */ + tail.refcount = nullptr; + tail.data.inlined.length = static_cast(tail_length); + memcpy(tail.data.inlined.bytes, source->data.refcounted.bytes + split, + tail_length); + source->refcount = source->refcount->sub_refcount(); + } else { + /* Build the result */ + switch (ref_whom) { + case GRPC_SLICE_REF_TAIL: + tail.refcount = source->refcount->sub_refcount(); + source->refcount = &grpc_core::kNoopRefcount; + break; + case GRPC_SLICE_REF_HEAD: + tail.refcount = &grpc_core::kNoopRefcount; + source->refcount = source->refcount->sub_refcount(); + break; + case GRPC_SLICE_REF_BOTH: + tail.refcount = source->refcount->sub_refcount(); + source->refcount = source->refcount->sub_refcount(); + /* Bump the refcount */ + tail.refcount->Ref(); + break; + } + /* Point into the source array */ + tail.data.refcounted.bytes = source->data.refcounted.bytes + split; + tail.data.refcounted.length = tail_length; + } + source->data.refcounted.length = split; + } + + return tail; +} + +grpc_slice grpc_slice_split_tail(grpc_slice* source, size_t split) { + return grpc_slice_split_tail_maybe_ref(source, split, GRPC_SLICE_REF_BOTH); +} + +grpc_slice grpc_slice_split_head(grpc_slice* source, size_t split) { + grpc_slice head; + + if (source->refcount == nullptr) { + GPR_ASSERT(source->data.inlined.length >= split); + + head.refcount = nullptr; + head.data.inlined.length = static_cast(split); + memcpy(head.data.inlined.bytes, source->data.inlined.bytes, split); + source->data.inlined.length = + static_cast(source->data.inlined.length - split); + memmove(source->data.inlined.bytes, source->data.inlined.bytes + split, + source->data.inlined.length); + } else if (split < sizeof(head.data.inlined.bytes)) { + GPR_ASSERT(source->data.refcounted.length >= split); + + head.refcount = nullptr; + head.data.inlined.length = static_cast(split); + memcpy(head.data.inlined.bytes, source->data.refcounted.bytes, split); + source->refcount = source->refcount->sub_refcount(); + source->data.refcounted.bytes += split; + source->data.refcounted.length -= split; + } else { + GPR_ASSERT(source->data.refcounted.length >= split); + + /* Build the result */ + head.refcount = source->refcount->sub_refcount(); + /* Bump the refcount */ + head.refcount->Ref(); + /* Point into the source array */ + head.data.refcounted.bytes = source->data.refcounted.bytes; + head.data.refcounted.length = split; + source->refcount = source->refcount->sub_refcount(); + source->data.refcounted.bytes += split; + source->data.refcounted.length -= split; + } + + return head; +} + +int grpc_slice_default_eq_impl(grpc_slice a, grpc_slice b) { + if (GRPC_SLICE_LENGTH(a) != GRPC_SLICE_LENGTH(b)) return false; + if (GRPC_SLICE_LENGTH(a) == 0) return true; + return 0 == memcmp(GRPC_SLICE_START_PTR(a), GRPC_SLICE_START_PTR(b), + GRPC_SLICE_LENGTH(a)); +} + +int grpc_slice_eq(grpc_slice a, grpc_slice b) { + if (a.refcount && b.refcount && + a.refcount->GetType() == b.refcount->GetType()) { + return a.refcount->Eq(a, b); + } + return grpc_slice_default_eq_impl(a, b); +} + +int grpc_slice_differs_refcounted(const grpc_slice& a, + const grpc_slice& b_not_inline) { + size_t a_len; + const uint8_t* a_ptr; + if (a.refcount) { + a_len = a.data.refcounted.length; + a_ptr = a.data.refcounted.bytes; + } else { + a_len = a.data.inlined.length; + a_ptr = &a.data.inlined.bytes[0]; + } + if (a_len != b_not_inline.data.refcounted.length) { + return true; + } + if (a_len == 0) { + return false; + } + // This check *must* occur after the a_len == 0 check + // to retain compatibility with grpc_slice_eq. + if (a_ptr == nullptr) { + return true; + } + return memcmp(a_ptr, b_not_inline.data.refcounted.bytes, a_len); +} + +int grpc_slice_cmp(grpc_slice a, grpc_slice b) { + int d = static_cast(GRPC_SLICE_LENGTH(a) - GRPC_SLICE_LENGTH(b)); + if (d != 0) return d; + return memcmp(GRPC_SLICE_START_PTR(a), GRPC_SLICE_START_PTR(b), + GRPC_SLICE_LENGTH(a)); +} + +int grpc_slice_str_cmp(grpc_slice a, const char* b) { + size_t b_length = strlen(b); + int d = static_cast(GRPC_SLICE_LENGTH(a) - b_length); + if (d != 0) return d; + return memcmp(GRPC_SLICE_START_PTR(a), b, b_length); +} + +int grpc_slice_is_equivalent(grpc_slice a, grpc_slice b) { + if (a.refcount == nullptr || b.refcount == nullptr) { + return grpc_slice_eq(a, b); + } + return a.data.refcounted.length == b.data.refcounted.length && + a.data.refcounted.bytes == b.data.refcounted.bytes; +} + +int grpc_slice_buf_start_eq(grpc_slice a, const void* b, size_t len) { + if (GRPC_SLICE_LENGTH(a) < len) return 0; + return 0 == memcmp(GRPC_SLICE_START_PTR(a), b, len); +} + +int grpc_slice_rchr(grpc_slice s, char c) { + const char* b = reinterpret_cast GRPC_SLICE_START_PTR(s); + int i; + for (i = static_cast GRPC_SLICE_LENGTH(s) - 1; i != -1 && b[i] != c; + i--) { + } + return i; +} + +int grpc_slice_chr(grpc_slice s, char c) { + const char* b = reinterpret_cast GRPC_SLICE_START_PTR(s); + const char* p = static_cast(memchr(b, c, GRPC_SLICE_LENGTH(s))); + return p == nullptr ? -1 : static_cast(p - b); +} + +int grpc_slice_slice(grpc_slice haystack, grpc_slice needle) { + size_t haystack_len = GRPC_SLICE_LENGTH(haystack); + const uint8_t* haystack_bytes = GRPC_SLICE_START_PTR(haystack); + size_t needle_len = GRPC_SLICE_LENGTH(needle); + const uint8_t* needle_bytes = GRPC_SLICE_START_PTR(needle); + + if (haystack_len == 0 || needle_len == 0) return -1; + if (haystack_len < needle_len) return -1; + if (haystack_len == needle_len) { + return grpc_slice_eq(haystack, needle) ? 0 : -1; + } + if (needle_len == 1) { + return grpc_slice_chr(haystack, static_cast(*needle_bytes)); + } + + const uint8_t* last = haystack_bytes + haystack_len - needle_len; + for (const uint8_t* cur = haystack_bytes; cur != last; ++cur) { + if (0 == memcmp(cur, needle_bytes, needle_len)) { + return static_cast(cur - haystack_bytes); + } + } + return -1; +} + +grpc_slice grpc_slice_dup(grpc_slice a) { + grpc_slice copy = GRPC_SLICE_MALLOC(GRPC_SLICE_LENGTH(a)); + memcpy(GRPC_SLICE_START_PTR(copy), GRPC_SLICE_START_PTR(a), + GRPC_SLICE_LENGTH(a)); + return copy; +} diff --git a/src/core/lib/slice/slice_api.cc b/src/core/lib/slice/slice_api.cc new file mode 100644 index 00000000..28827d20 --- /dev/null +++ b/src/core/lib/slice/slice_api.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +/* Public API */ +grpc_slice grpc_slice_ref(grpc_slice slice) { + return grpc_slice_ref_internal(slice); +} + +/* Public API */ +void grpc_slice_unref(grpc_slice slice) { + if (grpc_core::ExecCtx::Get() == nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_slice_unref_internal(slice); + } else { + grpc_slice_unref_internal(slice); + } +} diff --git a/src/core/lib/slice/slice_buffer.cc b/src/core/lib/slice/slice_buffer.cc new file mode 100644 index 00000000..d461fe64 --- /dev/null +++ b/src/core/lib/slice/slice_buffer.cc @@ -0,0 +1,413 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +/* grow a buffer; requires GRPC_SLICE_BUFFER_INLINE_ELEMENTS > 1 */ +#define GROW(x) (3 * (x) / 2) + +/* Typically, we do not actually need to embiggen (by calling + * memmove/malloc/realloc) - only if we were up against the full capacity of the + * slice buffer. If do_embiggen is inlined, the compiler clobbers multiple + * registers pointlessly in the common case. */ +static void GPR_ATTRIBUTE_NOINLINE do_embiggen(grpc_slice_buffer* sb, + const size_t slice_count, + const size_t slice_offset) { + if (slice_offset != 0) { + /* Make room by moving elements if there's still space unused */ + memmove(sb->base_slices, sb->slices, sb->count * sizeof(grpc_slice)); + sb->slices = sb->base_slices; + } else { + /* Allocate more memory if no more space is available */ + const size_t new_capacity = GROW(sb->capacity); + sb->capacity = new_capacity; + if (sb->base_slices == sb->inlined) { + sb->base_slices = static_cast( + gpr_malloc(new_capacity * sizeof(grpc_slice))); + memcpy(sb->base_slices, sb->inlined, slice_count * sizeof(grpc_slice)); + } else { + sb->base_slices = static_cast( + gpr_realloc(sb->base_slices, new_capacity * sizeof(grpc_slice))); + } + + sb->slices = sb->base_slices + slice_offset; + } +} + +static void maybe_embiggen(grpc_slice_buffer* sb) { + if (sb->count == 0) { + sb->slices = sb->base_slices; + return; + } + + /* How far away from sb->base_slices is sb->slices pointer */ + size_t slice_offset = static_cast(sb->slices - sb->base_slices); + size_t slice_count = sb->count + slice_offset; + if (GPR_UNLIKELY(slice_count == sb->capacity)) { + do_embiggen(sb, slice_count, slice_offset); + } +} + +void grpc_slice_buffer_init(grpc_slice_buffer* sb) { + sb->count = 0; + sb->length = 0; + sb->capacity = GRPC_SLICE_BUFFER_INLINE_ELEMENTS; + sb->base_slices = sb->slices = sb->inlined; +} + +void grpc_slice_buffer_destroy_internal(grpc_slice_buffer* sb) { + grpc_slice_buffer_reset_and_unref_internal(sb); + if (sb->base_slices != sb->inlined) { + gpr_free(sb->base_slices); + } +} + +void grpc_slice_buffer_destroy(grpc_slice_buffer* sb) { + if (grpc_core::ExecCtx::Get() == nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer_destroy_internal(sb); + } else { + grpc_slice_buffer_destroy_internal(sb); + } +} + +uint8_t* grpc_slice_buffer_tiny_add(grpc_slice_buffer* sb, size_t n) { + grpc_slice* back; + uint8_t* out; + + sb->length += n; + + if (sb->count == 0) goto add_first; + back = &sb->slices[sb->count - 1]; + if (back->refcount) goto add_new; + if ((back->data.inlined.length + n) > sizeof(back->data.inlined.bytes)) { + goto add_new; + } + out = back->data.inlined.bytes + back->data.inlined.length; + back->data.inlined.length = + static_cast(back->data.inlined.length + n); + return out; + +add_new: + maybe_embiggen(sb); +add_first: + back = &sb->slices[sb->count]; + sb->count++; + back->refcount = nullptr; + back->data.inlined.length = static_cast(n); + return back->data.inlined.bytes; +} + +size_t grpc_slice_buffer_add_indexed(grpc_slice_buffer* sb, grpc_slice s) { + size_t out = sb->count; + maybe_embiggen(sb); + sb->slices[out] = s; + sb->length += GRPC_SLICE_LENGTH(s); + sb->count = out + 1; + return out; +} + +void grpc_slice_buffer_add(grpc_slice_buffer* sb, grpc_slice s) { + size_t n = sb->count; + /* if both the last slice in the slice buffer and the slice being added + are inlined (that is, that they carry their data inside the slice data + structure), and the back slice is not full, then concatenate directly + into the back slice, preventing many small slices being passed into + writes */ + if (!s.refcount && n) { + grpc_slice* back = &sb->slices[n - 1]; + if (!back->refcount && + back->data.inlined.length < GRPC_SLICE_INLINED_SIZE) { + if (s.data.inlined.length + back->data.inlined.length <= + GRPC_SLICE_INLINED_SIZE) { + memcpy(back->data.inlined.bytes + back->data.inlined.length, + s.data.inlined.bytes, s.data.inlined.length); + back->data.inlined.length = static_cast( + back->data.inlined.length + s.data.inlined.length); + } else { + size_t cp1 = GRPC_SLICE_INLINED_SIZE - back->data.inlined.length; + memcpy(back->data.inlined.bytes + back->data.inlined.length, + s.data.inlined.bytes, cp1); + back->data.inlined.length = GRPC_SLICE_INLINED_SIZE; + maybe_embiggen(sb); + back = &sb->slices[n]; + sb->count = n + 1; + back->refcount = nullptr; + back->data.inlined.length = + static_cast(s.data.inlined.length - cp1); + memcpy(back->data.inlined.bytes, s.data.inlined.bytes + cp1, + s.data.inlined.length - cp1); + } + sb->length += s.data.inlined.length; + return; /* early out */ + } + } + grpc_slice_buffer_add_indexed(sb, s); +} + +void grpc_slice_buffer_addn(grpc_slice_buffer* sb, grpc_slice* s, size_t n) { + size_t i; + for (i = 0; i < n; i++) { + grpc_slice_buffer_add(sb, s[i]); + } +} + +void grpc_slice_buffer_pop(grpc_slice_buffer* sb) { + if (sb->count != 0) { + size_t count = --sb->count; + sb->length -= GRPC_SLICE_LENGTH(sb->slices[count]); + } +} + +void grpc_slice_buffer_reset_and_unref_internal(grpc_slice_buffer* sb) { + size_t i; + for (i = 0; i < sb->count; i++) { + grpc_slice_unref_internal(sb->slices[i]); + } + + sb->count = 0; + sb->length = 0; + sb->slices = sb->base_slices; +} + +void grpc_slice_buffer_reset_and_unref(grpc_slice_buffer* sb) { + if (grpc_core::ExecCtx::Get() == nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer_reset_and_unref_internal(sb); + } else { + grpc_slice_buffer_reset_and_unref_internal(sb); + } +} + +void grpc_slice_buffer_swap(grpc_slice_buffer* a, grpc_slice_buffer* b) { + size_t a_offset = static_cast(a->slices - a->base_slices); + size_t b_offset = static_cast(b->slices - b->base_slices); + + size_t a_count = a->count + a_offset; + size_t b_count = b->count + b_offset; + + if (a->base_slices == a->inlined) { + if (b->base_slices == b->inlined) { + /* swap contents of inlined buffer */ + grpc_slice temp[GRPC_SLICE_BUFFER_INLINE_ELEMENTS]; + memcpy(temp, a->base_slices, a_count * sizeof(grpc_slice)); + memcpy(a->base_slices, b->base_slices, b_count * sizeof(grpc_slice)); + memcpy(b->base_slices, temp, a_count * sizeof(grpc_slice)); + } else { + /* a is inlined, b is not - copy a inlined into b, fix pointers */ + a->base_slices = b->base_slices; + b->base_slices = b->inlined; + memcpy(b->base_slices, a->inlined, a_count * sizeof(grpc_slice)); + } + } else if (b->base_slices == b->inlined) { + /* b is inlined, a is not - copy b inlined int a, fix pointers */ + b->base_slices = a->base_slices; + a->base_slices = a->inlined; + memcpy(a->base_slices, b->inlined, b_count * sizeof(grpc_slice)); + } else { + /* no inlining: easy swap */ + std::swap(a->base_slices, b->base_slices); + } + + /* Update the slices pointers (cannot do a std::swap on slices fields here). + * Also note that since the base_slices pointers are already swapped we need + * use 'b_offset' for 'a->base_slices' and vice versa */ + a->slices = a->base_slices + b_offset; + b->slices = b->base_slices + a_offset; + + /* base_slices and slices fields are correctly set. Swap all other fields */ + std::swap(a->count, b->count); + std::swap(a->capacity, b->capacity); + std::swap(a->length, b->length); +} + +void grpc_slice_buffer_move_into(grpc_slice_buffer* src, + grpc_slice_buffer* dst) { + /* anything to move? */ + if (src->count == 0) { + return; + } + /* anything in dst? */ + if (dst->count == 0) { + grpc_slice_buffer_swap(src, dst); + return; + } + /* both buffers have data - copy, and reset src */ + grpc_slice_buffer_addn(dst, src->slices, src->count); + src->count = 0; + src->length = 0; +} + +template +static void slice_buffer_move_first_maybe_ref(grpc_slice_buffer* src, size_t n, + grpc_slice_buffer* dst) { + GPR_ASSERT(src->length >= n); + if (src->length == n) { + grpc_slice_buffer_move_into(src, dst); + return; + } + + size_t output_len = dst->length + n; + size_t new_input_len = src->length - n; + + while (src->count > 0) { + grpc_slice slice = grpc_slice_buffer_take_first(src); + size_t slice_len = GRPC_SLICE_LENGTH(slice); + if (n > slice_len) { + grpc_slice_buffer_add(dst, slice); + n -= slice_len; + } else if (n == slice_len) { + grpc_slice_buffer_add(dst, slice); + break; + } else if (incref) { /* n < slice_len */ + grpc_slice_buffer_undo_take_first( + src, grpc_slice_split_tail_maybe_ref(&slice, n, GRPC_SLICE_REF_BOTH)); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == n); + grpc_slice_buffer_add(dst, slice); + break; + } else { /* n < slice_len */ + grpc_slice_buffer_undo_take_first( + src, grpc_slice_split_tail_maybe_ref(&slice, n, GRPC_SLICE_REF_TAIL)); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == n); + grpc_slice_buffer_add_indexed(dst, slice); + break; + } + } + GPR_ASSERT(dst->length == output_len); + GPR_ASSERT(src->length == new_input_len); + GPR_ASSERT(src->count > 0); +} + +void grpc_slice_buffer_move_first(grpc_slice_buffer* src, size_t n, + grpc_slice_buffer* dst) { + slice_buffer_move_first_maybe_ref(src, n, dst); +} + +void grpc_slice_buffer_move_first_no_ref(grpc_slice_buffer* src, size_t n, + grpc_slice_buffer* dst) { + slice_buffer_move_first_maybe_ref(src, n, dst); +} + +void grpc_slice_buffer_move_first_into_buffer(grpc_slice_buffer* src, size_t n, + void* dst) { + char* dstp = static_cast(dst); + GPR_ASSERT(src->length >= n); + + while (n > 0) { + grpc_slice slice = grpc_slice_buffer_take_first(src); + size_t slice_len = GRPC_SLICE_LENGTH(slice); + if (slice_len > n) { + memcpy(dstp, GRPC_SLICE_START_PTR(slice), n); + grpc_slice_buffer_undo_take_first( + src, grpc_slice_sub_no_ref(slice, n, slice_len)); + n = 0; + } else if (slice_len == n) { + memcpy(dstp, GRPC_SLICE_START_PTR(slice), n); + grpc_slice_unref_internal(slice); + n = 0; + } else { + memcpy(dstp, GRPC_SLICE_START_PTR(slice), slice_len); + dstp += slice_len; + n -= slice_len; + grpc_slice_unref_internal(slice); + } + } +} + +void grpc_slice_buffer_trim_end(grpc_slice_buffer* sb, size_t n, + grpc_slice_buffer* garbage) { + GPR_ASSERT(n <= sb->length); + sb->length -= n; + for (;;) { + size_t idx = sb->count - 1; + grpc_slice slice = sb->slices[idx]; + size_t slice_len = GRPC_SLICE_LENGTH(slice); + if (slice_len > n) { + sb->slices[idx] = grpc_slice_split_head(&slice, slice_len - n); + if (garbage) { + grpc_slice_buffer_add_indexed(garbage, slice); + } else { + grpc_slice_unref_internal(slice); + } + return; + } else if (slice_len == n) { + if (garbage) { + grpc_slice_buffer_add_indexed(garbage, slice); + } else { + grpc_slice_unref_internal(slice); + } + sb->count = idx; + return; + } else { + if (garbage) { + grpc_slice_buffer_add_indexed(garbage, slice); + } else { + grpc_slice_unref_internal(slice); + } + n -= slice_len; + sb->count = idx; + } + } +} + +grpc_slice grpc_slice_buffer_take_first(grpc_slice_buffer* sb) { + grpc_slice slice; + GPR_ASSERT(sb->count > 0); + slice = sb->slices[0]; + sb->slices++; + sb->count--; + sb->length -= GRPC_SLICE_LENGTH(slice); + + return slice; +} + +void grpc_slice_buffer_remove_first(grpc_slice_buffer* sb) { + GPR_DEBUG_ASSERT(sb->count > 0); + sb->length -= GRPC_SLICE_LENGTH(sb->slices[0]); + grpc_slice_unref_internal(sb->slices[0]); + sb->slices++; + if (--sb->count == 0) { + sb->slices = sb->base_slices; + } +} + +void grpc_slice_buffer_sub_first(grpc_slice_buffer* sb, size_t begin, + size_t end) { + // TODO(soheil): Introduce a ptr version for sub. + sb->length -= GRPC_SLICE_LENGTH(sb->slices[0]); + sb->slices[0] = grpc_slice_sub_no_ref(sb->slices[0], begin, end); + sb->length += end - begin; +} + +void grpc_slice_buffer_undo_take_first(grpc_slice_buffer* sb, + grpc_slice slice) { + sb->slices--; + sb->slices[0] = slice; + sb->count++; + sb->length += GRPC_SLICE_LENGTH(slice); +} diff --git a/src/core/lib/slice/slice_intern.cc b/src/core/lib/slice/slice_intern.cc new file mode 100644 index 00000000..d3744570 --- /dev/null +++ b/src/core/lib/slice/slice_intern.cc @@ -0,0 +1,367 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/iomgr_internal.h" /* for iomgr_abort_on_leaks() */ +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/transport/static_metadata.h" + +#define LOG2_SHARD_COUNT 5 +#define SHARD_COUNT (1 << LOG2_SHARD_COUNT) +#define INITIAL_SHARD_CAPACITY 8 + +#define TABLE_IDX(hash, capacity) (((hash) >> LOG2_SHARD_COUNT) % (capacity)) +#define SHARD_IDX(hash) ((hash) & ((1 << LOG2_SHARD_COUNT) - 1)) + +using grpc_core::InternedSliceRefcount; + +typedef struct slice_shard { + grpc_core::Mutex mu; + InternedSliceRefcount** strs; + size_t count; + size_t capacity; +} slice_shard; + +static slice_shard* g_shards; + +struct static_metadata_hash_ent { + uint32_t hash; + uint32_t idx; +}; +static static_metadata_hash_ent + static_metadata_hash[4 * GRPC_STATIC_MDSTR_COUNT]; +static uint32_t max_static_metadata_hash_probe; +uint32_t grpc_static_metadata_hash_values[GRPC_STATIC_MDSTR_COUNT]; + +namespace grpc_core { + +/* hash seed: decided at initialization time */ +uint32_t g_hash_seed; +static bool g_forced_hash_seed = false; + +InternedSliceRefcount::~InternedSliceRefcount() { + slice_shard* shard = &g_shards[SHARD_IDX(this->hash)]; + MutexLock lock(&shard->mu); + InternedSliceRefcount** prev_next; + InternedSliceRefcount* cur; + for (prev_next = &shard->strs[TABLE_IDX(this->hash, shard->capacity)], + cur = *prev_next; + cur != this; prev_next = &cur->bucket_next, cur = cur->bucket_next) { + } + *prev_next = cur->bucket_next; + shard->count--; +} + +} // namespace grpc_core + +static void grow_shard(slice_shard* shard) { + GPR_TIMER_SCOPE("grow_strtab", 0); + + size_t capacity = shard->capacity * 2; + size_t i; + InternedSliceRefcount** strtab; + InternedSliceRefcount *s, *next; + + strtab = static_cast( + gpr_zalloc(sizeof(InternedSliceRefcount*) * capacity)); + + for (i = 0; i < shard->capacity; i++) { + for (s = shard->strs[i]; s; s = next) { + size_t idx = TABLE_IDX(s->hash, capacity); + next = s->bucket_next; + s->bucket_next = strtab[idx]; + strtab[idx] = s; + } + } + gpr_free(shard->strs); + shard->strs = strtab; + shard->capacity = capacity; +} + +grpc_core::InternedSlice::InternedSlice(InternedSliceRefcount* s) { + refcount = &s->base; + data.refcounted.bytes = reinterpret_cast(s + 1); + data.refcounted.length = s->length; +} + +uint32_t grpc_slice_default_hash_impl(grpc_slice s) { + return gpr_murmur_hash3(GRPC_SLICE_START_PTR(s), GRPC_SLICE_LENGTH(s), + grpc_core::g_hash_seed); +} + +uint32_t grpc_static_slice_hash(grpc_slice s) { + return grpc_static_metadata_hash_values[GRPC_STATIC_METADATA_INDEX(s)]; +} + +int grpc_static_slice_eq(grpc_slice a, grpc_slice b) { + return GRPC_STATIC_METADATA_INDEX(a) == GRPC_STATIC_METADATA_INDEX(b); +} + +uint32_t grpc_slice_hash(grpc_slice s) { return grpc_slice_hash_internal(s); } + +grpc_slice grpc_slice_maybe_static_intern(grpc_slice slice, + bool* returned_slice_is_different) { + if (GRPC_IS_STATIC_METADATA_STRING(slice)) { + return slice; + } + + uint32_t hash = grpc_slice_hash_internal(slice); + for (uint32_t i = 0; i <= max_static_metadata_hash_probe; i++) { + static_metadata_hash_ent ent = + static_metadata_hash[(hash + i) % GPR_ARRAY_SIZE(static_metadata_hash)]; + if (ent.hash == hash && ent.idx < GRPC_STATIC_MDSTR_COUNT && + grpc_slice_eq_static_interned( + slice, grpc_core::g_static_metadata_slice_table[ent.idx])) { + *returned_slice_is_different = true; + return grpc_core::g_static_metadata_slice_table[ent.idx]; + } + } + + return slice; +} + +grpc_slice grpc_slice_intern(grpc_slice slice) { + /* TODO(arjunroy): At present, this is capable of returning either a static or + an interned slice. This yields weirdness like the constructor for + ManagedMemorySlice instantiating itself as an instance of a derived type + (StaticMetadataSlice or InternedSlice). Should reexamine. */ + return grpc_core::ManagedMemorySlice(&slice); +} + +// Attempt to see if the provided slice or string matches a static slice. +// SliceArgs is either a const grpc_slice& or const pair&. +// In either case, hash is the pre-computed hash value. +// +// Returns: a matching static slice, or null. +template +static const grpc_core::StaticMetadataSlice* MatchStaticSlice( + uint32_t hash, const SliceArgs& args) { + for (uint32_t i = 0; i <= max_static_metadata_hash_probe; i++) { + static_metadata_hash_ent ent = + static_metadata_hash[(hash + i) % GPR_ARRAY_SIZE(static_metadata_hash)]; + if (ent.hash == hash && ent.idx < GRPC_STATIC_MDSTR_COUNT && + grpc_core::g_static_metadata_slice_table[ent.idx] == args) { + return &grpc_core::g_static_metadata_slice_table[ent.idx]; + } + } + return nullptr; +} + +// Helper methods to enable us to select appropriately overloaded slice methods +// whether we're dealing with a slice, or a buffer with length, when interning +// strings. Helpers for FindOrCreateInternedSlice(). +static const char* GetBuffer(const std::pair& buflen) { + return buflen.first; +} +static size_t GetLength(const std::pair& buflen) { + return buflen.second; +} +static const void* GetBuffer(const grpc_slice& slice) { + return GRPC_SLICE_START_PTR(slice); +} +static size_t GetLength(const grpc_slice& slice) { + return GRPC_SLICE_LENGTH(slice); +} + +// Creates an interned slice for a string that does not currently exist in the +// intern table. SliceArgs is either a const grpc_slice& or a const +// pair&. Hash is the pre-computed hash value. We must +// already hold the shard lock. Helper for FindOrCreateInternedSlice(). +// +// Returns: a newly interned slice. +template +static InternedSliceRefcount* InternNewStringLocked(slice_shard* shard, + size_t shard_idx, + uint32_t hash, + const SliceArgs& args) { + /* string data goes after the internal_string header */ + size_t len = GetLength(args); + const void* buffer = GetBuffer(args); + InternedSliceRefcount* s = + static_cast(gpr_malloc(sizeof(*s) + len)); + new (s) grpc_core::InternedSliceRefcount(len, hash, shard->strs[shard_idx]); + // TODO(arjunroy): Investigate why hpack tried to intern the nullptr string. + // https://github.com/grpc/grpc/pull/20110#issuecomment-526729282 + if (len > 0) { + memcpy(reinterpret_cast(s + 1), buffer, len); + } + shard->strs[shard_idx] = s; + shard->count++; + if (shard->count > shard->capacity * 2) { + grow_shard(shard); + } + return s; +} + +// Attempt to see if the provided slice or string matches an existing interned +// slice. SliceArgs... is either a const grpc_slice& or a string and length. In +// either case, hash is the pre-computed hash value. We must already hold the +// shard lock. Helper for FindOrCreateInternedSlice(). +// +// Returns: a pre-existing matching static slice, or null. +template +static InternedSliceRefcount* MatchInternedSliceLocked(uint32_t hash, + size_t idx, + const SliceArgs& args) { + InternedSliceRefcount* s; + slice_shard* shard = &g_shards[SHARD_IDX(hash)]; + /* search for an existing string */ + for (s = shard->strs[idx]; s; s = s->bucket_next) { + if (s->hash == hash && grpc_core::InternedSlice(s) == args) { + if (s->refcnt.RefIfNonZero()) { + return s; + } + } + } + return nullptr; +} + +// Attempt to see if the provided slice or string matches an existing interned +// slice, and failing that, create an interned slice with its contents. Returns +// either the existing matching interned slice or the newly created one. +// SliceArgs is either a const grpc_slice& or const pair&. +// In either case, hash is the pre-computed hash value. We do not hold the +// shard lock here, but do take it. +// +// Returns: an interned slice, either pre-existing/matched or newly created. +template +static InternedSliceRefcount* FindOrCreateInternedSlice(uint32_t hash, + const SliceArgs& args) { + slice_shard* shard = &g_shards[SHARD_IDX(hash)]; + grpc_core::MutexLock lock(&shard->mu); + const size_t idx = TABLE_IDX(hash, shard->capacity); + InternedSliceRefcount* s = MatchInternedSliceLocked(hash, idx, args); + if (s == nullptr) { + s = InternNewStringLocked(shard, idx, hash, args); + } + return s; +} + +grpc_core::ManagedMemorySlice::ManagedMemorySlice(const char* string) + : grpc_core::ManagedMemorySlice::ManagedMemorySlice(string, + strlen(string)) {} + +grpc_core::ManagedMemorySlice::ManagedMemorySlice(const char* buf, size_t len) { + GPR_TIMER_SCOPE("grpc_slice_intern", 0); + const uint32_t hash = gpr_murmur_hash3(buf, len, g_hash_seed); + const StaticMetadataSlice* static_slice = + MatchStaticSlice(hash, std::pair(buf, len)); + if (static_slice) { + *this = *static_slice; + } else { + *this = grpc_core::InternedSlice(FindOrCreateInternedSlice( + hash, std::pair(buf, len))); + } +} + +grpc_core::ManagedMemorySlice::ManagedMemorySlice(const grpc_slice* slice_ptr) { + GPR_TIMER_SCOPE("grpc_slice_intern", 0); + const grpc_slice& slice = *slice_ptr; + if (GRPC_IS_STATIC_METADATA_STRING(slice)) { + *this = static_cast(slice); + return; + } + const uint32_t hash = grpc_slice_hash_internal(slice); + const StaticMetadataSlice* static_slice = MatchStaticSlice(hash, slice); + if (static_slice) { + *this = *static_slice; + } else { + *this = grpc_core::InternedSlice(FindOrCreateInternedSlice(hash, slice)); + } +} + +void grpc_test_only_set_slice_hash_seed(uint32_t seed) { + grpc_core::g_hash_seed = seed; + grpc_core::g_forced_hash_seed = true; +} + +void grpc_slice_intern_init(void) { + if (!grpc_core::g_forced_hash_seed) { + grpc_core::g_hash_seed = + static_cast(gpr_now(GPR_CLOCK_REALTIME).tv_nsec); + } + g_shards = new slice_shard[SHARD_COUNT]; + for (size_t i = 0; i < SHARD_COUNT; i++) { + slice_shard* shard = &g_shards[i]; + shard->count = 0; + shard->capacity = INITIAL_SHARD_CAPACITY; + shard->strs = static_cast( + gpr_zalloc(sizeof(*shard->strs) * shard->capacity)); + } + for (size_t i = 0; i < GPR_ARRAY_SIZE(static_metadata_hash); i++) { + static_metadata_hash[i].hash = 0; + static_metadata_hash[i].idx = GRPC_STATIC_MDSTR_COUNT; + } + max_static_metadata_hash_probe = 0; + for (size_t i = 0; i < GRPC_STATIC_MDSTR_COUNT; i++) { + grpc_static_metadata_hash_values[i] = grpc_slice_default_hash_internal( + grpc_core::g_static_metadata_slice_table[i]); + for (size_t j = 0; j < GPR_ARRAY_SIZE(static_metadata_hash); j++) { + size_t slot = (grpc_static_metadata_hash_values[i] + j) % + GPR_ARRAY_SIZE(static_metadata_hash); + if (static_metadata_hash[slot].idx == GRPC_STATIC_MDSTR_COUNT) { + static_metadata_hash[slot].hash = grpc_static_metadata_hash_values[i]; + static_metadata_hash[slot].idx = static_cast(i); + if (j > max_static_metadata_hash_probe) { + max_static_metadata_hash_probe = static_cast(j); + } + break; + } + } + } + // Handle KV hash for all static mdelems. + for (size_t i = 0; i < GRPC_STATIC_MDELEM_COUNT; ++i) { + grpc_core::g_static_mdelem_table[i].HashInit(); + } +} + +void grpc_slice_intern_shutdown(void) { + for (size_t i = 0; i < SHARD_COUNT; i++) { + slice_shard* shard = &g_shards[i]; + /* TODO(ctiller): GPR_ASSERT(shard->count == 0); */ + if (shard->count != 0) { + gpr_log(GPR_DEBUG, "WARNING: %" PRIuPTR " metadata strings were leaked", + shard->count); + for (size_t j = 0; j < shard->capacity; j++) { + for (InternedSliceRefcount* s = shard->strs[j]; s; s = s->bucket_next) { + char* text = grpc_dump_slice(grpc_core::InternedSlice(s), + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "LEAKED: %s", text); + gpr_free(text); + } + } + if (grpc_iomgr_abort_on_leaks()) { + abort(); + } + } + gpr_free(shard->strs); + } + delete[] g_shards; +} diff --git a/src/core/lib/slice/slice_refcount.cc b/src/core/lib/slice/slice_refcount.cc new file mode 100644 index 00000000..10de62b7 --- /dev/null +++ b/src/core/lib/slice/slice_refcount.cc @@ -0,0 +1,17 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/slice/slice_refcount.h" diff --git a/src/core/lib/slice/slice_split.cc b/src/core/lib/slice/slice_split.cc new file mode 100644 index 00000000..d2a951cb --- /dev/null +++ b/src/core/lib/slice/slice_split.cc @@ -0,0 +1,100 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/slice/slice_split.h" + +#include + +#include + +/** Finds the initial (\a begin) and final (\a end) offsets of the next + * substring from \a str + \a read_offset until the next \a sep or the end of \a + * str. + * + * Returns 1 and updates \a begin and \a end. Returns 0 otherwise. */ +static int slice_find_separator_offset(const grpc_slice str, const char* sep, + const size_t read_offset, size_t* begin, + size_t* end) { + size_t i; + const uint8_t* str_ptr = GRPC_SLICE_START_PTR(str) + read_offset; + const size_t str_len = GRPC_SLICE_LENGTH(str) - read_offset; + const size_t sep_len = strlen(sep); + if (str_len < sep_len) { + return 0; + } + + for (i = 0; i <= str_len - sep_len; i++) { + if (memcmp(str_ptr + i, sep, sep_len) == 0) { + *begin = read_offset; + *end = read_offset + i; + return 1; + } + } + return 0; +} + +static void skip_leading_trailing_spaces(const uint8_t* str_buffer, + size_t* begin, size_t* end) { + while (*begin < *end && str_buffer[*begin] == ' ') { + (*begin)++; + } + while (*begin < *end && str_buffer[*end - 1] == ' ') { + (*end)--; + } +} + +static void grpc_slice_split_inner(grpc_slice str, const char* sep, + grpc_slice_buffer* dst, bool no_space) { + const size_t sep_len = strlen(sep); + size_t begin, end; + const uint8_t* str_buffer = GRPC_SLICE_START_PTR(str); + size_t sep_pos; + + GPR_ASSERT(sep_len > 0); + + if (slice_find_separator_offset(str, sep, 0, &begin, &end) != 0) { + do { + sep_pos = end; + if (no_space) { + skip_leading_trailing_spaces(str_buffer, &begin, &end); + } + grpc_slice_buffer_add_indexed(dst, grpc_slice_sub(str, begin, end)); + } while (slice_find_separator_offset(str, sep, sep_pos + sep_len, &begin, + &end) != 0); + begin = sep_pos + sep_len; + end = GRPC_SLICE_LENGTH(str); + if (no_space) { + skip_leading_trailing_spaces(str_buffer, &begin, &end); + } + grpc_slice_buffer_add_indexed(dst, grpc_slice_sub(str, begin, end)); + } else { /* no sep found, add whole input */ + begin = 0; + end = GRPC_SLICE_LENGTH(str); + if (no_space) { + skip_leading_trailing_spaces(str_buffer, &begin, &end); + } + grpc_slice_buffer_add_indexed(dst, grpc_slice_sub(str, begin, end)); + } +} + +void grpc_slice_split(grpc_slice str, const char* sep, grpc_slice_buffer* dst) { + grpc_slice_split_inner(str, sep, dst, false); +} + +void grpc_slice_split_without_space(grpc_slice str, const char* sep, + grpc_slice_buffer* dst) { + grpc_slice_split_inner(str, sep, dst, true); +} diff --git a/src/core/lib/slice/slice_string_helpers.cc b/src/core/lib/slice/slice_string_helpers.cc new file mode 100644 index 00000000..35744d0a --- /dev/null +++ b/src/core/lib/slice/slice_string_helpers.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/slice/slice_string_helpers.h" + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" + +char* grpc_dump_slice(const grpc_slice& s, uint32_t flags) { + return gpr_dump(reinterpret_cast GRPC_SLICE_START_PTR(s), + GRPC_SLICE_LENGTH(s), flags); +} + +grpc_slice grpc_dump_slice_to_slice(const grpc_slice& s, uint32_t flags) { + size_t len; + grpc_core::UniquePtr ptr( + gpr_dump_return_len(reinterpret_cast GRPC_SLICE_START_PTR(s), + GRPC_SLICE_LENGTH(s), flags, &len)); + return grpc_slice_from_moved_buffer(std::move(ptr), len); +} + +bool grpc_parse_slice_to_uint32(grpc_slice str, uint32_t* result) { + return gpr_parse_bytes_to_uint32( + reinterpret_cast GRPC_SLICE_START_PTR(str), + GRPC_SLICE_LENGTH(str), result) != 0; +} diff --git a/src/core/lib/slice/static_slice.cc b/src/core/lib/slice/static_slice.cc new file mode 100644 index 00000000..40681b11 --- /dev/null +++ b/src/core/lib/slice/static_slice.cc @@ -0,0 +1,529 @@ +/* + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * WARNING: Auto-generated code. + * + * To make changes to this file, change + * tools/codegen/core/gen_static_metadata.py, and then re-run it. + * + * See metadata.h for an explanation of the interface here, and metadata.cc for + * an explanation of what's going on. + */ + +#include + +#include "src/core/lib/slice/static_slice.h" + +namespace grpc_core { +const uint8_t g_static_metadata_bytes[] = { + 58, 112, 97, 116, 104, 58, 109, 101, 116, 104, 111, 100, 58, 115, 116, + 97, 116, 117, 115, 58, 97, 117, 116, 104, 111, 114, 105, 116, 121, 58, + 115, 99, 104, 101, 109, 101, 103, 114, 112, 99, 45, 109, 101, 115, 115, + 97, 103, 101, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 103, + 114, 112, 99, 45, 112, 97, 121, 108, 111, 97, 100, 45, 98, 105, 110, + 103, 114, 112, 99, 45, 101, 110, 99, 111, 100, 105, 110, 103, 103, 114, + 112, 99, 45, 97, 99, 99, 101, 112, 116, 45, 101, 110, 99, 111, 100, + 105, 110, 103, 103, 114, 112, 99, 45, 115, 101, 114, 118, 101, 114, 45, + 115, 116, 97, 116, 115, 45, 98, 105, 110, 103, 114, 112, 99, 45, 116, + 97, 103, 115, 45, 98, 105, 110, 103, 114, 112, 99, 45, 116, 114, 97, + 99, 101, 45, 98, 105, 110, 99, 111, 110, 116, 101, 110, 116, 45, 116, + 121, 112, 101, 99, 111, 110, 116, 101, 110, 116, 45, 101, 110, 99, 111, + 100, 105, 110, 103, 97, 99, 99, 101, 112, 116, 45, 101, 110, 99, 111, + 100, 105, 110, 103, 103, 114, 112, 99, 45, 105, 110, 116, 101, 114, 110, + 97, 108, 45, 101, 110, 99, 111, 100, 105, 110, 103, 45, 114, 101, 113, + 117, 101, 115, 116, 103, 114, 112, 99, 45, 105, 110, 116, 101, 114, 110, + 97, 108, 45, 115, 116, 114, 101, 97, 109, 45, 101, 110, 99, 111, 100, + 105, 110, 103, 45, 114, 101, 113, 117, 101, 115, 116, 117, 115, 101, 114, + 45, 97, 103, 101, 110, 116, 104, 111, 115, 116, 103, 114, 112, 99, 45, + 112, 114, 101, 118, 105, 111, 117, 115, 45, 114, 112, 99, 45, 97, 116, + 116, 101, 109, 112, 116, 115, 103, 114, 112, 99, 45, 114, 101, 116, 114, + 121, 45, 112, 117, 115, 104, 98, 97, 99, 107, 45, 109, 115, 120, 45, + 101, 110, 100, 112, 111, 105, 110, 116, 45, 108, 111, 97, 100, 45, 109, + 101, 116, 114, 105, 99, 115, 45, 98, 105, 110, 103, 114, 112, 99, 45, + 116, 105, 109, 101, 111, 117, 116, 49, 50, 51, 52, 103, 114, 112, 99, + 46, 119, 97, 105, 116, 95, 102, 111, 114, 95, 114, 101, 97, 100, 121, + 103, 114, 112, 99, 46, 116, 105, 109, 101, 111, 117, 116, 103, 114, 112, + 99, 46, 109, 97, 120, 95, 114, 101, 113, 117, 101, 115, 116, 95, 109, + 101, 115, 115, 97, 103, 101, 95, 98, 121, 116, 101, 115, 103, 114, 112, + 99, 46, 109, 97, 120, 95, 114, 101, 115, 112, 111, 110, 115, 101, 95, + 109, 101, 115, 115, 97, 103, 101, 95, 98, 121, 116, 101, 115, 47, 103, + 114, 112, 99, 46, 108, 98, 46, 118, 49, 46, 76, 111, 97, 100, 66, + 97, 108, 97, 110, 99, 101, 114, 47, 66, 97, 108, 97, 110, 99, 101, + 76, 111, 97, 100, 47, 101, 110, 118, 111, 121, 46, 115, 101, 114, 118, + 105, 99, 101, 46, 108, 111, 97, 100, 95, 115, 116, 97, 116, 115, 46, + 118, 50, 46, 76, 111, 97, 100, 82, 101, 112, 111, 114, 116, 105, 110, + 103, 83, 101, 114, 118, 105, 99, 101, 47, 83, 116, 114, 101, 97, 109, + 76, 111, 97, 100, 83, 116, 97, 116, 115, 47, 101, 110, 118, 111, 121, + 46, 115, 101, 114, 118, 105, 99, 101, 46, 108, 111, 97, 100, 95, 115, + 116, 97, 116, 115, 46, 118, 51, 46, 76, 111, 97, 100, 82, 101, 112, + 111, 114, 116, 105, 110, 103, 83, 101, 114, 118, 105, 99, 101, 47, 83, + 116, 114, 101, 97, 109, 76, 111, 97, 100, 83, 116, 97, 116, 115, 47, + 103, 114, 112, 99, 46, 104, 101, 97, 108, 116, 104, 46, 118, 49, 46, + 72, 101, 97, 108, 116, 104, 47, 87, 97, 116, 99, 104, 47, 101, 110, + 118, 111, 121, 46, 115, 101, 114, 118, 105, 99, 101, 46, 100, 105, 115, + 99, 111, 118, 101, 114, 121, 46, 118, 50, 46, 65, 103, 103, 114, 101, + 103, 97, 116, 101, 100, 68, 105, 115, 99, 111, 118, 101, 114, 121, 83, + 101, 114, 118, 105, 99, 101, 47, 83, 116, 114, 101, 97, 109, 65, 103, + 103, 114, 101, 103, 97, 116, 101, 100, 82, 101, 115, 111, 117, 114, 99, + 101, 115, 47, 101, 110, 118, 111, 121, 46, 115, 101, 114, 118, 105, 99, + 101, 46, 100, 105, 115, 99, 111, 118, 101, 114, 121, 46, 118, 51, 46, + 65, 103, 103, 114, 101, 103, 97, 116, 101, 100, 68, 105, 115, 99, 111, + 118, 101, 114, 121, 83, 101, 114, 118, 105, 99, 101, 47, 83, 116, 114, + 101, 97, 109, 65, 103, 103, 114, 101, 103, 97, 116, 101, 100, 82, 101, + 115, 111, 117, 114, 99, 101, 115, 100, 101, 102, 108, 97, 116, 101, 103, + 122, 105, 112, 115, 116, 114, 101, 97, 109, 47, 103, 122, 105, 112, 116, + 101, 116, 114, 97, 105, 108, 101, 114, 115, 71, 69, 84, 80, 79, 83, + 84, 47, 47, 105, 110, 100, 101, 120, 46, 104, 116, 109, 108, 104, 116, + 116, 112, 104, 116, 116, 112, 115, 50, 48, 48, 50, 48, 52, 50, 48, + 54, 51, 48, 52, 52, 48, 48, 52, 48, 52, 53, 48, 48, 97, 99, + 99, 101, 112, 116, 45, 99, 104, 97, 114, 115, 101, 116, 103, 122, 105, + 112, 44, 32, 100, 101, 102, 108, 97, 116, 101, 97, 99, 99, 101, 112, + 116, 45, 108, 97, 110, 103, 117, 97, 103, 101, 97, 99, 99, 101, 112, + 116, 45, 114, 97, 110, 103, 101, 115, 97, 99, 99, 101, 112, 116, 97, + 99, 99, 101, 115, 115, 45, 99, 111, 110, 116, 114, 111, 108, 45, 97, + 108, 108, 111, 119, 45, 111, 114, 105, 103, 105, 110, 97, 103, 101, 97, + 108, 108, 111, 119, 97, 117, 116, 104, 111, 114, 105, 122, 97, 116, 105, + 111, 110, 99, 97, 99, 104, 101, 45, 99, 111, 110, 116, 114, 111, 108, + 99, 111, 110, 116, 101, 110, 116, 45, 100, 105, 115, 112, 111, 115, 105, + 116, 105, 111, 110, 99, 111, 110, 116, 101, 110, 116, 45, 108, 97, 110, + 103, 117, 97, 103, 101, 99, 111, 110, 116, 101, 110, 116, 45, 108, 101, + 110, 103, 116, 104, 99, 111, 110, 116, 101, 110, 116, 45, 108, 111, 99, + 97, 116, 105, 111, 110, 99, 111, 110, 116, 101, 110, 116, 45, 114, 97, + 110, 103, 101, 99, 111, 111, 107, 105, 101, 100, 97, 116, 101, 101, 116, + 97, 103, 101, 120, 112, 101, 99, 116, 101, 120, 112, 105, 114, 101, 115, + 102, 114, 111, 109, 105, 102, 45, 109, 97, 116, 99, 104, 105, 102, 45, + 109, 111, 100, 105, 102, 105, 101, 100, 45, 115, 105, 110, 99, 101, 105, + 102, 45, 110, 111, 110, 101, 45, 109, 97, 116, 99, 104, 105, 102, 45, + 114, 97, 110, 103, 101, 105, 102, 45, 117, 110, 109, 111, 100, 105, 102, + 105, 101, 100, 45, 115, 105, 110, 99, 101, 108, 97, 115, 116, 45, 109, + 111, 100, 105, 102, 105, 101, 100, 108, 105, 110, 107, 108, 111, 99, 97, + 116, 105, 111, 110, 109, 97, 120, 45, 102, 111, 114, 119, 97, 114, 100, + 115, 112, 114, 111, 120, 121, 45, 97, 117, 116, 104, 101, 110, 116, 105, + 99, 97, 116, 101, 112, 114, 111, 120, 121, 45, 97, 117, 116, 104, 111, + 114, 105, 122, 97, 116, 105, 111, 110, 114, 97, 110, 103, 101, 114, 101, + 102, 101, 114, 101, 114, 114, 101, 102, 114, 101, 115, 104, 114, 101, 116, + 114, 121, 45, 97, 102, 116, 101, 114, 115, 101, 114, 118, 101, 114, 115, + 101, 116, 45, 99, 111, 111, 107, 105, 101, 115, 116, 114, 105, 99, 116, + 45, 116, 114, 97, 110, 115, 112, 111, 114, 116, 45, 115, 101, 99, 117, + 114, 105, 116, 121, 116, 114, 97, 110, 115, 102, 101, 114, 45, 101, 110, + 99, 111, 100, 105, 110, 103, 118, 97, 114, 121, 118, 105, 97, 119, 119, + 119, 45, 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 101, 48, + 105, 100, 101, 110, 116, 105, 116, 121, 97, 112, 112, 108, 105, 99, 97, + 116, 105, 111, 110, 47, 103, 114, 112, 99, 103, 114, 112, 99, 80, 85, + 84, 108, 98, 45, 99, 111, 115, 116, 45, 98, 105, 110, 105, 100, 101, + 110, 116, 105, 116, 121, 44, 100, 101, 102, 108, 97, 116, 101, 105, 100, + 101, 110, 116, 105, 116, 121, 44, 103, 122, 105, 112, 100, 101, 102, 108, + 97, 116, 101, 44, 103, 122, 105, 112, 105, 100, 101, 110, 116, 105, 116, + 121, 44, 100, 101, 102, 108, 97, 116, 101, 44, 103, 122, 105, 112}; + +grpc_slice_refcount grpc_core::StaticSliceRefcount::kStaticSubRefcount; + +StaticSliceRefcount g_static_metadata_slice_refcounts[GRPC_STATIC_MDSTR_COUNT] = + { + + StaticSliceRefcount(0), StaticSliceRefcount(1), + StaticSliceRefcount(2), StaticSliceRefcount(3), + StaticSliceRefcount(4), StaticSliceRefcount(5), + StaticSliceRefcount(6), StaticSliceRefcount(7), + StaticSliceRefcount(8), StaticSliceRefcount(9), + StaticSliceRefcount(10), StaticSliceRefcount(11), + StaticSliceRefcount(12), StaticSliceRefcount(13), + StaticSliceRefcount(14), StaticSliceRefcount(15), + StaticSliceRefcount(16), StaticSliceRefcount(17), + StaticSliceRefcount(18), StaticSliceRefcount(19), + StaticSliceRefcount(20), StaticSliceRefcount(21), + StaticSliceRefcount(22), StaticSliceRefcount(23), + StaticSliceRefcount(24), StaticSliceRefcount(25), + StaticSliceRefcount(26), StaticSliceRefcount(27), + StaticSliceRefcount(28), StaticSliceRefcount(29), + StaticSliceRefcount(30), StaticSliceRefcount(31), + StaticSliceRefcount(32), StaticSliceRefcount(33), + StaticSliceRefcount(34), StaticSliceRefcount(35), + StaticSliceRefcount(36), StaticSliceRefcount(37), + StaticSliceRefcount(38), StaticSliceRefcount(39), + StaticSliceRefcount(40), StaticSliceRefcount(41), + StaticSliceRefcount(42), StaticSliceRefcount(43), + StaticSliceRefcount(44), StaticSliceRefcount(45), + StaticSliceRefcount(46), StaticSliceRefcount(47), + StaticSliceRefcount(48), StaticSliceRefcount(49), + StaticSliceRefcount(50), StaticSliceRefcount(51), + StaticSliceRefcount(52), StaticSliceRefcount(53), + StaticSliceRefcount(54), StaticSliceRefcount(55), + StaticSliceRefcount(56), StaticSliceRefcount(57), + StaticSliceRefcount(58), StaticSliceRefcount(59), + StaticSliceRefcount(60), StaticSliceRefcount(61), + StaticSliceRefcount(62), StaticSliceRefcount(63), + StaticSliceRefcount(64), StaticSliceRefcount(65), + StaticSliceRefcount(66), StaticSliceRefcount(67), + StaticSliceRefcount(68), StaticSliceRefcount(69), + StaticSliceRefcount(70), StaticSliceRefcount(71), + StaticSliceRefcount(72), StaticSliceRefcount(73), + StaticSliceRefcount(74), StaticSliceRefcount(75), + StaticSliceRefcount(76), StaticSliceRefcount(77), + StaticSliceRefcount(78), StaticSliceRefcount(79), + StaticSliceRefcount(80), StaticSliceRefcount(81), + StaticSliceRefcount(82), StaticSliceRefcount(83), + StaticSliceRefcount(84), StaticSliceRefcount(85), + StaticSliceRefcount(86), StaticSliceRefcount(87), + StaticSliceRefcount(88), StaticSliceRefcount(89), + StaticSliceRefcount(90), StaticSliceRefcount(91), + StaticSliceRefcount(92), StaticSliceRefcount(93), + StaticSliceRefcount(94), StaticSliceRefcount(95), + StaticSliceRefcount(96), StaticSliceRefcount(97), + StaticSliceRefcount(98), StaticSliceRefcount(99), + StaticSliceRefcount(100), StaticSliceRefcount(101), + StaticSliceRefcount(102), StaticSliceRefcount(103), + StaticSliceRefcount(104), StaticSliceRefcount(105), + StaticSliceRefcount(106), StaticSliceRefcount(107), + StaticSliceRefcount(108), StaticSliceRefcount(109), +}; + +const StaticMetadataSlice + g_static_metadata_slice_table[GRPC_STATIC_MDSTR_COUNT] = { + + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[0].base, 5, + g_static_metadata_bytes + 0), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[1].base, 7, + g_static_metadata_bytes + 5), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[3].base, 10, + g_static_metadata_bytes + 19), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[4].base, 7, + g_static_metadata_bytes + 29), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[5].base, 12, + g_static_metadata_bytes + 36), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[6].base, 11, + g_static_metadata_bytes + 48), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[7].base, 16, + g_static_metadata_bytes + 59), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[8].base, 13, + g_static_metadata_bytes + 75), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[10].base, 21, + g_static_metadata_bytes + 108), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[11].base, 13, + g_static_metadata_bytes + 129), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[12].base, 14, + g_static_metadata_bytes + 142), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[13].base, 12, + g_static_metadata_bytes + 156), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[14].base, 16, + g_static_metadata_bytes + 168), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[16].base, 30, + g_static_metadata_bytes + 199), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[17].base, 37, + g_static_metadata_bytes + 229), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[18].base, 10, + g_static_metadata_bytes + 266), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[19].base, 4, + g_static_metadata_bytes + 276), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[20].base, 26, + g_static_metadata_bytes + 280), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[21].base, 22, + g_static_metadata_bytes + 306), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[22].base, 27, + g_static_metadata_bytes + 328), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[23].base, 12, + g_static_metadata_bytes + 355), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[24].base, 1, + g_static_metadata_bytes + 367), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[25].base, 1, + g_static_metadata_bytes + 368), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[26].base, 1, + g_static_metadata_bytes + 369), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[27].base, 1, + g_static_metadata_bytes + 370), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[29].base, 19, + g_static_metadata_bytes + 371), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[30].base, 12, + g_static_metadata_bytes + 390), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[31].base, 30, + g_static_metadata_bytes + 402), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[32].base, 31, + g_static_metadata_bytes + 432), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[33].base, 36, + g_static_metadata_bytes + 463), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[34].base, 65, + g_static_metadata_bytes + 499), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[35].base, 65, + g_static_metadata_bytes + 564), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[36].base, 28, + g_static_metadata_bytes + 629), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[37].base, 80, + g_static_metadata_bytes + 657), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[38].base, 80, + g_static_metadata_bytes + 737), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[39].base, 7, + g_static_metadata_bytes + 817), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[40].base, 4, + g_static_metadata_bytes + 824), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[41].base, 11, + g_static_metadata_bytes + 828), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[42].base, 2, + g_static_metadata_bytes + 839), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[43].base, 8, + g_static_metadata_bytes + 841), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[44].base, 3, + g_static_metadata_bytes + 849), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[45].base, 4, + g_static_metadata_bytes + 852), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[46].base, 1, + g_static_metadata_bytes + 856), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[47].base, 11, + g_static_metadata_bytes + 857), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[48].base, 4, + g_static_metadata_bytes + 868), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[49].base, 5, + g_static_metadata_bytes + 872), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[50].base, 3, + g_static_metadata_bytes + 877), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[51].base, 3, + g_static_metadata_bytes + 880), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[52].base, 3, + g_static_metadata_bytes + 883), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[53].base, 3, + g_static_metadata_bytes + 886), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[54].base, 3, + g_static_metadata_bytes + 889), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[55].base, 3, + g_static_metadata_bytes + 892), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[56].base, 3, + g_static_metadata_bytes + 895), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[57].base, 14, + g_static_metadata_bytes + 898), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[58].base, 13, + g_static_metadata_bytes + 912), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[59].base, 15, + g_static_metadata_bytes + 925), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[60].base, 13, + g_static_metadata_bytes + 940), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[61].base, 6, + g_static_metadata_bytes + 953), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[62].base, 27, + g_static_metadata_bytes + 959), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[63].base, 3, + g_static_metadata_bytes + 986), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[64].base, 5, + g_static_metadata_bytes + 989), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[65].base, 13, + g_static_metadata_bytes + 994), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[66].base, 13, + g_static_metadata_bytes + 1007), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[67].base, 19, + g_static_metadata_bytes + 1020), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[68].base, 16, + g_static_metadata_bytes + 1039), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[69].base, 14, + g_static_metadata_bytes + 1055), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[70].base, 16, + g_static_metadata_bytes + 1069), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[71].base, 13, + g_static_metadata_bytes + 1085), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[72].base, 6, + g_static_metadata_bytes + 1098), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[73].base, 4, + g_static_metadata_bytes + 1104), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[74].base, 4, + g_static_metadata_bytes + 1108), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[75].base, 6, + g_static_metadata_bytes + 1112), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[76].base, 7, + g_static_metadata_bytes + 1118), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[77].base, 4, + g_static_metadata_bytes + 1125), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[78].base, 8, + g_static_metadata_bytes + 1129), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[79].base, 17, + g_static_metadata_bytes + 1137), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[80].base, 13, + g_static_metadata_bytes + 1154), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[81].base, 8, + g_static_metadata_bytes + 1167), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[82].base, 19, + g_static_metadata_bytes + 1175), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[83].base, 13, + g_static_metadata_bytes + 1194), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[84].base, 4, + g_static_metadata_bytes + 1207), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[85].base, 8, + g_static_metadata_bytes + 1211), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[86].base, 12, + g_static_metadata_bytes + 1219), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[87].base, 18, + g_static_metadata_bytes + 1231), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[88].base, 19, + g_static_metadata_bytes + 1249), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[89].base, 5, + g_static_metadata_bytes + 1268), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[90].base, 7, + g_static_metadata_bytes + 1273), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[91].base, 7, + g_static_metadata_bytes + 1280), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[92].base, 11, + g_static_metadata_bytes + 1287), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[93].base, 6, + g_static_metadata_bytes + 1298), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[94].base, 10, + g_static_metadata_bytes + 1304), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[95].base, 25, + g_static_metadata_bytes + 1314), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[96].base, 17, + g_static_metadata_bytes + 1339), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[97].base, 4, + g_static_metadata_bytes + 1356), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[98].base, 3, + g_static_metadata_bytes + 1360), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[99].base, 16, + g_static_metadata_bytes + 1363), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[100].base, 1, + g_static_metadata_bytes + 1379), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[101].base, 8, + g_static_metadata_bytes + 1380), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[102].base, 16, + g_static_metadata_bytes + 1388), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[103].base, 4, + g_static_metadata_bytes + 1404), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[104].base, 3, + g_static_metadata_bytes + 1408), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[105].base, 11, + g_static_metadata_bytes + 1411), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[106].base, 16, + g_static_metadata_bytes + 1422), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[107].base, 13, + g_static_metadata_bytes + 1438), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[108].base, 12, + g_static_metadata_bytes + 1451), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[109].base, 21, + g_static_metadata_bytes + 1463), +}; +} // namespace grpc_core diff --git a/src/core/lib/surface/api_trace.cc b/src/core/lib/surface/api_trace.cc new file mode 100644 index 00000000..c40b9e1b --- /dev/null +++ b/src/core/lib/surface/api_trace.cc @@ -0,0 +1,25 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/api_trace.h" + +#include "src/core/lib/debug/trace.h" + +grpc_core::TraceFlag grpc_api_trace(false, "api"); diff --git a/src/core/lib/surface/builtins.cc b/src/core/lib/surface/builtins.cc new file mode 100644 index 00000000..b195ef9b --- /dev/null +++ b/src/core/lib/surface/builtins.cc @@ -0,0 +1,49 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/surface/builtins.h" + +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/lame_client.h" +#include "src/core/lib/surface/server.h" + +namespace grpc_core { + +void RegisterBuiltins(CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + grpc_add_connected_filter); + builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + grpc_add_connected_filter); + builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, + GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + grpc_add_connected_filter); + builder->channel_init()->RegisterStage( + GRPC_CLIENT_LAME_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + [](grpc_channel_stack_builder* builder) { + return grpc_channel_stack_builder_append_filter( + builder, &grpc_lame_filter, nullptr, nullptr); + }); + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, INT_MAX, [](grpc_channel_stack_builder* builder) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_core::Server::kServerTopFilter, nullptr, nullptr); + }); +} + +} // namespace grpc_core diff --git a/src/core/lib/surface/byte_buffer.cc b/src/core/lib/surface/byte_buffer.cc new file mode 100644 index 00000000..6246796e --- /dev/null +++ b/src/core/lib/surface/byte_buffer.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +grpc_byte_buffer* grpc_raw_byte_buffer_create(grpc_slice* slices, + size_t nslices) { + return grpc_raw_compressed_byte_buffer_create(slices, nslices, + GRPC_COMPRESS_NONE); +} + +grpc_byte_buffer* grpc_raw_compressed_byte_buffer_create( + grpc_slice* slices, size_t nslices, + grpc_compression_algorithm compression) { + size_t i; + grpc_byte_buffer* bb = + static_cast(gpr_malloc(sizeof(grpc_byte_buffer))); + bb->type = GRPC_BB_RAW; + bb->data.raw.compression = compression; + grpc_slice_buffer_init(&bb->data.raw.slice_buffer); + for (i = 0; i < nslices; i++) { + grpc_slice_ref_internal(slices[i]); + grpc_slice_buffer_add(&bb->data.raw.slice_buffer, slices[i]); + } + return bb; +} + +grpc_byte_buffer* grpc_raw_byte_buffer_from_reader( + grpc_byte_buffer_reader* reader) { + grpc_byte_buffer* bb = + static_cast(gpr_malloc(sizeof(grpc_byte_buffer))); + grpc_slice slice; + bb->type = GRPC_BB_RAW; + bb->data.raw.compression = GRPC_COMPRESS_NONE; + grpc_slice_buffer_init(&bb->data.raw.slice_buffer); + + while (grpc_byte_buffer_reader_next(reader, &slice)) { + grpc_slice_buffer_add(&bb->data.raw.slice_buffer, slice); + } + return bb; +} + +grpc_byte_buffer* grpc_byte_buffer_copy(grpc_byte_buffer* bb) { + switch (bb->type) { + case GRPC_BB_RAW: + return grpc_raw_compressed_byte_buffer_create( + bb->data.raw.slice_buffer.slices, bb->data.raw.slice_buffer.count, + bb->data.raw.compression); + } + GPR_UNREACHABLE_CODE(return nullptr); +} + +void grpc_byte_buffer_destroy(grpc_byte_buffer* bb) { + if (!bb) return; + grpc_core::ExecCtx exec_ctx; + switch (bb->type) { + case GRPC_BB_RAW: + grpc_slice_buffer_destroy_internal(&bb->data.raw.slice_buffer); + break; + } + gpr_free(bb); +} + +size_t grpc_byte_buffer_length(grpc_byte_buffer* bb) { + switch (bb->type) { + case GRPC_BB_RAW: + return bb->data.raw.slice_buffer.length; + } + GPR_UNREACHABLE_CODE(return 0); +} diff --git a/src/core/lib/surface/byte_buffer_reader.cc b/src/core/lib/surface/byte_buffer_reader.cc new file mode 100644 index 00000000..207aaef8 --- /dev/null +++ b/src/core/lib/surface/byte_buffer_reader.cc @@ -0,0 +1,101 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +int grpc_byte_buffer_reader_init(grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + reader->buffer_in = buffer; + switch (reader->buffer_in->type) { + case GRPC_BB_RAW: + reader->buffer_out = reader->buffer_in; + reader->current.index = 0; + break; + } + return 1; +} + +void grpc_byte_buffer_reader_destroy(grpc_byte_buffer_reader* reader) { + reader->buffer_out = nullptr; +} + +int grpc_byte_buffer_reader_peek(grpc_byte_buffer_reader* reader, + grpc_slice** slice) { + switch (reader->buffer_in->type) { + case GRPC_BB_RAW: { + grpc_slice_buffer* slice_buffer; + slice_buffer = &reader->buffer_out->data.raw.slice_buffer; + if (reader->current.index < slice_buffer->count) { + *slice = &slice_buffer->slices[reader->current.index]; + reader->current.index += 1; + return 1; + } + break; + } + } + return 0; +} + +int grpc_byte_buffer_reader_next(grpc_byte_buffer_reader* reader, + grpc_slice* slice) { + switch (reader->buffer_in->type) { + case GRPC_BB_RAW: { + grpc_slice_buffer* slice_buffer; + slice_buffer = &reader->buffer_out->data.raw.slice_buffer; + if (reader->current.index < slice_buffer->count) { + *slice = grpc_slice_ref_internal( + slice_buffer->slices[reader->current.index]); + reader->current.index += 1; + return 1; + } + break; + } + } + return 0; +} + +grpc_slice grpc_byte_buffer_reader_readall(grpc_byte_buffer_reader* reader) { + grpc_slice in_slice; + size_t bytes_read = 0; + const size_t input_size = grpc_byte_buffer_length(reader->buffer_out); + grpc_slice out_slice = GRPC_SLICE_MALLOC(input_size); + uint8_t* const outbuf = GRPC_SLICE_START_PTR(out_slice); /* just an alias */ + + grpc_core::ExecCtx exec_ctx; + while (grpc_byte_buffer_reader_next(reader, &in_slice) != 0) { + const size_t slice_length = GRPC_SLICE_LENGTH(in_slice); + memcpy(&(outbuf[bytes_read]), GRPC_SLICE_START_PTR(in_slice), slice_length); + bytes_read += slice_length; + grpc_slice_unref_internal(in_slice); + GPR_ASSERT(bytes_read <= input_size); + } + + return out_slice; +} diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc new file mode 100644 index 00000000..2a4ec680 --- /dev/null +++ b/src/core/lib/surface/call.cc @@ -0,0 +1,2055 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/call.h" + +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/time_precise.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/arena.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_split.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call_test_only.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" +#include "src/core/lib/transport/transport.h" + +/** The maximum number of concurrent batches possible. + Based upon the maximum number of individually queueable ops in the batch + api: + - initial metadata send + - message send + - status/close send (depending on client/server) + - initial metadata recv + - message recv + - status/close recv (depending on client/server) */ +#define MAX_CONCURRENT_BATCHES 6 + +#define MAX_SEND_EXTRA_METADATA_COUNT 3 + +// Used to create arena for the first call. +#define ESTIMATED_MDELEM_COUNT 16 + +struct batch_control { + batch_control() = default; + + grpc_call* call = nullptr; + grpc_transport_stream_op_batch op; + /* Share memory for cq_completion and notify_tag as they are never needed + simultaneously. Each byte used in this data structure count as six bytes + per call, so any savings we can make are worthwhile, + + We use notify_tag to determine whether or not to send notification to the + completion queue. Once we've made that determination, we can reuse the + memory for cq_completion. */ + union { + grpc_cq_completion cq_completion; + struct { + /* Any given op indicates completion by either (a) calling a closure or + (b) sending a notification on the call's completion queue. If + \a is_closure is true, \a tag indicates a closure to be invoked; + otherwise, \a tag indicates the tag to be used in the notification to + be sent to the completion queue. */ + void* tag; + bool is_closure; + } notify_tag; + } completion_data; + grpc_closure start_batch; + grpc_closure finish_batch; + std::atomic steps_to_complete{0}; + AtomicError batch_error; + void set_num_steps_to_complete(uintptr_t steps) { + steps_to_complete.store(steps, std::memory_order_release); + } + bool completed_batch_step() { + return steps_to_complete.fetch_sub(1, std::memory_order_acq_rel) == 1; + } +}; + +struct parent_call { + parent_call() { gpr_mu_init(&child_list_mu); } + ~parent_call() { gpr_mu_destroy(&child_list_mu); } + + gpr_mu child_list_mu; + grpc_call* first_child = nullptr; +}; + +struct child_call { + explicit child_call(grpc_call* parent) : parent(parent) {} + grpc_call* parent; + /** siblings: children of the same parent form a list, and this list is + protected under + parent->mu */ + grpc_call* sibling_next = nullptr; + grpc_call* sibling_prev = nullptr; +}; + +#define RECV_NONE ((gpr_atm)0) +#define RECV_INITIAL_METADATA_FIRST ((gpr_atm)1) + +struct grpc_call { + grpc_call(grpc_core::Arena* arena, const grpc_call_create_args& args) + : arena(arena), + cq(args.cq), + channel(args.channel), + is_client(args.server_transport_data == nullptr), + stream_op_payload(context) {} + + ~grpc_call() { + for (int i = 0; i < GRPC_CONTEXT_COUNT; ++i) { + if (context[i].destroy) { + context[i].destroy(context[i].value); + } + } + gpr_free(static_cast(const_cast(final_info.error_string))); + } + + grpc_core::RefCount ext_ref; + grpc_core::Arena* arena; + grpc_core::CallCombiner call_combiner; + grpc_completion_queue* cq; + grpc_polling_entity pollent; + grpc_channel* channel; + gpr_cycle_counter start_time = gpr_get_cycle_counter(); + /* parent_call* */ gpr_atm parent_call_atm = 0; + child_call* child = nullptr; + + /* client or server call */ + bool is_client; + /** has grpc_call_unref been called */ + bool destroy_called = false; + /** flag indicating that cancellation is inherited */ + bool cancellation_is_inherited = false; + // Trailers-only response status + bool is_trailers_only = false; + /** which ops are in-flight */ + bool sent_initial_metadata = false; + bool sending_message = false; + bool sent_final_op = false; + bool received_initial_metadata = false; + bool receiving_message = false; + bool requested_final_op = false; + gpr_atm any_ops_sent_atm = 0; + gpr_atm received_final_op_atm = 0; + + batch_control* active_batches[MAX_CONCURRENT_BATCHES] = {}; + grpc_transport_stream_op_batch_payload stream_op_payload; + + /* first idx: is_receiving, second idx: is_trailing */ + grpc_metadata_batch send_initial_metadata{arena}; + grpc_metadata_batch send_trailing_metadata{arena}; + grpc_metadata_batch recv_initial_metadata{arena}; + grpc_metadata_batch recv_trailing_metadata{arena}; + + /* Buffered read metadata waiting to be returned to the application. + Element 0 is initial metadata, element 1 is trailing metadata. */ + grpc_metadata_array* buffered_metadata[2] = {}; + + grpc_metadata compression_md; + + // A char* indicating the peer name. + gpr_atm peer_string = 0; + + /* Call data useful used for reporting. Only valid after the call has + * completed */ + grpc_call_final_info final_info; + + /* Compression algorithm for *incoming* data */ + grpc_message_compression_algorithm incoming_message_compression_algorithm = + GRPC_MESSAGE_COMPRESS_NONE; + /* Stream compression algorithm for *incoming* data */ + grpc_stream_compression_algorithm incoming_stream_compression_algorithm = + GRPC_STREAM_COMPRESS_NONE; + /* Supported encodings (compression algorithms), a bitset. + * Always support no compression. */ + uint32_t encodings_accepted_by_peer = 1 << GRPC_MESSAGE_COMPRESS_NONE; + /* Supported stream encodings (stream compression algorithms), a bitset */ + uint32_t stream_encodings_accepted_by_peer = 0; + + /* Contexts for various subsystems (security, tracing, ...). */ + grpc_call_context_element context[GRPC_CONTEXT_COUNT] = {}; + + /* for the client, extra metadata is initial metadata; for the + server, it's trailing metadata */ + grpc_linked_mdelem send_extra_metadata[MAX_SEND_EXTRA_METADATA_COUNT]; + int send_extra_metadata_count; + grpc_millis send_deadline; + + grpc_core::ManualConstructor sending_stream; + + grpc_core::OrphanablePtr receiving_stream; + bool call_failed_before_recv_message = false; + grpc_byte_buffer** receiving_buffer = nullptr; + grpc_slice receiving_slice = grpc_empty_slice(); + grpc_closure receiving_slice_ready; + grpc_closure receiving_stream_ready; + grpc_closure receiving_initial_metadata_ready; + grpc_closure receiving_trailing_metadata_ready; + uint32_t test_only_last_message_flags = 0; + // Status about operation of call + bool sent_server_trailing_metadata = false; + gpr_atm cancelled_with_error = 0; + + grpc_closure release_call; + + union { + struct { + grpc_status_code* status; + grpc_slice* status_details; + const char** error_string; + } client; + struct { + int* cancelled; + // backpointer to owning server if this is a server side call. + grpc_core::Server* core_server; + } server; + } final_op; + AtomicError status_error; + + /* recv_state can contain one of the following values: + RECV_NONE : : no initial metadata and messages received + RECV_INITIAL_METADATA_FIRST : received initial metadata first + a batch_control* : received messages first + + +------1------RECV_NONE------3-----+ + | | + | | + v v + RECV_INITIAL_METADATA_FIRST receiving_stream_ready_bctlp + | ^ | ^ + | | | | + +-----2-----+ +-----4-----+ + + For 1, 4: See receiving_initial_metadata_ready() function + For 2, 3: See receiving_stream_ready() function */ + gpr_atm recv_state = 0; +}; + +grpc_core::TraceFlag grpc_call_error_trace(false, "call_error"); +grpc_core::TraceFlag grpc_compression_trace(false, "compression"); + +#define CALL_STACK_FROM_CALL(call) \ + (grpc_call_stack*)((char*)(call) + \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_call))) +#define CALL_FROM_CALL_STACK(call_stack) \ + (grpc_call*)(((char*)(call_stack)) - \ + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_call))) + +#define CALL_ELEM_FROM_CALL(call, idx) \ + grpc_call_stack_element(CALL_STACK_FROM_CALL(call), idx) +#define CALL_FROM_TOP_ELEM(top_elem) \ + CALL_FROM_CALL_STACK(grpc_call_stack_from_top_element(top_elem)) + +static void execute_batch(grpc_call* call, + grpc_transport_stream_op_batch* batch, + grpc_closure* start_batch_closure); + +static void cancel_with_status(grpc_call* c, grpc_status_code status, + const char* description); +static void cancel_with_error(grpc_call* c, grpc_error_handle error); +static void destroy_call(void* call_stack, grpc_error_handle error); +static void receiving_slice_ready(void* bctlp, grpc_error_handle error); +static void set_final_status(grpc_call* call, grpc_error_handle error); +static void process_data_after_md(batch_control* bctl); +static void post_batch_completion(batch_control* bctl); + +static void add_init_error(grpc_error_handle* composite, + grpc_error_handle new_err) { + if (new_err == GRPC_ERROR_NONE) return; + if (*composite == GRPC_ERROR_NONE) { + *composite = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Call creation failed"); + } + *composite = grpc_error_add_child(*composite, new_err); +} + +void* grpc_call_arena_alloc(grpc_call* call, size_t size) { + return call->arena->Alloc(size); +} + +static parent_call* get_or_create_parent_call(grpc_call* call) { + parent_call* p = + reinterpret_cast(gpr_atm_acq_load(&call->parent_call_atm)); + if (p == nullptr) { + p = call->arena->New(); + if (!gpr_atm_rel_cas(&call->parent_call_atm, + reinterpret_cast(nullptr), + reinterpret_cast(p))) { + p->~parent_call(); + p = reinterpret_cast( + gpr_atm_acq_load(&call->parent_call_atm)); + } + } + return p; +} + +static parent_call* get_parent_call(grpc_call* call) { + return reinterpret_cast( + gpr_atm_acq_load(&call->parent_call_atm)); +} + +size_t grpc_call_get_initial_size_estimate() { + return sizeof(grpc_call) + sizeof(batch_control) * MAX_CONCURRENT_BATCHES + + sizeof(grpc_linked_mdelem) * ESTIMATED_MDELEM_COUNT; +} + +grpc_error_handle grpc_call_create(const grpc_call_create_args* args, + grpc_call** out_call) { + GPR_TIMER_SCOPE("grpc_call_create", 0); + + GRPC_CHANNEL_INTERNAL_REF(args->channel, "call"); + + grpc_core::Arena* arena; + grpc_call* call; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_channel_stack* channel_stack = + grpc_channel_get_channel_stack(args->channel); + size_t initial_size = grpc_channel_get_call_size_estimate(args->channel); + GRPC_STATS_INC_CALL_INITIAL_SIZE(initial_size); + size_t call_and_stack_size = + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(grpc_call)) + + channel_stack->call_stack_size; + size_t call_alloc_size = + call_and_stack_size + (args->parent ? sizeof(child_call) : 0); + + std::pair arena_with_call = + grpc_core::Arena::CreateWithAlloc(initial_size, call_alloc_size); + arena = arena_with_call.first; + call = new (arena_with_call.second) grpc_call(arena, *args); + *out_call = call; + grpc_slice path = grpc_empty_slice(); + if (call->is_client) { + call->final_op.client.status_details = nullptr; + call->final_op.client.status = nullptr; + call->final_op.client.error_string = nullptr; + GRPC_STATS_INC_CLIENT_CALLS_CREATED(); + GPR_ASSERT(args->add_initial_metadata_count < + MAX_SEND_EXTRA_METADATA_COUNT); + for (size_t i = 0; i < args->add_initial_metadata_count; i++) { + call->send_extra_metadata[i].md = args->add_initial_metadata[i]; + if (grpc_slice_eq_static_interned( + GRPC_MDKEY(args->add_initial_metadata[i]), GRPC_MDSTR_PATH)) { + path = grpc_slice_ref_internal( + GRPC_MDVALUE(args->add_initial_metadata[i])); + } + } + call->send_extra_metadata_count = + static_cast(args->add_initial_metadata_count); + } else { + GRPC_STATS_INC_SERVER_CALLS_CREATED(); + call->final_op.server.cancelled = nullptr; + call->final_op.server.core_server = args->server; + GPR_ASSERT(args->add_initial_metadata_count == 0); + call->send_extra_metadata_count = 0; + } + + grpc_millis send_deadline = args->send_deadline; + bool immediately_cancel = false; + + if (args->parent != nullptr) { + call->child = new (reinterpret_cast(arena_with_call.second) + + call_and_stack_size) child_call(args->parent); + + GRPC_CALL_INTERNAL_REF(args->parent, "child"); + GPR_ASSERT(call->is_client); + GPR_ASSERT(!args->parent->is_client); + + if (args->propagation_mask & GRPC_PROPAGATE_DEADLINE) { + send_deadline = std::min(send_deadline, args->parent->send_deadline); + } + /* for now GRPC_PROPAGATE_TRACING_CONTEXT *MUST* be passed with + * GRPC_PROPAGATE_STATS_CONTEXT */ + /* TODO(ctiller): This should change to use the appropriate census start_op + * call. */ + if (args->propagation_mask & GRPC_PROPAGATE_CENSUS_TRACING_CONTEXT) { + if (0 == (args->propagation_mask & GRPC_PROPAGATE_CENSUS_STATS_CONTEXT)) { + add_init_error(&error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Census tracing propagation requested " + "without Census context propagation")); + } + grpc_call_context_set(call, GRPC_CONTEXT_TRACING, + args->parent->context[GRPC_CONTEXT_TRACING].value, + nullptr); + } else if (args->propagation_mask & GRPC_PROPAGATE_CENSUS_STATS_CONTEXT) { + add_init_error(&error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Census context propagation requested " + "without Census tracing propagation")); + } + if (args->propagation_mask & GRPC_PROPAGATE_CANCELLATION) { + call->cancellation_is_inherited = true; + if (gpr_atm_acq_load(&args->parent->received_final_op_atm)) { + immediately_cancel = true; + } + } + } + call->send_deadline = send_deadline; + /* initial refcount dropped by grpc_call_unref */ + grpc_call_element_args call_args = {CALL_STACK_FROM_CALL(call), + args->server_transport_data, + call->context, + path, + call->start_time, + send_deadline, + call->arena, + &call->call_combiner}; + add_init_error(&error, grpc_call_stack_init(channel_stack, 1, destroy_call, + call, &call_args)); + // Publish this call to parent only after the call stack has been initialized. + if (args->parent != nullptr) { + child_call* cc = call->child; + parent_call* pc = get_or_create_parent_call(args->parent); + gpr_mu_lock(&pc->child_list_mu); + if (pc->first_child == nullptr) { + pc->first_child = call; + cc->sibling_next = cc->sibling_prev = call; + } else { + cc->sibling_next = pc->first_child; + cc->sibling_prev = pc->first_child->child->sibling_prev; + cc->sibling_next->child->sibling_prev = + cc->sibling_prev->child->sibling_next = call; + } + gpr_mu_unlock(&pc->child_list_mu); + } + + if (error != GRPC_ERROR_NONE) { + cancel_with_error(call, GRPC_ERROR_REF(error)); + } + if (immediately_cancel) { + cancel_with_error(call, GRPC_ERROR_CANCELLED); + } + if (args->cq != nullptr) { + GPR_ASSERT(args->pollset_set_alternative == nullptr && + "Only one of 'cq' and 'pollset_set_alternative' should be " + "non-nullptr."); + GRPC_CQ_INTERNAL_REF(args->cq, "bind"); + call->pollent = + grpc_polling_entity_create_from_pollset(grpc_cq_pollset(args->cq)); + } + if (args->pollset_set_alternative != nullptr) { + call->pollent = grpc_polling_entity_create_from_pollset_set( + args->pollset_set_alternative); + } + if (!grpc_polling_entity_is_empty(&call->pollent)) { + grpc_call_stack_set_pollset_or_pollset_set(CALL_STACK_FROM_CALL(call), + &call->pollent); + } + + if (call->is_client) { + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(call->channel); + if (channelz_channel != nullptr) { + channelz_channel->RecordCallStarted(); + } + } else if (call->final_op.server.core_server != nullptr) { + grpc_core::channelz::ServerNode* channelz_node = + call->final_op.server.core_server->channelz_node(); + if (channelz_node != nullptr) { + channelz_node->RecordCallStarted(); + } + } + + grpc_slice_unref_internal(path); + + return error; +} + +void grpc_call_set_completion_queue(grpc_call* call, + grpc_completion_queue* cq) { + GPR_ASSERT(cq); + + if (grpc_polling_entity_pollset_set(&call->pollent) != nullptr) { + gpr_log(GPR_ERROR, "A pollset_set is already registered for this call."); + abort(); + } + call->cq = cq; + GRPC_CQ_INTERNAL_REF(cq, "bind"); + call->pollent = grpc_polling_entity_create_from_pollset(grpc_cq_pollset(cq)); + grpc_call_stack_set_pollset_or_pollset_set(CALL_STACK_FROM_CALL(call), + &call->pollent); +} + +#ifndef NDEBUG +#define REF_REASON reason +#define REF_ARG , const char* reason +#else +#define REF_REASON "" +#define REF_ARG +#endif +void grpc_call_internal_ref(grpc_call* c REF_ARG) { + GRPC_CALL_STACK_REF(CALL_STACK_FROM_CALL(c), REF_REASON); +} +void grpc_call_internal_unref(grpc_call* c REF_ARG) { + GRPC_CALL_STACK_UNREF(CALL_STACK_FROM_CALL(c), REF_REASON); +} + +static void release_call(void* call, grpc_error_handle /*error*/) { + grpc_call* c = static_cast(call); + grpc_channel* channel = c->channel; + grpc_core::Arena* arena = c->arena; + c->~grpc_call(); + grpc_channel_update_call_size_estimate(channel, arena->Destroy()); + GRPC_CHANNEL_INTERNAL_UNREF(channel, "call"); +} + +static void destroy_call(void* call, grpc_error_handle /*error*/) { + GPR_TIMER_SCOPE("destroy_call", 0); + grpc_call* c = static_cast(call); + c->recv_initial_metadata.Clear(); + c->recv_trailing_metadata.Clear(); + c->receiving_stream.reset(); + parent_call* pc = get_parent_call(c); + if (pc != nullptr) { + pc->~parent_call(); + } + for (int i = 0; i < c->send_extra_metadata_count; i++) { + GRPC_MDELEM_UNREF(c->send_extra_metadata[i].md); + } + if (c->cq) { + GRPC_CQ_INTERNAL_UNREF(c->cq, "bind"); + } + + grpc_error_handle status_error = c->status_error.get(); + grpc_error_get_status(status_error, c->send_deadline, + &c->final_info.final_status, nullptr, nullptr, + &(c->final_info.error_string)); + c->status_error.set(GRPC_ERROR_NONE); + c->final_info.stats.latency = + gpr_cycle_counter_sub(gpr_get_cycle_counter(), c->start_time); + grpc_call_stack_destroy(CALL_STACK_FROM_CALL(c), &c->final_info, + GRPC_CLOSURE_INIT(&c->release_call, release_call, c, + grpc_schedule_on_exec_ctx)); +} + +void grpc_call_ref(grpc_call* c) { c->ext_ref.Ref(); } + +void grpc_call_unref(grpc_call* c) { + if (GPR_LIKELY(!c->ext_ref.Unref())) return; + + GPR_TIMER_SCOPE("grpc_call_unref", 0); + + child_call* cc = c->child; + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + GRPC_API_TRACE("grpc_call_unref(c=%p)", 1, (c)); + + if (cc) { + parent_call* pc = get_parent_call(cc->parent); + gpr_mu_lock(&pc->child_list_mu); + if (c == pc->first_child) { + pc->first_child = cc->sibling_next; + if (c == pc->first_child) { + pc->first_child = nullptr; + } + } + cc->sibling_prev->child->sibling_next = cc->sibling_next; + cc->sibling_next->child->sibling_prev = cc->sibling_prev; + gpr_mu_unlock(&pc->child_list_mu); + GRPC_CALL_INTERNAL_UNREF(cc->parent, "child"); + } + + GPR_ASSERT(!c->destroy_called); + c->destroy_called = true; + bool cancel = gpr_atm_acq_load(&c->any_ops_sent_atm) != 0 && + gpr_atm_acq_load(&c->received_final_op_atm) == 0; + if (cancel) { + cancel_with_error(c, GRPC_ERROR_CANCELLED); + } else { + // Unset the call combiner cancellation closure. This has the + // effect of scheduling the previously set cancellation closure, if + // any, so that it can release any internal references it may be + // holding to the call stack. + c->call_combiner.SetNotifyOnCancel(nullptr); + } + GRPC_CALL_INTERNAL_UNREF(c, "destroy"); +} + +grpc_call_error grpc_call_cancel(grpc_call* call, void* reserved) { + GRPC_API_TRACE("grpc_call_cancel(call=%p, reserved=%p)", 2, (call, reserved)); + GPR_ASSERT(!reserved); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + cancel_with_error(call, GRPC_ERROR_CANCELLED); + return GRPC_CALL_OK; +} + +// This is called via the call combiner to start sending a batch down +// the filter stack. +static void execute_batch_in_call_combiner(void* arg, + grpc_error_handle /*ignored*/) { + GPR_TIMER_SCOPE("execute_batch_in_call_combiner", 0); + grpc_transport_stream_op_batch* batch = + static_cast(arg); + grpc_call* call = static_cast(batch->handler_private.extra_arg); + grpc_call_element* elem = CALL_ELEM_FROM_CALL(call, 0); + GRPC_CALL_LOG_OP(GPR_INFO, elem, batch); + elem->filter->start_transport_stream_op_batch(elem, batch); +} + +// start_batch_closure points to a caller-allocated closure to be used +// for entering the call combiner. +static void execute_batch(grpc_call* call, + grpc_transport_stream_op_batch* batch, + grpc_closure* start_batch_closure) { + batch->handler_private.extra_arg = call; + GRPC_CLOSURE_INIT(start_batch_closure, execute_batch_in_call_combiner, batch, + grpc_schedule_on_exec_ctx); + GRPC_CALL_COMBINER_START(&call->call_combiner, start_batch_closure, + GRPC_ERROR_NONE, "executing batch"); +} + +char* grpc_call_get_peer(grpc_call* call) { + char* peer_string = + reinterpret_cast(gpr_atm_acq_load(&call->peer_string)); + if (peer_string != nullptr) return gpr_strdup(peer_string); + peer_string = grpc_channel_get_target(call->channel); + if (peer_string != nullptr) return peer_string; + return gpr_strdup("unknown"); +} + +grpc_call* grpc_call_from_top_element(grpc_call_element* surface_element) { + return CALL_FROM_TOP_ELEM(surface_element); +} + +/******************************************************************************* + * CANCELLATION + */ + +grpc_call_error grpc_call_cancel_with_status(grpc_call* c, + grpc_status_code status, + const char* description, + void* reserved) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_call_cancel_with_status(" + "c=%p, status=%d, description=%s, reserved=%p)", + 4, (c, (int)status, description, reserved)); + GPR_ASSERT(reserved == nullptr); + cancel_with_status(c, status, description); + return GRPC_CALL_OK; +} + +struct cancel_state { + grpc_call* call; + grpc_closure start_batch; + grpc_closure finish_batch; +}; +// The on_complete callback used when sending a cancel_stream batch down +// the filter stack. Yields the call combiner when the batch is done. +static void done_termination(void* arg, grpc_error_handle /*error*/) { + cancel_state* state = static_cast(arg); + GRPC_CALL_COMBINER_STOP(&state->call->call_combiner, + "on_complete for cancel_stream op"); + GRPC_CALL_INTERNAL_UNREF(state->call, "termination"); + gpr_free(state); +} + +static void cancel_with_error(grpc_call* c, grpc_error_handle error) { + if (!gpr_atm_rel_cas(&c->cancelled_with_error, 0, 1)) { + GRPC_ERROR_UNREF(error); + return; + } + GRPC_CALL_INTERNAL_REF(c, "termination"); + // Inform the call combiner of the cancellation, so that it can cancel + // any in-flight asynchronous actions that may be holding the call + // combiner. This ensures that the cancel_stream batch can be sent + // down the filter stack in a timely manner. + c->call_combiner.Cancel(GRPC_ERROR_REF(error)); + cancel_state* state = static_cast(gpr_malloc(sizeof(*state))); + state->call = c; + GRPC_CLOSURE_INIT(&state->finish_batch, done_termination, state, + grpc_schedule_on_exec_ctx); + grpc_transport_stream_op_batch* op = + grpc_make_transport_stream_op(&state->finish_batch); + op->cancel_stream = true; + op->payload->cancel_stream.cancel_error = error; + execute_batch(c, op, &state->start_batch); +} + +void grpc_call_cancel_internal(grpc_call* call) { + cancel_with_error(call, GRPC_ERROR_CANCELLED); +} + +static grpc_error_handle error_from_status(grpc_status_code status, + const char* description) { + // copying 'description' is needed to ensure the grpc_call_cancel_with_status + // guarantee that can be short-lived. + return grpc_error_set_int( + grpc_error_set_str(GRPC_ERROR_CREATE_FROM_COPIED_STRING(description), + GRPC_ERROR_STR_GRPC_MESSAGE, description), + GRPC_ERROR_INT_GRPC_STATUS, status); +} + +static void cancel_with_status(grpc_call* c, grpc_status_code status, + const char* description) { + cancel_with_error(c, error_from_status(status, description)); +} + +static void set_final_status(grpc_call* call, grpc_error_handle error) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_call_error_trace)) { + gpr_log(GPR_DEBUG, "set_final_status %s", call->is_client ? "CLI" : "SVR"); + gpr_log(GPR_DEBUG, "%s", grpc_error_std_string(error).c_str()); + } + if (call->is_client) { + std::string status_details; + grpc_error_get_status(error, call->send_deadline, + call->final_op.client.status, &status_details, + nullptr, call->final_op.client.error_string); + *call->final_op.client.status_details = + grpc_slice_from_cpp_string(std::move(status_details)); + call->status_error.set(error); + GRPC_ERROR_UNREF(error); + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(call->channel); + if (channelz_channel != nullptr) { + if (*call->final_op.client.status != GRPC_STATUS_OK) { + channelz_channel->RecordCallFailed(); + } else { + channelz_channel->RecordCallSucceeded(); + } + } + } else { + *call->final_op.server.cancelled = + error != GRPC_ERROR_NONE || !call->sent_server_trailing_metadata; + grpc_core::channelz::ServerNode* channelz_node = + call->final_op.server.core_server->channelz_node(); + if (channelz_node != nullptr) { + if (*call->final_op.server.cancelled || !call->status_error.ok()) { + channelz_node->RecordCallFailed(); + } else { + channelz_node->RecordCallSucceeded(); + } + } + GRPC_ERROR_UNREF(error); + } +} + +/******************************************************************************* + * COMPRESSION + */ + +static void set_incoming_message_compression_algorithm( + grpc_call* call, grpc_message_compression_algorithm algo) { + GPR_ASSERT(algo < GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT); + call->incoming_message_compression_algorithm = algo; +} + +static void set_incoming_stream_compression_algorithm( + grpc_call* call, grpc_stream_compression_algorithm algo) { + GPR_ASSERT(algo < GRPC_STREAM_COMPRESS_ALGORITHMS_COUNT); + call->incoming_stream_compression_algorithm = algo; +} + +grpc_compression_algorithm grpc_call_test_only_get_compression_algorithm( + grpc_call* call) { + grpc_compression_algorithm algorithm = GRPC_COMPRESS_NONE; + grpc_compression_algorithm_from_message_stream_compression_algorithm( + &algorithm, call->incoming_message_compression_algorithm, + call->incoming_stream_compression_algorithm); + return algorithm; +} + +static grpc_compression_algorithm compression_algorithm_for_level_locked( + grpc_call* call, grpc_compression_level level) { + return grpc_compression_algorithm_for_level(level, + call->encodings_accepted_by_peer); +} + +uint32_t grpc_call_test_only_get_message_flags(grpc_call* call) { + uint32_t flags; + flags = call->test_only_last_message_flags; + return flags; +} + +static void destroy_encodings_accepted_by_peer(void* /*p*/) {} + +static void set_encodings_accepted_by_peer(grpc_call* /*call*/, + grpc_mdelem mdel, + uint32_t* encodings_accepted_by_peer, + bool stream_encoding) { + size_t i; + uint32_t algorithm; + grpc_slice_buffer accept_encoding_parts; + grpc_slice accept_encoding_slice; + void* accepted_user_data; + + accepted_user_data = + grpc_mdelem_get_user_data(mdel, destroy_encodings_accepted_by_peer); + if (accepted_user_data != nullptr) { + *encodings_accepted_by_peer = static_cast( + reinterpret_cast(accepted_user_data) - 1); + return; + } + + *encodings_accepted_by_peer = 0; + + accept_encoding_slice = GRPC_MDVALUE(mdel); + grpc_slice_buffer_init(&accept_encoding_parts); + grpc_slice_split_without_space(accept_encoding_slice, ",", + &accept_encoding_parts); + + grpc_core::SetBit(encodings_accepted_by_peer, GRPC_COMPRESS_NONE); + for (i = 0; i < accept_encoding_parts.count; i++) { + int r; + grpc_slice accept_encoding_entry_slice = accept_encoding_parts.slices[i]; + if (!stream_encoding) { + r = grpc_message_compression_algorithm_parse( + accept_encoding_entry_slice, + reinterpret_cast(&algorithm)); + } else { + r = grpc_stream_compression_algorithm_parse( + accept_encoding_entry_slice, + reinterpret_cast(&algorithm)); + } + if (r) { + grpc_core::SetBit(encodings_accepted_by_peer, algorithm); + } else { + char* accept_encoding_entry_str = + grpc_slice_to_c_string(accept_encoding_entry_slice); + gpr_log(GPR_DEBUG, + "Unknown entry in accept encoding metadata: '%s'. Ignoring.", + accept_encoding_entry_str); + gpr_free(accept_encoding_entry_str); + } + } + + grpc_slice_buffer_destroy_internal(&accept_encoding_parts); + + grpc_mdelem_set_user_data( + mdel, destroy_encodings_accepted_by_peer, + reinterpret_cast( + static_cast(*encodings_accepted_by_peer) + 1)); +} + +uint32_t grpc_call_test_only_get_encodings_accepted_by_peer(grpc_call* call) { + uint32_t encodings_accepted_by_peer; + encodings_accepted_by_peer = call->encodings_accepted_by_peer; + return encodings_accepted_by_peer; +} + +grpc_stream_compression_algorithm +grpc_call_test_only_get_incoming_stream_encodings(grpc_call* call) { + return call->incoming_stream_compression_algorithm; +} + +static grpc_linked_mdelem* linked_from_md(grpc_metadata* md) { + return reinterpret_cast(&md->internal_data); +} + +static grpc_metadata* get_md_elem(grpc_metadata* metadata, + grpc_metadata* additional_metadata, int i, + int count) { + grpc_metadata* res = + i < count ? &metadata[i] : &additional_metadata[i - count]; + GPR_ASSERT(res); + return res; +} + +static int prepare_application_metadata(grpc_call* call, int count, + grpc_metadata* metadata, + int is_trailing, + int prepend_extra_metadata, + grpc_metadata* additional_metadata, + int additional_metadata_count) { + int total_count = count + additional_metadata_count; + int i; + grpc_metadata_batch* batch = is_trailing ? &call->send_trailing_metadata + : &call->send_initial_metadata; + for (i = 0; i < total_count; i++) { + grpc_metadata* md = get_md_elem(metadata, additional_metadata, i, count); + grpc_linked_mdelem* l = linked_from_md(md); + GPR_ASSERT(sizeof(grpc_linked_mdelem) == sizeof(md->internal_data)); + if (!GRPC_LOG_IF_ERROR("validate_metadata", + grpc_validate_header_key_is_legal(md->key))) { + break; + } else if (!grpc_is_binary_header_internal(md->key) && + !GRPC_LOG_IF_ERROR( + "validate_metadata", + grpc_validate_header_nonbin_value_is_legal(md->value))) { + break; + } else if (GRPC_SLICE_LENGTH(md->value) >= UINT32_MAX) { + // HTTP2 hpack encoding has a maximum limit. + break; + } + l->md = grpc_mdelem_from_grpc_metadata(const_cast(md)); + } + if (i != total_count) { + for (int j = 0; j < i; j++) { + grpc_metadata* md = get_md_elem(metadata, additional_metadata, j, count); + grpc_linked_mdelem* l = linked_from_md(md); + GRPC_MDELEM_UNREF(l->md); + } + return 0; + } + if (prepend_extra_metadata) { + if (call->send_extra_metadata_count == 0) { + prepend_extra_metadata = 0; + } else { + for (i = 0; i < call->send_extra_metadata_count; i++) { + GRPC_LOG_IF_ERROR("prepare_application_metadata", + batch->LinkTail(&call->send_extra_metadata[i])); + } + } + } + for (i = 0; i < total_count; i++) { + grpc_metadata* md = get_md_elem(metadata, additional_metadata, i, count); + grpc_linked_mdelem* l = linked_from_md(md); + grpc_error_handle error = batch->LinkTail(l); + if (error != GRPC_ERROR_NONE) { + GRPC_MDELEM_UNREF(l->md); + } + GRPC_LOG_IF_ERROR("prepare_application_metadata", error); + } + call->send_extra_metadata_count = 0; + + return 1; +} + +static grpc_message_compression_algorithm decode_message_compression( + grpc_mdelem md) { + grpc_message_compression_algorithm algorithm = + grpc_message_compression_algorithm_from_slice(GRPC_MDVALUE(md)); + if (algorithm == GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) { + char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md)); + gpr_log(GPR_ERROR, + "Invalid incoming message compression algorithm: '%s'. " + "Interpreting incoming data as uncompressed.", + md_c_str); + gpr_free(md_c_str); + return GRPC_MESSAGE_COMPRESS_NONE; + } + return algorithm; +} + +static grpc_stream_compression_algorithm decode_stream_compression( + grpc_mdelem md) { + grpc_stream_compression_algorithm algorithm = + grpc_stream_compression_algorithm_from_slice(GRPC_MDVALUE(md)); + if (algorithm == GRPC_STREAM_COMPRESS_ALGORITHMS_COUNT) { + char* md_c_str = grpc_slice_to_c_string(GRPC_MDVALUE(md)); + gpr_log(GPR_ERROR, + "Invalid incoming stream compression algorithm: '%s'. Interpreting " + "incoming data as uncompressed.", + md_c_str); + gpr_free(md_c_str); + return GRPC_STREAM_COMPRESS_NONE; + } + return algorithm; +} + +static void publish_app_metadata(grpc_call* call, grpc_metadata_batch* b, + int is_trailing) { + if (b->non_deadline_count() == 0) return; + if (!call->is_client && is_trailing) return; + if (is_trailing && call->buffered_metadata[1] == nullptr) return; + GPR_TIMER_SCOPE("publish_app_metadata", 0); + grpc_metadata_array* dest; + grpc_metadata* mdusr; + dest = call->buffered_metadata[is_trailing]; + if (dest->count + b->non_deadline_count() > dest->capacity) { + dest->capacity = std::max(dest->capacity + b->non_deadline_count(), + dest->capacity * 3 / 2); + dest->metadata = static_cast( + gpr_realloc(dest->metadata, sizeof(grpc_metadata) * dest->capacity)); + } + b->ForEach([&](grpc_mdelem md) { + mdusr = &dest->metadata[dest->count++]; + /* we pass back borrowed slices that are valid whilst the call is valid */ + mdusr->key = GRPC_MDKEY(md); + mdusr->value = GRPC_MDVALUE(md); + }); +} + +static void recv_initial_filter(grpc_call* call, grpc_metadata_batch* b) { + if (b->legacy_index()->named.content_encoding != nullptr) { + GPR_TIMER_SCOPE("incoming_stream_compression_algorithm", 0); + set_incoming_stream_compression_algorithm( + call, decode_stream_compression( + b->legacy_index()->named.content_encoding->md)); + b->Remove(GRPC_BATCH_CONTENT_ENCODING); + } + if (b->legacy_index()->named.grpc_encoding != nullptr) { + GPR_TIMER_SCOPE("incoming_message_compression_algorithm", 0); + set_incoming_message_compression_algorithm( + call, + decode_message_compression(b->legacy_index()->named.grpc_encoding->md)); + b->Remove(GRPC_BATCH_GRPC_ENCODING); + } + uint32_t message_encodings_accepted_by_peer = 1u; + uint32_t stream_encodings_accepted_by_peer = 1u; + if (b->legacy_index()->named.grpc_accept_encoding != nullptr) { + GPR_TIMER_SCOPE("encodings_accepted_by_peer", 0); + set_encodings_accepted_by_peer( + call, b->legacy_index()->named.grpc_accept_encoding->md, + &message_encodings_accepted_by_peer, false); + b->Remove(GRPC_BATCH_GRPC_ACCEPT_ENCODING); + } + if (b->legacy_index()->named.accept_encoding != nullptr) { + GPR_TIMER_SCOPE("stream_encodings_accepted_by_peer", 0); + set_encodings_accepted_by_peer(call, + b->legacy_index()->named.accept_encoding->md, + &stream_encodings_accepted_by_peer, true); + b->Remove(GRPC_BATCH_ACCEPT_ENCODING); + } + call->encodings_accepted_by_peer = + grpc_compression_bitset_from_message_stream_compression_bitset( + message_encodings_accepted_by_peer, + stream_encodings_accepted_by_peer); + publish_app_metadata(call, b, false); +} + +static void recv_trailing_filter(void* args, grpc_metadata_batch* b, + grpc_error_handle batch_error) { + grpc_call* call = static_cast(args); + if (batch_error != GRPC_ERROR_NONE) { + set_final_status(call, batch_error); + } else if (b->legacy_index()->named.grpc_status != nullptr) { + grpc_status_code status_code = grpc_get_status_code_from_metadata( + b->legacy_index()->named.grpc_status->md); + grpc_error_handle error = GRPC_ERROR_NONE; + if (status_code != GRPC_STATUS_OK) { + char* peer = grpc_call_get_peer(call); + error = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "Error received from peer ", peer)), + GRPC_ERROR_INT_GRPC_STATUS, + static_cast(status_code)); + gpr_free(peer); + } + if (b->legacy_index()->named.grpc_message != nullptr) { + error = grpc_error_set_str( + error, GRPC_ERROR_STR_GRPC_MESSAGE, + grpc_core::StringViewFromSlice( + GRPC_MDVALUE(b->legacy_index()->named.grpc_message->md))); + b->Remove(GRPC_BATCH_GRPC_MESSAGE); + } else if (error != GRPC_ERROR_NONE) { + error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, ""); + } + set_final_status(call, GRPC_ERROR_REF(error)); + b->Remove(GRPC_BATCH_GRPC_STATUS); + GRPC_ERROR_UNREF(error); + } else if (!call->is_client) { + set_final_status(call, GRPC_ERROR_NONE); + } else { + gpr_log(GPR_DEBUG, + "Received trailing metadata with no error and no status"); + set_final_status( + call, grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("No status received"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNKNOWN)); + } + publish_app_metadata(call, b, true); +} + +grpc_core::Arena* grpc_call_get_arena(grpc_call* call) { return call->arena; } + +grpc_call_stack* grpc_call_get_call_stack(grpc_call* call) { + return CALL_STACK_FROM_CALL(call); +} + +/******************************************************************************* + * BATCH API IMPLEMENTATION + */ + +static bool are_write_flags_valid(uint32_t flags) { + /* check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set */ + const uint32_t allowed_write_positions = + (GRPC_WRITE_USED_MASK | GRPC_WRITE_INTERNAL_USED_MASK); + const uint32_t invalid_positions = ~allowed_write_positions; + return !(flags & invalid_positions); +} + +static bool are_initial_metadata_flags_valid(uint32_t flags, bool is_client) { + /* check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set */ + uint32_t invalid_positions = ~GRPC_INITIAL_METADATA_USED_MASK; + if (!is_client) { + invalid_positions |= GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST; + } + return !(flags & invalid_positions); +} + +static size_t batch_slot_for_op(grpc_op_type type) { + switch (type) { + case GRPC_OP_SEND_INITIAL_METADATA: + return 0; + case GRPC_OP_SEND_MESSAGE: + return 1; + case GRPC_OP_SEND_CLOSE_FROM_CLIENT: + case GRPC_OP_SEND_STATUS_FROM_SERVER: + return 2; + case GRPC_OP_RECV_INITIAL_METADATA: + return 3; + case GRPC_OP_RECV_MESSAGE: + return 4; + case GRPC_OP_RECV_CLOSE_ON_SERVER: + case GRPC_OP_RECV_STATUS_ON_CLIENT: + return 5; + } + GPR_UNREACHABLE_CODE(return 123456789); +} + +static batch_control* reuse_or_allocate_batch_control(grpc_call* call, + const grpc_op* ops) { + size_t slot_idx = batch_slot_for_op(ops[0].op); + batch_control** pslot = &call->active_batches[slot_idx]; + batch_control* bctl; + if (*pslot != nullptr) { + bctl = *pslot; + if (bctl->call != nullptr) { + return nullptr; + } + bctl->~batch_control(); + bctl->op = {}; + new (&bctl->batch_error) AtomicError(); + } else { + bctl = call->arena->New(); + *pslot = bctl; + } + bctl->call = call; + bctl->op.payload = &call->stream_op_payload; + return bctl; +} + +static void finish_batch_completion(void* user_data, + grpc_cq_completion* /*storage*/) { + batch_control* bctl = static_cast(user_data); + grpc_call* call = bctl->call; + bctl->call = nullptr; + GRPC_CALL_INTERNAL_UNREF(call, "completion"); +} + +static void reset_batch_errors(batch_control* bctl) { + bctl->batch_error.set(GRPC_ERROR_NONE); +} + +static void post_batch_completion(batch_control* bctl) { + grpc_call* next_child_call; + grpc_call* call = bctl->call; + grpc_error_handle error = GRPC_ERROR_REF(bctl->batch_error.get()); + + if (bctl->op.send_initial_metadata) { + call->send_initial_metadata.Clear(); + } + if (bctl->op.send_message) { + if (bctl->op.payload->send_message.stream_write_closed && + error == GRPC_ERROR_NONE) { + error = grpc_error_add_child( + error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Attempt to send message after stream was closed.")); + } + call->sending_message = false; + } + if (bctl->op.send_trailing_metadata) { + call->send_trailing_metadata.Clear(); + } + if (bctl->op.recv_trailing_metadata) { + /* propagate cancellation to any interested children */ + gpr_atm_rel_store(&call->received_final_op_atm, 1); + parent_call* pc = get_parent_call(call); + if (pc != nullptr) { + grpc_call* child; + gpr_mu_lock(&pc->child_list_mu); + child = pc->first_child; + if (child != nullptr) { + do { + next_child_call = child->child->sibling_next; + if (child->cancellation_is_inherited) { + GRPC_CALL_INTERNAL_REF(child, "propagate_cancel"); + cancel_with_error(child, GRPC_ERROR_CANCELLED); + GRPC_CALL_INTERNAL_UNREF(child, "propagate_cancel"); + } + child = next_child_call; + } while (child != pc->first_child); + } + gpr_mu_unlock(&pc->child_list_mu); + } + GRPC_ERROR_UNREF(error); + error = GRPC_ERROR_NONE; + } + if (error != GRPC_ERROR_NONE && bctl->op.recv_message && + *call->receiving_buffer != nullptr) { + grpc_byte_buffer_destroy(*call->receiving_buffer); + *call->receiving_buffer = nullptr; + } + reset_batch_errors(bctl); + + if (bctl->completion_data.notify_tag.is_closure) { + /* unrefs error */ + bctl->call = nullptr; + grpc_core::Closure::Run( + DEBUG_LOCATION, + static_cast(bctl->completion_data.notify_tag.tag), + error); + GRPC_CALL_INTERNAL_UNREF(call, "completion"); + } else { + /* unrefs error */ + grpc_cq_end_op(bctl->call->cq, bctl->completion_data.notify_tag.tag, error, + finish_batch_completion, bctl, + &bctl->completion_data.cq_completion); + } +} + +static void finish_batch_step(batch_control* bctl) { + if (GPR_UNLIKELY(bctl->completed_batch_step())) { + post_batch_completion(bctl); + } +} + +static void continue_receiving_slices(batch_control* bctl) { + grpc_error_handle error; + grpc_call* call = bctl->call; + for (;;) { + size_t remaining = call->receiving_stream->length() - + (*call->receiving_buffer)->data.raw.slice_buffer.length; + if (remaining == 0) { + call->receiving_message = false; + call->receiving_stream.reset(); + finish_batch_step(bctl); + return; + } + if (call->receiving_stream->Next(remaining, &call->receiving_slice_ready)) { + error = call->receiving_stream->Pull(&call->receiving_slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&(*call->receiving_buffer)->data.raw.slice_buffer, + call->receiving_slice); + } else { + call->receiving_stream.reset(); + grpc_byte_buffer_destroy(*call->receiving_buffer); + *call->receiving_buffer = nullptr; + call->receiving_message = false; + finish_batch_step(bctl); + GRPC_ERROR_UNREF(error); + return; + } + } else { + return; + } + } +} + +static void receiving_slice_ready(void* bctlp, grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + bool release_error = false; + + if (error == GRPC_ERROR_NONE) { + grpc_slice slice; + error = call->receiving_stream->Pull(&slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&(*call->receiving_buffer)->data.raw.slice_buffer, + slice); + continue_receiving_slices(bctl); + } else { + /* Error returned by ByteStream::Pull() needs to be released manually */ + release_error = true; + } + } + + if (error != GRPC_ERROR_NONE) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures)) { + GRPC_LOG_IF_ERROR("receiving_slice_ready", GRPC_ERROR_REF(error)); + } + call->receiving_stream.reset(); + grpc_byte_buffer_destroy(*call->receiving_buffer); + *call->receiving_buffer = nullptr; + call->receiving_message = false; + finish_batch_step(bctl); + if (release_error) { + GRPC_ERROR_UNREF(error); + } + } +} + +static void process_data_after_md(batch_control* bctl) { + grpc_call* call = bctl->call; + if (call->receiving_stream == nullptr) { + *call->receiving_buffer = nullptr; + call->receiving_message = false; + finish_batch_step(bctl); + } else { + call->test_only_last_message_flags = call->receiving_stream->flags(); + if ((call->receiving_stream->flags() & GRPC_WRITE_INTERNAL_COMPRESS) && + (call->incoming_message_compression_algorithm > + GRPC_MESSAGE_COMPRESS_NONE)) { + grpc_compression_algorithm algo; + GPR_ASSERT( + grpc_compression_algorithm_from_message_stream_compression_algorithm( + &algo, call->incoming_message_compression_algorithm, + (grpc_stream_compression_algorithm)0)); + *call->receiving_buffer = + grpc_raw_compressed_byte_buffer_create(nullptr, 0, algo); + } else { + *call->receiving_buffer = grpc_raw_byte_buffer_create(nullptr, 0); + } + GRPC_CLOSURE_INIT(&call->receiving_slice_ready, receiving_slice_ready, bctl, + grpc_schedule_on_exec_ctx); + continue_receiving_slices(bctl); + } +} + +static void receiving_stream_ready(void* bctlp, grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + if (error != GRPC_ERROR_NONE) { + call->receiving_stream.reset(); + if (bctl->batch_error.ok()) { + bctl->batch_error.set(error); + } + cancel_with_error(call, GRPC_ERROR_REF(error)); + } + /* If recv_state is RECV_NONE, we will save the batch_control + * object with rel_cas, and will not use it after the cas. Its corresponding + * acq_load is in receiving_initial_metadata_ready() */ + if (error != GRPC_ERROR_NONE || call->receiving_stream == nullptr || + !gpr_atm_rel_cas(&call->recv_state, RECV_NONE, + reinterpret_cast(bctlp))) { + process_data_after_md(bctl); + } +} + +// The recv_message_ready callback used when sending a batch containing +// a recv_message op down the filter stack. Yields the call combiner +// before processing the received message. +static void receiving_stream_ready_in_call_combiner(void* bctlp, + grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + GRPC_CALL_COMBINER_STOP(&call->call_combiner, "recv_message_ready"); + receiving_stream_ready(bctlp, error); +} + +static void GPR_ATTRIBUTE_NOINLINE +handle_both_stream_and_msg_compression_set(grpc_call* call) { + std::string error_msg = absl::StrFormat( + "Incoming stream has both stream compression (%d) and message " + "compression (%d).", + call->incoming_stream_compression_algorithm, + call->incoming_message_compression_algorithm); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + cancel_with_status(call, GRPC_STATUS_INTERNAL, error_msg.c_str()); +} + +static void GPR_ATTRIBUTE_NOINLINE +handle_error_parsing_compression_algorithm(grpc_call* call) { + std::string error_msg = absl::StrFormat( + "Error in incoming message compression (%d) or stream " + "compression (%d).", + call->incoming_stream_compression_algorithm, + call->incoming_message_compression_algorithm); + cancel_with_status(call, GRPC_STATUS_INTERNAL, error_msg.c_str()); +} + +static void GPR_ATTRIBUTE_NOINLINE handle_invalid_compression( + grpc_call* call, grpc_compression_algorithm compression_algorithm) { + std::string error_msg = absl::StrFormat( + "Invalid compression algorithm value '%d'.", compression_algorithm); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + cancel_with_status(call, GRPC_STATUS_UNIMPLEMENTED, error_msg.c_str()); +} + +static void GPR_ATTRIBUTE_NOINLINE handle_compression_algorithm_disabled( + grpc_call* call, grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + std::string error_msg = + absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + cancel_with_status(call, GRPC_STATUS_UNIMPLEMENTED, error_msg.c_str()); +} + +static void GPR_ATTRIBUTE_NOINLINE handle_compression_algorithm_not_accepted( + grpc_call* call, grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + gpr_log(GPR_ERROR, + "Compression algorithm ('%s') not present in the bitset of " + "accepted encodings ('0x%x')", + algo_name, call->encodings_accepted_by_peer); +} + +static void validate_filtered_metadata(batch_control* bctl) { + grpc_compression_algorithm compression_algorithm; + grpc_call* call = bctl->call; + if (GPR_UNLIKELY(call->incoming_stream_compression_algorithm != + GRPC_STREAM_COMPRESS_NONE && + call->incoming_message_compression_algorithm != + GRPC_MESSAGE_COMPRESS_NONE)) { + handle_both_stream_and_msg_compression_set(call); + } else if ( + GPR_UNLIKELY( + grpc_compression_algorithm_from_message_stream_compression_algorithm( + &compression_algorithm, + call->incoming_message_compression_algorithm, + call->incoming_stream_compression_algorithm) == 0)) { + handle_error_parsing_compression_algorithm(call); + } else { + const grpc_compression_options compression_options = + grpc_channel_compression_options(call->channel); + if (GPR_UNLIKELY(compression_algorithm >= GRPC_COMPRESS_ALGORITHMS_COUNT)) { + handle_invalid_compression(call, compression_algorithm); + } else if (GPR_UNLIKELY( + grpc_compression_options_is_algorithm_enabled_internal( + &compression_options, compression_algorithm) == 0)) { + /* check if algorithm is supported by current channel config */ + handle_compression_algorithm_disabled(call, compression_algorithm); + } + /* GRPC_COMPRESS_NONE is always set. */ + GPR_DEBUG_ASSERT(call->encodings_accepted_by_peer != 0); + if (GPR_UNLIKELY(!grpc_core::GetBit(call->encodings_accepted_by_peer, + compression_algorithm))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + handle_compression_algorithm_not_accepted(call, compression_algorithm); + } + } + } +} + +static void receiving_initial_metadata_ready(void* bctlp, + grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + + GRPC_CALL_COMBINER_STOP(&call->call_combiner, "recv_initial_metadata_ready"); + + if (error == GRPC_ERROR_NONE) { + grpc_metadata_batch* md = &call->recv_initial_metadata; + recv_initial_filter(call, md); + + /* TODO(ctiller): this could be moved into recv_initial_filter now */ + GPR_TIMER_SCOPE("validate_filtered_metadata", 0); + validate_filtered_metadata(bctl); + + absl::optional deadline = + md->get(grpc_core::GrpcTimeoutMetadata()); + if (deadline.has_value() && !call->is_client) { + call->send_deadline = *deadline; + } + } else { + if (bctl->batch_error.ok()) { + bctl->batch_error.set(error); + } + cancel_with_error(call, GRPC_ERROR_REF(error)); + } + + grpc_closure* saved_rsr_closure = nullptr; + while (true) { + gpr_atm rsr_bctlp = gpr_atm_acq_load(&call->recv_state); + /* Should only receive initial metadata once */ + GPR_ASSERT(rsr_bctlp != 1); + if (rsr_bctlp == 0) { + /* We haven't seen initial metadata and messages before, thus initial + * metadata is received first. + * no_barrier_cas is used, as this function won't access the batch_control + * object saved by receiving_stream_ready() if the initial metadata is + * received first. */ + if (gpr_atm_no_barrier_cas(&call->recv_state, RECV_NONE, + RECV_INITIAL_METADATA_FIRST)) { + break; + } + } else { + /* Already received messages */ + saved_rsr_closure = + GRPC_CLOSURE_CREATE(receiving_stream_ready, (batch_control*)rsr_bctlp, + grpc_schedule_on_exec_ctx); + /* No need to modify recv_state */ + break; + } + } + if (saved_rsr_closure != nullptr) { + grpc_core::Closure::Run(DEBUG_LOCATION, saved_rsr_closure, + GRPC_ERROR_REF(error)); + } + + finish_batch_step(bctl); +} + +static void receiving_trailing_metadata_ready(void* bctlp, + grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + GRPC_CALL_COMBINER_STOP(&call->call_combiner, "recv_trailing_metadata_ready"); + grpc_metadata_batch* md = &call->recv_trailing_metadata; + recv_trailing_filter(call, md, GRPC_ERROR_REF(error)); + finish_batch_step(bctl); +} + +static void finish_batch(void* bctlp, grpc_error_handle error) { + batch_control* bctl = static_cast(bctlp); + grpc_call* call = bctl->call; + GRPC_CALL_COMBINER_STOP(&call->call_combiner, "on_complete"); + if (bctl->batch_error.ok()) { + bctl->batch_error.set(error); + } + if (error != GRPC_ERROR_NONE) { + cancel_with_error(call, GRPC_ERROR_REF(error)); + } + finish_batch_step(bctl); +} + +static void free_no_op_completion(void* /*p*/, grpc_cq_completion* completion) { + gpr_free(completion); +} + +static grpc_call_error call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* notify_tag, + int is_notify_tag_closure) { + GPR_TIMER_SCOPE("call_start_batch", 0); + + size_t i; + const grpc_op* op; + batch_control* bctl; + bool has_send_ops = false; + int num_recv_ops = 0; + grpc_call_error error = GRPC_CALL_OK; + grpc_transport_stream_op_batch* stream_op; + grpc_transport_stream_op_batch_payload* stream_op_payload; + + GRPC_CALL_LOG_BATCH(GPR_INFO, ops, nops); + + if (nops == 0) { + if (!is_notify_tag_closure) { + GPR_ASSERT(grpc_cq_begin_op(call->cq, notify_tag)); + grpc_cq_end_op(call->cq, notify_tag, GRPC_ERROR_NONE, + free_no_op_completion, nullptr, + static_cast( + gpr_malloc(sizeof(grpc_cq_completion)))); + } else { + grpc_core::Closure::Run(DEBUG_LOCATION, + static_cast(notify_tag), + GRPC_ERROR_NONE); + } + error = GRPC_CALL_OK; + goto done; + } + + bctl = reuse_or_allocate_batch_control(call, ops); + if (bctl == nullptr) { + return GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + } + bctl->completion_data.notify_tag.tag = notify_tag; + bctl->completion_data.notify_tag.is_closure = + static_cast(is_notify_tag_closure != 0); + + stream_op = &bctl->op; + stream_op_payload = &call->stream_op_payload; + + /* rewrite batch ops into a transport op */ + for (i = 0; i < nops; i++) { + op = &ops[i]; + if (op->reserved != nullptr) { + error = GRPC_CALL_ERROR; + goto done_with_error; + } + switch (op->op) { + case GRPC_OP_SEND_INITIAL_METADATA: { + /* Flag validation: currently allow no flags */ + if (!are_initial_metadata_flags_valid(op->flags, call->is_client)) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (call->sent_initial_metadata) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + // TODO(juanlishen): If the user has already specified a compression + // algorithm by setting the initial metadata with key of + // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that + // with the compression algorithm mapped from compression level. + /* process compression level */ + grpc_metadata& compression_md = call->compression_md; + compression_md.key = grpc_empty_slice(); + compression_md.value = grpc_empty_slice(); + size_t additional_metadata_count = 0; + grpc_compression_level effective_compression_level = + GRPC_COMPRESS_LEVEL_NONE; + bool level_set = false; + if (op->data.send_initial_metadata.maybe_compression_level.is_set) { + effective_compression_level = + op->data.send_initial_metadata.maybe_compression_level.level; + level_set = true; + } else { + const grpc_compression_options copts = + grpc_channel_compression_options(call->channel); + if (copts.default_level.is_set) { + level_set = true; + effective_compression_level = copts.default_level.level; + } + } + // Currently, only server side supports compression level setting. + if (level_set && !call->is_client) { + const grpc_compression_algorithm calgo = + compression_algorithm_for_level_locked( + call, effective_compression_level); + // The following metadata will be checked and removed by the message + // compression filter. It will be used as the call's compression + // algorithm. + compression_md.key = GRPC_MDSTR_GRPC_INTERNAL_ENCODING_REQUEST; + compression_md.value = grpc_compression_algorithm_slice(calgo); + additional_metadata_count++; + } + if (op->data.send_initial_metadata.count + additional_metadata_count > + INT_MAX) { + error = GRPC_CALL_ERROR_INVALID_METADATA; + goto done_with_error; + } + stream_op->send_initial_metadata = true; + call->sent_initial_metadata = true; + if (!prepare_application_metadata( + call, static_cast(op->data.send_initial_metadata.count), + op->data.send_initial_metadata.metadata, 0, call->is_client, + &compression_md, static_cast(additional_metadata_count))) { + error = GRPC_CALL_ERROR_INVALID_METADATA; + goto done_with_error; + } + /* TODO(ctiller): just make these the same variable? */ + if (call->is_client && call->send_deadline != GRPC_MILLIS_INF_FUTURE) { + call->send_initial_metadata.Set(grpc_core::GrpcTimeoutMetadata(), + call->send_deadline); + } + stream_op_payload->send_initial_metadata.send_initial_metadata = + &call->send_initial_metadata; + stream_op_payload->send_initial_metadata.send_initial_metadata_flags = + op->flags; + if (call->is_client) { + stream_op_payload->send_initial_metadata.peer_string = + &call->peer_string; + } + has_send_ops = true; + break; + } + case GRPC_OP_SEND_MESSAGE: { + if (!are_write_flags_valid(op->flags)) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (op->data.send_message.send_message == nullptr) { + error = GRPC_CALL_ERROR_INVALID_MESSAGE; + goto done_with_error; + } + if (call->sending_message) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + uint32_t flags = op->flags; + /* If the outgoing buffer is already compressed, mark it as so in the + flags. These will be picked up by the compression filter and further + (wasteful) attempts at compression skipped. */ + if (op->data.send_message.send_message->data.raw.compression > + GRPC_COMPRESS_NONE) { + flags |= GRPC_WRITE_INTERNAL_COMPRESS; + } + stream_op->send_message = true; + call->sending_message = true; + call->sending_stream.Init( + &op->data.send_message.send_message->data.raw.slice_buffer, flags); + stream_op_payload->send_message.send_message.reset( + call->sending_stream.get()); + has_send_ops = true; + break; + } + case GRPC_OP_SEND_CLOSE_FROM_CLIENT: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (!call->is_client) { + error = GRPC_CALL_ERROR_NOT_ON_SERVER; + goto done_with_error; + } + if (call->sent_final_op) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + stream_op->send_trailing_metadata = true; + call->sent_final_op = true; + stream_op_payload->send_trailing_metadata.send_trailing_metadata = + &call->send_trailing_metadata; + has_send_ops = true; + break; + } + case GRPC_OP_SEND_STATUS_FROM_SERVER: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (call->is_client) { + error = GRPC_CALL_ERROR_NOT_ON_CLIENT; + goto done_with_error; + } + if (call->sent_final_op) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + if (op->data.send_status_from_server.trailing_metadata_count > + INT_MAX) { + error = GRPC_CALL_ERROR_INVALID_METADATA; + goto done_with_error; + } + stream_op->send_trailing_metadata = true; + call->sent_final_op = true; + GPR_ASSERT(call->send_extra_metadata_count == 0); + call->send_extra_metadata_count = 1; + call->send_extra_metadata[0].md = grpc_get_reffed_status_elem( + op->data.send_status_from_server.status); + grpc_error_handle status_error = + op->data.send_status_from_server.status == GRPC_STATUS_OK + ? GRPC_ERROR_NONE + : grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Server returned error"), + GRPC_ERROR_INT_GRPC_STATUS, + static_cast( + op->data.send_status_from_server.status)); + if (op->data.send_status_from_server.status_details != nullptr) { + call->send_extra_metadata[1].md = grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_MESSAGE, + grpc_slice_copy( + *op->data.send_status_from_server.status_details)); + call->send_extra_metadata_count++; + if (status_error != GRPC_ERROR_NONE) { + char* msg = grpc_slice_to_c_string( + GRPC_MDVALUE(call->send_extra_metadata[1].md)); + status_error = grpc_error_set_str(status_error, + GRPC_ERROR_STR_GRPC_MESSAGE, msg); + gpr_free(msg); + } + } + + call->status_error.set(status_error); + GRPC_ERROR_UNREF(status_error); + + if (!prepare_application_metadata( + call, + static_cast( + op->data.send_status_from_server.trailing_metadata_count), + op->data.send_status_from_server.trailing_metadata, 1, 1, + nullptr, 0)) { + for (int n = 0; n < call->send_extra_metadata_count; n++) { + GRPC_MDELEM_UNREF(call->send_extra_metadata[n].md); + } + call->send_extra_metadata_count = 0; + error = GRPC_CALL_ERROR_INVALID_METADATA; + goto done_with_error; + } + stream_op_payload->send_trailing_metadata.send_trailing_metadata = + &call->send_trailing_metadata; + stream_op_payload->send_trailing_metadata.sent = + &call->sent_server_trailing_metadata; + has_send_ops = true; + break; + } + case GRPC_OP_RECV_INITIAL_METADATA: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (call->received_initial_metadata) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + call->received_initial_metadata = true; + call->buffered_metadata[0] = + op->data.recv_initial_metadata.recv_initial_metadata; + GRPC_CLOSURE_INIT(&call->receiving_initial_metadata_ready, + receiving_initial_metadata_ready, bctl, + grpc_schedule_on_exec_ctx); + stream_op->recv_initial_metadata = true; + stream_op_payload->recv_initial_metadata.recv_initial_metadata = + &call->recv_initial_metadata; + stream_op_payload->recv_initial_metadata.recv_initial_metadata_ready = + &call->receiving_initial_metadata_ready; + if (call->is_client) { + stream_op_payload->recv_initial_metadata.trailing_metadata_available = + &call->is_trailers_only; + } else { + stream_op_payload->recv_initial_metadata.peer_string = + &call->peer_string; + } + ++num_recv_ops; + break; + } + case GRPC_OP_RECV_MESSAGE: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (call->receiving_message) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + call->receiving_message = true; + stream_op->recv_message = true; + call->receiving_buffer = op->data.recv_message.recv_message; + stream_op_payload->recv_message.recv_message = &call->receiving_stream; + stream_op_payload->recv_message.call_failed_before_recv_message = + &call->call_failed_before_recv_message; + GRPC_CLOSURE_INIT(&call->receiving_stream_ready, + receiving_stream_ready_in_call_combiner, bctl, + grpc_schedule_on_exec_ctx); + stream_op_payload->recv_message.recv_message_ready = + &call->receiving_stream_ready; + ++num_recv_ops; + break; + } + case GRPC_OP_RECV_STATUS_ON_CLIENT: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (!call->is_client) { + error = GRPC_CALL_ERROR_NOT_ON_SERVER; + goto done_with_error; + } + if (call->requested_final_op) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + call->requested_final_op = true; + call->buffered_metadata[1] = + op->data.recv_status_on_client.trailing_metadata; + call->final_op.client.status = op->data.recv_status_on_client.status; + call->final_op.client.status_details = + op->data.recv_status_on_client.status_details; + call->final_op.client.error_string = + op->data.recv_status_on_client.error_string; + stream_op->recv_trailing_metadata = true; + stream_op_payload->recv_trailing_metadata.recv_trailing_metadata = + &call->recv_trailing_metadata; + stream_op_payload->recv_trailing_metadata.collect_stats = + &call->final_info.stats.transport_stream_stats; + GRPC_CLOSURE_INIT(&call->receiving_trailing_metadata_ready, + receiving_trailing_metadata_ready, bctl, + grpc_schedule_on_exec_ctx); + stream_op_payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &call->receiving_trailing_metadata_ready; + ++num_recv_ops; + break; + } + case GRPC_OP_RECV_CLOSE_ON_SERVER: { + /* Flag validation: currently allow no flags */ + if (op->flags != 0) { + error = GRPC_CALL_ERROR_INVALID_FLAGS; + goto done_with_error; + } + if (call->is_client) { + error = GRPC_CALL_ERROR_NOT_ON_CLIENT; + goto done_with_error; + } + if (call->requested_final_op) { + error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; + goto done_with_error; + } + call->requested_final_op = true; + call->final_op.server.cancelled = + op->data.recv_close_on_server.cancelled; + stream_op->recv_trailing_metadata = true; + stream_op_payload->recv_trailing_metadata.recv_trailing_metadata = + &call->recv_trailing_metadata; + stream_op_payload->recv_trailing_metadata.collect_stats = + &call->final_info.stats.transport_stream_stats; + GRPC_CLOSURE_INIT(&call->receiving_trailing_metadata_ready, + receiving_trailing_metadata_ready, bctl, + grpc_schedule_on_exec_ctx); + stream_op_payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &call->receiving_trailing_metadata_ready; + ++num_recv_ops; + break; + } + } + } + + GRPC_CALL_INTERNAL_REF(call, "completion"); + if (!is_notify_tag_closure) { + GPR_ASSERT(grpc_cq_begin_op(call->cq, notify_tag)); + } + bctl->set_num_steps_to_complete((has_send_ops ? 1 : 0) + num_recv_ops); + + if (has_send_ops) { + GRPC_CLOSURE_INIT(&bctl->finish_batch, finish_batch, bctl, + grpc_schedule_on_exec_ctx); + stream_op->on_complete = &bctl->finish_batch; + } + + gpr_atm_rel_store(&call->any_ops_sent_atm, 1); + execute_batch(call, stream_op, &bctl->start_batch); + +done: + return error; + +done_with_error: + /* reverse any mutations that occurred */ + if (stream_op->send_initial_metadata) { + call->sent_initial_metadata = false; + call->send_initial_metadata.Clear(); + } + if (stream_op->send_message) { + call->sending_message = false; + call->sending_stream->Orphan(); + } + if (stream_op->send_trailing_metadata) { + call->sent_final_op = false; + call->send_trailing_metadata.Clear(); + } + if (stream_op->recv_initial_metadata) { + call->received_initial_metadata = false; + } + if (stream_op->recv_message) { + call->receiving_message = false; + } + if (stream_op->recv_trailing_metadata) { + call->requested_final_op = false; + } + goto done; +} + +grpc_call_error grpc_call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* tag, void* reserved) { + grpc_call_error err; + + GRPC_API_TRACE( + "grpc_call_start_batch(call=%p, ops=%p, nops=%lu, tag=%p, " + "reserved=%p)", + 5, (call, ops, (unsigned long)nops, tag, reserved)); + + if (reserved != nullptr) { + err = GRPC_CALL_ERROR; + } else { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + err = call_start_batch(call, ops, nops, tag, 0); + } + + return err; +} + +grpc_call_error grpc_call_start_batch_and_execute(grpc_call* call, + const grpc_op* ops, + size_t nops, + grpc_closure* closure) { + return call_start_batch(call, ops, nops, closure, 1); +} + +void grpc_call_context_set(grpc_call* call, grpc_context_index elem, + void* value, void (*destroy)(void* value)) { + if (call->context[elem].destroy) { + call->context[elem].destroy(call->context[elem].value); + } + call->context[elem].value = value; + call->context[elem].destroy = destroy; +} + +void* grpc_call_context_get(grpc_call* call, grpc_context_index elem) { + return call->context[elem].value; +} + +uint8_t grpc_call_is_client(grpc_call* call) { return call->is_client; } + +grpc_compression_algorithm grpc_call_compression_for_level( + grpc_call* call, grpc_compression_level level) { + grpc_compression_algorithm algo = + compression_algorithm_for_level_locked(call, level); + return algo; +} + +bool grpc_call_is_trailers_only(const grpc_call* call) { + bool result = call->is_trailers_only; + GPR_DEBUG_ASSERT(!result || call->recv_initial_metadata.empty()); + return result; +} + +int grpc_call_failed_before_recv_message(const grpc_call* c) { + return c->call_failed_before_recv_message; +} + +const char* grpc_call_error_to_string(grpc_call_error error) { + switch (error) { + case GRPC_CALL_ERROR: + return "GRPC_CALL_ERROR"; + case GRPC_CALL_ERROR_ALREADY_ACCEPTED: + return "GRPC_CALL_ERROR_ALREADY_ACCEPTED"; + case GRPC_CALL_ERROR_ALREADY_FINISHED: + return "GRPC_CALL_ERROR_ALREADY_FINISHED"; + case GRPC_CALL_ERROR_ALREADY_INVOKED: + return "GRPC_CALL_ERROR_ALREADY_INVOKED"; + case GRPC_CALL_ERROR_BATCH_TOO_BIG: + return "GRPC_CALL_ERROR_BATCH_TOO_BIG"; + case GRPC_CALL_ERROR_INVALID_FLAGS: + return "GRPC_CALL_ERROR_INVALID_FLAGS"; + case GRPC_CALL_ERROR_INVALID_MESSAGE: + return "GRPC_CALL_ERROR_INVALID_MESSAGE"; + case GRPC_CALL_ERROR_INVALID_METADATA: + return "GRPC_CALL_ERROR_INVALID_METADATA"; + case GRPC_CALL_ERROR_NOT_INVOKED: + return "GRPC_CALL_ERROR_NOT_INVOKED"; + case GRPC_CALL_ERROR_NOT_ON_CLIENT: + return "GRPC_CALL_ERROR_NOT_ON_CLIENT"; + case GRPC_CALL_ERROR_NOT_ON_SERVER: + return "GRPC_CALL_ERROR_NOT_ON_SERVER"; + case GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE: + return "GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE"; + case GRPC_CALL_ERROR_PAYLOAD_TYPE_MISMATCH: + return "GRPC_CALL_ERROR_PAYLOAD_TYPE_MISMATCH"; + case GRPC_CALL_ERROR_TOO_MANY_OPERATIONS: + return "GRPC_CALL_ERROR_TOO_MANY_OPERATIONS"; + case GRPC_CALL_ERROR_COMPLETION_QUEUE_SHUTDOWN: + return "GRPC_CALL_ERROR_COMPLETION_QUEUE_SHUTDOWN"; + case GRPC_CALL_OK: + return "GRPC_CALL_OK"; + } + GPR_UNREACHABLE_CODE(return "GRPC_CALL_ERROR_UNKNOW"); +} diff --git a/src/core/lib/surface/call_details.cc b/src/core/lib/surface/call_details.cc new file mode 100644 index 00000000..ce73c3a5 --- /dev/null +++ b/src/core/lib/surface/call_details.cc @@ -0,0 +1,41 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" + +void grpc_call_details_init(grpc_call_details* details) { + GRPC_API_TRACE("grpc_call_details_init(details=%p)", 1, (details)); + details->method = grpc_empty_slice(); + details->host = grpc_empty_slice(); +} + +void grpc_call_details_destroy(grpc_call_details* details) { + GRPC_API_TRACE("grpc_call_details_destroy(details=%p)", 1, (details)); + grpc_core::ExecCtx exec_ctx; + grpc_slice_unref_internal(details->method); + grpc_slice_unref_internal(details->host); +} diff --git a/src/core/lib/surface/call_log_batch.cc b/src/core/lib/surface/call_log_batch.cc new file mode 100644 index 00000000..6bc1881a --- /dev/null +++ b/src/core/lib/surface/call_log_batch.cc @@ -0,0 +1,111 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/call.h" + +static void add_metadata(const grpc_metadata* md, size_t count, + std::vector* b) { + if (md == nullptr) { + b->push_back("(nil)"); + return; + } + for (size_t i = 0; i < count; i++) { + b->push_back("\nkey="); + b->push_back(std::string(grpc_core::StringViewFromSlice(md[i].key))); + b->push_back(" value="); + char* dump = grpc_dump_slice(md[i].value, GPR_DUMP_HEX | GPR_DUMP_ASCII); + b->push_back(dump); + gpr_free(dump); + } +} + +static std::string grpc_op_string(const grpc_op* op) { + std::vector parts; + switch (op->op) { + case GRPC_OP_SEND_INITIAL_METADATA: + parts.push_back("SEND_INITIAL_METADATA"); + add_metadata(op->data.send_initial_metadata.metadata, + op->data.send_initial_metadata.count, &parts); + break; + case GRPC_OP_SEND_MESSAGE: + parts.push_back(absl::StrFormat("SEND_MESSAGE ptr=%p", + op->data.send_message.send_message)); + break; + case GRPC_OP_SEND_CLOSE_FROM_CLIENT: + parts.push_back("SEND_CLOSE_FROM_CLIENT"); + break; + case GRPC_OP_SEND_STATUS_FROM_SERVER: + parts.push_back( + absl::StrFormat("SEND_STATUS_FROM_SERVER status=%d details=", + op->data.send_status_from_server.status)); + if (op->data.send_status_from_server.status_details != nullptr) { + char* dump = grpc_dump_slice( + *op->data.send_status_from_server.status_details, GPR_DUMP_ASCII); + parts.push_back(dump); + gpr_free(dump); + } else { + parts.push_back("(null)"); + } + add_metadata(op->data.send_status_from_server.trailing_metadata, + op->data.send_status_from_server.trailing_metadata_count, + &parts); + break; + case GRPC_OP_RECV_INITIAL_METADATA: + parts.push_back(absl::StrFormat( + "RECV_INITIAL_METADATA ptr=%p", + op->data.recv_initial_metadata.recv_initial_metadata)); + break; + case GRPC_OP_RECV_MESSAGE: + parts.push_back(absl::StrFormat("RECV_MESSAGE ptr=%p", + op->data.recv_message.recv_message)); + break; + case GRPC_OP_RECV_STATUS_ON_CLIENT: + parts.push_back(absl::StrFormat( + "RECV_STATUS_ON_CLIENT metadata=%p status=%p details=%p", + op->data.recv_status_on_client.trailing_metadata, + op->data.recv_status_on_client.status, + op->data.recv_status_on_client.status_details)); + break; + case GRPC_OP_RECV_CLOSE_ON_SERVER: + parts.push_back(absl::StrFormat("RECV_CLOSE_ON_SERVER cancelled=%p", + op->data.recv_close_on_server.cancelled)); + } + return absl::StrJoin(parts, ""); +} + +void grpc_call_log_batch(const char* file, int line, gpr_log_severity severity, + const grpc_op* ops, size_t nops) { + for (size_t i = 0; i < nops; i++) { + gpr_log(file, line, severity, "ops[%" PRIuPTR "]: %s", i, + grpc_op_string(&ops[i]).c_str()); + } +} diff --git a/src/core/lib/surface/channel.cc b/src/core/lib/surface/channel.cc new file mode 100644 index 00000000..f4af87e6 --- /dev/null +++ b/src/core/lib/surface/channel.cc @@ -0,0 +1,535 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/channel.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/transport/static_metadata.h" + +/** Cache grpc-status: X mdelems for X = 0..NUM_CACHED_STATUS_ELEMS. + * Avoids needing to take a metadata context lock for sending status + * if the status code is <= NUM_CACHED_STATUS_ELEMS. + * Sized to allow the most commonly used codes to fit in + * (OK, Cancelled, Unknown). */ +#define NUM_CACHED_STATUS_ELEMS 3 + +static void destroy_channel(void* arg, grpc_error_handle error); + +grpc_channel* grpc_channel_create_with_builder( + grpc_channel_stack_builder* builder, + grpc_channel_stack_type channel_stack_type, + grpc_resource_user* resource_user, size_t preallocated_bytes, + grpc_error_handle* error) { + char* target = gpr_strdup(grpc_channel_stack_builder_get_target(builder)); + grpc_channel_args* args = grpc_channel_args_copy( + grpc_channel_stack_builder_get_channel_arguments(builder)); + grpc_channel* channel; + if (channel_stack_type == GRPC_SERVER_CHANNEL) { + GRPC_STATS_INC_SERVER_CHANNELS_CREATED(); + } else { + GRPC_STATS_INC_CLIENT_CHANNELS_CREATED(); + } + grpc_error_handle builder_error = grpc_channel_stack_builder_finish( + builder, sizeof(grpc_channel), 1, destroy_channel, nullptr, + reinterpret_cast(&channel)); + if (builder_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "channel stack builder failed: %s", + grpc_error_std_string(builder_error).c_str()); + GPR_ASSERT(channel == nullptr); + if (error != nullptr) { + *error = builder_error; + } else { + GRPC_ERROR_UNREF(builder_error); + } + gpr_free(target); + grpc_channel_args_destroy(args); + if (resource_user != nullptr) { + if (preallocated_bytes > 0) { + grpc_resource_user_free(resource_user, preallocated_bytes); + } + grpc_resource_user_unref(resource_user); + } + return nullptr; + } + channel->target = target; + channel->resource_user = resource_user; + channel->preallocated_bytes = preallocated_bytes; + channel->is_client = grpc_channel_stack_type_is_client(channel_stack_type); + channel->registration_table.Init(); + + gpr_atm_no_barrier_store( + &channel->call_size_estimate, + (gpr_atm)CHANNEL_STACK_FROM_CHANNEL(channel)->call_stack_size + + grpc_call_get_initial_size_estimate()); + + grpc_compression_options_init(&channel->compression_options); + for (size_t i = 0; i < args->num_args; i++) { + if (0 == + strcmp(args->args[i].key, GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL)) { + channel->compression_options.default_level.is_set = true; + channel->compression_options.default_level.level = + static_cast(grpc_channel_arg_get_integer( + &args->args[i], + {GRPC_COMPRESS_LEVEL_NONE, GRPC_COMPRESS_LEVEL_NONE, + GRPC_COMPRESS_LEVEL_COUNT - 1})); + } else if (0 == strcmp(args->args[i].key, + GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM)) { + channel->compression_options.default_algorithm.is_set = true; + channel->compression_options.default_algorithm.algorithm = + static_cast(grpc_channel_arg_get_integer( + &args->args[i], {GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, + GRPC_COMPRESS_ALGORITHMS_COUNT - 1})); + } else if (0 == + strcmp(args->args[i].key, + GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET)) { + channel->compression_options.enabled_algorithms_bitset = + static_cast(args->args[i].value.integer) | + 0x1; /* always support no compression */ + } else if (0 == strcmp(args->args[i].key, GRPC_ARG_CHANNELZ_CHANNEL_NODE)) { + if (args->args[i].type == GRPC_ARG_POINTER) { + GPR_ASSERT(args->args[i].value.pointer.p != nullptr); + channel->channelz_node = static_cast( + args->args[i].value.pointer.p) + ->Ref(); + } else { + gpr_log(GPR_DEBUG, + GRPC_ARG_CHANNELZ_CHANNEL_NODE " should be a pointer"); + } + } + } + + grpc_channel_args_destroy(args); + return channel; +} + +static grpc_core::UniquePtr get_default_authority( + const grpc_channel_args* input_args) { + bool has_default_authority = false; + char* ssl_override = nullptr; + grpc_core::UniquePtr default_authority; + const size_t num_args = input_args != nullptr ? input_args->num_args : 0; + for (size_t i = 0; i < num_args; ++i) { + if (0 == strcmp(input_args->args[i].key, GRPC_ARG_DEFAULT_AUTHORITY)) { + has_default_authority = true; + } else if (0 == strcmp(input_args->args[i].key, + GRPC_SSL_TARGET_NAME_OVERRIDE_ARG)) { + ssl_override = grpc_channel_arg_get_string(&input_args->args[i]); + } + } + if (!has_default_authority && ssl_override != nullptr) { + default_authority.reset(gpr_strdup(ssl_override)); + } + return default_authority; +} + +static grpc_channel_args* build_channel_args( + const grpc_channel_args* input_args, char* default_authority) { + grpc_arg new_args[1]; + size_t num_new_args = 0; + if (default_authority != nullptr) { + new_args[num_new_args++] = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), default_authority); + } + return grpc_channel_args_copy_and_add(input_args, new_args, num_new_args); +} + +namespace { + +void* channelz_node_copy(void* p) { + grpc_core::channelz::ChannelNode* node = + static_cast(p); + node->Ref().release(); + return p; +} +void channelz_node_destroy(void* p) { + grpc_core::channelz::ChannelNode* node = + static_cast(p); + node->Unref(); +} +int channelz_node_cmp(void* p1, void* p2) { + return grpc_core::QsortCompare(p1, p2); +} +const grpc_arg_pointer_vtable channelz_node_arg_vtable = { + channelz_node_copy, channelz_node_destroy, channelz_node_cmp}; + +void CreateChannelzNode(grpc_channel_stack_builder* builder) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + // Check whether channelz is enabled. + const bool channelz_enabled = grpc_channel_args_find_bool( + args, GRPC_ARG_ENABLE_CHANNELZ, GRPC_ENABLE_CHANNELZ_DEFAULT); + if (!channelz_enabled) return; + // Get parameters needed to create the channelz node. + const size_t channel_tracer_max_memory = grpc_channel_args_find_integer( + args, GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, + {GRPC_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE_DEFAULT, 0, INT_MAX}); + const bool is_internal_channel = grpc_channel_args_find_bool( + args, GRPC_ARG_CHANNELZ_IS_INTERNAL_CHANNEL, false); + // Create the channelz node. + const char* target = grpc_channel_stack_builder_get_target(builder); + grpc_core::RefCountedPtr channelz_node = + grpc_core::MakeRefCounted( + target != nullptr ? target : "", channel_tracer_max_memory, + is_internal_channel); + channelz_node->AddTraceEvent( + grpc_core::channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("Channel created")); + // Add channelz node to channel args. + // We remove the is_internal_channel arg, since we no longer need it. + grpc_arg new_arg = grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_CHANNELZ_CHANNEL_NODE), channelz_node.get(), + &channelz_node_arg_vtable); + const char* args_to_remove[] = {GRPC_ARG_CHANNELZ_IS_INTERNAL_CHANNEL}; + grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( + args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), &new_arg, 1); + grpc_channel_stack_builder_set_channel_arguments(builder, new_args); + grpc_channel_args_destroy(new_args); +} + +} // namespace + +grpc_channel* grpc_channel_create(const char* target, + const grpc_channel_args* input_args, + grpc_channel_stack_type channel_stack_type, + grpc_transport* optional_transport, + grpc_resource_user* resource_user, + size_t preallocated_bytes, + grpc_error_handle* error) { + // We need to make sure that grpc_shutdown() does not shut things down + // until after the channel is destroyed. However, the channel may not + // actually be destroyed by the time grpc_channel_destroy() returns, + // since there may be other existing refs to the channel. If those + // refs are held by things that are visible to the wrapped language + // (such as outstanding calls on the channel), then the wrapped + // language can be responsible for making sure that grpc_shutdown() + // does not run until after those refs are released. However, the + // channel may also have refs to itself held internally for various + // things that need to be cleaned up at channel destruction (e.g., + // LB policies, subchannels, etc), and because these refs are not + // visible to the wrapped language, it cannot be responsible for + // deferring grpc_shutdown() until after they are released. To + // accommodate that, we call grpc_init() here and then call + // grpc_shutdown() when the channel is actually destroyed, thus + // ensuring that shutdown is deferred until that point. + grpc_init(); + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + const grpc_core::UniquePtr default_authority = + get_default_authority(input_args); + grpc_channel_args* args = + build_channel_args(input_args, default_authority.get()); + if (grpc_channel_stack_type_is_client(channel_stack_type)) { + auto channel_args_mutator = + grpc_channel_args_get_client_channel_creation_mutator(); + if (channel_args_mutator != nullptr) { + args = channel_args_mutator(target, args, channel_stack_type); + } + } + grpc_channel_stack_builder_set_channel_arguments(builder, args); + grpc_channel_args_destroy(args); + grpc_channel_stack_builder_set_target(builder, target); + grpc_channel_stack_builder_set_transport(builder, optional_transport); + if (!grpc_core::CoreConfiguration::Get().channel_init().CreateStack( + builder, channel_stack_type)) { + grpc_channel_stack_builder_destroy(builder); + if (resource_user != nullptr) { + if (preallocated_bytes > 0) { + grpc_resource_user_free(resource_user, preallocated_bytes); + } + grpc_resource_user_unref(resource_user); + } + grpc_shutdown(); // Since we won't call destroy_channel(). + return nullptr; + } + // We only need to do this for clients here. For servers, this will be + // done in src/core/lib/surface/server.cc. + if (grpc_channel_stack_type_is_client(channel_stack_type)) { + CreateChannelzNode(builder); + } + grpc_channel* channel = grpc_channel_create_with_builder( + builder, channel_stack_type, resource_user, preallocated_bytes, error); + if (channel == nullptr) { + grpc_shutdown(); // Since we won't call destroy_channel(). + } + return channel; +} + +size_t grpc_channel_get_call_size_estimate(grpc_channel* channel) { +#define ROUND_UP_SIZE 256 + /* We round up our current estimate to the NEXT value of ROUND_UP_SIZE. + This ensures: + 1. a consistent size allocation when our estimate is drifting slowly + (which is common) - which tends to help most allocators reuse memory + 2. a small amount of allowed growth over the estimate without hitting + the arena size doubling case, reducing overall memory usage */ + return (static_cast( + gpr_atm_no_barrier_load(&channel->call_size_estimate)) + + 2 * ROUND_UP_SIZE) & + ~static_cast(ROUND_UP_SIZE - 1); +} + +void grpc_channel_update_call_size_estimate(grpc_channel* channel, + size_t size) { + size_t cur = static_cast( + gpr_atm_no_barrier_load(&channel->call_size_estimate)); + if (cur < size) { + /* size grew: update estimate */ + gpr_atm_no_barrier_cas(&channel->call_size_estimate, + static_cast(cur), + static_cast(size)); + /* if we lose: never mind, something else will likely update soon enough */ + } else if (cur == size) { + /* no change: holding pattern */ + } else if (cur > 0) { + /* size shrank: decrease estimate */ + gpr_atm_no_barrier_cas( + &channel->call_size_estimate, static_cast(cur), + static_cast(std::min(cur - 1, (255 * cur + size) / 256))); + /* if we lose: never mind, something else will likely update soon enough */ + } +} + +char* grpc_channel_get_target(grpc_channel* channel) { + GRPC_API_TRACE("grpc_channel_get_target(channel=%p)", 1, (channel)); + return gpr_strdup(channel->target); +} + +void grpc_channel_get_info(grpc_channel* channel, + const grpc_channel_info* channel_info) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_channel_element* elem = + grpc_channel_stack_element(CHANNEL_STACK_FROM_CHANNEL(channel), 0); + elem->filter->get_channel_info(elem, channel_info); +} + +void grpc_channel_reset_connect_backoff(grpc_channel* channel) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_channel_reset_connect_backoff(channel=%p)", 1, + (channel)); + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->reset_connect_backoff = true; + grpc_channel_element* elem = + grpc_channel_stack_element(CHANNEL_STACK_FROM_CHANNEL(channel), 0); + elem->filter->start_transport_op(elem, op); +} + +static grpc_call* grpc_channel_create_call_internal( + grpc_channel* channel, grpc_call* parent_call, uint32_t propagation_mask, + grpc_completion_queue* cq, grpc_pollset_set* pollset_set_alternative, + grpc_mdelem path_mdelem, grpc_mdelem authority_mdelem, + grpc_millis deadline) { + grpc_mdelem send_metadata[2]; + size_t num_metadata = 0; + + GPR_ASSERT(channel->is_client); + GPR_ASSERT(!(cq != nullptr && pollset_set_alternative != nullptr)); + + send_metadata[num_metadata++] = path_mdelem; + if (!GRPC_MDISNULL(authority_mdelem)) { + send_metadata[num_metadata++] = authority_mdelem; + } + + grpc_call_create_args args; + args.channel = channel; + args.server = nullptr; + args.parent = parent_call; + args.propagation_mask = propagation_mask; + args.cq = cq; + args.pollset_set_alternative = pollset_set_alternative; + args.server_transport_data = nullptr; + args.add_initial_metadata = send_metadata; + args.add_initial_metadata_count = num_metadata; + args.send_deadline = deadline; + + grpc_call* call; + GRPC_LOG_IF_ERROR("call_create", grpc_call_create(&args, &call)); + return call; +} + +grpc_call* grpc_channel_create_call(grpc_channel* channel, + grpc_call* parent_call, + uint32_t propagation_mask, + grpc_completion_queue* completion_queue, + grpc_slice method, const grpc_slice* host, + gpr_timespec deadline, void* reserved) { + GPR_ASSERT(!reserved); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_call* call = grpc_channel_create_call_internal( + channel, parent_call, propagation_mask, completion_queue, nullptr, + grpc_mdelem_create(GRPC_MDSTR_PATH, method, nullptr), + host != nullptr ? grpc_mdelem_create(GRPC_MDSTR_AUTHORITY, *host, nullptr) + : GRPC_MDNULL, + grpc_timespec_to_millis_round_up(deadline)); + + return call; +} + +grpc_call* grpc_channel_create_pollset_set_call( + grpc_channel* channel, grpc_call* parent_call, uint32_t propagation_mask, + grpc_pollset_set* pollset_set, const grpc_slice& method, + const grpc_slice* host, grpc_millis deadline, void* reserved) { + GPR_ASSERT(!reserved); + return grpc_channel_create_call_internal( + channel, parent_call, propagation_mask, nullptr, pollset_set, + grpc_mdelem_create(GRPC_MDSTR_PATH, method, nullptr), + host != nullptr ? grpc_mdelem_create(GRPC_MDSTR_AUTHORITY, *host, nullptr) + : GRPC_MDNULL, + deadline); +} + +namespace grpc_core { + +RegisteredCall::RegisteredCall(const char* method_arg, const char* host_arg) + : path(method_arg != nullptr && method_arg[0] != 0 + ? grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, grpc_slice_from_copied_string(method_arg)) + : GRPC_MDNULL), + authority( + host_arg != nullptr && host_arg[0] != 0 + ? grpc_mdelem_from_slices(GRPC_MDSTR_AUTHORITY, + grpc_slice_from_copied_string(host_arg)) + : GRPC_MDNULL) {} + +RegisteredCall::RegisteredCall(const RegisteredCall& other) + : path(GRPC_MDELEM_REF(other.path)), + authority(GRPC_MDELEM_REF(other.authority)) {} + +RegisteredCall::~RegisteredCall() { + GRPC_MDELEM_UNREF(path); + GRPC_MDELEM_UNREF(authority); +} + +} // namespace grpc_core + +void* grpc_channel_register_call(grpc_channel* channel, const char* method, + const char* host, void* reserved) { + GRPC_API_TRACE( + "grpc_channel_register_call(channel=%p, method=%s, host=%s, reserved=%p)", + 4, (channel, method, host, reserved)); + GPR_ASSERT(!reserved); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + + grpc_core::MutexLock lock(&channel->registration_table->mu); + channel->registration_table->method_registration_attempts++; + auto key = std::make_pair(std::string(host != nullptr ? host : ""), + std::string(method != nullptr ? method : "")); + auto rc_posn = channel->registration_table->map.find(key); + if (rc_posn != channel->registration_table->map.end()) { + return &rc_posn->second; + } + auto insertion_result = channel->registration_table->map.insert( + {std::move(key), grpc_core::RegisteredCall(method, host)}); + return &insertion_result.first->second; +} + +grpc_call* grpc_channel_create_registered_call( + grpc_channel* channel, grpc_call* parent_call, uint32_t propagation_mask, + grpc_completion_queue* completion_queue, void* registered_call_handle, + gpr_timespec deadline, void* reserved) { + grpc_core::RegisteredCall* rc = + static_cast(registered_call_handle); + GRPC_API_TRACE( + "grpc_channel_create_registered_call(" + "channel=%p, parent_call=%p, propagation_mask=%x, completion_queue=%p, " + "registered_call_handle=%p, " + "deadline=gpr_timespec { tv_sec: %" PRId64 + ", tv_nsec: %d, clock_type: %d }, " + "reserved=%p)", + 9, + (channel, parent_call, (unsigned)propagation_mask, completion_queue, + registered_call_handle, deadline.tv_sec, deadline.tv_nsec, + (int)deadline.clock_type, reserved)); + GPR_ASSERT(!reserved); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_call* call = grpc_channel_create_call_internal( + channel, parent_call, propagation_mask, completion_queue, nullptr, + GRPC_MDELEM_REF(rc->path), GRPC_MDELEM_REF(rc->authority), + grpc_timespec_to_millis_round_up(deadline)); + + return call; +} + +static void destroy_channel(void* arg, grpc_error_handle /*error*/) { + grpc_channel* channel = static_cast(arg); + if (channel->channelz_node != nullptr) { + channel->channelz_node->AddTraceEvent( + grpc_core::channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("Channel destroyed")); + channel->channelz_node.reset(); + } + grpc_channel_stack_destroy(CHANNEL_STACK_FROM_CHANNEL(channel)); + channel->registration_table.Destroy(); + if (channel->resource_user != nullptr) { + if (channel->preallocated_bytes > 0) { + grpc_resource_user_free(channel->resource_user, + channel->preallocated_bytes); + } + grpc_resource_user_unref(channel->resource_user); + } + gpr_free(channel->target); + gpr_free(channel); + // See comment in grpc_channel_create() for why we do this. + grpc_shutdown(); +} + +void grpc_channel_destroy_internal(grpc_channel* channel) { + grpc_transport_op* op = grpc_make_transport_op(nullptr); + grpc_channel_element* elem; + GRPC_API_TRACE("grpc_channel_destroy(channel=%p)", 1, (channel)); + op->disconnect_with_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Channel Destroyed"); + elem = grpc_channel_stack_element(CHANNEL_STACK_FROM_CHANNEL(channel), 0); + elem->filter->start_transport_op(elem, op); + GRPC_CHANNEL_INTERNAL_UNREF(channel, "channel"); +} + +void grpc_channel_destroy(grpc_channel* channel) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_channel_destroy_internal(channel); +} diff --git a/src/core/lib/surface/channel_init.cc b/src/core/lib/surface/channel_init.cc new file mode 100644 index 00000000..beeb23d4 --- /dev/null +++ b/src/core/lib/surface/channel_init.cc @@ -0,0 +1,56 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/channel_init.h" + +#include + +namespace grpc_core { + +void ChannelInit::Builder::RegisterStage(grpc_channel_stack_type type, + int priority, Stage stage) { + slots_[type].emplace_back(std::move(stage), priority); +} + +ChannelInit ChannelInit::Builder::Build() { + ChannelInit result; + for (int i = 0; i < GRPC_NUM_CHANNEL_STACK_TYPES; i++) { + auto& slots = slots_[i]; + std::stable_sort( + slots.begin(), slots.end(), + [](const Slot& a, const Slot& b) { return a.priority < b.priority; }); + auto& result_slots = result.slots_[i]; + result_slots.reserve(slots.size()); + for (auto& slot : slots) { + result_slots.emplace_back(std::move(slot.stage)); + } + } + return result; +} + +bool ChannelInit::CreateStack(grpc_channel_stack_builder* builder, + grpc_channel_stack_type type) const { + for (const auto& stage : slots_[type]) { + if (!stage(builder)) return false; + } + return true; +} + +} // namespace grpc_core diff --git a/src/core/lib/surface/channel_ping.cc b/src/core/lib/surface/channel_ping.cc new file mode 100644 index 00000000..22c49386 --- /dev/null +++ b/src/core/lib/surface/channel_ping.cc @@ -0,0 +1,63 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" + +struct ping_result { + grpc_closure closure; + void* tag; + grpc_completion_queue* cq; + grpc_cq_completion completion_storage; +}; +static void ping_destroy(void* arg, grpc_cq_completion* /*storage*/) { + gpr_free(arg); +} + +static void ping_done(void* arg, grpc_error_handle error) { + ping_result* pr = static_cast(arg); + grpc_cq_end_op(pr->cq, pr->tag, GRPC_ERROR_REF(error), ping_destroy, pr, + &pr->completion_storage); +} + +void grpc_channel_ping(grpc_channel* channel, grpc_completion_queue* cq, + void* tag, void* reserved) { + GRPC_API_TRACE("grpc_channel_ping(channel=%p, cq=%p, tag=%p, reserved=%p)", 4, + (channel, cq, tag, reserved)); + grpc_transport_op* op = grpc_make_transport_op(nullptr); + ping_result* pr = static_cast(gpr_malloc(sizeof(*pr))); + grpc_channel_element* top_elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(channel), 0); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(reserved == nullptr); + pr->tag = tag; + pr->cq = cq; + GRPC_CLOSURE_INIT(&pr->closure, ping_done, pr, grpc_schedule_on_exec_ctx); + op->send_ping.on_ack = &pr->closure; + op->bind_pollset = grpc_cq_pollset(cq); + GPR_ASSERT(grpc_cq_begin_op(cq, tag)); + top_elem->filter->start_transport_op(top_elem, op); +} diff --git a/src/core/lib/surface/channel_stack_type.cc b/src/core/lib/surface/channel_stack_type.cc new file mode 100644 index 00000000..ecbc3ef6 --- /dev/null +++ b/src/core/lib/surface/channel_stack_type.cc @@ -0,0 +1,59 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/channel_stack_type.h" + +#include + +bool grpc_channel_stack_type_is_client(grpc_channel_stack_type type) { + switch (type) { + case GRPC_CLIENT_CHANNEL: + return true; + case GRPC_CLIENT_SUBCHANNEL: + return true; + case GRPC_CLIENT_LAME_CHANNEL: + return true; + case GRPC_CLIENT_DIRECT_CHANNEL: + return true; + case GRPC_SERVER_CHANNEL: + return false; + case GRPC_NUM_CHANNEL_STACK_TYPES: + break; + } + GPR_UNREACHABLE_CODE(return true;); +} + +const char* grpc_channel_stack_type_string(grpc_channel_stack_type type) { + switch (type) { + case GRPC_CLIENT_CHANNEL: + return "CLIENT_CHANNEL"; + case GRPC_CLIENT_SUBCHANNEL: + return "CLIENT_SUBCHANNEL"; + case GRPC_SERVER_CHANNEL: + return "SERVER_CHANNEL"; + case GRPC_CLIENT_LAME_CHANNEL: + return "CLIENT_LAME_CHANNEL"; + case GRPC_CLIENT_DIRECT_CHANNEL: + return "CLIENT_DIRECT_CHANNEL"; + case GRPC_NUM_CHANNEL_STACK_TYPES: + break; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} diff --git a/src/core/lib/surface/completion_queue.cc b/src/core/lib/surface/completion_queue.cc new file mode 100644 index 00000000..847b0fee --- /dev/null +++ b/src/core/lib/surface/completion_queue.cc @@ -0,0 +1,1429 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "src/core/lib/surface/completion_queue.h" + +#include +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/spinlock.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/event_string.h" + +grpc_core::TraceFlag grpc_trace_operation_failures(false, "op_failure"); +grpc_core::DebugOnlyTraceFlag grpc_trace_pending_tags(false, "pending_tags"); +grpc_core::DebugOnlyTraceFlag grpc_trace_cq_refcount(false, "cq_refcount"); + +namespace { + +// Specifies a cq thread local cache. +// The first event that occurs on a thread +// with a cq cache will go into that cache, and +// will only be returned on the thread that initialized the cache. +// NOTE: Only one event will ever be cached. +static GPR_THREAD_LOCAL(grpc_cq_completion*) g_cached_event; +static GPR_THREAD_LOCAL(grpc_completion_queue*) g_cached_cq; + +struct plucker { + grpc_pollset_worker** worker; + void* tag; +}; +struct cq_poller_vtable { + bool can_get_pollset; + bool can_listen; + size_t (*size)(void); + void (*init)(grpc_pollset* pollset, gpr_mu** mu); + grpc_error_handle (*kick)(grpc_pollset* pollset, + grpc_pollset_worker* specific_worker); + grpc_error_handle (*work)(grpc_pollset* pollset, grpc_pollset_worker** worker, + grpc_millis deadline); + void (*shutdown)(grpc_pollset* pollset, grpc_closure* closure); + void (*destroy)(grpc_pollset* pollset); +}; +typedef struct non_polling_worker { + gpr_cv cv; + bool kicked; + struct non_polling_worker* next; + struct non_polling_worker* prev; +} non_polling_worker; + +struct non_polling_poller { + gpr_mu mu; + bool kicked_without_poller; + non_polling_worker* root; + grpc_closure* shutdown; +}; +size_t non_polling_poller_size(void) { return sizeof(non_polling_poller); } + +void non_polling_poller_init(grpc_pollset* pollset, gpr_mu** mu) { + non_polling_poller* npp = reinterpret_cast(pollset); + gpr_mu_init(&npp->mu); + *mu = &npp->mu; +} + +void non_polling_poller_destroy(grpc_pollset* pollset) { + non_polling_poller* npp = reinterpret_cast(pollset); + gpr_mu_destroy(&npp->mu); +} + +grpc_error_handle non_polling_poller_work(grpc_pollset* pollset, + grpc_pollset_worker** worker, + grpc_millis deadline) { + non_polling_poller* npp = reinterpret_cast(pollset); + if (npp->shutdown) return GRPC_ERROR_NONE; + if (npp->kicked_without_poller) { + npp->kicked_without_poller = false; + return GRPC_ERROR_NONE; + } + non_polling_worker w; + gpr_cv_init(&w.cv); + if (worker != nullptr) *worker = reinterpret_cast(&w); + if (npp->root == nullptr) { + npp->root = w.next = w.prev = &w; + } else { + w.next = npp->root; + w.prev = w.next->prev; + w.next->prev = w.prev->next = &w; + } + w.kicked = false; + gpr_timespec deadline_ts = + grpc_millis_to_timespec(deadline, GPR_CLOCK_MONOTONIC); + while (!npp->shutdown && !w.kicked && + !gpr_cv_wait(&w.cv, &npp->mu, deadline_ts)) { + } + grpc_core::ExecCtx::Get()->InvalidateNow(); + if (&w == npp->root) { + npp->root = w.next; + if (&w == npp->root) { + if (npp->shutdown) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, npp->shutdown, GRPC_ERROR_NONE); + } + npp->root = nullptr; + } + } + w.next->prev = w.prev; + w.prev->next = w.next; + gpr_cv_destroy(&w.cv); + if (worker != nullptr) *worker = nullptr; + return GRPC_ERROR_NONE; +} + +grpc_error_handle non_polling_poller_kick( + grpc_pollset* pollset, grpc_pollset_worker* specific_worker) { + non_polling_poller* p = reinterpret_cast(pollset); + if (specific_worker == nullptr) { + specific_worker = reinterpret_cast(p->root); + } + if (specific_worker != nullptr) { + non_polling_worker* w = + reinterpret_cast(specific_worker); + if (!w->kicked) { + w->kicked = true; + gpr_cv_signal(&w->cv); + } + } else { + p->kicked_without_poller = true; + } + return GRPC_ERROR_NONE; +} + +void non_polling_poller_shutdown(grpc_pollset* pollset, grpc_closure* closure) { + non_polling_poller* p = reinterpret_cast(pollset); + GPR_ASSERT(closure != nullptr); + p->shutdown = closure; + if (p->root == nullptr) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + } else { + non_polling_worker* w = p->root; + do { + gpr_cv_signal(&w->cv); + w = w->next; + } while (w != p->root); + } +} + +const cq_poller_vtable g_poller_vtable_by_poller_type[] = { + /* GRPC_CQ_DEFAULT_POLLING */ + {true, true, grpc_pollset_size, grpc_pollset_init, grpc_pollset_kick, + grpc_pollset_work, grpc_pollset_shutdown, grpc_pollset_destroy}, + /* GRPC_CQ_NON_LISTENING */ + {true, false, grpc_pollset_size, grpc_pollset_init, grpc_pollset_kick, + grpc_pollset_work, grpc_pollset_shutdown, grpc_pollset_destroy}, + /* GRPC_CQ_NON_POLLING */ + {false, false, non_polling_poller_size, non_polling_poller_init, + non_polling_poller_kick, non_polling_poller_work, + non_polling_poller_shutdown, non_polling_poller_destroy}, +}; + +} // namespace + +struct cq_vtable { + grpc_cq_completion_type cq_completion_type; + size_t data_size; + void (*init)(void* data, grpc_completion_queue_functor* shutdown_callback); + void (*shutdown)(grpc_completion_queue* cq); + void (*destroy)(void* data); + bool (*begin_op)(grpc_completion_queue* cq, void* tag); + void (*end_op)(grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), + void* done_arg, grpc_cq_completion* storage, bool internal); + grpc_event (*next)(grpc_completion_queue* cq, gpr_timespec deadline, + void* reserved); + grpc_event (*pluck)(grpc_completion_queue* cq, void* tag, + gpr_timespec deadline, void* reserved); +}; + +namespace { + +/* Queue that holds the cq_completion_events. Internally uses + * MultiProducerSingleConsumerQueue (a lockfree multiproducer single consumer + * queue). It uses a queue_lock to support multiple consumers. + * Only used in completion queues whose completion_type is GRPC_CQ_NEXT */ +class CqEventQueue { + public: + CqEventQueue() = default; + ~CqEventQueue() = default; + + /* Note: The counter is not incremented/decremented atomically with push/pop. + * The count is only eventually consistent */ + intptr_t num_items() const { + return num_queue_items_.load(std::memory_order_relaxed); + } + + bool Push(grpc_cq_completion* c); + grpc_cq_completion* Pop(); + + private: + /* Spinlock to serialize consumers i.e pop() operations */ + gpr_spinlock queue_lock_ = GPR_SPINLOCK_INITIALIZER; + + grpc_core::MultiProducerSingleConsumerQueue queue_; + + /* A lazy counter of number of items in the queue. This is NOT atomically + incremented/decremented along with push/pop operations and hence is only + eventually consistent */ + std::atomic num_queue_items_{0}; +}; + +struct cq_next_data { + ~cq_next_data() { + GPR_ASSERT(queue.num_items() == 0); +#ifndef NDEBUG + if (pending_events.load(std::memory_order_acquire) != 0) { + gpr_log(GPR_ERROR, "Destroying CQ without draining it fully."); + } +#endif + } + + /** Completed events for completion-queues of type GRPC_CQ_NEXT */ + CqEventQueue queue; + + /** Counter of how many things have ever been queued on this completion queue + useful for avoiding locks to check the queue */ + std::atomic things_queued_ever{0}; + + /** Number of outstanding events (+1 if not shut down) + Initial count is dropped by grpc_completion_queue_shutdown */ + std::atomic pending_events{1}; + + /** 0 initially. 1 once we initiated shutdown */ + bool shutdown_called = false; +}; + +struct cq_pluck_data { + cq_pluck_data() { + completed_tail = &completed_head; + completed_head.next = reinterpret_cast(completed_tail); + } + + ~cq_pluck_data() { + GPR_ASSERT(completed_head.next == + reinterpret_cast(&completed_head)); +#ifndef NDEBUG + if (pending_events.load(std::memory_order_acquire) != 0) { + gpr_log(GPR_ERROR, "Destroying CQ without draining it fully."); + } +#endif + } + + /** Completed events for completion-queues of type GRPC_CQ_PLUCK */ + grpc_cq_completion completed_head; + grpc_cq_completion* completed_tail; + + /** Number of pending events (+1 if we're not shutdown). + Initial count is dropped by grpc_completion_queue_shutdown. */ + std::atomic pending_events{1}; + + /** Counter of how many things have ever been queued on this completion queue + useful for avoiding locks to check the queue */ + std::atomic things_queued_ever{0}; + + /** 0 initially. 1 once we completed shutting */ + /* TODO: (sreek) This is not needed since (shutdown == 1) if and only if + * (pending_events == 0). So consider removing this in future and use + * pending_events */ + std::atomic shutdown{false}; + + /** 0 initially. 1 once we initiated shutdown */ + bool shutdown_called = false; + + int num_pluckers = 0; + plucker pluckers[GRPC_MAX_COMPLETION_QUEUE_PLUCKERS]; +}; + +struct cq_callback_data { + explicit cq_callback_data(grpc_completion_queue_functor* shutdown_callback) + : shutdown_callback(shutdown_callback) {} + + ~cq_callback_data() { +#ifndef NDEBUG + if (pending_events.load(std::memory_order_acquire) != 0) { + gpr_log(GPR_ERROR, "Destroying CQ without draining it fully."); + } +#endif + } + + /** No actual completed events queue, unlike other types */ + + /** Number of pending events (+1 if we're not shutdown). + Initial count is dropped by grpc_completion_queue_shutdown. */ + std::atomic pending_events{1}; + + /** 0 initially. 1 once we initiated shutdown */ + bool shutdown_called = false; + + /** A callback that gets invoked when the CQ completes shutdown */ + grpc_completion_queue_functor* shutdown_callback; +}; + +} // namespace + +/* Completion queue structure */ +struct grpc_completion_queue { + /** Once owning_refs drops to zero, we will destroy the cq */ + grpc_core::RefCount owning_refs; + + gpr_mu* mu; + + const cq_vtable* vtable; + const cq_poller_vtable* poller_vtable; + +#ifndef NDEBUG + void** outstanding_tags; + size_t outstanding_tag_count; + size_t outstanding_tag_capacity; +#endif + + grpc_closure pollset_shutdown_done; + int num_polls; +}; + +/* Forward declarations */ +static void cq_finish_shutdown_next(grpc_completion_queue* cq); +static void cq_finish_shutdown_pluck(grpc_completion_queue* cq); +static void cq_finish_shutdown_callback(grpc_completion_queue* cq); +static void cq_shutdown_next(grpc_completion_queue* cq); +static void cq_shutdown_pluck(grpc_completion_queue* cq); +static void cq_shutdown_callback(grpc_completion_queue* cq); + +static bool cq_begin_op_for_next(grpc_completion_queue* cq, void* tag); +static bool cq_begin_op_for_pluck(grpc_completion_queue* cq, void* tag); +static bool cq_begin_op_for_callback(grpc_completion_queue* cq, void* tag); + +// A cq_end_op function is called when an operation on a given CQ with +// a given tag has completed. The storage argument is a reference to the +// space reserved for this completion as it is placed into the corresponding +// queue. The done argument is a callback that will be invoked when it is +// safe to free up that storage. The storage MUST NOT be freed until the +// done callback is invoked. +static void cq_end_op_for_next( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool internal); + +static void cq_end_op_for_pluck( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool internal); + +static void cq_end_op_for_callback( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool internal); + +static grpc_event cq_next(grpc_completion_queue* cq, gpr_timespec deadline, + void* reserved); + +static grpc_event cq_pluck(grpc_completion_queue* cq, void* tag, + gpr_timespec deadline, void* reserved); + +// Note that cq_init_next and cq_init_pluck do not use the shutdown_callback +static void cq_init_next(void* data, + grpc_completion_queue_functor* shutdown_callback); +static void cq_init_pluck(void* data, + grpc_completion_queue_functor* shutdown_callback); +static void cq_init_callback(void* data, + grpc_completion_queue_functor* shutdown_callback); +static void cq_destroy_next(void* data); +static void cq_destroy_pluck(void* data); +static void cq_destroy_callback(void* data); + +/* Completion queue vtables based on the completion-type */ +static const cq_vtable g_cq_vtable[] = { + /* GRPC_CQ_NEXT */ + {GRPC_CQ_NEXT, sizeof(cq_next_data), cq_init_next, cq_shutdown_next, + cq_destroy_next, cq_begin_op_for_next, cq_end_op_for_next, cq_next, + nullptr}, + /* GRPC_CQ_PLUCK */ + {GRPC_CQ_PLUCK, sizeof(cq_pluck_data), cq_init_pluck, cq_shutdown_pluck, + cq_destroy_pluck, cq_begin_op_for_pluck, cq_end_op_for_pluck, nullptr, + cq_pluck}, + /* GRPC_CQ_CALLBACK */ + {GRPC_CQ_CALLBACK, sizeof(cq_callback_data), cq_init_callback, + cq_shutdown_callback, cq_destroy_callback, cq_begin_op_for_callback, + cq_end_op_for_callback, nullptr, nullptr}, +}; + +#define DATA_FROM_CQ(cq) ((void*)((cq) + 1)) +#define POLLSET_FROM_CQ(cq) \ + ((grpc_pollset*)((cq)->vtable->data_size + (char*)DATA_FROM_CQ(cq))) + +grpc_core::TraceFlag grpc_cq_pluck_trace(false, "queue_pluck"); + +#define GRPC_SURFACE_TRACE_RETURNED_EVENT(cq, event) \ + do { \ + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace) && \ + (GRPC_TRACE_FLAG_ENABLED(grpc_cq_pluck_trace) || \ + (event)->type != GRPC_QUEUE_TIMEOUT)) { \ + gpr_log(GPR_INFO, "RETURN_EVENT[%p]: %s", cq, \ + grpc_event_string(event).c_str()); \ + } \ + } while (0) + +static void on_pollset_shutdown_done(void* arg, grpc_error_handle error); + +void grpc_cq_global_init() {} + +void grpc_completion_queue_thread_local_cache_init(grpc_completion_queue* cq) { + if (g_cached_cq == nullptr) { + g_cached_event = nullptr; + g_cached_cq = cq; + } +} + +int grpc_completion_queue_thread_local_cache_flush(grpc_completion_queue* cq, + void** tag, int* ok) { + grpc_cq_completion* storage = g_cached_event; + int ret = 0; + if (storage != nullptr && g_cached_cq == cq) { + *tag = storage->tag; + grpc_core::ExecCtx exec_ctx; + *ok = (storage->next & static_cast(1)) == 1; + storage->done(storage->done_arg, storage); + ret = 1; + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + GRPC_CQ_INTERNAL_REF(cq, "shutting_down"); + gpr_mu_lock(cq->mu); + cq_finish_shutdown_next(cq); + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down"); + } + } + g_cached_event = nullptr; + g_cached_cq = nullptr; + + return ret; +} + +bool CqEventQueue::Push(grpc_cq_completion* c) { + queue_.Push( + reinterpret_cast(c)); + return num_queue_items_.fetch_add(1, std::memory_order_relaxed) == 0; +} + +grpc_cq_completion* CqEventQueue::Pop() { + grpc_cq_completion* c = nullptr; + + if (gpr_spinlock_trylock(&queue_lock_)) { + GRPC_STATS_INC_CQ_EV_QUEUE_TRYLOCK_SUCCESSES(); + + bool is_empty = false; + c = reinterpret_cast(queue_.PopAndCheckEnd(&is_empty)); + gpr_spinlock_unlock(&queue_lock_); + + if (c == nullptr && !is_empty) { + GRPC_STATS_INC_CQ_EV_QUEUE_TRANSIENT_POP_FAILURES(); + } + } else { + GRPC_STATS_INC_CQ_EV_QUEUE_TRYLOCK_FAILURES(); + } + + if (c) { + num_queue_items_.fetch_sub(1, std::memory_order_relaxed); + } + + return c; +} + +grpc_completion_queue* grpc_completion_queue_create_internal( + grpc_cq_completion_type completion_type, grpc_cq_polling_type polling_type, + grpc_completion_queue_functor* shutdown_callback) { + GPR_TIMER_SCOPE("grpc_completion_queue_create_internal", 0); + + grpc_completion_queue* cq; + + GRPC_API_TRACE( + "grpc_completion_queue_create_internal(completion_type=%d, " + "polling_type=%d)", + 2, (completion_type, polling_type)); + + const cq_vtable* vtable = &g_cq_vtable[completion_type]; + const cq_poller_vtable* poller_vtable = + &g_poller_vtable_by_poller_type[polling_type]; + + grpc_core::ExecCtx exec_ctx; + GRPC_STATS_INC_CQS_CREATED(); + + cq = static_cast( + gpr_zalloc(sizeof(grpc_completion_queue) + vtable->data_size + + poller_vtable->size())); + + cq->vtable = vtable; + cq->poller_vtable = poller_vtable; + + /* One for destroy(), one for pollset_shutdown */ + new (&cq->owning_refs) grpc_core::RefCount(2); + + poller_vtable->init(POLLSET_FROM_CQ(cq), &cq->mu); + vtable->init(DATA_FROM_CQ(cq), shutdown_callback); + + GRPC_CLOSURE_INIT(&cq->pollset_shutdown_done, on_pollset_shutdown_done, cq, + grpc_schedule_on_exec_ctx); + return cq; +} + +static void cq_init_next(void* data, + grpc_completion_queue_functor* /*shutdown_callback*/) { + new (data) cq_next_data(); +} + +static void cq_destroy_next(void* data) { + cq_next_data* cqd = static_cast(data); + cqd->~cq_next_data(); +} + +static void cq_init_pluck( + void* data, grpc_completion_queue_functor* /*shutdown_callback*/) { + new (data) cq_pluck_data(); +} + +static void cq_destroy_pluck(void* data) { + cq_pluck_data* cqd = static_cast(data); + cqd->~cq_pluck_data(); +} + +static void cq_init_callback(void* data, + grpc_completion_queue_functor* shutdown_callback) { + new (data) cq_callback_data(shutdown_callback); +} + +static void cq_destroy_callback(void* data) { + cq_callback_data* cqd = static_cast(data); + cqd->~cq_callback_data(); +} + +grpc_cq_completion_type grpc_get_cq_completion_type(grpc_completion_queue* cq) { + return cq->vtable->cq_completion_type; +} + +int grpc_get_cq_poll_num(grpc_completion_queue* cq) { + int cur_num_polls; + gpr_mu_lock(cq->mu); + cur_num_polls = cq->num_polls; + gpr_mu_unlock(cq->mu); + return cur_num_polls; +} + +#ifndef NDEBUG +void grpc_cq_internal_ref(grpc_completion_queue* cq, const char* reason, + const char* file, int line) { + grpc_core::DebugLocation debug_location(file, line); +#else +void grpc_cq_internal_ref(grpc_completion_queue* cq) { + grpc_core::DebugLocation debug_location; + const char* reason = nullptr; +#endif + cq->owning_refs.Ref(debug_location, reason); +} + +static void on_pollset_shutdown_done(void* arg, grpc_error_handle /*error*/) { + grpc_completion_queue* cq = static_cast(arg); + GRPC_CQ_INTERNAL_UNREF(cq, "pollset_destroy"); +} + +#ifndef NDEBUG +void grpc_cq_internal_unref(grpc_completion_queue* cq, const char* reason, + const char* file, int line) { + grpc_core::DebugLocation debug_location(file, line); +#else +void grpc_cq_internal_unref(grpc_completion_queue* cq) { + grpc_core::DebugLocation debug_location; + const char* reason = nullptr; +#endif + if (GPR_UNLIKELY(cq->owning_refs.Unref(debug_location, reason))) { + cq->vtable->destroy(DATA_FROM_CQ(cq)); + cq->poller_vtable->destroy(POLLSET_FROM_CQ(cq)); +#ifndef NDEBUG + gpr_free(cq->outstanding_tags); +#endif + gpr_free(cq); + } +} + +#ifndef NDEBUG +static void cq_check_tag(grpc_completion_queue* cq, void* tag, bool lock_cq) { + int found = 0; + if (lock_cq) { + gpr_mu_lock(cq->mu); + } + + for (int i = 0; i < static_cast(cq->outstanding_tag_count); i++) { + if (cq->outstanding_tags[i] == tag) { + cq->outstanding_tag_count--; + std::swap(cq->outstanding_tags[i], + cq->outstanding_tags[cq->outstanding_tag_count]); + found = 1; + break; + } + } + + if (lock_cq) { + gpr_mu_unlock(cq->mu); + } + + GPR_ASSERT(found); +} +#else +static void cq_check_tag(grpc_completion_queue* /*cq*/, void* /*tag*/, + bool /*lock_cq*/) {} +#endif + +static bool cq_begin_op_for_next(grpc_completion_queue* cq, void* /*tag*/) { + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + return grpc_core::IncrementIfNonzero(&cqd->pending_events); +} + +static bool cq_begin_op_for_pluck(grpc_completion_queue* cq, void* /*tag*/) { + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + return grpc_core::IncrementIfNonzero(&cqd->pending_events); +} + +static bool cq_begin_op_for_callback(grpc_completion_queue* cq, void* /*tag*/) { + cq_callback_data* cqd = static_cast DATA_FROM_CQ(cq); + return grpc_core::IncrementIfNonzero(&cqd->pending_events); +} + +bool grpc_cq_begin_op(grpc_completion_queue* cq, void* tag) { +#ifndef NDEBUG + gpr_mu_lock(cq->mu); + if (cq->outstanding_tag_count == cq->outstanding_tag_capacity) { + cq->outstanding_tag_capacity = + std::max(size_t(4), 2 * cq->outstanding_tag_capacity); + cq->outstanding_tags = static_cast(gpr_realloc( + cq->outstanding_tags, + sizeof(*cq->outstanding_tags) * cq->outstanding_tag_capacity)); + } + cq->outstanding_tags[cq->outstanding_tag_count++] = tag; + gpr_mu_unlock(cq->mu); +#endif + return cq->vtable->begin_op(cq, tag); +} + +/* Queue a GRPC_OP_COMPLETED operation to a completion queue (with a + * completion + * type of GRPC_CQ_NEXT) */ +static void cq_end_op_for_next( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool /*internal*/) { + GPR_TIMER_SCOPE("cq_end_op_for_next", 0); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace) || + (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE)) { + std::string errmsg = grpc_error_std_string(error); + GRPC_API_TRACE( + "cq_end_op_for_next(cq=%p, tag=%p, error=%s, " + "done=%p, done_arg=%p, storage=%p)", + 6, (cq, tag, errmsg.c_str(), done, done_arg, storage)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, "Operation failed: tag=%p, error=%s", tag, + errmsg.c_str()); + } + } + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + int is_success = (error == GRPC_ERROR_NONE); + + storage->tag = tag; + storage->done = done; + storage->done_arg = done_arg; + storage->next = static_cast(is_success); + + cq_check_tag(cq, tag, true); /* Used in debug builds only */ + + if (g_cached_cq == cq && g_cached_event == nullptr) { + g_cached_event = storage; + } else { + /* Add the completion to the queue */ + bool is_first = cqd->queue.Push(storage); + cqd->things_queued_ever.fetch_add(1, std::memory_order_relaxed); + /* Since we do not hold the cq lock here, it is important to do an 'acquire' + load here (instead of a 'no_barrier' load) to match with the release + store + (done via pending_events.fetch_sub(1, ACQ_REL)) in cq_shutdown_next + */ + if (cqd->pending_events.load(std::memory_order_acquire) != 1) { + /* Only kick if this is the first item queued */ + if (is_first) { + gpr_mu_lock(cq->mu); + grpc_error_handle kick_error = + cq->poller_vtable->kick(POLLSET_FROM_CQ(cq), nullptr); + gpr_mu_unlock(cq->mu); + + if (kick_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Kick failed: %s", + grpc_error_std_string(kick_error).c_str()); + GRPC_ERROR_UNREF(kick_error); + } + } + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + GRPC_CQ_INTERNAL_REF(cq, "shutting_down"); + gpr_mu_lock(cq->mu); + cq_finish_shutdown_next(cq); + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down"); + } + } else { + GRPC_CQ_INTERNAL_REF(cq, "shutting_down"); + cqd->pending_events.store(0, std::memory_order_release); + gpr_mu_lock(cq->mu); + cq_finish_shutdown_next(cq); + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down"); + } + } + + GRPC_ERROR_UNREF(error); +} + +/* Queue a GRPC_OP_COMPLETED operation to a completion queue (with a + * completion + * type of GRPC_CQ_PLUCK) */ +static void cq_end_op_for_pluck( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool /*internal*/) { + GPR_TIMER_SCOPE("cq_end_op_for_pluck", 0); + + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + int is_success = (error == GRPC_ERROR_NONE); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace) || + (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE)) { + std::string errmsg = grpc_error_std_string(error).c_str(); + GRPC_API_TRACE( + "cq_end_op_for_pluck(cq=%p, tag=%p, error=%s, " + "done=%p, done_arg=%p, storage=%p)", + 6, (cq, tag, errmsg.c_str(), done, done_arg, storage)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Operation failed: tag=%p, error=%s", tag, + errmsg.c_str()); + } + } + + storage->tag = tag; + storage->done = done; + storage->done_arg = done_arg; + storage->next = reinterpret_cast(&cqd->completed_head) | + static_cast(is_success); + + gpr_mu_lock(cq->mu); + cq_check_tag(cq, tag, false); /* Used in debug builds only */ + + /* Add to the list of completions */ + cqd->things_queued_ever.fetch_add(1, std::memory_order_relaxed); + cqd->completed_tail->next = + reinterpret_cast(storage) | (1u & cqd->completed_tail->next); + cqd->completed_tail = storage; + + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + cq_finish_shutdown_pluck(cq); + gpr_mu_unlock(cq->mu); + } else { + grpc_pollset_worker* pluck_worker = nullptr; + for (int i = 0; i < cqd->num_pluckers; i++) { + if (cqd->pluckers[i].tag == tag) { + pluck_worker = *cqd->pluckers[i].worker; + break; + } + } + + grpc_error_handle kick_error = + cq->poller_vtable->kick(POLLSET_FROM_CQ(cq), pluck_worker); + gpr_mu_unlock(cq->mu); + if (kick_error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Kick failed: %s", + grpc_error_std_string(kick_error).c_str()); + GRPC_ERROR_UNREF(kick_error); + } + } + + GRPC_ERROR_UNREF(error); +} + +static void functor_callback(void* arg, grpc_error_handle error) { + auto* functor = static_cast(arg); + functor->functor_run(functor, error == GRPC_ERROR_NONE); +} + +/* Complete an event on a completion queue of type GRPC_CQ_CALLBACK */ +static void cq_end_op_for_callback( + grpc_completion_queue* cq, void* tag, grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), void* done_arg, + grpc_cq_completion* storage, bool internal) { + GPR_TIMER_SCOPE("cq_end_op_for_callback", 0); + + cq_callback_data* cqd = static_cast DATA_FROM_CQ(cq); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_api_trace) || + (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE)) { + std::string errmsg = grpc_error_std_string(error); + GRPC_API_TRACE( + "cq_end_op_for_callback(cq=%p, tag=%p, error=%s, " + "done=%p, done_arg=%p, storage=%p)", + 6, (cq, tag, errmsg.c_str(), done, done_arg, storage)); + if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_operation_failures) && + error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Operation failed: tag=%p, error=%s", tag, + errmsg.c_str()); + } + } + + // The callback-based CQ isn't really a queue at all and thus has no need + // for reserved storage. Invoke the done callback right away to release it. + done(done_arg, storage); + + cq_check_tag(cq, tag, true); /* Used in debug builds only */ + + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + cq_finish_shutdown_callback(cq); + } + + // If possible, schedule the callback onto an existing thread-local + // ApplicationCallbackExecCtx, which is a work queue. This is possible for: + // 1. The callback is internally-generated and there is an ACEC available + // 2. The callback is marked inlineable and there is an ACEC available + // 3. We are already running in a background poller thread (which always has + // an ACEC available at the base of the stack). + auto* functor = static_cast(tag); + if (((internal || functor->inlineable) && + grpc_core::ApplicationCallbackExecCtx::Available()) || + grpc_iomgr_is_any_background_poller_thread()) { + grpc_core::ApplicationCallbackExecCtx::Enqueue(functor, + (error == GRPC_ERROR_NONE)); + GRPC_ERROR_UNREF(error); + return; + } + + // Schedule the callback on a closure if not internal or triggered + // from a background poller thread. + grpc_core::Executor::Run( + GRPC_CLOSURE_CREATE(functor_callback, functor, nullptr), error); +} + +void grpc_cq_end_op(grpc_completion_queue* cq, void* tag, + grpc_error_handle error, + void (*done)(void* done_arg, grpc_cq_completion* storage), + void* done_arg, grpc_cq_completion* storage, + bool internal) { + cq->vtable->end_op(cq, tag, error, done, done_arg, storage, internal); +} + +struct cq_is_finished_arg { + gpr_atm last_seen_things_queued_ever; + grpc_completion_queue* cq; + grpc_millis deadline; + grpc_cq_completion* stolen_completion; + void* tag; /* for pluck */ + bool first_loop; +}; +class ExecCtxNext : public grpc_core::ExecCtx { + public: + explicit ExecCtxNext(void* arg) + : ExecCtx(0), check_ready_to_finish_arg_(arg) {} + + bool CheckReadyToFinish() override { + cq_is_finished_arg* a = + static_cast(check_ready_to_finish_arg_); + grpc_completion_queue* cq = a->cq; + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + GPR_ASSERT(a->stolen_completion == nullptr); + + intptr_t current_last_seen_things_queued_ever = + cqd->things_queued_ever.load(std::memory_order_relaxed); + + if (current_last_seen_things_queued_ever != + a->last_seen_things_queued_ever) { + a->last_seen_things_queued_ever = + cqd->things_queued_ever.load(std::memory_order_relaxed); + + /* Pop a cq_completion from the queue. Returns NULL if the queue is empty + * might return NULL in some cases even if the queue is not empty; but + * that + * is ok and doesn't affect correctness. Might effect the tail latencies a + * bit) */ + a->stolen_completion = cqd->queue.Pop(); + if (a->stolen_completion != nullptr) { + return true; + } + } + return !a->first_loop && a->deadline < grpc_core::ExecCtx::Get()->Now(); + } + + private: + void* check_ready_to_finish_arg_; +}; + +#ifndef NDEBUG +static void dump_pending_tags(grpc_completion_queue* cq) { + if (!GRPC_TRACE_FLAG_ENABLED(grpc_trace_pending_tags)) return; + std::vector parts; + parts.push_back("PENDING TAGS:"); + gpr_mu_lock(cq->mu); + for (size_t i = 0; i < cq->outstanding_tag_count; i++) { + parts.push_back(absl::StrFormat(" %p", cq->outstanding_tags[i])); + } + gpr_mu_unlock(cq->mu); + gpr_log(GPR_DEBUG, "%s", absl::StrJoin(parts, "").c_str()); +} +#else +static void dump_pending_tags(grpc_completion_queue* /*cq*/) {} +#endif + +static grpc_event cq_next(grpc_completion_queue* cq, gpr_timespec deadline, + void* reserved) { + GPR_TIMER_SCOPE("grpc_completion_queue_next", 0); + + grpc_event ret; + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + + GRPC_API_TRACE( + "grpc_completion_queue_next(" + "cq=%p, " + "deadline=gpr_timespec { tv_sec: %" PRId64 + ", tv_nsec: %d, clock_type: %d }, " + "reserved=%p)", + 5, + (cq, deadline.tv_sec, deadline.tv_nsec, (int)deadline.clock_type, + reserved)); + GPR_ASSERT(!reserved); + + dump_pending_tags(cq); + + GRPC_CQ_INTERNAL_REF(cq, "next"); + + grpc_millis deadline_millis = grpc_timespec_to_millis_round_up(deadline); + cq_is_finished_arg is_finished_arg = { + cqd->things_queued_ever.load(std::memory_order_relaxed), + cq, + deadline_millis, + nullptr, + nullptr, + true}; + ExecCtxNext exec_ctx(&is_finished_arg); + for (;;) { + grpc_millis iteration_deadline = deadline_millis; + + if (is_finished_arg.stolen_completion != nullptr) { + grpc_cq_completion* c = is_finished_arg.stolen_completion; + is_finished_arg.stolen_completion = nullptr; + ret.type = GRPC_OP_COMPLETE; + ret.success = c->next & 1u; + ret.tag = c->tag; + c->done(c->done_arg, c); + break; + } + + grpc_cq_completion* c = cqd->queue.Pop(); + + if (c != nullptr) { + ret.type = GRPC_OP_COMPLETE; + ret.success = c->next & 1u; + ret.tag = c->tag; + c->done(c->done_arg, c); + break; + } else { + /* If c == NULL it means either the queue is empty OR in an transient + inconsistent state. If it is the latter, we shold do a 0-timeout poll + so that the thread comes back quickly from poll to make a second + attempt at popping. Not doing this can potentially deadlock this + thread forever (if the deadline is infinity) */ + if (cqd->queue.num_items() > 0) { + iteration_deadline = 0; + } + } + + if (cqd->pending_events.load(std::memory_order_acquire) == 0) { + /* Before returning, check if the queue has any items left over (since + MultiProducerSingleConsumerQueue::Pop() can sometimes return NULL + even if the queue is not empty. If so, keep retrying but do not + return GRPC_QUEUE_SHUTDOWN */ + if (cqd->queue.num_items() > 0) { + /* Go to the beginning of the loop. No point doing a poll because + (cq->shutdown == true) is only possible when there is no pending + work (i.e cq->pending_events == 0) and any outstanding completion + events should have already been queued on this cq */ + continue; + } + + ret.type = GRPC_QUEUE_SHUTDOWN; + ret.success = 0; + break; + } + + if (!is_finished_arg.first_loop && + grpc_core::ExecCtx::Get()->Now() >= deadline_millis) { + ret.type = GRPC_QUEUE_TIMEOUT; + ret.success = 0; + dump_pending_tags(cq); + break; + } + + /* The main polling work happens in grpc_pollset_work */ + gpr_mu_lock(cq->mu); + cq->num_polls++; + grpc_error_handle err = cq->poller_vtable->work( + POLLSET_FROM_CQ(cq), nullptr, iteration_deadline); + gpr_mu_unlock(cq->mu); + + if (err != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Completion queue next failed: %s", + grpc_error_std_string(err).c_str()); + GRPC_ERROR_UNREF(err); + if (err == GRPC_ERROR_CANCELLED) { + ret.type = GRPC_QUEUE_SHUTDOWN; + } else { + ret.type = GRPC_QUEUE_TIMEOUT; + } + ret.success = 0; + dump_pending_tags(cq); + break; + } + is_finished_arg.first_loop = false; + } + + if (cqd->queue.num_items() > 0 && + cqd->pending_events.load(std::memory_order_acquire) > 0) { + gpr_mu_lock(cq->mu); + (void)cq->poller_vtable->kick(POLLSET_FROM_CQ(cq), nullptr); + gpr_mu_unlock(cq->mu); + } + + GRPC_SURFACE_TRACE_RETURNED_EVENT(cq, &ret); + GRPC_CQ_INTERNAL_UNREF(cq, "next"); + + GPR_ASSERT(is_finished_arg.stolen_completion == nullptr); + + return ret; +} + +/* Finishes the completion queue shutdown. This means that there are no more + completion events / tags expected from the completion queue + - Must be called under completion queue lock + - Must be called only once in completion queue's lifetime + - grpc_completion_queue_shutdown() MUST have been called before calling + this function */ +static void cq_finish_shutdown_next(grpc_completion_queue* cq) { + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + + GPR_ASSERT(cqd->shutdown_called); + GPR_ASSERT(cqd->pending_events.load(std::memory_order_relaxed) == 0); + + cq->poller_vtable->shutdown(POLLSET_FROM_CQ(cq), &cq->pollset_shutdown_done); +} + +static void cq_shutdown_next(grpc_completion_queue* cq) { + cq_next_data* cqd = static_cast DATA_FROM_CQ(cq); + + /* Need an extra ref for cq here because: + * We call cq_finish_shutdown_next() below, that would call pollset shutdown. + * Pollset shutdown decrements the cq ref count which can potentially destroy + * the cq (if that happens to be the last ref). + * Creating an extra ref here prevents the cq from getting destroyed while + * this function is still active */ + GRPC_CQ_INTERNAL_REF(cq, "shutting_down"); + gpr_mu_lock(cq->mu); + if (cqd->shutdown_called) { + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down"); + return; + } + cqd->shutdown_called = true; + /* Doing acq/release fetch_sub here to match with + * cq_begin_op_for_next and cq_end_op_for_next functions which read/write + * on this counter without necessarily holding a lock on cq */ + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + cq_finish_shutdown_next(cq); + } + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down"); +} + +grpc_event grpc_completion_queue_next(grpc_completion_queue* cq, + gpr_timespec deadline, void* reserved) { + return cq->vtable->next(cq, deadline, reserved); +} + +static int add_plucker(grpc_completion_queue* cq, void* tag, + grpc_pollset_worker** worker) { + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + if (cqd->num_pluckers == GRPC_MAX_COMPLETION_QUEUE_PLUCKERS) { + return 0; + } + cqd->pluckers[cqd->num_pluckers].tag = tag; + cqd->pluckers[cqd->num_pluckers].worker = worker; + cqd->num_pluckers++; + return 1; +} + +static void del_plucker(grpc_completion_queue* cq, void* tag, + grpc_pollset_worker** worker) { + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + for (int i = 0; i < cqd->num_pluckers; i++) { + if (cqd->pluckers[i].tag == tag && cqd->pluckers[i].worker == worker) { + cqd->num_pluckers--; + std::swap(cqd->pluckers[i], cqd->pluckers[cqd->num_pluckers]); + return; + } + } + GPR_UNREACHABLE_CODE(return ); +} + +class ExecCtxPluck : public grpc_core::ExecCtx { + public: + explicit ExecCtxPluck(void* arg) + : ExecCtx(0), check_ready_to_finish_arg_(arg) {} + + bool CheckReadyToFinish() override { + cq_is_finished_arg* a = + static_cast(check_ready_to_finish_arg_); + grpc_completion_queue* cq = a->cq; + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + + GPR_ASSERT(a->stolen_completion == nullptr); + gpr_atm current_last_seen_things_queued_ever = + cqd->things_queued_ever.load(std::memory_order_relaxed); + if (current_last_seen_things_queued_ever != + a->last_seen_things_queued_ever) { + gpr_mu_lock(cq->mu); + a->last_seen_things_queued_ever = + cqd->things_queued_ever.load(std::memory_order_relaxed); + grpc_cq_completion* c; + grpc_cq_completion* prev = &cqd->completed_head; + while ((c = reinterpret_cast( + prev->next & ~static_cast(1))) != + &cqd->completed_head) { + if (c->tag == a->tag) { + prev->next = (prev->next & static_cast(1)) | + (c->next & ~static_cast(1)); + if (c == cqd->completed_tail) { + cqd->completed_tail = prev; + } + gpr_mu_unlock(cq->mu); + a->stolen_completion = c; + return true; + } + prev = c; + } + gpr_mu_unlock(cq->mu); + } + return !a->first_loop && a->deadline < grpc_core::ExecCtx::Get()->Now(); + } + + private: + void* check_ready_to_finish_arg_; +}; + +static grpc_event cq_pluck(grpc_completion_queue* cq, void* tag, + gpr_timespec deadline, void* reserved) { + GPR_TIMER_SCOPE("grpc_completion_queue_pluck", 0); + + grpc_event ret; + grpc_cq_completion* c; + grpc_cq_completion* prev; + grpc_pollset_worker* worker = nullptr; + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + + if (GRPC_TRACE_FLAG_ENABLED(grpc_cq_pluck_trace)) { + GRPC_API_TRACE( + "grpc_completion_queue_pluck(" + "cq=%p, tag=%p, " + "deadline=gpr_timespec { tv_sec: %" PRId64 + ", tv_nsec: %d, clock_type: %d }, " + "reserved=%p)", + 6, + (cq, tag, deadline.tv_sec, deadline.tv_nsec, (int)deadline.clock_type, + reserved)); + } + GPR_ASSERT(!reserved); + + dump_pending_tags(cq); + + GRPC_CQ_INTERNAL_REF(cq, "pluck"); + gpr_mu_lock(cq->mu); + grpc_millis deadline_millis = grpc_timespec_to_millis_round_up(deadline); + cq_is_finished_arg is_finished_arg = { + cqd->things_queued_ever.load(std::memory_order_relaxed), + cq, + deadline_millis, + nullptr, + tag, + true}; + ExecCtxPluck exec_ctx(&is_finished_arg); + for (;;) { + if (is_finished_arg.stolen_completion != nullptr) { + gpr_mu_unlock(cq->mu); + c = is_finished_arg.stolen_completion; + is_finished_arg.stolen_completion = nullptr; + ret.type = GRPC_OP_COMPLETE; + ret.success = c->next & 1u; + ret.tag = c->tag; + c->done(c->done_arg, c); + break; + } + prev = &cqd->completed_head; + while ((c = reinterpret_cast( + prev->next & ~static_cast(1))) != + &cqd->completed_head) { + if (c->tag == tag) { + prev->next = (prev->next & static_cast(1)) | + (c->next & ~static_cast(1)); + if (c == cqd->completed_tail) { + cqd->completed_tail = prev; + } + gpr_mu_unlock(cq->mu); + ret.type = GRPC_OP_COMPLETE; + ret.success = c->next & 1u; + ret.tag = c->tag; + c->done(c->done_arg, c); + goto done; + } + prev = c; + } + if (cqd->shutdown.load(std::memory_order_relaxed)) { + gpr_mu_unlock(cq->mu); + ret.type = GRPC_QUEUE_SHUTDOWN; + ret.success = 0; + break; + } + if (!add_plucker(cq, tag, &worker)) { + gpr_log(GPR_DEBUG, + "Too many outstanding grpc_completion_queue_pluck calls: maximum " + "is %d", + GRPC_MAX_COMPLETION_QUEUE_PLUCKERS); + gpr_mu_unlock(cq->mu); + /* TODO(ctiller): should we use a different result here */ + ret.type = GRPC_QUEUE_TIMEOUT; + ret.success = 0; + dump_pending_tags(cq); + break; + } + if (!is_finished_arg.first_loop && + grpc_core::ExecCtx::Get()->Now() >= deadline_millis) { + del_plucker(cq, tag, &worker); + gpr_mu_unlock(cq->mu); + ret.type = GRPC_QUEUE_TIMEOUT; + ret.success = 0; + dump_pending_tags(cq); + break; + } + cq->num_polls++; + grpc_error_handle err = + cq->poller_vtable->work(POLLSET_FROM_CQ(cq), &worker, deadline_millis); + if (err != GRPC_ERROR_NONE) { + del_plucker(cq, tag, &worker); + gpr_mu_unlock(cq->mu); + gpr_log(GPR_ERROR, "Completion queue pluck failed: %s", + grpc_error_std_string(err).c_str()); + GRPC_ERROR_UNREF(err); + ret.type = GRPC_QUEUE_TIMEOUT; + ret.success = 0; + dump_pending_tags(cq); + break; + } + is_finished_arg.first_loop = false; + del_plucker(cq, tag, &worker); + } +done: + GRPC_SURFACE_TRACE_RETURNED_EVENT(cq, &ret); + GRPC_CQ_INTERNAL_UNREF(cq, "pluck"); + + GPR_ASSERT(is_finished_arg.stolen_completion == nullptr); + + return ret; +} + +grpc_event grpc_completion_queue_pluck(grpc_completion_queue* cq, void* tag, + gpr_timespec deadline, void* reserved) { + return cq->vtable->pluck(cq, tag, deadline, reserved); +} + +static void cq_finish_shutdown_pluck(grpc_completion_queue* cq) { + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + + GPR_ASSERT(cqd->shutdown_called); + GPR_ASSERT(!cqd->shutdown.load(std::memory_order_relaxed)); + cqd->shutdown.store(true, std::memory_order_relaxed); + + cq->poller_vtable->shutdown(POLLSET_FROM_CQ(cq), &cq->pollset_shutdown_done); +} + +/* NOTE: This function is almost exactly identical to cq_shutdown_next() but + * merging them is a bit tricky and probably not worth it */ +static void cq_shutdown_pluck(grpc_completion_queue* cq) { + cq_pluck_data* cqd = static_cast DATA_FROM_CQ(cq); + + /* Need an extra ref for cq here because: + * We call cq_finish_shutdown_pluck() below, that would call pollset shutdown. + * Pollset shutdown decrements the cq ref count which can potentially destroy + * the cq (if that happens to be the last ref). + * Creating an extra ref here prevents the cq from getting destroyed while + * this function is still active */ + GRPC_CQ_INTERNAL_REF(cq, "shutting_down (pluck cq)"); + gpr_mu_lock(cq->mu); + if (cqd->shutdown_called) { + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down (pluck cq)"); + return; + } + cqd->shutdown_called = true; + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + cq_finish_shutdown_pluck(cq); + } + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down (pluck cq)"); +} + +static void cq_finish_shutdown_callback(grpc_completion_queue* cq) { + cq_callback_data* cqd = static_cast DATA_FROM_CQ(cq); + auto* callback = cqd->shutdown_callback; + + GPR_ASSERT(cqd->shutdown_called); + + cq->poller_vtable->shutdown(POLLSET_FROM_CQ(cq), &cq->pollset_shutdown_done); + if (grpc_iomgr_is_any_background_poller_thread()) { + grpc_core::ApplicationCallbackExecCtx::Enqueue(callback, true); + return; + } + + // Schedule the callback on a closure if not internal or triggered + // from a background poller thread. + grpc_core::Executor::Run( + GRPC_CLOSURE_CREATE(functor_callback, callback, nullptr), + GRPC_ERROR_NONE); +} + +static void cq_shutdown_callback(grpc_completion_queue* cq) { + cq_callback_data* cqd = static_cast DATA_FROM_CQ(cq); + + /* Need an extra ref for cq here because: + * We call cq_finish_shutdown_callback() below, which calls pollset shutdown. + * Pollset shutdown decrements the cq ref count which can potentially destroy + * the cq (if that happens to be the last ref). + * Creating an extra ref here prevents the cq from getting destroyed while + * this function is still active */ + GRPC_CQ_INTERNAL_REF(cq, "shutting_down (callback cq)"); + gpr_mu_lock(cq->mu); + if (cqd->shutdown_called) { + gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down (callback cq)"); + return; + } + cqd->shutdown_called = true; + if (cqd->pending_events.fetch_sub(1, std::memory_order_acq_rel) == 1) { + gpr_mu_unlock(cq->mu); + cq_finish_shutdown_callback(cq); + } else { + gpr_mu_unlock(cq->mu); + } + GRPC_CQ_INTERNAL_UNREF(cq, "shutting_down (callback cq)"); +} + +/* Shutdown simply drops a ref that we reserved at creation time; if we drop + to zero here, then enter shutdown mode and wake up any waiters */ +void grpc_completion_queue_shutdown(grpc_completion_queue* cq) { + GPR_TIMER_SCOPE("grpc_completion_queue_shutdown", 0); + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_completion_queue_shutdown(cq=%p)", 1, (cq)); + cq->vtable->shutdown(cq); +} + +void grpc_completion_queue_destroy(grpc_completion_queue* cq) { + GPR_TIMER_SCOPE("grpc_completion_queue_destroy", 0); + GRPC_API_TRACE("grpc_completion_queue_destroy(cq=%p)", 1, (cq)); + grpc_completion_queue_shutdown(cq); + + grpc_core::ExecCtx exec_ctx; + GRPC_CQ_INTERNAL_UNREF(cq, "destroy"); +} + +grpc_pollset* grpc_cq_pollset(grpc_completion_queue* cq) { + return cq->poller_vtable->can_get_pollset ? POLLSET_FROM_CQ(cq) : nullptr; +} + +bool grpc_cq_can_listen(grpc_completion_queue* cq) { + return cq->poller_vtable->can_listen; +} diff --git a/src/core/lib/surface/completion_queue_factory.cc b/src/core/lib/surface/completion_queue_factory.cc new file mode 100644 index 00000000..80aaba24 --- /dev/null +++ b/src/core/lib/surface/completion_queue_factory.cc @@ -0,0 +1,88 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/completion_queue_factory.h" + +#include + +#include "src/core/lib/surface/completion_queue.h" + +/* + * == Default completion queue factory implementation == + */ + +static grpc_completion_queue* default_create( + const grpc_completion_queue_factory* /*factory*/, + const grpc_completion_queue_attributes* attr) { + return grpc_completion_queue_create_internal( + attr->cq_completion_type, attr->cq_polling_type, attr->cq_shutdown_cb); +} + +static grpc_completion_queue_factory_vtable default_vtable = {default_create}; + +static const grpc_completion_queue_factory g_default_cq_factory = { + "Default Factory", nullptr, &default_vtable}; + +/* + * == Completion queue factory APIs + */ + +const grpc_completion_queue_factory* grpc_completion_queue_factory_lookup( + const grpc_completion_queue_attributes* attributes) { + GPR_ASSERT(attributes->version >= 1 && + attributes->version <= GRPC_CQ_CURRENT_VERSION); + + /* The default factory can handle version 1 of the attributes structure. We + may have to change this as more fields are added to the structure */ + return &g_default_cq_factory; +} + +/* + * == Completion queue creation APIs == + */ + +grpc_completion_queue* grpc_completion_queue_create_for_next(void* reserved) { + GPR_ASSERT(!reserved); + grpc_completion_queue_attributes attr = {1, GRPC_CQ_NEXT, + GRPC_CQ_DEFAULT_POLLING, nullptr}; + return g_default_cq_factory.vtable->create(&g_default_cq_factory, &attr); +} + +grpc_completion_queue* grpc_completion_queue_create_for_pluck(void* reserved) { + GPR_ASSERT(!reserved); + grpc_completion_queue_attributes attr = {1, GRPC_CQ_PLUCK, + GRPC_CQ_DEFAULT_POLLING, nullptr}; + return g_default_cq_factory.vtable->create(&g_default_cq_factory, &attr); +} + +grpc_completion_queue* grpc_completion_queue_create_for_callback( + grpc_completion_queue_functor* shutdown_callback, void* reserved) { + GPR_ASSERT(!reserved); + grpc_completion_queue_attributes attr = { + 2, GRPC_CQ_CALLBACK, GRPC_CQ_DEFAULT_POLLING, shutdown_callback}; + return g_default_cq_factory.vtable->create(&g_default_cq_factory, &attr); +} + +grpc_completion_queue* grpc_completion_queue_create( + const grpc_completion_queue_factory* factory, + const grpc_completion_queue_attributes* attr, void* reserved) { + GPR_ASSERT(!reserved); + return factory->vtable->create(factory, attr); +} diff --git a/src/core/lib/surface/event_string.cc b/src/core/lib/surface/event_string.cc new file mode 100644 index 00000000..da40734a --- /dev/null +++ b/src/core/lib/surface/event_string.cc @@ -0,0 +1,62 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/event_string.h" + +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" + +static void addhdr(grpc_event* ev, std::vector* buf) { + buf->push_back(absl::StrFormat("tag:%p", ev->tag)); +} + +static const char* errstr(int success) { return success ? "OK" : "ERROR"; } + +static void adderr(int success, std::vector* buf) { + buf->push_back(absl::StrFormat(" %s", errstr(success))); +} + +std::string grpc_event_string(grpc_event* ev) { + if (ev == nullptr) return "null"; + std::vector out; + switch (ev->type) { + case GRPC_QUEUE_TIMEOUT: + out.push_back("QUEUE_TIMEOUT"); + break; + case GRPC_QUEUE_SHUTDOWN: + out.push_back("QUEUE_SHUTDOWN"); + break; + case GRPC_OP_COMPLETE: + out.push_back("OP_COMPLETE: "); + addhdr(ev, &out); + adderr(ev->success, &out); + break; + } + return absl::StrJoin(out, ""); +} diff --git a/src/core/lib/surface/init.cc b/src/core/lib/surface/init.cc new file mode 100644 index 00000000..3bc610c2 --- /dev/null +++ b/src/core/lib/surface/init.cc @@ -0,0 +1,220 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/init.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/fork.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/call_combiner.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/lame_client.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/bdp_estimator.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/transport_impl.h" + +/* (generated) built in registry of plugins */ +extern void grpc_register_built_in_plugins(void); + +#define MAX_PLUGINS 128 + +static gpr_once g_basic_init = GPR_ONCE_INIT; +static grpc_core::Mutex* g_init_mu; +static int g_initializations ABSL_GUARDED_BY(g_init_mu) = 0; +static grpc_core::CondVar* g_shutting_down_cv; +static bool g_shutting_down ABSL_GUARDED_BY(g_init_mu) = false; + +static void do_basic_init(void) { + gpr_log_verbosity_init(); + g_init_mu = new grpc_core::Mutex(); + g_shutting_down_cv = new grpc_core::CondVar(); + grpc_register_built_in_plugins(); + grpc_cq_global_init(); + grpc_core::grpc_executor_global_init(); + gpr_time_init(); +} + +typedef struct grpc_plugin { + void (*init)(); + void (*destroy)(); +} grpc_plugin; + +static grpc_plugin g_all_of_the_plugins[MAX_PLUGINS]; +static int g_number_of_plugins = 0; + +void grpc_register_plugin(void (*init)(void), void (*destroy)(void)) { + GRPC_API_TRACE("grpc_register_plugin(init=%p, destroy=%p)", 2, + ((void*)(intptr_t)init, (void*)(intptr_t)destroy)); + GPR_ASSERT(g_number_of_plugins != MAX_PLUGINS); + g_all_of_the_plugins[g_number_of_plugins].init = init; + g_all_of_the_plugins[g_number_of_plugins].destroy = destroy; + g_number_of_plugins++; +} + +void grpc_init(void) { + gpr_once_init(&g_basic_init, do_basic_init); + + grpc_core::MutexLock lock(g_init_mu); + if (++g_initializations == 1) { + if (g_shutting_down) { + g_shutting_down = false; + g_shutting_down_cv->SignalAll(); + } + grpc_core::Fork::GlobalInit(); + grpc_fork_handlers_auto_register(); + grpc_stats_init(); + grpc_slice_intern_init(); + grpc_mdctx_global_init(); + grpc_core::channelz::ChannelzRegistry::Init(); + grpc_security_pre_init(); + grpc_core::ApplicationCallbackExecCtx::GlobalInit(); + grpc_core::ExecCtx::GlobalInit(); + grpc_iomgr_init(); + gpr_timers_global_init(); + for (int i = 0; i < g_number_of_plugins; i++) { + if (g_all_of_the_plugins[i].init != nullptr) { + g_all_of_the_plugins[i].init(); + } + } + grpc_tracer_init(); + grpc_iomgr_start(); + } + + GRPC_API_TRACE("grpc_init(void)", 0, ()); +} + +void grpc_shutdown_internal_locked(void) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(g_init_mu) { + int i; + { + grpc_core::ExecCtx exec_ctx(0); + grpc_iomgr_shutdown_background_closure(); + { + grpc_timer_manager_set_threading(false); // shutdown timer_manager thread + for (i = g_number_of_plugins; i >= 0; i--) { + if (g_all_of_the_plugins[i].destroy != nullptr) { + g_all_of_the_plugins[i].destroy(); + } + } + } + grpc_iomgr_shutdown(); + gpr_timers_global_destroy(); + grpc_tracer_shutdown(); + grpc_mdctx_global_shutdown(); + grpc_slice_intern_shutdown(); + grpc_core::channelz::ChannelzRegistry::Shutdown(); + grpc_stats_shutdown(); + grpc_core::Fork::GlobalShutdown(); + } + grpc_core::ExecCtx::GlobalShutdown(); + grpc_core::ApplicationCallbackExecCtx::GlobalShutdown(); + g_shutting_down = false; + g_shutting_down_cv->SignalAll(); +} + +void grpc_shutdown_internal(void* /*ignored*/) { + GRPC_API_TRACE("grpc_shutdown_internal", 0, ()); + grpc_core::MutexLock lock(g_init_mu); + // We have released lock from the shutdown thread and it is possible that + // another grpc_init has been called, and do nothing if that is the case. + if (--g_initializations != 0) { + return; + } + grpc_shutdown_internal_locked(); +} + +void grpc_shutdown(void) { + GRPC_API_TRACE("grpc_shutdown(void)", 0, ()); + grpc_core::MutexLock lock(g_init_mu); + + if (--g_initializations == 0) { + grpc_core::ApplicationCallbackExecCtx* acec = + grpc_core::ApplicationCallbackExecCtx::Get(); + if (!grpc_iomgr_is_any_background_poller_thread() && + (acec == nullptr || + (acec->Flags() & GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD) == + 0)) { + // just run clean-up when this is called on non-executor thread. + gpr_log(GPR_DEBUG, "grpc_shutdown starts clean-up now"); + g_shutting_down = true; + grpc_shutdown_internal_locked(); + } else { + // spawn a detached thread to do the actual clean up in case we are + // currently in an executor thread. + gpr_log(GPR_DEBUG, "grpc_shutdown spawns clean-up thread"); + g_initializations++; + g_shutting_down = true; + grpc_core::Thread cleanup_thread( + "grpc_shutdown", grpc_shutdown_internal, nullptr, nullptr, + grpc_core::Thread::Options().set_joinable(false).set_tracked(false)); + cleanup_thread.Start(); + } + } +} + +void grpc_shutdown_blocking(void) { + GRPC_API_TRACE("grpc_shutdown_blocking(void)", 0, ()); + grpc_core::MutexLock lock(g_init_mu); + if (--g_initializations == 0) { + g_shutting_down = true; + grpc_shutdown_internal_locked(); + } +} + +int grpc_is_initialized(void) { + int r; + gpr_once_init(&g_basic_init, do_basic_init); + grpc_core::MutexLock lock(g_init_mu); + r = g_initializations > 0; + return r; +} + +void grpc_maybe_wait_for_async_shutdown(void) { + gpr_once_init(&g_basic_init, do_basic_init); + grpc_core::MutexLock lock(g_init_mu); + while (g_shutting_down) { + g_shutting_down_cv->Wait(g_init_mu); + } +} diff --git a/src/core/lib/surface/init_secure.cc b/src/core/lib/surface/init_secure.cc new file mode 100644 index 00000000..cd336fd4 --- /dev/null +++ b/src/core/lib/surface/init_secure.cc @@ -0,0 +1,103 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/security/authorization/sdk_server_authz_filter.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/plugin/plugin_credentials.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "src/core/lib/security/transport/secure_endpoint.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "src/core/lib/surface/init.h" +#include "src/core/tsi/transport_security_interface.h" + +void grpc_security_pre_init(void) {} + +static bool maybe_prepend_client_auth_filter( + grpc_channel_stack_builder* builder) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (args) { + for (size_t i = 0; i < args->num_args; i++) { + if (0 == strcmp(GRPC_ARG_SECURITY_CONNECTOR, args->args[i].key)) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_client_auth_filter, nullptr, nullptr); + } + } + } + return true; +} + +static bool maybe_prepend_server_auth_filter( + grpc_channel_stack_builder* builder) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (args) { + for (size_t i = 0; i < args->num_args; i++) { + if (0 == strcmp(GRPC_SERVER_CREDENTIALS_ARG, args->args[i].key)) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_server_auth_filter, nullptr, nullptr); + } + } + } + return true; +} + +static bool maybe_prepend_sdk_server_authz_filter( + grpc_channel_stack_builder* builder) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + const auto* provider = + grpc_channel_args_find_pointer( + args, GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER); + if (provider != nullptr) { + return grpc_channel_stack_builder_prepend_filter( + builder, &grpc_core::SdkServerAuthzFilter::kFilterVtable, nullptr, + nullptr); + } + return true; +} + +namespace grpc_core { +void RegisterSecurityFilters(CoreConfiguration::Builder* builder) { + // Register the auth client with a priority < INT_MAX to allow the authority + // filter -on which the auth filter depends- to be higher on the channel + // stack. + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, INT_MAX - 1, + maybe_prepend_client_auth_filter); + builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, + INT_MAX - 1, + maybe_prepend_client_auth_filter); + builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, INT_MAX - 1, + maybe_prepend_server_auth_filter); + // Register the SdkServerAuthzFilter with a priority less than + // server_auth_filter to allow server_auth_filter on which the sdk filter + // depends on to be higher on the channel stack. + builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, INT_MAX - 2, + maybe_prepend_sdk_server_authz_filter); +} +} // namespace grpc_core diff --git a/src/core/lib/surface/init_unsecure.cc b/src/core/lib/surface/init_unsecure.cc new file mode 100644 index 00000000..c9e1688d --- /dev/null +++ b/src/core/lib/surface/init_unsecure.cc @@ -0,0 +1,27 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/surface/init.h" + +void grpc_security_pre_init(void) {} + +void grpc_register_security_filters(void) {} diff --git a/src/core/lib/surface/lame_client.cc b/src/core/lib/surface/lame_client.cc new file mode 100644 index 00000000..a1fbd069 --- /dev/null +++ b/src/core/lib/surface/lame_client.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/lame_client.h" + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/static_metadata.h" + +#define GRPC_ARG_LAME_FILTER_ERROR "grpc.lame_filter_error" + +namespace grpc_core { + +namespace { + +struct ChannelData { + explicit ChannelData(grpc_channel_element_args* args) + : state_tracker("lame_channel", GRPC_CHANNEL_SHUTDOWN) { + grpc_error_handle* err = grpc_channel_args_find_pointer( + args->channel_args, GRPC_ARG_LAME_FILTER_ERROR); + if (err != nullptr) error = GRPC_ERROR_REF(*err); + } + + ~ChannelData() { GRPC_ERROR_UNREF(error); } + + grpc_error_handle error = GRPC_ERROR_NONE; + Mutex mu; + ConnectivityStateTracker state_tracker; +}; + +struct CallData { + CallCombiner* call_combiner; +}; + +static void lame_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + CallData* calld = static_cast(elem->call_data); + ChannelData* chand = static_cast(elem->channel_data); + grpc_transport_stream_op_batch_finish_with_failure( + op, GRPC_ERROR_REF(chand->error), calld->call_combiner); +} + +static void lame_get_channel_info(grpc_channel_element* /*elem*/, + const grpc_channel_info* /*channel_info*/) {} + +static void lame_start_transport_op(grpc_channel_element* elem, + grpc_transport_op* op) { + ChannelData* chand = static_cast(elem->channel_data); + { + MutexLock lock(&chand->mu); + if (op->start_connectivity_watch != nullptr) { + chand->state_tracker.AddWatcher(op->start_connectivity_watch_state, + std::move(op->start_connectivity_watch)); + } + if (op->stop_connectivity_watch != nullptr) { + chand->state_tracker.RemoveWatcher(op->stop_connectivity_watch); + } + } + if (op->send_ping.on_initiate != nullptr) { + ExecCtx::Run(DEBUG_LOCATION, op->send_ping.on_initiate, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("lame client channel")); + } + if (op->send_ping.on_ack != nullptr) { + ExecCtx::Run(DEBUG_LOCATION, op->send_ping.on_ack, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("lame client channel")); + } + GRPC_ERROR_UNREF(op->disconnect_with_error); + if (op->on_consumed != nullptr) { + ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); + } +} + +static grpc_error_handle lame_init_call_elem( + grpc_call_element* elem, const grpc_call_element_args* args) { + CallData* calld = static_cast(elem->call_data); + calld->call_combiner = args->call_combiner; + return GRPC_ERROR_NONE; +} + +static void lame_destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* then_schedule_closure) { + ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, GRPC_ERROR_NONE); +} + +static grpc_error_handle lame_init_channel_elem( + grpc_channel_element* elem, grpc_channel_element_args* args) { + new (elem->channel_data) ChannelData(args); + return GRPC_ERROR_NONE; +} + +static void lame_destroy_channel_elem(grpc_channel_element* elem) { + ChannelData* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +// Channel arg vtable for a grpc_error_handle. +void* ErrorCopy(void* p) { + grpc_error_handle* new_error = nullptr; + if (p != nullptr) { + grpc_error_handle* error = static_cast(p); + new_error = new grpc_error_handle(); + *new_error = GRPC_ERROR_REF(*error); + } + return new_error; +} +void ErrorDestroy(void* p) { + if (p != nullptr) { + grpc_error_handle* error = static_cast(p); + GRPC_ERROR_UNREF(*error); + delete error; + } +} +int ErrorCompare(void* p, void* q) { return grpc_core::QsortCompare(p, q); } +const grpc_arg_pointer_vtable kLameFilterErrorArgVtable = { + ErrorCopy, ErrorDestroy, ErrorCompare}; + +} // namespace + +grpc_arg MakeLameClientErrorArg(grpc_error_handle* error) { + return grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_LAME_FILTER_ERROR), error, + &kLameFilterErrorArgVtable); +} + +} // namespace grpc_core + +const grpc_channel_filter grpc_lame_filter = { + grpc_core::lame_start_transport_stream_op_batch, + grpc_core::lame_start_transport_op, + sizeof(grpc_core::CallData), + grpc_core::lame_init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + grpc_core::lame_destroy_call_elem, + sizeof(grpc_core::ChannelData), + grpc_core::lame_init_channel_elem, + grpc_core::lame_destroy_channel_elem, + grpc_core::lame_get_channel_info, + "lame-client", +}; + +#define CHANNEL_STACK_FROM_CHANNEL(c) ((grpc_channel_stack*)((c) + 1)) + +grpc_channel* grpc_lame_client_channel_create(const char* target, + grpc_status_code error_code, + const char* error_message) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE( + "grpc_lame_client_channel_create(target=%s, error_code=%d, " + "error_message=%s)", + 3, (target, (int)error_code, error_message)); + grpc_error_handle error = grpc_error_set_str( + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("lame client channel"), + GRPC_ERROR_INT_GRPC_STATUS, error_code), + GRPC_ERROR_STR_GRPC_MESSAGE, error_message); + grpc_arg error_arg = grpc_core::MakeLameClientErrorArg(&error); + grpc_channel_args args = {1, &error_arg}; + grpc_channel* channel = grpc_channel_create( + target, &args, GRPC_CLIENT_LAME_CHANNEL, nullptr, nullptr, 0, nullptr); + GRPC_ERROR_UNREF(error); + return channel; +} diff --git a/src/core/lib/surface/metadata_array.cc b/src/core/lib/surface/metadata_array.cc new file mode 100644 index 00000000..3633382a --- /dev/null +++ b/src/core/lib/surface/metadata_array.cc @@ -0,0 +1,36 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include "src/core/lib/surface/api_trace.h" + +void grpc_metadata_array_init(grpc_metadata_array* array) { + GRPC_API_TRACE("grpc_metadata_array_init(array=%p)", 1, (array)); + memset(array, 0, sizeof(*array)); +} + +void grpc_metadata_array_destroy(grpc_metadata_array* array) { + GRPC_API_TRACE("grpc_metadata_array_destroy(array=%p)", 1, (array)); + gpr_free(array->metadata); +} diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc new file mode 100644 index 00000000..5f85523c --- /dev/null +++ b/src/core/lib/surface/server.cc @@ -0,0 +1,1608 @@ +// +// Copyright 2015-2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/surface/server.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/spinlock.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/mpscq.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/init.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/static_metadata.h" + +namespace grpc_core { + +TraceFlag grpc_server_channel_trace(false, "server_channel"); + +// +// Server::RequestedCall +// + +struct Server::RequestedCall { + enum class Type { BATCH_CALL, REGISTERED_CALL }; + + RequestedCall(void* tag_arg, grpc_completion_queue* call_cq, + grpc_call** call_arg, grpc_metadata_array* initial_md, + grpc_call_details* details) + : type(Type::BATCH_CALL), + tag(tag_arg), + cq_bound_to_call(call_cq), + call(call_arg), + initial_metadata(initial_md) { + details->reserved = nullptr; + data.batch.details = details; + } + + RequestedCall(void* tag_arg, grpc_completion_queue* call_cq, + grpc_call** call_arg, grpc_metadata_array* initial_md, + RegisteredMethod* rm, gpr_timespec* deadline, + grpc_byte_buffer** optional_payload) + : type(Type::REGISTERED_CALL), + tag(tag_arg), + cq_bound_to_call(call_cq), + call(call_arg), + initial_metadata(initial_md) { + data.registered.method = rm; + data.registered.deadline = deadline; + data.registered.optional_payload = optional_payload; + } + + MultiProducerSingleConsumerQueue::Node mpscq_node; + const Type type; + void* const tag; + grpc_completion_queue* const cq_bound_to_call; + grpc_call** const call; + grpc_cq_completion completion; + grpc_metadata_array* const initial_metadata; + union { + struct { + grpc_call_details* details; + } batch; + struct { + RegisteredMethod* method; + gpr_timespec* deadline; + grpc_byte_buffer** optional_payload; + } registered; + } data; +}; + +// +// Server::RegisteredMethod +// + +struct Server::RegisteredMethod { + RegisteredMethod( + const char* method_arg, const char* host_arg, + grpc_server_register_method_payload_handling payload_handling_arg, + uint32_t flags_arg) + : method(method_arg == nullptr ? "" : method_arg), + host(host_arg == nullptr ? "" : host_arg), + payload_handling(payload_handling_arg), + flags(flags_arg) {} + + ~RegisteredMethod() = default; + + const std::string method; + const std::string host; + const grpc_server_register_method_payload_handling payload_handling; + const uint32_t flags; + // One request matcher per method. + std::unique_ptr matcher; +}; + +// +// Server::RequestMatcherInterface +// + +// RPCs that come in from the transport must be matched against RPC requests +// from the application. An incoming request from the application can be matched +// to an RPC that has already arrived or can be queued up for later use. +// Likewise, an RPC coming in from the transport can either be matched to a +// request that already arrived from the application or can be queued up for +// later use (marked pending). If there is a match, the request's tag is posted +// on the request's notification CQ. +// +// RequestMatcherInterface is the base class to provide this functionality. +class Server::RequestMatcherInterface { + public: + virtual ~RequestMatcherInterface() {} + + // Unref the calls associated with any incoming RPCs in the pending queue (not + // yet matched to an application-requested RPC). + virtual void ZombifyPending() = 0; + + // Mark all application-requested RPCs failed if they have not been matched to + // an incoming RPC. The error parameter indicates why the RPCs are being + // failed (always server shutdown in all current implementations). + virtual void KillRequests(grpc_error_handle error) = 0; + + // How many request queues are supported by this matcher. This is an abstract + // concept that essentially maps to gRPC completion queues. + virtual size_t request_queue_count() const = 0; + + // This function is invoked when the application requests a new RPC whose + // information is in the call parameter. The request_queue_index marks the + // queue onto which to place this RPC, and is typically associated with a gRPC + // CQ. If there are pending RPCs waiting to be matched, publish one (match it + // and notify the CQ). + virtual void RequestCallWithPossiblePublish(size_t request_queue_index, + RequestedCall* call) = 0; + + // This function is invoked on an incoming RPC, represented by the calld + // object. The RequestMatcher will try to match it against an + // application-requested RPC if possible or will place it in the pending queue + // otherwise. To enable some measure of fairness between server CQs, the match + // is done starting at the start_request_queue_index parameter in a cyclic + // order rather than always starting at 0. + virtual void MatchOrQueue(size_t start_request_queue_index, + CallData* calld) = 0; + + // Returns the server associated with this request matcher + virtual Server* server() const = 0; +}; + +// The RealRequestMatcher is an implementation of RequestMatcherInterface that +// actually uses all the features of RequestMatcherInterface: expecting the +// application to explicitly request RPCs and then matching those to incoming +// RPCs, along with a slow path by which incoming RPCs are put on a locked +// pending list if they aren't able to be matched to an application request. +class Server::RealRequestMatcher : public RequestMatcherInterface { + public: + explicit RealRequestMatcher(Server* server) + : server_(server), requests_per_cq_(server->cqs_.size()) {} + + ~RealRequestMatcher() override { + for (LockedMultiProducerSingleConsumerQueue& queue : requests_per_cq_) { + GPR_ASSERT(queue.Pop() == nullptr); + } + } + + void ZombifyPending() override { + while (!pending_.empty()) { + CallData* calld = pending_.front(); + calld->SetState(CallData::CallState::ZOMBIED); + calld->KillZombie(); + pending_.pop(); + } + } + + void KillRequests(grpc_error_handle error) override { + for (size_t i = 0; i < requests_per_cq_.size(); i++) { + RequestedCall* rc; + while ((rc = reinterpret_cast( + requests_per_cq_[i].Pop())) != nullptr) { + server_->FailCall(i, rc, GRPC_ERROR_REF(error)); + } + } + GRPC_ERROR_UNREF(error); + } + + size_t request_queue_count() const override { + return requests_per_cq_.size(); + } + + void RequestCallWithPossiblePublish(size_t request_queue_index, + RequestedCall* call) override { + if (requests_per_cq_[request_queue_index].Push(&call->mpscq_node)) { + /* this was the first queued request: we need to lock and start + matching calls */ + struct PendingCall { + RequestedCall* rc = nullptr; + CallData* calld; + }; + auto pop_next_pending = [this, request_queue_index] { + PendingCall pending_call; + { + MutexLock lock(&server_->mu_call_); + if (!pending_.empty()) { + pending_call.rc = reinterpret_cast( + requests_per_cq_[request_queue_index].Pop()); + if (pending_call.rc != nullptr) { + pending_call.calld = pending_.front(); + pending_.pop(); + } + } + } + return pending_call; + }; + while (true) { + PendingCall next_pending = pop_next_pending(); + if (next_pending.rc == nullptr) break; + if (!next_pending.calld->MaybeActivate()) { + // Zombied Call + next_pending.calld->KillZombie(); + } else { + next_pending.calld->Publish(request_queue_index, next_pending.rc); + } + } + } + } + + void MatchOrQueue(size_t start_request_queue_index, + CallData* calld) override { + for (size_t i = 0; i < requests_per_cq_.size(); i++) { + size_t cq_idx = (start_request_queue_index + i) % requests_per_cq_.size(); + RequestedCall* rc = + reinterpret_cast(requests_per_cq_[cq_idx].TryPop()); + if (rc != nullptr) { + GRPC_STATS_INC_SERVER_CQS_CHECKED(i); + calld->SetState(CallData::CallState::ACTIVATED); + calld->Publish(cq_idx, rc); + return; + } + } + // No cq to take the request found; queue it on the slow list. + GRPC_STATS_INC_SERVER_SLOWPATH_REQUESTS_QUEUED(); + // We need to ensure that all the queues are empty. We do this under + // the server mu_call_ lock to ensure that if something is added to + // an empty request queue, it will block until the call is actually + // added to the pending list. + RequestedCall* rc = nullptr; + size_t cq_idx = 0; + size_t loop_count; + { + MutexLock lock(&server_->mu_call_); + for (loop_count = 0; loop_count < requests_per_cq_.size(); loop_count++) { + cq_idx = + (start_request_queue_index + loop_count) % requests_per_cq_.size(); + rc = reinterpret_cast(requests_per_cq_[cq_idx].Pop()); + if (rc != nullptr) { + break; + } + } + if (rc == nullptr) { + calld->SetState(CallData::CallState::PENDING); + pending_.push(calld); + return; + } + } + GRPC_STATS_INC_SERVER_CQS_CHECKED(loop_count + requests_per_cq_.size()); + calld->SetState(CallData::CallState::ACTIVATED); + calld->Publish(cq_idx, rc); + } + + Server* server() const override { return server_; } + + private: + Server* const server_; + std::queue pending_; + std::vector requests_per_cq_; +}; + +// AllocatingRequestMatchers don't allow the application to request an RPC in +// advance or queue up any incoming RPC for later match. Instead, MatchOrQueue +// will call out to an allocation function passed in at the construction of the +// object. These request matchers are designed for the C++ callback API, so they +// only support 1 completion queue (passed in at the constructor). They are also +// used for the sync API. +class Server::AllocatingRequestMatcherBase : public RequestMatcherInterface { + public: + AllocatingRequestMatcherBase(Server* server, grpc_completion_queue* cq) + : server_(server), cq_(cq) { + size_t idx; + for (idx = 0; idx < server->cqs_.size(); idx++) { + if (server->cqs_[idx] == cq) { + break; + } + } + GPR_ASSERT(idx < server->cqs_.size()); + cq_idx_ = idx; + } + + void ZombifyPending() override {} + + void KillRequests(grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + size_t request_queue_count() const override { return 0; } + + void RequestCallWithPossiblePublish(size_t /*request_queue_index*/, + RequestedCall* /*call*/) final { + GPR_ASSERT(false); + } + + Server* server() const override { return server_; } + + // Supply the completion queue related to this request matcher + grpc_completion_queue* cq() const { return cq_; } + + // Supply the completion queue's index relative to the server. + size_t cq_idx() const { return cq_idx_; } + + private: + Server* const server_; + grpc_completion_queue* const cq_; + size_t cq_idx_; +}; + +// An allocating request matcher for non-registered methods (used for generic +// API and unimplemented RPCs). +class Server::AllocatingRequestMatcherBatch + : public AllocatingRequestMatcherBase { + public: + AllocatingRequestMatcherBatch(Server* server, grpc_completion_queue* cq, + std::function allocator) + : AllocatingRequestMatcherBase(server, cq), + allocator_(std::move(allocator)) {} + + void MatchOrQueue(size_t /*start_request_queue_index*/, + CallData* calld) override { + if (server()->ShutdownRefOnRequest()) { + BatchCallAllocation call_info = allocator_(); + GPR_ASSERT(server()->ValidateServerRequest( + cq(), static_cast(call_info.tag), nullptr, + nullptr) == GRPC_CALL_OK); + RequestedCall* rc = new RequestedCall( + static_cast(call_info.tag), call_info.cq, call_info.call, + call_info.initial_metadata, call_info.details); + calld->SetState(CallData::CallState::ACTIVATED); + calld->Publish(cq_idx(), rc); + } else { + calld->FailCallCreation(); + } + server()->ShutdownUnrefOnRequest(); + } + + private: + std::function allocator_; +}; + +// An allocating request matcher for registered methods. +class Server::AllocatingRequestMatcherRegistered + : public AllocatingRequestMatcherBase { + public: + AllocatingRequestMatcherRegistered( + Server* server, grpc_completion_queue* cq, RegisteredMethod* rm, + std::function allocator) + : AllocatingRequestMatcherBase(server, cq), + registered_method_(rm), + allocator_(std::move(allocator)) {} + + void MatchOrQueue(size_t /*start_request_queue_index*/, + CallData* calld) override { + if (server()->ShutdownRefOnRequest()) { + RegisteredCallAllocation call_info = allocator_(); + GPR_ASSERT(server()->ValidateServerRequest( + cq(), call_info.tag, call_info.optional_payload, + registered_method_) == GRPC_CALL_OK); + RequestedCall* rc = + new RequestedCall(call_info.tag, call_info.cq, call_info.call, + call_info.initial_metadata, registered_method_, + call_info.deadline, call_info.optional_payload); + calld->SetState(CallData::CallState::ACTIVATED); + calld->Publish(cq_idx(), rc); + } else { + calld->FailCallCreation(); + } + server()->ShutdownUnrefOnRequest(); + } + + private: + RegisteredMethod* const registered_method_; + std::function allocator_; +}; + +// +// ChannelBroadcaster +// + +namespace { + +class ChannelBroadcaster { + public: + // This can have an empty constructor and destructor since we want to control + // when the actual setup and shutdown broadcast take place. + + // Copies over the channels from the locked server. + void FillChannelsLocked(std::vector channels) { + GPR_DEBUG_ASSERT(channels_.empty()); + channels_ = std::move(channels); + } + + // Broadcasts a shutdown on each channel. + void BroadcastShutdown(bool send_goaway, grpc_error_handle force_disconnect) { + for (grpc_channel* channel : channels_) { + SendShutdown(channel, send_goaway, GRPC_ERROR_REF(force_disconnect)); + GRPC_CHANNEL_INTERNAL_UNREF(channel, "broadcast"); + } + channels_.clear(); // just for safety against double broadcast + GRPC_ERROR_UNREF(force_disconnect); + } + + private: + struct ShutdownCleanupArgs { + grpc_closure closure; + grpc_slice slice; + }; + + static void ShutdownCleanup(void* arg, grpc_error_handle /*error*/) { + ShutdownCleanupArgs* a = static_cast(arg); + grpc_slice_unref_internal(a->slice); + delete a; + } + + static void SendShutdown(grpc_channel* channel, bool send_goaway, + grpc_error_handle send_disconnect) { + ShutdownCleanupArgs* sc = new ShutdownCleanupArgs; + GRPC_CLOSURE_INIT(&sc->closure, ShutdownCleanup, sc, + grpc_schedule_on_exec_ctx); + grpc_transport_op* op = grpc_make_transport_op(&sc->closure); + grpc_channel_element* elem; + op->goaway_error = + send_goaway + ? grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server shutdown"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_OK) + : GRPC_ERROR_NONE; + op->set_accept_stream = true; + sc->slice = grpc_slice_from_copied_string("Server shutdown"); + op->disconnect_with_error = send_disconnect; + elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(channel), 0); + elem->filter->start_transport_op(elem, op); + } + + std::vector channels_; +}; + +} // namespace + +// +// Server +// + +const grpc_channel_filter Server::kServerTopFilter = { + Server::CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(Server::CallData), + Server::CallData::InitCallElement, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + Server::CallData::DestroyCallElement, + sizeof(Server::ChannelData), + Server::ChannelData::InitChannelElement, + Server::ChannelData::DestroyChannelElement, + grpc_channel_next_get_info, + "server", +}; + +namespace { + +RefCountedPtr CreateChannelzNode( + const grpc_channel_args* args) { + RefCountedPtr channelz_node; + if (grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_CHANNELZ, + GRPC_ENABLE_CHANNELZ_DEFAULT)) { + size_t channel_tracer_max_memory = grpc_channel_args_find_integer( + args, GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, + {GRPC_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE_DEFAULT, 0, INT_MAX}); + channelz_node = + MakeRefCounted(channel_tracer_max_memory); + channelz_node->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string("Server created")); + } + return channelz_node; +} + +} // namespace + +Server::Server(const grpc_channel_args* args) + : channel_args_(grpc_channel_args_copy(args)), + channelz_node_(CreateChannelzNode(args)) {} + +Server::~Server() { + grpc_channel_args_destroy(channel_args_); + // Remove the cq pollsets from the config_fetcher. + if (started_ && config_fetcher_ != nullptr && + config_fetcher_->interested_parties() != nullptr) { + for (grpc_pollset* pollset : pollsets_) { + grpc_pollset_set_del_pollset(config_fetcher_->interested_parties(), + pollset); + } + } + for (size_t i = 0; i < cqs_.size(); i++) { + GRPC_CQ_INTERNAL_UNREF(cqs_[i], "server"); + } +} + +void Server::AddListener(OrphanablePtr listener) { + channelz::ListenSocketNode* listen_socket_node = + listener->channelz_listen_socket_node(); + if (listen_socket_node != nullptr && channelz_node_ != nullptr) { + channelz_node_->AddChildListenSocket(listen_socket_node->Ref()); + } + listeners_.emplace_back(std::move(listener)); +} + +void Server::Start() { + started_ = true; + for (grpc_completion_queue* cq : cqs_) { + if (grpc_cq_can_listen(cq)) { + pollsets_.push_back(grpc_cq_pollset(cq)); + } + } + if (unregistered_request_matcher_ == nullptr) { + unregistered_request_matcher_ = absl::make_unique(this); + } + for (std::unique_ptr& rm : registered_methods_) { + if (rm->matcher == nullptr) { + rm->matcher = absl::make_unique(this); + } + } + { + MutexLock lock(&mu_global_); + starting_ = true; + } + // Register the interested parties from the config fetcher to the cq pollsets + // before starting listeners so that config fetcher is being polled when the + // listeners start watch the fetcher. + if (config_fetcher_ != nullptr && + config_fetcher_->interested_parties() != nullptr) { + for (grpc_pollset* pollset : pollsets_) { + grpc_pollset_set_add_pollset(config_fetcher_->interested_parties(), + pollset); + } + } + for (auto& listener : listeners_) { + listener.listener->Start(this, &pollsets_); + } + MutexLock lock(&mu_global_); + starting_ = false; + starting_cv_.Signal(); +} + +grpc_error_handle Server::SetupTransport( + grpc_transport* transport, grpc_pollset* accepting_pollset, + const grpc_channel_args* args, + const RefCountedPtr& socket_node, + grpc_resource_user* resource_user, size_t preallocated_bytes) { + // Create channel. + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_channel* channel = + grpc_channel_create(nullptr, args, GRPC_SERVER_CHANNEL, transport, + resource_user, preallocated_bytes, &error); + if (channel == nullptr) { + return error; + } + ChannelData* chand = static_cast( + grpc_channel_stack_element(grpc_channel_get_channel_stack(channel), 0) + ->channel_data); + // Set up CQs. + size_t cq_idx; + for (cq_idx = 0; cq_idx < cqs_.size(); cq_idx++) { + if (grpc_cq_pollset(cqs_[cq_idx]) == accepting_pollset) break; + } + if (cq_idx == cqs_.size()) { + // Completion queue not found. Pick a random one to publish new calls to. + cq_idx = static_cast(rand()) % cqs_.size(); + } + // Set up channelz node. + intptr_t channelz_socket_uuid = 0; + if (socket_node != nullptr) { + channelz_socket_uuid = socket_node->uuid(); + channelz_node_->AddChildSocket(socket_node); + } + // Initialize chand. + chand->InitTransport(Ref(), channel, cq_idx, transport, channelz_socket_uuid); + return GRPC_ERROR_NONE; +} + +bool Server::HasOpenConnections() { + MutexLock lock(&mu_global_); + return !channels_.empty(); +} + +void Server::SetRegisteredMethodAllocator( + grpc_completion_queue* cq, void* method_tag, + std::function allocator) { + RegisteredMethod* rm = static_cast(method_tag); + rm->matcher = absl::make_unique( + this, cq, rm, std::move(allocator)); +} + +void Server::SetBatchMethodAllocator( + grpc_completion_queue* cq, std::function allocator) { + GPR_DEBUG_ASSERT(unregistered_request_matcher_ == nullptr); + unregistered_request_matcher_ = + absl::make_unique(this, cq, + std::move(allocator)); +} + +void Server::RegisterCompletionQueue(grpc_completion_queue* cq) { + for (grpc_completion_queue* queue : cqs_) { + if (queue == cq) return; + } + GRPC_CQ_INTERNAL_REF(cq, "server"); + cqs_.push_back(cq); +} + +namespace { + +bool streq(const std::string& a, const char* b) { + return (a.empty() && b == nullptr) || + ((b != nullptr) && !strcmp(a.c_str(), b)); +} + +} // namespace + +Server::RegisteredMethod* Server::RegisterMethod( + const char* method, const char* host, + grpc_server_register_method_payload_handling payload_handling, + uint32_t flags) { + if (!method) { + gpr_log(GPR_ERROR, + "grpc_server_register_method method string cannot be NULL"); + return nullptr; + } + for (std::unique_ptr& m : registered_methods_) { + if (streq(m->method, method) && streq(m->host, host)) { + gpr_log(GPR_ERROR, "duplicate registration for %s@%s", method, + host ? host : "*"); + return nullptr; + } + } + if ((flags & ~GRPC_INITIAL_METADATA_USED_MASK) != 0) { + gpr_log(GPR_ERROR, "grpc_server_register_method invalid flags 0x%08x", + flags); + return nullptr; + } + registered_methods_.emplace_back(absl::make_unique( + method, host, payload_handling, flags)); + return registered_methods_.back().get(); +} + +void Server::DoneRequestEvent(void* req, grpc_cq_completion* /*c*/) { + delete static_cast(req); +} + +void Server::FailCall(size_t cq_idx, RequestedCall* rc, + grpc_error_handle error) { + *rc->call = nullptr; + rc->initial_metadata->count = 0; + GPR_ASSERT(error != GRPC_ERROR_NONE); + grpc_cq_end_op(cqs_[cq_idx], rc->tag, error, DoneRequestEvent, rc, + &rc->completion); +} + +// Before calling MaybeFinishShutdown(), we must hold mu_global_ and not +// hold mu_call_. +void Server::MaybeFinishShutdown() { + if (!ShutdownReady() || shutdown_published_) { + return; + } + { + MutexLock lock(&mu_call_); + KillPendingWorkLocked( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown")); + } + if (!channels_.empty() || listeners_destroyed_ < listeners_.size()) { + if (gpr_time_cmp(gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), + last_shutdown_message_time_), + gpr_time_from_seconds(1, GPR_TIMESPAN)) >= 0) { + last_shutdown_message_time_ = gpr_now(GPR_CLOCK_REALTIME); + gpr_log(GPR_DEBUG, + "Waiting for %" PRIuPTR " channels and %" PRIuPTR "/%" PRIuPTR + " listeners to be destroyed before shutting down server", + channels_.size(), listeners_.size() - listeners_destroyed_, + listeners_.size()); + } + return; + } + shutdown_published_ = true; + for (auto& shutdown_tag : shutdown_tags_) { + Ref().release(); + grpc_cq_end_op(shutdown_tag.cq, shutdown_tag.tag, GRPC_ERROR_NONE, + DoneShutdownEvent, this, &shutdown_tag.completion); + } +} + +void Server::KillPendingWorkLocked(grpc_error_handle error) { + if (started_) { + unregistered_request_matcher_->KillRequests(GRPC_ERROR_REF(error)); + unregistered_request_matcher_->ZombifyPending(); + for (std::unique_ptr& rm : registered_methods_) { + rm->matcher->KillRequests(GRPC_ERROR_REF(error)); + rm->matcher->ZombifyPending(); + } + } + GRPC_ERROR_UNREF(error); +} + +std::vector Server::GetChannelsLocked() const { + std::vector channels; + channels.reserve(channels_.size()); + for (const ChannelData* chand : channels_) { + channels.push_back(chand->channel()); + GRPC_CHANNEL_INTERNAL_REF(chand->channel(), "broadcast"); + } + return channels; +} + +void Server::ListenerDestroyDone(void* arg, grpc_error_handle /*error*/) { + Server* server = static_cast(arg); + MutexLock lock(&server->mu_global_); + server->listeners_destroyed_++; + server->MaybeFinishShutdown(); +} + +namespace { + +void DonePublishedShutdown(void* /*done_arg*/, grpc_cq_completion* storage) { + delete storage; +} + +} // namespace + +// - Kills all pending requests-for-incoming-RPC-calls (i.e., the requests made +// via grpc_server_request_call() and grpc_server_request_registered_call() +// will now be cancelled). See KillPendingWorkLocked(). +// +// - Shuts down the listeners (i.e., the server will no longer listen on the +// port for new incoming channels). +// +// - Iterates through all channels on the server and sends shutdown msg (see +// ChannelBroadcaster::BroadcastShutdown() for details) to the clients via +// the transport layer. The transport layer then guarantees the following: +// -- Sends shutdown to the client (e.g., HTTP2 transport sends GOAWAY). +// -- If the server has outstanding calls that are in the process, the +// connection is NOT closed until the server is done with all those calls. +// -- Once there are no more calls in progress, the channel is closed. +void Server::ShutdownAndNotify(grpc_completion_queue* cq, void* tag) { + absl::Notification* await_requests = nullptr; + ChannelBroadcaster broadcaster; + { + // Wait for startup to be finished. Locks mu_global. + MutexLock lock(&mu_global_); + while (starting_) { + starting_cv_.Wait(&mu_global_); + } + // Stay locked, and gather up some stuff to do. + GPR_ASSERT(grpc_cq_begin_op(cq, tag)); + if (shutdown_published_) { + grpc_cq_end_op(cq, tag, GRPC_ERROR_NONE, DonePublishedShutdown, nullptr, + new grpc_cq_completion); + return; + } + shutdown_tags_.emplace_back(tag, cq); + if (ShutdownCalled()) { + return; + } + last_shutdown_message_time_ = gpr_now(GPR_CLOCK_REALTIME); + broadcaster.FillChannelsLocked(GetChannelsLocked()); + // Collect all unregistered then registered calls. + { + MutexLock lock(&mu_call_); + KillPendingWorkLocked( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown")); + } + await_requests = ShutdownUnrefOnShutdownCall(); + } + // We expect no new requests but there can still be requests in-flight. + // Wait for them to complete before proceeding. + if (await_requests != nullptr) { + await_requests->WaitForNotification(); + } + // Shutdown listeners. + for (auto& listener : listeners_) { + channelz::ListenSocketNode* channelz_listen_socket_node = + listener.listener->channelz_listen_socket_node(); + if (channelz_node_ != nullptr && channelz_listen_socket_node != nullptr) { + channelz_node_->RemoveChildListenSocket( + channelz_listen_socket_node->uuid()); + } + GRPC_CLOSURE_INIT(&listener.destroy_done, ListenerDestroyDone, this, + grpc_schedule_on_exec_ctx); + listener.listener->SetOnDestroyDone(&listener.destroy_done); + listener.listener.reset(); + } + broadcaster.BroadcastShutdown(/*send_goaway=*/true, GRPC_ERROR_NONE); +} + +void Server::CancelAllCalls() { + ChannelBroadcaster broadcaster; + { + MutexLock lock(&mu_global_); + broadcaster.FillChannelsLocked(GetChannelsLocked()); + } + broadcaster.BroadcastShutdown( + /*send_goaway=*/false, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Cancelling all calls")); +} + +void Server::Orphan() { + { + MutexLock lock(&mu_global_); + GPR_ASSERT(ShutdownCalled() || listeners_.empty()); + GPR_ASSERT(listeners_destroyed_ == listeners_.size()); + } + Unref(); +} + +grpc_call_error Server::ValidateServerRequest( + grpc_completion_queue* cq_for_notification, void* tag, + grpc_byte_buffer** optional_payload, RegisteredMethod* rm) { + if ((rm == nullptr && optional_payload != nullptr) || + ((rm != nullptr) && ((optional_payload == nullptr) != + (rm->payload_handling == GRPC_SRM_PAYLOAD_NONE)))) { + return GRPC_CALL_ERROR_PAYLOAD_TYPE_MISMATCH; + } + if (!grpc_cq_begin_op(cq_for_notification, tag)) { + return GRPC_CALL_ERROR_COMPLETION_QUEUE_SHUTDOWN; + } + return GRPC_CALL_OK; +} + +grpc_call_error Server::ValidateServerRequestAndCq( + size_t* cq_idx, grpc_completion_queue* cq_for_notification, void* tag, + grpc_byte_buffer** optional_payload, RegisteredMethod* rm) { + size_t idx; + for (idx = 0; idx < cqs_.size(); idx++) { + if (cqs_[idx] == cq_for_notification) { + break; + } + } + if (idx == cqs_.size()) { + return GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE; + } + grpc_call_error error = + ValidateServerRequest(cq_for_notification, tag, optional_payload, rm); + if (error != GRPC_CALL_OK) { + return error; + } + *cq_idx = idx; + return GRPC_CALL_OK; +} + +grpc_call_error Server::QueueRequestedCall(size_t cq_idx, RequestedCall* rc) { + if (ShutdownCalled()) { + FailCall(cq_idx, rc, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server Shutdown")); + return GRPC_CALL_OK; + } + RequestMatcherInterface* rm; + switch (rc->type) { + case RequestedCall::Type::BATCH_CALL: + rm = unregistered_request_matcher_.get(); + break; + case RequestedCall::Type::REGISTERED_CALL: + rm = rc->data.registered.method->matcher.get(); + break; + } + rm->RequestCallWithPossiblePublish(cq_idx, rc); + return GRPC_CALL_OK; +} + +grpc_call_error Server::RequestCall(grpc_call** call, + grpc_call_details* details, + grpc_metadata_array* request_metadata, + grpc_completion_queue* cq_bound_to_call, + grpc_completion_queue* cq_for_notification, + void* tag) { + size_t cq_idx; + grpc_call_error error = ValidateServerRequestAndCq( + &cq_idx, cq_for_notification, tag, nullptr, nullptr); + if (error != GRPC_CALL_OK) { + return error; + } + RequestedCall* rc = + new RequestedCall(tag, cq_bound_to_call, call, request_metadata, details); + return QueueRequestedCall(cq_idx, rc); +} + +grpc_call_error Server::RequestRegisteredCall( + RegisteredMethod* rm, grpc_call** call, gpr_timespec* deadline, + grpc_metadata_array* request_metadata, grpc_byte_buffer** optional_payload, + grpc_completion_queue* cq_bound_to_call, + grpc_completion_queue* cq_for_notification, void* tag_new) { + size_t cq_idx; + grpc_call_error error = ValidateServerRequestAndCq( + &cq_idx, cq_for_notification, tag_new, optional_payload, rm); + if (error != GRPC_CALL_OK) { + return error; + } + RequestedCall* rc = + new RequestedCall(tag_new, cq_bound_to_call, call, request_metadata, rm, + deadline, optional_payload); + return QueueRequestedCall(cq_idx, rc); +} + +// +// Server::ChannelData::ConnectivityWatcher +// + +class Server::ChannelData::ConnectivityWatcher + : public AsyncConnectivityStateWatcherInterface { + public: + explicit ConnectivityWatcher(ChannelData* chand) : chand_(chand) { + GRPC_CHANNEL_INTERNAL_REF(chand_->channel_, "connectivity"); + } + + ~ConnectivityWatcher() override { + GRPC_CHANNEL_INTERNAL_UNREF(chand_->channel_, "connectivity"); + } + + private: + void OnConnectivityStateChange(grpc_connectivity_state new_state, + const absl::Status& /*status*/) override { + // Don't do anything until we are being shut down. + if (new_state != GRPC_CHANNEL_SHUTDOWN) return; + // Shut down channel. + MutexLock lock(&chand_->server_->mu_global_); + chand_->Destroy(); + } + + ChannelData* chand_; +}; + +// +// Server::ChannelData +// + +Server::ChannelData::~ChannelData() { + if (registered_methods_ != nullptr) { + for (const ChannelRegisteredMethod& crm : *registered_methods_) { + grpc_slice_unref_internal(crm.method); + GPR_DEBUG_ASSERT(crm.method.refcount == &kNoopRefcount || + crm.method.refcount == nullptr); + if (crm.has_host) { + grpc_slice_unref_internal(crm.host); + GPR_DEBUG_ASSERT(crm.host.refcount == &kNoopRefcount || + crm.host.refcount == nullptr); + } + } + registered_methods_.reset(); + } + if (server_ != nullptr) { + if (server_->channelz_node_ != nullptr && channelz_socket_uuid_ != 0) { + server_->channelz_node_->RemoveChildSocket(channelz_socket_uuid_); + } + { + MutexLock lock(&server_->mu_global_); + if (list_position_.has_value()) { + server_->channels_.erase(*list_position_); + list_position_.reset(); + } + server_->MaybeFinishShutdown(); + } + } +} + +void Server::ChannelData::InitTransport(RefCountedPtr server, + grpc_channel* channel, size_t cq_idx, + grpc_transport* transport, + intptr_t channelz_socket_uuid) { + server_ = std::move(server); + channel_ = channel; + cq_idx_ = cq_idx; + channelz_socket_uuid_ = channelz_socket_uuid; + // Build a lookup table phrased in terms of mdstr's in this channels context + // to quickly find registered methods. + size_t num_registered_methods = server_->registered_methods_.size(); + if (num_registered_methods > 0) { + uint32_t max_probes = 0; + size_t slots = 2 * num_registered_methods; + registered_methods_ = + absl::make_unique>(slots); + for (std::unique_ptr& rm : server_->registered_methods_) { + ExternallyManagedSlice host; + ExternallyManagedSlice method(rm->method.c_str()); + const bool has_host = !rm->host.empty(); + if (has_host) { + host = ExternallyManagedSlice(rm->host.c_str()); + } + uint32_t hash = + GRPC_MDSTR_KV_HASH(has_host ? host.Hash() : 0, method.Hash()); + uint32_t probes = 0; + for (probes = 0; (*registered_methods_)[(hash + probes) % slots] + .server_registered_method != nullptr; + probes++) { + } + if (probes > max_probes) max_probes = probes; + ChannelRegisteredMethod* crm = + &(*registered_methods_)[(hash + probes) % slots]; + crm->server_registered_method = rm.get(); + crm->flags = rm->flags; + crm->has_host = has_host; + if (has_host) { + crm->host = host; + } + crm->method = method; + } + GPR_ASSERT(slots <= UINT32_MAX); + registered_method_max_probes_ = max_probes; + } + // Publish channel. + { + MutexLock lock(&server_->mu_global_); + server_->channels_.push_front(this); + list_position_ = server_->channels_.begin(); + } + // Start accept_stream transport op. + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->set_accept_stream = true; + op->set_accept_stream_fn = AcceptStream; + op->set_accept_stream_user_data = this; + op->start_connectivity_watch = MakeOrphanable(this); + if (server_->ShutdownCalled()) { + op->disconnect_with_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Server shutdown"); + } + grpc_transport_perform_op(transport, op); +} + +Server::ChannelRegisteredMethod* Server::ChannelData::GetRegisteredMethod( + const grpc_slice& host, const grpc_slice& path, bool is_idempotent) { + if (registered_methods_ == nullptr) return nullptr; + /* TODO(ctiller): unify these two searches */ + /* check for an exact match with host */ + uint32_t hash = GRPC_MDSTR_KV_HASH(grpc_slice_hash_internal(host), + grpc_slice_hash_internal(path)); + for (size_t i = 0; i <= registered_method_max_probes_; i++) { + ChannelRegisteredMethod* rm = + &(*registered_methods_)[(hash + i) % registered_methods_->size()]; + if (rm->server_registered_method == nullptr) break; + if (!rm->has_host) continue; + if (rm->host != host) continue; + if (rm->method != path) continue; + if ((rm->flags & GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST) && + !is_idempotent) { + continue; + } + return rm; + } + /* check for a wildcard method definition (no host set) */ + hash = GRPC_MDSTR_KV_HASH(0, grpc_slice_hash_internal(path)); + for (size_t i = 0; i <= registered_method_max_probes_; i++) { + ChannelRegisteredMethod* rm = + &(*registered_methods_)[(hash + i) % registered_methods_->size()]; + if (rm->server_registered_method == nullptr) break; + if (rm->has_host) continue; + if (rm->method != path) continue; + if ((rm->flags & GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST) && + !is_idempotent) { + continue; + } + return rm; + } + return nullptr; +} + +void Server::ChannelData::AcceptStream(void* arg, grpc_transport* /*transport*/, + const void* transport_server_data) { + auto* chand = static_cast(arg); + /* create a call */ + grpc_call_create_args args; + args.channel = chand->channel_; + args.server = chand->server_.get(); + args.parent = nullptr; + args.propagation_mask = 0; + args.cq = nullptr; + args.pollset_set_alternative = nullptr; + args.server_transport_data = transport_server_data; + args.add_initial_metadata = nullptr; + args.add_initial_metadata_count = 0; + args.send_deadline = GRPC_MILLIS_INF_FUTURE; + grpc_call* call; + grpc_error_handle error = grpc_call_create(&args, &call); + grpc_call_element* elem = + grpc_call_stack_element(grpc_call_get_call_stack(call), 0); + auto* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(error); + calld->FailCallCreation(); + return; + } + calld->Start(elem); +} + +void Server::ChannelData::FinishDestroy(void* arg, + grpc_error_handle /*error*/) { + auto* chand = static_cast(arg); + Server* server = chand->server_.get(); + GRPC_CHANNEL_INTERNAL_UNREF(chand->channel_, "server"); + server->Unref(); +} + +void Server::ChannelData::Destroy() { + if (!list_position_.has_value()) return; + GPR_ASSERT(server_ != nullptr); + server_->channels_.erase(*list_position_); + list_position_.reset(); + server_->Ref().release(); + server_->MaybeFinishShutdown(); + GRPC_CLOSURE_INIT(&finish_destroy_channel_closure_, FinishDestroy, this, + grpc_schedule_on_exec_ctx); + if (GRPC_TRACE_FLAG_ENABLED(grpc_server_channel_trace)) { + gpr_log(GPR_INFO, "Disconnected client"); + } + grpc_transport_op* op = + grpc_make_transport_op(&finish_destroy_channel_closure_); + op->set_accept_stream = true; + grpc_channel_next_op( + grpc_channel_stack_element(grpc_channel_get_channel_stack(channel_), 0), + op); +} + +grpc_error_handle Server::ChannelData::InitChannelElement( + grpc_channel_element* elem, grpc_channel_element_args* args) { + GPR_ASSERT(args->is_first); + GPR_ASSERT(!args->is_last); + new (elem->channel_data) ChannelData(); + return GRPC_ERROR_NONE; +} + +void Server::ChannelData::DestroyChannelElement(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~ChannelData(); +} + +// +// Server::CallData +// + +Server::CallData::CallData(grpc_call_element* elem, + const grpc_call_element_args& args, + RefCountedPtr server) + : server_(std::move(server)), + call_(grpc_call_from_top_element(elem)), + call_combiner_(args.call_combiner) { + GRPC_CLOSURE_INIT(&recv_initial_metadata_ready_, RecvInitialMetadataReady, + elem, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReady, + elem, grpc_schedule_on_exec_ctx); +} + +Server::CallData::~CallData() { + GPR_ASSERT(state_.load(std::memory_order_relaxed) != CallState::PENDING); + GRPC_ERROR_UNREF(recv_initial_metadata_error_); + if (host_.has_value()) { + grpc_slice_unref_internal(*host_); + } + if (path_.has_value()) { + grpc_slice_unref_internal(*path_); + } + grpc_metadata_array_destroy(&initial_metadata_); + grpc_byte_buffer_destroy(payload_); +} + +void Server::CallData::SetState(CallState state) { + state_.store(state, std::memory_order_relaxed); +} + +bool Server::CallData::MaybeActivate() { + CallState expected = CallState::PENDING; + return state_.compare_exchange_strong(expected, CallState::ACTIVATED, + std::memory_order_acq_rel, + std::memory_order_relaxed); +} + +void Server::CallData::FailCallCreation() { + CallState expected_not_started = CallState::NOT_STARTED; + CallState expected_pending = CallState::PENDING; + if (state_.compare_exchange_strong(expected_not_started, CallState::ZOMBIED, + std::memory_order_acq_rel, + std::memory_order_acquire)) { + KillZombie(); + } else if (state_.compare_exchange_strong( + expected_pending, CallState::ZOMBIED, + std::memory_order_acq_rel, std::memory_order_relaxed)) { + // Zombied call will be destroyed when it's removed from the pending + // queue... later. + } +} + +void Server::CallData::Start(grpc_call_element* elem) { + grpc_op op; + op.op = GRPC_OP_RECV_INITIAL_METADATA; + op.flags = 0; + op.reserved = nullptr; + op.data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_; + GRPC_CLOSURE_INIT(&recv_initial_metadata_batch_complete_, + RecvInitialMetadataBatchComplete, elem, + grpc_schedule_on_exec_ctx); + grpc_call_start_batch_and_execute(call_, &op, 1, + &recv_initial_metadata_batch_complete_); +} + +void Server::CallData::Publish(size_t cq_idx, RequestedCall* rc) { + grpc_call_set_completion_queue(call_, rc->cq_bound_to_call); + *rc->call = call_; + cq_new_ = server_->cqs_[cq_idx]; + std::swap(*rc->initial_metadata, initial_metadata_); + switch (rc->type) { + case RequestedCall::Type::BATCH_CALL: + GPR_ASSERT(host_.has_value()); + GPR_ASSERT(path_.has_value()); + rc->data.batch.details->host = grpc_slice_ref_internal(*host_); + rc->data.batch.details->method = grpc_slice_ref_internal(*path_); + rc->data.batch.details->deadline = + grpc_millis_to_timespec(deadline_, GPR_CLOCK_MONOTONIC); + rc->data.batch.details->flags = recv_initial_metadata_flags_; + break; + case RequestedCall::Type::REGISTERED_CALL: + *rc->data.registered.deadline = + grpc_millis_to_timespec(deadline_, GPR_CLOCK_MONOTONIC); + if (rc->data.registered.optional_payload != nullptr) { + *rc->data.registered.optional_payload = payload_; + payload_ = nullptr; + } + break; + default: + GPR_UNREACHABLE_CODE(return ); + } + grpc_cq_end_op(cq_new_, rc->tag, GRPC_ERROR_NONE, Server::DoneRequestEvent, + rc, &rc->completion, true); +} + +void Server::CallData::PublishNewRpc(void* arg, grpc_error_handle error) { + grpc_call_element* call_elem = static_cast(arg); + auto* calld = static_cast(call_elem->call_data); + auto* chand = static_cast(call_elem->channel_data); + RequestMatcherInterface* rm = calld->matcher_; + Server* server = rm->server(); + if (error != GRPC_ERROR_NONE || server->ShutdownCalled()) { + calld->state_.store(CallState::ZOMBIED, std::memory_order_relaxed); + calld->KillZombie(); + return; + } + rm->MatchOrQueue(chand->cq_idx(), calld); +} + +namespace { + +void KillZombieClosure(void* call, grpc_error_handle /*error*/) { + grpc_call_unref(static_cast(call)); +} + +} // namespace + +void Server::CallData::KillZombie() { + GRPC_CLOSURE_INIT(&kill_zombie_closure_, KillZombieClosure, call_, + grpc_schedule_on_exec_ctx); + ExecCtx::Run(DEBUG_LOCATION, &kill_zombie_closure_, GRPC_ERROR_NONE); +} + +void Server::CallData::StartNewRpc(grpc_call_element* elem) { + auto* chand = static_cast(elem->channel_data); + if (server_->ShutdownCalled()) { + state_.store(CallState::ZOMBIED, std::memory_order_relaxed); + KillZombie(); + return; + } + // Find request matcher. + matcher_ = server_->unregistered_request_matcher_.get(); + grpc_server_register_method_payload_handling payload_handling = + GRPC_SRM_PAYLOAD_NONE; + if (path_.has_value() && host_.has_value()) { + ChannelRegisteredMethod* rm = + chand->GetRegisteredMethod(*host_, *path_, + (recv_initial_metadata_flags_ & + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST)); + if (rm != nullptr) { + matcher_ = rm->server_registered_method->matcher.get(); + payload_handling = rm->server_registered_method->payload_handling; + } + } + // Start recv_message op if needed. + switch (payload_handling) { + case GRPC_SRM_PAYLOAD_NONE: + PublishNewRpc(elem, GRPC_ERROR_NONE); + break; + case GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER: { + grpc_op op; + op.op = GRPC_OP_RECV_MESSAGE; + op.flags = 0; + op.reserved = nullptr; + op.data.recv_message.recv_message = &payload_; + GRPC_CLOSURE_INIT(&publish_, PublishNewRpc, elem, + grpc_schedule_on_exec_ctx); + grpc_call_start_batch_and_execute(call_, &op, 1, &publish_); + break; + } + } +} + +void Server::CallData::RecvInitialMetadataBatchComplete( + void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + auto* calld = static_cast(elem->call_data); + if (error != GRPC_ERROR_NONE) { + calld->FailCallCreation(); + return; + } + calld->StartNewRpc(elem); +} + +void Server::CallData::StartTransportStreamOpBatchImpl( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + if (batch->recv_initial_metadata) { + GPR_ASSERT(batch->payload->recv_initial_metadata.recv_flags == nullptr); + recv_initial_metadata_ = + batch->payload->recv_initial_metadata.recv_initial_metadata; + original_recv_initial_metadata_ready_ = + batch->payload->recv_initial_metadata.recv_initial_metadata_ready; + batch->payload->recv_initial_metadata.recv_initial_metadata_ready = + &recv_initial_metadata_ready_; + batch->payload->recv_initial_metadata.recv_flags = + &recv_initial_metadata_flags_; + } + if (batch->recv_trailing_metadata) { + original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &recv_trailing_metadata_ready_; + } + grpc_call_next_op(elem, batch); +} + +void Server::CallData::RecvInitialMetadataReady(void* arg, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + CallData* calld = static_cast(elem->call_data); + if (error == GRPC_ERROR_NONE) { + GPR_DEBUG_ASSERT( + calld->recv_initial_metadata_->legacy_index()->named.path != nullptr); + GPR_DEBUG_ASSERT( + calld->recv_initial_metadata_->legacy_index()->named.authority != + nullptr); + calld->path_.emplace(grpc_slice_ref_internal(GRPC_MDVALUE( + calld->recv_initial_metadata_->legacy_index()->named.path->md))); + calld->host_.emplace(grpc_slice_ref_internal(GRPC_MDVALUE( + calld->recv_initial_metadata_->legacy_index()->named.authority->md))); + calld->recv_initial_metadata_->Remove(GRPC_BATCH_PATH); + calld->recv_initial_metadata_->Remove(GRPC_BATCH_AUTHORITY); + } else { + (void)GRPC_ERROR_REF(error); + } + auto op_deadline = calld->recv_initial_metadata_->get(GrpcTimeoutMetadata()); + if (op_deadline.has_value()) { + calld->deadline_ = *op_deadline; + } + if (calld->host_.has_value() && calld->path_.has_value()) { + /* do nothing */ + } else { + /* Pass the error reference to calld->recv_initial_metadata_error */ + grpc_error_handle src_error = error; + error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Missing :authority or :path", &src_error, 1); + GRPC_ERROR_UNREF(src_error); + calld->recv_initial_metadata_error_ = GRPC_ERROR_REF(error); + } + grpc_closure* closure = calld->original_recv_initial_metadata_ready_; + calld->original_recv_initial_metadata_ready_ = nullptr; + if (calld->seen_recv_trailing_metadata_ready_) { + GRPC_CALL_COMBINER_START(calld->call_combiner_, + &calld->recv_trailing_metadata_ready_, + calld->recv_trailing_metadata_error_, + "continue server recv_trailing_metadata_ready"); + } + Closure::Run(DEBUG_LOCATION, closure, error); +} + +void Server::CallData::RecvTrailingMetadataReady(void* arg, + grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + CallData* calld = static_cast(elem->call_data); + if (calld->original_recv_initial_metadata_ready_ != nullptr) { + calld->recv_trailing_metadata_error_ = GRPC_ERROR_REF(error); + calld->seen_recv_trailing_metadata_ready_ = true; + GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready_, + RecvTrailingMetadataReady, elem, + grpc_schedule_on_exec_ctx); + GRPC_CALL_COMBINER_STOP(calld->call_combiner_, + "deferring server recv_trailing_metadata_ready " + "until after recv_initial_metadata_ready"); + return; + } + error = + grpc_error_add_child(GRPC_ERROR_REF(error), + GRPC_ERROR_REF(calld->recv_initial_metadata_error_)); + Closure::Run(DEBUG_LOCATION, calld->original_recv_trailing_metadata_ready_, + error); +} + +grpc_error_handle Server::CallData::InitCallElement( + grpc_call_element* elem, const grpc_call_element_args* args) { + auto* chand = static_cast(elem->channel_data); + new (elem->call_data) Server::CallData(elem, *args, chand->server()); + return GRPC_ERROR_NONE; +} + +void Server::CallData::DestroyCallElement( + grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); +} + +void Server::CallData::StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + calld->StartTransportStreamOpBatchImpl(elem, batch); +} + +} // namespace grpc_core + +// +// C-core API +// + +grpc_server* grpc_server_create(const grpc_channel_args* args, void* reserved) { + grpc_core::ExecCtx exec_ctx; + args = grpc_channel_args_remove_grpc_internal(args); + GRPC_API_TRACE("grpc_server_create(%p, %p)", 2, (args, reserved)); + grpc_server* c_server = new grpc_server; + c_server->core_server = grpc_core::MakeOrphanable(args); + grpc_channel_args_destroy(args); + return c_server; +} + +void grpc_server_register_completion_queue(grpc_server* server, + grpc_completion_queue* cq, + void* reserved) { + GRPC_API_TRACE( + "grpc_server_register_completion_queue(server=%p, cq=%p, reserved=%p)", 3, + (server, cq, reserved)); + GPR_ASSERT(!reserved); + auto cq_type = grpc_get_cq_completion_type(cq); + if (cq_type != GRPC_CQ_NEXT && cq_type != GRPC_CQ_CALLBACK) { + gpr_log(GPR_INFO, + "Completion queue of type %d is being registered as a " + "server-completion-queue", + static_cast(cq_type)); + /* Ideally we should log an error and abort but ruby-wrapped-language API + calls grpc_completion_queue_pluck() on server completion queues */ + } + server->core_server->RegisterCompletionQueue(cq); +} + +void* grpc_server_register_method( + grpc_server* server, const char* method, const char* host, + grpc_server_register_method_payload_handling payload_handling, + uint32_t flags) { + GRPC_API_TRACE( + "grpc_server_register_method(server=%p, method=%s, host=%s, " + "flags=0x%08x)", + 4, (server, method, host, flags)); + return server->core_server->RegisterMethod(method, host, payload_handling, + flags); +} + +void grpc_server_start(grpc_server* server) { + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_start(server=%p)", 1, (server)); + server->core_server->Start(); +} + +void grpc_server_shutdown_and_notify(grpc_server* server, + grpc_completion_queue* cq, void* tag) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_shutdown_and_notify(server=%p, cq=%p, tag=%p)", 3, + (server, cq, tag)); + server->core_server->ShutdownAndNotify(cq, tag); +} + +void grpc_server_cancel_all_calls(grpc_server* server) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_cancel_all_calls(server=%p)", 1, (server)); + server->core_server->CancelAllCalls(); +} + +void grpc_server_destroy(grpc_server* server) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_destroy(server=%p)", 1, (server)); + delete server; +} + +grpc_call_error grpc_server_request_call( + grpc_server* server, grpc_call** call, grpc_call_details* details, + grpc_metadata_array* request_metadata, + grpc_completion_queue* cq_bound_to_call, + grpc_completion_queue* cq_for_notification, void* tag) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_STATS_INC_SERVER_REQUESTED_CALLS(); + GRPC_API_TRACE( + "grpc_server_request_call(" + "server=%p, call=%p, details=%p, initial_metadata=%p, " + "cq_bound_to_call=%p, cq_for_notification=%p, tag=%p)", + 7, + (server, call, details, request_metadata, cq_bound_to_call, + cq_for_notification, tag)); + return server->core_server->RequestCall(call, details, request_metadata, + cq_bound_to_call, cq_for_notification, + tag); +} + +grpc_call_error grpc_server_request_registered_call( + grpc_server* server, void* registered_method, grpc_call** call, + gpr_timespec* deadline, grpc_metadata_array* request_metadata, + grpc_byte_buffer** optional_payload, + grpc_completion_queue* cq_bound_to_call, + grpc_completion_queue* cq_for_notification, void* tag_new) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_STATS_INC_SERVER_REQUESTED_CALLS(); + auto* rm = + static_cast(registered_method); + GRPC_API_TRACE( + "grpc_server_request_registered_call(" + "server=%p, registered_method=%p, call=%p, deadline=%p, " + "request_metadata=%p, " + "optional_payload=%p, cq_bound_to_call=%p, cq_for_notification=%p, " + "tag=%p)", + 9, + (server, registered_method, call, deadline, request_metadata, + optional_payload, cq_bound_to_call, cq_for_notification, tag_new)); + return server->core_server->RequestRegisteredCall( + rm, call, deadline, request_metadata, optional_payload, cq_bound_to_call, + cq_for_notification, tag_new); +} + +void grpc_server_set_config_fetcher( + grpc_server* server, grpc_server_config_fetcher* server_config_fetcher) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_set_config_fetcher(server=%p, config_fetcher=%p)", + 2, (server, server_config_fetcher)); + server->core_server->set_config_fetcher( + std::unique_ptr(server_config_fetcher)); +} + +void grpc_server_config_fetcher_destroy( + grpc_server_config_fetcher* server_config_fetcher) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_API_TRACE("grpc_server_config_fetcher_destroy(config_fetcher=%p)", 1, + (server_config_fetcher)); + delete server_config_fetcher; +} diff --git a/src/core/lib/surface/validate_metadata.cc b/src/core/lib/surface/validate_metadata.cc new file mode 100644 index 00000000..077acb11 --- /dev/null +++ b/src/core/lib/surface/validate_metadata.cc @@ -0,0 +1,136 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/validate_metadata.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gprpp/bitset.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +#if __cplusplus > 201103l +#define GRPC_VALIDATE_METADATA_CONSTEXPR_FN constexpr +#define GRPC_VALIDATE_METADATA_CONSTEXPR_VALUE constexpr +#else +#define GRPC_VALIDATE_METADATA_CONSTEXPR_FN +#define GRPC_VALIDATE_METADATA_CONSTEXPR_VALUE const +#endif + +static grpc_error_handle conforms_to(const grpc_slice& slice, + const grpc_core::BitSet<256>& legal_bits, + const char* err_desc) { + const uint8_t* p = GRPC_SLICE_START_PTR(slice); + const uint8_t* e = GRPC_SLICE_END_PTR(slice); + for (; p != e; p++) { + if (!legal_bits.is_set(*p)) { + size_t len; + grpc_core::UniquePtr ptr(gpr_dump_return_len( + reinterpret_cast GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_LENGTH(slice), GPR_DUMP_HEX | GPR_DUMP_ASCII, &len)); + grpc_error_handle error = grpc_error_set_str( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(err_desc), + GRPC_ERROR_INT_OFFSET, + p - GRPC_SLICE_START_PTR(slice)), + GRPC_ERROR_STR_RAW_BYTES, absl::string_view(ptr.get(), len)); + return error; + } + } + return GRPC_ERROR_NONE; +} + +static int error2int(grpc_error_handle error) { + int r = (error == GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); + return r; +} + +namespace { +class LegalHeaderKeyBits : public grpc_core::BitSet<256> { + public: + GRPC_VALIDATE_METADATA_CONSTEXPR_FN LegalHeaderKeyBits() { + for (int i = 'a'; i <= 'z'; i++) set(i); + for (int i = '0'; i <= '9'; i++) set(i); + set('-'); + set('_'); + set('.'); + } +}; +static GRPC_VALIDATE_METADATA_CONSTEXPR_VALUE LegalHeaderKeyBits + g_legal_header_key_bits; +} // namespace + +grpc_error_handle grpc_validate_header_key_is_legal(const grpc_slice& slice) { + if (GRPC_SLICE_LENGTH(slice) == 0) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Metadata keys cannot be zero length"); + } + if (GRPC_SLICE_LENGTH(slice) > UINT32_MAX) { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Metadata keys cannot be larger than UINT32_MAX"); + } + if (GRPC_SLICE_START_PTR(slice)[0] == ':') { + return GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Metadata keys cannot start with :"); + } + return conforms_to(slice, g_legal_header_key_bits, "Illegal header key"); +} + +int grpc_header_key_is_legal(grpc_slice slice) { + return error2int(grpc_validate_header_key_is_legal(slice)); +} + +namespace { +class LegalHeaderNonBinValueBits : public grpc_core::BitSet<256> { + public: + GRPC_VALIDATE_METADATA_CONSTEXPR_FN LegalHeaderNonBinValueBits() { + for (int i = 32; i <= 126; i++) { + set(i); + } + } +}; +static GRPC_VALIDATE_METADATA_CONSTEXPR_VALUE LegalHeaderNonBinValueBits + g_legal_header_non_bin_value_bits; +} // namespace + +grpc_error_handle grpc_validate_header_nonbin_value_is_legal( + const grpc_slice& slice) { + return conforms_to(slice, g_legal_header_non_bin_value_bits, + "Illegal header value"); +} + +int grpc_header_nonbin_value_is_legal(grpc_slice slice) { + return error2int(grpc_validate_header_nonbin_value_is_legal(slice)); +} + +int grpc_is_binary_header_internal(const grpc_slice& slice) { + return grpc_key_is_binary_header(GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_LENGTH(slice)); +} + +int grpc_is_binary_header(grpc_slice slice) { + return grpc_is_binary_header_internal(slice); +} diff --git a/src/core/lib/surface/version.cc b/src/core/lib/surface/version.cc new file mode 100644 index 00000000..c8b2fdb4 --- /dev/null +++ b/src/core/lib/surface/version.cc @@ -0,0 +1,28 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This file is autogenerated from: + templates/src/core/surface/version.c.template */ + +#include + +#include + +const char* grpc_version_string(void) { return "19.1.0"; } + +const char* grpc_g_stands_for(void) { return "granola"; } diff --git a/src/core/lib/transport/bdp_estimator.cc b/src/core/lib/transport/bdp_estimator.cc new file mode 100644 index 00000000..853eca69 --- /dev/null +++ b/src/core/lib/transport/bdp_estimator.cc @@ -0,0 +1,87 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/bdp_estimator.h" + +#include +#include + +#include "src/core/lib/gpr/useful.h" + +grpc_core::TraceFlag grpc_bdp_estimator_trace(false, "bdp_estimator"); + +namespace grpc_core { + +BdpEstimator::BdpEstimator(const char* name) + : ping_state_(PingState::UNSCHEDULED), + accumulator_(0), + estimate_(65536), + ping_start_time_(gpr_time_0(GPR_CLOCK_MONOTONIC)), + inter_ping_delay_(100), // start at 100ms + stable_estimate_count_(0), + bw_est_(0), + name_(name) {} + +grpc_millis BdpEstimator::CompletePing() { + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec dt_ts = gpr_time_sub(now, ping_start_time_); + double dt = static_cast(dt_ts.tv_sec) + + 1e-9 * static_cast(dt_ts.tv_nsec); + double bw = dt > 0 ? (static_cast(accumulator_) / dt) : 0; + int start_inter_ping_delay = inter_ping_delay_; + if (GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace)) { + gpr_log(GPR_INFO, + "bdp[%s]:complete acc=%" PRId64 " est=%" PRId64 + " dt=%lf bw=%lfMbs bw_est=%lfMbs", + name_, accumulator_, estimate_, dt, bw / 125000.0, + bw_est_ / 125000.0); + } + GPR_ASSERT(ping_state_ == PingState::STARTED); + if (accumulator_ > 2 * estimate_ / 3 && bw > bw_est_) { + estimate_ = std::max(accumulator_, estimate_ * 2); + bw_est_ = bw; + if (GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace)) { + gpr_log(GPR_INFO, "bdp[%s]: estimate increased to %" PRId64, name_, + estimate_); + } + inter_ping_delay_ /= 2; // if the ping estimate changes, + // exponentially get faster at probing + } else if (inter_ping_delay_ < 10000) { + stable_estimate_count_++; + if (stable_estimate_count_ >= 2) { + inter_ping_delay_ += + 100 + static_cast(rand() * 100.0 / + RAND_MAX); // if the ping estimate is steady, + // slowly ramp down the probe time + } + } + if (start_inter_ping_delay != inter_ping_delay_) { + stable_estimate_count_ = 0; + if (GRPC_TRACE_FLAG_ENABLED(grpc_bdp_estimator_trace)) { + gpr_log(GPR_INFO, "bdp[%s]:update_inter_time to %dms", name_, + inter_ping_delay_); + } + } + ping_state_ = PingState::UNSCHEDULED; + accumulator_ = 0; + return grpc_core::ExecCtx::Get()->Now() + inter_ping_delay_; +} + +} // namespace grpc_core diff --git a/src/core/lib/transport/byte_stream.cc b/src/core/lib/transport/byte_stream.cc new file mode 100644 index 00000000..3cc5275e --- /dev/null +++ b/src/core/lib/transport/byte_stream.cc @@ -0,0 +1,158 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/byte_stream.h" + +#include +#include + +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +// +// SliceBufferByteStream +// + +SliceBufferByteStream::SliceBufferByteStream(grpc_slice_buffer* slice_buffer, + uint32_t flags) + : ByteStream(static_cast(slice_buffer->length), flags) { + GPR_ASSERT(slice_buffer->length <= UINT32_MAX); + grpc_slice_buffer_init(&backing_buffer_); + grpc_slice_buffer_swap(slice_buffer, &backing_buffer_); +} + +SliceBufferByteStream::~SliceBufferByteStream() {} + +void SliceBufferByteStream::Orphan() { + grpc_slice_buffer_destroy_internal(&backing_buffer_); + GRPC_ERROR_UNREF(shutdown_error_); + // Note: We do not actually delete the object here, since + // SliceBufferByteStream is usually allocated as part of a larger + // object and has an OrphanablePtr of itself passed down through the + // filter stack. +} + +bool SliceBufferByteStream::Next(size_t /*max_size_hint*/, + grpc_closure* /*on_complete*/) { + GPR_DEBUG_ASSERT(backing_buffer_.count > 0); + return true; +} + +grpc_error_handle SliceBufferByteStream::Pull(grpc_slice* slice) { + if (GPR_UNLIKELY(shutdown_error_ != GRPC_ERROR_NONE)) { + return GRPC_ERROR_REF(shutdown_error_); + } + *slice = grpc_slice_buffer_take_first(&backing_buffer_); + return GRPC_ERROR_NONE; +} + +void SliceBufferByteStream::Shutdown(grpc_error_handle error) { + GRPC_ERROR_UNREF(shutdown_error_); + shutdown_error_ = error; +} + +// +// ByteStreamCache +// + +ByteStreamCache::ByteStreamCache(OrphanablePtr underlying_stream) + : underlying_stream_(std::move(underlying_stream)), + length_(underlying_stream_->length()), + flags_(underlying_stream_->flags()) { + grpc_slice_buffer_init(&cache_buffer_); +} + +ByteStreamCache::~ByteStreamCache() { Destroy(); } + +void ByteStreamCache::Destroy() { + underlying_stream_.reset(); + if (cache_buffer_.length > 0) { + grpc_slice_buffer_destroy_internal(&cache_buffer_); + } +} + +// +// ByteStreamCache::CachingByteStream +// + +ByteStreamCache::CachingByteStream::CachingByteStream(ByteStreamCache* cache) + : ByteStream(cache->length_, cache->flags_), cache_(cache) {} + +ByteStreamCache::CachingByteStream::~CachingByteStream() {} + +void ByteStreamCache::CachingByteStream::Orphan() { + GRPC_ERROR_UNREF(shutdown_error_); + // Note: We do not actually delete the object here, since + // CachingByteStream is usually allocated as part of a larger + // object and has an OrphanablePtr of itself passed down through the + // filter stack. +} + +bool ByteStreamCache::CachingByteStream::Next(size_t max_size_hint, + grpc_closure* on_complete) { + if (shutdown_error_ != GRPC_ERROR_NONE) return true; + if (cursor_ < cache_->cache_buffer_.count) return true; + GPR_ASSERT(cache_->underlying_stream_ != nullptr); + return cache_->underlying_stream_->Next(max_size_hint, on_complete); +} + +grpc_error_handle ByteStreamCache::CachingByteStream::Pull(grpc_slice* slice) { + if (shutdown_error_ != GRPC_ERROR_NONE) { + return GRPC_ERROR_REF(shutdown_error_); + } + if (cursor_ < cache_->cache_buffer_.count) { + *slice = grpc_slice_ref_internal(cache_->cache_buffer_.slices[cursor_]); + ++cursor_; + offset_ += GRPC_SLICE_LENGTH(*slice); + return GRPC_ERROR_NONE; + } + GPR_ASSERT(cache_->underlying_stream_ != nullptr); + grpc_error_handle error = cache_->underlying_stream_->Pull(slice); + if (error == GRPC_ERROR_NONE) { + grpc_slice_buffer_add(&cache_->cache_buffer_, + grpc_slice_ref_internal(*slice)); + ++cursor_; + offset_ += GRPC_SLICE_LENGTH(*slice); + // Orphan the underlying stream if it's been drained. + if (offset_ == cache_->underlying_stream_->length()) { + cache_->underlying_stream_.reset(); + } + } + return error; +} + +void ByteStreamCache::CachingByteStream::Shutdown(grpc_error_handle error) { + GRPC_ERROR_UNREF(shutdown_error_); + shutdown_error_ = GRPC_ERROR_REF(error); + if (cache_->underlying_stream_ != nullptr) { + cache_->underlying_stream_->Shutdown(error); + } +} + +void ByteStreamCache::CachingByteStream::Reset() { + cursor_ = 0; + offset_ = 0; +} + +} // namespace grpc_core diff --git a/src/core/lib/transport/connectivity_state.cc b/src/core/lib/transport/connectivity_state.cc new file mode 100644 index 00000000..ffad92e8 --- /dev/null +++ b/src/core/lib/transport/connectivity_state.cc @@ -0,0 +1,188 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/connectivity_state.h" + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/exec_ctx.h" + +namespace grpc_core { + +TraceFlag grpc_connectivity_state_trace(false, "connectivity_state"); + +const char* ConnectivityStateName(grpc_connectivity_state state) { + switch (state) { + case GRPC_CHANNEL_IDLE: + return "IDLE"; + case GRPC_CHANNEL_CONNECTING: + return "CONNECTING"; + case GRPC_CHANNEL_READY: + return "READY"; + case GRPC_CHANNEL_TRANSIENT_FAILURE: + return "TRANSIENT_FAILURE"; + case GRPC_CHANNEL_SHUTDOWN: + return "SHUTDOWN"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +// +// AsyncConnectivityStateWatcherInterface +// + +// A fire-and-forget class to asynchronously deliver a connectivity +// state notification to a watcher. +class AsyncConnectivityStateWatcherInterface::Notifier { + public: + Notifier(RefCountedPtr watcher, + grpc_connectivity_state state, const absl::Status& status, + const std::shared_ptr& work_serializer) + : watcher_(std::move(watcher)), state_(state), status_(status) { + if (work_serializer != nullptr) { + work_serializer->Run( + [this]() { SendNotification(this, GRPC_ERROR_NONE); }, + DEBUG_LOCATION); + } else { + GRPC_CLOSURE_INIT(&closure_, SendNotification, this, + grpc_schedule_on_exec_ctx); + ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); + } + } + + private: + static void SendNotification(void* arg, grpc_error_handle /*ignored*/) { + Notifier* self = static_cast(arg); + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, "watcher %p: delivering async notification for %s (%s)", + self->watcher_.get(), ConnectivityStateName(self->state_), + self->status_.ToString().c_str()); + } + self->watcher_->OnConnectivityStateChange(self->state_, self->status_); + delete self; + } + + RefCountedPtr watcher_; + const grpc_connectivity_state state_; + const absl::Status status_; + grpc_closure closure_; +}; + +void AsyncConnectivityStateWatcherInterface::Notify( + grpc_connectivity_state state, const absl::Status& status) { + new Notifier(Ref(), state, status, + work_serializer_); // Deletes itself when done. +} + +// +// ConnectivityStateTracker +// + +ConnectivityStateTracker::~ConnectivityStateTracker() { + grpc_connectivity_state current_state = + state_.load(std::memory_order_relaxed); + if (current_state == GRPC_CHANNEL_SHUTDOWN) return; + for (const auto& p : watchers_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, + "ConnectivityStateTracker %s[%p]: notifying watcher %p: %s -> %s", + name_, this, p.first, ConnectivityStateName(current_state), + ConnectivityStateName(GRPC_CHANNEL_SHUTDOWN)); + } + p.second->Notify(GRPC_CHANNEL_SHUTDOWN, absl::Status()); + } +} + +void ConnectivityStateTracker::AddWatcher( + grpc_connectivity_state initial_state, + OrphanablePtr watcher) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, "ConnectivityStateTracker %s[%p]: add watcher %p", name_, + this, watcher.get()); + } + grpc_connectivity_state current_state = + state_.load(std::memory_order_relaxed); + if (initial_state != current_state) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, + "ConnectivityStateTracker %s[%p]: notifying watcher %p: %s -> %s", + name_, this, watcher.get(), ConnectivityStateName(initial_state), + ConnectivityStateName(current_state)); + } + watcher->Notify(current_state, status_); + } + // If we're in state SHUTDOWN, don't add the watcher, so that it will + // be orphaned immediately. + if (current_state != GRPC_CHANNEL_SHUTDOWN) { + watchers_.insert(std::make_pair(watcher.get(), std::move(watcher))); + } +} + +void ConnectivityStateTracker::RemoveWatcher( + ConnectivityStateWatcherInterface* watcher) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, "ConnectivityStateTracker %s[%p]: remove watcher %p", + name_, this, watcher); + } + watchers_.erase(watcher); +} + +void ConnectivityStateTracker::SetState(grpc_connectivity_state state, + const absl::Status& status, + const char* reason) { + grpc_connectivity_state current_state = + state_.load(std::memory_order_relaxed); + if (state == current_state) return; + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, "ConnectivityStateTracker %s[%p]: %s -> %s (%s, %s)", + name_, this, ConnectivityStateName(current_state), + ConnectivityStateName(state), reason, status.ToString().c_str()); + } + state_.store(state, std::memory_order_relaxed); + status_ = status; + for (const auto& p : watchers_) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, + "ConnectivityStateTracker %s[%p]: notifying watcher %p: %s -> %s", + name_, this, p.first, ConnectivityStateName(current_state), + ConnectivityStateName(state)); + } + p.second->Notify(state, status); + } + // If the new state is SHUTDOWN, orphan all of the watchers. This + // avoids the need for the callers to explicitly cancel them. + if (state == GRPC_CHANNEL_SHUTDOWN) watchers_.clear(); +} + +grpc_connectivity_state ConnectivityStateTracker::state() const { + grpc_connectivity_state state = state_.load(std::memory_order_relaxed); + if (GRPC_TRACE_FLAG_ENABLED(grpc_connectivity_state_trace)) { + gpr_log(GPR_INFO, "ConnectivityStateTracker %s[%p]: get current state: %s", + name_, this, ConnectivityStateName(state)); + } + return state; +} + +} // namespace grpc_core diff --git a/src/core/lib/transport/error_utils.cc b/src/core/lib/transport/error_utils.cc new file mode 100644 index 00000000..2d2b67e3 --- /dev/null +++ b/src/core/lib/transport/error_utils.cc @@ -0,0 +1,191 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/error_utils.h" + +#include + +#include "src/core/lib/gprpp/status_helper.h" +#include "src/core/lib/iomgr/error_internal.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/status_conversion.h" + +static grpc_error_handle recursively_find_error_with_field( + grpc_error_handle error, grpc_error_ints which) { + intptr_t unused; + // If the error itself has a status code, return it. + if (grpc_error_get_int(error, which, &unused)) { + return error; + } +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + std::vector children = grpc_core::StatusGetChildren(error); + for (const absl::Status& child : children) { + grpc_error_handle result = recursively_find_error_with_field(child, which); + if (result != GRPC_ERROR_NONE) return result; + } +#else + if (grpc_error_is_special(error)) return GRPC_ERROR_NONE; + // Otherwise, search through its children. + uint8_t slot = error->first_err; + while (slot != UINT8_MAX) { + grpc_linked_error* lerr = + reinterpret_cast(error->arena + slot); + grpc_error_handle result = + recursively_find_error_with_field(lerr->err, which); + if (result) return result; + slot = lerr->next; + } +#endif + return GRPC_ERROR_NONE; +} + +void grpc_error_get_status(grpc_error_handle error, grpc_millis deadline, + grpc_status_code* code, std::string* message, + grpc_http2_error_code* http_error, + const char** error_string) { + // Fast path: We expect no error. + if (GPR_LIKELY(error == GRPC_ERROR_NONE)) { + if (code != nullptr) *code = GRPC_STATUS_OK; + if (message != nullptr) { + // Normally, we call grpc_error_get_str( + // error, GRPC_ERROR_STR_GRPC_MESSAGE, message). + // We can fastpath since we know that: + // 1) Error is null + // 2) which == GRPC_ERROR_STR_GRPC_MESSAGE + // 3) The resulting message is statically known. + // 4) Said resulting message is "". + // This means 3 movs, instead of 10s of instructions and a strlen. + *message = ""; + } + if (http_error != nullptr) { + *http_error = GRPC_HTTP2_NO_ERROR; + } + return; + } + + // Start with the parent error and recurse through the tree of children + // until we find the first one that has a status code. + grpc_error_handle found_error = + recursively_find_error_with_field(error, GRPC_ERROR_INT_GRPC_STATUS); + if (found_error == GRPC_ERROR_NONE) { + /// If no grpc-status exists, retry through the tree to find a http2 error + /// code + found_error = + recursively_find_error_with_field(error, GRPC_ERROR_INT_HTTP2_ERROR); + } + + // If we found an error with a status code above, use that; otherwise, + // fall back to using the parent error. + if (found_error == GRPC_ERROR_NONE) found_error = error; + + grpc_status_code status = GRPC_STATUS_UNKNOWN; + intptr_t integer; + if (grpc_error_get_int(found_error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } else if (grpc_error_get_int(found_error, GRPC_ERROR_INT_HTTP2_ERROR, + &integer)) { + status = grpc_http2_error_to_grpc_status( + static_cast(integer), deadline); + } else { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + status = static_cast(found_error.code()); +#endif + } + if (code != nullptr) *code = status; + + if (error_string != nullptr && status != GRPC_STATUS_OK) { + *error_string = gpr_strdup(grpc_error_std_string(error).c_str()); + } + + if (http_error != nullptr) { + if (grpc_error_get_int(found_error, GRPC_ERROR_INT_HTTP2_ERROR, &integer)) { + *http_error = static_cast(integer); + } else if (grpc_error_get_int(found_error, GRPC_ERROR_INT_GRPC_STATUS, + &integer)) { + *http_error = + grpc_status_to_http2_error(static_cast(integer)); + } else { + *http_error = found_error == GRPC_ERROR_NONE ? GRPC_HTTP2_NO_ERROR + : GRPC_HTTP2_INTERNAL_ERROR; + } + } + + // If the error has a status message, use it. Otherwise, fall back to + // the error description. + if (message != nullptr) { + if (!grpc_error_get_str(found_error, GRPC_ERROR_STR_GRPC_MESSAGE, + message)) { + if (!grpc_error_get_str(found_error, GRPC_ERROR_STR_DESCRIPTION, + message)) { +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + *message = grpc_error_std_string(error); +#else + *message = "unknown error"; +#endif + } + } + } +} + +absl::Status grpc_error_to_absl_status(grpc_error_handle error) { + grpc_status_code status; + // TODO(yashykt): This should be updated once we decide on how to use the + // absl::Status payload to capture all the contents of grpc_error. + std::string message; + grpc_error_get_status(error, GRPC_MILLIS_INF_FUTURE, &status, &message, + nullptr /* http_error */, nullptr /* error_string */); + return absl::Status(static_cast(status), message); +} + +grpc_error_handle absl_status_to_grpc_error(absl::Status status) { + // Special error checks + if (status.ok()) { + return GRPC_ERROR_NONE; + } + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STRING_VIEW(status.message()), + GRPC_ERROR_INT_GRPC_STATUS, static_cast(status.code())); +} + +bool grpc_error_has_clear_grpc_status(grpc_error_handle error) { + intptr_t unused; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &unused)) { + return true; + } +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS + std::vector children = grpc_core::StatusGetChildren(error); + for (const absl::Status& child : children) { + if (grpc_error_has_clear_grpc_status(child)) { + return true; + } + } +#else + uint8_t slot = error->first_err; + while (slot != UINT8_MAX) { + grpc_linked_error* lerr = + reinterpret_cast(error->arena + slot); + if (grpc_error_has_clear_grpc_status(lerr->err)) { + return true; + } + slot = lerr->next; + } +#endif + return false; +} diff --git a/src/core/lib/transport/metadata.cc b/src/core/lib/transport/metadata.cc new file mode 100644 index 00000000..004e44af --- /dev/null +++ b/src/core/lib/transport/metadata.cc @@ -0,0 +1,714 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/metadata.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/static_metadata.h" + +using grpc_core::AllocatedMetadata; +using grpc_core::InternedMetadata; +using grpc_core::StaticMetadata; +using grpc_core::UserData; + +/* There are two kinds of mdelem and mdstr instances. + * Static instances are declared in static_metadata.{h,c} and + * are initialized by grpc_mdctx_global_init(). + * Dynamic instances are stored in hash tables on grpc_mdctx, and are backed + * by internal_string and internal_element structures. + * Internal helper functions here-in (is_mdstr_static, is_mdelem_static) are + * used to determine which kind of element a pointer refers to. + */ + +grpc_core::DebugOnlyTraceFlag grpc_trace_metadata(false, "metadata"); + +#ifndef NDEBUG +#define DEBUG_ARGS , const char *file, int line +#define FWD_DEBUG_ARGS file, line + +void grpc_mdelem_trace_ref(void* md, const grpc_slice& key, + const grpc_slice& value, intptr_t refcnt, + const char* file, int line) { + if (grpc_trace_metadata.enabled()) { + char* key_str = grpc_slice_to_c_string(key); + char* value_str = grpc_slice_to_c_string(value); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "mdelem REF:%p:%" PRIdPTR "->%" PRIdPTR ": '%s' = '%s'", md, + refcnt, refcnt + 1, key_str, value_str); + gpr_free(key_str); + gpr_free(value_str); + } +} + +void grpc_mdelem_trace_unref(void* md, const grpc_slice& key, + const grpc_slice& value, intptr_t refcnt, + const char* file, int line) { + if (grpc_trace_metadata.enabled()) { + char* key_str = grpc_slice_to_c_string(key); + char* value_str = grpc_slice_to_c_string(value); + gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, + "mdelem UNREF:%p:%" PRIdPTR "->%" PRIdPTR ": '%s' = '%s'", md, + refcnt, refcnt - 1, key_str, value_str); + gpr_free(key_str); + gpr_free(value_str); + } +} + +#else // ifndef NDEBUG +#define DEBUG_ARGS +#define FWD_DEBUG_ARGS +#endif // ifndef NDEBUG + +#define INITIAL_SHARD_CAPACITY 8 +#define LOG2_SHARD_COUNT 4 +#define SHARD_COUNT ((size_t)(1 << LOG2_SHARD_COUNT)) + +#define TABLE_IDX(hash, capacity) (((hash) >> (LOG2_SHARD_COUNT)) % (capacity)) +#define SHARD_IDX(hash) ((hash) & ((1 << (LOG2_SHARD_COUNT)) - 1)) + +void StaticMetadata::HashInit() { + uint32_t k_hash = grpc_slice_hash_internal(kv_.key); + uint32_t v_hash = grpc_slice_hash_internal(kv_.value); + hash_ = GRPC_MDSTR_KV_HASH(k_hash, v_hash); +} + +AllocatedMetadata::AllocatedMetadata(const grpc_slice& key, + const grpc_slice& value) + : RefcountedMdBase(grpc_slice_ref_internal(key), + grpc_slice_ref_internal(value)) { +#ifndef NDEBUG + TraceAtStart("ALLOC_MD"); +#endif +} + +AllocatedMetadata::AllocatedMetadata(const grpc_slice& key, + const grpc_slice& value, const NoRefKey*) + : RefcountedMdBase(key, grpc_slice_ref_internal(value)) { +#ifndef NDEBUG + TraceAtStart("ALLOC_MD_NOREF_KEY"); +#endif +} + +AllocatedMetadata::AllocatedMetadata( + const grpc_core::ManagedMemorySlice& key, + const grpc_core::UnmanagedMemorySlice& value) + : RefcountedMdBase(key, value) { +#ifndef NDEBUG + TraceAtStart("ALLOC_MD_NOREF_KEY_VAL"); +#endif +} + +AllocatedMetadata::AllocatedMetadata( + const grpc_core::ExternallyManagedSlice& key, + const grpc_core::UnmanagedMemorySlice& value) + : RefcountedMdBase(key, value) { +#ifndef NDEBUG + TraceAtStart("ALLOC_MD_NOREF_KEY_VAL"); +#endif +} + +AllocatedMetadata::~AllocatedMetadata() { + grpc_slice_unref_internal(key()); + grpc_slice_unref_internal(value()); + void* user_data = user_data_.data.load(std::memory_order_relaxed); + if (user_data) { + destroy_user_data_func destroy_user_data = + user_data_.destroy_user_data.load(std::memory_order_relaxed); + destroy_user_data(user_data); + } +} + +#ifndef NDEBUG +void grpc_core::RefcountedMdBase::TraceAtStart(const char* tag) { + if (grpc_trace_metadata.enabled()) { + char* key_str = grpc_slice_to_c_string(key()); + char* value_str = grpc_slice_to_c_string(value()); + gpr_log(GPR_DEBUG, "mdelem %s:%p:%" PRIdPTR ": '%s' = '%s'", tag, this, + RefValue(), key_str, value_str); + gpr_free(key_str); + gpr_free(value_str); + } +} +#endif + +InternedMetadata::InternedMetadata(const grpc_slice& key, + const grpc_slice& value, uint32_t hash, + InternedMetadata* next) + : RefcountedMdBase(grpc_slice_ref_internal(key), + grpc_slice_ref_internal(value), hash), + link_(next) { +#ifndef NDEBUG + TraceAtStart("INTERNED_MD"); +#endif +} + +InternedMetadata::InternedMetadata(const grpc_slice& key, + const grpc_slice& value, uint32_t hash, + InternedMetadata* next, const NoRefKey*) + : RefcountedMdBase(key, grpc_slice_ref_internal(value), hash), link_(next) { +#ifndef NDEBUG + TraceAtStart("INTERNED_MD_NOREF_KEY"); +#endif +} + +InternedMetadata::~InternedMetadata() { + grpc_slice_unref_internal(key()); + grpc_slice_unref_internal(value()); + void* user_data = user_data_.data.load(std::memory_order_relaxed); + if (user_data) { + destroy_user_data_func destroy_user_data = + user_data_.destroy_user_data.load(std::memory_order_relaxed); + destroy_user_data(user_data); + } +} + +size_t InternedMetadata::CleanupLinkedMetadata( + InternedMetadata::BucketLink* head) { + size_t num_freed = 0; + InternedMetadata::BucketLink* prev_next = head; + InternedMetadata *md, *next; + + for (md = head->next; md; md = next) { + next = md->link_.next; + if (md->AllRefsDropped()) { + prev_next->next = next; + delete md; + num_freed++; + } else { + prev_next = &md->link_; + } + } + return num_freed; +} + +typedef struct mdtab_shard { + gpr_mu mu; + InternedMetadata::BucketLink* elems; + size_t count; + size_t capacity; + /** Estimate of the number of unreferenced mdelems in the hash table. + This will eventually converge to the exact number, but it's instantaneous + accuracy is not guaranteed */ + gpr_atm free_estimate; +} mdtab_shard; + +static mdtab_shard g_shards[SHARD_COUNT]; + +static void gc_mdtab(mdtab_shard* shard); + +void grpc_mdctx_global_init(void) { + /* initialize shards */ + for (size_t i = 0; i < SHARD_COUNT; i++) { + mdtab_shard* shard = &g_shards[i]; + gpr_mu_init(&shard->mu); + shard->count = 0; + gpr_atm_no_barrier_store(&shard->free_estimate, 0); + shard->capacity = INITIAL_SHARD_CAPACITY; + shard->elems = static_cast( + gpr_zalloc(sizeof(*shard->elems) * shard->capacity)); + } +} + +void grpc_mdctx_global_shutdown() { + for (size_t i = 0; i < SHARD_COUNT; i++) { + mdtab_shard* shard = &g_shards[i]; + gpr_mu_destroy(&shard->mu); + gc_mdtab(shard); + if (shard->count != 0) { + gpr_log(GPR_ERROR, "WARNING: %" PRIuPTR " metadata elements were leaked", + shard->count); + for (size_t i = 0; i < shard->capacity; i++) { + for (InternedMetadata* md = shard->elems[i].next; md; + md = md->bucket_next()) { + char* key_str = grpc_slice_to_c_string(md->key()); + char* value_str = grpc_slice_to_c_string(md->value()); + gpr_log(GPR_ERROR, "mdelem '%s' = '%s'", key_str, value_str); + gpr_free(key_str); + gpr_free(value_str); + } + } + if (grpc_iomgr_abort_on_leaks()) { + abort(); + } + } + // For ASAN builds, we don't want to crash here, because that will + // prevent ASAN from providing leak detection information, which is + // far more useful than this simple assertion. +#ifndef GRPC_ASAN_ENABLED + GPR_DEBUG_ASSERT(shard->count == 0); +#endif + gpr_free(shard->elems); + } +} + +#ifndef NDEBUG +static int is_mdelem_static(grpc_mdelem e) { + return reinterpret_cast(GRPC_MDELEM_DATA(e)) >= + &grpc_core::g_static_mdelem_table[0] && + reinterpret_cast(GRPC_MDELEM_DATA(e)) < + &grpc_core::g_static_mdelem_table[GRPC_STATIC_MDELEM_COUNT]; +} +#endif + +void InternedMetadata::RefWithShardLocked(mdtab_shard* shard) { +#ifndef NDEBUG + if (grpc_trace_metadata.enabled()) { + char* key_str = grpc_slice_to_c_string(key()); + char* value_str = grpc_slice_to_c_string(value()); + intptr_t value = RefValue(); + gpr_log(__FILE__, __LINE__, GPR_LOG_SEVERITY_DEBUG, + "mdelem REF:%p:%" PRIdPTR "->%" PRIdPTR ": '%s' = '%s'", this, + value, value + 1, key_str, value_str); + gpr_free(key_str); + gpr_free(value_str); + } +#endif + if (FirstRef()) { + gpr_atm_no_barrier_fetch_add(&shard->free_estimate, -1); + } +} + +static void gc_mdtab(mdtab_shard* shard) { + GPR_TIMER_SCOPE("gc_mdtab", 0); + size_t num_freed = 0; + for (size_t i = 0; i < shard->capacity; ++i) { + intptr_t freed = InternedMetadata::CleanupLinkedMetadata(&shard->elems[i]); + num_freed += freed; + shard->count -= freed; + } + gpr_atm_no_barrier_fetch_add(&shard->free_estimate, + -static_cast(num_freed)); +} + +static void grow_mdtab(mdtab_shard* shard) { + GPR_TIMER_SCOPE("grow_mdtab", 0); + + size_t capacity = shard->capacity * 2; + size_t i; + InternedMetadata::BucketLink* mdtab; + InternedMetadata *md, *next; + uint32_t hash; + + mdtab = static_cast( + gpr_zalloc(sizeof(InternedMetadata::BucketLink) * capacity)); + + for (i = 0; i < shard->capacity; i++) { + for (md = shard->elems[i].next; md; md = next) { + size_t idx; + hash = md->hash(); + next = md->bucket_next(); + idx = TABLE_IDX(hash, capacity); + md->set_bucket_next(mdtab[idx].next); + mdtab[idx].next = md; + } + } + gpr_free(shard->elems); + shard->elems = mdtab; + shard->capacity = capacity; +} + +static void rehash_mdtab(mdtab_shard* shard) { + if (gpr_atm_no_barrier_load(&shard->free_estimate) > + static_cast(shard->capacity / 4)) { + gc_mdtab(shard); + } else { + grow_mdtab(shard); + } +} + +template +static grpc_mdelem md_create_maybe_static(const grpc_slice& key, + const grpc_slice& value); +template +static grpc_mdelem md_create_must_intern(const grpc_slice& key, + const grpc_slice& value, + uint32_t hash); + +template +static grpc_mdelem md_create( + const grpc_slice& key, const grpc_slice& value, + grpc_mdelem_data* compatible_external_backing_store) { + // Ensure slices are, in fact, static if we claimed they were. + GPR_DEBUG_ASSERT(!key_definitely_static || + GRPC_IS_STATIC_METADATA_STRING(key)); + GPR_DEBUG_ASSERT(!value_definitely_static || + GRPC_IS_STATIC_METADATA_STRING(value)); + const bool key_is_interned = + key_definitely_static || grpc_slice_is_interned(key); + const bool value_is_interned = + value_definitely_static || grpc_slice_is_interned(value); + // External storage if either slice is not interned and the caller already + // created a backing store. If no backing store, we allocate one. + if (!key_is_interned || !value_is_interned) { + if (compatible_external_backing_store != nullptr) { + // Caller provided backing store. + return GRPC_MAKE_MDELEM(compatible_external_backing_store, + GRPC_MDELEM_STORAGE_EXTERNAL); + } else { + // We allocate backing store. + return key_definitely_static + ? GRPC_MAKE_MDELEM( + new AllocatedMetadata( + key, value, + static_cast( + nullptr)), + GRPC_MDELEM_STORAGE_ALLOCATED) + : GRPC_MAKE_MDELEM(new AllocatedMetadata(key, value), + GRPC_MDELEM_STORAGE_ALLOCATED); + } + } + return md_create_maybe_static( + key, value); +} + +template +static grpc_mdelem md_create_maybe_static(const grpc_slice& key, + const grpc_slice& value) { + // Ensure slices are, in fact, static if we claimed they were. + GPR_DEBUG_ASSERT(!key_definitely_static || + GRPC_IS_STATIC_METADATA_STRING(key)); + GPR_DEBUG_ASSERT(!value_definitely_static || + GRPC_IS_STATIC_METADATA_STRING(value)); + GPR_DEBUG_ASSERT(key.refcount != nullptr); + GPR_DEBUG_ASSERT(value.refcount != nullptr); + + const bool key_is_static_mdstr = + key_definitely_static || + key.refcount->GetType() == grpc_slice_refcount::Type::STATIC; + const bool value_is_static_mdstr = + value_definitely_static || + value.refcount->GetType() == grpc_slice_refcount::Type::STATIC; + + const intptr_t kidx = GRPC_STATIC_METADATA_INDEX(key); + + // Not all static slice input yields a statically stored metadata element. + if (key_is_static_mdstr && value_is_static_mdstr) { + grpc_mdelem static_elem = grpc_static_mdelem_for_static_strings( + kidx, GRPC_STATIC_METADATA_INDEX(value)); + if (!GRPC_MDISNULL(static_elem)) { + return static_elem; + } + } + + uint32_t khash = key_definitely_static + ? grpc_static_metadata_hash_values[kidx] + : grpc_slice_hash_refcounted(key); + + uint32_t hash = GRPC_MDSTR_KV_HASH(khash, grpc_slice_hash_refcounted(value)); + return md_create_must_intern(key, value, hash); +} + +template +static grpc_mdelem md_create_must_intern(const grpc_slice& key, + const grpc_slice& value, + uint32_t hash) { + // Here, we know both key and value are both at least interned, and both + // possibly static. We know that anything inside the shared interned table is + // also at least interned (and maybe static). Note that equality for a static + // and interned slice implies that they are both the same exact slice. + // The same applies to a pair of interned slices, or a pair of static slices. + // Rather than run the full equality check, we can therefore just do a pointer + // comparison of the refcounts. + InternedMetadata* md; + mdtab_shard* shard = &g_shards[SHARD_IDX(hash)]; + size_t idx; + + GPR_TIMER_SCOPE("grpc_mdelem_from_metadata_strings", 0); + + gpr_mu_lock(&shard->mu); + + idx = TABLE_IDX(hash, shard->capacity); + /* search for an existing pair */ + for (md = shard->elems[idx].next; md; md = md->bucket_next()) { + if (grpc_slice_static_interned_equal(key, md->key()) && + grpc_slice_static_interned_equal(value, md->value())) { + md->RefWithShardLocked(shard); + gpr_mu_unlock(&shard->mu); + return GRPC_MAKE_MDELEM(md, GRPC_MDELEM_STORAGE_INTERNED); + } + } + + /* not found: create a new pair */ + md = key_definitely_static + ? new InternedMetadata( + key, value, hash, shard->elems[idx].next, + static_cast(nullptr)) + : new InternedMetadata(key, value, hash, shard->elems[idx].next); + shard->elems[idx].next = md; + shard->count++; + + if (shard->count > shard->capacity * 2) { + rehash_mdtab(shard); + } + + gpr_mu_unlock(&shard->mu); + + return GRPC_MAKE_MDELEM(md, GRPC_MDELEM_STORAGE_INTERNED); +} + +grpc_mdelem grpc_mdelem_create( + const grpc_slice& key, const grpc_slice& value, + grpc_mdelem_data* compatible_external_backing_store) { + return md_create(key, value, compatible_external_backing_store); +} + +grpc_mdelem grpc_mdelem_create( + const grpc_core::StaticMetadataSlice& key, const grpc_slice& value, + grpc_mdelem_data* compatible_external_backing_store) { + return md_create(key, value, compatible_external_backing_store); +} + +/* Create grpc_mdelem from provided slices. We specify via template parameter + whether we know that the input key is static or not. If it is, we short + circuit various comparisons and a no-op unref. */ +template +static grpc_mdelem md_from_slices(const grpc_slice& key, + const grpc_slice& value) { + // Ensure key is, in fact, static if we claimed it was. + GPR_DEBUG_ASSERT(!key_definitely_static || + GRPC_IS_STATIC_METADATA_STRING(key)); + grpc_mdelem out = md_create(key, value, nullptr); + if (!key_definitely_static) { + grpc_slice_unref_internal(key); + } + grpc_slice_unref_internal(value); + return out; +} + +grpc_mdelem grpc_mdelem_from_slices(const grpc_slice& key, + const grpc_slice& value) { + return md_from_slices(key, value); +} + +grpc_mdelem grpc_mdelem_from_slices(const grpc_core::StaticMetadataSlice& key, + const grpc_slice& value) { + return md_from_slices(key, value); +} + +grpc_mdelem grpc_mdelem_from_slices( + const grpc_core::StaticMetadataSlice& key, + const grpc_core::StaticMetadataSlice& value) { + grpc_mdelem out = md_create_maybe_static(key, value); + return out; +} + +grpc_mdelem grpc_mdelem_from_slices( + const grpc_core::StaticMetadataSlice& key, + const grpc_core::ManagedMemorySlice& value) { + // TODO(arjunroy): We can save the unref if md_create_maybe_static ended up + // creating a new interned metadata. But otherwise - we need this here. + grpc_mdelem out = md_create_maybe_static(key, value); + grpc_slice_unref_internal(value); + return out; +} + +grpc_mdelem grpc_mdelem_from_slices( + const grpc_core::ManagedMemorySlice& key, + const grpc_core::ManagedMemorySlice& value) { + grpc_mdelem out = md_create_maybe_static(key, value); + // TODO(arjunroy): We can save the unref if md_create_maybe_static ended up + // creating a new interned metadata. But otherwise - we need this here. + grpc_slice_unref_internal(key); + grpc_slice_unref_internal(value); + return out; +} + +grpc_mdelem grpc_mdelem_from_grpc_metadata(grpc_metadata* metadata) { + bool key_changed = false; + grpc_slice key_slice = + grpc_slice_maybe_static_intern(metadata->key, &key_changed); + bool value_changed = false; + grpc_slice* unref_slice = nullptr; + grpc_slice value_slice = + grpc_slice_maybe_static_intern(metadata->value, &value_changed); + // If key or value changed, but the other didn't.... AND the other is a NOP + // refcount, then we need to convert it to a slice with a refcount else we run + // the risk of leaving a dangling reference to that metadata on the heap via + // this mdelem. + if (key_changed && !value_changed && value_slice.refcount != nullptr && + value_slice.refcount->GetType() == grpc_slice_refcount::Type::NOP) { + value_slice = grpc_slice_copy(value_slice); + unref_slice = &value_slice; + value_changed = true; + } else if (!key_changed && value_changed && key_slice.refcount != nullptr && + key_slice.refcount->GetType() == grpc_slice_refcount::Type::NOP) { + key_slice = grpc_slice_copy(key_slice); + unref_slice = &key_slice; + key_changed = true; + } + auto mdelem = + grpc_mdelem_create(key_slice, value_slice, + key_changed || value_changed + ? nullptr + : reinterpret_cast(metadata)); + if (unref_slice != nullptr) grpc_slice_unref_internal(*unref_slice); + return mdelem; +} + +static void* get_user_data(UserData* user_data, void (*destroy_func)(void*)) { + if (user_data->destroy_user_data.load(std::memory_order_acquire) == + destroy_func) { + return user_data->data.load(std::memory_order_relaxed); + } else { + return nullptr; + } +} + +void* grpc_mdelem_get_user_data(grpc_mdelem md, void (*destroy_func)(void*)) { + switch (GRPC_MDELEM_STORAGE(md)) { + case GRPC_MDELEM_STORAGE_EXTERNAL: + return nullptr; + case GRPC_MDELEM_STORAGE_STATIC: + return reinterpret_cast( + grpc_static_mdelem_user_data + [reinterpret_cast( + GRPC_MDELEM_DATA(md)) - + grpc_core::g_static_mdelem_table]); + case GRPC_MDELEM_STORAGE_ALLOCATED: { + auto* am = reinterpret_cast(GRPC_MDELEM_DATA(md)); + return get_user_data(am->user_data(), destroy_func); + } + case GRPC_MDELEM_STORAGE_INTERNED: { + auto* im = reinterpret_cast GRPC_MDELEM_DATA(md); + return get_user_data(im->user_data(), destroy_func); + } + } + GPR_UNREACHABLE_CODE(return nullptr); +} + +static void* set_user_data(UserData* ud, void (*destroy_func)(void*), + void* data) { + GPR_ASSERT((data == nullptr) == (destroy_func == nullptr)); + grpc_core::ReleasableMutexLock lock(&ud->mu_user_data); + if (ud->destroy_user_data.load(std::memory_order_relaxed)) { + /* user data can only be set once */ + lock.Release(); + if (destroy_func != nullptr) { + destroy_func(data); + } + return ud->data.load(std::memory_order_relaxed); + } + ud->data.store(data, std::memory_order_relaxed); + ud->destroy_user_data.store(destroy_func, std::memory_order_release); + return data; +} + +void* grpc_mdelem_set_user_data(grpc_mdelem md, void (*destroy_func)(void*), + void* data) { + switch (GRPC_MDELEM_STORAGE(md)) { + case GRPC_MDELEM_STORAGE_EXTERNAL: + destroy_func(data); + return nullptr; + case GRPC_MDELEM_STORAGE_STATIC: + destroy_func(data); + return reinterpret_cast( + grpc_static_mdelem_user_data + [reinterpret_cast( + GRPC_MDELEM_DATA(md)) - + grpc_core::g_static_mdelem_table]); + case GRPC_MDELEM_STORAGE_ALLOCATED: { + auto* am = reinterpret_cast(GRPC_MDELEM_DATA(md)); + return set_user_data(am->user_data(), destroy_func, data); + } + case GRPC_MDELEM_STORAGE_INTERNED: { + auto* im = reinterpret_cast GRPC_MDELEM_DATA(md); + GPR_DEBUG_ASSERT(!is_mdelem_static(md)); + return set_user_data(im->user_data(), destroy_func, data); + } + } + GPR_UNREACHABLE_CODE(return nullptr); +} + +bool grpc_mdelem_eq(grpc_mdelem a, grpc_mdelem b) { + if (a.payload == b.payload) return true; + if (GRPC_MDELEM_IS_INTERNED(a) && GRPC_MDELEM_IS_INTERNED(b)) return false; + if (GRPC_MDISNULL(a) || GRPC_MDISNULL(b)) return false; + return grpc_slice_eq(GRPC_MDKEY(a), GRPC_MDKEY(b)) && + grpc_slice_eq(GRPC_MDVALUE(a), GRPC_MDVALUE(b)); +} + +static void note_disposed_interned_metadata(uint32_t hash) { + mdtab_shard* shard = &g_shards[SHARD_IDX(hash)]; + gpr_atm_no_barrier_fetch_add(&shard->free_estimate, 1); +} + +void grpc_mdelem_do_unref(grpc_mdelem gmd DEBUG_ARGS) { + switch (GRPC_MDELEM_STORAGE(gmd)) { + case GRPC_MDELEM_STORAGE_EXTERNAL: + case GRPC_MDELEM_STORAGE_STATIC: + return; + case GRPC_MDELEM_STORAGE_INTERNED: { + auto* md = reinterpret_cast GRPC_MDELEM_DATA(gmd); + uint32_t hash = md->hash(); + if (GPR_UNLIKELY(md->Unref(FWD_DEBUG_ARGS))) { + /* once the refcount hits zero, some other thread can come along and + free md at any time: it's unsafe from this point on to access it */ + note_disposed_interned_metadata(hash); + } + break; + } + case GRPC_MDELEM_STORAGE_ALLOCATED: { + auto* md = reinterpret_cast GRPC_MDELEM_DATA(gmd); + if (GPR_UNLIKELY(md->Unref(FWD_DEBUG_ARGS))) { + delete md; + } + break; + } + } +} + +void grpc_mdelem_on_final_unref(grpc_mdelem_data_storage storage, void* ptr, + uint32_t hash DEBUG_ARGS) { +#ifndef NDEBUG + (void)file; + (void)line; +#endif + switch (storage) { + case GRPC_MDELEM_STORAGE_EXTERNAL: + case GRPC_MDELEM_STORAGE_STATIC: + return; + case GRPC_MDELEM_STORAGE_INTERNED: { + note_disposed_interned_metadata(hash); + break; + } + case GRPC_MDELEM_STORAGE_ALLOCATED: { + delete reinterpret_cast(ptr); + break; + } + } +} diff --git a/src/core/lib/transport/metadata_batch.cc b/src/core/lib/transport/metadata_batch.cc new file mode 100644 index 00000000..c651faa7 --- /dev/null +++ b/src/core/lib/transport/metadata_batch.cc @@ -0,0 +1,94 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/metadata_batch.h" + +#include +#include + +#include "absl/container/inlined_vector.h" + +#include +#include + +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" + +void grpc_metadata_batch_set_value(grpc_linked_mdelem* storage, + const grpc_slice& value) { + grpc_mdelem old_mdelem = storage->md; + grpc_mdelem new_mdelem = grpc_mdelem_from_slices( + grpc_slice_ref_internal(GRPC_MDKEY(old_mdelem)), value); + storage->md = new_mdelem; + GRPC_MDELEM_UNREF(old_mdelem); +} + +namespace { + +class CopySink { + public: + explicit CopySink(grpc_metadata_batch* dst) : dst_(dst) {} + + void Encode(grpc_mdelem md) { + // If the mdelem is not external, take a ref. + // Otherwise, create a new copy, holding its own refs to the + // underlying slices. + if (GRPC_MDELEM_STORAGE(md) != GRPC_MDELEM_STORAGE_EXTERNAL) { + md = GRPC_MDELEM_REF(md); + } else { + md = grpc_mdelem_from_slices(grpc_slice_copy(GRPC_MDKEY(md)), + grpc_slice_copy(GRPC_MDVALUE(md))); + } + // Error unused in non-debug builds. + grpc_error_handle GRPC_UNUSED error = dst_->Append(md); + // The only way that Append() can fail is if + // there's a duplicate entry for a callout. However, that can't be + // the case here, because we would not have been allowed to create + // a source batch that had that kind of conflict. + GPR_DEBUG_ASSERT(error == GRPC_ERROR_NONE); + } + + template + void Encode(T trait, V value) { + dst_->Set(trait, value); + } + + private: + grpc_metadata_batch* dst_; +}; + +} // namespace + +void grpc_metadata_batch_copy(const grpc_metadata_batch* src, + grpc_metadata_batch* dst) { + dst->Clear(); + CopySink sink(dst); + src->Encode(&sink); +} + +grpc_error_handle grpc_attach_md_to_error(grpc_error_handle src, + grpc_mdelem md) { + grpc_error_handle out = grpc_error_set_str( + grpc_error_set_str(src, GRPC_ERROR_STR_KEY, + grpc_core::StringViewFromSlice(GRPC_MDKEY(md))), + GRPC_ERROR_STR_VALUE, grpc_core::StringViewFromSlice(GRPC_MDVALUE(md))); + return out; +} diff --git a/src/core/lib/transport/pid_controller.cc b/src/core/lib/transport/pid_controller.cc new file mode 100644 index 00000000..c187a2be --- /dev/null +++ b/src/core/lib/transport/pid_controller.cc @@ -0,0 +1,51 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/pid_controller.h" + +#include "src/core/lib/gpr/useful.h" + +namespace grpc_core { + +PidController::PidController(const Args& args) + : last_control_value_(args.initial_control_value()), args_(args) {} + +double PidController::Update(double error, double dt) { + if (dt <= 0) return last_control_value_; + /* integrate error using the trapezoid rule */ + error_integral_ += dt * (last_error_ + error) * 0.5; + error_integral_ = grpc_core::Clamp(error_integral_, -args_.integral_range(), + args_.integral_range()); + double diff_error = (error - last_error_) / dt; + /* calculate derivative of control value vs time */ + double dc_dt = args_.gain_p() * error + args_.gain_i() * error_integral_ + + args_.gain_d() * diff_error; + /* and perform trapezoidal integration */ + double new_control_value = + last_control_value_ + dt * (last_dc_dt_ + dc_dt) * 0.5; + new_control_value = grpc_core::Clamp( + new_control_value, args_.min_control_value(), args_.max_control_value()); + last_error_ = error; + last_dc_dt_ = dc_dt; + last_control_value_ = new_control_value; + return new_control_value; +} + +} // namespace grpc_core diff --git a/src/core/lib/transport/static_metadata.cc b/src/core/lib/transport/static_metadata.cc new file mode 100644 index 00000000..71f4deef --- /dev/null +++ b/src/core/lib/transport/static_metadata.cc @@ -0,0 +1,1117 @@ +/* + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * WARNING: Auto-generated code. + * + * To make changes to this file, change + * tools/codegen/core/gen_static_metadata.py, and then re-run it. + * + * See metadata.h for an explanation of the interface here, and metadata.cc for + * an explanation of what's going on. + */ + +#include + +#include "src/core/lib/transport/static_metadata.h" + +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { +StaticMetadata g_static_mdelem_table[GRPC_STATIC_MDELEM_COUNT] = { + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[3].base, 10, + g_static_metadata_bytes + 19), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 0), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[1].base, 7, + g_static_metadata_bytes + 5), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[44].base, 3, + g_static_metadata_bytes + 849), + 1), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[1].base, 7, + g_static_metadata_bytes + 5), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[45].base, 4, + g_static_metadata_bytes + 852), + 2), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[0].base, 5, + g_static_metadata_bytes + 0), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[46].base, 1, + g_static_metadata_bytes + 856), + 3), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[0].base, 5, + g_static_metadata_bytes + 0), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[47].base, 11, + g_static_metadata_bytes + 857), + 4), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[4].base, 7, + g_static_metadata_bytes + 29), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[48].base, 4, + g_static_metadata_bytes + 868), + 5), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[4].base, 7, + g_static_metadata_bytes + 29), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[49].base, 5, + g_static_metadata_bytes + 872), + 6), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[50].base, 3, + g_static_metadata_bytes + 877), + 7), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[51].base, 3, + g_static_metadata_bytes + 880), + 8), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[52].base, 3, + g_static_metadata_bytes + 883), + 9), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[53].base, 3, + g_static_metadata_bytes + 886), + 10), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[54].base, 3, + g_static_metadata_bytes + 889), + 11), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[55].base, 3, + g_static_metadata_bytes + 892), + 12), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[2].base, 7, + g_static_metadata_bytes + 12), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[56].base, 3, + g_static_metadata_bytes + 895), + 13), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[57].base, 14, + g_static_metadata_bytes + 898), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 14), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[58].base, 13, + g_static_metadata_bytes + 912), + 15), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[59].base, 15, + g_static_metadata_bytes + 925), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 16), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[60].base, 13, + g_static_metadata_bytes + 940), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 17), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[61].base, 6, + g_static_metadata_bytes + 953), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 18), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[62].base, 27, + g_static_metadata_bytes + 959), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 19), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[63].base, 3, + g_static_metadata_bytes + 986), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 20), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[64].base, 5, + g_static_metadata_bytes + 989), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 21), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[65].base, 13, + g_static_metadata_bytes + 994), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 22), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[66].base, 13, + g_static_metadata_bytes + 1007), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 23), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[67].base, 19, + g_static_metadata_bytes + 1020), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 24), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[14].base, 16, + g_static_metadata_bytes + 168), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 25), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[68].base, 16, + g_static_metadata_bytes + 1039), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 26), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[69].base, 14, + g_static_metadata_bytes + 1055), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 27), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[70].base, 16, + g_static_metadata_bytes + 1069), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 28), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[71].base, 13, + g_static_metadata_bytes + 1085), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 29), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[13].base, 12, + g_static_metadata_bytes + 156), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 30), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[72].base, 6, + g_static_metadata_bytes + 1098), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 31), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[73].base, 4, + g_static_metadata_bytes + 1104), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 32), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[74].base, 4, + g_static_metadata_bytes + 1108), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 33), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[75].base, 6, + g_static_metadata_bytes + 1112), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 34), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[76].base, 7, + g_static_metadata_bytes + 1118), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 35), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[77].base, 4, + g_static_metadata_bytes + 1125), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 36), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[19].base, 4, + g_static_metadata_bytes + 276), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 37), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[78].base, 8, + g_static_metadata_bytes + 1129), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 38), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[79].base, 17, + g_static_metadata_bytes + 1137), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 39), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[80].base, 13, + g_static_metadata_bytes + 1154), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 40), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[81].base, 8, + g_static_metadata_bytes + 1167), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 41), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[82].base, 19, + g_static_metadata_bytes + 1175), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 42), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[83].base, 13, + g_static_metadata_bytes + 1194), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 43), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[84].base, 4, + g_static_metadata_bytes + 1207), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 44), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[85].base, 8, + g_static_metadata_bytes + 1211), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 45), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[86].base, 12, + g_static_metadata_bytes + 1219), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 46), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[87].base, 18, + g_static_metadata_bytes + 1231), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 47), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[88].base, 19, + g_static_metadata_bytes + 1249), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 48), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[89].base, 5, + g_static_metadata_bytes + 1268), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 49), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[90].base, 7, + g_static_metadata_bytes + 1273), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 50), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[91].base, 7, + g_static_metadata_bytes + 1280), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 51), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[92].base, 11, + g_static_metadata_bytes + 1287), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 52), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[93].base, 6, + g_static_metadata_bytes + 1298), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 53), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[94].base, 10, + g_static_metadata_bytes + 1304), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 54), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[95].base, 25, + g_static_metadata_bytes + 1314), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 55), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[96].base, 17, + g_static_metadata_bytes + 1339), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 56), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[18].base, 10, + g_static_metadata_bytes + 266), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 57), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[97].base, 4, + g_static_metadata_bytes + 1356), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 58), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[98].base, 3, + g_static_metadata_bytes + 1360), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 59), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[99].base, 16, + g_static_metadata_bytes + 1363), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 60), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[6].base, 11, + g_static_metadata_bytes + 48), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[100].base, 1, + g_static_metadata_bytes + 1379), + 61), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[6].base, 11, + g_static_metadata_bytes + 48), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[24].base, 1, + g_static_metadata_bytes + 367), + 62), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[6].base, 11, + g_static_metadata_bytes + 48), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[25].base, 1, + g_static_metadata_bytes + 368), + 63), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[8].base, 13, + g_static_metadata_bytes + 75), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[101].base, 8, + g_static_metadata_bytes + 1380), + 64), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[8].base, 13, + g_static_metadata_bytes + 75), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[40].base, 4, + g_static_metadata_bytes + 824), + 65), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[8].base, 13, + g_static_metadata_bytes + 75), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[39].base, 7, + g_static_metadata_bytes + 817), + 66), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[13].base, 12, + g_static_metadata_bytes + 156), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[102].base, 16, + g_static_metadata_bytes + 1388), + 67), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[4].base, 7, + g_static_metadata_bytes + 29), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[103].base, 4, + g_static_metadata_bytes + 1404), + 68), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[1].base, 7, + g_static_metadata_bytes + 5), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[104].base, 3, + g_static_metadata_bytes + 1408), + 69), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 70), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[14].base, 16, + g_static_metadata_bytes + 168), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[101].base, 8, + g_static_metadata_bytes + 1380), + 71), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[14].base, 16, + g_static_metadata_bytes + 168), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[40].base, 4, + g_static_metadata_bytes + 824), + 72), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[105].base, 11, + g_static_metadata_bytes + 1411), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[28].base, 0, + g_static_metadata_bytes + 371), + 73), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[101].base, 8, + g_static_metadata_bytes + 1380), + 74), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[39].base, 7, + g_static_metadata_bytes + 817), + 75), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[106].base, 16, + g_static_metadata_bytes + 1422), + 76), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[40].base, 4, + g_static_metadata_bytes + 824), + 77), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[107].base, 13, + g_static_metadata_bytes + 1438), + 78), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[108].base, 12, + g_static_metadata_bytes + 1451), + 79), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[9].base, 20, + g_static_metadata_bytes + 88), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[109].base, 21, + g_static_metadata_bytes + 1463), + 80), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[101].base, 8, + g_static_metadata_bytes + 1380), + 81), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[40].base, 4, + g_static_metadata_bytes + 824), + 82), + StaticMetadata(grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[15].base, 15, + g_static_metadata_bytes + 184), + grpc_core::StaticMetadataSlice( + &g_static_metadata_slice_refcounts[107].base, 13, + g_static_metadata_bytes + 1438), + 83), +}; + +/* Warning: the core static metadata currently operates under the soft +constraint that the first GRPC_CHTTP2_LAST_STATIC_ENTRY (61) entries must +contain metadata specified by the http2 hpack standard. The CHTTP2 transport +reads the core metadata with this assumption in mind. If the order of the core +static metadata is to be changed, then the CHTTP2 transport must be changed as +well to stop relying on the core metadata. */ + +grpc_mdelem g_static_mdelem_manifested[GRPC_STATIC_MDELEM_COUNT] = { + // clang-format off + /* GRPC_MDELEM_AUTHORITY_EMPTY: + ":authority": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[0].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_METHOD_GET: + ":method": "GET" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[1].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_METHOD_POST: + ":method": "POST" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[2].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_PATH_SLASH: + ":path": "/" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[3].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_PATH_SLASH_INDEX_DOT_HTML: + ":path": "/index.html" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[4].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_SCHEME_HTTP: + ":scheme": "http" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[5].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_SCHEME_HTTPS: + ":scheme": "https" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[6].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_200: + ":status": "200" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[7].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_204: + ":status": "204" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[8].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_206: + ":status": "206" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[9].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_304: + ":status": "304" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[10].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_400: + ":status": "400" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[11].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_404: + ":status": "404" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[12].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STATUS_500: + ":status": "500" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[13].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_CHARSET_EMPTY: + "accept-charset": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[14].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_ENCODING_GZIP_COMMA_DEFLATE: + "accept-encoding": "gzip, deflate" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[15].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_LANGUAGE_EMPTY: + "accept-language": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[16].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_RANGES_EMPTY: + "accept-ranges": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[17].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_EMPTY: + "accept": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[18].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCESS_CONTROL_ALLOW_ORIGIN_EMPTY: + "access-control-allow-origin": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[19].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_AGE_EMPTY: + "age": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[20].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ALLOW_EMPTY: + "allow": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[21].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_AUTHORIZATION_EMPTY: + "authorization": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[22].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CACHE_CONTROL_EMPTY: + "cache-control": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[23].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_DISPOSITION_EMPTY: + "content-disposition": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[24].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_ENCODING_EMPTY: + "content-encoding": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[25].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_LANGUAGE_EMPTY: + "content-language": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[26].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_LENGTH_EMPTY: + "content-length": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[27].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_LOCATION_EMPTY: + "content-location": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[28].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_RANGE_EMPTY: + "content-range": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[29].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_TYPE_EMPTY: + "content-type": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[30].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_COOKIE_EMPTY: + "cookie": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[31].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_DATE_EMPTY: + "date": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[32].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ETAG_EMPTY: + "etag": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[33].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_EXPECT_EMPTY: + "expect": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[34].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_EXPIRES_EMPTY: + "expires": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[35].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_FROM_EMPTY: + "from": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[36].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_HOST_EMPTY: + "host": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[37].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_IF_MATCH_EMPTY: + "if-match": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[38].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_IF_MODIFIED_SINCE_EMPTY: + "if-modified-since": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[39].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_IF_NONE_MATCH_EMPTY: + "if-none-match": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[40].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_IF_RANGE_EMPTY: + "if-range": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[41].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_IF_UNMODIFIED_SINCE_EMPTY: + "if-unmodified-since": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[42].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_LAST_MODIFIED_EMPTY: + "last-modified": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[43].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_LINK_EMPTY: + "link": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[44].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_LOCATION_EMPTY: + "location": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[45].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_MAX_FORWARDS_EMPTY: + "max-forwards": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[46].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_PROXY_AUTHENTICATE_EMPTY: + "proxy-authenticate": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[47].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_PROXY_AUTHORIZATION_EMPTY: + "proxy-authorization": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[48].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_RANGE_EMPTY: + "range": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[49].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_REFERER_EMPTY: + "referer": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[50].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_REFRESH_EMPTY: + "refresh": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[51].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_RETRY_AFTER_EMPTY: + "retry-after": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[52].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_SERVER_EMPTY: + "server": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[53].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_SET_COOKIE_EMPTY: + "set-cookie": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[54].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_STRICT_TRANSPORT_SECURITY_EMPTY: + "strict-transport-security": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[55].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_TRANSFER_ENCODING_EMPTY: + "transfer-encoding": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[56].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_USER_AGENT_EMPTY: + "user-agent": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[57].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_VARY_EMPTY: + "vary": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[58].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_VIA_EMPTY: + "via": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[59].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_WWW_AUTHENTICATE_EMPTY: + "www-authenticate": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[60].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_STATUS_0: + "grpc-status": "0" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[61].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_STATUS_1: + "grpc-status": "1" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[62].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_STATUS_2: + "grpc-status": "2" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[63].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ENCODING_IDENTITY: + "grpc-encoding": "identity" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[64].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ENCODING_GZIP: + "grpc-encoding": "gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[65].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ENCODING_DEFLATE: + "grpc-encoding": "deflate" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[66].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC: + "content-type": "application/grpc" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[67].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_SCHEME_GRPC: + ":scheme": "grpc" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[68].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_METHOD_PUT: + ":method": "PUT" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[69].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_ENCODING_EMPTY: + "accept-encoding": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[70].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_ENCODING_IDENTITY: + "content-encoding": "identity" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[71].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_CONTENT_ENCODING_GZIP: + "content-encoding": "gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[72].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_LB_COST_BIN_EMPTY: + "lb-cost-bin": "" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[73].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY: + "grpc-accept-encoding": "identity" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[74].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_DEFLATE: + "grpc-accept-encoding": "deflate" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[75].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE: + "grpc-accept-encoding": "identity,deflate" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[76].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_GZIP: + "grpc-accept-encoding": "gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[77].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_GZIP: + "grpc-accept-encoding": "identity,gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[78].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_DEFLATE_COMMA_GZIP: + "grpc-accept-encoding": "deflate,gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[79].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE_COMMA_GZIP: + "grpc-accept-encoding": "identity,deflate,gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[80].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_ENCODING_IDENTITY: + "accept-encoding": "identity" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[81].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_ENCODING_GZIP: + "accept-encoding": "gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[82].data(), + GRPC_MDELEM_STORAGE_STATIC), + /* GRPC_MDELEM_ACCEPT_ENCODING_IDENTITY_COMMA_GZIP: + "accept-encoding": "identity,gzip" */ + GRPC_MAKE_MDELEM( + &g_static_mdelem_table[83].data(), + GRPC_MDELEM_STORAGE_STATIC) + // clang-format on +}; +} // namespace grpc_core + +uintptr_t grpc_static_mdelem_user_data[GRPC_STATIC_MDELEM_COUNT] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 4, 6, 6, 8, 8, 2, 4, 4}; + +static const int8_t elems_r[] = { + 18, 11, -8, 0, 3, -75, -51, 0, 7, -4, 0, 0, 0, 12, -1, -2, + 0, 0, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, -50, 0, -55, -75, -76, -77, 0, + 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 28, 19, 18, 17, 16, 17, + 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, + 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, -11, 0}; +static uint32_t elems_phash(uint32_t i) { + i -= 46; + uint32_t x = i % 108; + uint32_t y = i / 108; + uint32_t h = x; + if (y < GPR_ARRAY_SIZE(elems_r)) { + uint32_t delta = static_cast(elems_r[y]); + h += delta; + } + return h; +} + +static const uint16_t elem_keys[] = { + 270, 271, 272, 273, 274, 275, 276, 1029, 1030, 1568, 1678, + 154, 155, 488, 489, 760, 919, 920, 46, 47, 1458, 1580, + 1690, 684, 685, 2008, 2118, 6628, 6738, 6848, 6958, 7068, 7178, + 7288, 7398, 7508, 7618, 7728, 7838, 7948, 1708, 8168, 8278, 8388, + 8498, 6518, 6298, 8608, 8058, 8718, 8828, 8938, 9048, 9158, 9268, + 9378, 9488, 9598, 9708, 9818, 9928, 10038, 10148, 10258, 10368, 10478, + 10588, 10698, 543, 1091, 10808, 214, 10918, 11578, 1096, 1097, 1098, + 1099, 981, 0, 0, 0, 1641, 1751, 0, 0, 0, 0, + 358, 1757, 0, 0, 0, 0, 1532}; +static const uint8_t elem_idxs[] = { + 7, 8, 9, 10, 11, 12, 13, 75, 77, 25, 70, 1, 2, 5, 6, 61, + 66, 65, 3, 4, 30, 72, 82, 62, 63, 57, 37, 17, 18, 19, 20, 21, + 22, 23, 24, 26, 27, 28, 29, 31, 15, 33, 34, 35, 36, 16, 14, 38, + 32, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 58, 68, 74, 59, 69, 60, 73, 76, 78, 79, 80, 64, 255, + 255, 255, 71, 81, 255, 255, 255, 255, 0, 83, 255, 255, 255, 255, 67}; + +grpc_mdelem grpc_static_mdelem_for_static_strings(intptr_t a, intptr_t b) { + if (a == -1 || b == -1) return GRPC_MDNULL; + uint32_t k = static_cast(a * 110 + b); + uint32_t h = elems_phash(k); + return h < GPR_ARRAY_SIZE(elem_keys) && elem_keys[h] == k && + elem_idxs[h] != 255 + ? GRPC_MAKE_MDELEM( + &grpc_core::g_static_mdelem_table[elem_idxs[h]].data(), + GRPC_MDELEM_STORAGE_STATIC) + : GRPC_MDNULL; +} + +const uint8_t grpc_static_accept_encoding_metadata[8] = {0, 74, 75, 76, + 77, 78, 79, 80}; + +const uint8_t grpc_static_accept_stream_encoding_metadata[4] = {0, 81, 82, 83}; diff --git a/src/core/lib/transport/status_conversion.cc b/src/core/lib/transport/status_conversion.cc new file mode 100644 index 00000000..46c6cd8c --- /dev/null +++ b/src/core/lib/transport/status_conversion.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/status_conversion.h" + +grpc_http2_error_code grpc_status_to_http2_error(grpc_status_code status) { + switch (status) { + case GRPC_STATUS_OK: + return GRPC_HTTP2_NO_ERROR; + case GRPC_STATUS_CANCELLED: + return GRPC_HTTP2_CANCEL; + case GRPC_STATUS_DEADLINE_EXCEEDED: + return GRPC_HTTP2_CANCEL; + case GRPC_STATUS_RESOURCE_EXHAUSTED: + return GRPC_HTTP2_ENHANCE_YOUR_CALM; + case GRPC_STATUS_PERMISSION_DENIED: + return GRPC_HTTP2_INADEQUATE_SECURITY; + case GRPC_STATUS_UNAVAILABLE: + return GRPC_HTTP2_REFUSED_STREAM; + default: + return GRPC_HTTP2_INTERNAL_ERROR; + } +} + +grpc_status_code grpc_http2_error_to_grpc_status(grpc_http2_error_code error, + grpc_millis deadline) { + switch (error) { + case GRPC_HTTP2_NO_ERROR: + /* should never be received */ + return GRPC_STATUS_INTERNAL; + case GRPC_HTTP2_CANCEL: + /* http2 cancel translates to STATUS_CANCELLED iff deadline hasn't been + * exceeded */ + return grpc_core::ExecCtx::Get()->Now() > deadline + ? GRPC_STATUS_DEADLINE_EXCEEDED + : GRPC_STATUS_CANCELLED; + case GRPC_HTTP2_ENHANCE_YOUR_CALM: + return GRPC_STATUS_RESOURCE_EXHAUSTED; + case GRPC_HTTP2_INADEQUATE_SECURITY: + return GRPC_STATUS_PERMISSION_DENIED; + case GRPC_HTTP2_REFUSED_STREAM: + return GRPC_STATUS_UNAVAILABLE; + default: + return GRPC_STATUS_INTERNAL; + } +} + +grpc_status_code grpc_http2_status_to_grpc_status(int status) { + switch (status) { + /* these HTTP2 status codes are called out explicitly in status.proto */ + case 200: + return GRPC_STATUS_OK; + case 400: + return GRPC_STATUS_INTERNAL; + case 401: + return GRPC_STATUS_UNAUTHENTICATED; + case 403: + return GRPC_STATUS_PERMISSION_DENIED; + case 404: + return GRPC_STATUS_UNIMPLEMENTED; + case 429: + return GRPC_STATUS_UNAVAILABLE; + case 502: + return GRPC_STATUS_UNAVAILABLE; + case 503: + return GRPC_STATUS_UNAVAILABLE; + case 504: + return GRPC_STATUS_UNAVAILABLE; + /* everything else is unknown */ + default: + return GRPC_STATUS_UNKNOWN; + } +} + +int grpc_status_to_http2_status(grpc_status_code /*status*/) { return 200; } diff --git a/src/core/lib/transport/status_metadata.cc b/src/core/lib/transport/status_metadata.cc new file mode 100644 index 00000000..fb706464 --- /dev/null +++ b/src/core/lib/transport/status_metadata.cc @@ -0,0 +1,63 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/status_metadata.h" + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/static_metadata.h" + +/* we offset status by a small amount when storing it into transport metadata + as metadata cannot store a 0 value (which is used as OK for grpc_status_codes + */ +#define STATUS_OFFSET 1 + +static void destroy_status(void* /*ignored*/) {} + +grpc_status_code grpc_get_status_code_from_metadata(grpc_mdelem md) { + if (grpc_mdelem_static_value_eq(md, GRPC_MDELEM_GRPC_STATUS_0)) { + return GRPC_STATUS_OK; + } + if (grpc_mdelem_static_value_eq(md, GRPC_MDELEM_GRPC_STATUS_1)) { + return GRPC_STATUS_CANCELLED; + } + if (grpc_mdelem_static_value_eq(md, GRPC_MDELEM_GRPC_STATUS_2)) { + return GRPC_STATUS_UNKNOWN; + } + void* user_data = grpc_mdelem_get_user_data(md, destroy_status); + if (user_data != nullptr) { + return static_cast(reinterpret_cast(user_data) - + STATUS_OFFSET); + } + uint32_t status; + if (!grpc_parse_slice_to_uint32(GRPC_MDVALUE(md), &status)) { + status = GRPC_STATUS_UNKNOWN; /* could not parse status code */ + } + grpc_mdelem_set_user_data(md, destroy_status, + reinterpret_cast(status + STATUS_OFFSET)); + return static_cast(status); +} + +grpc_mdelem grpc_get_reffed_status_elem_slowpath(int status_code) { + char tmp[GPR_LTOA_MIN_BUFSIZE]; + gpr_ltoa(status_code, tmp); + return grpc_mdelem_from_slices(GRPC_MDSTR_GRPC_STATUS, + grpc_core::UnmanagedMemorySlice(tmp)); +} diff --git a/src/core/lib/transport/timeout_encoding.cc b/src/core/lib/transport/timeout_encoding.cc new file mode 100644 index 00000000..27400638 --- /dev/null +++ b/src/core/lib/transport/timeout_encoding.cc @@ -0,0 +1,151 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/timeout_encoding.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" + +static int64_t round_up(int64_t x, int64_t divisor) { + return (x / divisor + (x % divisor != 0)) * divisor; +} + +/* round an integer up to the next value with three significant figures */ +static int64_t round_up_to_three_sig_figs(int64_t x) { + if (x < 1000) return x; + if (x < 10000) return round_up(x, 10); + if (x < 100000) return round_up(x, 100); + if (x < 1000000) return round_up(x, 1000); + if (x < 10000000) return round_up(x, 10000); + if (x < 100000000) return round_up(x, 100000); + if (x < 1000000000) return round_up(x, 1000000); + return round_up(x, 10000000); +} + +/* encode our minimum viable timeout value */ +static void enc_tiny(char* buffer) { memcpy(buffer, "1n", 3); } + +/* encode our maximum timeout value, about 1157 days */ +static void enc_huge(char* buffer) { memcpy(buffer, "99999999S", 10); } + +static void enc_ext(char* buffer, int64_t value, char ext) { + int n = int64_ttoa(value, buffer); + buffer[n] = ext; + buffer[n + 1] = 0; +} + +static void enc_seconds(char* buffer, int64_t sec) { + sec = round_up_to_three_sig_figs(sec); + if (sec % 3600 == 0) { + enc_ext(buffer, sec / 3600, 'H'); + } else if (sec % 60 == 0) { + enc_ext(buffer, sec / 60, 'M'); + } else { + enc_ext(buffer, sec, 'S'); + } +} + +static void enc_millis(char* buffer, int64_t x) { + x = round_up_to_three_sig_figs(x); + if (x < GPR_MS_PER_SEC) { + enc_ext(buffer, x, 'm'); + } else { + if (x % GPR_MS_PER_SEC == 0) { + enc_seconds(buffer, x / GPR_MS_PER_SEC); + } else { + enc_ext(buffer, x, 'm'); + } + } +} + +void grpc_http2_encode_timeout(grpc_millis timeout, char* buffer) { + const grpc_millis kMaxTimeout = 99999999000; + if (timeout <= 0) { + enc_tiny(buffer); + } else if (timeout < 1000 * GPR_MS_PER_SEC) { + enc_millis(buffer, timeout); + } else if (timeout >= kMaxTimeout) { + enc_huge(buffer); + } else { + enc_seconds(buffer, + timeout / GPR_MS_PER_SEC + (timeout % GPR_MS_PER_SEC != 0)); + } +} + +static int is_all_whitespace(const char* p, const char* end) { + while (p != end && *p == ' ') p++; + return p == end; +} + +int grpc_http2_decode_timeout(const grpc_slice& text, grpc_millis* timeout) { + grpc_millis x = 0; + const uint8_t* p = GRPC_SLICE_START_PTR(text); + const uint8_t* end = GRPC_SLICE_END_PTR(text); + int have_digit = 0; + /* skip whitespace */ + for (; p != end && *p == ' '; p++) { + } + /* decode numeric part */ + for (; p != end && *p >= '0' && *p <= '9'; p++) { + int32_t digit = static_cast(*p - static_cast('0')); + have_digit = 1; + /* spec allows max. 8 digits, but we allow values up to 1,000,000,000 */ + if (x >= (100 * 1000 * 1000)) { + if (x != (100 * 1000 * 1000) || digit != 0) { + *timeout = GRPC_MILLIS_INF_FUTURE; + return 1; + } + } + x = x * 10 + digit; + } + if (!have_digit) return 0; + /* skip whitespace */ + for (; p != end && *p == ' '; p++) { + } + if (p == end) return 0; + /* decode unit specifier */ + switch (*p) { + case 'n': + *timeout = x / GPR_NS_PER_MS + (x % GPR_NS_PER_MS != 0); + break; + case 'u': + *timeout = x / GPR_US_PER_MS + (x % GPR_US_PER_MS != 0); + break; + case 'm': + *timeout = x; + break; + case 'S': + *timeout = x * GPR_MS_PER_SEC; + break; + case 'M': + *timeout = x * 60 * GPR_MS_PER_SEC; + break; + case 'H': + *timeout = x * 60 * 60 * GPR_MS_PER_SEC; + break; + default: + return 0; + } + p++; + return is_all_whitespace(reinterpret_cast(p), + reinterpret_cast(end)); +} diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc new file mode 100644 index 00000000..5555b62e --- /dev/null +++ b/src/core/lib/transport/transport.cc @@ -0,0 +1,261 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/transport/transport.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/transport_impl.h" + +grpc_core::DebugOnlyTraceFlag grpc_trace_stream_refcount(false, + "stream_refcount"); + +void grpc_stream_destroy(grpc_stream_refcount* refcount) { + if (!grpc_iomgr_is_any_background_poller_thread() && + (grpc_core::ExecCtx::Get()->flags() & + GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP)) { + /* Ick. + The thread we're running on MAY be owned (indirectly) by a call-stack. + If that's the case, destroying the call-stack MAY try to destroy the + thread, which is a tangled mess that we just don't want to ever have to + cope with. + Throw this over to the executor (on a core-owned thread) and process it + there. */ + grpc_core::Executor::Run(&refcount->destroy, GRPC_ERROR_NONE); + } else { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &refcount->destroy, + GRPC_ERROR_NONE); + } +} + +void slice_stream_destroy(void* arg) { + grpc_stream_destroy(static_cast(arg)); +} + +#define STREAM_REF_FROM_SLICE_REF(p) \ + ((grpc_stream_refcount*)(((uint8_t*)(p)) - \ + offsetof(grpc_stream_refcount, slice_refcount))) + +grpc_slice grpc_slice_from_stream_owned_buffer(grpc_stream_refcount* refcount, + void* buffer, size_t length) { +#ifndef NDEBUG + grpc_stream_ref(STREAM_REF_FROM_SLICE_REF(&refcount->slice_refcount), + "slice"); +#else + grpc_stream_ref(STREAM_REF_FROM_SLICE_REF(&refcount->slice_refcount)); +#endif + grpc_slice res; + res.refcount = &refcount->slice_refcount; + res.data.refcounted.bytes = static_cast(buffer); + res.data.refcounted.length = length; + return res; +} + +#ifndef NDEBUG +void grpc_stream_ref_init(grpc_stream_refcount* refcount, int /*initial_refs*/, + grpc_iomgr_cb_func cb, void* cb_arg, + const char* object_type) { + refcount->object_type = object_type; +#else +void grpc_stream_ref_init(grpc_stream_refcount* refcount, int /*initial_refs*/, + grpc_iomgr_cb_func cb, void* cb_arg) { +#endif + GRPC_CLOSURE_INIT(&refcount->destroy, cb, cb_arg, grpc_schedule_on_exec_ctx); + + new (&refcount->refs) grpc_core::RefCount( + 1, GRPC_TRACE_FLAG_ENABLED(grpc_trace_stream_refcount) ? "stream_refcount" + : nullptr); + new (&refcount->slice_refcount) grpc_slice_refcount( + grpc_slice_refcount::Type::REGULAR, &refcount->refs, slice_stream_destroy, + refcount, &refcount->slice_refcount); +} + +static void move64(uint64_t* from, uint64_t* to) { + *to += *from; + *from = 0; +} + +void grpc_transport_move_one_way_stats(grpc_transport_one_way_stats* from, + grpc_transport_one_way_stats* to) { + move64(&from->framing_bytes, &to->framing_bytes); + move64(&from->data_bytes, &to->data_bytes); + move64(&from->header_bytes, &to->header_bytes); +} + +void grpc_transport_move_stats(grpc_transport_stream_stats* from, + grpc_transport_stream_stats* to) { + grpc_transport_move_one_way_stats(&from->incoming, &to->incoming); + grpc_transport_move_one_way_stats(&from->outgoing, &to->outgoing); +} + +size_t grpc_transport_stream_size(grpc_transport* transport) { + return GPR_ROUND_UP_TO_ALIGNMENT_SIZE(transport->vtable->sizeof_stream); +} + +void grpc_transport_destroy(grpc_transport* transport) { + transport->vtable->destroy(transport); +} + +int grpc_transport_init_stream(grpc_transport* transport, grpc_stream* stream, + grpc_stream_refcount* refcount, + const void* server_data, + grpc_core::Arena* arena) { + return transport->vtable->init_stream(transport, stream, refcount, + server_data, arena); +} + +void grpc_transport_perform_stream_op(grpc_transport* transport, + grpc_stream* stream, + grpc_transport_stream_op_batch* op) { + transport->vtable->perform_stream_op(transport, stream, op); +} + +void grpc_transport_perform_op(grpc_transport* transport, + grpc_transport_op* op) { + transport->vtable->perform_op(transport, op); +} + +void grpc_transport_set_pops(grpc_transport* transport, grpc_stream* stream, + grpc_polling_entity* pollent) { + grpc_pollset* pollset; + grpc_pollset_set* pollset_set; + if ((pollset = grpc_polling_entity_pollset(pollent)) != nullptr) { + transport->vtable->set_pollset(transport, stream, pollset); + } else if ((pollset_set = grpc_polling_entity_pollset_set(pollent)) != + nullptr) { + transport->vtable->set_pollset_set(transport, stream, pollset_set); + } else { + // No-op for empty pollset. Empty pollset is possible when using + // non-fd-based event engines such as CFStream. + } +} + +void grpc_transport_destroy_stream(grpc_transport* transport, + grpc_stream* stream, + grpc_closure* then_schedule_closure) { + transport->vtable->destroy_stream(transport, stream, then_schedule_closure); +} + +grpc_endpoint* grpc_transport_get_endpoint(grpc_transport* transport) { + return transport->vtable->get_endpoint(transport); +} + +// This comment should be sung to the tune of +// "Supercalifragilisticexpialidocious": +// +// grpc_transport_stream_op_batch_finish_with_failure +// is a function that must always unref cancel_error +// though it lives in lib, it handles transport stream ops sure +// it's grpc_transport_stream_op_batch_finish_with_failure +void grpc_transport_stream_op_batch_finish_with_failure( + grpc_transport_stream_op_batch* batch, grpc_error_handle error, + grpc_core::CallCombiner* call_combiner) { + if (batch->send_message) { + batch->payload->send_message.send_message.reset(); + } + if (batch->cancel_stream) { + GRPC_ERROR_UNREF(batch->payload->cancel_stream.cancel_error); + } + // Construct a list of closures to execute. + grpc_core::CallCombinerClosureList closures; + if (batch->recv_initial_metadata) { + closures.Add( + batch->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_REF(error), "failing recv_initial_metadata_ready"); + } + if (batch->recv_message) { + closures.Add(batch->payload->recv_message.recv_message_ready, + GRPC_ERROR_REF(error), "failing recv_message_ready"); + } + if (batch->recv_trailing_metadata) { + closures.Add( + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_REF(error), "failing recv_trailing_metadata_ready"); + } + if (batch->on_complete != nullptr) { + closures.Add(batch->on_complete, GRPC_ERROR_REF(error), + "failing on_complete"); + } + // Execute closures. + closures.RunClosures(call_combiner); + GRPC_ERROR_UNREF(error); +} + +struct made_transport_op { + grpc_closure outer_on_complete; + grpc_closure* inner_on_complete = nullptr; + grpc_transport_op op; + made_transport_op() { + memset(&outer_on_complete, 0, sizeof(outer_on_complete)); + } +}; + +static void destroy_made_transport_op(void* arg, grpc_error_handle error) { + made_transport_op* op = static_cast(arg); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->inner_on_complete, + GRPC_ERROR_REF(error)); + delete op; +} + +grpc_transport_op* grpc_make_transport_op(grpc_closure* on_complete) { + made_transport_op* op = new made_transport_op(); + GRPC_CLOSURE_INIT(&op->outer_on_complete, destroy_made_transport_op, op, + grpc_schedule_on_exec_ctx); + op->inner_on_complete = on_complete; + op->op.on_consumed = &op->outer_on_complete; + return &op->op; +} + +struct made_transport_stream_op { + grpc_closure outer_on_complete; + grpc_closure* inner_on_complete = nullptr; + grpc_transport_stream_op_batch op; + grpc_transport_stream_op_batch_payload payload{nullptr}; +}; +static void destroy_made_transport_stream_op(void* arg, + grpc_error_handle error) { + made_transport_stream_op* op = static_cast(arg); + grpc_closure* c = op->inner_on_complete; + delete op; + grpc_core::Closure::Run(DEBUG_LOCATION, c, GRPC_ERROR_REF(error)); +} + +grpc_transport_stream_op_batch* grpc_make_transport_stream_op( + grpc_closure* on_complete) { + made_transport_stream_op* op = new made_transport_stream_op(); + op->op.payload = &op->payload; + GRPC_CLOSURE_INIT(&op->outer_on_complete, destroy_made_transport_stream_op, + op, grpc_schedule_on_exec_ctx); + op->inner_on_complete = on_complete; + op->op.on_complete = &op->outer_on_complete; + return &op->op; +} diff --git a/src/core/lib/transport/transport_op_string.cc b/src/core/lib/transport/transport_op_string.cc new file mode 100644 index 00000000..230d43e4 --- /dev/null +++ b/src/core/lib/transport/transport_op_string.cc @@ -0,0 +1,189 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/connectivity_state.h" + +/* These routines are here to facilitate debugging - they produce string + representations of various transport data structures */ + +namespace { +class MetadataListEncoder { + public: + explicit MetadataListEncoder(std::vector* out) : out_(out) {} + void Encode(const grpc_mdelem& md) { + MaybeAddComma(); + out_->push_back("key="); + char* dump = grpc_dump_slice(GRPC_MDKEY(md), GPR_DUMP_HEX | GPR_DUMP_ASCII); + out_->push_back(dump); + gpr_free(dump); + out_->push_back(" value="); + dump = grpc_dump_slice(GRPC_MDVALUE(md), GPR_DUMP_HEX | GPR_DUMP_ASCII); + out_->push_back(dump); + gpr_free(dump); + } + + void Encode(grpc_core::GrpcTimeoutMetadata, grpc_millis deadline) { + MaybeAddComma(); + out_->push_back(absl::StrFormat("deadline=%" PRId64, deadline)); + } + + template + void Encode(Which, typename Which::ValueType value) { + MaybeAddComma(); + out_->push_back( + absl::StrCat(Which::key(), "=", Which::DisplayValue(value))); + } + + private: + void MaybeAddComma() { + if (out_->size() != initial_size_) out_->push_back(", "); + } + std::vector* const out_; + const size_t initial_size_ = out_->size(); +}; +} // namespace + +static void put_metadata_list(const grpc_metadata_batch& md, + std::vector* out) { + MetadataListEncoder encoder(out); + md.Encode(&encoder); +} + +std::string grpc_transport_stream_op_batch_string( + grpc_transport_stream_op_batch* op) { + std::vector out; + + if (op->send_initial_metadata) { + out.push_back(" SEND_INITIAL_METADATA{"); + put_metadata_list(*op->payload->send_initial_metadata.send_initial_metadata, + &out); + out.push_back("}"); + } + + if (op->send_message) { + if (op->payload->send_message.send_message != nullptr) { + out.push_back( + absl::StrFormat(" SEND_MESSAGE:flags=0x%08x:len=%d", + op->payload->send_message.send_message->flags(), + op->payload->send_message.send_message->length())); + } else { + // This can happen when we check a batch after the transport has + // processed and cleared the send_message op. + out.push_back(" SEND_MESSAGE(flag and length unknown, already orphaned)"); + } + } + + if (op->send_trailing_metadata) { + out.push_back(" SEND_TRAILING_METADATA{"); + put_metadata_list( + *op->payload->send_trailing_metadata.send_trailing_metadata, &out); + out.push_back("}"); + } + + if (op->recv_initial_metadata) { + out.push_back(" RECV_INITIAL_METADATA"); + } + + if (op->recv_message) { + out.push_back(" RECV_MESSAGE"); + } + + if (op->recv_trailing_metadata) { + out.push_back(" RECV_TRAILING_METADATA"); + } + + if (op->cancel_stream) { + out.push_back(absl::StrCat( + " CANCEL:", + grpc_error_std_string(op->payload->cancel_stream.cancel_error))); + } + + return absl::StrJoin(out, ""); +} + +std::string grpc_transport_op_string(grpc_transport_op* op) { + std::vector out; + + if (op->start_connectivity_watch != nullptr) { + out.push_back(absl::StrFormat( + " START_CONNECTIVITY_WATCH:watcher=%p:from=%s", + op->start_connectivity_watch.get(), + grpc_core::ConnectivityStateName(op->start_connectivity_watch_state))); + } + + if (op->stop_connectivity_watch != nullptr) { + out.push_back(absl::StrFormat(" STOP_CONNECTIVITY_WATCH:watcher=%p", + op->stop_connectivity_watch)); + } + + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + out.push_back(absl::StrCat( + " DISCONNECT:", grpc_error_std_string(op->disconnect_with_error))); + } + + if (op->goaway_error != GRPC_ERROR_NONE) { + out.push_back(absl::StrCat(" SEND_GOAWAY:%s", + grpc_error_std_string(op->goaway_error))); + } + + if (op->set_accept_stream) { + out.push_back(absl::StrFormat(" SET_ACCEPT_STREAM:%p(%p,...)", + op->set_accept_stream_fn, + op->set_accept_stream_user_data)); + } + + if (op->bind_pollset != nullptr) { + out.push_back(" BIND_POLLSET"); + } + + if (op->bind_pollset_set != nullptr) { + out.push_back(" BIND_POLLSET_SET"); + } + + if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) { + out.push_back(" SEND_PING"); + } + + return absl::StrJoin(out, ""); +} + +void grpc_call_log_op(const char* file, int line, gpr_log_severity severity, + grpc_call_element* elem, + grpc_transport_stream_op_batch* op) { + gpr_log(file, line, severity, "OP[%s:%p]: %s", elem->filter->name, elem, + grpc_transport_stream_op_batch_string(op).c_str()); +} diff --git a/src/core/lib/uri/uri_parser.cc b/src/core/lib/uri/uri_parser.cc new file mode 100644 index 00000000..ea103306 --- /dev/null +++ b/src/core/lib/uri/uri_parser.cc @@ -0,0 +1,191 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/uri/uri_parser.h" + +#include + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" + +#include + +#include "src/core/lib/gpr/string.h" + +namespace grpc_core { +namespace { + +// Similar to `grpc_permissive_percent_decode_slice`, this %-decodes all valid +// triplets, and passes through the rest verbatim. +std::string PercentDecode(absl::string_view str) { + if (str.empty() || !absl::StrContains(str, "%")) { + return std::string(str); + } + std::string out; + std::string unescaped; + out.reserve(str.size()); + for (size_t i = 0; i < str.length(); i++) { + unescaped = ""; + if (str[i] != '%') { + out += str[i]; + continue; + } + if (i + 3 >= str.length() || + !absl::CUnescape(absl::StrCat("\\x", str.substr(i + 1, 2)), + &unescaped) || + unescaped.length() > 1) { + out += str[i]; + } else { + out += unescaped[0]; + i += 2; + } + } + return out; +} + +// Checks if this string is made up of pchars, '/', '?', and '%' exclusively. +// See https://tools.ietf.org/html/rfc3986#section-3.4 +bool IsPCharString(absl::string_view str) { + return (str.find_first_not_of("ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789" + "?/:@\\-._~!$&'()*+,;=%") == + absl::string_view::npos); +} + +absl::Status MakeInvalidURIStatus(absl::string_view part_name, + absl::string_view uri, + absl::string_view extra) { + return absl::InvalidArgumentError(absl::StrFormat( + "Could not parse '%s' from uri '%s'. %s", part_name, uri, extra)); +} +} // namespace + +absl::StatusOr URI::Parse(absl::string_view uri_text) { + absl::StatusOr decoded; + absl::string_view remaining = uri_text; + // parse scheme + size_t idx = remaining.find(':'); + if (idx == remaining.npos || idx == 0) { + return MakeInvalidURIStatus("scheme", uri_text, "Scheme not found."); + } + std::string scheme(remaining.substr(0, idx)); + if (scheme.find_first_not_of("ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+-.") != std::string::npos) { + return MakeInvalidURIStatus("scheme", uri_text, + "Scheme contains invalid characters."); + } + if (!isalpha(scheme[0])) { + return MakeInvalidURIStatus( + "scheme", uri_text, + "Scheme must begin with an alpha character [A-Za-z]."); + } + remaining.remove_prefix(scheme.length() + 1); + // parse authority + std::string authority; + if (absl::StartsWith(remaining, "//")) { + remaining.remove_prefix(2); + authority = + PercentDecode(remaining.substr(0, remaining.find_first_of("/?#"))); + remaining.remove_prefix(authority.length()); + } + // parse path + std::string path; + if (!remaining.empty()) { + path = PercentDecode(remaining.substr(0, remaining.find_first_of("?#"))); + remaining.remove_prefix(path.length()); + } + // parse query + std::vector query_param_pairs; + if (!remaining.empty() && remaining[0] == '?') { + remaining.remove_prefix(1); + absl::string_view tmp_query = remaining.substr(0, remaining.find('#')); + if (tmp_query.empty()) { + return MakeInvalidURIStatus("query", uri_text, "Invalid query string."); + } + if (!IsPCharString(tmp_query)) { + return MakeInvalidURIStatus("query string", uri_text, + "Query string contains invalid characters."); + } + for (absl::string_view query_param : absl::StrSplit(tmp_query, '&')) { + const std::pair possible_kv = + absl::StrSplit(query_param, absl::MaxSplits('=', 1)); + if (possible_kv.first.empty()) continue; + query_param_pairs.push_back({PercentDecode(possible_kv.first), + PercentDecode(possible_kv.second)}); + } + remaining.remove_prefix(tmp_query.length()); + } + std::string fragment; + if (!remaining.empty() && remaining[0] == '#') { + remaining.remove_prefix(1); + if (!IsPCharString(remaining)) { + return MakeInvalidURIStatus("fragment", uri_text, + "Fragment contains invalid characters."); + } + fragment = PercentDecode(remaining); + } + return URI(std::move(scheme), std::move(authority), std::move(path), + std::move(query_param_pairs), std::move(fragment)); +} + +URI::URI(std::string scheme, std::string authority, std::string path, + std::vector query_parameter_pairs, std::string fragment) + : scheme_(std::move(scheme)), + authority_(std::move(authority)), + path_(std::move(path)), + query_parameter_pairs_(std::move(query_parameter_pairs)), + fragment_(std::move(fragment)) { + for (const auto& kv : query_parameter_pairs_) { + query_parameter_map_[kv.key] = kv.value; + } +} + +URI::URI(const URI& other) + : scheme_(other.scheme_), + authority_(other.authority_), + path_(other.path_), + query_parameter_pairs_(other.query_parameter_pairs_), + fragment_(other.fragment_) { + for (const auto& kv : query_parameter_pairs_) { + query_parameter_map_[kv.key] = kv.value; + } +} + +URI& URI::operator=(const URI& other) { + if (this == &other) { + return *this; + } + scheme_ = other.scheme_; + authority_ = other.authority_; + path_ = other.path_; + query_parameter_pairs_ = other.query_parameter_pairs_; + fragment_ = other.fragment_; + for (const auto& kv : query_parameter_pairs_) { + query_parameter_map_[kv.key] = kv.value; + } + return *this; +} +} // namespace grpc_core diff --git a/src/core/plugin_registry/grpc_plugin_registry.cc b/src/core/plugin_registry/grpc_plugin_registry.cc new file mode 100644 index 00000000..734c585d --- /dev/null +++ b/src/core/plugin_registry/grpc_plugin_registry.cc @@ -0,0 +1,193 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/builtins.h" + +void grpc_chttp2_plugin_init(void); +void grpc_chttp2_plugin_shutdown(void); +void grpc_client_channel_init(void); +void grpc_client_channel_shutdown(void); +void grpc_inproc_plugin_init(void); +void grpc_inproc_plugin_shutdown(void); +void grpc_resolver_fake_init(void); +void grpc_resolver_fake_shutdown(void); +void grpc_lb_policy_grpclb_init(void); +void grpc_lb_policy_grpclb_shutdown(void); +void grpc_lb_policy_priority_init(void); +void grpc_lb_policy_priority_shutdown(void); +void grpc_lb_policy_weighted_target_init(void); +void grpc_lb_policy_weighted_target_shutdown(void); +void grpc_lb_policy_pick_first_init(void); +void grpc_lb_policy_pick_first_shutdown(void); +void grpc_lb_policy_round_robin_init(void); +void grpc_lb_policy_round_robin_shutdown(void); +void grpc_resolver_dns_ares_init(void); +void grpc_resolver_dns_ares_shutdown(void); +void grpc_resolver_dns_native_init(void); +void grpc_resolver_dns_native_shutdown(void); +void grpc_resolver_sockaddr_init(void); +void grpc_resolver_sockaddr_shutdown(void); +void grpc_message_size_filter_init(void); +void grpc_message_size_filter_shutdown(void); +namespace grpc_core { +void FaultInjectionFilterInit(void); +void FaultInjectionFilterShutdown(void); +void GrpcLbPolicyRingHashInit(void); +void GrpcLbPolicyRingHashShutdown(void); +void RlsLbPluginInit(); +void RlsLbPluginShutdown(); +void ServiceConfigParserInit(void); +void ServiceConfigParserShutdown(void); +} // namespace grpc_core + +#ifndef GRPC_NO_XDS +namespace grpc_core { +void XdsClientGlobalInit(); +void XdsClientGlobalShutdown(); +} // namespace grpc_core +void grpc_certificate_provider_registry_init(void); +void grpc_certificate_provider_registry_shutdown(void); +namespace grpc_core { +void FileWatcherCertificateProviderInit(); +void FileWatcherCertificateProviderShutdown(); +} // namespace grpc_core +void grpc_lb_policy_cds_init(void); +void grpc_lb_policy_cds_shutdown(void); +void grpc_lb_policy_xds_cluster_impl_init(void); +void grpc_lb_policy_xds_cluster_impl_shutdown(void); +void grpc_lb_policy_xds_cluster_resolver_init(void); +void grpc_lb_policy_xds_cluster_resolver_shutdown(void); +void grpc_lb_policy_xds_cluster_manager_init(void); +void grpc_lb_policy_xds_cluster_manager_shutdown(void); +void grpc_resolver_xds_init(void); +void grpc_resolver_xds_shutdown(void); +namespace grpc_core { +void GoogleCloud2ProdResolverInit(); +void GoogleCloud2ProdResolverShutdown(); +} // namespace grpc_core +#endif + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT +void grpc_resolver_binder_init(void); +void grpc_resolver_binder_shutdown(void); +#endif + +void grpc_register_built_in_plugins(void) { + grpc_register_plugin(grpc_chttp2_plugin_init, grpc_chttp2_plugin_shutdown); + grpc_register_plugin(grpc_core::ServiceConfigParserInit, + grpc_core::ServiceConfigParserShutdown); + grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); + grpc_register_plugin(grpc_inproc_plugin_init, grpc_inproc_plugin_shutdown); + grpc_register_plugin(grpc_resolver_fake_init, grpc_resolver_fake_shutdown); + grpc_register_plugin(grpc_lb_policy_grpclb_init, + grpc_lb_policy_grpclb_shutdown); + grpc_register_plugin(grpc_core::RlsLbPluginInit, + grpc_core::RlsLbPluginShutdown); + grpc_register_plugin(grpc_lb_policy_priority_init, + grpc_lb_policy_priority_shutdown); + grpc_register_plugin(grpc_lb_policy_weighted_target_init, + grpc_lb_policy_weighted_target_shutdown); + grpc_register_plugin(grpc_lb_policy_pick_first_init, + grpc_lb_policy_pick_first_shutdown); + grpc_register_plugin(grpc_lb_policy_round_robin_init, + grpc_lb_policy_round_robin_shutdown); + grpc_register_plugin(grpc_core::GrpcLbPolicyRingHashInit, + grpc_core::GrpcLbPolicyRingHashShutdown); + grpc_register_plugin(grpc_resolver_dns_ares_init, + grpc_resolver_dns_ares_shutdown); + grpc_register_plugin(grpc_resolver_dns_native_init, + grpc_resolver_dns_native_shutdown); + grpc_register_plugin(grpc_resolver_sockaddr_init, + grpc_resolver_sockaddr_shutdown); + grpc_register_plugin(grpc_message_size_filter_init, + grpc_message_size_filter_shutdown); + grpc_register_plugin(grpc_core::FaultInjectionFilterInit, + grpc_core::FaultInjectionFilterShutdown); +#ifndef GRPC_NO_XDS + grpc_register_plugin(grpc_core::XdsClientGlobalInit, + grpc_core::XdsClientGlobalShutdown); + grpc_register_plugin(grpc_certificate_provider_registry_init, + grpc_certificate_provider_registry_shutdown); + grpc_register_plugin(grpc_core::FileWatcherCertificateProviderInit, + grpc_core::FileWatcherCertificateProviderShutdown); + grpc_register_plugin(grpc_lb_policy_cds_init, grpc_lb_policy_cds_shutdown); + grpc_register_plugin(grpc_lb_policy_xds_cluster_impl_init, + grpc_lb_policy_xds_cluster_impl_shutdown); + grpc_register_plugin(grpc_lb_policy_xds_cluster_resolver_init, + grpc_lb_policy_xds_cluster_resolver_shutdown); + grpc_register_plugin(grpc_lb_policy_xds_cluster_manager_init, + grpc_lb_policy_xds_cluster_manager_shutdown); + grpc_register_plugin(grpc_resolver_xds_init, grpc_resolver_xds_shutdown); + grpc_register_plugin(grpc_core::GoogleCloud2ProdResolverInit, + grpc_core::GoogleCloud2ProdResolverShutdown); +#endif + +#ifdef GPR_SUPPORT_BINDER_TRANSPORT + grpc_register_plugin(grpc_resolver_binder_init, + grpc_resolver_binder_shutdown); +#endif +} + +namespace grpc_core { + +extern void BuildClientChannelConfiguration( + CoreConfiguration::Builder* builder); +extern void SecurityRegisterHandshakerFactories( + CoreConfiguration::Builder* builder); +extern void RegisterClientAuthorityFilter(CoreConfiguration::Builder* builder); +extern void RegisterClientIdleFilter(CoreConfiguration::Builder* builder); +extern void RegisterDeadlineFilter(CoreConfiguration::Builder* builder); +extern void RegisterGrpcLbLoadReportingFilter( + CoreConfiguration::Builder* builder); +extern void RegisterHttpFilters(CoreConfiguration::Builder* builder); +extern void RegisterMaxAgeFilter(CoreConfiguration::Builder* builder); +extern void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder); +extern void RegisterSecurityFilters(CoreConfiguration::Builder* builder); +extern void RegisterServiceConfigChannelArgFilter( + CoreConfiguration::Builder* builder); +#ifndef GRPC_NO_XDS +extern void RegisterXdsChannelStackModifier( + CoreConfiguration::Builder* builder); +#endif + +void BuildCoreConfiguration(CoreConfiguration::Builder* builder) { + BuildClientChannelConfiguration(builder); + SecurityRegisterHandshakerFactories(builder); + RegisterClientAuthorityFilter(builder); + RegisterClientIdleFilter(builder); + RegisterGrpcLbLoadReportingFilter(builder); + RegisterHttpFilters(builder); + RegisterMaxAgeFilter(builder); + RegisterDeadlineFilter(builder); + RegisterMessageSizeFilter(builder); + RegisterServiceConfigChannelArgFilter(builder); + #ifndef GRPC_NO_XDS + RegisterXdsChannelStackModifier(builder); + #endif + // Run last so it gets a consistent location. + // TODO(ctiller): Is this actually necessary? + RegisterSecurityFilters(builder); + RegisterBuiltins(builder); +} + +} // namespace grpc_core diff --git a/src/core/plugin_registry/grpc_unsecure_plugin_registry.cc b/src/core/plugin_registry/grpc_unsecure_plugin_registry.cc new file mode 100644 index 00000000..e8a75567 --- /dev/null +++ b/src/core/plugin_registry/grpc_unsecure_plugin_registry.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/builtins.h" + +void grpc_chttp2_plugin_init(void); +void grpc_chttp2_plugin_shutdown(void); +void grpc_client_channel_init(void); +void grpc_client_channel_shutdown(void); +void grpc_inproc_plugin_init(void); +void grpc_inproc_plugin_shutdown(void); +void grpc_resolver_dns_ares_init(void); +void grpc_resolver_dns_ares_shutdown(void); +void grpc_resolver_dns_native_init(void); +void grpc_resolver_dns_native_shutdown(void); +void grpc_resolver_sockaddr_init(void); +void grpc_resolver_sockaddr_shutdown(void); +void grpc_resolver_fake_init(void); +void grpc_resolver_fake_shutdown(void); +void grpc_lb_policy_grpclb_init(void); +void grpc_lb_policy_grpclb_shutdown(void); +void grpc_lb_policy_priority_init(void); +void grpc_lb_policy_priority_shutdown(void); +void grpc_lb_policy_weighted_target_init(void); +void grpc_lb_policy_weighted_target_shutdown(void); +void grpc_lb_policy_pick_first_init(void); +void grpc_lb_policy_pick_first_shutdown(void); +void grpc_lb_policy_round_robin_init(void); +void grpc_lb_policy_round_robin_shutdown(void); +void grpc_message_size_filter_init(void); +void grpc_message_size_filter_shutdown(void); +namespace grpc_core { +void FaultInjectionFilterInit(void); +void FaultInjectionFilterShutdown(void); +void GrpcLbPolicyRingHashInit(void); +void GrpcLbPolicyRingHashShutdown(void); +void ServiceConfigParserInit(void); +void ServiceConfigParserShutdown(void); +} // namespace grpc_core + +void grpc_register_built_in_plugins(void) { + grpc_register_plugin(grpc_chttp2_plugin_init, grpc_chttp2_plugin_shutdown); + grpc_register_plugin(grpc_core::ServiceConfigParserInit, + grpc_core::ServiceConfigParserShutdown); + grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); + grpc_register_plugin(grpc_inproc_plugin_init, grpc_inproc_plugin_shutdown); + grpc_register_plugin(grpc_resolver_dns_ares_init, + grpc_resolver_dns_ares_shutdown); + grpc_register_plugin(grpc_resolver_dns_native_init, + grpc_resolver_dns_native_shutdown); + grpc_register_plugin(grpc_resolver_sockaddr_init, + grpc_resolver_sockaddr_shutdown); + grpc_register_plugin(grpc_resolver_fake_init, grpc_resolver_fake_shutdown); + grpc_register_plugin(grpc_lb_policy_grpclb_init, + grpc_lb_policy_grpclb_shutdown); + grpc_register_plugin(grpc_lb_policy_priority_init, + grpc_lb_policy_priority_shutdown); + grpc_register_plugin(grpc_lb_policy_weighted_target_init, + grpc_lb_policy_weighted_target_shutdown); + grpc_register_plugin(grpc_lb_policy_pick_first_init, + grpc_lb_policy_pick_first_shutdown); + grpc_register_plugin(grpc_lb_policy_round_robin_init, + grpc_lb_policy_round_robin_shutdown); + grpc_register_plugin(grpc_core::GrpcLbPolicyRingHashInit, + grpc_core::GrpcLbPolicyRingHashShutdown); + grpc_register_plugin(grpc_message_size_filter_init, + grpc_message_size_filter_shutdown); + grpc_register_plugin(grpc_core::FaultInjectionFilterInit, + grpc_core::FaultInjectionFilterShutdown); +} + +namespace grpc_core { + +extern void BuildClientChannelConfiguration( + CoreConfiguration::Builder* builder); +extern void RegisterClientAuthorityFilter(CoreConfiguration::Builder* builder); +extern void RegisterClientIdleFilter(CoreConfiguration::Builder* builder); +extern void RegisterDeadlineFilter(CoreConfiguration::Builder* builder); +extern void RegisterGrpcLbLoadReportingFilter( + CoreConfiguration::Builder* builder); +extern void RegisterHttpFilters(CoreConfiguration::Builder* builder); +extern void RegisterMaxAgeFilter(CoreConfiguration::Builder* builder); +extern void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder); +extern void RegisterSecurityFilters(CoreConfiguration::Builder* builder); +extern void RegisterServiceConfigChannelArgFilter( + CoreConfiguration::Builder* builder); + +void BuildCoreConfiguration(CoreConfiguration::Builder* builder) { + BuildClientChannelConfiguration(builder); + RegisterClientAuthorityFilter(builder); + RegisterClientIdleFilter(builder); + RegisterGrpcLbLoadReportingFilter(builder); + RegisterHttpFilters(builder); + RegisterMaxAgeFilter(builder); + RegisterDeadlineFilter(builder); + RegisterMessageSizeFilter(builder); + RegisterServiceConfigChannelArgFilter(builder); + // Run last so it gets a consistent location. + // TODO(ctiller): Is this actually necessary? + RegisterBuiltins(builder); +} + +} // namespace grpc_core diff --git a/src/core/tsi/alts/crypt/aes_gcm.cc b/src/core/tsi/alts/crypt/aes_gcm.cc new file mode 100644 index 00000000..4eaaee28 --- /dev/null +++ b/src/core/tsi/alts/crypt/aes_gcm.cc @@ -0,0 +1,690 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include "src/core/tsi/alts/crypt/gsec.h" + +constexpr size_t kKdfKeyLen = 32; +constexpr size_t kKdfCounterLen = 6; +constexpr size_t kKdfCounterOffset = 2; +constexpr size_t kRekeyAeadKeyLen = kAes128GcmKeyLength; + +/* Struct for additional data required if rekeying is enabled. */ +struct gsec_aes_gcm_aead_rekey_data { + uint8_t kdf_counter[kKdfCounterLen]; + uint8_t nonce_mask[kAesGcmNonceLength]; +}; + +/* Main struct for AES_GCM crypter interface. */ +struct gsec_aes_gcm_aead_crypter { + gsec_aead_crypter crypter; + size_t key_length; + size_t nonce_length; + size_t tag_length; + uint8_t* key; + gsec_aes_gcm_aead_rekey_data* rekey_data; + EVP_CIPHER_CTX* ctx; +}; + +static char* aes_gcm_get_openssl_errors() { + BIO* bio = BIO_new(BIO_s_mem()); + ERR_print_errors(bio); + BUF_MEM* mem = nullptr; + char* error_msg = nullptr; + BIO_get_mem_ptr(bio, &mem); + if (mem != nullptr) { + error_msg = static_cast(gpr_malloc(mem->length + 1)); + memcpy(error_msg, mem->data, mem->length); + error_msg[mem->length] = '\0'; + } + BIO_free_all(bio); + return error_msg; +} + +static void aes_gcm_format_errors(const char* error_msg, char** error_details) { + if (error_details == nullptr) { + return; + } + unsigned long error = ERR_get_error(); + if (error == 0 && error_msg != nullptr) { + *error_details = static_cast(gpr_malloc(strlen(error_msg) + 1)); + memcpy(*error_details, error_msg, strlen(error_msg) + 1); + return; + } + char* openssl_errors = aes_gcm_get_openssl_errors(); + if (openssl_errors != nullptr && error_msg != nullptr) { + size_t len = strlen(error_msg) + strlen(openssl_errors) + 2; /* ", " */ + *error_details = static_cast(gpr_malloc(len + 1)); + snprintf(*error_details, len + 1, "%s, %s", error_msg, openssl_errors); + gpr_free(openssl_errors); + } +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_max_ciphertext_and_tag_length( + const gsec_aead_crypter* crypter, size_t plaintext_length, + size_t* max_ciphertext_and_tag_length, char** error_details) { + if (max_ciphertext_and_tag_length == nullptr) { + aes_gcm_format_errors("max_ciphertext_and_tag_length is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + *max_ciphertext_and_tag_length = + plaintext_length + aes_gcm_crypter->tag_length; + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_max_plaintext_length( + const gsec_aead_crypter* crypter, size_t ciphertext_and_tag_length, + size_t* max_plaintext_length, char** error_details) { + if (max_plaintext_length == nullptr) { + aes_gcm_format_errors("max_plaintext_length is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + if (ciphertext_and_tag_length < aes_gcm_crypter->tag_length) { + *max_plaintext_length = 0; + aes_gcm_format_errors( + "ciphertext_and_tag_length is smaller than tag_length.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + *max_plaintext_length = + ciphertext_and_tag_length - aes_gcm_crypter->tag_length; + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_nonce_length( + const gsec_aead_crypter* crypter, size_t* nonce_length, + char** error_details) { + if (nonce_length == nullptr) { + aes_gcm_format_errors("nonce_length is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + *nonce_length = aes_gcm_crypter->nonce_length; + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_key_length( + const gsec_aead_crypter* crypter, size_t* key_length, + char** error_details) { + if (key_length == nullptr) { + aes_gcm_format_errors("key_length is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + *key_length = aes_gcm_crypter->key_length; + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_tag_length( + const gsec_aead_crypter* crypter, size_t* tag_length, + char** error_details) { + if (tag_length == nullptr) { + aes_gcm_format_errors("tag_length is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + *tag_length = aes_gcm_crypter->tag_length; + return GRPC_STATUS_OK; +} + +static void aes_gcm_mask_nonce(uint8_t* dst, const uint8_t* nonce, + const uint8_t* mask) { + uint64_t mask1; + uint32_t mask2; + memcpy(&mask1, mask, sizeof(mask1)); + memcpy(&mask2, mask + sizeof(mask1), sizeof(mask2)); + uint64_t nonce1; + uint32_t nonce2; + memcpy(&nonce1, nonce, sizeof(nonce1)); + memcpy(&nonce2, nonce + sizeof(nonce1), sizeof(nonce2)); + nonce1 ^= mask1; + nonce2 ^= mask2; + memcpy(dst, &nonce1, sizeof(nonce1)); + memcpy(dst + sizeof(nonce1), &nonce2, sizeof(nonce2)); +} + +static grpc_status_code aes_gcm_derive_aead_key(uint8_t* dst, + const uint8_t* kdf_key, + const uint8_t* kdf_counter) { + unsigned char buf[EVP_MAX_MD_SIZE]; + unsigned char ctr = 1; +#if OPENSSL_VERSION_NUMBER < 0x10100000L + HMAC_CTX hmac; + HMAC_CTX_init(&hmac); + if (!HMAC_Init_ex(&hmac, kdf_key, kKdfKeyLen, EVP_sha256(), nullptr) || + !HMAC_Update(&hmac, kdf_counter, kKdfCounterLen) || + !HMAC_Update(&hmac, &ctr, 1) || !HMAC_Final(&hmac, buf, nullptr)) { + HMAC_CTX_cleanup(&hmac); + return GRPC_STATUS_INTERNAL; + } + HMAC_CTX_cleanup(&hmac); +#else + HMAC_CTX* hmac = HMAC_CTX_new(); + if (hmac == nullptr) { + return GRPC_STATUS_INTERNAL; + } + if (!HMAC_Init_ex(hmac, kdf_key, kKdfKeyLen, EVP_sha256(), nullptr) || + !HMAC_Update(hmac, kdf_counter, kKdfCounterLen) || + !HMAC_Update(hmac, &ctr, 1) || !HMAC_Final(hmac, buf, nullptr)) { + HMAC_CTX_free(hmac); + return GRPC_STATUS_INTERNAL; + } + HMAC_CTX_free(hmac); +#endif + memcpy(dst, buf, kRekeyAeadKeyLen); + return GRPC_STATUS_OK; +} + +static grpc_status_code aes_gcm_rekey_if_required( + gsec_aes_gcm_aead_crypter* aes_gcm_crypter, const uint8_t* nonce, + char** error_details) { + // If rekey_data is nullptr, then rekeying is not supported and not required. + // If bytes 2-7 of kdf_counter differ from the (per message) nonce, then the + // encryption key is recomputed from a new kdf_counter to ensure that we don't + // encrypt more than 2^16 messages per encryption key (in each direction). + if (aes_gcm_crypter->rekey_data == nullptr || + memcmp(aes_gcm_crypter->rekey_data->kdf_counter, + nonce + kKdfCounterOffset, kKdfCounterLen) == 0) { + return GRPC_STATUS_OK; + } + memcpy(aes_gcm_crypter->rekey_data->kdf_counter, nonce + kKdfCounterOffset, + kKdfCounterLen); + uint8_t aead_key[kRekeyAeadKeyLen]; + if (aes_gcm_derive_aead_key(aead_key, aes_gcm_crypter->key, + aes_gcm_crypter->rekey_data->kdf_counter) != + GRPC_STATUS_OK) { + aes_gcm_format_errors("Rekeying failed in key derivation.", error_details); + return GRPC_STATUS_INTERNAL; + } + if (!EVP_DecryptInit_ex(aes_gcm_crypter->ctx, nullptr, nullptr, aead_key, + nullptr)) { + aes_gcm_format_errors("Rekeying failed in context update.", error_details); + return GRPC_STATUS_INTERNAL; + } + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_encrypt_iovec( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const struct iovec* aad_vec, size_t aad_vec_length, + const struct iovec* plaintext_vec, size_t plaintext_vec_length, + struct iovec ciphertext_vec, size_t* ciphertext_bytes_written, + char** error_details) { + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast(crypter); + // Input checks + if (nonce == nullptr) { + aes_gcm_format_errors("Nonce buffer is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (kAesGcmNonceLength != nonce_length) { + aes_gcm_format_errors("Nonce buffer has the wrong length.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (aad_vec_length > 0 && aad_vec == nullptr) { + aes_gcm_format_errors("Non-zero aad_vec_length but aad_vec is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (plaintext_vec_length > 0 && plaintext_vec == nullptr) { + aes_gcm_format_errors( + "Non-zero plaintext_vec_length but plaintext_vec is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (ciphertext_bytes_written == nullptr) { + aes_gcm_format_errors("bytes_written is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + *ciphertext_bytes_written = 0; + // rekey if required + if (aes_gcm_rekey_if_required(aes_gcm_crypter, nonce, error_details) != + GRPC_STATUS_OK) { + return GRPC_STATUS_INTERNAL; + } + // mask nonce if required + const uint8_t* nonce_aead = nonce; + uint8_t nonce_masked[kAesGcmNonceLength]; + if (aes_gcm_crypter->rekey_data != nullptr) { + aes_gcm_mask_nonce(nonce_masked, aes_gcm_crypter->rekey_data->nonce_mask, + nonce); + nonce_aead = nonce_masked; + } + // init openssl context + if (!EVP_EncryptInit_ex(aes_gcm_crypter->ctx, nullptr, nullptr, nullptr, + nonce_aead)) { + aes_gcm_format_errors("Initializing nonce failed", error_details); + return GRPC_STATUS_INTERNAL; + } + // process aad + size_t i; + for (i = 0; i < aad_vec_length; i++) { + const uint8_t* aad = static_cast(aad_vec[i].iov_base); + size_t aad_length = aad_vec[i].iov_len; + if (aad_length == 0) { + continue; + } + size_t aad_bytes_read = 0; + if (aad == nullptr) { + aes_gcm_format_errors("aad is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (!EVP_EncryptUpdate(aes_gcm_crypter->ctx, nullptr, + reinterpret_cast(&aad_bytes_read), aad, + static_cast(aad_length)) || + aad_bytes_read != aad_length) { + aes_gcm_format_errors("Setting authenticated associated data failed", + error_details); + return GRPC_STATUS_INTERNAL; + } + } + uint8_t* ciphertext = static_cast(ciphertext_vec.iov_base); + size_t ciphertext_length = ciphertext_vec.iov_len; + if (ciphertext == nullptr) { + aes_gcm_format_errors("ciphertext is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + // process plaintext + for (i = 0; i < plaintext_vec_length; i++) { + const uint8_t* plaintext = static_cast(plaintext_vec[i].iov_base); + size_t plaintext_length = plaintext_vec[i].iov_len; + if (plaintext == nullptr) { + if (plaintext_length == 0) { + continue; + } + aes_gcm_format_errors("plaintext is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (ciphertext_length < plaintext_length) { + aes_gcm_format_errors( + "ciphertext is not large enough to hold the result.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + int bytes_written = 0; + int bytes_to_write = static_cast(plaintext_length); + if (!EVP_EncryptUpdate(aes_gcm_crypter->ctx, ciphertext, &bytes_written, + plaintext, bytes_to_write)) { + aes_gcm_format_errors("Encrypting plaintext failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + if (bytes_written > bytes_to_write) { + aes_gcm_format_errors("More bytes written than expected.", error_details); + return GRPC_STATUS_INTERNAL; + } + ciphertext += bytes_written; + ciphertext_length -= bytes_written; + } + int bytes_written_temp = 0; + if (!EVP_EncryptFinal_ex(aes_gcm_crypter->ctx, nullptr, + &bytes_written_temp)) { + aes_gcm_format_errors("Finalizing encryption failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + if (bytes_written_temp != 0) { + aes_gcm_format_errors("Openssl wrote some unexpected bytes.", + error_details); + return GRPC_STATUS_INTERNAL; + } + if (ciphertext_length < kAesGcmTagLength) { + aes_gcm_format_errors("ciphertext is too small to hold a tag.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + + if (!EVP_CIPHER_CTX_ctrl(aes_gcm_crypter->ctx, EVP_CTRL_GCM_GET_TAG, + kAesGcmTagLength, ciphertext)) { + aes_gcm_format_errors("Writing tag failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + ciphertext += kAesGcmTagLength; + ciphertext_length -= kAesGcmTagLength; + *ciphertext_bytes_written = ciphertext_vec.iov_len - ciphertext_length; + return GRPC_STATUS_OK; +} + +static grpc_status_code gsec_aes_gcm_aead_crypter_decrypt_iovec( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const struct iovec* aad_vec, size_t aad_vec_length, + const struct iovec* ciphertext_vec, size_t ciphertext_vec_length, + struct iovec plaintext_vec, size_t* plaintext_bytes_written, + char** error_details) { + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + if (nonce == nullptr) { + aes_gcm_format_errors("Nonce buffer is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (kAesGcmNonceLength != nonce_length) { + aes_gcm_format_errors("Nonce buffer has the wrong length.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (aad_vec_length > 0 && aad_vec == nullptr) { + aes_gcm_format_errors("Non-zero aad_vec_length but aad_vec is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (ciphertext_vec_length > 0 && ciphertext_vec == nullptr) { + aes_gcm_format_errors( + "Non-zero plaintext_vec_length but plaintext_vec is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + // Compute the total length so we can ensure we don't pass the tag into + // EVP_decrypt. + size_t total_ciphertext_length = 0; + size_t i; + for (i = 0; i < ciphertext_vec_length; i++) { + total_ciphertext_length += ciphertext_vec[i].iov_len; + } + if (total_ciphertext_length < kAesGcmTagLength) { + aes_gcm_format_errors("ciphertext is too small to hold a tag.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (plaintext_bytes_written == nullptr) { + aes_gcm_format_errors("bytes_written is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + *plaintext_bytes_written = 0; + // rekey if required + if (aes_gcm_rekey_if_required(aes_gcm_crypter, nonce, error_details) != + GRPC_STATUS_OK) { + aes_gcm_format_errors("Rekeying failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + // mask nonce if required + const uint8_t* nonce_aead = nonce; + uint8_t nonce_masked[kAesGcmNonceLength]; + if (aes_gcm_crypter->rekey_data != nullptr) { + aes_gcm_mask_nonce(nonce_masked, aes_gcm_crypter->rekey_data->nonce_mask, + nonce); + nonce_aead = nonce_masked; + } + // init openssl context + if (!EVP_DecryptInit_ex(aes_gcm_crypter->ctx, nullptr, nullptr, nullptr, + nonce_aead)) { + aes_gcm_format_errors("Initializing nonce failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + // process aad + for (i = 0; i < aad_vec_length; i++) { + const uint8_t* aad = static_cast(aad_vec[i].iov_base); + size_t aad_length = aad_vec[i].iov_len; + if (aad_length == 0) { + continue; + } + size_t aad_bytes_read = 0; + if (aad == nullptr) { + aes_gcm_format_errors("aad is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (!EVP_DecryptUpdate(aes_gcm_crypter->ctx, nullptr, + reinterpret_cast(&aad_bytes_read), aad, + static_cast(aad_length)) || + aad_bytes_read != aad_length) { + aes_gcm_format_errors("Setting authenticated associated data failed.", + error_details); + return GRPC_STATUS_INTERNAL; + } + } + // process ciphertext + uint8_t* plaintext = static_cast(plaintext_vec.iov_base); + size_t plaintext_length = plaintext_vec.iov_len; + if (plaintext_length > 0 && plaintext == nullptr) { + aes_gcm_format_errors( + "plaintext is nullptr, but plaintext_length is positive.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + const uint8_t* ciphertext = nullptr; + size_t ciphertext_length = 0; + for (i = 0; + i < ciphertext_vec_length && total_ciphertext_length > kAesGcmTagLength; + i++) { + ciphertext = static_cast(ciphertext_vec[i].iov_base); + ciphertext_length = ciphertext_vec[i].iov_len; + if (ciphertext == nullptr) { + if (ciphertext_length == 0) { + continue; + } + aes_gcm_format_errors("ciphertext is nullptr.", error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INVALID_ARGUMENT; + } + size_t bytes_written = 0; + size_t bytes_to_write = ciphertext_length; + // Don't include the tag + if (bytes_to_write > total_ciphertext_length - kAesGcmTagLength) { + bytes_to_write = total_ciphertext_length - kAesGcmTagLength; + } + if (plaintext_length < bytes_to_write) { + aes_gcm_format_errors( + "Not enough plaintext buffer to hold encrypted ciphertext.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (!EVP_DecryptUpdate(aes_gcm_crypter->ctx, plaintext, + reinterpret_cast(&bytes_written), ciphertext, + static_cast(bytes_to_write))) { + aes_gcm_format_errors("Decrypting ciphertext failed.", error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INTERNAL; + } + if (bytes_written > ciphertext_length) { + aes_gcm_format_errors("More bytes written than expected.", error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INTERNAL; + } + ciphertext += bytes_written; + ciphertext_length -= bytes_written; + total_ciphertext_length -= bytes_written; + plaintext += bytes_written; + plaintext_length -= bytes_written; + } + if (total_ciphertext_length > kAesGcmTagLength) { + aes_gcm_format_errors( + "Not enough plaintext buffer to hold encrypted ciphertext.", + error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INVALID_ARGUMENT; + } + uint8_t tag[kAesGcmTagLength]; + uint8_t* tag_tmp = tag; + if (ciphertext_length > 0) { + memcpy(tag_tmp, ciphertext, ciphertext_length); + tag_tmp += ciphertext_length; + total_ciphertext_length -= ciphertext_length; + } + for (; i < ciphertext_vec_length; i++) { + ciphertext = static_cast(ciphertext_vec[i].iov_base); + ciphertext_length = ciphertext_vec[i].iov_len; + if (ciphertext == nullptr) { + if (ciphertext_length == 0) { + continue; + } + aes_gcm_format_errors("ciphertext is nullptr.", error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INVALID_ARGUMENT; + } + memcpy(tag_tmp, ciphertext, ciphertext_length); + tag_tmp += ciphertext_length; + total_ciphertext_length -= ciphertext_length; + } + if (!EVP_CIPHER_CTX_ctrl(aes_gcm_crypter->ctx, EVP_CTRL_GCM_SET_TAG, + kAesGcmTagLength, reinterpret_cast(tag))) { + aes_gcm_format_errors("Setting tag failed.", error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INTERNAL; + } + int bytes_written_temp = 0; + if (!EVP_DecryptFinal_ex(aes_gcm_crypter->ctx, nullptr, + &bytes_written_temp)) { + aes_gcm_format_errors("Checking tag failed.", error_details); + if (plaintext_vec.iov_base != nullptr) { + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + } + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (bytes_written_temp != 0) { + aes_gcm_format_errors("Openssl wrote some unexpected bytes.", + error_details); + memset(plaintext_vec.iov_base, 0x00, plaintext_vec.iov_len); + return GRPC_STATUS_INTERNAL; + } + *plaintext_bytes_written = plaintext_vec.iov_len - plaintext_length; + return GRPC_STATUS_OK; +} + +static void gsec_aes_gcm_aead_crypter_destroy(gsec_aead_crypter* crypter) { + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + reinterpret_cast( + const_cast(crypter)); + gpr_free(aes_gcm_crypter->key); + gpr_free(aes_gcm_crypter->rekey_data); + EVP_CIPHER_CTX_free(aes_gcm_crypter->ctx); +} + +static const gsec_aead_crypter_vtable vtable = { + gsec_aes_gcm_aead_crypter_encrypt_iovec, + gsec_aes_gcm_aead_crypter_decrypt_iovec, + gsec_aes_gcm_aead_crypter_max_ciphertext_and_tag_length, + gsec_aes_gcm_aead_crypter_max_plaintext_length, + gsec_aes_gcm_aead_crypter_nonce_length, + gsec_aes_gcm_aead_crypter_key_length, + gsec_aes_gcm_aead_crypter_tag_length, + gsec_aes_gcm_aead_crypter_destroy}; + +static grpc_status_code aes_gcm_new_evp_cipher_ctx( + gsec_aes_gcm_aead_crypter* aes_gcm_crypter, char** error_details) { + const EVP_CIPHER* cipher = nullptr; + bool is_rekey = aes_gcm_crypter->rekey_data != nullptr; + switch (is_rekey ? kRekeyAeadKeyLen : aes_gcm_crypter->key_length) { + case kAes128GcmKeyLength: + cipher = EVP_aes_128_gcm(); + break; + case kAes256GcmKeyLength: + cipher = EVP_aes_256_gcm(); + break; + } + const uint8_t* aead_key = aes_gcm_crypter->key; + uint8_t aead_key_rekey[kRekeyAeadKeyLen]; + if (is_rekey) { + if (aes_gcm_derive_aead_key(aead_key_rekey, aes_gcm_crypter->key, + aes_gcm_crypter->rekey_data->kdf_counter) != + GRPC_STATUS_OK) { + aes_gcm_format_errors("Deriving key failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + aead_key = aead_key_rekey; + } + if (!EVP_DecryptInit_ex(aes_gcm_crypter->ctx, cipher, nullptr, aead_key, + nullptr)) { + aes_gcm_format_errors("Setting key failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + if (!EVP_CIPHER_CTX_ctrl(aes_gcm_crypter->ctx, EVP_CTRL_GCM_SET_IVLEN, + static_cast(aes_gcm_crypter->nonce_length), + nullptr)) { + aes_gcm_format_errors("Setting nonce length failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + return GRPC_STATUS_OK; +} + +grpc_status_code gsec_aes_gcm_aead_crypter_create(const uint8_t* key, + size_t key_length, + size_t nonce_length, + size_t tag_length, bool rekey, + gsec_aead_crypter** crypter, + char** error_details) { + if (key == nullptr) { + aes_gcm_format_errors("key is nullptr.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (crypter == nullptr) { + aes_gcm_format_errors("crypter is nullptr.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + *crypter = nullptr; + if ((rekey && key_length != kAes128GcmRekeyKeyLength) || + (!rekey && key_length != kAes128GcmKeyLength && + key_length != kAes256GcmKeyLength) || + (tag_length != kAesGcmTagLength) || + (nonce_length != kAesGcmNonceLength)) { + aes_gcm_format_errors( + "Invalid key and/or nonce and/or tag length are provided at AEAD " + "crypter instance construction time.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + gsec_aes_gcm_aead_crypter* aes_gcm_crypter = + static_cast( + gpr_malloc(sizeof(gsec_aes_gcm_aead_crypter))); + aes_gcm_crypter->crypter.vtable = &vtable; + aes_gcm_crypter->nonce_length = nonce_length; + aes_gcm_crypter->tag_length = tag_length; + if (rekey) { + aes_gcm_crypter->key_length = kKdfKeyLen; + aes_gcm_crypter->rekey_data = static_cast( + gpr_malloc(sizeof(gsec_aes_gcm_aead_rekey_data))); + memcpy(aes_gcm_crypter->rekey_data->nonce_mask, key + kKdfKeyLen, + kAesGcmNonceLength); + // Set kdf_counter to all-zero for initial key derivation. + memset(aes_gcm_crypter->rekey_data->kdf_counter, 0, kKdfCounterLen); + } else { + aes_gcm_crypter->key_length = key_length; + aes_gcm_crypter->rekey_data = nullptr; + } + aes_gcm_crypter->key = + static_cast(gpr_malloc(aes_gcm_crypter->key_length)); + memcpy(aes_gcm_crypter->key, key, aes_gcm_crypter->key_length); + aes_gcm_crypter->ctx = EVP_CIPHER_CTX_new(); + grpc_status_code status = + aes_gcm_new_evp_cipher_ctx(aes_gcm_crypter, error_details); + if (status != GRPC_STATUS_OK) { + gsec_aes_gcm_aead_crypter_destroy(&aes_gcm_crypter->crypter); + gpr_free(aes_gcm_crypter); + return status; + } + *crypter = &aes_gcm_crypter->crypter; + return GRPC_STATUS_OK; +} diff --git a/src/core/tsi/alts/crypt/gsec.cc b/src/core/tsi/alts/crypt/gsec.cc new file mode 100644 index 00000000..d4de9cdb --- /dev/null +++ b/src/core/tsi/alts/crypt/gsec.cc @@ -0,0 +1,190 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/crypt/gsec.h" + +#include +#include + +#include + +static const char vtable_error_msg[] = + "crypter or crypter->vtable has not been initialized properly"; + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +grpc_status_code gsec_aead_crypter_encrypt( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const uint8_t* aad, size_t aad_length, const uint8_t* plaintext, + size_t plaintext_length, uint8_t* ciphertext_and_tag, + size_t ciphertext_and_tag_length, size_t* bytes_written, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->encrypt_iovec != nullptr) { + struct iovec aad_vec = {const_cast(aad), aad_length}; + struct iovec plaintext_vec = {const_cast(plaintext), + plaintext_length}; + struct iovec ciphertext_vec = {ciphertext_and_tag, + ciphertext_and_tag_length}; + return crypter->vtable->encrypt_iovec( + crypter, nonce, nonce_length, &aad_vec, 1, &plaintext_vec, 1, + ciphertext_vec, bytes_written, error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_encrypt_iovec( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const struct iovec* aad_vec, size_t aad_vec_length, + const struct iovec* plaintext_vec, size_t plaintext_vec_length, + struct iovec ciphertext_vec, size_t* ciphertext_bytes_written, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->encrypt_iovec != nullptr) { + return crypter->vtable->encrypt_iovec( + crypter, nonce, nonce_length, aad_vec, aad_vec_length, plaintext_vec, + plaintext_vec_length, ciphertext_vec, ciphertext_bytes_written, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_decrypt( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const uint8_t* aad, size_t aad_length, const uint8_t* ciphertext_and_tag, + size_t ciphertext_and_tag_length, uint8_t* plaintext, + size_t plaintext_length, size_t* bytes_written, char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->decrypt_iovec != nullptr) { + struct iovec aad_vec = {const_cast(aad), aad_length}; + struct iovec ciphertext_vec = {const_cast(ciphertext_and_tag), + ciphertext_and_tag_length}; + struct iovec plaintext_vec = {plaintext, plaintext_length}; + return crypter->vtable->decrypt_iovec( + crypter, nonce, nonce_length, &aad_vec, 1, &ciphertext_vec, 1, + plaintext_vec, bytes_written, error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_decrypt_iovec( + gsec_aead_crypter* crypter, const uint8_t* nonce, size_t nonce_length, + const struct iovec* aad_vec, size_t aad_vec_length, + const struct iovec* ciphertext_vec, size_t ciphertext_vec_length, + struct iovec plaintext_vec, size_t* plaintext_bytes_written, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->encrypt_iovec != nullptr) { + return crypter->vtable->decrypt_iovec( + crypter, nonce, nonce_length, aad_vec, aad_vec_length, ciphertext_vec, + ciphertext_vec_length, plaintext_vec, plaintext_bytes_written, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_max_ciphertext_and_tag_length( + const gsec_aead_crypter* crypter, size_t plaintext_length, + size_t* max_ciphertext_and_tag_length_to_return, char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->max_ciphertext_and_tag_length != nullptr) { + return crypter->vtable->max_ciphertext_and_tag_length( + crypter, plaintext_length, max_ciphertext_and_tag_length_to_return, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_max_plaintext_length( + const gsec_aead_crypter* crypter, size_t ciphertext_and_tag_length, + size_t* max_plaintext_length_to_return, char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->max_plaintext_length != nullptr) { + return crypter->vtable->max_plaintext_length( + crypter, ciphertext_and_tag_length, max_plaintext_length_to_return, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_nonce_length( + const gsec_aead_crypter* crypter, size_t* nonce_length_to_return, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->nonce_length != nullptr) { + return crypter->vtable->nonce_length(crypter, nonce_length_to_return, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_key_length(const gsec_aead_crypter* crypter, + size_t* key_length_to_return, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->key_length != nullptr) { + return crypter->vtable->key_length(crypter, key_length_to_return, + error_details); + } + /* An error occurred */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +grpc_status_code gsec_aead_crypter_tag_length(const gsec_aead_crypter* crypter, + size_t* tag_length_to_return, + char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->tag_length != nullptr) { + return crypter->vtable->tag_length(crypter, tag_length_to_return, + error_details); + } + /* An error occurred. */ + maybe_copy_error_msg(vtable_error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +void gsec_aead_crypter_destroy(gsec_aead_crypter* crypter) { + if (crypter != nullptr) { + if (crypter->vtable != nullptr && crypter->vtable->destruct != nullptr) { + crypter->vtable->destruct(crypter); + } + gpr_free(crypter); + } +} diff --git a/src/core/tsi/alts/frame_protector/alts_counter.cc b/src/core/tsi/alts/frame_protector/alts_counter.cc new file mode 100644 index 00000000..de163e3e --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_counter.cc @@ -0,0 +1,118 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/frame_protector/alts_counter.h" + +#include + +#include + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +grpc_status_code alts_counter_create(bool is_client, size_t counter_size, + size_t overflow_size, + alts_counter** crypter_counter, + char** error_details) { + /* Perform input sanity check. */ + if (counter_size == 0) { + const char error_msg[] = "counter_size is invalid."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (overflow_size == 0 || overflow_size >= counter_size) { + const char error_msg[] = "overflow_size is invalid."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (crypter_counter == nullptr) { + const char error_msg[] = "crypter_counter is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + *crypter_counter = + static_cast(gpr_malloc(sizeof(**crypter_counter))); + (*crypter_counter)->size = counter_size; + (*crypter_counter)->overflow_size = overflow_size; + (*crypter_counter)->counter = + static_cast(gpr_zalloc(counter_size)); + if (is_client) { + ((*crypter_counter)->counter)[counter_size - 1] = 0x80; + } + return GRPC_STATUS_OK; +} + +grpc_status_code alts_counter_increment(alts_counter* crypter_counter, + bool* is_overflow, + char** error_details) { + /* Perform input sanity check. */ + if (crypter_counter == nullptr) { + const char error_msg[] = "crypter_counter is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (is_overflow == nullptr) { + const char error_msg[] = "is_overflow is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + /* Increment the internal counter. */ + size_t i = 0; + for (; i < crypter_counter->overflow_size; i++) { + (crypter_counter->counter)[i]++; + if ((crypter_counter->counter)[i] != 0x00) { + break; + } + } + /** + * If the lower overflow_size bytes are all zero, the counter has overflowed. + */ + if (i == crypter_counter->overflow_size) { + *is_overflow = true; + return GRPC_STATUS_FAILED_PRECONDITION; + } + *is_overflow = false; + return GRPC_STATUS_OK; +} + +size_t alts_counter_get_size(alts_counter* crypter_counter) { + if (crypter_counter == nullptr) { + return 0; + } + return crypter_counter->size; +} + +unsigned char* alts_counter_get_counter(alts_counter* crypter_counter) { + if (crypter_counter == nullptr) { + return nullptr; + } + return crypter_counter->counter; +} + +void alts_counter_destroy(alts_counter* crypter_counter) { + if (crypter_counter != nullptr) { + gpr_free(crypter_counter->counter); + gpr_free(crypter_counter); + } +} diff --git a/src/core/tsi/alts/frame_protector/alts_crypter.cc b/src/core/tsi/alts/frame_protector/alts_crypter.cc new file mode 100644 index 00000000..56f05121 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_crypter.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/frame_protector/alts_crypter.h" + +#include + +#include + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +grpc_status_code alts_crypter_process_in_place( + alts_crypter* crypter, unsigned char* data, size_t data_allocated_size, + size_t data_size, size_t* output_size, char** error_details) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->process_in_place != nullptr) { + return crypter->vtable->process_in_place(crypter, data, data_allocated_size, + data_size, output_size, + error_details); + } + /* An error occurred. */ + const char error_msg[] = + "crypter or crypter->vtable has not been initialized properly."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; +} + +size_t alts_crypter_num_overhead_bytes(const alts_crypter* crypter) { + if (crypter != nullptr && crypter->vtable != nullptr && + crypter->vtable->num_overhead_bytes != nullptr) { + return crypter->vtable->num_overhead_bytes(crypter); + } + /* An error occurred. */ + return 0; +} + +void alts_crypter_destroy(alts_crypter* crypter) { + if (crypter != nullptr) { + if (crypter->vtable != nullptr && crypter->vtable->destruct != nullptr) { + crypter->vtable->destruct(crypter); + } + gpr_free(crypter); + } +} diff --git a/src/core/tsi/alts/frame_protector/alts_frame_protector.cc b/src/core/tsi/alts/frame_protector/alts_frame_protector.cc new file mode 100644 index 00000000..aae92a47 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_frame_protector.cc @@ -0,0 +1,408 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/tsi/alts/crypt/gsec.h" +#include "src/core/tsi/alts/frame_protector/alts_crypter.h" +#include "src/core/tsi/alts/frame_protector/frame_handler.h" +#include "src/core/tsi/transport_security.h" + +constexpr size_t kMinFrameLength = 1024; +constexpr size_t kDefaultFrameLength = 16 * 1024; +constexpr size_t kMaxFrameLength = 1024 * 1024; + +// Limit k on number of frames such that at most 2^(8 * k) frames can be sent. +constexpr size_t kAltsRecordProtocolRekeyFrameLimit = 8; +constexpr size_t kAltsRecordProtocolFrameLimit = 5; + +/* Main struct for alts_frame_protector. */ +struct alts_frame_protector { + tsi_frame_protector base; + alts_crypter* seal_crypter; + alts_crypter* unseal_crypter; + alts_frame_writer* writer; + alts_frame_reader* reader; + unsigned char* in_place_protect_buffer; + unsigned char* in_place_unprotect_buffer; + size_t in_place_protect_bytes_buffered; + size_t in_place_unprotect_bytes_processed; + size_t max_protected_frame_size; + size_t max_unprotected_frame_size; + size_t overhead_length; + size_t counter_overflow; +}; + +static tsi_result seal(alts_frame_protector* impl) { + char* error_details = nullptr; + size_t output_size = 0; + grpc_status_code status = alts_crypter_process_in_place( + impl->seal_crypter, impl->in_place_protect_buffer, + impl->max_protected_frame_size, impl->in_place_protect_bytes_buffered, + &output_size, &error_details); + impl->in_place_protect_bytes_buffered = output_size; + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "%s", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + return TSI_OK; +} + +static size_t max_encrypted_payload_bytes(alts_frame_protector* impl) { + return impl->max_protected_frame_size - kFrameHeaderSize; +} + +static tsi_result alts_protect_flush(tsi_frame_protector* self, + unsigned char* protected_output_frames, + size_t* protected_output_frames_size, + size_t* still_pending_size) { + if (self == nullptr || protected_output_frames == nullptr || + protected_output_frames_size == nullptr || + still_pending_size == nullptr) { + gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_protect_flush()."); + return TSI_INVALID_ARGUMENT; + } + alts_frame_protector* impl = reinterpret_cast(self); + /** + * If there's nothing to flush (i.e., in_place_protect_buffer is empty), + * we're done. + */ + if (impl->in_place_protect_bytes_buffered == 0) { + *protected_output_frames_size = 0; + *still_pending_size = 0; + return TSI_OK; + } + /** + * If a new frame can start being processed, we encrypt the payload and reset + * the frame writer to point to in_place_protect_buffer that holds the newly + * sealed frame. + */ + if (alts_is_frame_writer_done(impl->writer)) { + tsi_result result = seal(impl); + if (result != TSI_OK) { + return result; + } + if (!alts_reset_frame_writer(impl->writer, impl->in_place_protect_buffer, + impl->in_place_protect_bytes_buffered)) { + gpr_log(GPR_ERROR, "Couldn't reset frame writer."); + return TSI_INTERNAL_ERROR; + } + } + /** + * Write the sealed frame as much as possible to protected_output_frames. It's + * possible a frame will not be written out completely by a single flush + * (i.e., still_pending_size != 0), in which case the flush should be called + * iteratively until a complete frame has been written out. + */ + size_t written_frame_bytes = *protected_output_frames_size; + if (!alts_write_frame_bytes(impl->writer, protected_output_frames, + &written_frame_bytes)) { + gpr_log(GPR_ERROR, "Couldn't write frame bytes."); + return TSI_INTERNAL_ERROR; + } + *protected_output_frames_size = written_frame_bytes; + *still_pending_size = alts_get_num_writer_bytes_remaining(impl->writer); + /** + * If the current frame has been finished processing (i.e., sealed and written + * out completely), we empty in_place_protect_buffer. + */ + if (alts_is_frame_writer_done(impl->writer)) { + impl->in_place_protect_bytes_buffered = 0; + } + return TSI_OK; +} + +static tsi_result alts_protect(tsi_frame_protector* self, + const unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size, + unsigned char* protected_output_frames, + size_t* protected_output_frames_size) { + if (self == nullptr || unprotected_bytes == nullptr || + unprotected_bytes_size == nullptr || protected_output_frames == nullptr || + protected_output_frames_size == nullptr) { + gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_protect()."); + return TSI_INVALID_ARGUMENT; + } + alts_frame_protector* impl = reinterpret_cast(self); + + /** + * If more payload can be buffered, we buffer it as much as possible to + * in_place_protect_buffer. + */ + if (impl->in_place_protect_bytes_buffered + impl->overhead_length < + max_encrypted_payload_bytes(impl)) { + size_t bytes_to_buffer = std::min( + *unprotected_bytes_size, max_encrypted_payload_bytes(impl) - + impl->in_place_protect_bytes_buffered - + impl->overhead_length); + *unprotected_bytes_size = bytes_to_buffer; + if (bytes_to_buffer > 0) { + memcpy( + impl->in_place_protect_buffer + impl->in_place_protect_bytes_buffered, + unprotected_bytes, bytes_to_buffer); + impl->in_place_protect_bytes_buffered += bytes_to_buffer; + } + } else { + *unprotected_bytes_size = 0; + } + /** + * If a full frame has been buffered, we output it. If the first condition + * holds, then there exists an unencrypted full frame. If the second + * condition holds, then there exists a full frame that has already been + * encrypted. + */ + if (max_encrypted_payload_bytes(impl) == + impl->in_place_protect_bytes_buffered + impl->overhead_length || + max_encrypted_payload_bytes(impl) == + impl->in_place_protect_bytes_buffered) { + size_t still_pending_size = 0; + return alts_protect_flush(self, protected_output_frames, + protected_output_frames_size, + &still_pending_size); + } else { + *protected_output_frames_size = 0; + return TSI_OK; + } +} + +static tsi_result unseal(alts_frame_protector* impl) { + char* error_details = nullptr; + size_t output_size = 0; + grpc_status_code status = alts_crypter_process_in_place( + impl->unseal_crypter, impl->in_place_unprotect_buffer, + impl->max_unprotected_frame_size, + alts_get_output_bytes_read(impl->reader), &output_size, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "%s", error_details); + gpr_free(error_details); + return TSI_DATA_CORRUPTED; + } + return TSI_OK; +} + +static void ensure_buffer_size(alts_frame_protector* impl) { + if (!alts_has_read_frame_length(impl->reader)) { + return; + } + size_t buffer_space_remaining = impl->max_unprotected_frame_size - + alts_get_output_bytes_read(impl->reader); + /** + * Check if we need to resize in_place_unprotect_buffer in order to hold + * remaining bytes of a full frame. + */ + if (buffer_space_remaining < alts_get_reader_bytes_remaining(impl->reader)) { + size_t buffer_len = alts_get_output_bytes_read(impl->reader) + + alts_get_reader_bytes_remaining(impl->reader); + unsigned char* buffer = static_cast(gpr_malloc(buffer_len)); + memcpy(buffer, impl->in_place_unprotect_buffer, + alts_get_output_bytes_read(impl->reader)); + impl->max_unprotected_frame_size = buffer_len; + gpr_free(impl->in_place_unprotect_buffer); + impl->in_place_unprotect_buffer = buffer; + alts_reset_reader_output_buffer( + impl->reader, buffer + alts_get_output_bytes_read(impl->reader)); + } +} + +static tsi_result alts_unprotect(tsi_frame_protector* self, + const unsigned char* protected_frames_bytes, + size_t* protected_frames_bytes_size, + unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size) { + if (self == nullptr || protected_frames_bytes == nullptr || + protected_frames_bytes_size == nullptr || unprotected_bytes == nullptr || + unprotected_bytes_size == nullptr) { + gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_unprotect()."); + return TSI_INVALID_ARGUMENT; + } + alts_frame_protector* impl = reinterpret_cast(self); + /** + * If a new frame can start being processed, we reset the frame reader to + * point to in_place_unprotect_buffer that will be used to hold deframed + * result. + */ + if (alts_is_frame_reader_done(impl->reader) && + ((alts_get_output_buffer(impl->reader) == nullptr) || + (alts_get_output_bytes_read(impl->reader) == + impl->in_place_unprotect_bytes_processed + impl->overhead_length))) { + if (!alts_reset_frame_reader(impl->reader, + impl->in_place_unprotect_buffer)) { + gpr_log(GPR_ERROR, "Couldn't reset frame reader."); + return TSI_INTERNAL_ERROR; + } + impl->in_place_unprotect_bytes_processed = 0; + } + /** + * If a full frame has not yet been read, we read more bytes from + * protected_frames_bytes until a full frame has been read. We also need to + * make sure in_place_unprotect_buffer is large enough to hold a complete + * frame. + */ + if (!alts_is_frame_reader_done(impl->reader)) { + ensure_buffer_size(impl); + *protected_frames_bytes_size = + std::min(impl->max_unprotected_frame_size - + alts_get_output_bytes_read(impl->reader), + *protected_frames_bytes_size); + size_t read_frames_bytes_size = *protected_frames_bytes_size; + if (!alts_read_frame_bytes(impl->reader, protected_frames_bytes, + &read_frames_bytes_size)) { + gpr_log(GPR_ERROR, "Failed to process frame."); + return TSI_INTERNAL_ERROR; + } + *protected_frames_bytes_size = read_frames_bytes_size; + } else { + *protected_frames_bytes_size = 0; + } + /** + * If a full frame has been read, we unseal it, and write out the + * deframed result to unprotected_bytes. + */ + if (alts_is_frame_reader_done(impl->reader)) { + if (impl->in_place_unprotect_bytes_processed == 0) { + tsi_result result = unseal(impl); + if (result != TSI_OK) { + return result; + } + } + size_t bytes_to_write = std::min( + *unprotected_bytes_size, alts_get_output_bytes_read(impl->reader) - + impl->in_place_unprotect_bytes_processed - + impl->overhead_length); + if (bytes_to_write > 0) { + memcpy(unprotected_bytes, + impl->in_place_unprotect_buffer + + impl->in_place_unprotect_bytes_processed, + bytes_to_write); + } + *unprotected_bytes_size = bytes_to_write; + impl->in_place_unprotect_bytes_processed += bytes_to_write; + return TSI_OK; + } else { + *unprotected_bytes_size = 0; + return TSI_OK; + } +} + +static void alts_destroy(tsi_frame_protector* self) { + alts_frame_protector* impl = reinterpret_cast(self); + if (impl != nullptr) { + alts_crypter_destroy(impl->seal_crypter); + alts_crypter_destroy(impl->unseal_crypter); + gpr_free(impl->in_place_protect_buffer); + gpr_free(impl->in_place_unprotect_buffer); + alts_destroy_frame_writer(impl->writer); + alts_destroy_frame_reader(impl->reader); + gpr_free(impl); + } +} + +static const tsi_frame_protector_vtable alts_frame_protector_vtable = { + alts_protect, alts_protect_flush, alts_unprotect, alts_destroy}; + +static grpc_status_code create_alts_crypters(const uint8_t* key, + size_t key_size, bool is_client, + bool is_rekey, + alts_frame_protector* impl, + char** error_details) { + grpc_status_code status; + gsec_aead_crypter* aead_crypter_seal = nullptr; + gsec_aead_crypter* aead_crypter_unseal = nullptr; + status = gsec_aes_gcm_aead_crypter_create(key, key_size, kAesGcmNonceLength, + kAesGcmTagLength, is_rekey, + &aead_crypter_seal, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + status = gsec_aes_gcm_aead_crypter_create( + key, key_size, kAesGcmNonceLength, kAesGcmTagLength, is_rekey, + &aead_crypter_unseal, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + size_t overflow_size = is_rekey ? kAltsRecordProtocolRekeyFrameLimit + : kAltsRecordProtocolFrameLimit; + status = alts_seal_crypter_create(aead_crypter_seal, is_client, overflow_size, + &impl->seal_crypter, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + status = + alts_unseal_crypter_create(aead_crypter_unseal, is_client, overflow_size, + &impl->unseal_crypter, error_details); + return status; +} + +tsi_result alts_create_frame_protector(const uint8_t* key, size_t key_size, + bool is_client, bool is_rekey, + size_t* max_protected_frame_size, + tsi_frame_protector** self) { + if (key == nullptr || self == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_create_frame_protector()."); + return TSI_INTERNAL_ERROR; + } + char* error_details = nullptr; + alts_frame_protector* impl = grpc_core::Zalloc(); + grpc_status_code status = create_alts_crypters( + key, key_size, is_client, is_rekey, impl, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to create ALTS crypters, %s.", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + /** + * Set maximum frame size to be used by a frame protector. If it is nullptr, a + * default frame size will be used. Otherwise, the provided frame size will be + * adjusted (if not falling into a valid frame range) and used. + */ + size_t max_protected_frame_size_to_set = kDefaultFrameLength; + if (max_protected_frame_size != nullptr) { + *max_protected_frame_size = + std::min(*max_protected_frame_size, kMaxFrameLength); + *max_protected_frame_size = + std::max(*max_protected_frame_size, kMinFrameLength); + max_protected_frame_size_to_set = *max_protected_frame_size; + } + impl->max_protected_frame_size = max_protected_frame_size_to_set; + impl->max_unprotected_frame_size = max_protected_frame_size_to_set; + impl->in_place_protect_bytes_buffered = 0; + impl->in_place_unprotect_bytes_processed = 0; + impl->in_place_protect_buffer = static_cast( + gpr_malloc(sizeof(unsigned char) * max_protected_frame_size_to_set)); + impl->in_place_unprotect_buffer = static_cast( + gpr_malloc(sizeof(unsigned char) * max_protected_frame_size_to_set)); + impl->overhead_length = alts_crypter_num_overhead_bytes(impl->seal_crypter); + impl->writer = alts_create_frame_writer(); + impl->reader = alts_create_frame_reader(); + impl->base.vtable = &alts_frame_protector_vtable; + *self = &impl->base; + return TSI_OK; +} diff --git a/src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.cc b/src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.cc new file mode 100644 index 00000000..0574ed50 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.cc @@ -0,0 +1,114 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.h" + +#include + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +grpc_status_code input_sanity_check( + const alts_record_protocol_crypter* rp_crypter, const unsigned char* data, + size_t* output_size, char** error_details) { + if (rp_crypter == nullptr) { + maybe_copy_error_msg("alts_crypter instance is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } else if (data == nullptr) { + maybe_copy_error_msg("data is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } else if (output_size == nullptr) { + maybe_copy_error_msg("output_size is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + return GRPC_STATUS_OK; +} + +grpc_status_code increment_counter(alts_record_protocol_crypter* rp_crypter, + char** error_details) { + bool is_overflow = false; + grpc_status_code status = + alts_counter_increment(rp_crypter->ctr, &is_overflow, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + if (is_overflow) { + const char error_msg[] = + "crypter counter is wrapped. The connection" + "should be closed and the key should be deleted."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INTERNAL; + } + return GRPC_STATUS_OK; +} + +size_t alts_record_protocol_crypter_num_overhead_bytes(const alts_crypter* c) { + if (c != nullptr) { + size_t num_overhead_bytes = 0; + char* error_details = nullptr; + const alts_record_protocol_crypter* rp_crypter = + reinterpret_cast(c); + grpc_status_code status = gsec_aead_crypter_tag_length( + rp_crypter->crypter, &num_overhead_bytes, &error_details); + if (status == GRPC_STATUS_OK) { + return num_overhead_bytes; + } + } + return 0; +} + +void alts_record_protocol_crypter_destruct(alts_crypter* c) { + if (c != nullptr) { + alts_record_protocol_crypter* rp_crypter = + reinterpret_cast(c); + alts_counter_destroy(rp_crypter->ctr); + gsec_aead_crypter_destroy(rp_crypter->crypter); + } +} + +alts_record_protocol_crypter* alts_crypter_create_common( + gsec_aead_crypter* crypter, bool is_client, size_t overflow_size, + char** error_details) { + if (crypter != nullptr) { + auto* rp_crypter = static_cast( + gpr_malloc(sizeof(alts_record_protocol_crypter))); + size_t counter_size = 0; + grpc_status_code status = + gsec_aead_crypter_nonce_length(crypter, &counter_size, error_details); + if (status != GRPC_STATUS_OK) { + return nullptr; + } + /* Create a counter. */ + status = alts_counter_create(is_client, counter_size, overflow_size, + &rp_crypter->ctr, error_details); + if (status != GRPC_STATUS_OK) { + return nullptr; + } + rp_crypter->crypter = crypter; + return rp_crypter; + } + const char error_msg[] = "crypter is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return nullptr; +} diff --git a/src/core/tsi/alts/frame_protector/alts_seal_privacy_integrity_crypter.cc b/src/core/tsi/alts/frame_protector/alts_seal_privacy_integrity_crypter.cc new file mode 100644 index 00000000..f4078316 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_seal_privacy_integrity_crypter.cc @@ -0,0 +1,105 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/tsi/alts/frame_protector/alts_counter.h" +#include "src/core/tsi/alts/frame_protector/alts_crypter.h" +#include "src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.h" + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +/* Perform input santity check for a seal operation. */ +static grpc_status_code seal_check(alts_crypter* c, const unsigned char* data, + size_t data_allocated_size, size_t data_size, + size_t* output_size, char** error_details) { + /* Do common input sanity check. */ + grpc_status_code status = input_sanity_check( + reinterpret_cast(c), data, + output_size, error_details); + if (status != GRPC_STATUS_OK) return status; + /* Do seal-specific check. */ + size_t num_overhead_bytes = + alts_crypter_num_overhead_bytes(reinterpret_cast(c)); + if (data_size == 0) { + const char error_msg[] = "data_size is zero."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (data_size + num_overhead_bytes > data_allocated_size) { + const char error_msg[] = + "data_allocated_size is smaller than sum of data_size and " + "num_overhead_bytes."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + return GRPC_STATUS_OK; +} + +static grpc_status_code alts_seal_crypter_process_in_place( + alts_crypter* c, unsigned char* data, size_t data_allocated_size, + size_t data_size, size_t* output_size, char** error_details) { + grpc_status_code status = seal_check(c, data, data_allocated_size, data_size, + output_size, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Do AEAD encryption. */ + alts_record_protocol_crypter* rp_crypter = + reinterpret_cast(c); + status = gsec_aead_crypter_encrypt( + rp_crypter->crypter, alts_counter_get_counter(rp_crypter->ctr), + alts_counter_get_size(rp_crypter->ctr), nullptr /* aad */, + 0 /* aad_length */, data, data_size, data, data_allocated_size, + output_size, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Increment the crypter counter. */ + return increment_counter(rp_crypter, error_details); +} + +static const alts_crypter_vtable vtable = { + alts_record_protocol_crypter_num_overhead_bytes, + alts_seal_crypter_process_in_place, alts_record_protocol_crypter_destruct}; + +grpc_status_code alts_seal_crypter_create(gsec_aead_crypter* gc, bool is_client, + size_t overflow_size, + alts_crypter** crypter, + char** error_details) { + if (crypter == nullptr) { + const char error_msg[] = "crypter is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + alts_record_protocol_crypter* rp_crypter = + alts_crypter_create_common(gc, !is_client, overflow_size, error_details); + if (rp_crypter == nullptr) { + return GRPC_STATUS_FAILED_PRECONDITION; + } + rp_crypter->base.vtable = &vtable; + *crypter = &rp_crypter->base; + return GRPC_STATUS_OK; +} diff --git a/src/core/tsi/alts/frame_protector/alts_unseal_privacy_integrity_crypter.cc b/src/core/tsi/alts/frame_protector/alts_unseal_privacy_integrity_crypter.cc new file mode 100644 index 00000000..a5b70032 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/alts_unseal_privacy_integrity_crypter.cc @@ -0,0 +1,103 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/tsi/alts/frame_protector/alts_counter.h" +#include "src/core/tsi/alts/frame_protector/alts_crypter.h" +#include "src/core/tsi/alts/frame_protector/alts_record_protocol_crypter_common.h" + +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +/* Perform input santity check. */ +static grpc_status_code unseal_check(alts_crypter* c, const unsigned char* data, + size_t /*data_allocated_size*/, + size_t data_size, size_t* output_size, + char** error_details) { + /* Do common input sanity check. */ + grpc_status_code status = input_sanity_check( + reinterpret_cast(c), data, + output_size, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Do unseal-specific input check. */ + size_t num_overhead_bytes = + alts_crypter_num_overhead_bytes(reinterpret_cast(c)); + if (num_overhead_bytes > data_size) { + const char error_msg[] = "data_size is smaller than num_overhead_bytes."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + return GRPC_STATUS_OK; +} + +static grpc_status_code alts_unseal_crypter_process_in_place( + alts_crypter* c, unsigned char* data, size_t data_allocated_size, + size_t data_size, size_t* output_size, char** error_details) { + grpc_status_code status = unseal_check(c, data, data_allocated_size, + data_size, output_size, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Do AEAD decryption. */ + alts_record_protocol_crypter* rp_crypter = + reinterpret_cast(c); + status = gsec_aead_crypter_decrypt( + rp_crypter->crypter, alts_counter_get_counter(rp_crypter->ctr), + alts_counter_get_size(rp_crypter->ctr), nullptr /* aad */, + 0 /* aad_length */, data, data_size, data, data_allocated_size, + output_size, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Increment the crypter counter. */ + return increment_counter(rp_crypter, error_details); +} + +static const alts_crypter_vtable vtable = { + alts_record_protocol_crypter_num_overhead_bytes, + alts_unseal_crypter_process_in_place, + alts_record_protocol_crypter_destruct}; + +grpc_status_code alts_unseal_crypter_create(gsec_aead_crypter* gc, + bool is_client, + size_t overflow_size, + alts_crypter** crypter, + char** error_details) { + if (crypter == nullptr) { + const char error_msg[] = "crypter is nullptr."; + maybe_copy_error_msg(error_msg, error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + alts_record_protocol_crypter* rp_crypter = + alts_crypter_create_common(gc, is_client, overflow_size, error_details); + if (rp_crypter == nullptr) { + return GRPC_STATUS_FAILED_PRECONDITION; + } + rp_crypter->base.vtable = &vtable; + *crypter = &rp_crypter->base; + return GRPC_STATUS_OK; +} diff --git a/src/core/tsi/alts/frame_protector/frame_handler.cc b/src/core/tsi/alts/frame_protector/frame_handler.cc new file mode 100644 index 00000000..0237f247 --- /dev/null +++ b/src/core/tsi/alts/frame_protector/frame_handler.cc @@ -0,0 +1,219 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/frame_protector/frame_handler.h" + +#include +#include +#include + +#include + +#include +#include + +#include "src/core/lib/gprpp/memory.h" + +/* Use little endian to interpret a string of bytes as uint32_t. */ +static uint32_t load_32_le(const unsigned char* buffer) { + return (static_cast(buffer[3]) << 24) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[1]) << 8) | + static_cast(buffer[0]); +} + +/* Store uint32_t as a string of little endian bytes. */ +static void store_32_le(uint32_t value, unsigned char* buffer) { + buffer[3] = static_cast(value >> 24) & 0xFF; + buffer[2] = static_cast(value >> 16) & 0xFF; + buffer[1] = static_cast(value >> 8) & 0xFF; + buffer[0] = static_cast(value) & 0xFF; +} + +/* Frame writer implementation. */ +alts_frame_writer* alts_create_frame_writer() { + return grpc_core::Zalloc(); +} + +bool alts_reset_frame_writer(alts_frame_writer* writer, + const unsigned char* buffer, size_t length) { + if (buffer == nullptr) return false; + size_t max_input_size = SIZE_MAX - kFrameLengthFieldSize; + if (length > max_input_size) { + gpr_log(GPR_ERROR, "length must be at most %zu", max_input_size); + return false; + } + writer->input_buffer = buffer; + writer->input_size = length; + writer->input_bytes_written = 0; + writer->header_bytes_written = 0; + store_32_le( + static_cast(writer->input_size + kFrameMessageTypeFieldSize), + writer->header_buffer); + store_32_le(kFrameMessageType, writer->header_buffer + kFrameLengthFieldSize); + return true; +} + +bool alts_write_frame_bytes(alts_frame_writer* writer, unsigned char* output, + size_t* bytes_size) { + if (bytes_size == nullptr || output == nullptr) return false; + if (alts_is_frame_writer_done(writer)) { + *bytes_size = 0; + return true; + } + size_t bytes_written = 0; + /* Write some header bytes, if needed. */ + if (writer->header_bytes_written != sizeof(writer->header_buffer)) { + size_t bytes_to_write = + std::min(*bytes_size, + sizeof(writer->header_buffer) - writer->header_bytes_written); + memcpy(output, writer->header_buffer + writer->header_bytes_written, + bytes_to_write); + bytes_written += bytes_to_write; + *bytes_size -= bytes_to_write; + writer->header_bytes_written += bytes_to_write; + output += bytes_to_write; + if (writer->header_bytes_written != sizeof(writer->header_buffer)) { + *bytes_size = bytes_written; + return true; + } + } + /* Write some non-header bytes. */ + size_t bytes_to_write = + std::min(writer->input_size - writer->input_bytes_written, *bytes_size); + memcpy(output, writer->input_buffer, bytes_to_write); + writer->input_buffer += bytes_to_write; + bytes_written += bytes_to_write; + writer->input_bytes_written += bytes_to_write; + *bytes_size = bytes_written; + return true; +} + +bool alts_is_frame_writer_done(alts_frame_writer* writer) { + return writer->input_buffer == nullptr || + writer->input_size == writer->input_bytes_written; +} + +size_t alts_get_num_writer_bytes_remaining(alts_frame_writer* writer) { + return (sizeof(writer->header_buffer) - writer->header_bytes_written) + + (writer->input_size - writer->input_bytes_written); +} + +void alts_destroy_frame_writer(alts_frame_writer* writer) { gpr_free(writer); } + +/* Frame reader implementation. */ +alts_frame_reader* alts_create_frame_reader() { + alts_frame_reader* reader = grpc_core::Zalloc(); + return reader; +} + +bool alts_is_frame_reader_done(alts_frame_reader* reader) { + return reader->output_buffer == nullptr || + (reader->header_bytes_read == sizeof(reader->header_buffer) && + reader->bytes_remaining == 0); +} + +bool alts_has_read_frame_length(alts_frame_reader* reader) { + return sizeof(reader->header_buffer) == reader->header_bytes_read; +} + +size_t alts_get_reader_bytes_remaining(alts_frame_reader* reader) { + return alts_has_read_frame_length(reader) ? reader->bytes_remaining : 0; +} + +void alts_reset_reader_output_buffer(alts_frame_reader* reader, + unsigned char* buffer) { + reader->output_buffer = buffer; +} + +bool alts_reset_frame_reader(alts_frame_reader* reader, unsigned char* buffer) { + if (buffer == nullptr) return false; + reader->output_buffer = buffer; + reader->bytes_remaining = 0; + reader->header_bytes_read = 0; + reader->output_bytes_read = 0; + return true; +} + +bool alts_read_frame_bytes(alts_frame_reader* reader, + const unsigned char* bytes, size_t* bytes_size) { + if (bytes_size == nullptr) return false; + if (bytes == nullptr) { + *bytes_size = 0; + return false; + } + if (alts_is_frame_reader_done(reader)) { + *bytes_size = 0; + return true; + } + size_t bytes_processed = 0; + /* Process the header, if needed. */ + if (reader->header_bytes_read != sizeof(reader->header_buffer)) { + size_t bytes_to_write = std::min( + *bytes_size, sizeof(reader->header_buffer) - reader->header_bytes_read); + memcpy(reader->header_buffer + reader->header_bytes_read, bytes, + bytes_to_write); + reader->header_bytes_read += bytes_to_write; + bytes_processed += bytes_to_write; + bytes += bytes_to_write; + *bytes_size -= bytes_to_write; + if (reader->header_bytes_read != sizeof(reader->header_buffer)) { + *bytes_size = bytes_processed; + return true; + } + size_t frame_length = load_32_le(reader->header_buffer); + if (frame_length < kFrameMessageTypeFieldSize || + frame_length > kFrameMaxSize) { + gpr_log(GPR_ERROR, + "Bad frame length (should be at least %zu, and at most %zu)", + kFrameMessageTypeFieldSize, kFrameMaxSize); + *bytes_size = 0; + return false; + } + size_t message_type = + load_32_le(reader->header_buffer + kFrameLengthFieldSize); + if (message_type != kFrameMessageType) { + gpr_log(GPR_ERROR, "Unsupported message type %zu (should be %zu)", + message_type, kFrameMessageType); + *bytes_size = 0; + return false; + } + reader->bytes_remaining = frame_length - kFrameMessageTypeFieldSize; + } + /* Process the non-header bytes. */ + size_t bytes_to_write = std::min(*bytes_size, reader->bytes_remaining); + memcpy(reader->output_buffer, bytes, bytes_to_write); + reader->output_buffer += bytes_to_write; + bytes_processed += bytes_to_write; + reader->bytes_remaining -= bytes_to_write; + reader->output_bytes_read += bytes_to_write; + *bytes_size = bytes_processed; + return true; +} + +size_t alts_get_output_bytes_read(alts_frame_reader* reader) { + return reader->output_bytes_read; +} + +unsigned char* alts_get_output_buffer(alts_frame_reader* reader) { + return reader->output_buffer; +} + +void alts_destroy_frame_reader(alts_frame_reader* reader) { gpr_free(reader); } diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc new file mode 100644 index 00000000..87fd18ed --- /dev/null +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc @@ -0,0 +1,903 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" + +#include + +#include "upb/upb.hpp" + +#include +#include +#include + +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" + +#define TSI_ALTS_INITIAL_BUFFER_SIZE 256 + +const int kHandshakerClientOpNum = 4; + +struct alts_handshaker_client { + const alts_handshaker_client_vtable* vtable; +}; + +struct recv_message_result { + tsi_result status; + const unsigned char* bytes_to_send; + size_t bytes_to_send_size; + tsi_handshaker_result* result; +}; + +typedef struct alts_grpc_handshaker_client { + alts_handshaker_client base; + /* One ref is held by the entity that created this handshaker_client, and + * another ref is held by the pending RECEIVE_STATUS_ON_CLIENT op. */ + gpr_refcount refs; + alts_tsi_handshaker* handshaker; + grpc_call* call; + /* A pointer to a function handling the interaction with handshaker service. + * That is, it points to grpc_call_start_batch_and_execute when the handshaker + * client is used in a non-testing use case and points to a custom function + * that validates the data to be sent to handshaker service in a testing use + * case. */ + alts_grpc_caller grpc_caller; + /* A gRPC closure to be scheduled when the response from handshaker service + * is received. It will be initialized with the injected grpc RPC callback. */ + grpc_closure on_handshaker_service_resp_recv; + /* Buffers containing information to be sent (or received) to (or from) the + * handshaker service. */ + grpc_byte_buffer* send_buffer = nullptr; + grpc_byte_buffer* recv_buffer = nullptr; + grpc_status_code status = GRPC_STATUS_OK; + /* Initial metadata to be received from handshaker service. */ + grpc_metadata_array recv_initial_metadata; + /* A callback function provided by an application to be invoked when response + * is received from handshaker service. */ + tsi_handshaker_on_next_done_cb cb; + void* user_data; + /* ALTS credential options passed in from the caller. */ + grpc_alts_credentials_options* options; + /* target name information to be passed to handshaker service for server + * authorization check. */ + grpc_slice target_name; + /* boolean flag indicating if the handshaker client is used at client + * (is_client = true) or server (is_client = false) side. */ + bool is_client; + /* a temporary store for data received from handshaker service used to extract + * unused data. */ + grpc_slice recv_bytes; + /* a buffer containing data to be sent to the grpc client or server's peer. */ + unsigned char* buffer; + size_t buffer_size; + /** callback for receiving handshake call status */ + grpc_closure on_status_received; + /** gRPC status code of handshake call */ + grpc_status_code handshake_status_code = GRPC_STATUS_OK; + /** gRPC status details of handshake call */ + grpc_slice handshake_status_details; + /* mu synchronizes all fields below including their internal fields. */ + grpc_core::Mutex mu; + /* indicates if the handshaker call's RECV_STATUS_ON_CLIENT op is done. */ + bool receive_status_finished = false; + /* if non-null, contains arguments to complete a TSI next callback. */ + recv_message_result* pending_recv_message_result = nullptr; + /* Maximum frame size used by frame protector. */ + size_t max_frame_size; +} alts_grpc_handshaker_client; + +static void handshaker_client_send_buffer_destroy( + alts_grpc_handshaker_client* client) { + GPR_ASSERT(client != nullptr); + grpc_byte_buffer_destroy(client->send_buffer); + client->send_buffer = nullptr; +} + +static bool is_handshake_finished_properly(grpc_gcp_HandshakerResp* resp) { + GPR_ASSERT(resp != nullptr); + return grpc_gcp_HandshakerResp_result(resp) != nullptr; +} + +static void alts_grpc_handshaker_client_unref( + alts_grpc_handshaker_client* client) { + if (gpr_unref(&client->refs)) { + if (client->base.vtable != nullptr && + client->base.vtable->destruct != nullptr) { + client->base.vtable->destruct(&client->base); + } + grpc_byte_buffer_destroy(client->send_buffer); + grpc_byte_buffer_destroy(client->recv_buffer); + client->send_buffer = nullptr; + client->recv_buffer = nullptr; + grpc_metadata_array_destroy(&client->recv_initial_metadata); + grpc_slice_unref_internal(client->recv_bytes); + grpc_slice_unref_internal(client->target_name); + grpc_alts_credentials_options_destroy(client->options); + gpr_free(client->buffer); + grpc_slice_unref_internal(client->handshake_status_details); + delete client; + } +} + +static void maybe_complete_tsi_next( + alts_grpc_handshaker_client* client, bool receive_status_finished, + recv_message_result* pending_recv_message_result) { + recv_message_result* r; + { + grpc_core::MutexLock lock(&client->mu); + client->receive_status_finished |= receive_status_finished; + if (pending_recv_message_result != nullptr) { + GPR_ASSERT(client->pending_recv_message_result == nullptr); + client->pending_recv_message_result = pending_recv_message_result; + } + if (client->pending_recv_message_result == nullptr) { + return; + } + const bool have_final_result = + client->pending_recv_message_result->result != nullptr || + client->pending_recv_message_result->status != TSI_OK; + if (have_final_result && !client->receive_status_finished) { + // If we've received the final message from the handshake + // server, or we're about to invoke the TSI next callback + // with a status other than TSI_OK (which terminates the + // handshake), then first wait for the RECV_STATUS op to complete. + return; + } + r = client->pending_recv_message_result; + client->pending_recv_message_result = nullptr; + } + client->cb(r->status, client->user_data, r->bytes_to_send, + r->bytes_to_send_size, r->result); + gpr_free(r); +} + +static void handle_response_done(alts_grpc_handshaker_client* client, + tsi_result status, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + recv_message_result* p = grpc_core::Zalloc(); + p->status = status; + p->bytes_to_send = bytes_to_send; + p->bytes_to_send_size = bytes_to_send_size; + p->result = result; + maybe_complete_tsi_next(client, false /* receive_status_finished */, + p /* pending_recv_message_result */); +} + +void alts_handshaker_client_handle_response(alts_handshaker_client* c, + bool is_ok) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + grpc_byte_buffer* recv_buffer = client->recv_buffer; + grpc_status_code status = client->status; + alts_tsi_handshaker* handshaker = client->handshaker; + /* Invalid input check. */ + if (client->cb == nullptr) { + gpr_log(GPR_ERROR, + "client->cb is nullptr in alts_tsi_handshaker_handle_response()"); + return; + } + if (handshaker == nullptr) { + gpr_log(GPR_ERROR, + "handshaker is nullptr in alts_tsi_handshaker_handle_response()"); + handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); + return; + } + /* TSI handshake has been shutdown. */ + if (alts_tsi_handshaker_has_shutdown(handshaker)) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN, nullptr, 0, nullptr); + return; + } + /* Failed grpc call check. */ + if (!is_ok || status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "grpc call made to handshaker service failed"); + handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); + return; + } + if (recv_buffer == nullptr) { + gpr_log(GPR_ERROR, + "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()"); + handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr); + return; + } + upb::Arena arena; + grpc_gcp_HandshakerResp* resp = + alts_tsi_utils_deserialize_response(recv_buffer, arena.ptr()); + grpc_byte_buffer_destroy(client->recv_buffer); + client->recv_buffer = nullptr; + /* Invalid handshaker response check. */ + if (resp == nullptr) { + gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed"); + handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr); + return; + } + const grpc_gcp_HandshakerStatus* resp_status = + grpc_gcp_HandshakerResp_status(resp); + if (resp_status == nullptr) { + gpr_log(GPR_ERROR, "No status in HandshakerResp"); + handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr); + return; + } + upb_strview out_frames = grpc_gcp_HandshakerResp_out_frames(resp); + unsigned char* bytes_to_send = nullptr; + size_t bytes_to_send_size = 0; + if (out_frames.size > 0) { + bytes_to_send_size = out_frames.size; + while (bytes_to_send_size > client->buffer_size) { + client->buffer_size *= 2; + client->buffer = static_cast( + gpr_realloc(client->buffer, client->buffer_size)); + } + memcpy(client->buffer, out_frames.data, bytes_to_send_size); + bytes_to_send = client->buffer; + } + tsi_handshaker_result* result = nullptr; + if (is_handshake_finished_properly(resp)) { + tsi_result status = + alts_tsi_handshaker_result_create(resp, client->is_client, &result); + if (status != TSI_OK) { + gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed"); + handle_response_done(client, status, nullptr, 0, nullptr); + return; + } + alts_tsi_handshaker_result_set_unused_bytes( + result, &client->recv_bytes, + grpc_gcp_HandshakerResp_bytes_consumed(resp)); + } + grpc_status_code code = static_cast( + grpc_gcp_HandshakerStatus_code(resp_status)); + if (code != GRPC_STATUS_OK) { + upb_strview details = grpc_gcp_HandshakerStatus_details(resp_status); + if (details.size > 0) { + char* error_details = static_cast(gpr_zalloc(details.size + 1)); + memcpy(error_details, details.data, details.size); + gpr_log(GPR_ERROR, "Error from handshaker service:%s", error_details); + gpr_free(error_details); + } + } + // TODO(apolcyn): consider short ciruiting handle_response_done and + // invoking the TSI callback directly if we aren't done yet, if + // handle_response_done's allocation per message received causes + // a performance issue. + handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code), + bytes_to_send, bytes_to_send_size, result); +} + +static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client, + bool is_start) { + GPR_ASSERT(client != nullptr); + grpc_op ops[kHandshakerClientOpNum]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + if (is_start) { + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = nullptr; + op->data.recv_status_on_client.status = &client->handshake_status_code; + op->data.recv_status_on_client.status_details = + &client->handshake_status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(op - ops <= kHandshakerClientOpNum); + gpr_ref(&client->refs); + grpc_call_error call_error = + client->grpc_caller(client->call, ops, static_cast(op - ops), + &client->on_status_received); + // TODO(apolcyn): return the error here instead, as done for other ops? + GPR_ASSERT(call_error == GRPC_CALL_OK); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + GPR_ASSERT(op - ops <= kHandshakerClientOpNum); + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &client->recv_initial_metadata; + op++; + GPR_ASSERT(op - ops <= kHandshakerClientOpNum); + } + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = client->send_buffer; + op++; + GPR_ASSERT(op - ops <= kHandshakerClientOpNum); + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &client->recv_buffer; + op++; + GPR_ASSERT(op - ops <= kHandshakerClientOpNum); + GPR_ASSERT(client->grpc_caller != nullptr); + if (client->grpc_caller(client->call, ops, static_cast(op - ops), + &client->on_handshaker_service_resp_recv) != + GRPC_CALL_OK) { + gpr_log(GPR_ERROR, "Start batch operation failed"); + return TSI_INTERNAL_ERROR; + } + return TSI_OK; +} + +// TODO(apolcyn): remove this global queue when we can safely rely +// on a MAX_CONCURRENT_STREAMS setting in the ALTS handshake server to +// limit the number of concurrent handshakes. +namespace { + +class HandshakeQueue { + public: + explicit HandshakeQueue(size_t max_outstanding_handshakes) + : max_outstanding_handshakes_(max_outstanding_handshakes) {} + + void RequestHandshake(alts_grpc_handshaker_client* client) { + { + grpc_core::MutexLock lock(&mu_); + if (outstanding_handshakes_ == max_outstanding_handshakes_) { + // Max number already running, add to queue. + queued_handshakes_.push_back(client); + return; + } + // Start the handshake immediately. + ++outstanding_handshakes_; + } + continue_make_grpc_call(client, true /* is_start */); + } + + void HandshakeDone() { + alts_grpc_handshaker_client* client = nullptr; + { + grpc_core::MutexLock lock(&mu_); + if (queued_handshakes_.empty()) { + // Nothing more in queue. Decrement count and return immediately. + --outstanding_handshakes_; + return; + } + // Remove next entry from queue and start the handshake. + client = queued_handshakes_.front(); + queued_handshakes_.pop_front(); + } + continue_make_grpc_call(client, true /* is_start */); + } + + private: + grpc_core::Mutex mu_; + std::list queued_handshakes_; + size_t outstanding_handshakes_ = 0; + const size_t max_outstanding_handshakes_; +}; + +gpr_once g_queued_handshakes_init = GPR_ONCE_INIT; +/* Using separate queues for client and server handshakes is a + * hack that's mainly intended to satisfy the alts_concurrent_connectivity_test, + * which runs many concurrent handshakes where both endpoints + * are in the same process; this situation is problematic with a + * single queue because we have a high chance of using up all outstanding + * slots in the queue, such that there aren't any + * mutual client/server handshakes outstanding at the same time and + * able to make progress. */ +HandshakeQueue* g_client_handshake_queue; +HandshakeQueue* g_server_handshake_queue; + +void DoHandshakeQueuesInit(void) { + const size_t per_queue_max_outstanding_handshakes = 40; + g_client_handshake_queue = + new HandshakeQueue(per_queue_max_outstanding_handshakes); + g_server_handshake_queue = + new HandshakeQueue(per_queue_max_outstanding_handshakes); +} + +void RequestHandshake(alts_grpc_handshaker_client* client, bool is_client) { + gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit); + HandshakeQueue* queue = + is_client ? g_client_handshake_queue : g_server_handshake_queue; + queue->RequestHandshake(client); +} + +void HandshakeDone(bool is_client) { + HandshakeQueue* queue = + is_client ? g_client_handshake_queue : g_server_handshake_queue; + queue->HandshakeDone(); +} + +}; // namespace + +/** + * Populate grpc operation data with the fields of ALTS handshaker client and + * make a grpc call. + */ +static tsi_result make_grpc_call(alts_handshaker_client* c, bool is_start) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + if (is_start) { + RequestHandshake(client, client->is_client); + return TSI_OK; + } else { + return continue_make_grpc_call(client, is_start); + } +} + +static void on_status_received(void* arg, grpc_error_handle error) { + alts_grpc_handshaker_client* client = + static_cast(arg); + if (client->handshake_status_code != GRPC_STATUS_OK) { + // TODO(apolcyn): consider overriding the handshake result's + // status from the final ALTS message with the status here. + char* status_details = + grpc_slice_to_c_string(client->handshake_status_details); + gpr_log(GPR_INFO, + "alts_grpc_handshaker_client:%p on_status_received " + "status:%d details:|%s| error:|%s|", + client, client->handshake_status_code, status_details, + grpc_error_std_string(error).c_str()); + gpr_free(status_details); + } + maybe_complete_tsi_next(client, true /* receive_status_finished */, + nullptr /* pending_recv_message_result */); + HandshakeDone(client->is_client); + alts_grpc_handshaker_client_unref(client); +} + +/* Serializes a grpc_gcp_HandshakerReq message into a buffer and returns newly + * grpc_byte_buffer holding it. */ +static grpc_byte_buffer* get_serialized_handshaker_req( + grpc_gcp_HandshakerReq* req, upb_arena* arena) { + size_t buf_length; + char* buf = grpc_gcp_HandshakerReq_serialize(req, arena, &buf_length); + if (buf == nullptr) { + return nullptr; + } + grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_length); + grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref_internal(slice); + return byte_buffer; +} + +/* Create and populate a client_start handshaker request, then serialize it. */ +static grpc_byte_buffer* get_serialized_start_client( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + upb::Arena arena; + grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr()); + grpc_gcp_StartClientHandshakeReq* start_client = + grpc_gcp_HandshakerReq_mutable_client_start(req, arena.ptr()); + grpc_gcp_StartClientHandshakeReq_set_handshake_security_protocol( + start_client, grpc_gcp_ALTS); + grpc_gcp_StartClientHandshakeReq_add_application_protocols( + start_client, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr()); + grpc_gcp_StartClientHandshakeReq_add_record_protocols( + start_client, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr()); + grpc_gcp_RpcProtocolVersions* client_version = + grpc_gcp_StartClientHandshakeReq_mutable_rpc_versions(start_client, + arena.ptr()); + grpc_gcp_RpcProtocolVersions_assign_from_struct( + client_version, arena.ptr(), &client->options->rpc_versions); + grpc_gcp_StartClientHandshakeReq_set_target_name( + start_client, + upb_strview_make(reinterpret_cast( + GRPC_SLICE_START_PTR(client->target_name)), + GRPC_SLICE_LENGTH(client->target_name))); + target_service_account* ptr = + (reinterpret_cast(client->options)) + ->target_account_list_head; + while (ptr != nullptr) { + grpc_gcp_Identity* target_identity = + grpc_gcp_StartClientHandshakeReq_add_target_identities(start_client, + arena.ptr()); + grpc_gcp_Identity_set_service_account(target_identity, + upb_strview_makez(ptr->data)); + ptr = ptr->next; + } + grpc_gcp_StartClientHandshakeReq_set_max_frame_size( + start_client, static_cast(client->max_frame_size)); + return get_serialized_handshaker_req(req, arena.ptr()); +} + +static tsi_result handshaker_client_start_client(alts_handshaker_client* c) { + if (c == nullptr) { + gpr_log(GPR_ERROR, "client is nullptr in handshaker_client_start_client()"); + return TSI_INVALID_ARGUMENT; + } + grpc_byte_buffer* buffer = get_serialized_start_client(c); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + if (buffer == nullptr) { + gpr_log(GPR_ERROR, "get_serialized_start_client() failed"); + return TSI_INTERNAL_ERROR; + } + handshaker_client_send_buffer_destroy(client); + client->send_buffer = buffer; + tsi_result result = make_grpc_call(&client->base, true /* is_start */); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "make_grpc_call() failed"); + } + return result; +} + +/* Create and populate a start_server handshaker request, then serialize it. */ +static grpc_byte_buffer* get_serialized_start_server( + alts_handshaker_client* c, grpc_slice* bytes_received) { + GPR_ASSERT(c != nullptr); + GPR_ASSERT(bytes_received != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + + upb::Arena arena; + grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr()); + + grpc_gcp_StartServerHandshakeReq* start_server = + grpc_gcp_HandshakerReq_mutable_server_start(req, arena.ptr()); + grpc_gcp_StartServerHandshakeReq_add_application_protocols( + start_server, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr()); + grpc_gcp_ServerHandshakeParameters* value = + grpc_gcp_ServerHandshakeParameters_new(arena.ptr()); + grpc_gcp_ServerHandshakeParameters_add_record_protocols( + value, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr()); + grpc_gcp_StartServerHandshakeReq_handshake_parameters_set( + start_server, grpc_gcp_ALTS, value, arena.ptr()); + grpc_gcp_StartServerHandshakeReq_set_in_bytes( + start_server, upb_strview_make(reinterpret_cast( + GRPC_SLICE_START_PTR(*bytes_received)), + GRPC_SLICE_LENGTH(*bytes_received))); + grpc_gcp_RpcProtocolVersions* server_version = + grpc_gcp_StartServerHandshakeReq_mutable_rpc_versions(start_server, + arena.ptr()); + grpc_gcp_RpcProtocolVersions_assign_from_struct( + server_version, arena.ptr(), &client->options->rpc_versions); + grpc_gcp_StartServerHandshakeReq_set_max_frame_size( + start_server, static_cast(client->max_frame_size)); + return get_serialized_handshaker_req(req, arena.ptr()); +} + +static tsi_result handshaker_client_start_server(alts_handshaker_client* c, + grpc_slice* bytes_received) { + if (c == nullptr || bytes_received == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()"); + return TSI_INVALID_ARGUMENT; + } + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + grpc_byte_buffer* buffer = get_serialized_start_server(c, bytes_received); + if (buffer == nullptr) { + gpr_log(GPR_ERROR, "get_serialized_start_server() failed"); + return TSI_INTERNAL_ERROR; + } + handshaker_client_send_buffer_destroy(client); + client->send_buffer = buffer; + tsi_result result = make_grpc_call(&client->base, true /* is_start */); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "make_grpc_call() failed"); + } + return result; +} + +/* Create and populate a next handshaker request, then serialize it. */ +static grpc_byte_buffer* get_serialized_next(grpc_slice* bytes_received) { + GPR_ASSERT(bytes_received != nullptr); + upb::Arena arena; + grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr()); + grpc_gcp_NextHandshakeMessageReq* next = + grpc_gcp_HandshakerReq_mutable_next(req, arena.ptr()); + grpc_gcp_NextHandshakeMessageReq_set_in_bytes( + next, upb_strview_make(reinterpret_cast GRPC_SLICE_START_PTR( + *bytes_received), + GRPC_SLICE_LENGTH(*bytes_received))); + return get_serialized_handshaker_req(req, arena.ptr()); +} + +static tsi_result handshaker_client_next(alts_handshaker_client* c, + grpc_slice* bytes_received) { + if (c == nullptr || bytes_received == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()"); + return TSI_INVALID_ARGUMENT; + } + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + grpc_slice_unref_internal(client->recv_bytes); + client->recv_bytes = grpc_slice_ref_internal(*bytes_received); + grpc_byte_buffer* buffer = get_serialized_next(bytes_received); + if (buffer == nullptr) { + gpr_log(GPR_ERROR, "get_serialized_next() failed"); + return TSI_INTERNAL_ERROR; + } + handshaker_client_send_buffer_destroy(client); + client->send_buffer = buffer; + tsi_result result = make_grpc_call(&client->base, false /* is_start */); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "make_grpc_call() failed"); + } + return result; +} + +static void handshaker_client_shutdown(alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + if (client->call != nullptr) { + grpc_call_cancel_internal(client->call); + } +} + +static void handshaker_call_unref(void* arg, grpc_error_handle /* error */) { + grpc_call* call = static_cast(arg); + grpc_call_unref(call); +} + +static void handshaker_client_destruct(alts_handshaker_client* c) { + if (c == nullptr) { + return; + } + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + if (client->call != nullptr) { + // Throw this grpc_call_unref over to the ExecCtx so that + // we invoke it at the bottom of the call stack and + // prevent lock inversion problems due to nested ExecCtx flushing. + // TODO(apolcyn): we could remove this indirection and call + // grpc_call_unref inline if there was an internal variant of + // grpc_call_unref that didn't need to flush an ExecCtx. + if (grpc_core::ExecCtx::Get() == nullptr) { + // Unref handshaker call if there is no exec_ctx, e.g., in the case of + // Envoy ALTS transport socket. + grpc_call_unref(client->call); + } else { + // Using existing exec_ctx to unref handshaker call. + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call, + grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + } + } +} + +static const alts_handshaker_client_vtable vtable = { + handshaker_client_start_client, handshaker_client_start_server, + handshaker_client_next, handshaker_client_shutdown, + handshaker_client_destruct}; + +alts_handshaker_client* alts_grpc_handshaker_client_create( + alts_tsi_handshaker* handshaker, grpc_channel* channel, + const char* handshaker_service_url, grpc_pollset_set* interested_parties, + grpc_alts_credentials_options* options, const grpc_slice& target_name, + grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb, + void* user_data, alts_handshaker_client_vtable* vtable_for_testing, + bool is_client, size_t max_frame_size) { + if (channel == nullptr || handshaker_service_url == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()"); + return nullptr; + } + alts_grpc_handshaker_client* client = new alts_grpc_handshaker_client(); + memset(&client->base, 0, sizeof(client->base)); + client->base.vtable = + vtable_for_testing == nullptr ? &vtable : vtable_for_testing; + gpr_ref_init(&client->refs, 1); + client->handshaker = handshaker; + client->grpc_caller = grpc_call_start_batch_and_execute; + grpc_metadata_array_init(&client->recv_initial_metadata); + client->cb = cb; + client->user_data = user_data; + client->options = grpc_alts_credentials_options_copy(options); + client->target_name = grpc_slice_copy(target_name); + client->is_client = is_client; + client->recv_bytes = grpc_empty_slice(); + client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE; + client->buffer = static_cast(gpr_zalloc(client->buffer_size)); + client->handshake_status_details = grpc_empty_slice(); + client->max_frame_size = max_frame_size; + grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url); + client->call = + strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) == + 0 + ? nullptr + : grpc_channel_create_pollset_set_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, interested_parties, + grpc_slice_from_static_string(ALTS_SERVICE_METHOD), &slice, + GRPC_MILLIS_INF_FUTURE, nullptr); + GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&client->on_status_received, on_status_received, client, + grpc_schedule_on_exec_ctx); + grpc_slice_unref_internal(slice); + return &client->base; +} + +namespace grpc_core { +namespace internal { + +void alts_handshaker_client_set_grpc_caller_for_testing( + alts_handshaker_client* c, alts_grpc_caller caller) { + GPR_ASSERT(c != nullptr && caller != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->grpc_caller = caller; +} + +grpc_byte_buffer* alts_handshaker_client_get_send_buffer_for_testing( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + return client->send_buffer; +} + +grpc_byte_buffer** alts_handshaker_client_get_recv_buffer_addr_for_testing( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + return &client->recv_buffer; +} + +grpc_metadata_array* alts_handshaker_client_get_initial_metadata_for_testing( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + return &client->recv_initial_metadata; +} + +void alts_handshaker_client_set_recv_bytes_for_testing( + alts_handshaker_client* c, grpc_slice* recv_bytes) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->recv_bytes = grpc_slice_ref_internal(*recv_bytes); +} + +void alts_handshaker_client_set_fields_for_testing( + alts_handshaker_client* c, alts_tsi_handshaker* handshaker, + tsi_handshaker_on_next_done_cb cb, void* user_data, + grpc_byte_buffer* recv_buffer, grpc_status_code status) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->handshaker = handshaker; + client->cb = cb; + client->user_data = user_data; + client->recv_buffer = recv_buffer; + client->status = status; +} + +void alts_handshaker_client_check_fields_for_testing( + alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb, + void* user_data, bool has_sent_start_message, grpc_slice* recv_bytes) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + GPR_ASSERT(client->cb == cb); + GPR_ASSERT(client->user_data == user_data); + if (recv_bytes != nullptr) { + GPR_ASSERT(grpc_slice_cmp(client->recv_bytes, *recv_bytes) == 0); + } + GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing( + client->handshaker) == has_sent_start_message); +} + +void alts_handshaker_client_set_vtable_for_testing( + alts_handshaker_client* c, alts_handshaker_client_vtable* vtable) { + GPR_ASSERT(c != nullptr); + GPR_ASSERT(vtable != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->base.vtable = vtable; +} + +alts_tsi_handshaker* alts_handshaker_client_get_handshaker_for_testing( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + return client->handshaker; +} + +void alts_handshaker_client_set_cb_for_testing( + alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->cb = cb; +} + +grpc_closure* alts_handshaker_client_get_closure_for_testing( + alts_handshaker_client* c) { + GPR_ASSERT(c != nullptr); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + return &client->on_handshaker_service_resp_recv; +} + +void alts_handshaker_client_ref_for_testing(alts_handshaker_client* c) { + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + gpr_ref(&client->refs); +} + +void alts_handshaker_client_on_status_received_for_testing( + alts_handshaker_client* c, grpc_status_code status, + grpc_error_handle error) { + // We first make sure that the handshake queue has been initialized + // here because there are tests that use this API that mock out + // other parts of the alts_handshaker_client in such a way that the + // code path that would normally ensure that the handshake queue + // has been initialized isn't taken. + gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit); + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + client->handshake_status_code = status; + client->handshake_status_details = grpc_empty_slice(); + grpc_core::Closure::Run(DEBUG_LOCATION, &client->on_status_received, error); +} + +} // namespace internal +} // namespace grpc_core + +tsi_result alts_handshaker_client_start_client(alts_handshaker_client* client) { + if (client != nullptr && client->vtable != nullptr && + client->vtable->client_start != nullptr) { + return client->vtable->client_start(client); + } + gpr_log(GPR_ERROR, + "client or client->vtable has not been initialized properly"); + return TSI_INVALID_ARGUMENT; +} + +tsi_result alts_handshaker_client_start_server(alts_handshaker_client* client, + grpc_slice* bytes_received) { + if (client != nullptr && client->vtable != nullptr && + client->vtable->server_start != nullptr) { + return client->vtable->server_start(client, bytes_received); + } + gpr_log(GPR_ERROR, + "client or client->vtable has not been initialized properly"); + return TSI_INVALID_ARGUMENT; +} + +tsi_result alts_handshaker_client_next(alts_handshaker_client* client, + grpc_slice* bytes_received) { + if (client != nullptr && client->vtable != nullptr && + client->vtable->next != nullptr) { + return client->vtable->next(client, bytes_received); + } + gpr_log(GPR_ERROR, + "client or client->vtable has not been initialized properly"); + return TSI_INVALID_ARGUMENT; +} + +void alts_handshaker_client_shutdown(alts_handshaker_client* client) { + if (client != nullptr && client->vtable != nullptr && + client->vtable->shutdown != nullptr) { + client->vtable->shutdown(client); + } +} + +void alts_handshaker_client_destroy(alts_handshaker_client* c) { + if (c != nullptr) { + alts_grpc_handshaker_client* client = + reinterpret_cast(c); + alts_grpc_handshaker_client_unref(client); + } +} diff --git a/src/core/tsi/alts/handshaker/alts_shared_resource.cc b/src/core/tsi/alts/handshaker/alts_shared_resource.cc new file mode 100644 index 00000000..255eda5d --- /dev/null +++ b/src/core/tsi/alts/handshaker/alts_shared_resource.cc @@ -0,0 +1,83 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" + +#include + +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" + +static alts_shared_resource_dedicated g_alts_resource_dedicated; + +alts_shared_resource_dedicated* grpc_alts_get_shared_resource_dedicated(void) { + return &g_alts_resource_dedicated; +} + +static void thread_worker(void* /*arg*/) { + while (true) { + grpc_event event = + grpc_completion_queue_next(g_alts_resource_dedicated.cq, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type != GRPC_QUEUE_TIMEOUT); + if (event.type == GRPC_QUEUE_SHUTDOWN) { + break; + } + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + alts_handshaker_client* client = + static_cast(event.tag); + alts_handshaker_client_handle_response(client, event.success); + } +} + +void grpc_alts_shared_resource_dedicated_init() { + g_alts_resource_dedicated.cq = nullptr; + gpr_mu_init(&g_alts_resource_dedicated.mu); +} + +void grpc_alts_shared_resource_dedicated_start( + const char* handshaker_service_url) { + gpr_mu_lock(&g_alts_resource_dedicated.mu); + if (g_alts_resource_dedicated.cq == nullptr) { + g_alts_resource_dedicated.channel = + grpc_insecure_channel_create(handshaker_service_url, nullptr, nullptr); + g_alts_resource_dedicated.cq = + grpc_completion_queue_create_for_next(nullptr); + g_alts_resource_dedicated.thread = + grpc_core::Thread("alts_tsi_handshaker", &thread_worker, nullptr); + g_alts_resource_dedicated.interested_parties = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(g_alts_resource_dedicated.interested_parties, + grpc_cq_pollset(g_alts_resource_dedicated.cq)); + g_alts_resource_dedicated.thread.Start(); + } + gpr_mu_unlock(&g_alts_resource_dedicated.mu); +} + +void grpc_alts_shared_resource_dedicated_shutdown() { + if (g_alts_resource_dedicated.cq != nullptr) { + grpc_pollset_set_del_pollset(g_alts_resource_dedicated.interested_parties, + grpc_cq_pollset(g_alts_resource_dedicated.cq)); + grpc_completion_queue_shutdown(g_alts_resource_dedicated.cq); + g_alts_resource_dedicated.thread.Join(); + grpc_pollset_set_destroy(g_alts_resource_dedicated.interested_parties); + grpc_completion_queue_destroy(g_alts_resource_dedicated.cq); + grpc_channel_destroy(g_alts_resource_dedicated.channel); + } + gpr_mu_destroy(&g_alts_resource_dedicated.mu); +} diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc new file mode 100644 index 00000000..aac64452 --- /dev/null +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -0,0 +1,705 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" + +#include +#include +#include + +#include "upb/upb.hpp" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" + +/* Main struct for ALTS TSI handshaker. */ +struct alts_tsi_handshaker { + tsi_handshaker base; + grpc_slice target_name; + bool is_client; + bool has_sent_start_message = false; + bool has_created_handshaker_client = false; + char* handshaker_service_url; + grpc_pollset_set* interested_parties; + grpc_alts_credentials_options* options; + alts_handshaker_client_vtable* client_vtable_for_testing = nullptr; + grpc_channel* channel = nullptr; + bool use_dedicated_cq; + // mu synchronizes all fields below. Note these are the + // only fields that can be concurrently accessed (due to + // potential concurrency of tsi_handshaker_shutdown and + // tsi_handshaker_next). + grpc_core::Mutex mu; + alts_handshaker_client* client = nullptr; + // shutdown effectively follows base.handshake_shutdown, + // but is synchronized by the mutex of this object. + bool shutdown = false; + // Maximum frame size used by frame protector. + size_t max_frame_size; +}; + +/* Main struct for ALTS TSI handshaker result. */ +typedef struct alts_tsi_handshaker_result { + tsi_handshaker_result base; + char* peer_identity; + char* key_data; + unsigned char* unused_bytes; + size_t unused_bytes_size; + grpc_slice rpc_versions; + bool is_client; + grpc_slice serialized_context; + // Peer's maximum frame size. + size_t max_frame_size; +} alts_tsi_handshaker_result; + +static tsi_result handshaker_result_extract_peer( + const tsi_handshaker_result* self, tsi_peer* peer) { + if (self == nullptr || peer == nullptr) { + gpr_log(GPR_ERROR, "Invalid argument to handshaker_result_extract_peer()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + GPR_ASSERT(kTsiAltsNumOfPeerProperties == 5); + tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer); + int index = 0; + if (ok != TSI_OK) { + gpr_log(GPR_ERROR, "Failed to construct tsi peer"); + return ok; + } + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + return ok; + } + index++; + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property_from_cstring( + TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, result->peer_identity, + &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + } + index++; + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property( + TSI_ALTS_RPC_VERSIONS, + reinterpret_cast(GRPC_SLICE_START_PTR(result->rpc_versions)), + GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + } + index++; + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property( + TSI_ALTS_CONTEXT, + reinterpret_cast(GRPC_SLICE_START_PTR(result->serialized_context)), + GRPC_SLICE_LENGTH(result->serialized_context), &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + } + index++; + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + } + GPR_ASSERT(++index == kTsiAltsNumOfPeerProperties); + return ok; +} + +static tsi_result handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY; + return TSI_OK; +} + +static tsi_result handshaker_result_create_zero_copy_grpc_protector( + const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, + tsi_zero_copy_grpc_protector** protector) { + if (self == nullptr || protector == nullptr) { + gpr_log(GPR_ERROR, + "Invalid arguments to create_zero_copy_grpc_protector()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + + // In case the peer does not send max frame size (e.g. peer is gRPC Go or + // peer uses an old binary), the negotiated frame size is set to + // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if + // present). Otherwise, it is based on peer and user specified max frame + // size (if present). + size_t max_frame_size = kTsiAltsMinFrameSize; + if (result->max_frame_size) { + size_t peer_max_frame_size = result->max_frame_size; + max_frame_size = std::min(peer_max_frame_size, + max_output_protected_frame_size == nullptr + ? kTsiAltsMaxFrameSize + : *max_output_protected_frame_size); + max_frame_size = std::max(max_frame_size, kTsiAltsMinFrameSize); + } + max_output_protected_frame_size = &max_frame_size; + gpr_log(GPR_DEBUG, + "After Frame Size Negotiation, maximum frame size used by frame " + "protector equals %zu", + *max_output_protected_frame_size); + tsi_result ok = alts_zero_copy_grpc_protector_create( + reinterpret_cast(result->key_data), + kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client, + /*is_integrity_only=*/false, /*enable_extra_copy=*/false, + max_output_protected_frame_size, protector); + if (ok != TSI_OK) { + gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector"); + } + return ok; +} + +static tsi_result handshaker_result_create_frame_protector( + const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, + tsi_frame_protector** protector) { + if (self == nullptr || protector == nullptr) { + gpr_log(GPR_ERROR, + "Invalid arguments to handshaker_result_create_frame_protector()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + tsi_result ok = alts_create_frame_protector( + reinterpret_cast(result->key_data), + kAltsAes128GcmRekeyKeyLength, result->is_client, /*is_rekey=*/true, + max_output_protected_frame_size, protector); + if (ok != TSI_OK) { + gpr_log(GPR_ERROR, "Failed to create frame protector"); + } + return ok; +} + +static tsi_result handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + if (self == nullptr || bytes == nullptr || bytes_size == nullptr) { + gpr_log(GPR_ERROR, + "Invalid arguments to handshaker_result_get_unused_bytes()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + *bytes = result->unused_bytes; + *bytes_size = result->unused_bytes_size; + return TSI_OK; +} + +static void handshaker_result_destroy(tsi_handshaker_result* self) { + if (self == nullptr) { + return; + } + alts_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + gpr_free(result->peer_identity); + gpr_free(result->key_data); + gpr_free(result->unused_bytes); + grpc_slice_unref_internal(result->rpc_versions); + grpc_slice_unref_internal(result->serialized_context); + gpr_free(result); +} + +static const tsi_handshaker_result_vtable result_vtable = { + handshaker_result_extract_peer, + handshaker_result_get_frame_protector_type, + handshaker_result_create_zero_copy_grpc_protector, + handshaker_result_create_frame_protector, + handshaker_result_get_unused_bytes, + handshaker_result_destroy}; + +tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, + bool is_client, + tsi_handshaker_result** result) { + if (result == nullptr || resp == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()"); + return TSI_INVALID_ARGUMENT; + } + const grpc_gcp_HandshakerResult* hresult = + grpc_gcp_HandshakerResp_result(resp); + const grpc_gcp_Identity* identity = + grpc_gcp_HandshakerResult_peer_identity(hresult); + if (identity == nullptr) { + gpr_log(GPR_ERROR, "Invalid identity"); + return TSI_FAILED_PRECONDITION; + } + upb_strview peer_service_account = + grpc_gcp_Identity_service_account(identity); + if (peer_service_account.size == 0) { + gpr_log(GPR_ERROR, "Invalid peer service account"); + return TSI_FAILED_PRECONDITION; + } + upb_strview key_data = grpc_gcp_HandshakerResult_key_data(hresult); + if (key_data.size < kAltsAes128GcmRekeyKeyLength) { + gpr_log(GPR_ERROR, "Bad key length"); + return TSI_FAILED_PRECONDITION; + } + const grpc_gcp_RpcProtocolVersions* peer_rpc_version = + grpc_gcp_HandshakerResult_peer_rpc_versions(hresult); + if (peer_rpc_version == nullptr) { + gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions."); + return TSI_FAILED_PRECONDITION; + } + upb_strview application_protocol = + grpc_gcp_HandshakerResult_application_protocol(hresult); + if (application_protocol.size == 0) { + gpr_log(GPR_ERROR, "Invalid application protocol"); + return TSI_FAILED_PRECONDITION; + } + upb_strview record_protocol = + grpc_gcp_HandshakerResult_record_protocol(hresult); + if (record_protocol.size == 0) { + gpr_log(GPR_ERROR, "Invalid record protocol"); + return TSI_FAILED_PRECONDITION; + } + const grpc_gcp_Identity* local_identity = + grpc_gcp_HandshakerResult_local_identity(hresult); + if (local_identity == nullptr) { + gpr_log(GPR_ERROR, "Invalid local identity"); + return TSI_FAILED_PRECONDITION; + } + upb_strview local_service_account = + grpc_gcp_Identity_service_account(local_identity); + // We don't check if local service account is empty here + // because local identity could be empty in certain situations. + alts_tsi_handshaker_result* sresult = + grpc_core::Zalloc(); + sresult->key_data = + static_cast(gpr_zalloc(kAltsAes128GcmRekeyKeyLength)); + memcpy(sresult->key_data, key_data.data, kAltsAes128GcmRekeyKeyLength); + sresult->peer_identity = + static_cast(gpr_zalloc(peer_service_account.size + 1)); + memcpy(sresult->peer_identity, peer_service_account.data, + peer_service_account.size); + sresult->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult); + upb::Arena rpc_versions_arena; + bool serialized = grpc_gcp_rpc_protocol_versions_encode( + peer_rpc_version, rpc_versions_arena.ptr(), &sresult->rpc_versions); + if (!serialized) { + gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions."); + return TSI_FAILED_PRECONDITION; + } + upb::Arena context_arena; + grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); + grpc_gcp_AltsContext_set_application_protocol(context, application_protocol); + grpc_gcp_AltsContext_set_record_protocol(context, record_protocol); + // ALTS currently only supports the security level of 2, + // which is "grpc_gcp_INTEGRITY_AND_PRIVACY". + grpc_gcp_AltsContext_set_security_level(context, 2); + grpc_gcp_AltsContext_set_peer_service_account(context, peer_service_account); + grpc_gcp_AltsContext_set_local_service_account(context, + local_service_account); + grpc_gcp_AltsContext_set_peer_rpc_versions( + context, const_cast(peer_rpc_version)); + grpc_gcp_Identity* peer_identity = const_cast(identity); + if (peer_identity == nullptr) { + gpr_log(GPR_ERROR, "Null peer identity in ALTS context."); + return TSI_FAILED_PRECONDITION; + } + if (grpc_gcp_Identity_has_attributes(identity)) { + size_t iter = UPB_MAP_BEGIN; + grpc_gcp_Identity_AttributesEntry* peer_attributes_entry = + grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter); + while (peer_attributes_entry != nullptr) { + upb_strview key = grpc_gcp_Identity_AttributesEntry_key( + const_cast( + peer_attributes_entry)); + upb_strview val = grpc_gcp_Identity_AttributesEntry_value( + const_cast( + peer_attributes_entry)); + grpc_gcp_AltsContext_peer_attributes_set(context, key, val, + context_arena.ptr()); + peer_attributes_entry = + grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter); + } + } + size_t serialized_ctx_length; + char* serialized_ctx = grpc_gcp_AltsContext_serialize( + context, context_arena.ptr(), &serialized_ctx_length); + if (serialized_ctx == nullptr) { + gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context."); + return TSI_FAILED_PRECONDITION; + } + sresult->serialized_context = + grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length); + sresult->is_client = is_client; + sresult->base.vtable = &result_vtable; + *result = &sresult->base; + return TSI_OK; +} + +/* gRPC provided callback used when gRPC thread model is applied. */ +static void on_handshaker_service_resp_recv(void* arg, + grpc_error_handle error) { + alts_handshaker_client* client = static_cast(arg); + if (client == nullptr) { + gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr"); + return; + } + bool success = true; + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, + "ALTS handshaker on_handshaker_service_resp_recv error: %s", + grpc_error_std_string(error).c_str()); + success = false; + } + alts_handshaker_client_handle_response(client, success); +} + +/* gRPC provided callback used when dedicatd CQ and thread are used. + * It serves to safely bring the control back to application. */ +static void on_handshaker_service_resp_recv_dedicated( + void* arg, grpc_error_handle /*error*/) { + alts_shared_resource_dedicated* resource = + grpc_alts_get_shared_resource_dedicated(); + grpc_cq_end_op( + resource->cq, arg, GRPC_ERROR_NONE, + [](void* /*done_arg*/, grpc_cq_completion* /*storage*/) {}, nullptr, + &resource->storage); +} + +/* Returns TSI_OK if and only if no error is encountered. */ +static tsi_result alts_tsi_handshaker_continue_handshaker_next( + alts_tsi_handshaker* handshaker, const unsigned char* received_bytes, + size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb, + void* user_data) { + if (!handshaker->has_created_handshaker_client) { + if (handshaker->channel == nullptr) { + grpc_alts_shared_resource_dedicated_start( + handshaker->handshaker_service_url); + handshaker->interested_parties = + grpc_alts_get_shared_resource_dedicated()->interested_parties; + GPR_ASSERT(handshaker->interested_parties != nullptr); + } + grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr + ? on_handshaker_service_resp_recv_dedicated + : on_handshaker_service_resp_recv; + grpc_channel* channel = + handshaker->channel == nullptr + ? grpc_alts_get_shared_resource_dedicated()->channel + : handshaker->channel; + alts_handshaker_client* client = alts_grpc_handshaker_client_create( + handshaker, channel, handshaker->handshaker_service_url, + handshaker->interested_parties, handshaker->options, + handshaker->target_name, grpc_cb, cb, user_data, + handshaker->client_vtable_for_testing, handshaker->is_client, + handshaker->max_frame_size); + if (client == nullptr) { + gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); + return TSI_FAILED_PRECONDITION; + } + { + grpc_core::MutexLock lock(&handshaker->mu); + GPR_ASSERT(handshaker->client == nullptr); + handshaker->client = client; + if (handshaker->shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + return TSI_HANDSHAKE_SHUTDOWN; + } + } + handshaker->has_created_handshaker_client = true; + } + if (handshaker->channel == nullptr && + handshaker->client_vtable_for_testing == nullptr) { + GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq, + handshaker->client)); + } + grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0) + ? grpc_empty_slice() + : grpc_slice_from_copied_buffer( + reinterpret_cast(received_bytes), + received_bytes_size); + tsi_result ok = TSI_OK; + if (!handshaker->has_sent_start_message) { + handshaker->has_sent_start_message = true; + ok = handshaker->is_client + ? alts_handshaker_client_start_client(handshaker->client) + : alts_handshaker_client_start_server(handshaker->client, &slice); + // It's unsafe for the current thread to access any state in handshaker + // at this point, since alts_handshaker_client_start_client/server + // have potentially just started an op batch on the handshake call. + // The completion callback for that batch is unsynchronized and so + // can invoke the TSI next API callback from any thread, at which point + // there is nothing taking ownership of this handshaker to prevent it + // from being destroyed. + } else { + ok = alts_handshaker_client_next(handshaker->client, &slice); + } + grpc_slice_unref_internal(slice); + return ok; +} + +struct alts_tsi_handshaker_continue_handshaker_next_args { + alts_tsi_handshaker* handshaker; + std::unique_ptr received_bytes; + size_t received_bytes_size; + tsi_handshaker_on_next_done_cb cb; + void* user_data; + grpc_closure closure; +}; + +static void alts_tsi_handshaker_create_channel( + void* arg, grpc_error_handle /* unused_error */) { + alts_tsi_handshaker_continue_handshaker_next_args* next_args = + static_cast(arg); + alts_tsi_handshaker* handshaker = next_args->handshaker; + GPR_ASSERT(handshaker->channel == nullptr); + handshaker->channel = grpc_insecure_channel_create( + next_args->handshaker->handshaker_service_url, nullptr, nullptr); + tsi_result continue_next_result = + alts_tsi_handshaker_continue_handshaker_next( + handshaker, next_args->received_bytes.get(), + next_args->received_bytes_size, next_args->cb, next_args->user_data); + if (continue_next_result != TSI_OK) { + next_args->cb(continue_next_result, next_args->user_data, nullptr, 0, + nullptr); + } + delete next_args; +} + +static tsi_result handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, + size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/, + tsi_handshaker_on_next_done_cb cb, void* user_data) { + if (self == nullptr || cb == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); + return TSI_INVALID_ARGUMENT; + } + alts_tsi_handshaker* handshaker = + reinterpret_cast(self); + { + grpc_core::MutexLock lock(&handshaker->mu); + if (handshaker->shutdown) { + gpr_log(GPR_ERROR, "TSI handshake shutdown"); + return TSI_HANDSHAKE_SHUTDOWN; + } + } + if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) { + alts_tsi_handshaker_continue_handshaker_next_args* args = + new alts_tsi_handshaker_continue_handshaker_next_args(); + args->handshaker = handshaker; + args->received_bytes = nullptr; + args->received_bytes_size = received_bytes_size; + if (received_bytes_size > 0) { + args->received_bytes = std::unique_ptr( + static_cast(gpr_zalloc(received_bytes_size))); + memcpy(args->received_bytes.get(), received_bytes, received_bytes_size); + } + args->cb = cb; + args->user_data = user_data; + GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args, + grpc_schedule_on_exec_ctx); + // We continue this handshaker_next call at the bottom of the ExecCtx just + // so that we can invoke grpc_channel_create at the bottom of the call + // stack. Doing so avoids potential lock cycles between g_init_mu and other + // mutexes within core that might be held on the current call stack + // (note that g_init_mu gets acquired during channel creation). + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, GRPC_ERROR_NONE); + } else { + tsi_result ok = alts_tsi_handshaker_continue_handshaker_next( + handshaker, received_bytes, received_bytes_size, cb, user_data); + if (ok != TSI_OK) { + gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); + return ok; + } + } + return TSI_ASYNC; +} + +/* + * This API will be invoked by a non-gRPC application, and an ExecCtx needs + * to be explicitly created in order to invoke ALTS handshaker client API's + * that assumes the caller is inside gRPC core. + */ +static tsi_result handshaker_next_dedicated( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** bytes_to_send, + size_t* bytes_to_send_size, tsi_handshaker_result** result, + tsi_handshaker_on_next_done_cb cb, void* user_data) { + grpc_core::ExecCtx exec_ctx; + return handshaker_next(self, received_bytes, received_bytes_size, + bytes_to_send, bytes_to_send_size, result, cb, + user_data); +} + +static void handshaker_shutdown(tsi_handshaker* self) { + GPR_ASSERT(self != nullptr); + alts_tsi_handshaker* handshaker = + reinterpret_cast(self); + grpc_core::MutexLock lock(&handshaker->mu); + if (handshaker->shutdown) { + return; + } + if (handshaker->client != nullptr) { + alts_handshaker_client_shutdown(handshaker->client); + } + handshaker->shutdown = true; +} + +static void handshaker_destroy(tsi_handshaker* self) { + if (self == nullptr) { + return; + } + alts_tsi_handshaker* handshaker = + reinterpret_cast(self); + alts_handshaker_client_destroy(handshaker->client); + grpc_slice_unref_internal(handshaker->target_name); + grpc_alts_credentials_options_destroy(handshaker->options); + if (handshaker->channel != nullptr) { + grpc_channel_destroy_internal(handshaker->channel); + } + gpr_free(handshaker->handshaker_service_url); + delete handshaker; +} + +static const tsi_handshaker_vtable handshaker_vtable = { + nullptr, nullptr, + nullptr, nullptr, + nullptr, handshaker_destroy, + handshaker_next, handshaker_shutdown}; + +static const tsi_handshaker_vtable handshaker_vtable_dedicated = { + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + handshaker_destroy, + handshaker_next_dedicated, + handshaker_shutdown}; + +bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) { + GPR_ASSERT(handshaker != nullptr); + grpc_core::MutexLock lock(&handshaker->mu); + return handshaker->shutdown; +} + +tsi_result alts_tsi_handshaker_create( + const grpc_alts_credentials_options* options, const char* target_name, + const char* handshaker_service_url, bool is_client, + grpc_pollset_set* interested_parties, tsi_handshaker** self, + size_t user_specified_max_frame_size) { + if (handshaker_service_url == nullptr || self == nullptr || + options == nullptr || (is_client && target_name == nullptr)) { + gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()"); + return TSI_INVALID_ARGUMENT; + } + bool use_dedicated_cq = interested_parties == nullptr; + alts_tsi_handshaker* handshaker = new alts_tsi_handshaker(); + memset(&handshaker->base, 0, sizeof(handshaker->base)); + handshaker->base.vtable = + use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable; + handshaker->target_name = target_name == nullptr + ? grpc_empty_slice() + : grpc_slice_from_static_string(target_name); + handshaker->is_client = is_client; + handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); + handshaker->interested_parties = interested_parties; + handshaker->options = grpc_alts_credentials_options_copy(options); + handshaker->use_dedicated_cq = use_dedicated_cq; + handshaker->max_frame_size = user_specified_max_frame_size != 0 + ? user_specified_max_frame_size + : kTsiAltsMaxFrameSize; + *self = &handshaker->base; + return TSI_OK; +} + +void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* result, + grpc_slice* recv_bytes, + size_t bytes_consumed) { + GPR_ASSERT(recv_bytes != nullptr && result != nullptr); + if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) { + return; + } + alts_tsi_handshaker_result* sresult = + reinterpret_cast(result); + sresult->unused_bytes_size = GRPC_SLICE_LENGTH(*recv_bytes) - bytes_consumed; + sresult->unused_bytes = + static_cast(gpr_zalloc(sresult->unused_bytes_size)); + memcpy(sresult->unused_bytes, + GRPC_SLICE_START_PTR(*recv_bytes) + bytes_consumed, + sresult->unused_bytes_size); +} + +namespace grpc_core { +namespace internal { + +bool alts_tsi_handshaker_get_has_sent_start_message_for_testing( + alts_tsi_handshaker* handshaker) { + GPR_ASSERT(handshaker != nullptr); + return handshaker->has_sent_start_message; +} + +void alts_tsi_handshaker_set_client_vtable_for_testing( + alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) { + GPR_ASSERT(handshaker != nullptr); + handshaker->client_vtable_for_testing = vtable; +} + +bool alts_tsi_handshaker_get_is_client_for_testing( + alts_tsi_handshaker* handshaker) { + GPR_ASSERT(handshaker != nullptr); + return handshaker->is_client; +} + +alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing( + alts_tsi_handshaker* handshaker) { + return handshaker->client; +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/tsi/alts/handshaker/alts_tsi_utils.cc b/src/core/tsi/alts/handshaker/alts_tsi_utils.cc new file mode 100644 index 00000000..f80498db --- /dev/null +++ b/src/core/tsi/alts/handshaker/alts_tsi_utils.cc @@ -0,0 +1,64 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" + +#include + +#include "src/core/lib/slice/slice_internal.h" + +tsi_result alts_tsi_utils_convert_to_tsi_result(grpc_status_code code) { + switch (code) { + case GRPC_STATUS_OK: + return TSI_OK; + case GRPC_STATUS_UNKNOWN: + return TSI_UNKNOWN_ERROR; + case GRPC_STATUS_INVALID_ARGUMENT: + return TSI_INVALID_ARGUMENT; + case GRPC_STATUS_NOT_FOUND: + return TSI_NOT_FOUND; + case GRPC_STATUS_INTERNAL: + return TSI_INTERNAL_ERROR; + default: + return TSI_UNKNOWN_ERROR; + } +} + +grpc_gcp_HandshakerResp* alts_tsi_utils_deserialize_response( + grpc_byte_buffer* resp_buffer, upb_arena* arena) { + GPR_ASSERT(resp_buffer != nullptr); + GPR_ASSERT(arena != nullptr); + grpc_byte_buffer_reader bbr; + grpc_byte_buffer_reader_init(&bbr, resp_buffer); + grpc_slice slice = grpc_byte_buffer_reader_readall(&bbr); + size_t buf_size = GPR_SLICE_LENGTH(slice); + void* buf = upb_arena_malloc(arena, buf_size); + memcpy(buf, reinterpret_cast(GPR_SLICE_START_PTR(slice)), + buf_size); + grpc_gcp_HandshakerResp* resp = grpc_gcp_HandshakerResp_parse( + reinterpret_cast(buf), buf_size, arena); + grpc_slice_unref_internal(slice); + grpc_byte_buffer_reader_destroy(&bbr); + if (resp == nullptr) { + gpr_log(GPR_ERROR, "grpc_gcp_handshaker_resp_decode() failed"); + return nullptr; + } + return resp; +} diff --git a/src/core/tsi/alts/handshaker/transport_security_common_api.cc b/src/core/tsi/alts/handshaker/transport_security_common_api.cc new file mode 100644 index 00000000..402bf131 --- /dev/null +++ b/src/core/tsi/alts/handshaker/transport_security_common_api.cc @@ -0,0 +1,223 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +#include "upb/upb.hpp" + +bool grpc_gcp_rpc_protocol_versions_set_max( + grpc_gcp_rpc_protocol_versions* versions, uint32_t max_major, + uint32_t max_minor) { + if (versions == nullptr) { + gpr_log(GPR_ERROR, + "versions is nullptr in " + "grpc_gcp_rpc_protocol_versions_set_max()."); + return false; + } + versions->max_rpc_version.major = max_major; + versions->max_rpc_version.minor = max_minor; + return true; +} + +bool grpc_gcp_rpc_protocol_versions_set_min( + grpc_gcp_rpc_protocol_versions* versions, uint32_t min_major, + uint32_t min_minor) { + if (versions == nullptr) { + gpr_log(GPR_ERROR, + "versions is nullptr in " + "grpc_gcp_rpc_protocol_versions_set_min()."); + return false; + } + versions->min_rpc_version.major = min_major; + versions->min_rpc_version.minor = min_minor; + return true; +} + +bool grpc_gcp_rpc_protocol_versions_encode( + const grpc_gcp_rpc_protocol_versions* versions, grpc_slice* slice) { + if (versions == nullptr || slice == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_rpc_protocol_versions_encode()."); + return false; + } + upb::Arena arena; + grpc_gcp_RpcProtocolVersions* versions_msg = + grpc_gcp_RpcProtocolVersions_new(arena.ptr()); + grpc_gcp_RpcProtocolVersions_assign_from_struct(versions_msg, arena.ptr(), + versions); + return grpc_gcp_rpc_protocol_versions_encode(versions_msg, arena.ptr(), + slice); +} + +bool grpc_gcp_rpc_protocol_versions_encode( + const grpc_gcp_RpcProtocolVersions* versions, upb_arena* arena, + grpc_slice* slice) { + if (versions == nullptr || arena == nullptr || slice == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_rpc_protocol_versions_encode()."); + return false; + } + size_t buf_length; + char* buf = + grpc_gcp_RpcProtocolVersions_serialize(versions, arena, &buf_length); + if (buf == nullptr) { + return false; + } + *slice = grpc_slice_from_copied_buffer(buf, buf_length); + return true; +} + +bool grpc_gcp_rpc_protocol_versions_decode( + const grpc_slice& slice, grpc_gcp_rpc_protocol_versions* versions) { + if (versions == nullptr) { + gpr_log(GPR_ERROR, + "version is nullptr in " + "grpc_gcp_rpc_protocol_versions_decode()."); + return false; + } + upb::Arena arena; + grpc_gcp_RpcProtocolVersions* versions_msg = + grpc_gcp_RpcProtocolVersions_parse( + reinterpret_cast(GRPC_SLICE_START_PTR(slice)), + GRPC_SLICE_LENGTH(slice), arena.ptr()); + if (versions_msg == nullptr) { + gpr_log(GPR_ERROR, "cannot deserialize RpcProtocolVersions message"); + return false; + } + grpc_gcp_rpc_protocol_versions_assign_from_upb(versions, versions_msg); + return true; +} + +void grpc_gcp_rpc_protocol_versions_assign_from_upb( + grpc_gcp_rpc_protocol_versions* versions, + const grpc_gcp_RpcProtocolVersions* value) { + const grpc_gcp_RpcProtocolVersions_Version* max_version_msg = + grpc_gcp_RpcProtocolVersions_max_rpc_version(value); + if (max_version_msg != nullptr) { + versions->max_rpc_version.major = + grpc_gcp_RpcProtocolVersions_Version_major(max_version_msg); + versions->max_rpc_version.minor = + grpc_gcp_RpcProtocolVersions_Version_minor(max_version_msg); + } else { + versions->max_rpc_version.major = 0; + versions->max_rpc_version.minor = 0; + } + const grpc_gcp_RpcProtocolVersions_Version* min_version_msg = + grpc_gcp_RpcProtocolVersions_min_rpc_version(value); + if (min_version_msg != nullptr) { + versions->min_rpc_version.major = + grpc_gcp_RpcProtocolVersions_Version_major(min_version_msg); + versions->min_rpc_version.minor = + grpc_gcp_RpcProtocolVersions_Version_minor(min_version_msg); + } else { + versions->min_rpc_version.major = 0; + versions->min_rpc_version.minor = 0; + } +} + +void grpc_gcp_RpcProtocolVersions_assign_from_struct( + grpc_gcp_RpcProtocolVersions* versions, upb_arena* arena, + const grpc_gcp_rpc_protocol_versions* value) { + grpc_gcp_RpcProtocolVersions_Version* max_version_msg = + grpc_gcp_RpcProtocolVersions_mutable_max_rpc_version(versions, arena); + grpc_gcp_RpcProtocolVersions_Version_set_major(max_version_msg, + value->max_rpc_version.major); + grpc_gcp_RpcProtocolVersions_Version_set_minor(max_version_msg, + value->max_rpc_version.minor); + grpc_gcp_RpcProtocolVersions_Version* min_version_msg = + grpc_gcp_RpcProtocolVersions_mutable_min_rpc_version(versions, arena); + grpc_gcp_RpcProtocolVersions_Version_set_major(min_version_msg, + value->min_rpc_version.major); + grpc_gcp_RpcProtocolVersions_Version_set_minor(min_version_msg, + value->min_rpc_version.minor); +} + +bool grpc_gcp_rpc_protocol_versions_copy( + const grpc_gcp_rpc_protocol_versions* src, + grpc_gcp_rpc_protocol_versions* dst) { + if ((src == nullptr && dst != nullptr) || + (src != nullptr && dst == nullptr)) { + gpr_log(GPR_ERROR, + "Invalid arguments to " + "grpc_gcp_rpc_protocol_versions_copy()."); + return false; + } + if (src == nullptr) { + return true; + } + grpc_gcp_rpc_protocol_versions_set_max(dst, src->max_rpc_version.major, + src->max_rpc_version.minor); + grpc_gcp_rpc_protocol_versions_set_min(dst, src->min_rpc_version.major, + src->min_rpc_version.minor); + return true; +} + +namespace grpc_core { +namespace internal { + +int grpc_gcp_rpc_protocol_version_compare( + const grpc_gcp_rpc_protocol_versions_version* v1, + const grpc_gcp_rpc_protocol_versions_version* v2) { + if ((v1->major > v2->major) || + (v1->major == v2->major && v1->minor > v2->minor)) { + return 1; + } + if ((v1->major < v2->major) || + (v1->major == v2->major && v1->minor < v2->minor)) { + return -1; + } + return 0; +} + +} // namespace internal +} // namespace grpc_core + +bool grpc_gcp_rpc_protocol_versions_check( + const grpc_gcp_rpc_protocol_versions* local_versions, + const grpc_gcp_rpc_protocol_versions* peer_versions, + grpc_gcp_rpc_protocol_versions_version* highest_common_version) { + if (local_versions == nullptr || peer_versions == nullptr) { + gpr_log(GPR_ERROR, + "Invalid arguments to " + "grpc_gcp_rpc_protocol_versions_check()."); + return false; + } + /* max_common_version is MIN(local.max, peer.max) */ + const grpc_gcp_rpc_protocol_versions_version* max_common_version = + grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &local_versions->max_rpc_version, &peer_versions->max_rpc_version) > 0 + ? &peer_versions->max_rpc_version + : &local_versions->max_rpc_version; + /* min_common_version is MAX(local.min, peer.min) */ + const grpc_gcp_rpc_protocol_versions_version* min_common_version = + grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &local_versions->min_rpc_version, &peer_versions->min_rpc_version) > 0 + ? &local_versions->min_rpc_version + : &peer_versions->min_rpc_version; + bool result = grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + max_common_version, min_common_version) >= 0; + if (result && highest_common_version != nullptr) { + memcpy(highest_common_version, max_common_version, + sizeof(grpc_gcp_rpc_protocol_versions_version)); + } + return result; +} diff --git a/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.cc b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.cc new file mode 100644 index 00000000..67cdd994 --- /dev/null +++ b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.cc @@ -0,0 +1,226 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.h" + +#include + +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" + +/* Main struct for alts_grpc_integrity_only_record_protocol. */ +typedef struct alts_grpc_integrity_only_record_protocol { + alts_grpc_record_protocol base; + bool enable_extra_copy; + grpc_slice_buffer data_sb; + unsigned char* tag_buf; +} alts_grpc_integrity_only_record_protocol; + +/* --- alts_grpc_record_protocol methods implementation. --- */ + +static tsi_result alts_grpc_integrity_only_extra_copy_protect( + alts_grpc_record_protocol* rp, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + /* Allocates memory for protected frame and copies data. */ + size_t data_length = unprotected_slices->length; + size_t protected_frame_size = + unprotected_slices->length + rp->header_length + rp->tag_length; + grpc_slice protected_slice = GRPC_SLICE_MALLOC(protected_frame_size); + uint8_t* data = GRPC_SLICE_START_PTR(protected_slice) + rp->header_length; + for (size_t i = 0; i < unprotected_slices->count; i++) { + memcpy(data, GRPC_SLICE_START_PTR(unprotected_slices->slices[i]), + GRPC_SLICE_LENGTH(unprotected_slices->slices[i])); + data += GRPC_SLICE_LENGTH(unprotected_slices->slices[i]); + } + /* Calls alts_iovec_record_protocol protect. */ + char* error_details = nullptr; + iovec_t header_iovec = {GRPC_SLICE_START_PTR(protected_slice), + rp->header_length}; + iovec_t tag_iovec = { + GRPC_SLICE_START_PTR(protected_slice) + rp->header_length + data_length, + rp->tag_length}; + rp->iovec_buf[0].iov_base = + GRPC_SLICE_START_PTR(protected_slice) + rp->header_length; + rp->iovec_buf[0].iov_len = data_length; + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + rp->iovec_rp, rp->iovec_buf, 1, header_iovec, tag_iovec, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to protect, %s", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + grpc_slice_buffer_add(protected_slices, protected_slice); + grpc_slice_buffer_reset_and_unref_internal(unprotected_slices); + return TSI_OK; +} + +static tsi_result alts_grpc_integrity_only_protect( + alts_grpc_record_protocol* rp, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + /* Input sanity check. */ + if (rp == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol protect."); + return TSI_INVALID_ARGUMENT; + } + alts_grpc_integrity_only_record_protocol* integrity_only_record_protocol = + reinterpret_cast(rp); + if (integrity_only_record_protocol->enable_extra_copy) { + return alts_grpc_integrity_only_extra_copy_protect(rp, unprotected_slices, + protected_slices); + } + /* Allocates memory for header and tag slices. */ + grpc_slice header_slice = GRPC_SLICE_MALLOC(rp->header_length); + grpc_slice tag_slice = GRPC_SLICE_MALLOC(rp->tag_length); + /* Calls alts_iovec_record_protocol protect. */ + char* error_details = nullptr; + iovec_t header_iovec = {GRPC_SLICE_START_PTR(header_slice), + GRPC_SLICE_LENGTH(header_slice)}; + iovec_t tag_iovec = {GRPC_SLICE_START_PTR(tag_slice), + GRPC_SLICE_LENGTH(tag_slice)}; + alts_grpc_record_protocol_convert_slice_buffer_to_iovec(rp, + unprotected_slices); + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + rp->iovec_rp, rp->iovec_buf, unprotected_slices->count, header_iovec, + tag_iovec, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to protect, %s", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + /* Appends result to protected_slices. */ + grpc_slice_buffer_add(protected_slices, header_slice); + grpc_slice_buffer_move_into(unprotected_slices, protected_slices); + grpc_slice_buffer_add(protected_slices, tag_slice); + return TSI_OK; +} + +static tsi_result alts_grpc_integrity_only_unprotect( + alts_grpc_record_protocol* rp, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + /* Input sanity check. */ + if (rp == nullptr || protected_slices == nullptr || + unprotected_slices == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol unprotect."); + return TSI_INVALID_ARGUMENT; + } + if (protected_slices->length < rp->header_length + rp->tag_length) { + gpr_log(GPR_ERROR, "Protected slices do not have sufficient data."); + return TSI_INVALID_ARGUMENT; + } + /* In this method, rp points to alts_grpc_record_protocol struct + * and integrity_only_record_protocol points to + * alts_grpc_integrity_only_record_protocol struct. */ + alts_grpc_integrity_only_record_protocol* integrity_only_record_protocol = + reinterpret_cast(rp); + /* Strips frame header from protected slices. */ + grpc_slice_buffer_reset_and_unref_internal(&rp->header_sb); + grpc_slice_buffer_move_first(protected_slices, rp->header_length, + &rp->header_sb); + GPR_ASSERT(rp->header_sb.length == rp->header_length); + iovec_t header_iovec = alts_grpc_record_protocol_get_header_iovec(rp); + /* Moves protected slices data to data_sb and leaves the remaining tag. */ + grpc_slice_buffer_reset_and_unref_internal( + &integrity_only_record_protocol->data_sb); + grpc_slice_buffer_move_first(protected_slices, + protected_slices->length - rp->tag_length, + &integrity_only_record_protocol->data_sb); + GPR_ASSERT(protected_slices->length == rp->tag_length); + iovec_t tag_iovec = {nullptr, rp->tag_length}; + if (protected_slices->count == 1) { + tag_iovec.iov_base = GRPC_SLICE_START_PTR(protected_slices->slices[0]); + } else { + /* Frame tag is in multiple slices, copies the tag bytes from slice + * buffer to a single flat buffer. */ + alts_grpc_record_protocol_copy_slice_buffer( + protected_slices, integrity_only_record_protocol->tag_buf); + tag_iovec.iov_base = integrity_only_record_protocol->tag_buf; + } + /* Calls alts_iovec_record_protocol unprotect. */ + char* error_details = nullptr; + alts_grpc_record_protocol_convert_slice_buffer_to_iovec( + rp, &integrity_only_record_protocol->data_sb); + grpc_status_code status = alts_iovec_record_protocol_integrity_only_unprotect( + rp->iovec_rp, rp->iovec_buf, + integrity_only_record_protocol->data_sb.count, header_iovec, tag_iovec, + &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to unprotect, %s", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + grpc_slice_buffer_reset_and_unref_internal(&rp->header_sb); + grpc_slice_buffer_reset_and_unref_internal(protected_slices); + grpc_slice_buffer_move_into(&integrity_only_record_protocol->data_sb, + unprotected_slices); + return TSI_OK; +} + +static void alts_grpc_integrity_only_destruct(alts_grpc_record_protocol* rp) { + if (rp == nullptr) { + return; + } + alts_grpc_integrity_only_record_protocol* integrity_only_rp = + reinterpret_cast(rp); + grpc_slice_buffer_destroy_internal(&integrity_only_rp->data_sb); + gpr_free(integrity_only_rp->tag_buf); +} + +static const alts_grpc_record_protocol_vtable + alts_grpc_integrity_only_record_protocol_vtable = { + alts_grpc_integrity_only_protect, alts_grpc_integrity_only_unprotect, + alts_grpc_integrity_only_destruct}; + +tsi_result alts_grpc_integrity_only_record_protocol_create( + gsec_aead_crypter* crypter, size_t overflow_size, bool is_client, + bool is_protect, bool enable_extra_copy, alts_grpc_record_protocol** rp) { + if (crypter == nullptr || rp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol create."); + return TSI_INVALID_ARGUMENT; + } + alts_grpc_integrity_only_record_protocol* impl = + static_cast( + gpr_zalloc(sizeof(alts_grpc_integrity_only_record_protocol))); + /* Calls alts_grpc_record_protocol init. */ + tsi_result result = alts_grpc_record_protocol_init( + &impl->base, crypter, overflow_size, is_client, + /*is_integrity_only=*/true, is_protect); + if (result != TSI_OK) { + gpr_free(impl); + return result; + } + impl->enable_extra_copy = enable_extra_copy; + /* Initializes slice buffer for data_sb. */ + grpc_slice_buffer_init(&impl->data_sb); + /* Allocates tag buffer. */ + impl->tag_buf = + static_cast(gpr_malloc(impl->base.tag_length)); + impl->base.vtable = &alts_grpc_integrity_only_record_protocol_vtable; + *rp = &impl->base; + return TSI_OK; +} diff --git a/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.cc b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.cc new file mode 100644 index 00000000..e7890903 --- /dev/null +++ b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.cc @@ -0,0 +1,144 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.h" + +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" + +/* Privacy-integrity alts_grpc_record_protocol object uses the same struct + * defined in alts_grpc_record_protocol_common.h. */ + +/* --- alts_grpc_record_protocol methods implementation. --- */ + +static tsi_result alts_grpc_privacy_integrity_protect( + alts_grpc_record_protocol* rp, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + /* Input sanity check. */ + if (rp == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol protect."); + return TSI_INVALID_ARGUMENT; + } + /* Allocates memory for output frame. In privacy-integrity protect, the + * protected frame is stored in a newly allocated buffer. */ + size_t protected_frame_size = + unprotected_slices->length + rp->header_length + + alts_iovec_record_protocol_get_tag_length(rp->iovec_rp); + grpc_slice protected_slice = GRPC_SLICE_MALLOC(protected_frame_size); + iovec_t protected_iovec = {GRPC_SLICE_START_PTR(protected_slice), + GRPC_SLICE_LENGTH(protected_slice)}; + /* Calls alts_iovec_record_protocol protect. */ + char* error_details = nullptr; + alts_grpc_record_protocol_convert_slice_buffer_to_iovec(rp, + unprotected_slices); + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + rp->iovec_rp, rp->iovec_buf, unprotected_slices->count, + protected_iovec, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to protect, %s", error_details); + gpr_free(error_details); + grpc_slice_unref_internal(protected_slice); + return TSI_INTERNAL_ERROR; + } + grpc_slice_buffer_add(protected_slices, protected_slice); + grpc_slice_buffer_reset_and_unref_internal(unprotected_slices); + return TSI_OK; +} + +static tsi_result alts_grpc_privacy_integrity_unprotect( + alts_grpc_record_protocol* rp, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + /* Input sanity check. */ + if (rp == nullptr || protected_slices == nullptr || + unprotected_slices == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol unprotect."); + return TSI_INVALID_ARGUMENT; + } + /* Allocates memory for output frame. In privacy-integrity unprotect, the + * unprotected data are stored in a newly allocated buffer. */ + if (protected_slices->length < rp->header_length + rp->tag_length) { + gpr_log(GPR_ERROR, "Protected slices do not have sufficient data."); + return TSI_INVALID_ARGUMENT; + } + size_t unprotected_frame_size = + protected_slices->length - rp->header_length - rp->tag_length; + grpc_slice unprotected_slice = GRPC_SLICE_MALLOC(unprotected_frame_size); + iovec_t unprotected_iovec = {GRPC_SLICE_START_PTR(unprotected_slice), + GRPC_SLICE_LENGTH(unprotected_slice)}; + /* Strips frame header from protected slices. */ + grpc_slice_buffer_reset_and_unref_internal(&rp->header_sb); + grpc_slice_buffer_move_first(protected_slices, rp->header_length, + &rp->header_sb); + iovec_t header_iovec = alts_grpc_record_protocol_get_header_iovec(rp); + /* Calls alts_iovec_record_protocol unprotect. */ + char* error_details = nullptr; + alts_grpc_record_protocol_convert_slice_buffer_to_iovec(rp, protected_slices); + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_unprotect( + rp->iovec_rp, header_iovec, rp->iovec_buf, protected_slices->count, + unprotected_iovec, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to unprotect, %s", error_details); + gpr_free(error_details); + grpc_slice_unref_internal(unprotected_slice); + return TSI_INTERNAL_ERROR; + } + grpc_slice_buffer_reset_and_unref_internal(&rp->header_sb); + grpc_slice_buffer_reset_and_unref_internal(protected_slices); + grpc_slice_buffer_add(unprotected_slices, unprotected_slice); + return TSI_OK; +} + +static const alts_grpc_record_protocol_vtable + alts_grpc_privacy_integrity_record_protocol_vtable = { + alts_grpc_privacy_integrity_protect, + alts_grpc_privacy_integrity_unprotect, nullptr}; + +tsi_result alts_grpc_privacy_integrity_record_protocol_create( + gsec_aead_crypter* crypter, size_t overflow_size, bool is_client, + bool is_protect, alts_grpc_record_protocol** rp) { + if (crypter == nullptr || rp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol create."); + return TSI_INVALID_ARGUMENT; + } + auto* impl = static_cast( + gpr_zalloc(sizeof(alts_grpc_record_protocol))); + /* Calls alts_grpc_record_protocol init. */ + tsi_result result = + alts_grpc_record_protocol_init(impl, crypter, overflow_size, is_client, + /*is_integrity_only=*/false, is_protect); + if (result != TSI_OK) { + gpr_free(impl); + return result; + } + impl->vtable = &alts_grpc_privacy_integrity_record_protocol_vtable; + *rp = impl; + return TSI_OK; +} diff --git a/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.cc b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.cc new file mode 100644 index 00000000..f601b326 --- /dev/null +++ b/src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.cc @@ -0,0 +1,174 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.h" + +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" + +const size_t kInitialIovecBufferSize = 8; + +/* Makes sure iovec_buf in alts_grpc_record_protocol is large enough. */ +static void ensure_iovec_buf_size(alts_grpc_record_protocol* rp, + const grpc_slice_buffer* sb) { + GPR_ASSERT(rp != nullptr && sb != nullptr); + if (sb->count <= rp->iovec_buf_length) { + return; + } + /* At least double the iovec buffer size. */ + rp->iovec_buf_length = std::max(sb->count, 2 * rp->iovec_buf_length); + rp->iovec_buf = static_cast( + gpr_realloc(rp->iovec_buf, rp->iovec_buf_length * sizeof(iovec_t))); +} + +/* --- Implementation of methods defined in tsi_grpc_record_protocol_common.h. + * --- */ + +void alts_grpc_record_protocol_convert_slice_buffer_to_iovec( + alts_grpc_record_protocol* rp, const grpc_slice_buffer* sb) { + GPR_ASSERT(rp != nullptr && sb != nullptr); + ensure_iovec_buf_size(rp, sb); + for (size_t i = 0; i < sb->count; i++) { + rp->iovec_buf[i].iov_base = GRPC_SLICE_START_PTR(sb->slices[i]); + rp->iovec_buf[i].iov_len = GRPC_SLICE_LENGTH(sb->slices[i]); + } +} + +void alts_grpc_record_protocol_copy_slice_buffer(const grpc_slice_buffer* src, + unsigned char* dst) { + GPR_ASSERT(src != nullptr && dst != nullptr); + for (size_t i = 0; i < src->count; i++) { + size_t slice_length = GRPC_SLICE_LENGTH(src->slices[i]); + memcpy(dst, GRPC_SLICE_START_PTR(src->slices[i]), slice_length); + dst += slice_length; + } +} + +iovec_t alts_grpc_record_protocol_get_header_iovec( + alts_grpc_record_protocol* rp) { + iovec_t header_iovec = {nullptr, 0}; + if (rp == nullptr) { + return header_iovec; + } + header_iovec.iov_len = rp->header_length; + if (rp->header_sb.count == 1) { + header_iovec.iov_base = GRPC_SLICE_START_PTR(rp->header_sb.slices[0]); + } else { + /* Frame header is in multiple slices, copies the header bytes from slice + * buffer to a single flat buffer. */ + alts_grpc_record_protocol_copy_slice_buffer(&rp->header_sb, rp->header_buf); + header_iovec.iov_base = rp->header_buf; + } + return header_iovec; +} + +tsi_result alts_grpc_record_protocol_init(alts_grpc_record_protocol* rp, + gsec_aead_crypter* crypter, + size_t overflow_size, bool is_client, + bool is_integrity_only, + bool is_protect) { + if (rp == nullptr || crypter == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to alts_grpc_record_protocol init."); + return TSI_INVALID_ARGUMENT; + } + /* Creates alts_iovec_record_protocol. */ + char* error_details = nullptr; + grpc_status_code status = alts_iovec_record_protocol_create( + crypter, overflow_size, is_client, is_integrity_only, is_protect, + &rp->iovec_rp, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to create alts_iovec_record_protocol, %s.", + error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + /* Allocates header slice buffer. */ + grpc_slice_buffer_init(&rp->header_sb); + /* Allocates header buffer. */ + rp->header_length = alts_iovec_record_protocol_get_header_length(); + rp->header_buf = static_cast(gpr_malloc(rp->header_length)); + rp->tag_length = alts_iovec_record_protocol_get_tag_length(rp->iovec_rp); + /* Allocates iovec buffer. */ + rp->iovec_buf_length = kInitialIovecBufferSize; + rp->iovec_buf = + static_cast(gpr_malloc(rp->iovec_buf_length * sizeof(iovec_t))); + return TSI_OK; +} + +/* --- Implementation of methods defined in tsi_grpc_record_protocol.h. --- */ +tsi_result alts_grpc_record_protocol_protect( + alts_grpc_record_protocol* self, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + if (grpc_core::ExecCtx::Get() == nullptr || self == nullptr || + self->vtable == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->protect == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->protect(self, unprotected_slices, protected_slices); +} + +tsi_result alts_grpc_record_protocol_unprotect( + alts_grpc_record_protocol* self, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + if (grpc_core::ExecCtx::Get() == nullptr || self == nullptr || + self->vtable == nullptr || protected_slices == nullptr || + unprotected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->unprotect == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->unprotect(self, protected_slices, unprotected_slices); +} + +void alts_grpc_record_protocol_destroy(alts_grpc_record_protocol* self) { + if (self == nullptr) { + return; + } + if (self->vtable->destruct != nullptr) { + self->vtable->destruct(self); + } + alts_iovec_record_protocol_destroy(self->iovec_rp); + grpc_slice_buffer_destroy_internal(&self->header_sb); + gpr_free(self->header_buf); + gpr_free(self->iovec_buf); + gpr_free(self); +} + +/* Integrity-only and privacy-integrity share the same implementation. No need + * to call vtable. */ +size_t alts_grpc_record_protocol_max_unprotected_data_size( + const alts_grpc_record_protocol* self, size_t max_protected_frame_size) { + if (self == nullptr) { + return 0; + } + return alts_iovec_record_protocol_max_unprotected_data_size( + self->iovec_rp, max_protected_frame_size); +} diff --git a/src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.cc b/src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.cc new file mode 100644 index 00000000..26c18a45 --- /dev/null +++ b/src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.cc @@ -0,0 +1,478 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" + +#include +#include + +#include +#include + +#include "src/core/tsi/alts/frame_protector/alts_counter.h" + +struct alts_iovec_record_protocol { + alts_counter* ctr; + gsec_aead_crypter* crypter; + size_t tag_length; + bool is_integrity_only; + bool is_protect; +}; + +/* Copies error message to destination. */ +static void maybe_copy_error_msg(const char* src, char** dst) { + if (dst != nullptr && src != nullptr) { + *dst = static_cast(gpr_malloc(strlen(src) + 1)); + memcpy(*dst, src, strlen(src) + 1); + } +} + +/* Appends error message to destination. */ +static void maybe_append_error_msg(const char* appendix, char** dst) { + if (dst != nullptr && appendix != nullptr) { + int dst_len = static_cast(strlen(*dst)); + *dst = static_cast(realloc(*dst, dst_len + strlen(appendix) + 1)); + assert(*dst != nullptr); + memcpy(*dst + dst_len, appendix, strlen(appendix) + 1); + } +} + +/* Use little endian to interpret a string of bytes as uint32_t. */ +static uint32_t load_32_le(const unsigned char* buffer) { + return (static_cast(buffer[3]) << 24) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[1]) << 8) | + static_cast(buffer[0]); +} + +/* Store uint32_t as a string of little endian bytes. */ +static void store_32_le(uint32_t value, unsigned char* buffer) { + buffer[3] = static_cast(value >> 24) & 0xFF; + buffer[2] = static_cast(value >> 16) & 0xFF; + buffer[1] = static_cast(value >> 8) & 0xFF; + buffer[0] = static_cast(value) & 0xFF; +} + +/* Ensures header and tag iovec have sufficient length. */ +static grpc_status_code ensure_header_and_tag_length( + const alts_iovec_record_protocol* rp, iovec_t header, iovec_t tag, + char** error_details) { + if (rp == nullptr) { + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (header.iov_base == nullptr) { + maybe_copy_error_msg("Header is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (header.iov_len != alts_iovec_record_protocol_get_header_length()) { + maybe_copy_error_msg("Header length is incorrect.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (tag.iov_base == nullptr) { + maybe_copy_error_msg("Tag is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (tag.iov_len != rp->tag_length) { + maybe_copy_error_msg("Tag length is incorrect.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + return GRPC_STATUS_OK; +} + +/* Increments crypter counter and checks overflow. */ +static grpc_status_code increment_counter(alts_counter* counter, + char** error_details) { + if (counter == nullptr) { + return GRPC_STATUS_FAILED_PRECONDITION; + } + bool is_overflow = false; + grpc_status_code status = + alts_counter_increment(counter, &is_overflow, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + if (is_overflow) { + maybe_copy_error_msg("Crypter counter is overflowed.", error_details); + return GRPC_STATUS_INTERNAL; + } + return GRPC_STATUS_OK; +} + +/* Given an array of iovec, computes the total length of buffer. */ +static size_t get_total_length(const iovec_t* vec, size_t vec_length) { + size_t total_length = 0; + for (size_t i = 0; i < vec_length; ++i) { + total_length += vec[i].iov_len; + } + return total_length; +} + +/* Writes frame header given data and tag length. */ +static grpc_status_code write_frame_header(size_t data_length, + unsigned char* header, + char** error_details) { + if (header == nullptr) { + maybe_copy_error_msg("Header is nullptr.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + size_t frame_length = kZeroCopyFrameMessageTypeFieldSize + data_length; + store_32_le(static_cast(frame_length), header); + store_32_le(kZeroCopyFrameMessageType, + header + kZeroCopyFrameLengthFieldSize); + return GRPC_STATUS_OK; +} + +/* Verifies frame header given protected data length. */ +static grpc_status_code verify_frame_header(size_t data_length, + unsigned char* header, + char** error_details) { + if (header == nullptr) { + maybe_copy_error_msg("Header is nullptr.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + size_t frame_length = load_32_le(header); + if (frame_length != kZeroCopyFrameMessageTypeFieldSize + data_length) { + maybe_copy_error_msg("Bad frame length.", error_details); + return GRPC_STATUS_INTERNAL; + } + size_t message_type = load_32_le(header + kZeroCopyFrameLengthFieldSize); + if (message_type != kZeroCopyFrameMessageType) { + maybe_copy_error_msg("Unsupported message type.", error_details); + return GRPC_STATUS_INTERNAL; + } + return GRPC_STATUS_OK; +} + +/* --- alts_iovec_record_protocol methods implementation. --- */ + +size_t alts_iovec_record_protocol_get_header_length() { + return kZeroCopyFrameHeaderSize; +} + +size_t alts_iovec_record_protocol_get_tag_length( + const alts_iovec_record_protocol* rp) { + if (rp != nullptr) { + return rp->tag_length; + } + return 0; +} + +size_t alts_iovec_record_protocol_max_unprotected_data_size( + const alts_iovec_record_protocol* rp, size_t max_protected_frame_size) { + if (rp == nullptr) { + return 0; + } + size_t overhead_bytes_size = + kZeroCopyFrameMessageTypeFieldSize + rp->tag_length; + if (max_protected_frame_size <= overhead_bytes_size) return 0; + return max_protected_frame_size - overhead_bytes_size; +} + +grpc_status_code alts_iovec_record_protocol_integrity_only_protect( + alts_iovec_record_protocol* rp, const iovec_t* unprotected_vec, + size_t unprotected_vec_length, iovec_t header, iovec_t tag, + char** error_details) { + /* Input sanity checks. */ + if (rp == nullptr) { + maybe_copy_error_msg("Input iovec_record_protocol is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (!rp->is_integrity_only) { + maybe_copy_error_msg( + "Integrity-only operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (!rp->is_protect) { + maybe_copy_error_msg("Protect operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + grpc_status_code status = + ensure_header_and_tag_length(rp, header, tag, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Unprotected data should not be zero length. */ + size_t data_length = + get_total_length(unprotected_vec, unprotected_vec_length); + /* Sets frame header. */ + status = write_frame_header(data_length + rp->tag_length, + static_cast(header.iov_base), + error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Computes frame tag by calling AEAD crypter. */ + size_t bytes_written = 0; + status = gsec_aead_crypter_encrypt_iovec( + rp->crypter, alts_counter_get_counter(rp->ctr), + alts_counter_get_size(rp->ctr), unprotected_vec, unprotected_vec_length, + /* plaintext_vec = */ nullptr, /* plaintext_vec_length = */ 0, tag, + &bytes_written, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + if (bytes_written != rp->tag_length) { + maybe_copy_error_msg("Bytes written expects to be the same as tag length.", + error_details); + return GRPC_STATUS_INTERNAL; + } + /* Increments the crypter counter. */ + return increment_counter(rp->ctr, error_details); +} + +grpc_status_code alts_iovec_record_protocol_integrity_only_unprotect( + alts_iovec_record_protocol* rp, const iovec_t* protected_vec, + size_t protected_vec_length, iovec_t header, iovec_t tag, + char** error_details) { + /* Input sanity checks. */ + if (rp == nullptr) { + maybe_copy_error_msg("Input iovec_record_protocol is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (!rp->is_integrity_only) { + maybe_copy_error_msg( + "Integrity-only operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (rp->is_protect) { + maybe_copy_error_msg( + "Unprotect operations are not allowed for this object.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + grpc_status_code status = + ensure_header_and_tag_length(rp, header, tag, error_details); + if (status != GRPC_STATUS_OK) return status; + /* Protected data should not be zero length. */ + size_t data_length = get_total_length(protected_vec, protected_vec_length); + /* Verifies frame header. */ + status = verify_frame_header(data_length + rp->tag_length, + static_cast(header.iov_base), + error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Verifies frame tag by calling AEAD crypter. */ + iovec_t plaintext = {nullptr, 0}; + size_t bytes_written = 0; + status = gsec_aead_crypter_decrypt_iovec( + rp->crypter, alts_counter_get_counter(rp->ctr), + alts_counter_get_size(rp->ctr), protected_vec, protected_vec_length, &tag, + 1, plaintext, &bytes_written, error_details); + if (status != GRPC_STATUS_OK || bytes_written != 0) { + maybe_append_error_msg(" Frame tag verification failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + /* Increments the crypter counter. */ + return increment_counter(rp->ctr, error_details); +} + +grpc_status_code alts_iovec_record_protocol_privacy_integrity_protect( + alts_iovec_record_protocol* rp, const iovec_t* unprotected_vec, + size_t unprotected_vec_length, iovec_t protected_frame, + char** error_details) { + /* Input sanity checks. */ + if (rp == nullptr) { + maybe_copy_error_msg("Input iovec_record_protocol is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (rp->is_integrity_only) { + maybe_copy_error_msg( + "Privacy-integrity operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (!rp->is_protect) { + maybe_copy_error_msg("Protect operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + /* Unprotected data should not be zero length. */ + size_t data_length = + get_total_length(unprotected_vec, unprotected_vec_length); + /* Ensures protected frame iovec has sufficient size. */ + if (protected_frame.iov_base == nullptr) { + maybe_copy_error_msg("Protected frame is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (protected_frame.iov_len != + alts_iovec_record_protocol_get_header_length() + data_length + + rp->tag_length) { + maybe_copy_error_msg("Protected frame size is incorrect.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + /* Writer frame header. */ + grpc_status_code status = write_frame_header( + data_length + rp->tag_length, + static_cast(protected_frame.iov_base), error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Encrypt unprotected data by calling AEAD crypter. */ + unsigned char* ciphertext_buffer = + static_cast(protected_frame.iov_base) + + alts_iovec_record_protocol_get_header_length(); + iovec_t ciphertext = {ciphertext_buffer, data_length + rp->tag_length}; + size_t bytes_written = 0; + status = gsec_aead_crypter_encrypt_iovec( + rp->crypter, alts_counter_get_counter(rp->ctr), + alts_counter_get_size(rp->ctr), /* aad_vec = */ nullptr, + /* aad_vec_length = */ 0, unprotected_vec, unprotected_vec_length, + ciphertext, &bytes_written, error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + if (bytes_written != data_length + rp->tag_length) { + maybe_copy_error_msg( + "Bytes written expects to be data length plus tag length.", + error_details); + return GRPC_STATUS_INTERNAL; + } + /* Increments the crypter counter. */ + return increment_counter(rp->ctr, error_details); +} + +grpc_status_code alts_iovec_record_protocol_privacy_integrity_unprotect( + alts_iovec_record_protocol* rp, iovec_t header, + const iovec_t* protected_vec, size_t protected_vec_length, + iovec_t unprotected_data, char** error_details) { + /* Input sanity checks. */ + if (rp == nullptr) { + maybe_copy_error_msg("Input iovec_record_protocol is nullptr.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (rp->is_integrity_only) { + maybe_copy_error_msg( + "Privacy-integrity operations are not allowed for this object.", + error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + if (rp->is_protect) { + maybe_copy_error_msg( + "Unprotect operations are not allowed for this object.", error_details); + return GRPC_STATUS_FAILED_PRECONDITION; + } + /* Protected data size should be no less than tag size. */ + size_t protected_data_length = + get_total_length(protected_vec, protected_vec_length); + if (protected_data_length < rp->tag_length) { + maybe_copy_error_msg( + "Protected data length should be more than the tag length.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + /* Ensures header has sufficient size. */ + if (header.iov_base == nullptr) { + maybe_copy_error_msg("Header is nullptr.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + if (header.iov_len != alts_iovec_record_protocol_get_header_length()) { + maybe_copy_error_msg("Header length is incorrect.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + /* Ensures unprotected data iovec has sufficient size. */ + if (unprotected_data.iov_len != protected_data_length - rp->tag_length) { + maybe_copy_error_msg("Unprotected data size is incorrect.", error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + /* Verify frame header. */ + grpc_status_code status = verify_frame_header( + protected_data_length, static_cast(header.iov_base), + error_details); + if (status != GRPC_STATUS_OK) { + return status; + } + /* Decrypt protected data by calling AEAD crypter. */ + size_t bytes_written = 0; + status = gsec_aead_crypter_decrypt_iovec( + rp->crypter, alts_counter_get_counter(rp->ctr), + alts_counter_get_size(rp->ctr), /* aad_vec = */ nullptr, + /* aad_vec_length = */ 0, protected_vec, protected_vec_length, + unprotected_data, &bytes_written, error_details); + if (status != GRPC_STATUS_OK) { + maybe_append_error_msg(" Frame decryption failed.", error_details); + return GRPC_STATUS_INTERNAL; + } + if (bytes_written != protected_data_length - rp->tag_length) { + maybe_copy_error_msg( + "Bytes written expects to be protected data length minus tag length.", + error_details); + return GRPC_STATUS_INTERNAL; + } + /* Increments the crypter counter. */ + return increment_counter(rp->ctr, error_details); +} + +grpc_status_code alts_iovec_record_protocol_create( + gsec_aead_crypter* crypter, size_t overflow_size, bool is_client, + bool is_integrity_only, bool is_protect, alts_iovec_record_protocol** rp, + char** error_details) { + if (crypter == nullptr || rp == nullptr) { + maybe_copy_error_msg( + "Invalid nullptr arguments to alts_iovec_record_protocol create.", + error_details); + return GRPC_STATUS_INVALID_ARGUMENT; + } + alts_iovec_record_protocol* impl = static_cast( + gpr_zalloc(sizeof(alts_iovec_record_protocol))); + /* Gets counter length. */ + size_t counter_length = 0; + grpc_status_code status = + gsec_aead_crypter_nonce_length(crypter, &counter_length, error_details); + if (status != GRPC_STATUS_OK) { + goto cleanup; + } + /* Creates counters. */ + status = + alts_counter_create(is_protect ? !is_client : is_client, counter_length, + overflow_size, &impl->ctr, error_details); + if (status != GRPC_STATUS_OK) { + goto cleanup; + } + /* Gets tag length. */ + status = + gsec_aead_crypter_tag_length(crypter, &impl->tag_length, error_details); + if (status != GRPC_STATUS_OK) { + goto cleanup; + } + impl->crypter = crypter; + impl->is_integrity_only = is_integrity_only; + impl->is_protect = is_protect; + *rp = impl; + return GRPC_STATUS_OK; +cleanup: + alts_counter_destroy(impl->ctr); + gpr_free(impl); + return GRPC_STATUS_FAILED_PRECONDITION; +} + +void alts_iovec_record_protocol_destroy(alts_iovec_record_protocol* rp) { + if (rp != nullptr) { + alts_counter_destroy(rp->ctr); + gsec_aead_crypter_destroy(rp->crypter); + gpr_free(rp); + } +} diff --git a/src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.cc b/src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.cc new file mode 100644 index 00000000..d5455b6d --- /dev/null +++ b/src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.cc @@ -0,0 +1,307 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" + +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/alts/crypt/gsec.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" +#include "src/core/tsi/transport_security_grpc.h" + +constexpr size_t kMinFrameLength = 1024; +constexpr size_t kDefaultFrameLength = 16 * 1024; +constexpr size_t kMaxFrameLength = 16 * 1024 * 1024; + +/** + * Main struct for alts_zero_copy_grpc_protector. + * We choose to have two alts_grpc_record_protocol objects and two sets of slice + * buffers: one for protect and the other for unprotect, so that protect and + * unprotect can be executed in parallel. Implementations of this object must be + * thread compatible. + */ +typedef struct alts_zero_copy_grpc_protector { + tsi_zero_copy_grpc_protector base; + alts_grpc_record_protocol* record_protocol; + alts_grpc_record_protocol* unrecord_protocol; + size_t max_protected_frame_size; + size_t max_unprotected_data_size; + grpc_slice_buffer unprotected_staging_sb; + grpc_slice_buffer protected_sb; + grpc_slice_buffer protected_staging_sb; + uint32_t parsed_frame_size; +} alts_zero_copy_grpc_protector; + +/** + * Given a slice buffer, parses the first 4 bytes little-endian unsigned frame + * size and returns the total frame size including the frame field. Caller + * needs to make sure the input slice buffer has at least 4 bytes. Returns true + * on success and false on failure. + */ +static bool read_frame_size(const grpc_slice_buffer* sb, + uint32_t* total_frame_size) { + if (sb == nullptr || sb->length < kZeroCopyFrameLengthFieldSize) { + return false; + } + uint8_t frame_size_buffer[kZeroCopyFrameLengthFieldSize]; + uint8_t* buf = frame_size_buffer; + /* Copies the first 4 bytes to a temporary buffer. */ + size_t remaining = kZeroCopyFrameLengthFieldSize; + for (size_t i = 0; i < sb->count; i++) { + size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]); + if (remaining <= slice_length) { + memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining); + remaining = 0; + break; + } else { + memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length); + buf += slice_length; + remaining -= slice_length; + } + } + GPR_ASSERT(remaining == 0); + /* Gets little-endian frame size. */ + uint32_t frame_size = (static_cast(frame_size_buffer[3]) << 24) | + (static_cast(frame_size_buffer[2]) << 16) | + (static_cast(frame_size_buffer[1]) << 8) | + static_cast(frame_size_buffer[0]); + if (frame_size > kMaxFrameLength) { + gpr_log(GPR_ERROR, "Frame size is larger than maximum frame size"); + return false; + } + /* Returns frame size including frame length field. */ + *total_frame_size = + static_cast(frame_size + kZeroCopyFrameLengthFieldSize); + return true; +} + +/** + * Creates an alts_grpc_record_protocol object, given key, key size, and flags + * to indicate whether the record_protocol object uses the rekeying AEAD, + * whether the object is for client or server, whether the object is for + * integrity-only or privacy-integrity mode, and whether the object is used + * for protect or unprotect. + */ +static tsi_result create_alts_grpc_record_protocol( + const uint8_t* key, size_t key_size, bool is_rekey, bool is_client, + bool is_integrity_only, bool is_protect, bool enable_extra_copy, + alts_grpc_record_protocol** record_protocol) { + if (key == nullptr || record_protocol == nullptr) { + return TSI_INVALID_ARGUMENT; + } + grpc_status_code status; + gsec_aead_crypter* crypter = nullptr; + char* error_details = nullptr; + status = gsec_aes_gcm_aead_crypter_create(key, key_size, kAesGcmNonceLength, + kAesGcmTagLength, is_rekey, + &crypter, &error_details); + if (status != GRPC_STATUS_OK) { + gpr_log(GPR_ERROR, "Failed to create AEAD crypter, %s", error_details); + gpr_free(error_details); + return TSI_INTERNAL_ERROR; + } + size_t overflow_limit = is_rekey ? kAltsRecordProtocolRekeyFrameLimit + : kAltsRecordProtocolFrameLimit; + /* Creates alts_grpc_record_protocol with AEAD crypter ownership transferred. + */ + tsi_result result = is_integrity_only + ? alts_grpc_integrity_only_record_protocol_create( + crypter, overflow_limit, is_client, is_protect, + enable_extra_copy, record_protocol) + : alts_grpc_privacy_integrity_record_protocol_create( + crypter, overflow_limit, is_client, is_protect, + record_protocol); + if (result != TSI_OK) { + gsec_aead_crypter_destroy(crypter); + return result; + } + return TSI_OK; +} + +/* --- tsi_zero_copy_grpc_protector methods implementation. --- */ + +static tsi_result alts_zero_copy_grpc_protector_protect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + if (self == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + gpr_log(GPR_ERROR, "Invalid nullptr arguments to zero-copy grpc protect."); + return TSI_INVALID_ARGUMENT; + } + alts_zero_copy_grpc_protector* protector = + reinterpret_cast(self); + /* Calls alts_grpc_record_protocol protect repeatly. */ + while (unprotected_slices->length > protector->max_unprotected_data_size) { + grpc_slice_buffer_move_first(unprotected_slices, + protector->max_unprotected_data_size, + &protector->unprotected_staging_sb); + tsi_result status = alts_grpc_record_protocol_protect( + protector->record_protocol, &protector->unprotected_staging_sb, + protected_slices); + if (status != TSI_OK) { + return status; + } + } + return alts_grpc_record_protocol_protect( + protector->record_protocol, unprotected_slices, protected_slices); +} + +static tsi_result alts_zero_copy_grpc_protector_unprotect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + if (self == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to zero-copy grpc unprotect."); + return TSI_INVALID_ARGUMENT; + } + alts_zero_copy_grpc_protector* protector = + reinterpret_cast(self); + grpc_slice_buffer_move_into(protected_slices, &protector->protected_sb); + /* Keep unprotecting each frame if possible. */ + while (protector->protected_sb.length >= kZeroCopyFrameLengthFieldSize) { + if (protector->parsed_frame_size == 0) { + /* We have not parsed frame size yet. Parses frame size. */ + if (!read_frame_size(&protector->protected_sb, + &protector->parsed_frame_size)) { + grpc_slice_buffer_reset_and_unref_internal(&protector->protected_sb); + return TSI_DATA_CORRUPTED; + } + } + if (protector->protected_sb.length < protector->parsed_frame_size) break; + /* At this point, protected_sb contains at least one frame of data. */ + tsi_result status; + if (protector->protected_sb.length == protector->parsed_frame_size) { + status = alts_grpc_record_protocol_unprotect(protector->unrecord_protocol, + &protector->protected_sb, + unprotected_slices); + } else { + grpc_slice_buffer_move_first(&protector->protected_sb, + protector->parsed_frame_size, + &protector->protected_staging_sb); + status = alts_grpc_record_protocol_unprotect( + protector->unrecord_protocol, &protector->protected_staging_sb, + unprotected_slices); + } + protector->parsed_frame_size = 0; + if (status != TSI_OK) { + grpc_slice_buffer_reset_and_unref_internal(&protector->protected_sb); + return status; + } + } + return TSI_OK; +} + +static void alts_zero_copy_grpc_protector_destroy( + tsi_zero_copy_grpc_protector* self) { + if (self == nullptr) { + return; + } + alts_zero_copy_grpc_protector* protector = + reinterpret_cast(self); + alts_grpc_record_protocol_destroy(protector->record_protocol); + alts_grpc_record_protocol_destroy(protector->unrecord_protocol); + grpc_slice_buffer_destroy_internal(&protector->unprotected_staging_sb); + grpc_slice_buffer_destroy_internal(&protector->protected_sb); + grpc_slice_buffer_destroy_internal(&protector->protected_staging_sb); + gpr_free(protector); +} + +static tsi_result alts_zero_copy_grpc_protector_max_frame_size( + tsi_zero_copy_grpc_protector* self, size_t* max_frame_size) { + if (self == nullptr || max_frame_size == nullptr) return TSI_INVALID_ARGUMENT; + alts_zero_copy_grpc_protector* protector = + reinterpret_cast(self); + *max_frame_size = protector->max_protected_frame_size; + return TSI_OK; +} + +static const tsi_zero_copy_grpc_protector_vtable + alts_zero_copy_grpc_protector_vtable = { + alts_zero_copy_grpc_protector_protect, + alts_zero_copy_grpc_protector_unprotect, + alts_zero_copy_grpc_protector_destroy, + alts_zero_copy_grpc_protector_max_frame_size}; + +tsi_result alts_zero_copy_grpc_protector_create( + const uint8_t* key, size_t key_size, bool is_rekey, bool is_client, + bool is_integrity_only, bool enable_extra_copy, + size_t* max_protected_frame_size, + tsi_zero_copy_grpc_protector** protector) { + if (grpc_core::ExecCtx::Get() == nullptr || key == nullptr || + protector == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid nullptr arguments to alts_zero_copy_grpc_protector create."); + return TSI_INVALID_ARGUMENT; + } + /* Creates alts_zero_copy_protector. */ + alts_zero_copy_grpc_protector* impl = + static_cast( + gpr_zalloc(sizeof(alts_zero_copy_grpc_protector))); + /* Creates alts_grpc_record_protocol objects. */ + tsi_result status = create_alts_grpc_record_protocol( + key, key_size, is_rekey, is_client, is_integrity_only, + /*is_protect=*/true, enable_extra_copy, &impl->record_protocol); + if (status == TSI_OK) { + status = create_alts_grpc_record_protocol( + key, key_size, is_rekey, is_client, is_integrity_only, + /*is_protect=*/false, enable_extra_copy, &impl->unrecord_protocol); + if (status == TSI_OK) { + /* Sets maximum frame size. */ + size_t max_protected_frame_size_to_set = kDefaultFrameLength; + if (max_protected_frame_size != nullptr) { + *max_protected_frame_size = + std::min(*max_protected_frame_size, kMaxFrameLength); + *max_protected_frame_size = + std::max(*max_protected_frame_size, kMinFrameLength); + max_protected_frame_size_to_set = *max_protected_frame_size; + } + impl->max_protected_frame_size = max_protected_frame_size_to_set; + impl->max_unprotected_data_size = + alts_grpc_record_protocol_max_unprotected_data_size( + impl->record_protocol, max_protected_frame_size_to_set); + GPR_ASSERT(impl->max_unprotected_data_size > 0); + /* Allocates internal slice buffers. */ + grpc_slice_buffer_init(&impl->unprotected_staging_sb); + grpc_slice_buffer_init(&impl->protected_sb); + grpc_slice_buffer_init(&impl->protected_staging_sb); + impl->parsed_frame_size = 0; + impl->base.vtable = &alts_zero_copy_grpc_protector_vtable; + *protector = &impl->base; + return TSI_OK; + } + } + + /* Cleanup if create failed. */ + alts_grpc_record_protocol_destroy(impl->record_protocol); + alts_grpc_record_protocol_destroy(impl->unrecord_protocol); + gpr_free(impl); + return TSI_INTERNAL_ERROR; +} diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc new file mode 100644 index 00000000..ce09cdb7 --- /dev/null +++ b/src/core/tsi/fake_transport_security.cc @@ -0,0 +1,809 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/fake_transport_security.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/transport_security_grpc.h" + +/* --- Constants. ---*/ +#define TSI_FAKE_FRAME_HEADER_SIZE 4 +#define TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE 64 +#define TSI_FAKE_DEFAULT_FRAME_SIZE 16384 +#define TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 256 + +/* --- Structure definitions. ---*/ + +/* a frame is encoded like this: + | size | data | + where the size field value is the size of the size field plus the size of + the data encoded in little endian on 4 bytes. */ +struct tsi_fake_frame { + unsigned char* data; + size_t size; + size_t allocated_size; + size_t offset; + int needs_draining; +}; +typedef enum { + TSI_FAKE_CLIENT_INIT = 0, + TSI_FAKE_SERVER_INIT = 1, + TSI_FAKE_CLIENT_FINISHED = 2, + TSI_FAKE_SERVER_FINISHED = 3, + TSI_FAKE_HANDSHAKE_MESSAGE_MAX = 4 +} tsi_fake_handshake_message; + +struct tsi_fake_handshaker { + tsi_handshaker base; + int is_client; + tsi_fake_handshake_message next_message_to_send; + int needs_incoming_message; + tsi_fake_frame incoming_frame; + tsi_fake_frame outgoing_frame; + unsigned char* outgoing_bytes_buffer; + size_t outgoing_bytes_buffer_size; + tsi_result result; +}; +struct tsi_fake_frame_protector { + tsi_frame_protector base; + tsi_fake_frame protect_frame; + tsi_fake_frame unprotect_frame; + size_t max_frame_size; +}; +struct tsi_fake_zero_copy_grpc_protector { + tsi_zero_copy_grpc_protector base; + grpc_slice_buffer header_sb; + grpc_slice_buffer protected_sb; + size_t max_frame_size; + size_t parsed_frame_size; +}; +/* --- Utils. ---*/ + +static const char* tsi_fake_handshake_message_strings[] = { + "CLIENT_INIT", "SERVER_INIT", "CLIENT_FINISHED", "SERVER_FINISHED"}; + +static const char* tsi_fake_handshake_message_to_string(int msg) { + if (msg < 0 || msg >= TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { + gpr_log(GPR_ERROR, "Invalid message %d", msg); + return "UNKNOWN"; + } + return tsi_fake_handshake_message_strings[msg]; +} + +static tsi_result tsi_fake_handshake_message_from_string( + const char* msg_string, tsi_fake_handshake_message* msg) { + for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) { + if (strncmp(msg_string, tsi_fake_handshake_message_strings[i], + strlen(tsi_fake_handshake_message_strings[i])) == 0) { + *msg = static_cast(i); + return TSI_OK; + } + } + gpr_log(GPR_ERROR, "Invalid handshake message."); + return TSI_DATA_CORRUPTED; +} + +static uint32_t load32_little_endian(const unsigned char* buf) { + return (static_cast(buf[0]) | static_cast(buf[1] << 8) | + static_cast(buf[2] << 16) | + static_cast(buf[3] << 24)); +} + +static void store32_little_endian(uint32_t value, unsigned char* buf) { + buf[3] = static_cast((value >> 24) & 0xFF); + buf[2] = static_cast((value >> 16) & 0xFF); + buf[1] = static_cast((value >> 8) & 0xFF); + buf[0] = static_cast((value)&0xFF); +} + +static uint32_t read_frame_size(const grpc_slice_buffer* sb) { + GPR_ASSERT(sb != nullptr && sb->length >= TSI_FAKE_FRAME_HEADER_SIZE); + uint8_t frame_size_buffer[TSI_FAKE_FRAME_HEADER_SIZE]; + uint8_t* buf = frame_size_buffer; + /* Copies the first 4 bytes to a temporary buffer. */ + size_t remaining = TSI_FAKE_FRAME_HEADER_SIZE; + for (size_t i = 0; i < sb->count; i++) { + size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]); + if (remaining <= slice_length) { + memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining); + remaining = 0; + break; + } else { + memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length); + buf += slice_length; + remaining -= slice_length; + } + } + GPR_ASSERT(remaining == 0); + return load32_little_endian(frame_size_buffer); +} + +static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) { + frame->offset = 0; + frame->needs_draining = needs_draining; + if (!needs_draining) frame->size = 0; +} + +/* Checks if the frame's allocated size is at least frame->size, and reallocs + * more memory if necessary. */ +static void tsi_fake_frame_ensure_size(tsi_fake_frame* frame) { + if (frame->data == nullptr) { + frame->allocated_size = frame->size; + frame->data = + static_cast(gpr_malloc(frame->allocated_size)); + } else if (frame->size > frame->allocated_size) { + unsigned char* new_data = + static_cast(gpr_realloc(frame->data, frame->size)); + frame->data = new_data; + frame->allocated_size = frame->size; + } +} + +/* Decodes the serialized fake frame contained in incoming_bytes, and fills + * frame with the contents of the decoded frame. + * This method should not be called if frame->needs_framing is not 0. */ +static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes, + size_t* incoming_bytes_size, + tsi_fake_frame* frame) { + size_t available_size = *incoming_bytes_size; + size_t to_read_size = 0; + const unsigned char* bytes_cursor = incoming_bytes; + + if (frame->needs_draining) return TSI_INTERNAL_ERROR; + if (frame->data == nullptr) { + frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE; + frame->data = + static_cast(gpr_malloc(frame->allocated_size)); + } + + if (frame->offset < TSI_FAKE_FRAME_HEADER_SIZE) { + to_read_size = TSI_FAKE_FRAME_HEADER_SIZE - frame->offset; + if (to_read_size > available_size) { + /* Just fill what we can and exit. */ + memcpy(frame->data + frame->offset, bytes_cursor, available_size); + bytes_cursor += available_size; + frame->offset += available_size; + *incoming_bytes_size = static_cast(bytes_cursor - incoming_bytes); + return TSI_INCOMPLETE_DATA; + } + memcpy(frame->data + frame->offset, bytes_cursor, to_read_size); + bytes_cursor += to_read_size; + frame->offset += to_read_size; + available_size -= to_read_size; + frame->size = load32_little_endian(frame->data); + tsi_fake_frame_ensure_size(frame); + } + + to_read_size = frame->size - frame->offset; + if (to_read_size > available_size) { + memcpy(frame->data + frame->offset, bytes_cursor, available_size); + frame->offset += available_size; + bytes_cursor += available_size; + *incoming_bytes_size = static_cast(bytes_cursor - incoming_bytes); + return TSI_INCOMPLETE_DATA; + } + memcpy(frame->data + frame->offset, bytes_cursor, to_read_size); + bytes_cursor += to_read_size; + *incoming_bytes_size = static_cast(bytes_cursor - incoming_bytes); + tsi_fake_frame_reset(frame, 1 /* needs_draining */); + return TSI_OK; +} + +/* Encodes a fake frame into its wire format and places the result in + * outgoing_bytes. outgoing_bytes_size indicates the size of the encoded frame. + * This method should not be called if frame->needs_framing is 0. */ +static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes, + size_t* outgoing_bytes_size, + tsi_fake_frame* frame) { + size_t to_write_size = frame->size - frame->offset; + if (!frame->needs_draining) return TSI_INTERNAL_ERROR; + if (*outgoing_bytes_size < to_write_size) { + memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size); + frame->offset += *outgoing_bytes_size; + return TSI_INCOMPLETE_DATA; + } + memcpy(outgoing_bytes, frame->data + frame->offset, to_write_size); + *outgoing_bytes_size = to_write_size; + tsi_fake_frame_reset(frame, 0 /* needs_draining */); + return TSI_OK; +} + +/* Sets the payload of a fake frame to contain the given data blob, where + * data_size indicates the size of data. */ +static tsi_result tsi_fake_frame_set_data(unsigned char* data, size_t data_size, + tsi_fake_frame* frame) { + frame->offset = 0; + frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE; + tsi_fake_frame_ensure_size(frame); + store32_little_endian(static_cast(frame->size), frame->data); + memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size); + tsi_fake_frame_reset(frame, 1 /* needs draining */); + return TSI_OK; +} + +/* Destroys the contents of a fake frame. */ +static void tsi_fake_frame_destruct(tsi_fake_frame* frame) { + if (frame->data != nullptr) gpr_free(frame->data); +} + +/* --- tsi_frame_protector methods implementation. ---*/ + +static tsi_result fake_protector_protect(tsi_frame_protector* self, + const unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size, + unsigned char* protected_output_frames, + size_t* protected_output_frames_size) { + tsi_result result = TSI_OK; + tsi_fake_frame_protector* impl = + reinterpret_cast(self); + unsigned char frame_header[TSI_FAKE_FRAME_HEADER_SIZE]; + tsi_fake_frame* frame = &impl->protect_frame; + size_t saved_output_size = *protected_output_frames_size; + size_t drained_size = 0; + size_t* num_bytes_written = protected_output_frames_size; + *num_bytes_written = 0; + + /* Try to drain first. */ + if (frame->needs_draining) { + drained_size = saved_output_size - *num_bytes_written; + result = + tsi_fake_frame_encode(protected_output_frames, &drained_size, frame); + *num_bytes_written += drained_size; + protected_output_frames += drained_size; + if (result != TSI_OK) { + if (result == TSI_INCOMPLETE_DATA) { + *unprotected_bytes_size = 0; + result = TSI_OK; + } + return result; + } + } + + /* Now process the unprotected_bytes. */ + if (frame->needs_draining) return TSI_INTERNAL_ERROR; + if (frame->size == 0) { + /* New frame, create a header. */ + size_t written_in_frame_size = 0; + store32_little_endian(static_cast(impl->max_frame_size), + frame_header); + written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE; + result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame); + if (result != TSI_INCOMPLETE_DATA) { + gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s", + tsi_result_to_string(result)); + return result; + } + } + result = + tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame); + if (result != TSI_OK) { + if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; + return result; + } + + /* Try to drain again. */ + if (!frame->needs_draining) return TSI_INTERNAL_ERROR; + if (frame->offset != 0) return TSI_INTERNAL_ERROR; + drained_size = saved_output_size - *num_bytes_written; + result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame); + *num_bytes_written += drained_size; + if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; + return result; +} + +static tsi_result fake_protector_protect_flush( + tsi_frame_protector* self, unsigned char* protected_output_frames, + size_t* protected_output_frames_size, size_t* still_pending_size) { + tsi_result result = TSI_OK; + tsi_fake_frame_protector* impl = + reinterpret_cast(self); + tsi_fake_frame* frame = &impl->protect_frame; + if (!frame->needs_draining) { + /* Create a short frame. */ + frame->size = frame->offset; + frame->offset = 0; + frame->needs_draining = 1; + store32_little_endian(static_cast(frame->size), + frame->data); /* Overwrite header. */ + } + result = tsi_fake_frame_encode(protected_output_frames, + protected_output_frames_size, frame); + if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; + *still_pending_size = frame->size - frame->offset; + return result; +} + +static tsi_result fake_protector_unprotect( + tsi_frame_protector* self, const unsigned char* protected_frames_bytes, + size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size) { + tsi_result result = TSI_OK; + tsi_fake_frame_protector* impl = + reinterpret_cast(self); + tsi_fake_frame* frame = &impl->unprotect_frame; + size_t saved_output_size = *unprotected_bytes_size; + size_t drained_size = 0; + size_t* num_bytes_written = unprotected_bytes_size; + *num_bytes_written = 0; + + /* Try to drain first. */ + if (frame->needs_draining) { + /* Go past the header if needed. */ + if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; + drained_size = saved_output_size - *num_bytes_written; + result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame); + unprotected_bytes += drained_size; + *num_bytes_written += drained_size; + if (result != TSI_OK) { + if (result == TSI_INCOMPLETE_DATA) { + *protected_frames_bytes_size = 0; + result = TSI_OK; + } + return result; + } + } + + /* Now process the protected_bytes. */ + if (frame->needs_draining) return TSI_INTERNAL_ERROR; + result = tsi_fake_frame_decode(protected_frames_bytes, + protected_frames_bytes_size, frame); + if (result != TSI_OK) { + if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; + return result; + } + + /* Try to drain again. */ + if (!frame->needs_draining) return TSI_INTERNAL_ERROR; + if (frame->offset != 0) return TSI_INTERNAL_ERROR; + frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; /* Go past the header. */ + drained_size = saved_output_size - *num_bytes_written; + result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame); + *num_bytes_written += drained_size; + if (result == TSI_INCOMPLETE_DATA) result = TSI_OK; + return result; +} + +static void fake_protector_destroy(tsi_frame_protector* self) { + tsi_fake_frame_protector* impl = + reinterpret_cast(self); + tsi_fake_frame_destruct(&impl->protect_frame); + tsi_fake_frame_destruct(&impl->unprotect_frame); + gpr_free(self); +} + +static const tsi_frame_protector_vtable frame_protector_vtable = { + fake_protector_protect, + fake_protector_protect_flush, + fake_protector_unprotect, + fake_protector_destroy, +}; + +/* --- tsi_zero_copy_grpc_protector methods implementation. ---*/ + +static tsi_result fake_zero_copy_grpc_protector_protect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + if (self == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + tsi_fake_zero_copy_grpc_protector* impl = + reinterpret_cast(self); + /* Protects each frame. */ + while (unprotected_slices->length > 0) { + size_t frame_length = + std::min(impl->max_frame_size, + unprotected_slices->length + TSI_FAKE_FRAME_HEADER_SIZE); + grpc_slice slice = GRPC_SLICE_MALLOC(TSI_FAKE_FRAME_HEADER_SIZE); + store32_little_endian(static_cast(frame_length), + GRPC_SLICE_START_PTR(slice)); + grpc_slice_buffer_add(protected_slices, slice); + size_t data_length = frame_length - TSI_FAKE_FRAME_HEADER_SIZE; + grpc_slice_buffer_move_first(unprotected_slices, data_length, + protected_slices); + } + return TSI_OK; +} + +static tsi_result fake_zero_copy_grpc_protector_unprotect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + if (self == nullptr || unprotected_slices == nullptr || + protected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + tsi_fake_zero_copy_grpc_protector* impl = + reinterpret_cast(self); + grpc_slice_buffer_move_into(protected_slices, &impl->protected_sb); + /* Unprotect each frame, if we get a full frame. */ + while (impl->protected_sb.length >= TSI_FAKE_FRAME_HEADER_SIZE) { + if (impl->parsed_frame_size == 0) { + impl->parsed_frame_size = read_frame_size(&impl->protected_sb); + if (impl->parsed_frame_size <= 4) { + gpr_log(GPR_ERROR, "Invalid frame size."); + return TSI_DATA_CORRUPTED; + } + } + /* If we do not have a full frame, return with OK status. */ + if (impl->protected_sb.length < impl->parsed_frame_size) break; + /* Strips header bytes. */ + grpc_slice_buffer_move_first(&impl->protected_sb, + TSI_FAKE_FRAME_HEADER_SIZE, &impl->header_sb); + /* Moves data to unprotected slices. */ + grpc_slice_buffer_move_first( + &impl->protected_sb, + impl->parsed_frame_size - TSI_FAKE_FRAME_HEADER_SIZE, + unprotected_slices); + impl->parsed_frame_size = 0; + grpc_slice_buffer_reset_and_unref_internal(&impl->header_sb); + } + return TSI_OK; +} + +static void fake_zero_copy_grpc_protector_destroy( + tsi_zero_copy_grpc_protector* self) { + if (self == nullptr) return; + tsi_fake_zero_copy_grpc_protector* impl = + reinterpret_cast(self); + grpc_slice_buffer_destroy_internal(&impl->header_sb); + grpc_slice_buffer_destroy_internal(&impl->protected_sb); + gpr_free(impl); +} + +static tsi_result fake_zero_copy_grpc_protector_max_frame_size( + tsi_zero_copy_grpc_protector* self, size_t* max_frame_size) { + if (self == nullptr || max_frame_size == nullptr) return TSI_INVALID_ARGUMENT; + tsi_fake_zero_copy_grpc_protector* impl = + reinterpret_cast(self); + *max_frame_size = impl->max_frame_size; + return TSI_OK; +} + +static const tsi_zero_copy_grpc_protector_vtable + zero_copy_grpc_protector_vtable = { + fake_zero_copy_grpc_protector_protect, + fake_zero_copy_grpc_protector_unprotect, + fake_zero_copy_grpc_protector_destroy, + fake_zero_copy_grpc_protector_max_frame_size, +}; + +/* --- tsi_handshaker_result methods implementation. ---*/ + +struct fake_handshaker_result { + tsi_handshaker_result base; + unsigned char* unused_bytes; + size_t unused_bytes_size; +}; + +static tsi_result fake_handshaker_result_extract_peer( + const tsi_handshaker_result* /*self*/, tsi_peer* peer) { + /* Construct a tsi_peer with 1 property: certificate type, security_level. */ + tsi_result result = tsi_construct_peer(2, peer); + if (result != TSI_OK) return result; + result = tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_FAKE_CERTIFICATE_TYPE, + &peer->properties[0]); + if (result != TSI_OK) tsi_peer_destruct(peer); + result = tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_SECURITY_NONE), &peer->properties[1]); + if (result != TSI_OK) tsi_peer_destruct(peer); + return result; +} + +static tsi_result fake_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY; + return TSI_OK; +} + +static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector( + const tsi_handshaker_result* /*self*/, + size_t* max_output_protected_frame_size, + tsi_zero_copy_grpc_protector** protector) { + *protector = + tsi_create_fake_zero_copy_grpc_protector(max_output_protected_frame_size); + return TSI_OK; +} + +static tsi_result fake_handshaker_result_create_frame_protector( + const tsi_handshaker_result* /*self*/, + size_t* max_output_protected_frame_size, tsi_frame_protector** protector) { + *protector = tsi_create_fake_frame_protector(max_output_protected_frame_size); + return TSI_OK; +} + +static tsi_result fake_handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + fake_handshaker_result* result = reinterpret_cast( + const_cast(self)); + *bytes_size = result->unused_bytes_size; + *bytes = result->unused_bytes; + return TSI_OK; +} + +static void fake_handshaker_result_destroy(tsi_handshaker_result* self) { + fake_handshaker_result* result = + reinterpret_cast(self); + gpr_free(result->unused_bytes); + gpr_free(self); +} + +static const tsi_handshaker_result_vtable handshaker_result_vtable = { + fake_handshaker_result_extract_peer, + fake_handshaker_result_get_frame_protector_type, + fake_handshaker_result_create_zero_copy_grpc_protector, + fake_handshaker_result_create_frame_protector, + fake_handshaker_result_get_unused_bytes, + fake_handshaker_result_destroy, +}; + +static tsi_result fake_handshaker_result_create( + const unsigned char* unused_bytes, size_t unused_bytes_size, + tsi_handshaker_result** handshaker_result) { + if ((unused_bytes_size > 0 && unused_bytes == nullptr) || + handshaker_result == nullptr) { + return TSI_INVALID_ARGUMENT; + } + fake_handshaker_result* result = grpc_core::Zalloc(); + result->base.vtable = &handshaker_result_vtable; + if (unused_bytes_size > 0) { + result->unused_bytes = + static_cast(gpr_malloc(unused_bytes_size)); + memcpy(result->unused_bytes, unused_bytes, unused_bytes_size); + } + result->unused_bytes_size = unused_bytes_size; + *handshaker_result = &result->base; + return TSI_OK; +} + +/* --- tsi_handshaker methods implementation. ---*/ + +static tsi_result fake_handshaker_get_bytes_to_send_to_peer( + tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size) { + tsi_fake_handshaker* impl = reinterpret_cast(self); + tsi_result result = TSI_OK; + if (impl->needs_incoming_message || impl->result == TSI_OK) { + *bytes_size = 0; + return TSI_OK; + } + if (!impl->outgoing_frame.needs_draining) { + tsi_fake_handshake_message next_message_to_send = + // NOLINTNEXTLINE(bugprone-misplaced-widening-cast) + static_cast(impl->next_message_to_send + 2); + const char* msg_string = + tsi_fake_handshake_message_to_string(impl->next_message_to_send); + result = tsi_fake_frame_set_data( + reinterpret_cast(const_cast(msg_string)), + strlen(msg_string), &impl->outgoing_frame); + if (result != TSI_OK) return result; + if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { + next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX; + } + if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "%s prepared %s.", + impl->is_client ? "Client" : "Server", + tsi_fake_handshake_message_to_string(impl->next_message_to_send)); + } + impl->next_message_to_send = next_message_to_send; + } + result = tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame); + if (result != TSI_OK) return result; + if (!impl->is_client && + impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { + /* We're done. */ + if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "Server is done."); + } + impl->result = TSI_OK; + } else { + impl->needs_incoming_message = 1; + } + return TSI_OK; +} + +static tsi_result fake_handshaker_process_bytes_from_peer( + tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size) { + tsi_result result = TSI_OK; + tsi_fake_handshaker* impl = reinterpret_cast(self); + tsi_fake_handshake_message expected_msg = + static_cast(impl->next_message_to_send - 1); + tsi_fake_handshake_message received_msg; + + if (!impl->needs_incoming_message || impl->result == TSI_OK) { + *bytes_size = 0; + return TSI_OK; + } + result = tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame); + if (result != TSI_OK) return result; + + /* We now have a complete frame. */ + result = tsi_fake_handshake_message_from_string( + reinterpret_cast(impl->incoming_frame.data) + + TSI_FAKE_FRAME_HEADER_SIZE, + &received_msg); + if (result != TSI_OK) { + impl->result = result; + return result; + } + if (received_msg != expected_msg) { + gpr_log(GPR_ERROR, "Invalid received message (%s instead of %s)", + tsi_fake_handshake_message_to_string(received_msg), + tsi_fake_handshake_message_to_string(expected_msg)); + } + if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "%s received %s.", impl->is_client ? "Client" : "Server", + tsi_fake_handshake_message_to_string(received_msg)); + } + tsi_fake_frame_reset(&impl->incoming_frame, 0 /* needs_draining */); + impl->needs_incoming_message = 0; + if (impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) { + /* We're done. */ + if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "%s is done.", impl->is_client ? "Client" : "Server"); + } + impl->result = TSI_OK; + } + return TSI_OK; +} + +static tsi_result fake_handshaker_get_result(tsi_handshaker* self) { + tsi_fake_handshaker* impl = reinterpret_cast(self); + return impl->result; +} + +static void fake_handshaker_destroy(tsi_handshaker* self) { + tsi_fake_handshaker* impl = reinterpret_cast(self); + tsi_fake_frame_destruct(&impl->incoming_frame); + tsi_fake_frame_destruct(&impl->outgoing_frame); + gpr_free(impl->outgoing_bytes_buffer); + gpr_free(self); +} + +static tsi_result fake_handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** bytes_to_send, + size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, + tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { + /* Sanity check the arguments. */ + if ((received_bytes_size > 0 && received_bytes == nullptr) || + bytes_to_send == nullptr || bytes_to_send_size == nullptr || + handshaker_result == nullptr) { + return TSI_INVALID_ARGUMENT; + } + tsi_fake_handshaker* handshaker = + reinterpret_cast(self); + tsi_result result = TSI_OK; + + /* Decode and process a handshake frame from the peer. */ + size_t consumed_bytes_size = received_bytes_size; + if (received_bytes_size > 0) { + result = fake_handshaker_process_bytes_from_peer(self, received_bytes, + &consumed_bytes_size); + if (result != TSI_OK) return result; + } + + /* Create a handshake message to send to the peer and encode it as a fake + * frame. */ + size_t offset = 0; + do { + size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset; + result = fake_handshaker_get_bytes_to_send_to_peer( + self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size); + offset += sent_bytes_size; + if (result == TSI_INCOMPLETE_DATA) { + handshaker->outgoing_bytes_buffer_size *= 2; + handshaker->outgoing_bytes_buffer = static_cast( + gpr_realloc(handshaker->outgoing_bytes_buffer, + handshaker->outgoing_bytes_buffer_size)); + } + } while (result == TSI_INCOMPLETE_DATA); + if (result != TSI_OK) return result; + *bytes_to_send = handshaker->outgoing_bytes_buffer; + *bytes_to_send_size = offset; + + /* Check if the handshake was completed. */ + if (fake_handshaker_get_result(self) == TSI_HANDSHAKE_IN_PROGRESS) { + *handshaker_result = nullptr; + } else { + /* Calculate the unused bytes. */ + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = received_bytes_size - consumed_bytes_size; + if (unused_bytes_size > 0) { + unused_bytes = received_bytes + consumed_bytes_size; + } + + /* Create a handshaker_result containing the unused bytes. */ + result = fake_handshaker_result_create(unused_bytes, unused_bytes_size, + handshaker_result); + if (result == TSI_OK) { + /* Indicate that the handshake has completed and that a handshaker_result + * has been created. */ + self->handshaker_result_created = true; + } + } + return result; +} + +static const tsi_handshaker_vtable handshaker_vtable = { + nullptr, /* get_bytes_to_send_to_peer -- deprecated */ + nullptr, /* process_bytes_from_peer -- deprecated */ + nullptr, /* get_result -- deprecated */ + nullptr, /* extract_peer -- deprecated */ + nullptr, /* create_frame_protector -- deprecated */ + fake_handshaker_destroy, + fake_handshaker_next, + nullptr, /* shutdown */ +}; + +tsi_handshaker* tsi_create_fake_handshaker(int is_client) { + tsi_fake_handshaker* impl = grpc_core::Zalloc(); + impl->base.vtable = &handshaker_vtable; + impl->is_client = is_client; + impl->result = TSI_HANDSHAKE_IN_PROGRESS; + impl->outgoing_bytes_buffer_size = + TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE; + impl->outgoing_bytes_buffer = + static_cast(gpr_malloc(impl->outgoing_bytes_buffer_size)); + if (is_client) { + impl->needs_incoming_message = 0; + impl->next_message_to_send = TSI_FAKE_CLIENT_INIT; + } else { + impl->needs_incoming_message = 1; + impl->next_message_to_send = TSI_FAKE_SERVER_INIT; + } + return &impl->base; +} + +tsi_frame_protector* tsi_create_fake_frame_protector( + size_t* max_protected_frame_size) { + tsi_fake_frame_protector* impl = + grpc_core::Zalloc(); + impl->max_frame_size = (max_protected_frame_size == nullptr) + ? TSI_FAKE_DEFAULT_FRAME_SIZE + : *max_protected_frame_size; + impl->base.vtable = &frame_protector_vtable; + return &impl->base; +} + +tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector( + size_t* max_protected_frame_size) { + tsi_fake_zero_copy_grpc_protector* impl = + static_cast( + gpr_zalloc(sizeof(*impl))); + grpc_slice_buffer_init(&impl->header_sb); + grpc_slice_buffer_init(&impl->protected_sb); + impl->max_frame_size = (max_protected_frame_size == nullptr) + ? TSI_FAKE_DEFAULT_FRAME_SIZE + : *max_protected_frame_size; + impl->parsed_frame_size = 0; + impl->base.vtable = &zero_copy_grpc_protector_vtable; + return &impl->base; +} diff --git a/src/core/tsi/local_transport_security.cc b/src/core/tsi/local_transport_security.cc new file mode 100644 index 00000000..e12ba016 --- /dev/null +++ b/src/core/tsi/local_transport_security.cc @@ -0,0 +1,178 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/local_transport_security.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/tsi/transport_security_grpc.h" + +namespace { + +/* Main struct for local TSI zero-copy frame protector. */ +typedef struct local_zero_copy_grpc_protector { + tsi_zero_copy_grpc_protector base; +} local_zero_copy_grpc_protector; + +/* Main struct for local TSI handshaker result. */ +typedef struct local_tsi_handshaker_result { + tsi_handshaker_result base; + bool is_client; + unsigned char* unused_bytes; + size_t unused_bytes_size; +} local_tsi_handshaker_result; + +/* Main struct for local TSI handshaker. */ +typedef struct local_tsi_handshaker { + tsi_handshaker base; + bool is_client; +} local_tsi_handshaker; + +/* --- tsi_handshaker_result methods implementation. --- */ + +static tsi_result handshaker_result_extract_peer( + const tsi_handshaker_result* /*self*/, tsi_peer* /*peer*/) { + return TSI_OK; +} + +static tsi_result handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NONE; + return TSI_OK; +} + +static tsi_result handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + if (self == nullptr || bytes == nullptr || bytes_size == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to get_unused_bytes()"); + return TSI_INVALID_ARGUMENT; + } + auto* result = reinterpret_cast( + const_cast(self)); + *bytes_size = result->unused_bytes_size; + *bytes = result->unused_bytes; + return TSI_OK; +} + +static void handshaker_result_destroy(tsi_handshaker_result* self) { + if (self == nullptr) { + return; + } + local_tsi_handshaker_result* result = + reinterpret_cast( + const_cast(self)); + gpr_free(result->unused_bytes); + gpr_free(result); +} + +static const tsi_handshaker_result_vtable result_vtable = { + handshaker_result_extract_peer, + handshaker_result_get_frame_protector_type, + nullptr, /* handshaker_result_create_zero_copy_grpc_protector */ + nullptr, /* handshaker_result_create_frame_protector */ + handshaker_result_get_unused_bytes, + handshaker_result_destroy}; + +static tsi_result create_handshaker_result(bool is_client, + const unsigned char* received_bytes, + size_t received_bytes_size, + tsi_handshaker_result** self) { + if (self == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()"); + return TSI_INVALID_ARGUMENT; + } + local_tsi_handshaker_result* result = + grpc_core::Zalloc(); + result->is_client = is_client; + if (received_bytes_size > 0) { + result->unused_bytes = + static_cast(gpr_malloc(received_bytes_size)); + memcpy(result->unused_bytes, received_bytes, received_bytes_size); + } + result->unused_bytes_size = received_bytes_size; + result->base.vtable = &result_vtable; + *self = &result->base; + return TSI_OK; +} + +/* --- tsi_handshaker methods implementation. --- */ + +static tsi_result handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** /*bytes_to_send*/, + size_t* bytes_to_send_size, tsi_handshaker_result** result, + tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { + if (self == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()"); + return TSI_INVALID_ARGUMENT; + } + /* Note that there is no interaction between TSI peers, and all operations are + * local. + */ + local_tsi_handshaker* handshaker = + reinterpret_cast(self); + *bytes_to_send_size = 0; + create_handshaker_result(handshaker->is_client, received_bytes, + received_bytes_size, result); + return TSI_OK; +} + +static void handshaker_destroy(tsi_handshaker* self) { + if (self == nullptr) { + return; + } + local_tsi_handshaker* handshaker = + reinterpret_cast(self); + gpr_free(handshaker); +} + +static const tsi_handshaker_vtable handshaker_vtable = { + nullptr, /* get_bytes_to_send_to_peer -- deprecated */ + nullptr, /* process_bytes_from_peer -- deprecated */ + nullptr, /* get_result -- deprecated */ + nullptr, /* extract_peer -- deprecated */ + nullptr, /* create_frame_protector -- deprecated */ + handshaker_destroy, + handshaker_next, + nullptr, /* shutdown */ +}; + +} // namespace + +tsi_result tsi_local_handshaker_create(bool is_client, tsi_handshaker** self) { + if (self == nullptr) { + gpr_log(GPR_ERROR, "Invalid arguments to local_tsi_handshaker_create()"); + return TSI_INVALID_ARGUMENT; + } + local_tsi_handshaker* handshaker = grpc_core::Zalloc(); + handshaker->is_client = is_client; + handshaker->base.vtable = &handshaker_vtable; + *self = &handshaker->base; + return TSI_OK; +} diff --git a/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc b/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc new file mode 100644 index 00000000..dcfff00c --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +#ifdef OPENSSL_IS_BORINGSSL + +// BoringSSL allows SSL_SESSION to outlive SSL and SSL_CTX objects which are +// re-created by gRPC on every certificate rotation or subchannel creation. +// BoringSSL guarantees that SSL_SESSION is immutable so it's safe to share +// the same original session object between different threads and connections. + +namespace tsi { +namespace { + +class BoringSslCachedSession : public SslCachedSession { + public: + explicit BoringSslCachedSession(SslSessionPtr session) + : session_(std::move(session)) {} + + SslSessionPtr CopySession() const override { + // SslSessionPtr will dereference on destruction. + SSL_SESSION_up_ref(session_.get()); + return SslSessionPtr(session_.get()); + } + + private: + SslSessionPtr session_; +}; + +} // namespace + +std::unique_ptr SslCachedSession::Create( + SslSessionPtr session) { + return absl::make_unique(std::move(session)); +} + +} // namespace tsi + +#endif /* OPENSSL_IS_BORINGSSL */ diff --git a/src/core/tsi/ssl/session_cache/ssl_session_cache.cc b/src/core/tsi/ssl/session_cache/ssl_session_cache.cc new file mode 100644 index 00000000..5657fbc0 --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_cache.cc @@ -0,0 +1,179 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" + +#include +#include + +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +namespace tsi { + +/// Node for single cached session. +class SslSessionLRUCache::Node { + public: + Node(const std::string& key, SslSessionPtr session) : key_(key) { + SetSession(std::move(session)); + } + + // Not copyable nor movable. + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + const std::string& key() const { return key_; } + + /// Returns a copy of the node's cache session. + SslSessionPtr CopySession() const { return session_->CopySession(); } + + /// Set the \a session (which is moved) for the node. + void SetSession(SslSessionPtr session) { + session_ = SslCachedSession::Create(std::move(session)); + } + + private: + friend class SslSessionLRUCache; + + std::string key_; + std::unique_ptr session_; + + Node* next_ = nullptr; + Node* prev_ = nullptr; +}; + +SslSessionLRUCache::SslSessionLRUCache(size_t capacity) : capacity_(capacity) { + GPR_ASSERT(capacity > 0); +} + +SslSessionLRUCache::~SslSessionLRUCache() { + Node* node = use_order_list_head_; + while (node) { + Node* next = node->next_; + delete node; + node = next; + } +} + +size_t SslSessionLRUCache::Size() { + grpc_core::MutexLock lock(&lock_); + return use_order_list_size_; +} + +SslSessionLRUCache::Node* SslSessionLRUCache::FindLocked( + const std::string& key) { + auto it = entry_by_key_.find(key); + if (it == entry_by_key_.end()) { + return nullptr; + } + Node* node = it->second; + // Move to the beginning. + Remove(node); + PushFront(node); + AssertInvariants(); + return node; +} + +void SslSessionLRUCache::Put(const char* key, SslSessionPtr session) { + grpc_core::MutexLock lock(&lock_); + Node* node = FindLocked(key); + if (node != nullptr) { + node->SetSession(std::move(session)); + return; + } + node = new Node(key, std::move(session)); + PushFront(node); + entry_by_key_.emplace(key, node); + AssertInvariants(); + if (use_order_list_size_ > capacity_) { + GPR_ASSERT(use_order_list_tail_); + node = use_order_list_tail_; + Remove(node); + // Order matters, key is destroyed after deleting node. + entry_by_key_.erase(node->key()); + delete node; + AssertInvariants(); + } +} + +SslSessionPtr SslSessionLRUCache::Get(const char* key) { + grpc_core::MutexLock lock(&lock_); + // Key is only used for lookups. + Node* node = FindLocked(key); + if (node == nullptr) { + return nullptr; + } + return node->CopySession(); +} + +void SslSessionLRUCache::Remove(SslSessionLRUCache::Node* node) { + if (node->prev_ == nullptr) { + use_order_list_head_ = node->next_; + } else { + node->prev_->next_ = node->next_; + } + if (node->next_ == nullptr) { + use_order_list_tail_ = node->prev_; + } else { + node->next_->prev_ = node->prev_; + } + GPR_ASSERT(use_order_list_size_ >= 1); + use_order_list_size_--; +} + +void SslSessionLRUCache::PushFront(SslSessionLRUCache::Node* node) { + if (use_order_list_head_ == nullptr) { + use_order_list_head_ = node; + use_order_list_tail_ = node; + node->next_ = nullptr; + node->prev_ = nullptr; + } else { + node->next_ = use_order_list_head_; + node->next_->prev_ = node; + use_order_list_head_ = node; + node->prev_ = nullptr; + } + use_order_list_size_++; +} + +#ifndef NDEBUG +void SslSessionLRUCache::AssertInvariants() { + size_t size = 0; + Node* prev = nullptr; + Node* current = use_order_list_head_; + while (current != nullptr) { + size++; + GPR_ASSERT(current->prev_ == prev); + auto it = entry_by_key_.find(current->key()); + GPR_ASSERT(it != entry_by_key_.end()); + GPR_ASSERT(it->second == current); + prev = current; + current = current->next_; + } + GPR_ASSERT(prev == use_order_list_tail_); + GPR_ASSERT(size == use_order_list_size_); + GPR_ASSERT(entry_by_key_.size() == use_order_list_size_); +} +#else +void SslSessionLRUCache::AssertInvariants() {} +#endif + +} // namespace tsi diff --git a/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc b/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc new file mode 100644 index 00000000..60675d2d --- /dev/null +++ b/src/core/tsi/ssl/session_cache/ssl_session_openssl.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/tsi/ssl/session_cache/ssl_session.h" + +#ifndef OPENSSL_IS_BORINGSSL + +// OpenSSL invalidates SSL_SESSION on SSL destruction making it pointless +// to cache sessions. The workaround is to serialize (relatively expensive) +// session into binary blob and re-create it from blob on every handshake. +// Note that it's safe to keep serialized session outside of SSL lifetime +// as openssl performs all necessary validation while attempting to use a +// session and creates a new one if something is wrong (e.g. server changed +// set of allowed codecs). + +namespace tsi { +namespace { + +class OpenSslCachedSession : public SslCachedSession { + public: + OpenSslCachedSession(SslSessionPtr session) { + int size = i2d_SSL_SESSION(session.get(), nullptr); + GPR_ASSERT(size > 0); + grpc_slice slice = grpc_slice_malloc(size_t(size)); + unsigned char* start = GRPC_SLICE_START_PTR(slice); + int second_size = i2d_SSL_SESSION(session.get(), &start); + GPR_ASSERT(size == second_size); + serialized_session_ = slice; + } + + virtual ~OpenSslCachedSession() { grpc_slice_unref(serialized_session_); } + + SslSessionPtr CopySession() const override { + const unsigned char* data = GRPC_SLICE_START_PTR(serialized_session_); + size_t length = GRPC_SLICE_LENGTH(serialized_session_); + SSL_SESSION* session = d2i_SSL_SESSION(nullptr, &data, length); + if (session == nullptr) { + return SslSessionPtr(); + } + return SslSessionPtr(session); + } + + private: + grpc_slice serialized_session_; +}; + +} // namespace + +std::unique_ptr SslCachedSession::Create( + SslSessionPtr session) { + return absl::make_unique(std::move(session)); +} + +} // namespace tsi + +#endif /* OPENSSL_IS_BORINGSSL */ diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc new file mode 100644 index 00000000..5f4a40a3 --- /dev/null +++ b/src/core/tsi/ssl_transport_security.cc @@ -0,0 +1,2260 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/ssl_transport_security.h" + +#include +#include + +/* TODO(jboeuf): refactor inet_ntop into a portability header. */ +/* Note: for whomever reads this and tries to refactor this, this + can't be in grpc, it has to be in gpr. */ +#ifdef GPR_WINDOWS +#include +#else +#include +#include +#endif + +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" + +#include +#include +#include +#include +#include +#include + +extern "C" { +#include +#include /* For OPENSSL_free */ +#include +#include +#include +#include +#include +#include +} + +#include "src/core/lib/gpr/useful.h" +#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" +#include "src/core/tsi/ssl_types.h" +#include "src/core/tsi/transport_security.h" + +/* --- Constants. ---*/ + +#define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384 +#define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024 +#define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024 + +/* Putting a macro like this and littering the source file with #if is really + bad practice. + TODO(jboeuf): refactor all the #if / #endif in a separate module. */ +#ifndef TSI_OPENSSL_ALPN_SUPPORT +#define TSI_OPENSSL_ALPN_SUPPORT 1 +#endif + +/* TODO(jboeuf): I have not found a way to get this number dynamically from the + SSL structure. This is what we would ultimately want though... */ +#define TSI_SSL_MAX_PROTECTION_OVERHEAD 100 + +/* --- Structure definitions. ---*/ + +struct tsi_ssl_root_certs_store { + X509_STORE* store; +}; + +struct tsi_ssl_handshaker_factory { + const tsi_ssl_handshaker_factory_vtable* vtable; + gpr_refcount refcount; +}; + +struct tsi_ssl_client_handshaker_factory { + tsi_ssl_handshaker_factory base; + SSL_CTX* ssl_context; + unsigned char* alpn_protocol_list; + size_t alpn_protocol_list_length; + grpc_core::RefCountedPtr session_cache; +}; + +struct tsi_ssl_server_handshaker_factory { + /* Several contexts to support SNI. + The tsi_peer array contains the subject names of the server certificates + associated with the contexts at the same index. */ + tsi_ssl_handshaker_factory base; + SSL_CTX** ssl_contexts; + tsi_peer* ssl_context_x509_subject_names; + size_t ssl_context_count; + unsigned char* alpn_protocol_list; + size_t alpn_protocol_list_length; +}; + +struct tsi_ssl_handshaker { + tsi_handshaker base; + SSL* ssl; + BIO* network_io; + tsi_result result; + unsigned char* outgoing_bytes_buffer; + size_t outgoing_bytes_buffer_size; + tsi_ssl_handshaker_factory* factory_ref; +}; +struct tsi_ssl_handshaker_result { + tsi_handshaker_result base; + SSL* ssl; + BIO* network_io; + unsigned char* unused_bytes; + size_t unused_bytes_size; +}; +struct tsi_ssl_frame_protector { + tsi_frame_protector base; + SSL* ssl; + BIO* network_io; + unsigned char* buffer; + size_t buffer_size; + size_t buffer_offset; +}; +/* --- Library Initialization. ---*/ + +static gpr_once g_init_openssl_once = GPR_ONCE_INIT; +static int g_ssl_ctx_ex_factory_index = -1; +static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'}; +#if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE) +static const char kSslEnginePrefix[] = "engine:"; +#endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000 +static gpr_mu* g_openssl_mutexes = nullptr; +static void openssl_locking_cb(int mode, int type, const char* file, + int line) GRPC_UNUSED; +static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED; + +static void openssl_locking_cb(int mode, int type, const char* file, int line) { + if (mode & CRYPTO_LOCK) { + gpr_mu_lock(&g_openssl_mutexes[type]); + } else { + gpr_mu_unlock(&g_openssl_mutexes[type]); + } +} + +static unsigned long openssl_thread_id_cb(void) { + return static_cast(gpr_thd_currentid()); +} +#endif + +static void init_openssl(void) { +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + OPENSSL_init_ssl(0, nullptr); +#else + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); +#endif +#if OPENSSL_VERSION_NUMBER < 0x10100000 + if (!CRYPTO_get_locking_callback()) { + int num_locks = CRYPTO_num_locks(); + GPR_ASSERT(num_locks > 0); + g_openssl_mutexes = static_cast( + gpr_malloc(static_cast(num_locks) * sizeof(gpr_mu))); + for (int i = 0; i < num_locks; i++) { + gpr_mu_init(&g_openssl_mutexes[i]); + } + CRYPTO_set_locking_callback(openssl_locking_cb); + CRYPTO_set_id_callback(openssl_thread_id_cb); + } else { + gpr_log(GPR_INFO, "OpenSSL callback has already been set."); + } +#endif + g_ssl_ctx_ex_factory_index = + SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1); +} + +/* --- Ssl utils. ---*/ + +static const char* ssl_error_string(int error) { + switch (error) { + case SSL_ERROR_NONE: + return "SSL_ERROR_NONE"; + case SSL_ERROR_ZERO_RETURN: + return "SSL_ERROR_ZERO_RETURN"; + case SSL_ERROR_WANT_READ: + return "SSL_ERROR_WANT_READ"; + case SSL_ERROR_WANT_WRITE: + return "SSL_ERROR_WANT_WRITE"; + case SSL_ERROR_WANT_CONNECT: + return "SSL_ERROR_WANT_CONNECT"; + case SSL_ERROR_WANT_ACCEPT: + return "SSL_ERROR_WANT_ACCEPT"; + case SSL_ERROR_WANT_X509_LOOKUP: + return "SSL_ERROR_WANT_X509_LOOKUP"; + case SSL_ERROR_SYSCALL: + return "SSL_ERROR_SYSCALL"; + case SSL_ERROR_SSL: + return "SSL_ERROR_SSL"; + default: + return "Unknown error"; + } +} + +/* TODO(jboeuf): Remove when we are past the debugging phase with this code. */ +static void ssl_log_where_info(const SSL* ssl, int where, int flag, + const char* msg) { + if ((where & flag) && GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "%20.20s - %30.30s - %5.10s", msg, + SSL_state_string_long(ssl), SSL_state_string(ssl)); + } +} + +/* Used for debugging. TODO(jboeuf): Remove when code is mature enough. */ +static void ssl_info_callback(const SSL* ssl, int where, int ret) { + if (ret == 0) { + gpr_log(GPR_ERROR, "ssl_info_callback: error occurred.\n"); + return; + } + + ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP"); + ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START"); + ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE"); +} + +/* Returns 1 if name looks like an IP address, 0 otherwise. + This is a very rough heuristic, and only handles IPv6 in hexadecimal form. */ +static int looks_like_ip_address(absl::string_view name) { + size_t dot_count = 0; + size_t num_size = 0; + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == ':') { + /* IPv6 Address in hexadecimal form, : is not allowed in DNS names. */ + return 1; + } + if (name[i] >= '0' && name[i] <= '9') { + if (num_size > 3) return 0; + num_size++; + } else if (name[i] == '.') { + if (dot_count > 3 || num_size == 0) return 0; + dot_count++; + num_size = 0; + } else { + return 0; + } + } + if (dot_count < 3 || num_size == 0) return 0; + return 1; +} + +/* Gets the subject CN from an X509 cert. */ +static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8, + size_t* utf8_size) { + int common_name_index = -1; + X509_NAME_ENTRY* common_name_entry = nullptr; + ASN1_STRING* common_name_asn1 = nullptr; + X509_NAME* subject_name = X509_get_subject_name(cert); + int utf8_returned_size = 0; + if (subject_name == nullptr) { + gpr_log(GPR_INFO, "Could not get subject name from certificate."); + return TSI_NOT_FOUND; + } + common_name_index = + X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1); + if (common_name_index == -1) { + gpr_log(GPR_INFO, "Could not get common name of subject from certificate."); + return TSI_NOT_FOUND; + } + common_name_entry = X509_NAME_get_entry(subject_name, common_name_index); + if (common_name_entry == nullptr) { + gpr_log(GPR_ERROR, "Could not get common name entry from certificate."); + return TSI_INTERNAL_ERROR; + } + common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry); + if (common_name_asn1 == nullptr) { + gpr_log(GPR_ERROR, + "Could not get common name entry asn1 from certificate."); + return TSI_INTERNAL_ERROR; + } + utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1); + if (utf8_returned_size < 0) { + gpr_log(GPR_ERROR, "Could not extract utf8 from asn1 string."); + return TSI_OUT_OF_RESOURCES; + } + *utf8_size = static_cast(utf8_returned_size); + return TSI_OK; +} + +/* Gets the subject CN of an X509 cert as a tsi_peer_property. */ +static tsi_result peer_property_from_x509_common_name( + X509* cert, tsi_peer_property* property) { + unsigned char* common_name; + size_t common_name_size; + tsi_result result = + ssl_get_x509_common_name(cert, &common_name, &common_name_size); + if (result != TSI_OK) { + if (result == TSI_NOT_FOUND) { + common_name = nullptr; + common_name_size = 0; + } else { + return result; + } + } + result = tsi_construct_string_peer_property( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, + common_name == nullptr ? "" : reinterpret_cast(common_name), + common_name_size, property); + OPENSSL_free(common_name); + return result; +} + +/* Gets the X509 cert in PEM format as a tsi_peer_property. */ +static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) { + BIO* bio = BIO_new(BIO_s_mem()); + if (!PEM_write_bio_X509(bio, cert)) { + BIO_free(bio); + return TSI_INTERNAL_ERROR; + } + char* contents; + long len = BIO_get_mem_data(bio, &contents); + if (len <= 0) { + BIO_free(bio); + return TSI_INTERNAL_ERROR; + } + tsi_result result = tsi_construct_string_peer_property( + TSI_X509_PEM_CERT_PROPERTY, contents, static_cast(len), property); + BIO_free(bio); + return result; +} + +/* Gets the subject SANs from an X509 cert as a tsi_peer_property. */ +static tsi_result add_subject_alt_names_properties_to_peer( + tsi_peer* peer, GENERAL_NAMES* subject_alt_names, + size_t subject_alt_name_count, int* current_insert_index) { + size_t i; + tsi_result result = TSI_OK; + + for (i = 0; i < subject_alt_name_count; i++) { + GENERAL_NAME* subject_alt_name = + sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i)); + if (subject_alt_name->type == GEN_DNS || + subject_alt_name->type == GEN_EMAIL || + subject_alt_name->type == GEN_URI) { + unsigned char* name = nullptr; + int name_size; + std::string property_name; + if (subject_alt_name->type == GEN_DNS) { + name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName); + property_name = TSI_X509_DNS_PEER_PROPERTY; + } else if (subject_alt_name->type == GEN_EMAIL) { + name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.rfc822Name); + property_name = TSI_X509_EMAIL_PEER_PROPERTY; + } else { + name_size = ASN1_STRING_to_UTF8( + &name, subject_alt_name->d.uniformResourceIdentifier); + property_name = TSI_X509_URI_PEER_PROPERTY; + } + if (name_size < 0) { + gpr_log(GPR_ERROR, "Could not get utf8 from asn1 string."); + result = TSI_INTERNAL_ERROR; + break; + } + result = tsi_construct_string_peer_property( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, + reinterpret_cast(name), static_cast(name_size), + &peer->properties[(*current_insert_index)++]); + if (result != TSI_OK) { + OPENSSL_free(name); + break; + } + result = tsi_construct_string_peer_property( + property_name.c_str(), reinterpret_cast(name), + static_cast(name_size), + &peer->properties[(*current_insert_index)++]); + OPENSSL_free(name); + } else if (subject_alt_name->type == GEN_IPADD) { + char ntop_buf[INET6_ADDRSTRLEN]; + int af; + + if (subject_alt_name->d.iPAddress->length == 4) { + af = AF_INET; + } else if (subject_alt_name->d.iPAddress->length == 16) { + af = AF_INET6; + } else { + gpr_log(GPR_ERROR, "SAN IP Address contained invalid IP"); + result = TSI_INTERNAL_ERROR; + break; + } + const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data, + ntop_buf, INET6_ADDRSTRLEN); + if (name == nullptr) { + gpr_log(GPR_ERROR, "Could not get IP string from asn1 octet."); + result = TSI_INTERNAL_ERROR; + break; + } + + result = tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name, + &peer->properties[(*current_insert_index)++]); + if (result != TSI_OK) break; + result = tsi_construct_string_peer_property_from_cstring( + TSI_X509_IP_PEER_PROPERTY, name, + &peer->properties[(*current_insert_index)++]); + } else { + result = tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, "other types of SAN", + &peer->properties[(*current_insert_index)++]); + } + if (result != TSI_OK) break; + } + return result; +} + +/* Gets information about the peer's X509 cert as a tsi_peer object. */ +static tsi_result peer_from_x509(X509* cert, int include_certificate_type, + tsi_peer* peer) { + /* TODO(jboeuf): Maybe add more properties. */ + GENERAL_NAMES* subject_alt_names = static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + int subject_alt_name_count = + (subject_alt_names != nullptr) + ? static_cast(sk_GENERAL_NAME_num(subject_alt_names)) + : 0; + size_t property_count; + tsi_result result; + GPR_ASSERT(subject_alt_name_count >= 0); + property_count = (include_certificate_type ? static_cast(1) : 0) + + 2 /* common name, certificate */ + + static_cast(subject_alt_name_count); + for (int i = 0; i < subject_alt_name_count; i++) { + GENERAL_NAME* subject_alt_name = + sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i)); + // TODO(zhenlian): Clean up tsi_peer to avoid duplicate entries. + // URI, DNS, email and ip address SAN fields are plumbed to tsi_peer, in + // addition to all SAN fields (results in duplicate values). This code + // snippet updates property_count accordingly. + if (subject_alt_name->type == GEN_URI || + subject_alt_name->type == GEN_DNS || + subject_alt_name->type == GEN_EMAIL || + subject_alt_name->type == GEN_IPADD) { + property_count += 1; + } + } + result = tsi_construct_peer(property_count, peer); + if (result != TSI_OK) return result; + int current_insert_index = 0; + do { + if (include_certificate_type) { + result = tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer->properties[current_insert_index++]); + if (result != TSI_OK) break; + } + result = peer_property_from_x509_common_name( + cert, &peer->properties[current_insert_index++]); + if (result != TSI_OK) break; + + result = + add_pem_certificate(cert, &peer->properties[current_insert_index++]); + if (result != TSI_OK) break; + + if (subject_alt_name_count != 0) { + result = add_subject_alt_names_properties_to_peer( + peer, subject_alt_names, static_cast(subject_alt_name_count), + ¤t_insert_index); + if (result != TSI_OK) break; + } + } while (false); + + if (subject_alt_names != nullptr) { + sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free); + } + if (result != TSI_OK) tsi_peer_destruct(peer); + + GPR_ASSERT((int)peer->property_count == current_insert_index); + return result; +} + +/* Logs the SSL error stack. */ +static void log_ssl_error_stack(void) { + unsigned long err; + while ((err = ERR_get_error()) != 0) { + char details[256]; + ERR_error_string_n(static_cast(err), details, sizeof(details)); + gpr_log(GPR_ERROR, "%s", details); + } +} + +/* Performs an SSL_read and handle errors. */ +static tsi_result do_ssl_read(SSL* ssl, unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size) { + GPR_ASSERT(*unprotected_bytes_size <= INT_MAX); + ERR_clear_error(); + int read_from_ssl = SSL_read(ssl, unprotected_bytes, + static_cast(*unprotected_bytes_size)); + if (read_from_ssl <= 0) { + read_from_ssl = SSL_get_error(ssl, read_from_ssl); + switch (read_from_ssl) { + case SSL_ERROR_ZERO_RETURN: /* Received a close_notify alert. */ + case SSL_ERROR_WANT_READ: /* We need more data to finish the frame. */ + *unprotected_bytes_size = 0; + return TSI_OK; + case SSL_ERROR_WANT_WRITE: + gpr_log( + GPR_ERROR, + "Peer tried to renegotiate SSL connection. This is unsupported."); + return TSI_UNIMPLEMENTED; + case SSL_ERROR_SSL: + gpr_log(GPR_ERROR, "Corruption detected."); + log_ssl_error_stack(); + return TSI_DATA_CORRUPTED; + default: + gpr_log(GPR_ERROR, "SSL_read failed with error %s.", + ssl_error_string(read_from_ssl)); + return TSI_PROTOCOL_FAILURE; + } + } + *unprotected_bytes_size = static_cast(read_from_ssl); + return TSI_OK; +} + +/* Performs an SSL_write and handle errors. */ +static tsi_result do_ssl_write(SSL* ssl, unsigned char* unprotected_bytes, + size_t unprotected_bytes_size) { + GPR_ASSERT(unprotected_bytes_size <= INT_MAX); + ERR_clear_error(); + int ssl_write_result = SSL_write(ssl, unprotected_bytes, + static_cast(unprotected_bytes_size)); + if (ssl_write_result < 0) { + ssl_write_result = SSL_get_error(ssl, ssl_write_result); + if (ssl_write_result == SSL_ERROR_WANT_READ) { + gpr_log(GPR_ERROR, + "Peer tried to renegotiate SSL connection. This is unsupported."); + return TSI_UNIMPLEMENTED; + } else { + gpr_log(GPR_ERROR, "SSL_write failed with error %s.", + ssl_error_string(ssl_write_result)); + return TSI_INTERNAL_ERROR; + } + } + return TSI_OK; +} + +/* Loads an in-memory PEM certificate chain into the SSL context. */ +static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context, + const char* pem_cert_chain, + size_t pem_cert_chain_size) { + tsi_result result = TSI_OK; + X509* certificate = nullptr; + BIO* pem; + GPR_ASSERT(pem_cert_chain_size <= INT_MAX); + pem = BIO_new_mem_buf(pem_cert_chain, static_cast(pem_cert_chain_size)); + if (pem == nullptr) return TSI_OUT_OF_RESOURCES; + + do { + certificate = + PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast("")); + if (certificate == nullptr) { + result = TSI_INVALID_ARGUMENT; + break; + } + if (!SSL_CTX_use_certificate(context, certificate)) { + result = TSI_INVALID_ARGUMENT; + break; + } + while (true) { + X509* certificate_authority = + PEM_read_bio_X509(pem, nullptr, nullptr, const_cast("")); + if (certificate_authority == nullptr) { + ERR_clear_error(); + break; /* Done reading. */ + } + if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) { + X509_free(certificate_authority); + result = TSI_INVALID_ARGUMENT; + break; + } + /* We don't need to free certificate_authority as its ownership has been + transferred to the context. That is not the case for certificate + though. + */ + } + } while (false); + + if (certificate != nullptr) X509_free(certificate); + BIO_free(pem); + return result; +} + +#if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE) +static tsi_result ssl_ctx_use_engine_private_key(SSL_CTX* context, + const char* pem_key, + size_t pem_key_size) { + tsi_result result = TSI_OK; + EVP_PKEY* private_key = nullptr; + ENGINE* engine = nullptr; + char* engine_name = nullptr; + // Parse key which is in following format engine:: + do { + char* engine_start = (char*)pem_key + strlen(kSslEnginePrefix); + char* engine_end = (char*)strchr(engine_start, ':'); + if (engine_end == nullptr) { + result = TSI_INVALID_ARGUMENT; + break; + } + char* key_id = engine_end + 1; + int engine_name_length = engine_end - engine_start; + if (engine_name_length == 0) { + result = TSI_INVALID_ARGUMENT; + break; + } + engine_name = static_cast(gpr_zalloc(engine_name_length + 1)); + memcpy(engine_name, engine_start, engine_name_length); + gpr_log(GPR_DEBUG, "ENGINE key: %s", engine_name); + ENGINE_load_dynamic(); + engine = ENGINE_by_id(engine_name); + if (engine == nullptr) { + // If not available at ENGINE_DIR, use dynamic to load from + // current working directory. + engine = ENGINE_by_id("dynamic"); + if (engine == nullptr) { + gpr_log(GPR_ERROR, "Cannot load dynamic engine"); + result = TSI_INVALID_ARGUMENT; + break; + } + if (!ENGINE_ctrl_cmd_string(engine, "ID", engine_name, 0) || + !ENGINE_ctrl_cmd_string(engine, "DIR_LOAD", "2", 0) || + !ENGINE_ctrl_cmd_string(engine, "DIR_ADD", ".", 0) || + !ENGINE_ctrl_cmd_string(engine, "LIST_ADD", "1", 0) || + !ENGINE_ctrl_cmd_string(engine, "LOAD", NULL, 0)) { + gpr_log(GPR_ERROR, "Cannot find engine"); + result = TSI_INVALID_ARGUMENT; + break; + } + } + if (!ENGINE_set_default(engine, ENGINE_METHOD_ALL)) { + gpr_log(GPR_ERROR, "ENGINE_set_default with ENGINE_METHOD_ALL failed"); + result = TSI_INVALID_ARGUMENT; + break; + } + if (!ENGINE_init(engine)) { + gpr_log(GPR_ERROR, "ENGINE_init failed"); + result = TSI_INVALID_ARGUMENT; + break; + } + private_key = ENGINE_load_private_key(engine, key_id, 0, 0); + if (private_key == nullptr) { + gpr_log(GPR_ERROR, "ENGINE_load_private_key failed"); + result = TSI_INVALID_ARGUMENT; + break; + } + if (!SSL_CTX_use_PrivateKey(context, private_key)) { + gpr_log(GPR_ERROR, "SSL_CTX_use_PrivateKey failed"); + result = TSI_INVALID_ARGUMENT; + break; + } + } while (0); + if (engine != nullptr) ENGINE_free(engine); + if (private_key != nullptr) EVP_PKEY_free(private_key); + if (engine_name != nullptr) gpr_free(engine_name); + return result; +} +#endif /* !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE) */ + +static tsi_result ssl_ctx_use_pem_private_key(SSL_CTX* context, + const char* pem_key, + size_t pem_key_size) { + tsi_result result = TSI_OK; + EVP_PKEY* private_key = nullptr; + BIO* pem; + GPR_ASSERT(pem_key_size <= INT_MAX); + pem = BIO_new_mem_buf(pem_key, static_cast(pem_key_size)); + if (pem == nullptr) return TSI_OUT_OF_RESOURCES; + do { + private_key = + PEM_read_bio_PrivateKey(pem, nullptr, nullptr, const_cast("")); + if (private_key == nullptr) { + result = TSI_INVALID_ARGUMENT; + break; + } + if (!SSL_CTX_use_PrivateKey(context, private_key)) { + result = TSI_INVALID_ARGUMENT; + break; + } + } while (false); + if (private_key != nullptr) EVP_PKEY_free(private_key); + BIO_free(pem); + return result; +} + +/* Loads an in-memory PEM private key into the SSL context. */ +static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key, + size_t pem_key_size) { +// BoringSSL does not have ENGINE support +#if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE) + if (strncmp(pem_key, kSslEnginePrefix, strlen(kSslEnginePrefix)) == 0) { + return ssl_ctx_use_engine_private_key(context, pem_key, pem_key_size); + } else +#endif /* !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE) */ + { + return ssl_ctx_use_pem_private_key(context, pem_key, pem_key_size); + } +} + +/* Loads in-memory PEM verification certs into the SSL context and optionally + returns the verification cert names (root_names can be NULL). */ +static tsi_result x509_store_load_certs(X509_STORE* cert_store, + const char* pem_roots, + size_t pem_roots_size, + STACK_OF(X509_NAME) * *root_names) { + tsi_result result = TSI_OK; + size_t num_roots = 0; + X509* root = nullptr; + X509_NAME* root_name = nullptr; + BIO* pem; + GPR_ASSERT(pem_roots_size <= INT_MAX); + pem = BIO_new_mem_buf(pem_roots, static_cast(pem_roots_size)); + if (cert_store == nullptr) return TSI_INVALID_ARGUMENT; + if (pem == nullptr) return TSI_OUT_OF_RESOURCES; + if (root_names != nullptr) { + *root_names = sk_X509_NAME_new_null(); + if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES; + } + + while (true) { + root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast("")); + if (root == nullptr) { + ERR_clear_error(); + break; /* We're at the end of stream. */ + } + if (root_names != nullptr) { + root_name = X509_get_subject_name(root); + if (root_name == nullptr) { + gpr_log(GPR_ERROR, "Could not get name from root certificate."); + result = TSI_INVALID_ARGUMENT; + break; + } + root_name = X509_NAME_dup(root_name); + if (root_name == nullptr) { + result = TSI_OUT_OF_RESOURCES; + break; + } + sk_X509_NAME_push(*root_names, root_name); + root_name = nullptr; + } + ERR_clear_error(); + if (!X509_STORE_add_cert(cert_store, root)) { + unsigned long error = ERR_get_error(); + if (ERR_GET_LIB(error) != ERR_LIB_X509 || + ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + gpr_log(GPR_ERROR, "Could not add root certificate to ssl context."); + result = TSI_INTERNAL_ERROR; + break; + } + } + X509_free(root); + num_roots++; + } + if (num_roots == 0) { + gpr_log(GPR_ERROR, "Could not load any root certificate."); + result = TSI_INVALID_ARGUMENT; + } + + if (result != TSI_OK) { + if (root != nullptr) X509_free(root); + if (root_names != nullptr) { + sk_X509_NAME_pop_free(*root_names, X509_NAME_free); + *root_names = nullptr; + if (root_name != nullptr) X509_NAME_free(root_name); + } + } + BIO_free(pem); + return result; +} + +static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context, + const char* pem_roots, + size_t pem_roots_size, + STACK_OF(X509_NAME) * + *root_name) { + X509_STORE* cert_store = SSL_CTX_get_cert_store(context); + X509_STORE_set_flags(cert_store, + X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST); + return x509_store_load_certs(cert_store, pem_roots, pem_roots_size, + root_name); +} + +/* Populates the SSL context with a private key and a cert chain, and sets the + cipher list and the ephemeral ECDH key. */ +static tsi_result populate_ssl_context( + SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair, + const char* cipher_list) { + tsi_result result = TSI_OK; + if (key_cert_pair != nullptr) { + if (key_cert_pair->cert_chain != nullptr) { + result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain, + strlen(key_cert_pair->cert_chain)); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Invalid cert chain file."); + return result; + } + } + if (key_cert_pair->private_key != nullptr) { + result = ssl_ctx_use_private_key(context, key_cert_pair->private_key, + strlen(key_cert_pair->private_key)); + if (result != TSI_OK || !SSL_CTX_check_private_key(context)) { + gpr_log(GPR_ERROR, "Invalid private key."); + return result != TSI_OK ? result : TSI_INVALID_ARGUMENT; + } + } + } + if ((cipher_list != nullptr) && + !SSL_CTX_set_cipher_list(context, cipher_list)) { + gpr_log(GPR_ERROR, "Invalid cipher list: %s.", cipher_list); + return TSI_INVALID_ARGUMENT; + } + { + EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) { + gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key."); + EC_KEY_free(ecdh); + return TSI_INTERNAL_ERROR; + } + SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE); + EC_KEY_free(ecdh); + } + return TSI_OK; +} + +/* Extracts the CN and the SANs from an X509 cert as a peer object. */ +tsi_result tsi_ssl_extract_x509_subject_names_from_pem_cert( + const char* pem_cert, tsi_peer* peer) { + tsi_result result = TSI_OK; + X509* cert = nullptr; + BIO* pem; + pem = BIO_new_mem_buf(pem_cert, static_cast(strlen(pem_cert))); + if (pem == nullptr) return TSI_OUT_OF_RESOURCES; + + cert = PEM_read_bio_X509(pem, nullptr, nullptr, const_cast("")); + if (cert == nullptr) { + gpr_log(GPR_ERROR, "Invalid certificate"); + result = TSI_INVALID_ARGUMENT; + } else { + result = peer_from_x509(cert, 0, peer); + } + if (cert != nullptr) X509_free(cert); + BIO_free(pem); + return result; +} + +/* Builds the alpn protocol name list according to rfc 7301. */ +static tsi_result build_alpn_protocol_name_list( + const char** alpn_protocols, uint16_t num_alpn_protocols, + unsigned char** protocol_name_list, size_t* protocol_name_list_length) { + uint16_t i; + unsigned char* current; + *protocol_name_list = nullptr; + *protocol_name_list_length = 0; + if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT; + for (i = 0; i < num_alpn_protocols; i++) { + size_t length = + alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]); + if (length == 0 || length > 255) { + gpr_log(GPR_ERROR, "Invalid protocol name length: %d.", + static_cast(length)); + return TSI_INVALID_ARGUMENT; + } + *protocol_name_list_length += length + 1; + } + *protocol_name_list = + static_cast(gpr_malloc(*protocol_name_list_length)); + if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES; + current = *protocol_name_list; + for (i = 0; i < num_alpn_protocols; i++) { + size_t length = strlen(alpn_protocols[i]); + *(current++) = static_cast(length); /* max checked above. */ + memcpy(current, alpn_protocols[i], length); + current += length; + } + /* Safety check. */ + if ((current < *protocol_name_list) || + (static_cast(current - *protocol_name_list) != + *protocol_name_list_length)) { + return TSI_INTERNAL_ERROR; + } + return TSI_OK; +} + +// The verification callback is used for clients that don't really care about +// the server's certificate, but we need to pull it anyway, in case a higher +// layer wants to look at it. In this case the verification may fail, but +// we don't really care. +static int NullVerifyCallback(int /*preverify_ok*/, X509_STORE_CTX* /*ctx*/) { + return 1; +} + +// Sets the min and max TLS version of |ssl_context| to |min_tls_version| and +// |max_tls_version|, respectively. Calling this method is a no-op when using +// OpenSSL versions < 1.1. +static tsi_result tsi_set_min_and_max_tls_versions( + SSL_CTX* ssl_context, tsi_tls_version min_tls_version, + tsi_tls_version max_tls_version) { + if (ssl_context == nullptr) { + gpr_log(GPR_INFO, + "Invalid nullptr argument to |tsi_set_min_and_max_tls_versions|."); + return TSI_INVALID_ARGUMENT; + } +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + // Set the min TLS version of the SSL context if using OpenSSL version + // >= 1.1.0. This OpenSSL version is required because the + // |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs + // only exist in this version range. + switch (min_tls_version) { + case tsi_tls_version::TSI_TLS1_2: + SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION); + break; +#if defined(TLS1_3_VERSION) + // If the library does not support TLS 1.3 and the caller requests a minimum + // of TLS 1.3, then return an error because the caller's request cannot be + // satisfied. + case tsi_tls_version::TSI_TLS1_3: + SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION); + break; +#endif + default: + gpr_log(GPR_INFO, "TLS version is not supported."); + return TSI_FAILED_PRECONDITION; + } + + // Set the max TLS version of the SSL context. + switch (max_tls_version) { + case tsi_tls_version::TSI_TLS1_2: + SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION); + break; + case tsi_tls_version::TSI_TLS1_3: +#if defined(TLS1_3_VERSION) + SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION); +#else + // If the library does not support TLS 1.3, then set the max TLS version + // to TLS 1.2 instead. + SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION); +#endif + break; + default: + gpr_log(GPR_INFO, "TLS version is not supported."); + return TSI_FAILED_PRECONDITION; + } +#endif + return TSI_OK; +} + +/* --- tsi_ssl_root_certs_store methods implementation. ---*/ + +tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create( + const char* pem_roots) { + if (pem_roots == nullptr) { + gpr_log(GPR_ERROR, "The root certificates are empty."); + return nullptr; + } + tsi_ssl_root_certs_store* root_store = static_cast( + gpr_zalloc(sizeof(tsi_ssl_root_certs_store))); + if (root_store == nullptr) { + gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store."); + return nullptr; + } + root_store->store = X509_STORE_new(); + if (root_store->store == nullptr) { + gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE."); + gpr_free(root_store); + return nullptr; + } + tsi_result result = x509_store_load_certs(root_store->store, pem_roots, + strlen(pem_roots), nullptr); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Could not load root certificates."); + X509_STORE_free(root_store->store); + gpr_free(root_store); + return nullptr; + } + return root_store; +} + +void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) { + if (self == nullptr) return; + X509_STORE_free(self->store); + gpr_free(self); +} + +/* --- tsi_ssl_session_cache methods implementation. ---*/ + +tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) { + /* Pointer will be dereferenced by unref call. */ + return reinterpret_cast( + tsi::SslSessionLRUCache::Create(capacity).release()); +} + +void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) { + /* Pointer will be dereferenced by unref call. */ + reinterpret_cast(cache)->Ref().release(); +} + +void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) { + reinterpret_cast(cache)->Unref(); +} + +/* --- tsi_frame_protector methods implementation. ---*/ + +static tsi_result ssl_protector_protect(tsi_frame_protector* self, + const unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size, + unsigned char* protected_output_frames, + size_t* protected_output_frames_size) { + tsi_ssl_frame_protector* impl = + reinterpret_cast(self); + int read_from_ssl; + size_t available; + tsi_result result = TSI_OK; + + /* First see if we have some pending data in the SSL BIO. */ + int pending_in_ssl = static_cast(BIO_pending(impl->network_io)); + if (pending_in_ssl > 0) { + *unprotected_bytes_size = 0; + GPR_ASSERT(*protected_output_frames_size <= INT_MAX); + read_from_ssl = BIO_read(impl->network_io, protected_output_frames, + static_cast(*protected_output_frames_size)); + if (read_from_ssl < 0) { + gpr_log(GPR_ERROR, + "Could not read from BIO even though some data is pending"); + return TSI_INTERNAL_ERROR; + } + *protected_output_frames_size = static_cast(read_from_ssl); + return TSI_OK; + } + + /* Now see if we can send a complete frame. */ + available = impl->buffer_size - impl->buffer_offset; + if (available > *unprotected_bytes_size) { + /* If we cannot, just copy the data in our internal buffer. */ + memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes, + *unprotected_bytes_size); + impl->buffer_offset += *unprotected_bytes_size; + *protected_output_frames_size = 0; + return TSI_OK; + } + + /* If we can, prepare the buffer, send it to SSL_write and read. */ + memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes, available); + result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_size); + if (result != TSI_OK) return result; + + GPR_ASSERT(*protected_output_frames_size <= INT_MAX); + read_from_ssl = BIO_read(impl->network_io, protected_output_frames, + static_cast(*protected_output_frames_size)); + if (read_from_ssl < 0) { + gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write."); + return TSI_INTERNAL_ERROR; + } + *protected_output_frames_size = static_cast(read_from_ssl); + *unprotected_bytes_size = available; + impl->buffer_offset = 0; + return TSI_OK; +} + +static tsi_result ssl_protector_protect_flush( + tsi_frame_protector* self, unsigned char* protected_output_frames, + size_t* protected_output_frames_size, size_t* still_pending_size) { + tsi_result result = TSI_OK; + tsi_ssl_frame_protector* impl = + reinterpret_cast(self); + int read_from_ssl = 0; + int pending; + + if (impl->buffer_offset != 0) { + result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_offset); + if (result != TSI_OK) return result; + impl->buffer_offset = 0; + } + + pending = static_cast(BIO_pending(impl->network_io)); + GPR_ASSERT(pending >= 0); + *still_pending_size = static_cast(pending); + if (*still_pending_size == 0) return TSI_OK; + + GPR_ASSERT(*protected_output_frames_size <= INT_MAX); + read_from_ssl = BIO_read(impl->network_io, protected_output_frames, + static_cast(*protected_output_frames_size)); + if (read_from_ssl <= 0) { + gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write."); + return TSI_INTERNAL_ERROR; + } + *protected_output_frames_size = static_cast(read_from_ssl); + pending = static_cast(BIO_pending(impl->network_io)); + GPR_ASSERT(pending >= 0); + *still_pending_size = static_cast(pending); + return TSI_OK; +} + +static tsi_result ssl_protector_unprotect( + tsi_frame_protector* self, const unsigned char* protected_frames_bytes, + size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size) { + tsi_result result = TSI_OK; + int written_into_ssl = 0; + size_t output_bytes_size = *unprotected_bytes_size; + size_t output_bytes_offset = 0; + tsi_ssl_frame_protector* impl = + reinterpret_cast(self); + + /* First, try to read remaining data from ssl. */ + result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size); + if (result != TSI_OK) return result; + if (*unprotected_bytes_size == output_bytes_size) { + /* We have read everything we could and cannot process any more input. */ + *protected_frames_bytes_size = 0; + return TSI_OK; + } + output_bytes_offset = *unprotected_bytes_size; + unprotected_bytes += output_bytes_offset; + *unprotected_bytes_size = output_bytes_size - output_bytes_offset; + + /* Then, try to write some data to ssl. */ + GPR_ASSERT(*protected_frames_bytes_size <= INT_MAX); + written_into_ssl = BIO_write(impl->network_io, protected_frames_bytes, + static_cast(*protected_frames_bytes_size)); + if (written_into_ssl < 0) { + gpr_log(GPR_ERROR, "Sending protected frame to ssl failed with %d", + written_into_ssl); + return TSI_INTERNAL_ERROR; + } + *protected_frames_bytes_size = static_cast(written_into_ssl); + + /* Now try to read some data again. */ + result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size); + if (result == TSI_OK) { + /* Don't forget to output the total number of bytes read. */ + *unprotected_bytes_size += output_bytes_offset; + } + return result; +} + +static void ssl_protector_destroy(tsi_frame_protector* self) { + tsi_ssl_frame_protector* impl = + reinterpret_cast(self); + if (impl->buffer != nullptr) gpr_free(impl->buffer); + if (impl->ssl != nullptr) SSL_free(impl->ssl); + if (impl->network_io != nullptr) BIO_free(impl->network_io); + gpr_free(self); +} + +static const tsi_frame_protector_vtable frame_protector_vtable = { + ssl_protector_protect, + ssl_protector_protect_flush, + ssl_protector_unprotect, + ssl_protector_destroy, +}; + +/* --- tsi_server_handshaker_factory methods implementation. --- */ + +static void tsi_ssl_handshaker_factory_destroy( + tsi_ssl_handshaker_factory* factory) { + if (factory == nullptr) return; + + if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) { + factory->vtable->destroy(factory); + } + /* Note, we don't free(self) here because this object is always directly + * embedded in another object. If tsi_ssl_handshaker_factory_init allocates + * any memory, it should be free'd here. */ +} + +static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref( + tsi_ssl_handshaker_factory* factory) { + if (factory == nullptr) return nullptr; + gpr_refn(&factory->refcount, 1); + return factory; +} + +static void tsi_ssl_handshaker_factory_unref( + tsi_ssl_handshaker_factory* factory) { + if (factory == nullptr) return; + + if (gpr_unref(&factory->refcount)) { + tsi_ssl_handshaker_factory_destroy(factory); + } +} + +static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr}; + +/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for + * allocating memory for the factory. */ +static void tsi_ssl_handshaker_factory_init( + tsi_ssl_handshaker_factory* factory) { + GPR_ASSERT(factory != nullptr); + + factory->vtable = &handshaker_factory_vtable; + gpr_ref_init(&factory->refcount, 1); +} + +/* Gets the X509 cert chain in PEM format as a tsi_peer_property. */ +tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain, + tsi_peer_property* property) { + BIO* bio = BIO_new(BIO_s_mem()); + const auto peer_chain_len = sk_X509_num(peer_chain); + for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) { + if (!PEM_write_bio_X509(bio, sk_X509_value(peer_chain, i))) { + BIO_free(bio); + return TSI_INTERNAL_ERROR; + } + } + char* contents; + long len = BIO_get_mem_data(bio, &contents); + if (len <= 0) { + BIO_free(bio); + return TSI_INTERNAL_ERROR; + } + tsi_result result = tsi_construct_string_peer_property( + TSI_X509_PEM_CERT_CHAIN_PROPERTY, contents, static_cast(len), + property); + BIO_free(bio); + return result; +} + +/* --- tsi_handshaker_result methods implementation. ---*/ +static tsi_result ssl_handshaker_result_extract_peer( + const tsi_handshaker_result* self, tsi_peer* peer) { + tsi_result result = TSI_OK; + const unsigned char* alpn_selected = nullptr; + unsigned int alpn_selected_len; + const tsi_ssl_handshaker_result* impl = + reinterpret_cast(self); + X509* peer_cert = SSL_get_peer_certificate(impl->ssl); + if (peer_cert != nullptr) { + result = peer_from_x509(peer_cert, 1, peer); + X509_free(peer_cert); + if (result != TSI_OK) return result; + } +#if TSI_OPENSSL_ALPN_SUPPORT + SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len); +#endif /* TSI_OPENSSL_ALPN_SUPPORT */ + if (alpn_selected == nullptr) { + /* Try npn. */ + SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected, + &alpn_selected_len); + } + // When called on the client side, the stack also contains the + // peer's certificate; When called on the server side, + // the peer's certificate is not present in the stack + STACK_OF(X509)* peer_chain = SSL_get_peer_cert_chain(impl->ssl); + // 1 is for session reused property. + size_t new_property_count = peer->property_count + 3; + if (alpn_selected != nullptr) new_property_count++; + if (peer_chain != nullptr) new_property_count++; + tsi_peer_property* new_properties = static_cast( + gpr_zalloc(sizeof(*new_properties) * new_property_count)); + for (size_t i = 0; i < peer->property_count; i++) { + new_properties[i] = peer->properties[i]; + } + if (peer->properties != nullptr) gpr_free(peer->properties); + peer->properties = new_properties; + // Add peer chain if available + if (peer_chain != nullptr) { + result = tsi_ssl_get_cert_chain_contents( + peer_chain, &peer->properties[peer->property_count]); + if (result == TSI_OK) peer->property_count++; + } + if (alpn_selected != nullptr) { + result = tsi_construct_string_peer_property( + TSI_SSL_ALPN_SELECTED_PROTOCOL, + reinterpret_cast(alpn_selected), alpn_selected_len, + &peer->properties[peer->property_count]); + if (result != TSI_OK) return result; + peer->property_count++; + } + // Add security_level peer property. + result = tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer->properties[peer->property_count]); + if (result != TSI_OK) return result; + peer->property_count++; + + const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false"; + result = tsi_construct_string_peer_property_from_cstring( + TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused, + &peer->properties[peer->property_count]); + if (result != TSI_OK) return result; + peer->property_count++; + return result; +} + +static tsi_result ssl_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* /*self*/, + tsi_frame_protector_type* frame_protector_type) { + *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL; + return TSI_OK; +} + +static tsi_result ssl_handshaker_result_create_frame_protector( + const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, + tsi_frame_protector** protector) { + size_t actual_max_output_protected_frame_size = + TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND; + tsi_ssl_handshaker_result* impl = + reinterpret_cast( + const_cast(self)); + tsi_ssl_frame_protector* protector_impl = + static_cast( + gpr_zalloc(sizeof(*protector_impl))); + + if (max_output_protected_frame_size != nullptr) { + if (*max_output_protected_frame_size > + TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) { + *max_output_protected_frame_size = + TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND; + } else if (*max_output_protected_frame_size < + TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) { + *max_output_protected_frame_size = + TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND; + } + actual_max_output_protected_frame_size = *max_output_protected_frame_size; + } + protector_impl->buffer_size = + actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD; + protector_impl->buffer = + static_cast(gpr_malloc(protector_impl->buffer_size)); + if (protector_impl->buffer == nullptr) { + gpr_log(GPR_ERROR, + "Could not allocated buffer for tsi_ssl_frame_protector."); + gpr_free(protector_impl); + return TSI_INTERNAL_ERROR; + } + + /* Transfer ownership of ssl and network_io to the frame protector. */ + protector_impl->ssl = impl->ssl; + impl->ssl = nullptr; + protector_impl->network_io = impl->network_io; + impl->network_io = nullptr; + protector_impl->base.vtable = &frame_protector_vtable; + *protector = &protector_impl->base; + return TSI_OK; +} + +static tsi_result ssl_handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + const tsi_ssl_handshaker_result* impl = + reinterpret_cast(self); + *bytes_size = impl->unused_bytes_size; + *bytes = impl->unused_bytes; + return TSI_OK; +} + +static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) { + tsi_ssl_handshaker_result* impl = + reinterpret_cast(self); + SSL_free(impl->ssl); + BIO_free(impl->network_io); + gpr_free(impl->unused_bytes); + gpr_free(impl); +} + +static const tsi_handshaker_result_vtable handshaker_result_vtable = { + ssl_handshaker_result_extract_peer, + ssl_handshaker_result_get_frame_protector_type, + nullptr, /* create_zero_copy_grpc_protector */ + ssl_handshaker_result_create_frame_protector, + ssl_handshaker_result_get_unused_bytes, + ssl_handshaker_result_destroy, +}; + +static tsi_result ssl_handshaker_result_create( + tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes, + size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) { + if (handshaker == nullptr || handshaker_result == nullptr || + (unused_bytes_size > 0 && unused_bytes == nullptr)) { + return TSI_INVALID_ARGUMENT; + } + tsi_ssl_handshaker_result* result = + grpc_core::Zalloc(); + result->base.vtable = &handshaker_result_vtable; + /* Transfer ownership of ssl and network_io to the handshaker result. */ + result->ssl = handshaker->ssl; + handshaker->ssl = nullptr; + result->network_io = handshaker->network_io; + handshaker->network_io = nullptr; + /* Transfer ownership of |unused_bytes| to the handshaker result. */ + result->unused_bytes = unused_bytes; + result->unused_bytes_size = unused_bytes_size; + *handshaker_result = &result->base; + return TSI_OK; +} + +/* --- tsi_handshaker methods implementation. ---*/ + +static tsi_result ssl_handshaker_get_bytes_to_send_to_peer( + tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) { + int bytes_read_from_ssl = 0; + if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 || + *bytes_size > INT_MAX) { + return TSI_INVALID_ARGUMENT; + } + GPR_ASSERT(*bytes_size <= INT_MAX); + bytes_read_from_ssl = + BIO_read(impl->network_io, bytes, static_cast(*bytes_size)); + if (bytes_read_from_ssl < 0) { + *bytes_size = 0; + if (!BIO_should_retry(impl->network_io)) { + impl->result = TSI_INTERNAL_ERROR; + return impl->result; + } else { + return TSI_OK; + } + } + *bytes_size = static_cast(bytes_read_from_ssl); + return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA; +} + +static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) { + if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) && + SSL_is_init_finished(impl->ssl)) { + impl->result = TSI_OK; + } + return impl->result; +} + +static tsi_result ssl_handshaker_process_bytes_from_peer( + tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) { + int bytes_written_into_ssl_size = 0; + if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) { + return TSI_INVALID_ARGUMENT; + } + GPR_ASSERT(*bytes_size <= INT_MAX); + bytes_written_into_ssl_size = + BIO_write(impl->network_io, bytes, static_cast(*bytes_size)); + if (bytes_written_into_ssl_size < 0) { + gpr_log(GPR_ERROR, "Could not write to memory BIO."); + impl->result = TSI_INTERNAL_ERROR; + return impl->result; + } + *bytes_size = static_cast(bytes_written_into_ssl_size); + + if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) { + impl->result = TSI_OK; + return impl->result; + } else { + ERR_clear_error(); + /* Get ready to get some bytes from SSL. */ + int ssl_result = SSL_do_handshake(impl->ssl); + ssl_result = SSL_get_error(impl->ssl, ssl_result); + switch (ssl_result) { + case SSL_ERROR_WANT_READ: + if (BIO_pending(impl->network_io) == 0) { + /* We need more data. */ + return TSI_INCOMPLETE_DATA; + } else { + return TSI_OK; + } + case SSL_ERROR_NONE: + return TSI_OK; + default: { + char err_str[256]; + ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str)); + gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.", + ssl_error_string(ssl_result), err_str); + impl->result = TSI_PROTOCOL_FAILURE; + return impl->result; + } + } + } +} + +static void ssl_handshaker_destroy(tsi_handshaker* self) { + tsi_ssl_handshaker* impl = reinterpret_cast(self); + SSL_free(impl->ssl); + BIO_free(impl->network_io); + gpr_free(impl->outgoing_bytes_buffer); + tsi_ssl_handshaker_factory_unref(impl->factory_ref); + gpr_free(impl); +} + +// Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to +// |bytes_remaining|. +static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl, + unsigned char** bytes_remaining, + size_t* bytes_remaining_size) { + if (impl == nullptr || bytes_remaining == nullptr || + bytes_remaining_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + // Atempt to read all of the bytes in SSL's read BIO. These bytes should + // contain application data records that were appended to a handshake record + // containing the ClientFinished or ServerFinished message. + size_t bytes_in_ssl = BIO_pending(SSL_get_rbio(impl->ssl)); + if (bytes_in_ssl == 0) return TSI_OK; + *bytes_remaining = static_cast(gpr_malloc(bytes_in_ssl)); + int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining, + static_cast(bytes_in_ssl)); + // If an unexpected number of bytes were read, return an error status and free + // all of the bytes that were read. + if (bytes_read < 0 || static_cast(bytes_read) != bytes_in_ssl) { + gpr_log(GPR_ERROR, + "Failed to read the expected number of bytes from SSL object."); + gpr_free(*bytes_remaining); + *bytes_remaining = nullptr; + return TSI_INTERNAL_ERROR; + } + *bytes_remaining_size = static_cast(bytes_read); + return TSI_OK; +} + +static tsi_result ssl_handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** bytes_to_send, + size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, + tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { + /* Input sanity check. */ + if ((received_bytes_size > 0 && received_bytes == nullptr) || + bytes_to_send == nullptr || bytes_to_send_size == nullptr || + handshaker_result == nullptr) { + return TSI_INVALID_ARGUMENT; + } + /* If there are received bytes, process them first. */ + tsi_ssl_handshaker* impl = reinterpret_cast(self); + tsi_result status = TSI_OK; + size_t bytes_consumed = received_bytes_size; + if (received_bytes_size > 0) { + status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes, + &bytes_consumed); + if (status != TSI_OK) return status; + } + /* Get bytes to send to the peer, if available. */ + size_t offset = 0; + do { + size_t to_send_size = impl->outgoing_bytes_buffer_size - offset; + status = ssl_handshaker_get_bytes_to_send_to_peer( + impl, impl->outgoing_bytes_buffer + offset, &to_send_size); + offset += to_send_size; + if (status == TSI_INCOMPLETE_DATA) { + impl->outgoing_bytes_buffer_size *= 2; + impl->outgoing_bytes_buffer = static_cast(gpr_realloc( + impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size)); + } + } while (status == TSI_INCOMPLETE_DATA); + if (status != TSI_OK) return status; + *bytes_to_send = impl->outgoing_bytes_buffer; + *bytes_to_send_size = offset; + /* If handshake completes, create tsi_handshaker_result. */ + if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) { + *handshaker_result = nullptr; + } else { + // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is + // complete must be extracted and set to the unused bytes of the handshaker + // result. This indicates to the gRPC stack that there are bytes from the + // peer that must be processed. + unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + status = ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size); + if (status != TSI_OK) return status; + if (unused_bytes_size > received_bytes_size) { + gpr_log(GPR_ERROR, "More unused bytes than received bytes."); + gpr_free(unused_bytes); + return TSI_INTERNAL_ERROR; + } + status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size, + handshaker_result); + if (status == TSI_OK) { + /* Indicates that the handshake has completed and that a handshaker_result + * has been created. */ + self->handshaker_result_created = true; + } + } + return status; +} + +static const tsi_handshaker_vtable handshaker_vtable = { + nullptr, /* get_bytes_to_send_to_peer -- deprecated */ + nullptr, /* process_bytes_from_peer -- deprecated */ + nullptr, /* get_result -- deprecated */ + nullptr, /* extract_peer -- deprecated */ + nullptr, /* create_frame_protector -- deprecated */ + ssl_handshaker_destroy, + ssl_handshaker_next, + nullptr, /* shutdown */ +}; + +/* --- tsi_ssl_handshaker_factory common methods. --- */ + +static void tsi_ssl_handshaker_resume_session( + SSL* ssl, tsi::SslSessionLRUCache* session_cache) { + const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (server_name == nullptr) { + return; + } + tsi::SslSessionPtr session = session_cache->Get(server_name); + if (session != nullptr) { + // SSL_set_session internally increments reference counter. + SSL_set_session(ssl, session.get()); + } +} + +static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client, + const char* server_name_indication, + tsi_ssl_handshaker_factory* factory, + tsi_handshaker** handshaker) { + SSL* ssl = SSL_new(ctx); + BIO* network_io = nullptr; + BIO* ssl_io = nullptr; + tsi_ssl_handshaker* impl = nullptr; + *handshaker = nullptr; + if (ctx == nullptr) { + gpr_log(GPR_ERROR, "SSL Context is null. Should never happen."); + return TSI_INTERNAL_ERROR; + } + if (ssl == nullptr) { + return TSI_OUT_OF_RESOURCES; + } + SSL_set_info_callback(ssl, ssl_info_callback); + + if (!BIO_new_bio_pair(&network_io, 0, &ssl_io, 0)) { + gpr_log(GPR_ERROR, "BIO_new_bio_pair failed."); + SSL_free(ssl); + return TSI_OUT_OF_RESOURCES; + } + SSL_set_bio(ssl, ssl_io, ssl_io); + if (is_client) { + int ssl_result; + SSL_set_connect_state(ssl); + if (server_name_indication != nullptr) { + if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) { + gpr_log(GPR_ERROR, "Invalid server name indication %s.", + server_name_indication); + SSL_free(ssl); + BIO_free(network_io); + return TSI_INTERNAL_ERROR; + } + } + tsi_ssl_client_handshaker_factory* client_factory = + reinterpret_cast(factory); + if (client_factory->session_cache != nullptr) { + tsi_ssl_handshaker_resume_session(ssl, + client_factory->session_cache.get()); + } + ERR_clear_error(); + ssl_result = SSL_do_handshake(ssl); + ssl_result = SSL_get_error(ssl, ssl_result); + if (ssl_result != SSL_ERROR_WANT_READ) { + gpr_log(GPR_ERROR, + "Unexpected error received from first SSL_do_handshake call: %s", + ssl_error_string(ssl_result)); + SSL_free(ssl); + BIO_free(network_io); + return TSI_INTERNAL_ERROR; + } + } else { + SSL_set_accept_state(ssl); + } + + impl = grpc_core::Zalloc(); + impl->ssl = ssl; + impl->network_io = network_io; + impl->result = TSI_HANDSHAKE_IN_PROGRESS; + impl->outgoing_bytes_buffer_size = + TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE; + impl->outgoing_bytes_buffer = + static_cast(gpr_zalloc(impl->outgoing_bytes_buffer_size)); + impl->base.vtable = &handshaker_vtable; + impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory); + *handshaker = &impl->base; + return TSI_OK; +} + +static int select_protocol_list(const unsigned char** out, + unsigned char* outlen, + const unsigned char* client_list, + size_t client_list_len, + const unsigned char* server_list, + size_t server_list_len) { + const unsigned char* client_current = client_list; + while (static_cast(client_current - client_list) < + client_list_len) { + unsigned char client_current_len = *(client_current++); + const unsigned char* server_current = server_list; + while ((server_current >= server_list) && + static_cast(server_current - server_list) < + server_list_len) { + unsigned char server_current_len = *(server_current++); + if ((client_current_len == server_current_len) && + !memcmp(client_current, server_current, server_current_len)) { + *out = server_current; + *outlen = server_current_len; + return SSL_TLSEXT_ERR_OK; + } + server_current += server_current_len; + } + client_current += client_current_len; + } + return SSL_TLSEXT_ERR_NOACK; +} + +/* --- tsi_ssl_client_handshaker_factory methods implementation. --- */ + +tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( + tsi_ssl_client_handshaker_factory* factory, + const char* server_name_indication, tsi_handshaker** handshaker) { + return create_tsi_ssl_handshaker(factory->ssl_context, 1, + server_name_indication, &factory->base, + handshaker); +} + +void tsi_ssl_client_handshaker_factory_unref( + tsi_ssl_client_handshaker_factory* factory) { + if (factory == nullptr) return; + tsi_ssl_handshaker_factory_unref(&factory->base); +} + +static void tsi_ssl_client_handshaker_factory_destroy( + tsi_ssl_handshaker_factory* factory) { + if (factory == nullptr) return; + tsi_ssl_client_handshaker_factory* self = + reinterpret_cast(factory); + if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context); + if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list); + self->session_cache.reset(); + gpr_free(self); +} + +static int client_handshaker_factory_npn_callback( + SSL* /*ssl*/, unsigned char** out, unsigned char* outlen, + const unsigned char* in, unsigned int inlen, void* arg) { + tsi_ssl_client_handshaker_factory* factory = + static_cast(arg); + return select_protocol_list(const_cast(out), outlen, + factory->alpn_protocol_list, + factory->alpn_protocol_list_length, in, inlen); +} + +/* --- tsi_ssl_server_handshaker_factory methods implementation. --- */ + +tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( + tsi_ssl_server_handshaker_factory* factory, tsi_handshaker** handshaker) { + if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT; + /* Create the handshaker with the first context. We will switch if needed + because of SNI in ssl_server_handshaker_factory_servername_callback. */ + return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr, + &factory->base, handshaker); +} + +void tsi_ssl_server_handshaker_factory_unref( + tsi_ssl_server_handshaker_factory* factory) { + if (factory == nullptr) return; + tsi_ssl_handshaker_factory_unref(&factory->base); +} + +static void tsi_ssl_server_handshaker_factory_destroy( + tsi_ssl_handshaker_factory* factory) { + if (factory == nullptr) return; + tsi_ssl_server_handshaker_factory* self = + reinterpret_cast(factory); + size_t i; + for (i = 0; i < self->ssl_context_count; i++) { + if (self->ssl_contexts[i] != nullptr) { + SSL_CTX_free(self->ssl_contexts[i]); + tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]); + } + } + if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts); + if (self->ssl_context_x509_subject_names != nullptr) { + gpr_free(self->ssl_context_x509_subject_names); + } + if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list); + gpr_free(self); +} + +static int does_entry_match_name(absl::string_view entry, + absl::string_view name) { + if (entry.empty()) return 0; + + /* Take care of '.' terminations. */ + if (name.back() == '.') { + name.remove_suffix(1); + } + if (entry.back() == '.') { + entry.remove_suffix(1); + if (entry.empty()) return 0; + } + + if (absl::EqualsIgnoreCase(name, entry)) { + return 1; /* Perfect match. */ + } + if (entry.front() != '*') return 0; + + /* Wildchar subdomain matching. */ + if (entry.size() < 3 || entry[1] != '.') { /* At least *.x */ + gpr_log(GPR_ERROR, "Invalid wildchar entry."); + return 0; + } + size_t name_subdomain_pos = name.find('.'); + if (name_subdomain_pos == absl::string_view::npos) return 0; + if (name_subdomain_pos >= name.size() - 2) return 0; + absl::string_view name_subdomain = + name.substr(name_subdomain_pos + 1); /* Starts after the dot. */ + entry.remove_prefix(2); /* Remove *. */ + size_t dot = name_subdomain.find('.'); + if (dot == absl::string_view::npos || dot == name_subdomain.size() - 1) { + gpr_log(GPR_ERROR, "Invalid toplevel subdomain: %s", + std::string(name_subdomain).c_str()); + return 0; + } + if (name_subdomain.back() == '.') { + name_subdomain.remove_suffix(1); + } + return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry); +} + +static int ssl_server_handshaker_factory_servername_callback(SSL* ssl, + int* /*ap*/, + void* arg) { + tsi_ssl_server_handshaker_factory* impl = + static_cast(arg); + size_t i = 0; + const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (servername == nullptr || strlen(servername) == 0) { + return SSL_TLSEXT_ERR_NOACK; + } + + for (i = 0; i < impl->ssl_context_count; i++) { + if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i], + servername)) { + SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]); + return SSL_TLSEXT_ERR_OK; + } + } + gpr_log(GPR_ERROR, "No match found for server name: %s.", servername); + return SSL_TLSEXT_ERR_NOACK; +} + +#if TSI_OPENSSL_ALPN_SUPPORT +static int server_handshaker_factory_alpn_callback( + SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen, + const unsigned char* in, unsigned int inlen, void* arg) { + tsi_ssl_server_handshaker_factory* factory = + static_cast(arg); + return select_protocol_list(out, outlen, in, inlen, + factory->alpn_protocol_list, + factory->alpn_protocol_list_length); +} +#endif /* TSI_OPENSSL_ALPN_SUPPORT */ + +static int server_handshaker_factory_npn_advertised_callback( + SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) { + tsi_ssl_server_handshaker_factory* factory = + static_cast(arg); + *out = factory->alpn_protocol_list; + GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX); + *outlen = static_cast(factory->alpn_protocol_list_length); + return SSL_TLSEXT_ERR_OK; +} + +/// This callback is called when new \a session is established and ready to +/// be cached. This session can be reused for new connections to similar +/// servers at later point of time. +/// It's intended to be used with SSL_CTX_sess_set_new_cb function. +/// +/// It returns 1 if callback takes ownership over \a session and 0 otherwise. +static int server_handshaker_factory_new_session_callback( + SSL* ssl, SSL_SESSION* session) { + SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl); + if (ssl_context == nullptr) { + return 0; + } + void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index); + tsi_ssl_client_handshaker_factory* factory = + static_cast(arg); + const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (server_name == nullptr) { + return 0; + } + factory->session_cache->Put(server_name, tsi::SslSessionPtr(session)); + // Return 1 to indicate transferred ownership over the given session. + return 1; +} + +/* --- tsi_ssl_handshaker_factory constructors. --- */ + +static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = { + tsi_ssl_client_handshaker_factory_destroy}; + +tsi_result tsi_create_ssl_client_handshaker_factory( + const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair, + const char* pem_root_certs, const char* cipher_suites, + const char** alpn_protocols, uint16_t num_alpn_protocols, + tsi_ssl_client_handshaker_factory** factory) { + tsi_ssl_client_handshaker_options options; + options.pem_key_cert_pair = pem_key_cert_pair; + options.pem_root_certs = pem_root_certs; + options.cipher_suites = cipher_suites; + options.alpn_protocols = alpn_protocols; + options.num_alpn_protocols = num_alpn_protocols; + return tsi_create_ssl_client_handshaker_factory_with_options(&options, + factory); +} + +tsi_result tsi_create_ssl_client_handshaker_factory_with_options( + const tsi_ssl_client_handshaker_options* options, + tsi_ssl_client_handshaker_factory** factory) { + SSL_CTX* ssl_context = nullptr; + tsi_ssl_client_handshaker_factory* impl = nullptr; + tsi_result result = TSI_OK; + + gpr_once_init(&g_init_openssl_once, init_openssl); + + if (factory == nullptr) return TSI_INVALID_ARGUMENT; + *factory = nullptr; + if (options->pem_root_certs == nullptr && options->root_store == nullptr) { + return TSI_INVALID_ARGUMENT; + } + +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + ssl_context = SSL_CTX_new(TLS_method()); +#else + ssl_context = SSL_CTX_new(TLSv1_2_method()); +#endif + if (ssl_context == nullptr) { + log_ssl_error_stack(); + gpr_log(GPR_ERROR, "Could not create ssl context."); + return TSI_INVALID_ARGUMENT; + } + + result = tsi_set_min_and_max_tls_versions( + ssl_context, options->min_tls_version, options->max_tls_version); + if (result != TSI_OK) return result; + + impl = static_cast( + gpr_zalloc(sizeof(*impl))); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &client_handshaker_factory_vtable; + impl->ssl_context = ssl_context; + if (options->session_cache != nullptr) { + // Unref is called manually on factory destruction. + impl->session_cache = + reinterpret_cast(options->session_cache) + ->Ref(); + SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl); + SSL_CTX_sess_set_new_cb(ssl_context, + server_handshaker_factory_new_session_callback); + SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT); + } + + do { + result = populate_ssl_context(ssl_context, options->pem_key_cert_pair, + options->cipher_suites); + if (result != TSI_OK) break; + +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + // X509_STORE_up_ref is only available since OpenSSL 1.1. + if (options->root_store != nullptr) { + X509_STORE_up_ref(options->root_store->store); + SSL_CTX_set_cert_store(ssl_context, options->root_store->store); + } +#endif + if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) { + result = ssl_ctx_load_verification_certs( + ssl_context, options->pem_root_certs, strlen(options->pem_root_certs), + nullptr); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Cannot load server root certificates."); + break; + } + } + + if (options->num_alpn_protocols != 0) { + result = build_alpn_protocol_name_list( + options->alpn_protocols, options->num_alpn_protocols, + &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Building alpn list failed with error %s.", + tsi_result_to_string(result)); + break; + } +#if TSI_OPENSSL_ALPN_SUPPORT + GPR_ASSERT(impl->alpn_protocol_list_length < UINT_MAX); + if (SSL_CTX_set_alpn_protos( + ssl_context, impl->alpn_protocol_list, + static_cast(impl->alpn_protocol_list_length))) { + gpr_log(GPR_ERROR, "Could not set alpn protocol list to context."); + result = TSI_INVALID_ARGUMENT; + break; + } +#endif /* TSI_OPENSSL_ALPN_SUPPORT */ + SSL_CTX_set_next_proto_select_cb( + ssl_context, client_handshaker_factory_npn_callback, impl); + } + } while (false); + if (result != TSI_OK) { + tsi_ssl_handshaker_factory_unref(&impl->base); + return result; + } + if (options->skip_server_certificate_verification) { + SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NullVerifyCallback); + } else { + SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, nullptr); + } + /* TODO(jboeuf): Add revocation verification. */ + + *factory = impl; + return TSI_OK; +} + +static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = { + tsi_ssl_server_handshaker_factory_destroy}; + +tsi_result tsi_create_ssl_server_handshaker_factory( + const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs, const char* pem_client_root_certs, + int force_client_auth, const char* cipher_suites, + const char** alpn_protocols, uint16_t num_alpn_protocols, + tsi_ssl_server_handshaker_factory** factory) { + return tsi_create_ssl_server_handshaker_factory_ex( + pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs, + force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY + : TSI_DONT_REQUEST_CLIENT_CERTIFICATE, + cipher_suites, alpn_protocols, num_alpn_protocols, factory); +} + +tsi_result tsi_create_ssl_server_handshaker_factory_ex( + const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs, + size_t num_key_cert_pairs, const char* pem_client_root_certs, + tsi_client_certificate_request_type client_certificate_request, + const char* cipher_suites, const char** alpn_protocols, + uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) { + tsi_ssl_server_handshaker_options options; + options.pem_key_cert_pairs = pem_key_cert_pairs; + options.num_key_cert_pairs = num_key_cert_pairs; + options.pem_client_root_certs = pem_client_root_certs; + options.client_certificate_request = client_certificate_request; + options.cipher_suites = cipher_suites; + options.alpn_protocols = alpn_protocols; + options.num_alpn_protocols = num_alpn_protocols; + return tsi_create_ssl_server_handshaker_factory_with_options(&options, + factory); +} + +tsi_result tsi_create_ssl_server_handshaker_factory_with_options( + const tsi_ssl_server_handshaker_options* options, + tsi_ssl_server_handshaker_factory** factory) { + tsi_ssl_server_handshaker_factory* impl = nullptr; + tsi_result result = TSI_OK; + size_t i = 0; + + gpr_once_init(&g_init_openssl_once, init_openssl); + + if (factory == nullptr) return TSI_INVALID_ARGUMENT; + *factory = nullptr; + if (options->num_key_cert_pairs == 0 || + options->pem_key_cert_pairs == nullptr) { + return TSI_INVALID_ARGUMENT; + } + + impl = static_cast( + gpr_zalloc(sizeof(*impl))); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &server_handshaker_factory_vtable; + + impl->ssl_contexts = static_cast( + gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*))); + impl->ssl_context_x509_subject_names = static_cast( + gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer))); + if (impl->ssl_contexts == nullptr || + impl->ssl_context_x509_subject_names == nullptr) { + tsi_ssl_handshaker_factory_unref(&impl->base); + return TSI_OUT_OF_RESOURCES; + } + impl->ssl_context_count = options->num_key_cert_pairs; + + if (options->num_alpn_protocols > 0) { + result = build_alpn_protocol_name_list( + options->alpn_protocols, options->num_alpn_protocols, + &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); + if (result != TSI_OK) { + tsi_ssl_handshaker_factory_unref(&impl->base); + return result; + } + } + + for (i = 0; i < options->num_key_cert_pairs; i++) { + do { +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + impl->ssl_contexts[i] = SSL_CTX_new(TLS_method()); +#else + impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method()); +#endif + if (impl->ssl_contexts[i] == nullptr) { + log_ssl_error_stack(); + gpr_log(GPR_ERROR, "Could not create ssl context."); + result = TSI_OUT_OF_RESOURCES; + break; + } + + result = tsi_set_min_and_max_tls_versions(impl->ssl_contexts[i], + options->min_tls_version, + options->max_tls_version); + if (result != TSI_OK) return result; + + result = populate_ssl_context(impl->ssl_contexts[i], + &options->pem_key_cert_pairs[i], + options->cipher_suites); + if (result != TSI_OK) break; + + // TODO(elessar): Provide ability to disable session ticket keys. + + // Allow client cache sessions (it's needed for OpenSSL only). + int set_sid_ctx_result = SSL_CTX_set_session_id_context( + impl->ssl_contexts[i], kSslSessionIdContext, + GPR_ARRAY_SIZE(kSslSessionIdContext)); + if (set_sid_ctx_result == 0) { + gpr_log(GPR_ERROR, "Failed to set session id context."); + result = TSI_INTERNAL_ERROR; + break; + } + + if (options->session_ticket_key != nullptr) { + if (SSL_CTX_set_tlsext_ticket_keys( + impl->ssl_contexts[i], + const_cast(options->session_ticket_key), + options->session_ticket_key_size) == 0) { + gpr_log(GPR_ERROR, "Invalid STEK size."); + result = TSI_INVALID_ARGUMENT; + break; + } + } + + if (options->pem_client_root_certs != nullptr) { + STACK_OF(X509_NAME)* root_names = nullptr; + result = ssl_ctx_load_verification_certs( + impl->ssl_contexts[i], options->pem_client_root_certs, + strlen(options->pem_client_root_certs), &root_names); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Invalid verification certs."); + break; + } + SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names); + } + switch (options->client_certificate_request) { + case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, + NullVerifyCallback); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + NullVerifyCallback); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + break; + } + /* TODO(jboeuf): Add revocation verification. */ + + result = tsi_ssl_extract_x509_subject_names_from_pem_cert( + options->pem_key_cert_pairs[i].cert_chain, + &impl->ssl_context_x509_subject_names[i]); + if (result != TSI_OK) break; + + SSL_CTX_set_tlsext_servername_callback( + impl->ssl_contexts[i], + ssl_server_handshaker_factory_servername_callback); + SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl); +#if TSI_OPENSSL_ALPN_SUPPORT + SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i], + server_handshaker_factory_alpn_callback, impl); +#endif /* TSI_OPENSSL_ALPN_SUPPORT */ + SSL_CTX_set_next_protos_advertised_cb( + impl->ssl_contexts[i], + server_handshaker_factory_npn_advertised_callback, impl); + } while (false); + + if (result != TSI_OK) { + tsi_ssl_handshaker_factory_unref(&impl->base); + return result; + } + } + + *factory = impl; + return TSI_OK; +} + +/* --- tsi_ssl utils. --- */ + +int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) { + size_t i = 0; + size_t san_count = 0; + const tsi_peer_property* cn_property = nullptr; + int like_ip = looks_like_ip_address(name); + + /* Check the SAN first. */ + for (i = 0; i < peer->property_count; i++) { + const tsi_peer_property* property = &peer->properties[i]; + if (property->name == nullptr) continue; + if (strcmp(property->name, + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { + san_count++; + + absl::string_view entry(property->value.data, property->value.length); + if (!like_ip && does_entry_match_name(entry, name)) { + return 1; + } else if (like_ip && name == entry) { + /* IP Addresses are exact matches only. */ + return 1; + } + } else if (strcmp(property->name, + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) { + cn_property = property; + } + } + + /* If there's no SAN, try the CN, but only if its not like an IP Address */ + if (san_count == 0 && cn_property != nullptr && !like_ip) { + if (does_entry_match_name(absl::string_view(cn_property->value.data, + cn_property->value.length), + name)) { + return 1; + } + } + + return 0; /* Not found. */ +} + +/* --- Testing support. --- */ +const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable( + tsi_ssl_handshaker_factory* factory, + tsi_ssl_handshaker_factory_vtable* new_vtable) { + GPR_ASSERT(factory != nullptr); + GPR_ASSERT(factory->vtable != nullptr); + + const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable; + factory->vtable = new_vtable; + return orig_vtable; +} diff --git a/src/core/tsi/transport_security.cc b/src/core/tsi/transport_security.cc new file mode 100644 index 00000000..5d822604 --- /dev/null +++ b/src/core/tsi/transport_security.cc @@ -0,0 +1,384 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/transport_security.h" + +#include +#include + +#include +#include + +/* --- Tracing. --- */ + +grpc_core::TraceFlag tsi_tracing_enabled(false, "tsi"); + +/* --- tsi_result common implementation. --- */ + +const char* tsi_result_to_string(tsi_result result) { + switch (result) { + case TSI_OK: + return "TSI_OK"; + case TSI_UNKNOWN_ERROR: + return "TSI_UNKNOWN_ERROR"; + case TSI_INVALID_ARGUMENT: + return "TSI_INVALID_ARGUMENT"; + case TSI_PERMISSION_DENIED: + return "TSI_PERMISSION_DENIED"; + case TSI_INCOMPLETE_DATA: + return "TSI_INCOMPLETE_DATA"; + case TSI_FAILED_PRECONDITION: + return "TSI_FAILED_PRECONDITION"; + case TSI_UNIMPLEMENTED: + return "TSI_UNIMPLEMENTED"; + case TSI_INTERNAL_ERROR: + return "TSI_INTERNAL_ERROR"; + case TSI_DATA_CORRUPTED: + return "TSI_DATA_CORRUPTED"; + case TSI_NOT_FOUND: + return "TSI_NOT_FOUND"; + case TSI_PROTOCOL_FAILURE: + return "TSI_PROTOCOL_FAILURE"; + case TSI_HANDSHAKE_IN_PROGRESS: + return "TSI_HANDSHAKE_IN_PROGRESS"; + case TSI_OUT_OF_RESOURCES: + return "TSI_OUT_OF_RESOURCES"; + case TSI_ASYNC: + return "TSI_ASYNC"; + default: + return "UNKNOWN"; + } +} + +const char* tsi_security_level_to_string(tsi_security_level security_level) { + switch (security_level) { + case TSI_SECURITY_NONE: + return "TSI_SECURITY_NONE"; + case TSI_INTEGRITY_ONLY: + return "TSI_INTEGRITY_ONLY"; + case TSI_PRIVACY_AND_INTEGRITY: + return "TSI_PRIVACY_AND_INTEGRITY"; + default: + return "UNKNOWN"; + } +} + +/* --- tsi_frame_protector common implementation. --- + + Calls specific implementation after state/input validation. */ + +tsi_result tsi_frame_protector_protect(tsi_frame_protector* self, + const unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size, + unsigned char* protected_output_frames, + size_t* protected_output_frames_size) { + if (self == nullptr || self->vtable == nullptr || + unprotected_bytes == nullptr || unprotected_bytes_size == nullptr || + protected_output_frames == nullptr || + protected_output_frames_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->protect == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->protect(self, unprotected_bytes, unprotected_bytes_size, + protected_output_frames, + protected_output_frames_size); +} + +tsi_result tsi_frame_protector_protect_flush( + tsi_frame_protector* self, unsigned char* protected_output_frames, + size_t* protected_output_frames_size, size_t* still_pending_size) { + if (self == nullptr || self->vtable == nullptr || + protected_output_frames == nullptr || + protected_output_frames_size == nullptr || + still_pending_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->protect_flush == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->protect_flush(self, protected_output_frames, + protected_output_frames_size, + still_pending_size); +} + +tsi_result tsi_frame_protector_unprotect( + tsi_frame_protector* self, const unsigned char* protected_frames_bytes, + size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes, + size_t* unprotected_bytes_size) { + if (self == nullptr || self->vtable == nullptr || + protected_frames_bytes == nullptr || + protected_frames_bytes_size == nullptr || unprotected_bytes == nullptr || + unprotected_bytes_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->unprotect == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->unprotect(self, protected_frames_bytes, + protected_frames_bytes_size, unprotected_bytes, + unprotected_bytes_size); +} + +void tsi_frame_protector_destroy(tsi_frame_protector* self) { + if (self == nullptr) return; + self->vtable->destroy(self); +} + +/* --- tsi_handshaker common implementation. --- + + Calls specific implementation after state/input validation. */ + +tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker* self, + unsigned char* bytes, + size_t* bytes_size) { + if (self == nullptr || self->vtable == nullptr || bytes == nullptr || + bytes_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (self->vtable->get_bytes_to_send_to_peer == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size); +} + +tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker* self, + const unsigned char* bytes, + size_t* bytes_size) { + if (self == nullptr || self->vtable == nullptr || bytes == nullptr || + bytes_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (self->vtable->process_bytes_from_peer == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->process_bytes_from_peer(self, bytes, bytes_size); +} + +tsi_result tsi_handshaker_get_result(tsi_handshaker* self) { + if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT; + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (self->vtable->get_result == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->get_result(self); +} + +tsi_result tsi_handshaker_extract_peer(tsi_handshaker* self, tsi_peer* peer) { + if (self == nullptr || self->vtable == nullptr || peer == nullptr) { + return TSI_INVALID_ARGUMENT; + } + memset(peer, 0, sizeof(tsi_peer)); + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (tsi_handshaker_get_result(self) != TSI_OK) { + return TSI_FAILED_PRECONDITION; + } + if (self->vtable->extract_peer == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->extract_peer(self, peer); +} + +tsi_result tsi_handshaker_create_frame_protector( + tsi_handshaker* self, size_t* max_output_protected_frame_size, + tsi_frame_protector** protector) { + tsi_result result; + if (self == nullptr || self->vtable == nullptr || protector == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (tsi_handshaker_get_result(self) != TSI_OK) return TSI_FAILED_PRECONDITION; + if (self->vtable->create_frame_protector == nullptr) return TSI_UNIMPLEMENTED; + result = self->vtable->create_frame_protector( + self, max_output_protected_frame_size, protector); + if (result == TSI_OK) { + self->frame_protector_created = true; + } + return result; +} + +tsi_result tsi_handshaker_next( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** bytes_to_send, + size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, + tsi_handshaker_on_next_done_cb cb, void* user_data) { + if (self == nullptr || self->vtable == nullptr) return TSI_INVALID_ARGUMENT; + if (self->handshaker_result_created) return TSI_FAILED_PRECONDITION; + if (self->handshake_shutdown) return TSI_HANDSHAKE_SHUTDOWN; + if (self->vtable->next == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->next(self, received_bytes, received_bytes_size, + bytes_to_send, bytes_to_send_size, + handshaker_result, cb, user_data); +} + +void tsi_handshaker_shutdown(tsi_handshaker* self) { + if (self == nullptr || self->vtable == nullptr) return; + if (self->vtable->shutdown != nullptr) { + self->vtable->shutdown(self); + } + self->handshake_shutdown = true; +} + +void tsi_handshaker_destroy(tsi_handshaker* self) { + if (self == nullptr) return; + self->vtable->destroy(self); +} + +/* --- tsi_handshaker_result implementation. --- */ + +tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result* self, + tsi_peer* peer) { + if (self == nullptr || self->vtable == nullptr || peer == nullptr) { + return TSI_INVALID_ARGUMENT; + } + memset(peer, 0, sizeof(tsi_peer)); + if (self->vtable->extract_peer == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->extract_peer(self, peer); +} + +tsi_result tsi_handshaker_result_get_frame_protector_type( + const tsi_handshaker_result* self, + tsi_frame_protector_type* frame_protector_type) { + if (self == nullptr || frame_protector_type == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->get_frame_protector_type == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->get_frame_protector_type(self, frame_protector_type); +} + +tsi_result tsi_handshaker_result_create_frame_protector( + const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, + tsi_frame_protector** protector) { + if (self == nullptr || self->vtable == nullptr || protector == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->create_frame_protector == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->create_frame_protector( + self, max_output_protected_frame_size, protector); +} + +tsi_result tsi_handshaker_result_get_unused_bytes( + const tsi_handshaker_result* self, const unsigned char** bytes, + size_t* bytes_size) { + if (self == nullptr || self->vtable == nullptr || bytes == nullptr || + bytes_size == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->get_unused_bytes == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->get_unused_bytes(self, bytes, bytes_size); +} + +void tsi_handshaker_result_destroy(tsi_handshaker_result* self) { + if (self == nullptr) return; + self->vtable->destroy(self); +} + +/* --- tsi_peer implementation. --- */ + +tsi_peer_property tsi_init_peer_property(void) { + tsi_peer_property property; + memset(&property, 0, sizeof(tsi_peer_property)); + return property; +} + +static void tsi_peer_destroy_list_property(tsi_peer_property* children, + size_t child_count) { + size_t i; + for (i = 0; i < child_count; i++) { + tsi_peer_property_destruct(&children[i]); + } + gpr_free(children); +} + +void tsi_peer_property_destruct(tsi_peer_property* property) { + if (property->name != nullptr) { + gpr_free(property->name); + } + if (property->value.data != nullptr) { + gpr_free(property->value.data); + } + *property = tsi_init_peer_property(); /* Reset everything to 0. */ +} + +void tsi_peer_destruct(tsi_peer* self) { + if (self == nullptr) return; + if (self->properties != nullptr) { + tsi_peer_destroy_list_property(self->properties, self->property_count); + self->properties = nullptr; + } + self->property_count = 0; +} + +tsi_result tsi_construct_allocated_string_peer_property( + const char* name, size_t value_length, tsi_peer_property* property) { + *property = tsi_init_peer_property(); + if (name != nullptr) property->name = gpr_strdup(name); + if (value_length > 0) { + property->value.data = static_cast(gpr_zalloc(value_length)); + property->value.length = value_length; + } + return TSI_OK; +} + +tsi_result tsi_construct_string_peer_property_from_cstring( + const char* name, const char* value, tsi_peer_property* property) { + return tsi_construct_string_peer_property(name, value, strlen(value), + property); +} + +tsi_result tsi_construct_string_peer_property(const char* name, + const char* value, + size_t value_length, + tsi_peer_property* property) { + tsi_result result = tsi_construct_allocated_string_peer_property( + name, value_length, property); + if (result != TSI_OK) return result; + if (value_length > 0) { + memcpy(property->value.data, value, value_length); + } + return TSI_OK; +} + +tsi_result tsi_construct_peer(size_t property_count, tsi_peer* peer) { + memset(peer, 0, sizeof(tsi_peer)); + if (property_count > 0) { + peer->properties = static_cast( + gpr_zalloc(property_count * sizeof(tsi_peer_property))); + peer->property_count = property_count; + } + return TSI_OK; +} + +const tsi_peer_property* tsi_peer_get_property_by_name(const tsi_peer* peer, + const char* name) { + size_t i; + if (peer == nullptr) return nullptr; + for (i = 0; i < peer->property_count; i++) { + const tsi_peer_property* property = &peer->properties[i]; + if (name == nullptr && property->name == nullptr) { + return property; + } + if (name != nullptr && property->name != nullptr && + strcmp(property->name, name) == 0) { + return property; + } + } + return nullptr; +} diff --git a/src/core/tsi/transport_security_grpc.cc b/src/core/tsi/transport_security_grpc.cc new file mode 100644 index 00000000..cec87269 --- /dev/null +++ b/src/core/tsi/transport_security_grpc.cc @@ -0,0 +1,73 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/tsi/transport_security_grpc.h" + +/* This method creates a tsi_zero_copy_grpc_protector object. */ +tsi_result tsi_handshaker_result_create_zero_copy_grpc_protector( + const tsi_handshaker_result* self, size_t* max_output_protected_frame_size, + tsi_zero_copy_grpc_protector** protector) { + if (self == nullptr || self->vtable == nullptr || protector == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->create_zero_copy_grpc_protector == nullptr) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->create_zero_copy_grpc_protector( + self, max_output_protected_frame_size, protector); +} + +/* --- tsi_zero_copy_grpc_protector common implementation. --- + + Calls specific implementation after state/input validation. */ + +tsi_result tsi_zero_copy_grpc_protector_protect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices, + grpc_slice_buffer* protected_slices) { + if (self == nullptr || self->vtable == nullptr || + unprotected_slices == nullptr || protected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->protect == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->protect(self, unprotected_slices, protected_slices); +} + +tsi_result tsi_zero_copy_grpc_protector_unprotect( + tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices, + grpc_slice_buffer* unprotected_slices) { + if (self == nullptr || self->vtable == nullptr || + protected_slices == nullptr || unprotected_slices == nullptr) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->unprotect == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->unprotect(self, protected_slices, unprotected_slices); +} + +void tsi_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector* self) { + if (self == nullptr) return; + self->vtable->destroy(self); +} + +tsi_result tsi_zero_copy_grpc_protector_max_frame_size( + tsi_zero_copy_grpc_protector* self, size_t* max_frame_size) { + if (self == nullptr || max_frame_size == nullptr) return TSI_INVALID_ARGUMENT; + if (self->vtable->max_frame_size == nullptr) return TSI_UNIMPLEMENTED; + return self->vtable->max_frame_size(self, max_frame_size); +} diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc new file mode 100644 index 00000000..1702dc56 --- /dev/null +++ b/src/cpp/client/channel_cc.cc @@ -0,0 +1,275 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/surface/completion_queue.h" + +namespace grpc { + +static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer; +Channel::Channel(const std::string& host, grpc_channel* channel, + std::vector> + interceptor_creators) + : host_(host), c_channel_(channel) { + interceptor_creators_ = std::move(interceptor_creators); + g_gli_initializer.summon(); +} + +Channel::~Channel() { + grpc_channel_destroy(c_channel_); + CompletionQueue* callback_cq = callback_cq_.load(std::memory_order_relaxed); + if (callback_cq != nullptr) { + if (grpc_iomgr_run_in_background()) { + // gRPC-core provides the backing needed for the preferred CQ type + callback_cq->Shutdown(); + } else { + CompletionQueue::ReleaseCallbackAlternativeCQ(callback_cq); + } + } +} + +namespace { + +inline grpc_slice SliceFromArray(const char* arr, size_t len) { + return g_core_codegen_interface->grpc_slice_from_copied_buffer(arr, len); +} + +std::string GetChannelInfoField(grpc_channel* channel, + grpc_channel_info* channel_info, + char*** channel_info_field) { + char* value = nullptr; + memset(channel_info, 0, sizeof(*channel_info)); + *channel_info_field = &value; + grpc_channel_get_info(channel, channel_info); + if (value == nullptr) return ""; + std::string result = value; + gpr_free(value); + return result; +} + +} // namespace + +std::string Channel::GetLoadBalancingPolicyName() const { + grpc_channel_info channel_info; + return GetChannelInfoField(c_channel_, &channel_info, + &channel_info.lb_policy_name); +} + +std::string Channel::GetServiceConfigJSON() const { + grpc_channel_info channel_info; + return GetChannelInfoField(c_channel_, &channel_info, + &channel_info.service_config_json); +} + +namespace experimental { + +void ChannelResetConnectionBackoff(Channel* channel) { + grpc_channel_reset_connect_backoff(channel->c_channel_); +} + +} // namespace experimental + +::grpc::internal::Call Channel::CreateCallInternal( + const ::grpc::internal::RpcMethod& method, ::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq, size_t interceptor_pos) { + const bool kRegistered = method.channel_tag() && context->authority().empty(); + grpc_call* c_call = nullptr; + if (kRegistered) { + c_call = grpc_channel_create_registered_call( + c_channel_, context->propagate_from_call_, + context->propagation_options_.c_bitmask(), cq->cq(), + method.channel_tag(), context->raw_deadline(), nullptr); + } else { + const ::std::string* host_str = nullptr; + if (!context->authority_.empty()) { + host_str = &context->authority_; + } else if (!host_.empty()) { + host_str = &host_; + } + grpc_slice method_slice = + SliceFromArray(method.name(), strlen(method.name())); + grpc_slice host_slice; + if (host_str != nullptr) { + host_slice = ::grpc::SliceFromCopiedString(*host_str); + } + c_call = grpc_channel_create_call( + c_channel_, context->propagate_from_call_, + context->propagation_options_.c_bitmask(), cq->cq(), method_slice, + host_str == nullptr ? nullptr : &host_slice, context->raw_deadline(), + nullptr); + grpc_slice_unref(method_slice); + if (host_str != nullptr) { + grpc_slice_unref(host_slice); + } + } + grpc_census_call_set_context(c_call, context->census_context()); + + // ClientRpcInfo should be set before call because set_call also checks + // whether the call has been cancelled, and if the call was cancelled, we + // should notify the interceptors too. + auto* info = context->set_client_rpc_info( + method.name(), method.suffix_for_stats(), method.method_type(), this, + interceptor_creators_, interceptor_pos); + context->set_call(c_call, shared_from_this()); + + return ::grpc::internal::Call(c_call, this, cq, info); +} + +::grpc::internal::Call Channel::CreateCall( + const ::grpc::internal::RpcMethod& method, ::grpc::ClientContext* context, + CompletionQueue* cq) { + return CreateCallInternal(method, context, cq, 0); +} + +void Channel::PerformOpsOnCall(::grpc::internal::CallOpSetInterface* ops, + ::grpc::internal::Call* call) { + ops->FillOps( + call); // Make a copy of call. It's fine since Call just has pointers +} + +void* Channel::RegisterMethod(const char* method) { + return grpc_channel_register_call( + c_channel_, method, host_.empty() ? nullptr : host_.c_str(), nullptr); +} + +grpc_connectivity_state Channel::GetState(bool try_to_connect) { + return grpc_channel_check_connectivity_state(c_channel_, try_to_connect); +} + +namespace { + +class TagSaver final : public ::grpc::internal::CompletionQueueTag { + public: + explicit TagSaver(void* tag) : tag_(tag) {} + ~TagSaver() override {} + bool FinalizeResult(void** tag, bool* /*status*/) override { + *tag = tag_; + delete this; + return true; + } + + private: + void* tag_; +}; + +} // namespace + +void Channel::NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, + ::grpc::CompletionQueue* cq, void* tag) { + TagSaver* tag_saver = new TagSaver(tag); + grpc_channel_watch_connectivity_state(c_channel_, last_observed, deadline, + cq->cq(), tag_saver); +} + +bool Channel::WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) { + ::grpc::CompletionQueue cq; + bool ok = false; + void* tag = nullptr; + NotifyOnStateChangeImpl(last_observed, deadline, &cq, nullptr); + cq.Next(&tag, &ok); + GPR_ASSERT(tag == nullptr); + return ok; +} + +namespace { +class ShutdownCallback : public grpc_completion_queue_functor { + public: + ShutdownCallback() { + functor_run = &ShutdownCallback::Run; + // Set inlineable to true since this callback is trivial and thus does not + // need to be run from the executor (triggering a thread hop). This should + // only be used by internal callbacks like this and not by user application + // code. + inlineable = true; + } + // TakeCQ takes ownership of the cq into the shutdown callback + // so that the shutdown callback will be responsible for destroying it + void TakeCQ(::grpc::CompletionQueue* cq) { cq_ = cq; } + + // The Run function will get invoked by the completion queue library + // when the shutdown is actually complete + static void Run(grpc_completion_queue_functor* cb, int) { + auto* callback = static_cast(cb); + delete callback->cq_; + delete callback; + } + + private: + ::grpc::CompletionQueue* cq_ = nullptr; +}; +} // namespace + +::grpc::CompletionQueue* Channel::CallbackCQ() { + // TODO(vjpai): Consider using a single global CQ for the default CQ + // if there is no explicit per-channel CQ registered + CompletionQueue* callback_cq = callback_cq_.load(std::memory_order_acquire); + if (callback_cq != nullptr) { + return callback_cq; + } + // The callback_cq_ wasn't already set, so grab a lock and set it up exactly + // once for this channel. + grpc::internal::MutexLock l(&mu_); + callback_cq = callback_cq_.load(std::memory_order_relaxed); + if (callback_cq == nullptr) { + if (grpc_iomgr_run_in_background()) { + // gRPC-core provides the backing needed for the preferred CQ type + + auto* shutdown_callback = new ShutdownCallback; + callback_cq = + new ::grpc::CompletionQueue(grpc_completion_queue_attributes{ + GRPC_CQ_CURRENT_VERSION, GRPC_CQ_CALLBACK, + GRPC_CQ_DEFAULT_POLLING, shutdown_callback}); + + // Transfer ownership of the new cq to its own shutdown callback + shutdown_callback->TakeCQ(callback_cq); + } else { + // Otherwise we need to use the alternative CQ variant + callback_cq = CompletionQueue::CallbackAlternativeCQ(); + } + callback_cq_.store(callback_cq, std::memory_order_release); + } + return callback_cq; +} + +} // namespace grpc diff --git a/src/cpp/client/channel_test_peer.cc b/src/cpp/client/channel_test_peer.cc new file mode 100644 index 00000000..e55f5224 --- /dev/null +++ b/src/cpp/client/channel_test_peer.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/surface/channel.h" + +namespace grpc { +namespace testing { + +int ChannelTestPeer::registered_calls() const { + grpc_core::MutexLock lock(&channel_->c_channel_->registration_table->mu); + return static_cast(channel_->c_channel_->registration_table->map.size()); +} + +int ChannelTestPeer::registration_attempts() const { + grpc_core::MutexLock lock(&channel_->c_channel_->registration_table->mu); + return channel_->c_channel_->registration_table->method_registration_attempts; +} + +} // namespace testing +} // namespace grpc diff --git a/src/cpp/client/client_callback.cc b/src/cpp/client/client_callback.cc new file mode 100644 index 00000000..9e4ebbb3 --- /dev/null +++ b/src/cpp/client/client_callback.cc @@ -0,0 +1,57 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/surface/call.h" + +namespace grpc { +namespace internal { + +void ClientReactor::InternalScheduleOnDone(grpc::Status s) { + // Unlike other uses of closure, do not Ref or Unref here since the reactor + // object's lifetime is controlled by user code. + grpc_core::ExecCtx exec_ctx; + struct ClosureWithArg { + grpc_closure closure; + ClientReactor* const reactor; + const grpc::Status status; + ClosureWithArg(ClientReactor* reactor_arg, grpc::Status s) + : reactor(reactor_arg), status(std::move(s)) { + GRPC_CLOSURE_INIT( + &closure, + [](void* void_arg, grpc_error_handle) { + ClosureWithArg* arg = static_cast(void_arg); + arg->reactor->OnDone(arg->status); + delete arg; + }, + this, grpc_schedule_on_exec_ctx); + } + }; + ClosureWithArg* arg = new ClosureWithArg(this, std::move(s)); + grpc_core::Executor::Run(&arg->closure, GRPC_ERROR_NONE); +} + +bool ClientReactor::InternalTrailersOnly(const grpc_call* call) const { + return grpc_call_is_trailers_only(call); +} + +} // namespace internal +} // namespace grpc diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc new file mode 100644 index 00000000..d7a37f77 --- /dev/null +++ b/src/cpp/client/client_context.cc @@ -0,0 +1,179 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace grpc { + +class Channel; + +class DefaultGlobalClientCallbacks final + : public ClientContext::GlobalCallbacks { + public: + ~DefaultGlobalClientCallbacks() override {} + void DefaultConstructor(ClientContext* /*context*/) override {} + void Destructor(ClientContext* /*context*/) override {} +}; + +static internal::GrpcLibraryInitializer g_gli_initializer; +static DefaultGlobalClientCallbacks* g_default_client_callbacks = + new DefaultGlobalClientCallbacks(); +static ClientContext::GlobalCallbacks* g_client_callbacks = + g_default_client_callbacks; + +ClientContext::ClientContext() + : initial_metadata_received_(false), + wait_for_ready_(false), + wait_for_ready_explicitly_set_(false), + idempotent_(false), + cacheable_(false), + call_(nullptr), + call_canceled_(false), + deadline_(gpr_inf_future(GPR_CLOCK_REALTIME)), + census_context_(nullptr), + propagate_from_call_(nullptr), + compression_algorithm_(GRPC_COMPRESS_NONE), + initial_metadata_corked_(false) { + g_gli_initializer.summon(); + g_client_callbacks->DefaultConstructor(this); +} + +ClientContext::~ClientContext() { + if (call_) { + grpc_call_unref(call_); + } + g_client_callbacks->Destructor(this); +} + +void ClientContext::set_credentials( + const std::shared_ptr& creds) { + creds_ = creds; + // If call_ is set, we have already created the call, and set the call + // credentials. This should only be done before we have started the batch + // for sending initial metadata. + if (creds_ != nullptr && call_ != nullptr) { + if (!creds_->ApplyToCall(call_)) { + SendCancelToInterceptors(); + grpc_call_cancel_with_status(call_, GRPC_STATUS_CANCELLED, + "Failed to set credentials to rpc.", + nullptr); + } + } +} + +std::unique_ptr ClientContext::FromInternalServerContext( + const grpc::ServerContextBase& context, PropagationOptions options) { + std::unique_ptr ctx(new ClientContext); + ctx->propagate_from_call_ = context.call_.call; + ctx->propagation_options_ = options; + return ctx; +} + +std::unique_ptr ClientContext::FromServerContext( + const grpc::ServerContextBase& server_context, PropagationOptions options) { + return FromInternalServerContext(server_context, options); +} + +std::unique_ptr ClientContext::FromCallbackServerContext( + const grpc::CallbackServerContext& server_context, + PropagationOptions options) { + return FromInternalServerContext(server_context, options); +} + +void ClientContext::AddMetadata(const std::string& meta_key, + const std::string& meta_value) { + send_initial_metadata_.insert(std::make_pair(meta_key, meta_value)); +} + +void ClientContext::set_call(grpc_call* call, + const std::shared_ptr& channel) { + internal::MutexLock lock(&mu_); + GPR_ASSERT(call_ == nullptr); + call_ = call; + channel_ = channel; + if (creds_ && !creds_->ApplyToCall(call_)) { + // TODO(yashykt): should interceptors also see this status? + SendCancelToInterceptors(); + grpc_call_cancel_with_status(call, GRPC_STATUS_CANCELLED, + "Failed to set credentials to rpc.", nullptr); + } + if (call_canceled_) { + SendCancelToInterceptors(); + grpc_call_cancel(call_, nullptr); + } +} + +void ClientContext::set_compression_algorithm( + grpc_compression_algorithm algorithm) { + compression_algorithm_ = algorithm; + const char* algorithm_name = nullptr; + if (!grpc_compression_algorithm_name(algorithm, &algorithm_name)) { + gpr_log(GPR_ERROR, "Name for compression algorithm '%d' unknown.", + algorithm); + abort(); + } + GPR_ASSERT(algorithm_name != nullptr); + AddMetadata(GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, algorithm_name); +} + +void ClientContext::TryCancel() { + internal::MutexLock lock(&mu_); + if (call_) { + SendCancelToInterceptors(); + grpc_call_cancel(call_, nullptr); + } else { + call_canceled_ = true; + } +} + +void ClientContext::SendCancelToInterceptors() { + internal::CancelInterceptorBatchMethods cancel_methods; + for (size_t i = 0; i < rpc_info_.interceptors_.size(); i++) { + rpc_info_.RunInterceptor(&cancel_methods, i); + } +} + +std::string ClientContext::peer() const { + std::string peer; + if (call_) { + char* c_peer = grpc_call_get_peer(call_); + peer = c_peer; + gpr_free(c_peer); + } + return peer; +} + +void ClientContext::SetGlobalCallbacks(GlobalCallbacks* client_callbacks) { + GPR_ASSERT(g_client_callbacks == g_default_client_callbacks); + GPR_ASSERT(client_callbacks != nullptr); + GPR_ASSERT(client_callbacks != g_default_client_callbacks); + g_client_callbacks = client_callbacks; +} + +} // namespace grpc diff --git a/src/cpp/client/client_interceptor.cc b/src/cpp/client/client_interceptor.cc new file mode 100644 index 00000000..a91950ca --- /dev/null +++ b/src/cpp/client/client_interceptor.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { + +namespace internal { +experimental::ClientInterceptorFactoryInterface* + g_global_client_interceptor_factory = nullptr; +} + +namespace experimental { +void RegisterGlobalClientInterceptorFactory( + ClientInterceptorFactoryInterface* factory) { + if (internal::g_global_client_interceptor_factory != nullptr) { + GPR_ASSERT(false && + "It is illegal to call RegisterGlobalClientInterceptorFactory " + "multiple times."); + } + internal::g_global_client_interceptor_factory = factory; +} + +// For testing purposes only. +void TestOnlyResetGlobalClientInterceptorFactory() { + internal::g_global_client_interceptor_factory = nullptr; +} +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc new file mode 100644 index 00000000..48831d0f --- /dev/null +++ b/src/cpp/client/create_channel.cc @@ -0,0 +1,85 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "src/cpp/client/create_channel_internal.h" + +namespace grpc { +std::shared_ptr CreateChannel( + const grpc::string& target, + const std::shared_ptr& creds) { + return CreateCustomChannel(target, creds, grpc::ChannelArguments()); +} + +std::shared_ptr CreateCustomChannel( + const grpc::string& target, + const std::shared_ptr& creds, + const grpc::ChannelArguments& args) { + grpc::GrpcLibraryCodegen + init_lib; // We need to call init in case of bad creds. + return creds ? creds->CreateChannelImpl(target, args) + : grpc::CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::vector>()); +} + +namespace experimental { +/// Create a new \em custom \a Channel pointing to \a target with \a +/// interceptors being invoked per call. +/// +/// \warning For advanced use and testing ONLY. Override default channel +/// arguments only if necessary. +/// +/// \param target The URI of the endpoint to connect to. +/// \param creds Credentials to use for the created channel. If it does not +/// hold an object or is invalid, a lame channel (one on which all operations +/// fail) is returned. +/// \param args Options for channel creation. +std::shared_ptr CreateCustomChannelWithInterceptors( + const std::string& target, + const std::shared_ptr& creds, + const grpc::ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + grpc::GrpcLibraryCodegen + init_lib; // We need to call init in case of bad creds. + return creds ? creds->CreateChannelWithInterceptors( + target, args, std::move(interceptor_creators)) + : grpc::CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::move(interceptor_creators)); +} +} // namespace experimental + +} // namespace grpc diff --git a/src/cpp/client/create_channel_internal.cc b/src/cpp/client/create_channel_internal.cc new file mode 100644 index 00000000..f90836a6 --- /dev/null +++ b/src/cpp/client/create_channel_internal.cc @@ -0,0 +1,41 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/client/create_channel_internal.h" + +#include +#include +#include +#include + +#include + +struct grpc_channel; + +namespace grpc { + +std::shared_ptr CreateChannelInternal( + const std::string& host, grpc_channel* c_channel, + std::vector> + interceptor_creators) { + return std::shared_ptr( + new Channel(host, c_channel, std::move(interceptor_creators))); +} + +} // namespace grpc diff --git a/src/cpp/client/create_channel_posix.cc b/src/cpp/client/create_channel_posix.cc new file mode 100644 index 00000000..a89df0a3 --- /dev/null +++ b/src/cpp/client/create_channel_posix.cc @@ -0,0 +1,77 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include + +#include "src/cpp/client/create_channel_internal.h" + +namespace grpc { + +class ChannelArguments; + +#ifdef GPR_SUPPORT_CHANNELS_FROM_FD + +std::shared_ptr CreateInsecureChannelFromFd(const std::string& target, + int fd) { + grpc::internal::GrpcLibrary init_lib; + init_lib.init(); + return CreateChannelInternal( + "", grpc_insecure_channel_create_from_fd(target.c_str(), fd, nullptr), + std::vector< + std::unique_ptr>()); +} + +std::shared_ptr CreateCustomInsecureChannelFromFd( + const std::string& target, int fd, const grpc::ChannelArguments& args) { + internal::GrpcLibrary init_lib; + init_lib.init(); + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return CreateChannelInternal( + "", + grpc_insecure_channel_create_from_fd(target.c_str(), fd, &channel_args), + std::vector< + std::unique_ptr>()); +} + +namespace experimental { + +std::shared_ptr CreateCustomInsecureChannelWithInterceptorsFromFd( + const std::string& target, int fd, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + grpc::internal::GrpcLibrary init_lib; + init_lib.init(); + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return CreateChannelInternal( + "", + grpc_insecure_channel_create_from_fd(target.c_str(), fd, &channel_args), + std::move(interceptor_creators)); +} + +} // namespace experimental + +#endif // GPR_SUPPORT_CHANNELS_FROM_FD + +} // namespace grpc diff --git a/src/cpp/client/credentials_cc.cc b/src/cpp/client/credentials_cc.cc new file mode 100644 index 00000000..9dfb2f49 --- /dev/null +++ b/src/cpp/client/credentials_cc.cc @@ -0,0 +1,33 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +static grpc::internal::GrpcLibraryInitializer g_gli_initializer; +ChannelCredentials::ChannelCredentials() { g_gli_initializer.summon(); } + +ChannelCredentials::~ChannelCredentials() {} + +CallCredentials::CallCredentials() { g_gli_initializer.summon(); } + +CallCredentials::~CallCredentials() {} + +} // namespace grpc diff --git a/src/cpp/client/cronet_credentials.cc b/src/cpp/client/cronet_credentials.cc new file mode 100644 index 00000000..4ef81139 --- /dev/null +++ b/src/cpp/client/cronet_credentials.cc @@ -0,0 +1,63 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "src/cpp/client/create_channel_internal.h" + +namespace grpc { + +class CronetChannelCredentialsImpl final : public ChannelCredentials { + public: + CronetChannelCredentialsImpl(void* engine) : engine_(engine) {} + + std::shared_ptr CreateChannelImpl( + const string& target, const grpc::ChannelArguments& args) override { + return CreateChannelWithInterceptors( + target, args, + std::vector>()); + } + + SecureChannelCredentials* AsSecureCredentials() override { return nullptr; } + + private: + std::shared_ptr CreateChannelWithInterceptors( + const string& target, const grpc::ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) override { + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return CreateChannelInternal( + "", + grpc_cronet_secure_channel_create(engine_, target.c_str(), + &channel_args, nullptr), + std::move(interceptor_creators)); + } + void* engine_; +}; + +std::shared_ptr CronetChannelCredentials(void* engine) { + return std::shared_ptr( + new grpc::CronetChannelCredentialsImpl(engine)); +} +} // namespace grpc diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc new file mode 100644 index 00000000..57c0e77f --- /dev/null +++ b/src/cpp/client/insecure_credentials.cc @@ -0,0 +1,65 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include +#include +#include +#include +#include +#include + +#include "src/cpp/client/create_channel_internal.h" + +namespace grpc { + +namespace { +class InsecureChannelCredentialsImpl final : public ChannelCredentials { + public: + std::shared_ptr CreateChannelImpl( + const std::string& target, const ChannelArguments& args) override { + return CreateChannelWithInterceptors( + target, args, + std::vector>()); + } + + std::shared_ptr CreateChannelWithInterceptors( + const std::string& target, const ChannelArguments& args, + std::vector> + interceptor_creators) override { + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return ::grpc::CreateChannelInternal( + "", + grpc_insecure_channel_create(target.c_str(), &channel_args, nullptr), + std::move(interceptor_creators)); + } + + SecureChannelCredentials* AsSecureCredentials() override { return nullptr; } + + private: + bool IsInsecure() const override { return true; } +}; +} // namespace + +std::shared_ptr InsecureChannelCredentials() { + return std::shared_ptr( + new InsecureChannelCredentialsImpl()); +} + +} // namespace grpc diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc new file mode 100644 index 00000000..d5bdd310 --- /dev/null +++ b/src/cpp/client/secure_credentials.cc @@ -0,0 +1,533 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/client/secure_credentials.h" + +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(yashykt): We shouldn't be including "src/core" headers. +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "src/core/lib/security/util/json_util.h" +#include "src/cpp/client/create_channel_internal.h" +#include "src/cpp/common/secure_auth_context.h" + +namespace grpc { + +static grpc::internal::GrpcLibraryInitializer g_gli_initializer; +SecureChannelCredentials::SecureChannelCredentials( + grpc_channel_credentials* c_creds) + : c_creds_(c_creds) { + g_gli_initializer.summon(); +} + +std::shared_ptr SecureChannelCredentials::CreateChannelImpl( + const std::string& target, const ChannelArguments& args) { + return CreateChannelWithInterceptors( + target, args, + std::vector>()); +} + +std::shared_ptr +SecureChannelCredentials::CreateChannelWithInterceptors( + const std::string& target, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + grpc_channel_args channel_args; + args.SetChannelArgs(&channel_args); + return ::grpc::CreateChannelInternal( + args.GetSslTargetNameOverride(), + grpc_secure_channel_create(c_creds_, target.c_str(), &channel_args, + nullptr), + std::move(interceptor_creators)); +} + +SecureCallCredentials::SecureCallCredentials(grpc_call_credentials* c_creds) + : c_creds_(c_creds) { + g_gli_initializer.summon(); +} + +bool SecureCallCredentials::ApplyToCall(grpc_call* call) { + return grpc_call_set_credentials(call, c_creds_) == GRPC_CALL_OK; +} + +namespace internal { + +std::shared_ptr WrapChannelCredentials( + grpc_channel_credentials* creds) { + return creds == nullptr ? nullptr + : std::shared_ptr( + new SecureChannelCredentials(creds)); +} + +} // namespace internal + +namespace { + +std::shared_ptr WrapCallCredentials( + grpc_call_credentials* creds) { + return creds == nullptr ? nullptr + : std::shared_ptr( + new SecureCallCredentials(creds)); +} +} // namespace + +std::shared_ptr GoogleDefaultCredentials() { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return internal::WrapChannelCredentials( + grpc_google_default_credentials_create(nullptr)); +} + +std::shared_ptr ExternalAccountCredentials( + const grpc::string& json_string, const std::vector& scopes) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return WrapCallCredentials(grpc_external_account_credentials_create( + json_string.c_str(), absl::StrJoin(scopes, ",").c_str())); +} + +// Builds SSL Credentials given SSL specific options +std::shared_ptr SslCredentials( + const SslCredentialsOptions& options) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = { + options.pem_private_key.c_str(), options.pem_cert_chain.c_str()}; + + grpc_channel_credentials* c_creds = grpc_ssl_credentials_create( + options.pem_root_certs.empty() ? nullptr : options.pem_root_certs.c_str(), + options.pem_private_key.empty() ? nullptr : &pem_key_cert_pair, nullptr, + nullptr); + return internal::WrapChannelCredentials(c_creds); +} + +namespace experimental { + +namespace { + +void ClearStsCredentialsOptions(StsCredentialsOptions* options) { + if (options == nullptr) return; + options->token_exchange_service_uri.clear(); + options->resource.clear(); + options->audience.clear(); + options->scope.clear(); + options->requested_token_type.clear(); + options->subject_token_path.clear(); + options->subject_token_type.clear(); + options->actor_token_path.clear(); + options->actor_token_type.clear(); +} + +} // namespace + +// Builds STS credentials options from JSON. +grpc::Status StsCredentialsOptionsFromJson(const std::string& json_string, + StsCredentialsOptions* options) { + if (options == nullptr) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "options cannot be nullptr."); + } + ClearStsCredentialsOptions(options); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json json = grpc_core::Json::Parse(json_string.c_str(), &error); + if (error != GRPC_ERROR_NONE || + json.type() != grpc_core::Json::Type::OBJECT) { + GRPC_ERROR_UNREF(error); + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Invalid json."); + } + + // Required fields. + const char* value = grpc_json_get_string_property( + json, "token_exchange_service_uri", nullptr); + if (value == nullptr) { + ClearStsCredentialsOptions(options); + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "token_exchange_service_uri must be specified."); + } + options->token_exchange_service_uri.assign(value); + value = grpc_json_get_string_property(json, "subject_token_path", nullptr); + if (value == nullptr) { + ClearStsCredentialsOptions(options); + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "subject_token_path must be specified."); + } + options->subject_token_path.assign(value); + value = grpc_json_get_string_property(json, "subject_token_type", nullptr); + if (value == nullptr) { + ClearStsCredentialsOptions(options); + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "subject_token_type must be specified."); + } + options->subject_token_type.assign(value); + + // Optional fields. + value = grpc_json_get_string_property(json, "resource", nullptr); + if (value != nullptr) options->resource.assign(value); + value = grpc_json_get_string_property(json, "audience", nullptr); + if (value != nullptr) options->audience.assign(value); + value = grpc_json_get_string_property(json, "scope", nullptr); + if (value != nullptr) options->scope.assign(value); + value = grpc_json_get_string_property(json, "requested_token_type", nullptr); + if (value != nullptr) options->requested_token_type.assign(value); + value = grpc_json_get_string_property(json, "actor_token_path", nullptr); + if (value != nullptr) options->actor_token_path.assign(value); + value = grpc_json_get_string_property(json, "actor_token_type", nullptr); + if (value != nullptr) options->actor_token_type.assign(value); + + return grpc::Status(); +} + +// Builds STS credentials Options from the $STS_CREDENTIALS env var. +grpc::Status StsCredentialsOptionsFromEnv(StsCredentialsOptions* options) { + if (options == nullptr) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "options cannot be nullptr."); + } + ClearStsCredentialsOptions(options); + grpc_slice json_string = grpc_empty_slice(); + char* sts_creds_path = gpr_getenv("STS_CREDENTIALS"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc::Status status; + // NOLINTNEXTLINE(clang-diagnostic-unused-lambda-capture) + auto cleanup = [&json_string, &sts_creds_path, &error, &status]() { + grpc_slice_unref_internal(json_string); + gpr_free(sts_creds_path); + GRPC_ERROR_UNREF(error); + return status; + }; + + if (sts_creds_path == nullptr) { + status = grpc::Status(grpc::StatusCode::NOT_FOUND, + "STS_CREDENTIALS environment variable not set."); + return cleanup(); + } + error = grpc_load_file(sts_creds_path, 1, &json_string); + if (error != GRPC_ERROR_NONE) { + status = + grpc::Status(grpc::StatusCode::NOT_FOUND, grpc_error_std_string(error)); + return cleanup(); + } + status = StsCredentialsOptionsFromJson( + reinterpret_cast(GRPC_SLICE_START_PTR(json_string)), + options); + return cleanup(); +} + +// C++ to Core STS Credentials options. +grpc_sts_credentials_options StsCredentialsCppToCoreOptions( + const StsCredentialsOptions& options) { + grpc_sts_credentials_options opts; + memset(&opts, 0, sizeof(opts)); + opts.token_exchange_service_uri = options.token_exchange_service_uri.c_str(); + opts.resource = options.resource.c_str(); + opts.audience = options.audience.c_str(); + opts.scope = options.scope.c_str(); + opts.requested_token_type = options.requested_token_type.c_str(); + opts.subject_token_path = options.subject_token_path.c_str(); + opts.subject_token_type = options.subject_token_type.c_str(); + opts.actor_token_path = options.actor_token_path.c_str(); + opts.actor_token_type = options.actor_token_type.c_str(); + return opts; +} + +// Builds STS credentials. +std::shared_ptr StsCredentials( + const StsCredentialsOptions& options) { + auto opts = StsCredentialsCppToCoreOptions(options); + return WrapCallCredentials(grpc_sts_credentials_create(&opts, nullptr)); +} + +std::shared_ptr MetadataCredentialsFromPlugin( + std::unique_ptr plugin, + grpc_security_level min_security_level) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + const char* type = plugin->GetType(); + grpc::MetadataCredentialsPluginWrapper* wrapper = + new grpc::MetadataCredentialsPluginWrapper(std::move(plugin)); + grpc_metadata_credentials_plugin c_plugin = { + grpc::MetadataCredentialsPluginWrapper::GetMetadata, + grpc::MetadataCredentialsPluginWrapper::DebugString, + grpc::MetadataCredentialsPluginWrapper::Destroy, wrapper, type}; + return WrapCallCredentials(grpc_metadata_credentials_create_from_plugin( + c_plugin, min_security_level, nullptr)); +} + +// Builds ALTS Credentials given ALTS specific options +std::shared_ptr AltsCredentials( + const AltsCredentialsOptions& options) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + grpc_alts_credentials_options* c_options = + grpc_alts_credentials_client_options_create(); + for (const auto& service_account : options.target_service_accounts) { + grpc_alts_credentials_client_options_add_target_service_account( + c_options, service_account.c_str()); + } + grpc_channel_credentials* c_creds = grpc_alts_credentials_create(c_options); + grpc_alts_credentials_options_destroy(c_options); + return internal::WrapChannelCredentials(c_creds); +} + +// Builds Local Credentials +std::shared_ptr LocalCredentials( + grpc_local_connect_type type) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return internal::WrapChannelCredentials(grpc_local_credentials_create(type)); +} + +// Builds TLS Credentials given TLS options. +std::shared_ptr TlsCredentials( + const TlsChannelCredentialsOptions& options) { + return internal::WrapChannelCredentials( + grpc_tls_credentials_create(options.c_credentials_options())); +} + +} // namespace experimental + +// Builds credentials for use when running in GCE +std::shared_ptr GoogleComputeEngineCredentials() { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return WrapCallCredentials( + grpc_google_compute_engine_credentials_create(nullptr)); +} + +// Builds JWT credentials. +std::shared_ptr ServiceAccountJWTAccessCredentials( + const std::string& json_key, long token_lifetime_seconds) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + if (token_lifetime_seconds <= 0) { + gpr_log(GPR_ERROR, + "Trying to create JWTCredentials with non-positive lifetime"); + return WrapCallCredentials(nullptr); + } + gpr_timespec lifetime = + gpr_time_from_seconds(token_lifetime_seconds, GPR_TIMESPAN); + return WrapCallCredentials(grpc_service_account_jwt_access_credentials_create( + json_key.c_str(), lifetime, nullptr)); +} + +// Builds refresh token credentials. +std::shared_ptr GoogleRefreshTokenCredentials( + const std::string& json_refresh_token) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return WrapCallCredentials(grpc_google_refresh_token_credentials_create( + json_refresh_token.c_str(), nullptr)); +} + +// Builds access token credentials. +std::shared_ptr AccessTokenCredentials( + const std::string& access_token) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return WrapCallCredentials( + grpc_access_token_credentials_create(access_token.c_str(), nullptr)); +} + +// Builds IAM credentials. +std::shared_ptr GoogleIAMCredentials( + const std::string& authorization_token, + const std::string& authority_selector) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + return WrapCallCredentials(grpc_google_iam_credentials_create( + authorization_token.c_str(), authority_selector.c_str(), nullptr)); +} + +// Combines one channel credentials and one call credentials into a channel +// composite credentials. +std::shared_ptr CompositeChannelCredentials( + const std::shared_ptr& channel_creds, + const std::shared_ptr& call_creds) { + // Note that we are not saving shared_ptrs to the two credentials passed in + // here. This is OK because the underlying C objects (i.e., channel_creds and + // call_creds) into grpc_composite_credentials_create will see their refcounts + // incremented. + SecureChannelCredentials* s_channel_creds = + channel_creds->AsSecureCredentials(); + SecureCallCredentials* s_call_creds = call_creds->AsSecureCredentials(); + if (s_channel_creds && s_call_creds) { + return internal::WrapChannelCredentials( + grpc_composite_channel_credentials_create( + s_channel_creds->GetRawCreds(), s_call_creds->GetRawCreds(), + nullptr)); + } + return nullptr; +} + +std::shared_ptr CompositeCallCredentials( + const std::shared_ptr& creds1, + const std::shared_ptr& creds2) { + SecureCallCredentials* s_creds1 = creds1->AsSecureCredentials(); + SecureCallCredentials* s_creds2 = creds2->AsSecureCredentials(); + if (s_creds1 != nullptr && s_creds2 != nullptr) { + return WrapCallCredentials(grpc_composite_call_credentials_create( + s_creds1->GetRawCreds(), s_creds2->GetRawCreds(), nullptr)); + } + return nullptr; +} + +std::shared_ptr MetadataCredentialsFromPlugin( + std::unique_ptr plugin) { + grpc::GrpcLibraryCodegen init; // To call grpc_init(). + const char* type = plugin->GetType(); + grpc::MetadataCredentialsPluginWrapper* wrapper = + new grpc::MetadataCredentialsPluginWrapper(std::move(plugin)); + grpc_metadata_credentials_plugin c_plugin = { + grpc::MetadataCredentialsPluginWrapper::GetMetadata, + grpc::MetadataCredentialsPluginWrapper::DebugString, + grpc::MetadataCredentialsPluginWrapper::Destroy, wrapper, type}; + return WrapCallCredentials(grpc_metadata_credentials_create_from_plugin( + c_plugin, GRPC_PRIVACY_AND_INTEGRITY, nullptr)); +} + +namespace { +void DeleteWrapper(void* wrapper, grpc_error_handle /*ignored*/) { + MetadataCredentialsPluginWrapper* w = + static_cast(wrapper); + delete w; +} +} // namespace + +char* MetadataCredentialsPluginWrapper::DebugString(void* wrapper) { + GPR_ASSERT(wrapper); + MetadataCredentialsPluginWrapper* w = + static_cast(wrapper); + return gpr_strdup(w->plugin_->DebugString().c_str()); +} + +void MetadataCredentialsPluginWrapper::Destroy(void* wrapper) { + if (wrapper == nullptr) return; + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_core::Executor::Run(GRPC_CLOSURE_CREATE(DeleteWrapper, wrapper, nullptr), + GRPC_ERROR_NONE); +} + +int MetadataCredentialsPluginWrapper::GetMetadata( + void* wrapper, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb cb, void* user_data, + grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], + size_t* num_creds_md, grpc_status_code* status, + const char** error_details) { + GPR_ASSERT(wrapper); + MetadataCredentialsPluginWrapper* w = + static_cast(wrapper); + if (!w->plugin_) { + *num_creds_md = 0; + *status = GRPC_STATUS_OK; + *error_details = nullptr; + return 1; + } + if (w->plugin_->IsBlocking()) { + // The internals of context may be destroyed if GetMetadata is cancelled. + // Make a copy for InvokePlugin. + grpc_auth_metadata_context context_copy = grpc_auth_metadata_context(); + grpc_auth_metadata_context_copy(&context, &context_copy); + // Asynchronous return. + w->thread_pool_->Add([w, context_copy, cb, user_data]() mutable { + w->MetadataCredentialsPluginWrapper::InvokePlugin( + context_copy, cb, user_data, nullptr, nullptr, nullptr, nullptr); + grpc_auth_metadata_context_reset(&context_copy); + }); + return 0; + } else { + // Synchronous return. + w->InvokePlugin(context, cb, user_data, creds_md, num_creds_md, status, + error_details); + return 1; + } +} + +namespace { + +void UnrefMetadata(const std::vector& md) { + for (const auto& metadatum : md) { + grpc_slice_unref(metadatum.key); + grpc_slice_unref(metadatum.value); + } +} + +} // namespace + +void MetadataCredentialsPluginWrapper::InvokePlugin( + grpc_auth_metadata_context context, grpc_credentials_plugin_metadata_cb cb, + void* user_data, grpc_metadata creds_md[4], size_t* num_creds_md, + grpc_status_code* status_code, const char** error_details) { + std::multimap metadata; + + // const_cast is safe since the SecureAuthContext only inc/dec the refcount + // and the object is passed as a const ref to plugin_->GetMetadata. + SecureAuthContext cpp_channel_auth_context( + const_cast(context.channel_auth_context)); + + Status status = plugin_->GetMetadata(context.service_url, context.method_name, + cpp_channel_auth_context, &metadata); + std::vector md; + for (auto& metadatum : metadata) { + grpc_metadata md_entry; + md_entry.key = SliceFromCopiedString(metadatum.first); + md_entry.value = SliceFromCopiedString(metadatum.second); + md.push_back(md_entry); + } + if (creds_md != nullptr) { + // Synchronous return. + if (md.size() > GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX) { + *num_creds_md = 0; + *status_code = GRPC_STATUS_INTERNAL; + *error_details = gpr_strdup( + "blocking plugin credentials returned too many metadata keys"); + UnrefMetadata(md); + } else { + for (const auto& elem : md) { + creds_md[*num_creds_md].key = elem.key; + creds_md[*num_creds_md].value = elem.value; + ++(*num_creds_md); + } + *status_code = static_cast(status.error_code()); + *error_details = + status.ok() ? nullptr : gpr_strdup(status.error_message().c_str()); + } + } else { + // Asynchronous return. + cb(user_data, md.empty() ? nullptr : &md[0], md.size(), + static_cast(status.error_code()), + status.error_message().c_str()); + UnrefMetadata(md); + } +} + +MetadataCredentialsPluginWrapper::MetadataCredentialsPluginWrapper( + std::unique_ptr plugin) + : plugin_(std::move(plugin)) { + if (plugin_->IsBlocking()) { + thread_pool_.reset(CreateDefaultThreadPool()); + } +} + +} // namespace grpc diff --git a/src/cpp/client/xds_credentials.cc b/src/cpp/client/xds_credentials.cc new file mode 100644 index 00000000..d5446a02 --- /dev/null +++ b/src/cpp/client/xds_credentials.cc @@ -0,0 +1,47 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/cpp/client/secure_credentials.h" + +namespace grpc { + +std::shared_ptr XdsCredentials( + const std::shared_ptr& fallback_creds) { + GPR_ASSERT(fallback_creds != nullptr); + if (fallback_creds->IsInsecure()) { + grpc_channel_credentials* insecure_creds = + grpc_insecure_credentials_create(); + auto xds_creds = internal::WrapChannelCredentials( + grpc_xds_credentials_create(insecure_creds)); + grpc_channel_credentials_release(insecure_creds); + return xds_creds; + } else { + return internal::WrapChannelCredentials(grpc_xds_credentials_create( + fallback_creds->AsSecureCredentials()->GetRawCreds())); + } +} + +namespace experimental { + +std::shared_ptr XdsCredentials( + const std::shared_ptr& fallback_creds) { + return grpc::XdsCredentials(fallback_creds); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/codegen/codegen_init.cc b/src/cpp/codegen/codegen_init.cc new file mode 100644 index 00000000..e1e47cbb --- /dev/null +++ b/src/cpp/codegen/codegen_init.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +/// Null-initializes the global gRPC variables for the codegen library. These +/// stay null in the absence of grpc++ library. In this case, no gRPC +/// features such as the ability to perform calls will be available. Trying to +/// perform them would result in a segmentation fault when trying to deference +/// the following nulled globals. These should be associated with actual +/// as part of the instantiation of a \a grpc::GrpcLibraryInitializer variable. + +grpc::CoreCodegenInterface* grpc::g_core_codegen_interface; +grpc::GrpcLibraryInterface* grpc::g_glip; diff --git a/src/cpp/common/alarm.cc b/src/cpp/common/alarm.cc new file mode 100644 index 00000000..a367b53d --- /dev/null +++ b/src/cpp/common/alarm.cc @@ -0,0 +1,160 @@ +/* + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/surface/completion_queue.h" + +namespace grpc { + +namespace internal { +class AlarmImpl : public ::grpc::internal::CompletionQueueTag { + public: + AlarmImpl() : cq_(nullptr), tag_(nullptr) { + gpr_ref_init(&refs_, 1); + grpc_timer_init_unset(&timer_); + } + ~AlarmImpl() override {} + bool FinalizeResult(void** tag, bool* /*status*/) override { + *tag = tag_; + Unref(); + return true; + } + void Set(::grpc::CompletionQueue* cq, gpr_timespec deadline, void* tag) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + GRPC_CQ_INTERNAL_REF(cq->cq(), "alarm"); + cq_ = cq->cq(); + tag_ = tag; + GPR_ASSERT(grpc_cq_begin_op(cq_, this)); + GRPC_CLOSURE_INIT( + &on_alarm_, + [](void* arg, grpc_error_handle error) { + // queue the op on the completion queue + AlarmImpl* alarm = static_cast(arg); + alarm->Ref(); + // Preserve the cq and reset the cq_ so that the alarm + // can be reset when the alarm tag is delivered. + grpc_completion_queue* cq = alarm->cq_; + alarm->cq_ = nullptr; + grpc_cq_end_op( + cq, alarm, error, + [](void* /*arg*/, grpc_cq_completion* /*completion*/) {}, arg, + &alarm->completion_); + GRPC_CQ_INTERNAL_UNREF(cq, "alarm"); + }, + this, grpc_schedule_on_exec_ctx); + grpc_timer_init(&timer_, grpc_timespec_to_millis_round_up(deadline), + &on_alarm_); + } + void Set(gpr_timespec deadline, std::function f) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + // Don't use any CQ at all. Instead just use the timer to fire the function + callback_ = std::move(f); + Ref(); + GRPC_CLOSURE_INIT( + &on_alarm_, + [](void* arg, grpc_error_handle error) { + grpc_core::Executor::Run( + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle error) { + AlarmImpl* alarm = static_cast(arg); + alarm->callback_(error == GRPC_ERROR_NONE); + alarm->Unref(); + }, + arg, nullptr), + error); + }, + this, grpc_schedule_on_exec_ctx); + grpc_timer_init(&timer_, grpc_timespec_to_millis_round_up(deadline), + &on_alarm_); + } + void Cancel() { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_timer_cancel(&timer_); + } + void Destroy() { + Cancel(); + Unref(); + } + + private: + void Ref() { gpr_ref(&refs_); } + void Unref() { + if (gpr_unref(&refs_)) { + delete this; + } + } + + grpc_timer timer_; + gpr_refcount refs_; + grpc_closure on_alarm_; + grpc_cq_completion completion_; + // completion queue where events about this alarm will be posted + grpc_completion_queue* cq_; + void* tag_; + std::function callback_; +}; +} // namespace internal + +static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer; + +Alarm::Alarm() : alarm_(new internal::AlarmImpl()) { + g_gli_initializer.summon(); +} + +void Alarm::SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline, + void* tag) { + // Note that we know that alarm_ is actually an internal::AlarmImpl + // but we declared it as the base pointer to avoid a forward declaration + // or exposing core data structures in the C++ public headers. + // Thus it is safe to use a static_cast to the subclass here, and the + // C++ style guide allows us to do so in this case + static_cast(alarm_)->Set(cq, deadline, tag); +} + +void Alarm::SetInternal(gpr_timespec deadline, std::function f) { + // Note that we know that alarm_ is actually an internal::AlarmImpl + // but we declared it as the base pointer to avoid a forward declaration + // or exposing core data structures in the C++ public headers. + // Thus it is safe to use a static_cast to the subclass here, and the + // C++ style guide allows us to do so in this case + static_cast(alarm_)->Set(deadline, std::move(f)); +} + +Alarm::~Alarm() { + if (alarm_ != nullptr) { + static_cast(alarm_)->Destroy(); + } +} + +void Alarm::Cancel() { static_cast(alarm_)->Cancel(); } +} // namespace grpc diff --git a/src/cpp/common/alts_context.cc b/src/cpp/common/alts_context.cc new file mode 100644 index 00000000..0b7e7307 --- /dev/null +++ b/src/cpp/common/alts_context.cc @@ -0,0 +1,127 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/proto/grpc/gcp/altscontext.upb.h" + +namespace grpc { +namespace experimental { + +// A upb-generated grpc_gcp_AltsContext is passed in to construct an +// AltsContext. Normal users should use GetAltsContextFromAuthContext to get +// AltsContext, instead of constructing their own. +AltsContext::AltsContext(const grpc_gcp_AltsContext* ctx) { + upb_strview application_protocol = + grpc_gcp_AltsContext_application_protocol(ctx); + if (application_protocol.data != nullptr && application_protocol.size > 0) { + application_protocol_ = + std::string(application_protocol.data, application_protocol.size); + } + upb_strview record_protocol = grpc_gcp_AltsContext_record_protocol(ctx); + if (record_protocol.data != nullptr && record_protocol.size > 0) { + record_protocol_ = std::string(record_protocol.data, record_protocol.size); + } + upb_strview peer_service_account = + grpc_gcp_AltsContext_peer_service_account(ctx); + if (peer_service_account.data != nullptr && peer_service_account.size > 0) { + peer_service_account_ = + std::string(peer_service_account.data, peer_service_account.size); + } + upb_strview local_service_account = + grpc_gcp_AltsContext_local_service_account(ctx); + if (local_service_account.data != nullptr && local_service_account.size > 0) { + local_service_account_ = + std::string(local_service_account.data, local_service_account.size); + } + const grpc_gcp_RpcProtocolVersions* versions = + grpc_gcp_AltsContext_peer_rpc_versions(ctx); + if (versions != nullptr) { + const grpc_gcp_RpcProtocolVersions_Version* max_version = + grpc_gcp_RpcProtocolVersions_max_rpc_version(versions); + if (max_version != nullptr) { + int max_version_major = + grpc_gcp_RpcProtocolVersions_Version_major(max_version); + int max_version_minor = + grpc_gcp_RpcProtocolVersions_Version_minor(max_version); + peer_rpc_versions_.max_rpc_version.major_version = max_version_major; + peer_rpc_versions_.max_rpc_version.minor_version = max_version_minor; + } + const grpc_gcp_RpcProtocolVersions_Version* min_version = + grpc_gcp_RpcProtocolVersions_min_rpc_version(versions); + if (min_version != nullptr) { + int min_version_major = + grpc_gcp_RpcProtocolVersions_Version_major(min_version); + int min_version_minor = + grpc_gcp_RpcProtocolVersions_Version_minor(min_version); + peer_rpc_versions_.min_rpc_version.major_version = min_version_major; + peer_rpc_versions_.min_rpc_version.minor_version = min_version_minor; + } + } + if (grpc_gcp_AltsContext_security_level(ctx) >= GRPC_SECURITY_MIN || + grpc_gcp_AltsContext_security_level(ctx) <= GRPC_SECURITY_MAX) { + security_level_ = static_cast( + grpc_gcp_AltsContext_security_level(ctx)); + } + if (grpc_gcp_AltsContext_has_peer_attributes(ctx)) { + size_t iter = UPB_MAP_BEGIN; + const grpc_gcp_AltsContext_PeerAttributesEntry* peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_next(ctx, &iter); + while (peer_attributes_entry != nullptr) { + upb_strview key = + grpc_gcp_AltsContext_PeerAttributesEntry_key(peer_attributes_entry); + upb_strview val = + grpc_gcp_AltsContext_PeerAttributesEntry_value(peer_attributes_entry); + peer_attributes_map_[std::string(key.data, key.size)] = + std::string(val.data, val.size); + peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_next(ctx, &iter); + } + } +} + +std::string AltsContext::application_protocol() const { + return application_protocol_; +} + +std::string AltsContext::record_protocol() const { return record_protocol_; } + +std::string AltsContext::peer_service_account() const { + return peer_service_account_; +} + +std::string AltsContext::local_service_account() const { + return local_service_account_; +} + +grpc_security_level AltsContext::security_level() const { + return security_level_; +} + +AltsContext::RpcProtocolVersions AltsContext::peer_rpc_versions() const { + return peer_rpc_versions_; +} + +const std::map& AltsContext::peer_attributes() const { + return peer_attributes_map_; +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/alts_util.cc b/src/cpp/common/alts_util.cc new file mode 100644 index 00000000..b1dc38d0 --- /dev/null +++ b/src/cpp/common/alts_util.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "upb/upb.hpp" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/cpp/common/secure_auth_context.h" +#include "src/proto/grpc/gcp/altscontext.upb.h" + +namespace grpc { +namespace experimental { + +std::unique_ptr GetAltsContextFromAuthContext( + const std::shared_ptr& auth_context) { + if (auth_context == nullptr) { + gpr_log(GPR_ERROR, "auth_context is nullptr."); + return nullptr; + } + std::vector ctx_vector = + auth_context->FindPropertyValues(TSI_ALTS_CONTEXT); + if (ctx_vector.size() != 1) { + gpr_log(GPR_ERROR, "contains zero or more than one ALTS context."); + return nullptr; + } + upb::Arena context_arena; + grpc_gcp_AltsContext* ctx = grpc_gcp_AltsContext_parse( + ctx_vector[0].data(), ctx_vector[0].size(), context_arena.ptr()); + if (ctx == nullptr) { + gpr_log(GPR_ERROR, "fails to parse ALTS context."); + return nullptr; + } + if (grpc_gcp_AltsContext_security_level(ctx) < GRPC_SECURITY_MIN || + grpc_gcp_AltsContext_security_level(ctx) > GRPC_SECURITY_MAX) { + gpr_log(GPR_ERROR, "security_level is invalid."); + return nullptr; + } + return absl::make_unique(AltsContext(ctx)); +} + +grpc::Status AltsClientAuthzCheck( + const std::shared_ptr& auth_context, + const std::vector& expected_service_accounts) { + std::unique_ptr alts_ctx = + GetAltsContextFromAuthContext(auth_context); + if (alts_ctx == nullptr) { + return grpc::Status(grpc::StatusCode::PERMISSION_DENIED, + "fails to parse ALTS context."); + } + if (std::find(expected_service_accounts.begin(), + expected_service_accounts.end(), + alts_ctx->peer_service_account()) != + expected_service_accounts.end()) { + return grpc::Status::OK; + } + return grpc::Status( + grpc::StatusCode::PERMISSION_DENIED, + "client " + alts_ctx->peer_service_account() + " is not authorized."); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/auth_property_iterator.cc b/src/cpp/common/auth_property_iterator.cc new file mode 100644 index 00000000..1334ea99 --- /dev/null +++ b/src/cpp/common/auth_property_iterator.cc @@ -0,0 +1,69 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +AuthPropertyIterator::AuthPropertyIterator() + : property_(nullptr), ctx_(nullptr), index_(0), name_(nullptr) {} + +AuthPropertyIterator::AuthPropertyIterator( + const grpc_auth_property* property, const grpc_auth_property_iterator* iter) + : property_(property), + ctx_(iter->ctx), + index_(iter->index), + name_(iter->name) {} + +AuthPropertyIterator::~AuthPropertyIterator() {} + +AuthPropertyIterator& AuthPropertyIterator::operator++() { + grpc_auth_property_iterator iter = {ctx_, index_, name_}; + property_ = grpc_auth_property_iterator_next(&iter); + ctx_ = iter.ctx; + index_ = iter.index; + name_ = iter.name; + return *this; +} + +AuthPropertyIterator AuthPropertyIterator::operator++(int) { + AuthPropertyIterator tmp(*this); + operator++(); + return tmp; +} + +bool AuthPropertyIterator::operator==(const AuthPropertyIterator& rhs) const { + if (property_ == nullptr || rhs.property_ == nullptr) { + return property_ == rhs.property_; + } else { + return index_ == rhs.index_; + } +} + +bool AuthPropertyIterator::operator!=(const AuthPropertyIterator& rhs) const { + return !operator==(rhs); +} + +AuthProperty AuthPropertyIterator::operator*() { + return std::pair( + property_->name, + grpc::string_ref(property_->value, property_->value_length)); +} + +} // namespace grpc diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc new file mode 100644 index 00000000..6c45dda2 --- /dev/null +++ b/src/cpp/common/channel_arguments.cc @@ -0,0 +1,217 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/socket_mutator.h" + +namespace grpc { + +ChannelArguments::ChannelArguments() { + // This will be ignored if used on the server side. + SetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING, "grpc-c++/" + grpc::Version()); +} + +ChannelArguments::ChannelArguments(const ChannelArguments& other) + : strings_(other.strings_) { + args_.reserve(other.args_.size()); + auto list_it_dst = strings_.begin(); + auto list_it_src = other.strings_.begin(); + for (const auto& a : other.args_) { + grpc_arg ap; + ap.type = a.type; + GPR_ASSERT(list_it_src->c_str() == a.key); + ap.key = const_cast(list_it_dst->c_str()); + ++list_it_src; + ++list_it_dst; + switch (a.type) { + case GRPC_ARG_INTEGER: + ap.value.integer = a.value.integer; + break; + case GRPC_ARG_STRING: + GPR_ASSERT(list_it_src->c_str() == a.value.string); + ap.value.string = const_cast(list_it_dst->c_str()); + ++list_it_src; + ++list_it_dst; + break; + case GRPC_ARG_POINTER: + ap.value.pointer = a.value.pointer; + ap.value.pointer.p = a.value.pointer.vtable->copy(ap.value.pointer.p); + break; + } + args_.push_back(ap); + } +} + +ChannelArguments::~ChannelArguments() { + for (auto& arg : args_) { + if (arg.type == GRPC_ARG_POINTER) { + grpc_core::ExecCtx exec_ctx; + arg.value.pointer.vtable->destroy(arg.value.pointer.p); + } + } +} + +void ChannelArguments::Swap(ChannelArguments& other) { + args_.swap(other.args_); + strings_.swap(other.strings_); +} + +void ChannelArguments::SetCompressionAlgorithm( + grpc_compression_algorithm algorithm) { + SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, algorithm); +} + +void ChannelArguments::SetGrpclbFallbackTimeout(int fallback_timeout) { + SetInt(GRPC_ARG_GRPCLB_FALLBACK_TIMEOUT_MS, fallback_timeout); +} + +void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) { + if (!mutator) { + return; + } + grpc_arg mutator_arg = grpc_socket_mutator_to_arg(mutator); + bool replaced = false; + grpc_core::ExecCtx exec_ctx; + for (auto& arg : args_) { + if (arg.type == mutator_arg.type && + std::string(arg.key) == std::string(mutator_arg.key)) { + GPR_ASSERT(!replaced); + arg.value.pointer.vtable->destroy(arg.value.pointer.p); + arg.value.pointer = mutator_arg.value.pointer; + replaced = true; + } + } + + if (!replaced) { + strings_.push_back(std::string(mutator_arg.key)); + args_.push_back(mutator_arg); + args_.back().key = const_cast(strings_.back().c_str()); + } +} + +// Note: a second call to this will add in front the result of the first call. +// An example is calling this on a copy of ChannelArguments which already has a +// prefix. The user can build up a prefix string by calling this multiple times, +// each with more significant identifier. +void ChannelArguments::SetUserAgentPrefix( + const std::string& user_agent_prefix) { + if (user_agent_prefix.empty()) { + return; + } + bool replaced = false; + auto strings_it = strings_.begin(); + for (auto& arg : args_) { + ++strings_it; + if (arg.type == GRPC_ARG_STRING) { + if (std::string(arg.key) == GRPC_ARG_PRIMARY_USER_AGENT_STRING) { + GPR_ASSERT(arg.value.string == strings_it->c_str()); + *(strings_it) = user_agent_prefix + " " + arg.value.string; + arg.value.string = const_cast(strings_it->c_str()); + replaced = true; + break; + } + ++strings_it; + } + } + if (!replaced) { + SetString(GRPC_ARG_PRIMARY_USER_AGENT_STRING, user_agent_prefix); + } +} + +void ChannelArguments::SetResourceQuota( + const grpc::ResourceQuota& resource_quota) { + SetPointerWithVtable(GRPC_ARG_RESOURCE_QUOTA, + resource_quota.c_resource_quota(), + grpc_resource_quota_arg_vtable()); +} + +void ChannelArguments::SetMaxReceiveMessageSize(int size) { + SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, size); +} + +void ChannelArguments::SetMaxSendMessageSize(int size) { + SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, size); +} + +void ChannelArguments::SetLoadBalancingPolicyName( + const std::string& lb_policy_name) { + SetString(GRPC_ARG_LB_POLICY_NAME, lb_policy_name); +} + +void ChannelArguments::SetServiceConfigJSON( + const std::string& service_config_json) { + SetString(GRPC_ARG_SERVICE_CONFIG, service_config_json); +} + +void ChannelArguments::SetInt(const std::string& key, int value) { + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + strings_.push_back(key); + arg.key = const_cast(strings_.back().c_str()); + arg.value.integer = value; + + args_.push_back(arg); +} + +void ChannelArguments::SetPointer(const std::string& key, void* value) { + static const grpc_arg_pointer_vtable vtable = { + &PointerVtableMembers::Copy, &PointerVtableMembers::Destroy, + &PointerVtableMembers::Compare}; + SetPointerWithVtable(key, value, &vtable); +} + +void ChannelArguments::SetPointerWithVtable( + const std::string& key, void* value, + const grpc_arg_pointer_vtable* vtable) { + grpc_arg arg; + arg.type = GRPC_ARG_POINTER; + strings_.push_back(key); + arg.key = const_cast(strings_.back().c_str()); + arg.value.pointer.p = vtable->copy(value); + arg.value.pointer.vtable = vtable; + args_.push_back(arg); +} + +void ChannelArguments::SetString(const std::string& key, + const std::string& value) { + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + strings_.push_back(key); + arg.key = const_cast(strings_.back().c_str()); + strings_.push_back(value); + arg.value.string = const_cast(strings_.back().c_str()); + + args_.push_back(arg); +} + +void ChannelArguments::SetChannelArgs(grpc_channel_args* channel_args) const { + channel_args->num_args = args_.size(); + if (channel_args->num_args > 0) { + channel_args->args = const_cast(&args_[0]); + } +} + +} // namespace grpc diff --git a/src/cpp/common/channel_filter.cc b/src/cpp/common/channel_filter.cc new file mode 100644 index 00000000..373aebf6 --- /dev/null +++ b/src/cpp/common/channel_filter.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/common/channel_filter.h" + +#include + +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" + +namespace grpc { + +// MetadataBatch + +grpc_linked_mdelem* MetadataBatch::AddMetadata(const string& key, + const string& value) { + grpc_linked_mdelem* storage = new grpc_linked_mdelem; + storage->md = grpc_mdelem_from_slices(SliceFromCopiedString(key), + SliceFromCopiedString(value)); + GRPC_LOG_IF_ERROR("MetadataBatch::AddMetadata", batch_->LinkHead(storage)); + return storage; +} + +// ChannelData + +void ChannelData::StartTransportOp(grpc_channel_element* elem, + TransportOp* op) { + grpc_channel_next_op(elem, op->op()); +} + +void ChannelData::GetInfo(grpc_channel_element* elem, + const grpc_channel_info* channel_info) { + grpc_channel_next_get_info(elem, channel_info); +} + +// CallData + +void CallData::StartTransportStreamOpBatch(grpc_call_element* elem, + TransportStreamOpBatch* op) { + grpc_call_next_op(elem, op->op()); +} + +void CallData::SetPollsetOrPollsetSet(grpc_call_element* elem, + grpc_polling_entity* pollent) { + grpc_call_stack_ignore_set_pollset_or_pollset_set(elem, pollent); +} + +namespace internal { + +void RegisterChannelFilter( + grpc_channel_stack_type stack_type, int priority, + std::function include_filter, + const grpc_channel_filter* filter) { + auto maybe_add_filter = [include_filter, + filter](grpc_channel_stack_builder* builder) { + if (include_filter != nullptr) { + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (!include_filter(*args)) return true; + } + return grpc_channel_stack_builder_prepend_filter(builder, filter, nullptr, + nullptr); + }; + grpc_core::CoreConfiguration::RegisterBuilder( + [stack_type, priority, + maybe_add_filter](grpc_core::CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage(stack_type, priority, + maybe_add_filter); + }); +} + +} // namespace internal + +} // namespace grpc diff --git a/src/cpp/common/completion_queue_cc.cc b/src/cpp/common/completion_queue_cc.cc new file mode 100644 index 00000000..b962c527 --- /dev/null +++ b/src/cpp/common/completion_queue_cc.cc @@ -0,0 +1,206 @@ +/* + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/thd.h" + +namespace grpc { +namespace { + +internal::GrpcLibraryInitializer g_gli_initializer; + +gpr_once g_once_init_callback_alternative = GPR_ONCE_INIT; +grpc_core::Mutex* g_callback_alternative_mu; + +// Implement a ref-counted callback CQ for global use in the alternative +// implementation so that its threads are only created once. Do this using +// explicit ref-counts and raw pointers rather than a shared-ptr since that +// has a non-trivial destructor and thus can't be used for global variables. +struct CallbackAlternativeCQ { + int refs ABSL_GUARDED_BY(g_callback_alternative_mu) = 0; + CompletionQueue* cq ABSL_GUARDED_BY(g_callback_alternative_mu); + std::vector* nexting_threads + ABSL_GUARDED_BY(g_callback_alternative_mu); + + CompletionQueue* Ref() { + grpc_core::MutexLock lock(&*g_callback_alternative_mu); + refs++; + if (refs == 1) { + cq = new CompletionQueue; + int num_nexting_threads = + grpc_core::Clamp(gpr_cpu_num_cores() / 2, 2u, 16u); + nexting_threads = new std::vector; + for (int i = 0; i < num_nexting_threads; i++) { + nexting_threads->emplace_back( + "nexting_thread", + [](void* arg) { + grpc_completion_queue* cq = + static_cast(arg)->cq(); + while (true) { + // Use the raw Core next function rather than the C++ Next since + // Next incorporates FinalizeResult and we actually want that + // called from the callback functor itself. + // TODO(vjpai): Migrate below to next without a timeout or idle + // phase. That's currently starving out some other polling, + // though. + auto ev = grpc_completion_queue_next( + cq, + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(1000, GPR_TIMESPAN)), + nullptr); + if (ev.type == GRPC_QUEUE_SHUTDOWN) { + return; + } + if (ev.type == GRPC_QUEUE_TIMEOUT) { + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(100, GPR_TIMESPAN))); + continue; + } + GPR_DEBUG_ASSERT(ev.type == GRPC_OP_COMPLETE); + // We can always execute the callback inline rather than + // pushing it to another Executor thread because this + // thread is definitely running on a background thread, does not + // hold any application locks before executing the callback, + // and cannot be entered recursively. + auto* functor = + static_cast(ev.tag); + functor->functor_run(functor, ev.success); + } + }, + cq); + } + for (auto& th : *nexting_threads) { + th.Start(); + } + } + return cq; + } + + void Unref() { + grpc_core::MutexLock lock(g_callback_alternative_mu); + refs--; + if (refs == 0) { + cq->Shutdown(); + for (auto& th : *nexting_threads) { + th.Join(); + } + delete nexting_threads; + delete cq; + } + } +}; + +CallbackAlternativeCQ g_callback_alternative_cq; + +} // namespace + +// 'CompletionQueue' constructor can safely call GrpcLibraryCodegen(false) here +// i.e not have GrpcLibraryCodegen call grpc_init(). This is because, to create +// a 'grpc_completion_queue' instance (which is being passed as the input to +// this constructor), one must have already called grpc_init(). +CompletionQueue::CompletionQueue(grpc_completion_queue* take) + : GrpcLibraryCodegen(false), cq_(take) { + InitialAvalanching(); +} + +void CompletionQueue::Shutdown() { + g_gli_initializer.summon(); +#ifndef NDEBUG + if (!ServerListEmpty()) { + gpr_log(GPR_ERROR, + "CompletionQueue shutdown being shutdown before its server."); + } +#endif + CompleteAvalanching(); +} + +CompletionQueue::NextStatus CompletionQueue::AsyncNextInternal( + void** tag, bool* ok, gpr_timespec deadline) { + for (;;) { + auto ev = grpc_completion_queue_next(cq_, deadline, nullptr); + switch (ev.type) { + case GRPC_QUEUE_TIMEOUT: + return TIMEOUT; + case GRPC_QUEUE_SHUTDOWN: + return SHUTDOWN; + case GRPC_OP_COMPLETE: + auto core_cq_tag = + static_cast<::grpc::internal::CompletionQueueTag*>(ev.tag); + *ok = ev.success != 0; + *tag = core_cq_tag; + if (core_cq_tag->FinalizeResult(tag, ok)) { + return GOT_EVENT; + } + break; + } + } +} + +CompletionQueue::CompletionQueueTLSCache::CompletionQueueTLSCache( + CompletionQueue* cq) + : cq_(cq), flushed_(false) { + grpc_completion_queue_thread_local_cache_init(cq_->cq_); +} + +CompletionQueue::CompletionQueueTLSCache::~CompletionQueueTLSCache() { + GPR_ASSERT(flushed_); +} + +bool CompletionQueue::CompletionQueueTLSCache::Flush(void** tag, bool* ok) { + int res = 0; + void* res_tag; + flushed_ = true; + if (grpc_completion_queue_thread_local_cache_flush(cq_->cq_, &res_tag, + &res)) { + auto core_cq_tag = + static_cast<::grpc::internal::CompletionQueueTag*>(res_tag); + *ok = res == 1; + if (core_cq_tag->FinalizeResult(tag, ok)) { + return true; + } + } + return false; +} + +CompletionQueue* CompletionQueue::CallbackAlternativeCQ() { + gpr_once_init(&g_once_init_callback_alternative, + [] { g_callback_alternative_mu = new grpc_core::Mutex(); }); + return g_callback_alternative_cq.Ref(); +} + +void CompletionQueue::ReleaseCallbackAlternativeCQ(CompletionQueue* cq) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + (void)cq; + // This accesses g_callback_alternative_cq without acquiring the mutex + // but it's considered safe because it just reads the pointer address. + GPR_DEBUG_ASSERT(cq == g_callback_alternative_cq.cq); + g_callback_alternative_cq.Unref(); +} + +} // namespace grpc diff --git a/src/cpp/common/core_codegen.cc b/src/cpp/common/core_codegen.cc new file mode 100644 index 00000000..da964f66 --- /dev/null +++ b/src/cpp/common/core_codegen.cc @@ -0,0 +1,244 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/profiling/timers.h" + +struct grpc_byte_buffer; + +namespace grpc { + +const grpc_completion_queue_factory* +CoreCodegen::grpc_completion_queue_factory_lookup( + const grpc_completion_queue_attributes* attributes) { + return ::grpc_completion_queue_factory_lookup(attributes); +} + +grpc_completion_queue* CoreCodegen::grpc_completion_queue_create( + const grpc_completion_queue_factory* factory, + const grpc_completion_queue_attributes* attributes, void* reserved) { + return ::grpc_completion_queue_create(factory, attributes, reserved); +} + +grpc_completion_queue* CoreCodegen::grpc_completion_queue_create_for_next( + void* reserved) { + return ::grpc_completion_queue_create_for_next(reserved); +} + +grpc_completion_queue* CoreCodegen::grpc_completion_queue_create_for_pluck( + void* reserved) { + return ::grpc_completion_queue_create_for_pluck(reserved); +} + +void CoreCodegen::grpc_completion_queue_shutdown(grpc_completion_queue* cq) { + ::grpc_completion_queue_shutdown(cq); +} + +void CoreCodegen::grpc_completion_queue_destroy(grpc_completion_queue* cq) { + ::grpc_completion_queue_destroy(cq); +} + +grpc_event CoreCodegen::grpc_completion_queue_pluck(grpc_completion_queue* cq, + void* tag, + gpr_timespec deadline, + void* reserved) { + return ::grpc_completion_queue_pluck(cq, tag, deadline, reserved); +} + +void* CoreCodegen::gpr_malloc(size_t size) { return ::gpr_malloc(size); } + +void CoreCodegen::gpr_free(void* p) { return ::gpr_free(p); } + +void CoreCodegen::grpc_init() { ::grpc_init(); } +void CoreCodegen::grpc_shutdown() { ::grpc_shutdown(); } + +void CoreCodegen::gpr_mu_init(gpr_mu* mu) { ::gpr_mu_init(mu); } +void CoreCodegen::gpr_mu_destroy(gpr_mu* mu) { ::gpr_mu_destroy(mu); } +void CoreCodegen::gpr_mu_lock(gpr_mu* mu) { ::gpr_mu_lock(mu); } +void CoreCodegen::gpr_mu_unlock(gpr_mu* mu) { ::gpr_mu_unlock(mu); } +void CoreCodegen::gpr_cv_init(gpr_cv* cv) { ::gpr_cv_init(cv); } +void CoreCodegen::gpr_cv_destroy(gpr_cv* cv) { ::gpr_cv_destroy(cv); } +int CoreCodegen::gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, + gpr_timespec abs_deadline) { + return ::gpr_cv_wait(cv, mu, abs_deadline); +} +void CoreCodegen::gpr_cv_signal(gpr_cv* cv) { ::gpr_cv_signal(cv); } +void CoreCodegen::gpr_cv_broadcast(gpr_cv* cv) { ::gpr_cv_broadcast(cv); } + +grpc_byte_buffer* CoreCodegen::grpc_byte_buffer_copy(grpc_byte_buffer* bb) { + return ::grpc_byte_buffer_copy(bb); +} + +void CoreCodegen::grpc_byte_buffer_destroy(grpc_byte_buffer* bb) { + ::grpc_byte_buffer_destroy(bb); +} + +size_t CoreCodegen::grpc_byte_buffer_length(grpc_byte_buffer* bb) { + return ::grpc_byte_buffer_length(bb); +} + +grpc_call_error CoreCodegen::grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, + size_t nops, void* tag, + void* reserved) { + return ::grpc_call_start_batch(call, ops, nops, tag, reserved); +} + +grpc_call_error CoreCodegen::grpc_call_cancel_with_status( + grpc_call* call, grpc_status_code status, const char* description, + void* reserved) { + return ::grpc_call_cancel_with_status(call, status, description, reserved); +} + +int CoreCodegen::grpc_call_failed_before_recv_message(const grpc_call* c) { + return ::grpc_call_failed_before_recv_message(c); +} +void CoreCodegen::grpc_call_ref(grpc_call* call) { ::grpc_call_ref(call); } +void CoreCodegen::grpc_call_unref(grpc_call* call) { ::grpc_call_unref(call); } +void* CoreCodegen::grpc_call_arena_alloc(grpc_call* call, size_t length) { + return ::grpc_call_arena_alloc(call, length); +} +const char* CoreCodegen::grpc_call_error_to_string(grpc_call_error error) { + return ::grpc_call_error_to_string(error); +} + +int CoreCodegen::grpc_byte_buffer_reader_init(grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + return ::grpc_byte_buffer_reader_init(reader, buffer); +} + +void CoreCodegen::grpc_byte_buffer_reader_destroy( + grpc_byte_buffer_reader* reader) { + ::grpc_byte_buffer_reader_destroy(reader); +} + +int CoreCodegen::grpc_byte_buffer_reader_next(grpc_byte_buffer_reader* reader, + grpc_slice* slice) { + return ::grpc_byte_buffer_reader_next(reader, slice); +} + +int CoreCodegen::grpc_byte_buffer_reader_peek(grpc_byte_buffer_reader* reader, + grpc_slice** slice) { + return ::grpc_byte_buffer_reader_peek(reader, slice); +} + +grpc_byte_buffer* CoreCodegen::grpc_raw_byte_buffer_create(grpc_slice* slice, + size_t nslices) { + return ::grpc_raw_byte_buffer_create(slice, nslices); +} + +grpc_slice CoreCodegen::grpc_slice_new_with_user_data(void* p, size_t len, + void (*destroy)(void*), + void* user_data) { + return ::grpc_slice_new_with_user_data(p, len, destroy, user_data); +} + +grpc_slice CoreCodegen::grpc_slice_new_with_len(void* p, size_t len, + void (*destroy)(void*, + size_t)) { + return ::grpc_slice_new_with_len(p, len, destroy); +} + +grpc_slice CoreCodegen::grpc_empty_slice() { return ::grpc_empty_slice(); } + +grpc_slice CoreCodegen::grpc_slice_malloc(size_t length) { + return ::grpc_slice_malloc(length); +} + +void CoreCodegen::grpc_slice_unref(grpc_slice slice) { + ::grpc_slice_unref(slice); +} + +grpc_slice CoreCodegen::grpc_slice_ref(grpc_slice slice) { + return ::grpc_slice_ref(slice); +} + +grpc_slice CoreCodegen::grpc_slice_split_tail(grpc_slice* s, size_t split) { + return ::grpc_slice_split_tail(s, split); +} + +grpc_slice CoreCodegen::grpc_slice_split_head(grpc_slice* s, size_t split) { + return ::grpc_slice_split_head(s, split); +} + +grpc_slice CoreCodegen::grpc_slice_sub(grpc_slice s, size_t begin, size_t end) { + return ::grpc_slice_sub(s, begin, end); +} + +grpc_slice CoreCodegen::grpc_slice_from_static_buffer(const void* buffer, + size_t length) { + return ::grpc_slice_from_static_buffer(buffer, length); +} + +grpc_slice CoreCodegen::grpc_slice_from_copied_buffer(const void* buffer, + size_t length) { + return ::grpc_slice_from_copied_buffer(static_cast(buffer), + length); +} + +void CoreCodegen::grpc_slice_buffer_add(grpc_slice_buffer* sb, + grpc_slice slice) { + ::grpc_slice_buffer_add(sb, slice); +} + +void CoreCodegen::grpc_slice_buffer_pop(grpc_slice_buffer* sb) { + ::grpc_slice_buffer_pop(sb); +} + +void CoreCodegen::grpc_metadata_array_init(grpc_metadata_array* array) { + ::grpc_metadata_array_init(array); +} + +void CoreCodegen::grpc_metadata_array_destroy(grpc_metadata_array* array) { + ::grpc_metadata_array_destroy(array); +} + +const Status& CoreCodegen::ok() { return grpc::Status::OK; } + +const Status& CoreCodegen::cancelled() { return grpc::Status::CANCELLED; } + +gpr_timespec CoreCodegen::gpr_inf_future(gpr_clock_type type) { + return ::gpr_inf_future(type); +} + +gpr_timespec CoreCodegen::gpr_time_0(gpr_clock_type type) { + return ::gpr_time_0(type); +} + +void CoreCodegen::assert_fail(const char* failed_assertion, const char* file, + int line) { + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, "assertion failed: %s", + failed_assertion); + abort(); +} + +} // namespace grpc diff --git a/src/cpp/common/insecure_create_auth_context.cc b/src/cpp/common/insecure_create_auth_context.cc new file mode 100644 index 00000000..4e5cbd03 --- /dev/null +++ b/src/cpp/common/insecure_create_auth_context.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include +#include + +namespace grpc { + +std::shared_ptr CreateAuthContext(grpc_call* call) { + (void)call; + return std::shared_ptr(); +} + +} // namespace grpc diff --git a/src/cpp/common/resource_quota_cc.cc b/src/cpp/common/resource_quota_cc.cc new file mode 100644 index 00000000..25aa01ed --- /dev/null +++ b/src/cpp/common/resource_quota_cc.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +ResourceQuota::ResourceQuota() : impl_(grpc_resource_quota_create(nullptr)) {} + +ResourceQuota::ResourceQuota(const std::string& name) + : impl_(grpc_resource_quota_create(name.c_str())) {} + +ResourceQuota::~ResourceQuota() { grpc_resource_quota_unref(impl_); } + +ResourceQuota& ResourceQuota::Resize(size_t new_size) { + grpc_resource_quota_resize(impl_, new_size); + return *this; +} + +ResourceQuota& ResourceQuota::SetMaxThreads(int new_max_threads) { + grpc_resource_quota_set_max_threads(impl_, new_max_threads); + return *this; +} +} // namespace grpc diff --git a/src/cpp/common/rpc_method.cc b/src/cpp/common/rpc_method.cc new file mode 100644 index 00000000..a47dd3e4 --- /dev/null +++ b/src/cpp/common/rpc_method.cc @@ -0,0 +1,21 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc {} // namespace grpc diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc new file mode 100644 index 00000000..0ba01e55 --- /dev/null +++ b/src/cpp/common/secure_auth_context.cc @@ -0,0 +1,97 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/common/secure_auth_context.h" + +#include + +namespace grpc { + +std::vector SecureAuthContext::GetPeerIdentity() const { + if (ctx_ == nullptr) { + return std::vector(); + } + grpc_auth_property_iterator iter = + grpc_auth_context_peer_identity(ctx_.get()); + std::vector identity; + const grpc_auth_property* property = nullptr; + while ((property = grpc_auth_property_iterator_next(&iter))) { + identity.push_back( + grpc::string_ref(property->value, property->value_length)); + } + return identity; +} + +std::string SecureAuthContext::GetPeerIdentityPropertyName() const { + if (ctx_ == nullptr) { + return ""; + } + const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get()); + return name == nullptr ? "" : name; +} + +std::vector SecureAuthContext::FindPropertyValues( + const std::string& name) const { + if (ctx_ == nullptr) { + return std::vector(); + } + grpc_auth_property_iterator iter = + grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str()); + const grpc_auth_property* property = nullptr; + std::vector values; + while ((property = grpc_auth_property_iterator_next(&iter))) { + values.push_back(grpc::string_ref(property->value, property->value_length)); + } + return values; +} + +AuthPropertyIterator SecureAuthContext::begin() const { + if (ctx_ != nullptr) { + grpc_auth_property_iterator iter = + grpc_auth_context_property_iterator(ctx_.get()); + const grpc_auth_property* property = + grpc_auth_property_iterator_next(&iter); + return AuthPropertyIterator(property, &iter); + } else { + return end(); + } +} + +AuthPropertyIterator SecureAuthContext::end() const { + return AuthPropertyIterator(); +} + +void SecureAuthContext::AddProperty(const std::string& key, + const grpc::string_ref& value) { + if (ctx_ == nullptr) return; + grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(), + value.size()); +} + +bool SecureAuthContext::SetPeerIdentityPropertyName(const std::string& name) { + if (ctx_ == nullptr) return false; + return grpc_auth_context_set_peer_identity_property_name(ctx_.get(), + name.c_str()) != 0; +} + +bool SecureAuthContext::IsPeerAuthenticated() const { + if (ctx_ == nullptr) return false; + return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0; +} + +} // namespace grpc diff --git a/src/cpp/common/secure_channel_arguments.cc b/src/cpp/common/secure_channel_arguments.cc new file mode 100644 index 00000000..339d94b1 --- /dev/null +++ b/src/cpp/common/secure_channel_arguments.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/lib/channel/channel_args.h" + +namespace grpc { + +void ChannelArguments::SetSslTargetNameOverride(const std::string& name) { + SetString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, name); +} + +std::string ChannelArguments::GetSslTargetNameOverride() const { + for (unsigned int i = 0; i < args_.size(); i++) { + if (std::string(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG) == args_[i].key) { + return args_[i].value.string; + } + } + return ""; +} + +} // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc new file mode 100644 index 00000000..6633374b --- /dev/null +++ b/src/cpp/common/secure_create_auth_context.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/cpp/common/secure_auth_context.h" + +namespace grpc { + +std::shared_ptr CreateAuthContext(grpc_call* call) { + if (call == nullptr) { + return std::shared_ptr(); + } + grpc_core::RefCountedPtr ctx(grpc_call_auth_context(call)); + return std::make_shared(ctx.get()); +} + +} // namespace grpc diff --git a/src/cpp/common/tls_certificate_provider.cc b/src/cpp/common/tls_certificate_provider.cc new file mode 100644 index 00000000..62a2a5cf --- /dev/null +++ b/src/cpp/common/tls_certificate_provider.cc @@ -0,0 +1,59 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "absl/container/inlined_vector.h" + +#include +#include +#include + +namespace grpc { +namespace experimental { + +StaticDataCertificateProvider::StaticDataCertificateProvider( + const std::string& root_certificate, + const std::vector& identity_key_cert_pairs) { + GPR_ASSERT(!root_certificate.empty() || !identity_key_cert_pairs.empty()); + grpc_tls_identity_pairs* pairs_core = grpc_tls_identity_pairs_create(); + for (const IdentityKeyCertPair& pair : identity_key_cert_pairs) { + grpc_tls_identity_pairs_add_pair(pairs_core, pair.private_key.c_str(), + pair.certificate_chain.c_str()); + } + c_provider_ = grpc_tls_certificate_provider_static_data_create( + root_certificate.c_str(), pairs_core); + GPR_ASSERT(c_provider_ != nullptr); +}; + +StaticDataCertificateProvider::~StaticDataCertificateProvider() { + grpc_tls_certificate_provider_release(c_provider_); +}; + +FileWatcherCertificateProvider::FileWatcherCertificateProvider( + const std::string& private_key_path, + const std::string& identity_certificate_path, + const std::string& root_cert_path, unsigned int refresh_interval_sec) { + c_provider_ = grpc_tls_certificate_provider_file_watcher_create( + private_key_path.c_str(), identity_certificate_path.c_str(), + root_cert_path.c_str(), refresh_interval_sec); + GPR_ASSERT(c_provider_ != nullptr); +}; + +FileWatcherCertificateProvider::~FileWatcherCertificateProvider() { + grpc_tls_certificate_provider_release(c_provider_); +}; + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/tls_credentials_options.cc b/src/cpp/common/tls_credentials_options.cc new file mode 100644 index 00000000..8e672308 --- /dev/null +++ b/src/cpp/common/tls_credentials_options.cc @@ -0,0 +1,189 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "absl/container/inlined_vector.h" + +#include +#include +#include + +#include "src/cpp/common/tls_credentials_options_util.h" + +namespace grpc { +namespace experimental { + +/** gRPC TLS server authorization check arg API implementation **/ +TlsServerAuthorizationCheckArg::TlsServerAuthorizationCheckArg( + grpc_tls_server_authorization_check_arg* arg) + : c_arg_(arg) { + GPR_ASSERT(c_arg_ != nullptr); + if (c_arg_->context != nullptr) { + gpr_log(GPR_ERROR, "c_arg context has already been set"); + } + c_arg_->context = static_cast(this); + c_arg_->destroy_context = &TlsServerAuthorizationCheckArgDestroyContext; +} + +TlsServerAuthorizationCheckArg::~TlsServerAuthorizationCheckArg() {} + +void* TlsServerAuthorizationCheckArg::cb_user_data() const { + return c_arg_->cb_user_data; +} + +int TlsServerAuthorizationCheckArg::success() const { return c_arg_->success; } + +std::string TlsServerAuthorizationCheckArg::target_name() const { + std::string cpp_target_name(c_arg_->target_name); + return cpp_target_name; +} + +std::string TlsServerAuthorizationCheckArg::peer_cert() const { + std::string cpp_peer_cert(c_arg_->peer_cert); + return cpp_peer_cert; +} + +std::string TlsServerAuthorizationCheckArg::peer_cert_full_chain() const { + std::string cpp_peer_cert_full_chain(c_arg_->peer_cert_full_chain); + return cpp_peer_cert_full_chain; +} + +grpc_status_code TlsServerAuthorizationCheckArg::status() const { + return c_arg_->status; +} + +std::string TlsServerAuthorizationCheckArg::error_details() const { + return c_arg_->error_details->error_details(); +} + +void TlsServerAuthorizationCheckArg::set_cb_user_data(void* cb_user_data) { + c_arg_->cb_user_data = cb_user_data; +} + +void TlsServerAuthorizationCheckArg::set_success(int success) { + c_arg_->success = success; +} + +void TlsServerAuthorizationCheckArg::set_target_name( + const std::string& target_name) { + c_arg_->target_name = gpr_strdup(target_name.c_str()); +} + +void TlsServerAuthorizationCheckArg::set_peer_cert( + const std::string& peer_cert) { + c_arg_->peer_cert = gpr_strdup(peer_cert.c_str()); +} + +void TlsServerAuthorizationCheckArg::set_peer_cert_full_chain( + const std::string& peer_cert_full_chain) { + c_arg_->peer_cert_full_chain = gpr_strdup(peer_cert_full_chain.c_str()); +} + +void TlsServerAuthorizationCheckArg::set_status(grpc_status_code status) { + c_arg_->status = status; +} + +void TlsServerAuthorizationCheckArg::set_error_details( + const std::string& error_details) { + c_arg_->error_details->set_error_details(error_details.c_str()); +} + +void TlsServerAuthorizationCheckArg::OnServerAuthorizationCheckDoneCallback() { + if (c_arg_->cb == nullptr) { + gpr_log(GPR_ERROR, "server authorizaton check arg callback API is nullptr"); + return; + } + c_arg_->cb(c_arg_); +} + +TlsServerAuthorizationCheckConfig::TlsServerAuthorizationCheckConfig( + std::shared_ptr + server_authorization_check_interface) + : server_authorization_check_interface_( + std::move(server_authorization_check_interface)) { + c_config_ = grpc_tls_server_authorization_check_config_create( + nullptr, &TlsServerAuthorizationCheckConfigCSchedule, + &TlsServerAuthorizationCheckConfigCCancel, nullptr); + c_config_->set_context(static_cast(this)); +} + +TlsServerAuthorizationCheckConfig::~TlsServerAuthorizationCheckConfig() { + grpc_tls_server_authorization_check_config_release(c_config_); +} + +TlsCredentialsOptions::TlsCredentialsOptions() { + c_credentials_options_ = grpc_tls_credentials_options_create(); +} + +void TlsCredentialsOptions::set_certificate_provider( + std::shared_ptr certificate_provider) { + certificate_provider_ = std::move(certificate_provider); + if (certificate_provider_ != nullptr) { + grpc_tls_credentials_options_set_certificate_provider( + c_credentials_options_, certificate_provider_->c_provider()); + } +} + +void TlsCredentialsOptions::watch_root_certs() { + grpc_tls_credentials_options_watch_root_certs(c_credentials_options_); +} + +void TlsCredentialsOptions::set_root_cert_name( + const std::string& root_cert_name) { + grpc_tls_credentials_options_set_root_cert_name(c_credentials_options_, + root_cert_name.c_str()); +} + +void TlsCredentialsOptions::watch_identity_key_cert_pairs() { + grpc_tls_credentials_options_watch_identity_key_cert_pairs( + c_credentials_options_); +} + +void TlsCredentialsOptions::set_identity_cert_name( + const std::string& identity_cert_name) { + grpc_tls_credentials_options_set_identity_cert_name( + c_credentials_options_, identity_cert_name.c_str()); +} + +void TlsChannelCredentialsOptions::set_server_verification_option( + grpc_tls_server_verification_option server_verification_option) { + grpc_tls_credentials_options* options = c_credentials_options(); + GPR_ASSERT(options != nullptr); + grpc_tls_credentials_options_set_server_verification_option( + options, server_verification_option); +} + +void TlsChannelCredentialsOptions::set_server_authorization_check_config( + std::shared_ptr config) { + grpc_tls_credentials_options* options = c_credentials_options(); + GPR_ASSERT(options != nullptr); + if (config != nullptr) { + grpc_tls_credentials_options_set_server_authorization_check_config( + options, config->c_config()); + } +} + +void TlsServerCredentialsOptions::set_cert_request_type( + grpc_ssl_client_certificate_request_type cert_request_type) { + grpc_tls_credentials_options* options = c_credentials_options(); + GPR_ASSERT(options != nullptr); + grpc_tls_credentials_options_set_cert_request_type(options, + cert_request_type); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/tls_credentials_options_util.cc b/src/cpp/common/tls_credentials_options_util.cc new file mode 100644 index 00000000..ebcab262 --- /dev/null +++ b/src/cpp/common/tls_credentials_options_util.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/common/tls_credentials_options_util.h" + +#include "absl/container/inlined_vector.h" + +#include + +namespace grpc { +namespace experimental { + +/** The C schedule and cancel functions for the server authorization check + * config. They populate a C server authorization check arg with the result + * of a C++ server authorization check schedule/cancel API. **/ +int TlsServerAuthorizationCheckConfigCSchedule( + void* /*config_user_data*/, grpc_tls_server_authorization_check_arg* arg) { + if (arg == nullptr || arg->config == nullptr || + arg->config->context() == nullptr) { + gpr_log(GPR_ERROR, + "server authorization check arg was not properly initialized"); + return 1; + } + TlsServerAuthorizationCheckConfig* cpp_config = + static_cast(arg->config->context()); + TlsServerAuthorizationCheckArg* cpp_arg = + new TlsServerAuthorizationCheckArg(arg); + int schedule_result = cpp_config->Schedule(cpp_arg); + return schedule_result; +} + +void TlsServerAuthorizationCheckConfigCCancel( + void* /*config_user_data*/, grpc_tls_server_authorization_check_arg* arg) { + if (arg == nullptr || arg->config == nullptr || + arg->config->context() == nullptr) { + gpr_log(GPR_ERROR, + "server authorization check arg was not properly initialized"); + return; + } + if (arg->context == nullptr) { + gpr_log(GPR_ERROR, + "server authorization check arg schedule has already completed"); + return; + } + TlsServerAuthorizationCheckConfig* cpp_config = + static_cast(arg->config->context()); + TlsServerAuthorizationCheckArg* cpp_arg = + static_cast(arg->context); + cpp_config->Cancel(cpp_arg); +} + +void TlsServerAuthorizationCheckArgDestroyContext(void* context) { + if (context != nullptr) { + TlsServerAuthorizationCheckArg* cpp_arg = + static_cast(context); + delete cpp_arg; + } +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/validate_service_config.cc b/src/cpp/common/validate_service_config.cc new file mode 100644 index 00000000..86172d9d --- /dev/null +++ b/src/cpp/common/validate_service_config.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/ext/service_config/service_config.h" + +namespace grpc { +namespace experimental { +std::string ValidateServiceConfigJSON(const std::string& service_config_json) { + grpc_init(); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::ServiceConfig::Create(/*args=*/nullptr, + service_config_json.c_str(), &error); + std::string return_value; + if (error != GRPC_ERROR_NONE) { + return_value = grpc_error_std_string(error); + GRPC_ERROR_UNREF(error); + } + grpc_shutdown(); + return return_value; +} +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc new file mode 100644 index 00000000..99e060b3 --- /dev/null +++ b/src/cpp/common/version_cc.cc @@ -0,0 +1,26 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This file is autogenerated from: + templates/src/core/surface/version.c.template */ + +#include + +namespace grpc { +std::string Version() { return "1.42.0-dev"; } +} // namespace grpc diff --git a/src/cpp/ext/filters/census/channel_filter.cc b/src/cpp/ext/filters/census/channel_filter.cc new file mode 100644 index 00000000..ea3bce9f --- /dev/null +++ b/src/cpp/ext/filters/census/channel_filter.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/channel_filter.h" + +namespace grpc { + +grpc_error_handle CensusChannelData::Init(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/client_filter.cc b/src/cpp/ext/filters/census/client_filter.cc new file mode 100644 index 00000000..e1ae787f --- /dev/null +++ b/src/cpp/ext/filters/census/client_filter.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/client_filter.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "opencensus/stats/stats.h" +#include "opencensus/tags/context_util.h" +#include "opencensus/tags/tag_key.h" +#include "opencensus/tags/tag_map.h" + +#include "src/core/lib/surface/call.h" +#include "src/cpp/ext/filters/census/grpc_plugin.h" +#include "src/cpp/ext/filters/census/measures.h" + +namespace grpc { + +constexpr uint32_t + OpenCensusCallTracer::OpenCensusCallAttemptTracer::kMaxTraceContextLen; +constexpr uint32_t + OpenCensusCallTracer::OpenCensusCallAttemptTracer::kMaxTagsLen; + +grpc_error_handle CensusClientCallData::Init( + grpc_call_element* /* elem */, const grpc_call_element_args* args) { + tracer_ = args->arena->New(args); + GPR_DEBUG_ASSERT(args->context[GRPC_CONTEXT_CALL_TRACER].value == nullptr); + args->context[GRPC_CONTEXT_CALL_TRACER].value = tracer_; + args->context[GRPC_CONTEXT_CALL_TRACER].destroy = [](void* tracer) { + (static_cast(tracer))->~OpenCensusCallTracer(); + }; + return GRPC_ERROR_NONE; +} + +void CensusClientCallData::StartTransportStreamOpBatch( + grpc_call_element* elem, TransportStreamOpBatch* op) { + // Note that we are generating the overall call context here instead of in + // the constructor of `OpenCensusCallTracer` due to the semantics of + // `grpc_census_call_set_context` which allows the application to set the + // census context for a call anytime before the first call to + // `grpc_call_start_batch`. + if (op->op()->send_initial_metadata) { + tracer_->GenerateContext(); + } + grpc_call_next_op(elem, op->op()); +} + +// +// OpenCensusCallTracer::OpenCensusCallAttemptTracer +// + +namespace { + +CensusContext CreateCensusContextForCallAttempt( + absl::string_view method, const CensusContext& parent_context) { + GPR_DEBUG_ASSERT(parent_context.Context().IsValid()); + return CensusContext(absl::StrCat("Attempt.", method), &parent_context.Span(), + parent_context.tags()); +} + +} // namespace + +OpenCensusCallTracer::OpenCensusCallAttemptTracer::OpenCensusCallAttemptTracer( + OpenCensusCallTracer* parent, uint64_t attempt_num, + bool is_transparent_retry, bool arena_allocated) + : parent_(parent), + arena_allocated_(arena_allocated), + context_(CreateCensusContextForCallAttempt(parent_->method_, + parent_->context_)), + start_time_(absl::Now()) { + context_.AddSpanAttribute("previous-rpc-attempts", attempt_num); + context_.AddSpanAttribute("transparent-retry", is_transparent_retry); + memset(&stats_bin_, 0, sizeof(grpc_linked_mdelem)); + memset(&tracing_bin_, 0, sizeof(grpc_linked_mdelem)); +} + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer:: + RecordSendInitialMetadata(grpc_metadata_batch* send_initial_metadata, + uint32_t /* flags */) { + size_t tracing_len = TraceContextSerialize(context_.Context(), tracing_buf_, + kMaxTraceContextLen); + if (tracing_len > 0) { + GRPC_LOG_IF_ERROR( + "census grpc_filter", + grpc_metadata_batch_add_tail( + send_initial_metadata, &tracing_bin_, + grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_TRACE_BIN, + grpc_core::UnmanagedMemorySlice(tracing_buf_, tracing_len)), + GRPC_BATCH_GRPC_TRACE_BIN)); + } + grpc_slice tags = grpc_empty_slice(); + // TODO(unknown): Add in tagging serialization. + size_t encoded_tags_len = StatsContextSerialize(kMaxTagsLen, &tags); + if (encoded_tags_len > 0) { + GRPC_LOG_IF_ERROR( + "census grpc_filter", + grpc_metadata_batch_add_tail( + send_initial_metadata, &stats_bin_, + grpc_mdelem_from_slices(GRPC_MDSTR_GRPC_TAGS_BIN, tags), + GRPC_BATCH_GRPC_TAGS_BIN)); + } +} + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer::RecordSendMessage( + const grpc_core::ByteStream& /* send_message */) { + ++sent_message_count_; +} + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer::RecordReceivedMessage( + const grpc_core::ByteStream& /* recv_message */) { + ++recv_message_count_; +} + +namespace { + +void FilterTrailingMetadata(grpc_metadata_batch* b, uint64_t* elapsed_time) { + if (b->legacy_index()->named.grpc_server_stats_bin != nullptr) { + ServerStatsDeserialize( + reinterpret_cast(GRPC_SLICE_START_PTR( + GRPC_MDVALUE(b->legacy_index()->named.grpc_server_stats_bin->md))), + GRPC_SLICE_LENGTH( + GRPC_MDVALUE(b->legacy_index()->named.grpc_server_stats_bin->md)), + elapsed_time); + b->Remove(b->legacy_index()->named.grpc_server_stats_bin); + } +} + +} // namespace + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer:: + RecordReceivedTrailingMetadata( + absl::Status status, grpc_metadata_batch* recv_trailing_metadata, + const grpc_transport_stream_stats& transport_stream_stats) { + FilterTrailingMetadata(recv_trailing_metadata, &elapsed_time_); + const uint64_t request_size = transport_stream_stats.outgoing.data_bytes; + const uint64_t response_size = transport_stream_stats.incoming.data_bytes; + std::vector> tags = + context_.tags().tags(); + tags.emplace_back(ClientMethodTagKey(), std::string(parent_->method_)); + status_code_ = status.code(); + std::string final_status = absl::StatusCodeToString(status_code_); + tags.emplace_back(ClientStatusTagKey(), final_status); + ::opencensus::stats::Record( + {{RpcClientSentBytesPerRpc(), static_cast(request_size)}, + {RpcClientReceivedBytesPerRpc(), static_cast(response_size)}, + {RpcClientServerLatency(), + ToDoubleMilliseconds(absl::Nanoseconds(elapsed_time_))}}, + tags); +} + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer::RecordCancel( + grpc_error_handle cancel_error) { + status_code_ = absl::StatusCode::kCancelled; + GRPC_ERROR_UNREF(cancel_error); +} + +void OpenCensusCallTracer::OpenCensusCallAttemptTracer::RecordEnd( + const gpr_timespec& /* latency */) { + double latency_ms = absl::ToDoubleMilliseconds(absl::Now() - start_time_); + std::vector> tags = + context_.tags().tags(); + tags.emplace_back(ClientMethodTagKey(), std::string(parent_->method_)); + tags.emplace_back(ClientStatusTagKey(), StatusCodeToString(status_code_)); + ::opencensus::stats::Record( + {{RpcClientRoundtripLatency(), latency_ms}, + {RpcClientSentMessagesPerRpc(), sent_message_count_}, + {RpcClientReceivedMessagesPerRpc(), recv_message_count_}}, + tags); + if (status_code_ != absl::StatusCode::kOk) { + context_.Span().SetStatus(opencensus::trace::StatusCode(status_code_), + StatusCodeToString(status_code_)); + } + context_.EndSpan(); + grpc_core::MutexLock lock(&parent_->mu_); + if (--parent_->num_active_rpcs_ == 0) { + parent_->time_at_last_attempt_end_ = absl::Now(); + } + if (arena_allocated_) { + this->~OpenCensusCallAttemptTracer(); + } else { + delete this; + } +} + +// +// OpenCensusCallTracer +// + +OpenCensusCallTracer::OpenCensusCallTracer(const grpc_call_element_args* args) + : call_context_(args->context), + path_(grpc_slice_ref_internal(args->path)), + method_(GetMethod(&path_)), + arena_(args->arena) {} + +OpenCensusCallTracer::~OpenCensusCallTracer() { + std::vector> tags = + context_.tags().tags(); + tags.emplace_back(ClientMethodTagKey(), std::string(method_)); + ::opencensus::stats::Record( + {{RpcClientRetriesPerCall(), retries_ - 1}, // exclude first attempt + {RpcClientTransparentRetriesPerCall(), transparent_retries_}, + {RpcClientRetryDelayPerCall(), ToDoubleMilliseconds(retry_delay_)}}, + tags); + grpc_slice_unref_internal(path_); +} + +void OpenCensusCallTracer::GenerateContext() { + auto* parent_context = reinterpret_cast( + call_context_[GRPC_CONTEXT_TRACING].value); + GenerateClientContext(absl::StrCat("Sent.", method_), &context_, + (parent_context == nullptr) ? nullptr : parent_context); +} + +OpenCensusCallTracer::OpenCensusCallAttemptTracer* +OpenCensusCallTracer::StartNewAttempt(bool is_transparent_retry) { + // We allocate the first attempt on the arena and all subsequent attempts on + // the heap, so that in the common case we don't require a heap allocation, + // nor do we unnecessarily grow the arena. + bool is_first_attempt = true; + uint64_t attempt_num; + { + grpc_core::MutexLock lock(&mu_); + if (transparent_retries_ != 0 || retries_ != 0) { + is_first_attempt = false; + if (num_active_rpcs_ == 0) { + retry_delay_ += absl::Now() - time_at_last_attempt_end_; + } + } + attempt_num = retries_; + if (is_transparent_retry) { + ++transparent_retries_; + } else { + ++retries_; + } + ++num_active_rpcs_; + } + if (is_first_attempt) { + return arena_->New( + this, attempt_num, is_transparent_retry, true /* arena_allocated */); + } + return new OpenCensusCallAttemptTracer( + this, attempt_num, is_transparent_retry, false /* arena_allocated */); +} + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc new file mode 100644 index 00000000..672228a8 --- /dev/null +++ b/src/cpp/ext/filters/census/context.cc @@ -0,0 +1,157 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/context.h" + +#include "opencensus/tags/context_util.h" +#include "opencensus/trace/context_util.h" +#include "opencensus/trace/propagation/grpc_trace_bin.h" + +namespace grpc { + +using ::opencensus::tags::TagMap; +using ::opencensus::trace::Span; +using ::opencensus::trace::SpanContext; + +void GenerateServerContext(absl::string_view tracing, absl::string_view method, + CensusContext* context) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + context->~CensusContext(); + SpanContext parent_ctx = + opencensus::trace::propagation::FromGrpcTraceBinHeader(tracing); + if (parent_ctx.IsValid()) { + new (context) CensusContext(method, parent_ctx); + return; + } + new (context) CensusContext(method, TagMap{}); +} + +void GenerateClientContext(absl::string_view method, CensusContext* ctxt, + CensusContext* parent_ctxt) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + ctxt->~CensusContext(); + if (parent_ctxt != nullptr) { + SpanContext span_ctxt = parent_ctxt->Context(); + Span span = parent_ctxt->Span(); + if (span_ctxt.IsValid()) { + new (ctxt) CensusContext(method, &span, TagMap{}); + return; + } + } + const Span& span = opencensus::trace::GetCurrentSpan(); + const TagMap& tags = opencensus::tags::GetCurrentTagMap(); + if (span.context().IsValid()) { + // Create span with parent. + new (ctxt) CensusContext(method, &span, tags); + return; + } + // Create span without parent. + new (ctxt) CensusContext(method, tags); +} + +size_t TraceContextSerialize(const ::opencensus::trace::SpanContext& context, + char* tracing_buf, size_t tracing_buf_size) { + if (tracing_buf_size < + opencensus::trace::propagation::kGrpcTraceBinHeaderLen) { + return 0; + } + opencensus::trace::propagation::ToGrpcTraceBinHeader( + context, reinterpret_cast(tracing_buf)); + return opencensus::trace::propagation::kGrpcTraceBinHeaderLen; +} + +size_t StatsContextSerialize(size_t /*max_tags_len*/, grpc_slice* /*tags*/) { + // TODO(unknown): Add implementation. Waiting on stats tagging to be added. + return 0; +} + +size_t ServerStatsSerialize(uint64_t server_elapsed_time, char* buf, + size_t buf_size) { + return RpcServerStatsEncoding::Encode(server_elapsed_time, buf, buf_size); +} + +size_t ServerStatsDeserialize(const char* buf, size_t buf_size, + uint64_t* server_elapsed_time) { + return RpcServerStatsEncoding::Decode(absl::string_view(buf, buf_size), + server_elapsed_time); +} + +uint64_t GetIncomingDataSize(const grpc_call_final_info* final_info) { + return final_info->stats.transport_stream_stats.incoming.data_bytes; +} + +uint64_t GetOutgoingDataSize(const grpc_call_final_info* final_info) { + return final_info->stats.transport_stream_stats.outgoing.data_bytes; +} + +SpanContext SpanContextFromCensusContext(const census_context* ctxt) { + return reinterpret_cast(ctxt)->Context(); +} + +Span SpanFromCensusContext(const census_context* ctxt) { + return reinterpret_cast(ctxt)->Span(); +} + +absl::string_view StatusCodeToString(grpc_status_code code) { + switch (code) { + case GRPC_STATUS_OK: + return "OK"; + case GRPC_STATUS_CANCELLED: + return "CANCELLED"; + case GRPC_STATUS_UNKNOWN: + return "UNKNOWN"; + case GRPC_STATUS_INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case GRPC_STATUS_DEADLINE_EXCEEDED: + return "DEADLINE_EXCEEDED"; + case GRPC_STATUS_NOT_FOUND: + return "NOT_FOUND"; + case GRPC_STATUS_ALREADY_EXISTS: + return "ALREADY_EXISTS"; + case GRPC_STATUS_PERMISSION_DENIED: + return "PERMISSION_DENIED"; + case GRPC_STATUS_UNAUTHENTICATED: + return "UNAUTHENTICATED"; + case GRPC_STATUS_RESOURCE_EXHAUSTED: + return "RESOURCE_EXHAUSTED"; + case GRPC_STATUS_FAILED_PRECONDITION: + return "FAILED_PRECONDITION"; + case GRPC_STATUS_ABORTED: + return "ABORTED"; + case GRPC_STATUS_OUT_OF_RANGE: + return "OUT_OF_RANGE"; + case GRPC_STATUS_UNIMPLEMENTED: + return "UNIMPLEMENTED"; + case GRPC_STATUS_INTERNAL: + return "INTERNAL"; + case GRPC_STATUS_UNAVAILABLE: + return "UNAVAILABLE"; + case GRPC_STATUS_DATA_LOSS: + return "DATA_LOSS"; + default: + // gRPC wants users of this enum to include a default branch so that + // adding values is not a breaking change. + return "UNKNOWN_STATUS"; + } +} + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/grpc_plugin.cc b/src/cpp/ext/filters/census/grpc_plugin.cc new file mode 100644 index 00000000..41ded5cb --- /dev/null +++ b/src/cpp/ext/filters/census/grpc_plugin.cc @@ -0,0 +1,147 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/grpc_plugin.h" + +#include "opencensus/tags/tag_key.h" +#include "opencensus/trace/span.h" + +#include + +#include "src/cpp/ext/filters/census/channel_filter.h" +#include "src/cpp/ext/filters/census/client_filter.h" +#include "src/cpp/ext/filters/census/measures.h" +#include "src/cpp/ext/filters/census/server_filter.h" + +namespace grpc { + +void RegisterOpenCensusPlugin() { + RegisterChannelFilter( + "opencensus_client", GRPC_CLIENT_CHANNEL, INT_MAX /* priority */, + nullptr /* condition function */); + RegisterChannelFilter( + "opencensus_server", GRPC_SERVER_CHANNEL, INT_MAX /* priority */, + nullptr /* condition function */); + + // Access measures to ensure they are initialized. Otherwise, creating a view + // before the first RPC would cause an error. + RpcClientSentBytesPerRpc(); + RpcClientReceivedBytesPerRpc(); + RpcClientRoundtripLatency(); + RpcClientServerLatency(); + RpcClientSentMessagesPerRpc(); + RpcClientReceivedMessagesPerRpc(); + RpcClientRetriesPerCall(); + RpcClientTransparentRetriesPerCall(); + RpcClientRetryDelayPerCall(); + + RpcServerSentBytesPerRpc(); + RpcServerReceivedBytesPerRpc(); + RpcServerServerLatency(); + RpcServerSentMessagesPerRpc(); + RpcServerReceivedMessagesPerRpc(); +} + +::opencensus::trace::Span GetSpanFromServerContext( + grpc::ServerContext* context) { + if (context == nullptr) return opencensus::trace::Span::BlankSpan(); + + return reinterpret_cast(context->census_context()) + ->Span(); +} + +// These measure definitions should be kept in sync across opencensus +// implementations--see +// https://github.com/census-instrumentation/opencensus-java/blob/master/contrib/grpc_metrics/src/main/java/io/opencensus/contrib/grpc/metrics/RpcMeasureConstants.java. +::opencensus::tags::TagKey ClientMethodTagKey() { + static const auto method_tag_key = + ::opencensus::tags::TagKey::Register("grpc_client_method"); + return method_tag_key; +} + +::opencensus::tags::TagKey ClientStatusTagKey() { + static const auto status_tag_key = + ::opencensus::tags::TagKey::Register("grpc_client_status"); + return status_tag_key; +} + +::opencensus::tags::TagKey ServerMethodTagKey() { + static const auto method_tag_key = + ::opencensus::tags::TagKey::Register("grpc_server_method"); + return method_tag_key; +} + +::opencensus::tags::TagKey ServerStatusTagKey() { + static const auto status_tag_key = + ::opencensus::tags::TagKey::Register("grpc_server_status"); + return status_tag_key; +} + +// Client +ABSL_CONST_INIT const absl::string_view + kRpcClientSentMessagesPerRpcMeasureName = + "grpc.io/client/sent_messages_per_rpc"; + +ABSL_CONST_INIT const absl::string_view kRpcClientSentBytesPerRpcMeasureName = + "grpc.io/client/sent_bytes_per_rpc"; + +ABSL_CONST_INIT const absl::string_view + kRpcClientReceivedMessagesPerRpcMeasureName = + "grpc.io/client/received_messages_per_rpc"; + +ABSL_CONST_INIT const absl::string_view + kRpcClientReceivedBytesPerRpcMeasureName = + "grpc.io/client/received_bytes_per_rpc"; + +ABSL_CONST_INIT const absl::string_view kRpcClientRoundtripLatencyMeasureName = + "grpc.io/client/roundtrip_latency"; + +ABSL_CONST_INIT const absl::string_view kRpcClientServerLatencyMeasureName = + "grpc.io/client/server_latency"; + +ABSL_CONST_INIT const absl::string_view kRpcClientRetriesPerCallMeasureName = + "grpc.io/client/retries_per_call"; + +ABSL_CONST_INIT const absl::string_view + kRpcClientTransparentRetriesPerCallMeasureName = + "grpc.io/client/transparent_retries_per_call"; + +ABSL_CONST_INIT const absl::string_view kRpcClientRetryDelayPerCallMeasureName = + "grpc.io/client/retry_delay_per_call"; + +// Server +ABSL_CONST_INIT const absl::string_view + kRpcServerSentMessagesPerRpcMeasureName = + "grpc.io/server/sent_messages_per_rpc"; + +ABSL_CONST_INIT const absl::string_view kRpcServerSentBytesPerRpcMeasureName = + "grpc.io/server/sent_bytes_per_rpc"; + +ABSL_CONST_INIT const absl::string_view + kRpcServerReceivedMessagesPerRpcMeasureName = + "grpc.io/server/received_messages_per_rpc"; + +ABSL_CONST_INIT const absl::string_view + kRpcServerReceivedBytesPerRpcMeasureName = + "grpc.io/server/received_bytes_per_rpc"; + +ABSL_CONST_INIT const absl::string_view kRpcServerServerLatencyMeasureName = + "grpc.io/server/server_latency"; +} // namespace grpc diff --git a/src/cpp/ext/filters/census/measures.cc b/src/cpp/ext/filters/census/measures.cc new file mode 100644 index 00000000..c02fdce6 --- /dev/null +++ b/src/cpp/ext/filters/census/measures.cc @@ -0,0 +1,156 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/measures.h" + +#include "opencensus/stats/stats.h" + +#include "src/cpp/ext/filters/census/grpc_plugin.h" + +namespace grpc { + +using ::opencensus::stats::MeasureDouble; +using ::opencensus::stats::MeasureInt64; + +// These measure definitions should be kept in sync across opencensus +// implementations--see +// https://github.com/census-instrumentation/opencensus-java/blob/master/contrib/grpc_metrics/src/main/java/io/opencensus/contrib/grpc/metrics/RpcMeasureConstants.java. + +namespace { + +// Unit constants +constexpr char kUnitBytes[] = "By"; +constexpr char kUnitMilliseconds[] = "ms"; +constexpr char kCount[] = "1"; + +} // namespace + +// Client +MeasureDouble RpcClientSentBytesPerRpc() { + static const auto measure = MeasureDouble::Register( + kRpcClientSentBytesPerRpcMeasureName, + "Total bytes sent across all request messages per RPC", kUnitBytes); + return measure; +} + +MeasureDouble RpcClientReceivedBytesPerRpc() { + static const auto measure = MeasureDouble::Register( + kRpcClientReceivedBytesPerRpcMeasureName, + "Total bytes received across all response messages per RPC", kUnitBytes); + return measure; +} + +MeasureDouble RpcClientRoundtripLatency() { + static const auto measure = MeasureDouble::Register( + kRpcClientRoundtripLatencyMeasureName, + "Time between first byte of request sent to last byte of response " + "received, or terminal error", + kUnitMilliseconds); + return measure; +} + +MeasureDouble RpcClientServerLatency() { + static const auto measure = MeasureDouble::Register( + kRpcClientServerLatencyMeasureName, + "Time between first byte of request received to last byte of response " + "sent, or terminal error (propagated from the server)", + kUnitMilliseconds); + return measure; +} + +MeasureInt64 RpcClientSentMessagesPerRpc() { + static const auto measure = + MeasureInt64::Register(kRpcClientSentMessagesPerRpcMeasureName, + "Number of messages sent per RPC", kCount); + return measure; +} + +MeasureInt64 RpcClientReceivedMessagesPerRpc() { + static const auto measure = + MeasureInt64::Register(kRpcClientReceivedMessagesPerRpcMeasureName, + "Number of messages received per RPC", kCount); + return measure; +} + +// Client per-overall-client-call measures +MeasureInt64 RpcClientRetriesPerCall() { + static const auto measure = + MeasureInt64::Register(kRpcClientRetriesPerCallMeasureName, + "Number of retry or hedging attempts excluding " + "transparent retries made during the client call", + kCount); + return measure; +} + +MeasureInt64 RpcClientTransparentRetriesPerCall() { + static const auto measure = MeasureInt64::Register( + kRpcClientTransparentRetriesPerCallMeasureName, + "Number of transparent retries made during the client call", kCount); + return measure; +} + +MeasureDouble RpcClientRetryDelayPerCall() { + static const auto measure = + MeasureDouble::Register(kRpcClientRetryDelayPerCallMeasureName, + "Total time of delay while there is no active " + "attempt during the client call", + kUnitMilliseconds); + return measure; +} + +// Server +MeasureDouble RpcServerSentBytesPerRpc() { + static const auto measure = MeasureDouble::Register( + kRpcServerSentBytesPerRpcMeasureName, + "Total bytes sent across all messages per RPC", kUnitBytes); + return measure; +} + +MeasureDouble RpcServerReceivedBytesPerRpc() { + static const auto measure = MeasureDouble::Register( + kRpcServerReceivedBytesPerRpcMeasureName, + "Total bytes received across all messages per RPC", kUnitBytes); + return measure; +} + +MeasureDouble RpcServerServerLatency() { + static const auto measure = MeasureDouble::Register( + kRpcServerServerLatencyMeasureName, + "Time between first byte of request received to last byte of response " + "sent, or terminal error", + kUnitMilliseconds); + return measure; +} + +MeasureInt64 RpcServerSentMessagesPerRpc() { + static const auto measure = + MeasureInt64::Register(kRpcServerSentMessagesPerRpcMeasureName, + "Number of messages sent per RPC", kCount); + return measure; +} + +MeasureInt64 RpcServerReceivedMessagesPerRpc() { + static const auto measure = + MeasureInt64::Register(kRpcServerReceivedMessagesPerRpcMeasureName, + "Number of messages received per RPC", kCount); + return measure; +} + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/rpc_encoding.cc b/src/cpp/ext/filters/census/rpc_encoding.cc new file mode 100644 index 00000000..7ce3e940 --- /dev/null +++ b/src/cpp/ext/filters/census/rpc_encoding.cc @@ -0,0 +1,32 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/rpc_encoding.h" + +namespace grpc { + +constexpr size_t RpcServerStatsEncoding::kRpcServerStatsSize; +constexpr size_t RpcServerStatsEncoding::kEncodeDecodeFailure; +constexpr size_t RpcServerStatsEncoding::kVersionIdSize; +constexpr size_t RpcServerStatsEncoding::kFieldIdSize; +constexpr size_t RpcServerStatsEncoding::kVersionIdOffset; +constexpr size_t RpcServerStatsEncoding::kVersionId; + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/server_filter.cc b/src/cpp/ext/filters/census/server_filter.cc new file mode 100644 index 00000000..0a84680b --- /dev/null +++ b/src/cpp/ext/filters/census/server_filter.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/ext/filters/census/server_filter.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "opencensus/stats/stats.h" + +#include "src/core/lib/surface/call.h" +#include "src/cpp/ext/filters/census/grpc_plugin.h" +#include "src/cpp/ext/filters/census/measures.h" + +namespace grpc { + +constexpr uint32_t CensusServerCallData::kMaxServerStatsLen; + +namespace { + +// server metadata elements +struct ServerMetadataElements { + grpc_slice path; + grpc_slice tracing_slice; + grpc_slice census_proto; +}; + +void FilterInitialMetadata(grpc_metadata_batch* b, + ServerMetadataElements* sml) { + if (b->legacy_index()->named.path != nullptr) { + sml->path = grpc_slice_ref_internal( + GRPC_MDVALUE(b->legacy_index()->named.path->md)); + } + if (b->legacy_index()->named.grpc_trace_bin != nullptr) { + sml->tracing_slice = grpc_slice_ref_internal( + GRPC_MDVALUE(b->legacy_index()->named.grpc_trace_bin->md)); + b->Remove(GRPC_BATCH_GRPC_TRACE_BIN); + } + if (b->legacy_index()->named.grpc_tags_bin != nullptr) { + sml->census_proto = grpc_slice_ref_internal( + GRPC_MDVALUE(b->legacy_index()->named.grpc_tags_bin->md)); + b->Remove(GRPC_BATCH_GRPC_TAGS_BIN); + } +} + +} // namespace + +void CensusServerCallData::OnDoneRecvMessageCb(void* user_data, + grpc_error_handle error) { + grpc_call_element* elem = reinterpret_cast(user_data); + CensusServerCallData* calld = + reinterpret_cast(elem->call_data); + CensusChannelData* channeld = + reinterpret_cast(elem->channel_data); + GPR_ASSERT(calld != nullptr); + GPR_ASSERT(channeld != nullptr); + // Stream messages are no longer valid after receiving trailing metadata. + if ((*calld->recv_message_) != nullptr) { + ++calld->recv_message_count_; + } + grpc_core::Closure::Run(DEBUG_LOCATION, calld->initial_on_done_recv_message_, + GRPC_ERROR_REF(error)); +} + +void CensusServerCallData::OnDoneRecvInitialMetadataCb( + void* user_data, grpc_error_handle error) { + grpc_call_element* elem = reinterpret_cast(user_data); + CensusServerCallData* calld = + reinterpret_cast(elem->call_data); + GPR_ASSERT(calld != nullptr); + if (error == GRPC_ERROR_NONE) { + grpc_metadata_batch* initial_metadata = calld->recv_initial_metadata_; + GPR_ASSERT(initial_metadata != nullptr); + ServerMetadataElements sml; + sml.path = grpc_empty_slice(); + sml.tracing_slice = grpc_empty_slice(); + sml.census_proto = grpc_empty_slice(); + FilterInitialMetadata(initial_metadata, &sml); + calld->path_ = grpc_slice_ref_internal(sml.path); + calld->method_ = GetMethod(&calld->path_); + calld->qualified_method_ = absl::StrCat("Recv.", calld->method_); + const char* tracing_str = + GRPC_SLICE_IS_EMPTY(sml.tracing_slice) + ? "" + : reinterpret_cast( + GRPC_SLICE_START_PTR(sml.tracing_slice)); + size_t tracing_str_len = GRPC_SLICE_IS_EMPTY(sml.tracing_slice) + ? 0 + : GRPC_SLICE_LENGTH(sml.tracing_slice); + GenerateServerContext(absl::string_view(tracing_str, tracing_str_len), + calld->qualified_method_, &calld->context_); + grpc_slice_unref_internal(sml.tracing_slice); + grpc_slice_unref_internal(sml.census_proto); + grpc_slice_unref_internal(sml.path); + grpc_census_call_set_context( + calld->gc_, reinterpret_cast(&calld->context_)); + } + grpc_core::Closure::Run(DEBUG_LOCATION, + calld->initial_on_done_recv_initial_metadata_, + GRPC_ERROR_REF(error)); +} + +void CensusServerCallData::StartTransportStreamOpBatch( + grpc_call_element* elem, TransportStreamOpBatch* op) { + if (op->recv_initial_metadata() != nullptr) { + // substitute our callback for the op callback + recv_initial_metadata_ = op->recv_initial_metadata()->batch(); + initial_on_done_recv_initial_metadata_ = op->recv_initial_metadata_ready(); + op->set_recv_initial_metadata_ready(&on_done_recv_initial_metadata_); + } + if (op->send_message() != nullptr) { + ++sent_message_count_; + } + if (op->recv_message() != nullptr) { + recv_message_ = op->op()->payload->recv_message.recv_message; + initial_on_done_recv_message_ = + op->op()->payload->recv_message.recv_message_ready; + op->op()->payload->recv_message.recv_message_ready = &on_done_recv_message_; + } + // We need to record the time when the trailing metadata was sent to mark the + // completeness of the request. + if (op->send_trailing_metadata() != nullptr) { + elapsed_time_ = absl::Now() - start_time_; + size_t len = ServerStatsSerialize(absl::ToInt64Nanoseconds(elapsed_time_), + stats_buf_, kMaxServerStatsLen); + if (len > 0) { + GRPC_LOG_IF_ERROR( + "census grpc_filter", + grpc_metadata_batch_add_tail( + op->send_trailing_metadata()->batch(), &census_bin_, + grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_SERVER_STATS_BIN, + grpc_core::UnmanagedMemorySlice(stats_buf_, len)), + GRPC_BATCH_GRPC_SERVER_STATS_BIN)); + } + } + // Call next op. + grpc_call_next_op(elem, op->op()); +} + +grpc_error_handle CensusServerCallData::Init( + grpc_call_element* elem, const grpc_call_element_args* args) { + start_time_ = absl::Now(); + gc_ = + grpc_call_from_top_element(grpc_call_stack_element(args->call_stack, 0)); + GRPC_CLOSURE_INIT(&on_done_recv_initial_metadata_, + OnDoneRecvInitialMetadataCb, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_done_recv_message_, OnDoneRecvMessageCb, elem, + grpc_schedule_on_exec_ctx); + auth_context_ = grpc_call_auth_context(gc_); + return GRPC_ERROR_NONE; +} + +void CensusServerCallData::Destroy(grpc_call_element* /*elem*/, + const grpc_call_final_info* final_info, + grpc_closure* /*then_call_closure*/) { + const uint64_t request_size = GetOutgoingDataSize(final_info); + const uint64_t response_size = GetIncomingDataSize(final_info); + double elapsed_time_ms = absl::ToDoubleMilliseconds(elapsed_time_); + grpc_auth_context_release(auth_context_); + ::opencensus::stats::Record( + {{RpcServerSentBytesPerRpc(), static_cast(response_size)}, + {RpcServerReceivedBytesPerRpc(), static_cast(request_size)}, + {RpcServerServerLatency(), elapsed_time_ms}, + {RpcServerSentMessagesPerRpc(), sent_message_count_}, + {RpcServerReceivedMessagesPerRpc(), recv_message_count_}}, + {{ServerMethodTagKey(), method_}, + {ServerStatusTagKey(), StatusCodeToString(final_info->final_status)}}); + grpc_slice_unref_internal(path_); + context_.EndSpan(); +} + +} // namespace grpc diff --git a/src/cpp/ext/filters/census/views.cc b/src/cpp/ext/filters/census/views.cc new file mode 100644 index 00000000..926e8425 --- /dev/null +++ b/src/cpp/ext/filters/census/views.cc @@ -0,0 +1,641 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/time/time.h" +#include "opencensus/stats/internal/aggregation_window.h" +#include "opencensus/stats/internal/set_aggregation_window.h" +#include "opencensus/stats/stats.h" + +#include "src/cpp/ext/filters/census/grpc_plugin.h" + +namespace grpc { + +using ::opencensus::stats::Aggregation; +using ::opencensus::stats::AggregationWindow; +using ::opencensus::stats::BucketBoundaries; +using ::opencensus::stats::ViewDescriptor; + +// These measure definitions should be kept in sync across opencensus +// implementations. + +namespace { + +Aggregation BytesDistributionAggregation() { + return Aggregation::Distribution(BucketBoundaries::Explicit( + {0, 1024, 2048, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, + 67108864, 268435456, 1073741824, 4294967296})); +} + +Aggregation MillisDistributionAggregation() { + return Aggregation::Distribution(BucketBoundaries::Explicit( + {0, 0.01, 0.05, 0.1, 0.3, 0.6, 0.8, 1, 2, 3, 4, + 5, 6, 8, 10, 13, 16, 20, 25, 30, 40, 50, + 65, 80, 100, 130, 160, 200, 250, 300, 400, 500, 650, + 800, 1000, 2000, 5000, 10000, 20000, 50000, 100000})); +} + +Aggregation CountDistributionAggregation() { + return Aggregation::Distribution(BucketBoundaries::Exponential(17, 1.0, 2.0)); +} + +ViewDescriptor MinuteDescriptor() { + auto descriptor = ViewDescriptor(); + SetAggregationWindow(AggregationWindow::Interval(absl::Minutes(1)), + &descriptor); + return descriptor; +} + +ViewDescriptor HourDescriptor() { + auto descriptor = ViewDescriptor(); + SetAggregationWindow(AggregationWindow::Interval(absl::Hours(1)), + &descriptor); + return descriptor; +} + +} // namespace + +void RegisterOpenCensusViewsForExport() { + ClientSentMessagesPerRpcCumulative().RegisterForExport(); + ClientSentBytesPerRpcCumulative().RegisterForExport(); + ClientReceivedMessagesPerRpcCumulative().RegisterForExport(); + ClientReceivedBytesPerRpcCumulative().RegisterForExport(); + ClientRoundtripLatencyCumulative().RegisterForExport(); + ClientServerLatencyCumulative().RegisterForExport(); + + ServerSentMessagesPerRpcCumulative().RegisterForExport(); + ServerSentBytesPerRpcCumulative().RegisterForExport(); + ServerReceivedMessagesPerRpcCumulative().RegisterForExport(); + ServerReceivedBytesPerRpcCumulative().RegisterForExport(); + ServerServerLatencyCumulative().RegisterForExport(); +} + +// client cumulative +const ViewDescriptor& ClientSentBytesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/sent_bytes_per_rpc/cumulative") + .set_measure(kRpcClientSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedBytesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/received_bytes_per_rpc/cumulative") + .set_measure(kRpcClientReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRoundtripLatencyCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/roundtrip_latency/cumulative") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientServerLatencyCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/server_latency/cumulative") + .set_measure(kRpcClientServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientCompletedRpcsCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/completed_rpcs/cumulative") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ClientMethodTagKey()) + .add_column(ClientStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientSentMessagesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/received_messages_per_rpc/cumulative") + .set_measure(kRpcClientSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedMessagesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/sent_messages_per_rpc/cumulative") + .set_measure(kRpcClientReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesPerCallCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/retries_per_call/cumulative") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/retries/cumulative") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesPerCallCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/transparent_retries_per_call/cumulative") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/transparent_retries/cumulative") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetryDelayPerCallCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/client/retry_delay_per_call/cumulative") + .set_measure(kRpcClientRetryDelayPerCallMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +// server cumulative +const ViewDescriptor& ServerSentBytesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/received_bytes_per_rpc/cumulative") + .set_measure(kRpcServerSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedBytesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/sent_bytes_per_rpc/cumulative") + .set_measure(kRpcServerReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerServerLatencyCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/elapsed_time/cumulative") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerCompletedRpcsCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/completed_rpcs/cumulative") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ServerMethodTagKey()) + .add_column(ServerStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerSentMessagesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/received_messages_per_rpc/cumulative") + .set_measure(kRpcServerSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedMessagesPerRpcCumulative() { + const static ViewDescriptor descriptor = + ViewDescriptor() + .set_name("grpc.io/server/sent_messages_per_rpc/cumulative") + .set_measure(kRpcServerReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +// client minute +const ViewDescriptor& ClientSentBytesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/sent_bytes_per_rpc/minute") + .set_measure(kRpcClientSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedBytesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/received_bytes_per_rpc/minute") + .set_measure(kRpcClientReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRoundtripLatencyMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/roundtrip_latency/minute") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientServerLatencyMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/server_latency/minute") + .set_measure(kRpcClientServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientCompletedRpcsMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/completed_rpcs/minute") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ClientMethodTagKey()) + .add_column(ClientStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientSentMessagesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/sent_messages_per_rpc/minute") + .set_measure(kRpcClientSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedMessagesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/received_messages_per_rpc/minute") + .set_measure(kRpcClientReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesPerCallMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/retries_per_call/minute") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/retries/minute") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesPerCallMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/transparent_retries_per_call/minute") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/transparent_retries/minute") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetryDelayPerCallMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/client/retry_delay_per_call/minute") + .set_measure(kRpcClientRetryDelayPerCallMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +// server minute +const ViewDescriptor& ServerSentBytesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/sent_bytes_per_rpc/minute") + .set_measure(kRpcServerSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedBytesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/received_bytes_per_rpc/minute") + .set_measure(kRpcServerReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerServerLatencyMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/server_latency/minute") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerCompletedRpcsMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/completed_rpcs/minute") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ServerMethodTagKey()) + .add_column(ServerStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerSentMessagesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/sent_messages_per_rpc/minute") + .set_measure(kRpcServerSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedMessagesPerRpcMinute() { + const static ViewDescriptor descriptor = + MinuteDescriptor() + .set_name("grpc.io/server/received_messages_per_rpc/minute") + .set_measure(kRpcServerReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +// client hour +const ViewDescriptor& ClientSentBytesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/sent_bytes_per_rpc/hour") + .set_measure(kRpcClientSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedBytesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/received_bytes_per_rpc/hour") + .set_measure(kRpcClientReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRoundtripLatencyHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/roundtrip_latency/hour") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientServerLatencyHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/server_latency/hour") + .set_measure(kRpcClientServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientCompletedRpcsHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/completed_rpcs/hour") + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ClientMethodTagKey()) + .add_column(ClientStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientSentMessagesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/sent_messages_per_rpc/hour") + .set_measure(kRpcClientSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientReceivedMessagesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/received_messages_per_rpc/hour") + .set_measure(kRpcClientReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesPerCallHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/retries_per_call/hour") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetriesHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/retries/hour") + .set_measure(kRpcClientRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesPerCallHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/transparent_retries_per_call/hour") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientTransparentRetriesHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/transparent_retries/hour") + .set_measure(kRpcClientTransparentRetriesPerCallMeasureName) + .set_aggregation(Aggregation::Sum()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ClientRetryDelayPerCallHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/client/retry_delay_per_call/hour") + .set_measure(kRpcClientRetryDelayPerCallMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ClientMethodTagKey()); + return descriptor; +} + +// server hour +const ViewDescriptor& ServerSentBytesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/sent_bytes_per_rpc/hour") + .set_measure(kRpcServerSentBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedBytesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/received_bytes_per_rpc/hour") + .set_measure(kRpcServerReceivedBytesPerRpcMeasureName) + .set_aggregation(BytesDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerServerLatencyHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/server_latency/hour") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(MillisDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerCompletedRpcsHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/completed_rpcs/hour") + .set_measure(kRpcServerServerLatencyMeasureName) + .set_aggregation(Aggregation::Count()) + .add_column(ServerMethodTagKey()) + .add_column(ServerStatusTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerSentMessagesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/sent_messages_per_rpc/hour") + .set_measure(kRpcServerSentMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +const ViewDescriptor& ServerReceivedMessagesPerRpcHour() { + const static ViewDescriptor descriptor = + HourDescriptor() + .set_name("grpc.io/server/received_messages_per_rpc/hour") + .set_measure(kRpcServerReceivedMessagesPerRpcMeasureName) + .set_aggregation(CountDistributionAggregation()) + .add_column(ServerMethodTagKey()); + return descriptor; +} + +} // namespace grpc diff --git a/src/cpp/ext/proto_server_reflection.cc b/src/cpp/ext/proto_server_reflection.cc new file mode 100644 index 00000000..af35a285 --- /dev/null +++ b/src/cpp/ext/proto_server_reflection.cc @@ -0,0 +1,211 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/ext/proto_server_reflection.h" + +#include +#include + +#include + +using grpc::Status; +using grpc::StatusCode; +using grpc::reflection::v1alpha::ErrorResponse; +using grpc::reflection::v1alpha::ExtensionNumberResponse; +using grpc::reflection::v1alpha::ExtensionRequest; +using grpc::reflection::v1alpha::ListServiceResponse; +using grpc::reflection::v1alpha::ServerReflectionRequest; +using grpc::reflection::v1alpha::ServerReflectionResponse; +using grpc::reflection::v1alpha::ServiceResponse; + +namespace grpc { + +ProtoServerReflection::ProtoServerReflection() + : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {} + +void ProtoServerReflection::SetServiceList( + const std::vector* services) { + services_ = services; +} + +Status ProtoServerReflection::ServerReflectionInfo( + ServerContext* context, + ServerReaderWriter* + stream) { + ServerReflectionRequest request; + ServerReflectionResponse response; + Status status; + while (stream->Read(&request)) { + switch (request.message_request_case()) { + case ServerReflectionRequest::MessageRequestCase::kFileByFilename: + status = GetFileByName(context, request.file_by_filename(), &response); + break; + case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol: + status = GetFileContainingSymbol( + context, request.file_containing_symbol(), &response); + break; + case ServerReflectionRequest::MessageRequestCase:: + kFileContainingExtension: + status = GetFileContainingExtension( + context, &request.file_containing_extension(), &response); + break; + case ServerReflectionRequest::MessageRequestCase:: + kAllExtensionNumbersOfType: + status = GetAllExtensionNumbers( + context, request.all_extension_numbers_of_type(), + response.mutable_all_extension_numbers_response()); + break; + case ServerReflectionRequest::MessageRequestCase::kListServices: + status = + ListService(context, response.mutable_list_services_response()); + break; + default: + status = Status(StatusCode::UNIMPLEMENTED, ""); + } + + if (!status.ok()) { + FillErrorResponse(status, response.mutable_error_response()); + } + response.set_valid_host(request.host()); + response.set_allocated_original_request( + new ServerReflectionRequest(request)); + stream->Write(response); + } + + return Status::OK; +} + +void ProtoServerReflection::FillErrorResponse(const Status& status, + ErrorResponse* error_response) { + error_response->set_error_code(status.error_code()); + error_response->set_error_message(status.error_message()); +} + +Status ProtoServerReflection::ListService(ServerContext* /*context*/, + ListServiceResponse* response) { + if (services_ == nullptr) { + return Status(StatusCode::NOT_FOUND, "Services not found."); + } + for (const auto& value : *services_) { + ServiceResponse* service_response = response->add_service(); + service_response->set_name(value); + } + return Status::OK; +} + +Status ProtoServerReflection::GetFileByName( + ServerContext* /*context*/, const std::string& file_name, + ServerReflectionResponse* response) { + if (descriptor_pool_ == nullptr) { + return Status::CANCELLED; + } + + const protobuf::FileDescriptor* file_desc = + descriptor_pool_->FindFileByName(file_name); + if (file_desc == nullptr) { + return Status(StatusCode::NOT_FOUND, "File not found."); + } + std::unordered_set seen_files; + FillFileDescriptorResponse(file_desc, response, &seen_files); + return Status::OK; +} + +Status ProtoServerReflection::GetFileContainingSymbol( + ServerContext* /*context*/, const std::string& symbol, + ServerReflectionResponse* response) { + if (descriptor_pool_ == nullptr) { + return Status::CANCELLED; + } + + const protobuf::FileDescriptor* file_desc = + descriptor_pool_->FindFileContainingSymbol(symbol); + if (file_desc == nullptr) { + return Status(StatusCode::NOT_FOUND, "Symbol not found."); + } + std::unordered_set seen_files; + FillFileDescriptorResponse(file_desc, response, &seen_files); + return Status::OK; +} + +Status ProtoServerReflection::GetFileContainingExtension( + ServerContext* /*context*/, const ExtensionRequest* request, + ServerReflectionResponse* response) { + if (descriptor_pool_ == nullptr) { + return Status::CANCELLED; + } + + const protobuf::Descriptor* desc = + descriptor_pool_->FindMessageTypeByName(request->containing_type()); + if (desc == nullptr) { + return Status(StatusCode::NOT_FOUND, "Type not found."); + } + + const protobuf::FieldDescriptor* field_desc = + descriptor_pool_->FindExtensionByNumber(desc, + request->extension_number()); + if (field_desc == nullptr) { + return Status(StatusCode::NOT_FOUND, "Extension not found."); + } + std::unordered_set seen_files; + FillFileDescriptorResponse(field_desc->file(), response, &seen_files); + return Status::OK; +} + +Status ProtoServerReflection::GetAllExtensionNumbers( + ServerContext* /*context*/, const std::string& type, + ExtensionNumberResponse* response) { + if (descriptor_pool_ == nullptr) { + return Status::CANCELLED; + } + + const protobuf::Descriptor* desc = + descriptor_pool_->FindMessageTypeByName(type); + if (desc == nullptr) { + return Status(StatusCode::NOT_FOUND, "Type not found."); + } + + std::vector extensions; + descriptor_pool_->FindAllExtensions(desc, &extensions); + for (const auto& value : extensions) { + response->add_extension_number(value->number()); + } + response->set_base_type_name(type); + return Status::OK; +} + +void ProtoServerReflection::FillFileDescriptorResponse( + const protobuf::FileDescriptor* file_desc, + ServerReflectionResponse* response, + std::unordered_set* seen_files) { + if (seen_files->find(file_desc->name()) != seen_files->end()) { + return; + } + seen_files->insert(file_desc->name()); + + protobuf::FileDescriptorProto file_desc_proto; + std::string data; + file_desc->CopyTo(&file_desc_proto); + file_desc_proto.SerializeToString(&data); + response->mutable_file_descriptor_response()->add_file_descriptor_proto(data); + + for (int i = 0; i < file_desc->dependency_count(); ++i) { + FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files); + } +} + +} // namespace grpc diff --git a/src/cpp/ext/proto_server_reflection_plugin.cc b/src/cpp/ext/proto_server_reflection_plugin.cc new file mode 100644 index 00000000..0a892d84 --- /dev/null +++ b/src/cpp/ext/proto_server_reflection_plugin.cc @@ -0,0 +1,83 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "src/cpp/ext/proto_server_reflection.h" + +namespace grpc { +namespace reflection { + +ProtoServerReflectionPlugin::ProtoServerReflectionPlugin() + : reflection_service_(new grpc::ProtoServerReflection()) {} + +std::string ProtoServerReflectionPlugin::name() { + return "proto_server_reflection"; +} + +void ProtoServerReflectionPlugin::InitServer(grpc::ServerInitializer* si) { + si->RegisterService(reflection_service_); +} + +void ProtoServerReflectionPlugin::Finish(grpc::ServerInitializer* si) { + reflection_service_->SetServiceList(si->GetServiceList()); +} + +void ProtoServerReflectionPlugin::ChangeArguments(const std::string& /*name*/, + void* /*value*/) {} + +bool ProtoServerReflectionPlugin::has_sync_methods() const { + if (reflection_service_) { + return reflection_service_->has_synchronous_methods(); + } + return false; +} + +bool ProtoServerReflectionPlugin::has_async_methods() const { + if (reflection_service_) { + return reflection_service_->has_async_methods(); + } + return false; +} + +static std::unique_ptr< ::grpc::ServerBuilderPlugin> CreateProtoReflection() { + return std::unique_ptr< ::grpc::ServerBuilderPlugin>( + new ProtoServerReflectionPlugin()); +} + +void InitProtoReflectionServerBuilderPlugin() { + static struct Initialize { + Initialize() { + ::grpc::ServerBuilder::InternalAddPluginFactory(&CreateProtoReflection); + } + } initializer; +} + +// Force InitProtoReflectionServerBuilderPlugin() to be called at static +// initialization time. +struct StaticProtoReflectionPluginInitializer { + StaticProtoReflectionPluginInitializer() { + InitProtoReflectionServerBuilderPlugin(); + } +} static_proto_reflection_plugin_initializer; + +} // namespace reflection +} // namespace grpc diff --git a/src/cpp/server/admin/admin_services.cc b/src/cpp/server/admin/admin_services.cc new file mode 100644 index 00000000..7e8841e7 --- /dev/null +++ b/src/cpp/server/admin/admin_services.cc @@ -0,0 +1,51 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include + +// TODO(lidiz) build a real registration system that can pull in services +// automatically with minimum amount of code. +#include "src/cpp/server/channelz/channelz_service.h" +#if !defined(GRPC_NO_XDS) && !defined(DISABLED_XDS_PROTO_IN_CC) +#include "src/cpp/server/csds/csds.h" +#endif // GRPC_NO_XDS or DISABLED_XDS_PROTO_IN_CC +namespace grpc { + +namespace { + +static auto* g_channelz_service = new ChannelzService(); +#if !defined(GRPC_NO_XDS) && !defined(DISABLED_XDS_PROTO_IN_CC) +static auto* g_csds = new xds::experimental::ClientStatusDiscoveryService(); +#endif // GRPC_NO_XDS or DISABLED_XDS_PROTO_IN_CC + +} // namespace + +void AddAdminServices(ServerBuilder* builder) { + builder->RegisterService(g_channelz_service); +#if !defined(GRPC_NO_XDS) && !defined(DISABLED_XDS_PROTO_IN_CC) + builder->RegisterService(g_csds); +#endif // GRPC_NO_XDS or DISABLED_XDS_PROTO_IN_CC +} + +} // namespace grpc diff --git a/src/cpp/server/async_generic_service.cc b/src/cpp/server/async_generic_service.cc new file mode 100644 index 00000000..fdb3da83 --- /dev/null +++ b/src/cpp/server/async_generic_service.cc @@ -0,0 +1,32 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +void AsyncGenericService::RequestCall( + GenericServerContext* ctx, GenericServerAsyncReaderWriter* reader_writer, + ::grpc::CompletionQueue* call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + server_->RequestAsyncGenericCall(ctx, reader_writer, call_cq, notification_cq, + tag); +} + +} // namespace grpc diff --git a/src/cpp/server/authorization_policy_provider.cc b/src/cpp/server/authorization_policy_provider.cc new file mode 100644 index 00000000..8dab33dd --- /dev/null +++ b/src/cpp/server/authorization_policy_provider.cc @@ -0,0 +1,69 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace grpc { +namespace experimental { + +std::shared_ptr +StaticDataAuthorizationPolicyProvider::Create(const std::string& authz_policy, + grpc::Status* status) { + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_static_data_create( + authz_policy.c_str(), &code, &error_details); + if (code != GRPC_STATUS_OK) { + *status = grpc::Status(static_cast(code), error_details); + gpr_free(const_cast(error_details)); + return nullptr; + } + *status = grpc::Status(); + return std::make_shared(provider); +} + +StaticDataAuthorizationPolicyProvider:: + ~StaticDataAuthorizationPolicyProvider() { + grpc_authorization_policy_provider_release(c_provider_); +} + +std::shared_ptr +FileWatcherAuthorizationPolicyProvider::Create( + const std::string& authz_policy_path, unsigned int refresh_interval_sec, + grpc::Status* status) { + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + authz_policy_path.c_str(), refresh_interval_sec, &code, + &error_details); + if (code != GRPC_STATUS_OK) { + *status = grpc::Status(static_cast(code), error_details); + gpr_free(const_cast(error_details)); + return nullptr; + } + return std::make_shared(provider); +} + +FileWatcherAuthorizationPolicyProvider:: + ~FileWatcherAuthorizationPolicyProvider() { + grpc_authorization_policy_provider_release(c_provider_); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/server/channel_argument_option.cc b/src/cpp/server/channel_argument_option.cc new file mode 100644 index 00000000..680b0ee0 --- /dev/null +++ b/src/cpp/server/channel_argument_option.cc @@ -0,0 +1,65 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { + +std::unique_ptr MakeChannelArgumentOption( + const std::string& name, const std::string& value) { + class StringOption final : public ServerBuilderOption { + public: + StringOption(const std::string& name, const std::string& value) + : name_(name), value_(value) {} + + void UpdateArguments(ChannelArguments* args) override { + args->SetString(name_, value_); + } + void UpdatePlugins( + std::vector>* /*plugins*/) + override {} + + private: + const std::string name_; + const std::string value_; + }; + return std::unique_ptr(new StringOption(name, value)); +} + +std::unique_ptr MakeChannelArgumentOption( + const std::string& name, int value) { + class IntOption final : public ServerBuilderOption { + public: + IntOption(const std::string& name, int value) + : name_(name), value_(value) {} + + void UpdateArguments(ChannelArguments* args) override { + args->SetInt(name_, value_); + } + void UpdatePlugins( + std::vector>* /*plugins*/) + override {} + + private: + const std::string name_; + const int value_; + }; + return std::unique_ptr(new IntOption(name, value)); +} + +} // namespace grpc diff --git a/src/cpp/server/channelz/channelz_service.cc b/src/cpp/server/channelz/channelz_service.cc new file mode 100644 index 00000000..6dcf84bf --- /dev/null +++ b/src/cpp/server/channelz/channelz_service.cc @@ -0,0 +1,153 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/channelz/channelz_service.h" + +#include +#include + +namespace grpc { + +namespace { + +grpc::protobuf::util::Status ParseJson(const char* json_str, + grpc::protobuf::Message* message) { + grpc::protobuf::json::JsonParseOptions options; + options.case_insensitive_enum_parsing = true; + return grpc::protobuf::json::JsonStringToMessage(json_str, message, options); +} + +} // namespace + +Status ChannelzService::GetTopChannels( + ServerContext* /*unused*/, + const channelz::v1::GetTopChannelsRequest* request, + channelz::v1::GetTopChannelsResponse* response) { + char* json_str = grpc_channelz_get_top_channels(request->start_channel_id()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_top_channels returned null"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetServers( + ServerContext* /*unused*/, const channelz::v1::GetServersRequest* request, + channelz::v1::GetServersResponse* response) { + char* json_str = grpc_channelz_get_servers(request->start_server_id()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_servers returned null"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetServer(ServerContext* /*unused*/, + const channelz::v1::GetServerRequest* request, + channelz::v1::GetServerResponse* response) { + char* json_str = grpc_channelz_get_server(request->server_id()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_server returned null"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetServerSockets( + ServerContext* /*unused*/, + const channelz::v1::GetServerSocketsRequest* request, + channelz::v1::GetServerSocketsResponse* response) { + char* json_str = grpc_channelz_get_server_sockets( + request->server_id(), request->start_socket_id(), request->max_results()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_server_sockets returned null"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetChannel( + ServerContext* /*unused*/, const channelz::v1::GetChannelRequest* request, + channelz::v1::GetChannelResponse* response) { + char* json_str = grpc_channelz_get_channel(request->channel_id()); + if (json_str == nullptr) { + return Status(StatusCode::NOT_FOUND, "No object found for that ChannelId"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetSubchannel( + ServerContext* /*unused*/, + const channelz::v1::GetSubchannelRequest* request, + channelz::v1::GetSubchannelResponse* response) { + char* json_str = grpc_channelz_get_subchannel(request->subchannel_id()); + if (json_str == nullptr) { + return Status(StatusCode::NOT_FOUND, + "No object found for that SubchannelId"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetSocket(ServerContext* /*unused*/, + const channelz::v1::GetSocketRequest* request, + channelz::v1::GetSocketResponse* response) { + char* json_str = grpc_channelz_get_socket(request->socket_id()); + if (json_str == nullptr) { + return Status(StatusCode::NOT_FOUND, "No object found for that SocketId"); + } + grpc::protobuf::util::Status s = ParseJson(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +} // namespace grpc diff --git a/src/cpp/server/channelz/channelz_service_plugin.cc b/src/cpp/server/channelz/channelz_service_plugin.cc new file mode 100644 index 00000000..e2e9495e --- /dev/null +++ b/src/cpp/server/channelz/channelz_service_plugin.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "src/cpp/server/channelz/channelz_service.h" + +namespace grpc { +namespace channelz { +namespace experimental { + +class ChannelzServicePlugin : public ::grpc::ServerBuilderPlugin { + public: + ChannelzServicePlugin() : channelz_service_(new grpc::ChannelzService()) {} + + std::string name() override { return "channelz_service"; } + + void InitServer(grpc::ServerInitializer* si) override { + si->RegisterService(channelz_service_); + } + + void Finish(grpc::ServerInitializer* /*si*/) override {} + + void ChangeArguments(const std::string& /*name*/, void* /*value*/) override {} + + bool has_sync_methods() const override { + if (channelz_service_) { + return channelz_service_->has_synchronous_methods(); + } + return false; + } + + bool has_async_methods() const override { + if (channelz_service_) { + return channelz_service_->has_async_methods(); + } + return false; + } + + private: + std::shared_ptr channelz_service_; +}; + +static std::unique_ptr< ::grpc::ServerBuilderPlugin> +CreateChannelzServicePlugin() { + return std::unique_ptr< ::grpc::ServerBuilderPlugin>( + new ChannelzServicePlugin()); +} + +void InitChannelzService() { + static struct Initializer { + Initializer() { + ::grpc::ServerBuilder::InternalAddPluginFactory( + &grpc::channelz::experimental::CreateChannelzServicePlugin); + } + } initialize; +} + +} // namespace experimental +} // namespace channelz +} // namespace grpc diff --git a/src/cpp/server/create_default_thread_pool.cc b/src/cpp/server/create_default_thread_pool.cc new file mode 100644 index 00000000..8ca3e32c --- /dev/null +++ b/src/cpp/server/create_default_thread_pool.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/dynamic_thread_pool.h" + +#ifndef GRPC_CUSTOM_DEFAULT_THREAD_POOL + +namespace grpc { +namespace { + +ThreadPoolInterface* CreateDefaultThreadPoolImpl() { + int cores = gpr_cpu_num_cores(); + if (!cores) cores = 4; + return new DynamicThreadPool(cores); +} + +CreateThreadPoolFunc g_ctp_impl = CreateDefaultThreadPoolImpl; + +} // namespace + +ThreadPoolInterface* CreateDefaultThreadPool() { return g_ctp_impl(); } + +void SetCreateThreadPool(CreateThreadPoolFunc func) { g_ctp_impl = func; } + +} // namespace grpc + +#endif // !GRPC_CUSTOM_DEFAULT_THREAD_POOL diff --git a/src/cpp/server/csds/csds.cc b/src/cpp/server/csds/csds.cc new file mode 100644 index 00000000..61de3760 --- /dev/null +++ b/src/cpp/server/csds/csds.cc @@ -0,0 +1,94 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/cpp/server/csds/csds.h" + +#include + +#include "absl/status/statusor.h" + +#include +#include +#include + +#include "src/proto/grpc/testing/xds/v3/csds.grpc.pb.h" + +namespace grpc { +namespace xds { +namespace experimental { + +using envoy::service::status::v3::ClientConfig; +using envoy::service::status::v3::ClientStatusRequest; +using envoy::service::status::v3::ClientStatusResponse; + +namespace { + +absl::StatusOr DumpClientConfig() { + ClientConfig client_config; + grpc_slice serialized_client_config = grpc_dump_xds_configs(); + std::string bytes = StringFromCopiedSlice(serialized_client_config); + grpc_slice_unref(serialized_client_config); + if (!client_config.ParseFromString(bytes)) { + return absl::InternalError("Failed to parse ClientConfig."); + } + return client_config; +} + +} // namespace + +Status ClientStatusDiscoveryService::StreamClientStatus( + ServerContext* /*context*/, + ServerReaderWriter* stream) { + ClientStatusRequest request; + while (stream->Read(&request)) { + ClientStatusResponse response; + absl::StatusOr s = DumpClientConfig(); + if (!s.ok()) { + if (s.status().code() == absl::StatusCode::kUnavailable) { + // If the xDS client is not initialized, return empty response + stream->Write(response); + continue; + } + return Status(StatusCode(s.status().raw_code()), s.status().ToString()); + } + *response.add_config() = std::move(s.value()); + stream->Write(response); + } + return Status::OK; +} + +Status ClientStatusDiscoveryService::FetchClientStatus( + ServerContext* /*context*/, const ClientStatusRequest* /*request*/, + ClientStatusResponse* response) { + absl::StatusOr s = DumpClientConfig(); + if (!s.ok()) { + if (s.status().code() == absl::StatusCode::kUnavailable) { + // If the xDS client is not initialized, return empty response + return Status::OK; + } + return Status(StatusCode(s.status().raw_code()), s.status().ToString()); + } + *response->add_config() = std::move(s.value()); + return Status::OK; +} + +} // namespace experimental +} // namespace xds +} // namespace grpc diff --git a/src/cpp/server/dynamic_thread_pool.cc b/src/cpp/server/dynamic_thread_pool.cc new file mode 100644 index 00000000..e96dc4c4 --- /dev/null +++ b/src/cpp/server/dynamic_thread_pool.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/server/dynamic_thread_pool.h" + +#include +#include + +#include "src/core/lib/gprpp/thd.h" + +namespace grpc { + +DynamicThreadPool::DynamicThread::DynamicThread(DynamicThreadPool* pool) + : pool_(pool), + thd_( + "grpcpp_dynamic_pool", + [](void* th) { + static_cast(th)->ThreadFunc(); + }, + this) { + thd_.Start(); +} +DynamicThreadPool::DynamicThread::~DynamicThread() { thd_.Join(); } + +void DynamicThreadPool::DynamicThread::ThreadFunc() { + pool_->ThreadFunc(); + // Now that we have killed ourselves, we should reduce the thread count + grpc_core::MutexLock lock(&pool_->mu_); + pool_->nthreads_--; + // Move ourselves to dead list + pool_->dead_threads_.push_back(this); + + if ((pool_->shutdown_) && (pool_->nthreads_ == 0)) { + pool_->shutdown_cv_.Signal(); + } +} + +void DynamicThreadPool::ThreadFunc() { + for (;;) { + // Wait until work is available or we are shutting down. + grpc_core::ReleasableMutexLock lock(&mu_); + if (!shutdown_ && callbacks_.empty()) { + // If there are too many threads waiting, then quit this thread + if (threads_waiting_ >= reserve_threads_) { + break; + } + threads_waiting_++; + cv_.Wait(&mu_); + threads_waiting_--; + } + // Drain callbacks before considering shutdown to ensure all work + // gets completed. + if (!callbacks_.empty()) { + auto cb = callbacks_.front(); + callbacks_.pop(); + lock.Release(); + cb(); + } else if (shutdown_) { + break; + } + } +} + +DynamicThreadPool::DynamicThreadPool(int reserve_threads) + : shutdown_(false), + reserve_threads_(reserve_threads), + nthreads_(0), + threads_waiting_(0) { + for (int i = 0; i < reserve_threads_; i++) { + grpc_core::MutexLock lock(&mu_); + nthreads_++; + new DynamicThread(this); + } +} + +void DynamicThreadPool::ReapThreads(std::list* tlist) { + for (auto t = tlist->begin(); t != tlist->end(); t = tlist->erase(t)) { + delete *t; + } +} + +DynamicThreadPool::~DynamicThreadPool() { + grpc_core::MutexLock lock(&mu_); + shutdown_ = true; + cv_.SignalAll(); + while (nthreads_ != 0) { + shutdown_cv_.Wait(&mu_); + } + ReapThreads(&dead_threads_); +} + +void DynamicThreadPool::Add(const std::function& callback) { + grpc_core::MutexLock lock(&mu_); + // Add works to the callbacks list + callbacks_.push(callback); + // Increase pool size or notify as needed + if (threads_waiting_ == 0) { + // Kick off a new thread + nthreads_++; + new DynamicThread(this); + } else { + cv_.Signal(); + } + // Also use this chance to harvest dead threads + if (!dead_threads_.empty()) { + ReapThreads(&dead_threads_); + } +} + +} // namespace grpc diff --git a/src/cpp/server/external_connection_acceptor_impl.cc b/src/cpp/server/external_connection_acceptor_impl.cc new file mode 100644 index 00000000..1b0c68ff --- /dev/null +++ b/src/cpp/server/external_connection_acceptor_impl.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/server/external_connection_acceptor_impl.h" + +#include + +#include +#include + +namespace grpc { +namespace internal { +namespace { +// The actual type to return to user. It co-owns the internal impl object with +// the server. +class AcceptorWrapper : public experimental::ExternalConnectionAcceptor { + public: + explicit AcceptorWrapper(std::shared_ptr impl) + : impl_(std::move(impl)) {} + void HandleNewConnection(NewConnectionParameters* p) override { + impl_->HandleNewConnection(p); + } + + private: + std::shared_ptr impl_; +}; +} // namespace + +ExternalConnectionAcceptorImpl::ExternalConnectionAcceptorImpl( + const std::string& name, + ServerBuilder::experimental_type::ExternalConnectionType type, + std::shared_ptr creds) + : name_(name), creds_(std::move(creds)) { + GPR_ASSERT(type == + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD); +} + +std::unique_ptr +ExternalConnectionAcceptorImpl::GetAcceptor() { + grpc_core::MutexLock lock(&mu_); + GPR_ASSERT(!has_acceptor_); + has_acceptor_ = true; + return std::unique_ptr( + new AcceptorWrapper(shared_from_this())); +} + +void ExternalConnectionAcceptorImpl::HandleNewConnection( + experimental::ExternalConnectionAcceptor::NewConnectionParameters* p) { + grpc_core::MutexLock lock(&mu_); + if (shutdown_ || !started_) { + // TODO(yangg) clean up. + gpr_log( + GPR_ERROR, + "NOT handling external connection with fd %d, started %d, shutdown %d", + p->fd, started_, shutdown_); + return; + } + if (handler_) { + handler_->Handle(p->listener_fd, p->fd, p->read_buffer.c_buffer()); + } +} + +void ExternalConnectionAcceptorImpl::Shutdown() { + grpc_core::MutexLock lock(&mu_); + shutdown_ = true; +} + +void ExternalConnectionAcceptorImpl::Start() { + grpc_core::MutexLock lock(&mu_); + GPR_ASSERT(!started_); + GPR_ASSERT(has_acceptor_); + GPR_ASSERT(!shutdown_); + started_ = true; +} + +void ExternalConnectionAcceptorImpl::SetToChannelArgs(ChannelArguments* args) { + args->SetPointer(name_.c_str(), &handler_); +} + +} // namespace internal +} // namespace grpc diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc new file mode 100644 index 00000000..dd638e6e --- /dev/null +++ b/src/cpp/server/health/default_health_check_service.cc @@ -0,0 +1,504 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/server/health/default_health_check_service.h" + +#include + +#include "absl/memory/memory.h" +#include "upb/upb.hpp" + +#include +#include +#include +#include + +#include "src/proto/grpc/health/v1/health.upb.h" + +#define MAX_SERVICE_NAME_LENGTH 200 + +namespace grpc { + +// +// DefaultHealthCheckService +// + +DefaultHealthCheckService::DefaultHealthCheckService() { + services_map_[""].SetServingStatus(SERVING); +} + +void DefaultHealthCheckService::SetServingStatus( + const std::string& service_name, bool serving) { + grpc_core::MutexLock lock(&mu_); + if (shutdown_) { + // Set to NOT_SERVING in case service_name is not in the map. + serving = false; + } + services_map_[service_name].SetServingStatus(serving ? SERVING : NOT_SERVING); +} + +void DefaultHealthCheckService::SetServingStatus(bool serving) { + const ServingStatus status = serving ? SERVING : NOT_SERVING; + grpc_core::MutexLock lock(&mu_); + if (shutdown_) { + return; + } + for (auto& p : services_map_) { + ServiceData& service_data = p.second; + service_data.SetServingStatus(status); + } +} + +void DefaultHealthCheckService::Shutdown() { + grpc_core::MutexLock lock(&mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto& p : services_map_) { + ServiceData& service_data = p.second; + service_data.SetServingStatus(NOT_SERVING); + } +} + +DefaultHealthCheckService::ServingStatus +DefaultHealthCheckService::GetServingStatus( + const std::string& service_name) const { + grpc_core::MutexLock lock(&mu_); + auto it = services_map_.find(service_name); + if (it == services_map_.end()) { + return NOT_FOUND; + } + const ServiceData& service_data = it->second; + return service_data.GetServingStatus(); +} + +void DefaultHealthCheckService::RegisterCallHandler( + const std::string& service_name, + std::shared_ptr handler) { + grpc_core::MutexLock lock(&mu_); + ServiceData& service_data = services_map_[service_name]; + service_data.AddCallHandler(handler /* copies ref */); + HealthCheckServiceImpl::CallHandler* h = handler.get(); + h->SendHealth(std::move(handler), service_data.GetServingStatus()); +} + +void DefaultHealthCheckService::UnregisterCallHandler( + const std::string& service_name, + const std::shared_ptr& handler) { + grpc_core::MutexLock lock(&mu_); + auto it = services_map_.find(service_name); + if (it == services_map_.end()) return; + ServiceData& service_data = it->second; + service_data.RemoveCallHandler(handler); + if (service_data.Unused()) { + services_map_.erase(it); + } +} + +DefaultHealthCheckService::HealthCheckServiceImpl* +DefaultHealthCheckService::GetHealthCheckService( + std::unique_ptr cq) { + GPR_ASSERT(impl_ == nullptr); + impl_ = absl::make_unique(this, std::move(cq)); + return impl_.get(); +} + +// +// DefaultHealthCheckService::ServiceData +// + +void DefaultHealthCheckService::ServiceData::SetServingStatus( + ServingStatus status) { + status_ = status; + for (auto& call_handler : call_handlers_) { + call_handler->SendHealth(call_handler /* copies ref */, status); + } +} + +void DefaultHealthCheckService::ServiceData::AddCallHandler( + std::shared_ptr handler) { + call_handlers_.insert(std::move(handler)); +} + +void DefaultHealthCheckService::ServiceData::RemoveCallHandler( + const std::shared_ptr& handler) { + call_handlers_.erase(handler); +} + +// +// DefaultHealthCheckService::HealthCheckServiceImpl +// + +namespace { +const char kHealthCheckMethodName[] = "/grpc.health.v1.Health/Check"; +const char kHealthWatchMethodName[] = "/grpc.health.v1.Health/Watch"; +} // namespace + +DefaultHealthCheckService::HealthCheckServiceImpl::HealthCheckServiceImpl( + DefaultHealthCheckService* database, + std::unique_ptr cq) + : database_(database), cq_(std::move(cq)) { + // Add Check() method. + AddMethod(new internal::RpcServiceMethod( + kHealthCheckMethodName, internal::RpcMethod::NORMAL_RPC, nullptr)); + // Add Watch() method. + AddMethod(new internal::RpcServiceMethod( + kHealthWatchMethodName, internal::RpcMethod::SERVER_STREAMING, nullptr)); + // Create serving thread. + thread_ = absl::make_unique<::grpc_core::Thread>("grpc_health_check_service", + Serve, this); +} + +DefaultHealthCheckService::HealthCheckServiceImpl::~HealthCheckServiceImpl() { + // We will reach here after the server starts shutting down. + shutdown_ = true; + { + grpc_core::MutexLock lock(&cq_shutdown_mu_); + cq_->Shutdown(); + } + thread_->Join(); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::StartServingThread() { + // Request the calls we're interested in. + // We do this before starting the serving thread, so that we know it's + // done before server startup is complete. + CheckCallHandler::CreateAndStart(cq_.get(), database_, this); + WatchCallHandler::CreateAndStart(cq_.get(), database_, this); + // Start serving thread. + thread_->Start(); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::Serve(void* arg) { + HealthCheckServiceImpl* service = static_cast(arg); + void* tag; + bool ok; + while (true) { + if (!service->cq_->Next(&tag, &ok)) { + // The completion queue is shutting down. + GPR_ASSERT(service->shutdown_); + break; + } + auto* next_step = static_cast(tag); + next_step->Run(ok); + } +} + +bool DefaultHealthCheckService::HealthCheckServiceImpl::DecodeRequest( + const ByteBuffer& request, std::string* service_name) { + std::vector slices; + if (!request.Dump(&slices).ok()) return false; + uint8_t* request_bytes = nullptr; + size_t request_size = 0; + if (slices.size() == 1) { + request_bytes = const_cast(slices[0].begin()); + request_size = slices[0].size(); + } else if (slices.size() > 1) { + request_bytes = static_cast(gpr_malloc(request.Length())); + uint8_t* copy_to = request_bytes; + for (size_t i = 0; i < slices.size(); i++) { + memcpy(copy_to, slices[i].begin(), slices[i].size()); + copy_to += slices[i].size(); + } + } + upb::Arena arena; + grpc_health_v1_HealthCheckRequest* request_struct = + grpc_health_v1_HealthCheckRequest_parse( + reinterpret_cast(request_bytes), request_size, arena.ptr()); + if (slices.size() > 1) { + gpr_free(request_bytes); + } + if (request_struct == nullptr) { + return false; + } + upb_strview service = + grpc_health_v1_HealthCheckRequest_service(request_struct); + if (service.size > MAX_SERVICE_NAME_LENGTH) { + return false; + } + service_name->assign(service.data, service.size); + return true; +} + +bool DefaultHealthCheckService::HealthCheckServiceImpl::EncodeResponse( + ServingStatus status, ByteBuffer* response) { + upb::Arena arena; + grpc_health_v1_HealthCheckResponse* response_struct = + grpc_health_v1_HealthCheckResponse_new(arena.ptr()); + grpc_health_v1_HealthCheckResponse_set_status( + response_struct, + status == NOT_FOUND ? grpc_health_v1_HealthCheckResponse_SERVICE_UNKNOWN + : status == SERVING ? grpc_health_v1_HealthCheckResponse_SERVING + : grpc_health_v1_HealthCheckResponse_NOT_SERVING); + size_t buf_length; + char* buf = grpc_health_v1_HealthCheckResponse_serialize( + response_struct, arena.ptr(), &buf_length); + if (buf == nullptr) { + return false; + } + grpc_slice response_slice = grpc_slice_from_copied_buffer(buf, buf_length); + Slice encoded_response(response_slice, Slice::STEAL_REF); + ByteBuffer response_buffer(&encoded_response, 1); + response->Swap(&response_buffer); + return true; +} + +// +// DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler +// + +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) { + std::shared_ptr self = + std::make_shared(cq, database, service); + CheckCallHandler* handler = static_cast(self.get()); + { + grpc_core::MutexLock lock(&service->cq_shutdown_mu_); + if (service->shutdown_) return; + // Request a Check() call. + handler->next_ = + CallableTag(std::bind(&CheckCallHandler::OnCallReceived, handler, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + service->RequestAsyncUnary(0, &handler->ctx_, &handler->request_, + &handler->writer_, cq, cq, &handler->next_); + } +} + +DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + CheckCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) + : cq_(cq), database_(database), service_(service), writer_(&ctx_) {} + +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + OnCallReceived(std::shared_ptr self, bool ok) { + if (!ok) { + // The value of ok being false means that the server is shutting down. + return; + } + // Spawn a new handler instance to serve the next new client. Every handler + // instance will deallocate itself when it's done. + CreateAndStart(cq_, database_, service_); + // Process request. + gpr_log(GPR_DEBUG, "[HCS %p] Health check started for handler %p", service_, + this); + std::string service_name; + grpc::Status status = Status::OK; + ByteBuffer response; + if (!service_->DecodeRequest(request_, &service_name)) { + status = Status(StatusCode::INVALID_ARGUMENT, "could not parse request"); + } else { + ServingStatus serving_status = database_->GetServingStatus(service_name); + if (serving_status == NOT_FOUND) { + status = Status(StatusCode::NOT_FOUND, "service name unknown"); + } else if (!service_->EncodeResponse(serving_status, &response)) { + status = Status(StatusCode::INTERNAL, "could not encode response"); + } + } + // Send response. + { + grpc_core::MutexLock lock(&service_->cq_shutdown_mu_); + if (!service_->shutdown_) { + next_ = + CallableTag(std::bind(&CheckCallHandler::OnFinishDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + if (status.ok()) { + writer_.Finish(response, status, &next_); + } else { + writer_.FinishWithError(status, &next_); + } + } + } +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + OnFinishDone(std::shared_ptr self, bool ok) { + if (ok) { + gpr_log(GPR_DEBUG, "[HCS %p] Health check call finished for handler %p", + service_, this); + } + self.reset(); // To appease clang-tidy. +} + +// +// DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler +// + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) { + std::shared_ptr self = + std::make_shared(cq, database, service); + WatchCallHandler* handler = static_cast(self.get()); + { + grpc_core::MutexLock lock(&service->cq_shutdown_mu_); + if (service->shutdown_) return; + // Request AsyncNotifyWhenDone(). + handler->on_done_notified_ = + CallableTag(std::bind(&WatchCallHandler::OnDoneNotified, handler, + std::placeholders::_1, std::placeholders::_2), + self /* copies ref */); + handler->ctx_.AsyncNotifyWhenDone(&handler->on_done_notified_); + // Request a Watch() call. + handler->next_ = + CallableTag(std::bind(&WatchCallHandler::OnCallReceived, handler, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + service->RequestAsyncServerStreaming(1, &handler->ctx_, &handler->request_, + &handler->stream_, cq, cq, + &handler->next_); + } +} + +DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + WatchCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) + : cq_(cq), database_(database), service_(service), stream_(&ctx_) {} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnCallReceived(std::shared_ptr self, bool ok) { + if (!ok) { + // Server shutting down. + // + // AsyncNotifyWhenDone() needs to be called before the call starts, but the + // tag will not pop out if the call never starts ( + // https://github.com/grpc/grpc/issues/10136). So we need to manually + // release the ownership of the handler in this case. + GPR_ASSERT(on_done_notified_.ReleaseHandler() != nullptr); + return; + } + // Spawn a new handler instance to serve the next new client. Every handler + // instance will deallocate itself when it's done. + CreateAndStart(cq_, database_, service_); + // Parse request. + if (!service_->DecodeRequest(request_, &service_name_)) { + SendFinish(std::move(self), + Status(StatusCode::INVALID_ARGUMENT, "could not parse request")); + return; + } + // Register the call for updates to the service. + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch started for service \"%s\" (handler: %p)", + service_, service_name_.c_str(), this); + database_->RegisterCallHandler(service_name_, std::move(self)); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendHealth(std::shared_ptr self, ServingStatus status) { + grpc_core::MutexLock lock(&send_mu_); + // If there's already a send in flight, cache the new status, and + // we'll start a new send for it when the one in flight completes. + if (send_in_flight_) { + pending_status_ = status; + return; + } + // Start a send. + SendHealthLocked(std::move(self), status); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendHealthLocked(std::shared_ptr self, ServingStatus status) { + send_in_flight_ = true; + // Construct response. + ByteBuffer response; + bool success = service_->EncodeResponse(status, &response); + // Grab shutdown lock and send response. + grpc_core::MutexLock cq_lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) { + SendFinishLocked(std::move(self), Status::CANCELLED); + return; + } + if (!success) { + SendFinishLocked(std::move(self), + Status(StatusCode::INTERNAL, "could not encode response")); + return; + } + next_ = CallableTag(std::bind(&WatchCallHandler::OnSendHealthDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Write(response, &next_); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnSendHealthDone(std::shared_ptr self, bool ok) { + if (!ok) { + SendFinish(std::move(self), Status::CANCELLED); + return; + } + grpc_core::MutexLock lock(&send_mu_); + send_in_flight_ = false; + // If we got a new status since we started the last send, start a + // new send for it. + if (pending_status_ != NOT_FOUND) { + auto status = pending_status_; + pending_status_ = NOT_FOUND; + SendHealthLocked(std::move(self), status); + } +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendFinish(std::shared_ptr self, const Status& status) { + if (finish_called_) return; + grpc_core::MutexLock cq_lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) return; + SendFinishLocked(std::move(self), status); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendFinishLocked(std::shared_ptr self, const Status& status) { + on_finish_done_ = + CallableTag(std::bind(&WatchCallHandler::OnFinishDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Finish(status, &on_finish_done_); + finish_called_ = true; +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnFinishDone(std::shared_ptr self, bool ok) { + if (ok) { + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch call finished (service_name: \"%s\", " + "handler: %p).", + service_, service_name_.c_str(), this); + } + self.reset(); // To appease clang-tidy. +} + +// TODO(roth): This method currently assumes that there will be only one +// thread polling the cq and invoking the corresponding callbacks. If +// that changes, we will need to add synchronization here. +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnDoneNotified(std::shared_ptr self, bool ok) { + GPR_ASSERT(ok); + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch call is notified done (handler: %p, " + "is_cancelled: %d).", + service_, this, static_cast(ctx_.IsCancelled())); + database_->UnregisterCallHandler(service_name_, self); + SendFinish(std::move(self), Status::CANCELLED); +} + +} // namespace grpc diff --git a/src/cpp/server/health/health_check_service.cc b/src/cpp/server/health/health_check_service.cc new file mode 100644 index 00000000..a0fa2d62 --- /dev/null +++ b/src/cpp/server/health/health_check_service.cc @@ -0,0 +1,34 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { +namespace { +bool g_grpc_default_health_check_service_enabled = false; +} // namespace + +bool DefaultHealthCheckServiceEnabled() { + return g_grpc_default_health_check_service_enabled; +} + +void EnableDefaultHealthCheckService(bool enable) { + g_grpc_default_health_check_service_enabled = enable; +} + +} // namespace grpc diff --git a/src/cpp/server/health/health_check_service_server_builder_option.cc b/src/cpp/server/health/health_check_service_server_builder_option.cc new file mode 100644 index 00000000..3fa384ac --- /dev/null +++ b/src/cpp/server/health/health_check_service_server_builder_option.cc @@ -0,0 +1,35 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { + +HealthCheckServiceServerBuilderOption::HealthCheckServiceServerBuilderOption( + std::unique_ptr hc) + : hc_(std::move(hc)) {} +// Hand over hc_ to the server. +void HealthCheckServiceServerBuilderOption::UpdateArguments( + ChannelArguments* args) { + args->SetPointer(kHealthCheckServiceInterfaceArg, hc_.release()); +} + +void HealthCheckServiceServerBuilderOption::UpdatePlugins( + std::vector>* /*plugins*/) {} + +} // namespace grpc diff --git a/src/cpp/server/insecure_server_credentials.cc b/src/cpp/server/insecure_server_credentials.cc new file mode 100644 index 00000000..10da4b9e --- /dev/null +++ b/src/cpp/server/insecure_server_credentials.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +namespace grpc { +namespace { +class InsecureServerCredentialsImpl final : public ServerCredentials { + public: + int AddPortToServer(const std::string& addr, grpc_server* server) override { + return grpc_server_add_insecure_http2_port(server, addr.c_str()); + } + void SetAuthMetadataProcessor( + const std::shared_ptr& processor) override { + (void)processor; + GPR_ASSERT(0); // Should not be called on InsecureServerCredentials. + } + + private: + bool IsInsecure() const override { return true; } +}; +} // namespace + +std::shared_ptr InsecureServerCredentials() { + return std::shared_ptr( + new InsecureServerCredentialsImpl()); +} + +} // namespace grpc diff --git a/src/cpp/server/load_reporter/get_cpu_stats_linux.cc b/src/cpp/server/load_reporter/get_cpu_stats_linux.cc new file mode 100644 index 00000000..f778b137 --- /dev/null +++ b/src/cpp/server/load_reporter/get_cpu_stats_linux.cc @@ -0,0 +1,51 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_LINUX + +#include + +#include + +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +namespace grpc { +namespace load_reporter { + +std::pair GetCpuStatsImpl() { + uint64_t busy = 0, total = 0; + FILE* fp; + fp = fopen("/proc/stat", "r"); + uint64_t user, nice, system, idle; + if (fscanf(fp, "cpu %" PRIu64 " %" PRIu64 " %" PRIu64 " %" PRIu64, &user, + &nice, &system, &idle) != 4) { + // Something bad happened with the information, so assume it's all invalid + user = nice = system = idle = 0; + } + fclose(fp); + busy = user + nice + system; + total = busy + idle; + return std::make_pair(busy, total); +} + +} // namespace load_reporter +} // namespace grpc + +#endif // GPR_LINUX diff --git a/src/cpp/server/load_reporter/get_cpu_stats_macos.cc b/src/cpp/server/load_reporter/get_cpu_stats_macos.cc new file mode 100644 index 00000000..dbdde304 --- /dev/null +++ b/src/cpp/server/load_reporter/get_cpu_stats_macos.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_APPLE + +#include + +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +namespace grpc { +namespace load_reporter { + +std::pair GetCpuStatsImpl() { + uint64_t busy = 0, total = 0; + host_cpu_load_info_data_t cpuinfo; + mach_msg_type_number_t count = HOST_CPU_LOAD_INFO_COUNT; + if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, + (host_info_t)&cpuinfo, &count) == KERN_SUCCESS) { + for (int i = 0; i < CPU_STATE_MAX; i++) total += cpuinfo.cpu_ticks[i]; + busy = total - cpuinfo.cpu_ticks[CPU_STATE_IDLE]; + } + return std::make_pair(busy, total); +} + +} // namespace load_reporter +} // namespace grpc + +#endif // GPR_APPLE diff --git a/src/cpp/server/load_reporter/get_cpu_stats_unsupported.cc b/src/cpp/server/load_reporter/get_cpu_stats_unsupported.cc new file mode 100644 index 00000000..80fb8b6d --- /dev/null +++ b/src/cpp/server/load_reporter/get_cpu_stats_unsupported.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#if !defined(GPR_LINUX) && !defined(GPR_WINDOWS) && !defined(GPR_APPLE) + +#include + +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +namespace grpc { +namespace load_reporter { + +std::pair GetCpuStatsImpl() { + uint64_t busy = 0, total = 0; + gpr_log(GPR_ERROR, + "Platforms other than Linux, Windows, and MacOS are not supported."); + return std::make_pair(busy, total); +} + +} // namespace load_reporter +} // namespace grpc + +#endif // !defined(GPR_LINUX) && !defined(GPR_WINDOWS) && !defined(GPR_APPLE) diff --git a/src/cpp/server/load_reporter/get_cpu_stats_windows.cc b/src/cpp/server/load_reporter/get_cpu_stats_windows.cc new file mode 100644 index 00000000..c03daddb --- /dev/null +++ b/src/cpp/server/load_reporter/get_cpu_stats_windows.cc @@ -0,0 +1,56 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS + +#include + +#include + +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +namespace grpc { +namespace load_reporter { + +namespace { + +uint64_t FiletimeToInt(const FILETIME& ft) { + ULARGE_INTEGER i; + i.LowPart = ft.dwLowDateTime; + i.HighPart = ft.dwHighDateTime; + return i.QuadPart; +} + +} // namespace + +std::pair GetCpuStatsImpl() { + uint64_t busy = 0, total = 0; + FILETIME idle, kernel, user; + if (GetSystemTimes(&idle, &kernel, &user) != 0) { + total = FiletimeToInt(kernel) + FiletimeToInt(user); + busy = total - FiletimeToInt(idle); + } + return std::make_pair(busy, total); +} + +} // namespace load_reporter +} // namespace grpc + +#endif // GPR_WINDOWS diff --git a/src/cpp/server/load_reporter/load_data_store.cc b/src/cpp/server/load_reporter/load_data_store.cc new file mode 100644 index 00000000..9a65c3e7 --- /dev/null +++ b/src/cpp/server/load_reporter/load_data_store.cc @@ -0,0 +1,340 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_data_store.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/socket_utils.h" + +namespace grpc { +namespace load_reporter { + +// Some helper functions. +namespace { + +// Given a map from type K to a set of value type V, finds the set associated +// with the given key and erases the value from the set. If the set becomes +// empty, also erases the key-set pair. Returns true if the value is erased +// successfully. +template +bool UnorderedMapOfSetEraseKeyValue(std::unordered_map>& map, + const K& key, const V& value) { + auto it = map.find(key); + if (it != map.end()) { + size_t erased = it->second.erase(value); + if (it->second.empty()) { + map.erase(it); + } + return erased; + } + return false; +}; + +// Given a map from type K to a set of value type V, removes the given key and +// the associated set, and returns the set. Returns an empty set if the key is +// not found. +template +std::set UnorderedMapOfSetExtract(std::unordered_map>& map, + const K& key) { + auto it = map.find(key); + if (it != map.end()) { + auto set = std::move(it->second); + map.erase(it); + return set; + } + return {}; +}; + +// From a non-empty container, returns a pointer to a random element. +template +const typename C::value_type* RandomElement(const C& container) { + GPR_ASSERT(!container.empty()); + auto it = container.begin(); + std::advance(it, std::rand() % container.size()); + return &(*it); +} + +} // namespace + +LoadRecordKey::LoadRecordKey(const std::string& client_ip_and_token, + std::string user_id) + : user_id_(std::move(user_id)) { + GPR_ASSERT(client_ip_and_token.size() >= 2); + int ip_hex_size; + GPR_ASSERT(sscanf(client_ip_and_token.substr(0, 2).c_str(), "%d", + &ip_hex_size) == 1); + GPR_ASSERT(ip_hex_size == 0 || ip_hex_size == kIpv4AddressLength || + ip_hex_size == kIpv6AddressLength); + size_t cur_pos = 2; + client_ip_hex_ = client_ip_and_token.substr(cur_pos, ip_hex_size); + cur_pos += ip_hex_size; + if (client_ip_and_token.size() - cur_pos < kLbIdLength) { + lb_id_ = kInvalidLbId; + lb_tag_ = ""; + } else { + lb_id_ = client_ip_and_token.substr(cur_pos, kLbIdLength); + lb_tag_ = client_ip_and_token.substr(cur_pos + kLbIdLength); + } +} + +std::string LoadRecordKey::GetClientIpBytes() const { + if (client_ip_hex_.empty()) { + return ""; + } else if (client_ip_hex_.size() == kIpv4AddressLength) { + uint32_t ip_bytes; + if (sscanf(client_ip_hex_.c_str(), "%x", &ip_bytes) != 1) { + gpr_log(GPR_ERROR, + "Can't parse client IP (%s) from a hex string to an integer.", + client_ip_hex_.c_str()); + return ""; + } + ip_bytes = grpc_htonl(ip_bytes); + return std::string(reinterpret_cast(&ip_bytes), + sizeof(ip_bytes)); + } else if (client_ip_hex_.size() == kIpv6AddressLength) { + uint32_t ip_bytes[4]; + for (size_t i = 0; i < 4; ++i) { + if (sscanf(client_ip_hex_.substr(i * 8, (i + 1) * 8).c_str(), "%x", + ip_bytes + i) != 1) { + gpr_log( + GPR_ERROR, + "Can't parse client IP part (%s) from a hex string to an integer.", + client_ip_hex_.substr(i * 8, (i + 1) * 8).c_str()); + return ""; + } + ip_bytes[i] = grpc_htonl(ip_bytes[i]); + } + return std::string(reinterpret_cast(ip_bytes), + sizeof(ip_bytes)); + } else { + GPR_UNREACHABLE_CODE(return ""); + } +} + +LoadRecordValue::LoadRecordValue(std::string metric_name, uint64_t num_calls, + double total_metric_value) { + call_metrics_.emplace(std::move(metric_name), + CallMetricValue(num_calls, total_metric_value)); +} + +void PerBalancerStore::MergeRow(const LoadRecordKey& key, + const LoadRecordValue& value) { + // During suspension, the load data received will be dropped. + if (!suspended_) { + load_record_map_[key].MergeFrom(value); + gpr_log(GPR_DEBUG, + "[PerBalancerStore %p] Load data merged (Key: %s, Value: %s).", + this, key.ToString().c_str(), value.ToString().c_str()); + } else { + gpr_log(GPR_DEBUG, + "[PerBalancerStore %p] Load data dropped (Key: %s, Value: %s).", + this, key.ToString().c_str(), value.ToString().c_str()); + } + // We always keep track of num_calls_in_progress_, so that when this + // store is resumed, we still have a correct value of + // num_calls_in_progress_. + GPR_ASSERT(static_cast(num_calls_in_progress_) + + value.GetNumCallsInProgressDelta() >= + 0); + num_calls_in_progress_ += value.GetNumCallsInProgressDelta(); +} + +void PerBalancerStore::Suspend() { + suspended_ = true; + load_record_map_.clear(); + gpr_log(GPR_DEBUG, "[PerBalancerStore %p] Suspended.", this); +} + +void PerBalancerStore::Resume() { + suspended_ = false; + gpr_log(GPR_DEBUG, "[PerBalancerStore %p] Resumed.", this); +} + +uint64_t PerBalancerStore::GetNumCallsInProgressForReport() { + GPR_ASSERT(!suspended_); + last_reported_num_calls_in_progress_ = num_calls_in_progress_; + return num_calls_in_progress_; +} + +void PerHostStore::ReportStreamCreated(const std::string& lb_id, + const std::string& load_key) { + GPR_ASSERT(lb_id != kInvalidLbId); + SetUpForNewLbId(lb_id, load_key); + // Prior to this one, there was no load balancer receiving report, so we may + // have unassigned orphaned stores to assign to this new balancer. + // TODO(juanlishen): If the load key of this new stream is the same with + // some previously adopted orphan store, we may want to take the orphan to + // this stream. Need to discuss with LB team. + if (assigned_stores_.size() == 1) { + for (const auto& p : per_balancer_stores_) { + const std::string& other_lb_id = p.first; + const std::unique_ptr& orphaned_store = p.second; + if (other_lb_id != lb_id) { + orphaned_store->Resume(); + AssignOrphanedStore(orphaned_store.get(), lb_id); + } + } + } + // The first connected balancer will adopt the kInvalidLbId. + if (per_balancer_stores_.size() == 1) { + SetUpForNewLbId(kInvalidLbId, ""); + ReportStreamClosed(kInvalidLbId); + } +} + +void PerHostStore::ReportStreamClosed(const std::string& lb_id) { + auto it_store_for_gone_lb = per_balancer_stores_.find(lb_id); + GPR_ASSERT(it_store_for_gone_lb != per_balancer_stores_.end()); + // Remove this closed stream from our records. + GPR_ASSERT(UnorderedMapOfSetEraseKeyValue( + load_key_to_receiving_lb_ids_, it_store_for_gone_lb->second->load_key(), + lb_id)); + std::set orphaned_stores = + UnorderedMapOfSetExtract(assigned_stores_, lb_id); + // The stores that were assigned to this balancer are orphaned now. They + // should be re-assigned to other balancers which are still receiving reports. + for (PerBalancerStore* orphaned_store : orphaned_stores) { + const std::string* new_receiver = nullptr; + auto it = load_key_to_receiving_lb_ids_.find(orphaned_store->load_key()); + if (it != load_key_to_receiving_lb_ids_.end()) { + // First, try to pick from the active balancers with the same load key. + new_receiver = RandomElement(it->second); + } else if (!assigned_stores_.empty()) { + // If failed, pick from all the remaining active balancers. + new_receiver = &(RandomElement(assigned_stores_)->first); + } + if (new_receiver != nullptr) { + AssignOrphanedStore(orphaned_store, *new_receiver); + } else { + // Load data for an LB ID that can't be assigned to any stream should + // be dropped. + orphaned_store->Suspend(); + } + } +} + +PerBalancerStore* PerHostStore::FindPerBalancerStore( + const std::string& lb_id) const { + return per_balancer_stores_.find(lb_id) != per_balancer_stores_.end() + ? per_balancer_stores_.find(lb_id)->second.get() + : nullptr; +} + +const std::set* PerHostStore::GetAssignedStores( + const std::string& lb_id) const { + auto it = assigned_stores_.find(lb_id); + if (it == assigned_stores_.end()) return nullptr; + return &(it->second); +} + +void PerHostStore::AssignOrphanedStore(PerBalancerStore* orphaned_store, + const std::string& new_receiver) { + auto it = assigned_stores_.find(new_receiver); + GPR_ASSERT(it != assigned_stores_.end()); + it->second.insert(orphaned_store); + gpr_log(GPR_INFO, + "[PerHostStore %p] Re-assigned orphaned store (%p) with original LB" + " ID of %s to new receiver %s", + this, orphaned_store, orphaned_store->lb_id().c_str(), + new_receiver.c_str()); +} + +void PerHostStore::SetUpForNewLbId(const std::string& lb_id, + const std::string& load_key) { + // The top-level caller (i.e., LoadReportService) should guarantee the + // lb_id is unique for each reporting stream. + GPR_ASSERT(per_balancer_stores_.find(lb_id) == per_balancer_stores_.end()); + GPR_ASSERT(assigned_stores_.find(lb_id) == assigned_stores_.end()); + load_key_to_receiving_lb_ids_[load_key].insert(lb_id); + std::unique_ptr per_balancer_store( + new PerBalancerStore(lb_id, load_key)); + assigned_stores_[lb_id] = {per_balancer_store.get()}; + per_balancer_stores_[lb_id] = std::move(per_balancer_store); +} + +PerBalancerStore* LoadDataStore::FindPerBalancerStore( + const string& hostname, const string& lb_id) const { + auto it = per_host_stores_.find(hostname); + if (it != per_host_stores_.end()) { + const PerHostStore& per_host_store = it->second; + return per_host_store.FindPerBalancerStore(lb_id); + } else { + return nullptr; + } +} + +void LoadDataStore::MergeRow(const std::string& hostname, + const LoadRecordKey& key, + const LoadRecordValue& value) { + PerBalancerStore* per_balancer_store = + FindPerBalancerStore(hostname, key.lb_id()); + if (per_balancer_store != nullptr) { + per_balancer_store->MergeRow(key, value); + return; + } + // Unknown LB ID. Track it until its number of in-progress calls drops to + // zero. + int64_t in_progress_delta = value.GetNumCallsInProgressDelta(); + if (in_progress_delta != 0) { + auto it_tracker = unknown_balancer_id_trackers_.find(key.lb_id()); + if (it_tracker == unknown_balancer_id_trackers_.end()) { + gpr_log( + GPR_DEBUG, + "[LoadDataStore %p] Start tracking unknown balancer (lb_id_: %s).", + this, key.lb_id().c_str()); + unknown_balancer_id_trackers_.insert( + {key.lb_id(), static_cast(in_progress_delta)}); + } else if ((it_tracker->second += in_progress_delta) == 0) { + unknown_balancer_id_trackers_.erase(it_tracker); + gpr_log(GPR_DEBUG, + "[LoadDataStore %p] Stop tracking unknown balancer (lb_id_: %s).", + this, key.lb_id().c_str()); + } + } +} + +const std::set* LoadDataStore::GetAssignedStores( + const std::string& hostname, const std::string& lb_id) { + auto it = per_host_stores_.find(hostname); + if (it == per_host_stores_.end()) return nullptr; + return it->second.GetAssignedStores(lb_id); +} + +void LoadDataStore::ReportStreamCreated(const std::string& hostname, + const std::string& lb_id, + const std::string& load_key) { + per_host_stores_[hostname].ReportStreamCreated(lb_id, load_key); +} + +void LoadDataStore::ReportStreamClosed(const std::string& hostname, + const std::string& lb_id) { + auto it_per_host_store = per_host_stores_.find(hostname); + GPR_ASSERT(it_per_host_store != per_host_stores_.end()); + it_per_host_store->second.ReportStreamClosed(lb_id); +} + +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/load_reporter/load_reporter.cc b/src/cpp/server/load_reporter/load_reporter.cc new file mode 100644 index 00000000..df0ac096 --- /dev/null +++ b/src/cpp/server/load_reporter/load_reporter.cc @@ -0,0 +1,513 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_reporter.h" + +#include +#include +#include + +#include +#include +#include + +#include "opencensus/stats/internal/set_aggregation_window.h" +#include "opencensus/tags/tag_key.h" + +#include "src/cpp/server/load_reporter/constants.h" +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +namespace grpc { +namespace load_reporter { + +CpuStatsProvider::CpuStatsSample CpuStatsProviderDefaultImpl::GetCpuStats() { + return GetCpuStatsImpl(); +} + +CensusViewProvider::CensusViewProvider() + : tag_key_token_(::opencensus::tags::TagKey::Register(kTagKeyToken)), + tag_key_host_(::opencensus::tags::TagKey::Register(kTagKeyHost)), + tag_key_user_id_(::opencensus::tags::TagKey::Register(kTagKeyUserId)), + tag_key_status_(::opencensus::tags::TagKey::Register(kTagKeyStatus)), + tag_key_metric_name_( + ::opencensus::tags::TagKey::Register(kTagKeyMetricName)) { + // One view related to starting a call. + auto vd_start_count = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewStartCount) + .set_measure(kMeasureStartCount) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .set_description( + "Delta count of calls started broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_start_count); + view_descriptor_map_.emplace(kViewStartCount, vd_start_count); + // Four views related to ending a call. + // If this view is set as Count of kMeasureEndBytesSent (in hope of saving one + // measure), it's infeasible to prepare fake data for testing. That's because + // the OpenCensus API to make up view data will add the input data as separate + // measurements instead of setting the data values directly. + auto vd_end_count = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewEndCount) + .set_measure(kMeasureEndCount) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_status_) + .set_description( + "Delta count of calls ended broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_end_count); + view_descriptor_map_.emplace(kViewEndCount, vd_end_count); + auto vd_end_bytes_sent = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewEndBytesSent) + .set_measure(kMeasureEndBytesSent) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_status_) + .set_description( + "Delta sum of bytes sent broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_end_bytes_sent); + view_descriptor_map_.emplace(kViewEndBytesSent, vd_end_bytes_sent); + auto vd_end_bytes_received = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewEndBytesReceived) + .set_measure(kMeasureEndBytesReceived) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_status_) + .set_description( + "Delta sum of bytes received broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_end_bytes_received); + view_descriptor_map_.emplace(kViewEndBytesReceived, vd_end_bytes_received); + auto vd_end_latency_ms = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewEndLatencyMs) + .set_measure(kMeasureEndLatencyMs) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_status_) + .set_description( + "Delta sum of latency in ms broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_end_latency_ms); + view_descriptor_map_.emplace(kViewEndLatencyMs, vd_end_latency_ms); + // Two views related to other call metrics. + auto vd_metric_call_count = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewOtherCallMetricCount) + .set_measure(kMeasureOtherCallMetric) + .set_aggregation(::opencensus::stats::Aggregation::Count()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_metric_name_) + .set_description( + "Delta count of calls broken down by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_metric_call_count); + view_descriptor_map_.emplace(kViewOtherCallMetricCount, vd_metric_call_count); + auto vd_metric_value = + ::opencensus::stats::ViewDescriptor() + .set_name(kViewOtherCallMetricValue) + .set_measure(kMeasureOtherCallMetric) + .set_aggregation(::opencensus::stats::Aggregation::Sum()) + .add_column(tag_key_token_) + .add_column(tag_key_host_) + .add_column(tag_key_user_id_) + .add_column(tag_key_metric_name_) + .set_description( + "Delta sum of call metric value broken down " + "by ."); + ::opencensus::stats::SetAggregationWindow( + ::opencensus::stats::AggregationWindow::Delta(), &vd_metric_value); + view_descriptor_map_.emplace(kViewOtherCallMetricValue, vd_metric_value); +} + +double CensusViewProvider::GetRelatedViewDataRowDouble( + const ViewDataMap& view_data_map, const char* view_name, + size_t view_name_len, const std::vector& tag_values) { + auto it_vd = view_data_map.find(std::string(view_name, view_name_len)); + GPR_ASSERT(it_vd != view_data_map.end()); + GPR_ASSERT(it_vd->second.type() == + ::opencensus::stats::ViewData::Type::kDouble); + auto it_row = it_vd->second.double_data().find(tag_values); + GPR_ASSERT(it_row != it_vd->second.double_data().end()); + return it_row->second; +} + +uint64_t CensusViewProvider::GetRelatedViewDataRowInt( + const ViewDataMap& view_data_map, const char* view_name, + size_t view_name_len, const std::vector& tag_values) { + auto it_vd = view_data_map.find(std::string(view_name, view_name_len)); + GPR_ASSERT(it_vd != view_data_map.end()); + GPR_ASSERT(it_vd->second.type() == + ::opencensus::stats::ViewData::Type::kInt64); + auto it_row = it_vd->second.int_data().find(tag_values); + GPR_ASSERT(it_row != it_vd->second.int_data().end()); + GPR_ASSERT(it_row->second >= 0); + return it_row->second; +} + +CensusViewProviderDefaultImpl::CensusViewProviderDefaultImpl() { + for (const auto& p : view_descriptor_map()) { + const std::string& view_name = p.first; + const ::opencensus::stats::ViewDescriptor& vd = p.second; + // We need to use pair's piecewise ctor here, otherwise the deleted copy + // ctor of View will be called. + view_map_.emplace(std::piecewise_construct, + std::forward_as_tuple(view_name), + std::forward_as_tuple(vd)); + } +} + +CensusViewProvider::ViewDataMap CensusViewProviderDefaultImpl::FetchViewData() { + gpr_log(GPR_DEBUG, "[CVP %p] Starts fetching Census view data.", this); + ViewDataMap view_data_map; + for (auto& p : view_map_) { + const std::string& view_name = p.first; + ::opencensus::stats::View& view = p.second; + if (view.IsValid()) { + view_data_map.emplace(view_name, view.GetData()); + gpr_log(GPR_DEBUG, "[CVP %p] Fetched view data (view: %s).", this, + view_name.c_str()); + } else { + gpr_log( + GPR_DEBUG, + "[CVP %p] Can't fetch view data because view is invalid (view: %s).", + this, view_name.c_str()); + } + } + return view_data_map; +} + +std::string LoadReporter::GenerateLbId() { + while (true) { + if (next_lb_id_ > UINT32_MAX) { + gpr_log(GPR_ERROR, "[LR %p] The LB ID exceeds the max valid value!", + this); + return ""; + } + int64_t lb_id = next_lb_id_++; + // Overflow should never happen. + GPR_ASSERT(lb_id >= 0); + // Convert to padded hex string for a 32-bit LB ID. E.g, "0000ca5b". + char buf[kLbIdLength + 1]; + snprintf(buf, sizeof(buf), "%08" PRIx64, lb_id); + std::string lb_id_str(buf, kLbIdLength); + // The client may send requests with LB ID that has never been allocated + // by this load reporter. Those IDs are tracked and will be skipped when + // we generate a new ID. + if (!load_data_store_.IsTrackedUnknownBalancerId(lb_id_str)) { + return lb_id_str; + } + } +} + +::grpc::lb::v1::LoadBalancingFeedback +LoadReporter::GenerateLoadBalancingFeedback() { + grpc_core::ReleasableMutexLock lock(&feedback_mu_); + auto now = std::chrono::system_clock::now(); + // Discard records outside the window until there is only one record + // outside the window, which is used as the base for difference. + while (feedback_records_.size() > 1 && + !IsRecordInWindow(feedback_records_[1], now)) { + feedback_records_.pop_front(); + } + if (feedback_records_.size() < 2) { + return ::grpc::lb::v1::LoadBalancingFeedback::default_instance(); + } + // Find the longest range with valid ends. + auto oldest = feedback_records_.begin(); + auto newest = feedback_records_.end() - 1; + while (std::distance(oldest, newest) > 0 && + (newest->cpu_limit == 0 || oldest->cpu_limit == 0)) { + // A zero limit means that the system info reading was failed, so these + // records can't be used to calculate CPU utilization. + if (newest->cpu_limit == 0) --newest; + if (oldest->cpu_limit == 0) ++oldest; + } + if (std::distance(oldest, newest) < 1 || + oldest->end_time == newest->end_time || + newest->cpu_limit == oldest->cpu_limit) { + return ::grpc::lb::v1::LoadBalancingFeedback::default_instance(); + } + uint64_t rpcs = 0; + uint64_t errors = 0; + for (auto p = newest; p != oldest; --p) { + // Because these two numbers are counters, the oldest record shouldn't be + // included. + rpcs += p->rpcs; + errors += p->errors; + } + double cpu_usage = newest->cpu_usage - oldest->cpu_usage; + double cpu_limit = newest->cpu_limit - oldest->cpu_limit; + std::chrono::duration duration_seconds = + newest->end_time - oldest->end_time; + lock.Release(); + ::grpc::lb::v1::LoadBalancingFeedback feedback; + feedback.set_server_utilization(static_cast(cpu_usage / cpu_limit)); + feedback.set_calls_per_second( + static_cast(rpcs / duration_seconds.count())); + feedback.set_errors_per_second( + static_cast(errors / duration_seconds.count())); + return feedback; +} + +::google::protobuf::RepeatedPtrField<::grpc::lb::v1::Load> +LoadReporter::GenerateLoads(const std::string& hostname, + const std::string& lb_id) { + grpc_core::MutexLock lock(&store_mu_); + auto assigned_stores = load_data_store_.GetAssignedStores(hostname, lb_id); + GPR_ASSERT(assigned_stores != nullptr); + GPR_ASSERT(!assigned_stores->empty()); + ::google::protobuf::RepeatedPtrField<::grpc::lb::v1::Load> loads; + for (PerBalancerStore* per_balancer_store : *assigned_stores) { + GPR_ASSERT(!per_balancer_store->IsSuspended()); + if (!per_balancer_store->load_record_map().empty()) { + for (const auto& p : per_balancer_store->load_record_map()) { + const auto& key = p.first; + const auto& value = p.second; + auto load = loads.Add(); + load->set_load_balance_tag(key.lb_tag()); + load->set_user_id(key.user_id()); + load->set_client_ip_address(key.GetClientIpBytes()); + load->set_num_calls_started(static_cast(value.start_count())); + load->set_num_calls_finished_without_error( + static_cast(value.ok_count())); + load->set_num_calls_finished_with_error( + static_cast(value.error_count())); + load->set_total_bytes_sent(static_cast(value.bytes_sent())); + load->set_total_bytes_received( + static_cast(value.bytes_recv())); + load->mutable_total_latency()->set_seconds( + static_cast(value.latency_ms() / 1000)); + load->mutable_total_latency()->set_nanos( + (static_cast(value.latency_ms()) % 1000) * 1000000); + for (const auto& p : value.call_metrics()) { + const std::string& metric_name = p.first; + const CallMetricValue& metric_value = p.second; + auto call_metric_data = load->add_metric_data(); + call_metric_data->set_metric_name(metric_name); + call_metric_data->set_num_calls_finished_with_metric( + metric_value.num_calls()); + call_metric_data->set_total_metric_value( + metric_value.total_metric_value()); + } + if (per_balancer_store->lb_id() != lb_id) { + // This per-balancer store is an orphan assigned to this receiving + // balancer. + AttachOrphanLoadId(load, *per_balancer_store); + } + } + per_balancer_store->ClearLoadRecordMap(); + } + if (per_balancer_store->IsNumCallsInProgressChangedSinceLastReport()) { + auto load = loads.Add(); + load->set_num_calls_in_progress( + per_balancer_store->GetNumCallsInProgressForReport()); + if (per_balancer_store->lb_id() != lb_id) { + // This per-balancer store is an orphan assigned to this receiving + // balancer. + AttachOrphanLoadId(load, *per_balancer_store); + } + } + } + return loads; +} + +void LoadReporter::AttachOrphanLoadId( + ::grpc::lb::v1::Load* load, const PerBalancerStore& per_balancer_store) { + if (per_balancer_store.lb_id() == kInvalidLbId) { + load->set_load_key_unknown(true); + } else { + // We shouldn't set load_key_unknown to any value in this case because + // load_key_unknown and orphaned_load_identifier are under an oneof struct. + load->mutable_orphaned_load_identifier()->set_load_key( + per_balancer_store.load_key()); + load->mutable_orphaned_load_identifier()->set_load_balancer_id( + per_balancer_store.lb_id()); + } +} + +void LoadReporter::AppendNewFeedbackRecord(uint64_t rpcs, uint64_t errors) { + CpuStatsProvider::CpuStatsSample cpu_stats; + if (cpu_stats_provider_ != nullptr) { + cpu_stats = cpu_stats_provider_->GetCpuStats(); + } else { + // This will make the load balancing feedback generation a no-op. + cpu_stats = {0, 0}; + } + grpc_core::MutexLock lock(&feedback_mu_); + feedback_records_.emplace_back(std::chrono::system_clock::now(), rpcs, errors, + cpu_stats.first, cpu_stats.second); +} + +void LoadReporter::ReportStreamCreated(const std::string& hostname, + const std::string& lb_id, + const std::string& load_key) { + grpc_core::MutexLock lock(&store_mu_); + load_data_store_.ReportStreamCreated(hostname, lb_id, load_key); + gpr_log(GPR_INFO, + "[LR %p] Report stream created (host: %s, LB ID: %s, load key: %s).", + this, hostname.c_str(), lb_id.c_str(), load_key.c_str()); +} + +void LoadReporter::ReportStreamClosed(const std::string& hostname, + const std::string& lb_id) { + grpc_core::MutexLock lock(&store_mu_); + load_data_store_.ReportStreamClosed(hostname, lb_id); + gpr_log(GPR_INFO, "[LR %p] Report stream closed (host: %s, LB ID: %s).", this, + hostname.c_str(), lb_id.c_str()); +} + +void LoadReporter::ProcessViewDataCallStart( + const CensusViewProvider::ViewDataMap& view_data_map) { + auto it = view_data_map.find(kViewStartCount); + if (it != view_data_map.end()) { + for (const auto& p : it->second.int_data()) { + const std::vector& tag_values = p.first; + const uint64_t start_count = static_cast(p.second); + const std::string& client_ip_and_token = tag_values[0]; + const std::string& host = tag_values[1]; + const std::string& user_id = tag_values[2]; + LoadRecordKey key(client_ip_and_token, user_id); + LoadRecordValue value = LoadRecordValue(start_count); + { + grpc_core::MutexLock lock(&store_mu_); + load_data_store_.MergeRow(host, key, value); + } + } + } +} + +void LoadReporter::ProcessViewDataCallEnd( + const CensusViewProvider::ViewDataMap& view_data_map) { + uint64_t total_end_count = 0; + uint64_t total_error_count = 0; + auto it = view_data_map.find(kViewEndCount); + if (it != view_data_map.end()) { + for (const auto& p : it->second.int_data()) { + const std::vector& tag_values = p.first; + const uint64_t end_count = static_cast(p.second); + const std::string& client_ip_and_token = tag_values[0]; + const std::string& host = tag_values[1]; + const std::string& user_id = tag_values[2]; + const std::string& status = tag_values[3]; + // This is due to a bug reported internally of Java server load reporting + // implementation. + // TODO(juanlishen): Check whether this situation happens in OSS C++. + if (client_ip_and_token.empty()) { + gpr_log(GPR_DEBUG, + "Skipping processing Opencensus record with empty " + "client_ip_and_token tag."); + continue; + } + LoadRecordKey key(client_ip_and_token, user_id); + const uint64_t bytes_sent = CensusViewProvider::GetRelatedViewDataRowInt( + view_data_map, kViewEndBytesSent, sizeof(kViewEndBytesSent) - 1, + tag_values); + const uint64_t bytes_received = + CensusViewProvider::GetRelatedViewDataRowInt( + view_data_map, kViewEndBytesReceived, + sizeof(kViewEndBytesReceived) - 1, tag_values); + const uint64_t latency_ms = CensusViewProvider::GetRelatedViewDataRowInt( + view_data_map, kViewEndLatencyMs, sizeof(kViewEndLatencyMs) - 1, + tag_values); + uint64_t ok_count = 0; + uint64_t error_count = 0; + total_end_count += end_count; + if (std::strcmp(status.c_str(), kCallStatusOk) == 0) { + ok_count = end_count; + } else { + error_count = end_count; + total_error_count += end_count; + } + LoadRecordValue value = LoadRecordValue( + 0, ok_count, error_count, bytes_sent, bytes_received, latency_ms); + { + grpc_core::MutexLock lock(&store_mu_); + load_data_store_.MergeRow(host, key, value); + } + } + } + AppendNewFeedbackRecord(total_end_count, total_error_count); +} + +void LoadReporter::ProcessViewDataOtherCallMetrics( + const CensusViewProvider::ViewDataMap& view_data_map) { + auto it = view_data_map.find(kViewOtherCallMetricCount); + if (it != view_data_map.end()) { + for (const auto& p : it->second.int_data()) { + const std::vector& tag_values = p.first; + const int64_t num_calls = p.second; + const std::string& client_ip_and_token = tag_values[0]; + const std::string& host = tag_values[1]; + const std::string& user_id = tag_values[2]; + const std::string& metric_name = tag_values[3]; + LoadRecordKey key(client_ip_and_token, user_id); + const double total_metric_value = + CensusViewProvider::GetRelatedViewDataRowDouble( + view_data_map, kViewOtherCallMetricValue, + sizeof(kViewOtherCallMetricValue) - 1, tag_values); + LoadRecordValue value = LoadRecordValue( + metric_name, static_cast(num_calls), total_metric_value); + { + grpc_core::MutexLock lock(&store_mu_); + load_data_store_.MergeRow(host, key, value); + } + } + } +} + +void LoadReporter::FetchAndSample() { + gpr_log(GPR_DEBUG, + "[LR %p] Starts fetching Census view data and sampling LB feedback " + "record.", + this); + CensusViewProvider::ViewDataMap view_data_map = + census_view_provider_->FetchViewData(); + ProcessViewDataCallStart(view_data_map); + ProcessViewDataCallEnd(view_data_map); + ProcessViewDataOtherCallMetrics(view_data_map); +} + +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/load_reporter/load_reporter_async_service_impl.cc b/src/cpp/server/load_reporter/load_reporter_async_service_impl.cc new file mode 100644 index 00000000..7c9465b8 --- /dev/null +++ b/src/cpp/server/load_reporter/load_reporter_async_service_impl.cc @@ -0,0 +1,376 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_reporter_async_service_impl.h" + +#include + +#include "absl/memory/memory.h" + +namespace grpc { +namespace load_reporter { + +void LoadReporterAsyncServiceImpl::CallableTag::Run(bool ok) { + GPR_ASSERT(handler_function_ != nullptr); + GPR_ASSERT(handler_ != nullptr); + handler_function_(std::move(handler_), ok); +} + +LoadReporterAsyncServiceImpl::LoadReporterAsyncServiceImpl( + std::unique_ptr cq) + : cq_(std::move(cq)) { + thread_ = absl::make_unique<::grpc_core::Thread>("server_load_reporting", + Work, this); + std::unique_ptr cpu_stats_provider = nullptr; +#if defined(GPR_LINUX) || defined(GPR_WINDOWS) || defined(GPR_APPLE) + cpu_stats_provider = absl::make_unique(); +#endif + load_reporter_ = absl::make_unique( + kFeedbackSampleWindowSeconds, + std::unique_ptr(new CensusViewProviderDefaultImpl()), + std::move(cpu_stats_provider)); +} + +LoadReporterAsyncServiceImpl::~LoadReporterAsyncServiceImpl() { + // We will reach here after the server starts shutting down. + shutdown_ = true; + { + grpc_core::MutexLock lock(&cq_shutdown_mu_); + cq_->Shutdown(); + } + if (next_fetch_and_sample_alarm_ != nullptr) { + next_fetch_and_sample_alarm_->Cancel(); + } + thread_->Join(); +} + +void LoadReporterAsyncServiceImpl::ScheduleNextFetchAndSample() { + auto next_fetch_and_sample_time = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(kFetchAndSampleIntervalSeconds * 1000, + GPR_TIMESPAN)); + { + grpc_core::MutexLock lock(&cq_shutdown_mu_); + if (shutdown_) return; + // TODO(juanlishen): Improve the Alarm implementation to reuse a single + // instance for multiple events. + next_fetch_and_sample_alarm_ = absl::make_unique(); + next_fetch_and_sample_alarm_->Set(cq_.get(), next_fetch_and_sample_time, + this); + } + gpr_log(GPR_DEBUG, "[LRS %p] Next fetch-and-sample scheduled.", this); +} + +void LoadReporterAsyncServiceImpl::FetchAndSample(bool ok) { + if (!ok) { + gpr_log(GPR_INFO, "[LRS %p] Fetch-and-sample is stopped.", this); + return; + } + gpr_log(GPR_DEBUG, "[LRS %p] Starting a fetch-and-sample...", this); + load_reporter_->FetchAndSample(); + ScheduleNextFetchAndSample(); +} + +void LoadReporterAsyncServiceImpl::Work(void* arg) { + LoadReporterAsyncServiceImpl* service = + static_cast(arg); + service->FetchAndSample(true /* ok */); + // TODO(juanlishen): This is a workaround to wait for the cq to be ready. Need + // to figure out why cq is not ready after service starts. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + ReportLoadHandler::CreateAndStart(service->cq_.get(), service, + service->load_reporter_.get()); + void* tag; + bool ok; + while (true) { + if (!service->cq_->Next(&tag, &ok)) { + // The completion queue is shutting down. + GPR_ASSERT(service->shutdown_); + break; + } + if (tag == service) { + service->FetchAndSample(ok); + } else { + auto* next_step = static_cast(tag); + next_step->Run(ok); + } + } +} + +void LoadReporterAsyncServiceImpl::StartThread() { thread_->Start(); } + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::CreateAndStart( + ServerCompletionQueue* cq, LoadReporterAsyncServiceImpl* service, + LoadReporter* load_reporter) { + std::shared_ptr handler = + std::make_shared(cq, service, load_reporter); + ReportLoadHandler* p = handler.get(); + { + grpc_core::MutexLock lock(&service->cq_shutdown_mu_); + if (service->shutdown_) return; + p->on_done_notified_ = + CallableTag(std::bind(&ReportLoadHandler::OnDoneNotified, p, + std::placeholders::_1, std::placeholders::_2), + handler); + p->next_inbound_ = + CallableTag(std::bind(&ReportLoadHandler::OnRequestDelivered, p, + std::placeholders::_1, std::placeholders::_2), + std::move(handler)); + p->ctx_.AsyncNotifyWhenDone(&p->on_done_notified_); + service->RequestReportLoad(&p->ctx_, &p->stream_, cq, cq, + &p->next_inbound_); + } +} + +LoadReporterAsyncServiceImpl::ReportLoadHandler::ReportLoadHandler( + ServerCompletionQueue* cq, LoadReporterAsyncServiceImpl* service, + LoadReporter* load_reporter) + : cq_(cq), + service_(service), + load_reporter_(load_reporter), + stream_(&ctx_), + call_status_(WAITING_FOR_DELIVERY) {} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::OnRequestDelivered( + std::shared_ptr self, bool ok) { + if (ok) { + call_status_ = DELIVERED; + } else { + // AsyncNotifyWhenDone() needs to be called before the call starts, but the + // tag will not pop out if the call never starts ( + // https://github.com/grpc/grpc/issues/10136). So we need to manually + // release the ownership of the handler in this case. + GPR_ASSERT(on_done_notified_.ReleaseHandler() != nullptr); + } + if (!ok || shutdown_) { + // The value of ok being false means that the server is shutting down. + Shutdown(std::move(self), "OnRequestDelivered"); + return; + } + // Spawn a new handler instance to serve the next new client. Every handler + // instance will deallocate itself when it's done. + CreateAndStart(cq_, service_, load_reporter_); + { + grpc_core::ReleasableMutexLock lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) { + lock.Release(); + Shutdown(std::move(self), "OnRequestDelivered"); + return; + } + next_inbound_ = + CallableTag(std::bind(&ReportLoadHandler::OnReadDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Read(&request_, &next_inbound_); + } + // LB ID is unique for each load reporting stream. + lb_id_ = load_reporter_->GenerateLbId(); + gpr_log(GPR_INFO, + "[LRS %p] Call request delivered (lb_id_: %s, handler: %p). " + "Start reading the initial request...", + service_, lb_id_.c_str(), this); +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::OnReadDone( + std::shared_ptr self, bool ok) { + if (!ok || shutdown_) { + if (!ok && call_status_ < INITIAL_REQUEST_RECEIVED) { + // The client may have half-closed the stream or the stream is broken. + gpr_log(GPR_INFO, + "[LRS %p] Failed reading the initial request from the stream " + "(lb_id_: %s, handler: %p, done_notified: %d, is_cancelled: %d).", + service_, lb_id_.c_str(), this, static_cast(done_notified_), + static_cast(is_cancelled_)); + } + Shutdown(std::move(self), "OnReadDone"); + return; + } + // We only receive one request, which is the initial request. + if (call_status_ < INITIAL_REQUEST_RECEIVED) { + if (!request_.has_initial_request()) { + Shutdown(std::move(self), "OnReadDone+initial_request_not_found"); + } else { + call_status_ = INITIAL_REQUEST_RECEIVED; + const auto& initial_request = request_.initial_request(); + load_balanced_hostname_ = initial_request.load_balanced_hostname(); + load_key_ = initial_request.load_key(); + load_reporter_->ReportStreamCreated(load_balanced_hostname_, lb_id_, + load_key_); + const auto& load_report_interval = initial_request.load_report_interval(); + load_report_interval_ms_ = + static_cast(load_report_interval.seconds() * 1000 + + load_report_interval.nanos() / 1000); + gpr_log(GPR_INFO, + "[LRS %p] Initial request received. Start load reporting (load " + "balanced host: %s, interval: %" PRIu64 + " ms, lb_id_: %s, handler: %p)...", + service_, load_balanced_hostname_.c_str(), + load_report_interval_ms_, lb_id_.c_str(), this); + SendReport(self, true /* ok */); + // Expect this read to fail. + { + grpc_core::ReleasableMutexLock lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) { + lock.Release(); + Shutdown(std::move(self), "OnReadDone"); + return; + } + next_inbound_ = + CallableTag(std::bind(&ReportLoadHandler::OnReadDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Read(&request_, &next_inbound_); + } + } + } else { + // Another request received! This violates the spec. + gpr_log(GPR_ERROR, + "[LRS %p] Another request received (lb_id_: %s, handler: %p).", + service_, lb_id_.c_str(), this); + Shutdown(std::move(self), "OnReadDone+second_request"); + } +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::ScheduleNextReport( + std::shared_ptr self, bool ok) { + if (!ok || shutdown_) { + Shutdown(std::move(self), "ScheduleNextReport"); + return; + } + auto next_report_time = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(load_report_interval_ms_, GPR_TIMESPAN)); + { + grpc_core::ReleasableMutexLock lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) { + lock.Release(); + Shutdown(std::move(self), "ScheduleNextReport"); + return; + } + next_outbound_ = + CallableTag(std::bind(&ReportLoadHandler::SendReport, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + // TODO(juanlishen): Improve the Alarm implementation to reuse a single + // instance for multiple events. + next_report_alarm_ = absl::make_unique(); + next_report_alarm_->Set(cq_, next_report_time, &next_outbound_); + } + gpr_log(GPR_DEBUG, + "[LRS %p] Next load report scheduled (lb_id_: %s, handler: %p).", + service_, lb_id_.c_str(), this); +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::SendReport( + std::shared_ptr self, bool ok) { + if (!ok || shutdown_) { + Shutdown(std::move(self), "SendReport"); + return; + } + ::grpc::lb::v1::LoadReportResponse response; + auto loads = load_reporter_->GenerateLoads(load_balanced_hostname_, lb_id_); + response.mutable_load()->Swap(&loads); + auto feedback = load_reporter_->GenerateLoadBalancingFeedback(); + response.mutable_load_balancing_feedback()->Swap(&feedback); + if (call_status_ < INITIAL_RESPONSE_SENT) { + auto initial_response = response.mutable_initial_response(); + initial_response->set_load_balancer_id(lb_id_); + initial_response->set_implementation_id( + ::grpc::lb::v1::InitialLoadReportResponse::CPP); + initial_response->set_server_version(kVersion); + call_status_ = INITIAL_RESPONSE_SENT; + } + { + grpc_core::ReleasableMutexLock lock(&service_->cq_shutdown_mu_); + if (service_->shutdown_) { + lock.Release(); + Shutdown(std::move(self), "SendReport"); + return; + } + next_outbound_ = + CallableTag(std::bind(&ReportLoadHandler::ScheduleNextReport, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Write(response, &next_outbound_); + gpr_log(GPR_INFO, + "[LRS %p] Sending load report (lb_id_: %s, handler: %p, loads " + "count: %d)...", + service_, lb_id_.c_str(), this, response.load().size()); + } +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::OnDoneNotified( + std::shared_ptr self, bool ok) { + GPR_ASSERT(ok); + done_notified_ = true; + if (ctx_.IsCancelled()) { + is_cancelled_ = true; + } + gpr_log(GPR_INFO, + "[LRS %p] Load reporting call is notified done (handler: %p, " + "is_cancelled: %d).", + service_, this, static_cast(is_cancelled_)); + Shutdown(std::move(self), "OnDoneNotified"); +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::Shutdown( + std::shared_ptr self, const char* reason) { + if (!shutdown_) { + gpr_log(GPR_INFO, + "[LRS %p] Shutting down the handler (lb_id_: %s, handler: %p, " + "reason: %s).", + service_, lb_id_.c_str(), this, reason); + shutdown_ = true; + if (call_status_ >= INITIAL_REQUEST_RECEIVED) { + load_reporter_->ReportStreamClosed(load_balanced_hostname_, lb_id_); + next_report_alarm_->Cancel(); + } + } + // OnRequestDelivered() may be called after OnDoneNotified(), so we need to + // try to Finish() every time we are in Shutdown(). + if (call_status_ >= DELIVERED && call_status_ < FINISH_CALLED) { + grpc_core::MutexLock lock(&service_->cq_shutdown_mu_); + if (!service_->shutdown_) { + on_finish_done_ = + CallableTag(std::bind(&ReportLoadHandler::OnFinishDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + // TODO(juanlishen): Maybe add a message proto for the client to + // explicitly cancel the stream so that we can return OK status in such + // cases. + stream_.Finish(Status::CANCELLED, &on_finish_done_); + call_status_ = FINISH_CALLED; + } + } +} + +void LoadReporterAsyncServiceImpl::ReportLoadHandler::OnFinishDone( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::shared_ptr /*self*/, bool ok) { + if (ok) { + gpr_log(GPR_INFO, + "[LRS %p] Load reporting finished (lb_id_: %s, handler: %p).", + service_, lb_id_.c_str(), this); + } +} + +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/load_reporter/load_reporting_service_server_builder_option.cc b/src/cpp/server/load_reporter/load_reporting_service_server_builder_option.cc new file mode 100644 index 00000000..43ad7021 --- /dev/null +++ b/src/cpp/server/load_reporter/load_reporting_service_server_builder_option.cc @@ -0,0 +1,42 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/cpp/server/load_reporter/load_reporting_service_server_builder_plugin.h" + +namespace grpc { +namespace load_reporter { +namespace experimental { + +void LoadReportingServiceServerBuilderOption::UpdateArguments( + ::grpc::ChannelArguments* args) { + args->SetInt(GRPC_ARG_ENABLE_LOAD_REPORTING, true); +} + +void LoadReportingServiceServerBuilderOption::UpdatePlugins( + std::vector>* plugins) { + plugins->emplace_back( + new grpc::load_reporter::LoadReportingServiceServerBuilderPlugin()); +} + +} // namespace experimental +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/load_reporter/load_reporting_service_server_builder_plugin.cc b/src/cpp/server/load_reporter/load_reporting_service_server_builder_plugin.cc new file mode 100644 index 00000000..aa0ac19f --- /dev/null +++ b/src/cpp/server/load_reporter/load_reporting_service_server_builder_plugin.cc @@ -0,0 +1,60 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_reporting_service_server_builder_plugin.h" + +#include + +namespace grpc { +namespace load_reporter { + +bool LoadReportingServiceServerBuilderPlugin::has_sync_methods() const { + if (service_ != nullptr) { + return service_->has_synchronous_methods(); + } + return false; +} + +bool LoadReportingServiceServerBuilderPlugin::has_async_methods() const { + if (service_ != nullptr) { + return service_->has_async_methods(); + } + return false; +} + +void LoadReportingServiceServerBuilderPlugin::UpdateServerBuilder( + grpc::ServerBuilder* builder) { + auto cq = builder->AddCompletionQueue(); + service_ = std::make_shared(std::move(cq)); +} + +void LoadReportingServiceServerBuilderPlugin::InitServer( + grpc::ServerInitializer* si) { + si->RegisterService(service_); +} + +void LoadReportingServiceServerBuilderPlugin::Finish( + grpc::ServerInitializer* /*si*/) { + service_->StartThread(); + service_.reset(); +} + +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/load_reporter/util.cc b/src/cpp/server/load_reporter/util.cc new file mode 100644 index 00000000..75be7f16 --- /dev/null +++ b/src/cpp/server/load_reporter/util.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +namespace grpc { +namespace load_reporter { +namespace experimental { + +void AddLoadReportingCost(grpc::ServerContext* ctx, + const std::string& cost_name, double cost_value) { + if (std::isnormal(cost_value)) { + std::string buf; + buf.resize(sizeof(cost_value) + cost_name.size()); + memcpy(&(*buf.begin()), &cost_value, sizeof(cost_value)); + memcpy(&(*buf.begin()) + sizeof(cost_value), cost_name.data(), + cost_name.size()); + ctx->AddTrailingMetadata(GRPC_LB_COST_MD_KEY, buf); + } else { + gpr_log(GPR_ERROR, "Call metric value is not normal."); + } +} + +} // namespace experimental +} // namespace load_reporter +} // namespace grpc diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc new file mode 100644 index 00000000..b69ee522 --- /dev/null +++ b/src/cpp/server/secure_server_credentials.cc @@ -0,0 +1,153 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/server/secure_server_credentials.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/cpp/common/secure_auth_context.h" + +namespace grpc { + +void AuthMetadataProcessorAyncWrapper::Destroy(void* wrapper) { + auto* w = static_cast(wrapper); + delete w; +} + +void AuthMetadataProcessorAyncWrapper::Process( + void* wrapper, grpc_auth_context* context, const grpc_metadata* md, + size_t num_md, grpc_process_auth_metadata_done_cb cb, void* user_data) { + auto* w = static_cast(wrapper); + if (!w->processor_) { + // Early exit. + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_OK, nullptr); + return; + } + if (w->processor_->IsBlocking()) { + w->thread_pool_->Add([w, context, md, num_md, cb, user_data] { + w->AuthMetadataProcessorAyncWrapper::InvokeProcessor(context, md, num_md, + cb, user_data); + }); + } else { + // invoke directly. + w->InvokeProcessor(context, md, num_md, cb, user_data); + } +} + +void AuthMetadataProcessorAyncWrapper::InvokeProcessor( + grpc_auth_context* context, const grpc_metadata* md, size_t num_md, + grpc_process_auth_metadata_done_cb cb, void* user_data) { + AuthMetadataProcessor::InputMetadata metadata; + for (size_t i = 0; i < num_md; i++) { + metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key), + StringRefFromSlice(&md[i].value))); + } + SecureAuthContext ctx(context); + AuthMetadataProcessor::OutputMetadata consumed_metadata; + AuthMetadataProcessor::OutputMetadata response_metadata; + + Status status = processor_->Process(metadata, &ctx, &consumed_metadata, + &response_metadata); + + std::vector consumed_md; + for (const auto& consumed : consumed_metadata) { + grpc_metadata md_entry; + md_entry.key = SliceReferencingString(consumed.first); + md_entry.value = SliceReferencingString(consumed.second); + consumed_md.push_back(md_entry); + } + std::vector response_md; + for (const auto& response : response_metadata) { + grpc_metadata md_entry; + md_entry.key = SliceReferencingString(response.first); + md_entry.value = SliceReferencingString(response.second); + response_md.push_back(md_entry); + } + auto consumed_md_data = consumed_md.empty() ? nullptr : &consumed_md[0]; + auto response_md_data = response_md.empty() ? nullptr : &response_md[0]; + cb(user_data, consumed_md_data, consumed_md.size(), response_md_data, + response_md.size(), static_cast(status.error_code()), + status.error_message().c_str()); +} + +int SecureServerCredentials::AddPortToServer(const std::string& addr, + grpc_server* server) { + return grpc_server_add_secure_http2_port(server, addr.c_str(), creds_); +} + +void SecureServerCredentials::SetAuthMetadataProcessor( + const std::shared_ptr& processor) { + auto* wrapper = new grpc::AuthMetadataProcessorAyncWrapper(processor); + grpc_server_credentials_set_auth_metadata_processor( + creds_, {grpc::AuthMetadataProcessorAyncWrapper::Process, + grpc::AuthMetadataProcessorAyncWrapper::Destroy, wrapper}); +} + +std::shared_ptr SslServerCredentials( + const grpc::SslServerCredentialsOptions& options) { + std::vector pem_key_cert_pairs; + for (const auto& key_cert_pair : options.pem_key_cert_pairs) { + grpc_ssl_pem_key_cert_pair p = {key_cert_pair.private_key.c_str(), + key_cert_pair.cert_chain.c_str()}; + pem_key_cert_pairs.push_back(p); + } + grpc_server_credentials* c_creds = grpc_ssl_server_credentials_create_ex( + options.pem_root_certs.empty() ? nullptr : options.pem_root_certs.c_str(), + pem_key_cert_pairs.empty() ? nullptr : &pem_key_cert_pairs[0], + pem_key_cert_pairs.size(), + options.force_client_auth + ? GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY + : options.client_certificate_request, + nullptr); + return std::shared_ptr( + new SecureServerCredentials(c_creds)); +} + +namespace experimental { + +std::shared_ptr AltsServerCredentials( + const AltsServerCredentialsOptions& /* options */) { + grpc_alts_credentials_options* c_options = + grpc_alts_credentials_server_options_create(); + grpc_server_credentials* c_creds = + grpc_alts_server_credentials_create(c_options); + grpc_alts_credentials_options_destroy(c_options); + return std::shared_ptr( + new SecureServerCredentials(c_creds)); +} + +std::shared_ptr LocalServerCredentials( + grpc_local_connect_type type) { + return std::shared_ptr( + new SecureServerCredentials(grpc_local_server_credentials_create(type))); +} + +std::shared_ptr TlsServerCredentials( + const grpc::experimental::TlsServerCredentialsOptions& options) { + return std::shared_ptr(new SecureServerCredentials( + grpc_tls_server_credentials_create(options.c_credentials_options()))); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc new file mode 100644 index 00000000..f0fec748 --- /dev/null +++ b/src/cpp/server/server_builder.cc @@ -0,0 +1,446 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/cpp/server/external_connection_acceptor_impl.h" +#include "src/cpp/server/thread_pool_interface.h" + +namespace grpc { + +static std::vector (*)()>* + g_plugin_factory_list; +static gpr_once once_init_plugin_list = GPR_ONCE_INIT; + +static void do_plugin_list_init(void) { + g_plugin_factory_list = + new std::vector (*)()>(); +} + +ServerBuilder::ServerBuilder() + : max_receive_message_size_(INT_MIN), + max_send_message_size_(INT_MIN), + sync_server_settings_(SyncServerSettings()), + resource_quota_(nullptr) { + gpr_once_init(&once_init_plugin_list, do_plugin_list_init); + for (const auto& value : *g_plugin_factory_list) { + plugins_.emplace_back(value()); + } + + // all compression algorithms enabled by default. + enabled_compression_algorithms_bitset_ = + (1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1; + memset(&maybe_default_compression_level_, 0, + sizeof(maybe_default_compression_level_)); + memset(&maybe_default_compression_algorithm_, 0, + sizeof(maybe_default_compression_algorithm_)); +} + +ServerBuilder::~ServerBuilder() { + if (resource_quota_ != nullptr) { + grpc_resource_quota_unref(resource_quota_); + } +} + +std::unique_ptr ServerBuilder::AddCompletionQueue( + bool is_frequently_polled) { + grpc::ServerCompletionQueue* cq = new grpc::ServerCompletionQueue( + GRPC_CQ_NEXT, + is_frequently_polled ? GRPC_CQ_DEFAULT_POLLING : GRPC_CQ_NON_LISTENING, + nullptr); + cqs_.push_back(cq); + return std::unique_ptr(cq); +} + +ServerBuilder& ServerBuilder::RegisterService(Service* service) { + services_.emplace_back(new NamedService(service)); + return *this; +} + +ServerBuilder& ServerBuilder::RegisterService(const std::string& host, + Service* service) { + services_.emplace_back(new NamedService(host, service)); + return *this; +} + +ServerBuilder& ServerBuilder::RegisterAsyncGenericService( + AsyncGenericService* service) { + if (generic_service_ || callback_generic_service_) { + gpr_log(GPR_ERROR, + "Adding multiple generic services is unsupported for now. " + "Dropping the service %p", + service); + } else { + generic_service_ = service; + } + return *this; +} + +ServerBuilder& ServerBuilder::RegisterCallbackGenericService( + CallbackGenericService* service) { + if (generic_service_ || callback_generic_service_) { + gpr_log(GPR_ERROR, + "Adding multiple generic services is unsupported for now. " + "Dropping the service %p", + service); + } else { + callback_generic_service_ = service; + } + return *this; +} + +ServerBuilder& ServerBuilder::SetContextAllocator( + std::unique_ptr context_allocator) { + context_allocator_ = std::move(context_allocator); + return *this; +} + +std::unique_ptr +ServerBuilder::experimental_type::AddExternalConnectionAcceptor( + experimental_type::ExternalConnectionType type, + std::shared_ptr creds) { + std::string name_prefix("external:"); + char count_str[GPR_LTOA_MIN_BUFSIZE]; + gpr_ltoa(static_cast(builder_->acceptors_.size()), count_str); + builder_->acceptors_.emplace_back( + std::make_shared( + name_prefix.append(count_str), type, creds)); + return builder_->acceptors_.back()->GetAcceptor(); +} + +void ServerBuilder::experimental_type::SetAuthorizationPolicyProvider( + std::shared_ptr + provider) { + builder_->authorization_provider_ = std::move(provider); +} + +ServerBuilder& ServerBuilder::SetOption( + std::unique_ptr option) { + options_.push_back(std::move(option)); + return *this; +} + +ServerBuilder& ServerBuilder::SetSyncServerOption( + ServerBuilder::SyncServerOption option, int val) { + switch (option) { + case NUM_CQS: + sync_server_settings_.num_cqs = val; + break; + case MIN_POLLERS: + sync_server_settings_.min_pollers = val; + break; + case MAX_POLLERS: + sync_server_settings_.max_pollers = val; + break; + case CQ_TIMEOUT_MSEC: + sync_server_settings_.cq_timeout_msec = val; + break; + } + return *this; +} + +ServerBuilder& ServerBuilder::SetCompressionAlgorithmSupportStatus( + grpc_compression_algorithm algorithm, bool enabled) { + if (enabled) { + grpc_core::SetBit(&enabled_compression_algorithms_bitset_, algorithm); + } else { + grpc_core::ClearBit(&enabled_compression_algorithms_bitset_, algorithm); + } + return *this; +} + +ServerBuilder& ServerBuilder::SetDefaultCompressionLevel( + grpc_compression_level level) { + maybe_default_compression_level_.is_set = true; + maybe_default_compression_level_.level = level; + return *this; +} + +ServerBuilder& ServerBuilder::SetDefaultCompressionAlgorithm( + grpc_compression_algorithm algorithm) { + maybe_default_compression_algorithm_.is_set = true; + maybe_default_compression_algorithm_.algorithm = algorithm; + return *this; +} + +ServerBuilder& ServerBuilder::SetResourceQuota( + const grpc::ResourceQuota& resource_quota) { + if (resource_quota_ != nullptr) { + grpc_resource_quota_unref(resource_quota_); + } + resource_quota_ = resource_quota.c_resource_quota(); + grpc_resource_quota_ref(resource_quota_); + return *this; +} + +ServerBuilder& ServerBuilder::AddListeningPort( + const std::string& addr_uri, std::shared_ptr creds, + int* selected_port) { + const std::string uri_scheme = "dns:"; + std::string addr = addr_uri; + if (addr_uri.compare(0, uri_scheme.size(), uri_scheme) == 0) { + size_t pos = uri_scheme.size(); + while (addr_uri[pos] == '/') ++pos; // Skip slashes. + addr = addr_uri.substr(pos); + } + Port port = {addr, std::move(creds), selected_port}; + ports_.push_back(port); + return *this; +} + +ChannelArguments ServerBuilder::BuildChannelArgs() { + ChannelArguments args; + if (max_receive_message_size_ >= -1) { + args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, max_receive_message_size_); + } + if (max_send_message_size_ >= -1) { + args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, max_send_message_size_); + } + for (const auto& option : options_) { + option->UpdateArguments(&args); + option->UpdatePlugins(&plugins_); + } + args.SetInt(GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET, + enabled_compression_algorithms_bitset_); + if (maybe_default_compression_level_.is_set) { + args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL, + maybe_default_compression_level_.level); + } + if (maybe_default_compression_algorithm_.is_set) { + args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, + maybe_default_compression_algorithm_.algorithm); + } + if (resource_quota_ != nullptr) { + args.SetPointerWithVtable(GRPC_ARG_RESOURCE_QUOTA, resource_quota_, + grpc_resource_quota_arg_vtable()); + } + for (const auto& plugin : plugins_) { + plugin->UpdateServerBuilder(this); + plugin->UpdateChannelArguments(&args); + } + if (authorization_provider_ != nullptr) { + args.SetPointerWithVtable(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER, + authorization_provider_->c_provider(), + grpc_authorization_policy_provider_arg_vtable()); + } + return args; +} + +std::unique_ptr ServerBuilder::BuildAndStart() { + ChannelArguments args = BuildChannelArgs(); + + // == Determine if the server has any syncrhonous methods == + bool has_sync_methods = false; + for (const auto& value : services_) { + if (value->service->has_synchronous_methods()) { + has_sync_methods = true; + break; + } + } + + if (!has_sync_methods) { + for (const auto& value : plugins_) { + if (value->has_sync_methods()) { + has_sync_methods = true; + break; + } + } + } + + // If this is a Sync server, i.e a server expositing sync API, then the server + // needs to create some completion queues to listen for incoming requests. + // 'sync_server_cqs' are those internal completion queues. + // + // This is different from the completion queues added to the server via + // ServerBuilder's AddCompletionQueue() method (those completion queues + // are in 'cqs_' member variable of ServerBuilder object) + std::shared_ptr>> + sync_server_cqs( + std::make_shared< + std::vector>>()); + + bool has_frequently_polled_cqs = false; + for (const auto& cq : cqs_) { + if (cq->IsFrequentlyPolled()) { + has_frequently_polled_cqs = true; + break; + } + } + + // == Determine if the server has any callback methods == + bool has_callback_methods = false; + for (const auto& service : services_) { + if (service->service->has_callback_methods()) { + has_callback_methods = true; + has_frequently_polled_cqs = true; + break; + } + } + + if (callback_generic_service_ != nullptr) { + has_frequently_polled_cqs = true; + } + + const bool is_hybrid_server = has_sync_methods && has_frequently_polled_cqs; + + if (has_sync_methods) { + grpc_cq_polling_type polling_type = + is_hybrid_server ? GRPC_CQ_NON_POLLING : GRPC_CQ_DEFAULT_POLLING; + + // Create completion queues to listen to incoming rpc requests + for (int i = 0; i < sync_server_settings_.num_cqs; i++) { + sync_server_cqs->emplace_back( + new grpc::ServerCompletionQueue(GRPC_CQ_NEXT, polling_type, nullptr)); + } + } + + // TODO(vjpai): Add a section here for plugins once they can support callback + // methods + + if (has_sync_methods) { + // This is a Sync server + gpr_log(GPR_INFO, + "Synchronous server. Num CQs: %d, Min pollers: %d, Max Pollers: " + "%d, CQ timeout (msec): %d", + sync_server_settings_.num_cqs, sync_server_settings_.min_pollers, + sync_server_settings_.max_pollers, + sync_server_settings_.cq_timeout_msec); + } + + if (has_callback_methods) { + gpr_log(GPR_INFO, "Callback server."); + } + + std::unique_ptr server(new grpc::Server( + &args, sync_server_cqs, sync_server_settings_.min_pollers, + sync_server_settings_.max_pollers, sync_server_settings_.cq_timeout_msec, + std::move(acceptors_), server_config_fetcher_, resource_quota_, + std::move(interceptor_creators_))); + + ServerInitializer* initializer = server->initializer(); + + // Register all the completion queues with the server. i.e + // 1. sync_server_cqs: internal completion queues created IF this is a sync + // server + // 2. cqs_: Completion queues added via AddCompletionQueue() call + + for (const auto& cq : *sync_server_cqs) { + grpc_server_register_completion_queue(server->server_, cq->cq(), nullptr); + has_frequently_polled_cqs = true; + } + + if (has_callback_methods || callback_generic_service_ != nullptr) { + auto* cq = server->CallbackCQ(); + grpc_server_register_completion_queue(server->server_, cq->cq(), nullptr); + } + + // cqs_ contains the completion queue added by calling the ServerBuilder's + // AddCompletionQueue() API. Some of them may not be frequently polled (i.e by + // calling Next() or AsyncNext()) and hence are not safe to be used for + // listening to incoming channels. Such completion queues must be registered + // as non-listening queues. In debug mode, these should have their server list + // tracked since these are provided the user and must be Shutdown by the user + // after the server is shutdown. + for (const auto& cq : cqs_) { + grpc_server_register_completion_queue(server->server_, cq->cq(), nullptr); + cq->RegisterServer(server.get()); + } + + if (!has_frequently_polled_cqs) { + gpr_log(GPR_ERROR, + "At least one of the completion queues must be frequently polled"); + return nullptr; + } + + server->RegisterContextAllocator(std::move(context_allocator_)); + + for (const auto& value : services_) { + if (!server->RegisterService(value->host.get(), value->service)) { + return nullptr; + } + } + + for (const auto& value : plugins_) { + value->InitServer(initializer); + } + + if (generic_service_) { + server->RegisterAsyncGenericService(generic_service_); + } else if (callback_generic_service_) { + server->RegisterCallbackGenericService(callback_generic_service_); + } else { + for (const auto& value : services_) { + if (value->service->has_generic_methods()) { + gpr_log(GPR_ERROR, + "Some methods were marked generic but there is no " + "generic service registered."); + return nullptr; + } + } + } + + bool added_port = false; + for (auto& port : ports_) { + int r = server->AddListeningPort(port.addr, port.creds.get()); + if (!r) { + if (added_port) server->Shutdown(); + return nullptr; + } + added_port = true; + if (port.selected_port != nullptr) { + *port.selected_port = r; + } + } + + auto cqs_data = cqs_.empty() ? nullptr : &cqs_[0]; + server->Start(cqs_data, cqs_.size()); + + for (const auto& value : plugins_) { + value->Finish(initializer); + } + + return server; +} + +void ServerBuilder::InternalAddPluginFactory( + std::unique_ptr (*CreatePlugin)()) { + gpr_once_init(&once_init_plugin_list, do_plugin_list_init); + (*g_plugin_factory_list).push_back(CreatePlugin); +} + +ServerBuilder& ServerBuilder::EnableWorkaround(grpc_workaround_list id) { + switch (id) { + case GRPC_WORKAROUND_ID_CRONET_COMPRESSION: + return AddChannelArgument(GRPC_ARG_WORKAROUND_CRONET_COMPRESSION, 1); + default: + gpr_log(GPR_ERROR, "Workaround %u does not exist or is obsolete.", id); + return *this; + } +} + +} // namespace grpc diff --git a/src/cpp/server/server_callback.cc b/src/cpp/server/server_callback.cc new file mode 100644 index 00000000..5b2d328b --- /dev/null +++ b/src/cpp/server/server_callback.cc @@ -0,0 +1,84 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/executor.h" + +namespace grpc { +namespace internal { + +void ServerCallbackCall::ScheduleOnDone(bool inline_ondone) { + if (inline_ondone) { + CallOnDone(); + } else { + // Unlike other uses of closure, do not Ref or Unref here since at this + // point, all the Ref'fing and Unref'fing is done for this call. + grpc_core::ExecCtx exec_ctx; + struct ClosureWithArg { + grpc_closure closure; + ServerCallbackCall* call; + explicit ClosureWithArg(ServerCallbackCall* call_arg) : call(call_arg) { + GRPC_CLOSURE_INIT( + &closure, + [](void* void_arg, grpc_error_handle) { + ClosureWithArg* arg = static_cast(void_arg); + arg->call->CallOnDone(); + delete arg; + }, + this, grpc_schedule_on_exec_ctx); + } + }; + ClosureWithArg* arg = new ClosureWithArg(this); + grpc_core::Executor::Run(&arg->closure, GRPC_ERROR_NONE); + } +} + +void ServerCallbackCall::CallOnCancel(ServerReactor* reactor) { + if (reactor->InternalInlineable()) { + reactor->OnCancel(); + } else { + // Ref to make sure that the closure executes before the whole call gets + // destructed, and Unref within the closure. + Ref(); + grpc_core::ExecCtx exec_ctx; + struct ClosureWithArg { + grpc_closure closure; + ServerCallbackCall* call; + ServerReactor* reactor; + ClosureWithArg(ServerCallbackCall* call_arg, ServerReactor* reactor_arg) + : call(call_arg), reactor(reactor_arg) { + GRPC_CLOSURE_INIT( + &closure, + [](void* void_arg, grpc_error_handle) { + ClosureWithArg* arg = static_cast(void_arg); + arg->reactor->OnCancel(); + arg->call->MaybeDone(); + delete arg; + }, + this, grpc_schedule_on_exec_ctx); + } + }; + ClosureWithArg* arg = new ClosureWithArg(this, reactor); + grpc_core::Executor::Run(&arg->closure, GRPC_ERROR_NONE); + } +} + +} // namespace internal +} // namespace grpc diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc new file mode 100644 index 00000000..ebfbc9e8 --- /dev/null +++ b/src/cpp/server/server_cc.cc @@ -0,0 +1,1370 @@ +/* + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "src/cpp/client/create_channel_internal.h" +#include "src/cpp/server/external_connection_acceptor_impl.h" +#include "src/cpp/server/health/default_health_check_service.h" +#include "src/cpp/thread_manager/thread_manager.h" + +namespace grpc { +namespace { + +// The default value for maximum number of threads that can be created in the +// sync server. This value of INT_MAX is chosen to match the default behavior if +// no ResourceQuota is set. To modify the max number of threads in a sync +// server, pass a custom ResourceQuota object (with the desired number of +// max-threads set) to the server builder. +#define DEFAULT_MAX_SYNC_SERVER_THREADS INT_MAX + +// Give a useful status error message if the resource is exhausted specifically +// because the server threadpool is full. +const char* kServerThreadpoolExhausted = "Server Threadpool Exhausted"; + +// Although we might like to give a useful status error message on unimplemented +// RPCs, it's not always possible since that also would need to be added across +// languages and isn't actually required by the spec. +const char* kUnknownRpcMethod = ""; + +class DefaultGlobalCallbacks final : public Server::GlobalCallbacks { + public: + ~DefaultGlobalCallbacks() override {} + void PreSynchronousRequest(ServerContext* /*context*/) override {} + void PostSynchronousRequest(ServerContext* /*context*/) override {} +}; + +std::shared_ptr g_callbacks = nullptr; +gpr_once g_once_init_callbacks = GPR_ONCE_INIT; + +void InitGlobalCallbacks() { + if (!g_callbacks) { + g_callbacks.reset(new DefaultGlobalCallbacks()); + } +} + +class ShutdownTag : public internal::CompletionQueueTag { + public: + bool FinalizeResult(void** /*tag*/, bool* /*status*/) override { + return false; + } +}; + +class PhonyTag : public internal::CompletionQueueTag { + public: + bool FinalizeResult(void** /*tag*/, bool* /*status*/) override { + return true; + } +}; + +class UnimplementedAsyncRequestContext { + protected: + UnimplementedAsyncRequestContext() : generic_stream_(&server_context_) {} + + GenericServerContext server_context_; + GenericServerAsyncReaderWriter generic_stream_; +}; + +} // namespace + +ServerInterface::BaseAsyncRequest::BaseAsyncRequest( + ServerInterface* server, ServerContext* context, + internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) + : server_(server), + context_(context), + stream_(stream), + call_cq_(call_cq), + notification_cq_(notification_cq), + tag_(tag), + delete_on_finalize_(delete_on_finalize), + call_(nullptr), + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); + call_cq_->RegisterAvalanching(); // This op will trigger more ops +} + +ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { + call_cq_->CompleteAvalanching(); +} + +bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, + bool* status) { + if (done_intercepting_) { + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } + context_->set_call(call_); + context_->cq_ = call_cq_; + if (call_wrapper_.call() == nullptr) { + // Fill it since it is empty. + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); + } + + // just the pointers inside call are copied here + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + if (interceptor_methods_.RunInterceptors( + [this]() { ContinueFinalizeResultAfterInterception(); })) { + // There are no interceptors to run. Continue + } else { + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. + return false; + } + } + if (*status && call_) { + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); + } + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; +} + +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); + // Queue a tag which will be returned immediately + grpc_core::ExecCtx exec_ctx; + grpc_cq_begin_op(notification_cq_->cq(), this); + grpc_cq_end_op( + notification_cq_->cq(), this, GRPC_ERROR_NONE, + [](void* /*arg*/, grpc_cq_completion* completion) { delete completion; }, + nullptr, new grpc_cq_completion()); +} + +ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( + ServerInterface* server, ServerContext* context, + internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, const char* name, + internal::RpcMethod::RpcType type) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name), + type_(type) {} + +void ServerInterface::RegisteredAsyncRequest::IssueRequest( + void* registered_method, grpc_byte_buffer** payload, + ServerCompletionQueue* notification_cq) { + // The following call_start_batch is internally-generated so no need for an + // explanatory log on failure. + GPR_ASSERT(grpc_server_request_registered_call( + server_->server(), registered_method, &call_, + &context_->deadline_, context_->client_metadata_.arr(), + payload, call_cq_->cq(), notification_cq->cq(), + this) == GRPC_CALL_OK); +} + +ServerInterface::GenericAsyncRequest::GenericAsyncRequest( + ServerInterface* server, GenericServerContext* context, + internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + delete_on_finalize) { + grpc_call_details_init(&call_details_); + GPR_ASSERT(notification_cq); + GPR_ASSERT(call_cq); + // The following call_start_batch is internally-generated so no need for an + // explanatory log on failure. + GPR_ASSERT(grpc_server_request_call(server->server(), &call_, &call_details_, + context->client_metadata_.arr(), + call_cq->cq(), notification_cq->cq(), + this) == GRPC_CALL_OK); +} + +bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, + bool* status) { + // If we are done intercepting, there is nothing more for us to do + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + // TODO(yangg) remove the copy here. + if (*status) { + static_cast(context_)->method_ = + StringFromCopiedSlice(call_details_.method); + static_cast(context_)->host_ = + StringFromCopiedSlice(call_details_.host); + context_->deadline_ = call_details_.deadline; + } + grpc_slice_unref(call_details_.method); + grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info( + static_cast(context_)->method_.c_str(), + internal::RpcMethod::BIDI_STREAMING, + *server_->interceptor_creators())); + return BaseAsyncRequest::FinalizeResult(tag, status); +} + +namespace { +class ShutdownCallback : public grpc_completion_queue_functor { + public: + ShutdownCallback() { + functor_run = &ShutdownCallback::Run; + // Set inlineable to true since this callback is trivial and thus does not + // need to be run from the executor (triggering a thread hop). This should + // only be used by internal callbacks like this and not by user application + // code. + inlineable = true; + } + // TakeCQ takes ownership of the cq into the shutdown callback + // so that the shutdown callback will be responsible for destroying it + void TakeCQ(CompletionQueue* cq) { cq_ = cq; } + + // The Run function will get invoked by the completion queue library + // when the shutdown is actually complete + static void Run(grpc_completion_queue_functor* cb, int) { + auto* callback = static_cast(cb); + delete callback->cq_; + delete callback; + } + + private: + CompletionQueue* cq_ = nullptr; +}; +} // namespace + +/// Use private inheritance rather than composition only to establish order +/// of construction, since the public base class should be constructed after the +/// elements belonging to the private base class are constructed. This is not +/// possible using true composition. +class Server::UnimplementedAsyncRequest final + : private grpc::UnimplementedAsyncRequestContext, + public GenericAsyncRequest { + public: + UnimplementedAsyncRequest(ServerInterface* server, + grpc::ServerCompletionQueue* cq) + : GenericAsyncRequest(server, &server_context_, &generic_stream_, cq, cq, + nullptr, false) {} + + bool FinalizeResult(void** tag, bool* status) override; + + grpc::ServerContext* context() { return &server_context_; } + grpc::GenericServerAsyncReaderWriter* stream() { return &generic_stream_; } +}; + +/// UnimplementedAsyncResponse should not post user-visible completions to the +/// C++ completion queue, but is generated as a CQ event by the core +class Server::UnimplementedAsyncResponse final + : public grpc::internal::CallOpSet< + grpc::internal::CallOpSendInitialMetadata, + grpc::internal::CallOpServerSendStatus> { + public: + explicit UnimplementedAsyncResponse(UnimplementedAsyncRequest* request); + ~UnimplementedAsyncResponse() override { delete request_; } + + bool FinalizeResult(void** tag, bool* status) override { + if (grpc::internal::CallOpSet< + grpc::internal::CallOpSendInitialMetadata, + grpc::internal::CallOpServerSendStatus>::FinalizeResult(tag, + status)) { + delete this; + } else { + // The tag was swallowed due to interception. We will see it again. + } + return false; + } + + private: + UnimplementedAsyncRequest* const request_; +}; + +class Server::SyncRequest final : public grpc::internal::CompletionQueueTag { + public: + SyncRequest(Server* server, grpc::internal::RpcServiceMethod* method, + grpc_core::Server::RegisteredCallAllocation* data) + : SyncRequest(server, method) { + CommonSetup(data); + data->deadline = &deadline_; + data->optional_payload = has_request_payload_ ? &request_payload_ : nullptr; + } + + SyncRequest(Server* server, grpc::internal::RpcServiceMethod* method, + grpc_core::Server::BatchCallAllocation* data) + : SyncRequest(server, method) { + CommonSetup(data); + call_details_ = new grpc_call_details; + grpc_call_details_init(call_details_); + data->details = call_details_; + } + + ~SyncRequest() override { + // The destructor should only cleanup those objects created in the + // constructor, since some paths may or may not actually go through the + // Run stage where other objects are allocated. + if (has_request_payload_ && request_payload_) { + grpc_byte_buffer_destroy(request_payload_); + } + if (call_details_ != nullptr) { + grpc_call_details_destroy(call_details_); + delete call_details_; + } + grpc_metadata_array_destroy(&request_metadata_); + server_->UnrefWithPossibleNotify(); + } + + bool FinalizeResult(void** /*tag*/, bool* status) override { + if (!*status) { + delete this; + return false; + } + if (call_details_) { + deadline_ = call_details_->deadline; + } + return true; + } + + void Run(const std::shared_ptr& global_callbacks, + bool resources) { + ctx_.Init(deadline_, &request_metadata_); + wrapped_call_.Init( + call_, server_, &cq_, server_->max_receive_message_size(), + ctx_->ctx.set_server_rpc_info(method_->name(), method_->method_type(), + server_->interceptor_creators_)); + ctx_->ctx.set_call(call_); + ctx_->ctx.cq_ = &cq_; + request_metadata_.count = 0; + + global_callbacks_ = global_callbacks; + resources_ = resources; + + interceptor_methods_.SetCall(&*wrapped_call_); + interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_->ctx.client_metadata_); + + if (has_request_payload_) { + // Set interception point for RECV MESSAGE + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + deserialized_request_ = handler->Deserialize(call_, request_payload_, + &request_status_, nullptr); + if (!request_status_.ok()) { + gpr_log(GPR_DEBUG, "Failed to deserialize message."); + } + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(deserialized_request_, nullptr); + } + + if (interceptor_methods_.RunInterceptors( + [this]() { ContinueRunAfterInterception(); })) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } + + void ContinueRunAfterInterception() { + ctx_->ctx.BeginCompletionOp(&*wrapped_call_, nullptr, nullptr); + global_callbacks_->PreSynchronousRequest(&ctx_->ctx); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(grpc::internal::MethodHandler::HandlerParameter( + &*wrapped_call_, &ctx_->ctx, deserialized_request_, request_status_, + nullptr, nullptr)); + global_callbacks_->PostSynchronousRequest(&ctx_->ctx); + + cq_.Shutdown(); + + grpc::internal::CompletionQueueTag* op_tag = ctx_->ctx.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + // Ensure the cq_ is shutdown + grpc::PhonyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + + // Cleanup structures allocated during Run/ContinueRunAfterInterception + wrapped_call_.Destroy(); + ctx_.Destroy(); + + delete this; + } + + // For requests that must be only cleaned up but not actually Run + void Cleanup() { + cq_.Shutdown(); + grpc_call_unref(call_); + delete this; + } + + private: + SyncRequest(Server* server, grpc::internal::RpcServiceMethod* method) + : server_(server), + method_(method), + has_request_payload_(method->method_type() == + grpc::internal::RpcMethod::NORMAL_RPC || + method->method_type() == + grpc::internal::RpcMethod::SERVER_STREAMING), + cq_(grpc_completion_queue_create_for_pluck(nullptr)) {} + + template + void CommonSetup(CallAllocation* data) { + server_->Ref(); + grpc_metadata_array_init(&request_metadata_); + data->tag = static_cast(this); + data->call = &call_; + data->initial_metadata = &request_metadata_; + data->cq = cq_.cq(); + } + + Server* const server_; + grpc::internal::RpcServiceMethod* const method_; + const bool has_request_payload_; + grpc_call* call_; + grpc_call_details* call_details_ = nullptr; + gpr_timespec deadline_; + grpc_metadata_array request_metadata_; + grpc_byte_buffer* request_payload_ = nullptr; + grpc::CompletionQueue cq_; + grpc::Status request_status_; + std::shared_ptr global_callbacks_; + bool resources_; + void* deserialized_request_ = nullptr; + grpc::internal::InterceptorBatchMethodsImpl interceptor_methods_; + + // ServerContextWrapper allows ManualConstructor while using a private + // contructor of ServerContext via this friend class. + struct ServerContextWrapper { + ServerContext ctx; + + ServerContextWrapper(gpr_timespec deadline, grpc_metadata_array* arr) + : ctx(deadline, arr) {} + }; + + grpc_core::ManualConstructor ctx_; + grpc_core::ManualConstructor wrapped_call_; +}; + +template +class Server::CallbackRequest final + : public grpc::internal::CompletionQueueTag { + public: + static_assert( + std::is_base_of::value, + "ServerContextType must be derived from CallbackServerContext"); + + // For codegen services, the value of method represents the defined + // characteristics of the method being requested. For generic services, method + // is nullptr since these services don't have pre-defined methods. + CallbackRequest(Server* server, grpc::internal::RpcServiceMethod* method, + grpc::CompletionQueue* cq, + grpc_core::Server::RegisteredCallAllocation* data) + : server_(server), + method_(method), + has_request_payload_(method->method_type() == + grpc::internal::RpcMethod::NORMAL_RPC || + method->method_type() == + grpc::internal::RpcMethod::SERVER_STREAMING), + cq_(cq), + tag_(this), + ctx_(server_->context_allocator() != nullptr + ? server_->context_allocator()->NewCallbackServerContext() + : nullptr) { + CommonSetup(server, data); + data->deadline = &deadline_; + data->optional_payload = has_request_payload_ ? &request_payload_ : nullptr; + } + + // For generic services, method is nullptr since these services don't have + // pre-defined methods. + CallbackRequest(Server* server, grpc::CompletionQueue* cq, + grpc_core::Server::BatchCallAllocation* data) + : server_(server), + method_(nullptr), + has_request_payload_(false), + call_details_(new grpc_call_details), + cq_(cq), + tag_(this), + ctx_(server_->context_allocator() != nullptr + ? server_->context_allocator() + ->NewGenericCallbackServerContext() + : nullptr) { + CommonSetup(server, data); + grpc_call_details_init(call_details_); + data->details = call_details_; + } + + ~CallbackRequest() override { + delete call_details_; + grpc_metadata_array_destroy(&request_metadata_); + if (has_request_payload_ && request_payload_) { + grpc_byte_buffer_destroy(request_payload_); + } + if (ctx_alloc_by_default_ || server_->context_allocator() == nullptr) { + default_ctx_.Destroy(); + } + server_->UnrefWithPossibleNotify(); + } + + // Needs specialization to account for different processing of metadata + // in generic API + bool FinalizeResult(void** tag, bool* status) override; + + private: + // method_name needs to be specialized between named method and generic + const char* method_name() const; + + class CallbackCallTag : public grpc_completion_queue_functor { + public: + explicit CallbackCallTag(Server::CallbackRequest* req) + : req_(req) { + functor_run = &CallbackCallTag::StaticRun; + // Set inlineable to true since this callback is internally-controlled + // without taking any locks, and thus does not need to be run from the + // executor (which triggers a thread hop). This should only be used by + // internal callbacks like this and not by user application code. The work + // here is actually non-trivial, but there is no chance of having user + // locks conflict with each other so it's ok to run inlined. + inlineable = true; + } + + // force_run can not be performed on a tag if operations using this tag + // have been sent to PerformOpsOnCall. It is intended for error conditions + // that are detected before the operations are internally processed. + void force_run(bool ok) { Run(ok); } + + private: + Server::CallbackRequest* req_; + grpc::internal::Call* call_; + + static void StaticRun(grpc_completion_queue_functor* cb, int ok) { + static_cast(cb)->Run(static_cast(ok)); + } + void Run(bool ok) { + void* ignored = req_; + bool new_ok = ok; + GPR_ASSERT(!req_->FinalizeResult(&ignored, &new_ok)); + GPR_ASSERT(ignored == req_); + + if (!ok) { + // The call has been shutdown. + // Delete its contents to free up the request. + delete req_; + return; + } + + // Bind the call, deadline, and metadata from what we got + req_->ctx_->set_call(req_->call_); + req_->ctx_->cq_ = req_->cq_; + req_->ctx_->BindDeadlineAndMetadata(req_->deadline_, + &req_->request_metadata_); + req_->request_metadata_.count = 0; + + // Create a C++ Call to control the underlying core call + call_ = + new (grpc_call_arena_alloc(req_->call_, sizeof(grpc::internal::Call))) + grpc::internal::Call( + req_->call_, req_->server_, req_->cq_, + req_->server_->max_receive_message_size(), + req_->ctx_->set_server_rpc_info( + req_->method_name(), + (req_->method_ != nullptr) + ? req_->method_->method_type() + : grpc::internal::RpcMethod::BIDI_STREAMING, + req_->server_->interceptor_creators_)); + + req_->interceptor_methods_.SetCall(call_); + req_->interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + req_->interceptor_methods_.AddInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA); + req_->interceptor_methods_.SetRecvInitialMetadata( + &req_->ctx_->client_metadata_); + + if (req_->has_request_payload_) { + // Set interception point for RECV MESSAGE + req_->request_ = req_->method_->handler()->Deserialize( + req_->call_, req_->request_payload_, &req_->request_status_, + &req_->handler_data_); + if (!(req_->request_status_.ok())) { + gpr_log(GPR_DEBUG, "Failed to deserialize message."); + } + req_->request_payload_ = nullptr; + req_->interceptor_methods_.AddInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr); + } + + if (req_->interceptor_methods_.RunInterceptors( + [this] { ContinueRunAfterInterception(); })) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } + void ContinueRunAfterInterception() { + auto* handler = (req_->method_ != nullptr) + ? req_->method_->handler() + : req_->server_->generic_handler_.get(); + handler->RunHandler(grpc::internal::MethodHandler::HandlerParameter( + call_, req_->ctx_, req_->request_, req_->request_status_, + req_->handler_data_, [this] { delete req_; })); + } + }; + + template + void CommonSetup(Server* server, CallAllocation* data) { + server->Ref(); + grpc_metadata_array_init(&request_metadata_); + data->tag = static_cast(&tag_); + data->call = &call_; + data->initial_metadata = &request_metadata_; + if (ctx_ == nullptr) { + default_ctx_.Init(); + ctx_ = &*default_ctx_; + ctx_alloc_by_default_ = true; + } + ctx_->set_context_allocator(server->context_allocator()); + data->cq = cq_->cq(); + } + + Server* const server_; + grpc::internal::RpcServiceMethod* const method_; + const bool has_request_payload_; + grpc_byte_buffer* request_payload_ = nullptr; + void* request_ = nullptr; + void* handler_data_ = nullptr; + grpc::Status request_status_; + grpc_call_details* const call_details_ = nullptr; + grpc_call* call_; + gpr_timespec deadline_; + grpc_metadata_array request_metadata_; + grpc::CompletionQueue* const cq_; + bool ctx_alloc_by_default_ = false; + CallbackCallTag tag_; + ServerContextType* ctx_ = nullptr; + grpc_core::ManualConstructor default_ctx_; + grpc::internal::InterceptorBatchMethodsImpl interceptor_methods_; +}; + +template <> +bool Server::CallbackRequest::FinalizeResult( + void** /*tag*/, bool* /*status*/) { + return false; +} + +template <> +bool Server::CallbackRequest< + grpc::GenericCallbackServerContext>::FinalizeResult(void** /*tag*/, + bool* status) { + if (*status) { + deadline_ = call_details_->deadline; + // TODO(yangg) remove the copy here + ctx_->method_ = grpc::StringFromCopiedSlice(call_details_->method); + ctx_->host_ = grpc::StringFromCopiedSlice(call_details_->host); + } + grpc_slice_unref(call_details_->method); + grpc_slice_unref(call_details_->host); + return false; +} + +template <> +const char* Server::CallbackRequest::method_name() + const { + return method_->name(); +} + +template <> +const char* Server::CallbackRequest< + grpc::GenericCallbackServerContext>::method_name() const { + return ctx_->method().c_str(); +} + +// Implementation of ThreadManager. Each instance of SyncRequestThreadManager +// manages a pool of threads that poll for incoming Sync RPCs and call the +// appropriate RPC handlers +class Server::SyncRequestThreadManager : public grpc::ThreadManager { + public: + SyncRequestThreadManager(Server* server, grpc::CompletionQueue* server_cq, + std::shared_ptr global_callbacks, + grpc_resource_quota* rq, int min_pollers, + int max_pollers, int cq_timeout_msec) + : ThreadManager("SyncServer", rq, min_pollers, max_pollers), + server_(server), + server_cq_(server_cq), + cq_timeout_msec_(cq_timeout_msec), + global_callbacks_(std::move(global_callbacks)) {} + + WorkStatus PollForWork(void** tag, bool* ok) override { + *tag = nullptr; + // TODO(ctiller): workaround for GPR_TIMESPAN based deadlines not working + // right now + gpr_timespec deadline = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(cq_timeout_msec_, GPR_TIMESPAN)); + + switch (server_cq_->AsyncNext(tag, ok, deadline)) { + case grpc::CompletionQueue::TIMEOUT: + return TIMEOUT; + case grpc::CompletionQueue::SHUTDOWN: + return SHUTDOWN; + case grpc::CompletionQueue::GOT_EVENT: + return WORK_FOUND; + } + + GPR_UNREACHABLE_CODE(return TIMEOUT); + } + + void DoWork(void* tag, bool ok, bool resources) override { + (void)ok; + SyncRequest* sync_req = static_cast(tag); + + // Under the AllocatingRequestMatcher model we will never see an invalid tag + // here. + GPR_DEBUG_ASSERT(sync_req != nullptr); + GPR_DEBUG_ASSERT(ok); + + GPR_TIMER_SCOPE("sync_req->Run()", 0); + sync_req->Run(global_callbacks_, resources); + } + + void AddSyncMethod(grpc::internal::RpcServiceMethod* method, void* tag) { + server_->server()->core_server->SetRegisteredMethodAllocator( + server_cq_->cq(), tag, [this, method] { + grpc_core::Server::RegisteredCallAllocation result; + new SyncRequest(server_, method, &result); + return result; + }); + has_sync_method_ = true; + } + + void AddUnknownSyncMethod() { + if (has_sync_method_) { + unknown_method_ = absl::make_unique( + "unknown", grpc::internal::RpcMethod::BIDI_STREAMING, + new grpc::internal::UnknownMethodHandler(kUnknownRpcMethod)); + server_->server()->core_server->SetBatchMethodAllocator( + server_cq_->cq(), [this] { + grpc_core::Server::BatchCallAllocation result; + new SyncRequest(server_, unknown_method_.get(), &result); + return result; + }); + } + } + + void Shutdown() override { + ThreadManager::Shutdown(); + server_cq_->Shutdown(); + } + + void Wait() override { + ThreadManager::Wait(); + // Drain any pending items from the queue + void* tag; + bool ok; + while (server_cq_->Next(&tag, &ok)) { + // This problem can arise if the server CQ gets a request queued to it + // before it gets shutdown but then pulls it after shutdown. + static_cast(tag)->Cleanup(); + } + } + + void Start() { + if (has_sync_method_) { + Initialize(); // ThreadManager's Initialize() + } + } + + private: + Server* server_; + grpc::CompletionQueue* server_cq_; + int cq_timeout_msec_; + bool has_sync_method_ = false; + std::unique_ptr unknown_method_; + std::shared_ptr global_callbacks_; +}; + +static grpc::internal::GrpcLibraryInitializer g_gli_initializer; +Server::Server( + grpc::ChannelArguments* args, + std::shared_ptr>> + sync_server_cqs, + int min_pollers, int max_pollers, int sync_cq_timeout_msec, + std::vector> + acceptors, + grpc_server_config_fetcher* server_config_fetcher, + grpc_resource_quota* server_rq, + std::vector< + std::unique_ptr> + interceptor_creators) + : acceptors_(std::move(acceptors)), + interceptor_creators_(std::move(interceptor_creators)), + max_receive_message_size_(INT_MIN), + sync_server_cqs_(std::move(sync_server_cqs)), + started_(false), + shutdown_(false), + shutdown_notified_(false), + server_(nullptr), + server_initializer_(new ServerInitializer(this)), + health_check_service_disabled_(false) { + g_gli_initializer.summon(); + gpr_once_init(&grpc::g_once_init_callbacks, grpc::InitGlobalCallbacks); + global_callbacks_ = grpc::g_callbacks; + global_callbacks_->UpdateArguments(args); + + if (sync_server_cqs_ != nullptr) { + bool default_rq_created = false; + if (server_rq == nullptr) { + server_rq = grpc_resource_quota_create("SyncServer-default-rq"); + grpc_resource_quota_set_max_threads(server_rq, + DEFAULT_MAX_SYNC_SERVER_THREADS); + default_rq_created = true; + } + + for (const auto& it : *sync_server_cqs_) { + sync_req_mgrs_.emplace_back(new SyncRequestThreadManager( + this, it.get(), global_callbacks_, server_rq, min_pollers, + max_pollers, sync_cq_timeout_msec)); + } + + if (default_rq_created) { + grpc_resource_quota_unref(server_rq); + } + } + + for (auto& acceptor : acceptors_) { + acceptor->SetToChannelArgs(args); + } + + grpc_channel_args channel_args; + args->SetChannelArgs(&channel_args); + + for (size_t i = 0; i < channel_args.num_args; i++) { + if (0 == strcmp(channel_args.args[i].key, + grpc::kHealthCheckServiceInterfaceArg)) { + if (channel_args.args[i].value.pointer.p == nullptr) { + health_check_service_disabled_ = true; + } else { + health_check_service_.reset( + static_cast( + channel_args.args[i].value.pointer.p)); + } + } + if (0 == + strcmp(channel_args.args[i].key, GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH)) { + max_receive_message_size_ = channel_args.args[i].value.integer; + } + } + server_ = grpc_server_create(&channel_args, nullptr); + grpc_server_set_config_fetcher(server_, server_config_fetcher); +} + +Server::~Server() { + { + grpc::internal::ReleasableMutexLock lock(&mu_); + if (started_ && !shutdown_) { + lock.Release(); + Shutdown(); + } else if (!started_) { + // Shutdown the completion queues + for (const auto& value : sync_req_mgrs_) { + value->Shutdown(); + } + CompletionQueue* callback_cq = + callback_cq_.load(std::memory_order_relaxed); + if (callback_cq != nullptr) { + if (grpc_iomgr_run_in_background()) { + // gRPC-core provides the backing needed for the preferred CQ type + callback_cq->Shutdown(); + } else { + CompletionQueue::ReleaseCallbackAlternativeCQ(callback_cq); + } + callback_cq_.store(nullptr, std::memory_order_release); + } + } + } + // Destroy health check service before we destroy the C server so that + // it does not call grpc_server_request_registered_call() after the C + // server has been destroyed. + health_check_service_.reset(); + grpc_server_destroy(server_); +} + +void Server::SetGlobalCallbacks(GlobalCallbacks* callbacks) { + GPR_ASSERT(!grpc::g_callbacks); + GPR_ASSERT(callbacks); + grpc::g_callbacks.reset(callbacks); +} + +grpc_server* Server::c_server() { return server_; } + +std::shared_ptr Server::InProcessChannel( + const grpc::ChannelArguments& args) { + grpc_channel_args channel_args = args.c_channel_args(); + return grpc::CreateChannelInternal( + "inproc", grpc_inproc_channel_create(server_, &channel_args, nullptr), + std::vector>()); +} + +std::shared_ptr +Server::experimental_type::InProcessChannelWithInterceptors( + const grpc::ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + grpc_channel_args channel_args = args.c_channel_args(); + return grpc::CreateChannelInternal( + "inproc", + grpc_inproc_channel_create(server_->server_, &channel_args, nullptr), + std::move(interceptor_creators)); +} + +static grpc_server_register_method_payload_handling PayloadHandlingForMethod( + grpc::internal::RpcServiceMethod* method) { + switch (method->method_type()) { + case grpc::internal::RpcMethod::NORMAL_RPC: + case grpc::internal::RpcMethod::SERVER_STREAMING: + return GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER; + case grpc::internal::RpcMethod::CLIENT_STREAMING: + case grpc::internal::RpcMethod::BIDI_STREAMING: + return GRPC_SRM_PAYLOAD_NONE; + } + GPR_UNREACHABLE_CODE(return GRPC_SRM_PAYLOAD_NONE;); +} + +bool Server::RegisterService(const std::string* addr, grpc::Service* service) { + bool has_async_methods = service->has_async_methods(); + if (has_async_methods) { + GPR_ASSERT(service->server_ == nullptr && + "Can only register an asynchronous service against one server."); + service->server_ = this; + } + + const char* method_name = nullptr; + + for (const auto& method : service->methods_) { + if (method == nullptr) { // Handled by generic service if any. + continue; + } + + void* method_registration_tag = grpc_server_register_method( + server_, method->name(), addr ? addr->c_str() : nullptr, + PayloadHandlingForMethod(method.get()), 0); + if (method_registration_tag == nullptr) { + gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", + method->name()); + return false; + } + + if (method->handler() == nullptr) { // Async method without handler + method->set_server_tag(method_registration_tag); + } else if (method->api_type() == + grpc::internal::RpcServiceMethod::ApiType::SYNC) { + for (const auto& value : sync_req_mgrs_) { + value->AddSyncMethod(method.get(), method_registration_tag); + } + } else { + has_callback_methods_ = true; + grpc::internal::RpcServiceMethod* method_value = method.get(); + grpc::CompletionQueue* cq = CallbackCQ(); + server_->core_server->SetRegisteredMethodAllocator( + cq->cq(), method_registration_tag, [this, cq, method_value] { + grpc_core::Server::RegisteredCallAllocation result; + new CallbackRequest(this, method_value, + cq, &result); + return result; + }); + } + + method_name = method->name(); + } + + // Parse service name. + if (method_name != nullptr) { + std::stringstream ss(method_name); + std::string service_name; + if (std::getline(ss, service_name, '/') && + std::getline(ss, service_name, '/')) { + services_.push_back(service_name); + } + } + return true; +} + +void Server::RegisterAsyncGenericService(grpc::AsyncGenericService* service) { + GPR_ASSERT(service->server_ == nullptr && + "Can only register an async generic service against one server."); + service->server_ = this; + has_async_generic_service_ = true; +} + +void Server::RegisterCallbackGenericService( + grpc::CallbackGenericService* service) { + GPR_ASSERT( + service->server_ == nullptr && + "Can only register a callback generic service against one server."); + service->server_ = this; + has_callback_generic_service_ = true; + generic_handler_.reset(service->Handler()); + + grpc::CompletionQueue* cq = CallbackCQ(); + server_->core_server->SetBatchMethodAllocator(cq->cq(), [this, cq] { + grpc_core::Server::BatchCallAllocation result; + new CallbackRequest(this, cq, &result); + return result; + }); +} + +int Server::AddListeningPort(const std::string& addr, + grpc::ServerCredentials* creds) { + GPR_ASSERT(!started_); + int port = creds->AddPortToServer(addr, server_); + global_callbacks_->AddPort(this, addr, creds, port); + return port; +} + +void Server::Ref() { + shutdown_refs_outstanding_.fetch_add(1, std::memory_order_relaxed); +} + +void Server::UnrefWithPossibleNotify() { + if (GPR_UNLIKELY(shutdown_refs_outstanding_.fetch_sub( + 1, std::memory_order_acq_rel) == 1)) { + // No refs outstanding means that shutdown has been initiated and no more + // callback requests are outstanding. + grpc::internal::MutexLock lock(&mu_); + GPR_ASSERT(shutdown_); + shutdown_done_ = true; + shutdown_done_cv_.Signal(); + } +} + +void Server::UnrefAndWaitLocked() { + if (GPR_UNLIKELY(shutdown_refs_outstanding_.fetch_sub( + 1, std::memory_order_acq_rel) == 1)) { + shutdown_done_ = true; + return; // no need to wait on CV since done condition already set + } + while (!shutdown_done_) { + shutdown_done_cv_.Wait(&mu_); + } +} + +void Server::Start(grpc::ServerCompletionQueue** cqs, size_t num_cqs) { + GPR_ASSERT(!started_); + global_callbacks_->PreServerStart(this); + started_ = true; + + // Only create default health check service when user did not provide an + // explicit one. + grpc::ServerCompletionQueue* health_check_cq = nullptr; + grpc::DefaultHealthCheckService::HealthCheckServiceImpl* + default_health_check_service_impl = nullptr; + if (health_check_service_ == nullptr && !health_check_service_disabled_ && + grpc::DefaultHealthCheckServiceEnabled()) { + auto* default_hc_service = new grpc::DefaultHealthCheckService; + health_check_service_.reset(default_hc_service); + // We create a non-polling CQ to avoid impacting application + // performance. This ensures that we don't introduce thread hops + // for application requests that wind up on this CQ, which is polled + // in its own thread. + health_check_cq = new grpc::ServerCompletionQueue( + GRPC_CQ_NEXT, GRPC_CQ_NON_POLLING, nullptr); + grpc_server_register_completion_queue(server_, health_check_cq->cq(), + nullptr); + default_health_check_service_impl = + default_hc_service->GetHealthCheckService( + std::unique_ptr(health_check_cq)); + RegisterService(nullptr, default_health_check_service_impl); + } + + for (auto& acceptor : acceptors_) { + acceptor->GetCredentials()->AddPortToServer(acceptor->name(), server_); + } + + // If this server uses callback methods, then create a callback generic + // service to handle any unimplemented methods using the default reactor + // creator + if (has_callback_methods_ && !has_callback_generic_service_) { + unimplemented_service_ = absl::make_unique(); + RegisterCallbackGenericService(unimplemented_service_.get()); + } + +#ifndef NDEBUG + for (size_t i = 0; i < num_cqs; i++) { + cq_list_.push_back(cqs[i]); + } +#endif + + // If we have a generic service, all unmatched method names go there. + // Otherwise, we must provide at least one RPC request for an "unimplemented" + // RPC, which covers any RPC for a method name that isn't matched. If we + // have a sync service, let it be a sync unimplemented RPC, which must be + // registered before server start (to initialize an AllocatingRequestMatcher). + // If we have an AllocatingRequestMatcher, we can't also specify other + // unimplemented RPCs via explicit async requests, so we won't do so. If we + // only have async services, we can specify unimplemented RPCs on each async + // CQ so that some user polling thread will move them along as long as some + // progress is being made on any RPCs in the system. + bool unknown_rpc_needed = + !has_async_generic_service_ && !has_callback_generic_service_; + + if (unknown_rpc_needed && !sync_req_mgrs_.empty()) { + sync_req_mgrs_[0]->AddUnknownSyncMethod(); + unknown_rpc_needed = false; + } + + grpc_server_start(server_); + + if (unknown_rpc_needed) { + for (size_t i = 0; i < num_cqs; i++) { + if (cqs[i]->IsFrequentlyPolled()) { + new UnimplementedAsyncRequest(this, cqs[i]); + } + } + if (health_check_cq != nullptr) { + new UnimplementedAsyncRequest(this, health_check_cq); + } + unknown_rpc_needed = false; + } + + // If this server has any support for synchronous methods (has any sync + // server CQs), make sure that we have a ResourceExhausted handler + // to deal with the case of thread exhaustion + if (sync_server_cqs_ != nullptr && !sync_server_cqs_->empty()) { + resource_exhausted_handler_ = + absl::make_unique( + kServerThreadpoolExhausted); + } + + for (const auto& value : sync_req_mgrs_) { + value->Start(); + } + + if (default_health_check_service_impl != nullptr) { + default_health_check_service_impl->StartServingThread(); + } + + for (auto& acceptor : acceptors_) { + acceptor->Start(); + } +} + +void Server::ShutdownInternal(gpr_timespec deadline) { + grpc::internal::MutexLock lock(&mu_); + if (shutdown_) { + return; + } + + shutdown_ = true; + + for (auto& acceptor : acceptors_) { + acceptor->Shutdown(); + } + + /// The completion queue to use for server shutdown completion notification + grpc::CompletionQueue shutdown_cq; + grpc::ShutdownTag shutdown_tag; // Phony shutdown tag + grpc_server_shutdown_and_notify(server_, shutdown_cq.cq(), &shutdown_tag); + + shutdown_cq.Shutdown(); + + void* tag; + bool ok; + grpc::CompletionQueue::NextStatus status = + shutdown_cq.AsyncNext(&tag, &ok, deadline); + + // If this timed out, it means we are done with the grace period for a clean + // shutdown. We should force a shutdown now by cancelling all inflight calls + if (status == grpc::CompletionQueue::NextStatus::TIMEOUT) { + grpc_server_cancel_all_calls(server_); + } + // Else in case of SHUTDOWN or GOT_EVENT, it means that the server has + // successfully shutdown + + // Drop the shutdown ref and wait for all other refs to drop as well. + UnrefAndWaitLocked(); + + // Shutdown all ThreadManagers. This will try to gracefully stop all the + // threads in the ThreadManagers (once they process any inflight requests) + for (const auto& value : sync_req_mgrs_) { + value->Shutdown(); // ThreadManager's Shutdown() + } + + // Wait for threads in all ThreadManagers to terminate + for (const auto& value : sync_req_mgrs_) { + value->Wait(); + } + + // Shutdown the callback CQ. The CQ is owned by its own shutdown tag, so it + // will delete itself at true shutdown. + CompletionQueue* callback_cq = callback_cq_.load(std::memory_order_relaxed); + if (callback_cq != nullptr) { + if (grpc_iomgr_run_in_background()) { + // gRPC-core provides the backing needed for the preferred CQ type + callback_cq->Shutdown(); + } else { + CompletionQueue::ReleaseCallbackAlternativeCQ(callback_cq); + } + callback_cq_.store(nullptr, std::memory_order_release); + } + + // Drain the shutdown queue (if the previous call to AsyncNext() timed out + // and we didn't remove the tag from the queue yet) + while (shutdown_cq.Next(&tag, &ok)) { + // Nothing to be done here. Just ignore ok and tag values + } + + shutdown_notified_ = true; + shutdown_cv_.SignalAll(); + +#ifndef NDEBUG + // Unregister this server with the CQs passed into it by the user so that + // those can be checked for properly-ordered shutdown. + for (auto* cq : cq_list_) { + cq->UnregisterServer(this); + } + cq_list_.clear(); +#endif +} + +void Server::Wait() { + grpc::internal::MutexLock lock(&mu_); + while (started_ && !shutdown_notified_) { + shutdown_cv_.Wait(&mu_); + } +} + +void Server::PerformOpsOnCall(grpc::internal::CallOpSetInterface* ops, + grpc::internal::Call* call) { + ops->FillOps(call); +} + +bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, + bool* status) { + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + // We either had no interceptors run or we are done intercepting + if (*status) { + // Create a new request/response pair using the server and CQ values + // stored in this object's base class. + new UnimplementedAsyncRequest(server_, notification_cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } + } else { + // The tag was swallowed due to interception. We will see it again. + } + return false; +} + +Server::UnimplementedAsyncResponse::UnimplementedAsyncResponse( + UnimplementedAsyncRequest* request) + : request_(request) { + grpc::Status status(grpc::StatusCode::UNIMPLEMENTED, kUnknownRpcMethod); + grpc::internal::UnknownMethodHandler::FillOps(request_->context(), + kUnknownRpcMethod, this); + request_->stream()->call_.PerformOps(this); +} + +grpc::ServerInitializer* Server::initializer() { + return server_initializer_.get(); +} + +grpc::CompletionQueue* Server::CallbackCQ() { + // TODO(vjpai): Consider using a single global CQ for the default CQ + // if there is no explicit per-server CQ registered + CompletionQueue* callback_cq = callback_cq_.load(std::memory_order_acquire); + if (callback_cq != nullptr) { + return callback_cq; + } + // The callback_cq_ wasn't already set, so grab a lock and set it up exactly + // once for this server. + grpc::internal::MutexLock l(&mu_); + callback_cq = callback_cq_.load(std::memory_order_relaxed); + if (callback_cq != nullptr) { + return callback_cq; + } + if (grpc_iomgr_run_in_background()) { + // gRPC-core provides the backing needed for the preferred CQ type + auto* shutdown_callback = new grpc::ShutdownCallback; + callback_cq = new grpc::CompletionQueue(grpc_completion_queue_attributes{ + GRPC_CQ_CURRENT_VERSION, GRPC_CQ_CALLBACK, GRPC_CQ_DEFAULT_POLLING, + shutdown_callback}); + + // Transfer ownership of the new cq to its own shutdown callback + shutdown_callback->TakeCQ(callback_cq); + } else { + // Otherwise we need to use the alternative CQ variant + callback_cq = CompletionQueue::CallbackAlternativeCQ(); + } + + callback_cq_.store(callback_cq, std::memory_order_release); + return callback_cq; +} + +} // namespace grpc diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc new file mode 100644 index 00000000..b6c593c7 --- /dev/null +++ b/src/cpp/server/server_context.cc @@ -0,0 +1,379 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/surface/call.h" + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +// CompletionOp + +class ServerContextBase::CompletionOp final + : public internal::CallOpSetInterface { + public: + // initial refs: one in the server context, one in the cq + // must ref the call before calling constructor and after deleting this + CompletionOp(internal::Call* call, + ::grpc::internal::ServerCallbackCall* callback_controller) + : call_(*call), + callback_controller_(callback_controller), + has_tag_(false), + tag_(nullptr), + core_cq_tag_(this), + refs_(2), + finalized_(false), + cancelled_(0), + done_intercepting_(false) {} + + // CompletionOp isn't copyable or movable + CompletionOp(const CompletionOp&) = delete; + CompletionOp& operator=(const CompletionOp&) = delete; + CompletionOp(CompletionOp&&) = delete; + CompletionOp& operator=(CompletionOp&&) = delete; + + ~CompletionOp() override { + if (call_.server_rpc_info()) { + call_.server_rpc_info()->Unref(); + } + } + + void FillOps(internal::Call* call) override; + + // This should always be arena allocated in the call, so override delete. + // But this class is not trivially destructible, so must actually call delete + // before allowing the arena to be freed + static void operator delete(void* /*ptr*/, std::size_t size) { + // Use size to avoid unused-parameter warning since assert seems to be + // compiled out and treated as unused in some gcc optimized versions. + (void)size; + assert(size == sizeof(CompletionOp)); + } + + // This operator should never be called as the memory should be freed as part + // of the arena destruction. It only exists to provide a matching operator + // delete to the operator new so that some compilers will not complain (see + // https://github.com/grpc/grpc/issues/11301) Note at the time of adding this + // there are no tests catching the compiler warning. + static void operator delete(void*, void*) { assert(0); } + + bool FinalizeResult(void** tag, bool* status) override; + + bool CheckCancelled(CompletionQueue* cq) { + cq->TryPluck(this); + return CheckCancelledNoPluck(); + } + bool CheckCancelledAsync() { return CheckCancelledNoPluck(); } + + void set_tag(void* tag) { + has_tag_ = true; + tag_ = tag; + } + + void set_core_cq_tag(void* core_cq_tag) { core_cq_tag_ = core_cq_tag; } + + void* core_cq_tag() override { return core_cq_tag_; } + + void Unref(); + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_ASSERT(false); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + // We don't have a tag to return. + Unref(); + // Unref can delete this, so do not access anything from this afterward. + return; + } + /* Start a phony op so that we can return the tag */ + GPR_ASSERT(grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag_, + nullptr) == GRPC_CALL_OK); + } + + private: + bool CheckCancelledNoPluck() { + grpc_core::MutexLock lock(&mu_); + return finalized_ ? (cancelled_ != 0) : false; + } + + internal::Call call_; + ::grpc::internal::ServerCallbackCall* const callback_controller_; + bool has_tag_; + void* tag_; + void* core_cq_tag_; + grpc_core::RefCount refs_; + grpc_core::Mutex mu_; + bool finalized_; + int cancelled_; // This is an int (not bool) because it is passed to core + bool done_intercepting_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; +}; + +void ServerContextBase::CompletionOp::Unref() { + if (refs_.Unref()) { + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } +} + +void ServerContextBase::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); + // The following call_start_batch is internally-generated so no need for an + // explanatory log on failure. + GPR_ASSERT(grpc_call_start_batch(call->call(), &ops, 1, core_cq_tag_, + nullptr) == GRPC_CALL_OK); + /* No interceptors to run here */ +} + +bool ServerContextBase::CompletionOp::FinalizeResult(void** tag, bool* status) { + // Decide whether to do the unref or call the cancel callback within the lock + bool do_unref = false; + bool has_tag = false; + bool call_cancel = false; + + { + grpc_core::MutexLock lock(&mu_); + if (done_intercepting_) { + // We are done intercepting. + has_tag = has_tag_; + if (has_tag) { + *tag = tag_; + } + // Release the lock before unreffing as Unref may delete this object + do_unref = true; + } else { + finalized_ = true; + + // If for some reason the incoming status is false, mark that as a + // cancellation. + // TODO(vjpai): does this ever happen? + if (!*status) { + cancelled_ = 1; + } + + call_cancel = (cancelled_ != 0); + // Release the lock since we may call a callback and interceptors. + } + } + + if (do_unref) { + Unref(); + // Unref can delete this, so do not access anything from this afterward. + return has_tag; + } + if (call_cancel && callback_controller_ != nullptr) { + callback_controller_->MaybeCallOnCancel(); + } + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + // No interceptors were run + bool has_tag = has_tag_; + if (has_tag) { + *tag = tag_; + } + Unref(); + // Unref can delete this, so do not access anything from this afterward. + return has_tag; + } + // There are interceptors to be run. Return false for now. + return false; +} + +// ServerContextBase body + +ServerContextBase::ServerContextBase() + : deadline_(gpr_inf_future(GPR_CLOCK_REALTIME)) { + g_gli_initializer.summon(); +} + +ServerContextBase::ServerContextBase(gpr_timespec deadline, + grpc_metadata_array* arr) + : deadline_(deadline) { + std::swap(*client_metadata_.arr(), *arr); +} + +void ServerContextBase::BindDeadlineAndMetadata(gpr_timespec deadline, + grpc_metadata_array* arr) { + deadline_ = deadline; + std::swap(*client_metadata_.arr(), *arr); +} + +ServerContextBase::~ServerContextBase() { + if (completion_op_) { + completion_op_->Unref(); + // Unref can delete completion_op_, so do not access it afterward. + } + if (rpc_info_) { + rpc_info_->Unref(); + } + if (default_reactor_used_.load(std::memory_order_relaxed)) { + reinterpret_cast(&default_reactor_)->~Reactor(); + } +} + +ServerContextBase::CallWrapper::~CallWrapper() { + if (call) { + // If the ServerContext is part of the call's arena, this could free the + // object itself. + grpc_call_unref(call); + } +} + +void ServerContextBase::BeginCompletionOp( + internal::Call* call, std::function callback, + ::grpc::internal::ServerCallbackCall* callback_controller) { + GPR_ASSERT(!completion_op_); + if (rpc_info_) { + rpc_info_->Ref(); + } + grpc_call_ref(call->call()); + completion_op_ = + new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) + CompletionOp(call, callback_controller); + if (callback_controller != nullptr) { + completion_tag_.Set(call->call(), std::move(callback), completion_op_, + true); + completion_op_->set_core_cq_tag(&completion_tag_); + completion_op_->set_tag(completion_op_); + } else if (has_notify_when_done_tag_) { + completion_op_->set_tag(async_notify_when_done_tag_); + } + call->PerformOps(completion_op_); +} + +internal::CompletionQueueTag* ServerContextBase::GetCompletionOpTag() { + return static_cast(completion_op_); +} + +void ServerContextBase::AddInitialMetadata(const std::string& key, + const std::string& value) { + initial_metadata_.insert(std::make_pair(key, value)); +} + +void ServerContextBase::AddTrailingMetadata(const std::string& key, + const std::string& value) { + trailing_metadata_.insert(std::make_pair(key, value)); +} + +void ServerContextBase::TryCancel() const { + internal::CancelInterceptorBatchMethods cancel_methods; + if (rpc_info_) { + for (size_t i = 0; i < rpc_info_->interceptors_.size(); i++) { + rpc_info_->RunInterceptor(&cancel_methods, i); + } + } + grpc_call_error err = + grpc_call_cancel_with_status(call_.call, GRPC_STATUS_CANCELLED, + "Cancelled on the server side", nullptr); + if (err != GRPC_CALL_OK) { + gpr_log(GPR_ERROR, "TryCancel failed with: %d", err); + } +} + +bool ServerContextBase::IsCancelled() const { + if (completion_tag_) { + // When using callback API, this result is always valid. + return marked_cancelled_.load(std::memory_order_acquire) || + completion_op_->CheckCancelledAsync(); + } else if (has_notify_when_done_tag_) { + // When using async API, the result is only valid + // if the tag has already been delivered at the completion queue + return completion_op_ && completion_op_->CheckCancelledAsync(); + } else { + // when using sync API, the result is always valid + return marked_cancelled_.load(std::memory_order_acquire) || + (completion_op_ && completion_op_->CheckCancelled(cq_)); + } +} + +void ServerContextBase::set_compression_algorithm( + grpc_compression_algorithm algorithm) { + compression_algorithm_ = algorithm; + const char* algorithm_name = nullptr; + if (!grpc_compression_algorithm_name(algorithm, &algorithm_name)) { + gpr_log(GPR_ERROR, "Name for compression algorithm '%d' unknown.", + algorithm); + abort(); + } + GPR_ASSERT(algorithm_name != nullptr); + AddInitialMetadata(GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, algorithm_name); +} + +std::string ServerContextBase::peer() const { + std::string peer; + if (call_.call) { + char* c_peer = grpc_call_get_peer(call_.call); + peer = c_peer; + gpr_free(c_peer); + } + return peer; +} + +const struct census_context* ServerContextBase::census_context() const { + return call_.call == nullptr ? nullptr + : grpc_census_call_get_context(call_.call); +} + +void ServerContextBase::SetLoadReportingCosts( + const std::vector& cost_data) { + if (call_.call == nullptr) return; + for (const auto& cost_datum : cost_data) { + AddTrailingMetadata(GRPC_LB_COST_MD_KEY, cost_datum); + } +} + +} // namespace grpc diff --git a/src/cpp/server/server_credentials.cc b/src/cpp/server/server_credentials.cc new file mode 100644 index 00000000..454e8b4e --- /dev/null +++ b/src/cpp/server/server_credentials.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; +ServerCredentials::ServerCredentials() { g_gli_initializer.summon(); } + +ServerCredentials::~ServerCredentials() {} + +} // namespace grpc diff --git a/src/cpp/server/server_posix.cc b/src/cpp/server/server_posix.cc new file mode 100644 index 00000000..f2452cc3 --- /dev/null +++ b/src/cpp/server/server_posix.cc @@ -0,0 +1,32 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +namespace grpc { + +#ifdef GPR_SUPPORT_CHANNELS_FROM_FD + +void AddInsecureChannelFromFd(grpc::Server* server, int fd) { + grpc_server_add_insecure_channel_from_fd(server->c_server(), nullptr, fd); +} + +#endif // GPR_SUPPORT_CHANNELS_FROM_FD + +} // namespace grpc diff --git a/src/cpp/server/xds_server_credentials.cc b/src/cpp/server/xds_server_credentials.cc new file mode 100644 index 00000000..f1842389 --- /dev/null +++ b/src/cpp/server/xds_server_credentials.cc @@ -0,0 +1,47 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/cpp/server/secure_server_credentials.h" + +namespace grpc { + +std::shared_ptr XdsServerCredentials( + const std::shared_ptr& fallback_credentials) { + GPR_ASSERT(fallback_credentials != nullptr); + if (fallback_credentials->IsInsecure()) { + grpc_server_credentials* insecure_creds = + grpc_insecure_server_credentials_create(); + auto xds_creds = std::make_shared( + grpc_xds_server_credentials_create(insecure_creds)); + grpc_server_credentials_release(insecure_creds); + return xds_creds; + } + return std::make_shared( + grpc_xds_server_credentials_create( + fallback_credentials->AsSecureServerCredentials()->c_creds())); +} + +namespace experimental { + +std::shared_ptr XdsServerCredentials( + const std::shared_ptr& fallback_credentials) { + return grpc::XdsServerCredentials(fallback_credentials); +} + +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/thread_manager/thread_manager.cc b/src/cpp/thread_manager/thread_manager.cc new file mode 100644 index 00000000..d257ae8d --- /dev/null +++ b/src/cpp/thread_manager/thread_manager.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/thread_manager/thread_manager.h" + +#include + +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" + +namespace grpc { + +ThreadManager::WorkerThread::WorkerThread(ThreadManager* thd_mgr) + : thd_mgr_(thd_mgr) { + // Make thread creation exclusive with respect to its join happening in + // ~WorkerThread(). + thd_ = grpc_core::Thread( + "grpcpp_sync_server", + [](void* th) { static_cast(th)->Run(); }, + this, &created_); + if (!created_) { + gpr_log(GPR_ERROR, "Could not create grpc_sync_server worker-thread"); + } +} + +void ThreadManager::WorkerThread::Run() { + thd_mgr_->MainWorkLoop(); + thd_mgr_->MarkAsCompleted(this); +} + +ThreadManager::WorkerThread::~WorkerThread() { + // Don't join until the thread is fully constructed. + thd_.Join(); +} + +ThreadManager::ThreadManager(const char* name, + grpc_resource_quota* resource_quota, + int min_pollers, int max_pollers) + : shutdown_(false), + num_pollers_(0), + min_pollers_(min_pollers), + max_pollers_(max_pollers == -1 ? INT_MAX : max_pollers), + num_threads_(0), + max_active_threads_sofar_(0) { + resource_user_ = + grpc_resource_user_create(resource_quota, name != nullptr ? name : ""); +} + +ThreadManager::~ThreadManager() { + { + grpc_core::MutexLock lock(&mu_); + GPR_ASSERT(num_threads_ == 0); + } + + grpc_core::ExecCtx exec_ctx; // grpc_resource_user_unref needs an exec_ctx + grpc_resource_user_unref(resource_user_); + CleanupCompletedThreads(); +} + +void ThreadManager::Wait() { + grpc_core::MutexLock lock(&mu_); + while (num_threads_ != 0) { + shutdown_cv_.Wait(&mu_); + } +} + +void ThreadManager::Shutdown() { + grpc_core::MutexLock lock(&mu_); + shutdown_ = true; +} + +bool ThreadManager::IsShutdown() { + grpc_core::MutexLock lock(&mu_); + return shutdown_; +} + +int ThreadManager::GetMaxActiveThreadsSoFar() { + grpc_core::MutexLock list_lock(&list_mu_); + return max_active_threads_sofar_; +} + +void ThreadManager::MarkAsCompleted(WorkerThread* thd) { + { + grpc_core::MutexLock list_lock(&list_mu_); + completed_threads_.push_back(thd); + } + + { + grpc_core::MutexLock lock(&mu_); + num_threads_--; + if (num_threads_ == 0) { + shutdown_cv_.Signal(); + } + } + + // Give a thread back to the resource quota + grpc_resource_user_free_threads(resource_user_, 1); +} + +void ThreadManager::CleanupCompletedThreads() { + std::list completed_threads; + { + // swap out the completed threads list: allows other threads to clean up + // more quickly + grpc_core::MutexLock lock(&list_mu_); + completed_threads.swap(completed_threads_); + } + for (auto thd : completed_threads) delete thd; +} + +void ThreadManager::Initialize() { + if (!grpc_resource_user_allocate_threads(resource_user_, min_pollers_)) { + gpr_log(GPR_ERROR, + "No thread quota available to even create the minimum required " + "polling threads (i.e %d). Unable to start the thread manager", + min_pollers_); + abort(); + } + + { + grpc_core::MutexLock lock(&mu_); + num_pollers_ = min_pollers_; + num_threads_ = min_pollers_; + max_active_threads_sofar_ = min_pollers_; + } + + for (int i = 0; i < min_pollers_; i++) { + WorkerThread* worker = new WorkerThread(this); + GPR_ASSERT(worker->created()); // Must be able to create the minimum + worker->Start(); + } +} + +void ThreadManager::MainWorkLoop() { + while (true) { + void* tag; + bool ok; + WorkStatus work_status = PollForWork(&tag, &ok); + + grpc_core::LockableAndReleasableMutexLock lock(&mu_); + // Reduce the number of pollers by 1 and check what happened with the poll + num_pollers_--; + bool done = false; + switch (work_status) { + case TIMEOUT: + // If we timed out and we have more pollers than we need (or we are + // shutdown), finish this thread + if (shutdown_ || num_pollers_ > max_pollers_) done = true; + break; + case SHUTDOWN: + // If the thread manager is shutdown, finish this thread + done = true; + break; + case WORK_FOUND: + // If we got work and there are now insufficient pollers and there is + // quota available to create a new thread, start a new poller thread + bool resource_exhausted = false; + if (!shutdown_ && num_pollers_ < min_pollers_) { + if (grpc_resource_user_allocate_threads(resource_user_, 1)) { + // We can allocate a new poller thread + num_pollers_++; + num_threads_++; + if (num_threads_ > max_active_threads_sofar_) { + max_active_threads_sofar_ = num_threads_; + } + // Drop lock before spawning thread to avoid contention + lock.Release(); + WorkerThread* worker = new WorkerThread(this); + if (worker->created()) { + worker->Start(); + } else { + // Get lock again to undo changes to poller/thread counters. + grpc_core::MutexLock failure_lock(&mu_); + num_pollers_--; + num_threads_--; + resource_exhausted = true; + delete worker; + } + } else if (num_pollers_ > 0) { + // There is still at least some thread polling, so we can go on + // even though we are below the number of pollers that we would + // like to have (min_pollers_) + lock.Release(); + } else { + // There are no pollers to spare and we couldn't allocate + // a new thread, so resources are exhausted! + lock.Release(); + resource_exhausted = true; + } + } else { + // There are a sufficient number of pollers available so we can do + // the work and continue polling with our existing poller threads + lock.Release(); + } + // Lock is always released at this point - do the application work + // or return resource exhausted if there is new work but we couldn't + // get a thread in which to do it. + DoWork(tag, ok, !resource_exhausted); + // Take the lock again to check post conditions + lock.Lock(); + // If we're shutdown, we should finish at this point. + if (shutdown_) done = true; + break; + } + // If we decided to finish the thread, break out of the while loop + if (done) break; + + // Otherwise go back to polling as long as it doesn't exceed max_pollers_ + // + // **WARNING**: + // There is a possibility of threads thrashing here (i.e excessive thread + // shutdowns and creations than the ideal case). This happens if max_poller_ + // count is small and the rate of incoming requests is also small. In such + // scenarios we can possibly configure max_pollers_ to a higher value and/or + // increase the cq timeout. + // + // However, not doing this check here and unconditionally incrementing + // num_pollers (and hoping that the system will eventually settle down) has + // far worse consequences i.e huge number of threads getting created to the + // point of thread-exhaustion. For example: if the incoming request rate is + // very high, all the polling threads will return very quickly from + // PollForWork() with WORK_FOUND. They all briefly decrement num_pollers_ + // counter thereby possibly - and briefly - making it go below min_pollers; + // This will most likely result in the creation of a new poller since + // num_pollers_ dipped below min_pollers_. + // + // Now, If we didn't do the max_poller_ check here, all these threads will + // go back to doing PollForWork() and the whole cycle repeats (with a new + // thread being added in each cycle). Once the total number of threads in + // the system crosses a certain threshold (around ~1500), there is heavy + // contention on mutexes (the mu_ here or the mutexes in gRPC core like the + // pollset mutex) that makes DoWork() take longer to finish thereby causing + // new poller threads to be created even faster. This results in a thread + // avalanche. + if (num_pollers_ < max_pollers_) { + num_pollers_++; + } else { + break; + } + }; + + // This thread is exiting. Do some cleanup work i.e delete already completed + // worker threads + CleanupCompletedThreads(); + + // If we are here, either ThreadManager is shutting down or it already has + // enough threads. +} + +} // namespace grpc diff --git a/src/cpp/util/byte_buffer_cc.cc b/src/cpp/util/byte_buffer_cc.cc new file mode 100644 index 00000000..5c6f22bd --- /dev/null +++ b/src/cpp/util/byte_buffer_cc.cc @@ -0,0 +1,77 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +Status ByteBuffer::TrySingleSlice(Slice* slice) const { + if (!buffer_) { + return Status(StatusCode::FAILED_PRECONDITION, "Buffer not initialized"); + } + if ((buffer_->type == GRPC_BB_RAW) && + (buffer_->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer_->data.raw.slice_buffer.count == 1)) { + grpc_slice internal_slice = buffer_->data.raw.slice_buffer.slices[0]; + *slice = Slice(internal_slice, Slice::ADD_REF); + return Status::OK; + } else { + return Status(StatusCode::FAILED_PRECONDITION, + "Buffer isn't made up of a single uncompressed slice."); + } +} + +Status ByteBuffer::DumpToSingleSlice(Slice* slice) const { + if (!buffer_) { + return Status(StatusCode::FAILED_PRECONDITION, "Buffer not initialized"); + } + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer_)) { + return Status(StatusCode::INTERNAL, + "Couldn't initialize byte buffer reader"); + } + grpc_slice s = grpc_byte_buffer_reader_readall(&reader); + *slice = Slice(s, Slice::STEAL_REF); + grpc_byte_buffer_reader_destroy(&reader); + return Status::OK; +} + +Status ByteBuffer::Dump(std::vector* slices) const { + slices->clear(); + if (!buffer_) { + return Status(StatusCode::FAILED_PRECONDITION, "Buffer not initialized"); + } + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer_)) { + return Status(StatusCode::INTERNAL, + "Couldn't initialize byte buffer reader"); + } + grpc_slice s; + while (grpc_byte_buffer_reader_next(&reader, &s)) { + slices->push_back(Slice(s, Slice::STEAL_REF)); + } + grpc_byte_buffer_reader_destroy(&reader); + return Status::OK; +} + +} // namespace grpc diff --git a/src/cpp/util/core_stats.cc b/src/cpp/util/core_stats.cc new file mode 100644 index 00000000..edf0b1bb --- /dev/null +++ b/src/cpp/util/core_stats.cc @@ -0,0 +1,90 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/util/core_stats.h" + +#include + +using grpc::core::Bucket; +using grpc::core::Histogram; +using grpc::core::Metric; +using grpc::core::Stats; + +namespace grpc { + +void CoreStatsToProto(const grpc_stats_data& core, Stats* proto) { + for (int i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + Metric* m = proto->add_metrics(); + m->set_name(grpc_stats_counter_name[i]); + m->set_count(core.counters[i]); + } + for (int i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + Metric* m = proto->add_metrics(); + m->set_name(grpc_stats_histogram_name[i]); + Histogram* h = m->mutable_histogram(); + for (int j = 0; j < grpc_stats_histo_buckets[i]; j++) { + Bucket* b = h->add_buckets(); + b->set_start(grpc_stats_histo_bucket_boundaries[i][j]); + b->set_count(core.histograms[grpc_stats_histo_start[i] + j]); + } + } +} + +void ProtoToCoreStats(const grpc::core::Stats& proto, grpc_stats_data* core) { + memset(core, 0, sizeof(*core)); + for (const auto& m : proto.metrics()) { + switch (m.value_case()) { + case Metric::VALUE_NOT_SET: + break; + case Metric::kCount: + for (int i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + if (m.name() == grpc_stats_counter_name[i]) { + core->counters[i] = m.count(); + break; + } + } + break; + case Metric::kHistogram: + for (int i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + if (m.name() == grpc_stats_histogram_name[i]) { + const auto& h = m.histogram(); + bool valid = true; + if (grpc_stats_histo_buckets[i] != h.buckets_size()) valid = false; + for (int j = 0; valid && j < h.buckets_size(); j++) { + if (grpc_stats_histo_bucket_boundaries[i][j] != + h.buckets(j).start()) { + valid = false; + } + } + if (!valid) { + gpr_log(GPR_ERROR, + "Found histogram %s but shape is different from proto", + m.name().c_str()); + } + for (int j = 0; valid && j < h.buckets_size(); j++) { + core->histograms[grpc_stats_histo_start[i] + j] = + h.buckets(j).count(); + } + } + } + break; + } + } +} + +} // namespace grpc diff --git a/src/cpp/util/error_details.cc b/src/cpp/util/error_details.cc new file mode 100644 index 00000000..0330f012 --- /dev/null +++ b/src/cpp/util/error_details.cc @@ -0,0 +1,19 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include diff --git a/src/cpp/util/status.cc b/src/cpp/util/status.cc new file mode 100644 index 00000000..93696d81 --- /dev/null +++ b/src/cpp/util/status.cc @@ -0,0 +1,26 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { + +const Status& Status::OK = Status(); +const Status& Status::CANCELLED = Status(StatusCode::CANCELLED, ""); + +} // namespace grpc diff --git a/src/cpp/util/string_ref.cc b/src/cpp/util/string_ref.cc new file mode 100644 index 00000000..8b09a82a --- /dev/null +++ b/src/cpp/util/string_ref.cc @@ -0,0 +1,25 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { + +const size_t string_ref::npos = size_t(-1); + +} // namespace grpc diff --git a/src/cpp/util/time_cc.cc b/src/cpp/util/time_cc.cc new file mode 100644 index 00000000..6c9c228d --- /dev/null +++ b/src/cpp/util/time_cc.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +using std::chrono::duration_cast; +using std::chrono::high_resolution_clock; +using std::chrono::nanoseconds; +using std::chrono::seconds; +using std::chrono::system_clock; + +namespace grpc { + +void Timepoint2Timespec(const system_clock::time_point& from, + gpr_timespec* to) { + system_clock::duration deadline = from.time_since_epoch(); + seconds secs = duration_cast(deadline); + if (from == system_clock::time_point::max() || + secs.count() >= gpr_inf_future(GPR_CLOCK_REALTIME).tv_sec || + secs.count() < 0) { + *to = gpr_inf_future(GPR_CLOCK_REALTIME); + return; + } + nanoseconds nsecs = duration_cast(deadline - secs); + to->tv_sec = static_cast(secs.count()); + to->tv_nsec = static_cast(nsecs.count()); + to->clock_type = GPR_CLOCK_REALTIME; +} + +void TimepointHR2Timespec(const high_resolution_clock::time_point& from, + gpr_timespec* to) { + high_resolution_clock::duration deadline = from.time_since_epoch(); + seconds secs = duration_cast(deadline); + if (from == high_resolution_clock::time_point::max() || + secs.count() >= gpr_inf_future(GPR_CLOCK_REALTIME).tv_sec || + secs.count() < 0) { + *to = gpr_inf_future(GPR_CLOCK_REALTIME); + return; + } + nanoseconds nsecs = duration_cast(deadline - secs); + to->tv_sec = static_cast(secs.count()); + to->tv_nsec = static_cast(nsecs.count()); + to->clock_type = GPR_CLOCK_REALTIME; +} + +system_clock::time_point Timespec2Timepoint(gpr_timespec t) { + if (gpr_time_cmp(t, gpr_inf_future(t.clock_type)) == 0) { + return system_clock::time_point::max(); + } + t = gpr_convert_clock_type(t, GPR_CLOCK_REALTIME); + system_clock::time_point tp; + tp += duration_cast(seconds(t.tv_sec)); + tp += + duration_cast(nanoseconds(t.tv_nsec)); + return tp; +} + +} // namespace grpc diff --git a/src/ruby/ext/grpc/rb_enable_cpp.cc b/src/ruby/ext/grpc/rb_enable_cpp.cc new file mode 100644 index 00000000..701b2cf8 --- /dev/null +++ b/src/ruby/ext/grpc/rb_enable_cpp.cc @@ -0,0 +1,22 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +// This is a phony C++ source file to trigger ruby extension builder to +// pick C++ rather than C linker to link with c++ library properly. diff --git a/test/build/no-c++14-compat.cc b/test/build/no-c++14-compat.cc new file mode 100644 index 00000000..0c1771c7 --- /dev/null +++ b/test/build/no-c++14-compat.cc @@ -0,0 +1,19 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +int main(void) {} diff --git a/test/build/protobuf.cc b/test/build/protobuf.cc new file mode 100644 index 00000000..47520c23 --- /dev/null +++ b/test/build/protobuf.cc @@ -0,0 +1,26 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +bool protobuf_test(const google::protobuf::MethodDescriptor *method) { + return method->client_streaming() || method->server_streaming(); +} + +int main() { return 0; } diff --git a/test/core/address_utils/parse_address_test.cc b/test/core/address_utils/parse_address_test.cc new file mode 100644 index 00000000..f92a97bc --- /dev/null +++ b/test/core/address_utils/parse_address_test.cc @@ -0,0 +1,147 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/address_utils/parse_address.h" + +#include +#ifdef GRPC_HAVE_UNIX_SOCKET +#include +#endif + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_HAVE_UNIX_SOCKET + +static void test_grpc_parse_unix(const char* uri_text, const char* pathname) { + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + + GPR_ASSERT(1 == grpc_parse_uri(*uri, &addr)); + struct sockaddr_un* addr_un = + reinterpret_cast(addr.addr); + GPR_ASSERT(AF_UNIX == addr_un->sun_family); + GPR_ASSERT(0 == strcmp(addr_un->sun_path, pathname)); +} + +static void test_grpc_parse_unix_abstract(const char* uri_text, + const char* pathname) { + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + + GPR_ASSERT(1 == grpc_parse_uri(*uri, &addr)); + struct sockaddr_un* addr_un = + reinterpret_cast(addr.addr); + GPR_ASSERT(AF_UNIX == addr_un->sun_family); + GPR_ASSERT('\0' == addr_un->sun_path[0]); + GPR_ASSERT(0 == strncmp(addr_un->sun_path + 1, pathname, strlen(pathname))); +} + +#else /* GRPC_HAVE_UNIX_SOCKET */ + +static void test_grpc_parse_unix(const char* uri_text, const char* pathname) {} +static void test_grpc_parse_unix_abstract(const char* uri_text, + const char* pathname) {} + +#endif /* GRPC_HAVE_UNIX_SOCKET */ + +static void test_grpc_parse_ipv4(const char* uri_text, const char* host, + unsigned short port) { + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + char ntop_buf[GRPC_INET_ADDRSTRLEN]; + + GPR_ASSERT(1 == grpc_parse_ipv4(*uri, &addr)); + grpc_sockaddr_in* addr_in = reinterpret_cast(addr.addr); + GPR_ASSERT(GRPC_AF_INET == addr_in->sin_family); + GPR_ASSERT(nullptr != grpc_inet_ntop(GRPC_AF_INET, &addr_in->sin_addr, + ntop_buf, sizeof(ntop_buf))); + GPR_ASSERT(0 == strcmp(ntop_buf, host)); + GPR_ASSERT(grpc_ntohs(addr_in->sin_port) == port); +} + +static void test_grpc_parse_ipv6(const char* uri_text, const char* host, + unsigned short port, uint32_t scope_id) { + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + char ntop_buf[GRPC_INET6_ADDRSTRLEN]; + GPR_ASSERT(1 == grpc_parse_ipv6(*uri, &addr)); + grpc_sockaddr_in6* addr_in6 = reinterpret_cast(addr.addr); + GPR_ASSERT(GRPC_AF_INET6 == addr_in6->sin6_family); + GPR_ASSERT(nullptr != grpc_inet_ntop(GRPC_AF_INET6, &addr_in6->sin6_addr, + ntop_buf, sizeof(ntop_buf))); + GPR_ASSERT(0 == strcmp(ntop_buf, host)); + GPR_ASSERT(grpc_ntohs(addr_in6->sin6_port) == port); + GPR_ASSERT(addr_in6->sin6_scope_id == scope_id); +} + +/* Test parsing invalid ipv6 addresses (valid uri_text but invalid ipv6 addr) */ +static void test_grpc_parse_ipv6_invalid(const char* uri_text) { + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + GPR_ASSERT(!grpc_parse_ipv6(*uri, &addr)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_grpc_parse_unix("unix:/path/name", "/path/name"); + test_grpc_parse_unix_abstract("unix-abstract:foobar", "foobar"); + test_grpc_parse_ipv4("ipv4:192.0.2.1:12345", "192.0.2.1", 12345); + test_grpc_parse_ipv6("ipv6:[2001:db8::1]:12345", "2001:db8::1", 12345, 0); + test_grpc_parse_ipv6("ipv6:[2001:db8::1%252]:12345", "2001:db8::1", 12345, 2); + + /* Address length greater than GRPC_INET6_ADDRSTRLEN */ + test_grpc_parse_ipv6_invalid( + "ipv6:WWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWW45%" + "25v6:45%25x$1*"); + + grpc_shutdown(); +} diff --git a/test/core/address_utils/parse_address_with_named_scope_id_test.cc b/test/core/address_utils/parse_address_with_named_scope_id_test.cc new file mode 100644 index 00000000..caec8bfc --- /dev/null +++ b/test/core/address_utils/parse_address_with_named_scope_id_test.cc @@ -0,0 +1,130 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#ifdef GRPC_HAVE_UNIX_SOCKET +#include +#endif + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "test/core/util/test_config.h" + +static void test_grpc_parse_ipv6_parity_with_getaddrinfo( + const char* target, const struct sockaddr_in6 result_from_getaddrinfo) { + // Get the sockaddr that gRPC's ipv6 resolver resolves this too. + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(target); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address addr; + GPR_ASSERT(1 == grpc_parse_ipv6(*uri, &addr)); + grpc_sockaddr_in6* result_from_grpc_parser = + reinterpret_cast(addr.addr); + // Compare the sockaddr returned from gRPC's ipv6 resolver with that returned + // from getaddrinfo. + GPR_ASSERT(result_from_grpc_parser->sin6_family == AF_INET6); + GPR_ASSERT(result_from_getaddrinfo.sin6_family == AF_INET6); + GPR_ASSERT(memcmp(&result_from_grpc_parser->sin6_addr, + &result_from_getaddrinfo.sin6_addr, sizeof(in6_addr)) == 0); + GPR_ASSERT(result_from_grpc_parser->sin6_scope_id == + result_from_getaddrinfo.sin6_scope_id); + GPR_ASSERT(result_from_grpc_parser->sin6_scope_id != 0); + // TODO(unknown): compare sin6_flow_info fields? parse_ipv6 zero's this field + // as is. Cleanup +} + +struct sockaddr_in6 resolve_with_gettaddrinfo(const char* uri_text) { + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + std::string host; + std::string port; + grpc_core::SplitHostPort(uri->path(), &host, &port); + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_NUMERICHOST; + struct addrinfo* result; + int res = getaddrinfo(host.c_str(), port.c_str(), &hints, &result); + if (res != 0) { + gpr_log(GPR_ERROR, + "getaddrinfo failed to resolve host:%s port:%s. Error: %d.", + host.c_str(), port.c_str(), res); + abort(); + } + size_t num_addrs_from_getaddrinfo = 0; + for (struct addrinfo* resp = result; resp != nullptr; resp = resp->ai_next) { + num_addrs_from_getaddrinfo++; + } + GPR_ASSERT(num_addrs_from_getaddrinfo == 1); + GPR_ASSERT(result->ai_family == AF_INET6); + struct sockaddr_in6 out = + *reinterpret_cast(result->ai_addr); + // Cleanup + freeaddrinfo(result); + return out; +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + char* arbitrary_interface_name = static_cast(gpr_zalloc(IF_NAMESIZE)); + // Per RFC 3493, an interface index is a "small positive integer starts at 1". + // Probe candidate interface index numbers until we find one that the + // system recognizes, and then use that for the test. + for (size_t i = 1; i < 65536; i++) { + if (if_indextoname(i, arbitrary_interface_name) != nullptr) { + gpr_log(GPR_DEBUG, + "Found interface at index %" PRIuPTR + " named %s. Will use this for the test", + i, arbitrary_interface_name); + break; + } + } + GPR_ASSERT(strlen(arbitrary_interface_name) > 0); + std::string target = + absl::StrFormat("ipv6:[fe80::1234%%%s]:12345", arbitrary_interface_name); + struct sockaddr_in6 result_from_getaddrinfo = + resolve_with_gettaddrinfo(target.c_str()); + // Run the test + gpr_log(GPR_DEBUG, + "Run test_grpc_parse_ipv6_parity_with_getaddrinfo with target: %s", + target.c_str()); + test_grpc_parse_ipv6_parity_with_getaddrinfo(target.c_str(), + result_from_getaddrinfo); + // Cleanup + gpr_free(arbitrary_interface_name); + grpc_shutdown(); +} diff --git a/test/core/address_utils/sockaddr_utils_test.cc b/test/core/address_utils/sockaddr_utils_test.cc new file mode 100644 index 00000000..2c615f14 --- /dev/null +++ b/test/core/address_utils/sockaddr_utils_test.cc @@ -0,0 +1,287 @@ +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "test/core/util/test_config.h" + +namespace { + +grpc_resolved_address MakeAddr4(const uint8_t* data, size_t data_len) { + grpc_resolved_address resolved_addr4; + grpc_sockaddr_in* addr4 = + reinterpret_cast(resolved_addr4.addr); + memset(&resolved_addr4, 0, sizeof(resolved_addr4)); + addr4->sin_family = GRPC_AF_INET; + GPR_ASSERT(data_len == sizeof(addr4->sin_addr.s_addr)); + memcpy(&addr4->sin_addr.s_addr, data, data_len); + addr4->sin_port = grpc_htons(12345); + resolved_addr4.len = static_cast(sizeof(grpc_sockaddr_in)); + return resolved_addr4; +} + +grpc_resolved_address MakeAddr6(const uint8_t* data, size_t data_len) { + grpc_resolved_address resolved_addr6; + grpc_sockaddr_in6* addr6 = + reinterpret_cast(resolved_addr6.addr); + memset(&resolved_addr6, 0, sizeof(resolved_addr6)); + addr6->sin6_family = GRPC_AF_INET6; + GPR_ASSERT(data_len == sizeof(addr6->sin6_addr.s6_addr)); + memcpy(&addr6->sin6_addr.s6_addr, data, data_len); + addr6->sin6_port = grpc_htons(12345); + resolved_addr6.len = static_cast(sizeof(grpc_sockaddr_in6)); + return resolved_addr6; +} + +void SetIPv6ScopeId(grpc_resolved_address* addr, uint32_t scope_id) { + grpc_sockaddr_in6* addr6 = reinterpret_cast(addr->addr); + ASSERT_EQ(addr6->sin6_family, GRPC_AF_INET6); + addr6->sin6_scope_id = scope_id; +} + +const uint8_t kMapped[] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0xff, 0xff, 192, 0, 2, 1}; + +const uint8_t kNotQuiteMapped[] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0xff, 0xfe, 192, 0, 2, 99}; +const uint8_t kIPv4[] = {192, 0, 2, 1}; + +const uint8_t kIPv6[] = {0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1}; + +TEST(SockAddrUtilsTest, SockAddrIsV4Mapped) { + // v4mapped input should succeed. + grpc_resolved_address input6 = MakeAddr6(kMapped, sizeof(kMapped)); + ASSERT_TRUE(grpc_sockaddr_is_v4mapped(&input6, nullptr)); + grpc_resolved_address output4; + ASSERT_TRUE(grpc_sockaddr_is_v4mapped(&input6, &output4)); + grpc_resolved_address expect4 = MakeAddr4(kIPv4, sizeof(kIPv4)); + ASSERT_EQ(memcmp(&expect4, &output4, sizeof(expect4)), 0); + + // Non-v4mapped input should fail. + input6 = MakeAddr6(kNotQuiteMapped, sizeof(kNotQuiteMapped)); + ASSERT_FALSE(grpc_sockaddr_is_v4mapped(&input6, nullptr)); + ASSERT_FALSE(grpc_sockaddr_is_v4mapped(&input6, &output4)); + // Output is unchanged. + ASSERT_EQ(memcmp(&expect4, &output4, sizeof(expect4)), 0); + + // Plain IPv4 input should also fail. + grpc_resolved_address input4 = MakeAddr4(kIPv4, sizeof(kIPv4)); + ASSERT_FALSE(grpc_sockaddr_is_v4mapped(&input4, nullptr)); +} + +TEST(SockAddrUtilsTest, SockAddrToV4Mapped) { + // IPv4 input should succeed. + grpc_resolved_address input4 = MakeAddr4(kIPv4, sizeof(kIPv4)); + grpc_resolved_address output6; + ASSERT_TRUE(grpc_sockaddr_to_v4mapped(&input4, &output6)); + grpc_resolved_address expect6 = MakeAddr6(kMapped, sizeof(kMapped)); + ASSERT_EQ(memcmp(&expect6, &output6, sizeof(output6)), 0); + + // IPv6 input should fail. + grpc_resolved_address input6 = MakeAddr6(kIPv6, sizeof(kIPv6)); + ASSERT_TRUE(!grpc_sockaddr_to_v4mapped(&input6, &output6)); + // Output is unchanged. + ASSERT_EQ(memcmp(&expect6, &output6, sizeof(output6)), 0); + + // Already-v4mapped input should also fail. + input6 = MakeAddr6(kMapped, sizeof(kMapped)); + ASSERT_TRUE(!grpc_sockaddr_to_v4mapped(&input6, &output6)); +} + +TEST(SockAddrUtilsTest, SockAddrIsWildCard) { + // Generate wildcards. + grpc_resolved_address wild4; + grpc_resolved_address wild6; + grpc_sockaddr_make_wildcards(555, &wild4, &wild6); + grpc_resolved_address wild_mapped; + ASSERT_TRUE(grpc_sockaddr_to_v4mapped(&wild4, &wild_mapped)); + + // Test 0.0.0.0:555 + int port = -1; + ASSERT_TRUE(grpc_sockaddr_is_wildcard(&wild4, &port)); + ASSERT_TRUE(port == 555); + grpc_sockaddr_in* wild4_addr = + reinterpret_cast(&wild4.addr); + memset(&wild4_addr->sin_addr.s_addr, 0xbd, 1); + ASSERT_FALSE(grpc_sockaddr_is_wildcard(&wild4, &port)); + + // Test [::]:555 + port = -1; + ASSERT_TRUE(grpc_sockaddr_is_wildcard(&wild6, &port)); + ASSERT_EQ(port, 555); + grpc_sockaddr_in6* wild6_addr = + reinterpret_cast(&wild6.addr); + memset(&wild6_addr->sin6_addr.s6_addr, 0xbd, 1); + ASSERT_FALSE(grpc_sockaddr_is_wildcard(&wild6, &port)); + + // Test [::ffff:0.0.0.0]:555 + port = -1; + ASSERT_TRUE(grpc_sockaddr_is_wildcard(&wild_mapped, &port)); + ASSERT_EQ(port, 555); + grpc_sockaddr_in6* wild_mapped_addr = + reinterpret_cast(&wild_mapped.addr); + memset(&wild_mapped_addr->sin6_addr.s6_addr, 0xbd, 1); + ASSERT_FALSE(grpc_sockaddr_is_wildcard(&wild_mapped, &port)); + + // Test AF_UNSPEC. + port = -1; + grpc_resolved_address phony; + memset(&phony, 0, sizeof(phony)); + ASSERT_FALSE(grpc_sockaddr_is_wildcard(&phony, &port)); + ASSERT_EQ(port, -1); +} + +TEST(SockAddrUtilsTest, SockAddrToString) { + errno = 0x7EADBEEF; + + grpc_resolved_address input4 = MakeAddr4(kIPv4, sizeof(kIPv4)); + EXPECT_EQ(grpc_sockaddr_to_string(&input4, false), "192.0.2.1:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input4, true), "192.0.2.1:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input4), "ipv4:192.0.2.1:12345"); + + grpc_resolved_address input6 = MakeAddr6(kIPv6, sizeof(kIPv6)); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, false), "[2001:db8::1]:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, true), "[2001:db8::1]:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input6), "ipv6:[2001:db8::1]:12345"); + + SetIPv6ScopeId(&input6, 2); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, false), "[2001:db8::1%252]:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, true), "[2001:db8::1%252]:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input6), "ipv6:[2001:db8::1%252]:12345"); + + SetIPv6ScopeId(&input6, 101); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, false), + "[2001:db8::1%25101]:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, true), + "[2001:db8::1%25101]:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input6), "ipv6:[2001:db8::1%25101]:12345"); + + input6 = MakeAddr6(kMapped, sizeof(kMapped)); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, false), + "[::ffff:192.0.2.1]:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, true), "192.0.2.1:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input6), "ipv4:192.0.2.1:12345"); + + input6 = MakeAddr6(kNotQuiteMapped, sizeof(kNotQuiteMapped)); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, false), "[::fffe:c000:263]:12345"); + EXPECT_EQ(grpc_sockaddr_to_string(&input6, true), "[::fffe:c000:263]:12345"); + EXPECT_EQ(grpc_sockaddr_to_uri(&input6), "ipv6:[::fffe:c000:263]:12345"); + + grpc_resolved_address phony; + memset(&phony, 0, sizeof(phony)); + grpc_sockaddr* phony_addr = reinterpret_cast(phony.addr); + phony_addr->sa_family = 123; + EXPECT_EQ(grpc_sockaddr_to_string(&phony, false), "(sockaddr family=123)"); + EXPECT_EQ(grpc_sockaddr_to_string(&phony, true), "(sockaddr family=123)"); + EXPECT_TRUE(grpc_sockaddr_to_uri(&phony).empty()); +} + +TEST(SockAddrUtilsTest, SockAddrSetGetPort) { + grpc_resolved_address input4 = MakeAddr4(kIPv4, sizeof(kIPv4)); + ASSERT_EQ(grpc_sockaddr_get_port(&input4), 12345); + ASSERT_TRUE(grpc_sockaddr_set_port(&input4, 54321)); + ASSERT_EQ(grpc_sockaddr_get_port(&input4), 54321); + + grpc_resolved_address input6 = MakeAddr6(kIPv6, sizeof(kIPv6)); + ASSERT_EQ(grpc_sockaddr_get_port(&input6), 12345); + ASSERT_TRUE(grpc_sockaddr_set_port(&input6, 54321)); + ASSERT_EQ(grpc_sockaddr_get_port(&input6), 54321); + + grpc_resolved_address phony; + memset(&phony, 0, sizeof(phony)); + grpc_sockaddr* phony_addr = reinterpret_cast(phony.addr); + phony_addr->sa_family = 123; + ASSERT_EQ(grpc_sockaddr_get_port(&phony), false); + ASSERT_EQ(grpc_sockaddr_set_port(&phony, 1234), false); +} + +void VerifySocketAddressMatch(const std::string& ip_address, + const std::string& subnet, uint32_t mask_bits, + bool success) { + grpc_resolved_address addr; + ASSERT_EQ(grpc_string_to_sockaddr(&addr, ip_address.c_str(), false), + GRPC_ERROR_NONE); + // Setting the port has no effect on the match. + grpc_sockaddr_set_port(&addr, 12345); + grpc_resolved_address subnet_addr; + ASSERT_EQ(grpc_string_to_sockaddr(&subnet_addr, subnet.c_str(), false), + GRPC_ERROR_NONE); + grpc_sockaddr_mask_bits(&subnet_addr, mask_bits); + EXPECT_EQ(grpc_sockaddr_match_subnet(&addr, &subnet_addr, mask_bits), success) + << "IP=" << ip_address << " Subnet=" << subnet << " Mask=" << mask_bits; +} + +void VerifySocketAddressMatchSuccess(const std::string& ip_address, + const std::string& subnet, + uint32_t mask_bits) { + // If the IP address matches the subnet for a particular length, then it would + // match for all lengths [0, mask_bits] + for (uint32_t i = 0; i <= mask_bits; i++) { + VerifySocketAddressMatch(ip_address, subnet, i, true); + } +} + +void VerifySocketAddressMatchFailure(const std::string& ip_address, + const std::string& subnet, + uint32_t mask_bits) { + // If the IP address fails matching the subnet for a particular length, then + // it would also fail for all lengths [mask_bits, 128] + for (auto i = mask_bits; i <= 128; i++) { + VerifySocketAddressMatch(ip_address, subnet, i, false); + } +} + +TEST(SockAddrUtilsTest, SockAddrMatchSubnet) { + // IPv4 Tests + VerifySocketAddressMatchSuccess("192.168.1.1", "192.168.1.1", 32); + VerifySocketAddressMatchSuccess("255.255.255.255", "255.255.255.255", 32); + VerifySocketAddressMatchFailure("192.168.1.1", "192.168.1.2", 31); + VerifySocketAddressMatchFailure("192.168.1.1", "191.0.0.0", 8); + VerifySocketAddressMatchFailure("192.168.1.1", "0.0.0.0", 1); + // IPv6 Tests + VerifySocketAddressMatchSuccess("2001:db8::", "2001::", 16); + VerifySocketAddressMatchSuccess("2001:db8:cfe:134:3ab:3456:78:9", + "2001:db8:cfe:134:3ab:3456:78:9", 128); + VerifySocketAddressMatchSuccess("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", + "FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", + 128); + VerifySocketAddressMatchFailure("2001:db8:cfe:134:3ab:3456:78:9", + "3001:2:3:4:5:6:7:8", 4); + VerifySocketAddressMatchFailure("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", + "::", 1); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int retval = RUN_ALL_TESTS(); + return retval; +} diff --git a/test/core/avl/avl_test.cc b/test/core/avl/avl_test.cc new file mode 100644 index 00000000..9af85781 --- /dev/null +++ b/test/core/avl/avl_test.cc @@ -0,0 +1,300 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/avl/avl.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/util/test_config.h" + +static int* box(int x) { + int* b = static_cast(gpr_malloc(sizeof(*b))); + *b = x; + return b; +} + +static long int_compare(void* int1, void* int2, void* /*unused*/) { + return (*static_cast(int1)) - (*static_cast(int2)); +} +static void* int_copy(void* p, void* /*unused*/) { + return box(*static_cast(p)); +} + +static void destroy(void* p, void* /*unused*/) { gpr_free(p); } + +static const grpc_avl_vtable int_int_vtable = {destroy, int_copy, int_compare, + destroy, int_copy}; + +static void check_get(grpc_avl avl, int key, int value) { + int* k = box(key); + GPR_ASSERT(*(int*)grpc_avl_get(avl, k, nullptr) == value); + gpr_free(k); +} + +static void check_negget(grpc_avl avl, int key) { + int* k = box(key); + GPR_ASSERT(grpc_avl_get(avl, k, nullptr) == nullptr); + gpr_free(k); +} + +static grpc_avl remove_int(grpc_avl avl, int key) { + int* k = box(key); + avl = grpc_avl_remove(avl, k, nullptr); + gpr_free(k); + return avl; +} + +static void test_get(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_get"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(1), box(11), nullptr); + avl = grpc_avl_add(avl, box(2), box(22), nullptr); + avl = grpc_avl_add(avl, box(3), box(33), nullptr); + check_get(avl, 1, 11); + check_get(avl, 2, 22); + check_get(avl, 3, 33); + check_negget(avl, 4); + grpc_avl_unref(avl, nullptr); +} + +static void test_ll(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_ll"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(5), box(1), nullptr); + avl = grpc_avl_add(avl, box(4), box(2), nullptr); + avl = grpc_avl_add(avl, box(3), box(3), nullptr); + GPR_ASSERT(*(int*)avl.root->key == 4); + GPR_ASSERT(*(int*)avl.root->left->key == 3); + GPR_ASSERT(*(int*)avl.root->right->key == 5); + grpc_avl_unref(avl, nullptr); +} + +static void test_lr(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_lr"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(5), box(1), nullptr); + avl = grpc_avl_add(avl, box(3), box(2), nullptr); + avl = grpc_avl_add(avl, box(4), box(3), nullptr); + GPR_ASSERT(*(int*)avl.root->key == 4); + GPR_ASSERT(*(int*)avl.root->left->key == 3); + GPR_ASSERT(*(int*)avl.root->right->key == 5); + grpc_avl_unref(avl, nullptr); +} + +static void test_rr(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_rr"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(3), box(1), nullptr); + avl = grpc_avl_add(avl, box(4), box(2), nullptr); + avl = grpc_avl_add(avl, box(5), box(3), nullptr); + GPR_ASSERT(*(int*)avl.root->key == 4); + GPR_ASSERT(*(int*)avl.root->left->key == 3); + GPR_ASSERT(*(int*)avl.root->right->key == 5); + grpc_avl_unref(avl, nullptr); +} + +static void test_rl(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_rl"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(3), box(1), nullptr); + avl = grpc_avl_add(avl, box(5), box(2), nullptr); + avl = grpc_avl_add(avl, box(4), box(3), nullptr); + GPR_ASSERT(*(int*)avl.root->key == 4); + GPR_ASSERT(*(int*)avl.root->left->key == 3); + GPR_ASSERT(*(int*)avl.root->right->key == 5); + grpc_avl_unref(avl, nullptr); +} + +static void test_unbalanced(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_unbalanced"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(5), box(1), nullptr); + avl = grpc_avl_add(avl, box(4), box(2), nullptr); + avl = grpc_avl_add(avl, box(3), box(3), nullptr); + avl = grpc_avl_add(avl, box(2), box(4), nullptr); + avl = grpc_avl_add(avl, box(1), box(5), nullptr); + GPR_ASSERT(*(int*)avl.root->key == 4); + GPR_ASSERT(*(int*)avl.root->left->key == 2); + GPR_ASSERT(*(int*)avl.root->left->left->key == 1); + GPR_ASSERT(*(int*)avl.root->left->right->key == 3); + GPR_ASSERT(*(int*)avl.root->right->key == 5); + grpc_avl_unref(avl, nullptr); +} + +static void test_replace(void) { + grpc_avl avl; + gpr_log(GPR_DEBUG, "test_replace"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(1), box(1), nullptr); + avl = grpc_avl_add(avl, box(1), box(2), nullptr); + check_get(avl, 1, 2); + check_negget(avl, 2); + grpc_avl_unref(avl, nullptr); +} + +static void test_remove(void) { + grpc_avl avl; + grpc_avl avl3, avl4, avl5, avln; + gpr_log(GPR_DEBUG, "test_remove"); + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(3), box(1), nullptr); + avl = grpc_avl_add(avl, box(4), box(2), nullptr); + avl = grpc_avl_add(avl, box(5), box(3), nullptr); + + avl3 = remove_int(grpc_avl_ref(avl, nullptr), 3); + avl4 = remove_int(grpc_avl_ref(avl, nullptr), 4); + avl5 = remove_int(grpc_avl_ref(avl, nullptr), 5); + avln = remove_int(grpc_avl_ref(avl, nullptr), 1); + + grpc_avl_unref(avl, nullptr); + + check_negget(avl3, 3); + check_get(avl3, 4, 2); + check_get(avl3, 5, 3); + grpc_avl_unref(avl3, nullptr); + + check_get(avl4, 3, 1); + check_negget(avl4, 4); + check_get(avl4, 5, 3); + grpc_avl_unref(avl4, nullptr); + + check_get(avl5, 3, 1); + check_get(avl5, 4, 2); + check_negget(avl5, 5); + grpc_avl_unref(avl5, nullptr); + + check_get(avln, 3, 1); + check_get(avln, 4, 2); + check_get(avln, 5, 3); + grpc_avl_unref(avln, nullptr); +} + +static void test_badcase1(void) { + grpc_avl avl; + + gpr_log(GPR_DEBUG, "test_badcase1"); + + avl = grpc_avl_create(&int_int_vtable); + avl = grpc_avl_add(avl, box(88), box(1), nullptr); + avl = remove_int(avl, 643); + avl = remove_int(avl, 983); + avl = grpc_avl_add(avl, box(985), box(4), nullptr); + avl = grpc_avl_add(avl, box(640), box(5), nullptr); + avl = grpc_avl_add(avl, box(41), box(6), nullptr); + avl = grpc_avl_add(avl, box(112), box(7), nullptr); + avl = grpc_avl_add(avl, box(342), box(8), nullptr); + avl = remove_int(avl, 1013); + avl = grpc_avl_add(avl, box(434), box(10), nullptr); + avl = grpc_avl_add(avl, box(520), box(11), nullptr); + avl = grpc_avl_add(avl, box(231), box(12), nullptr); + avl = grpc_avl_add(avl, box(852), box(13), nullptr); + avl = remove_int(avl, 461); + avl = grpc_avl_add(avl, box(108), box(15), nullptr); + avl = grpc_avl_add(avl, box(806), box(16), nullptr); + avl = grpc_avl_add(avl, box(827), box(17), nullptr); + avl = remove_int(avl, 796); + avl = grpc_avl_add(avl, box(340), box(19), nullptr); + avl = grpc_avl_add(avl, box(498), box(20), nullptr); + avl = grpc_avl_add(avl, box(203), box(21), nullptr); + avl = grpc_avl_add(avl, box(751), box(22), nullptr); + avl = grpc_avl_add(avl, box(150), box(23), nullptr); + avl = remove_int(avl, 237); + avl = grpc_avl_add(avl, box(830), box(25), nullptr); + avl = remove_int(avl, 1007); + avl = remove_int(avl, 394); + avl = grpc_avl_add(avl, box(65), box(28), nullptr); + avl = remove_int(avl, 904); + avl = remove_int(avl, 123); + avl = grpc_avl_add(avl, box(238), box(31), nullptr); + avl = grpc_avl_add(avl, box(184), box(32), nullptr); + avl = remove_int(avl, 331); + avl = grpc_avl_add(avl, box(827), box(34), nullptr); + + check_get(avl, 830, 25); + + grpc_avl_unref(avl, nullptr); +} + +static void test_stress(int amount_of_stress) { + int added[1024]; + int i, j; + int deletions = 0; + grpc_avl avl; + + unsigned seed = static_cast(time(nullptr)); + + gpr_log(GPR_DEBUG, "test_stress amount=%d seed=%u", amount_of_stress, seed); + + srand(static_cast(time(nullptr))); + avl = grpc_avl_create(&int_int_vtable); + + memset(added, 0, sizeof(added)); + + for (i = 1; deletions < amount_of_stress; i++) { + int idx = rand() % static_cast GPR_ARRAY_SIZE(added); + GPR_ASSERT(i); + if (rand() < RAND_MAX / 2) { + added[idx] = i; + printf("avl = grpc_avl_add(avl, box(%d), box(%d), NULL); /* d=%d */\n", + idx, i, deletions); + avl = grpc_avl_add(avl, box(idx), box(i), nullptr); + } else { + deletions += (added[idx] != 0); + added[idx] = 0; + printf("avl = remove_int(avl, %d); /* d=%d */\n", idx, deletions); + avl = remove_int(avl, idx); + } + for (j = 0; j < static_cast GPR_ARRAY_SIZE(added); j++) { + if (added[j] != 0) { + check_get(avl, j, added[j]); + } else { + check_negget(avl, j); + } + } + } + + grpc_avl_unref(avl, nullptr); +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + + test_get(); + test_ll(); + test_lr(); + test_rr(); + test_rl(); + test_unbalanced(); + test_replace(); + test_remove(); + test_badcase1(); + test_stress(10); + + return 0; +} diff --git a/test/core/backoff/backoff_test.cc b/test/core/backoff/backoff_test.cc new file mode 100644 index 00000000..0a443743 --- /dev/null +++ b/test/core/backoff/backoff_test.cc @@ -0,0 +1,182 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/backoff/backoff.h" + +#include + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +using grpc_core::BackOff; + +TEST(BackOffTest, ConstantBackOff) { + const grpc_millis initial_backoff = 200; + const double multiplier = 1.0; + const double jitter = 0.0; + const grpc_millis max_backoff = 1000; + grpc_core::ExecCtx exec_ctx; + BackOff::Options options; + options.set_initial_backoff(initial_backoff) + .set_multiplier(multiplier) + .set_jitter(jitter) + .set_max_backoff(max_backoff); + BackOff backoff(options); + + grpc_millis next_attempt_start_time = backoff.NextAttemptTime(); + EXPECT_EQ(next_attempt_start_time - grpc_core::ExecCtx::Get()->Now(), + initial_backoff); + for (int i = 0; i < 10000; i++) { + next_attempt_start_time = backoff.NextAttemptTime(); + EXPECT_EQ(next_attempt_start_time - grpc_core::ExecCtx::Get()->Now(), + initial_backoff); + } +} + +TEST(BackOffTest, MinConnect) { + const grpc_millis initial_backoff = 100; + const double multiplier = 1.0; + const double jitter = 0.0; + const grpc_millis max_backoff = 1000; + grpc_core::ExecCtx exec_ctx; + BackOff::Options options; + options.set_initial_backoff(initial_backoff) + .set_multiplier(multiplier) + .set_jitter(jitter) + .set_max_backoff(max_backoff); + BackOff backoff(options); + grpc_millis next = backoff.NextAttemptTime(); + EXPECT_EQ(next - grpc_core::ExecCtx::Get()->Now(), initial_backoff); +} + +TEST(BackOffTest, NoJitterBackOff) { + const grpc_millis initial_backoff = 2; + const double multiplier = 2.0; + const double jitter = 0.0; + const grpc_millis max_backoff = 513; + BackOff::Options options; + options.set_initial_backoff(initial_backoff) + .set_multiplier(multiplier) + .set_jitter(jitter) + .set_max_backoff(max_backoff); + BackOff backoff(options); + // x_1 = 2 + // x_n = 2**i + x_{i-1} ( = 2**(n+1) - 2 ) + grpc_core::ExecCtx exec_ctx; + grpc_core::ExecCtx::Get()->TestOnlySetNow(0); + grpc_millis next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 2); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 6); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 14); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 30); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 62); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 126); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 254); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 510); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 1022); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + // Hit the maximum timeout. From this point onwards, retries will increase + // only by max timeout. + EXPECT_EQ(next, 1535); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 2048); + grpc_core::ExecCtx::Get()->TestOnlySetNow(next); + next = backoff.NextAttemptTime(); + EXPECT_EQ(next, 2561); +} + +TEST(BackOffTest, JitterBackOff) { + const grpc_millis initial_backoff = 500; + grpc_millis current_backoff = initial_backoff; + const grpc_millis max_backoff = 1000; + const double multiplier = 1.0; + const double jitter = 0.1; + BackOff::Options options; + options.set_initial_backoff(initial_backoff) + .set_multiplier(multiplier) + .set_jitter(jitter) + .set_max_backoff(max_backoff); + BackOff backoff(options); + + backoff.SetRandomSeed(0); // force consistent PRNG + + grpc_core::ExecCtx exec_ctx; + grpc_millis next = backoff.NextAttemptTime(); + EXPECT_EQ(next - grpc_core::ExecCtx::Get()->Now(), initial_backoff); + + grpc_millis expected_next_lower_bound = static_cast( + static_cast(current_backoff) * (1 - jitter)); + grpc_millis expected_next_upper_bound = static_cast( + static_cast(current_backoff) * (1 + jitter)); + + for (int i = 0; i < 10000; i++) { + next = backoff.NextAttemptTime(); + // next-now must be within (jitter*100)% of the current backoff (which + // increases by * multiplier up to max_backoff). + const grpc_millis timeout_millis = next - grpc_core::ExecCtx::Get()->Now(); + EXPECT_GE(timeout_millis, expected_next_lower_bound); + EXPECT_LE(timeout_millis, expected_next_upper_bound); + current_backoff = + std::min(static_cast(static_cast(current_backoff) * + multiplier), + max_backoff); + expected_next_lower_bound = static_cast( + static_cast(current_backoff) * (1 - jitter)); + expected_next_upper_bound = static_cast( + static_cast(current_backoff) * (1 + jitter)); + } +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/bad_client/bad_client.cc b/test/core/bad_client/bad_client.cc new file mode 100644 index 00000000..85a70b33 --- /dev/null +++ b/test/core/bad_client/bad_client.cc @@ -0,0 +1,336 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/bad_client/bad_client.h" + +#include + +#include +#include +#include + +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/resource_user_util.h" + +#define MIN_HTTP2_FRAME_SIZE 9 + +/* Args to provide to thread running server side validator */ +typedef struct { + grpc_server* server; + grpc_completion_queue* cq; + grpc_bad_client_server_side_validator validator; + void* registered_method; + gpr_event done_thd; +} thd_args; + +/* Run the server side validator and set done_thd once done */ +static void thd_func(void* arg) { + thd_args* a = static_cast(arg); + if (a->validator != nullptr) { + a->validator(a->server, a->cq, a->registered_method); + } + gpr_event_set(&a->done_thd, reinterpret_cast(1)); +} + +/* Sets the done_write event */ +static void set_done_write(void* arg, grpc_error_handle /*error*/) { + gpr_event* done_write = static_cast(arg); + gpr_event_set(done_write, reinterpret_cast(1)); +} + +static void server_setup_transport(void* ts, grpc_transport* transport) { + thd_args* a = static_cast(ts); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "SetupTransport", + a->server->core_server->SetupTransport( + transport, + /*accepting_pollset=*/nullptr, a->server->core_server->channel_args(), + /*socket_node=*/nullptr))); +} + +/* Sets the read_done event */ +static void set_read_done(void* arg, grpc_error_handle /*error*/) { + gpr_event* read_done = static_cast(arg); + gpr_event_set(read_done, reinterpret_cast(1)); +} + +/* shutdown client */ +static void shutdown_client(grpc_endpoint** client_fd) { + if (*client_fd != nullptr) { + grpc_endpoint_shutdown( + *client_fd, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Forced Disconnect")); + grpc_endpoint_destroy(*client_fd); + grpc_core::ExecCtx::Get()->Flush(); + *client_fd = nullptr; + } +} + +/* Runs client side validator */ +void grpc_run_client_side_validator(grpc_bad_client_arg* arg, uint32_t flags, + grpc_endpoint_pair* sfd, + grpc_completion_queue* client_cq) { + char* hex; + gpr_event done_write; + if (arg->client_payload_length < 4 * 1024) { + hex = gpr_dump(arg->client_payload, arg->client_payload_length, + GPR_DUMP_HEX | GPR_DUMP_ASCII); + /* Add a debug log */ + gpr_log(GPR_INFO, "TEST: %s", hex); + gpr_free(hex); + } else { + gpr_log(GPR_INFO, "TEST: (%" PRIdPTR " byte long string)", + arg->client_payload_length); + } + + grpc_slice slice = grpc_slice_from_copied_buffer(arg->client_payload, + arg->client_payload_length); + grpc_slice_buffer outgoing; + grpc_closure done_write_closure; + gpr_event_init(&done_write); + + grpc_slice_buffer_init(&outgoing); + grpc_slice_buffer_add(&outgoing, slice); + GRPC_CLOSURE_INIT(&done_write_closure, set_done_write, &done_write, + grpc_schedule_on_exec_ctx); + + /* Write data */ + grpc_endpoint_write(sfd->client, &outgoing, &done_write_closure, nullptr); + grpc_core::ExecCtx::Get()->Flush(); + + /* Await completion, unless the request is large and write may not finish + * before the peer shuts down. */ + if (!(flags & GRPC_BAD_CLIENT_LARGE_REQUEST)) { + GPR_ASSERT( + gpr_event_wait(&done_write, grpc_timeout_seconds_to_deadline(5))); + } + + if (flags & GRPC_BAD_CLIENT_DISCONNECT) { + shutdown_client(&sfd->client); + } + + if (sfd->client != nullptr) { + /* Validate client stream, if requested. */ + if (arg->client_validator != nullptr) { + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + grpc_slice_buffer incoming; + grpc_slice_buffer_init(&incoming); + /* We may need to do multiple reads to read the complete server + * response. */ + while (true) { + gpr_event read_done_event; + gpr_event_init(&read_done_event); + grpc_closure read_done_closure; + GRPC_CLOSURE_INIT(&read_done_closure, set_read_done, &read_done_event, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(sfd->client, &incoming, &read_done_closure, + /*urgent=*/true); + grpc_core::ExecCtx::Get()->Flush(); + do { + GPR_ASSERT(gpr_time_cmp(deadline, gpr_now(deadline.clock_type)) > 0); + /* Perform a cq next just to provide a thread that can read incoming + bytes on the client fd */ + GPR_ASSERT(grpc_completion_queue_next( + client_cq, grpc_timeout_milliseconds_to_deadline(100), + nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } while (!gpr_event_get(&read_done_event)); + if (arg->client_validator(&incoming, arg->client_validator_arg)) break; + gpr_log(GPR_INFO, + "client validator failed; trying additional read " + "in case we didn't get all the data"); + } + grpc_slice_buffer_destroy_internal(&incoming); + } + grpc_core::ExecCtx::Get()->Flush(); + } + + /* If the request was too large, then we need to forcefully shut down the + * client, so that the write can be considered completed */ + if (flags & GRPC_BAD_CLIENT_LARGE_REQUEST) { + shutdown_client(&sfd->client); + } + + /* Make sure that the client is done writing */ + while (!gpr_event_get(&done_write)) { + GPR_ASSERT( + grpc_completion_queue_next( + client_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } + + grpc_slice_buffer_destroy_internal(&outgoing); + grpc_core::ExecCtx::Get()->Flush(); +} + +void grpc_run_bad_client_test( + grpc_bad_client_server_side_validator server_validator, + grpc_bad_client_arg args[], int num_args, uint32_t flags) { + grpc_endpoint_pair sfd; + thd_args a; + grpc_transport* transport; + grpc_core::ExecCtx exec_ctx; + grpc_completion_queue* shutdown_cq; + grpc_completion_queue* client_cq; + + /* Init grpc */ + grpc_init(); + + sfd = grpc_iomgr_create_endpoint_pair("fixture", nullptr); + /* Create server, completion events */ + a.server = grpc_server_create(nullptr, nullptr); + a.cq = grpc_completion_queue_create_for_next(nullptr); + client_cq = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(a.server, a.cq, nullptr); + a.registered_method = + grpc_server_register_method(a.server, GRPC_BAD_CLIENT_REGISTERED_METHOD, + GRPC_BAD_CLIENT_REGISTERED_HOST, + GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER, 0); + grpc_server_start(a.server); + transport = grpc_create_chttp2_transport( + nullptr, sfd.server, false, grpc_resource_user_create_unlimited()); + server_setup_transport(&a, transport); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + + /* Bind fds to pollsets */ + grpc_endpoint_add_to_pollset(sfd.client, grpc_cq_pollset(client_cq)); + grpc_endpoint_add_to_pollset(sfd.server, grpc_cq_pollset(a.cq)); + + /* Check a ground truth */ + GPR_ASSERT(a.server->core_server->HasOpenConnections()); + + gpr_event_init(&a.done_thd); + a.validator = server_validator; + /* Start validator */ + + grpc_core::Thread server_validator_thd("grpc_bad_client", thd_func, &a); + server_validator_thd.Start(); + for (int i = 0; i < num_args; i++) { + grpc_run_client_side_validator(&args[i], i == (num_args - 1) ? flags : 0, + &sfd, client_cq); + } + /* Wait for server thread to finish */ + GPR_ASSERT(gpr_event_wait(&a.done_thd, grpc_timeout_seconds_to_deadline(1))); + + /* Shutdown. */ + shutdown_client(&sfd.client); + server_validator_thd.Join(); + shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(a.server, shutdown_cq, nullptr); + GPR_ASSERT(grpc_completion_queue_pluck(shutdown_cq, nullptr, + grpc_timeout_seconds_to_deadline(1), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_completion_queue_destroy(shutdown_cq); + grpc_server_destroy(a.server); + grpc_completion_queue_destroy(a.cq); + grpc_completion_queue_destroy(client_cq); + grpc_shutdown(); +} + +bool client_connection_preface_validator(grpc_slice_buffer* incoming, + void* /*arg*/) { + if (incoming->count < 1) { + return false; + } + grpc_slice slice = incoming->slices[0]; + /* There should be at least one settings frame present */ + if (GRPC_SLICE_LENGTH(slice) < MIN_HTTP2_FRAME_SIZE) { + return false; + } + const uint8_t* p = GRPC_SLICE_START_PTR(slice); + /* Check the frame type (SETTINGS) */ + return *(p + 3) == 4; +} + +/* connection preface and settings frame to be sent by the client */ +#define CONNECTION_PREFACE_FROM_CLIENT \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + +grpc_bad_client_arg connection_preface_arg = { + client_connection_preface_validator, nullptr, + CONNECTION_PREFACE_FROM_CLIENT, sizeof(CONNECTION_PREFACE_FROM_CLIENT) - 1}; + +bool rst_stream_client_validator(grpc_slice_buffer* incoming, void* /*arg*/) { + // Get last frame from incoming slice buffer. + grpc_slice_buffer last_frame_buffer; + grpc_slice_buffer_init(&last_frame_buffer); + grpc_slice_buffer_trim_end(incoming, 13, &last_frame_buffer); + GPR_ASSERT(last_frame_buffer.count == 1); + grpc_slice last_frame = last_frame_buffer.slices[0]; + + const uint8_t* p = GRPC_SLICE_START_PTR(last_frame); + bool success = + // Length == 4 + *p++ != 0 || *p++ != 0 || *p++ != 4 || + // Frame type (RST_STREAM) + *p++ != 3 || + // Flags + *p++ != 0 || + // Stream ID. + *p++ != 0 || *p++ != 0 || *p++ != 0 || *p++ != 1 || + // Payload (error code) + *p++ == 0 || *p++ == 0 || *p++ == 0 || *p == 0 || *p == 11; + + if (!success) { + gpr_log(GPR_INFO, "client expected RST_STREAM frame, not found"); + } + + grpc_slice_buffer_destroy(&last_frame_buffer); + return success; +} + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +void server_verifier_request_call(grpc_server* server, + grpc_completion_queue* cq, + void* /*registered_method*/) { + grpc_call_error error; + grpc_call* s; + grpc_call_details call_details; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_metadata_array request_metadata_recv; + + grpc_call_details_init(&call_details); + grpc_metadata_array_init(&request_metadata_recv); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.host, "localhost")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo/bar")); + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(s); + cq_verifier_destroy(cqv); +} diff --git a/test/core/bad_client/tests/bad_streaming_id.cc b/test/core/bad_client/tests/bad_streaming_id.cc new file mode 100644 index 00000000..85b5c94f --- /dev/null +++ b/test/core/bad_client/tests/bad_streaming_id.cc @@ -0,0 +1,133 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +#define HEADER_FRAME_ID_1 \ + "\x00\x00\xc9\x01\x05\x00\x00\x00\x01" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +#define HEADER_FRAME_ID_2 \ + "\x00\x00\xc9\x01\x05\x00\x00\x00\x02" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +#define HEADER_FRAME_ID_3 \ + "\x00\x00\xc9\x01\x05\x00\x00\x00\x03" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +namespace { + +void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +TEST(BadStreamingId, RegularHeader) { + grpc_bad_client_arg args[2]; + args[0] = connection_preface_arg; + args[1].client_validator = nullptr; + args[1].client_payload = HEADER_FRAME_ID_1; + args[1].client_payload_length = sizeof(HEADER_FRAME_ID_1) - 1; + grpc_run_bad_client_test(verifier, args, 2, GRPC_BAD_CLIENT_DISCONNECT); +} + +TEST(BadStreamingId, NonClientStreamId) { + grpc_bad_client_arg args[2]; + args[0] = connection_preface_arg; + // send a header frame with non-client stream id 2 + args[1].client_validator = nullptr; + args[1].client_payload = HEADER_FRAME_ID_2; + args[1].client_payload_length = sizeof(HEADER_FRAME_ID_2) - 1; + grpc_run_bad_client_test(verifier, args, 2, GRPC_BAD_CLIENT_DISCONNECT); +} + +TEST(BadStreamingId, ClosedStreamId) { + grpc_bad_client_arg args[4]; + args[0] = connection_preface_arg; + // send a header frame with stream id 1 + args[1].client_validator = nullptr; + args[1].client_payload = HEADER_FRAME_ID_1; + args[1].client_payload_length = sizeof(HEADER_FRAME_ID_1) - 1; + // send a header frame with stream id 3 + args[2].client_validator = nullptr; + args[2].client_payload = HEADER_FRAME_ID_3; + args[2].client_payload_length = sizeof(HEADER_FRAME_ID_3) - 1; + // send a header frame with closed stream id 1 again + args[3].client_validator = nullptr; + args[3].client_payload = HEADER_FRAME_ID_1; + args[3].client_payload_length = sizeof(HEADER_FRAME_ID_1) - 1; + grpc_run_bad_client_test(verifier, args, 4, GRPC_BAD_CLIENT_DISCONNECT); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/bad_client/tests/badreq.cc b/test/core/bad_client/tests/badreq.cc new file mode 100644 index 00000000..8c3c5a39 --- /dev/null +++ b/test/core/bad_client/tests/badreq.cc @@ -0,0 +1,133 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +#define PFX_STR \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* invalid content type */ + GRPC_RUN_BAD_CLIENT_TEST( + verifier, nullptr, + PFX_STR + "\x00\x00\xc2\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x08/foo/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x09text/html" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)", + GRPC_BAD_CLIENT_DISCONNECT); + + /* invalid te */ + GRPC_RUN_BAD_CLIENT_TEST( + verifier, nullptr, + PFX_STR + "\x00\x00\xcb\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x08/foo/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x02te\x0a" + "frobnicate" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)", + GRPC_BAD_CLIENT_DISCONNECT); + + /* two path headers */ + GRPC_RUN_BAD_CLIENT_TEST( + verifier, nullptr, + PFX_STR + "\x00\x00\xd9\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x08/foo/bar" + "\x10\x05:path\x08/foo/bah" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)", + GRPC_BAD_CLIENT_DISCONNECT); + + /* bad accept-encoding algorithm */ + GRPC_RUN_BAD_CLIENT_TEST( + verifier, nullptr, + PFX_STR + "\x00\x00\xd2\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x08/foo/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x1enobody-knows-the-trouble-i-see" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)", + GRPC_BAD_CLIENT_DISCONNECT); + + /* bad grpc-encoding algorithm */ + GRPC_RUN_BAD_CLIENT_TEST( + verifier, nullptr, + PFX_STR + "\x00\x00\xf5\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x08/foo/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x0dgrpc-encoding\x1cyou-dont-know-how-to-do-this" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)", + GRPC_BAD_CLIENT_DISCONNECT); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/connection_prefix.cc b/test/core/bad_client/tests/connection_prefix.cc new file mode 100644 index 00000000..45473722 --- /dev/null +++ b/test/core/bad_client/tests/connection_prefix.cc @@ -0,0 +1,64 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRIX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI *X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTPX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0X", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\rX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\nX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\rX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nSX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nSMX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nSM\rX", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nSM\r\nX", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, "PRI * HTTP/2.0\r\n\r\nSM\r\n\rX", + 0); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/duplicate_header.cc b/test/core/bad_client/tests/duplicate_header.cc new file mode 100644 index 00000000..de276634 --- /dev/null +++ b/test/core/bad_client/tests/duplicate_header.cc @@ -0,0 +1,135 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +#define PFX_STR \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ + +#define HEADER_STR \ + "\x00\x00\xc9\x01\x04\x00\x00\x00\x01" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +#define PAYLOAD_STR \ + "\x00\x00\x20\x00\x00\x00\x00\x00\x01" \ + "\x00\x00\x00\x00" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + grpc_call_error error; + grpc_call* s; + grpc_call_details call_details; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_op* op; + grpc_op ops[6]; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_metadata_array request_metadata_recv; + int was_cancelled = 2; + + grpc_call_details_init(&call_details); + grpc_metadata_array_init(&request_metadata_recv); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.host, "localhost")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo/bar")); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(s); + cq_verifier_destroy(cqv); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* Verify that sending multiple headers doesn't segfault */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR HEADER_STR HEADER_STR PAYLOAD_STR, 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR HEADER_STR HEADER_STR HEADER_STR PAYLOAD_STR, + 0); + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/head_of_line_blocking.cc b/test/core/bad_client/tests/head_of_line_blocking.cc new file mode 100644 index 00000000..cb140eb6 --- /dev/null +++ b/test/core/bad_client/tests/head_of_line_blocking.cc @@ -0,0 +1,139 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +static const char prefix[] = + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + // settings frame + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + // stream 1 headers: generated from server_registered_method.headers in this + // directory + "\x00\x00\xd0\x01\x04\x00\x00\x00\x01" + "\x10\x05:path\x0f/registered/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + // data frame for stream 1: advertise a 10000 byte payload (that we won't + // fulfill) + "\x00\x00\x05\x00\x00\x00\x00\x00\x01" + "\x01\x00\x00\x27\x10" + // stream 3 headers: generated from server_registered_method.headers in this + // directory + "\x00\x00\xd0\x01\x04\x00\x00\x00\x03" + "\x10\x05:path\x0f/registered/bar" + "\x10\x07:scheme\x04http" + "\x10\x07:method\x04POST" + "\x10\x0a:authority\x09localhost" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" + "\x10\x02te\x08trailers" + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + // data frame for stream 3: advertise a 10000 byte payload (that we will + // fulfill) + "\x00\x00\x05\x00\x00\x00\x00\x00\x03" + "\x01\x00\x00\x27\x10" + ""; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* registered_method) { + grpc_call_error error; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_metadata_array request_metadata_recv; + gpr_timespec deadline; + grpc_byte_buffer* payload = nullptr; + + grpc_metadata_array_init(&request_metadata_recv); + + error = grpc_server_request_registered_call(server, registered_method, &s, + &deadline, &request_metadata_recv, + &payload, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + GPR_ASSERT(payload != nullptr); + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_unref(s); + grpc_byte_buffer_destroy(payload); + cq_verifier_destroy(cqv); +} + +char* g_buffer; +size_t g_cap = 0; +size_t g_count = 0; + +static void addbuf(const void* data, size_t len) { + if (g_count + len > g_cap) { + g_cap = std::max(g_count + len, g_cap * 2); + g_buffer = static_cast(gpr_realloc(g_buffer, g_cap)); + } + memcpy(g_buffer + g_count, data, len); + g_count += len; +} + +int main(int argc, char** argv) { + int i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + +#define NUM_FRAMES 10 +#define FRAME_SIZE 1000 + + addbuf(prefix, sizeof(prefix) - 1); + for (i = 0; i < NUM_FRAMES; i++) { + uint8_t hdr[9] = {static_cast(FRAME_SIZE >> 16), + static_cast(FRAME_SIZE >> 8), + static_cast(FRAME_SIZE), + 0, + 0, + 0, + 0, + 0, + 3}; + uint8_t msg[FRAME_SIZE]; + memset(msg, 'a', sizeof(msg)); + addbuf(hdr, sizeof(hdr)); + addbuf(msg, FRAME_SIZE); + } + grpc_bad_client_arg bca = {nullptr, nullptr, g_buffer, g_count}; + grpc_run_bad_client_test(verifier, &bca, 1, 0); + gpr_free(g_buffer); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/bad_client/tests/headers.cc b/test/core/bad_client/tests/headers.cc new file mode 100644 index 00000000..b0795dd0 --- /dev/null +++ b/test/core/bad_client/tests/headers.cc @@ -0,0 +1,341 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +#define PFX_STR \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* partial http2 header prefixes */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x01\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x01\x04", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x01\x05", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x04\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x04\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x04\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x04\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x04\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + + /* test adding prioritization data */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x01\x01\x24\x00\x00\x00\x01" + "\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x02\x01\x24\x00\x00\x00\x01" + "\x00\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x24\x00\x00\x00\x01" + "\x00\x00\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x24\x00\x00\x00\x01" + "\x00\x00\x00\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x24\x00\x00\x00\x01" + "\x00\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + + /* test looking up an invalid index */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x01\x01\x04\x00\x00\x00\x01" + "\xfe", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x04\x00\x00\x00\x01" + "\x7f\x7f\x01" + "a", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x04\x00\x00\x00\x01" + "\x0f\x7f\x01" + "a", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x04\x00\x00\x00\x01" + "\x1f\x7f\x01" + "a", + 0); + /* test nvr, not indexed in static table */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\x01\x01" + "a", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\x11\x01" + "a", + GRPC_BAD_CLIENT_DISCONNECT); + /* illegal op code */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x01\x01\x04\x00\x00\x00\x01" + "\x80", + 0); + /* parse some long indices */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x02\x01\x04\x00\x00\x00\x01" + "\xff\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\xff\x80\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x06\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x07\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x80\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x80\x80", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x80\x80\x00", + 0); + /* overflow on byte 4 */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x06\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x7f", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x06\x01\x04\x00\x00\x00\x01" + "\xff\xff\xff\xff\xff\x0f", + GRPC_BAD_CLIENT_DISCONNECT); + /* overflow after byte 4 */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x08\x01\x04\x00\x00\x00\x01" + "\xff\x80\x80\x80\x80\x80\x80\x02", + 0); + /* end of headers mid-opcode */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x01\x01\x04\x00\x00\x00\x01" + "\x01", + GRPC_BAD_CLIENT_DISCONNECT); + + /* dynamic table size update: set to default */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\x3f\xe1\x1f", + GRPC_BAD_CLIENT_DISCONNECT); + /* dynamic table size update: set too large */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\x3f\xf1\x1f", + 0); + /* dynamic table size update: set twice */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x04\x01\x04\x00\x00\x00\x01" + "\x20\x3f\xe1\x1f", + GRPC_BAD_CLIENT_DISCONNECT); + /* dynamic table size update: set thrice */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x03\x01\x04\x00\x00\x00\x01" + "\x20\x20\x20", + 0); + + /* non-ending header followed by continuation frame */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x00\x01\x00\x00\x00\x00\x01" + "\x00\x00\x00\x09\x04\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + /* non-ending header followed by non-continuation frame */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x00\x01\x00\x00\x00\x00\x01" + "\x00\x00\x00\x00\x04\x00\x00\x00\x01", + 0); + /* non-ending header followed by a continuation frame for a different stream + */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x00\x01\x04\x00\x00\x00\x01" + "\x00\x00\x00\x01\x00\x00\x00\x00\x03" + "\x00\x00\x00\x09\x04\x00\x00\x00\x01", + 0); + /* opening with a continuation frame */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x09\x04\x00\x00\x00\x01", 0); + /* three header frames */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x00\x01\x04\x00\x00\x00\x01" + "\x00\x00\x00\x01\x04\x00\x00\x00\x01" + "\x00\x00\x00\x01\x04\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + + /* an invalid header found with fuzzing */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x01\x39\x67\xed\x1d\x64", + GRPC_BAD_CLIENT_DISCONNECT); + + /* a badly encoded timeout value */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x19\x01\x04\x00\x00\x00\x01" + "\x10\x0cgrpc-timeout\x0a" + "15 seconds", + GRPC_BAD_CLIENT_DISCONNECT); + /* a badly encoded timeout value: twice (catches caching) */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x19\x01\x04\x00\x00\x00\x01" + "\x10\x0cgrpc-timeout\x0a" + "15 seconds" + "\x00\x00\x19\x01\x04\x00\x00\x00\x03" + "\x10\x0cgrpc-timeout\x0a" + "15 seconds", + GRPC_BAD_CLIENT_DISCONNECT); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/initial_settings_frame.cc b/test/core/bad_client/tests/initial_settings_frame.cc new file mode 100644 index 00000000..0deda702 --- /dev/null +++ b/test/core/bad_client/tests/initial_settings_frame.cc @@ -0,0 +1,112 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +#define PFX_STR "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" +#define ONE_SETTING_HDR "\x00\x00\x06\x04\x00\x00\x00\x00\x00" + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* various partial prefixes */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x06", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x06", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x06", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x04", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x04\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x04\x01", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR "\x00\x00\x00\x04\xff", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + /* must not send frames with stream id != 0 */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x00\x00\x00\x00\x01", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x00\x40\x00\x00\x00", 0); + /* settings frame must be a multiple of six bytes long */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x01\x04\x00\x00\x00\x00\x00", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x02\x04\x00\x00\x00\x00\x00", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x03\x04\x00\x00\x00\x00\x00", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x04\x04\x00\x00\x00\x00\x00", 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x05\x04\x00\x00\x00\x00\x00", 0); + /* some settings values are illegal */ + /* max frame size = 0 */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR ONE_SETTING_HDR "\x00\x05\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR ONE_SETTING_HDR "\x00\x06\xff\xff\xff\xff", + GRPC_BAD_CLIENT_DISCONNECT); + /* update intiial window size */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR ONE_SETTING_HDR "\x00\x04\x00\x01\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + /* ack with data */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + "\x00\x00\x01\x04\x01\x00\x00\x00\x00", + 0); + /* settings frame with invalid flags */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x04\x10\x00\x00\x00\x00", 0); + /* unknown settings should be ignored */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR ONE_SETTING_HDR "\x00\x99\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/large_metadata.cc b/test/core/bad_client/tests/large_metadata.cc new file mode 100644 index 00000000..5b456172 --- /dev/null +++ b/test/core/bad_client/tests/large_metadata.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +// The large-metadata headers that we're adding for this test are not +// actually appended to this in a single string, since the string would +// be longer than the C99 string literal limit. Instead, we dynamically +// construct it by adding the large headers one at a time. + +/* headers: generated from large_metadata.headers in this directory */ +#define PFX_TOO_MUCH_METADATA_FROM_CLIENT_REQUEST \ + "\x00\x00\x00\x04\x01\x00\x00\x00\x00" \ + "\x00" \ + "5{\x01\x05\x00\x00\x00\x01" \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +// Each large-metadata header is constructed from these start and end +// strings, with a two-digit number in between. +#define PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_START_STR "\x10\x0duser-header" +#define PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_END_STR \ + "~aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" \ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + +// The size of each large-metadata header string. +#define PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_SIZE \ + ((sizeof(PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_START_STR) - 1) + 2 + \ + (sizeof(PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_END_STR) - 1)) + +// The number of headers we're adding and the total size of the client +// payload. +#define NUM_HEADERS 46 +#define TOO_MUCH_METADATA_FROM_CLIENT_REQUEST_SIZE \ + ((sizeof(PFX_TOO_MUCH_METADATA_FROM_CLIENT_REQUEST) - 1) + \ + (NUM_HEADERS * PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_SIZE) + 1) + +int main(int argc, char** argv) { + int i; + grpc_init(); + grpc::testing::TestEnvironment env(argc, argv); + + // Test sending more metadata than the server will accept. + std::vector headers; + for (i = 0; i < NUM_HEADERS; ++i) { + headers.push_back(absl::StrFormat( + "%s%02d%s", PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_START_STR, i, + PFX_TOO_MUCH_METADATA_FROM_CLIENT_HEADER_END_STR)); + } + std::string client_headers = absl::StrJoin(headers, ""); + char client_payload[TOO_MUCH_METADATA_FROM_CLIENT_REQUEST_SIZE] = + PFX_TOO_MUCH_METADATA_FROM_CLIENT_REQUEST; + memcpy(client_payload + sizeof(PFX_TOO_MUCH_METADATA_FROM_CLIENT_REQUEST) - 1, + client_headers.data(), client_headers.size()); + grpc_bad_client_arg args[2]; + args[0] = connection_preface_arg; + args[1].client_validator = rst_stream_client_validator; + args[1].client_payload = client_payload; + args[1].client_payload_length = sizeof(client_payload) - 1; + + grpc_run_bad_client_test(server_verifier_request_call, args, 2, 0); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/out_of_bounds.cc b/test/core/bad_client/tests/out_of_bounds.cc new file mode 100644 index 00000000..7bf1a914 --- /dev/null +++ b/test/core/bad_client/tests/out_of_bounds.cc @@ -0,0 +1,113 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +#define APPEND_BUFFER(string, to_append) \ + ((string).append((to_append), sizeof(to_append) - 1)) + +namespace { + +void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +void FrameVerifier(const std::string& attack_vector) { + grpc_bad_client_arg args[2]; + args[0] = connection_preface_arg; + args[1].client_validator = nullptr; + args[1].client_payload = attack_vector.c_str(); + args[1].client_payload_length = attack_vector.size(); + grpc_run_bad_client_test(verifier, args, 2, GRPC_BAD_CLIENT_DISCONNECT); +} + +TEST(OutOfBounds, MaxFrameSizeDataFrame) { + std::string out_of_bounds_data; + // Send a data frame larger than 2^14 + APPEND_BUFFER(out_of_bounds_data, "\x01\x00\x00\x00\x00\x00\x00\x00\x01"); + out_of_bounds_data.append(1 << 16, 'a'); + FrameVerifier(out_of_bounds_data); +} + +TEST(OutOfBounds, BadSizePriorityFrame) { + std::string bad_size_priority_frame; + // Priority Frame should be a length of 5 octets + APPEND_BUFFER(bad_size_priority_frame, + "\x00\x00\x03\x02\x00\x00\x00\x00\x01" + "\x11\x11\x12"); + FrameVerifier(bad_size_priority_frame); +} + +TEST(OutOfBounds, BadSizeRstStream) { + std::string bad_size_rst_stream; + // Rst Stream Frame should have a length of 4 octets + APPEND_BUFFER(bad_size_rst_stream, + "\x00\x00\x02\x03\x00\x00\x00\x00\x01" + "\x11\x11"); + FrameVerifier(bad_size_rst_stream); +} + +TEST(OutOfBounds, BadSizeSettings) { + std::string bad_size_settings; + // Settings Frame should have a length which is a multiple of 6 octets + APPEND_BUFFER(bad_size_settings, + "\x00\x00\x05\x04\x00\x00\x00\x00\x00" + "\x11\x11\x11\x11\x11"); + FrameVerifier(bad_size_settings); +} + +TEST(OutOfBounds, BadSizePing) { + std::string bad_size_ping; + // Rst Stream Frame should have a length of 8 octets + APPEND_BUFFER(bad_size_ping, + "\x00\x00\x05\x06\x00\x00\x00\x00\x00" + "\x11\x11\x11\x11\x11"); + FrameVerifier(bad_size_ping); +} + +TEST(OutOfBounds, WindowUpdate) { + std::string bad_size_window_update; + // Window Update Frame should have a length of 4 octets + APPEND_BUFFER(bad_size_window_update, + "\x00\x00\x01\x08\x00\x00\x00\x00\x00" + "\x11"); + FrameVerifier(bad_size_window_update); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/bad_client/tests/server_registered_method.cc b/test/core/bad_client/tests/server_registered_method.cc new file mode 100644 index 00000000..e176f611 --- /dev/null +++ b/test/core/bad_client/tests/server_registered_method.cc @@ -0,0 +1,128 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +#define PFX_STR \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ \ + "\x00\x00\xd0\x01\x04\x00\x00\x00\x01" \ + "\x10\x05:path\x0f/registered/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void verifier_succeeds(grpc_server* server, grpc_completion_queue* cq, + void* registered_method) { + grpc_call_error error; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_metadata_array request_metadata_recv; + gpr_timespec deadline; + grpc_byte_buffer* payload = nullptr; + + grpc_metadata_array_init(&request_metadata_recv); + + error = grpc_server_request_registered_call(server, registered_method, &s, + &deadline, &request_metadata_recv, + &payload, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + GPR_ASSERT(payload != nullptr); + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_unref(s); + grpc_byte_buffer_destroy(payload); + cq_verifier_destroy(cqv); +} + +static void verifier_fails(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* body generated with + * tools/codegen/core/gen_server_registered_method_bad_client_test_body.py */ + GRPC_RUN_BAD_CLIENT_TEST(verifier_fails, nullptr, + PFX_STR "\x00\x00\x00\x00\x00\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier_fails, nullptr, + PFX_STR "\x00\x00\x01\x00\x00\x00\x00\x00\x01\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier_fails, nullptr, + PFX_STR + "\x00\x00\x02\x00\x00\x00\x00\x00\x01\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST(verifier_fails, nullptr, + PFX_STR + "\x00\x00\x03\x00\x00\x00\x00\x00\x01\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_fails, nullptr, + PFX_STR "\x00\x00\x04\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_succeeds, nullptr, + PFX_STR "\x00\x00\x05\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x00", 0); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_fails, nullptr, + PFX_STR "\x00\x00\x05\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x01", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_succeeds, nullptr, + PFX_STR "\x00\x00\x06\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01\x00", + 0); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_fails, nullptr, + PFX_STR "\x00\x00\x05\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x02", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_fails, nullptr, + PFX_STR "\x00\x00\x06\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x02\x00", + GRPC_BAD_CLIENT_DISCONNECT); + GRPC_RUN_BAD_CLIENT_TEST( + verifier_succeeds, nullptr, + PFX_STR + "\x00\x00\x07\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x02\x00\x00", + 0); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/simple_request.cc b/test/core/bad_client/tests/simple_request.cc new file mode 100644 index 00000000..3c4bd84d --- /dev/null +++ b/test/core/bad_client/tests/simple_request.cc @@ -0,0 +1,172 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" +#include "test/core/end2end/cq_verifier.h" + +#define PFX_STR \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ \ + "\x00\x00\xc9\x01\x04\x00\x00\x00\x01" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +#define PFX_STR_UNUSUAL \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ \ + "\x00\x00\xf4\x01\x04\x00\x00\x00\x01" /* headers: generated from \ + simple_request_unusual.headers \ + in this directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x04host\x09localhost" \ + "\x10\x0c" \ + "content-type\x1e" \ + "application/grpc+this-is-valid" \ + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" \ + "\x10\x0cgrpc-timeout\x03" \ + "10S" \ + "\x10\x0cgrpc-timeout\x02" \ + "5S" + +#define PFX_STR_UNUSUAL2 \ + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" \ + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" /* settings frame */ \ + "\x00\x00\xf4\x01\x04\x00\x00\x00\x01" /* headers: generated from \ + simple_request_unusual2.headers \ + in this directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x04host\x09localhost" \ + "\x10\x0c" \ + "content-type\x1e" \ + "application/grpc;this-is-valid" \ + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" \ + "\x10\x0cgrpc-timeout\x03" \ + "10S" \ + "\x10\x0cgrpc-timeout\x02" \ + "5S" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + grpc_call_error error; + grpc_call* s; + grpc_call_details call_details; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_metadata_array request_metadata_recv; + + grpc_call_details_init(&call_details); + grpc_metadata_array_init(&request_metadata_recv); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.host, "localhost")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo/bar")); + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(s); + cq_verifier_destroy(cqv); +} + +static void failure_verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* basic request: check that things are working */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR, 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR_UNUSUAL, 0); + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, PFX_STR_UNUSUAL2, 0); + + /* push an illegal data frame */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR + "\x00\x00\x05\x00\x00\x00\x00\x00\x01" + "\x34\x00\x00\x00\x00", + 0); + + /* push a data frame with bad flags */ + GRPC_RUN_BAD_CLIENT_TEST(verifier, nullptr, + PFX_STR "\x00\x00\x00\x00\x02\x00\x00\x00\x01", 0); + /* push a window update with a bad length */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x00\x01\x08\x00\x00\x00\x00\x01", 0); + /* push a window update with bad flags */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x00\x00\x08\x10\x00\x00\x00\x01", 0); + /* push a window update with bad data (0 is not legal window size increment) + */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR + "\x00\x00\x04\x08\x00\x00\x00\x00\x01" + "\x00\x00\x00\x00", + 0); + /* push a short goaway */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x00\x04\x07\x00\x00\x00\x00\x00", 0); + /* disconnect before sending goaway */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x01\x12\x07\x00\x00\x00\x00\x00", + GRPC_BAD_CLIENT_DISCONNECT); + /* push a rst_stream with a bad length */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x00\x01\x03\x00\x00\x00\x00\x01", 0); + /* push a rst_stream with bad flags */ + GRPC_RUN_BAD_CLIENT_TEST(failure_verifier, nullptr, + PFX_STR "\x00\x00\x00\x03\x10\x00\x00\x00\x01", 0); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/bad_client/tests/unknown_frame.cc b/test/core/bad_client/tests/unknown_frame.cc new file mode 100644 index 00000000..e11625bf --- /dev/null +++ b/test/core/bad_client/tests/unknown_frame.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +#define APPEND_BUFFER(string, to_append) \ + ((string).append((to_append), sizeof(to_append) - 1)) + +namespace { +TEST(UnknownFrameType, Test) { + /* test that all invalid/unknown frame types are handled */ + for (int i = 10; i <= 255; i++) { + std::string unknown_frame_string; + APPEND_BUFFER(unknown_frame_string, "\x00\x00\x00"); + char frame_type = static_cast(i); + unknown_frame_string.append(&frame_type, 1); + APPEND_BUFFER(unknown_frame_string, "\x00\x00\x00\x00\x01"); + grpc_bad_client_arg args[2]; + args[0] = connection_preface_arg; + args[1].client_validator = nullptr; + args[1].client_payload = unknown_frame_string.c_str(); + args[1].client_payload_length = unknown_frame_string.size(); + grpc_run_bad_client_test(verifier, args, 2, GRPC_BAD_CLIENT_DISCONNECT); + } +} +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/bad_client/tests/window_overflow.cc b/test/core/bad_client/tests/window_overflow.cc new file mode 100644 index 00000000..81fa27e2 --- /dev/null +++ b/test/core/bad_client/tests/window_overflow.cc @@ -0,0 +1,102 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/surface/server.h" +#include "test/core/bad_client/bad_client.h" + +#define PFX_STR \ + "\x00\x00\x00\x04\x01\x00\x00\x00\x00" \ + "\x00\x00\xc9\x01\x04\x00\x00\x00\x01" /* headers: generated from \ + simple_request.headers in this \ + directory */ \ + "\x10\x05:path\x08/foo/bar" \ + "\x10\x07:scheme\x04http" \ + "\x10\x07:method\x04POST" \ + "\x10\x0a:authority\x09localhost" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x14grpc-accept-encoding\x15" \ + "deflate,identity,gzip" \ + "\x10\x02te\x08trailers" \ + "\x10\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)" + +static void verifier(grpc_server* server, grpc_completion_queue* cq, + void* /*registered_method*/) { + while (server->core_server->HasOpenConnections()) { + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(20), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + } +} + +char* g_buffer; +size_t g_cap = 0; +size_t g_count = 0; + +static void addbuf(const void* data, size_t len) { + if (g_count + len > g_cap) { + g_cap = std::max(g_count + len, g_cap * 2); + g_buffer = static_cast(gpr_realloc(g_buffer, g_cap)); + } + memcpy(g_buffer + g_count, data, len); + g_count += len; +} + +int main(int argc, char** argv) { + int i, j; +#define MAX_FRAME_SIZE 16384 +#define MESSAGES_PER_FRAME (MAX_FRAME_SIZE / 5) +#define FRAME_SIZE (MESSAGES_PER_FRAME * 5) +#define SEND_SIZE (4 * 1024 * 1024) +#define NUM_FRAMES (SEND_SIZE / FRAME_SIZE + 1) + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + addbuf(PFX_STR, sizeof(PFX_STR) - 1); + for (i = 0; i < NUM_FRAMES; i++) { + uint8_t hdr[9] = {static_cast(FRAME_SIZE >> 16), + static_cast(FRAME_SIZE >> 8), + static_cast + FRAME_SIZE, + 0, + 0, + 0, + 0, + 0, + 1}; + addbuf(hdr, sizeof(hdr)); + for (j = 0; j < MESSAGES_PER_FRAME; j++) { + uint8_t message[5] = {0, 0, 0, 0, 0}; + addbuf(message, sizeof(message)); + } + } + grpc_bad_client_arg bca[2]; + bca[0] = connection_preface_arg; + bca[1] = {rst_stream_client_validator, nullptr, g_buffer, g_count}; + grpc_run_bad_client_test(verifier, bca, 2, GRPC_BAD_CLIENT_LARGE_REQUEST); + gpr_free(g_buffer); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/bad_connection/close_fd_test.cc b/test/core/bad_connection/close_fd_test.cc new file mode 100644 index 00000000..5d8e19f5 --- /dev/null +++ b/test/core/bad_connection/close_fd_test.cc @@ -0,0 +1,764 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * close_fd_test tests the behavior of grpc core when the transport gets + * disconnected. + * The test creates an http2 transport over a socket pair and closes the + * client or server file descriptor to simulate connection breakage while + * an RPC call is in progress. + * + */ +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +typedef struct test_ctx test_ctx; + +struct test_ctx { + /* completion queue for call notifications on the server */ + grpc_completion_queue* cq; + /* completion queue registered to server for shutdown events */ + grpc_completion_queue* shutdown_cq; + /* client's completion queue */ + grpc_completion_queue* client_cq; + /* completion queue bound to call on the server */ + grpc_completion_queue* bound_cq; + /* Server responds to client calls */ + grpc_server* server; + /* Client calls are sent over the channel */ + grpc_channel* client; + /* encapsulates client, server endpoints */ + grpc_endpoint_pair* ep; +}; + +static test_ctx g_ctx; + +/* chttp2 transport that is immediately available (used for testing + connected_channel without a client_channel */ + +static void server_setup_transport(grpc_transport* transport) { + grpc_core::ExecCtx exec_ctx; + grpc_endpoint_add_to_pollset(g_ctx.ep->server, grpc_cq_pollset(g_ctx.cq)); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "SetupTransport", + g_ctx.server->core_server->SetupTransport( + transport, nullptr, g_ctx.server->core_server->channel_args(), + nullptr))); +} + +static void client_setup_transport(grpc_transport* transport) { + grpc_core::ExecCtx exec_ctx; + grpc_endpoint_add_to_pollset(g_ctx.ep->client, + grpc_cq_pollset(g_ctx.client_cq)); + grpc_arg authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test-authority")); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(nullptr, &authority_arg, 1); + /* TODO (pjaikumar): use GRPC_CLIENT_CHANNEL instead of + * GRPC_CLIENT_DIRECT_CHANNEL */ + g_ctx.client = + grpc_channel_create("socketpair-target", args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, nullptr); + grpc_channel_args_destroy(args); +} + +static void init_client() { + grpc_core::ExecCtx exec_ctx; + grpc_transport* transport; + transport = grpc_create_chttp2_transport( + nullptr, g_ctx.ep->client, true, grpc_resource_user_create_unlimited()); + client_setup_transport(transport); + GPR_ASSERT(g_ctx.client); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); +} + +static void init_server() { + grpc_core::ExecCtx exec_ctx; + grpc_transport* transport; + GPR_ASSERT(!g_ctx.server); + g_ctx.server = grpc_server_create(nullptr, nullptr); + grpc_server_register_completion_queue(g_ctx.server, g_ctx.cq, nullptr); + grpc_server_start(g_ctx.server); + transport = grpc_create_chttp2_transport( + nullptr, g_ctx.ep->server, false, grpc_resource_user_create_unlimited()); + server_setup_transport(transport); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); +} + +static void test_init() { + grpc_endpoint_pair* sfd = + static_cast(gpr_malloc(sizeof(grpc_endpoint_pair))); + memset(&g_ctx, 0, sizeof(g_ctx)); + g_ctx.ep = sfd; + g_ctx.cq = grpc_completion_queue_create_for_next(nullptr); + g_ctx.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + g_ctx.bound_cq = grpc_completion_queue_create_for_next(nullptr); + g_ctx.client_cq = grpc_completion_queue_create_for_next(nullptr); + + /* Create endpoints */ + *sfd = grpc_iomgr_create_endpoint_pair("fixture", nullptr); + /* Create client, server and setup transport over endpoint pair */ + init_server(); + init_client(); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event event; + do { + event = grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(1), + nullptr); + } while (event.type != GRPC_QUEUE_SHUTDOWN); +} + +static void drain_and_destroy_cq(grpc_completion_queue* cq) { + grpc_completion_queue_shutdown(cq); + drain_cq(cq); + grpc_completion_queue_destroy(cq); +} + +static void shutdown_server() { + if (!g_ctx.server) return; + grpc_server_shutdown_and_notify(g_ctx.server, g_ctx.shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(g_ctx.shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(1), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(g_ctx.server); + g_ctx.server = nullptr; +} + +static void shutdown_client() { + if (!g_ctx.client) return; + grpc_channel_destroy(g_ctx.client); + g_ctx.client = nullptr; +} + +static void end_test() { + shutdown_server(); + shutdown_client(); + + drain_and_destroy_cq(g_ctx.cq); + drain_and_destroy_cq(g_ctx.client_cq); + drain_and_destroy_cq(g_ctx.bound_cq); + grpc_completion_queue_destroy(g_ctx.shutdown_cq); + gpr_free(g_ctx.ep); +} + +typedef enum fd_type { CLIENT_FD, SERVER_FD } fd_type; + +static const char* fd_type_str(fd_type fdtype) { + if (fdtype == CLIENT_FD) { + return "client"; + } else if (fdtype == SERVER_FD) { + return "server"; + } else { + gpr_log(GPR_ERROR, "Unexpected fd_type %d", fdtype); + abort(); + } +} + +static void _test_close_before_server_recv(fd_type fdtype) { + grpc_core::ExecCtx exec_ctx; + grpc_call* call; + grpc_call* server_call; + grpc_event event; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + gpr_log(GPR_INFO, "Running test: test_close_%s_before_server_recv", + fd_type_str(fdtype)); + test_init(); + + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status = GRPC_STATUS__DO_NOT_USE; + grpc_call_error error; + grpc_slice details; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1); + call = grpc_channel_create_call( + g_ctx.client, nullptr, GRPC_PROPAGATE_DEFAULTS, g_ctx.client_cq, + grpc_slice_from_static_string("/foo"), nullptr, deadline, nullptr); + GPR_ASSERT(call); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(call, ops, static_cast(op - ops), + tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(g_ctx.server, &server_call, &call_details, + &request_metadata_recv, g_ctx.bound_cq, + g_ctx.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + event = grpc_completion_queue_next( + g_ctx.cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(101)); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + + grpc_endpoint_pair* sfd = g_ctx.ep; + int fd; + if (fdtype == SERVER_FD) { + fd = sfd->server->vtable->get_fd(sfd->server); + } else { + GPR_ASSERT(fdtype == CLIENT_FD); + fd = sfd->client->vtable->get_fd(sfd->client); + } + /* Connection is closed before the server receives the client's message. */ + close(fd); + + error = grpc_call_start_batch(server_call, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + event = grpc_completion_queue_next( + g_ctx.bound_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + + /* Batch operation completes on the server side. + * event.success will be true if the op completes successfully. + * event.success will be false if the op completes with an error. This can + * happen due to a race with closing the fd resulting in pending writes + * failing due to stream closure. + * */ + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.tag == tag(102)); + + event = grpc_completion_queue_next( + g_ctx.client_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + /* When the client fd is closed, the server gets EPIPE. + * When server fd is closed, server gets EBADF. + * In both cases server sends GRPC_STATUS_UNAVALABLE to the client. However, + * the client may not receive this grpc_status as it's socket is being closed. + * If the client didn't get grpc_status from the server it will time out + * waiting on the completion queue. So there 2 2 possibilities: + * 1. client times out waiting for server's response + * 2. client receives GRPC_STATUS_UNAVAILABLE from server + */ + if (event.type == GRPC_QUEUE_TIMEOUT) { + GPR_ASSERT(event.success == 0); + /* status is not initialized */ + GPR_ASSERT(status == GRPC_STATUS__DO_NOT_USE); + } else { + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(1)); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(call); + grpc_call_unref(server_call); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(); +} + +static void test_close_before_server_recv() { + /* Close client side of the connection before server receives message from + * client */ + _test_close_before_server_recv(CLIENT_FD); + /* Close server side of the connection before server receives message from + * client */ + _test_close_before_server_recv(SERVER_FD); +} + +static void _test_close_before_server_send(fd_type fdtype) { + grpc_core::ExecCtx exec_ctx; + grpc_call* call; + grpc_call* server_call; + grpc_event event; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + gpr_log(GPR_INFO, "Running test: test_close_%s_before_server_send", + fd_type_str(fdtype)); + test_init(); + + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status = GRPC_STATUS__DO_NOT_USE; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1); + call = grpc_channel_create_call( + g_ctx.client, nullptr, GRPC_PROPAGATE_DEFAULTS, g_ctx.client_cq, + grpc_slice_from_static_string("/foo"), nullptr, deadline, nullptr); + GPR_ASSERT(call); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(call, ops, static_cast(op - ops), + tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(g_ctx.server, &server_call, &call_details, + &request_metadata_recv, g_ctx.bound_cq, + g_ctx.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + event = grpc_completion_queue_next( + g_ctx.cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(101)); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(server_call, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + event = grpc_completion_queue_next( + g_ctx.bound_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(102)); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + + grpc_endpoint_pair* sfd = g_ctx.ep; + int fd; + if (fdtype == SERVER_FD) { + fd = sfd->server->vtable->get_fd(sfd->server); + } else { + GPR_ASSERT(fdtype == CLIENT_FD); + fd = sfd->client->vtable->get_fd(sfd->client); + } + + /* Connection is closed before the server sends message and status to the + * client. */ + close(fd); + error = grpc_call_start_batch(server_call, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* Batch operation succeeds on the server side */ + event = grpc_completion_queue_next( + g_ctx.bound_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(103)); + + event = grpc_completion_queue_next( + g_ctx.client_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + /* In both cases server sends GRPC_STATUS_UNAVALABLE to the client. However, + * the client may not receive this grpc_status as it's socket is being closed. + * If the client didn't get grpc_status from the server it will time out + * waiting on the completion queue + */ + if (event.type == GRPC_OP_COMPLETE) { + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.tag == tag(1)); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } else { + GPR_ASSERT(event.type == GRPC_QUEUE_TIMEOUT); + GPR_ASSERT(event.success == 0); + /* status is not initialized */ + GPR_ASSERT(status == GRPC_STATUS__DO_NOT_USE); + } + GPR_ASSERT(was_cancelled == 0); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(call); + grpc_call_unref(server_call); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(); +} + +static void test_close_before_server_send() { + /* Close client side of the connection before server sends message to client + * */ + _test_close_before_server_send(CLIENT_FD); + /* Close server side of the connection before server sends message to client + * */ + _test_close_before_server_send(SERVER_FD); +} + +static void _test_close_before_client_send(fd_type fdtype) { + grpc_core::ExecCtx exec_ctx; + grpc_call* call; + grpc_event event; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + gpr_log(GPR_INFO, "Running test: test_close_%s_before_client_send", + fd_type_str(fdtype)); + test_init(); + + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1); + call = grpc_channel_create_call( + g_ctx.client, nullptr, GRPC_PROPAGATE_DEFAULTS, g_ctx.client_cq, + grpc_slice_from_static_string("/foo"), nullptr, deadline, nullptr); + GPR_ASSERT(call); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + + grpc_endpoint_pair* sfd = g_ctx.ep; + int fd; + if (fdtype == SERVER_FD) { + fd = sfd->server->vtable->get_fd(sfd->server); + } else { + GPR_ASSERT(fdtype == CLIENT_FD); + fd = sfd->client->vtable->get_fd(sfd->client); + } + /* Connection is closed before the client sends a batch to the server */ + close(fd); + + error = grpc_call_start_batch(call, ops, static_cast(op - ops), + tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* Status unavailable is returned to the client when client or server fd is + * closed */ + event = grpc_completion_queue_next( + g_ctx.client_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.success == 1); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.tag == tag(1)); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + + /* No event is received on the server */ + event = grpc_completion_queue_next( + g_ctx.cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.success == 0); + GPR_ASSERT(event.type == GRPC_QUEUE_TIMEOUT); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(call); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(); +} +static void test_close_before_client_send() { + /* Close client side of the connection before client sends message to server + * */ + _test_close_before_client_send(CLIENT_FD); + /* Close server side of the connection before client sends message to server + * */ + _test_close_before_client_send(SERVER_FD); +} + +static void _test_close_before_call_create(fd_type fdtype) { + grpc_core::ExecCtx exec_ctx; + grpc_call* call; + grpc_event event; + test_init(); + + gpr_timespec deadline = grpc_timeout_milliseconds_to_deadline(100); + + grpc_endpoint_pair* sfd = g_ctx.ep; + int fd; + if (fdtype == SERVER_FD) { + fd = sfd->server->vtable->get_fd(sfd->server); + } else { + GPR_ASSERT(fdtype == CLIENT_FD); + fd = sfd->client->vtable->get_fd(sfd->client); + } + /* Connection is closed before the client creates a call */ + close(fd); + + call = grpc_channel_create_call( + g_ctx.client, nullptr, GRPC_PROPAGATE_DEFAULTS, g_ctx.client_cq, + grpc_slice_from_static_string("/foo"), nullptr, deadline, nullptr); + GPR_ASSERT(call); + + /* Client and server time out waiting on their completion queues and nothing + * is sent or received */ + event = grpc_completion_queue_next( + g_ctx.client_cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.type == GRPC_QUEUE_TIMEOUT); + GPR_ASSERT(event.success == 0); + + event = grpc_completion_queue_next( + g_ctx.cq, grpc_timeout_milliseconds_to_deadline(100), nullptr); + GPR_ASSERT(event.type == GRPC_QUEUE_TIMEOUT); + GPR_ASSERT(event.success == 0); + + grpc_call_unref(call); + end_test(); +} + +static void test_close_before_call_create() { + /* Close client side of the connection before client creates a call */ + _test_close_before_call_create(CLIENT_FD); + /* Close server side of the connection before client creates a call */ + _test_close_before_call_create(SERVER_FD); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + /* Init grpc */ + grpc_init(); + int iterations = 10; + + for (int i = 0; i < iterations; ++i) { + test_close_before_call_create(); + test_close_before_client_send(); + test_close_before_server_recv(); + test_close_before_server_send(); + } + + grpc_shutdown(); + + return 0; +} + +#else /* GRPC_POSIX_SOCKET_TCP */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/test/core/bad_ssl/bad_ssl_test.cc b/test/core/bad_ssl/bad_ssl_test.cc new file mode 100644 index 00000000..8fbf52b5 --- /dev/null +++ b/test/core/bad_ssl/bad_ssl_test.cc @@ -0,0 +1,163 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/subprocess.h" +#include "test/core/util/test_config.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void run_test(const char* target, size_t nops) { + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + grpc_channel* channel; + grpc_call* c; + + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_slice details; + grpc_status_code status; + grpc_call_error error; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + cq_verifier* cqv = cq_verifier_create(cq); + + grpc_op ops[6]; + grpc_op* op; + + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args args; + + args.num_args = 1; + args.args = &ssl_name_override; + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + channel = grpc_secure_channel_create(ssl_creds, target, &args, nullptr); + grpc_slice host = grpc_slice_from_static_string("foo.test.google.fr:1234"); + c = grpc_channel_create_call(channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), &host, + deadline, nullptr); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, nops, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status != GRPC_STATUS_OK); + + grpc_call_unref(c); + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_channel_destroy(channel); + grpc_completion_queue_destroy(cq); + cq_verifier_destroy(cqv); + grpc_channel_credentials_release(ssl_creds); +} + +int main(int argc, char** argv) { + char* me = argv[0]; + char* lslash = strrchr(me, '/'); + char* lunder = strrchr(me, '_'); + char* tmp; + char root[1024]; + char test[64]; + int port = grpc_pick_unused_port_or_die(); + char* args[10]; + int status; + size_t i; + gpr_subprocess* svr; + /* figure out where we are */ + if (lslash) { + memcpy(root, me, static_cast(lslash - me)); + root[lslash - me] = 0; + } else { + strcpy(root, "."); + } + if (argc == 2) { + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, argv[1]); + } + /* figure out our test name */ + tmp = lunder - 1; + while (*tmp != '_') tmp--; + tmp++; + memcpy(test, tmp, static_cast(lunder - tmp)); + /* start the server */ + gpr_asprintf(&args[0], "%s/bad_ssl_%s_server%s", root, test, + gpr_subprocess_binary_extension()); + args[1] = const_cast("--bind"); + std::string joined = grpc_core::JoinHostPort("::", port); + args[2] = const_cast(joined.c_str()); + svr = gpr_subprocess_create(4, const_cast(args)); + gpr_free(args[0]); + + for (i = 3; i <= 4; i++) { + grpc_init(); + run_test(args[2], i); + grpc_shutdown(); + } + + gpr_subprocess_interrupt(svr); + status = gpr_subprocess_join(svr); + gpr_subprocess_destroy(svr); + return status; +} diff --git a/test/core/bad_ssl/server_common.cc b/test/core/bad_ssl/server_common.cc new file mode 100644 index 00000000..d3e1a341 --- /dev/null +++ b/test/core/bad_ssl/server_common.cc @@ -0,0 +1,107 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/bad_ssl/server_common.h" + +#include + +#include + +#include "test/core/util/cmdline.h" +#include "test/core/util/test_config.h" + +/* Common server implementation details for all servers in servers/. + * There's nothing *wrong* with these servers per-se, but they are + * configured to cause some failure case in the SSL connection path. + */ + +static int got_sigint = 0; + +static void sigint_handler(int /*x*/) { got_sigint = 1; } + +const char* bad_ssl_addr(int argc, char** argv) { + gpr_cmdline* cl; + const char* addr = nullptr; + cl = gpr_cmdline_create("test server"); + gpr_cmdline_add_string(cl, "bind", "Bind host:port", &addr); + gpr_cmdline_parse(cl, argc, argv); + gpr_cmdline_destroy(cl); + GPR_ASSERT(addr); + return addr; +} + +void bad_ssl_run(grpc_server* server) { + int shutdown_started = 0; + int shutdown_finished = 0; + grpc_event ev; + grpc_call_error error; + grpc_call* s = nullptr; + grpc_call_details call_details; + grpc_metadata_array request_metadata_recv; + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_completion_queue* shutdown_cq; + + grpc_call_details_init(&call_details); + grpc_metadata_array_init(&request_metadata_recv); + + grpc_server_register_completion_queue(server, cq, nullptr); + grpc_server_start(server); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, + reinterpret_cast(1)); + GPR_ASSERT(GRPC_CALL_OK == error); + + signal(SIGINT, sigint_handler); + while (!shutdown_finished) { + if (got_sigint && !shutdown_started) { + gpr_log(GPR_INFO, "Shutting down due to SIGINT"); + shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(server, shutdown_cq, nullptr); + GPR_ASSERT(grpc_completion_queue_pluck( + shutdown_cq, nullptr, grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_completion_queue_destroy(shutdown_cq); + grpc_completion_queue_shutdown(cq); + shutdown_started = 1; + } + ev = grpc_completion_queue_next( + cq, + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000000, GPR_TIMESPAN)), + nullptr); + switch (ev.type) { + case GRPC_OP_COMPLETE: + GPR_ASSERT(ev.tag == (void*)1); + GPR_ASSERT(ev.success == 0); + break; + case GRPC_QUEUE_SHUTDOWN: + GPR_ASSERT(shutdown_started); + shutdown_finished = 1; + break; + case GRPC_QUEUE_TIMEOUT: + break; + } + } + + GPR_ASSERT(s == nullptr); + grpc_call_details_destroy(&call_details); + grpc_metadata_array_destroy(&request_metadata_recv); +} diff --git a/test/core/bad_ssl/servers/alpn.cc b/test/core/bad_ssl/servers/alpn.cc new file mode 100644 index 00000000..27a9aa46 --- /dev/null +++ b/test/core/bad_ssl/servers/alpn.cc @@ -0,0 +1,86 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/bad_ssl/server_common.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +/* This test starts a server that is configured to advertise (via alpn and npn) + * a protocol that the connecting client does not support. It does this by + * overriding the functions declared in alpn.c from the core library. */ + +static const char* const fake_versions[] = {"not-h2"}; + +int grpc_chttp2_is_alpn_version_supported(const char* version, size_t size) { + size_t i; + for (i = 0; i < GPR_ARRAY_SIZE(fake_versions); i++) { + if (!strncmp(version, fake_versions[i], size)) return 1; + } + return 0; +} + +size_t grpc_chttp2_num_alpn_versions(void) { + return GPR_ARRAY_SIZE(fake_versions); +} + +const char* grpc_chttp2_get_alpn_version_index(size_t i) { + GPR_ASSERT(i < GPR_ARRAY_SIZE(fake_versions)); + return fake_versions[i]; +} + +int main(int argc, char** argv) { + const char* addr = bad_ssl_addr(argc, argv); + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds; + grpc_server* server; + + grpc_init(); + ssl_creds = grpc_ssl_server_credentials_create(nullptr, &pem_key_cert_pair, 1, + 0, nullptr); + server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port(server, addr, ssl_creds)); + grpc_server_credentials_release(ssl_creds); + + bad_ssl_run(server); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/bad_ssl/servers/cert.cc b/test/core/bad_ssl/servers/cert.cc new file mode 100644 index 00000000..7b95eb34 --- /dev/null +++ b/test/core/bad_ssl/servers/cert.cc @@ -0,0 +1,64 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/bad_ssl/server_common.h" + +/* This server will present an untrusted cert to the connecting client, + * causing the SSL handshake to fail */ + +int main(int argc, char** argv) { + const char* addr = bad_ssl_addr(argc, argv); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair; + grpc_server_credentials* ssl_creds; + grpc_server* server; + grpc_slice cert_slice, key_slice; + + grpc_init(); + + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file("src/core/tsi/test_creds/badserver.pem", 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file("src/core/tsi/test_creds/badserver.key", 1, &key_slice))); + pem_key_cert_pair.private_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + pem_key_cert_pair.cert_chain = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + + ssl_creds = grpc_ssl_server_credentials_create(nullptr, &pem_key_cert_pair, 1, + 0, nullptr); + server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port(server, addr, ssl_creds)); + grpc_server_credentials_release(ssl_creds); + + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + + bad_ssl_run(server); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/channel/channel_args_test.cc b/test/core/channel/channel_args_test.cc new file mode 100644 index 00000000..d683f515 --- /dev/null +++ b/test/core/channel/channel_args_test.cc @@ -0,0 +1,216 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channel_args.h" + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/test_config.h" + +static void test_create(void) { + grpc_core::ExecCtx exec_ctx; + grpc_arg to_add[2]; + grpc_channel_args* ch_args; + + to_add[0] = + grpc_channel_arg_integer_create(const_cast("int_arg"), 123); + to_add[1] = grpc_channel_arg_string_create(const_cast("str key"), + const_cast("str value")); + ch_args = grpc_channel_args_copy_and_add(nullptr, to_add, 2); + + GPR_ASSERT(ch_args->num_args == 2); + GPR_ASSERT(strcmp(ch_args->args[0].key, to_add[0].key) == 0); + GPR_ASSERT(ch_args->args[0].type == to_add[0].type); + GPR_ASSERT(ch_args->args[0].value.integer == to_add[0].value.integer); + + GPR_ASSERT(strcmp(ch_args->args[1].key, to_add[1].key) == 0); + GPR_ASSERT(ch_args->args[1].type == to_add[1].type); + GPR_ASSERT(strcmp(ch_args->args[1].value.string, to_add[1].value.string) == + 0); + + grpc_channel_args_destroy(ch_args); +} + +struct fake_class { + int foo; +}; + +static void* fake_pointer_arg_copy(void* arg) { + gpr_log(GPR_DEBUG, "fake_pointer_arg_copy"); + fake_class* fc = static_cast(arg); + fake_class* new_fc = static_cast(gpr_malloc(sizeof(fake_class))); + new_fc->foo = fc->foo; + return new_fc; +} + +static void fake_pointer_arg_destroy(void* arg) { + gpr_log(GPR_DEBUG, "fake_pointer_arg_destroy"); + fake_class* fc = static_cast(arg); + gpr_free(fc); +} + +static int fake_pointer_cmp(void* a, void* b) { + return grpc_core::QsortCompare(a, b); +} + +static const grpc_arg_pointer_vtable fake_pointer_arg_vtable = { + fake_pointer_arg_copy, fake_pointer_arg_destroy, fake_pointer_cmp}; + +static void test_channel_create_with_args(void) { + grpc_arg client_a[3]; + + client_a[0] = + grpc_channel_arg_integer_create(const_cast("arg_int"), 0); + client_a[1] = grpc_channel_arg_string_create( + const_cast("arg_str"), const_cast("arg_str_val")); + // allocated and adds custom pointer arg + fake_class* fc = static_cast(gpr_malloc(sizeof(fake_class))); + fc->foo = 42; + client_a[2] = grpc_channel_arg_pointer_create( + const_cast("arg_pointer"), fc, &fake_pointer_arg_vtable); + + // creates channel + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel* c = + grpc_insecure_channel_create("fake_target", &client_args, nullptr); + // user is can free the memory they allocated here + gpr_free(fc); + grpc_channel_destroy(c); +} + +grpc_channel_args* mutate_channel_args(const char* target, + grpc_channel_args* old_args, + grpc_channel_stack_type /*type*/) { + GPR_ASSERT(old_args != nullptr); + GPR_ASSERT(grpc_channel_args_find(old_args, "arg_int")->value.integer == 0); + GPR_ASSERT(strcmp(grpc_channel_args_find(old_args, "arg_str")->value.string, + "arg_str_val") == 0); + GPR_ASSERT( + grpc_channel_args_find(old_args, "arg_pointer")->value.pointer.vtable == + &fake_pointer_arg_vtable); + + if (strcmp(target, "no_op_mutator") == 0) { + return old_args; + } + + GPR_ASSERT(strcmp(target, "minimal_stack_mutator") == 0); + const char* args_to_remove[] = {"arg_int", "arg_str", "arg_pointer"}; + + grpc_arg no_deadline_filter_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MINIMAL_STACK), 1); + grpc_channel_args* new_args = nullptr; + new_args = grpc_channel_args_copy_and_add_and_remove( + old_args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), + &no_deadline_filter_arg, 1); + grpc_channel_args_destroy(old_args); + return new_args; +} + +// Minimal stack should not have client_idle filter +static bool channel_has_client_idle_filter(grpc_channel* c) { + grpc_channel_stack* stack = grpc_channel_get_channel_stack(c); + for (size_t i = 0; i < stack->count; i++) { + if (strcmp(grpc_channel_stack_element(stack, i)->filter->name, + "client_idle") == 0) { + return true; + } + } + return false; +} + +static void test_channel_create_with_global_mutator(void) { + grpc_channel_args_set_client_channel_creation_mutator(mutate_channel_args); + // We also add some custom args to make sure the ownership is correct. + grpc_arg client_a[3]; + + client_a[0] = + grpc_channel_arg_integer_create(const_cast("arg_int"), 0); + client_a[1] = grpc_channel_arg_string_create( + const_cast("arg_str"), const_cast("arg_str_val")); + // allocated and adds custom pointer arg + fake_class* fc = static_cast(gpr_malloc(sizeof(fake_class))); + fc->foo = 42; + client_a[2] = grpc_channel_arg_pointer_create( + const_cast("arg_pointer"), fc, &fake_pointer_arg_vtable); + + // creates channels + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel* c = + grpc_insecure_channel_create("no_op_mutator", &client_args, nullptr); + GPR_ASSERT(channel_has_client_idle_filter(c)); + grpc_channel_destroy(c); + + c = grpc_insecure_channel_create("minimal_stack_mutator", &client_args, + nullptr); + GPR_ASSERT(channel_has_client_idle_filter(c) == false); + grpc_channel_destroy(c); + + gpr_free(fc); + auto mutator = grpc_channel_args_get_client_channel_creation_mutator(); + GPR_ASSERT(mutator == &mutate_channel_args); +} + +static void test_server_create_with_args(void) { + grpc_arg server_a[3]; + + // adds integer arg + server_a[0].type = GRPC_ARG_INTEGER; + server_a[0].key = const_cast("arg_int"); + server_a[0].value.integer = 0; + + // adds const str arg + server_a[1].type = GRPC_ARG_STRING; + server_a[1].key = const_cast("arg_str"); + server_a[1].value.string = const_cast("arg_str_val"); + + // allocated and adds custom pointer arg + fake_class* fc = static_cast(gpr_malloc(sizeof(fake_class))); + fc->foo = 42; + server_a[2].type = GRPC_ARG_POINTER; + server_a[2].key = const_cast("arg_pointer"); + server_a[2].value.pointer.vtable = &fake_pointer_arg_vtable; + server_a[2].value.pointer.p = fc; + + // creates server + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + grpc_server* s = grpc_server_create(&server_args, nullptr); + // user is can free the memory they allocated here + gpr_free(fc); + grpc_server_destroy(s); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_create(); + test_channel_create_with_args(); + test_server_create_with_args(); + // This has to be the last test. + // TODO(markdroth): re-enable this test once client_idle is re-enabled + // test_channel_create_with_global_mutator(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/channel/channel_stack_builder_test.cc b/test/core/channel/channel_stack_builder_test.cc new file mode 100644 index 00000000..88564ce9 --- /dev/null +++ b/test/core/channel/channel_stack_builder_test.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channel_stack_builder.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/util/test_config.h" + +static grpc_error_handle channel_init_func( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static grpc_error_handle call_init_func( + grpc_call_element* /*elem*/, const grpc_call_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void channel_destroy_func(grpc_channel_element* /*elem*/) {} + +static void call_destroy_func(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) {} + +bool g_replacement_fn_called = false; +bool g_original_fn_called = false; +void set_arg_once_fn(grpc_channel_stack* /*channel_stack*/, + grpc_channel_element* /*elem*/, void* arg) { + bool* called = static_cast(arg); + // Make sure this function is only called once per arg. + GPR_ASSERT(*called == false); + *called = true; +} + +static void test_channel_stack_builder_filter_replace(void) { + grpc_channel* channel = + grpc_insecure_channel_create("target name isn't used", nullptr, nullptr); + GPR_ASSERT(channel != nullptr); + // Make sure the high priority filter has been created. + GPR_ASSERT(g_replacement_fn_called); + // ... and that the low priority one hasn't. + GPR_ASSERT(!g_original_fn_called); + grpc_channel_destroy(channel); +} + +const grpc_channel_filter replacement_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, + call_init_func, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + call_destroy_func, + 0, + channel_init_func, + channel_destroy_func, + grpc_channel_next_get_info, + "filter_name"}; + +const grpc_channel_filter original_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, + call_init_func, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + call_destroy_func, + 0, + channel_init_func, + channel_destroy_func, + grpc_channel_next_get_info, + "filter_name"}; + +static bool add_replacement_filter(grpc_channel_stack_builder* builder) { + // Get rid of any other version of the filter, as determined by having the + // same name. + GPR_ASSERT(grpc_channel_stack_builder_remove_filter(builder, + replacement_filter.name)); + return grpc_channel_stack_builder_prepend_filter( + builder, &replacement_filter, set_arg_once_fn, &g_replacement_fn_called); +} + +static bool add_original_filter(grpc_channel_stack_builder* builder) { + return grpc_channel_stack_builder_prepend_filter( + builder, &original_filter, set_arg_once_fn, &g_original_fn_called); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_core::CoreConfiguration::RegisterBuilder( + [](grpc_core::CoreConfiguration::Builder* builder) { + builder->channel_init()->RegisterStage(GRPC_CLIENT_CHANNEL, INT_MAX, + add_original_filter); + builder->channel_init()->RegisterStage(GRPC_CLIENT_CHANNEL, INT_MAX, + add_replacement_filter); + }); + grpc_init(); + test_channel_stack_builder_filter_replace(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/channel/channel_stack_test.cc b/test/core/channel/channel_stack_test.cc new file mode 100644 index 00000000..0f2861b4 --- /dev/null +++ b/test/core/channel/channel_stack_test.cc @@ -0,0 +1,158 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channel_stack.h" + +#include + +#include +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +static grpc_error_handle channel_init_func(grpc_channel_element* elem, + grpc_channel_element_args* args) { + GPR_ASSERT(args->channel_args->num_args == 1); + GPR_ASSERT(args->channel_args->args[0].type == GRPC_ARG_INTEGER); + GPR_ASSERT(0 == strcmp(args->channel_args->args[0].key, "test_key")); + GPR_ASSERT(args->channel_args->args[0].value.integer == 42); + GPR_ASSERT(args->is_first); + GPR_ASSERT(args->is_last); + *static_cast(elem->channel_data) = 0; + return GRPC_ERROR_NONE; +} + +static grpc_error_handle call_init_func( + grpc_call_element* elem, const grpc_call_element_args* /*args*/) { + ++*static_cast(elem->channel_data); + *static_cast(elem->call_data) = 0; + return GRPC_ERROR_NONE; +} + +static void channel_destroy_func(grpc_channel_element* /*elem*/) {} + +static void call_destroy_func(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + ++*static_cast(elem->channel_data); +} + +static void call_func(grpc_call_element* elem, + grpc_transport_stream_op_batch* /*op*/) { + ++*static_cast(elem->call_data); +} + +static void channel_func(grpc_channel_element* elem, + grpc_transport_op* /*op*/) { + ++*static_cast(elem->channel_data); +} + +static void free_channel(void* arg, grpc_error_handle /*error*/) { + grpc_channel_stack_destroy(static_cast(arg)); + gpr_free(arg); +} + +static void free_call(void* arg, grpc_error_handle /*error*/) { + grpc_call_stack_destroy(static_cast(arg), nullptr, nullptr); + gpr_free(arg); +} + +static void test_create_channel_stack(void) { + const grpc_channel_filter filter = { + call_func, + channel_func, + sizeof(int), + call_init_func, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + call_destroy_func, + sizeof(int), + channel_init_func, + channel_destroy_func, + grpc_channel_next_get_info, + "some_test_filter"}; + const grpc_channel_filter* filters = &filter; + grpc_channel_stack* channel_stack; + grpc_call_stack* call_stack; + grpc_channel_element* channel_elem; + grpc_call_element* call_elem; + grpc_arg arg; + grpc_channel_args chan_args; + int* channel_data; + int* call_data; + grpc_core::ExecCtx exec_ctx; + grpc_slice path = grpc_slice_from_static_string("/service/method"); + + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast("test_key"); + arg.value.integer = 42; + + chan_args.num_args = 1; + chan_args.args = &arg; + + channel_stack = static_cast( + gpr_malloc(grpc_channel_stack_size(&filters, 1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "grpc_channel_stack_init", + grpc_channel_stack_init(1, free_channel, channel_stack, &filters, 1, + &chan_args, nullptr, "test", channel_stack))); + GPR_ASSERT(channel_stack->count == 1); + channel_elem = grpc_channel_stack_element(channel_stack, 0); + channel_data = static_cast(channel_elem->channel_data); + GPR_ASSERT(*channel_data == 0); + + call_stack = + static_cast(gpr_malloc(channel_stack->call_stack_size)); + const grpc_call_element_args args = { + call_stack, /* call_stack */ + nullptr, /* server_transport_data */ + nullptr, /* context */ + path, /* path */ + gpr_get_cycle_counter(), /* start_time */ + GRPC_MILLIS_INF_FUTURE, /* deadline */ + nullptr, /* arena */ + nullptr, /* call_combiner */ + }; + grpc_error_handle error = + grpc_call_stack_init(channel_stack, 1, free_call, call_stack, &args); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(call_stack->count == 1); + call_elem = grpc_call_stack_element(call_stack, 0); + GPR_ASSERT(call_elem->filter == channel_elem->filter); + GPR_ASSERT(call_elem->channel_data == channel_elem->channel_data); + call_data = static_cast(call_elem->call_data); + GPR_ASSERT(*call_data == 0); + GPR_ASSERT(*channel_data == 1); + + GRPC_CALL_STACK_UNREF(call_stack, "done"); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(*channel_data == 2); + + GRPC_CHANNEL_STACK_UNREF(channel_stack, "done"); + + grpc_slice_unref_internal(path); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_create_channel_stack(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/channel/channel_trace_test.cc b/test/core/channel/channel_trace_test.cc new file mode 100644 index 00000000..824bd21e --- /dev/null +++ b/test/core/channel/channel_trace_test.cc @@ -0,0 +1,329 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channel_trace.h" + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/channel_trace_proto_helper.h" + +namespace grpc_core { +namespace channelz { +namespace testing { + +// testing peer to access channel internals +class ChannelNodePeer { + public: + explicit ChannelNodePeer(ChannelNode* node) : node_(node) {} + ChannelTrace* trace() const { return &node_->trace_; } + + private: + ChannelNode* node_; +}; + +size_t GetSizeofTraceEvent() { return sizeof(ChannelTrace::TraceEvent); } + +namespace { + +void ValidateJsonArraySize(const Json& array, size_t expected) { + if (expected == 0) { + ASSERT_EQ(array.type(), Json::Type::JSON_NULL); + } else { + ASSERT_EQ(array.type(), Json::Type::ARRAY); + EXPECT_EQ(array.array_value().size(), expected); + } +} + +void ValidateChannelTraceData(const Json& json, + size_t num_events_logged_expected, + size_t actual_num_events_expected) { + ASSERT_EQ(json.type(), Json::Type::OBJECT); + Json::Object object = json.object_value(); + Json& num_events_logged_json = object["numEventsLogged"]; + ASSERT_EQ(num_events_logged_json.type(), Json::Type::STRING); + size_t num_events_logged = static_cast( + strtol(num_events_logged_json.string_value().c_str(), nullptr, 0)); + ASSERT_EQ(num_events_logged, num_events_logged_expected); + Json& start_time_json = object["creationTimestamp"]; + ASSERT_EQ(start_time_json.type(), Json::Type::STRING); + ValidateJsonArraySize(object["events"], actual_num_events_expected); +} + +void AddSimpleTrace(ChannelTrace* tracer) { + tracer->AddTraceEvent(ChannelTrace::Severity::Info, + grpc_slice_from_static_string("simple trace")); +} + +// checks for the existence of all the required members of the tracer. +void ValidateChannelTraceCustom(ChannelTrace* tracer, size_t num_events_logged, + size_t num_events_expected) { + Json json = tracer->RenderJson(); + ASSERT_EQ(json.type(), Json::Type::OBJECT); + std::string json_str = json.Dump(); + grpc::testing::ValidateChannelTraceProtoJsonTranslation(json_str.c_str()); + ValidateChannelTraceData(json, num_events_logged, num_events_expected); +} + +void ValidateChannelTrace(ChannelTrace* tracer, size_t num_events_logged) { + ValidateChannelTraceCustom(tracer, num_events_logged, num_events_logged); +} + +class ChannelFixture { + public: + explicit ChannelFixture(int max_tracer_event_memory) { + grpc_arg client_a = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + max_tracer_event_memory); + grpc_channel_args client_args = {1, &client_a}; + channel_ = + grpc_insecure_channel_create("fake_target", &client_args, nullptr); + } + + ~ChannelFixture() { grpc_channel_destroy(channel_); } + + grpc_channel* channel() { return channel_; } + + private: + grpc_channel* channel_; +}; + +} // anonymous namespace + +const int kEventListMemoryLimit = 1024 * 1024; + +// Tests basic ChannelTrace functionality like construction, adding trace, and +// lookups by uuid. +TEST(ChannelTracerTest, BasicTest) { + grpc_core::ExecCtx exec_ctx; + ChannelTrace tracer(kEventListMemoryLimit); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + tracer.AddTraceEvent(ChannelTrace::Severity::Info, + grpc_slice_from_static_string("trace three")); + tracer.AddTraceEvent(ChannelTrace::Severity::Error, + grpc_slice_from_static_string("trace four error")); + ValidateChannelTrace(&tracer, 4); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 6); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 10); +} + +// Tests more complex functionality, like a parent channel tracking +// subchannles. This exercises the ref/unref patterns since the parent tracer +// and this function will both hold refs to the subchannel. +TEST(ChannelTracerTest, ComplexTest) { + grpc_core::ExecCtx exec_ctx; + ChannelTrace tracer(kEventListMemoryLimit); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ChannelFixture channel1(kEventListMemoryLimit); + RefCountedPtr sc1 = + MakeRefCounted("fake_target", kEventListMemoryLimit, 0); + ChannelNodePeer sc1_peer(sc1.get()); + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Info, + grpc_slice_from_static_string("subchannel one created"), sc1); + ValidateChannelTrace(&tracer, 3); + AddSimpleTrace(sc1_peer.trace()); + AddSimpleTrace(sc1_peer.trace()); + AddSimpleTrace(sc1_peer.trace()); + ValidateChannelTrace(sc1_peer.trace(), 3); + AddSimpleTrace(sc1_peer.trace()); + AddSimpleTrace(sc1_peer.trace()); + AddSimpleTrace(sc1_peer.trace()); + ValidateChannelTrace(sc1_peer.trace(), 6); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 5); + ChannelFixture channel2(kEventListMemoryLimit); + RefCountedPtr sc2 = + MakeRefCounted("fake_target", kEventListMemoryLimit, 0); + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Info, + grpc_slice_from_static_string("LB channel two created"), sc2); + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Warning, + grpc_slice_from_static_string("subchannel one inactive"), sc1); + ValidateChannelTrace(&tracer, 7); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + sc1.reset(); + sc2.reset(); +} + +// Test a case in which the parent channel has subchannels and the subchannels +// have connections. Ensures that everything lives as long as it should then +// gets deleted. +TEST(ChannelTracerTest, TestNesting) { + grpc_core::ExecCtx exec_ctx; + ChannelTrace tracer(kEventListMemoryLimit); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 2); + ChannelFixture channel1(kEventListMemoryLimit); + RefCountedPtr sc1 = + MakeRefCounted("fake_target", kEventListMemoryLimit, 0); + ChannelNodePeer sc1_peer(sc1.get()); + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Info, + grpc_slice_from_static_string("subchannel one created"), sc1); + ValidateChannelTrace(&tracer, 3); + AddSimpleTrace(sc1_peer.trace()); + ChannelFixture channel2(kEventListMemoryLimit); + RefCountedPtr conn1 = + MakeRefCounted("fake_target", kEventListMemoryLimit, 0); + ChannelNodePeer conn1_peer(conn1.get()); + // nesting one level deeper. + sc1_peer.trace()->AddTraceEventWithReference( + ChannelTrace::Severity::Info, + grpc_slice_from_static_string("connection one created"), conn1); + ValidateChannelTrace(&tracer, 3); + AddSimpleTrace(conn1_peer.trace()); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 5); + ValidateChannelTrace(conn1_peer.trace(), 1); + ChannelFixture channel3(kEventListMemoryLimit); + RefCountedPtr sc2 = + MakeRefCounted("fake_target", kEventListMemoryLimit, 0); + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Info, + grpc_slice_from_static_string("subchannel two created"), sc2); + // this trace should not get added to the parents children since it is already + // present in the tracer. + tracer.AddTraceEventWithReference( + ChannelTrace::Severity::Warning, + grpc_slice_from_static_string("subchannel one inactive"), sc1); + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, 8); + sc1.reset(); + sc2.reset(); + conn1.reset(); +} + +TEST(ChannelTracerTest, TestSmallMemoryLimit) { + grpc_core::ExecCtx exec_ctx; + // doesn't make sense, but serves a testing purpose for the channel tracing + // bookkeeping. All tracing events added should will get immediately garbage + // collected. + const int kSmallMemoryLimit = 1; + ChannelTrace tracer(kSmallMemoryLimit); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + tracer.AddTraceEvent(ChannelTrace::Severity::Info, + grpc_slice_from_static_string("trace three")); + tracer.AddTraceEvent(ChannelTrace::Severity::Error, + grpc_slice_from_static_string("trace four error")); + ValidateChannelTraceCustom(&tracer, 4, 0); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTraceCustom(&tracer, 6, 0); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + AddSimpleTrace(&tracer); + ValidateChannelTraceCustom(&tracer, 10, 0); +} + +TEST(ChannelTracerTest, TestEviction) { + grpc_core::ExecCtx exec_ctx; + const int kTraceEventSize = GetSizeofTraceEvent(); + const int kNumEvents = 5; + ChannelTrace tracer(kTraceEventSize * kNumEvents); + for (int i = 1; i <= kNumEvents; ++i) { + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, i); + } + // at this point the list is full, and each subsequent enntry will cause an + // eviction. + for (int i = 1; i <= kNumEvents; ++i) { + AddSimpleTrace(&tracer); + ValidateChannelTraceCustom(&tracer, kNumEvents + i, kNumEvents); + } +} + +TEST(ChannelTracerTest, TestMultipleEviction) { + grpc_core::ExecCtx exec_ctx; + const int kTraceEventSize = GetSizeofTraceEvent(); + const int kNumEvents = 5; + ChannelTrace tracer(kTraceEventSize * kNumEvents); + for (int i = 1; i <= kNumEvents; ++i) { + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, i); + } + // at this point the list is full, and each subsequent enntry will cause an + // eviction. We will now add in a trace event that has a copied string. This + // uses more memory, so it will cause a double eviciction + tracer.AddTraceEvent( + ChannelTrace::Severity::Info, + grpc_slice_from_copied_string( + "long enough string to trigger a multiple eviction")); + ValidateChannelTraceCustom(&tracer, kNumEvents + 1, kNumEvents - 1); +} + +TEST(ChannelTracerTest, TestTotalEviction) { + grpc_core::ExecCtx exec_ctx; + const int kTraceEventSize = GetSizeofTraceEvent(); + const int kNumEvents = 5; + ChannelTrace tracer(kTraceEventSize * kNumEvents); + for (int i = 1; i <= kNumEvents; ++i) { + AddSimpleTrace(&tracer); + ValidateChannelTrace(&tracer, i); + } + // at this point the list is full. Now we add such a big slice that + // everything gets evicted. + grpc_slice huge_slice = grpc_slice_malloc(kTraceEventSize * (kNumEvents + 1)); + tracer.AddTraceEvent(ChannelTrace::Severity::Info, huge_slice); + ValidateChannelTraceCustom(&tracer, kNumEvents + 1, 0); +} + +} // namespace testing +} // namespace channelz +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/channel/channelz_registry_test.cc b/test/core/channel/channelz_registry_test.cc new file mode 100644 index 00000000..166dcdfd --- /dev/null +++ b/test/core/channel/channelz_registry_test.cc @@ -0,0 +1,151 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channelz_registry.h" + +#include +#include + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace channelz { +namespace testing { + +class ChannelzRegistryTest : public ::testing::Test { + protected: + // ensure we always have a fresh registry for tests. + void SetUp() override { ChannelzRegistry::Init(); } + + void TearDown() override { ChannelzRegistry::Shutdown(); } +}; + +static RefCountedPtr CreateTestNode() { + return MakeRefCounted("test", "test"); +} + +TEST_F(ChannelzRegistryTest, UuidStartsAboveZeroTest) { + RefCountedPtr channelz_channel = CreateTestNode(); + intptr_t uuid = channelz_channel->uuid(); + EXPECT_GT(uuid, 0) << "First uuid chose must be greater than zero. Zero if " + "reserved according to " + "https://github.com/grpc/proposal/blob/master/" + "A14-channelz.md"; +} + +TEST_F(ChannelzRegistryTest, UuidsAreIncreasing) { + std::vector> channelz_channels; + channelz_channels.reserve(10); + for (int i = 0; i < 10; ++i) { + channelz_channels.push_back(CreateTestNode()); + } + for (size_t i = 1; i < channelz_channels.size(); ++i) { + EXPECT_LT(channelz_channels[i - 1]->uuid(), channelz_channels[i]->uuid()) + << "Uuids must always be increasing"; + } +} + +TEST_F(ChannelzRegistryTest, RegisterGetTest) { + RefCountedPtr channelz_channel = CreateTestNode(); + RefCountedPtr retrieved = + ChannelzRegistry::Get(channelz_channel->uuid()); + EXPECT_EQ(channelz_channel, retrieved); +} + +TEST_F(ChannelzRegistryTest, RegisterManyItems) { + std::vector> channelz_channels; + for (int i = 0; i < 100; i++) { + channelz_channels.push_back(CreateTestNode()); + RefCountedPtr retrieved = + ChannelzRegistry::Get(channelz_channels[i]->uuid()); + EXPECT_EQ(channelz_channels[i], retrieved); + } +} + +TEST_F(ChannelzRegistryTest, NullIfNotPresentTest) { + RefCountedPtr channelz_channel = CreateTestNode(); + // try to pull out a uuid that does not exist. + RefCountedPtr nonexistant = + ChannelzRegistry::Get(channelz_channel->uuid() + 1); + EXPECT_EQ(nonexistant, nullptr); + RefCountedPtr retrieved = + ChannelzRegistry::Get(channelz_channel->uuid()); + EXPECT_EQ(channelz_channel, retrieved); +} + +TEST_F(ChannelzRegistryTest, TestUnregistration) { + const int kLoopIterations = 100; + // These channels will stay in the registry for the duration of the test. + std::vector> even_channels; + even_channels.reserve(kLoopIterations); + std::vector odd_uuids; + odd_uuids.reserve(kLoopIterations); + { + // These channels will unregister themselves at the end of this block. + std::vector> odd_channels; + odd_channels.reserve(kLoopIterations); + for (int i = 0; i < kLoopIterations; i++) { + even_channels.push_back(CreateTestNode()); + odd_channels.push_back(CreateTestNode()); + odd_uuids.push_back(odd_channels[i]->uuid()); + } + } + // Check that the even channels are present and the odd channels are not. + for (int i = 0; i < kLoopIterations; i++) { + RefCountedPtr retrieved = + ChannelzRegistry::Get(even_channels[i]->uuid()); + EXPECT_EQ(even_channels[i], retrieved); + retrieved = ChannelzRegistry::Get(odd_uuids[i]); + EXPECT_EQ(retrieved, nullptr); + } + // Add more channels and verify that they get added correctly, to make + // sure that the unregistration didn't leave the registry in a weird state. + std::vector> more_channels; + more_channels.reserve(kLoopIterations); + for (int i = 0; i < kLoopIterations; i++) { + more_channels.push_back(CreateTestNode()); + RefCountedPtr retrieved = + ChannelzRegistry::Get(more_channels[i]->uuid()); + EXPECT_EQ(more_channels[i], retrieved); + } +} + +} // namespace testing +} // namespace channelz +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/channel/channelz_test.cc b/test/core/channel/channelz_test.cc new file mode 100644 index 00000000..bd59ea40 --- /dev/null +++ b/test/core/channel/channelz_test.cc @@ -0,0 +1,567 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/channelz.h" + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_trace.h" +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/channel_trace_proto_helper.h" + +namespace grpc_core { +namespace channelz { +namespace testing { + +// testing peer to access channel internals +class CallCountingHelperPeer { + public: + explicit CallCountingHelperPeer(CallCountingHelper* node) : node_(node) {} + + gpr_timespec last_call_started_time() const { + CallCountingHelper::CounterData data; + node_->CollectData(&data); + return gpr_cycle_counter_to_time(data.last_call_started_cycle); + } + + private: + CallCountingHelper* node_; +}; + +namespace { + +std::vector GetUuidListFromArray(const Json::Array& arr) { + std::vector uuids; + for (const Json& value : arr) { + EXPECT_EQ(value.type(), Json::Type::OBJECT); + if (value.type() != Json::Type::OBJECT) continue; + const Json::Object& object = value.object_value(); + auto it = object.find("ref"); + EXPECT_NE(it, object.end()); + if (it == object.end()) continue; + EXPECT_EQ(it->second.type(), Json::Type::OBJECT); + if (it->second.type() != Json::Type::OBJECT) continue; + const Json::Object& ref_object = it->second.object_value(); + it = ref_object.find("channelId"); + EXPECT_NE(it, ref_object.end()); + if (it != ref_object.end()) { + uuids.push_back(atoi(it->second.string_value().c_str())); + } + } + return uuids; +} + +void ValidateJsonArraySize(const Json& array, size_t expected) { + if (expected == 0) { + ASSERT_EQ(array.type(), Json::Type::JSON_NULL); + } else { + ASSERT_EQ(array.type(), Json::Type::ARRAY); + EXPECT_EQ(array.array_value().size(), expected); + } +} + +void ValidateJsonEnd(const Json& json, bool end) { + auto it = json.object_value().find("end"); + if (end) { + ASSERT_NE(it, json.object_value().end()); + EXPECT_EQ(it->second.type(), Json::Type::JSON_TRUE); + } else { + ASSERT_EQ(it, json.object_value().end()); + } +} + +void ValidateGetTopChannels(size_t expected_channels) { + std::string json_str = ChannelzRegistry::GetTopChannels(0); + grpc::testing::ValidateGetTopChannelsResponseProtoJsonTranslation( + json_str.c_str()); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + // This check will naturally have to change when we support pagination. + // tracked: https://github.com/grpc/grpc/issues/16019. + ValidateJsonArraySize((*parsed_json.mutable_object())["channel"], + expected_channels); + ValidateJsonEnd(parsed_json, true); + // Also check that the core API formats this correctly. + char* core_api_json_str = grpc_channelz_get_top_channels(0); + grpc::testing::ValidateGetTopChannelsResponseProtoJsonTranslation( + core_api_json_str); + gpr_free(core_api_json_str); +} + +void ValidateGetServers(size_t expected_servers) { + std::string json_str = ChannelzRegistry::GetServers(0); + grpc::testing::ValidateGetServersResponseProtoJsonTranslation( + json_str.c_str()); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + // This check will naturally have to change when we support pagination. + // tracked: https://github.com/grpc/grpc/issues/16019. + ValidateJsonArraySize((*parsed_json.mutable_object())["server"], + expected_servers); + ValidateJsonEnd(parsed_json, true); + // Also check that the core API formats this correctly. + char* core_api_json_str = grpc_channelz_get_servers(0); + grpc::testing::ValidateGetServersResponseProtoJsonTranslation( + core_api_json_str); + gpr_free(core_api_json_str); +} + +class ChannelFixture { + public: + explicit ChannelFixture(int max_tracer_event_memory = 0) { + grpc_arg client_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + max_tracer_event_memory), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true)}; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + channel_ = + grpc_insecure_channel_create("fake_target", &client_args, nullptr); + } + + ~ChannelFixture() { grpc_channel_destroy(channel_); } + + grpc_channel* channel() { return channel_; } + + private: + grpc_channel* channel_; +}; + +class ServerFixture { + public: + explicit ServerFixture(int max_tracer_event_memory = 0) { + grpc_arg server_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + max_tracer_event_memory), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + server_ = grpc_server_create(&server_args, nullptr); + } + + ~ServerFixture() { grpc_server_destroy(server_); } + + grpc_server* server() const { return server_; } + + private: + grpc_server* server_; +}; + +struct ValidateChannelDataArgs { + int64_t calls_started; + int64_t calls_failed; + int64_t calls_succeeded; +}; + +void ValidateChildInteger(const Json::Object& object, const std::string& key, + int64_t expected) { + auto it = object.find(key); + if (expected == 0) { + ASSERT_EQ(it, object.end()); + return; + } + ASSERT_NE(it, object.end()); + ASSERT_EQ(it->second.type(), Json::Type::STRING); + int64_t gotten_number = static_cast( + strtol(it->second.string_value().c_str(), nullptr, 0)); + EXPECT_EQ(gotten_number, expected); +} + +void ValidateCounters(const std::string& json_str, + const ValidateChannelDataArgs& args) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(json.type(), Json::Type::OBJECT); + Json::Object* object = json.mutable_object(); + Json& data = (*object)["data"]; + ASSERT_EQ(data.type(), Json::Type::OBJECT); + ValidateChildInteger(data.object_value(), "callsStarted", args.calls_started); + ValidateChildInteger(data.object_value(), "callsFailed", args.calls_failed); + ValidateChildInteger(data.object_value(), "callsSucceeded", + args.calls_succeeded); +} + +void ValidateChannel(ChannelNode* channel, + const ValidateChannelDataArgs& args) { + std::string json_str = channel->RenderJsonString(); + grpc::testing::ValidateChannelProtoJsonTranslation(json_str.c_str()); + ValidateCounters(json_str, args); + // also check that the core API formats this the correct way + char* core_api_json_str = grpc_channelz_get_channel(channel->uuid()); + grpc::testing::ValidateGetChannelResponseProtoJsonTranslation( + core_api_json_str); + gpr_free(core_api_json_str); +} + +void ValidateServer(ServerNode* server, const ValidateChannelDataArgs& args) { + std::string json_str = server->RenderJsonString(); + grpc::testing::ValidateServerProtoJsonTranslation(json_str.c_str()); + ValidateCounters(json_str, args); + // also check that the core API formats this the correct way + char* core_api_json_str = grpc_channelz_get_server(server->uuid()); + grpc::testing::ValidateGetServerResponseProtoJsonTranslation( + core_api_json_str); + gpr_free(core_api_json_str); +} + +gpr_timespec GetLastCallStartedTime(CallCountingHelper* channel) { + CallCountingHelperPeer peer(channel); + return peer.last_call_started_time(); +} + +void ChannelzSleep(int64_t sleep_us) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(sleep_us, GPR_TIMESPAN))); + grpc_core::ExecCtx::Get()->InvalidateNow(); +} + +} // anonymous namespace + +class ChannelzChannelTest : public ::testing::TestWithParam {}; + +TEST_P(ChannelzChannelTest, BasicChannel) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channel(GetParam()); + ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(channel.channel()); + ValidateChannel(channelz_channel, {0, 0, 0}); +} + +TEST(ChannelzChannelTest, ChannelzDisabled) { + grpc_core::ExecCtx exec_ctx; + // explicitly disable channelz + grpc_arg arg[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), false)}; + grpc_channel_args args = {GPR_ARRAY_SIZE(arg), arg}; + grpc_channel* channel = + grpc_insecure_channel_create("fake_target", &args, nullptr); + ChannelNode* channelz_channel = grpc_channel_get_channelz_node(channel); + ASSERT_EQ(channelz_channel, nullptr); + grpc_channel_destroy(channel); +} + +TEST_P(ChannelzChannelTest, BasicChannelAPIFunctionality) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channel(GetParam()); + ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(channel.channel()); + channelz_channel->RecordCallStarted(); + channelz_channel->RecordCallFailed(); + channelz_channel->RecordCallSucceeded(); + ValidateChannel(channelz_channel, {1, 1, 1}); + channelz_channel->RecordCallStarted(); + channelz_channel->RecordCallFailed(); + channelz_channel->RecordCallSucceeded(); + channelz_channel->RecordCallStarted(); + channelz_channel->RecordCallFailed(); + channelz_channel->RecordCallSucceeded(); + ValidateChannel(channelz_channel, {3, 3, 3}); +} + +TEST_P(ChannelzChannelTest, LastCallStartedTime) { + grpc_core::ExecCtx exec_ctx; + CallCountingHelper counter; + // start a call to set the last call started timestamp + counter.RecordCallStarted(); + gpr_timespec time1 = GetLastCallStartedTime(&counter); + // time gone by should not affect the timestamp + ChannelzSleep(100); + gpr_timespec time2 = GetLastCallStartedTime(&counter); + EXPECT_EQ(gpr_time_cmp(time1, time2), 0); + // calls succeeded or failed should not affect the timestamp + ChannelzSleep(100); + counter.RecordCallFailed(); + counter.RecordCallSucceeded(); + gpr_timespec time3 = GetLastCallStartedTime(&counter); + EXPECT_EQ(gpr_time_cmp(time1, time3), 0); + // another call started should affect the timestamp + // sleep for extra long to avoid flakes (since we cache Now()) + ChannelzSleep(5000); + counter.RecordCallStarted(); + gpr_timespec time4 = GetLastCallStartedTime(&counter); + EXPECT_NE(gpr_time_cmp(time1, time4), 0); +} + +class ChannelzRegistryBasedTest : public ::testing::TestWithParam { + protected: + // ensure we always have a fresh registry for tests. + void SetUp() override { + ChannelzRegistry::Shutdown(); + ChannelzRegistry::Init(); + } + + void TearDown() override { + ChannelzRegistry::Shutdown(); + ChannelzRegistry::Init(); + } +}; + +TEST_F(ChannelzRegistryBasedTest, BasicGetTopChannelsTest) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channel; + ValidateGetTopChannels(1); +} + +TEST_F(ChannelzRegistryBasedTest, NoChannelsTest) { + grpc_core::ExecCtx exec_ctx; + ValidateGetTopChannels(0); +} + +TEST_F(ChannelzRegistryBasedTest, ManyChannelsTest) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channels[10]; + (void)channels; // suppress unused variable error + ValidateGetTopChannels(10); +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsPagination) { + grpc_core::ExecCtx exec_ctx; + // This is over the pagination limit. + ChannelFixture channels[150]; + (void)channels; // suppress unused variable error + std::string json_str = ChannelzRegistry::GetTopChannels(0); + grpc::testing::ValidateGetTopChannelsResponseProtoJsonTranslation( + json_str.c_str()); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + // 100 is the pagination limit. + ValidateJsonArraySize((*parsed_json.mutable_object())["channel"], 100); + ValidateJsonEnd(parsed_json, false); + // Now we get the rest. + json_str = ChannelzRegistry::GetTopChannels(101); + grpc::testing::ValidateGetTopChannelsResponseProtoJsonTranslation( + json_str.c_str()); + error = GRPC_ERROR_NONE; + parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + ValidateJsonArraySize((*parsed_json.mutable_object())["channel"], 50); + ValidateJsonEnd(parsed_json, true); +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsUuidCheck) { + const intptr_t kNumChannels = 50; + grpc_core::ExecCtx exec_ctx; + ChannelFixture channels[kNumChannels]; + (void)channels; // suppress unused variable error + std::string json_str = ChannelzRegistry::GetTopChannels(0); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + Json& array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, kNumChannels); + std::vector uuids = GetUuidListFromArray(array.array_value()); + for (int i = 0; i < kNumChannels; ++i) { + EXPECT_EQ(i + 1, uuids[i]); + } +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsMiddleUuidCheck) { + const intptr_t kNumChannels = 50; + const intptr_t kMidQuery = 40; + grpc_core::ExecCtx exec_ctx; + ChannelFixture channels[kNumChannels]; + (void)channels; // suppress unused variable error + // Only query for the end of the channels. + std::string json_str = ChannelzRegistry::GetTopChannels(kMidQuery); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + Json& array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, kNumChannels - kMidQuery + 1); + std::vector uuids = GetUuidListFromArray(array.array_value()); + for (size_t i = 0; i < uuids.size(); ++i) { + EXPECT_EQ(static_cast(kMidQuery + i), uuids[i]); + } +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsNoHitUuid) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture pre_channels[40]; // will take uuid[1, 40] + (void)pre_channels; // suppress unused variable error + ServerFixture servers[10]; // will take uuid[41, 50] + (void)servers; // suppress unused variable error + ChannelFixture channels[10]; // will take uuid[51, 60] + (void)channels; // suppress unused variable error + // Query in the middle of the server channels. + std::string json_str = ChannelzRegistry::GetTopChannels(45); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + Json& array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, 10); + std::vector uuids = GetUuidListFromArray(array.array_value()); + for (size_t i = 0; i < uuids.size(); ++i) { + EXPECT_EQ(static_cast(51 + i), uuids[i]); + } +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsMoreGaps) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channel_with_uuid1; + { ServerFixture channel_with_uuid2; } + ChannelFixture channel_with_uuid3; + { ServerFixture server_with_uuid4; } + ChannelFixture channel_with_uuid5; + // Current state of list: [1, NULL, 3, NULL, 5] + std::string json_str = ChannelzRegistry::GetTopChannels(2); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + Json array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, 2); + std::vector uuids = GetUuidListFromArray(array.array_value()); + EXPECT_EQ(static_cast(3), uuids[0]); + EXPECT_EQ(static_cast(5), uuids[1]); + json_str = ChannelzRegistry::GetTopChannels(4); + error = GRPC_ERROR_NONE; + parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, 1); + uuids = GetUuidListFromArray(array.array_value()); + EXPECT_EQ(static_cast(5), uuids[0]); +} + +TEST_F(ChannelzRegistryBasedTest, GetTopChannelsUuidAfterCompaction) { + const intptr_t kLoopIterations = 50; + grpc_core::ExecCtx exec_ctx; + std::vector> even_channels; + { + // these will delete and unregister themselves after this block. + std::vector> odd_channels; + for (int i = 0; i < kLoopIterations; i++) { + odd_channels.push_back(absl::make_unique()); + even_channels.push_back(absl::make_unique()); + } + } + std::string json_str = ChannelzRegistry::GetTopChannels(0); + grpc_error_handle error = GRPC_ERROR_NONE; + Json parsed_json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), Json::Type::OBJECT); + Json& array = (*parsed_json.mutable_object())["channel"]; + ValidateJsonArraySize(array, kLoopIterations); + std::vector uuids = GetUuidListFromArray(array.array_value()); + for (int i = 0; i < kLoopIterations; ++i) { + // only the even uuids will still be present. + EXPECT_EQ((i + 1) * 2, uuids[i]); + } +} + +TEST_F(ChannelzRegistryBasedTest, InternalChannelTest) { + grpc_core::ExecCtx exec_ctx; + ChannelFixture channels[10]; + (void)channels; // suppress unused variable error + // create an internal channel + grpc_arg client_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_CHANNELZ_IS_INTERNAL_CHANNEL), 1), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel* internal_channel = + grpc_insecure_channel_create("fake_target", &client_args, nullptr); + // The internal channel should not be returned from the request + ValidateGetTopChannels(10); + grpc_channel_destroy(internal_channel); +} + +TEST(ChannelzServerTest, BasicServerAPIFunctionality) { + grpc_core::ExecCtx exec_ctx; + ServerFixture server(10); + ServerNode* channelz_server = server.server()->core_server->channelz_node(); + channelz_server->RecordCallStarted(); + channelz_server->RecordCallFailed(); + channelz_server->RecordCallSucceeded(); + ValidateServer(channelz_server, {1, 1, 1}); + channelz_server->RecordCallStarted(); + channelz_server->RecordCallFailed(); + channelz_server->RecordCallSucceeded(); + channelz_server->RecordCallStarted(); + channelz_server->RecordCallFailed(); + channelz_server->RecordCallSucceeded(); + ValidateServer(channelz_server, {3, 3, 3}); +} + +TEST_F(ChannelzRegistryBasedTest, BasicGetServersTest) { + grpc_core::ExecCtx exec_ctx; + ServerFixture server; + ValidateGetServers(1); +} + +TEST_F(ChannelzRegistryBasedTest, NoServersTest) { + grpc_core::ExecCtx exec_ctx; + ValidateGetServers(0); +} + +TEST_F(ChannelzRegistryBasedTest, ManyServersTest) { + grpc_core::ExecCtx exec_ctx; + ServerFixture servers[10]; + (void)servers; // suppress unused variable error + ValidateGetServers(10); +} + +INSTANTIATE_TEST_SUITE_P(ChannelzChannelTestSweep, ChannelzChannelTest, + ::testing::Values(0, 8, 64, 1024, 1024 * 1024)); + +} // namespace testing +} // namespace channelz +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/channel/minimal_stack_is_minimal_test.cc b/test/core/channel/minimal_stack_is_minimal_test.cc new file mode 100644 index 00000000..4eda818b --- /dev/null +++ b/test/core/channel/minimal_stack_is_minimal_test.cc @@ -0,0 +1,213 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/******************************************************************************* + * This test verifies that various stack configurations result in the set of + * filters that we expect. + * + * This is akin to a golden-file test, and suffers the same disadvantages and + * advantages: it reflects that the code as written has not been modified - and + * valid code modifications WILL break this test and it will need updating. + * + * The intent therefore is to allow code reviewers to more easily catch changes + * that perturb the generated list of channel filters in different + * configurations and assess whether such a change is correct and desirable. + */ + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/channel_init.h" +#include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/transport_impl.h" +#include "test/core/util/test_config.h" + +// use CHECK_STACK instead +static int check_stack(const char* file, int line, const char* transport_name, + grpc_channel_args* init_args, + unsigned channel_stack_type, ...); + +// arguments: const char *transport_name - the name of the transport type to +// simulate +// grpc_channel_args *init_args - channel args to pass down +// grpc_channel_stack_type channel_stack_type - the archetype of +// channel stack to create +// variadic arguments - the (in-order) expected list of channel +// filters to instantiate, terminated with NULL +#define CHECK_STACK(...) check_stack(__FILE__, __LINE__, __VA_ARGS__) + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int errors = 0; + + // tests with a minimal stack + grpc_arg minimal_stack_arg; + minimal_stack_arg.type = GRPC_ARG_INTEGER; + minimal_stack_arg.key = const_cast(GRPC_ARG_MINIMAL_STACK); + minimal_stack_arg.value.integer = 1; + grpc_channel_args minimal_stack_args = {1, &minimal_stack_arg}; + errors += + CHECK_STACK("unknown", &minimal_stack_args, GRPC_CLIENT_DIRECT_CHANNEL, + "authority", "connected", NULL); + errors += CHECK_STACK("unknown", &minimal_stack_args, GRPC_CLIENT_SUBCHANNEL, + "authority", "connected", NULL); + errors += CHECK_STACK("unknown", &minimal_stack_args, GRPC_SERVER_CHANNEL, + "server", "connected", NULL); + errors += CHECK_STACK("chttp2", &minimal_stack_args, + GRPC_CLIENT_DIRECT_CHANNEL, "authority", "http-client", + "message_decompress", "connected", NULL); + errors += CHECK_STACK("chttp2", &minimal_stack_args, GRPC_CLIENT_SUBCHANNEL, + "authority", "http-client", "message_decompress", + "connected", NULL); + errors += + CHECK_STACK("chttp2", &minimal_stack_args, GRPC_SERVER_CHANNEL, "server", + "http-server", "message_decompress", "connected", NULL); + errors += CHECK_STACK(nullptr, &minimal_stack_args, GRPC_CLIENT_CHANNEL, + "client-channel", NULL); + + // tests with a default stack + errors += + CHECK_STACK("unknown", nullptr, GRPC_CLIENT_DIRECT_CHANNEL, "authority", + "message_size", "deadline", "connected", NULL); + errors += CHECK_STACK("unknown", nullptr, GRPC_CLIENT_SUBCHANNEL, "authority", + "message_size", "connected", NULL); + errors += CHECK_STACK("unknown", nullptr, GRPC_SERVER_CHANNEL, "server", + "message_size", "deadline", "connected", NULL); + errors += + CHECK_STACK("chttp2", nullptr, GRPC_CLIENT_DIRECT_CHANNEL, "authority", + "message_size", "deadline", "http-client", + "message_decompress", "message_compress", "connected", NULL); + errors += CHECK_STACK("chttp2", nullptr, GRPC_CLIENT_SUBCHANNEL, "authority", + "message_size", "http-client", "message_decompress", + "message_compress", "connected", NULL); + errors += + CHECK_STACK("chttp2", nullptr, GRPC_SERVER_CHANNEL, "server", + "message_size", "deadline", "http-server", + "message_decompress", "message_compress", "connected", NULL); + errors += CHECK_STACK(nullptr, nullptr, GRPC_CLIENT_CHANNEL, "client-channel", + NULL); + + GPR_ASSERT(errors == 0); + grpc_shutdown(); + return 0; +} + +/******************************************************************************* + * End of tests definitions, start of test infrastructure + */ + +static int check_stack(const char* file, int line, const char* transport_name, + grpc_channel_args* init_args, + unsigned channel_stack_type, ...) { + // create phony channel stack + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + grpc_transport_vtable fake_transport_vtable; + memset(&fake_transport_vtable, 0, sizeof(grpc_transport_vtable)); + fake_transport_vtable.name = transport_name; + grpc_transport fake_transport = {&fake_transport_vtable}; + grpc_channel_stack_builder_set_target(builder, "foo.test.google.fr"); + grpc_channel_args* channel_args = grpc_channel_args_copy(init_args); + if (transport_name != nullptr) { + grpc_channel_stack_builder_set_transport(builder, &fake_transport); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_stack_builder_set_channel_arguments(builder, channel_args); + GPR_ASSERT(grpc_core::CoreConfiguration::Get().channel_init().CreateStack( + builder, (grpc_channel_stack_type)channel_stack_type)); + } + + // build up our expectation list + std::vector parts; + va_list args; + va_start(args, channel_stack_type); + for (;;) { + char* a = va_arg(args, char*); + if (a == nullptr) break; + parts.push_back(a); + } + va_end(args); + std::string expect = absl::StrJoin(parts, ", "); + + // build up our "got" list + parts.clear(); + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + while (grpc_channel_stack_builder_move_next(it)) { + const char* name = grpc_channel_stack_builder_iterator_filter_name(it); + if (name == nullptr) continue; + parts.push_back(name); + } + std::string got = absl::StrJoin(parts, ", "); + grpc_channel_stack_builder_iterator_destroy(it); + + // figure out result, log if there's an error + int result = 0; + if (got != expect) { + parts.clear(); + for (size_t i = 0; i < channel_args->num_args; i++) { + std::string value; + switch (channel_args->args[i].type) { + case GRPC_ARG_INTEGER: { + value = absl::StrCat(channel_args->args[i].value.integer); + break; + } + case GRPC_ARG_STRING: + value = channel_args->args[i].value.string; + break; + case GRPC_ARG_POINTER: { + value = absl::StrFormat("%p", channel_args->args[i].value.pointer.p); + break; + } + } + parts.push_back(absl::StrCat(channel_args->args[i].key, "=", value)); + } + std::string args_str = absl::StrCat("{", absl::StrJoin(parts, ", "), "}"); + + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, + "**************************************************"); + gpr_log( + file, line, GPR_LOG_SEVERITY_ERROR, + "FAILED transport=%s; stack_type=%s; channel_args=%s:", transport_name, + grpc_channel_stack_type_string( + static_cast(channel_stack_type)), + args_str.c_str()); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, "EXPECTED: %s", expect.c_str()); + gpr_log(file, line, GPR_LOG_SEVERITY_ERROR, "GOT: %s", got.c_str()); + result = 1; + } + + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_stack_builder_destroy(builder); + grpc_channel_args_destroy(channel_args); + } + + return result; +} diff --git a/test/core/channel/status_util_test.cc b/test/core/channel/status_util_test.cc new file mode 100644 index 00000000..1d64bf19 --- /dev/null +++ b/test/core/channel/status_util_test.cc @@ -0,0 +1,49 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/channel/status_util.h" + +#include + +namespace grpc_core { +namespace internal { +namespace { + +TEST(StatusCodeSet, Basic) { + StatusCodeSet set; + EXPECT_TRUE(set.Empty()); + EXPECT_FALSE(set.Contains(GRPC_STATUS_OK)); + EXPECT_FALSE(set.Contains(GRPC_STATUS_UNAVAILABLE)); + set.Add(GRPC_STATUS_OK); + EXPECT_FALSE(set.Empty()); + EXPECT_TRUE(set.Contains(GRPC_STATUS_OK)); + EXPECT_FALSE(set.Contains(GRPC_STATUS_UNAVAILABLE)); + set.Add(GRPC_STATUS_UNAVAILABLE); + EXPECT_FALSE(set.Empty()); + EXPECT_TRUE(set.Contains(GRPC_STATUS_OK)); + EXPECT_TRUE(set.Contains(GRPC_STATUS_UNAVAILABLE)); +} + +} // namespace +} // namespace internal +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/client_channel/certificate_provider_registry_test.cc b/test/core/client_channel/certificate_provider_registry_test.cc new file mode 100644 index 00000000..3717d8b8 --- /dev/null +++ b/test/core/client_channel/certificate_provider_registry_test.cc @@ -0,0 +1,90 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include "src/core/ext/xds/certificate_provider_registry.h" + +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class FakeCertificateProviderFactory1 : public CertificateProviderFactory { + public: + const char* name() const override { return "fake1"; } + + RefCountedPtr CreateCertificateProviderConfig( + const Json& /*config_json*/, grpc_error_handle* /*error*/) override { + return nullptr; + } + + RefCountedPtr CreateCertificateProvider( + RefCountedPtr /*config*/) override { + return nullptr; + } +}; + +class FakeCertificateProviderFactory2 : public CertificateProviderFactory { + public: + const char* name() const override { return "fake2"; } + + RefCountedPtr CreateCertificateProviderConfig( + const Json& /*config_json*/, grpc_error_handle* /*error*/) override { + return nullptr; + } + + RefCountedPtr CreateCertificateProvider( + RefCountedPtr /*config*/) override { + return nullptr; + } +}; + +TEST(CertificateProviderRegistryTest, Basic) { + CertificateProviderRegistry::InitRegistry(); + auto* fake_factory_1 = new FakeCertificateProviderFactory1; + auto* fake_factory_2 = new FakeCertificateProviderFactory2; + CertificateProviderRegistry::RegisterCertificateProviderFactory( + std::unique_ptr(fake_factory_1)); + CertificateProviderRegistry::RegisterCertificateProviderFactory( + std::unique_ptr(fake_factory_2)); + EXPECT_EQ( + CertificateProviderRegistry::LookupCertificateProviderFactory("fake1"), + fake_factory_1); + EXPECT_EQ( + CertificateProviderRegistry::LookupCertificateProviderFactory("fake2"), + fake_factory_2); + EXPECT_EQ( + CertificateProviderRegistry::LookupCertificateProviderFactory("fake3"), + nullptr); + CertificateProviderRegistry::ShutdownRegistry(); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/core/client_channel/resolvers/binder_resolver_test.cc b/test/core/client_channel/resolvers/binder_resolver_test.cc new file mode 100644 index 00000000..4065a393 --- /dev/null +++ b/test/core/client_channel/resolvers/binder_resolver_test.cc @@ -0,0 +1,181 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/iomgr/port.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_HAVE_UNIX_SOCKET + +#include + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/channel/channel_args.h" + +// Registers the factory with `grpc_core::ResolverRegistry`. Defined in +// binder_resolver.cc +void grpc_resolver_binder_init(void); + +namespace { + +class BinderResolverTest : public ::testing::Test { + public: + BinderResolverTest() { + factory_ = grpc_core::ResolverRegistry::LookupResolverFactory("binder"); + } + ~BinderResolverTest() override {} + static void SetUpTestSuite() { + grpc_init(); + if (grpc_core::ResolverRegistry::LookupResolverFactory("binder") == + nullptr) { + // Binder resolver will only be registered on platforms that support + // binder transport. If it is not registered on current platform, we + // manually register it here for testing purpose. + grpc_resolver_binder_init(); + ASSERT_TRUE(grpc_core::ResolverRegistry::LookupResolverFactory("binder")); + } + } + static void TearDownTestSuite() { grpc_shutdown(); } + + void SetUp() override { ASSERT_TRUE(factory_); } + + class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + ResultHandler() = default; + + explicit ResultHandler(const std::string& expected_binder_id) + : expect_result_(true), expected_binder_id_(expected_binder_id) {} + + void ReturnResult(grpc_core::Resolver::Result result) override { + EXPECT_TRUE(expect_result_); + ASSERT_TRUE(result.addresses.size() == 1); + grpc_core::ServerAddress addr = result.addresses[0]; + const struct sockaddr_un* un = + reinterpret_cast(addr.address().addr); + EXPECT_EQ(addr.address().len, + sizeof(un->sun_family) + expected_binder_id_.length() + 1); + EXPECT_EQ(un->sun_family, AF_MAX); + EXPECT_EQ(un->sun_path, expected_binder_id_); + } + + void ReturnError(grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } + + private: + // Whether we expect ReturnResult function to be invoked + bool expect_result_ = false; + + std::string expected_binder_id_; + }; + + void TestSucceeds(const char* string, const std::string& expected_path) { + gpr_log(GPR_DEBUG, "test: '%s' should be valid for '%s'", string, + factory_->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + ASSERT_TRUE(uri.ok()) << uri.status().ToString(); + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.result_handler = + absl::make_unique(expected_path); + grpc_core::OrphanablePtr resolver = + factory_->CreateResolver(std::move(args)); + ASSERT_TRUE(resolver != nullptr); + resolver->StartLocked(); + } + + void TestFails(const char* string) { + gpr_log(GPR_DEBUG, "test: '%s' should be invalid for '%s'", string, + factory_->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + ASSERT_TRUE(uri.ok()) << uri.status().ToString(); + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.result_handler = + absl::make_unique(); + grpc_core::OrphanablePtr resolver = + factory_->CreateResolver(std::move(args)); + EXPECT_TRUE(resolver == nullptr); + } + + private: + grpc_core::ResolverFactory* factory_; +}; + +} // namespace + +// Authority is not allowed +TEST_F(BinderResolverTest, AuthorityPresents) { + TestFails("binder://example"); + TestFails("binder://google.com"); + TestFails("binder://google.com/test"); +} + +// Path cannot be empty +TEST_F(BinderResolverTest, EmptyPath) { + TestFails("binder:"); + TestFails("binder:/"); + TestFails("binder://"); +} + +TEST_F(BinderResolverTest, PathLength) { + // Note that we have a static assert in binder_resolver.cc that checks + // sizeof(sockaddr_un::sun_path) is greater than 100 + + // 100 character path should be fine + TestSucceeds(("binder:l" + std::string(98, 'o') + "g").c_str(), + "l" + std::string(98, 'o') + "g"); + + // 200 character path most likely will fail + TestFails(("binder:l" + std::string(198, 'o') + "g").c_str()); +} + +TEST_F(BinderResolverTest, SlashPrefixes) { + TestSucceeds("binder:///test", "test"); + TestSucceeds("binder:////test", "/test"); +} + +TEST_F(BinderResolverTest, ValidCases) { + TestSucceeds("binder:[[", "[["); + TestSucceeds("binder:google!com", "google!com"); + TestSucceeds("binder:test/", "test/"); + TestSucceeds("binder:test:", "test:"); + + TestSucceeds("binder:e", "e"); + TestSucceeds("binder:example", "example"); + TestSucceeds("binder:google.com", "google.com"); + TestSucceeds("binder:~", "~"); + TestSucceeds("binder:12345", "12345"); + TestSucceeds( + "binder:abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._" + "~", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~"); +} + +#endif + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/client_channel/resolvers/dns_resolver_connectivity_test.cc b/test/core/client_channel/resolvers/dns_resolver_connectivity_test.cc new file mode 100644 index 00000000..d82d92d4 --- /dev/null +++ b/test/core/client_channel/resolvers/dns_resolver_connectivity_test.cc @@ -0,0 +1,196 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/util/test_config.h" + +static gpr_mu g_mu; +static bool g_fail_resolution = true; +static std::shared_ptr* g_work_serializer; + +static void my_resolve_address(const char* addr, const char* /*default_port*/, + grpc_pollset_set* /*interested_parties*/, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + gpr_mu_lock(&g_mu); + GPR_ASSERT(0 == strcmp("test", addr)); + grpc_error_handle error = GRPC_ERROR_NONE; + if (g_fail_resolution) { + g_fail_resolution = false; + gpr_mu_unlock(&g_mu); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Forced Failure"); + } else { + gpr_mu_unlock(&g_mu); + *addrs = static_cast(gpr_malloc(sizeof(**addrs))); + (*addrs)->naddrs = 1; + (*addrs)->addrs = static_cast( + gpr_malloc(sizeof(*(*addrs)->addrs))); + (*addrs)->addrs[0].len = 123; + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, error); +} + +static grpc_address_resolver_vtable test_resolver = {my_resolve_address, + nullptr}; + +static grpc_ares_request* my_dns_lookup_ares_locked( + const char* /*dns_server*/, const char* addr, const char* /*default_port*/, + grpc_pollset_set* /*interested_parties*/, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* /*balancer_addresses*/, + char** /*service_config_json*/, int /*query_timeout_ms*/, + std::shared_ptr /*combiner*/) { // NOLINT + gpr_mu_lock(&g_mu); + GPR_ASSERT(0 == strcmp("test", addr)); + grpc_error_handle error = GRPC_ERROR_NONE; + if (g_fail_resolution) { + g_fail_resolution = false; + gpr_mu_unlock(&g_mu); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Forced Failure"); + } else { + gpr_mu_unlock(&g_mu); + *addresses = absl::make_unique(); + grpc_resolved_address phony_resolved_address; + memset(&phony_resolved_address, 0, sizeof(phony_resolved_address)); + phony_resolved_address.len = 123; + (*addresses)->emplace_back(phony_resolved_address, nullptr); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, error); + return nullptr; +} + +static void my_cancel_ares_request_locked(grpc_ares_request* request) { + GPR_ASSERT(request == nullptr); +} + +static grpc_core::OrphanablePtr create_resolver( + const char* name, + std::unique_ptr result_handler) { + grpc_core::ResolverFactory* factory = + grpc_core::ResolverRegistry::LookupResolverFactory("dns"); + absl::StatusOr uri = grpc_core::URI::Parse(name); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = std::move(result_handler); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + return resolver; +} + +class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + struct ResolverOutput { + grpc_core::Resolver::Result result; + grpc_error_handle error = GRPC_ERROR_NONE; + gpr_event ev; + + ResolverOutput() { gpr_event_init(&ev); } + ~ResolverOutput() { GRPC_ERROR_UNREF(error); } + }; + + void SetOutput(ResolverOutput* output) { + gpr_atm_rel_store(&output_, reinterpret_cast(output)); + } + + void ReturnResult(grpc_core::Resolver::Result result) override { + ResolverOutput* output = + reinterpret_cast(gpr_atm_acq_load(&output_)); + GPR_ASSERT(output != nullptr); + output->result = std::move(result); + output->error = GRPC_ERROR_NONE; + gpr_event_set(&output->ev, reinterpret_cast(1)); + } + + void ReturnError(grpc_error_handle error) override { + ResolverOutput* output = + reinterpret_cast(gpr_atm_acq_load(&output_)); + GPR_ASSERT(output != nullptr); + output->error = error; + gpr_event_set(&output->ev, reinterpret_cast(1)); + } + + private: + gpr_atm output_ = 0; // ResolverOutput* +}; + +// interleave waiting for an event with a timer check +static bool wait_loop(int deadline_seconds, gpr_event* ev) { + while (deadline_seconds) { + gpr_log(GPR_DEBUG, "Test: waiting for %d more seconds", deadline_seconds); + if (gpr_event_wait(ev, grpc_timeout_seconds_to_deadline(1))) return true; + deadline_seconds--; + + grpc_core::ExecCtx exec_ctx; + grpc_timer_check(nullptr); + } + return false; +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + + grpc_init(); + gpr_mu_init(&g_mu); + auto work_serializer = std::make_shared(); + g_work_serializer = &work_serializer; + grpc_set_resolver_impl(&test_resolver); + grpc_dns_lookup_ares_locked = my_dns_lookup_ares_locked; + grpc_cancel_ares_request_locked = my_cancel_ares_request_locked; + + { + grpc_core::ExecCtx exec_ctx; + ResultHandler* result_handler = new ResultHandler(); + grpc_core::OrphanablePtr resolver = create_resolver( + "dns:test", + std::unique_ptr(result_handler)); + ResultHandler::ResolverOutput output1; + result_handler->SetOutput(&output1); + resolver->StartLocked(); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(wait_loop(5, &output1.ev)); + GPR_ASSERT(output1.result.addresses.empty()); + GPR_ASSERT(output1.error != GRPC_ERROR_NONE); + + ResultHandler::ResolverOutput output2; + result_handler->SetOutput(&output2); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(wait_loop(30, &output2.ev)); + GPR_ASSERT(!output2.result.addresses.empty()); + GPR_ASSERT(output2.error == GRPC_ERROR_NONE); + } + + grpc_shutdown(); + gpr_mu_destroy(&g_mu); +} diff --git a/test/core/client_channel/resolvers/dns_resolver_cooldown_test.cc b/test/core/client_channel/resolvers/dns_resolver_cooldown_test.cc new file mode 100644 index 00000000..ed540221 --- /dev/null +++ b/test/core/client_channel/resolvers/dns_resolver_cooldown_test.cc @@ -0,0 +1,354 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/util/test_config.h" + +constexpr int kMinResolutionPeriodMs = 1000; + +extern grpc_address_resolver_vtable* grpc_resolve_address_impl; +static grpc_address_resolver_vtable* default_resolve_address; + +static std::shared_ptr* g_work_serializer; + +static grpc_ares_request* (*g_default_dns_lookup_ares_locked)( + const char* dns_server, const char* name, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* balancer_addresses, + char** service_config_json, int query_timeout_ms, + std::shared_ptr work_serializer); + +// Counter incremented by test_resolve_address_impl indicating the number of +// times a system-level resolution has happened. +static int g_resolution_count; + +static struct iomgr_args { + gpr_event ev; + gpr_atm done_atm; + gpr_mu* mu; + grpc_pollset* pollset; + grpc_pollset_set* pollset_set; +} g_iomgr_args; + +// Wrapper around default resolve_address in order to count the number of +// times we incur in a system-level name resolution. +static void test_resolve_address_impl(const char* name, + const char* default_port, + grpc_pollset_set* /*interested_parties*/, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + default_resolve_address->resolve_address( + name, default_port, g_iomgr_args.pollset_set, on_done, addrs); + ++g_resolution_count; + static grpc_millis last_resolution_time = 0; + if (last_resolution_time == 0) { + last_resolution_time = + grpc_timespec_to_millis_round_up(gpr_now(GPR_CLOCK_MONOTONIC)); + } else { + grpc_millis now = + grpc_timespec_to_millis_round_up(gpr_now(GPR_CLOCK_MONOTONIC)); + GPR_ASSERT(now - last_resolution_time >= kMinResolutionPeriodMs); + last_resolution_time = now; + } + // For correct time diff comparisons, make sure that any subsequent calls + // to grpc_core::ExecCtx::Get()->Now() on this thread don't return a time + // which is earlier than that returned by the call(s) to + // gpr_now(GPR_CLOCK_MONOTONIC) within this function. This is important + // because the resolver's last_resolution_timestamp_ will be taken from + // grpc_core::ExecCtx::Get()->Now() right after this returns. + grpc_core::ExecCtx::Get()->InvalidateNow(); +} + +static grpc_error_handle test_blocking_resolve_address_impl( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + return default_resolve_address->blocking_resolve_address(name, default_port, + addresses); +} + +static grpc_address_resolver_vtable test_resolver = { + test_resolve_address_impl, test_blocking_resolve_address_impl}; + +static grpc_ares_request* test_dns_lookup_ares_locked( + const char* dns_server, const char* name, const char* default_port, + grpc_pollset_set* /*interested_parties*/, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* balancer_addresses, + char** service_config_json, int query_timeout_ms, + std::shared_ptr work_serializer) { + grpc_ares_request* result = g_default_dns_lookup_ares_locked( + dns_server, name, default_port, g_iomgr_args.pollset_set, on_done, + addresses, balancer_addresses, service_config_json, query_timeout_ms, + std::move(work_serializer)); + ++g_resolution_count; + static grpc_millis last_resolution_time = 0; + grpc_millis now = + grpc_timespec_to_millis_round_up(gpr_now(GPR_CLOCK_MONOTONIC)); + gpr_log(GPR_DEBUG, + "last_resolution_time:%" PRId64 " now:%" PRId64 + " min_time_between:%d", + last_resolution_time, now, kMinResolutionPeriodMs); + if (last_resolution_time == 0) { + last_resolution_time = + grpc_timespec_to_millis_round_up(gpr_now(GPR_CLOCK_MONOTONIC)); + } else { + GPR_ASSERT(now - last_resolution_time >= kMinResolutionPeriodMs); + last_resolution_time = now; + } + // For correct time diff comparisons, make sure that any subsequent calls + // to grpc_core::ExecCtx::Get()->Now() on this thread don't return a time + // which is earlier than that returned by the call(s) to + // gpr_now(GPR_CLOCK_MONOTONIC) within this function. This is important + // because the resolver's last_resolution_timestamp_ will be taken from + // grpc_core::ExecCtx::Get()->Now() right after this returns. + grpc_core::ExecCtx::Get()->InvalidateNow(); + return result; +} + +static gpr_timespec test_deadline(void) { + return grpc_timeout_seconds_to_deadline(100); +} + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +static void iomgr_args_init(iomgr_args* args) { + gpr_event_init(&args->ev); + args->pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(args->pollset, &args->mu); + args->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(args->pollset_set, args->pollset); + gpr_atm_rel_store(&args->done_atm, 0); +} + +static void iomgr_args_finish(iomgr_args* args) { + GPR_ASSERT(gpr_event_wait(&args->ev, test_deadline())); + grpc_pollset_set_del_pollset(args->pollset_set, args->pollset); + grpc_pollset_set_destroy(args->pollset_set); + grpc_closure do_nothing_cb; + GRPC_CLOSURE_INIT(&do_nothing_cb, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + gpr_mu_lock(args->mu); + grpc_pollset_shutdown(args->pollset, &do_nothing_cb); + gpr_mu_unlock(args->mu); + // exec_ctx needs to be flushed before calling grpc_pollset_destroy() + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(args->pollset); + gpr_free(args->pollset); +} + +static grpc_millis n_sec_deadline(int seconds) { + return grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(seconds)); +} + +static void poll_pollset_until_request_done(iomgr_args* args) { + grpc_core::ExecCtx exec_ctx; + grpc_millis deadline = n_sec_deadline(10); + while (true) { + bool done = gpr_atm_acq_load(&args->done_atm) != 0; + if (done) { + break; + } + grpc_millis time_left = deadline - grpc_core::ExecCtx::Get()->Now(); + gpr_log(GPR_DEBUG, "done=%d, time_left=%" PRId64, done, time_left); + GPR_ASSERT(time_left >= 0); + grpc_pollset_worker* worker = nullptr; + gpr_mu_lock(args->mu); + GRPC_LOG_IF_ERROR("pollset_work", grpc_pollset_work(args->pollset, &worker, + n_sec_deadline(1))); + gpr_mu_unlock(args->mu); + grpc_core::ExecCtx::Get()->Flush(); + } + gpr_event_set(&args->ev, reinterpret_cast(1)); +} + +struct OnResolutionCallbackArg; + +class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + using ResultCallback = void (*)(OnResolutionCallbackArg* state); + + void SetCallback(ResultCallback result_cb, OnResolutionCallbackArg* state) { + GPR_ASSERT(result_cb_ == nullptr); + result_cb_ = result_cb; + GPR_ASSERT(state_ == nullptr); + state_ = state; + } + + void ReturnResult(grpc_core::Resolver::Result /*result*/) override { + GPR_ASSERT(result_cb_ != nullptr); + GPR_ASSERT(state_ != nullptr); + ResultCallback cb = result_cb_; + OnResolutionCallbackArg* state = state_; + result_cb_ = nullptr; + state_ = nullptr; + cb(state); + } + + void ReturnError(grpc_error_handle error) override { + gpr_log(GPR_ERROR, "resolver returned error: %s", + grpc_error_std_string(error).c_str()); + GPR_ASSERT(false); + } + + private: + ResultCallback result_cb_ = nullptr; + OnResolutionCallbackArg* state_ = nullptr; +}; + +struct OnResolutionCallbackArg { + const char* uri_str = nullptr; + grpc_core::OrphanablePtr resolver; + ResultHandler* result_handler; +}; + +// Set to true by the last callback in the resolution chain. +static bool g_all_callbacks_invoked; + +// It's interesting to run a few rounds of this test because as +// we run more rounds, the base starting time +// (i.e. ExecCtx g_start_time) gets further and further away +// from "Now()". Thus the more rounds ran, the more highlighted the +// difference is between absolute and relative times values. +static void on_fourth_resolution(OnResolutionCallbackArg* cb_arg) { + gpr_log(GPR_INFO, "4th: g_resolution_count: %d", g_resolution_count); + GPR_ASSERT(g_resolution_count == 4); + cb_arg->resolver.reset(); + gpr_atm_rel_store(&g_iomgr_args.done_atm, 1); + gpr_mu_lock(g_iomgr_args.mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(g_iomgr_args.pollset, nullptr)); + gpr_mu_unlock(g_iomgr_args.mu); + delete cb_arg; + g_all_callbacks_invoked = true; +} + +static void on_third_resolution(OnResolutionCallbackArg* cb_arg) { + gpr_log(GPR_INFO, "3rd: g_resolution_count: %d", g_resolution_count); + GPR_ASSERT(g_resolution_count == 3); + cb_arg->result_handler->SetCallback(on_fourth_resolution, cb_arg); + cb_arg->resolver->RequestReresolutionLocked(); + gpr_mu_lock(g_iomgr_args.mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(g_iomgr_args.pollset, nullptr)); + gpr_mu_unlock(g_iomgr_args.mu); +} + +static void on_second_resolution(OnResolutionCallbackArg* cb_arg) { + gpr_log(GPR_INFO, "2nd: g_resolution_count: %d", g_resolution_count); + // The resolution callback was not invoked until new data was + // available, which was delayed until after the cooldown period. + GPR_ASSERT(g_resolution_count == 2); + cb_arg->result_handler->SetCallback(on_third_resolution, cb_arg); + cb_arg->resolver->RequestReresolutionLocked(); + gpr_mu_lock(g_iomgr_args.mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(g_iomgr_args.pollset, nullptr)); + gpr_mu_unlock(g_iomgr_args.mu); +} + +static void on_first_resolution(OnResolutionCallbackArg* cb_arg) { + gpr_log(GPR_INFO, "1st: g_resolution_count: %d", g_resolution_count); + // There's one initial system-level resolution and one invocation of a + // notification callback (the current function). + GPR_ASSERT(g_resolution_count == 1); + cb_arg->result_handler->SetCallback(on_second_resolution, cb_arg); + cb_arg->resolver->RequestReresolutionLocked(); + gpr_mu_lock(g_iomgr_args.mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(g_iomgr_args.pollset, nullptr)); + gpr_mu_unlock(g_iomgr_args.mu); +} + +static void start_test_under_work_serializer(void* arg) { + OnResolutionCallbackArg* res_cb_arg = + static_cast(arg); + res_cb_arg->result_handler = new ResultHandler(); + grpc_core::ResolverFactory* factory = + grpc_core::ResolverRegistry::LookupResolverFactory("dns"); + absl::StatusOr uri = + grpc_core::URI::Parse(res_cb_arg->uri_str); + gpr_log(GPR_DEBUG, "test: '%s' should be valid for '%s'", res_cb_arg->uri_str, + factory->scheme()); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = std::unique_ptr( + res_cb_arg->result_handler); + g_resolution_count = 0; + + grpc_arg cooldown_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS), + kMinResolutionPeriodMs); + grpc_channel_args cooldown_args = {1, &cooldown_arg}; + args.args = &cooldown_args; + res_cb_arg->resolver = factory->CreateResolver(std::move(args)); + GPR_ASSERT(res_cb_arg->resolver != nullptr); + // First resolution, would incur in system-level resolution. + res_cb_arg->result_handler->SetCallback(on_first_resolution, res_cb_arg); + res_cb_arg->resolver->StartLocked(); +} + +static void test_cooldown() { + grpc_core::ExecCtx exec_ctx; + iomgr_args_init(&g_iomgr_args); + OnResolutionCallbackArg* res_cb_arg = new OnResolutionCallbackArg(); + res_cb_arg->uri_str = "dns:127.0.0.1"; + + (*g_work_serializer) + ->Run([res_cb_arg]() { start_test_under_work_serializer(res_cb_arg); }, + DEBUG_LOCATION); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&g_iomgr_args); + iomgr_args_finish(&g_iomgr_args); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + auto work_serializer = std::make_shared(); + g_work_serializer = &work_serializer; + + g_default_dns_lookup_ares_locked = grpc_dns_lookup_ares_locked; + grpc_dns_lookup_ares_locked = test_dns_lookup_ares_locked; + default_resolve_address = grpc_resolve_address_impl; + grpc_set_resolver_impl(&test_resolver); + + test_cooldown(); + + grpc_shutdown(); + GPR_ASSERT(g_all_callbacks_invoked); + return 0; +} diff --git a/test/core/client_channel/resolvers/dns_resolver_test.cc b/test/core/client_channel/resolvers/dns_resolver_test.cc new file mode 100644 index 00000000..7bc867eb --- /dev/null +++ b/test/core/client_channel/resolvers/dns_resolver_test.cc @@ -0,0 +1,100 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/util/test_config.h" + +static std::shared_ptr* g_work_serializer; + +class TestResultHandler : public grpc_core::Resolver::ResultHandler { + void ReturnResult(grpc_core::Resolver::Result /*result*/) override {} + void ReturnError(grpc_error_handle /*error*/) override {} +}; + +static void test_succeeds(grpc_core::ResolverFactory* factory, + const char* string) { + gpr_log(GPR_DEBUG, "test: '%s' should be valid for '%s'", string, + factory->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = absl::make_unique(); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + GPR_ASSERT(resolver != nullptr); +} + +static void test_fails(grpc_core::ResolverFactory* factory, + const char* string) { + gpr_log(GPR_DEBUG, "test: '%s' should be invalid for '%s'", string, + factory->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = absl::make_unique(); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + GPR_ASSERT(resolver == nullptr); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + auto work_serializer = std::make_shared(); + g_work_serializer = &work_serializer; + + grpc_core::ResolverFactory* dns = + grpc_core::ResolverRegistry::LookupResolverFactory("dns"); + + test_succeeds(dns, "dns:10.2.1.1"); + test_succeeds(dns, "dns:10.2.1.1:1234"); + test_succeeds(dns, "dns:www.google.com"); + test_succeeds(dns, "dns:///www.google.com"); + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (gpr_stricmp(resolver.get(), "native") == 0) { + test_fails(dns, "dns://8.8.8.8/8.8.8.8:8888"); + } else { + test_succeeds(dns, "dns://8.8.8.8/8.8.8.8:8888"); + } + grpc_shutdown(); + + return 0; +} diff --git a/test/core/client_channel/resolvers/fake_resolver_test.cc b/test/core/client_channel/resolvers/fake_resolver_test.cc new file mode 100644 index 00000000..010fb1cf --- /dev/null +++ b/test/core/client_channel/resolvers/fake_resolver_test.cc @@ -0,0 +1,210 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "test/core/util/test_config.h" + +class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + void SetExpectedAndEvent(grpc_core::Resolver::Result expected, + gpr_event* ev) { + GPR_ASSERT(ev_ == nullptr); + expected_ = std::move(expected); + ev_ = ev; + } + + void ReturnResult(grpc_core::Resolver::Result actual) override { + GPR_ASSERT(ev_ != nullptr); + // We only check the addresses, because that's the only thing + // explicitly set by the test via + // FakeResolverResponseGenerator::SetResponse(). + GPR_ASSERT(actual.addresses.size() == expected_.addresses.size()); + for (size_t i = 0; i < expected_.addresses.size(); ++i) { + GPR_ASSERT(actual.addresses[i] == expected_.addresses[i]); + } + gpr_event_set(ev_, reinterpret_cast(1)); + ev_ = nullptr; + } + + void ReturnError(grpc_error_handle /*error*/) override {} + + private: + grpc_core::Resolver::Result expected_; + gpr_event* ev_ = nullptr; +}; + +static grpc_core::OrphanablePtr build_fake_resolver( + std::shared_ptr work_serializer, + grpc_core::FakeResolverResponseGenerator* response_generator, + std::unique_ptr result_handler) { + grpc_core::ResolverFactory* factory = + grpc_core::ResolverRegistry::LookupResolverFactory("fake"); + grpc_arg generator_arg = + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + response_generator); + grpc_channel_args channel_args = {1, &generator_arg}; + grpc_core::ResolverArgs args; + args.args = &channel_args; + args.work_serializer = std::move(work_serializer); + args.result_handler = std::move(result_handler); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + return resolver; +} + +// Create a new resolution containing 2 addresses. +static grpc_core::Resolver::Result create_new_resolver_result() { + static size_t test_counter = 0; + const size_t num_addresses = 2; + // Create address list. + grpc_core::Resolver::Result result; + for (size_t i = 0; i < num_addresses; ++i) { + std::string uri_string = absl::StrFormat("ipv4:127.0.0.1:100%" PRIuPTR, + test_counter * num_addresses + i); + absl::StatusOr uri = grpc_core::URI::Parse(uri_string); + GPR_ASSERT(uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*uri, &address)); + absl::InlinedVector args_to_add; + result.addresses.emplace_back( + address.addr, address.len, + grpc_channel_args_copy_and_add(nullptr, nullptr, 0)); + } + ++test_counter; + return result; +} + +static void test_fake_resolver() { + grpc_core::ExecCtx exec_ctx; + std::shared_ptr work_serializer = + std::make_shared(); + // Create resolver. + ResultHandler* result_handler = new ResultHandler(); + grpc_core::RefCountedPtr + response_generator = + grpc_core::MakeRefCounted(); + grpc_core::OrphanablePtr resolver = build_fake_resolver( + work_serializer, response_generator.get(), + std::unique_ptr(result_handler)); + GPR_ASSERT(resolver.get() != nullptr); + resolver->StartLocked(); + // Test 1: normal resolution. + // next_results != NULL, reresolution_results == NULL. + // Expected response is next_results. + gpr_log(GPR_INFO, "TEST 1"); + grpc_core::Resolver::Result result = create_new_resolver_result(); + gpr_event ev1; + gpr_event_init(&ev1); + result_handler->SetExpectedAndEvent(result, &ev1); + response_generator->SetResponse(std::move(result)); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev1, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + // Test 2: update resolution. + // next_results != NULL, reresolution_results == NULL. + // Expected response is next_results. + gpr_log(GPR_INFO, "TEST 2"); + result = create_new_resolver_result(); + gpr_event ev2; + gpr_event_init(&ev2); + result_handler->SetExpectedAndEvent(result, &ev2); + response_generator->SetResponse(std::move(result)); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev2, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + // Test 3: normal re-resolution. + // next_results == NULL, reresolution_results != NULL. + // Expected response is reresolution_results. + gpr_log(GPR_INFO, "TEST 3"); + grpc_core::Resolver::Result reresolution_result = + create_new_resolver_result(); + gpr_event ev3; + gpr_event_init(&ev3); + result_handler->SetExpectedAndEvent(reresolution_result, &ev3); + // Set reresolution_results. + // No result will be returned until re-resolution is requested. + response_generator->SetReresolutionResponse(reresolution_result); + grpc_core::ExecCtx::Get()->Flush(); + // Trigger a re-resolution. + resolver->RequestReresolutionLocked(); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev3, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + // Test 4: repeat re-resolution. + // next_results == NULL, reresolution_results != NULL. + // Expected response is reresolution_results. + gpr_log(GPR_INFO, "TEST 4"); + gpr_event ev4; + gpr_event_init(&ev4); + result_handler->SetExpectedAndEvent(std::move(reresolution_result), &ev4); + // Trigger a re-resolution. + resolver->RequestReresolutionLocked(); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev4, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + // Test 5: normal resolution. + // next_results != NULL, reresolution_results != NULL. + // Expected response is next_results. + gpr_log(GPR_INFO, "TEST 5"); + result = create_new_resolver_result(); + gpr_event ev5; + gpr_event_init(&ev5); + result_handler->SetExpectedAndEvent(result, &ev5); + response_generator->SetResponse(std::move(result)); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev5, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + // Test 6: no-op. + // Requesting a new resolution without setting the response shouldn't trigger + // the resolution callback. + gpr_log(GPR_INFO, "TEST 6"); + gpr_event ev6; + gpr_event_init(&ev6); + result_handler->SetExpectedAndEvent(grpc_core::Resolver::Result(), &ev6); + GPR_ASSERT(gpr_event_wait(&ev6, grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + // Clean up. + resolver.reset(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_fake_resolver(); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/client_channel/resolvers/sockaddr_resolver_test.cc b/test/core/client_channel/resolvers/sockaddr_resolver_test.cc new file mode 100644 index 00000000..5c77d311 --- /dev/null +++ b/test/core/client_channel/resolvers/sockaddr_resolver_test.cc @@ -0,0 +1,122 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/util/test_config.h" + +static std::shared_ptr* g_work_serializer; + +class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + void ReturnResult(grpc_core::Resolver::Result /*result*/) override {} + + void ReturnError(grpc_error_handle error) override { + GRPC_ERROR_UNREF(error); + } +}; + +static void test_succeeds(grpc_core::ResolverFactory* factory, + const char* string) { + gpr_log(GPR_DEBUG, "test: '%s' should be valid for '%s'", string, + factory->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = absl::make_unique(); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + GPR_ASSERT(resolver != nullptr); + resolver->StartLocked(); + /* Flush ExecCtx to avoid stack-use-after-scope on on_res_arg which is + * accessed in the closure on_resolution_cb */ + grpc_core::ExecCtx::Get()->Flush(); +} + +static void test_fails(grpc_core::ResolverFactory* factory, + const char* string) { + gpr_log(GPR_DEBUG, "test: '%s' should be invalid for '%s'", string, + factory->scheme()); + grpc_core::ExecCtx exec_ctx; + absl::StatusOr uri = grpc_core::URI::Parse(string); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_core::ResolverArgs args; + args.uri = std::move(*uri); + args.work_serializer = *g_work_serializer; + args.result_handler = absl::make_unique(); + grpc_core::OrphanablePtr resolver = + factory->CreateResolver(std::move(args)); + GPR_ASSERT(resolver == nullptr); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + auto work_serializer = std::make_shared(); + g_work_serializer = &work_serializer; + + grpc_core::ResolverFactory* ipv4 = + grpc_core::ResolverRegistry::LookupResolverFactory("ipv4"); + grpc_core::ResolverFactory* ipv6 = + grpc_core::ResolverRegistry::LookupResolverFactory("ipv6"); + + test_fails(ipv4, "ipv4:10.2.1.1"); + test_succeeds(ipv4, "ipv4:10.2.1.1:1234"); + test_succeeds(ipv4, "ipv4:10.2.1.1:1234,127.0.0.1:4321"); + test_fails(ipv4, "ipv4:10.2.1.1:123456"); + test_fails(ipv4, "ipv4:www.google.com"); + test_fails(ipv4, "ipv4:["); + test_fails(ipv4, "ipv4://8.8.8.8/8.8.8.8:8888"); + + test_fails(ipv6, "ipv6:["); + test_fails(ipv6, "ipv6:[::]"); + test_succeeds(ipv6, "ipv6:[::]:1234"); + test_fails(ipv6, "ipv6:[::]:123456"); + test_fails(ipv6, "ipv6:www.google.com"); + +#ifdef GRPC_HAVE_UNIX_SOCKET + grpc_core::ResolverFactory* uds = + grpc_core::ResolverRegistry::LookupResolverFactory("unix"); + grpc_core::ResolverFactory* uds_abstract = + grpc_core::ResolverRegistry::LookupResolverFactory("unix-abstract"); + + test_succeeds(uds, "unix:///tmp/sockaddr_resolver_test"); + test_succeeds(uds_abstract, "unix-abstract:sockaddr_resolver_test"); +#endif // GRPC_HAVE_UNIX_SOCKET + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/client_channel/retry_throttle_test.cc b/test/core/client_channel/retry_throttle_test.cc new file mode 100644 index 00000000..1ae86843 --- /dev/null +++ b/test/core/client_channel/retry_throttle_test.cc @@ -0,0 +1,142 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/filters/client_channel/retry_throttle.h" + +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace internal { +namespace { + +TEST(ServerRetryThrottleData, Basic) { + // Max token count is 4, so threshold for retrying is 2. + // Token count starts at 4. + // Each failure decrements by 1. Each success increments by 1.6. + auto throttle_data = + MakeRefCounted(4000, 1600, nullptr); + // Failure: token_count=3. Above threshold. + EXPECT_TRUE(throttle_data->RecordFailure()); + // Success: token_count=4. Not incremented beyond max. + throttle_data->RecordSuccess(); + // Failure: token_count=3. Above threshold. + EXPECT_TRUE(throttle_data->RecordFailure()); + // Failure: token_count=2. At threshold, so no retries. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Failure: token_count=1. Below threshold, so no retries. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Failure: token_count=0. Below threshold, so no retries. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Failure: token_count=0. Below threshold, so no retries. Not + // decremented below min. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Success: token_count=1.6. + throttle_data->RecordSuccess(); + // Success: token_count=3.2. + throttle_data->RecordSuccess(); + // Failure: token_count=2.2. Above threshold. + EXPECT_TRUE(throttle_data->RecordFailure()); + // Failure: token_count=1.2. Below threshold, so no retries. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Success: token_count=2.8. + throttle_data->RecordSuccess(); + // Failure: token_count=1.8. Below threshold, so no retries. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Success: token_count=3.4. + throttle_data->RecordSuccess(); + // Failure: token_count=2.4. Above threshold. + EXPECT_TRUE(throttle_data->RecordFailure()); +} + +TEST(ServerRetryThrottleData, Replacement) { + // Create old throttle data. + // Max token count is 4, so threshold for retrying is 2. + // Token count starts at 4. + // Each failure decrements by 1. Each success increments by 1. + auto old_throttle_data = + MakeRefCounted(4000, 1000, nullptr); + // Failure: token_count=3. Above threshold. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Create new throttle data. + // Max token count is 10, so threshold for retrying is 5. + // Token count starts at 7.5 (ratio inherited from old_throttle_data). + // Each failure decrements by 1. Each success increments by 3. + auto throttle_data = MakeRefCounted( + 10000, 3000, old_throttle_data.get()); + // Failure via old_throttle_data: token_count=6.5. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure: token_count=5.5. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure via old_throttle_data: token_count=4.5. Below threshold. + EXPECT_FALSE(old_throttle_data->RecordFailure()); + // Failure: token_count=3.5. Below threshold. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Success: token_count=6.5. + throttle_data->RecordSuccess(); + // Failure via old_throttle_data: token_count=5.5. Above threshold. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure: token_count=4.5. Below threshold. + EXPECT_FALSE(throttle_data->RecordFailure()); +} + +TEST(ServerRetryThrottleMap, Replacement) { + ServerRetryThrottleMap::Init(); + const std::string kServerName = "server_name"; + // Create old throttle data. + // Max token count is 4, so threshold for retrying is 2. + // Token count starts at 4. + // Each failure decrements by 1. Each success increments by 1. + auto old_throttle_data = + ServerRetryThrottleMap::GetDataForServer(kServerName, 4000, 1000); + // Failure: token_count=3. Above threshold. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Create new throttle data. + // Max token count is 10, so threshold for retrying is 5. + // Token count starts at 7.5 (ratio inherited from old_throttle_data). + // Each failure decrements by 1. Each success increments by 3. + auto throttle_data = + ServerRetryThrottleMap::GetDataForServer(kServerName, 10000, 3000); + // Failure via old_throttle_data: token_count=6.5. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure: token_count=5.5. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure via old_throttle_data: token_count=4.5. Below threshold. + EXPECT_FALSE(old_throttle_data->RecordFailure()); + // Failure: token_count=3.5. Below threshold. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Success: token_count=6.5. + throttle_data->RecordSuccess(); + // Failure via old_throttle_data: token_count=5.5. Above threshold. + EXPECT_TRUE(old_throttle_data->RecordFailure()); + // Failure: token_count=4.5. Below threshold. + EXPECT_FALSE(throttle_data->RecordFailure()); + // Clean up. + ServerRetryThrottleMap::Shutdown(); +} + +} // namespace +} // namespace internal +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/client_channel/rls_lb_config_parser_test.cc b/test/core/client_channel/rls_lb_config_parser_test.cc new file mode 100644 index 00000000..2b69bff8 --- /dev/null +++ b/test/core/client_channel/rls_lb_config_parser_test.cc @@ -0,0 +1,550 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include + +#include "src/core/ext/service_config/service_config.h" +#include "src/core/lib/gpr/env.h" +#include "test/core/util/test_config.h" + +// A regular expression to enter referenced or child errors. +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS +#define CHILD_ERROR_TAG ".*children.*" +#else +#define CHILD_ERROR_TAG ".*referenced_errors.*" +#endif + +namespace grpc_core { +namespace { + +class RlsConfigParsingTest : public ::testing::Test { + public: + static void SetUpTestSuite() { + gpr_setenv("GRPC_EXPERIMENTAL_ENABLE_RLS_LB_POLICY", "true"); + grpc_init(); + } + + static void TearDownTestSuite() { + grpc_shutdown_blocking(); + gpr_unsetenv("GRPC_EXPERIMENTAL_ENABLE_RLS_LB_POLICY"); + } +}; + +TEST_F(RlsConfigParsingTest, ValidConfig) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"lookupService\":\"rls.example.com:80\",\n" + " \"cacheSizeBytes\":1,\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":[\n" + " {\"service\":\"foo\"}\n" + " ]\n" + " }\n" + " ]\n" + " },\n" + " \"childPolicy\":[\n" + " {\"unknown\":{}},\n" // Okay, since the next one exists. + " {\"grpclb\":{}}\n" + " ],\n" + " \"childPolicyConfigTargetFieldName\":\"target\"\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_NE(service_config, nullptr); +} + +// +// top-level fields +// + +TEST_F(RlsConfigParsingTest, TopLevelRequiredFieldsMissing) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig error:does not exist.*" + "field:childPolicyConfigTargetFieldName error:does not exist.*" + "field:childPolicy error:does not exist")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, TopLevelFieldsWrongTypes) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":1,\n" + " \"childPolicy\":1,\n" + " \"childPolicyConfigTargetFieldName\":1\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig error:type should be OBJECT.*" + "field:childPolicyConfigTargetFieldName error:type should be STRING.*" + "field:childPolicy error:type should be ARRAY")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, TopLevelFieldsInvalidValues) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"childPolicy\":[\n" + " {\"unknown\":{}}\n" + " ],\n" + " \"childPolicyConfigTargetFieldName\":\"\"\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:childPolicyConfigTargetFieldName error:must be non-empty.*" + "field:childPolicy" CHILD_ERROR_TAG + "No known policies in list: unknown")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, InvalidChildPolicyConfig) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"childPolicy\":[\n" + " {\"grpclb\":{\"childPolicy\":1}}\n" + " ],\n" + " \"childPolicyConfigTargetFieldName\":\"serviceName\"\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:childPolicy" CHILD_ERROR_TAG "GrpcLb Parser" CHILD_ERROR_TAG + "field:childPolicy" CHILD_ERROR_TAG "type should be array")); + GRPC_ERROR_UNREF(error); +} + +// +// routeLookupConfig fields +// + +TEST_F(RlsConfigParsingTest, RouteLookupConfigRequiredFieldsMissing) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders error:does not exist.*" + "field:lookupService error:does not exist")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, RouteLookupConfigFieldsWrongTypes) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":1,\n" + " \"name\":1,\n" + " \"lookupService\":1,\n" + " \"lookupServiceTimeout\":{},\n" + " \"maxAge\":{},\n" + " \"staleAge\":{},\n" + " \"cacheSizeBytes\":\"xxx\",\n" + " \"defaultTarget\":1\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders error:type should be ARRAY.*" + "field:lookupService error:type should be STRING.*" + "field:maxAge error:type should be STRING.*" + "field:staleAge error:type should be STRING.*" + "field:cacheSizeBytes error:type should be NUMBER.*" + "field:defaultTarget error:type should be STRING")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, RouteLookupConfigFieldsInvalidValues) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"lookupService\":\"\",\n" + " \"cacheSizeBytes\":0\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:lookupService error:must be valid gRPC target URI.*" + "field:cacheSizeBytes error:must be greater than 0")); + GRPC_ERROR_UNREF(error); +} + +// +// grpcKeybuilder fields +// + +TEST_F(RlsConfigParsingTest, GrpcKeybuilderRequiredFieldsMissing) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:0" CHILD_ERROR_TAG + "field:names error:does not exist")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, GrpcKeybuilderWrongFieldTypes) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":1,\n" + " \"headers\":1,\n" + " \"extraKeys\":1,\n" + " \"constantKeys\":1\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:0" CHILD_ERROR_TAG + "field:names error:type should be ARRAY.*" + "field:headers error:type should be ARRAY.*" + "field:extraKeys error:type should be OBJECT.*" + "field:constantKeys error:type should be OBJECT")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, GrpcKeybuilderInvalidValues) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":[],\n" + " \"extraKeys\":{\n" + " \"host\":1,\n" + " \"service\":1,\n" + " \"method\":1\n" + " },\n" + " \"constantKeys\":{\n" + " \"key\":1\n" + " }\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG + "index:0" CHILD_ERROR_TAG "field:names error:list is empty.*" + "field:extraKeys" CHILD_ERROR_TAG + "field:host error:type should be STRING.*" + "field:service error:type should be STRING.*" + "field:method error:type should be STRING.*" + "field:constantKeys" CHILD_ERROR_TAG + "field:key error:type should be STRING")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, GrpcKeybuilderInvalidHeaders) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"headers\":[\n" + " 1,\n" + " {\n" + " \"key\":1,\n" + " \"names\":1\n" + " },\n" + " {\n" + " \"names\":[]\n" + " },\n" + " {\n" + " \"key\":\"\",\n" + " \"names\":[1, \"\"]\n" + " }\n" + " ],\n" + " \"extraKeys\":{\n" + " \"host\": \"\"\n" + " },\n" + " \"constantKeys\":{\n" + " \"\":\"foo\"\n" + " }\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:0" CHILD_ERROR_TAG + "field:headers index:0 error:type should be OBJECT.*" + "field:headers index:1" CHILD_ERROR_TAG + "field:key error:type should be STRING.*" + "field:names error:type should be ARRAY.*" + "field:headers index:2" CHILD_ERROR_TAG + "field:key error:does not exist.*" + "field:names error:list is empty.*" + "field:headers index:3" CHILD_ERROR_TAG + "field:key error:must be non-empty.*" + "field:names index:0 error:type should be STRING.*" + "field:names index:1 error:header name must be non-empty.*" + "field:extraKeys" CHILD_ERROR_TAG + "field:host error:must be non-empty.*" + "field:constantKeys" CHILD_ERROR_TAG "error:keys must be non-empty")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, GrpcKeybuilderNameWrongFieldTypes) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":[\n" + " 1,\n" + " {\n" + " \"service\":1,\n" + " \"method\":1\n" + " }\n" + " ]\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:0" CHILD_ERROR_TAG + "field:names index:0 error:type should be OBJECT.*" + "field:names index:1" CHILD_ERROR_TAG + "field:service error:type should be STRING.*" + "field:method error:type should be STRING")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, DuplicateMethodNamesInSameKeyBuilder) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":[\n" + " {\n" + " \"service\":\"foo\",\n" + " \"method\":\"bar\"\n" + " },\n" + " {\n" + " \"service\":\"foo\",\n" + " \"method\":\"bar\"\n" + " }\n" + " ]\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:0" CHILD_ERROR_TAG + "field:names error:duplicate entry for /foo/bar")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RlsConfigParsingTest, DuplicateMethodNamesInDifferentKeyBuilders) { + const char* service_config_json = + "{\n" + " \"loadBalancingConfig\":[{\n" + " \"rls\":{\n" + " \"routeLookupConfig\":{\n" + " \"grpcKeybuilders\":[\n" + " {\n" + " \"names\":[\n" + " {\n" + " \"service\":\"foo\",\n" + " \"method\":\"bar\"\n" + " }\n" + " ]\n" + " },\n" + " {\n" + " \"names\":[\n" + " {\n" + " \"service\":\"foo\",\n" + " \"method\":\"bar\"\n" + " }\n" + " ]\n" + " }\n" + " ]\n" + " }\n" + " }\n" + " }]\n" + "}\n"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto service_config = ServiceConfig::Create( + /*args=*/nullptr, service_config_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing RLS LB policy config" CHILD_ERROR_TAG + "field:routeLookupConfig" CHILD_ERROR_TAG + "field:grpcKeybuilders" CHILD_ERROR_TAG "index:1" CHILD_ERROR_TAG + "field:names error:duplicate entry for /foo/bar")); + GRPC_ERROR_UNREF(error); +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/client_channel/service_config_test.cc b/test/core/client_channel/service_config_test.cc new file mode 100644 index 00000000..b555fc47 --- /dev/null +++ b/test/core/client_channel/service_config_test.cc @@ -0,0 +1,1541 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/service_config/service_config.h" + +#include +#include + +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" +#include "src/core/ext/filters/client_channel/retry_service_config.h" +#include "src/core/ext/filters/message_size/message_size_filter.h" +#include "src/core/ext/service_config/service_config_parser.h" +#include "src/core/lib/gpr/string.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +// +// ServiceConfig tests +// + +// Set this channel arg to true to disable parsing. +#define GRPC_ARG_DISABLE_PARSING "disable_parsing" + +// A regular expression to enter referenced or child errors. +#ifdef GRPC_ERROR_IS_ABSEIL_STATUS +#define CHILD_ERROR_TAG ".*children.*" +#else +#define CHILD_ERROR_TAG ".*referenced_errors.*" +#endif + +class TestParsedConfig1 : public ServiceConfigParser::ParsedConfig { + public: + explicit TestParsedConfig1(int value) : value_(value) {} + + int value() const { return value_; } + + private: + int value_; +}; + +class TestParser1 : public ServiceConfigParser::Parser { + public: + std::unique_ptr ParseGlobalParams( + const grpc_channel_args* args, const Json& json, + grpc_error_handle* error) override { + GPR_DEBUG_ASSERT(error != nullptr); + if (grpc_channel_args_find_bool(args, GRPC_ARG_DISABLE_PARSING, false)) { + return nullptr; + } + auto it = json.object_value().find("global_param"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::NUMBER) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING(InvalidTypeErrorMessage()); + return nullptr; + } + int value = gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (value == -1) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING(InvalidValueErrorMessage()); + return nullptr; + } + return absl::make_unique(value); + } + return nullptr; + } + + static const char* InvalidTypeErrorMessage() { + return "global_param value type should be a number"; + } + + static const char* InvalidValueErrorMessage() { + return "global_param value type should be non-negative"; + } +}; + +class TestParser2 : public ServiceConfigParser::Parser { + public: + std::unique_ptr ParsePerMethodParams( + const grpc_channel_args* args, const Json& json, + grpc_error_handle* error) override { + GPR_DEBUG_ASSERT(error != nullptr); + if (grpc_channel_args_find_bool(args, GRPC_ARG_DISABLE_PARSING, false)) { + return nullptr; + } + auto it = json.object_value().find("method_param"); + if (it != json.object_value().end()) { + if (it->second.type() != Json::Type::NUMBER) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING(InvalidTypeErrorMessage()); + return nullptr; + } + int value = gpr_parse_nonnegative_int(it->second.string_value().c_str()); + if (value == -1) { + *error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING(InvalidValueErrorMessage()); + return nullptr; + } + return absl::make_unique(value); + } + return nullptr; + } + + static const char* InvalidTypeErrorMessage() { + return "method_param value type should be a number"; + } + + static const char* InvalidValueErrorMessage() { + return "method_param value type should be non-negative"; + } +}; + +// This parser always adds errors +class ErrorParser : public ServiceConfigParser::Parser { + public: + std::unique_ptr ParsePerMethodParams( + const grpc_channel_args* /*arg*/, const Json& /*json*/, + grpc_error_handle* error) override { + GPR_DEBUG_ASSERT(error != nullptr); + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(MethodError()); + return nullptr; + } + + std::unique_ptr ParseGlobalParams( + const grpc_channel_args* /*arg*/, const Json& /*json*/, + grpc_error_handle* error) override { + GPR_DEBUG_ASSERT(error != nullptr); + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(GlobalError()); + return nullptr; + } + + static const char* MethodError() { return "ErrorParser : methodError"; } + + static const char* GlobalError() { return "ErrorParser : globalError"; } +}; + +class ServiceConfigTest : public ::testing::Test { + protected: + void SetUp() override { + ServiceConfigParserShutdown(); + ServiceConfigParserInit(); + EXPECT_EQ( + ServiceConfigParser::RegisterParser(absl::make_unique()), + 0); + EXPECT_EQ( + ServiceConfigParser::RegisterParser(absl::make_unique()), + 1); + } +}; + +TEST_F(ServiceConfigTest, ErrorCheck1) { + const char* test_json = ""; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("JSON parse error")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, BasicTest1) { + const char* test_json = "{}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); +} + +TEST_F(ServiceConfigTest, SkipMethodConfigWithNoNameOrEmptyName) { + const char* test_json = + "{\"methodConfig\": [" + " {\"method_param\":1}," + " {\"name\":[], \"method_param\":1}," + " {\"name\":[{\"service\":\"TestServ\"}], \"method_param\":2}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = ((*vector_ptr)[1]).get(); + EXPECT_EQ(static_cast(parsed_config)->value(), 2); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateMethodConfigNames) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{\"service\":\"TestServ\"}]}," + " {\"name\":[{\"service\":\"TestServ\"}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple method configs with same name")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateMethodConfigNamesWithNullMethod) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{\"service\":\"TestServ\",\"method\":null}]}," + " {\"name\":[{\"service\":\"TestServ\"}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple method configs with same name")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateMethodConfigNamesWithEmptyMethod) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{\"service\":\"TestServ\",\"method\":\"\"}]}," + " {\"name\":[{\"service\":\"TestServ\"}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple method configs with same name")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateDefaultMethodConfigs) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{}]}," + " {\"name\":[{}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple default method configs")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateDefaultMethodConfigsWithNullService) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{\"service\":null}]}," + " {\"name\":[{}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple default method configs")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ErrorDuplicateDefaultMethodConfigsWithEmptyService) { + const char* test_json = + "{\"methodConfig\": [" + " {\"name\":[{\"service\":\"\"}]}," + " {\"name\":[{}]}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "multiple default method configs")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, ValidMethodConfig) { + const char* test_json = + "{\"methodConfig\": [{\"name\":[{\"service\":\"TestServ\"}]}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); +} + +TEST_F(ServiceConfigTest, Parser1BasicTest1) { + const char* test_json = "{\"global_param\":5}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ((static_cast(svc_cfg->GetGlobalParsedConfig(0))) + ->value(), + 5); + EXPECT_EQ(svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")), + nullptr); +} + +TEST_F(ServiceConfigTest, Parser1BasicTest2) { + const char* test_json = "{\"global_param\":1000}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ((static_cast(svc_cfg->GetGlobalParsedConfig(0))) + ->value(), + 1000); +} + +TEST_F(ServiceConfigTest, Parser1DisabledViaChannelArg) { + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_DISABLE_PARSING), 1); + grpc_channel_args args = {1, &arg}; + const char* test_json = "{\"global_param\":5}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(svc_cfg->GetGlobalParsedConfig(0), nullptr); +} + +TEST_F(ServiceConfigTest, Parser1ErrorInvalidType) { + const char* test_json = "{\"global_param\":\"5\"}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + absl::StrCat("Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG, + TestParser1::InvalidTypeErrorMessage()))); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, Parser1ErrorInvalidValue) { + const char* test_json = "{\"global_param\":-5}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + absl::StrCat("Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG, + TestParser1::InvalidValueErrorMessage()))); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, Parser2BasicTest) { + const char* test_json = + "{\"methodConfig\": [{\"name\":[{\"service\":\"TestServ\"}], " + "\"method_param\":5}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = ((*vector_ptr)[1]).get(); + EXPECT_EQ(static_cast(parsed_config)->value(), 5); +} + +TEST_F(ServiceConfigTest, Parser2DisabledViaChannelArg) { + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_DISABLE_PARSING), 1); + grpc_channel_args args = {1, &arg}; + const char* test_json = + "{\"methodConfig\": [{\"name\":[{\"service\":\"TestServ\"}], " + "\"method_param\":5}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = ((*vector_ptr)[1]).get(); + EXPECT_EQ(parsed_config, nullptr); +} + +TEST_F(ServiceConfigTest, Parser2ErrorInvalidType) { + const char* test_json = + "{\"methodConfig\": [{\"name\":[{\"service\":\"TestServ\"}], " + "\"method_param\":\"5\"}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex(absl::StrCat( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG, + TestParser2::InvalidTypeErrorMessage()))); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ServiceConfigTest, Parser2ErrorInvalidValue) { + const char* test_json = + "{\"methodConfig\": [{\"name\":[{\"service\":\"TestServ\"}], " + "\"method_param\":-5}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex(absl::StrCat( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG, + TestParser2::InvalidValueErrorMessage()))); + GRPC_ERROR_UNREF(error); +} + +// Test parsing with ErrorParsers which always add errors +class ErroredParsersScopingTest : public ::testing::Test { + protected: + void SetUp() override { + ServiceConfigParserShutdown(); + ServiceConfigParserInit(); + EXPECT_EQ( + ServiceConfigParser::RegisterParser(absl::make_unique()), + 0); + EXPECT_EQ( + ServiceConfigParser::RegisterParser(absl::make_unique()), + 1); + } +}; + +TEST_F(ErroredParsersScopingTest, GlobalParams) { + const char* test_json = "{}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex(absl::StrCat( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG, + ErrorParser::GlobalError(), ".*", ErrorParser::GlobalError()))); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ErroredParsersScopingTest, MethodParams) { + const char* test_json = "{\"methodConfig\": [{}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex(absl::StrCat( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG, + ErrorParser::GlobalError(), ".*", ErrorParser::GlobalError(), + ".*Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG, + ErrorParser::MethodError(), ".*", ErrorParser::MethodError()))); + GRPC_ERROR_UNREF(error); +} + +// +// client_channel parser tests +// + +class ClientChannelParserTest : public ::testing::Test { + protected: + void SetUp() override { + ServiceConfigParserShutdown(); + ServiceConfigParserInit(); + EXPECT_EQ( + ServiceConfigParser::RegisterParser( + absl::make_unique()), + 0); + } +}; + +TEST_F(ClientChannelParserTest, ValidLoadBalancingConfigPickFirst) { + const char* test_json = "{\"loadBalancingConfig\": [{\"pick_first\":{}}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + auto lb_config = parsed_config->parsed_lb_config(); + EXPECT_STREQ(lb_config->name(), "pick_first"); +} + +TEST_F(ClientChannelParserTest, ValidLoadBalancingConfigRoundRobin) { + const char* test_json = + "{\"loadBalancingConfig\": [{\"round_robin\":{}}, {}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + auto lb_config = parsed_config->parsed_lb_config(); + EXPECT_STREQ(lb_config->name(), "round_robin"); +} + +TEST_F(ClientChannelParserTest, ValidLoadBalancingConfigGrpclb) { + const char* test_json = + "{\"loadBalancingConfig\": " + "[{\"grpclb\":{\"childPolicy\":[{\"pick_first\":{}}]}}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + auto lb_config = parsed_config->parsed_lb_config(); + EXPECT_STREQ(lb_config->name(), "grpclb"); +} + +TEST_F(ClientChannelParserTest, ValidLoadBalancingConfigXds) { + const char* test_json = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"xds_cluster_resolver_experimental\":{\n" + " \"discoveryMechanisms\": [\n" + " { \"clusterName\": \"foo\",\n" + " \"type\": \"EDS\"\n" + " } ]\n" + " } }\n" + " ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + auto lb_config = parsed_config->parsed_lb_config(); + EXPECT_STREQ(lb_config->name(), "xds_cluster_resolver_experimental"); +} + +TEST_F(ClientChannelParserTest, UnknownLoadBalancingConfig) { + const char* test_json = "{\"loadBalancingConfig\": [{\"unknown\":{}}]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG + "Client channel global parser" CHILD_ERROR_TAG + "field:loadBalancingConfig" CHILD_ERROR_TAG + "No known policies in list: unknown")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, InvalidGrpclbLoadBalancingConfig) { + const char* test_json = + "{\"loadBalancingConfig\": [" + " {\"grpclb\":{\"childPolicy\":1}}," + " {\"round_robin\":{}}" + "]}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG + "Client channel global parser" CHILD_ERROR_TAG + "field:loadBalancingConfig" CHILD_ERROR_TAG + "GrpcLb Parser" CHILD_ERROR_TAG + "field:childPolicy" CHILD_ERROR_TAG "type should be array")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, ValidLoadBalancingPolicy) { + const char* test_json = "{\"loadBalancingPolicy\":\"pick_first\"}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + EXPECT_EQ(parsed_config->parsed_deprecated_lb_policy(), "pick_first"); +} + +TEST_F(ClientChannelParserTest, ValidLoadBalancingPolicyAllCaps) { + const char* test_json = "{\"loadBalancingPolicy\":\"PICK_FIRST\"}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + EXPECT_EQ(parsed_config->parsed_deprecated_lb_policy(), "pick_first"); +} + +TEST_F(ClientChannelParserTest, UnknownLoadBalancingPolicy) { + const char* test_json = "{\"loadBalancingPolicy\":\"unknown\"}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG + "Client channel global parser" CHILD_ERROR_TAG + "field:loadBalancingPolicy error:Unknown lb policy")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, LoadBalancingPolicyXdsNotAllowed) { + const char* test_json = + "{\"loadBalancingPolicy\":\"xds_cluster_resolver_experimental\"}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG + "Client channel global parser" CHILD_ERROR_TAG + "field:loadBalancingPolicy " + "error:xds_cluster_resolver_experimental requires " + "a config. Please use loadBalancingConfig instead.")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, ValidTimeout) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"timeout\": \"5s\"\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = ((*vector_ptr)[0]).get(); + EXPECT_EQ((static_cast( + parsed_config)) + ->timeout(), + 5000); +} + +TEST_F(ClientChannelParserTest, InvalidTimeout) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"timeout\": \"5sec\"\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "Client channel parser" CHILD_ERROR_TAG + "field:timeout error:type should be STRING of the form given " + "by google.proto.Duration")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, ValidWaitForReady) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"waitForReady\": true\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = ((*vector_ptr)[0]).get(); + ASSERT_TRUE( + (static_cast( + parsed_config)) + ->wait_for_ready() + .has_value()); + EXPECT_TRUE( + (static_cast( + parsed_config)) + ->wait_for_ready() + .value()); +} + +TEST_F(ClientChannelParserTest, InvalidWaitForReady) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"waitForReady\": \"true\"\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "Client channel parser" CHILD_ERROR_TAG + "field:waitForReady error:Type should be true/false")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(ClientChannelParserTest, ValidHealthCheck) { + const char* test_json = + "{\n" + " \"healthCheckConfig\": {\n" + " \"serviceName\": \"health_check_service_name\"\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->health_check_service_name(), + "health_check_service_name"); +} + +TEST_F(ClientChannelParserTest, InvalidHealthCheckMultipleEntries) { + const char* test_json = + "{\n" + " \"healthCheckConfig\": {\n" + " \"serviceName\": \"health_check_service_name\"\n" + " },\n" + " \"healthCheckConfig\": {\n" + " \"serviceName\": \"health_check_service_name1\"\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "JSON parsing failed" CHILD_ERROR_TAG + "duplicate key \"healthCheckConfig\" at index 104")); + GRPC_ERROR_UNREF(error); +} + +// +// retry parser tests +// + +class RetryParserTest : public ::testing::Test { + protected: + void SetUp() override { + ServiceConfigParserShutdown(); + ServiceConfigParserInit(); + EXPECT_EQ(ServiceConfigParser::RegisterParser( + absl::make_unique()), + 0); + } +}; + +TEST_F(RetryParserTest, ValidRetryThrottling) { + const char* test_json = + "{\n" + " \"retryThrottling\": {\n" + " \"maxTokens\": 2,\n" + " \"tokenRatio\": 1.0\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* parsed_config = + static_cast( + svc_cfg->GetGlobalParsedConfig(0)); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->max_milli_tokens(), 2000); + EXPECT_EQ(parsed_config->milli_token_ratio(), 1000); +} + +TEST_F(RetryParserTest, RetryThrottlingMissingFields) { + const char* test_json = + "{\n" + " \"retryThrottling\": {\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG "retryThrottling" CHILD_ERROR_TAG + "field:retryThrottling field:maxTokens error:Not found" + ".*field:retryThrottling field:tokenRatio error:Not found")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryThrottlingNegativeMaxTokens) { + const char* test_json = + "{\n" + " \"retryThrottling\": {\n" + " \"maxTokens\": -2,\n" + " \"tokenRatio\": 1.0\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG "retryThrottling" CHILD_ERROR_TAG + "field:retryThrottling field:maxTokens error:should " + "be greater than zero")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryThrottlingInvalidTokenRatio) { + const char* test_json = + "{\n" + " \"retryThrottling\": {\n" + " \"maxTokens\": 2,\n" + " \"tokenRatio\": -1\n" + " }\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("Service config parsing error" CHILD_ERROR_TAG + "Global Params" CHILD_ERROR_TAG + "retryThrottling" CHILD_ERROR_TAG + "field:retryThrottling field:tokenRatio " + "error:Failed parsing")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, ValidRetryPolicy) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + const auto* parsed_config = + static_cast( + ((*vector_ptr)[0]).get()); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->max_attempts(), 3); + EXPECT_EQ(parsed_config->initial_backoff(), 1000); + EXPECT_EQ(parsed_config->max_backoff(), 120000); + EXPECT_EQ(parsed_config->backoff_multiplier(), 1.6f); + EXPECT_EQ(parsed_config->per_attempt_recv_timeout(), absl::nullopt); + EXPECT_TRUE( + parsed_config->retryable_status_codes().Contains(GRPC_STATUS_ABORTED)); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": 5\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "field:retryPolicy error:should be of type object")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyRequiredFieldsMissing) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + ".*field:maxAttempts error:required field missing" + ".*field:initialBackoff error:does not exist" + ".*field:maxBackoff error:does not exist" + ".*field:backoffMultiplier error:required field missing")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyMaxAttemptsWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": \"FOO\",\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:maxAttempts error:should be of type number")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyMaxAttemptsBadValue) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 1,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:maxAttempts error:should be at least 2")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyInitialBackoffWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1sec\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:initialBackoff error:type should be STRING of the " + "form given by google.proto.Duration")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyInitialBackoffBadValue) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"0s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:initialBackoff error:must be greater than 0")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyMaxBackoffWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120sec\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:maxBackoff error:type should be STRING of the form " + "given by google.proto.Duration")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyMaxBackoffBadValue) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"0s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:maxBackoff error:must be greater than 0")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyBackoffMultiplierWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:backoffMultiplier error:should be of type number")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyBackoffMultiplierBadValue) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 0,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:backoffMultiplier error:must be greater than 0")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyEmptyRetryableStatusCodes) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"retryableStatusCodes\": []\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:retryableStatusCodes error:must be non-empty")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyRetryableStatusCodesWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"retryableStatusCodes\": 0\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:retryableStatusCodes error:must be of type array")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyUnparseableRetryableStatusCodes) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"retryableStatusCodes\": [\"FOO\", 2]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG "field:retryableStatusCodes " + "error:failed to parse status code" + ".*field:retryableStatusCodes " + "error:status codes should be of type string")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, ValidRetryPolicyWithPerAttemptRecvTimeout) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"perAttemptRecvTimeout\": \"1s\",\n" + " \"retryableStatusCodes\": [\"ABORTED\"]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1); + grpc_channel_args args = {1, &arg}; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + const auto* parsed_config = + static_cast( + ((*vector_ptr)[0]).get()); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->max_attempts(), 2); + EXPECT_EQ(parsed_config->initial_backoff(), 1000); + EXPECT_EQ(parsed_config->max_backoff(), 120000); + EXPECT_EQ(parsed_config->backoff_multiplier(), 1.6f); + EXPECT_EQ(parsed_config->per_attempt_recv_timeout(), 1000); + EXPECT_TRUE( + parsed_config->retryable_status_codes().Contains(GRPC_STATUS_ABORTED)); +} + +TEST_F(RetryParserTest, + ValidRetryPolicyWithPerAttemptRecvTimeoutIgnoredWhenHedgingDisabled) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"perAttemptRecvTimeout\": \"1s\",\n" + " \"retryableStatusCodes\": [\"ABORTED\"]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + const auto* parsed_config = + static_cast( + ((*vector_ptr)[0]).get()); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->max_attempts(), 2); + EXPECT_EQ(parsed_config->initial_backoff(), 1000); + EXPECT_EQ(parsed_config->max_backoff(), 120000); + EXPECT_EQ(parsed_config->backoff_multiplier(), 1.6f); + EXPECT_EQ(parsed_config->per_attempt_recv_timeout(), absl::nullopt); + EXPECT_TRUE( + parsed_config->retryable_status_codes().Contains(GRPC_STATUS_ABORTED)); +} + +TEST_F(RetryParserTest, + ValidRetryPolicyWithPerAttemptRecvTimeoutAndUnsetRetryableStatusCodes) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"perAttemptRecvTimeout\": \"1s\"\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1); + grpc_channel_args args = {1, &arg}; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + const auto* parsed_config = + static_cast( + ((*vector_ptr)[0]).get()); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->max_attempts(), 2); + EXPECT_EQ(parsed_config->initial_backoff(), 1000); + EXPECT_EQ(parsed_config->max_backoff(), 120000); + EXPECT_EQ(parsed_config->backoff_multiplier(), 1.6f); + EXPECT_EQ(parsed_config->per_attempt_recv_timeout(), 1000); + EXPECT_TRUE(parsed_config->retryable_status_codes().Empty()); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyPerAttemptRecvTimeoutUnparseable) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"perAttemptRecvTimeout\": \"1sec\",\n" + " \"retryableStatusCodes\": [\"ABORTED\"]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1); + grpc_channel_args args = {1, &arg}; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:perAttemptRecvTimeout error:type must be STRING " + "of the form given by google.proto.Duration.")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyPerAttemptRecvTimeoutWrongType) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"perAttemptRecvTimeout\": 1,\n" + " \"retryableStatusCodes\": [\"ABORTED\"]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1); + grpc_channel_args args = {1, &arg}; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:perAttemptRecvTimeout error:type must be STRING " + "of the form given by google.proto.Duration.")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(RetryParserTest, InvalidRetryPolicyPerAttemptRecvTimeoutBadValue) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": \"1.6\",\n" + " \"perAttemptRecvTimeout\": \"0s\",\n" + " \"retryableStatusCodes\": [\"ABORTED\"]\n" + " }\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1); + grpc_channel_args args = {1, &arg}; + auto svc_cfg = ServiceConfig::Create(&args, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "retryPolicy" CHILD_ERROR_TAG + "field:perAttemptRecvTimeout error:must be greater than 0")); + GRPC_ERROR_UNREF(error); +} + +// +// message_size parser tests +// + +class MessageSizeParserTest : public ::testing::Test { + protected: + void SetUp() override { + ServiceConfigParserShutdown(); + ServiceConfigParserInit(); + EXPECT_EQ(ServiceConfigParser::RegisterParser( + absl::make_unique()), + 0); + } +}; + +TEST_F(MessageSizeParserTest, Valid) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"maxRequestMessageBytes\": 1024,\n" + " \"maxResponseMessageBytes\": 1024\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + const auto* vector_ptr = svc_cfg->GetMethodParsedConfigVector( + grpc_slice_from_static_string("/TestServ/TestMethod")); + ASSERT_NE(vector_ptr, nullptr); + auto parsed_config = + static_cast(((*vector_ptr)[0]).get()); + ASSERT_NE(parsed_config, nullptr); + EXPECT_EQ(parsed_config->limits().max_send_size, 1024); + EXPECT_EQ(parsed_config->limits().max_recv_size, 1024); +} + +TEST_F(MessageSizeParserTest, InvalidMaxRequestMessageBytes) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"maxRequestMessageBytes\": -1024\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "Message size parser" CHILD_ERROR_TAG + "field:maxRequestMessageBytes error:should be non-negative")); + GRPC_ERROR_UNREF(error); +} + +TEST_F(MessageSizeParserTest, InvalidMaxResponseMessageBytes) { + const char* test_json = + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"TestServ\", \"method\": \"TestMethod\" }\n" + " ],\n" + " \"maxResponseMessageBytes\": {}\n" + " } ]\n" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + auto svc_cfg = ServiceConfig::Create(nullptr, test_json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "Service config parsing error" CHILD_ERROR_TAG + "Method Params" CHILD_ERROR_TAG "methodConfig" CHILD_ERROR_TAG + "Message size parser" CHILD_ERROR_TAG + "field:maxResponseMessageBytes error:should be of type " + "number")); + GRPC_ERROR_UNREF(error); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/client_idle/idle_filter_state_test.cc b/test/core/client_idle/idle_filter_state_test.cc new file mode 100644 index 00000000..2911b683 --- /dev/null +++ b/test/core/client_idle/idle_filter_state_test.cc @@ -0,0 +1,109 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/filters/client_idle/idle_filter_state.h" + +#include + +#include +#include +#include + +#include + +namespace grpc_core { +namespace testing { + +TEST(IdleFilterStateTest, IdlenessStartsTimer) { + IdleFilterState s(false); + s.IncreaseCallCount(); + // First idle should start the timer + EXPECT_TRUE(s.DecreaseCallCount()); + for (int i = 0; i < 10; i++) { + // Next idle should not! + s.IncreaseCallCount(); + EXPECT_FALSE(s.DecreaseCallCount()); + } +} + +TEST(IdleFilterStateTest, TimerStopsAfterIdle) { + IdleFilterState s(true); + EXPECT_FALSE(s.CheckTimer()); +} + +TEST(IdleFilterStateTest, TimerKeepsGoingWithActivity) { + IdleFilterState s(true); + for (int i = 0; i < 10; i++) { + s.IncreaseCallCount(); + (void)s.DecreaseCallCount(); + EXPECT_TRUE(s.CheckTimer()); + } + EXPECT_FALSE(s.CheckTimer()); +} + +TEST(IdleFilterStateTest, StressTest) { + IdleFilterState s(false); + std::atomic done{false}; + int idle_polls = 0; + int thread_jumps = 0; + std::vector threads; + for (int idx = 0; idx < 100; idx++) { + std::thread t([&] { + int ctr = 0; + auto increase = [&] { + s.IncreaseCallCount(); + ctr++; + }; + auto decrease = [&] { + ctr--; + if (s.DecreaseCallCount()) { + thread_jumps++; + if (thread_jumps == 10) done.store(true, std::memory_order_relaxed); + EXPECT_EQ(ctr, 0); + do { + idle_polls++; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while (s.CheckTimer()); + } + }; + std::mt19937 g{std::random_device()()}; + while (!done.load(std::memory_order_relaxed)) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + for (int i = 0; i < 100; i++) { + if (g() & 1) { + increase(); + } else if (ctr > 0) { + decrease(); + } + } + while (ctr > 0) { + decrease(); + } + } + while (ctr > 0) { + decrease(); + } + }); + threads.emplace_back(std::move(t)); + } + for (auto& thread : threads) thread.join(); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/compiler_bugs/miscompile_with_no_unique_address_test.cc b/test/core/compiler_bugs/miscompile_with_no_unique_address_test.cc new file mode 100644 index 00000000..73e39f5c --- /dev/null +++ b/test/core/compiler_bugs/miscompile_with_no_unique_address_test.cc @@ -0,0 +1,59 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +// Make a template argument to test which bit pattern remains in A's destructor +// to try and detect similar bugs in non-MSAN builds (none have been detected +// yet thankfully) +template +class A { + public: + ~A() { EXPECT_EQ(a_, kInit); } + int a_ = kInit; +}; +template +class P : A { + public: + explicit P(T b) : b_(b) {} + // clang 11 with MSAN miscompiles this and marks A::a_ as uninitialized during + // P::~P() if GPR_NO_UNIQUE_ADDRESS is [[no_unique_address]] - so this test + // stands to ensure that we have a working definition for this compiler so + // that we don't flag false negatives elsewhere in the codebase. + GPR_NO_UNIQUE_ADDRESS T b_; +}; + +template +void c(T a) { + P _(a); +} + +TEST(Miscompile, Zero) { + c<0>([] {}); +} + +TEST(Miscompile, One) { + c<1>([] {}); +} + +TEST(Miscompile, MinusOne) { + c<-1>([] {}); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/compression/algorithm_test.cc b/test/core/compression/algorithm_test.cc new file mode 100644 index 00000000..b28aaf57 --- /dev/null +++ b/test/core/compression/algorithm_test.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/compression/algorithm_metadata.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/test_config.h" + +const uint32_t message_prefix_length = 0; +const uint32_t stream_prefix_length = 7; +static void test_algorithm_mesh(void) { + int i; + + gpr_log(GPR_DEBUG, "test_algorithm_mesh"); + + for (i = 0; i < GRPC_COMPRESS_ALGORITHMS_COUNT; i++) { + const char* name; + grpc_compression_algorithm parsed; + grpc_slice mdstr; + grpc_mdelem mdelem; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT( + grpc_compression_algorithm_name((grpc_compression_algorithm)i, &name)); + GPR_ASSERT(grpc_compression_algorithm_parse( + grpc_slice_from_static_string(name), &parsed)); + GPR_ASSERT((int)parsed == i); + mdstr = grpc_slice_from_copied_string(name); + GPR_ASSERT(grpc_slice_eq(mdstr, grpc_compression_algorithm_slice(parsed))); + GPR_ASSERT(parsed == grpc_compression_algorithm_from_slice(mdstr)); + if (parsed == 0) { + continue; + } else if (grpc_compression_algorithm_is_message(parsed)) { + mdelem = grpc_message_compression_encoding_mdelem( + grpc_compression_algorithm_to_message_compression_algorithm(parsed)); + grpc_slice value = GRPC_MDVALUE(mdelem); + GPR_ASSERT(0 == memcmp(&name[message_prefix_length], + GRPC_SLICE_START_PTR(value), + GRPC_SLICE_LENGTH(value))); + GPR_ASSERT(grpc_slice_eq(GRPC_MDKEY(mdelem), GRPC_MDSTR_GRPC_ENCODING)); + } else { + mdelem = grpc_stream_compression_encoding_mdelem( + grpc_compression_algorithm_to_stream_compression_algorithm(parsed)); + grpc_slice value = GRPC_MDVALUE(mdelem); + GPR_ASSERT(0 == memcmp(&name[stream_prefix_length], + GRPC_SLICE_START_PTR(value), + GRPC_SLICE_LENGTH(value))); + GPR_ASSERT( + grpc_slice_eq(GRPC_MDKEY(mdelem), GRPC_MDSTR_CONTENT_ENCODING)); + } + grpc_slice_unref_internal(mdstr); + GRPC_MDELEM_UNREF(mdelem); + } + + /* test failure */ + GPR_ASSERT(GRPC_MDISNULL( + grpc_compression_encoding_mdelem(GRPC_COMPRESS_ALGORITHMS_COUNT))); +} + +static void test_algorithm_failure(void) { + gpr_log(GPR_DEBUG, "test_algorithm_failure"); + // Test invalid algorithm name + grpc_slice mdstr = + grpc_slice_from_static_string("this-is-an-invalid-algorithm"); + GPR_ASSERT(grpc_compression_algorithm_from_slice(mdstr) == + GRPC_COMPRESS_ALGORITHMS_COUNT); + grpc_slice_unref_internal(mdstr); + // Test invalid algorithm enum entry. + GPR_ASSERT(grpc_compression_algorithm_name(GRPC_COMPRESS_ALGORITHMS_COUNT, + nullptr) == 0); + GPR_ASSERT( + grpc_compression_algorithm_name(static_cast( + GRPC_COMPRESS_ALGORITHMS_COUNT + 1), + nullptr) == 0); + GPR_ASSERT(grpc_slice_eq( + grpc_compression_algorithm_slice(GRPC_COMPRESS_ALGORITHMS_COUNT), + grpc_empty_slice())); + GPR_ASSERT(grpc_slice_eq( + grpc_compression_algorithm_slice(static_cast( + static_cast(GRPC_COMPRESS_ALGORITHMS_COUNT) + 1)), + grpc_empty_slice())); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_algorithm_mesh(); + test_algorithm_failure(); + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/compression/compression_test.cc b/test/core/compression/compression_test.cc new file mode 100644 index 00000000..2bc5ad20 --- /dev/null +++ b/test/core/compression/compression_test.cc @@ -0,0 +1,349 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/test_config.h" + +static void test_compression_algorithm_parse(void) { + size_t i; + const char* valid_names[] = {"identity", "gzip", "deflate", "stream/gzip"}; + const grpc_compression_algorithm valid_algorithms[] = { + GRPC_COMPRESS_NONE, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_DEFLATE, + GRPC_COMPRESS_STREAM_GZIP}; + const char* invalid_names[] = {"gzip2", "foo", "", "2gzip"}; + + gpr_log(GPR_DEBUG, "test_compression_algorithm_parse"); + + for (i = 0; i < GPR_ARRAY_SIZE(valid_names); i++) { + const char* valid_name = valid_names[i]; + grpc_compression_algorithm algorithm; + const int success = grpc_compression_algorithm_parse( + grpc_slice_from_static_string(valid_name), &algorithm); + GPR_ASSERT(success != 0); + GPR_ASSERT(algorithm == valid_algorithms[i]); + } + + for (i = 0; i < GPR_ARRAY_SIZE(invalid_names); i++) { + const char* invalid_name = invalid_names[i]; + grpc_compression_algorithm algorithm; + int success; + success = grpc_compression_algorithm_parse( + grpc_slice_from_static_string(invalid_name), &algorithm); + GPR_ASSERT(success == 0); + /* the value of "algorithm" is undefined upon failure */ + } +} + +static void test_compression_algorithm_name(void) { + int success; + const char* name; + size_t i; + const char* valid_names[] = {"identity", "gzip", "deflate", "stream/gzip"}; + const grpc_compression_algorithm valid_algorithms[] = { + GRPC_COMPRESS_NONE, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_DEFLATE, + GRPC_COMPRESS_STREAM_GZIP}; + + gpr_log(GPR_DEBUG, "test_compression_algorithm_name"); + + for (i = 0; i < GPR_ARRAY_SIZE(valid_algorithms); i++) { + success = grpc_compression_algorithm_name(valid_algorithms[i], &name); + GPR_ASSERT(success != 0); + GPR_ASSERT(strcmp(name, valid_names[i]) == 0); + } + + success = + grpc_compression_algorithm_name(GRPC_COMPRESS_ALGORITHMS_COUNT, &name); + GPR_ASSERT(success == 0); + /* the value of "name" is undefined upon failure */ +} + +static void test_compression_algorithm_for_level(void) { + gpr_log(GPR_DEBUG, "test_compression_algorithm_for_level"); + + { + /* accept only identity (aka none) */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } + + { + /* accept only gzip */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_GZIP); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_GZIP == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_GZIP == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_GZIP == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } + + { + /* accept only deflate */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_DEFLATE); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } + + { + /* accept gzip and deflate */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_GZIP); + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_DEFLATE); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_GZIP == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } + + { + /* accept stream gzip */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_STREAM_GZIP); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } + + { + /* accept all algorithms */ + uint32_t accepted_encodings = 0; + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_NONE); /* always */ + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_GZIP); + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_DEFLATE); + grpc_core::SetBit(&accepted_encodings, GRPC_COMPRESS_STREAM_GZIP); + + GPR_ASSERT(GRPC_COMPRESS_NONE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_NONE, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_GZIP == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_LOW, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_MED, + accepted_encodings)); + + GPR_ASSERT(GRPC_COMPRESS_DEFLATE == + grpc_compression_algorithm_for_level(GRPC_COMPRESS_LEVEL_HIGH, + accepted_encodings)); + } +} + +static void test_compression_enable_disable_algorithm(void) { + grpc_compression_options options; + grpc_compression_algorithm algorithm; + + gpr_log(GPR_DEBUG, "test_compression_enable_disable_algorithm"); + + grpc_compression_options_init(&options); + for (algorithm = GRPC_COMPRESS_NONE; + algorithm < GRPC_COMPRESS_ALGORITHMS_COUNT; + algorithm = static_cast( + static_cast(algorithm) + 1)) { + /* all algorithms are enabled by default */ + GPR_ASSERT(grpc_compression_options_is_algorithm_enabled(&options, + algorithm) != 0); + } + /* disable one by one */ + for (algorithm = GRPC_COMPRESS_NONE; + algorithm < GRPC_COMPRESS_ALGORITHMS_COUNT; + algorithm = static_cast( + static_cast(algorithm) + 1)) { + grpc_compression_options_disable_algorithm(&options, algorithm); + GPR_ASSERT(grpc_compression_options_is_algorithm_enabled(&options, + algorithm) == 0); + } + /* re-enable one by one */ + for (algorithm = GRPC_COMPRESS_NONE; + algorithm < GRPC_COMPRESS_ALGORITHMS_COUNT; + algorithm = static_cast( + static_cast(algorithm) + 1)) { + grpc_compression_options_enable_algorithm(&options, algorithm); + GPR_ASSERT(grpc_compression_options_is_algorithm_enabled(&options, + algorithm) != 0); + } +} + +static void test_channel_args_set_compression_algorithm(void) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args* ch_args; + + ch_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_GZIP); + GPR_ASSERT(ch_args->num_args == 1); + GPR_ASSERT(strcmp(ch_args->args[0].key, + GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM) == 0); + GPR_ASSERT(ch_args->args[0].type == GRPC_ARG_INTEGER); + + grpc_channel_args_destroy(ch_args); +} + +static void test_channel_args_compression_algorithm_states(void) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args *ch_args, *ch_args_wo_gzip, *ch_args_wo_gzip_deflate, + *ch_args_wo_gzip_deflate_gzip; + unsigned states_bitset; + size_t i; + + ch_args = grpc_channel_args_copy_and_add(nullptr, nullptr, 0); + /* by default, all enabled */ + states_bitset = static_cast( + grpc_channel_args_compression_algorithm_get_states(ch_args)); + + for (i = 0; i < GRPC_COMPRESS_ALGORITHMS_COUNT; i++) { + GPR_ASSERT(grpc_core::GetBit(states_bitset, i)); + } + + /* disable gzip and deflate and stream/gzip */ + ch_args_wo_gzip = grpc_channel_args_compression_algorithm_set_state( + &ch_args, GRPC_COMPRESS_GZIP, 0); + GPR_ASSERT(ch_args == ch_args_wo_gzip); + ch_args_wo_gzip_deflate = grpc_channel_args_compression_algorithm_set_state( + &ch_args_wo_gzip, GRPC_COMPRESS_DEFLATE, 0); + GPR_ASSERT(ch_args_wo_gzip == ch_args_wo_gzip_deflate); + ch_args_wo_gzip_deflate_gzip = + grpc_channel_args_compression_algorithm_set_state( + &ch_args_wo_gzip_deflate, GRPC_COMPRESS_STREAM_GZIP, 0); + GPR_ASSERT(ch_args_wo_gzip_deflate == ch_args_wo_gzip_deflate_gzip); + + states_bitset = + static_cast(grpc_channel_args_compression_algorithm_get_states( + ch_args_wo_gzip_deflate)); + for (i = 0; i < GRPC_COMPRESS_ALGORITHMS_COUNT; i++) { + if (i == GRPC_COMPRESS_GZIP || i == GRPC_COMPRESS_DEFLATE || + i == GRPC_COMPRESS_STREAM_GZIP) { + GPR_ASSERT(grpc_core::GetBit(states_bitset, i) == 0); + } else { + GPR_ASSERT(grpc_core::GetBit(states_bitset, i) != 0); + } + } + + /* re-enabled gzip and stream/gzip only */ + ch_args_wo_gzip = grpc_channel_args_compression_algorithm_set_state( + &ch_args_wo_gzip_deflate_gzip, GRPC_COMPRESS_GZIP, 1); + ch_args_wo_gzip = grpc_channel_args_compression_algorithm_set_state( + &ch_args_wo_gzip, GRPC_COMPRESS_STREAM_GZIP, 1); + GPR_ASSERT(ch_args_wo_gzip == ch_args_wo_gzip_deflate_gzip); + + states_bitset = static_cast( + grpc_channel_args_compression_algorithm_get_states(ch_args_wo_gzip)); + for (i = 0; i < GRPC_COMPRESS_ALGORITHMS_COUNT; i++) { + if (i == GRPC_COMPRESS_DEFLATE) { + GPR_ASSERT(grpc_core::GetBit(states_bitset, i) == 0); + } else { + GPR_ASSERT(grpc_core::GetBit(states_bitset, i) != 0); + } + } + + grpc_channel_args_destroy(ch_args); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_compression_algorithm_parse(); + test_compression_algorithm_name(); + test_compression_algorithm_for_level(); + test_compression_enable_disable_algorithm(); + test_channel_args_set_compression_algorithm(); + test_channel_args_compression_algorithm_states(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/compression/message_compress_fuzzer.cc b/test/core/compression/message_compress_fuzzer.cc new file mode 100644 index 00000000..c698e901 --- /dev/null +++ b/test/core/compression/message_compress_fuzzer.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "src/core/lib/compression/message_compress.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/util/memory_counters.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size < 1) return 0; + + // Instead of rolling something complicated to convert a uint8_t to the enum, + // just bail out if it isn't trivially convertible. + if (data[0] >= GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) return 0; + const auto compression_algorithm = + static_cast(data[0]); + + grpc_core::testing::LeakDetector leak_detector(true); + grpc_init(); + grpc_slice_buffer input_buffer; + grpc_slice_buffer_init(&input_buffer); + grpc_slice_buffer_add(&input_buffer, + grpc_slice_from_copied_buffer( + reinterpret_cast(data + 1), size - 1)); + grpc_slice_buffer output_buffer; + grpc_slice_buffer_init(&output_buffer); + + grpc_msg_compress(compression_algorithm, &input_buffer, &output_buffer); + + grpc_slice_buffer_destroy(&input_buffer); + grpc_slice_buffer_destroy(&output_buffer); + grpc_shutdown(); + return 0; +} diff --git a/test/core/compression/message_compress_test.cc b/test/core/compression/message_compress_test.cc new file mode 100644 index 00000000..d11e8e42 --- /dev/null +++ b/test/core/compression/message_compress_test.cc @@ -0,0 +1,352 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/compression/message_compress.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/slice_splitter.h" +#include "test/core/util/test_config.h" + +typedef enum { ONE_A = 0, ONE_KB_A, ONE_MB_A, TEST_VALUE_COUNT } test_value; + +typedef enum { + SHOULD_NOT_COMPRESS, + SHOULD_COMPRESS, + MAYBE_COMPRESSES +} compressability; + +static void assert_passthrough(grpc_slice value, + grpc_message_compression_algorithm algorithm, + grpc_slice_split_mode uncompressed_split_mode, + grpc_slice_split_mode compressed_split_mode, + compressability compress_result_check) { + grpc_slice_buffer input; + grpc_slice_buffer compressed_raw; + grpc_slice_buffer compressed; + grpc_slice_buffer output; + grpc_slice final; + int was_compressed; + const char* algorithm_name; + + GPR_ASSERT( + grpc_message_compression_algorithm_name(algorithm, &algorithm_name) != 0); + gpr_log(GPR_INFO, + "assert_passthrough: value_length=%" PRIuPTR + " value_hash=0x%08x " + "algorithm='%s' uncompressed_split='%s' compressed_split='%s'", + GRPC_SLICE_LENGTH(value), + gpr_murmur_hash3(GRPC_SLICE_START_PTR(value), + GRPC_SLICE_LENGTH(value), 0), + algorithm_name, grpc_slice_split_mode_name(uncompressed_split_mode), + grpc_slice_split_mode_name(compressed_split_mode)); + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&compressed_raw); + grpc_slice_buffer_init(&compressed); + grpc_slice_buffer_init(&output); + + grpc_split_slices_to_buffer(uncompressed_split_mode, &value, 1, &input); + + { + grpc_core::ExecCtx exec_ctx; + was_compressed = grpc_msg_compress(algorithm, &input, &compressed_raw); + } + GPR_ASSERT(input.count > 0); + + switch (compress_result_check) { + case SHOULD_NOT_COMPRESS: + GPR_ASSERT(was_compressed == 0); + break; + case SHOULD_COMPRESS: + GPR_ASSERT(was_compressed == 1); + break; + case MAYBE_COMPRESSES: + /* no check */ + break; + } + + grpc_split_slice_buffer(compressed_split_mode, &compressed_raw, &compressed); + + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(grpc_msg_decompress( + was_compressed ? algorithm : GRPC_MESSAGE_COMPRESS_NONE, &compressed, + &output)); + } + + final = grpc_slice_merge(output.slices, output.count); + GPR_ASSERT(grpc_slice_eq(value, final)); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&compressed); + grpc_slice_buffer_destroy(&compressed_raw); + grpc_slice_buffer_destroy(&output); + grpc_slice_unref(final); +} + +static grpc_slice repeated(char c, size_t length) { + grpc_slice out = grpc_slice_malloc(length); + memset(GRPC_SLICE_START_PTR(out), c, length); + return out; +} + +static compressability get_compressability( + test_value id, grpc_message_compression_algorithm algorithm) { + if (algorithm == GRPC_MESSAGE_COMPRESS_NONE) return SHOULD_NOT_COMPRESS; + switch (id) { + case ONE_A: + return SHOULD_NOT_COMPRESS; + case ONE_KB_A: + case ONE_MB_A: + return SHOULD_COMPRESS; + case TEST_VALUE_COUNT: + abort(); + } + return MAYBE_COMPRESSES; +} + +static grpc_slice create_test_value(test_value id) { + switch (id) { + case ONE_A: + return grpc_slice_from_copied_string("a"); + case ONE_KB_A: + return repeated('a', 1024); + case ONE_MB_A: + return repeated('a', 1024 * 1024); + case TEST_VALUE_COUNT: + abort(); + } + return grpc_slice_from_copied_string("bad value"); +} + +static void test_tiny_data_compress(void) { + grpc_slice_buffer input; + grpc_slice_buffer output; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add(&input, create_test_value(ONE_A)); + + for (int i = 0; i < GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT; i++) { + if (i == GRPC_MESSAGE_COMPRESS_NONE) continue; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(0 == grpc_msg_compress( + + static_cast(i), + &input, &output)); + GPR_ASSERT(1 == output.count); + } + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_decompression_data_crc(void) { + grpc_slice_buffer input; + grpc_slice_buffer corrupted; + grpc_slice_buffer output; + size_t idx; + const uint32_t bad = 0xdeadbeef; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&corrupted); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add(&input, create_test_value(ONE_MB_A)); + + grpc_core::ExecCtx exec_ctx; + /* compress it */ + grpc_msg_compress(GRPC_MESSAGE_COMPRESS_GZIP, &input, &corrupted); + /* corrupt the output by smashing the CRC */ + GPR_ASSERT(corrupted.count > 1); + GPR_ASSERT(GRPC_SLICE_LENGTH(corrupted.slices[1]) > 8); + idx = GRPC_SLICE_LENGTH(corrupted.slices[1]) - 8; + memcpy(GRPC_SLICE_START_PTR(corrupted.slices[1]) + idx, &bad, 4); + + /* try (and fail) to decompress the corrupted compresed buffer */ + GPR_ASSERT(0 == grpc_msg_decompress(GRPC_MESSAGE_COMPRESS_GZIP, &corrupted, + &output)); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&corrupted); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_decompression_data_missing_trailer(void) { + grpc_slice_buffer input; + grpc_slice_buffer decompressed; + grpc_slice_buffer garbage; + grpc_slice_buffer output; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&decompressed); + grpc_slice_buffer_init(&garbage); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add(&input, create_test_value(ONE_MB_A)); + + grpc_core::ExecCtx exec_ctx; + /* compress it */ + grpc_msg_compress(GRPC_MESSAGE_COMPRESS_GZIP, &input, &decompressed); + GPR_ASSERT(decompressed.length > 8); + /* Remove the footer from the decompressed message */ + grpc_slice_buffer_trim_end(&decompressed, 8, &garbage); + /* try (and fail) to decompress the compressed buffer without the footer */ + GPR_ASSERT(0 == grpc_msg_decompress(GRPC_MESSAGE_COMPRESS_GZIP, &decompressed, + &output)); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&decompressed); + grpc_slice_buffer_destroy(&garbage); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_decompression_data_trailing_garbage(void) { + grpc_slice_buffer input; + grpc_slice_buffer output; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&output); + /* append 0x99 to the end of an otherwise valid stream */ + grpc_slice_buffer_add( + &input, grpc_slice_from_copied_buffer( + "\x78\xda\x63\x60\x60\x60\x00\x00\x00\x04\x00\x01\x99", 13)); + + /* try (and fail) to decompress the invalid compresed buffer */ + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT( + 0 == grpc_msg_decompress(GRPC_MESSAGE_COMPRESS_DEFLATE, &input, &output)); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_decompression_data_stream(void) { + grpc_slice_buffer input; + grpc_slice_buffer output; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add(&input, + grpc_slice_from_copied_buffer("\x78\xda\xff\xff", 4)); + + /* try (and fail) to decompress the invalid compresed buffer */ + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT( + 0 == grpc_msg_decompress(GRPC_MESSAGE_COMPRESS_DEFLATE, &input, &output)); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_compression_algorithm(void) { + grpc_slice_buffer input; + grpc_slice_buffer output; + int was_compressed; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add( + &input, grpc_slice_from_copied_string("Never gonna give you up")); + + grpc_core::ExecCtx exec_ctx; + was_compressed = grpc_msg_compress(GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT, + &input, &output); + GPR_ASSERT(0 == was_compressed); + + was_compressed = + grpc_msg_compress(static_cast( + GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT + 123), + &input, &output); + GPR_ASSERT(0 == was_compressed); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&output); +} + +static void test_bad_decompression_algorithm(void) { + grpc_slice_buffer input; + grpc_slice_buffer output; + int was_decompressed; + + grpc_slice_buffer_init(&input); + grpc_slice_buffer_init(&output); + grpc_slice_buffer_add(&input, + grpc_slice_from_copied_string( + "I'm not really compressed but it doesn't matter")); + grpc_core::ExecCtx exec_ctx; + was_decompressed = grpc_msg_decompress(GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT, + &input, &output); + GPR_ASSERT(0 == was_decompressed); + + was_decompressed = + grpc_msg_decompress(static_cast( + GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT + 123), + &input, &output); + GPR_ASSERT(0 == was_decompressed); + + grpc_slice_buffer_destroy(&input); + grpc_slice_buffer_destroy(&output); +} + +int main(int argc, char** argv) { + unsigned i, j, k, m; + grpc_slice_split_mode uncompressed_split_modes[] = { + GRPC_SLICE_SPLIT_IDENTITY, GRPC_SLICE_SPLIT_ONE_BYTE}; + grpc_slice_split_mode compressed_split_modes[] = {GRPC_SLICE_SPLIT_MERGE_ALL, + GRPC_SLICE_SPLIT_IDENTITY, + GRPC_SLICE_SPLIT_ONE_BYTE}; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + for (i = 0; i < GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT; i++) { + for (j = 0; j < GPR_ARRAY_SIZE(uncompressed_split_modes); j++) { + for (k = 0; k < GPR_ARRAY_SIZE(compressed_split_modes); k++) { + for (m = 0; m < TEST_VALUE_COUNT; m++) { + grpc_slice slice = create_test_value(static_cast(m)); + assert_passthrough( + slice, static_cast(i), + static_cast(j), + static_cast(k), + get_compressability( + static_cast(m), + static_cast(i))); + grpc_slice_unref(slice); + } + } + } + } + + test_tiny_data_compress(); + test_bad_decompression_data_crc(); + test_bad_decompression_data_missing_trailer(); + test_bad_decompression_data_stream(); + test_bad_decompression_data_trailing_garbage(); + test_bad_compression_algorithm(); + test_bad_decompression_algorithm(); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/compression/message_decompress_fuzzer.cc b/test/core/compression/message_decompress_fuzzer.cc new file mode 100644 index 00000000..1d6eb45e --- /dev/null +++ b/test/core/compression/message_decompress_fuzzer.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "src/core/lib/compression/message_compress.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/util/memory_counters.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size < 1) return 0; + + // Instead of rolling something complicated to convert a uint8_t to the enum, + // just bail out if it isn't trivially convertible. + if (data[0] >= GRPC_MESSAGE_COMPRESS_ALGORITHMS_COUNT) return 0; + const auto compression_algorithm = + static_cast(data[0]); + + grpc_core::testing::LeakDetector leak_detector(true); + grpc_init(); + grpc_slice_buffer input_buffer; + grpc_slice_buffer_init(&input_buffer); + grpc_slice_buffer_add(&input_buffer, + grpc_slice_from_copied_buffer( + reinterpret_cast(data + 1), size - 1)); + grpc_slice_buffer output_buffer; + grpc_slice_buffer_init(&output_buffer); + + grpc_msg_decompress(compression_algorithm, &input_buffer, &output_buffer); + + grpc_slice_buffer_destroy(&input_buffer); + grpc_slice_buffer_destroy(&output_buffer); + grpc_shutdown(); + return 0; +} diff --git a/test/core/compression/stream_compression_fuzzer.cc b/test/core/compression/stream_compression_fuzzer.cc new file mode 100644 index 00000000..55bc4d6a --- /dev/null +++ b/test/core/compression/stream_compression_fuzzer.cc @@ -0,0 +1,53 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "src/core/lib/compression/stream_compression.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/util/memory_counters.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_core::testing::LeakDetector leak_detector(true); + grpc_init(); + auto* context = grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_slice_buffer input_buffer; + grpc_slice_buffer_init(&input_buffer); + grpc_slice_buffer_add( + &input_buffer, + grpc_slice_from_copied_buffer(reinterpret_cast(data), size)); + grpc_slice_buffer output_buffer; + grpc_slice_buffer_init(&output_buffer); + + grpc_stream_compress(context, &input_buffer, &output_buffer, nullptr, + SIZE_MAX, GRPC_STREAM_COMPRESSION_FLUSH_SYNC); + + grpc_stream_compression_context_destroy(context); + grpc_slice_buffer_destroy(&input_buffer); + grpc_slice_buffer_destroy(&output_buffer); + grpc_shutdown(); + return 0; +} diff --git a/test/core/compression/stream_compression_test.cc b/test/core/compression/stream_compression_test.cc new file mode 100644 index 00000000..252652f0 --- /dev/null +++ b/test/core/compression/stream_compression_test.cc @@ -0,0 +1,302 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/compression/stream_compression.h" + +#include + +#include +#include +#include +#include + +#include "test/core/util/test_config.h" + +static void generate_random_payload(char* payload, size_t size) { + size_t i; + static const char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + for (i = 0; i < size - 1; ++i) { + payload[i] = chars[rand() % static_cast(sizeof(chars) - 1)]; + } + payload[size - 1] = '\0'; +} + +static bool slice_buffer_equals_string(grpc_slice_buffer* buf, + const char* str) { + size_t i; + if (buf->length != strlen(str)) { + return false; + } + size_t pointer = 0; + for (i = 0; i < buf->count; i++) { + size_t slice_len = GRPC_SLICE_LENGTH(buf->slices[i]); + if (0 != + strncmp(str + pointer, + reinterpret_cast GRPC_SLICE_START_PTR(buf->slices[i]), + slice_len)) { + return false; + } + pointer += slice_len; + } + return true; +} + +static void test_stream_compression_simple_compress_decompress() { + const char test_str[] = "aaaaaaabbbbbbbccccccctesttesttest"; + grpc_slice_buffer source, relay, sink; + grpc_slice_buffer_init(&source); + grpc_slice_buffer_init(&relay); + grpc_slice_buffer_init(&sink); + grpc_stream_compression_context* compress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_stream_compression_context* decompress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + grpc_slice slice = grpc_slice_from_static_string(test_str); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + bool end_of_context; + size_t output_size; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(output_size == sizeof(test_str) - 1); + grpc_stream_compression_context_destroy(compress_ctx); + grpc_stream_compression_context_destroy(decompress_ctx); + + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str)); + + grpc_slice_buffer_destroy(&source); + grpc_slice_buffer_destroy(&relay); + grpc_slice_buffer_destroy(&sink); +} + +static void +test_stream_compression_simple_compress_decompress_with_output_size_constraint() { + const char test_str[] = "aaaaaaabbbbbbbccccccctesttesttest"; + grpc_slice_buffer source, relay, sink; + grpc_slice_buffer_init(&source); + grpc_slice_buffer_init(&relay); + grpc_slice_buffer_init(&sink); + grpc_stream_compression_context* compress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_stream_compression_context* decompress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + grpc_slice slice = grpc_slice_from_static_string(test_str); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + grpc_stream_compression_context_destroy(compress_ctx); + + bool end_of_context; + size_t output_size; + size_t max_output_size = 2; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + max_output_size, &end_of_context)); + GPR_ASSERT(output_size == max_output_size); + GPR_ASSERT(end_of_context == false); + grpc_slice slice_recv = grpc_slice_buffer_take_first(&sink); + char* str_recv = reinterpret_cast GRPC_SLICE_START_PTR(slice_recv); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice_recv) == max_output_size); + GPR_ASSERT(0 == strncmp(test_str, str_recv, max_output_size)); + grpc_slice_unref(slice_recv); + + size_t remaining_size = sizeof(test_str) - 1 - max_output_size; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + remaining_size, &end_of_context)); + GPR_ASSERT(output_size == remaining_size); + GPR_ASSERT(end_of_context == true); + + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str + max_output_size)); + + grpc_stream_compression_context_destroy(decompress_ctx); + grpc_slice_buffer_destroy(&source); + grpc_slice_buffer_destroy(&relay); + grpc_slice_buffer_destroy(&sink); +} + +#define LARGE_DATA_SIZE (1024 * 1024) +static void +test_stream_compression_simple_compress_decompress_with_large_data() { + char* test_str = + static_cast(gpr_malloc(LARGE_DATA_SIZE * sizeof(char))); + generate_random_payload(test_str, LARGE_DATA_SIZE); + grpc_slice_buffer source, relay, sink; + grpc_slice_buffer_init(&source); + grpc_slice_buffer_init(&relay); + grpc_slice_buffer_init(&sink); + grpc_stream_compression_context* compress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_stream_compression_context* decompress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + grpc_slice slice = grpc_slice_from_static_string(test_str); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + bool end_of_context; + size_t output_size; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(output_size == LARGE_DATA_SIZE - 1); + grpc_stream_compression_context_destroy(compress_ctx); + grpc_stream_compression_context_destroy(decompress_ctx); + + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str)); + + grpc_slice_buffer_destroy(&source); + grpc_slice_buffer_destroy(&relay); + grpc_slice_buffer_destroy(&sink); + gpr_free(test_str); +} + +static void test_stream_compression_drop_context() { + const char test_str[] = "aaaaaaabbbbbbbccccccc"; + const char test_str2[] = "dddddddeeeeeeefffffffggggg"; + grpc_slice_buffer source, relay, sink; + grpc_slice_buffer_init(&source); + grpc_slice_buffer_init(&relay); + grpc_slice_buffer_init(&sink); + grpc_stream_compression_context* compress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_slice slice = grpc_slice_from_static_string(test_str); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + grpc_stream_compression_context_destroy(compress_ctx); + + compress_ctx = grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + slice = grpc_slice_from_static_string(test_str2); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + grpc_stream_compression_context_destroy(compress_ctx); + + /* Concatenate the two compressed sliced into one to test decompressing two + * contexts */ + grpc_slice slice1 = grpc_slice_buffer_take_first(&relay); + grpc_slice slice2 = grpc_slice_buffer_take_first(&relay); + grpc_slice slice3 = + grpc_slice_malloc(GRPC_SLICE_LENGTH(slice1) + GRPC_SLICE_LENGTH(slice2)); + memcpy(GRPC_SLICE_START_PTR(slice3), GRPC_SLICE_START_PTR(slice1), + GRPC_SLICE_LENGTH(slice1)); + memcpy(GRPC_SLICE_START_PTR(slice3) + GRPC_SLICE_LENGTH(slice1), + GRPC_SLICE_START_PTR(slice2), GRPC_SLICE_LENGTH(slice2)); + grpc_slice_unref(slice1); + grpc_slice_unref(slice2); + grpc_slice_buffer_add(&relay, slice3); + + grpc_stream_compression_context* decompress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + bool end_of_context; + size_t output_size; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(end_of_context == true); + GPR_ASSERT(output_size == sizeof(test_str) - 1); + + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str)); + grpc_stream_compression_context_destroy(decompress_ctx); + grpc_slice_buffer_destroy(&sink); + + grpc_slice_buffer_init(&sink); + decompress_ctx = grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(end_of_context == true); + GPR_ASSERT(output_size == sizeof(test_str2) - 1); + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str2)); + grpc_stream_compression_context_destroy(decompress_ctx); + + grpc_slice_buffer_destroy(&source); + grpc_slice_buffer_destroy(&relay); + grpc_slice_buffer_destroy(&sink); +} + +static void test_stream_compression_sync_flush() { + const char test_str[] = "aaaaaaabbbbbbbccccccc"; + const char test_str2[] = "dddddddeeeeeeefffffffggggg"; + grpc_slice_buffer source, relay, sink; + grpc_slice_buffer_init(&source); + grpc_slice_buffer_init(&relay); + grpc_slice_buffer_init(&sink); + grpc_stream_compression_context* compress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_COMPRESS); + grpc_slice slice = grpc_slice_from_static_string(test_str); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_SYNC)); + + grpc_stream_compression_context* decompress_ctx = + grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + bool end_of_context; + size_t output_size; + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(end_of_context == false); + GPR_ASSERT(output_size == sizeof(test_str) - 1); + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str)); + grpc_slice_buffer_destroy(&sink); + + grpc_slice_buffer_init(&sink); + slice = grpc_slice_from_static_string(test_str2); + grpc_slice_buffer_add(&source, slice); + GPR_ASSERT(grpc_stream_compress(compress_ctx, &source, &relay, nullptr, + ~(size_t)0, + GRPC_STREAM_COMPRESSION_FLUSH_FINISH)); + grpc_stream_compression_context_destroy(compress_ctx); + + GPR_ASSERT(grpc_stream_decompress(decompress_ctx, &relay, &sink, &output_size, + ~(size_t)0, &end_of_context)); + GPR_ASSERT(end_of_context == true); + GPR_ASSERT(output_size == sizeof(test_str2) - 1); + GPR_ASSERT(slice_buffer_equals_string(&sink, test_str2)); + grpc_stream_compression_context_destroy(decompress_ctx); + + grpc_slice_buffer_destroy(&source); + grpc_slice_buffer_destroy(&relay); + grpc_slice_buffer_destroy(&sink); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_stream_compression_simple_compress_decompress(); + test_stream_compression_simple_compress_decompress_with_output_size_constraint(); + test_stream_compression_simple_compress_decompress_with_large_data(); + test_stream_compression_sync_flush(); + test_stream_compression_drop_context(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/compression/stream_decompression_fuzzer.cc b/test/core/compression/stream_decompression_fuzzer.cc new file mode 100644 index 00000000..cbfbd5d4 --- /dev/null +++ b/test/core/compression/stream_decompression_fuzzer.cc @@ -0,0 +1,54 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "src/core/lib/compression/stream_compression.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/util/memory_counters.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_core::testing::LeakDetector leak_detector(true); + grpc_init(); + auto* context = grpc_stream_compression_context_create( + GRPC_STREAM_COMPRESSION_GZIP_DECOMPRESS); + grpc_slice_buffer input_buffer; + grpc_slice_buffer_init(&input_buffer); + grpc_slice_buffer_add( + &input_buffer, + grpc_slice_from_copied_buffer(reinterpret_cast(data), size)); + grpc_slice_buffer output_buffer; + grpc_slice_buffer_init(&output_buffer); + bool end_of_context; + + grpc_stream_decompress(context, &input_buffer, &output_buffer, nullptr, + SIZE_MAX, &end_of_context); + + grpc_stream_compression_context_destroy(context); + grpc_slice_buffer_destroy(&input_buffer); + grpc_slice_buffer_destroy(&output_buffer); + grpc_shutdown(); + return 0; +} diff --git a/test/core/config/core_configuration_test.cc b/test/core/config/core_configuration_test.cc new file mode 100644 index 00000000..814c383a --- /dev/null +++ b/test/core/config/core_configuration_test.cc @@ -0,0 +1,73 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/config/core_configuration.h" + +#include +#include + +#include + +namespace grpc_core { + +// Allow substitution of config builder - in real code this would iterate +// through all plugins +namespace testing { +using ConfigBuilderFunction = std::function; +static ConfigBuilderFunction g_mock_builder; +} // namespace testing + +void BuildCoreConfiguration(CoreConfiguration::Builder* builder) { + ::grpc_core::testing::g_mock_builder(builder); +} + +namespace testing { +// Helper for testing - clear out any state, rebuild configuration with fn being +// the initializer +void InitConfigWithBuilder(ConfigBuilderFunction fn) { + CoreConfiguration::Reset(); + g_mock_builder = fn; + CoreConfiguration::Get(); + g_mock_builder = nullptr; +} + +TEST(ConfigTest, NoopConfig) { + InitConfigWithBuilder([](CoreConfiguration::Builder*) {}); + CoreConfiguration::Get(); +} + +TEST(ConfigTest, ThreadedInit) { + CoreConfiguration::Reset(); + g_mock_builder = [](CoreConfiguration::Builder*) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + }; + std::vector threads; + threads.reserve(64); + for (int i = 0; i < 64; i++) { + threads.push_back(std::thread([]() { CoreConfiguration::Get(); })); + } + for (auto& t : threads) { + t.join(); + } + g_mock_builder = nullptr; + CoreConfiguration::Get(); +} +} // namespace testing + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/debug/stats_test.cc b/test/core/debug/stats_test.cc new file mode 100644 index 00000000..fe8fe495 --- /dev/null +++ b/test/core/debug/stats_test.cc @@ -0,0 +1,163 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/debug/stats.h" + +#include +#include + +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { + +class Snapshot { + public: + Snapshot() { grpc_stats_collect(&begin_); } + + grpc_stats_data delta() { + grpc_stats_data now; + grpc_stats_collect(&now); + grpc_stats_data delta; + grpc_stats_diff(&now, &begin_, &delta); + return delta; + } + + private: + grpc_stats_data begin_; +}; + +TEST(StatsTest, IncCounters) { + for (int i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + std::unique_ptr snapshot(new Snapshot); + + grpc_core::ExecCtx exec_ctx; + GRPC_STATS_INC_COUNTER((grpc_stats_counters)i); + + EXPECT_EQ(snapshot->delta().counters[i], 1); + } +} + +TEST(StatsTest, IncSpecificCounter) { + std::unique_ptr snapshot(new Snapshot); + + grpc_core::ExecCtx exec_ctx; + GRPC_STATS_INC_SYSCALL_POLL(); + + EXPECT_EQ(snapshot->delta().counters[GRPC_STATS_COUNTER_SYSCALL_POLL], 1); +} + +static int FindExpectedBucket(int i, int j) { + if (j < 0) { + return 0; + } + if (j >= grpc_stats_histo_bucket_boundaries[i][grpc_stats_histo_buckets[i]]) { + return grpc_stats_histo_buckets[i] - 1; + } + return std::upper_bound(grpc_stats_histo_bucket_boundaries[i], + grpc_stats_histo_bucket_boundaries[i] + + grpc_stats_histo_buckets[i], + j) - + grpc_stats_histo_bucket_boundaries[i] - 1; +} + +class HistogramTest : public ::testing::TestWithParam {}; + +TEST_P(HistogramTest, IncHistogram) { + const int kHistogram = GetParam(); + std::vector threads; + int cur_bucket = 0; + auto run = [kHistogram](const std::vector& test_values, + int expected_bucket) { + gpr_log(GPR_DEBUG, "expected_bucket:%d nvalues=%" PRIdPTR, expected_bucket, + test_values.size()); + for (auto j : test_values) { + std::unique_ptr snapshot(new Snapshot); + + grpc_core::ExecCtx exec_ctx; + grpc_stats_inc_histogram[kHistogram](j); + + auto delta = snapshot->delta(); + + EXPECT_EQ( + delta + .histograms[grpc_stats_histo_start[kHistogram] + expected_bucket], + 1) + << "\nhistogram:" << kHistogram + << "\nexpected_bucket:" << expected_bucket << "\nj:" << j; + } + }; + std::vector test_values; + // largest bucket boundary for current histogram type. + int max_bucket_boundary = + grpc_stats_histo_bucket_boundaries[kHistogram] + [grpc_stats_histo_buckets[kHistogram] - + 1]; + for (int j = -1000; j < max_bucket_boundary + 1000;) { + int expected_bucket = FindExpectedBucket(kHistogram, j); + if (cur_bucket != expected_bucket) { + threads.emplace_back( + [test_values, run, cur_bucket]() { run(test_values, cur_bucket); }); + cur_bucket = expected_bucket; + test_values.clear(); + } + test_values.push_back(j); + if (j < max_bucket_boundary && + FindExpectedBucket(kHistogram, j + 1000) == expected_bucket && + FindExpectedBucket(kHistogram, j - 1000) == expected_bucket) { + // if we are far from bucket boundary, skip values to speed-up the tests + j += 500; + } else { + j++; + } + } + run(test_values, cur_bucket); + for (auto& t : threads) { + t.join(); + } +} + +INSTANTIATE_TEST_SUITE_P(HistogramTestCases, HistogramTest, + ::testing::Range(0, GRPC_STATS_HISTOGRAM_COUNT)); + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { +/* Only run this test if GRPC_COLLECT_STATS is defined or if it is a debug + * build. + */ +#if defined(GRPC_COLLECT_STATS) || !defined(NDEBUG) + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +#else + // Avoid unused parameter warning for conditional parameters. + (void)argc; + (void)argv; +#endif +} diff --git a/test/core/end2end/bad_server_response_test.cc b/test/core/end2end/bad_server_response_test.cc new file mode 100644 index 00000000..c906ad5a --- /dev/null +++ b/test/core/end2end/bad_server_response_test.cc @@ -0,0 +1,389 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_tcp_server.h" + +#define HTTP1_RESP_400 \ + "HTTP/1.0 400 Bad Request\n" \ + "Content-Type: text/html; charset=UTF-8\n" \ + "Content-Length: 0\n" \ + "Date: Tue, 07 Jun 2016 17:43:20 GMT\n\n" + +#define HTTP2_SETTINGS_FRAME "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + +#define HTTP2_RESP(STATUS_CODE) \ + "\x00\x00>\x01\x04\x00\x00\x00\x01" \ + "\x10\x0e" \ + "content-length\x01" \ + "0" \ + "\x10\x0c" \ + "content-type\x10" \ + "application/grpc" \ + "\x10\x07:status\x03" #STATUS_CODE + +#define UNPARSEABLE_RESP "Bad Request\n" + +#define HTTP2_DETAIL_MSG(STATUS_CODE) \ + "Received http2 header with status: " #STATUS_CODE + +/* TODO(zyc) Check the content of incoming data instead of using this length */ +/* The 'bad' server will start sending responses after reading this amount of + * data from the client. */ +#define SERVER_INCOMING_DATA_LENGTH_LOWER_THRESHOLD (size_t)200 + +struct rpc_state { + std::string target; + grpc_completion_queue* cq; + grpc_channel* channel; + grpc_call* call; + size_t incoming_data_length; + grpc_slice_buffer temp_incoming_buffer; + grpc_slice_buffer outgoing_buffer; + grpc_endpoint* tcp; + gpr_atm done_atm; + bool http2_response; + bool send_settings; + const char* response_payload; + size_t response_payload_length; + bool connection_attempt_made; +}; + +static int server_port; +static struct rpc_state state; +static grpc_closure on_read; +static grpc_closure on_writing_settings_frame; +static grpc_closure on_write; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void done_write(void* /*arg*/, grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + gpr_atm_rel_store(&state.done_atm, 1); +} + +static void done_writing_settings_frame(void* /* arg */, + grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_endpoint_read(state.tcp, &state.temp_incoming_buffer, &on_read, + /*urgent=*/false); +} + +static void handle_write() { + grpc_slice slice = grpc_slice_from_copied_buffer( + state.response_payload, state.response_payload_length); + + grpc_slice_buffer_reset_and_unref(&state.outgoing_buffer); + grpc_slice_buffer_add(&state.outgoing_buffer, slice); + grpc_endpoint_write(state.tcp, &state.outgoing_buffer, &on_write, nullptr); +} + +static void handle_read(void* /*arg*/, grpc_error_handle error) { + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "handle_read error: %s", + grpc_error_std_string(error).c_str()); + return; + } + state.incoming_data_length += state.temp_incoming_buffer.length; + + size_t i; + for (i = 0; i < state.temp_incoming_buffer.count; i++) { + char* dump = grpc_dump_slice(state.temp_incoming_buffer.slices[i], + GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "Server received: %s", dump); + gpr_free(dump); + } + + gpr_log(GPR_DEBUG, + "got %" PRIuPTR " bytes, expected %" PRIuPTR + " bytes or a non-HTTP2 response to be sent", + state.incoming_data_length, + SERVER_INCOMING_DATA_LENGTH_LOWER_THRESHOLD); + if (state.incoming_data_length >= + SERVER_INCOMING_DATA_LENGTH_LOWER_THRESHOLD || + !state.http2_response) { + handle_write(); + } else { + grpc_endpoint_read(state.tcp, &state.temp_incoming_buffer, &on_read, + /*urgent=*/false); + } +} + +static void on_connect(void* arg, grpc_endpoint* tcp, + grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + gpr_free(acceptor); + test_tcp_server* server = static_cast(arg); + GRPC_CLOSURE_INIT(&on_read, handle_read, nullptr, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_writing_settings_frame, done_writing_settings_frame, + nullptr, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&on_write, done_write, nullptr, grpc_schedule_on_exec_ctx); + grpc_slice_buffer_init(&state.temp_incoming_buffer); + grpc_slice_buffer_init(&state.outgoing_buffer); + state.connection_attempt_made = true; + state.tcp = tcp; + state.incoming_data_length = 0; + grpc_endpoint_add_to_pollset(tcp, server->pollset[0]); + if (state.send_settings) { + // Send settings frame from server + grpc_slice slice = grpc_slice_from_static_buffer( + HTTP2_SETTINGS_FRAME, sizeof(HTTP2_SETTINGS_FRAME) - 1); + grpc_slice_buffer_add(&state.outgoing_buffer, slice); + grpc_endpoint_write(state.tcp, &state.outgoing_buffer, + &on_writing_settings_frame, nullptr); + } else { + grpc_endpoint_read(state.tcp, &state.temp_incoming_buffer, &on_read, + /*urgent=*/false); + } +} + +static gpr_timespec n_sec_deadline(int seconds) { + return gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(seconds, GPR_TIMESPAN)); +} + +static void start_rpc(int target_port, grpc_status_code expected_status, + const char* expected_detail) { + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_call_error error; + cq_verifier* cqv; + grpc_slice details; + + state.cq = grpc_completion_queue_create_for_next(nullptr); + cqv = cq_verifier_create(state.cq); + state.target = grpc_core::JoinHostPort("127.0.0.1", target_port); + + state.channel = + grpc_insecure_channel_create(state.target.c_str(), nullptr, nullptr); + grpc_slice host = grpc_slice_from_static_string("localhost"); + // The default connect deadline is 20 seconds, so reduce the RPC deadline to 1 + // second. This helps us verify - a) If the server responded with a non-HTTP2 + // response, the connect fails immediately resulting in + // GRPC_STATUS_UNAVAILABLE instead of GRPC_STATUS_DEADLINE_EXCEEDED. b) If the + // server does not send a HTTP2 SETTINGs frame, the RPC fails with a + // DEADLINE_EXCEEDED. + state.call = grpc_channel_create_call( + state.channel, nullptr, GRPC_PROPAGATE_DEFAULTS, state.cq, + grpc_slice_from_static_string("/Service/Method"), &host, + n_sec_deadline(5), nullptr); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(state.call, ops, static_cast(op - ops), + tag(1), nullptr); + + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == expected_status); + if (expected_detail != nullptr) { + GPR_ASSERT(-1 != grpc_slice_slice(details, grpc_slice_from_static_string( + expected_detail))); + } + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_slice_unref(details); + cq_verifier_destroy(cqv); +} + +static void cleanup_rpc() { + grpc_event ev; + grpc_slice_buffer_destroy_internal(&state.temp_incoming_buffer); + grpc_slice_buffer_destroy_internal(&state.outgoing_buffer); + grpc_call_unref(state.call); + grpc_completion_queue_shutdown(state.cq); + do { + ev = grpc_completion_queue_next(state.cq, n_sec_deadline(1), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(state.cq); + grpc_channel_destroy(state.channel); + state.target.clear(); +} + +typedef struct { + test_tcp_server* server; + gpr_event* signal_when_done; +} poll_args; + +static void actually_poll_server(void* arg) { + poll_args* pa = static_cast(arg); + gpr_timespec deadline = n_sec_deadline(5); + while (true) { + bool done = gpr_atm_acq_load(&state.done_atm) != 0; + gpr_timespec time_left = + gpr_time_sub(deadline, gpr_now(GPR_CLOCK_REALTIME)); + gpr_log(GPR_DEBUG, "done=%d, time_left=%" PRId64 ".%09d", done, + time_left.tv_sec, time_left.tv_nsec); + if (done || gpr_time_cmp(time_left, gpr_time_0(GPR_TIMESPAN)) < 0) { + break; + } + test_tcp_server_poll(pa->server, 1000); + } + gpr_event_set(pa->signal_when_done, reinterpret_cast(1)); + gpr_free(pa); +} + +static grpc_core::Thread* poll_server_until_read_done( + test_tcp_server* server, gpr_event* signal_when_done) { + gpr_atm_rel_store(&state.done_atm, 0); + state.connection_attempt_made = false; + poll_args* pa = static_cast(gpr_malloc(sizeof(*pa))); + pa->server = server; + pa->signal_when_done = signal_when_done; + auto* th = + new grpc_core::Thread("grpc_poll_server", actually_poll_server, pa); + th->Start(); + return th; +} + +static void run_test(bool http2_response, bool send_settings, + const char* response_payload, + size_t response_payload_length, + grpc_status_code expected_status, + const char* expected_detail) { + test_tcp_server test_server; + grpc_core::ExecCtx exec_ctx; + gpr_event ev; + + grpc_init(); + gpr_event_init(&ev); + server_port = grpc_pick_unused_port_or_die(); + test_tcp_server_init(&test_server, on_connect, &test_server); + test_tcp_server_start(&test_server, server_port); + state.http2_response = http2_response; + state.send_settings = send_settings; + state.response_payload = response_payload; + state.response_payload_length = response_payload_length; + + /* poll server until sending out the response */ + std::unique_ptr thdptr( + poll_server_until_read_done(&test_server, &ev)); + start_rpc(server_port, expected_status, expected_detail); + gpr_event_wait(&ev, gpr_inf_future(GPR_CLOCK_REALTIME)); + thdptr->Join(); + /* Proof that the server accepted the TCP connection. */ + GPR_ASSERT(state.connection_attempt_made == true); + /* clean up */ + grpc_endpoint_shutdown(state.tcp, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test Shutdown")); + grpc_endpoint_destroy(state.tcp); + cleanup_rpc(); + grpc_core::ExecCtx::Get()->Flush(); + test_tcp_server_destroy(&test_server); + + grpc_shutdown(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* status defined in hpack static table */ + run_test(true, true, HTTP2_RESP(204), sizeof(HTTP2_RESP(204)) - 1, + GRPC_STATUS_UNKNOWN, HTTP2_DETAIL_MSG(204)); + run_test(true, true, HTTP2_RESP(206), sizeof(HTTP2_RESP(206)) - 1, + GRPC_STATUS_UNKNOWN, HTTP2_DETAIL_MSG(206)); + run_test(true, true, HTTP2_RESP(304), sizeof(HTTP2_RESP(304)) - 1, + GRPC_STATUS_UNKNOWN, HTTP2_DETAIL_MSG(304)); + run_test(true, true, HTTP2_RESP(400), sizeof(HTTP2_RESP(400)) - 1, + GRPC_STATUS_INTERNAL, HTTP2_DETAIL_MSG(400)); + run_test(true, true, HTTP2_RESP(404), sizeof(HTTP2_RESP(404)) - 1, + GRPC_STATUS_UNIMPLEMENTED, HTTP2_DETAIL_MSG(404)); + run_test(true, true, HTTP2_RESP(500), sizeof(HTTP2_RESP(500)) - 1, + GRPC_STATUS_UNKNOWN, HTTP2_DETAIL_MSG(500)); + + /* status not defined in hpack static table */ + run_test(true, true, HTTP2_RESP(401), sizeof(HTTP2_RESP(401)) - 1, + GRPC_STATUS_UNAUTHENTICATED, HTTP2_DETAIL_MSG(401)); + run_test(true, true, HTTP2_RESP(403), sizeof(HTTP2_RESP(403)) - 1, + GRPC_STATUS_PERMISSION_DENIED, HTTP2_DETAIL_MSG(403)); + run_test(true, true, HTTP2_RESP(429), sizeof(HTTP2_RESP(429)) - 1, + GRPC_STATUS_UNAVAILABLE, HTTP2_DETAIL_MSG(429)); + run_test(true, true, HTTP2_RESP(499), sizeof(HTTP2_RESP(499)) - 1, + GRPC_STATUS_UNKNOWN, HTTP2_DETAIL_MSG(499)); + run_test(true, true, HTTP2_RESP(502), sizeof(HTTP2_RESP(502)) - 1, + GRPC_STATUS_UNAVAILABLE, HTTP2_DETAIL_MSG(502)); + run_test(true, true, HTTP2_RESP(503), sizeof(HTTP2_RESP(503)) - 1, + GRPC_STATUS_UNAVAILABLE, HTTP2_DETAIL_MSG(503)); + run_test(true, true, HTTP2_RESP(504), sizeof(HTTP2_RESP(504)) - 1, + GRPC_STATUS_UNAVAILABLE, HTTP2_DETAIL_MSG(504)); + /* unparseable response. RPC should fail immediately due to a connect failure. + */ + run_test(false, false, UNPARSEABLE_RESP, sizeof(UNPARSEABLE_RESP) - 1, + GRPC_STATUS_UNAVAILABLE, nullptr); + + /* http1 response. RPC should fail immediately due to a connect failure. */ + run_test(false, false, HTTP1_RESP_400, sizeof(HTTP1_RESP_400) - 1, + GRPC_STATUS_UNAVAILABLE, nullptr); + + /* http2 response without sending a SETTINGs frame. RPC should fail with + * DEADLINE_EXCEEDED since the RPC deadline is lower than the connection + * attempt deadline. */ + run_test(true, false, HTTP2_RESP(404), sizeof(HTTP2_RESP(404)) - 1, + GRPC_STATUS_DEADLINE_EXCEEDED, nullptr); + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/connection_refused_test.cc b/test/core/end2end/connection_refused_test.cc new file mode 100644 index 00000000..6fd40fd5 --- /dev/null +++ b/test/core/end2end/connection_refused_test.cc @@ -0,0 +1,145 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void run_test(bool wait_for_ready, bool use_service_config) { + grpc_channel* chan; + grpc_call* call; + grpc_completion_queue* cq; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_slice details; + + gpr_log(GPR_INFO, "TEST: wait_for_ready=%d use_service_config=%d", + wait_for_ready, use_service_config); + + grpc_init(); + + grpc_metadata_array_init(&trailing_metadata_recv); + + cq = grpc_completion_queue_create_for_next(nullptr); + cqv = cq_verifier_create(cq); + + /* if using service config, create channel args */ + grpc_channel_args* args = nullptr; + if (use_service_config) { + GPR_ASSERT(wait_for_ready); + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + arg.value.string = const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"waitForReady\": true\n" + " } ]\n" + "}"); + args = grpc_channel_args_copy_and_add(args, &arg, 1); + } + + /* create a call, channel to a port which will refuse connection */ + int port = grpc_pick_unused_port_or_die(); + std::string addr = grpc_core::JoinHostPort("127.0.0.1", port); + gpr_log(GPR_INFO, "server: %s", addr.c_str()); + chan = grpc_insecure_channel_create(addr.c_str(), args, nullptr); + grpc_slice host = grpc_slice_from_static_string("nonexistant"); + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(2); + call = + grpc_channel_create_call(chan, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/service/method"), + &host, deadline, nullptr); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = (wait_for_ready && !use_service_config) + ? GRPC_INITIAL_METADATA_WAIT_FOR_READY + : 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), tag(1), + nullptr)); + /* verify that all tags get completed */ + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + if (wait_for_ready) { + GPR_ASSERT(status == GRPC_STATUS_DEADLINE_EXCEEDED); + } else { + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } + + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); + grpc_call_unref(call); + grpc_channel_destroy(chan); + cq_verifier_destroy(cqv); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + { + grpc_core::ExecCtx exec_ctx; + if (args != nullptr) grpc_channel_args_destroy(args); + } + + grpc_shutdown(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + run_test(false /* wait_for_ready */, false /* use_service_config */); + run_test(true /* wait_for_ready */, false /* use_service_config */); + run_test(true /* wait_for_ready */, true /* use_service_config */); + return 0; +} diff --git a/test/core/end2end/cq_verifier.cc b/test/core/end2end/cq_verifier.cc new file mode 100644 index 00000000..7d847278 --- /dev/null +++ b/test/core/end2end/cq_verifier.cc @@ -0,0 +1,325 @@ +// +// +// Copyright 2015 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "test/core/end2end/cq_verifier.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/compression/compression_internal.h" +#include "src/core/lib/compression/message_compress.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/event_string.h" + +#define ROOT_EXPECTATION 1000 + +// a set of metadata we expect to find on an event +typedef struct metadata { + size_t count; + size_t cap; + char** keys; + char** values; +} metadata; + +// details what we expect to find on a single event +struct Expectation { + Expectation(const char* f, int l, grpc_completion_type t, void* tag_arg, + bool check_success_arg, int success_arg, bool* seen_arg) + : file(f), + line(l), + type(t), + tag(tag_arg), + check_success(check_success_arg), + success(success_arg), + seen(seen_arg) {} + const char* file; + int line; + grpc_completion_type type; + void* tag; + bool check_success; + int success; + bool* seen; +}; + +// the verifier itself +struct cq_verifier { + // bound completion queue + grpc_completion_queue* cq; + // expectation list + std::list expectations; + // maybe expectation list + std::list maybe_expectations; +}; + +// TODO(yashykt): Convert this to constructor/destructor pair +cq_verifier* cq_verifier_create(grpc_completion_queue* cq) { + cq_verifier* v = new cq_verifier; + v->cq = cq; + return v; +} + +void cq_verifier_destroy(cq_verifier* v) { + cq_verify(v); + delete v; +} + +static int has_metadata(const grpc_metadata* md, size_t count, const char* key, + const char* value) { + size_t i; + for (i = 0; i < count; i++) { + if (0 == grpc_slice_str_cmp(md[i].key, key) && + 0 == grpc_slice_str_cmp(md[i].value, value)) { + return 1; + } + } + return 0; +} + +int contains_metadata(grpc_metadata_array* array, const char* key, + const char* value) { + return has_metadata(array->metadata, array->count, key, value); +} + +static int has_metadata_slices(const grpc_metadata* md, size_t count, + grpc_slice key, grpc_slice value) { + size_t i; + for (i = 0; i < count; i++) { + if (grpc_slice_eq(md[i].key, key) && grpc_slice_eq(md[i].value, value)) { + return 1; + } + } + return 0; +} + +int contains_metadata_slices(grpc_metadata_array* array, grpc_slice key, + grpc_slice value) { + return has_metadata_slices(array->metadata, array->count, key, value); +} + +static grpc_slice merge_slices(grpc_slice* slices, size_t nslices) { + size_t i; + size_t len = 0; + uint8_t* cursor; + grpc_slice out; + + for (i = 0; i < nslices; i++) { + len += GRPC_SLICE_LENGTH(slices[i]); + } + + out = grpc_slice_malloc(len); + cursor = GRPC_SLICE_START_PTR(out); + + for (i = 0; i < nslices; i++) { + memcpy(cursor, GRPC_SLICE_START_PTR(slices[i]), + GRPC_SLICE_LENGTH(slices[i])); + cursor += GRPC_SLICE_LENGTH(slices[i]); + } + + return out; +} + +int raw_byte_buffer_eq_slice(grpc_byte_buffer* rbb, grpc_slice b) { + grpc_slice a; + int ok; + + if (!rbb) return 0; + + a = merge_slices(rbb->data.raw.slice_buffer.slices, + rbb->data.raw.slice_buffer.count); + ok = GRPC_SLICE_LENGTH(a) == GRPC_SLICE_LENGTH(b) && + 0 == memcmp(GRPC_SLICE_START_PTR(a), GRPC_SLICE_START_PTR(b), + GRPC_SLICE_LENGTH(a)); + grpc_slice_unref(a); + grpc_slice_unref(b); + return ok; +} + +int byte_buffer_eq_slice(grpc_byte_buffer* bb, grpc_slice b) { + if (bb->data.raw.compression > GRPC_COMPRESS_NONE) { + grpc_slice_buffer decompressed_buffer; + grpc_slice_buffer_init(&decompressed_buffer); + GPR_ASSERT(grpc_msg_decompress( + grpc_compression_algorithm_to_message_compression_algorithm( + bb->data.raw.compression), + &bb->data.raw.slice_buffer, &decompressed_buffer)); + grpc_byte_buffer* rbb = grpc_raw_byte_buffer_create( + decompressed_buffer.slices, decompressed_buffer.count); + int ret_val = raw_byte_buffer_eq_slice(rbb, b); + grpc_byte_buffer_destroy(rbb); + grpc_slice_buffer_destroy(&decompressed_buffer); + return ret_val; + } + return raw_byte_buffer_eq_slice(bb, b); +} + +int byte_buffer_eq_string(grpc_byte_buffer* bb, const char* str) { + return byte_buffer_eq_slice(bb, grpc_slice_from_copied_string(str)); +} + +static bool is_probably_integer(void* p) { + return reinterpret_cast(p) < 1000000; +} + +namespace { + +std::string ExpectationString(const Expectation& e) { + std::string out; + if (is_probably_integer(e.tag)) { + out = absl::StrFormat("tag(%" PRIdPTR ") ", + reinterpret_cast(e.tag)); + } else { + out = absl::StrFormat("%p ", e.tag); + } + switch (e.type) { + case GRPC_OP_COMPLETE: + absl::StrAppendFormat(&out, "GRPC_OP_COMPLETE success=%d %s:%d", + e.success, e.file, e.line); + break; + case GRPC_QUEUE_TIMEOUT: + case GRPC_QUEUE_SHUTDOWN: + gpr_log(GPR_ERROR, "not implemented"); + abort(); + } + return out; +} + +std::string ExpectationsString(const cq_verifier& v) { + std::vector expectations; + for (const auto& e : v.expectations) { + expectations.push_back(ExpectationString(e)); + } + return absl::StrJoin(expectations, "\n"); +} + +} // namespace + +static void fail_no_event_received(cq_verifier* v) { + gpr_log(GPR_ERROR, "no event received, but expected:%s", + ExpectationsString(*v).c_str()); + abort(); +} + +static void verify_matches(const Expectation& e, const grpc_event& ev) { + GPR_ASSERT(e.type == ev.type); + switch (e.type) { + case GRPC_OP_COMPLETE: + if (e.check_success && e.success != ev.success) { + gpr_log(GPR_ERROR, "actual success does not match expected: %s", + ExpectationString(e).c_str()); + abort(); + } + break; + case GRPC_QUEUE_SHUTDOWN: + gpr_log(GPR_ERROR, "premature queue shutdown"); + abort(); + case GRPC_QUEUE_TIMEOUT: + gpr_log(GPR_ERROR, "not implemented"); + abort(); + } +} + +// Try to find the event in the expectations list +bool FindExpectations(std::list* expectations, + const grpc_event& ev) { + for (auto e = expectations->begin(); e != expectations->end(); ++e) { + if (e->tag == ev.tag) { + verify_matches(*e, ev); + if (e->seen != nullptr) { + *(e->seen) = true; + } + expectations->erase(e); + return true; + } + } + return false; +} + +void cq_verify(cq_verifier* v, int timeout_sec) { + const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(timeout_sec); + while (!v->expectations.empty()) { + grpc_event ev = grpc_completion_queue_next(v->cq, deadline, nullptr); + if (ev.type == GRPC_QUEUE_TIMEOUT) { + fail_no_event_received(v); + break; + } + if (FindExpectations(&v->expectations, ev)) continue; + if (FindExpectations(&v->maybe_expectations, ev)) continue; + gpr_log(GPR_ERROR, "cq returned unexpected event: %s", + grpc_event_string(&ev).c_str()); + gpr_log(GPR_ERROR, "expected tags:\n%s", ExpectationsString(*v).c_str()); + abort(); + } + v->maybe_expectations.clear(); +} + +void cq_verify_empty_timeout(cq_verifier* v, int timeout_sec) { + gpr_timespec deadline = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(timeout_sec, GPR_TIMESPAN)); + grpc_event ev; + + GPR_ASSERT(v->expectations.empty() && "expectation queue must be empty"); + + ev = grpc_completion_queue_next(v->cq, deadline, nullptr); + if (ev.type != GRPC_QUEUE_TIMEOUT) { + gpr_log(GPR_ERROR, "unexpected event (expected nothing): %s", + grpc_event_string(&ev).c_str()); + abort(); + } +} + +void cq_verify_empty(cq_verifier* v) { cq_verify_empty_timeout(v, 1); } + +void cq_maybe_expect_completion(cq_verifier* v, const char* file, int line, + void* tag, bool success, bool* seen) { + v->maybe_expectations.emplace_back(file, line, GRPC_OP_COMPLETE, tag, + true /* check_success */, success, seen); +} + +static void add(cq_verifier* v, const char* file, int line, + grpc_completion_type type, void* tag, bool check_success, + bool success) { + v->expectations.emplace_back(file, line, type, tag, check_success, success, + nullptr); +} + +void cq_expect_completion(cq_verifier* v, const char* file, int line, void* tag, + bool success) { + add(v, file, line, GRPC_OP_COMPLETE, tag, true, success); +} + +void cq_expect_completion_any_status(cq_verifier* v, const char* file, int line, + void* tag) { + add(v, file, line, GRPC_OP_COMPLETE, tag, false, false); +} diff --git a/test/core/end2end/cq_verifier_native.cc b/test/core/end2end/cq_verifier_native.cc new file mode 100644 index 00000000..d65dbcaa --- /dev/null +++ b/test/core/end2end/cq_verifier_native.cc @@ -0,0 +1,55 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/end2end/cq_verifier_internal.h" + +/* the verifier itself */ +struct cq_verifier { + /* bound completion queue */ + grpc_completion_queue* cq; + /* start of expectation list */ + expectation* first_expectation; +}; + +cq_verifier* cq_verifier_create(grpc_completion_queue* cq) { + cq_verifier* v = static_cast(gpr_malloc(sizeof(cq_verifier))); + v->cq = cq; + cq_verifier_set_first_expectation(v, nullptr); + return v; +} + +void cq_verifier_destroy(cq_verifier* v) { + cq_verify(v); + gpr_free(v); +} + +expectation* cq_verifier_get_first_expectation(cq_verifier* v) { + return v->first_expectation; +} + +void cq_verifier_set_first_expectation(cq_verifier* v, expectation* e) { + v->first_expectation = e; +} + +grpc_event cq_verifier_next_event(cq_verifier* v, int timeout_seconds) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + return grpc_completion_queue_next(v->cq, deadline, nullptr); +} diff --git a/test/core/end2end/data/client_certs.cc b/test/core/end2end/data/client_certs.cc new file mode 100644 index 00000000..8d9cb912 --- /dev/null +++ b/test/core/end2end/data/client_certs.cc @@ -0,0 +1,522 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +extern const char test_self_signed_client_cert[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x43, + 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x44, 0x73, 0x7a, 0x43, 0x43, + 0x41, 0x70, 0x75, 0x67, 0x41, 0x77, 0x49, 0x42, 0x41, 0x67, 0x49, 0x55, + 0x4f, 0x4e, 0x57, 0x62, 0x6b, 0x55, 0x6e, 0x31, 0x6f, 0x62, 0x48, 0x43, + 0x77, 0x39, 0x4c, 0x37, 0x6c, 0x4d, 0x4e, 0x45, 0x45, 0x35, 0x52, 0x45, + 0x76, 0x62, 0x38, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, 0x6f, 0x5a, 0x49, + 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x4c, 0x0a, 0x42, 0x51, 0x41, + 0x77, 0x61, 0x54, 0x45, 0x4c, 0x4d, 0x41, 0x6b, 0x47, 0x41, 0x31, 0x55, + 0x45, 0x42, 0x68, 0x4d, 0x43, 0x51, 0x56, 0x55, 0x78, 0x45, 0x7a, 0x41, + 0x52, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x67, 0x4d, 0x43, 0x6c, 0x4e, + 0x76, 0x62, 0x57, 0x55, 0x74, 0x55, 0x33, 0x52, 0x68, 0x64, 0x47, 0x55, + 0x78, 0x49, 0x54, 0x41, 0x66, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x6f, + 0x4d, 0x0a, 0x47, 0x45, 0x6c, 0x75, 0x64, 0x47, 0x56, 0x79, 0x62, 0x6d, + 0x56, 0x30, 0x49, 0x46, 0x64, 0x70, 0x5a, 0x47, 0x64, 0x70, 0x64, 0x48, + 0x4d, 0x67, 0x55, 0x48, 0x52, 0x35, 0x49, 0x45, 0x78, 0x30, 0x5a, 0x44, + 0x45, 0x69, 0x4d, 0x43, 0x41, 0x47, 0x41, 0x31, 0x55, 0x45, 0x41, 0x77, + 0x77, 0x5a, 0x59, 0x6d, 0x46, 0x6b, 0x59, 0x32, 0x78, 0x70, 0x5a, 0x57, + 0x35, 0x30, 0x4c, 0x6e, 0x52, 0x6c, 0x0a, 0x63, 0x33, 0x51, 0x75, 0x5a, + 0x32, 0x39, 0x76, 0x5a, 0x32, 0x78, 0x6c, 0x4c, 0x6d, 0x4e, 0x76, 0x62, + 0x54, 0x41, 0x65, 0x46, 0x77, 0x30, 0x79, 0x4d, 0x44, 0x41, 0x7a, 0x4d, + 0x54, 0x63, 0x78, 0x4e, 0x7a, 0x51, 0x7a, 0x4d, 0x6a, 0x4e, 0x61, 0x46, + 0x77, 0x30, 0x7a, 0x4d, 0x44, 0x41, 0x7a, 0x4d, 0x54, 0x55, 0x78, 0x4e, + 0x7a, 0x51, 0x7a, 0x4d, 0x6a, 0x4e, 0x61, 0x4d, 0x47, 0x6b, 0x78, 0x0a, + 0x43, 0x7a, 0x41, 0x4a, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x59, 0x54, + 0x41, 0x6b, 0x46, 0x56, 0x4d, 0x52, 0x4d, 0x77, 0x45, 0x51, 0x59, 0x44, + 0x56, 0x51, 0x51, 0x49, 0x44, 0x41, 0x70, 0x54, 0x62, 0x32, 0x31, 0x6c, + 0x4c, 0x56, 0x4e, 0x30, 0x59, 0x58, 0x52, 0x6c, 0x4d, 0x53, 0x45, 0x77, + 0x48, 0x77, 0x59, 0x44, 0x56, 0x51, 0x51, 0x4b, 0x44, 0x42, 0x68, 0x4a, + 0x62, 0x6e, 0x52, 0x6c, 0x0a, 0x63, 0x6d, 0x35, 0x6c, 0x64, 0x43, 0x42, + 0x58, 0x61, 0x57, 0x52, 0x6e, 0x61, 0x58, 0x52, 0x7a, 0x49, 0x46, 0x42, + 0x30, 0x65, 0x53, 0x42, 0x4d, 0x64, 0x47, 0x51, 0x78, 0x49, 0x6a, 0x41, + 0x67, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x4d, 0x4d, 0x47, 0x57, 0x4a, + 0x68, 0x5a, 0x47, 0x4e, 0x73, 0x61, 0x57, 0x56, 0x75, 0x64, 0x43, 0x35, + 0x30, 0x5a, 0x58, 0x4e, 0x30, 0x4c, 0x6d, 0x64, 0x76, 0x0a, 0x62, 0x32, + 0x64, 0x73, 0x5a, 0x53, 0x35, 0x6a, 0x62, 0x32, 0x30, 0x77, 0x67, 0x67, + 0x45, 0x69, 0x4d, 0x41, 0x30, 0x47, 0x43, 0x53, 0x71, 0x47, 0x53, 0x49, + 0x62, 0x33, 0x44, 0x51, 0x45, 0x42, 0x41, 0x51, 0x55, 0x41, 0x41, 0x34, + 0x49, 0x42, 0x44, 0x77, 0x41, 0x77, 0x67, 0x67, 0x45, 0x4b, 0x41, 0x6f, + 0x49, 0x42, 0x41, 0x51, 0x44, 0x76, 0x64, 0x7a, 0x4b, 0x44, 0x54, 0x59, + 0x76, 0x52, 0x0a, 0x67, 0x6a, 0x42, 0x4f, 0x55, 0x4f, 0x72, 0x7a, 0x44, + 0x77, 0x6b, 0x41, 0x5a, 0x47, 0x77, 0x4e, 0x46, 0x48, 0x48, 0x6c, 0x4d, + 0x59, 0x79, 0x4d, 0x47, 0x49, 0x35, 0x74, 0x49, 0x74, 0x6a, 0x33, 0x74, + 0x43, 0x7a, 0x58, 0x6b, 0x62, 0x70, 0x4d, 0x30, 0x75, 0x7a, 0x33, 0x5a, + 0x6a, 0x48, 0x56, 0x61, 0x68, 0x75, 0x2b, 0x65, 0x59, 0x63, 0x2b, 0x4b, + 0x76, 0x59, 0x41, 0x70, 0x4d, 0x36, 0x34, 0x0a, 0x46, 0x32, 0x64, 0x42, + 0x62, 0x31, 0x36, 0x68, 0x73, 0x37, 0x31, 0x33, 0x46, 0x43, 0x6b, 0x38, + 0x6d, 0x69, 0x68, 0x59, 0x41, 0x42, 0x6a, 0x6e, 0x53, 0x6e, 0x64, 0x72, + 0x51, 0x73, 0x6c, 0x2f, 0x55, 0x32, 0x76, 0x38, 0x59, 0x46, 0x54, 0x37, + 0x44, 0x69, 0x70, 0x66, 0x4c, 0x52, 0x65, 0x71, 0x71, 0x61, 0x4f, 0x47, + 0x75, 0x32, 0x6f, 0x39, 0x48, 0x64, 0x76, 0x57, 0x66, 0x69, 0x55, 0x6c, + 0x0a, 0x61, 0x69, 0x43, 0x2f, 0x55, 0x47, 0x47, 0x66, 0x52, 0x2b, 0x59, + 0x62, 0x6c, 0x70, 0x4b, 0x37, 0x43, 0x47, 0x2b, 0x37, 0x2f, 0x68, 0x76, + 0x54, 0x58, 0x74, 0x55, 0x73, 0x4d, 0x77, 0x2b, 0x4f, 0x70, 0x70, 0x6f, + 0x65, 0x48, 0x39, 0x7a, 0x38, 0x37, 0x72, 0x68, 0x4f, 0x4a, 0x4d, 0x78, + 0x74, 0x69, 0x43, 0x37, 0x58, 0x77, 0x55, 0x35, 0x72, 0x68, 0x45, 0x6d, + 0x61, 0x62, 0x2f, 0x31, 0x66, 0x0a, 0x31, 0x58, 0x4d, 0x2f, 0x6e, 0x4c, + 0x6f, 0x5a, 0x72, 0x66, 0x44, 0x41, 0x63, 0x54, 0x62, 0x44, 0x79, 0x77, + 0x6f, 0x65, 0x75, 0x38, 0x32, 0x36, 0x53, 0x4a, 0x33, 0x6d, 0x69, 0x66, + 0x61, 0x6a, 0x71, 0x37, 0x6f, 0x4b, 0x33, 0x4c, 0x44, 0x64, 0x4e, 0x4c, + 0x6a, 0x57, 0x5a, 0x77, 0x66, 0x45, 0x73, 0x43, 0x4f, 0x31, 0x71, 0x70, + 0x32, 0x43, 0x34, 0x67, 0x4c, 0x76, 0x42, 0x6c, 0x4f, 0x4f, 0x0a, 0x4b, + 0x73, 0x57, 0x4f, 0x4c, 0x4e, 0x62, 0x79, 0x36, 0x42, 0x79, 0x78, 0x43, + 0x4f, 0x50, 0x6c, 0x43, 0x54, 0x61, 0x30, 0x55, 0x43, 0x61, 0x56, 0x75, + 0x6f, 0x4e, 0x63, 0x6c, 0x59, 0x6f, 0x6c, 0x37, 0x31, 0x6a, 0x79, 0x69, + 0x31, 0x37, 0x4b, 0x57, 0x2b, 0x4e, 0x6b, 0x30, 0x6e, 0x4e, 0x65, 0x39, + 0x79, 0x61, 0x56, 0x63, 0x79, 0x72, 0x36, 0x48, 0x30, 0x7a, 0x33, 0x62, + 0x49, 0x6d, 0x66, 0x0a, 0x4a, 0x68, 0x62, 0x53, 0x75, 0x34, 0x72, 0x7a, + 0x49, 0x39, 0x33, 0x6e, 0x41, 0x67, 0x4d, 0x42, 0x41, 0x41, 0x47, 0x6a, + 0x55, 0x7a, 0x42, 0x52, 0x4d, 0x42, 0x30, 0x47, 0x41, 0x31, 0x55, 0x64, + 0x44, 0x67, 0x51, 0x57, 0x42, 0x42, 0x54, 0x4b, 0x4a, 0x73, 0x6b, 0x45, + 0x59, 0x64, 0x32, 0x6e, 0x64, 0x72, 0x77, 0x69, 0x68, 0x50, 0x54, 0x67, + 0x32, 0x50, 0x7a, 0x59, 0x46, 0x2f, 0x6b, 0x50, 0x0a, 0x67, 0x7a, 0x41, + 0x66, 0x42, 0x67, 0x4e, 0x56, 0x48, 0x53, 0x4d, 0x45, 0x47, 0x44, 0x41, + 0x57, 0x67, 0x42, 0x54, 0x4b, 0x4a, 0x73, 0x6b, 0x45, 0x59, 0x64, 0x32, + 0x6e, 0x64, 0x72, 0x77, 0x69, 0x68, 0x50, 0x54, 0x67, 0x32, 0x50, 0x7a, + 0x59, 0x46, 0x2f, 0x6b, 0x50, 0x67, 0x7a, 0x41, 0x50, 0x42, 0x67, 0x4e, + 0x56, 0x48, 0x52, 0x4d, 0x42, 0x41, 0x66, 0x38, 0x45, 0x42, 0x54, 0x41, + 0x44, 0x0a, 0x41, 0x51, 0x48, 0x2f, 0x4d, 0x41, 0x30, 0x47, 0x43, 0x53, + 0x71, 0x47, 0x53, 0x49, 0x62, 0x33, 0x44, 0x51, 0x45, 0x42, 0x43, 0x77, + 0x55, 0x41, 0x41, 0x34, 0x49, 0x42, 0x41, 0x51, 0x42, 0x6f, 0x47, 0x77, + 0x57, 0x52, 0x30, 0x70, 0x4c, 0x4d, 0x31, 0x69, 0x63, 0x58, 0x34, 0x62, + 0x49, 0x4a, 0x36, 0x79, 0x64, 0x75, 0x46, 0x55, 0x2f, 0x41, 0x34, 0x6a, + 0x53, 0x69, 0x71, 0x45, 0x54, 0x36, 0x0a, 0x67, 0x76, 0x4a, 0x68, 0x77, + 0x67, 0x45, 0x72, 0x69, 0x6c, 0x71, 0x54, 0x4b, 0x66, 0x48, 0x36, 0x59, + 0x38, 0x39, 0x72, 0x71, 0x74, 0x7a, 0x57, 0x38, 0x6b, 0x34, 0x55, 0x75, + 0x72, 0x41, 0x4f, 0x43, 0x73, 0x45, 0x34, 0x46, 0x41, 0x36, 0x77, 0x62, + 0x6b, 0x48, 0x57, 0x77, 0x72, 0x55, 0x4d, 0x6e, 0x43, 0x6c, 0x59, 0x34, + 0x6c, 0x6b, 0x48, 0x4a, 0x68, 0x2b, 0x4d, 0x75, 0x4e, 0x61, 0x4a, 0x0a, + 0x6e, 0x43, 0x47, 0x72, 0x4b, 0x38, 0x77, 0x52, 0x4b, 0x47, 0x62, 0x2f, + 0x6d, 0x71, 0x57, 0x39, 0x64, 0x35, 0x70, 0x50, 0x37, 0x32, 0x45, 0x74, + 0x31, 0x51, 0x36, 0x4f, 0x57, 0x36, 0x44, 0x41, 0x4b, 0x71, 0x47, 0x66, + 0x6a, 0x44, 0x57, 0x68, 0x32, 0x4d, 0x7a, 0x53, 0x50, 0x48, 0x42, 0x78, + 0x63, 0x43, 0x4c, 0x65, 0x79, 0x69, 0x67, 0x4f, 0x31, 0x77, 0x71, 0x64, + 0x34, 0x57, 0x31, 0x54, 0x0a, 0x6e, 0x76, 0x76, 0x71, 0x6c, 0x36, 0x6c, + 0x34, 0x4c, 0x2b, 0x42, 0x35, 0x49, 0x54, 0x2f, 0x63, 0x2b, 0x2f, 0x45, + 0x48, 0x4f, 0x33, 0x50, 0x77, 0x62, 0x49, 0x39, 0x76, 0x36, 0x4d, 0x47, + 0x54, 0x74, 0x4c, 0x6a, 0x73, 0x5a, 0x67, 0x6b, 0x52, 0x4b, 0x49, 0x74, + 0x61, 0x50, 0x68, 0x2b, 0x59, 0x65, 0x4a, 0x64, 0x6d, 0x42, 0x59, 0x68, + 0x52, 0x44, 0x31, 0x42, 0x76, 0x57, 0x62, 0x36, 0x73, 0x0a, 0x56, 0x77, + 0x45, 0x62, 0x37, 0x61, 0x51, 0x31, 0x6f, 0x53, 0x46, 0x2b, 0x65, 0x73, + 0x55, 0x76, 0x4d, 0x6d, 0x6a, 0x47, 0x56, 0x75, 0x48, 0x58, 0x75, 0x51, + 0x76, 0x57, 0x4a, 0x61, 0x68, 0x6e, 0x6a, 0x59, 0x64, 0x59, 0x54, 0x32, + 0x44, 0x69, 0x6b, 0x79, 0x71, 0x52, 0x2b, 0x41, 0x77, 0x61, 0x4b, 0x7a, + 0x72, 0x65, 0x34, 0x47, 0x4a, 0x4d, 0x48, 0x73, 0x58, 0x33, 0x2f, 0x43, + 0x66, 0x38, 0x0a, 0x71, 0x64, 0x78, 0x79, 0x49, 0x2b, 0x42, 0x31, 0x6a, + 0x55, 0x77, 0x4e, 0x72, 0x37, 0x73, 0x4c, 0x41, 0x32, 0x45, 0x59, 0x44, + 0x6a, 0x6e, 0x55, 0x52, 0x30, 0x6a, 0x45, 0x48, 0x63, 0x72, 0x4f, 0x42, + 0x53, 0x70, 0x49, 0x51, 0x79, 0x52, 0x4d, 0x47, 0x57, 0x64, 0x75, 0x6a, + 0x30, 0x50, 0x31, 0x36, 0x79, 0x62, 0x39, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, + 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, + 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x00}; + +extern const char test_self_signed_client_key[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x50, + 0x52, 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x45, 0x76, 0x67, 0x49, 0x42, + 0x41, 0x44, 0x41, 0x4e, 0x42, 0x67, 0x6b, 0x71, 0x68, 0x6b, 0x69, 0x47, + 0x39, 0x77, 0x30, 0x42, 0x41, 0x51, 0x45, 0x46, 0x41, 0x41, 0x53, 0x43, + 0x42, 0x4b, 0x67, 0x77, 0x67, 0x67, 0x53, 0x6b, 0x41, 0x67, 0x45, 0x41, + 0x41, 0x6f, 0x49, 0x42, 0x41, 0x51, 0x44, 0x76, 0x64, 0x7a, 0x4b, 0x44, + 0x54, 0x59, 0x76, 0x52, 0x67, 0x6a, 0x42, 0x4f, 0x0a, 0x55, 0x4f, 0x72, + 0x7a, 0x44, 0x77, 0x6b, 0x41, 0x5a, 0x47, 0x77, 0x4e, 0x46, 0x48, 0x48, + 0x6c, 0x4d, 0x59, 0x79, 0x4d, 0x47, 0x49, 0x35, 0x74, 0x49, 0x74, 0x6a, + 0x33, 0x74, 0x43, 0x7a, 0x58, 0x6b, 0x62, 0x70, 0x4d, 0x30, 0x75, 0x7a, + 0x33, 0x5a, 0x6a, 0x48, 0x56, 0x61, 0x68, 0x75, 0x2b, 0x65, 0x59, 0x63, + 0x2b, 0x4b, 0x76, 0x59, 0x41, 0x70, 0x4d, 0x36, 0x34, 0x46, 0x32, 0x64, + 0x42, 0x0a, 0x62, 0x31, 0x36, 0x68, 0x73, 0x37, 0x31, 0x33, 0x46, 0x43, + 0x6b, 0x38, 0x6d, 0x69, 0x68, 0x59, 0x41, 0x42, 0x6a, 0x6e, 0x53, 0x6e, + 0x64, 0x72, 0x51, 0x73, 0x6c, 0x2f, 0x55, 0x32, 0x76, 0x38, 0x59, 0x46, + 0x54, 0x37, 0x44, 0x69, 0x70, 0x66, 0x4c, 0x52, 0x65, 0x71, 0x71, 0x61, + 0x4f, 0x47, 0x75, 0x32, 0x6f, 0x39, 0x48, 0x64, 0x76, 0x57, 0x66, 0x69, + 0x55, 0x6c, 0x61, 0x69, 0x43, 0x2f, 0x0a, 0x55, 0x47, 0x47, 0x66, 0x52, + 0x2b, 0x59, 0x62, 0x6c, 0x70, 0x4b, 0x37, 0x43, 0x47, 0x2b, 0x37, 0x2f, + 0x68, 0x76, 0x54, 0x58, 0x74, 0x55, 0x73, 0x4d, 0x77, 0x2b, 0x4f, 0x70, + 0x70, 0x6f, 0x65, 0x48, 0x39, 0x7a, 0x38, 0x37, 0x72, 0x68, 0x4f, 0x4a, + 0x4d, 0x78, 0x74, 0x69, 0x43, 0x37, 0x58, 0x77, 0x55, 0x35, 0x72, 0x68, + 0x45, 0x6d, 0x61, 0x62, 0x2f, 0x31, 0x66, 0x31, 0x58, 0x4d, 0x2f, 0x0a, + 0x6e, 0x4c, 0x6f, 0x5a, 0x72, 0x66, 0x44, 0x41, 0x63, 0x54, 0x62, 0x44, + 0x79, 0x77, 0x6f, 0x65, 0x75, 0x38, 0x32, 0x36, 0x53, 0x4a, 0x33, 0x6d, + 0x69, 0x66, 0x61, 0x6a, 0x71, 0x37, 0x6f, 0x4b, 0x33, 0x4c, 0x44, 0x64, + 0x4e, 0x4c, 0x6a, 0x57, 0x5a, 0x77, 0x66, 0x45, 0x73, 0x43, 0x4f, 0x31, + 0x71, 0x70, 0x32, 0x43, 0x34, 0x67, 0x4c, 0x76, 0x42, 0x6c, 0x4f, 0x4f, + 0x4b, 0x73, 0x57, 0x4f, 0x0a, 0x4c, 0x4e, 0x62, 0x79, 0x36, 0x42, 0x79, + 0x78, 0x43, 0x4f, 0x50, 0x6c, 0x43, 0x54, 0x61, 0x30, 0x55, 0x43, 0x61, + 0x56, 0x75, 0x6f, 0x4e, 0x63, 0x6c, 0x59, 0x6f, 0x6c, 0x37, 0x31, 0x6a, + 0x79, 0x69, 0x31, 0x37, 0x4b, 0x57, 0x2b, 0x4e, 0x6b, 0x30, 0x6e, 0x4e, + 0x65, 0x39, 0x79, 0x61, 0x56, 0x63, 0x79, 0x72, 0x36, 0x48, 0x30, 0x7a, + 0x33, 0x62, 0x49, 0x6d, 0x66, 0x4a, 0x68, 0x62, 0x53, 0x0a, 0x75, 0x34, + 0x72, 0x7a, 0x49, 0x39, 0x33, 0x6e, 0x41, 0x67, 0x4d, 0x42, 0x41, 0x41, + 0x45, 0x43, 0x67, 0x67, 0x45, 0x42, 0x41, 0x4f, 0x49, 0x50, 0x4f, 0x4a, + 0x52, 0x54, 0x70, 0x47, 0x61, 0x48, 0x37, 0x47, 0x70, 0x43, 0x59, 0x55, + 0x70, 0x4c, 0x4b, 0x30, 0x67, 0x2f, 0x68, 0x50, 0x46, 0x6b, 0x46, 0x35, + 0x45, 0x79, 0x45, 0x57, 0x67, 0x2f, 0x31, 0x6c, 0x53, 0x59, 0x7a, 0x52, + 0x49, 0x70, 0x0a, 0x2b, 0x52, 0x73, 0x58, 0x36, 0x7a, 0x4f, 0x53, 0x2b, + 0x7a, 0x6b, 0x69, 0x4e, 0x48, 0x45, 0x76, 0x31, 0x6a, 0x6b, 0x65, 0x4b, + 0x4e, 0x6f, 0x37, 0x58, 0x44, 0x69, 0x48, 0x58, 0x4d, 0x37, 0x55, 0x36, + 0x52, 0x6b, 0x51, 0x74, 0x64, 0x6b, 0x5a, 0x41, 0x51, 0x64, 0x6b, 0x39, + 0x50, 0x6a, 0x4d, 0x33, 0x73, 0x45, 0x55, 0x64, 0x6d, 0x34, 0x43, 0x45, + 0x6e, 0x49, 0x6a, 0x66, 0x6d, 0x7a, 0x41, 0x0a, 0x70, 0x2f, 0x52, 0x38, + 0x54, 0x44, 0x30, 0x6b, 0x78, 0x6b, 0x4e, 0x4c, 0x49, 0x6b, 0x68, 0x75, + 0x46, 0x48, 0x32, 0x67, 0x64, 0x30, 0x35, 0x79, 0x33, 0x5a, 0x48, 0x44, + 0x53, 0x2f, 0x58, 0x69, 0x46, 0x6b, 0x41, 0x45, 0x39, 0x65, 0x4f, 0x54, + 0x30, 0x46, 0x72, 0x43, 0x37, 0x6f, 0x6d, 0x36, 0x45, 0x53, 0x44, 0x37, + 0x5a, 0x66, 0x46, 0x49, 0x57, 0x52, 0x31, 0x38, 0x70, 0x6e, 0x63, 0x57, + 0x0a, 0x5a, 0x47, 0x71, 0x37, 0x74, 0x46, 0x41, 0x5a, 0x5a, 0x52, 0x6d, + 0x70, 0x6b, 0x75, 0x6d, 0x32, 0x44, 0x2b, 0x4d, 0x4a, 0x79, 0x31, 0x67, + 0x57, 0x78, 0x49, 0x58, 0x42, 0x78, 0x74, 0x35, 0x6d, 0x61, 0x64, 0x54, + 0x45, 0x70, 0x52, 0x78, 0x51, 0x64, 0x35, 0x36, 0x74, 0x6f, 0x45, 0x6e, + 0x66, 0x78, 0x33, 0x37, 0x32, 0x46, 0x30, 0x79, 0x34, 0x7a, 0x6b, 0x63, + 0x58, 0x33, 0x70, 0x6e, 0x45, 0x0a, 0x34, 0x48, 0x36, 0x46, 0x61, 0x4a, + 0x55, 0x42, 0x6a, 0x64, 0x76, 0x4b, 0x6c, 0x32, 0x51, 0x7a, 0x46, 0x35, + 0x63, 0x30, 0x6a, 0x42, 0x71, 0x67, 0x78, 0x4d, 0x52, 0x76, 0x57, 0x50, + 0x35, 0x59, 0x66, 0x4e, 0x75, 0x38, 0x2b, 0x64, 0x6d, 0x61, 0x51, 0x4f, + 0x52, 0x50, 0x6b, 0x70, 0x7a, 0x53, 0x70, 0x74, 0x4f, 0x50, 0x6d, 0x5a, + 0x4d, 0x39, 0x56, 0x4b, 0x56, 0x2b, 0x74, 0x4a, 0x56, 0x53, 0x0a, 0x31, + 0x78, 0x6e, 0x4f, 0x49, 0x36, 0x44, 0x74, 0x72, 0x6e, 0x4e, 0x5a, 0x52, + 0x6f, 0x6a, 0x65, 0x67, 0x52, 0x2f, 0x45, 0x36, 0x4b, 0x68, 0x4e, 0x79, + 0x69, 0x50, 0x54, 0x59, 0x79, 0x39, 0x37, 0x55, 0x67, 0x59, 0x7a, 0x64, + 0x4b, 0x53, 0x2b, 0x53, 0x53, 0x45, 0x43, 0x67, 0x59, 0x45, 0x41, 0x2b, + 0x77, 0x67, 0x53, 0x49, 0x71, 0x72, 0x66, 0x6b, 0x65, 0x71, 0x71, 0x6f, + 0x74, 0x4a, 0x78, 0x0a, 0x63, 0x47, 0x78, 0x46, 0x34, 0x78, 0x39, 0x76, + 0x2f, 0x6c, 0x64, 0x4b, 0x72, 0x35, 0x68, 0x6c, 0x68, 0x4a, 0x4e, 0x6f, + 0x4b, 0x58, 0x4c, 0x6b, 0x65, 0x70, 0x6b, 0x63, 0x72, 0x76, 0x68, 0x68, + 0x78, 0x66, 0x48, 0x4b, 0x67, 0x6a, 0x57, 0x7a, 0x31, 0x6e, 0x5a, 0x59, + 0x2f, 0x2b, 0x52, 0x70, 0x67, 0x34, 0x32, 0x47, 0x46, 0x4d, 0x76, 0x78, + 0x57, 0x52, 0x72, 0x47, 0x54, 0x4d, 0x49, 0x4a, 0x0a, 0x64, 0x64, 0x69, + 0x4f, 0x72, 0x32, 0x34, 0x70, 0x30, 0x48, 0x43, 0x6b, 0x75, 0x73, 0x57, + 0x52, 0x4d, 0x4b, 0x51, 0x4c, 0x37, 0x58, 0x78, 0x76, 0x75, 0x48, 0x44, + 0x71, 0x30, 0x72, 0x6f, 0x38, 0x53, 0x47, 0x71, 0x58, 0x7a, 0x71, 0x57, + 0x47, 0x75, 0x48, 0x33, 0x31, 0x52, 0x2b, 0x59, 0x4e, 0x50, 0x38, 0x64, + 0x79, 0x32, 0x70, 0x71, 0x64, 0x33, 0x4f, 0x6c, 0x77, 0x7a, 0x54, 0x67, + 0x67, 0x0a, 0x38, 0x76, 0x30, 0x77, 0x77, 0x7a, 0x78, 0x38, 0x41, 0x75, + 0x79, 0x50, 0x35, 0x59, 0x73, 0x34, 0x4d, 0x32, 0x30, 0x45, 0x77, 0x76, + 0x37, 0x58, 0x75, 0x79, 0x30, 0x43, 0x67, 0x59, 0x45, 0x41, 0x39, 0x44, + 0x53, 0x47, 0x4d, 0x55, 0x38, 0x6a, 0x6d, 0x6a, 0x78, 0x4a, 0x2f, 0x75, + 0x50, 0x44, 0x43, 0x58, 0x57, 0x4f, 0x45, 0x41, 0x71, 0x74, 0x45, 0x37, + 0x38, 0x77, 0x54, 0x74, 0x49, 0x77, 0x0a, 0x75, 0x4d, 0x42, 0x76, 0x2b, + 0x67, 0x65, 0x30, 0x69, 0x6e, 0x63, 0x33, 0x37, 0x78, 0x66, 0x2b, 0x66, + 0x4e, 0x36, 0x44, 0x2f, 0x7a, 0x69, 0x54, 0x72, 0x4a, 0x76, 0x67, 0x77, + 0x2f, 0x58, 0x79, 0x54, 0x31, 0x35, 0x70, 0x6d, 0x51, 0x64, 0x4f, 0x6c, + 0x58, 0x78, 0x33, 0x53, 0x67, 0x31, 0x68, 0x39, 0x58, 0x42, 0x5a, 0x65, + 0x49, 0x6c, 0x61, 0x65, 0x43, 0x64, 0x46, 0x57, 0x72, 0x46, 0x42, 0x0a, + 0x6f, 0x59, 0x72, 0x56, 0x73, 0x69, 0x75, 0x6f, 0x58, 0x52, 0x73, 0x77, + 0x66, 0x6b, 0x46, 0x77, 0x41, 0x30, 0x79, 0x4f, 0x6b, 0x43, 0x73, 0x48, + 0x79, 0x47, 0x69, 0x49, 0x34, 0x54, 0x45, 0x30, 0x57, 0x31, 0x72, 0x47, + 0x62, 0x71, 0x50, 0x31, 0x35, 0x38, 0x49, 0x6a, 0x77, 0x58, 0x50, 0x63, + 0x7a, 0x42, 0x73, 0x77, 0x57, 0x49, 0x37, 0x69, 0x2f, 0x44, 0x36, 0x4c, + 0x70, 0x49, 0x4e, 0x4c, 0x0a, 0x42, 0x44, 0x37, 0x59, 0x59, 0x70, 0x66, + 0x48, 0x6d, 0x65, 0x4d, 0x43, 0x67, 0x59, 0x42, 0x30, 0x38, 0x41, 0x69, + 0x4b, 0x72, 0x37, 0x43, 0x66, 0x35, 0x34, 0x48, 0x2f, 0x67, 0x53, 0x71, + 0x6f, 0x35, 0x54, 0x63, 0x56, 0x47, 0x7a, 0x4c, 0x76, 0x64, 0x7a, 0x68, + 0x71, 0x58, 0x67, 0x4b, 0x45, 0x5a, 0x4b, 0x70, 0x30, 0x44, 0x48, 0x70, + 0x55, 0x68, 0x66, 0x69, 0x76, 0x70, 0x54, 0x4c, 0x65, 0x0a, 0x6f, 0x38, + 0x6a, 0x6a, 0x4b, 0x53, 0x4d, 0x53, 0x4e, 0x32, 0x55, 0x30, 0x4a, 0x76, + 0x48, 0x6a, 0x2f, 0x30, 0x78, 0x44, 0x61, 0x64, 0x47, 0x4f, 0x34, 0x59, + 0x4d, 0x59, 0x68, 0x4a, 0x63, 0x6c, 0x6c, 0x33, 0x43, 0x34, 0x56, 0x67, + 0x67, 0x53, 0x65, 0x6a, 0x61, 0x79, 0x62, 0x70, 0x41, 0x34, 0x36, 0x57, + 0x4a, 0x4a, 0x43, 0x64, 0x74, 0x39, 0x50, 0x74, 0x53, 0x55, 0x76, 0x33, + 0x36, 0x50, 0x0a, 0x65, 0x57, 0x41, 0x6f, 0x4f, 0x6b, 0x46, 0x73, 0x74, + 0x66, 0x68, 0x4a, 0x75, 0x75, 0x66, 0x58, 0x47, 0x78, 0x44, 0x73, 0x74, + 0x6e, 0x50, 0x74, 0x55, 0x61, 0x31, 0x6a, 0x57, 0x38, 0x38, 0x31, 0x67, + 0x69, 0x35, 0x78, 0x39, 0x44, 0x34, 0x4d, 0x6d, 0x71, 0x68, 0x5a, 0x6c, + 0x4b, 0x58, 0x6b, 0x68, 0x74, 0x64, 0x65, 0x41, 0x70, 0x72, 0x36, 0x4c, + 0x51, 0x4b, 0x42, 0x67, 0x51, 0x44, 0x64, 0x0a, 0x49, 0x74, 0x73, 0x4a, + 0x74, 0x39, 0x4a, 0x54, 0x6a, 0x70, 0x69, 0x72, 0x47, 0x66, 0x43, 0x35, + 0x6c, 0x68, 0x77, 0x49, 0x35, 0x73, 0x49, 0x49, 0x43, 0x61, 0x39, 0x6a, + 0x45, 0x4f, 0x39, 0x52, 0x76, 0x65, 0x45, 0x6f, 0x6c, 0x75, 0x57, 0x6b, + 0x4a, 0x59, 0x55, 0x66, 0x47, 0x36, 0x6b, 0x31, 0x78, 0x67, 0x48, 0x64, + 0x6b, 0x59, 0x77, 0x59, 0x57, 0x43, 0x64, 0x58, 0x44, 0x46, 0x5a, 0x61, + 0x0a, 0x44, 0x50, 0x4b, 0x75, 0x77, 0x6e, 0x45, 0x6b, 0x36, 0x4d, 0x72, + 0x55, 0x34, 0x66, 0x31, 0x38, 0x31, 0x6a, 0x6f, 0x4f, 0x37, 0x73, 0x4a, + 0x66, 0x33, 0x35, 0x2f, 0x73, 0x47, 0x6d, 0x75, 0x47, 0x4c, 0x30, 0x53, + 0x48, 0x7a, 0x51, 0x54, 0x76, 0x47, 0x76, 0x6e, 0x30, 0x75, 0x71, 0x6b, + 0x47, 0x4d, 0x38, 0x4d, 0x39, 0x52, 0x64, 0x6f, 0x4d, 0x58, 0x71, 0x7a, + 0x6b, 0x7a, 0x7a, 0x76, 0x4d, 0x0a, 0x4a, 0x67, 0x31, 0x65, 0x6a, 0x31, + 0x62, 0x55, 0x67, 0x58, 0x63, 0x44, 0x62, 0x54, 0x6e, 0x61, 0x45, 0x68, + 0x7a, 0x62, 0x64, 0x4c, 0x69, 0x54, 0x46, 0x73, 0x67, 0x35, 0x4e, 0x7a, + 0x4d, 0x74, 0x4b, 0x77, 0x4f, 0x6a, 0x64, 0x44, 0x49, 0x70, 0x5a, 0x51, + 0x4b, 0x42, 0x67, 0x45, 0x49, 0x48, 0x65, 0x4a, 0x49, 0x71, 0x69, 0x47, + 0x6a, 0x59, 0x67, 0x66, 0x37, 0x6d, 0x55, 0x6c, 0x58, 0x32, 0x0a, 0x76, + 0x4e, 0x57, 0x67, 0x46, 0x4e, 0x6c, 0x7a, 0x41, 0x70, 0x6b, 0x46, 0x53, + 0x43, 0x51, 0x38, 0x54, 0x6b, 0x7a, 0x6b, 0x44, 0x4f, 0x6a, 0x74, 0x43, + 0x64, 0x53, 0x48, 0x66, 0x64, 0x52, 0x44, 0x4a, 0x36, 0x2b, 0x71, 0x38, + 0x63, 0x53, 0x32, 0x54, 0x53, 0x51, 0x37, 0x51, 0x50, 0x6f, 0x41, 0x6c, + 0x49, 0x31, 0x77, 0x6f, 0x53, 0x30, 0x47, 0x34, 0x38, 0x54, 0x4e, 0x62, + 0x56, 0x53, 0x6f, 0x0a, 0x77, 0x44, 0x30, 0x6a, 0x4e, 0x56, 0x52, 0x54, + 0x64, 0x70, 0x41, 0x36, 0x52, 0x35, 0x46, 0x50, 0x73, 0x67, 0x30, 0x39, + 0x6f, 0x68, 0x42, 0x2f, 0x63, 0x61, 0x53, 0x6e, 0x30, 0x7a, 0x6c, 0x47, + 0x56, 0x68, 0x61, 0x32, 0x47, 0x53, 0x30, 0x38, 0x63, 0x65, 0x59, 0x72, + 0x6e, 0x37, 0x6e, 0x6e, 0x34, 0x50, 0x53, 0x5a, 0x2f, 0x55, 0x49, 0x59, + 0x54, 0x6d, 0x33, 0x70, 0x6a, 0x55, 0x6c, 0x56, 0x0a, 0x48, 0x35, 0x74, + 0x76, 0x48, 0x76, 0x30, 0x67, 0x47, 0x32, 0x43, 0x35, 0x76, 0x79, 0x33, + 0x74, 0x49, 0x59, 0x51, 0x74, 0x53, 0x51, 0x43, 0x6b, 0x0a, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x50, 0x52, 0x49, 0x56, 0x41, + 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, + 0x00}; + +extern const char test_signed_client_cert[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x43, + 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x44, 0x4e, 0x7a, 0x43, 0x43, + 0x41, 0x68, 0x38, 0x43, 0x46, 0x47, 0x79, 0x58, 0x30, 0x30, 0x52, 0x43, + 0x65, 0x70, 0x4f, 0x76, 0x2f, 0x71, 0x43, 0x4a, 0x31, 0x6f, 0x56, 0x64, + 0x54, 0x74, 0x59, 0x39, 0x32, 0x55, 0x38, 0x33, 0x4d, 0x41, 0x30, 0x47, + 0x43, 0x53, 0x71, 0x47, 0x53, 0x49, 0x62, 0x33, 0x44, 0x51, 0x45, 0x42, + 0x43, 0x77, 0x55, 0x41, 0x4d, 0x46, 0x59, 0x78, 0x0a, 0x43, 0x7a, 0x41, + 0x4a, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x59, 0x54, 0x41, 0x6b, 0x46, + 0x56, 0x4d, 0x52, 0x4d, 0x77, 0x45, 0x51, 0x59, 0x44, 0x56, 0x51, 0x51, + 0x49, 0x44, 0x41, 0x70, 0x54, 0x62, 0x32, 0x31, 0x6c, 0x4c, 0x56, 0x4e, + 0x30, 0x59, 0x58, 0x52, 0x6c, 0x4d, 0x53, 0x45, 0x77, 0x48, 0x77, 0x59, + 0x44, 0x56, 0x51, 0x51, 0x4b, 0x44, 0x42, 0x68, 0x4a, 0x62, 0x6e, 0x52, + 0x6c, 0x0a, 0x63, 0x6d, 0x35, 0x6c, 0x64, 0x43, 0x42, 0x58, 0x61, 0x57, + 0x52, 0x6e, 0x61, 0x58, 0x52, 0x7a, 0x49, 0x46, 0x42, 0x30, 0x65, 0x53, + 0x42, 0x4d, 0x64, 0x47, 0x51, 0x78, 0x44, 0x7a, 0x41, 0x4e, 0x42, 0x67, + 0x4e, 0x56, 0x42, 0x41, 0x4d, 0x4d, 0x42, 0x6e, 0x52, 0x6c, 0x63, 0x33, + 0x52, 0x6a, 0x59, 0x54, 0x41, 0x65, 0x46, 0x77, 0x30, 0x79, 0x4d, 0x44, + 0x41, 0x7a, 0x4d, 0x54, 0x67, 0x77, 0x0a, 0x4d, 0x54, 0x41, 0x32, 0x4d, + 0x54, 0x42, 0x61, 0x46, 0x77, 0x30, 0x7a, 0x4d, 0x44, 0x41, 0x7a, 0x4d, + 0x54, 0x59, 0x77, 0x4d, 0x54, 0x41, 0x32, 0x4d, 0x54, 0x42, 0x61, 0x4d, + 0x46, 0x6f, 0x78, 0x43, 0x7a, 0x41, 0x4a, 0x42, 0x67, 0x4e, 0x56, 0x42, + 0x41, 0x59, 0x54, 0x41, 0x6b, 0x46, 0x56, 0x4d, 0x52, 0x4d, 0x77, 0x45, + 0x51, 0x59, 0x44, 0x56, 0x51, 0x51, 0x49, 0x44, 0x41, 0x70, 0x54, 0x0a, + 0x62, 0x32, 0x31, 0x6c, 0x4c, 0x56, 0x4e, 0x30, 0x59, 0x58, 0x52, 0x6c, + 0x4d, 0x53, 0x45, 0x77, 0x48, 0x77, 0x59, 0x44, 0x56, 0x51, 0x51, 0x4b, + 0x44, 0x42, 0x68, 0x4a, 0x62, 0x6e, 0x52, 0x6c, 0x63, 0x6d, 0x35, 0x6c, + 0x64, 0x43, 0x42, 0x58, 0x61, 0x57, 0x52, 0x6e, 0x61, 0x58, 0x52, 0x7a, + 0x49, 0x46, 0x42, 0x30, 0x65, 0x53, 0x42, 0x4d, 0x64, 0x47, 0x51, 0x78, + 0x45, 0x7a, 0x41, 0x52, 0x0a, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x4d, + 0x4d, 0x43, 0x6e, 0x52, 0x6c, 0x63, 0x33, 0x52, 0x6a, 0x62, 0x47, 0x6c, + 0x6c, 0x62, 0x6e, 0x51, 0x77, 0x67, 0x67, 0x45, 0x69, 0x4d, 0x41, 0x30, + 0x47, 0x43, 0x53, 0x71, 0x47, 0x53, 0x49, 0x62, 0x33, 0x44, 0x51, 0x45, + 0x42, 0x41, 0x51, 0x55, 0x41, 0x41, 0x34, 0x49, 0x42, 0x44, 0x77, 0x41, + 0x77, 0x67, 0x67, 0x45, 0x4b, 0x41, 0x6f, 0x49, 0x42, 0x0a, 0x41, 0x51, + 0x43, 0x79, 0x71, 0x59, 0x52, 0x70, 0x2b, 0x44, 0x58, 0x56, 0x70, 0x37, + 0x32, 0x4e, 0x46, 0x62, 0x51, 0x48, 0x38, 0x68, 0x64, 0x68, 0x54, 0x5a, + 0x4c, 0x79, 0x63, 0x5a, 0x58, 0x4f, 0x6c, 0x4a, 0x68, 0x6d, 0x4d, 0x73, + 0x72, 0x4a, 0x6d, 0x72, 0x6a, 0x6e, 0x32, 0x70, 0x37, 0x70, 0x49, 0x2f, + 0x38, 0x6d, 0x54, 0x5a, 0x2f, 0x30, 0x46, 0x43, 0x2b, 0x53, 0x47, 0x57, + 0x42, 0x47, 0x0a, 0x5a, 0x56, 0x2b, 0x45, 0x4c, 0x69, 0x48, 0x72, 0x6d, + 0x43, 0x58, 0x35, 0x7a, 0x66, 0x61, 0x49, 0x4c, 0x72, 0x39, 0x49, 0x75, + 0x77, 0x37, 0x47, 0x68, 0x72, 0x33, 0x56, 0x7a, 0x6f, 0x65, 0x66, 0x69, + 0x38, 0x72, 0x36, 0x32, 0x72, 0x4c, 0x75, 0x70, 0x56, 0x50, 0x4e, 0x69, + 0x2f, 0x71, 0x64, 0x71, 0x79, 0x6a, 0x57, 0x6b, 0x32, 0x64, 0x45, 0x43, + 0x48, 0x43, 0x39, 0x5a, 0x33, 0x2b, 0x41, 0x0a, 0x67, 0x33, 0x4b, 0x7a, + 0x4b, 0x54, 0x79, 0x65, 0x72, 0x58, 0x57, 0x6a, 0x4b, 0x63, 0x76, 0x79, + 0x4b, 0x56, 0x6d, 0x4d, 0x30, 0x5a, 0x78, 0x45, 0x30, 0x52, 0x58, 0x68, + 0x44, 0x57, 0x2f, 0x52, 0x6f, 0x51, 0x62, 0x71, 0x5a, 0x73, 0x55, 0x32, + 0x47, 0x4b, 0x67, 0x31, 0x42, 0x32, 0x72, 0x68, 0x55, 0x55, 0x38, 0x4b, + 0x4e, 0x30, 0x67, 0x56, 0x6d, 0x4b, 0x6e, 0x30, 0x72, 0x4a, 0x48, 0x4f, + 0x0a, 0x78, 0x7a, 0x52, 0x56, 0x53, 0x59, 0x65, 0x59, 0x4c, 0x59, 0x70, + 0x35, 0x59, 0x6e, 0x37, 0x4b, 0x72, 0x74, 0x50, 0x4a, 0x63, 0x4b, 0x79, + 0x6f, 0x39, 0x61, 0x56, 0x75, 0x45, 0x72, 0x37, 0x64, 0x47, 0x41, 0x4e, + 0x7a, 0x70, 0x79, 0x46, 0x36, 0x6c, 0x67, 0x2f, 0x6e, 0x59, 0x42, 0x57, + 0x63, 0x2b, 0x39, 0x53, 0x47, 0x77, 0x6b, 0x6f, 0x4c, 0x64, 0x46, 0x76, + 0x4b, 0x76, 0x41, 0x42, 0x59, 0x0a, 0x4a, 0x4d, 0x79, 0x72, 0x62, 0x4e, + 0x68, 0x48, 0x55, 0x51, 0x66, 0x76, 0x30, 0x66, 0x7a, 0x61, 0x5a, 0x30, + 0x50, 0x38, 0x36, 0x64, 0x66, 0x54, 0x45, 0x4e, 0x72, 0x44, 0x78, 0x7a, + 0x41, 0x4c, 0x72, 0x7a, 0x47, 0x6e, 0x71, 0x63, 0x78, 0x33, 0x4b, 0x54, + 0x72, 0x77, 0x4a, 0x6a, 0x6b, 0x5a, 0x2f, 0x61, 0x53, 0x72, 0x31, 0x74, + 0x79, 0x44, 0x30, 0x2f, 0x74, 0x58, 0x76, 0x75, 0x6b, 0x52, 0x0a, 0x46, + 0x69, 0x50, 0x78, 0x57, 0x42, 0x4a, 0x68, 0x6a, 0x48, 0x51, 0x37, 0x30, + 0x47, 0x71, 0x54, 0x46, 0x51, 0x59, 0x31, 0x39, 0x52, 0x62, 0x68, 0x41, + 0x67, 0x4d, 0x42, 0x41, 0x41, 0x45, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, + 0x6f, 0x5a, 0x49, 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x4c, 0x42, + 0x51, 0x41, 0x44, 0x67, 0x67, 0x45, 0x42, 0x41, 0x46, 0x58, 0x43, 0x65, + 0x77, 0x4b, 0x38, 0x0a, 0x63, 0x57, 0x54, 0x2b, 0x7a, 0x57, 0x78, 0x58, + 0x79, 0x47, 0x46, 0x6e, 0x6f, 0x75, 0x46, 0x53, 0x42, 0x7a, 0x54, 0x69, + 0x30, 0x42, 0x4d, 0x42, 0x4a, 0x52, 0x72, 0x68, 0x73, 0x69, 0x4e, 0x6f, + 0x69, 0x51, 0x78, 0x6b, 0x71, 0x69, 0x74, 0x79, 0x4a, 0x48, 0x57, 0x46, + 0x45, 0x78, 0x69, 0x51, 0x5a, 0x69, 0x65, 0x2b, 0x37, 0x43, 0x41, 0x2b, + 0x45, 0x61, 0x62, 0x58, 0x43, 0x51, 0x55, 0x42, 0x0a, 0x2b, 0x4a, 0x77, + 0x4d, 0x53, 0x57, 0x4d, 0x32, 0x39, 0x6a, 0x33, 0x6d, 0x53, 0x77, 0x31, + 0x30, 0x44, 0x54, 0x66, 0x6d, 0x43, 0x33, 0x72, 0x68, 0x68, 0x65, 0x51, + 0x71, 0x47, 0x78, 0x79, 0x33, 0x30, 0x34, 0x42, 0x5a, 0x79, 0x55, 0x70, + 0x64, 0x70, 0x76, 0x49, 0x32, 0x64, 0x74, 0x33, 0x70, 0x2f, 0x6d, 0x63, + 0x73, 0x45, 0x37, 0x4f, 0x2b, 0x70, 0x34, 0x73, 0x51, 0x72, 0x53, 0x65, + 0x70, 0x0a, 0x67, 0x69, 0x6a, 0x69, 0x44, 0x73, 0x73, 0x4b, 0x41, 0x66, + 0x78, 0x54, 0x41, 0x6d, 0x55, 0x4d, 0x39, 0x33, 0x4e, 0x36, 0x2b, 0x51, + 0x38, 0x79, 0x4a, 0x4b, 0x35, 0x69, 0x6d, 0x6d, 0x78, 0x6c, 0x62, 0x65, + 0x59, 0x66, 0x69, 0x6a, 0x6f, 0x42, 0x76, 0x6d, 0x6b, 0x7a, 0x79, 0x42, + 0x2f, 0x42, 0x2b, 0x71, 0x4e, 0x52, 0x50, 0x73, 0x78, 0x30, 0x6e, 0x37, + 0x61, 0x46, 0x47, 0x6e, 0x66, 0x76, 0x0a, 0x6f, 0x57, 0x66, 0x6b, 0x57, + 0x32, 0x39, 0x36, 0x69, 0x50, 0x68, 0x57, 0x4c, 0x69, 0x77, 0x6b, 0x6e, + 0x70, 0x43, 0x33, 0x78, 0x42, 0x36, 0x6f, 0x4b, 0x33, 0x76, 0x52, 0x62, + 0x4b, 0x34, 0x5a, 0x6a, 0x31, 0x4f, 0x61, 0x47, 0x62, 0x30, 0x67, 0x72, + 0x4b, 0x37, 0x56, 0x4e, 0x38, 0x45, 0x79, 0x68, 0x42, 0x69, 0x78, 0x32, + 0x78, 0x56, 0x46, 0x36, 0x31, 0x69, 0x34, 0x64, 0x7a, 0x43, 0x4b, 0x0a, + 0x6b, 0x4d, 0x49, 0x70, 0x6c, 0x37, 0x43, 0x55, 0x70, 0x77, 0x31, 0x4d, + 0x62, 0x32, 0x7a, 0x38, 0x71, 0x33, 0x46, 0x32, 0x62, 0x48, 0x42, 0x53, + 0x37, 0x69, 0x46, 0x37, 0x67, 0x31, 0x43, 0x63, 0x6e, 0x35, 0x56, 0x47, + 0x63, 0x4f, 0x2b, 0x61, 0x4a, 0x2b, 0x36, 0x50, 0x57, 0x79, 0x64, 0x61, + 0x65, 0x71, 0x4a, 0x36, 0x56, 0x45, 0x42, 0x46, 0x30, 0x4e, 0x77, 0x76, + 0x39, 0x77, 0x6f, 0x65, 0x0a, 0x6d, 0x4c, 0x35, 0x41, 0x6c, 0x75, 0x4e, + 0x52, 0x4c, 0x61, 0x71, 0x6a, 0x5a, 0x76, 0x45, 0x3d, 0x0a, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x43, 0x45, 0x52, 0x54, 0x49, + 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, + 0x00}; + +extern const char test_signed_client_key[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x50, + 0x52, 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x45, 0x76, 0x67, 0x49, 0x42, + 0x41, 0x44, 0x41, 0x4e, 0x42, 0x67, 0x6b, 0x71, 0x68, 0x6b, 0x69, 0x47, + 0x39, 0x77, 0x30, 0x42, 0x41, 0x51, 0x45, 0x46, 0x41, 0x41, 0x53, 0x43, + 0x42, 0x4b, 0x67, 0x77, 0x67, 0x67, 0x53, 0x6b, 0x41, 0x67, 0x45, 0x41, + 0x41, 0x6f, 0x49, 0x42, 0x41, 0x51, 0x43, 0x79, 0x71, 0x59, 0x52, 0x70, + 0x2b, 0x44, 0x58, 0x56, 0x70, 0x37, 0x32, 0x4e, 0x0a, 0x46, 0x62, 0x51, + 0x48, 0x38, 0x68, 0x64, 0x68, 0x54, 0x5a, 0x4c, 0x79, 0x63, 0x5a, 0x58, + 0x4f, 0x6c, 0x4a, 0x68, 0x6d, 0x4d, 0x73, 0x72, 0x4a, 0x6d, 0x72, 0x6a, + 0x6e, 0x32, 0x70, 0x37, 0x70, 0x49, 0x2f, 0x38, 0x6d, 0x54, 0x5a, 0x2f, + 0x30, 0x46, 0x43, 0x2b, 0x53, 0x47, 0x57, 0x42, 0x47, 0x5a, 0x56, 0x2b, + 0x45, 0x4c, 0x69, 0x48, 0x72, 0x6d, 0x43, 0x58, 0x35, 0x7a, 0x66, 0x61, + 0x49, 0x0a, 0x4c, 0x72, 0x39, 0x49, 0x75, 0x77, 0x37, 0x47, 0x68, 0x72, + 0x33, 0x56, 0x7a, 0x6f, 0x65, 0x66, 0x69, 0x38, 0x72, 0x36, 0x32, 0x72, + 0x4c, 0x75, 0x70, 0x56, 0x50, 0x4e, 0x69, 0x2f, 0x71, 0x64, 0x71, 0x79, + 0x6a, 0x57, 0x6b, 0x32, 0x64, 0x45, 0x43, 0x48, 0x43, 0x39, 0x5a, 0x33, + 0x2b, 0x41, 0x67, 0x33, 0x4b, 0x7a, 0x4b, 0x54, 0x79, 0x65, 0x72, 0x58, + 0x57, 0x6a, 0x4b, 0x63, 0x76, 0x79, 0x0a, 0x4b, 0x56, 0x6d, 0x4d, 0x30, + 0x5a, 0x78, 0x45, 0x30, 0x52, 0x58, 0x68, 0x44, 0x57, 0x2f, 0x52, 0x6f, + 0x51, 0x62, 0x71, 0x5a, 0x73, 0x55, 0x32, 0x47, 0x4b, 0x67, 0x31, 0x42, + 0x32, 0x72, 0x68, 0x55, 0x55, 0x38, 0x4b, 0x4e, 0x30, 0x67, 0x56, 0x6d, + 0x4b, 0x6e, 0x30, 0x72, 0x4a, 0x48, 0x4f, 0x78, 0x7a, 0x52, 0x56, 0x53, + 0x59, 0x65, 0x59, 0x4c, 0x59, 0x70, 0x35, 0x59, 0x6e, 0x37, 0x4b, 0x0a, + 0x72, 0x74, 0x50, 0x4a, 0x63, 0x4b, 0x79, 0x6f, 0x39, 0x61, 0x56, 0x75, + 0x45, 0x72, 0x37, 0x64, 0x47, 0x41, 0x4e, 0x7a, 0x70, 0x79, 0x46, 0x36, + 0x6c, 0x67, 0x2f, 0x6e, 0x59, 0x42, 0x57, 0x63, 0x2b, 0x39, 0x53, 0x47, + 0x77, 0x6b, 0x6f, 0x4c, 0x64, 0x46, 0x76, 0x4b, 0x76, 0x41, 0x42, 0x59, + 0x4a, 0x4d, 0x79, 0x72, 0x62, 0x4e, 0x68, 0x48, 0x55, 0x51, 0x66, 0x76, + 0x30, 0x66, 0x7a, 0x61, 0x0a, 0x5a, 0x30, 0x50, 0x38, 0x36, 0x64, 0x66, + 0x54, 0x45, 0x4e, 0x72, 0x44, 0x78, 0x7a, 0x41, 0x4c, 0x72, 0x7a, 0x47, + 0x6e, 0x71, 0x63, 0x78, 0x33, 0x4b, 0x54, 0x72, 0x77, 0x4a, 0x6a, 0x6b, + 0x5a, 0x2f, 0x61, 0x53, 0x72, 0x31, 0x74, 0x79, 0x44, 0x30, 0x2f, 0x74, + 0x58, 0x76, 0x75, 0x6b, 0x52, 0x46, 0x69, 0x50, 0x78, 0x57, 0x42, 0x4a, + 0x68, 0x6a, 0x48, 0x51, 0x37, 0x30, 0x47, 0x71, 0x54, 0x0a, 0x46, 0x51, + 0x59, 0x31, 0x39, 0x52, 0x62, 0x68, 0x41, 0x67, 0x4d, 0x42, 0x41, 0x41, + 0x45, 0x43, 0x67, 0x67, 0x45, 0x41, 0x49, 0x4c, 0x38, 0x4a, 0x55, 0x68, + 0x4c, 0x34, 0x61, 0x77, 0x79, 0x76, 0x70, 0x57, 0x68, 0x51, 0x38, 0x78, + 0x50, 0x67, 0x54, 0x53, 0x6c, 0x57, 0x77, 0x62, 0x45, 0x6e, 0x38, 0x42, + 0x45, 0x30, 0x54, 0x61, 0x63, 0x4a, 0x6e, 0x43, 0x49, 0x4c, 0x75, 0x68, + 0x4e, 0x4d, 0x0a, 0x42, 0x52, 0x64, 0x66, 0x38, 0x4c, 0x6c, 0x52, 0x6b, + 0x2f, 0x38, 0x50, 0x4b, 0x51, 0x77, 0x56, 0x70, 0x56, 0x46, 0x33, 0x54, + 0x46, 0x62, 0x59, 0x53, 0x4d, 0x49, 0x2b, 0x55, 0x36, 0x62, 0x34, 0x68, + 0x4d, 0x56, 0x73, 0x73, 0x66, 0x76, 0x33, 0x48, 0x56, 0x51, 0x63, 0x2f, + 0x30, 0x38, 0x33, 0x64, 0x48, 0x71, 0x2b, 0x33, 0x58, 0x4f, 0x77, 0x55, + 0x43, 0x56, 0x6c, 0x55, 0x73, 0x74, 0x52, 0x0a, 0x53, 0x41, 0x7a, 0x54, + 0x45, 0x32, 0x45, 0x35, 0x45, 0x44, 0x4d, 0x72, 0x31, 0x73, 0x74, 0x64, + 0x68, 0x30, 0x53, 0x51, 0x68, 0x56, 0x34, 0x4e, 0x69, 0x6c, 0x66, 0x6f, + 0x73, 0x39, 0x73, 0x35, 0x55, 0x6b, 0x31, 0x5a, 0x36, 0x49, 0x47, 0x53, + 0x7a, 0x74, 0x6f, 0x7a, 0x31, 0x47, 0x67, 0x4f, 0x45, 0x72, 0x49, 0x63, + 0x2f, 0x6d, 0x47, 0x50, 0x79, 0x2f, 0x61, 0x41, 0x2f, 0x68, 0x62, 0x72, + 0x0a, 0x66, 0x52, 0x57, 0x48, 0x76, 0x54, 0x70, 0x33, 0x35, 0x2b, 0x4d, + 0x62, 0x43, 0x4a, 0x53, 0x76, 0x5a, 0x75, 0x4f, 0x65, 0x65, 0x76, 0x58, + 0x32, 0x69, 0x4c, 0x73, 0x30, 0x64, 0x4e, 0x7a, 0x71, 0x64, 0x6b, 0x36, + 0x44, 0x69, 0x4f, 0x57, 0x49, 0x48, 0x2f, 0x42, 0x56, 0x47, 0x69, 0x72, + 0x56, 0x50, 0x74, 0x4f, 0x36, 0x79, 0x6b, 0x72, 0x6b, 0x75, 0x54, 0x6a, + 0x31, 0x46, 0x57, 0x69, 0x4e, 0x0a, 0x68, 0x79, 0x5a, 0x33, 0x4d, 0x42, + 0x43, 0x68, 0x53, 0x68, 0x6c, 0x4e, 0x48, 0x32, 0x70, 0x6f, 0x4e, 0x58, + 0x34, 0x36, 0x6e, 0x74, 0x4f, 0x63, 0x37, 0x6e, 0x45, 0x75, 0x73, 0x30, + 0x71, 0x74, 0x65, 0x4f, 0x67, 0x78, 0x42, 0x4b, 0x38, 0x6c, 0x75, 0x6d, + 0x6d, 0x46, 0x45, 0x74, 0x6c, 0x65, 0x68, 0x43, 0x41, 0x37, 0x68, 0x64, + 0x2f, 0x38, 0x78, 0x75, 0x76, 0x59, 0x6c, 0x50, 0x30, 0x6b, 0x0a, 0x37, + 0x61, 0x4e, 0x36, 0x38, 0x34, 0x4c, 0x43, 0x52, 0x44, 0x61, 0x6a, 0x6d, + 0x41, 0x47, 0x70, 0x6f, 0x5a, 0x4f, 0x35, 0x37, 0x4e, 0x53, 0x44, 0x59, + 0x51, 0x68, 0x41, 0x46, 0x47, 0x5a, 0x65, 0x55, 0x5a, 0x39, 0x33, 0x53, + 0x4d, 0x46, 0x75, 0x63, 0x51, 0x4b, 0x42, 0x67, 0x51, 0x44, 0x65, 0x37, + 0x47, 0x47, 0x6b, 0x7a, 0x5a, 0x46, 0x45, 0x69, 0x76, 0x39, 0x31, 0x75, + 0x31, 0x71, 0x39, 0x0a, 0x6c, 0x67, 0x4d, 0x79, 0x31, 0x68, 0x35, 0x64, + 0x5a, 0x6a, 0x49, 0x5a, 0x4b, 0x67, 0x51, 0x61, 0x4f, 0x61, 0x72, 0x50, + 0x43, 0x36, 0x77, 0x43, 0x51, 0x4d, 0x55, 0x64, 0x71, 0x43, 0x66, 0x36, + 0x63, 0x53, 0x4c, 0x73, 0x41, 0x50, 0x72, 0x34, 0x54, 0x38, 0x45, 0x44, + 0x6f, 0x57, 0x73, 0x6e, 0x59, 0x37, 0x64, 0x53, 0x6e, 0x72, 0x54, 0x5a, + 0x36, 0x59, 0x43, 0x49, 0x46, 0x4c, 0x31, 0x54, 0x0a, 0x69, 0x64, 0x67, + 0x38, 0x4d, 0x33, 0x42, 0x51, 0x58, 0x69, 0x70, 0x49, 0x43, 0x43, 0x4a, + 0x6b, 0x46, 0x4f, 0x52, 0x53, 0x37, 0x36, 0x70, 0x4b, 0x4b, 0x5a, 0x30, + 0x77, 0x4d, 0x6e, 0x33, 0x2f, 0x4e, 0x67, 0x6b, 0x53, 0x65, 0x70, 0x73, + 0x6d, 0x4e, 0x63, 0x74, 0x39, 0x31, 0x57, 0x48, 0x72, 0x36, 0x6f, 0x6b, + 0x76, 0x78, 0x34, 0x74, 0x4f, 0x61, 0x6f, 0x52, 0x43, 0x74, 0x64, 0x7a, + 0x55, 0x0a, 0x67, 0x37, 0x6a, 0x74, 0x34, 0x4d, 0x72, 0x33, 0x73, 0x66, + 0x4c, 0x43, 0x69, 0x5a, 0x74, 0x71, 0x54, 0x51, 0x79, 0x79, 0x53, 0x64, + 0x4d, 0x55, 0x45, 0x77, 0x4b, 0x42, 0x67, 0x51, 0x44, 0x4e, 0x4b, 0x2b, + 0x5a, 0x46, 0x4b, 0x4c, 0x30, 0x58, 0x68, 0x6b, 0x57, 0x5a, 0x50, 0x2b, + 0x50, 0x47, 0x4b, 0x6a, 0x57, 0x47, 0x38, 0x4c, 0x57, 0x70, 0x50, 0x69, + 0x4b, 0x33, 0x64, 0x37, 0x38, 0x2f, 0x0a, 0x77, 0x59, 0x42, 0x46, 0x58, + 0x7a, 0x53, 0x54, 0x47, 0x6c, 0x6b, 0x72, 0x36, 0x46, 0x76, 0x52, 0x6d, + 0x59, 0x74, 0x5a, 0x65, 0x4e, 0x77, 0x58, 0x57, 0x52, 0x59, 0x4c, 0x42, + 0x34, 0x55, 0x78, 0x5a, 0x39, 0x41, 0x74, 0x34, 0x68, 0x62, 0x4a, 0x56, + 0x45, 0x64, 0x69, 0x2f, 0x32, 0x64, 0x49, 0x54, 0x4f, 0x7a, 0x2f, 0x73, + 0x65, 0x68, 0x56, 0x44, 0x79, 0x43, 0x41, 0x6a, 0x6a, 0x73, 0x33, 0x0a, + 0x67, 0x79, 0x63, 0x73, 0x63, 0x33, 0x55, 0x4a, 0x71, 0x69, 0x5a, 0x62, + 0x63, 0x77, 0x35, 0x58, 0x4b, 0x68, 0x49, 0x35, 0x54, 0x57, 0x42, 0x75, + 0x57, 0x78, 0x6b, 0x4b, 0x45, 0x4e, 0x64, 0x62, 0x4d, 0x53, 0x61, 0x79, + 0x6f, 0x67, 0x56, 0x62, 0x70, 0x32, 0x61, 0x53, 0x59, 0x6f, 0x52, 0x62, + 0x6c, 0x48, 0x37, 0x36, 0x34, 0x2f, 0x2f, 0x74, 0x30, 0x41, 0x43, 0x6d, + 0x62, 0x66, 0x54, 0x57, 0x0a, 0x4b, 0x55, 0x51, 0x52, 0x51, 0x50, 0x42, + 0x2f, 0x75, 0x77, 0x4b, 0x42, 0x67, 0x51, 0x43, 0x35, 0x51, 0x6a, 0x6a, + 0x6a, 0x66, 0x50, 0x4c, 0x38, 0x77, 0x34, 0x63, 0x4a, 0x6b, 0x47, 0x6f, + 0x59, 0x70, 0x46, 0x4b, 0x45, 0x4c, 0x4f, 0x32, 0x50, 0x4d, 0x52, 0x37, + 0x78, 0x53, 0x72, 0x6d, 0x65, 0x45, 0x63, 0x36, 0x68, 0x77, 0x6c, 0x46, + 0x77, 0x6a, 0x65, 0x4e, 0x43, 0x67, 0x6a, 0x79, 0x33, 0x0a, 0x4a, 0x4d, + 0x36, 0x67, 0x30, 0x79, 0x2b, 0x2b, 0x72, 0x49, 0x6a, 0x37, 0x4f, 0x32, + 0x71, 0x52, 0x6b, 0x59, 0x30, 0x49, 0x58, 0x46, 0x78, 0x76, 0x76, 0x46, + 0x33, 0x55, 0x75, 0x57, 0x65, 0x64, 0x78, 0x54, 0x43, 0x75, 0x31, 0x78, + 0x43, 0x2f, 0x75, 0x59, 0x48, 0x70, 0x32, 0x74, 0x69, 0x35, 0x30, 0x36, + 0x4c, 0x73, 0x53, 0x63, 0x42, 0x37, 0x59, 0x5a, 0x6f, 0x41, 0x4d, 0x2f, + 0x59, 0x42, 0x0a, 0x34, 0x69, 0x59, 0x6e, 0x39, 0x54, 0x78, 0x36, 0x78, + 0x4c, 0x6f, 0x59, 0x47, 0x50, 0x30, 0x48, 0x30, 0x69, 0x47, 0x77, 0x55, + 0x32, 0x53, 0x79, 0x42, 0x6c, 0x4e, 0x6b, 0x48, 0x54, 0x38, 0x6f, 0x58, + 0x55, 0x2b, 0x53, 0x59, 0x50, 0x35, 0x4d, 0x57, 0x74, 0x59, 0x6b, 0x56, + 0x62, 0x65, 0x53, 0x33, 0x2f, 0x56, 0x74, 0x4e, 0x57, 0x7a, 0x31, 0x67, + 0x51, 0x4b, 0x42, 0x67, 0x51, 0x43, 0x41, 0x0a, 0x36, 0x4e, 0x6b, 0x34, + 0x6b, 0x4e, 0x30, 0x6d, 0x48, 0x37, 0x59, 0x78, 0x45, 0x4b, 0x52, 0x7a, + 0x53, 0x4f, 0x66, 0x79, 0x7a, 0x65, 0x44, 0x46, 0x34, 0x6f, 0x56, 0x37, + 0x6b, 0x75, 0x42, 0x32, 0x46, 0x59, 0x55, 0x62, 0x6b, 0x54, 0x4c, 0x2b, + 0x54, 0x69, 0x72, 0x43, 0x33, 0x4b, 0x35, 0x38, 0x4a, 0x69, 0x59, 0x59, + 0x35, 0x45, 0x67, 0x63, 0x33, 0x31, 0x74, 0x72, 0x4f, 0x4b, 0x46, 0x6d, + 0x0a, 0x4a, 0x6c, 0x7a, 0x31, 0x78, 0x7a, 0x30, 0x62, 0x36, 0x44, 0x6b, + 0x6d, 0x4b, 0x57, 0x54, 0x69, 0x56, 0x33, 0x72, 0x39, 0x4f, 0x50, 0x48, + 0x4b, 0x4a, 0x38, 0x50, 0x37, 0x49, 0x65, 0x4a, 0x78, 0x41, 0x5a, 0x57, + 0x6d, 0x5a, 0x7a, 0x43, 0x64, 0x44, 0x75, 0x77, 0x6b, 0x76, 0x30, 0x69, + 0x2b, 0x57, 0x57, 0x2b, 0x7a, 0x30, 0x7a, 0x73, 0x49, 0x65, 0x33, 0x4a, + 0x6a, 0x45, 0x61, 0x76, 0x4e, 0x0a, 0x33, 0x7a, 0x62, 0x36, 0x4f, 0x37, + 0x52, 0x30, 0x48, 0x74, 0x7a, 0x69, 0x6b, 0x73, 0x57, 0x6f, 0x71, 0x4d, + 0x65, 0x54, 0x71, 0x5a, 0x65, 0x4f, 0x2b, 0x77, 0x61, 0x39, 0x69, 0x77, + 0x36, 0x76, 0x56, 0x4b, 0x51, 0x77, 0x31, 0x77, 0x57, 0x45, 0x71, 0x77, + 0x4b, 0x42, 0x67, 0x46, 0x48, 0x66, 0x61, 0x68, 0x46, 0x73, 0x30, 0x44, + 0x5a, 0x35, 0x63, 0x55, 0x54, 0x70, 0x47, 0x70, 0x42, 0x74, 0x0a, 0x46, + 0x2f, 0x41, 0x51, 0x47, 0x37, 0x75, 0x6b, 0x67, 0x69, 0x70, 0x42, 0x36, + 0x4e, 0x36, 0x41, 0x6b, 0x42, 0x39, 0x6b, 0x44, 0x62, 0x67, 0x43, 0x73, + 0x31, 0x46, 0x4c, 0x67, 0x64, 0x31, 0x39, 0x39, 0x4d, 0x51, 0x72, 0x45, + 0x6e, 0x63, 0x75, 0x67, 0x35, 0x68, 0x66, 0x70, 0x71, 0x38, 0x51, 0x65, + 0x72, 0x62, 0x79, 0x4d, 0x61, 0x74, 0x6d, 0x41, 0x2b, 0x47, 0x58, 0x6f, + 0x47, 0x4d, 0x62, 0x0a, 0x37, 0x76, 0x7a, 0x74, 0x4b, 0x45, 0x48, 0x38, + 0x35, 0x79, 0x7a, 0x70, 0x34, 0x6e, 0x30, 0x32, 0x46, 0x4e, 0x4c, 0x36, + 0x48, 0x37, 0x78, 0x4c, 0x34, 0x56, 0x56, 0x49, 0x4c, 0x76, 0x79, 0x5a, + 0x48, 0x64, 0x6f, 0x6c, 0x6d, 0x69, 0x4f, 0x52, 0x4a, 0x34, 0x71, 0x54, + 0x32, 0x68, 0x5a, 0x6e, 0x6c, 0x38, 0x70, 0x45, 0x51, 0x32, 0x54, 0x59, + 0x75, 0x46, 0x34, 0x52, 0x6c, 0x48, 0x55, 0x64, 0x0a, 0x6e, 0x53, 0x77, + 0x58, 0x58, 0x2b, 0x32, 0x6f, 0x30, 0x4a, 0x2f, 0x6e, 0x46, 0x38, 0x35, + 0x66, 0x6d, 0x34, 0x41, 0x77, 0x57, 0x4b, 0x41, 0x63, 0x0a, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x50, 0x52, 0x49, 0x56, 0x41, + 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, + 0x00}; diff --git a/test/core/end2end/data/server1_cert.cc b/test/core/end2end/data/server1_cert.cc new file mode 100644 index 00000000..b780a84d --- /dev/null +++ b/test/core/end2end/data/server1_cert.cc @@ -0,0 +1,132 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +extern const char test_server1_cert[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x43, + 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x44, 0x74, 0x44, 0x43, 0x43, + 0x41, 0x70, 0x79, 0x67, 0x41, 0x77, 0x49, 0x42, 0x41, 0x67, 0x49, 0x55, + 0x62, 0x4a, 0x66, 0x54, 0x52, 0x45, 0x4a, 0x36, 0x6b, 0x36, 0x2f, 0x2b, + 0x6f, 0x49, 0x6e, 0x57, 0x68, 0x56, 0x31, 0x4f, 0x31, 0x6a, 0x33, 0x5a, + 0x54, 0x30, 0x49, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, 0x6f, 0x5a, 0x49, + 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x4c, 0x0a, 0x42, 0x51, 0x41, + 0x77, 0x56, 0x6a, 0x45, 0x4c, 0x4d, 0x41, 0x6b, 0x47, 0x41, 0x31, 0x55, + 0x45, 0x42, 0x68, 0x4d, 0x43, 0x51, 0x56, 0x55, 0x78, 0x45, 0x7a, 0x41, + 0x52, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x67, 0x4d, 0x43, 0x6c, 0x4e, + 0x76, 0x62, 0x57, 0x55, 0x74, 0x55, 0x33, 0x52, 0x68, 0x64, 0x47, 0x55, + 0x78, 0x49, 0x54, 0x41, 0x66, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x6f, + 0x4d, 0x0a, 0x47, 0x45, 0x6c, 0x75, 0x64, 0x47, 0x56, 0x79, 0x62, 0x6d, + 0x56, 0x30, 0x49, 0x46, 0x64, 0x70, 0x5a, 0x47, 0x64, 0x70, 0x64, 0x48, + 0x4d, 0x67, 0x55, 0x48, 0x52, 0x35, 0x49, 0x45, 0x78, 0x30, 0x5a, 0x44, + 0x45, 0x50, 0x4d, 0x41, 0x30, 0x47, 0x41, 0x31, 0x55, 0x45, 0x41, 0x77, + 0x77, 0x47, 0x64, 0x47, 0x56, 0x7a, 0x64, 0x47, 0x4e, 0x68, 0x4d, 0x42, + 0x34, 0x58, 0x44, 0x54, 0x49, 0x77, 0x0a, 0x4d, 0x44, 0x4d, 0x78, 0x4f, + 0x44, 0x41, 0x7a, 0x4d, 0x54, 0x41, 0x30, 0x4d, 0x6c, 0x6f, 0x58, 0x44, + 0x54, 0x4d, 0x77, 0x4d, 0x44, 0x4d, 0x78, 0x4e, 0x6a, 0x41, 0x7a, 0x4d, + 0x54, 0x41, 0x30, 0x4d, 0x6c, 0x6f, 0x77, 0x5a, 0x54, 0x45, 0x4c, 0x4d, + 0x41, 0x6b, 0x47, 0x41, 0x31, 0x55, 0x45, 0x42, 0x68, 0x4d, 0x43, 0x56, + 0x56, 0x4d, 0x78, 0x45, 0x54, 0x41, 0x50, 0x42, 0x67, 0x4e, 0x56, 0x0a, + 0x42, 0x41, 0x67, 0x4d, 0x43, 0x45, 0x6c, 0x73, 0x62, 0x47, 0x6c, 0x75, + 0x62, 0x32, 0x6c, 0x7a, 0x4d, 0x52, 0x41, 0x77, 0x44, 0x67, 0x59, 0x44, + 0x56, 0x51, 0x51, 0x48, 0x44, 0x41, 0x64, 0x44, 0x61, 0x47, 0x6c, 0x6a, + 0x59, 0x57, 0x64, 0x76, 0x4d, 0x52, 0x55, 0x77, 0x45, 0x77, 0x59, 0x44, + 0x56, 0x51, 0x51, 0x4b, 0x44, 0x41, 0x78, 0x46, 0x65, 0x47, 0x46, 0x74, + 0x63, 0x47, 0x78, 0x6c, 0x0a, 0x4c, 0x43, 0x42, 0x44, 0x62, 0x79, 0x34, + 0x78, 0x47, 0x6a, 0x41, 0x59, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x4d, + 0x4d, 0x45, 0x53, 0x6f, 0x75, 0x64, 0x47, 0x56, 0x7a, 0x64, 0x43, 0x35, + 0x6e, 0x62, 0x32, 0x39, 0x6e, 0x62, 0x47, 0x55, 0x75, 0x59, 0x32, 0x39, + 0x74, 0x4d, 0x49, 0x49, 0x42, 0x49, 0x6a, 0x41, 0x4e, 0x42, 0x67, 0x6b, + 0x71, 0x68, 0x6b, 0x69, 0x47, 0x39, 0x77, 0x30, 0x42, 0x0a, 0x41, 0x51, + 0x45, 0x46, 0x41, 0x41, 0x4f, 0x43, 0x41, 0x51, 0x38, 0x41, 0x4d, 0x49, + 0x49, 0x42, 0x43, 0x67, 0x4b, 0x43, 0x41, 0x51, 0x45, 0x41, 0x35, 0x78, + 0x4f, 0x4f, 0x4e, 0x78, 0x4a, 0x4a, 0x38, 0x62, 0x38, 0x51, 0x61, 0x75, + 0x76, 0x6f, 0x62, 0x35, 0x2f, 0x37, 0x64, 0x50, 0x59, 0x5a, 0x66, 0x49, + 0x63, 0x64, 0x2b, 0x75, 0x68, 0x41, 0x57, 0x4c, 0x32, 0x5a, 0x6c, 0x54, + 0x50, 0x7a, 0x0a, 0x51, 0x76, 0x75, 0x34, 0x6f, 0x46, 0x30, 0x51, 0x49, + 0x34, 0x69, 0x59, 0x67, 0x50, 0x35, 0x69, 0x47, 0x67, 0x72, 0x79, 0x39, + 0x7a, 0x45, 0x74, 0x43, 0x4d, 0x2b, 0x59, 0x51, 0x53, 0x38, 0x55, 0x68, + 0x69, 0x41, 0x6c, 0x50, 0x6c, 0x71, 0x61, 0x36, 0x41, 0x4e, 0x78, 0x67, + 0x69, 0x42, 0x53, 0x45, 0x79, 0x4d, 0x48, 0x48, 0x2f, 0x78, 0x45, 0x38, + 0x6c, 0x6f, 0x2f, 0x2b, 0x63, 0x61, 0x59, 0x0a, 0x47, 0x65, 0x41, 0x43, + 0x71, 0x79, 0x36, 0x34, 0x30, 0x4a, 0x70, 0x6c, 0x2f, 0x4a, 0x6f, 0x63, + 0x46, 0x47, 0x6f, 0x33, 0x78, 0x64, 0x31, 0x4c, 0x38, 0x44, 0x43, 0x61, + 0x77, 0x6a, 0x6c, 0x61, 0x6a, 0x36, 0x65, 0x75, 0x37, 0x54, 0x37, 0x54, + 0x2f, 0x74, 0x70, 0x41, 0x56, 0x32, 0x71, 0x71, 0x31, 0x33, 0x62, 0x35, + 0x37, 0x31, 0x30, 0x65, 0x4e, 0x52, 0x62, 0x43, 0x41, 0x66, 0x46, 0x65, + 0x0a, 0x38, 0x79, 0x41, 0x4c, 0x69, 0x47, 0x51, 0x65, 0x6d, 0x78, 0x30, + 0x49, 0x59, 0x68, 0x6c, 0x5a, 0x58, 0x4e, 0x62, 0x49, 0x47, 0x57, 0x4c, + 0x42, 0x4e, 0x68, 0x42, 0x68, 0x76, 0x56, 0x6a, 0x4a, 0x68, 0x37, 0x55, + 0x76, 0x4f, 0x71, 0x70, 0x41, 0x44, 0x6b, 0x34, 0x78, 0x74, 0x6c, 0x38, + 0x6f, 0x35, 0x6a, 0x30, 0x78, 0x67, 0x4d, 0x49, 0x52, 0x67, 0x36, 0x57, + 0x4a, 0x47, 0x4b, 0x36, 0x63, 0x0a, 0x36, 0x66, 0x66, 0x53, 0x49, 0x67, + 0x34, 0x65, 0x50, 0x31, 0x58, 0x6d, 0x6f, 0x76, 0x4e, 0x59, 0x5a, 0x39, + 0x4c, 0x4c, 0x45, 0x4a, 0x47, 0x36, 0x38, 0x74, 0x46, 0x30, 0x51, 0x2f, + 0x79, 0x49, 0x4e, 0x34, 0x33, 0x42, 0x34, 0x64, 0x74, 0x31, 0x6f, 0x71, + 0x34, 0x6a, 0x7a, 0x53, 0x64, 0x43, 0x62, 0x47, 0x34, 0x46, 0x31, 0x45, + 0x69, 0x79, 0x6b, 0x54, 0x32, 0x54, 0x6d, 0x77, 0x50, 0x56, 0x0a, 0x59, + 0x44, 0x69, 0x38, 0x74, 0x6d, 0x6c, 0x36, 0x44, 0x66, 0x4f, 0x43, 0x44, + 0x47, 0x6e, 0x69, 0x74, 0x38, 0x73, 0x76, 0x6e, 0x4d, 0x45, 0x6d, 0x42, + 0x76, 0x2f, 0x66, 0x63, 0x50, 0x64, 0x33, 0x31, 0x47, 0x53, 0x62, 0x58, + 0x6a, 0x46, 0x38, 0x4d, 0x2b, 0x4b, 0x47, 0x47, 0x51, 0x49, 0x44, 0x41, + 0x51, 0x41, 0x42, 0x6f, 0x32, 0x73, 0x77, 0x61, 0x54, 0x41, 0x4a, 0x42, + 0x67, 0x4e, 0x56, 0x0a, 0x48, 0x52, 0x4d, 0x45, 0x41, 0x6a, 0x41, 0x41, + 0x4d, 0x41, 0x73, 0x47, 0x41, 0x31, 0x55, 0x64, 0x44, 0x77, 0x51, 0x45, + 0x41, 0x77, 0x49, 0x46, 0x34, 0x44, 0x42, 0x50, 0x42, 0x67, 0x4e, 0x56, + 0x48, 0x52, 0x45, 0x45, 0x53, 0x44, 0x42, 0x47, 0x67, 0x68, 0x41, 0x71, + 0x4c, 0x6e, 0x52, 0x6c, 0x63, 0x33, 0x51, 0x75, 0x5a, 0x32, 0x39, 0x76, + 0x5a, 0x32, 0x78, 0x6c, 0x4c, 0x6d, 0x5a, 0x79, 0x0a, 0x67, 0x68, 0x68, + 0x33, 0x59, 0x58, 0x52, 0x6c, 0x63, 0x6e, 0x70, 0x76, 0x62, 0x32, 0x6b, + 0x75, 0x64, 0x47, 0x56, 0x7a, 0x64, 0x43, 0x35, 0x6e, 0x62, 0x32, 0x39, + 0x6e, 0x62, 0x47, 0x55, 0x75, 0x59, 0x6d, 0x57, 0x43, 0x45, 0x69, 0x6f, + 0x75, 0x64, 0x47, 0x56, 0x7a, 0x64, 0x43, 0x35, 0x35, 0x62, 0x33, 0x56, + 0x30, 0x64, 0x57, 0x4a, 0x6c, 0x4c, 0x6d, 0x4e, 0x76, 0x62, 0x59, 0x63, + 0x45, 0x0a, 0x77, 0x4b, 0x67, 0x42, 0x41, 0x7a, 0x41, 0x4e, 0x42, 0x67, + 0x6b, 0x71, 0x68, 0x6b, 0x69, 0x47, 0x39, 0x77, 0x30, 0x42, 0x41, 0x51, + 0x73, 0x46, 0x41, 0x41, 0x4f, 0x43, 0x41, 0x51, 0x45, 0x41, 0x53, 0x38, + 0x68, 0x44, 0x51, 0x41, 0x38, 0x50, 0x53, 0x67, 0x69, 0x70, 0x67, 0x41, + 0x6d, 0x6c, 0x37, 0x51, 0x33, 0x2f, 0x64, 0x6a, 0x77, 0x51, 0x36, 0x34, + 0x34, 0x67, 0x68, 0x57, 0x51, 0x76, 0x0a, 0x43, 0x32, 0x4b, 0x62, 0x2b, + 0x72, 0x33, 0x30, 0x52, 0x43, 0x59, 0x31, 0x45, 0x79, 0x4b, 0x4e, 0x68, + 0x6e, 0x51, 0x6e, 0x49, 0x49, 0x68, 0x2f, 0x4f, 0x55, 0x62, 0x42, 0x5a, + 0x76, 0x68, 0x30, 0x4d, 0x30, 0x69, 0x59, 0x73, 0x79, 0x36, 0x78, 0x71, + 0x58, 0x67, 0x66, 0x44, 0x68, 0x43, 0x42, 0x39, 0x33, 0x41, 0x41, 0x36, + 0x6a, 0x30, 0x69, 0x35, 0x63, 0x53, 0x38, 0x66, 0x6b, 0x68, 0x48, 0x0a, + 0x4a, 0x6c, 0x34, 0x52, 0x4b, 0x30, 0x74, 0x53, 0x6b, 0x47, 0x51, 0x33, + 0x59, 0x4e, 0x59, 0x34, 0x4e, 0x7a, 0x58, 0x77, 0x51, 0x50, 0x2f, 0x76, + 0x6d, 0x55, 0x67, 0x66, 0x6b, 0x77, 0x38, 0x56, 0x42, 0x41, 0x5a, 0x34, + 0x59, 0x34, 0x47, 0x4b, 0x78, 0x70, 0x70, 0x64, 0x41, 0x54, 0x6a, 0x66, + 0x66, 0x49, 0x57, 0x2b, 0x73, 0x72, 0x62, 0x41, 0x6d, 0x64, 0x44, 0x72, + 0x75, 0x49, 0x52, 0x4d, 0x0a, 0x77, 0x50, 0x65, 0x69, 0x6b, 0x67, 0x4f, + 0x6f, 0x52, 0x72, 0x58, 0x66, 0x30, 0x4c, 0x41, 0x31, 0x66, 0x69, 0x34, + 0x54, 0x71, 0x78, 0x41, 0x52, 0x7a, 0x65, 0x52, 0x77, 0x65, 0x6e, 0x51, + 0x70, 0x61, 0x79, 0x4e, 0x66, 0x47, 0x48, 0x54, 0x76, 0x56, 0x46, 0x39, + 0x61, 0x4a, 0x6b, 0x6c, 0x38, 0x48, 0x6f, 0x61, 0x4d, 0x75, 0x6e, 0x54, + 0x41, 0x64, 0x47, 0x35, 0x70, 0x49, 0x56, 0x63, 0x72, 0x0a, 0x39, 0x47, + 0x4b, 0x69, 0x2f, 0x67, 0x45, 0x4d, 0x70, 0x58, 0x55, 0x4a, 0x62, 0x62, + 0x56, 0x76, 0x33, 0x55, 0x35, 0x66, 0x72, 0x58, 0x31, 0x57, 0x6f, 0x34, + 0x43, 0x46, 0x6f, 0x2b, 0x72, 0x5a, 0x57, 0x4a, 0x2f, 0x4c, 0x79, 0x43, + 0x4d, 0x65, 0x62, 0x30, 0x6a, 0x63, 0x69, 0x4e, 0x4c, 0x78, 0x53, 0x64, + 0x4d, 0x77, 0x6a, 0x2f, 0x45, 0x2f, 0x5a, 0x75, 0x45, 0x78, 0x6c, 0x79, + 0x65, 0x5a, 0x0a, 0x67, 0x63, 0x39, 0x63, 0x74, 0x50, 0x6a, 0x53, 0x4d, + 0x76, 0x67, 0x53, 0x79, 0x58, 0x45, 0x4b, 0x76, 0x36, 0x56, 0x77, 0x6f, + 0x62, 0x6c, 0x65, 0x65, 0x67, 0x38, 0x38, 0x56, 0x32, 0x5a, 0x67, 0x7a, + 0x65, 0x6e, 0x7a, 0x69, 0x4f, 0x52, 0x6f, 0x57, 0x6a, 0x34, 0x4b, 0x73, + 0x7a, 0x47, 0x2f, 0x6c, 0x62, 0x51, 0x5a, 0x76, 0x67, 0x3d, 0x3d, 0x0a, + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x43, 0x45, 0x52, + 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, 0x2d, 0x2d, + 0x2d, 0x0a, 0x00}; diff --git a/test/core/end2end/data/server1_key.cc b/test/core/end2end/data/server1_key.cc new file mode 100644 index 00000000..c2b09d20 --- /dev/null +++ b/test/core/end2end/data/server1_key.cc @@ -0,0 +1,162 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +extern const char test_server1_key[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x50, + 0x52, 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x45, 0x76, 0x77, 0x49, 0x42, + 0x41, 0x44, 0x41, 0x4e, 0x42, 0x67, 0x6b, 0x71, 0x68, 0x6b, 0x69, 0x47, + 0x39, 0x77, 0x30, 0x42, 0x41, 0x51, 0x45, 0x46, 0x41, 0x41, 0x53, 0x43, + 0x42, 0x4b, 0x6b, 0x77, 0x67, 0x67, 0x53, 0x6c, 0x41, 0x67, 0x45, 0x41, + 0x41, 0x6f, 0x49, 0x42, 0x41, 0x51, 0x44, 0x6e, 0x45, 0x34, 0x34, 0x33, + 0x45, 0x6b, 0x6e, 0x78, 0x76, 0x78, 0x42, 0x71, 0x0a, 0x36, 0x2b, 0x68, + 0x76, 0x6e, 0x2f, 0x74, 0x30, 0x39, 0x68, 0x6c, 0x38, 0x68, 0x78, 0x33, + 0x36, 0x36, 0x45, 0x42, 0x59, 0x76, 0x5a, 0x6d, 0x56, 0x4d, 0x2f, 0x4e, + 0x43, 0x2b, 0x37, 0x69, 0x67, 0x58, 0x52, 0x41, 0x6a, 0x69, 0x4a, 0x69, + 0x41, 0x2f, 0x6d, 0x49, 0x61, 0x43, 0x76, 0x4c, 0x33, 0x4d, 0x53, 0x30, + 0x49, 0x7a, 0x35, 0x68, 0x42, 0x4c, 0x78, 0x53, 0x47, 0x49, 0x43, 0x55, + 0x2b, 0x0a, 0x57, 0x70, 0x72, 0x6f, 0x41, 0x33, 0x47, 0x43, 0x49, 0x46, + 0x49, 0x54, 0x49, 0x77, 0x63, 0x66, 0x2f, 0x45, 0x54, 0x79, 0x57, 0x6a, + 0x2f, 0x35, 0x78, 0x70, 0x67, 0x5a, 0x34, 0x41, 0x4b, 0x72, 0x4c, 0x72, + 0x6a, 0x51, 0x6d, 0x6d, 0x58, 0x38, 0x6d, 0x68, 0x77, 0x55, 0x61, 0x6a, + 0x66, 0x46, 0x33, 0x55, 0x76, 0x77, 0x4d, 0x4a, 0x72, 0x43, 0x4f, 0x56, + 0x71, 0x50, 0x70, 0x36, 0x37, 0x74, 0x0a, 0x50, 0x74, 0x50, 0x2b, 0x32, + 0x6b, 0x42, 0x58, 0x61, 0x71, 0x72, 0x58, 0x64, 0x76, 0x6e, 0x76, 0x58, + 0x52, 0x34, 0x31, 0x46, 0x73, 0x49, 0x42, 0x38, 0x56, 0x37, 0x7a, 0x49, + 0x41, 0x75, 0x49, 0x5a, 0x42, 0x36, 0x62, 0x48, 0x51, 0x68, 0x69, 0x47, + 0x56, 0x6c, 0x63, 0x31, 0x73, 0x67, 0x5a, 0x59, 0x73, 0x45, 0x32, 0x45, + 0x47, 0x47, 0x39, 0x57, 0x4d, 0x6d, 0x48, 0x74, 0x53, 0x38, 0x36, 0x0a, + 0x71, 0x6b, 0x41, 0x4f, 0x54, 0x6a, 0x47, 0x32, 0x58, 0x79, 0x6a, 0x6d, + 0x50, 0x54, 0x47, 0x41, 0x77, 0x68, 0x47, 0x44, 0x70, 0x59, 0x6b, 0x59, + 0x72, 0x70, 0x7a, 0x70, 0x39, 0x39, 0x49, 0x69, 0x44, 0x68, 0x34, 0x2f, + 0x56, 0x65, 0x61, 0x69, 0x38, 0x31, 0x68, 0x6e, 0x30, 0x73, 0x73, 0x51, + 0x6b, 0x62, 0x72, 0x79, 0x30, 0x58, 0x52, 0x44, 0x2f, 0x49, 0x67, 0x33, + 0x6a, 0x63, 0x48, 0x68, 0x0a, 0x32, 0x33, 0x57, 0x69, 0x72, 0x69, 0x50, + 0x4e, 0x4a, 0x30, 0x4a, 0x73, 0x62, 0x67, 0x58, 0x55, 0x53, 0x4c, 0x4b, + 0x52, 0x50, 0x5a, 0x4f, 0x62, 0x41, 0x39, 0x56, 0x67, 0x4f, 0x4c, 0x79, + 0x32, 0x61, 0x58, 0x6f, 0x4e, 0x38, 0x34, 0x49, 0x4d, 0x61, 0x65, 0x4b, + 0x33, 0x79, 0x79, 0x2b, 0x63, 0x77, 0x53, 0x59, 0x47, 0x2f, 0x39, 0x39, + 0x77, 0x39, 0x33, 0x66, 0x55, 0x5a, 0x4a, 0x74, 0x65, 0x0a, 0x4d, 0x58, + 0x77, 0x7a, 0x34, 0x6f, 0x59, 0x5a, 0x41, 0x67, 0x4d, 0x42, 0x41, 0x41, + 0x45, 0x43, 0x67, 0x67, 0x45, 0x42, 0x41, 0x49, 0x56, 0x6e, 0x32, 0x4e, + 0x63, 0x61, 0x69, 0x2b, 0x34, 0x78, 0x62, 0x48, 0x30, 0x4f, 0x4c, 0x57, + 0x63, 0x6b, 0x61, 0x62, 0x77, 0x67, 0x79, 0x4a, 0x34, 0x49, 0x4d, 0x39, + 0x72, 0x44, 0x63, 0x30, 0x4c, 0x49, 0x55, 0x33, 0x36, 0x38, 0x4f, 0x31, + 0x6b, 0x55, 0x0a, 0x6b, 0x6f, 0x61, 0x69, 0x73, 0x38, 0x71, 0x50, 0x39, + 0x64, 0x75, 0x6a, 0x41, 0x57, 0x67, 0x66, 0x6f, 0x68, 0x33, 0x73, 0x47, + 0x68, 0x2f, 0x59, 0x47, 0x67, 0x4b, 0x6e, 0x39, 0x36, 0x56, 0x6e, 0x73, + 0x5a, 0x6a, 0x4b, 0x48, 0x6c, 0x79, 0x4d, 0x67, 0x46, 0x2b, 0x72, 0x34, + 0x54, 0x61, 0x44, 0x4a, 0x6e, 0x33, 0x6b, 0x32, 0x72, 0x6c, 0x41, 0x4f, + 0x57, 0x63, 0x75, 0x72, 0x47, 0x6c, 0x6a, 0x0a, 0x31, 0x71, 0x61, 0x56, + 0x6c, 0x73, 0x56, 0x34, 0x48, 0x69, 0x45, 0x7a, 0x70, 0x37, 0x70, 0x78, + 0x69, 0x44, 0x6d, 0x48, 0x68, 0x57, 0x76, 0x70, 0x34, 0x36, 0x37, 0x32, + 0x42, 0x62, 0x36, 0x69, 0x42, 0x47, 0x2b, 0x62, 0x73, 0x6a, 0x43, 0x55, + 0x4f, 0x45, 0x6b, 0x2f, 0x6e, 0x39, 0x6f, 0x39, 0x4b, 0x68, 0x5a, 0x7a, + 0x49, 0x42, 0x6c, 0x75, 0x52, 0x68, 0x74, 0x78, 0x43, 0x6d, 0x77, 0x35, + 0x0a, 0x6e, 0x77, 0x34, 0x44, 0x6f, 0x37, 0x7a, 0x30, 0x30, 0x50, 0x54, + 0x76, 0x4e, 0x38, 0x31, 0x32, 0x36, 0x30, 0x75, 0x50, 0x57, 0x53, 0x63, + 0x30, 0x34, 0x49, 0x72, 0x79, 0x74, 0x76, 0x5a, 0x55, 0x69, 0x41, 0x49, + 0x78, 0x2f, 0x35, 0x71, 0x78, 0x44, 0x37, 0x32, 0x62, 0x69, 0x6a, 0x32, + 0x78, 0x4a, 0x38, 0x74, 0x2f, 0x49, 0x39, 0x47, 0x49, 0x38, 0x67, 0x34, + 0x46, 0x74, 0x6f, 0x56, 0x42, 0x0a, 0x38, 0x70, 0x42, 0x36, 0x53, 0x2f, + 0x68, 0x4a, 0x58, 0x31, 0x50, 0x5a, 0x68, 0x68, 0x39, 0x56, 0x6c, 0x55, + 0x36, 0x59, 0x6b, 0x2b, 0x54, 0x4f, 0x66, 0x4f, 0x56, 0x6e, 0x62, 0x65, + 0x62, 0x47, 0x34, 0x57, 0x35, 0x31, 0x33, 0x38, 0x4c, 0x6b, 0x42, 0x38, + 0x33, 0x35, 0x65, 0x71, 0x6b, 0x33, 0x5a, 0x7a, 0x30, 0x71, 0x73, 0x62, + 0x63, 0x32, 0x65, 0x75, 0x6f, 0x69, 0x38, 0x48, 0x78, 0x69, 0x0a, 0x79, + 0x31, 0x56, 0x47, 0x77, 0x51, 0x45, 0x6d, 0x4d, 0x51, 0x36, 0x33, 0x6a, + 0x58, 0x7a, 0x34, 0x63, 0x36, 0x67, 0x2b, 0x58, 0x35, 0x35, 0x69, 0x66, + 0x76, 0x55, 0x4b, 0x39, 0x4a, 0x70, 0x6e, 0x35, 0x45, 0x38, 0x70, 0x71, + 0x2b, 0x70, 0x4d, 0x64, 0x37, 0x45, 0x43, 0x67, 0x59, 0x45, 0x41, 0x39, + 0x33, 0x6c, 0x59, 0x71, 0x2b, 0x43, 0x72, 0x35, 0x34, 0x4b, 0x34, 0x65, + 0x79, 0x35, 0x74, 0x0a, 0x73, 0x57, 0x4d, 0x61, 0x2b, 0x79, 0x65, 0x35, + 0x52, 0x71, 0x78, 0x6a, 0x7a, 0x67, 0x58, 0x6a, 0x32, 0x4b, 0x71, 0x72, + 0x35, 0x35, 0x6a, 0x62, 0x35, 0x34, 0x56, 0x57, 0x47, 0x37, 0x77, 0x70, + 0x32, 0x69, 0x47, 0x62, 0x67, 0x38, 0x46, 0x4d, 0x6c, 0x6b, 0x51, 0x77, + 0x7a, 0x54, 0x4a, 0x77, 0x65, 0x62, 0x7a, 0x44, 0x79, 0x43, 0x53, 0x61, + 0x74, 0x67, 0x75, 0x45, 0x5a, 0x4c, 0x75, 0x42, 0x0a, 0x67, 0x52, 0x47, + 0x72, 0x6f, 0x52, 0x6e, 0x73, 0x55, 0x4f, 0x79, 0x39, 0x76, 0x42, 0x76, + 0x68, 0x4b, 0x50, 0x4f, 0x63, 0x68, 0x39, 0x62, 0x66, 0x4b, 0x49, 0x6c, + 0x36, 0x71, 0x4f, 0x67, 0x7a, 0x4d, 0x4a, 0x42, 0x32, 0x36, 0x37, 0x66, + 0x42, 0x56, 0x57, 0x78, 0x35, 0x79, 0x62, 0x6e, 0x52, 0x62, 0x57, 0x4e, + 0x2f, 0x49, 0x37, 0x52, 0x76, 0x4d, 0x51, 0x66, 0x33, 0x6b, 0x2b, 0x39, + 0x79, 0x0a, 0x62, 0x69, 0x43, 0x49, 0x56, 0x6e, 0x78, 0x44, 0x4c, 0x45, + 0x45, 0x59, 0x79, 0x78, 0x37, 0x7a, 0x38, 0x35, 0x2f, 0x35, 0x71, 0x78, + 0x73, 0x58, 0x67, 0x2f, 0x4d, 0x43, 0x67, 0x59, 0x45, 0x41, 0x37, 0x77, + 0x6d, 0x57, 0x4b, 0x74, 0x43, 0x54, 0x6e, 0x30, 0x33, 0x32, 0x48, 0x79, + 0x39, 0x50, 0x38, 0x4f, 0x4c, 0x34, 0x39, 0x54, 0x30, 0x58, 0x36, 0x5a, + 0x38, 0x46, 0x6c, 0x6b, 0x44, 0x43, 0x0a, 0x52, 0x6b, 0x34, 0x32, 0x79, + 0x67, 0x72, 0x63, 0x2f, 0x4d, 0x55, 0x62, 0x75, 0x67, 0x71, 0x39, 0x52, + 0x47, 0x55, 0x78, 0x63, 0x43, 0x78, 0x6f, 0x49, 0x6d, 0x4f, 0x47, 0x39, + 0x4a, 0x58, 0x55, 0x70, 0x45, 0x74, 0x55, 0x65, 0x33, 0x31, 0x59, 0x44, + 0x6d, 0x32, 0x6a, 0x2b, 0x2f, 0x6e, 0x62, 0x76, 0x72, 0x6a, 0x6c, 0x36, + 0x2f, 0x62, 0x50, 0x32, 0x71, 0x57, 0x73, 0x30, 0x56, 0x37, 0x6c, 0x0a, + 0x64, 0x54, 0x4a, 0x6c, 0x36, 0x64, 0x41, 0x42, 0x50, 0x35, 0x31, 0x70, + 0x43, 0x77, 0x38, 0x2b, 0x6c, 0x34, 0x63, 0x57, 0x67, 0x42, 0x42, 0x58, + 0x30, 0x38, 0x4c, 0x6b, 0x65, 0x65, 0x6e, 0x38, 0x31, 0x32, 0x41, 0x41, + 0x46, 0x4e, 0x72, 0x6a, 0x6d, 0x44, 0x43, 0x6a, 0x58, 0x36, 0x72, 0x48, + 0x6a, 0x57, 0x48, 0x4c, 0x4a, 0x63, 0x70, 0x53, 0x31, 0x38, 0x66, 0x6e, + 0x52, 0x52, 0x6b, 0x50, 0x0a, 0x56, 0x31, 0x64, 0x2f, 0x41, 0x48, 0x57, + 0x58, 0x37, 0x4d, 0x4d, 0x43, 0x67, 0x59, 0x45, 0x41, 0x36, 0x47, 0x73, + 0x77, 0x32, 0x67, 0x75, 0x68, 0x70, 0x30, 0x5a, 0x66, 0x32, 0x47, 0x43, + 0x63, 0x61, 0x4e, 0x4b, 0x35, 0x44, 0x6c, 0x51, 0x61, 0x62, 0x38, 0x4f, + 0x4c, 0x34, 0x48, 0x77, 0x72, 0x70, 0x74, 0x74, 0x7a, 0x6f, 0x34, 0x6b, + 0x75, 0x54, 0x6c, 0x77, 0x74, 0x71, 0x4e, 0x4b, 0x70, 0x0a, 0x51, 0x39, + 0x48, 0x34, 0x61, 0x6c, 0x39, 0x71, 0x66, 0x46, 0x34, 0x43, 0x72, 0x31, + 0x54, 0x46, 0x79, 0x61, 0x39, 0x38, 0x2b, 0x45, 0x56, 0x59, 0x66, 0x38, + 0x79, 0x46, 0x52, 0x4d, 0x33, 0x4e, 0x4c, 0x4e, 0x6a, 0x5a, 0x70, 0x65, + 0x33, 0x67, 0x77, 0x59, 0x66, 0x32, 0x45, 0x65, 0x72, 0x6c, 0x4a, 0x6a, + 0x37, 0x56, 0x4c, 0x63, 0x61, 0x68, 0x77, 0x30, 0x4b, 0x4b, 0x7a, 0x6f, + 0x4e, 0x31, 0x0a, 0x51, 0x42, 0x45, 0x4e, 0x66, 0x77, 0x67, 0x50, 0x4c, + 0x52, 0x6b, 0x35, 0x73, 0x44, 0x6b, 0x78, 0x39, 0x56, 0x68, 0x53, 0x6d, + 0x63, 0x66, 0x6c, 0x2f, 0x64, 0x69, 0x4c, 0x72, 0x6f, 0x5a, 0x64, 0x70, + 0x41, 0x77, 0x74, 0x76, 0x33, 0x76, 0x6f, 0x34, 0x6e, 0x45, 0x6f, 0x78, + 0x65, 0x75, 0x47, 0x46, 0x62, 0x4b, 0x54, 0x47, 0x78, 0x33, 0x51, 0x6b, + 0x66, 0x30, 0x43, 0x67, 0x59, 0x45, 0x41, 0x0a, 0x78, 0x79, 0x52, 0x2b, + 0x64, 0x63, 0x62, 0x30, 0x35, 0x59, 0x67, 0x6d, 0x33, 0x77, 0x34, 0x6b, + 0x6c, 0x48, 0x51, 0x54, 0x6f, 0x77, 0x51, 0x31, 0x30, 0x73, 0x31, 0x48, + 0x38, 0x30, 0x69, 0x61, 0x55, 0x63, 0x5a, 0x42, 0x67, 0x51, 0x75, 0x52, + 0x31, 0x67, 0x68, 0x45, 0x74, 0x44, 0x62, 0x55, 0x50, 0x5a, 0x48, 0x73, + 0x6f, 0x52, 0x35, 0x74, 0x31, 0x78, 0x43, 0x42, 0x30, 0x32, 0x79, 0x73, + 0x0a, 0x44, 0x67, 0x41, 0x77, 0x4c, 0x76, 0x31, 0x62, 0x43, 0x68, 0x49, + 0x76, 0x78, 0x76, 0x48, 0x2f, 0x4c, 0x36, 0x4b, 0x4d, 0x38, 0x6f, 0x76, + 0x5a, 0x32, 0x4c, 0x65, 0x6b, 0x42, 0x58, 0x34, 0x41, 0x76, 0x69, 0x57, + 0x78, 0x6f, 0x42, 0x78, 0x4a, 0x6e, 0x66, 0x7a, 0x2f, 0x45, 0x56, 0x61, + 0x75, 0x39, 0x38, 0x42, 0x30, 0x62, 0x31, 0x61, 0x75, 0x52, 0x4e, 0x36, + 0x65, 0x53, 0x43, 0x38, 0x33, 0x0a, 0x46, 0x52, 0x75, 0x47, 0x6c, 0x64, + 0x6c, 0x53, 0x4f, 0x57, 0x31, 0x7a, 0x2f, 0x6e, 0x53, 0x68, 0x38, 0x56, + 0x69, 0x69, 0x7a, 0x53, 0x59, 0x45, 0x35, 0x48, 0x35, 0x48, 0x58, 0x31, + 0x71, 0x6b, 0x58, 0x45, 0x69, 0x70, 0x70, 0x76, 0x46, 0x52, 0x45, 0x38, + 0x38, 0x43, 0x67, 0x59, 0x42, 0x33, 0x42, 0x66, 0x75, 0x33, 0x59, 0x51, + 0x59, 0x36, 0x30, 0x49, 0x54, 0x57, 0x49, 0x53, 0x68, 0x76, 0x0a, 0x6e, + 0x4e, 0x6b, 0x64, 0x63, 0x62, 0x54, 0x54, 0x39, 0x65, 0x6f, 0x50, 0x39, + 0x73, 0x75, 0x61, 0x52, 0x4a, 0x6a, 0x77, 0x39, 0x32, 0x4c, 0x6e, 0x2b, + 0x37, 0x5a, 0x70, 0x41, 0x4c, 0x59, 0x6c, 0x51, 0x4d, 0x4b, 0x55, 0x5a, + 0x6d, 0x4a, 0x2f, 0x35, 0x75, 0x42, 0x6d, 0x4c, 0x73, 0x34, 0x52, 0x46, + 0x77, 0x55, 0x54, 0x51, 0x72, 0x75, 0x4c, 0x4f, 0x50, 0x4c, 0x34, 0x79, + 0x4c, 0x54, 0x48, 0x0a, 0x61, 0x77, 0x41, 0x44, 0x57, 0x55, 0x7a, 0x73, + 0x33, 0x49, 0x52, 0x72, 0x31, 0x66, 0x77, 0x6e, 0x39, 0x45, 0x2b, 0x7a, + 0x4d, 0x38, 0x4a, 0x56, 0x79, 0x4b, 0x43, 0x6e, 0x55, 0x45, 0x4d, 0x33, + 0x77, 0x34, 0x4e, 0x35, 0x55, 0x5a, 0x73, 0x6b, 0x47, 0x4f, 0x32, 0x6b, + 0x6c, 0x61, 0x73, 0x68, 0x41, 0x64, 0x33, 0x30, 0x68, 0x57, 0x4f, 0x2b, + 0x6b, 0x6e, 0x52, 0x76, 0x2f, 0x79, 0x30, 0x72, 0x0a, 0x75, 0x47, 0x49, + 0x59, 0x73, 0x39, 0x45, 0x6b, 0x37, 0x59, 0x58, 0x6c, 0x58, 0x49, 0x52, + 0x56, 0x72, 0x7a, 0x4d, 0x77, 0x63, 0x73, 0x72, 0x74, 0x31, 0x77, 0x3d, + 0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x50, + 0x52, 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x00}; diff --git a/test/core/end2end/data/test_root_cert.cc b/test/core/end2end/data/test_root_cert.cc new file mode 100644 index 00000000..941ebc82 --- /dev/null +++ b/test/core/end2end/data/test_root_cert.cc @@ -0,0 +1,122 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +extern const char test_root_cert[] = { + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x43, + 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, 0x49, 0x49, 0x44, 0x57, 0x6a, 0x43, 0x43, + 0x41, 0x6b, 0x4b, 0x67, 0x41, 0x77, 0x49, 0x42, 0x41, 0x67, 0x49, 0x55, + 0x57, 0x72, 0x50, 0x30, 0x56, 0x76, 0x48, 0x63, 0x79, 0x2b, 0x4c, 0x50, + 0x36, 0x55, 0x75, 0x59, 0x4e, 0x74, 0x69, 0x4c, 0x39, 0x67, 0x42, 0x68, + 0x44, 0x35, 0x6f, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, 0x6f, 0x5a, 0x49, + 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x4c, 0x0a, 0x42, 0x51, 0x41, + 0x77, 0x56, 0x6a, 0x45, 0x4c, 0x4d, 0x41, 0x6b, 0x47, 0x41, 0x31, 0x55, + 0x45, 0x42, 0x68, 0x4d, 0x43, 0x51, 0x56, 0x55, 0x78, 0x45, 0x7a, 0x41, + 0x52, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x67, 0x4d, 0x43, 0x6c, 0x4e, + 0x76, 0x62, 0x57, 0x55, 0x74, 0x55, 0x33, 0x52, 0x68, 0x64, 0x47, 0x55, + 0x78, 0x49, 0x54, 0x41, 0x66, 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x6f, + 0x4d, 0x0a, 0x47, 0x45, 0x6c, 0x75, 0x64, 0x47, 0x56, 0x79, 0x62, 0x6d, + 0x56, 0x30, 0x49, 0x46, 0x64, 0x70, 0x5a, 0x47, 0x64, 0x70, 0x64, 0x48, + 0x4d, 0x67, 0x55, 0x48, 0x52, 0x35, 0x49, 0x45, 0x78, 0x30, 0x5a, 0x44, + 0x45, 0x50, 0x4d, 0x41, 0x30, 0x47, 0x41, 0x31, 0x55, 0x45, 0x41, 0x77, + 0x77, 0x47, 0x64, 0x47, 0x56, 0x7a, 0x64, 0x47, 0x4e, 0x68, 0x4d, 0x42, + 0x34, 0x58, 0x44, 0x54, 0x49, 0x77, 0x0a, 0x4d, 0x44, 0x4d, 0x78, 0x4e, + 0x7a, 0x45, 0x34, 0x4e, 0x54, 0x6b, 0x31, 0x4d, 0x56, 0x6f, 0x58, 0x44, + 0x54, 0x4d, 0x77, 0x4d, 0x44, 0x4d, 0x78, 0x4e, 0x54, 0x45, 0x34, 0x4e, + 0x54, 0x6b, 0x31, 0x4d, 0x56, 0x6f, 0x77, 0x56, 0x6a, 0x45, 0x4c, 0x4d, + 0x41, 0x6b, 0x47, 0x41, 0x31, 0x55, 0x45, 0x42, 0x68, 0x4d, 0x43, 0x51, + 0x56, 0x55, 0x78, 0x45, 0x7a, 0x41, 0x52, 0x42, 0x67, 0x4e, 0x56, 0x0a, + 0x42, 0x41, 0x67, 0x4d, 0x43, 0x6c, 0x4e, 0x76, 0x62, 0x57, 0x55, 0x74, + 0x55, 0x33, 0x52, 0x68, 0x64, 0x47, 0x55, 0x78, 0x49, 0x54, 0x41, 0x66, + 0x42, 0x67, 0x4e, 0x56, 0x42, 0x41, 0x6f, 0x4d, 0x47, 0x45, 0x6c, 0x75, + 0x64, 0x47, 0x56, 0x79, 0x62, 0x6d, 0x56, 0x30, 0x49, 0x46, 0x64, 0x70, + 0x5a, 0x47, 0x64, 0x70, 0x64, 0x48, 0x4d, 0x67, 0x55, 0x48, 0x52, 0x35, + 0x49, 0x45, 0x78, 0x30, 0x0a, 0x5a, 0x44, 0x45, 0x50, 0x4d, 0x41, 0x30, + 0x47, 0x41, 0x31, 0x55, 0x45, 0x41, 0x77, 0x77, 0x47, 0x64, 0x47, 0x56, + 0x7a, 0x64, 0x47, 0x4e, 0x68, 0x4d, 0x49, 0x49, 0x42, 0x49, 0x6a, 0x41, + 0x4e, 0x42, 0x67, 0x6b, 0x71, 0x68, 0x6b, 0x69, 0x47, 0x39, 0x77, 0x30, + 0x42, 0x41, 0x51, 0x45, 0x46, 0x41, 0x41, 0x4f, 0x43, 0x41, 0x51, 0x38, + 0x41, 0x4d, 0x49, 0x49, 0x42, 0x43, 0x67, 0x4b, 0x43, 0x0a, 0x41, 0x51, + 0x45, 0x41, 0x73, 0x47, 0x4c, 0x30, 0x6f, 0x58, 0x66, 0x6c, 0x46, 0x30, + 0x4c, 0x7a, 0x6f, 0x4d, 0x2b, 0x42, 0x68, 0x2b, 0x71, 0x55, 0x55, 0x39, + 0x79, 0x68, 0x71, 0x7a, 0x77, 0x32, 0x77, 0x38, 0x4f, 0x4f, 0x58, 0x35, + 0x6d, 0x75, 0x2f, 0x69, 0x4e, 0x43, 0x79, 0x55, 0x4f, 0x42, 0x72, 0x71, + 0x61, 0x48, 0x69, 0x37, 0x6d, 0x47, 0x48, 0x78, 0x37, 0x33, 0x47, 0x44, + 0x30, 0x31, 0x0a, 0x64, 0x69, 0x4e, 0x7a, 0x43, 0x7a, 0x76, 0x6c, 0x63, + 0x51, 0x71, 0x64, 0x4e, 0x49, 0x48, 0x36, 0x4e, 0x51, 0x53, 0x4c, 0x37, + 0x44, 0x54, 0x70, 0x42, 0x6a, 0x63, 0x61, 0x36, 0x36, 0x6a, 0x59, 0x54, + 0x39, 0x75, 0x37, 0x33, 0x76, 0x5a, 0x65, 0x32, 0x4d, 0x44, 0x72, 0x72, + 0x31, 0x6e, 0x56, 0x62, 0x75, 0x4c, 0x76, 0x66, 0x75, 0x39, 0x38, 0x35, + 0x30, 0x63, 0x64, 0x78, 0x69, 0x55, 0x4f, 0x0a, 0x49, 0x6e, 0x76, 0x35, + 0x78, 0x66, 0x38, 0x2b, 0x73, 0x54, 0x48, 0x47, 0x30, 0x43, 0x2b, 0x61, + 0x2b, 0x56, 0x41, 0x76, 0x4d, 0x68, 0x73, 0x4c, 0x69, 0x52, 0x6a, 0x73, + 0x71, 0x2b, 0x6c, 0x58, 0x4b, 0x52, 0x4a, 0x79, 0x6b, 0x35, 0x7a, 0x6b, + 0x62, 0x62, 0x73, 0x45, 0x54, 0x79, 0x62, 0x71, 0x70, 0x78, 0x6f, 0x4a, + 0x2b, 0x4b, 0x37, 0x43, 0x6f, 0x53, 0x79, 0x33, 0x79, 0x63, 0x2f, 0x6b, + 0x0a, 0x51, 0x49, 0x59, 0x33, 0x54, 0x69, 0x70, 0x77, 0x45, 0x74, 0x77, + 0x6b, 0x4b, 0x50, 0x34, 0x68, 0x7a, 0x79, 0x6f, 0x36, 0x4b, 0x69, 0x47, + 0x64, 0x2f, 0x44, 0x50, 0x65, 0x78, 0x69, 0x65, 0x34, 0x6e, 0x42, 0x55, + 0x49, 0x6e, 0x4e, 0x33, 0x62, 0x53, 0x31, 0x42, 0x55, 0x65, 0x4e, 0x5a, + 0x35, 0x7a, 0x65, 0x61, 0x49, 0x43, 0x32, 0x65, 0x67, 0x33, 0x62, 0x6b, + 0x65, 0x65, 0x57, 0x37, 0x63, 0x0a, 0x71, 0x54, 0x35, 0x35, 0x62, 0x2b, + 0x59, 0x65, 0x6e, 0x36, 0x43, 0x78, 0x59, 0x30, 0x54, 0x45, 0x6b, 0x7a, + 0x42, 0x4b, 0x36, 0x41, 0x4b, 0x74, 0x2f, 0x57, 0x55, 0x69, 0x61, 0x6c, + 0x4b, 0x4d, 0x67, 0x54, 0x30, 0x77, 0x62, 0x54, 0x78, 0x52, 0x5a, 0x4f, + 0x37, 0x6b, 0x55, 0x43, 0x48, 0x33, 0x53, 0x71, 0x36, 0x65, 0x2f, 0x77, + 0x58, 0x65, 0x46, 0x64, 0x4a, 0x2b, 0x48, 0x76, 0x64, 0x56, 0x0a, 0x4c, + 0x50, 0x6c, 0x41, 0x67, 0x35, 0x54, 0x6e, 0x4d, 0x61, 0x4e, 0x70, 0x52, + 0x64, 0x51, 0x69, 0x68, 0x2f, 0x38, 0x6e, 0x52, 0x46, 0x70, 0x73, 0x64, + 0x77, 0x49, 0x44, 0x41, 0x51, 0x41, 0x42, 0x6f, 0x79, 0x41, 0x77, 0x48, + 0x6a, 0x41, 0x4d, 0x42, 0x67, 0x4e, 0x56, 0x48, 0x52, 0x4d, 0x45, 0x42, + 0x54, 0x41, 0x44, 0x41, 0x51, 0x48, 0x2f, 0x4d, 0x41, 0x34, 0x47, 0x41, + 0x31, 0x55, 0x64, 0x0a, 0x44, 0x77, 0x45, 0x42, 0x2f, 0x77, 0x51, 0x45, + 0x41, 0x77, 0x49, 0x43, 0x42, 0x44, 0x41, 0x4e, 0x42, 0x67, 0x6b, 0x71, + 0x68, 0x6b, 0x69, 0x47, 0x39, 0x77, 0x30, 0x42, 0x41, 0x51, 0x73, 0x46, + 0x41, 0x41, 0x4f, 0x43, 0x41, 0x51, 0x45, 0x41, 0x6b, 0x54, 0x72, 0x4b, + 0x5a, 0x6a, 0x42, 0x72, 0x4a, 0x58, 0x48, 0x70, 0x73, 0x2f, 0x48, 0x72, + 0x6a, 0x4e, 0x43, 0x46, 0x50, 0x62, 0x35, 0x61, 0x0a, 0x54, 0x48, 0x75, + 0x47, 0x50, 0x43, 0x53, 0x73, 0x65, 0x70, 0x65, 0x31, 0x77, 0x6b, 0x4b, + 0x64, 0x53, 0x70, 0x31, 0x68, 0x34, 0x48, 0x47, 0x52, 0x70, 0x4c, 0x6f, + 0x43, 0x67, 0x63, 0x4c, 0x79, 0x73, 0x43, 0x4a, 0x35, 0x68, 0x5a, 0x68, + 0x52, 0x70, 0x48, 0x6b, 0x52, 0x69, 0x68, 0x68, 0x65, 0x66, 0x2b, 0x72, + 0x46, 0x48, 0x45, 0x65, 0x36, 0x30, 0x55, 0x65, 0x50, 0x51, 0x4f, 0x33, + 0x53, 0x0a, 0x43, 0x56, 0x54, 0x74, 0x64, 0x4a, 0x42, 0x34, 0x43, 0x59, + 0x57, 0x70, 0x63, 0x4e, 0x79, 0x58, 0x4f, 0x64, 0x71, 0x65, 0x66, 0x72, + 0x62, 0x4a, 0x57, 0x35, 0x51, 0x4e, 0x6c, 0x6a, 0x78, 0x67, 0x69, 0x36, + 0x46, 0x68, 0x76, 0x73, 0x37, 0x4a, 0x4a, 0x6b, 0x42, 0x71, 0x64, 0x58, + 0x49, 0x6b, 0x57, 0x58, 0x74, 0x46, 0x6b, 0x32, 0x65, 0x52, 0x67, 0x4f, + 0x49, 0x50, 0x32, 0x45, 0x6f, 0x39, 0x0a, 0x2f, 0x4f, 0x48, 0x51, 0x48, + 0x6c, 0x59, 0x6e, 0x77, 0x5a, 0x46, 0x72, 0x6b, 0x36, 0x73, 0x70, 0x34, + 0x77, 0x50, 0x79, 0x52, 0x2b, 0x41, 0x39, 0x35, 0x53, 0x30, 0x74, 0x6f, + 0x5a, 0x42, 0x63, 0x79, 0x44, 0x56, 0x7a, 0x37, 0x75, 0x2b, 0x68, 0x4f, + 0x57, 0x30, 0x70, 0x47, 0x4b, 0x33, 0x77, 0x76, 0x69, 0x4f, 0x65, 0x39, + 0x6c, 0x76, 0x52, 0x67, 0x6a, 0x2f, 0x48, 0x33, 0x50, 0x77, 0x74, 0x0a, + 0x62, 0x65, 0x77, 0x62, 0x30, 0x6c, 0x2b, 0x4d, 0x68, 0x52, 0x69, 0x67, + 0x30, 0x2f, 0x44, 0x56, 0x48, 0x61, 0x6d, 0x79, 0x56, 0x78, 0x72, 0x44, + 0x52, 0x62, 0x71, 0x49, 0x6e, 0x55, 0x31, 0x2f, 0x47, 0x54, 0x4e, 0x43, + 0x77, 0x63, 0x5a, 0x6b, 0x58, 0x4b, 0x59, 0x46, 0x57, 0x53, 0x66, 0x39, + 0x32, 0x55, 0x2b, 0x6b, 0x49, 0x63, 0x54, 0x74, 0x68, 0x32, 0x34, 0x51, + 0x31, 0x67, 0x63, 0x77, 0x0a, 0x65, 0x5a, 0x69, 0x4c, 0x6c, 0x35, 0x46, + 0x66, 0x72, 0x57, 0x6f, 0x6b, 0x55, 0x4e, 0x79, 0x74, 0x46, 0x45, 0x6c, + 0x58, 0x6f, 0x62, 0x30, 0x56, 0x30, 0x61, 0x35, 0x2f, 0x6b, 0x62, 0x68, + 0x69, 0x4c, 0x63, 0x33, 0x79, 0x57, 0x6d, 0x76, 0x57, 0x71, 0x48, 0x54, + 0x70, 0x71, 0x43, 0x41, 0x4c, 0x62, 0x56, 0x79, 0x46, 0x2b, 0x72, 0x4b, + 0x4a, 0x6f, 0x32, 0x66, 0x35, 0x4b, 0x77, 0x3d, 0x3d, 0x0a, 0x2d, 0x2d, + 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x43, 0x45, 0x52, 0x54, 0x49, + 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, + 0x00}; diff --git a/test/core/end2end/dualstack_socket_test.cc b/test/core/end2end/dualstack_socket_test.cc new file mode 100644 index 00000000..23078435 --- /dev/null +++ b/test/core/end2end/dualstack_socket_test.cc @@ -0,0 +1,364 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_EV + +#include + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* This test exercises IPv4, IPv6, and dualstack sockets in various ways. */ + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(5000), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void log_resolved_addrs(const char* label, const char* hostname) { + grpc_resolved_addresses* res = nullptr; + grpc_error_handle error = grpc_blocking_resolve_address(hostname, "80", &res); + if (error != GRPC_ERROR_NONE || res == nullptr) { + GRPC_LOG_IF_ERROR(hostname, error); + return; + } + for (size_t i = 0; i < res->naddrs; ++i) { + gpr_log(GPR_INFO, "%s: %s", label, + grpc_sockaddr_to_uri(&res->addrs[i]).c_str()); + } + grpc_resolved_addresses_destroy(res); +} + +void test_connect(const char* server_host, const char* client_host, int port, + int expect_ok) { + grpc_channel* client; + grpc_server* server; + grpc_completion_queue* cq; + grpc_completion_queue* shutdown_cq; + grpc_call* c; + grpc_call* s; + cq_verifier* cqv; + gpr_timespec deadline; + int got_port; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_call_details call_details; + char* peer; + int picked_port = 0; + + if (port == 0) { + port = grpc_pick_unused_port_or_die(); + picked_port = 1; + } + + std::string server_hostport = grpc_core::JoinHostPort(server_host, port); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + /* Create server. */ + cq = grpc_completion_queue_create_for_next(nullptr); + server = grpc_server_create(nullptr, nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT((got_port = grpc_server_add_insecure_http2_port( + server, server_hostport.c_str())) > 0); + if (port == 0) { + port = got_port; + } else { + GPR_ASSERT(port == got_port); + } + grpc_server_start(server); + cqv = cq_verifier_create(cq); + + /* Create client. */ + std::string client_hostport; + if (client_host[0] == 'i') { + /* for ipv4:/ipv6: addresses, concatenate the port to each of the parts */ + std::vector uri_parts = + absl::StrSplit(client_host, ',', absl::SkipEmpty()); + std::vector hosts_with_port; + hosts_with_port.reserve(uri_parts.size()); + for (const absl::string_view& uri_part : uri_parts) { + hosts_with_port.push_back(absl::StrFormat("%s:%d", uri_part, port)); + } + client_hostport = absl::StrJoin(hosts_with_port, ","); + } else { + client_hostport = grpc_core::JoinHostPort(client_host, port); + } + client = + grpc_insecure_channel_create(client_hostport.c_str(), nullptr, nullptr); + + gpr_log(GPR_INFO, "Testing with server=%s client=%s (expecting %s)", + server_hostport.c_str(), client_hostport.c_str(), + expect_ok ? "success" : "failure"); + log_resolved_addrs("server resolved addr", server_host); + log_resolved_addrs("client resolved addr", client_host); + + if (expect_ok) { + /* Normal deadline, shouldn't be reached. */ + deadline = grpc_timeout_milliseconds_to_deadline(60000); + } else { + /* Give up faster when failure is expected. + BUG: Setting this to 1000 reveals a memory leak (b/18608927). */ + deadline = grpc_timeout_milliseconds_to_deadline(8000); + } + + /* Send a trivial request. */ + grpc_slice host = grpc_slice_from_static_string("foo.test.google.fr"); + c = grpc_channel_create_call(client, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), &host, + deadline, nullptr); + GPR_ASSERT(c); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = expect_ok ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + if (expect_ok) { + /* Check for a successful request. */ + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(c); + gpr_log(GPR_DEBUG, "got peer: '%s'", peer); + gpr_free(peer); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == + grpc_slice_str_cmp(call_details.host, "foo.test.google.fr")); + GPR_ASSERT(was_cancelled == 0); + + grpc_call_unref(s); + } else { + /* Check for a failed connection. */ + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + gpr_log(GPR_INFO, "status: %d (expected: %d)", status, + GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + /* Destroy client. */ + grpc_channel_destroy(client); + + /* Destroy server. */ + shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(server, shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(server); + grpc_completion_queue_destroy(shutdown_cq); + grpc_completion_queue_shutdown(cq); + drain_cq(cq); + grpc_completion_queue_destroy(cq); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + if (picked_port) { + grpc_recycle_unused_port(port); + } +} + +int external_dns_works(const char* host) { + grpc_resolved_addresses* res = nullptr; + grpc_error_handle error = grpc_blocking_resolve_address(host, "80", &res); + GRPC_ERROR_UNREF(error); + if (res != nullptr) { + grpc_resolved_addresses_destroy(res); + return 1; + } + return 0; +} + +int main(int argc, char** argv) { + int do_ipv6 = 1; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + if (!grpc_ipv6_loopback_available()) { + gpr_log(GPR_INFO, "Can't bind to ::1. Skipping IPv6 tests."); + do_ipv6 = 0; + } + + /* For coverage, test with and without dualstack sockets. */ + for (grpc_forbid_dualstack_sockets_for_testing = 0; + grpc_forbid_dualstack_sockets_for_testing <= 1; + grpc_forbid_dualstack_sockets_for_testing++) { + /* :: and 0.0.0.0 are handled identically. */ + test_connect("::", "127.0.0.1", 0, 1); + test_connect("::", "::ffff:127.0.0.1", 0, 1); + test_connect("::", "ipv4:127.0.0.1", 0, 1); + test_connect("::", "ipv6:[::ffff:127.0.0.1]", 0, 1); + test_connect("::", "localhost", 0, 1); + test_connect("0.0.0.0", "127.0.0.1", 0, 1); + test_connect("0.0.0.0", "::ffff:127.0.0.1", 0, 1); + test_connect("0.0.0.0", "ipv4:127.0.0.1", 0, 1); + test_connect("0.0.0.0", "ipv4:127.0.0.1,127.0.0.2,127.0.0.3", 0, 1); + test_connect("0.0.0.0", "ipv6:[::ffff:127.0.0.1],[::ffff:127.0.0.2]", 0, 1); + test_connect("0.0.0.0", "localhost", 0, 1); + if (do_ipv6) { + test_connect("::", "::1", 0, 1); + test_connect("0.0.0.0", "::1", 0, 1); + test_connect("::", "ipv6:[::1]", 0, 1); + test_connect("0.0.0.0", "ipv6:[::1]", 0, 1); + } + + /* These only work when the families agree. */ + test_connect("127.0.0.1", "127.0.0.1", 0, 1); + test_connect("127.0.0.1", "ipv4:127.0.0.1", 0, 1); + if (do_ipv6) { + test_connect("::1", "::1", 0, 1); + test_connect("::1", "127.0.0.1", 0, 0); + test_connect("127.0.0.1", "::1", 0, 0); + test_connect("::1", "ipv6:[::1]", 0, 1); + test_connect("::1", "ipv4:127.0.0.1", 0, 0); + test_connect("127.0.0.1", "ipv6:[::1]", 0, 0); + } + + if (!external_dns_works("loopback46.unittest.grpc.io")) { + gpr_log(GPR_INFO, "Skipping tests that depend on *.unittest.grpc.io."); + } else { + test_connect("loopback46.unittest.grpc.io", "loopback4.unittest.grpc.io", + 0, 1); + test_connect("loopback4.unittest.grpc.io", "loopback46.unittest.grpc.io", + 0, 1); + if (do_ipv6) { + test_connect("loopback46.unittest.grpc.io", + "loopback6.unittest.grpc.io", 0, 1); + test_connect("loopback6.unittest.grpc.io", + "loopback46.unittest.grpc.io", 0, 1); + test_connect("loopback4.unittest.grpc.io", "loopback6.unittest.grpc.io", + 0, 0); + test_connect("loopback6.unittest.grpc.io", "loopback4.unittest.grpc.io", + 0, 0); + } + } + } + + grpc_shutdown(); + + return 0; +} + +#else /* GRPC_POSIX_SOCKET_EV */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_EV */ diff --git a/test/core/end2end/end2end_nosec_tests.cc b/test/core/end2end/end2end_nosec_tests.cc new file mode 100644 index 00000000..c13f3ac2 --- /dev/null +++ b/test/core/end2end/end2end_nosec_tests.cc @@ -0,0 +1,757 @@ + +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This file is auto-generated */ + +#include "test/core/end2end/end2end_tests.h" + +#include +#include + +#include + + +static bool g_pre_init_called = false; + +extern void authority_not_supported(grpc_end2end_test_config config); +extern void authority_not_supported_pre_init(void); +extern void bad_hostname(grpc_end2end_test_config config); +extern void bad_hostname_pre_init(void); +extern void bad_ping(grpc_end2end_test_config config); +extern void bad_ping_pre_init(void); +extern void binary_metadata(grpc_end2end_test_config config); +extern void binary_metadata_pre_init(void); +extern void call_host_override(grpc_end2end_test_config config); +extern void call_host_override_pre_init(void); +extern void cancel_after_accept(grpc_end2end_test_config config); +extern void cancel_after_accept_pre_init(void); +extern void cancel_after_client_done(grpc_end2end_test_config config); +extern void cancel_after_client_done_pre_init(void); +extern void cancel_after_invoke(grpc_end2end_test_config config); +extern void cancel_after_invoke_pre_init(void); +extern void cancel_after_round_trip(grpc_end2end_test_config config); +extern void cancel_after_round_trip_pre_init(void); +extern void cancel_before_invoke(grpc_end2end_test_config config); +extern void cancel_before_invoke_pre_init(void); +extern void cancel_in_a_vacuum(grpc_end2end_test_config config); +extern void cancel_in_a_vacuum_pre_init(void); +extern void cancel_with_status(grpc_end2end_test_config config); +extern void cancel_with_status_pre_init(void); +extern void channelz(grpc_end2end_test_config config); +extern void channelz_pre_init(void); +extern void client_streaming(grpc_end2end_test_config config); +extern void client_streaming_pre_init(void); +extern void compressed_payload(grpc_end2end_test_config config); +extern void compressed_payload_pre_init(void); +extern void connectivity(grpc_end2end_test_config config); +extern void connectivity_pre_init(void); +extern void default_host(grpc_end2end_test_config config); +extern void default_host_pre_init(void); +extern void disappearing_server(grpc_end2end_test_config config); +extern void disappearing_server_pre_init(void); +extern void empty_batch(grpc_end2end_test_config config); +extern void empty_batch_pre_init(void); +extern void filter_causes_close(grpc_end2end_test_config config); +extern void filter_causes_close_pre_init(void); +extern void filter_context(grpc_end2end_test_config config); +extern void filter_context_pre_init(void); +extern void filter_init_fails(grpc_end2end_test_config config); +extern void filter_init_fails_pre_init(void); +extern void filter_latency(grpc_end2end_test_config config); +extern void filter_latency_pre_init(void); +extern void filter_status_code(grpc_end2end_test_config config); +extern void filter_status_code_pre_init(void); +extern void graceful_server_shutdown(grpc_end2end_test_config config); +extern void graceful_server_shutdown_pre_init(void); +extern void high_initial_seqno(grpc_end2end_test_config config); +extern void high_initial_seqno_pre_init(void); +extern void hpack_size(grpc_end2end_test_config config); +extern void hpack_size_pre_init(void); +extern void idempotent_request(grpc_end2end_test_config config); +extern void idempotent_request_pre_init(void); +extern void invoke_large_request(grpc_end2end_test_config config); +extern void invoke_large_request_pre_init(void); +extern void keepalive_timeout(grpc_end2end_test_config config); +extern void keepalive_timeout_pre_init(void); +extern void large_metadata(grpc_end2end_test_config config); +extern void large_metadata_pre_init(void); +extern void max_concurrent_streams(grpc_end2end_test_config config); +extern void max_concurrent_streams_pre_init(void); +extern void max_connection_age(grpc_end2end_test_config config); +extern void max_connection_age_pre_init(void); +extern void max_connection_idle(grpc_end2end_test_config config); +extern void max_connection_idle_pre_init(void); +extern void max_message_length(grpc_end2end_test_config config); +extern void max_message_length_pre_init(void); +extern void negative_deadline(grpc_end2end_test_config config); +extern void negative_deadline_pre_init(void); +extern void no_error_on_hotpath(grpc_end2end_test_config config); +extern void no_error_on_hotpath_pre_init(void); +extern void no_logging(grpc_end2end_test_config config); +extern void no_logging_pre_init(void); +extern void no_op(grpc_end2end_test_config config); +extern void no_op_pre_init(void); +extern void payload(grpc_end2end_test_config config); +extern void payload_pre_init(void); +extern void ping(grpc_end2end_test_config config); +extern void ping_pre_init(void); +extern void ping_pong_streaming(grpc_end2end_test_config config); +extern void ping_pong_streaming_pre_init(void); +extern void proxy_auth(grpc_end2end_test_config config); +extern void proxy_auth_pre_init(void); +extern void registered_call(grpc_end2end_test_config config); +extern void registered_call_pre_init(void); +extern void request_with_flags(grpc_end2end_test_config config); +extern void request_with_flags_pre_init(void); +extern void request_with_payload(grpc_end2end_test_config config); +extern void request_with_payload_pre_init(void); +extern void resource_quota_server(grpc_end2end_test_config config); +extern void resource_quota_server_pre_init(void); +extern void retry(grpc_end2end_test_config config); +extern void retry_pre_init(void); +extern void retry_cancel_during_delay(grpc_end2end_test_config config); +extern void retry_cancel_during_delay_pre_init(void); +extern void retry_cancel_with_multiple_send_batches(grpc_end2end_test_config config); +extern void retry_cancel_with_multiple_send_batches_pre_init(void); +extern void retry_cancellation(grpc_end2end_test_config config); +extern void retry_cancellation_pre_init(void); +extern void retry_disabled(grpc_end2end_test_config config); +extern void retry_disabled_pre_init(void); +extern void retry_exceeds_buffer_size_in_delay(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_delay_pre_init(void); +extern void retry_exceeds_buffer_size_in_initial_batch(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_initial_batch_pre_init(void); +extern void retry_exceeds_buffer_size_in_subsequent_batch(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_subsequent_batch_pre_init(void); +extern void retry_lb_drop(grpc_end2end_test_config config); +extern void retry_lb_drop_pre_init(void); +extern void retry_lb_fail(grpc_end2end_test_config config); +extern void retry_lb_fail_pre_init(void); +extern void retry_non_retriable_status(grpc_end2end_test_config config); +extern void retry_non_retriable_status_pre_init(void); +extern void retry_non_retriable_status_before_recv_trailing_metadata_started(grpc_end2end_test_config config); +extern void retry_non_retriable_status_before_recv_trailing_metadata_started_pre_init(void); +extern void retry_per_attempt_recv_timeout(grpc_end2end_test_config config); +extern void retry_per_attempt_recv_timeout_pre_init(void); +extern void retry_per_attempt_recv_timeout_on_last_attempt(grpc_end2end_test_config config); +extern void retry_per_attempt_recv_timeout_on_last_attempt_pre_init(void); +extern void retry_recv_initial_metadata(grpc_end2end_test_config config); +extern void retry_recv_initial_metadata_pre_init(void); +extern void retry_recv_message(grpc_end2end_test_config config); +extern void retry_recv_message_pre_init(void); +extern void retry_recv_trailing_metadata_error(grpc_end2end_test_config config); +extern void retry_recv_trailing_metadata_error_pre_init(void); +extern void retry_send_initial_metadata_refs(grpc_end2end_test_config config); +extern void retry_send_initial_metadata_refs_pre_init(void); +extern void retry_send_op_fails(grpc_end2end_test_config config); +extern void retry_send_op_fails_pre_init(void); +extern void retry_server_pushback_delay(grpc_end2end_test_config config); +extern void retry_server_pushback_delay_pre_init(void); +extern void retry_server_pushback_disabled(grpc_end2end_test_config config); +extern void retry_server_pushback_disabled_pre_init(void); +extern void retry_streaming(grpc_end2end_test_config config); +extern void retry_streaming_pre_init(void); +extern void retry_streaming_after_commit(grpc_end2end_test_config config); +extern void retry_streaming_after_commit_pre_init(void); +extern void retry_streaming_succeeds_before_replay_finished(grpc_end2end_test_config config); +extern void retry_streaming_succeeds_before_replay_finished_pre_init(void); +extern void retry_throttled(grpc_end2end_test_config config); +extern void retry_throttled_pre_init(void); +extern void retry_too_many_attempts(grpc_end2end_test_config config); +extern void retry_too_many_attempts_pre_init(void); +extern void server_finishes_request(grpc_end2end_test_config config); +extern void server_finishes_request_pre_init(void); +extern void server_streaming(grpc_end2end_test_config config); +extern void server_streaming_pre_init(void); +extern void shutdown_finishes_calls(grpc_end2end_test_config config); +extern void shutdown_finishes_calls_pre_init(void); +extern void shutdown_finishes_tags(grpc_end2end_test_config config); +extern void shutdown_finishes_tags_pre_init(void); +extern void simple_cacheable_request(grpc_end2end_test_config config); +extern void simple_cacheable_request_pre_init(void); +extern void simple_delayed_request(grpc_end2end_test_config config); +extern void simple_delayed_request_pre_init(void); +extern void simple_metadata(grpc_end2end_test_config config); +extern void simple_metadata_pre_init(void); +extern void simple_request(grpc_end2end_test_config config); +extern void simple_request_pre_init(void); +extern void stream_compression_compressed_payload(grpc_end2end_test_config config); +extern void stream_compression_compressed_payload_pre_init(void); +extern void stream_compression_payload(grpc_end2end_test_config config); +extern void stream_compression_payload_pre_init(void); +extern void stream_compression_ping_pong_streaming(grpc_end2end_test_config config); +extern void stream_compression_ping_pong_streaming_pre_init(void); +extern void streaming_error_response(grpc_end2end_test_config config); +extern void streaming_error_response_pre_init(void); +extern void trailing_metadata(grpc_end2end_test_config config); +extern void trailing_metadata_pre_init(void); +extern void write_buffering(grpc_end2end_test_config config); +extern void write_buffering_pre_init(void); +extern void write_buffering_at_end(grpc_end2end_test_config config); +extern void write_buffering_at_end_pre_init(void); + +void grpc_end2end_tests_pre_init(void) { + GPR_ASSERT(!g_pre_init_called); + g_pre_init_called = true; + authority_not_supported_pre_init(); + bad_hostname_pre_init(); + bad_ping_pre_init(); + binary_metadata_pre_init(); + call_host_override_pre_init(); + cancel_after_accept_pre_init(); + cancel_after_client_done_pre_init(); + cancel_after_invoke_pre_init(); + cancel_after_round_trip_pre_init(); + cancel_before_invoke_pre_init(); + cancel_in_a_vacuum_pre_init(); + cancel_with_status_pre_init(); + channelz_pre_init(); + client_streaming_pre_init(); + compressed_payload_pre_init(); + connectivity_pre_init(); + default_host_pre_init(); + disappearing_server_pre_init(); + empty_batch_pre_init(); + filter_causes_close_pre_init(); + filter_context_pre_init(); + filter_init_fails_pre_init(); + filter_latency_pre_init(); + filter_status_code_pre_init(); + graceful_server_shutdown_pre_init(); + high_initial_seqno_pre_init(); + hpack_size_pre_init(); + idempotent_request_pre_init(); + invoke_large_request_pre_init(); + keepalive_timeout_pre_init(); + large_metadata_pre_init(); + max_concurrent_streams_pre_init(); + max_connection_age_pre_init(); + max_connection_idle_pre_init(); + max_message_length_pre_init(); + negative_deadline_pre_init(); + no_error_on_hotpath_pre_init(); + no_logging_pre_init(); + no_op_pre_init(); + payload_pre_init(); + ping_pre_init(); + ping_pong_streaming_pre_init(); + proxy_auth_pre_init(); + registered_call_pre_init(); + request_with_flags_pre_init(); + request_with_payload_pre_init(); + resource_quota_server_pre_init(); + retry_pre_init(); + retry_cancel_during_delay_pre_init(); + retry_cancel_with_multiple_send_batches_pre_init(); + retry_cancellation_pre_init(); + retry_disabled_pre_init(); + retry_exceeds_buffer_size_in_delay_pre_init(); + retry_exceeds_buffer_size_in_initial_batch_pre_init(); + retry_exceeds_buffer_size_in_subsequent_batch_pre_init(); + retry_lb_drop_pre_init(); + retry_lb_fail_pre_init(); + retry_non_retriable_status_pre_init(); + retry_non_retriable_status_before_recv_trailing_metadata_started_pre_init(); + retry_per_attempt_recv_timeout_pre_init(); + retry_per_attempt_recv_timeout_on_last_attempt_pre_init(); + retry_recv_initial_metadata_pre_init(); + retry_recv_message_pre_init(); + retry_recv_trailing_metadata_error_pre_init(); + retry_send_initial_metadata_refs_pre_init(); + retry_send_op_fails_pre_init(); + retry_server_pushback_delay_pre_init(); + retry_server_pushback_disabled_pre_init(); + retry_streaming_pre_init(); + retry_streaming_after_commit_pre_init(); + retry_streaming_succeeds_before_replay_finished_pre_init(); + retry_throttled_pre_init(); + retry_too_many_attempts_pre_init(); + server_finishes_request_pre_init(); + server_streaming_pre_init(); + shutdown_finishes_calls_pre_init(); + shutdown_finishes_tags_pre_init(); + simple_cacheable_request_pre_init(); + simple_delayed_request_pre_init(); + simple_metadata_pre_init(); + simple_request_pre_init(); + stream_compression_compressed_payload_pre_init(); + stream_compression_payload_pre_init(); + stream_compression_ping_pong_streaming_pre_init(); + streaming_error_response_pre_init(); + trailing_metadata_pre_init(); + write_buffering_pre_init(); + write_buffering_at_end_pre_init(); +} + +// NOLINTNEXTLINE(readability-function-size) +void grpc_end2end_tests(int argc, char **argv, + grpc_end2end_test_config config) { + int i; + + GPR_ASSERT(g_pre_init_called); + + if (argc <= 1) { + authority_not_supported(config); + bad_hostname(config); + bad_ping(config); + binary_metadata(config); + call_host_override(config); + cancel_after_accept(config); + cancel_after_client_done(config); + cancel_after_invoke(config); + cancel_after_round_trip(config); + cancel_before_invoke(config); + cancel_in_a_vacuum(config); + cancel_with_status(config); + channelz(config); + client_streaming(config); + compressed_payload(config); + connectivity(config); + default_host(config); + disappearing_server(config); + empty_batch(config); + filter_causes_close(config); + filter_context(config); + filter_init_fails(config); + filter_latency(config); + filter_status_code(config); + graceful_server_shutdown(config); + high_initial_seqno(config); + hpack_size(config); + idempotent_request(config); + invoke_large_request(config); + keepalive_timeout(config); + large_metadata(config); + max_concurrent_streams(config); + max_connection_age(config); + max_connection_idle(config); + max_message_length(config); + negative_deadline(config); + no_error_on_hotpath(config); + no_logging(config); + no_op(config); + payload(config); + ping(config); + ping_pong_streaming(config); + proxy_auth(config); + registered_call(config); + request_with_flags(config); + request_with_payload(config); + resource_quota_server(config); + retry(config); + retry_cancel_during_delay(config); + retry_cancel_with_multiple_send_batches(config); + retry_cancellation(config); + retry_disabled(config); + retry_exceeds_buffer_size_in_delay(config); + retry_exceeds_buffer_size_in_initial_batch(config); + retry_exceeds_buffer_size_in_subsequent_batch(config); + retry_lb_drop(config); + retry_lb_fail(config); + retry_non_retriable_status(config); + retry_non_retriable_status_before_recv_trailing_metadata_started(config); + retry_per_attempt_recv_timeout(config); + retry_per_attempt_recv_timeout_on_last_attempt(config); + retry_recv_initial_metadata(config); + retry_recv_message(config); + retry_recv_trailing_metadata_error(config); + retry_send_initial_metadata_refs(config); + retry_send_op_fails(config); + retry_server_pushback_delay(config); + retry_server_pushback_disabled(config); + retry_streaming(config); + retry_streaming_after_commit(config); + retry_streaming_succeeds_before_replay_finished(config); + retry_throttled(config); + retry_too_many_attempts(config); + server_finishes_request(config); + server_streaming(config); + shutdown_finishes_calls(config); + shutdown_finishes_tags(config); + simple_cacheable_request(config); + simple_delayed_request(config); + simple_metadata(config); + simple_request(config); + stream_compression_compressed_payload(config); + stream_compression_payload(config); + stream_compression_ping_pong_streaming(config); + streaming_error_response(config); + trailing_metadata(config); + write_buffering(config); + write_buffering_at_end(config); + return; + } + + for (i = 1; i < argc; i++) { + if (0 == strcmp("authority_not_supported", argv[i])) { + authority_not_supported(config); + continue; + } + if (0 == strcmp("bad_hostname", argv[i])) { + bad_hostname(config); + continue; + } + if (0 == strcmp("bad_ping", argv[i])) { + bad_ping(config); + continue; + } + if (0 == strcmp("binary_metadata", argv[i])) { + binary_metadata(config); + continue; + } + if (0 == strcmp("call_host_override", argv[i])) { + call_host_override(config); + continue; + } + if (0 == strcmp("cancel_after_accept", argv[i])) { + cancel_after_accept(config); + continue; + } + if (0 == strcmp("cancel_after_client_done", argv[i])) { + cancel_after_client_done(config); + continue; + } + if (0 == strcmp("cancel_after_invoke", argv[i])) { + cancel_after_invoke(config); + continue; + } + if (0 == strcmp("cancel_after_round_trip", argv[i])) { + cancel_after_round_trip(config); + continue; + } + if (0 == strcmp("cancel_before_invoke", argv[i])) { + cancel_before_invoke(config); + continue; + } + if (0 == strcmp("cancel_in_a_vacuum", argv[i])) { + cancel_in_a_vacuum(config); + continue; + } + if (0 == strcmp("cancel_with_status", argv[i])) { + cancel_with_status(config); + continue; + } + if (0 == strcmp("channelz", argv[i])) { + channelz(config); + continue; + } + if (0 == strcmp("client_streaming", argv[i])) { + client_streaming(config); + continue; + } + if (0 == strcmp("compressed_payload", argv[i])) { + compressed_payload(config); + continue; + } + if (0 == strcmp("connectivity", argv[i])) { + connectivity(config); + continue; + } + if (0 == strcmp("default_host", argv[i])) { + default_host(config); + continue; + } + if (0 == strcmp("disappearing_server", argv[i])) { + disappearing_server(config); + continue; + } + if (0 == strcmp("empty_batch", argv[i])) { + empty_batch(config); + continue; + } + if (0 == strcmp("filter_causes_close", argv[i])) { + filter_causes_close(config); + continue; + } + if (0 == strcmp("filter_context", argv[i])) { + filter_context(config); + continue; + } + if (0 == strcmp("filter_init_fails", argv[i])) { + filter_init_fails(config); + continue; + } + if (0 == strcmp("filter_latency", argv[i])) { + filter_latency(config); + continue; + } + if (0 == strcmp("filter_status_code", argv[i])) { + filter_status_code(config); + continue; + } + if (0 == strcmp("graceful_server_shutdown", argv[i])) { + graceful_server_shutdown(config); + continue; + } + if (0 == strcmp("high_initial_seqno", argv[i])) { + high_initial_seqno(config); + continue; + } + if (0 == strcmp("hpack_size", argv[i])) { + hpack_size(config); + continue; + } + if (0 == strcmp("idempotent_request", argv[i])) { + idempotent_request(config); + continue; + } + if (0 == strcmp("invoke_large_request", argv[i])) { + invoke_large_request(config); + continue; + } + if (0 == strcmp("keepalive_timeout", argv[i])) { + keepalive_timeout(config); + continue; + } + if (0 == strcmp("large_metadata", argv[i])) { + large_metadata(config); + continue; + } + if (0 == strcmp("max_concurrent_streams", argv[i])) { + max_concurrent_streams(config); + continue; + } + if (0 == strcmp("max_connection_age", argv[i])) { + max_connection_age(config); + continue; + } + if (0 == strcmp("max_connection_idle", argv[i])) { + max_connection_idle(config); + continue; + } + if (0 == strcmp("max_message_length", argv[i])) { + max_message_length(config); + continue; + } + if (0 == strcmp("negative_deadline", argv[i])) { + negative_deadline(config); + continue; + } + if (0 == strcmp("no_error_on_hotpath", argv[i])) { + no_error_on_hotpath(config); + continue; + } + if (0 == strcmp("no_logging", argv[i])) { + no_logging(config); + continue; + } + if (0 == strcmp("no_op", argv[i])) { + no_op(config); + continue; + } + if (0 == strcmp("payload", argv[i])) { + payload(config); + continue; + } + if (0 == strcmp("ping", argv[i])) { + ping(config); + continue; + } + if (0 == strcmp("ping_pong_streaming", argv[i])) { + ping_pong_streaming(config); + continue; + } + if (0 == strcmp("proxy_auth", argv[i])) { + proxy_auth(config); + continue; + } + if (0 == strcmp("registered_call", argv[i])) { + registered_call(config); + continue; + } + if (0 == strcmp("request_with_flags", argv[i])) { + request_with_flags(config); + continue; + } + if (0 == strcmp("request_with_payload", argv[i])) { + request_with_payload(config); + continue; + } + if (0 == strcmp("resource_quota_server", argv[i])) { + resource_quota_server(config); + continue; + } + if (0 == strcmp("retry", argv[i])) { + retry(config); + continue; + } + if (0 == strcmp("retry_cancel_during_delay", argv[i])) { + retry_cancel_during_delay(config); + continue; + } + if (0 == strcmp("retry_cancel_with_multiple_send_batches", argv[i])) { + retry_cancel_with_multiple_send_batches(config); + continue; + } + if (0 == strcmp("retry_cancellation", argv[i])) { + retry_cancellation(config); + continue; + } + if (0 == strcmp("retry_disabled", argv[i])) { + retry_disabled(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_delay", argv[i])) { + retry_exceeds_buffer_size_in_delay(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_initial_batch", argv[i])) { + retry_exceeds_buffer_size_in_initial_batch(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_subsequent_batch", argv[i])) { + retry_exceeds_buffer_size_in_subsequent_batch(config); + continue; + } + if (0 == strcmp("retry_lb_drop", argv[i])) { + retry_lb_drop(config); + continue; + } + if (0 == strcmp("retry_lb_fail", argv[i])) { + retry_lb_fail(config); + continue; + } + if (0 == strcmp("retry_non_retriable_status", argv[i])) { + retry_non_retriable_status(config); + continue; + } + if (0 == strcmp("retry_non_retriable_status_before_recv_trailing_metadata_started", argv[i])) { + retry_non_retriable_status_before_recv_trailing_metadata_started(config); + continue; + } + if (0 == strcmp("retry_per_attempt_recv_timeout", argv[i])) { + retry_per_attempt_recv_timeout(config); + continue; + } + if (0 == strcmp("retry_per_attempt_recv_timeout_on_last_attempt", argv[i])) { + retry_per_attempt_recv_timeout_on_last_attempt(config); + continue; + } + if (0 == strcmp("retry_recv_initial_metadata", argv[i])) { + retry_recv_initial_metadata(config); + continue; + } + if (0 == strcmp("retry_recv_message", argv[i])) { + retry_recv_message(config); + continue; + } + if (0 == strcmp("retry_recv_trailing_metadata_error", argv[i])) { + retry_recv_trailing_metadata_error(config); + continue; + } + if (0 == strcmp("retry_send_initial_metadata_refs", argv[i])) { + retry_send_initial_metadata_refs(config); + continue; + } + if (0 == strcmp("retry_send_op_fails", argv[i])) { + retry_send_op_fails(config); + continue; + } + if (0 == strcmp("retry_server_pushback_delay", argv[i])) { + retry_server_pushback_delay(config); + continue; + } + if (0 == strcmp("retry_server_pushback_disabled", argv[i])) { + retry_server_pushback_disabled(config); + continue; + } + if (0 == strcmp("retry_streaming", argv[i])) { + retry_streaming(config); + continue; + } + if (0 == strcmp("retry_streaming_after_commit", argv[i])) { + retry_streaming_after_commit(config); + continue; + } + if (0 == strcmp("retry_streaming_succeeds_before_replay_finished", argv[i])) { + retry_streaming_succeeds_before_replay_finished(config); + continue; + } + if (0 == strcmp("retry_throttled", argv[i])) { + retry_throttled(config); + continue; + } + if (0 == strcmp("retry_too_many_attempts", argv[i])) { + retry_too_many_attempts(config); + continue; + } + if (0 == strcmp("server_finishes_request", argv[i])) { + server_finishes_request(config); + continue; + } + if (0 == strcmp("server_streaming", argv[i])) { + server_streaming(config); + continue; + } + if (0 == strcmp("shutdown_finishes_calls", argv[i])) { + shutdown_finishes_calls(config); + continue; + } + if (0 == strcmp("shutdown_finishes_tags", argv[i])) { + shutdown_finishes_tags(config); + continue; + } + if (0 == strcmp("simple_cacheable_request", argv[i])) { + simple_cacheable_request(config); + continue; + } + if (0 == strcmp("simple_delayed_request", argv[i])) { + simple_delayed_request(config); + continue; + } + if (0 == strcmp("simple_metadata", argv[i])) { + simple_metadata(config); + continue; + } + if (0 == strcmp("simple_request", argv[i])) { + simple_request(config); + continue; + } + if (0 == strcmp("stream_compression_compressed_payload", argv[i])) { + stream_compression_compressed_payload(config); + continue; + } + if (0 == strcmp("stream_compression_payload", argv[i])) { + stream_compression_payload(config); + continue; + } + if (0 == strcmp("stream_compression_ping_pong_streaming", argv[i])) { + stream_compression_ping_pong_streaming(config); + continue; + } + if (0 == strcmp("streaming_error_response", argv[i])) { + streaming_error_response(config); + continue; + } + if (0 == strcmp("trailing_metadata", argv[i])) { + trailing_metadata(config); + continue; + } + if (0 == strcmp("write_buffering", argv[i])) { + write_buffering(config); + continue; + } + if (0 == strcmp("write_buffering_at_end", argv[i])) { + write_buffering_at_end(config); + continue; + } + gpr_log(GPR_DEBUG, "not a test: '%s'", argv[i]); + abort(); + } +} diff --git a/test/core/end2end/end2end_test_utils.cc b/test/core/end2end/end2end_test_utils.cc new file mode 100644 index 00000000..c5927611 --- /dev/null +++ b/test/core/end2end/end2end_test_utils.cc @@ -0,0 +1,50 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "test/core/end2end/end2end_tests.h" + +const char* get_host_override_string(const char* str, + grpc_end2end_test_config config) { + if (config.feature_mask & FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER) { + return str; + } else { + return nullptr; + } +} + +const grpc_slice* get_host_override_slice(const char* str, + grpc_end2end_test_config config) { + const char* r = get_host_override_string(str, config); + if (r != nullptr) { + static grpc_slice ret; + ret = grpc_slice_from_static_string(r); + return &ret; + } + return nullptr; +} + +void validate_host_override_string(const char* pattern, grpc_slice str, + grpc_end2end_test_config config) { + if (config.feature_mask & FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER) { + GPR_ASSERT(0 == grpc_slice_str_cmp(str, pattern)); + } +} diff --git a/test/core/end2end/end2end_tests.cc b/test/core/end2end/end2end_tests.cc new file mode 100644 index 00000000..231e55c1 --- /dev/null +++ b/test/core/end2end/end2end_tests.cc @@ -0,0 +1,773 @@ + +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This file is auto-generated */ + +#include "test/core/end2end/end2end_tests.h" + +#include +#include + +#include + + +static bool g_pre_init_called = false; + +extern void authority_not_supported(grpc_end2end_test_config config); +extern void authority_not_supported_pre_init(void); +extern void bad_hostname(grpc_end2end_test_config config); +extern void bad_hostname_pre_init(void); +extern void bad_ping(grpc_end2end_test_config config); +extern void bad_ping_pre_init(void); +extern void binary_metadata(grpc_end2end_test_config config); +extern void binary_metadata_pre_init(void); +extern void call_creds(grpc_end2end_test_config config); +extern void call_creds_pre_init(void); +extern void call_host_override(grpc_end2end_test_config config); +extern void call_host_override_pre_init(void); +extern void cancel_after_accept(grpc_end2end_test_config config); +extern void cancel_after_accept_pre_init(void); +extern void cancel_after_client_done(grpc_end2end_test_config config); +extern void cancel_after_client_done_pre_init(void); +extern void cancel_after_invoke(grpc_end2end_test_config config); +extern void cancel_after_invoke_pre_init(void); +extern void cancel_after_round_trip(grpc_end2end_test_config config); +extern void cancel_after_round_trip_pre_init(void); +extern void cancel_before_invoke(grpc_end2end_test_config config); +extern void cancel_before_invoke_pre_init(void); +extern void cancel_in_a_vacuum(grpc_end2end_test_config config); +extern void cancel_in_a_vacuum_pre_init(void); +extern void cancel_with_status(grpc_end2end_test_config config); +extern void cancel_with_status_pre_init(void); +extern void channelz(grpc_end2end_test_config config); +extern void channelz_pre_init(void); +extern void client_streaming(grpc_end2end_test_config config); +extern void client_streaming_pre_init(void); +extern void compressed_payload(grpc_end2end_test_config config); +extern void compressed_payload_pre_init(void); +extern void connectivity(grpc_end2end_test_config config); +extern void connectivity_pre_init(void); +extern void default_host(grpc_end2end_test_config config); +extern void default_host_pre_init(void); +extern void disappearing_server(grpc_end2end_test_config config); +extern void disappearing_server_pre_init(void); +extern void empty_batch(grpc_end2end_test_config config); +extern void empty_batch_pre_init(void); +extern void filter_causes_close(grpc_end2end_test_config config); +extern void filter_causes_close_pre_init(void); +extern void filter_context(grpc_end2end_test_config config); +extern void filter_context_pre_init(void); +extern void filter_init_fails(grpc_end2end_test_config config); +extern void filter_init_fails_pre_init(void); +extern void filter_latency(grpc_end2end_test_config config); +extern void filter_latency_pre_init(void); +extern void filter_status_code(grpc_end2end_test_config config); +extern void filter_status_code_pre_init(void); +extern void graceful_server_shutdown(grpc_end2end_test_config config); +extern void graceful_server_shutdown_pre_init(void); +extern void high_initial_seqno(grpc_end2end_test_config config); +extern void high_initial_seqno_pre_init(void); +extern void hpack_size(grpc_end2end_test_config config); +extern void hpack_size_pre_init(void); +extern void idempotent_request(grpc_end2end_test_config config); +extern void idempotent_request_pre_init(void); +extern void invoke_large_request(grpc_end2end_test_config config); +extern void invoke_large_request_pre_init(void); +extern void keepalive_timeout(grpc_end2end_test_config config); +extern void keepalive_timeout_pre_init(void); +extern void large_metadata(grpc_end2end_test_config config); +extern void large_metadata_pre_init(void); +extern void max_concurrent_streams(grpc_end2end_test_config config); +extern void max_concurrent_streams_pre_init(void); +extern void max_connection_age(grpc_end2end_test_config config); +extern void max_connection_age_pre_init(void); +extern void max_connection_idle(grpc_end2end_test_config config); +extern void max_connection_idle_pre_init(void); +extern void max_message_length(grpc_end2end_test_config config); +extern void max_message_length_pre_init(void); +extern void negative_deadline(grpc_end2end_test_config config); +extern void negative_deadline_pre_init(void); +extern void no_error_on_hotpath(grpc_end2end_test_config config); +extern void no_error_on_hotpath_pre_init(void); +extern void no_logging(grpc_end2end_test_config config); +extern void no_logging_pre_init(void); +extern void no_op(grpc_end2end_test_config config); +extern void no_op_pre_init(void); +extern void payload(grpc_end2end_test_config config); +extern void payload_pre_init(void); +extern void ping(grpc_end2end_test_config config); +extern void ping_pre_init(void); +extern void ping_pong_streaming(grpc_end2end_test_config config); +extern void ping_pong_streaming_pre_init(void); +extern void proxy_auth(grpc_end2end_test_config config); +extern void proxy_auth_pre_init(void); +extern void registered_call(grpc_end2end_test_config config); +extern void registered_call_pre_init(void); +extern void request_with_flags(grpc_end2end_test_config config); +extern void request_with_flags_pre_init(void); +extern void request_with_payload(grpc_end2end_test_config config); +extern void request_with_payload_pre_init(void); +extern void resource_quota_server(grpc_end2end_test_config config); +extern void resource_quota_server_pre_init(void); +extern void retry(grpc_end2end_test_config config); +extern void retry_pre_init(void); +extern void retry_cancel_during_delay(grpc_end2end_test_config config); +extern void retry_cancel_during_delay_pre_init(void); +extern void retry_cancel_with_multiple_send_batches(grpc_end2end_test_config config); +extern void retry_cancel_with_multiple_send_batches_pre_init(void); +extern void retry_cancellation(grpc_end2end_test_config config); +extern void retry_cancellation_pre_init(void); +extern void retry_disabled(grpc_end2end_test_config config); +extern void retry_disabled_pre_init(void); +extern void retry_exceeds_buffer_size_in_delay(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_delay_pre_init(void); +extern void retry_exceeds_buffer_size_in_initial_batch(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_initial_batch_pre_init(void); +extern void retry_exceeds_buffer_size_in_subsequent_batch(grpc_end2end_test_config config); +extern void retry_exceeds_buffer_size_in_subsequent_batch_pre_init(void); +extern void retry_lb_drop(grpc_end2end_test_config config); +extern void retry_lb_drop_pre_init(void); +extern void retry_lb_fail(grpc_end2end_test_config config); +extern void retry_lb_fail_pre_init(void); +extern void retry_non_retriable_status(grpc_end2end_test_config config); +extern void retry_non_retriable_status_pre_init(void); +extern void retry_non_retriable_status_before_recv_trailing_metadata_started(grpc_end2end_test_config config); +extern void retry_non_retriable_status_before_recv_trailing_metadata_started_pre_init(void); +extern void retry_per_attempt_recv_timeout(grpc_end2end_test_config config); +extern void retry_per_attempt_recv_timeout_pre_init(void); +extern void retry_per_attempt_recv_timeout_on_last_attempt(grpc_end2end_test_config config); +extern void retry_per_attempt_recv_timeout_on_last_attempt_pre_init(void); +extern void retry_recv_initial_metadata(grpc_end2end_test_config config); +extern void retry_recv_initial_metadata_pre_init(void); +extern void retry_recv_message(grpc_end2end_test_config config); +extern void retry_recv_message_pre_init(void); +extern void retry_recv_trailing_metadata_error(grpc_end2end_test_config config); +extern void retry_recv_trailing_metadata_error_pre_init(void); +extern void retry_send_initial_metadata_refs(grpc_end2end_test_config config); +extern void retry_send_initial_metadata_refs_pre_init(void); +extern void retry_send_op_fails(grpc_end2end_test_config config); +extern void retry_send_op_fails_pre_init(void); +extern void retry_server_pushback_delay(grpc_end2end_test_config config); +extern void retry_server_pushback_delay_pre_init(void); +extern void retry_server_pushback_disabled(grpc_end2end_test_config config); +extern void retry_server_pushback_disabled_pre_init(void); +extern void retry_streaming(grpc_end2end_test_config config); +extern void retry_streaming_pre_init(void); +extern void retry_streaming_after_commit(grpc_end2end_test_config config); +extern void retry_streaming_after_commit_pre_init(void); +extern void retry_streaming_succeeds_before_replay_finished(grpc_end2end_test_config config); +extern void retry_streaming_succeeds_before_replay_finished_pre_init(void); +extern void retry_throttled(grpc_end2end_test_config config); +extern void retry_throttled_pre_init(void); +extern void retry_too_many_attempts(grpc_end2end_test_config config); +extern void retry_too_many_attempts_pre_init(void); +extern void sdk_authz(grpc_end2end_test_config config); +extern void sdk_authz_pre_init(void); +extern void server_finishes_request(grpc_end2end_test_config config); +extern void server_finishes_request_pre_init(void); +extern void server_streaming(grpc_end2end_test_config config); +extern void server_streaming_pre_init(void); +extern void shutdown_finishes_calls(grpc_end2end_test_config config); +extern void shutdown_finishes_calls_pre_init(void); +extern void shutdown_finishes_tags(grpc_end2end_test_config config); +extern void shutdown_finishes_tags_pre_init(void); +extern void simple_cacheable_request(grpc_end2end_test_config config); +extern void simple_cacheable_request_pre_init(void); +extern void simple_delayed_request(grpc_end2end_test_config config); +extern void simple_delayed_request_pre_init(void); +extern void simple_metadata(grpc_end2end_test_config config); +extern void simple_metadata_pre_init(void); +extern void simple_request(grpc_end2end_test_config config); +extern void simple_request_pre_init(void); +extern void stream_compression_compressed_payload(grpc_end2end_test_config config); +extern void stream_compression_compressed_payload_pre_init(void); +extern void stream_compression_payload(grpc_end2end_test_config config); +extern void stream_compression_payload_pre_init(void); +extern void stream_compression_ping_pong_streaming(grpc_end2end_test_config config); +extern void stream_compression_ping_pong_streaming_pre_init(void); +extern void streaming_error_response(grpc_end2end_test_config config); +extern void streaming_error_response_pre_init(void); +extern void trailing_metadata(grpc_end2end_test_config config); +extern void trailing_metadata_pre_init(void); +extern void write_buffering(grpc_end2end_test_config config); +extern void write_buffering_pre_init(void); +extern void write_buffering_at_end(grpc_end2end_test_config config); +extern void write_buffering_at_end_pre_init(void); + +void grpc_end2end_tests_pre_init(void) { + GPR_ASSERT(!g_pre_init_called); + g_pre_init_called = true; + authority_not_supported_pre_init(); + bad_hostname_pre_init(); + bad_ping_pre_init(); + binary_metadata_pre_init(); + call_creds_pre_init(); + call_host_override_pre_init(); + cancel_after_accept_pre_init(); + cancel_after_client_done_pre_init(); + cancel_after_invoke_pre_init(); + cancel_after_round_trip_pre_init(); + cancel_before_invoke_pre_init(); + cancel_in_a_vacuum_pre_init(); + cancel_with_status_pre_init(); + channelz_pre_init(); + client_streaming_pre_init(); + compressed_payload_pre_init(); + connectivity_pre_init(); + default_host_pre_init(); + disappearing_server_pre_init(); + empty_batch_pre_init(); + filter_causes_close_pre_init(); + filter_context_pre_init(); + filter_init_fails_pre_init(); + filter_latency_pre_init(); + filter_status_code_pre_init(); + graceful_server_shutdown_pre_init(); + high_initial_seqno_pre_init(); + hpack_size_pre_init(); + idempotent_request_pre_init(); + invoke_large_request_pre_init(); + keepalive_timeout_pre_init(); + large_metadata_pre_init(); + max_concurrent_streams_pre_init(); + max_connection_age_pre_init(); + max_connection_idle_pre_init(); + max_message_length_pre_init(); + negative_deadline_pre_init(); + no_error_on_hotpath_pre_init(); + no_logging_pre_init(); + no_op_pre_init(); + payload_pre_init(); + ping_pre_init(); + ping_pong_streaming_pre_init(); + proxy_auth_pre_init(); + registered_call_pre_init(); + request_with_flags_pre_init(); + request_with_payload_pre_init(); + resource_quota_server_pre_init(); + retry_pre_init(); + retry_cancel_during_delay_pre_init(); + retry_cancel_with_multiple_send_batches_pre_init(); + retry_cancellation_pre_init(); + retry_disabled_pre_init(); + retry_exceeds_buffer_size_in_delay_pre_init(); + retry_exceeds_buffer_size_in_initial_batch_pre_init(); + retry_exceeds_buffer_size_in_subsequent_batch_pre_init(); + retry_lb_drop_pre_init(); + retry_lb_fail_pre_init(); + retry_non_retriable_status_pre_init(); + retry_non_retriable_status_before_recv_trailing_metadata_started_pre_init(); + retry_per_attempt_recv_timeout_pre_init(); + retry_per_attempt_recv_timeout_on_last_attempt_pre_init(); + retry_recv_initial_metadata_pre_init(); + retry_recv_message_pre_init(); + retry_recv_trailing_metadata_error_pre_init(); + retry_send_initial_metadata_refs_pre_init(); + retry_send_op_fails_pre_init(); + retry_server_pushback_delay_pre_init(); + retry_server_pushback_disabled_pre_init(); + retry_streaming_pre_init(); + retry_streaming_after_commit_pre_init(); + retry_streaming_succeeds_before_replay_finished_pre_init(); + retry_throttled_pre_init(); + retry_too_many_attempts_pre_init(); + sdk_authz_pre_init(); + server_finishes_request_pre_init(); + server_streaming_pre_init(); + shutdown_finishes_calls_pre_init(); + shutdown_finishes_tags_pre_init(); + simple_cacheable_request_pre_init(); + simple_delayed_request_pre_init(); + simple_metadata_pre_init(); + simple_request_pre_init(); + stream_compression_compressed_payload_pre_init(); + stream_compression_payload_pre_init(); + stream_compression_ping_pong_streaming_pre_init(); + streaming_error_response_pre_init(); + trailing_metadata_pre_init(); + write_buffering_pre_init(); + write_buffering_at_end_pre_init(); +} + +// NOLINTNEXTLINE(readability-function-size) +void grpc_end2end_tests(int argc, char **argv, + grpc_end2end_test_config config) { + int i; + + GPR_ASSERT(g_pre_init_called); + + if (argc <= 1) { + authority_not_supported(config); + bad_hostname(config); + bad_ping(config); + binary_metadata(config); + call_creds(config); + call_host_override(config); + cancel_after_accept(config); + cancel_after_client_done(config); + cancel_after_invoke(config); + cancel_after_round_trip(config); + cancel_before_invoke(config); + cancel_in_a_vacuum(config); + cancel_with_status(config); + channelz(config); + client_streaming(config); + compressed_payload(config); + connectivity(config); + default_host(config); + disappearing_server(config); + empty_batch(config); + filter_causes_close(config); + filter_context(config); + filter_init_fails(config); + filter_latency(config); + filter_status_code(config); + graceful_server_shutdown(config); + high_initial_seqno(config); + hpack_size(config); + idempotent_request(config); + invoke_large_request(config); + keepalive_timeout(config); + large_metadata(config); + max_concurrent_streams(config); + max_connection_age(config); + max_connection_idle(config); + max_message_length(config); + negative_deadline(config); + no_error_on_hotpath(config); + no_logging(config); + no_op(config); + payload(config); + ping(config); + ping_pong_streaming(config); + proxy_auth(config); + registered_call(config); + request_with_flags(config); + request_with_payload(config); + resource_quota_server(config); + retry(config); + retry_cancel_during_delay(config); + retry_cancel_with_multiple_send_batches(config); + retry_cancellation(config); + retry_disabled(config); + retry_exceeds_buffer_size_in_delay(config); + retry_exceeds_buffer_size_in_initial_batch(config); + retry_exceeds_buffer_size_in_subsequent_batch(config); + retry_lb_drop(config); + retry_lb_fail(config); + retry_non_retriable_status(config); + retry_non_retriable_status_before_recv_trailing_metadata_started(config); + retry_per_attempt_recv_timeout(config); + retry_per_attempt_recv_timeout_on_last_attempt(config); + retry_recv_initial_metadata(config); + retry_recv_message(config); + retry_recv_trailing_metadata_error(config); + retry_send_initial_metadata_refs(config); + retry_send_op_fails(config); + retry_server_pushback_delay(config); + retry_server_pushback_disabled(config); + retry_streaming(config); + retry_streaming_after_commit(config); + retry_streaming_succeeds_before_replay_finished(config); + retry_throttled(config); + retry_too_many_attempts(config); + sdk_authz(config); + server_finishes_request(config); + server_streaming(config); + shutdown_finishes_calls(config); + shutdown_finishes_tags(config); + simple_cacheable_request(config); + simple_delayed_request(config); + simple_metadata(config); + simple_request(config); + stream_compression_compressed_payload(config); + stream_compression_payload(config); + stream_compression_ping_pong_streaming(config); + streaming_error_response(config); + trailing_metadata(config); + write_buffering(config); + write_buffering_at_end(config); + return; + } + + for (i = 1; i < argc; i++) { + if (0 == strcmp("authority_not_supported", argv[i])) { + authority_not_supported(config); + continue; + } + if (0 == strcmp("bad_hostname", argv[i])) { + bad_hostname(config); + continue; + } + if (0 == strcmp("bad_ping", argv[i])) { + bad_ping(config); + continue; + } + if (0 == strcmp("binary_metadata", argv[i])) { + binary_metadata(config); + continue; + } + if (0 == strcmp("call_creds", argv[i])) { + call_creds(config); + continue; + } + if (0 == strcmp("call_host_override", argv[i])) { + call_host_override(config); + continue; + } + if (0 == strcmp("cancel_after_accept", argv[i])) { + cancel_after_accept(config); + continue; + } + if (0 == strcmp("cancel_after_client_done", argv[i])) { + cancel_after_client_done(config); + continue; + } + if (0 == strcmp("cancel_after_invoke", argv[i])) { + cancel_after_invoke(config); + continue; + } + if (0 == strcmp("cancel_after_round_trip", argv[i])) { + cancel_after_round_trip(config); + continue; + } + if (0 == strcmp("cancel_before_invoke", argv[i])) { + cancel_before_invoke(config); + continue; + } + if (0 == strcmp("cancel_in_a_vacuum", argv[i])) { + cancel_in_a_vacuum(config); + continue; + } + if (0 == strcmp("cancel_with_status", argv[i])) { + cancel_with_status(config); + continue; + } + if (0 == strcmp("channelz", argv[i])) { + channelz(config); + continue; + } + if (0 == strcmp("client_streaming", argv[i])) { + client_streaming(config); + continue; + } + if (0 == strcmp("compressed_payload", argv[i])) { + compressed_payload(config); + continue; + } + if (0 == strcmp("connectivity", argv[i])) { + connectivity(config); + continue; + } + if (0 == strcmp("default_host", argv[i])) { + default_host(config); + continue; + } + if (0 == strcmp("disappearing_server", argv[i])) { + disappearing_server(config); + continue; + } + if (0 == strcmp("empty_batch", argv[i])) { + empty_batch(config); + continue; + } + if (0 == strcmp("filter_causes_close", argv[i])) { + filter_causes_close(config); + continue; + } + if (0 == strcmp("filter_context", argv[i])) { + filter_context(config); + continue; + } + if (0 == strcmp("filter_init_fails", argv[i])) { + filter_init_fails(config); + continue; + } + if (0 == strcmp("filter_latency", argv[i])) { + filter_latency(config); + continue; + } + if (0 == strcmp("filter_status_code", argv[i])) { + filter_status_code(config); + continue; + } + if (0 == strcmp("graceful_server_shutdown", argv[i])) { + graceful_server_shutdown(config); + continue; + } + if (0 == strcmp("high_initial_seqno", argv[i])) { + high_initial_seqno(config); + continue; + } + if (0 == strcmp("hpack_size", argv[i])) { + hpack_size(config); + continue; + } + if (0 == strcmp("idempotent_request", argv[i])) { + idempotent_request(config); + continue; + } + if (0 == strcmp("invoke_large_request", argv[i])) { + invoke_large_request(config); + continue; + } + if (0 == strcmp("keepalive_timeout", argv[i])) { + keepalive_timeout(config); + continue; + } + if (0 == strcmp("large_metadata", argv[i])) { + large_metadata(config); + continue; + } + if (0 == strcmp("max_concurrent_streams", argv[i])) { + max_concurrent_streams(config); + continue; + } + if (0 == strcmp("max_connection_age", argv[i])) { + max_connection_age(config); + continue; + } + if (0 == strcmp("max_connection_idle", argv[i])) { + max_connection_idle(config); + continue; + } + if (0 == strcmp("max_message_length", argv[i])) { + max_message_length(config); + continue; + } + if (0 == strcmp("negative_deadline", argv[i])) { + negative_deadline(config); + continue; + } + if (0 == strcmp("no_error_on_hotpath", argv[i])) { + no_error_on_hotpath(config); + continue; + } + if (0 == strcmp("no_logging", argv[i])) { + no_logging(config); + continue; + } + if (0 == strcmp("no_op", argv[i])) { + no_op(config); + continue; + } + if (0 == strcmp("payload", argv[i])) { + payload(config); + continue; + } + if (0 == strcmp("ping", argv[i])) { + ping(config); + continue; + } + if (0 == strcmp("ping_pong_streaming", argv[i])) { + ping_pong_streaming(config); + continue; + } + if (0 == strcmp("proxy_auth", argv[i])) { + proxy_auth(config); + continue; + } + if (0 == strcmp("registered_call", argv[i])) { + registered_call(config); + continue; + } + if (0 == strcmp("request_with_flags", argv[i])) { + request_with_flags(config); + continue; + } + if (0 == strcmp("request_with_payload", argv[i])) { + request_with_payload(config); + continue; + } + if (0 == strcmp("resource_quota_server", argv[i])) { + resource_quota_server(config); + continue; + } + if (0 == strcmp("retry", argv[i])) { + retry(config); + continue; + } + if (0 == strcmp("retry_cancel_during_delay", argv[i])) { + retry_cancel_during_delay(config); + continue; + } + if (0 == strcmp("retry_cancel_with_multiple_send_batches", argv[i])) { + retry_cancel_with_multiple_send_batches(config); + continue; + } + if (0 == strcmp("retry_cancellation", argv[i])) { + retry_cancellation(config); + continue; + } + if (0 == strcmp("retry_disabled", argv[i])) { + retry_disabled(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_delay", argv[i])) { + retry_exceeds_buffer_size_in_delay(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_initial_batch", argv[i])) { + retry_exceeds_buffer_size_in_initial_batch(config); + continue; + } + if (0 == strcmp("retry_exceeds_buffer_size_in_subsequent_batch", argv[i])) { + retry_exceeds_buffer_size_in_subsequent_batch(config); + continue; + } + if (0 == strcmp("retry_lb_drop", argv[i])) { + retry_lb_drop(config); + continue; + } + if (0 == strcmp("retry_lb_fail", argv[i])) { + retry_lb_fail(config); + continue; + } + if (0 == strcmp("retry_non_retriable_status", argv[i])) { + retry_non_retriable_status(config); + continue; + } + if (0 == strcmp("retry_non_retriable_status_before_recv_trailing_metadata_started", argv[i])) { + retry_non_retriable_status_before_recv_trailing_metadata_started(config); + continue; + } + if (0 == strcmp("retry_per_attempt_recv_timeout", argv[i])) { + retry_per_attempt_recv_timeout(config); + continue; + } + if (0 == strcmp("retry_per_attempt_recv_timeout_on_last_attempt", argv[i])) { + retry_per_attempt_recv_timeout_on_last_attempt(config); + continue; + } + if (0 == strcmp("retry_recv_initial_metadata", argv[i])) { + retry_recv_initial_metadata(config); + continue; + } + if (0 == strcmp("retry_recv_message", argv[i])) { + retry_recv_message(config); + continue; + } + if (0 == strcmp("retry_recv_trailing_metadata_error", argv[i])) { + retry_recv_trailing_metadata_error(config); + continue; + } + if (0 == strcmp("retry_send_initial_metadata_refs", argv[i])) { + retry_send_initial_metadata_refs(config); + continue; + } + if (0 == strcmp("retry_send_op_fails", argv[i])) { + retry_send_op_fails(config); + continue; + } + if (0 == strcmp("retry_server_pushback_delay", argv[i])) { + retry_server_pushback_delay(config); + continue; + } + if (0 == strcmp("retry_server_pushback_disabled", argv[i])) { + retry_server_pushback_disabled(config); + continue; + } + if (0 == strcmp("retry_streaming", argv[i])) { + retry_streaming(config); + continue; + } + if (0 == strcmp("retry_streaming_after_commit", argv[i])) { + retry_streaming_after_commit(config); + continue; + } + if (0 == strcmp("retry_streaming_succeeds_before_replay_finished", argv[i])) { + retry_streaming_succeeds_before_replay_finished(config); + continue; + } + if (0 == strcmp("retry_throttled", argv[i])) { + retry_throttled(config); + continue; + } + if (0 == strcmp("retry_too_many_attempts", argv[i])) { + retry_too_many_attempts(config); + continue; + } + if (0 == strcmp("sdk_authz", argv[i])) { + sdk_authz(config); + continue; + } + if (0 == strcmp("server_finishes_request", argv[i])) { + server_finishes_request(config); + continue; + } + if (0 == strcmp("server_streaming", argv[i])) { + server_streaming(config); + continue; + } + if (0 == strcmp("shutdown_finishes_calls", argv[i])) { + shutdown_finishes_calls(config); + continue; + } + if (0 == strcmp("shutdown_finishes_tags", argv[i])) { + shutdown_finishes_tags(config); + continue; + } + if (0 == strcmp("simple_cacheable_request", argv[i])) { + simple_cacheable_request(config); + continue; + } + if (0 == strcmp("simple_delayed_request", argv[i])) { + simple_delayed_request(config); + continue; + } + if (0 == strcmp("simple_metadata", argv[i])) { + simple_metadata(config); + continue; + } + if (0 == strcmp("simple_request", argv[i])) { + simple_request(config); + continue; + } + if (0 == strcmp("stream_compression_compressed_payload", argv[i])) { + stream_compression_compressed_payload(config); + continue; + } + if (0 == strcmp("stream_compression_payload", argv[i])) { + stream_compression_payload(config); + continue; + } + if (0 == strcmp("stream_compression_ping_pong_streaming", argv[i])) { + stream_compression_ping_pong_streaming(config); + continue; + } + if (0 == strcmp("streaming_error_response", argv[i])) { + streaming_error_response(config); + continue; + } + if (0 == strcmp("trailing_metadata", argv[i])) { + trailing_metadata(config); + continue; + } + if (0 == strcmp("write_buffering", argv[i])) { + write_buffering(config); + continue; + } + if (0 == strcmp("write_buffering_at_end", argv[i])) { + write_buffering_at_end(config); + continue; + } + gpr_log(GPR_DEBUG, "not a test: '%s'", argv[i]); + abort(); + } +} diff --git a/test/core/end2end/engine_passthrough.cc b/test/core/end2end/engine_passthrough.cc new file mode 100644 index 00000000..79c30883 --- /dev/null +++ b/test/core/end2end/engine_passthrough.cc @@ -0,0 +1,73 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This is a sample openSSL engine which tests the openSSL +// engine plugability with gRPC. +// This sample engine expects KeyId to be actual PEM encoded +// key itself and just calls standard openSSL functions. + +#include +#include +#include + +#ifndef OPENSSL_IS_BORINGSSL + +#include +#include + +extern "C" { +static const char engine_id[] = "libengine_passthrough"; +static const char engine_name[] = "A passthrough engine for private keys"; +static int e_passthrough_idx = -1; + +static int e_passthrough_init(ENGINE* e) { + if (e_passthrough_idx < 0) { + e_passthrough_idx = ENGINE_get_ex_new_index(0, NULL, NULL, NULL, 0); + if (e_passthrough_idx < 0) return 0; + } + return 1; +} + +EVP_PKEY* e_passthrough_load_privkey(ENGINE* eng, const char* key_id, + UI_METHOD* ui_method, + void* callback_data) { + EVP_PKEY* pkey = NULL; + BIO* pem = BIO_new_mem_buf((void*)key_id, (int)(strlen(key_id))); + if (pem == NULL) return NULL; + pkey = PEM_read_bio_PrivateKey(pem, NULL, NULL, (void*)""); + BIO_free(pem); + return pkey; +} + +int passthrough_bind_helper(ENGINE* e, const char* id) { + if (id && strcmp(id, engine_id)) { + return 0; + } + if (!ENGINE_set_id(e, engine_id) || !ENGINE_set_name(e, engine_name) || + !ENGINE_set_flags(e, ENGINE_FLAGS_NO_REGISTER_ALL) || + !ENGINE_set_init_function(e, e_passthrough_init) || + !ENGINE_set_load_privkey_function(e, e_passthrough_load_privkey)) { + return 0; + } + return 1; +} + +IMPLEMENT_DYNAMIC_BIND_FN(passthrough_bind_helper) +IMPLEMENT_DYNAMIC_CHECK_FN() +} +#endif // OPENSSL_IS_BORINGSSL diff --git a/test/core/end2end/fixtures/h2_census.cc b/test/core/end2end/fixtures/h2_census.cc new file mode 100644 index 00000000..6ab374e7 --- /dev/null +++ b/test/core/end2end/fixtures/h2_census.cc @@ -0,0 +1,129 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_fixture_data* ffd = new fullstack_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static grpc_arg make_census_enable_arg(void) { + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_ENABLE_CENSUS); + arg.value.integer = 1; + return arg; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + grpc_arg arg = make_census_enable_arg(); + client_args = grpc_channel_args_copy_and_add(client_args, &arg, 1); + f->client = grpc_insecure_channel_create(ffd->localaddr.c_str(), client_args, + nullptr); + GPR_ASSERT(f->client); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + } +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + grpc_arg arg = make_census_enable_arg(); + if (f->server) { + grpc_server_destroy(f->server); + } + server_args = grpc_channel_args_copy_and_add(server_args, &arg, 1); + f->server = grpc_server_create(server_args, nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(server_args); + } + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack+census", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_compress.cc b/test/core/end2end/fixtures/h2_compress.cc new file mode 100644 index 00000000..f0de5c6b --- /dev/null +++ b/test/core/end2end/fixtures/h2_compress.cc @@ -0,0 +1,133 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_compression_fixture_data { + ~fullstack_compression_fixture_data() { + grpc_channel_args_destroy(client_args_compression); + grpc_channel_args_destroy(server_args_compression); + } + std::string localaddr; + grpc_channel_args* client_args_compression = nullptr; + grpc_channel_args* server_args_compression = nullptr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack_compression( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_compression_fixture_data* ffd = + new fullstack_compression_fixture_data(); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + memset(&f, 0, sizeof(f)); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack_compression(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_compression_fixture_data* ffd = + static_cast(f->fixture_data); + if (ffd->client_args_compression != nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(ffd->client_args_compression); + } + ffd->client_args_compression = + grpc_channel_args_set_channel_default_compression_algorithm( + client_args, GRPC_COMPRESS_GZIP); + f->client = grpc_insecure_channel_create( + ffd->localaddr.c_str(), ffd->client_args_compression, nullptr); +} + +void chttp2_init_server_fullstack_compression(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_compression_fixture_data* ffd = + static_cast(f->fixture_data); + if (ffd->server_args_compression != nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(ffd->server_args_compression); + } + ffd->server_args_compression = + grpc_channel_args_set_channel_default_compression_algorithm( + server_args, GRPC_COMPRESS_GZIP); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(ffd->server_args_compression, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack_compression(grpc_end2end_test_fixture* f) { + grpc_core::ExecCtx exec_ctx; + fullstack_compression_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack_compression", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack_compression, + chttp2_init_client_fullstack_compression, + chttp2_init_server_fullstack_compression, + chttp2_tear_down_fullstack_compression}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_fakesec.cc b/test/core/end2end/fixtures/h2_fakesec.cc new file mode 100644 index 00000000..156185aa --- /dev/null +++ b/test/core/end2end/fixtures/h2_fakesec.cc @@ -0,0 +1,152 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_secure_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +static void chttp2_init_client_fake_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args) { + grpc_channel_credentials* fake_ts_creds = + grpc_fake_transport_security_credentials_create(); + chttp2_init_client_secure_fullstack(f, client_args, fake_ts_creds); +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void chttp2_init_server_fake_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args) { + grpc_server_credentials* fake_ts_creds = + grpc_fake_transport_security_server_credentials_create(); + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(fake_ts_creds, + processor); + } + chttp2_init_server_secure_fullstack(f, server_args, fake_ts_creds); +} + +/* All test configurations */ + +static grpc_end2end_test_config configs[] = { + {"chttp2/fake_secure_fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS_LEVEL_INSECURE, + nullptr, chttp2_create_fixture_secure_fullstack, + chttp2_init_client_fake_secure_fullstack, + chttp2_init_server_fake_secure_fullstack, + chttp2_tear_down_secure_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_fd.cc b/test/core/end2end/fixtures/h2_fd.cc new file mode 100644 index 00000000..bd977c9f --- /dev/null +++ b/test/core/end2end/fixtures/h2_fd.cc @@ -0,0 +1,124 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/test_config.h" + +typedef struct { + int fd_pair[2]; +} sp_fixture_data; + +static void create_sockets(int sv[2]) { + int flags; + grpc_create_socketpair_if_unix(sv); + flags = fcntl(sv[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[1], F_SETFL, flags | O_NONBLOCK) == 0); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv[0]) == GRPC_ERROR_NONE); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv[1]) == GRPC_ERROR_NONE); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_socketpair( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + sp_fixture_data* fixture_data = + static_cast(gpr_malloc(sizeof(*fixture_data))); + + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + f.fixture_data = fixture_data; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + create_sockets(fixture_data->fd_pair); + + return f; +} + +static void chttp2_init_client_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_core::ExecCtx exec_ctx; + sp_fixture_data* sfd = static_cast(f->fixture_data); + + GPR_ASSERT(!f->client); + f->client = grpc_insecure_channel_create_from_fd( + "fixture_client", sfd->fd_pair[0], client_args); + GPR_ASSERT(f->client); +} + +static void chttp2_init_server_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + grpc_core::ExecCtx exec_ctx; + sp_fixture_data* sfd = static_cast(f->fixture_data); + GPR_ASSERT(!f->server); + f->server = grpc_server_create(server_args, nullptr); + GPR_ASSERT(f->server); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); + + grpc_server_add_insecure_channel_from_fd(f->server, nullptr, sfd->fd_pair[1]); +} + +static void chttp2_tear_down_socketpair(grpc_end2end_test_fixture* f) { + gpr_free(f->fixture_data); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fd", FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + chttp2_create_fixture_socketpair, chttp2_init_client_socketpair, + chttp2_init_server_socketpair, chttp2_tear_down_socketpair}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} + +#else /* GRPC_POSIX_SOCKET */ + +int main(int /* argc */, char** /* argv */) { return 1; } + +#endif /* GRPC_POSIX_SOCKET */ diff --git a/test/core/end2end/fixtures/h2_full+pipe.cc b/test/core/end2end/fixtures/h2_full+pipe.cc new file mode 100644 index 00000000..f26c390d --- /dev/null +++ b/test/core/end2end/fixtures/h2_full+pipe.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test requires posix wakeup fds +#ifdef GRPC_POSIX_WAKEUP_FD + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_fixture_data* ffd = new fullstack_fixture_data(); + memset(&f, 0, sizeof(f)); + + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_insecure_channel_create(ffd->localaddr.c_str(), client_args, + nullptr); + GPR_ASSERT(f->client); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc_allow_specialized_wakeup_fd = 0; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} + +#else /* GRPC_POSIX_WAKEUP_FD */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_WAKEUP_FD */ diff --git a/test/core/end2end/fixtures/h2_full+trace.cc b/test/core/end2end/fixtures/h2_full+trace.cc new file mode 100644 index 00000000..48085e93 --- /dev/null +++ b/test/core/end2end/fixtures/h2_full+trace.cc @@ -0,0 +1,137 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#ifdef GRPC_POSIX_SOCKET +#include +#endif + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_fixture_data* ffd = new fullstack_fixture_data(); + memset(&f, 0, sizeof(f)); + + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_insecure_channel_create(ffd->localaddr.c_str(), client_args, + nullptr); + GPR_ASSERT(f->client); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + /* force tracing on, with a value to force many + code paths in trace.c to be taken */ + GPR_GLOBAL_CONFIG_SET(grpc_trace, "doesnt-exist,http,all"); + +#ifdef GRPC_POSIX_SOCKET + g_fixture_slowdown_factor = isatty(STDOUT_FILENO) ? 10 : 1; +#else + g_fixture_slowdown_factor = 10; +#endif + +#ifdef GPR_WINDOWS + /* on Windows, writing logs to stderr is very slow + when stderr is redirected to a disk file. + The "trace" tests fixtures generates large amount + of logs, so setting a buffer for stderr prevents certain + test cases from timing out. */ + setvbuf(stderr, NULL, _IOLBF, 1024); +#endif + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + GPR_ASSERT(0 == grpc_tracer_set_enabled("also-doesnt-exist", 0)); + GPR_ASSERT(1 == grpc_tracer_set_enabled("http", 1)); + GPR_ASSERT(1 == grpc_tracer_set_enabled("all", 1)); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_full.cc b/test/core/end2end/fixtures/h2_full.cc new file mode 100644 index 00000000..074ec7f8 --- /dev/null +++ b/test/core/end2end/fixtures/h2_full.cc @@ -0,0 +1,109 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_fixture_data* ffd = new fullstack_fixture_data(); + memset(&f, 0, sizeof(f)); + + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_insecure_channel_create(ffd->localaddr.c_str(), client_args, + nullptr); + GPR_ASSERT(f->client); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_http_proxy.cc b/test/core/end2end/fixtures/h2_http_proxy.cc new file mode 100644 index 00000000..a4b389c8 --- /dev/null +++ b/test/core/end2end/fixtures/h2_http_proxy.cc @@ -0,0 +1,136 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/http_proxy_fixture.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + ~fullstack_fixture_data() { grpc_end2end_http_proxy_destroy(proxy); } + std::string server_addr; + grpc_end2end_http_proxy* proxy = nullptr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* client_args, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + fullstack_fixture_data* ffd = new fullstack_fixture_data(); + const int server_port = grpc_pick_unused_port_or_die(); + ffd->server_addr = grpc_core::JoinHostPort("localhost", server_port); + + /* Passing client_args to proxy_create for the case of checking for proxy auth + */ + ffd->proxy = grpc_end2end_http_proxy_create(client_args); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + /* If testing for proxy auth, add credentials to proxy uri */ + const char* proxy_auth_str = grpc_channel_args_find_string( + client_args, GRPC_ARG_HTTP_PROXY_AUTH_CREDS); + std::string proxy_uri; + if (proxy_auth_str == nullptr) { + proxy_uri = absl::StrFormat( + "http://%s", grpc_end2end_http_proxy_get_proxy_name(ffd->proxy)); + } else { + proxy_uri = + absl::StrFormat("http://%s@%s", proxy_auth_str, + grpc_end2end_http_proxy_get_proxy_name(ffd->proxy)); + } + gpr_setenv("http_proxy", proxy_uri.c_str()); + grpc_channel_credentials* creds = grpc_insecure_credentials_create(); + f->client = grpc_secure_channel_create(creds, ffd->server_addr.c_str(), + client_args, nullptr); + grpc_channel_credentials_release(creds); + GPR_ASSERT(f->client); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->server_addr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_insecure.cc b/test/core/end2end/fixtures/h2_insecure.cc new file mode 100644 index 00000000..13ac3e8f --- /dev/null +++ b/test/core/end2end/fixtures/h2_insecure.cc @@ -0,0 +1,126 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +struct Chttp2InsecureFullstackFixtureData { + std::string localaddr; +}; + +grpc_end2end_test_fixture Chttp2CreateFixtureInsecureFullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + Chttp2InsecureFullstackFixtureData* ffd = + new Chttp2InsecureFullstackFixtureData(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void Chttp2InitClientInsecureFullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + Chttp2InsecureFullstackFixtureData* ffd = + static_cast(f->fixture_data); + grpc_channel_credentials* creds = grpc_insecure_credentials_create(); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + grpc_channel_credentials_release(creds); + GPR_ASSERT(f->client); +} + +void ProcessAuthFailure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +void Chttp2InitServerInsecureFullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + Chttp2InsecureFullstackFixtureData* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_credentials* server_creds = + grpc_insecure_server_credentials_create(); + if (grpc_channel_args_find(server_args, FAIL_AUTH_CHECK_SERVER_ARG_NAME) != + nullptr) { + grpc_auth_metadata_processor processor = {ProcessAuthFailure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(server_creds, + processor); + } + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void Chttp2TearDownInsecureFullstack(grpc_end2end_test_fixture* f) { + Chttp2InsecureFullstackFixtureData* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +grpc_end2end_test_config configs[] = { + {"chttp2/insecure_fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS_LEVEL_INSECURE, + nullptr, Chttp2CreateFixtureInsecureFullstack, + Chttp2InitClientInsecureFullstack, Chttp2InitServerInsecureFullstack, + Chttp2TearDownInsecureFullstack}, +}; + +} // namespace + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_local_ipv4.cc b/test/core/end2end/fixtures/h2_local_ipv4.cc new file mode 100644 index 00000000..f1748078 --- /dev/null +++ b/test/core/end2end/fixtures/h2_local_ipv4.cc @@ -0,0 +1,70 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/local_util.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack_ipv4( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f = + grpc_end2end_local_chttp2_create_fixture_fullstack(); + int port = grpc_pick_unused_port_or_die(); + static_cast(f.fixture_data) + ->localaddr = grpc_core::JoinHostPort("127.0.0.1", port); + return f; +} + +static void chttp2_init_client_fullstack_ipv4(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_client_fullstack(f, client_args, LOCAL_TCP); +} + +static void chttp2_init_server_fullstack_ipv4(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_server_fullstack(f, client_args, LOCAL_TCP); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack_local_ipv4", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS, + nullptr, chttp2_create_fixture_fullstack_ipv4, + chttp2_init_client_fullstack_ipv4, chttp2_init_server_fullstack_ipv4, + grpc_end2end_local_chttp2_tear_down_fullstack}}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fixtures/h2_local_ipv6.cc b/test/core/end2end/fixtures/h2_local_ipv6.cc new file mode 100644 index 00000000..b074659d --- /dev/null +++ b/test/core/end2end/fixtures/h2_local_ipv6.cc @@ -0,0 +1,70 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/local_util.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack_ipv6( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f = + grpc_end2end_local_chttp2_create_fixture_fullstack(); + int port = grpc_pick_unused_port_or_die(); + static_cast(f.fixture_data) + ->localaddr = grpc_core::JoinHostPort("[::1]", port); + return f; +} + +static void chttp2_init_client_fullstack_ipv6(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_client_fullstack(f, client_args, LOCAL_TCP); +} + +static void chttp2_init_server_fullstack_ipv6(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_server_fullstack(f, client_args, LOCAL_TCP); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack_local_ipv6", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS, + nullptr, chttp2_create_fixture_fullstack_ipv6, + chttp2_init_client_fullstack_ipv6, chttp2_init_server_fullstack_ipv6, + grpc_end2end_local_chttp2_tear_down_fullstack}}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fixtures/h2_local_uds.cc b/test/core/end2end/fixtures/h2_local_uds.cc new file mode 100644 index 00000000..f612cc24 --- /dev/null +++ b/test/core/end2end/fixtures/h2_local_uds.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/strings/str_format.h" + +#include + +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/local_util.h" +#include "test/core/util/test_config.h" + +static int unique = 1; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack_uds( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f = + grpc_end2end_local_chttp2_create_fixture_fullstack(); + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + static_cast(f.fixture_data) + ->localaddr = absl::StrFormat( + "unix:/tmp/grpc_fullstack_test.%d.%" PRId64 ".%" PRId32 ".%d", getpid(), + now.tv_sec, now.tv_nsec, unique++); + return f; +} + +static void chttp2_init_client_fullstack_uds(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_client_fullstack(f, client_args, UDS); +} + +static void chttp2_init_server_fullstack_uds(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_end2end_local_chttp2_init_server_fullstack(f, client_args, UDS); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack_local_uds", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS, + nullptr, chttp2_create_fixture_fullstack_uds, + chttp2_init_client_fullstack_uds, chttp2_init_server_fullstack_uds, + grpc_end2end_local_chttp2_tear_down_fullstack}}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fixtures/h2_oauth2.cc b/test/core/end2end/fixtures/h2_oauth2.cc new file mode 100644 index 00000000..1f71fda3 --- /dev/null +++ b/test/core/end2end/fixtures/h2_oauth2.cc @@ -0,0 +1,295 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +static const char oauth2_md[] = "Bearer aaslkfjs424535asdf"; +static const char* client_identity_property_name = "smurf_name"; +static const char* client_identity = "Brainy Smurf"; + +struct fullstack_secure_fixture_data { + std::string localaddr; + grpc_tls_version tls_version; +}; + +static const grpc_metadata* find_metadata(const grpc_metadata* md, + size_t md_count, const char* key, + const char* value) { + size_t i; + for (i = 0; i < md_count; i++) { + if (grpc_slice_str_cmp(md[i].key, key) == 0 && + grpc_slice_str_cmp(md[i].value, value) == 0) { + return &md[i]; + } + } + return nullptr; +} + +typedef struct { + size_t pseudo_refcount; +} test_processor_state; + +static void process_oauth2_success(void* state, grpc_auth_context* ctx, + const grpc_metadata* md, size_t md_count, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + const grpc_metadata* oauth2 = + find_metadata(md, md_count, "authorization", oauth2_md); + test_processor_state* s; + + GPR_ASSERT(state != nullptr); + s = static_cast(state); + GPR_ASSERT(s->pseudo_refcount == 1); + GPR_ASSERT(oauth2 != nullptr); + grpc_auth_context_add_cstring_property(ctx, client_identity_property_name, + client_identity); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( + ctx, client_identity_property_name) == 1); + cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); +} + +static void process_oauth2_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* md, size_t md_count, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + const grpc_metadata* oauth2 = + find_metadata(md, md_count, "authorization", oauth2_md); + test_processor_state* s; + GPR_ASSERT(state != nullptr); + s = static_cast(state); + GPR_ASSERT(s->pseudo_refcount == 1); + GPR_ASSERT(oauth2 != nullptr); + cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/, + grpc_tls_version tls_version) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + ffd->tls_version = tls_version; + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_2( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_2); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_3( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_3); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +static void chttp2_init_client_simple_ssl_with_oauth2_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args) { + grpc_core::ExecCtx exec_ctx; + grpc_slice ca_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + const char* test_root_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(test_root_cert, nullptr, nullptr, nullptr); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + grpc_call_credentials* oauth2_creds = grpc_md_only_test_credentials_create( + "authorization", oauth2_md, true /* is_async */); + grpc_channel_credentials* ssl_oauth2_creds = + grpc_composite_channel_credentials_create(ssl_creds, oauth2_creds, + nullptr); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_oauth2_creds); + grpc_channel_args_destroy(new_client_args); + grpc_channel_credentials_release(ssl_creds); + grpc_call_credentials_release(oauth2_creds); + grpc_slice_unref(ca_slice); +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void processor_destroy(void* state) { + test_processor_state* s = static_cast(state); + GPR_ASSERT((s->pseudo_refcount--) == 1); + gpr_free(s); +} + +static grpc_auth_metadata_processor test_processor_create(int failing) { + test_processor_state* s = + static_cast(gpr_malloc(sizeof(*s))); + grpc_auth_metadata_processor result; + s->pseudo_refcount = 1; + result.state = s; + result.destroy = processor_destroy; + if (failing) { + result.process = process_oauth2_failure; + } else { + result.process = process_oauth2_success; + } + return result; +} + +static void chttp2_init_server_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args) { + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_server_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + grpc_server_credentials_set_auth_metadata_processor( + ssl_creds, test_processor_create(fail_server_auth_check(server_args))); + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); +} + +/* All test configurations */ + +static grpc_end2end_test_config configs[] = { + {"chttp2/simple_ssl_with_oauth2_fullstack_tls1_2", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_2, + chttp2_init_client_simple_ssl_with_oauth2_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, + {"chttp2/simple_ssl_with_oauth2_fullstack_tls1_3", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_DOES_NOT_SUPPORT_CLIENT_HANDSHAKE_COMPLETE_FIRST, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_3, + chttp2_init_client_simple_ssl_with_oauth2_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_proxy.cc b/test/core/end2end/fixtures/h2_proxy.cc new file mode 100644 index 00000000..134a65cd --- /dev/null +++ b/test/core/end2end/fixtures/h2_proxy.cc @@ -0,0 +1,126 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/proxy.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +typedef struct fullstack_fixture_data { + grpc_end2end_proxy* proxy; +} fullstack_fixture_data; + +static grpc_server* create_proxy_server(const char* port, + grpc_channel_args* server_args) { + grpc_server* s = grpc_server_create(server_args, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port(s, port)); + return s; +} + +static grpc_channel* create_proxy_client(const char* target, + grpc_channel_args* client_args) { + return grpc_insecure_channel_create(target, client_args, nullptr); +} + +static const grpc_end2end_proxy_def proxy_def = {create_proxy_server, + create_proxy_client}; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + fullstack_fixture_data* ffd = static_cast( + gpr_malloc(sizeof(fullstack_fixture_data))); + memset(&f, 0, sizeof(f)); + + ffd->proxy = grpc_end2end_proxy_create(&proxy_def, client_args, server_args); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_insecure_channel_create( + grpc_end2end_proxy_get_client_target(ffd->proxy), client_args, nullptr); + GPR_ASSERT(f->client); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port( + f->server, grpc_end2end_proxy_get_server_port(ffd->proxy))); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + grpc_end2end_proxy_destroy(ffd->proxy); + gpr_free(ffd); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack+proxy", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_REQUEST_PROXYING | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_sockpair+trace.cc b/test/core/end2end/fixtures/h2_sockpair+trace.cc new file mode 100644 index 00000000..c3e33553 --- /dev/null +++ b/test/core/end2end/fixtures/h2_sockpair+trace.cc @@ -0,0 +1,199 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#ifdef GRPC_POSIX_SOCKET +#include +#endif + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/client/http_client_filter.h" +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* chttp2 transport that is immediately available (used for testing + connected_channel without a client_channel */ + +struct custom_fixture_data { + grpc_endpoint_pair ep; + grpc_resource_quota* resource_quota; +}; + +static void server_setup_transport(void* ts, grpc_transport* transport) { + grpc_end2end_test_fixture* f = static_cast(ts); + grpc_core::ExecCtx exec_ctx; + custom_fixture_data* fixture_data = + static_cast(f->fixture_data); + grpc_endpoint_add_to_pollset(fixture_data->ep.server, grpc_cq_pollset(f->cq)); + grpc_error_handle error = f->server->core_server->SetupTransport( + transport, nullptr, f->server->core_server->channel_args(), nullptr); + if (error == GRPC_ERROR_NONE) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(transport); + } +} + +typedef struct { + grpc_end2end_test_fixture* f; + grpc_channel_args* client_args; +} sp_client_setup; + +static void client_setup_transport(void* ts, grpc_transport* transport) { + sp_client_setup* cs = static_cast(ts); + grpc_arg authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test-authority")); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(cs->client_args, &authority_arg, 1); + grpc_error_handle error = GRPC_ERROR_NONE; + cs->f->client = + grpc_channel_create("socketpair-target", args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, &error); + grpc_channel_args_destroy(args); + if (cs->f->client != nullptr) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + cs->f->client = + grpc_lame_client_channel_create(nullptr, status, "lame channel"); + grpc_transport_destroy(transport); + } +} + +static grpc_end2end_test_fixture chttp2_create_fixture_socketpair( + grpc_channel_args* client_args, grpc_channel_args* /*server_args*/) { + custom_fixture_data* fixture_data = static_cast( + gpr_malloc(sizeof(custom_fixture_data))); + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + f.fixture_data = fixture_data; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + fixture_data->resource_quota = + grpc_resource_quota_from_channel_args(client_args, true); + fixture_data->ep = grpc_iomgr_create_endpoint_pair("fixture", nullptr); + return f; +} + +static void chttp2_init_client_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + sp_client_setup cs; + cs.client_args = client_args; + cs.f = f; + transport = grpc_create_chttp2_transport( + client_args, fixture_data->ep.client, true, + grpc_resource_user_create(fixture_data->resource_quota, + "client_transport")); + client_setup_transport(&cs, transport); + GPR_ASSERT(f->client); +} + +static void chttp2_init_server_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + GPR_ASSERT(!f->server); + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); + transport = grpc_create_chttp2_transport( + server_args, fixture_data->ep.server, false, + grpc_resource_user_create(fixture_data->resource_quota, + "server_transport")); + server_setup_transport(f, transport); +} + +static void chttp2_tear_down_socketpair(grpc_end2end_test_fixture* f) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_resource_quota_unref(fixture_data->resource_quota); + gpr_free(f->fixture_data); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/socketpair", FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + chttp2_create_fixture_socketpair, chttp2_init_client_socketpair, + chttp2_init_server_socketpair, chttp2_tear_down_socketpair}, +}; + +int main(int argc, char** argv) { + size_t i; + + /* force tracing on, with a value to force many + code paths in trace.c to be taken */ + GPR_GLOBAL_CONFIG_SET(grpc_trace, "doesnt-exist,http,all"); + +#ifdef GRPC_POSIX_SOCKET + g_fixture_slowdown_factor = isatty(STDOUT_FILENO) ? 10 : 1; +#else + g_fixture_slowdown_factor = 10; +#endif + +#ifdef GPR_WINDOWS + /* on Windows, writing logs to stderr is very slow + when stderr is redirected to a disk file. + The "trace" tests fixtures generates large amount + of logs, so setting a buffer for stderr prevents certain + test cases from timing out. */ + setvbuf(stderr, NULL, _IOLBF, 1024); +#endif + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + GPR_ASSERT(0 == grpc_tracer_set_enabled("also-doesnt-exist", 0)); + GPR_ASSERT(1 == grpc_tracer_set_enabled("http", 1)); + GPR_ASSERT(1 == grpc_tracer_set_enabled("all", 1)); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_sockpair.cc b/test/core/end2end/fixtures/h2_sockpair.cc new file mode 100644 index 00000000..a8b61152 --- /dev/null +++ b/test/core/end2end/fixtures/h2_sockpair.cc @@ -0,0 +1,172 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/client/http_client_filter.h" +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* chttp2 transport that is immediately available (used for testing + connected_channel without a client_channel */ + +struct custom_fixture_data { + grpc_endpoint_pair ep; + grpc_resource_quota* resource_quota; +}; + +static void server_setup_transport(void* ts, grpc_transport* transport) { + grpc_end2end_test_fixture* f = static_cast(ts); + grpc_core::ExecCtx exec_ctx; + custom_fixture_data* fixture_data = + static_cast(f->fixture_data); + grpc_endpoint_add_to_pollset(fixture_data->ep.server, grpc_cq_pollset(f->cq)); + grpc_error_handle error = f->server->core_server->SetupTransport( + transport, nullptr, f->server->core_server->channel_args(), nullptr); + if (error == GRPC_ERROR_NONE) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(transport); + } +} + +typedef struct { + grpc_end2end_test_fixture* f; + grpc_channel_args* client_args; +} sp_client_setup; + +static void client_setup_transport(void* ts, grpc_transport* transport) { + sp_client_setup* cs = static_cast(ts); + + grpc_arg authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test-authority")); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(cs->client_args, &authority_arg, 1); + grpc_error_handle error = GRPC_ERROR_NONE; + cs->f->client = + grpc_channel_create("socketpair-target", args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, &error); + grpc_channel_args_destroy(args); + if (cs->f->client != nullptr) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + cs->f->client = + grpc_lame_client_channel_create(nullptr, status, "lame channel"); + grpc_transport_destroy(transport); + } +} + +static grpc_end2end_test_fixture chttp2_create_fixture_socketpair( + grpc_channel_args* client_args, grpc_channel_args* /*server_args*/) { + custom_fixture_data* fixture_data = static_cast( + gpr_malloc(sizeof(custom_fixture_data))); + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + f.fixture_data = fixture_data; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + fixture_data->resource_quota = + grpc_resource_quota_from_channel_args(client_args, true); + fixture_data->ep = grpc_iomgr_create_endpoint_pair("fixture", nullptr); + return f; +} + +static void chttp2_init_client_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + sp_client_setup cs; + cs.client_args = client_args; + cs.f = f; + transport = grpc_create_chttp2_transport( + client_args, fixture_data->ep.client, true, + grpc_resource_user_create(fixture_data->resource_quota, + "client_transport")); + client_setup_transport(&cs, transport); + GPR_ASSERT(f->client); +} + +static void chttp2_init_server_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + GPR_ASSERT(!f->server); + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); + transport = grpc_create_chttp2_transport( + server_args, fixture_data->ep.server, false, + grpc_resource_user_create(fixture_data->resource_quota, + "server_transport")); + server_setup_transport(f, transport); +} + +static void chttp2_tear_down_socketpair(grpc_end2end_test_fixture* f) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_resource_quota_unref(fixture_data->resource_quota); + gpr_free(f->fixture_data); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/socketpair", FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + chttp2_create_fixture_socketpair, chttp2_init_client_socketpair, + chttp2_init_server_socketpair, chttp2_tear_down_socketpair}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_sockpair_1byte.cc b/test/core/end2end/fixtures/h2_sockpair_1byte.cc new file mode 100644 index 00000000..35586267 --- /dev/null +++ b/test/core/end2end/fixtures/h2_sockpair_1byte.cc @@ -0,0 +1,186 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/client/http_client_filter.h" +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* chttp2 transport that is immediately available (used for testing + connected_channel without a client_channel */ + +struct custom_fixture_data { + grpc_endpoint_pair ep; + grpc_resource_quota* resource_quota; +}; + +static void server_setup_transport(void* ts, grpc_transport* transport) { + grpc_end2end_test_fixture* f = static_cast(ts); + grpc_core::ExecCtx exec_ctx; + custom_fixture_data* fixture_data = + static_cast(f->fixture_data); + grpc_endpoint_add_to_pollset(fixture_data->ep.server, grpc_cq_pollset(f->cq)); + grpc_error_handle error = f->server->core_server->SetupTransport( + transport, nullptr, f->server->core_server->channel_args(), nullptr); + if (error == GRPC_ERROR_NONE) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + GRPC_ERROR_UNREF(error); + grpc_transport_destroy(transport); + } +} + +typedef struct { + grpc_end2end_test_fixture* f; + grpc_channel_args* client_args; +} sp_client_setup; + +static void client_setup_transport(void* ts, grpc_transport* transport) { + sp_client_setup* cs = static_cast(ts); + + grpc_arg authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test-authority")); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(cs->client_args, &authority_arg, 1); + grpc_error_handle error = GRPC_ERROR_NONE; + cs->f->client = + grpc_channel_create("socketpair-target", args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, &error); + grpc_channel_args_destroy(args); + if (cs->f->client != nullptr) { + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } else { + intptr_t integer; + grpc_status_code status = GRPC_STATUS_INTERNAL; + if (grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &integer)) { + status = static_cast(integer); + } + GRPC_ERROR_UNREF(error); + cs->f->client = + grpc_lame_client_channel_create(nullptr, status, "lame channel"); + grpc_transport_destroy(transport); + } +} + +static grpc_end2end_test_fixture chttp2_create_fixture_socketpair( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + custom_fixture_data* fixture_data = static_cast( + gpr_malloc(sizeof(custom_fixture_data))); + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + f.fixture_data = fixture_data; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_arg a[3]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = 1; + a[1].key = const_cast(GRPC_ARG_TCP_MIN_READ_CHUNK_SIZE); + a[1].type = GRPC_ARG_INTEGER; + a[1].value.integer = 1; + a[2].key = const_cast(GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE); + a[2].type = GRPC_ARG_INTEGER; + a[2].value.integer = 1; + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + fixture_data->resource_quota = + grpc_resource_quota_from_channel_args(&args, true); + fixture_data->ep = grpc_iomgr_create_endpoint_pair("fixture", &args); + return f; +} + +static void chttp2_init_client_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + sp_client_setup cs; + cs.client_args = client_args; + cs.f = f; + transport = grpc_create_chttp2_transport( + client_args, fixture_data->ep.client, true, + grpc_resource_user_create(fixture_data->resource_quota, + "client_transport")); + client_setup_transport(&cs, transport); + GPR_ASSERT(f->client); +} + +static void chttp2_init_server_socketpair(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_transport* transport; + GPR_ASSERT(!f->server); + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); + transport = grpc_create_chttp2_transport( + server_args, fixture_data->ep.server, false, + grpc_resource_user_create(fixture_data->resource_quota, + "server_transport")); + server_setup_transport(f, transport); +} + +static void chttp2_tear_down_socketpair(grpc_end2end_test_fixture* f) { + grpc_core::ExecCtx exec_ctx; + auto* fixture_data = static_cast(f->fixture_data); + grpc_resource_quota_unref(fixture_data->resource_quota); + gpr_free(f->fixture_data); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/socketpair_one_byte_at_a_time", + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + chttp2_create_fixture_socketpair, chttp2_init_client_socketpair, + chttp2_init_server_socketpair, chttp2_tear_down_socketpair}, +}; + +int main(int argc, char** argv) { + size_t i; + + g_fixture_slowdown_factor = 2; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_ssl.cc b/test/core/end2end/fixtures/h2_ssl.cc new file mode 100644 index 00000000..e66565f1 --- /dev/null +++ b/test/core/end2end/fixtures/h2_ssl.cc @@ -0,0 +1,224 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +struct fullstack_secure_fixture_data { + std::string localaddr; + grpc_tls_version tls_version; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/, + grpc_tls_version tls_version) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + ffd->tls_version = tls_version; + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_2( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_2); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_3( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_3); +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +static void chttp2_init_client_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args) { + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_creds); + grpc_channel_args_destroy(new_client_args); +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void chttp2_init_server_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args) { + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_server_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); + } + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); +} + +/* All test configurations */ + +static grpc_end2end_test_config configs[] = { + {"chttp2/simple_ssl_fullstack_tls1_2", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_2, + chttp2_init_client_simple_ssl_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, + {"chttp2/simple_ssl_fullstack_tls1_3", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_DOES_NOT_SUPPORT_CLIENT_HANDSHAKE_COMPLETE_FIRST, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_3, + chttp2_init_client_simple_ssl_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fixtures/h2_ssl_cred_reload.cc b/test/core/end2end/fixtures/h2_ssl_cred_reload.cc new file mode 100644 index 00000000..433e4433 --- /dev/null +++ b/test/core/end2end/fixtures/h2_ssl_cred_reload.cc @@ -0,0 +1,254 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +struct fullstack_secure_fixture_data { + std::string localaddr; + grpc_tls_version tls_version; + bool server_credential_reloaded = false; +}; + +static grpc_ssl_certificate_config_reload_status +ssl_server_certificate_config_callback( + void* user_data, grpc_ssl_server_certificate_config** config) { + if (config == nullptr) { + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL; + } + fullstack_secure_fixture_data* ffd = + static_cast(user_data); + if (!ffd->server_credential_reloaded) { + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + *config = grpc_ssl_server_certificate_config_create(ca_cert, + &pem_key_cert_pair, 1); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + ffd->server_credential_reloaded = true; + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW; + } else { + return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED; + } +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/, + grpc_tls_version tls_version) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + ffd->tls_version = tls_version; + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_2( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_2); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack_tls1_3( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_secure_fullstack(client_args, server_args, + grpc_tls_version::TLS1_3); +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + ffd->server_credential_reloaded = false; + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +static void chttp2_init_client_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args) { + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_creds); + grpc_channel_args_destroy(new_client_args); +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void chttp2_init_server_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args) { + grpc_ssl_server_credentials_options* options = + grpc_ssl_server_credentials_create_options_using_config_fetcher( + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, + ssl_server_certificate_config_callback, f->fixture_data); + grpc_server_credentials* ssl_creds = + grpc_ssl_server_credentials_create_with_options(options); + if (f != nullptr && ssl_creds != nullptr) { + // Set the min and max TLS version. + grpc_ssl_server_credentials* creds = + reinterpret_cast(ssl_creds); + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + creds->set_min_tls_version(ffd->tls_version); + creds->set_max_tls_version(ffd->tls_version); + } + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); + } + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); +} + +/* All test configurations */ + +static grpc_end2end_test_config configs[] = { + {"chttp2/simple_ssl_fullstack_tls1_2", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_2, + chttp2_init_client_simple_ssl_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, + {"chttp2/simple_ssl_fullstack_tls1_3", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER | + FEATURE_MASK_DOES_NOT_SUPPORT_CLIENT_HANDSHAKE_COMPLETE_FIRST, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack_tls1_3, + chttp2_init_client_simple_ssl_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_ssl_proxy.cc b/test/core/end2end/fixtures/h2_ssl_proxy.cc new file mode 100644 index 00000000..9980c095 --- /dev/null +++ b/test/core/end2end/fixtures/h2_ssl_proxy.cc @@ -0,0 +1,235 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/proxy.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +typedef struct fullstack_secure_fixture_data { + grpc_end2end_proxy* proxy; +} fullstack_secure_fixture_data; + +static grpc_server* create_proxy_server(const char* port, + grpc_channel_args* server_args) { + grpc_server* s = grpc_server_create(server_args, nullptr); + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + GPR_ASSERT(grpc_server_add_secure_http2_port(s, port, ssl_creds)); + grpc_server_credentials_release(ssl_creds); + return s; +} + +static grpc_channel* create_proxy_client(const char* target, + grpc_channel_args* client_args) { + grpc_channel* channel; + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + channel = + grpc_secure_channel_create(ssl_creds, target, new_client_args, nullptr); + grpc_channel_credentials_release(ssl_creds); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(new_client_args); + } + return channel; +} + +static const grpc_end2end_proxy_def proxy_def = {create_proxy_server, + create_proxy_client}; + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + fullstack_secure_fixture_data* ffd = + static_cast( + gpr_malloc(sizeof(fullstack_secure_fixture_data))); + memset(&f, 0, sizeof(f)); + + ffd->proxy = grpc_end2end_proxy_create(&proxy_def, client_args, server_args); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create( + creds, grpc_end2end_proxy_get_client_target(ffd->proxy), client_args, + nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, grpc_end2end_proxy_get_server_port(ffd->proxy), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + grpc_end2end_proxy_destroy(ffd->proxy); + gpr_free(ffd); +} + +static void chttp2_init_client_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args) { + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_creds); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(new_client_args); + } +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void chttp2_init_server_simple_ssl_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args) { + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); + } + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); +} + +/* All test configurations */ + +static grpc_end2end_test_config configs[] = { + {"chttp2/simple_ssl_fullstack", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_REQUEST_PROXYING | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_secure_fullstack, + chttp2_init_client_simple_ssl_secure_fullstack, + chttp2_init_server_simple_ssl_secure_fullstack, + chttp2_tear_down_secure_fullstack}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/h2_tls.cc b/test/core/end2end/fixtures/h2_tls.cc new file mode 100644 index 00000000..2254e1cf --- /dev/null +++ b/test/core/end2end/fixtures/h2_tls.cc @@ -0,0 +1,351 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/container/inlined_vector.h" + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +// For normal TLS connections. +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +typedef absl::InlinedVector ThreadList; + +struct fullstack_secure_fixture_data { + ~fullstack_secure_fixture_data() { + for (size_t ind = 0; ind < thd_list.size(); ind++) { + thd_list[ind].Join(); + } + grpc_tls_certificate_provider_release(client_provider); + grpc_tls_certificate_provider_release(server_provider); + } + std::string localaddr; + grpc_tls_version tls_version; + ThreadList thd_list; + grpc_tls_certificate_provider* client_provider = nullptr; + grpc_tls_certificate_provider* server_provider = nullptr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_static_data( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/, + grpc_tls_version tls_version) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + ffd->tls_version = tls_version; + grpc_slice root_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &root_slice))); + std::string root_cert = + std::string(grpc_core::StringViewFromSlice(root_slice)); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + std::string identity_cert = + std::string(grpc_core::StringViewFromSlice(cert_slice)); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + std::string private_key = + std::string(grpc_core::StringViewFromSlice(key_slice)); + grpc_tls_identity_pairs* client_pairs = grpc_tls_identity_pairs_create(); + grpc_tls_identity_pairs_add_pair(client_pairs, private_key.c_str(), + identity_cert.c_str()); + ffd->client_provider = grpc_tls_certificate_provider_static_data_create( + root_cert.c_str(), client_pairs); + grpc_tls_identity_pairs* server_pairs = grpc_tls_identity_pairs_create(); + grpc_tls_identity_pairs_add_pair(server_pairs, private_key.c_str(), + identity_cert.c_str()); + ffd->server_provider = grpc_tls_certificate_provider_static_data_create( + root_cert.c_str(), server_pairs); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_slice_unref(root_slice); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_cert_watcher( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/, + grpc_tls_version tls_version) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + ffd->tls_version = tls_version; + ffd->client_provider = grpc_tls_certificate_provider_file_watcher_create( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + ffd->server_provider = grpc_tls_certificate_provider_file_watcher_create( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_static_data_tls1_2( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_static_data(client_args, server_args, + grpc_tls_version::TLS1_2); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_static_data_tls1_3( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_static_data(client_args, server_args, + grpc_tls_version::TLS1_3); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_cert_watcher_tls1_2( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_cert_watcher(client_args, server_args, + grpc_tls_version::TLS1_2); +} + +static grpc_end2end_test_fixture chttp2_create_fixture_cert_watcher_tls1_3( + grpc_channel_args* client_args, grpc_channel_args* server_args) { + return chttp2_create_fixture_cert_watcher(client_args, server_args, + grpc_tls_version::TLS1_3); +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +// Application-provided callback for server authorization check. +static void server_authz_check_cb(void* user_data) { + grpc_tls_server_authorization_check_arg* check_arg = + static_cast(user_data); + GPR_ASSERT(check_arg != nullptr); + // result = 1 indicates the server authorization check passes. + // Normally, the application code should resort to mapping information + // between server identity and target name to derive the result. + // For this test, we directly return 1 for simplicity. + check_arg->success = 1; + check_arg->status = GRPC_STATUS_OK; + check_arg->cb(check_arg); +} + +// Asynchronous implementation of schedule field in +// grpc_server_authorization_check_config. +static int server_authz_check_async( + void* config_user_data, grpc_tls_server_authorization_check_arg* arg) { + fullstack_secure_fixture_data* ffd = + static_cast(config_user_data); + ffd->thd_list.push_back( + grpc_core::Thread("h2_tls_test", &server_authz_check_cb, arg)); + ffd->thd_list[ffd->thd_list.size() - 1].Start(); + return 1; +} + +// Create a TLS channel credential. +static grpc_channel_credentials* create_tls_channel_credentials( + fullstack_secure_fixture_data* ffd) { + grpc_tls_credentials_options* options = grpc_tls_credentials_options_create(); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + options->set_min_tls_version(ffd->tls_version); + options->set_max_tls_version(ffd->tls_version); + // Set credential provider. + grpc_tls_credentials_options_set_certificate_provider(options, + ffd->client_provider); + grpc_tls_credentials_options_watch_root_certs(options); + grpc_tls_credentials_options_watch_identity_key_cert_pairs(options); + /* Set server authorization check config. */ + grpc_tls_server_authorization_check_config* check_config = + grpc_tls_server_authorization_check_config_create( + ffd, server_authz_check_async, nullptr, nullptr); + grpc_tls_credentials_options_set_server_authorization_check_config( + options, check_config); + /* Create TLS channel credentials. */ + grpc_channel_credentials* creds = grpc_tls_credentials_create(options); + grpc_tls_server_authorization_check_config_release(check_config); + return creds; +} + +// Create a TLS server credential. +static grpc_server_credentials* create_tls_server_credentials( + fullstack_secure_fixture_data* ffd) { + grpc_tls_credentials_options* options = grpc_tls_credentials_options_create(); + options->set_min_tls_version(ffd->tls_version); + options->set_max_tls_version(ffd->tls_version); + // Set credential provider. + grpc_tls_credentials_options_set_certificate_provider(options, + ffd->server_provider); + grpc_tls_credentials_options_watch_root_certs(options); + grpc_tls_credentials_options_watch_identity_key_cert_pairs(options); + /* Set client certificate request type. */ + grpc_tls_credentials_options_set_cert_request_type( + options, GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + grpc_server_credentials* creds = grpc_tls_server_credentials_create(options); + return creds; +} + +static void chttp2_init_client(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + grpc_channel_credentials* ssl_creds = create_tls_channel_credentials( + static_cast(f->fixture_data)); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_creds); + grpc_channel_args_destroy(new_client_args); +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +static void chttp2_init_server(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + grpc_server_credentials* ssl_creds = create_tls_server_credentials( + static_cast(f->fixture_data)); + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(ssl_creds, processor); + } + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); +} + +static grpc_end2end_test_config configs[] = { + // client: static data provider + async custom verification + // server: static data provider + // extra: TLS 1.2 + {"chttp2/simple_ssl_fullstack_tls1_2", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_static_data_tls1_2, + chttp2_init_client, chttp2_init_server, chttp2_tear_down_secure_fullstack}, + // client: static data provider + async custom verification + // server: static data provider + // extra: TLS 1.3 + {"chttp2/simple_ssl_fullstack_tls1_3", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_static_data_tls1_3, + chttp2_init_client, chttp2_init_server, chttp2_tear_down_secure_fullstack}, + // client: certificate watcher provider + async custom verification + // server: certificate watcher provider + // extra: TLS 1.2 + {"chttp2/reloading_from_files_ssl_fullstack_tls1_2", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_cert_watcher_tls1_2, + chttp2_init_client, chttp2_init_server, chttp2_tear_down_secure_fullstack}, + // client: certificate watcher provider + async custom verification + // server: certificate watcher provider + // extra: TLS 1.3 + {"chttp2/reloading_from_files_ssl_fullstack_tls1_3", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + "foo.test.google.fr", chttp2_create_fixture_cert_watcher_tls1_3, + chttp2_init_client, chttp2_init_server, chttp2_tear_down_secure_fullstack}, + +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + grpc_init(); + for (size_t ind = 0; ind < sizeof(configs) / sizeof(*configs); ind++) { + grpc_end2end_tests(argc, argv, configs[ind]); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fixtures/h2_uds.cc b/test/core/end2end/fixtures/h2_uds.cc new file mode 100644 index 00000000..f42667e8 --- /dev/null +++ b/test/core/end2end/fixtures/h2_uds.cc @@ -0,0 +1,143 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +struct fullstack_fixture_data { + std::string localaddr; +}; + +static int unique = 1; + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack_base( + std::string addr) { + fullstack_fixture_data* ffd = new fullstack_fixture_data; + ffd->localaddr = std::move(addr); + + grpc_end2end_test_fixture f; + memset(&f, 0, sizeof(f)); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +static grpc_end2end_test_fixture chttp2_create_fixture_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + const std::string localaddr = absl::StrFormat( + "unix:/tmp/grpc_fullstack_test.%d.%" PRId64 ".%" PRId32 ".%d", getpid(), + now.tv_sec, now.tv_nsec, unique++); + return chttp2_create_fixture_fullstack_base(localaddr); +} + +static grpc_end2end_test_fixture +chttp2_create_fixture_fullstack_abstract_namespace( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + const std::string localaddr = absl::StrFormat( + "unix-abstract:grpc_fullstack_test.%d.%" PRId64 ".%" PRId32 ".%d", + getpid(), now.tv_sec, now.tv_nsec, unique++); + return chttp2_create_fixture_fullstack_base(localaddr); +} + +void chttp2_init_client_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_insecure_channel_create(ffd->localaddr.c_str(), client_args, + nullptr); +} + +void chttp2_init_server_fullstack(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(f->server, ffd->localaddr.c_str())); + grpc_server_start(f->server); +} + +void chttp2_tear_down_fullstack(grpc_end2end_test_fixture* f) { + fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"chttp2/fullstack_uds", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack, chttp2_init_client_fullstack, + chttp2_init_server_fullstack, chttp2_tear_down_fullstack}, +#ifndef GPR_APPLE // Apple doesn't support an abstract socket + {"chttp2/fullstack_uds_abstract_namespace", + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL | + FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, + nullptr, chttp2_create_fixture_fullstack_abstract_namespace, + chttp2_init_client_fullstack, chttp2_init_server_fullstack, + chttp2_tear_down_fullstack}, +#endif +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/http_proxy_fixture.cc b/test/core/end2end/fixtures/http_proxy_fixture.cc new file mode 100644 index 00000000..93f4f1bd --- /dev/null +++ b/test/core/end2end/fixtures/http_proxy_fixture.cc @@ -0,0 +1,666 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/end2end/fixtures/http_proxy_fixture.h" + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/http/parser.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/port.h" +#include "test/core/util/resource_user_util.h" + +struct grpc_end2end_http_proxy { + grpc_end2end_http_proxy() + : server(nullptr), channel_args(nullptr), mu(nullptr), combiner(nullptr) { + gpr_ref_init(&users, 1); + combiner = grpc_combiner_create(); + } + std::string proxy_name; + grpc_core::Thread thd; + grpc_tcp_server* server; + grpc_channel_args* channel_args; + gpr_mu* mu; + std::vector pollset; + gpr_refcount users; + + grpc_core::Combiner* combiner; +}; + +// +// Connection handling +// + +// proxy_connection structure is only accessed in the closures which are all +// scheduled under the same combiner lock. So there is no need for a mutex to +// protect this structure. +typedef struct proxy_connection { + grpc_end2end_http_proxy* proxy; + + grpc_endpoint* client_endpoint; + grpc_endpoint* server_endpoint; + + gpr_refcount refcount; + + grpc_pollset_set* pollset_set; + + // NOTE: All the closures execute under proxy->combiner lock. Which means + // there will not be any data-races between the closures + grpc_closure on_read_request_done; + grpc_closure on_server_connect_done; + grpc_closure on_write_response_done; + grpc_closure on_client_read_done; + grpc_closure on_client_write_done; + grpc_closure on_server_read_done; + grpc_closure on_server_write_done; + + bool client_read_failed : 1; + bool client_write_failed : 1; + bool client_shutdown : 1; + bool server_read_failed : 1; + bool server_write_failed : 1; + bool server_shutdown : 1; + + grpc_slice_buffer client_read_buffer; + grpc_slice_buffer client_deferred_write_buffer; + bool client_is_writing; + grpc_slice_buffer client_write_buffer; + grpc_slice_buffer server_read_buffer; + grpc_slice_buffer server_deferred_write_buffer; + bool server_is_writing; + grpc_slice_buffer server_write_buffer; + + grpc_http_parser http_parser; + grpc_http_request http_request; +} proxy_connection; + +static void proxy_connection_ref(proxy_connection* conn, + const char* /*reason*/) { + gpr_ref(&conn->refcount); +} + +// Helper function to destroy the proxy connection. +static void proxy_connection_unref(proxy_connection* conn, + const char* /*reason*/) { + if (gpr_unref(&conn->refcount)) { + gpr_log(GPR_DEBUG, "endpoints: %p %p", conn->client_endpoint, + conn->server_endpoint); + grpc_endpoint_destroy(conn->client_endpoint); + if (conn->server_endpoint != nullptr) { + grpc_endpoint_destroy(conn->server_endpoint); + } + grpc_pollset_set_destroy(conn->pollset_set); + grpc_slice_buffer_destroy_internal(&conn->client_read_buffer); + grpc_slice_buffer_destroy_internal(&conn->client_deferred_write_buffer); + grpc_slice_buffer_destroy_internal(&conn->client_write_buffer); + grpc_slice_buffer_destroy_internal(&conn->server_read_buffer); + grpc_slice_buffer_destroy_internal(&conn->server_deferred_write_buffer); + grpc_slice_buffer_destroy_internal(&conn->server_write_buffer); + grpc_http_parser_destroy(&conn->http_parser); + grpc_http_request_destroy(&conn->http_request); + gpr_unref(&conn->proxy->users); + gpr_free(conn); + } +} + +enum failure_type { + SETUP_FAILED, // To be used before we start proxying. + CLIENT_READ_FAILED, + CLIENT_WRITE_FAILED, + SERVER_READ_FAILED, + SERVER_WRITE_FAILED, +}; + +// Forward declarations +static void on_client_write_done(void* arg, grpc_error_handle error); +static void on_server_write_done(void* arg, grpc_error_handle error); +static void on_client_read_done(void* arg, grpc_error_handle error); +static void on_server_read_done(void* arg, grpc_error_handle error); +static void on_server_connect_done(void* arg, grpc_error_handle error); +static void on_read_request_done(void* arg, grpc_error_handle error); + +static void on_client_write_done_locked(void* arg, grpc_error_handle error); +static void on_server_write_done_locked(void* arg, grpc_error_handle error); +static void on_client_read_done_locked(void* arg, grpc_error_handle error); +static void on_server_read_done_locked(void* arg, grpc_error_handle error); +static void on_server_connect_done_locked(void* arg, grpc_error_handle error); +static void on_read_request_done_locked(void* arg, grpc_error_handle error); + +// Helper function to shut down the proxy connection. +static void proxy_connection_failed(proxy_connection* conn, + failure_type failure, const char* prefix, + grpc_error_handle error) { + gpr_log(GPR_INFO, "%s: %s", prefix, grpc_error_std_string(error).c_str()); + // Decide whether we should shut down the client and server. + bool shutdown_client = false; + bool shutdown_server = false; + if (failure == SETUP_FAILED) { + shutdown_client = true; + shutdown_server = true; + } else { + if ((failure == CLIENT_READ_FAILED && conn->client_write_failed) || + (failure == CLIENT_WRITE_FAILED && conn->client_read_failed) || + (failure == SERVER_READ_FAILED && !conn->client_is_writing)) { + shutdown_client = true; + } + if ((failure == SERVER_READ_FAILED && conn->server_write_failed) || + (failure == SERVER_WRITE_FAILED && conn->server_read_failed) || + (failure == CLIENT_READ_FAILED && !conn->server_is_writing)) { + shutdown_server = true; + } + } + // If we decided to shut down either one and have not yet done so, do so. + if (shutdown_client && !conn->client_shutdown) { + grpc_endpoint_shutdown(conn->client_endpoint, GRPC_ERROR_REF(error)); + conn->client_shutdown = true; + } + if (shutdown_server && !conn->server_shutdown && + (conn->server_endpoint != nullptr)) { + grpc_endpoint_shutdown(conn->server_endpoint, GRPC_ERROR_REF(error)); + conn->server_shutdown = true; + } + // Unref the connection. + proxy_connection_unref(conn, "conn_failed"); + GRPC_ERROR_UNREF(error); +} + +// Callback for writing proxy data to the client. +static void on_client_write_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + conn->client_is_writing = false; + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, CLIENT_WRITE_FAILED, + "HTTP proxy client write", GRPC_ERROR_REF(error)); + return; + } + // Clear write buffer (the data we just wrote). + grpc_slice_buffer_reset_and_unref(&conn->client_write_buffer); + // If more data was read from the server since we started this write, + // write that data now. + if (conn->client_deferred_write_buffer.length > 0) { + grpc_slice_buffer_move_into(&conn->client_deferred_write_buffer, + &conn->client_write_buffer); + conn->client_is_writing = true; + GRPC_CLOSURE_INIT(&conn->on_client_write_done, on_client_write_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_write(conn->client_endpoint, &conn->client_write_buffer, + &conn->on_client_write_done, nullptr); + } else { + // No more writes. Unref the connection. + proxy_connection_unref(conn, "write_done"); + } +} + +static void on_client_write_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_client_write_done, on_client_write_done_locked, + conn, nullptr); + conn->proxy->combiner->Run(&conn->on_client_write_done, + GRPC_ERROR_REF(error)); +} + +// Callback for writing proxy data to the backend server. +static void on_server_write_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + conn->server_is_writing = false; + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SERVER_WRITE_FAILED, + "HTTP proxy server write", GRPC_ERROR_REF(error)); + return; + } + // Clear write buffer (the data we just wrote). + grpc_slice_buffer_reset_and_unref(&conn->server_write_buffer); + // If more data was read from the client since we started this write, + // write that data now. + if (conn->server_deferred_write_buffer.length > 0) { + grpc_slice_buffer_move_into(&conn->server_deferred_write_buffer, + &conn->server_write_buffer); + conn->server_is_writing = true; + GRPC_CLOSURE_INIT(&conn->on_server_write_done, on_server_write_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_write(conn->server_endpoint, &conn->server_write_buffer, + &conn->on_server_write_done, nullptr); + } else { + // No more writes. Unref the connection. + proxy_connection_unref(conn, "server_write"); + } +} + +static void on_server_write_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_server_write_done, on_server_write_done_locked, + conn, nullptr); + conn->proxy->combiner->Run(&conn->on_server_write_done, + GRPC_ERROR_REF(error)); +} + +// Callback for reading data from the client, which will be proxied to +// the backend server. +static void on_client_read_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, CLIENT_READ_FAILED, "HTTP proxy client read", + GRPC_ERROR_REF(error)); + return; + } + // If there is already a pending write (i.e., server_write_buffer is + // not empty), then move the read data into server_deferred_write_buffer, + // and the next write will be requested in on_server_write_done(), when + // the current write is finished. + // + // Otherwise, move the read data into the write buffer and write it. + if (conn->server_is_writing) { + grpc_slice_buffer_move_into(&conn->client_read_buffer, + &conn->server_deferred_write_buffer); + } else { + grpc_slice_buffer_move_into(&conn->client_read_buffer, + &conn->server_write_buffer); + proxy_connection_ref(conn, "client_read"); + conn->server_is_writing = true; + GRPC_CLOSURE_INIT(&conn->on_server_write_done, on_server_write_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_write(conn->server_endpoint, &conn->server_write_buffer, + &conn->on_server_write_done, nullptr); + } + // Read more data. + GRPC_CLOSURE_INIT(&conn->on_client_read_done, on_client_read_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->client_endpoint, &conn->client_read_buffer, + &conn->on_client_read_done, /*urgent=*/false); +} + +static void on_client_read_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_client_read_done, on_client_read_done_locked, + conn, nullptr); + conn->proxy->combiner->Run(&conn->on_client_read_done, GRPC_ERROR_REF(error)); +} + +// Callback for reading data from the backend server, which will be +// proxied to the client. +static void on_server_read_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SERVER_READ_FAILED, "HTTP proxy server read", + GRPC_ERROR_REF(error)); + return; + } + // If there is already a pending write (i.e., client_write_buffer is + // not empty), then move the read data into client_deferred_write_buffer, + // and the next write will be requested in on_client_write_done(), when + // the current write is finished. + // + // Otherwise, move the read data into the write buffer and write it. + if (conn->client_is_writing) { + grpc_slice_buffer_move_into(&conn->server_read_buffer, + &conn->client_deferred_write_buffer); + } else { + grpc_slice_buffer_move_into(&conn->server_read_buffer, + &conn->client_write_buffer); + proxy_connection_ref(conn, "server_read"); + conn->client_is_writing = true; + GRPC_CLOSURE_INIT(&conn->on_client_write_done, on_client_write_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_write(conn->client_endpoint, &conn->client_write_buffer, + &conn->on_client_write_done, nullptr); + } + // Read more data. + GRPC_CLOSURE_INIT(&conn->on_server_read_done, on_server_read_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->server_endpoint, &conn->server_read_buffer, + &conn->on_server_read_done, /*urgent=*/false); +} + +static void on_server_read_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_server_read_done, on_server_read_done_locked, + conn, nullptr); + conn->proxy->combiner->Run(&conn->on_server_read_done, GRPC_ERROR_REF(error)); +} + +// Callback to write the HTTP response for the CONNECT request. +static void on_write_response_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + conn->client_is_writing = false; + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy write response", + GRPC_ERROR_REF(error)); + return; + } + // Clear write buffer. + grpc_slice_buffer_reset_and_unref(&conn->client_write_buffer); + // Start reading from both client and server. One of the read + // requests inherits our ref to conn, but we need to take a new ref + // for the other one. + proxy_connection_ref(conn, "client_read"); + proxy_connection_ref(conn, "server_read"); + proxy_connection_unref(conn, "write_response"); + GRPC_CLOSURE_INIT(&conn->on_client_read_done, on_client_read_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->client_endpoint, &conn->client_read_buffer, + &conn->on_client_read_done, /*urgent=*/false); + GRPC_CLOSURE_INIT(&conn->on_server_read_done, on_server_read_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->server_endpoint, &conn->server_read_buffer, + &conn->on_server_read_done, /*urgent=*/false); +} + +static void on_write_response_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_write_response_done, + on_write_response_done_locked, conn, nullptr); + conn->proxy->combiner->Run(&conn->on_write_response_done, + GRPC_ERROR_REF(error)); +} + +// Callback to connect to the backend server specified by the HTTP +// CONNECT request. +static void on_server_connect_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + // TODO(roth): Technically, in this case, we should handle the error + // by returning an HTTP response to the client indicating that the + // connection failed. However, for the purposes of this test code, + // it's fine to pretend this is a client-side error, which will + // cause the client connection to be dropped. + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy server connect", + GRPC_ERROR_REF(error)); + return; + } + // We've established a connection, so send back a 200 response code to + // the client. + // The write callback inherits our reference to conn. + grpc_slice slice = + grpc_slice_from_copied_string("HTTP/1.0 200 connected\r\n\r\n"); + grpc_slice_buffer_add(&conn->client_write_buffer, slice); + conn->client_is_writing = true; + GRPC_CLOSURE_INIT(&conn->on_write_response_done, on_write_response_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_write(conn->client_endpoint, &conn->client_write_buffer, + &conn->on_write_response_done, nullptr); +} + +static void on_server_connect_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_server_connect_done, + on_server_connect_done_locked, conn, nullptr); + conn->proxy->combiner->Run(&conn->on_server_connect_done, + GRPC_ERROR_REF(error)); +} + +/** + * Parses the proxy auth header value to check if it matches :- + * Basic + * Returns true if it matches, false otherwise + */ +static bool proxy_auth_header_matches(char* proxy_auth_header_val, + char* expected_cred) { + GPR_ASSERT(proxy_auth_header_val != nullptr); + GPR_ASSERT(expected_cred != nullptr); + if (strncmp(proxy_auth_header_val, "Basic ", 6) != 0) { + return false; + } + proxy_auth_header_val += 6; + grpc_slice decoded_slice = grpc_base64_decode(proxy_auth_header_val, 0); + const bool header_matches = + grpc_slice_str_cmp(decoded_slice, expected_cred) == 0; + grpc_slice_unref_internal(decoded_slice); + return header_matches; +} + +// Callback to read the HTTP CONNECT request. +// TODO(roth): Technically, for any of the failure modes handled by this +// function, we should handle the error by returning an HTTP response to +// the client indicating that the request failed. However, for the purposes +// of this test code, it's fine to pretend this is a client-side error, +// which will cause the client connection to be dropped. +static void on_read_request_done_locked(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + gpr_log(GPR_DEBUG, "on_read_request_done: %p %s", conn, + grpc_error_std_string(error).c_str()); + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy read request", + GRPC_ERROR_REF(error)); + return; + } + // Read request and feed it to the parser. + for (size_t i = 0; i < conn->client_read_buffer.count; ++i) { + if (GRPC_SLICE_LENGTH(conn->client_read_buffer.slices[i]) > 0) { + error = grpc_http_parser_parse( + &conn->http_parser, conn->client_read_buffer.slices[i], nullptr); + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy request parse", + GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + return; + } + } + } + grpc_slice_buffer_reset_and_unref(&conn->client_read_buffer); + // If we're not done reading the request, read more data. + if (conn->http_parser.state != GRPC_HTTP_BODY) { + GRPC_CLOSURE_INIT(&conn->on_read_request_done, on_read_request_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->client_endpoint, &conn->client_read_buffer, + &conn->on_read_request_done, /*urgent=*/false); + return; + } + // Make sure we got a CONNECT request. + if (strcmp(conn->http_request.method, "CONNECT") != 0) { + error = GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "HTTP proxy got request method ", conn->http_request.method)); + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy read request", + GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + return; + } + // If proxy auth is being used, check if the header is present and as expected + const grpc_arg* proxy_auth_arg = grpc_channel_args_find( + conn->proxy->channel_args, GRPC_ARG_HTTP_PROXY_AUTH_CREDS); + char* proxy_auth_str = grpc_channel_arg_get_string(proxy_auth_arg); + if (proxy_auth_str != nullptr) { + bool client_authenticated = false; + for (size_t i = 0; i < conn->http_request.hdr_count; i++) { + if (strcmp(conn->http_request.hdrs[i].key, "Proxy-Authorization") == 0) { + client_authenticated = proxy_auth_header_matches( + conn->http_request.hdrs[i].value, proxy_auth_str); + break; + } + } + if (!client_authenticated) { + const char* msg = "HTTP Connect could not verify authentication"; + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(msg); + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy read request", + GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + return; + } + } + // Resolve address. + grpc_resolved_addresses* resolved_addresses = nullptr; + error = grpc_blocking_resolve_address(conn->http_request.path, "80", + &resolved_addresses); + if (error != GRPC_ERROR_NONE) { + proxy_connection_failed(conn, SETUP_FAILED, "HTTP proxy DNS lookup", + GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + return; + } + GPR_ASSERT(resolved_addresses->naddrs >= 1); + // Connect to requested address. + // The connection callback inherits our reference to conn. + const grpc_millis deadline = + grpc_core::ExecCtx::Get()->Now() + 10 * GPR_MS_PER_SEC; + GRPC_CLOSURE_INIT(&conn->on_server_connect_done, on_server_connect_done, conn, + grpc_schedule_on_exec_ctx); + grpc_tcp_client_connect(&conn->on_server_connect_done, &conn->server_endpoint, + grpc_slice_allocator_create_unlimited(), + conn->pollset_set, nullptr, + &resolved_addresses->addrs[0], deadline); + grpc_resolved_addresses_destroy(resolved_addresses); +} + +static void on_read_request_done(void* arg, grpc_error_handle error) { + proxy_connection* conn = static_cast(arg); + GRPC_CLOSURE_INIT(&conn->on_read_request_done, on_read_request_done_locked, + conn, nullptr); + conn->proxy->combiner->Run(&conn->on_read_request_done, + GRPC_ERROR_REF(error)); +} + +static void on_accept(void* arg, grpc_endpoint* endpoint, + grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + gpr_free(acceptor); + grpc_end2end_http_proxy* proxy = static_cast(arg); + // Instantiate proxy_connection. + proxy_connection* conn = grpc_core::Zalloc(); + gpr_ref(&proxy->users); + conn->client_endpoint = endpoint; + conn->proxy = proxy; + gpr_ref_init(&conn->refcount, 1); + conn->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(conn->pollset_set, proxy->pollset[0]); + grpc_endpoint_add_to_pollset_set(endpoint, conn->pollset_set); + grpc_slice_buffer_init(&conn->client_read_buffer); + grpc_slice_buffer_init(&conn->client_deferred_write_buffer); + conn->client_is_writing = false; + grpc_slice_buffer_init(&conn->client_write_buffer); + grpc_slice_buffer_init(&conn->server_read_buffer); + grpc_slice_buffer_init(&conn->server_deferred_write_buffer); + conn->server_is_writing = false; + grpc_slice_buffer_init(&conn->server_write_buffer); + grpc_http_parser_init(&conn->http_parser, GRPC_HTTP_REQUEST, + &conn->http_request); + GRPC_CLOSURE_INIT(&conn->on_read_request_done, on_read_request_done, conn, + grpc_schedule_on_exec_ctx); + grpc_endpoint_read(conn->client_endpoint, &conn->client_read_buffer, + &conn->on_read_request_done, /*urgent=*/false); +} + +// +// Proxy class +// + +static void thread_main(void* arg) { + grpc_end2end_http_proxy* proxy = static_cast(arg); + grpc_core::ExecCtx exec_ctx; + do { + gpr_ref(&proxy->users); + grpc_pollset_worker* worker = nullptr; + gpr_mu_lock(proxy->mu); + GRPC_LOG_IF_ERROR( + "grpc_pollset_work", + grpc_pollset_work(proxy->pollset[0], &worker, + grpc_core::ExecCtx::Get()->Now() + GPR_MS_PER_SEC)); + gpr_mu_unlock(proxy->mu); + grpc_core::ExecCtx::Get()->Flush(); + } while (!gpr_unref(&proxy->users)); +} + +grpc_end2end_http_proxy* grpc_end2end_http_proxy_create( + grpc_channel_args* args) { + grpc_core::ExecCtx exec_ctx; + grpc_end2end_http_proxy* proxy = new grpc_end2end_http_proxy(); + // Construct proxy address. + const int proxy_port = grpc_pick_unused_port_or_die(); + proxy->proxy_name = grpc_core::JoinHostPort("localhost", proxy_port); + gpr_log(GPR_INFO, "Proxy address: %s", proxy->proxy_name.c_str()); + // Create TCP server. + proxy->channel_args = grpc_channel_args_copy(args); + grpc_error_handle error = grpc_tcp_server_create( + nullptr, proxy->channel_args, + grpc_slice_allocator_factory_create( + grpc_resource_quota_from_channel_args(args, true)), + &proxy->server); + GPR_ASSERT(error == GRPC_ERROR_NONE); + // Bind to port. + grpc_resolved_address resolved_addr; + grpc_sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + memset(&resolved_addr, 0, sizeof(resolved_addr)); + addr->sin_family = GRPC_AF_INET; + grpc_sockaddr_set_port(&resolved_addr, proxy_port); + int port; + error = grpc_tcp_server_add_port(proxy->server, &resolved_addr, &port); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(port == proxy_port); + // Start server. + auto* pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &proxy->mu); + proxy->pollset.push_back(pollset); + grpc_tcp_server_start(proxy->server, &proxy->pollset, on_accept, proxy); + + // Start proxy thread. + proxy->thd = grpc_core::Thread("grpc_http_proxy", thread_main, proxy); + proxy->thd.Start(); + return proxy; +} + +static void destroy_pollset(void* arg, grpc_error_handle /*error*/) { + grpc_pollset* pollset = static_cast(arg); + grpc_pollset_destroy(pollset); + gpr_free(pollset); +} + +void grpc_end2end_http_proxy_destroy(grpc_end2end_http_proxy* proxy) { + gpr_unref(&proxy->users); // Signal proxy thread to shutdown. + grpc_core::ExecCtx exec_ctx; + proxy->thd.Join(); + grpc_tcp_server_shutdown_listeners(proxy->server); + grpc_tcp_server_unref(proxy->server); + grpc_channel_args_destroy(proxy->channel_args); + grpc_pollset_shutdown(proxy->pollset[0], + GRPC_CLOSURE_CREATE(destroy_pollset, proxy->pollset[0], + grpc_schedule_on_exec_ctx)); + GRPC_COMBINER_UNREF(proxy->combiner, "test"); + delete proxy; +} + +const char* grpc_end2end_http_proxy_get_proxy_name( + grpc_end2end_http_proxy* proxy) { + return proxy->proxy_name.c_str(); +} diff --git a/test/core/end2end/fixtures/inproc.cc b/test/core/end2end/fixtures/inproc.cc new file mode 100644 index 00000000..71f8d352 --- /dev/null +++ b/test/core/end2end/fixtures/inproc.cc @@ -0,0 +1,95 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +typedef struct inproc_fixture_data { + bool phony; // reserved for future expansion. Struct can't be empty +} inproc_fixture_data; + +static grpc_end2end_test_fixture inproc_create_fixture( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + inproc_fixture_data* ffd = static_cast( + gpr_malloc(sizeof(inproc_fixture_data))); + memset(&f, 0, sizeof(f)); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void inproc_init_client(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + f->client = grpc_inproc_channel_create(f->server, client_args, nullptr); + GPR_ASSERT(f->client); +} + +void inproc_init_server(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); +} + +void inproc_tear_down(grpc_end2end_test_fixture* f) { + inproc_fixture_data* ffd = static_cast(f->fixture_data); + gpr_free(ffd); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"inproc", FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + inproc_create_fixture, inproc_init_client, inproc_init_server, + inproc_tear_down}, +}; + +int main(int argc, char** argv) { + size_t i; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_end2end_tests_pre_init(); + grpc_init(); + + for (i = 0; i < sizeof(configs) / sizeof(*configs); i++) { + grpc_end2end_tests(argc, argv, configs[i]); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/fixtures/local_util.cc b/test/core/end2end/fixtures/local_util.cc new file mode 100644 index 00000000..d7431f5d --- /dev/null +++ b/test/core/end2end/fixtures/local_util.cc @@ -0,0 +1,112 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/end2end/fixtures/local_util.h" + +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/surface/server.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +grpc_end2end_test_fixture grpc_end2end_local_chttp2_create_fixture_fullstack() { + grpc_end2end_test_fixture f; + grpc_end2end_local_fullstack_fixture_data* ffd = + new grpc_end2end_local_fullstack_fixture_data(); + memset(&f, 0, sizeof(f)); + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + return f; +} + +void grpc_end2end_local_chttp2_init_client_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_local_connect_type type) { + grpc_channel_credentials* creds = grpc_local_credentials_create(type); + grpc_end2end_local_fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +/* + * Check if server should fail auth check. If it is true, a different metadata + * processor will be installed that always fails in processing client's + * metadata. + */ +static bool fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return false; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return true; + } + } + return false; +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +void grpc_end2end_local_chttp2_init_server_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_local_connect_type type) { + grpc_server_credentials* creds = grpc_local_server_credentials_create(type); + grpc_end2end_local_fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + if (fail_server_auth_check(server_args)) { + grpc_auth_metadata_processor processor = {process_auth_failure, nullptr, + nullptr}; + grpc_server_credentials_set_auth_metadata_processor(creds, processor); + } + GPR_ASSERT(grpc_server_add_secure_http2_port(f->server, + ffd->localaddr.c_str(), creds)); + grpc_server_credentials_release(creds); + grpc_server_start(f->server); +} + +void grpc_end2end_local_chttp2_tear_down_fullstack( + grpc_end2end_test_fixture* f) { + grpc_end2end_local_fullstack_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} diff --git a/test/core/end2end/fixtures/proxy.cc b/test/core/end2end/fixtures/proxy.cc new file mode 100644 index 00000000..84f6bf4b --- /dev/null +++ b/test/core/end2end/fixtures/proxy.cc @@ -0,0 +1,468 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/end2end/fixtures/proxy.h" + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/surface/call.h" +#include "test/core/util/port.h" + +struct grpc_end2end_proxy { + grpc_end2end_proxy() + : cq(nullptr), + server(nullptr), + client(nullptr), + shutdown(false), + new_call(nullptr) { + memset(&new_call_details, 0, sizeof(new_call_details)); + memset(&new_call_metadata, 0, sizeof(new_call_metadata)); + } + grpc_core::Thread thd; + std::string proxy_port; + std::string server_port; + grpc_completion_queue* cq; + grpc_server* server; + grpc_channel* client; + + int shutdown; + + /* requested call */ + grpc_call* new_call; + grpc_call_details new_call_details; + grpc_metadata_array new_call_metadata; +}; + +typedef struct { + void (*func)(void* arg, int success); + void* arg; +} closure; + +typedef struct { + gpr_refcount refs; + grpc_end2end_proxy* proxy; + + grpc_call* c2p; + grpc_call* p2s; + + grpc_metadata_array c2p_initial_metadata; + grpc_metadata_array p2s_initial_metadata; + + grpc_byte_buffer* c2p_msg; + grpc_byte_buffer* p2s_msg; + + grpc_metadata_array p2s_trailing_metadata; + grpc_status_code p2s_status; + grpc_slice p2s_status_details; + + int c2p_server_cancelled; +} proxy_call; + +static void thread_main(void* arg); +static void request_call(grpc_end2end_proxy* proxy); + +grpc_end2end_proxy* grpc_end2end_proxy_create(const grpc_end2end_proxy_def* def, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + int proxy_port = grpc_pick_unused_port_or_die(); + int server_port = grpc_pick_unused_port_or_die(); + + grpc_end2end_proxy* proxy = new grpc_end2end_proxy(); + + proxy->proxy_port = grpc_core::JoinHostPort("localhost", proxy_port); + proxy->server_port = grpc_core::JoinHostPort("localhost", server_port); + + gpr_log(GPR_DEBUG, "PROXY ADDR:%s BACKEND:%s", proxy->proxy_port.c_str(), + proxy->server_port.c_str()); + + proxy->cq = grpc_completion_queue_create_for_next(nullptr); + proxy->server = def->create_server(proxy->proxy_port.c_str(), server_args); + + const char* arg_to_remove = GRPC_ARG_ENABLE_RETRIES; + grpc_arg arg_to_add = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_RETRIES), 0); + grpc_channel_args* proxy_client_args = + grpc_channel_args_copy_and_add_and_remove(client_args, &arg_to_remove, 1, + &arg_to_add, 1); + proxy->client = + def->create_client(proxy->server_port.c_str(), proxy_client_args); + grpc_channel_args_destroy(proxy_client_args); + + grpc_server_register_completion_queue(proxy->server, proxy->cq, nullptr); + grpc_server_start(proxy->server); + + grpc_call_details_init(&proxy->new_call_details); + proxy->thd = grpc_core::Thread("grpc_end2end_proxy", thread_main, proxy); + proxy->thd.Start(); + + request_call(proxy); + + return proxy; +} + +static closure* new_closure(void (*func)(void* arg, int success), void* arg) { + closure* cl = static_cast(gpr_malloc(sizeof(*cl))); + cl->func = func; + cl->arg = arg; + return cl; +} + +static void shutdown_complete(void* arg, int /*success*/) { + grpc_end2end_proxy* proxy = static_cast(arg); + proxy->shutdown = 1; + grpc_completion_queue_shutdown(proxy->cq); +} + +void grpc_end2end_proxy_destroy(grpc_end2end_proxy* proxy) { + grpc_server_shutdown_and_notify(proxy->server, proxy->cq, + new_closure(shutdown_complete, proxy)); + proxy->thd.Join(); + grpc_server_destroy(proxy->server); + grpc_channel_destroy(proxy->client); + grpc_completion_queue_destroy(proxy->cq); + grpc_call_details_destroy(&proxy->new_call_details); + delete proxy; +} + +static void unrefpc(proxy_call* pc, const char* /*reason*/) { + if (gpr_unref(&pc->refs)) { + grpc_call_unref(pc->c2p); + grpc_call_unref(pc->p2s); + grpc_metadata_array_destroy(&pc->c2p_initial_metadata); + grpc_metadata_array_destroy(&pc->p2s_initial_metadata); + grpc_metadata_array_destroy(&pc->p2s_trailing_metadata); + grpc_slice_unref(pc->p2s_status_details); + gpr_free(pc); + } +} + +static void refpc(proxy_call* pc, const char* /*reason*/) { + gpr_ref(&pc->refs); +} + +static void on_c2p_sent_initial_metadata(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + unrefpc(pc, "on_c2p_sent_initial_metadata"); +} + +static void on_p2s_recv_initial_metadata(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + grpc_op op; + grpc_call_error err; + + memset(&op, 0, sizeof(op)); + if (!pc->proxy->shutdown && !grpc_call_is_trailers_only(pc->p2s)) { + op.op = GRPC_OP_SEND_INITIAL_METADATA; + op.flags = 0; + op.reserved = nullptr; + op.data.send_initial_metadata.count = pc->p2s_initial_metadata.count; + op.data.send_initial_metadata.metadata = pc->p2s_initial_metadata.metadata; + refpc(pc, "on_c2p_sent_initial_metadata"); + err = grpc_call_start_batch(pc->c2p, &op, 1, + new_closure(on_c2p_sent_initial_metadata, pc), + nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } + + unrefpc(pc, "on_p2s_recv_initial_metadata"); +} + +static void on_p2s_sent_initial_metadata(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + unrefpc(pc, "on_p2s_sent_initial_metadata"); +} + +static void on_c2p_recv_msg(void* arg, int success); + +static void on_p2s_sent_message(void* arg, int success) { + proxy_call* pc = static_cast(arg); + grpc_op op; + grpc_call_error err; + + grpc_byte_buffer_destroy(pc->c2p_msg); + if (!pc->proxy->shutdown && success) { + op.op = GRPC_OP_RECV_MESSAGE; + op.flags = 0; + op.reserved = nullptr; + op.data.recv_message.recv_message = &pc->c2p_msg; + refpc(pc, "on_c2p_recv_msg"); + err = grpc_call_start_batch(pc->c2p, &op, 1, + new_closure(on_c2p_recv_msg, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } + + unrefpc(pc, "on_p2s_sent_message"); +} + +static void on_p2s_sent_close(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + unrefpc(pc, "on_p2s_sent_close"); +} + +static void on_c2p_recv_msg(void* arg, int success) { + proxy_call* pc = static_cast(arg); + grpc_op op; + grpc_call_error err; + + if (!pc->proxy->shutdown && success) { + if (pc->c2p_msg != nullptr) { + op.op = GRPC_OP_SEND_MESSAGE; + op.flags = 0; + op.reserved = nullptr; + op.data.send_message.send_message = pc->c2p_msg; + refpc(pc, "on_p2s_sent_message"); + err = grpc_call_start_batch( + pc->p2s, &op, 1, new_closure(on_p2s_sent_message, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } else { + op.op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op.flags = 0; + op.reserved = nullptr; + refpc(pc, "on_p2s_sent_close"); + err = grpc_call_start_batch(pc->p2s, &op, 1, + new_closure(on_p2s_sent_close, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } + } else { + if (pc->c2p_msg != nullptr) { + grpc_byte_buffer_destroy(pc->c2p_msg); + } + } + + unrefpc(pc, "on_c2p_recv_msg"); +} + +static void on_p2s_recv_msg(void* arg, int success); + +static void on_c2p_sent_message(void* arg, int success) { + proxy_call* pc = static_cast(arg); + grpc_op op; + grpc_call_error err; + + grpc_byte_buffer_destroy(pc->p2s_msg); + if (!pc->proxy->shutdown && success) { + op.op = GRPC_OP_RECV_MESSAGE; + op.flags = 0; + op.reserved = nullptr; + op.data.recv_message.recv_message = &pc->p2s_msg; + refpc(pc, "on_p2s_recv_msg"); + err = grpc_call_start_batch(pc->p2s, &op, 1, + new_closure(on_p2s_recv_msg, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } + + unrefpc(pc, "on_c2p_sent_message"); +} + +static void on_p2s_recv_msg(void* arg, int success) { + proxy_call* pc = static_cast(arg); + grpc_op op; + grpc_call_error err; + + if (!pc->proxy->shutdown && success && pc->p2s_msg) { + op.op = GRPC_OP_SEND_MESSAGE; + op.flags = 0; + op.reserved = nullptr; + op.data.send_message.send_message = pc->p2s_msg; + refpc(pc, "on_c2p_sent_message"); + err = grpc_call_start_batch(pc->c2p, &op, 1, + new_closure(on_c2p_sent_message, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } else { + grpc_byte_buffer_destroy(pc->p2s_msg); + } + unrefpc(pc, "on_p2s_recv_msg"); +} + +static void on_c2p_sent_status(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + unrefpc(pc, "on_c2p_sent_status"); +} + +static void on_p2s_status(void* arg, int success) { + proxy_call* pc = static_cast(arg); + grpc_op op[2]; // Possibly send empty initial metadata also if trailers-only + grpc_call_error err; + + memset(op, 0, sizeof(op)); + + if (!pc->proxy->shutdown) { + GPR_ASSERT(success); + + int op_count = 0; + if (grpc_call_is_trailers_only(pc->p2s)) { + op[op_count].op = GRPC_OP_SEND_INITIAL_METADATA; + op_count++; + } + + op[op_count].op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op[op_count].flags = 0; + op[op_count].reserved = nullptr; + op[op_count].data.send_status_from_server.trailing_metadata_count = + pc->p2s_trailing_metadata.count; + op[op_count].data.send_status_from_server.trailing_metadata = + pc->p2s_trailing_metadata.metadata; + op[op_count].data.send_status_from_server.status = pc->p2s_status; + op[op_count].data.send_status_from_server.status_details = + &pc->p2s_status_details; + op_count++; + refpc(pc, "on_c2p_sent_status"); + err = grpc_call_start_batch(pc->c2p, op, op_count, + new_closure(on_c2p_sent_status, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + } + + unrefpc(pc, "on_p2s_status"); +} + +static void on_c2p_closed(void* arg, int /*success*/) { + proxy_call* pc = static_cast(arg); + unrefpc(pc, "on_c2p_closed"); +} + +static void on_new_call(void* arg, int success) { + grpc_end2end_proxy* proxy = static_cast(arg); + grpc_call_error err; + + if (success) { + grpc_op op; + memset(&op, 0, sizeof(op)); + proxy_call* pc = static_cast(gpr_malloc(sizeof(*pc))); + memset(pc, 0, sizeof(*pc)); + pc->proxy = proxy; + std::swap(pc->c2p_initial_metadata, proxy->new_call_metadata); + pc->c2p = proxy->new_call; + pc->p2s = grpc_channel_create_call( + proxy->client, pc->c2p, GRPC_PROPAGATE_DEFAULTS, proxy->cq, + proxy->new_call_details.method, &proxy->new_call_details.host, + proxy->new_call_details.deadline, nullptr); + gpr_ref_init(&pc->refs, 1); + + op.reserved = nullptr; + + op.op = GRPC_OP_RECV_INITIAL_METADATA; + op.flags = 0; + op.data.recv_initial_metadata.recv_initial_metadata = + &pc->p2s_initial_metadata; + refpc(pc, "on_p2s_recv_initial_metadata"); + err = grpc_call_start_batch(pc->p2s, &op, 1, + new_closure(on_p2s_recv_initial_metadata, pc), + nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + op.op = GRPC_OP_SEND_INITIAL_METADATA; + op.flags = proxy->new_call_details.flags; + op.data.send_initial_metadata.count = pc->c2p_initial_metadata.count; + op.data.send_initial_metadata.metadata = pc->c2p_initial_metadata.metadata; + refpc(pc, "on_p2s_sent_initial_metadata"); + err = grpc_call_start_batch(pc->p2s, &op, 1, + new_closure(on_p2s_sent_initial_metadata, pc), + nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + op.op = GRPC_OP_RECV_MESSAGE; + op.flags = 0; + op.data.recv_message.recv_message = &pc->c2p_msg; + refpc(pc, "on_c2p_recv_msg"); + err = grpc_call_start_batch(pc->c2p, &op, 1, + new_closure(on_c2p_recv_msg, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + op.op = GRPC_OP_RECV_MESSAGE; + op.flags = 0; + op.data.recv_message.recv_message = &pc->p2s_msg; + refpc(pc, "on_p2s_recv_msg"); + err = grpc_call_start_batch(pc->p2s, &op, 1, + new_closure(on_p2s_recv_msg, pc), nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + op.op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op.flags = 0; + op.data.recv_status_on_client.trailing_metadata = + &pc->p2s_trailing_metadata; + op.data.recv_status_on_client.status = &pc->p2s_status; + op.data.recv_status_on_client.status_details = &pc->p2s_status_details; + refpc(pc, "on_p2s_status"); + err = grpc_call_start_batch(pc->p2s, &op, 1, new_closure(on_p2s_status, pc), + nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + op.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op.flags = 0; + op.data.recv_close_on_server.cancelled = &pc->c2p_server_cancelled; + refpc(pc, "on_c2p_closed"); + err = grpc_call_start_batch(pc->c2p, &op, 1, new_closure(on_c2p_closed, pc), + nullptr); + GPR_ASSERT(err == GRPC_CALL_OK); + + request_call(proxy); + + grpc_call_details_destroy(&proxy->new_call_details); + grpc_call_details_init(&proxy->new_call_details); + + unrefpc(pc, "init"); + } else { + GPR_ASSERT(proxy->new_call == nullptr); + } +} + +static void request_call(grpc_end2end_proxy* proxy) { + proxy->new_call = nullptr; + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + proxy->server, &proxy->new_call, + &proxy->new_call_details, + &proxy->new_call_metadata, proxy->cq, + proxy->cq, new_closure(on_new_call, proxy))); +} + +static void thread_main(void* arg) { + grpc_end2end_proxy* proxy = static_cast(arg); + closure* cl; + for (;;) { + grpc_event ev = grpc_completion_queue_next( + proxy->cq, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + switch (ev.type) { + case GRPC_QUEUE_TIMEOUT: + gpr_log(GPR_ERROR, "Should never reach here"); + abort(); + case GRPC_QUEUE_SHUTDOWN: + return; + case GRPC_OP_COMPLETE: + cl = static_cast(ev.tag); + cl->func(cl->arg, ev.success); + gpr_free(cl); + break; + } + } +} + +const char* grpc_end2end_proxy_get_client_target(grpc_end2end_proxy* proxy) { + return proxy->proxy_port.c_str(); +} + +const char* grpc_end2end_proxy_get_server_port(grpc_end2end_proxy* proxy) { + return proxy->server_port.c_str(); +} diff --git a/test/core/end2end/fuzzers/api_fuzzer.cc b/test/core/end2end/fuzzers/api_fuzzer.cc new file mode 100644 index 00000000..f62c1de2 --- /dev/null +++ b/test/core/end2end/fuzzers/api_fuzzer.cc @@ -0,0 +1,1071 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/metadata.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "test/core/end2end/data/ssl_test_data.h" +#include "test/core/end2end/fuzzers/api_fuzzer.pb.h" +#include "test/core/util/passthru_endpoint.h" + +//////////////////////////////////////////////////////////////////////////////// +// logging + +bool squelch = true; +bool leak_check = true; + +static void dont_log(gpr_log_func_args* /*args*/) {} + +//////////////////////////////////////////////////////////////////////////////// +// global state + +static gpr_timespec g_now; +static grpc_server* g_server; +static grpc_channel* g_channel; +static grpc_resource_quota* g_resource_quota; + +extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type); + +static gpr_timespec now_impl(gpr_clock_type clock_type) { + GPR_ASSERT(clock_type != GPR_TIMESPAN); + gpr_timespec ts = g_now; + ts.clock_type = clock_type; + return ts; +} + +//////////////////////////////////////////////////////////////////////////////// +// dns resolution + +typedef struct addr_req { + grpc_timer timer; + char* addr; + grpc_closure* on_done; + grpc_resolved_addresses** addrs; + std::unique_ptr* addresses; +} addr_req; + +static void finish_resolve(void* arg, grpc_error_handle error) { + addr_req* r = static_cast(arg); + + if (error == GRPC_ERROR_NONE && 0 == strcmp(r->addr, "server")) { + if (r->addrs != nullptr) { + grpc_resolved_addresses* addrs = + static_cast(gpr_malloc(sizeof(*addrs))); + addrs->naddrs = 1; + addrs->addrs = static_cast( + gpr_malloc(sizeof(*addrs->addrs))); + addrs->addrs[0].len = 0; + *r->addrs = addrs; + } else if (r->addresses != nullptr) { + *r->addresses = absl::make_unique(); + grpc_resolved_address fake_resolved_address; + memset(&fake_resolved_address, 0, sizeof(fake_resolved_address)); + fake_resolved_address.len = 0; + (*r->addresses)->emplace_back(fake_resolved_address, nullptr); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, GRPC_ERROR_NONE); + } else { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, r->on_done, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Resolution failed", &error, 1)); + } + + gpr_free(r->addr); + delete r; +} + +void my_resolve_address(const char* addr, const char* /*default_port*/, + grpc_pollset_set* /*interested_parties*/, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + addr_req* r = new addr_req(); + r->addr = gpr_strdup(addr); + r->on_done = on_done; + r->addrs = addrs; + grpc_timer_init( + &r->timer, GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now(), + GRPC_CLOSURE_CREATE(finish_resolve, r, grpc_schedule_on_exec_ctx)); +} + +static grpc_address_resolver_vtable fuzzer_resolver = {my_resolve_address, + nullptr}; + +grpc_ares_request* my_dns_lookup_ares_locked( + const char* /*dns_server*/, const char* addr, const char* /*default_port*/, + grpc_pollset_set* /*interested_parties*/, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* /*balancer_addresses*/, + char** /*service_config_json*/, int /*query_timeout*/, + std::shared_ptr /*combiner*/) { + addr_req* r = new addr_req(); + r->addr = gpr_strdup(addr); + r->on_done = on_done; + r->addrs = nullptr; + r->addresses = addresses; + grpc_timer_init( + &r->timer, GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now(), + GRPC_CLOSURE_CREATE(finish_resolve, r, grpc_schedule_on_exec_ctx)); + return nullptr; +} + +static void my_cancel_ares_request_locked(grpc_ares_request* request) { + GPR_ASSERT(request == nullptr); +} + +//////////////////////////////////////////////////////////////////////////////// +// client connection + +static void sched_connect(grpc_closure* closure, + grpc_slice_allocator* slice_allocator, + grpc_endpoint** ep, gpr_timespec deadline); + +typedef struct { + grpc_timer timer; + grpc_closure* closure; + grpc_endpoint** ep; + gpr_timespec deadline; + grpc_slice_allocator* slice_allocator; +} future_connect; + +static void do_connect(void* arg, grpc_error_handle error) { + future_connect* fc = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + grpc_slice_allocator_destroy(fc->slice_allocator); + *fc->ep = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, fc->closure, GRPC_ERROR_REF(error)); + } else if (g_server != nullptr) { + grpc_slice_allocator_destroy(fc->slice_allocator); + grpc_endpoint* client; + grpc_endpoint* server; + grpc_passthru_endpoint_create(&client, &server, nullptr); + *fc->ep = client; + + grpc_transport* transport = grpc_create_chttp2_transport( + nullptr, server, false, + grpc_resource_user_create(g_resource_quota, "transport-user")); + GPR_ASSERT(GRPC_LOG_IF_ERROR("SetupTransport", + g_server->core_server->SetupTransport( + transport, nullptr, nullptr, nullptr))); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + + grpc_core::ExecCtx::Run(DEBUG_LOCATION, fc->closure, GRPC_ERROR_NONE); + } else { + sched_connect(fc->closure, fc->slice_allocator, fc->ep, fc->deadline); + } + gpr_free(fc); +} + +static void sched_connect(grpc_closure* closure, + grpc_slice_allocator* slice_allocator, + grpc_endpoint** ep, gpr_timespec deadline) { + if (gpr_time_cmp(deadline, gpr_now(deadline.clock_type)) < 0) { + *ep = nullptr; + grpc_slice_allocator_destroy(slice_allocator); + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, closure, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Connect deadline exceeded")); + return; + } + + future_connect* fc = static_cast(gpr_malloc(sizeof(*fc))); + fc->closure = closure; + fc->ep = ep; + fc->deadline = deadline; + fc->slice_allocator = slice_allocator; + grpc_timer_init( + &fc->timer, GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now(), + GRPC_CLOSURE_CREATE(do_connect, fc, grpc_schedule_on_exec_ctx)); +} + +static void my_tcp_client_connect(grpc_closure* closure, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* /*interested_parties*/, + const grpc_channel_args* /*channel_args*/, + const grpc_resolved_address* /*addr*/, + grpc_millis deadline) { + sched_connect(closure, slice_allocator, ep, + grpc_millis_to_timespec(deadline, GPR_CLOCK_MONOTONIC)); +} + +grpc_tcp_client_vtable fuzz_tcp_client_vtable = {my_tcp_client_connect}; + +//////////////////////////////////////////////////////////////////////////////// +// test driver + +class Validator { + public: + explicit Validator(std::function impl) : impl_(impl) {} + + virtual ~Validator() {} + void Run(bool success) { + impl_(success); + delete this; + } + + private: + std::function impl_; +}; + +Validator* MakeValidator(std::function impl) { + return new Validator(std::move(impl)); +} + +static Validator* AssertSuccessAndDecrement(int* counter) { + return MakeValidator([counter](bool success) { + GPR_ASSERT(success); + --*counter; + }); +} + +static Validator* Decrement(int* counter) { + return MakeValidator([counter](bool) { --*counter; }); +} + +static Validator* ValidateConnectivityWatch(gpr_timespec deadline, + int* counter) { + return MakeValidator([deadline, counter](bool success) { + if (!success) { + GPR_ASSERT(gpr_time_cmp(gpr_now(deadline.clock_type), deadline) >= 0); + } + --*counter; + }); +} + +static void free_non_null(void* p) { + GPR_ASSERT(p != nullptr); + gpr_free(p); +} + +enum class CallType { CLIENT, SERVER, PENDING_SERVER, TOMBSTONED }; + +class Call : public std::enable_shared_from_this { + public: + explicit Call(CallType type) : type_(type) {} + ~Call(); + + CallType type() const { return type_; } + + bool done() const { + if ((type_ == CallType::TOMBSTONED || call_closed_) && pending_ops_ == 0) { + return true; + } + if (call_ == nullptr && type() != CallType::PENDING_SERVER) return true; + return false; + } + + void Shutdown() { + if (call_ != nullptr) { + grpc_call_cancel(call_, nullptr); + type_ = CallType::TOMBSTONED; + } + } + + void SetCall(grpc_call* call) { + GPR_ASSERT(call_ == nullptr); + call_ = call; + } + + grpc_call* call() const { return call_; } + + void RequestCall(grpc_server* server, grpc_completion_queue* cq) { + auto* v = FinishedRequestCall(); + grpc_call_error error = grpc_server_request_call( + server, &call_, &call_details_, &recv_initial_metadata_, cq, cq, v); + if (error != GRPC_CALL_OK) { + v->Run(false); + } + } + + void* Allocate(size_t size) { + void* p = gpr_malloc(size); + free_pointers_.push_back(p); + return p; + } + + template + T* AllocArray(size_t elems) { + return static_cast(Allocate(sizeof(T) * elems)); + } + + template + T* NewCopy(T value) { + T* p = AllocArray(1); + new (p) T(value); + return p; + } + + template + grpc_slice ReadSlice(const T& s) { + grpc_slice slice = grpc_slice_from_cpp_string(s.value()); + if (s.intern()) { + auto interned_slice = grpc_slice_intern(slice); + grpc_slice_unref(slice); + slice = interned_slice; + } + unref_slices_.push_back(slice); + return slice; + } + + template + grpc_metadata_array ReadMetadata(const M& metadata) { + grpc_metadata* m = AllocArray(metadata.size()); + for (int i = 0; i < metadata.size(); ++i) { + m[i].key = ReadSlice(metadata[i].key()); + m[i].value = ReadSlice(metadata[i].value()); + } + return grpc_metadata_array{static_cast(metadata.size()), + static_cast(metadata.size()), m}; + } + + absl::optional ReadOp( + const api_fuzzer::BatchOp& batch_op, bool* batch_is_ok, + uint8_t* batch_ops, std::vector>* unwinders) { + grpc_op op; + memset(&op, 0, sizeof(op)); + switch (batch_op.op_case()) { + case api_fuzzer::BatchOp::OP_NOT_SET: + /* invalid value */ + return {}; + case api_fuzzer::BatchOp::kSendInitialMetadata: + if (sent_initial_metadata_) { + *batch_is_ok = false; + } else { + sent_initial_metadata_ = true; + op.op = GRPC_OP_SEND_INITIAL_METADATA; + *batch_ops |= 1 << GRPC_OP_SEND_INITIAL_METADATA; + auto ary = ReadMetadata(batch_op.send_initial_metadata().metadata()); + op.data.send_initial_metadata.count = ary.count; + op.data.send_initial_metadata.metadata = ary.metadata; + } + break; + case api_fuzzer::BatchOp::kSendMessage: + op.op = GRPC_OP_SEND_MESSAGE; + if (send_message_ != nullptr) { + *batch_is_ok = false; + } else { + *batch_ops |= 1 << GRPC_OP_SEND_MESSAGE; + auto send = ReadSlice(batch_op.send_message().message()); + send_message_ = op.data.send_message.send_message = + grpc_raw_byte_buffer_create(&send, 1); + unwinders->push_back([this]() { + grpc_byte_buffer_destroy(send_message_); + send_message_ = nullptr; + }); + } + break; + case api_fuzzer::BatchOp::kSendCloseFromClient: + op.op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + *batch_ops |= 1 << GRPC_OP_SEND_CLOSE_FROM_CLIENT; + break; + case api_fuzzer::BatchOp::kSendStatusFromServer: { + op.op = GRPC_OP_SEND_STATUS_FROM_SERVER; + *batch_ops |= 1 << GRPC_OP_SEND_STATUS_FROM_SERVER; + auto ary = ReadMetadata(batch_op.send_status_from_server().metadata()); + op.data.send_status_from_server.trailing_metadata_count = ary.count; + op.data.send_status_from_server.trailing_metadata = ary.metadata; + op.data.send_status_from_server.status = static_cast( + batch_op.send_status_from_server().status_code()); + op.data.send_status_from_server.status_details = + batch_op.send_status_from_server().has_status_details() + ? NewCopy(ReadSlice( + batch_op.send_status_from_server().status_details())) + : nullptr; + } break; + case api_fuzzer::BatchOp::kReceiveInitialMetadata: + op.op = GRPC_OP_RECV_INITIAL_METADATA; + *batch_ops |= 1 << GRPC_OP_RECV_INITIAL_METADATA; + op.data.recv_initial_metadata.recv_initial_metadata = + &recv_initial_metadata_; + break; + case api_fuzzer::BatchOp::kReceiveMessage: + if (call_closed_) { + *batch_is_ok = false; + } else { + op.op = GRPC_OP_RECV_MESSAGE; + *batch_ops |= 1 << GRPC_OP_RECV_MESSAGE; + op.data.recv_message.recv_message = &recv_message_; + } + break; + case api_fuzzer::BatchOp::kReceiveStatusOnClient: + op.op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op.data.recv_status_on_client.status = &status_; + op.data.recv_status_on_client.trailing_metadata = + &recv_trailing_metadata_; + op.data.recv_status_on_client.status_details = &recv_status_details_; + break; + case api_fuzzer::BatchOp::kReceiveCloseOnServer: + op.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + *batch_ops |= 1 << GRPC_OP_RECV_CLOSE_ON_SERVER; + op.data.recv_close_on_server.cancelled = &cancelled_; + break; + } + op.reserved = nullptr; + op.flags = batch_op.flags(); + return op; + } + + Validator* FinishedBatchValidator(uint8_t has_ops) { + ++pending_ops_; + auto self = shared_from_this(); + return MakeValidator([self, has_ops](bool) { + --self->pending_ops_; + if ((has_ops & (1u << GRPC_OP_RECV_MESSAGE)) && self->call_closed_) { + GPR_ASSERT(self->recv_message_ == nullptr); + } + if ((has_ops & (1u << GRPC_OP_RECV_MESSAGE) && + self->recv_message_ != nullptr)) { + grpc_byte_buffer_destroy(self->recv_message_); + self->recv_message_ = nullptr; + } + if ((has_ops & (1u << GRPC_OP_SEND_MESSAGE))) { + grpc_byte_buffer_destroy(self->send_message_); + self->send_message_ = nullptr; + } + if ((has_ops & (1u << GRPC_OP_RECV_STATUS_ON_CLIENT)) || + (has_ops & (1u << GRPC_OP_RECV_CLOSE_ON_SERVER))) { + self->call_closed_ = true; + } + }); + } + + Validator* FinishedRequestCall() { + ++pending_ops_; + auto self = shared_from_this(); + return MakeValidator([self](bool success) { + GPR_ASSERT(self->pending_ops_ > 0); + --self->pending_ops_; + if (success) { + GPR_ASSERT(self->call_ != nullptr); + self->type_ = CallType::SERVER; + } else { + self->type_ = CallType::TOMBSTONED; + } + }); + } + + private: + CallType type_; + grpc_call* call_ = nullptr; + grpc_byte_buffer* recv_message_ = nullptr; + grpc_status_code status_; + grpc_metadata_array recv_initial_metadata_{0, 0, nullptr}; + grpc_metadata_array recv_trailing_metadata_{0, 0, nullptr}; + grpc_slice recv_status_details_ = grpc_empty_slice(); + // set by receive close on server, unset here to trigger + // msan if misused + int cancelled_; + int pending_ops_ = 0; + bool sent_initial_metadata_ = false; + grpc_call_details call_details_{}; + grpc_byte_buffer* send_message_ = nullptr; + bool call_closed_ = false; + + std::vector free_pointers_; + std::vector unref_slices_; +}; + +static std::vector> g_calls; +static size_t g_active_call = 0; + +static Call* ActiveCall() { + while (!g_calls.empty()) { + if (g_active_call >= g_calls.size()) { + g_active_call = 0; + } + if (g_calls[g_active_call] != nullptr && !g_calls[g_active_call]->done()) { + return g_calls[g_active_call].get(); + } + g_calls.erase(g_calls.begin() + g_active_call); + } + return nullptr; +} + +Call::~Call() { + if (call_ != nullptr) { + grpc_call_unref(call_); + } + + grpc_slice_unref(recv_status_details_); + grpc_call_details_destroy(&call_details_); + + for (auto* p : free_pointers_) { + gpr_free(p); + } + for (auto s : unref_slices_) { + grpc_slice_unref(s); + } + + grpc_metadata_array_destroy(&recv_initial_metadata_); + grpc_metadata_array_destroy(&recv_trailing_metadata_); +} + +template +grpc_channel_args* ReadArgs(const ChannelArgContainer& args) { + grpc_channel_args* res = + static_cast(gpr_malloc(sizeof(grpc_channel_args))); + res->num_args = args.size(); + res->args = + static_cast(gpr_malloc(sizeof(grpc_arg) * args.size())); + for (int i = 0; i < args.size(); i++) { + res->args[i].key = gpr_strdup(args[i].key().c_str()); + switch (args[i].value_case()) { + case api_fuzzer::ChannelArg::kStr: + res->args[i].type = GRPC_ARG_STRING; + res->args[i].value.string = gpr_strdup(args[i].str().c_str()); + break; + case api_fuzzer::ChannelArg::kI: + res->args[i].type = GRPC_ARG_INTEGER; + res->args[i].value.integer = args[i].i(); + break; + case api_fuzzer::ChannelArg::kResourceQuota: + grpc_resource_quota_ref(g_resource_quota); + res->args[i].type = GRPC_ARG_POINTER; + res->args[i].value.pointer.p = g_resource_quota; + res->args[i].value.pointer.vtable = grpc_resource_quota_arg_vtable(); + break; + case api_fuzzer::ChannelArg::VALUE_NOT_SET: + res->args[i].type = GRPC_ARG_INTEGER; + res->args[i].value.integer = 0; + break; + } + } + return res; +} + +static const char* ReadCredArtifact( + const api_fuzzer::CredArtifact& artifact, + std::initializer_list builtins) { + switch (artifact.type_case()) { + case api_fuzzer::CredArtifact::kCustom: + return artifact.custom().c_str(); + case api_fuzzer::CredArtifact::kBuiltin: + if (artifact.builtin() < 0) return nullptr; + if (artifact.builtin() < static_cast(builtins.size())) { + return *(builtins.begin() + artifact.builtin()); + } + return nullptr; + case api_fuzzer::CredArtifact::TYPE_NOT_SET: + return nullptr; + } +} + +static grpc_channel_credentials* ReadSslChannelCreds( + const api_fuzzer::SslChannelCreds& creds) { + const char* root_certs = + creds.has_root_certs() + ? ReadCredArtifact(creds.root_certs(), {test_root_cert}) + : nullptr; + const char* private_key = + creds.has_private_key() + ? ReadCredArtifact(creds.private_key(), + {test_server1_key, test_self_signed_client_key, + test_signed_client_key}) + : nullptr; + const char* certs = + creds.has_certs() + ? ReadCredArtifact(creds.certs(), + {test_server1_cert, test_self_signed_client_cert, + test_signed_client_cert}) + : nullptr; + grpc_ssl_pem_key_cert_pair key_cert_pair = {private_key, certs}; + return grpc_ssl_credentials_create( + root_certs, + private_key != nullptr && certs != nullptr ? &key_cert_pair : nullptr, + nullptr, nullptr); +} + +static grpc_call_credentials* ReadCallCreds( + const api_fuzzer::CallCreds& creds) { + switch (creds.type_case()) { + case api_fuzzer::CallCreds::TYPE_NOT_SET: + return nullptr; + case api_fuzzer::CallCreds::kNull: + return nullptr; + case api_fuzzer::CallCreds::kCompositeCallCreds: { + grpc_call_credentials* out = nullptr; + for (const auto& child_creds : + creds.composite_call_creds().call_creds()) { + grpc_call_credentials* child = ReadCallCreds(child_creds); + if (child != nullptr) { + if (out == nullptr) { + out = child; + } else { + auto* composed = + grpc_composite_call_credentials_create(out, child, nullptr); + grpc_call_credentials_release(child); + grpc_call_credentials_release(out); + out = composed; + } + } + } + return out; + } + case api_fuzzer::CallCreds::kAccessToken: + return grpc_access_token_credentials_create(creds.access_token().c_str(), + nullptr); + case api_fuzzer::CallCreds::kIam: + return grpc_google_iam_credentials_create( + creds.iam().auth_token().c_str(), creds.iam().auth_selector().c_str(), + nullptr); + /* TODO(ctiller): more cred types here */ + } +} + +static grpc_channel_credentials* ReadChannelCreds( + const api_fuzzer::ChannelCreds& creds) { + switch (creds.type_case()) { + case api_fuzzer::ChannelCreds::TYPE_NOT_SET: + return nullptr; + case api_fuzzer::ChannelCreds::kSslChannelCreds: + return ReadSslChannelCreds(creds.ssl_channel_creds()); + case api_fuzzer::ChannelCreds::kCompositeChannelCreds: { + const auto& comp = creds.composite_channel_creds(); + grpc_channel_credentials* c1 = + comp.has_channel_creds() ? ReadChannelCreds(comp.channel_creds()) + : nullptr; + grpc_call_credentials* c2 = + comp.has_call_creds() ? ReadCallCreds(comp.call_creds()) : nullptr; + if (c1 != nullptr && c2 != nullptr) { + grpc_channel_credentials* out = + grpc_composite_channel_credentials_create(c1, c2, nullptr); + grpc_channel_credentials_release(c1); + grpc_call_credentials_release(c2); + return out; + } else if (c1 != nullptr) { + return c1; + } else if (c2 != nullptr) { + grpc_call_credentials_release(c2); + return nullptr; + } else { + return nullptr; + } + GPR_UNREACHABLE_CODE(return nullptr); + } + case api_fuzzer::ChannelCreds::kNull: + return nullptr; + } +} + +DEFINE_PROTO_FUZZER(const api_fuzzer::Msg& msg) { + grpc_test_only_set_slice_hash_seed(0); + char* grpc_trace_fuzzer = gpr_getenv("GRPC_TRACE_FUZZER"); + if (squelch && grpc_trace_fuzzer == nullptr) gpr_set_log_function(dont_log); + gpr_free(grpc_trace_fuzzer); + grpc_set_tcp_client_impl(&fuzz_tcp_client_vtable); + gpr_now_impl = now_impl; + grpc_init(); + grpc_timer_manager_set_threading(false); + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Executor::SetThreadingAll(false); + } + grpc_set_resolver_impl(&fuzzer_resolver); + grpc_dns_lookup_ares_locked = my_dns_lookup_ares_locked; + grpc_cancel_ares_request_locked = my_cancel_ares_request_locked; + + GPR_ASSERT(g_channel == nullptr); + GPR_ASSERT(g_server == nullptr); + + bool server_shutdown = false; + int pending_server_shutdowns = 0; + int pending_channel_watches = 0; + int pending_pings = 0; + + g_resource_quota = grpc_resource_quota_create("api_fuzzer"); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + int action_index = 0; + auto no_more_actions = [&]() { action_index = msg.actions_size(); }; + + auto poll_cq = [&]() -> bool { + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_past(GPR_CLOCK_REALTIME), nullptr); + switch (ev.type) { + case GRPC_OP_COMPLETE: { + static_cast(ev.tag)->Run(ev.success); + break; + } + case GRPC_QUEUE_TIMEOUT: + break; + case GRPC_QUEUE_SHUTDOWN: + return true; + } + return false; + }; + + while (action_index < msg.actions_size() || g_channel != nullptr || + g_server != nullptr || pending_channel_watches > 0 || + pending_pings > 0 || ActiveCall() != nullptr) { + if (action_index == msg.actions_size()) { + if (g_channel != nullptr) { + grpc_channel_destroy(g_channel); + g_channel = nullptr; + } + if (g_server != nullptr) { + if (!server_shutdown) { + grpc_server_shutdown_and_notify( + g_server, cq, + AssertSuccessAndDecrement(&pending_server_shutdowns)); + server_shutdown = true; + pending_server_shutdowns++; + } else if (pending_server_shutdowns == 0) { + grpc_server_destroy(g_server); + g_server = nullptr; + } + } + for (auto& call : g_calls) { + if (call == nullptr) continue; + if (call->type() == CallType::PENDING_SERVER) continue; + call->Shutdown(); + } + + g_now = gpr_time_add(g_now, gpr_time_from_seconds(1, GPR_TIMESPAN)); + grpc_timer_manager_tick(); + GPR_ASSERT(!poll_cq()); + continue; + } + + grpc_timer_manager_tick(); + + const api_fuzzer::Action& action = msg.actions(action_index); + action_index++; + switch (action.type_case()) { + case api_fuzzer::Action::TYPE_NOT_SET: + no_more_actions(); + break; + // tickle completion queue + case api_fuzzer::Action::kPollCq: { + GPR_ASSERT(!poll_cq()); + break; + } + // increment global time + case api_fuzzer::Action::kAdvanceTime: { + g_now = gpr_time_add( + g_now, gpr_time_from_micros(action.advance_time(), GPR_TIMESPAN)); + break; + } + // create an insecure channel + case api_fuzzer::Action::kCreateChannel: { + if (g_channel == nullptr) { + grpc_channel_args* args = + ReadArgs(action.create_channel().channel_args()); + if (action.create_channel().has_channel_creds()) { + grpc_channel_credentials* creds = + ReadChannelCreds(action.create_channel().channel_creds()); + g_channel = grpc_secure_channel_create( + creds, action.create_channel().target().c_str(), args, nullptr); + grpc_channel_credentials_release(creds); + } else { + g_channel = grpc_insecure_channel_create( + action.create_channel().target().c_str(), args, nullptr); + } + GPR_ASSERT(g_channel != nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(args); + } + } else { + no_more_actions(); + } + break; + } + // destroy a channel + case api_fuzzer::Action::kCloseChannel: { + if (g_channel != nullptr) { + grpc_channel_destroy(g_channel); + g_channel = nullptr; + } else { + no_more_actions(); + } + break; + } + // bring up a server + case api_fuzzer::Action::kCreateServer: { + if (g_server == nullptr) { + grpc_channel_args* args = + ReadArgs(action.create_server().channel_args()); + g_server = grpc_server_create(args, nullptr); + GPR_ASSERT(g_server != nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(args); + } + grpc_server_register_completion_queue(g_server, cq, nullptr); + grpc_server_start(g_server); + server_shutdown = false; + GPR_ASSERT(pending_server_shutdowns == 0); + } else { + no_more_actions(); + } + break; + } + // begin server shutdown + case api_fuzzer::Action::kShutdownServer: { + if (g_server != nullptr) { + grpc_server_shutdown_and_notify( + g_server, cq, + AssertSuccessAndDecrement(&pending_server_shutdowns)); + pending_server_shutdowns++; + server_shutdown = true; + } else { + no_more_actions(); + } + break; + } + // cancel all calls if shutdown + case api_fuzzer::Action::kCancelAllCallsIfShutdown: { + if (g_server != nullptr && server_shutdown) { + grpc_server_cancel_all_calls(g_server); + } else { + no_more_actions(); + } + break; + } + // destroy server + case api_fuzzer::Action::kDestroyServerIfReady: { + if (g_server != nullptr && server_shutdown && + pending_server_shutdowns == 0) { + grpc_server_destroy(g_server); + g_server = nullptr; + } else { + no_more_actions(); + } + break; + } + // check connectivity + case api_fuzzer::Action::kCheckConnectivity: { + if (g_channel != nullptr) { + grpc_channel_check_connectivity_state(g_channel, + action.check_connectivity()); + } else { + no_more_actions(); + } + break; + } + // watch connectivity + case api_fuzzer::Action::kWatchConnectivity: { + if (g_channel != nullptr) { + grpc_connectivity_state st = + grpc_channel_check_connectivity_state(g_channel, 0); + if (st != GRPC_CHANNEL_SHUTDOWN) { + gpr_timespec deadline = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(action.watch_connectivity(), + GPR_TIMESPAN)); + grpc_channel_watch_connectivity_state( + g_channel, st, deadline, cq, + ValidateConnectivityWatch(deadline, &pending_channel_watches)); + pending_channel_watches++; + } + } else { + no_more_actions(); + } + break; + } + // create a call + case api_fuzzer::Action::kCreateCall: { + bool ok = true; + if (g_channel == nullptr) ok = false; + // If the active call is a server call, then use it as the parent call + // to exercise the propagation logic. + Call* parent_call = ActiveCall(); + if (parent_call != nullptr && parent_call->type() != CallType::SERVER) { + parent_call = nullptr; + } + g_calls.emplace_back(new Call(CallType::CLIENT)); + grpc_slice method = + g_calls.back()->ReadSlice(action.create_call().method()); + if (GRPC_SLICE_LENGTH(method) == 0) { + ok = false; + } + grpc_slice host = + g_calls.back()->ReadSlice(action.create_call().host()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(action.create_call().timeout(), GPR_TIMESPAN)); + + if (ok) { + g_calls.back()->SetCall(grpc_channel_create_call( + g_channel, parent_call == nullptr ? nullptr : parent_call->call(), + action.create_call().propagation_mask(), cq, method, &host, + deadline, nullptr)); + } else { + g_calls.pop_back(); + no_more_actions(); + } + break; + } + // switch the 'current' call + case api_fuzzer::Action::kChangeActiveCall: { + g_active_call++; + ActiveCall(); + break; + } + // queue some ops on a call + case api_fuzzer::Action::kQueueBatch: { + auto* active_call = ActiveCall(); + if (active_call == nullptr || + active_call->type() == CallType::PENDING_SERVER || + active_call->call() == nullptr) { + no_more_actions(); + break; + } + const auto& batch = action.queue_batch().operations(); + if (batch.size() > 6) { + no_more_actions(); + break; + } + std::vector ops; + bool ok = true; + uint8_t has_ops = 0; + std::vector> unwinders; + for (const auto& batch_op : batch) { + auto op = active_call->ReadOp(batch_op, &ok, &has_ops, &unwinders); + if (!op.has_value()) continue; + ops.push_back(*op); + } + if (g_channel == nullptr) ok = false; + if (ok) { + auto* v = active_call->FinishedBatchValidator(has_ops); + grpc_call_error error = grpc_call_start_batch( + active_call->call(), ops.data(), ops.size(), v, nullptr); + if (error != GRPC_CALL_OK) { + v->Run(false); + } + } else { + no_more_actions(); + for (auto& unwind : unwinders) { + unwind(); + } + } + break; + } + // cancel current call + case api_fuzzer::Action::kCancelCall: { + auto* active_call = ActiveCall(); + if (active_call != nullptr && active_call->call() != nullptr) { + grpc_call_cancel(active_call->call(), nullptr); + } else { + no_more_actions(); + } + break; + } + // get a calls peer + case api_fuzzer::Action::kGetPeer: { + auto* active_call = ActiveCall(); + if (active_call != nullptr && active_call->call() != nullptr) { + free_non_null(grpc_call_get_peer(active_call->call())); + } else { + no_more_actions(); + } + break; + } + // get a channels target + case api_fuzzer::Action::kGetTarget: { + if (g_channel != nullptr) { + free_non_null(grpc_channel_get_target(g_channel)); + } else { + no_more_actions(); + } + break; + } + // send a ping on a channel + case api_fuzzer::Action::kPing: { + if (g_channel != nullptr) { + pending_pings++; + grpc_channel_ping(g_channel, cq, Decrement(&pending_pings), nullptr); + } else { + no_more_actions(); + } + break; + } + // enable a tracer + case api_fuzzer::Action::kEnableTracer: { + grpc_tracer_set_enabled(action.enable_tracer().c_str(), 1); + break; + } + // disable a tracer + case api_fuzzer::Action::kDisableTracer: { + grpc_tracer_set_enabled(action.disable_tracer().c_str(), 0); + break; + } + // request a server call + case api_fuzzer::Action::kRequestCall: { + if (g_server == nullptr) { + no_more_actions(); + break; + } + g_calls.emplace_back(new Call(CallType::PENDING_SERVER)); + g_calls.back()->RequestCall(g_server, cq); + break; + } + // destroy a call + case api_fuzzer::Action::kDestroyCall: { + auto* active_call = ActiveCall(); + if (active_call != nullptr && + active_call->type() != CallType::PENDING_SERVER && + active_call->call() != nullptr) { + g_calls[g_active_call]->Shutdown(); + } else { + no_more_actions(); + } + break; + } + // resize the buffer pool + case api_fuzzer::Action::kResizeResourceQuota: { + grpc_resource_quota_resize(g_resource_quota, + action.resize_resource_quota()); + break; + } + } + } + + GPR_ASSERT(g_channel == nullptr); + GPR_ASSERT(g_server == nullptr); + GPR_ASSERT(ActiveCall() == nullptr); + GPR_ASSERT(g_calls.empty()); + + grpc_completion_queue_shutdown(cq); + GPR_ASSERT(poll_cq()); + grpc_completion_queue_destroy(cq); + + grpc_resource_quota_unref(g_resource_quota); + + grpc_shutdown_blocking(); +} diff --git a/test/core/end2end/fuzzers/client_fuzzer.cc b/test/core/end2end/fuzzers/client_fuzzer.cc new file mode 100644 index 00000000..bfc1c182 --- /dev/null +++ b/test/core/end2end/fuzzers/client_fuzzer.cc @@ -0,0 +1,162 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/mock_endpoint.h" + +bool squelch = true; +bool leak_check = true; + +static void discard_write(grpc_slice /*slice*/) {} + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void dont_log(gpr_log_func_args* /*args*/) {} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_test_only_set_slice_hash_seed(0); + if (squelch) gpr_set_log_function(dont_log); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Executor::SetThreadingAll(false); + + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("context_list_test"); + grpc_endpoint* mock_endpoint = grpc_mock_endpoint_create( + discard_write, + grpc_slice_allocator_create(resource_quota, "mock_endpoint")); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_transport* transport = grpc_create_chttp2_transport( + nullptr, mock_endpoint, true, + grpc_resource_user_create(resource_quota, "mock_transport")); + grpc_resource_quota_unref(resource_quota); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + grpc_arg authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test-authority")); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(nullptr, &authority_arg, 1); + grpc_channel* channel = + grpc_channel_create("test-target", args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, nullptr); + grpc_channel_args_destroy(args); + grpc_slice host = grpc_slice_from_static_string("localhost"); + grpc_call* call = grpc_channel_create_call( + channel, nullptr, 0, cq, grpc_slice_from_static_string("/foo"), &host, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array_init(&initial_metadata_recv); + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_status_code status; + grpc_slice details = grpc_empty_slice(); + + grpc_op ops[6]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + grpc_call_error error = + grpc_call_start_batch(call, ops, (size_t)(op - ops), tag(1), nullptr); + int requested_calls = 1; + GPR_ASSERT(GRPC_CALL_OK == error); + + grpc_mock_endpoint_put_read( + mock_endpoint, grpc_slice_from_copied_buffer((const char*)data, size)); + + grpc_event ev; + while (true) { + grpc_core::ExecCtx::Get()->Flush(); + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + switch (ev.type) { + case GRPC_QUEUE_TIMEOUT: + goto done; + case GRPC_QUEUE_SHUTDOWN: + break; + case GRPC_OP_COMPLETE: + requested_calls--; + break; + } + } + + done: + if (requested_calls) { + grpc_call_cancel(call, nullptr); + } + for (int i = 0; i < requested_calls; i++) { + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + } + grpc_completion_queue_shutdown(cq); + for (int i = 0; i < requested_calls; i++) { + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + } + grpc_call_unref(call); + grpc_completion_queue_destroy(cq); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_slice_unref(details); + grpc_channel_destroy(channel); + if (response_payload_recv != nullptr) { + grpc_byte_buffer_destroy(response_payload_recv); + } + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/fuzzers/server_fuzzer.cc b/test/core/end2end/fuzzers/server_fuzzer.cc new file mode 100644 index 00000000..418676a5 --- /dev/null +++ b/test/core/end2end/fuzzers/server_fuzzer.cc @@ -0,0 +1,136 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/server.h" +#include "test/core/util/mock_endpoint.h" + +bool squelch = true; +bool leak_check = true; + +static void discard_write(grpc_slice /*slice*/) {} + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void dont_log(gpr_log_func_args* /*args*/) {} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_test_only_set_slice_hash_seed(0); + if (squelch) gpr_set_log_function(dont_log); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Executor::SetThreadingAll(false); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("context_list_test"); + grpc_endpoint* mock_endpoint = grpc_mock_endpoint_create( + discard_write, + grpc_slice_allocator_create(resource_quota, "mock_endpoint")); + grpc_mock_endpoint_put_read( + mock_endpoint, grpc_slice_from_copied_buffer((const char*)data, size)); + grpc_server* server = grpc_server_create(nullptr, nullptr); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + // TODO(ctiller): add more registered methods (one for POST, one for PUT) + grpc_server_register_method(server, "/reg", nullptr, {}, 0); + grpc_server_start(server); + grpc_transport* transport = grpc_create_chttp2_transport( + nullptr, mock_endpoint, false, + grpc_resource_user_create(resource_quota, "mock_transport")); + grpc_resource_quota_unref(resource_quota); + GPR_ASSERT(GRPC_LOG_IF_ERROR("SetupTransport", + server->core_server->SetupTransport( + transport, nullptr, nullptr, nullptr))); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + + grpc_call* call1 = nullptr; + grpc_call_details call_details1; + grpc_metadata_array request_metadata1; + grpc_call_details_init(&call_details1); + grpc_metadata_array_init(&request_metadata1); + int requested_calls = 0; + + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_call(server, &call1, &call_details1, + &request_metadata1, cq, cq, tag(1))); + requested_calls++; + + grpc_event ev; + while (true) { + grpc_core::ExecCtx::Get()->Flush(); + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + switch (ev.type) { + case GRPC_QUEUE_TIMEOUT: + goto done; + case GRPC_QUEUE_SHUTDOWN: + break; + case GRPC_OP_COMPLETE: + if (ev.tag == tag(1)) { + requested_calls--; + // TODO(ctiller): keep reading that call! + } + break; + } + } + + done: + if (call1 != nullptr) grpc_call_unref(call1); + grpc_call_details_destroy(&call_details1); + grpc_metadata_array_destroy(&request_metadata1); + grpc_server_shutdown_and_notify(server, cq, tag(0xdead)); + grpc_server_cancel_all_calls(server); + grpc_millis deadline = grpc_core::ExecCtx::Get()->Now() + 5000; + for (int i = 0; i <= requested_calls; i++) { + // A single grpc_completion_queue_next might not be sufficient for getting + // the tag from shutdown, because we might potentially get blocked by + // an operation happening on the timer thread. + // For example, the deadline timer might expire, leading to the timer + // thread trying to cancel the RPC and thereby acquiring a few references + // to the call. This will prevent the shutdown to complete till the timer + // thread releases those references. + // As a solution, we are going to keep performing a cq_next for a + // liberal period of 5 seconds for the timer thread to complete its work. + do { + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + grpc_core::ExecCtx::Get()->InvalidateNow(); + } while (ev.type != GRPC_OP_COMPLETE && + grpc_core::ExecCtx::Get()->Now() < deadline); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + } + grpc_completion_queue_shutdown(cq); + for (int i = 0; i <= requested_calls; i++) { + do { + ev = grpc_completion_queue_next(cq, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + grpc_core::ExecCtx::Get()->InvalidateNow(); + } while (ev.type != GRPC_QUEUE_SHUTDOWN && + grpc_core::ExecCtx::Get()->Now() < deadline); + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + } + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/goaway_server_test.cc b/test/core/end2end/goaway_server_test.cc new file mode 100644 index 00000000..4d4509bc --- /dev/null +++ b/test/core/end2end/goaway_server_test.cc @@ -0,0 +1,388 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +extern grpc_address_resolver_vtable* grpc_resolve_address_impl; +static grpc_address_resolver_vtable* default_resolver; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static gpr_mu g_mu; +static int g_resolve_port = -1; + +static grpc_ares_request* (*iomgr_dns_lookup_ares_locked)( + const char* dns_server, const char* addr, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* balancer_addresses, + char** service_config_json, int query_timeout_ms, + std::shared_ptr combiner); + +static void (*iomgr_cancel_ares_request_locked)(grpc_ares_request* request); + +static void set_resolve_port(int port) { + gpr_mu_lock(&g_mu); + g_resolve_port = port; + gpr_mu_unlock(&g_mu); +} + +static void my_resolve_address(const char* addr, const char* default_port, + grpc_pollset_set* interested_parties, + grpc_closure* on_done, + grpc_resolved_addresses** addrs) { + if (0 != strcmp(addr, "test")) { + default_resolver->resolve_address(addr, default_port, interested_parties, + on_done, addrs); + return; + } + + grpc_error_handle error = GRPC_ERROR_NONE; + gpr_mu_lock(&g_mu); + if (g_resolve_port < 0) { + gpr_mu_unlock(&g_mu); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Forced Failure"); + } else { + *addrs = static_cast(gpr_malloc(sizeof(**addrs))); + (*addrs)->naddrs = 1; + (*addrs)->addrs = static_cast( + gpr_malloc(sizeof(*(*addrs)->addrs))); + memset((*addrs)->addrs, 0, sizeof(*(*addrs)->addrs)); + grpc_sockaddr_in* sa = + reinterpret_cast((*addrs)->addrs[0].addr); + sa->sin_family = GRPC_AF_INET; + sa->sin_addr.s_addr = 0x100007f; + sa->sin_port = grpc_htons(static_cast(g_resolve_port)); + (*addrs)->addrs[0].len = static_cast(sizeof(*sa)); + gpr_mu_unlock(&g_mu); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, error); +} + +static grpc_error_handle my_blocking_resolve_address( + const char* name, const char* default_port, + grpc_resolved_addresses** addresses) { + return default_resolver->blocking_resolve_address(name, default_port, + addresses); +} + +static grpc_address_resolver_vtable test_resolver = { + my_resolve_address, my_blocking_resolve_address}; + +static grpc_ares_request* my_dns_lookup_ares_locked( + const char* dns_server, const char* addr, const char* default_port, + grpc_pollset_set* interested_parties, grpc_closure* on_done, + std::unique_ptr* addresses, + std::unique_ptr* balancer_addresses, + char** service_config_json, int query_timeout_ms, + std::shared_ptr work_serializer) { + if (0 != strcmp(addr, "test")) { + return iomgr_dns_lookup_ares_locked( + dns_server, addr, default_port, interested_parties, on_done, addresses, + balancer_addresses, service_config_json, query_timeout_ms, + std::move(work_serializer)); + } + + grpc_error_handle error = GRPC_ERROR_NONE; + gpr_mu_lock(&g_mu); + if (g_resolve_port < 0) { + gpr_mu_unlock(&g_mu); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Forced Failure"); + } else { + *addresses = absl::make_unique(); + grpc_sockaddr_in sa; + memset(&sa, 0, sizeof(sa)); + sa.sin_family = GRPC_AF_INET; + sa.sin_addr.s_addr = 0x100007f; + sa.sin_port = grpc_htons(static_cast(g_resolve_port)); + (*addresses)->emplace_back(&sa, sizeof(sa), nullptr); + gpr_mu_unlock(&g_mu); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, error); + return nullptr; +} + +static void my_cancel_ares_request_locked(grpc_ares_request* request) { + if (request != nullptr) { + iomgr_cancel_ares_request_locked(request); + } +} + +int main(int argc, char** argv) { + grpc_completion_queue* cq; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + + grpc::testing::TestEnvironment env(argc, argv); + + gpr_mu_init(&g_mu); + grpc_init(); + default_resolver = grpc_resolve_address_impl; + grpc_set_resolver_impl(&test_resolver); + iomgr_dns_lookup_ares_locked = grpc_dns_lookup_ares_locked; + iomgr_cancel_ares_request_locked = grpc_cancel_ares_request_locked; + grpc_dns_lookup_ares_locked = my_dns_lookup_ares_locked; + grpc_cancel_ares_request_locked = my_cancel_ares_request_locked; + + int was_cancelled1; + int was_cancelled2; + + grpc_metadata_array trailing_metadata_recv1; + grpc_metadata_array request_metadata1; + grpc_call_details request_details1; + grpc_status_code status1; + grpc_slice details1; + grpc_metadata_array_init(&trailing_metadata_recv1); + grpc_metadata_array_init(&request_metadata1); + grpc_call_details_init(&request_details1); + + grpc_metadata_array trailing_metadata_recv2; + grpc_metadata_array request_metadata2; + grpc_call_details request_details2; + grpc_status_code status2; + grpc_slice details2; + grpc_metadata_array_init(&trailing_metadata_recv2); + grpc_metadata_array_init(&request_metadata2); + grpc_call_details_init(&request_details2); + + cq = grpc_completion_queue_create_for_next(nullptr); + cqv = cq_verifier_create(cq); + + /* reserve two ports */ + int port1 = grpc_pick_unused_port_or_die(); + int port2 = grpc_pick_unused_port_or_die(); + + std::string addr; + + grpc_channel_args client_args; + grpc_arg arg_array[2]; + arg_array[0].type = GRPC_ARG_INTEGER; + arg_array[0].key = + const_cast("grpc.testing.fixed_reconnect_backoff_ms"); + arg_array[0].value.integer = 1000; + /* When this test brings down server1 and then brings up server2, + * the targetted server port number changes, and the client channel + * needs to re-resolve to pick this up. This test requires that + * happen within 10 seconds, but gRPC's DNS resolvers rate limit + * resolution attempts to at most once every 30 seconds by default. + * So we tweak it for this test. */ + arg_array[1].type = GRPC_ARG_INTEGER; + arg_array[1].key = + const_cast(GRPC_ARG_DNS_MIN_TIME_BETWEEN_RESOLUTIONS_MS); + arg_array[1].value.integer = 1000; + client_args.args = arg_array; + client_args.num_args = 2; + + /* create a channel that picks first amongst the servers */ + grpc_channel* chan = + grpc_insecure_channel_create("test", &client_args, nullptr); + /* and an initial call to them */ + grpc_slice host = grpc_slice_from_static_string("127.0.0.1"); + grpc_call* call1 = + grpc_channel_create_call(chan, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), &host, + grpc_timeout_seconds_to_deadline(20), nullptr); + /* send initial metadata to probe connectivity */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call1, ops, + (size_t)(op - ops), + tag(0x101), nullptr)); + /* and receive status to probe termination */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv1; + op->data.recv_status_on_client.status = &status1; + op->data.recv_status_on_client.status_details = &details1; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call1, ops, + (size_t)(op - ops), + tag(0x102), nullptr)); + + /* bring a server up on the first port */ + grpc_server* server1 = grpc_server_create(nullptr, nullptr); + addr = absl::StrCat("127.0.0.1:", port1); + grpc_server_add_insecure_http2_port(server1, addr.c_str()); + grpc_server_register_completion_queue(server1, cq, nullptr); + grpc_server_start(server1); + + /* request a call to the server */ + grpc_call* server_call1; + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_call(server1, &server_call1, &request_details1, + &request_metadata1, cq, cq, tag(0x301))); + + set_resolve_port(port1); + + /* first call should now start */ + CQ_EXPECT_COMPLETION(cqv, tag(0x101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0x301), 1); + cq_verify(cqv); + + GPR_ASSERT(GRPC_CHANNEL_READY == + grpc_channel_check_connectivity_state(chan, 0)); + grpc_channel_watch_connectivity_state(chan, GRPC_CHANNEL_READY, + gpr_inf_future(GPR_CLOCK_REALTIME), cq, + tag(0x9999)); + + /* listen for close on the server call to probe for finishing */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled1; + op->flags = 0; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(server_call1, ops, + (size_t)(op - ops), + tag(0x302), nullptr)); + + /* shutdown first server: + * we should see a connectivity change and then nothing */ + set_resolve_port(-1); + grpc_server_shutdown_and_notify(server1, cq, tag(0xdead1)); + CQ_EXPECT_COMPLETION(cqv, tag(0x9999), 1); + cq_verify(cqv); + cq_verify_empty(cqv); + + /* and a new call: should go through to server2 when we start it */ + grpc_call* call2 = + grpc_channel_create_call(chan, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), &host, + grpc_timeout_seconds_to_deadline(20), nullptr); + /* send initial metadata to probe connectivity */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call2, ops, + (size_t)(op - ops), + tag(0x201), nullptr)); + /* and receive status to probe termination */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv2; + op->data.recv_status_on_client.status = &status2; + op->data.recv_status_on_client.status_details = &details2; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call2, ops, + (size_t)(op - ops), + tag(0x202), nullptr)); + + /* and bring up second server */ + set_resolve_port(port2); + grpc_server* server2 = grpc_server_create(nullptr, nullptr); + addr = absl::StrCat("127.0.0.1:", port2); + grpc_server_add_insecure_http2_port(server2, addr.c_str()); + grpc_server_register_completion_queue(server2, cq, nullptr); + grpc_server_start(server2); + + /* request a call to the server */ + grpc_call* server_call2; + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_call(server2, &server_call2, &request_details2, + &request_metadata2, cq, cq, tag(0x401))); + + /* second call should now start */ + CQ_EXPECT_COMPLETION(cqv, tag(0x201), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0x401), 1); + cq_verify(cqv); + + /* listen for close on the server call to probe for finishing */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled2; + op->flags = 0; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(server_call2, ops, + (size_t)(op - ops), + tag(0x402), nullptr)); + + /* shutdown second server: we should see nothing */ + grpc_server_shutdown_and_notify(server2, cq, tag(0xdead2)); + cq_verify_empty(cqv); + + grpc_call_cancel(call1, nullptr); + grpc_call_cancel(call2, nullptr); + + /* now everything else should finish */ + CQ_EXPECT_COMPLETION(cqv, tag(0x102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0x202), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0x302), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0x402), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead2), 1); + cq_verify(cqv); + + grpc_call_unref(call1); + grpc_call_unref(call2); + grpc_call_unref(server_call1); + grpc_call_unref(server_call2); + grpc_server_destroy(server1); + grpc_server_destroy(server2); + grpc_channel_destroy(chan); + + grpc_metadata_array_destroy(&trailing_metadata_recv1); + grpc_metadata_array_destroy(&request_metadata1); + grpc_call_details_destroy(&request_details1); + grpc_slice_unref(details1); + grpc_metadata_array_destroy(&trailing_metadata_recv2); + grpc_metadata_array_destroy(&request_metadata2); + grpc_call_details_destroy(&request_details2); + grpc_slice_unref(details2); + + cq_verifier_destroy(cqv); + grpc_completion_queue_destroy(cq); + + grpc_shutdown(); + gpr_mu_destroy(&g_mu); + + return 0; +} diff --git a/test/core/end2end/h2_ssl_cert_test.cc b/test/core/end2end/h2_ssl_cert_test.cc new file mode 100644 index 00000000..77026721 --- /dev/null +++ b/test/core/end2end/h2_ssl_cert_test.cc @@ -0,0 +1,403 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/data/ssl_test_data.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +extern "C" { +#include +} + +static std::string test_server1_key_id; + +namespace grpc { +namespace testing { + +struct fullstack_secure_fixture_data { + std::string localaddr; +}; + +static grpc_end2end_test_fixture chttp2_create_fixture_secure_fullstack( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + int port = grpc_pick_unused_port_or_die(); + fullstack_secure_fixture_data* ffd = new fullstack_secure_fixture_data(); + memset(&f, 0, sizeof(f)); + + ffd->localaddr = grpc_core::JoinHostPort("localhost", port); + + f.fixture_data = ffd; + f.cq = grpc_completion_queue_create_for_next(nullptr); + return f; +} + +static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/, + const grpc_metadata* /*md*/, + size_t /*md_count*/, + grpc_process_auth_metadata_done_cb cb, + void* user_data) { + GPR_ASSERT(state == nullptr); + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr); +} + +static void chttp2_init_client_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* client_args, + grpc_channel_credentials* creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + f->client = grpc_secure_channel_create(creds, ffd->localaddr.c_str(), + client_args, nullptr); + GPR_ASSERT(f->client != nullptr); + grpc_channel_credentials_release(creds); +} + +static void chttp2_init_server_secure_fullstack( + grpc_end2end_test_fixture* f, grpc_channel_args* server_args, + grpc_server_credentials* server_creds) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port( + f->server, ffd->localaddr.c_str(), server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(f->server); +} + +void chttp2_tear_down_secure_fullstack(grpc_end2end_test_fixture* f) { + fullstack_secure_fixture_data* ffd = + static_cast(f->fixture_data); + delete ffd; +} + +static int fail_server_auth_check(grpc_channel_args* server_args) { + size_t i; + if (server_args == nullptr) return 0; + for (i = 0; i < server_args->num_args; i++) { + if (strcmp(server_args->args[i].key, FAIL_AUTH_CHECK_SERVER_ARG_NAME) == + 0) { + return 1; + } + } + return 0; +} + +#define SERVER_INIT_NAME(REQUEST_TYPE) \ + chttp2_init_server_simple_ssl_secure_fullstack_##REQUEST_TYPE + +#define SERVER_INIT(REQUEST_TYPE) \ + static void SERVER_INIT_NAME(REQUEST_TYPE)( \ + grpc_end2end_test_fixture * f, grpc_channel_args * server_args) { \ + grpc_ssl_pem_key_cert_pair pem_cert_key_pair; \ + if (!test_server1_key_id.empty()) { \ + pem_cert_key_pair.private_key = test_server1_key_id.c_str(); \ + pem_cert_key_pair.cert_chain = test_server1_cert; \ + } else { \ + pem_cert_key_pair.private_key = test_server1_key; \ + pem_cert_key_pair.cert_chain = test_server1_cert; \ + } \ + grpc_server_credentials* ssl_creds = \ + grpc_ssl_server_credentials_create_ex( \ + test_root_cert, &pem_cert_key_pair, 1, REQUEST_TYPE, NULL); \ + if (fail_server_auth_check(server_args)) { \ + grpc_auth_metadata_processor processor = {process_auth_failure, NULL, \ + NULL}; \ + grpc_server_credentials_set_auth_metadata_processor(ssl_creds, \ + processor); \ + } \ + chttp2_init_server_secure_fullstack(f, server_args, ssl_creds); \ + } + +SERVER_INIT(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE) +SERVER_INIT(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY) +SERVER_INIT(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY) +SERVER_INIT(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY) +SERVER_INIT(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY) + +#define CLIENT_INIT_NAME(cert_type) \ + chttp2_init_client_simple_ssl_secure_fullstack_##cert_type + +typedef enum { NONE, SELF_SIGNED, SIGNED, BAD_CERT_PAIR } certtype; + +#define CLIENT_INIT(cert_type) \ + static void CLIENT_INIT_NAME(cert_type)(grpc_end2end_test_fixture * f, \ + grpc_channel_args * client_args) { \ + grpc_channel_credentials* ssl_creds = NULL; \ + grpc_ssl_pem_key_cert_pair self_signed_client_key_cert_pair = { \ + test_self_signed_client_key, test_self_signed_client_cert}; \ + grpc_ssl_pem_key_cert_pair signed_client_key_cert_pair = { \ + test_signed_client_key, test_signed_client_cert}; \ + grpc_ssl_pem_key_cert_pair bad_client_key_cert_pair = { \ + test_self_signed_client_key, test_signed_client_cert}; \ + grpc_ssl_pem_key_cert_pair* key_cert_pair = NULL; \ + switch (cert_type) { \ + case SELF_SIGNED: \ + key_cert_pair = &self_signed_client_key_cert_pair; \ + break; \ + case SIGNED: \ + key_cert_pair = &signed_client_key_cert_pair; \ + break; \ + case BAD_CERT_PAIR: \ + key_cert_pair = &bad_client_key_cert_pair; \ + break; \ + default: \ + break; \ + } \ + ssl_creds = grpc_ssl_credentials_create(test_root_cert, key_cert_pair, \ + NULL, NULL); \ + grpc_arg ssl_name_override = { \ + GRPC_ARG_STRING, \ + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), \ + {const_cast("foo.test.google.fr")}}; \ + grpc_channel_args* new_client_args = \ + grpc_channel_args_copy_and_add(client_args, &ssl_name_override, 1); \ + chttp2_init_client_secure_fullstack(f, new_client_args, ssl_creds); \ + { \ + grpc_core::ExecCtx exec_ctx; \ + grpc_channel_args_destroy(new_client_args); \ + } \ + } + +CLIENT_INIT(NONE) +CLIENT_INIT(SELF_SIGNED) +CLIENT_INIT(SIGNED) +CLIENT_INIT(BAD_CERT_PAIR) + +#define TEST_NAME(enum_name, cert_type, result) \ + "chttp2/ssl_" #enum_name "_" #cert_type "_" #result "_" + +typedef enum { SUCCESS, FAIL } test_result; + +#define SSL_TEST(request_type, cert_type, result) \ + { \ + {TEST_NAME(request_type, cert_type, result), \ + FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION | \ + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS | \ + FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL, \ + "foo.test.google.fr", \ + chttp2_create_fixture_secure_fullstack, \ + CLIENT_INIT_NAME(cert_type), \ + SERVER_INIT_NAME(request_type), \ + chttp2_tear_down_secure_fullstack}, \ + result \ + } + +/* All test configurations */ +typedef struct grpc_end2end_test_config_wrapper { + grpc_end2end_test_config config; + test_result result; +} grpc_end2end_test_config_wrapper; + +static grpc_end2end_test_config_wrapper configs[] = { + SSL_TEST(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, NONE, SUCCESS), + SSL_TEST(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, SELF_SIGNED, SUCCESS), + SSL_TEST(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, SIGNED, SUCCESS), + SSL_TEST(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, BAD_CERT_PAIR, FAIL), + + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, NONE, + SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, SELF_SIGNED, + SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, SIGNED, + SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, BAD_CERT_PAIR, + FAIL), + + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY, NONE, SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY, SELF_SIGNED, FAIL), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY, SIGNED, SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY, BAD_CERT_PAIR, + FAIL), + + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, + NONE, FAIL), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, + SELF_SIGNED, SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, + SIGNED, SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY, + BAD_CERT_PAIR, FAIL), + + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY, NONE, + FAIL), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY, + SELF_SIGNED, FAIL), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY, SIGNED, + SUCCESS), + SSL_TEST(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY, + BAD_CERT_PAIR, FAIL), +}; + +static void* tag(intptr_t t) { return (void*)t; } + +static gpr_timespec n_seconds_time(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_time(void) { return n_seconds_time(5); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_time(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +// Shuts down the server. +// Side effect - Also shuts down and drains the completion queue. +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->cq, tag(1000)); + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_client(f); + shutdown_server(f); + grpc_completion_queue_destroy(f->cq); +} + +static void simple_request_body(grpc_end2end_test_fixture f, + test_result expected_result) { + grpc_call* c; + gpr_timespec deadline = five_seconds_time(); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_call_error error; + + grpc_slice host = grpc_slice_from_static_string("foo.test.google.fr:1234"); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), &host, + deadline, nullptr); + GPR_ASSERT(c); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), expected_result == SUCCESS); + cq_verify(cqv); + + grpc_call_unref(c); + cq_verifier_destroy(cqv); +} + +class H2SslCertTest + : public ::testing::TestWithParam { + protected: + H2SslCertTest() { + gpr_log(GPR_INFO, "SSL_CERT_tests/%s", GetParam().config.name); + } + void SetUp() override { + fixture_ = GetParam().config.create_fixture(nullptr, nullptr); + GetParam().config.init_server(&fixture_, nullptr); + GetParam().config.init_client(&fixture_, nullptr); + } + void TearDown() override { + end_test(&fixture_); + GetParam().config.tear_down_data(&fixture_); + } + + grpc_end2end_test_fixture fixture_; +}; + +TEST_P(H2SslCertTest, SimpleRequestBody) { + simple_request_body(fixture_, GetParam().result); +} + +#ifndef OPENSSL_IS_BORINGSSL +#if GPR_LINUX +TEST_P(H2SslCertTest, SimpleRequestBodyUseEngine) { + test_server1_key_id.clear(); + test_server1_key_id.append("engine:libengine_passthrough:"); + test_server1_key_id.append(test_server1_key); + simple_request_body(fixture_, GetParam().result); +} +#endif +#endif + +INSTANTIATE_TEST_SUITE_P(H2SslCert, H2SslCertTest, + ::testing::ValuesIn(configs)); + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + FILE* roots_file; + size_t roots_size = strlen(test_root_cert); + char* roots_filename; + + grpc::testing::TestEnvironment env(argc, argv); + /* Set the SSL roots env var. */ + roots_file = + gpr_tmpfile("chttp2_simple_ssl_cert_fullstack_test", &roots_filename); + GPR_ASSERT(roots_filename != nullptr); + GPR_ASSERT(roots_file != nullptr); + GPR_ASSERT(fwrite(test_root_cert, 1, roots_size, roots_file) == roots_size); + fclose(roots_file); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, roots_filename); + + grpc_init(); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + + /* Cleanup. */ + remove(roots_filename); + gpr_free(roots_filename); + + return ret; +} diff --git a/test/core/end2end/h2_ssl_session_reuse_test.cc b/test/core/end2end/h2_ssl_session_reuse_test.cc new file mode 100644 index 00000000..6b88a1d8 --- /dev/null +++ b/test/core/end2end/h2_ssl_session_reuse_test.cc @@ -0,0 +1,300 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define CLIENT_CERT_PATH "src/core/tsi/test_creds/client.pem" +#define CLIENT_KEY_PATH "src/core/tsi/test_creds/client.key" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +namespace grpc { +namespace testing { +namespace { + +void* tag(intptr_t t) { return reinterpret_cast(t); } + +gpr_timespec five_seconds_time() { return grpc_timeout_seconds_to_deadline(5); } + +grpc_server* server_create(grpc_completion_queue* cq, const char* server_addr) { + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_cert_key_pair = {server_key, server_cert}; + grpc_server_credentials* server_creds = grpc_ssl_server_credentials_create_ex( + ca_cert, &pem_cert_key_pair, 1, + GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY, nullptr); + + grpc_server* server = grpc_server_create(nullptr, nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT( + grpc_server_add_secure_http2_port(server, server_addr, server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(server); + + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + return server; +} + +grpc_channel* client_create(const char* server_addr, + grpc_ssl_session_cache* cache) { + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(CLIENT_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CLIENT_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + const char* client_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* client_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair signed_client_key_cert_pair = {client_key, + client_cert}; + grpc_channel_credentials* client_creds = grpc_ssl_credentials_create( + ca_cert, &signed_client_key_cert_pair, nullptr, nullptr); + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + const_cast("waterzooi.test.google.be")), + grpc_ssl_session_cache_create_channel_arg(cache), + }; + + grpc_channel_args* client_args = + grpc_channel_args_copy_and_add(nullptr, args, GPR_ARRAY_SIZE(args)); + + grpc_channel* client = grpc_secure_channel_create(client_creds, server_addr, + client_args, nullptr); + GPR_ASSERT(client != nullptr); + grpc_channel_credentials_release(client_creds); + + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + } + + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + return client; +} + +void do_round_trip(grpc_completion_queue* cq, grpc_server* server, + const char* server_addr, grpc_ssl_session_cache* cache, + bool expect_session_reuse) { + grpc_channel* client = client_create(server_addr, cache); + + cq_verifier* cqv = cq_verifier_create(cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(60); + grpc_call* c = grpc_channel_create_call( + client, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + grpc_call* s; + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + grpc_auth_context* auth = grpc_call_auth_context(s); + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + auth, GRPC_SSL_SESSION_REUSED_PROPERTY); + const grpc_auth_property* property = grpc_auth_property_iterator_next(&it); + GPR_ASSERT(property != nullptr); + if (expect_session_reuse) { + GPR_ASSERT(strcmp(property->value, "true") == 0); + } else { + GPR_ASSERT(strcmp(property->value, "false") == 0); + } + grpc_auth_context_release(auth); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_channel_destroy(client); +} + +void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_time(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +TEST(H2SessionReuseTest, SingleReuse) { + int port = grpc_pick_unused_port_or_die(); + + std::string server_addr = grpc_core::JoinHostPort("localhost", port); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_ssl_session_cache* cache = grpc_ssl_session_cache_create_lru(16); + + grpc_server* server = server_create(cq, server_addr.c_str()); + + do_round_trip(cq, server, server_addr.c_str(), cache, false); + do_round_trip(cq, server, server_addr.c_str(), cache, true); + do_round_trip(cq, server, server_addr.c_str(), cache, true); + + grpc_ssl_session_cache_destroy(cache); + + GPR_ASSERT(grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(100), nullptr) + .type == GRPC_QUEUE_TIMEOUT); + + grpc_completion_queue* shutdown_cq = + grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(server, shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(server); + grpc_completion_queue_destroy(shutdown_cq); + + grpc_completion_queue_shutdown(cq); + drain_cq(cq); + grpc_completion_queue_destroy(cq); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + + grpc_init(); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + + return ret; +} diff --git a/test/core/end2end/inproc_callback_test.cc b/test/core/end2end/inproc_callback_test.cc new file mode 100644 index 00000000..be36993c --- /dev/null +++ b/test/core/end2end/inproc_callback_test.cc @@ -0,0 +1,509 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +typedef struct inproc_fixture_data { + bool phony; // reserved for future expansion. Struct can't be empty +} inproc_fixture_data; + +namespace { +template +class CQDeletingCallback : public grpc_completion_queue_functor { + public: + explicit CQDeletingCallback(F f) : func_(f) { + functor_run = &CQDeletingCallback::Run; + inlineable = false; + } + ~CQDeletingCallback() {} + static void Run(grpc_completion_queue_functor* cb, int ok) { + auto* callback = static_cast(cb); + callback->func_(static_cast(ok)); + delete callback; + } + + private: + F func_; +}; + +template +grpc_completion_queue_functor* NewDeletingCallback(F f) { + return new CQDeletingCallback(f); +} + +class ShutdownCallback : public grpc_completion_queue_functor { + public: + ShutdownCallback() : done_(false) { + functor_run = &ShutdownCallback::StaticRun; + inlineable = false; + gpr_mu_init(&mu_); + gpr_cv_init(&cv_); + } + ~ShutdownCallback() { + gpr_mu_destroy(&mu_); + gpr_cv_destroy(&cv_); + } + static void StaticRun(grpc_completion_queue_functor* cb, int ok) { + auto* callback = static_cast(cb); + callback->Run(static_cast(ok)); + } + void Run(bool /*ok*/) { + gpr_log(GPR_DEBUG, "CQ shutdown notification invoked"); + gpr_mu_lock(&mu_); + done_ = true; + gpr_cv_broadcast(&cv_); + gpr_mu_unlock(&mu_); + } + // The Wait function waits for a specified amount of + // time for the completion of the shutdown and returns + // whether it was successfully shut down + bool Wait(gpr_timespec deadline) { + gpr_mu_lock(&mu_); + while (!done_ && !gpr_cv_wait(&cv_, &mu_, deadline)) { + } + bool ret = done_; + gpr_mu_unlock(&mu_); + return ret; + } + + private: + bool done_; + gpr_mu mu_; + gpr_cv cv_; +}; + +ShutdownCallback* g_shutdown_callback; +} // namespace + +// The following global structure is the tag collection. It holds +// all information related to tags expected and tags received +// during the execution, with each callback setting a tag. +// The tag sets are implemented and checked using arrays and +// linear lookups (rather than maps) so that this test doesn't +// need the C++ standard library. +static gpr_mu tags_mu; +static gpr_cv tags_cv; +const size_t kAvailableTags = 4; +bool tags[kAvailableTags]; +bool tags_valid[kAvailableTags]; +bool tags_expected[kAvailableTags]; +bool tags_needed[kAvailableTags]; + +// Mark that a tag is expected; this function must be executed in the +// main thread only while there are no other threads altering the +// expectation set (e.g., by calling expect_tag or verify_tags) +static void expect_tag(intptr_t tag, bool ok) { + size_t idx = static_cast(tag); + GPR_ASSERT(idx < kAvailableTags); + tags_needed[idx] = true; + tags_expected[idx] = ok; +} + +// Check that the expected tags have reached, within a certain +// deadline. This must also be executed only on the main thread while +// there are no other threads altering the expectation set (e.g., by +// calling expect_tag or verify_tags). The tag verifier doesn't have +// to drive the CQ at all (unlike the next-based end2end tests) +// because the tags will get set when the callbacks are executed, +// which happens when a particular batch related to a callback is +// complete. +static void verify_tags(gpr_timespec deadline) { + bool done = false; + + gpr_mu_lock(&tags_mu); + while (!done) { + done = gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), deadline) > 0; + for (size_t i = 0; i < kAvailableTags; i++) { + if (tags_needed[i]) { + if (tags_valid[i]) { + gpr_log(GPR_DEBUG, "Verifying tag %d", static_cast(i)); + if (tags[i] != tags_expected[i]) { + gpr_log(GPR_ERROR, "Got wrong result (%d instead of %d) for tag %d", + tags[i], tags_expected[i], static_cast(i)); + GPR_ASSERT(false); + } + tags_valid[i] = false; + tags_needed[i] = false; + } else if (done) { + gpr_log(GPR_ERROR, "Didn't get tag %d", static_cast(i)); + GPR_ASSERT(false); + } + } + } + bool empty = true; + for (size_t i = 0; i < kAvailableTags; i++) { + if (tags_needed[i]) { + empty = false; + } + } + done = done || empty; + if (done) { + for (size_t i = 0; i < kAvailableTags; i++) { + if (tags_valid[i]) { + gpr_log(GPR_ERROR, "Got unexpected tag %d and result %d", + static_cast(i), tags[i]); + GPR_ASSERT(false); + } + tags_valid[i] = false; + } + } else { + gpr_cv_wait(&tags_cv, &tags_mu, deadline); + } + } + gpr_mu_unlock(&tags_mu); +} + +// This function creates a callback functor that emits the +// desired tag into the global tag set +static grpc_completion_queue_functor* tag(intptr_t t) { + auto func = [t](bool ok) { + gpr_mu_lock(&tags_mu); + gpr_log(GPR_DEBUG, "Completing operation %" PRIdPTR, t); + bool was_empty = true; + for (size_t i = 0; i < kAvailableTags; i++) { + if (tags_valid[i]) { + was_empty = false; + } + } + size_t idx = static_cast(t); + tags[idx] = ok; + tags_valid[idx] = true; + if (was_empty) { + gpr_cv_signal(&tags_cv); + } + gpr_mu_unlock(&tags_mu); + }; + auto cb = NewDeletingCallback(func); + return cb; +} + +static grpc_end2end_test_fixture inproc_create_fixture( + grpc_channel_args* /*client_args*/, grpc_channel_args* /*server_args*/) { + grpc_end2end_test_fixture f; + inproc_fixture_data* ffd = static_cast( + gpr_malloc(sizeof(inproc_fixture_data))); + memset(&f, 0, sizeof(f)); + + f.fixture_data = ffd; + g_shutdown_callback = new ShutdownCallback(); + f.cq = + grpc_completion_queue_create_for_callback(g_shutdown_callback, nullptr); + f.shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + + return f; +} + +void inproc_init_client(grpc_end2end_test_fixture* f, + grpc_channel_args* client_args) { + f->client = grpc_inproc_channel_create(f->server, client_args, nullptr); + GPR_ASSERT(f->client); +} + +void inproc_init_server(grpc_end2end_test_fixture* f, + grpc_channel_args* server_args) { + if (f->server) { + grpc_server_destroy(f->server); + } + f->server = grpc_server_create(server_args, nullptr); + grpc_server_register_completion_queue(f->server, f->cq, nullptr); + grpc_server_start(f->server); +} + +void inproc_tear_down(grpc_end2end_test_fixture* f) { + inproc_fixture_data* ffd = static_cast(f->fixture_data); + gpr_free(ffd); +} + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now() { return n_seconds_from_now(5); } + +static void drain_cq(grpc_completion_queue* /*cq*/) { + // Wait for the shutdown callback to arrive, or fail the test + GPR_ASSERT(g_shutdown_callback->Wait(five_seconds_from_now())); + gpr_log(GPR_DEBUG, "CQ shutdown wait complete"); + delete g_shutdown_callback; +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify( + f->server, f->shutdown_cq, + reinterpret_cast(static_cast(1000))); + GPR_ASSERT( + grpc_completion_queue_pluck(f->shutdown_cq, (void*)((intptr_t)1000), + grpc_timeout_seconds_to_deadline(5), nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config config, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + const char* error_string; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + gpr_timespec deadline = five_seconds_from_now(); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + // Create a basic client unary request batch (no payload) + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = &error_string; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Register a call at the server-side to match the incoming client call + error = grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(2)); + GPR_ASSERT(GRPC_CALL_OK == error); + + // We expect that the server call creation callback (and no others) will + // execute now since no other batch should be complete. + expect_tag(2, true); + verify_tags(deadline); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Create the server response batch (no payload) + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Both the client request and server response batches should get complete + // now and we should see that their callbacks have been executed + expect_tag(3, true); + expect_tag(1, true); + verify_tags(deadline); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + // the following sanity check makes sure that the requested error string is + // correctly populated by the core. It looks for certain substrings that are + // not likely to change much. Some parts of the error, like time created, + // obviously are not checked. + GPR_ASSERT(nullptr != strstr(error_string, "xyz")); + GPR_ASSERT(nullptr != strstr(error_string, "Error received from peer")); + GPR_ASSERT(nullptr != strstr(error_string, "grpc_message")); + GPR_ASSERT(nullptr != strstr(error_string, "grpc_status")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + gpr_free(static_cast(const_cast(error_string))); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + int expected_calls = 1; + if (config.feature_mask & FEATURE_MASK_SUPPORTS_REQUEST_PROXYING) { + expected_calls *= 2; + } +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_10_simple_requests", nullptr, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + gpr_log(GPR_INFO, "Running test: Passed simple request %d", i); + } + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_many_simple_requests(grpc_end2end_test_config config) { + int i; + const int many = 1000; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_many_simple_requests", nullptr, nullptr); + gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + for (i = 0; i < many; i++) { + simple_request_body(config, f); + } + double us = + gpr_timespec_to_micros(gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), t1)) / + many; + gpr_log(GPR_INFO, "Time per ping %f us", us); + end_test(&f); + config.tear_down_data(&f); +} + +static void simple_request(grpc_end2end_test_config config) { + int i; + for (i = 0; i < 10; i++) { + test_invoke_simple_request(config); + } + test_invoke_10_simple_requests(config); + test_invoke_many_simple_requests(config); +} + +static void simple_request_pre_init() { + gpr_mu_init(&tags_mu); + gpr_cv_init(&tags_cv); +} + +/* All test configurations */ +static grpc_end2end_test_config configs[] = { + {"inproc-callback", FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER, nullptr, + inproc_create_fixture, inproc_init_client, inproc_init_server, + inproc_tear_down}, +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + simple_request_pre_init(); + simple_request(configs[0]); + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/invalid_call_argument_test.cc b/test/core/end2end/invalid_call_argument_test.cc new file mode 100644 index 00000000..10ca9fbe --- /dev/null +++ b/test/core/end2end/invalid_call_argument_test.cc @@ -0,0 +1,636 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +struct test_state { + int is_client; + grpc_channel* chan; + grpc_call* call; + gpr_timespec deadline; + grpc_completion_queue* cq; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_slice details; + grpc_call* server_call; + grpc_server* server; + grpc_metadata_array server_initial_metadata_recv; + grpc_call_details call_details; +}; + +static struct test_state g_state; + +static void prepare_test(int is_client) { + int port = grpc_pick_unused_port_or_die(); + grpc_op* op; + g_state.is_client = is_client; + grpc_metadata_array_init(&g_state.initial_metadata_recv); + grpc_metadata_array_init(&g_state.trailing_metadata_recv); + g_state.deadline = grpc_timeout_seconds_to_deadline(5); + g_state.cq = grpc_completion_queue_create_for_next(nullptr); + g_state.cqv = cq_verifier_create(g_state.cq); + g_state.details = grpc_empty_slice(); + memset(g_state.ops, 0, sizeof(g_state.ops)); + + if (is_client) { + /* create a call, channel to a non existant server */ + g_state.chan = + grpc_insecure_channel_create("nonexistant:54321", nullptr, nullptr); + grpc_slice host = grpc_slice_from_static_string("nonexistant"); + g_state.call = grpc_channel_create_call( + g_state.chan, nullptr, GRPC_PROPAGATE_DEFAULTS, g_state.cq, + grpc_slice_from_static_string("/Foo"), &host, g_state.deadline, + nullptr); + } else { + g_state.server = grpc_server_create(nullptr, nullptr); + grpc_server_register_completion_queue(g_state.server, g_state.cq, nullptr); + std::string server_hostport = grpc_core::JoinHostPort("0.0.0.0", port); + grpc_server_add_insecure_http2_port(g_state.server, + server_hostport.c_str()); + grpc_server_start(g_state.server); + server_hostport = grpc_core::JoinHostPort("localhost", port); + g_state.chan = + grpc_insecure_channel_create(server_hostport.c_str(), nullptr, nullptr); + grpc_slice host = grpc_slice_from_static_string("bar"); + g_state.call = grpc_channel_create_call( + g_state.chan, nullptr, GRPC_PROPAGATE_DEFAULTS, g_state.cq, + grpc_slice_from_static_string("/Foo"), &host, g_state.deadline, + nullptr); + grpc_metadata_array_init(&g_state.server_initial_metadata_recv); + grpc_call_details_init(&g_state.call_details); + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), + tag(1), nullptr)); + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_call(g_state.server, &g_state.server_call, + &g_state.call_details, + &g_state.server_initial_metadata_recv, + g_state.cq, g_state.cq, tag(101))); + CQ_EXPECT_COMPLETION(g_state.cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(g_state.cqv, tag(1), 1); + cq_verify(g_state.cqv); + } +} + +static void cleanup_test() { + grpc_completion_queue* shutdown_cq; + grpc_call_unref(g_state.call); + cq_verifier_destroy(g_state.cqv); + grpc_channel_destroy(g_state.chan); + grpc_slice_unref(g_state.details); + grpc_metadata_array_destroy(&g_state.initial_metadata_recv); + grpc_metadata_array_destroy(&g_state.trailing_metadata_recv); + + if (!g_state.is_client) { + shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_call_unref(g_state.server_call); + grpc_server_shutdown_and_notify(g_state.server, shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_completion_queue_destroy(shutdown_cq); + grpc_server_destroy(g_state.server); + grpc_call_details_destroy(&g_state.call_details); + grpc_metadata_array_destroy(&g_state.server_initial_metadata_recv); + } + grpc_completion_queue_shutdown(g_state.cq); + while (grpc_completion_queue_next(g_state.cq, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(g_state.cq); +} + +static void test_non_null_reserved_on_start_batch() { + gpr_log(GPR_INFO, "test_non_null_reserved_on_start_batch"); + + prepare_test(1); + GPR_ASSERT(GRPC_CALL_ERROR == + grpc_call_start_batch(g_state.call, nullptr, 0, nullptr, tag(1))); + cleanup_test(); +} + +static void test_non_null_reserved_on_op() { + gpr_log(GPR_INFO, "test_non_null_reserved_on_op"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = tag(2); + op++; + GPR_ASSERT(GRPC_CALL_ERROR == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_send_initial_metadata_more_than_once() { + gpr_log(GPR_INFO, "test_send_initial_metadata_more_than_once"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), + tag(1), nullptr)); + CQ_EXPECT_COMPLETION(g_state.cqv, tag(1), 0); + cq_verify(g_state.cqv); + + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_too_many_metadata() { + gpr_log(GPR_INFO, "test_too_many_metadata"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = static_cast(INT_MAX) + 1; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_METADATA == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_send_null_message() { + gpr_log(GPR_INFO, "test_send_null_message"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_MESSAGE == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_send_messages_at_the_same_time() { + gpr_log(GPR_INFO, "test_send_messages_at_the_same_time"); + + grpc_op* op; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + prepare_test(1); + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = static_cast(tag(2)); + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + grpc_byte_buffer_destroy(request_payload); + cleanup_test(); +} + +static void test_send_server_status_from_client() { + gpr_log(GPR_INFO, "test_send_server_status_from_client"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_NOT_ON_CLIENT == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_receive_initial_metadata_twice_at_client() { + gpr_log(GPR_INFO, "test_receive_initial_metadata_twice_at_client"); + + grpc_op* op; + prepare_test(1); + op = g_state.ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &g_state.initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), + tag(1), nullptr)); + CQ_EXPECT_COMPLETION(g_state.cqv, tag(1), 0); + cq_verify(g_state.cqv); + op = g_state.ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &g_state.initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_receive_message_with_invalid_flags() { + gpr_log(GPR_INFO, "test_receive_message_with_invalid_flags"); + + grpc_op* op; + grpc_byte_buffer* payload = nullptr; + prepare_test(1); + op = g_state.ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &payload; + op->flags = 1; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_FLAGS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_receive_two_messages_at_the_same_time() { + gpr_log(GPR_INFO, "test_receive_two_messages_at_the_same_time"); + + grpc_op* op; + grpc_byte_buffer* payload = nullptr; + prepare_test(1); + op = g_state.ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &payload; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_recv_close_on_server_from_client() { + gpr_log(GPR_INFO, "test_recv_close_on_server_from_client"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_NOT_ON_CLIENT == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_recv_status_on_client_twice() { + gpr_log(GPR_INFO, "test_recv_status_on_client_twice"); + + grpc_op* op; + prepare_test(1); + + op = g_state.ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = + &g_state.trailing_metadata_recv; + op->data.recv_status_on_client.status = &g_state.status; + op->data.recv_status_on_client.status_details = &g_state.details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), + tag(1), nullptr)); + CQ_EXPECT_COMPLETION(g_state.cqv, tag(1), 1); + cq_verify(g_state.cqv); + + op = g_state.ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = nullptr; + op->data.recv_status_on_client.status = nullptr; + op->data.recv_status_on_client.status_details = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +static void test_send_close_from_client_on_server() { + gpr_log(GPR_INFO, "test_send_close_from_client_on_server"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_NOT_ON_SERVER == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_recv_status_on_client_from_server() { + gpr_log(GPR_INFO, "test_recv_status_on_client_from_server"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = + &g_state.trailing_metadata_recv; + op->data.recv_status_on_client.status = &g_state.status; + op->data.recv_status_on_client.status_details = &g_state.details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_NOT_ON_SERVER == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_send_status_from_server_with_invalid_flags() { + gpr_log(GPR_INFO, "test_send_status_from_server_with_invalid_flags"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 1; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_FLAGS == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_too_many_trailing_metadata() { + gpr_log(GPR_INFO, "test_too_many_trailing_metadata"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = + static_cast(INT_MAX) + 1; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_METADATA == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_send_server_status_twice() { + gpr_log(GPR_INFO, "test_send_server_status_twice"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_recv_close_on_server_with_invalid_flags() { + gpr_log(GPR_INFO, "test_recv_close_on_server_with_invalid_flags"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = nullptr; + op->flags = 1; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_FLAGS == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_recv_close_on_server_twice() { + gpr_log(GPR_INFO, "test_recv_close_on_server_twice"); + + grpc_op* op; + prepare_test(0); + + op = g_state.ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_TOO_MANY_OPERATIONS == + grpc_call_start_batch(g_state.server_call, g_state.ops, + (size_t)(op - g_state.ops), tag(2), + nullptr)); + cleanup_test(); +} + +static void test_invalid_initial_metadata_reserved_key() { + gpr_log(GPR_INFO, "test_invalid_initial_metadata_reserved_key"); + + grpc_metadata metadata; + metadata.key = grpc_slice_from_static_string(":start_with_colon"); + metadata.value = grpc_slice_from_static_string("value"); + + grpc_op* op; + prepare_test(1); + op = g_state.ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &metadata; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_ERROR_INVALID_METADATA == + grpc_call_start_batch(g_state.call, g_state.ops, + (size_t)(op - g_state.ops), tag(1), + nullptr)); + cleanup_test(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_invalid_initial_metadata_reserved_key(); + test_non_null_reserved_on_start_batch(); + test_non_null_reserved_on_op(); + test_send_initial_metadata_more_than_once(); + test_too_many_metadata(); + test_send_null_message(); + test_send_messages_at_the_same_time(); + test_send_server_status_from_client(); + test_receive_initial_metadata_twice_at_client(); + test_receive_message_with_invalid_flags(); + test_receive_two_messages_at_the_same_time(); + test_recv_close_on_server_from_client(); + test_recv_status_on_client_twice(); + test_send_close_from_client_on_server(); + test_recv_status_on_client_from_server(); + test_send_status_from_server_with_invalid_flags(); + test_too_many_trailing_metadata(); + test_send_server_status_twice(); + test_recv_close_on_server_with_invalid_flags(); + test_recv_close_on_server_twice(); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/end2end/multiple_server_queues_test.cc b/test/core/end2end/multiple_server_queues_test.cc new file mode 100644 index 00000000..98738376 --- /dev/null +++ b/test/core/end2end/multiple_server_queues_test.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/test_config.h" + +int main(int argc, char** argv) { + grpc_completion_queue* cq1; + grpc_completion_queue* cq2; + grpc_completion_queue* cq3; + grpc_completion_queue_attributes attr; + + grpc_server* server; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + attr.cq_polling_type = GRPC_CQ_DEFAULT_POLLING; + cq1 = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + attr.cq_polling_type = GRPC_CQ_NON_LISTENING; + cq2 = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + attr.cq_polling_type = GRPC_CQ_NON_POLLING; + cq3 = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + server = grpc_server_create(nullptr, nullptr); + grpc_server_register_completion_queue(server, cq1, nullptr); + grpc_server_add_insecure_http2_port(server, "[::]:0"); + grpc_server_register_completion_queue(server, cq2, nullptr); + grpc_server_register_completion_queue(server, cq3, nullptr); + + grpc_server_start(server); + grpc_server_shutdown_and_notify(server, cq2, nullptr); + grpc_completion_queue_next(cq2, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); /* cue queue freeze */ + grpc_completion_queue_shutdown(cq1); + grpc_completion_queue_shutdown(cq2); + grpc_completion_queue_shutdown(cq3); + + grpc_completion_queue_next(cq1, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_completion_queue_next(cq2, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_completion_queue_next(cq3, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq1); + grpc_completion_queue_destroy(cq2); + grpc_completion_queue_destroy(cq3); + grpc_shutdown(); + return 0; +} diff --git a/test/core/end2end/no_server_test.cc b/test/core/end2end/no_server_test.cc new file mode 100644 index 00000000..e69d92c6 --- /dev/null +++ b/test/core/end2end/no_server_test.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/test_config.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +void run_test(bool wait_for_ready) { + gpr_log(GPR_INFO, "TEST: wait_for_ready=%d", wait_for_ready); + + grpc_init(); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + cq_verifier* cqv = cq_verifier_create(cq); + + grpc_core::RefCountedPtr + response_generator = + grpc_core::MakeRefCounted(); + grpc_arg arg = grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + response_generator.get()); + grpc_channel_args args = {1, &arg}; + + /* create a call, channel to a non existant server */ + grpc_channel* chan = + grpc_insecure_channel_create("fake:nonexistant", &args, nullptr); + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(2); + grpc_call* call = grpc_channel_create_call( + chan, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/Foo"), nullptr, deadline, nullptr); + + grpc_op ops[6]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = wait_for_ready ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0; + op->reserved = nullptr; + op++; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_status_code status; + grpc_slice details; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), tag(1), + nullptr)); + + { + grpc_core::ExecCtx exec_ctx; + response_generator->SetFailure(); + } + + /* verify that all tags get completed */ + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + gpr_log(GPR_INFO, "call status: %d", status); + if (wait_for_ready) { + GPR_ASSERT(status == GRPC_STATUS_DEADLINE_EXCEEDED); + } else { + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); + grpc_call_unref(call); + grpc_channel_destroy(chan); + cq_verifier_destroy(cqv); + + grpc_shutdown(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + run_test(true /* wait_for_ready */); + run_test(false /* wait_for_ready */); + return 0; +} diff --git a/test/core/end2end/tests/authority_not_supported.cc b/test/core/end2end/tests/authority_not_supported.cc new file mode 100644 index 00000000..806215e1 --- /dev/null +++ b/test/core/end2end/tests/authority_not_supported.cc @@ -0,0 +1,186 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Request/response with metadata and payload.*/ +static void test_with_authority_header(grpc_end2end_test_config config) { + grpc_call* c; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_metadata meta_c[2] = {{grpc_slice_from_static_string("key1"), + grpc_slice_from_static_string("val1"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key2"), + grpc_slice_from_static_string("val2"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_end2end_test_fixture f = + begin_test(config, "test_with_authority_header", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + grpc_slice host = grpc_slice_from_static_string("foo.test.google.fr"); + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), &host, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_c; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_CANCELLED); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void authority_not_supported(grpc_end2end_test_config config) { + if (config.feature_mask & FEATURE_MASK_SUPPORTS_AUTHORITY_HEADER) { + return; + } + test_with_authority_header(config); +} + +void authority_not_supported_pre_init(void) {} diff --git a/test/core/end2end/tests/bad_hostname.cc b/test/core/end2end/tests/bad_hostname.cc new file mode 100644 index 00000000..a2caf88e --- /dev/null +++ b/test/core/end2end/tests/bad_hostname.cc @@ -0,0 +1,171 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_client(&f, client_args); + config.init_server(&f, server_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_fixture f) { + grpc_call* c; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + grpc_slice host = grpc_slice_from_static_string("slartibartfast.local"); + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), &host, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_INTERNAL); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(f); + end_test(&f); + config.tear_down_data(&f); +} + +void bad_hostname(grpc_end2end_test_config config) { + if (config.feature_mask & FEATURE_MASK_SUPPORTS_HOSTNAME_VERIFICATION) { + test_invoke_simple_request(config); + } +} + +void bad_hostname_pre_init(void) {} diff --git a/test/core/end2end/tests/bad_ping.cc b/test/core/end2end/tests/bad_ping.cc new file mode 100644 index 00000000..d957f583 --- /dev/null +++ b/test/core/end2end/tests/bad_ping.cc @@ -0,0 +1,374 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +#define MAX_PING_STRIKES 2 + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), + nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Send more pings than server allows to trigger server's GOAWAY. +static void test_bad_ping(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_arg client_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0)}; + grpc_arg server_a[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 300000 /* 5 minutes */), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), MAX_PING_STRIKES), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0)}; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, &client_args); + config.init_server(&f, &server_args); + + grpc_call* c; + grpc_call* s; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(10); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + // Send too many pings to the server to trigger the punishment: + // The first ping will let server mark its last_recv time. Afterwards, each + // ping will trigger a ping strike, and we need at least MAX_PING_STRIKES + // strikes to trigger the punishment. So (MAX_PING_STRIKES + 2) pings are + // needed here. + int i; + for (i = 1; i <= MAX_PING_STRIKES + 2; i++) { + grpc_channel_ping(f.client, f.cq, tag(200 + i), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(200 + i), 1); + if (i == MAX_PING_STRIKES + 2) { + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + } + cq_verify(cqv); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + cq_verify(cqv); + + grpc_call_unref(s); + + // The connection should be closed immediately after the misbehaved pings, + // the in-progress RPC should fail. + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 1); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +// Try sending more pings than server allows, but server should be fine because +// max_pings_without_data should limit pings sent out on wire. +static void test_pings_without_data(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + // Only allow MAX_PING_STRIKES pings without data (DATA/HEADERS/WINDOW_UPDATE) + // so that the transport will throttle the excess pings. + grpc_arg client_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), + MAX_PING_STRIKES), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0)}; + grpc_arg server_a[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 300000 /* 5 minutes */), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), MAX_PING_STRIKES), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0)}; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, &client_args); + config.init_server(&f, &server_args); + + grpc_call* c; + grpc_call* s; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(10); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + // Send too many pings to the server similar to the previous test case. + // However, since we set the MAX_PINGS_WITHOUT_DATA at the client side, only + // MAX_PING_STRIKES will actually be sent and the rpc will still succeed. + int i; + for (i = 1; i <= MAX_PING_STRIKES + 2; i++) { + grpc_channel_ping(f.client, f.cq, tag(200 + i), nullptr); + if (i <= MAX_PING_STRIKES) { + CQ_EXPECT_COMPLETION(cqv, tag(200 + i), 1); + } + cq_verify(cqv); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + // Client call should return. + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + + // Also expect the previously blocked pings to complete with an error + CQ_EXPECT_COMPLETION(cqv, tag(200 + MAX_PING_STRIKES + 1), 0); + CQ_EXPECT_COMPLETION(cqv, tag(200 + MAX_PING_STRIKES + 2), 0); + + cq_verify(cqv); + + grpc_call_unref(s); + + // The rpc should be successful. + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void bad_ping(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION); + test_bad_ping(config); + test_pings_without_data(config); +} + +void bad_ping_pre_init(void) {} diff --git a/test/core/end2end/tests/binary_metadata.cc b/test/core/end2end/tests/binary_metadata.cc new file mode 100644 index 00000000..ebc64a94 --- /dev/null +++ b/test/core/end2end/tests/binary_metadata.cc @@ -0,0 +1,315 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Request/response with metadata and payload.*/ +static void test_request_response_with_metadata_and_payload( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_metadata meta_c[2] = { + {grpc_slice_from_static_string("key1-bin"), + grpc_slice_from_static_string( + "\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key2-bin"), + grpc_slice_from_static_string( + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_metadata meta_s[2] = { + {grpc_slice_from_static_string("key3-bin"), + grpc_slice_from_static_string( + "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key4-bin"), + grpc_slice_from_static_string( + "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_end2end_test_fixture f = + begin_test(config, "test_request_response_with_metadata_and_payload", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_c; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_s; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_string = grpc_slice_from_static_string( + "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12" + "\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24" + "\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36" + "\x37\x38\x39\x3a\x3b\x3c\x3d\x3e\x3f\x40\x41\x42\x43\x44\x45\x46\x47\x48" + "\x49\x4a\x4b\x4c\x4d\x4e\x4f\x50\x51\x52\x53\x54\x55\x56\x57\x58\x59\x5a" + "\x5b\x5c\x5d\x5e\x5f\x60\x61\x62\x63\x64\x65\x66\x67\x68\x69\x6a\x6b\x6c" + "\x6d\x6e\x6f\x70\x71\x72\x73\x74\x75\x76\x77\x78\x79\x7a\x7b\x7c\x7d\x7e" + "\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90" + "\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2" + "\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4" + "\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6" + "\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8" + "\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea" + "\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc" + "\xfd\xfe\xff"); + op->data.send_status_from_server.status_details = &status_string; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT( + 0 == + grpc_slice_str_cmp( + details, + "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10" + "\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20" + "\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30" + "\x31\x32\x33\x34\x35\x36\x37\x38\x39\x3a\x3b\x3c\x3d\x3e\x3f\x40" + "\x41\x42\x43\x44\x45\x46\x47\x48\x49\x4a\x4b\x4c\x4d\x4e\x4f\x50" + "\x51\x52\x53\x54\x55\x56\x57\x58\x59\x5a\x5b\x5c\x5d\x5e\x5f\x60" + "\x61\x62\x63\x64\x65\x66\x67\x68\x69\x6a\x6b\x6c\x6d\x6e\x6f\x70" + "\x71\x72\x73\x74\x75\x76\x77\x78\x79\x7a\x7b\x7c\x7d\x7e\x7f\x80" + "\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90" + "\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0" + "\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0" + "\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0" + "\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0" + "\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0" + "\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0" + "\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, "hello you")); + GPR_ASSERT(contains_metadata( + &request_metadata_recv, "key1-bin", + "\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc")); + GPR_ASSERT(contains_metadata( + &request_metadata_recv, "key2-bin", + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d")); + GPR_ASSERT(contains_metadata( + &initial_metadata_recv, "key3-bin", + "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee")); + GPR_ASSERT(contains_metadata( + &initial_metadata_recv, "key4-bin", + "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void binary_metadata(grpc_end2end_test_config config) { + test_request_response_with_metadata_and_payload(config); +} + +void binary_metadata_pre_init(void) {} diff --git a/test/core/end2end/tests/call_creds.cc b/test/core/end2end/tests/call_creds.cc new file mode 100644 index 00000000..8d64b550 --- /dev/null +++ b/test/core/end2end/tests/call_creds.cc @@ -0,0 +1,552 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static const char iam_token[] = "token"; +static const char iam_selector[] = "selector"; +static const char overridden_iam_token[] = "overridden_token"; +static const char overridden_iam_selector[] = "overridden_selector"; +static const char fake_md_key[] = "fake_key"; +static const char fake_md_value[] = "fake_value"; +static const char overridden_fake_md_key[] = "overridden_fake_key"; +static const char overridden_fake_md_value[] = "overridden_fake_value"; + +typedef enum { NONE, OVERRIDE, DESTROY, FAIL } override_mode; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + bool use_secure_call_creds, + int fail_server_auth_check) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s%s/%s", test_name, + use_secure_call_creds ? "_with_secure_call_creds" + : "_with_insecure_call_creds", + config.name); + f = config.create_fixture(nullptr, nullptr); + config.init_client(&f, nullptr); + if (fail_server_auth_check) { + grpc_arg fail_auth_arg = { + GRPC_ARG_STRING, + const_cast(FAIL_AUTH_CHECK_SERVER_ARG_NAME), + {nullptr}}; + grpc_channel_args args; + args.num_args = 1; + args.args = &fail_auth_arg; + config.init_server(&f, &args); + } else { + config.init_server(&f, nullptr); + } + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void print_auth_context(int is_client, const grpc_auth_context* ctx) { + const grpc_auth_property* p; + grpc_auth_property_iterator it; + gpr_log(GPR_INFO, "%s peer:", is_client ? "client" : "server"); + gpr_log(GPR_INFO, "\tauthenticated: %s", + grpc_auth_context_peer_is_authenticated(ctx) ? "YES" : "NO"); + it = grpc_auth_context_peer_identity(ctx); + while ((p = grpc_auth_property_iterator_next(&it)) != nullptr) { + gpr_log(GPR_INFO, "\t\t%s: %s", p->name, p->value); + } + gpr_log(GPR_INFO, "\tall properties:"); + it = grpc_auth_context_property_iterator(ctx); + while ((p = grpc_auth_property_iterator_next(&it)) != nullptr) { + gpr_log(GPR_INFO, "\t\t%s: %s", p->name, p->value); + } +} + +static void request_response_with_payload_and_call_creds( + const char* test_name, grpc_end2end_test_config config, override_mode mode, + bool use_secure_call_creds) { + grpc_call* c = nullptr; + grpc_call* s = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_end2end_test_fixture f; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_call_credentials* creds = nullptr; + grpc_auth_context* s_auth_context = nullptr; + grpc_auth_context* c_auth_context = nullptr; + + f = begin_test(config, test_name, use_secure_call_creds, 0); + cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + if (use_secure_call_creds) { + creds = + grpc_google_iam_credentials_create(iam_token, iam_selector, nullptr); + } else { + creds = + grpc_md_only_test_credentials_create(fake_md_key, fake_md_value, false); + } + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(grpc_call_set_credentials(c, creds) == GRPC_CALL_OK); + switch (mode) { + case NONE: + break; + case OVERRIDE: + grpc_call_credentials_release(creds); + if (use_secure_call_creds) { + creds = grpc_google_iam_credentials_create( + overridden_iam_token, overridden_iam_selector, nullptr); + } else { + creds = grpc_md_only_test_credentials_create( + overridden_fake_md_key, overridden_fake_md_value, false); + } + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(grpc_call_set_credentials(c, creds) == GRPC_CALL_OK); + break; + case DESTROY: + GPR_ASSERT(grpc_call_set_credentials(c, nullptr) == GRPC_CALL_OK); + break; + case FAIL: + // Do nothing + break; + } + grpc_call_credentials_release(creds); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + if (mode == FAIL) { + // Expect the call to fail since the channel credentials did not satisfy the + // minimum security level requirements. + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + GPR_ASSERT(status == GRPC_STATUS_UNAUTHENTICATED); + } else { + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + s_auth_context = grpc_call_auth_context(s); + GPR_ASSERT(s_auth_context != nullptr); + print_auth_context(0, s_auth_context); + grpc_auth_context_release(s_auth_context); + + c_auth_context = grpc_call_auth_context(c); + GPR_ASSERT(c_auth_context != nullptr); + print_auth_context(1, c_auth_context); + grpc_auth_context_release(c_auth_context); + + /* Cannot set creds on the server call object. */ + GPR_ASSERT(grpc_call_set_credentials(s, nullptr) != GRPC_CALL_OK); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, "hello you")); + + switch (mode) { + case NONE: + if (use_secure_call_creds) { + GPR_ASSERT(contains_metadata( + &request_metadata_recv, GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + iam_token)); + GPR_ASSERT(contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + iam_selector)); + } else { + GPR_ASSERT(contains_metadata(&request_metadata_recv, fake_md_key, + fake_md_value)); + } + break; + case OVERRIDE: + if (use_secure_call_creds) { + GPR_ASSERT(contains_metadata( + &request_metadata_recv, GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + overridden_iam_token)); + GPR_ASSERT(contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + overridden_iam_selector)); + } else { + GPR_ASSERT(contains_metadata(&request_metadata_recv, + overridden_fake_md_key, + overridden_fake_md_value)); + } + break; + case DESTROY: + GPR_ASSERT(!contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + iam_token)); + GPR_ASSERT(!contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + iam_selector)); + GPR_ASSERT(!contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + overridden_iam_token)); + GPR_ASSERT(!contains_metadata(&request_metadata_recv, + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + overridden_iam_selector)); + GPR_ASSERT(!contains_metadata(&request_metadata_recv, fake_md_key, + fake_md_value)); + GPR_ASSERT(!contains_metadata(&request_metadata_recv, + overridden_fake_md_key, + overridden_fake_md_value)); + break; + case FAIL: + GPR_ASSERT(0); + } + grpc_call_unref(s); + } + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_request_response_with_payload_and_call_creds( + grpc_end2end_test_config config, bool use_secure_call_creds) { + request_response_with_payload_and_call_creds( + "test_request_response_with_payload_and_call_creds", config, NONE, + use_secure_call_creds); +} + +static void test_request_response_with_payload_and_overridden_call_creds( + grpc_end2end_test_config config, bool use_secure_call_creds) { + request_response_with_payload_and_call_creds( + "test_request_response_with_payload_and_overridden_call_creds", config, + OVERRIDE, use_secure_call_creds); +} + +static void test_request_response_with_payload_and_deleted_call_creds( + grpc_end2end_test_config config, bool use_secure_call_creds) { + request_response_with_payload_and_call_creds( + "test_request_response_with_payload_and_deleted_call_creds", config, + DESTROY, use_secure_call_creds); +} + +static void test_request_response_with_payload_fail_to_send_call_creds( + grpc_end2end_test_config config, bool use_secure_call_creds) { + request_response_with_payload_and_call_creds( + "test_request_response_with_payload_fail_to_send_call_creds", config, + FAIL, use_secure_call_creds); +} + +static void test_request_with_server_rejecting_client_creds( + grpc_end2end_test_config config) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_end2end_test_fixture f; + gpr_timespec deadline = five_seconds_from_now(); + cq_verifier* cqv; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_call_credentials* creds; + + f = begin_test(config, "test_request_with_server_rejecting_client_creds", + false, 1); + cqv = cq_verifier_create(f.cq); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + creds = + grpc_md_only_test_credentials_create(fake_md_key, fake_md_value, false); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(grpc_call_set_credentials(c, creds) == GRPC_CALL_OK); + grpc_call_credentials_release(creds); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(error == GRPC_CALL_OK); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNAUTHENTICATED); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void call_creds(grpc_end2end_test_config config) { + // Test fixtures that support call credentials with a minimum security level + // of GRPC_PRIVACY_AND_INTEGRITY + if (config.feature_mask & FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS) { + test_request_response_with_payload_and_call_creds(config, true); + test_request_response_with_payload_and_overridden_call_creds(config, true); + test_request_response_with_payload_and_deleted_call_creds(config, true); + } + // Test that fixtures that support call credentials with a minimum security + // level of GRPC_SECURITY_NONE cannot send call credentials that require + // higher security level + if (config.feature_mask & + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS_LEVEL_INSECURE) { + test_request_response_with_payload_fail_to_send_call_creds(config, true); + } + // Fixtures that support sending call credentials should be able to send call + // credentials of security level GRPC_SECURITY_NONE. + if (config.feature_mask & FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS || + config.feature_mask & + FEATURE_MASK_SUPPORTS_PER_CALL_CREDENTIALS_LEVEL_INSECURE) { + test_request_response_with_payload_and_call_creds(config, false); + test_request_response_with_payload_and_overridden_call_creds(config, false); + test_request_response_with_payload_and_deleted_call_creds(config, false); + test_request_with_server_rejecting_client_creds(config); + } +} + +void call_creds_pre_init(void) {} diff --git a/test/core/end2end/tests/call_host_override.cc b/test/core/end2end/tests/call_host_override.cc new file mode 100644 index 00000000..04d718d4 --- /dev/null +++ b/test/core/end2end/tests/call_host_override.cc @@ -0,0 +1,229 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + grpc_arg fake_security_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr:1234")}}; + grpc_channel_args* new_client_args = grpc_channel_args_copy_and_add( + client_args, &fake_security_name_override, 1); + config.init_client(&f, new_client_args); + grpc_channel_args_destroy(new_client_args); + config.init_server(&f, server_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call( + f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), + get_host_override_slice("foo.test.google.fr:1234", config), deadline, + nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(error == GRPC_CALL_OK); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(error == GRPC_CALL_OK); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(error == GRPC_CALL_OK); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + validate_host_override_string("foo.test.google.fr:1234", call_details.host, + config); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void call_host_override(grpc_end2end_test_config config) { + test_invoke_simple_request(config); +} + +void call_host_override_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_after_accept.cc b/test/core/end2end/tests/cancel_after_accept.cc new file mode 100644 index 00000000..459cdc80 --- /dev/null +++ b/test/core/end2end/tests/cancel_after_accept.cc @@ -0,0 +1,268 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + cancellation_mode mode, + bool use_service_config, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/%s/%s", test_name, config.name, + mode.name, use_service_config ? "service_config" : "client_api"); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel after accept, no payload */ +static void test_cancel_after_accept(grpc_end2end_test_config config, + cancellation_mode mode, + bool use_service_config) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_call* s; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + int was_cancelled = 2; + + grpc_channel_args* args = nullptr; + if (use_service_config) { + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + arg.value.string = const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" },\n" + " { \"service\": \"unused\" }\n" + " ],\n" + " \"timeout\": \"5s\"\n" + " } ]\n" + "}"); + args = grpc_channel_args_copy_and_add(args, &arg, 1); + } + + grpc_end2end_test_fixture f = begin_test(config, "cancel_after_accept", mode, + use_service_config, args, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = use_service_config + ? gpr_inf_future(GPR_CLOCK_MONOTONIC) + : five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(2)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == mode.expect_status || status == GRPC_STATUS_INTERNAL); + GPR_ASSERT(was_cancelled == 1); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + grpc_call_unref(s); + + if (args != nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(args); + } + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_after_accept(grpc_end2end_test_config config) { + unsigned i; + + for (i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); i++) { + test_cancel_after_accept(config, cancellation_modes[i], + false /* use_service_config */); + if (config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL && + cancellation_modes[i].expect_status == GRPC_STATUS_DEADLINE_EXCEEDED) { + test_cancel_after_accept(config, cancellation_modes[i], + true /* use_service_config */); + } + } +} + +void cancel_after_accept_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_after_client_done.cc b/test/core/end2end/tests/cancel_after_client_done.cc new file mode 100644 index 00000000..1039d9b3 --- /dev/null +++ b/test/core/end2end/tests/cancel_after_client_done.cc @@ -0,0 +1,236 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + cancellation_mode mode, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/%s", test_name, config.name, + mode.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel after accept with a writes closed, no payload */ +static void test_cancel_after_accept_and_writes_closed( + grpc_end2end_test_config config, cancellation_mode mode) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_call* s; + grpc_end2end_test_fixture f = + begin_test(config, "test_cancel_after_accept_and_writes_closed", mode, + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(2)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == mode.expect_status || status == GRPC_STATUS_INTERNAL); + GPR_ASSERT(was_cancelled == 1); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_after_client_done(grpc_end2end_test_config config) { + unsigned i; + + for (i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); i++) { + test_cancel_after_accept_and_writes_closed(config, cancellation_modes[i]); + } +} + +void cancel_after_client_done_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_after_invoke.cc b/test/core/end2end/tests/cancel_after_invoke.cc new file mode 100644 index 00000000..0195dd1c --- /dev/null +++ b/test/core/end2end/tests/cancel_after_invoke.cc @@ -0,0 +1,194 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + cancellation_mode mode, + size_t test_ops, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/%s [%" PRIdPTR " ops]", test_name, + config.name, mode.name, test_ops); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->cq, tag(1000)); + grpc_event ev = grpc_completion_queue_next( + f->cq, grpc_timeout_seconds_to_deadline(5), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag(1000)); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel after invoke, no payload */ +static void test_cancel_after_invoke(grpc_end2end_test_config config, + cancellation_mode mode, size_t test_ops) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_end2end_test_fixture f = begin_test(config, "test_cancel_after_invoke", + mode, test_ops, nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, test_ops, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == mode.expect_status || status == GRPC_STATUS_INTERNAL); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_after_invoke(grpc_end2end_test_config config) { + unsigned i, j; + + for (j = 3; j < 6; j++) { + for (i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); i++) { + test_cancel_after_invoke(config, cancellation_modes[i], j); + } + } +} + +void cancel_after_invoke_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_after_round_trip.cc b/test/core/end2end/tests/cancel_after_round_trip.cc new file mode 100644 index 00000000..bb579d0d --- /dev/null +++ b/test/core/end2end/tests/cancel_after_round_trip.cc @@ -0,0 +1,303 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + cancellation_mode mode, + bool use_service_config, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/%s/%s", test_name, config.name, + mode.name, use_service_config ? "service_config" : "client_api"); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel after accept, no payload */ +static void test_cancel_after_round_trip(grpc_end2end_test_config config, + cancellation_mode mode, + bool use_service_config) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_call* s; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload1 = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* response_payload2 = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + int was_cancelled = 2; + + grpc_channel_args* args = nullptr; + if (use_service_config) { + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + arg.value.string = const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"timeout\": \"5s\"\n" + " } ]\n" + "}"); + args = grpc_channel_args_copy_and_add(args, &arg, 1); + } + + grpc_end2end_test_fixture f = + begin_test(config, "cancel_after_round_trip", mode, use_service_config, + args, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = use_service_config + ? gpr_inf_future(GPR_CLOCK_MONOTONIC) + : five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload1; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + request_payload_recv = nullptr; + response_payload_recv = nullptr; + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload2; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + cq_verify(cqv); + + GPR_ASSERT(status == mode.expect_status || status == GRPC_STATUS_INTERNAL); + GPR_ASSERT(was_cancelled == 1); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload1); + grpc_byte_buffer_destroy(response_payload2); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + grpc_call_unref(s); + + if (args != nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(args); + } + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_after_round_trip(grpc_end2end_test_config config) { + unsigned i; + + for (i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); i++) { + test_cancel_after_round_trip(config, cancellation_modes[i], + false /* use_service_config */); + if (config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL && + cancellation_modes[i].expect_status == GRPC_STATUS_DEADLINE_EXCEEDED) { + test_cancel_after_round_trip(config, cancellation_modes[i], + true /* use_service_config */); + } + } +} + +void cancel_after_round_trip_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_before_invoke.cc b/test/core/end2end/tests/cancel_before_invoke.cc new file mode 100644 index 00000000..74821932 --- /dev/null +++ b/test/core/end2end/tests/cancel_before_invoke.cc @@ -0,0 +1,188 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + size_t num_ops, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s [%" PRIdPTR " ops]", test_name, + config.name, num_ops); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel before invoke */ +static void test_cancel_before_invoke(grpc_end2end_test_config config, + size_t test_ops) { + grpc_op ops[6]; + grpc_op* op; + grpc_call* c; + grpc_end2end_test_fixture f = + begin_test(config, "cancel_before_invoke", test_ops, nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + GPR_ASSERT(GRPC_CALL_OK == grpc_call_cancel(c, nullptr)); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, test_ops, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_CANCELLED); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_before_invoke(grpc_end2end_test_config config) { + size_t i; + for (i = 1; i <= 6; i++) { + test_cancel_before_invoke(config, i); + } +} + +void cancel_before_invoke_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_in_a_vacuum.cc b/test/core/end2end/tests/cancel_in_a_vacuum.cc new file mode 100644 index 00000000..e004cee0 --- /dev/null +++ b/test/core/end2end/tests/cancel_in_a_vacuum.cc @@ -0,0 +1,121 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + cancellation_mode mode, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/%s", test_name, config.name, + mode.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Cancel and do nothing */ +static void test_cancel_in_a_vacuum(grpc_end2end_test_config config, + cancellation_mode mode) { + grpc_call* c; + grpc_end2end_test_fixture f = + begin_test(config, "test_cancel_in_a_vacuum", mode, nullptr, nullptr); + cq_verifier* v_client = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + grpc_call_unref(c); + + cq_verifier_destroy(v_client); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_in_a_vacuum(grpc_end2end_test_config config) { + unsigned i; + + for (i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); i++) { + test_cancel_in_a_vacuum(config, cancellation_modes[i]); + } +} + +void cancel_in_a_vacuum_pre_init(void) {} diff --git a/test/core/end2end/tests/cancel_with_status.cc b/test/core/end2end/tests/cancel_with_status.cc new file mode 100644 index 00000000..98c83c5f --- /dev/null +++ b/test/core/end2end/tests/cancel_with_status.cc @@ -0,0 +1,182 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + size_t num_ops, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s [%" PRIdPTR " ops]", test_name, + config.name, num_ops); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->cq, tag(1000)); + grpc_event ev = grpc_completion_queue_next( + f->cq, grpc_timeout_seconds_to_deadline(5), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag(1000)); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f, size_t num_ops) { + grpc_call* c; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + gpr_log(GPR_DEBUG, "test with %" PRIuPTR " ops", num_ops); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(num_ops <= (size_t)(op - ops)); + error = grpc_call_start_batch(c, ops, num_ops, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + char* dynamic_string = gpr_strdup("xyz"); + grpc_call_cancel_with_status(c, GRPC_STATUS_UNIMPLEMENTED, dynamic_string, + nullptr); + // The API of \a description allows for it to be a dynamic/non-const + // string, test this guarantee. + gpr_free(dynamic_string); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config, + size_t num_ops) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", num_ops, nullptr, + nullptr); + simple_request_body(config, f, num_ops); + end_test(&f); + config.tear_down_data(&f); +} + +void cancel_with_status(grpc_end2end_test_config config) { + size_t i; + for (i = 1; i <= 4; i++) { + test_invoke_simple_request(config, i); + } +} + +void cancel_with_status_pre_init(void) {} diff --git a/test/core/end2end/tests/channelz.cc b/test/core/end2end/tests/channelz.cc new file mode 100644 index 00000000..d508fc12 --- /dev/null +++ b/test/core/end2end/tests/channelz.cc @@ -0,0 +1,324 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channelz_registry.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void run_one_request(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f, + bool request_is_success) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = + request_is_success ? GRPC_STATUS_OK : GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_channelz(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + grpc_arg arg[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true)}; + grpc_channel_args args = {GPR_ARRAY_SIZE(arg), arg}; + + f = begin_test(config, "test_channelz", &args, &args); + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(f.client); + GPR_ASSERT(channelz_channel != nullptr); + + grpc_core::channelz::ServerNode* channelz_server = + f.server->core_server->channelz_node(); + GPR_ASSERT(channelz_server != nullptr); + + std::string json = channelz_channel->RenderJsonString(); + // nothing is present yet + GPR_ASSERT(json.find("\"callsStarted\"") == json.npos); + GPR_ASSERT(json.find("\"callsFailed\"") == json.npos); + GPR_ASSERT(json.find("\"callsSucceeded\"") == json.npos); + + // one successful request + run_one_request(config, f, true); + + json = channelz_channel->RenderJsonString(); + GPR_ASSERT(json.find("\"callsStarted\":\"1\"") != json.npos); + GPR_ASSERT(json.find("\"callsSucceeded\":\"1\"") != json.npos); + + // one failed request + run_one_request(config, f, false); + + json = channelz_channel->RenderJsonString(); + GPR_ASSERT(json.find("\"callsStarted\":\"2\"") != json.npos); + GPR_ASSERT(json.find("\"callsFailed\":\"1\"") != json.npos); + GPR_ASSERT(json.find("\"callsSucceeded\":\"1\"") != json.npos); + // channel tracing is not enabled, so these should not be preset. + GPR_ASSERT(json.find("\"trace\"") == json.npos); + GPR_ASSERT(json.find("\"description\":\"Channel created\"") == json.npos); + GPR_ASSERT(json.find("\"severity\":\"CT_INFO\"") == json.npos); + + json = channelz_server->RenderJsonString(); + GPR_ASSERT(json.find("\"callsStarted\":\"2\"") != json.npos); + GPR_ASSERT(json.find("\"callsFailed\":\"1\"") != json.npos); + GPR_ASSERT(json.find("\"callsSucceeded\":\"1\"") != json.npos); + // channel tracing is not enabled, so these should not be preset. + GPR_ASSERT(json.find("\"trace\"") == json.npos); + GPR_ASSERT(json.find("\"description\":\"Channel created\"") == json.npos); + GPR_ASSERT(json.find("\"severity\":\"CT_INFO\"") == json.npos); + + json = channelz_server->RenderServerSockets(0, 100); + GPR_ASSERT(json.find("\"end\":true") != json.npos); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_channelz_with_channel_trace(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + grpc_arg arg[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + 1024 * 1024), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true)}; + grpc_channel_args args = {GPR_ARRAY_SIZE(arg), arg}; + + f = begin_test(config, "test_channelz_with_channel_trace", &args, &args); + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(f.client); + GPR_ASSERT(channelz_channel != nullptr); + + grpc_core::channelz::ServerNode* channelz_server = + f.server->core_server->channelz_node(); + GPR_ASSERT(channelz_server != nullptr); + + run_one_request(config, f, true); + + std::string json = channelz_channel->RenderJsonString(); + GPR_ASSERT(json.find("\"trace\"") != json.npos); + GPR_ASSERT(json.find("\"description\":\"Channel created\"") != json.npos); + GPR_ASSERT(json.find("\"severity\":\"CT_INFO\"") != json.npos); + + json = channelz_server->RenderJsonString(); + GPR_ASSERT(json.find("\"trace\"") != json.npos); + GPR_ASSERT(json.find("\"description\":\"Server created\"") != json.npos); + GPR_ASSERT(json.find("\"severity\":\"CT_INFO\"") != json.npos); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_channelz_disabled(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + grpc_arg arg[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), false)}; + grpc_channel_args args = {GPR_ARRAY_SIZE(arg), arg}; + + f = begin_test(config, "test_channelz_disabled", &args, &args); + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(f.client); + GPR_ASSERT(channelz_channel == nullptr); + // one successful request + run_one_request(config, f, true); + GPR_ASSERT(channelz_channel == nullptr); + end_test(&f); + config.tear_down_data(&f); +} + +void channelz(grpc_end2end_test_config config) { + test_channelz(config); + test_channelz_with_channel_trace(config); + test_channelz_disabled(config); +} + +void channelz_pre_init(void) {} diff --git a/test/core/end2end/tests/client_streaming.cc b/test/core/end2end/tests/client_streaming.cc new file mode 100644 index 00000000..6b62a010 --- /dev/null +++ b/test/core/end2end/tests/client_streaming.cc @@ -0,0 +1,273 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Client streaming test where the client sends a bunch of messages and the +// server reads them. After reading some messages, the server sends the status. +// Client writes fail after that due to the end of stream and the client +// subsequently requests and receives the status. +static void test_client_streaming(grpc_end2end_test_config config, + int messages) { + grpc_end2end_test_fixture f = + begin_test(config, "test_client_streaming", nullptr, nullptr); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* request_payload = nullptr; + int i; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + // Client writes bunch of messages and server reads them + for (i = 0; i < messages; i++) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_byte_buffer_destroy(request_payload); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + cq_verify(cqv); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + grpc_byte_buffer_destroy(request_payload_recv); + } + + // Server sends status denoting end of stream + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + // Do an empty verify to make sure that the client receives the status + cq_verify_empty(cqv); + + // Client tries sending another message which should fail + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_byte_buffer_destroy(request_payload); + CQ_EXPECT_COMPLETION(cqv, tag(103), 0); + cq_verify(cqv); + + // Client sends close and requests status + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(request_payload_slice); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + + end_test(&f); + config.tear_down_data(&f); +} + +void client_streaming(grpc_end2end_test_config config) { + for (int i = 0; i < 10; i++) { + test_client_streaming(config, i); + } +} + +void client_streaming_pre_init(void) {} diff --git a/test/core/end2end/tests/compressed_payload.cc b/test/core/end2end/tests/compressed_payload.cc new file mode 100644 index 00000000..cd9be9b4 --- /dev/null +++ b/test/core/end2end/tests/compressed_payload.cc @@ -0,0 +1,701 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/call_test_only.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args, + bool decompress_in_core) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s%s/%s", test_name, + decompress_in_core ? "" : "_with_decompression_disabled", + config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void request_for_disabled_algorithm( + grpc_end2end_test_config config, const char* test_name, + uint32_t send_flags_bitmask, + grpc_compression_algorithm algorithm_to_disable, + grpc_compression_algorithm requested_client_compression_algorithm, + grpc_status_code expected_error, grpc_metadata* client_metadata, + bool decompress_in_core) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice; + grpc_byte_buffer* request_payload; + grpc_channel_args* client_args; + grpc_channel_args* server_args; + grpc_end2end_test_fixture f; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + cq_verifier* cqv; + char str[1024]; + + memset(str, 'x', 1023); + str[1023] = '\0'; + request_payload_slice = grpc_slice_from_copied_string(str); + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + + client_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, requested_client_compression_algorithm); + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_NONE); + server_args = grpc_channel_args_compression_algorithm_set_state( + &server_args, algorithm_to_disable, false); + if (!decompress_in_core) { + grpc_arg disable_decompression_in_core_arg = + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION), 0); + grpc_channel_args* old_client_args = client_args; + grpc_channel_args* old_server_args = server_args; + client_args = grpc_channel_args_copy_and_add( + client_args, &disable_decompression_in_core_arg, 1); + server_args = grpc_channel_args_copy_and_add( + server_args, &disable_decompression_in_core_arg, 1); + grpc_channel_args_destroy(old_client_args); + grpc_channel_args_destroy(old_server_args); + } + + f = begin_test(config, test_name, client_args, server_args, + decompress_in_core); + cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + if (client_metadata != nullptr) { + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = client_metadata; + } else { + op->data.send_initial_metadata.count = 0; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = send_flags_bitmask; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), false); + + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + cq_verify(cqv); + + /* call was cancelled (closed) ... */ + GPR_ASSERT(was_cancelled != 0); + /* with a certain error */ + GPR_ASSERT(status == expected_error); + + const char* algo_name = nullptr; + GPR_ASSERT(grpc_compression_algorithm_name(algorithm_to_disable, &algo_name)); + std::string expected_details = + absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); + /* and we expect a specific reason for it */ + GPR_ASSERT(0 == grpc_slice_str_cmp(details, expected_details.c_str())); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_slice_unref(request_payload_slice); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + end_test(&f); + config.tear_down_data(&f); +} + +static void request_with_payload_template_inner( + grpc_end2end_test_config config, const char* test_name, + uint32_t client_send_flags_bitmask, + grpc_compression_algorithm default_client_channel_compression_algorithm, + grpc_compression_algorithm default_server_channel_compression_algorithm, + grpc_compression_algorithm expected_algorithm_from_client, + grpc_compression_algorithm expected_algorithm_from_server, + grpc_metadata* client_init_metadata, bool set_server_level, + grpc_compression_level server_compression_level, + bool send_message_before_initial_metadata, bool decompress_in_core) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice; + grpc_byte_buffer* request_payload = nullptr; + grpc_channel_args* client_args; + grpc_channel_args* server_args; + grpc_end2end_test_fixture f; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload; + grpc_byte_buffer* response_payload_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + cq_verifier* cqv; + char request_str[1024]; + char response_str[1024]; + + memset(request_str, 'x', 1023); + request_str[1023] = '\0'; + + memset(response_str, 'y', 1023); + response_str[1023] = '\0'; + + request_payload_slice = grpc_slice_from_copied_string(request_str); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string(response_str); + + client_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_client_channel_compression_algorithm); + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_server_channel_compression_algorithm); + if (!decompress_in_core) { + grpc_arg disable_decompression_in_core_arg = + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION), 0); + grpc_channel_args* old_client_args = client_args; + grpc_channel_args* old_server_args = server_args; + client_args = grpc_channel_args_copy_and_add( + client_args, &disable_decompression_in_core_arg, 1); + server_args = grpc_channel_args_copy_and_add( + server_args, &disable_decompression_in_core_arg, 1); + grpc_channel_args_destroy(old_client_args); + grpc_channel_args_destroy(old_server_args); + } + f = begin_test(config, test_name, client_args, server_args, + decompress_in_core); + cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + if (send_message_before_initial_metadata) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = client_send_flags_bitmask; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + } + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + if (client_init_metadata != nullptr) { + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = client_init_metadata; + } else { + op->data.send_initial_metadata.count = 0; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), true); + cq_verify(cqv); + + GPR_ASSERT(grpc_core::BitCount( + grpc_call_test_only_get_encodings_accepted_by_peer(s)) == + GRPC_COMPRESS_ALGORITHMS_COUNT); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_NONE) != 0); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_DEFLATE) != 0); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_GZIP) != 0); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + if (set_server_level) { + op->data.send_initial_metadata.maybe_compression_level.is_set = true; + op->data.send_initial_metadata.maybe_compression_level.level = + server_compression_level; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + for (int i = 0; i < 2; i++) { + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + if (i > 0 || !send_message_before_initial_metadata) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = client_send_flags_bitmask; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), + tag(2), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + GPR_ASSERT(request_payload_recv->type == GRPC_BB_RAW); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, request_str)); + GPR_ASSERT(request_payload_recv->data.raw.compression == + (decompress_in_core ? GRPC_COMPRESS_NONE + : expected_algorithm_from_client)); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + + GPR_ASSERT(response_payload_recv->type == GRPC_BB_RAW); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, response_str)); + if (server_compression_level > GRPC_COMPRESS_LEVEL_NONE) { + const grpc_compression_algorithm algo_for_server_level = + grpc_call_compression_for_level(s, server_compression_level); + GPR_ASSERT( + response_payload_recv->data.raw.compression == + (decompress_in_core ? GRPC_COMPRESS_NONE : algo_for_server_level)); + } else { + GPR_ASSERT(response_payload_recv->data.raw.compression == + (decompress_in_core ? GRPC_COMPRESS_NONE + : expected_algorithm_from_server)); + } + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + } + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + end_test(&f); + config.tear_down_data(&f); +} + +static void request_with_payload_template( + grpc_end2end_test_config config, const char* test_name, + uint32_t client_send_flags_bitmask, + grpc_compression_algorithm default_client_channel_compression_algorithm, + grpc_compression_algorithm default_server_channel_compression_algorithm, + grpc_compression_algorithm expected_algorithm_from_client, + grpc_compression_algorithm expected_algorithm_from_server, + grpc_metadata* client_init_metadata, bool set_server_level, + grpc_compression_level server_compression_level, + bool send_message_before_initial_metadata) { + request_with_payload_template_inner( + config, test_name, client_send_flags_bitmask, + default_client_channel_compression_algorithm, + default_server_channel_compression_algorithm, + expected_algorithm_from_client, expected_algorithm_from_server, + client_init_metadata, set_server_level, server_compression_level, + send_message_before_initial_metadata, false); + request_with_payload_template_inner( + config, test_name, client_send_flags_bitmask, + default_client_channel_compression_algorithm, + default_server_channel_compression_algorithm, + expected_algorithm_from_client, expected_algorithm_from_server, + client_init_metadata, set_server_level, server_compression_level, + send_message_before_initial_metadata, true); +} + +static void test_invoke_request_with_exceptionally_uncompressed_payload( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_exceptionally_uncompressed_payload", + GRPC_WRITE_NO_COMPRESS, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_COMPRESS_NONE, GRPC_COMPRESS_GZIP, nullptr, false, + /* ignored */ GRPC_COMPRESS_LEVEL_NONE, false); +} + +static void test_invoke_request_with_uncompressed_payload( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_uncompressed_payload", 0, + GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, + GRPC_COMPRESS_NONE, nullptr, false, + /* ignored */ GRPC_COMPRESS_LEVEL_NONE, false); +} + +static void test_invoke_request_with_compressed_payload( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload", 0, + GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_COMPRESS_GZIP, nullptr, false, + /* ignored */ GRPC_COMPRESS_LEVEL_NONE, false); +} + +static void test_invoke_request_with_send_message_before_initial_metadata( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload", 0, + GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_COMPRESS_GZIP, nullptr, false, + /* ignored */ GRPC_COMPRESS_LEVEL_NONE, true); +} + +static void test_invoke_request_with_server_level( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_server_level", 0, GRPC_COMPRESS_NONE, + GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE /* ignored */, + nullptr, true, GRPC_COMPRESS_LEVEL_HIGH, false); +} + +static void test_invoke_request_with_compressed_payload_md_override( + grpc_end2end_test_config config) { + grpc_metadata gzip_compression_override; + grpc_metadata identity_compression_override; + + gzip_compression_override.key = GRPC_MDSTR_GRPC_INTERNAL_ENCODING_REQUEST; + gzip_compression_override.value = grpc_slice_from_static_string("gzip"); + memset(&gzip_compression_override.internal_data, 0, + sizeof(gzip_compression_override.internal_data)); + + identity_compression_override.key = GRPC_MDSTR_GRPC_INTERNAL_ENCODING_REQUEST; + identity_compression_override.value = + grpc_slice_from_static_string("identity"); + memset(&identity_compression_override.internal_data, 0, + sizeof(identity_compression_override.internal_data)); + + /* Channel default NONE (aka IDENTITY), call override to GZIP */ + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload_md_override_1", 0, + GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_GZIP, + GRPC_COMPRESS_NONE, &gzip_compression_override, false, + /*ignored*/ GRPC_COMPRESS_LEVEL_NONE, false); + + /* Channel default DEFLATE, call override to GZIP */ + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload_md_override_2", 0, + GRPC_COMPRESS_DEFLATE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_GZIP, + GRPC_COMPRESS_NONE, &gzip_compression_override, false, + /*ignored*/ GRPC_COMPRESS_LEVEL_NONE, false); + + /* Channel default DEFLATE, call override to NONE (aka IDENTITY) */ + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload_md_override_3", 0, + GRPC_COMPRESS_DEFLATE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, + GRPC_COMPRESS_NONE, &identity_compression_override, false, + /*ignored*/ GRPC_COMPRESS_LEVEL_NONE, false); +} + +static void test_invoke_request_with_disabled_algorithm( + grpc_end2end_test_config config) { + request_for_disabled_algorithm(config, + "test_invoke_request_with_disabled_algorithm", + 0, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_STATUS_UNIMPLEMENTED, nullptr, false); + request_for_disabled_algorithm(config, + "test_invoke_request_with_disabled_algorithm", + 0, GRPC_COMPRESS_GZIP, GRPC_COMPRESS_GZIP, + GRPC_STATUS_UNIMPLEMENTED, nullptr, true); +} + +void compressed_payload(grpc_end2end_test_config config) { + test_invoke_request_with_exceptionally_uncompressed_payload(config); + test_invoke_request_with_uncompressed_payload(config); + test_invoke_request_with_compressed_payload(config); + test_invoke_request_with_send_message_before_initial_metadata(config); + test_invoke_request_with_server_level(config); + test_invoke_request_with_compressed_payload_md_override(config); + test_invoke_request_with_disabled_algorithm(config); +} + +void compressed_payload_pre_init(void) {} diff --git a/test/core/end2end/tests/connectivity.cc b/test/core/end2end/tests/connectivity.cc new file mode 100644 index 00000000..326b7bf8 --- /dev/null +++ b/test/core/end2end/tests/connectivity.cc @@ -0,0 +1,246 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +typedef struct { + gpr_event started; + grpc_channel* channel; + grpc_completion_queue* cq; +} child_events; + +struct CallbackContext { + grpc_completion_queue_functor functor; + gpr_event finished; + explicit CallbackContext(void (*cb)(grpc_completion_queue_functor* functor, + int success)) { + functor.functor_run = cb; + functor.inlineable = false; + gpr_event_init(&finished); + } +}; + +static void child_thread(void* arg) { + child_events* ce = static_cast(arg); + grpc_event ev; + gpr_event_set(&ce->started, reinterpret_cast(1)); + gpr_log(GPR_DEBUG, "verifying"); + ev = grpc_completion_queue_next(ce->cq, gpr_inf_future(GPR_CLOCK_MONOTONIC), + nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag(1)); + GPR_ASSERT(ev.success == 0); +} + +static void test_connectivity(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + grpc_connectivity_state state; + cq_verifier* cqv = cq_verifier_create(f.cq); + child_events ce; + + grpc_channel_args client_args; + grpc_arg arg_array[1]; + arg_array[0].type = GRPC_ARG_INTEGER; + arg_array[0].key = + const_cast("grpc.testing.fixed_reconnect_backoff_ms"); + arg_array[0].value.integer = 1000; + client_args.args = arg_array; + client_args.num_args = 1; + + config.init_client(&f, &client_args); + + ce.channel = f.client; + ce.cq = f.cq; + gpr_event_init(&ce.started); + grpc_core::Thread thd("grpc_connectivity", child_thread, &ce); + thd.Start(); + + gpr_event_wait(&ce.started, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + + /* channels should start life in IDLE, and stay there */ + GPR_ASSERT(grpc_channel_check_connectivity_state(f.client, 0) == + GRPC_CHANNEL_IDLE); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + GPR_ASSERT(grpc_channel_check_connectivity_state(f.client, 0) == + GRPC_CHANNEL_IDLE); + + /* start watching for a change */ + gpr_log(GPR_DEBUG, "watching"); + grpc_channel_watch_connectivity_state( + f.client, GRPC_CHANNEL_IDLE, gpr_now(GPR_CLOCK_MONOTONIC), f.cq, tag(1)); + + /* eventually the child thread completion should trigger */ + thd.Join(); + + /* check that we're still in idle, and start connecting */ + GPR_ASSERT(grpc_channel_check_connectivity_state(f.client, 1) == + GRPC_CHANNEL_IDLE); + /* start watching for a change */ + grpc_channel_watch_connectivity_state(f.client, GRPC_CHANNEL_IDLE, + grpc_timeout_seconds_to_deadline(3), + f.cq, tag(2)); + + /* and now the watch should trigger */ + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_TRANSIENT_FAILURE || + state == GRPC_CHANNEL_CONNECTING); + + /* quickly followed by a transition to TRANSIENT_FAILURE */ + grpc_channel_watch_connectivity_state(f.client, GRPC_CHANNEL_CONNECTING, + grpc_timeout_seconds_to_deadline(3), + f.cq, tag(3)); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_TRANSIENT_FAILURE || + state == GRPC_CHANNEL_CONNECTING); + + gpr_log(GPR_DEBUG, "*** STARTING SERVER ***"); + + /* now let's bring up a server to connect to */ + config.init_server(&f, nullptr); + + gpr_log(GPR_DEBUG, "*** STARTED SERVER ***"); + + /* we'll go through some set of transitions (some might be missed), until + READY is reached */ + while (state != GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state( + f.client, state, grpc_timeout_seconds_to_deadline(3), f.cq, tag(4)); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_READY || + state == GRPC_CHANNEL_CONNECTING || + state == GRPC_CHANNEL_TRANSIENT_FAILURE); + } + + /* bring down the server again */ + /* we should go immediately to TRANSIENT_FAILURE */ + gpr_log(GPR_DEBUG, "*** SHUTTING DOWN SERVER ***"); + + grpc_channel_watch_connectivity_state(f.client, GRPC_CHANNEL_READY, + grpc_timeout_seconds_to_deadline(3), + f.cq, tag(5)); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + + CQ_EXPECT_COMPLETION(cqv, tag(5), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_TRANSIENT_FAILURE || + state == GRPC_CHANNEL_CONNECTING || state == GRPC_CHANNEL_IDLE); + + /* cleanup server */ + grpc_server_destroy(f.server); + + gpr_log(GPR_DEBUG, "*** SHUTDOWN SERVER ***"); + + grpc_channel_destroy(f.client); + grpc_completion_queue_shutdown(f.cq); + grpc_completion_queue_destroy(f.cq); + + /* shutdown_cq is not used in this test */ + grpc_completion_queue_destroy(f.shutdown_cq); + config.tear_down_data(&f); + + cq_verifier_destroy(cqv); +} + +static void cb_watch_connectivity(grpc_completion_queue_functor* functor, + int success) { + CallbackContext* cb_ctx = reinterpret_cast(functor); + + gpr_log(GPR_DEBUG, "cb_watch_connectivity called, verifying"); + + /* callback must not have errors */ + GPR_ASSERT(success != 0); + + gpr_event_set(&cb_ctx->finished, reinterpret_cast(1)); +} + +static void cb_shutdown(grpc_completion_queue_functor* functor, + int /*success*/) { + CallbackContext* cb_ctx = reinterpret_cast(functor); + + gpr_log(GPR_DEBUG, "cb_shutdown called, nothing to do"); + gpr_event_set(&cb_ctx->finished, reinterpret_cast(1)); +} + +static void test_watch_connectivity_cq_callback( + grpc_end2end_test_config config) { + CallbackContext cb_ctx(cb_watch_connectivity); + CallbackContext cb_shutdown_ctx(cb_shutdown); + grpc_completion_queue* cq; + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + + config.init_client(&f, nullptr); + + /* start connecting */ + grpc_channel_check_connectivity_state(f.client, 1); + + /* create the cq callback */ + cq = grpc_completion_queue_create_for_callback(&cb_shutdown_ctx.functor, + nullptr); + + /* start watching for any change, cb is immediately called + * and no dead lock should be raised */ + grpc_channel_watch_connectivity_state(f.client, GRPC_CHANNEL_IDLE, + grpc_timeout_seconds_to_deadline(3), cq, + &cb_ctx.functor); + + /* we just check that the callback was executed once notifying a connection + * transition */ + GPR_ASSERT(gpr_event_wait(&cb_ctx.finished, + gpr_inf_future(GPR_CLOCK_MONOTONIC)) != nullptr); + + /* shutdown, since shutdown cb might be executed in a background thread + * we actively wait till is executed. */ + grpc_completion_queue_shutdown(cq); + gpr_event_wait(&cb_shutdown_ctx.finished, + gpr_inf_future(GPR_CLOCK_MONOTONIC)); + + /* cleanup */ + grpc_channel_destroy(f.client); + grpc_completion_queue_destroy(cq); + + /* shutdown_cq and cq are not used in this test */ + grpc_completion_queue_destroy(f.cq); + grpc_completion_queue_destroy(f.shutdown_cq); + + config.tear_down_data(&f); +} + +void connectivity(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION); + test_connectivity(config); + test_watch_connectivity_cq_callback(config); +} + +void connectivity_pre_init(void) {} diff --git a/test/core/end2end/tests/default_host.cc b/test/core/end2end/tests/default_host.cc new file mode 100644 index 00000000..efd4fea9 --- /dev/null +++ b/test/core/end2end/tests/default_host.cc @@ -0,0 +1,225 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_client(&f, client_args); + config.init_server(&f, server_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(error == GRPC_CALL_OK); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(error == GRPC_CALL_OK); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(error == GRPC_CALL_OK); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + if (config.overridden_call_host != nullptr) { + validate_host_override_string(config.overridden_call_host, + call_details.host, config); + } else { + GPR_ASSERT(grpc_slice_buf_start_eq(call_details.host, "localhost", 9) || + grpc_slice_buf_start_eq(call_details.host, "127.0.0.1", 9)); + } + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void default_host(grpc_end2end_test_config config) { + test_invoke_simple_request(config); +} + +void default_host_pre_init(void) {} diff --git a/test/core/end2end/tests/disappearing_server.cc b/test/core/end2end/tests/disappearing_server.cc new file mode 100644 index 00000000..6f4493cc --- /dev/null +++ b/test/core/end2end/tests/disappearing_server.cc @@ -0,0 +1,218 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + + /* Note: shutdown_cq was unused in this test */ + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void do_request_and_shutdown_server(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture* f, + cq_verifier* cqv) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f->client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f->cq, grpc_slice_from_static_string("/foo"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f->server, &s, &call_details, + &request_metadata_recv, f->cq, f->cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + /* should be able to shut down the server early + - and still complete the request */ + grpc_server_shutdown_and_notify(f->server, f->cq, tag(1000)); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1000), 1); + cq_verify(cqv); + /* Please refer https://github.com/grpc/grpc/issues/21221 for additional + * details. + * TODO(yashykt@) - The following line should be removeable after C-Core + * correctly handles GOAWAY frames. Internal Reference b/135458602. If this + * test remains flaky even after this, an alternative fix would be to send a + * request when the server is in the shut down state. + */ + cq_verify_empty(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); +} + +static void disappearing_server_test(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_log(GPR_INFO, "Running test: %s/%s", "disappearing_server_test", + config.name); + + config.init_client(&f, nullptr); + config.init_server(&f, nullptr); + + do_request_and_shutdown_server(config, &f, cqv); + + /* now destroy and recreate the server */ + config.init_server(&f, nullptr); + + do_request_and_shutdown_server(config, &f, cqv); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void disappearing_server(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION); +#ifndef GPR_WINDOWS /* b/148110727 for more details */ + disappearing_server_test(config); +#endif /* GPR_WINDOWS */ +} + +void disappearing_server_pre_init(void) {} diff --git a/test/core/end2end/tests/empty_batch.cc b/test/core/end2end/tests/empty_batch.cc new file mode 100644 index 00000000..3104c5d0 --- /dev/null +++ b/test/core/end2end/tests/empty_batch.cc @@ -0,0 +1,124 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void empty_batch_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_call_error error; + grpc_op* op = nullptr; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + error = grpc_call_start_batch(c, op, 0, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_empty_body(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_empty_body", nullptr, nullptr); + empty_batch_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +void empty_batch(grpc_end2end_test_config config) { + test_invoke_empty_body(config); +} + +void empty_batch_pre_init(void) {} diff --git a/test/core/end2end/tests/filter_causes_close.cc b/test/core/end2end/tests/filter_causes_close.cc new file mode 100644 index 00000000..8944c1a5 --- /dev/null +++ b/test/core/end2end/tests/filter_causes_close.cc @@ -0,0 +1,266 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Simple request via a server filter that always closes the stream.*/ +static void test_request(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "filter_causes_close", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_PERMISSION_DENIED); + GPR_ASSERT(0 == + grpc_slice_str_cmp(details, "Failure that's not preventable.")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +/******************************************************************************* + * Test filter - always closes incoming requests + */ + +typedef struct { + grpc_closure* recv_im_ready; +} call_data; + +typedef struct { + uint8_t unused; +} channel_data; + +static void recv_im_ready(void* arg, grpc_error_handle error) { + grpc_call_element* elem = static_cast(arg); + call_data* calld = static_cast(elem->call_data); + grpc_core::Closure::Run( + DEBUG_LOCATION, calld->recv_im_ready, + grpc_error_set_int(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Failure that's not preventable.", &error, 1), + GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_PERMISSION_DENIED)); +} + +static void start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + call_data* calld = static_cast(elem->call_data); + if (op->recv_initial_metadata) { + calld->recv_im_ready = + op->payload->recv_initial_metadata.recv_initial_metadata_ready; + op->payload->recv_initial_metadata.recv_initial_metadata_ready = + GRPC_CLOSURE_CREATE(recv_im_ready, elem, grpc_schedule_on_exec_ctx); + } + grpc_call_next_op(elem, op); +} + +static grpc_error_handle init_call_elem( + grpc_call_element* /*elem*/, const grpc_call_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) {} + +static grpc_error_handle init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +static const grpc_channel_filter test_filter = { + start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, + sizeof(channel_data), + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "filter_causes_close"}; + +/******************************************************************************* + * Registration + */ + +void filter_causes_close(grpc_end2end_test_config config) { + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, 0, [](grpc_channel_stack_builder* builder) { + return grpc_channel_stack_builder_prepend_filter( + builder, &test_filter, nullptr, nullptr); + }); + }, + [config] { test_request(config); }); +} + +void filter_causes_close_pre_init(void) {} diff --git a/test/core/end2end/tests/filter_context.cc b/test/core/end2end/tests/filter_context.cc new file mode 100644 index 00000000..a071a2ff --- /dev/null +++ b/test/core/end2end/tests/filter_context.cc @@ -0,0 +1,303 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Simple request to test that filters see a consistent view of the +// call context. +static void test_request(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "filter_context", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_string = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_string; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(s); + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +/******************************************************************************* + * Test context filter + */ + +struct call_data { + grpc_call_context_element* context; +}; + +static grpc_error_handle init_call_elem(grpc_call_element* elem, + const grpc_call_element_args* args) { + call_data* calld = static_cast(elem->call_data); + calld->context = args->context; + gpr_log(GPR_INFO, "init_call_elem(): context=%p", args->context); + return GRPC_ERROR_NONE; +} + +static void start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + call_data* calld = static_cast(elem->call_data); + // If batch payload context is not null (which will happen in some + // cancellation cases), make sure we get the same context here that we + // saw in init_call_elem(). + gpr_log(GPR_INFO, "start_transport_stream_op_batch(): context=%p", + batch->payload->context); + if (batch->payload->context != nullptr) { + GPR_ASSERT(calld->context == batch->payload->context); + } + grpc_call_next_op(elem, batch); +} + +static void destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) {} + +static grpc_error_handle init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +static const grpc_channel_filter test_filter = { + start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(call_data), + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "filter_context"}; + +/******************************************************************************* + * Registration + */ + +void filter_context(grpc_end2end_test_config config) { + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + for (auto type : {GRPC_CLIENT_CHANNEL, GRPC_CLIENT_SUBCHANNEL, + GRPC_CLIENT_DIRECT_CHANNEL, GRPC_SERVER_CHANNEL}) { + builder->channel_init()->RegisterStage( + type, INT_MAX, [](grpc_channel_stack_builder* builder) { + // Want to add the filter as close to the end as possible, to + // make sure that all of the filters work well together. + // However, we can't add it at the very end, because the + // connected channel filter must be the last one. So we add it + // right before the last one. + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_last(builder); + GPR_ASSERT(grpc_channel_stack_builder_move_prev(it)); + const bool retval = + grpc_channel_stack_builder_add_filter_before( + it, &test_filter, nullptr, nullptr); + grpc_channel_stack_builder_iterator_destroy(it); + return retval; + }); + } + }, + [config] { test_request(config); }); +} + +void filter_context_pre_init(void) {} diff --git a/test/core/end2end/tests/filter_init_fails.cc b/test/core/end2end/tests/filter_init_fails.cc new file mode 100644 index 00000000..d252c2d2 --- /dev/null +++ b/test/core/end2end/tests/filter_init_fails.cc @@ -0,0 +1,519 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static bool g_enable_server_channel_filter = false; +static bool g_enable_client_channel_filter = false; +static bool g_enable_client_subchannel_filter = false; +static bool g_channel_filter_init_failure = false; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Simple request via a SERVER_CHANNEL filter that always fails to +// initialize the call. +static void test_server_channel_filter(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "filter_init_fails", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + if (g_channel_filter_init_failure) { + // Inproc channel returns invalid_argument and other clients return + // unavailable. + // Windows with sockpair returns unknown. + GPR_ASSERT(status == GRPC_STATUS_UNKNOWN || + status == GRPC_STATUS_UNAVAILABLE || + status == GRPC_STATUS_INVALID_ARGUMENT); + } else { + GPR_ASSERT(status == GRPC_STATUS_PERMISSION_DENIED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "access denied")); + } + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +// Simple request via a CLIENT_CHANNEL or CLIENT_DIRECT_CHANNEL filter +// that always fails to initialize the call. +static void test_client_channel_filter(grpc_end2end_test_config config) { + grpc_call* c; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + gpr_timespec deadline = five_seconds_from_now(); + grpc_end2end_test_fixture f = + begin_test(config, "filter_init_fails", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + if (g_channel_filter_init_failure) { + GPR_ASSERT(status == GRPC_STATUS_INVALID_ARGUMENT); + } else { + GPR_ASSERT(status == GRPC_STATUS_PERMISSION_DENIED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "access denied")); + } + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +// Simple request via a CLIENT_SUBCHANNEL filter that always fails to +// initialize the call. +static void test_client_subchannel_filter(grpc_end2end_test_config config) { + grpc_call* c; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + gpr_timespec deadline = five_seconds_from_now(); + grpc_end2end_test_fixture f = + begin_test(config, "filter_init_fails", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + if (g_channel_filter_init_failure) { + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } else { + GPR_ASSERT(status == GRPC_STATUS_PERMISSION_DENIED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "access denied")); + } + + // Reset and create a new call. (The first call uses a different code + // path in client_channel.c than subsequent calls on the same channel, + // and we need to test both.) + grpc_call_unref(c); + status = GRPC_STATUS_OK; + grpc_slice_unref(details); + details = grpc_empty_slice(); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + if (g_channel_filter_init_failure) { + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + } else { + GPR_ASSERT(status == GRPC_STATUS_PERMISSION_DENIED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "access denied")); + } + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +/******************************************************************************* + * Test filter - always fails to initialize a call + */ + +static grpc_error_handle init_call_elem( + grpc_call_element* /*elem*/, const grpc_call_element_args* /*args*/) { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("access denied"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_PERMISSION_DENIED); +} + +static void destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) {} + +static grpc_error_handle init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + if (g_channel_filter_init_failure) { + return grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test channel filter init error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_INVALID_ARGUMENT); + } + return GRPC_ERROR_NONE; +} + +static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +static const grpc_channel_filter test_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "filter_init_fails"}; + +/******************************************************************************* + * Registration + */ + +static void filter_init_fails_internal(grpc_end2end_test_config config) { + gpr_log(GPR_INFO, "Testing SERVER_CHANNEL filter."); + g_enable_server_channel_filter = true; + test_server_channel_filter(config); + g_enable_server_channel_filter = false; + gpr_log(GPR_INFO, "Testing CLIENT_CHANNEL / CLIENT_DIRECT_CHANNEL filter."); + g_enable_client_channel_filter = true; + test_client_channel_filter(config); + g_enable_client_channel_filter = false; + // If the client handshake completes before the server handshake and the + // client is able to send application data before the server handshake + // completes, then testing the CLIENT_SUBCHANNEL filter will cause the server + // to freeze waiting for the final handshake message from the client. This + // handshake message will never arrive because it would have been sent with + // the first application data message, which failed because of the filter. + if ((config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL) && + !(config.feature_mask & + FEATURE_MASK_DOES_NOT_SUPPORT_CLIENT_HANDSHAKE_COMPLETE_FIRST)) { + gpr_log(GPR_INFO, "Testing CLIENT_SUBCHANNEL filter."); + g_enable_client_subchannel_filter = true; + test_client_subchannel_filter(config); + g_enable_client_subchannel_filter = false; + } +} + +void filter_init_fails(grpc_end2end_test_config config) { + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + auto register_stage = [builder](grpc_channel_stack_type type, + bool* enable) { + builder->channel_init()->RegisterStage( + type, INT_MAX, [enable](grpc_channel_stack_builder* builder) { + if (!*enable) return true; + // Want to add the filter as close to the end as possible, + // to make sure that all of the filters work well together. + // However, we can't add it at the very end, because either the + // client_channel filter or connected_channel filter must be the + // last one. So we add it right before the last one. + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_last(builder); + GPR_ASSERT(grpc_channel_stack_builder_move_prev(it)); + const bool retval = + grpc_channel_stack_builder_add_filter_before( + it, &test_filter, nullptr, nullptr); + grpc_channel_stack_builder_iterator_destroy(it); + return retval; + }); + }; + register_stage(GRPC_SERVER_CHANNEL, &g_enable_server_channel_filter); + register_stage(GRPC_CLIENT_CHANNEL, &g_enable_client_channel_filter); + register_stage(GRPC_CLIENT_SUBCHANNEL, + &g_enable_client_subchannel_filter); + register_stage(GRPC_CLIENT_DIRECT_CHANNEL, + &g_enable_client_channel_filter); + }, + [config] { filter_init_fails_internal(config); }); +} + +void filter_init_fails_pre_init(void) {} diff --git a/test/core/end2end/tests/filter_latency.cc b/test/core/end2end/tests/filter_latency.cc new file mode 100644 index 00000000..d87db6f4 --- /dev/null +++ b/test/core/end2end/tests/filter_latency.cc @@ -0,0 +1,337 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static gpr_mu g_mu; +static gpr_timespec g_client_latency; +static gpr_timespec g_server_latency; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Simple request via a server filter that saves the reported latency value. +static void test_request(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "filter_latency", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_mu_lock(&g_mu); + g_client_latency = gpr_time_0(GPR_TIMESPAN); + g_server_latency = gpr_time_0(GPR_TIMESPAN); + gpr_mu_unlock(&g_mu); + const gpr_timespec start_time = gpr_now(GPR_CLOCK_REALTIME); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_string = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_string; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(s); + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); + + const gpr_timespec end_time = gpr_now(GPR_CLOCK_REALTIME); + const gpr_timespec max_latency = gpr_time_sub(end_time, start_time); + + // Perform checks after test tear-down + // Guards against the case that there's outstanding channel-related work on a + // call prior to verification + gpr_mu_lock(&g_mu); + GPR_ASSERT(gpr_time_cmp(max_latency, g_client_latency) >= 0); + GPR_ASSERT(gpr_time_cmp(gpr_time_0(GPR_TIMESPAN), g_client_latency) <= 0); + GPR_ASSERT(gpr_time_cmp(max_latency, g_server_latency) >= 0); + GPR_ASSERT(gpr_time_cmp(gpr_time_0(GPR_TIMESPAN), g_server_latency) <= 0); + // Server latency should always be smaller than client latency, however since + // we only calculate latency at destruction time, and that might mean that we + // need to wait for outstanding channel-related work, this isn't verifiable + // right now (the server MAY hold on to the call for longer than the client). + // GPR_ASSERT(gpr_time_cmp(g_server_latency, g_client_latency) < 0); + gpr_mu_unlock(&g_mu); +} + +/******************************************************************************* + * Test latency filter + */ + +static grpc_error_handle init_call_elem( + grpc_call_element* /*elem*/, const grpc_call_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void client_destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* final_info, + grpc_closure* /*ignored*/) { + gpr_mu_lock(&g_mu); + g_client_latency = final_info->stats.latency; + gpr_mu_unlock(&g_mu); +} + +static void server_destroy_call_elem(grpc_call_element* /*elem*/, + const grpc_call_final_info* final_info, + grpc_closure* /*ignored*/) { + gpr_mu_lock(&g_mu); + g_server_latency = final_info->stats.latency; + gpr_mu_unlock(&g_mu); +} + +static grpc_error_handle init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +static const grpc_channel_filter test_client_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + client_destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "client_filter_latency"}; + +static const grpc_channel_filter test_server_filter = { + grpc_call_next_op, + grpc_channel_next_op, + 0, + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + server_destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "server_filter_latency"}; + +/******************************************************************************* + * Registration + */ + +void filter_latency(grpc_end2end_test_config config) { + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + auto register_stage = [builder](grpc_channel_stack_type type, + const grpc_channel_filter* filter) { + builder->channel_init()->RegisterStage( + type, INT_MAX, [filter](grpc_channel_stack_builder* builder) { + // Want to add the filter as close to the end as possible, to + // make sure that all of the filters work well together. + // However, we can't add it at the very end, because the + // connected channel filter must be the last one. So we add it + // right before the last one. + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_last(builder); + GPR_ASSERT(grpc_channel_stack_builder_move_prev(it)); + const bool retval = + grpc_channel_stack_builder_add_filter_before( + it, filter, nullptr, nullptr); + grpc_channel_stack_builder_iterator_destroy(it); + return retval; + }); + }; + register_stage(GRPC_CLIENT_CHANNEL, &test_client_filter); + register_stage(GRPC_CLIENT_DIRECT_CHANNEL, &test_client_filter); + register_stage(GRPC_SERVER_CHANNEL, &test_server_filter); + }, + [config] { test_request(config); }); +} + +void filter_latency_pre_init(void) { gpr_mu_init(&g_mu); } diff --git a/test/core/end2end/tests/filter_status_code.cc b/test/core/end2end/tests/filter_status_code.cc new file mode 100644 index 00000000..aac56815 --- /dev/null +++ b/test/core/end2end/tests/filter_status_code.cc @@ -0,0 +1,389 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This test verifies - + * 1) grpc_call_final_info passed to the filters on destroying a call contains + * the proper status. + * 2) If the response has both an HTTP status code and a gRPC status code, then + * we should prefer the gRPC status code as mentioned in + * https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/channel_init.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static gpr_mu g_mu; +static grpc_call_stack* g_client_call_stack; +static grpc_call_stack* g_server_call_stack; +static bool g_client_code_recv; +static bool g_server_code_recv; +static gpr_cv g_client_code_cv; +static gpr_cv g_server_code_cv; +static grpc_status_code g_client_status_code; +static grpc_status_code g_server_status_code; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Simple request via a server filter that saves the reported status code. +static void test_request(grpc_end2end_test_config config) { + g_client_code_recv = false; + g_server_code_recv = false; + + grpc_call* c; + grpc_call* s; + grpc_end2end_test_fixture f = + begin_test(config, "filter_status_code", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_mu_lock(&g_mu); + g_client_call_stack = nullptr; + g_server_call_stack = nullptr; + g_client_status_code = GRPC_STATUS_OK; + g_server_status_code = GRPC_STATUS_OK; + gpr_mu_unlock(&g_mu); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + gpr_mu_lock(&g_mu); + g_client_call_stack = grpc_call_get_call_stack(c); + gpr_mu_unlock(&g_mu); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + gpr_mu_lock(&g_mu); + g_server_call_stack = grpc_call_get_call_stack(s); + gpr_mu_unlock(&g_mu); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_string = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_string; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(s); + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); + + // Perform checks after test tear-down + // Guards against the case that there's outstanding channel-related work on a + // call prior to verification + gpr_mu_lock(&g_mu); + if (!g_client_code_recv) { + GPR_ASSERT(gpr_cv_wait(&g_client_code_cv, &g_mu, + grpc_timeout_seconds_to_deadline(3)) == 0); + } + if (!g_server_code_recv) { + GPR_ASSERT(gpr_cv_wait(&g_server_code_cv, &g_mu, + grpc_timeout_seconds_to_deadline(3)) == 0); + } + GPR_ASSERT(g_client_status_code == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(g_server_status_code == GRPC_STATUS_UNIMPLEMENTED); + gpr_mu_unlock(&g_mu); +} + +/******************************************************************************* + * Test status_code filter + */ + +typedef struct final_status_data { + grpc_call_stack* call; +} final_status_data; + +static void server_start_transport_stream_op_batch( + grpc_call_element* elem, grpc_transport_stream_op_batch* op) { + auto* data = static_cast(elem->call_data); + gpr_mu_lock(&g_mu); + if (data->call == g_server_call_stack) { + if (op->send_initial_metadata) { + auto* batch = op->payload->send_initial_metadata.send_initial_metadata; + if (batch->legacy_index()->named.status != nullptr) { + /* Replace the HTTP status with 404 */ + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "Substitute", batch->Substitute(batch->legacy_index()->named.status, + GRPC_MDELEM_STATUS_404))); + } + } + } + gpr_mu_unlock(&g_mu); + grpc_call_next_op(elem, op); +} + +static grpc_error_handle init_call_elem(grpc_call_element* elem, + const grpc_call_element_args* args) { + final_status_data* data = static_cast(elem->call_data); + data->call = args->call_stack; + return GRPC_ERROR_NONE; +} + +static void client_destroy_call_elem(grpc_call_element* elem, + const grpc_call_final_info* final_info, + grpc_closure* /*ignored*/) { + final_status_data* data = static_cast(elem->call_data); + gpr_mu_lock(&g_mu); + // Some fixtures, like proxies, will spawn intermidiate calls + // We only want the results from our explicit calls + if (data->call == g_client_call_stack) { + g_client_status_code = final_info->final_status; + g_client_code_recv = true; + gpr_cv_signal(&g_client_code_cv); + } + gpr_mu_unlock(&g_mu); +} + +static void server_destroy_call_elem(grpc_call_element* elem, + const grpc_call_final_info* final_info, + grpc_closure* /*ignored*/) { + final_status_data* data = static_cast(elem->call_data); + gpr_mu_lock(&g_mu); + // Some fixtures, like proxies, will spawn intermidiate calls + // We only want the results from our explicit calls + if (data->call == g_server_call_stack) { + g_server_status_code = final_info->final_status; + g_server_code_recv = true; + gpr_cv_signal(&g_server_code_cv); + } + gpr_mu_unlock(&g_mu); +} + +static grpc_error_handle init_channel_elem( + grpc_channel_element* /*elem*/, grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} + +static const grpc_channel_filter test_client_filter = { + grpc_call_next_op, + grpc_channel_next_op, + sizeof(final_status_data), + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + client_destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "client_filter_status_code"}; + +static const grpc_channel_filter test_server_filter = { + server_start_transport_stream_op_batch, + grpc_channel_next_op, + sizeof(final_status_data), + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + server_destroy_call_elem, + 0, + init_channel_elem, + destroy_channel_elem, + grpc_channel_next_get_info, + "server_filter_status_code"}; + +/******************************************************************************* + * Registration + */ + +void filter_status_code(grpc_end2end_test_config config) { + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + auto register_stage = [builder](grpc_channel_stack_type type, + const grpc_channel_filter* filter) { + builder->channel_init()->RegisterStage( + type, INT_MAX, [filter](grpc_channel_stack_builder* builder) { + // Want to add the filter as close to the end as possible, to + // make sure that all of the filters work well together. + // However, we can't add it at the very end, because the + // connected_channel/client_channel filter must be the last one. + // So we add it right before the last one. + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_last(builder); + GPR_ASSERT(grpc_channel_stack_builder_move_prev(it)); + const bool retval = + grpc_channel_stack_builder_add_filter_before( + it, filter, nullptr, nullptr); + grpc_channel_stack_builder_iterator_destroy(it); + return retval; + }); + }; + register_stage(GRPC_CLIENT_CHANNEL, &test_client_filter); + register_stage(GRPC_CLIENT_DIRECT_CHANNEL, &test_client_filter); + register_stage(GRPC_SERVER_CHANNEL, &test_server_filter); + }, + [config] { test_request(config); }); +} + +void filter_status_code_pre_init(void) { + gpr_mu_init(&g_mu); + gpr_cv_init(&g_client_code_cv); + gpr_cv_init(&g_server_code_cv); +} diff --git a/test/core/end2end/tests/graceful_server_shutdown.cc b/test/core/end2end/tests/graceful_server_shutdown.cc new file mode 100644 index 00000000..8dc56ff0 --- /dev/null +++ b/test/core/end2end/tests/graceful_server_shutdown.cc @@ -0,0 +1,204 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + /* Note: shutdown_cq is not used in this test */ + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_early_server_shutdown_finishes_inflight_calls( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_end2end_test_fixture f = + begin_test(config, "test_early_server_shutdown_finishes_inflight_calls", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = n_seconds_from_now(10); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + /* shutdown and destroy the server */ + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + cq_verify_empty(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_call_unref(s); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void graceful_server_shutdown(grpc_end2end_test_config config) { + test_early_server_shutdown_finishes_inflight_calls(config); +} + +void graceful_server_shutdown_pre_init(void) {} diff --git a/test/core/end2end/tests/high_initial_seqno.cc b/test/core/end2end/tests/high_initial_seqno.cc new file mode 100644 index 00000000..9772601e --- /dev/null +++ b/test/core/end2end/tests/high_initial_seqno.cc @@ -0,0 +1,236 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + /* TODO(ctiller): this rate limits the test, and it should be removed when + retry has been implemented; until then cross-thread chatter + may result in some requests needing to be cancelled due to + seqno exhaustion. */ + cq_verify_empty(cqv); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config, + int initial_sequence_number) { + int i; + grpc_end2end_test_fixture f; + grpc_arg client_arg; + grpc_channel_args client_args; + + client_arg.type = GRPC_ARG_INTEGER; + client_arg.key = const_cast(GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER); + client_arg.value.integer = initial_sequence_number; + + client_args.num_args = 1; + client_args.args = &client_arg; + + std::string name = absl::StrCat("test_invoke_requests first_seqno=", + initial_sequence_number); + f = begin_test(config, name.c_str(), &client_args, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + gpr_log(GPR_INFO, "Running test: Passed simple request %d", i); + } + end_test(&f); + config.tear_down_data(&f); +} + +void high_initial_seqno(grpc_end2end_test_config config) { + test_invoke_10_simple_requests(config, 16777213); + if (config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION) { + test_invoke_10_simple_requests(config, 2147483645); + } +} + +void high_initial_seqno_pre_init(void) {} diff --git a/test/core/end2end/tests/hpack_size.cc b/test/core/end2end/tests/hpack_size.cc new file mode 100644 index 00000000..ca141022 --- /dev/null +++ b/test/core/end2end/tests/hpack_size.cc @@ -0,0 +1,397 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +const char* hobbits[][2] = { + {"Adaldrida", "Brandybuck"}, {"Adamanta", "Took"}, + {"Adalgrim", "Took"}, {"Adelard", "Took"}, + {"Amaranth", "Brandybuck"}, {"Andwise", "Roper"}, + {"Angelica", "Baggins"}, {"Asphodel", "Burrows"}, + {"Balbo", "Baggins"}, {"Bandobras", "Took"}, + {"Belba", "Bolger"}, {"Bell", "Gamgee"}, + {"Belladonna", "Baggins"}, {"Berylla", "Baggins"}, + {"Bilbo", "Baggins"}, {"Bilbo", "Gardner"}, + {"Bill", "Butcher"}, {"Bingo", "Baggins"}, + {"Bodo", "Proudfoot"}, {"Bowman", "Cotton"}, + {"Bungo", "Baggins"}, {"Camellia", "Sackville"}, + {"Carl", "Cotton"}, {"Celandine", "Brandybuck"}, + {"Chica", "Baggins"}, {"Daddy", "Twofoot"}, + {"Daisy", "Boffin"}, {"Diamond", "Took"}, + {"Dinodas", "Brandybuck"}, {"Doderic", "Brandybuck"}, + {"Dodinas", "Brandybuck"}, {"Donnamira", "Boffin"}, + {"Dora", "Baggins"}, {"Drogo", "Baggins"}, + {"Dudo", "Baggins"}, {"Eglantine", "Took"}, + {"Elanor", "Fairbairn"}, {"Elfstan", "Fairbairn"}, + {"Esmeralda", "Brandybuck"}, {"Estella", "Brandybuck"}, + {"Everard", "Took"}, {"Falco", "Chubb-Baggins"}, + {"Faramir", "Took"}, {"Farmer", "Maggot"}, + {"Fastolph", "Bolger"}, {"Ferdibrand", "Took"}, + {"Ferdinand", "Took"}, {"Ferumbras", "Took"}, + {"Ferumbras", "Took"}, {"Filibert", "Bolger"}, + {"Firiel", "Fairbairn"}, {"Flambard", "Took"}, + {"Folco", "Boffin"}, {"Fortinbras", "Took"}, + {"Fortinbras", "Took"}, {"Fosco", "Baggins"}, + {"Fredegar", "Bolger"}, {"Frodo", "Baggins"}, + {"Frodo", "Gardner"}, {"Gerontius", "Took"}, + {"Gilly", "Baggins"}, {"Goldilocks", "Took"}, + {"Gorbadoc", "Brandybuck"}, {"Gorbulas", "Brandybuck"}, + {"Gorhendad", "Brandybuck"}, {"Gormadoc", "Brandybuck"}, + {"Griffo", "Boffin"}, {"Halfast", "Gamgee"}, + {"Halfred", "Gamgee"}, {"Halfred", "Greenhand"}, + {"Hanna", "Brandybuck"}, {"Hamfast", "Gamgee"}, + {"Hamfast", "Gardner"}, {"Hamson", "Gamgee"}, + {"Harding", "Gardner"}, {"Hilda", "Brandybuck"}, + {"Hildibrand", "Took"}, {"Hildifons", "Took"}, + {"Hildigard", "Took"}, {"Hildigrim", "Took"}, + {"Hob", "Gammidge"}, {"Hob", "Hayward"}, + {"Hobson", "Gamgee"}, {"Holfast", "Gardner"}, + {"Holman", "Cotton"}, {"Holman", "Greenhand"}, + {"Hugo", "Boffin"}, {"Hugo", "Bracegirdle"}, + {"Ilberic", "Brandybuck"}, {"Isembard", "Took"}, + {"Isembold", "Took"}, {"Isengar", "Took"}, + {"Isengrim", "Took"}, {"Isengrim", "Took"}, + {"Isumbras", "Took"}, {"Isumbras", "Took"}, + {"Jolly", "Cotton"}, + /* + {"Lalia", "Took"}, + {"Largo", "Baggins"}, + {"Laura", "Baggins"}, + {"Lily", "Goodbody"}, + {"Lily", "Cotton"}, + {"Linda", "Proudfoot"}, + {"Lobelia", "Sackville-Baggins"}, + {"Longo", "Baggins"}, + {"Lotho", "Sackville-Baggins"}, + {"Madoc", "Brandybuck"}, + {"Malva", "Brandybuck"}, + {"Marigold", "Cotton"}, + {"Marmadas", "Brandybuck"}, + {"Marmadoc", "Brandybuck"}, + {"Marroc", "Brandybuck"}, + {"May", "Gamgee"}, + {"Melilot", "Brandybuck"}, + {"Menegilda", "Brandybuck"}, + {"Mentha", "Brandybuck"}, + {"Meriadoc", "Brandybuck"}, + {"Merimac", "Brandybuck"}, + {"Merimas", "Brandybuck"}, + {"Merry", "Gardner"}, + {"Milo", "Burrows"}, + {"Mimosa", "Baggins"}, + {"Minto", "Burrows"}, + {"Mirabella", "Brandybuck"}, + {"Moro", "Burrows"}, + {"Mosco", "Burrows"}, + {"Mungo", "Baggins"}, + {"Myrtle", "Burrows"}, + {"Odo", "Proudfoot"}, + {"Odovacar", "Bolger"}, + {"Olo", "Proudfoot"}, + {"Orgulas", "Brandybuck"}, + {"Otho", "Sackville-Baggins"}, + {"Paladin", "Took"}, + {"Pansy", "Bolger"}, + {"Pearl", "Took"}, + {"Peony", "Burrows"}, + {"Peregrin", "Took"}, + {"Pervinca", "Took"}, + {"Pimpernel", "Took"}, + {"Pippin", "Gardner"}, + {"Polo", "Baggins"}, + {"Ponto", "Baggins"}, + {"Porto", "Baggins"}, + {"Posco", "Baggins"}, + {"Poppy", "Bolger"}, + {"Primrose", "Gardner"}, + {"Primula", "Baggins"}, + {"Prisca", "Bolger"}, + {"Reginard", "Took"}, + {"Robin", "Smallburrow"}, + {"Robin", "Gardner"}, + {"Rorimac", "Brandybuck"}, + {"Rosa", "Took"}, + {"Rosamunda", "Bolger"}, + {"Rose", "Gardner"}, + {"Ruby", "Baggins"}, + {"Ruby", "Gardner"}, + {"Rudigar", "Bolger"}, + {"Rufus", "Burrows"}, + {"Sadoc", "Brandybuck"}, + {"Salvia", "Bolger"}, + {"Samwise", "Gamgee"}, + {"Sancho", "Proudfoot"}, + {"Saradas", "Brandybuck"}, + {"Saradoc", "Brandybuck"}, + {"Seredic", "Brandybuck"}, + {"Sigismond", "Took"}, + {"Smeagol", "Gollum"}, + {"Tanta", "Baggins"}, + {"Ted", "Sandyman"}, + {"Tobold", "Hornblower"}, + {"Togo", "Goodbody"}, + {"Tolman", "Cotton"}, + {"Tolman", "Gardner"}, + {"Widow", "Rumble"}, + {"Wilcome", "Cotton"}, + {"Wilcome", "Cotton"}, + {"Wilibald", "Bolger"}, + {"Will", "Whitfoot"}, + {"Wiseman", "Gamwich"}*/ +}; + +const char* dragons[] = {"Ancalagon", "Glaurung", "Scatha", + "Smaug the Magnificent"}; + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f, size_t index) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_metadata extra_metadata[3]; + grpc_slice details; + int was_cancelled = 2; + + memset(extra_metadata, 0, sizeof(extra_metadata)); + extra_metadata[0].key = grpc_slice_from_static_string("hobbit-first-name"); + extra_metadata[0].value = grpc_slice_from_static_string( + hobbits[index % GPR_ARRAY_SIZE(hobbits)][0]); + extra_metadata[1].key = grpc_slice_from_static_string("hobbit-second-name"); + extra_metadata[1].value = grpc_slice_from_static_string( + hobbits[index % GPR_ARRAY_SIZE(hobbits)][1]); + extra_metadata[2].key = grpc_slice_from_static_string("dragon"); + extra_metadata[2].value = + grpc_slice_from_static_string(dragons[index % GPR_ARRAY_SIZE(dragons)]); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = GPR_ARRAY_SIZE(extra_metadata); + op->data.send_initial_metadata.metadata = extra_metadata; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_size(grpc_end2end_test_config config, int encode_size, + int decode_size) { + size_t i; + grpc_end2end_test_fixture f; + grpc_arg server_arg; + grpc_channel_args server_args; + grpc_arg client_arg; + grpc_channel_args client_args; + + server_arg.type = GRPC_ARG_INTEGER; + server_arg.key = const_cast(GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER); + server_arg.value.integer = decode_size; + server_args.num_args = 1; + server_args.args = &server_arg; + + client_arg.type = GRPC_ARG_INTEGER; + client_arg.key = const_cast(GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER); + client_arg.value.integer = encode_size; + client_args.num_args = 1; + client_args.args = &client_arg; + + std::string name = + absl::StrFormat("test_size:e=%d:d=%d", encode_size, decode_size); + f = begin_test(config, name.c_str(), + encode_size != 4096 ? &client_args : nullptr, + decode_size != 4096 ? &server_args : nullptr); + for (i = 0; i < 4 * GPR_ARRAY_SIZE(hobbits); i++) { + simple_request_body(config, f, i); + } + end_test(&f); + config.tear_down_data(&f); +} + +void hpack_size(grpc_end2end_test_config config) { + static const int interesting_sizes[] = {4096, 0, 100, + 1000, 32768, 4 * 1024 * 1024}; + size_t i, j; + + for (i = 0; i < GPR_ARRAY_SIZE(interesting_sizes); i++) { + for (j = 0; j < GPR_ARRAY_SIZE(interesting_sizes); j++) { + test_size(config, interesting_sizes[i], interesting_sizes[j]); + } + } +} + +void hpack_size_pre_init(void) {} diff --git a/test/core/end2end/tests/idempotent_request.cc b/test/core/end2end/tests/idempotent_request.cc new file mode 100644 index 00000000..a271c96c --- /dev/null +++ b/test/core/end2end/tests/idempotent_request.cc @@ -0,0 +1,239 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_10_simple_requests", nullptr, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + gpr_log(GPR_INFO, "Passed simple request %d", i); + } + end_test(&f); + config.tear_down_data(&f); +} + +void idempotent_request(grpc_end2end_test_config config) { + int i; + for (i = 0; i < 10; i++) { + test_invoke_simple_request(config); + } + test_invoke_10_simple_requests(config); +} + +void idempotent_request_pre_init(void) {} diff --git a/test/core/end2end/tests/invoke_large_request.cc b/test/core/end2end/tests/invoke_large_request.cc new file mode 100644 index 00000000..2fff0cc0 --- /dev/null +++ b/test/core/end2end/tests/invoke_large_request.cc @@ -0,0 +1,285 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, n_seconds_from_now(5), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static grpc_slice large_slice(void) { + grpc_slice slice = grpc_slice_malloc(1000000); + memset(GRPC_SLICE_START_PTR(slice), 'x', GRPC_SLICE_LENGTH(slice)); + return slice; +} + +static void test_invoke_large_request(grpc_end2end_test_config config, + int max_frame_size, int lookahead_bytes) { + std::string name = absl::StrFormat( + "test_invoke_large_request:max_frame_size=%d:lookahead_bytes=%d", + max_frame_size, lookahead_bytes); + + grpc_arg args[2]; + args[0].type = GRPC_ARG_INTEGER; + args[0].key = const_cast(GRPC_ARG_HTTP2_MAX_FRAME_SIZE); + args[0].value.integer = max_frame_size; + args[1].type = GRPC_ARG_INTEGER; + args[1].key = const_cast(GRPC_ARG_HTTP2_STREAM_LOOKAHEAD_BYTES); + args[1].value.integer = lookahead_bytes; + grpc_channel_args channel_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, name.c_str(), &channel_args, &channel_args); + + grpc_slice request_payload_slice = large_slice(); + grpc_slice response_payload_slice = large_slice(); + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = n_seconds_from_now(30); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + end_test(&f); + config.tear_down_data(&f); +} + +void invoke_large_request(grpc_end2end_test_config config) { + test_invoke_large_request(config, 16384, 65536); + test_invoke_large_request(config, 32768, 65536); + + test_invoke_large_request(config, 1000000 - 1, 65536); + test_invoke_large_request(config, 1000000, 65536); + test_invoke_large_request(config, 1000000 + 1, 65536); + test_invoke_large_request(config, 1000000 + 2, 65536); + test_invoke_large_request(config, 1000000 + 3, 65536); + test_invoke_large_request(config, 1000000 + 4, 65536); + test_invoke_large_request(config, 1000000 + 5, 65536); + test_invoke_large_request(config, 1000000 + 6, 65536); + + test_invoke_large_request(config, 1000000 - 1, 2000000); + test_invoke_large_request(config, 1000000, 2000000); + test_invoke_large_request(config, 1000000 + 1, 2000000); + test_invoke_large_request(config, 1000000 + 2, 2000000); + test_invoke_large_request(config, 1000000 + 3, 2000000); + test_invoke_large_request(config, 1000000 + 4, 2000000); + test_invoke_large_request(config, 1000000 + 5, 2000000); + test_invoke_large_request(config, 1000000 + 6, 2000000); +} + +void invoke_large_request_pre_init(void) {} diff --git a/test/core/end2end/tests/keepalive_timeout.cc b/test/core/end2end/tests/keepalive_timeout.cc new file mode 100644 index 00000000..6db92f2c --- /dev/null +++ b/test/core/end2end/tests/keepalive_timeout.cc @@ -0,0 +1,443 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/frame_ping.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +#ifdef GRPC_POSIX_SOCKET +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "%s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + five_seconds_from_now(), nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client sends a request, server replies with a payload, then waits for the + keepalive watchdog timeouts before returning status. */ +static void test_keepalive_timeout(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + grpc_arg keepalive_arg_elems[3]; + keepalive_arg_elems[0].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[0].key = const_cast(GRPC_ARG_KEEPALIVE_TIME_MS); + keepalive_arg_elems[0].value.integer = 3500; + keepalive_arg_elems[1].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[1].key = const_cast(GRPC_ARG_KEEPALIVE_TIMEOUT_MS); + keepalive_arg_elems[1].value.integer = 0; + keepalive_arg_elems[2].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[2].key = const_cast(GRPC_ARG_HTTP2_BDP_PROBE); + keepalive_arg_elems[2].value.integer = 0; + grpc_channel_args keepalive_args = {GPR_ARRAY_SIZE(keepalive_arg_elems), + keepalive_arg_elems}; + + grpc_end2end_test_fixture f = + begin_test(config, "keepalive_timeout", &keepalive_args, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + /* Disable ping ack to trigger the keepalive timeout */ + grpc_set_disable_ping_ack(true); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + + char* details_str = grpc_slice_to_c_string(details); + char* method_str = grpc_slice_to_c_string(call_details.method); + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "keepalive watchdog timeout")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + gpr_free(details_str); + gpr_free(method_str); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +/* Verify that reads reset the keepalive ping timer. The client sends 30 pings + * with a sleep of 10ms in between. It has a configured keepalive timer of + * 200ms. In the success case, each ping ack should reset the keepalive timer so + * that the keepalive ping is never sent. */ +static void test_read_delays_keepalive(grpc_end2end_test_config config) { +#ifdef GRPC_POSIX_SOCKET + grpc_core::UniquePtr poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + /* It is hard to get the timing right for the polling engine poll. */ + if ((0 == strcmp(poller.get(), "poll"))) { + return; + } +#endif // GRPC_POSIX_SOCKET + const int kPingIntervalMS = 100; + grpc_arg keepalive_arg_elems[3]; + keepalive_arg_elems[0].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[0].key = const_cast(GRPC_ARG_KEEPALIVE_TIME_MS); + keepalive_arg_elems[0].value.integer = 20 * kPingIntervalMS; + keepalive_arg_elems[1].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[1].key = const_cast(GRPC_ARG_KEEPALIVE_TIMEOUT_MS); + keepalive_arg_elems[1].value.integer = 0; + keepalive_arg_elems[2].type = GRPC_ARG_INTEGER; + keepalive_arg_elems[2].key = const_cast(GRPC_ARG_HTTP2_BDP_PROBE); + keepalive_arg_elems[2].value.integer = 0; + grpc_channel_args keepalive_args = {GPR_ARRAY_SIZE(keepalive_arg_elems), + keepalive_arg_elems}; + grpc_end2end_test_fixture f = begin_test(config, "test_read_delays_keepalive", + &keepalive_args, nullptr); + /* Disable ping ack to trigger the keepalive timeout */ + grpc_set_disable_ping_ack(true); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_byte_buffer* request_payload; + grpc_byte_buffer* request_payload_recv; + grpc_byte_buffer* response_payload; + grpc_byte_buffer* response_payload_recv; + int i; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + for (i = 0; i < 30; i++) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + /* Sleep for a short interval to check if the client sends any pings */ + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(kPingIntervalMS)); + } + + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + + end_test(&f); + config.tear_down_data(&f); +} + +void keepalive_timeout(grpc_end2end_test_config config) { + test_keepalive_timeout(config); + test_read_delays_keepalive(config); +} + +void keepalive_timeout_pre_init(void) {} diff --git a/test/core/end2end/tests/large_metadata.cc b/test/core/end2end/tests/large_metadata.cc new file mode 100644 index 00000000..19ac61a2 --- /dev/null +++ b/test/core/end2end/tests/large_metadata.cc @@ -0,0 +1,383 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Request with a large amount of metadata. +static void test_request_with_large_metadata(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_metadata meta; + const size_t large_size = 64 * 1024; + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_MAX_METADATA_SIZE); + arg.value.integer = static_cast(large_size) + 1024; + grpc_channel_args args = {1, &arg}; + grpc_end2end_test_fixture f = + begin_test(config, "test_request_with_large_metadata", &args, &args); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + meta.key = grpc_slice_from_static_string("key"); + meta.value = grpc_slice_malloc(large_size); + memset(GRPC_SLICE_START_PTR(meta.value), 'a', large_size); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + // Client: send request. + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &meta; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + // Server: send initial metadata and receive request. + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + // Server: receive close and send status. This should trigger + // completion of request on client. + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(contains_metadata_slices(&request_metadata_recv, + grpc_slice_from_static_string("key"), + meta.value)); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + grpc_slice_unref(meta.value); + + end_test(&f); + config.tear_down_data(&f); +} + +// Server responds with metadata larger than what the client accepts. +static void test_request_with_bad_large_metadata_response( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_metadata meta; + const size_t large_size = 64 * 1024; + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_MAX_METADATA_SIZE); + arg.value.integer = 1024; + grpc_channel_args args = {1, &arg}; + grpc_end2end_test_fixture f = begin_test( + config, "test_request_with_bad_large_metadata_response", &args, &args); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + meta.key = grpc_slice_from_static_string("key"); + meta.value = grpc_slice_malloc(large_size); + memset(GRPC_SLICE_START_PTR(meta.value), 'a', large_size); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + // Client: send request. + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + // Server: send large initial metadata + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &meta; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); + GPR_ASSERT(0 == grpc_slice_str_cmp( + details, "received initial metadata size exceeds limit")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_slice_unref(meta.value); + + end_test(&f); + config.tear_down_data(&f); +} + +void large_metadata(grpc_end2end_test_config config) { + test_request_with_large_metadata(config); + // TODO(yashykt): Maybe add checks for metadata size in inproc transport too. + if (strcmp(config.name, "inproc") != 0) { + test_request_with_bad_large_metadata_response(config); + } +} + +void large_metadata_pre_init(void) {} diff --git a/test/core/end2end/tests/load_reporting_hook.cc b/test/core/end2end/tests/load_reporting_hook.cc new file mode 100644 index 00000000..f96cda34 --- /dev/null +++ b/test/core/end2end/tests/load_reporting_hook.cc @@ -0,0 +1,312 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/load_reporting/server_load_reporting_filter.h" +#include "src/core/ext/filters/load_reporting/server_load_reporting_plugin.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static void* tag(intptr_t t) { return (void*)t; } + +typedef struct { + gpr_mu mu; + intptr_t channel_id; + intptr_t call_id; + + char* initial_md_str; + char* trailing_md_str; + char* method_name; + + uint64_t incoming_bytes; + uint64_t outgoing_bytes; + + grpc_status_code call_final_status; + + bool fully_processed; +} load_reporting_data; + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void request_response_with_payload( + grpc_end2end_test_config config, grpc_end2end_test_fixture f, + const char* method_name, const char* request_msg, const char* response_msg, + grpc_metadata* initial_lr_metadata, grpc_metadata* trailing_lr_metadata) { + grpc_slice request_payload_slice = grpc_slice_from_static_string(request_msg); + grpc_slice response_payload_slice = + grpc_slice_from_static_string(response_msg); + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string(method_name), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + GPR_ASSERT(initial_lr_metadata != nullptr); + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = initial_lr_metadata; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + GPR_ASSERT(trailing_lr_metadata != nullptr); + op->data.send_status_from_server.trailing_metadata_count = 1; + op->data.send_status_from_server.trailing_metadata = trailing_lr_metadata; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); +} + +/* override the default for testing purposes */ +extern void (*g_load_reporting_fn)( + const grpc_load_reporting_call_data* call_data); + +static void test_load_reporting_hook(grpc_end2end_test_config config) { + /* TODO(dgq): this test is currently a noop until LR is fully defined. + * Leaving the rest here, as it'll likely be reusable. */ + + /* Introduce load reporting for the server through its arguments */ + grpc_arg arg = grpc_load_reporting_enable_arg(); + grpc_channel_args* lr_server_args = + grpc_channel_args_copy_and_add(nullptr, &arg, 1); + + grpc_end2end_test_fixture f = + begin_test(config, "test_load_reporting_hook", nullptr, lr_server_args); + + const char* method_name = "/gRPCFTW"; + const char* request_msg = "the msg from the client"; + const char* response_msg = "... and the response from the server"; + + grpc_metadata initial_lr_metadata; + grpc_metadata trailing_lr_metadata; + + initial_lr_metadata.key = GRPC_MDSTR_LB_TOKEN; + initial_lr_metadata.value = grpc_slice_from_static_string("client-token"); + memset(&initial_lr_metadata.internal_data, 0, + sizeof(initial_lr_metadata.internal_data)); + + trailing_lr_metadata.key = GRPC_MDSTR_LB_COST_BIN; + trailing_lr_metadata.value = grpc_slice_from_static_string("server-token"); + memset(&trailing_lr_metadata.internal_data, 0, + sizeof(trailing_lr_metadata.internal_data)); + + request_response_with_payload(config, f, method_name, request_msg, + response_msg, &initial_lr_metadata, + &trailing_lr_metadata); + end_test(&f); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(lr_server_args); + } + config.tear_down_data(&f); +} + +void load_reporting_hook(grpc_end2end_test_config config) { + test_load_reporting_hook(config); +} + +void load_reporting_hook_pre_init(void) {} diff --git a/test/core/end2end/tests/max_concurrent_streams.cc b/test/core/end2end/tests/max_concurrent_streams.cc new file mode 100644 index 00000000..3d3e4063 --- /dev/null +++ b/test/core/end2end/tests/max_concurrent_streams.cc @@ -0,0 +1,832 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_max_concurrent_streams(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + grpc_arg server_arg; + grpc_channel_args server_args; + grpc_call* c1; + grpc_call* c2; + grpc_call* s1; + grpc_call* s2; + int live_call; + gpr_timespec deadline; + cq_verifier* cqv; + grpc_event ev; + grpc_call_details call_details; + grpc_metadata_array request_metadata_recv; + grpc_metadata_array initial_metadata_recv1; + grpc_metadata_array trailing_metadata_recv1; + grpc_metadata_array initial_metadata_recv2; + grpc_metadata_array trailing_metadata_recv2; + grpc_status_code status1; + grpc_call_error error; + grpc_slice details1; + grpc_status_code status2; + grpc_slice details2; + grpc_op ops[6]; + grpc_op* op; + int was_cancelled; + int got_client_start; + int got_server_start; + + server_arg.key = const_cast(GRPC_ARG_MAX_CONCURRENT_STREAMS); + server_arg.type = GRPC_ARG_INTEGER; + server_arg.value.integer = 1; + + server_args.num_args = 1; + server_args.args = &server_arg; + + f = begin_test(config, "test_max_concurrent_streams", nullptr, &server_args); + cqv = cq_verifier_create(f.cq); + + grpc_metadata_array_init(&request_metadata_recv); + grpc_metadata_array_init(&initial_metadata_recv1); + grpc_metadata_array_init(&trailing_metadata_recv1); + grpc_metadata_array_init(&initial_metadata_recv2); + grpc_metadata_array_init(&trailing_metadata_recv2); + grpc_call_details_init(&call_details); + + /* perform a ping-pong to ensure that settings have had a chance to round + trip */ + simple_request_body(config, f); + /* perform another one to make sure that the one stream case still works */ + simple_request_body(config, f); + + /* start two requests - ensuring that the second is not accepted until + the first completes */ + deadline = n_seconds_from_now(1000); + c1 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/alpha"), + nullptr, deadline, nullptr); + GPR_ASSERT(c1); + c2 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/beta"), + nullptr, deadline, nullptr); + GPR_ASSERT(c2); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s1, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(301), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv1; + op->data.recv_status_on_client.status = &status1; + op->data.recv_status_on_client.status_details = &details1; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv1; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(302), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(401), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv2; + op->data.recv_status_on_client.status = &status2; + op->data.recv_status_on_client.status_details = &details2; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv1; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(402), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + got_client_start = 0; + got_server_start = 0; + live_call = -1; + while (!got_client_start || !got_server_start) { + ev = grpc_completion_queue_next(f.cq, grpc_timeout_seconds_to_deadline(3), + nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.success); + if (ev.tag == tag(101)) { + GPR_ASSERT(!got_server_start); + got_server_start = 1; + } else { + GPR_ASSERT(!got_client_start); + GPR_ASSERT(ev.tag == tag(301) || ev.tag == tag(401)); + /* The /alpha or /beta calls started above could be invoked (but NOT + * both); + * check this here */ + /* We'll get tag 303 or 403, we want 300, 400 */ + live_call = (static_cast(reinterpret_cast(ev.tag))) - 1; + got_client_start = 1; + } + } + GPR_ASSERT(live_call == 300 || live_call == 400); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s1, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(live_call + 2), 1); + /* first request is finished, we should be able to start the second */ + live_call = (live_call == 300) ? 400 : 300; + CQ_EXPECT_COMPLETION(cqv, tag(live_call + 1), 1); + cq_verify(cqv); + + grpc_call_details_destroy(&call_details); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s2, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201))); + CQ_EXPECT_COMPLETION(cqv, tag(201), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s2, ops, static_cast(op - ops), + tag(202), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(live_call + 2), 1); + CQ_EXPECT_COMPLETION(cqv, tag(202), 1); + cq_verify(cqv); + + cq_verifier_destroy(cqv); + + grpc_call_unref(c1); + grpc_call_unref(s1); + grpc_call_unref(c2); + grpc_call_unref(s2); + + grpc_slice_unref(details1); + grpc_slice_unref(details2); + grpc_metadata_array_destroy(&initial_metadata_recv1); + grpc_metadata_array_destroy(&trailing_metadata_recv1); + grpc_metadata_array_destroy(&initial_metadata_recv2); + grpc_metadata_array_destroy(&trailing_metadata_recv2); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_max_concurrent_streams_with_timeout_on_first( + grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + grpc_arg server_arg; + grpc_channel_args server_args; + grpc_call* c1; + grpc_call* c2; + grpc_call* s1; + grpc_call* s2; + cq_verifier* cqv; + grpc_call_details call_details; + grpc_metadata_array request_metadata_recv; + grpc_metadata_array initial_metadata_recv1; + grpc_metadata_array trailing_metadata_recv1; + grpc_metadata_array initial_metadata_recv2; + grpc_metadata_array trailing_metadata_recv2; + grpc_status_code status1; + grpc_call_error error; + grpc_slice details1 = grpc_empty_slice(); + grpc_status_code status2; + grpc_slice details2 = grpc_empty_slice(); + grpc_op ops[6]; + grpc_op* op; + int was_cancelled; + + server_arg.key = const_cast(GRPC_ARG_MAX_CONCURRENT_STREAMS); + server_arg.type = GRPC_ARG_INTEGER; + server_arg.value.integer = 1; + + server_args.num_args = 1; + server_args.args = &server_arg; + + f = begin_test(config, "test_max_concurrent_streams_with_timeout_on_first", + nullptr, &server_args); + cqv = cq_verifier_create(f.cq); + + grpc_metadata_array_init(&request_metadata_recv); + grpc_metadata_array_init(&initial_metadata_recv1); + grpc_metadata_array_init(&trailing_metadata_recv1); + grpc_metadata_array_init(&initial_metadata_recv2); + grpc_metadata_array_init(&trailing_metadata_recv2); + grpc_call_details_init(&call_details); + + /* perform a ping-pong to ensure that settings have had a chance to round + trip */ + simple_request_body(config, f); + /* perform another one to make sure that the one stream case still works */ + simple_request_body(config, f); + + /* start two requests - ensuring that the second is not accepted until + the first completes */ + c1 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/alpha"), + nullptr, n_seconds_from_now(3), nullptr); + GPR_ASSERT(c1); + c2 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/beta"), + nullptr, n_seconds_from_now(1000), nullptr); + GPR_ASSERT(c2); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s1, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(301), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv1; + op->data.recv_status_on_client.status = &status1; + op->data.recv_status_on_client.status_details = &details1; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv1; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(302), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(301), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(401), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv2; + op->data.recv_status_on_client.status = &status2; + op->data.recv_status_on_client.status_details = &details2; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv2; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(402), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s2, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201))); + + CQ_EXPECT_COMPLETION(cqv, tag(302), 1); + /* first request is finished, we should be able to start the second */ + CQ_EXPECT_COMPLETION(cqv, tag(401), 1); + CQ_EXPECT_COMPLETION(cqv, tag(201), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s2, ops, static_cast(op - ops), + tag(202), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(402), 1); + CQ_EXPECT_COMPLETION(cqv, tag(202), 1); + cq_verify(cqv); + + cq_verifier_destroy(cqv); + + grpc_call_unref(c1); + grpc_call_unref(s1); + grpc_call_unref(c2); + grpc_call_unref(s2); + + grpc_slice_unref(details1); + grpc_slice_unref(details2); + grpc_metadata_array_destroy(&initial_metadata_recv1); + grpc_metadata_array_destroy(&trailing_metadata_recv1); + grpc_metadata_array_destroy(&initial_metadata_recv2); + grpc_metadata_array_destroy(&trailing_metadata_recv2); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_max_concurrent_streams_with_timeout_on_second( + grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + grpc_arg server_arg; + grpc_channel_args server_args; + grpc_call* c1; + grpc_call* c2; + grpc_call* s1; + cq_verifier* cqv; + grpc_call_details call_details; + grpc_metadata_array request_metadata_recv; + grpc_metadata_array initial_metadata_recv1; + grpc_metadata_array trailing_metadata_recv1; + grpc_metadata_array initial_metadata_recv2; + grpc_metadata_array trailing_metadata_recv2; + grpc_status_code status1; + grpc_call_error error; + grpc_slice details1 = grpc_empty_slice(); + grpc_status_code status2; + grpc_slice details2 = grpc_empty_slice(); + grpc_op ops[6]; + grpc_op* op; + int was_cancelled; + + server_arg.key = const_cast(GRPC_ARG_MAX_CONCURRENT_STREAMS); + server_arg.type = GRPC_ARG_INTEGER; + server_arg.value.integer = 1; + + server_args.num_args = 1; + server_args.args = &server_arg; + + f = begin_test(config, "test_max_concurrent_streams_with_timeout_on_second", + nullptr, &server_args); + cqv = cq_verifier_create(f.cq); + + grpc_metadata_array_init(&request_metadata_recv); + grpc_metadata_array_init(&initial_metadata_recv1); + grpc_metadata_array_init(&trailing_metadata_recv1); + grpc_metadata_array_init(&initial_metadata_recv2); + grpc_metadata_array_init(&trailing_metadata_recv2); + grpc_call_details_init(&call_details); + + /* perform a ping-pong to ensure that settings have had a chance to round + trip */ + simple_request_body(config, f); + /* perform another one to make sure that the one stream case still works */ + simple_request_body(config, f); + + /* start two requests - ensuring that the second is not accepted until + the first completes , and the second request will timeout in the + concurrent_list */ + c1 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/alpha"), + nullptr, n_seconds_from_now(1000), nullptr); + GPR_ASSERT(c1); + c2 = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/beta"), + nullptr, n_seconds_from_now(3), nullptr); + GPR_ASSERT(c2); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s1, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(301), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv1; + op->data.recv_status_on_client.status = &status1; + op->data.recv_status_on_client.status_details = &details1; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv1; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c1, ops, static_cast(op - ops), + tag(302), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(301), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(401), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv2; + op->data.recv_status_on_client.status = &status2; + op->data.recv_status_on_client.status_details = &details2; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv2; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c2, ops, static_cast(op - ops), + tag(402), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* the second request is time out*/ + CQ_EXPECT_COMPLETION(cqv, tag(401), 0); + CQ_EXPECT_COMPLETION(cqv, tag(402), 1); + cq_verify(cqv); + + /* second request is finished because of time out, so destroy the second call + */ + grpc_call_unref(c2); + + /* now reply the first call */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s1, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(302), 1); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + cq_verifier_destroy(cqv); + + grpc_call_unref(c1); + grpc_call_unref(s1); + + grpc_slice_unref(details1); + grpc_slice_unref(details2); + grpc_metadata_array_destroy(&initial_metadata_recv1); + grpc_metadata_array_destroy(&trailing_metadata_recv1); + grpc_metadata_array_destroy(&initial_metadata_recv2); + grpc_metadata_array_destroy(&trailing_metadata_recv2); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + end_test(&f); + config.tear_down_data(&f); +} + +void max_concurrent_streams(grpc_end2end_test_config config) { + test_max_concurrent_streams_with_timeout_on_first(config); + test_max_concurrent_streams_with_timeout_on_second(config); + test_max_concurrent_streams(config); +} + +void max_concurrent_streams_pre_init(void) {} diff --git a/test/core/end2end/tests/max_connection_age.cc b/test/core/end2end/tests/max_connection_age.cc new file mode 100644 index 00000000..3219b60e --- /dev/null +++ b/test/core/end2end/tests/max_connection_age.cc @@ -0,0 +1,362 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +#define MAX_CONNECTION_AGE_MS 500 +#define MAX_CONNECTION_AGE_GRACE_MS 1000 +#define MAX_CONNECTION_IDLE_MS 9999 + +#define MAX_CONNECTION_AGE_JITTER_MULTIPLIER 1.1 +#define CALL_DEADLINE_S 10 +/* The amount of time we wait for the connection to time out, but after it the + connection should not use up its grace period. It should be a number between + MAX_CONNECTION_AGE_MS and MAX_CONNECTION_AGE_MS + + MAX_CONNECTION_AGE_GRACE_MS */ +#define CQ_MAX_CONNECTION_AGE_WAIT_TIME_S 1 +/* The amount of time we wait after the connection reaches its max age, it + should be shorter than CALL_DEADLINE_S - CQ_MAX_CONNECTION_AGE_WAIT_TIME_S */ +#define CQ_MAX_CONNECTION_AGE_GRACE_WAIT_TIME_S 2 +/* The grace period for the test to observe the channel shutdown process */ +#define IMMEDIATE_SHUTDOWN_GRACE_TIME_MS 3000 + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), + nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_max_age_forcibly_close(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_arg server_a[3]; + server_a[0].type = GRPC_ARG_INTEGER; + server_a[0].key = const_cast(GRPC_ARG_MAX_CONNECTION_AGE_MS); + server_a[0].value.integer = MAX_CONNECTION_AGE_MS; + server_a[1].type = GRPC_ARG_INTEGER; + server_a[1].key = const_cast(GRPC_ARG_MAX_CONNECTION_AGE_GRACE_MS); + server_a[1].value.integer = MAX_CONNECTION_AGE_GRACE_MS; + server_a[2].type = GRPC_ARG_INTEGER; + server_a[2].key = const_cast(GRPC_ARG_MAX_CONNECTION_IDLE_MS); + server_a[2].value.integer = MAX_CONNECTION_IDLE_MS; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, nullptr); + config.init_server(&f, &server_args); + + grpc_call* c; + grpc_call* s; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(CALL_DEADLINE_S); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + gpr_timespec expect_shutdown_time = grpc_timeout_milliseconds_to_deadline( + static_cast(MAX_CONNECTION_AGE_MS * + MAX_CONNECTION_AGE_JITTER_MULTIPLIER) + + MAX_CONNECTION_AGE_GRACE_MS + IMMEDIATE_SHUTDOWN_GRACE_TIME_MS); + + /* Wait for the channel to reach its max age */ + cq_verify_empty_timeout(cqv, CQ_MAX_CONNECTION_AGE_WAIT_TIME_S); + + /* After the channel reaches its max age, we still do nothing here. And wait + for it to use up its max age grace period. */ + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + gpr_timespec channel_shutdown_time = gpr_now(GPR_CLOCK_MONOTONIC); + GPR_ASSERT(gpr_time_cmp(channel_shutdown_time, expect_shutdown_time) < 0); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), true); + cq_verify(cqv); + + grpc_call_unref(s); + + /* The connection should be closed immediately after the max age grace period, + the in-progress RPC should fail. */ + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 1); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_max_age_gracefully_close(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_arg server_a[3]; + server_a[0].type = GRPC_ARG_INTEGER; + server_a[0].key = const_cast(GRPC_ARG_MAX_CONNECTION_AGE_MS); + server_a[0].value.integer = MAX_CONNECTION_AGE_MS; + server_a[1].type = GRPC_ARG_INTEGER; + server_a[1].key = const_cast(GRPC_ARG_MAX_CONNECTION_AGE_GRACE_MS); + server_a[1].value.integer = INT_MAX; + server_a[2].type = GRPC_ARG_INTEGER; + server_a[2].key = const_cast(GRPC_ARG_MAX_CONNECTION_IDLE_MS); + server_a[2].value.integer = MAX_CONNECTION_IDLE_MS; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, nullptr); + config.init_server(&f, &server_args); + + grpc_call* c; + grpc_call* s; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(CALL_DEADLINE_S); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + /* Wait for the channel to reach its max age */ + cq_verify_empty_timeout(cqv, CQ_MAX_CONNECTION_AGE_WAIT_TIME_S); + + /* The connection is shutting down gracefully. In-progress rpc should not be + closed, hence the completion queue should see nothing here. */ + cq_verify_empty_timeout(cqv, CQ_MAX_CONNECTION_AGE_GRACE_WAIT_TIME_S); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), true); + cq_verify(cqv); + + grpc_call_unref(s); + + /* The connection is closed gracefully with goaway, the rpc should still be + completed. */ + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +void max_connection_age(grpc_end2end_test_config config) { + test_max_age_forcibly_close(config); + test_max_age_gracefully_close(config); +} + +void max_connection_age_pre_init(void) {} diff --git a/test/core/end2end/tests/max_connection_idle.cc b/test/core/end2end/tests/max_connection_idle.cc new file mode 100644 index 00000000..ecc51b54 --- /dev/null +++ b/test/core/end2end/tests/max_connection_idle.cc @@ -0,0 +1,238 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +#define MAX_CONNECTION_IDLE_MS 500 +#define MAX_CONNECTION_AGE_MS 9999 + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), + nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture* f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f->cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + c = grpc_channel_create_call(f->client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f->cq, grpc_slice_from_static_string("/foo"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f->server, &s, &call_details, + &request_metadata_recv, f->cq, f->cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_max_connection_idle(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + cq_verifier* cqv = cq_verifier_create(f.cq); + + grpc_arg client_a[1]; + client_a[0].type = GRPC_ARG_INTEGER; + client_a[0].key = + const_cast("grpc.testing.fixed_reconnect_backoff_ms"); + client_a[0].value.integer = 1000; + grpc_arg server_a[2]; + server_a[0].type = GRPC_ARG_INTEGER; + server_a[0].key = const_cast(GRPC_ARG_MAX_CONNECTION_IDLE_MS); + server_a[0].value.integer = MAX_CONNECTION_IDLE_MS; + server_a[1].type = GRPC_ARG_INTEGER; + server_a[1].key = const_cast(GRPC_ARG_MAX_CONNECTION_AGE_MS); + server_a[1].value.integer = MAX_CONNECTION_AGE_MS; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, &client_args); + config.init_server(&f, &server_args); + + /* check that we're still in idle, and start connecting */ + GPR_ASSERT(grpc_channel_check_connectivity_state(f.client, 1) == + GRPC_CHANNEL_IDLE); + /* we'll go through some set of transitions (some might be missed), until + READY is reached */ + while (state != GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state( + f.client, state, grpc_timeout_seconds_to_deadline(3), f.cq, tag(99)); + CQ_EXPECT_COMPLETION(cqv, tag(99), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_READY || + state == GRPC_CHANNEL_CONNECTING || + state == GRPC_CHANNEL_TRANSIENT_FAILURE); + } + + /* Use a simple request to cancel and reset the max idle timer */ + simple_request_body(config, &f); + + /* wait for the channel to reach its maximum idle time */ + grpc_channel_watch_connectivity_state( + f.client, GRPC_CHANNEL_READY, + grpc_timeout_milliseconds_to_deadline(MAX_CONNECTION_IDLE_MS + 3000), + f.cq, tag(99)); + CQ_EXPECT_COMPLETION(cqv, tag(99), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_TRANSIENT_FAILURE || + state == GRPC_CHANNEL_CONNECTING || state == GRPC_CHANNEL_IDLE); + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + cq_verify(cqv); + + grpc_server_destroy(f.server); + grpc_channel_destroy(f.client); + grpc_completion_queue_shutdown(f.cq); + drain_cq(f.cq); + grpc_completion_queue_destroy(f.cq); + grpc_completion_queue_destroy(f.shutdown_cq); + config.tear_down_data(&f); + + cq_verifier_destroy(cqv); +} + +void max_connection_idle(grpc_end2end_test_config config) { + test_max_connection_idle(config); +} + +void max_connection_idle_pre_init(void) {} diff --git a/test/core/end2end/tests/max_message_length.cc b/test/core/end2end/tests/max_message_length.cc new file mode 100644 index 00000000..c36d370f --- /dev/null +++ b/test/core/end2end/tests/max_message_length.cc @@ -0,0 +1,835 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + // We intentionally do not pass the client and server args to + // create_fixture(), since we don't want the limit enforced on the + // proxy, only on the backend server. + f = config.create_fixture(nullptr, nullptr); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->cq, tag(1000)); + grpc_event ev = grpc_completion_queue_next( + f->cq, grpc_timeout_seconds_to_deadline(5), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag(1000)); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Test with request larger than the limit. +// If send_limit is true, applies send limit on client; otherwise, applies +// recv limit on server. +static void test_max_message_length_on_request(grpc_end2end_test_config config, + bool send_limit, + bool use_service_config, + bool use_string_json_value) { + gpr_log(GPR_INFO, + "testing request with send_limit=%d use_service_config=%d " + "use_string_json_value=%d", + send_limit, use_service_config, use_string_json_value); + + grpc_end2end_test_fixture f; + grpc_call* c = nullptr; + grpc_call* s = nullptr; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* recv_payload = nullptr; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + grpc_channel_args* client_args = nullptr; + grpc_channel_args* server_args = nullptr; + if (use_service_config) { + // We don't currently support service configs on the server side. + GPR_ASSERT(send_limit); + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + arg.value.string = + use_string_json_value + ? const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"maxRequestMessageBytes\": \"5\"\n" + " } ]\n" + "}") + : const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"maxRequestMessageBytes\": 5\n" + " } ]\n" + "}"); + client_args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + } else { + // Set limit via channel args. + grpc_arg arg; + arg.key = send_limit + ? const_cast(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH) + : const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH); + arg.type = GRPC_ARG_INTEGER; + arg.value.integer = 5; + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + if (send_limit) { + client_args = args; + } else { + server_args = args; + } + } + + f = begin_test(config, "test_max_request_message_length", client_args, + server_args); + { + grpc_core::ExecCtx exec_ctx; + if (client_args != nullptr) grpc_channel_args_destroy(client_args); + if (server_args != nullptr) grpc_channel_args_destroy(server_args); + } + + cqv = cq_verifier_create(f.cq); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + if (send_limit) { + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + goto done; + } + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(was_cancelled == 1); + +done: + GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); + GPR_ASSERT( + grpc_slice_str_cmp( + details, send_limit + ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)") == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(recv_payload); + + grpc_call_unref(c); + if (s != nullptr) grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +// Test with response larger than the limit. +// If send_limit is true, applies send limit on server; otherwise, applies +// recv limit on client. +static void test_max_message_length_on_response(grpc_end2end_test_config config, + bool send_limit, + bool use_service_config, + bool use_string_json_value) { + gpr_log(GPR_INFO, + "testing response with send_limit=%d use_service_config=%d " + "use_string_json_value=%d", + send_limit, use_service_config, use_string_json_value); + + grpc_end2end_test_fixture f; + grpc_call* c = nullptr; + grpc_call* s = nullptr; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* recv_payload = nullptr; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + grpc_channel_args* client_args = nullptr; + grpc_channel_args* server_args = nullptr; + if (use_service_config) { + // We don't currently support service configs on the server side. + GPR_ASSERT(!send_limit); + grpc_arg arg; + arg.type = GRPC_ARG_STRING; + arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + arg.value.string = const_cast( + use_string_json_value + ? "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"maxResponseMessageBytes\": \"5\"\n" + " } ]\n" + "}" + : "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"maxResponseMessageBytes\": 5\n" + " } ]\n" + "}"); + client_args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + } else { + // Set limit via channel args. + grpc_arg arg; + arg.key = send_limit + ? const_cast(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH) + : const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH); + arg.type = GRPC_ARG_INTEGER; + arg.value.integer = 5; + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + if (send_limit) { + server_args = args; + } else { + client_args = args; + } + } + + f = begin_test(config, "test_max_response_message_length", client_args, + server_args); + { + grpc_core::ExecCtx exec_ctx; + if (client_args != nullptr) grpc_channel_args_destroy(client_args); + if (server_args != nullptr) grpc_channel_args_destroy(server_args); + } + cqv = cq_verifier_create(f.cq); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); + GPR_ASSERT( + grpc_slice_str_cmp( + details, send_limit + ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)") == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(recv_payload); + + grpc_call_unref(c); + if (s != nullptr) grpc_call_unref(s); + cq_verifier_destroy(cqv); + end_test(&f); + config.tear_down_data(&f); +} + +static grpc_metadata gzip_compression_override() { + grpc_metadata gzip_compression_override; + gzip_compression_override.key = GRPC_MDSTR_GRPC_INTERNAL_ENCODING_REQUEST; + gzip_compression_override.value = grpc_slice_from_static_string("gzip"); + memset(&gzip_compression_override.internal_data, 0, + sizeof(gzip_compression_override.internal_data)); + return gzip_compression_override; +} + +// Test receive message limit with compressed request larger than the limit +static void test_max_receive_message_length_on_compressed_request( + grpc_end2end_test_config config, bool minimal_stack) { + gpr_log(GPR_INFO, + "test max receive message length on compressed request with " + "minimal_stack=%d", + minimal_stack); + grpc_end2end_test_fixture f; + grpc_call* c = nullptr; + grpc_call* s = nullptr; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_slice request_payload_slice = grpc_slice_malloc(1024); + memset(GRPC_SLICE_START_PTR(request_payload_slice), 'a', 1024); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* recv_payload = nullptr; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details, status_details; + int was_cancelled = 2; + + // Set limit via channel args. + grpc_arg arg[2]; + arg[0] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH), 5); + arg[1] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MINIMAL_STACK), minimal_stack); + grpc_channel_args* server_args = + grpc_channel_args_copy_and_add(nullptr, arg, 2); + + f = begin_test(config, "test_max_request_message_length", nullptr, + server_args); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(server_args); + } + cqv = cq_verifier_create(f.cq); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + grpc_metadata compression_md = gzip_compression_override(); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &compression_md; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + if (minimal_stack) { + /* Expect the RPC to proceed normally for a minimal stack */ + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + } + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + if (minimal_stack) { + /* We do not perform message size checks for minimal stack. */ + GPR_ASSERT(status == GRPC_STATUS_OK); + } else { + GPR_ASSERT(was_cancelled == 1); + GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); + GPR_ASSERT(grpc_slice_str_cmp( + details, "Received message larger than max (29 vs. 5)") == + 0); + } + grpc_slice_unref(details); + grpc_slice_unref(request_payload_slice); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(recv_payload); + grpc_call_unref(c); + if (s != nullptr) grpc_call_unref(s); + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +// Test receive message limit with compressed response larger than the limit. +static void test_max_receive_message_length_on_compressed_response( + grpc_end2end_test_config config, bool minimal_stack) { + gpr_log(GPR_INFO, + "testing max receive message length on compressed response with " + "minimal_stack=%d", + minimal_stack); + grpc_end2end_test_fixture f; + grpc_call* c = nullptr; + grpc_call* s = nullptr; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_slice response_payload_slice = grpc_slice_malloc(1024); + memset(GRPC_SLICE_START_PTR(response_payload_slice), 'a', 1024); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* recv_payload = nullptr; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + // Set limit via channel args. + grpc_arg arg[2]; + arg[0] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH), 5); + arg[1] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MINIMAL_STACK), minimal_stack); + grpc_channel_args* client_args = + grpc_channel_args_copy_and_add(nullptr, arg, 2); + + f = begin_test(config, "test_max_response_message_length", client_args, + nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + } + cqv = cq_verifier_create(f.cq); + + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + grpc_metadata compression_md = gzip_compression_override(); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &compression_md; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + if (minimal_stack) { + /* We do not perform message size checks for minimal stack. */ + GPR_ASSERT(status == GRPC_STATUS_OK); + } else { + GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); + GPR_ASSERT(grpc_slice_str_cmp( + details, "Received message larger than max (29 vs. 5)") == + 0); + } + grpc_slice_unref(details); + grpc_slice_unref(response_payload_slice); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(recv_payload); + + grpc_call_unref(c); + if (s != nullptr) grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void max_message_length(grpc_end2end_test_config config) { + test_max_message_length_on_request(config, false /* send_limit */, + false /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_request(config, true /* send_limit */, + false /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_response(config, false /* send_limit */, + false /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_response(config, true /* send_limit */, + false /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_request(config, true /* send_limit */, + true /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_request(config, true /* send_limit */, + true /* use_service_config */, + true /* use_string_json_value */); + test_max_message_length_on_response(config, false /* send_limit */, + true /* use_service_config */, + false /* use_string_json_value */); + test_max_message_length_on_response(config, false /* send_limit */, + true /* use_service_config */, + true /* use_string_json_value */); + /* The following tests are not useful for inproc transport and do not work + * with our simple proxy. */ + if (strcmp(config.name, "inproc") != 0 && + (config.feature_mask & FEATURE_MASK_SUPPORTS_REQUEST_PROXYING) == 0) { + test_max_receive_message_length_on_compressed_request(config, false); + test_max_receive_message_length_on_compressed_request(config, true); + test_max_receive_message_length_on_compressed_response(config, false); + test_max_receive_message_length_on_compressed_response(config, true); + } +} + +void max_message_length_pre_init(void) {} diff --git a/test/core/end2end/tests/negative_deadline.cc b/test/core/end2end/tests/negative_deadline.cc new file mode 100644 index 00000000..a2d1528e --- /dev/null +++ b/test/core/end2end/tests/negative_deadline.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f, size_t num_ops) { + grpc_call* c; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + gpr_log(GPR_DEBUG, "test with %" PRIuPTR " ops", num_ops); + + gpr_timespec deadline = gpr_inf_past(GPR_CLOCK_REALTIME); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + GPR_ASSERT(num_ops <= (size_t)(op - ops)); + error = grpc_call_start_batch(c, ops, num_ops, tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_DEADLINE_EXCEEDED); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config, + size_t num_ops) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(config, f, num_ops); + end_test(&f); + config.tear_down_data(&f); +} + +void negative_deadline(grpc_end2end_test_config config) { + size_t i; + for (i = 1; i <= 4; i++) { + test_invoke_simple_request(config, i); + } +} + +void negative_deadline_pre_init(void) {} diff --git a/test/core/end2end/tests/no_error_on_hotpath.cc b/test/core/end2end/tests/no_error_on_hotpath.cc new file mode 100644 index 00000000..59ea65bc --- /dev/null +++ b/test/core/end2end/tests/no_error_on_hotpath.cc @@ -0,0 +1,246 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/error.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.status_details = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(GRPC_SLICE_LENGTH(details) == 0); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_no_error_on_hotpath_one_request( + grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request_with_no_error_logging", + nullptr, nullptr); + // First RPC is not considered the hotpath, since there are lots of things to + // set up. + simple_request_body(config, f); + grpc_disable_error_creation(); + simple_request_body(config, f); + grpc_enable_error_creation(); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_no_error_on_hotpath_10_requests( + grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = begin_test( + config, "test_no_error_on_hotpath_in_one_request", nullptr, nullptr); + // First RPC is not considered the hotpath, since there are lots of things to + // set up. + simple_request_body(config, f); + grpc_disable_error_creation(); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + } + grpc_enable_error_creation(); + end_test(&f); + config.tear_down_data(&f); +} + +void no_error_on_hotpath(grpc_end2end_test_config config) { + test_no_error_on_hotpath_one_request(config); + test_no_error_on_hotpath_10_requests(config); +} + +void no_error_on_hotpath_pre_init(void) {} diff --git a/test/core/end2end/tests/no_logging.cc b/test/core/end2end/tests/no_logging.cc new file mode 100644 index 00000000..7ca364c2 --- /dev/null +++ b/test/core/end2end/tests/no_logging.cc @@ -0,0 +1,295 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/error.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +void gpr_default_log(gpr_log_func_args* args); + +static void test_no_log(gpr_log_func_args* args) { + std::string message = absl::StrCat("Unwanted log: ", args->message); + args->message = message.c_str(); + gpr_default_log(args); + abort(); +} + +static void test_no_error_log(gpr_log_func_args* args) { + if (args->severity == GPR_LOG_SEVERITY_ERROR) { + test_no_log(args); + } +} + +static gpr_atm g_log_func = reinterpret_cast(gpr_default_log); + +static void log_dispatcher_func(gpr_log_func_args* args) { + gpr_log_func log_func = + reinterpret_cast(gpr_atm_no_barrier_load(&g_log_func)); + log_func(args); +} + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request_with_no_error_logging", + nullptr, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_10_simple_requests_with_no_error_logging", + nullptr, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + gpr_log(GPR_INFO, "Passed simple request %d", i); + } + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_no_error_logging_in_entire_process( + grpc_end2end_test_config config) { + int i; + gpr_atm_no_barrier_store(&g_log_func, (gpr_atm)test_no_error_log); + for (i = 0; i < 10; i++) { + test_invoke_simple_request(config); + } + test_invoke_10_simple_requests(config); + gpr_atm_no_barrier_store(&g_log_func, (gpr_atm)gpr_default_log); +} + +static void test_no_logging_in_one_request(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_no_logging_in_last_request", nullptr, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + } + gpr_atm_no_barrier_store(&g_log_func, (gpr_atm)test_no_log); + simple_request_body(config, f); + gpr_atm_no_barrier_store(&g_log_func, (gpr_atm)gpr_default_log); + end_test(&f); + config.tear_down_data(&f); +} + +void no_logging(grpc_end2end_test_config config) { + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + grpc_tracer_set_enabled("all", 0); + gpr_set_log_function(log_dispatcher_func); + test_no_logging_in_one_request(config); + test_no_error_logging_in_entire_process(config); + gpr_set_log_function(gpr_default_log); +} + +void no_logging_pre_init(void) {} diff --git a/test/core/end2end/tests/no_op.cc b/test/core/end2end/tests/no_op.cc new file mode 100644 index 00000000..1f3343cf --- /dev/null +++ b/test/core/end2end/tests/no_op.cc @@ -0,0 +1,94 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_no_op(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = begin_test(config, "no-op", nullptr, nullptr); + end_test(&f); + config.tear_down_data(&f); +} + +void no_op(grpc_end2end_test_config config) { test_no_op(config); } + +void no_op_pre_init(void) {} diff --git a/test/core/end2end/tests/payload.cc b/test/core/end2end/tests/payload.cc new file mode 100644 index 00000000..fd4565c1 --- /dev/null +++ b/test/core/end2end/tests/payload.cc @@ -0,0 +1,286 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Creates and returns a grpc_slice containing random alphanumeric characters. + */ +static grpc_slice generate_random_slice() { + size_t i; + static const char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + char* output; + const size_t output_size = 1024 * 1024; + output = static_cast(gpr_malloc(output_size)); + for (i = 0; i < output_size - 1; ++i) { + output[i] = chars[rand() % static_cast(sizeof(chars) - 1)]; + } + output[output_size - 1] = '\0'; + grpc_slice out = grpc_slice_from_copied_string(output); + gpr_free(output); + return out; +} + +static void request_response_with_payload(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + /* Create large request and response bodies. These are big enough to require + * multiple round trips to deliver to the peer, and their exact contents of + * will be verified on completion. */ + grpc_slice request_payload_slice = generate_random_slice(); + grpc_slice response_payload_slice = generate_random_slice(); + + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = n_seconds_from_now(60); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + GPR_ASSERT( + byte_buffer_eq_slice(response_payload_recv, response_payload_slice)); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); +} + +/* Client sends a request with payload, server reads then returns a response + payload and status. */ +static void test_invoke_request_response_with_payload( + grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = begin_test( + config, "test_invoke_request_response_with_payload", nullptr, nullptr); + request_response_with_payload(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_request_response_with_payload( + grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = begin_test( + config, "test_invoke_10_request_response_with_payload", nullptr, nullptr); + for (i = 0; i < 10; i++) { + request_response_with_payload(config, f); + } + end_test(&f); + config.tear_down_data(&f); +} + +void payload(grpc_end2end_test_config config) { + test_invoke_request_response_with_payload(config); + test_invoke_10_request_response_with_payload(config); +} + +void payload_pre_init(void) {} diff --git a/test/core/end2end/tests/ping.cc b/test/core/end2end/tests/ping.cc new file mode 100644 index 00000000..db4d5452 --- /dev/null +++ b/test/core/end2end/tests/ping.cc @@ -0,0 +1,111 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +#define PING_NUM 5 + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static void test_ping(grpc_end2end_test_config config, + int min_time_between_pings_ms) { + grpc_end2end_test_fixture f = config.create_fixture(nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + int i; + + grpc_arg client_a[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 1)}; + grpc_arg server_a[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 1)}; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(client_a), client_a}; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(server_a), server_a}; + + config.init_client(&f, &client_args); + config.init_server(&f, &server_args); + + grpc_channel_ping(f.client, f.cq, tag(0), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(0), 0); + + /* check that we're still in idle, and start connecting */ + GPR_ASSERT(grpc_channel_check_connectivity_state(f.client, 1) == + GRPC_CHANNEL_IDLE); + /* we'll go through some set of transitions (some might be missed), until + READY is reached */ + while (state != GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state( + f.client, state, + gpr_time_add(grpc_timeout_seconds_to_deadline(3), + gpr_time_from_millis(min_time_between_pings_ms * PING_NUM, + GPR_TIMESPAN)), + f.cq, tag(99)); + CQ_EXPECT_COMPLETION(cqv, tag(99), 1); + cq_verify(cqv); + state = grpc_channel_check_connectivity_state(f.client, 0); + GPR_ASSERT(state == GRPC_CHANNEL_READY || + state == GRPC_CHANNEL_CONNECTING || + state == GRPC_CHANNEL_TRANSIENT_FAILURE); + } + + for (i = 1; i <= PING_NUM; i++) { + grpc_channel_ping(f.client, f.cq, tag(i), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(i), 1); + cq_verify(cqv); + } + + grpc_server_shutdown_and_notify(f.server, f.cq, tag(0xdead)); + CQ_EXPECT_COMPLETION(cqv, tag(0xdead), 1); + cq_verify(cqv); + + /* cleanup server */ + grpc_server_destroy(f.server); + + grpc_channel_destroy(f.client); + grpc_completion_queue_shutdown(f.cq); + grpc_completion_queue_destroy(f.cq); + + /* f.shutdown_cq is not used in this test */ + grpc_completion_queue_destroy(f.shutdown_cq); + config.tear_down_data(&f); + + cq_verifier_destroy(cqv); +} + +void ping(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION); + test_ping(config, 0); + test_ping(config, 100); +} + +void ping_pre_init(void) {} diff --git a/test/core/end2end/tests/ping_pong_streaming.cc b/test/core/end2end/tests/ping_pong_streaming.cc new file mode 100644 index 00000000..66c60c3e --- /dev/null +++ b/test/core/end2end/tests/ping_pong_streaming.cc @@ -0,0 +1,280 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client pings and server pongs. Repeat messages rounds before finishing. */ +static void test_pingpong_streaming(grpc_end2end_test_config config, + int messages) { + grpc_end2end_test_fixture f = + begin_test(config, "test_pingpong_streaming", nullptr, nullptr); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_byte_buffer* request_payload; + grpc_byte_buffer* request_payload_recv; + grpc_byte_buffer* response_payload; + grpc_byte_buffer* response_payload_recv; + int i; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + for (i = 0; i < messages; i++) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + } + + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + + end_test(&f); + config.tear_down_data(&f); +} + +void ping_pong_streaming(grpc_end2end_test_config config) { + int i; + + for (i = 1; i < 10; i++) { + test_pingpong_streaming(config, i); + } +} + +void ping_pong_streaming_pre_init(void) {} diff --git a/test/core/end2end/tests/proxy_auth.cc b/test/core/end2end/tests/proxy_auth.cc new file mode 100644 index 00000000..738387aa --- /dev/null +++ b/test/core/end2end/tests/proxy_auth.cc @@ -0,0 +1,233 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/** + * This test is for checking whether proxy authentication is working with HTTP + * Connect. + */ +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/fixtures/http_proxy_fixture.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_proxy_auth(grpc_end2end_test_config config) { + /* Indicate that the proxy requires user auth */ + grpc_arg client_arg; + client_arg.type = GRPC_ARG_STRING; + client_arg.key = const_cast(GRPC_ARG_HTTP_PROXY_AUTH_CREDS); + client_arg.value.string = const_cast(GRPC_TEST_HTTP_PROXY_AUTH_CREDS); + grpc_channel_args client_args = {1, &client_arg}; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_proxy_auth", &client_args, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +void proxy_auth(grpc_end2end_test_config config) { + test_invoke_proxy_auth(config); +} + +void proxy_auth_pre_init(void) {} diff --git a/test/core/end2end/tests/registered_call.cc b/test/core/end2end/tests/registered_call.cc new file mode 100644 index 00000000..b7c0258f --- /dev/null +++ b/test/core/end2end/tests/registered_call.cc @@ -0,0 +1,222 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f, void* rc) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_registered_call( + f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, rc, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + void* rc = grpc_channel_register_call(f.client, "/foo", nullptr, nullptr); + + simple_request_body(config, f, rc); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_10_simple_requests", nullptr, nullptr); + void* rc = grpc_channel_register_call(f.client, "/foo", nullptr, nullptr); + + for (i = 0; i < 10; i++) { + simple_request_body(config, f, rc); + gpr_log(GPR_INFO, "Passed simple request %d", i); + } + end_test(&f); + config.tear_down_data(&f); +} + +void registered_call(grpc_end2end_test_config config) { + test_invoke_simple_request(config); + test_invoke_10_simple_requests(config); +} + +void registered_call_pre_init(void) {} diff --git a/test/core/end2end/tests/request_with_flags.cc b/test/core/end2end/tests/request_with_flags.cc new file mode 100644 index 00000000..3f42b609 --- /dev/null +++ b/test/core/end2end/tests/request_with_flags.cc @@ -0,0 +1,205 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/transport/byte_stream.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec one_second_from_now(void) { return n_seconds_from_now(1); } + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, one_second_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_invoke_request_with_flags( + grpc_end2end_test_config config, uint32_t* flags_for_op, + grpc_call_error call_start_batch_expected_result) { + grpc_call* c; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_request_with_flags", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + grpc_call_error expectation; + + gpr_timespec deadline = one_second_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = flags_for_op[op->op]; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = flags_for_op[op->op]; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = flags_for_op[op->op]; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = flags_for_op[op->op]; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = flags_for_op[op->op]; + op->reserved = nullptr; + op++; + expectation = call_start_batch_expected_result; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(expectation == error); + + if (expectation == GRPC_CALL_OK) { + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + grpc_slice_unref(details); + } + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void request_with_flags(grpc_end2end_test_config config) { + size_t i; + uint32_t flags_for_op[GRPC_OP_RECV_CLOSE_ON_SERVER + 1]; + + { + /* check that all grpc_op_types fail when their flag value is set to an + * invalid value */ + int indices[] = {GRPC_OP_SEND_INITIAL_METADATA, GRPC_OP_SEND_MESSAGE, + GRPC_OP_SEND_CLOSE_FROM_CLIENT, + GRPC_OP_RECV_INITIAL_METADATA, + GRPC_OP_RECV_STATUS_ON_CLIENT}; + for (i = 0; i < GPR_ARRAY_SIZE(indices); ++i) { + memset(flags_for_op, 0, sizeof(flags_for_op)); + flags_for_op[indices[i]] = 0xDEADBEEF; + test_invoke_request_with_flags(config, flags_for_op, + GRPC_CALL_ERROR_INVALID_FLAGS); + } + } + { + /* check valid operation with allowed flags for GRPC_OP_SEND_BUFFER */ + uint32_t flags[] = {GRPC_WRITE_BUFFER_HINT, GRPC_WRITE_NO_COMPRESS, + GRPC_WRITE_INTERNAL_COMPRESS}; + for (i = 0; i < GPR_ARRAY_SIZE(flags); ++i) { + memset(flags_for_op, 0, sizeof(flags_for_op)); + flags_for_op[GRPC_OP_SEND_MESSAGE] = flags[i]; + test_invoke_request_with_flags(config, flags_for_op, GRPC_CALL_OK); + } + } +} + +void request_with_flags_pre_init(void) {} diff --git a/test/core/end2end/tests/request_with_payload.cc b/test/core/end2end/tests/request_with_payload.cc new file mode 100644 index 00000000..1a655c1b --- /dev/null +++ b/test/core/end2end/tests/request_with_payload.cc @@ -0,0 +1,228 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client sends a request with payload, server reads then returns status. */ +static void test_invoke_request_with_payload(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_request_with_payload", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void request_with_payload(grpc_end2end_test_config config) { + test_invoke_request_with_payload(config); +} + +void request_with_payload_pre_init(void) {} diff --git a/test/core/end2end/tests/resource_quota_server.cc b/test/core/end2end/tests/resource_quota_server.cc new file mode 100644 index 00000000..0e5e50d6 --- /dev/null +++ b/test/core/end2end/tests/resource_quota_server.cc @@ -0,0 +1,380 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Creates and returns a grpc_slice containing random alphanumeric characters. + */ +static grpc_slice generate_random_slice() { + size_t i; + static const char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + char* output; + const size_t output_size = 1024 * 1024; + output = static_cast(gpr_malloc(output_size)); + for (i = 0; i < output_size - 1; ++i) { + output[i] = chars[rand() % static_cast(sizeof(chars) - 1)]; + } + output[output_size - 1] = '\0'; + grpc_slice out = grpc_slice_from_copied_string(output); + gpr_free(output); + return out; +} + +void resource_quota_server(grpc_end2end_test_config config) { + if (config.feature_mask & + FEATURE_MASK_DOES_NOT_SUPPORT_RESOURCE_QUOTA_SERVER) { + return; + } + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("test_server"); + grpc_resource_quota_resize(resource_quota, 5 * 1024 * 1024); + +#define NUM_CALLS 100 +#define CLIENT_BASE_TAG 0x1000 +#define SERVER_START_BASE_TAG 0x2000 +#define SERVER_RECV_BASE_TAG 0x3000 +#define SERVER_END_BASE_TAG 0x4000 + + grpc_arg arg; + arg.key = const_cast(GRPC_ARG_RESOURCE_QUOTA); + arg.type = GRPC_ARG_POINTER; + arg.value.pointer.p = resource_quota; + arg.value.pointer.vtable = grpc_resource_quota_arg_vtable(); + grpc_channel_args args = {1, &arg}; + + grpc_end2end_test_fixture f = + begin_test(config, "resource_quota_server", nullptr, &args); + + /* Create large request and response bodies. These are big enough to require + * multiple round trips to deliver to the peer, and their exact contents of + * will be verified on completion. */ + grpc_slice request_payload_slice = generate_random_slice(); + + grpc_call** client_calls = + static_cast(malloc(sizeof(grpc_call*) * NUM_CALLS)); + grpc_call** server_calls = + static_cast(malloc(sizeof(grpc_call*) * NUM_CALLS)); + grpc_metadata_array* initial_metadata_recv = + static_cast( + malloc(sizeof(grpc_metadata_array) * NUM_CALLS)); + grpc_metadata_array* trailing_metadata_recv = + static_cast( + malloc(sizeof(grpc_metadata_array) * NUM_CALLS)); + grpc_metadata_array* request_metadata_recv = + static_cast( + malloc(sizeof(grpc_metadata_array) * NUM_CALLS)); + grpc_call_details* call_details = static_cast( + malloc(sizeof(grpc_call_details) * NUM_CALLS)); + grpc_status_code* status = static_cast( + malloc(sizeof(grpc_status_code) * NUM_CALLS)); + grpc_slice* details = + static_cast(malloc(sizeof(grpc_slice) * NUM_CALLS)); + grpc_byte_buffer** request_payload = static_cast( + malloc(sizeof(grpc_byte_buffer*) * NUM_CALLS)); + grpc_byte_buffer** request_payload_recv = static_cast( + malloc(sizeof(grpc_byte_buffer*) * NUM_CALLS)); + int* was_cancelled = static_cast(malloc(sizeof(int) * NUM_CALLS)); + grpc_call_error error; + int pending_client_calls = 0; + int pending_server_start_calls = 0; + int pending_server_recv_calls = 0; + int pending_server_end_calls = 0; + int cancelled_calls_on_client = 0; + int cancelled_calls_on_server = 0; + int deadline_exceeded = 0; + int unavailable = 0; + + grpc_op ops[6]; + grpc_op* op; + + for (int i = 0; i < NUM_CALLS; i++) { + grpc_metadata_array_init(&initial_metadata_recv[i]); + grpc_metadata_array_init(&trailing_metadata_recv[i]); + grpc_metadata_array_init(&request_metadata_recv[i]); + grpc_call_details_init(&call_details[i]); + request_payload[i] = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + request_payload_recv[i] = nullptr; + was_cancelled[i] = 0; + } + + for (int i = 0; i < NUM_CALLS; i++) { + error = grpc_server_request_call( + f.server, &server_calls[i], &call_details[i], &request_metadata_recv[i], + f.cq, f.cq, tag(SERVER_START_BASE_TAG + i)); + GPR_ASSERT(GRPC_CALL_OK == error); + + pending_server_start_calls++; + } + + for (int i = 0; i < NUM_CALLS; i++) { + client_calls[i] = + grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f.cq, grpc_slice_from_static_string("/foo"), + nullptr, n_seconds_from_now(60), nullptr); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload[i]; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv[i]; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = + &trailing_metadata_recv[i]; + op->data.recv_status_on_client.status = &status[i]; + op->data.recv_status_on_client.status_details = &details[i]; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(client_calls[i], ops, + static_cast(op - ops), + tag(CLIENT_BASE_TAG + i), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + pending_client_calls++; + } + + while (pending_client_calls + pending_server_recv_calls + + pending_server_end_calls > + 0) { + grpc_event ev = + grpc_completion_queue_next(f.cq, n_seconds_from_now(60), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + + int ev_tag = static_cast(reinterpret_cast(ev.tag)); + if (ev_tag < CLIENT_BASE_TAG) { + abort(); /* illegal tag */ + } else if (ev_tag < SERVER_START_BASE_TAG) { + /* client call finished */ + int call_id = ev_tag - CLIENT_BASE_TAG; + GPR_ASSERT(call_id >= 0); + GPR_ASSERT(call_id < NUM_CALLS); + switch (status[call_id]) { + case GRPC_STATUS_RESOURCE_EXHAUSTED: + cancelled_calls_on_client++; + break; + case GRPC_STATUS_DEADLINE_EXCEEDED: + deadline_exceeded++; + break; + case GRPC_STATUS_UNAVAILABLE: + unavailable++; + break; + case GRPC_STATUS_OK: + break; + default: + gpr_log(GPR_ERROR, "Unexpected status code: %d", status[call_id]); + abort(); + } + GPR_ASSERT(pending_client_calls > 0); + + grpc_metadata_array_destroy(&initial_metadata_recv[call_id]); + grpc_metadata_array_destroy(&trailing_metadata_recv[call_id]); + grpc_call_unref(client_calls[call_id]); + grpc_slice_unref(details[call_id]); + grpc_byte_buffer_destroy(request_payload[call_id]); + + pending_client_calls--; + } else if (ev_tag < SERVER_RECV_BASE_TAG) { + /* new incoming call to the server */ + int call_id = ev_tag - SERVER_START_BASE_TAG; + GPR_ASSERT(call_id >= 0); + GPR_ASSERT(call_id < NUM_CALLS); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv[call_id]; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch( + server_calls[call_id], ops, static_cast(op - ops), + tag(SERVER_RECV_BASE_TAG + call_id), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(pending_server_start_calls > 0); + pending_server_start_calls--; + pending_server_recv_calls++; + + grpc_call_details_destroy(&call_details[call_id]); + grpc_metadata_array_destroy(&request_metadata_recv[call_id]); + } else if (ev_tag < SERVER_END_BASE_TAG) { + /* finished read on the server */ + int call_id = ev_tag - SERVER_RECV_BASE_TAG; + GPR_ASSERT(call_id >= 0); + GPR_ASSERT(call_id < NUM_CALLS); + + if (ev.success) { + if (request_payload_recv[call_id] != nullptr) { + grpc_byte_buffer_destroy(request_payload_recv[call_id]); + request_payload_recv[call_id] = nullptr; + } + } else { + GPR_ASSERT(request_payload_recv[call_id] == nullptr); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled[call_id]; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch( + server_calls[call_id], ops, static_cast(op - ops), + tag(SERVER_END_BASE_TAG + call_id), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(pending_server_recv_calls > 0); + pending_server_recv_calls--; + pending_server_end_calls++; + } else { + int call_id = ev_tag - SERVER_END_BASE_TAG; + GPR_ASSERT(call_id >= 0); + GPR_ASSERT(call_id < NUM_CALLS); + + if (was_cancelled[call_id]) { + cancelled_calls_on_server++; + } + GPR_ASSERT(pending_server_end_calls > 0); + pending_server_end_calls--; + + grpc_call_unref(server_calls[call_id]); + } + } + + gpr_log(GPR_INFO, + "Done. %d total calls: %d cancelled at server, %d cancelled at " + "client, %d timed out, %d unavailable.", + NUM_CALLS, cancelled_calls_on_server, cancelled_calls_on_client, + deadline_exceeded, unavailable); + + grpc_slice_unref(request_payload_slice); + grpc_resource_quota_unref(resource_quota); + + end_test(&f); + config.tear_down_data(&f); + + free(client_calls); + free(server_calls); + free(initial_metadata_recv); + free(trailing_metadata_recv); + free(request_metadata_recv); + free(call_details); + free(status); + free(details); + free(request_payload); + free(request_payload_recv); + free(was_cancelled); +} + +void resource_quota_server_pre_init(void) {} diff --git a/test/core/end2end/tests/retry.cc b/test/core/end2end/tests/retry.cc new file mode 100644 index 00000000..68b22ac4 --- /dev/null +++ b/test/core/end2end/tests/retry.cc @@ -0,0 +1,323 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests a basic retry scenario: +// - 2 retries allowed for ABORTED status +// - first attempt returns ABORTED +// - second attempt returns OK +static void test_retry(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was not sent in the + // initial attempt. + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + GPR_ASSERT(!grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + bool found_retry_header = false; + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + if (grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)) { + GPR_ASSERT( + grpc_slice_eq(request_metadata_recv.metadata[i].value, GRPC_MDSTR_1)); + found_retry_header = true; + break; + } + } + GPR_ASSERT(found_retry_header); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry(config); +} + +void retry_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_cancel_during_delay.cc b/test/core/end2end/tests/retry_cancel_during_delay.cc new file mode 100644 index 00000000..b0cafab0 --- /dev/null +++ b/test/core/end2end/tests/retry_cancel_during_delay.cc @@ -0,0 +1,284 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests retry cancellation during backoff. +static void test_retry_cancel_during_delay(grpc_end2end_test_config config, + cancellation_mode mode) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"10s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " },\n" + " \"timeout\": \"5s\"\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + std::string name = absl::StrCat("retry_cancel_during_delay/", mode.name); + grpc_end2end_test_fixture f = + begin_test(config, name.c_str(), &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec expect_finish_before = n_seconds_from_now(10); + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client starts a batch with all 6 ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Server gets a call and fails with retryable status. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Server should never get a second call, because the initial retry + // delay is longer than the call's deadline. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + cq_verify_empty(cqv); + + // Initiate cancellation. + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + // Make sure we didn't wait the full deadline before failing. + gpr_log( + GPR_INFO, "Expect completion before: %s", + absl::FormatTime(grpc_core::ToAbslTime(expect_finish_before)).c_str()); + GPR_ASSERT(gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), expect_finish_before) < + 0); + + gpr_log(GPR_INFO, "status=%d expected=%d", status, mode.expect_status); + GPR_ASSERT(status == mode.expect_status); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_cancel_during_delay(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + for (size_t i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); ++i) { + test_retry_cancel_during_delay(config, cancellation_modes[i]); + } +} + +void retry_cancel_during_delay_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_cancel_with_multiple_send_batches.cc b/test/core/end2end/tests/retry_cancel_with_multiple_send_batches.cc new file mode 100644 index 00000000..bcbe7929 --- /dev/null +++ b/test/core/end2end/tests/retry_cancel_with_multiple_send_batches.cc @@ -0,0 +1,339 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/channel_init.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests cancellation with multiple send op batches. +static void test_retry_cancel_with_multiple_send_batches( + grpc_end2end_test_config config, cancellation_mode mode) { + grpc_call* c; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + char* peer; + + std::string service_config_string = absl::StrFormat( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"%ds\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}", + 5 * grpc_test_slowdown_factor()); + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_RETRIES), 1), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast(service_config_string.c_str())), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + std::string name = + absl::StrCat("retry_cancel_with_multiple_send_batches/", mode.name); + grpc_end2end_test_fixture f = + begin_test(config, name.c_str(), &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = n_seconds_from_now(3); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + // Start a batch containing send_initial_metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Start a batch containing send_message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Start a batch containing send_trailing_metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Start a batch containing recv ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Initiate cancellation. + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + // Client ops should now complete. + CQ_EXPECT_COMPLETION(cqv, tag(1), false); + CQ_EXPECT_COMPLETION(cqv, tag(2), false); + CQ_EXPECT_COMPLETION(cqv, tag(3), false); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + gpr_log(GPR_INFO, "status=%d expected=%d", status, mode.expect_status); + GPR_ASSERT(status == mode.expect_status); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +namespace { + +// A filter that fails all batches with send ops. +class FailSendOpsFilter { + public: + static grpc_channel_filter kFilterVtable; + + public: + class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(args); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); + } + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + if (batch->send_initial_metadata || batch->send_message || + batch->send_trailing_metadata) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "FailSendOpsFilter failing batch"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_ABORTED), + calld->call_combiner_); + return; + } + grpc_call_next_op(elem, batch); + } + + private: + explicit CallData(const grpc_call_element_args* args) + : call_combiner_(args->call_combiner) {} + + grpc_core::CallCombiner* call_combiner_; + }; + + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* /*args*/) { + new (elem->channel_data) FailSendOpsFilter(); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~FailSendOpsFilter(); + } +}; + +grpc_channel_filter FailSendOpsFilter::kFilterVtable = { + CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + sizeof(FailSendOpsFilter), + Init, + Destroy, + grpc_channel_next_get_info, + "FailSendOpsFilter", +}; + +bool MaybeAddFilter(grpc_channel_stack_builder* builder) { + // Skip on proxy (which explicitly disables retries). + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (!grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_RETRIES, true)) { + return true; + } + // Install filter. + return grpc_channel_stack_builder_prepend_filter( + builder, &FailSendOpsFilter::kFilterVtable, nullptr, nullptr); +} + +} // namespace + +void retry_cancel_with_multiple_send_batches(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, 0, + MaybeAddFilter); + }, + [config]() { + for (size_t i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); ++i) { + test_retry_cancel_with_multiple_send_batches(config, + cancellation_modes[i]); + } + }); +} + +void retry_cancel_with_multiple_send_batches_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_cancellation.cc b/test/core/end2end/tests/retry_cancellation.cc new file mode 100644 index 00000000..64108c8b --- /dev/null +++ b/test/core/end2end/tests/retry_cancellation.cc @@ -0,0 +1,278 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests retry cancellation. +static void test_retry_cancellation(grpc_end2end_test_config config, + cancellation_mode mode) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " },\n" + " \"timeout\": \"5s\"\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + std::string name = absl::StrCat("retry_cancellation/", mode.name); + grpc_end2end_test_fixture f = + begin_test(config, name.c_str(), &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client starts a batch with all 6 ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Server gets a call and fails with retryable status. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Server gets a second call (the retry). + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Initiate cancellation. + GPR_ASSERT(GRPC_CALL_OK == mode.initiate_cancel(c, nullptr)); + + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == mode.expect_status); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_cancellation(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + for (size_t i = 0; i < GPR_ARRAY_SIZE(cancellation_modes); ++i) { + test_retry_cancellation(config, cancellation_modes[i]); + } +} + +void retry_cancellation_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_disabled.cc b/test/core/end2end/tests/retry_disabled.cc new file mode 100644 index 00000000..dc1d99f2 --- /dev/null +++ b/test/core/end2end/tests/retry_disabled.cc @@ -0,0 +1,258 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry when retries are disabled via the +// GRPC_ARG_ENABLE_RETRIES channel arg, even when there is retry +// configuration in the service config. +// - 1 retry allowed for ABORTED status +// - first attempt returns ABORTED but does not retry +static void test_retry_disabled(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_RETRIES), 0), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_disabled", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_disabled(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_disabled(config); +} + +void retry_disabled_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_exceeds_buffer_size_in_delay.cc b/test/core/end2end/tests/retry_exceeds_buffer_size_in_delay.cc new file mode 100644 index 00000000..74671136 --- /dev/null +++ b/test/core/end2end/tests/retry_exceeds_buffer_size_in_delay.cc @@ -0,0 +1,315 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests the case where the retry buffer size is exceeded during backoff. +// - 1 retry allowed for ABORTED status +// - buffer size set to 100 KiB (larger than initial metadata) +// - client initially sends initial metadata (smaller than buffer size) +// - server sends ABORTED, client goes into backoff delay +// - client sends a 100 KiB message, thus exceeding the buffer size limit +// - retry attempt gets ABORTED but is not retried +static void test_retry_exceeds_buffer_size_in_delay( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + const size_t buf_size = 102401; + char* buf = static_cast(gpr_malloc(buf_size * sizeof(*buf))); + memset(buf, 'a', buf_size - 1); + buf[buf_size - 1] = '\0'; + grpc_slice request_payload_slice = grpc_slice_from_static_string(buf); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"2s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_PER_RPC_RETRY_BUFFER_SIZE), 102400), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = begin_test( + config, "retry_exceeds_buffer_size_in_delay", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = grpc_timeout_milliseconds_to_deadline(15000); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client sends initial metadata and starts the recv ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Server gets a call. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server sends ABORTED. This tells the client to retry. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Do a bit more polling, to make sure the client sees status from the + // first attempt. (Note: This polls for 1s, which is less than the + // retry initial backoff time of 2s from the service config above.) + cq_verify_empty(cqv); + + // Client sends a message that puts it over the buffer size limit. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + // Server gets another call. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Server again sends ABORTED. But this time, the client won't retry, + // since the call has been committed by exceeding the buffer size. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); + gpr_free(buf); +} + +void retry_exceeds_buffer_size_in_delay(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_exceeds_buffer_size_in_delay(config); +} + +void retry_exceeds_buffer_size_in_delay_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_exceeds_buffer_size_in_initial_batch.cc b/test/core/end2end/tests/retry_exceeds_buffer_size_in_initial_batch.cc new file mode 100644 index 00000000..f20bc693 --- /dev/null +++ b/test/core/end2end/tests/retry_exceeds_buffer_size_in_initial_batch.cc @@ -0,0 +1,262 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't make any further attempts after we exceed the +// max buffer size. +// - 1 retry allowed for ABORTED status +// - buffer size set to 2 bytes +// - client sends a 3-byte message +// - first attempt gets ABORTED but is not retried +static void test_retry_exceeds_buffer_size_in_initial_batch( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_PER_RPC_RETRY_BUFFER_SIZE), 2), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_exceeds_buffer_size_in_initial_batch", + &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_exceeds_buffer_size_in_initial_batch( + grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_exceeds_buffer_size_in_initial_batch(config); +} + +void retry_exceeds_buffer_size_in_initial_batch_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_exceeds_buffer_size_in_subsequent_batch.cc b/test/core/end2end/tests/retry_exceeds_buffer_size_in_subsequent_batch.cc new file mode 100644 index 00000000..f4630fdf --- /dev/null +++ b/test/core/end2end/tests/retry_exceeds_buffer_size_in_subsequent_batch.cc @@ -0,0 +1,276 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Similar to the retry_exceeds_buffer_size_in_initial_batch test, but we +// don't exceed the buffer size until the second batch. +// - 1 retry allowed for ABORTED status +// - buffer size set to 100 KiB (larger than initial metadata) +// - client sends a 100 KiB message +// - first attempt gets ABORTED but is not retried +static void test_retry_exceeds_buffer_size_in_subsequent_batch( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + const size_t buf_size = 102401; + char* buf = static_cast(gpr_malloc(buf_size * sizeof(*buf))); + memset(buf, 'a', buf_size - 1); + buf[buf_size - 1] = '\0'; + // TODO(markdroth): buf is not a static string, so fix the next line + grpc_slice request_payload_slice = grpc_slice_from_static_string(buf); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_PER_RPC_RETRY_BUFFER_SIZE), 102400), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_exceeds_buffer_size_in_subsequent_batch", + &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); + gpr_free(buf); +} + +void retry_exceeds_buffer_size_in_subsequent_batch( + grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_exceeds_buffer_size_in_subsequent_batch(config); +} + +void retry_exceeds_buffer_size_in_subsequent_batch_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_lb_drop.cc b/test/core/end2end/tests/retry_lb_drop.cc new file mode 100644 index 00000000..4e0dc323 --- /dev/null +++ b/test/core/end2end/tests/retry_lb_drop.cc @@ -0,0 +1,273 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" +#include "test/core/util/test_lb_policies.h" + +namespace grpc_core { +namespace { + +const char* kDropPolicyName = "drop_lb"; + +class DropPolicy : public LoadBalancingPolicy { + public: + explicit DropPolicy(Args args) : LoadBalancingPolicy(std::move(args)) {} + + const char* name() const override { return kDropPolicyName; } + + void UpdateLocked(UpdateArgs) override { + channel_control_helper()->UpdateState(GRPC_CHANNEL_READY, absl::Status(), + absl::make_unique()); + } + + void ResetBackoffLocked() override {} + void ShutdownLocked() override {} + + private: + class DropPicker : public SubchannelPicker { + public: + PickResult Pick(PickArgs /*args*/) override { + return PickResult::Drop( + absl::UnavailableError("Call dropped by drop LB policy")); + } + }; +}; + +class DropLbConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kDropPolicyName; } +}; + +class DropPolicyFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kDropPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } +}; + +std::vector* g_pick_args_vector = nullptr; + +void RegisterDropPolicy() { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique()); + RegisterTestPickArgsLoadBalancingPolicy( + [](const PickArgsSeen& pick_args) { + GPR_ASSERT(g_pick_args_vector != nullptr); + g_pick_args_vector->push_back(pick_args); + }, + kDropPolicyName); +} + +} // namespace +} // namespace grpc_core + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry when the LB policy drops a call, +// even when there is retry configuration in the service config. +// - 1 retry allowed for UNAVAILABLE status +// - first attempt returns UNAVAILABLE due to LB drop but does not retry +static void test_retry_lb_drop(grpc_end2end_test_config config) { + grpc_call* c; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + std::vector pick_args_seen; + grpc_core::g_pick_args_vector = &pick_args_seen; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"loadBalancingConfig\": [ {\n" + " \"test_pick_args_lb\": {}\n" + " } ],\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"UNAVAILABLE\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_lb_drop", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(0 == + grpc_slice_str_cmp(details, "Call dropped by drop LB policy")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + gpr_log(GPR_INFO, "NUMBER OF LB PICKS: %" PRIuPTR, pick_args_seen.size()); + GPR_ASSERT(pick_args_seen.size() == 1); + + grpc_core::g_pick_args_vector = nullptr; + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_lb_drop(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_lb_drop(config); +} + +void retry_lb_drop_pre_init(void) { grpc_core::RegisterDropPolicy(); } diff --git a/test/core/end2end/tests/retry_lb_fail.cc b/test/core/end2end/tests/retry_lb_fail.cc new file mode 100644 index 00000000..e97567d5 --- /dev/null +++ b/test/core/end2end/tests/retry_lb_fail.cc @@ -0,0 +1,273 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +namespace grpc_core { +namespace { + +const char* kFailPolicyName = "fail_lb"; + +std::atomic g_num_lb_picks; + +class FailPolicy : public LoadBalancingPolicy { + public: + explicit FailPolicy(Args args) : LoadBalancingPolicy(std::move(args)) {} + + const char* name() const override { return kFailPolicyName; } + + void UpdateLocked(UpdateArgs) override { + absl::Status status = absl::AbortedError("LB pick failed"); + channel_control_helper()->UpdateState( + GRPC_CHANNEL_TRANSIENT_FAILURE, status, + absl::make_unique(status)); + } + + void ResetBackoffLocked() override {} + void ShutdownLocked() override {} + + private: + class FailPicker : public SubchannelPicker { + public: + explicit FailPicker(absl::Status status) : status_(status) {} + + PickResult Pick(PickArgs /*args*/) override { + g_num_lb_picks.fetch_add(1); + return PickResult::Fail(status_); + } + + private: + absl::Status status_; + }; +}; + +class FailLbConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kFailPolicyName; } +}; + +class FailPolicyFactory : public LoadBalancingPolicyFactory { + public: + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kFailPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } +}; + +void RegisterFailPolicy() { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +} // namespace +} // namespace grpc_core + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we retry properly when the LB policy fails the call before +// it ever gets to the transport, even if recv_trailing_metadata isn't +// started by the application until after the LB pick fails. +// - 1 retry allowed for ABORTED status +// - on first attempt, LB policy fails with ABORTED before application +// starts recv_trailing_metadata op +static void test_retry_lb_fail(grpc_end2end_test_config config) { + grpc_call* c; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + grpc_core::g_num_lb_picks.store(0, std::memory_order_relaxed); + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_RETRIES), 1), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"loadBalancingConfig\": [ {\n" + " \"fail_lb\": {}\n" + " } ],\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_lb_fail", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), false); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "LB pick failed")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + + cq_verifier_destroy(cqv); + + int num_picks = grpc_core::g_num_lb_picks.load(std::memory_order_relaxed); + gpr_log(GPR_INFO, "NUM LB PICKS: %d", num_picks); + GPR_ASSERT(num_picks == 2); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_lb_fail(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_lb_fail(config); +} + +void retry_lb_fail_pre_init(void) { grpc_core::RegisterFailPolicy(); } diff --git a/test/core/end2end/tests/retry_non_retriable_status.cc b/test/core/end2end/tests/retry_non_retriable_status.cc new file mode 100644 index 00000000..77710736 --- /dev/null +++ b/test/core/end2end/tests/retry_non_retriable_status.cc @@ -0,0 +1,254 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry for non-retryable status codes. +// - 1 retry allowed for ABORTED status +// - first attempt gets INVALID_ARGUMENT, so no retry is done +static void test_retry_non_retriable_status(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_non_retriable_status", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_INVALID_ARGUMENT; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_INVALID_ARGUMENT); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_non_retriable_status(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_non_retriable_status(config); +} + +void retry_non_retriable_status_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_non_retriable_status_before_recv_trailing_metadata_started.cc b/test/core/end2end/tests/retry_non_retriable_status_before_recv_trailing_metadata_started.cc new file mode 100644 index 00000000..657c206a --- /dev/null +++ b/test/core/end2end/tests/retry_non_retriable_status_before_recv_trailing_metadata_started.cc @@ -0,0 +1,270 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry for non-retryable status codes, even if +// status is received before the recv_trailing_metadata op is started. +// - 1 retry allowed for ABORTED status +// - first attempt gets INVALID_ARGUMENT, so no retry is done +static void +test_retry_non_retriable_status_before_recv_trailing_metadata_started( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = begin_test( + config, + "retry_non_retriable_status_before_recv_trailing_metadata_started", + &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_INVALID_ARGUMENT; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_INVALID_ARGUMENT); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_non_retriable_status_before_recv_trailing_metadata_started( + grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_non_retriable_status_before_recv_trailing_metadata_started(config); +} + +void retry_non_retriable_status_before_recv_trailing_metadata_started_pre_init() { +} diff --git a/test/core/end2end/tests/retry_per_attempt_recv_timeout.cc b/test/core/end2end/tests/retry_per_attempt_recv_timeout.cc new file mode 100644 index 00000000..5dc1a7db --- /dev/null +++ b/test/core/end2end/tests/retry_per_attempt_recv_timeout.cc @@ -0,0 +1,345 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests perAttemptRecvTimeout: +// - 2 retries allowed for ABORTED status +// - first attempt does not receive a response until after perAttemptRecvTimeout +// - second attempt returns ABORTED +// - third attempt returns OK +static void test_retry_per_attempt_recv_timeout( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_call* s0; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"perAttemptRecvTimeout\": \"2s\",\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Server gets a call but does not respond to the call. + error = + grpc_server_request_call(f.server, &s0, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was not sent in the + // initial attempt. + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + GPR_ASSERT(!grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Server gets a second call. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Now we can unref the first call. + grpc_call_unref(s0); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + bool found_retry_header = false; + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + if (grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)) { + GPR_ASSERT( + grpc_slice_eq(request_metadata_recv.metadata[i].value, GRPC_MDSTR_1)); + found_retry_header = true; + break; + } + } + GPR_ASSERT(found_retry_header); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server sends status ABORTED. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Server gets a third call. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(301)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(301), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + found_retry_header = false; + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + if (grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)) { + GPR_ASSERT( + grpc_slice_eq(request_metadata_recv.metadata[i].value, GRPC_MDSTR_2)); + found_retry_header = true; + break; + } + } + GPR_ASSERT(found_retry_header); + + // Server sends OK status. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(302), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(302), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_per_attempt_recv_timeout(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_per_attempt_recv_timeout(config); +} + +void retry_per_attempt_recv_timeout_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_per_attempt_recv_timeout_on_last_attempt.cc b/test/core/end2end/tests/retry_per_attempt_recv_timeout_on_last_attempt.cc new file mode 100644 index 00000000..572727b8 --- /dev/null +++ b/test/core/end2end/tests/retry_per_attempt_recv_timeout_on_last_attempt.cc @@ -0,0 +1,261 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests perAttemptRecvTimeout: +// - 1 retry allowed for ABORTED status +// - both attempts do not receive a response until after perAttemptRecvTimeout +static void test_retry_per_attempt_recv_timeout_on_last_attempt( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_call* s0; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_EXPERIMENTAL_ENABLE_HEDGING), 1), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"perAttemptRecvTimeout\": \"2s\",\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = n_seconds_from_now(10); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Server gets a call but does not respond to the call. + error = + grpc_server_request_call(f.server, &s0, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was not sent in the + // initial attempt. + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + GPR_ASSERT(!grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + // Server gets a second call, which it also does not respond to. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Now we can unref the first call. + grpc_call_unref(s0); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + bool found_retry_header = false; + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + if (grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)) { + GPR_ASSERT( + grpc_slice_eq(request_metadata_recv.metadata[i].value, GRPC_MDSTR_1)); + found_retry_header = true; + break; + } + } + GPR_ASSERT(found_retry_header); + + // Client sees call completion. + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_CANCELLED); + GPR_ASSERT( + 0 == grpc_slice_str_cmp(details, "retry perAttemptRecvTimeout exceeded")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_per_attempt_recv_timeout_on_last_attempt( + grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_per_attempt_recv_timeout_on_last_attempt(config); +} + +void retry_per_attempt_recv_timeout_on_last_attempt_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_recv_initial_metadata.cc b/test/core/end2end/tests/retry_recv_initial_metadata.cc new file mode 100644 index 00000000..24f99931 --- /dev/null +++ b/test/core/end2end/tests/retry_recv_initial_metadata.cc @@ -0,0 +1,276 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that receiving initial metadata commits the call. +// - 1 retry allowed for ABORTED status +// - first attempt receives initial metadata before trailing metadata, +// so no retry is done even though status was ABORTED +static void test_retry_recv_initial_metadata(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + grpc_metadata initial_metadata_from_server = { + grpc_slice_from_static_string("key1"), + grpc_slice_from_static_string("val1"), + {{nullptr, nullptr, nullptr, nullptr}}}; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_recv_initial_metadata", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server sends initial metadata in its own batch, before sending + // trailing metadata. + // Ideally, this would not require actually sending any metadata + // entries, but we do so to avoid sporadic failures in the proxy + // tests, where the proxy may wind up combining the batches, depending + // on timing. Sending a metadata entry ensures that the transport + // won't send a Trailers-Only response, even if the batches are combined. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = &initial_metadata_from_server; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_recv_initial_metadata(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_recv_initial_metadata(config); +} + +void retry_recv_initial_metadata_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_recv_message.cc b/test/core/end2end/tests/retry_recv_message.cc new file mode 100644 index 00000000..347c1485 --- /dev/null +++ b/test/core/end2end/tests/retry_recv_message.cc @@ -0,0 +1,258 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that receiving a message commits the call. +// - 1 retry allowed for ABORTED status +// - first attempt receives a message and therefore does not retry even +// though the final status is ABORTED +static void test_retry_recv_message(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_recv_message", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_recv_message(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_recv_message(config); +} + +void retry_recv_message_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_recv_trailing_metadata_error.cc b/test/core/end2end/tests/retry_recv_trailing_metadata_error.cc new file mode 100644 index 00000000..ecd49795 --- /dev/null +++ b/test/core/end2end/tests/retry_recv_trailing_metadata_error.cc @@ -0,0 +1,366 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/channel_init.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we honor the error passed to recv_trailing_metadata_ready +// when determining the call's status, even if the op completion runs before +// the recv_trailing_metadata op is started from the surface. +// - 1 retry allowed for ABORTED status +// - server returns ABORTED, but filter overwrites to INVALID_ARGUMENT, +// so no retry is done +static void test_retry_recv_trailing_metadata_error( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = begin_test( + config, "retry_recv_trailing_metadata_error", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_INVALID_ARGUMENT); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "injected error")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +namespace { + +// A filter that returns recv_trailing_metadata_ready with an error. +class InjectStatusFilter { + public: + static grpc_channel_filter kFilterVtable; + + public: + class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* /*args*/) { + new (elem->call_data) CallData(); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); + } + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* calld = static_cast(elem->call_data); + if (batch->recv_trailing_metadata) { + calld->original_recv_trailing_metadata_ready_ = + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready; + batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + &calld->recv_trailing_metadata_ready_; + } + grpc_call_next_op(elem, batch); + } + + private: + CallData() { + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, + RecvTrailingMetadataReady, this, nullptr); + } + + static void RecvTrailingMetadataReady(void* arg, + grpc_error_handle /*error*/) { + auto* calld = static_cast(arg); + grpc_core::Closure::Run( + DEBUG_LOCATION, calld->original_recv_trailing_metadata_ready_, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("injected error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_INVALID_ARGUMENT)); + } + + grpc_closure recv_trailing_metadata_ready_; + grpc_closure* original_recv_trailing_metadata_ready_ = nullptr; + }; + + static grpc_error_handle Init(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_channel_element* /*elem*/) {} +}; + +grpc_channel_filter InjectStatusFilter::kFilterVtable = { + CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + 0, + Init, + Destroy, + grpc_channel_next_get_info, + "InjectStatusFilter", +}; + +bool AddFilter(grpc_channel_stack_builder* builder) { + // Skip on proxy (which explicitly disables retries). + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (!grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_RETRIES, true)) { + return true; + } + // Install filter. + return grpc_channel_stack_builder_prepend_filter( + builder, &InjectStatusFilter::kFilterVtable, nullptr, nullptr); +} + +} // namespace + +void retry_recv_trailing_metadata_error(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, 0, + AddFilter); + }, + [config] { test_retry_recv_trailing_metadata_error(config); }); +} + +void retry_recv_trailing_metadata_error_pre_init() {} diff --git a/test/core/end2end/tests/retry_send_initial_metadata_refs.cc b/test/core/end2end/tests/retry_send_initial_metadata_refs.cc new file mode 100644 index 00000000..b9eeafe1 --- /dev/null +++ b/test/core/end2end/tests/retry_send_initial_metadata_refs.cc @@ -0,0 +1,357 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we hold refs to send_initial_metadata payload while +// cached, even after the caller has released its refs: +// - 2 retries allowed for ABORTED status +// - first attempt returns ABORTED +// - second attempt returns OK +static void test_retry_send_initial_metadata_refs( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array client_send_initial_metadata; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = begin_test( + config, "retry_send_initial_metadata_refs", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&client_send_initial_metadata); + client_send_initial_metadata.count = 2; + client_send_initial_metadata.metadata = static_cast( + gpr_malloc(client_send_initial_metadata.count * sizeof(grpc_metadata))); + // First element is short enough for slices to be inlined. + client_send_initial_metadata.metadata[0].key = + grpc_slice_from_copied_string(std::string("foo").c_str()); + client_send_initial_metadata.metadata[0].value = + grpc_slice_from_copied_string(std::string("bar").c_str()); + // Second element requires slice allocation. + client_send_initial_metadata.metadata[1].key = grpc_slice_from_copied_string( + std::string(GRPC_SLICE_INLINED_SIZE + 1, 'x').c_str()); + client_send_initial_metadata.metadata[1].value = + grpc_slice_from_copied_string( + std::string(GRPC_SLICE_INLINED_SIZE + 1, 'y').c_str()); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = client_send_initial_metadata.count; + op->data.send_initial_metadata.metadata = + client_send_initial_metadata.metadata; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + for (size_t i = 0; i < client_send_initial_metadata.count; ++i) { + grpc_slice_unref(client_send_initial_metadata.metadata[i].key); + grpc_slice_unref(client_send_initial_metadata.metadata[i].value); + } + grpc_metadata_array_destroy(&client_send_initial_metadata); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was not sent in the + // initial attempt. + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + GPR_ASSERT(!grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + GPR_ASSERT(contains_metadata_slices(&request_metadata_recv, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS, + GRPC_MDSTR_1)); + // It should also contain the initial metadata, even though the client + // freed it already. + GPR_ASSERT(contains_metadata(&request_metadata_recv, "foo", "bar")); + GPR_ASSERT( + contains_metadata(&request_metadata_recv, + std::string(GRPC_SLICE_INLINED_SIZE + 1, 'x').c_str(), + std::string(GRPC_SLICE_INLINED_SIZE + 1, 'y').c_str())); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_send_initial_metadata_refs(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_send_initial_metadata_refs(config); +} + +void retry_send_initial_metadata_refs_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_send_op_fails.cc b/test/core/end2end/tests/retry_send_op_fails.cc new file mode 100644 index 00000000..c090240b --- /dev/null +++ b/test/core/end2end/tests/retry_send_op_fails.cc @@ -0,0 +1,384 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/channel_stack_builder.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/channel_init.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests failure on a send op batch: +// - 2 retries allowed for ABORTED status +// - on the first call attempt, the batch containing the +// send_initial_metadata op fails, and then the call returns ABORTED, +// all without ever going out on the wire +// - second attempt returns ABORTED but does not retry, because only 2 +// attempts are allowed +static void test_retry_send_op_fails(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_send_op_fails", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Start a batch containing send ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Start a batch containing recv ops. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Client send ops should now complete. + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + // Server should get a call. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + // Server fails with status ABORTED. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // In principle, the server batch should complete before the client + // recv ops batch, but in the proxy fixtures, there are multiple threads + // involved, so the completion order tends to be a little racy. + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + // Make sure the "grpc-previous-rpc-attempts" header was sent in the retry. + bool found_retry_header = false; + for (size_t i = 0; i < request_metadata_recv.count; ++i) { + if (grpc_slice_eq(request_metadata_recv.metadata[i].key, + GRPC_MDSTR_GRPC_PREVIOUS_RPC_ATTEMPTS)) { + GPR_ASSERT( + grpc_slice_eq(request_metadata_recv.metadata[i].value, GRPC_MDSTR_1)); + found_retry_header = true; + break; + } + } + GPR_ASSERT(found_retry_header); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +namespace { + +// A filter that, for the first call it sees, will fail the batch +// containing send_initial_metadata and then fail the call with status +// ABORTED. All subsequent calls are allowed through without failures. +class FailFirstSendOpFilter { + public: + static grpc_channel_filter kFilterVtable; + + public: + class CallData { + public: + static grpc_error_handle Init(grpc_call_element* elem, + const grpc_call_element_args* args) { + new (elem->call_data) CallData(args); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_call_element* elem, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*ignored*/) { + auto* calld = static_cast(elem->call_data); + calld->~CallData(); + } + + static void StartTransportStreamOpBatch( + grpc_call_element* elem, grpc_transport_stream_op_batch* batch) { + auto* chand = static_cast(elem->channel_data); + auto* calld = static_cast(elem->call_data); + if (!chand->seen_first_) { + chand->seen_first_ = true; + calld->fail_ = true; + } + if (calld->fail_ && !batch->cancel_stream) { + grpc_transport_stream_op_batch_finish_with_failure( + batch, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "FailFirstSendOpFilter failing batch"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_ABORTED), + calld->call_combiner_); + return; + } + grpc_call_next_op(elem, batch); + } + + private: + explicit CallData(const grpc_call_element_args* args) + : call_combiner_(args->call_combiner) {} + + grpc_core::CallCombiner* call_combiner_; + bool fail_ = false; + }; + + static grpc_error_handle Init(grpc_channel_element* elem, + grpc_channel_element_args* /*args*/) { + new (elem->channel_data) FailFirstSendOpFilter(); + return GRPC_ERROR_NONE; + } + + static void Destroy(grpc_channel_element* elem) { + auto* chand = static_cast(elem->channel_data); + chand->~FailFirstSendOpFilter(); + } + + bool seen_first_ = false; +}; + +grpc_channel_filter FailFirstSendOpFilter::kFilterVtable = { + CallData::StartTransportStreamOpBatch, + grpc_channel_next_op, + sizeof(CallData), + CallData::Init, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + CallData::Destroy, + sizeof(FailFirstSendOpFilter), + Init, + Destroy, + grpc_channel_next_get_info, + "FailFirstSendOpFilter", +}; + +} // namespace + +void retry_send_op_fails(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + grpc_core::CoreConfiguration::RunWithSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + grpc_core::BuildCoreConfiguration(builder); + builder->channel_init()->RegisterStage( + GRPC_CLIENT_SUBCHANNEL, 0, [](grpc_channel_stack_builder* builder) { + // Skip on proxy (which explicitly disables retries). + const grpc_channel_args* args = + grpc_channel_stack_builder_get_channel_arguments(builder); + if (!grpc_channel_args_find_bool(args, GRPC_ARG_ENABLE_RETRIES, + true)) { + return true; + } + // Install filter. + return grpc_channel_stack_builder_prepend_filter( + builder, &FailFirstSendOpFilter::kFilterVtable, nullptr, + nullptr); + }); + }, + [config] { test_retry_send_op_fails(config); }); +} + +void retry_send_op_fails_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_server_pushback_delay.cc b/test/core/end2end/tests/retry_server_pushback_delay.cc new file mode 100644 index 00000000..b5716746 --- /dev/null +++ b/test/core/end2end/tests/retry_server_pushback_delay.cc @@ -0,0 +1,316 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we honor server push-back delay. +// - 2 retries allowed for ABORTED status +// - first attempt gets ABORTED with a long delay +// - second attempt succeeds +static void test_retry_server_pushback_delay(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_metadata pushback_md; + memset(&pushback_md, 0, sizeof(pushback_md)); + pushback_md.key = GRPC_MDSTR_GRPC_RETRY_PUSHBACK_MS; + pushback_md.value = grpc_slice_from_static_string("2000"); + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_server_pushback_delay", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 1; + op->data.send_status_from_server.trailing_metadata = &pushback_md; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + gpr_timespec before_retry = gpr_now(GPR_CLOCK_MONOTONIC); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + gpr_timespec after_retry = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec retry_delay = gpr_time_sub(after_retry, before_retry); + // Configured back-off was 1 second, server push-back said 2 seconds. + // To avoid flakiness, we allow some fudge factor here. + gpr_log(GPR_INFO, "retry delay was {.tv_sec=%" PRId64 ", .tv_nsec=%d}", + retry_delay.tv_sec, retry_delay.tv_nsec); + GPR_ASSERT(retry_delay.tv_sec >= 1); + if (retry_delay.tv_sec == 1) { + GPR_ASSERT(retry_delay.tv_nsec >= 800000000); + } + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_server_pushback_delay(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_server_pushback_delay(config); +} + +void retry_server_pushback_delay_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_server_pushback_disabled.cc b/test/core/end2end/tests/retry_server_pushback_disabled.cc new file mode 100644 index 00000000..cc805108 --- /dev/null +++ b/test/core/end2end/tests/retry_server_pushback_disabled.cc @@ -0,0 +1,304 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry when disabled by server push-back. +// - 2 retries allowed for ABORTED status +// - first attempt gets ABORTED +// - second attempt gets ABORTED but server push back disables retrying +static void test_retry_server_pushback_disabled( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_metadata pushback_md; + memset(&pushback_md, 0, sizeof(pushback_md)); + pushback_md.key = GRPC_MDSTR_GRPC_RETRY_PUSHBACK_MS; + pushback_md.value = grpc_slice_from_static_string("-1"); + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = begin_test( + config, "retry_server_pushback_disabled", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 1; + op->data.send_status_from_server.trailing_metadata = &pushback_md; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_server_pushback_disabled(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_server_pushback_disabled(config); +} + +void retry_server_pushback_disabled_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_streaming.cc b/test/core/end2end/tests/retry_streaming.cc new file mode 100644 index 00000000..fdd8e592 --- /dev/null +++ b/test/core/end2end/tests/retry_streaming.cc @@ -0,0 +1,452 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/server.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests retrying a streaming RPC. This is the same as +// the basic retry test, except that the client sends two messages on the +// call before the initial attempt fails. +// FIXME: We should also test the case where the retry is committed after +// replaying 1 of 2 previously-completed send_message ops. However, +// there's no way to trigger that from an end2end test, because the +// replayed ops happen under the hood -- they are not surfaced to the +// C-core API, and therefore we have no way to inject the commit at the +// right point. +static void test_retry_streaming(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice request2_payload_slice = grpc_slice_from_static_string("bar"); + grpc_slice request3_payload_slice = grpc_slice_from_static_string("baz"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("quux"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* request2_payload = + grpc_raw_byte_buffer_create(&request2_payload_slice, 1); + grpc_byte_buffer* request3_payload = + grpc_raw_byte_buffer_create(&request3_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* request2_payload_recv = nullptr; + grpc_byte_buffer* request3_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE), + 1024 * 8), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ENABLE_CHANNELZ), true), + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"))}; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_streaming", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + grpc_core::channelz::ChannelNode* channelz_channel = + grpc_channel_get_channelz_node(f.client); + + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client starts a batch for receiving initial metadata, a message, + // and trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Client sends initial metadata and a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + // Server gets a call with received initial metadata. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server receives a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + // Client sends a second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request2_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(3), true); + cq_verify(cqv); + + // Server receives the second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request2_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + cq_verify(cqv); + + // Server sends both initial and trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(104), true); + cq_verify(cqv); + + // Clean up from first attempt. + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + grpc_byte_buffer_destroy(request_payload_recv); + request_payload_recv = nullptr; + GPR_ASSERT( + byte_buffer_eq_slice(request2_payload_recv, request2_payload_slice)); + grpc_byte_buffer_destroy(request2_payload_recv); + request2_payload_recv = nullptr; + + // Server gets a second call (the retry). + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server receives a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + cq_verify(cqv); + + // Server receives a second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request2_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(203), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(203), true); + cq_verify(cqv); + + // Client sends a third message and a close. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request3_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + // Server receives a third message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request3_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(204), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(204), true); + cq_verify(cqv); + + // Server receives a close and sends initial metadata, a message, and + // trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + // Returning a retriable code, but because we are also sending a + // message, the client will commit instead of retrying again. + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(205), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(205), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + GPR_ASSERT(channelz_channel != nullptr); + std::string json = channelz_channel->RenderJsonString(); + gpr_log(GPR_INFO, "%s", json.c_str()); + GPR_ASSERT(json.find("\"trace\"") != json.npos); + GPR_ASSERT(json.find("\"description\":\"Channel created\"") != json.npos); + GPR_ASSERT(json.find("\"severity\":\"CT_INFO\"") != json.npos); + GPR_ASSERT(json.find("Resolution event") != json.npos); + GPR_ASSERT(json.find("Created new LB policy") != json.npos); + GPR_ASSERT(json.find("Service config changed") != json.npos); + GPR_ASSERT(json.find("Address list became non-empty") != json.npos); + GPR_ASSERT(json.find("Channel state change to CONNECTING") != json.npos); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request2_payload); + grpc_byte_buffer_destroy(request3_payload); + grpc_byte_buffer_destroy(response_payload); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + grpc_byte_buffer_destroy(request_payload_recv); + GPR_ASSERT( + byte_buffer_eq_slice(request2_payload_recv, request2_payload_slice)); + grpc_byte_buffer_destroy(request2_payload_recv); + GPR_ASSERT( + byte_buffer_eq_slice(request3_payload_recv, request3_payload_slice)); + grpc_byte_buffer_destroy(request3_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_streaming(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + + test_retry_streaming(config); +} + +void retry_streaming_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_streaming_after_commit.cc b/test/core/end2end/tests/retry_streaming_after_commit.cc new file mode 100644 index 00000000..5829edb2 --- /dev/null +++ b/test/core/end2end/tests/retry_streaming_after_commit.cc @@ -0,0 +1,358 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we can continue to send/recv messages on a streaming call +// after retries are committed. +static void test_retry_streaming_after_commit(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice request2_payload_slice = grpc_slice_from_static_string("bar"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("baz"); + grpc_slice response2_payload_slice = grpc_slice_from_static_string("quux"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* request2_payload = + grpc_raw_byte_buffer_create(&request2_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* response2_payload = + grpc_raw_byte_buffer_create(&response2_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* request2_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_byte_buffer* response2_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_streaming_after_commit", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client starts a batch for receiving initial metadata and a message. + // This will commit retries. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Client sends initial metadata and a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(3), true); + cq_verify(cqv); + + // Server gets a call with received initial metadata. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server receives a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + // Server sends initial metadata and a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + // Client receives initial metadata and a message. + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + // Client sends a second message and a close. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request2_payload; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + // Server receives a second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request2_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(104), true); + cq_verify(cqv); + + // Server receives a close, sends a second message, and sends status. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response2_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + // Returning a retriable code, but because retries are already + // committed, the client will not retry. + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(105), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(105), true); + cq_verify(cqv); + + // Client receives a second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response2_payload_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(5), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(5), true); + cq_verify(cqv); + + // Client receives status. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request2_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(response2_payload); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + grpc_byte_buffer_destroy(request_payload_recv); + GPR_ASSERT( + byte_buffer_eq_slice(request2_payload_recv, request2_payload_slice)); + grpc_byte_buffer_destroy(request2_payload_recv); + GPR_ASSERT( + byte_buffer_eq_slice(response_payload_recv, response_payload_slice)); + grpc_byte_buffer_destroy(response_payload_recv); + GPR_ASSERT( + byte_buffer_eq_slice(response2_payload_recv, response2_payload_slice)); + grpc_byte_buffer_destroy(response2_payload_recv); + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_streaming_after_commit(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_streaming_after_commit(config); +} + +void retry_streaming_after_commit_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_streaming_succeeds_before_replay_finished.cc b/test/core/end2end/tests/retry_streaming_succeeds_before_replay_finished.cc new file mode 100644 index 00000000..4c2faca2 --- /dev/null +++ b/test/core/end2end/tests/retry_streaming_succeeds_before_replay_finished.cc @@ -0,0 +1,405 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we correctly clean up if the second attempt finishes +// before we have finished replaying all of the send ops. +static void test_retry_streaming_succeeds_before_replay_finished( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice request2_payload_slice = grpc_slice_from_static_string("bar"); + grpc_slice request3_payload_slice = grpc_slice_from_static_string("baz"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("quux"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* request2_payload = + grpc_raw_byte_buffer_create(&request2_payload_slice, 1); + grpc_byte_buffer* request3_payload = + grpc_raw_byte_buffer_create(&request3_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* request2_payload_recv = nullptr; + grpc_byte_buffer* request3_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_streaming", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + // Client starts a batch for receiving initial metadata, a message, + // and trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Client sends initial metadata and a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + cq_verify(cqv); + + // Server gets a call with received initial metadata. + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server receives a message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + // Client sends a second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request2_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(3), true); + cq_verify(cqv); + + // Server receives the second message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request2_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + cq_verify(cqv); + + // Client sends a third message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request3_payload; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + // Server receives the third message. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request3_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(104), true); + cq_verify(cqv); + + // Server sends both initial and trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(105), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(105), true); + cq_verify(cqv); + + // Clean up from first attempt. + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + grpc_byte_buffer_destroy(request_payload_recv); + request_payload_recv = nullptr; + GPR_ASSERT( + byte_buffer_eq_slice(request2_payload_recv, request2_payload_slice)); + grpc_byte_buffer_destroy(request2_payload_recv); + request2_payload_recv = nullptr; + GPR_ASSERT( + byte_buffer_eq_slice(request3_payload_recv, request3_payload_slice)); + grpc_byte_buffer_destroy(request3_payload_recv); + request3_payload_recv = nullptr; + + // Server gets a second call (the retry). + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + // Server receives the first message (and does not receive any others). + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + cq_verify(cqv); + + // Server sends initial metadata, a message, and trailing metadata. + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + // Returning a retriable code, but because we are also sending a + // message, the client will commit instead of retrying again. + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(205), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(205), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request2_payload); + grpc_byte_buffer_destroy(request3_payload); + grpc_byte_buffer_destroy(response_payload); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_streaming_succeeds_before_replay_finished( + grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_streaming_succeeds_before_replay_finished(config); +} + +void retry_streaming_succeeds_before_replay_finished_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_throttled.cc b/test/core/end2end/tests/retry_throttled.cc new file mode 100644 index 00000000..a9051ba6 --- /dev/null +++ b/test/core/end2end/tests/retry_throttled.cc @@ -0,0 +1,261 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we don't retry when throttled. +// - 1 retry allowed for ABORTED status +// - first attempt gets ABORTED but is over limit, so no retry is done +static void test_retry_throttled(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ],\n" + // A single failure will cause us to be throttled. + // (This is not a very realistic config, but it works for the + // purposes of this test.) + " \"retryThrottling\": {\n" + " \"maxTokens\": 2,\n" + " \"tokenRatio\": 1.0\n" + " }\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_throttled", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_throttled(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_throttled(config); +} + +void retry_throttled_pre_init(void) {} diff --git a/test/core/end2end/tests/retry_too_many_attempts.cc b/test/core/end2end/tests/retry_too_many_attempts.cc new file mode 100644 index 00000000..1048334e --- /dev/null +++ b/test/core/end2end/tests/retry_too_many_attempts.cc @@ -0,0 +1,297 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/end2end/tests/cancel_test_helpers.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Tests that we stop retrying after the configured number of attempts. +// - 1 retry allowed for ABORTED status +// - first attempt gets ABORTED +// - second attempt gets ABORTED but does not retry +static void test_retry_too_many_attempts(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_slice request_payload_slice = grpc_slice_from_static_string("foo"); + grpc_slice response_payload_slice = grpc_slice_from_static_string("bar"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + + grpc_arg args[] = { + grpc_channel_arg_string_create( + const_cast(GRPC_ARG_SERVICE_CONFIG), + const_cast( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"service\", \"method\": \"method\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 2,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}")), + }; + grpc_channel_args client_args = {GPR_ARRAY_SIZE(args), args}; + grpc_end2end_test_fixture f = + begin_test(config, "retry_too_many_attempts", &client_args, nullptr); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/service/method"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + grpc_call_unref(s); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(201)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(201), true); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_ABORTED; + op->data.send_status_from_server.status_details = &status_details; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(202), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(202), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_ABORTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void retry_too_many_attempts(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_CLIENT_CHANNEL); + test_retry_too_many_attempts(config); +} + +void retry_too_many_attempts_pre_init(void) {} diff --git a/test/core/end2end/tests/sdk_authz.cc b/test/core/end2end/tests/sdk_authz.cc new file mode 100644 index 00000000..19e228bd --- /dev/null +++ b/test/core/end2end/tests/sdk_authz.cc @@ -0,0 +1,722 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/security/authorization/grpc_authorization_policy_provider.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" +#include "test/core/util/tls_utils.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_allow_authorized_request(grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + const char* error_string = nullptr; + grpc_call_error error; + grpc_slice details = grpc_empty_slice(); + int was_cancelled = 2; + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = &error_string; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + GPR_ASSERT(GRPC_STATUS_OK == status); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(details); + gpr_free(const_cast(error_string)); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + cq_verifier_destroy(cqv); +} + +static void test_deny_unauthorized_request(grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + const char* error_string = nullptr; + grpc_call_error error; + grpc_slice details = grpc_empty_slice(); + + cq_verifier* cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = &error_string; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(GRPC_STATUS_PERMISSION_DENIED == status); + GPR_ASSERT(0 == + grpc_slice_str_cmp(details, "Unauthorized RPC request rejected.")); + + grpc_slice_unref(details); + gpr_free(const_cast(error_string)); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + + grpc_call_unref(c); + cq_verifier_destroy(cqv); +} + +static void test_static_init_allow_authorized_request( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_static_data_create(authz_policy, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_static_init_allow_authorized_request", nullptr, + &server_args); + grpc_authorization_policy_provider_release(provider); + test_allow_authorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_static_init_deny_unauthorized_request( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_static_data_create(authz_policy, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_static_init_deny_unauthorized_request", nullptr, + &server_args); + grpc_authorization_policy_provider_release(provider); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_static_init_deny_request_no_match_in_policy( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_static_data_create(authz_policy, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_static_init_deny_request_no_match_in_policy", + nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_init_allow_authorized_request( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_file_watcher_init_allow_authorized_request", + nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_allow_authorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_init_deny_unauthorized_request( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_file_watcher_init_deny_unauthorized_request", + nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_init_deny_request_no_match_in_policy( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = begin_test( + config, "test_file_watcher_init_deny_request_no_match_in_policy", nullptr, + &server_args); + grpc_authorization_policy_provider_release(provider); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_valid_policy_reload( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = begin_test( + config, "test_file_watcher_valid_policy_reload", nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_allow_authorized_request(f); + // Replace existing policy in file with a different authorization policy. + authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + tmp_policy.RewriteFile(authz_policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_invalid_policy_skip_reload( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = + begin_test(config, "test_file_watcher_invalid_policy_skip_reload", + nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_allow_authorized_request(f); + // Replace exisiting policy in file with an invalid policy. + authz_policy = "{}"; + tmp_policy.RewriteFile(authz_policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + test_allow_authorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_file_watcher_recovers_from_failure( + grpc_end2end_test_config config) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(authz_policy); + grpc_status_code code = GRPC_STATUS_OK; + const char* error_details; + grpc_authorization_policy_provider* provider = + grpc_authorization_policy_provider_file_watcher_create( + tmp_policy.name().c_str(), /*refresh_interval_sec=*/1, &code, + &error_details); + GPR_ASSERT(GRPC_STATUS_OK == code); + grpc_arg args[] = { + grpc_channel_arg_pointer_create( + const_cast(GRPC_ARG_AUTHORIZATION_POLICY_PROVIDER), provider, + grpc_authorization_policy_provider_arg_vtable()), + }; + grpc_channel_args server_args = {GPR_ARRAY_SIZE(args), args}; + + grpc_end2end_test_fixture f = begin_test( + config, "test_file_watcher_valid_policy_reload", nullptr, &server_args); + grpc_authorization_policy_provider_release(provider); + test_allow_authorized_request(f); + // Replace exisiting policy in file with an invalid policy. + authz_policy = "{}"; + tmp_policy.RewriteFile(authz_policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + test_allow_authorized_request(f); + // Recover from reload errors, by replacing invalid policy in file with a + // valid policy. + authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_bar\"," + " \"request\": {" + " \"paths\": [" + " \"*/bar\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]" + "}"; + tmp_policy.RewriteFile(authz_policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + test_deny_unauthorized_request(f); + + end_test(&f); + config.tear_down_data(&f); +} + +void sdk_authz(grpc_end2end_test_config config) { + test_static_init_allow_authorized_request(config); + test_static_init_deny_unauthorized_request(config); + test_static_init_deny_request_no_match_in_policy(config); + test_file_watcher_init_allow_authorized_request(config); + test_file_watcher_init_deny_unauthorized_request(config); + test_file_watcher_init_deny_request_no_match_in_policy(config); + test_file_watcher_valid_policy_reload(config); + test_file_watcher_invalid_policy_skip_reload(config); + test_file_watcher_recovers_from_failure(config); +} + +void sdk_authz_pre_init(void) {} diff --git a/test/core/end2end/tests/server_finishes_request.cc b/test/core/end2end/tests/server_finishes_request.cc new file mode 100644 index 00000000..3c16b579 --- /dev/null +++ b/test/core/end2end/tests/server_finishes_request.cc @@ -0,0 +1,203 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_request_body(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +void server_finishes_request(grpc_end2end_test_config config) { + test_invoke_simple_request(config); +} + +void server_finishes_request_pre_init(void) {} diff --git a/test/core/end2end/tests/server_streaming.cc b/test/core/end2end/tests/server_streaming.cc new file mode 100644 index 00000000..f7dcf0fe --- /dev/null +++ b/test/core/end2end/tests/server_streaming.cc @@ -0,0 +1,280 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client requests status along with the initial metadata. Server streams + * messages and ends with a non-OK status. Client reads after server is done + * writing, and expects to get the status after the messages. */ +static void test_server_streaming(grpc_end2end_test_config config, + int num_messages) { + grpc_end2end_test_fixture f = + begin_test(config, "test_server_streaming", nullptr, nullptr); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_byte_buffer* request_payload_recv; + grpc_byte_buffer* response_payload; + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello world"); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + // Client requests status early but should not receive status till all the + // messages are received. + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + // Client sends close early + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + // Server writes bunch of messages + for (int i = 0; i < num_messages; i++) { + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + cq_verify(cqv); + + grpc_byte_buffer_destroy(response_payload); + } + + // Server sends status + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + bool seen_status = false; + CQ_MAYBE_EXPECT_COMPLETION(cqv, tag(1), true, &seen_status); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + // Client keeps reading messages till it gets the status + int num_messages_received = 0; + while (true) { + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_MAYBE_EXPECT_COMPLETION(cqv, tag(1), true, &seen_status); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + if (request_payload_recv == nullptr) { + // The transport has received the trailing metadata. + break; + } + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + grpc_byte_buffer_destroy(request_payload_recv); + num_messages_received++; + } + GPR_ASSERT(num_messages_received == num_messages); + if (!seen_status) { + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + } + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + + grpc_slice_unref(response_payload_slice); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + + end_test(&f); + config.tear_down_data(&f); +} + +void server_streaming(grpc_end2end_test_config config) { + test_server_streaming(config, 0); + test_server_streaming(config, 1); + test_server_streaming(config, 10); +} + +void server_streaming_pre_init(void) {} diff --git a/test/core/end2end/tests/shutdown_finishes_calls.cc b/test/core/end2end/tests/shutdown_finishes_calls.cc new file mode 100644 index 00000000..63f66a72 --- /dev/null +++ b/test/core/end2end/tests/shutdown_finishes_calls.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + /* f->shutdown_cq is not used in this test */ + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_early_server_shutdown_finishes_inflight_calls( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_end2end_test_fixture f = + begin_test(config, "test_early_server_shutdown_finishes_inflight_calls", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->data.send_initial_metadata.metadata = nullptr; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* Make sure we don't shutdown the server while HTTP/2 PING frames are still + * being exchanged on the newly established connection. It can lead to + * failures when testing with HTTP proxy. See + * https://github.com/grpc/grpc/issues/14471 + */ + gpr_sleep_until(n_seconds_from_now(1)); + + /* shutdown and destroy the server */ + grpc_server_shutdown_and_notify(f.server, f.cq, tag(1000)); + grpc_server_cancel_all_calls(f.server); + + CQ_EXPECT_COMPLETION(cqv, tag(1000), 1); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + grpc_server_destroy(f.server); + + GPR_ASSERT(status == GRPC_STATUS_UNAVAILABLE); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 1); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + end_test(&f); + config.tear_down_data(&f); +} + +void shutdown_finishes_calls(grpc_end2end_test_config config) { + test_early_server_shutdown_finishes_inflight_calls(config); +} + +void shutdown_finishes_calls_pre_init(void) {} diff --git a/test/core/end2end/tests/shutdown_finishes_tags.cc b/test/core/end2end/tests/shutdown_finishes_tags.cc new file mode 100644 index 00000000..4c4128ef --- /dev/null +++ b/test/core/end2end/tests/shutdown_finishes_tags.cc @@ -0,0 +1,109 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + /* f->shutdown_cq is not used in this test */ + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void test_early_server_shutdown_finishes_tags( + grpc_end2end_test_config config) { + grpc_end2end_test_fixture f = begin_test( + config, "test_early_server_shutdown_finishes_tags", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_call* s = reinterpret_cast(1); + grpc_call_details call_details; + grpc_metadata_array request_metadata_recv; + + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + /* upon shutdown, the server should finish all requested calls indicating + no new call */ + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + grpc_server_shutdown_and_notify(f.server, f.cq, tag(1000)); + CQ_EXPECT_COMPLETION(cqv, tag(101), 0); + CQ_EXPECT_COMPLETION(cqv, tag(1000), 1); + cq_verify(cqv); + GPR_ASSERT(s == nullptr); + + grpc_server_destroy(f.server); + + end_test(&f); + config.tear_down_data(&f); + cq_verifier_destroy(cqv); +} + +void shutdown_finishes_tags(grpc_end2end_test_config config) { + test_early_server_shutdown_finishes_tags(config); +} + +void shutdown_finishes_tags_pre_init(void) {} diff --git a/test/core/end2end/tests/simple_cacheable_request.cc b/test/core/end2end/tests/simple_cacheable_request.cc new file mode 100644 index 00000000..f14df66a --- /dev/null +++ b/test/core/end2end/tests/simple_cacheable_request.cc @@ -0,0 +1,274 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +enum { TIMEOUT = 200000 }; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Request/response with metadata and payload.*/ +static void test_cacheable_request_response_with_metadata_and_payload( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_metadata meta_c[2] = {{grpc_slice_from_static_string("key1"), + grpc_slice_from_static_string("val1"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key2"), + grpc_slice_from_static_string("val2"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_metadata meta_s[2] = {{grpc_slice_from_static_string("key3"), + grpc_slice_from_static_string("val3"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key4"), + grpc_slice_from_static_string("val4"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_end2end_test_fixture f = begin_test( + config, "test_cacheable_request_response_with_metadata_and_payload", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_c; + op->flags = GRPC_INITIAL_METADATA_CACHEABLE_REQUEST; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_s; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + if (config.feature_mask & FEATURE_MASK_SUPPORTS_REQUEST_PROXYING) { + // Our simple proxy does not support cacheable requests + } else { + GPR_ASSERT(GRPC_INITIAL_METADATA_CACHEABLE_REQUEST & call_details.flags); + } + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, "hello you")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key1", "val1")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key2", "val2")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key3", "val3")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key4", "val4")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void simple_cacheable_request(grpc_end2end_test_config config) { + test_cacheable_request_response_with_metadata_and_payload(config); +} + +void simple_cacheable_request_pre_init(void) {} diff --git a/test/core/end2end/tests/simple_delayed_request.cc b/test/core/end2end/tests/simple_delayed_request.cc new file mode 100644 index 00000000..10170122 --- /dev/null +++ b/test/core/end2end/tests/simple_delayed_request.cc @@ -0,0 +1,232 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void simple_delayed_request_body(grpc_end2end_test_config config, + grpc_end2end_test_fixture* f, + grpc_channel_args* client_args, + grpc_channel_args* server_args, + long /*delay_us*/) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f->cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + config.init_client(f, client_args); + config.init_server(f, server_args); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f->client, nullptr, GRPC_PROPAGATE_DEFAULTS, + f->cq, grpc_slice_from_static_string("/foo"), + nullptr, deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f->server, &s, &call_details, + &request_metadata_recv, f->cq, f->cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); +} + +static void test_simple_delayed_request_short(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + grpc_channel_args client_args; + grpc_arg arg_array[1]; + arg_array[0].type = GRPC_ARG_INTEGER; + arg_array[0].key = + const_cast("grpc.testing.fixed_reconnect_backoff_ms"); + arg_array[0].value.integer = 1000; + client_args.args = arg_array; + client_args.num_args = 1; + + gpr_log(GPR_INFO, "Running test: %s/%s", "test_simple_delayed_request_short", + config.name); + f = config.create_fixture(nullptr, nullptr); + + simple_delayed_request_body(config, &f, &client_args, nullptr, 100000); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_simple_delayed_request_long(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + grpc_channel_args client_args; + grpc_arg arg_array[1]; + arg_array[0].type = GRPC_ARG_INTEGER; + arg_array[0].key = + const_cast("grpc.testing.fixed_reconnect_backoff_ms"); + arg_array[0].value.integer = 1000; + client_args.args = arg_array; + client_args.num_args = 1; + + gpr_log(GPR_INFO, "Running test: %s/%s", "test_simple_delayed_request_long", + config.name); + f = config.create_fixture(nullptr, nullptr); + /* This timeout should be longer than a single retry */ + simple_delayed_request_body(config, &f, &client_args, nullptr, 1500000); + end_test(&f); + config.tear_down_data(&f); +} + +void simple_delayed_request(grpc_end2end_test_config config) { + GPR_ASSERT(config.feature_mask & FEATURE_MASK_SUPPORTS_DELAYED_CONNECTION); + test_simple_delayed_request_short(config); + test_simple_delayed_request_long(config); +} + +void simple_delayed_request_pre_init(void) {} diff --git a/test/core/end2end/tests/simple_metadata.cc b/test/core/end2end/tests/simple_metadata.cc new file mode 100644 index 00000000..74c28d4f --- /dev/null +++ b/test/core/end2end/tests/simple_metadata.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Request/response with metadata and payload.*/ +static void test_request_response_with_metadata_and_payload( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_metadata meta_c[2] = {{grpc_slice_from_static_string("key1"), + grpc_slice_from_static_string("val1"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key2"), + grpc_slice_from_static_string("val2"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_metadata meta_s[2] = {{grpc_slice_from_static_string("key3"), + grpc_slice_from_static_string("val3"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key4"), + grpc_slice_from_static_string("val4"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_end2end_test_fixture f = + begin_test(config, "test_request_response_with_metadata_and_payload", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_c; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_s; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, "hello you")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key1", "val1")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key2", "val2")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key3", "val3")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key4", "val4")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void simple_metadata(grpc_end2end_test_config config) { + test_request_response_with_metadata_and_payload(config); +} + +void simple_metadata_pre_init(void) {} diff --git a/test/core/end2end/tests/simple_request.cc b/test/core/end2end/tests/simple_request.cc new file mode 100644 index 00000000..3a0b75f5 --- /dev/null +++ b/test/core/end2end/tests/simple_request.cc @@ -0,0 +1,291 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/string.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void check_peer(char* peer_name) { + // If the peer name is a uds path, then check if it is filled + if (strncmp(peer_name, "unix:/", strlen("unix:/")) == 0) { + GPR_ASSERT(strncmp(peer_name, "unix:/tmp/grpc_fullstack_test.", + strlen("unix:/tmp/grpc_fullstack_test.")) == 0); + } +} + +static void simple_request_body(grpc_end2end_test_config config, + grpc_end2end_test_fixture f) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + const char* error_string; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + char* peer; + grpc_stats_data* before = + static_cast(gpr_malloc(sizeof(grpc_stats_data))); + grpc_stats_data* after = + static_cast(gpr_malloc(sizeof(grpc_stats_data))); + +#if defined(GRPC_COLLECT_STATS) || !defined(NDEBUG) + grpc_stats_collect(before); +#endif /* defined(GRPC_COLLECT_STATS) || !defined(NDEBUG) */ + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer_before_call=%s", peer); + gpr_free(peer); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = &error_string; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(s); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "server_peer=%s", peer); + check_peer(peer); + gpr_free(peer); + peer = grpc_call_get_peer(c); + GPR_ASSERT(peer != nullptr); + gpr_log(GPR_DEBUG, "client_peer=%s", peer); + check_peer(peer); + gpr_free(peer); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + // the following sanity check makes sure that the requested error string is + // correctly populated by the core. It looks for certain substrings that are + // not likely to change much. Some parts of the error, like time created, + // obviously are not checked. + GPR_ASSERT(nullptr != strstr(error_string, "xyz")); + GPR_ASSERT(nullptr != strstr(error_string, "Error received from peer")); + GPR_ASSERT(nullptr != strstr(error_string, "grpc_message")); + GPR_ASSERT(nullptr != strstr(error_string, "grpc_status")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(0 == call_details.flags); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + gpr_free(const_cast(error_string)); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + int expected_calls = 1; + if (config.feature_mask & FEATURE_MASK_SUPPORTS_REQUEST_PROXYING) { + expected_calls *= 2; + } +#if defined(GRPC_COLLECT_STATS) || !defined(NDEBUG) + + grpc_stats_collect(after); + + gpr_log(GPR_DEBUG, "%s", grpc_stats_data_as_json(after).c_str()); + + GPR_ASSERT(after->counters[GRPC_STATS_COUNTER_CLIENT_CALLS_CREATED] - + before->counters[GRPC_STATS_COUNTER_CLIENT_CALLS_CREATED] == + expected_calls); + GPR_ASSERT(after->counters[GRPC_STATS_COUNTER_SERVER_CALLS_CREATED] - + before->counters[GRPC_STATS_COUNTER_SERVER_CALLS_CREATED] == + expected_calls); +#endif /* defined(GRPC_COLLECT_STATS) || !defined(NDEBUG) */ + gpr_free(before); + gpr_free(after); +} + +static void test_invoke_simple_request(grpc_end2end_test_config config) { + grpc_end2end_test_fixture f; + + f = begin_test(config, "test_invoke_simple_request", nullptr, nullptr); + simple_request_body(config, f); + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_10_simple_requests(grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_10_simple_requests", nullptr, nullptr); + for (i = 0; i < 10; i++) { + simple_request_body(config, f); + gpr_log(GPR_INFO, "Running test: Passed simple request %d", i); + } + end_test(&f); + config.tear_down_data(&f); +} + +void simple_request(grpc_end2end_test_config config) { + int i; + for (i = 0; i < 10; i++) { + test_invoke_simple_request(config); + } + test_invoke_10_simple_requests(config); +} + +void simple_request_pre_init(void) {} diff --git a/test/core/end2end/tests/stream_compression_compressed_payload.cc b/test/core/end2end/tests/stream_compression_compressed_payload.cc new file mode 100644 index 00000000..35aba3fd --- /dev/null +++ b/test/core/end2end/tests/stream_compression_compressed_payload.cc @@ -0,0 +1,625 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/call_test_only.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +static void request_for_disabled_algorithm( + grpc_end2end_test_config config, const char* test_name, + uint32_t send_flags_bitmask, + grpc_compression_algorithm algorithm_to_disable, + grpc_compression_algorithm requested_client_compression_algorithm, + grpc_status_code expected_error, grpc_metadata* client_metadata) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice; + grpc_byte_buffer* request_payload; + grpc_channel_args* client_args; + grpc_channel_args* server_args; + grpc_end2end_test_fixture f; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + cq_verifier* cqv; + char str[1024]; + + memset(str, 'x', 1023); + str[1023] = '\0'; + request_payload_slice = grpc_slice_from_copied_string(str); + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + + client_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, requested_client_compression_algorithm); + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_NONE); + { + grpc_core::ExecCtx exec_ctx; + server_args = grpc_channel_args_compression_algorithm_set_state( + &server_args, algorithm_to_disable, false); + } + + f = begin_test(config, test_name, client_args, server_args); + cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + if (client_metadata != nullptr) { + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = client_metadata; + } else { + op->data.send_initial_metadata.count = 0; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = send_flags_bitmask; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); + cq_verify(cqv); + + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), false); + + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + cq_verify(cqv); + + /* call was cancelled (closed) ... */ + GPR_ASSERT(was_cancelled != 0); + /* with a certain error */ + GPR_ASSERT(status == expected_error); + + const char* algo_name = nullptr; + GPR_ASSERT(grpc_compression_algorithm_name(algorithm_to_disable, &algo_name)); + std::string expected_details = + absl::StrCat("Compression algorithm '", algo_name, "' is disabled."); + /* and we expect a specific reason for it */ + GPR_ASSERT(0 == grpc_slice_str_cmp(details, expected_details.c_str())); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_slice_unref(request_payload_slice); + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv); + + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + } + + end_test(&f); + config.tear_down_data(&f); +} + +static void request_with_payload_template( + grpc_end2end_test_config config, const char* test_name, + uint32_t client_send_flags_bitmask, + grpc_compression_algorithm default_client_channel_compression_algorithm, + grpc_compression_algorithm default_server_channel_compression_algorithm, + grpc_compression_algorithm /*expected_client_compression_algorithm*/, + grpc_compression_algorithm /*expected_server_compression_algorithm*/, + grpc_metadata* client_init_metadata, bool set_server_level, + grpc_compression_level server_compression_level, + bool send_message_before_initial_metadata, + bool set_default_server_message_compression_algorithm, + grpc_compression_algorithm default_server_message_compression_algorithm) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice; + grpc_byte_buffer* request_payload = nullptr; + grpc_channel_args* client_args; + grpc_channel_args* server_args; + grpc_end2end_test_fixture f; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload; + grpc_byte_buffer* response_payload_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + cq_verifier* cqv; + char request_str[1024]; + char response_str[1024]; + + memset(request_str, 'x', 1023); + request_str[1023] = '\0'; + + memset(response_str, 'y', 1023); + response_str[1023] = '\0'; + + request_payload_slice = grpc_slice_from_copied_string(request_str); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string(response_str); + + client_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_client_channel_compression_algorithm); + if (set_default_server_message_compression_algorithm) { + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_server_message_compression_algorithm); + } else { + server_args = grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, default_server_channel_compression_algorithm); + } + + f = begin_test(config, test_name, client_args, server_args); + cqv = cq_verifier_create(f.cq); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + if (send_message_before_initial_metadata) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = client_send_flags_bitmask; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + if (client_init_metadata != nullptr) { + op->data.send_initial_metadata.count = 1; + op->data.send_initial_metadata.metadata = client_init_metadata; + } else { + op->data.send_initial_metadata.count = 0; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), true); + cq_verify(cqv); + + GPR_ASSERT(grpc_core::BitCount( + grpc_call_test_only_get_encodings_accepted_by_peer(s)) == + GRPC_COMPRESS_ALGORITHMS_COUNT); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_NONE) != 0); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_DEFLATE) != 0); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_GZIP) != 0); + GPR_ASSERT( + grpc_core::GetBit(grpc_call_test_only_get_encodings_accepted_by_peer(s), + GRPC_COMPRESS_STREAM_GZIP) != 0); + GPR_ASSERT(grpc_core::BitCount( + grpc_call_test_only_get_encodings_accepted_by_peer(s)) == + GRPC_COMPRESS_ALGORITHMS_COUNT); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + if (set_server_level) { + op->data.send_initial_metadata.maybe_compression_level.is_set = true; + op->data.send_initial_metadata.maybe_compression_level.level = + server_compression_level; + } + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + for (int i = 0; i < 2; i++) { + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + if (i > 0 || !send_message_before_initial_metadata) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = client_send_flags_bitmask; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), + tag(2), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + } + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + GPR_ASSERT(request_payload_recv->type == GRPC_BB_RAW); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, request_str)); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + + GPR_ASSERT(response_payload_recv->type == GRPC_BB_RAW); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, response_str)); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + } + + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + } + + end_test(&f); + config.tear_down_data(&f); +} + +static void test_invoke_request_with_compressed_payload( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload", 0, + GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_STREAM_GZIP, + GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_STREAM_GZIP, nullptr, + false, /* ignored */ + GRPC_COMPRESS_LEVEL_NONE, false, false, GRPC_COMPRESS_NONE); +} + +static void test_invoke_request_with_send_message_before_initial_metadata( + grpc_end2end_test_config config) { + request_with_payload_template( + config, "test_invoke_request_with_send_message_before_initial_metadata", + 0, GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_STREAM_GZIP, + GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_STREAM_GZIP, nullptr, + false, /* ignored */ + GRPC_COMPRESS_LEVEL_NONE, true, false, GRPC_COMPRESS_NONE); +} + +static void test_invoke_request_with_compressed_payload_md_override( + grpc_end2end_test_config config) { + grpc_metadata gzip_compression_override; + grpc_metadata identity_compression_override; + + gzip_compression_override.key = + GRPC_MDSTR_GRPC_INTERNAL_STREAM_ENCODING_REQUEST; + gzip_compression_override.value = + grpc_slice_from_static_string("stream/gzip"); + memset(&gzip_compression_override.internal_data, 0, + sizeof(gzip_compression_override.internal_data)); + + identity_compression_override.key = + GRPC_MDSTR_GRPC_INTERNAL_STREAM_ENCODING_REQUEST; + identity_compression_override.value = + grpc_slice_from_static_string("identity"); + memset(&identity_compression_override.internal_data, 0, + sizeof(identity_compression_override.internal_data)); + + /* Channel default NONE (aka IDENTITY), call override to stream GZIP */ + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload_md_override_1", 0, + GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, GRPC_COMPRESS_STREAM_GZIP, + GRPC_COMPRESS_NONE, &gzip_compression_override, false, + /*ignored*/ GRPC_COMPRESS_LEVEL_NONE, false, false, GRPC_COMPRESS_NONE); + + /* Channel default stream GZIP, call override to NONE (aka IDENTITY) */ + request_with_payload_template( + config, "test_invoke_request_with_compressed_payload_md_override_3", 0, + GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_NONE, GRPC_COMPRESS_NONE, + GRPC_COMPRESS_NONE, &identity_compression_override, false, + /*ignored*/ GRPC_COMPRESS_LEVEL_NONE, false, false, GRPC_COMPRESS_NONE); +} + +static void test_invoke_request_with_disabled_algorithm( + grpc_end2end_test_config config) { + request_for_disabled_algorithm( + config, "test_invoke_request_with_disabled_algorithm", 0, + GRPC_COMPRESS_STREAM_GZIP, GRPC_COMPRESS_STREAM_GZIP, + GRPC_STATUS_UNIMPLEMENTED, nullptr); +} + +void stream_compression_compressed_payload(grpc_end2end_test_config config) { + test_invoke_request_with_compressed_payload(config); + test_invoke_request_with_send_message_before_initial_metadata(config); + test_invoke_request_with_compressed_payload_md_override(config); + test_invoke_request_with_disabled_algorithm(config); +} + +void stream_compression_compressed_payload_pre_init(void) {} diff --git a/test/core/end2end/tests/stream_compression_payload.cc b/test/core/end2end/tests/stream_compression_payload.cc new file mode 100644 index 00000000..f5522949 --- /dev/null +++ b/test/core/end2end/tests/stream_compression_payload.cc @@ -0,0 +1,302 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/surface/call.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Creates and returns a grpc_slice containing random alphanumeric characters. + */ +static grpc_slice generate_random_slice() { + size_t i; + static const char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + char* output; + const size_t output_size = 1024 * 1024; + output = static_cast(gpr_malloc(output_size)); + for (i = 0; i < output_size - 1; ++i) { + output[i] = chars[rand() % static_cast(sizeof(chars) - 1)]; + } + output[output_size - 1] = '\0'; + grpc_slice out = grpc_slice_from_copied_string(output); + gpr_free(output); + return out; +} + +static void request_response_with_payload(grpc_end2end_test_config /*config*/, + grpc_end2end_test_fixture f) { + /* Create large request and response bodies. These are big enough to require + * multiple round trips to deliver to the peer, and their exact contents of + * will be verified on completion. */ + grpc_slice request_payload_slice = generate_random_slice(); + grpc_slice response_payload_slice = generate_random_slice(); + + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = n_seconds_from_now(60); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_slice(request_payload_recv, request_payload_slice)); + GPR_ASSERT( + byte_buffer_eq_slice(response_payload_recv, response_payload_slice)); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); +} + +/* Client sends a request with payload, server reads then returns a response + payload and status. */ +static void test_invoke_request_response_with_payload( + grpc_end2end_test_config config) { + grpc_channel_args* client_args = + grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_STREAM_GZIP); + grpc_channel_args* server_args = + grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_STREAM_GZIP); + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_request_response_with_payload", + client_args, server_args); + request_response_with_payload(config, f); + end_test(&f); + config.tear_down_data(&f); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + } +} + +static void test_invoke_10_request_response_with_payload( + grpc_end2end_test_config config) { + int i; + grpc_end2end_test_fixture f = begin_test( + config, "test_invoke_10_request_response_with_payload", nullptr, nullptr); + for (i = 0; i < 10; i++) { + request_response_with_payload(config, f); + } + end_test(&f); + config.tear_down_data(&f); +} + +void stream_compression_payload(grpc_end2end_test_config config) { + test_invoke_request_response_with_payload(config); + test_invoke_10_request_response_with_payload(config); +} + +void stream_compression_payload_pre_init(void) {} diff --git a/test/core/end2end/tests/stream_compression_ping_pong_streaming.cc b/test/core/end2end/tests/stream_compression_ping_pong_streaming.cc new file mode 100644 index 00000000..c7e131af --- /dev/null +++ b/test/core/end2end/tests/stream_compression_ping_pong_streaming.cc @@ -0,0 +1,295 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/compression/compression_args.h" +#include "src/core/lib/surface/call.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client pings and server pongs. Repeat messages rounds before finishing. */ +static void test_pingpong_streaming(grpc_end2end_test_config config, + int messages) { + grpc_channel_args* client_args = + grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_STREAM_GZIP); + grpc_channel_args* server_args = + grpc_channel_args_set_channel_default_compression_algorithm( + nullptr, GRPC_COMPRESS_STREAM_GZIP); + grpc_end2end_test_fixture f = + begin_test(config, "test_pingpong_streaming", client_args, server_args); + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + grpc_byte_buffer* request_payload; + grpc_byte_buffer* request_payload_recv; + grpc_byte_buffer* response_payload; + grpc_byte_buffer* response_payload_recv; + int i; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(100)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(100), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(101), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + for (i = 0; i < messages; i++) { + request_payload = grpc_raw_byte_buffer_create(&request_payload_slice, 1); + response_payload = grpc_raw_byte_buffer_create(&response_payload_slice, 1); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(102), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), + tag(103), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + } + + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + cq_verify(cqv); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_slice_unref(details); + + end_test(&f); + config.tear_down_data(&f); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(client_args); + grpc_channel_args_destroy(server_args); + } +} + +void stream_compression_ping_pong_streaming(grpc_end2end_test_config config) { + int i; + + for (i = 1; i < 10; i++) { + test_pingpong_streaming(config, i); + } +} + +void stream_compression_ping_pong_streaming_pre_init(void) {} diff --git a/test/core/end2end/tests/streaming_error_response.cc b/test/core/end2end/tests/streaming_error_response.cc new file mode 100644 index 00000000..178eb68f --- /dev/null +++ b/test/core/end2end/tests/streaming_error_response.cc @@ -0,0 +1,292 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/** \file Verify that status ordering rules are obeyed. + \ref doc/status_ordering.md */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args, + bool request_status_early) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s/request_status_early=%s", test_name, + config.name, request_status_early ? "true" : "false"); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +// Client sends a request with payload, potentially requesting status early. The +// server reads and streams responses. The client cancels the RPC to get an +// error status. (Server sending a non-OK status is not considered an error +// status.) +static void test(grpc_end2end_test_config config, bool request_status_early, + bool recv_message_separately) { + grpc_call* c; + grpc_call* s; + grpc_slice response_payload1_slice = grpc_slice_from_copied_string("hello"); + grpc_byte_buffer* response_payload1 = + grpc_raw_byte_buffer_create(&response_payload1_slice, 1); + grpc_slice response_payload2_slice = grpc_slice_from_copied_string("world"); + grpc_byte_buffer* response_payload2 = + grpc_raw_byte_buffer_create(&response_payload2_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "streaming_error_response", nullptr, nullptr, + request_status_early); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* response_payload1_recv = nullptr; + grpc_byte_buffer* response_payload2_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status = GRPC_STATUS_OK; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + GPR_ASSERT(!recv_message_separately || request_status_early); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + if (!recv_message_separately) { + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload1_recv; + op++; + } + if (request_status_early) { + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + } + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload1; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + if (recv_message_separately) { + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload1_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + } + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + if (!request_status_early) { + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + } + if (recv_message_separately) { + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + } + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload2; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + // The success of the op depends on whether the payload is written before the + // transport sees the end of stream. If the stream has been write closed + // before the write completes, it would fail, otherwise it would succeed. + // Since this behavior is dependent on the transport implementation, we allow + // any success status with this op. + CQ_EXPECT_COMPLETION_ANY_STATUS(cqv, tag(103)); + + if (!request_status_early) { + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload2_recv; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + } + + // Cancel the call so that the client sets up an error status. + grpc_call_cancel(c, nullptr); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(104), 1); + if (request_status_early) { + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + } + cq_verify(cqv); + + if (!request_status_early) { + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv); + + GPR_ASSERT(response_payload1_recv != nullptr); + GPR_ASSERT(response_payload2_recv != nullptr); + } + + GPR_ASSERT(status == GRPC_STATUS_CANCELLED); + GPR_ASSERT(was_cancelled == 1); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(response_payload1); + grpc_byte_buffer_destroy(response_payload2); + grpc_byte_buffer_destroy(response_payload1_recv); + grpc_byte_buffer_destroy(response_payload2_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void streaming_error_response(grpc_end2end_test_config config) { + test(config, false, false); + test(config, true, false); + test(config, true, true); +} + +void streaming_error_response_pre_init(void) {} diff --git a/test/core/end2end/tests/trailing_metadata.cc b/test/core/end2end/tests/trailing_metadata.cc new file mode 100644 index 00000000..1d1b883e --- /dev/null +++ b/test/core/end2end/tests/trailing_metadata.cc @@ -0,0 +1,275 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Request/response with metadata and payload.*/ +static void test_request_response_with_metadata_and_payload( + grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_slice response_payload_slice = + grpc_slice_from_copied_string("hello you"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + grpc_metadata meta_c[2] = {{grpc_slice_from_static_string("key1"), + grpc_slice_from_static_string("val1"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key2"), + grpc_slice_from_static_string("val2"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_metadata meta_s[2] = {{grpc_slice_from_static_string("key3"), + grpc_slice_from_static_string("val3"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key4"), + grpc_slice_from_static_string("val4"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_metadata meta_t[2] = {{grpc_slice_from_static_string("key5"), + grpc_slice_from_static_string("val5"), + {{nullptr, nullptr, nullptr, nullptr}}}, + {grpc_slice_from_static_string("key6"), + grpc_slice_from_static_string("val6"), + {{nullptr, nullptr, nullptr, nullptr}}}}; + grpc_end2end_test_fixture f = + begin_test(config, "test_request_response_with_metadata_and_payload", + nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_c; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = + grpc_server_request_call(f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 2; + op->data.send_initial_metadata.metadata = meta_s; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 2; + op->data.send_status_from_server.trailing_metadata = meta_t; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(response_payload_recv, "hello you")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key1", "val1")); + GPR_ASSERT(contains_metadata(&request_metadata_recv, "key2", "val2")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key3", "val3")); + GPR_ASSERT(contains_metadata(&initial_metadata_recv, "key4", "val4")); + GPR_ASSERT(contains_metadata(&trailing_metadata_recv, "key5", "val5")); + GPR_ASSERT(contains_metadata(&trailing_metadata_recv, "key6", "val6")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + + end_test(&f); + config.tear_down_data(&f); +} + +void trailing_metadata(grpc_end2end_test_config config) { + test_request_response_with_metadata_and_payload(config); +} + +void trailing_metadata_pre_init(void) {} diff --git a/test/core/end2end/tests/write_buffering.cc b/test/core/end2end/tests/write_buffering.cc new file mode 100644 index 00000000..d3536a91 --- /dev/null +++ b/test/core/end2end/tests/write_buffering.cc @@ -0,0 +1,284 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client sends a request with payload, server reads then returns status. */ +static void test_invoke_request_with_payload(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice1 = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload1 = + grpc_raw_byte_buffer_create(&request_payload_slice1, 1); + grpc_slice request_payload_slice2 = grpc_slice_from_copied_string("abc123"); + grpc_byte_buffer* request_payload2 = + grpc_raw_byte_buffer_create(&request_payload_slice2, 1); + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_request_with_payload", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv1 = nullptr; + grpc_byte_buffer* request_payload_recv2 = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details = grpc_empty_slice(); + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); /* send message is buffered */ + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload1; + op->flags = GRPC_WRITE_BUFFER_HINT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* recv message should not succeed yet - it's buffered at the client still */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv1; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + CQ_EXPECT_COMPLETION(cqv, tag(3), true); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + /* send another message, this time not buffered */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload2; + op->flags = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* now the first send should match up with the first recv */ + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + /* and the next recv should be ready immediately also */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv2; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(104), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(105), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(105), 1); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv1, "hello world")); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv2, "abc123")); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload1); + grpc_byte_buffer_destroy(request_payload_recv1); + grpc_byte_buffer_destroy(request_payload2); + grpc_byte_buffer_destroy(request_payload_recv2); + + end_test(&f); + config.tear_down_data(&f); +} + +void write_buffering(grpc_end2end_test_config config) { + test_invoke_request_with_payload(config); +} + +void write_buffering_pre_init(void) {} diff --git a/test/core/end2end/tests/write_buffering_at_end.cc b/test/core/end2end/tests/write_buffering_at_end.cc new file mode 100644 index 00000000..c0114d9b --- /dev/null +++ b/test/core/end2end/tests/write_buffering_at_end.cc @@ -0,0 +1,273 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#include "test/core/end2end/cq_verifier.h" +#include "test/core/end2end/end2end_tests.h" + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_end2end_test_fixture begin_test(grpc_end2end_test_config config, + const char* test_name, + grpc_channel_args* client_args, + grpc_channel_args* server_args) { + grpc_end2end_test_fixture f; + gpr_log(GPR_INFO, "Running test: %s/%s", test_name, config.name); + f = config.create_fixture(client_args, server_args); + config.init_server(&f, server_args); + config.init_client(&f, client_args); + return f; +} + +static gpr_timespec n_seconds_from_now(int n) { + return grpc_timeout_seconds_to_deadline(n); +} + +static gpr_timespec five_seconds_from_now(void) { + return n_seconds_from_now(5); +} + +static void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, five_seconds_from_now(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +static void shutdown_server(grpc_end2end_test_fixture* f) { + if (!f->server) return; + grpc_server_shutdown_and_notify(f->server, f->shutdown_cq, tag(1000)); + GPR_ASSERT(grpc_completion_queue_pluck(f->shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_server_destroy(f->server); + f->server = nullptr; +} + +static void shutdown_client(grpc_end2end_test_fixture* f) { + if (!f->client) return; + grpc_channel_destroy(f->client); + f->client = nullptr; +} + +static void end_test(grpc_end2end_test_fixture* f) { + shutdown_server(f); + shutdown_client(f); + + grpc_completion_queue_shutdown(f->cq); + drain_cq(f->cq); + grpc_completion_queue_destroy(f->cq); + grpc_completion_queue_destroy(f->shutdown_cq); +} + +/* Client sends a request with payload, server reads then returns status. */ +static void test_invoke_request_with_payload(grpc_end2end_test_config config) { + grpc_call* c; + grpc_call* s; + grpc_slice request_payload_slice = + grpc_slice_from_copied_string("hello world"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_end2end_test_fixture f = + begin_test(config, "test_invoke_request_with_payload", nullptr, nullptr); + cq_verifier* cqv = cq_verifier_create(f.cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv1 = nullptr; + grpc_byte_buffer* request_payload_recv2 = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details = grpc_empty_slice(); + int was_cancelled = 2; + + gpr_timespec deadline = five_seconds_from_now(); + c = grpc_channel_create_call(f.client, nullptr, GRPC_PROPAGATE_DEFAULTS, f.cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(2), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + GPR_ASSERT(GRPC_CALL_OK == grpc_server_request_call( + f.server, &s, &call_details, + &request_metadata_recv, f.cq, f.cq, tag(101))); + CQ_EXPECT_COMPLETION(cqv, tag(1), true); /* send message is buffered */ + CQ_EXPECT_COMPLETION(cqv, tag(101), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = GRPC_WRITE_BUFFER_HINT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(3), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* recv message should not succeed yet - it's buffered at the client still */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv1; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(2), true); + CQ_EXPECT_COMPLETION(cqv, tag(3), true); + CQ_EXPECT_COMPLETION(cqv, tag(102), true); + cq_verify(cqv); + + /* send end of stream: should release the buffering */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* now the first send should match up with the first recv */ + CQ_EXPECT_COMPLETION(cqv, tag(103), true); + CQ_EXPECT_COMPLETION(cqv, tag(4), true); + cq_verify(cqv); + + /* and the next recv should be ready immediately also (and empty) */ + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv2; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(104), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(104), true); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(4), + nullptr); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(105), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(105), 1); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT(byte_buffer_eq_string(request_payload_recv1, "hello world")); + GPR_ASSERT(request_payload_recv2 == nullptr); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(request_payload_recv1); + + end_test(&f); + config.tear_down_data(&f); +} + +void write_buffering_at_end(grpc_end2end_test_config config) { + test_invoke_request_with_payload(config); +} + +void write_buffering_at_end_pre_init(void) {} diff --git a/test/core/event_engine/endpoint_config_test.cc b/test/core/event_engine/endpoint_config_test.cc new file mode 100644 index 00000000..f6e58147 --- /dev/null +++ b/test/core/event_engine/endpoint_config_test.cc @@ -0,0 +1,45 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/event_engine/endpoint_config_internal.h" +#include "test/core/util/test_config.h" + +using ::grpc_event_engine::experimental::ChannelArgsEndpointConfig; + +TEST(EndpointConfigTest, CanSRetrieveValuesFromChannelArgs) { + grpc_arg arg = grpc_channel_arg_integer_create(const_cast("arst"), 3); + const grpc_channel_args args = {1, &arg}; + ChannelArgsEndpointConfig config(&args); + EXPECT_EQ(absl::get(config.Get("arst")), 3); +} + +TEST(EndpointConfigTest, ReturnsMonostateForMissingKeys) { + ChannelArgsEndpointConfig config(nullptr); + EXPECT_TRUE( + absl::holds_alternative(config.Get("nonexistent"))); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/core/event_engine/test_suite/event_engine_test.cc b/test/core/event_engine/test_suite/event_engine_test.cc new file mode 100644 index 00000000..5ba038f2 --- /dev/null +++ b/test/core/event_engine/test_suite/event_engine_test.cc @@ -0,0 +1,28 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "test/core/event_engine/test_suite/event_engine_test.h" + +#include + +#include + +std::function()>* + g_ee_factory = nullptr; + +void SetEventEngineFactory( + std::function< + std::unique_ptr()> + factory) { + testing::AddGlobalTestEnvironment(new EventEngineTestEnvironment(factory)); +} diff --git a/test/core/event_engine/test_suite/timer_test.cc b/test/core/event_engine/test_suite/timer_test.cc new file mode 100644 index 00000000..088dc2ce --- /dev/null +++ b/test/core/event_engine/test_suite/timer_test.cc @@ -0,0 +1,173 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#include "absl/functional/bind_front.h" +#include "absl/time/time.h" + +#include +#include + +#include "src/core/lib/gprpp/sync.h" +#include "test/core/event_engine/test_suite/event_engine_test.h" + +using ::testing::ElementsAre; + +class EventEngineTimerTest : public EventEngineTest { + public: + void ScheduleCheckCB(absl::Time when, std::atomic* call_count, + std::atomic* fail_count, int total_expected); + + protected: + grpc_core::Mutex mu_; + grpc_core::CondVar cv_; + bool signaled_ ABSL_GUARDED_BY(mu_) = false; +}; + +TEST_F(EventEngineTimerTest, ImmediateCallbackIsExecutedQuickly) { + auto engine = this->NewEventEngine(); + grpc_core::MutexLock lock(&mu_); + engine->RunAt(absl::Now(), [this]() { + grpc_core::MutexLock lock(&mu_); + signaled_ = true; + cv_.Signal(); + }); + cv_.WaitWithTimeout(&mu_, absl::Seconds(5)); + ASSERT_TRUE(signaled_); +} + +TEST_F(EventEngineTimerTest, SupportsCancellation) { + auto engine = this->NewEventEngine(); + auto handle = engine->RunAt(absl::InfiniteFuture(), []() {}); + ASSERT_TRUE(engine->Cancel(handle)); +} + +TEST_F(EventEngineTimerTest, CancelledCallbackIsNotExecuted) { + { + auto engine = this->NewEventEngine(); + auto handle = engine->RunAt(absl::InfiniteFuture(), [this]() { + grpc_core::MutexLock lock(&mu_); + signaled_ = true; + }); + ASSERT_TRUE(engine->Cancel(handle)); + } + // The engine is deleted, and all closures should have been flushed + grpc_core::MutexLock lock(&mu_); + ASSERT_FALSE(signaled_); +} + +TEST_F(EventEngineTimerTest, TimersRespectScheduleOrdering) { + // Note: this is a brittle test if the first call to `RunAt` takes longer than + // the second callback's wait time. + std::vector ordered; + uint8_t count = 0; + grpc_core::MutexLock lock(&mu_); + { + auto engine = this->NewEventEngine(); + engine->RunAt(absl::Now() + absl::Seconds(1), [&]() { + grpc_core::MutexLock lock(&mu_); + ordered.push_back(2); + ++count; + cv_.Signal(); + }); + engine->RunAt(absl::Now(), [&]() { + grpc_core::MutexLock lock(&mu_); + ordered.push_back(1); + ++count; + cv_.Signal(); + }); + // Ensure both callbacks have run. Simpler than a mutex. + while (count != 2) { + cv_.WaitWithTimeout(&mu_, absl::Microseconds(100)); + } + } + // The engine is deleted, and all closures should have been flushed beforehand + ASSERT_THAT(ordered, ElementsAre(1, 2)); +} + +TEST_F(EventEngineTimerTest, CancellingExecutedCallbackIsNoopAndReturnsFalse) { + auto engine = this->NewEventEngine(); + grpc_core::MutexLock lock(&mu_); + auto handle = engine->RunAt(absl::Now(), [this]() { + grpc_core::MutexLock lock(&mu_); + signaled_ = true; + cv_.Signal(); + }); + cv_.WaitWithTimeout(&mu_, absl::Seconds(10)); + ASSERT_TRUE(signaled_); + // The callback has run, and now we'll try to cancel it. + ASSERT_FALSE(engine->Cancel(handle)); +} + +void EventEngineTimerTest::ScheduleCheckCB(absl::Time when, + std::atomic* call_count, + std::atomic* fail_count, + int total_expected) { + // TODO(hork): make the EventEngine the time source of truth! libuv supports + // millis, absl::Time reports in nanos. This generic test will be hard-coded + // to the lowest common denominator until EventEngines can compare relative + // times with supported resolution. + int64_t now_millis = absl::ToUnixMillis(absl::Now()); + int64_t when_millis = absl::ToUnixMillis(when); + EXPECT_LE(when_millis, now_millis); + if (when_millis > now_millis) ++(*fail_count); + if (++(*call_count) == total_expected) { + grpc_core::MutexLock lock(&mu_); + signaled_ = true; + cv_.Signal(); + } +} + +TEST_F(EventEngineTimerTest, StressTestTimersNotCalledBeforeScheduled) { + auto engine = this->NewEventEngine(); + constexpr int thread_count = 100; + constexpr int call_count_per_thread = 100; + constexpr float timeout_min_seconds = 1; + constexpr float timeout_max_seconds = 10; + std::atomic call_count{0}; + std::atomic failed_call_count{0}; + std::vector threads; + threads.reserve(thread_count); + for (int thread_n = 0; thread_n < thread_count; ++thread_n) { + threads.emplace_back([&]() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(timeout_min_seconds, + timeout_max_seconds); + for (int call_n = 0; call_n < call_count_per_thread; ++call_n) { + absl::Time when = absl::Now() + absl::Seconds(dis(gen)); + engine->RunAt( + when, absl::bind_front(&EventEngineTimerTest::ScheduleCheckCB, this, + when, &call_count, &failed_call_count, + thread_count * call_count_per_thread)); + } + }); + } + for (auto& t : threads) { + t.join(); + } + grpc_core::MutexLock lock(&mu_); + // to protect against spurious wakeups. + while (!signaled_) { + cv_.Wait(&mu_); + } + gpr_log(GPR_DEBUG, "failed timer count: %d of %d", failed_call_count.load(), + thread_count * call_count); + ASSERT_EQ(0, failed_call_count.load()); +} diff --git a/test/core/fling/client.cc b/test/core/fling/client.cc new file mode 100644 index 00000000..5fa5ecc3 --- /dev/null +++ b/test/core/fling/client.cc @@ -0,0 +1,247 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/profiling/timers.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/grpc_profiler.h" +#include "test/core/util/histogram.h" +#include "test/core/util/test_config.h" + +static grpc_histogram* histogram; +static grpc_byte_buffer* the_buffer; +static grpc_channel* channel; +static grpc_completion_queue* cq; +static grpc_call* call; +static grpc_op ops[6]; +static grpc_op stream_init_ops[2]; +static grpc_op stream_step_ops[2]; +static grpc_metadata_array initial_metadata_recv; +static grpc_metadata_array trailing_metadata_recv; +static grpc_byte_buffer* response_payload_recv = nullptr; +static grpc_status_code status; +static grpc_slice details; +static grpc_op* op; + +static void init_ping_pong_request(void) { + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + memset(ops, 0, sizeof(ops)); + op = ops; + + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = the_buffer; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; +} + +static void step_ping_pong_request(void) { + GPR_TIMER_SCOPE("ping_pong", 1); + grpc_slice host = grpc_slice_from_static_string("localhost"); + call = grpc_channel_create_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/Reflector/reflectUnary"), &host, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), (void*)1, + nullptr)); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_call_unref(call); + grpc_byte_buffer_destroy(response_payload_recv); + call = nullptr; +} + +static void init_ping_pong_stream(void) { + grpc_metadata_array_init(&initial_metadata_recv); + + grpc_call_error error; + grpc_slice host = grpc_slice_from_static_string("localhost"); + call = grpc_channel_create_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/Reflector/reflectStream"), &host, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + stream_init_ops[0].op = GRPC_OP_SEND_INITIAL_METADATA; + stream_init_ops[0].data.send_initial_metadata.count = 0; + stream_init_ops[1].op = GRPC_OP_RECV_INITIAL_METADATA; + stream_init_ops[1].data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv; + error = grpc_call_start_batch(call, stream_init_ops, 2, + reinterpret_cast(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + + grpc_metadata_array_init(&initial_metadata_recv); + + stream_step_ops[0].op = GRPC_OP_SEND_MESSAGE; + stream_step_ops[0].data.send_message.send_message = the_buffer; + stream_step_ops[1].op = GRPC_OP_RECV_MESSAGE; + stream_step_ops[1].data.recv_message.recv_message = &response_payload_recv; +} + +static void step_ping_pong_stream(void) { + GPR_TIMER_SCOPE("ping_pong", 1); + grpc_call_error error; + error = grpc_call_start_batch(call, stream_step_ops, 2, + reinterpret_cast(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_byte_buffer_destroy(response_payload_recv); +} + +static double now(void) { + gpr_timespec tv = gpr_now(GPR_CLOCK_REALTIME); + return 1e9 * static_cast(tv.tv_sec) + tv.tv_nsec; +} + +typedef struct { + const char* name; + void (*init)(); + void (*do_one_step)(); +} scenario; + +static const scenario scenarios[] = { + {"ping-pong-request", init_ping_pong_request, step_ping_pong_request}, + {"ping-pong-stream", init_ping_pong_stream, step_ping_pong_stream}, +}; + +int main(int argc, char** argv) { + grpc_slice slice = grpc_slice_from_copied_string("x"); + double start, stop; + unsigned i; + + char* fake_argv[1]; + + int payload_size = 1; + int secure = 0; + const char* target = "localhost:443"; + gpr_cmdline* cl; + grpc_event event; + const char* scenario_name = "ping-pong-request"; + scenario sc = {nullptr, nullptr, nullptr}; + + gpr_timers_set_log_filename("latency_trace.fling_client.txt"); + + GPR_ASSERT(argc >= 1); + fake_argv[0] = argv[0]; + grpc::testing::TestEnvironment env(1, fake_argv); + + grpc_init(); + + int warmup_seconds = 1; + int benchmark_seconds = 5; + + cl = gpr_cmdline_create("fling client"); + gpr_cmdline_add_int(cl, "payload_size", "Size of the payload to send", + &payload_size); + gpr_cmdline_add_string(cl, "target", "Target host:port", &target); + gpr_cmdline_add_flag(cl, "secure", "Run with security?", &secure); + gpr_cmdline_add_string(cl, "scenario", "Scenario", &scenario_name); + gpr_cmdline_add_int(cl, "warmup", "Warmup seconds", &warmup_seconds); + gpr_cmdline_add_int(cl, "benchmark", "Benchmark seconds", &benchmark_seconds); + gpr_cmdline_parse(cl, argc, argv); + gpr_cmdline_destroy(cl); + + for (i = 0; i < GPR_ARRAY_SIZE(scenarios); i++) { + if (0 == strcmp(scenarios[i].name, scenario_name)) { + sc = scenarios[i]; + } + } + if (!sc.name) { + fprintf(stderr, "unsupported scenario '%s'. Valid are:", scenario_name); + fflush(stderr); + for (i = 0; i < GPR_ARRAY_SIZE(scenarios); i++) { + fprintf(stderr, " %s", scenarios[i].name); + fflush(stderr); + } + return 1; + } + + channel = grpc_insecure_channel_create(target, nullptr, nullptr); + cq = grpc_completion_queue_create_for_next(nullptr); + the_buffer = + grpc_raw_byte_buffer_create(&slice, static_cast(payload_size)); + histogram = grpc_histogram_create(0.01, 60e9); + + sc.init(); + + gpr_timespec end_warmup = grpc_timeout_seconds_to_deadline(warmup_seconds); + gpr_timespec end_profiling = + grpc_timeout_seconds_to_deadline(warmup_seconds + benchmark_seconds); + + while (gpr_time_cmp(gpr_now(end_warmup.clock_type), end_warmup) < 0) { + sc.do_one_step(); + } + + gpr_log(GPR_INFO, "start profiling"); + grpc_profiler_start("client.prof"); + while (gpr_time_cmp(gpr_now(end_profiling.clock_type), end_profiling) < 0) { + start = now(); + sc.do_one_step(); + stop = now(); + grpc_histogram_add(histogram, stop - start); + } + grpc_profiler_stop(); + + if (call) { + grpc_call_unref(call); + } + + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + do { + event = grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + } while (event.type != GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cq); + grpc_byte_buffer_destroy(the_buffer); + grpc_slice_unref(slice); + + gpr_log(GPR_INFO, "latency (50/95/99/99.9): %f/%f/%f/%f", + grpc_histogram_percentile(histogram, 50), + grpc_histogram_percentile(histogram, 95), + grpc_histogram_percentile(histogram, 99), + grpc_histogram_percentile(histogram, 99.9)); + grpc_histogram_destroy(histogram); + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/fling/fling_stream_test.cc b/test/core/fling/fling_stream_test.cc new file mode 100644 index 00000000..80b1ebda --- /dev/null +++ b/test/core/fling/fling_stream_test.cc @@ -0,0 +1,80 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/util/port.h" +#include "test/core/util/subprocess.h" + +int main(int /*argc*/, char** argv) { + char* me = argv[0]; + char* lslash = strrchr(me, '/'); + char root[1024]; + int port = grpc_pick_unused_port_or_die(); + char* args[10]; + int status; + gpr_subprocess *svr, *cli; + /* figure out where we are */ + if (lslash) { + memcpy(root, me, static_cast(lslash - me)); + root[lslash - me] = 0; + } else { + strcpy(root, "."); + } + /* start the server */ + std::string command = + absl::StrCat(root, "/fling_server", gpr_subprocess_binary_extension()); + args[0] = const_cast(command.c_str()); + args[1] = const_cast("--bind"); + std::string joined = grpc_core::JoinHostPort("::", port); + args[2] = const_cast(joined.c_str()); + args[3] = const_cast("--no-secure"); + svr = gpr_subprocess_create(4, const_cast(args)); + + /* start the client */ + command = + absl::StrCat(root, "/fling_client", gpr_subprocess_binary_extension()); + args[0] = const_cast(command.c_str()); + args[1] = const_cast("--target"); + joined = grpc_core::JoinHostPort("127.0.0.1", port); + args[2] = const_cast(joined.c_str()); + args[3] = const_cast("--scenario=ping-pong-stream"); + args[4] = const_cast("--no-secure"); + args[5] = nullptr; + cli = gpr_subprocess_create(6, const_cast(args)); + + /* wait for completion */ + printf("waiting for client\n"); + if ((status = gpr_subprocess_join(cli))) { + gpr_subprocess_destroy(cli); + gpr_subprocess_destroy(svr); + return status; + } + gpr_subprocess_destroy(cli); + + gpr_subprocess_interrupt(svr); + status = gpr_subprocess_join(svr); + gpr_subprocess_destroy(svr); + return status; +} diff --git a/test/core/fling/fling_test.cc b/test/core/fling/fling_test.cc new file mode 100644 index 00000000..0f0d8fe4 --- /dev/null +++ b/test/core/fling/fling_test.cc @@ -0,0 +1,82 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/util/port.h" +#include "test/core/util/subprocess.h" + +int main(int /*argc*/, const char** argv) { + const char* me = argv[0]; + const char* lslash = strrchr(me, '/'); + char root[1024]; + int port = grpc_pick_unused_port_or_die(); + char* args[10]; + int status; + gpr_subprocess *svr, *cli; + /* figure out where we are */ + if (lslash) { + memcpy(root, me, static_cast(lslash - me)); + root[lslash - me] = 0; + } else { + strcpy(root, "."); + } + /* start the server */ + std::string command = + absl::StrCat(root, "/fling_server", gpr_subprocess_binary_extension()); + args[0] = const_cast(command.c_str()); + args[1] = const_cast("--bind"); + std::string joined = grpc_core::JoinHostPort("::", port); + args[2] = const_cast(joined.c_str()); + args[3] = const_cast("--no-secure"); + svr = gpr_subprocess_create(4, const_cast(args)); + + /* start the client */ + command = + absl::StrCat(root, "/fling_client", gpr_subprocess_binary_extension()); + args[0] = const_cast(command.c_str()); + args[1] = const_cast("--target"); + joined = grpc_core::JoinHostPort("127.0.0.1", port); + args[2] = const_cast(joined.c_str()); + args[3] = const_cast("--scenario=ping-pong-request"); + args[4] = const_cast("--no-secure"); + args[5] = nullptr; + cli = gpr_subprocess_create(6, const_cast(args)); + + /* wait for completion */ + printf("waiting for client\n"); + if ((status = gpr_subprocess_join(cli))) { + gpr_subprocess_destroy(cli); + gpr_subprocess_destroy(svr); + return status; + } + gpr_subprocess_destroy(cli); + + gpr_subprocess_interrupt(svr); + status = gpr_subprocess_join(svr); + gpr_subprocess_destroy(svr); + return status; +} diff --git a/test/core/fling/server.cc b/test/core/fling/server.cc new file mode 100644 index 00000000..441cf002 --- /dev/null +++ b/test/core/fling/server.cc @@ -0,0 +1,327 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include + +#include +#include +#ifndef _WIN32 +/* This is for _exit() below, which is temporary. */ +#include +#endif + +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/profiling/timers.h" +#include "test/core/end2end/data/ssl_test_data.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/grpc_profiler.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static grpc_completion_queue* cq; +static grpc_server* server; +static grpc_call* call; +static grpc_call_details call_details; +static grpc_metadata_array request_metadata_recv; +static grpc_metadata_array initial_metadata_send; +static grpc_byte_buffer* payload_buffer = nullptr; +/* Used to drain the terminal read in unary calls. */ +static grpc_byte_buffer* terminal_buffer = nullptr; + +static grpc_op read_op; +static grpc_op metadata_send_op; +static grpc_op write_op; +static grpc_op status_op[2]; +static int was_cancelled = 2; +static grpc_op unary_ops[6]; +static int got_sigint = 0; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +typedef enum { + FLING_SERVER_NEW_REQUEST = 1, + FLING_SERVER_READ_FOR_UNARY, + FLING_SERVER_BATCH_OPS_FOR_UNARY, + FLING_SERVER_SEND_INIT_METADATA_FOR_STREAMING, + FLING_SERVER_READ_FOR_STREAMING, + FLING_SERVER_WRITE_FOR_STREAMING, + FLING_SERVER_SEND_STATUS_FOR_STREAMING +} fling_server_tags; + +typedef struct { + gpr_refcount pending_ops; + uint32_t flags; +} call_state; + +static void request_call(void) { + grpc_metadata_array_init(&request_metadata_recv); + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_call(server, &call, &call_details, + &request_metadata_recv, cq, cq, + tag(FLING_SERVER_NEW_REQUEST))); +} + +static void handle_unary_method(void) { + grpc_op* op; + grpc_call_error error; + + grpc_metadata_array_init(&initial_metadata_send); + + op = unary_ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &terminal_buffer; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + if (payload_buffer == nullptr) { + gpr_log(GPR_INFO, "NULL payload buffer !!!"); + } + op->data.send_message.send_message = payload_buffer; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status_details = nullptr; + op++; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op++; + + error = grpc_call_start_batch(call, unary_ops, + static_cast(op - unary_ops), + tag(FLING_SERVER_BATCH_OPS_FOR_UNARY), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); +} + +static void send_initial_metadata(void) { + grpc_call_error error; + void* tagarg = tag(FLING_SERVER_SEND_INIT_METADATA_FOR_STREAMING); + grpc_metadata_array_init(&initial_metadata_send); + metadata_send_op.op = GRPC_OP_SEND_INITIAL_METADATA; + metadata_send_op.data.send_initial_metadata.count = 0; + error = grpc_call_start_batch(call, &metadata_send_op, 1, tagarg, nullptr); + + GPR_ASSERT(GRPC_CALL_OK == error); +} + +static void start_read_op(int t) { + grpc_call_error error; + /* Starting read at server */ + read_op.op = GRPC_OP_RECV_MESSAGE; + read_op.data.recv_message.recv_message = &payload_buffer; + error = grpc_call_start_batch(call, &read_op, 1, tag(t), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); +} + +static void start_write_op(void) { + grpc_call_error error; + void* tagarg = tag(FLING_SERVER_WRITE_FOR_STREAMING); + /* Starting write at server */ + write_op.op = GRPC_OP_SEND_MESSAGE; + if (payload_buffer == nullptr) { + gpr_log(GPR_INFO, "NULL payload buffer !!!"); + } + write_op.data.send_message.send_message = payload_buffer; + error = grpc_call_start_batch(call, &write_op, 1, tagarg, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); +} + +static void start_send_status(void) { + grpc_call_error error; + void* tagarg = tag(FLING_SERVER_SEND_STATUS_FOR_STREAMING); + status_op[0].op = GRPC_OP_SEND_STATUS_FROM_SERVER; + status_op[0].data.send_status_from_server.status = GRPC_STATUS_OK; + status_op[0].data.send_status_from_server.trailing_metadata_count = 0; + status_op[0].data.send_status_from_server.status_details = nullptr; + status_op[1].op = GRPC_OP_RECV_CLOSE_ON_SERVER; + status_op[1].data.recv_close_on_server.cancelled = &was_cancelled; + + error = grpc_call_start_batch(call, status_op, 2, tagarg, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); +} + +/* We have some sort of deadlock, so let's not exit gracefully for now. + When that is resolved, please remove the #include above. */ +static void sigint_handler(int /*x*/) { _exit(0); } + +int main(int argc, char** argv) { + grpc_event ev; + call_state* s; + std::string addr_buf; + gpr_cmdline* cl; + grpc_completion_queue* shutdown_cq; + int shutdown_started = 0; + int shutdown_finished = 0; + + int secure = 0; + const char* addr = nullptr; + + char* fake_argv[1]; + + gpr_timers_set_log_filename("latency_trace.fling_server.txt"); + + GPR_ASSERT(argc >= 1); + fake_argv[0] = argv[0]; + grpc_test_init(1, fake_argv); + + grpc_init(); + srand(static_cast(clock())); + + cl = gpr_cmdline_create("fling server"); + gpr_cmdline_add_string(cl, "bind", "Bind host:port", &addr); + gpr_cmdline_add_flag(cl, "secure", "Run with security?", &secure); + gpr_cmdline_parse(cl, argc, argv); + gpr_cmdline_destroy(cl); + + if (addr == nullptr) { + addr_buf = grpc_core::JoinHostPort("::", grpc_pick_unused_port_or_die()); + addr = addr_buf.c_str(); + } + gpr_log(GPR_INFO, "creating server on: %s", addr); + + cq = grpc_completion_queue_create_for_next(nullptr); + if (secure) { + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {test_server1_key, + test_server1_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT(grpc_server_add_secure_http2_port(server, addr, ssl_creds)); + grpc_server_credentials_release(ssl_creds); + } else { + server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port(server, addr)); + } + grpc_server_register_completion_queue(server, cq, nullptr); + grpc_server_start(server); + + addr = nullptr; + addr_buf.clear(); + + grpc_call_details_init(&call_details); + + request_call(); + + grpc_profiler_start("server.prof"); + signal(SIGINT, sigint_handler); + while (!shutdown_finished) { + if (got_sigint && !shutdown_started) { + gpr_log(GPR_INFO, "Shutting down due to SIGINT"); + + shutdown_cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(server, shutdown_cq, tag(1000)); + + GPR_ASSERT(grpc_completion_queue_pluck( + shutdown_cq, tag(1000), + grpc_timeout_seconds_to_deadline(5), nullptr) + .type == GRPC_OP_COMPLETE); + grpc_completion_queue_destroy(shutdown_cq); + + grpc_completion_queue_shutdown(cq); + shutdown_started = 1; + } + ev = grpc_completion_queue_next( + cq, + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000000, GPR_TIMESPAN)), + nullptr); + s = static_cast(ev.tag); + switch (ev.type) { + case GRPC_OP_COMPLETE: + switch (reinterpret_cast(s)) { + case FLING_SERVER_NEW_REQUEST: + if (call != nullptr) { + if (0 == grpc_slice_str_cmp(call_details.method, + "/Reflector/reflectStream")) { + /* Received streaming call. Send metadata here. */ + start_read_op(FLING_SERVER_READ_FOR_STREAMING); + send_initial_metadata(); + } else { + /* Received unary call. Can do all ops in one batch. */ + start_read_op(FLING_SERVER_READ_FOR_UNARY); + } + } else { + GPR_ASSERT(shutdown_started); + } + /* request_call(); + */ + break; + case FLING_SERVER_READ_FOR_STREAMING: + if (payload_buffer != nullptr) { + /* Received payload from client. */ + start_write_op(); + } else { + /* Received end of stream from client. */ + start_send_status(); + } + break; + case FLING_SERVER_WRITE_FOR_STREAMING: + /* Write completed at server */ + grpc_byte_buffer_destroy(payload_buffer); + payload_buffer = nullptr; + start_read_op(FLING_SERVER_READ_FOR_STREAMING); + break; + case FLING_SERVER_SEND_INIT_METADATA_FOR_STREAMING: + /* Metadata send completed at server */ + break; + case FLING_SERVER_SEND_STATUS_FOR_STREAMING: + /* Send status and close completed at server */ + grpc_call_unref(call); + if (!shutdown_started) request_call(); + break; + case FLING_SERVER_READ_FOR_UNARY: + /* Finished payload read for unary. Start all reamaining + * unary ops in a batch. + */ + handle_unary_method(); + break; + case FLING_SERVER_BATCH_OPS_FOR_UNARY: + /* Finished unary call. */ + grpc_byte_buffer_destroy(payload_buffer); + payload_buffer = nullptr; + grpc_call_unref(call); + if (!shutdown_started) request_call(); + break; + } + break; + case GRPC_QUEUE_SHUTDOWN: + GPR_ASSERT(shutdown_started); + shutdown_finished = 1; + break; + case GRPC_QUEUE_TIMEOUT: + break; + } + } + grpc_profiler_stop(); + grpc_call_details_destroy(&call_details); + + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); + grpc_shutdown(); + return 0; +} diff --git a/test/core/gpr/alloc_test.cc b/test/core/gpr/alloc_test.cc new file mode 100644 index 00000000..afa4dccd --- /dev/null +++ b/test/core/gpr/alloc_test.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +static void test_malloc_aligned() { + for (size_t size = 1; size <= 256; ++size) { + void* ptr = gpr_malloc_aligned(size, 16); + GPR_ASSERT(ptr != nullptr); + GPR_ASSERT(((intptr_t)ptr & 0xf) == 0); + memset(ptr, 0, size); + gpr_free_aligned(ptr); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_malloc_aligned(); + return 0; +} diff --git a/test/core/gpr/arena_test.cc b/test/core/gpr/arena_test.cc new file mode 100644 index 00000000..64bd0589 --- /dev/null +++ b/test/core/gpr/arena_test.cc @@ -0,0 +1,131 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/arena.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +using grpc_core::Arena; + +static void test_noop(void) { Arena::Create(1)->Destroy(); } + +static void test(const char* name, size_t init_size, const size_t* allocs, + size_t nallocs) { + std::vector parts; + parts.push_back( + absl::StrFormat("test '%s': %" PRIdPTR " <- {", name, init_size)); + for (size_t i = 0; i < nallocs; i++) { + parts.push_back(absl::StrFormat("%" PRIdPTR ",", allocs[i])); + } + parts.push_back("}"); + std::string s = absl::StrJoin(parts, ""); + gpr_log(GPR_INFO, "%s", s.c_str()); + + Arena* a = Arena::Create(init_size); + void** ps = static_cast(gpr_zalloc(sizeof(*ps) * nallocs)); + for (size_t i = 0; i < nallocs; i++) { + ps[i] = a->Alloc(allocs[i]); + // ensure the returned address is aligned + GPR_ASSERT(((intptr_t)ps[i] & 0xf) == 0); + // ensure no duplicate results + for (size_t j = 0; j < i; j++) { + GPR_ASSERT(ps[i] != ps[j]); + } + // ensure writable + memset(ps[i], 1, allocs[i]); + } + a->Destroy(); + gpr_free(ps); +} + +#define TEST(name, init_size, ...) \ + static const size_t allocs_##name[] = {__VA_ARGS__}; \ + test(#name, init_size, allocs_##name, GPR_ARRAY_SIZE(allocs_##name)) + +#define CONCURRENT_TEST_THREADS 10 + +size_t concurrent_test_iterations() { + if (sizeof(void*) < 8) return 1000; + return 100000; +} + +typedef struct { + gpr_event ev_start; + Arena* arena; +} concurrent_test_args; + +static void concurrent_test_body(void* arg) { + concurrent_test_args* a = static_cast(arg); + gpr_event_wait(&a->ev_start, gpr_inf_future(GPR_CLOCK_REALTIME)); + for (size_t i = 0; i < concurrent_test_iterations(); i++) { + *static_cast(a->arena->Alloc(1)) = static_cast(i); + } +} + +static void concurrent_test(void) { + gpr_log(GPR_DEBUG, "concurrent_test"); + + concurrent_test_args args; + gpr_event_init(&args.ev_start); + args.arena = Arena::Create(1024); + + grpc_core::Thread thds[CONCURRENT_TEST_THREADS]; + + for (int i = 0; i < CONCURRENT_TEST_THREADS; i++) { + thds[i] = + grpc_core::Thread("grpc_concurrent_test", concurrent_test_body, &args); + thds[i].Start(); + } + + gpr_event_set(&args.ev_start, reinterpret_cast(1)); + + for (auto& th : thds) { + th.Join(); + } + + args.arena->Destroy(); +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + + test_noop(); + TEST(0_1, 0, 1); + TEST(1_1, 1, 1); + TEST(1_2, 1, 2); + TEST(1_3, 1, 3); + TEST(1_inc, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + TEST(6_123, 6, 1, 2, 3); + concurrent_test(); + + return 0; +} diff --git a/test/core/gpr/cpu_test.cc b/test/core/gpr/cpu_test.cc new file mode 100644 index 00000000..90c72bad --- /dev/null +++ b/test/core/gpr/cpu_test.cc @@ -0,0 +1,151 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test gpr per-cpu support: + gpr_cpu_num_cores() + gpr_cpu_current_cpu() +*/ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +/* Test structure is essentially: + 1) Figure out how many cores are present on the test system + 2) Create 3 times that many threads + 3) Have each thread do some amount of work (basically want to + gaurantee that all threads are running at once, and enough of them + to run on all cores). + 4) Each thread checks what core it is running on, and marks that core + as "used" in the test. + 5) Count number of "used" cores. + + The test will fail if: + 1) gpr_cpu_num_cores() == 0 + 2) Any result from gpr_cpu_current_cpu() >= gpr_cpu_num_cores() + 3) Ideally, we would fail if not all cores were seen as used. Unfortunately, + this is only probabilistically true, and depends on the OS, it's + scheduler, etc. So we just print out an indication of how many were seen; + hopefully developers can use this to sanity check their system. +*/ + +/* Status shared across threads */ +struct cpu_test { + gpr_mu mu; + int nthreads; + uint32_t ncores; + int is_done; + gpr_cv done_cv; + int* used; /* is this core used? */ + unsigned r; /* random number */ +}; + +static void worker_thread(void* arg) { + struct cpu_test* ct = static_cast(arg); + uint32_t cpu; + unsigned r = 12345678; + unsigned i, j; + /* Avoid repetitive division calculations */ + int64_t max_i = 1000 / grpc_test_slowdown_factor(); + int64_t max_j = 1000 / grpc_test_slowdown_factor(); + for (i = 0; i < max_i; i++) { + /* run for a bit - just calculate something random. */ + for (j = 0; j < max_j; j++) { + r = (r * 17) & ((r - i) | (r * i)); + } + cpu = gpr_cpu_current_cpu(); + GPR_ASSERT(cpu < ct->ncores); + gpr_mu_lock(&ct->mu); + ct->used[cpu] = 1; + for (j = 0; j < ct->ncores; j++) { + if (!ct->used[j]) break; + } + gpr_mu_unlock(&ct->mu); + if (j == ct->ncores) { + break; /* all cpus have been used - no further use in running this test */ + } + } + gpr_mu_lock(&ct->mu); + ct->r = r; /* make it look like we care about r's value... */ + ct->nthreads--; + if (ct->nthreads == 0) { + ct->is_done = 1; + gpr_cv_signal(&ct->done_cv); + } + gpr_mu_unlock(&ct->mu); +} + +static void cpu_test(void) { + uint32_t i; + int cores_seen = 0; + struct cpu_test ct; + ct.ncores = gpr_cpu_num_cores(); + GPR_ASSERT(ct.ncores > 0); + ct.nthreads = static_cast(ct.ncores) * 3; + ct.used = static_cast(gpr_malloc(ct.ncores * sizeof(int))); + memset(ct.used, 0, ct.ncores * sizeof(int)); + gpr_mu_init(&ct.mu); + gpr_cv_init(&ct.done_cv); + ct.is_done = 0; + + uint32_t nthreads = ct.ncores * 3; + grpc_core::Thread* thd = + static_cast(gpr_malloc(sizeof(*thd) * nthreads)); + + for (i = 0; i < nthreads; i++) { + thd[i] = grpc_core::Thread("grpc_cpu_test", &worker_thread, &ct); + thd[i].Start(); + } + gpr_mu_lock(&ct.mu); + while (!ct.is_done) { + gpr_cv_wait(&ct.done_cv, &ct.mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&ct.mu); + for (i = 0; i < nthreads; i++) { + thd[i].Join(); + } + gpr_free(thd); + fprintf(stderr, "Saw cores ["); + fflush(stderr); + for (i = 0; i < ct.ncores; i++) { + if (ct.used[i]) { + fprintf(stderr, "%d,", i); + fflush(stderr); + cores_seen++; + } + } + fprintf(stderr, "] (%d/%d)\n", cores_seen, ct.ncores); + fflush(stderr); + gpr_mu_destroy(&ct.mu); + gpr_cv_destroy(&ct.done_cv); + gpr_free(ct.used); +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + cpu_test(); + return 0; +} diff --git a/test/core/gpr/env_test.cc b/test/core/gpr/env_test.cc new file mode 100644 index 00000000..89237cdf --- /dev/null +++ b/test/core/gpr/env_test.cc @@ -0,0 +1,64 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gpr/env.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x) + +static void test_setenv_getenv(void) { + const char* name = "FOO"; + const char* value = "BAR"; + char* retrieved_value; + + LOG_TEST_NAME("test_setenv_getenv"); + + gpr_setenv(name, value); + retrieved_value = gpr_getenv(name); + GPR_ASSERT(retrieved_value != nullptr); + GPR_ASSERT(strcmp(value, retrieved_value) == 0); + gpr_free(retrieved_value); +} + +static void test_unsetenv(void) { + const char* name = "FOO"; + const char* value = "BAR"; + char* retrieved_value; + + LOG_TEST_NAME("test_unsetenv"); + + gpr_setenv(name, value); + gpr_unsetenv(name); + retrieved_value = gpr_getenv(name); + GPR_ASSERT(retrieved_value == nullptr); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_setenv_getenv(); + test_unsetenv(); + return 0; +} diff --git a/test/core/gpr/log_test.cc b/test/core/gpr/log_test.cc new file mode 100644 index 00000000..2acefdaf --- /dev/null +++ b/test/core/gpr/log_test.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include + +#include "src/core/lib/gprpp/global_config.h" +#include "test/core/util/test_config.h" + +static bool log_func_reached = false; + +static void test_callback(gpr_log_func_args* args) { + GPR_ASSERT(0 == strcmp(__FILE__, args->file)); + GPR_ASSERT(args->severity == GPR_LOG_SEVERITY_INFO); + GPR_ASSERT(0 == strcmp(args->message, "hello 1 2 3")); +} + +static void test_should_log(gpr_log_func_args* /*args*/) { + log_func_reached = true; +} + +static void test_should_not_log(gpr_log_func_args* /*args*/) { + GPR_ASSERT(false); +} + +#define test_log_function_reached(SEVERITY) \ + gpr_set_log_function(test_should_log); \ + log_func_reached = false; \ + gpr_log_message(SEVERITY, "hello 1 2 3"); \ + GPR_ASSERT(log_func_reached); \ + log_func_reached = false; \ + gpr_log(SEVERITY, "hello %d %d %d", 1, 2, 3); \ + GPR_ASSERT(log_func_reached); \ + gpr_set_log_function(nullptr); + +#define test_log_function_unreached(SEVERITY) \ + gpr_set_log_function(test_should_not_log); \ + gpr_log_message(SEVERITY, "hello 1 2 3"); \ + gpr_log(SEVERITY, "hello %d %d %d", 1, 2, 3); \ + gpr_set_log_function(nullptr); + +TEST(LogTest, Basic) { + /* test logging at various verbosity levels */ + gpr_log(GPR_DEBUG, "%s", "hello world"); + gpr_log(GPR_INFO, "%s", "hello world"); + gpr_log(GPR_ERROR, "%s", "hello world"); + /* should succeed */ + GPR_ASSERT(1); + gpr_set_log_function(test_callback); + gpr_log_message(GPR_INFO, "hello 1 2 3"); + gpr_log(GPR_INFO, "hello %d %d %d", 1, 2, 3); + gpr_set_log_function(nullptr); +} + +TEST(LogTest, LogVerbosity) { + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + test_log_function_reached(GPR_ERROR); + test_log_function_reached(GPR_INFO); + test_log_function_reached(GPR_DEBUG); + + gpr_set_log_verbosity(GPR_LOG_SEVERITY_INFO); + test_log_function_reached(GPR_ERROR); + test_log_function_reached(GPR_INFO); + test_log_function_unreached(GPR_DEBUG); + + gpr_set_log_verbosity(GPR_LOG_SEVERITY_ERROR); + test_log_function_reached(GPR_ERROR); + test_log_function_unreached(GPR_INFO); + test_log_function_unreached(GPR_DEBUG); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/gpr/murmur_hash_test.cc b/test/core/gpr/murmur_hash_test.cc new file mode 100644 index 00000000..1756496e --- /dev/null +++ b/test/core/gpr/murmur_hash_test.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gpr/murmur_hash.h" + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +typedef uint32_t (*hash_func)(const void* key, size_t len, uint32_t seed); + +/* From smhasher: + This should hopefully be a thorough and uambiguous test of whether a hash + is correctly implemented on a given platform */ + +static void verification_test(hash_func hash, uint32_t expected) { + uint8_t key[256]; + uint32_t hashes[256]; + uint32_t final = 0; + size_t i; + + memset(key, 0, sizeof(key)); + memset(hashes, 0, sizeof(hashes)); + + /* Hash keys of the form {0}, {0,1}, {0,1,2}... up to N=255,using 256-N as + the seed */ + + for (i = 0; i < 256; i++) { + key[i] = static_cast(i); + hashes[i] = hash(key, i, static_cast(256u - i)); + } + + /* Then hash the result array */ + + final = hash(hashes, sizeof(hashes), 0); + + /* The first four bytes of that hash, interpreted as a little-endian integer, + is our + verification value */ + + if (expected != final) { + gpr_log(GPR_INFO, "Verification value 0x%08X : Failed! (Expected 0x%08x)", + final, expected); + abort(); + } else { + gpr_log(GPR_INFO, "Verification value 0x%08X : Passed!", final); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + /* basic tests to verify that things don't crash */ + gpr_murmur_hash3("", 0, 0); + gpr_murmur_hash3("xyz", 3, 0); + verification_test(gpr_murmur_hash3, 0xB0F57EE3); + return 0; +} diff --git a/test/core/gpr/spinlock_test.cc b/test/core/gpr/spinlock_test.cc new file mode 100644 index 00000000..8bd9396d --- /dev/null +++ b/test/core/gpr/spinlock_test.cc @@ -0,0 +1,156 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr spin-lock support. */ + +#include "src/core/lib/gpr/spinlock.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +/* ------------------------------------------------- */ +/* Tests for gpr_spinlock. */ +struct test { + int thread_count; /* number of threads */ + grpc_core::Thread* threads; + + int64_t iterations; /* number of iterations per thread */ + int64_t counter; + int incr_step; /* how much to increment/decrement refcount each time */ + + gpr_spinlock mu; /* protects iterations, counter */ +}; + +/* Return pointer to a new struct test. */ +static struct test* test_new(int threads, int64_t iterations, int incr_step) { + struct test* m = static_cast(gpr_malloc(sizeof(*m))); + m->thread_count = threads; + m->threads = static_cast( + gpr_malloc(sizeof(*m->threads) * static_cast(threads))); + m->iterations = iterations; + m->counter = 0; + m->thread_count = 0; + m->incr_step = incr_step; + m->mu = GPR_SPINLOCK_INITIALIZER; + return m; +} + +/* Return pointer to a new struct test. */ +static void test_destroy(struct test* m) { + gpr_free(m->threads); + gpr_free(m); +} + +/* Create m->threads threads, each running (*body)(m) */ +static void test_create_threads(struct test* m, void (*body)(void* arg)) { + int i; + for (i = 0; i != m->thread_count; i++) { + m->threads[i] = grpc_core::Thread("grpc_create_threads", body, m); + m->threads[i].Start(); + } +} + +/* Wait until all threads report done. */ +static void test_wait(struct test* m) { + int i; + for (i = 0; i != m->thread_count; i++) { + m->threads[i].Join(); + } +} + +/* Test several threads running (*body)(struct test *m) for increasing settings + of m->iterations, until about timeout_s to 2*timeout_s seconds have elapsed. + If extra!=NULL, run (*extra)(m) in an additional thread. + incr_step controls by how much m->refcount should be incremented/decremented + (if at all) each time in the tests. + */ +static void test(const char* name, void (*body)(void* m), int timeout_s, + int incr_step) { + int64_t iterations = 1024; + struct test* m; + gpr_timespec start = gpr_now(GPR_CLOCK_REALTIME); + gpr_timespec time_taken; + gpr_timespec deadline = gpr_time_add( + start, gpr_time_from_micros(static_cast(timeout_s) * 1000000, + GPR_TIMESPAN)); + fprintf(stderr, "%s:", name); + fflush(stderr); + while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0) { + if (iterations < INT64_MAX / 2) iterations <<= 1; + fprintf(stderr, " %ld", static_cast(iterations)); + fflush(stderr); + m = test_new(10, iterations, incr_step); + test_create_threads(m, body); + test_wait(m); + if (m->counter != m->thread_count * m->iterations * m->incr_step) { + fprintf(stderr, "counter %ld threads %d iterations %ld\n", + static_cast(m->counter), m->thread_count, + static_cast(m->iterations)); + fflush(stderr); + GPR_ASSERT(0); + } + test_destroy(m); + } + time_taken = gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), start); + fprintf(stderr, " done %lld.%09d s\n", + static_cast(time_taken.tv_sec), + static_cast(time_taken.tv_nsec)); + fflush(stderr); +} + +/* Increment m->counter on each iteration; then mark thread as done. */ +static void inc(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + gpr_spinlock_lock(&m->mu); + m->counter++; + gpr_spinlock_unlock(&m->mu); + } +} + +/* Increment m->counter under lock acquired with trylock, m->iterations times; + then mark thread as done. */ +static void inctry(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations;) { + if (gpr_spinlock_trylock(&m->mu)) { + m->counter++; + gpr_spinlock_unlock(&m->mu); + i++; + } + } +} + +/* ------------------------------------------------- */ + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + test("spinlock", &inc, 1, 1); + test("spinlock try", &inctry, 1, 1); + return 0; +} diff --git a/test/core/gpr/string_test.cc b/test/core/gpr/string_test.cc new file mode 100644 index 00000000..5e3ed9d5 --- /dev/null +++ b/test/core/gpr/string_test.cc @@ -0,0 +1,313 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gpr/string.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x) + +static void test_strdup(void) { + static const char* src1 = "hello world"; + char* dst1; + + LOG_TEST_NAME("test_strdup"); + + dst1 = gpr_strdup(src1); + GPR_ASSERT(0 == strcmp(src1, dst1)); + gpr_free(dst1); + + GPR_ASSERT(nullptr == gpr_strdup(nullptr)); +} + +static void expect_dump(const char* buf, size_t len, uint32_t flags, + const char* result) { + char* got = gpr_dump(buf, len, flags); + GPR_ASSERT(0 == strcmp(got, result)); + gpr_free(got); +} + +static void test_dump(void) { + LOG_TEST_NAME("test_dump"); + expect_dump("\x01", 1, GPR_DUMP_HEX, "01"); + expect_dump("\x01", 1, GPR_DUMP_HEX | GPR_DUMP_ASCII, "01 '.'"); + expect_dump("\x01\x02", 2, GPR_DUMP_HEX, "01 02"); + expect_dump("\x01\x23\x45\x67\x89\xab\xcd\xef", 8, GPR_DUMP_HEX, + "01 23 45 67 89 ab cd ef"); + expect_dump("ab", 2, GPR_DUMP_HEX | GPR_DUMP_ASCII, "61 62 'ab'"); +} + +static void test_pu32_fail(const char* s) { + uint32_t out; + GPR_ASSERT(!gpr_parse_bytes_to_uint32(s, strlen(s), &out)); +} + +static void test_pu32_succeed(const char* s, uint32_t want) { + uint32_t out; + GPR_ASSERT(gpr_parse_bytes_to_uint32(s, strlen(s), &out)); + GPR_ASSERT(out == want); +} + +static void test_parse_uint32(void) { + LOG_TEST_NAME("test_parse_uint32"); + + test_pu32_fail("-1"); + test_pu32_fail("a"); + test_pu32_fail(""); + test_pu32_succeed("0", 0); + test_pu32_succeed("1", 1); + test_pu32_succeed("2", 2); + test_pu32_succeed("3", 3); + test_pu32_succeed("4", 4); + test_pu32_succeed("5", 5); + test_pu32_succeed("6", 6); + test_pu32_succeed("7", 7); + test_pu32_succeed("8", 8); + test_pu32_succeed("9", 9); + test_pu32_succeed("10", 10); + test_pu32_succeed("11", 11); + test_pu32_succeed("12", 12); + test_pu32_succeed("13", 13); + test_pu32_succeed("14", 14); + test_pu32_succeed("15", 15); + test_pu32_succeed("16", 16); + test_pu32_succeed("17", 17); + test_pu32_succeed("18", 18); + test_pu32_succeed("19", 19); + test_pu32_succeed("1234567890", 1234567890); + test_pu32_succeed("4294967295", 4294967295u); + test_pu32_fail("4294967296"); + test_pu32_fail("4294967297"); + test_pu32_fail("4294967298"); + test_pu32_fail("4294967299"); +} + +static void test_asprintf(void) { + char* buf; + int i, j; + + LOG_TEST_NAME("test_asprintf"); + + /* Print an empty string. */ + GPR_ASSERT(gpr_asprintf(&buf, "%s", "") == 0); + GPR_ASSERT(buf[0] == '\0'); + gpr_free(buf); + + /* Print strings of various lengths. */ + for (i = 1; i < 100; i++) { + GPR_ASSERT(gpr_asprintf(&buf, "%0*d", i, 1) == i); + + /* The buffer should resemble "000001\0". */ + for (j = 0; j < i - 2; j++) { + GPR_ASSERT(buf[j] == '0'); + } + GPR_ASSERT(buf[i - 1] == '1'); + GPR_ASSERT(buf[i] == '\0'); + gpr_free(buf); + } +} + +static void test_strjoin(void) { + const char* parts[4] = {"one", "two", "three", "four"}; + size_t joined_len; + char* joined; + + LOG_TEST_NAME("test_strjoin"); + + joined = gpr_strjoin(parts, 4, &joined_len); + GPR_ASSERT(0 == strcmp("onetwothreefour", joined)); + gpr_free(joined); + + joined = gpr_strjoin(parts, 0, &joined_len); + GPR_ASSERT(0 == strcmp("", joined)); + gpr_free(joined); + + joined = gpr_strjoin(parts, 1, &joined_len); + GPR_ASSERT(0 == strcmp("one", joined)); + gpr_free(joined); +} + +static void test_strjoin_sep(void) { + const char* parts[4] = {"one", "two", "three", "four"}; + size_t joined_len; + char* joined; + + LOG_TEST_NAME("test_strjoin_sep"); + + joined = gpr_strjoin_sep(parts, 4, ", ", &joined_len); + GPR_ASSERT(0 == strcmp("one, two, three, four", joined)); + gpr_free(joined); + + /* empty separator */ + joined = gpr_strjoin_sep(parts, 4, "", &joined_len); + GPR_ASSERT(0 == strcmp("onetwothreefour", joined)); + gpr_free(joined); + + /* degenerated case specifying zero input parts */ + joined = gpr_strjoin_sep(parts, 0, ", ", &joined_len); + GPR_ASSERT(0 == strcmp("", joined)); + gpr_free(joined); + + /* single part should have no separator */ + joined = gpr_strjoin_sep(parts, 1, ", ", &joined_len); + GPR_ASSERT(0 == strcmp("one", joined)); + gpr_free(joined); +} + +static void test_ltoa() { + char* str; + char buf[GPR_LTOA_MIN_BUFSIZE]; + + LOG_TEST_NAME("test_ltoa"); + + /* zero */ + GPR_ASSERT(1 == gpr_ltoa(0, buf)); + GPR_ASSERT(0 == strcmp("0", buf)); + + /* positive number */ + GPR_ASSERT(3 == gpr_ltoa(123, buf)); + GPR_ASSERT(0 == strcmp("123", buf)); + + /* negative number */ + GPR_ASSERT(6 == gpr_ltoa(-12345, buf)); + GPR_ASSERT(0 == strcmp("-12345", buf)); + + /* large negative - we don't know the size of long in advance */ + GPR_ASSERT(gpr_asprintf(&str, "%lld", (long long)LONG_MIN)); + GPR_ASSERT(strlen(str) == (size_t)gpr_ltoa(LONG_MIN, buf)); + GPR_ASSERT(0 == strcmp(str, buf)); + gpr_free(str); +} + +static void test_int64toa() { + char buf[GPR_INT64TOA_MIN_BUFSIZE]; + + LOG_TEST_NAME("test_int64toa"); + + /* zero */ + GPR_ASSERT(1 == int64_ttoa(0, buf)); + GPR_ASSERT(0 == strcmp("0", buf)); + + /* positive */ + GPR_ASSERT(3 == int64_ttoa(123, buf)); + GPR_ASSERT(0 == strcmp("123", buf)); + + /* large positive */ + GPR_ASSERT(19 == int64_ttoa(9223372036854775807LL, buf)); + GPR_ASSERT(0 == strcmp("9223372036854775807", buf)); + + /* large negative */ + GPR_ASSERT(20 == int64_ttoa(-9223372036854775807LL - 1, buf)); + GPR_ASSERT(0 == strcmp("-9223372036854775808", buf)); +} + +static void test_leftpad() { + char* padded; + + LOG_TEST_NAME("test_leftpad"); + + padded = gpr_leftpad("foo", ' ', 5); + GPR_ASSERT(0 == strcmp(" foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", ' ', 4); + GPR_ASSERT(0 == strcmp(" foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", ' ', 3); + GPR_ASSERT(0 == strcmp("foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", ' ', 2); + GPR_ASSERT(0 == strcmp("foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", ' ', 1); + GPR_ASSERT(0 == strcmp("foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", ' ', 0); + GPR_ASSERT(0 == strcmp("foo", padded)); + gpr_free(padded); + + padded = gpr_leftpad("foo", '0', 5); + GPR_ASSERT(0 == strcmp("00foo", padded)); + gpr_free(padded); +} + +static void test_stricmp(void) { + LOG_TEST_NAME("test_stricmp"); + + GPR_ASSERT(0 == gpr_stricmp("hello", "hello")); + GPR_ASSERT(0 == gpr_stricmp("HELLO", "hello")); + GPR_ASSERT(gpr_stricmp("a", "b") < 0); + GPR_ASSERT(gpr_stricmp("b", "a") > 0); +} + +static void test_memrchr(void) { + LOG_TEST_NAME("test_memrchr"); + + GPR_ASSERT(nullptr == gpr_memrchr(nullptr, 'a', 0)); + GPR_ASSERT(nullptr == gpr_memrchr("", 'a', 0)); + GPR_ASSERT(nullptr == gpr_memrchr("hello", 'b', 5)); + GPR_ASSERT(0 == strcmp((const char*)gpr_memrchr("hello", 'h', 5), "hello")); + GPR_ASSERT(0 == strcmp((const char*)gpr_memrchr("hello", 'o', 5), "o")); + GPR_ASSERT(0 == strcmp((const char*)gpr_memrchr("hello", 'l', 5), "lo")); +} + +static void test_parse_bool_value(void) { + LOG_TEST_NAME("test_parse_bool_value"); + + bool ret; + GPR_ASSERT(true == gpr_parse_bool_value("truE", &ret) && true == ret); + GPR_ASSERT(true == gpr_parse_bool_value("falsE", &ret) && false == ret); + GPR_ASSERT(true == gpr_parse_bool_value("1", &ret) && true == ret); + GPR_ASSERT(true == gpr_parse_bool_value("0", &ret) && false == ret); + GPR_ASSERT(true == gpr_parse_bool_value("Yes", &ret) && true == ret); + GPR_ASSERT(true == gpr_parse_bool_value("No", &ret) && false == ret); + GPR_ASSERT(true == gpr_parse_bool_value("Y", &ret) && true == ret); + GPR_ASSERT(true == gpr_parse_bool_value("N", &ret) && false == ret); + GPR_ASSERT(false == gpr_parse_bool_value(nullptr, &ret)); + GPR_ASSERT(false == gpr_parse_bool_value("", &ret)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_strdup(); + test_dump(); + test_parse_uint32(); + test_asprintf(); + test_strjoin(); + test_strjoin_sep(); + test_ltoa(); + test_int64toa(); + test_leftpad(); + test_stricmp(); + test_memrchr(); + test_parse_bool_value(); + return 0; +} diff --git a/test/core/gpr/sync_test.cc b/test/core/gpr/sync_test.cc new file mode 100644 index 00000000..b38180d0 --- /dev/null +++ b/test/core/gpr/sync_test.cc @@ -0,0 +1,473 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr synchronization support. */ + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +/* ==================Example use of interface=================== + + A producer-consumer queue of up to N integers, + illustrating the use of the calls in this interface. */ + +#define N 4 + +typedef struct queue { + gpr_cv non_empty; /* Signalled when length becomes non-zero. */ + gpr_cv non_full; /* Signalled when length becomes non-N. */ + gpr_mu mu; /* Protects all fields below. + (That is, except during initialization or + destruction, the fields below should be accessed + only by a thread that holds mu.) */ + int head; /* Index of head of queue 0..N-1. */ + int length; /* Number of valid elements in queue 0..N. */ + int elem[N]; /* elem[head .. head+length-1] are queue elements. */ +} queue; + +/* Initialize *q. */ +void queue_init(queue* q) { + gpr_mu_init(&q->mu); + gpr_cv_init(&q->non_empty); + gpr_cv_init(&q->non_full); + q->head = 0; + q->length = 0; +} + +/* Free storage associated with *q. */ +void queue_destroy(queue* q) { + gpr_mu_destroy(&q->mu); + gpr_cv_destroy(&q->non_empty); + gpr_cv_destroy(&q->non_full); +} + +/* Wait until there is room in *q, then append x to *q. */ +void queue_append(queue* q, int x) { + gpr_mu_lock(&q->mu); + /* To wait for a predicate without a deadline, loop on the negation of the + predicate, and use gpr_cv_wait(..., gpr_inf_future(GPR_CLOCK_REALTIME)) + inside the loop + to release the lock, wait, and reacquire on each iteration. Code that + makes the condition true should use gpr_cv_broadcast() on the + corresponding condition variable. The predicate must be on state + protected by the lock. */ + while (q->length == N) { + gpr_cv_wait(&q->non_full, &q->mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + if (q->length == 0) { /* Wake threads blocked in queue_remove(). */ + /* It's normal to use gpr_cv_broadcast() or gpr_signal() while + holding the lock. */ + gpr_cv_broadcast(&q->non_empty); + } + q->elem[(q->head + q->length) % N] = x; + q->length++; + gpr_mu_unlock(&q->mu); +} + +/* If it can be done without blocking, append x to *q and return non-zero. + Otherwise return 0. */ +int queue_try_append(queue* q, int x) { + int result = 0; + if (gpr_mu_trylock(&q->mu)) { + if (q->length != N) { + if (q->length == 0) { /* Wake threads blocked in queue_remove(). */ + gpr_cv_broadcast(&q->non_empty); + } + q->elem[(q->head + q->length) % N] = x; + q->length++; + result = 1; + } + gpr_mu_unlock(&q->mu); + } + return result; +} + +/* Wait until the *q is non-empty or deadline abs_deadline passes. If the + queue is non-empty, remove its head entry, place it in *head, and return + non-zero. Otherwise return 0. */ +int queue_remove(queue* q, int* head, gpr_timespec abs_deadline) { + int result = 0; + gpr_mu_lock(&q->mu); + /* To wait for a predicate with a deadline, loop on the negation of the + predicate or until gpr_cv_wait() returns true. Code that makes + the condition true should use gpr_cv_broadcast() on the corresponding + condition variable. The predicate must be on state protected by the + lock. */ + while (q->length == 0 && !gpr_cv_wait(&q->non_empty, &q->mu, abs_deadline)) { + } + if (q->length != 0) { /* Queue is non-empty. */ + result = 1; + if (q->length == N) { /* Wake threads blocked in queue_append(). */ + gpr_cv_broadcast(&q->non_full); + } + *head = q->elem[q->head]; + q->head = (q->head + 1) % N; + q->length--; + } /* else deadline exceeded */ + gpr_mu_unlock(&q->mu); + return result; +} + +/* ------------------------------------------------- */ +/* Tests for gpr_mu and gpr_cv, and the queue example. */ +struct test { + int nthreads; /* number of threads */ + grpc_core::Thread* threads; + + int64_t iterations; /* number of iterations per thread */ + int64_t counter; + int thread_count; /* used to allocate thread ids */ + int done; /* threads not yet completed */ + int incr_step; /* how much to increment/decrement refcount each time */ + + gpr_mu mu; /* protects iterations, counter, thread_count, done */ + + gpr_cv cv; /* signalling depends on test */ + + gpr_cv done_cv; /* signalled when done == 0 */ + + queue q; + + gpr_stats_counter stats_counter; + + gpr_refcount refcount; + gpr_refcount thread_refcount; + gpr_event event; +}; + +/* Return pointer to a new struct test. */ +static struct test* test_new(int nthreads, int64_t iterations, int incr_step) { + struct test* m = static_cast(gpr_malloc(sizeof(*m))); + m->nthreads = nthreads; + m->threads = static_cast( + gpr_malloc(sizeof(*m->threads) * nthreads)); + m->iterations = iterations; + m->counter = 0; + m->thread_count = 0; + m->done = nthreads; + m->incr_step = incr_step; + gpr_mu_init(&m->mu); + gpr_cv_init(&m->cv); + gpr_cv_init(&m->done_cv); + queue_init(&m->q); + gpr_stats_init(&m->stats_counter, 0); + gpr_ref_init(&m->refcount, 0); + gpr_ref_init(&m->thread_refcount, nthreads); + gpr_event_init(&m->event); + return m; +} + +/* Return pointer to a new struct test. */ +static void test_destroy(struct test* m) { + gpr_mu_destroy(&m->mu); + gpr_cv_destroy(&m->cv); + gpr_cv_destroy(&m->done_cv); + queue_destroy(&m->q); + gpr_free(m->threads); + gpr_free(m); +} + +/* Create m->nthreads threads, each running (*body)(m) */ +static void test_create_threads(struct test* m, void (*body)(void* arg)) { + int i; + for (i = 0; i != m->nthreads; i++) { + m->threads[i] = grpc_core::Thread("grpc_create_threads", body, m); + m->threads[i].Start(); + } +} + +/* Wait until all threads report done. */ +static void test_wait(struct test* m) { + gpr_mu_lock(&m->mu); + while (m->done != 0) { + gpr_cv_wait(&m->done_cv, &m->mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&m->mu); + for (int i = 0; i != m->nthreads; i++) { + m->threads[i].Join(); + } +} + +/* Get an integer thread id in the raneg 0..nthreads-1 */ +static int thread_id(struct test* m) { + int id; + gpr_mu_lock(&m->mu); + id = m->thread_count++; + gpr_mu_unlock(&m->mu); + return id; +} + +/* Indicate that a thread is done, by decrementing m->done + and signalling done_cv if m->done==0. */ +static void mark_thread_done(struct test* m) { + gpr_mu_lock(&m->mu); + GPR_ASSERT(m->done != 0); + m->done--; + if (m->done == 0) { + gpr_cv_signal(&m->done_cv); + } + gpr_mu_unlock(&m->mu); +} + +/* Test several threads running (*body)(struct test *m) for increasing settings + of m->iterations, until about timeout_s to 2*timeout_s seconds have elapsed. + If extra!=NULL, run (*extra)(m) in an additional thread. + incr_step controls by how much m->refcount should be incremented/decremented + (if at all) each time in the tests. + */ +static void test(const char* name, void (*body)(void* m), + void (*extra)(void* m), int timeout_s, int incr_step) { + int64_t iterations = 256; + struct test* m; + gpr_timespec start = gpr_now(GPR_CLOCK_REALTIME); + gpr_timespec time_taken; + gpr_timespec deadline = gpr_time_add( + start, gpr_time_from_micros(static_cast(timeout_s) * 1000000, + GPR_TIMESPAN)); + fprintf(stderr, "%s:", name); + fflush(stderr); + while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0) { + fprintf(stderr, " %ld", static_cast(iterations)); + fflush(stderr); + m = test_new(10, iterations, incr_step); + grpc_core::Thread extra_thd; + if (extra != nullptr) { + extra_thd = grpc_core::Thread(name, extra, m); + extra_thd.Start(); + m->done++; /* one more thread to wait for */ + } + test_create_threads(m, body); + test_wait(m); + if (extra != nullptr) { + extra_thd.Join(); + } + if (m->counter != m->nthreads * m->iterations * m->incr_step) { + fprintf(stderr, "counter %ld threads %d iterations %ld\n", + static_cast(m->counter), m->nthreads, + static_cast(m->iterations)); + fflush(stderr); + GPR_ASSERT(0); + } + test_destroy(m); + iterations <<= 1; + } + time_taken = gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), start); + fprintf(stderr, " done %lld.%09d s\n", + static_cast(time_taken.tv_sec), + static_cast(time_taken.tv_nsec)); + fflush(stderr); +} + +/* Increment m->counter on each iteration; then mark thread as done. */ +static void inc(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + gpr_mu_lock(&m->mu); + m->counter++; + gpr_mu_unlock(&m->mu); + } + mark_thread_done(m); +} + +/* Increment m->counter under lock acquired with trylock, m->iterations times; + then mark thread as done. */ +static void inctry(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations;) { + if (gpr_mu_trylock(&m->mu)) { + m->counter++; + gpr_mu_unlock(&m->mu); + i++; + } + } + mark_thread_done(m); +} + +/* Increment counter only when (m->counter%m->nthreads)==m->thread_id; then mark + thread as done. */ +static void inc_by_turns(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + int id = thread_id(m); + for (i = 0; i != m->iterations; i++) { + gpr_mu_lock(&m->mu); + while ((m->counter % m->nthreads) != id) { + gpr_cv_wait(&m->cv, &m->mu, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + m->counter++; + gpr_cv_broadcast(&m->cv); + gpr_mu_unlock(&m->mu); + } + mark_thread_done(m); +} + +/* Wait a millisecond and increment counter on each iteration; + then mark thread as done. */ +static void inc_with_1ms_delay(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + gpr_timespec deadline; + gpr_mu_lock(&m->mu); + deadline = gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(1000, GPR_TIMESPAN)); + while (!gpr_cv_wait(&m->cv, &m->mu, deadline)) { + } + m->counter++; + gpr_mu_unlock(&m->mu); + } + mark_thread_done(m); +} + +/* Wait a millisecond and increment counter on each iteration, using an event + for timing; then mark thread as done. */ +static void inc_with_1ms_delay_event(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + gpr_timespec deadline; + deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000, GPR_TIMESPAN)); + GPR_ASSERT(gpr_event_wait(&m->event, deadline) == nullptr); + gpr_mu_lock(&m->mu); + m->counter++; + gpr_mu_unlock(&m->mu); + } + mark_thread_done(m); +} + +/* Produce m->iterations elements on queue m->q, then mark thread as done. + Even threads use queue_append(), and odd threads use queue_try_append() + until it succeeds. */ +static void many_producers(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + int x = thread_id(m); + if ((x & 1) == 0) { + for (i = 0; i != m->iterations; i++) { + queue_append(&m->q, 1); + } + } else { + for (i = 0; i != m->iterations; i++) { + while (!queue_try_append(&m->q, 1)) { + } + } + } + mark_thread_done(m); +} + +/* Consume elements from m->q until m->nthreads*m->iterations are seen, + wait an extra second to confirm that no more elements are arriving, + then mark thread as done. */ +static void consumer(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t n = m->iterations * m->nthreads; + int64_t i; + int value; + for (i = 0; i != n; i++) { + queue_remove(&m->q, &value, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_lock(&m->mu); + m->counter = n; + gpr_mu_unlock(&m->mu); + GPR_ASSERT( + !queue_remove(&m->q, &value, + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(1000000, GPR_TIMESPAN)))); + mark_thread_done(m); +} + +/* Increment m->stats_counter m->iterations times, transfer counter value to + m->counter, then mark thread as done. */ +static void statsinc(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + gpr_stats_inc(&m->stats_counter, 1); + } + gpr_mu_lock(&m->mu); + m->counter = gpr_stats_read(&m->stats_counter); + gpr_mu_unlock(&m->mu); + mark_thread_done(m); +} + +/* Increment m->refcount by m->incr_step for m->iterations times. Decrement + m->thread_refcount once, and if it reaches zero, set m->event to (void*)1; + then mark thread as done. */ +static void refinc(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t i; + for (i = 0; i != m->iterations; i++) { + if (m->incr_step == 1) { + gpr_ref(&m->refcount); + } else { + gpr_refn(&m->refcount, m->incr_step); + } + } + if (gpr_unref(&m->thread_refcount)) { + gpr_event_set(&m->event, reinterpret_cast(1)); + } + mark_thread_done(m); +} + +/* Wait until m->event is set to (void *)1, then decrement m->refcount by 1 + (m->nthreads * m->iterations * m->incr_step) times, and ensure that the last + decrement caused the counter to reach zero, then mark thread as done. */ +static void refcheck(void* v /*=m*/) { + struct test* m = static_cast(v); + int64_t n = m->iterations * m->nthreads * m->incr_step; + int64_t i; + GPR_ASSERT(gpr_event_wait(&m->event, gpr_inf_future(GPR_CLOCK_REALTIME)) == + (void*)1); + GPR_ASSERT(gpr_event_get(&m->event) == (void*)1); + for (i = 1; i != n; i++) { + GPR_ASSERT(!gpr_unref(&m->refcount)); + m->counter++; + } + GPR_ASSERT(gpr_unref(&m->refcount)); + m->counter++; + mark_thread_done(m); +} + +/* ------------------------------------------------- */ + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + test("mutex", &inc, nullptr, 1, 1); + test("mutex try", &inctry, nullptr, 1, 1); + test("cv", &inc_by_turns, nullptr, 1, 1); + test("timedcv", &inc_with_1ms_delay, nullptr, 1, 1); + test("queue", &many_producers, &consumer, 10, 1); + test("stats_counter", &statsinc, nullptr, 1, 1); + test("refcount by 1", &refinc, &refcheck, 1, 1); + test("refcount by 3", &refinc, &refcheck, 1, 3); /* incr_step of 3 is an + arbitrary choice. Any + number > 1 is okay here */ + test("timedevent", &inc_with_1ms_delay_event, nullptr, 1, 1); + return 0; +} diff --git a/test/core/gpr/time_test.cc b/test/core/gpr/time_test.cc new file mode 100644 index 00000000..4631ae8f --- /dev/null +++ b/test/core/gpr/time_test.cc @@ -0,0 +1,268 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr time support. */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +static void to_fp(void* arg, const char* buf, size_t len) { + fwrite(buf, 1, len, static_cast(arg)); +} + +/* Convert gpr_intmax x to ascii base b (2..16), and write with + (*writer)(arg, ...), zero padding to "chars" digits). */ +static void i_to_s(intmax_t x, int base, int chars, + void (*writer)(void* arg, const char* buf, size_t len), + void* arg) { + char buf[64]; + char fmt[32]; + GPR_ASSERT(base == 16 || base == 10); + sprintf(fmt, "%%0%d%s", chars, base == 16 ? PRIxMAX : PRIdMAX); + sprintf(buf, fmt, x); + (*writer)(arg, buf, strlen(buf)); +} + +/* Convert ts to ascii, and write with (*writer)(arg, ...). */ +static void ts_to_s(gpr_timespec t, + void (*writer)(void* arg, const char* buf, size_t len), + void* arg) { + if (t.tv_sec < 0 && t.tv_nsec != 0) { + t.tv_sec++; + t.tv_nsec = GPR_NS_PER_SEC - t.tv_nsec; + } + i_to_s(t.tv_sec, 10, 0, writer, arg); + (*writer)(arg, ".", 1); + i_to_s(t.tv_nsec, 10, 9, writer, arg); +} + +static void test_values(void) { + int i; + + gpr_timespec x = gpr_time_0(GPR_CLOCK_REALTIME); + GPR_ASSERT(x.tv_sec == 0 && x.tv_nsec == 0); + + x = gpr_inf_future(GPR_CLOCK_REALTIME); + fprintf(stderr, "far future "); + fflush(stderr); + i_to_s(x.tv_sec, 16, 16, &to_fp, stderr); + fprintf(stderr, "\n"); + GPR_ASSERT(x.tv_sec == INT64_MAX); + fprintf(stderr, "far future "); + fflush(stderr); + ts_to_s(x, &to_fp, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + + x = gpr_inf_past(GPR_CLOCK_REALTIME); + fprintf(stderr, "far past "); + fflush(stderr); + i_to_s(x.tv_sec, 16, 16, &to_fp, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + GPR_ASSERT(x.tv_sec == INT64_MIN); + fprintf(stderr, "far past "); + fflush(stderr); + ts_to_s(x, &to_fp, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + + for (i = 1; i != 1000 * 1000 * 1000; i *= 10) { + x = gpr_time_from_micros(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec == i / GPR_US_PER_SEC && + x.tv_nsec == (i % GPR_US_PER_SEC) * GPR_NS_PER_US); + x = gpr_time_from_nanos(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec == i / GPR_NS_PER_SEC && + x.tv_nsec == (i % GPR_NS_PER_SEC)); + x = gpr_time_from_millis(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec == i / GPR_MS_PER_SEC && + x.tv_nsec == (i % GPR_MS_PER_SEC) * GPR_NS_PER_MS); + } + + /* Test possible overflow in conversion of -ve values. */ + x = gpr_time_from_micros(-(INT64_MAX - 999997), GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec < 0); + GPR_ASSERT(x.tv_nsec >= 0 && x.tv_nsec < GPR_NS_PER_SEC); + + x = gpr_time_from_nanos(-(INT64_MAX - 999999997), GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec < 0); + GPR_ASSERT(x.tv_nsec >= 0 && x.tv_nsec < GPR_NS_PER_SEC); + + x = gpr_time_from_millis(-(INT64_MAX - 997), GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec < 0); + GPR_ASSERT(x.tv_nsec >= 0 && x.tv_nsec < GPR_NS_PER_SEC); + + /* Test general -ve values. */ + for (i = -1; i > -1000 * 1000 * 1000; i *= 7) { + x = gpr_time_from_micros(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec * GPR_US_PER_SEC + x.tv_nsec / GPR_NS_PER_US == i); + x = gpr_time_from_nanos(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec * GPR_NS_PER_SEC + x.tv_nsec == i); + x = gpr_time_from_millis(i, GPR_TIMESPAN); + GPR_ASSERT(x.tv_sec * GPR_MS_PER_SEC + x.tv_nsec / GPR_NS_PER_MS == i); + } +} + +static void test_add_sub(void) { + int i; + int j; + int k; + /* Basic addition and subtraction. */ + for (i = -100; i <= 100; i++) { + for (j = -100; j <= 100; j++) { + for (k = 1; k <= 10000000; k *= 10) { + int sum = i + j; + int diff = i - j; + gpr_timespec it = gpr_time_from_micros(i * k, GPR_TIMESPAN); + gpr_timespec jt = gpr_time_from_micros(j * k, GPR_TIMESPAN); + gpr_timespec sumt = gpr_time_add(it, jt); + gpr_timespec difft = gpr_time_sub(it, jt); + if (gpr_time_cmp(gpr_time_from_micros(sum * k, GPR_TIMESPAN), sumt) != + 0) { + fprintf(stderr, "i %d j %d sum %d sumt ", i, j, sum); + fflush(stderr); + ts_to_s(sumt, &to_fp, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + GPR_ASSERT(0); + } + if (gpr_time_cmp(gpr_time_from_micros(diff * k, GPR_TIMESPAN), difft) != + 0) { + fprintf(stderr, "i %d j %d diff %d diff ", i, j, diff); + fflush(stderr); + ts_to_s(sumt, &to_fp, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + GPR_ASSERT(0); + } + } + } + } +} + +static void test_overflow(void) { + /* overflow */ + gpr_timespec x = gpr_time_from_micros(1, GPR_TIMESPAN); + do { + x = gpr_time_add(x, x); + } while (gpr_time_cmp(x, gpr_inf_future(GPR_TIMESPAN)) < 0); + GPR_ASSERT(gpr_time_cmp(x, gpr_inf_future(GPR_TIMESPAN)) == 0); + x = gpr_time_from_micros(-1, GPR_TIMESPAN); + do { + x = gpr_time_add(x, x); + } while (gpr_time_cmp(x, gpr_inf_past(GPR_TIMESPAN)) > 0); + GPR_ASSERT(gpr_time_cmp(x, gpr_inf_past(GPR_TIMESPAN)) == 0); +} + +static void test_sticky_infinities(void) { + int i; + int j; + int k; + gpr_timespec infinity[2]; + gpr_timespec addend[3]; + infinity[0] = gpr_inf_future(GPR_TIMESPAN); + infinity[1] = gpr_inf_past(GPR_TIMESPAN); + addend[0] = gpr_inf_future(GPR_TIMESPAN); + addend[1] = gpr_inf_past(GPR_TIMESPAN); + addend[2] = gpr_time_0(GPR_TIMESPAN); + + /* Infinities are sticky */ + for (i = 0; i != sizeof(infinity) / sizeof(infinity[0]); i++) { + for (j = 0; j != sizeof(addend) / sizeof(addend[0]); j++) { + gpr_timespec x = gpr_time_add(infinity[i], addend[j]); + GPR_ASSERT(gpr_time_cmp(x, infinity[i]) == 0); + x = gpr_time_sub(infinity[i], addend[j]); + GPR_ASSERT(gpr_time_cmp(x, infinity[i]) == 0); + } + for (k = -200; k <= 200; k++) { + gpr_timespec y = gpr_time_from_micros(k * 100000, GPR_TIMESPAN); + gpr_timespec x = gpr_time_add(infinity[i], y); + GPR_ASSERT(gpr_time_cmp(x, infinity[i]) == 0); + x = gpr_time_sub(infinity[i], y); + GPR_ASSERT(gpr_time_cmp(x, infinity[i]) == 0); + } + } +} + +static void test_similar(void) { + GPR_ASSERT(1 == gpr_time_similar(gpr_inf_future(GPR_TIMESPAN), + gpr_inf_future(GPR_TIMESPAN), + gpr_time_0(GPR_TIMESPAN))); + GPR_ASSERT(1 == gpr_time_similar(gpr_inf_past(GPR_TIMESPAN), + gpr_inf_past(GPR_TIMESPAN), + gpr_time_0(GPR_TIMESPAN))); + GPR_ASSERT(0 == gpr_time_similar(gpr_inf_past(GPR_TIMESPAN), + gpr_inf_future(GPR_TIMESPAN), + gpr_time_0(GPR_TIMESPAN))); + GPR_ASSERT(0 == gpr_time_similar(gpr_inf_future(GPR_TIMESPAN), + gpr_inf_past(GPR_TIMESPAN), + gpr_time_0(GPR_TIMESPAN))); + GPR_ASSERT(1 == gpr_time_similar(gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_0(GPR_TIMESPAN))); + GPR_ASSERT(1 == gpr_time_similar(gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_from_micros(15, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN))); + GPR_ASSERT(1 == gpr_time_similar(gpr_time_from_micros(15, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN))); + GPR_ASSERT(0 == gpr_time_similar(gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_from_micros(25, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN))); + GPR_ASSERT(0 == gpr_time_similar(gpr_time_from_micros(25, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN), + gpr_time_from_micros(10, GPR_TIMESPAN))); +} + +static void test_convert_extreme(void) { + gpr_timespec realtime = {INT64_MAX, 1, GPR_CLOCK_REALTIME}; + gpr_timespec monotime = gpr_convert_clock_type(realtime, GPR_CLOCK_MONOTONIC); + GPR_ASSERT(monotime.tv_sec == realtime.tv_sec); + GPR_ASSERT(monotime.clock_type == GPR_CLOCK_MONOTONIC); +} + +static void test_cmp_extreme(void) { + gpr_timespec t1 = {INT64_MAX, 1, GPR_CLOCK_REALTIME}; + gpr_timespec t2 = {INT64_MAX, 2, GPR_CLOCK_REALTIME}; + GPR_ASSERT(gpr_time_cmp(t1, t2) == 0); + t1.tv_sec = INT64_MIN; + t2.tv_sec = INT64_MIN; + GPR_ASSERT(gpr_time_cmp(t1, t2) == 0); +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + + test_values(); + test_add_sub(); + test_overflow(); + test_sticky_infinities(); + test_similar(); + test_convert_extreme(); + test_cmp_extreme(); + return 0; +} diff --git a/test/core/gpr/tls_test.cc b/test/core/gpr/tls_test.cc new file mode 100644 index 00000000..1642580f --- /dev/null +++ b/test/core/gpr/tls_test.cc @@ -0,0 +1,67 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr thread local storage support. */ + +#include "src/core/lib/gpr/tls.h" + +#include + +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +struct BiggerThanMachineWord { + size_t a, b; + uint8_t c; +}; + +static GPR_THREAD_LOCAL(BiggerThanMachineWord) test_var; +// Fails to compile: static GPR_THREAD_LOCAL(std::unique_ptr) non_trivial; + +namespace { +void thd_body(void*) { + for (size_t i = 0; i < 100000; i++) { + BiggerThanMachineWord next = {i, i, uint8_t(i)}; + test_var = next; + BiggerThanMachineWord read = test_var; + ASSERT_EQ(read.a, i); + ASSERT_EQ(read.b, i); + ASSERT_EQ(read.c, uint8_t(i)) << i; + } +} + +TEST(ThreadLocal, ReadWrite) { + std::array threads; + for (grpc_core::Thread& th : threads) { + th = grpc_core::Thread("grpc_tls_test", thd_body, nullptr); + th.Start(); + } + for (grpc_core::Thread& th : threads) { + th.Join(); + } +} + +} // namespace + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gpr/useful_test.cc b/test/core/gpr/useful_test.cc new file mode 100644 index 00000000..2d7d6705 --- /dev/null +++ b/test/core/gpr/useful_test.cc @@ -0,0 +1,70 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/gpr/useful.h" + +#include + +namespace grpc_core { + +TEST(UsefulTest, ClampWorks) { + EXPECT_EQ(grpc_core::Clamp(1, 0, 2), 1); + EXPECT_EQ(grpc_core::Clamp(0, 0, 2), 0); + EXPECT_EQ(grpc_core::Clamp(2, 0, 2), 2); + EXPECT_EQ(grpc_core::Clamp(-1, 0, 2), 0); + EXPECT_EQ(grpc_core::Clamp(3, 0, 2), 2); +} + +TEST(UsefulTest, Rotate) { + EXPECT_EQ(grpc_core::RotateLeft(0x80000001u, 1u), 3); + EXPECT_EQ(grpc_core::RotateRight(0x80000001u, 1u), 0xc0000000); +} + +TEST(UsefulTest, ArraySize) { + int four[4]; + int five[5]; + + EXPECT_EQ(GPR_ARRAY_SIZE(four), 4); + EXPECT_EQ(GPR_ARRAY_SIZE(five), 5); +} + +TEST(UsefulTest, BitOps) { + uint32_t bitset = 0; + + EXPECT_EQ(grpc_core::BitCount((1u << 31) - 1), 31); + EXPECT_EQ(grpc_core::BitCount(1u << 3), 1); + EXPECT_EQ(grpc_core::BitCount(0), 0); + EXPECT_EQ(grpc_core::SetBit(&bitset, 3), 8); + EXPECT_EQ(grpc_core::BitCount(bitset), 1); + EXPECT_EQ(grpc_core::GetBit(bitset, 3), 1); + EXPECT_EQ(grpc_core::SetBit(&bitset, 1), 10); + EXPECT_EQ(grpc_core::BitCount(bitset), 2); + EXPECT_EQ(grpc_core::ClearBit(&bitset, 3), 2); + EXPECT_EQ(grpc_core::BitCount(bitset), 1); + EXPECT_EQ(grpc_core::GetBit(bitset, 3), 0); + EXPECT_EQ(grpc_core::BitCount(std::numeric_limits::max()), 64); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/bitset_test.cc b/test/core/gprpp/bitset_test.cc new file mode 100644 index 00000000..dd6204ec --- /dev/null +++ b/test/core/gprpp/bitset_test.cc @@ -0,0 +1,106 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/bitset.h" + +#include + +#include + +namespace grpc_core { +namespace testing { + +// Stand in type to make the size to test a type +template +struct Size { + static constexpr size_t kBits = K; +}; + +using TestSizes = ::testing::Types< + // All sizes up to 17 bits + Size<1>, Size<2>, Size<3>, Size<4>, Size<5>, Size<6>, Size<7>, Size<8>, + Size<9>, Size<10>, Size<11>, Size<12>, Size<13>, Size<14>, Size<15>, + Size<16>, Size<17>, + // Values around 32 bits + Size<24>, Size<25>, Size<26>, Size<27>, Size<28>, Size<29>, Size<30>, + Size<31>, Size<32>, Size<33>, + // Values around 48 bits + Size<47>, Size<48>, Size<49>, + // Values around 64 bits + Size<62>, Size<63>, Size<64>, Size<65>, Size<66>, + // Values around 96 bits + Size<95>, Size<96>, Size<97>, + // Silly numbers of bits + Size<1024>, Size<4000>, Size<4321> >; + +template +struct BitSetTest : public ::testing::Test {}; + +TYPED_TEST_SUITE(BitSetTest, TestSizes); + +TYPED_TEST(BitSetTest, NoneAtInit) { + BitSet b; + EXPECT_TRUE(b.none()); +} + +TYPED_TEST(BitSetTest, OneBit) { + constexpr size_t kBits = TypeParam::kBits; + for (size_t i = 0; i < kBits; i++) { + BitSet b; + b.set(i); + EXPECT_FALSE(b.none()); + for (size_t j = 0; j < kBits; j++) { + EXPECT_EQ(b.is_set(j), i == j); + } + } +} + +TYPED_TEST(BitSetTest, AllSet) { + constexpr size_t kBits = TypeParam::kBits; + BitSet b; + for (size_t i = 0; i < kBits; i++) { + EXPECT_FALSE(b.all()); + b.set(i); + } + EXPECT_TRUE(b.all()); +} + +TYPED_TEST(BitSetTest, Count) { + constexpr size_t kBits = TypeParam::kBits; + BitSet b; + std::set bits_set; + std::random_device rd; + std::uniform_int_distribution dist(0, kBits - 1); + for (size_t i = 0; i < 4 * kBits; i++) { + size_t bit = dist(rd); + bits_set.insert(bit); + b.set(bit); + EXPECT_EQ(b.count(), bits_set.size()); + } +} + +TEST(EmptyBitSet, Empty) { + BitSet<0> b; + EXPECT_TRUE(b.all()); + EXPECT_TRUE(b.none()); + EXPECT_EQ(b.count(), 0); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/capture_test.cc b/test/core/gprpp/capture_test.cc new file mode 100644 index 00000000..0cb68ef2 --- /dev/null +++ b/test/core/gprpp/capture_test.cc @@ -0,0 +1,39 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/capture.h" + +#include + +namespace grpc_core { + +TEST(CaptureTest, Capture) { + auto f = Capture([](int* p) { EXPECT_EQ(*p, 42); }, 42); + f(); +} + +TEST(CaptureTest, WithArgsAndReturn) { + int captured = 1; + auto f = + Capture([captured](int* p, int arg) { return (captured + *p) * arg; }, 2); + EXPECT_EQ(f(2), 6); + EXPECT_EQ(f(3), 9); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/chunked_vector_fuzzer.cc b/test/core/gprpp/chunked_vector_fuzzer.cc new file mode 100644 index 00000000..1704e3e7 --- /dev/null +++ b/test/core/gprpp/chunked_vector_fuzzer.cc @@ -0,0 +1,152 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/gprpp/chunked_vector.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "test/core/gprpp/chunked_vector_fuzzer.pb.h" + +bool squelch = true; +bool leak_check = true; + +static constexpr size_t kChunkSize = 17; +using IntHdl = std::shared_ptr; + +namespace grpc_core { +struct Comparison { + explicit Comparison(Arena* arena) : chunked(arena) {} + + ChunkedVector chunked; + std::vector std; + + // Check that both chunked and std are equivalent. + void AssertOk() const { + GPR_ASSERT(std.size() == chunked.size()); + auto it_chunked = chunked.cbegin(); + auto it_std = std.cbegin(); + while (it_std != std.cend()) { + GPR_ASSERT(**it_std == **it_chunked); + ++it_chunked; + ++it_std; + } + GPR_ASSERT(it_chunked == chunked.cend()); + } +}; + +class Fuzzer { + public: + Fuzzer() = default; + ~Fuzzer() = default; + + void Act(const chunked_vector_fuzzer::Action& action) { + switch (action.action_type_case()) { + case chunked_vector_fuzzer::Action::kEmplaceBack: { + // Add some value to the back of a comparison, assert that both vectors + // are equivalent. + auto* c = Mutate(action.emplace_back().vector()); + c->chunked.EmplaceBack( + std::make_shared(action.emplace_back().value())); + c->std.emplace_back( + std::make_shared(action.emplace_back().value())); + c->AssertOk(); + } break; + case chunked_vector_fuzzer::Action::kPopBack: { + // Remove some value to the back of a comparison, assert that both + // vectors are equivalent. + auto* c = Mutate(action.pop_back().vector()); + if (c->chunked.size() > 0) { + c->chunked.PopBack(); + c->std.pop_back(); + c->AssertOk(); + } + } break; + case chunked_vector_fuzzer::Action::kCopy: { + // Copy one vector into another, assert both everything stays + // equivalent. + auto it_from = vectors_.find(action.copy().from()); + if (it_from == vectors_.end()) { + it_from = + vectors_.emplace(action.copy().from(), Comparison(arena_.get())) + .first; + } + auto it_to = vectors_.find(action.copy().to()); + if (it_to == vectors_.end()) { + it_to = vectors_.emplace(action.copy().to(), it_from->second).first; + } else { + it_to->second = it_from->second; + } + it_from->second.AssertOk(); + it_to->second.AssertOk(); + } break; + case chunked_vector_fuzzer::Action::kMove: { + // Move one vector into another, assert both everything stays + // equivalent. + auto it_from = vectors_.find(action.move().from()); + if (it_from == vectors_.end()) { + it_from = + vectors_.emplace(action.move().from(), Comparison(arena_.get())) + .first; + } + auto it_to = vectors_.find(action.move().to()); + if (it_to == vectors_.end()) { + it_to = + vectors_.emplace(action.move().to(), std::move(it_from->second)) + .first; + } else { + it_to->second = it_from->second; + } + it_from->second.AssertOk(); + it_to->second.AssertOk(); + } break; + case chunked_vector_fuzzer::Action::kClear: { + // Clear a vector, assert that both underlying vectors are equivalent. + auto* c = Mutate(action.clear().vector()); + c->chunked.Clear(); + c->std.clear(); + c->AssertOk(); + } break; + case chunked_vector_fuzzer::Action::kSwap: { + // Swap two vectors, assert that both underlying vectors are equivalent. + auto* from = Mutate(action.swap().from()); + auto* to = Mutate(action.swap().to()); + from->chunked.Swap(&to->chunked); + from->std.swap(to->std); + from->AssertOk(); + } break; + case chunked_vector_fuzzer::Action::ACTION_TYPE_NOT_SET: + break; + } + } + + private: + Comparison* Mutate(int index) { + auto it = vectors_.find(index); + if (it != vectors_.end()) { + return &it->second; + } + return &vectors_.emplace(index, Comparison(arena_.get())).first->second; + } + + ScopedArenaPtr arena_ = MakeScopedArena(128); + std::map vectors_; +}; +} // namespace grpc_core + +DEFINE_PROTO_FUZZER(const chunked_vector_fuzzer::Msg& msg) { + grpc_core::Fuzzer fuzzer; + for (int i = 0; i < msg.actions_size(); i++) { + fuzzer.Act(msg.actions(i)); + } +} diff --git a/test/core/gprpp/chunked_vector_test.cc b/test/core/gprpp/chunked_vector_test.cc new file mode 100644 index 00000000..274bc0f3 --- /dev/null +++ b/test/core/gprpp/chunked_vector_test.cc @@ -0,0 +1,154 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/chunked_vector.h" + +#include + +namespace grpc_core { +namespace testing { + +static constexpr size_t kInitialArenaSize = 1024; +static constexpr size_t kChunkSize = 3; + +TEST(ChunkedVector, Noop) { + auto arena = MakeScopedArena(kInitialArenaSize); + ChunkedVector v(arena.get()); + EXPECT_EQ(0, v.size()); +} + +TEST(ChunkedVector, Stack) { + auto arena = MakeScopedArena(kInitialArenaSize); + ChunkedVector v(arena.get()); + + // Populate 2 chunks of memory, and 2/3 of a final chunk. + EXPECT_EQ(0, v.size()); + v.EmplaceBack(1); + EXPECT_EQ(1, v.size()); + v.EmplaceBack(2); + EXPECT_EQ(2, v.size()); + v.EmplaceBack(3); + EXPECT_EQ(3, v.size()); + v.EmplaceBack(4); + EXPECT_EQ(4, v.size()); + v.EmplaceBack(5); + EXPECT_EQ(5, v.size()); + v.EmplaceBack(6); + EXPECT_EQ(6, v.size()); + v.EmplaceBack(7); + EXPECT_EQ(7, v.size()); + v.EmplaceBack(8); + EXPECT_EQ(8, v.size()); + + // Now pop all of them out and check the expected ordering. + EXPECT_EQ(8, v.PopBack()); + EXPECT_EQ(7, v.size()); + EXPECT_EQ(7, v.PopBack()); + EXPECT_EQ(6, v.size()); + EXPECT_EQ(6, v.PopBack()); + EXPECT_EQ(5, v.size()); + EXPECT_EQ(5, v.PopBack()); + EXPECT_EQ(4, v.size()); + EXPECT_EQ(4, v.PopBack()); + EXPECT_EQ(3, v.size()); + EXPECT_EQ(3, v.PopBack()); + EXPECT_EQ(2, v.size()); + EXPECT_EQ(2, v.PopBack()); + EXPECT_EQ(1, v.size()); + EXPECT_EQ(1, v.PopBack()); + EXPECT_EQ(0, v.size()); +} + +TEST(ChunkedVector, Iterate) { + auto arena = MakeScopedArena(kInitialArenaSize); + ChunkedVector v(arena.get()); + v.EmplaceBack(1); + v.EmplaceBack(2); + v.EmplaceBack(3); + v.EmplaceBack(4); + v.EmplaceBack(5); + v.EmplaceBack(6); + v.EmplaceBack(7); + v.EmplaceBack(8); + + auto it = v.begin(); + EXPECT_EQ(1, *it); + ++it; + EXPECT_EQ(2, *it); + ++it; + EXPECT_EQ(3, *it); + ++it; + EXPECT_EQ(4, *it); + ++it; + EXPECT_EQ(5, *it); + ++it; + EXPECT_EQ(6, *it); + ++it; + EXPECT_EQ(7, *it); + ++it; + EXPECT_EQ(8, *it); + ++it; + EXPECT_EQ(v.end(), it); +} + +TEST(ChunkedVector, ConstIterate) { + auto arena = MakeScopedArena(kInitialArenaSize); + ChunkedVector v(arena.get()); + v.EmplaceBack(1); + v.EmplaceBack(2); + v.EmplaceBack(3); + v.EmplaceBack(4); + v.EmplaceBack(5); + v.EmplaceBack(6); + v.EmplaceBack(7); + v.EmplaceBack(8); + + auto it = v.cbegin(); + EXPECT_EQ(1, *it); + ++it; + EXPECT_EQ(2, *it); + ++it; + EXPECT_EQ(3, *it); + ++it; + EXPECT_EQ(4, *it); + ++it; + EXPECT_EQ(5, *it); + ++it; + EXPECT_EQ(6, *it); + ++it; + EXPECT_EQ(7, *it); + ++it; + EXPECT_EQ(8, *it); + ++it; + EXPECT_EQ(v.cend(), it); +} + +TEST(ChunkedVector, Clear) { + auto arena = MakeScopedArena(kInitialArenaSize); + ChunkedVector v(arena.get()); + v.EmplaceBack(1); + EXPECT_EQ(v.size(), 1); + v.Clear(); + EXPECT_EQ(v.size(), 0); + EXPECT_EQ(v.begin(), v.end()); +} + +} // namespace testing + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/dual_ref_counted_test.cc b/test/core/gprpp/dual_ref_counted_test.cc new file mode 100644 index 00000000..53fe0882 --- /dev/null +++ b/test/core/gprpp/dual_ref_counted_test.cc @@ -0,0 +1,108 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/gprpp/dual_ref_counted.h" + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class Foo : public DualRefCounted { + public: + Foo() = default; + ~Foo() override { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(DualRefCounted, Basic) { + Foo* foo = new Foo(); + foo->Unref(); +} + +TEST(DualRefCounted, ExtraRef) { + Foo* foo = new Foo(); + foo->Ref().release(); + foo->Unref(); + foo->Unref(); +} + +TEST(DualRefCounted, ExtraWeakRef) { + Foo* foo = new Foo(); + foo->WeakRef().release(); + foo->Unref(); + foo->WeakUnref(); +} + +TEST(DualRefCounted, RefIfNonZero) { + Foo* foo = new Foo(); + foo->WeakRef().release(); + { + RefCountedPtr foop = foo->RefIfNonZero(); + EXPECT_NE(foop.get(), nullptr); + } + foo->Unref(); + { + RefCountedPtr foop = foo->RefIfNonZero(); + EXPECT_EQ(foop.get(), nullptr); + } + foo->WeakUnref(); +} + +class FooWithTracing : public DualRefCounted { + public: + FooWithTracing() : DualRefCounted("FooWithTracing") {} + ~FooWithTracing() override { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(DualRefCountedWithTracing, Basic) { + FooWithTracing* foo = new FooWithTracing(); + foo->Ref(DEBUG_LOCATION, "extra_ref").release(); + foo->Unref(DEBUG_LOCATION, "extra_ref"); + foo->WeakRef(DEBUG_LOCATION, "extra_ref").release(); + foo->WeakUnref(DEBUG_LOCATION, "extra_ref"); + // Can use the no-argument methods, too. + foo->Ref().release(); + foo->Unref(); + foo->WeakRef().release(); + foo->WeakUnref(); + foo->Unref(DEBUG_LOCATION, "original_ref"); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/examine_stack_test.cc b/test/core/gprpp/examine_stack_test.cc new file mode 100644 index 00000000..2bffa08c --- /dev/null +++ b/test/core/gprpp/examine_stack_test.cc @@ -0,0 +1,84 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/examine_stack.h" + +#include +#include + +#include + +#include "absl/debugging/stacktrace.h" +#include "absl/debugging/symbolize.h" + +#include + +namespace { + +std::string SimpleCurrentStackTraceProvider() { return "stacktrace"; } + +std::string AbseilCurrentStackTraceProvider() { + std::string result = "Stack trace:\n"; + constexpr int kNumStackFrames = 10; + void* stack[kNumStackFrames]; + int frame_sizes[kNumStackFrames]; + int depth = absl::GetStackFrames(stack, frame_sizes, kNumStackFrames, 1); + for (int i = 0; i < depth; i++) { + char tmp[1024]; + const char* symbol = "(unknown)"; + if (absl::Symbolize(stack[i], tmp, sizeof(tmp))) { + symbol = tmp; + } + result += symbol; + result += +"\n"; + } + return result; +} + +} // namespace + +TEST(ExamineStackTest, NullStackProvider) { + grpc_core::SetCurrentStackTraceProvider(nullptr); + EXPECT_EQ(grpc_core::GetCurrentStackTraceProvider(), nullptr); + EXPECT_EQ(grpc_core::GetCurrentStackTrace(), absl::nullopt); +} + +TEST(ExamineStackTest, SimpleStackProvider) { + grpc_core::SetCurrentStackTraceProvider(&SimpleCurrentStackTraceProvider); + EXPECT_NE(grpc_core::GetCurrentStackTraceProvider(), nullptr); + EXPECT_EQ(grpc_core::GetCurrentStackTrace(), "stacktrace"); +} + +TEST(ExamineStackTest, AbseilStackProvider) { + grpc_core::SetCurrentStackTraceProvider(&AbseilCurrentStackTraceProvider); + EXPECT_NE(grpc_core::GetCurrentStackTraceProvider(), nullptr); + const absl::optional stack_trace = + grpc_core::GetCurrentStackTrace(); + EXPECT_NE(stack_trace, absl::nullopt); + gpr_log(GPR_INFO, "stack_trace=%s", stack_trace->c_str()); +#if !defined(NDEBUG) && !defined(GPR_MUSL_LIBC_COMPAT) + EXPECT_TRUE(stack_trace->find("GetCurrentStackTrace") != std::string::npos); +#endif +} + +int main(int argc, char** argv) { + absl::InitializeSymbolizer(argv[0]); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/gprpp/fork_test.cc b/test/core/gprpp/fork_test.cc new file mode 100644 index 00000000..bc69f77c --- /dev/null +++ b/test/core/gprpp/fork_test.cc @@ -0,0 +1,139 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/fork.h" + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +static void test_init() { + GPR_ASSERT(!grpc_core::Fork::Enabled()); + + // Default fork support (disabled) + grpc_core::Fork::GlobalInit(); + GPR_ASSERT(!grpc_core::Fork::Enabled()); + grpc_core::Fork::GlobalShutdown(); + + // Explicitly disabled fork support + grpc_core::Fork::Enable(false); + grpc_core::Fork::GlobalInit(); + GPR_ASSERT(!grpc_core::Fork::Enabled()); + grpc_core::Fork::GlobalShutdown(); + + // Explicitly enabled fork support + grpc_core::Fork::Enable(true); + grpc_core::Fork::GlobalInit(); + GPR_ASSERT(grpc_core::Fork::Enabled()); + grpc_core::Fork::GlobalShutdown(); +} + +// This spawns CONCURRENT_TEST_THREADS that last up to +// THREAD_DELAY_MS, and checks that the Fork::AwaitThreads() +// returns roughly after THREAD_DELAY_MS. The epsilon is high +// because tsan threads can take a while to spawn/join. +#define THREAD_DELAY_MS 6000 +#define THREAD_DELAY_EPSILON 1500 +#define CONCURRENT_TEST_THREADS 100 + +static void sleeping_thd(void* arg) { + int64_t sleep_ms = reinterpret_cast(arg); + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(sleep_ms, GPR_TIMESPAN))); +} + +static void test_thd_count() { + // Test no active threads + grpc_core::Fork::Enable(true); + grpc_core::Fork::GlobalInit(); + grpc_core::Fork::AwaitThreads(); + grpc_core::Fork::GlobalShutdown(); + + grpc_core::Fork::Enable(true); + grpc_core::Fork::GlobalInit(); + grpc_core::Thread thds[CONCURRENT_TEST_THREADS]; + gpr_timespec est_end_time = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(THREAD_DELAY_MS, GPR_TIMESPAN)); + gpr_timespec tolerance = + gpr_time_from_millis(THREAD_DELAY_EPSILON, GPR_TIMESPAN); + for (int i = 0; i < CONCURRENT_TEST_THREADS; i++) { + intptr_t sleep_time_ms = + (i * THREAD_DELAY_MS) / (CONCURRENT_TEST_THREADS - 1); + thds[i] = grpc_core::Thread("grpc_fork_test", sleeping_thd, + reinterpret_cast(sleep_time_ms)); + thds[i].Start(); + } + grpc_core::Fork::AwaitThreads(); + gpr_timespec end_time = gpr_now(GPR_CLOCK_REALTIME); + for (auto& thd : thds) { + thd.Join(); + } + GPR_ASSERT(gpr_time_similar(end_time, est_end_time, tolerance)); + grpc_core::Fork::GlobalShutdown(); +} + +static void exec_ctx_thread(void* arg) { + bool* exec_ctx_created = static_cast(arg); + grpc_core::Fork::IncExecCtxCount(); + *exec_ctx_created = true; +} + +static void test_exec_count() { + grpc_core::Fork::Enable(true); + grpc_core::Fork::GlobalInit(); + + grpc_core::Fork::IncExecCtxCount(); + GPR_ASSERT(grpc_core::Fork::BlockExecCtx()); + grpc_core::Fork::DecExecCtxCount(); + grpc_core::Fork::AllowExecCtx(); + + grpc_core::Fork::IncExecCtxCount(); + grpc_core::Fork::IncExecCtxCount(); + GPR_ASSERT(!grpc_core::Fork::BlockExecCtx()); + grpc_core::Fork::DecExecCtxCount(); + grpc_core::Fork::DecExecCtxCount(); + + grpc_core::Fork::IncExecCtxCount(); + GPR_ASSERT(grpc_core::Fork::BlockExecCtx()); + grpc_core::Fork::DecExecCtxCount(); + grpc_core::Fork::AllowExecCtx(); + + // Test that block_exec_ctx() blocks grpc_core::Fork::IncExecCtxCount + bool exec_ctx_created = false; + grpc_core::Thread thd = + grpc_core::Thread("grpc_fork_test", exec_ctx_thread, &exec_ctx_created); + grpc_core::Fork::IncExecCtxCount(); + GPR_ASSERT(grpc_core::Fork::BlockExecCtx()); + grpc_core::Fork::DecExecCtxCount(); + thd.Start(); + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + GPR_ASSERT(!exec_ctx_created); + grpc_core::Fork::AllowExecCtx(); + thd.Join(); // This ensure that the call got un-blocked + grpc_core::Fork::GlobalShutdown(); +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + test_init(); + test_thd_count(); + test_exec_count(); + + return 0; +} diff --git a/test/core/gprpp/global_config_env_test.cc b/test/core/gprpp/global_config_env_test.cc new file mode 100644 index 00000000..519c242e --- /dev/null +++ b/test/core/gprpp/global_config_env_test.cc @@ -0,0 +1,131 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/global_config_env.h" + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/memory.h" + +namespace { + +bool g_config_error_function_called; + +void ClearConfigErrorCalled() { g_config_error_function_called = false; } + +bool IsConfigErrorCalled() { return g_config_error_function_called; } + +// This function is for preventing the program from invoking +// an error handler due to configuration error and +// make test routines know whether there is error. +void FakeConfigErrorFunction(const char* /*error_message*/) { + g_config_error_function_called = true; +} + +class GlobalConfigEnvTest : public ::testing::Test { + protected: + void SetUp() override { ClearConfigErrorCalled(); } + void TearDown() override { EXPECT_FALSE(IsConfigErrorCalled()); } +}; + +} // namespace + +GPR_GLOBAL_CONFIG_DEFINE_BOOL(bool_var, true, ""); +GPR_GLOBAL_CONFIG_DEFINE_INT32(int32_var, 1234, ""); +GPR_GLOBAL_CONFIG_DEFINE_STRING(string_var, "Apple", ""); + +TEST_F(GlobalConfigEnvTest, BoolWithEnvTest) { + const char* bool_var_name = "BOOL_VAR"; + + gpr_unsetenv(bool_var_name); + EXPECT_TRUE(GPR_GLOBAL_CONFIG_GET(bool_var)); + + gpr_setenv(bool_var_name, "true"); + EXPECT_TRUE(GPR_GLOBAL_CONFIG_GET(bool_var)); + + gpr_setenv(bool_var_name, "false"); + EXPECT_FALSE(GPR_GLOBAL_CONFIG_GET(bool_var)); + + EXPECT_FALSE(IsConfigErrorCalled()); + + gpr_setenv(bool_var_name, ""); + GPR_GLOBAL_CONFIG_GET(bool_var); + EXPECT_TRUE(IsConfigErrorCalled()); + ClearConfigErrorCalled(); + + gpr_setenv(bool_var_name, "!"); + GPR_GLOBAL_CONFIG_GET(bool_var); + EXPECT_TRUE(IsConfigErrorCalled()); + ClearConfigErrorCalled(); +} + +TEST_F(GlobalConfigEnvTest, Int32WithEnvTest) { + const char* int32_var_name = "INT32_VAR"; + + gpr_unsetenv(int32_var_name); + EXPECT_EQ(1234, GPR_GLOBAL_CONFIG_GET(int32_var)); + + gpr_setenv(int32_var_name, "0"); + EXPECT_EQ(0, GPR_GLOBAL_CONFIG_GET(int32_var)); + + gpr_setenv(int32_var_name, "-123456789"); + EXPECT_EQ(-123456789, GPR_GLOBAL_CONFIG_GET(int32_var)); + + gpr_setenv(int32_var_name, "123456789"); + EXPECT_EQ(123456789, GPR_GLOBAL_CONFIG_GET(int32_var)); + + EXPECT_FALSE(IsConfigErrorCalled()); + + gpr_setenv(int32_var_name, "-1AB"); + GPR_GLOBAL_CONFIG_GET(int32_var); + EXPECT_TRUE(IsConfigErrorCalled()); + ClearConfigErrorCalled(); +} + +TEST_F(GlobalConfigEnvTest, StringWithEnvTest) { + const char* string_var_name = "STRING_VAR"; + grpc_core::UniquePtr value; + + gpr_unsetenv(string_var_name); + value = GPR_GLOBAL_CONFIG_GET(string_var); + EXPECT_EQ(0, strcmp(value.get(), "Apple")); + + gpr_setenv(string_var_name, "Banana"); + value = GPR_GLOBAL_CONFIG_GET(string_var); + EXPECT_EQ(0, strcmp(value.get(), "Banana")); + + gpr_setenv(string_var_name, ""); + value = GPR_GLOBAL_CONFIG_GET(string_var); + EXPECT_EQ(0, strcmp(value.get(), "")); +} + +int main(int argc, char** argv) { + // Not to abort the test when parsing error happens. + grpc_core::SetGlobalConfigEnvErrorFunction(&FakeConfigErrorFunction); + + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/gprpp/global_config_test.cc b/test/core/gprpp/global_config_test.cc new file mode 100644 index 00000000..31ab94bf --- /dev/null +++ b/test/core/gprpp/global_config_test.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/global_config.h" + +#include +#include + +#include + +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/memory.h" + +GPR_GLOBAL_CONFIG_DECLARE_BOOL(bool_var); + +GPR_GLOBAL_CONFIG_DEFINE_BOOL(bool_var, false, ""); +GPR_GLOBAL_CONFIG_DEFINE_INT32(int32_var, 0, ""); +GPR_GLOBAL_CONFIG_DEFINE_STRING(string_var, "", ""); + +TEST(GlobalConfigTest, BoolTest) { + EXPECT_FALSE(GPR_GLOBAL_CONFIG_GET(bool_var)); + GPR_GLOBAL_CONFIG_SET(bool_var, true); + EXPECT_TRUE(GPR_GLOBAL_CONFIG_GET(bool_var)); +} + +TEST(GlobalConfigTest, Int32Test) { + EXPECT_EQ(0, GPR_GLOBAL_CONFIG_GET(int32_var)); + GPR_GLOBAL_CONFIG_SET(int32_var, 1024); + EXPECT_EQ(1024, GPR_GLOBAL_CONFIG_GET(int32_var)); +} + +TEST(GlobalConfigTest, StringTest) { + grpc_core::UniquePtr value; + + value = GPR_GLOBAL_CONFIG_GET(string_var); + EXPECT_EQ(0, strcmp(value.get(), "")); + + GPR_GLOBAL_CONFIG_SET(string_var, "Test"); + + value = GPR_GLOBAL_CONFIG_GET(string_var); + EXPECT_EQ(0, strcmp(value.get(), "Test")); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/gprpp/host_port_test.cc b/test/core/gprpp/host_port_test.cc new file mode 100644 index 00000000..a63065c5 --- /dev/null +++ b/test/core/gprpp/host_port_test.cc @@ -0,0 +1,85 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/host_port.h" + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +static void join_host_port_expect(const char* host, int port, + const char* expected) { + std::string actual = grpc_core::JoinHostPort(host, port); + GPR_ASSERT(actual == expected); +} + +static void test_join_host_port(void) { + join_host_port_expect("foo", 101, "foo:101"); + join_host_port_expect("", 102, ":102"); + join_host_port_expect("1::2", 103, "[1::2]:103"); + join_host_port_expect("[::1]", 104, "[::1]:104"); +} + +/* Garbage in, garbage out. */ +static void test_join_host_port_garbage(void) { + join_host_port_expect("[foo]", 105, "[foo]:105"); + join_host_port_expect("[::", 106, "[:::106"); + join_host_port_expect("::]", 107, "[::]]:107"); +} + +static void split_host_port_expect(const char* name, const char* host, + const char* port, bool ret) { + std::string actual_host; + std::string actual_port; + const bool actual_ret = + grpc_core::SplitHostPort(name, &actual_host, &actual_port); + GPR_ASSERT(actual_ret == ret); + GPR_ASSERT(actual_host == (host == nullptr ? "" : host)); + GPR_ASSERT(actual_port == (port == nullptr ? "" : port)); +} + +static void test_split_host_port() { + split_host_port_expect("", "", nullptr, true); + split_host_port_expect("[a:b]", "a:b", nullptr, true); + split_host_port_expect("1.2.3.4", "1.2.3.4", nullptr, true); + split_host_port_expect("0.0.0.0:", "0.0.0.0", "", true); + split_host_port_expect("a:b:c::", "a:b:c::", nullptr, true); + split_host_port_expect("[a:b:c::]:", "a:b:c::", "", true); + split_host_port_expect("[a:b]:30", "a:b", "30", true); + split_host_port_expect("1.2.3.4:30", "1.2.3.4", "30", true); + split_host_port_expect(":30", "", "30", true); +} + +static void test_split_host_port_invalid() { + split_host_port_expect("[a:b", nullptr, nullptr, false); + split_host_port_expect("[a:b]30", nullptr, nullptr, false); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + + test_join_host_port(); + test_join_host_port_garbage(); + test_split_host_port(); + test_split_host_port_invalid(); + + return 0; +} diff --git a/test/core/gprpp/manual_constructor_test.cc b/test/core/gprpp/manual_constructor_test.cc new file mode 100644 index 00000000..04ad5926 --- /dev/null +++ b/test/core/gprpp/manual_constructor_test.cc @@ -0,0 +1,100 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr synchronization support. */ + +#include "src/core/lib/gprpp/manual_constructor.h" + +#include +#include + +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +class A { + public: + A() {} + virtual ~A() {} + virtual const char* foo() { return "A_foo"; } + virtual const char* bar() { return "A_bar"; } +}; + +class B : public A { + public: + B() {} + ~B() override {} + const char* foo() override { return "B_foo"; } + char get_junk() { return junk[0]; } + + private: + char junk[1000]; +}; + +class C : public B { + public: + C() {} + ~C() override {} + const char* bar() override { return "C_bar"; } + char get_more_junk() { return more_junk[0]; } + + private: + char more_junk[1000]; +}; + +class D : public A { + public: + const char* bar() override { return "D_bar"; } +}; + +static void basic_test() { + grpc_core::PolymorphicManualConstructor poly; + poly.Init(); + GPR_ASSERT(!strcmp(poly->foo(), "B_foo")); + GPR_ASSERT(!strcmp(poly->bar(), "A_bar")); +} + +static void complex_test() { + grpc_core::PolymorphicManualConstructor polyB; + polyB.Init(); + GPR_ASSERT(!strcmp(polyB->foo(), "B_foo")); + GPR_ASSERT(!strcmp(polyB->bar(), "A_bar")); + + grpc_core::PolymorphicManualConstructor polyC; + polyC.Init(); + GPR_ASSERT(!strcmp(polyC->foo(), "B_foo")); + GPR_ASSERT(!strcmp(polyC->bar(), "C_bar")); + + grpc_core::PolymorphicManualConstructor polyD; + polyD.Init(); + GPR_ASSERT(!strcmp(polyD->foo(), "A_foo")); + GPR_ASSERT(!strcmp(polyD->bar(), "D_bar")); +} + +/* ------------------------------------------------- */ + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + basic_test(); + complex_test(); + return 0; +} diff --git a/test/core/gprpp/match_test.cc b/test/core/gprpp/match_test.cc new file mode 100644 index 00000000..f0ea8a0c --- /dev/null +++ b/test/core/gprpp/match_test.cc @@ -0,0 +1,76 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/match.h" + +#include + +namespace grpc_core { +namespace testing { + +TEST(MatchTest, Test) { + EXPECT_EQ(Match( + absl::variant(1.9), [](int) -> int { abort(); }, + [](double x) -> int { + EXPECT_EQ(x, 1.9); + return 42; + }), + 42); + EXPECT_EQ(Match( + absl::variant(3), + [](int x) -> int { + EXPECT_EQ(x, 3); + return 42; + }, + [](double) -> int { abort(); }), + 42); +} + +TEST(MatchTest, TestVoidReturn) { + bool triggered = false; + Match( + absl::variant(1.9), [](int) { abort(); }, + [&triggered](double x) { + EXPECT_EQ(x, 1.9); + triggered = true; + }); + EXPECT_TRUE(triggered); +} + +TEST(MatchTest, TestMutable) { + absl::variant v = 1.9; + MatchMutable( + &v, [](int*) { abort(); }, [](double* x) { *x = 0.0; }); + EXPECT_EQ(v, (absl::variant(0.0))); +} + +TEST(MatchTest, TestMutableWithReturn) { + absl::variant v = 1.9; + EXPECT_EQ(MatchMutable( + &v, [](int*) -> int { abort(); }, + [](double* x) -> int { + *x = 0.0; + return 1; + }), + 1); + EXPECT_EQ(v, (absl::variant(0.0))); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/mpscq_test.cc b/test/core/gprpp/mpscq_test.cc new file mode 100644 index 00000000..40667942 --- /dev/null +++ b/test/core/gprpp/mpscq_test.cc @@ -0,0 +1,188 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/mpscq.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +using grpc_core::MultiProducerSingleConsumerQueue; + +typedef struct test_node { + MultiProducerSingleConsumerQueue::Node node; + size_t i; + size_t* ctr; +} test_node; + +static test_node* new_node(size_t i, size_t* ctr) { + test_node* n = new test_node(); + n->i = i; + n->ctr = ctr; + return n; +} + +static void test_serial(void) { + gpr_log(GPR_DEBUG, "test_serial"); + MultiProducerSingleConsumerQueue q; + for (size_t i = 0; i < 10000000; i++) { + q.Push(&new_node(i, nullptr)->node); + } + for (size_t i = 0; i < 10000000; i++) { + test_node* n = reinterpret_cast(q.Pop()); + GPR_ASSERT(n); + GPR_ASSERT(n->i == i); + delete n; + } +} + +typedef struct { + size_t ctr; + MultiProducerSingleConsumerQueue* q; + gpr_event* start; +} thd_args; + +#define THREAD_ITERATIONS 10000 + +static void test_thread(void* args) { + thd_args* a = static_cast(args); + gpr_event_wait(a->start, gpr_inf_future(GPR_CLOCK_REALTIME)); + for (size_t i = 1; i <= THREAD_ITERATIONS; i++) { + a->q->Push(&new_node(i, &a->ctr)->node); + } +} + +static void test_mt(void) { + gpr_log(GPR_DEBUG, "test_mt"); + gpr_event start; + gpr_event_init(&start); + grpc_core::Thread thds[100]; + thd_args ta[GPR_ARRAY_SIZE(thds)]; + MultiProducerSingleConsumerQueue q; + for (size_t i = 0; i < GPR_ARRAY_SIZE(thds); i++) { + ta[i].ctr = 0; + ta[i].q = &q; + ta[i].start = &start; + thds[i] = grpc_core::Thread("grpc_mt_test", test_thread, &ta[i]); + thds[i].Start(); + } + size_t num_done = 0; + size_t spins = 0; + gpr_event_set(&start, reinterpret_cast(1)); + while (num_done != GPR_ARRAY_SIZE(thds)) { + MultiProducerSingleConsumerQueue::Node* n; + while ((n = q.Pop()) == nullptr) { + spins++; + } + test_node* tn = reinterpret_cast(n); + GPR_ASSERT(*tn->ctr == tn->i - 1); + *tn->ctr = tn->i; + if (tn->i == THREAD_ITERATIONS) num_done++; + delete tn; + } + gpr_log(GPR_DEBUG, "spins: %" PRIdPTR, spins); + for (auto& th : thds) { + th.Join(); + } +} + +typedef struct { + thd_args* ta; + size_t num_thds; + gpr_mu mu; + size_t num_done; + size_t spins; + MultiProducerSingleConsumerQueue* q; + gpr_event* start; +} pull_args; + +static void pull_thread(void* arg) { + pull_args* pa = static_cast(arg); + gpr_event_wait(pa->start, gpr_inf_future(GPR_CLOCK_REALTIME)); + + for (;;) { + gpr_mu_lock(&pa->mu); + if (pa->num_done == pa->num_thds) { + gpr_mu_unlock(&pa->mu); + return; + } + MultiProducerSingleConsumerQueue::Node* n; + while ((n = pa->q->Pop()) == nullptr) { + pa->spins++; + } + test_node* tn = reinterpret_cast(n); + GPR_ASSERT(*tn->ctr == tn->i - 1); + *tn->ctr = tn->i; + if (tn->i == THREAD_ITERATIONS) pa->num_done++; + delete tn; + gpr_mu_unlock(&pa->mu); + } +} + +static void test_mt_multipop(void) { + gpr_log(GPR_DEBUG, "test_mt_multipop"); + gpr_event start; + gpr_event_init(&start); + grpc_core::Thread thds[50]; + grpc_core::Thread pull_thds[50]; + thd_args ta[GPR_ARRAY_SIZE(thds)]; + MultiProducerSingleConsumerQueue q; + for (size_t i = 0; i < GPR_ARRAY_SIZE(thds); i++) { + ta[i].ctr = 0; + ta[i].q = &q; + ta[i].start = &start; + thds[i] = grpc_core::Thread("grpc_multipop_test", test_thread, &ta[i]); + thds[i].Start(); + } + pull_args pa; + pa.ta = ta; + pa.num_thds = GPR_ARRAY_SIZE(thds); + pa.spins = 0; + pa.num_done = 0; + pa.q = &q; + pa.start = &start; + gpr_mu_init(&pa.mu); + for (size_t i = 0; i < GPR_ARRAY_SIZE(pull_thds); i++) { + pull_thds[i] = grpc_core::Thread("grpc_multipop_pull", pull_thread, &pa); + pull_thds[i].Start(); + } + gpr_event_set(&start, reinterpret_cast(1)); + for (auto& pth : pull_thds) { + pth.Join(); + } + gpr_log(GPR_DEBUG, "spins: %" PRIdPTR, pa.spins); + for (auto& th : thds) { + th.Join(); + } + gpr_mu_destroy(&pa.mu); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_serial(); + test_mt(); + test_mt_multipop(); + return 0; +} diff --git a/test/core/gprpp/orphanable_test.cc b/test/core/gprpp/orphanable_test.cc new file mode 100644 index 00000000..9fd4af4a --- /dev/null +++ b/test/core/gprpp/orphanable_test.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/orphanable.h" + +#include + +#include "src/core/lib/gprpp/memory.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class Foo : public Orphanable { + public: + Foo() : Foo(0) {} + explicit Foo(int value) : value_(value) {} + void Orphan() override { delete this; } + int value() const { return value_; } + + private: + int value_; +}; + +TEST(Orphanable, Basic) { + Foo* foo = new Foo(); + foo->Orphan(); +} + +TEST(OrphanablePtr, Basic) { + OrphanablePtr foo(new Foo()); + EXPECT_EQ(0, foo->value()); +} + +TEST(MakeOrphanable, DefaultConstructor) { + auto foo = MakeOrphanable(); + EXPECT_EQ(0, foo->value()); +} + +TEST(MakeOrphanable, WithParameters) { + auto foo = MakeOrphanable(5); + EXPECT_EQ(5, foo->value()); +} + +class Bar : public InternallyRefCounted { + public: + Bar() : Bar(0) {} + explicit Bar(int value) : value_(value) {} + void Orphan() override { Unref(); } + int value() const { return value_; } + + void StartWork() { self_ref_ = Ref(); } + void FinishWork() { self_ref_.reset(); } + + private: + int value_; + RefCountedPtr self_ref_; +}; + +TEST(OrphanablePtr, InternallyRefCounted) { + auto bar = MakeOrphanable(); + bar->StartWork(); + bar->FinishWork(); +} + +class Baz : public InternallyRefCounted { + public: + Baz() : Baz(0) {} + explicit Baz(int value) : InternallyRefCounted("Baz"), value_(value) {} + void Orphan() override { Unref(); } + int value() const { return value_; } + + void StartWork() { self_ref_ = Ref(DEBUG_LOCATION, "work"); } + void FinishWork() { + // This is a little ugly, but it makes the logged ref and unref match up. + self_ref_.release(); + Unref(DEBUG_LOCATION, "work"); + } + + private: + int value_; + RefCountedPtr self_ref_; +}; + +TEST(OrphanablePtr, InternallyRefCountedWithTracing) { + auto baz = MakeOrphanable(); + baz->StartWork(); + baz->FinishWork(); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/overload_test.cc b/test/core/gprpp/overload_test.cc new file mode 100644 index 00000000..2bbc3318 --- /dev/null +++ b/test/core/gprpp/overload_test.cc @@ -0,0 +1,37 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/overload.h" + +#include + +namespace grpc_core { +namespace testing { + +TEST(Overload, Test) { + auto a = [](int x) { return x; }; + auto b = [](std::string x) -> int { return x.length(); }; + auto overload = Overload(a, b); + EXPECT_EQ(overload(1), 1); + EXPECT_EQ(overload("1"), 1); + EXPECT_EQ(overload("abc"), 3); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/ref_counted_ptr_test.cc b/test/core/gprpp/ref_counted_ptr_test.cc new file mode 100644 index 00000000..bf202970 --- /dev/null +++ b/test/core/gprpp/ref_counted_ptr_test.cc @@ -0,0 +1,523 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/ref_counted_ptr.h" + +#include + +#include + +#include "src/core/lib/gprpp/dual_ref_counted.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +// +// RefCountedPtr<> tests +// + +class Foo : public RefCounted { + public: + Foo() : value_(0) {} + + explicit Foo(int value) : value_(value) {} + + int value() const { return value_; } + + private: + int value_; +}; + +TEST(RefCountedPtr, DefaultConstructor) { RefCountedPtr foo; } + +TEST(RefCountedPtr, ExplicitConstructorEmpty) { + RefCountedPtr foo(nullptr); +} + +TEST(RefCountedPtr, ExplicitConstructor) { RefCountedPtr foo(new Foo()); } + +TEST(RefCountedPtr, MoveConstructor) { + RefCountedPtr foo(new Foo()); + RefCountedPtr foo2(std::move(foo)); + // NOLINTNEXTLINE(bugprone-use-after-move) + EXPECT_EQ(nullptr, foo.get()); + EXPECT_NE(nullptr, foo2.get()); +} + +TEST(RefCountedPtr, MoveAssignment) { + RefCountedPtr foo(new Foo()); + RefCountedPtr foo2 = std::move(foo); + // NOLINTNEXTLINE(bugprone-use-after-move) + EXPECT_EQ(nullptr, foo.get()); + EXPECT_NE(nullptr, foo2.get()); +} + +TEST(RefCountedPtr, CopyConstructor) { + RefCountedPtr foo(new Foo()); + RefCountedPtr foo2(foo); + EXPECT_NE(nullptr, foo.get()); + EXPECT_EQ(foo.get(), foo2.get()); +} + +TEST(RefCountedPtr, CopyAssignment) { + RefCountedPtr foo(new Foo()); + RefCountedPtr foo2 = foo; + EXPECT_NE(nullptr, foo.get()); + EXPECT_EQ(foo.get(), foo2.get()); +} + +TEST(RefCountedPtr, CopyAssignmentWhenEmpty) { + RefCountedPtr foo; + RefCountedPtr foo2; + foo2 = foo; + EXPECT_EQ(nullptr, foo.get()); + EXPECT_EQ(nullptr, foo2.get()); +} + +TEST(RefCountedPtr, CopyAssignmentToSelf) { + RefCountedPtr foo(new Foo()); + foo = *&foo; // The "*&" avoids warnings from LLVM -Wself-assign. +} + +TEST(RefCountedPtr, EnclosedScope) { + RefCountedPtr foo(new Foo()); + { + RefCountedPtr foo2(std::move(foo)); + // NOLINTNEXTLINE(bugprone-use-after-move) + EXPECT_EQ(nullptr, foo.get()); + EXPECT_NE(nullptr, foo2.get()); + } + EXPECT_EQ(nullptr, foo.get()); +} + +TEST(RefCountedPtr, ResetFromNullToNonNull) { + RefCountedPtr foo; + EXPECT_EQ(nullptr, foo.get()); + foo.reset(new Foo()); + EXPECT_NE(nullptr, foo.get()); +} + +TEST(RefCountedPtr, ResetFromNonNullToNonNull) { + RefCountedPtr foo(new Foo()); + EXPECT_NE(nullptr, foo.get()); + Foo* original = foo.get(); + foo.reset(new Foo()); + EXPECT_NE(nullptr, foo.get()); + EXPECT_NE(original, foo.get()); +} + +TEST(RefCountedPtr, ResetFromNonNullToNull) { + RefCountedPtr foo(new Foo()); + EXPECT_NE(nullptr, foo.get()); + foo.reset(); + EXPECT_EQ(nullptr, foo.get()); +} + +TEST(RefCountedPtr, ResetFromNullToNull) { + RefCountedPtr foo; + EXPECT_EQ(nullptr, foo.get()); + foo.reset(); + EXPECT_EQ(nullptr, foo.get()); +} + +TEST(RefCountedPtr, DerefernceOperators) { + RefCountedPtr foo(new Foo()); + foo->value(); + Foo& foo_ref = *foo; + foo_ref.value(); +} + +TEST(RefCountedPtr, EqualityOperators) { + RefCountedPtr foo(new Foo()); + RefCountedPtr bar = foo; + RefCountedPtr empty; + // Test equality between RefCountedPtrs. + EXPECT_EQ(foo, bar); + EXPECT_NE(foo, empty); + // Test equality with bare pointers. + EXPECT_EQ(foo, foo.get()); + EXPECT_EQ(empty, nullptr); + EXPECT_NE(foo, nullptr); +} + +TEST(RefCountedPtr, Swap) { + Foo* foo = new Foo(); + Foo* bar = new Foo(); + RefCountedPtr ptr1(foo); + RefCountedPtr ptr2(bar); + ptr1.swap(ptr2); + EXPECT_EQ(foo, ptr2.get()); + EXPECT_EQ(bar, ptr1.get()); + RefCountedPtr ptr3; + ptr3.swap(ptr2); + EXPECT_EQ(nullptr, ptr2.get()); + EXPECT_EQ(foo, ptr3.get()); +} + +TEST(MakeRefCounted, NoArgs) { + RefCountedPtr foo = MakeRefCounted(); + EXPECT_EQ(0, foo->value()); +} + +TEST(MakeRefCounted, Args) { + RefCountedPtr foo = MakeRefCounted(3); + EXPECT_EQ(3, foo->value()); +} + +class FooWithTracing : public RefCounted { + public: + FooWithTracing() : RefCounted("FooWithTracing") {} +}; + +TEST(RefCountedPtr, RefCountedWithTracing) { + RefCountedPtr foo(new FooWithTracing()); + RefCountedPtr foo2 = foo->Ref(DEBUG_LOCATION, "foo"); + foo2.release(); + foo->Unref(DEBUG_LOCATION, "foo"); +} + +class BaseClass : public RefCounted { + public: + BaseClass() {} +}; + +class Subclass : public BaseClass { + public: + Subclass() {} +}; + +TEST(RefCountedPtr, ConstructFromSubclass) { + RefCountedPtr p(new Subclass()); +} + +TEST(RefCountedPtr, CopyAssignFromSubclass) { + RefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + RefCountedPtr s = MakeRefCounted(); + b = s; + EXPECT_NE(nullptr, b.get()); +} + +TEST(RefCountedPtr, MoveAssignFromSubclass) { + RefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + RefCountedPtr s = MakeRefCounted(); + b = std::move(s); + EXPECT_NE(nullptr, b.get()); +} + +TEST(RefCountedPtr, ResetFromSubclass) { + RefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + b.reset(new Subclass()); + EXPECT_NE(nullptr, b.get()); +} + +TEST(RefCountedPtr, EqualityWithSubclass) { + Subclass* s = new Subclass(); + RefCountedPtr b(s); + EXPECT_EQ(b, s); +} + +void FunctionTakingBaseClass(RefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(RefCountedPtr, CanPassSubclassToFunctionExpectingBaseClass) { + RefCountedPtr p = MakeRefCounted(); + FunctionTakingBaseClass(p); +} + +void FunctionTakingSubclass(RefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(RefCountedPtr, CanPassSubclassToFunctionExpectingSubclass) { + RefCountedPtr p = MakeRefCounted(); + FunctionTakingSubclass(p); +} + +// +// WeakRefCountedPtr<> tests +// + +class Bar : public DualRefCounted { + public: + Bar() : value_(0) {} + + explicit Bar(int value) : value_(value) {} + + ~Bar() override { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + int value() const { return value_; } + + private: + int value_; + bool shutting_down_ = false; +}; + +TEST(WeakRefCountedPtr, DefaultConstructor) { WeakRefCountedPtr bar; } + +TEST(WeakRefCountedPtr, ExplicitConstructorEmpty) { + WeakRefCountedPtr bar(nullptr); +} + +TEST(WeakRefCountedPtr, ExplicitConstructor) { + RefCountedPtr bar_strong(new Bar()); + bar_strong->WeakRef().release(); + WeakRefCountedPtr bar(bar_strong.get()); +} + +TEST(WeakRefCountedPtr, MoveConstructor) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2(std::move(bar)); + EXPECT_EQ(nullptr, bar.get()); // NOLINT + EXPECT_NE(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, MoveAssignment) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = std::move(bar); + EXPECT_EQ(nullptr, bar.get()); // NOLINT + EXPECT_NE(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyConstructor) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2(bar); + EXPECT_NE(nullptr, bar.get()); + EXPECT_EQ(bar.get(), bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignment) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar; + EXPECT_NE(nullptr, bar.get()); + EXPECT_EQ(bar.get(), bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignmentWhenEmpty) { + WeakRefCountedPtr bar; + WeakRefCountedPtr bar2; + bar2 = bar; + EXPECT_EQ(nullptr, bar.get()); + EXPECT_EQ(nullptr, bar2.get()); +} + +TEST(WeakRefCountedPtr, CopyAssignmentToSelf) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + bar = *&bar; // The "*&" avoids warnings from LLVM -Wself-assign. +} + +TEST(WeakRefCountedPtr, EnclosedScope) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + { + WeakRefCountedPtr bar2(std::move(bar)); + // NOLINTNEXTLINE(bugprone-use-after-move) + EXPECT_EQ(nullptr, bar.get()); + EXPECT_NE(nullptr, bar2.get()); + } + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNullToNonNull) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar; + EXPECT_EQ(nullptr, bar.get()); + bar_strong->WeakRef().release(); + bar.reset(bar_strong.get()); + EXPECT_NE(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNonNullToNonNull) { + RefCountedPtr bar_strong(new Bar()); + RefCountedPtr bar2_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + EXPECT_NE(nullptr, bar.get()); + bar2_strong->WeakRef().release(); + bar.reset(bar2_strong.get()); + EXPECT_NE(nullptr, bar.get()); + EXPECT_NE(bar_strong.get(), bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNonNullToNull) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + EXPECT_NE(nullptr, bar.get()); + bar.reset(); + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, ResetFromNullToNull) { + WeakRefCountedPtr bar; + EXPECT_EQ(nullptr, bar.get()); + bar.reset(); + EXPECT_EQ(nullptr, bar.get()); +} + +TEST(WeakRefCountedPtr, DerefernceOperators) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + bar->value(); + Bar& bar_ref = *bar; + bar_ref.value(); +} + +TEST(WeakRefCountedPtr, EqualityOperators) { + RefCountedPtr bar_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar; + WeakRefCountedPtr empty; + // Test equality between RefCountedPtrs. + EXPECT_EQ(bar, bar2); + EXPECT_NE(bar, empty); + // Test equality with bare pointers. + EXPECT_EQ(bar, bar.get()); + EXPECT_EQ(empty, nullptr); + EXPECT_NE(bar, nullptr); +} + +TEST(WeakRefCountedPtr, Swap) { + RefCountedPtr bar_strong(new Bar()); + RefCountedPtr bar2_strong(new Bar()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar2_strong->WeakRef(); + bar.swap(bar2); + EXPECT_EQ(bar_strong.get(), bar2.get()); + EXPECT_EQ(bar2_strong.get(), bar.get()); + WeakRefCountedPtr bar3; + bar3.swap(bar2); + EXPECT_EQ(nullptr, bar2.get()); + EXPECT_EQ(bar_strong.get(), bar3.get()); +} + +class BarWithTracing : public DualRefCounted { + public: + BarWithTracing() : DualRefCounted("BarWithTracing") {} + + ~BarWithTracing() override { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +TEST(WeakRefCountedPtr, RefCountedWithTracing) { + RefCountedPtr bar_strong(new BarWithTracing()); + WeakRefCountedPtr bar = bar_strong->WeakRef(); + WeakRefCountedPtr bar2 = bar->WeakRef(DEBUG_LOCATION, "bar"); + bar2.release(); + bar->WeakUnref(DEBUG_LOCATION, "bar"); +} + +class WeakBaseClass : public DualRefCounted { + public: + WeakBaseClass() {} + + ~WeakBaseClass() override { GPR_ASSERT(shutting_down_); } + + void Orphan() override { shutting_down_ = true; } + + private: + bool shutting_down_ = false; +}; + +class WeakSubclass : public WeakBaseClass { + public: + WeakSubclass() {} +}; + +TEST(WeakRefCountedPtr, ConstructFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p(strong->WeakRef().release()); +} + +TEST(WeakRefCountedPtr, CopyAssignFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + WeakRefCountedPtr s = strong->WeakRef(); + b = s; + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, MoveAssignFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + WeakRefCountedPtr s = strong->WeakRef(); + b = std::move(s); + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, ResetFromWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b; + EXPECT_EQ(nullptr, b.get()); + b.reset(strong->WeakRef().release()); + EXPECT_NE(nullptr, b.get()); +} + +TEST(WeakRefCountedPtr, EqualityWithWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr b = strong->WeakRef(); + EXPECT_EQ(b, strong.get()); +} + +void FunctionTakingWeakBaseClass(WeakRefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakBaseClass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p = strong->WeakRef(); + FunctionTakingWeakBaseClass(p); +} + +void FunctionTakingWeakSubclass(WeakRefCountedPtr p) { + p.reset(); // To appease clang-tidy. +} + +TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakSubclass) { + RefCountedPtr strong(new WeakSubclass()); + WeakRefCountedPtr p = strong->WeakRef(); + FunctionTakingWeakSubclass(p); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/ref_counted_test.cc b/test/core/gprpp/ref_counted_test.cc new file mode 100644 index 00000000..6e95aa28 --- /dev/null +++ b/test/core/gprpp/ref_counted_test.cc @@ -0,0 +1,193 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/gprpp/ref_counted.h" + +#include + +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class Foo : public RefCounted { + public: + Foo() { + static_assert(std::has_virtual_destructor::value, + "PolymorphicRefCount doesn't have a virtual dtor"); + } +}; + +TEST(RefCounted, Basic) { + Foo* foo = new Foo(); + foo->Unref(); +} + +TEST(RefCounted, ExtraRef) { + Foo* foo = new Foo(); + RefCountedPtr foop = foo->Ref(); + foop.release(); + foo->Unref(); + foo->Unref(); +} + +class Value : public RefCounted { + public: + Value(int value, std::set>* registry) : value_(value) { + registry->emplace(this); + } + + int value() const { return value_; } + + private: + int value_; +}; + +void GarbageCollectRegistry(std::set>* registry) { + for (auto it = registry->begin(); it != registry->end();) { + RefCountedPtr v = (*it)->RefIfNonZero(); + // Check if the object has any refs remaining. + if (v != nullptr) { + // It has refs remaining, so we do not delete it. + ++it; + } else { + // No refs remaining, so remove it from the registry. + it = registry->erase(it); + } + } +} + +TEST(RefCounted, NoDeleteUponUnref) { + std::set> registry; + // Add two objects to the registry. + auto v1 = MakeRefCounted(1, ®istry); + auto v2 = MakeRefCounted(2, ®istry); + EXPECT_THAT(registry, + ::testing::UnorderedElementsAre( + ::testing::Pointee(::testing::Property(&Value::value, 1)), + ::testing::Pointee(::testing::Property(&Value::value, 2)))); + // Running garbage collection should not delete anything, since both + // entries still have refs. + GarbageCollectRegistry(®istry); + EXPECT_THAT(registry, + ::testing::UnorderedElementsAre( + ::testing::Pointee(::testing::Property(&Value::value, 1)), + ::testing::Pointee(::testing::Property(&Value::value, 2)))); + // Unref v2 and run GC to remove it. + v2.reset(); + GarbageCollectRegistry(®istry); + EXPECT_THAT(registry, ::testing::UnorderedElementsAre(::testing::Pointee( + ::testing::Property(&Value::value, 1)))); + // Now unref v1 and run GC again. + v1.reset(); + GarbageCollectRegistry(®istry); + EXPECT_THAT(registry, ::testing::UnorderedElementsAre()); +} + +class ValueInExternalAllocation + : public RefCounted { + public: + explicit ValueInExternalAllocation(int value) : value_(value) {} + + int value() const { return value_; } + + private: + int value_; +}; + +TEST(RefCounted, CallDtorUponUnref) { + std::aligned_storage::type storage; + RefCountedPtr value( + new (&storage) ValueInExternalAllocation(5)); + EXPECT_EQ(value->value(), 5); +} + +class FooNonPolymorphic + : public RefCounted { + public: + FooNonPolymorphic() { + static_assert(!std::has_virtual_destructor::value, + "NonPolymorphicRefCount has a virtual dtor"); + } +}; + +TEST(RefCountedNonPolymorphic, Basic) { + FooNonPolymorphic* foo = new FooNonPolymorphic(); + foo->Unref(); +} + +TEST(RefCountedNonPolymorphic, ExtraRef) { + FooNonPolymorphic* foo = new FooNonPolymorphic(); + RefCountedPtr foop = foo->Ref(); + foop.release(); + foo->Unref(); + foo->Unref(); +} + +class FooWithTracing : public RefCounted { + public: + FooWithTracing() : RefCounted("Foo") {} +}; + +TEST(RefCountedWithTracing, Basic) { + FooWithTracing* foo = new FooWithTracing(); + RefCountedPtr foop = foo->Ref(DEBUG_LOCATION, "extra_ref"); + foop.release(); + foo->Unref(DEBUG_LOCATION, "extra_ref"); + // Can use the no-argument methods, too. + foop = foo->Ref(); + foop.release(); + foo->Unref(); + foo->Unref(DEBUG_LOCATION, "original_ref"); +} + +class FooNonPolymorphicWithTracing + : public RefCounted { + public: + FooNonPolymorphicWithTracing() : RefCounted("FooNonPolymorphicWithTracing") {} +}; + +TEST(RefCountedNonPolymorphicWithTracing, Basic) { + FooNonPolymorphicWithTracing* foo = new FooNonPolymorphicWithTracing(); + RefCountedPtr foop = + foo->Ref(DEBUG_LOCATION, "extra_ref"); + foop.release(); + foo->Unref(DEBUG_LOCATION, "extra_ref"); + // Can use the no-argument methods, too. + foop = foo->Ref(); + foop.release(); + foo->Unref(); + foo->Unref(DEBUG_LOCATION, "original_ref"); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/stat_test.cc b/test/core/gprpp/stat_test.cc new file mode 100644 index 00000000..7f3a0073 --- /dev/null +++ b/test/core/gprpp/stat_test.cc @@ -0,0 +1,75 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/gprpp/stat.h" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +TEST(STAT, GetTimestampOnTmpFile) { + // Create a temporary empty file. + FILE* tmp = nullptr; + char* tmp_name; + tmp = gpr_tmpfile("prefix", &tmp_name); + ASSERT_NE(tmp_name, nullptr); + ASSERT_NE(tmp, nullptr); + fclose(tmp); + // Check the last modified date is correctly set. + time_t timestamp = 0; + absl::Status status = + grpc_core::GetFileModificationTime(tmp_name, ×tamp); + EXPECT_EQ(status.code(), absl::StatusCode::kOk); + EXPECT_GT(timestamp, 0); + // Clean up. + remove(tmp_name); + gpr_free(tmp_name); +} + +TEST(STAT, GetTimestampOnFailure) { + time_t timestamp = 0; + absl::Status status = + grpc_core::GetFileModificationTime("/DOES_NOT_EXIST", ×tamp); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + // Check the last modified date is not set. + EXPECT_EQ(timestamp, 0); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/status_helper_test.cc b/test/core/gprpp/status_helper_test.cc new file mode 100644 index 00000000..9e153867 --- /dev/null +++ b/test/core/gprpp/status_helper_test.cc @@ -0,0 +1,180 @@ +// +// Copyright 2021 the gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/gprpp/status_helper.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "google/rpc/status.upb.h" +#include "upb/upb.hpp" + +namespace grpc_core { +namespace { + +TEST(StatusUtilTest, CreateStatus) { + absl::Status s = + StatusCreate(absl::StatusCode::kUnknown, "Test", DEBUG_LOCATION, + {absl::OkStatus(), absl::CancelledError()}); + EXPECT_EQ(absl::StatusCode::kUnknown, s.code()); + EXPECT_EQ("Test", s.message()); +#ifndef NDEBUG + EXPECT_EQ(true, StatusGetStr(s, StatusStrProperty::kFile).has_value()); + EXPECT_EQ(true, StatusGetInt(s, StatusIntProperty::kFileLine).has_value()); +#endif + EXPECT_EQ(true, StatusGetTime(s, StatusTimeProperty::kCreated).has_value()); + EXPECT_THAT(StatusGetChildren(s), + ::testing::ElementsAre(absl::CancelledError())); +} + +TEST(StatusUtilTest, SetAndGetInt) { + absl::Status s = absl::CancelledError(); + StatusSetInt(&s, StatusIntProperty::kErrorNo, 2021); + EXPECT_EQ(2021, StatusGetInt(s, StatusIntProperty::kErrorNo)); +} + +TEST(StatusUtilTest, GetIntNotExistent) { + absl::Status s = absl::CancelledError(); + EXPECT_EQ(absl::optional(), + StatusGetInt(s, StatusIntProperty::kErrorNo)); +} + +TEST(StatusUtilTest, SetAndGetStr) { + absl::Status s = absl::CancelledError(); + StatusSetStr(&s, StatusStrProperty::kOsError, "value"); + EXPECT_EQ("value", StatusGetStr(s, StatusStrProperty::kOsError)); +} + +TEST(StatusUtilTest, GetStrNotExistent) { + absl::Status s = absl::CancelledError(); + EXPECT_EQ(absl::optional(), + StatusGetStr(s, StatusStrProperty::kOsError)); +} + +TEST(StatusUtilTest, SetAndGetTime) { + absl::Status s = absl::CancelledError(); + absl::Time t = absl::Now(); + StatusSetTime(&s, StatusTimeProperty::kCreated, t); + EXPECT_EQ(t, StatusGetTime(s, StatusTimeProperty::kCreated)); +} + +TEST(StatusUtilTest, GetTimeNotExistent) { + absl::Status s = absl::CancelledError(); + EXPECT_EQ(absl::optional(), + StatusGetTime(s, StatusTimeProperty::kCreated)); +} + +TEST(StatusUtilTest, AddAndGetChildren) { + absl::Status s = absl::CancelledError(); + absl::Status child1 = absl::AbortedError("Message1"); + absl::Status child2 = absl::DeadlineExceededError("Message2"); + StatusAddChild(&s, child1); + StatusAddChild(&s, child2); + EXPECT_THAT(StatusGetChildren(s), ::testing::ElementsAre(child1, child2)); +} + +TEST(StatusUtilTest, ToAndFromProto) { + absl::Status s = absl::CancelledError("Message"); + StatusSetInt(&s, StatusIntProperty::kErrorNo, 2021); + StatusSetStr(&s, StatusStrProperty::kOsError, "value"); + upb::Arena arena; + google_rpc_Status* msg = internal::StatusToProto(s, arena.ptr()); + absl::Status s2 = internal::StatusFromProto(msg); + EXPECT_EQ(s, s2); +} + +TEST(StatusUtilTest, OkToString) { + absl::Status s = absl::OkStatus(); + std::string t = StatusToString(s); + EXPECT_EQ("OK", t); +} + +TEST(StatusUtilTest, CancelledErrorToString) { + absl::Status s = absl::CancelledError(); + std::string t = StatusToString(s); + EXPECT_EQ("CANCELLED", t); +} + +TEST(StatusUtilTest, ErrorWithIntPropertyToString) { + absl::Status s = absl::CancelledError("Message"); + StatusSetInt(&s, StatusIntProperty::kErrorNo, 2021); + std::string t = StatusToString(s); + EXPECT_EQ("CANCELLED:Message {errno:2021}", t); +} + +TEST(StatusUtilTest, ErrorWithStrPropertyToString) { + absl::Status s = absl::CancelledError("Message"); + StatusSetStr(&s, StatusStrProperty::kDescription, "Hey"); + std::string t = StatusToString(s); + EXPECT_EQ("CANCELLED:Message {description:\"Hey\"}", t); +} + +TEST(StatusUtilTest, ErrorWithTimePropertyToString) { + absl::Status s = absl::CancelledError("Message"); + absl::Time t = absl::FromCivil(absl::CivilSecond(2021, 4, 29, 8, 56, 30), + absl::LocalTimeZone()); + StatusSetTime(&s, StatusTimeProperty::kCreated, t); + EXPECT_EQ(StatusToString(s), + absl::StrCat("CANCELLED:Message {created_time:\"", + absl::FormatTime(t), "\"}")); +} + +TEST(StatusUtilTest, ComplexErrorWithChildrenToString) { + absl::Status s = absl::CancelledError("Message"); + StatusSetInt(&s, StatusIntProperty::kErrorNo, 2021); + absl::Status s1 = absl::AbortedError("Message1"); + StatusAddChild(&s, s1); + absl::Status s2 = absl::AlreadyExistsError("Message2"); + StatusSetStr(&s2, StatusStrProperty::kOsError, "value"); + StatusAddChild(&s, s2); + std::string t = StatusToString(s); + EXPECT_EQ( + "CANCELLED:Message {errno:2021, children:[" + "ABORTED:Message1, ALREADY_EXISTS:Message2 {os_error:\"value\"}]}", + t); +} + +TEST(StatusUtilTest, AllocPtr) { + absl::Status statuses[] = {absl::OkStatus(), absl::CancelledError(), + absl::AbortedError("Message")}; + for (const auto& s : statuses) { + uintptr_t p = internal::StatusAllocPtr(s); + EXPECT_EQ(s, internal::StatusGetFromPtr(p)); + internal::StatusFreePtr(p); + } +} + +TEST(StatusUtilTest, AllocHeapPtr) { + absl::Status statuses[] = {absl::OkStatus(), absl::CancelledError(), + absl::AbortedError("Message")}; + for (const auto& s : statuses) { + uintptr_t p = internal::StatusAllocHeapPtr(s); + EXPECT_EQ(s, internal::StatusGetFromHeapPtr(p)); + internal::StatusFreeHeapPtr(p); + } +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/gprpp/table_test.cc b/test/core/gprpp/table_test.cc new file mode 100644 index 00000000..a355a9cc --- /dev/null +++ b/test/core/gprpp/table_test.cc @@ -0,0 +1,168 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/gprpp/table.h" + +#include +#include + +#include + +#include "absl/types/optional.h" + +namespace grpc_core { +namespace testing { + +TEST(Table, InstantiateEmpty) { Table<>(); } + +TEST(Table, NoOp) { + Table t; + EXPECT_EQ(t.get(), nullptr); + EXPECT_EQ(t.get(), nullptr); + EXPECT_EQ(t.get(), nullptr); + EXPECT_EQ(t.get<0>(), nullptr); + EXPECT_EQ(t.get<1>(), nullptr); + EXPECT_EQ(t.get<2>(), nullptr); +} + +TEST(Table, SetTheThings) { + Table t; + t.set(3); + t.set(2.9); + t.set("Hello world!"); + EXPECT_EQ(*t.get(), 3); + EXPECT_EQ(*t.get(), 2.9); + EXPECT_EQ(*t.get(), "Hello world!"); + EXPECT_EQ(*t.get<0>(), 3); + EXPECT_EQ(*t.get<1>(), 2.9); + EXPECT_EQ(*t.get<2>(), "Hello world!"); +} + +TEST(Table, GetDefault) { + Table t; + EXPECT_EQ(*t.get_or_create(), ""); + EXPECT_EQ(*t.get_or_create(), 0.0); + EXPECT_EQ(*t.get_or_create(), 0); +} + +TEST(Table, GetDefaultIndexed) { + Table t; + EXPECT_EQ(*t.get_or_create<2>(), ""); + EXPECT_EQ(*t.get_or_create<1>(), 0.0); + EXPECT_EQ(*t.get_or_create<0>(), 0); +} + +TEST(Table, Copy) { + Table t; + t.set("abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(*t.get(), "abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(t.get(), nullptr); + Table u(t); + EXPECT_EQ(*u.get(), "abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(*t.get(), "abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(t.get(), nullptr); + EXPECT_EQ(u.get(), nullptr); + u.set("hello"); + EXPECT_EQ(*u.get<1>(), "hello"); + EXPECT_EQ(*t.get<1>(), "abcdefghijklmnopqrstuvwxyz"); + t = u; + EXPECT_EQ(*u.get(), "hello"); + EXPECT_EQ(*t.get(), "hello"); +} + +TEST(Table, Move) { + Table t; + t.set("abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(*t.get(), "abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(t.get(), nullptr); + Table u(std::move(t)); + EXPECT_NE(t.get(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(*u.get(), "abcdefghijklmnopqrstuvwxyz"); + EXPECT_EQ(t.get(), nullptr); + EXPECT_EQ(u.get(), nullptr); + u.set("hello"); + EXPECT_EQ(*u.get<1>(), "hello"); + t = std::move(u); + EXPECT_NE(u.get(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(*t.get(), "hello"); +} + +TEST(Table, SameTypes) { + Table t; + // The following lines should not compile: + // t.get(); + // t.has<4>(); + // t.get<4>(); + // t.clear<4>(); + EXPECT_EQ(t.get<0>(), nullptr); + EXPECT_EQ(t.get<1>(), nullptr); + EXPECT_EQ(t.get<2>(), nullptr); + t.set<1>("Hello!"); + EXPECT_EQ(t.get<0>(), nullptr); + EXPECT_EQ(*t.get<1>(), "Hello!"); + EXPECT_EQ(t.get<2>(), nullptr); +} + +TEST(Table, ForEach) { + Table t; + t.set<0>(1); + t.set<1>(2); + t.set<2>(3); + int i = 1; + t.ForEach([&i](int x) { + EXPECT_EQ(x, i); + i++; + }); +} + +#if !defined(_MSC_VER) +// Test suite proving this is memory efficient compared to +// tuple...> +// TODO(ctiller): determine why this test doesn't compile under MSVC. +// For now whether it passes or not in that one environment is probably +// immaterial. + +template +struct TableSizeTest : public ::testing::Test {}; + +using SizeTests = ::testing::Types< + std::tuple, std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TableSizeTest, SizeTests); + +template +int sizeof_tuple_of_optionals(std::tuple*) { + return sizeof(std::tuple...>); +} + +template +int sizeof_table(std::tuple*) { + return sizeof(Table); +} + +TYPED_TEST(TableSizeTest, SmallerThanTupleOfOptionals) { + EXPECT_GE(sizeof_tuple_of_optionals(static_cast(nullptr)), + sizeof_table(static_cast(nullptr))); +} +#endif + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/gprpp/thd_test.cc b/test/core/gprpp/thd_test.cc new file mode 100644 index 00000000..5cf5e443 --- /dev/null +++ b/test/core/gprpp/thd_test.cc @@ -0,0 +1,101 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test of gpr thread support. */ + +#include "src/core/lib/gprpp/thd.h" + +#include +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +#define NUM_THREADS 100 + +struct test { + gpr_mu mu; + int n; + int is_done; + gpr_cv done_cv; +}; + +/* A Thread body. Decrement t->n, and if is becomes zero, set t->done. */ +static void thd_body1(void* v) { + struct test* t = static_cast(v); + gpr_mu_lock(&t->mu); + t->n--; + if (t->n == 0) { + t->is_done = 1; + gpr_cv_signal(&t->done_cv); + } + gpr_mu_unlock(&t->mu); +} + +/* Test that we can create a number of threads, wait for them, and join them. */ +static void test1(void) { + grpc_core::Thread thds[NUM_THREADS]; + struct test t; + gpr_mu_init(&t.mu); + gpr_cv_init(&t.done_cv); + t.n = NUM_THREADS; + t.is_done = 0; + for (auto& th : thds) { + th = grpc_core::Thread("grpc_thread_body1_test", &thd_body1, &t); + th.Start(); + } + gpr_mu_lock(&t.mu); + while (!t.is_done) { + gpr_cv_wait(&t.done_cv, &t.mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&t.mu); + for (auto& th : thds) { + th.Join(); + } + GPR_ASSERT(t.n == 0); + gpr_mu_destroy(&t.mu); + gpr_cv_destroy(&t.done_cv); +} + +static void thd_body2(void* /*v*/) {} + +/* Test that we can create a number of threads and join them. */ +static void test2(void) { + grpc_core::Thread thds[NUM_THREADS]; + for (auto& th : thds) { + bool ok; + th = grpc_core::Thread("grpc_thread_body2_test", &thd_body2, nullptr, &ok); + GPR_ASSERT(ok); + th.Start(); + } + for (auto& th : thds) { + th.Join(); + } +} + +/* ------------------------------------------------- */ + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + test1(); + test2(); + return 0; +} diff --git a/test/core/gprpp/time_util_test.cc b/test/core/gprpp/time_util_test.cc new file mode 100644 index 00000000..0fbbce5b --- /dev/null +++ b/test/core/gprpp/time_util_test.cc @@ -0,0 +1,134 @@ +// +// Copyright 2021 the gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/gprpp/time_util.h" + +#include +#include + +#include + +#include "absl/time/time.h" + +#include + +TEST(TimeUtilTest, ToGprTimeSpecFromAbslDurationWithRegularValues) { + std::vector times = {-10, -1, 0, 1, 10}; + for (int t : times) { + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_nanos(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Nanoseconds(t)))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_micros(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Microseconds(t)))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_millis(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Milliseconds(t)))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_seconds(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Seconds(t)))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_minutes(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Minutes(t)))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_from_hours(t, GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::Hours(t)))); + } +} + +TEST(TimeUtilTest, ToGprTimeSpecFromAbslDurationWithInfinites) { + EXPECT_EQ(0, + gpr_time_cmp(gpr_inf_past(GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(-absl::InfiniteDuration()))); + EXPECT_EQ(0, gpr_time_cmp(gpr_time_0(GPR_TIMESPAN), + grpc_core::ToGprTimeSpec(absl::ZeroDuration()))); +} + +TEST(TimeUtilTest, ToGprTimeSpecFromAbslTimeWithRegularValues) { + std::vector times = {0, 10, 100000000}; + for (int t : times) { + EXPECT_EQ(0, + gpr_time_cmp(gpr_time_from_nanos(t, GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::FromUnixNanos(t)))); + EXPECT_EQ(0, + gpr_time_cmp(gpr_time_from_micros(t, GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::FromUnixMicros(t)))); + EXPECT_EQ(0, + gpr_time_cmp(gpr_time_from_millis(t, GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::FromUnixMillis(t)))); + EXPECT_EQ(0, + gpr_time_cmp(gpr_time_from_seconds(t, GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::FromUnixSeconds(t)))); + } +} + +TEST(TimeUtilTest, ToGprTimeSpecFromAbslTimeWithInfinites) { + EXPECT_EQ(0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::InfiniteFuture()))); + EXPECT_EQ(0, gpr_time_cmp(gpr_inf_past(GPR_CLOCK_REALTIME), + grpc_core::ToGprTimeSpec(absl::InfinitePast()))); +} + +TEST(TimeUtilTest, ToAbslDurationWithRegularValues) { + std::vector times = {-10, -1, 0, 1, 10}; + for (int t : times) { + EXPECT_EQ(absl::Nanoseconds(t), + grpc_core::ToAbslDuration(gpr_time_from_nanos(t, GPR_TIMESPAN))); + EXPECT_EQ(absl::Microseconds(t), + grpc_core::ToAbslDuration(gpr_time_from_micros(t, GPR_TIMESPAN))); + EXPECT_EQ(absl::Milliseconds(t), + grpc_core::ToAbslDuration(gpr_time_from_millis(t, GPR_TIMESPAN))); + EXPECT_EQ(absl::Seconds(t), grpc_core::ToAbslDuration( + gpr_time_from_seconds(t, GPR_TIMESPAN))); + EXPECT_EQ(absl::Minutes(t), grpc_core::ToAbslDuration( + gpr_time_from_minutes(t, GPR_TIMESPAN))); + EXPECT_EQ(absl::Hours(t), + grpc_core::ToAbslDuration(gpr_time_from_hours(t, GPR_TIMESPAN))); + } +} + +TEST(TimeUtilTest, ToAbslDurationWithInfinites) { + EXPECT_EQ(absl::InfiniteDuration(), + grpc_core::ToAbslDuration(gpr_inf_future(GPR_TIMESPAN))); + EXPECT_EQ(-absl::InfiniteDuration(), + grpc_core::ToAbslDuration(gpr_inf_past(GPR_TIMESPAN))); +} + +TEST(TimeUtilTest, ToAbslTimeWithRegularValues) { + std::vector times = {0, 10, 100000000}; + for (int t : times) { + EXPECT_EQ(absl::FromUnixNanos(t), grpc_core::ToAbslTime(gpr_time_from_nanos( + t, GPR_CLOCK_REALTIME))); + EXPECT_EQ( + absl::FromUnixMicros(t), + grpc_core::ToAbslTime(gpr_time_from_micros(t, GPR_CLOCK_REALTIME))); + EXPECT_EQ( + absl::FromUnixMillis(t), + grpc_core::ToAbslTime(gpr_time_from_millis(t, GPR_CLOCK_REALTIME))); + EXPECT_EQ( + absl::FromUnixSeconds(t), + grpc_core::ToAbslTime(gpr_time_from_seconds(t, GPR_CLOCK_REALTIME))); + } +} + +TEST(TimeUtilTest, ToAbslTimeWithInfinites) { + EXPECT_EQ(absl::InfiniteFuture(), + grpc_core::ToAbslTime(gpr_inf_future(GPR_CLOCK_REALTIME))); + EXPECT_EQ(absl::InfinitePast(), + grpc_core::ToAbslTime(gpr_inf_past(GPR_CLOCK_REALTIME))); + EXPECT_EQ(absl::UnixEpoch(), + grpc_core::ToAbslTime(gpr_time_0(GPR_CLOCK_REALTIME))); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/handshake/client_ssl.cc b/test/core/handshake/client_ssl.cc new file mode 100644 index 00000000..a18e971e --- /dev/null +++ b/test/core/handshake/client_ssl.cc @@ -0,0 +1,407 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define SSL_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SSL_KEY_PATH "src/core/tsi/test_creds/server1.key" +#define SSL_CA_PATH "src/core/tsi/test_creds/ca.pem" + +grpc_core::TraceFlag client_ssl_tsi_tracing_enabled(false, "tsi"); + +class SslLibraryInfo { + public: + SslLibraryInfo() {} + + void Notify() { + grpc_core::MutexLock lock(&mu_); + ready_ = true; + cv_.Signal(); + } + + void Await() { + grpc_core::MutexLock lock(&mu_); + while (!ready_) { + cv_.Wait(&mu_); + } + } + + private: + grpc_core::Mutex mu_; + grpc_core::CondVar cv_; + bool ready_ ABSL_GUARDED_BY(mu_) = false; +}; + +// Arguments for TLS server thread. +typedef struct { + int socket; + char* alpn_preferred; + SslLibraryInfo* ssl_library_info; +} server_args; + +// Based on https://wiki.openssl.org/index.php/Simple_TLS_Server. +// Pick an arbitrary unused port and return it in *out_port. Return +// an fd>=0 on success. +static int create_socket(int* out_port) { + int s; + struct sockaddr_in addr; + socklen_t addr_len; + *out_port = -1; + + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + + s = socket(AF_INET, SOCK_STREAM, 0); + if (s < 0) { + perror("Unable to create socket"); + return -1; + } + + if (bind(s, reinterpret_cast(&addr), sizeof(addr)) < 0) { + perror("Unable to bind"); + gpr_log(GPR_ERROR, "%s", "Unable to bind to any port"); + close(s); + return -1; + } + + if (listen(s, 1) < 0) { + perror("Unable to listen"); + close(s); + return -1; + } + + addr_len = sizeof(addr); + if (getsockname(s, reinterpret_cast(&addr), &addr_len) != + 0 || + addr_len > sizeof(addr)) { + perror("getsockname"); + gpr_log(GPR_ERROR, "%s", "Unable to get socket local address"); + close(s); + return -1; + } + + *out_port = ntohs(addr.sin_port); + return s; +} + +// Server callback during ALPN negotiation. See man page for +// SSL_CTX_set_alpn_select_cb. +static int alpn_select_cb(SSL* /*ssl*/, const uint8_t** out, uint8_t* out_len, + const uint8_t* in, unsigned in_len, void* arg) { + const uint8_t* alpn_preferred = static_cast(arg); + + *out = alpn_preferred; + *out_len = static_cast( + strlen(reinterpret_cast(alpn_preferred))); + + // Validate that the ALPN list includes "h2" and "grpc-exp", that "grpc-exp" + // precedes "h2". + bool grpc_exp_seen = false; + bool h2_seen = false; + const char* inp = reinterpret_cast(in); + const char* in_end = inp + in_len; + while (inp < in_end) { + const size_t length = static_cast(*inp++); + if (length == strlen("grpc-exp") && strncmp(inp, "grpc-exp", length) == 0) { + grpc_exp_seen = true; + GPR_ASSERT(!h2_seen); + } + if (length == strlen("h2") && strncmp(inp, "h2", length) == 0) { + h2_seen = true; + GPR_ASSERT(grpc_exp_seen); + } + inp += length; + } + + GPR_ASSERT(inp == in_end); + GPR_ASSERT(grpc_exp_seen); + GPR_ASSERT(h2_seen); + + return SSL_TLSEXT_ERR_OK; +} + +static void ssl_log_where_info(const SSL* ssl, int where, int flag, + const char* msg) { + if ((where & flag) && + GRPC_TRACE_FLAG_ENABLED(client_ssl_tsi_tracing_enabled)) { + gpr_log(GPR_INFO, "%20.20s - %30.30s - %5.10s", msg, + SSL_state_string_long(ssl), SSL_state_string(ssl)); + } +} + +static void ssl_server_info_callback(const SSL* ssl, int where, int ret) { + if (ret == 0) { + gpr_log(GPR_ERROR, "ssl_server_info_callback: error occurred.\n"); + return; + } + + ssl_log_where_info(ssl, where, SSL_CB_LOOP, "Server: LOOP"); + ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, + "Server: HANDSHAKE START"); + ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, + "Server: HANDSHAKE DONE"); +} + +// Minimal TLS server. This is largely based on the example at +// https://wiki.openssl.org/index.php/Simple_TLS_Server and the gRPC core +// internals in src/core/tsi/ssl_transport_security.c. +static void server_thread(void* arg) { + const server_args* args = static_cast(arg); + + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); + args->ssl_library_info->Notify(); + + const SSL_METHOD* method = TLSv1_2_server_method(); + SSL_CTX* ctx = SSL_CTX_new(method); + if (!ctx) { + perror("Unable to create SSL context"); + ERR_print_errors_fp(stderr); + abort(); + } + + // Load key pair. + if (SSL_CTX_use_certificate_file(ctx, SSL_CERT_PATH, SSL_FILETYPE_PEM) < 0) { + perror("Unable to use certificate file."); + ERR_print_errors_fp(stderr); + abort(); + } + if (SSL_CTX_use_PrivateKey_file(ctx, SSL_KEY_PATH, SSL_FILETYPE_PEM) < 0) { + perror("Unable to use private key file."); + ERR_print_errors_fp(stderr); + abort(); + } + if (SSL_CTX_check_private_key(ctx) != 1) { + perror("Check private key failed."); + ERR_print_errors_fp(stderr); + abort(); + } + + // Set the cipher list to match the one expressed in + // src/core/tsi/ssl_transport_security.cc. + const char* cipher_list = + "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-" + "SHA384:ECDHE-RSA-AES256-GCM-SHA384"; + if (!SSL_CTX_set_cipher_list(ctx, cipher_list)) { + ERR_print_errors_fp(stderr); + gpr_log(GPR_ERROR, "Couldn't set server cipher list."); + abort(); + } + + // Enable automatic curve selection. This is a NO-OP when using OpenSSL + // versions > 1.0.2. + if (!SSL_CTX_set_ecdh_auto(ctx, /*onoff=*/1)) { + ERR_print_errors_fp(stderr); + gpr_log(GPR_ERROR, "Couldn't set automatic curve selection."); + abort(); + } + + // Register the ALPN selection callback. + SSL_CTX_set_alpn_select_cb(ctx, alpn_select_cb, args->alpn_preferred); + + // bind/listen/accept at TCP layer. + const int sock = args->socket; + gpr_log(GPR_INFO, "Server listening"); + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + const int client = + accept(sock, reinterpret_cast(&addr), &len); + if (client < 0) { + perror("Unable to accept"); + abort(); + } + + // Establish a SSL* and accept at SSL layer. + SSL* ssl = SSL_new(ctx); + SSL_set_info_callback(ssl, ssl_server_info_callback); + GPR_ASSERT(ssl); + SSL_set_fd(ssl, client); + if (SSL_accept(ssl) <= 0) { + ERR_print_errors_fp(stderr); + gpr_log(GPR_ERROR, "Handshake failed."); + } else { + gpr_log(GPR_INFO, "Handshake successful."); + } + + // Send out the settings frame. + const char settings_frame[] = "\x00\x00\x00\x04\x00\x00\x00\x00\x00"; + SSL_write(ssl, settings_frame, sizeof(settings_frame) - 1); + + // Wait until the client drops its connection. + char buf; + while (SSL_read(ssl, &buf, sizeof(buf)) > 0) { + } + + SSL_free(ssl); + close(client); + close(sock); + SSL_CTX_free(ctx); +} + +// This test launches a minimal TLS server on a separate thread and then +// establishes a TLS handshake via the core library to the server. The TLS +// server validates ALPN aspects of the handshake and supplies the protocol +// specified in the server_alpn_preferred argument to the client. +static bool client_ssl_test(char* server_alpn_preferred) { + bool success = true; + + grpc_init(); + + // Find a port we can bind to. Retries added to handle flakes in port server + // and port picking. + int port = -1; + int server_socket = -1; + int socket_retries = 30; + while (server_socket == -1 && socket_retries-- > 0) { + server_socket = create_socket(&port); + if (server_socket == -1) { + sleep(1); + } + } + GPR_ASSERT(server_socket > 0 && port > 0); + + // Launch the TLS server thread. + SslLibraryInfo ssl_library_info; + server_args args = {server_socket, server_alpn_preferred, &ssl_library_info}; + bool ok; + grpc_core::Thread thd("grpc_client_ssl_test", server_thread, &args, &ok); + GPR_ASSERT(ok); + thd.Start(); + ssl_library_info.Await(); + + // Load key pair and establish client SSL credentials. + grpc_ssl_pem_key_cert_pair pem_key_cert_pair; + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CA_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + pem_key_cert_pair.private_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + pem_key_cert_pair.cert_chain = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create( + ca_cert, &pem_key_cert_pair, nullptr, nullptr); + + // Establish a channel pointing at the TLS server. Since the gRPC runtime is + // lazy, this won't necessarily establish a connection yet. + std::string target = absl::StrCat("127.0.0.1:", port); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args grpc_args; + grpc_args.num_args = 1; + grpc_args.args = &ssl_name_override; + grpc_channel* channel = grpc_secure_channel_create(ssl_creds, target.c_str(), + &grpc_args, nullptr); + GPR_ASSERT(channel); + + // Initially the channel will be idle, the + // grpc_channel_check_connectivity_state triggers an attempt to connect. + GPR_ASSERT(grpc_channel_check_connectivity_state( + channel, 1 /* try_to_connect */) == GRPC_CHANNEL_IDLE); + + // Wait a bounded number of times for the channel to be ready. When the + // channel is ready, the initial TLS handshake will have successfully + // completed and we know that the client's ALPN list satisfied the server. + int retries = 10; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + while (state != GRPC_CHANNEL_READY && retries-- > 0) { + grpc_channel_watch_connectivity_state( + channel, state, grpc_timeout_seconds_to_deadline(3), cq, nullptr); + gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(5); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + state = + grpc_channel_check_connectivity_state(channel, 0 /* try_to_connect */); + } + grpc_completion_queue_destroy(cq); + if (retries < 0) { + success = false; + } + + grpc_channel_destroy(channel); + grpc_channel_credentials_release(ssl_creds); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + + thd.Join(); + + grpc_shutdown(); + + return success; +} + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + // Handshake succeeeds when the server has grpc-exp as the ALPN preference. + GPR_ASSERT(client_ssl_test(const_cast("grpc-exp"))); + // Handshake succeeeds when the server has h2 as the ALPN preference. This + // covers legacy gRPC servers which don't support grpc-exp. + GPR_ASSERT(client_ssl_test(const_cast("h2"))); + // Handshake fails when the server uses a fake protocol as its ALPN + // preference. This validates the client is correctly validating ALPN returns + // and sanity checks the client_ssl_test. + GPR_ASSERT(!client_ssl_test(const_cast("foo"))); + // Clean up the SSL libraries. + EVP_cleanup(); + return 0; +} + +#else /* GRPC_POSIX_SOCKET_TCP */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/test/core/handshake/readahead_handshaker_server_ssl.cc b/test/core/handshake/readahead_handshaker_server_ssl.cc new file mode 100644 index 00000000..120490e7 --- /dev/null +++ b/test/core/handshake/readahead_handshaker_server_ssl.cc @@ -0,0 +1,94 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/handshaker_factory.h" +#include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/transport/security_handshaker.h" +#include "test/core/handshake/server_ssl_common.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* The purpose of this test is to exercise the case when a + * grpc *security_handshaker* begins its handshake with data already + * in the read buffer of the handshaker arg. This scenario is created by + * adding a fake "readahead" handshaker at the beginning of the server's + * handshaker list, which just reads from the connection and then places + * read bytes into the read buffer of the handshake arg (to be passed down + * to the security_handshaker). This test is meant to protect code relying on + * this functionality that lives outside of this repo. */ + +namespace grpc_core { + +class ReadAheadHandshaker : public Handshaker { + public: + ~ReadAheadHandshaker() override {} + const char* name() const override { return "read_ahead"; } + void Shutdown(grpc_error_handle /*why*/) override {} + void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/, + grpc_closure* on_handshake_done, + HandshakerArgs* args) override { + grpc_endpoint_read(args->endpoint, args->read_buffer, on_handshake_done, + /*urgent=*/false); + } +}; + +class ReadAheadHandshakerFactory : public HandshakerFactory { + public: + void AddHandshakers(const grpc_channel_args* /*args*/, + grpc_pollset_set* /*interested_parties*/, + HandshakeManager* handshake_mgr) override { + handshake_mgr->Add(MakeRefCounted()); + } + ~ReadAheadHandshakerFactory() override = default; +}; + +} // namespace grpc_core + +int main(int /*argc*/, char* /*argv*/[]) { + grpc_core::CoreConfiguration::BuildSpecialConfiguration( + [](grpc_core::CoreConfiguration::Builder* builder) { + BuildCoreConfiguration(builder); + builder->handshaker_registry()->RegisterHandshakerFactory( + true /* at_start */, grpc_core::HANDSHAKER_SERVER, + absl::make_unique()); + }); + + grpc_init(); + const char* full_alpn_list[] = {"grpc-exp", "h2"}; + GPR_ASSERT(server_ssl_test(full_alpn_list, 2, "grpc-exp")); + CleanupSslLibrary(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/handshake/server_ssl.cc b/test/core/handshake/server_ssl.cc new file mode 100644 index 00000000..f759b70a --- /dev/null +++ b/test/core/handshake/server_ssl.cc @@ -0,0 +1,59 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/handshake/server_ssl_common.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + // Handshake succeeeds when the client supplies the standard ALPN list. + const char* full_alpn_list[] = {"grpc-exp", "h2"}; + GPR_ASSERT(server_ssl_test(full_alpn_list, 2, "grpc-exp")); + // Handshake succeeeds when the client supplies only h2 as the ALPN list. This + // covers legacy gRPC clients which don't support grpc-exp. + const char* h2_only_alpn_list[] = {"h2"}; + GPR_ASSERT(server_ssl_test(h2_only_alpn_list, 1, "h2")); + // Handshake succeeds when the client supplies superfluous ALPN entries and + // also when h2 precedes gprc-exp. + const char* extra_alpn_list[] = {"foo", "h2", "bar", "grpc-exp"}; + GPR_ASSERT(server_ssl_test(extra_alpn_list, 4, "h2")); + // Handshake fails when the client uses a fake protocol as its only ALPN + // preference. This validates the server is correctly validating ALPN + // and sanity checks the server_ssl_test. + const char* fake_alpn_list[] = {"foo"}; + GPR_ASSERT(!server_ssl_test(fake_alpn_list, 1, "foo")); + CleanupSslLibrary(); + return 0; +} diff --git a/test/core/handshake/server_ssl_common.cc b/test/core/handshake/server_ssl_common.cc new file mode 100644 index 00000000..7b891467 --- /dev/null +++ b/test/core/handshake/server_ssl_common.cc @@ -0,0 +1,284 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/handshake/server_ssl_common.h" + +#include +#include +#include +#include + +#include + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define SSL_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SSL_KEY_PATH "src/core/tsi/test_creds/server1.key" +#define SSL_CA_PATH "src/core/tsi/test_creds/ca.pem" + +namespace { + +// Handshake completed signal to server thread. +gpr_event client_handshake_complete; + +int create_socket(int port) { + int s; + struct sockaddr_in addr; + + addr.sin_family = AF_INET; + addr.sin_port = htons(static_cast(port)); + addr.sin_addr.s_addr = htonl(INADDR_ANY); + + s = socket(AF_INET, SOCK_STREAM, 0); + if (s < 0) { + perror("Unable to create socket"); + return -1; + } + + if (connect(s, reinterpret_cast(&addr), sizeof(addr)) < 0) { + perror("Unable to connect"); + return -1; + } + + return s; +} + +class ServerInfo { + public: + explicit ServerInfo(int p) : port_(p) {} + + int port() const { return port_; } + + void Activate() { + grpc_core::MutexLock lock(&mu_); + ready_ = true; + cv_.Signal(); + } + + void Await() { + grpc_core::MutexLock lock(&mu_); + while (!ready_) { + cv_.Wait(&mu_); + } + } + + private: + const int port_; + grpc_core::Mutex mu_; + grpc_core::CondVar cv_; + bool ready_ ABSL_GUARDED_BY(mu_) = false; +}; + +// Simple gRPC server. This listens until client_handshake_complete occurs. +void server_thread(void* arg) { + ServerInfo* s = static_cast(arg); + const int port = s->port(); + + // Load key pair and establish server SSL credentials. + grpc_ssl_pem_key_cert_pair pem_key_cert_pair; + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CA_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + pem_key_cert_pair.private_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + pem_key_cert_pair.cert_chain = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + ca_cert, &pem_key_cert_pair, 1, 0, nullptr); + + // Start server listening on local port. + std::string addr = absl::StrCat("127.0.0.1:", port); + grpc_server* server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT( + grpc_server_add_secure_http2_port(server, addr.c_str(), ssl_creds)); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + grpc_server_register_completion_queue(server, cq, nullptr); + grpc_server_start(server); + + // Notify the other side that it is now ok to start working since SSL is + // definitely already started. + s->Activate(); + + // Wait a bounded number of time until client_handshake_complete is set, + // sleeping between polls. + int retries = 10; + while (!gpr_event_get(&client_handshake_complete) && retries-- > 0) { + const gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(1); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_TIMEOUT); + } + + gpr_log(GPR_INFO, "Shutting down server"); + grpc_server_shutdown_and_notify(server, cq, nullptr); + grpc_completion_queue_shutdown(cq); + + const gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(5); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); + grpc_server_credentials_release(ssl_creds); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); +} + +} // namespace + +// This test launches a gRPC server on a separate thread and then establishes a +// TLS handshake via a minimal TLS client. The TLS client has configurable (via +// alpn_list) ALPN settings and can probe at the supported ALPN preferences +// using this (via alpn_expected). +bool server_ssl_test(const char* alpn_list[], unsigned int alpn_list_len, + const char* alpn_expected) { + bool success = true; + + grpc_init(); + ServerInfo s(grpc_pick_unused_port_or_die()); + gpr_event_init(&client_handshake_complete); + + // Launch the gRPC server thread. + bool ok; + grpc_core::Thread thd("grpc_ssl_test", server_thread, &s, &ok); + GPR_ASSERT(ok); + thd.Start(); + + // The work in server_thread will cause the SSL initialization to take place + // so long as we wait for it to reach beyond the point of adding a secure + // server port. + s.Await(); + + const SSL_METHOD* method = TLSv1_2_client_method(); + SSL_CTX* ctx = SSL_CTX_new(method); + if (!ctx) { + perror("Unable to create SSL context"); + ERR_print_errors_fp(stderr); + abort(); + } + + // Load key pair. + if (SSL_CTX_use_certificate_file(ctx, SSL_CERT_PATH, SSL_FILETYPE_PEM) < 0) { + ERR_print_errors_fp(stderr); + abort(); + } + if (SSL_CTX_use_PrivateKey_file(ctx, SSL_KEY_PATH, SSL_FILETYPE_PEM) < 0) { + ERR_print_errors_fp(stderr); + abort(); + } + + // Set the cipher list to match the one expressed in + // src/core/tsi/ssl_transport_security.c. + const char* cipher_list = + "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-" + "SHA384:ECDHE-RSA-AES256-GCM-SHA384"; + if (!SSL_CTX_set_cipher_list(ctx, cipher_list)) { + ERR_print_errors_fp(stderr); + gpr_log(GPR_ERROR, "Couldn't set server cipher list."); + abort(); + } + + // Configure ALPN list the client will send to the server. This must match the + // wire format, see documentation for SSL_CTX_set_alpn_protos. + unsigned int alpn_protos_len = alpn_list_len; + for (unsigned int i = 0; i < alpn_list_len; ++i) { + alpn_protos_len += static_cast(strlen(alpn_list[i])); + } + unsigned char* alpn_protos = + static_cast(gpr_malloc(alpn_protos_len)); + unsigned char* p = alpn_protos; + for (unsigned int i = 0; i < alpn_list_len; ++i) { + const uint8_t len = static_cast(strlen(alpn_list[i])); + *p++ = len; + memcpy(p, alpn_list[i], len); + p += len; + } + GPR_ASSERT(SSL_CTX_set_alpn_protos(ctx, alpn_protos, alpn_protos_len) == 0); + + // Try and connect to server. We allow a bounded number of retries as we might + // be racing with the server setup on its separate thread. + int retries = 10; + int sock = -1; + while (sock == -1 && retries-- > 0) { + sock = create_socket(s.port()); + if (sock < 0) { + sleep(1); + } + } + GPR_ASSERT(sock > 0); + gpr_log(GPR_INFO, "Connected to server on port %d", s.port()); + + // Establish a SSL* and connect at SSL layer. + SSL* ssl = SSL_new(ctx); + GPR_ASSERT(ssl); + SSL_set_fd(ssl, sock); + if (SSL_connect(ssl) <= 0) { + ERR_print_errors_fp(stderr); + gpr_log(GPR_ERROR, "Handshake failed."); + success = false; + } else { + gpr_log(GPR_INFO, "Handshake successful."); + // Validate ALPN preferred by server matches alpn_expected. + const unsigned char* alpn_selected; + unsigned int alpn_selected_len; + SSL_get0_alpn_selected(ssl, &alpn_selected, &alpn_selected_len); + if (strlen(alpn_expected) != alpn_selected_len || + strncmp(reinterpret_cast(alpn_selected), alpn_expected, + alpn_selected_len) != 0) { + gpr_log(GPR_ERROR, "Unexpected ALPN protocol preference"); + success = false; + } + } + gpr_event_set(&client_handshake_complete, &client_handshake_complete); + + SSL_free(ssl); + gpr_free(alpn_protos); + SSL_CTX_free(ctx); + close(sock); + + thd.Join(); + + grpc_shutdown(); + + return success; +} + +void CleanupSslLibrary() { EVP_cleanup(); } diff --git a/test/core/handshake/verify_peer_options.cc b/test/core/handshake/verify_peer_options.cc new file mode 100644 index 00000000..606fb2cc --- /dev/null +++ b/test/core/handshake/verify_peer_options.cc @@ -0,0 +1,285 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define SSL_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SSL_KEY_PATH "src/core/tsi/test_creds/server1.key" +#define SSL_CA_PATH "src/core/tsi/test_creds/ca.pem" + +// Simple gRPC server. This listens until client_handshake_complete occurs. +static gpr_event client_handshake_complete; + +static void server_thread(void* arg) { + const int port = *static_cast(arg); + + // Load key pair and establish server SSL credentials. + grpc_ssl_pem_key_cert_pair pem_key_cert_pair; + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CA_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + pem_key_cert_pair.private_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + pem_key_cert_pair.cert_chain = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + ca_cert, &pem_key_cert_pair, 1, 0, nullptr); + + // Start server listening on local port. + std::string addr = absl::StrCat("127.0.0.1:", port); + grpc_server* server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT( + grpc_server_add_secure_http2_port(server, addr.c_str(), ssl_creds)); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + grpc_server_register_completion_queue(server, cq, nullptr); + grpc_server_start(server); + + // Wait a bounded number of time until client_handshake_complete is set, + // sleeping between polls. The total time spent (deadline * retries) + // should be strictly greater than the client retry limit so that the + // client will always timeout first. + int retries = 60; + while (!gpr_event_get(&client_handshake_complete) && retries-- > 0) { + const gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(1); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_TIMEOUT); + } + + gpr_log(GPR_INFO, "Shutting down server"); + grpc_server_shutdown_and_notify(server, cq, nullptr); + grpc_server_cancel_all_calls(server); + grpc_completion_queue_shutdown(cq); + + const gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(60); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); + grpc_server_credentials_release(ssl_creds); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); +} + +// This test launches a minimal TLS grpc server on a separate thread and then +// establishes a TLS handshake via the core library to the server. The client +// uses the supplied verify options. +static bool verify_peer_options_test(verify_peer_options* verify_options) { + bool success = true; + + grpc_init(); + int port = grpc_pick_unused_port_or_die(); + gpr_event_init(&client_handshake_complete); + + // Load key pair and establish client SSL credentials. + // NOTE: we intentionally load the credential files before starting + // the server thread because grpc_load_file can experience trouble + // when two threads attempt to load the same file concurrently + // and server thread also reads the same files as soon as it starts. + // See https://github.com/grpc/grpc/issues/23503 for details. + grpc_ssl_pem_key_cert_pair pem_key_cert_pair; + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CA_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + pem_key_cert_pair.private_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + pem_key_cert_pair.cert_chain = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create( + ca_cert, &pem_key_cert_pair, verify_options, nullptr); + + // Launch the gRPC server thread. + bool ok; + grpc_core::Thread thd("grpc_client_ssl_test", server_thread, &port, &ok); + GPR_ASSERT(ok); + thd.Start(); + + // Establish a channel pointing at the TLS server. Since the gRPC runtime is + // lazy, this won't necessarily establish a connection yet. + std::string target = absl::StrCat("127.0.0.1:", port); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args grpc_args; + grpc_args.num_args = 1; + grpc_args.args = &ssl_name_override; + grpc_channel* channel = grpc_secure_channel_create(ssl_creds, target.c_str(), + &grpc_args, nullptr); + GPR_ASSERT(channel); + + // Initially the channel will be idle, the + // grpc_channel_check_connectivity_state triggers an attempt to connect. + GPR_ASSERT(grpc_channel_check_connectivity_state( + channel, 1 /* try_to_connect */) == GRPC_CHANNEL_IDLE); + + // Wait a bounded number of times for the channel to be ready. When the + // channel is ready, the initial TLS handshake will have successfully + // completed. The total time spent on the client side (retries * deadline) + // should be greater than the server side time limit. + int retries = 10; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + while (state != GRPC_CHANNEL_READY && retries-- > 0) { + grpc_channel_watch_connectivity_state( + channel, state, grpc_timeout_seconds_to_deadline(3), cq, nullptr); + gpr_timespec cq_deadline = grpc_timeout_seconds_to_deadline(5); + grpc_event ev = grpc_completion_queue_next(cq, cq_deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + state = + grpc_channel_check_connectivity_state(channel, 0 /* try_to_connect */); + } + grpc_completion_queue_destroy(cq); + if (retries < 0) { + success = false; + } + + grpc_channel_destroy(channel); + grpc_channel_credentials_release(ssl_creds); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + + // Now that the client is completely cleaned up, trigger the server to + // shutdown + gpr_event_set(&client_handshake_complete, &client_handshake_complete); + // Wait for the server to completely shutdown + thd.Join(); + + grpc_shutdown(); + + return success; +} + +static int callback_return_value = 0; +static char callback_target_host[4096]; +static char callback_target_pem[4096]; +static void* callback_userdata = nullptr; +static void* destruct_userdata = nullptr; + +static int verify_callback(const char* target_host, const char* target_pem, + void* userdata) { + if (target_host != nullptr) { + snprintf(callback_target_host, sizeof(callback_target_host), "%s", + target_host); + } else { + callback_target_host[0] = '\0'; + } + if (target_pem != nullptr) { + snprintf(callback_target_pem, sizeof(callback_target_pem), "%s", + target_pem); + } else { + callback_target_pem[0] = '\0'; + } + callback_userdata = userdata; + return callback_return_value; +} + +static void verify_destruct(void* userdata) { destruct_userdata = userdata; } + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + int userdata = 42; + verify_peer_options verify_options; + + // Load the server's cert so that we can assert it gets passed to the callback + grpc_slice cert_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SSL_CERT_PATH, 1, &cert_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + + // Running with all-null values should have no effect + verify_options.verify_peer_callback = nullptr; + verify_options.verify_peer_callback_userdata = nullptr; + verify_options.verify_peer_destruct = nullptr; + GPR_ASSERT(verify_peer_options_test(&verify_options)); + GPR_ASSERT(strlen(callback_target_host) == 0); + GPR_ASSERT(strlen(callback_target_pem) == 0); + GPR_ASSERT(callback_userdata == nullptr); + GPR_ASSERT(destruct_userdata == nullptr); + + // Running with the callbacks and verify we get the expected values + verify_options.verify_peer_callback = verify_callback; + verify_options.verify_peer_callback_userdata = static_cast(&userdata); + verify_options.verify_peer_destruct = verify_destruct; + GPR_ASSERT(verify_peer_options_test(&verify_options)); + GPR_ASSERT(strcmp(callback_target_host, "foo.test.google.fr") == 0); + GPR_ASSERT(strcmp(callback_target_pem, server_cert) == 0); + GPR_ASSERT(callback_userdata == static_cast(&userdata)); + GPR_ASSERT(destruct_userdata == static_cast(&userdata)); + + // If the callback returns non-zero, initializing the channel should fail. + callback_return_value = 1; + GPR_ASSERT(!verify_peer_options_test(&verify_options)); + + grpc_slice_unref(cert_slice); + + grpc_shutdown(); + return 0; +} + +#else /* GRPC_POSIX_SOCKET_TCP */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/test/core/http/format_request_test.cc b/test/core/http/format_request_test.cc new file mode 100644 index 00000000..d4d1be80 --- /dev/null +++ b/test/core/http/format_request_test.cc @@ -0,0 +1,152 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/http/format_request.h" + +#include + +#include + +#include "test/core/util/test_config.h" + +static void test_format_get_request(void) { + grpc_http_header hdr = {const_cast("x-yz"), const_cast("abc")}; + grpc_httpcli_request req; + grpc_slice slice; + + memset(&req, 0, sizeof(req)); + req.host = const_cast("example.com"); + req.http.path = const_cast("/index.html"); + req.http.hdr_count = 1; + req.http.hdrs = &hdr; + + slice = grpc_httpcli_format_get_request(&req); + + GPR_ASSERT(0 == grpc_slice_str_cmp(slice, + "GET /index.html HTTP/1.0\r\n" + "Host: example.com\r\n" + "Connection: close\r\n" + "User-Agent: " GRPC_HTTPCLI_USER_AGENT + "\r\n" + "x-yz: abc\r\n" + "\r\n")); + + grpc_slice_unref(slice); +} + +static void test_format_post_request(void) { + grpc_http_header hdr = {const_cast("x-yz"), const_cast("abc")}; + grpc_httpcli_request req; + grpc_slice slice; + char body_bytes[] = "fake body"; + size_t body_len = 9; + + memset(&req, 0, sizeof(req)); + req.host = const_cast("example.com"); + req.http.path = const_cast("/index.html"); + req.http.hdr_count = 1; + req.http.hdrs = &hdr; + + slice = grpc_httpcli_format_post_request(&req, body_bytes, body_len); + + GPR_ASSERT(0 == grpc_slice_str_cmp(slice, + "POST /index.html HTTP/1.0\r\n" + "Host: example.com\r\n" + "Connection: close\r\n" + "User-Agent: " GRPC_HTTPCLI_USER_AGENT + "\r\n" + "x-yz: abc\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: 9\r\n" + "\r\n" + "fake body")); + + grpc_slice_unref(slice); +} + +static void test_format_post_request_no_body(void) { + grpc_http_header hdr = {const_cast("x-yz"), const_cast("abc")}; + grpc_httpcli_request req; + grpc_slice slice; + + memset(&req, 0, sizeof(req)); + req.host = const_cast("example.com"); + req.http.path = const_cast("/index.html"); + req.http.hdr_count = 1; + req.http.hdrs = &hdr; + + slice = grpc_httpcli_format_post_request(&req, nullptr, 0); + + GPR_ASSERT(0 == grpc_slice_str_cmp(slice, + "POST /index.html HTTP/1.0\r\n" + "Host: example.com\r\n" + "Connection: close\r\n" + "User-Agent: " GRPC_HTTPCLI_USER_AGENT + "\r\n" + "x-yz: abc\r\n" + "\r\n")); + + grpc_slice_unref(slice); +} + +static void test_format_post_request_content_type_override(void) { + grpc_http_header hdrs[2]; + grpc_httpcli_request req; + grpc_slice slice; + char body_bytes[] = "fake%20body"; + size_t body_len = 11; + + hdrs[0].key = const_cast("x-yz"); + hdrs[0].value = const_cast("abc"); + hdrs[1].key = const_cast("Content-Type"); + hdrs[1].value = const_cast("application/x-www-form-urlencoded"); + memset(&req, 0, sizeof(req)); + req.host = const_cast("example.com"); + req.http.path = const_cast("/index.html"); + req.http.hdr_count = 2; + req.http.hdrs = hdrs; + + slice = grpc_httpcli_format_post_request(&req, body_bytes, body_len); + + GPR_ASSERT(0 == grpc_slice_str_cmp( + slice, + "POST /index.html HTTP/1.0\r\n" + "Host: example.com\r\n" + "Connection: close\r\n" + "User-Agent: " GRPC_HTTPCLI_USER_AGENT "\r\n" + "x-yz: abc\r\n" + "Content-Type: application/x-www-form-urlencoded\r\n" + "Content-Length: 11\r\n" + "\r\n" + "fake%20body")); + + grpc_slice_unref(slice); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_format_get_request(); + test_format_post_request(); + test_format_post_request_no_body(); + test_format_post_request_content_type_override(); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/http/httpcli_test.cc b/test/core/http/httpcli_test.cc new file mode 100644 index 00000000..dbc04716 --- /dev/null +++ b/test/core/http/httpcli_test.cc @@ -0,0 +1,217 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/http/httpcli.h" + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/iomgr.h" +#include "test/core/util/port.h" +#include "test/core/util/subprocess.h" +#include "test/core/util/test_config.h" + +static int g_done = 0; +static grpc_httpcli_context g_context; +static gpr_mu* g_mu; +static grpc_polling_entity g_pops; + +static grpc_millis n_seconds_time(int seconds) { + return grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(seconds)); +} + +static void on_finish(void* arg, grpc_error_handle error) { + const char* expect = + "Hello world!" + "

This is a test

"; + grpc_http_response* response = static_cast(arg); + GPR_ASSERT(response); + gpr_log(GPR_INFO, "response status=%d error=%s", response->status, + grpc_error_std_string(error).c_str()); + GPR_ASSERT(response->status == 200); + GPR_ASSERT(response->body_length == strlen(expect)); + GPR_ASSERT(0 == memcmp(expect, response->body, response->body_length)); + gpr_mu_lock(g_mu); + g_done = 1; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&g_pops), nullptr))); + gpr_mu_unlock(g_mu); +} + +static void test_get(int port) { + grpc_httpcli_request req; + char* host; + grpc_core::ExecCtx exec_ctx; + + g_done = 0; + gpr_log(GPR_INFO, "test_get"); + + gpr_asprintf(&host, "localhost:%d", port); + gpr_log(GPR_INFO, "requesting from %s", host); + + memset(&req, 0, sizeof(req)); + req.host = host; + req.http.path = const_cast("/get"); + req.handshaker = &grpc_httpcli_plaintext; + + grpc_http_response response; + response = {}; + grpc_resource_quota* resource_quota = grpc_resource_quota_create("test_get"); + grpc_httpcli_get( + &g_context, &g_pops, resource_quota, &req, n_seconds_time(15), + GRPC_CLOSURE_CREATE(on_finish, &response, grpc_schedule_on_exec_ctx), + &response); + gpr_mu_lock(g_mu); + while (!g_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(grpc_polling_entity_pollset(&g_pops), + &worker, n_seconds_time(1)))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + gpr_free(host); + grpc_http_response_destroy(&response); +} + +static void test_post(int port) { + grpc_httpcli_request req; + char* host; + grpc_core::ExecCtx exec_ctx; + + g_done = 0; + gpr_log(GPR_INFO, "test_post"); + + gpr_asprintf(&host, "localhost:%d", port); + gpr_log(GPR_INFO, "posting to %s", host); + + memset(&req, 0, sizeof(req)); + req.host = host; + req.http.path = const_cast("/post"); + req.handshaker = &grpc_httpcli_plaintext; + + grpc_http_response response; + response = {}; + grpc_resource_quota* resource_quota = grpc_resource_quota_create("test_post"); + grpc_httpcli_post( + &g_context, &g_pops, resource_quota, &req, "hello", 5, n_seconds_time(15), + GRPC_CLOSURE_CREATE(on_finish, &response, grpc_schedule_on_exec_ctx), + &response); + gpr_mu_lock(g_mu); + while (!g_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(grpc_polling_entity_pollset(&g_pops), + &worker, n_seconds_time(1)))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + gpr_free(host); + grpc_http_response_destroy(&response); +} + +static void destroy_pops(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy( + grpc_polling_entity_pollset(static_cast(p))); +} + +int main(int argc, char** argv) { + gpr_subprocess* server; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_closure destroyed; + grpc_core::ExecCtx exec_ctx; + char* me = argv[0]; + char* lslash = strrchr(me, '/'); + char* args[4]; + int port = grpc_pick_unused_port_or_die(); + int arg_shift = 0; + /* figure out where we are */ + char* root; + if (lslash != nullptr) { + /* Hack for bazel target */ + if (static_cast(lslash - me) >= (sizeof("http") - 1) && + strncmp(me + (lslash - me) - sizeof("http") + 1, "http", + sizeof("http") - 1) == 0) { + lslash = me + (lslash - me) - sizeof("http"); + } + root = static_cast( + gpr_malloc(static_cast(lslash - me + sizeof("/../..")))); + memcpy(root, me, static_cast(lslash - me)); + memcpy(root + (lslash - me), "/../..", sizeof("/../..")); + } else { + root = gpr_strdup("."); + } + + GPR_ASSERT(argc <= 2); + if (argc == 2) { + args[0] = gpr_strdup(argv[1]); + } else { + arg_shift = 1; + gpr_asprintf(&args[0], "%s/test/core/http/python_wrapper.sh", root); + gpr_asprintf(&args[1], "%s/test/core/http/test_server.py", root); + } + + /* start the server */ + args[1 + arg_shift] = const_cast("--port"); + gpr_asprintf(&args[2 + arg_shift], "%d", port); + server = + gpr_subprocess_create(3 + arg_shift, const_cast(args)); + GPR_ASSERT(server); + gpr_free(args[0]); + if (arg_shift) gpr_free(args[1]); + gpr_free(args[2 + arg_shift]); + gpr_free(root); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(5, GPR_TIMESPAN))); + + grpc_httpcli_context_init(&g_context); + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &g_mu); + g_pops = grpc_polling_entity_create_from_pollset(pollset); + + test_get(port); + test_post(port); + + grpc_httpcli_context_destroy(&g_context); + GRPC_CLOSURE_INIT(&destroyed, destroy_pops, &g_pops, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&g_pops), &destroyed); + } + grpc_shutdown(); + + gpr_free(grpc_polling_entity_pollset(&g_pops)); + + gpr_subprocess_destroy(server); + + return 0; +} diff --git a/test/core/http/httpscli_test.cc b/test/core/http/httpscli_test.cc new file mode 100644 index 00000000..07e46ffa --- /dev/null +++ b/test/core/http/httpscli_test.cc @@ -0,0 +1,227 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/security/security_connector/ssl_utils_config.h" +#include "test/core/util/port.h" +#include "test/core/util/subprocess.h" +#include "test/core/util/test_config.h" + +static int g_done = 0; +static grpc_httpcli_context g_context; +static gpr_mu* g_mu; +static grpc_polling_entity g_pops; + +static grpc_millis n_seconds_time(int seconds) { + return grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(seconds)); +} + +static void on_finish(void* arg, grpc_error_handle error) { + const char* expect = + "Hello world!" + "

This is a test

"; + grpc_http_response* response = static_cast(arg); + GPR_ASSERT(response); + gpr_log(GPR_INFO, "response status=%d error=%s", response->status, + grpc_error_std_string(error).c_str()); + GPR_ASSERT(response->status == 200); + GPR_ASSERT(response->body_length == strlen(expect)); + GPR_ASSERT(0 == memcmp(expect, response->body, response->body_length)); + gpr_mu_lock(g_mu); + g_done = 1; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&g_pops), nullptr))); + gpr_mu_unlock(g_mu); +} + +static void test_get(int port) { + grpc_httpcli_request req; + char* host; + grpc_core::ExecCtx exec_ctx; + + g_done = 0; + gpr_log(GPR_INFO, "test_get"); + + gpr_asprintf(&host, "localhost:%d", port); + gpr_log(GPR_INFO, "requesting from %s", host); + + memset(&req, 0, sizeof(req)); + req.host = host; + req.ssl_host_override = const_cast("foo.test.google.fr"); + req.http.path = const_cast("/get"); + req.handshaker = &grpc_httpcli_ssl; + + grpc_http_response response; + response = {}; + grpc_resource_quota* resource_quota = grpc_resource_quota_create("test_get"); + grpc_httpcli_get( + &g_context, &g_pops, resource_quota, &req, n_seconds_time(15), + GRPC_CLOSURE_CREATE(on_finish, &response, grpc_schedule_on_exec_ctx), + &response); + gpr_mu_lock(g_mu); + while (!g_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(grpc_polling_entity_pollset(&g_pops), + &worker, n_seconds_time(1)))); + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + gpr_free(host); + grpc_http_response_destroy(&response); +} + +static void test_post(int port) { + grpc_httpcli_request req; + char* host; + grpc_core::ExecCtx exec_ctx; + + g_done = 0; + gpr_log(GPR_INFO, "test_post"); + + gpr_asprintf(&host, "localhost:%d", port); + gpr_log(GPR_INFO, "posting to %s", host); + + memset(&req, 0, sizeof(req)); + req.host = host; + req.ssl_host_override = const_cast("foo.test.google.fr"); + req.http.path = const_cast("/post"); + req.handshaker = &grpc_httpcli_ssl; + + grpc_http_response response; + response = {}; + grpc_resource_quota* resource_quota = grpc_resource_quota_create("test_post"); + grpc_httpcli_post( + &g_context, &g_pops, resource_quota, &req, "hello", 5, n_seconds_time(15), + GRPC_CLOSURE_CREATE(on_finish, &response, grpc_schedule_on_exec_ctx), + &response); + gpr_mu_lock(g_mu); + while (!g_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(grpc_polling_entity_pollset(&g_pops), + &worker, n_seconds_time(1)))); + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + gpr_free(host); + grpc_http_response_destroy(&response); +} + +static void destroy_pops(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy( + grpc_polling_entity_pollset(static_cast(p))); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + gpr_subprocess* server; + char* me = argv[0]; + char* lslash = strrchr(me, '/'); + char* args[5]; + int port = grpc_pick_unused_port_or_die(); + int arg_shift = 0; + /* figure out where we are */ + char* root; + if (lslash != nullptr) { + /* Hack for bazel target */ + if (static_cast(lslash - me) >= (sizeof("http") - 1) && + strncmp(me + (lslash - me) - sizeof("http") + 1, "http", + sizeof("http") - 1) == 0) { + lslash = me + (lslash - me) - sizeof("http"); + } + root = static_cast( + gpr_malloc(static_cast(lslash - me + sizeof("/../..")))); + memcpy(root, me, static_cast(lslash - me)); + memcpy(root + (lslash - me), "/../..", sizeof("/../..")); + } else { + root = gpr_strdup("."); + } + + GPR_ASSERT(argc <= 2); + if (argc == 2) { + args[0] = gpr_strdup(argv[1]); + } else { + arg_shift = 1; + gpr_asprintf(&args[0], "%s/test/core/http/python_wrapper.sh", root); + gpr_asprintf(&args[1], "%s/test/core/http/test_server.py", root); + } + + /* Set the environment variable for the SSL certificate file */ + char* pem_file; + gpr_asprintf(&pem_file, "%s/src/core/tsi/test_creds/ca.pem", root); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, pem_file); + gpr_free(pem_file); + + /* start the server */ + args[1 + arg_shift] = const_cast("--port"); + gpr_asprintf(&args[2 + arg_shift], "%d", port); + args[3 + arg_shift] = const_cast("--ssl"); + server = gpr_subprocess_create(4 + arg_shift, const_cast(args)); + GPR_ASSERT(server); + gpr_free(args[0]); + if (arg_shift) gpr_free(args[1]); + gpr_free(args[2 + arg_shift]); + gpr_free(root); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(5, GPR_TIMESPAN))); + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + grpc_httpcli_context_init(&g_context); + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &g_mu); + g_pops = grpc_polling_entity_create_from_pollset(pollset); + + test_get(port); + test_post(port); + + { + grpc_core::ExecCtx exec_ctx; + grpc_httpcli_context_destroy(&g_context); + GRPC_CLOSURE_INIT(&destroyed, destroy_pops, &g_pops, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&g_pops), &destroyed); + } + grpc_shutdown(); + + gpr_free(grpc_polling_entity_pollset(&g_pops)); + + gpr_subprocess_destroy(server); + + return 0; +} diff --git a/test/core/http/parser_test.cc b/test/core/http/parser_test.cc new file mode 100644 index 00000000..c08df1d4 --- /dev/null +++ b/test/core/http/parser_test.cc @@ -0,0 +1,310 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/http/parser.h" + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/util/slice_splitter.h" +#include "test/core/util/test_config.h" + +static void test_request_succeeds(grpc_slice_split_mode split_mode, + const char* request_text, + const char* expect_method, + grpc_http_version expect_version, + const char* expect_path, + const char* expect_body, ...) { + grpc_http_parser parser; + grpc_slice input_slice = grpc_slice_from_copied_string(request_text); + size_t num_slices; + size_t i; + grpc_slice* slices; + va_list args; + grpc_http_request request; + memset(&request, 0, sizeof(request)); + + grpc_split_slices(split_mode, &input_slice, 1, &slices, &num_slices); + grpc_slice_unref(input_slice); + + grpc_http_parser_init(&parser, GRPC_HTTP_REQUEST, &request); + + for (i = 0; i < num_slices; i++) { + GPR_ASSERT(grpc_http_parser_parse(&parser, slices[i], nullptr) == + GRPC_ERROR_NONE); + grpc_slice_unref(slices[i]); + } + GPR_ASSERT(grpc_http_parser_eof(&parser) == GRPC_ERROR_NONE); + + GPR_ASSERT(GRPC_HTTP_REQUEST == parser.type); + GPR_ASSERT(0 == strcmp(expect_method, request.method)); + GPR_ASSERT(0 == strcmp(expect_path, request.path)); + GPR_ASSERT(expect_version == request.version); + + if (expect_body != nullptr) { + GPR_ASSERT(strlen(expect_body) == request.body_length); + GPR_ASSERT(0 == memcmp(expect_body, request.body, request.body_length)); + } else { + GPR_ASSERT(request.body_length == 0); + } + + va_start(args, expect_body); + i = 0; + for (;;) { + char* expect_key; + char* expect_value; + expect_key = va_arg(args, char*); + if (!expect_key) break; + GPR_ASSERT(i < request.hdr_count); + expect_value = va_arg(args, char*); + GPR_ASSERT(expect_value); + GPR_ASSERT(0 == strcmp(expect_key, request.hdrs[i].key)); + GPR_ASSERT(0 == strcmp(expect_value, request.hdrs[i].value)); + i++; + } + va_end(args); + GPR_ASSERT(i == request.hdr_count); + + grpc_http_request_destroy(&request); + grpc_http_parser_destroy(&parser); + gpr_free(slices); +} + +static void test_succeeds(grpc_slice_split_mode split_mode, + const char* response_text, int expect_status, + const char* expect_body, ...) { + grpc_http_parser parser; + grpc_slice input_slice = grpc_slice_from_copied_string(response_text); + size_t num_slices; + size_t i; + grpc_slice* slices; + va_list args; + grpc_http_response response; + response = {}; + + grpc_split_slices(split_mode, &input_slice, 1, &slices, &num_slices); + grpc_slice_unref(input_slice); + + grpc_http_parser_init(&parser, GRPC_HTTP_RESPONSE, &response); + + for (i = 0; i < num_slices; i++) { + GPR_ASSERT(grpc_http_parser_parse(&parser, slices[i], nullptr) == + GRPC_ERROR_NONE); + grpc_slice_unref(slices[i]); + } + GPR_ASSERT(grpc_http_parser_eof(&parser) == GRPC_ERROR_NONE); + + GPR_ASSERT(GRPC_HTTP_RESPONSE == parser.type); + GPR_ASSERT(expect_status == response.status); + if (expect_body != nullptr) { + GPR_ASSERT(strlen(expect_body) == response.body_length); + GPR_ASSERT(0 == memcmp(expect_body, response.body, response.body_length)); + } else { + GPR_ASSERT(response.body_length == 0); + } + + va_start(args, expect_body); + i = 0; + for (;;) { + char* expect_key; + char* expect_value; + expect_key = va_arg(args, char*); + if (!expect_key) break; + GPR_ASSERT(i < response.hdr_count); + expect_value = va_arg(args, char*); + GPR_ASSERT(expect_value); + GPR_ASSERT(0 == strcmp(expect_key, response.hdrs[i].key)); + GPR_ASSERT(0 == strcmp(expect_value, response.hdrs[i].value)); + i++; + } + va_end(args); + GPR_ASSERT(i == response.hdr_count); + + grpc_http_response_destroy(&response); + grpc_http_parser_destroy(&parser); + gpr_free(slices); +} + +static void test_fails(grpc_slice_split_mode split_mode, + const char* response_text) { + grpc_http_parser parser; + grpc_slice input_slice = grpc_slice_from_copied_string(response_text); + size_t num_slices; + size_t i; + grpc_slice* slices; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_http_response response; + response = {}; + + grpc_split_slices(split_mode, &input_slice, 1, &slices, &num_slices); + grpc_slice_unref(input_slice); + + grpc_http_parser_init(&parser, GRPC_HTTP_RESPONSE, &response); + + for (i = 0; i < num_slices; i++) { + if (GRPC_ERROR_NONE == error) { + error = grpc_http_parser_parse(&parser, slices[i], nullptr); + } + grpc_slice_unref(slices[i]); + } + if (GRPC_ERROR_NONE == error) { + error = grpc_http_parser_eof(&parser); + } + GPR_ASSERT(error != GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); + + grpc_http_response_destroy(&response); + grpc_http_parser_destroy(&parser); + gpr_free(slices); +} + +static void test_request_fails(grpc_slice_split_mode split_mode, + const char* request_text) { + grpc_http_parser parser; + grpc_slice input_slice = grpc_slice_from_copied_string(request_text); + size_t num_slices; + size_t i; + grpc_slice* slices; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_http_request request; + memset(&request, 0, sizeof(request)); + + grpc_split_slices(split_mode, &input_slice, 1, &slices, &num_slices); + grpc_slice_unref(input_slice); + + grpc_http_parser_init(&parser, GRPC_HTTP_REQUEST, &request); + + for (i = 0; i < num_slices; i++) { + if (error == GRPC_ERROR_NONE) { + error = grpc_http_parser_parse(&parser, slices[i], nullptr); + } + grpc_slice_unref(slices[i]); + } + if (error == GRPC_ERROR_NONE) { + error = grpc_http_parser_eof(&parser); + } + GPR_ASSERT(error != GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); + + grpc_http_request_destroy(&request); + grpc_http_parser_destroy(&parser); + gpr_free(slices); +} + +int main(int argc, char** argv) { + size_t i; + const grpc_slice_split_mode split_modes[] = {GRPC_SLICE_SPLIT_IDENTITY, + GRPC_SLICE_SPLIT_ONE_BYTE}; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + for (i = 0; i < GPR_ARRAY_SIZE(split_modes); i++) { + test_succeeds(split_modes[i], + "HTTP/1.0 200 OK\r\n" + "xyz: abc\r\n" + "\r\n" + "hello world!", + 200, "hello world!", "xyz", "abc", NULL); + test_succeeds(split_modes[i], + "HTTP/1.0 404 Not Found\r\n" + "\r\n", + 404, nullptr, NULL); + test_succeeds(split_modes[i], + "HTTP/1.1 200 OK\r\n" + "xyz: abc\r\n" + "\r\n" + "hello world!", + 200, "hello world!", "xyz", "abc", NULL); + test_succeeds(split_modes[i], + "HTTP/1.1 200 OK\n" + "\n" + "abc", + 200, "abc", NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/1.0\r\n" + "\r\n", + "GET", GRPC_HTTP_HTTP10, "/", nullptr, NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/1.0\r\n" + "\r\n" + "xyz", + "GET", GRPC_HTTP_HTTP10, "/", "xyz", NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/1.1\r\n" + "\r\n" + "xyz", + "GET", GRPC_HTTP_HTTP11, "/", "xyz", NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/2.0\r\n" + "\r\n" + "xyz", + "GET", GRPC_HTTP_HTTP20, "/", "xyz", NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/1.0\r\n" + "xyz: abc\r\n" + "\r\n" + "xyz", + "GET", GRPC_HTTP_HTTP10, "/", "xyz", "xyz", "abc", + NULL); + test_request_succeeds(split_modes[i], + "GET / HTTP/1.0\n" + "\n" + "xyz", + "GET", GRPC_HTTP_HTTP10, "/", "xyz", NULL); + test_fails(split_modes[i], "HTTP/1.0\r\n"); + test_fails(split_modes[i], "HTTP/1.2\r\n"); + test_fails(split_modes[i], "HTTP/1.0 000 XYX\r\n"); + test_fails(split_modes[i], "HTTP/1.0 200 OK\n"); + test_fails(split_modes[i], "HTTP/1.0 200 OK\r\n"); + test_fails(split_modes[i], "HTTP/1.0 200 OK\r\nFoo x\r\n"); + test_fails(split_modes[i], + "HTTP/1.0 200 OK\r\n" + "xyz: abc\r\n" + " def\r\n" + "\r\n" + "hello world!"); + test_request_fails(split_modes[i], "GET\r\n"); + test_request_fails(split_modes[i], "GET /\r\n"); + test_request_fails(split_modes[i], "GET / HTTP/0.0\r\n"); + test_request_fails(split_modes[i], "GET / ____/1.0\r\n"); + test_request_fails(split_modes[i], "GET / HTTP/1.2\r\n"); + test_request_fails(split_modes[i], "GET / HTTP/1.0\n"); + + char* tmp1 = + static_cast(gpr_malloc(2 * GRPC_HTTP_PARSER_MAX_HEADER_LENGTH)); + memset(tmp1, 'a', 2 * GRPC_HTTP_PARSER_MAX_HEADER_LENGTH - 1); + tmp1[2 * GRPC_HTTP_PARSER_MAX_HEADER_LENGTH - 1] = 0; + std::string tmp2 = + absl::StrFormat("HTTP/1.0 200 OK\r\nxyz: %s\r\n\r\n", tmp1); + gpr_free(tmp1); + test_fails(split_modes[i], tmp2.c_str()); + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/http/request_fuzzer.cc b/test/core/http/request_fuzzer.cc new file mode 100644 index 00000000..9798cfb3 --- /dev/null +++ b/test/core/http/request_fuzzer.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include + +#include "src/core/lib/http/parser.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_http_parser parser; + grpc_http_request request; + grpc_init(); + memset(&request, 0, sizeof(request)); + grpc_http_parser_init(&parser, GRPC_HTTP_REQUEST, &request); + grpc_slice slice = grpc_slice_from_copied_buffer((const char*)data, size); + GRPC_ERROR_UNREF(grpc_http_parser_parse(&parser, slice, nullptr)); + GRPC_ERROR_UNREF(grpc_http_parser_eof(&parser)); + grpc_slice_unref(slice); + grpc_http_parser_destroy(&parser); + grpc_http_request_destroy(&request); + grpc_shutdown(); + return 0; +} diff --git a/test/core/http/response_fuzzer.cc b/test/core/http/response_fuzzer.cc new file mode 100644 index 00000000..cf82ccfe --- /dev/null +++ b/test/core/http/response_fuzzer.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/lib/http/parser.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_http_parser parser; + grpc_http_response response; + grpc_init(); + response = {}; + grpc_http_parser_init(&parser, GRPC_HTTP_RESPONSE, &response); + grpc_slice slice = grpc_slice_from_copied_buffer((const char*)data, size); + GRPC_ERROR_UNREF(grpc_http_parser_parse(&parser, slice, nullptr)); + GRPC_ERROR_UNREF(grpc_http_parser_eof(&parser)); + grpc_slice_unref(slice); + grpc_http_parser_destroy(&parser); + grpc_http_response_destroy(&response); + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/buffer_list_test.cc b/test/core/iomgr/buffer_list_test.cc new file mode 100644 index 00000000..c7ac823c --- /dev/null +++ b/test/core/iomgr/buffer_list_test.cc @@ -0,0 +1,135 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/buffer_list.h" + +#include + +#include "src/core/lib/iomgr/port.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_LINUX_ERRQUEUE + +static void TestShutdownFlushesListVerifier(void* arg, + grpc_core::Timestamps* /*ts*/, + grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(arg != nullptr); + gpr_atm* done = reinterpret_cast(arg); + gpr_atm_rel_store(done, static_cast(1)); +} + +/** Tests that all TracedBuffer elements in the list are flushed out on + * shutdown. + * Also tests that arg is passed correctly. + */ +static void TestShutdownFlushesList() { + grpc_core::grpc_tcp_set_write_timestamps_callback( + TestShutdownFlushesListVerifier); + grpc_core::TracedBuffer* list = nullptr; +#define NUM_ELEM 5 + gpr_atm verifier_called[NUM_ELEM]; + for (auto i = 0; i < NUM_ELEM; i++) { + gpr_atm_rel_store(&verifier_called[i], static_cast(0)); + grpc_core::TracedBuffer::AddNewEntry( + &list, i, 0, static_cast(&verifier_called[i])); + } + grpc_core::TracedBuffer::Shutdown(&list, nullptr, GRPC_ERROR_NONE); + GPR_ASSERT(list == nullptr); + for (auto i = 0; i < NUM_ELEM; i++) { + GPR_ASSERT(gpr_atm_acq_load(&verifier_called[i]) == + static_cast(1)); + } +} + +static void TestVerifierCalledOnAckVerifier(void* arg, + grpc_core::Timestamps* ts, + grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(arg != nullptr); + GPR_ASSERT(ts->acked_time.time.clock_type == GPR_CLOCK_REALTIME); + GPR_ASSERT(ts->acked_time.time.tv_sec == 123); + GPR_ASSERT(ts->acked_time.time.tv_nsec == 456); + GPR_ASSERT(ts->info.length > 0); + gpr_atm* done = reinterpret_cast(arg); + gpr_atm_rel_store(done, static_cast(1)); +} + +/** Tests that the timestamp verifier is called on an ACK timestamp. + */ +static void TestVerifierCalledOnAck() { + struct sock_extended_err serr; + serr.ee_data = 213; + serr.ee_info = grpc_core::SCM_TSTAMP_ACK; + struct grpc_core::scm_timestamping tss; + tss.ts[0].tv_sec = 123; + tss.ts[0].tv_nsec = 456; + grpc_core::grpc_tcp_set_write_timestamps_callback( + TestVerifierCalledOnAckVerifier); + grpc_core::TracedBuffer* list = nullptr; + gpr_atm verifier_called; + gpr_atm_rel_store(&verifier_called, static_cast(0)); + grpc_core::TracedBuffer::AddNewEntry(&list, 213, 0, &verifier_called); + grpc_core::TracedBuffer::ProcessTimestamp(&list, &serr, nullptr, &tss); + GPR_ASSERT(gpr_atm_acq_load(&verifier_called) == static_cast(1)); + GPR_ASSERT(list == nullptr); + grpc_core::TracedBuffer::Shutdown(&list, nullptr, GRPC_ERROR_NONE); +} + +/** Tests that shutdown can be called repeatedly. + */ +static void TestRepeatedShutdown() { + struct sock_extended_err serr; + serr.ee_data = 213; + serr.ee_info = grpc_core::SCM_TSTAMP_ACK; + struct grpc_core::scm_timestamping tss; + tss.ts[0].tv_sec = 123; + tss.ts[0].tv_nsec = 456; + grpc_core::grpc_tcp_set_write_timestamps_callback( + TestVerifierCalledOnAckVerifier); + grpc_core::TracedBuffer* list = nullptr; + gpr_atm verifier_called; + gpr_atm_rel_store(&verifier_called, static_cast(0)); + grpc_core::TracedBuffer::AddNewEntry(&list, 213, 0, &verifier_called); + grpc_core::TracedBuffer::ProcessTimestamp(&list, &serr, nullptr, &tss); + GPR_ASSERT(gpr_atm_acq_load(&verifier_called) == static_cast(1)); + GPR_ASSERT(list == nullptr); + grpc_core::TracedBuffer::Shutdown(&list, nullptr, GRPC_ERROR_NONE); + grpc_core::TracedBuffer::Shutdown(&list, nullptr, GRPC_ERROR_NONE); + grpc_core::TracedBuffer::Shutdown(&list, nullptr, GRPC_ERROR_NONE); +} + +static void TestTcpBufferList() { + TestVerifierCalledOnAck(); + TestShutdownFlushesList(); + TestRepeatedShutdown(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + TestTcpBufferList(); + grpc_shutdown(); + return 0; +} + +#else /* GRPC_LINUX_ERRQUEUE */ + +int main(int /*argc*/, char** /*argv*/) { return 0; } + +#endif /* GRPC_LINUX_ERRQUEUE */ diff --git a/test/core/iomgr/combiner_test.cc b/test/core/iomgr/combiner_test.cc new file mode 100644 index 00000000..7aa0e1c0 --- /dev/null +++ b/test/core/iomgr/combiner_test.cc @@ -0,0 +1,149 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/combiner.h" + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +static void test_no_op(void) { + gpr_log(GPR_DEBUG, "test_no_op"); + grpc_core::ExecCtx exec_ctx; + GRPC_COMBINER_UNREF(grpc_combiner_create(), "test_no_op"); +} + +static void set_event_to_true(void* value, grpc_error_handle /*error*/) { + gpr_event_set(static_cast(value), reinterpret_cast(1)); +} + +static void test_execute_one(void) { + gpr_log(GPR_DEBUG, "test_execute_one"); + + grpc_core::Combiner* lock = grpc_combiner_create(); + gpr_event done; + gpr_event_init(&done); + grpc_core::ExecCtx exec_ctx; + lock->Run(GRPC_CLOSURE_CREATE(set_event_to_true, &done, nullptr), + GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&done, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + GRPC_COMBINER_UNREF(lock, "test_execute_one"); +} + +typedef struct { + size_t ctr; + grpc_core::Combiner* lock; + gpr_event done; +} thd_args; + +typedef struct { + size_t* ctr; + size_t value; +} ex_args; + +static void check_one(void* a, grpc_error_handle /*error*/) { + ex_args* args = static_cast(a); + GPR_ASSERT(*args->ctr == args->value - 1); + *args->ctr = args->value; + gpr_free(a); +} + +static void execute_many_loop(void* a) { + thd_args* args = static_cast(a); + grpc_core::ExecCtx exec_ctx; + size_t n = 1; + for (size_t i = 0; i < 10; i++) { + for (size_t j = 0; j < 10000; j++) { + ex_args* c = static_cast(gpr_malloc(sizeof(*c))); + c->ctr = &args->ctr; + c->value = n++; + args->lock->Run(GRPC_CLOSURE_CREATE(check_one, c, nullptr), + GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + // sleep for a little bit, to test a combiner draining and another thread + // picking it up + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + } + args->lock->Run(GRPC_CLOSURE_CREATE(set_event_to_true, &args->done, nullptr), + GRPC_ERROR_NONE); +} + +static void test_execute_many(void) { + gpr_log(GPR_DEBUG, "test_execute_many"); + + grpc_core::Combiner* lock = grpc_combiner_create(); + grpc_core::Thread thds[100]; + thd_args ta[GPR_ARRAY_SIZE(thds)]; + for (size_t i = 0; i < GPR_ARRAY_SIZE(thds); i++) { + ta[i].ctr = 0; + ta[i].lock = lock; + gpr_event_init(&ta[i].done); + thds[i] = grpc_core::Thread("grpc_execute_many", execute_many_loop, &ta[i]); + thds[i].Start(); + } + for (size_t i = 0; i < GPR_ARRAY_SIZE(thds); i++) { + GPR_ASSERT(gpr_event_wait(&ta[i].done, + gpr_inf_future(GPR_CLOCK_REALTIME)) != nullptr); + thds[i].Join(); + } + grpc_core::ExecCtx exec_ctx; + GRPC_COMBINER_UNREF(lock, "test_execute_many"); +} + +static gpr_event got_in_finally; + +static void in_finally(void* /*arg*/, grpc_error_handle /*error*/) { + gpr_event_set(&got_in_finally, reinterpret_cast(1)); +} + +static void add_finally(void* arg, grpc_error_handle /*error*/) { + static_cast(arg)->Run( + GRPC_CLOSURE_CREATE(in_finally, arg, nullptr), GRPC_ERROR_NONE); +} + +static void test_execute_finally(void) { + gpr_log(GPR_DEBUG, "test_execute_finally"); + + grpc_core::Combiner* lock = grpc_combiner_create(); + grpc_core::ExecCtx exec_ctx; + gpr_event_init(&got_in_finally); + lock->Run(GRPC_CLOSURE_CREATE(add_finally, lock, nullptr), GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&got_in_finally, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GRPC_COMBINER_UNREF(lock, "test_execute_finally"); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_no_op(); + test_execute_one(); + test_execute_finally(); + test_execute_many(); + grpc_shutdown(); + + return 0; +} diff --git a/test/core/iomgr/endpoint_pair_test.cc b/test/core/iomgr/endpoint_pair_test.cc new file mode 100644 index 00000000..d97a3c78 --- /dev/null +++ b/test/core/iomgr/endpoint_pair_test.cc @@ -0,0 +1,78 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/endpoint_pair.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/iomgr/endpoint_tests.h" +#include "test/core/util/test_config.h" + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; + +static void clean_up(void) {} + +static grpc_endpoint_test_fixture create_fixture_endpoint_pair( + size_t slice_size) { + grpc_core::ExecCtx exec_ctx; + grpc_endpoint_test_fixture f; + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + grpc_endpoint_pair p = grpc_iomgr_create_endpoint_pair("test", &args); + f.client_ep = p.client; + f.server_ep = p.server; + grpc_endpoint_add_to_pollset(f.client_ep, g_pollset); + grpc_endpoint_add_to_pollset(f.server_ep, g_pollset); + + return f; +} + +static grpc_endpoint_test_config configs[] = { + {"tcp/tcp_socketpair", create_fixture_endpoint_pair, clean_up}, +}; + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + grpc_endpoint_tests(configs[0], g_pollset, g_mu); + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + } + grpc_shutdown(); + gpr_free(g_pollset); + + return 0; +} diff --git a/test/core/iomgr/endpoint_tests.cc b/test/core/iomgr/endpoint_tests.cc new file mode 100644 index 00000000..2a0e5a18 --- /dev/null +++ b/test/core/iomgr/endpoint_tests.cc @@ -0,0 +1,355 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/iomgr/endpoint_tests.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +/* + General test notes: + + All tests which write data into an endpoint write i%256 into byte i, which + is verified by readers. + + In general there are a few interesting things to vary which may lead to + exercising different codepaths in an implementation: + 1. Total amount of data written to the endpoint + 2. Size of slice allocations + 3. Amount of data we read from or write to the endpoint at once + + The tests here tend to parameterize these where applicable. + +*/ + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; + +size_t count_slices(grpc_slice* slices, size_t nslices, int* current_data) { + size_t num_bytes = 0; + size_t i; + size_t j; + unsigned char* buf; + for (i = 0; i < nslices; ++i) { + buf = GRPC_SLICE_START_PTR(slices[i]); + for (j = 0; j < GRPC_SLICE_LENGTH(slices[i]); ++j) { + GPR_ASSERT(buf[j] == *current_data); + *current_data = (*current_data + 1) % 256; + } + num_bytes += GRPC_SLICE_LENGTH(slices[i]); + } + return num_bytes; +} + +static grpc_endpoint_test_fixture begin_test(grpc_endpoint_test_config config, + const char* test_name, + size_t slice_size) { + gpr_log(GPR_INFO, "%s/%s", test_name, config.name); + return config.create_fixture(slice_size); +} + +static void end_test(grpc_endpoint_test_config config) { config.clean_up(); } + +static grpc_slice* allocate_blocks(size_t num_bytes, size_t slice_size, + size_t* num_blocks, uint8_t* current_data) { + size_t nslices = num_bytes / slice_size + (num_bytes % slice_size ? 1 : 0); + grpc_slice* slices = + static_cast(gpr_malloc(sizeof(grpc_slice) * nslices)); + size_t num_bytes_left = num_bytes; + size_t i; + size_t j; + unsigned char* buf; + *num_blocks = nslices; + + for (i = 0; i < nslices; ++i) { + slices[i] = grpc_slice_malloc(slice_size > num_bytes_left ? num_bytes_left + : slice_size); + num_bytes_left -= GRPC_SLICE_LENGTH(slices[i]); + buf = GRPC_SLICE_START_PTR(slices[i]); + for (j = 0; j < GRPC_SLICE_LENGTH(slices[i]); ++j) { + buf[j] = *current_data; + (*current_data)++; + } + } + GPR_ASSERT(num_bytes_left == 0); + return slices; +} + +struct read_and_write_test_state { + grpc_endpoint* read_ep; + grpc_endpoint* write_ep; + size_t target_bytes; + size_t bytes_read; + size_t current_write_size; + size_t bytes_written; + int current_read_data; + uint8_t current_write_data; + int read_done; + int write_done; + grpc_slice_buffer incoming; + grpc_slice_buffer outgoing; + grpc_closure done_read; + grpc_closure done_write; + grpc_closure read_scheduler; + grpc_closure write_scheduler; +}; + +static void read_scheduler(void* data, grpc_error_handle /* error */) { + struct read_and_write_test_state* state = + static_cast(data); + grpc_endpoint_read(state->read_ep, &state->incoming, &state->done_read, + /*urgent=*/false); +} + +static void read_and_write_test_read_handler(void* data, + grpc_error_handle error) { + struct read_and_write_test_state* state = + static_cast(data); + + state->bytes_read += count_slices( + state->incoming.slices, state->incoming.count, &state->current_read_data); + if (state->bytes_read == state->target_bytes || error != GRPC_ERROR_NONE) { + gpr_log(GPR_INFO, "Read handler done"); + gpr_mu_lock(g_mu); + state->read_done = 1 + (error == GRPC_ERROR_NONE); + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr)); + gpr_mu_unlock(g_mu); + } else if (error == GRPC_ERROR_NONE) { + /* We perform many reads one after another. If grpc_endpoint_read and the + * read_handler are both run inline, we might end up growing the stack + * beyond the limit. Schedule the read on ExecCtx to avoid this. */ + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &state->read_scheduler, + GRPC_ERROR_NONE); + } +} + +static void write_scheduler(void* data, grpc_error_handle /* error */) { + struct read_and_write_test_state* state = + static_cast(data); + grpc_endpoint_write(state->write_ep, &state->outgoing, &state->done_write, + nullptr); +} + +static void read_and_write_test_write_handler(void* data, + grpc_error_handle error) { + struct read_and_write_test_state* state = + static_cast(data); + grpc_slice* slices = nullptr; + size_t nslices; + + if (error == GRPC_ERROR_NONE) { + state->bytes_written += state->current_write_size; + if (state->target_bytes - state->bytes_written < + state->current_write_size) { + state->current_write_size = state->target_bytes - state->bytes_written; + } + if (state->current_write_size != 0) { + slices = allocate_blocks(state->current_write_size, 8192, &nslices, + &state->current_write_data); + grpc_slice_buffer_reset_and_unref(&state->outgoing); + grpc_slice_buffer_addn(&state->outgoing, slices, nslices); + /* We perform many writes one after another. If grpc_endpoint_write and + * the write_handler are both run inline, we might end up growing the + * stack beyond the limit. Schedule the write on ExecCtx to avoid this. */ + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &state->write_scheduler, + GRPC_ERROR_NONE); + gpr_free(slices); + return; + } + } + + gpr_log(GPR_INFO, "Write handler done"); + gpr_mu_lock(g_mu); + state->write_done = 1 + (error == GRPC_ERROR_NONE); + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr)); + gpr_mu_unlock(g_mu); +} + +/* Do both reading and writing using the grpc_endpoint API. + + This also includes a test of the shutdown behavior. + */ +static void read_and_write_test(grpc_endpoint_test_config config, + size_t num_bytes, size_t write_size, + size_t slice_size, bool shutdown) { + struct read_and_write_test_state state; + grpc_endpoint_test_fixture f = + begin_test(config, "read_and_write_test", slice_size); + grpc_core::ExecCtx exec_ctx; + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(20)); + gpr_log(GPR_DEBUG, + "num_bytes=%" PRIuPTR " write_size=%" PRIuPTR " slice_size=%" PRIuPTR + " shutdown=%d", + num_bytes, write_size, slice_size, shutdown); + + if (shutdown) { + gpr_log(GPR_INFO, "Start read and write shutdown test"); + } else { + gpr_log(GPR_INFO, + "Start read and write test with %" PRIuPTR + " bytes, slice size %" PRIuPTR, + num_bytes, slice_size); + } + + state.read_ep = f.client_ep; + state.write_ep = f.server_ep; + state.target_bytes = num_bytes; + state.bytes_read = 0; + state.current_write_size = write_size; + state.bytes_written = 0; + state.read_done = 0; + state.write_done = 0; + state.current_read_data = 0; + state.current_write_data = 0; + GRPC_CLOSURE_INIT(&state.read_scheduler, read_scheduler, &state, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&state.done_read, read_and_write_test_read_handler, &state, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&state.write_scheduler, write_scheduler, &state, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&state.done_write, read_and_write_test_write_handler, + &state, grpc_schedule_on_exec_ctx); + grpc_slice_buffer_init(&state.outgoing); + grpc_slice_buffer_init(&state.incoming); + + /* Get started by pretending an initial write completed */ + /* NOTE: Sets up initial conditions so we can have the same write handler + for the first iteration as for later iterations. It does the right thing + even when bytes_written is unsigned. */ + state.bytes_written -= state.current_write_size; + read_and_write_test_write_handler(&state, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + + grpc_endpoint_read(state.read_ep, &state.incoming, &state.done_read, + /*urgent=*/false); + if (shutdown) { + gpr_log(GPR_DEBUG, "shutdown read"); + grpc_endpoint_shutdown( + state.read_ep, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test Shutdown")); + gpr_log(GPR_DEBUG, "shutdown write"); + grpc_endpoint_shutdown( + state.write_ep, GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test Shutdown")); + } + grpc_core::ExecCtx::Get()->Flush(); + + gpr_mu_lock(g_mu); + while (!state.read_done || !state.write_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(grpc_core::ExecCtx::Get()->Now() < deadline); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + } + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + + end_test(config); + grpc_slice_buffer_destroy_internal(&state.outgoing); + grpc_slice_buffer_destroy_internal(&state.incoming); + grpc_endpoint_destroy(state.read_ep); + grpc_endpoint_destroy(state.write_ep); +} + +static void inc_on_failure(void* arg, grpc_error_handle error) { + gpr_mu_lock(g_mu); + *static_cast(arg) += (error != GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_LOG_IF_ERROR("kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +static void wait_for_fail_count(int* fail_count, int want_fail_count) { + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(10)); + while (grpc_core::ExecCtx::Get()->Now() < deadline && + *fail_count < want_fail_count) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + GPR_ASSERT(*fail_count == want_fail_count); + gpr_mu_unlock(g_mu); +} + +static void multiple_shutdown_test(grpc_endpoint_test_config config) { + grpc_endpoint_test_fixture f = + begin_test(config, "multiple_shutdown_test", 128); + int fail_count = 0; + + grpc_slice_buffer slice_buffer; + grpc_slice_buffer_init(&slice_buffer); + + grpc_core::ExecCtx exec_ctx; + grpc_endpoint_add_to_pollset(f.client_ep, g_pollset); + grpc_endpoint_read(f.client_ep, &slice_buffer, + GRPC_CLOSURE_CREATE(inc_on_failure, &fail_count, + grpc_schedule_on_exec_ctx), + /*urgent=*/false); + wait_for_fail_count(&fail_count, 0); + grpc_endpoint_shutdown(f.client_ep, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test Shutdown")); + wait_for_fail_count(&fail_count, 1); + grpc_endpoint_read(f.client_ep, &slice_buffer, + GRPC_CLOSURE_CREATE(inc_on_failure, &fail_count, + grpc_schedule_on_exec_ctx), + /*urgent=*/false); + wait_for_fail_count(&fail_count, 2); + grpc_slice_buffer_add(&slice_buffer, grpc_slice_from_copied_string("a")); + grpc_endpoint_write(f.client_ep, &slice_buffer, + GRPC_CLOSURE_CREATE(inc_on_failure, &fail_count, + grpc_schedule_on_exec_ctx), + nullptr); + wait_for_fail_count(&fail_count, 3); + grpc_endpoint_shutdown(f.client_ep, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test Shutdown")); + wait_for_fail_count(&fail_count, 3); + + grpc_slice_buffer_destroy_internal(&slice_buffer); + + grpc_endpoint_destroy(f.client_ep); + grpc_endpoint_destroy(f.server_ep); +} + +void grpc_endpoint_tests(grpc_endpoint_test_config config, + grpc_pollset* pollset, gpr_mu* mu) { + size_t i; + g_pollset = pollset; + g_mu = mu; + multiple_shutdown_test(config); + read_and_write_test(config, 10000000, 100000, 8192, false); + read_and_write_test(config, 1000000, 100000, 1, false); + read_and_write_test(config, 100000000, 100000, 1, true); + for (i = 1; i < 1000; i = std::max(i + 1, i * 5 / 4)) { + read_and_write_test(config, 40320, i, i, false); + } + g_pollset = nullptr; + g_mu = nullptr; +} diff --git a/test/core/iomgr/error_test.cc b/test/core/iomgr/error_test.cc new file mode 100644 index 00000000..cbe80f29 --- /dev/null +++ b/test/core/iomgr/error_test.cc @@ -0,0 +1,222 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/error.h" + +#include + +#include + +#include +#include +#include + +#include "test/core/util/test_config.h" + +TEST(ErrorTest, SetGetInt) { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test"); + EXPECT_NE(error, GRPC_ERROR_NONE); + intptr_t i = 0; +#ifndef NDEBUG + // GRPC_ERROR_INT_FILE_LINE is for debug only + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_FILE_LINE, &i)); + EXPECT_TRUE(i); // line set will never be 0 +#endif + EXPECT_TRUE(!grpc_error_get_int(error, GRPC_ERROR_INT_ERRNO, &i)); + EXPECT_TRUE(!grpc_error_get_int(error, GRPC_ERROR_INT_SIZE, &i)); + + intptr_t errnumber = 314; + error = grpc_error_set_int(error, GRPC_ERROR_INT_ERRNO, errnumber); + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_ERRNO, &i)); + EXPECT_EQ(i, errnumber); + + intptr_t http = 2; + error = grpc_error_set_int(error, GRPC_ERROR_INT_HTTP2_ERROR, http); + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_HTTP2_ERROR, &i)); + EXPECT_EQ(i, http); + + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorTest, SetGetStr) { + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test"); + + std::string str; + EXPECT_TRUE(!grpc_error_get_str(error, GRPC_ERROR_STR_SYSCALL, &str)); + EXPECT_TRUE(!grpc_error_get_str(error, GRPC_ERROR_STR_TSI_ERROR, &str)); +#ifndef NDEBUG + // GRPC_ERROR_STR_FILE is for debug only + EXPECT_TRUE(grpc_error_get_str(error, GRPC_ERROR_STR_FILE, &str)); + EXPECT_THAT(str, testing::HasSubstr("error_test.c")); + // __FILE__ expands differently on + // Windows. All should at least + // contain error_test.c +#endif + EXPECT_TRUE(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &str)); + EXPECT_EQ(str, "Test"); + + error = + grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, "longer message"); + EXPECT_TRUE(grpc_error_get_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, &str)); + EXPECT_EQ(str, "longer message"); + + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorTest, CopyAndUnRef) { + // error1 has one ref + grpc_error_handle error1 = + grpc_error_set_str(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Test"), + GRPC_ERROR_STR_GRPC_MESSAGE, "message"); + std::string str; + EXPECT_TRUE(grpc_error_get_str(error1, GRPC_ERROR_STR_GRPC_MESSAGE, &str)); + EXPECT_EQ(str, "message"); + + // error 1 has two refs + (void)GRPC_ERROR_REF(error1); + // this gives error3 a ref to the new error, and decrements error1 to one ref + grpc_error_handle error3 = + grpc_error_set_str(error1, GRPC_ERROR_STR_SYSCALL, "syscall"); + EXPECT_NE(error3, error1); // should not be the same because of extra ref + EXPECT_TRUE(grpc_error_get_str(error3, GRPC_ERROR_STR_GRPC_MESSAGE, &str)); + EXPECT_EQ(str, "message"); + + // error 1 should not have a syscall but 3 should + EXPECT_TRUE(!grpc_error_get_str(error1, GRPC_ERROR_STR_SYSCALL, &str)); + EXPECT_TRUE(grpc_error_get_str(error3, GRPC_ERROR_STR_SYSCALL, &str)); + EXPECT_EQ(str, "syscall"); + + GRPC_ERROR_UNREF(error1); + GRPC_ERROR_UNREF(error3); +} + +TEST(ErrorTest, CreateReferencing) { + grpc_error_handle child = + grpc_error_set_str(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Child"), + GRPC_ERROR_STR_GRPC_MESSAGE, "message"); + grpc_error_handle parent = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Parent", &child, 1); + EXPECT_NE(parent, GRPC_ERROR_NONE); + + GRPC_ERROR_UNREF(child); + GRPC_ERROR_UNREF(parent); +} + +TEST(ErrorTest, CreateReferencingMany) { + grpc_error_handle children[3]; + children[0] = + grpc_error_set_str(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Child1"), + GRPC_ERROR_STR_GRPC_MESSAGE, "message"); + children[1] = + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Child2"), + GRPC_ERROR_INT_HTTP2_ERROR, 5); + children[2] = + grpc_error_set_str(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Child3"), + GRPC_ERROR_STR_GRPC_MESSAGE, "message 3"); + + grpc_error_handle parent = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Parent", children, 3); + EXPECT_NE(parent, GRPC_ERROR_NONE); + + for (size_t i = 0; i < 3; ++i) { + GRPC_ERROR_UNREF(children[i]); + } + GRPC_ERROR_UNREF(parent); +} + +TEST(ErrorTest, PrintErrorString) { + grpc_error_handle error = + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNIMPLEMENTED); + error = grpc_error_set_int(error, GRPC_ERROR_INT_SIZE, 666); + error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, "message"); + // gpr_log(GPR_DEBUG, "%s", grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorTest, PrintErrorStringReference) { + grpc_error_handle children[2]; + children[0] = grpc_error_set_str( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("1"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNIMPLEMENTED), + GRPC_ERROR_STR_GRPC_MESSAGE, "message for child 1"); + children[1] = grpc_error_set_str( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("2sd"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_INTERNAL), + GRPC_ERROR_STR_GRPC_MESSAGE, "message for child 2"); + + grpc_error_handle parent = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Parent", children, 2); + + for (size_t i = 0; i < 2; ++i) { + GRPC_ERROR_UNREF(children[i]); + } + GRPC_ERROR_UNREF(parent); +} + +TEST(ErrorTest, TestOsError) { + int fake_errno = 5; + const char* syscall = "syscall name"; + grpc_error_handle error = GRPC_OS_ERROR(fake_errno, syscall); + + intptr_t i = 0; + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_ERRNO, &i)); + EXPECT_EQ(i, fake_errno); + + std::string str; + EXPECT_TRUE(grpc_error_get_str(error, GRPC_ERROR_STR_SYSCALL, &str)); + EXPECT_EQ(str, syscall); + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorTest, Overflow) { + // absl::Status doesn't have a limit so there is no overflow +#ifndef GRPC_ERROR_IS_ABSEIL_STATUS + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Overflow"); + + for (size_t i = 0; i < 150; ++i) { + error = grpc_error_add_child(error, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Child")); + } + + error = grpc_error_set_int(error, GRPC_ERROR_INT_HTTP2_ERROR, 5); + error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, + "message for child 2"); + error = grpc_error_set_int(error, GRPC_ERROR_INT_GRPC_STATUS, 5); + + intptr_t i; + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_HTTP2_ERROR, &i)); + EXPECT_EQ(i, 5); + EXPECT_TRUE(!grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &i)); + + error = grpc_error_set_int(error, GRPC_ERROR_INT_HTTP2_ERROR, 10); + EXPECT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_HTTP2_ERROR, &i)); + EXPECT_EQ(i, 10); + + GRPC_ERROR_UNREF(error); +#endif +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/iomgr/ev_epollex_linux_test.cc b/test/core/iomgr/ev_epollex_linux_test.cc new file mode 100644 index 00000000..5656efa7 --- /dev/null +++ b/test/core/iomgr/ev_epollex_linux_test.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "src/core/lib/iomgr/port.h" + +/* This test only relevant on linux systems where epoll() is available */ +#if defined(GRPC_LINUX_EPOLL_CREATE1) && defined(GRPC_LINUX_EVENTFD) +#include +#include + +#include + +#include "src/core/lib/iomgr/ev_epollex_linux.h" +#include "test/core/util/test_config.h" + +static void pollset_destroy(void* ps, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(ps)); + gpr_free(ps); +} + +// This test is added to cover the case found in bug: +// https://github.com/grpc/grpc/issues/15760 +static void test_pollable_owner_fd() { + grpc_core::ExecCtx exec_ctx; + int ev_fd1; + int ev_fd2; + grpc_fd* grpc_fd1; + grpc_fd* grpc_fd2; + grpc_pollset* ps; + gpr_mu* mu; + + // == Create two grpc_fds == + // All we need is two file descriptors. Doesn't matter what type. We use + // eventfd type here for the purpose of this test + ev_fd1 = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + ev_fd2 = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (ev_fd1 < 0 || ev_fd2 < 0) { + gpr_log(GPR_ERROR, "Error in creating event fds for the test"); + return; + } + grpc_fd1 = grpc_fd_create(ev_fd1, "epollex-test-fd1", false); + grpc_fd2 = grpc_fd_create(ev_fd2, "epollex-test-fd2", false); + grpc_core::ExecCtx::Get()->Flush(); + + // == Create a pollset == + ps = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(ps, &mu); + grpc_core::ExecCtx::Get()->Flush(); + + // == Add fd1 to pollset == + grpc_pollset_add_fd(ps, grpc_fd1); + grpc_core::ExecCtx::Get()->Flush(); + + // == Destroy fd1 == + grpc_fd_orphan(grpc_fd1, nullptr, nullptr, "test fd1 orphan"); + grpc_core::ExecCtx::Get()->Flush(); + + // = Add fd2 to pollset == + // + // Before https://github.com/grpc/grpc/issues/15760, the following line caused + // unexpected behavior (The previous grpc_pollset_add_fd(ps, grpc_fd1) created + // an underlying structure in epollex that held a reference to grpc_fd1 which + // was being accessed here even after grpc_fd_orphan(grpc_fd1) was called + grpc_pollset_add_fd(ps, grpc_fd2); + grpc_core::ExecCtx::Get()->Flush(); + + // == Destroy fd2 == + grpc_fd_orphan(grpc_fd2, nullptr, nullptr, "test fd2 orphan"); + grpc_core::ExecCtx::Get()->Flush(); + + // == Destroy pollset + grpc_closure ps_destroy_closure; + GRPC_CLOSURE_INIT(&ps_destroy_closure, pollset_destroy, ps, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(ps, &ps_destroy_closure); + grpc_core::ExecCtx::Get()->Flush(); +} + +int main(int argc, char** argv) { + const char* poll_strategy = nullptr; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + poll_strategy = grpc_get_poll_strategy_name(); + if (poll_strategy != nullptr && strcmp(poll_strategy, "epollex") == 0) { + test_pollable_owner_fd(); + } else { + gpr_log(GPR_INFO, + "Skipping the test. The test is only relevant for 'epollex' " + "strategy. and the current strategy is: '%s'", + poll_strategy); + } + } + + grpc_shutdown(); + return 0; +} +#else /* defined(GRPC_LINUX_EPOLL_CREATE1) && defined(GRPC_LINUX_EVENTFD) */ +int main(int /*argc*/, char** /*argv*/) { return 0; } +#endif diff --git a/test/core/iomgr/fd_conservation_posix_test.cc b/test/core/iomgr/fd_conservation_posix_test.cc new file mode 100644 index 00000000..8af8949d --- /dev/null +++ b/test/core/iomgr/fd_conservation_posix_test.cc @@ -0,0 +1,52 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "test/core/util/test_config.h" + +int main(int argc, char** argv) { + int i; + struct rlimit rlim; + grpc_endpoint_pair p; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + + /* set max # of file descriptors to a low value, and + verify we can create and destroy many more than this number + of descriptors */ + rlim.rlim_cur = rlim.rlim_max = 10; + GPR_ASSERT(0 == setrlimit(RLIMIT_NOFILE, &rlim)); + for (i = 0; i < 100; i++) { + p = grpc_iomgr_create_endpoint_pair("test", nullptr); + grpc_endpoint_destroy(p.client); + grpc_endpoint_destroy(p.server); + grpc_core::ExecCtx::Get()->Flush(); + } + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/fd_posix_test.cc b/test/core/iomgr/fd_posix_test.cc new file mode 100644 index 00000000..7c2ec737 --- /dev/null +++ b/test/core/iomgr/fd_posix_test.cc @@ -0,0 +1,538 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_EV + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "test/core/util/test_config.h" + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; + +/* buffer size used to send and receive data. + 1024 is the minimal value to set TCP send and receive buffer. */ +#define BUF_SIZE 1024 + +/* Create a test socket with the right properties for testing. + port is the TCP port to listen or connect to. + Return a socket FD and sockaddr_in. */ +static void create_test_socket(int port, int* socket_fd, + struct sockaddr_in* sin) { + int fd; + int one = 1; + int buffer_size_bytes = BUF_SIZE; + int flags; + + fd = socket(AF_INET, SOCK_STREAM, 0); + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + /* Reset the size of socket send buffer to the minimal value to facilitate + buffer filling up and triggering notify_on_write */ + GPR_ASSERT(grpc_set_socket_sndbuf(fd, buffer_size_bytes) == GRPC_ERROR_NONE); + GPR_ASSERT(grpc_set_socket_rcvbuf(fd, buffer_size_bytes) == GRPC_ERROR_NONE); + /* Make fd non-blocking */ + flags = fcntl(fd, F_GETFL, 0); + GPR_ASSERT(fcntl(fd, F_SETFL, flags | O_NONBLOCK) == 0); + *socket_fd = fd; + + /* Use local address for test */ + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(0x7f000001); + GPR_ASSERT(port >= 0 && port < 65536); + sin->sin_port = htons(static_cast(port)); +} + +/* Phony gRPC callback */ +void no_op_cb(void* /*arg*/, int /*success*/) {} + +/* =======An upload server to test notify_on_read=========== + The server simply reads and counts a stream of bytes. */ + +/* An upload server. */ +typedef struct { + grpc_fd* em_fd; /* listening fd */ + ssize_t read_bytes_total; /* total number of received bytes */ + int done; /* set to 1 when a server finishes serving */ + grpc_closure listen_closure; +} server; + +static void server_init(server* sv) { + sv->read_bytes_total = 0; + sv->done = 0; +} + +/* An upload session. + Created when a new upload request arrives in the server. */ +typedef struct { + server* sv; /* not owned by a single session */ + grpc_fd* em_fd; /* fd to read upload bytes */ + char read_buf[BUF_SIZE]; /* buffer to store upload bytes */ + grpc_closure session_read_closure; +} session; + +/* Called when an upload session can be safely shutdown. + Close session FD and start to shutdown listen FD. */ +static void session_shutdown_cb(void* arg, /*session */ + bool /*success*/) { + session* se = static_cast(arg); + server* sv = se->sv; + grpc_fd_orphan(se->em_fd, nullptr, nullptr, "a"); + gpr_free(se); + /* Start to shutdown listen fd. */ + grpc_fd_shutdown(sv->em_fd, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("session_shutdown_cb")); +} + +/* Called when data become readable in a session. */ +static void session_read_cb(void* arg, /*session */ + grpc_error_handle error) { + session* se = static_cast(arg); + int fd = grpc_fd_wrapped_fd(se->em_fd); + + ssize_t read_once = 0; + ssize_t read_total = 0; + + if (error != GRPC_ERROR_NONE) { + session_shutdown_cb(arg, true); + return; + } + + do { + read_once = read(fd, se->read_buf, BUF_SIZE); + if (read_once > 0) read_total += read_once; + } while (read_once > 0); + se->sv->read_bytes_total += read_total; + + /* read() returns 0 to indicate the TCP connection was closed by the client. + read(fd, read_buf, 0) also returns 0 which should never be called as such. + It is possible to read nothing due to spurious edge event or data has + been drained, In such a case, read() returns -1 and set errno to EAGAIN. */ + if (read_once == 0) { + session_shutdown_cb(arg, true); + } else if (read_once == -1) { + if (errno == EAGAIN) { + /* An edge triggered event is cached in the kernel until next poll. + In the current single thread implementation, session_read_cb is called + in the polling thread, such that polling only happens after this + callback, and will catch read edge event if data is available again + before notify_on_read. + TODO(chenw): in multi-threaded version, callback and polling can be + run in different threads. polling may catch a persist read edge event + before notify_on_read is called. */ + grpc_fd_notify_on_read(se->em_fd, &se->session_read_closure); + } else { + gpr_log(GPR_ERROR, "Unhandled read error %s", strerror(errno)); + abort(); + } + } +} + +/* Called when the listen FD can be safely shutdown. + Close listen FD and signal that server can be shutdown. */ +static void listen_shutdown_cb(void* arg /*server*/, int /*success*/) { + server* sv = static_cast(arg); + + grpc_fd_orphan(sv->em_fd, nullptr, nullptr, "b"); + + gpr_mu_lock(g_mu); + sv->done = 1; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +/* Called when a new TCP connection request arrives in the listening port. */ +static void listen_cb(void* arg, /*=sv_arg*/ + grpc_error_handle error) { + server* sv = static_cast(arg); + int fd; + int flags; + session* se; + struct sockaddr_storage ss; + socklen_t slen = sizeof(ss); + grpc_fd* listen_em_fd = sv->em_fd; + + if (error != GRPC_ERROR_NONE) { + listen_shutdown_cb(arg, 1); + return; + } + + fd = accept(grpc_fd_wrapped_fd(listen_em_fd), + reinterpret_cast(&ss), &slen); + GPR_ASSERT(fd >= 0); + GPR_ASSERT(fd < FD_SETSIZE); + flags = fcntl(fd, F_GETFL, 0); + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + se = static_cast(gpr_malloc(sizeof(*se))); + se->sv = sv; + se->em_fd = grpc_fd_create(fd, "listener", false); + grpc_pollset_add_fd(g_pollset, se->em_fd); + GRPC_CLOSURE_INIT(&se->session_read_closure, session_read_cb, se, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_read(se->em_fd, &se->session_read_closure); + + grpc_fd_notify_on_read(listen_em_fd, &sv->listen_closure); +} + +/* Max number of connections pending to be accepted by listen(). */ +#define MAX_NUM_FD 1024 + +/* Start a test server, return the TCP listening port bound to listen_fd. + listen_cb() is registered to be interested in reading from listen_fd. + When connection request arrives, listen_cb() is called to accept the + connection request. */ +static int server_start(server* sv) { + int port = 0; + int fd; + struct sockaddr_in sin; + socklen_t addr_len; + + create_test_socket(port, &fd, &sin); + addr_len = sizeof(sin); + GPR_ASSERT(bind(fd, (struct sockaddr*)&sin, addr_len) == 0); + GPR_ASSERT(getsockname(fd, (struct sockaddr*)&sin, &addr_len) == 0); + port = ntohs(sin.sin_port); + GPR_ASSERT(listen(fd, MAX_NUM_FD) == 0); + + sv->em_fd = grpc_fd_create(fd, "server", false); + grpc_pollset_add_fd(g_pollset, sv->em_fd); + /* Register to be interested in reading from listen_fd. */ + GRPC_CLOSURE_INIT(&sv->listen_closure, listen_cb, sv, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_read(sv->em_fd, &sv->listen_closure); + + return port; +} + +/* Wait and shutdown a sever. */ +static void server_wait_and_shutdown(server* sv) { + gpr_mu_lock(g_mu); + while (!sv->done) { + grpc_core::ExecCtx exec_ctx; + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, GRPC_MILLIS_INF_FUTURE))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); +} + +/* ===An upload client to test notify_on_write=== */ + +/* Client write buffer size */ +#define CLIENT_WRITE_BUF_SIZE 10 +/* Total number of times that the client fills up the write buffer */ +#define CLIENT_TOTAL_WRITE_CNT 3 + +/* An upload client. */ +typedef struct { + grpc_fd* em_fd; + char write_buf[CLIENT_WRITE_BUF_SIZE]; + ssize_t write_bytes_total; + /* Number of times that the client fills up the write buffer and calls + notify_on_write to schedule another write. */ + int client_write_cnt; + + int done; /* set to 1 when a client finishes sending */ + grpc_closure write_closure; +} client; + +static void client_init(client* cl) { + memset(cl->write_buf, 0, sizeof(cl->write_buf)); + cl->write_bytes_total = 0; + cl->client_write_cnt = 0; + cl->done = 0; +} + +/* Called when a client upload session is ready to shutdown. */ +static void client_session_shutdown_cb(void* arg /*client*/, int /*success*/) { + client* cl = static_cast(arg); + grpc_fd_orphan(cl->em_fd, nullptr, nullptr, "c"); + cl->done = 1; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); +} + +/* Write as much as possible, then register notify_on_write. */ +static void client_session_write(void* arg, /*client */ + grpc_error_handle error) { + client* cl = static_cast(arg); + int fd = grpc_fd_wrapped_fd(cl->em_fd); + ssize_t write_once = 0; + + if (error != GRPC_ERROR_NONE) { + gpr_mu_lock(g_mu); + client_session_shutdown_cb(arg, 1); + gpr_mu_unlock(g_mu); + return; + } + + do { + write_once = write(fd, cl->write_buf, CLIENT_WRITE_BUF_SIZE); + if (write_once > 0) cl->write_bytes_total += write_once; + } while (write_once > 0); + + if (errno == EAGAIN) { + gpr_mu_lock(g_mu); + if (cl->client_write_cnt < CLIENT_TOTAL_WRITE_CNT) { + GRPC_CLOSURE_INIT(&cl->write_closure, client_session_write, cl, + grpc_schedule_on_exec_ctx); + grpc_fd_notify_on_write(cl->em_fd, &cl->write_closure); + cl->client_write_cnt++; + } else { + client_session_shutdown_cb(arg, 1); + } + gpr_mu_unlock(g_mu); + } else { + gpr_log(GPR_ERROR, "unknown errno %s", strerror(errno)); + abort(); + } +} + +/* Start a client to send a stream of bytes. */ +static void client_start(client* cl, int port) { + int fd; + struct sockaddr_in sin; + create_test_socket(port, &fd, &sin); + if (connect(fd, reinterpret_cast(&sin), sizeof(sin)) == + -1) { + if (errno == EINPROGRESS) { + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + pfd.revents = 0; + if (poll(&pfd, 1, -1) == -1) { + gpr_log(GPR_ERROR, "poll() failed during connect; errno=%d", errno); + abort(); + } + } else { + gpr_log(GPR_ERROR, "Failed to connect to the server (errno=%d)", errno); + abort(); + } + } + + cl->em_fd = grpc_fd_create(fd, "client", false); + grpc_pollset_add_fd(g_pollset, cl->em_fd); + + client_session_write(cl, GRPC_ERROR_NONE); +} + +/* Wait for the signal to shutdown a client. */ +static void client_wait_and_shutdown(client* cl) { + gpr_mu_lock(g_mu); + while (!cl->done) { + grpc_pollset_worker* worker = nullptr; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, GRPC_MILLIS_INF_FUTURE))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); +} + +/* Test grpc_fd. Start an upload server and client, upload a stream of + bytes from the client to the server, and verify that the total number of + sent bytes is equal to the total number of received bytes. */ +static void test_grpc_fd(void) { + server sv; + client cl; + int port; + grpc_core::ExecCtx exec_ctx; + + server_init(&sv); + port = server_start(&sv); + client_init(&cl); + client_start(&cl, port); + + client_wait_and_shutdown(&cl); + server_wait_and_shutdown(&sv); + GPR_ASSERT(sv.read_bytes_total == cl.write_bytes_total); + gpr_log(GPR_INFO, "Total read bytes %" PRIdPTR, sv.read_bytes_total); +} + +typedef struct fd_change_data { + grpc_iomgr_cb_func cb_that_ran; +} fd_change_data; + +void init_change_data(fd_change_data* fdc) { fdc->cb_that_ran = nullptr; } + +void destroy_change_data(fd_change_data* /*fdc*/) {} + +static void first_read_callback(void* arg /* fd_change_data */, + grpc_error_handle /*error*/) { + fd_change_data* fdc = static_cast(arg); + + gpr_mu_lock(g_mu); + fdc->cb_that_ran = first_read_callback; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +static void second_read_callback(void* arg /* fd_change_data */, + grpc_error_handle /*error*/) { + fd_change_data* fdc = static_cast(arg); + + gpr_mu_lock(g_mu); + fdc->cb_that_ran = second_read_callback; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +/* Test that changing the callback we use for notify_on_read actually works. + Note that we have two different but almost identical callbacks above -- the + point is to have two different function pointers and two different data + pointers and make sure that changing both really works. */ +static void test_grpc_fd_change(void) { + grpc_fd* em_fd; + fd_change_data a, b; + int flags; + int sv[2]; + char data; + ssize_t result; + grpc_closure first_closure; + grpc_closure second_closure; + grpc_core::ExecCtx exec_ctx; + + GRPC_CLOSURE_INIT(&first_closure, first_read_callback, &a, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&second_closure, second_read_callback, &b, + grpc_schedule_on_exec_ctx); + + init_change_data(&a); + init_change_data(&b); + + GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == 0); + flags = fcntl(sv[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[1], F_SETFL, flags | O_NONBLOCK) == 0); + + em_fd = grpc_fd_create(sv[0], "test_grpc_fd_change", false); + grpc_pollset_add_fd(g_pollset, em_fd); + + /* Register the first callback, then make its FD readable */ + grpc_fd_notify_on_read(em_fd, &first_closure); + data = 0; + result = write(sv[1], &data, 1); + GPR_ASSERT(result == 1); + + /* And now wait for it to run. */ + gpr_mu_lock(g_mu); + while (a.cb_that_ran == nullptr) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, GRPC_MILLIS_INF_FUTURE))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + GPR_ASSERT(a.cb_that_ran == first_read_callback); + gpr_mu_unlock(g_mu); + + /* And drain the socket so we can generate a new read edge */ + result = read(sv[0], &data, 1); + GPR_ASSERT(result == 1); + + /* Now register a second callback with distinct change data, and do the same + thing again. */ + grpc_fd_notify_on_read(em_fd, &second_closure); + data = 0; + result = write(sv[1], &data, 1); + GPR_ASSERT(result == 1); + + gpr_mu_lock(g_mu); + while (b.cb_that_ran == nullptr) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, GRPC_MILLIS_INF_FUTURE))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + /* Except now we verify that second_read_callback ran instead */ + GPR_ASSERT(b.cb_that_ran == second_read_callback); + gpr_mu_unlock(g_mu); + + grpc_fd_orphan(em_fd, nullptr, nullptr, "d"); + + destroy_change_data(&a); + destroy_change_data(&b); + close(sv[1]); +} + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + test_grpc_fd(); + test_grpc_fd_change(); + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(g_pollset); + } + grpc_shutdown(); + return 0; +} + +#else /* GRPC_POSIX_SOCKET_EV */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_EV */ diff --git a/test/core/iomgr/grpc_ipv6_loopback_available_test.cc b/test/core/iomgr/grpc_ipv6_loopback_available_test.cc new file mode 100644 index 00000000..efda72e8 --- /dev/null +++ b/test/core/iomgr/grpc_ipv6_loopback_available_test.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/lib/iomgr/port.h" +#include "test/core/util/test_config.h" + +#ifdef GPR_WINDOWS +#include "src/core/lib/iomgr/socket_windows.h" +#else +#include "src/core/lib/iomgr/socket_utils_posix.h" +#endif + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + // This test assumes that the ipv6 loopback is available + // in all environments in which grpc tests run in. + GPR_ASSERT(grpc_ipv6_loopback_available()); + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/load_file_test.cc b/test/core/iomgr/load_file_test.cc new file mode 100644 index 00000000..1e46ab44 --- /dev/null +++ b/test/core/iomgr/load_file_test.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/load_file.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x) + +static const char prefix[] = "file_test"; + +static void test_load_empty_file(void) { + FILE* tmp = nullptr; + grpc_slice slice; + grpc_slice slice_with_null_term; + grpc_error_handle error; + char* tmp_name; + + LOG_TEST_NAME("test_load_empty_file"); + + tmp = gpr_tmpfile(prefix, &tmp_name); + GPR_ASSERT(tmp_name != nullptr); + GPR_ASSERT(tmp != nullptr); + fclose(tmp); + + error = grpc_load_file(tmp_name, 0, &slice); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == 0); + + error = grpc_load_file(tmp_name, 1, &slice_with_null_term); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice_with_null_term) == 1); + GPR_ASSERT(GRPC_SLICE_START_PTR(slice_with_null_term)[0] == 0); + + remove(tmp_name); + gpr_free(tmp_name); + grpc_slice_unref(slice); + grpc_slice_unref(slice_with_null_term); +} + +static void test_load_failure(void) { + FILE* tmp = nullptr; + grpc_slice slice; + grpc_error_handle error; + char* tmp_name; + + LOG_TEST_NAME("test_load_failure"); + + tmp = gpr_tmpfile(prefix, &tmp_name); + GPR_ASSERT(tmp_name != nullptr); + GPR_ASSERT(tmp != nullptr); + fclose(tmp); + remove(tmp_name); + + error = grpc_load_file(tmp_name, 0, &slice); + GPR_ASSERT(error != GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == 0); + gpr_free(tmp_name); + grpc_slice_unref(slice); +} + +static void test_load_small_file(void) { + FILE* tmp = nullptr; + grpc_slice slice; + grpc_slice slice_with_null_term; + grpc_error_handle error; + char* tmp_name; + const char* blah = "blah"; + + LOG_TEST_NAME("test_load_small_file"); + + tmp = gpr_tmpfile(prefix, &tmp_name); + GPR_ASSERT(tmp_name != nullptr); + GPR_ASSERT(tmp != nullptr); + GPR_ASSERT(fwrite(blah, 1, strlen(blah), tmp) == strlen(blah)); + fclose(tmp); + + error = grpc_load_file(tmp_name, 0, &slice); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == strlen(blah)); + GPR_ASSERT(!memcmp(GRPC_SLICE_START_PTR(slice), blah, strlen(blah))); + + error = grpc_load_file(tmp_name, 1, &slice_with_null_term); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice_with_null_term) == (strlen(blah) + 1)); + GPR_ASSERT(strcmp((const char*)GRPC_SLICE_START_PTR(slice_with_null_term), + blah) == 0); + + remove(tmp_name); + gpr_free(tmp_name); + grpc_slice_unref(slice); + grpc_slice_unref(slice_with_null_term); +} + +static void test_load_big_file(void) { + FILE* tmp = nullptr; + grpc_slice slice; + grpc_error_handle error; + char* tmp_name; + static const size_t buffer_size = 124631; + unsigned char* buffer = static_cast(gpr_malloc(buffer_size)); + unsigned char* current; + size_t i; + + LOG_TEST_NAME("test_load_big_file"); + + memset(buffer, 42, buffer_size); + + tmp = gpr_tmpfile(prefix, &tmp_name); + GPR_ASSERT(tmp != nullptr); + GPR_ASSERT(tmp_name != nullptr); + GPR_ASSERT(fwrite(buffer, 1, buffer_size, tmp) == buffer_size); + fclose(tmp); + + error = grpc_load_file(tmp_name, 0, &slice); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == buffer_size); + current = GRPC_SLICE_START_PTR(slice); + for (i = 0; i < buffer_size; i++) { + GPR_ASSERT(current[i] == 42); + } + + remove(tmp_name); + gpr_free(tmp_name); + grpc_slice_unref(slice); + gpr_free(buffer); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_load_empty_file(); + test_load_failure(); + test_load_small_file(); + test_load_big_file(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/mpmcqueue_test.cc b/test/core/iomgr/mpmcqueue_test.cc new file mode 100644 index 00000000..0344d86f --- /dev/null +++ b/test/core/iomgr/mpmcqueue_test.cc @@ -0,0 +1,227 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/executor/mpmcqueue.h" + +#include + +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +#define TEST_NUM_ITEMS 10000 + +// Testing items for queue +struct WorkItem { + int index; + bool done; + + explicit WorkItem(int i) : index(i) { done = false; } +}; + +// Thread to "produce" items and put items into queue +// It will also check that all items has been marked done and clean up all +// produced items on destructing. +class ProducerThread { + public: + ProducerThread(grpc_core::InfLenFIFOQueue* queue, int start_index, + int num_items) + : start_index_(start_index), num_items_(num_items), queue_(queue) { + items_ = nullptr; + thd_ = grpc_core::Thread( + "mpmcq_test_producer_thd", + [](void* th) { static_cast(th)->Run(); }, this); + } + ~ProducerThread() { + for (int i = 0; i < num_items_; ++i) { + GPR_ASSERT(items_[i]->done); + delete items_[i]; + } + delete[] items_; + } + + void Start() { thd_.Start(); } + void Join() { thd_.Join(); } + + private: + void Run() { + items_ = new WorkItem*[num_items_]; + for (int i = 0; i < num_items_; ++i) { + items_[i] = new WorkItem(start_index_ + i); + queue_->Put(items_[i]); + } + } + + int start_index_; + int num_items_; + grpc_core::InfLenFIFOQueue* queue_; + grpc_core::Thread thd_; + WorkItem** items_; +}; + +// Thread to pull out items from queue +class ConsumerThread { + public: + explicit ConsumerThread(grpc_core::InfLenFIFOQueue* queue) : queue_(queue) { + thd_ = grpc_core::Thread( + "mpmcq_test_consumer_thd", + [](void* th) { static_cast(th)->Run(); }, this); + } + ~ConsumerThread() {} + + void Start() { thd_.Start(); } + void Join() { thd_.Join(); } + + private: + void Run() { + // count number of Get() called in this thread + int count = 0; + + WorkItem* item; + while ((item = static_cast(queue_->Get(nullptr))) != nullptr) { + count++; + GPR_ASSERT(!item->done); + item->done = true; + } + + gpr_log(GPR_DEBUG, "ConsumerThread: %d times of Get() called.", count); + } + grpc_core::InfLenFIFOQueue* queue_; + grpc_core::Thread thd_; +}; + +static void test_FIFO(void) { + gpr_log(GPR_INFO, "test_FIFO"); + grpc_core::InfLenFIFOQueue large_queue; + for (int i = 0; i < TEST_NUM_ITEMS; ++i) { + large_queue.Put(static_cast(new WorkItem(i))); + } + GPR_ASSERT(large_queue.count() == TEST_NUM_ITEMS); + for (int i = 0; i < TEST_NUM_ITEMS; ++i) { + WorkItem* item = static_cast(large_queue.Get(nullptr)); + GPR_ASSERT(i == item->index); + delete item; + } +} + +// Test if queue's behavior of expanding is correct. (Only does expansion when +// it gets full, and each time expands to doubled size). +static void test_space_efficiency(void) { + gpr_log(GPR_INFO, "test_space_efficiency"); + grpc_core::InfLenFIFOQueue queue; + for (int i = 0; i < queue.init_num_nodes(); ++i) { + queue.Put(static_cast(new WorkItem(i))); + } + // Queue should not have been expanded at this time. + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes()); + for (int i = 0; i < queue.init_num_nodes(); ++i) { + WorkItem* item = static_cast(queue.Get(nullptr)); + queue.Put(item); + } + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes()); + for (int i = 0; i < queue.init_num_nodes(); ++i) { + WorkItem* item = static_cast(queue.Get(nullptr)); + delete item; + } + // Queue never shrinks even it is empty. + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes()); + GPR_ASSERT(queue.count() == 0); + // queue empty now + for (int i = 0; i < queue.init_num_nodes() * 2; ++i) { + queue.Put(static_cast(new WorkItem(i))); + } + GPR_ASSERT(queue.count() == queue.init_num_nodes() * 2); + // Queue should have been expanded once. + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes() * 2); + for (int i = 0; i < queue.init_num_nodes(); ++i) { + WorkItem* item = static_cast(queue.Get(nullptr)); + delete item; + } + GPR_ASSERT(queue.count() == queue.init_num_nodes()); + // Queue will never shrink, should keep same number of node as before. + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes() * 2); + for (int i = 0; i < queue.init_num_nodes() + 1; ++i) { + queue.Put(static_cast(new WorkItem(i))); + } + GPR_ASSERT(queue.count() == queue.init_num_nodes() * 2 + 1); + // Queue should have been expanded twice. + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes() * 4); + for (int i = 0; i < queue.init_num_nodes() * 2 + 1; ++i) { + WorkItem* item = static_cast(queue.Get(nullptr)); + delete item; + } + GPR_ASSERT(queue.count() == 0); + GPR_ASSERT(queue.num_nodes() == queue.init_num_nodes() * 4); + gpr_log(GPR_DEBUG, "Done."); +} + +static void test_many_thread(void) { + gpr_log(GPR_INFO, "test_many_thread"); + const int num_producer_threads = 10; + const int num_consumer_threads = 20; + grpc_core::InfLenFIFOQueue queue; + ProducerThread** producer_threads = new ProducerThread*[num_producer_threads]; + ConsumerThread** consumer_threads = new ConsumerThread*[num_consumer_threads]; + + gpr_log(GPR_DEBUG, "Fork ProducerThreads..."); + for (int i = 0; i < num_producer_threads; ++i) { + producer_threads[i] = + new ProducerThread(&queue, i * TEST_NUM_ITEMS, TEST_NUM_ITEMS); + producer_threads[i]->Start(); + } + gpr_log(GPR_DEBUG, "ProducerThreads Started."); + gpr_log(GPR_DEBUG, "Fork ConsumerThreads..."); + for (int i = 0; i < num_consumer_threads; ++i) { + consumer_threads[i] = new ConsumerThread(&queue); + consumer_threads[i]->Start(); + } + gpr_log(GPR_DEBUG, "ConsumerThreads Started."); + gpr_log(GPR_DEBUG, "Waiting ProducerThreads to finish..."); + for (int i = 0; i < num_producer_threads; ++i) { + producer_threads[i]->Join(); + } + gpr_log(GPR_DEBUG, "All ProducerThreads Terminated."); + gpr_log(GPR_DEBUG, "Terminating ConsumerThreads..."); + for (int i = 0; i < num_consumer_threads; ++i) { + queue.Put(nullptr); + } + for (int i = 0; i < num_consumer_threads; ++i) { + consumer_threads[i]->Join(); + } + gpr_log(GPR_DEBUG, "All ConsumerThreads Terminated."); + gpr_log(GPR_DEBUG, "Checking WorkItems and Cleaning Up..."); + for (int i = 0; i < num_producer_threads; ++i) { + // Destructor of ProducerThread will do the check of WorkItems + delete producer_threads[i]; + } + delete[] producer_threads; + for (int i = 0; i < num_consumer_threads; ++i) { + delete consumer_threads[i]; + } + delete[] consumer_threads; + gpr_log(GPR_DEBUG, "Done."); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_FIFO(); + test_space_efficiency(); + test_many_thread(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/pollset_windows_starvation_test.cc b/test/core/iomgr/pollset_windows_starvation_test.cc new file mode 100644 index 00000000..3357e3b5 --- /dev/null +++ b/test/core/iomgr/pollset_windows_starvation_test.cc @@ -0,0 +1,144 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iocp_windows.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_windows.h" +#include "src/core/lib/surface/init.h" +#include "test/core/util/test_config.h" + +#if defined(GRPC_WINSOCK_SOCKET) + +// At least three threads are required to reproduce #18848 +const size_t THREADS = 3; + +struct ThreadParams { + gpr_cv cv; + gpr_mu mu; + int complete; + int queuing; + gpr_mu* pollset_mu[THREADS]; +}; + +int main(int argc, char** argv) { + grpc_init(); + + // Create the threads that all start queueing for work. + // + // The first one becomes the active poller for work and the two other + // threads go into the poller queue. + // + // When work arrives, the first one notifies the next queued poller, + // this wakes the second thread - however all this does is return from + // the grpc_pollset_work function. It's up to that thread to figure + // out if it still wants to queue for more work or if it should kick + // other pollers. + // + // Previously that kick only affected pollers in the same pollset, thus + // leaving the other threads stuck in the poller queue. Now the pollset- + // specific grpc_pollset_kick will also kick pollers from other pollsets + // if there are no pollers in the current pollset. This frees up the + // last threads and completes the test. + ThreadParams params = {}; + gpr_cv_init(¶ms.cv); + gpr_mu_init(¶ms.mu); + std::vector threads; + for (int i = 0; i < THREADS; i++) { + grpc_core::Thread thd( + "Poller", + [](void* params) { + ThreadParams* tparams = static_cast(params); + grpc_core::ExecCtx exec_ctx; + + gpr_mu* mu; + grpc_pollset pollset = {}; + grpc_pollset_init(&pollset, &mu); + + // Lock the pollset mutex before notifying the test runner thread that + // one more thread is queuing. This allows the test runner thread to + // wait for all threads to be queued before sending the first kick by + // waiting for the mutexes to be released, which happens in + // gpr_pollset_work when the poller is queued. + gpr_mu_lock(mu); + + gpr_mu_lock(&tparams->mu); + tparams->pollset_mu[tparams->queuing] = mu; + tparams->queuing++; + gpr_cv_signal(&tparams->cv); + gpr_mu_unlock(&tparams->mu); + + // Queue for work and once we're done, make sure to kick the remaining + // threads. + grpc_error_handle error; + error = grpc_pollset_work(&pollset, NULL, GRPC_MILLIS_INF_FUTURE); + error = grpc_pollset_kick(&pollset, NULL); + + gpr_mu_unlock(mu); + + gpr_mu_lock(&tparams->mu); + tparams->complete++; + gpr_cv_signal(&tparams->cv); + gpr_mu_unlock(&tparams->mu); + }, + ¶ms); + thd.Start(); + threads.push_back(std::move(thd)); + } + + // Wait for all three threads to be queuing. + gpr_mu_lock(¶ms.mu); + while ( + params.queuing != THREADS && + !gpr_cv_wait(¶ms.cv, ¶ms.mu, gpr_inf_future(GPR_CLOCK_REALTIME))) + ; + gpr_mu_unlock(¶ms.mu); + + // Wait for the mutexes to be released. This indicates that the threads have + // entered the work wait. + // + // At least currently these are essentially all references to the same global + // pollset mutex, but we are still waiting on them once for each thread in + // the case this ever changes. + for (int i = 0; i < THREADS; i++) { + gpr_mu_lock(params.pollset_mu[i]); + gpr_mu_unlock(params.pollset_mu[i]); + } + + grpc_iocp_kick(); + + // Wait for the threads to complete. + gpr_mu_lock(¶ms.mu); + while ( + params.complete != THREADS && + !gpr_cv_wait(¶ms.cv, ¶ms.mu, gpr_inf_future(GPR_CLOCK_REALTIME))) + ; + gpr_mu_unlock(¶ms.mu); + + for (auto& t : threads) t.Join(); + return EXIT_SUCCESS; +} +#else /* defined(GRPC_WINSOCK_SOCKET) */ +int main(int /*argc*/, char** /*argv*/) { return 0; } +#endif diff --git a/test/core/iomgr/resolve_address_posix_test.cc b/test/core/iomgr/resolve_address_posix_test.cc new file mode 100644 index 00000000..7e34a027 --- /dev/null +++ b/test/core/iomgr/resolve_address_posix_test.cc @@ -0,0 +1,221 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/test_config.h" + +static gpr_timespec test_deadline(void) { + return grpc_timeout_seconds_to_deadline(100); +} + +typedef struct args_struct { + grpc_core::Thread thd; + gpr_event ev; + grpc_resolved_addresses* addrs; + gpr_mu* mu; + bool done; // guarded by mu + grpc_pollset* pollset; // guarded by mu + grpc_pollset_set* pollset_set; +} args_struct; + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void args_init(args_struct* args) { + gpr_event_init(&args->ev); + args->pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(args->pollset, &args->mu); + args->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(args->pollset_set, args->pollset); + args->addrs = nullptr; + args->done = false; +} + +void args_finish(args_struct* args) { + GPR_ASSERT(gpr_event_wait(&args->ev, test_deadline())); + args->thd.Join(); + // Don't need to explicitly destruct args->thd since + // args is actually going to be destructed, not just freed + grpc_resolved_addresses_destroy(args->addrs); + grpc_pollset_set_del_pollset(args->pollset_set, args->pollset); + grpc_pollset_set_destroy(args->pollset_set); + grpc_closure do_nothing_cb; + GRPC_CLOSURE_INIT(&do_nothing_cb, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(args->pollset, &do_nothing_cb); + // exec_ctx needs to be flushed before calling grpc_pollset_destroy() + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(args->pollset); + gpr_free(args->pollset); +} + +static grpc_millis n_sec_deadline(int seconds) { + return grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(seconds)); +} + +static void actually_poll(void* argsp) { + args_struct* args = static_cast(argsp); + grpc_millis deadline = n_sec_deadline(10); + while (true) { + grpc_core::ExecCtx exec_ctx; + { + grpc_core::MutexLockForGprMu lock(args->mu); + if (args->done) { + break; + } + grpc_millis time_left = deadline - grpc_core::ExecCtx::Get()->Now(); + gpr_log(GPR_DEBUG, "done=%d, time_left=%" PRId64, args->done, time_left); + GPR_ASSERT(time_left >= 0); + grpc_pollset_worker* worker = nullptr; + GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(args->pollset, &worker, n_sec_deadline(1))); + } + } + gpr_event_set(&args->ev, reinterpret_cast(1)); +} + +static void poll_pollset_until_request_done(args_struct* args) { + args->thd = grpc_core::Thread("grpc_poll_pollset", actually_poll, args); + args->thd.Start(); +} + +static void must_succeed(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err == GRPC_ERROR_NONE); + GPR_ASSERT(args->addrs != nullptr); + GPR_ASSERT(args->addrs->naddrs > 0); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +static void must_fail(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err != GRPC_ERROR_NONE); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +static void resolve_address_must_succeed(const char* target) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + poll_pollset_until_request_done(&args); + grpc_resolve_address( + target, "1" /* port number */, args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + args_finish(&args); +} + +static void test_named_and_numeric_scope_ids(void) { + char* arbitrary_interface_name = static_cast(gpr_zalloc(IF_NAMESIZE)); + int interface_index = 0; + // Probe candidate interface index numbers until we find one that the + // system recognizes, and then use that for the test. + for (size_t i = 1; i < 65536; i++) { + if (if_indextoname(i, arbitrary_interface_name) != nullptr) { + gpr_log(GPR_DEBUG, + "Found interface at index %" PRIuPTR + " named %s. Will use this for the test", + i, arbitrary_interface_name); + interface_index = static_cast(i); + break; + } + } + GPR_ASSERT(strlen(arbitrary_interface_name) > 0); + // Test resolution of an ipv6 address with a named scope ID + gpr_log(GPR_DEBUG, "test resolution with a named scope ID"); + std::string target_with_named_scope_id = + absl::StrFormat("fe80::1234%%%s", arbitrary_interface_name); + resolve_address_must_succeed(target_with_named_scope_id.c_str()); + gpr_free(arbitrary_interface_name); + // Test resolution of an ipv6 address with a numeric scope ID + gpr_log(GPR_DEBUG, "test resolution with a numeric scope ID"); + std::string target_with_numeric_scope_id = + absl::StrFormat("fe80::1234%%%d", interface_index); + resolve_address_must_succeed(target_with_numeric_scope_id.c_str()); +} + +int main(int argc, char** argv) { + // First set the resolver type based off of --resolver + const char* resolver_type = nullptr; + gpr_cmdline* cl = gpr_cmdline_create("resolve address test"); + gpr_cmdline_add_string(cl, "resolver", "Resolver type (ares or native)", + &resolver_type); + // In case that there are more than one argument on the command line, + // --resolver will always be the first one, so only parse the first argument + // (other arguments may be unknown to cl) + gpr_cmdline_parse(cl, argc > 2 ? 2 : argc, argv); + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (strlen(resolver.get()) != 0) { + gpr_log(GPR_INFO, "Warning: overriding resolver setting of %s", + resolver.get()); + } + if (resolver_type != nullptr && gpr_stricmp(resolver_type, "native") == 0) { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "native"); + } else if (resolver_type != nullptr && + gpr_stricmp(resolver_type, "ares") == 0) { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "ares"); + } else { + gpr_log(GPR_ERROR, "--resolver_type was not set to ares or native"); + abort(); + } + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + { + grpc_core::ExecCtx exec_ctx; + test_named_and_numeric_scope_ids(); + // c-ares resolver doesn't support UDS (ability for native DNS resolver + // to handle this is only expected to be used by servers, which + // unconditionally use the native DNS resolver). + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + } + gpr_cmdline_destroy(cl); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/resolve_address_test.cc b/test/core/iomgr/resolve_address_test.cc new file mode 100644 index 00000000..6e633f52 --- /dev/null +++ b/test/core/iomgr/resolve_address_test.cc @@ -0,0 +1,402 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/resolve_address.h" + +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/lib/event_engine/sockaddr.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/test_config.h" + +static gpr_timespec test_deadline(void) { + return grpc_timeout_seconds_to_deadline(100); +} + +typedef struct args_struct { + gpr_event ev; + grpc_resolved_addresses* addrs; + gpr_mu* mu; + bool done; // guarded by mu + grpc_pollset* pollset; // guarded by mu + grpc_pollset_set* pollset_set; +} args_struct; + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void args_init(args_struct* args) { + gpr_event_init(&args->ev); + args->pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(args->pollset, &args->mu); + args->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(args->pollset_set, args->pollset); + args->addrs = nullptr; + args->done = false; +} + +void args_finish(args_struct* args) { + GPR_ASSERT(gpr_event_wait(&args->ev, test_deadline())); + grpc_resolved_addresses_destroy(args->addrs); + grpc_pollset_set_del_pollset(args->pollset_set, args->pollset); + grpc_pollset_set_destroy(args->pollset_set); + grpc_closure do_nothing_cb; + GRPC_CLOSURE_INIT(&do_nothing_cb, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + gpr_mu_lock(args->mu); + grpc_pollset_shutdown(args->pollset, &do_nothing_cb); + gpr_mu_unlock(args->mu); + // exec_ctx needs to be flushed before calling grpc_pollset_destroy() + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(args->pollset); + gpr_free(args->pollset); +} + +static grpc_millis n_sec_deadline(int seconds) { + return grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(seconds)); +} + +static void poll_pollset_until_request_done(args_struct* args) { + // Try to give enough time for c-ares to run through its retries + // a few times if needed. + grpc_millis deadline = n_sec_deadline(90); + while (true) { + grpc_core::ExecCtx exec_ctx; + { + grpc_core::MutexLockForGprMu lock(args->mu); + if (args->done) { + break; + } + grpc_millis time_left = deadline - grpc_core::ExecCtx::Get()->Now(); + gpr_log(GPR_DEBUG, "done=%d, time_left=%" PRId64, args->done, time_left); + GPR_ASSERT(time_left >= 0); + grpc_pollset_worker* worker = nullptr; + GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(args->pollset, &worker, n_sec_deadline(1))); + } + } + gpr_event_set(&args->ev, reinterpret_cast(1)); +} + +static void must_succeed(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err == GRPC_ERROR_NONE); + GPR_ASSERT(args->addrs != nullptr); + GPR_ASSERT(args->addrs->naddrs > 0); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +static void must_fail(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err != GRPC_ERROR_NONE); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +// This test assumes the environment has an ipv6 loopback +static void must_succeed_with_ipv6_first(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err == GRPC_ERROR_NONE); + GPR_ASSERT(args->addrs != nullptr); + GPR_ASSERT(args->addrs->naddrs > 0); + const struct sockaddr* first_address = + reinterpret_cast(args->addrs->addrs[0].addr); + GPR_ASSERT(first_address->sa_family == AF_INET6); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +static void must_succeed_with_ipv4_first(void* argsp, grpc_error_handle err) { + args_struct* args = static_cast(argsp); + GPR_ASSERT(err == GRPC_ERROR_NONE); + GPR_ASSERT(args->addrs != nullptr); + GPR_ASSERT(args->addrs->naddrs > 0); + const struct sockaddr* first_address = + reinterpret_cast(args->addrs->addrs[0].addr); + GPR_ASSERT(first_address->sa_family == AF_INET); + grpc_core::MutexLockForGprMu lock(args->mu); + args->done = true; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr)); +} + +static void test_localhost(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + "localhost:1", nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_default_port(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + "localhost", "1", args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_localhost_result_has_ipv6_first(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address("localhost:1", nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed_with_ipv6_first, &args, + grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_localhost_result_has_ipv4_first_when_ipv6_isnt_available( + void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address("localhost:1", nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed_with_ipv4_first, &args, + grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_non_numeric_default_port(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + "localhost", "https", args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_missing_default_port(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + "localhost", nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_fail, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_ipv6_with_port(void) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + "[2001:db8::1]:1", nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); +} + +static void test_ipv6_without_port(void) { + const char* const kCases[] = { + "2001:db8::1", + "2001:db8::1.2.3.4", + "[2001:db8::1]", + }; + unsigned i; + for (i = 0; i < sizeof(kCases) / sizeof(*kCases); i++) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + kCases[i], "80", args.pollset_set, + GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); + } +} + +static void test_invalid_ip_addresses(void) { + const char* const kCases[] = { + "293.283.1238.3:1", + "[2001:db8::11111]:1", + }; + unsigned i; + for (i = 0; i < sizeof(kCases) / sizeof(*kCases); i++) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + kCases[i], nullptr, args.pollset_set, + GRPC_CLOSURE_CREATE(must_fail, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); + } +} + +static void test_unparseable_hostports(void) { + const char* const kCases[] = { + "[", "[::1", "[::1]bad", "[1.2.3.4]", "[localhost]", "[localhost]:1", + }; + unsigned i; + for (i = 0; i < sizeof(kCases) / sizeof(*kCases); i++) { + grpc_core::ExecCtx exec_ctx; + args_struct args; + args_init(&args); + grpc_resolve_address( + kCases[i], "1", args.pollset_set, + GRPC_CLOSURE_CREATE(must_fail, &args, grpc_schedule_on_exec_ctx), + &args.addrs); + grpc_core::ExecCtx::Get()->Flush(); + poll_pollset_until_request_done(&args); + args_finish(&args); + } +} + +typedef struct mock_ipv6_disabled_source_addr_factory { + address_sorting_source_addr_factory base; +} mock_ipv6_disabled_source_addr_factory; + +static bool mock_ipv6_disabled_source_addr_factory_get_source_addr( + address_sorting_source_addr_factory* /*factory*/, + const address_sorting_address* dest_addr, + address_sorting_address* source_addr) { + // Mock lack of IPv6. For IPv4, set the source addr to be the same + // as the destination; tests won't actually connect on the result anyways. + if (address_sorting_abstract_get_family(dest_addr) == + ADDRESS_SORTING_AF_INET6) { + return false; + } + memcpy(source_addr->addr, &dest_addr->addr, dest_addr->len); + source_addr->len = dest_addr->len; + return true; +} + +void mock_ipv6_disabled_source_addr_factory_destroy( + address_sorting_source_addr_factory* factory) { + mock_ipv6_disabled_source_addr_factory* f = + reinterpret_cast(factory); + gpr_free(f); +} + +const address_sorting_source_addr_factory_vtable + kMockIpv6DisabledSourceAddrFactoryVtable = { + mock_ipv6_disabled_source_addr_factory_get_source_addr, + mock_ipv6_disabled_source_addr_factory_destroy, +}; + +int main(int argc, char** argv) { + // First set the resolver type based off of --resolver + const char* resolver_type = nullptr; + gpr_cmdline* cl = gpr_cmdline_create("resolve address test"); + gpr_cmdline_add_string(cl, "resolver", "Resolver type (ares or native)", + &resolver_type); + // In case that there are more than one argument on the command line, + // --resolver will always be the first one, so only parse the first argument + // (other arguments may be unknown to cl) + gpr_cmdline_parse(cl, argc > 2 ? 2 : argc, argv); + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (strlen(resolver.get()) != 0) { + gpr_log(GPR_INFO, "Warning: overriding resolver setting of %s", + resolver.get()); + } + if (resolver_type != nullptr && gpr_stricmp(resolver_type, "native") == 0) { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "native"); + } else if (resolver_type != nullptr && + gpr_stricmp(resolver_type, "ares") == 0) { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "ares"); + } else { + gpr_log(GPR_ERROR, "--resolver_type was not set to ares or native"); + abort(); + } + // Run the test. + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + test_localhost(); + test_default_port(); + test_non_numeric_default_port(); + test_missing_default_port(); + test_ipv6_with_port(); + test_ipv6_without_port(); + test_invalid_ip_addresses(); + test_unparseable_hostports(); + if (gpr_stricmp(resolver_type, "ares") == 0) { + // This behavior expectation is specific to c-ares. + test_localhost_result_has_ipv6_first(); + } + grpc_core::Executor::ShutdownAll(); + } + gpr_cmdline_destroy(cl); + grpc_shutdown(); + // The following test uses + // "address_sorting_override_source_addr_factory_for_testing", which works + // on a per-grpc-init basis, and so it's simplest to run this next test + // within a standalone grpc_init/grpc_shutdown pair. + if (gpr_stricmp(resolver_type, "ares") == 0) { + // Run a test case in which c-ares's address sorter + // thinks that IPv4 is available and IPv6 isn't. + grpc_init(); + mock_ipv6_disabled_source_addr_factory* factory = + static_cast( + gpr_malloc(sizeof(mock_ipv6_disabled_source_addr_factory))); + factory->base.vtable = &kMockIpv6DisabledSourceAddrFactoryVtable; + address_sorting_override_source_addr_factory_for_testing(&factory->base); + test_localhost_result_has_ipv4_first_when_ipv6_isnt_available(); + grpc_shutdown(); + } + return 0; +} diff --git a/test/core/iomgr/resource_quota_test.cc b/test/core/iomgr/resource_quota_test.cc new file mode 100644 index 00000000..959542d3 --- /dev/null +++ b/test/core/iomgr/resource_quota_test.cc @@ -0,0 +1,1027 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/resource_quota.h" + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +gpr_mu g_mu; +gpr_cv g_cv; + +static void inc_int_cb(void* a, grpc_error_handle /*error*/) { + gpr_mu_lock(&g_mu); + ++*static_cast(a); + gpr_cv_signal(&g_cv); + gpr_mu_unlock(&g_mu); +} + +static void assert_counter_becomes(int* ctr, int value) { + gpr_mu_lock(&g_mu); + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + while (*ctr != value) { + GPR_ASSERT(!gpr_cv_wait(&g_cv, &g_mu, deadline)); + } + gpr_mu_unlock(&g_mu); +} + +static void set_event_cb(void* a, grpc_error_handle /*error*/) { + gpr_event_set(static_cast(a), reinterpret_cast(1)); +} +grpc_closure* set_event(gpr_event* ev) { + return GRPC_CLOSURE_CREATE(set_event_cb, ev, grpc_schedule_on_exec_ctx); +} + +typedef struct { + size_t size; + grpc_resource_user* resource_user; + grpc_closure* then; +} reclaimer_args; + +static void reclaimer_cb(void* args, grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + reclaimer_args* a = static_cast(args); + grpc_resource_user_free(a->resource_user, a->size); + grpc_resource_user_finish_reclamation(a->resource_user); + grpc_core::Closure::Run(DEBUG_LOCATION, a->then, GRPC_ERROR_NONE); + gpr_free(a); +} + +grpc_closure* make_reclaimer(grpc_resource_user* resource_user, size_t size, + grpc_closure* then) { + reclaimer_args* a = static_cast(gpr_malloc(sizeof(*a))); + a->size = size; + a->resource_user = resource_user; + a->then = then; + return GRPC_CLOSURE_CREATE(reclaimer_cb, a, grpc_schedule_on_exec_ctx); +} + +static void unused_reclaimer_cb(void* arg, grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_CANCELLED); + grpc_core::Closure::Run(DEBUG_LOCATION, static_cast(arg), + GRPC_ERROR_NONE); +} +grpc_closure* make_unused_reclaimer(grpc_closure* then) { + return GRPC_CLOSURE_CREATE(unused_reclaimer_cb, then, + grpc_schedule_on_exec_ctx); +} + +static void destroy_user(grpc_resource_user* usr) { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_unref(usr); +} + +static void test_no_op(void) { + gpr_log(GPR_INFO, "** test_no_op **"); + grpc_resource_quota_unref(grpc_resource_quota_create("test_no_op")); +} + +static void test_resize_then_destroy(void) { + gpr_log(GPR_INFO, "** test_resize_then_destroy **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_resize_then_destroy"); + grpc_resource_quota_resize(q, 1024 * 1024); + grpc_resource_quota_unref(q); +} + +static void test_resource_user_no_op(void) { + gpr_log(GPR_INFO, "** test_resource_user_no_op **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_resource_user_no_op"); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_instant_alloc_then_free(void) { + gpr_log(GPR_INFO, "** test_instant_alloc_then_free **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_instant_alloc_then_free"); + grpc_resource_quota_resize(q, 1024 * 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, nullptr)); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_instant_alloc_free_pair(void) { + gpr_log(GPR_INFO, "** test_instant_alloc_free_pair **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_instant_alloc_free_pair"); + grpc_resource_quota_resize(q, 1024 * 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, nullptr)); + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_simple_async_alloc(void) { + gpr_log(GPR_INFO, "** test_simple_async_alloc **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_simple_async_alloc"); + grpc_resource_quota_resize(q, 1024 * 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + { + // Now the allocation should be inline. + GPR_ASSERT(grpc_resource_user_alloc(usr, 1024, nullptr)); + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_async_alloc_blocked_by_size(void) { + gpr_log(GPR_INFO, "** test_async_alloc_blocked_by_size **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_async_alloc_blocked_by_size"); + grpc_resource_quota_resize(q, 1); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + gpr_event ev; + gpr_event_init(&ev); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait( + &ev, grpc_timeout_milliseconds_to_deadline(100)) == nullptr); + } + grpc_resource_quota_resize(q, 1024); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_scavenge(void) { + gpr_log(GPR_INFO, "** test_scavenge **"); + grpc_resource_quota* q = grpc_resource_quota_create("test_scavenge"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr1 = grpc_resource_user_create(q, "usr1"); + grpc_resource_user* usr2 = grpc_resource_user_create(q, "usr2"); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr1, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr1, 1024); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr2, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr2, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr1); + destroy_user(usr2); +} + +static void test_scavenge_blocked(void) { + gpr_log(GPR_INFO, "** test_scavenge_blocked **"); + grpc_resource_quota* q = grpc_resource_quota_create("test_scavenge_blocked"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr1 = grpc_resource_user_create(q, "usr1"); + grpc_resource_user* usr2 = grpc_resource_user_create(q, "usr2"); + gpr_event ev; + { + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr1, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr2, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait( + &ev, grpc_timeout_milliseconds_to_deadline(100)) == nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr1, 1024); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr2, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr1); + destroy_user(usr2); +} + +static void test_blocked_until_scheduled_reclaim(void) { + gpr_log(GPR_INFO, "** test_blocked_until_scheduled_reclaim **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_blocked_until_scheduled_reclaim"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + gpr_event reclaim_done; + gpr_event_init(&reclaim_done); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_reclaimer(usr, 1024, set_event(&reclaim_done))); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaim_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_blocked_until_scheduled_reclaim_and_scavenge(void) { + gpr_log(GPR_INFO, "** test_blocked_until_scheduled_reclaim_and_scavenge **"); + grpc_resource_quota* q = grpc_resource_quota_create( + "test_blocked_until_scheduled_reclaim_and_scavenge"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr1 = grpc_resource_user_create(q, "usr1"); + grpc_resource_user* usr2 = grpc_resource_user_create(q, "usr2"); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr1, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + gpr_event reclaim_done; + gpr_event_init(&reclaim_done); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr1, false, make_reclaimer(usr1, 1024, set_event(&reclaim_done))); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr2, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaim_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr2, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr1); + destroy_user(usr2); +} + +static void test_blocked_until_scheduled_destructive_reclaim(void) { + gpr_log(GPR_INFO, "** test_blocked_until_scheduled_destructive_reclaim **"); + grpc_resource_quota* q = grpc_resource_quota_create( + "test_blocked_until_scheduled_destructive_reclaim"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + gpr_event reclaim_done; + gpr_event_init(&reclaim_done); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, true, make_reclaimer(usr, 1024, set_event(&reclaim_done))); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaim_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); +} + +static void test_unused_reclaim_is_cancelled(void) { + gpr_log(GPR_INFO, "** test_unused_reclaim_is_cancelled **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_unused_reclaim_is_cancelled"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + gpr_event benign_done; + gpr_event_init(&benign_done); + gpr_event destructive_done; + gpr_event_init(&destructive_done); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_unused_reclaimer(set_event(&benign_done))); + grpc_resource_user_post_reclaimer( + usr, true, make_unused_reclaimer(set_event(&destructive_done))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + grpc_resource_quota_unref(q); + destroy_user(usr); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); +} + +static void test_benign_reclaim_is_preferred(void) { + gpr_log(GPR_INFO, "** test_benign_reclaim_is_preferred **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_benign_reclaim_is_preferred"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + gpr_event benign_done; + gpr_event_init(&benign_done); + gpr_event destructive_done; + gpr_event_init(&destructive_done); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_reclaimer(usr, 1024, set_event(&benign_done))); + grpc_resource_user_post_reclaimer( + usr, true, make_unused_reclaimer(set_event(&destructive_done))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); +} + +static void test_multiple_reclaims_can_be_triggered(void) { + gpr_log(GPR_INFO, "** test_multiple_reclaims_can_be_triggered **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_multiple_reclaims_can_be_triggered"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + gpr_event benign_done; + gpr_event_init(&benign_done); + gpr_event destructive_done; + gpr_event_init(&destructive_done); + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_reclaimer(usr, 512, set_event(&benign_done))); + grpc_resource_user_post_reclaimer( + usr, true, make_reclaimer(usr, 512, set_event(&destructive_done))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + gpr_event ev; + gpr_event_init(&ev); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&ev))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&ev, grpc_timeout_seconds_to_deadline(5)) != + nullptr); + ; + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + grpc_resource_quota_unref(q); + destroy_user(usr); + GPR_ASSERT(gpr_event_wait(&benign_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&destructive_done, + grpc_timeout_seconds_to_deadline(5)) != nullptr); +} + +static void test_resource_user_stays_allocated_until_memory_released(void) { + gpr_log(GPR_INFO, + "** test_resource_user_stays_allocated_until_memory_released **"); + grpc_resource_quota* q = grpc_resource_quota_create( + "test_resource_user_stays_allocated_until_memory_released"); + grpc_resource_quota_resize(q, 1024 * 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, nullptr)); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_quota_unref(q); + grpc_resource_user_unref(usr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } +} + +static void +test_resource_user_stays_allocated_and_reclaimers_unrun_until_memory_released( + void) { + gpr_log(GPR_INFO, + "** " + "test_resource_user_stays_allocated_and_reclaimers_unrun_until_" + "memory_released **"); + grpc_resource_quota* q = grpc_resource_quota_create( + "test_resource_user_stays_allocated_and_reclaimers_unrun_until_memory_" + "released"); + grpc_resource_quota_resize(q, 1024); + for (int i = 0; i < 10; i++) { + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + gpr_event reclaimer_cancelled; + gpr_event_init(&reclaimer_cancelled); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_unused_reclaimer(set_event(&reclaimer_cancelled))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaimer_cancelled, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + gpr_event allocated; + gpr_event_init(&allocated); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&allocated))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&allocated, grpc_timeout_seconds_to_deadline( + 5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&reclaimer_cancelled, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_unref(usr); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaimer_cancelled, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaimer_cancelled, + grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + } + grpc_resource_quota_unref(q); +} + +static void test_reclaimers_can_be_posted_repeatedly(void) { + gpr_log(GPR_INFO, "** test_reclaimers_can_be_posted_repeatedly **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_reclaimers_can_be_posted_repeatedly"); + grpc_resource_quota_resize(q, 1024); + grpc_resource_user* usr = grpc_resource_user_create(q, "usr"); + { + gpr_event allocated; + gpr_event_init(&allocated); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&allocated))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&allocated, + grpc_timeout_seconds_to_deadline(5)) != nullptr); + } + for (int i = 0; i < 10; i++) { + gpr_event reclaimer_done; + gpr_event_init(&reclaimer_done); + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_post_reclaimer( + usr, false, make_reclaimer(usr, 1024, set_event(&reclaimer_done))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&reclaimer_done, + grpc_timeout_milliseconds_to_deadline(100)) == + nullptr); + } + { + gpr_event allocated; + gpr_event_init(&allocated); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_resource_user_alloc(usr, 1024, set_event(&allocated))); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(gpr_event_wait(&allocated, grpc_timeout_seconds_to_deadline( + 5)) != nullptr); + GPR_ASSERT(gpr_event_wait(&reclaimer_done, + grpc_timeout_seconds_to_deadline(5)) != + nullptr); + } + } + { + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_free(usr, 1024); + } + destroy_user(usr); + grpc_resource_quota_unref(q); +} + +static void test_one_slice(void) { + gpr_log(GPR_INFO, "** test_one_slice **"); + grpc_resource_quota* q = grpc_resource_quota_create("test_one_slice"); + grpc_resource_quota_resize(q, 1024); + grpc_slice_allocator* alloc = grpc_slice_allocator_create(q, "usr"); + int num_allocs = 0; + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + { + const int start_allocs = num_allocs; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + alloc, 1024, 1, grpc_slice_allocator_intent::kDefault, &buffer, + inc_int_cb, &num_allocs)); + grpc_core::ExecCtx::Get()->Flush(); + assert_counter_becomes(&num_allocs, start_allocs + 1); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer_destroy_internal(&buffer); + grpc_slice_allocator_destroy(alloc); + } + grpc_resource_quota_unref(q); +} + +static void test_one_slice_through_slice_allocator_factory(void) { + gpr_log(GPR_INFO, "** test_one_slice_through_slice_allocator_factory **"); + grpc_resource_quota* resource_quota = grpc_resource_quota_create( + "test_one_slice_through_slice_allocator_factory"); + int num_allocs = 0; + grpc_resource_quota_resize(resource_quota, 1024); + grpc_slice_allocator_factory* slice_allocator_factory = + grpc_slice_allocator_factory_create(resource_quota); + grpc_slice_allocator* slice_allocator = + grpc_slice_allocator_factory_create_slice_allocator( + slice_allocator_factory, "usr"); + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + { + const int start_allocs = num_allocs; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + slice_allocator, 1024, 1, grpc_slice_allocator_intent::kDefault, + &buffer, inc_int_cb, &num_allocs)); + grpc_core::ExecCtx::Get()->Flush(); + assert_counter_becomes(&num_allocs, start_allocs + 1); + } + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer_destroy_internal(&buffer); + grpc_slice_allocator_destroy(slice_allocator); + grpc_slice_allocator_factory_destroy(slice_allocator_factory); + } +} + +static void test_slice_allocator_pressure_adjusted_allocation() { + gpr_log(GPR_INFO, "** test_slice_allocator_pressure_adjusted_allocation **"); + // Quota large enough to avoid the 1/16 maximum allocation limit. + grpc_resource_quota* resource_quota = grpc_resource_quota_create( + "test_one_slice_through_slice_allocator_factory"); + grpc_resource_quota_resize(resource_quota, 32 * 1024); + grpc_resource_user* black_hole_resource_user = + grpc_resource_user_create(resource_quota, "black hole"); + { + // Consume ~95% of the quota + grpc_core::ExecCtx exec_ctx; + grpc_resource_user_safe_alloc(black_hole_resource_user, 31 * 1024); + } + GPR_ASSERT(grpc_resource_quota_get_memory_pressure(resource_quota) > 0.95); + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice_allocator* constrained_allocator = + grpc_slice_allocator_create(resource_quota, "constrained user"); + { + // Attempt to get 512 bytes + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + constrained_allocator, 2 * 1024, 1, + grpc_slice_allocator_intent::kReadBuffer, &buffer, + [](void*, grpc_error_handle) {}, nullptr)); + } + grpc_slice slice = grpc_slice_buffer_take_first(&buffer); + GPR_ASSERT(grpc_refcounted_slice_length(slice) < 2 * 1024); + GPR_ASSERT(grpc_refcounted_slice_length(slice) >= 256); + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_unref(slice); + grpc_resource_user_free(black_hole_resource_user, 31 * 1024); + grpc_resource_user_unref(black_hole_resource_user); + grpc_slice_allocator_destroy(constrained_allocator); + grpc_resource_quota_unref(resource_quota); + grpc_slice_buffer_destroy_internal(&buffer); + } +} + +static void test_slice_allocator_capped_allocation() { + gpr_log(GPR_INFO, "** test_slice_allocator_pressure_adjusted_allocation **"); + grpc_resource_quota* resource_quota = grpc_resource_quota_create( + "test_one_slice_through_slice_allocator_factory"); + grpc_resource_quota_resize(resource_quota, 32 * 1024); + grpc_arg to_add[2]; + grpc_channel_args* ch_args; + to_add[0] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_TCP_MIN_READ_CHUNK_SIZE), 1024); + to_add[1] = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_TCP_MAX_READ_CHUNK_SIZE), 2048); + ch_args = grpc_channel_args_copy_and_add(nullptr, to_add, 2); + grpc_slice_allocator* slice_allocator = + grpc_slice_allocator_create(resource_quota, "capped user", ch_args); + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + { + // Attempt to get more than the maximum + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + slice_allocator, 4 * 1024, 1, grpc_slice_allocator_intent::kReadBuffer, + &buffer, [](void*, grpc_error_handle) {}, nullptr)); + } + grpc_slice max_slice = grpc_slice_buffer_take_first(&buffer); + GPR_ASSERT(grpc_refcounted_slice_length(max_slice) == 2048); + { + // Attempt to get less than the minimum + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + slice_allocator, 512, 1, grpc_slice_allocator_intent::kReadBuffer, + &buffer, [](void*, grpc_error_handle) {}, nullptr)); + } + grpc_slice min_slice = grpc_slice_buffer_take_first(&buffer); + GPR_ASSERT(grpc_refcounted_slice_length(min_slice) == 1024); + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_unref(max_slice); + grpc_slice_unref(min_slice); + grpc_slice_allocator_destroy(slice_allocator); + grpc_resource_quota_unref(resource_quota); + grpc_slice_buffer_destroy_internal(&buffer); + grpc_channel_args_destroy(ch_args); + } +} + +static void test_one_slice_deleted_late(void) { + gpr_log(GPR_INFO, "** test_one_slice_deleted_late **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_one_slice_deleted_late"); + grpc_resource_quota_resize(q, 1024); + grpc_slice_allocator* alloc = grpc_slice_allocator_create(q, "usr"); + int num_allocs = 0; + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + { + const int start_allocs = num_allocs; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + alloc, 1024, 1, grpc_slice_allocator_intent::kDefault, &buffer, + inc_int_cb, &num_allocs)); + grpc_core::ExecCtx::Get()->Flush(); + assert_counter_becomes(&num_allocs, start_allocs + 1); + } + + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_allocator_destroy(alloc); + grpc_resource_quota_unref(q); + grpc_slice_buffer_destroy_internal(&buffer); + } +} + +static void test_resize_to_zero(void) { + gpr_log(GPR_INFO, "** test_resize_to_zero **"); + grpc_resource_quota* q = grpc_resource_quota_create("test_resize_to_zero"); + grpc_resource_quota_resize(q, 0); + grpc_resource_quota_unref(q); +} + +static void test_negative_rq_free_pool(void) { + gpr_log(GPR_INFO, "** test_negative_rq_free_pool **"); + grpc_resource_quota* q = + grpc_resource_quota_create("test_negative_rq_free_pool"); + grpc_resource_quota_resize(q, 1024); + grpc_slice_allocator* alloc = grpc_slice_allocator_create(q, "usr"); + int num_allocs = 0; + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + + { + const int start_allocs = num_allocs; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(!grpc_slice_allocator_allocate( + alloc, 1024, 1, grpc_slice_allocator_intent::kDefault, &buffer, + inc_int_cb, &num_allocs)); + grpc_core::ExecCtx::Get()->Flush(); + assert_counter_becomes(&num_allocs, start_allocs + 1); + } + + grpc_resource_quota_resize(q, 512); + + double eps = 0.0001; + GPR_ASSERT(grpc_resource_quota_get_memory_pressure(q) < 1 + eps); + GPR_ASSERT(grpc_resource_quota_get_memory_pressure(q) > 1 - eps); + + { + grpc_core::ExecCtx exec_ctx; + grpc_slice_allocator_destroy(alloc); + grpc_resource_quota_unref(q); + grpc_slice_buffer_destroy_internal(&buffer); + } +} + +// Simple test to check resource quota thread limits +static void test_thread_limit() { + grpc_core::ExecCtx exec_ctx; + + grpc_resource_quota* rq = grpc_resource_quota_create("test_thread_limit"); + grpc_resource_user* ru1 = grpc_resource_user_create(rq, "ru1"); + grpc_resource_user* ru2 = grpc_resource_user_create(rq, "ru2"); + + // Max threads = 100 + grpc_resource_quota_set_max_threads(rq, 100); + + // Request quota for 100 threads (50 for ru1, 50 for ru2) + GPR_ASSERT(grpc_resource_user_allocate_threads(ru1, 10)); + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 10)); + GPR_ASSERT(grpc_resource_user_allocate_threads(ru1, 40)); + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 40)); + + // Threads exhausted. Next request must fail + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru2, 20)); + + // Free 20 threads from two different users + grpc_resource_user_free_threads(ru1, 10); + grpc_resource_user_free_threads(ru2, 10); + + // Next request to 20 threads must succeed + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 20)); + + // No more thread quota again + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru1, 20)); + + // Free 10 more + grpc_resource_user_free_threads(ru1, 10); + + GPR_ASSERT(grpc_resource_user_allocate_threads(ru1, 5)); + GPR_ASSERT( + !grpc_resource_user_allocate_threads(ru2, 10)); // Only 5 available + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 5)); + + // Teardown (ru1 and ru2 release all the quota back to rq) + grpc_resource_user_unref(ru1); + grpc_resource_user_unref(ru2); + grpc_resource_quota_unref(rq); +} + +// Change max quota in either direction dynamically +static void test_thread_maxquota_change() { + grpc_core::ExecCtx exec_ctx; + + grpc_resource_quota* rq = + grpc_resource_quota_create("test_thread_maxquota_change"); + grpc_resource_user* ru1 = grpc_resource_user_create(rq, "ru1"); + grpc_resource_user* ru2 = grpc_resource_user_create(rq, "ru2"); + + // Max threads = 100 + grpc_resource_quota_set_max_threads(rq, 100); + + // Request quota for 100 threads (50 for ru1, 50 for ru2) + GPR_ASSERT(grpc_resource_user_allocate_threads(ru1, 50)); + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 50)); + + // Threads exhausted. Next request must fail + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru2, 20)); + + // Increase maxquota and retry + // Max threads = 150; + grpc_resource_quota_set_max_threads(rq, 150); + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 20)); // ru2=70, ru1=50 + + // Decrease maxquota (Note: Quota already given to ru1 and ru2 is + // unaffected) Max threads = 10; + grpc_resource_quota_set_max_threads(rq, 10); + + // New requests will fail until quota is available + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru1, 10)); + + // Make quota available + grpc_resource_user_free_threads(ru1, 50); // ru1 now has 0 + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru1, 10)); // not enough + + grpc_resource_user_free_threads(ru2, 70); // ru2 now has 0 + + // Now we can get quota up-to 10, the current max + GPR_ASSERT(grpc_resource_user_allocate_threads(ru2, 10)); + // No more thread quota again + GPR_ASSERT(!grpc_resource_user_allocate_threads(ru1, 10)); + + // Teardown (ru1 and ru2 release all the quota back to rq) + grpc_resource_user_unref(ru1); + grpc_resource_user_unref(ru2); + grpc_resource_quota_unref(rq); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + gpr_mu_init(&g_mu); + gpr_cv_init(&g_cv); + test_no_op(); + test_resize_then_destroy(); + test_resource_user_no_op(); + test_instant_alloc_then_free(); + test_instant_alloc_free_pair(); + test_simple_async_alloc(); + test_async_alloc_blocked_by_size(); + test_scavenge(); + test_scavenge_blocked(); + test_blocked_until_scheduled_reclaim(); + test_blocked_until_scheduled_reclaim_and_scavenge(); + test_blocked_until_scheduled_destructive_reclaim(); + test_unused_reclaim_is_cancelled(); + test_benign_reclaim_is_preferred(); + test_multiple_reclaims_can_be_triggered(); + test_resource_user_stays_allocated_until_memory_released(); + test_resource_user_stays_allocated_and_reclaimers_unrun_until_memory_released(); + test_reclaimers_can_be_posted_repeatedly(); + test_one_slice(); + test_one_slice_deleted_late(); + test_resize_to_zero(); + test_negative_rq_free_pool(); + test_one_slice_through_slice_allocator_factory(); + test_slice_allocator_pressure_adjusted_allocation(); + test_slice_allocator_capped_allocation(); + gpr_mu_destroy(&g_mu); + gpr_cv_destroy(&g_cv); + + // Resource quota thread related + test_thread_limit(); + test_thread_maxquota_change(); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/socket_utils_test.cc b/test/core/iomgr/socket_utils_test.cc new file mode 100644 index 00000000..aa4b1800 --- /dev/null +++ b/test/core/iomgr/socket_utils_test.cc @@ -0,0 +1,172 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_UTILS_COMMON + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/socket_mutator.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "test/core/util/test_config.h" + +struct test_socket_mutator { + grpc_socket_mutator base; + int option_value; +}; + +static bool mutate_fd(int fd, grpc_socket_mutator* mutator) { + int newval; + socklen_t intlen = sizeof(newval); + struct test_socket_mutator* m = + reinterpret_cast(mutator); + + if (0 != setsockopt(fd, IPPROTO_IP, IP_TOS, &m->option_value, + sizeof(m->option_value))) { + return false; + } + if (0 != getsockopt(fd, IPPROTO_IP, IP_TOS, &newval, &intlen)) { + return false; + } + if (newval != m->option_value) { + return false; + } + return true; +} + +static bool mutate_fd_2(const grpc_mutate_socket_info* info, + grpc_socket_mutator* mutator) { + int newval; + socklen_t intlen = sizeof(newval); + struct test_socket_mutator* m = + reinterpret_cast(mutator); + + if (0 != setsockopt(info->fd, IPPROTO_IP, IP_TOS, &m->option_value, + sizeof(m->option_value))) { + return false; + } + if (0 != getsockopt(info->fd, IPPROTO_IP, IP_TOS, &newval, &intlen)) { + return false; + } + if (newval != m->option_value) { + return false; + } + return true; +} + +static void destroy_test_mutator(grpc_socket_mutator* mutator) { + struct test_socket_mutator* m = + reinterpret_cast(mutator); + gpr_free(m); +} + +static int compare_test_mutator(grpc_socket_mutator* a, + grpc_socket_mutator* b) { + struct test_socket_mutator* ma = + reinterpret_cast(a); + struct test_socket_mutator* mb = + reinterpret_cast(b); + return grpc_core::QsortCompare(ma->option_value, mb->option_value); +} + +static const grpc_socket_mutator_vtable mutator_vtable = { + mutate_fd, compare_test_mutator, destroy_test_mutator, nullptr}; + +static const grpc_socket_mutator_vtable mutator_vtable2 = { + nullptr, compare_test_mutator, destroy_test_mutator, mutate_fd_2}; + +static void test_with_vtable(const grpc_socket_mutator_vtable* vtable) { + int sock = socket(PF_INET, SOCK_STREAM, 0); + GPR_ASSERT(sock > 0); + + struct test_socket_mutator mutator; + grpc_socket_mutator_init(&mutator.base, vtable); + + mutator.option_value = IPTOS_LOWDELAY; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "set_socket_with_mutator", + grpc_set_socket_with_mutator(sock, GRPC_FD_CLIENT_CONNECTION_USAGE, + (grpc_socket_mutator*)&mutator))); + + mutator.option_value = IPTOS_THROUGHPUT; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "set_socket_with_mutator", + grpc_set_socket_with_mutator(sock, GRPC_FD_CLIENT_CONNECTION_USAGE, + (grpc_socket_mutator*)&mutator))); + + mutator.option_value = IPTOS_RELIABILITY; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "set_socket_with_mutator", + grpc_set_socket_with_mutator(sock, GRPC_FD_CLIENT_CONNECTION_USAGE, + (grpc_socket_mutator*)&mutator))); + + mutator.option_value = -1; + auto err = grpc_set_socket_with_mutator( + sock, GRPC_FD_CLIENT_CONNECTION_USAGE, + reinterpret_cast(&mutator)); + GPR_ASSERT(err != GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(err); +} + +int main(int argc, char** argv) { + int sock; + grpc::testing::TestEnvironment env(argc, argv); + + sock = socket(PF_INET, SOCK_STREAM, 0); + GPR_ASSERT(sock > 0); + + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_nonblocking", + grpc_set_socket_nonblocking(sock, 1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_nonblocking", + grpc_set_socket_nonblocking(sock, 0))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_cloexec", + grpc_set_socket_cloexec(sock, 1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_cloexec", + grpc_set_socket_cloexec(sock, 0))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_reuse_addr", + grpc_set_socket_reuse_addr(sock, 1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_reuse_addr", + grpc_set_socket_reuse_addr(sock, 0))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_low_latency", + grpc_set_socket_low_latency(sock, 1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("set_socket_low_latency", + grpc_set_socket_low_latency(sock, 0))); + + test_with_vtable(&mutator_vtable); + test_with_vtable(&mutator_vtable2); + + close(sock); + + return 0; +} + +#else /* GRPC_POSIX_SOCKET_UTILS_COMMON */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_UTILS_COMMON */ diff --git a/test/core/iomgr/stranded_event_test.cc b/test/core/iomgr/stranded_event_test.cc new file mode 100644 index 00000000..3634b2cc --- /dev/null +++ b/test/core/iomgr/stranded_event_test.cc @@ -0,0 +1,445 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/uri/uri_parser.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/memory_counters.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +const int kNumMessagePingPongsPerCall = 4000; + +struct TestCall { + explicit TestCall(grpc_channel* channel, grpc_call* call, + grpc_completion_queue* cq) + : channel(channel), call(call), cq(cq) {} + + TestCall(const TestCall& other) = delete; + TestCall& operator=(const TestCall& other) = delete; + + ~TestCall() { + grpc_call_cancel(call, nullptr); + grpc_call_unref(call); + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); + } + + grpc_channel* channel; + grpc_call* call; + grpc_completion_queue* cq; + absl::optional + status; // filled in when the call is finished +}; + +void StartCall(TestCall* test_call) { + grpc_op ops[6]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + void* tag = test_call; + grpc_call_error error = grpc_call_start_batch( + test_call->call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + cq_verifier* cqv = cq_verifier_create(test_call->cq); + CQ_EXPECT_COMPLETION(cqv, tag, 1); + cq_verify(cqv); + cq_verifier_destroy(cqv); +} + +void SendMessage(grpc_call* call, grpc_completion_queue* cq) { + grpc_slice request_payload_slice = grpc_slice_from_copied_string("a"); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_op ops[6]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->reserved = nullptr; + op++; + void* tag = call; + grpc_call_error error = grpc_call_start_batch( + call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + cq_verifier* cqv = cq_verifier_create(cq); + CQ_EXPECT_COMPLETION(cqv, tag, 1); + cq_verify(cqv); + cq_verifier_destroy(cqv); + grpc_byte_buffer_destroy(request_payload); +} + +void ReceiveMessage(grpc_call* call, grpc_completion_queue* cq) { + grpc_byte_buffer* request_payload = nullptr; + grpc_op ops[6]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload; + op->reserved = nullptr; + op++; + void* tag = call; + grpc_call_error error = grpc_call_start_batch( + call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + cq_verifier* cqv = cq_verifier_create(cq); + CQ_EXPECT_COMPLETION(cqv, tag, 1); + cq_verify(cqv); + cq_verifier_destroy(cqv); + grpc_byte_buffer_destroy(request_payload); +} + +void ReceiveInitialMetadata(TestCall* test_call, gpr_timespec deadline) { + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array_init(&initial_metadata_recv); + grpc_op ops[6]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->reserved = nullptr; + op++; + void* tag = test_call; + grpc_call_error error = grpc_call_start_batch( + test_call->call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_event event = + grpc_completion_queue_next(test_call->cq, deadline, nullptr); + if (event.type != GRPC_OP_COMPLETE || !event.success) { + gpr_log(GPR_ERROR, + "Wanted op complete with success, got op type:%d success:%d", + event.type, event.success); + GPR_ASSERT(0); + } + GPR_ASSERT(event.tag == tag); + grpc_metadata_array_destroy(&initial_metadata_recv); +} + +void FinishCall(TestCall* test_call) { + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_slice details; + grpc_metadata_array_init(&trailing_metadata_recv); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + void* tag = test_call; + grpc_call_error error = grpc_call_start_batch( + test_call->call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_event event = grpc_completion_queue_next( + test_call->cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success); + GPR_ASSERT(event.tag == tag); + test_call->status = status; + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_slice_unref(details); +} + +class TestServer { + public: + explicit TestServer() { + cq_ = grpc_completion_queue_create_for_next(nullptr); + server_ = grpc_server_create(nullptr, nullptr); + address_ = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + grpc_server_register_completion_queue(server_, cq_, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port(server_, address_.c_str())); + grpc_server_start(server_); + thread_ = std::thread(std::bind(&TestServer::AcceptThread, this)); + } + + ~TestServer() { + thread_.join(); + void* shutdown_and_notify_tag = this; + grpc_server_shutdown_and_notify(server_, cq_, shutdown_and_notify_tag); + grpc_event event = grpc_completion_queue_next( + cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.tag == shutdown_and_notify_tag); + GPR_ASSERT(event.success); + grpc_server_destroy(server_); + grpc_completion_queue_shutdown(cq_); + while (grpc_completion_queue_next(cq_, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq_); + } + + std::string address() const { return address_; } + + private: + void AcceptThread() { + grpc_call_details call_details; + grpc_call_details_init(&call_details); + grpc_metadata_array request_metadata_recv; + grpc_metadata_array_init(&request_metadata_recv); + void* tag = &call_details; + grpc_call* call; + grpc_call_error error = grpc_server_request_call( + server_, &call, &call_details, &request_metadata_recv, cq_, cq_, tag); + GPR_ASSERT(error == GRPC_CALL_OK); + grpc_event event = grpc_completion_queue_next( + cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success); + GPR_ASSERT(event.tag == tag); + grpc_op ops[6]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(call, ops, static_cast(op - ops), tag, + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + event = grpc_completion_queue_next(cq_, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success); + GPR_ASSERT(event.tag == tag); + for (int i = 0; i < kNumMessagePingPongsPerCall; i++) { + ReceiveMessage(call, cq_); + SendMessage(call, cq_); + } + grpc_call_cancel_with_status(call, GRPC_STATUS_PERMISSION_DENIED, + "test status", nullptr); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(call); + } + + grpc_server* server_; + grpc_completion_queue* cq_; + std::string address_; + std::thread thread_; +}; + +grpc_core::Resolver::Result BuildResolverResponse( + const std::vector& addresses) { + grpc_core::Resolver::Result result; + for (const auto& address_str : addresses) { + absl::StatusOr uri = grpc_core::URI::Parse(address_str); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "Failed to parse. Error: %s", + uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*uri, &address)); + result.addresses.emplace_back(address.addr, address.len, nullptr); + } + return result; +} + +// Perform a simple RPC where the server cancels the request with +// grpc_call_cancel_with_status +TEST(Pollers, TestReadabilityNotificationsDontGetStrandedOnOneCq) { + gpr_log(GPR_DEBUG, "test thread"); + /* 64 is a somewhat arbitary number, the important thing is that it + * exceeds the value of MAX_EPOLL_EVENTS_HANDLED_EACH_POLL_CALL (16), which + * is enough to repro a bug at time of writing. */ + const int kNumCalls = 64; + size_t ping_pong_round = 0; + size_t ping_pongs_done = 0; + grpc_core::Mutex ping_pong_round_mu; + grpc_core::CondVar ping_pong_round_cv; + const std::string kSharedUnconnectableAddress = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + gpr_log(GPR_DEBUG, "created unconnectable address:%s", + kSharedUnconnectableAddress.c_str()); + std::vector threads; + threads.reserve(kNumCalls); + std::vector> test_servers; + // Instantiate servers inline here, so that we get port allocation out of the + // way and don't depend on it during the actual test. It can sometimes take + // time to allocate kNumCalls ports from the port server, and we don't want to + // hit test timeouts because of that. + test_servers.reserve(kNumCalls); + for (int i = 0; i < kNumCalls; i++) { + test_servers.push_back(absl::make_unique()); + } + for (int i = 0; i < kNumCalls; i++) { + auto test_server = test_servers[i].get(); + threads.push_back(std::thread([kSharedUnconnectableAddress, + &ping_pong_round, &ping_pongs_done, + &ping_pong_round_mu, &ping_pong_round_cv, + test_server]() { + gpr_log(GPR_DEBUG, "using test_server with address:%s", + test_server->address().c_str()); + std::vector args; + grpc_arg service_config_arg; + service_config_arg.type = GRPC_ARG_STRING; + service_config_arg.key = const_cast(GRPC_ARG_SERVICE_CONFIG); + service_config_arg.value.string = + const_cast("{\"loadBalancingConfig\":[{\"round_robin\":{}}]}"); + args.push_back(service_config_arg); + auto fake_resolver_response_generator = + grpc_core::MakeRefCounted(); + { + grpc_core::ExecCtx exec_ctx; + fake_resolver_response_generator->SetResponse(BuildResolverResponse( + {absl::StrCat("ipv4:", kSharedUnconnectableAddress), + absl::StrCat("ipv4:", test_server->address())})); + } + args.push_back(grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + fake_resolver_response_generator.get())); + grpc_channel_args* channel_args = + grpc_channel_args_copy_and_add(nullptr, args.data(), args.size()); + grpc_channel* channel = grpc_insecure_channel_create( + "fake:///test.server.com", channel_args, nullptr); + grpc_channel_args_destroy(channel_args); + grpc_completion_queue* cq = + grpc_completion_queue_create_for_next(nullptr); + grpc_call* call = grpc_channel_create_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + auto test_call = absl::make_unique(channel, call, cq); + // Start a call, and ensure that round_robin load balancing is configured + StartCall(test_call.get()); + // Make sure the test is doing what it's meant to be doing + grpc_channel_info channel_info; + memset(&channel_info, 0, sizeof(channel_info)); + char* lb_policy_name = nullptr; + channel_info.lb_policy_name = &lb_policy_name; + grpc_channel_get_info(channel, &channel_info); + EXPECT_EQ(std::string(lb_policy_name), "round_robin") + << "not using round robin; this test has a low chance of hitting the " + "bug that it's meant to try to hit"; + gpr_free(lb_policy_name); + // Receive initial metadata + gpr_log(GPR_DEBUG, + "now receive initial metadata on call with server address:%s", + test_server->address().c_str()); + ReceiveInitialMetadata(test_call.get(), + grpc_timeout_seconds_to_deadline(30)); + for (int i = 1; i <= kNumMessagePingPongsPerCall; i++) { + { + grpc_core::MutexLock lock(&ping_pong_round_mu); + ping_pong_round_cv.SignalAll(); + while (int(ping_pong_round) != i) { + ping_pong_round_cv.Wait(&ping_pong_round_mu); + } + } + SendMessage(test_call->call, test_call->cq); + ReceiveMessage(test_call->call, test_call->cq); + { + grpc_core::MutexLock lock(&ping_pong_round_mu); + ping_pongs_done++; + ping_pong_round_cv.SignalAll(); + } + } + gpr_log(GPR_DEBUG, "now receive status on call with server address:%s", + test_server->address().c_str()); + FinishCall(test_call.get()); + GPR_ASSERT(test_call->status.has_value()); + GPR_ASSERT(test_call->status.value() == GRPC_STATUS_PERMISSION_DENIED); + { + grpc_core::ExecCtx exec_ctx; + fake_resolver_response_generator.reset(); + } + })); + } + for (size_t i = 1; i <= kNumMessagePingPongsPerCall; i++) { + { + grpc_core::MutexLock lock(&ping_pong_round_mu); + while (ping_pongs_done < ping_pong_round * kNumCalls) { + ping_pong_round_cv.Wait(&ping_pong_round_mu); + } + ping_pong_round++; + ping_pong_round_cv.SignalAll(); + gpr_log(GPR_DEBUG, "initiate ping pong round: %ld", ping_pong_round); + } + } + for (auto& thread : threads) { + thread.join(); + } + gpr_log(GPR_DEBUG, "All RPCs completed!"); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/iomgr/tcp_client_posix_test.cc b/test/core/iomgr/tcp_client_posix_test.cc new file mode 100644 index 00000000..e5ec7bc3 --- /dev/null +++ b/test/core/iomgr/tcp_client_posix_test.cc @@ -0,0 +1,266 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP_CLIENT + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/iomgr/timer.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +static grpc_pollset_set* g_pollset_set; +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; +static int g_connections_complete = 0; +static grpc_endpoint* g_connecting = nullptr; + +static grpc_millis test_deadline(void) { + return grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(10)); +} + +static void finish_connection() { + gpr_mu_lock(g_mu); + g_connections_complete++; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + + gpr_mu_unlock(g_mu); +} + +static void must_succeed(void* /*arg*/, grpc_error_handle error) { + GPR_ASSERT(g_connecting != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_endpoint_shutdown(g_connecting, GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "must_succeed called")); + grpc_endpoint_destroy(g_connecting); + g_connecting = nullptr; + finish_connection(); +} + +static void must_fail(void* /*arg*/, grpc_error_handle error) { + GPR_ASSERT(g_connecting == nullptr); + GPR_ASSERT(error != GRPC_ERROR_NONE); + finish_connection(); +} + +void test_succeeds(void) { + gpr_log(GPR_ERROR, "---- starting test_succeeds() ----"); + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + int svr_fd; + int r; + int connections_complete_before; + grpc_closure done; + grpc_core::ExecCtx exec_ctx; + + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + addr->sin_family = AF_INET; + + /* create a phony server */ + svr_fd = socket(AF_INET, SOCK_STREAM, 0); + GPR_ASSERT(svr_fd >= 0); + GPR_ASSERT( + 0 == bind(svr_fd, (struct sockaddr*)addr, (socklen_t)resolved_addr.len)); + GPR_ASSERT(0 == listen(svr_fd, 1)); + + gpr_mu_lock(g_mu); + connections_complete_before = g_connections_complete; + gpr_mu_unlock(g_mu); + + /* connect to it */ + GPR_ASSERT(getsockname(svr_fd, (struct sockaddr*)addr, + (socklen_t*)&resolved_addr.len) == 0); + GRPC_CLOSURE_INIT(&done, must_succeed, nullptr, grpc_schedule_on_exec_ctx); + grpc_tcp_client_connect( + &done, &g_connecting, grpc_slice_allocator_create_unlimited(), + g_pollset_set, nullptr, &resolved_addr, GRPC_MILLIS_INF_FUTURE); + /* await the connection */ + do { + resolved_addr.len = static_cast(sizeof(addr)); + r = accept(svr_fd, reinterpret_cast(addr), + reinterpret_cast(&resolved_addr.len)); + } while (r == -1 && errno == EINTR); + GPR_ASSERT(r >= 0); + close(r); + + gpr_mu_lock(g_mu); + + while (g_connections_complete == connections_complete_before) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, + grpc_timespec_to_millis_round_up( + grpc_timeout_seconds_to_deadline(5))))); + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + + gpr_mu_unlock(g_mu); + gpr_log(GPR_ERROR, "---- finished test_succeeds() ----"); +} + +void test_fails(void) { + gpr_log(GPR_ERROR, "---- starting test_fails() ----"); + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + int connections_complete_before; + grpc_closure done; + grpc_core::ExecCtx exec_ctx; + + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + addr->sin_family = AF_INET; + + gpr_mu_lock(g_mu); + connections_complete_before = g_connections_complete; + gpr_mu_unlock(g_mu); + + /* connect to a broken address */ + GRPC_CLOSURE_INIT(&done, must_fail, nullptr, grpc_schedule_on_exec_ctx); + grpc_tcp_client_connect( + &done, &g_connecting, grpc_slice_allocator_create_unlimited(), + g_pollset_set, nullptr, &resolved_addr, GRPC_MILLIS_INF_FUTURE); + gpr_mu_lock(g_mu); + + /* wait for the connection callback to finish */ + while (g_connections_complete == connections_complete_before) { + grpc_pollset_worker* worker = nullptr; + grpc_millis polling_deadline = test_deadline(); + switch (grpc_timer_check(&polling_deadline)) { + case GRPC_TIMERS_FIRED: + break; + case GRPC_TIMERS_NOT_CHECKED: + polling_deadline = 0; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_TIMERS_CHECKED_AND_EMPTY: + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, polling_deadline))); + break; + } + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + + gpr_mu_unlock(g_mu); + gpr_log(GPR_ERROR, "---- finished test_fails() ----"); +} + +void test_fails_bad_addr_no_leak(void) { + gpr_log(GPR_ERROR, "---- starting test_fails_bad_addr_no_leak() ----"); + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + int connections_complete_before; + grpc_closure done; + grpc_core::ExecCtx exec_ctx; + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + // force `grpc_tcp_client_prepare_fd` to fail. contrived, but effective. + addr->sin_family = AF_IPX; + gpr_mu_lock(g_mu); + connections_complete_before = g_connections_complete; + gpr_mu_unlock(g_mu); + // connect to an invalid address. + GRPC_CLOSURE_INIT(&done, must_fail, nullptr, grpc_schedule_on_exec_ctx); + grpc_tcp_client_connect( + &done, &g_connecting, grpc_slice_allocator_create_unlimited(), + g_pollset_set, nullptr, &resolved_addr, GRPC_MILLIS_INF_FUTURE); + gpr_mu_lock(g_mu); + while (g_connections_complete == connections_complete_before) { + grpc_pollset_worker* worker = nullptr; + grpc_millis polling_deadline = test_deadline(); + switch (grpc_timer_check(&polling_deadline)) { + case GRPC_TIMERS_FIRED: + break; + case GRPC_TIMERS_NOT_CHECKED: + polling_deadline = 0; + ABSL_FALLTHROUGH_INTENDED; + case GRPC_TIMERS_CHECKED_AND_EMPTY: + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, polling_deadline))); + break; + } + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + gpr_log(GPR_ERROR, "---- finished test_fails_bad_addr_no_leak() ----"); +} + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + { + grpc_core::ExecCtx exec_ctx; + g_pollset_set = grpc_pollset_set_create(); + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + grpc_pollset_set_add_pollset(g_pollset_set, g_pollset); + + test_succeeds(); + test_fails(); + test_fails_bad_addr_no_leak(); + grpc_pollset_set_destroy(g_pollset_set); + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + } + + grpc_shutdown(); + gpr_free(g_pollset); + return 0; +} + +#else /* GRPC_POSIX_SOCKET_TCP_CLIENT */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_CLIENT */ diff --git a/test/core/iomgr/tcp_posix_test.cc b/test/core/iomgr/tcp_posix_test.cc new file mode 100644 index 00000000..4d7e371e --- /dev/null +++ b/test/core/iomgr/tcp_posix_test.cc @@ -0,0 +1,652 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/buffer_list.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/sockaddr_posix.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/iomgr/endpoint_tests.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; + +/* + General test notes: + + All tests which write data into a socket write i%256 into byte i, which is + verified by readers. + + In general there are a few interesting things to vary which may lead to + exercising different codepaths in an implementation: + 1. Total amount of data written to the socket + 2. Size of slice allocations + 3. Amount of data we read from or write to the socket at once + + The tests here tend to parameterize these where applicable. + + */ + +static void create_sockets(int sv[2]) { + int flags; + GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == 0); + flags = fcntl(sv[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[1], F_SETFL, flags | O_NONBLOCK) == 0); +} + +static void create_inet_sockets(int sv[2]) { + /* Prepare listening socket */ + struct sockaddr_in addr; + memset(&addr, 0, sizeof(struct sockaddr_in)); + addr.sin_family = AF_INET; + int sock = socket(AF_INET, SOCK_STREAM, 0); + GPR_ASSERT(sock); + GPR_ASSERT(bind(sock, (sockaddr*)&addr, sizeof(sockaddr_in)) == 0); + listen(sock, 1); + + /* Prepare client socket and connect to server */ + socklen_t len = sizeof(sockaddr_in); + GPR_ASSERT(getsockname(sock, (sockaddr*)&addr, &len) == 0); + + int client = socket(AF_INET, SOCK_STREAM, 0); + GPR_ASSERT(client); + int ret; + do { + ret = connect(client, reinterpret_cast(&addr), + sizeof(sockaddr_in)); + } while (ret == -1 && errno == EINTR); + + /* Accept client connection */ + len = sizeof(socklen_t); + int server; + do { + server = accept(sock, reinterpret_cast(&addr), &len); + } while (server == -1 && errno == EINTR); + GPR_ASSERT(server != -1); + + sv[0] = server; + sv[1] = client; + int flags = fcntl(sv[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv[1], F_SETFL, flags | O_NONBLOCK) == 0); +} + +static ssize_t fill_socket(int fd) { + ssize_t write_bytes; + ssize_t total_bytes = 0; + int i; + unsigned char buf[256]; + for (i = 0; i < 256; ++i) { + buf[i] = static_cast(i); + } + do { + write_bytes = write(fd, buf, 256); + if (write_bytes > 0) { + total_bytes += write_bytes; + } + } while (write_bytes >= 0 || errno == EINTR); + GPR_ASSERT(errno == EAGAIN); + return total_bytes; +} + +static size_t fill_socket_partial(int fd, size_t bytes) { + ssize_t write_bytes; + size_t total_bytes = 0; + unsigned char* buf = static_cast(gpr_malloc(bytes)); + unsigned i; + for (i = 0; i < bytes; ++i) { + buf[i] = static_cast(i % 256); + } + + do { + write_bytes = write(fd, buf, bytes - total_bytes); + if (write_bytes > 0) { + total_bytes += static_cast(write_bytes); + } + } while ((write_bytes >= 0 || errno == EINTR) && bytes > total_bytes); + + gpr_free(buf); + + return total_bytes; +} + +struct read_socket_state { + grpc_endpoint* ep; + size_t read_bytes; + size_t target_read_bytes; + grpc_slice_buffer incoming; + grpc_closure read_cb; +}; + +static size_t count_slices(grpc_slice* slices, size_t nslices, + int* current_data) { + size_t num_bytes = 0; + unsigned i, j; + unsigned char* buf; + for (i = 0; i < nslices; ++i) { + buf = GRPC_SLICE_START_PTR(slices[i]); + for (j = 0; j < GRPC_SLICE_LENGTH(slices[i]); ++j) { + GPR_ASSERT(buf[j] == *current_data); + *current_data = (*current_data + 1) % 256; + } + num_bytes += GRPC_SLICE_LENGTH(slices[i]); + } + return num_bytes; +} + +static void read_cb(void* user_data, grpc_error_handle error) { + struct read_socket_state* state = + static_cast(user_data); + size_t read_bytes; + int current_data; + + GPR_ASSERT(error == GRPC_ERROR_NONE); + + gpr_mu_lock(g_mu); + current_data = state->read_bytes % 256; + read_bytes = count_slices(state->incoming.slices, state->incoming.count, + ¤t_data); + state->read_bytes += read_bytes; + gpr_log(GPR_INFO, "Read %" PRIuPTR " bytes of %" PRIuPTR, read_bytes, + state->target_read_bytes); + if (state->read_bytes >= state->target_read_bytes) { + GPR_ASSERT( + GRPC_LOG_IF_ERROR("kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); + } else { + gpr_mu_unlock(g_mu); + grpc_endpoint_read(state->ep, &state->incoming, &state->read_cb, + /*urgent=*/false); + } +} + +/* Write to a socket, then read from it using the grpc_tcp API. */ +static void read_test(size_t num_bytes, size_t slice_size) { + int sv[2]; + grpc_endpoint* ep; + struct read_socket_state state; + size_t written_bytes; + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(20)); + grpc_core::ExecCtx exec_ctx; + + gpr_log(GPR_INFO, "Read test of size %" PRIuPTR ", slice size %" PRIuPTR, + num_bytes, slice_size); + + create_sockets(sv); + + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER, + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + ep = grpc_tcp_create(grpc_fd_create(sv[1], "read_test", false), &args, "test", + grpc_slice_allocator_create_unlimited()); + grpc_endpoint_add_to_pollset(ep, g_pollset); + + written_bytes = fill_socket_partial(sv[0], num_bytes); + gpr_log(GPR_INFO, "Wrote %" PRIuPTR " bytes", written_bytes); + + state.ep = ep; + state.read_bytes = 0; + state.target_read_bytes = written_bytes; + grpc_slice_buffer_init(&state.incoming); + GRPC_CLOSURE_INIT(&state.read_cb, read_cb, &state, grpc_schedule_on_exec_ctx); + + grpc_endpoint_read(ep, &state.incoming, &state.read_cb, /*urgent=*/false); + + gpr_mu_lock(g_mu); + while (state.read_bytes < state.target_read_bytes) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + GPR_ASSERT(state.read_bytes == state.target_read_bytes); + gpr_mu_unlock(g_mu); + + grpc_slice_buffer_destroy_internal(&state.incoming); + grpc_endpoint_destroy(ep); +} + +/* Write to a socket until it fills up, then read from it using the grpc_tcp + API. */ +static void large_read_test(size_t slice_size) { + int sv[2]; + grpc_endpoint* ep; + struct read_socket_state state; + ssize_t written_bytes; + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(20)); + grpc_core::ExecCtx exec_ctx; + + gpr_log(GPR_INFO, "Start large read test, slice size %" PRIuPTR, slice_size); + + create_sockets(sv); + + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + ep = grpc_tcp_create(grpc_fd_create(sv[1], "large_read_test", false), &args, + "test", grpc_slice_allocator_create_unlimited()); + grpc_endpoint_add_to_pollset(ep, g_pollset); + + written_bytes = fill_socket(sv[0]); + gpr_log(GPR_INFO, "Wrote %" PRIuPTR " bytes", written_bytes); + + state.ep = ep; + state.read_bytes = 0; + state.target_read_bytes = static_cast(written_bytes); + grpc_slice_buffer_init(&state.incoming); + GRPC_CLOSURE_INIT(&state.read_cb, read_cb, &state, grpc_schedule_on_exec_ctx); + + grpc_endpoint_read(ep, &state.incoming, &state.read_cb, /*urgent=*/false); + + gpr_mu_lock(g_mu); + while (state.read_bytes < state.target_read_bytes) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + GPR_ASSERT(state.read_bytes == state.target_read_bytes); + gpr_mu_unlock(g_mu); + + grpc_slice_buffer_destroy_internal(&state.incoming); + grpc_endpoint_destroy(ep); +} + +struct write_socket_state { + grpc_endpoint* ep; + int write_done; +}; + +static grpc_slice* allocate_blocks(size_t num_bytes, size_t slice_size, + size_t* num_blocks, uint8_t* current_data) { + size_t nslices = num_bytes / slice_size + (num_bytes % slice_size ? 1u : 0u); + grpc_slice* slices = + static_cast(gpr_malloc(sizeof(grpc_slice) * nslices)); + size_t num_bytes_left = num_bytes; + unsigned i, j; + unsigned char* buf; + *num_blocks = nslices; + + for (i = 0; i < nslices; ++i) { + slices[i] = grpc_slice_malloc(slice_size > num_bytes_left ? num_bytes_left + : slice_size); + num_bytes_left -= GRPC_SLICE_LENGTH(slices[i]); + buf = GRPC_SLICE_START_PTR(slices[i]); + for (j = 0; j < GRPC_SLICE_LENGTH(slices[i]); ++j) { + buf[j] = *current_data; + (*current_data)++; + } + } + GPR_ASSERT(num_bytes_left == 0); + return slices; +} + +static void write_done(void* user_data /* write_socket_state */, + grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + struct write_socket_state* state = + static_cast(user_data); + gpr_mu_lock(g_mu); + state->write_done = 1; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +void drain_socket_blocking(int fd, size_t num_bytes, size_t read_size) { + unsigned char* buf = static_cast(gpr_malloc(read_size)); + ssize_t bytes_read; + size_t bytes_left = num_bytes; + int flags; + int current = 0; + int i; + grpc_core::ExecCtx exec_ctx; + + flags = fcntl(fd, F_GETFL, 0); + GPR_ASSERT(fcntl(fd, F_SETFL, flags & ~O_NONBLOCK) == 0); + + for (;;) { + grpc_pollset_worker* worker = nullptr; + gpr_mu_lock(g_mu); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(g_pollset, &worker, + grpc_timespec_to_millis_round_up( + grpc_timeout_milliseconds_to_deadline(10))))); + gpr_mu_unlock(g_mu); + + do { + bytes_read = + read(fd, buf, bytes_left > read_size ? read_size : bytes_left); + } while (bytes_read < 0 && errno == EINTR); + GPR_ASSERT(bytes_read >= 0); + for (i = 0; i < bytes_read; ++i) { + GPR_ASSERT(buf[i] == current); + current = (current + 1) % 256; + } + bytes_left -= static_cast(bytes_read); + if (bytes_left == 0) break; + } + flags = fcntl(fd, F_GETFL, 0); + GPR_ASSERT(fcntl(fd, F_SETFL, flags | O_NONBLOCK) == 0); + + gpr_free(buf); +} + +/* Verifier for timestamps callback for write_test */ +void timestamps_verifier(void* arg, grpc_core::Timestamps* ts, + grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(arg != nullptr); + GPR_ASSERT(ts->sendmsg_time.time.clock_type == GPR_CLOCK_REALTIME); + GPR_ASSERT(ts->scheduled_time.time.clock_type == GPR_CLOCK_REALTIME); + GPR_ASSERT(ts->acked_time.time.clock_type == GPR_CLOCK_REALTIME); + gpr_atm* done_timestamps = static_cast(arg); + gpr_atm_rel_store(done_timestamps, static_cast(1)); +} + +/* Write to a socket using the grpc_tcp API, then drain it directly. + Note that if the write does not complete immediately we need to drain the + socket in parallel with the read. If collect_timestamps is true, it will + try to get timestamps for the write. */ +static void write_test(size_t num_bytes, size_t slice_size, + bool collect_timestamps) { + int sv[2]; + grpc_endpoint* ep; + struct write_socket_state state; + size_t num_blocks; + grpc_slice* slices; + uint8_t current_data = 0; + grpc_slice_buffer outgoing; + grpc_closure write_done_closure; + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(20)); + grpc_core::ExecCtx exec_ctx; + + if (collect_timestamps && !grpc_event_engine_can_track_errors()) { + return; + } + + gpr_log(GPR_INFO, + "Start write test with %" PRIuPTR " bytes, slice size %" PRIuPTR, + num_bytes, slice_size); + + if (collect_timestamps) { + create_inet_sockets(sv); + } else { + create_sockets(sv); + } + + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER, + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + ep = grpc_tcp_create(grpc_fd_create(sv[1], "write_test", collect_timestamps), + &args, "test", grpc_slice_allocator_create_unlimited()); + grpc_endpoint_add_to_pollset(ep, g_pollset); + + state.ep = ep; + state.write_done = 0; + + slices = allocate_blocks(num_bytes, slice_size, &num_blocks, ¤t_data); + + grpc_slice_buffer_init(&outgoing); + grpc_slice_buffer_addn(&outgoing, slices, num_blocks); + GRPC_CLOSURE_INIT(&write_done_closure, write_done, &state, + grpc_schedule_on_exec_ctx); + + gpr_atm done_timestamps; + gpr_atm_rel_store(&done_timestamps, static_cast(0)); + grpc_endpoint_write(ep, &outgoing, &write_done_closure, + grpc_event_engine_can_track_errors() && collect_timestamps + ? &done_timestamps + : nullptr); + drain_socket_blocking(sv[0], num_bytes, num_bytes); + exec_ctx.Flush(); + gpr_mu_lock(g_mu); + for (;;) { + grpc_pollset_worker* worker = nullptr; + if (state.write_done && + (!(grpc_event_engine_can_track_errors() && collect_timestamps) || + gpr_atm_acq_load(&done_timestamps) == static_cast(1))) { + break; + } + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_mu_unlock(g_mu); + exec_ctx.Flush(); + gpr_mu_lock(g_mu); + } + gpr_mu_unlock(g_mu); + + grpc_slice_buffer_destroy_internal(&outgoing); + grpc_endpoint_destroy(ep); + gpr_free(slices); +} + +void on_fd_released(void* arg, grpc_error_handle /*errors*/) { + int* done = static_cast(arg); + *done = 1; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); +} + +/* Do a read_test, then release fd and try to read/write again. Verify that + grpc_tcp_fd() is available before the fd is released. */ +static void release_fd_test(size_t num_bytes, size_t slice_size) { + int sv[2]; + grpc_endpoint* ep; + struct read_socket_state state; + size_t written_bytes; + int fd; + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(20)); + grpc_core::ExecCtx exec_ctx; + grpc_closure fd_released_cb; + int fd_released_done = 0; + GRPC_CLOSURE_INIT(&fd_released_cb, &on_fd_released, &fd_released_done, + grpc_schedule_on_exec_ctx); + + gpr_log(GPR_INFO, + "Release fd read_test of size %" PRIuPTR ", slice size %" PRIuPTR, + num_bytes, slice_size); + + create_sockets(sv); + + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + ep = grpc_tcp_create(grpc_fd_create(sv[1], "read_test", false), &args, "test", + grpc_slice_allocator_create_unlimited()); + GPR_ASSERT(grpc_tcp_fd(ep) == sv[1] && sv[1] >= 0); + grpc_endpoint_add_to_pollset(ep, g_pollset); + + written_bytes = fill_socket_partial(sv[0], num_bytes); + gpr_log(GPR_INFO, "Wrote %" PRIuPTR " bytes", written_bytes); + + state.ep = ep; + state.read_bytes = 0; + state.target_read_bytes = written_bytes; + grpc_slice_buffer_init(&state.incoming); + GRPC_CLOSURE_INIT(&state.read_cb, read_cb, &state, grpc_schedule_on_exec_ctx); + + grpc_endpoint_read(ep, &state.incoming, &state.read_cb, /*urgent=*/false); + + gpr_mu_lock(g_mu); + while (state.read_bytes < state.target_read_bytes) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_log(GPR_DEBUG, "wakeup: read=%" PRIdPTR " target=%" PRIdPTR, + state.read_bytes, state.target_read_bytes); + gpr_mu_unlock(g_mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + } + GPR_ASSERT(state.read_bytes == state.target_read_bytes); + gpr_mu_unlock(g_mu); + + grpc_slice_buffer_destroy_internal(&state.incoming); + grpc_tcp_destroy_and_release_fd(ep, &fd, &fd_released_cb); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(g_mu); + while (!fd_released_done) { + grpc_pollset_worker* worker = nullptr; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "pollset_work", grpc_pollset_work(g_pollset, &worker, deadline))); + gpr_log(GPR_DEBUG, "wakeup: fd_released_done=%d", fd_released_done); + } + gpr_mu_unlock(g_mu); + GPR_ASSERT(fd_released_done == 1); + GPR_ASSERT(fd == sv[1]); + + written_bytes = fill_socket_partial(sv[0], num_bytes); + drain_socket_blocking(fd, written_bytes, written_bytes); + written_bytes = fill_socket_partial(fd, num_bytes); + drain_socket_blocking(sv[0], written_bytes, written_bytes); + close(fd); +} + +void run_tests(void) { + size_t i = 0; + + read_test(100, 8192); + read_test(10000, 8192); + read_test(10000, 137); + read_test(10000, 1); + large_read_test(8192); + large_read_test(1); + + write_test(100, 8192, false); + write_test(100, 1, false); + write_test(100000, 8192, false); + write_test(100000, 1, false); + write_test(100000, 137, false); + + write_test(100, 8192, true); + write_test(100, 1, true); + write_test(100000, 8192, true); + write_test(100000, 1, true); + write_test(100, 137, true); + + for (i = 1; i < 1000; i = std::max(i + 1, i * 5 / 4)) { + write_test(40320, i, false); + write_test(40320, i, true); + } + + release_fd_test(100, 8192); +} + +static void clean_up(void) {} + +static grpc_endpoint_test_fixture create_fixture_tcp_socketpair( + size_t slice_size) { + int sv[2]; + grpc_endpoint_test_fixture f; + grpc_core::ExecCtx exec_ctx; + + create_sockets(sv); + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + f.client_ep = + grpc_tcp_create(grpc_fd_create(sv[0], "fixture:client", false), &args, + "test", grpc_slice_allocator_create_unlimited()); + f.server_ep = + grpc_tcp_create(grpc_fd_create(sv[1], "fixture:server", false), &args, + "test", grpc_slice_allocator_create_unlimited()); + grpc_endpoint_add_to_pollset(f.client_ep, g_pollset); + grpc_endpoint_add_to_pollset(f.server_ep, g_pollset); + + return f; +} + +static grpc_endpoint_test_config configs[] = { + {"tcp/tcp_socketpair", create_fixture_tcp_socketpair, clean_up}, +}; + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + grpc_core::grpc_tcp_set_write_timestamps_callback(timestamps_verifier); + { + grpc_core::ExecCtx exec_ctx; + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + grpc_endpoint_tests(configs[0], g_pollset, g_mu); + run_tests(); + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_shutdown(); + gpr_free(g_pollset); + + return 0; +} + +#else /* GRPC_POSIX_SOCKET_TCP */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/test/core/iomgr/tcp_server_posix_test.cc b/test/core/iomgr/tcp_server_posix_test.cc new file mode 100644 index 00000000..f83a3704 --- /dev/null +++ b/test/core/iomgr/tcp_server_posix_test.cc @@ -0,0 +1,533 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test won't work except with posix sockets enabled +#ifdef GRPC_POSIX_SOCKET_TCP_SERVER + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", #x) + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; +static int g_nconnects = 0; + +typedef struct { + /* Owns a ref to server. */ + grpc_tcp_server* server; + unsigned port_index; + unsigned fd_index; + int server_fd; +} on_connect_result; + +typedef struct { + grpc_tcp_server* server; + + /* arg is this server_weak_ref. */ + grpc_closure server_shutdown; +} server_weak_ref; + +#define MAX_URI 1024 +typedef struct { + grpc_resolved_address addr; + char str[MAX_URI]; +} test_addr; + +#define MAX_ADDRS 100 +typedef struct { + size_t naddrs; + test_addr addrs[MAX_ADDRS]; +} test_addrs; + +static on_connect_result g_result = {nullptr, 0, 0, -1}; + +static char family_name_buf[1024]; +static const char* sock_family_name(int family) { + if (family == AF_INET) { + return "AF_INET"; + } else if (family == AF_INET6) { + return "AF_INET6"; + } else if (family == AF_UNSPEC) { + return "AF_UNSPEC"; + } else { + sprintf(family_name_buf, "%d", family); + return family_name_buf; + } +} + +static void on_connect_result_init(on_connect_result* result) { + result->server = nullptr; + result->port_index = 0; + result->fd_index = 0; + result->server_fd = -1; +} + +static void on_connect_result_set(on_connect_result* result, + const grpc_tcp_server_acceptor* acceptor) { + result->server = grpc_tcp_server_ref(acceptor->from_server); + result->port_index = acceptor->port_index; + result->fd_index = acceptor->fd_index; + result->server_fd = grpc_tcp_server_port_fd( + result->server, acceptor->port_index, acceptor->fd_index); +} + +static void server_weak_ref_shutdown(void* arg, grpc_error_handle /*error*/) { + server_weak_ref* weak_ref = static_cast(arg); + weak_ref->server = nullptr; +} + +static void server_weak_ref_init(server_weak_ref* weak_ref) { + weak_ref->server = nullptr; + GRPC_CLOSURE_INIT(&weak_ref->server_shutdown, server_weak_ref_shutdown, + weak_ref, grpc_schedule_on_exec_ctx); +} + +/* Make weak_ref->server_shutdown a shutdown_starting cb on server. + grpc_tcp_server promises that the server object will live until + weak_ref->server_shutdown has returned. A strong ref on grpc_tcp_server + should be held until server_weak_ref_set() returns to avoid a race where the + server is deleted before the shutdown_starting cb is added. */ +static void server_weak_ref_set(server_weak_ref* weak_ref, + grpc_tcp_server* server) { + grpc_tcp_server_shutdown_starting_add(server, &weak_ref->server_shutdown); + weak_ref->server = server; +} + +static void test_addr_init_str(test_addr* addr) { + std::string str = grpc_sockaddr_to_string(&addr->addr, false); + size_t str_len = std::min(str.size(), sizeof(addr->str) - 1); + memcpy(addr->str, str.c_str(), str_len); + addr->str[str_len] = '\0'; +} + +static void on_connect(void* /*arg*/, grpc_endpoint* tcp, + grpc_pollset* /*pollset*/, + grpc_tcp_server_acceptor* acceptor) { + grpc_endpoint_shutdown(tcp, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Connected")); + grpc_endpoint_destroy(tcp); + + on_connect_result temp_result; + on_connect_result_set(&temp_result, acceptor); + gpr_free(acceptor); + + gpr_mu_lock(g_mu); + g_result = temp_result; + g_nconnects++; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(g_pollset, nullptr))); + gpr_mu_unlock(g_mu); +} + +static void test_no_op(void) { + grpc_core::ExecCtx exec_ctx; + grpc_tcp_server* s; + GPR_ASSERT(GRPC_ERROR_NONE == + grpc_tcp_server_create(nullptr, nullptr, + grpc_slice_allocator_factory_create( + grpc_resource_quota_create(nullptr)), + &s)); + grpc_tcp_server_unref(s); +} + +static void test_no_op_with_start(void) { + grpc_core::ExecCtx exec_ctx; + grpc_tcp_server* s; + GPR_ASSERT(GRPC_ERROR_NONE == + grpc_tcp_server_create(nullptr, nullptr, + grpc_slice_allocator_factory_create( + grpc_resource_quota_create(nullptr)), + &s)); + LOG_TEST("test_no_op_with_start"); + std::vector empty_pollset; + grpc_tcp_server_start(s, &empty_pollset, on_connect, nullptr); + grpc_tcp_server_unref(s); +} + +static void test_no_op_with_port(void) { + grpc_core::ExecCtx exec_ctx; + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + grpc_tcp_server* s; + GPR_ASSERT(GRPC_ERROR_NONE == + grpc_tcp_server_create(nullptr, nullptr, + grpc_slice_allocator_factory_create( + grpc_resource_quota_create(nullptr)), + &s)); + LOG_TEST("test_no_op_with_port"); + + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + addr->sin_family = AF_INET; + int port = -1; + GPR_ASSERT(grpc_tcp_server_add_port(s, &resolved_addr, &port) == + GRPC_ERROR_NONE && + port > 0); + + grpc_tcp_server_unref(s); +} + +static void test_no_op_with_port_and_start(void) { + grpc_core::ExecCtx exec_ctx; + grpc_resolved_address resolved_addr; + struct sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + grpc_tcp_server* s; + GPR_ASSERT(GRPC_ERROR_NONE == + grpc_tcp_server_create(nullptr, nullptr, + grpc_slice_allocator_factory_create( + grpc_resource_quota_create(nullptr)), + &s)); + LOG_TEST("test_no_op_with_port_and_start"); + int port = -1; + + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_in)); + addr->sin_family = AF_INET; + GPR_ASSERT(grpc_tcp_server_add_port(s, &resolved_addr, &port) == + GRPC_ERROR_NONE && + port > 0); + + std::vector empty_pollset; + grpc_tcp_server_start(s, &empty_pollset, on_connect, nullptr); + + grpc_tcp_server_unref(s); +} + +static grpc_error_handle tcp_connect(const test_addr* remote, + on_connect_result* result) { + grpc_millis deadline = + grpc_timespec_to_millis_round_up(grpc_timeout_seconds_to_deadline(10)); + int clifd; + int nconnects_before; + const struct sockaddr* remote_addr = + reinterpret_cast(remote->addr.addr); + + gpr_log(GPR_INFO, "Connecting to %s", remote->str); + gpr_mu_lock(g_mu); + nconnects_before = g_nconnects; + on_connect_result_init(&g_result); + clifd = socket(remote_addr->sa_family, SOCK_STREAM, 0); + if (clifd < 0) { + gpr_mu_unlock(g_mu); + return GRPC_OS_ERROR(errno, "Failed to create socket"); + } + gpr_log(GPR_DEBUG, "start connect to %s", remote->str); + if (connect(clifd, remote_addr, static_cast(remote->addr.len)) != + 0) { + gpr_mu_unlock(g_mu); + close(clifd); + return GRPC_OS_ERROR(errno, "connect"); + } + gpr_log(GPR_DEBUG, "wait"); + while (g_nconnects == nconnects_before && + deadline > grpc_core::ExecCtx::Get()->Now()) { + grpc_pollset_worker* worker = nullptr; + grpc_error_handle err; + if ((err = grpc_pollset_work(g_pollset, &worker, deadline)) != + GRPC_ERROR_NONE) { + gpr_mu_unlock(g_mu); + close(clifd); + return err; + } + gpr_mu_unlock(g_mu); + + gpr_mu_lock(g_mu); + } + gpr_log(GPR_DEBUG, "wait done"); + if (g_nconnects != nconnects_before + 1) { + gpr_mu_unlock(g_mu); + close(clifd); + return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Didn't connect"); + } + close(clifd); + *result = g_result; + + gpr_mu_unlock(g_mu); + gpr_log(GPR_INFO, "Result (%d, %d) fd %d", result->port_index, + result->fd_index, result->server_fd); + grpc_tcp_server_unref(result->server); + return GRPC_ERROR_NONE; +} + +/* Tests a tcp server on "::" listeners with multiple ports. If channel_args is + non-NULL, pass them to the server. If dst_addrs is non-NULL, use valid addrs + as destination addrs (port is not set). If dst_addrs is NULL, use listener + addrs as destination addrs. If test_dst_addrs is true, test connectivity with + each destination address, set grpc_resolved_address::len=0 for failures, but + don't fail the overall unitest. */ +static void test_connect(size_t num_connects, + const grpc_channel_args* channel_args, + test_addrs* dst_addrs, bool test_dst_addrs) { + grpc_core::ExecCtx exec_ctx; + grpc_resolved_address resolved_addr; + grpc_resolved_address resolved_addr1; + struct sockaddr_storage* const addr = + reinterpret_cast(resolved_addr.addr); + struct sockaddr_storage* const addr1 = + reinterpret_cast(resolved_addr1.addr); + unsigned svr_fd_count; + int port; + int svr_port; + unsigned svr1_fd_count; + int svr1_port; + grpc_tcp_server* s; + const unsigned num_ports = 2; + GPR_ASSERT(GRPC_ERROR_NONE == + grpc_tcp_server_create( + nullptr, channel_args, + grpc_slice_allocator_factory_create( + grpc_resource_quota_from_channel_args(channel_args, true)), + &s)); + unsigned port_num; + server_weak_ref weak_ref; + server_weak_ref_init(&weak_ref); + server_weak_ref_set(&weak_ref, s); + LOG_TEST("test_connect"); + gpr_log(GPR_INFO, + "clients=%lu, num chan args=%lu, remote IP=%s, test_dst_addrs=%d", + static_cast(num_connects), + static_cast( + channel_args != nullptr ? channel_args->num_args : 0), + dst_addrs != nullptr ? "" : "::", test_dst_addrs); + memset(&resolved_addr, 0, sizeof(resolved_addr)); + memset(&resolved_addr1, 0, sizeof(resolved_addr1)); + resolved_addr.len = static_cast(sizeof(struct sockaddr_storage)); + resolved_addr1.len = static_cast(sizeof(struct sockaddr_storage)); + addr->ss_family = addr1->ss_family = AF_INET; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "grpc_tcp_server_add_port", + grpc_tcp_server_add_port(s, &resolved_addr, &svr_port))); + gpr_log(GPR_INFO, "Allocated port %d", svr_port); + GPR_ASSERT(svr_port > 0); + /* Cannot use wildcard (port==0), because add_port() will try to reuse the + same port as a previous add_port(). */ + svr1_port = grpc_pick_unused_port_or_die(); + GPR_ASSERT(svr1_port > 0); + gpr_log(GPR_INFO, "Picked unused port %d", svr1_port); + grpc_sockaddr_set_port(&resolved_addr1, svr1_port); + GPR_ASSERT(grpc_tcp_server_add_port(s, &resolved_addr1, &port) == + GRPC_ERROR_NONE && + port == svr1_port); + + /* Bad port_index. */ + GPR_ASSERT(grpc_tcp_server_port_fd_count(s, 2) == 0); + GPR_ASSERT(grpc_tcp_server_port_fd(s, 2, 0) < 0); + + /* Bad fd_index. */ + GPR_ASSERT(grpc_tcp_server_port_fd(s, 0, 100) < 0); + GPR_ASSERT(grpc_tcp_server_port_fd(s, 1, 100) < 0); + + /* Got at least one fd per port. */ + svr_fd_count = grpc_tcp_server_port_fd_count(s, 0); + GPR_ASSERT(svr_fd_count >= 1); + svr1_fd_count = grpc_tcp_server_port_fd_count(s, 1); + GPR_ASSERT(svr1_fd_count >= 1); + + std::vector test_pollset; + test_pollset.push_back(g_pollset); + grpc_tcp_server_start(s, &test_pollset, on_connect, nullptr); + + if (dst_addrs != nullptr) { + int ports[] = {svr_port, svr1_port}; + for (port_num = 0; port_num < num_ports; ++port_num) { + size_t dst_idx; + size_t num_tested = 0; + for (dst_idx = 0; dst_idx < dst_addrs->naddrs; ++dst_idx) { + test_addr dst = dst_addrs->addrs[dst_idx]; + on_connect_result result; + grpc_error_handle err; + if (dst.addr.len == 0) { + gpr_log(GPR_DEBUG, "Skipping test of non-functional local IP %s", + dst.str); + continue; + } + GPR_ASSERT(grpc_sockaddr_set_port(&dst.addr, ports[port_num])); + test_addr_init_str(&dst); + ++num_tested; + on_connect_result_init(&result); + if ((err = tcp_connect(&dst, &result)) == GRPC_ERROR_NONE && + result.server_fd >= 0 && result.server == s) { + continue; + } + gpr_log(GPR_ERROR, "Failed to connect to %s: %s", dst.str, + grpc_error_std_string(err).c_str()); + GPR_ASSERT(test_dst_addrs); + dst_addrs->addrs[dst_idx].addr.len = 0; + GRPC_ERROR_UNREF(err); + } + GPR_ASSERT(num_tested > 0); + } + } else { + for (port_num = 0; port_num < num_ports; ++port_num) { + const unsigned num_fds = grpc_tcp_server_port_fd_count(s, port_num); + unsigned fd_num; + for (fd_num = 0; fd_num < num_fds; ++fd_num) { + int fd = grpc_tcp_server_port_fd(s, port_num, fd_num); + size_t connect_num; + test_addr dst; + GPR_ASSERT(fd >= 0); + dst.addr.len = static_cast(sizeof(dst.addr.addr)); + GPR_ASSERT(getsockname(fd, (struct sockaddr*)dst.addr.addr, + (socklen_t*)&dst.addr.len) == 0); + GPR_ASSERT(dst.addr.len <= sizeof(dst.addr.addr)); + test_addr_init_str(&dst); + gpr_log(GPR_INFO, "(%d, %d) fd %d family %s listening on %s", port_num, + fd_num, fd, sock_family_name(addr->ss_family), dst.str); + for (connect_num = 0; connect_num < num_connects; ++connect_num) { + on_connect_result result; + on_connect_result_init(&result); + GPR_ASSERT( + GRPC_LOG_IF_ERROR("tcp_connect", tcp_connect(&dst, &result))); + GPR_ASSERT(result.server_fd == fd); + GPR_ASSERT(result.port_index == port_num); + GPR_ASSERT(result.fd_index == fd_num); + GPR_ASSERT(result.server == s); + GPR_ASSERT( + grpc_tcp_server_port_fd(s, result.port_index, result.fd_index) == + result.server_fd); + } + } + } + } + /* Weak ref to server valid until final unref. */ + GPR_ASSERT(weak_ref.server != nullptr); + GPR_ASSERT(grpc_tcp_server_port_fd(s, 0, 0) >= 0); + + grpc_tcp_server_unref(s); + grpc_core::ExecCtx::Get()->Flush(); + + /* Weak ref lost. */ + GPR_ASSERT(weak_ref.server == nullptr); +} + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc_arg chan_args[1]; + chan_args[0].type = GRPC_ARG_INTEGER; + chan_args[0].key = const_cast(GRPC_ARG_EXPAND_WILDCARD_ADDRS); + chan_args[0].value.integer = 1; + const grpc_channel_args channel_args = {1, chan_args}; + struct ifaddrs* ifa = nullptr; + struct ifaddrs* ifa_it; + // Zalloc dst_addrs to avoid oversized frames. + test_addrs* dst_addrs = grpc_core::Zalloc(); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + // wait a few seconds to make sure IPv6 link-local addresses can be bound + // if we are running under docker container that has just started. + // See https://github.com/moby/moby/issues/38491 + // See https://github.com/grpc/grpc/issues/15610 + gpr_sleep_until(grpc_timeout_seconds_to_deadline(4)); + { + grpc_core::ExecCtx exec_ctx; + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + + test_no_op(); + test_no_op_with_start(); + test_no_op_with_port(); + test_no_op_with_port_and_start(); + + if (getifaddrs(&ifa) != 0 || ifa == nullptr) { + gpr_log(GPR_ERROR, "getifaddrs: %s", strerror(errno)); + return EXIT_FAILURE; + } + dst_addrs->naddrs = 0; + for (ifa_it = ifa; ifa_it != nullptr && dst_addrs->naddrs < MAX_ADDRS; + ifa_it = ifa_it->ifa_next) { + if (ifa_it->ifa_addr == nullptr) { + continue; + } else if (ifa_it->ifa_addr->sa_family == AF_INET) { + dst_addrs->addrs[dst_addrs->naddrs].addr.len = + static_cast(sizeof(struct sockaddr_in)); + } else if (ifa_it->ifa_addr->sa_family == AF_INET6) { + dst_addrs->addrs[dst_addrs->naddrs].addr.len = + static_cast(sizeof(struct sockaddr_in6)); + } else { + continue; + } + memcpy(dst_addrs->addrs[dst_addrs->naddrs].addr.addr, ifa_it->ifa_addr, + dst_addrs->addrs[dst_addrs->naddrs].addr.len); + GPR_ASSERT( + grpc_sockaddr_set_port(&dst_addrs->addrs[dst_addrs->naddrs].addr, 0)); + test_addr_init_str(&dst_addrs->addrs[dst_addrs->naddrs]); + ++dst_addrs->naddrs; + } + freeifaddrs(ifa); + ifa = nullptr; + + /* Connect to same addresses as listeners. */ + test_connect(1, nullptr, nullptr, false); + test_connect(10, nullptr, nullptr, false); + + /* Set dst_addrs->addrs[i].len=0 for dst_addrs that are unreachable with a + "::" listener. */ + test_connect(1, nullptr, dst_addrs, true); + + /* Test connect(2) with dst_addrs. */ + test_connect(1, &channel_args, dst_addrs, false); + /* Test connect(2) with dst_addrs. */ + test_connect(10, &channel_args, dst_addrs, false); + + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + } + grpc_shutdown(); + gpr_free(dst_addrs); + gpr_free(g_pollset); + return EXIT_SUCCESS; +} + +#else /* GRPC_POSIX_SOCKET_SERVER */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_POSIX_SOCKET_SERVER */ diff --git a/test/core/iomgr/threadpool_test.cc b/test/core/iomgr/threadpool_test.cc new file mode 100644 index 00000000..1324b276 --- /dev/null +++ b/test/core/iomgr/threadpool_test.cc @@ -0,0 +1,189 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/executor/threadpool.h" + +#include "test/core/util/test_config.h" + +static const int kSmallThreadPoolSize = 20; +static const int kLargeThreadPoolSize = 100; +static const int kThreadSmallIter = 100; +static const int kThreadLargeIter = 10000; + +static void test_size_zero(void) { + gpr_log(GPR_INFO, "test_size_zero"); + grpc_core::ThreadPool* pool_size_zero = new grpc_core::ThreadPool(0); + GPR_ASSERT(pool_size_zero->pool_capacity() == 1); + delete pool_size_zero; +} + +static void test_constructor_option(void) { + gpr_log(GPR_INFO, "test_constructor_option"); + // Tests options + grpc_core::Thread::Options options; + options.set_stack_size(192 * 1024); // Random non-default value + grpc_core::ThreadPool* pool = + new grpc_core::ThreadPool(0, "test_constructor_option", options); + GPR_ASSERT(pool->thread_options().stack_size() == options.stack_size()); + delete pool; +} + +// Simple functor for testing. It will count how many times being called. +class SimpleFunctorForAdd : public grpc_completion_queue_functor { + public: + friend class SimpleFunctorCheckForAdd; + SimpleFunctorForAdd() { + functor_run = &SimpleFunctorForAdd::Run; + inlineable = true; + internal_next = this; + internal_success = 0; + } + ~SimpleFunctorForAdd() {} + static void Run(struct grpc_completion_queue_functor* cb, int /*ok*/) { + auto* callback = static_cast(cb); + callback->count_.fetch_add(1, std::memory_order_relaxed); + } + + int count() { return count_.load(std::memory_order_relaxed); } + + private: + std::atomic count_{0}; +}; + +static void test_add(void) { + gpr_log(GPR_INFO, "test_add"); + grpc_core::ThreadPool* pool = + new grpc_core::ThreadPool(kSmallThreadPoolSize, "test_add"); + + SimpleFunctorForAdd* functor = new SimpleFunctorForAdd(); + for (int i = 0; i < kThreadSmallIter; ++i) { + pool->Add(functor); + } + delete pool; + GPR_ASSERT(functor->count() == kThreadSmallIter); + delete functor; + gpr_log(GPR_DEBUG, "Done."); +} + +// Thread that adds closures to pool +class WorkThread { + public: + WorkThread(grpc_core::ThreadPool* pool, SimpleFunctorForAdd* cb, int num_add) + : num_add_(num_add), cb_(cb), pool_(pool) { + thd_ = grpc_core::Thread( + "thread_pool_test_add_thd", + [](void* th) { static_cast(th)->Run(); }, this); + } + ~WorkThread() {} + + void Start() { thd_.Start(); } + void Join() { thd_.Join(); } + + private: + void Run() { + for (int i = 0; i < num_add_; ++i) { + pool_->Add(cb_); + } + } + + int num_add_; + SimpleFunctorForAdd* cb_; + grpc_core::ThreadPool* pool_; + grpc_core::Thread thd_; +}; + +static void test_multi_add(void) { + gpr_log(GPR_INFO, "test_multi_add"); + const int num_work_thds = 10; + grpc_core::ThreadPool* pool = + new grpc_core::ThreadPool(kLargeThreadPoolSize, "test_multi_add"); + SimpleFunctorForAdd* functor = new SimpleFunctorForAdd(); + WorkThread** work_thds = static_cast( + gpr_zalloc(sizeof(WorkThread*) * num_work_thds)); + gpr_log(GPR_DEBUG, "Fork threads for adding..."); + for (int i = 0; i < num_work_thds; ++i) { + work_thds[i] = new WorkThread(pool, functor, kThreadLargeIter); + work_thds[i]->Start(); + } + // Wait for all threads finish + gpr_log(GPR_DEBUG, "Waiting for all work threads finish..."); + for (int i = 0; i < num_work_thds; ++i) { + work_thds[i]->Join(); + delete work_thds[i]; + } + gpr_free(work_thds); + gpr_log(GPR_DEBUG, "Done."); + gpr_log(GPR_DEBUG, "Waiting for all closures finish..."); + // Destructor of thread pool will wait for all closures to finish + delete pool; + GPR_ASSERT(functor->count() == kThreadLargeIter * num_work_thds); + delete functor; + gpr_log(GPR_DEBUG, "Done."); +} + +// Checks the current count with a given number. +class SimpleFunctorCheckForAdd : public grpc_completion_queue_functor { + public: + SimpleFunctorCheckForAdd(int ok, int* count) : count_(count) { + functor_run = &SimpleFunctorCheckForAdd::Run; + inlineable = true; + internal_success = ok; + } + ~SimpleFunctorCheckForAdd() {} + static void Run(struct grpc_completion_queue_functor* cb, int /*ok*/) { + auto* callback = static_cast(cb); + (*callback->count_)++; + GPR_ASSERT(*callback->count_ == callback->internal_success); + } + + private: + int* count_; +}; + +static void test_one_thread_FIFO(void) { + gpr_log(GPR_INFO, "test_one_thread_FIFO"); + int counter = 0; + grpc_core::ThreadPool* pool = + new grpc_core::ThreadPool(1, "test_one_thread_FIFO"); + SimpleFunctorCheckForAdd** check_functors = + static_cast( + gpr_zalloc(sizeof(SimpleFunctorCheckForAdd*) * kThreadSmallIter)); + for (int i = 0; i < kThreadSmallIter; ++i) { + check_functors[i] = new SimpleFunctorCheckForAdd(i + 1, &counter); + pool->Add(check_functors[i]); + } + // Destructor of pool will wait until all closures finished. + delete pool; + for (int i = 0; i < kThreadSmallIter; ++i) { + delete check_functors[i]; + } + gpr_free(check_functors); + gpr_log(GPR_DEBUG, "Done."); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_size_zero(); + test_constructor_option(); + test_add(); + test_multi_add(); + test_one_thread_FIFO(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/iomgr/time_averaged_stats_test.cc b/test/core/iomgr/time_averaged_stats_test.cc new file mode 100644 index 00000000..68513533 --- /dev/null +++ b/test/core/iomgr/time_averaged_stats_test.cc @@ -0,0 +1,194 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/time_averaged_stats.h" + +#include + +#include + +#include "test/core/util/test_config.h" + +#define EXPECT_EQ(a, b) GPR_ASSERT((a) == (b)) +#define EXPECT_DOUBLE_EQ(a, b) GPR_ASSERT(fabs((a) - (b)) < 1e-9) + +static void no_regress_no_persist_test_1(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0, 0.0); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(0, tas.aggregate_total_weight); + + /* Should have no effect */ + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(0, tas.aggregate_total_weight); + + /* Should replace old average */ + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(2000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1, tas.aggregate_total_weight); +} + +static void no_regress_no_persist_test_2(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0, 0.0); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + /* Should replace init value */ + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(2000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 3000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(3000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1, tas.aggregate_total_weight); +} + +static void no_regress_no_persist_test_3(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0, 0.0); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + /* Should replace init value */ + grpc_time_averaged_stats_add_sample(&tas, 2500); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(2500, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 3500); + grpc_time_averaged_stats_add_sample(&tas, 4500); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(4000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2, tas.aggregate_total_weight); +} + +static void some_regress_no_persist_test(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0.5, 0.0); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(0, tas.aggregate_total_weight); + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + /* (2 * 2000 + 0.5 * 1000) / 2.5 */ + EXPECT_DOUBLE_EQ(1800, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2.5, tas.aggregate_total_weight); +} + +static void some_decay_test(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 1, 0.0); + EXPECT_EQ(1000, tas.aggregate_weighted_avg); + /* Should avg with init value */ + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(1500, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(1500, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(1500, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2, tas.aggregate_total_weight); +} + +static void no_regress_full_persist_test(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0, 1.0); + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(0, tas.aggregate_total_weight); + + /* Should replace init value */ + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_EQ(2000, tas.aggregate_weighted_avg); + EXPECT_EQ(1, tas.aggregate_total_weight); + + /* Will result in average of the 3 samples. */ + grpc_time_averaged_stats_add_sample(&tas, 2300); + grpc_time_averaged_stats_add_sample(&tas, 2300); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(2200, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(3, tas.aggregate_total_weight); +} + +static void no_regress_some_persist_test(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0, 0.5); + /* Should replace init value */ + grpc_time_averaged_stats_add_sample(&tas, 2000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(2000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 2500); + grpc_time_averaged_stats_add_sample(&tas, 4000); + grpc_time_averaged_stats_update_average(&tas); + EXPECT_DOUBLE_EQ(3000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2.5, tas.aggregate_total_weight); +} + +static void some_regress_some_persist_test(void) { + grpc_time_averaged_stats tas; + grpc_time_averaged_stats_init(&tas, 1000, 0.4, 0.6); + /* Sample weight = 0 */ + EXPECT_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_EQ(0, tas.aggregate_total_weight); + + grpc_time_averaged_stats_update_average(&tas); + /* (0.6 * 0 * 1000 + 0.4 * 1000 / 0.4) */ + EXPECT_DOUBLE_EQ(1000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(0.4, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 2640); + grpc_time_averaged_stats_update_average(&tas); + /* (1 * 2640 + 0.6 * 0.4 * 1000 + 0.4 * 1000 / (1 + 0.6 * 0.4 + 0.4) */ + EXPECT_DOUBLE_EQ(2000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(1.64, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 2876.8); + grpc_time_averaged_stats_update_average(&tas); + /* (1 * 2876.8 + 0.6 * 1.64 * 2000 + 0.4 * 1000 / (1 + 0.6 * 1.64 + 0.4) */ + EXPECT_DOUBLE_EQ(2200, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2.384, tas.aggregate_total_weight); + + grpc_time_averaged_stats_add_sample(&tas, 4944.32); + grpc_time_averaged_stats_update_average(&tas); + /* (1 * 4944.32 + 0.6 * 2.384 * 2200 + 0.4 * 1000) / + (1 + 0.6 * 2.384 + 0.4) */ + EXPECT_DOUBLE_EQ(3000, tas.aggregate_weighted_avg); + EXPECT_DOUBLE_EQ(2.8304, tas.aggregate_total_weight); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + no_regress_no_persist_test_1(); + no_regress_no_persist_test_2(); + no_regress_no_persist_test_3(); + some_regress_no_persist_test(); + some_decay_test(); + no_regress_full_persist_test(); + no_regress_some_persist_test(); + some_regress_some_persist_test(); + return 0; +} diff --git a/test/core/iomgr/timer_heap_test.cc b/test/core/iomgr/timer_heap_test.cc new file mode 100644 index 00000000..82da474d --- /dev/null +++ b/test/core/iomgr/timer_heap_test.cc @@ -0,0 +1,297 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/timer_heap.h" + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/port.h" +#include "test/core/util/test_config.h" + +static gpr_atm random_deadline(void) { return rand(); } + +static grpc_timer* create_test_elements(size_t num_elements) { + grpc_timer* elems = + static_cast(gpr_malloc(num_elements * sizeof(grpc_timer))); + size_t i; + for (i = 0; i < num_elements; i++) { + elems[i].deadline = random_deadline(); + } + return elems; +} + +static int contains(grpc_timer_heap* pq, grpc_timer* el) { + size_t i; + for (i = 0; i < pq->timer_count; i++) { + if (pq->timers[i] == el) return 1; + } + return 0; +} + +static void check_valid(grpc_timer_heap* pq) { + size_t i; + for (i = 0; i < pq->timer_count; ++i) { + size_t left_child = 1u + 2u * i; + size_t right_child = left_child + 1u; + if (left_child < pq->timer_count) { + GPR_ASSERT(pq->timers[i]->deadline <= pq->timers[left_child]->deadline); + } + if (right_child < pq->timer_count) { + GPR_ASSERT(pq->timers[i]->deadline <= pq->timers[right_child]->deadline); + } + } +} + +/******************************************************************************* + * test1 + */ + +static void test1(void) { + grpc_timer_heap pq; + const size_t num_test_elements = 200; + const size_t num_test_operations = 10000; + size_t i; + grpc_timer* test_elements = create_test_elements(num_test_elements); + uint8_t* inpq = static_cast(gpr_malloc(num_test_elements)); + + gpr_log(GPR_INFO, "test1"); + + grpc_timer_heap_init(&pq); + memset(inpq, 0, num_test_elements); + GPR_ASSERT(grpc_timer_heap_is_empty(&pq)); + check_valid(&pq); + for (i = 0; i < num_test_elements; ++i) { + GPR_ASSERT(!contains(&pq, &test_elements[i])); + grpc_timer_heap_add(&pq, &test_elements[i]); + check_valid(&pq); + GPR_ASSERT(contains(&pq, &test_elements[i])); + inpq[i] = 1; + } + for (i = 0; i < num_test_elements; ++i) { + /* Test that check still succeeds even for element that wasn't just + inserted. */ + GPR_ASSERT(contains(&pq, &test_elements[i])); + } + + GPR_ASSERT(pq.timer_count == num_test_elements); + + check_valid(&pq); + + for (i = 0; i < num_test_operations; ++i) { + size_t elem_num = static_cast(rand()) % num_test_elements; + grpc_timer* el = &test_elements[elem_num]; + if (!inpq[elem_num]) { /* not in pq */ + GPR_ASSERT(!contains(&pq, el)); + el->deadline = random_deadline(); + grpc_timer_heap_add(&pq, el); + GPR_ASSERT(contains(&pq, el)); + inpq[elem_num] = 1; + check_valid(&pq); + } else { + GPR_ASSERT(contains(&pq, el)); + grpc_timer_heap_remove(&pq, el); + GPR_ASSERT(!contains(&pq, el)); + inpq[elem_num] = 0; + check_valid(&pq); + } + } + + grpc_timer_heap_destroy(&pq); + gpr_free(test_elements); + gpr_free(inpq); +} + +/******************************************************************************* + * test2 + */ + +typedef struct { + grpc_timer elem; + bool inserted; +} elem_struct; + +static elem_struct* search_elems(elem_struct* elems, size_t count, + bool inserted) { + size_t* search_order = + static_cast(gpr_malloc(count * sizeof(*search_order))); + for (size_t i = 0; i < count; i++) { + search_order[i] = i; + } + for (size_t i = 0; i < count * 2; i++) { + size_t a = static_cast(rand()) % count; + size_t b = static_cast(rand()) % count; + std::swap(search_order[a], search_order[b]); + } + elem_struct* out = nullptr; + for (size_t i = 0; out == nullptr && i < count; i++) { + if (elems[search_order[i]].inserted == inserted) { + out = &elems[search_order[i]]; + } + } + gpr_free(search_order); + return out; +} + +static void test2(void) { + gpr_log(GPR_INFO, "test2"); + + grpc_timer_heap pq; + + static const size_t elems_size = 1000; + elem_struct* elems = + static_cast(gpr_malloc(elems_size * sizeof(elem_struct))); + size_t num_inserted = 0; + + grpc_timer_heap_init(&pq); + memset(elems, 0, elems_size * sizeof(elems[0])); + + for (size_t round = 0; round < 10000; round++) { + int r = rand() % 1000; + if (r <= 550) { + /* 55% of the time we try to add something */ + elem_struct* el = search_elems(elems, elems_size, false); + if (el != nullptr) { + el->elem.deadline = random_deadline(); + grpc_timer_heap_add(&pq, &el->elem); + el->inserted = true; + num_inserted++; + check_valid(&pq); + } + } else if (r <= 650) { + /* 10% of the time we try to remove something */ + elem_struct* el = search_elems(elems, elems_size, true); + if (el != nullptr) { + grpc_timer_heap_remove(&pq, &el->elem); + el->inserted = false; + num_inserted--; + check_valid(&pq); + } + } else { + /* the remaining times we pop */ + if (num_inserted > 0) { + grpc_timer* top = grpc_timer_heap_top(&pq); + grpc_timer_heap_pop(&pq); + for (size_t i = 0; i < elems_size; i++) { + if (top == &elems[i].elem) { + GPR_ASSERT(elems[i].inserted); + elems[i].inserted = false; + } + } + num_inserted--; + check_valid(&pq); + } + } + + if (num_inserted) { + grpc_millis* min_deadline = nullptr; + for (size_t i = 0; i < elems_size; i++) { + if (elems[i].inserted) { + if (min_deadline == nullptr) { + min_deadline = &elems[i].elem.deadline; + } else { + if (elems[i].elem.deadline < *min_deadline) { + min_deadline = &elems[i].elem.deadline; + } + } + } + } + GPR_ASSERT(grpc_timer_heap_top(&pq)->deadline == *min_deadline); + } + } + + grpc_timer_heap_destroy(&pq); + gpr_free(elems); +} + +static void shrink_test(void) { + gpr_log(GPR_INFO, "shrink_test"); + + grpc_timer_heap pq; + size_t i; + size_t expected_size; + + /* A large random number to allow for multiple shrinkages, at least 512. */ + const size_t num_elements = static_cast(rand()) % 2000 + 512; + + grpc_timer_heap_init(&pq); + + /* Create a priority queue with many elements. Make sure the Size() is + correct. */ + for (i = 0; i < num_elements; ++i) { + GPR_ASSERT(i == pq.timer_count); + grpc_timer_heap_add(&pq, create_test_elements(1)); + } + GPR_ASSERT(num_elements == pq.timer_count); + + /* Remove elements until the Size is 1/4 the original size. */ + while (pq.timer_count > num_elements / 4) { + grpc_timer* const te = pq.timers[pq.timer_count - 1]; + grpc_timer_heap_remove(&pq, te); + gpr_free(te); + } + GPR_ASSERT(num_elements / 4 == pq.timer_count); + + /* Expect that Capacity is in the right range: + Size * 2 <= Capacity <= Size * 4 */ + GPR_ASSERT(pq.timer_count * 2 <= pq.timer_capacity); + GPR_ASSERT(pq.timer_capacity <= pq.timer_count * 4); + check_valid(&pq); + + /* Remove the rest of the elements. Check that the Capacity is not more than + 4 times the Size and not less than 2 times, but never goes below 16. */ + expected_size = pq.timer_count; + while (pq.timer_count > 0) { + const size_t which = static_cast(rand()) % pq.timer_count; + grpc_timer* te = pq.timers[which]; + grpc_timer_heap_remove(&pq, te); + gpr_free(te); + expected_size--; + GPR_ASSERT(expected_size == pq.timer_count); + GPR_ASSERT(pq.timer_count * 2 <= pq.timer_capacity); + if (pq.timer_count >= 8) { + GPR_ASSERT(pq.timer_capacity <= pq.timer_count * 4); + } else { + GPR_ASSERT(16 <= pq.timer_capacity); + } + check_valid(&pq); + } + + GPR_ASSERT(0 == pq.timer_count); + GPR_ASSERT(pq.timer_capacity >= 16 && pq.timer_capacity < 32); + + grpc_timer_heap_destroy(&pq); +} + +int main(int argc, char** argv) { + int i; + + grpc::testing::TestEnvironment env(argc, argv); + + for (i = 0; i < 5; i++) { + test1(); + test2(); + shrink_test(); + } + + return 0; +} diff --git a/test/core/iomgr/timer_list_test.cc b/test/core/iomgr/timer_list_test.cc new file mode 100644 index 00000000..b7640ee6 --- /dev/null +++ b/test/core/iomgr/timer_list_test.cc @@ -0,0 +1,265 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +// This test only works with the generic timer implementation +#ifndef GRPC_CUSTOM_SOCKET + +#include + +#include +#include + +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/iomgr/iomgr_internal.h" +#include "src/core/lib/iomgr/timer.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tracer_util.h" + +#define MAX_CB 30 + +extern grpc_core::TraceFlag grpc_timer_trace; +extern grpc_core::TraceFlag grpc_timer_check_trace; + +static int cb_called[MAX_CB][2]; +static const int64_t kMillisIn25Days = 2160000000; +static const int64_t kHoursIn25Days = 600; + +static void cb(void* arg, grpc_error_handle error) { + cb_called[reinterpret_cast(arg)][error == GRPC_ERROR_NONE]++; +} + +static void add_test(void) { + int i; + grpc_timer timers[20]; + grpc_core::ExecCtx exec_ctx; + + gpr_log(GPR_INFO, "add_test"); + + grpc_timer_list_init(); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_trace); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_check_trace); + memset(cb_called, 0, sizeof(cb_called)); + + grpc_millis start = grpc_core::ExecCtx::Get()->Now(); + + /* 10 ms timers. will expire in the current epoch */ + for (i = 0; i < 10; i++) { + grpc_timer_init( + &timers[i], start + 10, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)i, grpc_schedule_on_exec_ctx)); + } + + /* 1010 ms timers. will expire in the next epoch */ + for (i = 10; i < 20; i++) { + grpc_timer_init( + &timers[i], start + 1010, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)i, grpc_schedule_on_exec_ctx)); + } + + /* collect timers. Only the first batch should be ready. */ + grpc_core::ExecCtx::Get()->TestOnlySetNow(start + 500); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_FIRED); + grpc_core::ExecCtx::Get()->Flush(); + for (i = 0; i < 20; i++) { + GPR_ASSERT(cb_called[i][1] == (i < 10)); + GPR_ASSERT(cb_called[i][0] == 0); + } + + grpc_core::ExecCtx::Get()->TestOnlySetNow(start + 600); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_CHECKED_AND_EMPTY); + grpc_core::ExecCtx::Get()->Flush(); + for (i = 0; i < 30; i++) { + GPR_ASSERT(cb_called[i][1] == (i < 10)); + GPR_ASSERT(cb_called[i][0] == 0); + } + + /* collect the rest of the timers */ + grpc_core::ExecCtx::Get()->TestOnlySetNow(start + 1500); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_FIRED); + grpc_core::ExecCtx::Get()->Flush(); + for (i = 0; i < 30; i++) { + GPR_ASSERT(cb_called[i][1] == (i < 20)); + GPR_ASSERT(cb_called[i][0] == 0); + } + + grpc_core::ExecCtx::Get()->TestOnlySetNow(start + 1600); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_CHECKED_AND_EMPTY); + for (i = 0; i < 30; i++) { + GPR_ASSERT(cb_called[i][1] == (i < 20)); + GPR_ASSERT(cb_called[i][0] == 0); + } + + grpc_timer_list_shutdown(); +} + +/* Cleaning up a list with pending timers. */ +void destruction_test(void) { + grpc_timer timers[5]; + grpc_core::ExecCtx exec_ctx; + + gpr_log(GPR_INFO, "destruction_test"); + + grpc_core::ExecCtx::Get()->TestOnlySetNow(0); + grpc_timer_list_init(); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_trace); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_check_trace); + memset(cb_called, 0, sizeof(cb_called)); + + grpc_timer_init( + &timers[0], 100, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)0, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[1], 3, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)1, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[2], 100, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)2, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[3], 3, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)3, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[4], 1, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)4, grpc_schedule_on_exec_ctx)); + grpc_core::ExecCtx::Get()->TestOnlySetNow(2); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_FIRED); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(1 == cb_called[4][1]); + grpc_timer_cancel(&timers[0]); + grpc_timer_cancel(&timers[3]); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(1 == cb_called[0][0]); + GPR_ASSERT(1 == cb_called[3][0]); + + grpc_timer_list_shutdown(); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(1 == cb_called[1][0]); + GPR_ASSERT(1 == cb_called[2][0]); +} + +/* Cleans up a list with pending timers that simulate long-running-services. + This test does the following: + 1) Simulates grpc server start time to 25 days in the past (completed in + `main` using TestOnlyGlobalInit()) + 2) Creates 4 timers - one with a deadline 25 days in the future, one just + 3 milliseconds in future, one way out in the future, and one using the + grpc_timespec_to_millis_round_up function to compute a deadline of 25 + days in the future + 3) Simulates 4 milliseconds of elapsed time by changing `now` (cached at + step 1) to `now+4` + 4) Shuts down the timer list + https://github.com/grpc/grpc/issues/15904 */ +void long_running_service_cleanup_test(void) { + grpc_timer timers[4]; + grpc_core::ExecCtx exec_ctx; + + gpr_log(GPR_INFO, "long_running_service_cleanup_test"); + + grpc_millis now = grpc_core::ExecCtx::Get()->Now(); + GPR_ASSERT(now >= kMillisIn25Days); + grpc_timer_list_init(); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_trace); + grpc_core::testing::grpc_tracer_enable_flag(&grpc_timer_check_trace); + memset(cb_called, 0, sizeof(cb_called)); + + grpc_timer_init( + &timers[0], now + kMillisIn25Days, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)0, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[1], now + 3, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)1, grpc_schedule_on_exec_ctx)); + grpc_timer_init( + &timers[2], GRPC_MILLIS_INF_FUTURE - 1, + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)2, grpc_schedule_on_exec_ctx)); + + gpr_timespec deadline_spec = grpc_millis_to_timespec( + now + kMillisIn25Days, gpr_clock_type::GPR_CLOCK_MONOTONIC); + + /* grpc_timespec_to_millis_round_up is how users usually compute a millisecond + input value into grpc_timer_init, so we mimic that behavior here */ + grpc_timer_init( + &timers[3], grpc_timespec_to_millis_round_up(deadline_spec), + GRPC_CLOSURE_CREATE(cb, (void*)(intptr_t)3, grpc_schedule_on_exec_ctx)); + + grpc_core::ExecCtx::Get()->TestOnlySetNow(now + 4); + GPR_ASSERT(grpc_timer_check(nullptr) == GRPC_TIMERS_FIRED); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(0 == cb_called[0][0]); // Timer 0 not called + GPR_ASSERT(0 == cb_called[0][1]); + GPR_ASSERT(0 == cb_called[1][0]); + GPR_ASSERT(1 == cb_called[1][1]); // Timer 1 fired + GPR_ASSERT(0 == cb_called[2][0]); // Timer 2 not called + GPR_ASSERT(0 == cb_called[2][1]); + GPR_ASSERT(0 == cb_called[3][0]); // Timer 3 not called + GPR_ASSERT(0 == cb_called[3][1]); + + grpc_timer_list_shutdown(); + grpc_core::ExecCtx::Get()->Flush(); + /* Timers 0, 2, and 3 were fired with an error during cleanup */ + GPR_ASSERT(1 == cb_called[0][0]); + GPR_ASSERT(0 == cb_called[1][0]); + GPR_ASSERT(1 == cb_called[2][0]); + GPR_ASSERT(1 == cb_called[3][0]); +} + +int main(int argc, char** argv) { + /* Tests with default g_start_time */ + { + grpc::testing::TestEnvironment env(argc, argv); + grpc_core::ExecCtx::GlobalInit(); + grpc_core::ExecCtx exec_ctx; + grpc_set_default_iomgr_platform(); + grpc_iomgr_platform_init(); + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + add_test(); + destruction_test(); + grpc_iomgr_platform_shutdown(); + } + grpc_core::ExecCtx::GlobalShutdown(); + + /* Begin long running service tests */ + { + grpc::testing::TestEnvironment env(argc, argv); + /* Set g_start_time back 25 days. */ + /* We set g_start_time here in case there are any initialization + dependencies that use g_start_time. */ + gpr_timespec new_start = + gpr_time_sub(gpr_now(gpr_clock_type::GPR_CLOCK_MONOTONIC), + gpr_time_from_hours(kHoursIn25Days, + gpr_clock_type::GPR_CLOCK_MONOTONIC)); + grpc_core::ExecCtx::TestOnlyGlobalInit(new_start); + grpc_core::ExecCtx exec_ctx; + grpc_set_default_iomgr_platform(); + grpc_iomgr_platform_init(); + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + long_running_service_cleanup_test(); + add_test(); + destruction_test(); + grpc_iomgr_platform_shutdown(); + } + grpc_core::ExecCtx::GlobalShutdown(); + + return 0; +} + +#else /* GRPC_CUSTOM_SOCKET */ + +int main(int argc, char** argv) { return 1; } + +#endif /* GRPC_CUSTOM_SOCKET */ diff --git a/test/core/iomgr/work_serializer_test.cc b/test/core/iomgr/work_serializer_test.cc new file mode 100644 index 00000000..09258747 --- /dev/null +++ b/test/core/iomgr/work_serializer_test.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/work_serializer.h" + +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "test/core/util/test_config.h" + +namespace { +TEST(WorkSerializerTest, NoOp) { grpc_core::WorkSerializer lock; } + +TEST(WorkSerializerTest, ExecuteOne) { + grpc_core::WorkSerializer lock; + gpr_event done; + gpr_event_init(&done); + lock.Run([&done]() { gpr_event_set(&done, reinterpret_cast(1)); }, + DEBUG_LOCATION); + EXPECT_TRUE(gpr_event_wait(&done, grpc_timeout_seconds_to_deadline(5)) != + nullptr); +} + +class TestThread { + public: + explicit TestThread(grpc_core::WorkSerializer* lock) + : lock_(lock), thread_("grpc_execute_many", ExecuteManyLoop, this) { + gpr_event_init(&done_); + thread_.Start(); + } + + ~TestThread() { + EXPECT_NE(gpr_event_wait(&done_, gpr_inf_future(GPR_CLOCK_REALTIME)), + nullptr); + thread_.Join(); + } + + private: + static void ExecuteManyLoop(void* arg) { + TestThread* self = static_cast(arg); + size_t n = 1; + for (size_t i = 0; i < 10; i++) { + for (size_t j = 0; j < 10000; j++) { + struct ExecutionArgs { + size_t* counter; + size_t value; + }; + ExecutionArgs* c = new ExecutionArgs; + c->counter = &self->counter_; + c->value = n++; + self->lock_->Run( + [c]() { + EXPECT_TRUE(*c->counter == c->value - 1); + *c->counter = c->value; + delete c; + }, + DEBUG_LOCATION); + } + // sleep for a little bit, to test other threads picking up the load + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + } + self->lock_->Run( + [self]() { gpr_event_set(&self->done_, reinterpret_cast(1)); }, + DEBUG_LOCATION); + } + + grpc_core::WorkSerializer* lock_ = nullptr; + grpc_core::Thread thread_; + size_t counter_ = 0; + gpr_event done_; +}; + +TEST(WorkSerializerTest, ExecuteMany) { + grpc_core::WorkSerializer lock; + { + std::vector> threads; + for (size_t i = 0; i < 100; ++i) { + threads.push_back(absl::make_unique(&lock)); + } + } +} +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/json/fuzzer.cc b/test/core/json/fuzzer.cc new file mode 100644 index 00000000..2925611c --- /dev/null +++ b/test/core/json/fuzzer.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include + +#include "src/core/lib/json/json.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json::Parse( + absl::string_view(reinterpret_cast(data), size), &error); + GRPC_ERROR_UNREF(error); + return 0; +} diff --git a/test/core/json/json_test.cc b/test/core/json/json_test.cc new file mode 100644 index 00000000..6821648c --- /dev/null +++ b/test/core/json/json_test.cc @@ -0,0 +1,295 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/json/json.h" + +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { + +void ValidateValue(const Json& actual, const Json& expected); + +void ValidateObject(const Json::Object& actual, const Json::Object& expected) { + ASSERT_EQ(actual.size(), expected.size()); + auto actual_it = actual.begin(); + for (const auto& p : expected) { + EXPECT_EQ(actual_it->first, p.first); + ValidateValue(actual_it->second, p.second); + ++actual_it; + } +} + +void ValidateArray(const Json::Array& actual, const Json::Array& expected) { + ASSERT_EQ(actual.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + ValidateValue(actual[i], expected[i]); + } +} + +void ValidateValue(const Json& actual, const Json& expected) { + ASSERT_EQ(actual.type(), expected.type()); + switch (expected.type()) { + case Json::Type::JSON_NULL: + case Json::Type::JSON_TRUE: + case Json::Type::JSON_FALSE: + break; + case Json::Type::STRING: + case Json::Type::NUMBER: + EXPECT_EQ(actual.string_value(), expected.string_value()); + break; + case Json::Type::OBJECT: + ValidateObject(actual.object_value(), expected.object_value()); + break; + case Json::Type::ARRAY: + ValidateArray(actual.array_value(), expected.array_value()); + break; + } +} + +void RunSuccessTest(const char* input, const Json& expected, + const char* expected_output) { + gpr_log(GPR_INFO, "parsing string \"%s\" - should succeed", input); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(input, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ValidateValue(json, expected); + std::string output = json.Dump(); + EXPECT_EQ(output, expected_output); +} + +TEST(Json, Whitespace) { + RunSuccessTest(" 0 ", 0, "0"); + RunSuccessTest(" 1 ", 1, "1"); + RunSuccessTest(" \" \" ", " ", "\" \""); + RunSuccessTest(" \"a\" ", "a", "\"a\""); + RunSuccessTest(" true ", true, "true"); +} + +TEST(Json, Utf16) { + RunSuccessTest("\"\\u0020\\\\\\u0010\\u000a\\u000D\"", " \\\u0010\n\r", + "\" \\\\\\u0010\\n\\r\""); +} + +TEST(Json, Utf8) { + RunSuccessTest("\"ßâñć௵⇒\"", "ßâñć௵⇒", + "\"\\u00df\\u00e2\\u00f1\\u0107\\u0bf5\\u21d2\""); + RunSuccessTest("\"\\u00df\\u00e2\\u00f1\\u0107\\u0bf5\\u21d2\"", "ßâñć௵⇒", + "\"\\u00df\\u00e2\\u00f1\\u0107\\u0bf5\\u21d2\""); + // Testing UTF-8 character "𝄞", U+11D1E. + RunSuccessTest("\"\xf0\x9d\x84\x9e\"", "\xf0\x9d\x84\x9e", + "\"\\ud834\\udd1e\""); + RunSuccessTest("\"\\ud834\\udd1e\"", "\xf0\x9d\x84\x9e", + "\"\\ud834\\udd1e\""); + RunSuccessTest("{\"\\ud834\\udd1e\":0}", + Json::Object{{"\xf0\x9d\x84\x9e", 0}}, + "{\"\\ud834\\udd1e\":0}"); +} + +TEST(Json, NestedEmptyContainers) { + RunSuccessTest(" [ [ ] , { } , [ ] ] ", + Json::Array{ + Json::Array(), + Json::Object(), + Json::Array(), + }, + "[[],{},[]]"); +} + +TEST(Json, EscapesAndControlCharactersInKeyStrings) { + RunSuccessTest(" { \"\\u007f\x7f\\n\\r\\\"\\f\\b\\\\a , b\": 1, \"\": 0 } ", + Json::Object{ + {"\u007f\u007f\n\r\"\f\b\\a , b", 1}, + {"", 0}, + }, + "{\"\":0,\"\\u007f\\u007f\\n\\r\\\"\\f\\b\\\\a , b\":1}"); +} + +TEST(Json, WriterCutsOffInvalidUtf8) { + RunSuccessTest("\"abc\xf0\x9d\x24\"", "abc\xf0\x9d\x24", "\"abc\""); + RunSuccessTest("\"\xff\"", "\xff", "\"\""); +} + +TEST(Json, ValidNumbers) { + RunSuccessTest("[0, 42 , 0.0123, 123.456]", + Json::Array{ + 0, + 42, + Json("0.0123", /*is_number=*/true), + Json("123.456", /*is_number=*/true), + }, + "[0,42,0.0123,123.456]"); + RunSuccessTest("[1e4,-53.235e-31, 0.3e+3]", + Json::Array{ + Json("1e4", /*is_number=*/true), + Json("-53.235e-31", /*is_number=*/true), + Json("0.3e+3", /*is_number=*/true), + }, + "[1e4,-53.235e-31,0.3e+3]"); +} + +TEST(Json, Keywords) { + RunSuccessTest("[true, false, null]", + Json::Array{ + Json(true), + Json(false), + Json(), + }, + "[true,false,null]"); +} + +void RunParseFailureTest(const char* input) { + gpr_log(GPR_INFO, "parsing string \"%s\" - should fail", input); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(input, &error); + gpr_log(GPR_INFO, "error: %s", grpc_error_std_string(error).c_str()); + EXPECT_NE(error, GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); +} + +TEST(Json, InvalidInput) { + RunParseFailureTest("\\"); + RunParseFailureTest("nu ll"); + RunParseFailureTest("{\"foo\": bar}"); + RunParseFailureTest("{\"foo\": bar\"x\"}"); + RunParseFailureTest("fals"); + RunParseFailureTest("0,0 "); + RunParseFailureTest("\"foo\",[]"); +} + +TEST(Json, UnterminatedString) { RunParseFailureTest("\"\\x"); } + +TEST(Json, InvalidUtf16) { + RunParseFailureTest("\"\\u123x"); + RunParseFailureTest("{\"\\u123x"); +} + +TEST(Json, ImbalancedSurrogatePairs) { + RunParseFailureTest("\"\\ud834f"); + RunParseFailureTest("{\"\\ud834f\":0}"); + RunParseFailureTest("\"\\ud834\\n"); + RunParseFailureTest("{\"\\ud834\\n\":0}"); + RunParseFailureTest("\"\\udd1ef"); + RunParseFailureTest("{\"\\udd1ef\":0}"); + RunParseFailureTest("\"\\ud834\\ud834\""); + RunParseFailureTest("{\"\\ud834\\ud834\"\":0}"); + RunParseFailureTest("\"\\ud834\\u1234\""); + RunParseFailureTest("{\"\\ud834\\u1234\"\":0}"); + RunParseFailureTest("\"\\ud834]\""); + RunParseFailureTest("{\"\\ud834]\"\":0}"); + RunParseFailureTest("\"\\ud834 \""); + RunParseFailureTest("{\"\\ud834 \"\":0}"); + RunParseFailureTest("\"\\ud834\\\\\""); + RunParseFailureTest("{\"\\ud834\\\\\"\":0}"); +} + +TEST(Json, EmbeddedInvalidWhitechars) { + RunParseFailureTest("\"\n\""); + RunParseFailureTest("\"\t\""); +} + +TEST(Json, EmptyString) { RunParseFailureTest(""); } + +TEST(Json, ExtraCharsAtEndOfParsing) { + RunParseFailureTest("{},"); + RunParseFailureTest("{}x"); +} + +TEST(Json, ImbalancedContainers) { + RunParseFailureTest("{}}"); + RunParseFailureTest("[]]"); + RunParseFailureTest("{{}"); + RunParseFailureTest("[[]"); + RunParseFailureTest("[}"); + RunParseFailureTest("{]"); +} + +TEST(Json, BadContainers) { + RunParseFailureTest("{x}"); + RunParseFailureTest("{x=0,y}"); +} + +TEST(Json, DuplicateObjectKeys) { RunParseFailureTest("{\"x\": 1, \"x\": 1}"); } + +TEST(Json, TrailingComma) { + RunParseFailureTest("{,}"); + RunParseFailureTest("[1,2,3,4,]"); + RunParseFailureTest("{\"a\": 1, }"); +} + +TEST(Json, KeySyntaxInArray) { RunParseFailureTest("[\"x\":0]"); } + +TEST(Json, InvalidNumbers) { + RunParseFailureTest("1."); + RunParseFailureTest("1e"); + RunParseFailureTest(".12"); + RunParseFailureTest("1.x"); + RunParseFailureTest("1.12x"); + RunParseFailureTest("1ex"); + RunParseFailureTest("1e12x"); + RunParseFailureTest(".12x"); + RunParseFailureTest("000"); +}; + +TEST(Json, Equality) { + // Null. + EXPECT_EQ(Json(), Json()); + // Numbers. + EXPECT_EQ(Json(1), Json(1)); + EXPECT_NE(Json(1), Json(2)); + EXPECT_EQ(Json(1), Json("1", /*is_number=*/true)); + EXPECT_EQ(Json("-5e5", /*is_number=*/true), Json("-5e5", /*is_number=*/true)); + // Booleans. + EXPECT_EQ(Json(true), Json(true)); + EXPECT_EQ(Json(false), Json(false)); + EXPECT_NE(Json(true), Json(false)); + // Strings. + EXPECT_EQ(Json("foo"), Json("foo")); + EXPECT_NE(Json("foo"), Json("bar")); + // Arrays. + EXPECT_EQ(Json(Json::Array{"foo"}), Json(Json::Array{"foo"})); + EXPECT_NE(Json(Json::Array{"foo"}), Json(Json::Array{"bar"})); + // Objects. + EXPECT_EQ(Json(Json::Object{{"foo", 1}}), Json(Json::Object{{"foo", 1}})); + EXPECT_NE(Json(Json::Object{{"foo", 1}}), Json(Json::Object{{"foo", 2}})); + EXPECT_NE(Json(Json::Object{{"foo", 1}}), Json(Json::Object{{"bar", 1}})); + // Differing types. + EXPECT_NE(Json(1), Json("foo")); + EXPECT_NE(Json(1), Json(true)); + EXPECT_NE(Json(1), Json(Json::Array{})); + EXPECT_NE(Json(1), Json(Json::Object{})); + EXPECT_NE(Json(1), Json()); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/nanopb/fuzzer_response.cc b/test/core/nanopb/fuzzer_response.cc new file mode 100644 index 00000000..c4e85eb3 --- /dev/null +++ b/test/core/nanopb/fuzzer_response.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "upb/upb.hpp" + +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h" + +bool squelch = true; +bool leak_check = true; + +static void dont_log(gpr_log_func_args* /*args*/) {} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* /*data*/, + size_t /*size*/) { + grpc_init(); + if (squelch) gpr_set_log_function(dont_log); + // TODO(veblush): Convert this to upb. + /* + grpc_slice slice = grpc_slice_from_copied_buffer((const char*)data, size); + upb::Arena arena; + grpc_core::grpc_grpclb_initial_response_parse(slice, arena.ptr()); + grpc_slice_unref(slice); + */ + grpc_shutdown(); + return 0; +} diff --git a/test/core/nanopb/fuzzer_serverlist.cc b/test/core/nanopb/fuzzer_serverlist.cc new file mode 100644 index 00000000..2c611ef9 --- /dev/null +++ b/test/core/nanopb/fuzzer_serverlist.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h" + +bool squelch = true; +bool leak_check = true; + +static void dont_log(gpr_log_func_args* /*args*/) {} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* /*data*/, + size_t /*size*/) { + grpc_init(); + if (squelch) gpr_set_log_function(dont_log); + // TODO(veblush): Convert this to upb. + /* + grpc_slice slice = grpc_slice_from_copied_buffer((const char*)data, size); + grpc_core::grpc_grpclb_serverlist* serverlist; + if ((serverlist = grpc_core::grpc_grpclb_response_parse_serverlist(slice))) { + grpc_grpclb_destroy_serverlist(serverlist); + } + grpc_slice_unref(slice); + */ + grpc_shutdown(); + return 0; +} diff --git a/test/core/network_benchmarks/low_level_ping_pong.cc b/test/core/network_benchmarks/low_level_ping_pong.cc new file mode 100644 index 00000000..683c85e2 --- /dev/null +++ b/test/core/network_benchmarks/low_level_ping_pong.cc @@ -0,0 +1,696 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* + Basic I/O ping-pong benchmarks. + + The goal here is to establish lower bounds on how fast the stack could get by + measuring the cost of using various I/O strategies to do a basic + request-response loop. + */ + +#include +#include +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/histogram.h" + +typedef struct fd_pair { + int read_fd; + int write_fd; +} fd_pair; + +typedef struct thread_args { + fd_pair fds; + size_t msg_size; + int (*read_bytes)(struct thread_args* args, char* buf); + int (*write_bytes)(struct thread_args* args, char* buf); + int (*setup)(struct thread_args* args); + int epoll_fd; + const char* strategy_name; +} thread_args; + +/* + Read strategies + + There are a number of read strategies, each of which has a blocking and + non-blocking version. + */ + +/* Basic call to read() */ +static int read_bytes(int fd, char* buf, size_t read_size, int spin) { + size_t bytes_read = 0; + ssize_t err; + do { + err = read(fd, buf + bytes_read, read_size - bytes_read); + if (err < 0) { + if (errno == EINTR) { + continue; + } else { + if (errno == EAGAIN && spin == 1) { + continue; + } + gpr_log(GPR_ERROR, "Read failed: %s", strerror(errno)); + return -1; + } + } else { + bytes_read += static_cast(err); + } + } while (bytes_read < read_size); + return 0; +} + +static int blocking_read_bytes(thread_args* args, char* buf) { + return read_bytes(args->fds.read_fd, buf, args->msg_size, 0); +} + +static int spin_read_bytes(thread_args* args, char* buf) { + return read_bytes(args->fds.read_fd, buf, args->msg_size, 1); +} + +/* Call poll() to monitor a non-blocking fd */ +static int poll_read_bytes(int fd, char* buf, size_t read_size, int spin) { + struct pollfd pfd; + size_t bytes_read = 0; + int err; + ssize_t err2; + + pfd.fd = fd; + pfd.events = POLLIN; + do { + err = poll(&pfd, 1, spin ? 0 : -1); + if (err < 0) { + if (errno == EINTR) { + continue; + } else { + gpr_log(GPR_ERROR, "Poll failed: %s", strerror(errno)); + return -1; + } + } + if (err == 0 && spin) continue; + GPR_ASSERT(err == 1); + GPR_ASSERT(pfd.revents == POLLIN); + do { + err2 = read(fd, buf + bytes_read, read_size - bytes_read); + } while (err2 < 0 && errno == EINTR); + if (err2 < 0 && errno != EAGAIN) { + gpr_log(GPR_ERROR, "Read failed: %s", strerror(errno)); + return -1; + } + bytes_read += static_cast(err2); + } while (bytes_read < read_size); + return 0; +} + +static int poll_read_bytes_blocking(struct thread_args* args, char* buf) { + return poll_read_bytes(args->fds.read_fd, buf, args->msg_size, 0); +} + +static int poll_read_bytes_spin(struct thread_args* args, char* buf) { + return poll_read_bytes(args->fds.read_fd, buf, args->msg_size, 1); +} + +#ifdef __linux__ +/* Call epoll_wait() to monitor a non-blocking fd */ +static int epoll_read_bytes(struct thread_args* args, char* buf, int spin) { + struct epoll_event ev; + size_t bytes_read = 0; + int err; + ssize_t err2; + size_t read_size = args->msg_size; + + do { + err = epoll_wait(args->epoll_fd, &ev, 1, spin ? 0 : -1); + if (err < 0) { + if (errno == EINTR) continue; + gpr_log(GPR_ERROR, "epoll_wait failed: %s", strerror(errno)); + return -1; + } + if (err == 0 && spin) continue; + GPR_ASSERT(err == 1); + GPR_ASSERT(ev.events & EPOLLIN); + GPR_ASSERT(ev.data.fd == args->fds.read_fd); + do { + do { + err2 = + read(args->fds.read_fd, buf + bytes_read, read_size - bytes_read); + } while (err2 < 0 && errno == EINTR); + if (errno == EAGAIN) break; + bytes_read += static_cast(err2); + /* TODO(klempner): This should really be doing an extra call after we are + done to ensure we see an EAGAIN */ + } while (bytes_read < read_size); + } while (bytes_read < read_size); + GPR_ASSERT(bytes_read == read_size); + return 0; +} + +static int epoll_read_bytes_blocking(struct thread_args* args, char* buf) { + return epoll_read_bytes(args, buf, 0); +} + +static int epoll_read_bytes_spin(struct thread_args* args, char* buf) { + return epoll_read_bytes(args, buf, 1); +} +#endif /* __linux__ */ + +/* Write out bytes. + At this point we only have one strategy, since in the common case these + writes go directly out to the kernel. + */ +static int blocking_write_bytes(struct thread_args* args, char* buf) { + size_t bytes_written = 0; + ssize_t err; + size_t write_size = args->msg_size; + do { + err = write(args->fds.write_fd, buf + bytes_written, + write_size - bytes_written); + if (err < 0) { + if (errno == EINTR) { + continue; + } else { + gpr_log(GPR_ERROR, "Read failed: %s", strerror(errno)); + return -1; + } + } else { + bytes_written += static_cast(err); + } + } while (bytes_written < write_size); + return 0; +} + +/* + Initialization code + + These are called at the beginning of the client and server thread, depending + on the scenario we're using. + */ +static int set_socket_nonblocking(thread_args* args) { + if (!GRPC_LOG_IF_ERROR("Unable to set read socket nonblocking", + grpc_set_socket_nonblocking(args->fds.read_fd, 1))) { + return -1; + } + if (!GRPC_LOG_IF_ERROR("Unable to set write socket nonblocking", + grpc_set_socket_nonblocking(args->fds.write_fd, 1))) { + return -1; + } + return 0; +} + +static int do_nothing(thread_args* /*args*/) { return 0; } + +#ifdef __linux__ +/* Special case for epoll, where we need to create the fd ahead of time. */ +static int epoll_setup(thread_args* args) { + int epoll_fd; + struct epoll_event ev; + set_socket_nonblocking(args); + epoll_fd = epoll_create(1); + if (epoll_fd < 0) { + gpr_log(GPR_ERROR, "epoll_create: %s", strerror(errno)); + return -1; + } + + args->epoll_fd = epoll_fd; + + ev.events = EPOLLIN | EPOLLET; + ev.data.fd = args->fds.read_fd; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->fds.read_fd, &ev) < 0) { + gpr_log(GPR_ERROR, "epoll_ctl: %s", strerror(errno)); + } + return 0; +} +#endif + +static void server_thread(thread_args* args) { + char* buf = static_cast(gpr_malloc(args->msg_size)); + if (args->setup(args) < 0) { + gpr_log(GPR_ERROR, "Setup failed"); + } + for (;;) { + if (args->read_bytes(args, buf) < 0) { + gpr_log(GPR_ERROR, "Server read failed"); + gpr_free(buf); + return; + } + if (args->write_bytes(args, buf) < 0) { + gpr_log(GPR_ERROR, "Server write failed"); + gpr_free(buf); + return; + } + } +} + +static void server_thread_wrap(void* arg) { + thread_args* args = static_cast(arg); + server_thread(args); +} + +static void print_histogram(grpc_histogram* histogram) { + /* TODO(klempner): Print more detailed information, such as detailed histogram + buckets */ + gpr_log(GPR_INFO, "latency (50/95/99/99.9): %f/%f/%f/%f", + grpc_histogram_percentile(histogram, 50), + grpc_histogram_percentile(histogram, 95), + grpc_histogram_percentile(histogram, 99), + grpc_histogram_percentile(histogram, 99.9)); +} + +static double now(void) { + gpr_timespec tv = gpr_now(GPR_CLOCK_REALTIME); + return 1e9 * static_cast(tv.tv_sec) + static_cast(tv.tv_nsec); +} + +static void client_thread(thread_args* args) { + char* buf = static_cast(gpr_malloc(args->msg_size * sizeof(char))); + memset(buf, 0, args->msg_size * sizeof(char)); + grpc_histogram* histogram = grpc_histogram_create(0.01, 60e9); + double start_time; + double end_time; + double interval; + const int kNumIters = 100000; + int i; + + if (args->setup(args) < 0) { + gpr_log(GPR_ERROR, "Setup failed"); + } + for (i = 0; i < kNumIters; ++i) { + start_time = now(); + if (args->write_bytes(args, buf) < 0) { + gpr_log(GPR_ERROR, "Client write failed"); + goto error; + } + if (args->read_bytes(args, buf) < 0) { + gpr_log(GPR_ERROR, "Client read failed"); + goto error; + } + end_time = now(); + if (i > kNumIters / 2) { + interval = end_time - start_time; + grpc_histogram_add(histogram, interval); + } + } + print_histogram(histogram); +error: + gpr_free(buf); + grpc_histogram_destroy(histogram); +} + +/* This roughly matches tcp_server's create_listening_socket */ +static int create_listening_socket(struct sockaddr* port, socklen_t len) { + int fd = socket(port->sa_family, SOCK_STREAM, 0); + if (fd < 0) { + gpr_log(GPR_ERROR, "Unable to create socket: %s", strerror(errno)); + goto error; + } + + if (!GRPC_LOG_IF_ERROR("Failed to set listening socket cloexec", + grpc_set_socket_cloexec(fd, 1))) { + goto error; + } + if (!GRPC_LOG_IF_ERROR("Failed to set listening socket low latency", + grpc_set_socket_low_latency(fd, 1))) { + goto error; + } + if (!GRPC_LOG_IF_ERROR("Failed to set listening socket reuse addr", + grpc_set_socket_reuse_addr(fd, 1))) { + goto error; + } + + if (bind(fd, port, len) < 0) { + gpr_log(GPR_ERROR, "bind: %s", strerror(errno)); + goto error; + } + + if (listen(fd, 1) < 0) { + gpr_log(GPR_ERROR, "listen: %s", strerror(errno)); + goto error; + } + + if (getsockname(fd, port, &len) < 0) { + gpr_log(GPR_ERROR, "getsockname: %s", strerror(errno)); + goto error; + } + + return fd; + +error: + if (fd >= 0) { + close(fd); + } + return -1; +} + +static int connect_client(struct sockaddr* addr, socklen_t len) { + int fd = socket(addr->sa_family, SOCK_STREAM, 0); + int err; + if (fd < 0) { + gpr_log(GPR_ERROR, "Unable to create socket: %s", strerror(errno)); + goto error; + } + + if (!GRPC_LOG_IF_ERROR("Failed to set connecting socket cloexec", + grpc_set_socket_cloexec(fd, 1))) { + goto error; + } + if (!GRPC_LOG_IF_ERROR("Failed to set connecting socket low latency", + grpc_set_socket_low_latency(fd, 1))) { + goto error; + } + + do { + err = connect(fd, addr, len); + } while (err < 0 && errno == EINTR); + + if (err < 0) { + gpr_log(GPR_ERROR, "connect error: %s", strerror(errno)); + goto error; + } + return fd; + +error: + if (fd >= 0) { + close(fd); + } + return -1; +} + +static int accept_server(int listen_fd) { + int fd = accept(listen_fd, nullptr, nullptr); + if (fd < 0) { + gpr_log(GPR_ERROR, "Accept failed: %s", strerror(errno)); + return -1; + } + return fd; +} + +static int create_sockets_tcp(fd_pair* client_fds, fd_pair* server_fds) { + int listen_fd = -1; + int client_fd = -1; + int server_fd = -1; + + struct sockaddr_in port; + struct sockaddr* sa_port = reinterpret_cast(&port); + + port.sin_family = AF_INET; + port.sin_port = 0; + port.sin_addr.s_addr = INADDR_ANY; + + listen_fd = create_listening_socket(sa_port, sizeof(port)); + if (listen_fd == -1) { + gpr_log(GPR_ERROR, "Listen failed"); + goto error; + } + + client_fd = connect_client(sa_port, sizeof(port)); + if (client_fd == -1) { + gpr_log(GPR_ERROR, "Connect failed"); + goto error; + } + + server_fd = accept_server(listen_fd); + if (server_fd == -1) { + gpr_log(GPR_ERROR, "Accept failed"); + goto error; + } + + client_fds->read_fd = client_fd; + client_fds->write_fd = client_fd; + server_fds->read_fd = server_fd; + server_fds->write_fd = server_fd; + close(listen_fd); + return 0; + +error: + if (listen_fd != -1) { + close(listen_fd); + } + if (client_fd != -1) { + close(client_fd); + } + if (server_fd != -1) { + close(server_fd); + } + return -1; +} + +static int create_sockets_socketpair(fd_pair* client_fds, fd_pair* server_fds) { + int fds[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) < 0) { + gpr_log(GPR_ERROR, "socketpair: %s", strerror(errno)); + return -1; + } + + client_fds->read_fd = fds[0]; + client_fds->write_fd = fds[0]; + server_fds->read_fd = fds[1]; + server_fds->write_fd = fds[1]; + return 0; +} + +static int create_sockets_pipe(fd_pair* client_fds, fd_pair* server_fds) { + int cfds[2]; + int sfds[2]; + if (pipe(cfds) < 0) { + gpr_log(GPR_ERROR, "pipe: %s", strerror(errno)); + return -1; + } + + if (pipe(sfds) < 0) { + gpr_log(GPR_ERROR, "pipe: %s", strerror(errno)); + return -1; + } + + client_fds->read_fd = cfds[0]; + client_fds->write_fd = cfds[1]; + server_fds->read_fd = sfds[0]; + server_fds->write_fd = sfds[1]; + return 0; +} + +static const char* read_strategy_usage = + "Strategy for doing reads, which is one of:\n" + " blocking: blocking read calls\n" + " same_thread_poll: poll() call on same thread \n" +#ifdef __linux__ + " same_thread_epoll: epoll_wait() on same thread \n" +#endif + " spin_read: spinning non-blocking read() calls \n" + " spin_poll: spinning 0 timeout poll() calls \n" +#ifdef __linux__ + " spin_epoll: spinning 0 timeout epoll_wait() calls \n" +#endif + ""; + +static const char* socket_type_usage = + "Type of socket used, one of:\n" + " tcp: fds are endpoints of a TCP connection\n" + " socketpair: fds come from socketpair()\n" + " pipe: fds come from pipe()\n"; + +void print_usage(char* argv0) { + fprintf(stderr, "%s usage:\n\n", argv0); + fprintf(stderr, "%s read_strategy socket_type msg_size\n\n", argv0); + fprintf(stderr, "where read_strategy is one of:\n"); + fprintf(stderr, " blocking: blocking read calls\n"); + fprintf(stderr, " same_thread_poll: poll() call on same thread \n"); +#ifdef __linux__ + fprintf(stderr, " same_thread_epoll: epoll_wait() on same thread \n"); +#endif + fprintf(stderr, " spin_read: spinning non-blocking read() calls \n"); + fprintf(stderr, " spin_poll: spinning 0 timeout poll() calls \n"); +#ifdef __linux__ + fprintf(stderr, " spin_epoll: spinning 0 timeout epoll_wait() calls \n"); +#endif + fprintf(stderr, "and socket_type is one of:\n"); + fprintf(stderr, " tcp: fds are endpoints of a TCP connection\n"); + fprintf(stderr, " socketpair: fds come from socketpair()\n"); + fprintf(stderr, " pipe: fds come from pipe()\n"); + fflush(stderr); +} + +typedef struct test_strategy { + const char* name; + int (*read_strategy)(struct thread_args* args, char* buf); + int (*setup)(struct thread_args* args); +} test_strategy; + +static test_strategy test_strategies[] = { + {"blocking", blocking_read_bytes, do_nothing}, + {"same_thread_poll", poll_read_bytes_blocking, set_socket_nonblocking}, +#ifdef __linux__ + {"same_thread_epoll", epoll_read_bytes_blocking, epoll_setup}, + {"spin_epoll", epoll_read_bytes_spin, epoll_setup}, +#endif /* __linux__ */ + {"spin_read", spin_read_bytes, set_socket_nonblocking}, + {"spin_poll", poll_read_bytes_spin, set_socket_nonblocking}}; + +static const char* socket_types[] = {"tcp", "socketpair", "pipe"}; + +int create_socket(const char* socket_type, fd_pair* client_fds, + fd_pair* server_fds) { + if (strcmp(socket_type, "tcp") == 0) { + create_sockets_tcp(client_fds, server_fds); + } else if (strcmp(socket_type, "socketpair") == 0) { + create_sockets_socketpair(client_fds, server_fds); + } else if (strcmp(socket_type, "pipe") == 0) { + create_sockets_pipe(client_fds, server_fds); + } else { + fprintf(stderr, "Invalid socket type %s\n", socket_type); + fflush(stderr); + return -1; + } + return 0; +} + +static int run_benchmark(const char* socket_type, thread_args* client_args, + thread_args* server_args) { + int rv = 0; + + rv = create_socket(socket_type, &client_args->fds, &server_args->fds); + if (rv < 0) { + return rv; + } + + gpr_log(GPR_INFO, "Starting test %s %s %zu", client_args->strategy_name, + socket_type, client_args->msg_size); + + grpc_core::Thread server("server_thread", server_thread_wrap, server_args); + server.Start(); + client_thread(client_args); + server.Join(); + + return 0; +} + +static int run_all_benchmarks(size_t msg_size) { + int error = 0; + size_t i; + for (i = 0; i < GPR_ARRAY_SIZE(test_strategies); ++i) { + test_strategy* strategy = &test_strategies[i]; + size_t j; + for (j = 0; j < GPR_ARRAY_SIZE(socket_types); ++j) { + thread_args* client_args = + static_cast(gpr_malloc(sizeof(thread_args))); + thread_args* server_args = + static_cast(gpr_malloc(sizeof(thread_args))); + const char* socket_type = socket_types[j]; + + client_args->read_bytes = strategy->read_strategy; + client_args->write_bytes = blocking_write_bytes; + client_args->setup = strategy->setup; + client_args->msg_size = msg_size; + client_args->strategy_name = strategy->name; + server_args->read_bytes = strategy->read_strategy; + server_args->write_bytes = blocking_write_bytes; + server_args->setup = strategy->setup; + server_args->msg_size = msg_size; + server_args->strategy_name = strategy->name; + error = run_benchmark(socket_type, client_args, server_args); + if (error < 0) { + return error; + } + } + } + return error; +} + +int main(int argc, char** argv) { + thread_args* client_args = + static_cast(gpr_malloc(sizeof(thread_args))); + thread_args* server_args = + static_cast(gpr_malloc(sizeof(thread_args))); + int msg_size = -1; + const char* read_strategy = nullptr; + const char* socket_type = nullptr; + size_t i; + const test_strategy* strategy = nullptr; + int error = 0; + + gpr_cmdline* cmdline = + gpr_cmdline_create("low_level_ping_pong network benchmarking tool"); + + gpr_cmdline_add_int(cmdline, "msg_size", "Size of sent messages", &msg_size); + gpr_cmdline_add_string(cmdline, "read_strategy", read_strategy_usage, + &read_strategy); + gpr_cmdline_add_string(cmdline, "socket_type", socket_type_usage, + &socket_type); + + gpr_cmdline_parse(cmdline, argc, argv); + + if (msg_size == -1) { + msg_size = 50; + } + + if (read_strategy == nullptr) { + gpr_log(GPR_INFO, "No strategy specified, running all benchmarks"); + return run_all_benchmarks(static_cast(msg_size)); + } + + if (socket_type == nullptr) { + socket_type = "tcp"; + } + if (msg_size <= 0) { + fprintf(stderr, "msg_size must be > 0\n"); + fflush(stderr); + print_usage(argv[0]); + return -1; + } + + for (i = 0; i < GPR_ARRAY_SIZE(test_strategies); ++i) { + if (strcmp(test_strategies[i].name, read_strategy) == 0) { + strategy = &test_strategies[i]; + } + } + if (strategy == nullptr) { + fprintf(stderr, "Invalid read strategy %s\n", read_strategy); + fflush(stderr); + return -1; + } + + client_args->read_bytes = strategy->read_strategy; + client_args->write_bytes = blocking_write_bytes; + client_args->setup = strategy->setup; + client_args->msg_size = static_cast(msg_size); + client_args->strategy_name = read_strategy; + server_args->read_bytes = strategy->read_strategy; + server_args->write_bytes = blocking_write_bytes; + server_args->setup = strategy->setup; + server_args->msg_size = static_cast(msg_size); + server_args->strategy_name = read_strategy; + + error = run_benchmark(socket_type, client_args, server_args); + + gpr_cmdline_destroy(cmdline); + return error; +} diff --git a/test/core/promise/activity_test.cc b/test/core/promise/activity_test.cc new file mode 100644 index 00000000..017a789e --- /dev/null +++ b/test/core/promise/activity_test.cc @@ -0,0 +1,336 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/activity.h" + +#include +#include + +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/promise/wait_set.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +using testing::_; +using testing::Mock; +using testing::MockFunction; +using testing::SaveArg; +using testing::StrictMock; + +namespace grpc_core { + +// A simple Barrier type: stalls progress until it is 'cleared'. +class Barrier { + public: + struct Result {}; + + Promise Wait() { + return [this]() -> Poll { + absl::MutexLock lock(&mu_); + if (cleared_) { + return Result{}; + } else { + return wait_set_.AddPending(Activity::current()->MakeOwningWaker()); + } + }; + } + + void Clear() { + mu_.Lock(); + cleared_ = true; + auto wakeup = wait_set_.TakeWakeupSet(); + mu_.Unlock(); + wakeup.Wakeup(); + } + + private: + absl::Mutex mu_; + WaitSet wait_set_ ABSL_GUARDED_BY(mu_); + bool cleared_ ABSL_GUARDED_BY(mu_) = false; +}; + +// A simple Barrier type: stalls progress until it is 'cleared'. +// This variant supports only a single waiter. +class SingleBarrier { + public: + struct Result {}; + + Promise Wait() { + return [this]() -> Poll { + absl::MutexLock lock(&mu_); + if (cleared_) { + return Result{}; + } else { + waker_ = Activity::current()->MakeOwningWaker(); + return Pending(); + } + }; + } + + void Clear() { + mu_.Lock(); + cleared_ = true; + auto waker = std::move(waker_); + mu_.Unlock(); + waker.Wakeup(); + } + + private: + absl::Mutex mu_; + Waker waker_ ABSL_GUARDED_BY(mu_); + bool cleared_ ABSL_GUARDED_BY(mu_) = false; +}; + +TEST(ActivityTest, ImmediatelyCompleteWithSuccess) { + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [] { return [] { return absl::OkStatus(); }; }, NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(ActivityTest, ImmediatelyCompleteWithFailure) { + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::CancelledError())); + MakeActivity( + [] { return [] { return absl::CancelledError(); }; }, NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(ActivityTest, DropImmediately) { + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::CancelledError())); + MakeActivity( + [] { return []() -> Poll { return Pending(); }; }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(ActivityTest, Cancel) { + StrictMock> on_done; + auto activity = MakeActivity( + [] { return []() -> Poll { return Pending(); }; }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(on_done, Call(absl::CancelledError())); + activity->Cancel(); + Mock::VerifyAndClearExpectations(&on_done); + activity.reset(); +} + +template +class BarrierTest : public testing::Test { + public: + using Type = B; +}; + +using BarrierTestTypes = testing::Types; +TYPED_TEST_SUITE(BarrierTest, BarrierTestTypes); + +TYPED_TEST(BarrierTest, Barrier) { + typename TestFixture::Type b; + StrictMock> on_done; + auto activity = MakeActivity( + [&b] { + return Seq(b.Wait(), [](typename TestFixture::Type::Result) { + return absl::OkStatus(); + }); + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + // Clearing the barrier should let the activity proceed to return a result. + EXPECT_CALL(on_done, Call(absl::OkStatus())); + b.Clear(); +} + +TYPED_TEST(BarrierTest, BarrierPing) { + typename TestFixture::Type b1; + typename TestFixture::Type b2; + StrictMock> on_done1; + StrictMock> on_done2; + MockCallbackScheduler scheduler1; + MockCallbackScheduler scheduler2; + auto activity1 = MakeActivity( + [&b1, &b2] { + return Seq(b1.Wait(), [&b2](typename TestFixture::Type::Result) { + // Clear the barrier whilst executing an activity + b2.Clear(); + return absl::OkStatus(); + }); + }, + UseMockCallbackScheduler{&scheduler1}, + [&on_done1](absl::Status status) { on_done1.Call(std::move(status)); }); + auto activity2 = MakeActivity( + [&b2] { + return Seq(b2.Wait(), [](typename TestFixture::Type::Result) { + return absl::OkStatus(); + }); + }, + UseMockCallbackScheduler{&scheduler2}, + [&on_done2](absl::Status status) { on_done2.Call(std::move(status)); }); + // Since barrier triggers inside activity1 promise, activity2 wakeup will be + // scheduled from a callback. + std::function cb1; + std::function cb2; + EXPECT_CALL(scheduler1, Schedule(_)).WillOnce(SaveArg<0>(&cb1)); + b1.Clear(); + Mock::VerifyAndClearExpectations(&scheduler1); + EXPECT_CALL(on_done1, Call(absl::OkStatus())); + EXPECT_CALL(scheduler2, Schedule(_)).WillOnce(SaveArg<0>(&cb2)); + cb1(); + Mock::VerifyAndClearExpectations(&on_done1); + EXPECT_CALL(on_done2, Call(absl::OkStatus())); + cb2(); +} + +TYPED_TEST(BarrierTest, WakeSelf) { + typename TestFixture::Type b; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&b] { + return Seq(Join(b.Wait(), + [&b] { + b.Clear(); + return 1; + }), + [](std::tuple) { + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TYPED_TEST(BarrierTest, WakeAfterDestruction) { + typename TestFixture::Type b; + { + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::CancelledError())); + MakeActivity( + [&b] { + return Seq(b.Wait(), [](typename TestFixture::Type::Result) { + return absl::OkStatus(); + }); + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + } + b.Clear(); +} + +TEST(ActivityTest, ForceWakeup) { + StrictMock> on_done; + int run_count = 0; + auto activity = MakeActivity( + [&run_count]() -> Poll { + ++run_count; + switch (run_count) { + case 1: + return Pending{}; + case 2: + return absl::OkStatus(); + default: + abort(); + } + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + activity->ForceWakeup(); +} + +struct TestContext { + bool* done; +}; +template <> +struct ContextType {}; + +TEST(ActivityTest, WithContext) { + bool done = false; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [] { + *GetContext()->done = true; + return Immediate(absl::OkStatus()); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }, + TestContext{&done}); + EXPECT_TRUE(done); +} + +TEST(ActivityTest, CanCancelDuringExecution) { + ActivityPtr activity; + StrictMock> on_done; + int run_count = 0; + + activity = MakeActivity( + [&activity, &run_count]() -> Poll { + ++run_count; + switch (run_count) { + case 1: + return Pending{}; + case 2: + activity.reset(); + return Pending{}; + default: + abort(); + } + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + + EXPECT_CALL(on_done, Call(absl::CancelledError())); + activity->ForceWakeup(); +} + +TEST(ActivityTest, CanCancelDuringSuccessfulExecution) { + ActivityPtr activity; + StrictMock> on_done; + int run_count = 0; + + activity = MakeActivity( + [&activity, &run_count]() -> Poll { + ++run_count; + switch (run_count) { + case 1: + return Pending{}; + case 2: + activity.reset(); + return absl::OkStatus(); + default: + abort(); + } + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + + EXPECT_CALL(on_done, Call(absl::OkStatus())); + activity->ForceWakeup(); +} + +TEST(WakerTest, CanWakeupEmptyWaker) { + // Empty wakers should not do anything upon wakeup. + Waker().Wakeup(); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/benchmark/competition.cc b/test/core/promise/benchmark/competition.cc new file mode 100644 index 00000000..648c9070 --- /dev/null +++ b/test/core/promise/benchmark/competition.cc @@ -0,0 +1,492 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* This benchmark exists to ensure that immediately-firing alarms are fast */ + +#include + +#include "absl/synchronization/mutex.h" + +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" +#include "test/core/promise/benchmark/filter_stack.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +namespace filter_stack { + +Filter passthrough_filter = { + CallNextOp, NoCallData, NoCallData, NoChannelData, NoChannelData, 0, 0, +}; + +struct Interject { + Closure c; + Closure* next; + + static void Callback(void* p, absl::Status status) { + auto* i = static_cast(p); + i->next->Run(std::move(status)); + } + + static void Init(CallElem* elem) { + auto* i = static_cast(elem->call_data); + i->c.f = Callback; + i->c.p = i; + } + + static void Destroy(CallElem*) {} + + static void StartOp(CallElem* elem, Op* op) { + auto* i = static_cast(elem->call_data); + if (op->recv_initial_metadata) { + i->next = op->on_complete; + op->on_complete = &i->c; + } + CallNextOp(elem, op); + } +}; + +Filter interject_filter = { + Interject::StartOp, + Interject::Init, + Interject::Destroy, + NoChannelData, + NoChannelData, + sizeof(Interject), + 0, +}; + +struct InterjectPipe { + Closure c_init_metadata; + Closure* next_init_metadata; + Closure c_payload; + Closure* next_payload; + Closure c_trailing_metadata; + Closure* next_trailing_metadata; + + static void CallbackInitMetadata(void* p, absl::Status status) { + auto* i = static_cast(p); + i->next_init_metadata->Run(std::move(status)); + } + + static void CallbackPayload(void* p, absl::Status status) { + auto* i = static_cast(p); + i->next_payload->Run(std::move(status)); + } + + static void CallbackTrailingMetadata(void* p, absl::Status status) { + auto* i = static_cast(p); + i->next_trailing_metadata->Run(std::move(status)); + } + + static void Init(CallElem* elem) { + auto* i = static_cast(elem->call_data); + i->c_init_metadata.f = CallbackInitMetadata; + i->c_init_metadata.p = i; + i->c_payload.f = CallbackPayload; + i->c_payload.p = i; + i->c_trailing_metadata.f = CallbackTrailingMetadata; + i->c_trailing_metadata.p = i; + } + + static void Destroy(CallElem*) {} + + static void StartOp(CallElem* elem, Op* op) { + auto* i = static_cast(elem->call_data); + if (op->recv_trailing_metadata) { + i->next_trailing_metadata = op->on_complete; + op->on_complete = &i->c_trailing_metadata; + } + if (op->recv_message) { + i->next_payload = op->on_complete; + op->on_complete = &i->c_payload; + } + if (op->recv_initial_metadata) { + i->next_init_metadata = op->on_complete; + op->on_complete = &i->c_init_metadata; + } + CallNextOp(elem, op); + } +}; + +Filter interject_pipe = { + InterjectPipe::StartOp, + InterjectPipe::Init, + InterjectPipe::Destroy, + NoChannelData, + NoChannelData, + sizeof(InterjectPipe), + 0, +}; + +void EndOp(CallElem*, Op* op) { op->on_complete->Run(absl::OkStatus()); } + +Filter end_filter = {EndOp, NoCallData, NoCallData, NoChannelData, + NoChannelData, 0, 0}; + +static void unary(benchmark::State& state, + std::initializer_list filters) { + auto* channel = + MakeChannel(const_cast(&*filters.begin()), filters.size()); + for (auto _ : state) { + auto* call = MakeCall(channel); + Op op; + Op::Payload payload; + op.recv_initial_metadata = true; + op.recv_message = true; + op.recv_trailing_metadata = true; + op.payload = &payload; + Closure done = {call, +[](void* p, absl::Status status) { + if (!status.ok()) abort(); + FreeCall(static_cast(p)); + }}; + op.on_complete = &done; + RunOp(call, &op); + } + FreeChannel(channel); +} + +static void BM_FilterStack_Passthrough3_Unary(benchmark::State& state) { + unary(state, {&passthrough_filter, &passthrough_filter, &passthrough_filter, + &end_filter}); +} +BENCHMARK(BM_FilterStack_Passthrough3_Unary); + +static void BM_FilterStack_Passthrough10_Unary(benchmark::State& state) { + unary(state, {&passthrough_filter, &passthrough_filter, &passthrough_filter, + &passthrough_filter, &passthrough_filter, &passthrough_filter, + &passthrough_filter, &passthrough_filter, &passthrough_filter, + &passthrough_filter, &end_filter}); +} +BENCHMARK(BM_FilterStack_Passthrough10_Unary); + +static void BM_FilterStack_Interject3_Unary(benchmark::State& state) { + unary(state, + {&interject_filter, &interject_filter, &interject_filter, &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject3_Unary); + +static void BM_FilterStack_Interject10_Unary(benchmark::State& state) { + unary(state, {&interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject10_Unary); + +static void BM_FilterStack_Interject30_Unary(benchmark::State& state) { + unary(state, {&interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &interject_filter, &interject_filter, &interject_filter, + &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject30_Unary); + +static void BM_FilterStack_Interject3Pipe_Unary(benchmark::State& state) { + unary(state, + {&interject_pipe, &interject_pipe, &interject_pipe, &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject3Pipe_Unary); + +static void BM_FilterStack_Interject10Pipe_Unary(benchmark::State& state) { + unary(state, + {&interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject10Pipe_Unary); + +static void BM_FilterStack_Interject30Pipe_Unary(benchmark::State& state) { + unary(state, + {&interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &interject_pipe, &interject_pipe, + &interject_pipe, &interject_pipe, &end_filter}); +} +BENCHMARK(BM_FilterStack_Interject30Pipe_Unary); + +} // namespace filter_stack + +namespace grpc_core { + +namespace activity_stack { +struct RPCIO { + Latch recv_initial_metadata; +}; + +struct RPCP { + Pipe pipe; +}; +} // namespace activity_stack + +template <> +struct ContextType {}; + +template <> +struct ContextType {}; + +namespace activity_stack { + +template +static void unary(benchmark::State& state, MakeCall make_call) { + printf("activity stack size: %d\n", static_cast(make_call()->Size())); + for (auto _ : state) { + make_call(); + } +} + +static void BM_ActivityStack_Passthrough3_Unary(benchmark::State& state) { + unary(state, []() { + return MakeActivity( + []() { + auto one = []() { return absl::OkStatus(); }; + return TrySeq(one, one, one); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }); + }); +} +BENCHMARK(BM_ActivityStack_Passthrough3_Unary); + +static void BM_ActivityStack_Passthrough10_Unary(benchmark::State& state) { + unary(state, []() { + return MakeActivity( + []() { + auto one = []() { return absl::OkStatus(); }; + return TrySeq(one, one, one, one, one, one, one, one, one, one); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }); + }); +} +BENCHMARK(BM_ActivityStack_Passthrough10_Unary); + +static void BM_ActivityStack_Interject3Latches_Unary(benchmark::State& state) { + unary(state, []() { + RPCIO rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->recv_initial_metadata.Wait(); + }; + return Seq(Join(one(), one(), one(), + []() { + GetContext()->recv_initial_metadata.Set(42); + return true; + }), + []() { return absl::OkStatus(); }); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject3Latches_Unary); + +static void BM_ActivityStack_Interject10Latches_Unary(benchmark::State& state) { + unary(state, []() { + RPCIO rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->recv_initial_metadata.Wait(); + }; + return Seq(Join(one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), + []() { + GetContext()->recv_initial_metadata.Set(42); + return true; + }), + []() { return absl::OkStatus(); }); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject10Latches_Unary); + +static void BM_ActivityStack_Interject30Latches_Unary(benchmark::State& state) { + unary(state, []() { + RPCIO rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->recv_initial_metadata.Wait(); + }; + return Seq( + Join(one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), one(), one(), one(), + []() { + GetContext()->recv_initial_metadata.Set(42); + return true; + }), + []() { return absl::OkStatus(); }); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject30Latches_Unary); + +static void BM_ActivityStack_Interject3Filters_Unary(benchmark::State& state) { + unary(state, []() { + RPCP rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->pipe.sender.Filter( + [](int i) { return absl::StatusOr(i); }); + }; + return TryJoin( + one(), one(), one(), + Seq( + GetContext()->pipe.sender.Push(42), + []() { return GetContext()->pipe.sender.Push(43); }, + []() { return GetContext()->pipe.sender.Push(44); }, + []() { + auto x = std::move(GetContext()->pipe.sender); + return absl::OkStatus(); + }), + Seq( + GetContext()->pipe.receiver.Next(), + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return absl::OkStatus(); })); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject3Filters_Unary); + +static void BM_ActivityStack_Interject10Filters_Unary(benchmark::State& state) { + unary(state, []() { + RPCP rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->pipe.sender.Filter( + [](int i) { return absl::StatusOr(i); }); + }; + return TryJoin( + one(), one(), one(), one(), one(), one(), one(), one(), one(), + one(), + Seq( + GetContext()->pipe.sender.Push(42), + []() { return GetContext()->pipe.sender.Push(43); }, + []() { return GetContext()->pipe.sender.Push(44); }, + []() { + auto x = std::move(GetContext()->pipe.sender); + return absl::OkStatus(); + }), + Seq( + GetContext()->pipe.receiver.Next(), + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return absl::OkStatus(); })); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject10Filters_Unary); + +static void BM_ActivityStack_Interject30Filters_Unary(benchmark::State& state) { + unary(state, []() { + RPCP rpcio; + return MakeActivity( + []() { + auto one = []() { + return GetContext()->pipe.sender.Filter( + [](int i) { return absl::StatusOr(i); }); + }; + return TryJoin( + one(), one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), one(), one(), one(), one(), one(), one(), + one(), one(), one(), + Seq( + GetContext()->pipe.sender.Push(42), + []() { return GetContext()->pipe.sender.Push(43); }, + []() { return GetContext()->pipe.sender.Push(44); }, + []() { + auto x = std::move(GetContext()->pipe.sender); + return absl::OkStatus(); + }), + Seq( + GetContext()->pipe.receiver.Next(), + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return GetContext()->pipe.receiver.Next(); }, + []() { return absl::OkStatus(); })); + }, + NoWakeupScheduler(), + [](absl::Status status) { + if (!status.ok()) abort(); + }, + std::move(rpcio)); + }); +} +BENCHMARK(BM_ActivityStack_Interject30Filters_Unary); + +} // namespace activity_stack +} // namespace grpc_core + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/core/promise/benchmark/filter_stack.cc b/test/core/promise/benchmark/filter_stack.cc new file mode 100644 index 00000000..fe438e40 --- /dev/null +++ b/test/core/promise/benchmark/filter_stack.cc @@ -0,0 +1,102 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/promise/benchmark/filter_stack.h" + +namespace filter_stack { + +ChannelStack* MakeChannel(Filter** filters, size_t num_filters) { + size_t size = sizeof(ChannelStack) + num_filters * sizeof(ChannelElem); + size_t call_size = sizeof(CallStack) + num_filters * sizeof(CallElem); + for (size_t i = 0; i < num_filters; i++) { + size += filters[i]->sizeof_channel_data; + call_size += filters[i]->sizeof_call_data; + } + char* data = new char[size]; + ChannelStack* stk = reinterpret_cast(data); + new (data) ChannelStack{0, num_filters, call_size}; + data += sizeof(ChannelStack); + char* user_data = data + num_filters * sizeof(ChannelElem); + for (size_t i = 0; i < num_filters; i++) { + new (data) ChannelElem{filters[i], user_data}; + filters[i]->init_channel_data(reinterpret_cast(data)); + data += sizeof(ChannelElem); + user_data += filters[i]->sizeof_channel_data; + } + printf("CALL STACK SIZE: %d\n", static_cast(call_size)); + return stk; +} + +void FreeChannel(ChannelStack* stk) { + ChannelElem* elems = reinterpret_cast(stk + 1); + for (size_t i = 0; i < stk->num_elems; i++) { + elems[i].filter->destroy_channel_data(&elems[i]); + } + stk->~ChannelStack(); + delete[] reinterpret_cast(stk); +} + +CallStack* MakeCall(ChannelStack* stk) { + char* data = new char[stk->call_stack_size]; + CallStack* call = reinterpret_cast(data); + new (data) CallStack{{1}, stk->num_elems, {}}; + data += sizeof(CallStack); + ChannelElem* channel_elems = reinterpret_cast(stk + 1); + char* user_data = data + stk->num_elems * sizeof(CallElem); + for (size_t i = 0; i < stk->num_elems; i++) { + new (data) CallElem{channel_elems[i].filter, channel_elems[i].channel_data, + user_data}; + channel_elems[i].filter->init_call_data(reinterpret_cast(data)); + data += sizeof(CallElem); + user_data += channel_elems[i].filter->sizeof_call_data; + } + return call; +} + +static void RefCall(CallStack* stk) { + stk->refcount.fetch_add(1, std::memory_order_relaxed); +} + +static void UnrefCall(CallStack* stk) { + if (stk->refcount.fetch_sub(1, std::memory_order_acq_rel) == 1) { + CallElem* elems = reinterpret_cast(stk + 1); + for (size_t i = 0; i < stk->num_elems; i++) { + elems[i].filter->destroy_call_data(&elems[i]); + } + stk->~CallStack(); + delete[] reinterpret_cast(stk); + } +} + +void FreeCall(CallStack* stk) { UnrefCall(stk); } + +void NoChannelData(ChannelElem*) {} +void NoCallData(CallElem*) {} + +static void StartOp(CallElem* elem, Op* op) { + elem->filter->start_transport_stream_op_batch(elem, op); +} + +void CallNextOp(CallElem* elem, Op* op) { StartOp(elem + 1, op); } + +void RunOp(CallStack* stk, Op* op) { + RefCall(stk); + { + absl::MutexLock lock(&stk->mutex); + StartOp(reinterpret_cast(stk + 1), op); + } + UnrefCall(stk); +} + +} // namespace filter_stack diff --git a/test/core/promise/context_test.cc b/test/core/promise/context_test.cc new file mode 100644 index 00000000..b4f4afe9 --- /dev/null +++ b/test/core/promise/context_test.cc @@ -0,0 +1,42 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/context.h" + +#include + +namespace grpc_core { + +struct TestContext { + bool done = false; +}; + +template <> +struct ContextType {}; + +TEST(Context, WithContext) { + EXPECT_EQ(GetContext(), nullptr); + TestContext test; + EXPECT_EQ(GetContext(), nullptr); + EXPECT_EQ(test.done, false); + WithContext([]() { GetContext()->done = true; }, &test)(); + EXPECT_EQ(test.done, true); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/exec_ctx_wakeup_scheduler_test.cc b/test/core/promise/exec_ctx_wakeup_scheduler_test.cc new file mode 100644 index 00000000..ea3568c3 --- /dev/null +++ b/test/core/promise/exec_ctx_wakeup_scheduler_test.cc @@ -0,0 +1,66 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/exec_ctx_wakeup_scheduler.h" + +#include + +#include "src/core/lib/promise/activity.h" + +namespace grpc_core { + +TEST(ExecCtxWakeupSchedulerTest, Works) { + int state = 0; + bool done = false; + auto activity = MakeActivity( + [&state]() mutable -> Poll { + ++state; + switch (state) { + case 1: + return Pending(); + case 2: + return absl::OkStatus(); + default: + abort(); + } + }, + ExecCtxWakeupScheduler(), + [&done](absl::Status status) { + EXPECT_EQ(status, absl::OkStatus()); + done = true; + }); + + EXPECT_EQ(state, 1); + EXPECT_FALSE(done); + { + ExecCtx exec_ctx; + EXPECT_FALSE(exec_ctx.HasWork()); + activity->ForceWakeup(); + EXPECT_TRUE(exec_ctx.HasWork()); + EXPECT_EQ(state, 1); + EXPECT_FALSE(done); + } + EXPECT_EQ(state, 2); + EXPECT_TRUE(done); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc_core::ExecCtx::GlobalInit(); + int r = RUN_ALL_TESTS(); + grpc_core::ExecCtx::GlobalShutdown(); + return r; +} diff --git a/test/core/promise/for_each_test.cc b/test/core/promise/for_each_test.cc new file mode 100644 index 00000000..d668c3d4 --- /dev/null +++ b/test/core/promise/for_each_test.cc @@ -0,0 +1,72 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/for_each.h" + +#include +#include + +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/observable.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/seq.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +using testing::Mock; +using testing::MockFunction; +using testing::StrictMock; + +namespace grpc_core { + +TEST(ForEachTest, SendThriceWithPipe) { + Pipe pipe; + int num_received = 0; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&pipe, &num_received] { + return Map( + Join( + // Push 3 things into a pipe -- 1, 2, then 3 -- then close. + Seq( + pipe.sender.Push(1), + [&pipe] { return pipe.sender.Push(2); }, + [&pipe] { return pipe.sender.Push(3); }, + [&pipe] { + auto drop = std::move(pipe.sender); + return absl::OkStatus(); + }), + // Use a ForEach loop to read them out and verify all values are + // seen. + ForEach(std::move(pipe.receiver), + [&num_received](int i) { + num_received++; + EXPECT_EQ(num_received, i); + return absl::OkStatus(); + })), + JustElem<1>()); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + Mock::VerifyAndClearExpectations(&on_done); + EXPECT_EQ(num_received, 3); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/if_test.cc b/test/core/promise/if_test.cc new file mode 100644 index 00000000..f0ec4379 --- /dev/null +++ b/test/core/promise/if_test.cc @@ -0,0 +1,58 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/if.h" + +#include + +namespace grpc_core { + +TEST(IfTest, ChooseTrue) { + EXPECT_EQ(If([]() { return true; }, []() { return 1; }, []() { return 2; })(), + Poll(1)); +} + +TEST(IfTest, ChooseFalse) { + EXPECT_EQ( + If([]() { return false; }, []() { return 1; }, []() { return 2; })(), + Poll(2)); +} + +TEST(IfTest, ChooseSuccesfulTrue) { + EXPECT_EQ(If([]() { return absl::StatusOr(true); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr(1))); +} + +TEST(IfTest, ChooseSuccesfulFalse) { + EXPECT_EQ(If([]() { return absl::StatusOr(false); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr(2))); +} + +TEST(IfTest, ChooseFailure) { + EXPECT_EQ(If([]() { return absl::StatusOr(); }, + []() { return absl::StatusOr(1); }, + []() { return absl::StatusOr(2); })(), + Poll>(absl::StatusOr())); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/join_test.cc b/test/core/promise/join_test.cc new file mode 100644 index 00000000..df3ab83a --- /dev/null +++ b/test/core/promise/join_test.cc @@ -0,0 +1,41 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/join.h" + +#include + +namespace grpc_core { + +TEST(JoinTest, Join1) { + EXPECT_EQ(Join([] { return 3; })(), + (Poll>(std::make_tuple(3)))); +} + +TEST(JoinTest, Join2) { + EXPECT_EQ(Join([] { return 3; }, [] { return 4; })(), + (Poll>(std::make_tuple(3, 4)))); +} + +TEST(JoinTest, Join3) { + EXPECT_EQ(Join([] { return 3; }, [] { return 4; }, [] { return 5; })(), + (Poll>(std::make_tuple(3, 4, 5)))); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/latch_test.cc b/test/core/promise/latch_test.cc new file mode 100644 index 00000000..a3054a1e --- /dev/null +++ b/test/core/promise/latch_test.cc @@ -0,0 +1,54 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/latch.h" + +#include +#include + +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/seq.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +using testing::MockFunction; +using testing::StrictMock; + +namespace grpc_core { + +TEST(LatchTest, Works) { + Latch latch; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&latch] { + return Seq(Join(latch.Wait(), + [&latch]() { + latch.Set(42); + return true; + }), + [](std::tuple result) { + EXPECT_EQ(*std::get<0>(result), 42); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/loop_test.cc b/test/core/promise/loop_test.cc new file mode 100644 index 00000000..5e6cf449 --- /dev/null +++ b/test/core/promise/loop_test.cc @@ -0,0 +1,56 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/loop.h" + +#include + +#include "src/core/lib/promise/seq.h" + +namespace grpc_core { + +TEST(LoopTest, CountToFive) { + int i = 0; + Loop([&i]() -> LoopCtl { + i++; + if (i < 5) return Continue(); + return i; + })(); + EXPECT_EQ(i, 5); +} + +TEST(LoopTest, FactoryCountToFive) { + int i = 0; + Loop([&i]() { + return [&i]() -> LoopCtl { + i++; + if (i < 5) return Continue(); + return i; + }; + })(); + EXPECT_EQ(i, 5); +} + +TEST(LoopTest, LoopOfSeq) { + auto x = + Loop(Seq([]() { return 42; }, [](int i) -> LoopCtl { return i; }))(); + EXPECT_EQ(x, Poll(42)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/map_test.cc b/test/core/promise/map_test.cc new file mode 100644 index 00000000..d60d8364 --- /dev/null +++ b/test/core/promise/map_test.cc @@ -0,0 +1,39 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/map.h" + +#include + +#include "src/core/lib/promise/promise.h" + +namespace grpc_core { + +TEST(MapTest, Works) { + Promise x = Map([]() { return 42; }, [](int i) { return i / 2; }); + EXPECT_EQ(x(), Poll(21)); +} + +TEST(MapTest, JustElem) { + std::tuple t(1, 3.2); + EXPECT_EQ(JustElem<1>()(t), 3.2); + EXPECT_EQ(JustElem<0>()(t), 1); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/observable_test.cc b/test/core/promise/observable_test.cc new file mode 100644 index 00000000..3eb120b7 --- /dev/null +++ b/test/core/promise/observable_test.cc @@ -0,0 +1,131 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/observable.h" + +#include +#include + +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/seq.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +using testing::MockFunction; +using testing::StrictMock; + +namespace grpc_core { + +// A simple Barrier type: stalls progress until it is 'cleared'. +class Barrier { + public: + struct Result {}; + + Promise Wait() { + return [this]() -> Poll { + absl::MutexLock lock(&mu_); + if (cleared_) { + return Result{}; + } else { + return wait_set_.AddPending(Activity::current()->MakeOwningWaker()); + } + }; + } + + void Clear() { + mu_.Lock(); + cleared_ = true; + auto wakeup = wait_set_.TakeWakeupSet(); + mu_.Unlock(); + wakeup.Wakeup(); + } + + private: + absl::Mutex mu_; + WaitSet wait_set_ ABSL_GUARDED_BY(mu_); + bool cleared_ ABSL_GUARDED_BY(mu_) = false; +}; + +TEST(ObservableTest, CanPushAndGet) { + StrictMock> on_done; + Observable observable; + auto observer = observable.MakeObserver(); + auto activity = MakeActivity( + [&observer]() { + return Seq(observer.Get(), [](absl::optional i) { + return i == 42 ? absl::OkStatus() : absl::UnknownError("expected 42"); + }); + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + observable.Push(42); +} + +TEST(ObservableTest, CanNext) { + StrictMock> on_done; + Observable observable; + auto observer = observable.MakeObserver(); + auto activity = MakeActivity( + [&observer]() { + return Seq( + observer.Get(), + [&observer](absl::optional i) { + EXPECT_EQ(i, 42); + return observer.Next(); + }, + [](absl::optional i) { + return i == 1 ? absl::OkStatus() + : absl::UnknownError("expected 1"); + }); + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + observable.Push(42); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + observable.Push(1); +} + +TEST(ObservableTest, CanWatch) { + StrictMock> on_done; + Observable observable; + Barrier barrier; + auto activity = MakeActivity( + [&observable, &barrier]() { + return observable.Watch( + [&barrier](int x, + WatchCommitter* committer) -> Promise { + if (x == 3) { + committer->Commit(); + return Seq(barrier.Wait(), Immediate(absl::OkStatus())); + } else { + return Never(); + } + }); + }, + InlineWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + observable.Push(1); + observable.Push(2); + observable.Push(3); + observable.Push(4); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + barrier.Clear(); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/pipe_test.cc b/test/core/promise/pipe_test.cc new file mode 100644 index 00000000..ba38fbef --- /dev/null +++ b/test/core/promise/pipe_test.cc @@ -0,0 +1,180 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/pipe.h" + +#include +#include + +#include "absl/memory/memory.h" + +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/seq.h" +#include "test/core/promise/test_wakeup_schedulers.h" + +using testing::MockFunction; +using testing::StrictMock; + +namespace grpc_core { + +TEST(PipeTest, CanSendAndReceive) { + Pipe pipe; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&pipe] { + return Seq( + // Concurrently: send 42 into the pipe, and receive from the pipe. + Join(pipe.sender.Push(42), pipe.receiver.Next()), + // Once complete, verify successful sending and the received value + // is 42. + [](std::tuple> result) { + EXPECT_EQ(result, std::make_tuple(true, absl::optional(42))); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(PipeTest, CanReceiveAndSend) { + Pipe pipe; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&pipe] { + return Seq( + // Concurrently: receive from the pipe, and send 42 into the pipe. + Join(pipe.receiver.Next(), pipe.sender.Push(42)), + // Once complete, verify the received value is 42 and successful + // sending. + [](std::tuple, bool> result) { + EXPECT_EQ(result, std::make_tuple(absl::optional(42), true)); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(PipeTest, CanSeeClosedOnSend) { + Pipe pipe; + StrictMock> on_done; + auto sender = std::move(pipe.sender); + auto receiver = + absl::make_unique>(std::move(pipe.receiver)); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + // Push 42 onto the pipe - this will the pipe's one-deep send buffer. + EXPECT_TRUE(NowOrNever(sender.Push(42)).has_value()); + MakeActivity( + [&sender, &receiver] { + return Seq( + // Concurrently: + // - push 43 into the sender, which will stall because the buffer is + // full + // - and close the receiver, which will fail the pending send. + Join(sender.Push(43), + [&receiver] { + receiver.reset(); + return absl::OkStatus(); + }), + // Verify both that the send failed and that we executed the close. + [](std::tuple result) { + EXPECT_EQ(result, std::make_tuple(false, absl::OkStatus())); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(PipeTest, CanSeeClosedOnReceive) { + Pipe pipe; + StrictMock> on_done; + auto sender = absl::make_unique>(std::move(pipe.sender)); + auto receiver = std::move(pipe.receiver); + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&sender, &receiver] { + return Seq( + // Concurrently: + // - wait for a received value (will stall forever since we push + // nothing into the queue) + // - close the sender, which will signal the receiver to return an + // end-of-stream. + Join(receiver.Next(), + [&sender] { + sender.reset(); + return absl::OkStatus(); + }), + // Verify we received end-of-stream and closed the sender. + [](std::tuple, absl::Status> result) { + EXPECT_EQ(result, std::make_tuple(absl::optional(), + absl::OkStatus())); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +TEST(PipeTest, CanFilter) { + Pipe pipe; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&pipe] { + // Setup some filters here, carefully getting ordering correct by doing + // so outside of the Join() since C++ does not define execution order + // between arguments. + // TODO(ctiller): A future change to Pipe will specify an ordering + // between filters added to sender and receiver, at which point these + // should move back. + auto doubler = pipe.receiver.Filter( + [](int p) { return absl::StatusOr(p * 2); }); + auto adder = pipe.sender.Filter( + [](int p) { return absl::StatusOr(p + 1); }); + return Seq( + // Concurrently: + // - push 42 into the pipe + // - wait for a value to be received, and filter it by doubling it + // - wait for a value to be received, and filter it by adding one to + // it + // - wait for a value to be received and close the pipe. + Join(pipe.sender.Push(42), std::move(doubler), std::move(adder), + Seq(pipe.receiver.Next(), + [&pipe](absl::optional i) { + auto x = std::move(pipe.receiver); + return i; + })), + // Verify all of the above happened correctly. + [](std::tuple> + result) { + EXPECT_EQ(result, std::make_tuple(true, absl::OkStatus(), + absl::OkStatus(), + absl::optional(85))); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/poll_test.cc b/test/core/promise/poll_test.cc new file mode 100644 index 00000000..3a44f503 --- /dev/null +++ b/test/core/promise/poll_test.cc @@ -0,0 +1,45 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/poll.h" + +#include + +namespace grpc_core { + +TEST(PollTest, IsItPoll) { + EXPECT_EQ(PollTraits>::is_poll(), true); + EXPECT_EQ(PollTraits>::is_poll(), true); + EXPECT_EQ(PollTraits>>::is_poll(), true); + EXPECT_EQ(PollTraits::is_poll(), false); + EXPECT_EQ(PollTraits::is_poll(), false); + EXPECT_EQ(PollTraits>::is_poll(), false); +} + +TEST(PollTest, Pending) { + Poll i = Pending(); + EXPECT_TRUE(absl::holds_alternative(i)); +} + +TEST(PollTest, Ready) { + Poll i = 1; + EXPECT_TRUE(absl::holds_alternative(i)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/promise_factory_test.cc b/test/core/promise/promise_factory_test.cc new file mode 100644 index 00000000..b354dbe2 --- /dev/null +++ b/test/core/promise/promise_factory_test.cc @@ -0,0 +1,72 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/detail/promise_factory.h" + +#include + +#include "absl/functional/bind_front.h" + +#include "src/core/lib/gprpp/capture.h" +#include "src/core/lib/promise/promise.h" + +namespace grpc_core { +namespace promise_detail { +namespace testing { + +template +PromiseFactory MakeFactory(F f) { + return PromiseFactory(std::move(f)); +} + +TEST(AdaptorTest, FactoryFromPromise) { + EXPECT_EQ( + MakeFactory([]() { return Poll(Poll(42)); }).Once()(), + Poll(42)); + EXPECT_EQ( + MakeFactory([]() { return Poll(Poll(42)); }).Repeated()(), + Poll(42)); + EXPECT_EQ(MakeFactory(Promise([]() { + return Poll(Poll(42)); + })).Once()(), + Poll(42)); + EXPECT_EQ(MakeFactory(Promise([]() { + return Poll(Poll(42)); + })).Repeated()(), + Poll(42)); +} + +TEST(AdaptorTest, FactoryFromBindFrontPromise) { + EXPECT_EQ(MakeFactory( + absl::bind_front([](int i) { return Poll(i); }, 42)) + .Once()(), + Poll(42)); +} + +TEST(AdaptorTest, FactoryFromCapturePromise) { + EXPECT_EQ(MakeFactory( + grpc_core::Capture([](int* i) { return Poll(*i); }, 42)) + .Once()(), + Poll(42)); +} + +} // namespace testing +} // namespace promise_detail + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/promise_fuzzer.cc b/test/core/promise/promise_fuzzer.cc new file mode 100644 index 00000000..78a0da9e --- /dev/null +++ b/test/core/promise/promise_fuzzer.cc @@ -0,0 +1,317 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/join.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/promise/seq.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "test/core/promise/promise_fuzzer.pb.h" + +bool squelch = true; +bool leak_check = true; + +namespace grpc_core { +// Return type for infallible promises. +// We choose this so that it's easy to construct, and will trigger asan failures +// if misused, and is copyable. +using IntHdl = std::shared_ptr; + +template +using PromiseFactory = std::function(T)>; + +namespace { +class Fuzzer { + public: + void Run(const promise_fuzzer::Msg& msg) { + // If there's no promise we can't construct and activity and... we're done. + if (!msg.has_promise()) { + return; + } + // Construct activity. + activity_ = MakeActivity( + [msg, this] { + return Seq(MakePromise(msg.promise()), + [] { return absl::OkStatus(); }); + }, + Scheduler{this}, + [this](absl::Status status) { + // Must only be called once + GPR_ASSERT(!done_); + // If we became certain of the eventual status, verify it. + if (expected_status_.has_value()) { + GPR_ASSERT(status == *expected_status_); + } + // Mark ourselves done. + done_ = true; + }); + for (int i = 0; !done_ && activity_ != nullptr && i < msg.actions_size(); + i++) { + // Do some things + const auto& action = msg.actions(i); + switch (action.action_type_case()) { + // Force a wakeup + case promise_fuzzer::Action::kForceWakeup: + activity_->ForceWakeup(); + break; + // Cancel from the outside + case promise_fuzzer::Action::kCancel: + ExpectCancelled(); + activity_.reset(); + break; + // Flush any pending wakeups + case promise_fuzzer::Action::kFlushWakeup: + if (wakeup_ != nullptr) absl::exchange(wakeup_, nullptr)(); + break; + // Drop some wakeups (external system closed?) + case promise_fuzzer::Action::kDropWaker: { + int n = action.drop_waker(); + auto v = std::move(wakers_[n]); + wakers_.erase(n); + break; + } + // Wakeup some wakeups + case promise_fuzzer::Action::kAwakeWaker: { + int n = action.awake_waker(); + auto v = std::move(wakers_[n]); + wakers_.erase(n); + for (auto& w : v) { + w.Wakeup(); + } + break; + } + case promise_fuzzer::Action::ACTION_TYPE_NOT_SET: + break; + } + } + ExpectCancelled(); + activity_.reset(); + if (wakeup_ != nullptr) absl::exchange(wakeup_, nullptr)(); + GPR_ASSERT(done_); + } + + private: + // Schedule wakeups against the fuzzer + struct Scheduler { + Fuzzer* fuzzer; + // Schedule a wakeup + template + void ScheduleWakeup(ActivityType* activity) { + GPR_ASSERT(activity == fuzzer->activity_.get()); + GPR_ASSERT(fuzzer->wakeup_ == nullptr); + fuzzer->wakeup_ = [activity]() { activity->RunScheduledWakeup(); }; + } + }; + + // We know that if not already finished, the status when finished will be + // cancelled. + void ExpectCancelled() { + if (!done_ && !expected_status_.has_value()) { + expected_status_ = absl::CancelledError(); + } + } + + // Construct a promise factory from a protobuf + PromiseFactory MakePromiseFactory( + const promise_fuzzer::PromiseFactory& p) { + switch (p.promise_factory_type_case()) { + case promise_fuzzer::PromiseFactory::kPromise: + return [p, this](IntHdl) { return MakePromise(p.promise()); }; + case promise_fuzzer::PromiseFactory::kLast: + return [](IntHdl h) { return [h]() { return h; }; }; + case promise_fuzzer::PromiseFactory::PROMISE_FACTORY_TYPE_NOT_SET: + break; + } + return [](IntHdl) { + return []() -> Poll { return std::make_shared(42); }; + }; + } + + // Construct a promise from a protobuf + Promise MakePromise(const promise_fuzzer::Promise& p) { + switch (p.promise_type_case()) { + case promise_fuzzer::Promise::kSeq: + switch (p.seq().promise_factories_size()) { + case 1: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0))); + case 2: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0)), + MakePromiseFactory(p.seq().promise_factories(1))); + case 3: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0)), + MakePromiseFactory(p.seq().promise_factories(1)), + MakePromiseFactory(p.seq().promise_factories(2))); + case 4: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0)), + MakePromiseFactory(p.seq().promise_factories(1)), + MakePromiseFactory(p.seq().promise_factories(2)), + MakePromiseFactory(p.seq().promise_factories(3))); + case 5: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0)), + MakePromiseFactory(p.seq().promise_factories(1)), + MakePromiseFactory(p.seq().promise_factories(2)), + MakePromiseFactory(p.seq().promise_factories(3)), + MakePromiseFactory(p.seq().promise_factories(4))); + case 6: + return Seq(MakePromise(p.seq().first()), + MakePromiseFactory(p.seq().promise_factories(0)), + MakePromiseFactory(p.seq().promise_factories(1)), + MakePromiseFactory(p.seq().promise_factories(2)), + MakePromiseFactory(p.seq().promise_factories(3)), + MakePromiseFactory(p.seq().promise_factories(4)), + MakePromiseFactory(p.seq().promise_factories(5))); + } + break; + case promise_fuzzer::Promise::kJoin: + switch (p.join().promises_size()) { + case 1: + return Map(Join(MakePromise(p.join().promises(0))), + [](std::tuple t) { return std::get<0>(t); }); + case 2: + return Map( + Join(MakePromise(p.join().promises(0)), + MakePromise(p.join().promises(1))), + [](std::tuple t) { return std::get<0>(t); }); + case 3: + return Map(Join(MakePromise(p.join().promises(0)), + MakePromise(p.join().promises(1)), + MakePromise(p.join().promises(2))), + [](std::tuple t) { + return std::get<0>(t); + }); + case 4: + return Map(Join(MakePromise(p.join().promises(0)), + MakePromise(p.join().promises(1)), + MakePromise(p.join().promises(2)), + MakePromise(p.join().promises(3))), + [](std::tuple t) { + return std::get<0>(t); + }); + case 5: + return Map( + Join(MakePromise(p.join().promises(0)), + MakePromise(p.join().promises(1)), + MakePromise(p.join().promises(2)), + MakePromise(p.join().promises(3)), + MakePromise(p.join().promises(4))), + [](std::tuple t) { + return std::get<0>(t); + }); + case 6: + return Map( + Join(MakePromise(p.join().promises(0)), + MakePromise(p.join().promises(1)), + MakePromise(p.join().promises(2)), + MakePromise(p.join().promises(3)), + MakePromise(p.join().promises(4)), + MakePromise(p.join().promises(5))), + [](std::tuple + t) { return std::get<0>(t); }); + } + break; + case promise_fuzzer::Promise::kRace: + switch (p.race().promises_size()) { + case 1: + return Race(MakePromise(p.race().promises(0))); + case 2: + return Race(MakePromise(p.race().promises(0)), + MakePromise(p.race().promises(1))); + case 3: + return Race(MakePromise(p.race().promises(0)), + MakePromise(p.race().promises(1)), + MakePromise(p.race().promises(2))); + case 4: + return Race(MakePromise(p.race().promises(0)), + MakePromise(p.race().promises(1)), + MakePromise(p.race().promises(2)), + MakePromise(p.race().promises(3))); + case 5: + return Race(MakePromise(p.race().promises(0)), + MakePromise(p.race().promises(1)), + MakePromise(p.race().promises(2)), + MakePromise(p.race().promises(3)), + MakePromise(p.race().promises(4))); + case 6: + return Race(MakePromise(p.race().promises(0)), + MakePromise(p.race().promises(1)), + MakePromise(p.race().promises(2)), + MakePromise(p.race().promises(3)), + MakePromise(p.race().promises(4)), + MakePromise(p.race().promises(5))); + } + break; + case promise_fuzzer::Promise::kNever: + return Never(); + case promise_fuzzer::Promise::kSleepFirstN: { + int n = p.sleep_first_n(); + return [n]() mutable -> Poll { + if (n <= 0) return std::make_shared(0); + n--; + return Pending{}; + }; + } + case promise_fuzzer::Promise::kCancelFromInside: + return [this]() -> Poll { + this->activity_.reset(); + return Pending{}; + }; + case promise_fuzzer::Promise::kWaitOnceOnWaker: { + bool called = false; + auto config = p.wait_once_on_waker(); + return [this, config, called]() mutable -> Poll { + if (!called) { + if (config.owning()) { + wakers_[config.waker()].push_back( + Activity::current()->MakeOwningWaker()); + } else { + wakers_[config.waker()].push_back( + Activity::current()->MakeNonOwningWaker()); + } + return Pending(); + } + return std::make_shared(3); + }; + } + case promise_fuzzer::Promise::PromiseTypeCase::PROMISE_TYPE_NOT_SET: + break; + } + return [] { return std::make_shared(42); }; + } + + // Activity under test + ActivityPtr activity_; + // Scheduled wakeup (may be nullptr if no wakeup scheduled) + std::function wakeup_; + // If we are certain of the final status, then that. Otherwise, nullopt if we + // don't know. + absl::optional expected_status_; + // Has on_done been called? + bool done_ = false; + // Wakers that may be scheduled + std::map> wakers_; +}; +} // namespace + +} // namespace grpc_core + +DEFINE_PROTO_FUZZER(const promise_fuzzer::Msg& msg) { + grpc_core::Fuzzer().Run(msg); +} diff --git a/test/core/promise/promise_test.cc b/test/core/promise/promise_test.cc new file mode 100644 index 00000000..aa1a8144 --- /dev/null +++ b/test/core/promise/promise_test.cc @@ -0,0 +1,43 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/promise.h" + +#include + +namespace grpc_core { + +TEST(PromiseTest, Works) { + Promise x = []() { return 42; }; + EXPECT_EQ(x(), Poll(42)); +} + +TEST(PromiseTest, Immediate) { EXPECT_EQ(Immediate(42)(), Poll(42)); } + +TEST(PromiseTest, WithResult) { + EXPECT_EQ(WithResult(Immediate(42))(), Poll(42)); + // Fails to compile: WithResult(Immediate(std::string("hello"))); + // Fails to compile: WithResult(Immediate(42.9)); +} + +TEST(PromiseTest, NowOrNever) { + EXPECT_EQ(NowOrNever(Immediate(42)), absl::optional(42)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/race_test.cc b/test/core/promise/race_test.cc new file mode 100644 index 00000000..7928b9f8 --- /dev/null +++ b/test/core/promise/race_test.cc @@ -0,0 +1,33 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/race.h" + +#include + +namespace grpc_core { + +Poll instant() { return 1; } +Poll never() { return Pending(); } + +TEST(RaceTest, Race1) { EXPECT_EQ(Race(instant)(), Poll(1)); } +TEST(RaceTest, Race2A) { EXPECT_EQ(Race(instant, never)(), Poll(1)); } +TEST(RaceTest, Race2B) { EXPECT_EQ(Race(never, instant)(), Poll(1)); } + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/seq_test.cc b/test/core/promise/seq_test.cc new file mode 100644 index 00000000..15b2fd32 --- /dev/null +++ b/test/core/promise/seq_test.cc @@ -0,0 +1,94 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/seq.h" + +#include + +namespace grpc_core { + +TEST(PromiseTest, Immediate) { + EXPECT_EQ(Seq([] { return 3; })(), 3); +} + +TEST(PromiseTest, OneThen) { + auto initial = [] { return 3; }; + auto then = [](int i) { return [i]() { return i + 4; }; }; + EXPECT_EQ(Seq(initial, then)(), Poll(7)); +} + +TEST(PromiseTest, TwoTypedThens) { + struct A {}; + struct B {}; + struct C {}; + auto initial = [] { return A{}; }; + auto next1 = [](A) { return []() { return B{}; }; }; + auto next2 = [](B) { return []() { return C{}; }; }; + EXPECT_FALSE(absl::holds_alternative(Seq(initial, next1, next2)())); +} + +/* This does not compile, but is useful for testing error messages generated +TEST(PromiseTest, MisTypedThen) { + struct A {}; + struct B {}; + auto initial = [] { return A{}; }; + auto next = [](B) { return []() { return B{}; }; }; + Seq(initial, next)().take(); +} +*/ + +TEST(PromiseTest, TwoThens) { + auto initial = [] { return std::string("a"); }; + auto next1 = [](std::string i) { return [i]() { return i + "b"; }; }; + auto next2 = [](std::string i) { return [i]() { return i + "c"; }; }; + EXPECT_EQ(Seq(initial, next1, next2)(), Poll("abc")); +} + +TEST(PromiseTest, ThreeThens) { + EXPECT_EQ(Seq([] { return std::string("a"); }, + [](std::string i) { return [i]() { return i + "b"; }; }, + [](std::string i) { return [i]() { return i + "c"; }; }, + [](std::string i) { return [i]() { return i + "d"; }; })(), + Poll("abcd")); +} + +struct Big { + int x[256]; + void YesItIsUnused() const {} +}; + +TEST(PromiseTest, SaneSizes) { + auto x = Big(); + auto p1 = Seq( + [x] { + x.YesItIsUnused(); + return 1; + }, + [](int) { + auto y = Big(); + return [y]() { + y.YesItIsUnused(); + return 2; + }; + }); + EXPECT_GE(sizeof(p1), sizeof(Big)); + EXPECT_LT(sizeof(p1), 2 * sizeof(Big)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/try_join_test.cc b/test/core/promise/try_join_test.cc new file mode 100644 index 00000000..2f880ad7 --- /dev/null +++ b/test/core/promise/try_join_test.cc @@ -0,0 +1,80 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/try_join.h" + +#include + +namespace grpc_core { + +template +using P = std::function>()>; + +template +P instant_ok(T x) { + return [x] { return absl::StatusOr(x); }; +} + +template +P instant_fail() { + return [] { return absl::StatusOr(); }; +} + +template +Poll>> ok(T... x) { + return absl::StatusOr>(absl::in_place, x...); +} + +template +Poll>> fail() { + return absl::StatusOr>(); +} + +template +P pending() { + return []() -> Poll> { return Pending(); }; +} + +TEST(TryJoinTest, Join1) { EXPECT_EQ(TryJoin(instant_ok(1))(), ok(1)); } + +TEST(TryJoinTest, Join1Fail) { + EXPECT_EQ(TryJoin(instant_fail())(), fail()); +} + +TEST(TryJoinTest, Join2Success) { + EXPECT_EQ(TryJoin(instant_ok(1), instant_ok(2))(), ok(1, 2)); +} + +TEST(TryJoinTest, Join2Fail1) { + EXPECT_EQ(TryJoin(instant_ok(1), instant_fail())(), (fail())); +} + +TEST(TryJoinTest, Join2Fail2) { + EXPECT_EQ(TryJoin(instant_fail(), instant_ok(2))(), (fail())); +} + +TEST(TryJoinTest, Join2Fail1P) { + EXPECT_EQ(TryJoin(pending(), instant_fail())(), (fail())); +} + +TEST(TryJoinTest, Join2Fail2P) { + EXPECT_EQ(TryJoin(instant_fail(), pending())(), (fail())); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/promise/try_seq_test.cc b/test/core/promise/try_seq_test.cc new file mode 100644 index 00000000..97034b52 --- /dev/null +++ b/test/core/promise/try_seq_test.cc @@ -0,0 +1,78 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/promise/try_seq.h" + +#include + +namespace grpc_core { + +TEST(PromiseTest, SucceedAndThen) { + EXPECT_EQ(TrySeq([] { return absl::StatusOr(1); }, + [](int i) { + return [i]() { return absl::StatusOr(i + 1); }; + })(), + Poll>(absl::StatusOr(2))); +} + +TEST(PromiseTest, SucceedDirectlyAndThenDirectly) { + EXPECT_EQ( + TrySeq([] { return 1; }, [](int i) { return [i]() { return i + 1; }; })(), + Poll>(absl::StatusOr(2))); +} + +TEST(PromiseTest, SucceedAndThenChangeType) { + EXPECT_EQ( + TrySeq([] { return absl::StatusOr(42); }, + [](int i) { + return [i]() { + return absl::StatusOr(std::to_string(i)); + }; + })(), + Poll>(absl::StatusOr("42"))); +} + +TEST(PromiseTest, FailAndThen) { + EXPECT_EQ(TrySeq([]() { return absl::StatusOr(absl::CancelledError()); }, + [](int) { + return []() -> Poll> { abort(); }; + })(), + Poll>( + absl::StatusOr(absl::CancelledError()))); +} + +TEST(PromiseTest, RawSucceedAndThen) { + EXPECT_EQ(TrySeq([] { return absl::OkStatus(); }, + [] { return []() { return absl::OkStatus(); }; })(), + Poll(absl::OkStatus())); +} + +TEST(PromiseTest, RawFailAndThen) { + EXPECT_EQ(TrySeq([] { return absl::CancelledError(); }, + []() { return []() -> Poll { abort(); }; })(), + Poll(absl::CancelledError())); +} + +TEST(PromiseTest, RawSucceedAndThenValue) { + EXPECT_EQ(TrySeq([] { return absl::OkStatus(); }, + [] { return []() { return absl::StatusOr(42); }; })(), + Poll>(absl::StatusOr(42))); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/resource_quota/memory_quota_fuzzer.cc b/test/core/resource_quota/memory_quota_fuzzer.cc new file mode 100644 index 00000000..b3326f52 --- /dev/null +++ b/test/core/resource_quota/memory_quota_fuzzer.cc @@ -0,0 +1,181 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "test/core/resource_quota/call_checker.h" +#include "test/core/resource_quota/memory_quota_fuzzer.pb.h" + +bool squelch = true; +bool leak_check = true; + +namespace grpc_core { +namespace testing { +namespace { +ReclamationPass MapReclamationPass(memory_quota_fuzzer::Reclaimer::Pass pass) { + switch (pass) { + case memory_quota_fuzzer::Reclaimer::BENIGN: + return ReclamationPass::kBenign; + case memory_quota_fuzzer::Reclaimer::IDLE: + return ReclamationPass::kIdle; + case memory_quota_fuzzer::Reclaimer::DESTRUCTIVE: + return ReclamationPass::kDestructive; + default: + return ReclamationPass::kBenign; + } +} + +class Fuzzer { + public: + void Run(const memory_quota_fuzzer::Msg& msg) { + grpc_core::ExecCtx exec_ctx; + RunMsg(msg); + do { + memory_quotas_.clear(); + memory_allocators_.clear(); + allocations_.clear(); + exec_ctx.Flush(); + } while (!memory_quotas_.empty() || !memory_allocators_.empty() || + !allocations_.empty()); + } + + private: + void RunMsg(const memory_quota_fuzzer::Msg& msg) { + for (int i = 0; i < msg.actions_size(); ++i) { + const auto& action = msg.actions(i); + switch (action.action_type_case()) { + case memory_quota_fuzzer::Action::kFlushExecCtx: + ExecCtx::Get()->Flush(); + break; + case memory_quota_fuzzer::Action::kCreateQuota: + memory_quotas_.emplace(action.quota(), MemoryQuota()); + break; + case memory_quota_fuzzer::Action::kDeleteQuota: + memory_quotas_.erase(action.quota()); + break; + case memory_quota_fuzzer::Action::kCreateAllocator: + WithQuota(action.quota(), [this, action](MemoryQuota* q) { + memory_allocators_.emplace(action.allocator(), + q->CreateMemoryOwner()); + }); + break; + case memory_quota_fuzzer::Action::kDeleteAllocator: + memory_allocators_.erase(action.allocator()); + break; + case memory_quota_fuzzer::Action::kSetQuotaSize: + WithQuota(action.quota(), [action](MemoryQuota* q) { + q->SetSize(Clamp(action.set_quota_size(), uint64_t{0}, + uint64_t{std::numeric_limits::max()})); + }); + break; + case memory_quota_fuzzer::Action::kRebindQuota: + WithQuota(action.quota(), [this, action](MemoryQuota* q) { + WithAllocator(action.allocator(), + [q](MemoryOwner* a) { a->Rebind(q); }); + }); + break; + case memory_quota_fuzzer::Action::kCreateAllocation: { + auto min = action.create_allocation().min(); + auto max = action.create_allocation().max(); + if (min > max) break; + if (max > MemoryRequest::max_allowed_size()) break; + MemoryRequest req(min, max); + WithAllocator( + action.allocator(), [this, action, req](MemoryOwner* a) { + auto alloc = a->allocator()->MakeReservation(req); + allocations_.emplace(action.allocation(), std::move(alloc)); + }); + } break; + case memory_quota_fuzzer::Action::kDeleteAllocation: + allocations_.erase(action.allocation()); + break; + case memory_quota_fuzzer::Action::kPostReclaimer: { + std::function)> reclaimer; + auto cfg = action.post_reclaimer(); + if (cfg.synchronous()) { + reclaimer = [this, cfg](absl::optional) { + RunMsg(cfg.msg()); + }; + } else { + reclaimer = [cfg, this](absl::optional sweep) { + struct Args { + absl::optional sweep; + memory_quota_fuzzer::Msg msg; + Fuzzer* fuzzer; + }; + auto* args = new Args{std::move(sweep), cfg.msg(), this}; + auto* closure = GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle) { + auto* args = static_cast(arg); + args->fuzzer->RunMsg(args->msg); + delete args; + }, + args, nullptr); + ExecCtx::Get()->Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); + }; + auto pass = MapReclamationPass(cfg.pass()); + WithAllocator( + action.allocator(), [pass, reclaimer](MemoryOwner* a) { + // ensure called exactly once + auto call_checker = CallChecker::Make(); + a->PostReclaimer(pass, + [reclaimer, call_checker]( + absl::optional sweep) { + call_checker->Called(); + reclaimer(std::move(sweep)); + }); + }); + } + } break; + case memory_quota_fuzzer::Action::ACTION_TYPE_NOT_SET: + break; + } + } + } + + template + void WithQuota(int quota, F f) { + auto it = memory_quotas_.find(quota); + if (it == memory_quotas_.end()) return; + f(&it->second); + } + + template + void WithAllocator(int allocator, F f) { + auto it = memory_allocators_.find(allocator); + if (it == memory_allocators_.end()) return; + f(&it->second); + } + + std::map memory_quotas_; + std::map memory_allocators_; + std::map allocations_; +}; + +} // namespace +} // namespace testing +} // namespace grpc_core + +static void dont_log(gpr_log_func_args* /*args*/) {} + +DEFINE_PROTO_FUZZER(const memory_quota_fuzzer::Msg& msg) { + if (squelch) gpr_set_log_function(dont_log); + gpr_log_verbosity_init(); + grpc_tracer_init(); + grpc_core::testing::Fuzzer().Run(msg); +} diff --git a/test/core/resource_quota/memory_quota_stress_test.cc b/test/core/resource_quota/memory_quota_stress_test.cc new file mode 100644 index 00000000..601c0662 --- /dev/null +++ b/test/core/resource_quota/memory_quota_stress_test.cc @@ -0,0 +1,213 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/resource_quota/memory_quota.h" + +namespace grpc_core { + +namespace { +class StressTest { + public: + // Create a stress test with some size. + StressTest(size_t num_quotas, size_t num_allocators) { + for (size_t i = 0; i < num_quotas; ++i) { + quotas_.emplace_back(); + } + std::random_device g; + std::uniform_int_distribution dist(0, num_quotas - 1); + for (size_t i = 0; i < num_allocators; ++i) { + allocators_.emplace_back(quotas_[dist(g)].CreateMemoryOwner()); + } + } + + // Run the thread for some period of time. + void Run(int seconds) { + std::vector threads; + + // A few threads constantly rebinding allocators to different quotas. + threads.reserve(2 + 2 + 3 * allocators_.size()); + for (int i = 0; i < 2; i++) threads.push_back(Run(Rebinder)); + // And another few threads constantly resizing quotas. + for (int i = 0; i < 2; i++) threads.push_back(Run(Resizer)); + + // For each (allocator, pass), start a thread continuously allocating from + // that allocator. Whenever the first allocation is made, schedule a + // reclaimer for that pass. + for (size_t i = 0; i < allocators_.size(); i++) { + auto* allocator = &allocators_[i]; + for (ReclamationPass pass : + {ReclamationPass::kBenign, ReclamationPass::kIdle, + ReclamationPass::kDestructive}) { + threads.push_back(Run([allocator, pass](StatePtr st) mutable { + if (st->RememberReservation(allocator->allocator()->MakeReservation( + st->RandomRequest()))) { + allocator->PostReclaimer( + pass, [st](absl::optional sweep) { + if (!sweep.has_value()) return; + st->ForgetReservations(); + }); + } + })); + } + } + + // All threads started, wait for the alloted time. + std::this_thread::sleep_for(std::chrono::seconds(seconds)); + + // Toggle the completion bit, and then wait for the threads. + done_.store(true, std::memory_order_relaxed); + while (!threads.empty()) { + threads.back().join(); + threads.pop_back(); + } + } + + private: + // Per-thread state. + // Not everything is used on every thread, but it's not terrible having the + // extra state around and it does simplify things somewhat. + class State { + public: + explicit State(StressTest* test) + : test_(test), + quotas_distribution_(0, test_->quotas_.size() - 1), + allocators_distribution_(0, test_->allocators_.size() - 1), + size_distribution_(1, 4 * 1024 * 1024), + quota_size_distribution_(1024 * 1024, size_t(8) * 1024 * 1024 * 1024), + choose_variable_size_(1, 100) {} + + // Choose a random quota, and return an owned pointer to it. + // Not thread-safe, only callable from the owning thread. + MemoryQuota* RandomQuota() { + return &test_->quotas_[quotas_distribution_(g_)]; + } + + // Choose a random allocator, and return a borrowed pointer to it. + // Not thread-safe, only callable from the owning thread. + MemoryOwner* RandomAllocator() { + return &test_->allocators_[allocators_distribution_(g_)]; + } + + // Random memory request size - 1% of allocations are chosen to be variable + // sized - the rest are fixed (since variable sized create some contention + // problems between allocator threads of different passes on the same + // allocator). + // Not thread-safe, only callable from the owning thread. + MemoryRequest RandomRequest() { + size_t a = size_distribution_(g_); + if (choose_variable_size_(g_) == 1) { + size_t b = size_distribution_(g_); + return MemoryRequest(std::min(a, b), std::max(a, b)); + } + return MemoryRequest(a); + } + + // Choose a new size for a backing quota. + // Not thread-safe, only callable from the owning thread. + size_t RandomQuotaSize() { return quota_size_distribution_(g_); } + + // Remember a reservation, return true if it's the first remembered since + // the last reclamation. + // Thread-safe. + bool RememberReservation(MemoryAllocator::Reservation reservation) + ABSL_LOCKS_EXCLUDED(mu_) { + MutexLock lock(&mu_); + bool was_empty = reservations_.empty(); + reservations_.emplace_back(std::move(reservation)); + return was_empty; + } + + // Return all reservations made until this moment, so that they can be + // dropped. + std::vector ForgetReservations() + ABSL_LOCKS_EXCLUDED(mu_) { + MutexLock lock(&mu_); + return std::move(reservations_); + } + + private: + // Owning test. + StressTest* const test_; + // Random number generator. + std::mt19937 g_{std::random_device()()}; + // Distribution to choose a quota. + std::uniform_int_distribution quotas_distribution_; + // Distribution to choose an allocator. + std::uniform_int_distribution allocators_distribution_; + // Distribution to choose an allocation size. + std::uniform_int_distribution size_distribution_; + // Distribution to choose a quota size. + std::uniform_int_distribution quota_size_distribution_; + // Distribution to choose whether to make a variable-sized allocation. + std::uniform_int_distribution choose_variable_size_; + + // Mutex to protect the reservation list. + Mutex mu_; + // Reservations remembered by this thread. + std::vector reservations_ + ABSL_GUARDED_BY(mu_); + }; + // Type alias since we always pass around these shared pointers. + using StatePtr = std::shared_ptr; + + // Choose one allocator, one quota, rebind the allocator to the quota. + static void Rebinder(StatePtr st) { + auto* allocator = st->RandomAllocator(); + auto* quota = st->RandomQuota(); + allocator->Rebind(quota); + } + + // Choose one allocator, resize it to a randomly chosen size. + static void Resizer(StatePtr st) { + auto* quota = st->RandomQuota(); + size_t size = st->RandomQuotaSize(); + quota->SetSize(size); + } + + // Create a thread that repeatedly runs a function until the test is done. + // We create one instance of State that we pass as a StatePtr to said + // function as the current overall state for this thread. + // Monitors done_ to see when we should stop. + // Ensures there's an ExecCtx for each iteration of the loop. + template + std::thread Run(Fn fn) { + return std::thread([this, fn]() mutable { + auto state = std::make_shared(this); + while (!done_.load(std::memory_order_relaxed)) { + ExecCtx exec_ctx; + fn(state); + } + }); + } + + // Flag for when the test is completed. + std::atomic done_{false}; + + // Memory quotas to test against. We build this up at construction time, but + // then don't resize, so we can load from it continuously from all of the + // threads. + std::vector quotas_; + // Memory allocators to test against. Similarly, built at construction time, + // and then the shape of this vector is not changed. + std::vector allocators_; +}; +} // namespace + +} // namespace grpc_core + +int main(int, char**) { grpc_core::StressTest(16, 64).Run(8); } diff --git a/test/core/resource_quota/memory_quota_test.cc b/test/core/resource_quota/memory_quota_test.cc new file mode 100644 index 00000000..c1e9985a --- /dev/null +++ b/test/core/resource_quota/memory_quota_test.cc @@ -0,0 +1,191 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/resource_quota/memory_quota.h" + +#include + +#include "absl/synchronization/notification.h" + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_refcount.h" +#include "test/core/resource_quota/call_checker.h" + +namespace grpc_core { +namespace testing { + +// +// Helpers +// + +template +struct Sized { + char blah[kSize]; + virtual ~Sized() {} +}; + +// +// MemoryRequestTest +// + +TEST(MemoryRequestTest, ConversionFromSize) { + MemoryRequest request = 3; + EXPECT_EQ(request.min(), 3); + EXPECT_EQ(request.max(), 3); +} + +TEST(MemoryRequestTest, MinMax) { + MemoryRequest request(3, 7); + EXPECT_EQ(request.min(), 3); + EXPECT_EQ(request.max(), 7); +} + +// +// MemoryQuotaTest +// + +TEST(MemoryQuotaTest, NoOp) { MemoryQuota(); } + +TEST(MemoryQuotaTest, CreateAllocatorNoOp) { + MemoryQuota memory_quota; + auto memory_allocator = memory_quota.CreateMemoryAllocator(); +} + +TEST(MemoryQuotaTest, CreateObjectFromAllocator) { + MemoryQuota memory_quota; + auto memory_allocator = memory_quota.CreateMemoryAllocator(); + auto object = memory_allocator.MakeUnique>(); +} + +TEST(MemoryQuotaTest, CreateSomeObjectsAndExpectReclamation) { + ExecCtx exec_ctx; + + MemoryQuota memory_quota; + memory_quota.SetSize(4096); + auto memory_allocator = memory_quota.CreateMemoryOwner(); + auto object = memory_allocator.allocator()->MakeUnique>(); + + auto checker1 = CallChecker::Make(); + memory_allocator.PostReclaimer( + ReclamationPass::kDestructive, + [&object, checker1](absl::optional sweep) { + checker1->Called(); + EXPECT_TRUE(sweep.has_value()); + object.reset(); + }); + auto object2 = memory_allocator.allocator()->MakeUnique>(); + exec_ctx.Flush(); + EXPECT_EQ(object.get(), nullptr); + + auto checker2 = CallChecker::Make(); + memory_allocator.PostReclaimer( + ReclamationPass::kDestructive, + [&object2, checker2](absl::optional sweep) { + checker2->Called(); + EXPECT_TRUE(sweep.has_value()); + object2.reset(); + }); + auto object3 = memory_allocator.allocator()->MakeUnique>(); + exec_ctx.Flush(); + EXPECT_EQ(object2.get(), nullptr); +} + +TEST(MemoryQuotaTest, BasicRebind) { + ExecCtx exec_ctx; + + MemoryQuota memory_quota; + memory_quota.SetSize(4096); + MemoryQuota memory_quota2; + memory_quota2.SetSize(4096); + + auto memory_allocator = memory_quota2.CreateMemoryOwner(); + auto object = memory_allocator.allocator()->MakeUnique>(); + + memory_allocator.Rebind(&memory_quota); + auto memory_allocator2 = memory_quota2.CreateMemoryOwner(); + + auto checker1 = CallChecker::Make(); + memory_allocator2.PostReclaimer( + ReclamationPass::kDestructive, + [checker1](absl::optional sweep) { + checker1->Called(); + // Taken memory should be reassigned to + // memory_quota, so this should be cancelled + EXPECT_FALSE(sweep.has_value()); + }); + + auto checker2 = CallChecker::Make(); + memory_allocator.PostReclaimer( + ReclamationPass::kDestructive, + [&object, checker2](absl::optional sweep) { + checker2->Called(); + EXPECT_TRUE(sweep.has_value()); + // The new memory allocator should reclaim + // the object allocated against the previous + // quota because that's now part of this + // quota. + object.reset(); + }); + + auto object2 = memory_allocator.allocator()->MakeUnique>(); + exec_ctx.Flush(); + EXPECT_EQ(object.get(), nullptr); +} + +TEST(MemoryQuotaTest, ReserveRangeNoPressure) { + MemoryQuota memory_quota; + auto memory_allocator = memory_quota.CreateMemoryAllocator(); + size_t total = 0; + for (int i = 0; i < 10000; i++) { + auto n = memory_allocator.Reserve(MemoryRequest(100, 40000)); + EXPECT_EQ(n, 40000); + total += n; + } + memory_allocator.Release(total); +} + +TEST(MemoryQuotaTest, MakeSlice) { + MemoryQuota memory_quota; + auto memory_allocator = memory_quota.CreateMemoryAllocator(); + std::vector slices; + for (int i = 1; i < 1000; i++) { + int min = i; + int max = 10 * i - 9; + slices.push_back(memory_allocator.MakeSlice(MemoryRequest(min, max))); + } + for (grpc_slice slice : slices) { + grpc_slice_unref_internal(slice); + } +} + +TEST(MemoryQuotaTest, ContainerAllocator) { + MemoryQuota memory_quota; + auto memory_allocator = memory_quota.CreateMemoryAllocator(); + Vector vec(&memory_allocator); + for (int i = 0; i < 100000; i++) { + vec.push_back(i); + } +} + +} // namespace testing +} // namespace grpc_core + +// Hook needed to run ExecCtx outside of iomgr. +void grpc_set_default_iomgr_platform() {} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + gpr_log_verbosity_init(); + return RUN_ALL_TESTS(); +} diff --git a/test/core/resource_quota/resource_quota_test.cc b/test/core/resource_quota/resource_quota_test.cc new file mode 100644 index 00000000..07852a31 --- /dev/null +++ b/test/core/resource_quota/resource_quota_test.cc @@ -0,0 +1,37 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/resource_quota/resource_quota.h" + +#include + +namespace grpc_core { +namespace testing { + +TEST(ResourceQuotaTest, Works) { + auto q = MakeRefCounted(); + EXPECT_NE(q->thread_quota(), nullptr); + EXPECT_NE(q->memory_quota(), nullptr); +} + +} // namespace testing +} // namespace grpc_core + +// Hook needed to run ExecCtx outside of iomgr. +void grpc_set_default_iomgr_platform() {} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/resource_quota/thread_quota_test.cc b/test/core/resource_quota/thread_quota_test.cc new file mode 100644 index 00000000..a2014be6 --- /dev/null +++ b/test/core/resource_quota/thread_quota_test.cc @@ -0,0 +1,45 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/resource_quota/thread_quota.h" + +#include + +namespace grpc_core { +namespace testing { + +TEST(ThreadQuotaTest, Works) { + auto q = MakeRefCounted(); + EXPECT_TRUE(q->Reserve(128)); + q->SetMax(10); + EXPECT_FALSE(q->Reserve(128)); + EXPECT_FALSE(q->Reserve(1)); + q->Release(118); + EXPECT_FALSE(q->Reserve(1)); + q->Release(1); + EXPECT_TRUE(q->Reserve(1)); + EXPECT_FALSE(q->Reserve(1)); + q->Release(10); +} + +} // namespace testing +} // namespace grpc_core + +// Hook needed to run ExecCtx outside of iomgr. +void grpc_set_default_iomgr_platform() {} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/security/alts_credentials_fuzzer.cc b/test/core/security/alts_credentials_fuzzer.cc new file mode 100644 index 00000000..1f2aa605 --- /dev/null +++ b/test/core/security/alts_credentials_fuzzer.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" +#include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" +#include "test/core/util/fuzzer_util.h" + +using grpc_core::testing::grpc_fuzzer_get_next_byte; +using grpc_core::testing::grpc_fuzzer_get_next_string; +using grpc_core::testing::input_stream; + +// Logging +bool squelch = true; +bool leak_check = true; + +static void dont_log(gpr_log_func_args* /*args*/) {} + +// Add a random number of target service accounts to client options. +static void read_target_service_accounts( + input_stream* inp, grpc_alts_credentials_options* options) { + size_t n = grpc_fuzzer_get_next_byte(inp); + for (size_t i = 0; i < n; i++) { + char* service_account = grpc_fuzzer_get_next_string(inp, nullptr); + if (service_account != nullptr) { + grpc_alts_credentials_client_options_add_target_service_account( + options, service_account); + gpr_free(service_account); + } + } + // Added to improve code coverage. + grpc_alts_credentials_client_options_add_target_service_account(options, + nullptr); + grpc_alts_credentials_client_options_add_target_service_account( + nullptr, "this is service account"); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + char* grpc_trace_fuzzer = gpr_getenv("GRPC_TRACE_FUZZER"); + if (squelch && grpc_trace_fuzzer == nullptr) { + gpr_set_log_function(dont_log); + } + gpr_free(grpc_trace_fuzzer); + input_stream inp = {data, data + size}; + grpc_init(); + bool is_on_gcp = grpc_alts_is_running_on_gcp(); + while (inp.cur != inp.end) { + bool enable_untrusted_alts = grpc_fuzzer_get_next_byte(&inp) & 0x01; + char* handshaker_service_url = + grpc_fuzzer_get_next_byte(&inp) & 0x01 + ? grpc_fuzzer_get_next_string(&inp, nullptr) + : nullptr; + if (grpc_fuzzer_get_next_byte(&inp) & 0x01) { + // Test ALTS channel credentials. + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + read_target_service_accounts(&inp, options); + grpc_channel_credentials* cred = grpc_alts_credentials_create_customized( + options, handshaker_service_url, enable_untrusted_alts); + if (!enable_untrusted_alts && !is_on_gcp) { + GPR_ASSERT(cred == nullptr); + } else { + GPR_ASSERT(cred != nullptr); + } + grpc_channel_credentials_release(cred); + grpc_alts_credentials_options_destroy(options); + } else { + // Test ALTS server credentials. + grpc_alts_credentials_options* options = + grpc_alts_credentials_server_options_create(); + grpc_server_credentials* cred = + grpc_alts_server_credentials_create_customized( + options, handshaker_service_url, enable_untrusted_alts); + if (!enable_untrusted_alts && !is_on_gcp) { + GPR_ASSERT(cred == nullptr); + } else { + GPR_ASSERT(cred != nullptr); + } + grpc_server_credentials_release(cred); + grpc_alts_credentials_options_destroy(options); + } + gpr_free(handshaker_service_url); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/alts_security_connector_test.cc b/test/core/security/alts_security_connector_test.cc new file mode 100644 index 00000000..becd4e63 --- /dev/null +++ b/test/core/security/alts_security_connector_test.cc @@ -0,0 +1,206 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/transport/transport.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/core/tsi/transport_security.h" + +using grpc_core::internal::grpc_alts_auth_context_from_tsi_peer; + +/* This file contains unit tests of grpc_alts_auth_context_from_tsi_peer(). */ +static void test_invalid_input_failure() { + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(nullptr); + GPR_ASSERT(ctx == nullptr); +} + +static void test_empty_certificate_type_failure() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(0, &peer) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx == nullptr); + tsi_peer_destruct(&peer); +} + +static void test_empty_peer_property_failure() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx == nullptr); + tsi_peer_destruct(&peer); +} + +static void test_missing_rpc_protocol_versions_property_failure() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice", + &peer.properties[1]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx == nullptr); + tsi_peer_destruct(&peer); +} + +static void test_missing_security_level_property_failure() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice", + &peer.properties[1]) == TSI_OK); + grpc_gcp_rpc_protocol_versions peer_versions; + grpc_gcp_rpc_protocol_versions_set_max(&peer_versions, + GRPC_PROTOCOL_VERSION_MAX_MAJOR, + GRPC_PROTOCOL_VERSION_MAX_MINOR); + grpc_gcp_rpc_protocol_versions_set_min(&peer_versions, + GRPC_PROTOCOL_VERSION_MIN_MAJOR, + GRPC_PROTOCOL_VERSION_MIN_MINOR); + grpc_slice serialized_peer_versions; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&peer_versions, + &serialized_peer_versions)); + + GPR_ASSERT(tsi_construct_string_peer_property( + TSI_ALTS_RPC_VERSIONS, + reinterpret_cast( + GRPC_SLICE_START_PTR(serialized_peer_versions)), + GRPC_SLICE_LENGTH(serialized_peer_versions), + &peer.properties[2]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx == nullptr); + grpc_slice_unref(serialized_peer_versions); + tsi_peer_destruct(&peer); +} + +static void test_unknown_peer_property_failure() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + "unknown", "alice", &peer.properties[1]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx == nullptr); + tsi_peer_destruct(&peer); +} + +static bool test_identity(const grpc_auth_context* ctx, + const char* expected_property_name, + const char* expected_identity) { + grpc_auth_property_iterator it; + const grpc_auth_property* prop; + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); + it = grpc_auth_context_peer_identity(ctx); + prop = grpc_auth_property_iterator_next(&it); + GPR_ASSERT(prop != nullptr); + if (strcmp(prop->name, expected_property_name) != 0) { + gpr_log(GPR_ERROR, "Expected peer identity property name %s and got %s.", + expected_property_name, prop->name); + return false; + } + if (strncmp(prop->value, expected_identity, prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected peer identity %s and got got %s.", + expected_identity, prop->value); + return false; + } + return true; +} + +static void test_alts_peer_to_auth_context_success() { + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice", + &peer.properties[1]) == TSI_OK); + grpc_gcp_rpc_protocol_versions peer_versions; + grpc_gcp_rpc_protocol_versions_set_max(&peer_versions, + GRPC_PROTOCOL_VERSION_MAX_MAJOR, + GRPC_PROTOCOL_VERSION_MAX_MINOR); + grpc_gcp_rpc_protocol_versions_set_min(&peer_versions, + GRPC_PROTOCOL_VERSION_MIN_MAJOR, + GRPC_PROTOCOL_VERSION_MIN_MINOR); + grpc_slice serialized_peer_versions; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&peer_versions, + &serialized_peer_versions)); + GPR_ASSERT(tsi_construct_string_peer_property( + TSI_ALTS_RPC_VERSIONS, + reinterpret_cast( + GRPC_SLICE_START_PTR(serialized_peer_versions)), + GRPC_SLICE_LENGTH(serialized_peer_versions), + &peer.properties[2]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[3]) == TSI_OK); + char test_ctx[] = "test serialized context"; + grpc_slice serialized_alts_ctx = grpc_slice_from_copied_string(test_ctx); + GPR_ASSERT( + tsi_construct_string_peer_property( + TSI_ALTS_CONTEXT, + reinterpret_cast(GRPC_SLICE_START_PTR(serialized_alts_ctx)), + GRPC_SLICE_LENGTH(serialized_alts_ctx), + &peer.properties[4]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(test_identity(ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + "alice")); + ctx.reset(DEBUG_LOCATION, "test"); + grpc_slice_unref(serialized_peer_versions); + grpc_slice_unref(serialized_alts_ctx); + tsi_peer_destruct(&peer); +} + +int main(int /*argc*/, char** /*argv*/) { + /* Test. */ + test_invalid_input_failure(); + test_empty_certificate_type_failure(); + test_empty_peer_property_failure(); + test_unknown_peer_property_failure(); + test_missing_rpc_protocol_versions_property_failure(); + test_missing_security_level_property_failure(); + test_alts_peer_to_auth_context_success(); + + return 0; +} diff --git a/test/core/security/auth_context_test.cc b/test/core/security/auth_context_test.cc new file mode 100644 index 00000000..142cc60f --- /dev/null +++ b/test/core/security/auth_context_test.cc @@ -0,0 +1,146 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" +#include "test/core/util/test_config.h" + +static void test_empty_context(void) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_property_iterator it; + + gpr_log(GPR_INFO, "test_empty_context"); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) == + nullptr); + it = grpc_auth_context_peer_identity(ctx.get()); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + it = grpc_auth_context_property_iterator(ctx.get()); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + GPR_ASSERT( + grpc_auth_context_set_peer_identity_property_name(ctx.get(), "bar") == 0); + GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) == + nullptr); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_simple_context(void) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_property_iterator it; + size_t i; + + gpr_log(GPR_INFO, "test_simple_context"); + GPR_ASSERT(ctx != nullptr); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapo"); + grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar"); + GPR_ASSERT(ctx->properties().count == 3); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(), + "name") == 1); + + GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()), + "name") == 0); + it = grpc_auth_context_property_iterator(ctx.get()); + for (i = 0; i < ctx->properties().count; i++) { + const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); + GPR_ASSERT(p == &ctx->properties().array[i]); + } + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[2]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + it = grpc_auth_context_peer_identity(ctx.get()); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[0]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[1]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_chained_context(void) { + grpc_core::RefCountedPtr chained = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_context* chained_ptr = chained.get(); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(std::move(chained)); + + grpc_auth_property_iterator it; + size_t i; + + gpr_log(GPR_INFO, "test_chained_context"); + grpc_auth_context_add_cstring_property(chained_ptr, "name", "padapo"); + grpc_auth_context_add_cstring_property(chained_ptr, "foo", "baz"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chap0"); + grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar"); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(), + "name") == 1); + + GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()), + "name") == 0); + it = grpc_auth_context_property_iterator(ctx.get()); + for (i = 0; i < ctx->properties().count; i++) { + const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); + GPR_ASSERT(p == &ctx->properties().array[i]); + } + for (i = 0; i < chained_ptr->properties().count; i++) { + const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); + GPR_ASSERT(p == &chained_ptr->properties().array[i]); + } + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[2]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &chained_ptr->properties().array[1]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + it = grpc_auth_context_peer_identity(ctx.get()); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[0]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &ctx->properties().array[1]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == + &chained_ptr->properties().array[0]); + GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); + + ctx.reset(DEBUG_LOCATION, "test"); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_empty_context(); + test_simple_context(); + test_chained_context(); + return 0; +} diff --git a/test/core/security/authorization_matchers_test.cc b/test/core/security/authorization_matchers_test.cc new file mode 100644 index 00000000..6f4a3efc --- /dev/null +++ b/test/core/security/authorization_matchers_test.cc @@ -0,0 +1,457 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include +#include + +#include + +#include "src/core/lib/security/authorization/evaluate_args.h" +#include "src/core/lib/security/authorization/matchers.h" +#include "test/core/util/evaluate_args_test_util.h" + +namespace grpc_core { + +class AuthorizationMatchersTest : public ::testing::Test { + protected: + EvaluateArgsTestUtil args_; +}; + +TEST_F(AuthorizationMatchersTest, AlwaysAuthorizationMatcher) { + EvaluateArgs args = args_.MakeEvaluateArgs(); + AlwaysAuthorizationMatcher matcher; + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AndAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata("foo", "bar"); + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kDestPort, /*port=*/123)); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(rules))); + EXPECT_TRUE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AndAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata("foo", "not_bar"); + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kDestPort, /*port=*/123)); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(rules))); + // Header rule fails. Expected value "bar", got "not_bar" for key "foo". + EXPECT_FALSE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, OrAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata("foo", "bar"); + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kDestPort, /*port=*/456)); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(rules))); + // Matches as header rule matches even though port rule fails. + EXPECT_TRUE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, OrAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata("foo", "not_bar"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(rules))); + // Header rule fails. Expected value "bar", got "not_bar" for key "foo". + EXPECT_FALSE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, NotAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata(":path", "/different/foo"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + auto matcher = AuthorizationMatcher::Create(Rbac::Principal( + Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kPath, + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"/expected/foo", + /*case_sensitive=*/false) + .value()))); + EXPECT_TRUE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, NotAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata(":path", "/expected/foo"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + auto matcher = AuthorizationMatcher::Create(Rbac::Principal( + Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kPath, + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"/expected/foo", + /*case_sensitive=*/false) + .value()))); + EXPECT_FALSE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, HybridAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata("foo", "bar"); + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> sub_and_rules; + sub_and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + std::vector> sub_or_rules; + sub_or_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kDestPort, /*port=*/123)); + std::vector> and_rules; + and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kAnd, std::move(sub_and_rules))); + and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kOr, std::move(std::move(sub_or_rules)))); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(and_rules))); + EXPECT_TRUE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, HybridAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata("foo", "bar"); + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> sub_and_rules; + sub_and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"foo", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + sub_and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"absent_key", HeaderMatcher::Type::kExact, + /*matcher=*/"some_value") + .value())); + std::vector> sub_or_rules; + sub_or_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kDestPort, /*port=*/123)); + std::vector> and_rules; + and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kAnd, std::move(sub_and_rules))); + and_rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kOr, std::move(std::move(sub_or_rules)))); + auto matcher = AuthorizationMatcher::Create( + Rbac::Permission(Rbac::Permission::RuleType::kAnd, std::move(and_rules))); + // Fails as "absent_key" header was not present. + EXPECT_FALSE(matcher->Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PathAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata(":path", "expected/path"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + PathAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"expected/path", + /*case_sensitive=*/false) + .value()); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PathAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata(":path", "different/path"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + PathAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"expected/path", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + PathAuthorizationMatcherFailedMatchMissingPath) { + EvaluateArgs args = args_.MakeEvaluateArgs(); + PathAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"expected/path", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, HeaderAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata("key123", "foo_xxx"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + HeaderAuthorizationMatcher matcher( + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kPrefix, + /*matcher=*/"foo") + .value()); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, HeaderAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata("key123", "foo"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + HeaderAuthorizationMatcher matcher( + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + HeaderAuthorizationMatcherFailedMatchMultivaluedHeader) { + args_.AddPairToMetadata("key123", "foo"); + args_.AddPairToMetadata("key123", "bar"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + HeaderAuthorizationMatcher matcher( + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kExact, + /*matcher=*/"foo") + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + HeaderAuthorizationMatcherFailedMatchMissingHeader) { + EvaluateArgs args = args_.MakeEvaluateArgs(); + HeaderAuthorizationMatcher matcher( + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kSuffix, + /*matcher=*/"foo") + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + IpAuthorizationMatcherLocalIpSuccessfulMatch) { + args_.SetLocalEndpoint("ipv4:1.2.3.4:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + IpAuthorizationMatcher matcher( + IpAuthorizationMatcher::Type::kDestIp, + Rbac::CidrRange(/*address_prefix=*/"1.7.8.9", /*prefix_len=*/8)); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, IpAuthorizationMatcherLocalIpFailedMatch) { + args_.SetLocalEndpoint("ipv4:1.2.3.4:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + IpAuthorizationMatcher matcher( + IpAuthorizationMatcher::Type::kDestIp, + Rbac::CidrRange(/*address_prefix=*/"1.2.3.9", /*prefix_len=*/32)); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, IpAuthorizationMatcherPeerIpSuccessfulMatch) { + args_.SetPeerEndpoint("ipv6:[1:2:3::]:456"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + IpAuthorizationMatcher matcher( + IpAuthorizationMatcher::Type::kSourceIp, + Rbac::CidrRange(/*address_prefix=*/"1:2:4::", /*prefix_len=*/32)); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, IpAuthorizationMatcherPeerIpFailedMatch) { + args_.SetPeerEndpoint("ipv6:[1:2::]:456"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + IpAuthorizationMatcher matcher( + IpAuthorizationMatcher::Type::kSourceIp, + Rbac::CidrRange(/*address_prefix=*/"1:3::", /*prefix_len=*/32)); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + IpAuthorizationMatcherUnsupportedIpFailedMatch) { + EvaluateArgs args = args_.MakeEvaluateArgs(); + IpAuthorizationMatcher matcher(IpAuthorizationMatcher::Type::kRemoteIp, {}); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PortAuthorizationMatcherSuccessfulMatch) { + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + PortAuthorizationMatcher matcher(/*port=*/123); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PortAuthorizationMatcherFailedMatch) { + args_.SetLocalEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + PortAuthorizationMatcher matcher(/*port=*/456); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + AuthenticatedMatcherUnAuthenticatedConnection) { + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"foo.com", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, + AuthenticatedMatcherAuthenticatedConnectionMatcherUnset) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"", + /*case_sensitive=*/false) + .value()); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AuthenticatedMatcherSuccessfulUriSanMatches) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_TLS_TRANSPORT_SECURITY_TYPE); + args_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, + "spiffe://foo.abc"); + args_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, + "https://foo.domain.com"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"spiffe://foo.abc", + /*case_sensitive=*/false) + .value()); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AuthenticatedMatcherFailedUriSanMatches) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_TLS_TRANSPORT_SECURITY_TYPE); + args_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, + "spiffe://bar.abc"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"spiffe://foo.abc", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AuthenticatedMatcherSuccessfulDnsSanMatches) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + args_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, + "spiffe://bar.abc"); + args_.AddPropertyToAuthContext(GRPC_PEER_DNS_PROPERTY_NAME, + "foo.test.domain.com"); + args_.AddPropertyToAuthContext(GRPC_PEER_DNS_PROPERTY_NAME, + "bar.test.domain.com"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + // No match found in URI SANs, finds match in DNS SANs. + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"bar.test.domain.com", + /*case_sensitive=*/false) + .value()); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AuthenticatedMatcherFailedDnsSanMatches) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + args_.AddPropertyToAuthContext(GRPC_PEER_DNS_PROPERTY_NAME, + "foo.test.domain.com"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"bar.test.domain.com", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, AuthenticatedMatcherFailedNothingMatches) { + args_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + EvaluateArgs args = args_.MakeEvaluateArgs(); + AuthenticatedAuthorizationMatcher matcher( + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"foo", + /*case_sensitive=*/false) + .value()); + EXPECT_FALSE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PolicyAuthorizationMatcherSuccessfulMatch) { + args_.AddPairToMetadata("key123", "foo"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kExact, + /*matcher=*/"foo") + .value())); + PolicyAuthorizationMatcher matcher(Rbac::Policy( + Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(rules)), + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + EXPECT_TRUE(matcher.Matches(args)); +} + +TEST_F(AuthorizationMatchersTest, PolicyAuthorizationMatcherFailedMatch) { + args_.AddPairToMetadata("key123", "foo"); + EvaluateArgs args = args_.MakeEvaluateArgs(); + std::vector> rules; + rules.push_back(absl::make_unique( + Rbac::Permission::RuleType::kHeader, + HeaderMatcher::Create(/*name=*/"key123", HeaderMatcher::Type::kExact, + /*matcher=*/"bar") + .value())); + PolicyAuthorizationMatcher matcher(Rbac::Policy( + Rbac::Permission(Rbac::Permission::RuleType::kOr, std::move(rules)), + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + EXPECT_FALSE(matcher.Matches(args)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/aws_request_signer_test.cc b/test/core/security/aws_request_signer_test.cc new file mode 100644 index 00000000..699f32f1 --- /dev/null +++ b/test/core/security/aws_request_signer_test.cc @@ -0,0 +1,279 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/security/credentials/external/aws_request_signer.h" + +#include + +#include + +#include "test/core/util/test_config.h" + +namespace testing { + +namespace { +// Test cases of Aws endpoints that the aws-sourced credentials will depend +// on. +const char* kAmzTestAccessKeyId = "ASIARD4OQDT6A77FR3CL"; +const char* kAmzTestSecretAccessKey = + "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"; +const char* kAmzTestToken = + "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/" + "OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/" + "Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+" + "CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/" + "QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/" + "glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/" + "Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/" + "Vi8pNcM2/" + "VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+" + "mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+" + "Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/" + "KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkb" + "juz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/" + "A5R/0QfEKOZL1/" + "k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+" + "01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/" + "LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+" + "Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3" + "mz1cg4Ekgcrn/" + "E09NTsxAqD8NcZ7C7ECom9r+" + "X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="; +const char* kAmzTestDate = "20200811T065522Z"; + +// Test cases derived from the Aws signature v4 test suite. +// https://github.com/boto/botocore/tree/master/tests/unit/auth/aws4_testsuite +const char* kBotoTestAccessKeyId = "AKIDEXAMPLE"; +const char* kBotoTestSecretAccessKey = + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; +const char* kBotoTestToken = ""; +const char* kBotoTestDate = "Mon, 09 Sep 2011 23:36:00 GMT"; +} // namespace + +// AWS official example from the developer doc. +// https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html +TEST(GrpcAwsRequestSignerTest, AWSOfficialExample) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + "AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "", "GET", + "https://iam.amazonaws.com/?Action=ListUsers&Version=2010-05-08", + "us-east-1", "", + {{"content-type", "application/x-www-form-urlencoded; charset=utf-8"}, + {"x-amz-date", "20150830T123600Z"}}, + &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20150830/us-east-1/iam/aws4_request, " + "SignedHeaders=content-type;host;x-amz-date, " + "Signature=" + "5d672d79c15b13162d9279b0855cfba6789a8edb4c82c400e06b5924a6f2b5d7"); +} + +TEST(GrpcAwsRequestSignerTest, GetDescribeRegions) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kAmzTestAccessKeyId, kAmzTestSecretAccessKey, kAmzTestToken, "GET", + "https://" + "ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", + "us-east-2", "", {{"x-amz-date", kAmzTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ( + signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=ASIARD4OQDT6A77FR3CL/20200811/us-east-2/ec2/aws4_request, " + "SignedHeaders=host;x-amz-date;x-amz-security-token, " + "Signature=" + "631ea80cddfaa545fdadb120dc92c9f18166e38a5c47b50fab9fce476e022855"); +} + +TEST(GrpcAwsRequestSignerTest, PostGetCallerIdentity) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kAmzTestAccessKeyId, kAmzTestSecretAccessKey, kAmzTestToken, "POST", + "https://" + "sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", "", {{"x-amz-date", kAmzTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ( + signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=ASIARD4OQDT6A77FR3CL/20200811/us-east-2/sts/aws4_request, " + "SignedHeaders=host;x-amz-date;x-amz-security-token, " + "Signature=" + "73452984e4a880ffdc5c392355733ec3f5ba310d5e0609a89244440cadfe7a7a"); +} + +TEST(GrpcAwsRequestSignerTest, PostGetCallerIdentityNoToken) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kAmzTestAccessKeyId, kAmzTestSecretAccessKey, "", "POST", + "https://" + "sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", "", {{"x-amz-date", kAmzTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ( + signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=ASIARD4OQDT6A77FR3CL/20200811/us-east-2/sts/aws4_request, " + "SignedHeaders=host;x-amz-date, " + "Signature=" + "d095ba304919cd0d5570ba8a3787884ee78b860f268ed040ba23831d55536d56"); +} + +TEST(GrpcAwsRequestSignerTest, GetHost) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer(kBotoTestAccessKeyId, + kBotoTestSecretAccessKey, kBotoTestToken, + "GET", "https://host.foo.com", "us-east-1", + "", {{"date", kBotoTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host, " + "Signature=" + "b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470"); +} + +TEST(GrpcAwsRequestSignerTest, GetHostDuplicateQueryParam) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "GET", + "https://host.foo.com/?foo=Zoo&foo=aha", "us-east-1", "", + {{"date", kBotoTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host, " + "Signature=" + "be7148d34ebccdc6423b19085378aa0bee970bdc61d144bd1a8c48c33079ab09"); +} + +TEST(GrpcAwsRequestSignerTest, PostWithUpperCaseHeaderKey) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "POST", + "https://host.foo.com/", "us-east-1", "", + {{"date", kBotoTestDate}, {"ZOO", "zoobar"}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host;zoo, " + "Signature=" + "b7a95a52518abbca0964a999a880429ab734f35ebbf1235bd79a5de87756dc4a"); +} + +TEST(GrpcAwsRequestSignerTest, PostWithUpperCaseHeaderValue) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "POST", + "https://host.foo.com/", "us-east-1", "", + {{"date", kBotoTestDate}, {"zoo", "ZOOBAR"}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host;zoo, " + "Signature=" + "273313af9d0c265c531e11db70bbd653f3ba074c1009239e8559d3987039cad7"); +} + +TEST(GrpcAwsRequestSignerTest, SignPostWithHeader) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "POST", + "https://host.foo.com/", "us-east-1", "", + {{"date", kBotoTestDate}, {"p", "phfft"}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host;p, " + "Signature=" + "debf546796015d6f6ded8626f5ce98597c33b47b9164cf6b17b4642036fcb592"); +} + +TEST(GrpcAwsRequestSignerTest, PostWithBodyNoCustomHeaders) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "POST", + "https://host.foo.com/", "us-east-1", "foo=bar", + {{"date", kBotoTestDate}, + {"Content-Type", "application/x-www-form-urlencoded"}}, + &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=content-type;date;host, " + "Signature=" + "5a15b22cf462f047318703b92e6f4f38884e4a7ab7b1d6426ca46a8bd1c26cbc"); +} + +TEST(GrpcAwsRequestSignerTest, SignPostWithQueryString) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + kBotoTestAccessKeyId, kBotoTestSecretAccessKey, kBotoTestToken, "POST", + "https://host.foo.com/?foo=bar", "us-east-1", "", + {{"date", kBotoTestDate}}, &error); + EXPECT_EQ(error, GRPC_ERROR_NONE); + EXPECT_EQ(signer.GetSignedRequestHeaders()["Authorization"], + "AWS4-HMAC-SHA256 " + "Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, " + "SignedHeaders=date;host, " + "Signature=" + "b6e3b79003ce0743a491606ba1035a804593b0efb1e20a11cba83f8c25a57a92"); +} + +TEST(GrpcAwsRequestSignerTest, InvalidUrl) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer("access_key_id", "secret_access_key", + "token", "POST", "invalid_url", + "us-east-1", "", {}, &error); + std::string actual_error_description; + GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, + &actual_error_description)); + EXPECT_EQ(actual_error_description, "Invalid Aws request url."); + GRPC_ERROR_UNREF(error); +} + +TEST(GrpcAwsRequestSignerTest, DuplicateRequestDate) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::AwsRequestSigner signer( + "access_key_id", "secret_access_key", "token", "POST", "invalid_url", + "us-east-1", "", {{"date", kBotoTestDate}, {"x-amz-date", kAmzTestDate}}, + &error); + std::string actual_error_description; + GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, + &actual_error_description)); + EXPECT_EQ(actual_error_description, + "Only one of {date, x-amz-date} can be specified, not both."); + GRPC_ERROR_UNREF(error); +} + +} // namespace testing + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/cel_authorization_engine_test.cc b/test/core/security/cel_authorization_engine_test.cc new file mode 100644 index 00000000..4c94a87d --- /dev/null +++ b/test/core/security/cel_authorization_engine_test.cc @@ -0,0 +1,80 @@ +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/security/authorization/cel_authorization_engine.h" + +#include + +namespace grpc_core { + +class CelAuthorizationEngineTest : public ::testing::Test { + protected: + void SetUp() override { + deny_policy_ = envoy_config_rbac_v3_RBAC_new(arena_.ptr()); + envoy_config_rbac_v3_RBAC_set_action(deny_policy_, 1); + allow_policy_ = envoy_config_rbac_v3_RBAC_new(arena_.ptr()); + envoy_config_rbac_v3_RBAC_set_action(allow_policy_, 0); + } + upb::Arena arena_; + envoy_config_rbac_v3_RBAC* deny_policy_; + envoy_config_rbac_v3_RBAC* allow_policy_; +}; + +TEST_F(CelAuthorizationEngineTest, CreateEngineSuccessOnePolicy) { + std::vector policies{allow_policy_}; + std::unique_ptr engine = + CelAuthorizationEngine::CreateCelAuthorizationEngine(policies); + EXPECT_NE(engine, nullptr) + << "Error: Failed to create CelAuthorizationEngine with one policy."; +} + +TEST_F(CelAuthorizationEngineTest, CreateEngineSuccessTwoPolicies) { + std::vector policies{deny_policy_, allow_policy_}; + std::unique_ptr engine = + CelAuthorizationEngine::CreateCelAuthorizationEngine(policies); + EXPECT_NE(engine, nullptr) + << "Error: Failed to create CelAuthorizationEngine with two policies."; +} + +TEST_F(CelAuthorizationEngineTest, CreateEngineFailNoPolicies) { + std::vector policies{}; + std::unique_ptr engine = + CelAuthorizationEngine::CreateCelAuthorizationEngine(policies); + EXPECT_EQ(engine, nullptr) + << "Error: Created CelAuthorizationEngine without policies."; +} + +TEST_F(CelAuthorizationEngineTest, CreateEngineFailTooManyPolicies) { + std::vector policies{deny_policy_, allow_policy_, + deny_policy_}; + std::unique_ptr engine = + CelAuthorizationEngine::CreateCelAuthorizationEngine(policies); + EXPECT_EQ(engine, nullptr) + << "Error: Created CelAuthorizationEngine with more than two policies."; +} + +TEST_F(CelAuthorizationEngineTest, CreateEngineFailWrongPolicyOrder) { + std::vector policies{allow_policy_, deny_policy_}; + std::unique_ptr engine = + CelAuthorizationEngine::CreateCelAuthorizationEngine(policies); + EXPECT_EQ(engine, nullptr) << "Error: Created CelAuthorizationEngine with " + "policies in the wrong order."; +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/security/check_gcp_environment_linux_test.cc b/test/core/security/check_gcp_environment_linux_test.cc new file mode 100644 index 00000000..2f266cb4 --- /dev/null +++ b/test/core/security/check_gcp_environment_linux_test.cc @@ -0,0 +1,86 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +#if GPR_LINUX + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/tmpfile.h" + +static bool check_bios_data_linux_test(const char* data) { + /* Create a file with contents data. */ + char* filename = nullptr; + FILE* fp = gpr_tmpfile("check_gcp_environment_test", &filename); + GPR_ASSERT(filename != nullptr); + GPR_ASSERT(fp != nullptr); + GPR_ASSERT(fwrite(data, 1, strlen(data), fp) == strlen(data)); + fclose(fp); + bool result = grpc_core::internal::check_bios_data( + reinterpret_cast(filename)); + /* Cleanup. */ + remove(filename); + gpr_free(filename); + return result; +} + +static void test_gcp_environment_check_success() { + /* Exact match. */ + GPR_ASSERT(check_bios_data_linux_test("Google")); + GPR_ASSERT(check_bios_data_linux_test("Google Compute Engine")); + /* With leading and trailing whitespaces. */ + GPR_ASSERT(check_bios_data_linux_test(" Google ")); + GPR_ASSERT(check_bios_data_linux_test("Google ")); + GPR_ASSERT(check_bios_data_linux_test(" Google")); + GPR_ASSERT(check_bios_data_linux_test(" Google Compute Engine ")); + GPR_ASSERT(check_bios_data_linux_test("Google Compute Engine ")); + GPR_ASSERT(check_bios_data_linux_test(" Google Compute Engine")); + /* With leading and trailing \t and \n. */ + GPR_ASSERT(check_bios_data_linux_test("\t\tGoogle Compute Engine\t")); + GPR_ASSERT(check_bios_data_linux_test("Google Compute Engine\n")); + GPR_ASSERT(check_bios_data_linux_test("\n\n\tGoogle Compute Engine \n\t\t")); +} + +static void test_gcp_environment_check_failure() { + GPR_ASSERT(!check_bios_data_linux_test("non_existing-file")); + GPR_ASSERT(!check_bios_data_linux_test("Google-Chrome")); + GPR_ASSERT(!check_bios_data_linux_test("Amazon")); + GPR_ASSERT(!check_bios_data_linux_test("Google-Chrome\t\t")); + GPR_ASSERT(!check_bios_data_linux_test("Amazon")); + GPR_ASSERT(!check_bios_data_linux_test("\n")); +} + +int main(int /*argc*/, char** /*argv*/) { + /* Tests. */ + test_gcp_environment_check_success(); + test_gcp_environment_check_failure(); + return 0; +} + +#else // GPR_LINUX + +int main(int /*argc*/, char** /*argv*/) { return 0; } + +#endif // GPR_LINUX diff --git a/test/core/security/check_gcp_environment_windows_test.cc b/test/core/security/check_gcp_environment_windows_test.cc new file mode 100644 index 00000000..19341732 --- /dev/null +++ b/test/core/security/check_gcp_environment_windows_test.cc @@ -0,0 +1,93 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/alts/check_gcp_environment.h" + +#ifdef GPR_WINDOWS + +#include +#include + +#include +#include + +#include "src/core/lib/gpr/tmpfile.h" + +namespace grpc_core { +namespace internal { + +bool check_windows_registry_product_name(HKEY root_key, + const char* reg_key_path, + const char* reg_key_name); + +} // namespace internal +} // namespace grpc_core + +static bool check_bios_data_windows_test(const char* data) { + char const reg_key_path[] = "SYSTEM\\HardwareConfig\\Current\\"; + char const reg_key_name[] = "grpcTestValueName"; + // Modify the registry for the current user to contain the + // test value. We cannot use the system registry because the + // user may not have privileges to change it. + auto rc = RegSetKeyValueA(HKEY_CURRENT_USER, reg_key_path, reg_key_name, + REG_SZ, reinterpret_cast(data), + static_cast(strlen(data))); + if (rc != 0) { + return false; + } + + auto result = grpc_core::internal::check_windows_registry_product_name( + HKEY_CURRENT_USER, reg_key_path, reg_key_name); + + (void)RegDeleteKeyValueA(HKEY_CURRENT_USER, reg_key_path, reg_key_name); + + return result; +} + +static void test_gcp_environment_check_success() { + // This is the only value observed in production. + GPR_ASSERT(check_bios_data_windows_test("Google Compute Engine")); + // Be generous and accept other values that were accepted by the previous + // implementation. + GPR_ASSERT(check_bios_data_windows_test("Google")); + GPR_ASSERT(check_bios_data_windows_test("Google\n")); + GPR_ASSERT(check_bios_data_windows_test("Google\r")); + GPR_ASSERT(check_bios_data_windows_test("Google\r\n")); + GPR_ASSERT(check_bios_data_windows_test(" Google \r\n")); + GPR_ASSERT(check_bios_data_windows_test(" \t\t Google\r\n")); + GPR_ASSERT(check_bios_data_windows_test(" \t\t Google\t\t \r\n")); +} + +static void test_gcp_environment_check_failure() { + GPR_ASSERT(!check_bios_data_windows_test("\t\tAmazon\n")); + GPR_ASSERT(!check_bios_data_windows_test(" Amazon\r\n")); +} + +int main(int argc, char** argv) { + /* Tests. */ + test_gcp_environment_check_success(); + test_gcp_environment_check_failure(); + return 0; +} +#else // GPR_WINDOWS + +int main(int /*argc*/, char** /*argv*/) { return 0; } + +#endif // GPR_WINDOWS diff --git a/test/core/security/create_jwt.cc b/test/core/security/create_jwt.cc new file mode 100644 index 00000000..a82e07a8 --- /dev/null +++ b/test/core/security/create_jwt.cc @@ -0,0 +1,98 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/jwt/jwt_credentials.h" +#include "test/core/util/cmdline.h" + +void create_jwt(const char* json_key_file_path, const char* service_url, + const char* scope) { + grpc_auth_json_key key; + char* jwt; + grpc_slice json_key_data; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(json_key_file_path, 1, &json_key_data))); + key = grpc_auth_json_key_create_from_string( + reinterpret_cast GRPC_SLICE_START_PTR(json_key_data)); + grpc_slice_unref(json_key_data); + if (!grpc_auth_json_key_is_valid(&key)) { + fprintf(stderr, "Could not parse json key.\n"); + fflush(stderr); + exit(1); + } + jwt = grpc_jwt_encode_and_sign( + &key, service_url == nullptr ? GRPC_JWT_OAUTH2_AUDIENCE : service_url, + grpc_max_auth_token_lifetime(), scope); + grpc_auth_json_key_destruct(&key); + if (jwt == nullptr) { + fprintf(stderr, "Could not create JWT.\n"); + fflush(stderr); + exit(1); + } + fprintf(stdout, "%s\n", jwt); + gpr_free(jwt); +} + +int main(int argc, char** argv) { + const char* scope = nullptr; + const char* json_key_file_path = nullptr; + const char* service_url = nullptr; + grpc_init(); + gpr_cmdline* cl = gpr_cmdline_create("create_jwt"); + gpr_cmdline_add_string(cl, "json_key", "File path of the json key.", + &json_key_file_path); + gpr_cmdline_add_string(cl, "scope", + "OPTIONAL Space delimited permissions. Mutually " + "exclusive with service_url", + &scope); + gpr_cmdline_add_string(cl, "service_url", + "OPTIONAL service URL. Mutually exclusive with scope.", + &service_url); + gpr_cmdline_parse(cl, argc, argv); + + if (json_key_file_path == nullptr) { + fprintf(stderr, "Missing --json_key option.\n"); + fflush(stderr); + exit(1); + } + if (scope != nullptr) { + if (service_url != nullptr) { + fprintf(stderr, + "Options --scope and --service_url are mutually exclusive.\n"); + fflush(stderr); + exit(1); + } + } else if (service_url == nullptr) { + fprintf(stderr, "Need one of --service_url or --scope options.\n"); + fflush(stderr); + exit(1); + } + + create_jwt(json_key_file_path, service_url, scope); + + gpr_cmdline_destroy(cl); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc new file mode 100644 index 00000000..19267a2e --- /dev/null +++ b/test/core/security/credentials_test.cc @@ -0,0 +1,3659 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/lib/security/credentials/credentials.h" + +#include +#include + +#include + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/credentials/composite/composite_credentials.h" +#include "src/core/lib/security/credentials/external/aws_external_account_credentials.h" +#include "src/core/lib/security/credentials/external/external_account_credentials.h" +#include "src/core/lib/security/credentials/external/file_external_account_credentials.h" +#include "src/core/lib/security/credentials/external/url_external_account_credentials.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/security/credentials/google_default/google_default_credentials.h" +#include "src/core/lib/security/credentials/jwt/jwt_credentials.h" +#include "src/core/lib/security/credentials/oauth2/oauth2_credentials.h" +#include "src/core/lib/security/transport/auth_filters.h" +#include "src/core/lib/uri/uri_parser.h" +#include "test/core/util/test_config.h" + +using grpc_core::internal::grpc_flush_cached_google_default_credentials; +using grpc_core::internal::set_gce_tenancy_checker_for_testing; + +/* -- Constants. -- */ + +static const char test_google_iam_authorization_token[] = "blahblahblhahb"; +static const char test_google_iam_authority_selector[] = "respectmyauthoritah"; +static const char test_oauth2_bearer_token[] = + "Bearer blaaslkdjfaslkdfasdsfasf"; + +/* This JSON key was generated with the GCE console and revoked immediately. + The identifiers have been changed as well. + Maximum size for a string literal is 509 chars in C89, yay! */ +static const char test_json_key_str_part1[] = + "{ \"private_key\": \"-----BEGIN PRIVATE KEY-----" + "\\nMIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAOEvJsnoHnyHkXcp\\n7mJE" + "qg" + "WGjiw71NfXByguekSKho65FxaGbsnSM9SMQAqVk7Q2rG+I0OpsT0LrWQtZ\\nyjSeg/" + "rWBQvS4hle4LfijkP3J5BG+" + "IXDMP8RfziNRQsenAXDNPkY4kJCvKux2xdD\\nOnVF6N7dL3nTYZg+" + "uQrNsMTz9UxVAgMBAAECgYEAzbLewe1xe9vy+2GoSsfib+28\\nDZgSE6Bu/" + "zuFoPrRc6qL9p2SsnV7txrunTyJkkOnPLND9ABAXybRTlcVKP/sGgza\\n/" + "8HpCqFYM9V8f34SBWfD4fRFT+n/" + "73cfRUtGXdXpseva2lh8RilIQfPhNZAncenU\\ngqXjDvpkypEusgXAykECQQD+"; +static const char test_json_key_str_part2[] = + "53XxNVnxBHsYb+AYEfklR96yVi8HywjVHP34+OQZ\\nCslxoHQM8s+" + "dBnjfScLu22JqkPv04xyxmt0QAKm9+vTdAkEA4ib7YvEAn2jXzcCI\\nEkoy2L/" + "XydR1GCHoacdfdAwiL2npOdnbvi4ZmdYRPY1LSTO058tQHKVXV7NLeCa3\\nAARh2QJBAMKeDA" + "G" + "W303SQv2cZTdbeaLKJbB5drz3eo3j7dDKjrTD9JupixFbzcGw\\n8FZi5c8idxiwC36kbAL6Hz" + "A" + "ZoX+ofI0CQE6KCzPJTtYNqyShgKAZdJ8hwOcvCZtf\\n6z8RJm0+" + "6YBd38lfh5j8mZd7aHFf6I17j5AQY7oPEc47TjJj/" + "5nZ68ECQQDvYuI3\\nLyK5fS8g0SYbmPOL9TlcHDOqwG0mrX9qpg5DC2fniXNSrrZ64GTDKdzZ" + "Y" + "Ap6LI9W\\nIqv4vr6y38N79TTC\\n-----END PRIVATE KEY-----\\n\", "; +static const char test_json_key_str_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + +/* Test refresh token. */ +static const char test_refresh_token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"," + " \"type\": \"authorized_user\"}"; + +/* Test external account credentials. */ +static const char test_external_account_credentials_str[] = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"\",\"token_url\":\"https://" + "sts.googleapis.com:5555/" + "token\",\"token_info_url\":\"\",\"credential_source\":{\"file\":" + "\"credentials_file_path\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + +static const char test_external_account_credentials_multi_pattern_sts_str[] = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"https://sts.test.googleapis.com:5555/" + "service_account_impersonation_url\",\"token_url\":\"https://" + "test.sts.googleapis.com:5555/token\",\"token_info_url\":\"https://" + "test-sts.googleapis.com:5555/" + "token_info\",\"credential_source\":{\"file\":\"credentials_file_path\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + +static const char test_external_account_credentials_multi_pattern_iam_str[] = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"https://iamcredentials.test.googleapis.com:5555/" + "service_account_impersonation_url\",\"token_url\":\"https://" + "test.iamcredentials.googleapis.com:5555/" + "token\",\"token_info_url\":\"https://" + "test-iamcredentials.googleapis.com:5555/" + "token_info\",\"credential_source\":{\"file\":\"credentials_file_path\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + +static const char valid_oauth2_json_response[] = + "{\"access_token\":\"ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_\"," + " \"expires_in\":3599, " + " \"token_type\":\"Bearer\"}"; + +static const char valid_sts_json_response[] = + "{\"access_token\":\"ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_\"," + " \"expires_in\":3599, " + " \"issued_token_type\":\"urn:ietf:params:oauth:token-type:access_token\", " + " \"token_type\":\"Bearer\"}"; + +static const char test_scope[] = "perm1 perm2"; + +static const char test_signed_jwt[] = + "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImY0OTRkN2M1YWU2MGRmOTcyNmM4YW" + "U0MDcyZTViYTdmZDkwODg2YzcifQ"; +static const char test_signed_jwt_token_type[] = + "urn:ietf:params:oauth:token-type:id_token"; +static const char test_signed_jwt2[] = + "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImY0OTRkN2M1YWU2MGRmOTcyNmM5YW" + "U2MDcyZTViYTdnZDkwODg5YzcifQ"; +static const char test_signed_jwt_token_type2[] = + "urn:ietf:params:oauth:token-type:jwt"; +static const char test_signed_jwt_path_prefix[] = "test_sign_jwt"; + +static const char test_service_url[] = "https://foo.com/foo.v1"; +static const char test_service_url_no_service_name[] = "https://foo.com/"; +static const char other_test_service_url[] = "https://bar.com/bar.v1"; +static const char other_test_service_url_no_service_name[] = "https://bar.com/"; + +static const char test_sts_endpoint_url[] = + "https://foo.com:5555/v1/token-exchange"; + +static const char test_method[] = "ThisIsNotAMethod"; + +static const char valid_external_account_creds_token_exchange_response[] = + "{\"access_token\":\"token_exchange_access_token\"," + " \"expires_in\":3599," + " \"token_type\":\"Bearer\"}"; + +static const char + valid_external_account_creds_service_account_impersonation_response[] = + "{\"accessToken\":\"service_account_impersonation_access_token\"," + " \"expireTime\":\"2050-01-01T00:00:00Z\"}"; + +static const char + valid_url_external_account_creds_options_credential_source_format_text[] = + "{\"url\":\"https://foo.com:5555/generate_subject_token_format_text\"," + "\"headers\":{\"Metadata-Flavor\":\"Google\"}}"; + +static const char + valid_url_external_account_creds_options_credential_source_with_qurey_params_format_text + [] = "{\"url\":\"https://foo.com:5555/" + "path/to/url/creds?p1=v1&p2=v2\"," + "\"headers\":{\"Metadata-Flavor\":\"Google\"}}"; + +static const char + valid_url_external_account_creds_retrieve_subject_token_response_format_text + [] = "test_subject_token"; + +static const char + valid_url_external_account_creds_options_credential_source_format_json[] = + "{\"url\":\"https://foo.com:5555/generate_subject_token_format_json\"," + "\"headers\":{\"Metadata-Flavor\":\"Google\"}," + "\"format\":{\"type\":\"json\",\"subject_token_field_name\":\"access_" + "token\"}}"; + +static const char + valid_url_external_account_creds_retrieve_subject_token_response_format_json + [] = "{\"access_token\":\"test_subject_token\"}"; + +static const char + invalid_url_external_account_creds_options_credential_source[] = + "{\"url\":\"invalid_credential_source_url\"," + "\"headers\":{\"Metadata-Flavor\":\"Google\"}}"; + +static const char + valid_aws_external_account_creds_retrieve_signing_keys_response[] = + "{\"AccessKeyId\":\"test_access_key_id\",\"SecretAccessKey\":" + "\"test_secret_access_key\",\"Token\":\"test_token\"}"; + +static const char valid_aws_external_account_creds_options_credential_source[] = + "{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\"," + "\"url\":\"https://foo.com:5555/url\"," + "\"regional_cred_verification_url\":\"https://foo.com:5555/" + "regional_cred_verification_url_{region}\"}"; + +static const char + invalid_aws_external_account_creds_options_credential_source_unmatched_environment_id + [] = "{\"environment_id\":\"unsupported_aws_version\"," + "\"region_url\":\"https://foo.com:5555/region_url\"," + "\"url\":\"https://foo.com:5555/url\"," + "\"regional_cred_verification_url\":\"https://foo.com:5555/" + "regional_cred_verification_url_{region}\"}"; + +static const char + invalid_aws_external_account_creds_options_credential_source_invalid_region_url + [] = "{\"environment_id\":\"aws1\"," + "\"region_url\":\"invalid_region_url\"," + "\"url\":\"https://foo.com:5555/url\"," + "\"regional_cred_verification_url\":\"https://foo.com:5555/" + "regional_cred_verification_url_{region}\"}"; + +static const char + invalid_aws_external_account_creds_options_credential_source_invalid_url[] = + "{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\"," + "\"url\":\"invalid_url\"," + "\"regional_cred_verification_url\":\"https://foo.com:5555/" + "regional_cred_verification_url_{region}\"}"; + +static const char + invalid_aws_external_account_creds_options_credential_source_missing_role_name + [] = "{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\"," + "\"url\":\"https://foo.com:5555/url_no_role_name\"," + "\"regional_cred_verification_url\":\"https://foo.com:5555/" + "regional_cred_verification_url_{region}\"}"; + +static const char + invalid_aws_external_account_creds_options_credential_source_invalid_regional_cred_verification_url + [] = "{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\"," + "\"url\":\"https://foo.com:5555/url_no_role_name\"," + "\"regional_cred_verification_url\":\"invalid_regional_cred_" + "verification_url\"}"; + +/* -- Global state flags. -- */ + +static bool g_test_is_on_gce = false; + +static bool g_test_gce_tenancy_checker_called = false; + +/* -- Utils. -- */ + +static char* test_json_key_str(void) { + size_t result_len = strlen(test_json_key_str_part1) + + strlen(test_json_key_str_part2) + + strlen(test_json_key_str_part3); + char* result = static_cast(gpr_malloc(result_len + 1)); + char* current = result; + strcpy(result, test_json_key_str_part1); + current += strlen(test_json_key_str_part1); + strcpy(current, test_json_key_str_part2); + current += strlen(test_json_key_str_part2); + strcpy(current, test_json_key_str_part3); + return result; +} + +static grpc_httpcli_response http_response(int status, const char* body) { + grpc_httpcli_response response; + response = {}; + response.status = status; + response.body = gpr_strdup(const_cast(body)); + response.body_length = strlen(body); + return response; +} + +/* -- Tests. -- */ + +static void test_empty_md_array(void) { + grpc_core::ExecCtx exec_ctx; + grpc_credentials_mdelem_array md_array; + md_array = {}; + GPR_ASSERT(md_array.md == nullptr); + GPR_ASSERT(md_array.size == 0); + grpc_credentials_mdelem_array_destroy(&md_array); +} + +static void test_add_to_empty_md_array(void) { + grpc_core::ExecCtx exec_ctx; + grpc_credentials_mdelem_array md_array; + md_array = {}; + const char* key = "hello"; + const char* value = "there blah blah blah blah blah blah blah"; + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_copied_string(key), grpc_slice_from_copied_string(value)); + grpc_credentials_mdelem_array_add(&md_array, md); + GPR_ASSERT(md_array.size == 1); + GPR_ASSERT(grpc_mdelem_eq(md, md_array.md[0])); + GRPC_MDELEM_UNREF(md); + grpc_credentials_mdelem_array_destroy(&md_array); +} + +static void test_add_abunch_to_md_array(void) { + grpc_core::ExecCtx exec_ctx; + grpc_credentials_mdelem_array md_array; + md_array = {}; + const char* key = "hello"; + const char* value = "there blah blah blah blah blah blah blah"; + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_copied_string(key), grpc_slice_from_copied_string(value)); + size_t num_entries = 1000; + for (size_t i = 0; i < num_entries; ++i) { + grpc_credentials_mdelem_array_add(&md_array, md); + } + for (size_t i = 0; i < num_entries; ++i) { + GPR_ASSERT(grpc_mdelem_eq(md_array.md[i], md)); + } + GRPC_MDELEM_UNREF(md); + grpc_credentials_mdelem_array_destroy(&md_array); +} + +static void test_oauth2_token_fetcher_creds_parsing_ok(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = + http_response(200, valid_oauth2_json_response); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == GRPC_CREDENTIALS_OK); + GPR_ASSERT(token_lifetime == 3599 * GPR_MS_PER_SEC); + GPR_ASSERT(grpc_slice_str_cmp(GRPC_MDKEY(token_md), "authorization") == 0); + GPR_ASSERT(grpc_slice_str_cmp(GRPC_MDVALUE(token_md), + "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_") == + 0); + GRPC_MDELEM_UNREF(token_md); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_bad_http_status(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = + http_response(401, valid_oauth2_json_response); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_empty_http_body(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = http_response(200, ""); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_invalid_json(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = + http_response(200, + "{\"access_token\":\"ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_\"," + " \"expires_in\":3599, " + " \"token_type\":\"Bearer\""); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_missing_token(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = http_response(200, + "{" + " \"expires_in\":3599, " + " \"token_type\":\"Bearer\"}"); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_missing_token_type(void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = + http_response(200, + "{\"access_token\":\"ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_\"," + " \"expires_in\":3599, " + "}"); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +static void test_oauth2_token_fetcher_creds_parsing_missing_token_lifetime( + void) { + grpc_core::ExecCtx exec_ctx; + grpc_mdelem token_md = GRPC_MDNULL; + grpc_millis token_lifetime; + grpc_httpcli_response response = + http_response(200, + "{\"access_token\":\"ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_\"," + " \"token_type\":\"Bearer\"}"); + GPR_ASSERT(grpc_oauth2_token_fetcher_credentials_parse_server_response( + &response, &token_md, &token_lifetime) == + GRPC_CREDENTIALS_ERROR); + grpc_http_response_destroy(&response); +} + +namespace { + +class RequestMetadataState { + public: + static RequestMetadataState* NewInstance( + grpc_error_handle expected_error, + std::map expected) { + RequestMetadataState* state = new RequestMetadataState( + expected_error, std::move(expected), + grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create())); + return state; + } + + private: + RequestMetadataState(grpc_error_handle expected_error, + std::map expected, + grpc_polling_entity pollent) + : expected_error_(expected_error), + expected_(expected), + pollent_(pollent) { + GRPC_CLOSURE_INIT(&on_request_metadata_, OnRequestMetadata, this, + grpc_schedule_on_exec_ctx); + } + + public: + ~RequestMetadataState() { + grpc_credentials_mdelem_array_destroy(&md_array_); + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); + } + + void RunRequestMetadataTest(grpc_call_credentials* creds, + grpc_auth_metadata_context auth_md_ctx) { + grpc_error_handle error = GRPC_ERROR_NONE; + if (creds->get_request_metadata(&pollent_, auth_md_ctx, &md_array_, + &on_request_metadata_, &error)) { + // Synchronous result. Invoke the callback directly. + CheckRequestMetadata(error); + GRPC_ERROR_UNREF(error); + } + } + + private: + static void OnRequestMetadata(void* arg, grpc_error_handle error) { + RequestMetadataState* state = static_cast(arg); + state->CheckRequestMetadata(error); + } + + void CheckRequestMetadata(grpc_error_handle error) { + gpr_log(GPR_INFO, "expected_error: %s", + grpc_error_std_string(expected_error_).c_str()); + gpr_log(GPR_INFO, "actual_error: %s", grpc_error_std_string(error).c_str()); + if (expected_error_ == GRPC_ERROR_NONE) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + } else { + std::string expected_error; + GPR_ASSERT(grpc_error_get_str(expected_error_, GRPC_ERROR_STR_DESCRIPTION, + &expected_error)); + std::string actual_error; + GPR_ASSERT( + grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &actual_error)); + GPR_ASSERT(expected_error == actual_error); + GRPC_ERROR_UNREF(expected_error_); + } + gpr_log(GPR_INFO, "expected_size=%" PRIdPTR " actual_size=%" PRIdPTR, + expected_.size(), md_array_.size); + GPR_ASSERT(md_array_.size == expected_.size()); + CheckMetadata(expected_, &md_array_); + delete this; + } + + static void CheckMetadata(const std::map& expected, + grpc_credentials_mdelem_array* md_array) { + for (auto const& i : expected) { + size_t j; + for (j = 0; j < md_array->size; ++j) { + absl::string_view actual_key = + grpc_core::StringViewFromSlice(GRPC_MDKEY(md_array->md[j])); + if (actual_key == i.first) { + absl::string_view actual_value = + grpc_core::StringViewFromSlice(GRPC_MDVALUE(md_array->md[j])); + GPR_ASSERT(actual_value == i.second); + break; + } + } + if (j == md_array->size) { + gpr_log(GPR_ERROR, "key %s not found", i.first.c_str()); + GPR_ASSERT(0); + } + } + } + + private: + grpc_error_handle expected_error_; + std::map expected_; + grpc_credentials_mdelem_array md_array_; + grpc_closure on_request_metadata_; + grpc_polling_entity pollent_; +}; + +} // namespace + +static void test_google_iam_creds(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + test_google_iam_authorization_token}, + {GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + test_google_iam_authority_selector}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_call_credentials* creds = grpc_google_iam_credentials_create( + test_google_iam_authorization_token, test_google_iam_authority_selector, + nullptr); + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + state->RunRequestMetadataTest(creds, auth_md_ctx); + creds->Unref(); +} + +static void test_access_token_creds(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {GRPC_AUTHORIZATION_METADATA_KEY, "Bearer blah"}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_call_credentials* creds = + grpc_access_token_credentials_create("blah", nullptr); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + state->RunRequestMetadataTest(creds, auth_md_ctx); + creds->Unref(); +} + +namespace { +class check_channel_oauth2 final : public grpc_channel_credentials { + public: + check_channel_oauth2() : grpc_channel_credentials("mock") {} + ~check_channel_oauth2() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* /*target*/, const grpc_channel_args* /*args*/, + grpc_channel_args** /*new_args*/) override { + GPR_ASSERT(strcmp(type(), "mock") == 0); + GPR_ASSERT(call_creds != nullptr); + GPR_ASSERT(strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == + 0); + return nullptr; + } +}; +} // namespace + +static void test_channel_oauth2_composite_creds(void) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args* new_args; + grpc_channel_credentials* channel_creds = new check_channel_oauth2(); + grpc_call_credentials* oauth2_creds = + grpc_access_token_credentials_create("blah", nullptr); + grpc_channel_credentials* channel_oauth2_creds = + grpc_composite_channel_credentials_create(channel_creds, oauth2_creds, + nullptr); + grpc_channel_credentials_release(channel_creds); + grpc_call_credentials_release(oauth2_creds); + channel_oauth2_creds->create_security_connector(nullptr, nullptr, nullptr, + &new_args); + grpc_channel_credentials_release(channel_oauth2_creds); +} + +static void test_oauth2_google_iam_composite_creds(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {GRPC_AUTHORIZATION_METADATA_KEY, test_oauth2_bearer_token}, + {GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + test_google_iam_authorization_token}, + {GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + test_google_iam_authority_selector}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_call_credentials* oauth2_creds = grpc_md_only_test_credentials_create( + "authorization", test_oauth2_bearer_token, false); + + /* Check security level of fake credentials. */ + GPR_ASSERT(oauth2_creds->min_security_level() == GRPC_SECURITY_NONE); + + grpc_call_credentials* google_iam_creds = grpc_google_iam_credentials_create( + test_google_iam_authorization_token, test_google_iam_authority_selector, + nullptr); + grpc_call_credentials* composite_creds = + grpc_composite_call_credentials_create(oauth2_creds, google_iam_creds, + nullptr); + /* Check security level of composite credentials. */ + GPR_ASSERT(composite_creds->min_security_level() == + GRPC_PRIVACY_AND_INTEGRITY); + + oauth2_creds->Unref(); + google_iam_creds->Unref(); + GPR_ASSERT(strcmp(composite_creds->type(), + GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); + const grpc_composite_call_credentials::CallCredentialsList& creds_list = + static_cast(composite_creds) + ->inner(); + GPR_ASSERT(creds_list.size() == 2); + GPR_ASSERT(strcmp(creds_list[0]->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == + 0); + GPR_ASSERT(strcmp(creds_list[1]->type(), GRPC_CALL_CREDENTIALS_TYPE_IAM) == + 0); + state->RunRequestMetadataTest(composite_creds, auth_md_ctx); + composite_creds->Unref(); +} + +namespace { +class check_channel_oauth2_google_iam final : public grpc_channel_credentials { + public: + check_channel_oauth2_google_iam() : grpc_channel_credentials("mock") {} + ~check_channel_oauth2_google_iam() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* /*target*/, const grpc_channel_args* /*args*/, + grpc_channel_args** /*new_args*/) override { + GPR_ASSERT(strcmp(type(), "mock") == 0); + GPR_ASSERT(call_creds != nullptr); + GPR_ASSERT( + strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); + const grpc_composite_call_credentials::CallCredentialsList& creds_list = + static_cast(call_creds.get()) + ->inner(); + GPR_ASSERT( + strcmp(creds_list[0]->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); + GPR_ASSERT(strcmp(creds_list[1]->type(), GRPC_CALL_CREDENTIALS_TYPE_IAM) == + 0); + return nullptr; + } +}; +} // namespace + +static void test_channel_oauth2_google_iam_composite_creds(void) { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args* new_args; + grpc_channel_credentials* channel_creds = + new check_channel_oauth2_google_iam(); + grpc_call_credentials* oauth2_creds = + grpc_access_token_credentials_create("blah", nullptr); + grpc_channel_credentials* channel_oauth2_creds = + grpc_composite_channel_credentials_create(channel_creds, oauth2_creds, + nullptr); + grpc_call_credentials* google_iam_creds = grpc_google_iam_credentials_create( + test_google_iam_authorization_token, test_google_iam_authority_selector, + nullptr); + + grpc_channel_credentials* channel_oauth2_iam_creds = + grpc_composite_channel_credentials_create(channel_oauth2_creds, + google_iam_creds, nullptr); + grpc_channel_credentials_release(channel_creds); + grpc_call_credentials_release(oauth2_creds); + grpc_channel_credentials_release(channel_oauth2_creds); + grpc_call_credentials_release(google_iam_creds); + + channel_oauth2_iam_creds->create_security_connector(nullptr, nullptr, nullptr, + &new_args); + + grpc_channel_credentials_release(channel_oauth2_iam_creds); +} + +static void validate_compute_engine_http_request( + const grpc_httpcli_request* request) { + GPR_ASSERT(request->handshaker != &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, "metadata.google.internal.") == 0); + GPR_ASSERT( + strcmp(request->http.path, + "/computeMetadata/v1/instance/service-accounts/default/token") == + 0); + GPR_ASSERT(request->http.hdr_count == 1); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Metadata-Flavor") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, "Google") == 0); +} + +static int compute_engine_httpcli_get_success_override( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + validate_compute_engine_http_request(request); + *response = http_response(200, valid_oauth2_json_response); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int compute_engine_httpcli_get_failure_override( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + validate_compute_engine_http_request(request); + *response = http_response(403, "Not Authorized."); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int httpcli_post_should_not_be_called( + const grpc_httpcli_request* /*request*/, const char* /*body_bytes*/, + size_t /*body_size*/, grpc_millis /*deadline*/, grpc_closure* /*on_done*/, + grpc_httpcli_response* /*response*/) { + GPR_ASSERT("HTTP POST should not be called" == nullptr); + return 1; +} + +static int httpcli_get_should_not_be_called( + const grpc_httpcli_request* /*request*/, grpc_millis /*deadline*/, + grpc_closure* /*on_done*/, grpc_httpcli_response* /*response*/) { + GPR_ASSERT("HTTP GET should not be called" == nullptr); + return 1; +} + +static void test_compute_engine_creds_success() { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + const char expected_creds_debug_string[] = + "GoogleComputeEngineTokenFetcherCredentials{" + "OAuth2TokenFetcherCredentials}"; + grpc_call_credentials* creds = + grpc_google_compute_engine_credentials_create(nullptr); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + + /* First request: http get should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_compute_engine_creds_failure(void) { + grpc_core::ExecCtx exec_ctx; + const char expected_creds_debug_string[] = + "GoogleComputeEngineTokenFetcherCredentials{" + "OAuth2TokenFetcherCredentials}"; + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token."), + {}); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_call_credentials* creds = + grpc_google_compute_engine_credentials_create(nullptr); + grpc_httpcli_set_override(compute_engine_httpcli_get_failure_override, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void validate_refresh_token_http_request( + const grpc_httpcli_request* request, const char* body, size_t body_size) { + /* The content of the assertion is tested extensively in json_token_test. */ + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + std::string expected_body = absl::StrFormat( + GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING, + "32555999999.apps.googleusercontent.com", "EmssLNjJy1332hD4KFsecret", + "1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42"); + GPR_ASSERT(expected_body.size() == body_size); + GPR_ASSERT(memcmp(expected_body.data(), body, body_size) == 0); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, GRPC_GOOGLE_OAUTH2_SERVICE_HOST) == 0); + GPR_ASSERT( + strcmp(request->http.path, GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH) == 0); + GPR_ASSERT(request->http.hdr_count == 1); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); +} + +static int refresh_token_httpcli_post_success( + const grpc_httpcli_request* request, const char* body, size_t body_size, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + validate_refresh_token_http_request(request, body, body_size); + *response = http_response(200, valid_oauth2_json_response); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int token_httpcli_post_failure(const grpc_httpcli_request* /*request*/, + const char* /*body*/, + size_t /*body_size*/, + grpc_millis /*deadline*/, + grpc_closure* on_done, + grpc_httpcli_response* response) { + *response = http_response(403, "Not Authorized."); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void test_refresh_token_creds_success(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + const char expected_creds_debug_string[] = + "GoogleRefreshToken{ClientID:32555999999.apps.googleusercontent.com," + "OAuth2TokenFetcherCredentials}"; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_call_credentials* creds = grpc_google_refresh_token_credentials_create( + test_refresh_token_str, nullptr); + + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + + /* First request: http put should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + refresh_token_httpcli_post_success); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_refresh_token_creds_failure(void) { + grpc_core::ExecCtx exec_ctx; + const char expected_creds_debug_string[] = + "GoogleRefreshToken{ClientID:32555999999.apps.googleusercontent.com," + "OAuth2TokenFetcherCredentials}"; + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token."), + {}); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_call_credentials* creds = grpc_google_refresh_token_credentials_create( + test_refresh_token_str, nullptr); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + token_httpcli_post_failure); + state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_valid_sts_creds_options(void) { + grpc_sts_credentials_options valid_options = { + test_sts_endpoint_url, // sts_endpoint_url + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + test_signed_jwt_path_prefix, // subject_token_path + test_signed_jwt_token_type, // subject_token_type + nullptr, // actor_token_path + nullptr // actor_token_type + }; + absl::StatusOr sts_url = + grpc_core::ValidateStsCredentialsOptions(&valid_options); + GPR_ASSERT(sts_url.ok()); + absl::string_view host; + absl::string_view port; + GPR_ASSERT(grpc_core::SplitHostPort(sts_url->authority(), &host, &port)); + GPR_ASSERT(host == "foo.com"); + GPR_ASSERT(port == "5555"); +} + +static void test_invalid_sts_creds_options(void) { + grpc_sts_credentials_options invalid_options = { + test_sts_endpoint_url, // sts_endpoint_url + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + nullptr, // subject_token_path (Required) + test_signed_jwt_token_type, // subject_token_type + nullptr, // actor_token_path + nullptr // actor_token_type + }; + absl::StatusOr url_should_be_invalid = + grpc_core::ValidateStsCredentialsOptions(&invalid_options); + GPR_ASSERT(!url_should_be_invalid.ok()); + + invalid_options = { + test_sts_endpoint_url, // sts_endpoint_url + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + test_signed_jwt_path_prefix, // subject_token_path + nullptr, // subject_token_type (Required) + nullptr, // actor_token_path + nullptr // actor_token_type + }; + url_should_be_invalid = + grpc_core::ValidateStsCredentialsOptions(&invalid_options); + GPR_ASSERT(!url_should_be_invalid.ok()); + + invalid_options = { + nullptr, // sts_endpoint_url (Required) + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + test_signed_jwt_path_prefix, // subject_token_path + test_signed_jwt_token_type, // subject_token_type (Required) + nullptr, // actor_token_path + nullptr // actor_token_type + }; + url_should_be_invalid = + grpc_core::ValidateStsCredentialsOptions(&invalid_options); + GPR_ASSERT(!url_should_be_invalid.ok()); + + invalid_options = { + "not_a_valid_uri", // sts_endpoint_url + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + test_signed_jwt_path_prefix, // subject_token_path + test_signed_jwt_token_type, // subject_token_type (Required) + nullptr, // actor_token_path + nullptr // actor_token_type + }; + url_should_be_invalid = + grpc_core::ValidateStsCredentialsOptions(&invalid_options); + GPR_ASSERT(!url_should_be_invalid.ok()); + + invalid_options = { + "ftp://ftp.is.not.a.valid.scheme/bar", // sts_endpoint_url + nullptr, // resource + nullptr, // audience + nullptr, // scope + nullptr, // requested_token_type + test_signed_jwt_path_prefix, // subject_token_path + test_signed_jwt_token_type, // subject_token_type (Required) + nullptr, // actor_token_path + nullptr // actor_token_type + }; + url_should_be_invalid = + grpc_core::ValidateStsCredentialsOptions(&invalid_options); + GPR_ASSERT(!url_should_be_invalid.ok()); +} + +static void assert_query_parameters(const grpc_core::URI& uri, + absl::string_view expected_key, + absl::string_view expected_val) { + const auto it = uri.query_parameter_map().find(expected_key); + GPR_ASSERT(it != uri.query_parameter_map().end()); + if (it->second != expected_val) { + gpr_log(GPR_ERROR, "%s!=%s", std::string(it->second).c_str(), + std::string(expected_val).c_str()); + } + GPR_ASSERT(it->second == expected_val); +} + +static void validate_sts_token_http_request(const grpc_httpcli_request* request, + const char* body, size_t body_size, + bool expect_actor_token) { + // Check that the body is constructed properly. + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + std::string get_url_equivalent = + absl::StrFormat("%s?%s", test_sts_endpoint_url, body); + absl::StatusOr url = + grpc_core::URI::Parse(get_url_equivalent); + if (!url.ok()) { + gpr_log(GPR_ERROR, "%s", url.status().ToString().c_str()); + GPR_ASSERT(url.ok()); + } + assert_query_parameters(*url, "resource", "resource"); + assert_query_parameters(*url, "audience", "audience"); + assert_query_parameters(*url, "scope", "scope"); + assert_query_parameters(*url, "requested_token_type", "requested_token_type"); + assert_query_parameters(*url, "subject_token", test_signed_jwt); + assert_query_parameters(*url, "subject_token_type", + test_signed_jwt_token_type); + if (expect_actor_token) { + assert_query_parameters(*url, "actor_token", test_signed_jwt2); + assert_query_parameters(*url, "actor_token_type", + test_signed_jwt_token_type2); + } else { + GPR_ASSERT(url->query_parameter_map().find("actor_token") == + url->query_parameter_map().end()); + GPR_ASSERT(url->query_parameter_map().find("actor_token_type") == + url->query_parameter_map().end()); + } + + // Check the rest of the request. + GPR_ASSERT(strcmp(request->host, "foo.com:5555") == 0); + GPR_ASSERT(strcmp(request->http.path, "/v1/token-exchange") == 0); + GPR_ASSERT(request->http.hdr_count == 1); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); +} + +static int sts_token_httpcli_post_success(const grpc_httpcli_request* request, + const char* body, size_t body_size, + grpc_millis /*deadline*/, + grpc_closure* on_done, + grpc_httpcli_response* response) { + validate_sts_token_http_request(request, body, body_size, true); + *response = http_response(200, valid_sts_json_response); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int sts_token_httpcli_post_success_no_actor_token( + const grpc_httpcli_request* request, const char* body, size_t body_size, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + validate_sts_token_http_request(request, body, body_size, false); + *response = http_response(200, valid_sts_json_response); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static char* write_tmp_jwt_file(const char* jwt_contents) { + char* path; + FILE* tmp = gpr_tmpfile(test_signed_jwt_path_prefix, &path); + GPR_ASSERT(path != nullptr); + GPR_ASSERT(tmp != nullptr); + size_t jwt_length = strlen(jwt_contents); + GPR_ASSERT(fwrite(jwt_contents, 1, jwt_length, tmp) == jwt_length); + fclose(tmp); + return path; +} + +static void test_sts_creds_success(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + const char expected_creds_debug_string[] = + "StsTokenFetcherCredentials{Path:/v1/" + "token-exchange,Authority:foo.com:5555,OAuth2TokenFetcherCredentials}"; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + char* subject_token_path = write_tmp_jwt_file(test_signed_jwt); + char* actor_token_path = write_tmp_jwt_file(test_signed_jwt2); + grpc_sts_credentials_options valid_options = { + test_sts_endpoint_url, // sts_endpoint_url + "resource", // resource + "audience", // audience + "scope", // scope + "requested_token_type", // requested_token_type + subject_token_path, // subject_token_path + test_signed_jwt_token_type, // subject_token_type + actor_token_path, // actor_token_path + test_signed_jwt_token_type2 // actor_token_type + }; + grpc_call_credentials* creds = + grpc_sts_credentials_create(&valid_options, nullptr); + + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + + /* First request: http put should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + sts_token_httpcli_post_success); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_free(subject_token_path); + gpr_free(actor_token_path); +} + +static void test_sts_creds_token_file_not_found(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_sts_credentials_options valid_options = { + test_sts_endpoint_url, // sts_endpoint_url + "resource", // resource + "audience", // audience + "scope", // scope + "requested_token_type", // requested_token_type + "/some/completely/random/path", // subject_token_path + test_signed_jwt_token_type, // subject_token_type + "", // actor_token_path + "" // actor_token_type + }; + grpc_call_credentials* creds = + grpc_sts_credentials_create(&valid_options, nullptr); + + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token."), + {}); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Cleanup. */ + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_sts_creds_no_actor_token_success(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + const char expected_creds_debug_string[] = + "StsTokenFetcherCredentials{Path:/v1/" + "token-exchange,Authority:foo.com:5555,OAuth2TokenFetcherCredentials}"; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + char* subject_token_path = write_tmp_jwt_file(test_signed_jwt); + grpc_sts_credentials_options valid_options = { + test_sts_endpoint_url, // sts_endpoint_url + "resource", // resource + "audience", // audience + "scope", // scope + "requested_token_type", // requested_token_type + subject_token_path, // subject_token_path + test_signed_jwt_token_type, // subject_token_type + "", // actor_token_path + "" // actor_token_type + }; + grpc_call_credentials* creds = + grpc_sts_credentials_create(&valid_options, nullptr); + + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + + /* First request: http put should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + sts_token_httpcli_post_success_no_actor_token); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_free(subject_token_path); +} + +static void test_sts_creds_load_token_failure(void) { + const char expected_creds_debug_string[] = + "StsTokenFetcherCredentials{Path:/v1/" + "token-exchange,Authority:foo.com:5555,OAuth2TokenFetcherCredentials}"; + grpc_core::ExecCtx exec_ctx; + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token."), + {}); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + char* test_signed_jwt_path = write_tmp_jwt_file(test_signed_jwt); + grpc_sts_credentials_options options = { + test_sts_endpoint_url, // sts_endpoint_url + "resource", // resource + "audience", // audience + "scope", // scope + "requested_token_type", // requested_token_type + "invalid_path", // subject_token_path + test_signed_jwt_token_type, // subject_token_type + nullptr, // actor_token_path + nullptr // actor_token_type + }; + grpc_call_credentials* creds = grpc_sts_credentials_create(&options, nullptr); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_free(test_signed_jwt_path); +} + +static void test_sts_creds_http_failure(void) { + const char expected_creds_debug_string[] = + "StsTokenFetcherCredentials{Path:/v1/" + "token-exchange,Authority:foo.com:5555,OAuth2TokenFetcherCredentials}"; + grpc_core::ExecCtx exec_ctx; + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token."), + {}); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + char* test_signed_jwt_path = write_tmp_jwt_file(test_signed_jwt); + grpc_sts_credentials_options valid_options = { + test_sts_endpoint_url, // sts_endpoint_url + "resource", // resource + "audience", // audience + "scope", // scope + "requested_token_type", // requested_token_type + test_signed_jwt_path, // subject_token_path + test_signed_jwt_token_type, // subject_token_type + nullptr, // actor_token_path + nullptr // actor_token_type + }; + grpc_call_credentials* creds = + grpc_sts_credentials_create(&valid_options, nullptr); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + token_httpcli_post_failure); + state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_free(test_signed_jwt_path); +} + +static void validate_jwt_encode_and_sign_params( + const grpc_auth_json_key* json_key, const char* scope, + gpr_timespec token_lifetime) { + GPR_ASSERT(grpc_auth_json_key_is_valid(json_key)); + GPR_ASSERT(json_key->private_key != nullptr); + GPR_ASSERT(RSA_check_key(json_key->private_key)); + GPR_ASSERT(json_key->type != nullptr && + strcmp(json_key->type, "service_account") == 0); + GPR_ASSERT(json_key->private_key_id != nullptr && + strcmp(json_key->private_key_id, + "e6b5137873db8d2ef81e06a47289e6434ec8a165") == 0); + GPR_ASSERT(json_key->client_id != nullptr && + strcmp(json_key->client_id, + "777-abaslkan11hlb6nmim3bpspl31ud.apps." + "googleusercontent.com") == 0); + GPR_ASSERT(json_key->client_email != nullptr && + strcmp(json_key->client_email, + "777-abaslkan11hlb6nmim3bpspl31ud@developer." + "gserviceaccount.com") == 0); + if (scope != nullptr) GPR_ASSERT(strcmp(scope, test_scope) == 0); + GPR_ASSERT(gpr_time_cmp(token_lifetime, grpc_max_auth_token_lifetime()) == 0); +} + +static char* encode_and_sign_jwt_success(const grpc_auth_json_key* json_key, + const char* audience, + gpr_timespec token_lifetime, + const char* scope) { + if (strcmp(audience, test_service_url_no_service_name) != 0 && + strcmp(audience, other_test_service_url_no_service_name) != 0) { + return nullptr; + } + validate_jwt_encode_and_sign_params(json_key, scope, token_lifetime); + return gpr_strdup(test_signed_jwt); +} + +static char* encode_and_sign_jwt_failure(const grpc_auth_json_key* json_key, + const char* /*audience*/, + gpr_timespec token_lifetime, + const char* scope) { + validate_jwt_encode_and_sign_params(json_key, scope, token_lifetime); + return nullptr; +} + +static char* encode_and_sign_jwt_should_not_be_called( + const grpc_auth_json_key* /*json_key*/, const char* /*audience*/, + gpr_timespec /*token_lifetime*/, const char* /*scope*/) { + GPR_ASSERT("grpc_jwt_encode_and_sign should not be called" == nullptr); + return nullptr; +} + +static grpc_service_account_jwt_access_credentials* creds_as_jwt( + grpc_call_credentials* creds) { + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_JWT) == 0); + return reinterpret_cast(creds); +} + +static void test_jwt_creds_lifetime(void) { + char* json_key_string = test_json_key_str(); + const char expected_creds_debug_string_prefix[] = + "JWTAccessCredentials{ExpirationTime:"; + // Max lifetime. + grpc_call_credentials* jwt_creds = + grpc_service_account_jwt_access_credentials_create( + json_key_string, grpc_max_auth_token_lifetime(), nullptr); + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), + grpc_max_auth_token_lifetime()) == 0); + /* Check security level. */ + GPR_ASSERT(jwt_creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + GPR_ASSERT(strncmp(expected_creds_debug_string_prefix, + jwt_creds->debug_string().c_str(), + strlen(expected_creds_debug_string_prefix)) == 0); + grpc_call_credentials_release(jwt_creds); + + // Shorter lifetime. + gpr_timespec token_lifetime = {10, 0, GPR_TIMESPAN}; + GPR_ASSERT(gpr_time_cmp(grpc_max_auth_token_lifetime(), token_lifetime) > 0); + jwt_creds = grpc_service_account_jwt_access_credentials_create( + json_key_string, token_lifetime, nullptr); + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), + token_lifetime) == 0); + GPR_ASSERT(strncmp(expected_creds_debug_string_prefix, + jwt_creds->debug_string().c_str(), + strlen(expected_creds_debug_string_prefix)) == 0); + grpc_call_credentials_release(jwt_creds); + + // Cropped lifetime. + gpr_timespec add_to_max = {10, 0, GPR_TIMESPAN}; + token_lifetime = gpr_time_add(grpc_max_auth_token_lifetime(), add_to_max); + jwt_creds = grpc_service_account_jwt_access_credentials_create( + json_key_string, token_lifetime, nullptr); + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), + grpc_max_auth_token_lifetime()) == 0); + GPR_ASSERT(strncmp(expected_creds_debug_string_prefix, + jwt_creds->debug_string().c_str(), + strlen(expected_creds_debug_string_prefix)) == 0); + grpc_call_credentials_release(jwt_creds); + + gpr_free(json_key_string); +} + +static void test_remove_service_from_jwt_uri(void) { + const char wrong_uri[] = "hello world"; + GPR_ASSERT(!grpc_core::RemoveServiceNameFromJwtUri(wrong_uri).ok()); + const char valid_uri[] = "https://foo.com/get/"; + const char expected_uri[] = "https://foo.com/"; + auto output = grpc_core::RemoveServiceNameFromJwtUri(valid_uri); + GPR_ASSERT(output.ok()); + GPR_ASSERT(strcmp(output->c_str(), expected_uri) == 0); +} + +static void test_jwt_creds_success(void) { + const char expected_creds_debug_string_prefix[] = + "JWTAccessCredentials{ExpirationTime:"; + + char* json_key_string = test_json_key_str(); + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + std::string expected_md_value = absl::StrCat("Bearer ", test_signed_jwt); + std::map emd = { + {"authorization", expected_md_value.c_str()}}; + grpc_call_credentials* creds = + grpc_service_account_jwt_access_credentials_create( + json_key_string, grpc_max_auth_token_lifetime(), nullptr); + + /* First request: jwt_encode_and_sign should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_jwt_encode_and_sign_set_override(encode_and_sign_jwt_success); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_jwt_encode_and_sign_set_override( + encode_and_sign_jwt_should_not_be_called); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + /* Third request: Different service url so jwt_encode_and_sign should be + called again (no caching). */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + auth_md_ctx.service_url = other_test_service_url; + grpc_jwt_encode_and_sign_set_override(encode_and_sign_jwt_success); + state->RunRequestMetadataTest(creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(strncmp(expected_creds_debug_string_prefix, + creds->debug_string().c_str(), + strlen(expected_creds_debug_string_prefix)) == 0); + + creds->Unref(); + gpr_free(json_key_string); + grpc_jwt_encode_and_sign_set_override(nullptr); +} + +static void test_jwt_creds_signing_failure(void) { + const char expected_creds_debug_string_prefix[] = + "JWTAccessCredentials{ExpirationTime:"; + char* json_key_string = test_json_key_str(); + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + RequestMetadataState* state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Could not generate JWT."), {}); + grpc_call_credentials* creds = + grpc_service_account_jwt_access_credentials_create( + json_key_string, grpc_max_auth_token_lifetime(), nullptr); + + grpc_jwt_encode_and_sign_set_override(encode_and_sign_jwt_failure); + state->RunRequestMetadataTest(creds, auth_md_ctx); + + gpr_free(json_key_string); + GPR_ASSERT(strncmp(expected_creds_debug_string_prefix, + creds->debug_string().c_str(), + strlen(expected_creds_debug_string_prefix)) == 0); + + creds->Unref(); + grpc_jwt_encode_and_sign_set_override(nullptr); +} + +static void set_google_default_creds_env_var_with_file_contents( + const char* file_prefix, const char* contents) { + size_t contents_len = strlen(contents); + char* creds_file_name; + FILE* creds_file = gpr_tmpfile(file_prefix, &creds_file_name); + GPR_ASSERT(creds_file_name != nullptr); + GPR_ASSERT(creds_file != nullptr); + GPR_ASSERT(fwrite(contents, 1, contents_len, creds_file) == contents_len); + fclose(creds_file); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, creds_file_name); + gpr_free(creds_file_name); +} + +static bool test_gce_tenancy_checker(void) { + g_test_gce_tenancy_checker_called = true; + return g_test_is_on_gce; +} + +static void test_google_default_creds_auth_key(void) { + grpc_core::ExecCtx exec_ctx; + grpc_composite_channel_credentials* creds; + char* json_key = test_json_key_str(); + grpc_flush_cached_google_default_credentials(); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = true; + set_google_default_creds_env_var_with_file_contents( + "json_key_google_default_creds", json_key); + gpr_free(json_key); + creds = reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* jwt = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT( + strcmp(jwt->key().client_id, + "777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent.com") == + 0); + GPR_ASSERT(g_test_gce_tenancy_checker_called == false); + creds->Unref(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ +} + +static void test_google_default_creds_refresh_token(void) { + grpc_core::ExecCtx exec_ctx; + grpc_composite_channel_credentials* creds; + grpc_flush_cached_google_default_credentials(); + set_google_default_creds_env_var_with_file_contents( + "refresh_token_google_default_creds", test_refresh_token_str); + creds = reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* refresh = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT(strcmp(refresh->refresh_token().client_id, + "32555999999.apps.googleusercontent.com") == 0); + creds->Unref(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ +} + +static void test_google_default_creds_external_account_credentials(void) { + grpc_core::ExecCtx exec_ctx; + grpc_composite_channel_credentials* creds; + grpc_flush_cached_google_default_credentials(); + set_google_default_creds_env_var_with_file_contents( + "google_default_creds_external_account_credentials", + test_external_account_credentials_str); + creds = reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* external = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT(external != nullptr); + creds->Unref(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ +} + +static void +test_google_default_creds_external_account_credentials_multi_pattern_sts(void) { + grpc_core::ExecCtx exec_ctx; + grpc_composite_channel_credentials* creds; + grpc_flush_cached_google_default_credentials(); + set_google_default_creds_env_var_with_file_contents( + "google_default_creds_external_account_credentials", + test_external_account_credentials_multi_pattern_sts_str); + creds = reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* external = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT(external != nullptr); + creds->Unref(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ +} + +static void +test_google_default_creds_external_account_credentials_multi_pattern_iam(void) { + grpc_core::ExecCtx exec_ctx; + grpc_composite_channel_credentials* creds; + grpc_flush_cached_google_default_credentials(); + set_google_default_creds_env_var_with_file_contents( + "google_default_creds_external_account_credentials", + test_external_account_credentials_multi_pattern_iam_str); + creds = reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* external = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT(external != nullptr); + creds->Unref(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ +} + +static int default_creds_metadata_server_detection_httpcli_get_success_override( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + *response = http_response(200, ""); + grpc_http_header* headers = + static_cast(gpr_malloc(sizeof(*headers) * 1)); + headers[0].key = gpr_strdup("Metadata-Flavor"); + headers[0].value = gpr_strdup("Google"); + response->hdr_count = 1; + response->hdrs = headers; + GPR_ASSERT(strcmp(request->http.path, "/") == 0); + GPR_ASSERT(strcmp(request->host, "metadata.google.internal.") == 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static std::string null_well_known_creds_path_getter(void) { return ""; } + +static void test_google_default_creds_gce(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_flush_cached_google_default_credentials(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ + grpc_override_well_known_credentials_path_getter( + null_well_known_creds_path_getter); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = true; + + /* Simulate a successful detection of GCE. */ + grpc_composite_channel_credentials* creds = + reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + + /* Verify that the default creds actually embeds a GCE creds. */ + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(creds->call_creds() != nullptr); + grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds->mutable_call_creds(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + + GPR_ASSERT(g_test_gce_tenancy_checker_called == true); + + /* Cleanup. */ + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + grpc_override_well_known_credentials_path_getter(nullptr); +} + +static void test_google_default_creds_non_gce(void) { + grpc_core::ExecCtx exec_ctx; + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_flush_cached_google_default_credentials(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ + grpc_override_well_known_credentials_path_getter( + null_well_known_creds_path_getter); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = false; + /* Simulate a successful detection of metadata server. */ + grpc_httpcli_set_override( + default_creds_metadata_server_detection_httpcli_get_success_override, + httpcli_post_should_not_be_called); + grpc_composite_channel_credentials* creds = + reinterpret_cast( + grpc_google_default_credentials_create(nullptr)); + /* Verify that the default creds actually embeds a GCE creds. */ + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(creds->call_creds() != nullptr); + grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(creds->mutable_call_creds(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(g_test_gce_tenancy_checker_called == true); + /* Cleanup. */ + creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); + grpc_override_well_known_credentials_path_getter(nullptr); +} + +static int default_creds_gce_detection_httpcli_get_failure_override( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + /* No magic header. */ + GPR_ASSERT(strcmp(request->http.path, "/") == 0); + GPR_ASSERT(strcmp(request->host, "metadata.google.internal.") == 0); + *response = http_response(200, ""); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void test_no_google_default_creds(void) { + grpc_flush_cached_google_default_credentials(); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ + grpc_override_well_known_credentials_path_getter( + null_well_known_creds_path_getter); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = false; + grpc_httpcli_set_override( + default_creds_gce_detection_httpcli_get_failure_override, + httpcli_post_should_not_be_called); + /* Simulate a successful detection of GCE. */ + GPR_ASSERT(grpc_google_default_credentials_create(nullptr) == nullptr); + /* Try a second one. GCE detection should occur again. */ + g_test_gce_tenancy_checker_called = false; + GPR_ASSERT(grpc_google_default_credentials_create(nullptr) == nullptr); + GPR_ASSERT(g_test_gce_tenancy_checker_called == true); + /* Cleanup. */ + grpc_override_well_known_credentials_path_getter(nullptr); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_google_default_creds_call_creds_specified(void) { + std::map emd = { + {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::ExecCtx exec_ctx; + grpc_flush_cached_google_default_credentials(); + grpc_call_credentials* call_creds = + grpc_google_compute_engine_credentials_create(nullptr); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = true; + grpc_httpcli_set_override( + default_creds_metadata_server_detection_httpcli_get_success_override, + httpcli_post_should_not_be_called); + grpc_composite_channel_credentials* channel_creds = + reinterpret_cast( + grpc_google_default_credentials_create(call_creds)); + GPR_ASSERT(g_test_gce_tenancy_checker_called == false); + GPR_ASSERT(channel_creds != nullptr); + GPR_ASSERT(channel_creds->call_creds() != nullptr); + grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(channel_creds->mutable_call_creds(), + auth_md_ctx); + + grpc_core::ExecCtx::Get()->Flush(); + channel_creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +struct fake_call_creds : public grpc_call_credentials { + public: + explicit fake_call_creds() : grpc_call_credentials("fake") { + grpc_slice key = grpc_slice_from_static_string("foo"); + grpc_slice value = grpc_slice_from_static_string("oof"); + phony_md_ = grpc_mdelem_from_slices(key, value); + grpc_slice_unref(key); + grpc_slice_unref(value); + } + + ~fake_call_creds() override { GRPC_MDELEM_UNREF(phony_md_); } + + bool get_request_metadata(grpc_polling_entity* /*pollent*/, + grpc_auth_metadata_context /*context*/, + grpc_credentials_mdelem_array* md_array, + grpc_closure* /*on_request_metadata*/, + grpc_error_handle* /*error*/) override { + grpc_credentials_mdelem_array_add(md_array, phony_md_); + return true; + } + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* /*md_array*/, + grpc_error_handle /*error*/) override {} + + private: + grpc_mdelem phony_md_; +}; + +static void test_google_default_creds_not_default(void) { + std::map emd = {{"foo", "oof"}}; + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::ExecCtx exec_ctx; + grpc_flush_cached_google_default_credentials(); + grpc_core::RefCountedPtr call_creds = + grpc_core::MakeRefCounted(); + set_gce_tenancy_checker_for_testing(test_gce_tenancy_checker); + g_test_gce_tenancy_checker_called = false; + g_test_is_on_gce = true; + grpc_httpcli_set_override( + default_creds_metadata_server_detection_httpcli_get_success_override, + httpcli_post_should_not_be_called); + grpc_composite_channel_credentials* channel_creds = + reinterpret_cast( + grpc_google_default_credentials_create(call_creds.release())); + GPR_ASSERT(g_test_gce_tenancy_checker_called == false); + GPR_ASSERT(channel_creds != nullptr); + GPR_ASSERT(channel_creds->call_creds() != nullptr); + state->RunRequestMetadataTest(channel_creds->mutable_call_creds(), + auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + channel_creds->Unref(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +typedef enum { + PLUGIN_INITIAL_STATE, + PLUGIN_GET_METADATA_CALLED_STATE, + PLUGIN_DESTROY_CALLED_STATE +} plugin_state; + +static const std::map plugin_md = {{"foo", "bar"}, + {"hi", "there"}}; + +static int plugin_get_metadata_success( + void* state, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb /*cb*/, void* /*user_data*/, + grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], + size_t* num_creds_md, grpc_status_code* /*status*/, + const char** /*error_details*/) { + GPR_ASSERT(strcmp(context.service_url, test_service_url) == 0); + GPR_ASSERT(strcmp(context.method_name, test_method) == 0); + GPR_ASSERT(context.channel_auth_context == nullptr); + GPR_ASSERT(context.reserved == nullptr); + GPR_ASSERT(plugin_md.size() < GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX); + plugin_state* s = static_cast(state); + *s = PLUGIN_GET_METADATA_CALLED_STATE; + size_t i = 0; + for (auto const& md : plugin_md) { + memset(&creds_md[i], 0, sizeof(grpc_metadata)); + creds_md[i].key = grpc_slice_from_copied_string(md.first.c_str()); + creds_md[i].value = grpc_slice_from_copied_string(md.second.c_str()); + i += 1; + } + *num_creds_md = plugin_md.size(); + return true; // Synchronous return. +} + +static const char* plugin_error_details = "Could not get metadata for plugin."; + +static int plugin_get_metadata_failure( + void* state, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb /*cb*/, void* /*user_data*/, + grpc_metadata /*creds_md*/[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], + size_t* /*num_creds_md*/, grpc_status_code* status, + const char** error_details) { + GPR_ASSERT(strcmp(context.service_url, test_service_url) == 0); + GPR_ASSERT(strcmp(context.method_name, test_method) == 0); + GPR_ASSERT(context.channel_auth_context == nullptr); + GPR_ASSERT(context.reserved == nullptr); + plugin_state* s = static_cast(state); + *s = PLUGIN_GET_METADATA_CALLED_STATE; + *status = GRPC_STATUS_UNAUTHENTICATED; + *error_details = gpr_strdup(plugin_error_details); + return true; // Synchronous return. +} + +static void plugin_destroy(void* state) { + plugin_state* s = static_cast(state); + *s = PLUGIN_DESTROY_CALLED_STATE; +} + +static char* plugin_debug_string(void* state) { + plugin_state* s = static_cast(state); + char* ret = nullptr; + switch (*s) { + case PLUGIN_INITIAL_STATE: + gpr_asprintf(&ret, "TestPluginCredentials{state:INITIAL}"); + break; + case PLUGIN_GET_METADATA_CALLED_STATE: + gpr_asprintf(&ret, "TestPluginCredentials{state:GET_METADATA_CALLED}"); + break; + case PLUGIN_DESTROY_CALLED_STATE: + gpr_asprintf(&ret, "TestPluginCredentials{state:DESTROY}"); + break; + default: + gpr_asprintf(&ret, "TestPluginCredentials{state:UNKNOWN}"); + break; + } + return ret; +} + +static void test_metadata_plugin_success(void) { + const char expected_creds_debug_string[] = + "TestPluginCredentials{state:GET_METADATA_CALLED}"; + plugin_state state = PLUGIN_INITIAL_STATE; + grpc_metadata_credentials_plugin plugin; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + RequestMetadataState* md_state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, plugin_md); + + plugin.state = &state; + plugin.get_metadata = plugin_get_metadata_success; + plugin.destroy = plugin_destroy; + plugin.debug_string = plugin_debug_string; + + grpc_call_credentials* creds = grpc_metadata_credentials_create_from_plugin( + plugin, GRPC_PRIVACY_AND_INTEGRITY, nullptr); + /* Check security level. */ + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + GPR_ASSERT(state == PLUGIN_INITIAL_STATE); + md_state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + creds->Unref(); + + GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE); +} + +static void test_metadata_plugin_failure(void) { + const char expected_creds_debug_string[] = + "TestPluginCredentials{state:GET_METADATA_CALLED}"; + + plugin_state state = PLUGIN_INITIAL_STATE; + grpc_metadata_credentials_plugin plugin; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + RequestMetadataState* md_state = RequestMetadataState::NewInstance( + GRPC_ERROR_CREATE_FROM_CPP_STRING( + absl::StrCat("Getting metadata from plugin failed with error: ", + plugin_error_details)), + {}); + + plugin.state = &state; + plugin.get_metadata = plugin_get_metadata_failure; + plugin.destroy = plugin_destroy; + plugin.debug_string = plugin_debug_string; + + grpc_call_credentials* creds = grpc_metadata_credentials_create_from_plugin( + plugin, GRPC_PRIVACY_AND_INTEGRITY, nullptr); + GPR_ASSERT(state == PLUGIN_INITIAL_STATE); + md_state->RunRequestMetadataTest(creds, auth_md_ctx); + GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE); + GPR_ASSERT( + strcmp(creds->debug_string().c_str(), expected_creds_debug_string) == 0); + creds->Unref(); + + GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE); +} + +static void test_get_well_known_google_credentials_file_path(void) { + char* home = gpr_getenv("HOME"); + bool restore_home_env = false; +#if defined(GRPC_BAZEL_BUILD) && \ + (defined(GPR_POSIX_ENV) || defined(GPR_LINUX_ENV)) + // when running under bazel locally, the HOME variable is not set + // so we set it to some fake value + restore_home_env = true; + gpr_setenv("HOME", "/fake/home/for/bazel"); +#endif /* defined(GRPC_BAZEL_BUILD) && (defined(GPR_POSIX_ENV) || \ + defined(GPR_LINUX_ENV)) */ + std::string path = grpc_get_well_known_google_credentials_file_path(); + GPR_ASSERT(!path.empty()); +#if defined(GPR_POSIX_ENV) || defined(GPR_LINUX_ENV) + restore_home_env = true; + gpr_unsetenv("HOME"); + path = grpc_get_well_known_google_credentials_file_path(); + GPR_ASSERT(path.empty()); +#endif /* GPR_POSIX_ENV || GPR_LINUX_ENV */ + if (restore_home_env) { + if (home) { + gpr_setenv("HOME", home); + } else { + gpr_unsetenv("HOME"); + } + } + gpr_free(home); +} + +static void test_channel_creds_duplicate_without_call_creds(void) { + const char expected_creds_debug_string[] = + "AccessTokenCredentials{Token:present}"; + grpc_core::ExecCtx exec_ctx; + + grpc_channel_credentials* channel_creds = + grpc_fake_transport_security_credentials_create(); + + grpc_core::RefCountedPtr dup = + channel_creds->duplicate_without_call_credentials(); + GPR_ASSERT(dup == channel_creds); + dup.reset(); + + grpc_call_credentials* call_creds = + grpc_access_token_credentials_create("blah", nullptr); + grpc_channel_credentials* composite_creds = + grpc_composite_channel_credentials_create(channel_creds, call_creds, + nullptr); + GPR_ASSERT(strcmp(call_creds->debug_string().c_str(), + expected_creds_debug_string) == 0); + + call_creds->Unref(); + dup = composite_creds->duplicate_without_call_credentials(); + GPR_ASSERT(dup == channel_creds); + dup.reset(); + + channel_creds->Unref(); + composite_creds->Unref(); +} + +typedef struct { + const char* url_scheme; + const char* call_host; + const char* call_method; + const char* desired_service_url; + const char* desired_method_name; +} auth_metadata_context_test_case; + +static void test_auth_metadata_context(void) { + auth_metadata_context_test_case test_cases[] = { + // No service nor method. + {"https", "www.foo.com", "", "https://www.foo.com", ""}, + // No method. + {"https", "www.foo.com", "/Service", "https://www.foo.com/Service", ""}, + // Empty service and method. + {"https", "www.foo.com", "//", "https://www.foo.com/", ""}, + // Empty method. + {"https", "www.foo.com", "/Service/", "https://www.foo.com/Service", ""}, + // Malformed url. + {"https", "www.foo.com:", "/Service/", "https://www.foo.com:/Service", + ""}, + // https, default explicit port. + {"https", "www.foo.com:443", "/Service/FooMethod", + "https://www.foo.com/Service", "FooMethod"}, + // https, default implicit port. + {"https", "www.foo.com", "/Service/FooMethod", + "https://www.foo.com/Service", "FooMethod"}, + // https with ipv6 literal, default explicit port. + {"https", "[1080:0:0:0:8:800:200C:417A]:443", "/Service/FooMethod", + "https://[1080:0:0:0:8:800:200C:417A]/Service", "FooMethod"}, + // https with ipv6 literal, default implicit port. + {"https", "[1080:0:0:0:8:800:200C:443]", "/Service/FooMethod", + "https://[1080:0:0:0:8:800:200C:443]/Service", "FooMethod"}, + // https, custom port. + {"https", "www.foo.com:8888", "/Service/FooMethod", + "https://www.foo.com:8888/Service", "FooMethod"}, + // https with ipv6 literal, custom port. + {"https", "[1080:0:0:0:8:800:200C:417A]:8888", "/Service/FooMethod", + "https://[1080:0:0:0:8:800:200C:417A]:8888/Service", "FooMethod"}, + // custom url scheme, https default port. + {"blah", "www.foo.com:443", "/Service/FooMethod", + "blah://www.foo.com:443/Service", "FooMethod"}}; + for (uint32_t i = 0; i < GPR_ARRAY_SIZE(test_cases); i++) { + const char* url_scheme = test_cases[i].url_scheme; + grpc_slice call_host = + grpc_slice_from_copied_string(test_cases[i].call_host); + grpc_slice call_method = + grpc_slice_from_copied_string(test_cases[i].call_method); + grpc_auth_metadata_context auth_md_context; + memset(&auth_md_context, 0, sizeof(auth_md_context)); + grpc_auth_metadata_context_build(url_scheme, call_host, call_method, + nullptr, &auth_md_context); + if (strcmp(auth_md_context.service_url, + test_cases[i].desired_service_url) != 0) { + gpr_log(GPR_ERROR, "Invalid service url, want: %s, got %s.", + test_cases[i].desired_service_url, auth_md_context.service_url); + GPR_ASSERT(false); + } + if (strcmp(auth_md_context.method_name, + test_cases[i].desired_method_name) != 0) { + gpr_log(GPR_ERROR, "Invalid method name, want: %s, got %s.", + test_cases[i].desired_method_name, auth_md_context.method_name); + GPR_ASSERT(false); + } + GPR_ASSERT(auth_md_context.channel_auth_context == nullptr); + grpc_slice_unref(call_host); + grpc_slice_unref(call_method); + grpc_auth_metadata_context_reset(&auth_md_context); + } +} + +static void validate_external_account_creds_token_exchage_request( + const grpc_httpcli_request* request, const char* body, size_t body_size, + bool /*expect_actor_token*/) { + // Check that the body is constructed properly. + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + std::string get_url_equivalent = + absl::StrFormat("%s?%s", "https://foo.com:5555/token", body); + absl::StatusOr uri = + grpc_core::URI::Parse(get_url_equivalent); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "%s", uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + assert_query_parameters(*uri, "audience", "audience"); + assert_query_parameters(*uri, "grant_type", + "urn:ietf:params:oauth:grant-type:token-exchange"); + assert_query_parameters(*uri, "requested_token_type", + "urn:ietf:params:oauth:token-type:access_token"); + assert_query_parameters(*uri, "subject_token", "test_subject_token"); + assert_query_parameters(*uri, "subject_token_type", "subject_token_type"); + assert_query_parameters(*uri, "scope", + "https://www.googleapis.com/auth/cloud-platform"); + + // Check the rest of the request. + GPR_ASSERT(strcmp(request->host, "foo.com:5555") == 0); + GPR_ASSERT(strcmp(request->http.path, "/token") == 0); + GPR_ASSERT(request->http.hdr_count == 2); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].key, "Authorization") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].value, + "Basic Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=") == 0); +} + +static void +validate_external_account_creds_token_exchage_request_with_url_encode( + const grpc_httpcli_request* request, const char* body, size_t body_size, + bool /*expect_actor_token*/) { + // Check that the body is constructed properly. + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT( + strcmp( + std::string(body, body_size).c_str(), + "audience=audience_!%40%23%24&grant_type=urn%3Aietf%3Aparams%3Aoauth%" + "3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%" + "3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&subject_token_type=" + "subject_token_type_!%40%23%24&subject_token=test_subject_token&" + "scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&" + "options=%7B%7D") == 0); + + // Check the rest of the request. + GPR_ASSERT(strcmp(request->host, "foo.com:5555") == 0); + GPR_ASSERT(strcmp(request->http.path, "/token_url_encode") == 0); + GPR_ASSERT(request->http.hdr_count == 2); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].key, "Authorization") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].value, + "Basic Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=") == 0); +} + +static void +validate_external_account_creds_service_account_impersonation_request( + const grpc_httpcli_request* request, const char* body, size_t body_size, + bool /*expect_actor_token*/) { + // Check that the body is constructed properly. + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(body, "scope=scope_1 scope_2") == 0); + // Check the rest of the request. + GPR_ASSERT(strcmp(request->host, "foo.com:5555") == 0); + GPR_ASSERT(strcmp(request->http.path, "/service_account_impersonation") == 0); + GPR_ASSERT(request->http.hdr_count == 2); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].key, "Authorization") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].value, + "Bearer token_exchange_access_token") == 0); +} + +static int external_account_creds_httpcli_post_success( + const grpc_httpcli_request* request, const char* body, size_t body_size, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + if (strcmp(request->http.path, "/token") == 0) { + validate_external_account_creds_token_exchage_request(request, body, + body_size, true); + *response = http_response( + 200, valid_external_account_creds_token_exchange_response); + } else if (strcmp(request->http.path, "/service_account_impersonation") == + 0) { + validate_external_account_creds_service_account_impersonation_request( + request, body, body_size, true); + *response = http_response( + 200, + valid_external_account_creds_service_account_impersonation_response); + } else if (strcmp(request->http.path, "/token_url_encode") == 0) { + validate_external_account_creds_token_exchage_request_with_url_encode( + request, body, body_size, true); + *response = http_response( + 200, valid_external_account_creds_token_exchange_response); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int +external_account_creds_httpcli_post_failure_token_exchange_response_missing_access_token( + const grpc_httpcli_request* request, const char* /*body*/, + size_t /*body_size*/, grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + if (strcmp(request->http.path, "/token") == 0) { + *response = http_response(200, + "{\"not_access_token\":\"not_access_token\"," + "\"expires_in\":3599," + " \"token_type\":\"Bearer\"}"); + } else if (strcmp(request->http.path, "/service_account_impersonation") == + 0) { + *response = http_response( + 200, + valid_external_account_creds_service_account_impersonation_response); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int url_external_account_creds_httpcli_get_success( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + if (strcmp(request->http.path, "/generate_subject_token_format_text") == 0) { + *response = http_response( + 200, + valid_url_external_account_creds_retrieve_subject_token_response_format_text); + } else if (strcmp(request->http.path, "/path/to/url/creds?p1=v1&p2=v2") == + 0) { + *response = http_response( + 200, + valid_url_external_account_creds_retrieve_subject_token_response_format_text); + } else if (strcmp(request->http.path, + "/generate_subject_token_format_json") == 0) { + *response = http_response( + 200, + valid_url_external_account_creds_retrieve_subject_token_response_format_json); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void validate_aws_external_account_creds_token_exchage_request( + const grpc_httpcli_request* request, const char* body, size_t body_size, + bool /*expect_actor_token*/) { + // Check that the body is constructed properly. + GPR_ASSERT(body != nullptr); + GPR_ASSERT(body_size != 0); + // Check that the regional_cred_verification_url got constructed + // with the correct AWS Region ("test_regionz" or "test_region"). + GPR_ASSERT(strstr(body, "regional_cred_verification_url_test_region")); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + std::string get_url_equivalent = + absl::StrFormat("%s?%s", "https://foo.com:5555/token", body); + absl::StatusOr uri = + grpc_core::URI::Parse(get_url_equivalent); + GPR_ASSERT(uri.ok()); + assert_query_parameters(*uri, "audience", "audience"); + assert_query_parameters(*uri, "grant_type", + "urn:ietf:params:oauth:grant-type:token-exchange"); + assert_query_parameters(*uri, "requested_token_type", + "urn:ietf:params:oauth:token-type:access_token"); + assert_query_parameters(*uri, "subject_token_type", "subject_token_type"); + assert_query_parameters(*uri, "scope", + "https://www.googleapis.com/auth/cloud-platform"); + // Check the rest of the request. + GPR_ASSERT(strcmp(request->host, "foo.com:5555") == 0); + GPR_ASSERT(strcmp(request->http.path, "/token") == 0); + GPR_ASSERT(request->http.hdr_count == 2); + GPR_ASSERT(strcmp(request->http.hdrs[0].key, "Content-Type") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[0].value, + "application/x-www-form-urlencoded") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].key, "Authorization") == 0); + GPR_ASSERT(strcmp(request->http.hdrs[1].value, + "Basic Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=") == 0); +} + +static int aws_external_account_creds_httpcli_get_success( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + if (strcmp(request->http.path, "/region_url") == 0) { + *response = http_response(200, "test_regionz"); + } else if (strcmp(request->http.path, "/url") == 0) { + *response = http_response(200, "test_role_name"); + } else if (strcmp(request->http.path, "/url_no_role_name") == 0) { + *response = http_response(200, ""); + } else if (strcmp(request->http.path, "/url/test_role_name") == 0) { + *response = http_response( + 200, valid_aws_external_account_creds_retrieve_signing_keys_response); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int aws_external_account_creds_httpcli_post_success( + const grpc_httpcli_request* request, const char* body, size_t body_size, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + if (strcmp(request->http.path, "/token") == 0) { + validate_aws_external_account_creds_token_exchage_request(request, body, + body_size, true); + *response = http_response( + 200, valid_external_account_creds_token_exchange_response); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +// The subclass of ExternalAccountCredentials for testing. +// ExternalAccountCredentials is an abstract class so we can't directly test +// against it. +class TestExternalAccountCredentials final + : public grpc_core::ExternalAccountCredentials { + public: + TestExternalAccountCredentials(Options options, + std::vector scopes) + : ExternalAccountCredentials(std::move(options), std::move(scopes)) {} + + protected: + void RetrieveSubjectToken( + HTTPRequestContext* /*ctx*/, const Options& /*options*/, + std::function cb) override { + cb("test_subject_token", GRPC_ERROR_NONE); + } +}; + +static void test_external_account_creds_success(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {}); + /* Check security level. */ + GPR_ASSERT(creds.min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + /* First request: http put should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + /* Second request: the cached token should be served directly. */ + state = RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_external_account_creds_success_with_url_encode(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience_!@#$", // audience; + "subject_token_type_!@#$", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token_url_encode", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {}); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void +test_external_account_creds_success_with_service_account_impersonation(void) { + std::map emd = { + {"authorization", "Bearer service_account_impersonation_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "https://foo.com:5555/service_account_impersonation", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {"scope_1", "scope_2"}); + /* Check security level. */ + GPR_ASSERT(creds.min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + /* First request: http put should be called. */ + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_external_account_creds_failure_invalid_token_url(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "https://foo.com:5555/service_account_impersonation", // service_account_impersonation_url; + "invalid_token_url", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {}); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid token url: invalid_token_url."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + GRPC_ERROR_UNREF(error); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void +test_external_account_creds_failure_invalid_service_account_impersonation_url( + void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "invalid_service_account_impersonation_url", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {}); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid service account impersonation url: " + "invalid_service_account_impersonation_url."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + GRPC_ERROR_UNREF(error); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void +test_external_account_creds_failure_token_exchange_response_missing_access_token( + void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_core::Json credential_source(""); + TestExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "https://foo.com:5555/service_account_impersonation", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + TestExternalAccountCredentials creds(options, {}); + grpc_httpcli_set_override( + httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_failure_token_exchange_response_missing_access_token); + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing or invalid access_token in " + "{\"not_access_token\":\"not_access_token\",\"expires_in\":3599,\"token_" + "type\":\"Bearer\"}."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + state->RunRequestMetadataTest(&creds, auth_md_ctx); + GRPC_ERROR_UNREF(error); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_url_external_account_creds_success_format_text(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_url_external_account_creds_options_credential_source_format_text, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::UrlExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(url_external_account_creds_httpcli_get_success, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void +test_url_external_account_creds_success_with_qurey_params_format_text(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_url_external_account_creds_options_credential_source_with_qurey_params_format_text, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::UrlExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(url_external_account_creds_httpcli_get_success, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_url_external_account_creds_success_format_json(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_url_external_account_creds_options_credential_source_format_json, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::UrlExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(url_external_account_creds_httpcli_get_success, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void +test_url_external_account_creds_failure_invalid_credential_source_url(void) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_url_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::UrlExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds == nullptr); + std::string actual_error; + GPR_ASSERT( + grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &actual_error)); + GPR_ASSERT(absl::StartsWith(actual_error, "Invalid credential source url.")); + GRPC_ERROR_UNREF(error); +} + +static void test_file_external_account_creds_success_format_text(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + char* subject_token_path = write_tmp_jwt_file("test_subject_token"); + grpc_core::Json credential_source = grpc_core::Json::Parse( + absl::StrFormat( + "{\"file\":\"%s\"}", + absl::StrReplaceAll(subject_token_path, {{"\\", "\\\\"}})), + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::FileExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); + gpr_free(subject_token_path); +} + +static void test_file_external_account_creds_success_format_json(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + char* subject_token_path = + write_tmp_jwt_file("{\"access_token\":\"test_subject_token\"}"); + grpc_core::Json credential_source = grpc_core::Json::Parse( + absl::StrFormat( + "{\n" + "\"file\":\"%s\",\n" + "\"format\":\n" + "{\n" + "\"type\":\"json\",\n" + "\"subject_token_field_name\":\"access_token\"\n" + "}\n" + "}", + absl::StrReplaceAll(subject_token_path, {{"\\", "\\\\"}})), + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::FileExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); + gpr_free(subject_token_path); +} + +static void test_file_external_account_creds_failure_file_not_found(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = + grpc_core::Json::Parse("{\"file\":\"non_exisiting_file\"}", &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::FileExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Failed to load file"); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); +} + +static void test_file_external_account_creds_failure_invalid_json_content( + void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + char* subject_token_path = write_tmp_jwt_file("not_a_valid_json_file"); + grpc_core::Json credential_source = grpc_core::Json::Parse( + absl::StrFormat( + "{\n" + "\"file\":\"%s\",\n" + "\"format\":\n" + "{\n" + "\"type\":\"json\",\n" + "\"subject_token_field_name\":\"access_token\"\n" + "}\n" + "}", + absl::StrReplaceAll(subject_token_path, {{"\\", "\\\\"}})), + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::FileExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "The content of the file is not a valid json object."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); + gpr_free(subject_token_path); +} + +static void test_aws_external_account_creds_success(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_aws_external_account_creds_success_path_region_env_keys_url( + void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + gpr_setenv("AWS_REGION", "test_regionz"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_REGION"); +} + +static void +test_aws_external_account_creds_success_path_default_region_env_keys_url(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + gpr_setenv("AWS_DEFAULT_REGION", "test_regionz"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_DEFAULT_REGION"); +} + +static void +test_aws_external_account_creds_success_path_duplicate_region_env_keys_url( + void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + // Make sure that AWS_REGION gets used over AWS_DEFAULT_REGION + gpr_setenv("AWS_REGION", "test_regionz"); + gpr_setenv("AWS_DEFAULT_REGION", "ERROR_REGION"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_REGION"); + gpr_unsetenv("AWS_DEFAULT_REGION"); +} + +static void test_aws_external_account_creds_success_path_region_url_keys_env( + void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + gpr_setenv("AWS_ACCESS_KEY_ID", "test_access_key_id"); + gpr_setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key"); + gpr_setenv("AWS_SESSION_TOKEN", "test_token"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_ACCESS_KEY_ID"); + gpr_unsetenv("AWS_SECRET_ACCESS_KEY"); + gpr_unsetenv("AWS_SESSION_TOKEN"); +} + +static void test_aws_external_account_creds_success_path_region_env_keys_env( + void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + gpr_setenv("AWS_REGION", "test_regionz"); + gpr_setenv("AWS_ACCESS_KEY_ID", "test_access_key_id"); + gpr_setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key"); + gpr_setenv("AWS_SESSION_TOKEN", "test_token"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_REGION"); + gpr_unsetenv("AWS_ACCESS_KEY_ID"); + gpr_unsetenv("AWS_SECRET_ACCESS_KEY"); + gpr_unsetenv("AWS_SESSION_TOKEN"); +} + +static void +test_aws_external_account_creds_success_path_default_region_env_keys_env(void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + gpr_setenv("AWS_DEFAULT_REGION", "test_regionz"); + gpr_setenv("AWS_ACCESS_KEY_ID", "test_access_key_id"); + gpr_setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key"); + gpr_setenv("AWS_SESSION_TOKEN", "test_token"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_DEFAULT_REGION"); + gpr_unsetenv("AWS_ACCESS_KEY_ID"); + gpr_unsetenv("AWS_SECRET_ACCESS_KEY"); + gpr_unsetenv("AWS_SESSION_TOKEN"); +} + +static void +test_aws_external_account_creds_success_path_duplicate_region_env_keys_env( + void) { + std::map emd = { + {"authorization", "Bearer token_exchange_access_token"}}; + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + // Make sure that AWS_REGION gets used over AWS_DEFAULT_REGION + gpr_setenv("AWS_REGION", "test_regionz"); + gpr_setenv("AWS_DEFAULT_REGION", "ERROR_REGION"); + gpr_setenv("AWS_ACCESS_KEY_ID", "test_access_key_id"); + gpr_setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key"); + gpr_setenv("AWS_SESSION_TOKEN", "test_token"); + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + valid_aws_external_account_creds_options_credential_source, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + RequestMetadataState* state = + RequestMetadataState::NewInstance(GRPC_ERROR_NONE, emd); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + gpr_unsetenv("AWS_REGION"); + gpr_unsetenv("AWS_DEFAULT_REGION"); + gpr_unsetenv("AWS_ACCESS_KEY_ID"); + gpr_unsetenv("AWS_SECRET_ACCESS_KEY"); + gpr_unsetenv("AWS_SESSION_TOKEN"); +} + +static void test_aws_external_account_creds_failure_unmatched_environment_id( + void) { + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_aws_external_account_creds_options_credential_source_unmatched_environment_id, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds == nullptr); + std::string expected_error = "environment_id does not match."; + std::string actual_error; + GPR_ASSERT( + grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &actual_error)); + GPR_ASSERT(expected_error == actual_error); + GRPC_ERROR_UNREF(error); +} + +static void test_aws_external_account_creds_failure_invalid_region_url(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_aws_external_account_creds_options_credential_source_invalid_region_url, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Invalid region url: invalid_region_url."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); +} + +static void test_aws_external_account_creds_failure_invalid_url(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_aws_external_account_creds_options_credential_source_invalid_url, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Invalid url: invalid_url."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); +} + +static void test_aws_external_account_creds_failure_missing_role_name(void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_aws_external_account_creds_options_credential_source_missing_role_name, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Missing role name when retrieving signing keys."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); +} + +static void +test_aws_external_account_creds_failure_invalid_regional_cred_verification_url( + void) { + grpc_core::ExecCtx exec_ctx; + grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, + nullptr, nullptr}; + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json credential_source = grpc_core::Json::Parse( + invalid_aws_external_account_creds_options_credential_source_invalid_regional_cred_verification_url, + &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ExternalAccountCredentials::Options options = { + "external_account", // type; + "audience", // audience; + "subject_token_type", // subject_token_type; + "", // service_account_impersonation_url; + "https://foo.com:5555/token", // token_url; + "https://foo.com:5555/token_info", // token_info_url; + credential_source, // credential_source; + "quota_project_id", // quota_project_id; + "client_id", // client_id; + "client_secret", // client_secret; + "", // workforce_pool_user_project; + }; + auto creds = + grpc_core::AwsExternalAccountCredentials::Create(options, {}, &error); + GPR_ASSERT(creds != nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(creds->min_security_level() == GRPC_PRIVACY_AND_INTEGRITY); + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Creating aws request signer failed."); + grpc_error_handle expected_error = + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error occurred when fetching oauth2 token.", &error, 1); + RequestMetadataState* state = + RequestMetadataState::NewInstance(expected_error, {}); + grpc_httpcli_set_override(aws_external_account_creds_httpcli_get_success, + aws_external_account_creds_httpcli_post_success); + state->RunRequestMetadataTest(creds.get(), auth_md_ctx); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); + GRPC_ERROR_UNREF(error); +} + +static void test_external_account_credentials_create_success(void) { + // url credentials + const char* url_options_string = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"url\":\"https://foo.com:5555/" + "generate_subject_token_format_json\",\"headers\":{\"Metadata-Flavor\":" + "\"Google\"},\"format\":{\"type\":\"json\",\"subject_token_field_name\":" + "\"access_token\"}},\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + const char* url_scopes_string = "scope1,scope2"; + grpc_call_credentials* url_creds = grpc_external_account_credentials_create( + url_options_string, url_scopes_string); + GPR_ASSERT(url_creds != nullptr); + url_creds->Unref(); + // file credentials + const char* file_options_string = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"file\":\"credentials_file_path\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + const char* file_scopes_string = "scope1,scope2"; + grpc_call_credentials* file_creds = grpc_external_account_credentials_create( + file_options_string, file_scopes_string); + GPR_ASSERT(file_creds != nullptr); + file_creds->Unref(); + // aws credentials + const char* aws_options_string = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\",\"url\":\"https://" + "foo.com:5555/url\",\"regional_cred_verification_url\":\"https://" + "foo.com:5555/regional_cred_verification_url_{region}\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + const char* aws_scopes_string = "scope1,scope2"; + grpc_call_credentials* aws_creds = grpc_external_account_credentials_create( + aws_options_string, aws_scopes_string); + GPR_ASSERT(aws_creds != nullptr); + aws_creds->Unref(); +} + +static void +test_external_account_credentials_create_failure_invalid_json_format(void) { + const char* options_string = "invalid_json"; + grpc_call_credentials* creds = + grpc_external_account_credentials_create(options_string, ""); + GPR_ASSERT(creds == nullptr); +} + +static void +test_external_account_credentials_create_failure_invalid_options_format(void) { + const char* options_string = "{\"random_key\":\"random_value\"}"; + grpc_call_credentials* creds = + grpc_external_account_credentials_create(options_string, ""); + GPR_ASSERT(creds == nullptr); +} + +static void +test_external_account_credentials_create_failure_invalid_options_credential_source( + void) { + const char* options_string = + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"random_key\":\"random_value\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"; + grpc_call_credentials* creds = + grpc_external_account_credentials_create(options_string, ""); + GPR_ASSERT(creds == nullptr); +} + +static void test_external_account_credentials_create_success_workforce_pool( + void) { + const char* url_options_string = + "{\"type\":\"external_account\",\"audience\":\"//iam.googleapis.com/" + "locations/location/workforcePools/pool/providers/provider\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"url\":\"https://foo.com:5555/" + "generate_subject_token_format_json\",\"headers\":{\"Metadata-Flavor\":" + "\"Google\"},\"format\":{\"type\":\"json\",\"subject_token_field_name\":" + "\"access_token\"}},\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\",\"workforce_pool_user_project\":\"workforce_pool_user_" + "project\"}"; + const char* url_scopes_string = "scope1,scope2"; + grpc_call_credentials* url_creds = grpc_external_account_credentials_create( + url_options_string, url_scopes_string); + GPR_ASSERT(url_creds != nullptr); + url_creds->Unref(); +} + +static void +test_external_account_credentials_create_failure_invalid_workforce_pool_audience( + void) { + const char* url_options_string = + "{\"type\":\"external_account\",\"audience\":\"invalid_workforce_pool_" + "audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"url\":\"https://foo.com:5555/" + "generate_subject_token_format_json\",\"headers\":{\"Metadata-Flavor\":" + "\"Google\"},\"format\":{\"type\":\"json\",\"subject_token_field_name\":" + "\"access_token\"}},\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\",\"workforce_pool_user_project\":\"workforce_pool_user_" + "project\"}"; + const char* url_scopes_string = "scope1,scope2"; + grpc_call_credentials* url_creds = grpc_external_account_credentials_create( + url_options_string, url_scopes_string); + GPR_ASSERT(url_creds == nullptr); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_empty_md_array(); + test_add_to_empty_md_array(); + test_add_abunch_to_md_array(); + test_oauth2_token_fetcher_creds_parsing_ok(); + test_oauth2_token_fetcher_creds_parsing_bad_http_status(); + test_oauth2_token_fetcher_creds_parsing_empty_http_body(); + test_oauth2_token_fetcher_creds_parsing_invalid_json(); + test_oauth2_token_fetcher_creds_parsing_missing_token(); + test_oauth2_token_fetcher_creds_parsing_missing_token_type(); + test_oauth2_token_fetcher_creds_parsing_missing_token_lifetime(); + test_google_iam_creds(); + test_access_token_creds(); + test_channel_oauth2_composite_creds(); + test_oauth2_google_iam_composite_creds(); + test_channel_oauth2_google_iam_composite_creds(); + test_compute_engine_creds_success(); + test_compute_engine_creds_failure(); + test_refresh_token_creds_success(); + test_refresh_token_creds_failure(); + test_valid_sts_creds_options(); + test_invalid_sts_creds_options(); + test_sts_creds_success(); + test_sts_creds_no_actor_token_success(); + test_sts_creds_load_token_failure(); + test_sts_creds_http_failure(); + test_sts_creds_token_file_not_found(); + test_jwt_creds_lifetime(); + test_jwt_creds_success(); + test_jwt_creds_signing_failure(); + test_remove_service_from_jwt_uri(); + test_google_default_creds_auth_key(); + test_google_default_creds_refresh_token(); + test_google_default_creds_external_account_credentials(); + test_google_default_creds_external_account_credentials_multi_pattern_sts(); + test_google_default_creds_external_account_credentials_multi_pattern_iam(); + test_google_default_creds_gce(); + test_google_default_creds_non_gce(); + test_no_google_default_creds(); + test_google_default_creds_call_creds_specified(); + test_google_default_creds_not_default(); + test_metadata_plugin_success(); + test_metadata_plugin_failure(); + test_get_well_known_google_credentials_file_path(); + test_channel_creds_duplicate_without_call_creds(); + test_auth_metadata_context(); + test_external_account_creds_success(); + test_external_account_creds_success_with_url_encode(); + test_external_account_creds_success_with_service_account_impersonation(); + test_external_account_creds_failure_invalid_token_url(); + test_external_account_creds_failure_invalid_service_account_impersonation_url(); + test_external_account_creds_failure_token_exchange_response_missing_access_token(); + test_url_external_account_creds_success_format_text(); + test_url_external_account_creds_success_format_json(); + test_url_external_account_creds_failure_invalid_credential_source_url(); + test_url_external_account_creds_success_with_qurey_params_format_text(); + test_file_external_account_creds_success_format_text(); + test_file_external_account_creds_success_format_json(); + test_file_external_account_creds_failure_file_not_found(); + test_file_external_account_creds_failure_invalid_json_content(); + test_aws_external_account_creds_success(); + test_aws_external_account_creds_success_path_region_env_keys_url(); + test_aws_external_account_creds_success_path_default_region_env_keys_url(); + test_aws_external_account_creds_success_path_duplicate_region_env_keys_url(); + test_aws_external_account_creds_success_path_region_url_keys_env(); + test_aws_external_account_creds_success_path_region_env_keys_env(); + test_aws_external_account_creds_success_path_default_region_env_keys_env(); + test_aws_external_account_creds_success_path_duplicate_region_env_keys_env(); + test_aws_external_account_creds_failure_unmatched_environment_id(); + test_aws_external_account_creds_failure_invalid_region_url(); + test_aws_external_account_creds_failure_invalid_url(); + test_aws_external_account_creds_failure_missing_role_name(); + test_aws_external_account_creds_failure_invalid_regional_cred_verification_url(); + test_external_account_credentials_create_success(); + test_external_account_credentials_create_failure_invalid_json_format(); + test_external_account_credentials_create_failure_invalid_options_format(); + test_external_account_credentials_create_failure_invalid_options_credential_source(); + test_external_account_credentials_create_success_workforce_pool(); + test_external_account_credentials_create_failure_invalid_workforce_pool_audience(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/evaluate_args_test.cc b/test/core/security/evaluate_args_test.cc new file mode 100644 index 00000000..a8df9b30 --- /dev/null +++ b/test/core/security/evaluate_args_test.cc @@ -0,0 +1,174 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/evaluate_args.h" + +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "test/core/util/evaluate_args_test_util.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { + +class EvaluateArgsTest : public ::testing::Test { + protected: + EvaluateArgsTestUtil util_; +}; + +TEST_F(EvaluateArgsTest, EmptyMetadata) { + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetPath(), nullptr); + EXPECT_EQ(args.GetMethod(), nullptr); + EXPECT_EQ(args.GetHost(), nullptr); + EXPECT_THAT(args.GetHeaders(), ::testing::ElementsAre()); + EXPECT_EQ(args.GetHeaderValue("some_key", nullptr), absl::nullopt); +} + +TEST_F(EvaluateArgsTest, GetPathSuccess) { + util_.AddPairToMetadata(":path", "/expected/path"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetPath(), "/expected/path"); +} + +TEST_F(EvaluateArgsTest, GetHostSuccess) { + util_.AddPairToMetadata("host", "host123"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetHost(), "host123"); +} + +TEST_F(EvaluateArgsTest, GetMethodSuccess) { + util_.AddPairToMetadata(":method", "GET"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetMethod(), "GET"); +} + +TEST_F(EvaluateArgsTest, GetHeadersSuccess) { + util_.AddPairToMetadata("host", "host123"); + util_.AddPairToMetadata(":path", "/expected/path"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_THAT(args.GetHeaders(), + ::testing::UnorderedElementsAre( + ::testing::Pair("host", "host123"), + ::testing::Pair(":path", "/expected/path"))); +} + +TEST_F(EvaluateArgsTest, GetHeaderValueSuccess) { + util_.AddPairToMetadata("key123", "value123"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + std::string concatenated_value; + absl::optional value = + args.GetHeaderValue("key123", &concatenated_value); + ASSERT_TRUE(value.has_value()); + EXPECT_EQ(value.value(), "value123"); +} + +TEST_F(EvaluateArgsTest, TestLocalAddressAndPort) { + util_.SetLocalEndpoint("ipv6:[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:456"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + grpc_resolved_address local_address = args.GetLocalAddress(); + EXPECT_EQ(grpc_sockaddr_to_uri(&local_address), + "ipv6:[2001:db8:85a3::8a2e:370:7334]:456"); + EXPECT_EQ(args.GetLocalAddressString(), + "2001:0db8:85a3:0000:0000:8a2e:0370:7334"); + EXPECT_EQ(args.GetLocalPort(), 456); +} + +TEST_F(EvaluateArgsTest, TestPeerAddressAndPort) { + util_.SetPeerEndpoint("ipv4:255.255.255.255:123"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + grpc_resolved_address peer_address = args.GetPeerAddress(); + EXPECT_EQ(grpc_sockaddr_to_uri(&peer_address), "ipv4:255.255.255.255:123"); + EXPECT_EQ(args.GetPeerAddressString(), "255.255.255.255"); + EXPECT_EQ(args.GetPeerPort(), 123); +} + +TEST_F(EvaluateArgsTest, EmptyAuthContext) { + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_TRUE(args.GetTransportSecurityType().empty()); + EXPECT_TRUE(args.GetSpiffeId().empty()); + EXPECT_TRUE(args.GetUriSans().empty()); + EXPECT_TRUE(args.GetDnsSans().empty()); + EXPECT_TRUE(args.GetCommonName().empty()); +} + +TEST_F(EvaluateArgsTest, GetTransportSecurityTypeSuccessOneProperty) { + util_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + "ssl"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetTransportSecurityType(), "ssl"); +} + +TEST_F(EvaluateArgsTest, GetTransportSecurityTypeFailDuplicateProperty) { + util_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + "type1"); + util_.AddPropertyToAuthContext(GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + "type2"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_TRUE(args.GetTransportSecurityType().empty()); +} + +TEST_F(EvaluateArgsTest, GetSpiffeIdSuccessOneProperty) { + util_.AddPropertyToAuthContext(GRPC_PEER_SPIFFE_ID_PROPERTY_NAME, "id123"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetSpiffeId(), "id123"); +} + +TEST_F(EvaluateArgsTest, GetSpiffeIdFailDuplicateProperty) { + util_.AddPropertyToAuthContext(GRPC_PEER_SPIFFE_ID_PROPERTY_NAME, "id123"); + util_.AddPropertyToAuthContext(GRPC_PEER_SPIFFE_ID_PROPERTY_NAME, "id456"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_TRUE(args.GetSpiffeId().empty()); +} + +TEST_F(EvaluateArgsTest, GetUriSanSuccessMultipleProperties) { + util_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, "foo"); + util_.AddPropertyToAuthContext(GRPC_PEER_URI_PROPERTY_NAME, "bar"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_THAT(args.GetUriSans(), ::testing::ElementsAre("foo", "bar")); +} + +TEST_F(EvaluateArgsTest, GetDnsSanSuccessMultipleProperties) { + util_.AddPropertyToAuthContext(GRPC_PEER_DNS_PROPERTY_NAME, "foo"); + util_.AddPropertyToAuthContext(GRPC_PEER_DNS_PROPERTY_NAME, "bar"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_THAT(args.GetDnsSans(), ::testing::ElementsAre("foo", "bar")); +} + +TEST_F(EvaluateArgsTest, GetCommonNameSuccessOneProperty) { + util_.AddPropertyToAuthContext(GRPC_X509_CN_PROPERTY_NAME, "server123"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_EQ(args.GetCommonName(), "server123"); +} + +TEST_F(EvaluateArgsTest, GetCommonNameFailDuplicateProperty) { + util_.AddPropertyToAuthContext(GRPC_X509_CN_PROPERTY_NAME, "server123"); + util_.AddPropertyToAuthContext(GRPC_X509_CN_PROPERTY_NAME, "server456"); + EvaluateArgs args = util_.MakeEvaluateArgs(); + EXPECT_TRUE(args.GetCommonName().empty()); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/fetch_oauth2.cc b/test/core/security/fetch_oauth2.cc new file mode 100644 index 00000000..b9f08dcd --- /dev/null +++ b/test/core/security/fetch_oauth2.cc @@ -0,0 +1,157 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/util/json_util.h" +#include "src/cpp/client/secure_credentials.h" +#include "test/core/security/oauth2_utils.h" +#include "test/core/util/cmdline.h" + +static grpc_call_credentials* create_sts_creds(const char* json_file_path) { + grpc::experimental::StsCredentialsOptions options; + if (strlen(json_file_path) == 0) { + auto status = grpc::experimental::StsCredentialsOptionsFromEnv(&options); + if (!status.ok()) { + gpr_log(GPR_ERROR, "%s", status.error_message().c_str()); + return nullptr; + } + } else { + grpc_slice sts_options_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(json_file_path, 1, &sts_options_slice))); + auto status = grpc::experimental::StsCredentialsOptionsFromJson( + reinterpret_cast(GRPC_SLICE_START_PTR(sts_options_slice)), + &options); + gpr_slice_unref(sts_options_slice); + if (!status.ok()) { + gpr_log(GPR_ERROR, "%s", status.error_message().c_str()); + return nullptr; + } + } + grpc_sts_credentials_options opts = + grpc::experimental::StsCredentialsCppToCoreOptions(options); + grpc_call_credentials* result = grpc_sts_credentials_create(&opts, nullptr); + return result; +} + +static grpc_call_credentials* create_refresh_token_creds( + const char* json_refresh_token_file_path) { + grpc_slice refresh_token; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file(json_refresh_token_file_path, 1, &refresh_token))); + grpc_call_credentials* result = grpc_google_refresh_token_credentials_create( + reinterpret_cast GRPC_SLICE_START_PTR(refresh_token), + nullptr); + gpr_slice_unref(refresh_token); + return result; +} + +int main(int argc, char** argv) { + grpc_call_credentials* creds = nullptr; + const char* json_sts_options_file_path = nullptr; + const char* json_refresh_token_file_path = nullptr; + char* token = nullptr; + int use_gce = 0; + gpr_cmdline* cl = gpr_cmdline_create("fetch_oauth2"); + gpr_cmdline_add_string(cl, "json_refresh_token", + "File path of the json refresh token.", + &json_refresh_token_file_path); + gpr_cmdline_add_string( + cl, "json_sts_options", + "File path of the json sts options. If the path is empty, the program " + "will attempt to use the $STS_CREDENTIALS environment variable to access " + "a file containing the options.", + &json_sts_options_file_path); + gpr_cmdline_add_flag( + cl, "gce", + "Get a token from the GCE metadata server (only works in GCE).", + &use_gce); + gpr_cmdline_parse(cl, argc, argv); + + grpc_init(); + + if (json_sts_options_file_path != nullptr && + json_refresh_token_file_path != nullptr) { + gpr_log( + GPR_ERROR, + "--json_sts_options and --json_refresh_token are mutually exclusive."); + exit(1); + } + + if (use_gce) { + if (json_sts_options_file_path != nullptr || + json_refresh_token_file_path != nullptr) { + gpr_log(GPR_INFO, + "Ignoring json refresh token or sts options to get a token from " + "the GCE metadata server."); + } + creds = grpc_google_compute_engine_credentials_create(nullptr); + if (creds == nullptr) { + gpr_log(GPR_ERROR, "Could not create gce credentials."); + exit(1); + } + } else if (json_refresh_token_file_path != nullptr) { + creds = create_refresh_token_creds(json_refresh_token_file_path); + if (creds == nullptr) { + gpr_log(GPR_ERROR, + "Could not create refresh token creds. %s does probably not " + "contain a valid json refresh token.", + json_refresh_token_file_path); + exit(1); + } + } else if (json_sts_options_file_path != nullptr) { + creds = create_sts_creds(json_sts_options_file_path); + if (creds == nullptr) { + gpr_log(GPR_ERROR, + "Could not create sts creds. %s does probably not contain a " + "valid json for sts options.", + json_sts_options_file_path); + exit(1); + } + } else { + gpr_log( + GPR_ERROR, + "Missing --gce, --json_sts_options, or --json_refresh_token option."); + exit(1); + } + GPR_ASSERT(creds != nullptr); + + token = grpc_test_fetch_oauth2_token_with_credentials(creds); + if (token != nullptr) { + printf("Got token: %s.\n", token); + gpr_free(token); + } + grpc_call_credentials_release(creds); + gpr_cmdline_destroy(cl); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/grpc_alts_credentials_options_test.cc b/test/core/security/grpc_alts_credentials_options_test.cc new file mode 100644 index 00000000..0e4c2fe2 --- /dev/null +++ b/test/core/security/grpc_alts_credentials_options_test.cc @@ -0,0 +1,95 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" + +#include +#include +#include + +#include +#include + +#define ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_1 "abc@google.com" +#define ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_2 "def@google.com" + +const size_t kTargetServiceAccountNum = 2; + +static void test_copy_client_options_failure() { + /* Initialization. */ + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + /* Test. */ + GPR_ASSERT(grpc_alts_credentials_options_copy(nullptr) == nullptr); + /* Cleanup. */ + grpc_alts_credentials_options_destroy(options); +} + +static size_t get_target_service_account_num( + grpc_alts_credentials_options* options) { + auto client_options = + reinterpret_cast(options); + size_t num = 0; + target_service_account* node = client_options->target_account_list_head; + while (node != nullptr) { + num++; + node = node->next; + } + return num; +} + +static void test_client_options_api_success() { + /* Initialization. */ + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + /* Set client options fields. */ + grpc_alts_credentials_client_options_add_target_service_account( + options, ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_1); + grpc_alts_credentials_client_options_add_target_service_account( + options, ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_2); + /* Validate client option fields. */ + GPR_ASSERT(get_target_service_account_num(options) == + kTargetServiceAccountNum); + auto client_options = + reinterpret_cast(options); + GPR_ASSERT(strcmp(client_options->target_account_list_head->data, + ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_2) == 0); + GPR_ASSERT(strcmp(client_options->target_account_list_head->next->data, + ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_1) == 0); + /* Perform a copy operation and validate its correctness. */ + grpc_alts_credentials_options* new_options = + grpc_alts_credentials_options_copy(options); + GPR_ASSERT(get_target_service_account_num(new_options) == + kTargetServiceAccountNum); + auto new_client_options = + reinterpret_cast(new_options); + GPR_ASSERT(strcmp(new_client_options->target_account_list_head->data, + ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_2) == 0); + GPR_ASSERT(strcmp(new_client_options->target_account_list_head->next->data, + ALTS_CLIENT_OPTIONS_TEST_TARGET_SERVICE_ACCOUNT_1) == 0); + /* Cleanup.*/ + grpc_alts_credentials_options_destroy(options); + grpc_alts_credentials_options_destroy(new_options); +} + +int main(int /*argc*/, char** /*argv*/) { + /* Test. */ + test_copy_client_options_failure(); + test_client_options_api_success(); + return 0; +} diff --git a/test/core/security/grpc_authorization_engine_test.cc b/test/core/security/grpc_authorization_engine_test.cc new file mode 100644 index 00000000..f85482e1 --- /dev/null +++ b/test/core/security/grpc_authorization_engine_test.cc @@ -0,0 +1,115 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/grpc_authorization_engine.h" + +#include +#include + +namespace grpc_core { + +TEST(GrpcAuthorizationEngineTest, AllowEngineWithMatchingPolicy) { + Rbac::Policy policy1( + Rbac::Permission(Rbac::Permission::RuleType::kNot, + Rbac::Permission(Rbac::Permission::RuleType::kAny)), + Rbac::Principal(Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + Rbac::Policy policy2((Rbac::Permission(Rbac::Permission::RuleType::kAny)), + (Rbac::Principal(Rbac::Principal::RuleType::kAny))); + std::map policies; + policies["policy1"] = std::move(policy1); + policies["policy2"] = std::move(policy2); + Rbac rbac(Rbac::Action::kAllow, std::move(policies)); + GrpcAuthorizationEngine engine(std::move(rbac)); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kAllow); + EXPECT_EQ(decision.matching_policy_name, "policy2"); +} + +TEST(GrpcAuthorizationEngineTest, AllowEngineWithNoMatchingPolicy) { + Rbac::Policy policy1( + Rbac::Permission(Rbac::Permission::RuleType::kNot, + Rbac::Permission(Rbac::Permission::RuleType::kAny)), + Rbac::Principal(Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + std::map policies; + policies["policy1"] = std::move(policy1); + Rbac rbac(Rbac::Action::kAllow, std::move(policies)); + GrpcAuthorizationEngine engine(std::move(rbac)); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kDeny); + EXPECT_TRUE(decision.matching_policy_name.empty()); +} + +TEST(GrpcAuthorizationEngineTest, AllowEngineWithEmptyPolicies) { + GrpcAuthorizationEngine engine(Rbac::Action::kAllow); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kDeny); + EXPECT_TRUE(decision.matching_policy_name.empty()); +} + +TEST(GrpcAuthorizationEngineTest, DenyEngineWithMatchingPolicy) { + Rbac::Policy policy1( + Rbac::Permission(Rbac::Permission::RuleType::kNot, + Rbac::Permission(Rbac::Permission::RuleType::kAny)), + Rbac::Principal(Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + Rbac::Policy policy2((Rbac::Permission(Rbac::Permission::RuleType::kAny)), + (Rbac::Principal(Rbac::Principal::RuleType::kAny))); + std::map policies; + policies["policy1"] = std::move(policy1); + policies["policy2"] = std::move(policy2); + Rbac rbac(Rbac::Action::kDeny, std::move(policies)); + GrpcAuthorizationEngine engine(std::move(rbac)); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kDeny); + EXPECT_EQ(decision.matching_policy_name, "policy2"); +} + +TEST(GrpcAuthorizationEngineTest, DenyEngineWithNoMatchingPolicy) { + Rbac::Policy policy1( + Rbac::Permission(Rbac::Permission::RuleType::kNot, + Rbac::Permission(Rbac::Permission::RuleType::kAny)), + Rbac::Principal(Rbac::Principal::RuleType::kNot, + Rbac::Principal(Rbac::Principal::RuleType::kAny))); + std::map policies; + policies["policy1"] = std::move(policy1); + Rbac rbac(Rbac::Action::kDeny, std::move(policies)); + GrpcAuthorizationEngine engine(std::move(rbac)); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kAllow); + EXPECT_TRUE(decision.matching_policy_name.empty()); +} + +TEST(GrpcAuthorizationEngineTest, DenyEngineWithEmptyPolicies) { + GrpcAuthorizationEngine engine(Rbac::Action::kDeny); + AuthorizationEngine::Decision decision = + engine.Evaluate(EvaluateArgs(nullptr, nullptr)); + EXPECT_EQ(decision.type, AuthorizationEngine::Decision::Type::kAllow); + EXPECT_TRUE(decision.matching_policy_name.empty()); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/security/grpc_authorization_policy_provider_test.cc b/test/core/security/grpc_authorization_policy_provider_test.cc new file mode 100644 index 00000000..3f1c87c8 --- /dev/null +++ b/test/core/security/grpc_authorization_policy_provider_test.cc @@ -0,0 +1,219 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/security/authorization/grpc_authorization_policy_provider.h" + +#include +#include + +#include + +#include "src/core/lib/security/authorization/grpc_authorization_engine.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +#define VALID_POLICY_PATH_1 \ + "test/core/security/authorization/test_policies/valid_policy_1.json" +#define VALID_POLICY_PATH_2 \ + "test/core/security/authorization/test_policies/valid_policy_2.json" +#define INVALID_POLICY_PATH \ + "test/core/security/authorization/test_policies/invalid_policy.json" + +namespace grpc_core { + +TEST(AuthorizationPolicyProviderTest, StaticDataInitializationSuccessful) { + auto provider = StaticDataAuthorizationPolicyProvider::Create( + testing::GetFileContents(VALID_POLICY_PATH_1)); + ASSERT_TRUE(provider.ok()); + auto engines = (*provider)->engines(); + auto* allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + auto* deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); +} + +TEST(AuthorizationPolicyProviderTest, + StaticDataInitializationFailedInvalidPolicy) { + auto provider = StaticDataAuthorizationPolicyProvider::Create( + testing::GetFileContents(INVALID_POLICY_PATH)); + EXPECT_EQ(provider.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(provider.status().message(), "\"name\" field is not present."); +} + +TEST(AuthorizationPolicyProviderTest, + FileWatcherInitializationSuccessValidPolicy) { + auto tmp_authz_policy = absl::make_unique( + testing::GetFileContents(VALID_POLICY_PATH_1)); + auto provider = FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1); + ASSERT_TRUE(provider.ok()); + auto engines = (*provider)->engines(); + auto* allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + auto* deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); +} + +TEST(AuthorizationPolicyProviderTest, + FileWatcherInitializationFailedInvalidPolicy) { + auto tmp_authz_policy = absl::make_unique( + testing::GetFileContents(INVALID_POLICY_PATH)); + auto provider = FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1); + EXPECT_EQ(provider.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(provider.status().message(), "\"name\" field is not present."); +} + +TEST(AuthorizationPolicyProviderTest, FileWatcherSuccessValidPolicyRefresh) { + auto tmp_authz_policy = absl::make_unique( + testing::GetFileContents(VALID_POLICY_PATH_1)); + auto provider = FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1); + ASSERT_TRUE(provider.ok()); + auto engines = (*provider)->engines(); + auto* allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + auto* deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); + // Rewrite the file with a different valid authorization policy. + tmp_authz_policy->RewriteFile(testing::GetFileContents(VALID_POLICY_PATH_2)); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + engines = (*provider)->engines(); + allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 2); + deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 0); +} + +TEST(AuthorizationPolicyProviderTest, + FileWatcherInvalidPolicyRefreshSkipReload) { + auto tmp_authz_policy = absl::make_unique( + testing::GetFileContents(VALID_POLICY_PATH_1)); + auto provider = FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1); + ASSERT_TRUE(provider.ok()); + auto engines = (*provider)->engines(); + auto* allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + auto* deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); + // Skips the following policy update, and continues to use the valid policy. + tmp_authz_policy->RewriteFile(testing::GetFileContents(INVALID_POLICY_PATH)); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + engines = (*provider)->engines(); + allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); +} + +TEST(AuthorizationPolicyProviderTest, FileWatcherRecoversFromFailure) { + auto tmp_authz_policy = absl::make_unique( + testing::GetFileContents(VALID_POLICY_PATH_1)); + auto provider = FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1); + ASSERT_TRUE(provider.ok()); + auto engines = (*provider)->engines(); + auto* allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + auto* deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); + // Skips the following policy update, and continues to use the valid policy. + tmp_authz_policy->RewriteFile(testing::GetFileContents(INVALID_POLICY_PATH)); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + engines = (*provider)->engines(); + allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 1); + deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 1); + // Rewrite the file with a valid authorization policy. + tmp_authz_policy->RewriteFile(testing::GetFileContents(VALID_POLICY_PATH_2)); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + engines = (*provider)->engines(); + allow_engine = + dynamic_cast(engines.allow_engine.get()); + ASSERT_NE(allow_engine, nullptr); + EXPECT_EQ(allow_engine->action(), Rbac::Action::kAllow); + EXPECT_EQ(allow_engine->num_policies(), 2); + deny_engine = + dynamic_cast(engines.deny_engine.get()); + ASSERT_NE(deny_engine, nullptr); + EXPECT_EQ(deny_engine->action(), Rbac::Action::kDeny); + EXPECT_EQ(deny_engine->num_policies(), 0); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/grpc_tls_certificate_distributor_test.cc b/test/core/security/grpc_tls_certificate_distributor_test.cc new file mode 100644 index 00000000..24a25a1b --- /dev/null +++ b/test/core/security/grpc_tls_certificate_distributor_test.cc @@ -0,0 +1,950 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_distributor.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +namespace grpc_core { + +namespace testing { + +constexpr const char* kCertName1 = "cert_1_name"; +constexpr const char* kCertName2 = "cert_2_name"; +constexpr const char* kRootCert1Name = "root_cert_1_name"; +constexpr const char* kRootCert1Contents = "root_cert_1_contents"; +constexpr const char* kRootCert2Name = "root_cert_2_name"; +constexpr const char* kRootCert2Contents = "root_cert_2_contents"; +constexpr const char* kIdentityCert1Name = "identity_cert_1_name"; +constexpr const char* kIdentityCert1PrivateKey = "identity_private_key_1"; +constexpr const char* kIdentityCert1Contents = "identity_cert_1_contents"; +constexpr const char* kIdentityCert2Name = "identity_cert_2_name"; +constexpr const char* kIdentityCert2PrivateKey = "identity_private_key_2"; +constexpr const char* kIdentityCert2Contents = "identity_cert_2_contents"; +constexpr const char* kErrorMessage = "error_message"; +constexpr const char* kRootErrorMessage = "root_error_message"; +constexpr const char* kIdentityErrorMessage = "identity_error_message"; + +class GrpcTlsCertificateDistributorTest : public ::testing::Test { + protected: + // Forward declaration. + class TlsCertificatesTestWatcher; + + // CredentialInfo contains the parameters when calling OnCertificatesChanged + // of a watcher. When OnCertificatesChanged is invoked, we will push a + // CredentialInfo to the cert_update_queue of state_, and check in each test + // if the status updates are correct. + struct CredentialInfo { + std::string root_certs; + PemKeyCertPairList key_cert_pairs; + CredentialInfo(std::string root, PemKeyCertPairList key_cert) + : root_certs(std::move(root)), key_cert_pairs(std::move(key_cert)) {} + bool operator==(const CredentialInfo& other) const { + return root_certs == other.root_certs && + key_cert_pairs == other.key_cert_pairs; + } + }; + + // ErrorInfo contains the parameters when calling OnError of a watcher. When + // OnError is invoked, we will push a ErrorInfo to the error_queue of state_, + // and check in each test if the status updates are correct. + struct ErrorInfo { + std::string root_cert_str; + std::string identity_cert_str; + ErrorInfo(std::string root, std::string identity) + : root_cert_str(std::move(root)), + identity_cert_str(std::move(identity)) {} + bool operator==(const ErrorInfo& other) const { + return root_cert_str == other.root_cert_str && + identity_cert_str == other.identity_cert_str; + } + }; + + struct WatcherState { + TlsCertificatesTestWatcher* watcher = nullptr; + std::deque cert_update_queue; + std::deque error_queue; + + std::deque GetCredentialQueue() { + // We move the data member value so the data member will be re-initiated + // with size 0, and ready for the next check. + return std::move(cert_update_queue); + } + std::deque GetErrorQueue() { + // We move the data member value so the data member will be re-initiated + // with size 0, and ready for the next check. + return std::move(error_queue); + } + }; + + class TlsCertificatesTestWatcher : public grpc_tls_certificate_distributor:: + TlsCertificatesWatcherInterface { + public: + // ctor sets state->watcher to this. + explicit TlsCertificatesTestWatcher(WatcherState* state) : state_(state) { + state_->watcher = this; + } + + // dtor sets state->watcher to nullptr. + ~TlsCertificatesTestWatcher() override { state_->watcher = nullptr; } + + void OnCertificatesChanged( + absl::optional root_certs, + absl::optional key_cert_pairs) override { + std::string updated_root; + if (root_certs.has_value()) { + updated_root = std::string(*root_certs); + } + PemKeyCertPairList updated_identity; + if (key_cert_pairs.has_value()) { + updated_identity = std::move(*key_cert_pairs); + } + state_->cert_update_queue.emplace_back(std::move(updated_root), + std::move(updated_identity)); + } + + void OnError(grpc_error_handle root_cert_error, + grpc_error_handle identity_cert_error) override { + GPR_ASSERT(root_cert_error != GRPC_ERROR_NONE || + identity_cert_error != GRPC_ERROR_NONE); + std::string root_error_str; + std::string identity_error_str; + if (root_cert_error != GRPC_ERROR_NONE) { + GPR_ASSERT(grpc_error_get_str( + root_cert_error, GRPC_ERROR_STR_DESCRIPTION, &root_error_str)); + } + if (identity_cert_error != GRPC_ERROR_NONE) { + GPR_ASSERT(grpc_error_get_str(identity_cert_error, + GRPC_ERROR_STR_DESCRIPTION, + &identity_error_str)); + } + state_->error_queue.emplace_back(std::move(root_error_str), + std::move(identity_error_str)); + GRPC_ERROR_UNREF(root_cert_error); + GRPC_ERROR_UNREF(identity_cert_error); + } + + private: + WatcherState* state_; + }; + + // CallbackStatus contains the parameters when calling watch_status_callback_ + // of the distributor. When a particular callback is invoked, we will push a + // CallbackStatus to a callback_queue_, and check in each test if the status + // updates are correct. + struct CallbackStatus { + std::string cert_name; + bool root_being_watched; + bool identity_being_watched; + CallbackStatus(std::string name, bool root_watched, bool identity_watched) + : cert_name(std::move(name)), + root_being_watched(root_watched), + identity_being_watched(identity_watched) {} + bool operator==(const CallbackStatus& other) const { + return cert_name == other.cert_name && + root_being_watched == other.root_being_watched && + identity_being_watched == other.identity_being_watched; + } + }; + + void SetUp() override { + distributor_.SetWatchStatusCallback([this](std::string cert_name, + bool root_being_watched, + bool identity_being_watched) { + callback_queue_.emplace_back(std::move(cert_name), root_being_watched, + identity_being_watched); + }); + } + + WatcherState* MakeWatcher(absl::optional root_cert_name, + absl::optional identity_cert_name) { + MutexLock lock(&mu_); + watchers_.emplace_back(); + // TlsCertificatesTestWatcher ctor takes a pointer to the WatcherState. + // It sets WatcherState::watcher to point to itself. + // The TlsCertificatesTestWatcher dtor will set WatcherState::watcher back + // to nullptr to indicate that it's been destroyed. + auto watcher = + absl::make_unique(&watchers_.back()); + distributor_.WatchTlsCertificates(std::move(watcher), + std::move(root_cert_name), + std::move(identity_cert_name)); + return &watchers_.back(); + } + + void CancelWatch(WatcherState* state) { + MutexLock lock(&mu_); + distributor_.CancelTlsCertificatesWatch(state->watcher); + EXPECT_EQ(state->watcher, nullptr); + } + + std::deque GetCallbackQueue() { + // We move the data member value so the data member will be re-initiated + // with size 0, and ready for the next check. + return std::move(callback_queue_); + } + + grpc_tls_certificate_distributor distributor_; + // Use a std::list<> here to avoid the address invalidation caused by internal + // reallocation of std::vector<>. + std::list watchers_; + std::deque callback_queue_; + // This is to make watchers_ and callback_queue_ thread-safe. + Mutex mu_; +}; + +TEST_F(GrpcTlsCertificateDistributorTest, BasicCredentialBehaviors) { + EXPECT_FALSE(distributor_.HasRootCerts(kRootCert1Name)); + EXPECT_FALSE(distributor_.HasKeyCertPairs(kIdentityCert1Name)); + // After setting the certificates to the corresponding cert names, the + // distributor should possess the corresponding certs. + distributor_.SetKeyMaterials(kRootCert1Name, kRootCert1Contents, + absl::nullopt); + EXPECT_TRUE(distributor_.HasRootCerts(kRootCert1Name)); + distributor_.SetKeyMaterials( + kIdentityCert1Name, absl::nullopt, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + EXPECT_TRUE(distributor_.HasKeyCertPairs(kIdentityCert1Name)); + // Querying a non-existing cert name should return false. + EXPECT_FALSE(distributor_.HasRootCerts(kRootCert2Name)); + EXPECT_FALSE(distributor_.HasKeyCertPairs(kIdentityCert2Name)); +} + +TEST_F(GrpcTlsCertificateDistributorTest, UpdateCredentialsOnAnySide) { + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // SetKeyMaterials should trigger watcher's OnCertificatesChanged method. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + // Set root certs should trigger watcher's OnCertificatesChanged again. + distributor_.SetKeyMaterials(kCertName1, kRootCert2Contents, absl::nullopt); + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert2Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + // Set identity certs should trigger watcher's OnCertificatesChanged again. + distributor_.SetKeyMaterials( + kCertName1, absl::nullopt, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert2Contents, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)))); + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateDistributorTest, SameIdentityNameDiffRootName) { + // Register watcher 1. + WatcherState* watcher_state_1 = + MakeWatcher(kRootCert1Name, kIdentityCert1Name); + EXPECT_THAT( + GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kRootCert1Name, true, false), + CallbackStatus(kIdentityCert1Name, false, true))); + // Register watcher 2. + WatcherState* watcher_state_2 = + MakeWatcher(kRootCert2Name, kIdentityCert1Name); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre(CallbackStatus( + kRootCert2Name, true, false))); + // Push credential updates to kRootCert1Name and check if the status works as + // expected. + distributor_.SetKeyMaterials(kRootCert1Name, kRootCert1Contents, + absl::nullopt); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert1Contents, {}))); + // Push credential updates to kRootCert2Name. + distributor_.SetKeyMaterials(kRootCert2Name, kRootCert2Contents, + absl::nullopt); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert2Contents, {}))); + // Push credential updates to kIdentityCert1Name and check if the status works + // as expected. + distributor_.SetKeyMaterials( + kIdentityCert1Name, absl::nullopt, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Check the updates are delivered to watcher 1 and watcher 2. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + EXPECT_THAT( + watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert2Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre(CallbackStatus( + kRootCert1Name, false, false))); + // Cancel watcher 2. + CancelWatch(watcher_state_2); + EXPECT_THAT( + GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kRootCert2Name, false, false), + CallbackStatus(kIdentityCert1Name, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, SameRootNameDiffIdentityName) { + // Register watcher 1. + WatcherState* watcher_state_1 = + MakeWatcher(kRootCert1Name, kIdentityCert1Name); + EXPECT_THAT( + GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kRootCert1Name, true, false), + CallbackStatus(kIdentityCert1Name, false, true))); + // Register watcher 2. + WatcherState* watcher_state_2 = + MakeWatcher(kRootCert1Name, kIdentityCert2Name); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre(CallbackStatus( + kIdentityCert2Name, false, true))); + // Push credential updates to kRootCert1Name and check if the status works as + // expected. + distributor_.SetKeyMaterials(kRootCert1Name, kRootCert1Contents, + absl::nullopt); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert1Contents, {}))); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert1Contents, {}))); + // Push credential updates to SetKeyMaterials. + distributor_.SetKeyMaterials( + kIdentityCert1Name, absl::nullopt, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Check the updates are delivered to watcher 1. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + // Push credential updates to kIdentityCert2Name. + distributor_.SetKeyMaterials( + kIdentityCert2Name, absl::nullopt, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + // Check the updates are delivered to watcher 2. + EXPECT_THAT( + watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)))); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre(CallbackStatus( + kIdentityCert1Name, false, false))); + // Cancel watcher 2. + CancelWatch(watcher_state_2); + EXPECT_THAT( + GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kRootCert1Name, false, false), + CallbackStatus(kIdentityCert2Name, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + AddAndCancelFirstWatcherForSameRootAndIdentityCertName) { + // Register watcher 1 watching kCertName1 for both root and identity certs. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // Push credential updates to kCertName1 and check if the status works as + // expected. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Check the updates are delivered to watcher 1. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + AddAndCancelFirstWatcherForIdentityCertNameWithRootBeingWatched) { + // Register watcher 1 watching kCertName1 for root certs. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, absl::nullopt); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, false))); + // Register watcher 2 watching kCertName1 for identity certs. + WatcherState* watcher_state_2 = MakeWatcher(absl::nullopt, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // Push credential updates to kCertName1 and check if the status works as + // expected. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert1Contents, {}))); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(kIdentityCert1PrivateKey, + kIdentityCert1Contents)))); + // Push root cert updates to kCertName1. + distributor_.SetKeyMaterials(kCertName1, kRootCert2Contents, absl::nullopt); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert2Contents, {}))); + // Check the updates are not delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), ::testing::ElementsAre()); + // Push identity cert updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, absl::nullopt, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + // Check the updates are not delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(kIdentityCert2PrivateKey, + kIdentityCert2Contents)))); + watcher_state_2->cert_update_queue.clear(); + // Cancel watcher 2. + CancelWatch(watcher_state_2); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, false))); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + AddAndCancelFirstWatcherForRootCertNameWithIdentityBeingWatched) { + // Register watcher 1 watching kCertName1 for identity certs. + WatcherState* watcher_state_1 = MakeWatcher(absl::nullopt, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, true))); + // Register watcher 2 watching kCertName1 for root certs. + WatcherState* watcher_state_2 = MakeWatcher(kCertName1, absl::nullopt); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // Push credential updates to kCertName1 and check if the status works as + // expected. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(kIdentityCert1PrivateKey, + kIdentityCert1Contents)))); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert1Contents, {}))); + // Push root cert updates to kCertName1. + distributor_.SetKeyMaterials(kCertName1, kRootCert2Contents, absl::nullopt); + // Check the updates are delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert2Contents, {}))); + // Check the updates are not delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + // Push identity cert updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, absl::nullopt, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + // Check the updates are not delivered to watcher 2. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), ::testing::ElementsAre()); + // Check the updates are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(kIdentityCert2PrivateKey, + kIdentityCert2Contents)))); + // Cancel watcher 2. + CancelWatch(watcher_state_2); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, true))); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + RemoveAllWatchersForCertNameAndAddAgain) { + // Register watcher 1 and watcher 2 watching kCertName1 for root and identity + // certs. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + WatcherState* watcher_state_2 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre()); + // Push credential updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Cancel watcher 2. + CancelWatch(watcher_state_2); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre()); + // Cancel watcher 1. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, false))); + // Register watcher 3 watching kCertName for root and identity certs. + WatcherState* watcher_state_3 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // Push credential updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, kRootCert2Contents, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + // Check the updates are delivered to watcher 3. + EXPECT_THAT( + watcher_state_3->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert2Contents, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)))); + // Cancel watcher 3. + CancelWatch(watcher_state_3); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, false, false))); +} + +TEST_F(GrpcTlsCertificateDistributorTest, ResetCallbackToNull) { + // Register watcher 1 watching kCertName1 for root and identity certs. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + EXPECT_THAT(GetCallbackQueue(), + ::testing::ElementsAre(CallbackStatus(kCertName1, true, true))); + // Reset callback to nullptr. + distributor_.SetWatchStatusCallback(nullptr); + // Cancel watcher 1 shouldn't trigger any callback. + CancelWatch(watcher_state_1); + EXPECT_THAT(GetCallbackQueue(), ::testing::ElementsAre()); +} + +TEST_F(GrpcTlsCertificateDistributorTest, SetKeyMaterialsInCallback) { + distributor_.SetWatchStatusCallback([this](std::string cert_name, + bool /*root_being_watched*/, + bool /*identity_being_watched*/) { + distributor_.SetKeyMaterials( + cert_name, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + }); + auto verify_function = [this](std::string cert_name) { + WatcherState* watcher_state_1 = MakeWatcher(cert_name, cert_name); + // Check the updates are delivered to watcher 1. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, MakeCertKeyPairs(kIdentityCert1PrivateKey, + kIdentityCert1Contents)))); + CancelWatch(watcher_state_1); + }; + // Start 1000 threads that will register a watcher to a new cert name, verify + // the key materials being set, and then cancel the watcher, to make sure the + // lock mechanism in the distributor is safe. + std::vector threads; + threads.reserve(1000); + for (int i = 0; i < 1000; ++i) { + threads.emplace_back(verify_function, std::to_string(i)); + } + for (auto& th : threads) { + th.join(); + } +} + +TEST_F(GrpcTlsCertificateDistributorTest, WatchACertInfoWithValidCredentials) { + // Push credential updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Push root credential updates to kCertName2. + distributor_.SetKeyMaterials(kRootCert2Name, kRootCert2Contents, + absl::nullopt); + // Push identity credential updates to kCertName2. + distributor_.SetKeyMaterials( + kIdentityCert2Name, absl::nullopt, + MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2Contents)); + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // watcher 1 should receive the credentials right away. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + CancelWatch(watcher_state_1); + // Register watcher 2. + WatcherState* watcher_state_2 = MakeWatcher(kRootCert2Name, absl::nullopt); + // watcher 2 should receive the root credentials right away. + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(kRootCert2Contents, {}))); + // Register watcher 3. + WatcherState* watcher_state_3 = + MakeWatcher(absl::nullopt, kIdentityCert2Name); + // watcher 3 should received the identity credentials right away. + EXPECT_THAT(watcher_state_3->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(kIdentityCert2PrivateKey, + kIdentityCert2Contents)))); + CancelWatch(watcher_state_2); + CancelWatch(watcher_state_3); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForCertForBothRootAndIdentity) { + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // Calling SetErrorForCert on both cert names should only call one OnError + // on watcher 1. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + // Calling SetErrorForCert on root cert name should call OnError + // on watcher 1 again. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage), + absl::nullopt); + EXPECT_THAT( + watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kErrorMessage, kIdentityErrorMessage))); + // Calling SetErrorForCert on identity cert name should call OnError + // on watcher 1 again. + distributor_.SetErrorForCert( + kCertName1, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kErrorMessage, kErrorMessage))); + distributor_.CancelTlsCertificatesWatch(watcher_state_1->watcher); + EXPECT_EQ(watcher_state_1->watcher, nullptr); +} + +TEST_F(GrpcTlsCertificateDistributorTest, SetErrorForCertForRootOrIdentity) { + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, absl::nullopt); + // Calling SetErrorForCert on root name should only call one OnError + // on watcher 1. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootErrorMessage, ""))); + // Calling SetErrorForCert on identity name should do nothing. + distributor_.SetErrorForCert( + kCertName1, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on both names should still get one OnError call. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootErrorMessage, ""))); + CancelWatch(watcher_state_1); + // Register watcher 2. + WatcherState* watcher_state_2 = MakeWatcher(absl::nullopt, kCertName1); + // Calling SetErrorForCert on identity name should only call one OnError + // on watcher 2. + distributor_.SetErrorForCert( + kCertName1, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kIdentityErrorMessage))); + // Calling SetErrorForCert on root name should do nothing. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on both names should still get one OnError call. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kIdentityErrorMessage))); + CancelWatch(watcher_state_2); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForIdentityNameWithPreexistingErrorForRootName) { + // SetErrorForCert for kCertName1. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + // Register watcher 1 for kCertName1 as root and kCertName2 as identity. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName2); + // Should trigger OnError call right away since kCertName1 has error. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootErrorMessage, ""))); + // Calling SetErrorForCert on kCertName2 should trigger OnError with both + // errors, because kCertName1 also has error. + distributor_.SetErrorForCert( + kCertName2, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForCertForRootNameWithSameNameForIdentityErrored) { + // SetErrorForCert for kCertName1. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + // Register watcher 1 for kCertName2 as root and kCertName1 as identity. + WatcherState* watcher_state_1 = MakeWatcher(kCertName2, kCertName1); + // Should trigger OnError call right away since kCertName2 has error. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kIdentityErrorMessage))); + // Calling SetErrorForCert on kCertName2 should trigger OnError with both + // errors, because kCertName1 also has error. + distributor_.SetErrorForCert( + kCertName2, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForIdentityNameWithoutErrorForRootName) { + // Register watcher 1 for kCertName1 as root and kCertName2 as identity. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName2); + // Should not trigger OnError. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on kCertName2 should trigger OnError. + distributor_.SetErrorForCert( + kCertName2, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kIdentityErrorMessage))); + CancelWatch(watcher_state_1); + // Register watcher 2 for kCertName2 as identity and a non-existing name + // kRootCert1Name as root. + WatcherState* watcher_state_2 = MakeWatcher(kRootCert1Name, kCertName2); + // Should not trigger OnError. + EXPECT_THAT(watcher_state_2->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on kCertName2 should trigger OnError. + distributor_.SetErrorForCert( + kCertName2, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_2->error_queue, + ::testing::ElementsAre(ErrorInfo("", kIdentityErrorMessage))); + CancelWatch(watcher_state_2); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForRootNameWithPreexistingErrorForIdentityName) { + WatcherState* watcher_state_1 = MakeWatcher(kCertName2, kCertName1); + // Should not trigger OnError. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on kCertName2 should trigger OnError. + distributor_.SetErrorForCert( + kCertName2, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootErrorMessage, ""))); + CancelWatch(watcher_state_1); + // Register watcher 2 for kCertName2 as root and a non-existing name + // kIdentityCert1Name as identity. + WatcherState* watcher_state_2 = MakeWatcher(kCertName2, kIdentityCert1Name); + // Should not trigger OnError. + EXPECT_THAT(watcher_state_2->GetErrorQueue(), ::testing::ElementsAre()); + // Calling SetErrorForCert on kCertName2 should trigger OnError. + distributor_.SetErrorForCert( + kCertName2, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootErrorMessage, ""))); + CancelWatch(watcher_state_2); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + CancelTheLastWatcherOnAnErroredCertInfo) { + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // Calling SetErrorForCert on both cert names should only call one OnError + // on watcher 1. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + // When watcher 1 is removed, the cert info entry should be removed. + CancelWatch(watcher_state_1); + // Register watcher 2 on the same cert name. + WatcherState* watcher_state_2 = MakeWatcher(kCertName1, kCertName1); + // Should not trigger OnError call on watcher 2 right away. + EXPECT_THAT(watcher_state_2->GetErrorQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_2); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + WatchErroredCertInfoWithValidCredentialData) { + // Push credential updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Calling SetErrorForCert on both cert names. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // watcher 1 should receive both the old credentials and the error right away. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateDistributorTest, + SetErrorForCertThenSuccessfulCredentialUpdates) { + // Calling SetErrorForCert on both cert names. + distributor_.SetErrorForCert( + kCertName1, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + // Push credential updates to kCertName1. + distributor_.SetKeyMaterials( + kCertName1, kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)); + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // watcher 1 should only receive credential updates without any error, because + // the previous error is wiped out by a successful update. + EXPECT_THAT( + watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + kRootCert1Contents, + MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1Contents)))); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateDistributorTest, WatchCertInfoThenInvokeSetError) { + // Register watcher 1. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, kCertName1); + // Register watcher 2. + WatcherState* watcher_state_2 = MakeWatcher(kRootCert1Name, absl::nullopt); + // Register watcher 3. + WatcherState* watcher_state_3 = + MakeWatcher(absl::nullopt, kIdentityCert1Name); + distributor_.SetError(GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage)); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kErrorMessage, kErrorMessage))); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kErrorMessage, ""))); + EXPECT_THAT(watcher_state_3->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kErrorMessage))); + CancelWatch(watcher_state_1); + CancelWatch(watcher_state_2); + CancelWatch(watcher_state_3); +} + +TEST_F(GrpcTlsCertificateDistributorTest, WatchErroredCertInfoBySetError) { + // Register watcher 1 watching kCertName1 as root. + WatcherState* watcher_state_1 = MakeWatcher(kCertName1, absl::nullopt); + // Register watcher 2 watching kCertName2 as identity. + WatcherState* watcher_state_2 = MakeWatcher(absl::nullopt, kCertName2); + // Call SetError and then cancel all watchers. + distributor_.SetError(GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage)); + CancelWatch(watcher_state_1); + CancelWatch(watcher_state_2); + // Register watcher 3 watching kCertName1 as root and kCertName2 as identity + // should not get the error updates. + WatcherState* watcher_state_3 = MakeWatcher(kCertName1, kCertName2); + EXPECT_THAT(watcher_state_3->GetErrorQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_3); + // Register watcher 4 watching kCertName2 as root and kCertName1 as identity + // should not get the error updates. + WatcherState* watcher_state_4 = MakeWatcher(kCertName2, kCertName1); + EXPECT_THAT(watcher_state_4->GetErrorQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_4); +} + +TEST_F(GrpcTlsCertificateDistributorTest, SetErrorForCertInCallback) { + distributor_.SetWatchStatusCallback([this](std::string cert_name, + bool /*root_being_watched*/, + bool /*identity_being_watched*/) { + this->distributor_.SetErrorForCert( + cert_name, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + }); + auto verify_function = [this](std::string cert_name) { + WatcherState* watcher_state_1 = MakeWatcher(cert_name, cert_name); + // Check the errors are delivered to watcher 1. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre( + ErrorInfo(kRootErrorMessage, kIdentityErrorMessage))); + CancelWatch(watcher_state_1); + }; + // Start 1000 threads that will register a watcher to a new cert name, verify + // the key materials being set, and then cancel the watcher, to make sure the + // lock mechanism in the distributor is safe. + std::vector threads; + threads.reserve(1000); + for (int i = 0; i < 1000; ++i) { + threads.emplace_back(verify_function, std::to_string(i)); + } + for (auto& th : threads) { + th.join(); + } +} + +} // namespace testing + +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/grpc_tls_certificate_provider_test.cc b/test/core/security/grpc_tls_certificate_provider_test.cc new file mode 100644 index 00000000..a13e8078 --- /dev/null +++ b/test/core/security/grpc_tls_certificate_provider_test.cc @@ -0,0 +1,551 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h" + +#include +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" +#define CA_CERT_PATH_2 "src/core/tsi/test_creds/multi-domain.pem" +#define SERVER_CERT_PATH_2 "src/core/tsi/test_creds/server0.pem" +#define SERVER_KEY_PATH_2 "src/core/tsi/test_creds/server0.key" +#define INVALID_PATH "invalid/path" + +namespace grpc_core { + +namespace testing { + +constexpr const char* kCertName = "cert_name"; +constexpr const char* kRootError = "Unable to get latest root certificates."; +constexpr const char* kIdentityError = + "Unable to get latest identity certificates."; + +class GrpcTlsCertificateProviderTest : public ::testing::Test { + protected: + // Forward declaration. + class TlsCertificatesTestWatcher; + + // CredentialInfo contains the parameters when calling OnCertificatesChanged + // of a watcher. When OnCertificatesChanged is invoked, we will push a + // CredentialInfo to the cert_update_queue of state_, and check in each test + // if the status updates are correct. + struct CredentialInfo { + std::string root_certs; + PemKeyCertPairList key_cert_pairs; + CredentialInfo(std::string root, PemKeyCertPairList key_cert) + : root_certs(std::move(root)), key_cert_pairs(std::move(key_cert)) {} + bool operator==(const CredentialInfo& other) const { + return root_certs == other.root_certs && + key_cert_pairs == other.key_cert_pairs; + } + }; + + // ErrorInfo contains the parameters when calling OnError of a watcher. When + // OnError is invoked, we will push a ErrorInfo to the error_queue of state_, + // and check in each test if the status updates are correct. + struct ErrorInfo { + std::string root_cert_str; + std::string identity_cert_str; + ErrorInfo(std::string root, std::string identity) + : root_cert_str(std::move(root)), + identity_cert_str(std::move(identity)) {} + bool operator==(const ErrorInfo& other) const { + return root_cert_str == other.root_cert_str && + identity_cert_str == other.identity_cert_str; + } + }; + + struct WatcherState { + TlsCertificatesTestWatcher* watcher = nullptr; + std::deque cert_update_queue; + std::deque error_queue; + Mutex mu; + + std::deque GetCredentialQueue() { + // We move the data member value so the data member will be re-initiated + // with size 0, and ready for the next check. + MutexLock lock(&mu); + return std::move(cert_update_queue); + } + std::deque GetErrorQueue() { + // We move the data member value so the data member will be re-initiated + // with size 0, and ready for the next check. + MutexLock lock(&mu); + return std::move(error_queue); + } + }; + + class TlsCertificatesTestWatcher : public grpc_tls_certificate_distributor:: + TlsCertificatesWatcherInterface { + public: + // ctor sets state->watcher to this. + explicit TlsCertificatesTestWatcher(WatcherState* state) : state_(state) { + state_->watcher = this; + } + + // dtor sets state->watcher to nullptr. + ~TlsCertificatesTestWatcher() override { state_->watcher = nullptr; } + + void OnCertificatesChanged( + absl::optional root_certs, + absl::optional key_cert_pairs) override { + MutexLock lock(&state_->mu); + std::string updated_root; + if (root_certs.has_value()) { + updated_root = std::string(*root_certs); + } + PemKeyCertPairList updated_identity; + if (key_cert_pairs.has_value()) { + updated_identity = std::move(*key_cert_pairs); + } + state_->cert_update_queue.emplace_back(std::move(updated_root), + std::move(updated_identity)); + } + + void OnError(grpc_error_handle root_cert_error, + grpc_error_handle identity_cert_error) override { + MutexLock lock(&state_->mu); + GPR_ASSERT(root_cert_error != GRPC_ERROR_NONE || + identity_cert_error != GRPC_ERROR_NONE); + std::string root_error_str; + std::string identity_error_str; + if (root_cert_error != GRPC_ERROR_NONE) { + GPR_ASSERT(grpc_error_get_str( + root_cert_error, GRPC_ERROR_STR_DESCRIPTION, &root_error_str)); + } + if (identity_cert_error != GRPC_ERROR_NONE) { + GPR_ASSERT(grpc_error_get_str(identity_cert_error, + GRPC_ERROR_STR_DESCRIPTION, + &identity_error_str)); + } + state_->error_queue.emplace_back(std::move(root_error_str), + std::move(identity_error_str)); + GRPC_ERROR_UNREF(root_cert_error); + GRPC_ERROR_UNREF(identity_cert_error); + } + + private: + WatcherState* state_; + }; + + void SetUp() override { + root_cert_ = GetFileContents(CA_CERT_PATH); + cert_chain_ = GetFileContents(SERVER_CERT_PATH); + private_key_ = GetFileContents(SERVER_KEY_PATH); + root_cert_2_ = GetFileContents(CA_CERT_PATH_2); + cert_chain_2_ = GetFileContents(SERVER_CERT_PATH_2); + private_key_2_ = GetFileContents(SERVER_KEY_PATH_2); + } + + WatcherState* MakeWatcher( + RefCountedPtr distributor, + absl::optional root_cert_name, + absl::optional identity_cert_name) { + MutexLock lock(&mu_); + distributor_ = distributor; + watchers_.emplace_back(); + // TlsCertificatesTestWatcher ctor takes a pointer to the WatcherState. + // It sets WatcherState::watcher to point to itself. + // The TlsCertificatesTestWatcher dtor will set WatcherState::watcher back + // to nullptr to indicate that it's been destroyed. + auto watcher = + absl::make_unique(&watchers_.back()); + distributor_->WatchTlsCertificates(std::move(watcher), + std::move(root_cert_name), + std::move(identity_cert_name)); + return &watchers_.back(); + } + + void CancelWatch(WatcherState* state) { + MutexLock lock(&mu_); + distributor_->CancelTlsCertificatesWatch(state->watcher); + EXPECT_EQ(state->watcher, nullptr); + } + + std::string root_cert_; + std::string private_key_; + std::string cert_chain_; + std::string root_cert_2_; + std::string private_key_2_; + std::string cert_chain_2_; + RefCountedPtr distributor_; + // Use a std::list<> here to avoid the address invalidation caused by internal + // reallocation of std::vector<>. + std::list watchers_; + // This is to make watchers_ thread-safe. + Mutex mu_; +}; + +TEST_F(GrpcTlsCertificateProviderTest, StaticDataCertificateProviderCreation) { + StaticDataCertificateProvider provider( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + // Watcher watching both root and identity certs. + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + CancelWatch(watcher_state_1); + // Watcher watching only root certs. + WatcherState* watcher_state_2 = + MakeWatcher(provider.distributor(), kCertName, absl::nullopt); + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(root_cert_, {}))); + CancelWatch(watcher_state_2); + // Watcher watching only identity certs. + WatcherState* watcher_state_3 = + MakeWatcher(provider.distributor(), absl::nullopt, kCertName); + EXPECT_THAT( + watcher_state_3->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())))); + CancelWatch(watcher_state_3); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderWithGoodPaths) { + FileWatcherCertificateProvider provider(SERVER_KEY_PATH, SERVER_CERT_PATH, + CA_CERT_PATH, 1); + // Watcher watching both root and identity certs. + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + CancelWatch(watcher_state_1); + // Watcher watching only root certs. + WatcherState* watcher_state_2 = + MakeWatcher(provider.distributor(), kCertName, absl::nullopt); + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo(root_cert_, {}))); + CancelWatch(watcher_state_2); + // Watcher watching only identity certs. + WatcherState* watcher_state_3 = + MakeWatcher(provider.distributor(), absl::nullopt, kCertName); + EXPECT_THAT( + watcher_state_3->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + "", MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())))); + CancelWatch(watcher_state_3); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderWithBadPaths) { + FileWatcherCertificateProvider provider(INVALID_PATH, INVALID_PATH, + INVALID_PATH, 1); + // Watcher watching both root and identity certs. + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootError, kIdentityError))); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_1); + // Watcher watching only root certs. + WatcherState* watcher_state_2 = + MakeWatcher(provider.distributor(), kCertName, absl::nullopt); + EXPECT_THAT(watcher_state_2->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo(kRootError, ""))); + EXPECT_THAT(watcher_state_2->GetCredentialQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_2); + // Watcher watching only identity certs. + WatcherState* watcher_state_3 = + MakeWatcher(provider.distributor(), absl::nullopt, kCertName); + EXPECT_THAT(watcher_state_3->GetErrorQueue(), + ::testing::ElementsAre(ErrorInfo("", kIdentityError))); + EXPECT_THAT(watcher_state_3->GetCredentialQueue(), ::testing::ElementsAre()); + CancelWatch(watcher_state_3); +} + +// The following tests write credential data to temporary files to test the +// transition behavior of the provider. +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderOnBothCertsRefreshed) { + // Create temporary files and copy cert data into them. + TmpFile tmp_root_cert(root_cert_); + TmpFile tmp_identity_key(private_key_); + TmpFile tmp_identity_cert(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key.name(), + tmp_identity_cert.name(), + tmp_root_cert.name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // Expect to see the credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Copy new data to files. + // TODO(ZhenLian): right now it is not completely atomic. Use the real atomic + // update when the directory renaming is added in gpr. + tmp_root_cert.RewriteFile(root_cert_2_); + tmp_identity_key.RewriteFile(private_key_2_); + tmp_identity_cert.RewriteFile(cert_chain_2_); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see the new credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_2_, MakeCertKeyPairs(private_key_2_.c_str(), + cert_chain_2_.c_str())))); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderOnRootCertsRefreshed) { + // Create temporary files and copy cert data into them. + TmpFile tmp_root_cert(root_cert_); + TmpFile tmp_identity_key(private_key_); + TmpFile tmp_identity_cert(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key.name(), + tmp_identity_cert.name(), + tmp_root_cert.name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // Expect to see the credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Copy new data to files. + // TODO(ZhenLian): right now it is not completely atomic. Use the real atomic + // update when the directory renaming is added in gpr. + tmp_root_cert.RewriteFile(root_cert_2_); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see the new credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_2_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderOnIdentityCertsRefreshed) { + // Create temporary files and copy cert data into them. + TmpFile tmp_root_cert(root_cert_); + TmpFile tmp_identity_key(private_key_); + TmpFile tmp_identity_cert(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key.name(), + tmp_identity_cert.name(), + tmp_root_cert.name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // Expect to see the credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Copy new data to files. + // TODO(ZhenLian): right now it is not completely atomic. Use the real atomic + // update when the directory renaming is added in gpr. + tmp_identity_key.RewriteFile(private_key_2_); + tmp_identity_cert.RewriteFile(cert_chain_2_); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see the new credential data. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_2_.c_str(), + cert_chain_2_.c_str())))); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderWithGoodAtFirstThenDeletedBothCerts) { + // Create temporary files and copy cert data into it. + auto tmp_root_cert = absl::make_unique(root_cert_); + auto tmp_identity_key = absl::make_unique(private_key_); + auto tmp_identity_cert = absl::make_unique(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key->name(), + tmp_identity_cert->name(), + tmp_root_cert->name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // The initial data is all good, so we expect to have successful credential + // updates. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Delete TmpFile objects, which will remove the corresponding files. + tmp_root_cert.reset(); + tmp_identity_key.reset(); + tmp_identity_cert.reset(); + // Wait 2 seconds for the provider's refresh thread to read the deleted files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see errors sent to watchers, and no credential updates. + // We have no ideas on how many errors we will receive, so we only check once. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::Contains(ErrorInfo(kRootError, kIdentityError))); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderWithGoodAtFirstThenDeletedRootCerts) { + // Create temporary files and copy cert data into it. + auto tmp_root_cert = absl::make_unique(root_cert_); + TmpFile tmp_identity_key(private_key_); + TmpFile tmp_identity_cert(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key.name(), + tmp_identity_cert.name(), + tmp_root_cert->name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // The initial data is all good, so we expect to have successful credential + // updates. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Delete root TmpFile object, which will remove the corresponding file. + tmp_root_cert.reset(); + // Wait 2 seconds for the provider's refresh thread to read the deleted files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see errors sent to watchers, and no credential updates. + // We have no ideas on how many errors we will receive, so we only check once. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::Contains(ErrorInfo(kRootError, ""))); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FileWatcherCertificateProviderWithGoodAtFirstThenDeletedIdentityCerts) { + // Create temporary files and copy cert data into it. + TmpFile tmp_root_cert(root_cert_); + auto tmp_identity_key = absl::make_unique(private_key_); + auto tmp_identity_cert = absl::make_unique(cert_chain_); + // Create FileWatcherCertificateProvider. + FileWatcherCertificateProvider provider(tmp_identity_key->name(), + tmp_identity_cert->name(), + tmp_root_cert.name(), 1); + WatcherState* watcher_state_1 = + MakeWatcher(provider.distributor(), kCertName, kCertName); + // The initial data is all good, so we expect to have successful credential + // updates. + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), + ::testing::ElementsAre(CredentialInfo( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), + cert_chain_.c_str())))); + // Delete identity TmpFile objects, which will remove the corresponding files. + tmp_identity_key.reset(); + tmp_identity_cert.reset(); + // Wait 2 seconds for the provider's refresh thread to read the deleted files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see errors sent to watchers, and no credential updates. + // We have no ideas on how many errors we will receive, so we only check once. + EXPECT_THAT(watcher_state_1->GetErrorQueue(), + ::testing::Contains(ErrorInfo("", kIdentityError))); + EXPECT_THAT(watcher_state_1->GetCredentialQueue(), ::testing::ElementsAre()); + // Clean up. + CancelWatch(watcher_state_1); +} + +TEST_F(GrpcTlsCertificateProviderTest, FailedKeyCertMatchOnEmptyPrivateKey) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch(/*private_key=*/"", cert_chain_); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.status().message(), "Private key string is empty."); +} + +TEST_F(GrpcTlsCertificateProviderTest, FailedKeyCertMatchOnEmptyCertificate) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch(private_key_2_, /*cert_chain=*/""); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.status().message(), "Certificate string is empty."); +} + +TEST_F(GrpcTlsCertificateProviderTest, FailedKeyCertMatchOnInvalidCertFormat) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch(private_key_2_, "invalid_certificate"); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.status().message(), + "Conversion from PEM string to X509 failed."); +} + +TEST_F(GrpcTlsCertificateProviderTest, + FailedKeyCertMatchOnInvalidPrivateKeyFormat) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch("invalid_private_key", cert_chain_2_); + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.status().message(), + "Conversion from PEM string to EVP_PKEY failed."); +} + +TEST_F(GrpcTlsCertificateProviderTest, SuccessfulKeyCertMatch) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch(private_key_2_, cert_chain_2_); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(*status); +} + +TEST_F(GrpcTlsCertificateProviderTest, FailedKeyCertMatchOnInvalidPair) { + absl::StatusOr status = + PrivateKeyAndCertificateMatch(private_key_2_, cert_chain_); + EXPECT_TRUE(status.ok()); + EXPECT_FALSE(*status); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/grpc_tls_credentials_options_test.cc b/test/core/security/grpc_tls_credentials_options_test.cc new file mode 100644 index 00000000..761effd9 --- /dev/null +++ b/test/core/security/grpc_tls_credentials_options_test.cc @@ -0,0 +1,505 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/tls/tls_credentials.h" +#include "src/core/lib/security/security_connector/tls/tls_security_connector.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" +#define CA_CERT_PATH_2 "src/core/tsi/test_creds/multi-domain.pem" +#define SERVER_CERT_PATH_2 "src/core/tsi/test_creds/server0.pem" +#define SERVER_KEY_PATH_2 "src/core/tsi/test_creds/server0.key" +#define INVALID_PATH "invalid/path" + +namespace grpc_core { + +namespace testing { + +class GrpcTlsCredentialsOptionsTest : public ::testing::Test { + protected: + void SetUp() override { + root_cert_ = GetFileContents(CA_CERT_PATH); + cert_chain_ = GetFileContents(SERVER_CERT_PATH); + private_key_ = GetFileContents(SERVER_KEY_PATH); + root_cert_2_ = GetFileContents(CA_CERT_PATH_2); + cert_chain_2_ = GetFileContents(SERVER_CERT_PATH_2); + private_key_2_ = GetFileContents(SERVER_KEY_PATH_2); + } + + std::string root_cert_; + std::string private_key_; + std::string cert_chain_; + std::string root_cert_2_; + std::string private_key_2_; + std::string cert_chain_2_; +}; + +TEST_F(GrpcTlsCredentialsOptionsTest, ErrorDetails) { + grpc_tls_error_details error_details; + EXPECT_STREQ(error_details.error_details().c_str(), ""); + error_details.set_error_details("test error details"); + EXPECT_STREQ(error_details.error_details().c_str(), "test error details"); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, ClientOptionsOnDefaultRootCerts) { + auto options = MakeRefCounted(); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); +} + +// Tests for StaticDataCertificateProvider. +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithStaticDataProviderOnBothCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithStaticDataProviderOnRootCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + root_cert_, PemKeyCertPairList()); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_FALSE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithStaticDataProviderOnNotProvidedCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + "", MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithDefaultRootAndStaticDataProviderOnIdentityCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + "", MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithStaticDataProviderOnBothCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + root_cert_, MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithStaticDataProviderOnIdentityCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + "", MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_cert_request_type(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_FALSE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithStaticDataProviderOnNotProvidedCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + root_cert_, PemKeyCertPairList()); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_cert_request_type(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); +} + +// Tests for FileWatcherCertificateProvider. +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnBothCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnRootCerts) { + auto options = MakeRefCounted(); + auto provider = + MakeRefCounted("", "", CA_CERT_PATH, 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_FALSE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnNotProvidedCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + SERVER_KEY_PATH, SERVER_CERT_PATH, "", 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnBadTrustCerts) { + auto options = MakeRefCounted(); + auto provider = + MakeRefCounted("", "", INVALID_PATH, 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithCertWatcherProviderOnBothCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithCertWatcherProviderOnIdentityCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + SERVER_KEY_PATH, SERVER_CERT_PATH, "", 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_cert_request_type(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_FALSE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithCertWatcherProviderOnNotProvidedCerts) { + auto options = MakeRefCounted(); + auto provider = + MakeRefCounted("", "", CA_CERT_PATH, 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_cert_request_type(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ServerOptionsWithCertWatcherProviderOnBadIdentityCerts) { + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + INVALID_PATH, INVALID_PATH, "", 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_identity_pair(true); + options->set_cert_request_type(GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + auto connector = credentials->create_security_connector(nullptr); + ASSERT_NE(connector, nullptr); + TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_EQ(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); +} + +// The following tests write credential data to temporary files to test the +// transition behavior of the provider. +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnCertificateRefreshed) { + // Create temporary files and copy cert data into them. + TmpFile tmp_root_cert(root_cert_); + TmpFile tmp_identity_key(private_key_); + TmpFile tmp_identity_cert(cert_chain_); + // Create ClientOptions using FileWatcherCertificateProvider. + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + tmp_identity_key.name(), tmp_identity_cert.name(), tmp_root_cert.name(), + 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + // Expect to see the credential data. + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + ASSERT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_); + ASSERT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), + MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + // Copy new data to files. + // TODO(ZhenLian): right now it is not completely atomic. Use the real atomic + // update when the directory renaming is added in gpr. + tmp_root_cert.RewriteFile(root_cert_2_); + tmp_identity_key.RewriteFile(private_key_2_); + tmp_identity_cert.RewriteFile(cert_chain_2_); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // Expect to see new credential data loaded by the security connector. + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + ASSERT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_2_); + ASSERT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), + MakeCertKeyPairs(private_key_2_.c_str(), cert_chain_2_.c_str())); +} + +TEST_F(GrpcTlsCredentialsOptionsTest, + ClientOptionsWithCertWatcherProviderOnDeletedFiles) { + // Create temporary files and copy cert data into it. + auto tmp_root_cert = absl::make_unique(root_cert_); + auto tmp_identity_key = absl::make_unique(private_key_); + auto tmp_identity_cert = absl::make_unique(cert_chain_); + // Create ClientOptions using FileWatcherCertificateProvider. + auto options = MakeRefCounted(); + auto provider = MakeRefCounted( + tmp_identity_key->name(), tmp_identity_cert->name(), + tmp_root_cert->name(), 1); + options->set_certificate_provider(std::move(provider)); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto credentials = MakeRefCounted(options); + ASSERT_NE(credentials, nullptr); + grpc_channel_args* new_args = nullptr; + auto connector = credentials->create_security_connector( + nullptr, "random targets", nullptr, &new_args); + grpc_channel_args_destroy(new_args); + ASSERT_NE(connector, nullptr); + TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + // The initial data is all good, so we expect to have successful credential + // updates. + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + ASSERT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_); + ASSERT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), + MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); + // Delete TmpFile objects, which will remove the corresponding files. + tmp_root_cert.reset(); + tmp_identity_key.reset(); + tmp_identity_cert.reset(); + // Wait 2 seconds for the provider's refresh thread to read the deleted files. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(2, GPR_TIMESPAN))); + // It's a bit hard to test if errors are sent to the security connector, + // because the security connector simply logs the error. We will see the err + // messages if we open the log. + // The old certs should still being used. + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + ASSERT_TRUE(tls_connector->RootCertsForTesting().has_value()); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_); + ASSERT_TRUE(tls_connector->KeyCertPairListForTesting().has_value()); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), + MakeCertKeyPairs(private_key_.c_str(), cert_chain_.c_str())); +} + +} // namespace testing + +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/insecure_security_connector_test.cc b/test/core/security/insecure_security_connector_test.cc new file mode 100644 index 00000000..0955eeb8 --- /dev/null +++ b/test/core/security/insecure_security_connector_test.cc @@ -0,0 +1,63 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/lib/security/security_connector/insecure/insecure_security_connector.h" + +#include +#include + +#include + +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/tsi/transport_security.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +TEST(InsecureSecurityConnector, MakeAuthContextTest) { + auto auth_context = TestOnlyMakeInsecureAuthContext(); + // Verify that peer is not authenticated + EXPECT_EQ(auth_context->is_authenticated(), false); + // Verify that GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME is set + auto it = grpc_auth_context_find_properties_by_name( + auth_context.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + ASSERT_NE(prop, nullptr); + EXPECT_STREQ(prop->name, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME); + EXPECT_STREQ(prop->value, kInsecureTransportSecurityType); + // Verify that security level is set to none + it = grpc_auth_context_find_properties_by_name( + auth_context.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME); + prop = grpc_auth_property_iterator_next(&it); + ASSERT_NE(prop, nullptr); + EXPECT_EQ(grpc_tsi_security_level_string_to_enum(prop->value), + GRPC_SECURITY_NONE); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/core/security/json_token_test.cc b/test/core/security/json_token_test.cc new file mode 100644 index 00000000..9f855a8d --- /dev/null +++ b/test/core/security/json_token_test.cc @@ -0,0 +1,454 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/credentials/jwt/json_token.h" + +#include + +#include + +#include +#include +#include +#include + +#include "src/core/lib/json/json.h" +#include "src/core/lib/security/credentials/oauth2/oauth2_credentials.h" +#include "src/core/lib/slice/b64.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +using grpc_core::Json; + +/* This JSON key was generated with the GCE console and revoked immediately. + The identifiers have been changed as well. + Maximum size for a string literal is 509 chars in C89, yay! */ +static const char test_json_key_str_part1[] = + "{ \"private_key\": \"-----BEGIN PRIVATE KEY-----" + "\\nMIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAOEvJsnoHnyHkXcp\\n7mJE" + "qg" + "WGjiw71NfXByguekSKho65FxaGbsnSM9SMQAqVk7Q2rG+I0OpsT0LrWQtZ\\nyjSeg/" + "rWBQvS4hle4LfijkP3J5BG+" + "IXDMP8RfziNRQsenAXDNPkY4kJCvKux2xdD\\nOnVF6N7dL3nTYZg+" + "uQrNsMTz9UxVAgMBAAECgYEAzbLewe1xe9vy+2GoSsfib+28\\nDZgSE6Bu/" + "zuFoPrRc6qL9p2SsnV7txrunTyJkkOnPLND9ABAXybRTlcVKP/sGgza\\n/" + "8HpCqFYM9V8f34SBWfD4fRFT+n/" + "73cfRUtGXdXpseva2lh8RilIQfPhNZAncenU\\ngqXjDvpkypEusgXAykECQQD+"; +static const char test_json_key_str_part2[] = + "53XxNVnxBHsYb+AYEfklR96yVi8HywjVHP34+OQZ\\nCslxoHQM8s+" + "dBnjfScLu22JqkPv04xyxmt0QAKm9+vTdAkEA4ib7YvEAn2jXzcCI\\nEkoy2L/" + "XydR1GCHoacdfdAwiL2npOdnbvi4ZmdYRPY1LSTO058tQHKVXV7NLeCa3\\nAARh2QJBAMKeDA" + "G" + "W303SQv2cZTdbeaLKJbB5drz3eo3j7dDKjrTD9JupixFbzcGw\\n8FZi5c8idxiwC36kbAL6Hz" + "A" + "ZoX+ofI0CQE6KCzPJTtYNqyShgKAZdJ8hwOcvCZtf\\n6z8RJm0+" + "6YBd38lfh5j8mZd7aHFf6I17j5AQY7oPEc47TjJj/" + "5nZ68ECQQDvYuI3\\nLyK5fS8g0SYbmPOL9TlcHDOqwG0mrX9qpg5DC2fniXNSrrZ64GTDKdzZ" + "Y" + "Ap6LI9W\\nIqv4vr6y38N79TTC\\n-----END PRIVATE KEY-----\\n\", "; +static const char test_json_key_str_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + +/* Test refresh token. */ +static const char test_refresh_token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"," + " \"type\": \"authorized_user\"}"; + +static const char test_scope[] = "myperm1 myperm2"; + +static const char test_service_url[] = "https://foo.com/foo.v1"; + +static char* test_json_key_str(const char* bad_part3) { + const char* part3 = + bad_part3 != nullptr ? bad_part3 : test_json_key_str_part3; + size_t result_len = strlen(test_json_key_str_part1) + + strlen(test_json_key_str_part2) + strlen(part3); + char* result = static_cast(gpr_malloc(result_len + 1)); + char* current = result; + strcpy(result, test_json_key_str_part1); + current += strlen(test_json_key_str_part1); + strcpy(current, test_json_key_str_part2); + current += strlen(test_json_key_str_part2); + strcpy(current, part3); + return result; +} + +static void test_parse_json_key_success(void) { + char* json_string = test_json_key_str(nullptr); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(grpc_auth_json_key_is_valid(&json_key)); + GPR_ASSERT(json_key.type != nullptr && + strcmp(json_key.type, "service_account") == 0); + GPR_ASSERT(json_key.private_key_id != nullptr && + strcmp(json_key.private_key_id, + "e6b5137873db8d2ef81e06a47289e6434ec8a165") == 0); + GPR_ASSERT(json_key.client_id != nullptr && + strcmp(json_key.client_id, + "777-abaslkan11hlb6nmim3bpspl31ud.apps." + "googleusercontent.com") == 0); + GPR_ASSERT(json_key.client_email != nullptr && + strcmp(json_key.client_email, + "777-abaslkan11hlb6nmim3bpspl31ud@developer." + "gserviceaccount.com") == 0); + GPR_ASSERT(json_key.private_key != nullptr); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_bad_json(void) { + const char non_closing_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" "; + char* json_string = test_json_key_str(non_closing_part3); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_no_type(void) { + const char no_type_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\" }"; + char* json_string = test_json_key_str(no_type_part3); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_no_client_id(void) { + const char no_client_id_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", " + "\"type\": \"service_account\" }"; + char* json_string = test_json_key_str(no_client_id_part3); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_no_client_email(void) { + const char no_client_email_part3[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + char* json_string = test_json_key_str(no_client_email_part3); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_no_private_key_id(void) { + const char no_private_key_id_part3[] = + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + char* json_string = test_json_key_str(no_private_key_id_part3); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); +} + +static void test_parse_json_key_failure_no_private_key(void) { + const char no_private_key_json_string[] = + "{ \"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(no_private_key_json_string); + GPR_ASSERT(!grpc_auth_json_key_is_valid(&json_key)); + grpc_auth_json_key_destruct(&json_key); +} + +static Json parse_json_part_from_jwt(const char* str, size_t len) { + grpc_core::ExecCtx exec_ctx; + char* b64 = static_cast(gpr_malloc(len + 1)); + strncpy(b64, str, len); + b64[len] = '\0'; + grpc_slice slice = grpc_base64_decode(b64, 1); + gpr_free(b64); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(slice)); + grpc_error_handle error = GRPC_ERROR_NONE; + absl::string_view string = grpc_core::StringViewFromSlice(slice); + Json json = Json::Parse(string, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + GRPC_ERROR_UNREF(error); + } + grpc_slice_unref(slice); + return json; +} + +static void check_jwt_header(const Json& header) { + Json::Object object = header.object_value(); + Json value = object["alg"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(strcmp(value.string_value().c_str(), "RS256") == 0); + value = object["typ"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(strcmp(value.string_value().c_str(), "JWT") == 0); + value = object["kid"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(strcmp(value.string_value().c_str(), + "e6b5137873db8d2ef81e06a47289e6434ec8a165") == 0); +} + +static void check_jwt_claim(const Json& claim, const char* expected_audience, + const char* expected_scope) { + Json::Object object = claim.object_value(); + + Json value = object["iss"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(value.string_value() == + "777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount.com"); + + if (expected_scope != nullptr) { + GPR_ASSERT(object.find("sub") == object.end()); + value = object["scope"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(value.string_value() == expected_scope); + } else { + /* Claims without scope must have a sub. */ + GPR_ASSERT(object.find("scope") == object.end()); + value = object["sub"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(value.string_value() == object["iss"].string_value()); + } + + value = object["aud"]; + GPR_ASSERT(value.type() == Json::Type::STRING); + GPR_ASSERT(value.string_value() == expected_audience); + + gpr_timespec expiration = gpr_time_0(GPR_CLOCK_REALTIME); + value = object["exp"]; + GPR_ASSERT(value.type() == Json::Type::NUMBER); + expiration.tv_sec = strtol(value.string_value().c_str(), nullptr, 10); + + gpr_timespec issue_time = gpr_time_0(GPR_CLOCK_REALTIME); + value = object["iat"]; + GPR_ASSERT(value.type() == Json::Type::NUMBER); + issue_time.tv_sec = strtol(value.string_value().c_str(), nullptr, 10); + + gpr_timespec parsed_lifetime = gpr_time_sub(expiration, issue_time); + GPR_ASSERT(parsed_lifetime.tv_sec == grpc_max_auth_token_lifetime().tv_sec); +} + +static void check_jwt_signature(const char* b64_signature, RSA* rsa_key, + const char* signed_data, + size_t signed_data_size) { + grpc_core::ExecCtx exec_ctx; + + EVP_MD_CTX* md_ctx = EVP_MD_CTX_create(); + EVP_PKEY* key = EVP_PKEY_new(); + + grpc_slice sig = grpc_base64_decode(b64_signature, 1); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(sig)); + GPR_ASSERT(GRPC_SLICE_LENGTH(sig) == 128); + + GPR_ASSERT(md_ctx != nullptr); + GPR_ASSERT(key != nullptr); + EVP_PKEY_set1_RSA(key, rsa_key); + + GPR_ASSERT( + EVP_DigestVerifyInit(md_ctx, nullptr, EVP_sha256(), nullptr, key) == 1); + GPR_ASSERT(EVP_DigestVerifyUpdate(md_ctx, signed_data, signed_data_size) == + 1); + GPR_ASSERT(EVP_DigestVerifyFinal(md_ctx, GRPC_SLICE_START_PTR(sig), + GRPC_SLICE_LENGTH(sig)) == 1); + + grpc_slice_unref_internal(sig); + if (key != nullptr) EVP_PKEY_free(key); + if (md_ctx != nullptr) EVP_MD_CTX_destroy(md_ctx); +} + +static char* service_account_creds_jwt_encode_and_sign( + const grpc_auth_json_key* key) { + return grpc_jwt_encode_and_sign(key, GRPC_JWT_OAUTH2_AUDIENCE, + grpc_max_auth_token_lifetime(), test_scope); +} + +static char* jwt_creds_jwt_encode_and_sign(const grpc_auth_json_key* key) { + return grpc_jwt_encode_and_sign(key, test_service_url, + grpc_max_auth_token_lifetime(), nullptr); +} + +static void service_account_creds_check_jwt_claim(const Json& claim) { + check_jwt_claim(claim, GRPC_JWT_OAUTH2_AUDIENCE, test_scope); +} + +static void jwt_creds_check_jwt_claim(const Json& claim) { + check_jwt_claim(claim, test_service_url, nullptr); +} + +static void test_jwt_encode_and_sign( + char* (*jwt_encode_and_sign_func)(const grpc_auth_json_key*), + void (*check_jwt_claim_func)(const Json&)) { + char* json_string = test_json_key_str(nullptr); + grpc_auth_json_key json_key = + grpc_auth_json_key_create_from_string(json_string); + const char* b64_signature; + size_t offset = 0; + char* jwt = jwt_encode_and_sign_func(&json_key); + const char* dot = strchr(jwt, '.'); + GPR_ASSERT(dot != nullptr); + Json parsed_header = + parse_json_part_from_jwt(jwt, static_cast(dot - jwt)); + GPR_ASSERT(parsed_header.type() == Json::Type::OBJECT); + check_jwt_header(parsed_header); + offset = static_cast(dot - jwt) + 1; + + dot = strchr(jwt + offset, '.'); + GPR_ASSERT(dot != nullptr); + Json parsed_claim = parse_json_part_from_jwt( + jwt + offset, static_cast(dot - (jwt + offset))); + GPR_ASSERT(parsed_claim.type() == Json::Type::OBJECT); + check_jwt_claim_func(parsed_claim); + offset = static_cast(dot - jwt) + 1; + + dot = strchr(jwt + offset, '.'); + GPR_ASSERT(dot == nullptr); /* no more part. */ + b64_signature = jwt + offset; + check_jwt_signature(b64_signature, json_key.private_key, jwt, offset - 1); + + gpr_free(json_string); + grpc_auth_json_key_destruct(&json_key); + gpr_free(jwt); +} + +static void test_service_account_creds_jwt_encode_and_sign(void) { + test_jwt_encode_and_sign(service_account_creds_jwt_encode_and_sign, + service_account_creds_check_jwt_claim); +} + +static void test_jwt_creds_jwt_encode_and_sign(void) { + test_jwt_encode_and_sign(jwt_creds_jwt_encode_and_sign, + jwt_creds_check_jwt_claim); +} + +static void test_parse_refresh_token_success(void) { + grpc_auth_refresh_token refresh_token = + grpc_auth_refresh_token_create_from_string(test_refresh_token_str); + GPR_ASSERT(grpc_auth_refresh_token_is_valid(&refresh_token)); + GPR_ASSERT(refresh_token.type != nullptr && + (strcmp(refresh_token.type, "authorized_user") == 0)); + GPR_ASSERT(refresh_token.client_id != nullptr && + (strcmp(refresh_token.client_id, + "32555999999.apps.googleusercontent.com") == 0)); + GPR_ASSERT( + refresh_token.client_secret != nullptr && + (strcmp(refresh_token.client_secret, "EmssLNjJy1332hD4KFsecret") == 0)); + GPR_ASSERT(refresh_token.refresh_token != nullptr && + (strcmp(refresh_token.refresh_token, + "1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42") == 0)); + grpc_auth_refresh_token_destruct(&refresh_token); +} + +static void test_parse_refresh_token_failure_no_type(void) { + const char refresh_token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"}"; + grpc_auth_refresh_token refresh_token = + grpc_auth_refresh_token_create_from_string(refresh_token_str); + GPR_ASSERT(!grpc_auth_refresh_token_is_valid(&refresh_token)); +} + +static void test_parse_refresh_token_failure_no_client_id(void) { + const char refresh_token_str[] = + "{ \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"," + " \"type\": \"authorized_user\"}"; + grpc_auth_refresh_token refresh_token = + grpc_auth_refresh_token_create_from_string(refresh_token_str); + GPR_ASSERT(!grpc_auth_refresh_token_is_valid(&refresh_token)); +} + +static void test_parse_refresh_token_failure_no_client_secret(void) { + const char refresh_token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"," + " \"type\": \"authorized_user\"}"; + grpc_auth_refresh_token refresh_token = + grpc_auth_refresh_token_create_from_string(refresh_token_str); + GPR_ASSERT(!grpc_auth_refresh_token_is_valid(&refresh_token)); +} + +static void test_parse_refresh_token_failure_no_refresh_token(void) { + const char refresh_token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"type\": \"authorized_user\"}"; + grpc_auth_refresh_token refresh_token = + grpc_auth_refresh_token_create_from_string(refresh_token_str); + GPR_ASSERT(!grpc_auth_refresh_token_is_valid(&refresh_token)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_parse_json_key_success(); + test_parse_json_key_failure_bad_json(); + test_parse_json_key_failure_no_type(); + test_parse_json_key_failure_no_client_id(); + test_parse_json_key_failure_no_client_email(); + test_parse_json_key_failure_no_private_key_id(); + test_parse_json_key_failure_no_private_key(); + test_service_account_creds_jwt_encode_and_sign(); + test_jwt_creds_jwt_encode_and_sign(); + test_parse_refresh_token_success(); + test_parse_refresh_token_failure_no_type(); + test_parse_refresh_token_failure_no_client_id(); + test_parse_refresh_token_failure_no_client_secret(); + test_parse_refresh_token_failure_no_refresh_token(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/jwt_verifier_test.cc b/test/core/security/jwt_verifier_test.cc new file mode 100644 index 00000000..933c413a --- /dev/null +++ b/test/core/security/jwt_verifier_test.cc @@ -0,0 +1,644 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/credentials/jwt/jwt_verifier.h" + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/security/credentials/jwt/json_token.h" +#include "src/core/lib/slice/b64.h" +#include "test/core/util/test_config.h" + +using grpc_core::Json; + +/* This JSON key was generated with the GCE console and revoked immediately. + The identifiers have been changed as well. + Maximum size for a string literal is 509 chars in C89, yay! */ +static const char json_key_str_part1[] = + "{ \"private_key\": \"-----BEGIN PRIVATE KEY-----" + "\\nMIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAOEvJsnoHnyHkXcp\\n7mJE" + "qg" + "WGjiw71NfXByguekSKho65FxaGbsnSM9SMQAqVk7Q2rG+I0OpsT0LrWQtZ\\nyjSeg/" + "rWBQvS4hle4LfijkP3J5BG+" + "IXDMP8RfziNRQsenAXDNPkY4kJCvKux2xdD\\nOnVF6N7dL3nTYZg+" + "uQrNsMTz9UxVAgMBAAECgYEAzbLewe1xe9vy+2GoSsfib+28\\nDZgSE6Bu/" + "zuFoPrRc6qL9p2SsnV7txrunTyJkkOnPLND9ABAXybRTlcVKP/sGgza\\n/" + "8HpCqFYM9V8f34SBWfD4fRFT+n/" + "73cfRUtGXdXpseva2lh8RilIQfPhNZAncenU\\ngqXjDvpkypEusgXAykECQQD+"; +static const char json_key_str_part2[] = + "53XxNVnxBHsYb+AYEfklR96yVi8HywjVHP34+OQZ\\nCslxoHQM8s+" + "dBnjfScLu22JqkPv04xyxmt0QAKm9+vTdAkEA4ib7YvEAn2jXzcCI\\nEkoy2L/" + "XydR1GCHoacdfdAwiL2npOdnbvi4ZmdYRPY1LSTO058tQHKVXV7NLeCa3\\nAARh2QJBAMKeDA" + "G" + "W303SQv2cZTdbeaLKJbB5drz3eo3j7dDKjrTD9JupixFbzcGw\\n8FZi5c8idxiwC36kbAL6Hz" + "A" + "ZoX+ofI0CQE6KCzPJTtYNqyShgKAZdJ8hwOcvCZtf\\n6z8RJm0+" + "6YBd38lfh5j8mZd7aHFf6I17j5AQY7oPEc47TjJj/" + "5nZ68ECQQDvYuI3\\nLyK5fS8g0SYbmPOL9TlcHDOqwG0mrX9qpg5DC2fniXNSrrZ64GTDKdzZ" + "Y" + "Ap6LI9W\\nIqv4vr6y38N79TTC\\n-----END PRIVATE KEY-----\\n\", "; +static const char json_key_str_part3_for_google_email_issuer[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud@developer.gserviceaccount." + "com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; +/* Trick our JWT library into issuing a JWT with iss=accounts.google.com. */ +static const char json_key_str_part3_for_url_issuer[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": \"accounts.google.com\", " + "\"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; +static const char json_key_str_part3_for_custom_email_issuer[] = + "\"private_key_id\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\", " + "\"client_email\": " + "\"foo@bar.com\", \"client_id\": " + "\"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent." + "com\", \"type\": \"service_account\" }"; + +static grpc_jwt_verifier_email_domain_key_url_mapping custom_mapping = { + "bar.com", "keys.bar.com/jwk"}; + +static const char expected_user_data[] = "user data"; + +static const char good_jwk_set[] = + "{" + " \"keys\": [" + " {" + " \"kty\": \"RSA\"," + " \"alg\": \"RS256\"," + " \"use\": \"sig\"," + " \"kid\": \"e6b5137873db8d2ef81e06a47289e6434ec8a165\"," + " \"n\": " + "\"4S8myegefIeRdynuYkSqBYaOLDvU19cHKC56RIqGjrkXFoZuydIz1IxACpWTtDasb4jQ6mxP" + "QutZC1nKNJ6D-tYFC9LiGV7gt-KOQ_cnkEb4hcMw_xF_OI1FCx6cBcM0-" + "RjiQkK8q7HbF0M6dUXo3t0vedNhmD65Cs2wxPP1TFU=\"," + " \"e\": \"AQAB\"" + " }" + " ]" + "}"; + +static gpr_timespec expected_lifetime = {3600, 0, GPR_TIMESPAN}; + +static const char good_google_email_keys_part1[] = + "{\"e6b5137873db8d2ef81e06a47289e6434ec8a165\": \"-----BEGIN " + "CERTIFICATE-----" + "\\nMIICATCCAWoCCQDEywLhxvHjnDANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB\\nVTET" + "MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0\\ncyBQdHkgTHR" + "kMB4XDTE1MDYyOTA4Mzk1MFoXDTI1MDYyNjA4Mzk1MFowRTELMAkG\\nA1UEBhMCQVUxEzARBg" + "NVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0\\nIFdpZGdpdHMgUHR5IEx0ZDCBn" + "zANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA4S8m\\nyegefIeRdynuYkSqBYaOLDvU19cHKC56" + "RIqGjrkXFoZuydIz1IxACpWTtDasb4jQ\\n6mxPQutZC1nKNJ6D+tYFC9LiGV7gt+KOQ/"; + +static const char good_google_email_keys_part2[] = + "cnkEb4hcMw/xF/OI1FCx6cBcM0+" + "Rji\\nQkK8q7HbF0M6dUXo3t0vedNhmD65Cs2wxPP1TFUCAwEAATANBgkqhkiG9w0BAQsF\\nA" + "AOBgQBfu69FkPmBknbKNFgurPz78kbs3VNN+k/" + "PUgO5DHKskJmgK2TbtvX2VMpx\\nkftmHGzgzMzUlOtigCaGMgHWjfqjpP9uuDbahXrZBJzB8c" + "Oq7MrQF8r17qVvo3Ue\\nPjTKQMAsU8uxTEMmeuz9L6yExs0rfd6bPOrQkAoVfFfiYB3/" + "pA==\\n-----END CERTIFICATE-----\\n\"}"; + +static const char expected_audience[] = "https://foo.com"; + +static const char good_openid_config[] = + "{" + " \"issuer\": \"https://accounts.google.com\"," + " \"authorization_endpoint\": " + "\"https://accounts.google.com/o/oauth2/v2/auth\"," + " \"token_endpoint\": \"https://oauth2.googleapis.com/token\"," + " \"userinfo_endpoint\": \"https://www.googleapis.com/oauth2/v3/userinfo\"," + " \"revocation_endpoint\": \"https://oauth2.googleapis.com/revoke\"," + " \"jwks_uri\": \"https://www.googleapis.com/oauth2/v3/certs\"" + "}"; + +static const char expired_claims[] = + "{ \"aud\": \"https://foo.com\"," + " \"iss\": \"blah.foo.com\"," + " \"sub\": \"juju@blah.foo.com\"," + " \"jti\": \"jwtuniqueid\"," + " \"iat\": 100," /* Way back in the past... */ + " \"exp\": 120," + " \"nbf\": 60," + " \"foo\": \"bar\"}"; + +static const char claims_without_time_constraint[] = + "{ \"aud\": \"https://foo.com\"," + " \"iss\": \"blah.foo.com\"," + " \"sub\": \"juju@blah.foo.com\"," + " \"jti\": \"jwtuniqueid\"," + " \"foo\": \"bar\"}"; + +static const char claims_with_bad_subject[] = + "{ \"aud\": \"https://foo.com\"," + " \"iss\": \"evil@blah.foo.com\"," + " \"sub\": \"juju@blah.foo.com\"," + " \"jti\": \"jwtuniqueid\"," + " \"foo\": \"bar\"}"; + +static const char invalid_claims[] = + "{ \"aud\": \"https://foo.com\"," + " \"iss\": 46," /* Issuer cannot be a number. */ + " \"sub\": \"juju@blah.foo.com\"," + " \"jti\": \"jwtuniqueid\"," + " \"foo\": \"bar\"}"; + +typedef struct { + grpc_jwt_verifier_status expected_status; + const char* expected_issuer; + const char* expected_subject; +} verifier_test_config; + +static void test_jwt_issuer_email_domain(void) { + const char* d = grpc_jwt_issuer_email_domain("https://foo.com"); + GPR_ASSERT(d == nullptr); + d = grpc_jwt_issuer_email_domain("foo.com"); + GPR_ASSERT(d == nullptr); + d = grpc_jwt_issuer_email_domain(""); + GPR_ASSERT(d == nullptr); + d = grpc_jwt_issuer_email_domain("@"); + GPR_ASSERT(d == nullptr); + d = grpc_jwt_issuer_email_domain("bar@foo"); + GPR_ASSERT(strcmp(d, "foo") == 0); + d = grpc_jwt_issuer_email_domain("bar@foo.com"); + GPR_ASSERT(strcmp(d, "foo.com") == 0); + d = grpc_jwt_issuer_email_domain("bar@blah.foo.com"); + GPR_ASSERT(strcmp(d, "foo.com") == 0); + d = grpc_jwt_issuer_email_domain("bar.blah@blah.foo.com"); + GPR_ASSERT(strcmp(d, "foo.com") == 0); + d = grpc_jwt_issuer_email_domain("bar.blah@baz.blah.foo.com"); + GPR_ASSERT(strcmp(d, "foo.com") == 0); + + /* This is not a very good parser but make sure we do not crash on these weird + inputs. */ + d = grpc_jwt_issuer_email_domain("@foo"); + GPR_ASSERT(strcmp(d, "foo") == 0); + d = grpc_jwt_issuer_email_domain("bar@."); + GPR_ASSERT(d != nullptr); + d = grpc_jwt_issuer_email_domain("bar@.."); + GPR_ASSERT(d != nullptr); + d = grpc_jwt_issuer_email_domain("bar@..."); + GPR_ASSERT(d != nullptr); +} + +static void test_claims_success(void) { + grpc_jwt_claims* claims; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(claims_without_time_constraint, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(json.type() == Json::Type::OBJECT); + grpc_core::ExecCtx exec_ctx; + claims = grpc_jwt_claims_from_json(json); + GPR_ASSERT(claims != nullptr); + GPR_ASSERT(*grpc_jwt_claims_json(claims) == json); + GPR_ASSERT(strcmp(grpc_jwt_claims_audience(claims), "https://foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_issuer(claims), "blah.foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_subject(claims), "juju@blah.foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_id(claims), "jwtuniqueid") == 0); + GPR_ASSERT(grpc_jwt_claims_check(claims, "https://foo.com") == + GRPC_JWT_VERIFIER_OK); + grpc_jwt_claims_destroy(claims); +} + +static void test_expired_claims_failure(void) { + grpc_jwt_claims* claims; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(expired_claims, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(json.type() == Json::Type::OBJECT); + gpr_timespec exp_iat = {100, 0, GPR_CLOCK_REALTIME}; + gpr_timespec exp_exp = {120, 0, GPR_CLOCK_REALTIME}; + gpr_timespec exp_nbf = {60, 0, GPR_CLOCK_REALTIME}; + grpc_core::ExecCtx exec_ctx; + claims = grpc_jwt_claims_from_json(json); + GPR_ASSERT(claims != nullptr); + GPR_ASSERT(*grpc_jwt_claims_json(claims) == json); + GPR_ASSERT(strcmp(grpc_jwt_claims_audience(claims), "https://foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_issuer(claims), "blah.foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_subject(claims), "juju@blah.foo.com") == 0); + GPR_ASSERT(strcmp(grpc_jwt_claims_id(claims), "jwtuniqueid") == 0); + GPR_ASSERT(gpr_time_cmp(grpc_jwt_claims_issued_at(claims), exp_iat) == 0); + GPR_ASSERT(gpr_time_cmp(grpc_jwt_claims_expires_at(claims), exp_exp) == 0); + GPR_ASSERT(gpr_time_cmp(grpc_jwt_claims_not_before(claims), exp_nbf) == 0); + + GPR_ASSERT(grpc_jwt_claims_check(claims, "https://foo.com") == + GRPC_JWT_VERIFIER_TIME_CONSTRAINT_FAILURE); + grpc_jwt_claims_destroy(claims); +} + +static void test_invalid_claims_failure(void) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(invalid_claims, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(json.type() == Json::Type::OBJECT); + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(grpc_jwt_claims_from_json(json) == nullptr); +} + +static void test_bad_audience_claims_failure(void) { + grpc_jwt_claims* claims; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(claims_without_time_constraint, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(json.type() == Json::Type::OBJECT); + grpc_core::ExecCtx exec_ctx; + claims = grpc_jwt_claims_from_json(json); + GPR_ASSERT(claims != nullptr); + GPR_ASSERT(grpc_jwt_claims_check(claims, "https://bar.com") == + GRPC_JWT_VERIFIER_BAD_AUDIENCE); + grpc_jwt_claims_destroy(claims); +} + +static void test_bad_subject_claims_failure(void) { + grpc_jwt_claims* claims; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(claims_with_bad_subject, &error); + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "JSON parse error: %s", + grpc_error_std_string(error).c_str()); + } + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(json.type() == Json::Type::OBJECT); + grpc_core::ExecCtx exec_ctx; + claims = grpc_jwt_claims_from_json(json); + GPR_ASSERT(claims != nullptr); + GPR_ASSERT(grpc_jwt_claims_check(claims, "https://foo.com") == + GRPC_JWT_VERIFIER_BAD_SUBJECT); + grpc_jwt_claims_destroy(claims); +} + +static char* json_key_str(const char* last_part) { + size_t result_len = strlen(json_key_str_part1) + strlen(json_key_str_part2) + + strlen(last_part); + char* result = static_cast(gpr_malloc(result_len + 1)); + char* current = result; + strcpy(result, json_key_str_part1); + current += strlen(json_key_str_part1); + strcpy(current, json_key_str_part2); + current += strlen(json_key_str_part2); + strcpy(current, last_part); + return result; +} + +static char* good_google_email_keys(void) { + size_t result_len = strlen(good_google_email_keys_part1) + + strlen(good_google_email_keys_part2); + char* result = static_cast(gpr_malloc(result_len + 1)); + char* current = result; + strcpy(result, good_google_email_keys_part1); + current += strlen(good_google_email_keys_part1); + strcpy(current, good_google_email_keys_part2); + return result; +} + +static grpc_httpcli_response http_response(int status, char* body) { + grpc_httpcli_response response; + response = {}; + response.status = status; + response.body = body; + response.body_length = strlen(body); + return response; +} + +static int httpcli_post_should_not_be_called( + const grpc_httpcli_request* /*request*/, const char* /*body_bytes*/, + size_t /*body_size*/, grpc_millis /*deadline*/, grpc_closure* /*on_done*/, + grpc_httpcli_response* /*response*/) { + GPR_ASSERT("HTTP POST should not be called" == nullptr); + return 1; +} + +static int httpcli_get_google_keys_for_email( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + *response = http_response(200, good_google_email_keys()); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, "www.googleapis.com") == 0); + GPR_ASSERT(strcmp(request->http.path, + "/robot/v1/metadata/x509/" + "777-abaslkan11hlb6nmim3bpspl31ud@developer." + "gserviceaccount.com") == 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void on_verification_success(void* user_data, + grpc_jwt_verifier_status status, + grpc_jwt_claims* claims) { + GPR_ASSERT(status == GRPC_JWT_VERIFIER_OK); + GPR_ASSERT(claims != nullptr); + GPR_ASSERT(user_data == (void*)expected_user_data); + GPR_ASSERT(strcmp(grpc_jwt_claims_audience(claims), expected_audience) == 0); + grpc_jwt_claims_destroy(claims); +} + +static void test_jwt_verifier_google_email_issuer_success(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_google_email_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_google_keys_for_email, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_success, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(jwt); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static int httpcli_get_custom_keys_for_email( + const grpc_httpcli_request* request, grpc_millis /*deadline*/, + grpc_closure* on_done, grpc_httpcli_response* response) { + *response = http_response(200, gpr_strdup(good_jwk_set)); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, "keys.bar.com") == 0); + GPR_ASSERT(strcmp(request->http.path, "/jwk/foo@bar.com") == 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void test_jwt_verifier_custom_email_issuer_success(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(&custom_mapping, 1); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_custom_email_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_custom_keys_for_email, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_success, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(jwt); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static int httpcli_get_jwk_set(const grpc_httpcli_request* request, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + *response = http_response(200, gpr_strdup(good_jwk_set)); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, "www.googleapis.com") == 0); + GPR_ASSERT(strcmp(request->http.path, "/oauth2/v3/certs") == 0); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static int httpcli_get_openid_config(const grpc_httpcli_request* request, + grpc_millis /*deadline*/, + grpc_closure* on_done, + grpc_httpcli_response* response) { + *response = http_response(200, gpr_strdup(good_openid_config)); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + GPR_ASSERT(strcmp(request->host, "accounts.google.com") == 0); + GPR_ASSERT(strcmp(request->http.path, GRPC_OPENID_CONFIG_URL_SUFFIX) == 0); + grpc_httpcli_set_override(httpcli_get_jwk_set, + httpcli_post_should_not_be_called); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void test_jwt_verifier_url_issuer_success(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_url_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_openid_config, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_success, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(jwt); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void on_verification_key_retrieval_error(void* user_data, + grpc_jwt_verifier_status status, + grpc_jwt_claims* claims) { + GPR_ASSERT(status == GRPC_JWT_VERIFIER_KEY_RETRIEVAL_ERROR); + GPR_ASSERT(claims == nullptr); + GPR_ASSERT(user_data == (void*)expected_user_data); +} + +static int httpcli_get_bad_json(const grpc_httpcli_request* request, + grpc_millis /*deadline*/, grpc_closure* on_done, + grpc_httpcli_response* response) { + *response = http_response(200, gpr_strdup("{\"bad\": \"stuff\"}")); + GPR_ASSERT(request->handshaker == &grpc_httpcli_ssl); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, on_done, GRPC_ERROR_NONE); + return 1; +} + +static void test_jwt_verifier_url_issuer_bad_config(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_url_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_bad_json, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_key_retrieval_error, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(jwt); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void test_jwt_verifier_bad_json_key(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_google_email_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_bad_json, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_key_retrieval_error, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(jwt); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static void corrupt_jwt_sig(char* jwt) { + grpc_slice sig; + char* bad_b64_sig; + uint8_t* sig_bytes; + char* last_dot = strrchr(jwt, '.'); + GPR_ASSERT(last_dot != nullptr); + { + grpc_core::ExecCtx exec_ctx; + sig = grpc_base64_decode(last_dot + 1, 1); + } + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(sig)); + sig_bytes = GRPC_SLICE_START_PTR(sig); + (*sig_bytes)++; /* Corrupt first byte. */ + bad_b64_sig = grpc_base64_encode(GRPC_SLICE_START_PTR(sig), + GRPC_SLICE_LENGTH(sig), 1, 0); + memcpy(last_dot + 1, bad_b64_sig, strlen(bad_b64_sig)); + gpr_free(bad_b64_sig); + grpc_slice_unref(sig); +} + +static void on_verification_bad_signature(void* user_data, + grpc_jwt_verifier_status status, + grpc_jwt_claims* claims) { + GPR_ASSERT(status == GRPC_JWT_VERIFIER_BAD_SIGNATURE); + GPR_ASSERT(claims == nullptr); + GPR_ASSERT(user_data == (void*)expected_user_data); +} + +static void test_jwt_verifier_bad_signature(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + char* jwt = nullptr; + char* key_str = json_key_str(json_key_str_part3_for_url_issuer); + grpc_auth_json_key key = grpc_auth_json_key_create_from_string(key_str); + gpr_free(key_str); + GPR_ASSERT(grpc_auth_json_key_is_valid(&key)); + grpc_httpcli_set_override(httpcli_get_openid_config, + httpcli_post_should_not_be_called); + jwt = grpc_jwt_encode_and_sign(&key, expected_audience, expected_lifetime, + nullptr); + grpc_auth_json_key_destruct(&key); + corrupt_jwt_sig(jwt); + GPR_ASSERT(jwt != nullptr); + grpc_jwt_verifier_verify(verifier, nullptr, jwt, expected_audience, + on_verification_bad_signature, + const_cast(expected_user_data)); + gpr_free(jwt); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +static int httpcli_get_should_not_be_called( + const grpc_httpcli_request* /*request*/, grpc_millis /*deadline*/, + grpc_closure* /*on_done*/, grpc_httpcli_response* /*response*/) { + GPR_ASSERT(0); + return 1; +} + +static void on_verification_bad_format(void* user_data, + grpc_jwt_verifier_status status, + grpc_jwt_claims* claims) { + GPR_ASSERT(status == GRPC_JWT_VERIFIER_BAD_FORMAT); + GPR_ASSERT(claims == nullptr); + GPR_ASSERT(user_data == (void*)expected_user_data); +} + +static void test_jwt_verifier_bad_format(void) { + grpc_core::ExecCtx exec_ctx; + grpc_jwt_verifier* verifier = grpc_jwt_verifier_create(nullptr, 0); + grpc_httpcli_set_override(httpcli_get_should_not_be_called, + httpcli_post_should_not_be_called); + grpc_jwt_verifier_verify(verifier, nullptr, "bad jwt", expected_audience, + on_verification_bad_format, + const_cast(expected_user_data)); + grpc_jwt_verifier_destroy(verifier); + grpc_core::ExecCtx::Get()->Flush(); + grpc_httpcli_set_override(nullptr, nullptr); +} + +/* find verification key: bad jks, cannot find key in jks */ +/* bad signature custom provided email*/ +/* bad key */ + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_jwt_issuer_email_domain(); + test_claims_success(); + test_expired_claims_failure(); + test_invalid_claims_failure(); + test_bad_audience_claims_failure(); + test_bad_subject_claims_failure(); + test_jwt_verifier_google_email_issuer_success(); + test_jwt_verifier_custom_email_issuer_success(); + test_jwt_verifier_url_issuer_success(); + test_jwt_verifier_url_issuer_bad_config(); + test_jwt_verifier_bad_json_key(); + test_jwt_verifier_bad_signature(); + test_jwt_verifier_bad_format(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/linux_system_roots_test.cc b/test/core/security/linux_system_roots_test.cc new file mode 100644 index 00000000..318f27de --- /dev/null +++ b/test/core/security/linux_system_roots_test.cc @@ -0,0 +1,100 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#ifdef GPR_LINUX +#include +#include + +#include "gtest/gtest.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/load_system_roots_linux.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +TEST(AbsoluteFilePathTest, ConcatenatesCorrectly) { + const char* directory = "nonexistent/test/directory"; + const char* filename = "doesnotexist.txt"; + char result_path[MAXPATHLEN]; + grpc_core::GetAbsoluteFilePath(directory, filename, result_path); + EXPECT_STREQ(result_path, "nonexistent/test/directory/doesnotexist.txt"); +} + +TEST(CreateRootCertsBundleTest, ReturnsEmpty) { + // Test that CreateRootCertsBundle returns an empty slice for null or + // nonexistent cert directories. + grpc_slice result_slice = grpc_core::CreateRootCertsBundle(nullptr); + EXPECT_TRUE(GRPC_SLICE_IS_EMPTY(result_slice)); + grpc_slice_unref(result_slice); + result_slice = grpc_core::CreateRootCertsBundle("does/not/exist"); + EXPECT_TRUE(GRPC_SLICE_IS_EMPTY(result_slice)); + grpc_slice_unref(result_slice); +} + +TEST(CreateRootCertsBundleTest, BundlesCorrectly) { + // Test that CreateRootCertsBundle returns a correct slice. + grpc_slice roots_bundle = grpc_empty_slice(); + GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file("test/core/security/etc/bundle.pem", 1, &roots_bundle)); + // result_slice should have the same content as roots_bundle. + grpc_slice result_slice = + grpc_core::CreateRootCertsBundle("test/core/security/etc/test_roots"); + char* result_str = grpc_slice_to_c_string(result_slice); + char* bundle_str = grpc_slice_to_c_string(roots_bundle); + EXPECT_STREQ(result_str, bundle_str); + // Clean up. + gpr_free(result_str); + gpr_free(bundle_str); + grpc_slice_unref(roots_bundle); + grpc_slice_unref(result_slice); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +#else +int main() { + printf("*** WARNING: this test is only supported on Linux systems ***\n"); + return 0; +} +#endif // GPR_LINUX diff --git a/test/core/security/matchers_test.cc b/test/core/security/matchers_test.cc new file mode 100644 index 00000000..d569b09f --- /dev/null +++ b/test/core/security/matchers_test.cc @@ -0,0 +1,206 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/matchers/matchers.h" + +#include + +namespace grpc_core { + +TEST(StringMatcherTest, ExactMatchCaseSensitive) { + auto string_matcher = + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"exact", /*case_sensitive=*/true); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("exact")); + EXPECT_FALSE(string_matcher->Match("Exact")); + EXPECT_FALSE(string_matcher->Match("exacz")); +} + +TEST(StringMatcherTest, ExactMatchCaseInsensitive) { + auto string_matcher = + StringMatcher::Create(StringMatcher::Type::kExact, + /*matcher=*/"exact", /*case_sensitive=*/false); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("Exact")); + EXPECT_FALSE(string_matcher->Match("Exacz")); +} + +TEST(StringMatcherTest, PrefixMatchCaseSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kPrefix, + /*matcher=*/"prefix", + /*case_sensitive=*/true); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("prefix-test")); + EXPECT_FALSE(string_matcher->Match("xx-prefix-test")); + EXPECT_FALSE(string_matcher->Match("Prefix-test")); + EXPECT_FALSE(string_matcher->Match("pre-test")); +} + +TEST(StringMatcherTest, PrefixMatchCaseInsensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kPrefix, + /*matcher=*/"prefix", + /*case_sensitive=*/false); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("PREfix-test")); + EXPECT_FALSE(string_matcher->Match("xx-PREfix-test")); + EXPECT_FALSE(string_matcher->Match("PRE-test")); +} + +TEST(StringMatcherTest, SuffixMatchCaseSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kSuffix, + /*matcher=*/"suffix", + /*case_sensitive=*/true); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("test-suffix")); + EXPECT_FALSE(string_matcher->Match("test-Suffix")); + EXPECT_FALSE(string_matcher->Match("test-suffix-xx")); + EXPECT_FALSE(string_matcher->Match("test-suffiz")); +} + +TEST(StringMatcherTest, SuffixMatchCaseInSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kSuffix, + /*matcher=*/"suffix", + /*case_sensitive=*/false); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("Test-SUFFIX")); + EXPECT_FALSE(string_matcher->Match("Test-SUFFIX-xx")); + EXPECT_FALSE(string_matcher->Match("Test-SUFFIZ")); +} + +TEST(StringMatcherTest, InvalidRegex) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kSafeRegex, + /*matcher=*/"a[b-a]", + /*case_sensitive=*/true); + EXPECT_FALSE(string_matcher.ok()); + EXPECT_EQ(string_matcher.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(string_matcher.status().message(), + "Invalid regex string specified in matcher."); +} + +TEST(StringMatcherTest, SafeRegexMatchCaseSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kSafeRegex, + /*matcher=*/"regex.*"); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("regex-test")); + EXPECT_FALSE(string_matcher->Match("xx-regex-test")); + EXPECT_FALSE(string_matcher->Match("Regex-test")); + EXPECT_FALSE(string_matcher->Match("test-regex")); +} + +TEST(StringMatcherTest, ContainsMatchCaseSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kContains, + /*matcher=*/"contains", + /*case_sensitive=*/true); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("test-contains")); + EXPECT_TRUE(string_matcher->Match("test-contains-test")); + EXPECT_FALSE(string_matcher->Match("test-Contains")); + EXPECT_FALSE(string_matcher->Match("test-containz")); +} + +TEST(StringMatcherTest, ContainsMatchCaseInSensitive) { + auto string_matcher = StringMatcher::Create(StringMatcher::Type::kContains, + /*matcher=*/"contains", + /*case_sensitive=*/false); + ASSERT_TRUE(string_matcher.ok()); + EXPECT_TRUE(string_matcher->Match("Test-Contains")); + EXPECT_TRUE(string_matcher->Match("Test-Contains-Test")); + EXPECT_FALSE(string_matcher->Match("Test-Containz")); +} + +TEST(HeaderMatcherTest, StringMatcher) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kExact, + /*matcher=*/"exact"); + ASSERT_TRUE(header_matcher.ok()); + EXPECT_TRUE(header_matcher->Match("exact")); + EXPECT_FALSE(header_matcher->Match("Exact")); + EXPECT_FALSE(header_matcher->Match("exacz")); +} + +TEST(HeaderMatcherTest, StringMatcherWithInvertMatch) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kExact, + /*matcher=*/"exact", + /*range_start=*/0, /*range_end=*/0, + /*present_match=*/false, /*invert_match=*/true); + ASSERT_TRUE(header_matcher.ok()); + EXPECT_FALSE(header_matcher->Match("exact")); + EXPECT_TRUE(header_matcher->Match("Exact")); + EXPECT_TRUE(header_matcher->Match("exacz")); +} + +TEST(HeaderMatcherTest, InvalidRegex) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kSafeRegex, + /*matcher=*/"a[b-a]", + /*range_start=*/0, /*range_end=*/0, + /*present_match=*/false, /*invert_match=*/true); + EXPECT_FALSE(header_matcher.ok()); + EXPECT_EQ(header_matcher.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(header_matcher.status().message(), + "Invalid regex string specified in matcher."); +} + +TEST(HeaderMatcherTest, RangeMatcherValidRange) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kRange, + /*matcher=*/"", /*range_start=*/10, + /*range_end*/ 20); + ASSERT_TRUE(header_matcher.ok()); + EXPECT_TRUE(header_matcher->Match("16")); + EXPECT_TRUE(header_matcher->Match("10")); + EXPECT_FALSE(header_matcher->Match("3")); + EXPECT_FALSE(header_matcher->Match("20")); +} + +TEST(HeaderMatcherTest, RangeMatcherInvalidRange) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kRange, + /*matcher=*/"", /*range_start=*/20, + /*range_end*/ 10); + EXPECT_FALSE(header_matcher.ok()); + EXPECT_EQ(header_matcher.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ( + header_matcher.status().message(), + "Invalid range specifier specified: end cannot be smaller than start."); +} + +TEST(HeaderMatcherTest, PresentMatcherTrue) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kPresent, + /*matcher=*/"", /*range_start=*/0, + /*range_end=*/0, /*present_match=*/true); + ASSERT_TRUE(header_matcher.ok()); + EXPECT_TRUE(header_matcher->Match("any_value")); + EXPECT_FALSE(header_matcher->Match(absl::nullopt)); +} + +TEST(HeaderMatcherTest, PresentMatcherFalse) { + auto header_matcher = + HeaderMatcher::Create(/*name=*/"key", HeaderMatcher::Type::kPresent, + /*matcher=*/"", /*range_start=*/0, + /*range_end=*/0, /*present_match=*/false); + ASSERT_TRUE(header_matcher.ok()); + EXPECT_FALSE(header_matcher->Match("any_value")); + EXPECT_TRUE(header_matcher->Match(absl::nullopt)); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/security/oauth2_utils.cc b/test/core/security/oauth2_utils.cc new file mode 100644 index 00000000..536696b4 --- /dev/null +++ b/test/core/security/oauth2_utils.cc @@ -0,0 +1,117 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/security/oauth2_utils.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/security/credentials/credentials.h" + +typedef struct { + gpr_mu* mu; + grpc_polling_entity pops; + bool is_done; + char* token; + + grpc_credentials_mdelem_array md_array; + grpc_closure closure; +} oauth2_request; + +static void on_oauth2_response(void* arg, grpc_error_handle error) { + oauth2_request* request = static_cast(arg); + char* token = nullptr; + grpc_slice token_slice; + if (error != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Fetching token failed: %s", + grpc_error_std_string(error).c_str()); + } else { + GPR_ASSERT(request->md_array.size == 1); + token_slice = GRPC_MDVALUE(request->md_array.md[0]); + token = static_cast(gpr_malloc(GRPC_SLICE_LENGTH(token_slice) + 1)); + memcpy(token, GRPC_SLICE_START_PTR(token_slice), + GRPC_SLICE_LENGTH(token_slice)); + token[GRPC_SLICE_LENGTH(token_slice)] = '\0'; + } + grpc_credentials_mdelem_array_destroy(&request->md_array); + gpr_mu_lock(request->mu); + request->is_done = true; + request->token = token; + GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&request->pops), nullptr)); + gpr_mu_unlock(request->mu); +} + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +char* grpc_test_fetch_oauth2_token_with_credentials( + grpc_call_credentials* creds) { + oauth2_request request; + memset(&request, 0, sizeof(request)); + grpc_core::ExecCtx exec_ctx; + grpc_closure do_nothing_closure; + grpc_auth_metadata_context null_ctx = {"", "", nullptr, nullptr}; + + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &request.mu); + request.pops = grpc_polling_entity_create_from_pollset(pollset); + request.is_done = false; + + GRPC_CLOSURE_INIT(&do_nothing_closure, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + + GRPC_CLOSURE_INIT(&request.closure, on_oauth2_response, &request, + grpc_schedule_on_exec_ctx); + + grpc_error_handle error = GRPC_ERROR_NONE; + if (creds->get_request_metadata(&request.pops, null_ctx, &request.md_array, + &request.closure, &error)) { + // Synchronous result; invoke callback directly. + on_oauth2_response(&request, error); + GRPC_ERROR_UNREF(error); + } + grpc_core::ExecCtx::Get()->Flush(); + + gpr_mu_lock(request.mu); + while (!request.is_done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(grpc_polling_entity_pollset(&request.pops), + &worker, GRPC_MILLIS_INF_FUTURE))) { + request.is_done = true; + } + } + gpr_mu_unlock(request.mu); + + grpc_pollset_shutdown(grpc_polling_entity_pollset(&request.pops), + &do_nothing_closure); + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(grpc_polling_entity_pollset(&request.pops)); + gpr_free(pollset); + return request.token; +} diff --git a/test/core/security/print_google_default_creds_token.cc b/test/core/security/print_google_default_creds_token.cc new file mode 100644 index 00000000..0d6c7b79 --- /dev/null +++ b/test/core/security/print_google_default_creds_token.cc @@ -0,0 +1,130 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/security/credentials/composite/composite_credentials.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/util/cmdline.h" + +typedef struct { + gpr_mu* mu; + grpc_polling_entity pops; + bool is_done; + + grpc_credentials_mdelem_array md_array; + grpc_closure on_request_metadata; +} synchronizer; + +static void on_metadata_response(void* arg, grpc_error_handle error) { + synchronizer* sync = static_cast(arg); + if (error != GRPC_ERROR_NONE) { + fprintf(stderr, "Fetching token failed: %s\n", + grpc_error_std_string(error).c_str()); + fflush(stderr); + } else { + char* token; + GPR_ASSERT(sync->md_array.size == 1); + token = grpc_slice_to_c_string(GRPC_MDVALUE(sync->md_array.md[0])); + printf("\nGot token: %s\n\n", token); + gpr_free(token); + } + gpr_mu_lock(sync->mu); + sync->is_done = true; + GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&sync->pops), nullptr)); + gpr_mu_unlock(sync->mu); +} + +int main(int argc, char** argv) { + int result = 0; + grpc_core::ExecCtx exec_ctx; + synchronizer sync; + grpc_channel_credentials* creds = nullptr; + const char* service_url = "https://test.foo.google.com/Foo"; + grpc_auth_metadata_context context; + gpr_cmdline* cl = gpr_cmdline_create("print_google_default_creds_token"); + grpc_pollset* pollset = nullptr; + grpc_error_handle error = GRPC_ERROR_NONE; + gpr_cmdline_add_string(cl, "service_url", + "Service URL for the token request.", &service_url); + gpr_cmdline_parse(cl, argc, argv); + memset(&context, 0, sizeof(context)); + context.service_url = service_url; + + grpc_init(); + + creds = grpc_google_default_credentials_create(); + if (creds == nullptr) { + fprintf(stderr, "\nCould not find default credentials.\n\n"); + fflush(stderr); + result = 1; + goto end; + } + + memset(&sync, 0, sizeof(sync)); + pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &sync.mu); + sync.pops = grpc_polling_entity_create_from_pollset(pollset); + sync.is_done = false; + GRPC_CLOSURE_INIT(&sync.on_request_metadata, on_metadata_response, &sync, + grpc_schedule_on_exec_ctx); + + error = GRPC_ERROR_NONE; + if (reinterpret_cast(creds) + ->mutable_call_creds() + ->get_request_metadata(&sync.pops, context, &sync.md_array, + &sync.on_request_metadata, &error)) { + // Synchronous response. Invoke callback directly. + on_metadata_response(&sync, error); + GRPC_ERROR_UNREF(error); + } + + gpr_mu_lock(sync.mu); + while (!sync.is_done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(grpc_polling_entity_pollset(&sync.pops), &worker, + GRPC_MILLIS_INF_FUTURE))) + sync.is_done = true; + gpr_mu_unlock(sync.mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(sync.mu); + } + gpr_mu_unlock(sync.mu); + + grpc_channel_credentials_release(creds); + gpr_free(grpc_polling_entity_pollset(&sync.pops)); + +end: + gpr_cmdline_destroy(cl); + grpc_shutdown(); + return result; +} diff --git a/test/core/security/rbac_translator_test.cc b/test/core/security/rbac_translator_test.cc new file mode 100644 index 00000000..ecc3f8eb --- /dev/null +++ b/test/core/security/rbac_translator_test.cc @@ -0,0 +1,783 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/lib/security/authorization/rbac_translator.h" + +#include +#include + +namespace grpc_core { + +namespace { + +MATCHER_P2(EqualsPrincipalName, expected_matcher_type, expected_matcher_value, + "") { + return arg->type == Rbac::Principal::RuleType::kPrincipalName && + arg->string_matcher.type() == expected_matcher_type && + arg->string_matcher.string_matcher() == expected_matcher_value; +} + +MATCHER_P2(EqualsPath, expected_matcher_type, expected_matcher_value, "") { + return arg->type == Rbac::Permission::RuleType::kPath && + arg->string_matcher.type() == expected_matcher_type && + arg->string_matcher.string_matcher() == expected_matcher_value; +} + +MATCHER_P3(EqualsHeader, expected_name, expected_matcher_type, + expected_matcher_value, "") { + return arg->type == Rbac::Permission::RuleType::kHeader && + arg->header_matcher.name() == expected_name && + arg->header_matcher.type() == expected_matcher_type && + arg->header_matcher.string_matcher() == expected_matcher_value; +} + +} // namespace + +TEST(GenerateRbacPoliciesTest, InvalidPolicy) { + const char* authz_policy = + "{" + " \"name\": \"authz-policy\",," + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + std::string(rbac_policies.status().message()), + ::testing::StartsWith("Failed to parse SDK authorization policy.")); +} + +TEST(GenerateRbacPoliciesTest, MissingAuthorizationPolicyName) { + const char* authz_policy = "{}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), "\"name\" field is not present."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectAuthorizationPolicyNameType) { + const char* authz_policy = + "{" + " \"name\": [\"authz_policy\"]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), "\"name\" is not a string."); +} + +TEST(GenerateRbacPoliciesTest, MissingAllowRules) { + const char* authz_policy = + "{" + " \"name\": \"authz_policy\"" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "\"allow_rules\" is not present."); +} + +TEST(GenerateRbacPoliciesTest, MissingDenyRules) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().deny_policy.action, Rbac::Action::kDeny); + EXPECT_TRUE(rbac_policies.value().deny_policy.policies.empty()); +} + +TEST(GenerateRbacPoliciesTest, IncorrectAllowRulesType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": {}" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "\"allow_rules\" is not an array."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectDenyRulesType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"deny_rules\": 123" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "\"deny_rules\" is not an array."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectRuleType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [\"rule-a\"]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: is not an object."); +} + +TEST(GenerateRbacPoliciesTest, MissingRuleNameField) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [{}]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"name\" is not present."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectRuleNameType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": 123" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"name\" is not a string."); +} + +TEST(GenerateRbacPoliciesTest, MissingSourceAndRequest) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_allow_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)))))); +} + +TEST(GenerateRbacPoliciesTest, EmptySourceAndRequest) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"source\": {}," + " \"request\": {}" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_allow_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)))))); +} + +TEST(GenerateRbacPoliciesTest, IncorrectSourceType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"source\": 111" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"source\" is not an object."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectPrincipalsType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"source\": {" + " \"principals\": [" + " \"*\"," + " 123" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"principals\" 1: is not a string."); +} + +TEST(GenerateRbacPoliciesTest, ParseSourceSuccess) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"source\": {" + " \"principals\": [" + " \"spiffe://foo.abc\"," + " \"spiffe://bar*\"," + " \"*baz\"," + " \"spiffe://abc.*.com\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_policy\"," + " \"source\": {" + " \"principals\": [" + " \"*\"" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_allow_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::AllOf( + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAnd), + ::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Principal::type, + Rbac::Principal::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre( + EqualsPrincipalName( + StringMatcher::Type::kExact, + "spiffe://foo.abc"), + EqualsPrincipalName( + StringMatcher::Type::kPrefix, + "spiffe://bar"), + EqualsPrincipalName( + StringMatcher::Type::kSuffix, "baz"), + EqualsPrincipalName( + StringMatcher::Type::kExact, + "spiffe://abc.*.com"))))))))))))); + EXPECT_EQ(rbac_policies.value().deny_policy.action, Rbac::Action::kDeny); + EXPECT_THAT( + rbac_policies.value().deny_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_deny_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::AllOf( + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAnd), + ::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Principal::type, + Rbac::Principal::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre(EqualsPrincipalName( + StringMatcher::Type::kPrefix, + ""))))))))))))); +} + +TEST(GenerateRbacPoliciesTest, IncorrectRequestType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_policy\"," + " \"request\": 111" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "deny_rules 0: \"request\" is not an object."); +} + +TEST(GenerateRbacPoliciesTest, IncorrectPathType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"deny_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"request\": {" + " \"paths\": [" + " \"path-a\"," + " 123" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "deny_rules 0: \"paths\" 1: is not a string."); +} + +TEST(GenerateRbacPoliciesTest, ParseRequestPathsSuccess) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"request\": {" + " \"paths\": [" + " \"*\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_policy\"," + " \"request\": {" + " \"paths\": [" + " \"path-foo\"," + " \"path-bar*\"," + " \"*baz\"" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().deny_policy.action, Rbac::Action::kDeny); + EXPECT_THAT( + rbac_policies.value().deny_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_deny_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::AllOf( + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAnd), + ::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre( + EqualsPath(StringMatcher::Type::kExact, + "path-foo"), + EqualsPath(StringMatcher::Type::kPrefix, + "path-bar"), + EqualsPath(StringMatcher::Type::kSuffix, + "baz"))))))))))))); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_allow_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::AllOf( + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAnd), + ::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre( + EqualsPath(StringMatcher::Type::kPrefix, + ""))))))))))))); +} + +TEST(GenerateRbacPoliciesTest, IncorrectHeaderType) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"deny_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"request\": {" + " \"headers\": [" + " \"header-a\"" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "deny_rules 0: \"headers\" 0: is not an object."); +} + +TEST(GenerateRbacPoliciesTest, UnsupportedGrpcHeaders) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"deny_rules\": [" + " {" + " \"name\": \"policy\"," + " \"request\": {" + " \"headers\": [" + " {" + " \"key\": \"grpc-xxx\"," + " \"values\": [" + " \"*\"" + " ]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "deny_rules 0: \"headers\" 0: Unsupported \"key\" grpc-xxx."); +} + +TEST(GenerateRbacPoliciesTest, UnsupportedPseudoHeaders) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"policy\"," + " \"request\": {" + " \"headers\": [" + " {" + " \"key\": \":method\"," + " \"values\": [" + " \"*\"" + " ]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"headers\" 0: Unsupported \"key\" :method."); +} + +TEST(GenerateRbacPoliciesTest, UnsupportedHostHeader) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"policy\"," + " \"request\": {" + " \"headers\": [" + " {" + " \"key\": \"Host\"," + " \"values\": [" + " \"*\"" + " ]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"headers\" 0: Unsupported \"key\" Host."); +} + +TEST(GenerateRbacPoliciesTest, EmptyHeaderValuesList) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy_1\"," + " \"request\": {" + " \"headers\": [" + " {" + " \"key\": \"key-a\"," + " \"values\": [" + " ]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + EXPECT_EQ(rbac_policies.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(rbac_policies.status().message(), + "allow_rules 0: \"headers\" 0: \"values\" list is empty."); +} + +TEST(GenerateRbacPoliciesTest, ParseRequestHeadersSuccess) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy\"," + " \"request\": {" + " \"headers\": [" + " {" + " \"key\": \"key-1\"," + " \"values\": [" + " \"*\"" + " ]" + " }," + " {" + " \"key\": \"key-2\"," + " \"values\": [" + " \"foo\"," + " \"bar*\"," + " \"*baz\"" + " ]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().deny_policy.action, Rbac::Action::kDeny); + EXPECT_TRUE(rbac_policies.value().deny_policy.policies.empty()); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre(::testing::Pair( + "authz_allow_policy", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::AllOf( + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAnd), + ::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kAnd)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre( + ::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre( + EqualsHeader( + "key-1", + HeaderMatcher::Type:: + kPrefix, + ""))))), + ::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre( + EqualsHeader("key-2", + HeaderMatcher:: + Type::kExact, + "foo"), + EqualsHeader( + "key-2", + HeaderMatcher::Type:: + kPrefix, + "bar"), + EqualsHeader( + "key-2", + HeaderMatcher::Type:: + kSuffix, + "baz"))))))))))))))))); +} + +TEST(GenerateRbacPoliciesTest, ParseRulesArraySuccess) { + const char* authz_policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_policy_1\"," + " \"source\": {" + " \"principals\": [" + " \"spiffe://foo.abc\"" + " ]" + " }," + " \"request\": {" + " \"paths\": [" + " \"foo\"" + " ]" + " }" + " }," + " {" + " \"name\": \"allow_policy_2\"" + " }" + " ]" + "}"; + auto rbac_policies = GenerateRbacPolicies(authz_policy); + ASSERT_TRUE(rbac_policies.ok()); + EXPECT_EQ(rbac_policies.value().deny_policy.action, Rbac::Action::kDeny); + EXPECT_TRUE(rbac_policies.value().deny_policy.policies.empty()); + EXPECT_EQ(rbac_policies.value().allow_policy.action, Rbac::Action::kAllow); + EXPECT_THAT( + rbac_policies.value().allow_policy.policies, + ::testing::ElementsAre( + ::testing::Pair( + "authz_allow_policy_1", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::AllOf( + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAnd), + ::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Permission::type, + Rbac::Permission::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Permission::permissions, + ::testing::ElementsAre(EqualsPath( + StringMatcher::Type::kExact, + "foo"))))))))), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::AllOf( + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAnd), + ::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Pointee(::testing::Field( + &Rbac::Principal::type, + Rbac::Principal::RuleType::kOr)), + ::testing::Pointee(::testing::Field( + &Rbac::Principal::principals, + ::testing::ElementsAre( + EqualsPrincipalName( + StringMatcher::Type::kExact, + "spiffe://foo.abc"))))))))))), + ::testing::Pair( + "authz_allow_policy_2", + ::testing::AllOf( + ::testing::Field( + &Rbac::Policy::permissions, + ::testing::Field(&Rbac::Permission::type, + Rbac::Permission::RuleType::kAny)), + ::testing::Field( + &Rbac::Policy::principals, + ::testing::Field(&Rbac::Principal::type, + Rbac::Principal::RuleType::kAny)))))); +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc new file mode 100644 index 00000000..0b3e98c0 --- /dev/null +++ b/test/core/security/secure_endpoint_test.cc @@ -0,0 +1,234 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/transport/secure_endpoint.h" + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/fake_transport_security.h" +#include "test/core/iomgr/endpoint_tests.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +static gpr_mu* g_mu; +static grpc_pollset* g_pollset; + +static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair( + size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices, + bool use_zero_copy_protector) { + grpc_core::ExecCtx exec_ctx; + tsi_frame_protector* fake_read_protector = + tsi_create_fake_frame_protector(nullptr); + tsi_frame_protector* fake_write_protector = + tsi_create_fake_frame_protector(nullptr); + tsi_zero_copy_grpc_protector* fake_read_zero_copy_protector = + use_zero_copy_protector + ? tsi_create_fake_zero_copy_grpc_protector(nullptr) + : nullptr; + tsi_zero_copy_grpc_protector* fake_write_zero_copy_protector = + use_zero_copy_protector + ? tsi_create_fake_zero_copy_grpc_protector(nullptr) + : nullptr; + grpc_endpoint_test_fixture f; + grpc_endpoint_pair tcp; + + grpc_arg a[1]; + a[0].key = const_cast(GRPC_ARG_TCP_READ_CHUNK_SIZE); + a[0].type = GRPC_ARG_INTEGER; + a[0].value.integer = static_cast(slice_size); + grpc_channel_args args = {GPR_ARRAY_SIZE(a), a}; + tcp = grpc_iomgr_create_endpoint_pair("fixture", &args); + grpc_endpoint_add_to_pollset(tcp.client, g_pollset); + grpc_endpoint_add_to_pollset(tcp.server, g_pollset); + + if (leftover_nslices == 0) { + f.client_ep = grpc_secure_endpoint_create(fake_read_protector, + fake_read_zero_copy_protector, + tcp.client, nullptr, 0); + } else { + unsigned i; + tsi_result result; + size_t still_pending_size; + size_t total_buffer_size = 8192; + size_t buffer_size = total_buffer_size; + uint8_t* encrypted_buffer = static_cast(gpr_malloc(buffer_size)); + uint8_t* cur = encrypted_buffer; + grpc_slice encrypted_leftover; + for (i = 0; i < leftover_nslices; i++) { + grpc_slice plain = leftover_slices[i]; + uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain); + size_t message_size = GRPC_SLICE_LENGTH(plain); + while (message_size > 0) { + size_t protected_buffer_size_to_send = buffer_size; + size_t processed_message_size = message_size; + result = tsi_frame_protector_protect( + fake_write_protector, message_bytes, &processed_message_size, cur, + &protected_buffer_size_to_send); + GPR_ASSERT(result == TSI_OK); + message_bytes += processed_message_size; + message_size -= processed_message_size; + cur += protected_buffer_size_to_send; + GPR_ASSERT(buffer_size >= protected_buffer_size_to_send); + buffer_size -= protected_buffer_size_to_send; + } + grpc_slice_unref(plain); + } + do { + size_t protected_buffer_size_to_send = buffer_size; + result = tsi_frame_protector_protect_flush(fake_write_protector, cur, + &protected_buffer_size_to_send, + &still_pending_size); + GPR_ASSERT(result == TSI_OK); + cur += protected_buffer_size_to_send; + GPR_ASSERT(buffer_size >= protected_buffer_size_to_send); + buffer_size -= protected_buffer_size_to_send; + } while (still_pending_size > 0); + encrypted_leftover = grpc_slice_from_copied_buffer( + reinterpret_cast(encrypted_buffer), + total_buffer_size - buffer_size); + f.client_ep = grpc_secure_endpoint_create( + fake_read_protector, fake_read_zero_copy_protector, tcp.client, + &encrypted_leftover, 1); + grpc_slice_unref(encrypted_leftover); + gpr_free(encrypted_buffer); + } + + f.server_ep = grpc_secure_endpoint_create(fake_write_protector, + fake_write_zero_copy_protector, + tcp.server, nullptr, 0); + + return f; +} + +static grpc_endpoint_test_fixture +secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size) { + return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0, + false); +} + +static grpc_endpoint_test_fixture +secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy( + size_t slice_size) { + return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0, + true); +} + +static grpc_endpoint_test_fixture +secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size) { + grpc_slice s = + grpc_slice_from_copied_string("hello world 12345678900987654321"); + return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1, + false); +} + +static grpc_endpoint_test_fixture +secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy( + size_t slice_size) { + grpc_slice s = + grpc_slice_from_copied_string("hello world 12345678900987654321"); + return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1, true); +} + +static void clean_up(void) {} + +static grpc_endpoint_test_config configs[] = { + {"secure_ep/tcp_socketpair", + secure_endpoint_create_fixture_tcp_socketpair_noleftover, clean_up}, + {"secure_ep/tcp_socketpair_zero_copy", + secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy, + clean_up}, + {"secure_ep/tcp_socketpair_leftover", + secure_endpoint_create_fixture_tcp_socketpair_leftover, clean_up}, + {"secure_ep/tcp_socketpair_leftover_zero_copy", + secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy, + clean_up}, +}; + +static void inc_call_ctr(void* arg, grpc_error_handle /*error*/) { + ++*static_cast(arg); +} + +static void test_leftover(grpc_endpoint_test_config config, size_t slice_size) { + grpc_endpoint_test_fixture f = config.create_fixture(slice_size); + grpc_slice_buffer incoming; + grpc_slice s = + grpc_slice_from_copied_string("hello world 12345678900987654321"); + grpc_core::ExecCtx exec_ctx; + int n = 0; + grpc_closure done_closure; + gpr_log(GPR_INFO, "Start test left over"); + + grpc_slice_buffer_init(&incoming); + GRPC_CLOSURE_INIT(&done_closure, inc_call_ctr, &n, grpc_schedule_on_exec_ctx); + grpc_endpoint_read(f.client_ep, &incoming, &done_closure, /*urgent=*/false); + + grpc_core::ExecCtx::Get()->Flush(); + GPR_ASSERT(n == 1); + GPR_ASSERT(incoming.count == 1); + GPR_ASSERT(grpc_slice_eq(s, incoming.slices[0])); + + grpc_endpoint_shutdown( + f.client_ep, GRPC_ERROR_CREATE_FROM_STATIC_STRING("test_leftover end")); + grpc_endpoint_shutdown( + f.server_ep, GRPC_ERROR_CREATE_FROM_STATIC_STRING("test_leftover end")); + grpc_endpoint_destroy(f.client_ep); + grpc_endpoint_destroy(f.server_ep); + + grpc_slice_unref_internal(s); + grpc_slice_buffer_destroy_internal(&incoming); + + clean_up(); +} + +static void destroy_pollset(void* p, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(p)); +} + +int main(int argc, char** argv) { + grpc_closure destroyed; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + { + grpc_core::ExecCtx exec_ctx; + g_pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(g_pollset, &g_mu); + grpc_endpoint_tests(configs[0], g_pollset, g_mu); + grpc_endpoint_tests(configs[1], g_pollset, g_mu); + test_leftover(configs[2], 1); + test_leftover(configs[3], 1); + GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(g_pollset, &destroyed); + } + + grpc_shutdown(); + + gpr_free(g_pollset); + + return 0; +} diff --git a/test/core/security/security_connector_test.cc b/test/core/security/security_connector_test.cc new file mode 100644 index 00000000..fd0555ba --- /dev/null +++ b/test/core/security/security_connector_test.cc @@ -0,0 +1,784 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/security_connector/security_connector.h" + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "src/core/tsi/transport_security.h" +#include "test/core/util/test_config.h" + +#ifndef TSI_OPENSSL_ALPN_SUPPORT +#define TSI_OPENSSL_ALPN_SUPPORT 1 +#endif + +static int check_transport_security_type(const grpc_auth_context* ctx) { + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) return 0; + if (strncmp(prop->value, GRPC_SSL_TRANSPORT_SECURITY_TYPE, + prop->value_length) != 0) { + return 0; + } + /* Check that we have only one property with this name. */ + if (grpc_auth_property_iterator_next(&it) != nullptr) return 0; + return 1; +} + +static int check_peer_property(const tsi_peer* peer, + const tsi_peer_property* expected) { + size_t i; + for (i = 0; i < peer->property_count; i++) { + const tsi_peer_property* prop = &peer->properties[i]; + if ((strcmp(prop->name, expected->name) == 0) && + (prop->value.length == expected->value.length) && + (memcmp(prop->value.data, expected->value.data, + expected->value.length) == 0)) { + return 1; + } + } + return 0; /* Not found... */ +} + +static int check_ssl_peer_equivalence(const tsi_peer* original, + const tsi_peer* reconstructed) { + /* The reconstructed peer only has CN, SAN and pem cert properties. */ + size_t i; + for (i = 0; i < original->property_count; i++) { + const tsi_peer_property* prop = &original->properties[i]; + if ((strcmp(prop->name, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) || + (strcmp(prop->name, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == + 0) || + (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0)) { + if (!check_peer_property(reconstructed, prop)) return 0; + } + } + return 1; +} + +static void test_check_security_level() { + GPR_ASSERT(grpc_check_security_level(GRPC_PRIVACY_AND_INTEGRITY, + GRPC_PRIVACY_AND_INTEGRITY) == true); + GPR_ASSERT(grpc_check_security_level(GRPC_PRIVACY_AND_INTEGRITY, + GRPC_INTEGRITY_ONLY) == true); + GPR_ASSERT(grpc_check_security_level(GRPC_PRIVACY_AND_INTEGRITY, + GRPC_SECURITY_NONE) == true); + GPR_ASSERT(grpc_check_security_level(GRPC_INTEGRITY_ONLY, + GRPC_PRIVACY_AND_INTEGRITY) == false); + GPR_ASSERT(grpc_check_security_level(GRPC_INTEGRITY_ONLY, + GRPC_INTEGRITY_ONLY) == true); + GPR_ASSERT(grpc_check_security_level(GRPC_INTEGRITY_ONLY, + GRPC_SECURITY_NONE) == true); + GPR_ASSERT(grpc_check_security_level(GRPC_SECURITY_NONE, + GRPC_PRIVACY_AND_INTEGRITY) == false); + GPR_ASSERT(grpc_check_security_level(GRPC_SECURITY_NONE, + GRPC_INTEGRITY_ONLY) == false); + GPR_ASSERT(grpc_check_security_level(GRPC_SECURITY_NONE, + GRPC_SECURITY_NONE) == true); +} + +static void test_unauthenticated_ssl_peer(void) { + tsi_peer peer; + tsi_peer rpeer; + GPR_ASSERT(tsi_construct_peer(2, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[1]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(!grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_transport_security_type(ctx.get())); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); + GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); + + grpc_shallow_peer_destruct(&rpeer); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static int check_identity(const grpc_auth_context* ctx, + const char* expected_property_name, + const char** expected_identities, + size_t num_identities) { + grpc_auth_property_iterator it; + const grpc_auth_property* prop; + size_t i; + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); + it = grpc_auth_context_peer_identity(ctx); + for (i = 0; i < num_identities; i++) { + prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "Expected identity value %s not found.", + expected_identities[i]); + return 0; + } + if (strcmp(prop->name, expected_property_name) != 0) { + gpr_log(GPR_ERROR, "Expected peer identity property name %s and got %s.", + expected_property_name, prop->name); + return 0; + } + if (strncmp(prop->value, expected_identities[i], prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected peer identity %s and got %s.", + expected_identities[i], prop->value); + return 0; + } + } + return 1; +} + +static int check_x509_cn(const grpc_auth_context* ctx, + const char* expected_cn) { + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + ctx, GRPC_X509_CN_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "CN property not found."); + return 0; + } + if (strncmp(prop->value, expected_cn, prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected CN %s and got %s", expected_cn, prop->value); + return 0; + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_ERROR, "Expected only one property for CN."); + return 0; + } + return 1; +} + +static int check_x509_pem_cert(const grpc_auth_context* ctx, + const char* expected_pem_cert) { + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "Pem certificate property not found."); + return 0; + } + if (strncmp(prop->value, expected_pem_cert, prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected pem cert %s and got %s", expected_pem_cert, + prop->value); + return 0; + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_ERROR, "Expected only one property for pem cert."); + return 0; + } + return 1; +} + +static int check_x509_pem_cert_chain(const grpc_auth_context* ctx, + const char* expected_pem_cert_chain) { + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + ctx, GRPC_X509_PEM_CERT_CHAIN_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "Pem certificate chain property not found."); + return 0; + } + if (strncmp(prop->value, expected_pem_cert_chain, prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected pem cert chain %s and got %s", + expected_pem_cert_chain, prop->value); + return 0; + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_ERROR, "Expected only one property for pem cert chain."); + return 0; + } + return 1; +} + +static int check_sans( + const grpc_auth_context* ctx, const char* expected_property_name, + const std::vector& expected_property_values) { + grpc_auth_property_iterator it = + grpc_auth_context_find_properties_by_name(ctx, expected_property_name); + for (const auto& property_value : expected_property_values) { + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr) { + gpr_log(GPR_ERROR, "Expected value %s not found.", + property_value.c_str()); + return 0; + } + if (strncmp(prop->value, property_value.c_str(), prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected peer %s and got %s.", property_value.c_str(), + prop->value); + return 0; + } + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_ERROR, "Expected only %zu property values.", + expected_property_values.size()); + return 0; + } + return 1; +} + +static int check_spiffe_id(const grpc_auth_context* ctx, + const char* expected_spiffe_id, + bool expect_spiffe_id) { + grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name( + ctx, GRPC_PEER_SPIFFE_ID_PROPERTY_NAME); + const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it); + if (prop == nullptr && !expect_spiffe_id) { + return 1; + } + if (prop != nullptr && !expect_spiffe_id) { + gpr_log(GPR_ERROR, "SPIFFE ID not expected, but got %s.", prop->value); + return 0; + } + if (prop == nullptr && expect_spiffe_id) { + gpr_log(GPR_ERROR, "SPIFFE ID expected, but got nullptr."); + return 0; + } + if (strncmp(prop->value, expected_spiffe_id, prop->value_length) != 0) { + gpr_log(GPR_ERROR, "Expected SPIFFE ID %s but got %s.", expected_spiffe_id, + prop->value); + return 0; + } + if (grpc_auth_property_iterator_next(&it) != nullptr) { + gpr_log(GPR_ERROR, "Expected only one property for SPIFFE ID."); + return 0; + } + return 1; +} + +static void test_cn_only_ssl_peer_to_auth_context(void) { + tsi_peer peer; + tsi_peer rpeer; + const char* expected_cn = "cn1"; + const char* expected_pem_cert = "pem_cert1"; + const char* expected_pem_cert_chain = "pem_cert1_chain"; + GPR_ASSERT(tsi_construct_peer(5, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, expected_cn, + &peer.properties[1]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, + &peer.properties[2]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[3]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_CHAIN_PROPERTY, expected_pem_cert_chain, + &peer.properties[4]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT( + check_identity(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, &expected_cn, 1)); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + GPR_ASSERT(check_x509_pem_cert_chain(ctx.get(), expected_pem_cert_chain)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); + GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); + + grpc_shallow_peer_destruct(&rpeer); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_cn_and_one_san_ssl_peer_to_auth_context(void) { + tsi_peer peer; + tsi_peer rpeer; + const char* expected_cn = "cn1"; + const char* expected_san = "san1"; + const char* expected_pem_cert = "pem_cert1"; + const char* expected_pem_cert_chain = "pem_cert1_chain"; + GPR_ASSERT(tsi_construct_peer(6, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, expected_cn, + &peer.properties[1]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, expected_san, + &peer.properties[2]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, + &peer.properties[3]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[4]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_CHAIN_PROPERTY, expected_pem_cert_chain, + &peer.properties[5]) == TSI_OK); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT( + check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, &expected_san, 1)); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + GPR_ASSERT(check_x509_pem_cert_chain(ctx.get(), expected_pem_cert_chain)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); + GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); + + grpc_shallow_peer_destruct(&rpeer); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_cn_and_multiple_sans_ssl_peer_to_auth_context(void) { + tsi_peer peer; + tsi_peer rpeer; + const char* expected_cn = "cn1"; + const char* expected_sans[] = {"san1", "san2", "san3"}; + const char* expected_pem_cert = "pem_cert1"; + const char* expected_pem_cert_chain = "pem_cert1_chain"; + size_t i; + GPR_ASSERT(tsi_construct_peer(5 + GPR_ARRAY_SIZE(expected_sans), &peer) == + TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, expected_cn, + &peer.properties[1]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, + &peer.properties[2]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[3]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_CHAIN_PROPERTY, expected_pem_cert_chain, + &peer.properties[4]) == TSI_OK); + for (i = 0; i < GPR_ARRAY_SIZE(expected_sans); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, + expected_sans[i], &peer.properties[5 + i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, + expected_sans, GPR_ARRAY_SIZE(expected_sans))); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + GPR_ASSERT(check_x509_pem_cert_chain(ctx.get(), expected_pem_cert_chain)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); + GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); + + grpc_shallow_peer_destruct(&rpeer); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context( + void) { + tsi_peer peer; + tsi_peer rpeer; + const char* expected_cn = "cn1"; + const char* expected_pem_cert = "pem_cert1"; + const char* expected_pem_cert_chain = "pem_cert1_chain"; + const char* expected_sans[] = {"san1", "san2", "san3"}; + size_t i; + GPR_ASSERT(tsi_construct_peer(7 + GPR_ARRAY_SIZE(expected_sans), &peer) == + TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + "foo", "bar", &peer.properties[1]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, expected_cn, + &peer.properties[2]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + "chapi", "chapo", &peer.properties[3]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, + &peer.properties[4]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_SECURITY_LEVEL_PEER_PROPERTY, + tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY), + &peer.properties[5]) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_PEM_CERT_CHAIN_PROPERTY, expected_pem_cert_chain, + &peer.properties[6]) == TSI_OK); + for (i = 0; i < GPR_ARRAY_SIZE(expected_sans); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, + expected_sans[i], &peer.properties[7 + i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, + expected_sans, GPR_ARRAY_SIZE(expected_sans))); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + GPR_ASSERT(check_x509_pem_cert_chain(ctx.get(), expected_pem_cert_chain)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); + GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); + + grpc_shallow_peer_destruct(&rpeer); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_dns_peer_to_auth_context(void) { + tsi_peer peer; + const std::vector expected_dns = {"dns1", "dns2", "dns3"}; + GPR_ASSERT(tsi_construct_peer(expected_dns.size(), &peer) == TSI_OK); + for (size_t i = 0; i < expected_dns.size(); ++i) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_DNS_PEER_PROPERTY, expected_dns[i].c_str(), + &peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(check_sans(ctx.get(), GRPC_PEER_DNS_PROPERTY_NAME, expected_dns)); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_uri_peer_to_auth_context(void) { + tsi_peer peer; + const std::vector expected_uri = {"uri1", "uri2", "uri3"}; + GPR_ASSERT(tsi_construct_peer(expected_uri.size(), &peer) == TSI_OK); + for (size_t i = 0; i < expected_uri.size(); ++i) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_URI_PEER_PROPERTY, expected_uri[i].c_str(), + &peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(check_sans(ctx.get(), GRPC_PEER_URI_PROPERTY_NAME, expected_uri)); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_email_peer_to_auth_context(void) { + tsi_peer peer; + const std::vector expected_emails = {"email1", "email2"}; + GPR_ASSERT(tsi_construct_peer(expected_emails.size(), &peer) == TSI_OK); + for (size_t i = 0; i < expected_emails.size(); ++i) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_EMAIL_PEER_PROPERTY, expected_emails[i].c_str(), + &peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT( + check_sans(ctx.get(), GRPC_PEER_EMAIL_PROPERTY_NAME, expected_emails)); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_ip_peer_to_auth_context(void) { + tsi_peer peer; + const std::vector expected_ips = {"128.128.128.128", + "255.255.255.255"}; + GPR_ASSERT(tsi_construct_peer(expected_ips.size(), &peer) == TSI_OK); + for (size_t i = 0; i < expected_ips.size(); ++i) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_IP_PEER_PROPERTY, expected_ips[i].c_str(), + &peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer, GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(check_sans(ctx.get(), GRPC_PEER_IP_PROPERTY_NAME, expected_ips)); + tsi_peer_destruct(&peer); + ctx.reset(DEBUG_LOCATION, "test"); +} + +static void test_spiffe_id_peer_to_auth_context(void) { + // Invalid SPIFFE IDs should not be plumbed. + std::string long_id(2050, 'x'); + std::string long_domain(256, 'x'); + tsi_peer invalid_peer; + std::vector invalid_spiffe_id = { + "", + "spi://", + "sfiffe://domain/wl", + "spiffe://domain", + "spiffe://domain/", + long_id, + "spiffe://" + long_domain + "/wl"}; + size_t i; + GPR_ASSERT(tsi_construct_peer(invalid_spiffe_id.size(), &invalid_peer) == + TSI_OK); + for (i = 0; i < invalid_spiffe_id.size(); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_URI_PEER_PROPERTY, invalid_spiffe_id[i].c_str(), + &invalid_peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr invalid_ctx = + grpc_ssl_peer_to_auth_context(&invalid_peer, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(invalid_ctx != nullptr); + GPR_ASSERT(check_spiffe_id(invalid_ctx.get(), nullptr, false)); + tsi_peer_destruct(&invalid_peer); + invalid_ctx.reset(DEBUG_LOCATION, "test"); + // A valid SPIFFE ID should be plumbed. + tsi_peer valid_peer; + std::string valid_spiffe_id = "spiffe://foo.bar.com/wl"; + GPR_ASSERT(tsi_construct_peer(1, &valid_peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_URI_PEER_PROPERTY, valid_spiffe_id.c_str(), + &valid_peer.properties[0]) == TSI_OK); + grpc_core::RefCountedPtr valid_ctx = + grpc_ssl_peer_to_auth_context(&valid_peer, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(valid_ctx != nullptr); + GPR_ASSERT(check_spiffe_id(valid_ctx.get(), "spiffe://foo.bar.com/wl", true)); + tsi_peer_destruct(&valid_peer); + valid_ctx.reset(DEBUG_LOCATION, "test"); + // Multiple SPIFFE IDs should not be plumbed. + tsi_peer multiple_peer; + std::vector multiple_spiffe_id = { + "spiffe://foo.bar.com/wl", "https://xyz", "spiffe://foo.bar.com/wl2"}; + GPR_ASSERT(tsi_construct_peer(multiple_spiffe_id.size(), &multiple_peer) == + TSI_OK); + for (i = 0; i < multiple_spiffe_id.size(); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_URI_PEER_PROPERTY, multiple_spiffe_id[i].c_str(), + &multiple_peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr multiple_ctx = + grpc_ssl_peer_to_auth_context(&multiple_peer, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(multiple_ctx != nullptr); + GPR_ASSERT(check_spiffe_id(multiple_ctx.get(), nullptr, false)); + tsi_peer_destruct(&multiple_peer); + multiple_ctx.reset(DEBUG_LOCATION, "test"); + // A valid SPIFFE certificate should only has one URI SAN field. + // SPIFFE ID should not be plumbed if there are multiple URIs. + tsi_peer multiple_uri_peer; + std::vector multiple_uri = {"spiffe://foo.bar.com/wl", + "https://xyz", "ssh://foo.bar.com/"}; + GPR_ASSERT(tsi_construct_peer(multiple_uri.size(), &multiple_uri_peer) == + TSI_OK); + for (i = 0; i < multiple_spiffe_id.size(); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_URI_PEER_PROPERTY, multiple_uri[i].c_str(), + &multiple_uri_peer.properties[i]) == TSI_OK); + } + grpc_core::RefCountedPtr multiple_uri_ctx = + grpc_ssl_peer_to_auth_context(&multiple_uri_peer, + GRPC_SSL_TRANSPORT_SECURITY_TYPE); + GPR_ASSERT(multiple_uri_ctx != nullptr); + GPR_ASSERT(check_spiffe_id(multiple_uri_ctx.get(), nullptr, false)); + tsi_peer_destruct(&multiple_uri_peer); + multiple_uri_ctx.reset(DEBUG_LOCATION, "test"); +} + +static const char* roots_for_override_api = "roots for override api"; + +static grpc_ssl_roots_override_result override_roots_success( + char** pem_root_certs) { + *pem_root_certs = gpr_strdup(roots_for_override_api); + return GRPC_SSL_ROOTS_OVERRIDE_OK; +} + +static grpc_ssl_roots_override_result override_roots_permanent_failure( + char** /*pem_root_certs*/) { + return GRPC_SSL_ROOTS_OVERRIDE_FAIL_PERMANENTLY; +} + +static void test_ipv6_address_san(void) { + const char* addresses[] = { + "2001:db8::1", "fe80::abcd:ef65:4321%em0", "fd11:feed:beef:0:cafe::4", + "128.10.0.1:8888", "[2001:db8::1]:8080", "[2001:db8::1%em1]:8080", + }; + const char* san_ips[] = { + "2001:db8::1", "fe80::abcd:ef65:4321", "fd11:feed:beef:0:cafe::4", + "128.10.0.1", "2001:db8::1", "2001:db8::1", + }; + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + for (size_t i = 0; i < GPR_ARRAY_SIZE(addresses); i++) { + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, san_ips[i], + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(grpc_ssl_host_matches_name(&peer, addresses[i])); + tsi_peer_property_destruct(&peer.properties[0]); + } + tsi_peer_destruct(&peer); +} + +namespace grpc_core { +namespace { + +class TestDefaultSslRootStore : public DefaultSslRootStore { + public: + static grpc_slice ComputePemRootCertsForTesting() { + return ComputePemRootCerts(); + } +}; + +} // namespace +} // namespace grpc_core + +// TODO(unknown): Convert this test to C++ test when security_connector +// implementation is converted to C++. +static void test_default_ssl_roots(void) { + const char* roots_for_env_var = "roots for env var"; + + char* roots_env_var_file_path; + FILE* roots_env_var_file = + gpr_tmpfile("test_roots_for_env_var", &roots_env_var_file_path); + fwrite(roots_for_env_var, 1, strlen(roots_for_env_var), roots_env_var_file); + fclose(roots_env_var_file); + + /* First let's get the root through the override: set the env to an invalid + value. */ + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, ""); + grpc_set_ssl_roots_override_callback(override_roots_success); + grpc_slice roots = + grpc_core::TestDefaultSslRootStore::ComputePemRootCertsForTesting(); + char* roots_contents = grpc_slice_to_c_string(roots); + grpc_slice_unref(roots); + GPR_ASSERT(strcmp(roots_contents, roots_for_override_api) == 0); + gpr_free(roots_contents); + + /* Now let's set the env var: We should get the contents pointed value + instead. */ + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, + roots_env_var_file_path); + roots = grpc_core::TestDefaultSslRootStore::ComputePemRootCertsForTesting(); + roots_contents = grpc_slice_to_c_string(roots); + grpc_slice_unref(roots); + GPR_ASSERT(strcmp(roots_contents, roots_for_env_var) == 0); + gpr_free(roots_contents); + + /* Now reset the env var. We should fall back to the value overridden using + the api. */ + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, ""); + roots = grpc_core::TestDefaultSslRootStore::ComputePemRootCertsForTesting(); + roots_contents = grpc_slice_to_c_string(roots); + grpc_slice_unref(roots); + GPR_ASSERT(strcmp(roots_contents, roots_for_override_api) == 0); + gpr_free(roots_contents); + + /* Now setup a permanent failure for the overridden roots and we should get + an empty slice. */ + GPR_GLOBAL_CONFIG_SET(grpc_not_use_system_ssl_roots, true); + grpc_set_ssl_roots_override_callback(override_roots_permanent_failure); + roots = grpc_core::TestDefaultSslRootStore::ComputePemRootCertsForTesting(); + GPR_ASSERT(GRPC_SLICE_IS_EMPTY(roots)); + const tsi_ssl_root_certs_store* root_store = + grpc_core::TestDefaultSslRootStore::GetRootStore(); + GPR_ASSERT(root_store == nullptr); + + /* Cleanup. */ + remove(roots_env_var_file_path); + gpr_free(roots_env_var_file_path); +} + +static void test_peer_alpn_check(void) { +#if TSI_OPENSSL_ALPN_SUPPORT + tsi_peer peer; + const char* alpn = "grpc"; + const char* wrong_alpn = "wrong"; + // peer does not have a TSI_SSL_ALPN_SELECTED_PROTOCOL property. + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property("wrong peer property name", + alpn, strlen(alpn), + &peer.properties[0]) == TSI_OK); + grpc_error_handle error = grpc_ssl_check_alpn(&peer); + GPR_ASSERT(error != GRPC_ERROR_NONE); + tsi_peer_destruct(&peer); + GRPC_ERROR_UNREF(error); + // peer has a TSI_SSL_ALPN_SELECTED_PROTOCOL property but with an incorrect + // property value. + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property(TSI_SSL_ALPN_SELECTED_PROTOCOL, + wrong_alpn, strlen(wrong_alpn), + &peer.properties[0]) == TSI_OK); + error = grpc_ssl_check_alpn(&peer); + GPR_ASSERT(error != GRPC_ERROR_NONE); + tsi_peer_destruct(&peer); + GRPC_ERROR_UNREF(error); + // peer has a TSI_SSL_ALPN_SELECTED_PROTOCOL property with a correct property + // value. + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property(TSI_SSL_ALPN_SELECTED_PROTOCOL, + alpn, strlen(alpn), + &peer.properties[0]) == TSI_OK); + GPR_ASSERT(grpc_ssl_check_alpn(&peer) == GRPC_ERROR_NONE); + tsi_peer_destruct(&peer); +#else + GPR_ASSERT(grpc_ssl_check_alpn(nullptr) == GRPC_ERROR_NONE); +#endif +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_unauthenticated_ssl_peer(); + test_cn_only_ssl_peer_to_auth_context(); + test_cn_and_one_san_ssl_peer_to_auth_context(); + test_cn_and_multiple_sans_ssl_peer_to_auth_context(); + test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context(); + test_dns_peer_to_auth_context(); + test_uri_peer_to_auth_context(); + test_email_peer_to_auth_context(); + test_ip_peer_to_auth_context(); + test_spiffe_id_peer_to_auth_context(); + test_ipv6_address_san(); + test_default_ssl_roots(); + test_peer_alpn_check(); + test_check_security_level(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/ssl_credentials_test.cc b/test/core/security/ssl_credentials_test.cc new file mode 100644 index 00000000..a9833013 --- /dev/null +++ b/test/core/security/ssl_credentials_test.cc @@ -0,0 +1,67 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/credentials/ssl/ssl_credentials.h" + +#include +#include + +#include +#include +#include + +#include "src/core/tsi/ssl_transport_security.h" +#include "test/core/util/test_config.h" + +static void test_convert_grpc_to_tsi_cert_pairs() { + grpc_ssl_pem_key_cert_pair grpc_pairs[] = {{"private_key1", "cert_chain1"}, + {"private_key2", "cert_chain2"}, + {"private_key3", "cert_chain3"}}; + const size_t num_pairs = 3; + + { + tsi_ssl_pem_key_cert_pair* tsi_pairs = + grpc_convert_grpc_to_tsi_cert_pairs(grpc_pairs, 0); + GPR_ASSERT(tsi_pairs == nullptr); + } + + { + tsi_ssl_pem_key_cert_pair* tsi_pairs = + grpc_convert_grpc_to_tsi_cert_pairs(grpc_pairs, num_pairs); + + GPR_ASSERT(tsi_pairs != nullptr); + for (size_t i = 0; i < num_pairs; i++) { + GPR_ASSERT(strncmp(grpc_pairs[i].private_key, tsi_pairs[i].private_key, + strlen(grpc_pairs[i].private_key)) == 0); + GPR_ASSERT(strncmp(grpc_pairs[i].cert_chain, tsi_pairs[i].cert_chain, + strlen(grpc_pairs[i].cert_chain)) == 0); + } + + grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_pairs, num_pairs); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_convert_grpc_to_tsi_cert_pairs(); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/ssl_server_fuzzer.cc b/test/core/security/ssl_server_fuzzer.cc new file mode 100644 index 00000000..08a696cf --- /dev/null +++ b/test/core/security/ssl_server_fuzzer.cc @@ -0,0 +1,126 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "test/core/util/mock_endpoint.h" +#include "test/core/util/resource_user_util.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +bool squelch = true; +// ssl has an array of global gpr_mu's that are never released. +// Turning this on will fail the leak check. +bool leak_check = false; + +static void discard_write(grpc_slice /*slice*/) {} + +static void dont_log(gpr_log_func_args* /*args*/) {} + +struct handshake_state { + bool done_callback_called; +}; + +static void on_handshake_done(void* arg, grpc_error_handle error) { + grpc_core::HandshakerArgs* args = + static_cast(arg); + struct handshake_state* state = + static_cast(args->user_data); + GPR_ASSERT(state->done_callback_called == false); + state->done_callback_called = true; + // The fuzzer should not pass the handshake. + GPR_ASSERT(error != GRPC_ERROR_NONE); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (squelch) gpr_set_log_function(dont_log); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + + grpc_slice_allocator* slice_allocator = + grpc_slice_allocator_create_unlimited(); + grpc_endpoint* mock_endpoint = + grpc_mock_endpoint_create(discard_write, slice_allocator); + + grpc_mock_endpoint_put_read( + mock_endpoint, grpc_slice_from_copied_buffer((const char*)data, size)); + + // Load key pair and establish server SSL credentials. + grpc_slice ca_slice, cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* ca_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* creds = grpc_ssl_server_credentials_create( + ca_cert, &pem_key_cert_pair, 1, 0, nullptr); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_slice_unref(ca_slice); + + // Create security connector + grpc_core::RefCountedPtr sc = + creds->create_security_connector(nullptr); + GPR_ASSERT(sc != nullptr); + grpc_millis deadline = GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now(); + + struct handshake_state state; + state.done_callback_called = false; + auto handshake_mgr = + grpc_core::MakeRefCounted(); + sc->add_handshakers(nullptr, nullptr, handshake_mgr.get()); + handshake_mgr->DoHandshake(mock_endpoint, nullptr /* channel_args */, + deadline, nullptr /* acceptor */, + on_handshake_done, &state); + grpc_core::ExecCtx::Get()->Flush(); + + // If the given string happens to be part of the correct client hello, the + // server will wait for more data. Explicitly fail the server by shutting + // down the endpoint. + if (!state.done_callback_called) { + grpc_endpoint_shutdown( + mock_endpoint, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Explicit close")); + grpc_core::ExecCtx::Get()->Flush(); + } + GPR_ASSERT(state.done_callback_called); + + sc.reset(DEBUG_LOCATION, "test"); + grpc_server_credentials_release(creds); + grpc_core::ExecCtx::Get()->Flush(); + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/security/tls_security_connector_test.cc b/test/core/security/tls_security_connector_test.cc new file mode 100644 index 00000000..ef09559f --- /dev/null +++ b/test/core/security/tls_security_connector_test.cc @@ -0,0 +1,705 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/security/security_connector/tls/tls_security_connector.h" + +#include +#include + +#include +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h" +#include "src/core/lib/security/credentials/tls/tls_credentials.h" +#include "src/core/tsi/transport_security.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define CLIENT_CERT_PATH "src/core/tsi/test_creds/multi-domain.pem" +#define SERVER_CERT_PATH_0 "src/core/tsi/test_creds/server0.pem" +#define SERVER_KEY_PATH_0 "src/core/tsi/test_creds/server0.key" +#define SERVER_CERT_PATH_1 "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH_1 "src/core/tsi/test_creds/server1.key" + +namespace grpc { +namespace testing { + +constexpr const char* kRootCertName = "root_cert_name"; +constexpr const char* kIdentityCertName = "identity_cert_name"; +constexpr const char* kErrorMessage = "error_message"; +constexpr const char* kTargetName = "some_target"; + +class TlsSecurityConnectorTest : public ::testing::Test { + protected: + TlsSecurityConnectorTest() {} + void SetUp() override { + grpc_slice ca_slice_1, ca_slice_0, cert_slice_1, key_slice_1, cert_slice_0, + key_slice_0; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice_1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(CLIENT_CERT_PATH, 1, &ca_slice_0))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH_1, 1, &cert_slice_1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH_1, 1, &key_slice_1))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH_0, 1, &cert_slice_0))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH_0, 1, &key_slice_0))); + root_cert_1_ = std::string(grpc_core::StringViewFromSlice(ca_slice_1)); + root_cert_0_ = std::string(grpc_core::StringViewFromSlice(ca_slice_0)); + std::string identity_key_1 = + std::string(grpc_core::StringViewFromSlice(key_slice_1)); + std::string identity_key_0 = + std::string(grpc_core::StringViewFromSlice(key_slice_0)); + std::string identity_cert_1 = + std::string(grpc_core::StringViewFromSlice(cert_slice_1)); + std::string identity_cert_0 = + std::string(grpc_core::StringViewFromSlice(cert_slice_0)); + identity_pairs_1_.emplace_back(identity_key_1, identity_cert_1); + identity_pairs_0_.emplace_back(identity_key_0, identity_cert_0); + grpc_slice_unref(ca_slice_1); + grpc_slice_unref(ca_slice_0); + grpc_slice_unref(cert_slice_1); + grpc_slice_unref(key_slice_1); + grpc_slice_unref(cert_slice_0); + grpc_slice_unref(key_slice_0); + } + + void TearDown() override {} + + std::string root_cert_1_; + std::string root_cert_0_; + grpc_core::PemKeyCertPairList identity_pairs_1_; + grpc_core::PemKeyCertPairList identity_pairs_0_; +}; + +class TlsTestCertificateProvider : public ::grpc_tls_certificate_provider { + public: + explicit TlsTestCertificateProvider( + grpc_core::RefCountedPtr distributor) + : distributor_(std::move(distributor)) {} + ~TlsTestCertificateProvider() override {} + grpc_core::RefCountedPtr distributor() + const override { + return distributor_; + } + + private: + grpc_core::RefCountedPtr distributor_; +}; + +// Tests for ChannelSecurityConnector. +TEST_F(TlsSecurityConnectorTest, + RootAndIdentityCertsObtainedWhenCreateChannelSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_channel_args* new_args = nullptr; + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + &new_args); + EXPECT_NE(connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + distributor->SetKeyMaterials(kRootCertName, root_cert_1_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_1_); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_1_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_1_); + grpc_channel_args_destroy(new_args); +} + +TEST_F(TlsSecurityConnectorTest, + SystemRootsWhenCreateChannelSecurityConnector) { + // Create options watching for no certificates. + grpc_core::RefCountedPtr root_options = + grpc_core::MakeRefCounted(); + grpc_core::RefCountedPtr root_credential = + grpc_core::MakeRefCounted(root_options); + grpc_channel_args* root_new_args = nullptr; + grpc_core::RefCountedPtr root_connector = + root_credential->create_security_connector(nullptr, "some_target", + nullptr, &root_new_args); + EXPECT_NE(root_connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_root_connector = + static_cast( + root_connector.get()); + EXPECT_NE(tls_root_connector->ClientHandshakerFactoryForTesting(), nullptr); + grpc_channel_args_destroy(root_new_args); +} + +TEST_F(TlsSecurityConnectorTest, + SystemRootsAndIdentityCertsObtainedWhenCreateChannelSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + // Create options only watching for identity certificates. + grpc_core::RefCountedPtr root_options = + grpc_core::MakeRefCounted(); + root_options->set_certificate_provider(provider); + root_options->set_watch_identity_pair(true); + root_options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr root_credential = + grpc_core::MakeRefCounted(root_options); + grpc_channel_args* root_new_args = nullptr; + grpc_core::RefCountedPtr root_connector = + root_credential->create_security_connector(nullptr, "some_target", + nullptr, &root_new_args); + EXPECT_NE(root_connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_root_connector = + static_cast( + root_connector.get()); + EXPECT_NE(tls_root_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_root_connector->KeyCertPairListForTesting(), identity_pairs_0_); + // If we have a root update, we shouldn't receive them in security connector, + // since we claimed to use default system roots. + distributor->SetKeyMaterials(kRootCertName, root_cert_1_, absl::nullopt); + EXPECT_NE(tls_root_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_NE(tls_root_connector->RootCertsForTesting(), root_cert_1_); + grpc_channel_args_destroy(root_new_args); +} + +TEST_F(TlsSecurityConnectorTest, + RootCertsObtainedWhenCreateChannelSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + // Create options only watching for root certificates. + grpc_core::RefCountedPtr root_options = + grpc_core::MakeRefCounted(); + root_options->set_certificate_provider(provider); + root_options->set_watch_root_cert(true); + root_options->set_root_cert_name(kRootCertName); + grpc_core::RefCountedPtr root_credential = + grpc_core::MakeRefCounted(root_options); + grpc_channel_args* root_new_args = nullptr; + grpc_core::RefCountedPtr root_connector = + root_credential->create_security_connector(nullptr, "some_target", + nullptr, &root_new_args); + EXPECT_NE(root_connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_root_connector = + static_cast( + root_connector.get()); + EXPECT_NE(tls_root_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_root_connector->RootCertsForTesting(), root_cert_0_); + distributor->SetKeyMaterials(kRootCertName, root_cert_1_, absl::nullopt); + EXPECT_NE(tls_root_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_root_connector->RootCertsForTesting(), root_cert_1_); + grpc_channel_args_destroy(root_new_args); +} + +TEST_F(TlsSecurityConnectorTest, + CertPartiallyObtainedWhenCreateChannelSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + // Registered the options watching both certs, but only root certs are + // available at distributor right now. + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_channel_args* new_args = nullptr; + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + &new_args); + EXPECT_NE(connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + // The client_handshaker_factory_ shouldn't be updated. + EXPECT_EQ(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + // After updating the root certs, the client_handshaker_factory_ should be + // updated. + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + grpc_channel_args_destroy(new_args); +} + +TEST_F(TlsSecurityConnectorTest, + DistributorHasErrorForChannelSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_channel_args* new_args = nullptr; + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + &new_args); + EXPECT_NE(connector, nullptr); + grpc_core::TlsChannelSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + // Calling SetErrorForCert on distributor shouldn't invalidate the previous + // valid credentials. + distributor->SetErrorForCert( + kRootCertName, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage), + absl::nullopt); + distributor->SetErrorForCert( + kIdentityCertName, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage)); + EXPECT_NE(tls_connector->ClientHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + grpc_channel_args_destroy(new_args); +} + +TEST_F(TlsSecurityConnectorTest, + CreateChannelSecurityConnectorFailNoTargetName) { + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_channel_args* new_args = nullptr; + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, nullptr, nullptr, + &new_args); + EXPECT_EQ(connector, nullptr); +} + +TEST_F(TlsSecurityConnectorTest, + CreateChannelSecurityConnectorFailNoCredentials) { + auto connector = + grpc_core::TlsChannelSecurityConnector::CreateTlsChannelSecurityConnector( + nullptr, grpc_core::MakeRefCounted(), + nullptr, kTargetName, nullptr, nullptr); + EXPECT_EQ(connector, nullptr); +} + +TEST_F(TlsSecurityConnectorTest, CreateChannelSecurityConnectorFailNoOptions) { + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + auto connector = + grpc_core::TlsChannelSecurityConnector::CreateTlsChannelSecurityConnector( + credential, nullptr, nullptr, kTargetName, nullptr, nullptr); + EXPECT_EQ(connector, nullptr); +} + +TEST_F(TlsSecurityConnectorTest, TlsCheckHostNameSuccess) { + const char* target_name = "foo.test.google.fr"; + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, target_name, + &peer.properties[0]) == TSI_OK); + grpc_error_handle error = + grpc_core::internal::TlsCheckHostName(target_name, &peer); + tsi_peer_destruct(&peer); + EXPECT_EQ(error, GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); +} + +TEST_F(TlsSecurityConnectorTest, TlsCheckHostNameFail) { + const char* target_name = "foo.test.google.fr"; + const char* another_name = "bar.test.google.fr"; + tsi_peer peer; + GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, another_name, + &peer.properties[0]) == TSI_OK); + grpc_error_handle error = + grpc_core::internal::TlsCheckHostName(target_name, &peer); + tsi_peer_destruct(&peer); + EXPECT_NE(error, GRPC_ERROR_NONE); + GRPC_ERROR_UNREF(error); +} + +TEST_F(TlsSecurityConnectorTest, + CompareChannelSecurityConnectorSucceedsOnSameCredentials) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_root_cert_name(kRootCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + grpc_core::RefCountedPtr other_connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + // Comparing the equality of security connectors generated from the same + // channel credentials with same settings should succeed. + EXPECT_EQ(connector->cmp(other_connector.get()), 0); +} + +TEST_F(TlsSecurityConnectorTest, + CompareChannelSecurityConnectorFailsOnDifferentChannelCredentials) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_root_cert_name(kRootCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + grpc_core::RefCountedPtr other_credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr other_connector = + other_credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + // Comparing the equality of security connectors generated from different + // channel credentials should fail. + EXPECT_NE(connector->cmp(other_connector.get()), 0); +} + +TEST_F(TlsSecurityConnectorTest, + CompareChannelSecurityConnectorFailsOnDifferentCallCredentials) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_root_cert_name(kRootCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + grpc_call_credentials* call_creds = + grpc_md_only_test_credentials_create("", "", true); + grpc_core::RefCountedPtr other_connector = + credential->create_security_connector( + grpc_core::RefCountedPtr(call_creds), + kTargetName, nullptr, nullptr); + // Comparing the equality of security connectors generated with different call + // credentials should fail. + EXPECT_NE(connector->cmp(other_connector.get()), 0); +} + +TEST_F(TlsSecurityConnectorTest, + CompareChannelSecurityConnectorFailsOnDifferentTargetNames) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_root_cert_name(kRootCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr, kTargetName, nullptr, + nullptr); + grpc_core::RefCountedPtr other_connector = + credential->create_security_connector(nullptr, "", nullptr, nullptr); + // Comparing the equality of security connectors generated with different + // target names should fail. + EXPECT_NE(connector->cmp(other_connector.get()), 0); +} + +// Tests for ServerSecurityConnector. +TEST_F(TlsSecurityConnectorTest, + RootAndIdentityCertsObtainedWhenCreateServerSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr); + EXPECT_NE(connector, nullptr); + grpc_core::TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + distributor->SetKeyMaterials(kRootCertName, root_cert_1_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_1_); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_1_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_1_); +} + +// Note that on server side, we don't have tests watching root certs only, +// because in TLS, the identity certs should always be presented. If we don't +// provide, it will try to load certs from some default system locations, and +// will hence fail on some systems. +TEST_F(TlsSecurityConnectorTest, + IdentityCertsObtainedWhenCreateServerSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + // Create options only watching for identity certificates. + grpc_core::RefCountedPtr identity_options = + grpc_core::MakeRefCounted(); + identity_options->set_certificate_provider(provider); + identity_options->set_watch_identity_pair(true); + identity_options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr identity_credential = + grpc_core::MakeRefCounted(identity_options); + grpc_core::RefCountedPtr identity_connector = + identity_credential->create_security_connector(nullptr); + EXPECT_NE(identity_connector, nullptr); + grpc_core::TlsServerSecurityConnector* tls_identity_connector = + static_cast( + identity_connector.get()); + EXPECT_NE(tls_identity_connector->ServerHandshakerFactoryForTesting(), + nullptr); + EXPECT_EQ(tls_identity_connector->KeyCertPairListForTesting(), + identity_pairs_0_); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_1_); + EXPECT_NE(tls_identity_connector->ServerHandshakerFactoryForTesting(), + nullptr); + EXPECT_EQ(tls_identity_connector->KeyCertPairListForTesting(), + identity_pairs_1_); +} + +TEST_F(TlsSecurityConnectorTest, + CertPartiallyObtainedWhenCreateServerSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + // Registered the options watching both certs, but only root certs are + // available at distributor right now. + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr); + EXPECT_NE(connector, nullptr); + grpc_core::TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + // The server_handshaker_factory_ shouldn't be updated. + EXPECT_EQ(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + // After updating the root certs, the server_handshaker_factory_ should be + // updated. + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); +} + +TEST_F(TlsSecurityConnectorTest, + DistributorHasErrorForServerSecurityConnector) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kRootCertName, root_cert_0_, absl::nullopt); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_root_cert(true); + options->set_watch_identity_pair(true); + options->set_root_cert_name(kRootCertName); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr); + EXPECT_NE(connector, nullptr); + grpc_core::TlsServerSecurityConnector* tls_connector = + static_cast(connector.get()); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); + // Calling SetErrorForCert on distributor shouldn't invalidate the previous + // valid credentials. + distributor->SetErrorForCert( + kRootCertName, GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage), + absl::nullopt); + distributor->SetErrorForCert( + kIdentityCertName, absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kErrorMessage)); + EXPECT_NE(tls_connector->ServerHandshakerFactoryForTesting(), nullptr); + EXPECT_EQ(tls_connector->RootCertsForTesting(), root_cert_0_); + EXPECT_EQ(tls_connector->KeyCertPairListForTesting(), identity_pairs_0_); +} + +TEST_F(TlsSecurityConnectorTest, + CreateServerSecurityConnectorFailNoCredentials) { + auto connector = + grpc_core::TlsServerSecurityConnector::CreateTlsServerSecurityConnector( + nullptr, grpc_core::MakeRefCounted()); + EXPECT_EQ(connector, nullptr); +} + +TEST_F(TlsSecurityConnectorTest, CreateServerSecurityConnectorFailNoOptions) { + grpc_core::RefCountedPtr options = + grpc_core::MakeRefCounted(); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + auto connector = + grpc_core::TlsServerSecurityConnector::CreateTlsServerSecurityConnector( + credential, nullptr); + EXPECT_EQ(connector, nullptr); +} + +TEST_F(TlsSecurityConnectorTest, + CompareServerSecurityConnectorSucceedsOnSameCredentials) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_identity_pair(true); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr); + grpc_core::RefCountedPtr other_connector = + credential->create_security_connector(nullptr); + // Comparing the equality of security connectors generated from the same + // server credentials with same settings should succeed. + EXPECT_EQ(connector->cmp(other_connector.get()), 0); +} + +TEST_F(TlsSecurityConnectorTest, + CompareServerSecurityConnectorFailsOnDifferentServerCredentials) { + grpc_core::RefCountedPtr distributor = + grpc_core::MakeRefCounted(); + distributor->SetKeyMaterials(kIdentityCertName, absl::nullopt, + identity_pairs_0_); + grpc_core::RefCountedPtr<::grpc_tls_certificate_provider> provider = + grpc_core::MakeRefCounted(distributor); + auto options = grpc_core::MakeRefCounted(); + options->set_certificate_provider(provider); + options->set_watch_identity_pair(true); + options->set_identity_cert_name(kIdentityCertName); + grpc_core::RefCountedPtr credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr connector = + credential->create_security_connector(nullptr); + grpc_core::RefCountedPtr other_credential = + grpc_core::MakeRefCounted(options); + grpc_core::RefCountedPtr other_connector = + other_credential->create_security_connector(nullptr); + // Comparing the equality of security connectors generated from different + // server credentials should fail. + EXPECT_NE(connector->cmp(other_connector.get()), 0); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + GPR_GLOBAL_CONFIG_SET(grpc_default_ssl_roots_file_path, CA_CERT_PATH); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/security/verify_jwt.cc b/test/core/security/verify_jwt.cc new file mode 100644 index 00000000..bf951b47 --- /dev/null +++ b/test/core/security/verify_jwt.cc @@ -0,0 +1,120 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/security/credentials/jwt/jwt_verifier.h" +#include "test/core/util/cmdline.h" + +typedef struct { + grpc_pollset* pollset; + gpr_mu* mu; + int is_done; + int success; +} synchronizer; + +static void print_usage_and_exit(gpr_cmdline* cl, const char* argv0) { + std::string usage = gpr_cmdline_usage_string(cl, argv0); + fprintf(stderr, "%s", usage.c_str()); + fflush(stderr); + gpr_cmdline_destroy(cl); + exit(1); +} + +static void on_jwt_verification_done(void* user_data, + grpc_jwt_verifier_status status, + grpc_jwt_claims* claims) { + synchronizer* sync = static_cast(user_data); + + sync->success = (status == GRPC_JWT_VERIFIER_OK); + if (sync->success) { + GPR_ASSERT(claims != nullptr); + std::string claims_str = grpc_jwt_claims_json(claims)->Dump(/*indent=*/2); + printf("Claims: \n\n%s\n", claims_str.c_str()); + grpc_jwt_claims_destroy(claims); + } else { + GPR_ASSERT(claims == nullptr); + fprintf(stderr, "Verification failed with error %s\n", + grpc_jwt_verifier_status_to_string(status)); + fflush(stderr); + } + + gpr_mu_lock(sync->mu); + sync->is_done = 1; + GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(sync->pollset, nullptr)); + gpr_mu_unlock(sync->mu); +} + +int main(int argc, char** argv) { + synchronizer sync; + grpc_jwt_verifier* verifier; + gpr_cmdline* cl; + const char* jwt = nullptr; + const char* aud = nullptr; + grpc_core::ExecCtx exec_ctx; + + grpc_init(); + cl = gpr_cmdline_create("JWT verifier tool"); + gpr_cmdline_add_string(cl, "jwt", "JSON web token to verify", &jwt); + gpr_cmdline_add_string(cl, "aud", "Audience for the JWT", &aud); + gpr_cmdline_parse(cl, argc, argv); + if (jwt == nullptr || aud == nullptr) { + print_usage_and_exit(cl, argv[0]); + } + + verifier = grpc_jwt_verifier_create(nullptr, 0); + + grpc_init(); + + sync.pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(sync.pollset, &sync.mu); + sync.is_done = 0; + + grpc_jwt_verifier_verify(verifier, sync.pollset, jwt, aud, + on_jwt_verification_done, &sync); + + gpr_mu_lock(sync.mu); + while (!sync.is_done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(sync.pollset, &worker, GRPC_MILLIS_INF_FUTURE))) { + sync.is_done = true; + } + gpr_mu_unlock(sync.mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(sync.mu); + } + gpr_mu_unlock(sync.mu); + + gpr_free(sync.pollset); + + grpc_jwt_verifier_destroy(verifier); + + gpr_cmdline_destroy(cl); + grpc_shutdown(); + return !sync.success; +} diff --git a/test/core/security/xds_credentials_test.cc b/test/core/security/xds_credentials_test.cc new file mode 100644 index 00000000..c055972a --- /dev/null +++ b/test/core/security/xds_credentials_test.cc @@ -0,0 +1,304 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/lib/security/credentials/xds/xds_credentials.h" + +#include + +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +namespace { + +StringMatcher ExactMatcher(const char* string) { + return StringMatcher::Create(StringMatcher::Type::kExact, string).value(); +} + +StringMatcher PrefixMatcher(const char* string, bool case_sensitive = true) { + return StringMatcher::Create(StringMatcher::Type::kPrefix, string, + case_sensitive) + .value(); +} + +StringMatcher SuffixMatcher(const char* string, bool case_sensitive = true) { + return StringMatcher::Create(StringMatcher::Type::kSuffix, string, + case_sensitive) + .value(); +} + +StringMatcher ContainsMatcher(const char* string, bool case_sensitive = true) { + return StringMatcher::Create(StringMatcher::Type::kContains, string, + case_sensitive) + .value(); +} + +StringMatcher SafeRegexMatcher(const char* string) { + return StringMatcher::Create(StringMatcher::Type::kSafeRegex, string).value(); +} + +TEST(XdsSanMatchingTest, EmptySansList) { + std::vector sans = {}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ExactMatcher("a.example.com"), ExactMatcher("b.example.com")})); +} + +TEST(XdsSanMatchingTest, EmptyMatchersList) { + std::vector sans = {"a.example.com", "foo.example.com"}; + EXPECT_TRUE( + TestOnlyXdsVerifySubjectAlternativeNames(sans.data(), sans.size(), {})); +} + +TEST(XdsSanMatchingTest, ExactMatchIllegalValues) { + std::vector sans = {".a.example.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ExactMatcher(""), ExactMatcher("a.example.com"), + ExactMatcher(".a.example.com")})); + sans = {""}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ExactMatcher(""), ExactMatcher("a.example.com"), + ExactMatcher(".a.example.com")})); + sans = {"a.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ExactMatcher(""), ExactMatcher("a.example.com"), + ExactMatcher(".a.example.com")})); +} + +TEST(XdsSanMatchingTest, ExactMatchDns) { + std::vector sans = {"a.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("b.example.com")})); + sans = {"b.example.com."}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com.")})); + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("b.example.com.")})); +} + +TEST(XdsSanMatchingTest, ExactMatchWithFullyQualifiedSan) { + std::vector sans = {"a.example.com."}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("b.example.com")})); +} + +TEST(XdsSanMatchingTest, ExactMatchWithFullyQualifiedMatcher) { + std::vector sans = {"a.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com.")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("b.example.com.")})); +} + +TEST(XdsSanMatchingTest, ExactMatchDnsCaseInsensitive) { + std::vector sans = {"A.eXaMpLe.CoM"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com")})); + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.ExAmPlE.cOm")})); +} + +TEST(XdsSanMatchingTest, ExactMatchMultipleSansMultipleMatchers) { + std::vector sans = {"a.example.com", "foo.example.com", + "b.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ExactMatcher("abc.example.com"), ExactMatcher("foo.example.com"), + ExactMatcher("xyz.example.com")})); +} + +TEST(XdsSanMatchingTest, ExactMatchWildCard) { + std::vector sans = {"*.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("a.example.com")})); + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("fOo.ExAmPlE.cOm")})); + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("BaR.eXaMpLe.CoM")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher(".example.com")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("example.com")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("foo.bar.com")})); +} + +TEST(XdsSanMatchingTest, ExactMatchWildCardDoesNotMatchSingleLabelDomain) { + std::vector sans = {"*"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.com.")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("bar.baz.com")})); + sans = {"*."}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.com.")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("bar.baz.com")})); +} + +TEST(XdsSanMatchingTest, ExactMatchAsteriskOnlyPermittedInLeftMostDomainName) { + std::vector sans = {"*.example.*.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.example.xyz.com")})); + sans = {"*.exam*ple.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.example.com")})); +} + +TEST(XdsSanMatchingTest, + ExactMatchAsteriskMustBeOnlyCharacterInLeftMostDomainName) { + std::vector sans = {"*c.example.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.example.com")})); +} + +TEST(XdsSanMatchingTest, + ExactMatchAsteriskMatchingAcrossDomainLabelsNotPermitted) { + std::vector sans = {"*.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.example.com")})); + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("foo.bar.baz.com")})); + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ExactMatcher("abc.com")})); +} + +TEST(XdsSanMatchingTest, PrefixMatch) { + std::vector sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames(sans.data(), sans.size(), + {PrefixMatcher("abc")})); + sans = {"AbC.CoM"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {PrefixMatcher("abc")})); + sans = {"xyz.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {PrefixMatcher("abc")})); +} + +TEST(XdsSanMatchingTest, PrefixMatchIgnoreCase) { + std::vector sans = {"aBc.cOm"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {PrefixMatcher("AbC", false /* case_sensitive */)})); + sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {PrefixMatcher("AbC", false /* case_sensitive */)})); + sans = {"xyz.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {PrefixMatcher("AbC", false /* case_sensitive */)})); +} + +TEST(XdsSanMatchingTest, SuffixMatch) { + std::vector sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SuffixMatcher(".com")})); + sans = {"AbC.CoM"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SuffixMatcher(".com")})); + sans = {"abc.xyz"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SuffixMatcher(".com")})); +} + +TEST(XdsSanMatchingTest, SuffixMatchIgnoreCase) { + std::vector sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {SuffixMatcher(".CoM", false /* case_sensitive */)})); + sans = {"AbC.cOm"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {SuffixMatcher(".CoM", false /* case_sensitive */)})); + sans = {"abc.xyz"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {SuffixMatcher(".CoM", false /* case_sensitive */)})); +} + +TEST(XdsSanMatchingTest, ContainsMatch) { + std::vector sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ContainsMatcher("abc")})); + sans = {"xyz.abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ContainsMatcher("abc")})); + sans = {"foo.AbC.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {ContainsMatcher("abc")})); +} + +TEST(XdsSanMatchingTest, ContainsMatchIgnoresCase) { + std::vector sans = {"abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ContainsMatcher("AbC", false /* case_sensitive */)})); + sans = {"xyz.abc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ContainsMatcher("AbC", false /* case_sensitive */)})); + sans = {"foo.aBc.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ContainsMatcher("AbC", false /* case_sensitive */)})); + sans = {"foo.Ab.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), + {ContainsMatcher("AbC", false /* case_sensitive */)})); +} + +TEST(XdsSanMatchingTest, RegexMatch) { + std::vector sans = {"abc.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SafeRegexMatcher("(abc|xyz).example.com")})); + sans = {"xyz.example.com"}; + EXPECT_TRUE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SafeRegexMatcher("(abc|xyz).example.com")})); + sans = {"foo.example.com"}; + EXPECT_FALSE(TestOnlyXdsVerifySubjectAlternativeNames( + sans.data(), sans.size(), {SafeRegexMatcher("(abc|xyz).example.com")})); +} + +} // namespace + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/server_config_selector/server_config_selector_test.cc b/test/core/server_config_selector/server_config_selector_test.cc new file mode 100644 index 00000000..74c7659c --- /dev/null +++ b/test/core/server_config_selector/server_config_selector_test.cc @@ -0,0 +1,81 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/filters/server_config_selector/server_config_selector.h" + +#include + +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +using grpc_core::ServerConfigSelector; +using grpc_core::ServerConfigSelectorProvider; + +class TestServerConfigSelectorProvider : public ServerConfigSelectorProvider { + absl::StatusOr> Watch( + std::unique_ptr /* watcher */) override { + return absl::UnavailableError("Test ServerConfigSelector"); + } + + void CancelWatch() override {} +}; + +// Test that ServerConfigSelectorProvider can be safely copied to channel args +// and destroyed +TEST(ServerConfigSelectorProviderTest, CopyChannelArgs) { + auto server_config_selector_provider = + grpc_core::MakeRefCounted(); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + EXPECT_EQ(server_config_selector_provider, + ServerConfigSelectorProvider::GetFromChannelArgs(*args)); + grpc_channel_args_destroy(args); +} + +// Test compare on channel args with the same ServerConfigSelectorProvider +TEST(ServerConfigSelectorProviderTest, ChannelArgsCompare) { + auto server_config_selector_provider = + grpc_core::MakeRefCounted(); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + grpc_channel_args* new_args = grpc_channel_args_copy(args); + EXPECT_EQ(ServerConfigSelectorProvider::GetFromChannelArgs(*new_args), + ServerConfigSelectorProvider::GetFromChannelArgs(*args)); + grpc_channel_args_destroy(args); + grpc_channel_args_destroy(new_args); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/slice/b64_decode_fuzzer.cc b/test/core/slice/b64_decode_fuzzer.cc new file mode 100644 index 00000000..cb3439af --- /dev/null +++ b/test/core/slice/b64_decode_fuzzer.cc @@ -0,0 +1,38 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "src/core/lib/slice/b64.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size < 1) return 0; + grpc_init(); + const bool url_safe = static_cast(0x100) < data[0]; + grpc_slice res = grpc_base64_decode_with_len( + reinterpret_cast(data + 1), size - 1, url_safe); + grpc_slice_unref(res); + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/b64_encode_fuzzer.cc b/test/core/slice/b64_encode_fuzzer.cc new file mode 100644 index 00000000..5b62fde5 --- /dev/null +++ b/test/core/slice/b64_encode_fuzzer.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "src/core/lib/slice/b64.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size < 2) return 0; + const bool url_safe = static_cast(0x100) < data[0]; + const bool multiline = static_cast(0x100) < data[1]; + char* res = grpc_base64_encode(reinterpret_cast(data + 2), + size - 2, url_safe, multiline); + gpr_free(res); + return 0; +} diff --git a/test/core/slice/b64_test.cc b/test/core/slice/b64_test.cc new file mode 100644 index 00000000..28c299f4 --- /dev/null +++ b/test/core/slice/b64_test.cc @@ -0,0 +1,220 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/slice/b64.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +static int buffers_are_equal(const unsigned char* buf1, + const unsigned char* buf2, size_t size) { + size_t i; + for (i = 0; i < size; i++) { + if (buf1[i] != buf2[i]) { + gpr_log(GPR_ERROR, "buf1 and buf2 differ: buf1[%d] = %x vs buf2[%d] = %x", + static_cast(i), buf1[i], static_cast(i), buf2[i]); + return 0; + } + } + return 1; +} + +static void test_simple_encode_decode_b64(int url_safe, int multiline) { + const char* hello = "hello"; + char* hello_b64 = + grpc_base64_encode(hello, strlen(hello), url_safe, multiline); + grpc_core::ExecCtx exec_ctx; + grpc_slice hello_slice = grpc_base64_decode(hello_b64, url_safe); + GPR_ASSERT(GRPC_SLICE_LENGTH(hello_slice) == strlen(hello)); + GPR_ASSERT(strncmp((const char*)GRPC_SLICE_START_PTR(hello_slice), hello, + GRPC_SLICE_LENGTH(hello_slice)) == 0); + + grpc_slice_unref_internal(hello_slice); + + gpr_free(hello_b64); +} + +static void test_full_range_encode_decode_b64(int url_safe, int multiline) { + unsigned char orig[256]; + size_t i; + char* b64; + grpc_slice orig_decoded; + for (i = 0; i < sizeof(orig); i++) orig[i] = static_cast(i); + + /* Try all the different paddings. */ + for (i = 0; i < 3; i++) { + grpc_core::ExecCtx exec_ctx; + b64 = grpc_base64_encode(orig, sizeof(orig) - i, url_safe, multiline); + orig_decoded = grpc_base64_decode(b64, url_safe); + GPR_ASSERT(GRPC_SLICE_LENGTH(orig_decoded) == (sizeof(orig) - i)); + GPR_ASSERT(buffers_are_equal(orig, GRPC_SLICE_START_PTR(orig_decoded), + sizeof(orig) - i)); + grpc_slice_unref_internal(orig_decoded); + gpr_free(b64); + } +} + +static void test_simple_encode_decode_b64_no_multiline(void) { + test_simple_encode_decode_b64(0, 0); +} + +static void test_simple_encode_decode_b64_multiline(void) { + test_simple_encode_decode_b64(0, 1); +} + +static void test_simple_encode_decode_b64_urlsafe_no_multiline(void) { + test_simple_encode_decode_b64(1, 0); +} + +static void test_simple_encode_decode_b64_urlsafe_multiline(void) { + test_simple_encode_decode_b64(1, 1); +} + +static void test_full_range_encode_decode_b64_no_multiline(void) { + test_full_range_encode_decode_b64(0, 0); +} + +static void test_full_range_encode_decode_b64_multiline(void) { + test_full_range_encode_decode_b64(0, 1); +} + +static void test_full_range_encode_decode_b64_urlsafe_no_multiline(void) { + test_full_range_encode_decode_b64(1, 0); +} + +static void test_full_range_encode_decode_b64_urlsafe_multiline(void) { + test_full_range_encode_decode_b64(1, 1); +} + +static void test_url_safe_unsafe_mismatch_failure(void) { + unsigned char orig[256]; + size_t i; + char* b64; + grpc_slice orig_decoded; + int url_safe = 1; + for (i = 0; i < sizeof(orig); i++) orig[i] = static_cast(i); + + grpc_core::ExecCtx exec_ctx; + b64 = grpc_base64_encode(orig, sizeof(orig), url_safe, 0); + orig_decoded = grpc_base64_decode(b64, !url_safe); + GPR_ASSERT(GRPC_SLICE_IS_EMPTY(orig_decoded)); + gpr_free(b64); + grpc_slice_unref_internal(orig_decoded); + + b64 = grpc_base64_encode(orig, sizeof(orig), !url_safe, 0); + orig_decoded = grpc_base64_decode(b64, url_safe); + GPR_ASSERT(GRPC_SLICE_IS_EMPTY(orig_decoded)); + gpr_free(b64); + grpc_slice_unref_internal(orig_decoded); +} + +static void test_rfc4648_test_vectors(void) { + char* b64; + + b64 = grpc_base64_encode("", 0, 0, 0); + GPR_ASSERT(strcmp("", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("f", 1, 0, 0); + GPR_ASSERT(strcmp("Zg==", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("fo", 2, 0, 0); + GPR_ASSERT(strcmp("Zm8=", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("foo", 3, 0, 0); + GPR_ASSERT(strcmp("Zm9v", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("foob", 4, 0, 0); + GPR_ASSERT(strcmp("Zm9vYg==", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("fooba", 5, 0, 0); + GPR_ASSERT(strcmp("Zm9vYmE=", b64) == 0); + gpr_free(b64); + + b64 = grpc_base64_encode("foobar", 6, 0, 0); + GPR_ASSERT(strcmp("Zm9vYmFy", b64) == 0); + gpr_free(b64); +} + +static void test_unpadded_decode(void) { + grpc_slice decoded; + + grpc_core::ExecCtx exec_ctx; + decoded = grpc_base64_decode("Zm9vYmFy", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "foobar") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("Zm9vYmE", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "fooba") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("Zm9vYg", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "foob") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("Zm9v", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "foo") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("Zm8", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "fo") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("Zg", 0); + GPR_ASSERT(!GRPC_SLICE_IS_EMPTY(decoded)); + GPR_ASSERT(grpc_slice_str_cmp(decoded, "f") == 0); + grpc_slice_unref(decoded); + + decoded = grpc_base64_decode("", 0); + GPR_ASSERT(GRPC_SLICE_IS_EMPTY(decoded)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_simple_encode_decode_b64_no_multiline(); + test_simple_encode_decode_b64_multiline(); + test_simple_encode_decode_b64_urlsafe_no_multiline(); + test_simple_encode_decode_b64_urlsafe_multiline(); + test_full_range_encode_decode_b64_no_multiline(); + test_full_range_encode_decode_b64_multiline(); + test_full_range_encode_decode_b64_urlsafe_no_multiline(); + test_full_range_encode_decode_b64_urlsafe_multiline(); + test_url_safe_unsafe_mismatch_failure(); + test_rfc4648_test_vectors(); + test_unpadded_decode(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/percent_decode_fuzzer.cc b/test/core/slice/percent_decode_fuzzer.cc new file mode 100644 index 00000000..25d90a34 --- /dev/null +++ b/test/core/slice/percent_decode_fuzzer.cc @@ -0,0 +1,50 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/slice/percent_encoding.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + grpc_init(); + grpc_slice input = grpc_slice_from_copied_buffer((const char*)data, size); + absl::optional output; + output = + grpc_core::PercentDecodeSlice(input, grpc_core::PercentEncodingType::URL); + if (output.has_value()) { + grpc_slice_unref(*output); + } + output = grpc_core::PercentDecodeSlice( + input, grpc_core::PercentEncodingType::Compatible); + if (output.has_value()) { + grpc_slice_unref(*output); + } + grpc_slice_unref(grpc_core::PermissivePercentDecodeSlice(input)); + grpc_slice_unref(input); + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/percent_encode_fuzzer.cc b/test/core/slice/percent_encode_fuzzer.cc new file mode 100644 index 00000000..782d1f97 --- /dev/null +++ b/test/core/slice/percent_encode_fuzzer.cc @@ -0,0 +1,58 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/slice/percent_encoding.h" + +bool squelch = true; +bool leak_check = true; + +static void test(const uint8_t* data, size_t size, + grpc_core::PercentEncodingType type) { + grpc_init(); + grpc_slice input = + grpc_slice_from_copied_buffer(reinterpret_cast(data), size); + grpc_slice output = grpc_core::PercentEncodeSlice(input, type); + absl::optional decoded_output = + grpc_core::PercentDecodeSlice(output, type); + // encoder must always produce decodable output + GPR_ASSERT(decoded_output.has_value()); + grpc_slice permissive_decoded_output = + grpc_core::PermissivePercentDecodeSlice(output); + // and decoded output must always match the input + GPR_ASSERT(grpc_slice_eq(input, *decoded_output)); + GPR_ASSERT(grpc_slice_eq(input, permissive_decoded_output)); + grpc_slice_unref(input); + grpc_slice_unref(output); + grpc_slice_unref(*decoded_output); + grpc_slice_unref(permissive_decoded_output); + grpc_shutdown(); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + test(data, size, grpc_core::PercentEncodingType::URL); + test(data, size, grpc_core::PercentEncodingType::Compatible); + return 0; +} diff --git a/test/core/slice/percent_encoding_test.cc b/test/core/slice/percent_encoding_test.cc new file mode 100644 index 00000000..d756637e --- /dev/null +++ b/test/core/slice/percent_encoding_test.cc @@ -0,0 +1,144 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/slice/percent_encoding.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/util/test_config.h" + +#define TEST_VECTOR(raw, encoded, dict) \ + test_vector(raw, sizeof(raw) - 1, encoded, sizeof(encoded) - 1, dict) + +#define TEST_NONCONFORMANT_VECTOR(encoded, permissive_unencoded, dict) \ + test_nonconformant_vector(encoded, sizeof(encoded) - 1, \ + permissive_unencoded, \ + sizeof(permissive_unencoded) - 1, dict) + +static void test_vector(const char* raw, size_t raw_length, const char* encoded, + size_t encoded_length, + grpc_core::PercentEncodingType type) { + char* raw_msg = gpr_dump(raw, raw_length, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* encoded_msg = + gpr_dump(encoded, encoded_length, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "Trial:\nraw = %s\nencoded = %s", raw_msg, encoded_msg); + gpr_free(raw_msg); + gpr_free(encoded_msg); + + grpc_slice raw_slice = grpc_slice_from_copied_buffer(raw, raw_length); + grpc_slice encoded_slice = + grpc_slice_from_copied_buffer(encoded, encoded_length); + grpc_slice raw2encoded_slice = grpc_core::PercentEncodeSlice(raw_slice, type); + absl::optional encoded2raw_slice = + grpc_core::PercentDecodeSlice(encoded_slice, type); + GPR_ASSERT(encoded2raw_slice.has_value()); + grpc_slice encoded2raw_permissive_slice = + grpc_core::PermissivePercentDecodeSlice(encoded_slice); + + char* raw2encoded_msg = + grpc_dump_slice(raw2encoded_slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* encoded2raw_msg = + grpc_dump_slice(*encoded2raw_slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* encoded2raw_permissive_msg = grpc_dump_slice( + encoded2raw_permissive_slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, + "Result:\nraw2encoded = %s\nencoded2raw = %s\nencoded2raw_permissive " + "= %s", + raw2encoded_msg, encoded2raw_msg, encoded2raw_permissive_msg); + gpr_free(raw2encoded_msg); + gpr_free(encoded2raw_msg); + gpr_free(encoded2raw_permissive_msg); + + GPR_ASSERT(grpc_slice_eq(raw_slice, *encoded2raw_slice)); + GPR_ASSERT(grpc_slice_eq(raw_slice, encoded2raw_permissive_slice)); + GPR_ASSERT(grpc_slice_eq(encoded_slice, raw2encoded_slice)); + + grpc_slice_unref(*encoded2raw_slice); + grpc_slice_unref(encoded2raw_permissive_slice); + grpc_slice_unref(raw2encoded_slice); + grpc_slice_unref(raw_slice); + grpc_slice_unref(encoded_slice); +} + +static void test_nonconformant_vector(const char* encoded, + size_t encoded_length, + const char* permissive_unencoded, + size_t permissive_unencoded_length, + grpc_core::PercentEncodingType type) { + char* permissive_unencoded_msg = + gpr_dump(permissive_unencoded, permissive_unencoded_length, + GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* encoded_msg = + gpr_dump(encoded, encoded_length, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "Trial:\nraw = %s\nencoded = %s", permissive_unencoded_msg, + encoded_msg); + gpr_free(permissive_unencoded_msg); + gpr_free(encoded_msg); + + grpc_slice permissive_unencoded_slice = grpc_slice_from_copied_buffer( + permissive_unencoded, permissive_unencoded_length); + grpc_slice encoded_slice = + grpc_slice_from_copied_buffer(encoded, encoded_length); + absl::optional encoded2raw_slice = + grpc_core::PercentDecodeSlice(encoded_slice, type); + GPR_ASSERT(!encoded2raw_slice.has_value()); + grpc_slice encoded2raw_permissive_slice = + grpc_core::PermissivePercentDecodeSlice(encoded_slice); + + char* encoded2raw_permissive_msg = grpc_dump_slice( + encoded2raw_permissive_slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_DEBUG, "Result:\nencoded2raw_permissive = %s", + encoded2raw_permissive_msg); + gpr_free(encoded2raw_permissive_msg); + + GPR_ASSERT( + grpc_slice_eq(permissive_unencoded_slice, encoded2raw_permissive_slice)); + + grpc_slice_unref(permissive_unencoded_slice); + grpc_slice_unref(encoded2raw_permissive_slice); + grpc_slice_unref(encoded_slice); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + TEST_VECTOR( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", + grpc_core::PercentEncodingType::URL); + TEST_VECTOR("\x00", "%00", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("\x01", "%01", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("a b", "a%20b", grpc_core::PercentEncodingType::URL); + TEST_VECTOR(" b", "%20b", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("a b", "a b", grpc_core::PercentEncodingType::Compatible); + TEST_VECTOR(" b", " b", grpc_core::PercentEncodingType::Compatible); + TEST_VECTOR("\x0f", "%0F", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("\xff", "%FF", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("\xee", "%EE", grpc_core::PercentEncodingType::URL); + TEST_VECTOR("%2", "%252", grpc_core::PercentEncodingType::URL); + TEST_NONCONFORMANT_VECTOR("%", "%", grpc_core::PercentEncodingType::URL); + TEST_NONCONFORMANT_VECTOR("%A", "%A", grpc_core::PercentEncodingType::URL); + TEST_NONCONFORMANT_VECTOR("%AG", "%AG", grpc_core::PercentEncodingType::URL); + TEST_NONCONFORMANT_VECTOR("\0", "\0", grpc_core::PercentEncodingType::URL); + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/slice_buffer_test.cc b/test/core/slice/slice_buffer_test.cc new file mode 100644 index 00000000..ffa2c7fd --- /dev/null +++ b/test/core/slice/slice_buffer_test.cc @@ -0,0 +1,161 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +void test_slice_buffer_add() { + grpc_slice_buffer buf; + grpc_slice aaa = grpc_slice_from_copied_string("aaa"); + grpc_slice bb = grpc_slice_from_copied_string("bb"); + size_t i; + + grpc_slice_buffer_init(&buf); + for (i = 0; i < 10; i++) { + grpc_slice_ref(aaa); + grpc_slice_ref(bb); + grpc_slice_buffer_add(&buf, aaa); + grpc_slice_buffer_add(&buf, bb); + } + GPR_ASSERT(buf.count > 0); + GPR_ASSERT(buf.length == 50); + grpc_slice_buffer_reset_and_unref(&buf); + GPR_ASSERT(buf.count == 0); + GPR_ASSERT(buf.length == 0); + for (i = 0; i < 10; i++) { + grpc_slice_ref(aaa); + grpc_slice_ref(bb); + grpc_slice_buffer_add(&buf, aaa); + grpc_slice_buffer_add(&buf, bb); + } + GPR_ASSERT(buf.count > 0); + GPR_ASSERT(buf.length == 50); + for (i = 0; i < 10; i++) { + grpc_slice_buffer_pop(&buf); + grpc_slice_unref(aaa); + grpc_slice_unref(bb); + } + GPR_ASSERT(buf.count == 0); + GPR_ASSERT(buf.length == 0); + grpc_slice_buffer_destroy(&buf); +} + +void test_slice_buffer_move_first() { + grpc_slice slices[3]; + grpc_slice_buffer src; + grpc_slice_buffer dst; + int idx = 0; + size_t src_len = 0; + size_t dst_len = 0; + + slices[0] = grpc_slice_from_copied_string("aaa"); + slices[1] = grpc_slice_from_copied_string("bbbb"); + slices[2] = grpc_slice_from_copied_string("ccc"); + + grpc_slice_buffer_init(&src); + grpc_slice_buffer_init(&dst); + for (idx = 0; idx < 3; idx++) { + grpc_slice_ref(slices[idx]); + /* For this test, it is important that we add each slice at a new + slice index */ + grpc_slice_buffer_add_indexed(&src, slices[idx]); + grpc_slice_buffer_add_indexed(&dst, slices[idx]); + } + + /* Case 1: Move more than the first slice's length from src to dst */ + src_len = src.length; + dst_len = dst.length; + grpc_slice_buffer_move_first(&src, 4, &dst); + src_len -= 4; + dst_len += 4; + GPR_ASSERT(src.length == src_len); + GPR_ASSERT(dst.length == dst_len); + + /* src now has two slices ["bbb"] and ["ccc"] */ + /* Case 2: Move the first slice from src to dst */ + grpc_slice_buffer_move_first(&src, 3, &dst); + src_len -= 3; + dst_len += 3; + GPR_ASSERT(src.length == src_len); + GPR_ASSERT(dst.length == dst_len); + + /* src now has one slice ["ccc"] */ + /* Case 3: Move less than the first slice's length from src to dst*/ + grpc_slice_buffer_move_first(&src, 2, &dst); + src_len -= 2; + dst_len += 2; + GPR_ASSERT(src.length == src_len); + GPR_ASSERT(dst.length == dst_len); +} + +void test_slice_buffer_first() { + grpc_slice slices[3]; + slices[0] = grpc_slice_from_copied_string("aaa"); + slices[1] = grpc_slice_from_copied_string("bbbb"); + slices[2] = grpc_slice_from_copied_string("ccccc"); + + grpc_slice_buffer buf; + grpc_slice_buffer_init(&buf); + for (int idx = 0; idx < 3; ++idx) { + grpc_slice_ref(slices[idx]); + grpc_slice_buffer_add_indexed(&buf, slices[idx]); + } + + grpc_slice* first = grpc_slice_buffer_peek_first(&buf); + GPR_ASSERT(GPR_SLICE_LENGTH(*first) == GPR_SLICE_LENGTH(slices[0])); + GPR_ASSERT(buf.count == 3); + GPR_ASSERT(buf.length == 12); + + grpc_slice_buffer_sub_first(&buf, 1, 2); + first = grpc_slice_buffer_peek_first(&buf); + GPR_ASSERT(GPR_SLICE_LENGTH(*first) == 1); + GPR_ASSERT(buf.count == 3); + GPR_ASSERT(buf.length == 10); + + grpc_slice_buffer_remove_first(&buf); + first = grpc_slice_buffer_peek_first(&buf); + GPR_ASSERT(GPR_SLICE_LENGTH(*first) == GPR_SLICE_LENGTH(slices[1])); + GPR_ASSERT(buf.count == 2); + GPR_ASSERT(buf.length == 9); + + grpc_slice_buffer_remove_first(&buf); + first = grpc_slice_buffer_peek_first(&buf); + GPR_ASSERT(GPR_SLICE_LENGTH(*first) == GPR_SLICE_LENGTH(slices[2])); + GPR_ASSERT(buf.count == 1); + GPR_ASSERT(buf.length == 5); + + grpc_slice_buffer_remove_first(&buf); + GPR_ASSERT(buf.count == 0); + GPR_ASSERT(buf.length == 0); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_slice_buffer_add(); + test_slice_buffer_move_first(); + test_slice_buffer_first(); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/slice_intern_test.cc b/test/core/slice/slice_intern_test.cc new file mode 100644 index 00000000..277351cd --- /dev/null +++ b/test/core/slice/slice_intern_test.cc @@ -0,0 +1,101 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x); + +static void test_slice_interning(void) { + LOG_TEST_NAME("test_slice_interning"); + + grpc_init(); + grpc_slice src1 = grpc_slice_from_copied_string("hello123456789123456789"); + grpc_slice src2 = grpc_slice_from_copied_string("hello123456789123456789"); + + // Explicitly checking that the slices are at different addresses prevents + // failure with windows opt 64bit build. + // See https://github.com/grpc/grpc/issues/20519 + GPR_ASSERT(&src1 != &src2); + GPR_ASSERT(GRPC_SLICE_START_PTR(src1) != GRPC_SLICE_START_PTR(src2)); + + grpc_slice interned1 = grpc_slice_intern(src1); + grpc_slice interned2 = grpc_slice_intern(src2); + GPR_ASSERT(GRPC_SLICE_START_PTR(interned1) == + GRPC_SLICE_START_PTR(interned2)); + GPR_ASSERT(GRPC_SLICE_START_PTR(interned1) != GRPC_SLICE_START_PTR(src1)); + GPR_ASSERT(GRPC_SLICE_START_PTR(interned2) != GRPC_SLICE_START_PTR(src2)); + grpc_slice_unref(src1); + grpc_slice_unref(src2); + grpc_slice_unref(interned1); + grpc_slice_unref(interned2); + grpc_shutdown(); +} + +static void test_static_slice_interning(void) { + LOG_TEST_NAME("test_static_slice_interning"); + + // grpc_init/grpc_shutdown deliberately omitted: they should not be necessary + // to intern a static slice + + for (size_t i = 0; i < GRPC_STATIC_MDSTR_COUNT; i++) { + GPR_ASSERT(grpc_slice_is_equivalent( + grpc_core::g_static_metadata_slice_table[i], + grpc_slice_intern(grpc_core::g_static_metadata_slice_table[i]))); + } +} + +static void test_static_slice_copy_interning(void) { + LOG_TEST_NAME("test_static_slice_copy_interning"); + + grpc_init(); + + for (size_t i = 0; i < GRPC_STATIC_MDSTR_COUNT; i++) { + grpc_slice copy = + grpc_slice_dup(grpc_core::g_static_metadata_slice_table[i]); + GPR_ASSERT(grpc_core::g_static_metadata_slice_table[i].refcount != + copy.refcount); + GPR_ASSERT(grpc_core::g_static_metadata_slice_table[i].refcount == + grpc_slice_intern(copy).refcount); + grpc_slice_unref(copy); + } + + grpc_shutdown(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_slice_interning(); + test_static_slice_interning(); + test_static_slice_copy_interning(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/slice/slice_split_test.cc b/test/core/slice/slice_split_test.cc new file mode 100644 index 00000000..6e349ac3 --- /dev/null +++ b/test/core/slice/slice_split_test.cc @@ -0,0 +1,174 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/slice/slice_split.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x) + +static void test_strsplit(void) { + grpc_slice_buffer* parts; + grpc_slice str; + + LOG_TEST_NAME("test_strsplit"); + + parts = + static_cast(gpr_malloc(sizeof(grpc_slice_buffer))); + grpc_slice_buffer_init(parts); + + str = grpc_slice_from_copied_string("one, two, three, four"); + grpc_slice_split(str, ", ", parts); + GPR_ASSERT(4 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "one")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "two")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[2], "three")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[3], "four")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator not present in string */ + str = grpc_slice_from_copied_string("one two three four"); + grpc_slice_split(str, ", ", parts); + GPR_ASSERT(1 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "one two three four")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator at the end */ + str = grpc_slice_from_copied_string("foo,"); + grpc_slice_split(str, ",", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "foo")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator at the beginning */ + str = grpc_slice_from_copied_string(",foo"); + grpc_slice_split(str, ",", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "foo")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* standalone separator */ + str = grpc_slice_from_copied_string(","); + grpc_slice_split(str, ",", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* empty input */ + str = grpc_slice_from_copied_string(""); + grpc_slice_split(str, ", ", parts); + GPR_ASSERT(1 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + grpc_slice_buffer_destroy(parts); + gpr_free(parts); +} + +static void test_strsplit_nospace(void) { + grpc_slice_buffer* parts; + grpc_slice str; + + LOG_TEST_NAME("test_strsplit_nospace"); + + parts = + static_cast(gpr_malloc(sizeof(grpc_slice_buffer))); + grpc_slice_buffer_init(parts); + + str = grpc_slice_from_copied_string("one ,two, three , four"); + grpc_slice_split_without_space(str, ",", parts); + GPR_ASSERT(4 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "one")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "two")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[2], "three")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[3], "four")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator not present in string */ + str = grpc_slice_from_copied_string("one two three four "); + grpc_slice_split_without_space(str, ",", parts); + GPR_ASSERT(1 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "one two three four")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator at the end */ + str = grpc_slice_from_copied_string("foo,"); + grpc_slice_split_without_space(str, ",", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "foo")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* separator at the beginning */ + str = grpc_slice_from_copied_string(" , foo"); + grpc_slice_split_without_space(str, ",", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "foo")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* standalone separator */ + str = grpc_slice_from_copied_string(", "); + grpc_slice_split_without_space(str, ", ", parts); + GPR_ASSERT(2 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[1], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + /* empty input */ + str = grpc_slice_from_copied_string(""); + grpc_slice_split_without_space(str, ",", parts); + GPR_ASSERT(1 == parts->count); + GPR_ASSERT(0 == grpc_slice_str_cmp(parts->slices[0], "")); + grpc_slice_buffer_reset_and_unref(parts); + grpc_slice_unref(str); + + grpc_slice_buffer_destroy(parts); + gpr_free(parts); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_strsplit(); + test_strsplit_nospace(); + return 0; +} diff --git a/test/core/slice/slice_string_helpers_test.cc b/test/core/slice/slice_string_helpers_test.cc new file mode 100644 index 00000000..53b229c0 --- /dev/null +++ b/test/core/slice/slice_string_helpers_test.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/slice/slice_string_helpers.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_internal.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x) + +static void expect_slice_dump(grpc_slice slice, uint32_t flags, + const char* result) { + char* got = grpc_dump_slice(slice, flags); + GPR_ASSERT(0 == strcmp(got, result)); + gpr_free(got); + grpc_slice_unref_internal(slice); +} + +static void test_dump_slice(void) { + static const char* text = "HELLO WORLD!"; + static const char* long_text = + "It was a bright cold day in April, and the clocks were striking " + "thirteen. Winston Smith, his chin nuzzled into his breast in an effort " + "to escape the vile wind, slipped quickly through the glass doors of " + "Victory Mansions, though not quickly enough to prevent a swirl of " + "gritty dust from entering along with him."; + + LOG_TEST_NAME("test_dump_slice"); + + expect_slice_dump(grpc_slice_from_copied_string(text), GPR_DUMP_ASCII, text); + expect_slice_dump(grpc_slice_from_copied_string(long_text), GPR_DUMP_ASCII, + long_text); + expect_slice_dump(grpc_slice_from_copied_buffer("\x01", 1), GPR_DUMP_HEX, + "01"); + expect_slice_dump(grpc_slice_from_copied_buffer("\x01", 1), + GPR_DUMP_HEX | GPR_DUMP_ASCII, "01 '.'"); +} + +int main(int, char**) { + test_dump_slice(); + return 0; +} diff --git a/test/core/slice/slice_test.cc b/test/core/slice/slice_test.cc new file mode 100644 index 00000000..3da56f16 --- /dev/null +++ b/test/core/slice/slice_test.cc @@ -0,0 +1,293 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/slice/slice_internal.h" + +#define LOG_TEST_NAME(x) gpr_log(GPR_INFO, "%s", x); + +static void test_slice_malloc_returns_something_sensible(void) { + /* Calls grpc_slice_create for various lengths and verifies the internals for + consistency. */ + size_t length; + size_t i; + grpc_slice slice; + + LOG_TEST_NAME("test_slice_malloc_returns_something_sensible"); + + for (length = 0; length <= 1024; length++) { + slice = grpc_slice_malloc(length); + /* If there is a length, slice.data must be non-NULL. If length is zero + we don't care. */ + if (length > GRPC_SLICE_INLINED_SIZE) { + GPR_ASSERT(slice.data.refcounted.bytes); + } + /* Returned slice length must be what was requested. */ + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == length); + /* We must be able to write to every byte of the data */ + for (i = 0; i < length; i++) { + GRPC_SLICE_START_PTR(slice)[i] = static_cast(i); + } + /* And finally we must succeed in destroying the slice */ + grpc_slice_unref_internal(slice); + } +} + +static void do_nothing(void* /*ignored*/) {} + +static void test_slice_new_returns_something_sensible(void) { + uint8_t x; + + grpc_slice slice = grpc_slice_new(&x, 1, do_nothing); + GPR_ASSERT(slice.refcount); + GPR_ASSERT(slice.data.refcounted.bytes == &x); + GPR_ASSERT(slice.data.refcounted.length == 1); + grpc_slice_unref_internal(slice); +} + +/* destroy function that sets a mark to indicate it was called. */ +static void set_mark(void* p) { *(static_cast(p)) = 1; } + +static void test_slice_new_with_user_data(void) { + int marker = 0; + uint8_t buf[2]; + grpc_slice slice; + + buf[0] = 0; + buf[1] = 1; + slice = grpc_slice_new_with_user_data(buf, 2, set_mark, &marker); + GPR_ASSERT(marker == 0); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == 2); + GPR_ASSERT(GRPC_SLICE_START_PTR(slice)[0] == 0); + GPR_ASSERT(GRPC_SLICE_START_PTR(slice)[1] == 1); + + /* unref should cause destroy function to run. */ + grpc_slice_unref_internal(slice); + GPR_ASSERT(marker == 1); +} + +static int do_nothing_with_len_1_calls = 0; + +static void do_nothing_with_len_1(void* /*ignored*/, size_t len) { + GPR_ASSERT(len == 1); + do_nothing_with_len_1_calls++; +} + +static void test_slice_new_with_len_returns_something_sensible(void) { + uint8_t x; + int num_refs = 5; /* To test adding/removing an arbitrary number of refs */ + int i; + + grpc_slice slice = grpc_slice_new_with_len(&x, 1, do_nothing_with_len_1); + GPR_ASSERT(slice.refcount); /* ref count is initialized to 1 at this point */ + GPR_ASSERT(slice.data.refcounted.bytes == &x); + GPR_ASSERT(slice.data.refcounted.length == 1); + GPR_ASSERT(do_nothing_with_len_1_calls == 0); + + /* Add an arbitrary number of refs to the slice and remoe the refs. This is to + make sure that that the destroy callback (i.e do_nothing_with_len_1()) is + not called until the last unref operation */ + for (i = 0; i < num_refs; i++) { + grpc_slice_ref_internal(slice); + } + for (i = 0; i < num_refs; i++) { + grpc_slice_unref_internal(slice); + } + GPR_ASSERT(do_nothing_with_len_1_calls == 0); /* Shouldn't be called yet */ + + /* last unref */ + grpc_slice_unref_internal(slice); + GPR_ASSERT(do_nothing_with_len_1_calls == 1); +} + +static void test_slice_sub_works(unsigned length) { + grpc_slice slice; + grpc_slice sub; + unsigned i, j, k; + + LOG_TEST_NAME("test_slice_sub_works"); + gpr_log(GPR_INFO, "length=%d", length); + + /* Create a slice in which each byte is equal to the distance from it to the + beginning of the slice. */ + slice = grpc_slice_malloc(length); + for (i = 0; i < length; i++) { + GRPC_SLICE_START_PTR(slice)[i] = static_cast(i); + } + + /* Ensure that for all subsets length is correct and that we start on the + correct byte. Additionally check that no copies were made. */ + for (i = 0; i < length; i++) { + for (j = i; j < length; j++) { + sub = grpc_slice_sub(slice, i, j); + GPR_ASSERT(GRPC_SLICE_LENGTH(sub) == j - i); + for (k = 0; k < j - i; k++) { + GPR_ASSERT(GRPC_SLICE_START_PTR(sub)[k] == (uint8_t)(i + k)); + } + grpc_slice_unref_internal(sub); + } + } + grpc_slice_unref_internal(slice); +} + +static void check_head_tail(grpc_slice slice, grpc_slice head, + grpc_slice tail) { + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == + GRPC_SLICE_LENGTH(head) + GRPC_SLICE_LENGTH(tail)); + GPR_ASSERT(0 == memcmp(GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_START_PTR(head), GRPC_SLICE_LENGTH(head))); + GPR_ASSERT(0 == memcmp(GRPC_SLICE_START_PTR(slice) + GRPC_SLICE_LENGTH(head), + GRPC_SLICE_START_PTR(tail), GRPC_SLICE_LENGTH(tail))); +} + +static void test_slice_split_head_works(size_t length) { + grpc_slice slice; + grpc_slice head, tail; + size_t i; + + LOG_TEST_NAME("test_slice_split_head_works"); + gpr_log(GPR_INFO, "length=%" PRIuPTR, length); + + /* Create a slice in which each byte is equal to the distance from it to the + beginning of the slice. */ + slice = grpc_slice_malloc(length); + for (i = 0; i < length; i++) { + GRPC_SLICE_START_PTR(slice)[i] = static_cast(i); + } + + /* Ensure that for all subsets length is correct and that we start on the + correct byte. Additionally check that no copies were made. */ + for (i = 0; i < length; i++) { + tail = grpc_slice_ref_internal(slice); + head = grpc_slice_split_head(&tail, i); + check_head_tail(slice, head, tail); + grpc_slice_unref_internal(tail); + grpc_slice_unref_internal(head); + } + + grpc_slice_unref_internal(slice); +} + +static void test_slice_split_tail_works(size_t length) { + grpc_slice slice; + grpc_slice head, tail; + size_t i; + + LOG_TEST_NAME("test_slice_split_tail_works"); + gpr_log(GPR_INFO, "length=%" PRIuPTR, length); + + /* Create a slice in which each byte is equal to the distance from it to the + beginning of the slice. */ + slice = grpc_slice_malloc(length); + for (i = 0; i < length; i++) { + GRPC_SLICE_START_PTR(slice)[i] = static_cast(i); + } + + /* Ensure that for all subsets length is correct and that we start on the + correct byte. Additionally check that no copies were made. */ + for (i = 0; i < length; i++) { + head = grpc_slice_ref_internal(slice); + tail = grpc_slice_split_tail(&head, i); + check_head_tail(slice, head, tail); + grpc_slice_unref_internal(tail); + grpc_slice_unref_internal(head); + } + + grpc_slice_unref_internal(slice); +} + +static void test_slice_from_copied_string_works(void) { + static const char* text = "HELLO WORLD!"; + grpc_slice slice; + + LOG_TEST_NAME("test_slice_from_copied_string_works"); + + slice = grpc_slice_from_copied_string(text); + GPR_ASSERT(strlen(text) == GRPC_SLICE_LENGTH(slice)); + GPR_ASSERT( + 0 == memcmp(text, GRPC_SLICE_START_PTR(slice), GRPC_SLICE_LENGTH(slice))); + grpc_slice_unref_internal(slice); +} + +static void test_moved_string_slice(void) { + LOG_TEST_NAME("test_moved_string_slice"); + + // Small string should be inlined. + constexpr char kSmallStr[] = "hello12345"; + char* small_ptr = strdup(kSmallStr); + grpc_slice small = + grpc_slice_from_moved_string(grpc_core::UniquePtr(small_ptr)); + GPR_ASSERT(GRPC_SLICE_LENGTH(small) == strlen(kSmallStr)); + GPR_ASSERT(GRPC_SLICE_START_PTR(small) != + reinterpret_cast(small_ptr)); + grpc_slice_unref_internal(small); + + // Large string should be move the reference. + constexpr char kSLargeStr[] = "hello123456789123456789123456789"; + char* large_ptr = strdup(kSLargeStr); + grpc_slice large = + grpc_slice_from_moved_string(grpc_core::UniquePtr(large_ptr)); + GPR_ASSERT(GRPC_SLICE_LENGTH(large) == strlen(kSLargeStr)); + GPR_ASSERT(GRPC_SLICE_START_PTR(large) == + reinterpret_cast(large_ptr)); + grpc_slice_unref_internal(large); + + // Moved buffer must respect the provided length not the actual length of the + // string. + large_ptr = strdup(kSLargeStr); + small = grpc_slice_from_moved_buffer(grpc_core::UniquePtr(large_ptr), + strlen(kSmallStr)); + GPR_ASSERT(GRPC_SLICE_LENGTH(small) == strlen(kSmallStr)); + GPR_ASSERT(GRPC_SLICE_START_PTR(small) != + reinterpret_cast(large_ptr)); + grpc_slice_unref_internal(small); +} + +void test_string_view_from_slice() { + constexpr char kStr[] = "foo"; + absl::string_view sv( + grpc_core::StringViewFromSlice(grpc_slice_from_static_string(kStr))); + GPR_ASSERT(std::string(sv) == kStr); +} + +int main(int, char**) { + unsigned length; + test_slice_malloc_returns_something_sensible(); + test_slice_new_returns_something_sensible(); + test_slice_new_with_user_data(); + test_slice_new_with_len_returns_something_sensible(); + for (length = 0; length < 128; length++) { + test_slice_sub_works(length); + test_slice_split_head_works(length); + test_slice_split_tail_works(length); + } + test_slice_from_copied_string_works(); + test_moved_string_slice(); + test_string_view_from_slice(); + return 0; +} diff --git a/test/core/surface/byte_buffer_reader_test.cc b/test/core/surface/byte_buffer_reader_test.cc new file mode 100644 index 00000000..89f50824 --- /dev/null +++ b/test/core/surface/byte_buffer_reader_test.cc @@ -0,0 +1,279 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x) + +static void test_read_one_slice(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice first_slice, second_slice; + int first_code, second_code; + + LOG_TEST("test_read_one_slice"); + slice = grpc_slice_from_copied_string("test"); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_next(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(first_slice), "test", 4) == 0); + grpc_slice_unref(first_slice); + second_code = grpc_byte_buffer_reader_next(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_read_one_slice_malloc(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice first_slice, second_slice; + int first_code, second_code; + + LOG_TEST("test_read_one_slice_malloc"); + slice = grpc_slice_malloc(4); + memcpy(GRPC_SLICE_START_PTR(slice), "test", 4); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_next(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(first_slice), "test", 4) == 0); + grpc_slice_unref(first_slice); + second_code = grpc_byte_buffer_reader_next(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_read_none_compressed_slice(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice first_slice, second_slice; + int first_code, second_code; + + LOG_TEST("test_read_none_compressed_slice"); + slice = grpc_slice_from_copied_string("test"); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_next(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(first_slice), "test", 4) == 0); + grpc_slice_unref(first_slice); + second_code = grpc_byte_buffer_reader_next(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_peek_one_slice(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice* first_slice; + grpc_slice* second_slice; + int first_code, second_code; + + LOG_TEST("test_peek_one_slice"); + slice = grpc_slice_from_copied_string("test"); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_peek(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*first_slice), "test", 4) == 0); + second_code = grpc_byte_buffer_reader_peek(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_peek_one_slice_malloc(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice* first_slice; + grpc_slice* second_slice; + int first_code, second_code; + + LOG_TEST("test_peek_one_slice_malloc"); + slice = grpc_slice_malloc(4); + memcpy(GRPC_SLICE_START_PTR(slice), "test", 4); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_peek(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*first_slice), "test", 4) == 0); + second_code = grpc_byte_buffer_reader_peek(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_peek_none_compressed_slice(void) { + grpc_slice slice; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice* first_slice; + grpc_slice* second_slice; + int first_code, second_code; + + LOG_TEST("test_peek_none_compressed_slice"); + slice = grpc_slice_from_copied_string("test"); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + first_code = grpc_byte_buffer_reader_peek(&reader, &first_slice); + GPR_ASSERT(first_code != 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*first_slice), "test", 4) == 0); + second_code = grpc_byte_buffer_reader_peek(&reader, &second_slice); + GPR_ASSERT(second_code == 0); + grpc_byte_buffer_destroy(buffer); +} + +static void test_byte_buffer_from_reader(void) { + grpc_slice slice; + grpc_byte_buffer *buffer, *buffer_from_reader; + grpc_byte_buffer_reader reader; + + LOG_TEST("test_byte_buffer_from_reader"); + slice = grpc_slice_malloc(4); + memcpy(GRPC_SLICE_START_PTR(slice), "test", 4); + buffer = grpc_raw_byte_buffer_create(&slice, 1); + grpc_slice_unref(slice); + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + + buffer_from_reader = grpc_raw_byte_buffer_from_reader(&reader); + GPR_ASSERT(buffer->type == buffer_from_reader->type); + GPR_ASSERT(buffer_from_reader->data.raw.compression == GRPC_COMPRESS_NONE); + GPR_ASSERT(buffer_from_reader->data.raw.slice_buffer.count == 1); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR( + buffer_from_reader->data.raw.slice_buffer.slices[0]), + "test", 4) == 0); + + grpc_byte_buffer_destroy(buffer); + grpc_byte_buffer_destroy(buffer_from_reader); +} + +static void test_readall(void) { + char* lotsa_as[512]; + char* lotsa_bs[1024]; + grpc_slice slices[2]; + grpc_byte_buffer* buffer; + grpc_byte_buffer_reader reader; + grpc_slice slice_out; + + LOG_TEST("test_readall"); + + memset(lotsa_as, 'a', 512 * sizeof(lotsa_as[0])); + memset(lotsa_bs, 'b', 1024 * sizeof(lotsa_bs[0])); + /* use slices large enough to overflow inlining */ + slices[0] = grpc_slice_malloc(512); + memcpy(GRPC_SLICE_START_PTR(slices[0]), lotsa_as, 512); + slices[1] = grpc_slice_malloc(1024); + memcpy(GRPC_SLICE_START_PTR(slices[1]), lotsa_bs, 1024); + + buffer = grpc_raw_byte_buffer_create(slices, 2); + grpc_slice_unref(slices[0]); + grpc_slice_unref(slices[1]); + + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + slice_out = grpc_byte_buffer_reader_readall(&reader); + + GPR_ASSERT(GRPC_SLICE_LENGTH(slice_out) == 512 + 1024); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(slice_out), lotsa_as, 512) == 0); + GPR_ASSERT(memcmp(&(GRPC_SLICE_START_PTR(slice_out)[512]), lotsa_bs, 1024) == + 0); + grpc_slice_unref(slice_out); + grpc_byte_buffer_destroy(buffer); +} + +static void test_byte_buffer_copy(void) { + char* lotsa_as[512]; + char* lotsa_bs[1024]; + grpc_slice slices[2]; + grpc_byte_buffer* buffer; + grpc_byte_buffer* copied_buffer; + grpc_byte_buffer_reader reader; + grpc_slice slice_out; + + LOG_TEST("test_byte_buffer_copy"); + + memset(lotsa_as, 'a', 512 * sizeof(lotsa_as[0])); + memset(lotsa_bs, 'b', 1024 * sizeof(lotsa_bs[0])); + /* use slices large enough to overflow inlining */ + slices[0] = grpc_slice_malloc(512); + memcpy(GRPC_SLICE_START_PTR(slices[0]), lotsa_as, 512); + slices[1] = grpc_slice_malloc(1024); + memcpy(GRPC_SLICE_START_PTR(slices[1]), lotsa_bs, 1024); + + buffer = grpc_raw_byte_buffer_create(slices, 2); + grpc_slice_unref(slices[0]); + grpc_slice_unref(slices[1]); + copied_buffer = grpc_byte_buffer_copy(buffer); + + GPR_ASSERT(grpc_byte_buffer_reader_init(&reader, buffer) && + "Couldn't init byte buffer reader"); + slice_out = grpc_byte_buffer_reader_readall(&reader); + + GPR_ASSERT(GRPC_SLICE_LENGTH(slice_out) == 512 + 1024); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(slice_out), lotsa_as, 512) == 0); + GPR_ASSERT(memcmp(&(GRPC_SLICE_START_PTR(slice_out)[512]), lotsa_bs, 1024) == + 0); + grpc_slice_unref(slice_out); + grpc_byte_buffer_destroy(buffer); + grpc_byte_buffer_destroy(copied_buffer); +} + +int main(int argc, char** argv) { + grpc_init(); + grpc::testing::TestEnvironment env(argc, argv); + test_read_one_slice(); + test_read_one_slice_malloc(); + test_read_none_compressed_slice(); + test_peek_one_slice(); + test_peek_one_slice_malloc(); + test_peek_none_compressed_slice(); + test_byte_buffer_from_reader(); + test_byte_buffer_copy(); + test_readall(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/surface/channel_create_test.cc b/test/core/surface/channel_create_test.cc new file mode 100644 index 00000000..5f109c0f --- /dev/null +++ b/test/core/surface/channel_create_test.cc @@ -0,0 +1,52 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/test_config.h" + +void test_unknown_scheme_target(void) { + grpc_channel* chan; + /* avoid default prefix */ + grpc_core::ResolverRegistry::Builder::ShutdownRegistry(); + grpc_core::ResolverRegistry::Builder::InitRegistry(); + + chan = grpc_insecure_channel_create("blah://blah", nullptr, nullptr); + GPR_ASSERT(chan != nullptr); + + grpc_core::ExecCtx exec_ctx; + grpc_channel_element* elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(chan), 0); + GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client")); + + grpc_channel_destroy(chan); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_unknown_scheme_target(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/surface/completion_queue_test.cc b/test/core/surface/completion_queue_test.cc new file mode 100644 index 00000000..0a8bee78 --- /dev/null +++ b/test/core/surface/completion_queue_test.cc @@ -0,0 +1,505 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/surface/completion_queue.h" + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x) + +static void* create_test_tag(void) { + static intptr_t i = 0; + return reinterpret_cast(++i); +} + +/* helper for tests to shutdown correctly and tersely */ +static void shutdown_and_destroy(grpc_completion_queue* cc) { + grpc_event ev; + grpc_completion_queue_shutdown(cc); + + switch (grpc_get_cq_completion_type(cc)) { + case GRPC_CQ_NEXT: { + ev = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + break; + } + case GRPC_CQ_PLUCK: { + ev = grpc_completion_queue_pluck( + cc, create_test_tag(), gpr_inf_past(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + break; + } + case GRPC_CQ_CALLBACK: { + // Nothing to do here. The shutdown callback will be invoked when + // possible. + break; + } + default: { + gpr_log(GPR_ERROR, "Unknown completion type"); + break; + } + } + + grpc_completion_queue_destroy(cc); +} + +/* ensure we can create and destroy a completion channel */ +static void test_no_op(void) { + grpc_cq_completion_type completion_types[] = {GRPC_CQ_NEXT, GRPC_CQ_PLUCK}; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + LOG_TEST("test_no_op"); + + attr.version = 1; + for (size_t i = 0; i < GPR_ARRAY_SIZE(completion_types); i++) { + for (size_t j = 0; j < GPR_ARRAY_SIZE(polling_types); j++) { + attr.cq_completion_type = completion_types[i]; + attr.cq_polling_type = polling_types[j]; + shutdown_and_destroy(grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr)); + } + } +} + +static void test_pollset_conversion(void) { + grpc_cq_completion_type completion_types[] = {GRPC_CQ_NEXT, GRPC_CQ_PLUCK}; + grpc_cq_polling_type polling_types[] = {GRPC_CQ_DEFAULT_POLLING, + GRPC_CQ_NON_LISTENING}; + grpc_completion_queue* cq; + grpc_completion_queue_attributes attr; + + LOG_TEST("test_pollset_conversion"); + + attr.version = 1; + for (size_t i = 0; i < GPR_ARRAY_SIZE(completion_types); i++) { + for (size_t j = 0; j < GPR_ARRAY_SIZE(polling_types); j++) { + attr.cq_completion_type = completion_types[i]; + attr.cq_polling_type = polling_types[j]; + cq = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + GPR_ASSERT(grpc_cq_pollset(cq) != nullptr); + shutdown_and_destroy(cq); + } + } +} + +static void test_wait_empty(void) { + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue* cc; + grpc_completion_queue_attributes attr; + grpc_event event; + + LOG_TEST("test_wait_empty"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + event = + grpc_completion_queue_next(cc, gpr_now(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_QUEUE_TIMEOUT); + shutdown_and_destroy(cc); + } +} + +static void do_nothing_end_completion(void* /*arg*/, + grpc_cq_completion* /*c*/) {} + +static void test_cq_end_op(void) { + grpc_event ev; + grpc_completion_queue* cc; + grpc_cq_completion completion; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + void* tag = create_test_tag(); + + LOG_TEST("test_cq_end_op"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + grpc_core::ExecCtx exec_ctx; + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + GPR_ASSERT(grpc_cq_begin_op(cc, tag)); + grpc_cq_end_op(cc, tag, GRPC_ERROR_NONE, do_nothing_end_completion, nullptr, + &completion); + + ev = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag); + GPR_ASSERT(ev.success); + + shutdown_and_destroy(cc); + } +} + +static void test_cq_tls_cache_full(void) { + grpc_event ev; + grpc_completion_queue* cc; + grpc_cq_completion completion; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + void* tag = create_test_tag(); + void* res_tag; + int ok; + + LOG_TEST("test_cq_tls_cache_full"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + grpc_core::ExecCtx exec_ctx; // Reset exec_ctx + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + grpc_completion_queue_thread_local_cache_init(cc); + GPR_ASSERT(grpc_cq_begin_op(cc, tag)); + grpc_cq_end_op(cc, tag, GRPC_ERROR_NONE, do_nothing_end_completion, nullptr, + &completion); + + ev = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_TIMEOUT); + + GPR_ASSERT( + grpc_completion_queue_thread_local_cache_flush(cc, &res_tag, &ok) == 1); + GPR_ASSERT(res_tag == tag); + GPR_ASSERT(ok); + + ev = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_TIMEOUT); + + shutdown_and_destroy(cc); + } +} + +static void test_cq_tls_cache_empty(void) { + grpc_completion_queue* cc; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + void* res_tag; + int ok; + + LOG_TEST("test_cq_tls_cache_empty"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + grpc_core::ExecCtx exec_ctx; // Reset exec_ctx + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + GPR_ASSERT( + grpc_completion_queue_thread_local_cache_flush(cc, &res_tag, &ok) == 0); + grpc_completion_queue_thread_local_cache_init(cc); + GPR_ASSERT( + grpc_completion_queue_thread_local_cache_flush(cc, &res_tag, &ok) == 0); + shutdown_and_destroy(cc); + } +} + +static void test_shutdown_then_next_polling(void) { + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue* cc; + grpc_completion_queue_attributes attr; + grpc_event event; + LOG_TEST("test_shutdown_then_next_polling"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + grpc_completion_queue_shutdown(cc); + event = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(event.type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cc); + } +} + +static void test_shutdown_then_next_with_timeout(void) { + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue* cc; + grpc_completion_queue_attributes attr; + grpc_event event; + LOG_TEST("test_shutdown_then_next_with_timeout"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_NEXT; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + grpc_completion_queue_shutdown(cc); + event = grpc_completion_queue_next(cc, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(event.type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cc); + } +} + +static void test_pluck(void) { + grpc_event ev; + grpc_completion_queue* cc; + void* tags[128]; + grpc_cq_completion completions[GPR_ARRAY_SIZE(tags)]; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + unsigned i, j; + + LOG_TEST("test_pluck"); + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + tags[i] = create_test_tag(); + for (j = 0; j < i; j++) { + GPR_ASSERT(tags[i] != tags[j]); + } + } + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_PLUCK; + for (size_t pidx = 0; pidx < GPR_ARRAY_SIZE(polling_types); pidx++) { + grpc_core::ExecCtx exec_ctx; // reset exec_ctx + attr.cq_polling_type = polling_types[pidx]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + GPR_ASSERT(grpc_cq_begin_op(cc, tags[i])); + grpc_cq_end_op(cc, tags[i], GRPC_ERROR_NONE, do_nothing_end_completion, + nullptr, &completions[i]); + } + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + ev = grpc_completion_queue_pluck( + cc, tags[i], gpr_inf_past(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.tag == tags[i]); + } + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + GPR_ASSERT(grpc_cq_begin_op(cc, tags[i])); + grpc_cq_end_op(cc, tags[i], GRPC_ERROR_NONE, do_nothing_end_completion, + nullptr, &completions[i]); + } + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + ev = grpc_completion_queue_pluck(cc, tags[GPR_ARRAY_SIZE(tags) - i - 1], + gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.tag == tags[GPR_ARRAY_SIZE(tags) - i - 1]); + } + + shutdown_and_destroy(cc); + } +} + +static void test_pluck_after_shutdown(void) { + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_event ev; + grpc_completion_queue* cc; + grpc_completion_queue_attributes attr; + + LOG_TEST("test_pluck_after_shutdown"); + + attr.version = 1; + attr.cq_completion_type = GRPC_CQ_PLUCK; + for (size_t i = 0; i < GPR_ARRAY_SIZE(polling_types); i++) { + attr.cq_polling_type = polling_types[i]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + grpc_completion_queue_shutdown(cc); + ev = grpc_completion_queue_pluck( + cc, nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cc); + } +} + +static void test_callback(void) { + grpc_completion_queue* cc; + static void* tags[128]; + grpc_cq_completion completions[GPR_ARRAY_SIZE(tags)]; + grpc_cq_polling_type polling_types[] = { + GRPC_CQ_DEFAULT_POLLING, GRPC_CQ_NON_LISTENING, GRPC_CQ_NON_POLLING}; + grpc_completion_queue_attributes attr; + unsigned i; + static gpr_mu mu, shutdown_mu; + static gpr_cv cv, shutdown_cv; + static int cb_counter; + gpr_mu_init(&mu); + gpr_mu_init(&shutdown_mu); + gpr_cv_init(&cv); + gpr_cv_init(&shutdown_cv); + + LOG_TEST("test_callback"); + + bool got_shutdown = false; + class ShutdownCallback : public grpc_completion_queue_functor { + public: + explicit ShutdownCallback(bool* done) : done_(done) { + functor_run = &ShutdownCallback::Run; + inlineable = false; + } + ~ShutdownCallback() {} + static void Run(grpc_completion_queue_functor* cb, int ok) { + gpr_mu_lock(&shutdown_mu); + *static_cast(cb)->done_ = static_cast(ok); + // Signal when the shutdown callback is completed. + gpr_cv_signal(&shutdown_cv); + gpr_mu_unlock(&shutdown_mu); + } + + private: + bool* done_; + }; + ShutdownCallback shutdown_cb(&got_shutdown); + + attr.version = 2; + attr.cq_completion_type = GRPC_CQ_CALLBACK; + attr.cq_shutdown_cb = &shutdown_cb; + + for (size_t pidx = 0; pidx < GPR_ARRAY_SIZE(polling_types); pidx++) { + int sumtags = 0; + int counter = 0; + cb_counter = 0; + { + // reset exec_ctx types + grpc_core::ExecCtx exec_ctx; + attr.cq_polling_type = polling_types[pidx]; + cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + + class TagCallback : public grpc_completion_queue_functor { + public: + TagCallback(int* counter, int tag) : counter_(counter), tag_(tag) { + functor_run = &TagCallback::Run; + // Inlineable should be false since this callback takes locks. + inlineable = false; + } + ~TagCallback() {} + static void Run(grpc_completion_queue_functor* cb, int ok) { + GPR_ASSERT(static_cast(ok)); + auto* callback = static_cast(cb); + gpr_mu_lock(&mu); + cb_counter++; + *callback->counter_ += callback->tag_; + if (cb_counter == GPR_ARRAY_SIZE(tags)) { + gpr_cv_signal(&cv); + } + gpr_mu_unlock(&mu); + delete callback; + }; + + private: + int* counter_; + int tag_; + }; + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + tags[i] = static_cast(new TagCallback(&counter, i)); + sumtags += i; + } + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + GPR_ASSERT(grpc_cq_begin_op(cc, tags[i])); + grpc_cq_end_op(cc, tags[i], GRPC_ERROR_NONE, do_nothing_end_completion, + nullptr, &completions[i]); + } + + gpr_mu_lock(&mu); + while (cb_counter != GPR_ARRAY_SIZE(tags)) { + // Wait for all the callbacks to complete. + gpr_cv_wait(&cv, &mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&mu); + + shutdown_and_destroy(cc); + + gpr_mu_lock(&shutdown_mu); + while (!got_shutdown) { + // Wait for the shutdown callback to complete. + gpr_cv_wait(&shutdown_cv, &shutdown_mu, + gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&shutdown_mu); + } + + // Run the assertions to check if the test ran successfully. + GPR_ASSERT(sumtags == counter); + GPR_ASSERT(got_shutdown); + got_shutdown = false; + } + + gpr_cv_destroy(&cv); + gpr_cv_destroy(&shutdown_cv); + gpr_mu_destroy(&mu); + gpr_mu_destroy(&shutdown_mu); +} + +struct thread_state { + grpc_completion_queue* cc; + void* tag; +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_no_op(); + test_pollset_conversion(); + test_wait_empty(); + test_shutdown_then_next_polling(); + test_shutdown_then_next_with_timeout(); + test_cq_end_op(); + test_pluck(); + test_pluck_after_shutdown(); + test_cq_tls_cache_full(); + test_cq_tls_cache_empty(); + test_callback(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/surface/completion_queue_threading_test.cc b/test/core/surface/completion_queue_threading_test.cc new file mode 100644 index 00000000..c7379dcc --- /dev/null +++ b/test/core/surface/completion_queue_threading_test.cc @@ -0,0 +1,300 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/surface/completion_queue.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x) + +static void* create_test_tag(void) { + static intptr_t i = 0; + return reinterpret_cast(++i); +} + +/* helper for tests to shutdown correctly and tersely */ +static void shutdown_and_destroy(grpc_completion_queue* cc) { + grpc_event ev; + grpc_completion_queue_shutdown(cc); + + switch (grpc_get_cq_completion_type(cc)) { + case GRPC_CQ_NEXT: { + ev = grpc_completion_queue_next(cc, gpr_inf_past(GPR_CLOCK_REALTIME), + nullptr); + break; + } + case GRPC_CQ_PLUCK: { + ev = grpc_completion_queue_pluck( + cc, create_test_tag(), gpr_inf_past(GPR_CLOCK_REALTIME), nullptr); + break; + } + default: { + gpr_log(GPR_ERROR, "Unknown completion type"); + break; + } + } + + GPR_ASSERT(ev.type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cc); +} + +static void do_nothing_end_completion(void* /*arg*/, + grpc_cq_completion* /*c*/) {} + +struct thread_state { + grpc_completion_queue* cc; + void* tag; +}; + +static void pluck_one(void* arg) { + struct thread_state* state = static_cast(arg); + grpc_completion_queue_pluck(state->cc, state->tag, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); +} + +static void test_too_many_plucks(void) { + grpc_event ev; + grpc_completion_queue* cc; + void* tags[GRPC_MAX_COMPLETION_QUEUE_PLUCKERS]; + grpc_cq_completion completions[GPR_ARRAY_SIZE(tags)]; + grpc_core::Thread threads[GPR_ARRAY_SIZE(tags)]; + struct thread_state thread_states[GPR_ARRAY_SIZE(tags)]; + grpc_core::ExecCtx exec_ctx; + unsigned i, j; + + LOG_TEST("test_too_many_plucks"); + + cc = grpc_completion_queue_create_for_pluck(nullptr); + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + tags[i] = create_test_tag(); + for (j = 0; j < i; j++) { + GPR_ASSERT(tags[i] != tags[j]); + } + thread_states[i].cc = cc; + thread_states[i].tag = tags[i]; + threads[i] = + grpc_core::Thread("grpc_pluck_test", pluck_one, thread_states + i); + threads[i].Start(); + } + + /* wait until all other threads are plucking */ + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1000)); + + ev = grpc_completion_queue_pluck(cc, create_test_tag(), + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_QUEUE_TIMEOUT); + + for (i = 0; i < GPR_ARRAY_SIZE(tags); i++) { + GPR_ASSERT(grpc_cq_begin_op(cc, tags[i])); + grpc_cq_end_op(cc, tags[i], GRPC_ERROR_NONE, do_nothing_end_completion, + nullptr, &completions[i]); + } + + for (auto& th : threads) { + th.Join(); + } + + shutdown_and_destroy(cc); +} + +#define TEST_THREAD_EVENTS 10000 + +typedef struct test_thread_options { + gpr_event on_started; + gpr_event* phase1; + gpr_event on_phase1_done; + gpr_event* phase2; + gpr_event on_finished; + size_t events_triggered; + int id; + grpc_completion_queue* cc; +} test_thread_options; + +gpr_timespec ten_seconds_time(void) { + return grpc_timeout_seconds_to_deadline(10); +} + +static void free_completion(void* /*arg*/, grpc_cq_completion* completion) { + gpr_free(completion); +} + +static void producer_thread(void* arg) { + test_thread_options* opt = static_cast(arg); + int i; + + gpr_log(GPR_INFO, "producer %d started", opt->id); + gpr_event_set(&opt->on_started, reinterpret_cast(1)); + GPR_ASSERT(gpr_event_wait(opt->phase1, ten_seconds_time())); + + gpr_log(GPR_INFO, "producer %d phase 1", opt->id); + for (i = 0; i < TEST_THREAD_EVENTS; i++) { + GPR_ASSERT(grpc_cq_begin_op(opt->cc, (void*)(intptr_t)1)); + } + + gpr_log(GPR_INFO, "producer %d phase 1 done", opt->id); + gpr_event_set(&opt->on_phase1_done, reinterpret_cast(1)); + GPR_ASSERT(gpr_event_wait(opt->phase2, ten_seconds_time())); + + gpr_log(GPR_INFO, "producer %d phase 2", opt->id); + for (i = 0; i < TEST_THREAD_EVENTS; i++) { + grpc_core::ExecCtx exec_ctx; + grpc_cq_end_op(opt->cc, reinterpret_cast(1), GRPC_ERROR_NONE, + free_completion, nullptr, + static_cast( + gpr_malloc(sizeof(grpc_cq_completion)))); + opt->events_triggered++; + } + + gpr_log(GPR_INFO, "producer %d phase 2 done", opt->id); + gpr_event_set(&opt->on_finished, reinterpret_cast(1)); +} + +static void consumer_thread(void* arg) { + test_thread_options* opt = static_cast(arg); + grpc_event ev; + + gpr_log(GPR_INFO, "consumer %d started", opt->id); + gpr_event_set(&opt->on_started, reinterpret_cast(1)); + GPR_ASSERT(gpr_event_wait(opt->phase1, ten_seconds_time())); + + gpr_log(GPR_INFO, "consumer %d phase 1", opt->id); + + gpr_log(GPR_INFO, "consumer %d phase 1 done", opt->id); + gpr_event_set(&opt->on_phase1_done, reinterpret_cast(1)); + GPR_ASSERT(gpr_event_wait(opt->phase2, ten_seconds_time())); + + gpr_log(GPR_INFO, "consumer %d phase 2", opt->id); + for (;;) { + ev = grpc_completion_queue_next( + opt->cc, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + switch (ev.type) { + case GRPC_OP_COMPLETE: + GPR_ASSERT(ev.success); + opt->events_triggered++; + break; + case GRPC_QUEUE_SHUTDOWN: + gpr_log(GPR_INFO, "consumer %d phase 2 done", opt->id); + gpr_event_set(&opt->on_finished, reinterpret_cast(1)); + return; + case GRPC_QUEUE_TIMEOUT: + gpr_log(GPR_ERROR, "Invalid timeout received"); + abort(); + } + } +} + +static void test_threading(size_t producers, size_t consumers) { + test_thread_options* options = static_cast( + gpr_malloc((producers + consumers) * sizeof(test_thread_options))); + gpr_event phase1 = GPR_EVENT_INIT; + gpr_event phase2 = GPR_EVENT_INIT; + grpc_completion_queue* cc = grpc_completion_queue_create_for_next(nullptr); + size_t i; + size_t total_consumed = 0; + static int optid = 101; + + gpr_log(GPR_INFO, "%s: %" PRIuPTR " producers, %" PRIuPTR " consumers", + "test_threading", producers, consumers); + + /* start all threads: they will wait for phase1 */ + grpc_core::Thread* threads = static_cast( + gpr_malloc(sizeof(*threads) * (producers + consumers))); + for (i = 0; i < producers + consumers; i++) { + gpr_event_init(&options[i].on_started); + gpr_event_init(&options[i].on_phase1_done); + gpr_event_init(&options[i].on_finished); + options[i].phase1 = &phase1; + options[i].phase2 = &phase2; + options[i].events_triggered = 0; + options[i].cc = cc; + options[i].id = optid++; + + bool ok; + threads[i] = grpc_core::Thread( + i < producers ? "grpc_producer" : "grpc_consumer", + i < producers ? producer_thread : consumer_thread, options + i, &ok); + GPR_ASSERT(ok); + threads[i].Start(); + gpr_event_wait(&options[i].on_started, ten_seconds_time()); + } + + /* start phase1: producers will pre-declare all operations they will + complete */ + gpr_log(GPR_INFO, "start phase 1"); + gpr_event_set(&phase1, reinterpret_cast(1)); + + gpr_log(GPR_INFO, "wait phase 1"); + for (i = 0; i < producers + consumers; i++) { + GPR_ASSERT(gpr_event_wait(&options[i].on_phase1_done, ten_seconds_time())); + } + gpr_log(GPR_INFO, "done phase 1"); + + /* start phase2: operations will complete, and consumers will consume them */ + gpr_log(GPR_INFO, "start phase 2"); + gpr_event_set(&phase2, reinterpret_cast(1)); + + /* in parallel, we shutdown the completion channel - all events should still + be consumed */ + grpc_completion_queue_shutdown(cc); + + /* join all threads */ + gpr_log(GPR_INFO, "wait phase 2"); + for (i = 0; i < producers + consumers; i++) { + GPR_ASSERT(gpr_event_wait(&options[i].on_finished, ten_seconds_time())); + } + gpr_log(GPR_INFO, "done phase 2"); + + /* destroy the completion channel */ + grpc_completion_queue_destroy(cc); + + for (i = 0; i < producers + consumers; i++) { + threads[i].Join(); + } + gpr_free(threads); + + /* verify that everything was produced and consumed */ + for (i = 0; i < producers + consumers; i++) { + if (i < producers) { + GPR_ASSERT(options[i].events_triggered == TEST_THREAD_EVENTS); + } else { + total_consumed += options[i].events_triggered; + } + } + GPR_ASSERT(total_consumed == producers * TEST_THREAD_EVENTS); + + gpr_free(options); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_too_many_plucks(); + test_threading(1, 1); + test_threading(1, 10); + test_threading(10, 1); + test_threading(10, 10); + grpc_shutdown(); + return 0; +} diff --git a/test/core/surface/concurrent_connectivity_test.cc b/test/core/surface/concurrent_connectivity_test.cc new file mode 100644 index 00000000..f63ee879 --- /dev/null +++ b/test/core/surface/concurrent_connectivity_test.cc @@ -0,0 +1,311 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* TODO(yashykt): When our macos testing infrastructure becomes good enough, we + * wouldn't need to reduce the number of threads on MacOS */ +#ifdef __APPLE__ +#define NUM_THREADS 10 +#else +#define NUM_THREADS 100 +#endif /* __APPLE */ + +#define NUM_OUTER_LOOPS 10 +#define NUM_INNER_LOOPS 10 +#define DELAY_MILLIS 10 +#define POLL_MILLIS 15000 + +#define NUM_OUTER_LOOPS_SHORT_TIMEOUTS 10 +#define NUM_INNER_LOOPS_SHORT_TIMEOUTS 100 +#define DELAY_MILLIS_SHORT_TIMEOUTS 1 +// in a successful test run, POLL_MILLIS should never be reached because all +// runs should end after the shorter delay_millis +#define POLL_MILLIS_SHORT_TIMEOUTS 30000 +// it should never take longer that this to shutdown the server +#define SERVER_SHUTDOWN_TIMEOUT 30000 + +static void* tag(int n) { return reinterpret_cast(n); } + +void create_loop_destroy(void* addr) { + for (int i = 0; i < NUM_OUTER_LOOPS; ++i) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_channel* chan = grpc_insecure_channel_create(static_cast(addr), + nullptr, nullptr); + + for (int j = 0; j < NUM_INNER_LOOPS; ++j) { + gpr_timespec later_time = + grpc_timeout_milliseconds_to_deadline(DELAY_MILLIS); + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(chan, 1); + grpc_channel_watch_connectivity_state(chan, state, later_time, cq, + nullptr); + gpr_timespec poll_time = + grpc_timeout_milliseconds_to_deadline(POLL_MILLIS); + GPR_ASSERT(grpc_completion_queue_next(cq, poll_time, nullptr).type == + GRPC_OP_COMPLETE); + /* check that the watcher from "watch state" was free'd */ + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(chan) == 0); + } + grpc_channel_destroy(chan); + grpc_completion_queue_destroy(cq); + } +} + +// Always stack-allocate or new ServerThreadArgs; never use gpr_malloc since +// this contains C++ objects. +struct ServerThreadArgs { + std::string addr; + grpc_server* server = nullptr; + grpc_completion_queue* cq = nullptr; + std::vector pollset; + gpr_mu* mu = nullptr; + gpr_event ready; + std::atomic_bool stop{false}; +}; + +void server_thread(void* vargs) { + struct ServerThreadArgs* args = static_cast(vargs); + grpc_event ev; + gpr_timespec deadline = + grpc_timeout_milliseconds_to_deadline(SERVER_SHUTDOWN_TIMEOUT); + ev = grpc_completion_queue_next(args->cq, deadline, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == tag(0xd1e)); +} + +static void on_connect(void* vargs, grpc_endpoint* tcp, + grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + gpr_free(acceptor); + struct ServerThreadArgs* args = static_cast(vargs); + grpc_endpoint_shutdown(tcp, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Connected")); + grpc_endpoint_destroy(tcp); + gpr_mu_lock(args->mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(args->pollset[0], nullptr)); + gpr_mu_unlock(args->mu); +} + +void bad_server_thread(void* vargs) { + struct ServerThreadArgs* args = static_cast(vargs); + + grpc_core::ExecCtx exec_ctx; + grpc_resolved_address resolved_addr; + grpc_sockaddr* addr = reinterpret_cast(resolved_addr.addr); + int port; + grpc_tcp_server* s; + grpc_error_handle error = grpc_tcp_server_create( + nullptr, nullptr, + grpc_slice_allocator_factory_create(grpc_resource_quota_create(nullptr)), + &s); + GPR_ASSERT(error == GRPC_ERROR_NONE); + memset(&resolved_addr, 0, sizeof(resolved_addr)); + addr->sa_family = GRPC_AF_INET; + error = grpc_tcp_server_add_port(s, &resolved_addr, &port); + GPR_ASSERT(GRPC_LOG_IF_ERROR("grpc_tcp_server_add_port", error)); + GPR_ASSERT(port > 0); + args->addr = absl::StrCat("localhost:", port); + + grpc_tcp_server_start(s, &args->pollset, on_connect, args); + gpr_event_set(&args->ready, reinterpret_cast(1)); + + gpr_mu_lock(args->mu); + while (!args->stop.load(std::memory_order_acquire)) { + grpc_millis deadline = grpc_core::ExecCtx::Get()->Now() + 100; + + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(args->pollset[0], &worker, deadline))) { + args->stop.store(true, std::memory_order_release); + } + gpr_mu_unlock(args->mu); + + gpr_mu_lock(args->mu); + } + gpr_mu_unlock(args->mu); + + grpc_tcp_server_unref(s); +} + +static void done_pollset_shutdown(void* pollset, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(pollset)); + gpr_free(pollset); +} + +int run_concurrent_connectivity_test() { + struct ServerThreadArgs args; + + grpc_init(); + + /* First round, no server */ + { + gpr_log(GPR_DEBUG, "Wave 1"); + grpc_core::Thread threads[NUM_THREADS]; + args.addr = "localhost:54321"; + for (auto& th : threads) { + th = grpc_core::Thread("grpc_wave_1", create_loop_destroy, + const_cast(args.addr.c_str())); + th.Start(); + } + for (auto& th : threads) { + th.Join(); + } + } + + { + /* Second round, actual grpc server */ + gpr_log(GPR_DEBUG, "Wave 2"); + int port = grpc_pick_unused_port_or_die(); + args.addr = absl::StrCat("localhost:", port); + args.server = grpc_server_create(nullptr, nullptr); + grpc_server_add_insecure_http2_port(args.server, args.addr.c_str()); + args.cq = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(args.server, args.cq, nullptr); + grpc_server_start(args.server); + grpc_core::Thread server2("grpc_wave_2_server", server_thread, &args); + server2.Start(); + + grpc_core::Thread threads[NUM_THREADS]; + for (auto& th : threads) { + th = grpc_core::Thread("grpc_wave_2", create_loop_destroy, + const_cast(args.addr.c_str())); + th.Start(); + } + for (auto& th : threads) { + th.Join(); + } + grpc_server_shutdown_and_notify(args.server, args.cq, tag(0xd1e)); + + server2.Join(); + grpc_server_destroy(args.server); + grpc_completion_queue_destroy(args.cq); + } + + { + /* Third round, bogus tcp server */ + gpr_log(GPR_DEBUG, "Wave 3"); + auto* pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &args.mu); + args.pollset.push_back(pollset); + gpr_event_init(&args.ready); + grpc_core::Thread server3("grpc_wave_3_server", bad_server_thread, &args); + server3.Start(); + gpr_event_wait(&args.ready, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + + grpc_core::Thread threads[NUM_THREADS]; + for (auto& th : threads) { + th = grpc_core::Thread("grpc_wave_3", create_loop_destroy, + const_cast(args.addr.c_str())); + th.Start(); + } + for (auto& th : threads) { + th.Join(); + } + + args.stop.store(true, std::memory_order_release); + server3.Join(); + { + grpc_core::ExecCtx exec_ctx; + grpc_pollset_shutdown( + args.pollset[0], + GRPC_CLOSURE_CREATE(done_pollset_shutdown, args.pollset[0], + grpc_schedule_on_exec_ctx)); + } + } + + grpc_shutdown(); + return 0; +} + +void watches_with_short_timeouts(void* addr) { + for (int i = 0; i < NUM_OUTER_LOOPS_SHORT_TIMEOUTS; ++i) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_channel* chan = grpc_insecure_channel_create(static_cast(addr), + nullptr, nullptr); + + for (int j = 0; j < NUM_INNER_LOOPS_SHORT_TIMEOUTS; ++j) { + gpr_timespec later_time = + grpc_timeout_milliseconds_to_deadline(DELAY_MILLIS_SHORT_TIMEOUTS); + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(chan, 0); + GPR_ASSERT(state == GRPC_CHANNEL_IDLE); + grpc_channel_watch_connectivity_state(chan, state, later_time, cq, + nullptr); + gpr_timespec poll_time = + grpc_timeout_milliseconds_to_deadline(POLL_MILLIS_SHORT_TIMEOUTS); + grpc_event ev = grpc_completion_queue_next(cq, poll_time, nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.success == false); + /* check that the watcher from "watch state" was free'd */ + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(chan) == 0); + } + grpc_channel_destroy(chan); + grpc_completion_queue_destroy(cq); + } +} + +// This test tries to catch deadlock situations. +// With short timeouts on "watches" and long timeouts on cq next calls, +// so that a QUEUE_TIMEOUT likely means that something is stuck. +int run_concurrent_watches_with_short_timeouts_test() { + grpc_init(); + + grpc_core::Thread threads[NUM_THREADS]; + + for (auto& th : threads) { + th = grpc_core::Thread("grpc_short_watches", watches_with_short_timeouts, + const_cast("localhost:54321")); + th.Start(); + } + for (auto& th : threads) { + th.Join(); + } + + grpc_shutdown(); + return 0; +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + + run_concurrent_connectivity_test(); + run_concurrent_watches_with_short_timeouts_test(); +} diff --git a/test/core/surface/init_test.cc b/test/core/surface/init_test.cc new file mode 100644 index 00000000..f99ba1e4 --- /dev/null +++ b/test/core/surface/init_test.cc @@ -0,0 +1,145 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/surface/init.h" + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/test_config.h" + +static int g_plugin_state; + +static void plugin_init(void) { g_plugin_state = 1; } +static void plugin_destroy(void) { g_plugin_state = 2; } +static bool plugin_is_intialized(void) { return g_plugin_state == 1; } +static bool plugin_is_destroyed(void) { return g_plugin_state == 2; } + +static void test(int rounds) { + int i; + for (i = 0; i < rounds; i++) { + grpc_init(); + } + for (i = 0; i < rounds; i++) { + grpc_shutdown(); + } + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, test) { + test(1); + test(2); + test(3); +} + +static void test_blocking(int rounds) { + int i; + for (i = 0; i < rounds; i++) { + grpc_init(); + } + for (i = 0; i < rounds; i++) { + grpc_shutdown_blocking(); + } + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, blocking) { + test_blocking(1); + test_blocking(2); + test_blocking(3); +} + +TEST(Init, shutdown_with_thread) { + grpc_init(); + { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + grpc_shutdown(); + } + grpc_maybe_wait_for_async_shutdown(); + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, mixed) { + grpc_init(); + grpc_init(); + grpc_shutdown(); + grpc_init(); + grpc_shutdown(); + grpc_shutdown(); + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, mixed_with_thread) { + grpc_init(); + { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + grpc_init(); + grpc_shutdown(); + grpc_init(); + grpc_shutdown(); + grpc_shutdown(); + } + grpc_maybe_wait_for_async_shutdown(); + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, plugin) { + grpc_init(); + EXPECT_TRUE(plugin_is_intialized()); + grpc_shutdown_blocking(); + EXPECT_TRUE(plugin_is_destroyed()); + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, repeatedly) { + for (int i = 0; i < 10; i++) { + grpc_init(); + { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + grpc_shutdown(); + } + } + grpc_maybe_wait_for_async_shutdown(); + EXPECT_FALSE(grpc_is_initialized()); +} + +TEST(Init, repeatedly_blocking) { + for (int i = 0; i < 10; i++) { + grpc_init(); + { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx( + GRPC_APP_CALLBACK_EXEC_CTX_FLAG_IS_INTERNAL_THREAD); + grpc_shutdown_blocking(); + } + } + EXPECT_FALSE(grpc_is_initialized()); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_register_plugin(plugin_init, plugin_destroy); + return RUN_ALL_TESTS(); +} diff --git a/test/core/surface/lame_client_test.cc b/test/core/surface/lame_client_test.cc new file mode 100644 index 00000000..24bf65d4 --- /dev/null +++ b/test/core/surface/lame_client_test.cc @@ -0,0 +1,157 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/transport.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/test_config.h" + +class Watcher : public grpc_core::ConnectivityStateWatcherInterface { + public: + void Notify(grpc_connectivity_state new_state, + const absl::Status& /* status */) override { + GPR_ASSERT(new_state == GRPC_CHANNEL_SHUTDOWN); + } +}; + +static void* tag(intptr_t t) { return reinterpret_cast(t); } + +static grpc_closure transport_op_cb; + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void test_transport_op(grpc_channel* channel) { + grpc_core::ExecCtx exec_ctx; + grpc_transport_op* op = grpc_make_transport_op(nullptr); + op->start_connectivity_watch = grpc_core::MakeOrphanable(); + grpc_channel_element* elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(channel), 0); + elem->filter->start_transport_op(elem, op); + + GRPC_CLOSURE_INIT(&transport_op_cb, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + op = grpc_make_transport_op(&transport_op_cb); + elem->filter->start_transport_op(elem, op); +} + +int main(int argc, char** argv) { + grpc_channel* chan; + grpc_call* call; + grpc_completion_queue* cq; + cq_verifier* cqv; + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + char* peer; + + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + + const char* error_message = "Rpc sent on a lame channel."; + grpc_status_code error_code = GRPC_STATUS_ABORTED; + chan = grpc_lame_client_channel_create("lampoon:national", error_code, + error_message); + GPR_ASSERT(chan); + + test_transport_op(chan); + + GPR_ASSERT(GRPC_CHANNEL_TRANSIENT_FAILURE == + grpc_channel_check_connectivity_state(chan, 0)); + + cq = grpc_completion_queue_create_for_next(nullptr); + + grpc_slice host = grpc_slice_from_static_string("anywhere"); + call = + grpc_channel_create_call(chan, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/Foo"), &host, + grpc_timeout_seconds_to_deadline(100), nullptr); + GPR_ASSERT(call); + cqv = cq_verifier_create(cq); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(call, ops, static_cast(op - ops), + tag(1), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* the call should immediately fail */ + CQ_EXPECT_COMPLETION(cqv, tag(1), 0); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(call, ops, static_cast(op - ops), + tag(2), nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + /* the call should immediately fail */ + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv); + + peer = grpc_call_get_peer(call); + GPR_ASSERT(strcmp(peer, "lampoon:national") == 0); + gpr_free(peer); + + GPR_ASSERT(status == error_code); + GPR_ASSERT(grpc_slice_str_cmp(details, error_message) == 0); + + grpc_call_unref(call); + grpc_channel_destroy(chan); + cq_verifier_destroy(cqv); + grpc_completion_queue_destroy(cq); + + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_slice_unref(details); + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/surface/num_external_connectivity_watchers_test.cc b/test/core/surface/num_external_connectivity_watchers_test.cc new file mode 100644 index 00000000..2a348dae --- /dev/null +++ b/test/core/surface/num_external_connectivity_watchers_test.cc @@ -0,0 +1,206 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" + +typedef struct test_fixture { + const char* name; + grpc_channel* (*create_channel)(const char* addr); +} test_fixture; + +static size_t next_tag = 1; + +static void channel_idle_start_watch(grpc_channel* channel, + grpc_completion_queue* cq) { + gpr_timespec connect_deadline = grpc_timeout_milliseconds_to_deadline(1); + GPR_ASSERT(grpc_channel_check_connectivity_state(channel, 0) == + GRPC_CHANNEL_IDLE); + + grpc_channel_watch_connectivity_state(channel, GRPC_CHANNEL_IDLE, + connect_deadline, cq, + reinterpret_cast(next_tag++)); + gpr_log(GPR_DEBUG, "number of active connect watchers: %d", + grpc_channel_num_external_connectivity_watchers(channel)); +} + +static void channel_idle_poll_for_timeout(grpc_channel* channel, + grpc_completion_queue* cq) { + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + + /* expect watch_connectivity_state to end with a timeout */ + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.success == false); + GPR_ASSERT(grpc_channel_check_connectivity_state(channel, 0) == + GRPC_CHANNEL_IDLE); +} + +/* Test and use the "num_external_watchers" call to make sure + * that "connectivity watcher" structs are free'd just after, if + * their corresponding timeouts occur. */ +static void run_timeouts_test(const test_fixture* fixture) { + gpr_log(GPR_INFO, "TEST: %s", fixture->name); + + grpc_init(); + std::string addr = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + + grpc_channel* channel = fixture->create_channel(addr.c_str()); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + /* start 1 watcher and then let it time out */ + channel_idle_start_watch(channel, cq); + channel_idle_poll_for_timeout(channel, cq); + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(channel) == 0); + + /* start 3 watchers and then let them all time out */ + for (size_t i = 1; i <= 3; i++) { + channel_idle_start_watch(channel, cq); + } + for (size_t i = 1; i <= 3; i++) { + channel_idle_poll_for_timeout(channel, cq); + } + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(channel) == 0); + + /* start 3 watchers, see one time out, start another 3, and then see them all + * time out */ + for (size_t i = 1; i <= 3; i++) { + channel_idle_start_watch(channel, cq); + } + channel_idle_poll_for_timeout(channel, cq); + for (size_t i = 3; i <= 5; i++) { + channel_idle_start_watch(channel, cq); + } + for (size_t i = 1; i <= 5; i++) { + channel_idle_poll_for_timeout(channel, cq); + } + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(channel) == 0); + + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + GPR_ASSERT(grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cq); + + grpc_shutdown(); +} + +/* An edge scenario; sets channel state to explicitly, and outside + * of a polling call. */ +static void run_channel_shutdown_before_timeout_test( + const test_fixture* fixture) { + gpr_log(GPR_INFO, "TEST: %s", fixture->name); + + grpc_init(); + std::string addr = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + + grpc_channel* channel = fixture->create_channel(addr.c_str()); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + + /* start 1 watcher and then shut down the channel before the timer goes off */ + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(channel) == 0); + + /* expecting a 30 second timeout to go off much later than the shutdown. */ + gpr_timespec connect_deadline = grpc_timeout_seconds_to_deadline(30); + GPR_ASSERT(grpc_channel_check_connectivity_state(channel, 0) == + GRPC_CHANNEL_IDLE); + + grpc_channel_watch_connectivity_state(channel, GRPC_CHANNEL_IDLE, + connect_deadline, cq, + reinterpret_cast(1)); + grpc_channel_destroy(channel); + + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + /* expect success with a state transition to CHANNEL_SHUTDOWN */ + GPR_ASSERT(ev.success == true); + + grpc_completion_queue_shutdown(cq); + GPR_ASSERT(grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type == GRPC_QUEUE_SHUTDOWN); + grpc_completion_queue_destroy(cq); + + grpc_shutdown(); +} + +static grpc_channel* insecure_test_create_channel(const char* addr) { + return grpc_insecure_channel_create(addr, nullptr, nullptr); +} + +static const test_fixture insecure_test = { + "insecure", + insecure_test_create_channel, +}; + +static grpc_channel* secure_test_create_channel(const char* addr) { + grpc_slice ca_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + const char* test_root_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(test_root_cert, nullptr, nullptr, nullptr); + grpc_slice_unref(ca_slice); + grpc_arg ssl_name_override = { + GRPC_ARG_STRING, + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + {const_cast("foo.test.google.fr")}}; + grpc_channel_args* new_client_args = + grpc_channel_args_copy_and_add(nullptr, &ssl_name_override, 1); + grpc_channel* channel = + grpc_secure_channel_create(ssl_creds, addr, new_client_args, nullptr); + { + grpc_core::ExecCtx exec_ctx; + grpc_channel_args_destroy(new_client_args); + } + grpc_channel_credentials_release(ssl_creds); + return channel; +} + +static const test_fixture secure_test = { + "secure", + secure_test_create_channel, +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + + run_timeouts_test(&insecure_test); + run_timeouts_test(&secure_test); + + run_channel_shutdown_before_timeout_test(&insecure_test); + run_channel_shutdown_before_timeout_test(&secure_test); +} diff --git a/test/core/surface/secure_channel_create_test.cc b/test/core/surface/secure_channel_create_test.cc new file mode 100644 index 00000000..97a7337b --- /dev/null +++ b/test/core/surface/secure_channel_create_test.cc @@ -0,0 +1,78 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/util/test_config.h" + +void test_unknown_scheme_target(void) { + grpc_core::ResolverRegistry::Builder::ShutdownRegistry(); + grpc_core::ResolverRegistry::Builder::InitRegistry(); + grpc_channel_credentials* creds = + grpc_fake_transport_security_credentials_create(); + grpc_channel* chan = + grpc_secure_channel_create(creds, "blah://blah", nullptr, nullptr); + grpc_channel_element* elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(chan), 0); + GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client")); + grpc_core::ExecCtx exec_ctx; + GRPC_CHANNEL_INTERNAL_UNREF(chan, "test"); + creds->Unref(); +} + +void test_security_connector_already_in_arg(void) { + grpc_arg arg = grpc_security_connector_to_arg(nullptr); + grpc_channel_args args; + args.num_args = 1; + args.args = &arg; + grpc_channel* chan = + grpc_secure_channel_create(nullptr, nullptr, &args, nullptr); + grpc_channel_element* elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(chan), 0); + GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client")); + grpc_core::ExecCtx exec_ctx; + GRPC_CHANNEL_INTERNAL_UNREF(chan, "test"); +} + +void test_null_creds(void) { + grpc_channel* chan = + grpc_secure_channel_create(nullptr, nullptr, nullptr, nullptr); + grpc_channel_element* elem = + grpc_channel_stack_element(grpc_channel_get_channel_stack(chan), 0); + GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client")); + grpc_core::ExecCtx exec_ctx; + GRPC_CHANNEL_INTERNAL_UNREF(chan, "test"); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_security_connector_already_in_arg(); + test_null_creds(); + test_unknown_scheme_target(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/surface/sequential_connectivity_test.cc b/test/core/surface/sequential_connectivity_test.cc new file mode 100644 index 00000000..01d9a85c --- /dev/null +++ b/test/core/surface/sequential_connectivity_test.cc @@ -0,0 +1,203 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +typedef struct test_fixture { + const char* name; + void (*add_server_port)(grpc_server* server, const char* addr); + // Have the creds here so all the channels will share the same one to enabled + // subchannel sharing if needed. + grpc_channel_credentials* creds; +} test_fixture; + +#define NUM_CONNECTIONS 100 + +typedef struct { + grpc_server* server; + grpc_completion_queue* cq; +} server_thread_args; + +static void server_thread_func(void* args) { + server_thread_args* a = static_cast(args); + grpc_event ev = grpc_completion_queue_next( + a->cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == nullptr); + GPR_ASSERT(ev.success == true); +} + +static grpc_channel* create_test_channel(const char* addr, + grpc_channel_credentials* creds, + bool share_subchannel) { + grpc_channel* channel = nullptr; + std::vector args; + args.push_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL), + !share_subchannel)); + if (creds != nullptr) { + args.push_back(grpc_channel_arg_string_create( + const_cast(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG), + const_cast("foo.test.google.fr"))); + } + grpc_channel_args channel_args = {args.size(), args.data()}; + if (creds != nullptr) { + channel = grpc_secure_channel_create(creds, addr, &channel_args, nullptr); + } else { + channel = grpc_insecure_channel_create(addr, &channel_args, nullptr); + } + return channel; +} + +static void run_test(const test_fixture* fixture, bool share_subchannel) { + gpr_log(GPR_INFO, "TEST: %s sharing subchannel: %d", fixture->name, + share_subchannel); + + std::string addr = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + + grpc_server* server = grpc_server_create(nullptr, nullptr); + fixture->add_server_port(server, addr.c_str()); + grpc_completion_queue* server_cq = + grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server, server_cq, nullptr); + grpc_server_start(server); + + server_thread_args sta = {server, server_cq}; + grpc_core::Thread server_thread("grpc_server", server_thread_func, &sta); + server_thread.Start(); + + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_channel* channels[NUM_CONNECTIONS]; + for (size_t i = 0; i < NUM_CONNECTIONS; i++) { + channels[i] = + create_test_channel(addr.c_str(), fixture->creds, share_subchannel); + + gpr_timespec connect_deadline = grpc_timeout_seconds_to_deadline(30); + grpc_connectivity_state state; + while ((state = grpc_channel_check_connectivity_state(channels[i], 1)) != + GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state(channels[i], state, + connect_deadline, cq, nullptr); + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + /* check that the watcher from "watch state" was free'd */ + GPR_ASSERT(grpc_channel_num_external_connectivity_watchers(channels[i]) == + 0); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == nullptr); + GPR_ASSERT(ev.success == true); + } + } + + grpc_server_shutdown_and_notify(server, server_cq, nullptr); + server_thread.Join(); + + grpc_completion_queue_shutdown(server_cq); + grpc_completion_queue_shutdown(cq); + + while (grpc_completion_queue_next(server_cq, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + + for (size_t i = 0; i < NUM_CONNECTIONS; i++) { + grpc_channel_destroy(channels[i]); + } + + grpc_server_destroy(server); + grpc_completion_queue_destroy(server_cq); + grpc_completion_queue_destroy(cq); +} + +static void insecure_test_add_port(grpc_server* server, const char* addr) { + grpc_server_add_insecure_http2_port(server, addr); +} + +static void secure_test_add_port(grpc_server* server, const char* addr) { + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key, server_cert}; + grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create( + nullptr, &pem_key_cert_pair, 1, 0, nullptr); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + grpc_server_add_secure_http2_port(server, addr, ssl_creds); + grpc_server_credentials_release(ssl_creds); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + const test_fixture insecure_test = { + "insecure", + insecure_test_add_port, + nullptr, + }; + run_test(&insecure_test, /*share_subchannel=*/true); + run_test(&insecure_test, /*share_subchannel=*/false); + + grpc_slice ca_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + const char* test_root_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + grpc_channel_credentials* ssl_creds = + grpc_ssl_credentials_create(test_root_cert, nullptr, nullptr, nullptr); + grpc_slice_unref(ca_slice); + const test_fixture secure_test = { + "secure", + secure_test_add_port, + ssl_creds, + }; + run_test(&secure_test, /*share_subchannel=*/true); + run_test(&secure_test, /*share_subchannel=*/false); + grpc_channel_credentials_release(ssl_creds); + + grpc_shutdown(); +} diff --git a/test/core/surface/server_chttp2_test.cc b/test/core/surface/server_chttp2_test.cc new file mode 100644 index 00000000..e85c1d58 --- /dev/null +++ b/test/core/surface/server_chttp2_test.cc @@ -0,0 +1,73 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/tsi/fake_transport_security.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +TEST(ServerChttp2, UnparseableTarget) { + grpc_channel_args args = {0, nullptr}; + grpc_server* server = grpc_server_create(&args, nullptr); + int port = grpc_server_add_insecure_http2_port(server, "["); + EXPECT_EQ(port, 0); + grpc_server_destroy(server); +} + +TEST(ServerChttp2, AddSamePortTwice) { + grpc_arg a = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ALLOW_REUSEPORT), 0); + grpc_channel_args args = {1, &a}; + + int port = grpc_pick_unused_port_or_die(); + grpc_completion_queue* cq = grpc_completion_queue_create_for_pluck(nullptr); + grpc_server* server = grpc_server_create(&args, nullptr); + grpc_server_credentials* fake_creds = + grpc_fake_transport_security_server_credentials_create(); + std::string addr = grpc_core::JoinHostPort("localhost", port); + EXPECT_EQ(grpc_server_add_secure_http2_port(server, addr.c_str(), fake_creds), + port); + EXPECT_EQ(grpc_server_add_secure_http2_port(server, addr.c_str(), fake_creds), + 0); + + grpc_server_credentials_release(fake_creds); + grpc_server_shutdown_and_notify(server, cq, nullptr); + grpc_completion_queue_pluck(cq, nullptr, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/surface/server_test.cc b/test/core/surface/server_test.cc new file mode 100644 index 00000000..630d9f82 --- /dev/null +++ b/test/core/surface/server_test.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +void test_register_method_fail(void) { + grpc_server* server = grpc_server_create(nullptr, nullptr); + void* method; + void* method_old; + method = grpc_server_register_method(server, nullptr, nullptr, + GRPC_SRM_PAYLOAD_NONE, 0); + GPR_ASSERT(method == nullptr); + method_old = + grpc_server_register_method(server, "m", "h", GRPC_SRM_PAYLOAD_NONE, 0); + GPR_ASSERT(method_old != nullptr); + method = grpc_server_register_method( + server, "m", "h", GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER, 0); + GPR_ASSERT(method == nullptr); + method_old = + grpc_server_register_method(server, "m2", "h2", GRPC_SRM_PAYLOAD_NONE, + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST); + GPR_ASSERT(method_old != nullptr); + method = + grpc_server_register_method(server, "m2", "h2", GRPC_SRM_PAYLOAD_NONE, 0); + GPR_ASSERT(method == nullptr); + method = grpc_server_register_method( + server, "m2", "h2", GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER, + GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST); + GPR_ASSERT(method == nullptr); + grpc_server_destroy(server); +} + +void test_request_call_on_no_server_cq(void) { + grpc_completion_queue* cc = grpc_completion_queue_create_for_next(nullptr); + grpc_server* server = grpc_server_create(nullptr, nullptr); + GPR_ASSERT(GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE == + grpc_server_request_call(server, nullptr, nullptr, nullptr, cc, cc, + nullptr)); + GPR_ASSERT(GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE == + grpc_server_request_registered_call(server, nullptr, nullptr, + nullptr, nullptr, nullptr, cc, + cc, nullptr)); + grpc_completion_queue_destroy(cc); + grpc_server_destroy(server); +} + +void test_bind_server_twice(void) { + grpc_arg a = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_ALLOW_REUSEPORT), 0); + grpc_channel_args args = {1, &a}; + + grpc_server* server1 = grpc_server_create(&args, nullptr); + grpc_server* server2 = grpc_server_create(&args, nullptr); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + int port = grpc_pick_unused_port_or_die(); + std::string addr = absl::StrCat("[::]:", port); + grpc_server_register_completion_queue(server1, cq, nullptr); + grpc_server_register_completion_queue(server2, cq, nullptr); + GPR_ASSERT(0 == + grpc_server_add_secure_http2_port(server2, addr.c_str(), nullptr)); + GPR_ASSERT(port == + grpc_server_add_insecure_http2_port(server1, addr.c_str())); + GPR_ASSERT(0 == grpc_server_add_insecure_http2_port(server2, addr.c_str())); + grpc_server_credentials* fake_creds = + grpc_fake_transport_security_server_credentials_create(); + GPR_ASSERT(0 == grpc_server_add_secure_http2_port(server2, addr.c_str(), + fake_creds)); + grpc_server_credentials_release(fake_creds); + grpc_server_shutdown_and_notify(server1, cq, nullptr); + grpc_server_shutdown_and_notify(server2, cq, nullptr); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + grpc_server_destroy(server1); + grpc_server_destroy(server2); + grpc_completion_queue_destroy(cq); +} + +void test_bind_server_to_addr(const char* host, bool secure) { + int port = grpc_pick_unused_port_or_die(); + std::string addr = grpc_core::JoinHostPort(host, port); + gpr_log(GPR_INFO, "Test bind to %s", addr.c_str()); + + grpc_server* server = grpc_server_create(nullptr, nullptr); + if (secure) { + grpc_server_credentials* fake_creds = + grpc_fake_transport_security_server_credentials_create(); + GPR_ASSERT( + grpc_server_add_secure_http2_port(server, addr.c_str(), fake_creds)); + grpc_server_credentials_release(fake_creds); + } else { + GPR_ASSERT(grpc_server_add_insecure_http2_port(server, addr.c_str())); + } + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + grpc_server_start(server); + grpc_server_shutdown_and_notify(server, cq, nullptr); + grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); +} + +static int external_dns_works(const char* host) { + grpc_resolved_addresses* res = nullptr; + grpc_error_handle error = grpc_blocking_resolve_address(host, "80", &res); + GRPC_ERROR_UNREF(error); + if (res != nullptr) { + grpc_resolved_addresses_destroy(res); + return 1; + } + return 0; +} + +static void test_bind_server_to_addrs(const char** addrs, size_t n) { + for (size_t i = 0; i < n; i++) { + test_bind_server_to_addr(addrs[i], false); + test_bind_server_to_addr(addrs[i], true); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + test_register_method_fail(); + test_request_call_on_no_server_cq(); + test_bind_server_twice(); + + static const char* addrs[] = { + "::1", "127.0.0.1", "::ffff:127.0.0.1", "localhost", "0.0.0.0", "::", + }; + test_bind_server_to_addrs(addrs, GPR_ARRAY_SIZE(addrs)); + + if (external_dns_works("loopback46.unittest.grpc.io")) { + static const char* dns_addrs[] = { + "loopback46.unittest.grpc.io", + "loopback4.unittest.grpc.io", + }; + test_bind_server_to_addrs(dns_addrs, GPR_ARRAY_SIZE(dns_addrs)); + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/transport/bdp_estimator_test.cc b/test/core/transport/bdp_estimator_test.cc new file mode 100644 index 00000000..6d5cc9a1 --- /dev/null +++ b/test/core/transport/bdp_estimator_test.cc @@ -0,0 +1,151 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/bdp_estimator.h" + +#include + +#include + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "test/core/util/test_config.h" + +extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type); + +namespace grpc_core { +namespace testing { +namespace { +int g_clock = 0; + +gpr_timespec fake_gpr_now(gpr_clock_type clock_type) { + gpr_timespec ts; + ts.tv_sec = g_clock; + ts.tv_nsec = 0; + ts.clock_type = clock_type; + return ts; +} + +void inc_time(void) { g_clock += 30; } +} // namespace + +TEST(BdpEstimatorTest, NoOp) { BdpEstimator est("test"); } + +TEST(BdpEstimatorTest, EstimateBdpNoSamples) { + BdpEstimator est("test"); + est.EstimateBdp(); +} + +namespace { +void AddSamples(BdpEstimator* estimator, int64_t* samples, size_t n) { + estimator->AddIncomingBytes(1234567); + inc_time(); + grpc_core::ExecCtx exec_ctx; + estimator->SchedulePing(); + estimator->StartPing(); + for (size_t i = 0; i < n; i++) { + estimator->AddIncomingBytes(samples[i]); + } + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(1, GPR_TIMESPAN))); + grpc_core::ExecCtx::Get()->InvalidateNow(); + estimator->CompletePing(); +} + +void AddSample(BdpEstimator* estimator, int64_t sample) { + AddSamples(estimator, &sample, 1); +} +} // namespace + +TEST(BdpEstimatorTest, GetEstimate1Sample) { + BdpEstimator est("test"); + AddSample(&est, 100); + est.EstimateBdp(); +} + +TEST(BdpEstimatorTest, GetEstimate2Samples) { + BdpEstimator est("test"); + AddSample(&est, 100); + AddSample(&est, 100); + est.EstimateBdp(); +} + +TEST(BdpEstimatorTest, GetEstimate3Samples) { + BdpEstimator est("test"); + AddSample(&est, 100); + AddSample(&est, 100); + AddSample(&est, 100); + est.EstimateBdp(); +} + +namespace { +int64_t NextPow2(int64_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} +} // namespace + +class BdpEstimatorRandomTest : public ::testing::TestWithParam {}; + +TEST_P(BdpEstimatorRandomTest, GetEstimateRandomValues) { + BdpEstimator est("test"); + const int kMaxSample = 65535; + int min = kMaxSample; + int max = 0; + for (size_t i = 0; i < GetParam(); i++) { + int sample = rand() % (kMaxSample + 1); + if (sample < min) min = sample; + if (sample > max) max = sample; + AddSample(&est, sample); + if (i >= 3) { + EXPECT_LE(est.EstimateBdp(), std::max(int64_t(65536), 2 * NextPow2(max))) + << " min:" << min << " max:" << max << " sample:" << sample; + } + } +} + +INSTANTIATE_TEST_SUITE_P(TooManyNames, BdpEstimatorRandomTest, + ::testing::Values(3, 4, 6, 9, 13, 19, 28, 42, 63, 94, + 141, 211, 316, 474, 711)); + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + gpr_now_impl = grpc_core::testing::fake_gpr_now; + grpc_init(); + grpc_timer_manager_set_threading(false); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/binder/binder_transport_test.cc b/test/core/transport/binder/binder_transport_test.cc new file mode 100644 index 00000000..74c6b090 --- /dev/null +++ b/test/core/transport/binder/binder_transport_test.cc @@ -0,0 +1,711 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Unit-tests for grpc_binder_transport +// +// Verify that a calls to the perform_stream_op of grpc_binder_transport +// transform into the correct sequence of binder transactions. +#include "src/core/ext/transport/binder/transport/binder_transport.h" + +#include +#include +#include + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/synchronization/notification.h" + +#include + +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" +#include "src/core/ext/transport/binder/transport/binder_stream.h" +#include "test/core/transport/binder/mock_objects.h" +#include "test/core/util/test_config.h" + +namespace grpc_binder { +namespace { + +using ::testing::Expectation; +using ::testing::NiceMock; +using ::testing::Return; + +class BinderTransportTest : public ::testing::Test { + public: + BinderTransportTest() + : arena_(grpc_core::Arena::Create(/* initial_size = */ 1)), + transport_(grpc_create_binder_transport_client( + absl::make_unique>(), + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>())) { + auto* gbt = reinterpret_cast(transport_); + gbt->wire_writer = absl::make_unique(); + GRPC_STREAM_REF_INIT(&ref_, 1, nullptr, nullptr, "phony ref"); + } + + ~BinderTransportTest() override { + grpc_core::ExecCtx exec_ctx; + grpc_transport_destroy(transport_); + grpc_core::ExecCtx::Get()->Flush(); + for (grpc_binder_stream* gbs : stream_buffer_) { + gbs->~grpc_binder_stream(); + gpr_free(gbs); + } + arena_->Destroy(); + } + + void PerformStreamOp(grpc_binder_stream* gbs, + grpc_transport_stream_op_batch* op) { + grpc_transport_perform_stream_op(transport_, + reinterpret_cast(gbs), op); + } + + grpc_binder_transport* GetBinderTransport() { + return reinterpret_cast(transport_); + } + + grpc_binder_stream* InitNewBinderStream() { + grpc_binder_stream* gbs = static_cast( + gpr_malloc(grpc_transport_stream_size(transport_))); + grpc_transport_init_stream(transport_, reinterpret_cast(gbs), + &ref_, nullptr, arena_); + stream_buffer_.push_back(gbs); + return gbs; + } + + MockWireWriter& GetWireWriter() { + return *reinterpret_cast( + GetBinderTransport()->wire_writer.get()); + } + + static void SetUpTestSuite() { grpc_init(); } + static void TearDownTestSuite() { grpc_shutdown(); } + + protected: + grpc_core::Arena* arena_; + grpc_transport* transport_; + grpc_stream_refcount ref_; + std::vector stream_buffer_; +}; + +void MockCallback(void* arg, grpc_error_handle error); + +class MockGrpcClosure { + public: + explicit MockGrpcClosure(absl::Notification* notification = nullptr) + : notification_(notification) { + GRPC_CLOSURE_INIT(&closure_, MockCallback, this, nullptr); + } + + grpc_closure* GetGrpcClosure() { return &closure_; } + MOCK_METHOD(void, Callback, (grpc_error_handle), ()); + + absl::Notification* notification_; + + private: + grpc_closure closure_; +}; + +void MockCallback(void* arg, grpc_error_handle error) { + MockGrpcClosure* mock_closure = static_cast(arg); + mock_closure->Callback(error); + if (mock_closure->notification_) { + mock_closure->notification_->Notify(); + } +} + +// Matches with transactions having the desired flag, method_ref, +// initial_metadata, and message_data. +MATCHER_P4(TransactionMatches, flag, method_ref, initial_metadata, message_data, + "") { + if (arg.GetFlags() != flag) return false; + if (flag & kFlagPrefix) { + if (arg.GetMethodRef() != method_ref) return false; + if (arg.GetPrefixMetadata() != initial_metadata) return false; + } + if (flag & kFlagMessageData) { + if (arg.GetMessageData() != message_data) return false; + } + return true; +} + +// Matches with grpc_error having error message containing |msg|. +MATCHER_P(GrpcErrorMessageContains, msg, "") { + return absl::StrContains(grpc_error_std_string(arg), msg); +} + +// Verify that the lower-level metadata has the same content as the gRPC +// metadata. +void VerifyMetadataEqual(const Metadata& md, + const grpc_metadata_batch& grpc_md) { + size_t i = 0; + grpc_md.ForEach([&](grpc_mdelem mdelm) { + EXPECT_EQ(grpc_core::StringViewFromSlice(GRPC_MDKEY(mdelm)), md[i].first); + EXPECT_EQ(grpc_core::StringViewFromSlice(GRPC_MDVALUE(mdelm)), + md[i].second); + i++; + }); + EXPECT_EQ(md.size(), i); +} + +// RAII helper classes for constructing gRPC metadata and receiving callbacks. +struct MakeSendInitialMetadata { + MakeSendInitialMetadata(const Metadata& initial_metadata, + const std::string& method_ref, + grpc_transport_stream_op_batch* op) + : storage(initial_metadata.size()) { + size_t i = 0; + for (const auto& md : initial_metadata) { + const std::string& key = md.first; + const std::string& value = md.second; + EXPECT_EQ(grpc_metadata_batch_add_tail( + &grpc_initial_metadata, &storage[i], + grpc_mdelem_from_slices(grpc_slice_from_cpp_string(key), + grpc_slice_from_cpp_string(value))), + GRPC_ERROR_NONE); + i++; + } + if (!method_ref.empty()) { + EXPECT_EQ( + grpc_metadata_batch_add_tail( + &grpc_initial_metadata, &method_ref_storage, + grpc_mdelem_from_slices(GRPC_MDSTR_PATH, + grpc_slice_from_cpp_string(method_ref))), + GRPC_ERROR_NONE); + } + op->send_initial_metadata = true; + op->payload->send_initial_metadata.send_initial_metadata = + &grpc_initial_metadata; + } + ~MakeSendInitialMetadata() {} + + std::vector storage; + grpc_linked_mdelem method_ref_storage; + grpc_core::ScopedArenaPtr arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch grpc_initial_metadata{arena.get()}; +}; + +struct MakeSendMessage { + MakeSendMessage(const std::string& message, + grpc_transport_stream_op_batch* op) { + grpc_slice_buffer send_buffer; + grpc_slice_buffer_init(&send_buffer); + grpc_slice send_slice = grpc_slice_from_cpp_string(message); + grpc_slice_buffer_add(&send_buffer, send_slice); + + send_stream.Init(&send_buffer, 0); + grpc_slice_buffer_destroy(&send_buffer); + + op->send_message = true; + op->payload->send_message.send_message.reset(send_stream.get()); + } + + grpc_core::ManualConstructor send_stream; +}; + +struct MakeSendTrailingMetadata { + explicit MakeSendTrailingMetadata(const Metadata& trailing_metadata, + grpc_transport_stream_op_batch* op) { + EXPECT_TRUE(trailing_metadata.empty()); + + op->send_trailing_metadata = true; + op->payload->send_trailing_metadata.send_trailing_metadata = + &grpc_trailing_metadata; + } + + grpc_core::ScopedArenaPtr arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch grpc_trailing_metadata{arena.get()}; +}; + +struct MakeRecvInitialMetadata { + explicit MakeRecvInitialMetadata(grpc_transport_stream_op_batch* op, + Expectation* call_before = nullptr) + : ready(¬ification) { + op->recv_initial_metadata = true; + op->payload->recv_initial_metadata.recv_initial_metadata = + &grpc_initial_metadata; + op->payload->recv_initial_metadata.recv_initial_metadata_ready = + ready.GetGrpcClosure(); + if (call_before) { + EXPECT_CALL(ready, Callback).After(*call_before); + } else { + EXPECT_CALL(ready, Callback); + } + } + + ~MakeRecvInitialMetadata() {} + + MockGrpcClosure ready; + grpc_core::ScopedArenaPtr arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch grpc_initial_metadata{arena.get()}; + absl::Notification notification; +}; + +struct MakeRecvMessage { + explicit MakeRecvMessage(grpc_transport_stream_op_batch* op, + Expectation* call_before = nullptr) + : ready(¬ification) { + op->recv_message = true; + op->payload->recv_message.recv_message = &grpc_message; + op->payload->recv_message.recv_message_ready = ready.GetGrpcClosure(); + if (call_before) { + EXPECT_CALL(ready, Callback).After(*call_before); + } else { + EXPECT_CALL(ready, Callback); + } + } + + MockGrpcClosure ready; + absl::Notification notification; + grpc_core::OrphanablePtr grpc_message; +}; + +struct MakeRecvTrailingMetadata { + explicit MakeRecvTrailingMetadata(grpc_transport_stream_op_batch* op, + Expectation* call_before = nullptr) + : ready(¬ification) { + op->recv_trailing_metadata = true; + op->payload->recv_trailing_metadata.recv_trailing_metadata = + &grpc_trailing_metadata; + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = + ready.GetGrpcClosure(); + if (call_before) { + EXPECT_CALL(ready, Callback).After(*call_before); + } else { + EXPECT_CALL(ready, Callback); + } + } + + ~MakeRecvTrailingMetadata() {} + + MockGrpcClosure ready; + grpc_core::ScopedArenaPtr arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch grpc_trailing_metadata{arena.get()}; + absl::Notification notification; +}; + +const Metadata kDefaultMetadata = { + {"", ""}, + {"", "value"}, + {"key", ""}, + {"key", "value"}, +}; + +constexpr char kDefaultMethodRef[] = "/some/path"; +constexpr char kDefaultMessage[] = "binder transport message"; +constexpr int kDefaultStatus = 0x1234; + +Metadata AppendMethodRef(const Metadata& md, const std::string& method_ref) { + Metadata result = md; + result.emplace_back(":path", method_ref); + return result; +} + +Metadata AppendStatus(const Metadata& md, int status) { + Metadata result = md; + result.emplace_back("grpc-status", std::to_string(status)); + return result; +} + +} // namespace + +TEST_F(BinderTransportTest, CreateBinderTransport) { + EXPECT_NE(transport_, nullptr); +} + +TEST_F(BinderTransportTest, TransactionIdIncrement) { + grpc_binder_stream* gbs0 = InitNewBinderStream(); + EXPECT_EQ(gbs0->t, GetBinderTransport()); + EXPECT_EQ(gbs0->tx_code, kFirstCallId); + grpc_binder_stream* gbs1 = InitNewBinderStream(); + EXPECT_EQ(gbs1->t, GetBinderTransport()); + EXPECT_EQ(gbs1->tx_code, kFirstCallId + 1); + grpc_binder_stream* gbs2 = InitNewBinderStream(); + EXPECT_EQ(gbs2->t, GetBinderTransport()); + EXPECT_EQ(gbs2->tx_code, kFirstCallId + 2); +} + +TEST_F(BinderTransportTest, PerformSendInitialMetadata) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + const Metadata kInitialMetadata = kDefaultMetadata; + MakeSendInitialMetadata send_initial_metadata(kInitialMetadata, "", &op); + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + ::testing::InSequence sequence; + EXPECT_CALL(GetWireWriter(), RpcCall(TransactionMatches( + kFlagPrefix, "", kInitialMetadata, ""))); + EXPECT_CALL(mock_on_complete, Callback); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); +} + +TEST_F(BinderTransportTest, PerformSendInitialMetadataMethodRef) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + const Metadata kInitialMetadata = kDefaultMetadata; + const std::string kMethodRef = kDefaultMethodRef; + MakeSendInitialMetadata send_initial_metadata(kInitialMetadata, kMethodRef, + &op); + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + ::testing::InSequence sequence; + EXPECT_CALL(GetWireWriter(), + RpcCall(TransactionMatches(kFlagPrefix, kMethodRef.substr(1), + kInitialMetadata, ""))); + EXPECT_CALL(mock_on_complete, Callback); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); +} + +TEST_F(BinderTransportTest, PerformSendMessage) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + const std::string kMessage = kDefaultMessage; + MakeSendMessage send_message(kMessage, &op); + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + ::testing::InSequence sequence; + EXPECT_CALL( + GetWireWriter(), + RpcCall(TransactionMatches(kFlagMessageData, "", Metadata{}, kMessage))); + EXPECT_CALL(mock_on_complete, Callback); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); +} + +TEST_F(BinderTransportTest, PerformSendTrailingMetadata) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + // The wireformat guarantees that suffix metadata will always be empty. + // TODO(waynetu): Check whether gRPC can internally add extra trailing + // metadata. + const Metadata kTrailingMetadata = {}; + MakeSendTrailingMetadata send_trailing_metadata(kTrailingMetadata, &op); + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + ::testing::InSequence sequence; + EXPECT_CALL(GetWireWriter(), RpcCall(TransactionMatches( + kFlagSuffix, "", kTrailingMetadata, ""))); + EXPECT_CALL(mock_on_complete, Callback); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); +} + +TEST_F(BinderTransportTest, PerformSendAll) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + const Metadata kInitialMetadata = kDefaultMetadata; + const std::string kMethodRef = kDefaultMethodRef; + MakeSendInitialMetadata send_initial_metadata(kInitialMetadata, kMethodRef, + &op); + + const std::string kMessage = kDefaultMessage; + MakeSendMessage send_message(kMessage, &op); + + // The wireformat guarantees that suffix metadata will always be empty. + // TODO(waynetu): Check whether gRPC can internally add extra trailing + // metadata. + const Metadata kTrailingMetadata = {}; + MakeSendTrailingMetadata send_trailing_metadata(kTrailingMetadata, &op); + + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + ::testing::InSequence sequence; + EXPECT_CALL(GetWireWriter(), + RpcCall(TransactionMatches( + kFlagPrefix | kFlagMessageData | kFlagSuffix, + kMethodRef.substr(1), kInitialMetadata, kMessage))); + EXPECT_CALL(mock_on_complete, Callback); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); +} + +TEST_F(BinderTransportTest, PerformRecvInitialMetadata) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + MakeRecvInitialMetadata recv_initial_metadata(&op); + + const Metadata kInitialMetadata = kDefaultMetadata; + auto* gbt = reinterpret_cast(transport_); + gbt->transport_stream_receiver->NotifyRecvInitialMetadata(gbs->tx_code, + kInitialMetadata); + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); + + VerifyMetadataEqual(kInitialMetadata, + recv_initial_metadata.grpc_initial_metadata); +} + +TEST_F(BinderTransportTest, PerformRecvInitialMetadataWithMethodRef) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + MakeRecvInitialMetadata recv_initial_metadata(&op); + + auto* gbt = reinterpret_cast(transport_); + const Metadata kInitialMetadataWithMethodRef = + AppendMethodRef(kDefaultMetadata, kDefaultMethodRef); + gbt->transport_stream_receiver->NotifyRecvInitialMetadata( + gbs->tx_code, kInitialMetadataWithMethodRef); + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); + + VerifyMetadataEqual(kInitialMetadataWithMethodRef, + recv_initial_metadata.grpc_initial_metadata); +} + +TEST_F(BinderTransportTest, PerformRecvMessage) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + MakeRecvMessage recv_message(&op); + + auto* gbt = reinterpret_cast(transport_); + const std::string kMessage = kDefaultMessage; + gbt->transport_stream_receiver->NotifyRecvMessage(gbs->tx_code, kMessage); + + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); + recv_message.notification.WaitForNotification(); + + EXPECT_TRUE(recv_message.grpc_message->Next(SIZE_MAX, nullptr)); + grpc_slice slice; + EXPECT_EQ(recv_message.grpc_message->Pull(&slice), GRPC_ERROR_NONE); + EXPECT_EQ(kMessage, + std::string(reinterpret_cast(GRPC_SLICE_START_PTR(slice)), + GRPC_SLICE_LENGTH(slice))); + grpc_slice_unref_internal(slice); +} + +TEST_F(BinderTransportTest, PerformRecvTrailingMetadata) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + MakeRecvTrailingMetadata recv_trailing_metadata(&op); + + const Metadata kTrailingMetadata = kDefaultMetadata; + auto* gbt = reinterpret_cast(transport_); + constexpr int kStatus = kDefaultStatus; + gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( + gbs->tx_code, kTrailingMetadata, kStatus); + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); + recv_trailing_metadata.notification.WaitForNotification(); + + VerifyMetadataEqual(AppendStatus(kTrailingMetadata, kStatus), + recv_trailing_metadata.grpc_trailing_metadata); +} + +TEST_F(BinderTransportTest, PerformRecvAll) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + MakeRecvInitialMetadata recv_initial_metadata(&op); + MakeRecvMessage recv_message(&op); + MakeRecvTrailingMetadata recv_trailing_metadata(&op); + + auto* gbt = reinterpret_cast(transport_); + const Metadata kInitialMetadataWithMethodRef = + AppendMethodRef(kDefaultMetadata, kDefaultMethodRef); + gbt->transport_stream_receiver->NotifyRecvInitialMetadata( + gbs->tx_code, kInitialMetadataWithMethodRef); + + const std::string kMessage = kDefaultMessage; + gbt->transport_stream_receiver->NotifyRecvMessage(gbs->tx_code, kMessage); + + Metadata trailing_metadata = kDefaultMetadata; + constexpr int kStatus = kDefaultStatus; + gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( + gbs->tx_code, trailing_metadata, kStatus); + PerformStreamOp(gbs, &op); + grpc_core::ExecCtx::Get()->Flush(); + recv_trailing_metadata.notification.WaitForNotification(); + + VerifyMetadataEqual(kInitialMetadataWithMethodRef, + recv_initial_metadata.grpc_initial_metadata); + trailing_metadata.emplace_back("grpc-status", std::to_string(kStatus)); + VerifyMetadataEqual(trailing_metadata, + recv_trailing_metadata.grpc_trailing_metadata); + EXPECT_TRUE(recv_message.grpc_message->Next(SIZE_MAX, nullptr)); + grpc_slice slice; + EXPECT_EQ(recv_message.grpc_message->Pull(&slice), GRPC_ERROR_NONE); + EXPECT_EQ(kMessage, + std::string(reinterpret_cast(GRPC_SLICE_START_PTR(slice)), + GRPC_SLICE_LENGTH(slice))); + grpc_slice_unref_internal(slice); +} + +TEST_F(BinderTransportTest, PerformAllOps) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + grpc_transport_stream_op_batch op{}; + grpc_transport_stream_op_batch_payload payload(nullptr); + op.payload = &payload; + + const Metadata kSendInitialMetadata = kDefaultMetadata; + const std::string kMethodRef = kDefaultMethodRef; + MakeSendInitialMetadata send_initial_metadata(kSendInitialMetadata, + kMethodRef, &op); + + const std::string kSendMessage = kDefaultMessage; + MakeSendMessage send_message(kSendMessage, &op); + + // The wireformat guarantees that suffix metadata will always be empty. + // TODO(waynetu): Check whether gRPC can internally add extra trailing + // metadata. + const Metadata kSendTrailingMetadata = {}; + MakeSendTrailingMetadata send_trailing_metadata(kSendTrailingMetadata, &op); + + MockGrpcClosure mock_on_complete; + op.on_complete = mock_on_complete.GetGrpcClosure(); + + // TODO(waynetu): Currently, we simply drop the prefix '/' from the :path + // argument to obtain the method name. Update the test if this turns out to be + // incorrect. + EXPECT_CALL(GetWireWriter(), + RpcCall(TransactionMatches( + kFlagPrefix | kFlagMessageData | kFlagSuffix, + kMethodRef.substr(1), kSendInitialMetadata, kSendMessage))); + Expectation on_complete = EXPECT_CALL(mock_on_complete, Callback); + + // Recv callbacks can happen after the on_complete callback. + MakeRecvInitialMetadata recv_initial_metadata( + &op, /* call_before = */ &on_complete); + MakeRecvMessage recv_message(&op, /* call_before = */ &on_complete); + MakeRecvTrailingMetadata recv_trailing_metadata( + &op, /* call_before = */ &on_complete); + + PerformStreamOp(gbs, &op); + + // Flush the execution context to force on_complete to run before recv + // callbacks get scheduled. + grpc_core::ExecCtx::Get()->Flush(); + + auto* gbt = reinterpret_cast(transport_); + const Metadata kRecvInitialMetadata = + AppendMethodRef(kDefaultMetadata, kDefaultMethodRef); + gbt->transport_stream_receiver->NotifyRecvInitialMetadata( + gbs->tx_code, kRecvInitialMetadata); + const std::string kRecvMessage = kDefaultMessage; + gbt->transport_stream_receiver->NotifyRecvMessage(gbs->tx_code, kRecvMessage); + const Metadata kRecvTrailingMetadata = kDefaultMetadata; + constexpr int kStatus = 0x1234; + gbt->transport_stream_receiver->NotifyRecvTrailingMetadata( + gbs->tx_code, kRecvTrailingMetadata, kStatus); + + grpc_core::ExecCtx::Get()->Flush(); + recv_initial_metadata.notification.WaitForNotification(); + recv_message.notification.WaitForNotification(); + recv_trailing_metadata.notification.WaitForNotification(); + + VerifyMetadataEqual(kRecvInitialMetadata, + recv_initial_metadata.grpc_initial_metadata); + VerifyMetadataEqual(AppendStatus(kRecvTrailingMetadata, kStatus), + recv_trailing_metadata.grpc_trailing_metadata); + + EXPECT_TRUE(recv_message.grpc_message->Next(SIZE_MAX, nullptr)); + grpc_slice slice; + EXPECT_EQ(recv_message.grpc_message->Pull(&slice), GRPC_ERROR_NONE); + EXPECT_EQ(kRecvMessage, + std::string(reinterpret_cast(GRPC_SLICE_START_PTR(slice)), + GRPC_SLICE_LENGTH(slice))); + grpc_slice_unref_internal(slice); +} + +TEST_F(BinderTransportTest, WireWriterRpcCallErrorPropagates) { + grpc_core::ExecCtx exec_ctx; + grpc_binder_stream* gbs = InitNewBinderStream(); + + MockGrpcClosure mock_on_complete1; + MockGrpcClosure mock_on_complete2; + + EXPECT_CALL(GetWireWriter(), RpcCall) + .WillOnce(Return(absl::OkStatus())) + .WillOnce(Return(absl::InternalError("WireWriter::RpcCall failed"))); + EXPECT_CALL(mock_on_complete1, Callback(GRPC_ERROR_NONE)); + EXPECT_CALL(mock_on_complete2, + Callback(GrpcErrorMessageContains("WireWriter::RpcCall failed"))); + + const Metadata kInitialMetadata = {}; + grpc_transport_stream_op_batch op1{}; + grpc_transport_stream_op_batch_payload payload1(nullptr); + op1.payload = &payload1; + MakeSendInitialMetadata send_initial_metadata1(kInitialMetadata, "", &op1); + op1.on_complete = mock_on_complete1.GetGrpcClosure(); + + grpc_transport_stream_op_batch op2{}; + grpc_transport_stream_op_batch_payload payload2(nullptr); + op2.payload = &payload2; + MakeSendInitialMetadata send_initial_metadata2(kInitialMetadata, "", &op2); + op2.on_complete = mock_on_complete2.GetGrpcClosure(); + + PerformStreamOp(gbs, &op1); + PerformStreamOp(gbs, &op2); + grpc_core::ExecCtx::Get()->Flush(); +} + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/end2end/binder_server_test.cc b/test/core/transport/binder/end2end/binder_server_test.cc new file mode 100644 index 00000000..67fc635b --- /dev/null +++ b/test/core/transport/binder/end2end/binder_server_test.cc @@ -0,0 +1,217 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/transport/binder/server/binder_server.h" + +#include +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include + +#include "src/core/ext/transport/binder/client/channel_create_impl.h" +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" +#include "src/core/ext/transport/binder/server/binder_server.h" +#include "src/core/ext/transport/binder/server/binder_server_credentials.h" +#include "test/core/transport/binder/end2end/fake_binder.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +namespace grpc { +namespace testing { + +namespace { + +class BinderServerCredentialsImpl final : public ServerCredentials { + public: + int AddPortToServer(const std::string& addr, grpc_server* server) override { + return grpc_core::AddBinderPort( + addr, server, + [](grpc_binder::TransactionReceiver::OnTransactCb transact_cb) { + return absl::make_unique< + grpc_binder::end2end_testing::FakeTransactionReceiver>( + nullptr, std::move(transact_cb)); + }, + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>()); + } + + void SetAuthMetadataProcessor( + const std::shared_ptr& /*processor*/) override { + GPR_ASSERT(false); + } + + private: + bool IsInsecure() const override { return true; } +}; + +} // namespace + +std::shared_ptr BinderServerCredentials() { + return std::shared_ptr(new BinderServerCredentialsImpl()); +} + +std::shared_ptr CreateBinderChannel( + std::unique_ptr endpoint_binder) { + grpc::internal::GrpcLibrary init_lib; + init_lib.init(); + + return grpc::CreateChannelInternal( + "", + grpc::internal::CreateChannelFromBinderImpl( + std::move(endpoint_binder), + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>(), + nullptr), + std::vector>()); +} + +} // namespace testing +} // namespace grpc + +namespace { + +class BinderServerTest : public ::testing::Test { + public: + BinderServerTest() { + grpc_binder::end2end_testing::g_transaction_processor = + new grpc_binder::end2end_testing::TransactionProcessor(); + } + ~BinderServerTest() override { + delete grpc_binder::end2end_testing::g_transaction_processor; + } + static void SetUpTestSuite() { grpc_init(); } + static void TearDownTestSuite() { grpc_shutdown(); } +}; + +#ifndef GPR_SUPPORT_BINDER_TRANSPORT +TEST(BinderServerCredentialsTest, + FailedInEnvironmentsNotSupportingBinderTransport) { + grpc::ServerBuilder server_builder; + grpc::testing::TestServiceImpl service; + server_builder.RegisterService(&service); + server_builder.AddListeningPort( + "binder:fail", + grpc::experimental::BinderServerCredentials( + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>())); + EXPECT_EQ(server_builder.BuildAndStart(), nullptr); +} +#endif // !GPR_SUPPORT_BINDER_TRANSPORT + +TEST_F(BinderServerTest, BuildAndStart) { + grpc::ServerBuilder server_builder; + grpc::testing::TestServiceImpl service; + server_builder.RegisterService(&service); + server_builder.AddListeningPort("binder:example.service", + grpc::testing::BinderServerCredentials()); + std::unique_ptr server = server_builder.BuildAndStart(); + EXPECT_NE(grpc::experimental::binder::GetEndpointBinder("example.service"), + nullptr); + server->Shutdown(); + EXPECT_EQ(grpc::experimental::binder::GetEndpointBinder("example.service"), + nullptr); +} + +TEST_F(BinderServerTest, BuildAndStartFailed) { + grpc::ServerBuilder server_builder; + grpc::testing::TestServiceImpl service; + server_builder.RegisterService(&service); + // Error: binder address should begin with binder: + server_builder.AddListeningPort("localhost:12345", + grpc::testing::BinderServerCredentials()); + std::unique_ptr server = server_builder.BuildAndStart(); + EXPECT_EQ(server, nullptr); +} + +TEST_F(BinderServerTest, CreateChannelWithEndpointBinder) { + grpc::ServerBuilder server_builder; + grpc::testing::TestServiceImpl service; + server_builder.RegisterService(&service); + server_builder.AddListeningPort("binder:example.service", + grpc::testing::BinderServerCredentials()); + std::unique_ptr server = server_builder.BuildAndStart(); + void* raw_endpoint_binder = + grpc::experimental::binder::GetEndpointBinder("example.service"); + std::unique_ptr endpoint_binder = + absl::make_unique( + static_cast( + raw_endpoint_binder)); + std::shared_ptr channel = + grpc::testing::CreateBinderChannel(std::move(endpoint_binder)); + std::unique_ptr stub = + grpc::testing::EchoTestService::NewStub(channel); + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + grpc::ClientContext context; + request.set_message("BinderServerBuilder"); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.message(), "BinderServerBuilder"); + server->Shutdown(); +} + +TEST_F(BinderServerTest, CreateChannelWithEndpointBinderMultipleConnections) { + grpc::ServerBuilder server_builder; + grpc::testing::TestServiceImpl service; + server_builder.RegisterService(&service); + server_builder.AddListeningPort("binder:example.service.multiple.connections", + grpc::testing::BinderServerCredentials()); + std::unique_ptr server = server_builder.BuildAndStart(); + void* raw_endpoint_binder = grpc::experimental::binder::GetEndpointBinder( + "example.service.multiple.connections"); + constexpr size_t kNumThreads = 128; + + auto thread_fn = [&](size_t id) { + std::unique_ptr endpoint_binder = + absl::make_unique( + static_cast( + raw_endpoint_binder)); + std::shared_ptr channel = + grpc::testing::CreateBinderChannel(std::move(endpoint_binder)); + std::unique_ptr stub = + grpc::testing::EchoTestService::NewStub(channel); + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + grpc::ClientContext context; + request.set_message(absl::StrFormat("BinderServerBuilder-%d", id)); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.message(), + absl::StrFormat("BinderServerBuilder-%d", id)); + }; + + std::vector threads(kNumThreads); + for (size_t i = 0; i < kNumThreads; ++i) { + threads[i] = std::thread(thread_fn, i); + } + for (auto& thr : threads) { + thr.join(); + } + server->Shutdown(); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/end2end/end2end_binder_transport_test.cc b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc new file mode 100644 index 00000000..da9a7d09 --- /dev/null +++ b/test/core/transport/binder/end2end/end2end_binder_transport_test.cc @@ -0,0 +1,560 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/time/time.h" + +#include + +#include "src/core/ext/transport/binder/transport/binder_transport.h" +#include "src/core/ext/transport/binder/wire_format/wire_reader_impl.h" +#include "test/core/transport/binder/end2end/fake_binder.h" +#include "test/core/transport/binder/end2end/testing_channel_create.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +namespace grpc_binder { + +namespace { + +class End2EndBinderTransportTest + : public ::testing::TestWithParam { + public: + End2EndBinderTransportTest() { + end2end_testing::g_transaction_processor = + new end2end_testing::TransactionProcessor(GetParam()); + service_ = absl::make_unique(); + grpc::ServerBuilder builder; + builder.RegisterService(service_.get()); + server_ = builder.BuildAndStart(); + } + + ~End2EndBinderTransportTest() override { + server_->Shutdown(); + service_.reset(); + delete end2end_testing::g_transaction_processor; + } + + std::unique_ptr NewStub() { + grpc::ChannelArguments args; + std::shared_ptr channel = BinderChannel(server_.get(), args); + return grpc::testing::EchoTestService::NewStub(channel); + } + + static void SetUpTestSuite() { grpc_init(); } + static void TearDownTestSuite() { grpc_shutdown(); } + + std::shared_ptr BinderChannel( + grpc::Server* server, const grpc::ChannelArguments& args) { + return end2end_testing::BinderChannelForTesting(server, args); + } + + protected: + std::unique_ptr service_; + std::unique_ptr server_; +}; + +} // namespace + +TEST_P(End2EndBinderTransportTest, SetupTransport) { + grpc_core::ExecCtx exec_ctx; + grpc_transport *client_transport, *server_transport; + std::tie(client_transport, server_transport) = + end2end_testing::CreateClientServerBindersPairForTesting(); + EXPECT_NE(client_transport, nullptr); + EXPECT_NE(server_transport, nullptr); + + grpc_transport_destroy(client_transport); + grpc_transport_destroy(server_transport); +} + +TEST_P(End2EndBinderTransportTest, UnaryCall) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCall"); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.message(), "UnaryCall"); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallWithNonOkStatus) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallWithNonOkStatus"); + request.mutable_param()->mutable_expected_error()->set_code( + grpc::StatusCode::INTERNAL); + request.mutable_param()->mutable_expected_error()->set_error_message( + "expected to fail"); + // Server will not response the client with message data, however, since all + // callbacks after the trailing metadata are cancelled, we shall not be + // blocked here. + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expected to fail")); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallServerTimeout) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + context.set_deadline(absl::ToChronoTime(absl::Now() + absl::Seconds(1))); + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallServerTimeout"); + // Server will sleep for 2 seconds before responding us. + request.mutable_param()->set_server_sleep_us(2000000); + // Disable cancellation check because the request will time out. + request.mutable_param()->set_skip_cancelled_check(true); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::DEADLINE_EXCEEDED); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallClientTimeout) { + std::unique_ptr stub = NewStub(); + + // Set transaction delay to a large number. This happens after the channel + // creation so that we don't need to wait that long for client and server to + // be connected. + end2end_testing::g_transaction_processor->SetDelay(absl::Seconds(5)); + + grpc::ClientContext context; + context.set_deadline(absl::ToChronoTime(absl::Now() + absl::Seconds(1))); + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallClientTimeout"); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::DEADLINE_EXCEEDED); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallUnimplemented) { + std::unique_ptr stub = NewStub(); + + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallUnimplemented"); + grpc::Status status = stub->Unimplemented(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallClientCancel) { + std::unique_ptr stub = NewStub(); + + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallClientCancel"); + context.TryCancel(); + grpc::Status status = stub->Unimplemented(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallEchoMetadataInitially) { + std::unique_ptr stub = NewStub(); + + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallEchoMetadataInitially"); + request.mutable_param()->set_echo_metadata_initially(true); + context.AddMetadata("key1", "value1"); + context.AddMetadata("key2", "value2"); + grpc::Status status = stub->Echo(&context, request, &response); + const auto& initial_metadata = context.GetServerInitialMetadata(); + EXPECT_EQ(initial_metadata.find("key1")->second, "value1"); + EXPECT_EQ(initial_metadata.find("key2")->second, "value2"); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallEchoMetadata) { + std::unique_ptr stub = NewStub(); + + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallEchoMetadata"); + request.mutable_param()->set_echo_metadata(true); + context.AddMetadata("key1", "value1"); + context.AddMetadata("key2", "value2"); + grpc::Status status = stub->Echo(&context, request, &response); + const auto& initial_metadata = context.GetServerTrailingMetadata(); + EXPECT_EQ(initial_metadata.find("key1")->second, "value1"); + EXPECT_EQ(initial_metadata.find("key2")->second, "value2"); +} + +TEST_P(End2EndBinderTransportTest, UnaryCallResponseMessageLength) { + std::unique_ptr stub = NewStub(); + + for (size_t response_length : {1, 2, 5, 10, 100, 1000000}) { + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallResponseMessageLength"); + request.mutable_param()->set_response_message_length(response_length); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_EQ(response.message().length(), response_length); + } +} + +TEST_P(End2EndBinderTransportTest, UnaryCallTryCancel) { + std::unique_ptr stub = NewStub(); + + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_BEFORE_PROCESSING)); + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message("UnaryCallTryCancel"); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, ServerStreamingCall) { + std::unique_ptr stub = NewStub(); + constexpr size_t kServerResponseStreamsToSend = 100; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerResponseStreamsToSend, + std::to_string(kServerResponseStreamsToSend)); + grpc::testing::EchoRequest request; + request.set_message("ServerStreamingCall"); + std::unique_ptr> reader = + stub->ResponseStream(&context, request); + grpc::testing::EchoResponse response; + size_t cnt = 0; + while (reader->Read(&response)) { + EXPECT_EQ(response.message(), "ServerStreamingCall" + std::to_string(cnt)); + cnt++; + } + EXPECT_EQ(cnt, kServerResponseStreamsToSend); + grpc::Status status = reader->Finish(); + EXPECT_TRUE(status.ok()); +} + +TEST_P(End2EndBinderTransportTest, ServerStreamingCallCoalescingApi) { + std::unique_ptr stub = NewStub(); + constexpr size_t kServerResponseStreamsToSend = 100; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerResponseStreamsToSend, + std::to_string(kServerResponseStreamsToSend)); + context.AddMetadata(grpc::testing::kServerUseCoalescingApi, "1"); + grpc::testing::EchoRequest request; + request.set_message("ServerStreamingCallCoalescingApi"); + std::unique_ptr> reader = + stub->ResponseStream(&context, request); + grpc::testing::EchoResponse response; + size_t cnt = 0; + while (reader->Read(&response)) { + EXPECT_EQ(response.message(), + "ServerStreamingCallCoalescingApi" + std::to_string(cnt)); + cnt++; + } + EXPECT_EQ(cnt, kServerResponseStreamsToSend); + grpc::Status status = reader->Finish(); + EXPECT_TRUE(status.ok()); +} + +TEST_P(End2EndBinderTransportTest, + ServerStreamingCallTryCancelBeforeProcessing) { + std::unique_ptr stub = NewStub(); + constexpr size_t kServerResponseStreamsToSend = 100; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerResponseStreamsToSend, + std::to_string(kServerResponseStreamsToSend)); + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_BEFORE_PROCESSING)); + grpc::testing::EchoRequest request; + request.set_message("ServerStreamingCallTryCancelBeforeProcessing"); + std::unique_ptr> reader = + stub->ResponseStream(&context, request); + grpc::testing::EchoResponse response; + EXPECT_FALSE(reader->Read(&response)); + grpc::Status status = reader->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, + ServerSteramingCallTryCancelDuringProcessing) { + std::unique_ptr stub = NewStub(); + constexpr size_t kServerResponseStreamsToSend = 2; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerResponseStreamsToSend, + std::to_string(kServerResponseStreamsToSend)); + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_DURING_PROCESSING)); + grpc::testing::EchoRequest request; + request.set_message("ServerStreamingCallTryCancelDuringProcessing"); + std::unique_ptr> reader = + stub->ResponseStream(&context, request); + grpc::testing::EchoResponse response; + size_t cnt = 0; + while (reader->Read(&response)) { + EXPECT_EQ( + response.message(), + "ServerStreamingCallTryCancelDuringProcessing" + std::to_string(cnt)); + cnt++; + } + grpc::Status status = reader->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, + ServerSteramingCallTryCancelAfterProcessing) { + std::unique_ptr stub = NewStub(); + constexpr size_t kServerResponseStreamsToSend = 100; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerResponseStreamsToSend, + std::to_string(kServerResponseStreamsToSend)); + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_AFTER_PROCESSING)); + grpc::testing::EchoRequest request; + request.set_message("ServerStreamingCallTryCancelAfterProcessing"); + std::unique_ptr> reader = + stub->ResponseStream(&context, request); + grpc::testing::EchoResponse response; + size_t cnt = 0; + while (reader->Read(&response)) { + EXPECT_EQ( + response.message(), + "ServerStreamingCallTryCancelAfterProcessing" + std::to_string(cnt)); + cnt++; + } + EXPECT_EQ(cnt, kServerResponseStreamsToSend); + grpc::Status status = reader->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, ClientStreamingCall) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + grpc::testing::EchoResponse response; + std::unique_ptr> writer = + stub->RequestStream(&context, &response); + constexpr size_t kClientStreamingCounts = 100; + std::string expected = ""; + for (size_t i = 0; i < kClientStreamingCounts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("ClientStreamingCall" + std::to_string(i)); + EXPECT_TRUE(writer->Write(request)); + expected += "ClientStreamingCall" + std::to_string(i); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.message(), expected); +} + +TEST_P(End2EndBinderTransportTest, + ClientStreamingCallTryCancelBeforeProcessing) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_BEFORE_PROCESSING)); + grpc::testing::EchoResponse response; + std::unique_ptr> writer = + stub->RequestStream(&context, &response); + constexpr size_t kClientStreamingCounts = 100; + for (size_t i = 0; i < kClientStreamingCounts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("ClientStreamingCallBeforeProcessing" + + std::to_string(i)); + writer->Write(request); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, + ClientStreamingCallTryCancelDuringProcessing) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_DURING_PROCESSING)); + grpc::testing::EchoResponse response; + std::unique_ptr> writer = + stub->RequestStream(&context, &response); + constexpr size_t kClientStreamingCounts = 100; + for (size_t i = 0; i < kClientStreamingCounts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("ClientStreamingCallDuringProcessing" + + std::to_string(i)); + writer->Write(request); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, + ClientStreamingCallTryCancelAfterProcessing) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerTryCancelRequest, + std::to_string(grpc::testing::CANCEL_AFTER_PROCESSING)); + grpc::testing::EchoResponse response; + std::unique_ptr> writer = + stub->RequestStream(&context, &response); + constexpr size_t kClientStreamingCounts = 100; + for (size_t i = 0; i < kClientStreamingCounts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("ClientStreamingCallAfterProcessing" + + std::to_string(i)); + writer->Write(request); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), grpc::StatusCode::CANCELLED); +} + +TEST_P(End2EndBinderTransportTest, BiDirStreamingCall) { + std::unique_ptr stub = NewStub(); + grpc::ClientContext context; + std::shared_ptr> + stream = stub->BidiStream(&context); + constexpr size_t kBiDirStreamingCounts = 100; + + struct WriterArgs { + std::shared_ptr> + stream; + size_t bi_dir_streaming_counts; + } writer_args; + + writer_args.stream = stream; + writer_args.bi_dir_streaming_counts = kBiDirStreamingCounts; + + auto writer_fn = [](void* arg) { + const WriterArgs& args = *static_cast(arg); + for (size_t i = 0; i < args.bi_dir_streaming_counts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("BiDirStreamingCall" + std::to_string(i)); + args.stream->Write(request); + } + args.stream->WritesDone(); + }; + + grpc_core::Thread writer_thread("writer-thread", writer_fn, + static_cast(&writer_args)); + writer_thread.Start(); + for (size_t i = 0; i < kBiDirStreamingCounts; ++i) { + grpc::testing::EchoResponse response; + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), "BiDirStreamingCall" + std::to_string(i)); + } + grpc::Status status = stream->Finish(); + EXPECT_TRUE(status.ok()); + writer_thread.Join(); +} + +TEST_P(End2EndBinderTransportTest, BiDirStreamingCallServerFinishesHalfway) { + std::unique_ptr stub = NewStub(); + constexpr size_t kBiDirStreamingCounts = 100; + grpc::ClientContext context; + context.AddMetadata(grpc::testing::kServerFinishAfterNReads, + std::to_string(kBiDirStreamingCounts / 2)); + std::shared_ptr> + stream = stub->BidiStream(&context); + + struct WriterArgs { + std::shared_ptr> + stream; + size_t bi_dir_streaming_counts; + } writer_args; + + writer_args.stream = stream; + writer_args.bi_dir_streaming_counts = kBiDirStreamingCounts; + + auto writer_fn = [](void* arg) { + const WriterArgs& args = *static_cast(arg); + for (size_t i = 0; i < args.bi_dir_streaming_counts; ++i) { + grpc::testing::EchoRequest request; + request.set_message("BiDirStreamingCallServerFinishesHalfway" + + std::to_string(i)); + if (!args.stream->Write(request)) { + return; + } + } + args.stream->WritesDone(); + }; + + grpc_core::Thread writer_thread("writer-thread", writer_fn, + static_cast(&writer_args)); + writer_thread.Start(); + for (size_t i = 0; i < kBiDirStreamingCounts / 2; ++i) { + grpc::testing::EchoResponse response; + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), + "BiDirStreamingCallServerFinishesHalfway" + std::to_string(i)); + } + grpc::testing::EchoResponse response; + EXPECT_FALSE(stream->Read(&response)); + writer_thread.Join(); + grpc::Status status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +TEST_P(End2EndBinderTransportTest, LargeMessages) { + std::unique_ptr stub = NewStub(); + for (size_t size = 1; size <= 1024 * 1024; size *= 4) { + grpc::ClientContext context; + grpc::testing::EchoRequest request; + grpc::testing::EchoResponse response; + request.set_message(std::string(size, 'a')); + grpc::Status status = stub->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.message().size(), size); + EXPECT_TRUE(std::all_of(response.message().begin(), + response.message().end(), + [](char c) { return c == 'a'; })); + } +} + +INSTANTIATE_TEST_SUITE_P( + End2EndBinderTransportTestWithDifferentDelayTimes, + End2EndBinderTransportTest, + testing::Values(absl::ZeroDuration(), absl::Nanoseconds(10), + absl::Microseconds(10), absl::Microseconds(100), + absl::Milliseconds(1), absl::Milliseconds(20))); + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/end2end/fake_binder.cc b/test/core/transport/binder/end2end/fake_binder.cc new file mode 100644 index 00000000..6d0d1cd3 --- /dev/null +++ b/test/core/transport/binder/end2end/fake_binder.cc @@ -0,0 +1,273 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/binder/end2end/fake_binder.h" + +#include +#include + +#include + +namespace grpc_binder { +namespace end2end_testing { + +TransactionProcessor* g_transaction_processor = nullptr; + +int32_t FakeWritableParcel::GetDataSize() const { return data_size_; } + +absl::Status FakeWritableParcel::WriteInt32(int32_t data) { + data_.push_back(data); + data_size_ += sizeof(int32_t); + return absl::OkStatus(); +} + +absl::Status FakeWritableParcel::WriteInt64(int64_t data) { + data_.push_back(data); + data_size_ += sizeof(int64_t); + return absl::OkStatus(); +} + +absl::Status FakeWritableParcel::WriteBinder(HasRawBinder* binder) { + data_.push_back(binder->GetRawBinder()); + data_size_ += sizeof(void*); + return absl::OkStatus(); +} + +absl::Status FakeWritableParcel::WriteString(absl::string_view s) { + data_.push_back(std::string(s)); + data_size_ += s.size(); + return absl::OkStatus(); +} + +absl::Status FakeWritableParcel::WriteByteArray(const int8_t* buffer, + int32_t length) { + data_.push_back(std::vector(buffer, buffer + length)); + data_size_ += length; + return absl::OkStatus(); +} + +int32_t FakeReadableParcel::GetDataSize() const { return data_size_; } + +absl::Status FakeReadableParcel::ReadInt32(int32_t* data) { + if (data_position_ >= data_.size() || + !absl::holds_alternative(data_[data_position_])) { + return absl::InternalError("ReadInt32 failed"); + } + *data = absl::get(data_[data_position_++]); + return absl::OkStatus(); +} + +absl::Status FakeReadableParcel::ReadInt64(int64_t* data) { + if (data_position_ >= data_.size() || + !absl::holds_alternative(data_[data_position_])) { + return absl::InternalError("ReadInt64 failed"); + } + *data = absl::get(data_[data_position_++]); + return absl::OkStatus(); +} + +absl::Status FakeReadableParcel::ReadBinder(std::unique_ptr* data) { + if (data_position_ >= data_.size() || + !absl::holds_alternative(data_[data_position_])) { + return absl::InternalError("ReadBinder failed"); + } + void* endpoint = absl::get(data_[data_position_++]); + if (!endpoint) return absl::InternalError("ReadBinder failed"); + *data = absl::make_unique(static_cast(endpoint)); + return absl::OkStatus(); +} + +absl::Status FakeReadableParcel::ReadString(std::string* str) { + if (data_position_ >= data_.size() || + !absl::holds_alternative(data_[data_position_])) { + return absl::InternalError("ReadString failed"); + } + *str = absl::get(data_[data_position_++]); + return absl::OkStatus(); +} + +absl::Status FakeReadableParcel::ReadByteArray(std::string* data) { + if (data_position_ >= data_.size() || + !absl::holds_alternative>(data_[data_position_])) { + return absl::InternalError("ReadByteArray failed"); + } + const std::vector& byte_array = + absl::get>(data_[data_position_++]); + data->resize(byte_array.size()); + for (size_t i = 0; i < byte_array.size(); ++i) { + (*data)[i] = byte_array[i]; + } + return absl::OkStatus(); +} + +absl::Status FakeBinder::Transact(BinderTransportTxCode tx_code) { + endpoint_->tunnel->EnQueueTransaction(endpoint_->other_end, tx_code, + input_->MoveData()); + return absl::OkStatus(); +} + +FakeTransactionReceiver::FakeTransactionReceiver( + grpc_core::RefCountedPtr wire_reader_ref, + TransactionReceiver::OnTransactCb transact_cb) { + persistent_tx_receiver_ = &g_transaction_processor->NewPersistentTxReceiver( + std::move(wire_reader_ref), std::move(transact_cb), + absl::make_unique()); +} + +std::unique_ptr FakeBinder::ConstructTxReceiver( + grpc_core::RefCountedPtr wire_reader_ref, + TransactionReceiver::OnTransactCb cb) const { + return absl::make_unique(wire_reader_ref, cb); +} + +void* FakeTransactionReceiver::GetRawBinder() { + return persistent_tx_receiver_->tunnel_->GetSendEndpoint(); +} + +std::unique_ptr FakeTransactionReceiver::GetSender() const { + return absl::make_unique( + persistent_tx_receiver_->tunnel_->GetSendEndpoint()); +} + +PersistentFakeTransactionReceiver::PersistentFakeTransactionReceiver( + grpc_core::RefCountedPtr wire_reader_ref, + TransactionReceiver::OnTransactCb cb, + std::unique_ptr tunnel) + : wire_reader_ref_(std::move(wire_reader_ref)), + callback_(std::move(cb)), + tunnel_(std::move(tunnel)) { + FakeEndpoint* recv_endpoint = tunnel_->GetRecvEndpoint(); + recv_endpoint->owner = this; +} + +TransactionProcessor::TransactionProcessor(absl::Duration delay) + : delay_nsec_(absl::ToInt64Nanoseconds(delay)), + tx_thread_( + "process-thread", + [](void* arg) { + auto* self = static_cast(arg); + self->ProcessLoop(); + }, + this), + terminated_(false) { + tx_thread_.Start(); +} + +void TransactionProcessor::SetDelay(absl::Duration delay) { + delay_nsec_ = absl::ToInt64Nanoseconds(delay); +} + +void TransactionProcessor::Terminate() { + if (!terminated_.load(std::memory_order_seq_cst)) { + gpr_log(GPR_INFO, "Terminating the processor"); + terminated_.store(true, std::memory_order_seq_cst); + tx_thread_.Join(); + gpr_log(GPR_INFO, "Processor terminated"); + } +} + +void TransactionProcessor::WaitForNextTransaction() { + absl::Time now = absl::Now(); + if (now < deliver_time_) { + absl::Duration diff = deliver_time_ - now; + // Release the lock before going to sleep. + mu_.Unlock(); + absl::SleepFor(diff); + mu_.Lock(); + } +} + +void TransactionProcessor::Flush() { + while (true) { + FakeEndpoint* target = nullptr; + BinderTransportTxCode tx_code{}; + FakeData data; + mu_.Lock(); + if (tx_queue_.empty()) { + mu_.Unlock(); + break; + } + WaitForNextTransaction(); + std::tie(target, tx_code, data) = std::move(tx_queue_.front()); + tx_queue_.pop(); + if (!tx_queue_.empty()) { + deliver_time_ = absl::Now() + GetRandomDelay(); + } + mu_.Unlock(); + auto* tx_receiver = + static_cast(target->owner); + auto parcel = absl::make_unique(std::move(data)); + tx_receiver->Receive(tx_code, parcel.get()).IgnoreError(); + } +} + +void TransactionProcessor::ProcessLoop() { + while (!terminated_.load(std::memory_order_seq_cst)) { + FakeEndpoint* target = nullptr; + BinderTransportTxCode tx_code{}; + FakeData data; + mu_.Lock(); + if (tx_queue_.empty()) { + mu_.Unlock(); + continue; + } + WaitForNextTransaction(); + std::tie(target, tx_code, data) = std::move(tx_queue_.front()); + tx_queue_.pop(); + if (!tx_queue_.empty()) { + deliver_time_ = absl::Now() + GetRandomDelay(); + } + mu_.Unlock(); + auto* tx_receiver = + static_cast(target->owner); + auto parcel = absl::make_unique(std::move(data)); + tx_receiver->Receive(tx_code, parcel.get()).IgnoreError(); + } + Flush(); +} + +absl::Duration TransactionProcessor::GetRandomDelay() { + int64_t delay = + absl::Uniform(bit_gen_, delay_nsec_ / 2, delay_nsec_); + return absl::Nanoseconds(delay); +} + +void TransactionProcessor::EnQueueTransaction(FakeEndpoint* target, + BinderTransportTxCode tx_code, + FakeData data) { + grpc_core::MutexLock lock(&mu_); + if (tx_queue_.empty()) { + // This is the first transaction in the queue. Compute its deliver time. + deliver_time_ = absl::Now() + GetRandomDelay(); + } + tx_queue_.emplace(target, tx_code, std::move(data)); +} + +FakeBinderTunnel::FakeBinderTunnel() + : send_endpoint_(absl::make_unique(this)), + recv_endpoint_(absl::make_unique(this)) { + send_endpoint_->other_end = recv_endpoint_.get(); + recv_endpoint_->other_end = send_endpoint_.get(); +} + +std::pair, std::unique_ptr> +NewBinderPair(TransactionReceiver::OnTransactCb transact_cb) { + auto tx_receiver = absl::make_unique( + nullptr, std::move(transact_cb)); + std::unique_ptr sender = tx_receiver->GetSender(); + return std::make_pair(std::move(sender), std::move(tx_receiver)); +} + +} // namespace end2end_testing +} // namespace grpc_binder diff --git a/test/core/transport/binder/end2end/fake_binder_test.cc b/test/core/transport/binder/end2end/fake_binder_test.cc new file mode 100644 index 00000000..1c32a043 --- /dev/null +++ b/test/core/transport/binder/end2end/fake_binder_test.cc @@ -0,0 +1,350 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/binder/end2end/fake_binder.h" + +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/time/time.h" + +#include "test/core/util/test_config.h" + +namespace grpc_binder { +namespace end2end_testing { +namespace { + +class FakeBinderTest : public ::testing::TestWithParam { + public: + FakeBinderTest() { + g_transaction_processor = new TransactionProcessor(GetParam()); + } + ~FakeBinderTest() override { delete g_transaction_processor; } +}; + +} // namespace + +TEST_P(FakeBinderTest, SendInt32) { + constexpr int kValue = 0x1234; + constexpr int kTxCode = 0x4321; + int called = 0; + std::unique_ptr sender; + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + EXPECT_EQ(tx_code, kTxCode); + int value = 0; + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ(value, kValue); + called++; + return absl::OkStatus(); + }); + + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteInt32(kValue).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 1); +} + +TEST_P(FakeBinderTest, SendString) { + constexpr char kValue[] = "example-string"; + constexpr int kTxCode = 0x4321; + int called = 0; + std::unique_ptr sender; + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + EXPECT_EQ(tx_code, kTxCode); + std::string value; + EXPECT_TRUE(parcel->ReadString(&value).ok()); + EXPECT_STREQ(value.c_str(), kValue); + called++; + return absl::OkStatus(); + }); + + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteString(kValue).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 1); +} + +TEST_P(FakeBinderTest, SendByteArray) { + constexpr char kValue[] = "example-byte-array"; + constexpr int kTxCode = 0x4321; + int called = 0; + std::unique_ptr sender; + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + EXPECT_EQ(tx_code, kTxCode); + std::string value; + EXPECT_TRUE(parcel->ReadByteArray(&value).ok()); + EXPECT_EQ(value, kValue); + called++; + return absl::OkStatus(); + }); + + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel + ->WriteByteArray(reinterpret_cast(kValue), + strlen(kValue)) + .ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 1); +} + +TEST_P(FakeBinderTest, SendMultipleItems) { + constexpr char kByteArray[] = "example-byte-array"; + constexpr char kString[] = "example-string"; + constexpr int kValue = 0x1234; + constexpr int kTxCode = 0x4321; + int called = 0; + std::unique_ptr sender; + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + int value_result; + EXPECT_EQ(tx_code, kTxCode); + EXPECT_TRUE(parcel->ReadInt32(&value_result).ok()); + EXPECT_EQ(value_result, kValue); + std::string byte_array_result; + EXPECT_TRUE(parcel->ReadByteArray(&byte_array_result).ok()); + EXPECT_EQ(byte_array_result, kByteArray); + std::string string_result; + EXPECT_TRUE(parcel->ReadString(&string_result).ok()); + EXPECT_STREQ(string_result.c_str(), kString); + called++; + return absl::OkStatus(); + }); + + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteInt32(kValue).ok()); + EXPECT_TRUE(parcel + ->WriteByteArray(reinterpret_cast(kByteArray), + strlen(kByteArray)) + .ok()); + EXPECT_TRUE(parcel->WriteString(kString).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 1); +} + +TEST_P(FakeBinderTest, SendBinder) { + constexpr int kValue = 0x1234; + constexpr int kTxCode = 0x4321; + int called = 0; + std::unique_ptr sender; + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + EXPECT_EQ(tx_code, kTxCode); + std::unique_ptr binder; + EXPECT_TRUE(parcel->ReadBinder(&binder).ok()); + EXPECT_TRUE(binder->PrepareTransaction().ok()); + WritableParcel* writable_parcel = binder->GetWritableParcel(); + EXPECT_TRUE(writable_parcel->WriteInt32(kValue).ok()); + EXPECT_TRUE(binder->Transact(BinderTransportTxCode(kTxCode + 1)).ok()); + called++; + return absl::OkStatus(); + }); + + int called2 = 0; + std::unique_ptr tx_receiver2 = + absl::make_unique( + nullptr, + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + int value; + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ(value, kValue); + EXPECT_EQ(tx_code, kTxCode + 1); + called2++; + return absl::OkStatus(); + }); + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteBinder(tx_receiver2.get()).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 1); + EXPECT_EQ(called2, 1); +} + +TEST_P(FakeBinderTest, SendTransactionAfterDestruction) { + constexpr int kValue = 0x1234; + constexpr int kTxCode = 0x4321; + std::unique_ptr sender; + int called = 0; + { + std::unique_ptr tx_receiver; + std::tie(sender, tx_receiver) = NewBinderPair( + [&](transaction_code_t tx_code, ReadableParcel* parcel, int /*uid*/) { + EXPECT_EQ(tx_code, kTxCode); + int value; + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ(value, kValue + called); + called++; + return absl::OkStatus(); + }); + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteInt32(kValue).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + } + // tx_receiver gets destructed here. This additional transaction should + // *still* be received. + EXPECT_TRUE(sender->PrepareTransaction().ok()); + WritableParcel* parcel = sender->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteInt32(kValue + 1).ok()); + EXPECT_TRUE(sender->Transact(BinderTransportTxCode(kTxCode)).ok()); + + g_transaction_processor->Terminate(); + EXPECT_EQ(called, 2); +} + +namespace { + +struct ThreadArgument { + int tid; + std::vector, + std::unique_ptr>>>* + global_binder_pairs; + std::vector>* global_cnts; + int tx_code; + int num_pairs_per_thread; + int num_transactions_per_pair; + grpc_core::Mutex* mu; +}; + +} // namespace + +// Verify that this system works correctly in a concurrent environment. +// +// In end-to-end tests, there will be at least two threads, one from client to +// server and vice versa. Thus, it's important for us to make sure that the +// simulation is correct in such setup. +TEST_P(FakeBinderTest, StressTest) { + constexpr int kTxCode = 0x4321; + constexpr int kNumThreads = 16; + constexpr int kNumPairsPerThread = 128; + constexpr int kNumTransactionsPerPair = 128; + std::vector args(kNumThreads); + + grpc_core::Mutex mu; + std::vector, std::unique_ptr>>> + global_binder_pairs(kNumThreads); + std::vector> global_cnts( + kNumThreads, std::vector(kNumPairsPerThread, 0)); + + auto th_function = [](void* arg) { + ThreadArgument* th_arg = static_cast(arg); + int tid = th_arg->tid; + std::vector, + std::unique_ptr>> + binder_pairs; + for (int p = 0; p < th_arg->num_pairs_per_thread; ++p) { + std::unique_ptr binder; + std::unique_ptr tx_receiver; + int expected_tx_code = th_arg->tx_code; + std::vector>* cnt = th_arg->global_cnts; + std::tie(binder, tx_receiver) = + NewBinderPair([tid, p, cnt, expected_tx_code]( + transaction_code_t tx_code, ReadableParcel* parcel, + int /*uid*/) mutable { + EXPECT_EQ(tx_code, expected_tx_code); + int value; + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ(tid, value); + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ(p, value); + EXPECT_TRUE(parcel->ReadInt32(&value).ok()); + EXPECT_EQ((*cnt)[tid][p], value); + (*cnt)[tid][p]++; + return absl::OkStatus(); + }); + binder_pairs.emplace_back(std::move(binder), std::move(tx_receiver)); + } + std::vector order; + for (int i = 0; i < th_arg->num_pairs_per_thread; ++i) { + for (int j = 0; j < th_arg->num_transactions_per_pair; ++j) { + order.emplace_back(i); + } + } + std::mt19937 rng(tid); + std::shuffle(order.begin(), order.end(), rng); + std::vector tx_cnt(th_arg->num_pairs_per_thread); + for (int p : order) { + EXPECT_TRUE(binder_pairs[p].first->PrepareTransaction().ok()); + WritableParcel* parcel = binder_pairs[p].first->GetWritableParcel(); + EXPECT_TRUE(parcel->WriteInt32(th_arg->tid).ok()); + EXPECT_TRUE(parcel->WriteInt32(p).ok()); + EXPECT_TRUE(parcel->WriteInt32(tx_cnt[p]++).ok()); + EXPECT_TRUE(binder_pairs[p] + .first->Transact(BinderTransportTxCode(th_arg->tx_code)) + .ok()); + } + th_arg->mu->Lock(); + (*th_arg->global_binder_pairs)[tid] = std::move(binder_pairs); + th_arg->mu->Unlock(); + }; + + std::vector thrs(kNumThreads); + std::vector thr_names(kNumThreads); + for (int i = 0; i < kNumThreads; ++i) { + args[i].tid = i; + args[i].global_binder_pairs = &global_binder_pairs; + args[i].global_cnts = &global_cnts; + args[i].tx_code = kTxCode; + args[i].num_pairs_per_thread = kNumPairsPerThread; + args[i].num_transactions_per_pair = kNumTransactionsPerPair; + args[i].mu = μ + thr_names[i] = absl::StrFormat("thread-%d", i); + thrs[i] = grpc_core::Thread(thr_names[i].c_str(), th_function, &args[i]); + } + for (auto& th : thrs) th.Start(); + for (auto& th : thrs) th.Join(); + g_transaction_processor->Terminate(); +} + +INSTANTIATE_TEST_SUITE_P(FakeBinderTestWithDifferentDelayTimes, FakeBinderTest, + testing::Values(absl::ZeroDuration(), + absl::Nanoseconds(10), + absl::Microseconds(10))); + +} // namespace end2end_testing +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/end2end/testing_channel_create.cc b/test/core/transport/binder/end2end/testing_channel_create.cc new file mode 100644 index 00000000..14ab6ab9 --- /dev/null +++ b/test/core/transport/binder/end2end/testing_channel_create.cc @@ -0,0 +1,132 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/binder/end2end/testing_channel_create.h" + +#include + +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" +#include "src/core/ext/transport/binder/transport/binder_transport.h" +#include "src/core/ext/transport/binder/wire_format/wire_reader_impl.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/error_utils.h" + +namespace grpc_binder { +namespace end2end_testing { + +namespace { +// Since we assume the first half of the transport setup is completed before the +// server side enters WireReader::SetupTransport, we need this helper to wait +// and finish that part of the negotiation for us. +class ServerSetupTransportHelper { + public: + ServerSetupTransportHelper() + : wire_reader_(absl::make_unique( + /*transport_stream_receiver=*/nullptr, /*is_client=*/false, + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>())) { + std::tie(endpoint_binder_, tx_receiver_) = NewBinderPair( + [this](transaction_code_t tx_code, ReadableParcel* parcel, int uid) { + return this->wire_reader_->ProcessTransaction(tx_code, parcel, uid); + }); + } + std::unique_ptr WaitForClientBinder() { + return wire_reader_->RecvSetupTransport(); + } + + std::unique_ptr GetEndpointBinderForClient() { + return std::move(endpoint_binder_); + } + + private: + std::unique_ptr wire_reader_; + // The endpoint binder for client. + std::unique_ptr endpoint_binder_; + std::unique_ptr tx_receiver_; +}; +} // namespace + +std::pair +CreateClientServerBindersPairForTesting() { + ServerSetupTransportHelper helper; + std::unique_ptr endpoint_binder = helper.GetEndpointBinderForClient(); + grpc_transport* client_transport = nullptr; + + struct ThreadArgs { + std::unique_ptr endpoint_binder; + grpc_transport** client_transport; + } args; + + args.endpoint_binder = std::move(endpoint_binder); + args.client_transport = &client_transport; + + grpc_core::Thread client_thread( + "client-thread", + [](void* arg) { + ThreadArgs* args = static_cast(arg); + std::unique_ptr endpoint_binder = + std::move(args->endpoint_binder); + *args->client_transport = grpc_create_binder_transport_client( + std::move(endpoint_binder), + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>()); + }, + &args); + client_thread.Start(); + grpc_transport* server_transport = grpc_create_binder_transport_server( + helper.WaitForClientBinder(), + std::make_shared()); + client_thread.Join(); + return std::make_pair(client_transport, server_transport); +} + +std::shared_ptr BinderChannelForTesting( + grpc::Server* server, const grpc::ChannelArguments& args) { + grpc_channel_args channel_args = args.c_channel_args(); + return grpc::CreateChannelInternal( + "", + grpc_binder_channel_create_for_testing(server->c_server(), &channel_args, + nullptr), + std::vector>()); +} + +} // namespace end2end_testing +} // namespace grpc_binder + +grpc_channel* grpc_binder_channel_create_for_testing(grpc_server* server, + grpc_channel_args* args, + void* /*reserved*/) { + grpc_core::ExecCtx exec_ctx; + + grpc_arg default_authority_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast("test.authority")); + grpc_channel_args* client_args = + grpc_channel_args_copy_and_add(args, &default_authority_arg, 1); + + grpc_transport *client_transport, *server_transport; + std::tie(client_transport, server_transport) = + grpc_binder::end2end_testing::CreateClientServerBindersPairForTesting(); + grpc_error_handle error = server->core_server->SetupTransport( + server_transport, nullptr, args, nullptr); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_channel* channel = + grpc_channel_create("binder", client_args, GRPC_CLIENT_DIRECT_CHANNEL, + client_transport, nullptr, 0, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_channel_args_destroy(client_args); + return channel; +} diff --git a/test/core/transport/binder/endpoint_binder_pool_test.cc b/test/core/transport/binder/endpoint_binder_pool_test.cc new file mode 100644 index 00000000..e0f671a3 --- /dev/null +++ b/test/core/transport/binder/endpoint_binder_pool_test.cc @@ -0,0 +1,76 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/transport/binder/client/endpoint_binder_pool.h" + +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" + +#include "src/core/ext/transport/binder/client/endpoint_binder_pool.h" +#include "test/core/transport/binder/mock_objects.h" +#include "test/core/util/test_config.h" + +namespace grpc_binder { + +class CallbackChecker { + public: + MOCK_METHOD(void, Cb, (std::unique_ptr), ()); +}; + +TEST(EndpointBinderPoolTest, AddBeforeGet) { + EndpointBinderPool pool; + auto b = absl::make_unique(); + CallbackChecker cc; + pool.AddEndpointBinder("test", std::move(b)); + // TODO(mingcl): Use pointer matcher to verify it is `b` being passed back + // here. It is only available in newer gtest version + EXPECT_CALL(cc, Cb(testing::_)); + pool.GetEndpointBinder( + "test", std::bind(&CallbackChecker::Cb, &cc, std::placeholders::_1)); +} + +TEST(EndpointBinderPoolTest, GetBeforeAdd) { + EndpointBinderPool pool; + auto b = absl::make_unique(); + CallbackChecker cc; + EXPECT_CALL(cc, Cb(testing::_)).Times(0); + pool.GetEndpointBinder( + "test", std::bind(&CallbackChecker::Cb, &cc, std::placeholders::_1)); + EXPECT_CALL(cc, Cb(testing::_)).Times(1); + pool.AddEndpointBinder("test", std::move(b)); +} + +TEST(EndpointBinderPoolTest, ExpectNotCalled) { + EndpointBinderPool pool; + auto b = absl::make_unique(); + CallbackChecker cc; + EXPECT_CALL(cc, Cb(testing::_)).Times(0); + pool.GetEndpointBinder( + "test", std::bind(&CallbackChecker::Cb, &cc, std::placeholders::_1)); +} + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/mock_objects.cc b/test/core/transport/binder/mock_objects.cc new file mode 100644 index 00000000..4fb03555 --- /dev/null +++ b/test/core/transport/binder/mock_objects.cc @@ -0,0 +1,55 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/binder/mock_objects.h" + +#include + +#include "absl/memory/memory.h" + +namespace grpc_binder { + +using ::testing::Return; + +MockReadableParcel::MockReadableParcel() { + ON_CALL(*this, ReadBinder).WillByDefault([](std::unique_ptr* binder) { + *binder = absl::make_unique(); + return absl::OkStatus(); + }); + ON_CALL(*this, ReadInt32).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, ReadByteArray).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, ReadString).WillByDefault(Return(absl::OkStatus())); +} + +MockWritableParcel::MockWritableParcel() { + ON_CALL(*this, WriteInt32).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, WriteBinder).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, WriteString).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, WriteByteArray).WillByDefault(Return(absl::OkStatus())); +} + +MockBinder::MockBinder() { + ON_CALL(*this, PrepareTransaction).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, Transact).WillByDefault(Return(absl::OkStatus())); + ON_CALL(*this, GetWritableParcel).WillByDefault(Return(&mock_input_)); + ON_CALL(*this, ConstructTxReceiver) + .WillByDefault( + [this](grpc_core::RefCountedPtr /*wire_reader_ref*/, + TransactionReceiver::OnTransactCb cb) { + return absl::make_unique( + cb, BinderTransportTxCode::SETUP_TRANSPORT, &mock_output_); + }); +} + +} // namespace grpc_binder diff --git a/test/core/transport/binder/transport_stream_receiver_test.cc b/test/core/transport/binder/transport_stream_receiver_test.cc new file mode 100644 index 00000000..134a27a9 --- /dev/null +++ b/test/core/transport/binder/transport_stream_receiver_test.cc @@ -0,0 +1,288 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" + +#include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h" +#include "test/core/util/test_config.h" + +namespace grpc_binder { +namespace { + +// TODO(waynetu): These are hacks to make callbacks aware of their stream IDs +// and sequence numbers. Remove/Refactor these hacks when possible. +template +std::pair Decode(const T& /*data*/) { + assert(false && "This should not be called"); + return {}; +} + +template <> +std::pair Decode(const std::string& data) { + assert(data.size() == sizeof(StreamIdentifier) + sizeof(int)); + StreamIdentifier id{}; + int seq_num{}; + std::memcpy(&id, data.data(), sizeof(StreamIdentifier)); + std::memcpy(&seq_num, data.data() + sizeof(StreamIdentifier), sizeof(int)); + return std::make_pair(id, seq_num); +} + +template <> +std::pair Decode(const Metadata& data) { + assert(data.size() == 1); + const std::string& encoding = data[0].first; + return Decode(encoding); +} + +template +T Encode(StreamIdentifier /*id*/, int /*seq_num*/) { + assert(false && "This should not be called"); + return {}; +} + +template <> +std::string Encode(StreamIdentifier id, int seq_num) { + char result[sizeof(StreamIdentifier) + sizeof(int)]; + std::memcpy(result, &id, sizeof(StreamIdentifier)); + std::memcpy(result + sizeof(StreamIdentifier), &seq_num, sizeof(int)); + return std::string(result, sizeof(StreamIdentifier) + sizeof(int)); +} + +template <> +Metadata Encode(StreamIdentifier id, int seq_num) { + return {{Encode(id, seq_num), ""}}; +} + +MATCHER_P2(StreamIdAndSeqNumMatch, id, seq_num, "") { + auto p = Decode(arg.value()); + return p.first == id && p.second == seq_num; +} + +// MockCallback is used to verify the every callback passed to transaction +// receiver will eventually be invoked with the artifact of its corresponding +// binder transaction. +template +class MockCallback { + public: + explicit MockCallback(StreamIdentifier id, int seq_num) + : id_(id), seq_num_(seq_num) {} + + MOCK_METHOD(void, ActualCallback, (FirstArg), ()); + + std::function GetHandle() { + return [this](FirstArg first_arg, TrailingArgs...) { + this->ActualCallback(first_arg); + }; + } + + void ExpectCallbackInvocation() { + EXPECT_CALL(*this, ActualCallback(StreamIdAndSeqNumMatch(id_, seq_num_))); + } + + private: + StreamIdentifier id_; + int seq_num_; +}; + +using MockInitialMetadataCallback = MockCallback>; +using MockMessageCallback = MockCallback>; +using MockTrailingMetadataCallback = + MockCallback, int>; + +class MockOpBatch { + public: + MockOpBatch(StreamIdentifier id, int flag, int seq_num) + : id_(id), flag_(flag), seq_num_(seq_num) { + if (flag_ & kFlagPrefix) { + initial_metadata_callback_ = + absl::make_unique(id_, seq_num_); + } + if (flag_ & kFlagMessageData) { + message_callback_ = absl::make_unique(id_, seq_num_); + } + if (flag_ & kFlagSuffix) { + trailing_metadata_callback_ = + absl::make_unique(id_, seq_num_); + } + } + + void Complete(TransportStreamReceiver& receiver) { + if (flag_ & kFlagPrefix) { + initial_metadata_callback_->ExpectCallbackInvocation(); + receiver.NotifyRecvInitialMetadata(id_, Encode(id_, seq_num_)); + } + if (flag_ & kFlagMessageData) { + message_callback_->ExpectCallbackInvocation(); + receiver.NotifyRecvMessage(id_, Encode(id_, seq_num_)); + } + if (flag_ & kFlagSuffix) { + trailing_metadata_callback_->ExpectCallbackInvocation(); + receiver.NotifyRecvTrailingMetadata(id_, Encode(id_, seq_num_), + 0); + } + } + + void RequestRecv(TransportStreamReceiver& receiver) { + if (flag_ & kFlagPrefix) { + receiver.RegisterRecvInitialMetadata( + id_, initial_metadata_callback_->GetHandle()); + } + if (flag_ & kFlagMessageData) { + receiver.RegisterRecvMessage(id_, message_callback_->GetHandle()); + } + if (flag_ & kFlagSuffix) { + receiver.RegisterRecvTrailingMetadata( + id_, trailing_metadata_callback_->GetHandle()); + } + } + + MockOpBatch NextBatch(int flag) const { + return MockOpBatch(id_, flag, seq_num_ + 1); + } + + private: + std::unique_ptr initial_metadata_callback_; + std::unique_ptr message_callback_; + std::unique_ptr trailing_metadata_callback_; + int id_, flag_, seq_num_; +}; + +class TransportStreamReceiverTest : public ::testing::Test { + protected: + MockOpBatch NewGrpcStream(int flag) { + return MockOpBatch(current_id_++, flag, 0); + } + + StreamIdentifier current_id_ = 0; +}; + +const int kFlagAll = kFlagPrefix | kFlagMessageData | kFlagSuffix; + +} // namespace + +TEST_F(TransportStreamReceiverTest, MultipleStreamRequestThenComplete) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagAll); + t0.RequestRecv(receiver); + t0.Complete(receiver); +} + +TEST_F(TransportStreamReceiverTest, MultipleStreamCompleteThenRequest) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagAll); + t0.Complete(receiver); + t0.RequestRecv(receiver); +} + +TEST_F(TransportStreamReceiverTest, MultipleStreamInterleaved) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagAll); + MockOpBatch t1 = NewGrpcStream(kFlagAll); + t1.Complete(receiver); + t0.Complete(receiver); + t0.RequestRecv(receiver); + t1.RequestRecv(receiver); +} + +TEST_F(TransportStreamReceiverTest, MultipleStreamInterleavedReversed) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagAll); + MockOpBatch t1 = NewGrpcStream(kFlagAll); + t0.RequestRecv(receiver); + t1.RequestRecv(receiver); + t1.Complete(receiver); + t0.Complete(receiver); +} + +TEST_F(TransportStreamReceiverTest, MultipleStreamMoreInterleaved) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagAll); + MockOpBatch t1 = NewGrpcStream(kFlagAll); + t0.RequestRecv(receiver); + t1.Complete(receiver); + MockOpBatch t2 = NewGrpcStream(kFlagAll); + t2.RequestRecv(receiver); + t0.Complete(receiver); + t1.RequestRecv(receiver); + t2.Complete(receiver); +} + +TEST_F(TransportStreamReceiverTest, SingleStreamUnaryCall) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagPrefix); + MockOpBatch t1 = t0.NextBatch(kFlagMessageData); + MockOpBatch t2 = t1.NextBatch(kFlagSuffix); + t0.RequestRecv(receiver); + t1.RequestRecv(receiver); + t2.RequestRecv(receiver); + t0.Complete(receiver); + t1.Complete(receiver); + t2.Complete(receiver); +} + +TEST_F(TransportStreamReceiverTest, SingleStreamStreamingCall) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagPrefix); + t0.RequestRecv(receiver); + t0.Complete(receiver); + MockOpBatch t1 = t0.NextBatch(kFlagMessageData); + t1.Complete(receiver); + t1.RequestRecv(receiver); + MockOpBatch t2 = t1.NextBatch(kFlagMessageData); + t2.RequestRecv(receiver); + t2.Complete(receiver); + MockOpBatch t3 = t2.NextBatch(kFlagMessageData); + MockOpBatch t4 = t3.NextBatch(kFlagMessageData); + t3.Complete(receiver); + t4.Complete(receiver); + t3.RequestRecv(receiver); + t4.RequestRecv(receiver); +} + +TEST_F(TransportStreamReceiverTest, DISABLED_SingleStreamBufferedCallbacks) { + TransportStreamReceiverImpl receiver(/*is_client=*/true); + MockOpBatch t0 = NewGrpcStream(kFlagPrefix); + MockOpBatch t1 = t0.NextBatch(kFlagMessageData); + MockOpBatch t2 = t1.NextBatch(kFlagMessageData); + MockOpBatch t3 = t2.NextBatch(kFlagSuffix); + t0.RequestRecv(receiver); + // TODO(waynetu): Can gRPC issues recv_message before it actually receives the + // previous one? + t1.RequestRecv(receiver); + t2.RequestRecv(receiver); + t3.RequestRecv(receiver); + t0.Complete(receiver); + t1.Complete(receiver); + t2.Complete(receiver); + t3.Complete(receiver); +} + +// TODO(waynetu): Should we have some concurrent stress tests to make sure that +// thread safety is well taken care of? + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/wire_reader_test.cc b/test/core/transport/binder/wire_reader_test.cc new file mode 100644 index 00000000..5d48d036 --- /dev/null +++ b/test/core/transport/binder/wire_reader_test.cc @@ -0,0 +1,326 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Unit tests for WireReaderImpl. +// +// WireReaderImpl is responsible for turning incoming transactions into +// top-level metadata. The following tests verify that the interactions between +// WireReaderImpl and both the output (readable) parcel and the transport stream +// receiver are correct in all possible situations. +#include +#include +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include "src/core/ext/transport/binder/security_policy/untrusted_security_policy.h" +#include "src/core/ext/transport/binder/wire_format/wire_reader_impl.h" +#include "test/core/transport/binder/mock_objects.h" +#include "test/core/util/test_config.h" + +namespace grpc_binder { + +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; +using ::testing::StrictMock; + +namespace { + +class WireReaderTest : public ::testing::Test { + public: + WireReaderTest() + : transport_stream_receiver_( + std::make_shared>()), + wire_reader_( + transport_stream_receiver_, /*is_client=*/true, + std::make_shared< + grpc::experimental::binder::UntrustedSecurityPolicy>()) {} + + protected: + void ExpectReadInt32(int result) { + EXPECT_CALL(mock_readable_parcel_, ReadInt32) + .WillOnce(DoAll(SetArgPointee<0>(result), Return(absl::OkStatus()))); + } + + void ExpectReadByteArray(const std::string& buffer) { + ExpectReadInt32(buffer.length()); + if (!buffer.empty()) { + EXPECT_CALL(mock_readable_parcel_, ReadByteArray) + .WillOnce([buffer](std::string* data) { + *data = buffer; + return absl::OkStatus(); + }); + } + } + + void UnblockSetupTransport() { + // SETUP_TRANSPORT should finish before we can proceed with any other + // requests and streaming calls. The MockBinder will construct a + // MockTransactionReceiver, which will then sends SETUP_TRANSPORT request + // back to us. + wire_reader_.SetupTransport(absl::make_unique()); + } + + template + absl::Status CallProcessTransaction(T tx_code) { + return wire_reader_.ProcessTransaction( + static_cast(tx_code), &mock_readable_parcel_, + /*uid=*/0); + } + + std::shared_ptr> + transport_stream_receiver_; + WireReaderImpl wire_reader_; + MockReadableParcel mock_readable_parcel_; +}; + +MATCHER_P(StatusOrStrEq, target, "") { + if (!arg.ok()) return false; + return arg.value() == target; +} + +MATCHER_P(StatusOrContainerEq, target, "") { + if (!arg.ok()) return false; + return arg.value() == target; +} + +} // namespace + +TEST_F(WireReaderTest, SetupTransport) { + auto mock_binder = absl::make_unique(); + MockBinder& mock_binder_ref = *mock_binder; + + ::testing::InSequence sequence; + EXPECT_CALL(mock_binder_ref, Initialize); + EXPECT_CALL(mock_binder_ref, PrepareTransaction); + const MockReadableParcel mock_readable_parcel; + EXPECT_CALL(mock_binder_ref, GetWritableParcel); + + // Write version. + EXPECT_CALL(mock_binder_ref.GetWriter(), WriteInt32(77)); + + // The transaction receiver immediately informs the wire writer that the + // transport has been successfully set up. + EXPECT_CALL(mock_binder_ref, ConstructTxReceiver); + + EXPECT_CALL(mock_binder_ref.GetReader(), ReadInt32); + EXPECT_CALL(mock_binder_ref.GetReader(), ReadBinder); + + // Write transaction receiver. + EXPECT_CALL(mock_binder_ref.GetWriter(), WriteBinder); + // Perform transaction. + EXPECT_CALL(mock_binder_ref, Transact); + + wire_reader_.SetupTransport(std::move(mock_binder)); +} + +TEST_F(WireReaderTest, ProcessTransactionControlMessageSetupTransport) { + ::testing::InSequence sequence; + UnblockSetupTransport(); +} + +TEST_F(WireReaderTest, ProcessTransactionControlMessagePingResponse) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + EXPECT_CALL(mock_readable_parcel_, ReadInt32); + EXPECT_TRUE( + CallProcessTransaction(BinderTransportTxCode::PING_RESPONSE).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataEmptyFlagIgnored) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // first transaction: empty flag + ExpectReadInt32(0); + // Won't further read sequence number. + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, + ProcessTransactionServerRpcDataFlagPrefixWithoutMetadata) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // flag + ExpectReadInt32(kFlagPrefix); + // sequence number + ExpectReadInt32(0); + + // count + ExpectReadInt32(0); + EXPECT_CALL( + *transport_stream_receiver_, + NotifyRecvInitialMetadata(kFirstCallId, StatusOrContainerEq(Metadata{}))); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataFlagPrefixWithMetadata) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // flag + ExpectReadInt32(kFlagPrefix); + // sequence number + ExpectReadInt32(0); + + const std::vector> kMetadata = { + {"", ""}, + {"", "value"}, + {"key", ""}, + {"key", "value"}, + {"another-key", "another-value"}, + }; + + // count + ExpectReadInt32(kMetadata.size()); + for (const auto& md : kMetadata) { + // metadata key + ExpectReadByteArray(md.first); + // metadata val + // TODO(waynetu): metadata value can also be "parcelable". + ExpectReadByteArray(md.second); + } + EXPECT_CALL( + *transport_stream_receiver_, + NotifyRecvInitialMetadata(kFirstCallId, StatusOrContainerEq(kMetadata))); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataFlagMessageDataNonEmpty) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // flag + ExpectReadInt32(kFlagMessageData); + // sequence number + ExpectReadInt32(0); + + // message data + // TODO(waynetu): message data can also be "parcelable". + const std::string kMessageData = "message data"; + ExpectReadByteArray(kMessageData); + EXPECT_CALL(*transport_stream_receiver_, + NotifyRecvMessage(kFirstCallId, StatusOrStrEq(kMessageData))); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataFlagMessageDataEmpty) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // flag + ExpectReadInt32(kFlagMessageData); + // sequence number + ExpectReadInt32(0); + + // message data + // TODO(waynetu): message data can also be "parcelable". + const std::string kMessageData = ""; + ExpectReadByteArray(kMessageData); + EXPECT_CALL(*transport_stream_receiver_, + NotifyRecvMessage(kFirstCallId, StatusOrStrEq(kMessageData))); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataFlagSuffixWithStatus) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + constexpr int kStatus = 0x1234; + // flag + ExpectReadInt32(kFlagSuffix | kFlagStatusDescription | (kStatus << 16)); + // sequence number + ExpectReadInt32(0); + // status description + EXPECT_CALL(mock_readable_parcel_, ReadString); + // metadata count + ExpectReadInt32(0); + EXPECT_CALL(*transport_stream_receiver_, + NotifyRecvTrailingMetadata( + kFirstCallId, StatusOrContainerEq(Metadata{}), kStatus)); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, ProcessTransactionServerRpcDataFlagSuffixWithoutStatus) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // flag + ExpectReadInt32(kFlagSuffix); + // sequence number + ExpectReadInt32(0); + // No status description + // metadata count + ExpectReadInt32(0); + EXPECT_CALL(*transport_stream_receiver_, + NotifyRecvTrailingMetadata(kFirstCallId, + StatusOrContainerEq(Metadata{}), 0)); + + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +TEST_F(WireReaderTest, InBoundFlowControl) { + ::testing::InSequence sequence; + UnblockSetupTransport(); + + // data size + EXPECT_CALL(mock_readable_parcel_, GetDataSize).WillOnce(Return(1000)); + // flag + ExpectReadInt32(kFlagMessageData | kFlagMessageDataIsPartial); + // sequence number + ExpectReadInt32(0); + // message size + ExpectReadInt32(1000); + EXPECT_CALL(mock_readable_parcel_, ReadByteArray) + .WillOnce(DoAll(SetArgPointee<0>(std::string(1000, 'a')), + Return(absl::OkStatus()))); + + // Data is not completed. No callback will be triggered. + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); + + EXPECT_CALL(mock_readable_parcel_, GetDataSize).WillOnce(Return(1000)); + // flag + ExpectReadInt32(kFlagMessageData); + // sequence number + ExpectReadInt32(1); + // message size + ExpectReadInt32(1000); + EXPECT_CALL(mock_readable_parcel_, ReadByteArray) + .WillOnce(DoAll(SetArgPointee<0>(std::string(1000, 'b')), + Return(absl::OkStatus()))); + + EXPECT_CALL(*transport_stream_receiver_, + NotifyRecvMessage(kFirstCallId, + StatusOrContainerEq(std::string(1000, 'a') + + std::string(1000, 'b')))); + EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok()); +} + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/binder/wire_writer_test.cc b/test/core/transport/binder/wire_writer_test.cc new file mode 100644 index 00000000..7b320695 --- /dev/null +++ b/test/core/transport/binder/wire_writer_test.cc @@ -0,0 +1,251 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/transport/binder/wire_format/wire_writer.h" + +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include "test/core/transport/binder/mock_objects.h" +#include "test/core/util/test_config.h" + +namespace grpc_binder { + +using ::testing::Return; + +MATCHER_P(StrEqInt8Ptr, target, "") { + return std::string(reinterpret_cast(arg), target.size()) == + target; +} + +TEST(WireWriterTest, RpcCall) { + auto mock_binder = absl::make_unique(); + MockBinder& mock_binder_ref = *mock_binder; + MockWritableParcel mock_writable_parcel; + ON_CALL(mock_binder_ref, GetWritableParcel) + .WillByDefault(Return(&mock_writable_parcel)); + WireWriterImpl wire_writer(std::move(mock_binder)); + + auto ExpectWriteByteArray = [&](const std::string& target) { + // length + EXPECT_CALL(mock_writable_parcel, WriteInt32(target.size())); + if (!target.empty()) { + // content + EXPECT_CALL(mock_writable_parcel, + WriteByteArray(StrEqInt8Ptr(target), target.size())); + } + }; + + ::testing::InSequence sequence; + int sequence_number = 0; + + { + // flag + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + // sequence number + EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); + + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); + + Transaction tx(kFirstCallId, /*is_client=*/true); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + sequence_number++; + } + { + // flag + EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagPrefix)); + // sequence number. This is another stream so the sequence number starts + // with 0. + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + + EXPECT_CALL(mock_writable_parcel, + WriteString(absl::string_view("/example/method/ref"))); + + const std::vector> kMetadata = { + {"", ""}, + {"", "value"}, + {"key", ""}, + {"key", "value"}, + {"another-key", "another-value"}, + }; + + // Number of metadata + EXPECT_CALL(mock_writable_parcel, WriteInt32(kMetadata.size())); + + for (const auto& md : kMetadata) { + ExpectWriteByteArray(md.first); + ExpectWriteByteArray(md.second); + } + + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 1))); + + Transaction tx(kFirstCallId + 1, /*is_client=*/true); + tx.SetPrefix(kMetadata); + tx.SetMethodRef("/example/method/ref"); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + } + { + // flag + EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagMessageData)); + // sequence number + EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); + + ExpectWriteByteArray("data"); + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); + + Transaction tx(kFirstCallId, /*is_client=*/true); + tx.SetData("data"); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + sequence_number++; + } + { + // flag + EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagSuffix)); + // sequence number + EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); + + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); + + Transaction tx(kFirstCallId, /*is_client=*/true); + tx.SetSuffix({}); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + sequence_number++; + } + { + // flag + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagPrefix | kFlagMessageData | kFlagSuffix)); + // sequence number + EXPECT_CALL(mock_writable_parcel, WriteInt32(sequence_number)); + + EXPECT_CALL(mock_writable_parcel, + WriteString(absl::string_view("/example/method/ref"))); + + const std::vector> kMetadata = { + {"", ""}, + {"", "value"}, + {"key", ""}, + {"key", "value"}, + {"another-key", "another-value"}, + }; + + // Number of metadata + EXPECT_CALL(mock_writable_parcel, WriteInt32(kMetadata.size())); + + for (const auto& md : kMetadata) { + ExpectWriteByteArray(md.first); + ExpectWriteByteArray(md.second); + } + + // Empty message data + ExpectWriteByteArray(""); + + EXPECT_CALL(mock_binder_ref, Transact(BinderTransportTxCode(kFirstCallId))); + + Transaction tx(kFirstCallId, /*is_client=*/true); + // TODO(waynetu): Implement a helper function that automatically creates + // EXPECT_CALL based on the tx object. + tx.SetPrefix(kMetadata); + tx.SetMethodRef("/example/method/ref"); + tx.SetData(""); + tx.SetSuffix({}); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + sequence_number++; + } + + // Really large message + { + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_writable_parcel, GetDataSize) + .WillOnce(Return(WireWriterImpl::kBlockSize)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(1)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_writable_parcel, GetDataSize) + .WillOnce(Return(WireWriterImpl::kBlockSize)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagMessageData)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(2)); + ExpectWriteByteArray("a"); + EXPECT_CALL(mock_writable_parcel, GetDataSize).WillOnce(Return(1)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 2))); + + // Use a new stream. + Transaction tx(kFirstCallId + 2, /*is_client=*/true); + tx.SetData(std::string(2 * WireWriterImpl::kBlockSize + 1, 'a')); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + } + // Really large message with metadata + { + EXPECT_CALL( + mock_writable_parcel, + WriteInt32(kFlagPrefix | kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + EXPECT_CALL(mock_writable_parcel, WriteString(absl::string_view("123"))); + EXPECT_CALL(mock_writable_parcel, WriteInt32(0)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_writable_parcel, GetDataSize) + .WillOnce(Return(WireWriterImpl::kBlockSize)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(1)); + ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a')); + EXPECT_CALL(mock_writable_parcel, GetDataSize) + .WillOnce(Return(WireWriterImpl::kBlockSize)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + EXPECT_CALL(mock_writable_parcel, + WriteInt32(kFlagMessageData | kFlagSuffix)); + EXPECT_CALL(mock_writable_parcel, WriteInt32(2)); + ExpectWriteByteArray("a"); + EXPECT_CALL(mock_writable_parcel, GetDataSize).WillOnce(Return(1)); + EXPECT_CALL(mock_binder_ref, + Transact(BinderTransportTxCode(kFirstCallId + 3))); + + // Use a new stream. + Transaction tx(kFirstCallId + 3, /*is_client=*/true); + tx.SetPrefix({}); + tx.SetMethodRef("123"); + tx.SetData(std::string(2 * WireWriterImpl::kBlockSize + 1, 'a')); + tx.SetSuffix({}); + EXPECT_TRUE(wire_writer.RpcCall(tx).ok()); + } +} + +} // namespace grpc_binder + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/byte_stream_test.cc b/test/core/transport/byte_stream_test.cc new file mode 100644 index 00000000..014bb1d0 --- /dev/null +++ b/test/core/transport/byte_stream_test.cc @@ -0,0 +1,254 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/byte_stream.h" + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace { + +// +// SliceBufferByteStream tests +// + +void NotCalledClosure(void* /*arg*/, grpc_error_handle /*error*/) { + GPR_ASSERT(false); +} + +TEST(SliceBufferByteStream, Basic) { + grpc_core::ExecCtx exec_ctx; + // Create and populate slice buffer. + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice input[] = { + grpc_slice_from_static_string("foo"), + grpc_slice_from_static_string("bar"), + }; + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + grpc_slice_buffer_add(&buffer, input[i]); + } + // Create byte stream. + SliceBufferByteStream stream(&buffer, 0); + grpc_slice_buffer_destroy_internal(&buffer); + EXPECT_EQ(6U, stream.length()); + grpc_closure closure; + GRPC_CLOSURE_INIT(&closure, NotCalledClosure, nullptr, + grpc_schedule_on_exec_ctx); + // Read each slice. Note that Next() always returns synchronously. + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + grpc_slice output; + grpc_error_handle error = stream.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[i], output)); + grpc_slice_unref_internal(output); + } + // Clean up. + stream.Orphan(); +} + +TEST(SliceBufferByteStream, Shutdown) { + grpc_core::ExecCtx exec_ctx; + // Create and populate slice buffer. + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice input[] = { + grpc_slice_from_static_string("foo"), + grpc_slice_from_static_string("bar"), + }; + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + grpc_slice_buffer_add(&buffer, input[i]); + } + // Create byte stream. + SliceBufferByteStream stream(&buffer, 0); + grpc_slice_buffer_destroy_internal(&buffer); + EXPECT_EQ(6U, stream.length()); + grpc_closure closure; + GRPC_CLOSURE_INIT(&closure, NotCalledClosure, nullptr, + grpc_schedule_on_exec_ctx); + // Read the first slice. + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + grpc_slice output; + grpc_error_handle error = stream.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[0], output)); + grpc_slice_unref_internal(output); + // Now shutdown. + grpc_error_handle shutdown_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("shutdown error"); + stream.Shutdown(GRPC_ERROR_REF(shutdown_error)); + // After shutdown, the next pull() should return the error. + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + error = stream.Pull(&output); + EXPECT_TRUE(error == shutdown_error); + GRPC_ERROR_UNREF(error); + GRPC_ERROR_UNREF(shutdown_error); + // Clean up. + stream.Orphan(); +} + +// +// CachingByteStream tests +// + +TEST(CachingByteStream, Basic) { + grpc_core::ExecCtx exec_ctx; + // Create and populate slice buffer byte stream. + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice input[] = { + grpc_slice_from_static_string("foo"), + grpc_slice_from_static_string("bar"), + }; + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + grpc_slice_buffer_add(&buffer, input[i]); + } + SliceBufferByteStream underlying_stream(&buffer, 0); + grpc_slice_buffer_destroy_internal(&buffer); + // Create cache and caching stream. + ByteStreamCache cache((OrphanablePtr(&underlying_stream))); + ByteStreamCache::CachingByteStream stream(&cache); + grpc_closure closure; + GRPC_CLOSURE_INIT(&closure, NotCalledClosure, nullptr, + grpc_schedule_on_exec_ctx); + // Read each slice. Note that next() always returns synchronously, + // because the underlying byte stream always does. + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + grpc_slice output; + grpc_error_handle error = stream.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[i], output)); + grpc_slice_unref_internal(output); + } + // Clean up. + stream.Orphan(); + cache.Destroy(); +} + +TEST(CachingByteStream, Reset) { + grpc_core::ExecCtx exec_ctx; + // Create and populate slice buffer byte stream. + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice input[] = { + grpc_slice_from_static_string("foo"), + grpc_slice_from_static_string("bar"), + }; + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + grpc_slice_buffer_add(&buffer, input[i]); + } + SliceBufferByteStream underlying_stream(&buffer, 0); + grpc_slice_buffer_destroy_internal(&buffer); + // Create cache and caching stream. + ByteStreamCache cache((OrphanablePtr(&underlying_stream))); + ByteStreamCache::CachingByteStream stream(&cache); + grpc_closure closure; + GRPC_CLOSURE_INIT(&closure, NotCalledClosure, nullptr, + grpc_schedule_on_exec_ctx); + // Read one slice. + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + grpc_slice output; + grpc_error_handle error = stream.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[0], output)); + grpc_slice_unref_internal(output); + // Reset the caching stream. The reads should start over from the + // first slice. + stream.Reset(); + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + ASSERT_TRUE(stream.Next(~(size_t)0, &closure)); + error = stream.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[i], output)); + grpc_slice_unref_internal(output); + } + // Clean up. + stream.Orphan(); + cache.Destroy(); +} + +TEST(CachingByteStream, SharedCache) { + grpc_core::ExecCtx exec_ctx; + // Create and populate slice buffer byte stream. + grpc_slice_buffer buffer; + grpc_slice_buffer_init(&buffer); + grpc_slice input[] = { + grpc_slice_from_static_string("foo"), + grpc_slice_from_static_string("bar"), + }; + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + grpc_slice_buffer_add(&buffer, input[i]); + } + SliceBufferByteStream underlying_stream(&buffer, 0); + grpc_slice_buffer_destroy_internal(&buffer); + // Create cache and two caching streams. + ByteStreamCache cache((OrphanablePtr(&underlying_stream))); + ByteStreamCache::CachingByteStream stream1(&cache); + ByteStreamCache::CachingByteStream stream2(&cache); + grpc_closure closure; + GRPC_CLOSURE_INIT(&closure, NotCalledClosure, nullptr, + grpc_schedule_on_exec_ctx); + // Read one slice from stream1. + EXPECT_TRUE(stream1.Next(~(size_t)0, &closure)); + grpc_slice output; + grpc_error_handle error = stream1.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[0], output)); + grpc_slice_unref_internal(output); + // Read all slices from stream2. + for (size_t i = 0; i < GPR_ARRAY_SIZE(input); ++i) { + EXPECT_TRUE(stream2.Next(~(size_t)0, &closure)); + error = stream2.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[i], output)); + grpc_slice_unref_internal(output); + } + // Now read the second slice from stream1. + EXPECT_TRUE(stream1.Next(~(size_t)0, &closure)); + error = stream1.Pull(&output); + EXPECT_TRUE(error == GRPC_ERROR_NONE); + EXPECT_TRUE(grpc_slice_eq(input[1], output)); + grpc_slice_unref_internal(output); + // Clean up. + stream1.Orphan(); + stream2.Orphan(); + cache.Destroy(); +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/transport/chttp2/alpn_test.cc b/test/core/transport/chttp2/alpn_test.cc new file mode 100644 index 00000000..ffd1f664 --- /dev/null +++ b/test/core/transport/chttp2/alpn_test.cc @@ -0,0 +1,58 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/alpn/alpn.h" + +#include + +#include "test/core/util/test_config.h" + +static void test_alpn_success(void) { + GPR_ASSERT(grpc_chttp2_is_alpn_version_supported("h2", 2)); + GPR_ASSERT(grpc_chttp2_is_alpn_version_supported("grpc-exp", 8)); +} + +static void test_alpn_failure(void) { + GPR_ASSERT(!grpc_chttp2_is_alpn_version_supported("h2-155", 6)); + GPR_ASSERT(!grpc_chttp2_is_alpn_version_supported("h1-15", 5)); +} + +// First index in ALPN supported version list of a given protocol. Returns a +// value one beyond the last valid element index if not found. +static size_t alpn_version_index(const char* version, size_t size) { + size_t i; + for (i = 0; i < grpc_chttp2_num_alpn_versions(); ++i) { + if (!strncmp(version, grpc_chttp2_get_alpn_version_index(i), size)) { + return i; + } + } + return i; +} + +static void test_alpn_grpc_before_h2(void) { + // grpc-exp is preferred over h2. + GPR_ASSERT(alpn_version_index("grpc-exp", 8) < alpn_version_index("h2", 2)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_alpn_success(); + test_alpn_failure(); + test_alpn_grpc_before_h2(); + return 0; +} diff --git a/test/core/transport/chttp2/bin_decoder_test.cc b/test/core/transport/chttp2/bin_decoder_test.cc new file mode 100644 index 00000000..3db0fff7 --- /dev/null +++ b/test/core/transport/chttp2/bin_decoder_test.cc @@ -0,0 +1,171 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/bin_decoder.h" + +#include + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/util/test_config.h" + +static int all_ok = 1; + +static void expect_slice_eq(grpc_slice expected, grpc_slice slice, + const char* debug, int line) { + if (!grpc_slice_eq(slice, expected)) { + char* hs = grpc_dump_slice(slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* he = grpc_dump_slice(expected, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "FAILED:%d: %s\ngot: %s\nwant: %s", line, debug, hs, + he); + gpr_free(hs); + gpr_free(he); + all_ok = 0; + } + grpc_slice_unref_internal(expected); + grpc_slice_unref_internal(slice); +} + +static grpc_slice base64_encode(const char* s) { + grpc_slice ss = grpc_slice_from_copied_string(s); + grpc_slice out = grpc_chttp2_base64_encode(ss); + grpc_slice_unref_internal(ss); + return out; +} + +static grpc_slice base64_decode(const char* s) { + grpc_slice ss = grpc_slice_from_copied_string(s); + grpc_slice out = grpc_chttp2_base64_decode(ss); + grpc_slice_unref_internal(ss); + return out; +} + +static grpc_slice base64_decode_with_length(const char* s, + size_t output_length) { + grpc_slice ss = grpc_slice_from_copied_string(s); + grpc_slice out = grpc_chttp2_base64_decode_with_length(ss, output_length); + grpc_slice_unref_internal(ss); + return out; +} + +static size_t base64_infer_length(const char* s) { + grpc_slice ss = grpc_slice_from_copied_string(s); + size_t out = grpc_chttp2_base64_infer_length_after_decode(ss); + grpc_slice_unref_internal(ss); + return out; +} + +#define EXPECT_DECODED_LENGTH(s, expected) \ + GPR_ASSERT((expected) == base64_infer_length((s))); + +#define EXPECT_SLICE_EQ(expected, slice) \ + expect_slice_eq( \ + grpc_slice_from_copied_buffer(expected, sizeof(expected) - 1), slice, \ + #slice, __LINE__); + +#define ENCODE_AND_DECODE(s) \ + EXPECT_SLICE_EQ( \ + s, grpc_chttp2_base64_decode_with_length(base64_encode(s), strlen(s))); + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + + /* ENCODE_AND_DECODE tests grpc_chttp2_base64_decode_with_length(), which + takes encoded base64 strings without pad chars, but output length is + required. */ + /* Base64 test vectors from RFC 4648 */ + ENCODE_AND_DECODE(""); + ENCODE_AND_DECODE("f"); + ENCODE_AND_DECODE("foo"); + ENCODE_AND_DECODE("fo"); + ENCODE_AND_DECODE("foob"); + ENCODE_AND_DECODE("fooba"); + ENCODE_AND_DECODE("foobar"); + + ENCODE_AND_DECODE("\xc0\xc1\xc2\xc3\xc4\xc5"); + + /* Base64 test vectors from RFC 4648, with pad chars */ + /* BASE64("") = "" */ + EXPECT_SLICE_EQ("", base64_decode("")); + /* BASE64("f") = "Zg==" */ + EXPECT_SLICE_EQ("f", base64_decode("Zg==")); + /* BASE64("fo") = "Zm8=" */ + EXPECT_SLICE_EQ("fo", base64_decode("Zm8=")); + /* BASE64("foo") = "Zm9v" */ + EXPECT_SLICE_EQ("foo", base64_decode("Zm9v")); + /* BASE64("foob") = "Zm9vYg==" */ + EXPECT_SLICE_EQ("foob", base64_decode("Zm9vYg==")); + /* BASE64("fooba") = "Zm9vYmE=" */ + EXPECT_SLICE_EQ("fooba", base64_decode("Zm9vYmE=")); + /* BASE64("foobar") = "Zm9vYmFy" */ + EXPECT_SLICE_EQ("foobar", base64_decode("Zm9vYmFy")); + + EXPECT_SLICE_EQ("\xc0\xc1\xc2\xc3\xc4\xc5", base64_decode("wMHCw8TF")); + + // Test illegal input length in grpc_chttp2_base64_decode + EXPECT_SLICE_EQ("", base64_decode("a")); + EXPECT_SLICE_EQ("", base64_decode("ab")); + EXPECT_SLICE_EQ("", base64_decode("abc")); + + // Test illegal charactors in grpc_chttp2_base64_decode + EXPECT_SLICE_EQ("", base64_decode("Zm:v")); + EXPECT_SLICE_EQ("", base64_decode("Zm=v")); + + // Test output_length longer than max possible output length in + // grpc_chttp2_base64_decode_with_length + EXPECT_SLICE_EQ("", base64_decode_with_length("Zg", 2)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm8", 3)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm9v", 4)); + + // Test illegal charactors in grpc_chttp2_base64_decode_with_length + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm:v", 3)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm=v", 3)); + + EXPECT_DECODED_LENGTH("", 0); + EXPECT_DECODED_LENGTH("ab", 1); + EXPECT_DECODED_LENGTH("abc", 2); + EXPECT_DECODED_LENGTH("abcd", 3); + EXPECT_DECODED_LENGTH("abcdef", 4); + EXPECT_DECODED_LENGTH("abcdefg", 5); + EXPECT_DECODED_LENGTH("abcdefgh", 6); + + EXPECT_DECODED_LENGTH("ab==", 1); + EXPECT_DECODED_LENGTH("abc=", 2); + EXPECT_DECODED_LENGTH("abcd", 3); + EXPECT_DECODED_LENGTH("abcdef==", 4); + EXPECT_DECODED_LENGTH("abcdefg=", 5); + EXPECT_DECODED_LENGTH("abcdefgh", 6); + + EXPECT_DECODED_LENGTH("a", 0); + EXPECT_DECODED_LENGTH("a===", 0); + EXPECT_DECODED_LENGTH("abcde", 0); + EXPECT_DECODED_LENGTH("abcde===", 0); + } + grpc_shutdown(); + return all_ok ? 0 : 1; +} diff --git a/test/core/transport/chttp2/bin_encoder_test.cc b/test/core/transport/chttp2/bin_encoder_test.cc new file mode 100644 index 00000000..6cf53be7 --- /dev/null +++ b/test/core/transport/chttp2/bin_encoder_test.cc @@ -0,0 +1,179 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" + +#include + +/* This is here for grpc_is_binary_header + * TODO(murgatroid99): Remove this + */ +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/util/test_config.h" + +static int all_ok = 1; + +static void expect_slice_eq(grpc_slice expected, grpc_slice slice, + const char* debug, int line) { + if (!grpc_slice_eq(slice, expected)) { + char* hs = grpc_dump_slice(slice, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* he = grpc_dump_slice(expected, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "FAILED:%d: %s\ngot: %s\nwant: %s", line, debug, hs, + he); + gpr_free(hs); + gpr_free(he); + all_ok = 0; + } + grpc_slice_unref(expected); + grpc_slice_unref(slice); +} + +static grpc_slice B64(const char* s) { + grpc_slice ss = grpc_slice_from_copied_string(s); + grpc_slice out = grpc_chttp2_base64_encode(ss); + grpc_slice_unref(ss); + return out; +} + +static grpc_slice HUFF(const char* s) { + grpc_slice ss = grpc_slice_from_copied_string(s); + grpc_slice out = grpc_chttp2_huffman_compress(ss); + grpc_slice_unref(ss); + return out; +} + +#define EXPECT_SLICE_EQ(expected, slice) \ + expect_slice_eq( \ + grpc_slice_from_copied_buffer(expected, sizeof(expected) - 1), slice, \ + #slice, __LINE__); + +static void expect_combined_equiv(const char* s, size_t len, int line) { + grpc_slice input = grpc_slice_from_copied_buffer(s, len); + grpc_slice base64 = grpc_chttp2_base64_encode(input); + grpc_slice expect = grpc_chttp2_huffman_compress(base64); + grpc_slice got = grpc_chttp2_base64_encode_and_huffman_compress(input); + if (!grpc_slice_eq(expect, got)) { + char* t = grpc_dump_slice(input, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* e = grpc_dump_slice(expect, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* g = grpc_dump_slice(got, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "FAILED:%d:\ntest: %s\ngot: %s\nwant: %s", line, t, g, + e); + gpr_free(t); + gpr_free(e); + gpr_free(g); + all_ok = 0; + } + grpc_slice_unref(input); + grpc_slice_unref(base64); + grpc_slice_unref(expect); + grpc_slice_unref(got); +} + +#define EXPECT_COMBINED_EQUIV(x) \ + expect_combined_equiv(x, sizeof(x) - 1, __LINE__) + +static void expect_binary_header(const char* hdr, int binary) { + if (grpc_is_binary_header(grpc_slice_from_static_string(hdr)) != binary) { + gpr_log(GPR_ERROR, "FAILED: expected header '%s' to be %s", hdr, + binary ? "binary" : "not binary"); + all_ok = 0; + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + /* Base64 test vectors from RFC 4648, with padding removed */ + /* BASE64("") = "" */ + EXPECT_SLICE_EQ("", B64("")); + /* BASE64("f") = "Zg" */ + EXPECT_SLICE_EQ("Zg", B64("f")); + /* BASE64("fo") = "Zm8" */ + EXPECT_SLICE_EQ("Zm8", B64("fo")); + /* BASE64("foo") = "Zm9v" */ + EXPECT_SLICE_EQ("Zm9v", B64("foo")); + /* BASE64("foob") = "Zm9vYg" */ + EXPECT_SLICE_EQ("Zm9vYg", B64("foob")); + /* BASE64("fooba") = "Zm9vYmE" */ + EXPECT_SLICE_EQ("Zm9vYmE", B64("fooba")); + /* BASE64("foobar") = "Zm9vYmFy" */ + EXPECT_SLICE_EQ("Zm9vYmFy", B64("foobar")); + + EXPECT_SLICE_EQ("wMHCw8TF", B64("\xc0\xc1\xc2\xc3\xc4\xc5")); + + /* Huffman encoding tests */ + EXPECT_SLICE_EQ("\xf1\xe3\xc2\xe5\xf2\x3a\x6b\xa0\xab\x90\xf4\xff", + HUFF("www.example.com")); + EXPECT_SLICE_EQ("\xa8\xeb\x10\x64\x9c\xbf", HUFF("no-cache")); + EXPECT_SLICE_EQ("\x25\xa8\x49\xe9\x5b\xa9\x7d\x7f", HUFF("custom-key")); + EXPECT_SLICE_EQ("\x25\xa8\x49\xe9\x5b\xb8\xe8\xb4\xbf", HUFF("custom-value")); + EXPECT_SLICE_EQ("\xae\xc3\x77\x1a\x4b", HUFF("private")); + EXPECT_SLICE_EQ( + "\xd0\x7a\xbe\x94\x10\x54\xd4\x44\xa8\x20\x05\x95\x04\x0b\x81\x66\xe0\x82" + "\xa6\x2d\x1b\xff", + HUFF("Mon, 21 Oct 2013 20:13:21 GMT")); + EXPECT_SLICE_EQ( + "\x9d\x29\xad\x17\x18\x63\xc7\x8f\x0b\x97\xc8\xe9\xae\x82\xae\x43\xd3", + HUFF("https://www.example.com")); + + /* Various test vectors for combined encoding */ + EXPECT_COMBINED_EQUIV(""); + EXPECT_COMBINED_EQUIV("f"); + EXPECT_COMBINED_EQUIV("fo"); + EXPECT_COMBINED_EQUIV("foo"); + EXPECT_COMBINED_EQUIV("foob"); + EXPECT_COMBINED_EQUIV("fooba"); + EXPECT_COMBINED_EQUIV("foobar"); + EXPECT_COMBINED_EQUIV("www.example.com"); + EXPECT_COMBINED_EQUIV("no-cache"); + EXPECT_COMBINED_EQUIV("custom-key"); + EXPECT_COMBINED_EQUIV("custom-value"); + EXPECT_COMBINED_EQUIV("private"); + EXPECT_COMBINED_EQUIV("Mon, 21 Oct 2013 20:13:21 GMT"); + EXPECT_COMBINED_EQUIV("https://www.example.com"); + EXPECT_COMBINED_EQUIV( + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f" + "\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f" + "\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x3a\x3b\x3c\x3d\x3e\x3f" + "\x40\x41\x42\x43\x44\x45\x46\x47\x48\x49\x4a\x4b\x4c\x4d\x4e\x4f" + "\x50\x51\x52\x53\x54\x55\x56\x57\x58\x59\x5a\x5b\x5c\x5d\x5e\x5f" + "\x60\x61\x62\x63\x64\x65\x66\x67\x68\x69\x6a\x6b\x6c\x6d\x6e\x6f" + "\x70\x71\x72\x73\x74\x75\x76\x77\x78\x79\x7a\x7b\x7c\x7d\x7e\x7f" + "\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f" + "\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f" + "\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf" + "\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf" + "\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf" + "\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf" + "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef" + "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"); + + expect_binary_header("foo-bin", 1); + expect_binary_header("foo-bar", 0); + expect_binary_header("-bin", 0); + + grpc_shutdown(); + return all_ok ? 0 : 1; +} diff --git a/test/core/transport/chttp2/context_list_test.cc b/test/core/transport/chttp2/context_list_test.cc new file mode 100644 index 00000000..89ac6ac3 --- /dev/null +++ b/test/core/transport/chttp2/context_list_test.cc @@ -0,0 +1,178 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/context_list.h" + +#include +#include + +#include + +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/transport/transport.h" +#include "test/core/util/mock_endpoint.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +const uint32_t kByteOffset = 123; + +void* PhonyArgsCopier(void* arg) { return arg; } + +void TestExecuteFlushesListVerifier(void* arg, grpc_core::Timestamps* ts, + grpc_error_handle error) { + ASSERT_NE(arg, nullptr); + EXPECT_EQ(error, GRPC_ERROR_NONE); + if (ts) { + EXPECT_EQ(ts->byte_offset, kByteOffset); + } + gpr_atm* done = reinterpret_cast(arg); + gpr_atm_rel_store(done, static_cast(1)); +} + +void discard_write(grpc_slice /*slice*/) {} + +class ContextListTest : public ::testing::Test { + protected: + void SetUp() override { + grpc_http2_set_write_timestamps_callback(TestExecuteFlushesListVerifier); + grpc_http2_set_fn_get_copied_context(PhonyArgsCopier); + } +}; + +/** Tests that all ContextList elements in the list are flushed out on + * execute. + * Also tests that arg and byte_counter are passed correctly. + */ +TEST_F(ContextListTest, ExecuteFlushesList) { + grpc_core::ContextList* list = nullptr; + const int kNumElems = 5; + grpc_core::ExecCtx exec_ctx; + grpc_stream_refcount ref; + GRPC_STREAM_REF_INIT(&ref, 1, nullptr, nullptr, "phony ref"); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("context_list_test"); + grpc_endpoint* mock_endpoint = grpc_mock_endpoint_create( + discard_write, + grpc_slice_allocator_create(resource_quota, "mock_endpoint")); + grpc_transport* t = grpc_create_chttp2_transport( + nullptr, mock_endpoint, true, + grpc_resource_user_create(resource_quota, "mock_transport")); + grpc_resource_quota_unref(resource_quota); + std::vector s; + s.reserve(kNumElems); + gpr_atm verifier_called[kNumElems]; + for (auto i = 0; i < kNumElems; i++) { + s.push_back(static_cast( + gpr_malloc(grpc_transport_stream_size(t)))); + grpc_transport_init_stream(reinterpret_cast(t), + reinterpret_cast(s[i]), &ref, + nullptr, nullptr); + s[i]->context = &verifier_called[i]; + s[i]->byte_counter = kByteOffset; + gpr_atm_rel_store(&verifier_called[i], static_cast(0)); + grpc_core::ContextList::Append(&list, s[i]); + } + grpc_core::Timestamps ts; + grpc_core::ContextList::Execute(list, &ts, GRPC_ERROR_NONE); + for (auto i = 0; i < kNumElems; i++) { + EXPECT_EQ(gpr_atm_acq_load(&verifier_called[i]), static_cast(1)); + grpc_transport_destroy_stream(reinterpret_cast(t), + reinterpret_cast(s[i]), + nullptr); + exec_ctx.Flush(); + gpr_free(s[i]); + } + grpc_transport_destroy(t); + exec_ctx.Flush(); +} + +TEST_F(ContextListTest, EmptyList) { + grpc_core::ContextList* list = nullptr; + grpc_core::ExecCtx exec_ctx; + grpc_core::Timestamps ts; + grpc_core::ContextList::Execute(list, &ts, GRPC_ERROR_NONE); + exec_ctx.Flush(); +} + +TEST_F(ContextListTest, EmptyListEmptyTimestamp) { + grpc_core::ContextList* list = nullptr; + grpc_core::ExecCtx exec_ctx; + grpc_core::ContextList::Execute(list, nullptr, GRPC_ERROR_NONE); + exec_ctx.Flush(); +} + +TEST_F(ContextListTest, NonEmptyListEmptyTimestamp) { + grpc_core::ContextList* list = nullptr; + const int kNumElems = 5; + grpc_core::ExecCtx exec_ctx; + grpc_stream_refcount ref; + GRPC_STREAM_REF_INIT(&ref, 1, nullptr, nullptr, "phony ref"); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("context_list_test"); + grpc_endpoint* mock_endpoint = grpc_mock_endpoint_create( + discard_write, + grpc_slice_allocator_create(resource_quota, "mock_endpoint")); + grpc_transport* t = grpc_create_chttp2_transport( + nullptr, mock_endpoint, true, + grpc_resource_user_create(resource_quota, "mock_transport")); + grpc_resource_quota_unref(resource_quota); + std::vector s; + s.reserve(kNumElems); + gpr_atm verifier_called[kNumElems]; + for (auto i = 0; i < kNumElems; i++) { + s.push_back(static_cast( + gpr_malloc(grpc_transport_stream_size(t)))); + grpc_transport_init_stream(reinterpret_cast(t), + reinterpret_cast(s[i]), &ref, + nullptr, nullptr); + s[i]->context = &verifier_called[i]; + s[i]->byte_counter = kByteOffset; + gpr_atm_rel_store(&verifier_called[i], static_cast(0)); + grpc_core::ContextList::Append(&list, s[i]); + } + grpc_core::ContextList::Execute(list, nullptr, GRPC_ERROR_NONE); + for (auto i = 0; i < kNumElems; i++) { + EXPECT_EQ(gpr_atm_acq_load(&verifier_called[i]), static_cast(1)); + grpc_transport_destroy_stream(reinterpret_cast(t), + reinterpret_cast(s[i]), + nullptr); + exec_ctx.Flush(); + gpr_free(s[i]); + } + grpc_transport_destroy(t); + exec_ctx.Flush(); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/chttp2/flow_control_test.cc b/test/core/transport/chttp2/flow_control_test.cc new file mode 100644 index 00000000..54ef1d8f --- /dev/null +++ b/test/core/transport/chttp2/flow_control_test.cc @@ -0,0 +1,366 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/flow_control.h" + +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +class TransportTargetWindowSizeMocker + : public grpc_core::chttp2::TestOnlyTransportTargetWindowEstimatesMocker { + public: + static constexpr uint32_t kLargeInitialWindowSize = 1u << 31; + static constexpr uint32_t kSmallInitialWindowSize = 0; + + double ComputeNextTargetInitialWindowSizeFromPeriodicUpdate( + double /* current_target */) override { + if (alternating_initial_window_sizes_) { + window_size_ = (window_size_ == kLargeInitialWindowSize) + ? kSmallInitialWindowSize + : kLargeInitialWindowSize; + } + return window_size_; + } + + // Alternates the initial window size targets. Computes a low values if it was + // previously high, or a high value if it was previously low. + void AlternateTargetInitialWindowSizes() { + alternating_initial_window_sizes_ = true; + } + + void Reset() { + alternating_initial_window_sizes_ = false; + window_size_ = kLargeInitialWindowSize; + } + + private: + bool alternating_initial_window_sizes_ = false; + double window_size_ = kLargeInitialWindowSize; +}; + +TransportTargetWindowSizeMocker* g_target_initial_window_size_mocker; + +void* tag(intptr_t t) { return reinterpret_cast(t); } + +void VerifyChannelReady(grpc_channel* channel, grpc_completion_queue* cq) { + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(channel, 1 /* try_to_connect */); + while (state != GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state( + channel, state, grpc_timeout_seconds_to_deadline(5), cq, nullptr); + grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), + nullptr); + state = grpc_channel_check_connectivity_state(channel, 0); + } +} + +void VerifyChannelConnected(grpc_channel* channel, grpc_completion_queue* cq) { + // Verify channel is connected. Use a ping to make sure that clients + // tries sending/receiving bytes if the channel is connected. + grpc_channel_ping(channel, cq, reinterpret_cast(2000), nullptr); + grpc_event ev = grpc_completion_queue_next( + cq, grpc_timeout_seconds_to_deadline(5), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == reinterpret_cast(2000)); + GPR_ASSERT(ev.success == 1); + GPR_ASSERT(grpc_channel_check_connectivity_state(channel, 0) == + GRPC_CHANNEL_READY); +} + +// Shuts down and destroys the server. +void ServerShutdownAndDestroy(grpc_server* server, grpc_completion_queue* cq) { + // Shutdown and destroy server + grpc_server_shutdown_and_notify(server, cq, reinterpret_cast(1000)); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .tag != reinterpret_cast(1000)) { + } + grpc_server_destroy(server); +} + +grpc_slice LargeSlice(void) { + grpc_slice slice = grpc_slice_malloc(10000000); // ~10MB + memset(GRPC_SLICE_START_PTR(slice), 'x', GRPC_SLICE_LENGTH(slice)); + return slice; +} + +void PerformCallWithLargePayload(grpc_channel* channel, grpc_server* server, + grpc_completion_queue* cq) { + grpc_slice request_payload_slice = LargeSlice(); + grpc_slice response_payload_slice = LargeSlice(); + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* request_payload_recv = nullptr; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(30); + c = grpc_channel_create_call(channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &request_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_UNIMPLEMENTED; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_UNIMPLEMENTED); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(request_payload); + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(request_payload_recv); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_slice_unref(request_payload_slice); + grpc_slice_unref(response_payload_slice); +} + +class FlowControlTest : public ::testing::Test { + protected: + void SetUp() override { + cq_ = grpc_completion_queue_create_for_next(nullptr); + // create the server + std::string server_address = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + grpc_arg server_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH), -1), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH), -1)}; + grpc_channel_args server_channel_args = {GPR_ARRAY_SIZE(server_args), + server_args}; + server_ = grpc_server_create(&server_channel_args, nullptr); + grpc_server_register_completion_queue(server_, cq_, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(server_, server_address.c_str())); + grpc_server_start(server_); + // create the channel (bdp pings are enabled by default) + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 1), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH), -1), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH), -1)}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + channel_ = grpc_insecure_channel_create(server_address.c_str(), + &client_channel_args, nullptr); + VerifyChannelReady(channel_, cq_); + g_target_initial_window_size_mocker->Reset(); + } + + void TearDown() override { + // shutdown and destroy the client and server + grpc_channel_destroy(channel_); + ServerShutdownAndDestroy(server_, cq_); + grpc_completion_queue_shutdown(cq_); + while (grpc_completion_queue_next(cq_, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq_); + } + + grpc_server* server_ = nullptr; + grpc_channel* channel_ = nullptr; + grpc_completion_queue* cq_ = nullptr; +}; + +TEST_F(FlowControlTest, + TestLargeWindowSizeUpdatesDoNotCauseIllegalFlowControlWindows) { + for (int i = 0; i < 10; ++i) { + PerformCallWithLargePayload(channel_, server_, cq_); + VerifyChannelConnected(channel_, cq_); + } +} + +TEST_F(FlowControlTest, TestWindowSizeUpdatesDoNotCauseStalledStreams) { + g_target_initial_window_size_mocker->AlternateTargetInitialWindowSizes(); + for (int i = 0; i < 100; ++i) { + PerformCallWithLargePayload(channel_, server_, cq_); + VerifyChannelConnected(channel_, cq_); + } +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + // Make sure that we will have an active poller on all client-side fd's that + // are capable of sending and receiving even in the case that we don't have an + // active RPC operation on the fd. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); + ::grpc_core::chttp2::g_test_only_transport_flow_control_window_check = true; + g_target_initial_window_size_mocker = new TransportTargetWindowSizeMocker(); + grpc_core::chttp2::g_test_only_transport_target_window_estimates_mocker = + g_target_initial_window_size_mocker; + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/transport/chttp2/hpack_encoder_index_test.cc b/test/core/transport/chttp2/hpack_encoder_index_test.cc new file mode 100644 index 00000000..e7286fbd --- /dev/null +++ b/test/core/transport/chttp2/hpack_encoder_index_test.cc @@ -0,0 +1,63 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder_index.h" + +#include +#include + +#include + +namespace grpc_core { +namespace testing { + +struct TestKey { + using Stored = uint32_t; + uint32_t value; + uint32_t stored() const { return value; } + uint32_t hash() const { return value; } + bool operator==(uint32_t other) const { return other == value; } +}; + +TEST(HPackEncoderIndexTest, SetAndGet) { + HPackEncoderIndex index; + std::default_random_engine rng; + std::unordered_map last_index; + for (uint32_t i = 0; i < 10000; i++) { + uint32_t key = rng(); + index.Insert({key}, i); + EXPECT_EQ(index.Lookup({key}), i); + last_index[key] = i; + } + for (auto p : last_index) { + auto r = index.Lookup({p.first}); + if (r.has_value()) { + EXPECT_EQ(*r, p.second); + } + } +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/chttp2/hpack_encoder_test.cc b/test/core/transport/chttp2/hpack_encoder_test.cc new file mode 100644 index 00000000..d59248cf --- /dev/null +++ b/test/core/transport/chttp2/hpack_encoder_test.cc @@ -0,0 +1,402 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" + +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/ext/transport/chttp2/transport/hpack_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/metadata.h" +#include "test/core/util/parse_hexstring.h" +#include "test/core/util/slice_splitter.h" +#include "test/core/util/test_config.h" + +#define TEST(x) run_test(x, #x) + +grpc_core::HPackCompressor* g_compressor; +int g_failure = 0; + +void** to_delete = nullptr; +size_t num_to_delete = 0; +size_t cap_to_delete = 0; + +typedef struct { + bool eof; + bool use_true_binary_metadata; + bool only_intern_key; +} verify_params; + +/* verify that the output frames that are generated by encoding the stream + have sensible type and flags values */ +static void verify_frames(grpc_slice_buffer& output, bool header_is_eof) { + /* per the HTTP/2 spec: + All frames begin with a fixed 9-octet header followed by a + variable-length payload. + + +-----------------------------------------------+ + | Length (24) | + +---------------+---------------+---------------+ + | Type (8) | Flags (8) | + +-+-------------+---------------+-------------------------------+ + |R| Stream Identifier (31) | + +=+=============================================================+ + | Frame Payload (0...) ... + +---------------------------------------------------------------+ + */ + uint8_t type = 0xff, flags = 0xff; + size_t i, merged_length, frame_size; + bool first_frame = false; + bool in_header = false; + bool end_header = false; + bool is_closed = false; + for (i = 0; i < output.count;) { + first_frame = i == 0; + grpc_slice* slice = &output.slices[i++]; + + // Read gRPC frame header + uint8_t* p = GRPC_SLICE_START_PTR(*slice); + frame_size = 0; + frame_size |= static_cast(p[0]) << 16; + frame_size |= static_cast(p[1]) << 8; + frame_size |= static_cast(p[2]); + type = p[3]; + flags = p[4]; + + // Read remainder of the gRPC frame + merged_length = GRPC_SLICE_LENGTH(*slice); + while (merged_length < frame_size + 9) { // including 9 byte frame header + grpc_slice* slice = &output.slices[i++]; + merged_length += GRPC_SLICE_LENGTH(*slice); + } + + // Verifications + if (first_frame && type != GRPC_CHTTP2_FRAME_HEADER) { + gpr_log(GPR_ERROR, "expected first frame to be of type header"); + gpr_log(GPR_ERROR, "EXPECT: 0x%x", GRPC_CHTTP2_FRAME_HEADER); + gpr_log(GPR_ERROR, "GOT: 0x%x", type); + g_failure = 1; + } else if (first_frame && header_is_eof && + !(flags & GRPC_CHTTP2_DATA_FLAG_END_STREAM)) { + gpr_log(GPR_ERROR, "missing END_STREAM flag in HEADER frame"); + g_failure = 1; + } + if (is_closed && + (type == GRPC_CHTTP2_FRAME_DATA || type == GRPC_CHTTP2_FRAME_HEADER)) { + gpr_log(GPR_ERROR, + "stream is closed; new frame headers and data are not allowed"); + g_failure = 1; + } + if (end_header && (type == GRPC_CHTTP2_FRAME_HEADER || + type == GRPC_CHTTP2_FRAME_CONTINUATION)) { + gpr_log(GPR_ERROR, + "frame header is ended; new headers and continuations are not " + "allowed"); + g_failure = 1; + } + if (in_header && + (type == GRPC_CHTTP2_FRAME_DATA || type == GRPC_CHTTP2_FRAME_HEADER)) { + gpr_log(GPR_ERROR, + "parsing frame header; new headers and data are not allowed"); + g_failure = 1; + } + if (flags & ~(GRPC_CHTTP2_DATA_FLAG_END_STREAM | + GRPC_CHTTP2_DATA_FLAG_END_HEADERS)) { + gpr_log(GPR_ERROR, "unexpected frame flags: 0x%x", flags); + g_failure = 1; + } + + // Update state + if (flags & GRPC_CHTTP2_DATA_FLAG_END_HEADERS) { + in_header = false; + end_header = true; + } else if (type == GRPC_CHTTP2_DATA_FLAG_END_HEADERS) { + in_header = true; + } + if (flags & GRPC_CHTTP2_DATA_FLAG_END_STREAM) { + is_closed = true; + if (type == GRPC_CHTTP2_FRAME_CONTINUATION) { + gpr_log(GPR_ERROR, "unexpected END_STREAM flag in CONTINUATION frame"); + g_failure = 1; + } + } + } +} + +/* verify that the output generated by encoding the stream matches the + hexstring passed in */ +static void verify(const verify_params params, const char* expected, + size_t nheaders, ...) { + grpc_slice_buffer output; + grpc_slice merged; + grpc_slice expect = parse_hexstring(expected); + size_t i; + va_list l; + grpc_linked_mdelem* e = + static_cast(gpr_malloc(sizeof(*e) * nheaders)); + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + + va_start(l, nheaders); + for (i = 0; i < nheaders; i++) { + char* key = va_arg(l, char*); + char* value = va_arg(l, char*); + grpc_slice value_slice = grpc_slice_from_static_string(value); + if (!params.only_intern_key) { + value_slice = grpc_slice_intern(value_slice); + } + e[i].md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(key)), value_slice); + GPR_ASSERT(GRPC_ERROR_NONE == b.LinkTail(&e[i])); + } + va_end(l); + + if (cap_to_delete == num_to_delete) { + cap_to_delete = std::max(2 * cap_to_delete, size_t(1000)); + to_delete = static_cast( + gpr_realloc(to_delete, sizeof(*to_delete) * cap_to_delete)); + } + to_delete[num_to_delete++] = e; + + grpc_slice_buffer_init(&output); + + grpc_transport_one_way_stats stats; + stats = {}; + grpc_core::HPackCompressor::EncodeHeaderOptions hopt{ + 0xdeadbeef, /* stream_id */ + params.eof, /* is_eof */ + params.use_true_binary_metadata, /* use_true_binary_metadata */ + 16384, /* max_frame_size */ + &stats /* stats */ + }; + g_compressor->EncodeHeaders(hopt, b, &output); + verify_frames(output, params.eof); + merged = grpc_slice_merge(output.slices, output.count); + grpc_slice_buffer_destroy_internal(&output); + + if (!grpc_slice_eq(merged, expect)) { + char* expect_str = grpc_dump_slice(expect, GPR_DUMP_HEX | GPR_DUMP_ASCII); + char* got_str = grpc_dump_slice(merged, GPR_DUMP_HEX | GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "mismatched output for %s", expected); + gpr_log(GPR_ERROR, "EXPECT: %s", expect_str); + gpr_log(GPR_ERROR, "GOT: %s", got_str); + gpr_free(expect_str); + gpr_free(got_str); + g_failure = 1; + } + + grpc_slice_unref_internal(merged); + grpc_slice_unref_internal(expect); +} + +static void test_basic_headers() { + int i; + + verify_params params = { + false, + false, + false, + }; + verify(params, "000005 0104 deadbeef 40 0161 0161", 1, "a", "a"); + verify(params, "000001 0104 deadbeef be", 1, "a", "a"); + verify(params, "000001 0104 deadbeef be", 1, "a", "a"); + verify(params, "000006 0104 deadbeef be 40 0162 0163", 2, "a", "a", "b", "c"); + verify(params, "000002 0104 deadbeef bf be", 2, "a", "a", "b", "c"); + verify(params, "000004 0104 deadbeef 7f 00 0164", 1, "a", "d"); + + /* flush out what's there to make a few values look very popular */ + for (i = 0; i < 350; i++) { + verify(params, "000003 0104 deadbeef c0 bf be", 3, "a", "a", "b", "c", "a", + "d"); + } + + verify(params, "000006 0104 deadbeef c0 00 016b 0176", 2, "a", "a", "k", "v"); + /* this could be 000004 0104 deadbeef 0f 30 0176 also */ + verify(params, "000004 0104 deadbeef 0f 2f 0176", 1, "a", "v"); +} + +static void verify_continuation_headers(const char* key, const char* value, + bool is_eof) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_slice_buffer output; + grpc_mdelem elem = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(key)), + grpc_slice_intern(grpc_slice_from_static_string(value))); + grpc_linked_mdelem e; + e.md = elem; + e.prev = nullptr; + e.next = nullptr; + grpc_metadata_batch b(arena.get()); + GPR_ASSERT(GRPC_ERROR_NONE == b.LinkTail(&e)); + grpc_slice_buffer_init(&output); + + grpc_transport_one_way_stats stats; + stats = {}; + grpc_core::HPackCompressor::EncodeHeaderOptions hopt = { + 0xdeadbeef, /* stream_id */ + is_eof, /* is_eof */ + false, /* use_true_binary_metadata */ + 150, /* max_frame_size */ + &stats /* stats */}; + g_compressor->EncodeHeaders(hopt, b, &output); + verify_frames(output, is_eof); + grpc_slice_buffer_destroy_internal(&output); +} + +static void test_continuation_headers() { + char value[200]; + memset(value, 'a', 200); + value[199] = 0; // null terminator + verify_continuation_headers("key", value, true); + + char value2[400]; + memset(value2, 'b', 400); + value2[399] = 0; // null terminator + verify_continuation_headers("key2", value2, true); +} + +static void encode_int_to_str(int i, char* p) { + p[0] = static_cast('a' + i % 26); + i /= 26; + GPR_ASSERT(i < 26); + p[1] = static_cast('a' + i); + p[2] = 0; +} + +static void test_decode_table_overflow() { + // Decrease the default table size to make decode table overflow easier. + g_compressor->SetMaxTableSize(1024); + int i; + char key[3], value[3]; + + verify_params params = { + false, + false, + false, + }; + + for (i = 0; i < 29; i++) { + encode_int_to_str(i, key); + encode_int_to_str(i + 1, value); + if (i == 0) { + // 3fe107 corresponds to the table size update. + std::string expect = absl::StrFormat( + "00000a 0104 deadbeef 3fe107 40 02%02x%02x 02%02x%02x", key[0], + key[1], value[0], value[1]); + verify(params, expect.c_str(), 1, key, value); + } else { + std::string expect = + absl::StrFormat("000008 0104 deadbeef %02x 40 02%02x%02x 02%02x%02x", + 0x80 + 61 + i, key[0], key[1], value[0], value[1]); + verify(params, expect.c_str(), 2, "aa", "ba", key, value); + } + } + + /* if the above passes, then we must have just knocked this pair out of the + decoder stack, and so we'll be forced to re-encode it */ + verify(params, "000007 0104 deadbeef 40 026161 026261", 1, "aa", "ba"); +} + +static void verify_table_size_change_match_elem_size(const char* key, + const char* value, + bool use_true_binary) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_slice_buffer output; + grpc_mdelem elem = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(key)), + grpc_slice_intern(grpc_slice_from_static_string(value))); + size_t elem_size = grpc_core::MetadataSizeInHPackTable(elem, use_true_binary); + size_t initial_table_size = g_compressor->test_only_table_size(); + grpc_linked_mdelem e; + e.md = elem; + e.prev = nullptr; + e.next = nullptr; + grpc_metadata_batch b(arena.get()); + GPR_ASSERT(GRPC_ERROR_NONE == b.LinkTail(&e)); + grpc_slice_buffer_init(&output); + + grpc_transport_one_way_stats stats; + stats = {}; + grpc_core::HPackCompressor::EncodeHeaderOptions hopt = { + 0xdeadbeef, /* stream_id */ + false, /* is_eof */ + use_true_binary, /* use_true_binary_metadata */ + 16384, /* max_frame_size */ + &stats /* stats */}; + g_compressor->EncodeHeaders(hopt, b, &output); + verify_frames(output, false); + grpc_slice_buffer_destroy_internal(&output); + + GPR_ASSERT(g_compressor->test_only_table_size() == + elem_size + initial_table_size); +} + +static void test_encode_header_size() { + verify_table_size_change_match_elem_size("hello", "world", false); + verify_table_size_change_match_elem_size("hello-bin", "world", false); + verify_table_size_change_match_elem_size("true-binary-bin", + "I_am_true_binary_value", true); +} + +static void test_interned_key_indexed() { + int i; + verify_params params = {false, false, true}; + verify(params, "000009 0104 deadbeef 40 0161 0162 0f2f 0163", 2, "a", "b", + "a", "c"); + for (i = 0; i < 10; i++) { + verify(params, "000008 0104 deadbeef 0f2f 0162 0f2f 0163", 2, "a", "b", "a", + "c"); + } +} + +static void run_test(void (*test)(), const char* name) { + gpr_log(GPR_INFO, "RUN TEST: %s", name); + grpc_core::ExecCtx exec_ctx; + g_compressor = new grpc_core::HPackCompressor(); + test(); + delete g_compressor; +} + +int main(int argc, char** argv) { + size_t i; + grpc_test_only_set_slice_hash_seed(0); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + TEST(test_basic_headers); + TEST(test_decode_table_overflow); + TEST(test_encode_header_size); + TEST(test_interned_key_indexed); + TEST(test_continuation_headers); + grpc_shutdown(); + for (i = 0; i < num_to_delete; i++) { + gpr_free(to_delete[i]); + } + return g_failure; +} diff --git a/test/core/transport/chttp2/hpack_parser_fuzzer_test.cc b/test/core/transport/chttp2/hpack_parser_fuzzer_test.cc new file mode 100644 index 00000000..99e0eeac --- /dev/null +++ b/test/core/transport/chttp2/hpack_parser_fuzzer_test.cc @@ -0,0 +1,85 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "test/core/transport/chttp2/hpack_parser_fuzzer.pb.h" + +bool squelch = true; +bool leak_check = true; + +static void dont_log(gpr_log_func_args* /*args*/) {} + +DEFINE_PROTO_FUZZER(const hpack_parser_fuzzer::Msg& msg) { + grpc_test_only_set_slice_hash_seed(0); + if (squelch) gpr_set_log_function(dont_log); + grpc_init(); + { + std::unique_ptr parser(new grpc_core::HPackParser); + for (int i = 0; i < msg.frames_size(); i++) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_core::ExecCtx exec_ctx; + grpc_metadata_batch b(arena.get()); + + const auto& frame = msg.frames(i); + grpc_core::HPackParser::Boundary boundary = + grpc_core::HPackParser::Boundary::None; + if (frame.end_of_headers()) { + boundary = grpc_core::HPackParser::Boundary::EndOfHeaders; + } + if (frame.end_of_stream()) { + boundary = grpc_core::HPackParser::Boundary::EndOfStream; + } + grpc_core::HPackParser::Priority priority = + grpc_core::HPackParser::Priority::None; + if (frame.priority()) { + priority = grpc_core::HPackParser::Priority::Included; + } + int max_length = 1024; + if (frame.max_metadata_length() != 0) { + max_length = frame.max_metadata_length(); + } + + parser->BeginFrame( + &b, max_length, boundary, priority, + grpc_core::HPackParser::LogInfo{ + 1, grpc_core::HPackParser::LogInfo::kHeaders, false}); + int stop_buffering_ctr = + std::max(-1, frame.stop_buffering_after_segments()); + for (const auto& parse : frame.parse()) { + grpc_slice buffer = + grpc_slice_from_copied_buffer(parse.data(), parse.size()); + GRPC_ERROR_UNREF(parser->Parse(buffer, i == msg.frames_size() - 1)); + grpc_slice_unref(buffer); + stop_buffering_ctr--; + if (0 == stop_buffering_ctr) parser->StopBufferingFrame(); + } + parser->FinishFrame(); + } + } + grpc_shutdown(); +} diff --git a/test/core/transport/chttp2/hpack_parser_table_test.cc b/test/core/transport/chttp2/hpack_parser_table_test.cc new file mode 100644 index 00000000..ab4cb60b --- /dev/null +++ b/test/core/transport/chttp2/hpack_parser_table_test.cc @@ -0,0 +1,149 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/hpack_parser_table.h" + +#include +#include + +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace { +void AssertIndex(const HPackTable* tbl, uint32_t idx, const char* key, + const char* value) { + const auto* md = tbl->Lookup(idx); + ASSERT_NE(md, nullptr); + EXPECT_EQ(md->DebugString(), absl::StrCat(key, ": ", value)); +} +} // namespace + +TEST(HpackParserTableTest, StaticTable) { + grpc_core::ExecCtx exec_ctx; + HPackTable tbl; + + AssertIndex(&tbl, 1, ":authority", ""); + AssertIndex(&tbl, 2, ":method", "GET"); + AssertIndex(&tbl, 3, ":method", "POST"); + AssertIndex(&tbl, 4, ":path", "/"); + AssertIndex(&tbl, 5, ":path", "/index.html"); + AssertIndex(&tbl, 6, ":scheme", "http"); + AssertIndex(&tbl, 7, ":scheme", "https"); + AssertIndex(&tbl, 8, ":status", "200"); + AssertIndex(&tbl, 9, ":status", "204"); + AssertIndex(&tbl, 10, ":status", "206"); + AssertIndex(&tbl, 11, ":status", "304"); + AssertIndex(&tbl, 12, ":status", "400"); + AssertIndex(&tbl, 13, ":status", "404"); + AssertIndex(&tbl, 14, ":status", "500"); + AssertIndex(&tbl, 15, "accept-charset", ""); + AssertIndex(&tbl, 16, "accept-encoding", "gzip, deflate"); + AssertIndex(&tbl, 17, "accept-language", ""); + AssertIndex(&tbl, 18, "accept-ranges", ""); + AssertIndex(&tbl, 19, "accept", ""); + AssertIndex(&tbl, 20, "access-control-allow-origin", ""); + AssertIndex(&tbl, 21, "age", ""); + AssertIndex(&tbl, 22, "allow", ""); + AssertIndex(&tbl, 23, "authorization", ""); + AssertIndex(&tbl, 24, "cache-control", ""); + AssertIndex(&tbl, 25, "content-disposition", ""); + AssertIndex(&tbl, 26, "content-encoding", ""); + AssertIndex(&tbl, 27, "content-language", ""); + AssertIndex(&tbl, 28, "content-length", ""); + AssertIndex(&tbl, 29, "content-location", ""); + AssertIndex(&tbl, 30, "content-range", ""); + AssertIndex(&tbl, 31, "content-type", ""); + AssertIndex(&tbl, 32, "cookie", ""); + AssertIndex(&tbl, 33, "date", ""); + AssertIndex(&tbl, 34, "etag", ""); + AssertIndex(&tbl, 35, "expect", ""); + AssertIndex(&tbl, 36, "expires", ""); + AssertIndex(&tbl, 37, "from", ""); + AssertIndex(&tbl, 38, "host", ""); + AssertIndex(&tbl, 39, "if-match", ""); + AssertIndex(&tbl, 40, "if-modified-since", ""); + AssertIndex(&tbl, 41, "if-none-match", ""); + AssertIndex(&tbl, 42, "if-range", ""); + AssertIndex(&tbl, 43, "if-unmodified-since", ""); + AssertIndex(&tbl, 44, "last-modified", ""); + AssertIndex(&tbl, 45, "link", ""); + AssertIndex(&tbl, 46, "location", ""); + AssertIndex(&tbl, 47, "max-forwards", ""); + AssertIndex(&tbl, 48, "proxy-authenticate", ""); + AssertIndex(&tbl, 49, "proxy-authorization", ""); + AssertIndex(&tbl, 50, "range", ""); + AssertIndex(&tbl, 51, "referer", ""); + AssertIndex(&tbl, 52, "refresh", ""); + AssertIndex(&tbl, 53, "retry-after", ""); + AssertIndex(&tbl, 54, "server", ""); + AssertIndex(&tbl, 55, "set-cookie", ""); + AssertIndex(&tbl, 56, "strict-transport-security", ""); + AssertIndex(&tbl, 57, "transfer-encoding", ""); + AssertIndex(&tbl, 58, "user-agent", ""); + AssertIndex(&tbl, 59, "vary", ""); + AssertIndex(&tbl, 60, "via", ""); + AssertIndex(&tbl, 61, "www-authenticate", ""); +} + +TEST(HpackParserTableTest, ManyAdditions) { + HPackTable tbl; + int i; + + grpc_core::ExecCtx exec_ctx; + + for (i = 0; i < 100000; i++) { + grpc_mdelem elem; + std::string key = absl::StrCat("K.", i); + std::string value = absl::StrCat("VALUE.", i); + elem = grpc_mdelem_from_slices(grpc_slice_from_cpp_string(key), + grpc_slice_from_cpp_string(value)); + ASSERT_EQ(tbl.Add(HPackTable::Memento(elem)), GRPC_ERROR_NONE); + AssertIndex(&tbl, 1 + grpc_core::hpack_constants::kLastStaticEntry, + key.c_str(), value.c_str()); + if (i) { + std::string key = absl::StrCat("K.", i - 1); + std::string value = absl::StrCat("VALUE.", i - 1); + AssertIndex(&tbl, 2 + grpc_core::hpack_constants::kLastStaticEntry, + key.c_str(), value.c_str()); + } + } +} + +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int r = RUN_ALL_TESTS(); + grpc_shutdown(); + return r; +} diff --git a/test/core/transport/chttp2/hpack_parser_test.cc b/test/core/transport/chttp2/hpack_parser_test.cc new file mode 100644 index 00000000..7327c6db --- /dev/null +++ b/test/core/transport/chttp2/hpack_parser_test.cc @@ -0,0 +1,292 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/parse_hexstring.h" +#include "test/core/util/slice_splitter.h" +#include "test/core/util/test_config.h" + +struct TestInput { + const char* input; + const char* expected_parse; +}; + +struct Test { + absl::optional table_size; + std::vector inputs; +}; + +class ParseTest : public ::testing::TestWithParam { + public: + ParseTest() { + grpc_init(); + parser_ = absl::make_unique(); + } + + ~ParseTest() override { + { + grpc_core::ExecCtx exec_ctx; + parser_.reset(); + } + + grpc_shutdown(); + } + + void SetUp() override { + if (GetParam().table_size.has_value()) { + parser_->hpack_table()->SetMaxBytes(GetParam().table_size.value()); + EXPECT_EQ(parser_->hpack_table()->SetCurrentTableSize( + GetParam().table_size.value()), + GRPC_ERROR_NONE); + } + } + + void TestVector(grpc_slice_split_mode mode, const char* hexstring, + std::string expect) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_core::ExecCtx exec_ctx; + grpc_slice input = parse_hexstring(hexstring); + grpc_slice* slices; + size_t nslices; + size_t i; + + grpc_metadata_batch b(arena.get()); + + parser_->BeginFrame( + &b, 4096, grpc_core::HPackParser::Boundary::None, + grpc_core::HPackParser::Priority::None, + grpc_core::HPackParser::LogInfo{ + 1, grpc_core::HPackParser::LogInfo::kHeaders, false}); + + grpc_split_slices(mode, &input, 1, &slices, &nslices); + grpc_slice_unref(input); + + for (i = 0; i < nslices; i++) { + grpc_core::ExecCtx exec_ctx; + auto err = parser_->Parse(slices[i], i == nslices - 1); + if (err != GRPC_ERROR_NONE) { + gpr_log(GPR_ERROR, "Unexpected parse error: %s", + grpc_error_std_string(err).c_str()); + abort(); + } + } + + for (i = 0; i < nslices; i++) { + grpc_slice_unref(slices[i]); + } + gpr_free(slices); + + TestEncoder encoder; + b.Encode(&encoder); + EXPECT_EQ(encoder.result(), expect); + } + + private: + class TestEncoder { + public: + std::string result() { return out_; } + + void Encode(grpc_mdelem elem) { + out_.append(absl::StrCat( + grpc_core::StringViewFromSlice(GRPC_MDKEY(elem)), ": ", + grpc_core::StringViewFromSlice(GRPC_MDVALUE(elem)), "\n")); + } + + template + void Encode(T, V) { + abort(); // not implemented + } + + private: + std::string out_; + }; + + std::unique_ptr parser_; +}; + +TEST_P(ParseTest, WholeSlices) { + for (const auto& input : GetParam().inputs) { + TestVector(GRPC_SLICE_SPLIT_MERGE_ALL, input.input, input.expected_parse); + } +} + +TEST_P(ParseTest, OneByteAtATime) { + for (const auto& input : GetParam().inputs) { + TestVector(GRPC_SLICE_SPLIT_ONE_BYTE, input.input, input.expected_parse); + } +} + +INSTANTIATE_TEST_SUITE_P( + ParseTest, ParseTest, + ::testing::Values( + Test{ + {}, + { + /* D.2.1 */ + {"400a 6375 7374 6f6d 2d6b 6579 0d63 7573" + "746f 6d2d 6865 6164 6572", + "custom-key: custom-header\n"}, + /* D.2.2 */ + {"040c 2f73 616d 706c 652f 7061 7468", ":path: /sample/path\n"}, + /* D.2.3 */ + {"1008 7061 7373 776f 7264 0673 6563 7265" + "74", + "password: secret\n"}, + /* D.2.4 */ + {"82", ":method: GET\n"}, + }}, + Test{{}, + { + /* D.3.1 */ + {"8286 8441 0f77 7777 2e65 7861 6d70 6c65" + "2e63 6f6d", + ":method: GET\n" + ":scheme: http\n" + ":path: /\n" + ":authority: www.example.com\n"}, + /* D.3.2 */ + {"8286 84be 5808 6e6f 2d63 6163 6865", + ":method: GET\n" + ":scheme: http\n" + ":path: /\n" + ":authority: www.example.com\n" + "cache-control: no-cache\n"}, + /* D.3.3 */ + {"8287 85bf 400a 6375 7374 6f6d 2d6b 6579" + "0c63 7573 746f 6d2d 7661 6c75 65", + ":method: GET\n" + ":scheme: https\n" + ":path: /index.html\n" + ":authority: www.example.com\n" + "custom-key: custom-value\n"}, + }}, + Test{{}, + { + /* D.4.1 */ + {"8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4" + "ff", + ":method: GET\n" + ":scheme: http\n" + ":path: /\n" + ":authority: www.example.com\n"}, + /* D.4.2 */ + {"8286 84be 5886 a8eb 1064 9cbf", + ":method: GET\n" + ":scheme: http\n" + ":path: /\n" + ":authority: www.example.com\n" + "cache-control: no-cache\n"}, + /* D.4.3 */ + {"8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925" + "a849 e95b b8e8 b4bf", + ":method: GET\n" + ":scheme: https\n" + ":path: /index.html\n" + ":authority: www.example.com\n" + "custom-key: custom-value\n"}, + }}, + Test{{256}, + { + /* D.5.1 */ + {"4803 3330 3258 0770 7269 7661 7465 611d" + "4d6f 6e2c 2032 3120 4f63 7420 3230 3133" + "2032 303a 3133 3a32 3120 474d 546e 1768" + "7474 7073 3a2f 2f77 7777 2e65 7861 6d70" + "6c65 2e63 6f6d", + ":status: 302\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:21 GMT\n" + "location: https://www.example.com\n"}, + /* D.5.2 */ + {"4803 3330 37c1 c0bf", + ":status: 307\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:21 GMT\n" + "location: https://www.example.com\n"}, + /* D.5.3 */ + {"88c1 611d 4d6f 6e2c 2032 3120 4f63 7420" + "3230 3133 2032 303a 3133 3a32 3220 474d" + "54c0 5a04 677a 6970 7738 666f 6f3d 4153" + "444a 4b48 514b 425a 584f 5157 454f 5049" + "5541 5851 5745 4f49 553b 206d 6178 2d61" + "6765 3d33 3630 303b 2076 6572 7369 6f6e" + "3d31", + ":status: 200\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:22 GMT\n" + "location: https://www.example.com\n" + "content-encoding: gzip\n" + "set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; " + "version=1\n"}, + }}, + Test{{256}, + { + /* D.6.1 */ + {"4882 6402 5885 aec3 771a 4b61 96d0 7abe" + "9410 54d4 44a8 2005 9504 0b81 66e0 82a6" + "2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8" + "e9ae 82ae 43d3", + ":status: 302\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:21 GMT\n" + "location: https://www.example.com\n"}, + /* D.6.2 */ + {"4883 640e ffc1 c0bf", + ":status: 307\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:21 GMT\n" + "location: https://www.example.com\n"}, + /* D.6.3 */ + {"88c1 6196 d07a be94 1054 d444 a820 0595" + "040b 8166 e084 a62d 1bff c05a 839b d9ab" + "77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b" + "3960 d5af 2708 7f36 72c1 ab27 0fb5 291f" + "9587 3160 65c0 03ed 4ee5 b106 3d50 07", + ":status: 200\n" + "cache-control: private\n" + "date: Mon, 21 Oct 2013 20:13:22 GMT\n" + "location: https://www.example.com\n" + "content-encoding: gzip\n" + "set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; " + "version=1\n"}, + }}, + Test{{}, + { + // Binary metadata: created using: + // tools/codegen/core/gen_header_frame.py + // --compression inc --no_framing --hex + // < test/core/transport/chttp2/binary-metadata.headers + {"40 09 61 2e 62 2e 63 2d 62 69 6e 0c 62 32 31 6e 4d 6a 41 79 " + "4d 51 3d 3d", + "a.b.c-bin: omg2021\n"}, + }})); + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/chttp2/hpack_utils_test.cc b/test/core/transport/chttp2/hpack_utils_test.cc new file mode 100644 index 00000000..5cb89488 --- /dev/null +++ b/test/core/transport/chttp2/hpack_utils_test.cc @@ -0,0 +1,122 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder_index.h" + +namespace grpc_core { +namespace testing { + +static void VerifyAsciiHeaderSize(const char* key, const char* value, + bool intern_key, bool intern_value) { + grpc_mdelem elem = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string(key), intern_key), + maybe_intern(grpc_slice_from_static_string(value), intern_value)); + size_t elem_size = grpc_core::MetadataSizeInHPackTable(elem, false); + size_t expected_size = 32 + strlen(key) + strlen(value); + GPR_ASSERT(expected_size == elem_size); + GRPC_MDELEM_UNREF(elem); +} + +static void VerifyBinaryHeaderSize(const char* key, const uint8_t* value, + size_t value_len, bool intern_key, + bool intern_value) { + grpc_mdelem elem = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string(key), intern_key), + maybe_intern(grpc_slice_from_static_buffer(value, value_len), + intern_value)); + GPR_ASSERT(grpc_is_binary_header(GRPC_MDKEY(elem))); + size_t elem_size = grpc_core::MetadataSizeInHPackTable(elem, false); + grpc_slice value_slice = grpc_slice_from_copied_buffer( + reinterpret_cast(value), value_len); + grpc_slice base64_encoded = grpc_chttp2_base64_encode(value_slice); + size_t expected_size = 32 + strlen(key) + GRPC_SLICE_LENGTH(base64_encoded); + GPR_ASSERT(expected_size == elem_size); + grpc_slice_unref_internal(value_slice); + grpc_slice_unref_internal(base64_encoded); + GRPC_MDELEM_UNREF(elem); +} + +struct Param { + bool intern_key; + bool intern_value; +} + +class MetadataTest : public ::testing::TestWithParam { +}; + +#define BUFFER_SIZE 64 +TEST_P(MetadataTest, MetadataSize) { + const bool intern_key = GetParam().intern_key; + const bool intern_value = GetParam().intern_value; + gpr_log(GPR_INFO, "test_mdelem_size: intern_key=%d intern_value=%d", + intern_key, intern_value); + grpc_init(); + grpc_core::ExecCtx exec_ctx; + + uint8_t binary_value[BUFFER_SIZE] = {0}; + for (uint8_t i = 0; i < BUFFER_SIZE; i++) { + binary_value[i] = i; + } + + verify_ascii_header_size("hello", "world", intern_key, intern_value); + verify_ascii_header_size("hello", "worldxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + intern_key, intern_value); + verify_ascii_header_size(":scheme", "http", intern_key, intern_value); + + for (uint8_t i = 0; i < BUFFER_SIZE; i++) { + verify_binary_header_size("hello-bin", binary_value, i, intern_key, + intern_value); + } + + grpc_shutdown(); +} + +INSTANTIATE_TEST_SUITE_P(MetadataTestSuite, MetadataTest, + ::testing::Values(Param{false, false}, + Param{false, true}, + Param{true, false}, + Param{true, true})); + +TEST(HPackEncoderIndexTest, SetAndGet) { + HPackEncoderIndex index; + std::default_random_engine rng; + std::unordered_map last_index; + for (uint32_t i = 0; i < 10000; i++) { + uint32_t key = rng(); + index.Insert({key}, i); + EXPECT_EQ(index.Lookup({key}), i); + last_index[key] = i; + } + for (auto p : last_index) { + auto r = index.Lookup({p.first}); + if (r.has_value()) { + EXPECT_EQ(*r, p.second); + } + } +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/chttp2/popularity_count_test.cc b/test/core/transport/chttp2/popularity_count_test.cc new file mode 100644 index 00000000..bcd03f11 --- /dev/null +++ b/test/core/transport/chttp2/popularity_count_test.cc @@ -0,0 +1,75 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/core/ext/transport/chttp2/transport/popularity_count.h" + +#include + +#include + +namespace grpc_core { +namespace testing { + +static constexpr uint8_t kTestSize = 4; + +struct Scenario { + std::array initial_values; + uint8_t final_add; + bool expectation; +}; + +std::ostream& operator<<(std::ostream& out, Scenario s) { + out << "init:"; + for (size_t i = 0; i < kTestSize; i++) { + if (i != 0) { + out << ","; + } + out << static_cast(s.initial_values[i]); + } + out << " final:" << static_cast(s.final_add); + out << " expect:" << (s.expectation ? "true" : "false"); + return out; +} + +struct PopularityCountTest : public ::testing::TestWithParam {}; + +TEST_P(PopularityCountTest, Test) { + Scenario s = GetParam(); + PopularityCount pop; + for (size_t i = 0; i < kTestSize; i++) { + for (size_t j = 0; j < s.initial_values[i]; j++) { + pop.AddElement(i); + } + } + EXPECT_EQ(pop.AddElement(s.final_add), s.expectation); +} + +INSTANTIATE_TEST_SUITE_P(InterestingTests, PopularityCountTest, + ::testing::Values(Scenario{{0, 0, 0, 0}, 0, true}, + Scenario{{64, 0, 0, 0}, 0, true}, + Scenario{{64, 0, 0, 0}, 1, false})); + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/chttp2/remove_stream_from_stalled_lists_test.cc b/test/core/transport/chttp2/remove_stream_from_stalled_lists_test.cc new file mode 100644 index 00000000..8de38613 --- /dev/null +++ b/test/core/transport/chttp2/remove_stream_from_stalled_lists_test.cc @@ -0,0 +1,358 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/transport/chttp2/transport/flow_control.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/host_port.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +class TransportTargetWindowEstimatesMocker + : public grpc_core::chttp2::TestOnlyTransportTargetWindowEstimatesMocker { + public: + explicit TransportTargetWindowEstimatesMocker() {} + + double ComputeNextTargetInitialWindowSizeFromPeriodicUpdate( + double current_target) override { + const double kTinyWindow = 512; + const double kSmallWindow = 8192; + // The goal is to bounce back and forth between 512 and 8192 initial window + // sizes, in order to get the following to happen at the server (in order): + // + // 1) Stall the server-side RPC's outgoing message on stream window flow + // control. + // + // 2) Send another settings frame with a change in initial window + // size setting, which will make the server-side call go writable. + if (current_target > kTinyWindow) { + return kTinyWindow; + } else { + return kSmallWindow; + } + } +}; + +void StartCall(grpc_call* call, grpc_completion_queue* cq) { + grpc_op ops[1]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY; + op->reserved = nullptr; + op++; + void* tag = call; + grpc_call_error error = grpc_call_start_batch( + call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_event event = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success); + GPR_ASSERT(event.tag == tag); +} + +void FinishCall(grpc_call* call, grpc_completion_queue* cq) { + grpc_op ops[4]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_status_code status; + grpc_slice details; + grpc_byte_buffer* recv_payload = nullptr; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &recv_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + void* tag = call; + grpc_call_error error = grpc_call_start_batch( + call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_event event = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + GPR_ASSERT(event.success); + GPR_ASSERT(event.tag == tag); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_byte_buffer_destroy(recv_payload); + grpc_slice_unref(details); +} + +class TestServer { + public: + explicit TestServer() { + cq_ = grpc_completion_queue_create_for_next(nullptr); + server_ = grpc_server_create(nullptr, nullptr); + address_ = grpc_core::JoinHostPort("[::1]", grpc_pick_unused_port_or_die()); + grpc_server_register_completion_queue(server_, cq_, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port(server_, address_.c_str())); + grpc_server_start(server_); + accept_thread_ = std::thread(std::bind(&TestServer::AcceptThread, this)); + } + + int ShutdownAndGetNumCallsHandled() { + { + // prevent the server from requesting any more calls + grpc_core::MutexLock lock(&shutdown_mu_); + shutdown_ = true; + } + grpc_server_shutdown_and_notify(server_, cq_, this /* tag */); + accept_thread_.join(); + grpc_server_destroy(server_); + grpc_completion_queue_shutdown(cq_); + while (grpc_completion_queue_next(cq_, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq_); + return num_calls_handled_; + } + + std::string address() const { return address_; } + + private: + void AcceptThread() { + std::vector rpc_threads; + bool got_shutdown_and_notify_tag = false; + while (!got_shutdown_and_notify_tag) { + void* request_call_tag = &rpc_threads; + grpc_call_details call_details; + grpc_call_details_init(&call_details); + grpc_call* call = nullptr; + grpc_completion_queue* call_cq = nullptr; + grpc_metadata_array request_metadata_recv; + grpc_metadata_array_init(&request_metadata_recv); + { + grpc_core::MutexLock lock(&shutdown_mu_); + if (!shutdown_) { + call_cq = grpc_completion_queue_create_for_next(nullptr); + grpc_call_error error = grpc_server_request_call( + server_, &call, &call_details, &request_metadata_recv, call_cq, + cq_, request_call_tag); + GPR_ASSERT(error == GRPC_CALL_OK); + } + } + grpc_event event = grpc_completion_queue_next( + cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(event.type == GRPC_OP_COMPLETE); + grpc_call_details_destroy(&call_details); + grpc_metadata_array_destroy(&request_metadata_recv); + if (event.success) { + if (event.tag == request_call_tag) { + // HandleOneRpc takes ownership of its parameters + num_calls_handled_++; + rpc_threads.push_back( + std::thread(std::bind(&TestServer::HandleOneRpc, call, call_cq))); + } else if (event.tag == this /* shutdown_and_notify tag */) { + grpc_core::MutexLock lock(&shutdown_mu_); + GPR_ASSERT(shutdown_); + GPR_ASSERT(call_cq == nullptr); + got_shutdown_and_notify_tag = true; + } else { + GPR_ASSERT(0); + } + } else { + grpc_core::MutexLock lock(&shutdown_mu_); + GPR_ASSERT(shutdown_); + grpc_completion_queue_destroy(call_cq); + } + } + gpr_log(GPR_INFO, "test server shutdown, joining RPC threads..."); + for (auto& t : rpc_threads) { + t.join(); + } + gpr_log(GPR_INFO, "test server threads all finished!"); + } + + static void HandleOneRpc(grpc_call* call, grpc_completion_queue* call_cq) { + // Send a large enough payload to get us stalled on outgoing flow control + std::string send_payload = ""; + for (int i = 0; i < 4 * 1e6; i++) { + send_payload += "a"; + } + grpc_slice request_payload_slice = + grpc_slice_from_copied_string(send_payload.c_str()); + grpc_byte_buffer* request_payload = + grpc_raw_byte_buffer_create(&request_payload_slice, 1); + void* tag = call_cq; + grpc_op ops[2]; + grpc_op* op; + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload; + op->reserved = nullptr; + op++; + grpc_call_error error = grpc_call_start_batch( + call, ops, static_cast(op - ops), tag, nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + std::thread poller([call_cq]() { + // poll the connection so that we actively pick up bytes off the wire, + // including settings frames with window size increases + while (grpc_completion_queue_next( + call_cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + }); + grpc_call_cancel(call, nullptr); + grpc_call_unref(call); + grpc_completion_queue_shutdown(call_cq); + poller.join(); + grpc_completion_queue_destroy(call_cq); + grpc_byte_buffer_destroy(request_payload); + grpc_slice_unref(request_payload_slice); + } + + grpc_server* server_; + grpc_completion_queue* cq_; + std::string address_; + std::thread accept_thread_; + int num_calls_handled_ = 0; + grpc_core::Mutex shutdown_mu_; + bool shutdown_ = false; +}; + +// Perform a simple RPC where the server cancels the request with +// grpc_call_cancel_with_status +TEST(Pollers, TestDontCrashWhenTryingToReproIssueFixedBy23984) { + // 64 threads is arbitrary but chosen because, experimentally it's enough to + // repro the targetted crash crash (which is then fixed by + // https://github.com/grpc/grpc/pull/23984) at a very high rate. + const int kNumCalls = 64; + std::vector threads; + threads.reserve(kNumCalls); + std::unique_ptr test_server = absl::make_unique(); + const std::string server_address = test_server->address(); + for (int i = 0; i < kNumCalls; i++) { + threads.push_back(std::thread([server_address]() { + std::vector args; + // this test is meant to create one connection to the server for each + // of these threads + args.push_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL), true)); + grpc_channel_args* channel_args = + grpc_channel_args_copy_and_add(nullptr, args.data(), args.size()); + grpc_channel* channel = grpc_insecure_channel_create( + std::string("ipv6:" + server_address).c_str(), channel_args, nullptr); + grpc_channel_args_destroy(channel_args); + grpc_completion_queue* cq = + grpc_completion_queue_create_for_next(nullptr); + grpc_call* call = grpc_channel_create_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + StartCall(call, cq); + // Explicitly avoid reading on this RPC for a period of time. The + // goal is to get the server side RPC to stall on it's outgoing stream + // flow control window, as the first step in trying to trigger a bug. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + // Note that this test doesn't really care what the status of the RPC was, + // because we're just trying to make sure that we don't crash. + FinishCall(call, cq); + grpc_call_unref(call); + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); + })); + } + for (auto& thread : threads) { + thread.join(); + } + gpr_log(GPR_DEBUG, "All RPCs completed!"); + int num_calls_seen_at_server = test_server->ShutdownAndGetNumCallsHandled(); + if (num_calls_seen_at_server != kNumCalls) { + gpr_log(GPR_ERROR, + "Expected server to handle %d calls, but instead it only handled " + "%d. This suggests some or all RPCs didn't make it to the server, " + "which means " + "that this test likely isn't doing what it's meant to be doing.", + kNumCalls, num_calls_seen_at_server); + GPR_ASSERT(0); + } +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + // Make sure that we will have an active poller on all client-side fd's that + // are capable of sending settings frames with window updates etc., even in + // the case that we don't have an active RPC operation on the fd. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); + grpc_core::chttp2::g_test_only_transport_target_window_estimates_mocker = + new TransportTargetWindowEstimatesMocker(); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/transport/chttp2/settings_timeout_test.cc b/test/core/transport/chttp2/settings_timeout_test.cc new file mode 100644 index 00000000..7696d041 --- /dev/null +++ b/test/core/transport/chttp2/settings_timeout_test.cc @@ -0,0 +1,261 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/port.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace test { +namespace { + +// A gRPC server, running in its own thread. +class ServerThread { + public: + explicit ServerThread(const char* address) : address_(address) {} + + void Start() { + // Start server with 1-second handshake timeout. + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_SERVER_HANDSHAKE_TIMEOUT_MS); + arg.value.integer = 1000; + grpc_channel_args args = {1, &arg}; + server_ = grpc_server_create(&args, nullptr); + ASSERT_TRUE(grpc_server_add_insecure_http2_port(server_, address_)); + cq_ = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server_, cq_, nullptr); + grpc_server_start(server_); + thread_ = + absl::make_unique(std::bind(&ServerThread::Serve, this)); + } + + void Shutdown() { + grpc_completion_queue* shutdown_cq = + grpc_completion_queue_create_for_pluck(nullptr); + grpc_server_shutdown_and_notify(server_, shutdown_cq, nullptr); + GPR_ASSERT(grpc_completion_queue_pluck(shutdown_cq, nullptr, + grpc_timeout_seconds_to_deadline(1), + nullptr) + .type == GRPC_OP_COMPLETE); + grpc_completion_queue_destroy(shutdown_cq); + grpc_server_destroy(server_); + grpc_completion_queue_destroy(cq_); + thread_->join(); + } + + private: + void Serve() { + // The completion queue should not return anything other than shutdown. + grpc_event ev = grpc_completion_queue_next( + cq_, gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + ASSERT_EQ(GRPC_QUEUE_SHUTDOWN, ev.type); + } + + const char* address_; // Do not own. + grpc_server* server_ = nullptr; + grpc_completion_queue* cq_ = nullptr; + std::unique_ptr thread_; +}; + +// A TCP client that connects to the server, reads data until the server +// closes, and then terminates. +class Client { + public: + explicit Client(const char* server_address) + : server_address_(server_address) {} + + void Connect() { + grpc_core::ExecCtx exec_ctx; + grpc_resolved_addresses* server_addresses = nullptr; + grpc_error_handle error = + grpc_blocking_resolve_address(server_address_, "80", &server_addresses); + ASSERT_EQ(GRPC_ERROR_NONE, error) << grpc_error_std_string(error); + ASSERT_GE(server_addresses->naddrs, 1UL); + pollset_ = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset_, &mu_); + grpc_pollset_set* pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(pollset_set, pollset_); + EventState state; + grpc_tcp_client_connect( + state.closure(), &endpoint_, grpc_slice_allocator_create_unlimited(), + pollset_set, nullptr /* channel_args */, server_addresses->addrs, + grpc_core::ExecCtx::Get()->Now() + 1000); + ASSERT_TRUE(PollUntilDone( + &state, + grpc_timespec_to_millis_round_up(gpr_inf_future(GPR_CLOCK_MONOTONIC)))); + ASSERT_EQ(GRPC_ERROR_NONE, state.error()); + grpc_pollset_set_destroy(pollset_set); + grpc_endpoint_add_to_pollset(endpoint_, pollset_); + grpc_resolved_addresses_destroy(server_addresses); + } + + // Reads until an error is returned. + // Returns true if an error was encountered before the deadline. + bool ReadUntilError() { + grpc_core::ExecCtx exec_ctx; + grpc_slice_buffer read_buffer; + grpc_slice_buffer_init(&read_buffer); + bool retval = true; + // Use a deadline of 3 seconds, which is a lot more than we should + // need for a 1-second timeout, but this helps avoid flakes. + grpc_millis deadline = grpc_core::ExecCtx::Get()->Now() + 3000; + while (true) { + EventState state; + grpc_endpoint_read(endpoint_, &read_buffer, state.closure(), + /*urgent=*/true); + if (!PollUntilDone(&state, deadline)) { + retval = false; + break; + } + if (state.error() != GRPC_ERROR_NONE) break; + gpr_log(GPR_INFO, "client read %" PRIuPTR " bytes", read_buffer.length); + grpc_slice_buffer_reset_and_unref_internal(&read_buffer); + } + grpc_endpoint_shutdown(endpoint_, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("shutdown")); + grpc_slice_buffer_destroy_internal(&read_buffer); + return retval; + } + + void Shutdown() { + grpc_core::ExecCtx exec_ctx; + grpc_endpoint_destroy(endpoint_); + grpc_pollset_shutdown(pollset_, + GRPC_CLOSURE_CREATE(&Client::PollsetDestroy, pollset_, + grpc_schedule_on_exec_ctx)); + } + + private: + // State used to wait for an I/O event. + class EventState { + public: + EventState() { + GRPC_CLOSURE_INIT(&closure_, &EventState::OnEventDone, this, + grpc_schedule_on_exec_ctx); + } + + ~EventState() { GRPC_ERROR_UNREF(error_); } + + grpc_closure* closure() { return &closure_; } + + bool done() const { return gpr_atm_acq_load(&done_atm_) != 0; } + + // Caller does NOT take ownership of the error. + grpc_error_handle error() const { return error_; } + + private: + static void OnEventDone(void* arg, grpc_error_handle error) { + gpr_log(GPR_INFO, "OnEventDone(): %s", + grpc_error_std_string(error).c_str()); + EventState* state = static_cast(arg); + state->error_ = GRPC_ERROR_REF(error); + gpr_atm_rel_store(&state->done_atm_, 1); + } + + grpc_closure closure_; + gpr_atm done_atm_ = 0; + grpc_error_handle error_ = GRPC_ERROR_NONE; + }; + + // Returns true if done, or false if deadline exceeded. + bool PollUntilDone(EventState* state, grpc_millis deadline) { + while (true) { + grpc_pollset_worker* worker = nullptr; + gpr_mu_lock(mu_); + GRPC_LOG_IF_ERROR( + "grpc_pollset_work", + grpc_pollset_work(pollset_, &worker, + grpc_core::ExecCtx::Get()->Now() + 100)); + // Flushes any work scheduled before or during polling. + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_unlock(mu_); + if (state != nullptr && state->done()) return true; + if (grpc_core::ExecCtx::Get()->Now() >= deadline) return false; + } + } + + static void PollsetDestroy(void* arg, grpc_error_handle /*error*/) { + grpc_pollset* pollset = static_cast(arg); + grpc_pollset_destroy(pollset); + gpr_free(pollset); + } + + const char* server_address_; // Do not own. + grpc_endpoint* endpoint_; + gpr_mu* mu_; + grpc_pollset* pollset_; +}; + +TEST(SettingsTimeout, Basic) { + // Construct server address string. + const int server_port = grpc_pick_unused_port_or_die(); + std::string server_address_string = absl::StrCat("localhost:", server_port); + // Start server. + gpr_log(GPR_INFO, "starting server on %s", server_address_string.c_str()); + ServerThread server_thread(server_address_string.c_str()); + server_thread.Start(); + // Create client and connect to server. + gpr_log(GPR_INFO, "starting client connect"); + Client client(server_address_string.c_str()); + client.Connect(); + // Client read. Should fail due to server dropping connection. + gpr_log(GPR_INFO, "starting client read"); + EXPECT_TRUE(client.ReadUntilError()); + // Shut down client. + gpr_log(GPR_INFO, "shutting down client"); + client.Shutdown(); + // Shut down server. + gpr_log(GPR_INFO, "shutting down server"); + server_thread.Shutdown(); + // Clean up. +} + +} // namespace +} // namespace test +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/transport/chttp2/stream_map_test.cc b/test/core/transport/chttp2/stream_map_test.cc new file mode 100644 index 00000000..3f63709c --- /dev/null +++ b/test/core/transport/chttp2/stream_map_test.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/stream_map.h" + +#include + +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x) + +/* test creation & destruction */ +static void test_no_op(void) { + grpc_chttp2_stream_map map; + + LOG_TEST("test_no_op"); + + grpc_chttp2_stream_map_init(&map, 8); + grpc_chttp2_stream_map_destroy(&map); +} + +/* test lookup on an empty map */ +static void test_empty_find(void) { + grpc_chttp2_stream_map map; + + LOG_TEST("test_empty_find"); + + grpc_chttp2_stream_map_init(&map, 8); + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(&map, 39128)); + grpc_chttp2_stream_map_destroy(&map); +} + +/* test add & lookup */ +static void test_basic_add_find(uint32_t n) { + grpc_chttp2_stream_map map; + uint32_t i; + size_t got; + + LOG_TEST("test_basic_add_find"); + gpr_log(GPR_INFO, "n = %d", n); + + grpc_chttp2_stream_map_init(&map, 8); + GPR_ASSERT(0 == grpc_chttp2_stream_map_size(&map)); + for (i = 1; i <= n; i++) { + grpc_chttp2_stream_map_add(&map, i, reinterpret_cast(i)); + } + GPR_ASSERT(n == grpc_chttp2_stream_map_size(&map)); + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(&map, 0)); + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(&map, n + 1)); + for (i = 1; i <= n; i++) { + got = reinterpret_cast(grpc_chttp2_stream_map_find(&map, i)); + GPR_ASSERT(i == got); + } + grpc_chttp2_stream_map_destroy(&map); +} + +/* verify that for_each gets the right values during test_delete_evens_XXX */ +static void verify_for_each(void* user_data, uint32_t stream_id, void* ptr) { + uint32_t* for_each_check = static_cast(user_data); + GPR_ASSERT(ptr); + GPR_ASSERT(*for_each_check == stream_id); + *for_each_check += 2; +} + +static void check_delete_evens(grpc_chttp2_stream_map* map, uint32_t n) { + uint32_t for_each_check = 1; + uint32_t i; + size_t got; + + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(map, 0)); + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(map, n + 1)); + for (i = 1; i <= n; i++) { + if (i & 1) { + got = reinterpret_cast(grpc_chttp2_stream_map_find(map, i)); + GPR_ASSERT(i == got); + } else { + GPR_ASSERT(nullptr == grpc_chttp2_stream_map_find(map, i)); + } + } + + grpc_chttp2_stream_map_for_each(map, verify_for_each, &for_each_check); + if (n & 1) { + GPR_ASSERT(for_each_check == n + 2); + } else { + GPR_ASSERT(for_each_check == n + 1); + } +} + +/* add a bunch of keys, delete the even ones, and make sure the map is + consistent */ +static void test_delete_evens_sweep(uint32_t n) { + grpc_chttp2_stream_map map; + uint32_t i; + + LOG_TEST("test_delete_evens_sweep"); + gpr_log(GPR_INFO, "n = %d", n); + + grpc_chttp2_stream_map_init(&map, 8); + for (i = 1; i <= n; i++) { + grpc_chttp2_stream_map_add(&map, i, reinterpret_cast(i)); + } + for (i = 1; i <= n; i++) { + if ((i & 1) == 0) { + GPR_ASSERT((void*)(uintptr_t)i == grpc_chttp2_stream_map_delete(&map, i)); + } + } + check_delete_evens(&map, n); + grpc_chttp2_stream_map_destroy(&map); +} + +/* add a bunch of keys, delete the even ones immediately, and make sure the map + is consistent */ +static void test_delete_evens_incremental(uint32_t n) { + grpc_chttp2_stream_map map; + uint32_t i; + + LOG_TEST("test_delete_evens_incremental"); + gpr_log(GPR_INFO, "n = %d", n); + + grpc_chttp2_stream_map_init(&map, 8); + for (i = 1; i <= n; i++) { + grpc_chttp2_stream_map_add(&map, i, reinterpret_cast(i)); + if ((i & 1) == 0) { + grpc_chttp2_stream_map_delete(&map, i); + } + } + check_delete_evens(&map, n); + grpc_chttp2_stream_map_destroy(&map); +} + +/* add a bunch of keys, delete old ones after some time, ensure the + backing array does not grow */ +static void test_periodic_compaction(uint32_t n) { + grpc_chttp2_stream_map map; + uint32_t i; + uint32_t del; + + LOG_TEST("test_periodic_compaction"); + gpr_log(GPR_INFO, "n = %d", n); + + grpc_chttp2_stream_map_init(&map, 16); + GPR_ASSERT(map.capacity == 16); + for (i = 1; i <= n; i++) { + grpc_chttp2_stream_map_add(&map, i, reinterpret_cast(i)); + if (i > 8) { + del = i - 8; + GPR_ASSERT((void*)(uintptr_t)del == + grpc_chttp2_stream_map_delete(&map, del)); + } + } + GPR_ASSERT(map.capacity == 16); + grpc_chttp2_stream_map_destroy(&map); +} + +int main(int argc, char** argv) { + uint32_t n = 1; + uint32_t prev = 1; + uint32_t tmp; + + grpc::testing::TestEnvironment env(argc, argv); + + test_no_op(); + test_empty_find(); + + while (n < 100000) { + test_basic_add_find(n); + test_delete_evens_sweep(n); + test_delete_evens_incremental(n); + test_periodic_compaction(n); + + tmp = n; + n += prev; + prev = tmp; + } + + return 0; +} diff --git a/test/core/transport/chttp2/too_many_pings_test.cc b/test/core/transport/chttp2/too_many_pings_test.cc new file mode 100644 index 00000000..d1a6ec7c --- /dev/null +++ b/test/core/transport/chttp2/too_many_pings_test.cc @@ -0,0 +1,838 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/channel.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/memory_counters.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +class TransportCounter { + public: + static void CounterInitCallback() { + absl::MutexLock lock(&mu()); + ++count_; + } + + static void CounterDestructCallback() { + absl::MutexLock lock(&mu()); + if (--count_ == 0) { + cv().SignalAll(); + } + } + + static void WaitForTransportsToBeDestroyed() { + absl::MutexLock lock(&mu()); + while (count_ != 0) { + ASSERT_FALSE(cv().WaitWithTimeout(&mu(), absl::Seconds(10))); + } + } + + static int count() { + absl::MutexLock lock(&mu()); + return count_; + } + + static absl::Mutex& mu() { + static absl::Mutex* mu = new absl::Mutex(); + return *mu; + } + + static absl::CondVar& cv() { + static absl::CondVar* cv = new absl::CondVar(); + return *cv; + } + + private: + static int count_; +}; + +int TransportCounter::count_ = 0; + +void* tag(intptr_t t) { return reinterpret_cast(t); } + +// Perform a simple RPC where the server cancels the request with +// grpc_call_cancel_with_status +grpc_status_code PerformCall(grpc_channel* channel, grpc_server* server, + grpc_completion_queue* cq) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + // Start a call + c = grpc_channel_create_call(channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + // Request a call on the server + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + grpc_call_cancel_with_status(s, GRPC_STATUS_PERMISSION_DENIED, "test status", + nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + // cleanup + grpc_slice_unref(details); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + grpc_call_unref(s); + cq_verifier_destroy(cqv); + return status; +} + +// Test that sending a lot of RPCs that are cancelled by the server doesn't +// result in too many pings due to the pings sent by BDP. +TEST(TooManyPings, TestLotsOfServerCancelledRpcsDoesntGiveTooManyPings) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + // create the server + grpc_server* server = grpc_server_create(nullptr, nullptr); + std::string server_address = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(server, server_address.c_str())); + grpc_server_start(server); + // create the channel (bdp pings are enabled by default) + grpc_channel* channel = grpc_insecure_channel_create( + server_address.c_str(), nullptr /* channel args */, nullptr); + std::map statuses_and_counts; + const int kNumTotalRpcs = 1e5; + // perform an RPC + gpr_log(GPR_INFO, + "Performing %d total RPCs and expecting them all to receive status " + "PERMISSION_DENIED (%d)", + kNumTotalRpcs, GRPC_STATUS_PERMISSION_DENIED); + for (int i = 0; i < kNumTotalRpcs; i++) { + grpc_status_code status = PerformCall(channel, server, cq); + statuses_and_counts[status] += 1; + } + int num_not_cancelled = 0; + for (auto itr = statuses_and_counts.begin(); itr != statuses_and_counts.end(); + itr++) { + if (itr->first != GRPC_STATUS_PERMISSION_DENIED) { + num_not_cancelled += itr->second; + } + gpr_log(GPR_INFO, "%d / %d RPCs received status code: %d", itr->second, + kNumTotalRpcs, itr->first); + } + if (num_not_cancelled > 0) { + gpr_log(GPR_ERROR, + "Expected all RPCs to receive status PERMISSION_DENIED (%d) but %d " + "received other status codes", + GRPC_STATUS_PERMISSION_DENIED, num_not_cancelled); + FAIL(); + } + // shutdown and destroy the client and server + grpc_channel_destroy(channel); + grpc_server_shutdown_and_notify(server, cq, nullptr); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_server_destroy(server); + grpc_completion_queue_destroy(cq); +} + +// Perform a simple RPC where the client makes a request, and both the client +// and server continue reading so that gRPC can send and receive keepalive +// pings. +grpc_status_code PerformWaitingCall(grpc_channel* channel, grpc_server* server, + grpc_completion_queue* cq) { + grpc_call* c; + grpc_call* s; + cq_verifier* cqv = cq_verifier_create(cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(15); + // Start a call + c = grpc_channel_create_call(channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + // Request a call on the server + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + // Since the server is configured to allow only a single ping strike, it would + // take 3 pings to trigger the GOAWAY frame with "too_many_pings" from the + // server. (The second ping from the client would be the first bad ping sent + // too quickly leading to a ping strike and the third ping would lead to the + // GOAWAY.) If the client settings match with the server's settings, there + // won't be a bad ping, and the call will end due to the deadline expiring + // instead. + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + // The call will end after this + cq_verify(cqv, 60); + // cleanup + grpc_slice_unref(details); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(c); + grpc_call_unref(s); + cq_verifier_destroy(cqv); + return status; +} + +// Shuts down and destroys the server. +void ServerShutdownAndDestroy(grpc_server* server, grpc_completion_queue* cq) { + // Shutdown and destroy server + grpc_server_shutdown_and_notify(server, cq, reinterpret_cast(1000)); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .tag != reinterpret_cast(1000)) { + } + grpc_server_destroy(server); +} + +void VerifyChannelReady(grpc_channel* channel, grpc_completion_queue* cq) { + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(channel, 1 /* try_to_connect */); + while (state != GRPC_CHANNEL_READY) { + grpc_channel_watch_connectivity_state( + channel, state, grpc_timeout_seconds_to_deadline(5), cq, nullptr); + grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), + nullptr); + state = grpc_channel_check_connectivity_state(channel, 0); + } +} + +void VerifyChannelDisconnected(grpc_channel* channel, + grpc_completion_queue* cq) { + // Verify channel gets disconnected. Use a ping to make sure that clients + // tries sending/receiving bytes if the channel is connected. + grpc_channel_ping(channel, cq, reinterpret_cast(2000), nullptr); + grpc_event ev = grpc_completion_queue_next( + cq, grpc_timeout_seconds_to_deadline(5), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == reinterpret_cast(2000)); + GPR_ASSERT(ev.success == 0); + GPR_ASSERT(grpc_channel_check_connectivity_state(channel, 0) != + GRPC_CHANNEL_READY); +} + +class KeepaliveThrottlingTest : public ::testing::Test { + protected: + // Starts the server and makes sure that the channel is able to get connected. + grpc_server* ServerStart(const char* addr, grpc_completion_queue* cq) { + // Set up server channel args to expect pings at an interval of 5 seconds + // and use a single ping strike + grpc_arg server_args[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 5 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), 1)}; + grpc_channel_args server_channel_args = {GPR_ARRAY_SIZE(server_args), + server_args}; + // Create server + grpc_server* server = grpc_server_create(&server_channel_args, nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT(grpc_server_add_insecure_http2_port(server, addr)); + grpc_server_start(server); + return server; + } +}; + +TEST_F(KeepaliveThrottlingTest, KeepaliveThrottlingMultipleChannels) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + std::string server_address = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + grpc_server* server = ServerStart(server_address.c_str(), cq); + // create two channel with a keepalive ping interval of 1 second. + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), 1 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0)}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + grpc_channel* channel = grpc_insecure_channel_create( + server_address.c_str(), &client_channel_args, nullptr); + grpc_channel* channel_dup = grpc_insecure_channel_create( + server_address.c_str(), &client_channel_args, nullptr); + int expected_keepalive_time_sec = 1; + // We need 3 GOAWAY frames to throttle the keepalive time from 1 second to 8 + // seconds (> 5sec). + for (int i = 0; i < 3; i++) { + gpr_log(GPR_INFO, "Expected keepalive time : %d", + expected_keepalive_time_sec); + EXPECT_EQ(PerformWaitingCall(channel, server, cq), GRPC_STATUS_UNAVAILABLE); + expected_keepalive_time_sec *= 2; + } + gpr_log( + GPR_INFO, + "Client keepalive time %d should now be in sync with the server settings", + expected_keepalive_time_sec); + EXPECT_EQ(PerformWaitingCall(channel, server, cq), + GRPC_STATUS_DEADLINE_EXCEEDED); + // Since the subchannel is shared, the second channel should also have + // keepalive settings in sync with the server. + gpr_log(GPR_INFO, "Now testing second channel sharing the same subchannel"); + EXPECT_EQ(PerformWaitingCall(channel_dup, server, cq), + GRPC_STATUS_DEADLINE_EXCEEDED); + // shutdown and destroy the client and server + grpc_channel_destroy(channel); + grpc_channel_destroy(channel_dup); + ServerShutdownAndDestroy(server, cq); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); +} + +grpc_core::Resolver::Result BuildResolverResult( + const std::vector& addresses) { + grpc_core::Resolver::Result result; + for (const auto& address_str : addresses) { + absl::StatusOr uri = grpc_core::URI::Parse(address_str); + if (!uri.ok()) { + gpr_log(GPR_ERROR, "Failed to parse uri. Error: %s", + uri.status().ToString().c_str()); + GPR_ASSERT(uri.ok()); + } + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*uri, &address)); + result.addresses.emplace_back(address.addr, address.len, nullptr); + } + return result; +} + +// Tests that when new subchannels are created due to a change in resolved +// addresses, the new subchannels use the updated keepalive time. +TEST_F(KeepaliveThrottlingTest, NewSubchannelsUseUpdatedKeepaliveTime) { + grpc_core::ExecCtx exec_ctx; + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + std::string server_address1 = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + std::string server_address2 = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + grpc_server* server1 = ServerStart(server_address1.c_str(), cq); + grpc_server* server2 = ServerStart(server_address2.c_str(), cq); + // create a single channel with multiple subchannels with a keepalive ping + // interval of 1 second. To get finer control on subchannel connection times, + // we are using pick_first instead of round_robin and using the fake resolver + // response generator to switch between the two. + auto response_generator = + grpc_core::MakeRefCounted(); + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), 1 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0), + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + response_generator.get())}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + grpc_channel* channel = + grpc_insecure_channel_create("fake:///", &client_channel_args, nullptr); + // For a single subchannel 3 GOAWAYs would be sufficient to increase the + // keepalive time from 1 second to beyond 5 seconds. Even though we are + // alternating between two subchannels, 3 GOAWAYs should still be enough since + // the channel should start all new transports with the new keepalive value + // (even those from a different subchannel). + int expected_keepalive_time_sec = 1; + for (int i = 0; i < 3; i++) { + gpr_log(GPR_INFO, "Expected keepalive time : %d", + expected_keepalive_time_sec); + response_generator->SetResponse(BuildResolverResult({absl::StrCat( + "ipv4:", i % 2 == 0 ? server_address1 : server_address2)})); + // ExecCtx::Flush() might not be enough to make sure that the resolver + // result has been propagated, so sleep for a bit. + grpc_core::ExecCtx::Get()->Flush(); + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + EXPECT_EQ(PerformWaitingCall(channel, i % 2 == 0 ? server1 : server2, cq), + GRPC_STATUS_UNAVAILABLE); + expected_keepalive_time_sec *= 2; + } + gpr_log( + GPR_INFO, + "Client keepalive time %d should now be in sync with the server settings", + expected_keepalive_time_sec); + response_generator->SetResponse( + BuildResolverResult({absl::StrCat("ipv4:", server_address2)})); + grpc_core::ExecCtx::Get()->Flush(); + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + EXPECT_EQ(PerformWaitingCall(channel, server2, cq), + GRPC_STATUS_DEADLINE_EXCEEDED); + // shutdown and destroy the client and server + grpc_channel_destroy(channel); + ServerShutdownAndDestroy(server1, cq); + ServerShutdownAndDestroy(server2, cq); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); +} + +// Tests that when a channel has multiple subchannels and receives a GOAWAY with +// "too_many_pings" on one of them, all subchannels start any new transports +// with an updated keepalive time. +TEST_F(KeepaliveThrottlingTest, + ExistingSubchannelsUseNewKeepaliveTimeWhenReconnecting) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + std::string server_address1 = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + std::string server_address2 = + grpc_core::JoinHostPort("127.0.0.1", grpc_pick_unused_port_or_die()); + // create a single channel with round robin load balancing policy. + auto response_generator = + grpc_core::MakeRefCounted(); + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), 1 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_BDP_PROBE), 0), + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + response_generator.get())}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + grpc_channel* channel = + grpc_insecure_channel_create("fake:///", &client_channel_args, nullptr); + response_generator->SetResponse( + BuildResolverResult({absl::StrCat("ipv4:", server_address1), + absl::StrCat("ipv4:", server_address2)})); + // For a single subchannel 3 GOAWAYs would be sufficient to increase the + // keepalive time from 1 second to beyond 5 seconds. Even though we are + // alternating between two subchannels, 3 GOAWAYs should still be enough since + // the channel should start all new transports with the new keepalive value + // (even those from a different subchannel). + int expected_keepalive_time_sec = 1; + for (int i = 0; i < 3; i++) { + gpr_log(GPR_ERROR, "Expected keepalive time : %d", + expected_keepalive_time_sec); + grpc_server* server = ServerStart( + i % 2 == 0 ? server_address1.c_str() : server_address2.c_str(), cq); + VerifyChannelReady(channel, cq); + EXPECT_EQ(PerformWaitingCall(channel, server, cq), GRPC_STATUS_UNAVAILABLE); + ServerShutdownAndDestroy(server, cq); + VerifyChannelDisconnected(channel, cq); + expected_keepalive_time_sec *= 2; + } + gpr_log( + GPR_INFO, + "Client keepalive time %d should now be in sync with the server settings", + expected_keepalive_time_sec); + grpc_server* server = ServerStart(server_address1.c_str(), cq); + VerifyChannelReady(channel, cq); + EXPECT_EQ(PerformWaitingCall(channel, server, cq), + GRPC_STATUS_DEADLINE_EXCEEDED); + ServerShutdownAndDestroy(server, cq); + // shutdown and destroy the client + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); +} + +// Perform a simple RPC where the client makes a request expecting a response +// with payload. +void PerformCallWithResponsePayload(grpc_channel* channel, grpc_server* server, + grpc_completion_queue* cq) { + grpc_slice response_payload_slice = grpc_slice_from_static_string("hello"); + + grpc_call* c; + grpc_call* s; + grpc_byte_buffer* response_payload = + grpc_raw_byte_buffer_create(&response_payload_slice, 1); + cq_verifier* cqv = cq_verifier_create(cq); + grpc_op ops[6]; + grpc_op* op; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_call_details call_details; + grpc_status_code status; + grpc_call_error error; + grpc_slice details; + int was_cancelled = 2; + + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(60); + c = grpc_channel_create_call(channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, + deadline, nullptr); + GPR_ASSERT(c); + + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details_init(&call_details); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(c, ops, static_cast(op - ops), tag(1), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + error = grpc_server_request_call(server, &s, &call_details, + &request_metadata_recv, cq, cq, tag(101)); + GPR_ASSERT(GRPC_CALL_OK == error); + CQ_EXPECT_COMPLETION(cqv, tag(101), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(102), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(102), 1); + cq_verify(cqv); + + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_CLOSE_ON_SERVER; + op->data.recv_close_on_server.cancelled = &was_cancelled; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = response_payload; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = 0; + op->data.send_status_from_server.status = GRPC_STATUS_OK; + grpc_slice status_details = grpc_slice_from_static_string("xyz"); + op->data.send_status_from_server.status_details = &status_details; + op->flags = 0; + op->reserved = nullptr; + op++; + error = grpc_call_start_batch(s, ops, static_cast(op - ops), tag(103), + nullptr); + GPR_ASSERT(GRPC_CALL_OK == error); + + CQ_EXPECT_COMPLETION(cqv, tag(103), 1); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(0 == grpc_slice_str_cmp(details, "xyz")); + GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/foo")); + GPR_ASSERT(was_cancelled == 0); + GPR_ASSERT( + byte_buffer_eq_slice(response_payload_recv, response_payload_slice)); + + grpc_slice_unref(details); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + + grpc_call_unref(c); + grpc_call_unref(s); + + cq_verifier_destroy(cqv); + + grpc_byte_buffer_destroy(response_payload); + grpc_byte_buffer_destroy(response_payload_recv); +} + +TEST(TooManyPings, BdpPingNotSentWithoutReceiveSideActivity) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + // create the server + std::string server_address = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + grpc_arg server_args[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 60 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), 1)}; + grpc_channel_args server_channel_args = {GPR_ARRAY_SIZE(server_args), + server_args}; + grpc_server* server = grpc_server_create(&server_channel_args, nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(server, server_address.c_str())); + grpc_server_start(server); + // create the channel (bdp pings are enabled by default) + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 1)}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + grpc_channel* channel = grpc_insecure_channel_create( + server_address.c_str(), &client_channel_args, nullptr); + VerifyChannelReady(channel, cq); + EXPECT_EQ(TransportCounter::count(), 2 /* one each for server and client */); + cq_verifier* cqv = cq_verifier_create(cq); + // Channel should be able to send two pings without disconnect if there was no + // BDP sent. + grpc_channel_ping(channel, cq, tag(1), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv, 5); + // Second ping + grpc_channel_ping(channel, cq, tag(2), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv, 5); + ASSERT_EQ(grpc_channel_check_connectivity_state(channel, 0), + GRPC_CHANNEL_READY); + PerformCallWithResponsePayload(channel, server, cq); + // Wait a bit to make sure that the BDP ping goes out. + cq_verify_empty_timeout(cqv, 1); + // The call with a response payload should have triggered a BDP ping. + // Send two more pings to verify. The second ping should cause a disconnect. + // If BDP was not sent, the second ping would not cause a disconnect. + grpc_channel_ping(channel, cq, tag(3), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(3), 1); + cq_verify(cqv, 5); + // Second ping + grpc_channel_ping(channel, cq, tag(4), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(4), 1); + cq_verify(cqv, 5); + // Make sure that the transports have been destroyed + VerifyChannelDisconnected(channel, cq); + TransportCounter::WaitForTransportsToBeDestroyed(); + cq_verifier_destroy(cqv); + // shutdown and destroy the client and server + ServerShutdownAndDestroy(server, cq); + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); +} + +TEST(TooManyPings, TransportsGetCleanedUpOnDisconnect) { + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + // create the client and server + std::string server_address = + grpc_core::JoinHostPort("localhost", grpc_pick_unused_port_or_die()); + grpc_arg server_args[] = { + grpc_channel_arg_integer_create( + const_cast( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), + 60 * 1000), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PING_STRIKES), 1)}; + grpc_channel_args server_channel_args = {GPR_ARRAY_SIZE(server_args), + server_args}; + grpc_server* server = grpc_server_create(&server_channel_args, nullptr); + grpc_server_register_completion_queue(server, cq, nullptr); + GPR_ASSERT( + grpc_server_add_insecure_http2_port(server, server_address.c_str())); + grpc_server_start(server); + grpc_arg client_args[] = { + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), 0), + grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), 1)}; + grpc_channel_args client_channel_args = {GPR_ARRAY_SIZE(client_args), + client_args}; + grpc_channel* channel = grpc_insecure_channel_create( + server_address.c_str(), &client_channel_args, nullptr); + VerifyChannelReady(channel, cq); + EXPECT_EQ(TransportCounter::count(), 2 /* one each for server and client */); + cq_verifier* cqv = cq_verifier_create(cq); + // First ping + grpc_channel_ping(channel, cq, tag(1), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(1), 1); + cq_verify(cqv, 5); + // Second ping + grpc_channel_ping(channel, cq, tag(2), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv, 5); + // Third ping caused disconnect + grpc_channel_ping(channel, cq, tag(2), nullptr); + CQ_EXPECT_COMPLETION(cqv, tag(2), 1); + cq_verify(cqv, 5); + // Make sure that the transports have been destroyed + VerifyChannelDisconnected(channel, cq); + TransportCounter::WaitForTransportsToBeDestroyed(); + cq_verifier_destroy(cqv); + // shutdown and destroy the client and server + ServerShutdownAndDestroy(server, cq); + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + while (grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr) + .type != GRPC_QUEUE_SHUTDOWN) { + } + grpc_completion_queue_destroy(cq); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_core::TestOnlySetGlobalHttp2TransportInitCallback( + TransportCounter::CounterInitCallback); + grpc_core::TestOnlySetGlobalHttp2TransportDestructCallback( + TransportCounter::CounterDestructCallback); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/transport/chttp2/varint_test.cc b/test/core/transport/chttp2/varint_test.cc new file mode 100644 index 00000000..400529b4 --- /dev/null +++ b/test/core/transport/chttp2/varint_test.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/ext/transport/chttp2/transport/varint.h" + +#include +#include +#include + +#include "test/core/util/test_config.h" + +template +static void test_varint(uint32_t value, uint8_t prefix_or, + const char* expect_bytes, size_t expect_length) { + grpc_core::VarintWriter w(value); + grpc_slice expect = + grpc_slice_from_copied_buffer(expect_bytes, expect_length); + grpc_slice slice; + gpr_log(GPR_DEBUG, "Test: 0x%08x", value); + GPR_ASSERT(w.length() == expect_length); + slice = grpc_slice_malloc(w.length()); + w.Write(prefix_or, GRPC_SLICE_START_PTR(slice)); + GPR_ASSERT(grpc_slice_eq(expect, slice)); + grpc_slice_unref(expect); + grpc_slice_unref(slice); +} + +#define TEST_VARINT(value, prefix_bits, prefix_or, expect) \ + test_varint(value, prefix_or, expect, sizeof(expect) - 1) + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + TEST_VARINT(0, 1, 0, "\x00"); + TEST_VARINT(128, 1, 0, "\x7f\x01"); + TEST_VARINT(16384, 1, 0, "\x7f\x81\x7f"); + TEST_VARINT(2097152, 1, 0, "\x7f\x81\xff\x7f"); + TEST_VARINT(268435456, 1, 0, "\x7f\x81\xff\xff\x7f"); + TEST_VARINT(0xffffffff, 1, 0, "\x7f\x80\xff\xff\xff\x0f"); + grpc_shutdown(); + return 0; +} diff --git a/test/core/transport/connectivity_state_test.cc b/test/core/transport/connectivity_state_test.cc new file mode 100644 index 00000000..21ea0114 --- /dev/null +++ b/test/core/transport/connectivity_state_test.cc @@ -0,0 +1,244 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/connectivity_state.h" + +#include + +#include + +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tracer_util.h" + +namespace grpc_core { +namespace { + +TEST(ConnectivityStateName, Basic) { + EXPECT_STREQ("IDLE", ConnectivityStateName(GRPC_CHANNEL_IDLE)); + EXPECT_STREQ("CONNECTING", ConnectivityStateName(GRPC_CHANNEL_CONNECTING)); + EXPECT_STREQ("READY", ConnectivityStateName(GRPC_CHANNEL_READY)); + EXPECT_STREQ("TRANSIENT_FAILURE", + ConnectivityStateName(GRPC_CHANNEL_TRANSIENT_FAILURE)); + EXPECT_STREQ("SHUTDOWN", ConnectivityStateName(GRPC_CHANNEL_SHUTDOWN)); +} + +class Watcher : public ConnectivityStateWatcherInterface { + public: + Watcher(int* count, grpc_connectivity_state* output, absl::Status* status, + bool* destroyed = nullptr) + : count_(count), + output_(output), + status_(status), + destroyed_(destroyed) {} + + ~Watcher() override { + if (destroyed_ != nullptr) *destroyed_ = true; + } + + void Notify(grpc_connectivity_state new_state, + const absl::Status& status) override { + ++*count_; + *output_ = new_state; + *status_ = status; + } + + private: + int* count_; + grpc_connectivity_state* output_; + absl::Status* status_; + bool* destroyed_; +}; + +TEST(StateTracker, SetAndGetState) { + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_CONNECTING, + absl::Status()); + EXPECT_EQ(tracker.state(), GRPC_CHANNEL_CONNECTING); + EXPECT_TRUE(tracker.status().ok()); + tracker.SetState(GRPC_CHANNEL_READY, absl::Status(), "whee"); + EXPECT_EQ(tracker.state(), GRPC_CHANNEL_READY); + EXPECT_TRUE(tracker.status().ok()); + absl::Status transient_failure_status(absl::StatusCode::kUnavailable, + "status for testing"); + tracker.SetState(GRPC_CHANNEL_TRANSIENT_FAILURE, transient_failure_status, + "reason"); + EXPECT_EQ(tracker.state(), GRPC_CHANNEL_TRANSIENT_FAILURE); + EXPECT_EQ(tracker.status(), transient_failure_status); +} + +TEST(StateTracker, NotificationUponAddingWatcher) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_CONNECTING); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + MakeOrphanable(&count, &state, &status)); + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_CONNECTING); + EXPECT_TRUE(status.ok()); +} + +TEST(StateTracker, NotificationUponAddingWatcherWithTransientFailure) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + absl::Status transient_failure_status(absl::StatusCode::kUnavailable, + "status for testing"); + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_TRANSIENT_FAILURE, + transient_failure_status); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + MakeOrphanable(&count, &state, &status)); + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_TRANSIENT_FAILURE); + EXPECT_EQ(status, transient_failure_status); +} + +TEST(StateTracker, NotificationUponStateChange) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_IDLE); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + MakeOrphanable(&count, &state, &status)); + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_IDLE); + EXPECT_TRUE(status.ok()); + absl::Status transient_failure_status(absl::StatusCode::kUnavailable, + "status for testing"); + tracker.SetState(GRPC_CHANNEL_TRANSIENT_FAILURE, transient_failure_status, + "whee"); + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_TRANSIENT_FAILURE); + EXPECT_EQ(status, transient_failure_status); +} + +TEST(StateTracker, SubscribeThenUnsubscribe) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + bool destroyed = false; + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_IDLE); + ConnectivityStateWatcherInterface* watcher = + new Watcher(&count, &state, &status, &destroyed); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher)); + // No initial notification, since we started the watch from the + // current state. + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_IDLE); + EXPECT_TRUE(status.ok()); + // Cancel watch. This should not generate another notification. + tracker.RemoveWatcher(watcher); + EXPECT_TRUE(destroyed); + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_IDLE); + EXPECT_TRUE(status.ok()); +} + +TEST(StateTracker, OrphanUponShutdown) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + bool destroyed = false; + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_IDLE); + ConnectivityStateWatcherInterface* watcher = + new Watcher(&count, &state, &status, &destroyed); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher)); + // No initial notification, since we started the watch from the + // current state. + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_IDLE); + EXPECT_TRUE(status.ok()); + // Set state to SHUTDOWN. + tracker.SetState(GRPC_CHANNEL_SHUTDOWN, absl::Status(), "shutting down"); + EXPECT_TRUE(destroyed); + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_SHUTDOWN); + EXPECT_TRUE(status.ok()); +} + +TEST(StateTracker, AddWhenAlreadyShutdown) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + bool destroyed = false; + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_SHUTDOWN, + absl::Status()); + ConnectivityStateWatcherInterface* watcher = + new Watcher(&count, &state, &status, &destroyed); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + OrphanablePtr(watcher)); + EXPECT_TRUE(destroyed); + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_SHUTDOWN); + EXPECT_TRUE(status.ok()); +} + +TEST(StateTracker, NotifyShutdownAtDestruction) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_IDLE; + absl::Status status; + { + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_IDLE); + tracker.AddWatcher(GRPC_CHANNEL_IDLE, + MakeOrphanable(&count, &state, &status)); + // No initial notification, since we started the watch from the + // current state. + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_IDLE); + } + // Upon tracker destruction, we get a notification for SHUTDOWN. + EXPECT_EQ(count, 1); + EXPECT_EQ(state, GRPC_CHANNEL_SHUTDOWN); +} + +TEST(StateTracker, DoNotNotifyShutdownAtDestructionIfAlreadyInShutdown) { + int count = 0; + grpc_connectivity_state state = GRPC_CHANNEL_SHUTDOWN; + absl::Status status; + { + ConnectivityStateTracker tracker("xxx", GRPC_CHANNEL_SHUTDOWN); + tracker.AddWatcher(GRPC_CHANNEL_SHUTDOWN, + MakeOrphanable(&count, &state, &status)); + // No initial notification, since we started the watch from the + // current state. + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_SHUTDOWN); + } + // No additional notification upon tracker destruction, since we were + // already in state SHUTDOWN. + EXPECT_EQ(count, 0); + EXPECT_EQ(state, GRPC_CHANNEL_SHUTDOWN); +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + grpc_core::testing::grpc_tracer_enable_flag( + &grpc_core::grpc_connectivity_state_trace); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/error_utils_test.cc b/test/core/transport/error_utils_test.cc new file mode 100644 index 00000000..ebbb9f96 --- /dev/null +++ b/test/core/transport/error_utils_test.cc @@ -0,0 +1,94 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/transport/error_utils.h" + +#include + +#include "absl/status/status.h" + +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +namespace { + +// ---- Ok Status ---- +TEST(ErrorUtilsTest, AbslOkToGrpcError) { + grpc_error_handle error = absl_status_to_grpc_error(absl::OkStatus()); + ASSERT_EQ(GRPC_ERROR_NONE, error); + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorUtilsTest, GrpcSpecialErrorNoneToAbslStatus) { + absl::Status status = grpc_error_to_absl_status(GRPC_ERROR_NONE); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(status.message(), ""); +} + +// ---- Asymmetry of conversions of "Special" errors ---- +TEST(ErrorUtilsTest, AbslStatusToGrpcErrorDoesNotReturnSpecialVariables) { + grpc_error_handle error = + absl_status_to_grpc_error(absl::CancelledError("CANCELLED")); + ASSERT_NE(error, GRPC_ERROR_CANCELLED); + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorUtilsTest, GrpcSpecialErrorCancelledToAbslStatus) { + absl::Status status = grpc_error_to_absl_status(GRPC_ERROR_CANCELLED); + ASSERT_TRUE(absl::IsCancelled(status)); + ASSERT_EQ(status.message(), "CANCELLED"); +} + +TEST(ErrorUtilsTest, GrpcSpecialErrorOOMToAbslStatus) { + absl::Status status = grpc_error_to_absl_status(GRPC_ERROR_OOM); + ASSERT_TRUE(absl::IsResourceExhausted(status)); + ASSERT_EQ(status.message(), "RESOURCE_EXHAUSTED"); +} + +// ---- Ordinary statuses ---- +TEST(ErrorUtilsTest, AbslUnavailableToGrpcError) { + grpc_error_handle error = + absl_status_to_grpc_error(absl::UnavailableError("Making tea")); + // Status code checks + intptr_t code; + ASSERT_TRUE(grpc_error_get_int(error, GRPC_ERROR_INT_GRPC_STATUS, &code)); + ASSERT_EQ(static_cast(code), GRPC_STATUS_UNAVAILABLE); + // Status message checks + std::string message; + ASSERT_TRUE(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, &message)); + ASSERT_EQ(message, "Making tea"); + GRPC_ERROR_UNREF(error); +} + +TEST(ErrorUtilsTest, GrpcErrorUnavailableToAbslStatus) { + grpc_error_handle error = grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "weighted_target: all children report state TRANSIENT_FAILURE"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE); + absl::Status status = grpc_error_to_absl_status(error); + ASSERT_TRUE(absl::IsUnavailable(status)); + ASSERT_EQ(status.message(), + "weighted_target: all children report state TRANSIENT_FAILURE"); + GRPC_ERROR_UNREF(error); +} + +} // namespace + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +}; diff --git a/test/core/transport/metadata_map_test.cc b/test/core/transport/metadata_map_test.cc new file mode 100644 index 00000000..91e58c6b --- /dev/null +++ b/test/core/transport/metadata_map_test.cc @@ -0,0 +1,95 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +TEST(MetadataMapTest, Noop) { + auto arena = MakeScopedArena(1024); + MetadataMap<>(arena.get()); +} + +TEST(MetadataMapTest, NoopWithDeadline) { + auto arena = MakeScopedArena(1024); + MetadataMap(arena.get()); +} + +TEST(MetadataMapTest, SimpleOps) { + auto arena = MakeScopedArena(1024); + MetadataMap map(arena.get()); + EXPECT_EQ(map.get_pointer(GrpcTimeoutMetadata()), nullptr); + EXPECT_EQ(map.get(GrpcTimeoutMetadata()), absl::nullopt); + map.Set(GrpcTimeoutMetadata(), 1234); + EXPECT_NE(map.get_pointer(GrpcTimeoutMetadata()), nullptr); + EXPECT_EQ(*map.get_pointer(GrpcTimeoutMetadata()), 1234); + EXPECT_EQ(map.get(GrpcTimeoutMetadata()), 1234); + map.Remove(GrpcTimeoutMetadata()); + EXPECT_EQ(map.get_pointer(GrpcTimeoutMetadata()), nullptr); + EXPECT_EQ(map.get(GrpcTimeoutMetadata()), absl::nullopt); +} + +// Target for MetadataMap::Encode. +// Writes down some string representation of what it receives, so we can +// EXPECT_EQ it later. +class FakeEncoder { + public: + std::string output() { return output_; } + + void Encode(grpc_mdelem md) { + output_ += + absl::StrCat("LEGACY CALL: key=", StringViewFromSlice(GRPC_MDKEY(md)), + " value=", StringViewFromSlice(GRPC_MDVALUE(md)), "\n"); + } + + void Encode(GrpcTimeoutMetadata, grpc_millis deadline) { + output_ += absl::StrCat("grpc-timeout: deadline=", deadline, "\n"); + } + + private: + std::string output_; +}; + +TEST(MetadataMapTest, EmptyEncodeTest) { + FakeEncoder encoder; + auto arena = MakeScopedArena(1024); + MetadataMap map(arena.get()); + map.Encode(&encoder); + EXPECT_EQ(encoder.output(), ""); +} + +TEST(MetadataMapTest, TimeoutEncodeTest) { + FakeEncoder encoder; + auto arena = MakeScopedArena(1024); + MetadataMap map(arena.get()); + map.Set(GrpcTimeoutMetadata(), 1234); + map.Encode(&encoder); + EXPECT_EQ(encoder.output(), "grpc-timeout: deadline=1234\n"); +} + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +}; diff --git a/test/core/transport/metadata_test.cc b/test/core/transport/metadata_test.cc new file mode 100644 index 00000000..81569297 --- /dev/null +++ b/test/core/transport/metadata_test.cc @@ -0,0 +1,394 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/metadata.h" + +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/bin_encoder.h" +#include "src/core/ext/transport/chttp2/transport/hpack_utils.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/test_config.h" + +/* a large number */ +#define MANY 10000 + +static grpc_slice maybe_intern(grpc_slice in, bool intern) { + grpc_slice out = intern ? grpc_slice_intern(in) : grpc_slice_ref(in); + grpc_slice_unref(in); + return out; +} + +static grpc_slice maybe_dup(grpc_slice in, bool dup) { + grpc_slice out = dup ? grpc_slice_dup(in) : grpc_slice_ref(in); + grpc_slice_unref(in); + return out; +} + +static void test_create_metadata(bool intern_keys, bool intern_values) { + grpc_mdelem m1, m2, m3; + + gpr_log(GPR_INFO, "test_create_metadata: intern_keys=%d intern_values=%d", + intern_keys, intern_values); + + grpc_core::ExecCtx exec_ctx; + m1 = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values)); + m2 = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values)); + m3 = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("c"), intern_values)); + GPR_ASSERT(grpc_mdelem_eq(m1, m2)); + GPR_ASSERT(!grpc_mdelem_eq(m3, m1)); + GPR_ASSERT(grpc_slice_eq(GRPC_MDKEY(m3), GRPC_MDKEY(m1))); + GPR_ASSERT(!grpc_slice_eq(GRPC_MDVALUE(m3), GRPC_MDVALUE(m1))); + GPR_ASSERT(grpc_slice_str_cmp(GRPC_MDKEY(m1), "a") == 0); + GPR_ASSERT(grpc_slice_str_cmp(GRPC_MDVALUE(m1), "b") == 0); + GPR_ASSERT(grpc_slice_str_cmp(GRPC_MDVALUE(m3), "c") == 0); + GRPC_MDELEM_UNREF(m1); + GRPC_MDELEM_UNREF(m2); + GRPC_MDELEM_UNREF(m3); +} + +static void test_create_many_ephemeral_metadata(bool intern_keys, + bool intern_values) { + char buffer[GPR_LTOA_MIN_BUFSIZE]; + long i; + + gpr_log( + GPR_INFO, + "test_create_many_ephemeral_metadata: intern_keys=%d intern_values=%d", + intern_keys, intern_values); + + grpc_core::ExecCtx exec_ctx; + /* add, and immediately delete a bunch of different elements */ + for (i = 0; i < MANY; i++) { + gpr_ltoa(i, buffer); + GRPC_MDELEM_UNREF(grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_copied_string(buffer), intern_values))); + } +} + +static void test_create_many_persistant_metadata(void) { + char buffer[GPR_LTOA_MIN_BUFSIZE]; + long i; + grpc_mdelem* created = + static_cast(gpr_malloc(sizeof(grpc_mdelem) * MANY)); + grpc_mdelem md; + + gpr_log(GPR_INFO, "test_create_many_persistant_metadata"); + + grpc_core::ExecCtx exec_ctx; + /* add phase */ + for (i = 0; i < MANY; i++) { + gpr_ltoa(i, buffer); + created[i] = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("a")), + grpc_slice_intern(grpc_slice_from_static_string(buffer))); + } + /* verify phase */ + for (i = 0; i < MANY; i++) { + gpr_ltoa(i, buffer); + md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("a")), + grpc_slice_intern(grpc_slice_from_static_string(buffer))); + GPR_ASSERT(grpc_mdelem_eq(md, created[i])); + GRPC_MDELEM_UNREF(md); + } + /* cleanup phase */ + for (i = 0; i < MANY; i++) { + GRPC_MDELEM_UNREF(created[i]); + } + + gpr_free(created); +} + +static void test_spin_creating_the_same_thing(bool intern_keys, + bool intern_values) { + gpr_log(GPR_INFO, + "test_spin_creating_the_same_thing: intern_keys=%d intern_values=%d", + intern_keys, intern_values); + + grpc_core::ExecCtx exec_ctx; + grpc_mdelem a, b, c; + GRPC_MDELEM_UNREF( + a = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values))); + GRPC_MDELEM_UNREF( + b = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values))); + GRPC_MDELEM_UNREF( + c = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values))); + if (intern_keys && intern_values) { + GPR_ASSERT(a.payload == b.payload); + GPR_ASSERT(a.payload == c.payload); + } +} + +static void test_identity_laws(bool intern_keys, bool intern_values) { + gpr_log(GPR_INFO, "test_identity_laws: intern_keys=%d intern_values=%d", + intern_keys, intern_values); + + grpc_core::ExecCtx exec_ctx; + grpc_mdelem a, b, c; + a = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values)); + b = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values)); + c = grpc_mdelem_from_slices( + maybe_intern(grpc_slice_from_static_string("a"), intern_keys), + maybe_intern(grpc_slice_from_static_string("b"), intern_values)); + GPR_ASSERT(grpc_mdelem_eq(a, a)); + GPR_ASSERT(grpc_mdelem_eq(b, b)); + GPR_ASSERT(grpc_mdelem_eq(c, c)); + GPR_ASSERT(grpc_mdelem_eq(a, b)); + GPR_ASSERT(grpc_mdelem_eq(b, c)); + GPR_ASSERT(grpc_mdelem_eq(a, c)); + GPR_ASSERT(grpc_mdelem_eq(b, a)); + GPR_ASSERT(grpc_mdelem_eq(c, b)); + GPR_ASSERT(grpc_mdelem_eq(c, a)); + if (intern_keys && intern_values) { + GPR_ASSERT(a.payload == b.payload); + GPR_ASSERT(a.payload == c.payload); + } else { + GPR_ASSERT(a.payload != b.payload); + GPR_ASSERT(a.payload != c.payload); + GPR_ASSERT(b.payload != c.payload); + } + GRPC_MDELEM_UNREF(a); + GRPC_MDELEM_UNREF(b); + GRPC_MDELEM_UNREF(c); +} + +static void test_things_stick_around(void) { + size_t i, j; + size_t nstrs = 1000; + grpc_slice* strs = + static_cast(gpr_malloc(sizeof(grpc_slice) * nstrs)); + size_t* shuf = static_cast(gpr_malloc(sizeof(size_t) * nstrs)); + grpc_slice test; + + gpr_log(GPR_INFO, "test_things_stick_around"); + + grpc_core::ExecCtx exec_ctx; + + for (i = 0; i < nstrs; i++) { + std::string buffer = + absl::StrFormat("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx%" PRIuPTR "x", i); + strs[i] = grpc_slice_intern(grpc_slice_from_static_string(buffer.c_str())); + shuf[i] = i; + } + + for (i = 0; i < nstrs; i++) { + grpc_slice_ref_internal(strs[i]); + grpc_slice_unref_internal(strs[i]); + } + + for (i = 0; i < nstrs; i++) { + size_t p = static_cast(rand()) % nstrs; + size_t q = static_cast(rand()) % nstrs; + size_t temp = shuf[p]; + shuf[p] = shuf[q]; + shuf[q] = temp; + } + + for (i = 0; i < nstrs; i++) { + grpc_slice_unref_internal(strs[shuf[i]]); + for (j = i + 1; j < nstrs; j++) { + std::string buffer = absl::StrFormat( + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx%" PRIuPTR "x", shuf[j]); + test = grpc_slice_intern(grpc_slice_from_static_string(buffer.c_str())); + GPR_ASSERT(grpc_slice_is_equivalent(test, strs[shuf[j]])); + grpc_slice_unref_internal(test); + } + } + gpr_free(strs); + gpr_free(shuf); +} + +static void test_user_data_works(void) { + int* ud1; + int* ud2; + grpc_mdelem md; + gpr_log(GPR_INFO, "test_user_data_works"); + + grpc_core::ExecCtx exec_ctx; + ud1 = static_cast(gpr_malloc(sizeof(int))); + *ud1 = 1; + ud2 = static_cast(gpr_malloc(sizeof(int))); + *ud2 = 2; + md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("abc")), + grpc_slice_intern(grpc_slice_from_static_string("123"))); + grpc_mdelem_set_user_data(md, gpr_free, ud1); + grpc_mdelem_set_user_data(md, gpr_free, ud2); + GPR_ASSERT(grpc_mdelem_get_user_data(md, gpr_free) == ud1); + GRPC_MDELEM_UNREF(md); +} + +static void test_user_data_works_for_allocated_md(void) { + int* ud1; + int* ud2; + grpc_mdelem md; + gpr_log(GPR_INFO, "test_user_data_works"); + + grpc_core::ExecCtx exec_ctx; + ud1 = static_cast(gpr_malloc(sizeof(int))); + *ud1 = 1; + ud2 = static_cast(gpr_malloc(sizeof(int))); + *ud2 = 2; + md = grpc_mdelem_from_slices(grpc_slice_from_static_string("abc"), + grpc_slice_from_static_string("123")); + grpc_mdelem_set_user_data(md, gpr_free, ud1); + grpc_mdelem_set_user_data(md, gpr_free, ud2); + GPR_ASSERT(grpc_mdelem_get_user_data(md, gpr_free) == ud1); + GRPC_MDELEM_UNREF(md); +} + +static void test_copied_static_metadata(bool dup_key, bool dup_value) { + gpr_log(GPR_INFO, "test_static_metadata: dup_key=%d dup_value=%d", dup_key, + dup_value); + grpc_core::ExecCtx exec_ctx; + + for (size_t i = 0; i < GRPC_STATIC_MDELEM_COUNT; i++) { + grpc_mdelem p = GRPC_MAKE_MDELEM(&grpc_core::g_static_mdelem_table[i], + GRPC_MDELEM_STORAGE_STATIC); + grpc_mdelem q = + grpc_mdelem_from_slices(maybe_dup(GRPC_MDKEY(p), dup_key), + maybe_dup(GRPC_MDVALUE(p), dup_value)); + GPR_ASSERT(grpc_mdelem_eq(p, q)); + if (dup_key || dup_value) { + GPR_ASSERT(p.payload != q.payload); + } else { + GPR_ASSERT(p.payload == q.payload); + } + GRPC_MDELEM_UNREF(p); + GRPC_MDELEM_UNREF(q); + } +} + +static void test_grpc_metadata_batch_get_value_with_absent_key(void) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch metadata(arena.get()); + std::string concatenated_value; + absl::optional value = + metadata.GetValue("absent_key", &concatenated_value); + GPR_ASSERT(value == absl::nullopt); +} + +static void test_grpc_metadata_batch_get_value_returns_one_value(void) { + const char* kKey = "some_key"; + const char* kValue = "some_value"; + auto arena = grpc_core::MakeScopedArena(1024); + grpc_linked_mdelem storage; + grpc_metadata_batch metadata(arena.get()); + storage.md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(kKey)), + grpc_slice_intern(grpc_slice_from_static_string(kValue))); + GPR_ASSERT(metadata.LinkHead(&storage) == GRPC_ERROR_NONE); + std::string concatenated_value; + absl::optional value = + metadata.GetValue(kKey, &concatenated_value); + GPR_ASSERT(value.has_value()); + GPR_ASSERT(value.value() == kValue); +} + +static void test_grpc_metadata_batch_get_value_returns_multiple_values(void) { + const char* kKey = "some_key"; + const char* kValue1 = "value1"; + const char* kValue2 = "value2"; + auto arena = grpc_core::MakeScopedArena(1024); + grpc_linked_mdelem storage1; + grpc_linked_mdelem storage2; + grpc_metadata_batch metadata(arena.get()); + storage1.md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(kKey)), + grpc_slice_intern(grpc_slice_from_static_string(kValue1))); + GPR_ASSERT(metadata.LinkTail(&storage1) == GRPC_ERROR_NONE); + storage2.md = grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string(kKey)), + grpc_slice_intern(grpc_slice_from_static_string(kValue2))); + GPR_ASSERT(metadata.LinkTail(&storage2) == GRPC_ERROR_NONE); + std::string concatenated_value; + absl::optional value = + metadata.GetValue(kKey, &concatenated_value); + GPR_ASSERT(value.has_value()); + GPR_ASSERT(value.value() == absl::StrCat(kValue1, ",", kValue2)); +} + +static void test_grpc_chttp2_incoming_metadata_replace_or_add_works(void) { + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch buffer(arena.get()); + GRPC_LOG_IF_ERROR("incoming_buffer_add", + buffer.Append(grpc_mdelem_from_slices( + grpc_slice_from_static_string("a"), + grpc_slice_from_static_string("b")))); + GRPC_LOG_IF_ERROR( + "incoming_buffer_replace_or_add", + buffer.ReplaceOrAppend(grpc_slice_from_static_string("a"), + grpc_slice_malloc(1024 * 1024 * 1024))); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + for (int k = 0; k <= 1; k++) { + for (int v = 0; v <= 1; v++) { + test_create_metadata(k, v); + test_create_many_ephemeral_metadata(k, v); + test_identity_laws(k, v); + test_spin_creating_the_same_thing(k, v); + test_copied_static_metadata(k, v); + } + } + test_create_many_persistant_metadata(); + test_things_stick_around(); + test_user_data_works(); + test_user_data_works_for_allocated_md(); + test_grpc_metadata_batch_get_value_with_absent_key(); + test_grpc_metadata_batch_get_value_returns_one_value(); + test_grpc_metadata_batch_get_value_returns_multiple_values(); + test_grpc_chttp2_incoming_metadata_replace_or_add_works(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/transport/parsed_metadata_test.cc b/test/core/transport/parsed_metadata_test.cc new file mode 100644 index 00000000..d0a8c34d --- /dev/null +++ b/test/core/transport/parsed_metadata_test.cc @@ -0,0 +1,210 @@ +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/lib/transport/parsed_metadata.h" + +#include +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +struct CharTrait { + using MementoType = char; + static const char* key() { return "key"; } + static char test_memento() { return 'a'; } + static char test_value() { return 'a'; } + static size_t test_memento_transport_size() { return 34; } + static char MementoToValue(char memento) { return memento; } + static char ParseMemento(const grpc_slice& slice) { + return *GRPC_SLICE_START_PTR(slice); + } + static std::string DisplayValue(char value) { return std::string(1, value); } +}; + +struct Int32Trait { + using MementoType = int32_t; + static const char* key() { return "key2"; } + static int32_t test_memento() { return -1; } + static int32_t test_value() { return -1; } + static size_t test_memento_transport_size() { return 478; } + static int32_t MementoToValue(int32_t memento) { return memento; } + static int32_t ParseMemento(const grpc_slice& slice) { + int32_t out; + GPR_ASSERT(absl::SimpleAtoi(StringViewFromSlice(slice), &out)); + return out; + } + static std::string DisplayValue(int32_t value) { + return std::to_string(value); + } +}; + +struct Int64Trait { + using MementoType = int64_t; + static const char* key() { return "key3"; } + static int64_t test_memento() { return 83481847284179298; } + static int64_t test_value() { return -83481847284179298; } + static size_t test_memento_transport_size() { return 87; } + static int64_t MementoToValue(int64_t memento) { return -memento; } + static int64_t ParseMemento(const grpc_slice& slice) { + int64_t out; + GPR_ASSERT(absl::SimpleAtoi(StringViewFromSlice(slice), &out)); + return out; + } + static std::string DisplayValue(int64_t value) { + return std::to_string(value); + } +}; + +struct IntptrTrait { + using MementoType = intptr_t; + static const char* key() { return "key4"; } + static intptr_t test_memento() { return 8374298; } + static intptr_t test_value() { return test_memento() / 2; } + static size_t test_memento_transport_size() { return 800; } + static intptr_t MementoToValue(intptr_t memento) { return memento / 2; } + static intptr_t ParseMemento(const grpc_slice& slice) { + intptr_t out; + GPR_ASSERT(absl::SimpleAtoi(StringViewFromSlice(slice), &out)); + return out; + } + static std::string DisplayValue(intptr_t value) { + return std::to_string(value); + } +}; + +struct StringTrait { + using MementoType = std::string; + static const char* key() { return "key5-bin"; } + static std::string test_memento() { return "hello"; } + static std::string test_value() { return "hi hello"; } + static size_t test_memento_transport_size() { return 599; } + static std::string MementoToValue(std::string memento) { + return "hi " + memento; + } + static std::string ParseMemento(const grpc_slice& slice) { + auto view = StringViewFromSlice(slice); + return std::string(view.begin(), view.end()); + } + static std::string DisplayValue(const std::string& value) { return value; } +}; + +class FakeContainer { + public: + void Set(CharTrait, char x) { SetChar(x); } + void Set(Int32Trait, int32_t x) { SetInt32(x); } + void Set(Int64Trait, int64_t x) { SetInt64(x); } + void Set(IntptrTrait, intptr_t x) { SetIntptr(x); } + void Set(StringTrait, std::string x) { SetString(x); } + + void Set(const ::grpc_core::ParsedMetadata& metadata) { + EXPECT_EQ(GRPC_ERROR_NONE, metadata.SetOnContainer(this)); + } + + MOCK_METHOD1(SetChar, void(char)); + MOCK_METHOD1(SetInt32, void(int32_t)); + MOCK_METHOD1(SetInt64, void(int64_t)); + MOCK_METHOD1(SetIntptr, void(intptr_t)); + MOCK_METHOD1(SetString, void(std::string)); +}; + +using ParsedMetadata = ::grpc_core::ParsedMetadata; + +TEST(ParsedMetadataTest, Noop) { ParsedMetadata(); } + +TEST(ParsedMetadataTest, DebugString) { + ParsedMetadata parsed(CharTrait(), 'x', 36); + EXPECT_EQ(parsed.DebugString(), "key: x"); +} + +TEST(ParsedMetadataTest, IsNotBinary) { + ParsedMetadata parsed(CharTrait(), 'x', 36); + EXPECT_FALSE(parsed.is_binary_header()); +} + +TEST(ParsedMetadataTest, IsBinary) { + ParsedMetadata parsed(StringTrait(), "s", 36); + EXPECT_TRUE(parsed.is_binary_header()); +} + +TEST(ParsedMetadataTest, Set) { + FakeContainer c; + ParsedMetadata p(CharTrait(), 'x', 36); + EXPECT_CALL(c, SetChar('x')).Times(1); + c.Set(p); + p = ParsedMetadata(Int32Trait(), -1, 478); + EXPECT_CALL(c, SetInt32(-1)).Times(1); + c.Set(p); + p = ParsedMetadata(Int64Trait(), 83481847284179298, 87); + EXPECT_CALL(c, SetInt64(-83481847284179298)).Times(1); + c.Set(p); + p = ParsedMetadata(IntptrTrait(), 8374298, 800); + EXPECT_CALL(c, SetIntptr(4187149)).Times(1); + c.Set(p); + p = ParsedMetadata(StringTrait(), "hello", 599); + EXPECT_CALL(c, SetString("hi hello")).Times(1); + c.Set(p); +} + +template +class TraitSpecializedTest : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(TraitSpecializedTest); + +TYPED_TEST_P(TraitSpecializedTest, Noop) { + ParsedMetadata(TypeParam(), TypeParam::test_memento(), + TypeParam::test_memento_transport_size()); +} + +TYPED_TEST_P(TraitSpecializedTest, CanMove) { + ParsedMetadata a(TypeParam(), TypeParam::test_memento(), + TypeParam::test_memento_transport_size()); + ParsedMetadata b = std::move(a); + a = std::move(b); +} + +TYPED_TEST_P(TraitSpecializedTest, DebugString) { + ParsedMetadata p(TypeParam(), TypeParam::test_memento(), + TypeParam::test_memento_transport_size()); + EXPECT_EQ(p.DebugString(), + absl::StrCat(TypeParam::key(), ": ", + TypeParam::DisplayValue(TypeParam::test_memento()))); +} + +TYPED_TEST_P(TraitSpecializedTest, TransportSize) { + ParsedMetadata p(TypeParam(), TypeParam::test_memento(), + TypeParam::test_memento_transport_size()); + EXPECT_EQ(p.transport_size(), TypeParam::test_memento_transport_size()); +} + +REGISTER_TYPED_TEST_SUITE_P(TraitSpecializedTest, Noop, CanMove, DebugString, + TransportSize); + +using InterestingTraits = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(My, TraitSpecializedTest, InterestingTraits); + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +}; diff --git a/test/core/transport/pid_controller_test.cc b/test/core/transport/pid_controller_test.cc new file mode 100644 index 00000000..b41a9b02 --- /dev/null +++ b/test/core/transport/pid_controller_test.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/pid_controller.h" + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +TEST(PidController, NoOp) { + PidController pid(PidController::Args() + .set_gain_p(1) + .set_gain_i(1) + .set_gain_d(1) + .set_initial_control_value(1)); +} + +struct SimpleConvergenceTestArgs { + double gain_p; + double gain_i; + double gain_d; + double dt; + double set_point; + double start; +}; + +std::ostream& operator<<(std::ostream& out, SimpleConvergenceTestArgs args) { + return out << "gain_p:" << args.gain_p << " gain_i:" << args.gain_i + << " gain_d:" << args.gain_d << " dt:" << args.dt + << " set_point:" << args.set_point << " start:" << args.start; +} + +class SimpleConvergenceTest + : public ::testing::TestWithParam {}; + +TEST_P(SimpleConvergenceTest, Converges) { + PidController pid(PidController::Args() + .set_gain_p(GetParam().gain_p) + .set_gain_i(GetParam().gain_i) + .set_gain_d(GetParam().gain_d) + .set_initial_control_value(GetParam().start)); + + for (int i = 0; i < 100000; i++) { + pid.Update(GetParam().set_point - pid.last_control_value(), GetParam().dt); + } + + EXPECT_LT(fabs(GetParam().set_point - pid.last_control_value()), 0.1); + if (GetParam().gain_i > 0) { + EXPECT_LT(fabs(pid.error_integral()), 0.1); + } +} + +INSTANTIATE_TEST_SUITE_P( + X, SimpleConvergenceTest, + ::testing::Values(SimpleConvergenceTestArgs{0.2, 0, 0, 1, 100, 0}, + SimpleConvergenceTestArgs{0.2, 0.1, 0, 1, 100, 0}, + SimpleConvergenceTestArgs{0.2, 0.1, 0.1, 1, 100, 0})); + +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/core/transport/static_metadata_test.cc b/test/core/transport/static_metadata_test.cc new file mode 100644 index 00000000..06be96db --- /dev/null +++ b/test/core/transport/static_metadata_test.cc @@ -0,0 +1,52 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/static_metadata.h" + +#include + +#include + +#include "src/core/lib/transport/metadata.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace { + +TEST(StaticMetadataTest, ReadAllStaticElements) { + // This makes sure that all static elements are returned when + // grpc_mdelem_from_slices is called with key pairs pregenerated. + for (int i = 0; i < GRPC_STATIC_MDELEM_COUNT; i++) { + const grpc_mdelem mdelem = g_static_mdelem_manifested[i]; + const grpc_mdelem mdelem2 = + grpc_mdelem_from_slices(GRPC_MDKEY(mdelem), GRPC_MDVALUE(mdelem)); + EXPECT_EQ(mdelem.payload, mdelem2.payload); + } +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int retval = RUN_ALL_TESTS(); + grpc_shutdown(); + return retval; +} diff --git a/test/core/transport/status_conversion_test.cc b/test/core/transport/status_conversion_test.cc new file mode 100644 index 00000000..ef7e6170 --- /dev/null +++ b/test/core/transport/status_conversion_test.cc @@ -0,0 +1,183 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/status_conversion.h" + +#include + +#include "test/core/util/test_config.h" + +#define GRPC_STATUS_TO_HTTP2_ERROR(a, b) \ + GPR_ASSERT(grpc_status_to_http2_error(a) == (b)) +#define HTTP2_ERROR_TO_GRPC_STATUS(a, deadline, b) \ + do { \ + grpc_core::ExecCtx exec_ctx; \ + GPR_ASSERT(grpc_http2_error_to_grpc_status(a, deadline) == (b)); \ + \ + } while (0) +#define GRPC_STATUS_TO_HTTP2_STATUS(a, b) \ + GPR_ASSERT(grpc_status_to_http2_status(a) == (b)) +#define HTTP2_STATUS_TO_GRPC_STATUS(a, b) \ + GPR_ASSERT(grpc_http2_status_to_grpc_status(a) == (b)) + +static void test_grpc_status_to_http2_error() { + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_OK, GRPC_HTTP2_NO_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_CANCELLED, GRPC_HTTP2_CANCEL); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_UNKNOWN, GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_INVALID_ARGUMENT, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_DEADLINE_EXCEEDED, GRPC_HTTP2_CANCEL); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_NOT_FOUND, GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_ALREADY_EXISTS, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_PERMISSION_DENIED, + GRPC_HTTP2_INADEQUATE_SECURITY); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_UNAUTHENTICATED, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_RESOURCE_EXHAUSTED, + GRPC_HTTP2_ENHANCE_YOUR_CALM); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_FAILED_PRECONDITION, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_ABORTED, GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_OUT_OF_RANGE, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_UNIMPLEMENTED, + GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_INTERNAL, GRPC_HTTP2_INTERNAL_ERROR); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_UNAVAILABLE, + GRPC_HTTP2_REFUSED_STREAM); + GRPC_STATUS_TO_HTTP2_ERROR(GRPC_STATUS_DATA_LOSS, GRPC_HTTP2_INTERNAL_ERROR); +} + +static void test_grpc_status_to_http2_status() { + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_OK, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_CANCELLED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_UNKNOWN, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_INVALID_ARGUMENT, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_DEADLINE_EXCEEDED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_NOT_FOUND, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_ALREADY_EXISTS, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_PERMISSION_DENIED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_UNAUTHENTICATED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_RESOURCE_EXHAUSTED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_FAILED_PRECONDITION, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_ABORTED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_OUT_OF_RANGE, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_UNIMPLEMENTED, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_INTERNAL, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_UNAVAILABLE, 200); + GRPC_STATUS_TO_HTTP2_STATUS(GRPC_STATUS_DATA_LOSS, 200); +} + +static void test_http2_error_to_grpc_status() { + const grpc_millis before_deadline = GRPC_MILLIS_INF_FUTURE; + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_NO_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_PROTOCOL_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_INTERNAL_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_FLOW_CONTROL_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_SETTINGS_TIMEOUT, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_STREAM_CLOSED, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_FRAME_SIZE_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_REFUSED_STREAM, before_deadline, + GRPC_STATUS_UNAVAILABLE); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_CANCEL, before_deadline, + GRPC_STATUS_CANCELLED); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_COMPRESSION_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_CONNECT_ERROR, before_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_ENHANCE_YOUR_CALM, before_deadline, + GRPC_STATUS_RESOURCE_EXHAUSTED); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_INADEQUATE_SECURITY, before_deadline, + GRPC_STATUS_PERMISSION_DENIED); + + const grpc_millis after_deadline = 0; + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_NO_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_PROTOCOL_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_INTERNAL_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_FLOW_CONTROL_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_SETTINGS_TIMEOUT, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_STREAM_CLOSED, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_FRAME_SIZE_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_REFUSED_STREAM, after_deadline, + GRPC_STATUS_UNAVAILABLE); + // We only have millisecond granularity in our timing code. This sleeps for 5 + // millis to ensure that the status conversion code will pick up the fact + // that the deadline has expired. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(5, GPR_TIMESPAN))); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_CANCEL, after_deadline, + GRPC_STATUS_DEADLINE_EXCEEDED); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_COMPRESSION_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_CONNECT_ERROR, after_deadline, + GRPC_STATUS_INTERNAL); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_ENHANCE_YOUR_CALM, after_deadline, + GRPC_STATUS_RESOURCE_EXHAUSTED); + HTTP2_ERROR_TO_GRPC_STATUS(GRPC_HTTP2_INADEQUATE_SECURITY, after_deadline, + GRPC_STATUS_PERMISSION_DENIED); +} + +static void test_http2_status_to_grpc_status() { + HTTP2_STATUS_TO_GRPC_STATUS(200, GRPC_STATUS_OK); + HTTP2_STATUS_TO_GRPC_STATUS(400, GRPC_STATUS_INTERNAL); + HTTP2_STATUS_TO_GRPC_STATUS(401, GRPC_STATUS_UNAUTHENTICATED); + HTTP2_STATUS_TO_GRPC_STATUS(403, GRPC_STATUS_PERMISSION_DENIED); + HTTP2_STATUS_TO_GRPC_STATUS(404, GRPC_STATUS_UNIMPLEMENTED); + HTTP2_STATUS_TO_GRPC_STATUS(409, GRPC_STATUS_UNKNOWN); + HTTP2_STATUS_TO_GRPC_STATUS(412, GRPC_STATUS_UNKNOWN); + HTTP2_STATUS_TO_GRPC_STATUS(429, GRPC_STATUS_UNAVAILABLE); + HTTP2_STATUS_TO_GRPC_STATUS(499, GRPC_STATUS_UNKNOWN); + HTTP2_STATUS_TO_GRPC_STATUS(500, GRPC_STATUS_UNKNOWN); + HTTP2_STATUS_TO_GRPC_STATUS(502, GRPC_STATUS_UNAVAILABLE); + HTTP2_STATUS_TO_GRPC_STATUS(503, GRPC_STATUS_UNAVAILABLE); + HTTP2_STATUS_TO_GRPC_STATUS(504, GRPC_STATUS_UNAVAILABLE); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + test_grpc_status_to_http2_error(); + test_grpc_status_to_http2_status(); + test_http2_error_to_grpc_status(); + test_http2_status_to_grpc_status(); + + /* check all status values can be converted */ + for (int i = 0; i <= 999; i++) { + grpc_http2_status_to_grpc_status(i); + } + + grpc_shutdown(); + + return 0; +} diff --git a/test/core/transport/status_metadata_test.cc b/test/core/transport/status_metadata_test.cc new file mode 100644 index 00000000..720c3eb7 --- /dev/null +++ b/test/core/transport/status_metadata_test.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/status_metadata.h" + +#include + +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/test_config.h" + +namespace { + +TEST(GetStatusCodeFromMetadata, OK) { + EXPECT_EQ(GRPC_STATUS_OK, + grpc_get_status_code_from_metadata(GRPC_MDELEM_GRPC_STATUS_0)); +} + +TEST(GetStatusCodeFromMetadata, CANCELLED) { + EXPECT_EQ(GRPC_STATUS_CANCELLED, + grpc_get_status_code_from_metadata(GRPC_MDELEM_GRPC_STATUS_1)); +} + +TEST(GetStatusCodeFromMetadata, UNKNOWN) { + EXPECT_EQ(GRPC_STATUS_UNKNOWN, + grpc_get_status_code_from_metadata(GRPC_MDELEM_GRPC_STATUS_2)); +} + +TEST(GetStatusCodeFromMetadata, Other) { + grpc_mdelem status_md = grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_STATUS, grpc_slice_from_static_string("10")); + EXPECT_EQ(GRPC_STATUS_ABORTED, grpc_get_status_code_from_metadata(status_md)); + GRPC_MDELEM_UNREF(status_md); +} + +TEST(GetStatusCodeFromMetadata, Unparseable) { + grpc_mdelem status_md = grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_STATUS, grpc_slice_from_static_string("NaN")); + EXPECT_EQ(GRPC_STATUS_UNKNOWN, grpc_get_status_code_from_metadata(status_md)); + GRPC_MDELEM_UNREF(status_md); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/stream_owned_slice_test.cc b/test/core/transport/stream_owned_slice_test.cc new file mode 100644 index 00000000..e7b76042 --- /dev/null +++ b/test/core/transport/stream_owned_slice_test.cc @@ -0,0 +1,42 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/lib/transport/transport.h" +#include "test/core/util/test_config.h" + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + + uint8_t buffer[] = "abc123"; + grpc_stream_refcount r; + GRPC_STREAM_REF_INIT(&r, 1, do_nothing, nullptr, "test"); + grpc_slice slice = + grpc_slice_from_stream_owned_buffer(&r, buffer, sizeof(buffer)); + GPR_ASSERT(GRPC_SLICE_START_PTR(slice) == buffer); + GPR_ASSERT(GRPC_SLICE_LENGTH(slice) == sizeof(buffer)); + grpc_slice_unref(slice); + + grpc_shutdown(); + return 0; +} diff --git a/test/core/transport/timeout_encoding_test.cc b/test/core/transport/timeout_encoding_test.cc new file mode 100644 index 00000000..5f6d38c7 --- /dev/null +++ b/test/core/transport/timeout_encoding_test.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/transport/timeout_encoding.h" + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/gpr/murmur_hash.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x) + +static void assert_encodes_as(grpc_millis ts, const char* s) { + char buffer[GRPC_HTTP2_TIMEOUT_ENCODE_MIN_BUFSIZE]; + grpc_http2_encode_timeout(ts, buffer); + gpr_log(GPR_INFO, "check '%s' == '%s'", buffer, s); + GPR_ASSERT(0 == strcmp(buffer, s)); +} + +void test_encoding(void) { + LOG_TEST("test_encoding"); + assert_encodes_as(-1, "1n"); + assert_encodes_as(-10, "1n"); + assert_encodes_as(1, "1m"); + assert_encodes_as(10, "10m"); + assert_encodes_as(100, "100m"); + assert_encodes_as(890, "890m"); + assert_encodes_as(900, "900m"); + assert_encodes_as(901, "901m"); + assert_encodes_as(1000, "1S"); + assert_encodes_as(2000, "2S"); + assert_encodes_as(2500, "2500m"); + assert_encodes_as(59900, "59900m"); + assert_encodes_as(50000, "50S"); + assert_encodes_as(59000, "59S"); + assert_encodes_as(60000, "1M"); + assert_encodes_as(80000, "80S"); + assert_encodes_as(90000, "90S"); + assert_encodes_as(120000, "2M"); + assert_encodes_as(20 * 60 * GPR_MS_PER_SEC, "20M"); + assert_encodes_as(60 * 60 * GPR_MS_PER_SEC, "1H"); + assert_encodes_as(10 * 60 * 60 * GPR_MS_PER_SEC, "10H"); + assert_encodes_as(60 * 60 * GPR_MS_PER_SEC - 100, "1H"); + assert_encodes_as(100 * 60 * 60 * GPR_MS_PER_SEC, "100H"); + assert_encodes_as(100000000000, "99999999S"); +} + +static void assert_decodes_as(const char* buffer, grpc_millis expected) { + grpc_millis got; + uint32_t hash = gpr_murmur_hash3(buffer, strlen(buffer), 0); + gpr_log(GPR_INFO, "check decoding '%s' (hash=0x%x)", buffer, hash); + GPR_ASSERT(1 == grpc_http2_decode_timeout( + grpc_slice_from_static_string(buffer), &got)); + if (got != expected) { + gpr_log(GPR_ERROR, "got:'%" PRId64 "' != expected:'%" PRId64 "'", got, + expected); + abort(); + } +} + +void decode_suite(char ext, grpc_millis (*answer)(int64_t x)) { + long test_vals[] = {1, 12, 123, 1234, 12345, 123456, + 1234567, 12345678, 123456789, 98765432, 9876543, 987654, + 98765, 9876, 987, 98, 9}; + for (unsigned i = 0; i < GPR_ARRAY_SIZE(test_vals); i++) { + std::string input = absl::StrFormat("%ld%c", test_vals[i], ext); + assert_decodes_as(input.c_str(), answer(test_vals[i])); + + input = absl::StrFormat(" %ld%c", test_vals[i], ext); + assert_decodes_as(input.c_str(), answer(test_vals[i])); + + input = absl::StrFormat("%ld %c", test_vals[i], ext); + assert_decodes_as(input.c_str(), answer(test_vals[i])); + + input = absl::StrFormat("%ld %c ", test_vals[i], ext); + assert_decodes_as(input.c_str(), answer(test_vals[i])); + } +} + +static grpc_millis millis_from_nanos(int64_t x) { + return static_cast(x / GPR_NS_PER_MS + (x % GPR_NS_PER_MS != 0)); +} +static grpc_millis millis_from_micros(int64_t x) { + return static_cast(x / GPR_US_PER_MS + (x % GPR_US_PER_MS != 0)); +} +static grpc_millis millis_from_millis(int64_t x) { + return static_cast(x); +} +static grpc_millis millis_from_seconds(int64_t x) { + return static_cast(x * GPR_MS_PER_SEC); +} +static grpc_millis millis_from_minutes(int64_t x) { + return static_cast(x * 60 * GPR_MS_PER_SEC); +} +static grpc_millis millis_from_hours(int64_t x) { + return static_cast(x * 3600 * GPR_MS_PER_SEC); +} + +void test_decoding(void) { + LOG_TEST("test_decoding"); + decode_suite('n', millis_from_nanos); + decode_suite('u', millis_from_micros); + decode_suite('m', millis_from_millis); + decode_suite('S', millis_from_seconds); + decode_suite('M', millis_from_minutes); + decode_suite('H', millis_from_hours); + assert_decodes_as("1000000000S", millis_from_seconds(1000 * 1000 * 1000)); + assert_decodes_as("1000000000000000000000u", GRPC_MILLIS_INF_FUTURE); + assert_decodes_as("1000000001S", GRPC_MILLIS_INF_FUTURE); + assert_decodes_as("2000000001S", GRPC_MILLIS_INF_FUTURE); + assert_decodes_as("9999999999S", GRPC_MILLIS_INF_FUTURE); +} + +static void assert_decoding_fails(const char* s) { + grpc_millis x; + GPR_ASSERT(0 == + grpc_http2_decode_timeout(grpc_slice_from_static_string(s), &x)); +} + +void test_decoding_fails(void) { + LOG_TEST("test_decoding_fails"); + assert_decoding_fails(""); + assert_decoding_fails(" "); + assert_decoding_fails("x"); + assert_decoding_fails("1"); + assert_decoding_fails("1x"); + assert_decoding_fails("1ux"); + assert_decoding_fails("!"); + assert_decoding_fails("n1"); + assert_decoding_fails("-1u"); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_encoding(); + test_decoding(); + test_decoding_fails(); + return 0; +} diff --git a/test/core/tsi/alts/crypt/aes_gcm_test.cc b/test/core/tsi/alts/crypt/aes_gcm_test.cc new file mode 100644 index 00000000..db475376 --- /dev/null +++ b/test/core/tsi/alts/crypt/aes_gcm_test.cc @@ -0,0 +1,2123 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "src/core/tsi/alts/crypt/gsec.h" +#include "test/core/tsi/alts/crypt/gsec_test_util.h" +#include "test/core/util/test_config.h" + +const size_t kTestMinTagLengthForCorruption = 8; +const size_t kTestNumCrypters = 3; +const size_t kTestMaxSlices = 5; +const size_t kTestMaxLength = 1024; +const size_t kTestNumEncryptions = 100; + +/* Struct for pre-generated test vector */ +typedef struct gsec_aead_test_vector { + uint8_t* nonce; + uint8_t* aad; + uint8_t* key; + uint8_t* plaintext; + uint8_t* ciphertext_and_tag; + size_t nonce_length; + size_t aad_length; + size_t key_length; + size_t plaintext_length; + size_t ciphertext_and_tag_length; +} gsec_aead_test_vector; + +static void gsec_randomly_slice(uint8_t* input, size_t input_length, + struct iovec** output, size_t* output_length) { + if (input_length == 0) { + *output = nullptr; + *output_length = 0; + return; + } + *output_length = gsec_test_bias_random_uint32(kTestMaxSlices) + 1; + *output = + static_cast(malloc(*output_length * sizeof(**output))); + size_t i; + for (i = 0; i < *output_length - 1; i++) { + size_t slice_length = + gsec_test_bias_random_uint32(static_cast(input_length)); + struct iovec slice = {input, slice_length}; + (*output)[i] = slice; + input += slice_length; + input_length -= slice_length; + } + struct iovec slice = {input, input_length}; + (*output)[*output_length - 1] = slice; +} + +static void gsec_assert_ok(grpc_status_code status, const char* error_detail) { + char empty_string[] = ""; + if (error_detail == nullptr) { + error_detail = empty_string; + } + if (status != GRPC_STATUS_OK) { + fprintf(stderr, "Status is not ok: %s\n", error_detail); + } + GPR_ASSERT(status == GRPC_STATUS_OK); +} + +static void gsec_test_random_encrypt_decrypt(gsec_aead_crypter* crypter, + size_t aad_length, + size_t message_length) { + GPR_ASSERT(crypter != nullptr); + size_t nonce_length, tag_length; + uint8_t *nonce, *aad, *message; + gsec_aead_crypter_nonce_length(crypter, &nonce_length, nullptr); + gsec_aead_crypter_tag_length(crypter, &tag_length, nullptr); + + gsec_test_random_array(&nonce, nonce_length); + gsec_test_random_array(&aad, aad_length); + gsec_test_random_array(&message, message_length); + + /* Test encryption */ + size_t ciphertext_and_tag_length, ciphertext_bytes_written = 0; + gsec_aead_crypter_max_ciphertext_and_tag_length( + crypter, message_length, &ciphertext_and_tag_length, nullptr); + + uint8_t* ciphertext_and_tag = + static_cast(gpr_malloc(ciphertext_and_tag_length)); + + char* error_buffer = nullptr; + gsec_assert_ok( + gsec_aead_crypter_encrypt(crypter, nonce, nonce_length, aad, aad_length, + message, message_length, ciphertext_and_tag, + ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_buffer), + error_buffer); + GPR_ASSERT(message_length + tag_length == ciphertext_and_tag_length); + GPR_ASSERT(ciphertext_bytes_written == ciphertext_and_tag_length); + + /* Test decryption */ + size_t plaintext_length, plaintext_bytes_written = 0; + gsec_aead_crypter_max_plaintext_length(crypter, ciphertext_bytes_written, + &plaintext_length, nullptr); + uint8_t* plaintext = static_cast(gpr_malloc(plaintext_length)); + grpc_status_code status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, ciphertext_and_tag, + ciphertext_bytes_written, plaintext, plaintext_length, + &plaintext_bytes_written, nullptr); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(message_length == plaintext_bytes_written); + if (message_length != 0) { + GPR_ASSERT(memcmp(message, plaintext, message_length) == 0); + } + + /** + * The returned plaintext will be zeroed if there was an authentication error. + */ + uint8_t* zero_message = static_cast(gpr_zalloc(plaintext_length)); + if (tag_length >= kTestMinTagLengthForCorruption) { + char* error_message; + /* Corrupt nonce */ + if (nonce_length > 0) { + plaintext_bytes_written = 0; + uint8_t* corrupt_nonce; + gsec_test_copy_and_alter_random_byte(nonce, &corrupt_nonce, nonce_length); + status = gsec_aead_crypter_decrypt( + crypter, corrupt_nonce, nonce_length, aad, aad_length, + ciphertext_and_tag, ciphertext_bytes_written, plaintext, + plaintext_length, &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, "Checking tag failed.", + error_message)); + GPR_ASSERT(plaintext_bytes_written == 0); + if (plaintext_length != 0) { + GPR_ASSERT(memcmp(zero_message, plaintext, plaintext_length) == 0); + } + gpr_free(corrupt_nonce); + gpr_free(error_message); + } + + /* Corrupt ciphertext_and_tag */ + plaintext_bytes_written = 0; + uint8_t* corrupt_ciphertext_and_tag; + gsec_test_copy_and_alter_random_byte(ciphertext_and_tag, + &corrupt_ciphertext_and_tag, + ciphertext_and_tag_length); + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, + corrupt_ciphertext_and_tag, ciphertext_bytes_written, plaintext, + plaintext_length, &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(plaintext_bytes_written == 0); + if (plaintext_length != 0) { + GPR_ASSERT(memcmp(zero_message, plaintext, plaintext_length) == 0); + } + gpr_free(error_message); + gpr_free(corrupt_ciphertext_and_tag); + + /* Corrupt start of ciphertext_and_tag */ + plaintext_bytes_written = 0; + gsec_test_copy(ciphertext_and_tag, &corrupt_ciphertext_and_tag, + ciphertext_and_tag_length); + (*corrupt_ciphertext_and_tag)++; + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, + corrupt_ciphertext_and_tag, ciphertext_bytes_written, plaintext, + plaintext_length, &plaintext_bytes_written, &error_message); + GPR_ASSERT(plaintext_bytes_written == 0); + if (plaintext_length != 0) { + GPR_ASSERT(memcmp(zero_message, plaintext, plaintext_length) == 0); + } + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + gpr_free(error_message); + gpr_free(corrupt_ciphertext_and_tag); + + /* Corrupt end of ciphertext_and_tag */ + plaintext_bytes_written = 0; + gsec_test_copy(ciphertext_and_tag, &corrupt_ciphertext_and_tag, + ciphertext_and_tag_length); + (*(corrupt_ciphertext_and_tag + ciphertext_and_tag_length - 1))++; + + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, + corrupt_ciphertext_and_tag, ciphertext_bytes_written, plaintext, + plaintext_length, &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(plaintext_bytes_written == 0); + if (plaintext_length != 0) { + GPR_ASSERT(memcmp(zero_message, plaintext, plaintext_length) == 0); + } + gpr_free(error_message); + gpr_free(corrupt_ciphertext_and_tag); + } + + gpr_free(zero_message); + gpr_free(nonce); + gpr_free(aad); + gpr_free(message); + gpr_free(plaintext); + gpr_free(ciphertext_and_tag); +} + +static void gsec_test_encrypt_decrypt(gsec_aead_crypter* crypter) { + GPR_ASSERT(crypter != nullptr); + size_t aad_length, message_length; + aad_length = gsec_test_bias_random_uint32(kTestMaxLength); + message_length = gsec_test_bias_random_uint32(kTestMaxLength); + gsec_test_random_encrypt_decrypt(crypter, aad_length, message_length); + gsec_test_random_encrypt_decrypt(crypter, 0, message_length); + gsec_test_random_encrypt_decrypt(crypter, aad_length, 0); +} + +static void gsec_test_multiple_random_encrypt_decrypt( + gsec_aead_crypter* crypter, size_t* aad_lengths, size_t* message_lengths, + size_t count) { + GPR_ASSERT(crypter != nullptr); + size_t nonce_length, tag_length; + uint8_t **nonces, **aads, **messages; + nonces = static_cast(gpr_malloc(sizeof(uint8_t*) * count)); + aads = static_cast(gpr_malloc(sizeof(uint8_t*) * count)); + messages = static_cast(gpr_malloc(sizeof(uint8_t*) * count)); + + gsec_aead_crypter_nonce_length(crypter, &nonce_length, nullptr); + gsec_aead_crypter_tag_length(crypter, &tag_length, nullptr); + + size_t ind; + for (ind = 0; ind < count; ind++) { + size_t aad_length = (aad_lengths == nullptr) ? 0 : aad_lengths[ind]; + size_t message_length = + (message_lengths == nullptr) ? 0 : message_lengths[ind]; + gsec_test_random_array(&(nonces[ind]), nonce_length); + gsec_test_random_array(&(aads[ind]), aad_length); + gsec_test_random_array(&(messages[ind]), message_length); + } + + size_t* ciphertext_and_tag_lengths = + static_cast(gpr_malloc(sizeof(size_t) * count)); + size_t* ciphertext_bytes_writtens = + static_cast(gpr_malloc(sizeof(size_t) * count)); + size_t* plaintext_lengths = + static_cast(gpr_malloc(sizeof(size_t) * count)); + size_t* plaintext_bytes_writtens = + static_cast(gpr_malloc(sizeof(size_t) * count)); + uint8_t** ciphertext_and_tags = + static_cast(gpr_malloc(sizeof(uint8_t*) * count)); + uint8_t** plaintexts = + static_cast(gpr_malloc(sizeof(uint8_t*) * count)); + + /* Do encryption */ + for (ind = 0; ind < count; ind++) { + size_t aad_length = (aad_lengths == nullptr) ? 0 : aad_lengths[ind]; + size_t message_length = + (message_lengths == nullptr) ? 0 : message_lengths[ind]; + gsec_aead_crypter_max_ciphertext_and_tag_length( + crypter, message_length, &(ciphertext_and_tag_lengths[ind]), nullptr); + ciphertext_and_tags[ind] = + static_cast(gpr_malloc(ciphertext_and_tag_lengths[ind])); + grpc_status_code status = gsec_aead_crypter_encrypt( + crypter, nonces[ind], nonce_length, aads[ind], aad_length, + messages[ind], message_length, ciphertext_and_tags[ind], + ciphertext_and_tag_lengths[ind], &(ciphertext_bytes_writtens[ind]), + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(message_length + tag_length == ciphertext_and_tag_lengths[ind]); + GPR_ASSERT(ciphertext_bytes_writtens[ind] == + ciphertext_and_tag_lengths[ind]); + } + /* Do Decryption */ + for (ind = 0; ind < count; ind++) { + size_t aad_length = (aad_lengths == nullptr) ? 0 : aad_lengths[ind]; + size_t message_length = + (message_lengths == nullptr) ? 0 : message_lengths[ind]; + gsec_aead_crypter_max_plaintext_length(crypter, + ciphertext_bytes_writtens[ind], + &(plaintext_lengths[ind]), nullptr); + plaintexts[ind] = static_cast(gpr_malloc(plaintext_lengths[ind])); + grpc_status_code status = gsec_aead_crypter_decrypt( + crypter, nonces[ind], nonce_length, aads[ind], aad_length, + ciphertext_and_tags[ind], ciphertext_bytes_writtens[ind], + plaintexts[ind], plaintext_lengths[ind], + &(plaintext_bytes_writtens[ind]), nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(message_length == plaintext_bytes_writtens[ind]); + if (message_length != 0) { + GPR_ASSERT(memcmp(messages[ind], plaintexts[ind], message_length) == 0); + } + } + + /* Slice the plaintext and encrypt with iovecs */ + for (ind = 0; ind < count; ind++) { + size_t aad_length = (aad_lengths == nullptr) ? 0 : aad_lengths[ind]; + struct iovec* aad_vecs = nullptr; + size_t aad_vecs_length = 0; + gsec_randomly_slice(aads[ind], aad_length, &aad_vecs, &aad_vecs_length); + size_t message_length = + (message_lengths == nullptr) ? 0 : message_lengths[ind]; + struct iovec* message_vecs = nullptr; + size_t message_vecs_length = 0; + gsec_randomly_slice(messages[ind], message_length, &message_vecs, + &message_vecs_length); + + size_t ciphertext_length = ciphertext_and_tag_lengths[ind]; + uint8_t* another_ciphertext = + static_cast(malloc(ciphertext_length)); + struct iovec another_ciphertext_vec = {another_ciphertext, + ciphertext_length}; + + char* error_details = nullptr; + size_t ciphertext_bytes_written = 0; + gsec_assert_ok( + gsec_aead_crypter_encrypt_iovec( + crypter, nonces[ind], nonce_length, aad_vecs, aad_vecs_length, + message_vecs, message_vecs_length, another_ciphertext_vec, + &ciphertext_bytes_written, &error_details), + error_details); + GPR_ASSERT(memcmp(ciphertext_and_tags[ind], another_ciphertext_vec.iov_base, + ciphertext_length) == 0); + free(another_ciphertext); + free(aad_vecs); + free(message_vecs); + } + + /* Slice the ciphertext and decrypt with iovecs */ + for (ind = 0; ind < count; ind++) { + size_t message_length = + (message_lengths == nullptr) ? 0 : message_lengths[ind]; + message_length = message_length + 0; + + size_t aad_length = (aad_lengths == nullptr) ? 0 : aad_lengths[ind]; + + struct iovec* aad_vecs = nullptr; + size_t aad_vecs_length = 0; + gsec_randomly_slice(aads[ind], aad_length, &aad_vecs, &aad_vecs_length); + + struct iovec* ciphertext_vecs = nullptr; + size_t ciphertext_vecs_length = 0; + gsec_randomly_slice(ciphertext_and_tags[ind], + ciphertext_bytes_writtens[ind], &ciphertext_vecs, + &ciphertext_vecs_length); + + size_t decrypted_length = plaintext_lengths[ind]; + uint8_t* decrypted = static_cast(malloc(decrypted_length)); + struct iovec decrypted_vec = {decrypted, decrypted_length}; + + char* error_details = nullptr; + gsec_assert_ok(gsec_aead_crypter_decrypt_iovec( + crypter, nonces[ind], nonce_length, aad_vecs, + aad_vecs_length, ciphertext_vecs, ciphertext_vecs_length, + decrypted_vec, &decrypted_length, &error_details), + error_details); + GPR_ASSERT(decrypted_vec.iov_len == message_length); + if (message_length != 0) { + GPR_ASSERT( + memcmp(decrypted_vec.iov_base, messages[ind], message_length) == 0); + } + free(decrypted); + free(aad_vecs); + free(ciphertext_vecs); + } + + for (ind = 0; ind < count; ind++) { + gpr_free(nonces[ind]); + gpr_free(aads[ind]); + gpr_free(messages[ind]); + gpr_free(ciphertext_and_tags[ind]); + gpr_free(plaintexts[ind]); + } + gpr_free(nonces); + gpr_free(aads); + gpr_free(messages); + gpr_free(ciphertext_and_tag_lengths); + gpr_free(ciphertext_bytes_writtens); + gpr_free(plaintext_lengths); + gpr_free(plaintext_bytes_writtens); + gpr_free(ciphertext_and_tags); + gpr_free(plaintexts); +} + +static void gsec_test_multiple_encrypt_decrypt(gsec_aead_crypter* crypter) { + GPR_ASSERT(crypter != nullptr); + size_t count = kTestNumEncryptions; + size_t* aad_lengths = + static_cast(gpr_malloc(sizeof(size_t) * count)); + size_t* message_lengths = + static_cast(gpr_malloc(sizeof(size_t) * count)); + size_t ind; + for (ind = 0; ind < count; ind++) { + aad_lengths[ind] = gsec_test_bias_random_uint32(kTestMaxLength); + message_lengths[ind] = gsec_test_bias_random_uint32(kTestMaxLength); + } + gsec_test_multiple_random_encrypt_decrypt(crypter, aad_lengths, + message_lengths, count); + gsec_test_multiple_random_encrypt_decrypt(crypter, aad_lengths, nullptr, + count); + gsec_test_multiple_random_encrypt_decrypt(crypter, nullptr, message_lengths, + count); + gpr_free(aad_lengths); + gpr_free(message_lengths); +} + +static void gsec_test_encryption_failure(gsec_aead_crypter* crypter) { + GPR_ASSERT(crypter != nullptr); + size_t aad_length = kTestMaxLength; + size_t message_length = kTestMaxLength; + size_t nonce_length; + + char* error_message; + uint8_t *nonce, *aad, *message; + + gsec_aead_crypter_nonce_length(crypter, &nonce_length, nullptr); + gsec_test_random_array(&nonce, nonce_length); + gsec_test_random_array(&aad, aad_length); + gsec_test_random_array(&message, message_length); + + size_t ciphertext_and_tag_length, ciphertext_bytes_written = 0; + gsec_aead_crypter_max_ciphertext_and_tag_length( + crypter, message_length, &ciphertext_and_tag_length, nullptr); + uint8_t* ciphertext_and_tag = + static_cast(gpr_malloc(ciphertext_and_tag_length)); + + /* nullptr nonce */ + grpc_status_code status = gsec_aead_crypter_encrypt( + crypter, nullptr, nonce_length, aad, aad_length, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, &ciphertext_bytes_written, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer is nullptr.")); + gpr_free(error_message); + + /* Big nonce */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length + 1, aad, aad_length, message, + message_length, ciphertext_and_tag, ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer has the wrong length.")); + gpr_free(error_message); + + /* Small nonce */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length - 1, aad, aad_length, message, + message_length, ciphertext_and_tag, ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer has the wrong length.")); + gpr_free(error_message); + + /* nullptr aad */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, nullptr, aad_length, message, + message_length, ciphertext_and_tag, ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "aad is nullptr.")); + gpr_free(error_message); + + /* nullptr aad with zero length */ + gsec_assert_ok( + gsec_aead_crypter_encrypt(crypter, nonce, nonce_length, nullptr, 0, + message, message_length, ciphertext_and_tag, + ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message), + error_message); + + /* nullptr plaintext */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, nullptr, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, &ciphertext_bytes_written, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "plaintext is nullptr.")); + gpr_free(error_message); + + /* nullptr ciphertext */ + status = gsec_aead_crypter_encrypt(crypter, nonce, nonce_length, aad, + aad_length, message, message_length, + nullptr, ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "ciphertext is nullptr.")); + gpr_free(error_message); + + /* Short ciphertext */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length - 1, + &ciphertext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "ciphertext is too small to hold a tag.")); + gpr_free(error_message); + + /* nullptr ciphertext_bytes_written */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, nullptr, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "bytes_written is nullptr.")); + gpr_free(error_message); + + /* nullptr plaintext/ciphertext encrypt with zero length */ + gsec_assert_ok(gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, nullptr, 0, + ciphertext_and_tag, ciphertext_and_tag_length, + &ciphertext_bytes_written, &error_message), + error_message); + + /* Success */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, &ciphertext_bytes_written, + &error_message); + GPR_ASSERT(status == GRPC_STATUS_OK); + + gpr_free(message); + gpr_free(aad); + gpr_free(nonce); + gpr_free(ciphertext_and_tag); +} + +static void gsec_test_decryption_failure(gsec_aead_crypter* crypter) { + GPR_ASSERT(crypter != nullptr); + size_t aad_length = kTestMaxLength; + size_t message_length = kTestMaxLength; + size_t nonce_length, tag_length; + uint8_t *nonce, *aad, *message; + + gsec_aead_crypter_nonce_length(crypter, &nonce_length, nullptr); + gsec_aead_crypter_tag_length(crypter, &tag_length, nullptr); + gsec_test_random_array(&nonce, nonce_length); + gsec_test_random_array(&aad, aad_length); + gsec_test_random_array(&message, message_length); + + /* Test encryption */ + size_t ciphertext_and_tag_length, ciphertext_bytes_written = 0; + gsec_aead_crypter_max_ciphertext_and_tag_length( + crypter, message_length, &ciphertext_and_tag_length, nullptr); + uint8_t* ciphertext_and_tag = + static_cast(gpr_malloc(ciphertext_and_tag_length)); + + grpc_status_code status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, aad, aad_length, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, &ciphertext_bytes_written, + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(ciphertext_bytes_written == ciphertext_and_tag_length); + + size_t plaintext_length, plaintext_bytes_written = 0; + gsec_aead_crypter_max_plaintext_length(crypter, ciphertext_bytes_written, + &plaintext_length, nullptr); + uint8_t* plaintext = static_cast(gpr_malloc(plaintext_length)); + + char* error_message; + /* nullptr nonce */ + status = gsec_aead_crypter_decrypt( + crypter, nullptr, nonce_length, aad, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer is nullptr.")); + gpr_free(error_message); + + /* Big nonce */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length + 1, aad, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer has the wrong length.")); + gpr_free(error_message); + + /* Small nonce */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length - 1, aad, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Nonce buffer has the wrong length.")); + gpr_free(error_message); + + /* nullptr aad */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, nullptr, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "aad is nullptr.")); + gpr_free(error_message); + + /* nullptr aad with zero length */ + status = gsec_aead_crypter_encrypt( + crypter, nonce, nonce_length, nullptr, 0, message, message_length, + ciphertext_and_tag, ciphertext_and_tag_length, &ciphertext_bytes_written, + &error_message); + GPR_ASSERT(status == GRPC_STATUS_OK); + + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, nullptr, 0, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + GPR_ASSERT(status == GRPC_STATUS_OK); + + /* Small ciphertext */ + if (tag_length > 0) { + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, ciphertext_and_tag, + tag_length - 1, plaintext, plaintext_length, &plaintext_bytes_written, + &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "ciphertext is too small to hold a tag.")); + gpr_free(error_message); + } + + /* nullptr ciphertext */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, nullptr, + ciphertext_and_tag_length, plaintext, plaintext_length, + &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "ciphertext is nullptr.")); + gpr_free(error_message); + + /* nullptr plaintext */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, nullptr, plaintext_length, + &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "plaintext is nullptr, but plaintext_length is positive.")); + gpr_free(error_message); + + /* Short plaintext */ + status = gsec_aead_crypter_decrypt( + crypter, nonce, nonce_length, aad, aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, plaintext_length - 1, + &plaintext_bytes_written, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Not enough plaintext buffer to hold encrypted ciphertext.")); + gpr_free(error_message); + + /* nullptr plaintext_bytes_written */ + status = gsec_aead_crypter_decrypt(crypter, nonce, nonce_length, aad, + aad_length, ciphertext_and_tag, + ciphertext_and_tag_length, plaintext, + plaintext_length, nullptr, &error_message); + + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "bytes_written is nullptr.")); + gpr_free(error_message); + + gpr_free(message); + gpr_free(plaintext); + gpr_free(ciphertext_and_tag); + gpr_free(aad); + gpr_free(nonce); +} + +static void gsec_test_encrypt_decrypt_test_vector( + gsec_aead_crypter* crypter, gsec_aead_test_vector* test_vector) { + GPR_ASSERT(crypter != nullptr); + /* Test byte-based encryption interface. */ + size_t ciphertext_and_tag_length, ciphertext_bytes_written = 0; + gsec_aead_crypter_max_ciphertext_and_tag_length( + crypter, test_vector->plaintext_length, &ciphertext_and_tag_length, + nullptr); + uint8_t* ciphertext_and_tag_bytes = + static_cast(gpr_malloc(ciphertext_and_tag_length)); + grpc_status_code status = gsec_aead_crypter_encrypt( + crypter, test_vector->nonce, test_vector->nonce_length, test_vector->aad, + test_vector->aad_length, test_vector->plaintext, + test_vector->plaintext_length, ciphertext_and_tag_bytes, + ciphertext_and_tag_length, &ciphertext_bytes_written, nullptr); + + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(ciphertext_bytes_written == ciphertext_and_tag_length); + GPR_ASSERT(memcmp(test_vector->ciphertext_and_tag, ciphertext_and_tag_bytes, + ciphertext_and_tag_length) == 0); + + /* Test byte-based decryption interface */ + size_t plaintext_length, plaintext_bytes_written = 0; + gsec_aead_crypter_max_plaintext_length(crypter, ciphertext_and_tag_length, + &plaintext_length, nullptr); + uint8_t* plaintext_bytes = + static_cast(gpr_malloc(plaintext_length)); + status = gsec_aead_crypter_decrypt( + crypter, test_vector->nonce, test_vector->nonce_length, test_vector->aad, + test_vector->aad_length, test_vector->ciphertext_and_tag, + test_vector->ciphertext_and_tag_length, plaintext_bytes, plaintext_length, + &plaintext_bytes_written, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + if (plaintext_bytes_written != 0) { + GPR_ASSERT(memcmp(test_vector->plaintext, plaintext_bytes, + plaintext_bytes_written) == 0); + } + + gpr_free(ciphertext_and_tag_bytes); + gpr_free(plaintext_bytes); +} + +static void gsec_test_get_crypter_from_test_vector( + gsec_aead_crypter** crypter, gsec_aead_test_vector* test_vector, + bool rekey = false) { + size_t key_length = test_vector->key_length; + GPR_ASSERT(key_length == kAes128GcmKeyLength || + key_length == kAes256GcmKeyLength || + key_length == kAes128GcmRekeyKeyLength); + size_t nonce_length = test_vector->nonce_length; + GPR_ASSERT(nonce_length == kAesGcmNonceLength); + size_t plaintext_length = test_vector->plaintext_length; + size_t ciphertext_and_tag_length = test_vector->ciphertext_and_tag_length; + GPR_ASSERT(ciphertext_and_tag_length == plaintext_length + kAesGcmTagLength); + size_t tag_length = ciphertext_and_tag_length - plaintext_length; + gsec_aes_gcm_aead_crypter_create(test_vector->key, key_length, nonce_length, + tag_length, rekey, crypter, nullptr); +} + +static void gsec_test_verify_crypter_on_test_vector( + gsec_aead_test_vector* test_vector, bool rekey = false) { + gsec_aead_crypter* crypter; + gsec_test_get_crypter_from_test_vector(&crypter, test_vector, rekey); + gsec_test_encrypt_decrypt_test_vector(crypter, test_vector); + gsec_aead_crypter_destroy(crypter); +} + +static void gsec_aead_malloc_test_vector( + gsec_aead_test_vector** test_vector, const uint8_t* key, size_t key_length, + const uint8_t* nonce, size_t nonce_length, const uint8_t* aad, + size_t aad_length, const uint8_t* plaintext, size_t plaintext_length, + const uint8_t* ciphertext_and_tag, size_t ciphertext_and_tag_length) { + *test_vector = static_cast( + gpr_malloc(sizeof(gsec_aead_test_vector))); + (*test_vector)->key_length = key_length; + (*test_vector)->nonce_length = nonce_length; + (*test_vector)->aad_length = aad_length; + (*test_vector)->plaintext_length = plaintext_length; + (*test_vector)->ciphertext_and_tag_length = ciphertext_and_tag_length; + gsec_test_copy(key, &((*test_vector)->key), key_length); + gsec_test_copy(nonce, &((*test_vector)->nonce), nonce_length); + gsec_test_copy(aad, &((*test_vector)->aad), aad_length); + gsec_test_copy(plaintext, &((*test_vector)->plaintext), plaintext_length); + gsec_test_copy(ciphertext_and_tag, &((*test_vector)->ciphertext_and_tag), + ciphertext_and_tag_length); +} + +static void gsec_aead_free_test_vector(gsec_aead_test_vector* test_vector) { + gpr_free(test_vector->key); + gpr_free(test_vector->nonce); + gpr_free(test_vector->aad); + gpr_free(test_vector->plaintext); + gpr_free(test_vector->ciphertext_and_tag); + gpr_free(test_vector); +} + +static void gsec_test_create_random_aes_gcm_crypter(gsec_aead_crypter** crypter, + size_t key_length, + size_t nonce_length, + size_t tag_length, + bool rekey) { + uint8_t* key; + gsec_test_random_array(&key, key_length); + gsec_aes_gcm_aead_crypter_create(key, key_length, nonce_length, tag_length, + rekey, crypter, nullptr); + gpr_free(key); +} + +static void gsec_test_get_random_aes_gcm_crypters( + gsec_aead_crypter*** crypters) { + *crypters = static_cast( + gpr_malloc(sizeof(gsec_aead_crypter*) * kTestNumCrypters)); + gsec_test_create_random_aes_gcm_crypter( + &((*crypters)[0]), kAes128GcmKeyLength, kAesGcmNonceLength, + kAesGcmTagLength, /*rekey=*/false); + gsec_test_create_random_aes_gcm_crypter( + &((*crypters)[1]), kAes256GcmKeyLength, kAesGcmNonceLength, + kAesGcmTagLength, /*rekey=*/false); + gsec_test_create_random_aes_gcm_crypter( + &((*crypters)[2]), kAes128GcmRekeyKeyLength, kAesGcmNonceLength, + kAesGcmTagLength, /*rekey=*/true); +} + +static void gsec_test_do_generic_crypter_tests() { + gsec_aead_crypter** crypters; + gsec_test_get_random_aes_gcm_crypters(&crypters); + size_t ind; + for (ind = 0; ind < kTestNumCrypters; ind++) { + gsec_test_encrypt_decrypt(crypters[ind]); + gsec_test_multiple_encrypt_decrypt(crypters[ind]); + gsec_test_encryption_failure(crypters[ind]); + gsec_test_decryption_failure(crypters[ind]); + } + for (ind = 0; ind < kTestNumCrypters; ind++) { + gsec_aead_crypter_destroy(crypters[ind]); + } + gpr_free(crypters); +} + +static void gsec_test_do_vector_tests_rekey_nist() { + // NIST vectors from: + // http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf + // + // IEEE vectors from: + // http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf + // + // Key expanded by setting expandedKey = (key||(key ^ {0x01, .., 0x01})||key ^ + // {0x02,..,0x02}))[0:44]. + + gsec_aead_test_vector vec; + + // Derived from NIST test vector 1 + uint8_t nonce_0[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + uint8_t aad_0[1] = {}; + uint8_t key_0[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x2, + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2}; + uint8_t plaintext_0[1] = {}; + uint8_t ciphertext_0[] = {0x85, 0xE8, 0x73, 0xE0, 0x2, 0xF6, 0xEB, 0xDC, + 0x40, 0x60, 0x95, 0x4E, 0xB8, 0x67, 0x55, 0x8}; + vec = {nonce_0, aad_0, key_0, plaintext_0, ciphertext_0, 12, 0, 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from NIST test vector 2 + uint8_t nonce_1[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + uint8_t aad_1[1] = {}; + uint8_t key_1[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x2, + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2}; + uint8_t plaintext_1[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + uint8_t ciphertext_1[] = {0x51, 0xE9, 0xA8, 0xCB, 0x23, 0xCA, 0x25, 0x12, + 0xC8, 0x25, 0x6A, 0xFF, 0xF8, 0xE7, 0x2D, 0x68, + 0x1A, 0xCA, 0x19, 0xA1, 0x14, 0x8A, 0xC1, 0x15, + 0xE8, 0x3D, 0xF4, 0x88, 0x8C, 0xC0, 0xD, 0x11}; + vec = {nonce_1, aad_1, key_1, plaintext_1, ciphertext_1, 12, 0, 44, 16, 32}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from NIST test vector 3 + uint8_t nonce_2[] = {0xCA, 0xFE, 0xBA, 0xBE, 0xFA, 0xCE, + 0xDB, 0xAD, 0xDE, 0xCA, 0xF8, 0x88}; + uint8_t aad_2[1] = {}; + uint8_t key_2[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_2[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, + 0xC5, 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, + 0xF7, 0xDA, 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, + 0x3C, 0xC, 0x95, 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, + 0x49, 0xA6, 0xB5, 0x25, 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, + 0x57, 0xBA, 0x63, 0x7B, 0x39, 0x1A, 0xAF, 0xD2, 0x55}; + uint8_t ciphertext_2[] = { + 0x10, 0x18, 0xED, 0x5A, 0x14, 0x2, 0xA8, 0x65, 0x16, 0xD6, 0x57, 0x6D, + 0x70, 0xB2, 0xFF, 0xCC, 0xCA, 0x26, 0x1B, 0x94, 0xDF, 0x88, 0xB5, 0x8F, + 0x53, 0xB6, 0x4D, 0xFB, 0xA4, 0x35, 0xD1, 0x8B, 0x2F, 0x6E, 0x3B, 0x78, + 0x69, 0xF9, 0x35, 0x3D, 0x4A, 0xC8, 0xCF, 0x9, 0xAF, 0xB1, 0x66, 0x3D, + 0xAA, 0x7B, 0x40, 0x17, 0xE6, 0xFC, 0x2C, 0x17, 0x7C, 0xC, 0x8, 0x7C, + 0xD, 0xF1, 0x16, 0x21, 0x29, 0x95, 0x22, 0x13, 0xCE, 0xE1, 0xBC, 0x6E, + 0x9C, 0x84, 0x95, 0xDD, 0x70, 0x5E, 0x1F, 0x3D}; + vec = {nonce_2, aad_2, key_2, plaintext_2, ciphertext_2, 12, 0, 44, 64, 80}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from NIST test vector 4 + uint8_t nonce_3[] = {0xCA, 0xFE, 0xBA, 0xBE, 0xFA, 0xCE, + 0xDB, 0xAD, 0xDE, 0xCA, 0xF8, 0x88}; + uint8_t aad_3[] = {0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, 0xBE, + 0xEF, 0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, + 0xBE, 0xEF, 0xAB, 0xAD, 0xDA, 0xD2}; + uint8_t key_3[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_3[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, 0xC5, + 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, 0xF7, 0xDA, + 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, 0x3C, 0xC, 0x95, + 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, 0x49, 0xA6, 0xB5, 0x25, + 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, 0x57, 0xBA, 0x63, 0x7B, 0x39}; + uint8_t ciphertext_3[] = { + 0x10, 0x18, 0xED, 0x5A, 0x14, 0x2, 0xA8, 0x65, 0x16, 0xD6, 0x57, + 0x6D, 0x70, 0xB2, 0xFF, 0xCC, 0xCA, 0x26, 0x1B, 0x94, 0xDF, 0x88, + 0xB5, 0x8F, 0x53, 0xB6, 0x4D, 0xFB, 0xA4, 0x35, 0xD1, 0x8B, 0x2F, + 0x6E, 0x3B, 0x78, 0x69, 0xF9, 0x35, 0x3D, 0x4A, 0xC8, 0xCF, 0x9, + 0xAF, 0xB1, 0x66, 0x3D, 0xAA, 0x7B, 0x40, 0x17, 0xE6, 0xFC, 0x2C, + 0x17, 0x7C, 0xC, 0x8, 0x7C, 0x47, 0x64, 0x56, 0x5D, 0x7, 0x7E, + 0x91, 0x24, 0x0, 0x1D, 0xDB, 0x27, 0xFC, 0x8, 0x48, 0xC5}; + vec = {nonce_3, aad_3, key_3, plaintext_3, ciphertext_3, 12, 20, 44, 60, 76}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from adapted NIST test vector 4 for KDF counter boundary (flip + // nonce bit 15) + uint8_t nonce_4[] = {0xCA, 0x7E, 0xBA, 0xBE, 0xFA, 0xCE, + 0xDB, 0xAD, 0xDE, 0xCA, 0xF8, 0x88}; + uint8_t aad_4[] = {0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, 0xBE, + 0xEF, 0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, + 0xBE, 0xEF, 0xAB, 0xAD, 0xDA, 0xD2}; + uint8_t key_4[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_4[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, 0xC5, + 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, 0xF7, 0xDA, + 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, 0x3C, 0xC, 0x95, + 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, 0x49, 0xA6, 0xB5, 0x25, + 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, 0x57, 0xBA, 0x63, 0x7B, 0x39}; + uint8_t ciphertext_4[] = { + 0xE6, 0x50, 0xD3, 0xC0, 0xFB, 0x87, 0x93, 0x27, 0xF2, 0xD0, 0x32, + 0x87, 0xFA, 0x93, 0xCD, 0x7, 0x34, 0x2B, 0x13, 0x62, 0x15, 0xAD, + 0xBC, 0xA0, 0xC, 0x3B, 0xD5, 0x9, 0x9E, 0xC4, 0x18, 0x32, 0xB1, + 0xD1, 0x8E, 0x4, 0x23, 0xED, 0x26, 0xBB, 0x12, 0xC6, 0xCD, 0x9, + 0xDE, 0xBB, 0x29, 0x23, 0xA, 0x94, 0xC0, 0xCE, 0xE1, 0x59, 0x3, + 0x65, 0x6F, 0x85, 0xED, 0xB6, 0xFC, 0x50, 0x9B, 0x1B, 0x28, 0x21, + 0x63, 0x82, 0x17, 0x2E, 0xCB, 0xCC, 0x31, 0xE1, 0xE9, 0xB1}; + vec = {nonce_4, aad_4, key_4, plaintext_4, ciphertext_4, 12, 20, 44, 60, 76}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from adapted NIST test vector 4 for KDF counter boundary (flip + // nonce bit 16) + uint8_t nonce_5[] = {0xCA, 0xFE, 0xBB, 0xBE, 0xFA, 0xCE, + 0xDB, 0xAD, 0xDE, 0xCA, 0xF8, 0x88}; + uint8_t aad_5[] = {0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, 0xBE, + 0xEF, 0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, + 0xBE, 0xEF, 0xAB, 0xAD, 0xDA, 0xD2}; + uint8_t key_5[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_5[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, 0xC5, + 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, 0xF7, 0xDA, + 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, 0x3C, 0xC, 0x95, + 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, 0x49, 0xA6, 0xB5, 0x25, + 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, 0x57, 0xBA, 0x63, 0x7B, 0x39}; + uint8_t ciphertext_5[] = { + 0xC0, 0x12, 0x1E, 0x6C, 0x95, 0x4D, 0x7, 0x67, 0xF9, 0x66, 0x30, + 0xC3, 0x34, 0x50, 0x99, 0x97, 0x91, 0xB2, 0xDA, 0x2A, 0xD0, 0x5C, + 0x41, 0x90, 0x16, 0x9C, 0xCA, 0xD9, 0xAC, 0x86, 0xFF, 0x1C, 0x72, + 0x1E, 0x3D, 0x82, 0xF2, 0xAD, 0x22, 0xAB, 0x46, 0x3B, 0xAB, 0x4A, + 0x7, 0x54, 0xB7, 0xDD, 0x68, 0xCA, 0x4D, 0xE7, 0xEA, 0x25, 0x31, + 0xB6, 0x25, 0xED, 0xA0, 0x1F, 0x89, 0x31, 0x2B, 0x2A, 0xB9, 0x57, + 0xD5, 0xC7, 0xF8, 0x56, 0x8D, 0xD9, 0x5F, 0xCD, 0xCD, 0x1F}; + vec = {nonce_5, aad_5, key_5, plaintext_5, ciphertext_5, 12, 20, 44, 60, 76}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from adapted NIST test vector 4 for KDF counter boundary (flip + // nonce bit 63) + uint8_t nonce_6[] = {0xCA, 0xFE, 0xBA, 0xBE, 0xFA, 0xCE, + 0xDB, 0x2D, 0xDE, 0xCA, 0xF8, 0x88}; + uint8_t aad_6[] = {0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, 0xBE, + 0xEF, 0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, + 0xBE, 0xEF, 0xAB, 0xAD, 0xDA, 0xD2}; + uint8_t key_6[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_6[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, 0xC5, + 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, 0xF7, 0xDA, + 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, 0x3C, 0xC, 0x95, + 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, 0x49, 0xA6, 0xB5, 0x25, + 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, 0x57, 0xBA, 0x63, 0x7B, 0x39}; + uint8_t ciphertext_6[] = { + 0x8A, 0xF3, 0x7E, 0xA5, 0x68, 0x4A, 0x4D, 0x81, 0xD4, 0xFD, 0x81, + 0x72, 0x61, 0xFD, 0x97, 0x43, 0x9, 0x9E, 0x7E, 0x6A, 0x2, 0x5E, + 0xAA, 0xCF, 0x8E, 0x54, 0xB1, 0x24, 0xFB, 0x57, 0x43, 0x14, 0x9E, + 0x5, 0xCB, 0x89, 0xF4, 0xA4, 0x94, 0x67, 0xFE, 0x2E, 0x5E, 0x59, + 0x65, 0xF2, 0x9A, 0x19, 0xF9, 0x94, 0x16, 0xB0, 0x1, 0x6B, 0x54, + 0x58, 0x5D, 0x12, 0x55, 0x37, 0x83, 0xBA, 0x59, 0xE9, 0xF7, 0x82, + 0xE8, 0x2E, 0x9, 0x7C, 0x33, 0x6B, 0xF7, 0x98, 0x9F, 0x8}; + vec = {nonce_6, aad_6, key_6, plaintext_6, ciphertext_6, 12, 20, 44, 60, 76}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from adapted NIST test vector 4 for KDF counter boundary (flip + // nonce bit 64) + uint8_t nonce_7[] = {0xCA, 0xFE, 0xBA, 0xBE, 0xFA, 0xCE, + 0xDB, 0xAD, 0xDF, 0xCA, 0xF8, 0x88}; + uint8_t aad_7[] = {0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, 0xBE, + 0xEF, 0xFE, 0xED, 0xFA, 0xCE, 0xDE, 0xAD, + 0xBE, 0xEF, 0xAB, 0xAD, 0xDA, 0xD2}; + uint8_t key_7[] = {0xFE, 0xFF, 0xE9, 0x92, 0x86, 0x65, 0x73, 0x1C, 0x6D, + 0x6A, 0x8F, 0x94, 0x67, 0x30, 0x83, 0x8, 0xFF, 0xFE, + 0xE8, 0x93, 0x87, 0x64, 0x72, 0x1D, 0x6C, 0x6B, 0x8E, + 0x95, 0x66, 0x31, 0x82, 0x9, 0xFC, 0xFD, 0xEB, 0x90, + 0x84, 0x67, 0x71, 0x1E, 0x6F, 0x68, 0x8D, 0x96}; + uint8_t plaintext_7[] = { + 0xD9, 0x31, 0x32, 0x25, 0xF8, 0x84, 0x6, 0xE5, 0xA5, 0x59, 0x9, 0xC5, + 0xAF, 0xF5, 0x26, 0x9A, 0x86, 0xA7, 0xA9, 0x53, 0x15, 0x34, 0xF7, 0xDA, + 0x2E, 0x4C, 0x30, 0x3D, 0x8A, 0x31, 0x8A, 0x72, 0x1C, 0x3C, 0xC, 0x95, + 0x95, 0x68, 0x9, 0x53, 0x2F, 0xCF, 0xE, 0x24, 0x49, 0xA6, 0xB5, 0x25, + 0xB1, 0x6A, 0xED, 0xF5, 0xAA, 0xD, 0xE6, 0x57, 0xBA, 0x63, 0x7B, 0x39}; + uint8_t ciphertext_7[] = { + 0xFB, 0xD5, 0x28, 0x44, 0x8D, 0x3, 0x46, 0xBF, 0xA8, 0x78, 0x63, + 0x48, 0x64, 0xD4, 0x7, 0xA3, 0x5A, 0x3, 0x9D, 0xE9, 0xDB, 0x2F, + 0x1F, 0xEB, 0x8E, 0x96, 0x5B, 0x3A, 0xE9, 0x35, 0x6C, 0xE6, 0x28, + 0x94, 0x41, 0xD7, 0x7F, 0x8F, 0xD, 0xF2, 0x94, 0x89, 0x1F, 0x37, + 0xEA, 0x43, 0x8B, 0x22, 0x3E, 0x3B, 0xF2, 0xBD, 0xC5, 0x3D, 0x4C, + 0x5A, 0x74, 0xFB, 0x68, 0xB, 0xB3, 0x12, 0xA8, 0xDE, 0xC6, 0xF7, + 0x25, 0x2C, 0xBC, 0xD7, 0xF5, 0x79, 0x97, 0x50, 0xAD, 0x78}; + vec = {nonce_7, aad_7, key_7, plaintext_7, ciphertext_7, 12, 20, 44, 60, 76}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); +} + +static void gsec_test_do_vector_tests_rekey_ieee() { + // IEEE vectors from: + // http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf + // + // Key expanded by setting expandedKey = (key||(key ^ {0x01, .., 0x01})||key ^ + // {0x02,..,0x02}))[0:44]. + + gsec_aead_test_vector vec; + + // Derived from IEEE 2.1.1 54-byte auth + uint8_t nonce_8[] = {0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, + 0x5E, 0x81, 0xB2, 0xC2, 0x84, 0x65}; + uint8_t aad_8[] = {0xD6, 0x9, 0xB1, 0xF0, 0x56, 0x63, 0x7A, 0xD, 0x46, 0xDF, + 0x99, 0x8D, 0x88, 0xE5, 0x22, 0x2A, 0xB2, 0xC2, 0x84, 0x65, + 0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, 0x5E, 0x81, 0x8, 0x0, + 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, + 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x0, 0x1}; + uint8_t key_8[] = {0xAD, 0x7A, 0x2B, 0xD0, 0x3E, 0xAC, 0x83, 0x5A, 0x6F, + 0x62, 0xF, 0xDC, 0xB5, 0x6, 0xB3, 0x45, 0xAC, 0x7B, + 0x2A, 0xD1, 0x3F, 0xAD, 0x82, 0x5B, 0x6E, 0x63, 0xE, + 0xDD, 0xB4, 0x7, 0xB2, 0x44, 0xAF, 0x78, 0x29, 0xD2, + 0x3C, 0xAE, 0x81, 0x58, 0x6D, 0x60, 0xD, 0xDE}; + uint8_t plaintext_8[1] = {}; + uint8_t ciphertext_8[] = {0x3E, 0xA0, 0xB5, 0x84, 0xF3, 0xC8, 0x5E, 0x93, + 0xF9, 0x32, 0xE, 0xA5, 0x91, 0x69, 0x9E, 0xFB}; + vec = {nonce_8, aad_8, key_8, plaintext_8, ciphertext_8, 12, 70, 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.1.2 54-byte auth + uint8_t nonce_9[] = {0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, + 0x5E, 0x81, 0xB2, 0xC2, 0x84, 0x65}; + uint8_t aad_9[] = {0xD6, 0x9, 0xB1, 0xF0, 0x56, 0x63, 0x7A, 0xD, 0x46, 0xDF, + 0x99, 0x8D, 0x88, 0xE5, 0x22, 0x2A, 0xB2, 0xC2, 0x84, 0x65, + 0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, 0x5E, 0x81, 0x8, 0x0, + 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, + 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x0, 0x1}; + uint8_t key_9[] = {0xE3, 0xC0, 0x8A, 0x8F, 0x6, 0xC6, 0xE3, 0xAD, 0x95, + 0xA7, 0x5, 0x57, 0xB2, 0x3F, 0x75, 0x48, 0x3C, 0xE3, + 0x30, 0x21, 0xA9, 0xC7, 0x2B, 0x70, 0x25, 0x66, 0x62, + 0x4, 0xC6, 0x9C, 0xB, 0x72, 0xE1, 0xC2, 0x88, 0x8D, + 0x4, 0xC4, 0xE1, 0xAF, 0x97, 0xA5, 0x7, 0x55}; + uint8_t plaintext_9[1] = {}; + uint8_t ciphertext_9[] = {0x29, 0x4E, 0x2, 0x8B, 0xF1, 0xFE, 0x6F, 0x14, + 0xC4, 0xE8, 0xF7, 0x30, 0x5C, 0x93, 0x3E, 0xB5}; + vec = {nonce_9, aad_9, key_9, plaintext_9, ciphertext_9, 12, 70, 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.2.1 60-byte crypt + uint8_t nonce_10[] = {0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, + 0x5E, 0x81, 0xB2, 0xC2, 0x84, 0x65}; + uint8_t aad_10[] = {0xD6, 0x9, 0xB1, 0xF0, 0x56, 0x63, 0x7A, + 0xD, 0x46, 0xDF, 0x99, 0x8D, 0x88, 0xE5, + 0x2E, 0x0, 0xB2, 0xC2, 0x84, 0x65, 0x12, + 0x15, 0x35, 0x24, 0xC0, 0x89, 0x5E, 0x81}; + uint8_t key_10[] = {0xAD, 0x7A, 0x2B, 0xD0, 0x3E, 0xAC, 0x83, 0x5A, 0x6F, + 0x62, 0xF, 0xDC, 0xB5, 0x6, 0xB3, 0x45, 0xAC, 0x7B, + 0x2A, 0xD1, 0x3F, 0xAD, 0x82, 0x5B, 0x6E, 0x63, 0xE, + 0xDD, 0xB4, 0x7, 0xB2, 0x44, 0xAF, 0x78, 0x29, 0xD2, + 0x3C, 0xAE, 0x81, 0x58, 0x6D, 0x60, 0xD, 0xDE}; + uint8_t plaintext_10[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x0, 0x2}; + uint8_t ciphertext_10[] = { + 0xDB, 0x3D, 0x25, 0x71, 0x9C, 0x6B, 0xA, 0x3C, 0xA6, 0x14, 0x5C, + 0x15, 0x9D, 0x5C, 0x6E, 0xD9, 0xAF, 0xF9, 0xC6, 0xE0, 0xB7, 0x9F, + 0x17, 0x1, 0x9E, 0xA9, 0x23, 0xB8, 0x66, 0x5D, 0xDF, 0x52, 0x13, + 0x7A, 0xD6, 0x11, 0xF0, 0xD1, 0xBF, 0x41, 0x7A, 0x7C, 0xA8, 0x5E, + 0x45, 0xAF, 0xE1, 0x6, 0xFF, 0x9C, 0x75, 0x69, 0xD3, 0x35, 0xD0, + 0x86, 0xAE, 0x6C, 0x3, 0xF0, 0x9, 0x87, 0xCC, 0xD6}; + vec = {nonce_10, aad_10, key_10, plaintext_10, ciphertext_10, + 12, 28, 44, 48, 64}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.2.2 60-byte crypt + uint8_t nonce_11[] = {0x12, 0x15, 0x35, 0x24, 0xC0, 0x89, + 0x5E, 0x81, 0xB2, 0xC2, 0x84, 0x65}; + uint8_t aad_11[] = {0xD6, 0x9, 0xB1, 0xF0, 0x56, 0x63, 0x7A, + 0xD, 0x46, 0xDF, 0x99, 0x8D, 0x88, 0xE5, + 0x2E, 0x0, 0xB2, 0xC2, 0x84, 0x65, 0x12, + 0x15, 0x35, 0x24, 0xC0, 0x89, 0x5E, 0x81}; + uint8_t key_11[] = {0xE3, 0xC0, 0x8A, 0x8F, 0x6, 0xC6, 0xE3, 0xAD, 0x95, + 0xA7, 0x5, 0x57, 0xB2, 0x3F, 0x75, 0x48, 0x3C, 0xE3, + 0x30, 0x21, 0xA9, 0xC7, 0x2B, 0x70, 0x25, 0x66, 0x62, + 0x4, 0xC6, 0x9C, 0xB, 0x72, 0xE1, 0xC2, 0x88, 0x8D, + 0x4, 0xC4, 0xE1, 0xAF, 0x97, 0xA5, 0x7, 0x55}; + uint8_t plaintext_11[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x0, 0x2}; + uint8_t ciphertext_11[] = { + 0x16, 0x41, 0xF2, 0x8E, 0xC1, 0x3A, 0xFC, 0xC8, 0xF7, 0x90, 0x33, + 0x89, 0x78, 0x72, 0x1, 0x5, 0x16, 0x44, 0x91, 0x49, 0x33, 0xE9, + 0x20, 0x2B, 0xB9, 0xD0, 0x6A, 0xA0, 0x20, 0xC2, 0xA6, 0x7E, 0xF5, + 0x1D, 0xFE, 0x7B, 0xC0, 0xA, 0x85, 0x6C, 0x55, 0xB8, 0xF8, 0x13, + 0x3E, 0x77, 0xF6, 0x59, 0x13, 0x25, 0x2, 0xBA, 0xD6, 0x3F, 0x57, + 0x13, 0xD5, 0x7D, 0xC, 0x11, 0xE0, 0xF8, 0x71, 0xED}; + vec = {nonce_11, aad_11, key_11, plaintext_11, ciphertext_11, + 12, 28, 44, 48, 64}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.3.1 60-byte auth + uint8_t nonce_12[] = {0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x0, 0x1, 0x76, 0xD4, 0x57, 0xED}; + uint8_t aad_12[] = { + 0xE2, 0x1, 0x6, 0xD7, 0xCD, 0xD, 0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x88, 0xE5, 0x40, 0x0, 0x76, 0xD4, 0x57, 0xED, 0x8, 0x0, 0xF, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, + 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x0, 0x3}; + uint8_t key_12[] = {0x7, 0x1B, 0x11, 0x3B, 0xC, 0xA7, 0x43, 0xFE, 0xCC, + 0xCF, 0x3D, 0x5, 0x1F, 0x73, 0x73, 0x82, 0x6, 0x1A, + 0x10, 0x3A, 0xD, 0xA6, 0x42, 0xFF, 0xCD, 0xCE, 0x3C, + 0x4, 0x1E, 0x72, 0x72, 0x83, 0x5, 0x19, 0x13, 0x39, + 0xE, 0xA5, 0x41, 0xFC, 0xCE, 0xCD, 0x3F, 0x7}; + uint8_t plaintext_12[1] = {}; + uint8_t ciphertext_12[] = {0x58, 0x83, 0x7A, 0x10, 0x56, 0x2B, 0xF, 0x1F, + 0x8E, 0xDB, 0xE5, 0x8C, 0xA5, 0x58, 0x11, 0xD3}; + vec = {nonce_12, aad_12, key_12, plaintext_12, ciphertext_12, 12, 68, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.3.2 60-byte auth + uint8_t nonce_13[] = {0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x0, 0x1, 0x76, 0xD4, 0x57, 0xED}; + uint8_t aad_13[] = { + 0xE2, 0x1, 0x6, 0xD7, 0xCD, 0xD, 0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x88, 0xE5, 0x40, 0x0, 0x76, 0xD4, 0x57, 0xED, 0x8, 0x0, 0xF, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, + 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x0, 0x3}; + uint8_t key_13[] = {0x69, 0x1D, 0x3E, 0xE9, 0x9, 0xD7, 0xF5, 0x41, 0x67, + 0xFD, 0x1C, 0xA0, 0xB5, 0xD7, 0x69, 0x8, 0x1F, 0x2B, + 0xDE, 0x1A, 0xEE, 0x65, 0x5F, 0xDB, 0xAB, 0x80, 0xBD, + 0x52, 0x95, 0xAE, 0x6B, 0xE7, 0x6B, 0x1F, 0x3C, 0xEB, + 0xB, 0xD5, 0xF7, 0x43, 0x65, 0xFF, 0x1E, 0xA2}; + uint8_t plaintext_13[1] = {}; + uint8_t ciphertext_13[] = {0xC2, 0x72, 0x2F, 0xF6, 0xCA, 0x29, 0xA2, 0x57, + 0x71, 0x8A, 0x52, 0x9D, 0x1F, 0xC, 0x6A, 0x3B}; + vec = {nonce_13, aad_13, key_13, plaintext_13, ciphertext_13, 12, 68, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.4.1 54-byte crypt + uint8_t nonce_14[] = {0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x0, 0x1, 0x76, 0xD4, 0x57, 0xED}; + uint8_t aad_14[] = {0xE2, 0x1, 0x6, 0xD7, 0xCD, 0xD, 0xF0, + 0x76, 0x1E, 0x8D, 0xCD, 0x3D, 0x88, 0xE5, + 0x4C, 0x2A, 0x76, 0xD4, 0x57, 0xED}; + uint8_t key_14[] = {0x7, 0x1B, 0x11, 0x3B, 0xC, 0xA7, 0x43, 0xFE, 0xCC, + 0xCF, 0x3D, 0x5, 0x1F, 0x73, 0x73, 0x82, 0x6, 0x1A, + 0x10, 0x3A, 0xD, 0xA6, 0x42, 0xFF, 0xCD, 0xCE, 0x3C, + 0x4, 0x1E, 0x72, 0x72, 0x83, 0x5, 0x19, 0x13, 0x39, + 0xE, 0xA5, 0x41, 0xFC, 0xCE, 0xCD, 0x3F, 0x7}; + uint8_t plaintext_14[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, + 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x0, 0x4}; + uint8_t ciphertext_14[] = { + 0xFD, 0x96, 0xB7, 0x15, 0xB9, 0x3A, 0x13, 0x34, 0x6A, 0xF5, 0x1E, 0x8A, + 0xCD, 0xF7, 0x92, 0xCD, 0xC7, 0xB2, 0x68, 0x6F, 0x85, 0x74, 0xC7, 0xE, + 0x6B, 0xC, 0xBF, 0x16, 0x29, 0x1D, 0xED, 0x42, 0x7A, 0xD7, 0x3F, 0xEC, + 0x48, 0xCD, 0x29, 0x8E, 0x5, 0x28, 0xA1, 0xF4, 0xC6, 0x44, 0xA9, 0x49, + 0xFC, 0x31, 0xDC, 0x92, 0x79, 0x70, 0x6D, 0xDB, 0xA3, 0x3F}; + vec = {nonce_14, aad_14, key_14, plaintext_14, ciphertext_14, + 12, 20, 44, 42, 58}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.4.2 54-byte crypt + uint8_t nonce_15[] = {0xF0, 0x76, 0x1E, 0x8D, 0xCD, 0x3D, + 0x0, 0x1, 0x76, 0xD4, 0x57, 0xED}; + uint8_t aad_15[] = {0xE2, 0x1, 0x6, 0xD7, 0xCD, 0xD, 0xF0, + 0x76, 0x1E, 0x8D, 0xCD, 0x3D, 0x88, 0xE5, + 0x4C, 0x2A, 0x76, 0xD4, 0x57, 0xED}; + uint8_t key_15[] = {0x69, 0x1D, 0x3E, 0xE9, 0x9, 0xD7, 0xF5, 0x41, 0x67, + 0xFD, 0x1C, 0xA0, 0xB5, 0xD7, 0x69, 0x8, 0x1F, 0x2B, + 0xDE, 0x1A, 0xEE, 0x65, 0x5F, 0xDB, 0xAB, 0x80, 0xBD, + 0x52, 0x95, 0xAE, 0x6B, 0xE7, 0x6B, 0x1F, 0x3C, 0xEB, + 0xB, 0xD5, 0xF7, 0x43, 0x65, 0xFF, 0x1E, 0xA2}; + uint8_t plaintext_15[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, + 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x0, 0x4}; + uint8_t ciphertext_15[] = { + 0xB6, 0x8F, 0x63, 0x0, 0xC2, 0xE9, 0xAE, 0x83, 0x3B, 0xDC, 0x7, 0xE, + 0x24, 0x2, 0x1A, 0x34, 0x77, 0x11, 0x8E, 0x78, 0xCC, 0xF8, 0x4E, 0x11, + 0xA4, 0x85, 0xD8, 0x61, 0x47, 0x6C, 0x30, 0xF, 0x17, 0x53, 0x53, 0xD5, + 0xCD, 0xF9, 0x20, 0x8, 0xA4, 0xF8, 0x78, 0xE6, 0xCC, 0x35, 0x77, 0x76, + 0x80, 0x85, 0xC5, 0xA, 0xE, 0x98, 0xFD, 0xA6, 0xCB, 0xB8}; + vec = {nonce_15, aad_15, key_15, plaintext_15, ciphertext_15, + 12, 20, 44, 42, 58}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.5.1 65-byte auth + uint8_t nonce_16[] = {0x7C, 0xFD, 0xE9, 0xF9, 0xE3, 0x37, + 0x24, 0xC6, 0x89, 0x32, 0xD6, 0x12}; + uint8_t aad_16[] = { + 0x84, 0xC5, 0xD5, 0x13, 0xD2, 0xAA, 0xF6, 0xE5, 0xBB, 0xD2, 0x72, 0x77, + 0x88, 0xE5, 0x23, 0x0, 0x89, 0x32, 0xD6, 0x12, 0x7C, 0xFD, 0xE9, 0xF9, + 0xE3, 0x37, 0x24, 0xC6, 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, + 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x0, 0x5}; + uint8_t key_16[] = {0x1, 0x3F, 0xE0, 0xB, 0x5F, 0x11, 0xBE, 0x7F, 0x86, + 0x6D, 0xC, 0xBB, 0xC5, 0x5A, 0x7A, 0x90, 0x0, 0x3E, + 0xE1, 0xA, 0x5E, 0x10, 0xBF, 0x7E, 0x87, 0x6C, 0xD, + 0xBA, 0xC4, 0x5B, 0x7B, 0x91, 0x3, 0x3D, 0xE2, 0x9, + 0x5D, 0x13, 0xBC, 0x7D, 0x84, 0x6F, 0xE, 0xB9}; + uint8_t plaintext_16[1] = {}; + uint8_t ciphertext_16[] = {0xCC, 0xA2, 0xE, 0xEC, 0xDA, 0x62, 0x83, 0xF0, + 0x9B, 0xB3, 0x54, 0x3D, 0xD9, 0x9E, 0xDB, 0x9B}; + vec = {nonce_16, aad_16, key_16, plaintext_16, ciphertext_16, 12, 81, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.5.2 65-byte auth + uint8_t nonce_17[] = {0x7C, 0xFD, 0xE9, 0xF9, 0xE3, 0x37, + 0x24, 0xC6, 0x89, 0x32, 0xD6, 0x12}; + uint8_t aad_17[] = { + 0x84, 0xC5, 0xD5, 0x13, 0xD2, 0xAA, 0xF6, 0xE5, 0xBB, 0xD2, 0x72, 0x77, + 0x88, 0xE5, 0x23, 0x0, 0x89, 0x32, 0xD6, 0x12, 0x7C, 0xFD, 0xE9, 0xF9, + 0xE3, 0x37, 0x24, 0xC6, 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, + 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x0, 0x5}; + uint8_t key_17[] = {0x83, 0xC0, 0x93, 0xB5, 0x8D, 0xE7, 0xFF, 0xE1, 0xC0, + 0xDA, 0x92, 0x6A, 0xC4, 0x3F, 0xB3, 0x60, 0x9A, 0xC1, + 0xC8, 0xF, 0xEE, 0x1B, 0x62, 0x44, 0x97, 0xEF, 0x94, + 0x2E, 0x2F, 0x79, 0xA8, 0x23, 0x81, 0xC2, 0x91, 0xB7, + 0x8F, 0xE5, 0xFD, 0xE3, 0xC2, 0xD8, 0x90, 0x68}; + uint8_t plaintext_17[1] = {}; + uint8_t ciphertext_17[] = {0xB2, 0x32, 0xCC, 0x1D, 0xA5, 0x11, 0x7B, 0xF1, + 0x50, 0x3, 0x73, 0x4F, 0xA5, 0x99, 0xD2, 0x71}; + vec = {nonce_17, aad_17, key_17, plaintext_17, ciphertext_17, 12, 81, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.6.1 61-byte crypt + uint8_t nonce_18[] = {0x7C, 0xFD, 0xE9, 0xF9, 0xE3, 0x37, + 0x24, 0xC6, 0x89, 0x32, 0xD6, 0x12}; + uint8_t aad_18[] = {0x84, 0xC5, 0xD5, 0x13, 0xD2, 0xAA, 0xF6, + 0xE5, 0xBB, 0xD2, 0x72, 0x77, 0x88, 0xE5, + 0x2F, 0x0, 0x89, 0x32, 0xD6, 0x12, 0x7C, + 0xFD, 0xE9, 0xF9, 0xE3, 0x37, 0x24, 0xC6}; + uint8_t key_18[] = {0x1, 0x3F, 0xE0, 0xB, 0x5F, 0x11, 0xBE, 0x7F, 0x86, + 0x6D, 0xC, 0xBB, 0xC5, 0x5A, 0x7A, 0x90, 0x0, 0x3E, + 0xE1, 0xA, 0x5E, 0x10, 0xBF, 0x7E, 0x87, 0x6C, 0xD, + 0xBA, 0xC4, 0x5B, 0x7B, 0x91, 0x3, 0x3D, 0xE2, 0x9, + 0x5D, 0x13, 0xBC, 0x7D, 0x84, 0x6F, 0xE, 0xB9}; + uint8_t plaintext_18[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, + 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x0, 0x6}; + uint8_t ciphertext_18[] = { + 0xFF, 0x19, 0x10, 0xD3, 0x5A, 0xD7, 0xE5, 0x65, 0x78, 0x90, 0xC7, + 0xC5, 0x60, 0x14, 0x6F, 0xD0, 0x38, 0x70, 0x7F, 0x20, 0x4B, 0x66, + 0xED, 0xBC, 0x3D, 0x16, 0x1F, 0x8A, 0xCE, 0x24, 0x4B, 0x98, 0x59, + 0x21, 0x2, 0x3C, 0x43, 0x6E, 0x3A, 0x1C, 0x35, 0x32, 0xEC, 0xD5, + 0xD0, 0x9A, 0x5, 0x6D, 0x70, 0xBE, 0x58, 0x3F, 0xD, 0x10, 0x82, + 0x9D, 0x93, 0x87, 0xD0, 0x7D, 0x33, 0xD8, 0x72, 0xE4, 0x90}; + vec = {nonce_18, aad_18, key_18, plaintext_18, ciphertext_18, + 12, 28, 44, 49, 65}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.6.2 61-byte crypt + uint8_t nonce_19[] = {0x7C, 0xFD, 0xE9, 0xF9, 0xE3, 0x37, + 0x24, 0xC6, 0x89, 0x32, 0xD6, 0x12}; + uint8_t aad_19[] = {0x84, 0xC5, 0xD5, 0x13, 0xD2, 0xAA, 0xF6, + 0xE5, 0xBB, 0xD2, 0x72, 0x77, 0x88, 0xE5, + 0x2F, 0x0, 0x89, 0x32, 0xD6, 0x12, 0x7C, + 0xFD, 0xE9, 0xF9, 0xE3, 0x37, 0x24, 0xC6}; + uint8_t key_19[] = {0x83, 0xC0, 0x93, 0xB5, 0x8D, 0xE7, 0xFF, 0xE1, 0xC0, + 0xDA, 0x92, 0x6A, 0xC4, 0x3F, 0xB3, 0x60, 0x9A, 0xC1, + 0xC8, 0xF, 0xEE, 0x1B, 0x62, 0x44, 0x97, 0xEF, 0x94, + 0x2E, 0x2F, 0x79, 0xA8, 0x23, 0x81, 0xC2, 0x91, 0xB7, + 0x8F, 0xE5, 0xFD, 0xE3, 0xC2, 0xD8, 0x90, 0x68}; + uint8_t plaintext_19[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, + 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x0, 0x6}; + uint8_t ciphertext_19[] = { + 0xD, 0xB4, 0xCF, 0x95, 0x6B, 0x5F, 0x97, 0xEC, 0xA4, 0xEA, 0xB8, + 0x2A, 0x69, 0x55, 0x30, 0x7F, 0x9A, 0xE0, 0x2A, 0x32, 0xDD, 0x7D, + 0x93, 0xF8, 0x3D, 0x66, 0xAD, 0x4, 0xE1, 0xCF, 0xDC, 0x51, 0x82, + 0xAD, 0x12, 0xAB, 0xDE, 0xA5, 0xBB, 0xB6, 0x19, 0xA1, 0xBD, 0x5F, + 0xB9, 0xA5, 0x73, 0x59, 0xF, 0xBA, 0x90, 0x8E, 0x9C, 0x7A, 0x46, + 0xC1, 0xF7, 0xBA, 0x9, 0x5, 0xD1, 0xB5, 0x5F, 0xFD, 0xA4}; + vec = {nonce_19, aad_19, key_19, plaintext_19, ciphertext_19, + 12, 28, 44, 49, 65}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.7.1 79-byte crypt + uint8_t nonce_20[] = {0x7A, 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, + 0x0, 0x1, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t aad_20[] = { + 0x68, 0xF2, 0xE7, 0x76, 0x96, 0xCE, 0x7A, 0xE8, 0xE2, 0xCA, 0x4E, + 0xC5, 0x88, 0xE5, 0x41, 0x0, 0x2E, 0x58, 0x49, 0x5C, 0x8, 0x0, + 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, + 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, + 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x0, 0x7}; + uint8_t key_20[] = {0x88, 0xEE, 0x8, 0x7F, 0xD9, 0x5D, 0xA9, 0xFB, 0xF6, + 0x72, 0x5A, 0xA9, 0xD7, 0x57, 0xB0, 0xCD, 0x89, 0xEF, + 0x9, 0x7E, 0xD8, 0x5C, 0xA8, 0xFA, 0xF7, 0x73, 0x5B, + 0xA8, 0xD6, 0x56, 0xB1, 0xCC, 0x8A, 0xEC, 0xA, 0x7D, + 0xDB, 0x5F, 0xAB, 0xF9, 0xF4, 0x70, 0x58, 0xAB}; + uint8_t plaintext_20[1] = {}; + uint8_t ciphertext_20[] = {0x81, 0x3F, 0xE, 0x63, 0xF, 0x96, 0xFB, 0x2D, + 0x3, 0xF, 0x58, 0xD8, 0x3F, 0x5C, 0xDF, 0xD0}; + vec = {nonce_20, aad_20, key_20, plaintext_20, ciphertext_20, 12, 87, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.7.2 79-byte crypt + uint8_t nonce_21[] = {0x7A, 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, + 0x0, 0x1, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t aad_21[] = { + 0x68, 0xF2, 0xE7, 0x76, 0x96, 0xCE, 0x7A, 0xE8, 0xE2, 0xCA, 0x4E, + 0xC5, 0x88, 0xE5, 0x41, 0x0, 0x2E, 0x58, 0x49, 0x5C, 0x8, 0x0, + 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, + 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, + 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x0, 0x7}; + uint8_t key_21[] = {0x4C, 0x97, 0x3D, 0xBC, 0x73, 0x64, 0x62, 0x16, 0x74, + 0xF8, 0xB5, 0xB8, 0x9E, 0x5C, 0x15, 0x51, 0x1F, 0xCE, + 0xD9, 0x21, 0x64, 0x90, 0xFB, 0x1C, 0x1A, 0x2C, 0xAA, + 0xF, 0xFE, 0x4, 0x7, 0xE5, 0x4E, 0x95, 0x3F, 0xBE, + 0x71, 0x66, 0x60, 0x14, 0x76, 0xFA, 0xB7, 0xBA}; + uint8_t plaintext_21[1] = {}; + uint8_t ciphertext_21[] = {0x77, 0xE5, 0xA4, 0x4C, 0x21, 0xEB, 0x7, 0x18, + 0x8A, 0xAC, 0xBD, 0x74, 0xD1, 0x98, 0xE, 0x97}; + vec = {nonce_21, aad_21, key_21, plaintext_21, ciphertext_21, 12, 87, + 44, 0, 16}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.8.1 61-byte crypt + uint8_t nonce_22[] = {0x7A, 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, + 0x0, 0x1, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t aad_22[] = {0x68, 0xF2, 0xE7, 0x76, 0x96, 0xCE, 0x7A, + 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, 0x88, 0xE5, + 0x4D, 0x0, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t key_22[] = {0x88, 0xEE, 0x8, 0x7F, 0xD9, 0x5D, 0xA9, 0xFB, 0xF6, + 0x72, 0x5A, 0xA9, 0xD7, 0x57, 0xB0, 0xCD, 0x89, 0xEF, + 0x9, 0x7E, 0xD8, 0x5C, 0xA8, 0xFA, 0xF7, 0x73, 0x5B, + 0xA8, 0xD6, 0x56, 0xB1, 0xCC, 0x8A, 0xEC, 0xA, 0x7D, + 0xDB, 0x5F, 0xAB, 0xF9, 0xF4, 0x70, 0x58, 0xAB}; + uint8_t plaintext_22[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, + 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x0, 0x8}; + uint8_t ciphertext_22[] = { + 0x95, 0x8E, 0xC3, 0xF6, 0xD6, 0xA, 0xFE, 0xDA, 0x99, 0xEF, 0xD8, 0x88, + 0xF1, 0x75, 0xE5, 0xFC, 0xD4, 0xC8, 0x7B, 0x9B, 0xCC, 0x5C, 0x2F, 0x54, + 0x26, 0x25, 0x3A, 0x8B, 0x50, 0x62, 0x96, 0xC8, 0xC4, 0x33, 0x9, 0xAB, + 0x2A, 0xDB, 0x59, 0x39, 0x46, 0x25, 0x41, 0xD9, 0x5E, 0x80, 0x81, 0x1E, + 0x4, 0xE7, 0x6, 0xB1, 0x49, 0x8F, 0x2C, 0x40, 0x7C, 0x7F, 0xB2, 0x34, + 0xF8, 0xCC, 0x1, 0xA6, 0x47, 0x55, 0xE, 0xE6, 0xB5, 0x57, 0xB3, 0x5A, + 0x7E, 0x39, 0x45, 0x38, 0x18, 0x21, 0xF4}; + vec = {nonce_22, aad_22, key_22, plaintext_22, ciphertext_22, + 12, 20, 44, 63, 79}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); + + // Derived from IEEE 2.8.2 61-byte crypt + uint8_t nonce_23[] = {0x7A, 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, + 0x0, 0x1, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t aad_23[] = {0x68, 0xF2, 0xE7, 0x76, 0x96, 0xCE, 0x7A, + 0xE8, 0xE2, 0xCA, 0x4E, 0xC5, 0x88, 0xE5, + 0x4D, 0x0, 0x2E, 0x58, 0x49, 0x5C}; + uint8_t key_23[] = {0x4C, 0x97, 0x3D, 0xBC, 0x73, 0x64, 0x62, 0x16, 0x74, + 0xF8, 0xB5, 0xB8, 0x9E, 0x5C, 0x15, 0x51, 0x1F, 0xCE, + 0xD9, 0x21, 0x64, 0x90, 0xFB, 0x1C, 0x1A, 0x2C, 0xAA, + 0xF, 0xFE, 0x4, 0x7, 0xE5, 0x4E, 0x95, 0x3F, 0xBE, + 0x71, 0x66, 0x60, 0x14, 0x76, 0xFA, 0xB7, 0xBA}; + uint8_t plaintext_23[] = { + 0x8, 0x0, 0xF, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, + 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x0, 0x8}; + uint8_t ciphertext_23[] = { + 0xB4, 0x4D, 0x7, 0x20, 0x11, 0xCD, 0x36, 0xD2, 0x72, 0xA9, 0xB7, 0xA9, + 0x8D, 0xB9, 0xAA, 0x90, 0xCB, 0xC5, 0xC6, 0x7B, 0x93, 0xDD, 0xCE, 0x67, + 0xC8, 0x54, 0x50, 0x32, 0x14, 0xE2, 0xE8, 0x96, 0xEC, 0x7E, 0x9D, 0xB6, + 0x49, 0xED, 0x4B, 0xCF, 0x6F, 0x85, 0xA, 0xAC, 0x2, 0x23, 0xD0, 0xCF, + 0x92, 0xC8, 0x3D, 0xB8, 0x7, 0x95, 0xC3, 0xA1, 0x7E, 0xCC, 0x12, 0x48, + 0xBB, 0x0, 0x59, 0x17, 0x12, 0xB1, 0xAE, 0x71, 0xE2, 0x68, 0x16, 0x41, + 0x96, 0x25, 0x21, 0x62, 0x81, 0xB, 0x0}; + vec = {nonce_23, aad_23, key_23, plaintext_23, ciphertext_23, + 12, 20, 44, 63, 79}; + gsec_test_verify_crypter_on_test_vector(&vec, /*rekey=*/true); +} + +static void gsec_test_do_vector_tests_nist() { + /** + * From: + * http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/ + * gcm-revised-spec.pdf + */ + + /* Test vector 1 */ + gsec_aead_test_vector* test_vector_1; + const uint8_t test_vector_1_key[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}; + const uint8_t test_vector_1_nonce[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + const uint8_t test_vector_1_aad[1] = {}; + const uint8_t test_vector_1_plaintext[1] = {}; + const uint8_t test_vector_1_ciphertext_and_tag[] = { + 0x58, 0xe2, 0xfc, 0xce, 0xfa, 0x7e, 0x30, 0x61, + 0x36, 0x7f, 0x1d, 0x57, 0xa4, 0xe7, 0x45, 0x5a}; + gsec_aead_malloc_test_vector( + &test_vector_1, test_vector_1_key, + sizeof(test_vector_1_key) / sizeof(uint8_t), test_vector_1_nonce, + sizeof(test_vector_1_nonce) / sizeof(uint8_t), test_vector_1_aad, 0, + test_vector_1_plaintext, 0, test_vector_1_ciphertext_and_tag, + sizeof(test_vector_1_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_1); + gsec_aead_free_test_vector(test_vector_1); + + /* Test vector 2 */ + gsec_aead_test_vector* test_vector_2; + const uint8_t test_vector_2_key[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}; + const uint8_t test_vector_2_nonce[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + const uint8_t test_vector_2_aad[1] = {}; + const uint8_t test_vector_2_plaintext[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}; + const uint8_t test_vector_2_ciphertext_and_tag[] = { + 0x03, 0x88, 0xda, 0xce, 0x60, 0xb6, 0xa3, 0x92, 0xf3, 0x28, 0xc2, + 0xb9, 0x71, 0xb2, 0xfe, 0x78, 0xab, 0x6e, 0x47, 0xd4, 0x2c, 0xec, + 0x13, 0xbd, 0xf5, 0x3a, 0x67, 0xb2, 0x12, 0x57, 0xbd, 0xdf}; + gsec_aead_malloc_test_vector( + &test_vector_2, test_vector_2_key, + sizeof(test_vector_2_key) / sizeof(uint8_t), test_vector_2_nonce, + sizeof(test_vector_2_nonce) / sizeof(uint8_t), test_vector_2_aad, 0, + test_vector_2_plaintext, + sizeof(test_vector_2_plaintext) / sizeof(uint8_t), + test_vector_2_ciphertext_and_tag, + sizeof(test_vector_2_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_2); + gsec_aead_free_test_vector(test_vector_2); + + /* Test vector 3 */ + gsec_aead_test_vector* test_vector_3; + const uint8_t test_vector_3_key[] = {0xfe, 0xff, 0xe9, 0x92, 0x86, 0x65, + 0x73, 0x1c, 0x6d, 0x6a, 0x8f, 0x94, + 0x67, 0x30, 0x83, 0x08}; + const uint8_t test_vector_3_nonce[] = {0xca, 0xfe, 0xba, 0xbe, 0xfa, 0xce, + 0xdb, 0xad, 0xde, 0xca, 0xf8, 0x88}; + const uint8_t test_vector_3_aad[1] = {}; + const uint8_t test_vector_3_plaintext[] = { + 0xd9, 0x31, 0x32, 0x25, 0xf8, 0x84, 0x06, 0xe5, 0xa5, 0x59, 0x09, + 0xc5, 0xaf, 0xf5, 0x26, 0x9a, 0x86, 0xa7, 0xa9, 0x53, 0x15, 0x34, + 0xf7, 0xda, 0x2e, 0x4c, 0x30, 0x3d, 0x8a, 0x31, 0x8a, 0x72, 0x1c, + 0x3c, 0x0c, 0x95, 0x95, 0x68, 0x09, 0x53, 0x2f, 0xcf, 0x0e, 0x24, + 0x49, 0xa6, 0xb5, 0x25, 0xb1, 0x6a, 0xed, 0xf5, 0xaa, 0x0d, 0xe6, + 0x57, 0xba, 0x63, 0x7b, 0x39, 0x1a, 0xaf, 0xd2, 0x55}; + const uint8_t test_vector_3_ciphertext_and_tag[] = { + 0x42, 0x83, 0x1e, 0xc2, 0x21, 0x77, 0x74, 0x24, 0x4b, 0x72, 0x21, 0xb7, + 0x84, 0xd0, 0xd4, 0x9c, 0xe3, 0xaa, 0x21, 0x2f, 0x2c, 0x02, 0xa4, 0xe0, + 0x35, 0xc1, 0x7e, 0x23, 0x29, 0xac, 0xa1, 0x2e, 0x21, 0xd5, 0x14, 0xb2, + 0x54, 0x66, 0x93, 0x1c, 0x7d, 0x8f, 0x6a, 0x5a, 0xac, 0x84, 0xaa, 0x05, + 0x1b, 0xa3, 0x0b, 0x39, 0x6a, 0x0a, 0xac, 0x97, 0x3d, 0x58, 0xe0, 0x91, + 0x47, 0x3f, 0x59, 0x85, 0x4d, 0x5c, 0x2a, 0xf3, 0x27, 0xcd, 0x64, 0xa6, + 0x2c, 0xf3, 0x5a, 0xbd, 0x2b, 0xa6, 0xfa, 0xb4}; + gsec_aead_malloc_test_vector( + &test_vector_3, test_vector_3_key, + sizeof(test_vector_3_key) / sizeof(uint8_t), test_vector_3_nonce, + sizeof(test_vector_3_nonce) / sizeof(uint8_t), test_vector_3_aad, 0, + test_vector_3_plaintext, + sizeof(test_vector_3_plaintext) / sizeof(uint8_t), + test_vector_3_ciphertext_and_tag, + sizeof(test_vector_3_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_3); + gsec_aead_free_test_vector(test_vector_3); + + /* Test vector 4 */ + gsec_aead_test_vector* test_vector_4; + const uint8_t test_vector_4_key[] = {0xfe, 0xff, 0xe9, 0x92, 0x86, 0x65, + 0x73, 0x1c, 0x6d, 0x6a, 0x8f, 0x94, + 0x67, 0x30, 0x83, 0x08}; + const uint8_t test_vector_4_nonce[] = {0xca, 0xfe, 0xba, 0xbe, 0xfa, 0xce, + 0xdb, 0xad, 0xde, 0xca, 0xf8, 0x88}; + const uint8_t test_vector_4_aad[] = {0xfe, 0xed, 0xfa, 0xce, 0xde, 0xad, 0xbe, + 0xef, 0xfe, 0xed, 0xfa, 0xce, 0xde, 0xad, + 0xbe, 0xef, 0xab, 0xad, 0xda, 0xd2}; + const uint8_t test_vector_4_plaintext[] = { + 0xd9, 0x31, 0x32, 0x25, 0xf8, 0x84, 0x06, 0xe5, 0xa5, 0x59, 0x09, 0xc5, + 0xaf, 0xf5, 0x26, 0x9a, 0x86, 0xa7, 0xa9, 0x53, 0x15, 0x34, 0xf7, 0xda, + 0x2e, 0x4c, 0x30, 0x3d, 0x8a, 0x31, 0x8a, 0x72, 0x1c, 0x3c, 0x0c, 0x95, + 0x95, 0x68, 0x09, 0x53, 0x2f, 0xcf, 0x0e, 0x24, 0x49, 0xa6, 0xb5, 0x25, + 0xb1, 0x6a, 0xed, 0xf5, 0xaa, 0x0d, 0xe6, 0x57, 0xba, 0x63, 0x7b, 0x39}; + const uint8_t test_vector_4_ciphertext_and_tag[] = { + 0x42, 0x83, 0x1e, 0xc2, 0x21, 0x77, 0x74, 0x24, 0x4b, 0x72, 0x21, + 0xb7, 0x84, 0xd0, 0xd4, 0x9c, 0xe3, 0xaa, 0x21, 0x2f, 0x2c, 0x02, + 0xa4, 0xe0, 0x35, 0xc1, 0x7e, 0x23, 0x29, 0xac, 0xa1, 0x2e, 0x21, + 0xd5, 0x14, 0xb2, 0x54, 0x66, 0x93, 0x1c, 0x7d, 0x8f, 0x6a, 0x5a, + 0xac, 0x84, 0xaa, 0x05, 0x1b, 0xa3, 0x0b, 0x39, 0x6a, 0x0a, 0xac, + 0x97, 0x3d, 0x58, 0xe0, 0x91, 0x5b, 0xc9, 0x4f, 0xbc, 0x32, 0x21, + 0xa5, 0xdb, 0x94, 0xfa, 0xe9, 0x5a, 0xe7, 0x12, 0x1a, 0x47}; + gsec_aead_malloc_test_vector( + &test_vector_4, test_vector_4_key, + sizeof(test_vector_4_key) / sizeof(uint8_t), test_vector_4_nonce, + sizeof(test_vector_4_nonce) / sizeof(uint8_t), test_vector_4_aad, + sizeof(test_vector_4_aad) / sizeof(uint8_t), test_vector_4_plaintext, + sizeof(test_vector_4_plaintext) / sizeof(uint8_t), + test_vector_4_ciphertext_and_tag, + sizeof(test_vector_4_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_4); + gsec_aead_free_test_vector(test_vector_4); +} + +static void gsec_test_do_vector_tests_ieee() { + /** + * From: + * http://www.ieee802.org/1/files/public/docs2011/ + * bn-randall-test-vectors-0511-v1.pdf + */ + + /* 2.1.1 54-byte auth */ + gsec_aead_test_vector* test_vector_5; + const uint8_t test_vector_5_key[] = {0xad, 0x7a, 0x2b, 0xd0, 0x3e, 0xac, + 0x83, 0x5a, 0x6f, 0x62, 0x0f, 0xdc, + 0xb5, 0x06, 0xb3, 0x45}; + const uint8_t test_vector_5_nonce[] = {0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, + 0x5e, 0x81, 0xb2, 0xc2, 0x84, 0x65}; + const uint8_t test_vector_5_aad[] = { + 0xd6, 0x09, 0xb1, 0xf0, 0x56, 0x63, 0x7a, 0x0d, 0x46, 0xdf, 0x99, 0x8d, + 0x88, 0xe5, 0x22, 0x2a, 0xb2, 0xc2, 0x84, 0x65, 0x12, 0x15, 0x35, 0x24, + 0xc0, 0x89, 0x5e, 0x81, 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x00, 0x01}; + const uint8_t test_vector_5_plaintext[1] = {}; + const uint8_t test_vector_5_ciphertext_and_tag[] = { + 0xf0, 0x94, 0x78, 0xa9, 0xb0, 0x90, 0x07, 0xd0, + 0x6f, 0x46, 0xe9, 0xb6, 0xa1, 0xda, 0x25, 0xdd}; + gsec_aead_malloc_test_vector( + &test_vector_5, test_vector_5_key, + sizeof(test_vector_5_key) / sizeof(uint8_t), test_vector_5_nonce, + sizeof(test_vector_5_nonce) / sizeof(uint8_t), test_vector_5_aad, + sizeof(test_vector_5_aad) / sizeof(uint8_t), test_vector_5_plaintext, 0, + test_vector_5_ciphertext_and_tag, + sizeof(test_vector_5_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_5); + gsec_aead_free_test_vector(test_vector_5); + + /* 2.1.2 54-byte auth */ + gsec_aead_test_vector* test_vector_6; + const uint8_t test_vector_6_key[] = { + 0xe3, 0xc0, 0x8a, 0x8f, 0x06, 0xc6, 0xe3, 0xad, 0x95, 0xa7, 0x05, + 0x57, 0xb2, 0x3f, 0x75, 0x48, 0x3c, 0xe3, 0x30, 0x21, 0xa9, 0xc7, + 0x2b, 0x70, 0x25, 0x66, 0x62, 0x04, 0xc6, 0x9c, 0x0b, 0x72}; + + const uint8_t test_vector_6_nonce[] = {0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, + 0x5e, 0x81, 0xb2, 0xc2, 0x84, 0x65}; + const uint8_t test_vector_6_aad[] = { + 0xd6, 0x09, 0xb1, 0xf0, 0x56, 0x63, 0x7a, 0x0d, 0x46, 0xdf, 0x99, 0x8d, + 0x88, 0xe5, 0x22, 0x2a, 0xb2, 0xc2, 0x84, 0x65, 0x12, 0x15, 0x35, 0x24, + 0xc0, 0x89, 0x5e, 0x81, 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x00, 0x01}; + const uint8_t test_vector_6_plaintext[1] = {}; + const uint8_t test_vector_6_ciphertext_and_tag[] = { + 0x2f, 0x0b, 0xc5, 0xaf, 0x40, 0x9e, 0x06, 0xd6, + 0x09, 0xea, 0x8b, 0x7d, 0x0f, 0xa5, 0xea, 0x50}; + gsec_aead_malloc_test_vector( + &test_vector_6, test_vector_6_key, + sizeof(test_vector_6_key) / sizeof(uint8_t), test_vector_6_nonce, + sizeof(test_vector_6_nonce) / sizeof(uint8_t), test_vector_6_aad, + sizeof(test_vector_6_aad) / sizeof(uint8_t), test_vector_6_plaintext, 0, + test_vector_6_ciphertext_and_tag, + sizeof(test_vector_6_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_6); + gsec_aead_free_test_vector(test_vector_6); + + /* 2.2.1 60-byte crypt */ + gsec_aead_test_vector* test_vector_7; + const uint8_t test_vector_7_key[] = {0xad, 0x7a, 0x2b, 0xd0, 0x3e, 0xac, + 0x83, 0x5a, 0x6f, 0x62, 0x0f, 0xdc, + 0xb5, 0x06, 0xb3, 0x45}; + + const uint8_t test_vector_7_nonce[] = {0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, + 0x5e, 0x81, 0xb2, 0xc2, 0x84, 0x65}; + const uint8_t test_vector_7_aad[] = { + 0xd6, 0x09, 0xb1, 0xf0, 0x56, 0x63, 0x7a, 0x0d, 0x46, 0xdf, + 0x99, 0x8d, 0x88, 0xe5, 0x2e, 0x00, 0xb2, 0xc2, 0x84, 0x65, + 0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, 0x5e, 0x81}; + const uint8_t test_vector_7_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x00, 0x02}; + const uint8_t test_vector_7_ciphertext_and_tag[] = { + 0x70, 0x1a, 0xfa, 0x1c, 0xc0, 0x39, 0xc0, 0xd7, 0x65, 0x12, 0x8a, + 0x66, 0x5d, 0xab, 0x69, 0x24, 0x38, 0x99, 0xbf, 0x73, 0x18, 0xcc, + 0xdc, 0x81, 0xc9, 0x93, 0x1d, 0xa1, 0x7f, 0xbe, 0x8e, 0xdd, 0x7d, + 0x17, 0xcb, 0x8b, 0x4c, 0x26, 0xfc, 0x81, 0xe3, 0x28, 0x4f, 0x2b, + 0x7f, 0xba, 0x71, 0x3d, 0x4f, 0x8d, 0x55, 0xe7, 0xd3, 0xf0, 0x6f, + 0xd5, 0xa1, 0x3c, 0x0c, 0x29, 0xb9, 0xd5, 0xb8, 0x80}; + gsec_aead_malloc_test_vector( + &test_vector_7, test_vector_7_key, + sizeof(test_vector_7_key) / sizeof(uint8_t), test_vector_7_nonce, + sizeof(test_vector_7_nonce) / sizeof(uint8_t), test_vector_7_aad, + sizeof(test_vector_7_aad) / sizeof(uint8_t), test_vector_7_plaintext, + sizeof(test_vector_7_plaintext) / sizeof(uint8_t), + test_vector_7_ciphertext_and_tag, + sizeof(test_vector_7_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_7); + gsec_aead_free_test_vector(test_vector_7); + + /* 2.2.2 60-byte crypt */ + gsec_aead_test_vector* test_vector_8; + const uint8_t test_vector_8_key[] = { + 0xe3, 0xc0, 0x8a, 0x8f, 0x06, 0xc6, 0xe3, 0xad, 0x95, 0xa7, 0x05, + 0x57, 0xb2, 0x3f, 0x75, 0x48, 0x3c, 0xe3, 0x30, 0x21, 0xa9, 0xc7, + 0x2b, 0x70, 0x25, 0x66, 0x62, 0x04, 0xc6, 0x9c, 0x0b, 0x72}; + const uint8_t test_vector_8_nonce[] = {0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, + 0x5e, 0x81, 0xb2, 0xc2, 0x84, 0x65}; + const uint8_t test_vector_8_aad[] = { + 0xd6, 0x09, 0xb1, 0xf0, 0x56, 0x63, 0x7a, 0x0d, 0x46, 0xdf, + 0x99, 0x8d, 0x88, 0xe5, 0x2e, 0x00, 0xb2, 0xc2, 0x84, 0x65, + 0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, 0x5e, 0x81}; + const uint8_t test_vector_8_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x00, 0x02}; + const uint8_t test_vector_8_ciphertext_and_tag[] = { + 0xe2, 0x00, 0x6e, 0xb4, 0x2f, 0x52, 0x77, 0x02, 0x2d, 0x9b, 0x19, + 0x92, 0x5b, 0xc4, 0x19, 0xd7, 0xa5, 0x92, 0x66, 0x6c, 0x92, 0x5f, + 0xe2, 0xef, 0x71, 0x8e, 0xb4, 0xe3, 0x08, 0xef, 0xea, 0xa7, 0xc5, + 0x27, 0x3b, 0x39, 0x41, 0x18, 0x86, 0x0a, 0x5b, 0xe2, 0xa9, 0x7f, + 0x56, 0xab, 0x78, 0x36, 0x5c, 0xa5, 0x97, 0xcd, 0xbb, 0x3e, 0xdb, + 0x8d, 0x1a, 0x11, 0x51, 0xea, 0x0a, 0xf7, 0xb4, 0x36}; + gsec_aead_malloc_test_vector( + &test_vector_8, test_vector_8_key, + sizeof(test_vector_8_key) / sizeof(uint8_t), test_vector_8_nonce, + sizeof(test_vector_8_nonce) / sizeof(uint8_t), test_vector_8_aad, + sizeof(test_vector_8_aad) / sizeof(uint8_t), test_vector_8_plaintext, + sizeof(test_vector_8_plaintext) / sizeof(uint8_t), + test_vector_8_ciphertext_and_tag, + sizeof(test_vector_8_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_8); + gsec_aead_free_test_vector(test_vector_8); + + /* 2.3.1 60-byte auth */ + gsec_aead_test_vector* test_vector_9; + const uint8_t test_vector_9_key[] = {0x07, 0x1b, 0x11, 0x3b, 0x0c, 0xa7, + 0x43, 0xfe, 0xcc, 0xcf, 0x3d, 0x05, + 0x1f, 0x73, 0x73, 0x82}; + const uint8_t test_vector_9_nonce[] = {0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x00, 0x01, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_9_aad[] = { + 0xe2, 0x01, 0x06, 0xd7, 0xcd, 0x0d, 0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x88, 0xe5, 0x40, 0x00, 0x76, 0xd4, 0x57, 0xed, 0x08, 0x00, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, + 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x00, 0x03}; + const uint8_t test_vector_9_plaintext[1] = {}; + const uint8_t test_vector_9_ciphertext_and_tag[] = { + 0x0c, 0x01, 0x7b, 0xc7, 0x3b, 0x22, 0x7d, 0xfc, + 0xc9, 0xba, 0xfa, 0x1c, 0x41, 0xac, 0xc3, 0x53}; + gsec_aead_malloc_test_vector( + &test_vector_9, test_vector_9_key, + sizeof(test_vector_9_key) / sizeof(uint8_t), test_vector_9_nonce, + sizeof(test_vector_9_nonce) / sizeof(uint8_t), test_vector_9_aad, + sizeof(test_vector_9_aad) / sizeof(uint8_t), test_vector_9_plaintext, 0, + test_vector_9_ciphertext_and_tag, + sizeof(test_vector_9_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_9); + gsec_aead_free_test_vector(test_vector_9); + + /* 2.3.2 60-byte auth */ + gsec_aead_test_vector* test_vector_10; + const uint8_t test_vector_10_key[] = { + 0x69, 0x1d, 0x3e, 0xe9, 0x09, 0xd7, 0xf5, 0x41, 0x67, 0xfd, 0x1c, + 0xa0, 0xb5, 0xd7, 0x69, 0x08, 0x1f, 0x2b, 0xde, 0x1a, 0xee, 0x65, + 0x5f, 0xdb, 0xab, 0x80, 0xbd, 0x52, 0x95, 0xae, 0x6b, 0xe7}; + const uint8_t test_vector_10_nonce[] = {0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x00, 0x01, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_10_aad[] = { + 0xe2, 0x01, 0x06, 0xd7, 0xcd, 0x0d, 0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x88, 0xe5, 0x40, 0x00, 0x76, 0xd4, 0x57, 0xed, 0x08, 0x00, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, + 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x00, 0x03}; + const uint8_t test_vector_10_plaintext[1] = {}; + const uint8_t test_vector_10_ciphertext_and_tag[] = { + 0x35, 0x21, 0x7c, 0x77, 0x4b, 0xbc, 0x31, 0xb6, + 0x31, 0x66, 0xbc, 0xf9, 0xd4, 0xab, 0xed, 0x07}; + gsec_aead_malloc_test_vector( + &test_vector_10, test_vector_10_key, + sizeof(test_vector_10_key) / sizeof(uint8_t), test_vector_10_nonce, + sizeof(test_vector_10_nonce) / sizeof(uint8_t), test_vector_10_aad, + sizeof(test_vector_10_aad) / sizeof(uint8_t), test_vector_10_plaintext, 0, + test_vector_10_ciphertext_and_tag, + sizeof(test_vector_10_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_10); + gsec_aead_free_test_vector(test_vector_10); + + /* 2.4.1 54-byte crypt */ + gsec_aead_test_vector* test_vector_11; + const uint8_t test_vector_11_key[] = {0x07, 0x1b, 0x11, 0x3b, 0x0c, 0xa7, + 0x43, 0xfe, 0xcc, 0xcf, 0x3d, 0x05, + 0x1f, 0x73, 0x73, 0x82}; + const uint8_t test_vector_11_nonce[] = {0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x00, 0x01, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_11_aad[] = { + 0xe2, 0x01, 0x06, 0xd7, 0xcd, 0x0d, 0xf0, 0x76, 0x1e, 0x8d, + 0xcd, 0x3d, 0x88, 0xe5, 0x4c, 0x2a, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_11_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, + 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x00, 0x04}; + const uint8_t test_vector_11_ciphertext_and_tag[] = { + 0x13, 0xb4, 0xc7, 0x2b, 0x38, 0x9d, 0xc5, 0x01, 0x8e, 0x72, 0xa1, 0x71, + 0xdd, 0x85, 0xa5, 0xd3, 0x75, 0x22, 0x74, 0xd3, 0xa0, 0x19, 0xfb, 0xca, + 0xed, 0x09, 0xa4, 0x25, 0xcd, 0x9b, 0x2e, 0x1c, 0x9b, 0x72, 0xee, 0xe7, + 0xc9, 0xde, 0x7d, 0x52, 0xb3, 0xf3, 0xd6, 0xa5, 0x28, 0x4f, 0x4a, 0x6d, + 0x3f, 0xe2, 0x2a, 0x5d, 0x6c, 0x2b, 0x96, 0x04, 0x94, 0xc3}; + gsec_aead_malloc_test_vector( + &test_vector_11, test_vector_11_key, + sizeof(test_vector_11_key) / sizeof(uint8_t), test_vector_11_nonce, + sizeof(test_vector_11_nonce) / sizeof(uint8_t), test_vector_11_aad, + sizeof(test_vector_11_aad) / sizeof(uint8_t), test_vector_11_plaintext, + sizeof(test_vector_11_plaintext) / sizeof(uint8_t), + test_vector_11_ciphertext_and_tag, + sizeof(test_vector_11_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_11); + gsec_aead_free_test_vector(test_vector_11); + + /* 2.4.2 54-byte crypt */ + gsec_aead_test_vector* test_vector_12; + const uint8_t test_vector_12_key[] = { + 0x69, 0x1d, 0x3e, 0xe9, 0x09, 0xd7, 0xf5, 0x41, 0x67, 0xfd, 0x1c, + 0xa0, 0xb5, 0xd7, 0x69, 0x08, 0x1f, 0x2b, 0xde, 0x1a, 0xee, 0x65, + 0x5f, 0xdb, 0xab, 0x80, 0xbd, 0x52, 0x95, 0xae, 0x6b, 0xe7}; + const uint8_t test_vector_12_nonce[] = {0xf0, 0x76, 0x1e, 0x8d, 0xcd, 0x3d, + 0x00, 0x01, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_12_aad[] = { + 0xe2, 0x01, 0x06, 0xd7, 0xcd, 0x0d, 0xf0, 0x76, 0x1e, 0x8d, + 0xcd, 0x3d, 0x88, 0xe5, 0x4c, 0x2a, 0x76, 0xd4, 0x57, 0xed}; + const uint8_t test_vector_12_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, + 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x00, 0x04}; + const uint8_t test_vector_12_ciphertext_and_tag[] = { + 0xc1, 0x62, 0x3f, 0x55, 0x73, 0x0c, 0x93, 0x53, 0x30, 0x97, 0xad, 0xda, + 0xd2, 0x56, 0x64, 0x96, 0x61, 0x25, 0x35, 0x2b, 0x43, 0xad, 0xac, 0xbd, + 0x61, 0xc5, 0xef, 0x3a, 0xc9, 0x0b, 0x5b, 0xee, 0x92, 0x9c, 0xe4, 0x63, + 0x0e, 0xa7, 0x9f, 0x6c, 0xe5, 0x19, 0x12, 0xaf, 0x39, 0xc2, 0xd1, 0xfd, + 0xc2, 0x05, 0x1f, 0x8b, 0x7b, 0x3c, 0x9d, 0x39, 0x7e, 0xf2}; + gsec_aead_malloc_test_vector( + &test_vector_12, test_vector_12_key, + sizeof(test_vector_12_key) / sizeof(uint8_t), test_vector_12_nonce, + sizeof(test_vector_12_nonce) / sizeof(uint8_t), test_vector_12_aad, + sizeof(test_vector_12_aad) / sizeof(uint8_t), test_vector_12_plaintext, + sizeof(test_vector_12_plaintext) / sizeof(uint8_t), + test_vector_12_ciphertext_and_tag, + sizeof(test_vector_12_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_12); + gsec_aead_free_test_vector(test_vector_12); + + /* 2.5.1 65-byte auth */ + gsec_aead_test_vector* test_vector_13; + const uint8_t test_vector_13_key[] = {0x01, 0x3f, 0xe0, 0x0b, 0x5f, 0x11, + 0xbe, 0x7f, 0x86, 0x6d, 0x0c, 0xbb, + 0xc5, 0x5a, 0x7a, 0x90}; + const uint8_t test_vector_13_nonce[] = {0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, + 0x24, 0xc6, 0x89, 0x32, 0xd6, 0x12}; + const uint8_t test_vector_13_aad[] = { + 0x84, 0xc5, 0xd5, 0x13, 0xd2, 0xaa, 0xf6, 0xe5, 0xbb, 0xd2, 0x72, 0x77, + 0x88, 0xe5, 0x23, 0x00, 0x89, 0x32, 0xd6, 0x12, 0x7c, 0xfd, 0xe9, 0xf9, + 0xe3, 0x37, 0x24, 0xc6, 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x00, 0x05}; + const uint8_t test_vector_13_plaintext[1] = {}; + const uint8_t test_vector_13_ciphertext_and_tag[] = { + 0x21, 0x78, 0x67, 0xe5, 0x0c, 0x2d, 0xad, 0x74, + 0xc2, 0x8c, 0x3b, 0x50, 0xab, 0xdf, 0x69, 0x5a}; + gsec_aead_malloc_test_vector( + &test_vector_13, test_vector_13_key, + sizeof(test_vector_13_key) / sizeof(uint8_t), test_vector_13_nonce, + sizeof(test_vector_13_nonce) / sizeof(uint8_t), test_vector_13_aad, + sizeof(test_vector_13_aad) / sizeof(uint8_t), test_vector_13_plaintext, 0, + test_vector_13_ciphertext_and_tag, + sizeof(test_vector_13_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_13); + gsec_aead_free_test_vector(test_vector_13); + + /* 2.5.2 65-byte auth */ + gsec_aead_test_vector* test_vector_14; + const uint8_t test_vector_14_key[] = { + 0x83, 0xc0, 0x93, 0xb5, 0x8d, 0xe7, 0xff, 0xe1, 0xc0, 0xda, 0x92, + 0x6a, 0xc4, 0x3f, 0xb3, 0x60, 0x9a, 0xc1, 0xc8, 0x0f, 0xee, 0x1b, + 0x62, 0x44, 0x97, 0xef, 0x94, 0x2e, 0x2f, 0x79, 0xa8, 0x23}; + const uint8_t test_vector_14_nonce[] = {0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, + 0x24, 0xc6, 0x89, 0x32, 0xd6, 0x12}; + const uint8_t test_vector_14_aad[] = { + 0x84, 0xc5, 0xd5, 0x13, 0xd2, 0xaa, 0xf6, 0xe5, 0xbb, 0xd2, 0x72, 0x77, + 0x88, 0xe5, 0x23, 0x00, 0x89, 0x32, 0xd6, 0x12, 0x7c, 0xfd, 0xe9, 0xf9, + 0xe3, 0x37, 0x24, 0xc6, 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x00, 0x05}; + const uint8_t test_vector_14_plaintext[1] = {}; + const uint8_t test_vector_14_ciphertext_and_tag[] = { + 0x6e, 0xe1, 0x60, 0xe8, 0xfa, 0xec, 0xa4, 0xb3, + 0x6c, 0x86, 0xb2, 0x34, 0x92, 0x0c, 0xa9, 0x75}; + gsec_aead_malloc_test_vector( + &test_vector_14, test_vector_14_key, + sizeof(test_vector_14_key) / sizeof(uint8_t), test_vector_14_nonce, + sizeof(test_vector_14_nonce) / sizeof(uint8_t), test_vector_14_aad, + sizeof(test_vector_14_aad) / sizeof(uint8_t), test_vector_14_plaintext, 0, + test_vector_14_ciphertext_and_tag, + sizeof(test_vector_14_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_14); + gsec_aead_free_test_vector(test_vector_14); + + /* 2.6.1 61-byte crypt */ + gsec_aead_test_vector* test_vector_15; + const uint8_t test_vector_15_key[] = {0x01, 0x3f, 0xe0, 0x0b, 0x5f, 0x11, + 0xbe, 0x7f, 0x86, 0x6d, 0x0c, 0xbb, + 0xc5, 0x5a, 0x7a, 0x90}; + const uint8_t test_vector_15_nonce[] = {0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, + 0x24, 0xc6, 0x89, 0x32, 0xd6, 0x12}; + const uint8_t test_vector_15_aad[] = { + 0x84, 0xc5, 0xd5, 0x13, 0xd2, 0xaa, 0xf6, 0xe5, 0xbb, 0xd2, + 0x72, 0x77, 0x88, 0xe5, 0x2f, 0x00, 0x89, 0x32, 0xd6, 0x12, + 0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, 0x24, 0xc6}; + const uint8_t test_vector_15_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, + 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x00, 0x06}; + const uint8_t test_vector_15_ciphertext_and_tag[] = { + 0x3a, 0x4d, 0xe6, 0xfa, 0x32, 0x19, 0x10, 0x14, 0xdb, 0xb3, 0x03, + 0xd9, 0x2e, 0xe3, 0xa9, 0xe8, 0xa1, 0xb5, 0x99, 0xc1, 0x4d, 0x22, + 0xfb, 0x08, 0x00, 0x96, 0xe1, 0x38, 0x11, 0x81, 0x6a, 0x3c, 0x9c, + 0x9b, 0xcf, 0x7c, 0x1b, 0x9b, 0x96, 0xda, 0x80, 0x92, 0x04, 0xe2, + 0x9d, 0x0e, 0x2a, 0x76, 0x42, 0xbf, 0xd3, 0x10, 0xa4, 0x83, 0x7c, + 0x81, 0x6c, 0xcf, 0xa5, 0xac, 0x23, 0xab, 0x00, 0x39, 0x88}; + gsec_aead_malloc_test_vector( + &test_vector_15, test_vector_15_key, + sizeof(test_vector_15_key) / sizeof(uint8_t), test_vector_15_nonce, + sizeof(test_vector_15_nonce) / sizeof(uint8_t), test_vector_15_aad, + sizeof(test_vector_15_aad) / sizeof(uint8_t), test_vector_15_plaintext, + sizeof(test_vector_15_plaintext) / sizeof(uint8_t), + test_vector_15_ciphertext_and_tag, + sizeof(test_vector_15_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_15); + gsec_aead_free_test_vector(test_vector_15); + + /* 2.6.2 61-byte crypt */ + gsec_aead_test_vector* test_vector_16; + const uint8_t test_vector_16_key[] = { + 0x83, 0xc0, 0x93, 0xb5, 0x8d, 0xe7, 0xff, 0xe1, 0xc0, 0xda, 0x92, + 0x6a, 0xc4, 0x3f, 0xb3, 0x60, 0x9a, 0xc1, 0xc8, 0x0f, 0xee, 0x1b, + 0x62, 0x44, 0x97, 0xef, 0x94, 0x2e, 0x2f, 0x79, 0xa8, 0x23}; + const uint8_t test_vector_16_nonce[] = {0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, + 0x24, 0xc6, 0x89, 0x32, 0xd6, 0x12}; + const uint8_t test_vector_16_aad[] = { + 0x84, 0xc5, 0xd5, 0x13, 0xd2, 0xaa, 0xf6, 0xe5, 0xbb, 0xd2, + 0x72, 0x77, 0x88, 0xe5, 0x2f, 0x00, 0x89, 0x32, 0xd6, 0x12, + 0x7c, 0xfd, 0xe9, 0xf9, 0xe3, 0x37, 0x24, 0xc6}; + const uint8_t test_vector_16_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, + 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x00, 0x06}; + const uint8_t test_vector_16_ciphertext_and_tag[] = { + 0x11, 0x02, 0x22, 0xff, 0x80, 0x50, 0xcb, 0xec, 0xe6, 0x6a, 0x81, + 0x3a, 0xd0, 0x9a, 0x73, 0xed, 0x7a, 0x9a, 0x08, 0x9c, 0x10, 0x6b, + 0x95, 0x93, 0x89, 0x16, 0x8e, 0xd6, 0xe8, 0x69, 0x8e, 0xa9, 0x02, + 0xeb, 0x12, 0x77, 0xdb, 0xec, 0x2e, 0x68, 0xe4, 0x73, 0x15, 0x5a, + 0x15, 0xa7, 0xda, 0xee, 0xd4, 0xa1, 0x0f, 0x4e, 0x05, 0x13, 0x9c, + 0x23, 0xdf, 0x00, 0xb3, 0xaa, 0xdc, 0x71, 0xf0, 0x59, 0x6a}; + gsec_aead_malloc_test_vector( + &test_vector_16, test_vector_16_key, + sizeof(test_vector_16_key) / sizeof(uint8_t), test_vector_16_nonce, + sizeof(test_vector_16_nonce) / sizeof(uint8_t), test_vector_16_aad, + sizeof(test_vector_16_aad) / sizeof(uint8_t), test_vector_16_plaintext, + sizeof(test_vector_16_plaintext) / sizeof(uint8_t), + test_vector_16_ciphertext_and_tag, + sizeof(test_vector_16_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_16); + gsec_aead_free_test_vector(test_vector_16); + + /* 2.7.1 79-byte crypt */ + gsec_aead_test_vector* test_vector_17; + const uint8_t test_vector_17_key[] = {0x88, 0xee, 0x08, 0x7f, 0xd9, 0x5d, + 0xa9, 0xfb, 0xf6, 0x72, 0x5a, 0xa9, + 0xd7, 0x57, 0xb0, 0xcd}; + const uint8_t test_vector_17_nonce[] = {0x7a, 0xe8, 0xe2, 0xca, 0x4e, 0xc5, + 0x00, 0x01, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_17_aad[] = { + 0x68, 0xf2, 0xe7, 0x76, 0x96, 0xce, 0x7a, 0xe8, 0xe2, 0xca, 0x4e, + 0xc5, 0x88, 0xe5, 0x41, 0x00, 0x2e, 0x58, 0x49, 0x5c, 0x08, 0x00, + 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, + 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, + 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x00, 0x07}; + const uint8_t test_vector_17_plaintext[1] = {}; + const uint8_t test_vector_17_ciphertext_and_tag[] = { + 0x07, 0x92, 0x2b, 0x8e, 0xbc, 0xf1, 0x0b, 0xb2, + 0x29, 0x75, 0x88, 0xca, 0x4c, 0x61, 0x45, 0x23}; + gsec_aead_malloc_test_vector( + &test_vector_17, test_vector_17_key, + sizeof(test_vector_17_key) / sizeof(uint8_t), test_vector_17_nonce, + sizeof(test_vector_17_nonce) / sizeof(uint8_t), test_vector_17_aad, + sizeof(test_vector_17_aad) / sizeof(uint8_t), test_vector_17_plaintext, 0, + test_vector_17_ciphertext_and_tag, + sizeof(test_vector_17_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_17); + gsec_aead_free_test_vector(test_vector_17); + + /* 2.7.2 79-byte crypt */ + gsec_aead_test_vector* test_vector_18; + const uint8_t test_vector_18_key[] = { + 0x4c, 0x97, 0x3d, 0xbc, 0x73, 0x64, 0x62, 0x16, 0x74, 0xf8, 0xb5, + 0xb8, 0x9e, 0x5c, 0x15, 0x51, 0x1f, 0xce, 0xd9, 0x21, 0x64, 0x90, + 0xfb, 0x1c, 0x1a, 0x2c, 0xaa, 0x0f, 0xfe, 0x04, 0x07, 0xe5}; + const uint8_t test_vector_18_nonce[] = {0x7a, 0xe8, 0xe2, 0xca, 0x4e, 0xc5, + 0x00, 0x01, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_18_aad[] = { + 0x68, 0xf2, 0xe7, 0x76, 0x96, 0xce, 0x7a, 0xe8, 0xe2, 0xca, 0x4e, + 0xc5, 0x88, 0xe5, 0x41, 0x00, 0x2e, 0x58, 0x49, 0x5c, 0x08, 0x00, + 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, + 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, + 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x00, 0x07}; + const uint8_t test_vector_18_plaintext[1] = {}; + const uint8_t test_vector_18_ciphertext_and_tag[] = { + 0x00, 0xbd, 0xa1, 0xb7, 0xe8, 0x76, 0x08, 0xbc, + 0xbf, 0x47, 0x0f, 0x12, 0x15, 0x7f, 0x4c, 0x07}; + gsec_aead_malloc_test_vector( + &test_vector_18, test_vector_18_key, + sizeof(test_vector_18_key) / sizeof(uint8_t), test_vector_18_nonce, + sizeof(test_vector_18_nonce) / sizeof(uint8_t), test_vector_18_aad, + sizeof(test_vector_18_aad) / sizeof(uint8_t), test_vector_18_plaintext, 0, + test_vector_18_ciphertext_and_tag, + sizeof(test_vector_18_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_18); + gsec_aead_free_test_vector(test_vector_18); + + /* 2.8.1 61-byte crypt */ + gsec_aead_test_vector* test_vector_19; + const uint8_t test_vector_19_key[] = {0x88, 0xee, 0x08, 0x7f, 0xd9, 0x5d, + 0xa9, 0xfb, 0xf6, 0x72, 0x5a, 0xa9, + 0xd7, 0x57, 0xb0, 0xcd}; + const uint8_t test_vector_19_nonce[] = {0x7a, 0xe8, 0xe2, 0xca, 0x4e, 0xc5, + 0x00, 0x01, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_19_aad[] = { + 0x68, 0xf2, 0xe7, 0x76, 0x96, 0xce, 0x7a, 0xe8, 0xe2, 0xca, + 0x4e, 0xc5, 0x88, 0xe5, 0x4d, 0x00, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_19_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, + 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x00, 0x08}; + const uint8_t test_vector_19_ciphertext_and_tag[] = { + 0xc3, 0x1f, 0x53, 0xd9, 0x9e, 0x56, 0x87, 0xf7, 0x36, 0x51, 0x19, 0xb8, + 0x32, 0xd2, 0xaa, 0xe7, 0x07, 0x41, 0xd5, 0x93, 0xf1, 0xf9, 0xe2, 0xab, + 0x34, 0x55, 0x77, 0x9b, 0x07, 0x8e, 0xb8, 0xfe, 0xac, 0xdf, 0xec, 0x1f, + 0x8e, 0x3e, 0x52, 0x77, 0xf8, 0x18, 0x0b, 0x43, 0x36, 0x1f, 0x65, 0x12, + 0xad, 0xb1, 0x6d, 0x2e, 0x38, 0x54, 0x8a, 0x2c, 0x71, 0x9d, 0xba, 0x72, + 0x28, 0xd8, 0x40, 0x88, 0xf8, 0x75, 0x7a, 0xdb, 0x8a, 0xa7, 0x88, 0xd8, + 0xf6, 0x5a, 0xd6, 0x68, 0xbe, 0x70, 0xe7}; + gsec_aead_malloc_test_vector( + &test_vector_19, test_vector_19_key, + sizeof(test_vector_19_key) / sizeof(uint8_t), test_vector_19_nonce, + sizeof(test_vector_19_nonce) / sizeof(uint8_t), test_vector_19_aad, + sizeof(test_vector_19_aad) / sizeof(uint8_t), test_vector_19_plaintext, + sizeof(test_vector_19_plaintext) / sizeof(uint8_t), + test_vector_19_ciphertext_and_tag, + sizeof(test_vector_19_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_19); + gsec_aead_free_test_vector(test_vector_19); + + /* 2.8.2 61-byte crypt */ + gsec_aead_test_vector* test_vector_20; + const uint8_t test_vector_20_key[] = { + 0x4c, 0x97, 0x3d, 0xbc, 0x73, 0x64, 0x62, 0x16, 0x74, 0xf8, 0xb5, + 0xb8, 0x9e, 0x5c, 0x15, 0x51, 0x1f, 0xce, 0xd9, 0x21, 0x64, 0x90, + 0xfb, 0x1c, 0x1a, 0x2c, 0xaa, 0x0f, 0xfe, 0x04, 0x07, 0xe5}; + const uint8_t test_vector_20_nonce[] = {0x7a, 0xe8, 0xe2, 0xca, 0x4e, 0xc5, + 0x00, 0x01, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_20_aad[] = { + 0x68, 0xf2, 0xe7, 0x76, 0x96, 0xce, 0x7a, 0xe8, 0xe2, 0xca, + 0x4e, 0xc5, 0x88, 0xe5, 0x4d, 0x00, 0x2e, 0x58, 0x49, 0x5c}; + const uint8_t test_vector_20_plaintext[] = { + 0x08, 0x00, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, + 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, + 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x00, 0x08}; + const uint8_t test_vector_20_ciphertext_and_tag[] = { + 0xba, 0x8a, 0xe3, 0x1b, 0xc5, 0x06, 0x48, 0x6d, 0x68, 0x73, 0xe4, 0xfc, + 0xe4, 0x60, 0xe7, 0xdc, 0x57, 0x59, 0x1f, 0xf0, 0x06, 0x11, 0xf3, 0x1c, + 0x38, 0x34, 0xfe, 0x1c, 0x04, 0xad, 0x80, 0xb6, 0x68, 0x03, 0xaf, 0xcf, + 0x5b, 0x27, 0xe6, 0x33, 0x3f, 0xa6, 0x7c, 0x99, 0xda, 0x47, 0xc2, 0xf0, + 0xce, 0xd6, 0x8d, 0x53, 0x1b, 0xd7, 0x41, 0xa9, 0x43, 0xcf, 0xf7, 0xa6, + 0x71, 0x3b, 0xd0, 0x26, 0x11, 0xcd, 0x7d, 0xaa, 0x01, 0xd6, 0x1c, 0x5c, + 0x88, 0x6d, 0xc1, 0xa8, 0x17, 0x01, 0x07}; + gsec_aead_malloc_test_vector( + &test_vector_20, test_vector_20_key, + sizeof(test_vector_20_key) / sizeof(uint8_t), test_vector_20_nonce, + sizeof(test_vector_20_nonce) / sizeof(uint8_t), test_vector_20_aad, + sizeof(test_vector_20_aad) / sizeof(uint8_t), test_vector_20_plaintext, + sizeof(test_vector_20_plaintext) / sizeof(uint8_t), + test_vector_20_ciphertext_and_tag, + sizeof(test_vector_20_ciphertext_and_tag) / sizeof(uint8_t)); + gsec_test_verify_crypter_on_test_vector(test_vector_20); + gsec_aead_free_test_vector(test_vector_20); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + gsec_test_do_generic_crypter_tests(); + gsec_test_do_vector_tests_nist(); + gsec_test_do_vector_tests_ieee(); + gsec_test_do_vector_tests_rekey_nist(); + gsec_test_do_vector_tests_rekey_ieee(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/crypt/gsec_test_util.cc b/test/core/tsi/alts/crypt/gsec_test_util.cc new file mode 100644 index 00000000..29533fa8 --- /dev/null +++ b/test/core/tsi/alts/crypt/gsec_test_util.cc @@ -0,0 +1,91 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/tsi/alts/crypt/gsec_test_util.h" + +#include + +#include + +void gsec_test_random_bytes(uint8_t* bytes, size_t length) { + srand(time(nullptr)); + size_t ind; + for (ind = 0; ind < length; ind++) { + bytes[ind] = static_cast(rand() % 255 + 1); + } +} + +void gsec_test_random_array(uint8_t** bytes, size_t length) { + if (bytes != nullptr) { + *bytes = static_cast(gpr_malloc(length)); + gsec_test_random_bytes(*bytes, length); + } else { + fprintf(stderr, "bytes buffer is nullptr in gsec_test_random_array()."); + abort(); + } +} + +uint32_t gsec_test_bias_random_uint32(uint32_t max_length) { + uint32_t value; + gsec_test_random_bytes(reinterpret_cast(&value), sizeof(value)); + return value % max_length; +} + +void gsec_test_copy(const uint8_t* src, uint8_t** des, size_t source_len) { + if (src != nullptr && des != nullptr) { + *des = static_cast(gpr_malloc(source_len)); + if (*des != nullptr) { + memcpy(*des, src, source_len); + } + } else { + fprintf(stderr, "Either src or des buffer is nullptr in gsec_test_copy()."); + abort(); + } +} + +void gsec_test_copy_and_alter_random_byte(const uint8_t* src, uint8_t** des, + size_t source_len) { + if (src != nullptr && des != nullptr) { + *des = static_cast(gpr_malloc(source_len)); + memcpy(*des, src, source_len); + uint32_t offset; + offset = gsec_test_bias_random_uint32(static_cast(source_len)); + (*(*des + offset))++; + } else { + fprintf(stderr, + "Either src or des is nullptr in " + "gsec_test_copy_and_alter_random_byte()."); + abort(); + } +} + +int gsec_test_expect_compare_code_and_substr(grpc_status_code status1, + grpc_status_code status2, + const char* msg1, + const char* msg2) { + int failure = 1; + if (status1 != status2) { + fprintf(stderr, "Status %d does not equal %d.\n", status1, status2); + failure = 0; + } + if (strstr(msg1, msg2) == nullptr) { + fprintf(stderr, "Status message <%s> does not contain <%s>.\n", msg1, msg2); + failure = 0; + } + return failure; +} diff --git a/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc new file mode 100644 index 00000000..559709fa --- /dev/null +++ b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server.cc @@ -0,0 +1,290 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.h" +#include "test/core/tsi/alts/fake_handshaker/handshaker.pb.h" +#include "test/core/tsi/alts/fake_handshaker/transport_security_common.pb.h" + +// Fake handshake messages. +constexpr char kClientInitFrame[] = "ClientInit"; +constexpr char kServerFrame[] = "ServerInitAndFinished"; +constexpr char kClientFinishFrame[] = "ClientFinished"; +// Error messages. +constexpr char kInvalidFrameError[] = "Invalid input frame."; +constexpr char kWrongStateError[] = "Wrong handshake state."; + +namespace grpc { +namespace gcp { + +// FakeHandshakeService implements a fake handshaker service using a fake key +// exchange protocol. The fake key exchange protocol is a 3-message protocol: +// - Client first sends ClientInit message to Server. +// - Server then sends ServerInitAndFinished message back to Client. +// - Client finally sends ClientFinished message to Server. +// This fake handshaker service is intended for ALTS integration testing without +// relying on real ALTS handshaker service inside GCE. +// It is thread-safe. +class FakeHandshakerService : public HandshakerService::Service { + public: + explicit FakeHandshakerService(int expected_max_concurrent_rpcs) + : expected_max_concurrent_rpcs_(expected_max_concurrent_rpcs) {} + + Status DoHandshake( + ServerContext* /*server_context*/, + ServerReaderWriter* stream) override { + ConcurrentRpcsCheck concurrent_rpcs_check(this); + Status status; + HandshakerContext context; + HandshakerReq request; + HandshakerResp response; + gpr_log(GPR_DEBUG, "Start a new handshake."); + while (stream->Read(&request)) { + status = ProcessRequest(&context, request, &response); + if (!status.ok()) return WriteErrorResponse(stream, status); + stream->Write(response); + if (context.state == COMPLETED) return Status::OK; + request.Clear(); + } + return Status::OK; + } + + private: + // HandshakeState is used by fake handshaker server to keep track of client's + // handshake status. In the beginning of a handshake, the state is INITIAL. + // If start_client or start_server request is called, the state becomes at + // least STARTED. When the handshaker server produces the first fame, the + // state becomes SENT. After the handshaker server processes the final frame + // from the peer, the state becomes COMPLETED. + enum HandshakeState { INITIAL, STARTED, SENT, COMPLETED }; + + struct HandshakerContext { + bool is_client = true; + HandshakeState state = INITIAL; + }; + + Status ProcessRequest(HandshakerContext* context, + const HandshakerReq& request, + HandshakerResp* response) { + GPR_ASSERT(context != nullptr && response != nullptr); + response->Clear(); + if (request.has_client_start()) { + gpr_log(GPR_DEBUG, "Process client start request."); + return ProcessClientStart(context, request.client_start(), response); + } else if (request.has_server_start()) { + gpr_log(GPR_DEBUG, "Process server start request."); + return ProcessServerStart(context, request.server_start(), response); + } else if (request.has_next()) { + gpr_log(GPR_DEBUG, "Process next request."); + return ProcessNext(context, request.next(), response); + } + return Status(StatusCode::INVALID_ARGUMENT, "Request is empty."); + } + + Status ProcessClientStart(HandshakerContext* context, + const StartClientHandshakeReq& request, + HandshakerResp* response) { + GPR_ASSERT(context != nullptr && response != nullptr); + // Checks request. + if (context->state != INITIAL) { + return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError); + } + if (request.application_protocols_size() == 0) { + return Status(StatusCode::INVALID_ARGUMENT, + "At least one application protocol needed."); + } + if (request.record_protocols_size() == 0) { + return Status(StatusCode::INVALID_ARGUMENT, + "At least one record protocol needed."); + } + // Sets response. + response->set_out_frames(kClientInitFrame); + response->set_bytes_consumed(0); + response->mutable_status()->set_code(StatusCode::OK); + // Updates handshaker context. + context->is_client = true; + context->state = SENT; + return Status::OK; + } + + Status ProcessServerStart(HandshakerContext* context, + const StartServerHandshakeReq& request, + HandshakerResp* response) { + GPR_ASSERT(context != nullptr && response != nullptr); + // Checks request. + if (context->state != INITIAL) { + return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError); + } + if (request.application_protocols_size() == 0) { + return Status(StatusCode::INVALID_ARGUMENT, + "At least one application protocol needed."); + } + if (request.handshake_parameters().empty()) { + return Status(StatusCode::INVALID_ARGUMENT, + "At least one set of handshake parameters needed."); + } + // Sets response. + if (request.in_bytes().empty()) { + // start_server request does not have in_bytes. + response->set_bytes_consumed(0); + context->state = STARTED; + } else { + // start_server request has in_bytes. + if (request.in_bytes() == kClientInitFrame) { + response->set_out_frames(kServerFrame); + response->set_bytes_consumed(strlen(kClientInitFrame)); + context->state = SENT; + } else { + return Status(StatusCode::UNKNOWN, kInvalidFrameError); + } + } + response->mutable_status()->set_code(StatusCode::OK); + context->is_client = false; + return Status::OK; + } + + Status ProcessNext(HandshakerContext* context, + const NextHandshakeMessageReq& request, + HandshakerResp* response) { + GPR_ASSERT(context != nullptr && response != nullptr); + if (context->is_client) { + // Processes next request on client side. + if (context->state != SENT) { + return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError); + } + if (request.in_bytes() != kServerFrame) { + return Status(StatusCode::UNKNOWN, kInvalidFrameError); + } + response->set_out_frames(kClientFinishFrame); + response->set_bytes_consumed(strlen(kServerFrame)); + context->state = COMPLETED; + } else { + // Processes next request on server side. + HandshakeState current_state = context->state; + if (current_state == STARTED) { + if (request.in_bytes() != kClientInitFrame) { + return Status(StatusCode::UNKNOWN, kInvalidFrameError); + } + response->set_out_frames(kServerFrame); + response->set_bytes_consumed(strlen(kClientInitFrame)); + context->state = SENT; + } else if (current_state == SENT) { + // Client finish frame may be sent along with the first payload from the + // client, handshaker only consumes the client finish frame. + if (request.in_bytes().substr(0, strlen(kClientFinishFrame)) != + kClientFinishFrame) { + return Status(StatusCode::UNKNOWN, kInvalidFrameError); + } + response->set_bytes_consumed(strlen(kClientFinishFrame)); + context->state = COMPLETED; + } else { + return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError); + } + } + // At this point, processing next request succeeded. + response->mutable_status()->set_code(StatusCode::OK); + if (context->state == COMPLETED) { + *response->mutable_result() = GetHandshakerResult(); + } + return Status::OK; + } + + Status WriteErrorResponse( + ServerReaderWriter* stream, + const Status& status) { + GPR_ASSERT(!status.ok()); + HandshakerResp response; + response.mutable_status()->set_code(status.error_code()); + response.mutable_status()->set_details(status.error_message()); + stream->Write(response); + return status; + } + + HandshakerResult GetHandshakerResult() { + HandshakerResult result; + result.set_application_protocol("grpc"); + result.set_record_protocol("ALTSRP_GCM_AES128_REKEY"); + result.mutable_peer_identity()->set_service_account("peer_identity"); + result.mutable_local_identity()->set_service_account("local_identity"); + string key(1024, '\0'); + result.set_key_data(key); + result.set_max_frame_size(16384); + result.mutable_peer_rpc_versions()->mutable_max_rpc_version()->set_major(2); + result.mutable_peer_rpc_versions()->mutable_max_rpc_version()->set_minor(1); + result.mutable_peer_rpc_versions()->mutable_min_rpc_version()->set_major(2); + result.mutable_peer_rpc_versions()->mutable_min_rpc_version()->set_minor(1); + return result; + } + + class ConcurrentRpcsCheck { + public: + explicit ConcurrentRpcsCheck(FakeHandshakerService* parent) + : parent_(parent) { + if (parent->expected_max_concurrent_rpcs_ > 0) { + grpc::internal::MutexLock lock( + &parent->expected_max_concurrent_rpcs_mu_); + if (++parent->concurrent_rpcs_ > + parent->expected_max_concurrent_rpcs_) { + gpr_log(GPR_ERROR, + "FakeHandshakerService:%p concurrent_rpcs_:%d " + "expected_max_concurrent_rpcs:%d", + parent, parent->concurrent_rpcs_, + parent->expected_max_concurrent_rpcs_); + abort(); + } + } + } + + ~ConcurrentRpcsCheck() { + if (parent_->expected_max_concurrent_rpcs_ > 0) { + grpc::internal::MutexLock lock( + &parent_->expected_max_concurrent_rpcs_mu_); + parent_->concurrent_rpcs_--; + } + } + + private: + FakeHandshakerService* parent_; + }; + + grpc::internal::Mutex expected_max_concurrent_rpcs_mu_; + int concurrent_rpcs_ = 0; + const int expected_max_concurrent_rpcs_; +}; + +std::unique_ptr CreateFakeHandshakerService( + int expected_max_concurrent_rpcs) { + return std::unique_ptr{ + new grpc::gcp::FakeHandshakerService(expected_max_concurrent_rpcs)}; +} + +} // namespace gcp +} // namespace grpc diff --git a/test/core/tsi/alts/fake_handshaker/fake_handshaker_server_main.cc b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server_main.cc new file mode 100644 index 00000000..5f24e9a2 --- /dev/null +++ b/test/core/tsi/alts/fake_handshaker/fake_handshaker_server_main.cc @@ -0,0 +1,56 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(int32_t, handshaker_port, 55056, + "TCP port on which the fake handshaker server listens to."); + +static void RunFakeHandshakerServer(const std::string& server_address) { + std::unique_ptr service = + grpc::gcp::CreateFakeHandshakerService( + 0 /* expected max concurrent rpcs unset */); + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + gpr_log(GPR_INFO, "Fake handshaker server listening on %s", + server_address.c_str()); + std::unique_ptr server = builder.BuildAndStart(); + server->Wait(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + GPR_ASSERT(absl::GetFlag(FLAGS_handshaker_port) != 0); + std::ostringstream server_address; + server_address << "[::1]:" << absl::GetFlag(FLAGS_handshaker_port); + + RunFakeHandshakerServer(server_address.str()); + return 0; +} diff --git a/test/core/tsi/alts/frame_protector/alts_counter_test.cc b/test/core/tsi/alts/frame_protector/alts_counter_test.cc new file mode 100644 index 00000000..80fd95fb --- /dev/null +++ b/test/core/tsi/alts/frame_protector/alts_counter_test.cc @@ -0,0 +1,181 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/frame_protector/alts_counter.h" + +#include +#include + +#include "test/core/tsi/alts/crypt/gsec_test_util.h" + +const size_t kSmallCounterSize = 4; +const size_t kSmallOverflowSize = 1; +const size_t kGcmCounterSize = 12; +const size_t kGcmOverflowSize = 5; + +static bool do_bytes_represent_client(alts_counter* ctr, + unsigned char* /*counter*/, size_t size) { + return (ctr->counter[size - 1] & 0x80) == 0x80; +} + +static void alts_counter_test_input_sanity_check(size_t counter_size, + size_t overflow_size) { + alts_counter* ctr = nullptr; + char* error_details = nullptr; + + /* Input sanity check on alts_counter_create(). */ + /* Invalid counter size. */ + grpc_status_code status = + alts_counter_create(true, 0, overflow_size, &ctr, &error_details); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_details, + "counter_size is invalid.")); + gpr_free(error_details); + + /* Invalid overflow size. */ + status = alts_counter_create(true, counter_size, 0, &ctr, &error_details); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_details, + "overflow_size is invalid.")); + gpr_free(error_details); + + /* alts_counter is nullptr. */ + status = alts_counter_create(true, counter_size, overflow_size, nullptr, + &error_details); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_details, + "crypter_counter is nullptr.")); + gpr_free(error_details); + + status = alts_counter_create(true, counter_size, overflow_size, &ctr, + &error_details); + GPR_ASSERT(status == GRPC_STATUS_OK); + + /* Input sanity check on alts_counter_increment(). */ + /* crypter_counter is nullptr. */ + bool is_overflow = false; + status = alts_counter_increment(nullptr, &is_overflow, &error_details); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_details, + "crypter_counter is nullptr.")); + gpr_free(error_details); + /* is_overflow is nullptr. */ + status = alts_counter_increment(ctr, nullptr, &error_details); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_details, + "is_overflow is nullptr.")); + gpr_free(error_details); + alts_counter_destroy(ctr); +} + +static void alts_counter_test_overflow_full_range(bool is_client, + size_t counter_size, + size_t overflow_size) { + alts_counter* ctr = nullptr; + char* error_details = nullptr; + grpc_status_code status = alts_counter_create( + is_client, counter_size, overflow_size, &ctr, &error_details); + GPR_ASSERT(status == GRPC_STATUS_OK); + unsigned char* expected = + static_cast(gpr_zalloc(counter_size)); + if (is_client) { + expected[counter_size - 1] = 0x80; + } + /* Do a single iteration to ensure the counter is initialized as expected. */ + GPR_ASSERT(do_bytes_represent_client(ctr, alts_counter_get_counter(ctr), + counter_size) == is_client); + GPR_ASSERT(memcmp(alts_counter_get_counter(ctr), expected, counter_size) == + 0); + bool is_overflow = false; + GPR_ASSERT(alts_counter_increment(ctr, &is_overflow, &error_details) == + GRPC_STATUS_OK); + GPR_ASSERT(!is_overflow); + /** + * The counter can return 2^{overflow_size * 8} counters. The + * high-order bit is fixed to the client/server. The last call will yield a + * useable counter, but overflow the counter object. + */ + int iterations = 1 << (overflow_size * 8); + int ind = 1; + for (ind = 1; ind < iterations - 1; ind++) { + GPR_ASSERT(do_bytes_represent_client(ctr, alts_counter_get_counter(ctr), + counter_size) == is_client); + GPR_ASSERT(alts_counter_increment(ctr, &is_overflow, &error_details) == + GRPC_STATUS_OK); + GPR_ASSERT(!is_overflow); + } + GPR_ASSERT(do_bytes_represent_client(ctr, alts_counter_get_counter(ctr), + counter_size) == is_client); + GPR_ASSERT(alts_counter_increment(ctr, &is_overflow, &error_details) == + GRPC_STATUS_FAILED_PRECONDITION); + GPR_ASSERT(is_overflow); + gpr_free(expected); + alts_counter_destroy(ctr); +} + +/* Set the counter manually and make sure it overflows as expected. */ +static void alts_counter_test_overflow_single_increment(bool is_client, + size_t counter_size, + size_t overflow_size) { + alts_counter* ctr = nullptr; + char* error_details = nullptr; + grpc_status_code status = alts_counter_create( + is_client, counter_size, overflow_size, &ctr, &error_details); + GPR_ASSERT(status == GRPC_STATUS_OK); + unsigned char* expected = + static_cast(gpr_zalloc(counter_size)); + memset(expected, 0xFF, overflow_size); + expected[0] = 0xFE; + + if (is_client) { + expected[counter_size - 1] = 0x80; + } + memcpy(ctr->counter, expected, counter_size); + GPR_ASSERT(do_bytes_represent_client(ctr, alts_counter_get_counter(ctr), + counter_size) == is_client); + GPR_ASSERT(memcmp(expected, alts_counter_get_counter(ctr), counter_size) == + 0); + bool is_overflow = false; + GPR_ASSERT(alts_counter_increment(ctr, &is_overflow, &error_details) == + GRPC_STATUS_OK); + GPR_ASSERT(!is_overflow); + GPR_ASSERT(do_bytes_represent_client(ctr, alts_counter_get_counter(ctr), + counter_size) == is_client); + expected[0] = static_cast(expected[0] + 1); + GPR_ASSERT(memcmp(expected, alts_counter_get_counter(ctr), counter_size) == + 0); + GPR_ASSERT(alts_counter_increment(ctr, &is_overflow, &error_details) == + GRPC_STATUS_FAILED_PRECONDITION); + GPR_ASSERT(is_overflow); + gpr_free(expected); + alts_counter_destroy(ctr); +} + +int main(int /*argc*/, char** /*argv*/) { + alts_counter_test_input_sanity_check(kGcmCounterSize, kGcmOverflowSize); + alts_counter_test_overflow_full_range(true, kSmallCounterSize, + kSmallOverflowSize); + alts_counter_test_overflow_full_range(false, kSmallCounterSize, + kSmallOverflowSize); + alts_counter_test_overflow_single_increment(true, kGcmCounterSize, + kGcmOverflowSize); + alts_counter_test_overflow_single_increment(false, kGcmCounterSize, + kGcmOverflowSize); + + return 0; +} diff --git a/test/core/tsi/alts/frame_protector/alts_crypter_test.cc b/test/core/tsi/alts/frame_protector/alts_crypter_test.cc new file mode 100644 index 00000000..05285442 --- /dev/null +++ b/test/core/tsi/alts/frame_protector/alts_crypter_test.cc @@ -0,0 +1,493 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/frame_protector/alts_crypter.h" + +#include +#include +#include + +#include +#include + +#include "test/core/tsi/alts/crypt/gsec_test_util.h" + +static void alts_crypter_test_random_seal_unseal(alts_crypter* server_seal, + alts_crypter* server_unseal, + alts_crypter* client_seal, + alts_crypter* client_unseal) { + size_t data_size = gsec_test_bias_random_uint32(1024) + 1; + size_t num_overhead_bytes = alts_crypter_num_overhead_bytes(server_seal); + size_t protected_data_size = data_size + num_overhead_bytes; + uint8_t* data_buffer = static_cast(gpr_malloc(protected_data_size)); + gsec_test_random_bytes(data_buffer, data_size); + uint8_t* duplicate_buffer = nullptr; + gsec_test_copy(data_buffer, &duplicate_buffer, data_size); + + /* Client seal and server unseal */ + size_t size = data_size; + grpc_status_code status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + status = alts_crypter_process_in_place( + server_unseal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer, duplicate_buffer, data_size) == 0); + GPR_ASSERT(size == data_size); + /* Server seal and client unseal */ + status = alts_crypter_process_in_place( + server_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + status = alts_crypter_process_in_place( + client_unseal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer, duplicate_buffer, data_size) == 0); + GPR_ASSERT(size == data_size); + gpr_free(data_buffer); + gpr_free(duplicate_buffer); +} + +static void alts_crypter_test_multiple_random_seal_unseal( + alts_crypter* server_seal, alts_crypter* server_unseal, + alts_crypter* client_seal, alts_crypter* client_unseal) { + size_t data_size = gsec_test_bias_random_uint32(1024) + 1; + size_t num_overhead_bytes = alts_crypter_num_overhead_bytes(server_seal); + size_t protected_data_size = data_size + num_overhead_bytes; + + uint8_t* data_buffer1 = + static_cast(gpr_malloc(protected_data_size)); + uint8_t* data_buffer2 = + static_cast(gpr_malloc(protected_data_size)); + uint8_t* duplicate_buffer1 = nullptr; + uint8_t* duplicate_buffer2 = nullptr; + gsec_test_random_bytes(data_buffer1, data_size); + gsec_test_random_bytes(data_buffer2, data_size); + gsec_test_copy(data_buffer1, &duplicate_buffer1, data_size); + gsec_test_copy(data_buffer2, &duplicate_buffer2, data_size); + + /* Client seal and server unseal */ + size_t size1 = data_size, size2 = data_size; + grpc_status_code status = alts_crypter_process_in_place( + client_seal, data_buffer1, protected_data_size, size1, &size1, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size1 == protected_data_size); + status = alts_crypter_process_in_place( + client_seal, data_buffer2, protected_data_size, size2, &size2, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size2 == protected_data_size); + status = alts_crypter_process_in_place( + server_unseal, data_buffer1, protected_data_size, size1, &size1, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer1, duplicate_buffer1, data_size) == 0); + GPR_ASSERT(size1 == data_size); + status = alts_crypter_process_in_place( + server_unseal, data_buffer2, protected_data_size, size2, &size2, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer2, duplicate_buffer2, data_size) == 0); + GPR_ASSERT(size2 == data_size); + + /* Server seal and client unseal */ + status = alts_crypter_process_in_place( + server_seal, data_buffer1, protected_data_size, size1, &size1, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size1 == protected_data_size); + status = alts_crypter_process_in_place( + server_seal, data_buffer2, protected_data_size, size2, &size2, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size2 == protected_data_size); + status = alts_crypter_process_in_place( + client_unseal, data_buffer1, protected_data_size, size1, &size1, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer1, duplicate_buffer1, data_size) == 0); + GPR_ASSERT(size1 == data_size); + status = alts_crypter_process_in_place( + client_unseal, data_buffer2, protected_data_size, size2, &size2, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(data_buffer2, duplicate_buffer2, data_size) == 0); + GPR_ASSERT(size2 == data_size); + + gpr_free(data_buffer1); + gpr_free(data_buffer2); + gpr_free(duplicate_buffer1); + gpr_free(duplicate_buffer2); +} + +static void alts_crypter_test_corrupted_unseal( + alts_crypter* server_seal, alts_crypter* server_unseal, + alts_crypter* client_seal, alts_crypter* /*client_unseal*/) { + size_t data_size = gsec_test_bias_random_uint32(1024) + 1; + size_t num_overhead_bytes = alts_crypter_num_overhead_bytes(server_seal); + size_t protected_data_size = data_size + num_overhead_bytes; + auto* data_buffer = static_cast(gpr_malloc(protected_data_size)); + auto* zero_buffer = static_cast(gpr_zalloc(data_size)); + + /* Corrupt a random byte in protected data. */ + size_t size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + grpc_status_code status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + uint8_t* corrupted_data_buffer; + char* error_message = nullptr; + gsec_test_copy_and_alter_random_byte(data_buffer, &corrupted_data_buffer, + protected_data_size); + status = alts_crypter_process_in_place(server_unseal, corrupted_data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(memcmp(corrupted_data_buffer, zero_buffer, data_size) == 0); + gpr_free(corrupted_data_buffer); + gpr_free(error_message); + + /* Corrupt the beginning of protected data. */ + size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + gsec_test_copy(data_buffer, &corrupted_data_buffer, protected_data_size); + (*corrupted_data_buffer)++; + status = alts_crypter_process_in_place(server_unseal, corrupted_data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(memcmp(corrupted_data_buffer, zero_buffer, data_size) == 0); + gpr_free(corrupted_data_buffer); + gpr_free(error_message); + + /* Corrupt the end of protected data. */ + size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + gsec_test_copy(data_buffer, &corrupted_data_buffer, protected_data_size); + (*(corrupted_data_buffer + protected_data_size - 1))++; + status = alts_crypter_process_in_place(server_unseal, corrupted_data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(memcmp(corrupted_data_buffer, zero_buffer, data_size) == 0); + gpr_free(corrupted_data_buffer); + gpr_free(error_message); + + gpr_free(data_buffer); + gpr_free(zero_buffer); +} + +static void alts_crypter_test_unsync_seal_unseal(alts_crypter* server_seal, + alts_crypter* server_unseal, + alts_crypter* client_seal, + alts_crypter* client_unseal) { + size_t data_size = gsec_test_bias_random_uint32(1024) + 1; + size_t num_overhead_bytes = alts_crypter_num_overhead_bytes(server_seal); + size_t protected_data_size = data_size + num_overhead_bytes; + auto* data_buffer = static_cast(gpr_malloc(protected_data_size)); + auto* zero_buffer = static_cast(gpr_zalloc(data_size)); + + /* Perform two seals at client, one unseal at server. */ + size_t size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + grpc_status_code status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + + size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + status = alts_crypter_process_in_place( + client_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + + char* error_message = nullptr; + status = alts_crypter_process_in_place(server_unseal, data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(memcmp(data_buffer, zero_buffer, data_size) == 0); + gpr_free(error_message); + + /* Perform two seals at server, one unseal at client. */ + size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + status = alts_crypter_process_in_place( + server_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + + size = data_size; + gsec_test_random_bytes(data_buffer, data_size); + status = alts_crypter_process_in_place( + server_seal, data_buffer, protected_data_size, size, &size, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(size == protected_data_size); + + status = alts_crypter_process_in_place(client_unseal, data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Checking tag failed")); + GPR_ASSERT(memcmp(data_buffer, zero_buffer, data_size) == 0); + gpr_free(error_message); + gpr_free(data_buffer); + gpr_free(zero_buffer); +} + +static void alts_crypter_test_input_sanity_check(alts_crypter* crypter_seal, + alts_crypter* crypter_unseal) { + size_t data_size = gsec_test_bias_random_uint32(1024) + 1; + size_t num_overhead_bytes = alts_crypter_num_overhead_bytes(crypter_seal); + size_t protected_data_size = data_size + num_overhead_bytes; + auto* data_buffer = static_cast(gpr_malloc(protected_data_size)); + gsec_test_random_bytes(data_buffer, data_size); + char* error_message = nullptr; + size_t size = data_size; + + /* Crypter is nullptr. */ + grpc_status_code status = alts_crypter_process_in_place( + nullptr, data_buffer, protected_data_size, size, &size, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "crypter or crypter->vtable has not been initialized properly.")); + gpr_free(error_message); + + /* Seal data is nullptr. */ + size = data_size; + status = alts_crypter_process_in_place( + crypter_seal, nullptr, protected_data_size, size, &size, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "data is nullptr.")); + gpr_free(error_message); + + /* Seal data size is 0. */ + size = 0; + status = alts_crypter_process_in_place(crypter_seal, data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "data_size is zero.")); + gpr_free(error_message); + + /* Seal data buffer has a size smaller than the required. */ + size = data_size; + status = alts_crypter_process_in_place(crypter_seal, data_buffer, + protected_data_size - 1, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "data_allocated_size is smaller than sum of data_size and " + "num_overhead_bytes.")); + gpr_free(error_message); + + /* Unseal data is nullptr. */ + size = data_size; + status = alts_crypter_process_in_place(crypter_unseal, nullptr, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "data is nullptr.")); + gpr_free(error_message); + + /* Unseal data size is 0. */ + size = 0; + status = alts_crypter_process_in_place(crypter_unseal, data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "data_size is smaller than num_overhead_bytes.")); + gpr_free(error_message); + + /* Unseal data size is smaller than number of overhead bytes. */ + size = num_overhead_bytes - 1; + status = alts_crypter_process_in_place(crypter_unseal, data_buffer, + protected_data_size, size, &size, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "data_size is smaller than num_overhead_bytes.")); + gpr_free(error_message); + gpr_free(data_buffer); +} + +static void create_random_alts_seal_crypter( + alts_crypter** server_seal, alts_crypter** server_unseal, + alts_crypter** client_seal, alts_crypter** client_unseal, + gsec_aead_crypter** server_crypter_seal, + gsec_aead_crypter** server_crypter_unseal, + gsec_aead_crypter** client_crypter_seal, + gsec_aead_crypter** client_crypter_unseal, bool rekey) { + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + uint8_t* key; + gsec_test_random_array(&key, key_length); + gsec_aes_gcm_aead_crypter_create(key, key_length, kAesGcmNonceLength, + kAesGcmTagLength, rekey, server_crypter_seal, + nullptr); + gsec_aes_gcm_aead_crypter_create(key, key_length, kAesGcmNonceLength, + kAesGcmTagLength, rekey, + server_crypter_unseal, nullptr); + gsec_aes_gcm_aead_crypter_create(key, key_length, kAesGcmNonceLength, + kAesGcmTagLength, rekey, client_crypter_seal, + nullptr); + gsec_aes_gcm_aead_crypter_create(key, key_length, kAesGcmNonceLength, + kAesGcmTagLength, rekey, + client_crypter_unseal, nullptr); + + size_t overflow_size = rekey ? 8 : 5; + alts_seal_crypter_create(*client_crypter_seal, /*is_client=*/true, + overflow_size, client_seal, nullptr); + alts_unseal_crypter_create(*client_crypter_unseal, /*is_client=*/true, + overflow_size, client_unseal, nullptr); + alts_seal_crypter_create(*server_crypter_seal, /*is_client=*/false, + overflow_size, server_seal, nullptr); + alts_unseal_crypter_create(*server_crypter_unseal, /*is_client=*/false, + overflow_size, server_unseal, nullptr); + gpr_free(key); +} + +static void destroy_random_alts_seal_crypter(alts_crypter* server_seal, + alts_crypter* server_unseal, + alts_crypter* client_seal, + alts_crypter* client_unseal) { + alts_crypter_destroy(server_seal); + alts_crypter_destroy(server_unseal); + alts_crypter_destroy(client_seal); + alts_crypter_destroy(client_unseal); +} + +static void alts_crypter_do_generic_tests() { + alts_crypter *server_seal = nullptr, *server_unseal = nullptr, + *client_seal = nullptr, *client_unseal = nullptr; + gsec_aead_crypter *server_crypter_seal = nullptr, + *server_crypter_unseal = nullptr, + *client_crypter_seal = nullptr, + *client_crypter_unseal = nullptr; + /* Random seal and unseal tests */ + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/false); + alts_crypter_test_random_seal_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/true); + alts_crypter_test_random_seal_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + /* Multiple random seal and unseal tests */ + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/false); + alts_crypter_test_multiple_random_seal_unseal(server_seal, server_unseal, + client_seal, client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/true); + alts_crypter_test_multiple_random_seal_unseal(server_seal, server_unseal, + client_seal, client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + /* Corrupted unseal tests */ + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/false); + alts_crypter_test_corrupted_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/true); + alts_crypter_test_corrupted_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + /* Unsync seal and unseal tests */ + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/false); + alts_crypter_test_unsync_seal_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/true); + alts_crypter_test_unsync_seal_unseal(server_seal, server_unseal, client_seal, + client_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + /* Input sanity check tests */ + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/false); + alts_crypter_test_input_sanity_check(server_seal, server_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); + + create_random_alts_seal_crypter(&server_seal, &server_unseal, &client_seal, + &client_unseal, &server_crypter_seal, + &server_crypter_unseal, &client_crypter_seal, + &client_crypter_unseal, /*rekey=*/true); + alts_crypter_test_input_sanity_check(server_seal, server_unseal); + destroy_random_alts_seal_crypter(server_seal, server_unseal, client_seal, + client_unseal); +} + +int main(int /*argc*/, char** /*argv*/) { + alts_crypter_do_generic_tests(); + return 0; +} diff --git a/test/core/tsi/alts/frame_protector/alts_frame_protector_test.cc b/test/core/tsi/alts/frame_protector/alts_frame_protector_test.cc new file mode 100644 index 00000000..3b75ecf6 --- /dev/null +++ b/test/core/tsi/alts/frame_protector/alts_frame_protector_test.cc @@ -0,0 +1,395 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" + +#include + +#include +#include + +#include "src/core/tsi/alts/crypt/gsec.h" +#include "src/core/tsi/transport_security_interface.h" +#include "test/core/tsi/alts/crypt/gsec_test_util.h" +#include "test/core/tsi/transport_security_test_lib.h" + +const size_t kChannelSize = 32768; + +static void alts_test_do_round_trip_check_frames( + tsi_test_frame_protector_fixture* fixture, const uint8_t* key, + const size_t key_size, bool rekey, const uint8_t* client_message, + const size_t client_message_size, const uint8_t* client_expected_frames, + const size_t client_frame_size, const uint8_t* server_message, + const size_t server_message_size, const uint8_t* server_expected_frames, + const size_t server_frame_size) { + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->config != nullptr); + tsi_frame_protector* client_frame_protector = nullptr; + tsi_frame_protector* server_frame_protector = nullptr; + tsi_test_frame_protector_config* config = fixture->config; + tsi_test_channel* channel = fixture->channel; + /* Create a client frame protector. */ + size_t client_max_output_protected_frame_size = + config->client_max_output_protected_frame_size; + GPR_ASSERT( + alts_create_frame_protector(key, key_size, /*is_client=*/true, rekey, + client_max_output_protected_frame_size == 0 + ? nullptr + : &client_max_output_protected_frame_size, + &client_frame_protector) == TSI_OK); + /* Create a server frame protector. */ + size_t server_max_output_protected_frame_size = + config->server_max_output_protected_frame_size; + GPR_ASSERT( + alts_create_frame_protector(key, key_size, /*is_client=*/false, rekey, + server_max_output_protected_frame_size == 0 + ? nullptr + : &server_max_output_protected_frame_size, + &server_frame_protector) == TSI_OK); + tsi_test_frame_protector_fixture_init(fixture, client_frame_protector, + server_frame_protector); + /* Client sends a message to server. */ + uint8_t* saved_client_message = config->client_message; + config->client_message = const_cast(client_message); + config->client_message_size = client_message_size; + tsi_test_frame_protector_send_message_to_peer(config, channel, + client_frame_protector, + /*is_client=*/true); + /* Verify if the generated frame is the same as the expected. */ + GPR_ASSERT(channel->bytes_written_to_server_channel == client_frame_size); + GPR_ASSERT(memcmp(client_expected_frames, channel->server_channel, + client_frame_size) == 0); + unsigned char* server_received_message = + static_cast(gpr_malloc(kChannelSize)); + size_t server_received_message_size = 0; + tsi_test_frame_protector_receive_message_from_peer( + config, channel, server_frame_protector, server_received_message, + &server_received_message_size, /*is_client=*/false); + GPR_ASSERT(config->client_message_size == server_received_message_size); + GPR_ASSERT(memcmp(config->client_message, server_received_message, + server_received_message_size) == 0); + /* Server sends a message to client. */ + uint8_t* saved_server_message = config->server_message; + config->server_message = const_cast(server_message); + config->server_message_size = server_message_size; + tsi_test_frame_protector_send_message_to_peer(config, channel, + server_frame_protector, + /*is_client=*/false); + /* Verify if the generated frame is the same as the expected. */ + GPR_ASSERT(channel->bytes_written_to_client_channel == server_frame_size); + GPR_ASSERT(memcmp(server_expected_frames, channel->client_channel, + server_frame_size) == 0); + unsigned char* client_received_message = + static_cast(gpr_malloc(kChannelSize)); + size_t client_received_message_size = 0; + tsi_test_frame_protector_receive_message_from_peer( + config, channel, client_frame_protector, client_received_message, + &client_received_message_size, + /*is_client=*/true); + GPR_ASSERT(config->server_message_size == client_received_message_size); + GPR_ASSERT(memcmp(config->server_message, client_received_message, + client_received_message_size) == 0); + config->client_message = saved_client_message; + config->server_message = saved_server_message; + /* Destroy server and client frame protectors. */ + gpr_free(server_received_message); + gpr_free(client_received_message); +} + +static void alts_test_do_round_trip_vector_tests() { + const uint8_t key[] = {0xfe, 0xff, 0xe9, 0x92, 0x86, 0x65, 0x73, 0x1c, + 0x6d, 0x6a, 0x8f, 0x94, 0x67, 0x30, 0x83, 0x08}; + const char small_message[] = {'C', 'h', 'a', 'p', 'i', ' ', + 'C', 'h', 'a', 'p', 'o'}; + const uint8_t large_message[] = { + 0xd9, 0x31, 0x32, 0x25, 0xf8, 0x84, 0x06, 0xe5, 0xa5, 0x59, 0x09, 0xc5, + 0xaf, 0xf5, 0x26, 0x9a, 0x86, 0xa7, 0xa9, 0x53, 0x15, 0x34, 0xf7, 0xda, + 0x2e, 0x4c, 0x30, 0x3d, 0x8a, 0x31, 0x8a, 0x72, 0x1c, 0x3c, 0x0c, 0x95, + 0x95, 0x68, 0x09, 0x53, 0x2f, 0xcf, 0x0e, 0x24, 0x49, 0xa6, 0xb5, 0x25, + 0xb1, 0x6a, 0xed, 0xf5, 0xaa, 0x0d, 0xe6, 0x57, 0xba, 0x63, 0x7b, 0x39, + 0x1a, 0xaf, 0xd2, 0x55, 0xd6, 0x09, 0xb1, 0xf0, 0x56, 0x63, 0x7a, 0x0d, + 0x46, 0xdf, 0x99, 0x8d, 0x88, 0xe5, 0x22, 0x2a, 0xb2, 0xc2, 0x84, 0x65, + 0x12, 0x15, 0x35, 0x24, 0xc0, 0x89, 0x5e, 0x81, 0x08, 0x06, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, + 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30}; + const size_t small_message_size = sizeof(small_message) / sizeof(uint8_t); + const size_t large_message_size = sizeof(large_message) / sizeof(uint8_t); + /* Test small client message and large server message. */ + const uint8_t client_expected_frame1[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x09, 0xd8, 0xd5, 0x92, + 0x4d, 0x50, 0x32, 0xb7, 0x1f, 0xb8, 0xf2, 0xbb, 0x43, 0xc7, 0xe2, 0x94, + 0x3d, 0x3e, 0x9a, 0x78, 0x76, 0xaa, 0x0a, 0x6b, 0xfa, 0x98, 0x3a}; + const uint8_t server_expected_frame1[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa9, 0x4b, 0xf8, 0xc8, + 0xe7, 0x8f, 0x1a, 0x26, 0x37, 0x44, 0xa2, 0x5c, 0x55, 0x94, 0x30, 0x4e, + 0x3e, 0x16, 0xe7, 0x9e, 0x96, 0xe8, 0x1b, 0xc0, 0xdd, 0x52, 0x30, 0x06, + 0xc2, 0x72, 0x9a, 0xa1, 0x0b, 0xdb, 0xdc, 0x19, 0x8c, 0x93, 0x5e, 0x84, + 0x1f, 0x4b, 0x97, 0x26, 0xf0, 0x73, 0x85, 0x59, 0x00, 0x95, 0xc1, 0xc5, + 0x22, 0x2f, 0x70, 0x85, 0x68, 0x2c, 0x4f, 0xfe, 0x30, 0x26, 0x91, 0xde, + 0x62, 0x55, 0x1d, 0x35, 0x01, 0x96, 0x1c, 0xe7, 0xa2, 0x8b, 0x14, 0x8a, + 0x5e, 0x1b, 0x4a, 0x3b, 0x4f, 0x65, 0x0f, 0xca, 0x79, 0x10, 0xb4, 0xdd, + 0xf7, 0xa4, 0x8b, 0x64, 0x2f, 0x00, 0x39, 0x60, 0x03, 0xfc, 0xe1, 0x8b, + 0x5c, 0x19, 0xba, 0xcc, 0x46, 0xba, 0x88, 0xdd, 0x40, 0x42, 0x27, 0x4f, + 0xe4, 0x1a, 0x6a, 0x31, 0x6c, 0x1c, 0xb0, 0xb6, 0x5c, 0x3e, 0xca, 0x84, + 0x9b, 0x5f, 0x04, 0x84, 0x11, 0xa9, 0xf8, 0x39, 0xe7, 0xe7, 0xc5, 0xc4, + 0x33, 0x9f, 0x63, 0x21, 0x9a, 0x7c, 0x9c, 0x64}; + const size_t client_frame_size1 = + sizeof(client_expected_frame1) / sizeof(uint8_t); + const size_t server_frame_size1 = + sizeof(server_expected_frame1) / sizeof(uint8_t); + tsi_test_frame_protector_fixture* fixture = + tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, + reinterpret_cast(small_message), small_message_size, + client_expected_frame1, client_frame_size1, large_message, + large_message_size, server_expected_frame1, server_frame_size1); + tsi_test_frame_protector_fixture_destroy(fixture); + /** + * Test large client message, small server message, and small + * message_buffer_allocated_size. + */ + const uint8_t client_expected_frame2[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x93, 0x81, 0x86, 0xc7, + 0xdc, 0xf4, 0x77, 0x3a, 0xdb, 0x91, 0x94, 0x61, 0xba, 0xed, 0xd5, 0x37, + 0x47, 0x53, 0x0c, 0xe1, 0xbf, 0x59, 0x23, 0x20, 0xde, 0x8b, 0x25, 0x13, + 0x72, 0xe7, 0x8a, 0x4f, 0x32, 0x61, 0xc6, 0xda, 0xc3, 0xe9, 0xff, 0x31, + 0x33, 0x53, 0x4a, 0xf8, 0xc9, 0x98, 0xe4, 0x19, 0x71, 0x9c, 0x5e, 0x72, + 0xc7, 0x35, 0x97, 0x78, 0x30, 0xf2, 0xc4, 0xd1, 0x53, 0xd5, 0x6e, 0x8f, + 0x4f, 0xd9, 0x28, 0x5a, 0xfd, 0x22, 0x57, 0x7f, 0x95, 0xb4, 0x8a, 0x5e, + 0x7c, 0x47, 0xa8, 0xcf, 0x64, 0x3d, 0x83, 0xa5, 0xcf, 0xc3, 0xfe, 0x54, + 0xc2, 0x6a, 0x40, 0xc4, 0xfb, 0x8e, 0x07, 0x77, 0x70, 0x8f, 0x99, 0x94, + 0xb1, 0xd5, 0xa7, 0xf9, 0x0d, 0xc7, 0x11, 0xc5, 0x6f, 0x4a, 0x4f, 0x56, + 0xd5, 0xe2, 0x9c, 0xbb, 0x95, 0x7a, 0xd0, 0x9f, 0x30, 0x54, 0xca, 0x6d, + 0x5c, 0x8e, 0x83, 0xa0, 0x04, 0x5e, 0xd0, 0x22, 0x8c, 0x2a, 0x7f, 0xdb, + 0xfe, 0xb3, 0x2e, 0xae, 0x22, 0xe6, 0xf4, 0xb7}; + const uint8_t server_expected_frame2[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x33, 0x12, 0xab, 0x9d, + 0x76, 0x2b, 0x5f, 0xab, 0xf3, 0x6d, 0xc4, 0xaa, 0xe5, 0x1e, 0x63, 0xc1, + 0x7b, 0x7b, 0x10, 0xd5, 0x63, 0x0f, 0x29, 0xad, 0x17, 0x33, 0x73}; + const size_t client_frame_size2 = + sizeof(client_expected_frame2) / sizeof(uint8_t); + const size_t server_frame_size2 = + sizeof(server_expected_frame2) / sizeof(uint8_t); + fixture = tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, large_message, + large_message_size, client_expected_frame2, client_frame_size2, + reinterpret_cast(small_message), small_message_size, + server_expected_frame2, server_frame_size2); + tsi_test_frame_protector_fixture_destroy(fixture); + /** + * Test large client message, small server message, and small + * protected_buffer_size. + */ + const uint8_t client_expected_frame3[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x93, 0x81, 0x86, 0xc7, + 0xdc, 0xf4, 0x77, 0x3a, 0xdb, 0x91, 0x94, 0x61, 0xba, 0xed, 0xd5, 0x37, + 0x47, 0x53, 0x0c, 0xe1, 0xbf, 0x59, 0x23, 0x20, 0xde, 0x8b, 0x25, 0x13, + 0x72, 0xe7, 0x8a, 0x4f, 0x32, 0x61, 0xc6, 0xda, 0xc3, 0xe9, 0xff, 0x31, + 0x33, 0x53, 0x4a, 0xf8, 0xc9, 0x98, 0xe4, 0x19, 0x71, 0x9c, 0x5e, 0x72, + 0xc7, 0x35, 0x97, 0x78, 0x30, 0xf2, 0xc4, 0xd1, 0x53, 0xd5, 0x6e, 0x8f, + 0x4f, 0xd9, 0x28, 0x5a, 0xfd, 0x22, 0x57, 0x7f, 0x95, 0xb4, 0x8a, 0x5e, + 0x7c, 0x47, 0xa8, 0xcf, 0x64, 0x3d, 0x83, 0xa5, 0xcf, 0xc3, 0xfe, 0x54, + 0xc2, 0x6a, 0x40, 0xc4, 0xfb, 0x8e, 0x07, 0x77, 0x70, 0x8f, 0x99, 0x94, + 0xb1, 0xd5, 0xa7, 0xf9, 0x0d, 0xc7, 0x11, 0xc5, 0x6f, 0x4a, 0x4f, 0x56, + 0xd5, 0xe2, 0x9c, 0xbb, 0x95, 0x7a, 0xd0, 0x9f, 0x30, 0x54, 0xca, 0x6d, + 0x5c, 0x8e, 0x83, 0xa0, 0x04, 0x5e, 0xd0, 0x22, 0x8c, 0x2a, 0x7f, 0xdb, + 0xfe, 0xb3, 0x2e, 0xae, 0x22, 0xe6, 0xf4, 0xb7}; + const uint8_t server_expected_frame3[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x33, 0x12, 0xab, 0x9d, + 0x76, 0x2b, 0x5f, 0xab, 0xf3, 0x6d, 0xc4, 0xaa, 0xe5, 0x1e, 0x63, 0xc1, + 0x7b, 0x7b, 0x10, 0xd5, 0x63, 0x0f, 0x29, 0xad, 0x17, 0x33, 0x73}; + const size_t client_frame_size3 = + sizeof(client_expected_frame3) / sizeof(uint8_t); + const size_t server_frame_size3 = + sizeof(server_expected_frame3) / sizeof(uint8_t); + fixture = tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, large_message, + large_message_size, client_expected_frame3, client_frame_size3, + reinterpret_cast(small_message), small_message_size, + server_expected_frame3, server_frame_size3); + tsi_test_frame_protector_fixture_destroy(fixture); + /** + * Test large client message, small server message, and small + * read_buffer_allocated_size. + */ + const uint8_t client_expected_frame4[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x93, 0x81, 0x86, 0xc7, + 0xdc, 0xf4, 0x77, 0x3a, 0xdb, 0x91, 0x94, 0x61, 0xba, 0xed, 0xd5, 0x37, + 0x47, 0x53, 0x0c, 0xe1, 0xbf, 0x59, 0x23, 0x20, 0xde, 0x8b, 0x25, 0x13, + 0x72, 0xe7, 0x8a, 0x4f, 0x32, 0x61, 0xc6, 0xda, 0xc3, 0xe9, 0xff, 0x31, + 0x33, 0x53, 0x4a, 0xf8, 0xc9, 0x98, 0xe4, 0x19, 0x71, 0x9c, 0x5e, 0x72, + 0xc7, 0x35, 0x97, 0x78, 0x30, 0xf2, 0xc4, 0xd1, 0x53, 0xd5, 0x6e, 0x8f, + 0x4f, 0xd9, 0x28, 0x5a, 0xfd, 0x22, 0x57, 0x7f, 0x95, 0xb4, 0x8a, 0x5e, + 0x7c, 0x47, 0xa8, 0xcf, 0x64, 0x3d, 0x83, 0xa5, 0xcf, 0xc3, 0xfe, 0x54, + 0xc2, 0x6a, 0x40, 0xc4, 0xfb, 0x8e, 0x07, 0x77, 0x70, 0x8f, 0x99, 0x94, + 0xb1, 0xd5, 0xa7, 0xf9, 0x0d, 0xc7, 0x11, 0xc5, 0x6f, 0x4a, 0x4f, 0x56, + 0xd5, 0xe2, 0x9c, 0xbb, 0x95, 0x7a, 0xd0, 0x9f, 0x30, 0x54, 0xca, 0x6d, + 0x5c, 0x8e, 0x83, 0xa0, 0x04, 0x5e, 0xd0, 0x22, 0x8c, 0x2a, 0x7f, 0xdb, + 0xfe, 0xb3, 0x2e, 0xae, 0x22, 0xe6, 0xf4, 0xb7}; + const uint8_t server_expected_frame4[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x33, 0x12, 0xab, 0x9d, + 0x76, 0x2b, 0x5f, 0xab, 0xf3, 0x6d, 0xc4, 0xaa, 0xe5, 0x1e, 0x63, 0xc1, + 0x7b, 0x7b, 0x10, 0xd5, 0x63, 0x0f, 0x29, 0xad, 0x17, 0x33, 0x73}; + const size_t client_frame_size4 = + sizeof(client_expected_frame4) / sizeof(uint8_t); + const size_t server_frame_size4 = + sizeof(server_expected_frame4) / sizeof(uint8_t); + fixture = tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, large_message, + large_message_size, client_expected_frame4, client_frame_size4, + reinterpret_cast(small_message), small_message_size, + server_expected_frame4, server_frame_size4); + tsi_test_frame_protector_fixture_destroy(fixture); + /** + * Test large client message, small server message, and small + * client_max_output_protected_frame_size. + */ + const uint8_t client_expected_frame5[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x93, 0x81, 0x86, 0xc7, + 0xdc, 0xf4, 0x77, 0x3a, 0xdb, 0x91, 0x94, 0x61, 0xba, 0xed, 0xd5, 0x37, + 0x47, 0x53, 0x0c, 0xe1, 0xbf, 0x59, 0x23, 0x20, 0xde, 0x8b, 0x25, 0x13, + 0x72, 0xe7, 0x8a, 0x4f, 0x32, 0x61, 0xc6, 0xda, 0xc3, 0xe9, 0xff, 0x31, + 0x33, 0x53, 0x4a, 0xf8, 0xc9, 0x98, 0xe4, 0x19, 0x71, 0x9c, 0x5e, 0x72, + 0xc7, 0x35, 0x97, 0x78, 0x30, 0xf2, 0xc4, 0xd1, 0x53, 0xd5, 0x6e, 0x8f, + 0x4f, 0xd9, 0x28, 0x5a, 0xfd, 0x22, 0x57, 0x7f, 0x95, 0xb4, 0x8a, 0x5e, + 0x7c, 0x47, 0xa8, 0xcf, 0x64, 0x3d, 0x83, 0xa5, 0xcf, 0xc3, 0xfe, 0x54, + 0xc2, 0x6a, 0x40, 0xc4, 0xfb, 0x8e, 0x07, 0x77, 0x70, 0x8f, 0x99, 0x94, + 0xb1, 0xd5, 0xa7, 0xf9, 0x0d, 0xc7, 0x11, 0xc5, 0x6f, 0x4a, 0x4f, 0x56, + 0xd5, 0xe2, 0x9c, 0xbb, 0x95, 0x7a, 0xd0, 0x9f, 0x30, 0x54, 0xca, 0x6d, + 0x5c, 0x8e, 0x83, 0xa0, 0x04, 0x5e, 0xd0, 0x22, 0x8c, 0x2a, 0x7f, 0xdb, + 0xfe, 0xb3, 0x2e, 0xae, 0x22, 0xe6, 0xf4, 0xb7}; + const uint8_t server_expected_frame5[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x33, 0x12, 0xab, 0x9d, + 0x76, 0x2b, 0x5f, 0xab, 0xf3, 0x6d, 0xc4, 0xaa, 0xe5, 0x1e, 0x63, 0xc1, + 0x7b, 0x7b, 0x10, 0xd5, 0x63, 0x0f, 0x29, 0xad, 0x17, 0x33, 0x73}; + const size_t client_frame_size5 = + sizeof(client_expected_frame5) / sizeof(uint8_t); + const size_t server_frame_size5 = + sizeof(server_expected_frame5) / sizeof(uint8_t); + fixture = tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, large_message, + large_message_size, client_expected_frame5, client_frame_size5, + reinterpret_cast(small_message), small_message_size, + server_expected_frame5, server_frame_size5); + tsi_test_frame_protector_fixture_destroy(fixture); + /** + * Test small client message, large server message, and small + * server_max_output_protected_frame_size. + */ + const uint8_t client_expected_frame6[] = { + 0x1f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x09, 0xd8, 0xd5, 0x92, + 0x4d, 0x50, 0x32, 0xb7, 0x1f, 0xb8, 0xf2, 0xbb, 0x43, 0xc7, 0xe2, 0x94, + 0x3d, 0x3e, 0x9a, 0x78, 0x76, 0xaa, 0x0a, 0x6b, 0xfa, 0x98, 0x3a}; + const uint8_t server_expected_frame6[] = { + 0x94, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa9, 0x4b, 0xf8, 0xc8, + 0xe7, 0x8f, 0x1a, 0x26, 0x37, 0x44, 0xa2, 0x5c, 0x55, 0x94, 0x30, 0x4e, + 0x3e, 0x16, 0xe7, 0x9e, 0x96, 0xe8, 0x1b, 0xc0, 0xdd, 0x52, 0x30, 0x06, + 0xc2, 0x72, 0x9a, 0xa1, 0x0b, 0xdb, 0xdc, 0x19, 0x8c, 0x93, 0x5e, 0x84, + 0x1f, 0x4b, 0x97, 0x26, 0xf0, 0x73, 0x85, 0x59, 0x00, 0x95, 0xc1, 0xc5, + 0x22, 0x2f, 0x70, 0x85, 0x68, 0x2c, 0x4f, 0xfe, 0x30, 0x26, 0x91, 0xde, + 0x62, 0x55, 0x1d, 0x35, 0x01, 0x96, 0x1c, 0xe7, 0xa2, 0x8b, 0x14, 0x8a, + 0x5e, 0x1b, 0x4a, 0x3b, 0x4f, 0x65, 0x0f, 0xca, 0x79, 0x10, 0xb4, 0xdd, + 0xf7, 0xa4, 0x8b, 0x64, 0x2f, 0x00, 0x39, 0x60, 0x03, 0xfc, 0xe1, 0x8b, + 0x5c, 0x19, 0xba, 0xcc, 0x46, 0xba, 0x88, 0xdd, 0x40, 0x42, 0x27, 0x4f, + 0xe4, 0x1a, 0x6a, 0x31, 0x6c, 0x1c, 0xb0, 0xb6, 0x5c, 0x3e, 0xca, 0x84, + 0x9b, 0x5f, 0x04, 0x84, 0x11, 0xa9, 0xf8, 0x39, 0xe7, 0xe7, 0xc5, 0xc4, + 0x33, 0x9f, 0x63, 0x21, 0x9a, 0x7c, 0x9c, 0x64}; + const size_t client_frame_size6 = + sizeof(client_expected_frame6) / sizeof(uint8_t); + const size_t server_frame_size6 = + sizeof(server_expected_frame6) / sizeof(uint8_t); + fixture = tsi_test_frame_protector_fixture_create(); + alts_test_do_round_trip_check_frames( + fixture, key, kAes128GcmKeyLength, /*rekey=*/false, + reinterpret_cast(small_message), small_message_size, + client_expected_frame6, client_frame_size6, large_message, + large_message_size, server_expected_frame6, server_frame_size6); + tsi_test_frame_protector_fixture_destroy(fixture); +} + +static void alts_test_do_round_trip(tsi_test_frame_protector_fixture* fixture, + bool rekey) { + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->config != nullptr); + tsi_frame_protector* client_frame_protector = nullptr; + tsi_frame_protector* server_frame_protector = nullptr; + tsi_test_frame_protector_config* config = fixture->config; + /* Create a key to be used by both client and server. */ + uint8_t* key = nullptr; + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + gsec_test_random_array(&key, key_length); + /* Create a client frame protector. */ + size_t client_max_output_protected_frame_size = + config->client_max_output_protected_frame_size; + GPR_ASSERT( + alts_create_frame_protector(key, key_length, /*is_client=*/true, rekey, + client_max_output_protected_frame_size == 0 + ? nullptr + : &client_max_output_protected_frame_size, + &client_frame_protector) == TSI_OK); + /* Create a server frame protector. */ + size_t server_max_output_protected_frame_size = + config->server_max_output_protected_frame_size; + GPR_ASSERT( + alts_create_frame_protector(key, key_length, /*is_client=*/false, rekey, + server_max_output_protected_frame_size == 0 + ? nullptr + : &server_max_output_protected_frame_size, + &server_frame_protector) == TSI_OK); + tsi_test_frame_protector_fixture_init(fixture, client_frame_protector, + server_frame_protector); + tsi_test_frame_protector_do_round_trip_no_handshake(fixture); + gpr_free(key); +} + +/* Run all combinations of different arguments of test config. */ +static void alts_test_do_round_trip_all(bool rekey) { + unsigned int* bit_array = static_cast( + gpr_malloc(sizeof(unsigned int) * TSI_TEST_NUM_OF_ARGUMENTS)); + unsigned int mask = 1U << (TSI_TEST_NUM_OF_ARGUMENTS - 1); + unsigned int val = 0, ind = 0; + for (val = 0; val < TSI_TEST_NUM_OF_COMBINATIONS; val++) { + unsigned int v = val; + for (ind = 0; ind < TSI_TEST_NUM_OF_ARGUMENTS; ind++) { + bit_array[ind] = (v & mask) ? 1 : 0; + v <<= 1; + } + tsi_test_frame_protector_fixture* fixture = + tsi_test_frame_protector_fixture_create(); + tsi_test_frame_protector_config_destroy(fixture->config); + fixture->config = tsi_test_frame_protector_config_create( + bit_array[0], bit_array[1], bit_array[2], bit_array[3], bit_array[4], + bit_array[5], bit_array[6]); + alts_test_do_round_trip(fixture, rekey); + tsi_test_frame_protector_fixture_destroy(fixture); + } + gpr_free(bit_array); +} + +int main(int /*argc*/, char** /*argv*/) { + alts_test_do_round_trip_vector_tests(); + alts_test_do_round_trip_all(/*rekey=*/false); + alts_test_do_round_trip_all(/*rekey=*/true); + return 0; +} diff --git a/test/core/tsi/alts/frame_protector/frame_handler_test.cc b/test/core/tsi/alts/frame_protector/frame_handler_test.cc new file mode 100644 index 00000000..a8324aba --- /dev/null +++ b/test/core/tsi/alts/frame_protector/frame_handler_test.cc @@ -0,0 +1,248 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/frame_protector/frame_handler.h" + +#include +#include +#include + +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/tsi/alts/crypt/gsec_test_util.h" + +const size_t kFrameHandlerTestBufferSize = 1024; + +typedef struct frame_handler { + alts_frame_writer* writer; + alts_frame_reader* reader; + unsigned char* buffer; + size_t buffer_size; +} frame_handler; + +static size_t frame_length(size_t payload_length) { + return payload_length + kFrameHeaderSize; +} + +static frame_handler* create_frame_handler() { + frame_handler* handler = + static_cast(gpr_malloc(sizeof(frame_handler))); + handler->writer = alts_create_frame_writer(); + handler->reader = alts_create_frame_reader(); + handler->buffer = nullptr; + handler->buffer_size = 0; + return handler; +} + +static void destroy_frame_handler(frame_handler* handler) { + if (handler != nullptr) { + alts_destroy_frame_reader(handler->reader); + alts_destroy_frame_writer(handler->writer); + if (handler->buffer != nullptr) gpr_free(handler->buffer); + gpr_free(handler); + } +} + +static void frame(frame_handler* handler, unsigned char* payload, + size_t payload_length, size_t write_length) { + handler->buffer_size = frame_length(payload_length); + handler->buffer = + static_cast(gpr_malloc(handler->buffer_size)); + GPR_ASSERT(alts_reset_frame_writer(handler->writer, payload, payload_length)); + size_t offset = 0; + while (offset < handler->buffer_size && + !alts_is_frame_writer_done(handler->writer)) { + size_t bytes_written = + std::min(write_length, handler->buffer_size - offset); + GPR_ASSERT(alts_write_frame_bytes(handler->writer, handler->buffer + offset, + &bytes_written)); + offset += bytes_written; + } + GPR_ASSERT(alts_is_frame_writer_done(handler->writer)); + GPR_ASSERT(handler->buffer_size == offset); +} + +static size_t deframe(frame_handler* handler, unsigned char* bytes, + size_t read_length) { + GPR_ASSERT(alts_reset_frame_reader(handler->reader, bytes)); + size_t offset = 0; + while (offset < handler->buffer_size && + !alts_is_frame_reader_done(handler->reader)) { + size_t bytes_read = std::min(read_length, handler->buffer_size - offset); + GPR_ASSERT(alts_read_frame_bytes(handler->reader, handler->buffer + offset, + &bytes_read)); + offset += bytes_read; + } + GPR_ASSERT(alts_is_frame_reader_done(handler->reader)); + GPR_ASSERT(handler->buffer_size == offset); + return offset - handler->reader->header_bytes_read; +} + +static void frame_n_deframe(frame_handler* handler, unsigned char* payload, + size_t payload_length, size_t write_length, + size_t read_length) { + frame(handler, payload, payload_length, write_length); + unsigned char* bytes = + static_cast(gpr_malloc(kFrameHandlerTestBufferSize)); + size_t deframed_payload_length = deframe(handler, bytes, read_length); + GPR_ASSERT(payload_length == deframed_payload_length); + GPR_ASSERT(memcmp(payload, bytes, payload_length) == 0); + gpr_free(bytes); +} + +static void frame_handler_test_frame_deframe() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame_n_deframe(handler, payload, payload_length, + frame_length(payload_length), frame_length(payload_length)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_small_buffer() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame_n_deframe(handler, payload, payload_length, 1, 1); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_input_stream() { + frame_handler* handler = create_frame_handler(); + GPR_ASSERT(!alts_reset_frame_writer(handler->writer, nullptr, 0)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_bad_input_length() { + unsigned char payload[] = "hello world"; + frame_handler* handler = create_frame_handler(); + GPR_ASSERT(!alts_reset_frame_writer(handler->writer, payload, SIZE_MAX)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_writer_byte_length() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + GPR_ASSERT(alts_reset_frame_writer(handler->writer, payload, payload_length)); + GPR_ASSERT( + !alts_write_frame_bytes(handler->writer, handler->buffer, nullptr)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_writer_bytes() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + GPR_ASSERT(alts_reset_frame_writer(handler->writer, payload, payload_length)); + GPR_ASSERT( + !alts_write_frame_bytes(handler->writer, nullptr, &payload_length)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_bad_frame_length() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame(handler, payload, payload_length, payload_length); + memset(handler->buffer, 0x00, kFrameLengthFieldSize); + unsigned char* bytes = + static_cast(gpr_malloc(kFrameHandlerTestBufferSize)); + GPR_ASSERT(alts_reset_frame_reader(handler->reader, bytes)); + size_t bytes_read = handler->buffer_size; + GPR_ASSERT( + !alts_read_frame_bytes(handler->reader, handler->buffer, &bytes_read)); + GPR_ASSERT(alts_is_frame_reader_done(handler->reader)); + GPR_ASSERT(bytes_read == 0); + gpr_free(bytes); + destroy_frame_handler(handler); +} + +static void frame_handler_test_unsupported_message_type() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame(handler, payload, payload_length, payload_length); + memset(handler->buffer + kFrameLengthFieldSize, 0x00, + kFrameMessageTypeFieldSize); + unsigned char* bytes = + static_cast(gpr_malloc(kFrameHandlerTestBufferSize)); + GPR_ASSERT(alts_reset_frame_reader(handler->reader, bytes)); + size_t bytes_read = handler->buffer_size; + GPR_ASSERT( + !alts_read_frame_bytes(handler->reader, handler->buffer, &bytes_read)); + GPR_ASSERT(alts_is_frame_reader_done(handler->reader)); + GPR_ASSERT(bytes_read == 0); + gpr_free(bytes); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_output_stream() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame(handler, payload, payload_length, payload_length); + GPR_ASSERT(!alts_reset_frame_reader(handler->reader, nullptr)); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_reader_byte_length() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame(handler, payload, payload_length, payload_length); + unsigned char* bytes = + static_cast(gpr_malloc(kFrameHandlerTestBufferSize)); + GPR_ASSERT(alts_reset_frame_reader(handler->reader, bytes)); + GPR_ASSERT(!alts_read_frame_bytes(handler->reader, handler->buffer, nullptr)); + gpr_free(bytes); + destroy_frame_handler(handler); +} + +static void frame_handler_test_null_reader_bytes() { + unsigned char payload[] = "hello world"; + size_t payload_length = strlen(reinterpret_cast(payload)) + 1; + frame_handler* handler = create_frame_handler(); + frame(handler, payload, payload_length, payload_length); + unsigned char* bytes = + static_cast(gpr_malloc(kFrameHandlerTestBufferSize)); + GPR_ASSERT(alts_reset_frame_reader(handler->reader, bytes)); + size_t bytes_read = handler->buffer_size; + GPR_ASSERT(!alts_read_frame_bytes(handler->reader, nullptr, &bytes_read)); + gpr_free(bytes); + destroy_frame_handler(handler); +} + +int main(int /*argc*/, char** /*argv*/) { + frame_handler_test_frame_deframe(); + frame_handler_test_small_buffer(); + frame_handler_test_null_input_stream(); + frame_handler_test_bad_input_length(); + frame_handler_test_null_writer_byte_length(); + frame_handler_test_null_writer_bytes(); + frame_handler_test_bad_frame_length(); + frame_handler_test_unsupported_message_type(); + frame_handler_test_null_output_stream(); + frame_handler_test_null_reader_byte_length(); + frame_handler_test_null_reader_bytes(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc b/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc new file mode 100644 index 00000000..a5a12698 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_concurrent_connectivity_test.cc @@ -0,0 +1,660 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/security/credentials/alts/alts_credentials.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/core/lib/security/security_connector/alts/alts_security_connector.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h" +#include "test/core/util/memory_counters.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +const int kFakeHandshakeServerMaxConcurrentStreams = 40; + +void drain_cq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next( + cq, grpc_timeout_milliseconds_to_deadline(5000), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +grpc_channel* create_secure_channel_for_test( + const char* server_addr, const char* fake_handshake_server_addr, + int reconnect_backoff_ms) { + grpc_alts_credentials_options* alts_options = + grpc_alts_credentials_client_options_create(); + grpc_channel_credentials* channel_creds = + grpc_alts_credentials_create_customized(alts_options, + fake_handshake_server_addr, + true /* enable_untrusted_alts */); + grpc_alts_credentials_options_destroy(alts_options); + // The main goal of these tests are to stress concurrent ALTS handshakes, + // so we prevent subchnannel sharing. + std::vector new_args; + new_args.push_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL), true)); + if (reconnect_backoff_ms != 0) { + new_args.push_back(grpc_channel_arg_integer_create( + const_cast("grpc.testing.fixed_reconnect_backoff_ms"), + reconnect_backoff_ms)); + } + grpc_channel_args* channel_args = + grpc_channel_args_copy_and_add(nullptr, new_args.data(), new_args.size()); + grpc_channel* channel = grpc_secure_channel_create(channel_creds, server_addr, + channel_args, nullptr); + grpc_channel_args_destroy(channel_args); + grpc_channel_credentials_release(channel_creds); + return channel; +} + +class FakeHandshakeServer { + public: + explicit FakeHandshakeServer(bool check_num_concurrent_rpcs) { + int port = grpc_pick_unused_port_or_die(); + address_ = grpc_core::JoinHostPort("localhost", port); + if (check_num_concurrent_rpcs) { + service_ = grpc::gcp:: + CreateFakeHandshakerService(kFakeHandshakeServerMaxConcurrentStreams /* expected max concurrent rpcs */); + } else { + service_ = grpc::gcp::CreateFakeHandshakerService( + 0 /* expected max concurrent rpcs unset */); + } + grpc::ServerBuilder builder; + builder.AddListeningPort(address_.c_str(), + grpc::InsecureServerCredentials()); + builder.RegisterService(service_.get()); + // TODO(apolcyn): when removing the global concurrent handshake limiting + // queue, set MAX_CONCURRENT_STREAMS on this server. + server_ = builder.BuildAndStart(); + gpr_log(GPR_INFO, "Fake handshaker server listening on %s", + address_.c_str()); + } + + ~FakeHandshakeServer() { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + } + + const char* address() { return address_.c_str(); } + + private: + std::string address_; + std::unique_ptr service_; + std::unique_ptr server_; +}; + +class TestServer { + public: + explicit TestServer() + : fake_handshake_server_(true /* check num concurrent rpcs */) { + grpc_alts_credentials_options* alts_options = + grpc_alts_credentials_server_options_create(); + grpc_server_credentials* server_creds = + grpc_alts_server_credentials_create_customized( + alts_options, fake_handshake_server_.address(), + true /* enable_untrusted_alts */); + grpc_alts_credentials_options_destroy(alts_options); + server_ = grpc_server_create(nullptr, nullptr); + server_cq_ = grpc_completion_queue_create_for_next(nullptr); + grpc_server_register_completion_queue(server_, server_cq_, nullptr); + int port = grpc_pick_unused_port_or_die(); + server_addr_ = grpc_core::JoinHostPort("localhost", port); + GPR_ASSERT(grpc_server_add_secure_http2_port(server_, server_addr_.c_str(), + server_creds)); + grpc_server_credentials_release(server_creds); + grpc_server_start(server_); + gpr_log(GPR_DEBUG, "Start TestServer %p. listen on %s", this, + server_addr_.c_str()); + server_thd_ = absl::make_unique(PollUntilShutdown, this); + } + + ~TestServer() { + gpr_log(GPR_DEBUG, "Begin dtor of TestServer %p", this); + grpc_server_shutdown_and_notify(server_, server_cq_, this); + server_thd_->join(); + grpc_server_destroy(server_); + grpc_completion_queue_shutdown(server_cq_); + drain_cq(server_cq_); + grpc_completion_queue_destroy(server_cq_); + } + + const char* address() { return server_addr_.c_str(); } + + static void PollUntilShutdown(const TestServer* self) { + grpc_event ev = grpc_completion_queue_next( + self->server_cq_, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type == GRPC_OP_COMPLETE); + GPR_ASSERT(ev.tag == self); + gpr_log(GPR_DEBUG, "TestServer %p stop polling", self); + } + + private: + grpc_server* server_; + grpc_completion_queue* server_cq_; + std::unique_ptr server_thd_; + std::string server_addr_; + // Give this test server its own ALTS handshake server + // so that we avoid competing for ALTS handshake server resources (e.g. + // available HTTP2 streams on a globally shared handshaker subchannel) + // with clients that are trying to do mutual ALTS handshakes + // with this server (which could "deadlock" mutual handshakes). + // TODO(apolcyn): remove this workaround from this test and have + // clients/servers share a single fake handshake server if + // the underlying issue needs to be fixed. + FakeHandshakeServer fake_handshake_server_; +}; + +class ConnectLoopRunner { + public: + explicit ConnectLoopRunner( + const char* server_address, const char* fake_handshake_server_addr, + int per_connect_deadline_seconds, size_t loops, + grpc_connectivity_state expected_connectivity_states, + int reconnect_backoff_ms) + : server_address_(grpc_core::UniquePtr(gpr_strdup(server_address))), + fake_handshake_server_addr_( + grpc_core::UniquePtr(gpr_strdup(fake_handshake_server_addr))), + per_connect_deadline_seconds_(per_connect_deadline_seconds), + loops_(loops), + expected_connectivity_states_(expected_connectivity_states), + reconnect_backoff_ms_(reconnect_backoff_ms) { + thd_ = absl::make_unique(ConnectLoop, this); + } + + ~ConnectLoopRunner() { thd_->join(); } + + static void ConnectLoop(const ConnectLoopRunner* self) { + for (size_t i = 0; i < self->loops_; i++) { + gpr_log(GPR_DEBUG, "runner:%p connect_loop begin loop %ld", self, i); + grpc_completion_queue* cq = + grpc_completion_queue_create_for_next(nullptr); + grpc_channel* channel = create_secure_channel_for_test( + self->server_address_.get(), self->fake_handshake_server_addr_.get(), + self->reconnect_backoff_ms_); + // Connect, forcing an ALTS handshake + gpr_timespec connect_deadline = + grpc_timeout_seconds_to_deadline(self->per_connect_deadline_seconds_); + grpc_connectivity_state state = + grpc_channel_check_connectivity_state(channel, 1); + ASSERT_EQ(state, GRPC_CHANNEL_IDLE); + while (state != self->expected_connectivity_states_) { + if (self->expected_connectivity_states_ == + GRPC_CHANNEL_TRANSIENT_FAILURE) { + ASSERT_NE(state, GRPC_CHANNEL_READY); // sanity check + } else { + ASSERT_EQ(self->expected_connectivity_states_, GRPC_CHANNEL_READY); + } + grpc_channel_watch_connectivity_state( + channel, state, gpr_inf_future(GPR_CLOCK_REALTIME), cq, nullptr); + grpc_event ev = + grpc_completion_queue_next(cq, connect_deadline, nullptr); + ASSERT_EQ(ev.type, GRPC_OP_COMPLETE) + << "connect_loop runner:" << std::hex << self + << " got ev.type:" << ev.type << " i:" << i; + ASSERT_TRUE(ev.success); + grpc_connectivity_state prev_state = state; + state = grpc_channel_check_connectivity_state(channel, 1); + if (self->expected_connectivity_states_ == + GRPC_CHANNEL_TRANSIENT_FAILURE && + prev_state == GRPC_CHANNEL_CONNECTING && + state == GRPC_CHANNEL_CONNECTING) { + // Detect a race in state checking: if the watch_connectivity_state + // completed from prior state "connecting", this could be because the + // channel momentarily entered state "transient failure", which is + // what we want. However, if the channel immediately re-enters + // "connecting" state, then the new state check might still result in + // "connecting". A continuous repeat of this can cause this loop to + // never terminate in time. So take this scenario to indicate that the + // channel momentarily entered transient failure. + break; + } + } + grpc_channel_destroy(channel); + grpc_completion_queue_shutdown(cq); + drain_cq(cq); + grpc_completion_queue_destroy(cq); + gpr_log(GPR_DEBUG, "runner:%p connect_loop finished loop %ld", self, i); + } + } + + private: + grpc_core::UniquePtr server_address_; + grpc_core::UniquePtr fake_handshake_server_addr_; + int per_connect_deadline_seconds_; + size_t loops_; + grpc_connectivity_state expected_connectivity_states_; + std::unique_ptr thd_; + int reconnect_backoff_ms_; +}; + +// Perform a few ALTS handshakes sequentially (using the fake, in-process ALTS +// handshake server). +TEST(AltsConcurrentConnectivityTest, TestBasicClientServerHandshakes) { + FakeHandshakeServer fake_handshake_server( + true /* check num concurrent rpcs */); + TestServer test_server; + { + ConnectLoopRunner runner( + test_server.address(), fake_handshake_server.address(), + 5 /* per connect deadline seconds */, 10 /* loops */, + GRPC_CHANNEL_READY /* expected connectivity states */, + 0 /* reconnect_backoff_ms unset */); + } +} + +/* Run a bunch of concurrent ALTS handshakes on concurrent channels + * (using the fake, in-process handshake server). */ +TEST(AltsConcurrentConnectivityTest, TestConcurrentClientServerHandshakes) { + FakeHandshakeServer fake_handshake_server( + true /* check num concurrent rpcs */); + // Test + { + TestServer test_server; + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + size_t num_concurrent_connects = 50; + std::vector> connect_loop_runners; + gpr_log(GPR_DEBUG, + "start performing concurrent expected-to-succeed connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back(absl::make_unique( + test_server.address(), fake_handshake_server.address(), + 15 /* per connect deadline seconds */, 5 /* loops */, + GRPC_CHANNEL_READY /* expected connectivity states */, + 0 /* reconnect_backoff_ms unset */)); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, + "done performing concurrent expected-to-succeed connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_DEBUG, "Test took longer than expected."); + abort(); + } + } +} + +class FakeTcpServer { + public: + enum ProcessReadResult { + CONTINUE_READING, + CLOSE_SOCKET, + }; + + enum class AcceptMode { + kWaitForClientToSendFirstBytes, // useful for emulating ALTS based + // grpc servers + kEagerlySendSettings, // useful for emulating insecure grpc servers (e.g. + // ALTS handshake servers) + }; + + explicit FakeTcpServer( + AcceptMode accept_mode, + const std::function& process_read_cb) + : accept_mode_(accept_mode), process_read_cb_(process_read_cb) { + port_ = grpc_pick_unused_port_or_die(); + accept_socket_ = socket(AF_INET6, SOCK_STREAM, 0); + address_ = absl::StrCat("[::]:", port_); + GPR_ASSERT(accept_socket_ != -1); + if (accept_socket_ == -1) { + gpr_log(GPR_ERROR, "Failed to create socket: %d", errno); + abort(); + } + int val = 1; + if (setsockopt(accept_socket_, SOL_SOCKET, SO_REUSEADDR, &val, + sizeof(val)) != 0) { + gpr_log(GPR_ERROR, + "Failed to set SO_REUSEADDR on socket bound to [::1]:%d : %d", + port_, errno); + abort(); + } + if (fcntl(accept_socket_, F_SETFL, O_NONBLOCK) != 0) { + gpr_log(GPR_ERROR, "Failed to set O_NONBLOCK on socket: %d", errno); + abort(); + } + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port_); + (reinterpret_cast(&addr.sin6_addr))[15] = 1; + if (bind(accept_socket_, reinterpret_cast(&addr), + sizeof(addr)) != 0) { + gpr_log(GPR_ERROR, "Failed to bind socket to [::1]:%d : %d", port_, + errno); + abort(); + } + if (listen(accept_socket_, 100)) { + gpr_log(GPR_ERROR, "Failed to listen on socket bound to [::1]:%d : %d", + port_, errno); + abort(); + } + gpr_event_init(&stop_ev_); + run_server_loop_thd_ = absl::make_unique(RunServerLoop, this); + } + + ~FakeTcpServer() { + gpr_log(GPR_DEBUG, + "FakeTcpServer stop and " + "join server thread"); + gpr_event_set(&stop_ev_, reinterpret_cast(1)); + run_server_loop_thd_->join(); + gpr_log(GPR_DEBUG, + "FakeTcpServer join server " + "thread complete"); + } + + const char* address() { return address_.c_str(); } + + static ProcessReadResult CloseSocketUponReceivingBytesFromPeer( + int bytes_received_size, int read_error, int s) { + if (bytes_received_size < 0 && read_error != EAGAIN && + read_error != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s, + errno); + abort(); + } + if (bytes_received_size >= 0) { + gpr_log(GPR_DEBUG, + "Fake TCP server received %d bytes from peer socket: %d. Close " + "the " + "connection.", + bytes_received_size, s); + return CLOSE_SOCKET; + } + return CONTINUE_READING; + } + + static ProcessReadResult CloseSocketUponCloseFromPeer(int bytes_received_size, + int read_error, int s) { + if (bytes_received_size < 0 && read_error != EAGAIN && + read_error != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to receive from peer socket: %d. errno: %d", s, + errno); + abort(); + } + if (bytes_received_size == 0) { + // The peer has shut down the connection. + gpr_log(GPR_DEBUG, + "Fake TCP server received 0 bytes from peer socket: %d. Close " + "the " + "connection.", + s); + return CLOSE_SOCKET; + } + return CONTINUE_READING; + } + + class FakeTcpServerPeer { + public: + explicit FakeTcpServerPeer(int fd) : fd_(fd) {} + + ~FakeTcpServerPeer() { close(fd_); } + + void MaybeContinueSendingSettings() { + // https://tools.ietf.org/html/rfc7540#section-4.1 + const std::vector kEmptyHttp2SettingsFrame = { + 0x00, 0x00, 0x00, // length + 0x04, // settings type + 0x00, // flags + 0x00, 0x00, 0x00, 0x00 // stream identifier + }; + if (total_bytes_sent_ < int(kEmptyHttp2SettingsFrame.size())) { + int bytes_to_send = kEmptyHttp2SettingsFrame.size() - total_bytes_sent_; + int bytes_sent = + send(fd_, kEmptyHttp2SettingsFrame.data() + total_bytes_sent_, + bytes_to_send, 0); + if (bytes_sent < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + gpr_log(GPR_ERROR, + "Fake TCP server encountered unexpected error:%d |%s| " + "sending %d bytes on fd:%d", + errno, strerror(errno), bytes_to_send, fd_); + GPR_ASSERT(0); + } else if (bytes_sent > 0) { + total_bytes_sent_ += bytes_sent; + GPR_ASSERT(total_bytes_sent_ <= int(kEmptyHttp2SettingsFrame.size())); + } + } + } + + int fd() { return fd_; } + + private: + int fd_; + int total_bytes_sent_ = 0; + }; + + // Run a loop that periodically, every 10 ms: + // 1) Checks if there are any new TCP connections to accept. + // 2) Checks if any data has arrived yet on established connections, + // and reads from them if so, processing the sockets as configured. + static void RunServerLoop(FakeTcpServer* self) { + std::set> peers; + while (!gpr_event_get(&self->stop_ev_)) { + int p = accept(self->accept_socket_, nullptr, nullptr); + if (p == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { + gpr_log(GPR_ERROR, "Failed to accept connection: %d", errno); + abort(); + } + if (p != -1) { + gpr_log(GPR_DEBUG, "accepted peer socket: %d", p); + if (fcntl(p, F_SETFL, O_NONBLOCK) != 0) { + gpr_log(GPR_ERROR, + "Failed to set O_NONBLOCK on peer socket:%d errno:%d", p, + errno); + abort(); + } + peers.insert(absl::make_unique(p)); + } + auto it = peers.begin(); + while (it != peers.end()) { + FakeTcpServerPeer* peer = (*it).get(); + if (self->accept_mode_ == AcceptMode::kEagerlySendSettings) { + peer->MaybeContinueSendingSettings(); + } + char buf[100]; + int bytes_received_size = recv(peer->fd(), buf, 100, 0); + ProcessReadResult r = + self->process_read_cb_(bytes_received_size, errno, peer->fd()); + if (r == CLOSE_SOCKET) { + it = peers.erase(it); + } else { + GPR_ASSERT(r == CONTINUE_READING); + it++; + } + } + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(10, GPR_TIMESPAN))); + } + close(self->accept_socket_); + } + + private: + int accept_socket_; + int port_; + gpr_event stop_ev_; + std::string address_; + std::unique_ptr run_server_loop_thd_; + const AcceptMode accept_mode_; + std::function process_read_cb_; +}; + +/* This test is intended to make sure that ALTS handshakes we correctly + * fail fast when the security handshaker gets an error while reading + * from the remote peer, after having earlier sent the first bytes of the + * ALTS handshake to the peer, i.e. after getting into the middle of a + * handshake. */ +TEST(AltsConcurrentConnectivityTest, + TestHandshakeFailsFastWhenPeerEndpointClosesConnectionAfterAccepting) { + // Don't enforce the number of concurrent rpcs for the fake handshake + // server in this test, because this test will involve handshake RPCs + // getting cancelled. Because there isn't explicit synchronization between + // an ALTS handshake client's RECV_STATUS op completing after call + // cancellation, and the corresponding fake handshake server's sync + // method handler returning, enforcing a limit on the number of active + // RPCs at the fake handshake server would be inherently racey. + FakeHandshakeServer fake_handshake_server( + false /* check num concurrent rpcs */); + // The fake_backend_server emulates a secure (ALTS based) gRPC backend. So + // it waits for the client to send the first bytes. + FakeTcpServer fake_backend_server( + FakeTcpServer::AcceptMode::kWaitForClientToSendFirstBytes, + FakeTcpServer::CloseSocketUponReceivingBytesFromPeer); + { + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + std::vector> connect_loop_runners; + size_t num_concurrent_connects = 100; + gpr_log(GPR_DEBUG, "start performing concurrent expected-to-fail connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back(absl::make_unique( + fake_backend_server.address(), fake_handshake_server.address(), + 10 /* per connect deadline seconds */, 3 /* loops */, + GRPC_CHANNEL_TRANSIENT_FAILURE /* expected connectivity states */, + 0 /* reconnect_backoff_ms unset */)); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, "done performing concurrent expected-to-fail connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_ERROR, + "Exceeded test deadline. ALTS handshakes might not be failing " + "fast when the peer endpoint closes the connection abruptly"); + abort(); + } + } +} + +/* This test is intended to make sure that ALTS handshakes correctly + * fail fast when the ALTS handshake server fails incoming handshakes fast. */ +TEST(AltsConcurrentConnectivityTest, + TestHandshakeFailsFastWhenHandshakeServerClosesConnectionAfterAccepting) { + // The fake_handshake_server emulates a broken ALTS handshaker, which + // is an insecure server. So send settings to the client eagerly. + FakeTcpServer fake_handshake_server( + FakeTcpServer::AcceptMode::kEagerlySendSettings, + FakeTcpServer::CloseSocketUponReceivingBytesFromPeer); + // The fake_backend_server emulates a secure (ALTS based) server, so wait + // for the client to send the first bytes. + FakeTcpServer fake_backend_server( + FakeTcpServer::AcceptMode::kWaitForClientToSendFirstBytes, + FakeTcpServer::CloseSocketUponCloseFromPeer); + { + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + std::vector> connect_loop_runners; + size_t num_concurrent_connects = 100; + gpr_log(GPR_DEBUG, "start performing concurrent expected-to-fail connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back(absl::make_unique( + fake_backend_server.address(), fake_handshake_server.address(), + 20 /* per connect deadline seconds */, 2 /* loops */, + GRPC_CHANNEL_TRANSIENT_FAILURE /* expected connectivity states */, + 0 /* reconnect_backoff_ms unset */)); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, "done performing concurrent expected-to-fail connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_ERROR, + "Exceeded test deadline. ALTS handshakes might not be failing " + "fast when the handshake server closes new connections"); + abort(); + } + } +} + +/* This test is intended to make sure that ALTS handshakes correctly + * fail fast when the ALTS handshake server is non-responsive, in which case + * the overall connection deadline kicks in. */ +TEST(AltsConcurrentConnectivityTest, + TestHandshakeFailsFastWhenHandshakeServerHangsAfterAccepting) { + // fake_handshake_server emulates an insecure server, so send settings first. + // It will be unresponsive for the rest of the connection, though. + FakeTcpServer fake_handshake_server( + FakeTcpServer::AcceptMode::kEagerlySendSettings, + FakeTcpServer::CloseSocketUponCloseFromPeer); + // fake_backend_server emulates an ALTS based server, so wait for the client + // to send the first bytes. + FakeTcpServer fake_backend_server( + FakeTcpServer::AcceptMode::kWaitForClientToSendFirstBytes, + FakeTcpServer::CloseSocketUponCloseFromPeer); + { + gpr_timespec test_deadline = grpc_timeout_seconds_to_deadline(20); + std::vector> connect_loop_runners; + size_t num_concurrent_connects = 100; + gpr_log(GPR_DEBUG, "start performing concurrent expected-to-fail connects"); + for (size_t i = 0; i < num_concurrent_connects; i++) { + connect_loop_runners.push_back(absl::make_unique( + fake_backend_server.address(), fake_handshake_server.address(), + 10 /* per connect deadline seconds */, 2 /* loops */, + GRPC_CHANNEL_TRANSIENT_FAILURE /* expected connectivity states */, + 100 /* reconnect_backoff_ms */)); + } + connect_loop_runners.clear(); + gpr_log(GPR_DEBUG, "done performing concurrent expected-to-fail connects"); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), test_deadline) > 0) { + gpr_log(GPR_ERROR, + "Exceeded test deadline. ALTS handshakes might not be failing " + "fast when the handshake server is non-response timeout occurs"); + abort(); + } + } +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc new file mode 100644 index 00000000..2a36f0aa --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc @@ -0,0 +1,511 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" + +#include "upb/upb.hpp" + +#include + +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" +#include "src/core/tsi/transport_security.h" +#include "src/core/tsi/transport_security_interface.h" +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" +#include "test/core/util/test_config.h" + +#define ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME "Hello Google" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME "bigtable.google.api.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1 "A@google.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2 "B@google.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE (64 * 1024) + +const size_t kHandshakerClientOpNum = 4; +const size_t kMaxRpcVersionMajor = 3; +const size_t kMaxRpcVersionMinor = 2; +const size_t kMinRpcVersionMajor = 2; +const size_t kMinRpcVersionMinor = 1; + +using grpc_core::internal::alts_handshaker_client_get_closure_for_testing; +using grpc_core::internal:: + alts_handshaker_client_get_initial_metadata_for_testing; +using grpc_core::internal:: + alts_handshaker_client_get_recv_buffer_addr_for_testing; +using grpc_core::internal::alts_handshaker_client_get_send_buffer_for_testing; +using grpc_core::internal:: + alts_handshaker_client_on_status_received_for_testing; +using grpc_core::internal::alts_handshaker_client_set_cb_for_testing; +using grpc_core::internal::alts_handshaker_client_set_grpc_caller_for_testing; + +typedef struct alts_handshaker_client_test_config { + grpc_channel* channel; + grpc_completion_queue* cq; + alts_handshaker_client* client; + alts_handshaker_client* server; + grpc_slice out_frame; +} alts_handshaker_client_test_config; + +static void validate_rpc_protocol_versions( + const grpc_gcp_RpcProtocolVersions* versions) { + GPR_ASSERT(versions != nullptr); + const grpc_gcp_RpcProtocolVersions_Version* max_version = + grpc_gcp_RpcProtocolVersions_max_rpc_version(versions); + const grpc_gcp_RpcProtocolVersions_Version* min_version = + grpc_gcp_RpcProtocolVersions_min_rpc_version(versions); + GPR_ASSERT(grpc_gcp_RpcProtocolVersions_Version_major(max_version) == + kMaxRpcVersionMajor); + GPR_ASSERT(grpc_gcp_RpcProtocolVersions_Version_minor(max_version) == + kMaxRpcVersionMinor); + GPR_ASSERT(grpc_gcp_RpcProtocolVersions_Version_major(min_version) == + kMinRpcVersionMajor); + GPR_ASSERT(grpc_gcp_RpcProtocolVersions_Version_minor(min_version) == + kMinRpcVersionMinor); +} + +static void validate_target_identities( + const grpc_gcp_Identity* const* target_identities, + size_t target_identities_count) { + GPR_ASSERT(target_identities_count == 2); + const grpc_gcp_Identity* identity1 = target_identities[1]; + const grpc_gcp_Identity* identity2 = target_identities[0]; + GPR_ASSERT(upb_strview_eql( + grpc_gcp_Identity_service_account(identity1), + upb_strview_makez(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1))); + GPR_ASSERT(upb_strview_eql( + grpc_gcp_Identity_service_account(identity2), + upb_strview_makez(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2))); +} + +/** + * Validate if grpc operation data is correctly populated with the fields of + * ALTS handshaker client. + */ +static bool validate_op(alts_handshaker_client* c, const grpc_op* op, + size_t nops, bool is_start) { + GPR_ASSERT(c != nullptr && op != nullptr && nops != 0); + bool ok = true; + grpc_op* start_op = const_cast(op); + if (is_start) { + ok &= (op->op == GRPC_OP_SEND_INITIAL_METADATA); + ok &= (op->data.send_initial_metadata.count == 0); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + ok &= (op->op == GRPC_OP_RECV_INITIAL_METADATA); + ok &= (op->data.recv_initial_metadata.recv_initial_metadata == + alts_handshaker_client_get_initial_metadata_for_testing(c)); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + } + ok &= (op->op == GRPC_OP_SEND_MESSAGE); + ok &= (op->data.send_message.send_message == + alts_handshaker_client_get_send_buffer_for_testing(c)); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + ok &= (op->op == GRPC_OP_RECV_MESSAGE); + ok &= (op->data.recv_message.recv_message == + alts_handshaker_client_get_recv_buffer_addr_for_testing(c)); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + return ok; +} + +static grpc_gcp_HandshakerReq* deserialize_handshaker_req( + grpc_byte_buffer* buffer, upb_arena* arena) { + GPR_ASSERT(buffer != nullptr); + grpc_byte_buffer_reader bbr; + GPR_ASSERT(grpc_byte_buffer_reader_init(&bbr, buffer)); + grpc_slice slice = grpc_byte_buffer_reader_readall(&bbr); + grpc_gcp_HandshakerReq* req = grpc_gcp_handshaker_req_decode(slice, arena); + GPR_ASSERT(req != nullptr); + grpc_slice_unref(slice); + grpc_byte_buffer_reader_destroy(&bbr); + return req; +} + +static bool is_recv_status_op(const grpc_op* op, size_t nops) { + return nops == 1 && op->op == GRPC_OP_RECV_STATUS_ON_CLIENT; +} + +/** + * A mock grpc_caller used to check if client_start, server_start, and next + * operations correctly handle invalid arguments. It should not be called. + */ +static grpc_call_error check_must_not_be_called(grpc_call* /*call*/, + const grpc_op* /*ops*/, + size_t /*nops*/, + grpc_closure* /*tag*/) { + GPR_ASSERT(0); +} + +/** + * A mock grpc_caller used to check correct execution of client_start operation. + * It checks if the client_start handshaker request is populated with correct + * handshake_security_protocol, application_protocol, record_protocol and + * max_frame_size, and op is correctly populated. + */ +static grpc_call_error check_client_start_success(grpc_call* /*call*/, + const grpc_op* op, + size_t nops, + grpc_closure* closure) { + // RECV_STATUS ops are asserted to always succeed + if (is_recv_status_op(op, nops)) { + return GRPC_CALL_OK; + } + upb::Arena arena; + alts_handshaker_client* client = + static_cast(closure->cb_arg); + GPR_ASSERT(alts_handshaker_client_get_closure_for_testing(client) == closure); + grpc_gcp_HandshakerReq* req = deserialize_handshaker_req( + alts_handshaker_client_get_send_buffer_for_testing(client), arena.ptr()); + const grpc_gcp_StartClientHandshakeReq* client_start = + grpc_gcp_HandshakerReq_client_start(req); + GPR_ASSERT(grpc_gcp_StartClientHandshakeReq_handshake_security_protocol( + client_start) == grpc_gcp_ALTS); + upb_strview const* application_protocols = + grpc_gcp_StartClientHandshakeReq_application_protocols(client_start, + nullptr); + GPR_ASSERT(upb_strview_eql(application_protocols[0], + upb_strview_makez(ALTS_APPLICATION_PROTOCOL))); + upb_strview const* record_protocols = + grpc_gcp_StartClientHandshakeReq_record_protocols(client_start, nullptr); + GPR_ASSERT(upb_strview_eql(record_protocols[0], + upb_strview_makez(ALTS_RECORD_PROTOCOL))); + const grpc_gcp_RpcProtocolVersions* rpc_protocol_versions = + grpc_gcp_StartClientHandshakeReq_rpc_versions(client_start); + validate_rpc_protocol_versions(rpc_protocol_versions); + size_t target_identities_count; + const grpc_gcp_Identity* const* target_identities = + grpc_gcp_StartClientHandshakeReq_target_identities( + client_start, &target_identities_count); + validate_target_identities(target_identities, target_identities_count); + GPR_ASSERT(upb_strview_eql( + grpc_gcp_StartClientHandshakeReq_target_name(client_start), + upb_strview_makez(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME))); + GPR_ASSERT(grpc_gcp_StartClientHandshakeReq_max_frame_size(client_start) == + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); + GPR_ASSERT(validate_op(client, op, nops, true /* is_start */)); + return GRPC_CALL_OK; +} + +/** + * A mock grpc_caller used to check correct execution of server_start operation. + * It checks if the server_start handshaker request is populated with correct + * handshake_security_protocol, application_protocol, record_protocol and + * max_frame_size, and op is correctly populated. + */ +static grpc_call_error check_server_start_success(grpc_call* /*call*/, + const grpc_op* op, + size_t nops, + grpc_closure* closure) { + // RECV_STATUS ops are asserted to always succeed + if (is_recv_status_op(op, nops)) { + return GRPC_CALL_OK; + } + upb::Arena arena; + alts_handshaker_client* client = + static_cast(closure->cb_arg); + GPR_ASSERT(alts_handshaker_client_get_closure_for_testing(client) == closure); + grpc_gcp_HandshakerReq* req = deserialize_handshaker_req( + alts_handshaker_client_get_send_buffer_for_testing(client), arena.ptr()); + const grpc_gcp_StartServerHandshakeReq* server_start = + grpc_gcp_HandshakerReq_server_start(req); + upb_strview const* application_protocols = + grpc_gcp_StartServerHandshakeReq_application_protocols(server_start, + nullptr); + GPR_ASSERT(upb_strview_eql(application_protocols[0], + upb_strview_makez(ALTS_APPLICATION_PROTOCOL))); + GPR_ASSERT(grpc_gcp_StartServerHandshakeReq_handshake_parameters_size( + server_start) == 1); + grpc_gcp_ServerHandshakeParameters* value; + GPR_ASSERT(grpc_gcp_StartServerHandshakeReq_handshake_parameters_get( + server_start, grpc_gcp_ALTS, &value)); + upb_strview const* record_protocols = + grpc_gcp_ServerHandshakeParameters_record_protocols(value, nullptr); + GPR_ASSERT(upb_strview_eql(record_protocols[0], + upb_strview_makez(ALTS_RECORD_PROTOCOL))); + validate_rpc_protocol_versions( + grpc_gcp_StartServerHandshakeReq_rpc_versions(server_start)); + GPR_ASSERT(grpc_gcp_StartServerHandshakeReq_max_frame_size(server_start) == + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); + GPR_ASSERT(validate_op(client, op, nops, true /* is_start */)); + return GRPC_CALL_OK; +} + +/** + * A mock grpc_caller used to check correct execution of next operation. It + * checks if the next handshaker request is populated with correct information, + * and op is correctly populated. + */ +static grpc_call_error check_next_success(grpc_call* /*call*/, + const grpc_op* op, size_t nops, + grpc_closure* closure) { + upb::Arena arena; + alts_handshaker_client* client = + static_cast(closure->cb_arg); + GPR_ASSERT(alts_handshaker_client_get_closure_for_testing(client) == closure); + grpc_gcp_HandshakerReq* req = deserialize_handshaker_req( + alts_handshaker_client_get_send_buffer_for_testing(client), arena.ptr()); + const grpc_gcp_NextHandshakeMessageReq* next = + grpc_gcp_HandshakerReq_next(req); + GPR_ASSERT(upb_strview_eql( + grpc_gcp_NextHandshakeMessageReq_in_bytes(next), + upb_strview_makez(ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME))); + GPR_ASSERT(validate_op(client, op, nops, false /* is_start */)); + return GRPC_CALL_OK; +} + +/** + * A mock grpc_caller used to check if client_start, server_start, and next + * operations correctly handle the situation when the grpc call made to the + * handshaker service fails. + */ +static grpc_call_error check_grpc_call_failure(grpc_call* /*call*/, + const grpc_op* op, size_t nops, + grpc_closure* /*tag*/) { + // RECV_STATUS ops are asserted to always succeed + if (is_recv_status_op(op, nops)) { + return GRPC_CALL_OK; + } + return GRPC_CALL_ERROR; +} + +static grpc_alts_credentials_options* create_credentials_options( + bool is_client) { + grpc_alts_credentials_options* options = + is_client ? grpc_alts_credentials_client_options_create() + : grpc_alts_credentials_server_options_create(); + if (is_client) { + grpc_alts_credentials_client_options_add_target_service_account( + options, ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1); + grpc_alts_credentials_client_options_add_target_service_account( + options, ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2); + } + grpc_gcp_rpc_protocol_versions* versions = &options->rpc_versions; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + versions, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + versions, kMinRpcVersionMajor, kMinRpcVersionMinor)); + return options; +} + +static alts_handshaker_client_test_config* create_config() { + alts_handshaker_client_test_config* config = + static_cast( + gpr_zalloc(sizeof(*config))); + config->channel = grpc_insecure_channel_create( + ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, nullptr, nullptr); + config->cq = grpc_completion_queue_create_for_next(nullptr); + grpc_alts_credentials_options* client_options = + create_credentials_options(true /* is_client */); + grpc_alts_credentials_options* server_options = + create_credentials_options(false /* is_client */); + config->server = alts_grpc_handshaker_client_create( + nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, + nullptr, server_options, + grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), + nullptr, nullptr, nullptr, nullptr, false, + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); + config->client = alts_grpc_handshaker_client_create( + nullptr, config->channel, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, + nullptr, client_options, + grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME), + nullptr, nullptr, nullptr, nullptr, true, + ALTS_HANDSHAKER_CLIENT_TEST_MAX_FRAME_SIZE); + GPR_ASSERT(config->client != nullptr); + GPR_ASSERT(config->server != nullptr); + grpc_alts_credentials_options_destroy(client_options); + grpc_alts_credentials_options_destroy(server_options); + config->out_frame = + grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME); + return config; +} + +static void destroy_config(alts_handshaker_client_test_config* config) { + if (config == nullptr) { + return; + } + grpc_completion_queue_destroy(config->cq); + grpc_channel_destroy(config->channel); + alts_handshaker_client_destroy(config->client); + alts_handshaker_client_destroy(config->server); + grpc_slice_unref(config->out_frame); + gpr_free(config); +} + +static void schedule_request_invalid_arg_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + /* Tests. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_must_not_be_called); + /* Check client_start. */ + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_start_client(nullptr) == + TSI_INVALID_ARGUMENT); + } + /* Check server_start. */ + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_start_server(config->server, nullptr) == + TSI_INVALID_ARGUMENT); + } + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_start_server( + nullptr, &config->out_frame) == TSI_INVALID_ARGUMENT); + } + /* Check next. */ + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_next(config->client, nullptr) == + TSI_INVALID_ARGUMENT); + } + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_next(nullptr, &config->out_frame) == + TSI_INVALID_ARGUMENT); + } + /* Check shutdown. */ + alts_handshaker_client_shutdown(nullptr); + /* Cleanup. */ + destroy_config(config); +} + +static void schedule_request_success_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + /* Check client_start success. */ + alts_handshaker_client_set_grpc_caller_for_testing( + config->client, check_client_start_success); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_start_client(config->client) == TSI_OK); + } + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_next(nullptr, &config->out_frame) == + TSI_INVALID_ARGUMENT); + } + /* Check server_start success. */ + alts_handshaker_client_set_grpc_caller_for_testing( + config->server, check_server_start_success); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_start_server( + config->server, &config->out_frame) == TSI_OK); + } + /* Check client next success. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_next_success); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_next(config->client, + &config->out_frame) == TSI_OK); + } + /* Check server next success. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->server, + check_next_success); + { + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(alts_handshaker_client_next(config->server, + &config->out_frame) == TSI_OK); + } + /* Cleanup. */ + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + config->client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + alts_handshaker_client_on_status_received_for_testing( + config->server, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + destroy_config(config); +} + +static void tsi_cb_assert_tsi_internal_error( + tsi_result status, void* /*user_data*/, + const unsigned char* /*bytes_to_send*/, size_t /*bytes_to_send_size*/, + tsi_handshaker_result* /*result*/) { + GPR_ASSERT(status == TSI_INTERNAL_ERROR); +} + +static void schedule_request_grpc_call_failure_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + /* Check client_start failure. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_grpc_call_failure); + { + grpc_core::ExecCtx exec_ctx; + // TODO(apolcyn): go back to asserting TSI_INTERNAL_ERROR as return + // value instead of callback status, after removing the global + // queue in https://github.com/grpc/grpc/pull/20722 + alts_handshaker_client_set_cb_for_testing(config->client, + tsi_cb_assert_tsi_internal_error); + alts_handshaker_client_start_client(config->client); + } + /* Check server_start failure. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->server, + check_grpc_call_failure); + { + grpc_core::ExecCtx exec_ctx; + // TODO(apolcyn): go back to asserting TSI_INTERNAL_ERROR as return + // value instead of callback status, after removing the global + // queue in https://github.com/grpc/grpc/pull/20722 + alts_handshaker_client_set_cb_for_testing(config->server, + tsi_cb_assert_tsi_internal_error); + alts_handshaker_client_start_server(config->server, &config->out_frame); + } + { + grpc_core::ExecCtx exec_ctx; + /* Check client next failure. */ + GPR_ASSERT(alts_handshaker_client_next( + config->client, &config->out_frame) == TSI_INTERNAL_ERROR); + } + { + grpc_core::ExecCtx exec_ctx; + /* Check server next failure. */ + GPR_ASSERT(alts_handshaker_client_next( + config->server, &config->out_frame) == TSI_INTERNAL_ERROR); + } + /* Cleanup. */ + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + config->client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + alts_handshaker_client_on_status_received_for_testing( + config->server, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + destroy_config(config); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + /* Initialization. */ + grpc_init(); + grpc_alts_shared_resource_dedicated_init(); + /* Tests. */ + schedule_request_invalid_arg_test(); + schedule_request_success_test(); + schedule_request_grpc_call_failure_test(); + /* Cleanup. */ + grpc_alts_shared_resource_dedicated_shutdown(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc new file mode 100644 index 00000000..2ab73a26 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc @@ -0,0 +1,163 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +bool grpc_gcp_handshaker_resp_set_peer_rpc_versions( + grpc_gcp_HandshakerResp* resp, upb_arena* arena, uint32_t max_major, + uint32_t max_minor, uint32_t min_major, uint32_t min_minor) { + if (resp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to " + "grpc_gcp_handshaker_resp_set_peer_rpc_versions()."); + return false; + } + grpc_gcp_rpc_protocol_versions versions; + versions.max_rpc_version.major = max_major; + versions.max_rpc_version.minor = max_minor; + versions.min_rpc_version.major = min_major; + versions.min_rpc_version.minor = min_minor; + grpc_gcp_HandshakerResult* result = + grpc_gcp_HandshakerResp_mutable_result(resp, arena); + grpc_gcp_RpcProtocolVersions* upb_versions = + grpc_gcp_HandshakerResult_mutable_peer_rpc_versions(result, arena); + grpc_gcp_RpcProtocolVersions_assign_from_struct(upb_versions, arena, + &versions); + return true; +} + +grpc_gcp_HandshakerReq* grpc_gcp_handshaker_req_decode(grpc_slice slice, + upb_arena* arena) { + size_t buf_size = GPR_SLICE_LENGTH(slice); + void* buf = upb_arena_malloc(arena, buf_size); + memcpy(buf, reinterpret_cast(GPR_SLICE_START_PTR(slice)), + buf_size); + grpc_gcp_HandshakerReq* resp = grpc_gcp_HandshakerReq_parse( + reinterpret_cast(buf), buf_size, arena); + if (!resp) { + gpr_log(GPR_ERROR, "grpc_gcp_HandshakerReq decode error"); + return nullptr; + } + return resp; +} + +/* Check equality of a pair of grpc_gcp_identity fields. */ +static bool handshaker_identity_equals(const grpc_gcp_Identity* l_id, + const grpc_gcp_Identity* r_id) { + if ((grpc_gcp_Identity_has_service_account(l_id) != + grpc_gcp_Identity_has_service_account(r_id)) || + (grpc_gcp_Identity_has_hostname(l_id) != + grpc_gcp_Identity_has_hostname(r_id))) { + return false; + } + + if (grpc_gcp_Identity_has_service_account(l_id)) { + if (!upb_strview_eql(grpc_gcp_Identity_service_account(l_id), + grpc_gcp_Identity_service_account(r_id))) { + return false; + } + } else if (grpc_gcp_Identity_has_hostname(l_id)) { + if (!upb_strview_eql(grpc_gcp_Identity_hostname(l_id), + grpc_gcp_Identity_hostname(r_id))) { + return false; + } + } + return true; +} + +static bool handshaker_rpc_versions_equals( + const grpc_gcp_RpcProtocolVersions* l_version, + const grpc_gcp_RpcProtocolVersions* r_version) { + const grpc_gcp_RpcProtocolVersions_Version* l_maxver = + grpc_gcp_RpcProtocolVersions_max_rpc_version(l_version); + const grpc_gcp_RpcProtocolVersions_Version* r_maxver = + grpc_gcp_RpcProtocolVersions_max_rpc_version(r_version); + const grpc_gcp_RpcProtocolVersions_Version* l_minver = + grpc_gcp_RpcProtocolVersions_min_rpc_version(l_version); + const grpc_gcp_RpcProtocolVersions_Version* r_minver = + grpc_gcp_RpcProtocolVersions_min_rpc_version(r_version); + return (grpc_gcp_RpcProtocolVersions_Version_major(l_maxver) == + grpc_gcp_RpcProtocolVersions_Version_major(r_maxver)) && + (grpc_gcp_RpcProtocolVersions_Version_minor(l_maxver) == + grpc_gcp_RpcProtocolVersions_Version_minor(r_maxver)) && + (grpc_gcp_RpcProtocolVersions_Version_major(l_minver) == + grpc_gcp_RpcProtocolVersions_Version_major(r_minver)) && + (grpc_gcp_RpcProtocolVersions_Version_minor(l_minver) == + grpc_gcp_RpcProtocolVersions_Version_minor(r_minver)); +} + +/* Check equality of a pair of ALTS handshake responses. */ +bool grpc_gcp_handshaker_resp_equals(const grpc_gcp_HandshakerResp* l_resp, + const grpc_gcp_HandshakerResp* r_resp) { + return upb_strview_eql(grpc_gcp_HandshakerResp_out_frames(l_resp), + grpc_gcp_HandshakerResp_out_frames(r_resp)) && + (grpc_gcp_HandshakerResp_bytes_consumed(l_resp) == + grpc_gcp_HandshakerResp_bytes_consumed(l_resp)) && + grpc_gcp_handshaker_resp_result_equals( + grpc_gcp_HandshakerResp_result(l_resp), + grpc_gcp_HandshakerResp_result(r_resp)) && + grpc_gcp_handshaker_resp_status_equals( + grpc_gcp_HandshakerResp_status(l_resp), + grpc_gcp_HandshakerResp_status(r_resp)); +} + +/* This method checks equality of two handshaker response results. */ +bool grpc_gcp_handshaker_resp_result_equals( + const grpc_gcp_HandshakerResult* l_result, + const grpc_gcp_HandshakerResult* r_result) { + if (l_result == nullptr && r_result == nullptr) { + return true; + } else if ((l_result != nullptr && r_result == nullptr) || + (l_result == nullptr && r_result != nullptr)) { + return false; + } + return upb_strview_eql( + grpc_gcp_HandshakerResult_application_protocol(l_result), + grpc_gcp_HandshakerResult_application_protocol(r_result)) && + upb_strview_eql(grpc_gcp_HandshakerResult_record_protocol(l_result), + grpc_gcp_HandshakerResult_record_protocol(r_result)) && + upb_strview_eql(grpc_gcp_HandshakerResult_key_data(l_result), + grpc_gcp_HandshakerResult_key_data(r_result)) && + handshaker_identity_equals( + grpc_gcp_HandshakerResult_peer_identity(l_result), + grpc_gcp_HandshakerResult_peer_identity(r_result)) && + handshaker_identity_equals( + grpc_gcp_HandshakerResult_local_identity(l_result), + grpc_gcp_HandshakerResult_local_identity(r_result)) && + (grpc_gcp_HandshakerResult_keep_channel_open(l_result) == + grpc_gcp_HandshakerResult_keep_channel_open(r_result)) && + handshaker_rpc_versions_equals( + grpc_gcp_HandshakerResult_peer_rpc_versions(l_result), + grpc_gcp_HandshakerResult_peer_rpc_versions(r_result)); +} + +/* This method checks equality of two handshaker response statuses. */ +bool grpc_gcp_handshaker_resp_status_equals( + const grpc_gcp_HandshakerStatus* l_status, + const grpc_gcp_HandshakerStatus* r_status) { + if (l_status == nullptr && r_status == nullptr) { + return true; + } else if ((l_status != nullptr && r_status == nullptr) || + (l_status == nullptr && r_status != nullptr)) { + return false; + } + return (grpc_gcp_HandshakerStatus_code(l_status) == + grpc_gcp_HandshakerStatus_code(r_status)) && + upb_strview_eql(grpc_gcp_HandshakerStatus_details(l_status), + grpc_gcp_HandshakerStatus_details(r_status)); +} diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc new file mode 100644 index 00000000..0e373a0b --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -0,0 +1,1075 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" + +#include +#include + +#include "upb/upb.hpp" + +#include +#include + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" +#include "src/core/tsi/transport_security_grpc.h" +#include "src/proto/grpc/gcp/altscontext.upb.h" +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" +#include "test/core/util/test_config.h" + +#define ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES "Hello World" +#define ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME "Hello Google" +#define ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES "Hello " +#define ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES "Google" +#define ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY "chapi@service.google.com" +#define ALTS_TSI_HANDSHAKER_TEST_SECURITY_LEVEL "TSI_PRIVACY_AND_INTEGRITY" +#define ALTS_TSI_HANDSHAKER_TEST_KEY_DATA \ + "ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKL" +#define ALTS_TSI_HANDSHAKER_TEST_BUFFER_SIZE 100 +#define ALTS_TSI_HANDSHAKER_TEST_SLEEP_TIME_IN_SECONDS 2 +#define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR 3 +#define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR 2 +#define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR 2 +#define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR 1 +#define ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY "chapilocal@service.google.com" +#define ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL \ + "test application protocol" +#define ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL "test record protocol" +#define ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE (2 * 1024 * 1024) +#define ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_KEY "peer" +#define ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_VALUE "attributes" + +using grpc_core::internal::alts_handshaker_client_check_fields_for_testing; +using grpc_core::internal::alts_handshaker_client_get_handshaker_for_testing; +using grpc_core::internal:: + alts_handshaker_client_get_recv_buffer_addr_for_testing; +using grpc_core::internal:: + alts_handshaker_client_on_status_received_for_testing; +using grpc_core::internal::alts_handshaker_client_ref_for_testing; +using grpc_core::internal::alts_handshaker_client_set_cb_for_testing; +using grpc_core::internal::alts_handshaker_client_set_fields_for_testing; +using grpc_core::internal::alts_handshaker_client_set_recv_bytes_for_testing; +using grpc_core::internal::alts_tsi_handshaker_get_client_for_testing; +using grpc_core::internal::alts_tsi_handshaker_get_is_client_for_testing; +using grpc_core::internal::alts_tsi_handshaker_set_client_vtable_for_testing; +static bool should_handshaker_client_api_succeed = true; + +/* ALTS mock notification. */ +typedef struct notification { + gpr_cv cv; + gpr_mu mu; + bool notified; +} notification; + +/* Type of ALTS handshaker response. */ +typedef enum { + INVALID, + FAILED, + CLIENT_START, + SERVER_START, + CLIENT_NEXT, + SERVER_NEXT, +} alts_handshaker_response_type; + +static alts_handshaker_client* cb_event = nullptr; +static notification caller_to_tsi_notification; +static notification tsi_to_caller_notification; + +static void notification_init(notification* n) { + gpr_mu_init(&n->mu); + gpr_cv_init(&n->cv); + n->notified = false; +} + +static void notification_destroy(notification* n) { + gpr_mu_destroy(&n->mu); + gpr_cv_destroy(&n->cv); +} + +static void signal(notification* n) { + gpr_mu_lock(&n->mu); + n->notified = true; + gpr_cv_signal(&n->cv); + gpr_mu_unlock(&n->mu); +} + +static void wait(notification* n) { + gpr_mu_lock(&n->mu); + while (!n->notified) { + gpr_cv_wait(&n->cv, &n->mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + n->notified = false; + gpr_mu_unlock(&n->mu); +} + +/** + * This method mocks ALTS handshaker service to generate handshaker response + * for a specific request. + */ +static grpc_byte_buffer* generate_handshaker_response( + alts_handshaker_response_type type) { + upb::Arena arena; + grpc_gcp_HandshakerResult* result; + grpc_gcp_Identity* peer_identity; + grpc_gcp_HandshakerResp* resp = grpc_gcp_HandshakerResp_new(arena.ptr()); + grpc_gcp_HandshakerStatus* status = + grpc_gcp_HandshakerResp_mutable_status(resp, arena.ptr()); + grpc_gcp_HandshakerStatus_set_code(status, 0); + grpc_gcp_Identity* local_identity; + switch (type) { + case INVALID: + break; + case CLIENT_START: + case SERVER_START: + grpc_gcp_HandshakerResp_set_out_frames( + resp, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + break; + case CLIENT_NEXT: + grpc_gcp_HandshakerResp_set_out_frames( + resp, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + grpc_gcp_HandshakerResp_set_bytes_consumed( + resp, strlen(ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES)); + result = grpc_gcp_HandshakerResp_mutable_result(resp, arena.ptr()); + peer_identity = + grpc_gcp_HandshakerResult_mutable_peer_identity(result, arena.ptr()); + grpc_gcp_Identity_attributes_set( + peer_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_KEY), + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_VALUE), + arena.ptr()); + grpc_gcp_Identity_set_service_account( + peer_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY)); + grpc_gcp_HandshakerResult_set_key_data( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions( + resp, arena.ptr(), ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + local_identity = + grpc_gcp_HandshakerResult_mutable_local_identity(result, arena.ptr()); + grpc_gcp_Identity_set_service_account( + local_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY)); + grpc_gcp_HandshakerResult_set_application_protocol( + result, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL)); + grpc_gcp_HandshakerResult_set_record_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL)); + grpc_gcp_HandshakerResult_set_max_frame_size( + result, ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE); + break; + case SERVER_NEXT: + grpc_gcp_HandshakerResp_set_bytes_consumed( + resp, strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + result = grpc_gcp_HandshakerResp_mutable_result(resp, arena.ptr()); + peer_identity = + grpc_gcp_HandshakerResult_mutable_peer_identity(result, arena.ptr()); + grpc_gcp_Identity_attributes_set( + peer_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_KEY), + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_VALUE), + arena.ptr()); + grpc_gcp_Identity_set_service_account( + peer_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY)); + grpc_gcp_HandshakerResult_set_key_data( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions( + resp, arena.ptr(), ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + local_identity = + grpc_gcp_HandshakerResult_mutable_local_identity(result, arena.ptr()); + grpc_gcp_Identity_set_service_account( + local_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY)); + grpc_gcp_HandshakerResult_set_application_protocol( + result, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL)); + grpc_gcp_HandshakerResult_set_record_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL)); + break; + case FAILED: + grpc_gcp_HandshakerStatus_set_code(status, 3 /* INVALID ARGUMENT */); + break; + } + size_t buf_len; + char* buf = grpc_gcp_HandshakerResp_serialize(resp, arena.ptr(), &buf_len); + grpc_slice slice = gpr_slice_from_copied_buffer(buf, buf_len); + if (type == INVALID) { + grpc_slice bad_slice = + grpc_slice_split_head(&slice, GRPC_SLICE_LENGTH(slice) - 1); + grpc_slice_unref(slice); + slice = grpc_slice_ref(bad_slice); + grpc_slice_unref(bad_slice); + } + grpc_byte_buffer* buffer = + grpc_raw_byte_buffer_create(&slice, 1 /* number of slices */); + grpc_slice_unref(slice); + return buffer; +} + +static void check_must_not_be_called(tsi_result /*status*/, void* /*user_data*/, + const unsigned char* /*bytes_to_send*/, + size_t /*bytes_to_send_size*/, + tsi_handshaker_result* /*result*/) { + GPR_ASSERT(0); +} + +static void on_client_start_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result == nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == + TSI_INVALID_ARGUMENT); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_INVALID_ARGUMENT); + /* Validate unused bytes. */ + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes, + &unused_bytes_size) == + TSI_INVALID_ARGUMENT); + signal(&tsi_to_caller_notification); +} + +static void on_server_start_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result == nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == + TSI_INVALID_ARGUMENT); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_INVALID_ARGUMENT); + /* Validate unused bytes. */ + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes, + &unused_bytes_size) == + TSI_INVALID_ARGUMENT); + signal(&tsi_to_caller_notification); +} + +static void on_client_next_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result != nullptr); + // Validate max frame size value after Frame Size Negotiation. Here peer max + // frame size is greater than default value, and user specified max frame size + // is absent. + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_zero_copy_grpc_protector( + result, nullptr, &zero_copy_protector) == TSI_OK); + size_t actual_max_frame_size; + tsi_zero_copy_grpc_protector_max_frame_size(zero_copy_protector, + &actual_max_frame_size); + GPR_ASSERT(actual_max_frame_size == kTsiAltsMaxFrameSize); + tsi_zero_copy_grpc_protector_destroy(zero_copy_protector); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties); + GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data, + peer.properties[0].value.length) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, + peer.properties[1].value.data, + peer.properties[1].value.length) == 0); + /* Validate alts context. */ + upb::Arena context_arena; + grpc_gcp_AltsContext* ctx = grpc_gcp_AltsContext_parse( + peer.properties[3].value.data, peer.properties[3].value.length, + context_arena.ptr()); + GPR_ASSERT(ctx != nullptr); + upb_strview application_protocol = + grpc_gcp_AltsContext_application_protocol(ctx); + upb_strview record_protocol = grpc_gcp_AltsContext_record_protocol(ctx); + upb_strview peer_account = grpc_gcp_AltsContext_peer_service_account(ctx); + upb_strview local_account = grpc_gcp_AltsContext_local_service_account(ctx); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL, + application_protocol.data, application_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL, + record_protocol.data, record_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, peer_account.data, + peer_account.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY, local_account.data, + local_account.size) == 0); + size_t iter = UPB_MAP_BEGIN; + grpc_gcp_AltsContext_PeerAttributesEntry* peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_nextmutable(ctx, &iter); + GPR_ASSERT(peer_attributes_entry != nullptr); + while (peer_attributes_entry != nullptr) { + upb_strview key = grpc_gcp_AltsContext_PeerAttributesEntry_key( + const_cast( + peer_attributes_entry)); + upb_strview val = grpc_gcp_AltsContext_PeerAttributesEntry_value( + const_cast( + peer_attributes_entry)); + GPR_ASSERT(upb_strview_eql( + key, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_KEY))); + GPR_ASSERT(upb_strview_eql( + val, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_VALUE))); + peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_nextmutable(ctx, &iter); + } + /* Validate security level. */ + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_SECURITY_LEVEL, + peer.properties[4].value.data, + peer.properties[4].value.length) == 0); + tsi_peer_destruct(&peer); + /* Validate unused bytes. */ + const unsigned char* bytes = nullptr; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes, + &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == strlen(ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES)); + GPR_ASSERT(memcmp(bytes, ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES, bytes_size) == + 0); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_OK); + GPR_ASSERT(protector != nullptr); + tsi_frame_protector_destroy(protector); + tsi_handshaker_result_destroy(result); + signal(&tsi_to_caller_notification); +} + +static void on_server_next_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(result != nullptr); + // Validate max frame size value after Frame Size Negotiation. The negotiated + // frame size value equals minimum send frame size, due to the absence of peer + // max frame size. + tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; + size_t user_specified_max_frame_size = + ALTS_TSI_HANDSHAKER_TEST_MAX_FRAME_SIZE; + GPR_ASSERT(tsi_handshaker_result_create_zero_copy_grpc_protector( + result, &user_specified_max_frame_size, + &zero_copy_protector) == TSI_OK); + size_t actual_max_frame_size; + tsi_zero_copy_grpc_protector_max_frame_size(zero_copy_protector, + &actual_max_frame_size); + GPR_ASSERT(actual_max_frame_size == kTsiAltsMinFrameSize); + tsi_zero_copy_grpc_protector_destroy(zero_copy_protector); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties); + GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data, + peer.properties[0].value.length) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, + peer.properties[1].value.data, + peer.properties[1].value.length) == 0); + /* Validate alts context. */ + upb::Arena context_arena; + grpc_gcp_AltsContext* ctx = grpc_gcp_AltsContext_parse( + peer.properties[3].value.data, peer.properties[3].value.length, + context_arena.ptr()); + GPR_ASSERT(ctx != nullptr); + upb_strview application_protocol = + grpc_gcp_AltsContext_application_protocol(ctx); + upb_strview record_protocol = grpc_gcp_AltsContext_record_protocol(ctx); + upb_strview peer_account = grpc_gcp_AltsContext_peer_service_account(ctx); + upb_strview local_account = grpc_gcp_AltsContext_local_service_account(ctx); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL, + application_protocol.data, application_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL, + record_protocol.data, record_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, peer_account.data, + peer_account.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY, local_account.data, + local_account.size) == 0); + size_t iter = UPB_MAP_BEGIN; + grpc_gcp_AltsContext_PeerAttributesEntry* peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_nextmutable(ctx, &iter); + GPR_ASSERT(peer_attributes_entry != nullptr); + while (peer_attributes_entry != nullptr) { + upb_strview key = grpc_gcp_AltsContext_PeerAttributesEntry_key( + const_cast( + peer_attributes_entry)); + upb_strview val = grpc_gcp_AltsContext_PeerAttributesEntry_value( + const_cast( + peer_attributes_entry)); + GPR_ASSERT(upb_strview_eql( + key, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_KEY))); + GPR_ASSERT(upb_strview_eql( + val, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_PEER_ATTRIBUTES_VALUE))); + peer_attributes_entry = + grpc_gcp_AltsContext_peer_attributes_nextmutable(ctx, &iter); + } + /* Check security level. */ + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_SECURITY_LEVEL, + peer.properties[4].value.data, + peer.properties[4].value.length) == 0); + + tsi_peer_destruct(&peer); + /* Validate unused bytes. */ + const unsigned char* bytes = nullptr; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes, + &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == 0); + GPR_ASSERT(bytes == nullptr); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_OK); + GPR_ASSERT(protector != nullptr); + tsi_frame_protector_destroy(protector); + tsi_handshaker_result_destroy(result); + signal(&tsi_to_caller_notification); +} + +static tsi_result mock_client_start(alts_handshaker_client* client) { + if (!should_handshaker_client_api_succeed) { + return TSI_INTERNAL_ERROR; + } + /* Note that the alts_tsi_handshaker needs to set its + * has_sent_start_message field field to true + * before the call to alts_handshaker_client_start is made because + * because it's unsafe to access it afterwards. */ + alts_handshaker_client_check_fields_for_testing( + client, on_client_start_success_cb, nullptr, true, nullptr); + /* Populate handshaker response for client_start request. */ + grpc_byte_buffer** recv_buffer_ptr = + alts_handshaker_client_get_recv_buffer_addr_for_testing(client); + *recv_buffer_ptr = generate_handshaker_response(CLIENT_START); + cb_event = client; + signal(&caller_to_tsi_notification); + return TSI_OK; +} + +static void mock_shutdown(alts_handshaker_client* /*self*/) {} + +static tsi_result mock_server_start(alts_handshaker_client* client, + grpc_slice* bytes_received) { + if (!should_handshaker_client_api_succeed) { + return TSI_INTERNAL_ERROR; + } + alts_handshaker_client_check_fields_for_testing( + client, on_server_start_success_cb, nullptr, true, nullptr); + grpc_slice slice = grpc_empty_slice(); + GPR_ASSERT(grpc_slice_cmp(*bytes_received, slice) == 0); + /* Populate handshaker response for server_start request. */ + grpc_byte_buffer** recv_buffer_ptr = + alts_handshaker_client_get_recv_buffer_addr_for_testing(client); + *recv_buffer_ptr = generate_handshaker_response(SERVER_START); + cb_event = client; + grpc_slice_unref(slice); + signal(&caller_to_tsi_notification); + return TSI_OK; +} + +static tsi_result mock_next(alts_handshaker_client* client, + grpc_slice* bytes_received) { + if (!should_handshaker_client_api_succeed) { + return TSI_INTERNAL_ERROR; + } + alts_tsi_handshaker* handshaker = + alts_handshaker_client_get_handshaker_for_testing(client); + bool is_client = alts_tsi_handshaker_get_is_client_for_testing(handshaker); + tsi_handshaker_on_next_done_cb cb = + is_client ? on_client_next_success_cb : on_server_next_success_cb; + alts_handshaker_client_set_cb_for_testing(client, cb); + alts_handshaker_client_set_recv_bytes_for_testing(client, bytes_received); + alts_handshaker_client_check_fields_for_testing(client, cb, nullptr, true, + bytes_received); + GPR_ASSERT(bytes_received != nullptr); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*bytes_received), + ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + GRPC_SLICE_LENGTH(*bytes_received)) == 0); + /* Populate handshaker response for next request. */ + grpc_slice out_frame = + grpc_slice_from_static_string(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME); + grpc_byte_buffer** recv_buffer_ptr = + alts_handshaker_client_get_recv_buffer_addr_for_testing(client); + *recv_buffer_ptr = is_client ? generate_handshaker_response(CLIENT_NEXT) + : generate_handshaker_response(SERVER_NEXT); + alts_handshaker_client_set_recv_bytes_for_testing(client, &out_frame); + cb_event = client; + signal(&caller_to_tsi_notification); + grpc_slice_unref(out_frame); + return TSI_OK; +} + +static void mock_destruct(alts_handshaker_client* /*client*/) {} + +static alts_handshaker_client_vtable vtable = {mock_client_start, + mock_server_start, mock_next, + mock_shutdown, mock_destruct}; + +static tsi_handshaker* create_test_handshaker(bool is_client) { + tsi_handshaker* handshaker = nullptr; + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + alts_tsi_handshaker_create(options, "target_name", + ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING, is_client, + nullptr, &handshaker, 0); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + alts_tsi_handshaker_set_client_vtable_for_testing(alts_handshaker, &vtable); + grpc_alts_credentials_options_destroy(options); + return handshaker; +} + +static void run_tsi_handshaker_destroy_with_exec_ctx( + tsi_handshaker* handshaker) { + grpc_core::ExecCtx exec_ctx; + tsi_handshaker_destroy(handshaker); +} + +static void check_handshaker_next_invalid_input() { + /* Initialization. */ + tsi_handshaker* handshaker = create_test_handshaker(true); + /* Check nullptr handshaker. */ + GPR_ASSERT(tsi_handshaker_next(nullptr, nullptr, 0, nullptr, nullptr, nullptr, + check_must_not_be_called, + nullptr) == TSI_INVALID_ARGUMENT); + /* Check nullptr callback. */ + GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, + nullptr, nullptr, + nullptr) == TSI_INVALID_ARGUMENT); + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); +} + +static void check_handshaker_shutdown_invalid_input() { + /* Initialization. */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + /* Check nullptr handshaker. */ + tsi_handshaker_shutdown(nullptr); + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); +} + +static void check_handshaker_next_success() { + /** + * Create handshakers for which internal mock client is going to do + * correctness check. + */ + tsi_handshaker* client_handshaker = + create_test_handshaker(true /* is_client */); + tsi_handshaker* server_handshaker = + create_test_handshaker(false /* is_client */); + /* Client start. */ + GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, on_client_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Client next. */ + GPR_ASSERT(tsi_handshaker_next( + client_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_client_next_success_cb, nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Server start. */ + GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, on_server_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Server next. */ + GPR_ASSERT(tsi_handshaker_next( + server_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_server_next_success_cb, nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker); +} + +static void check_handshaker_next_with_shutdown() { + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client*/); + /* next(success) -- shutdown(success) -- next (fail) */ + GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, + nullptr, on_client_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + tsi_handshaker_shutdown(handshaker); + GPR_ASSERT(tsi_handshaker_next( + handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_client_next_success_cb, + nullptr) == TSI_HANDSHAKE_SHUTDOWN); + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); +} + +static void check_handle_response_with_shutdown(void* /*unused*/) { + wait(&caller_to_tsi_notification); + alts_handshaker_client_handle_response(cb_event, true /* is_ok */); +} + +static void check_handshaker_next_failure() { + /** + * Create handshakers for which internal mock client is always going to fail. + */ + tsi_handshaker* client_handshaker = + create_test_handshaker(true /* is_client */); + tsi_handshaker* server_handshaker = + create_test_handshaker(false /* is_client */); + /* Client start. */ + GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Server start. */ + GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Server next. */ + GPR_ASSERT(tsi_handshaker_next( + server_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Client next. */ + GPR_ASSERT(tsi_handshaker_next( + client_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(server_handshaker); + run_tsi_handshaker_destroy_with_exec_ctx(client_handshaker); +} + +static void on_invalid_input_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void on_failed_grpc_call_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_nullptr_handshaker() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + grpc_slice slice = grpc_empty_slice(); + grpc_byte_buffer* recv_buffer = grpc_raw_byte_buffer_create(&slice, 1); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Check nullptr handshaker. */ + alts_handshaker_client_set_fields_for_testing(client, nullptr, + on_invalid_input_cb, nullptr, + recv_buffer, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, true); + /* Note: here and elsewhere in this test, we first ref the handshaker in order + * to match the unref that on_status_received will do. This necessary + * because this test mocks out the grpc call in such a way that the code + * path that would usually take this ref is skipped. */ + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + grpc_slice_unref(slice); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void check_handle_response_nullptr_recv_bytes() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Check nullptr recv_bytes. */ + alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, + on_invalid_input_cb, nullptr, + nullptr, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, true); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void check_handle_response_failed_grpc_call_to_handshaker_service() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + grpc_slice slice = grpc_empty_slice(); + grpc_byte_buffer* recv_buffer = grpc_raw_byte_buffer_create(&slice, 1); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Check failed grpc call made to handshaker service. */ + alts_handshaker_client_set_fields_for_testing( + client, alts_handshaker, on_failed_grpc_call_cb, nullptr, recv_buffer, + GRPC_STATUS_UNKNOWN); + alts_handshaker_client_handle_response(client, true); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_UNKNOWN, GRPC_ERROR_NONE); + } + /* Cleanup. */ + grpc_slice_unref(slice); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void +check_handle_response_failed_recv_message_from_handshaker_service() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + grpc_slice slice = grpc_empty_slice(); + grpc_byte_buffer* recv_buffer = grpc_raw_byte_buffer_create(&slice, 1); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Check failed recv message op from handshaker service. */ + alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, + on_failed_grpc_call_cb, nullptr, + recv_buffer, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, false); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + grpc_slice_unref(slice); + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void on_invalid_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_DATA_CORRUPTED); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_invalid_resp() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Tests. */ + grpc_byte_buffer* recv_buffer = generate_handshaker_response(INVALID); + alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, + on_invalid_resp_cb, nullptr, + recv_buffer, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, true); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void check_handle_response_success(void* /*unused*/) { + /* Client start. */ + wait(&caller_to_tsi_notification); + alts_handshaker_client_handle_response(cb_event, true /* is_ok */); + /* Client next. */ + wait(&caller_to_tsi_notification); + alts_handshaker_client_handle_response(cb_event, true /* is_ok */); + alts_handshaker_client_ref_for_testing(cb_event); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + cb_event, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Server start. */ + wait(&caller_to_tsi_notification); + alts_handshaker_client_handle_response(cb_event, true /* is_ok */); + /* Server next. */ + wait(&caller_to_tsi_notification); + alts_handshaker_client_handle_response(cb_event, true /* is_ok */); + alts_handshaker_client_ref_for_testing(cb_event); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + cb_event, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } +} + +static void on_failed_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_failure() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + /* Tests. */ + grpc_byte_buffer* recv_buffer = generate_handshaker_response(FAILED); + alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, + on_failed_resp_cb, nullptr, + recv_buffer, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, true /* is_ok*/); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +static void on_shutdown_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_HANDSHAKE_SHUTDOWN); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_after_shutdown() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + tsi_handshaker* handshaker = create_test_handshaker(true /* is_client */); + tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, nullptr, + on_client_start_success_cb, nullptr); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast(handshaker); + alts_handshaker_client* client = + alts_tsi_handshaker_get_client_for_testing(alts_handshaker); + grpc_byte_buffer** recv_buffer_ptr = + alts_handshaker_client_get_recv_buffer_addr_for_testing(client); + grpc_byte_buffer_destroy(*recv_buffer_ptr); + + /* Tests. */ + tsi_handshaker_shutdown(handshaker); + grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START); + alts_handshaker_client_set_fields_for_testing(client, alts_handshaker, + on_shutdown_resp_cb, nullptr, + recv_buffer, GRPC_STATUS_OK); + alts_handshaker_client_handle_response(client, true); + alts_handshaker_client_ref_for_testing(client); + { + grpc_core::ExecCtx exec_ctx; + alts_handshaker_client_on_status_received_for_testing( + client, GRPC_STATUS_OK, GRPC_ERROR_NONE); + } + /* Cleanup. */ + run_tsi_handshaker_destroy_with_exec_ctx(handshaker); + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +void check_handshaker_next_fails_after_shutdown() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + cb_event = nullptr; + /* Tests. */ + grpc_core::Thread thd("alts_tsi_handshaker_test", + &check_handle_response_with_shutdown, nullptr); + thd.Start(); + check_handshaker_next_with_shutdown(); + thd.Join(); + /* Cleanup. */ + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +void check_handshaker_success() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + /* Tests. */ + grpc_core::Thread thd("alts_tsi_handshaker_test", + &check_handle_response_success, nullptr); + thd.Start(); + check_handshaker_next_success(); + thd.Join(); + /* Cleanup. */ + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + /* Initialization. */ + grpc_init(); + grpc_alts_shared_resource_dedicated_init(); + /* Tests. */ + should_handshaker_client_api_succeed = true; + check_handshaker_success(); + check_handshaker_next_invalid_input(); + check_handshaker_next_fails_after_shutdown(); + check_handle_response_after_shutdown(); + should_handshaker_client_api_succeed = false; + check_handshaker_shutdown_invalid_input(); + check_handshaker_next_failure(); + check_handle_response_nullptr_handshaker(); + check_handle_response_nullptr_recv_bytes(); + check_handle_response_failed_grpc_call_to_handshaker_service(); + check_handle_response_failed_recv_message_from_handshaker_service(); + check_handle_response_invalid_resp(); + check_handle_response_failure(); + /* Cleanup. */ + grpc_alts_shared_resource_dedicated_shutdown(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc new file mode 100644 index 00000000..54fb0cd6 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" + +#include "upb/upb.hpp" + +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" +#include "test/core/util/test_config.h" + +#define ALTS_TSI_UTILS_TEST_OUT_FRAME "Hello Google" + +static void convert_to_tsi_result_test() { + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_OK) == TSI_OK); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_UNKNOWN) == + TSI_UNKNOWN_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result( + GRPC_STATUS_INVALID_ARGUMENT) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_OUT_OF_RANGE) == + TSI_UNKNOWN_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_INTERNAL) == + TSI_INTERNAL_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_NOT_FOUND) == + TSI_NOT_FOUND); +} + +static void deserialize_response_test() { + upb::Arena arena; + grpc_gcp_HandshakerResp* resp = grpc_gcp_HandshakerResp_new(arena.ptr()); + grpc_gcp_HandshakerResp_set_out_frames( + resp, upb_strview_makez(ALTS_TSI_UTILS_TEST_OUT_FRAME)); + size_t buf_len; + char* buf = grpc_gcp_HandshakerResp_serialize(resp, arena.ptr(), &buf_len); + grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_len); + + /* Valid serialization. */ + upb::Arena arena2; + grpc_byte_buffer* buffer = + grpc_raw_byte_buffer_create(&slice, 1 /* number of slices */); + grpc_gcp_HandshakerResp* decoded_resp = + alts_tsi_utils_deserialize_response(buffer, arena2.ptr()); + GPR_ASSERT(grpc_gcp_handshaker_resp_equals(resp, decoded_resp)); + grpc_byte_buffer_destroy(buffer); + + /* Invalid serialization. */ + grpc_slice bad_slice = + grpc_slice_split_head(&slice, GRPC_SLICE_LENGTH(slice) - 1); + buffer = grpc_raw_byte_buffer_create(&bad_slice, 1 /* number of slices */); + GPR_ASSERT(alts_tsi_utils_deserialize_response(buffer, arena2.ptr()) == + nullptr); + + /* Clean up. */ + grpc_slice_unref(slice); + grpc_slice_unref(bad_slice); + grpc_byte_buffer_destroy(buffer); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + /* Tests. */ + grpc_init(); + deserialize_response_test(); + convert_to_tsi_result_test(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc b/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc new file mode 100644 index 00000000..7579ed10 --- /dev/null +++ b/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc @@ -0,0 +1,166 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +#include +#include +#include + +const size_t kMaxRpcVersionMajor = 3; +const size_t kMaxRpcVersionMinor = 2; +const size_t kMinRpcVersionMajor = 2; +const size_t kMinRpcVersionMinor = 1; + +static bool grpc_gcp_rpc_protocol_versions_equal( + grpc_gcp_rpc_protocol_versions* l_versions, + grpc_gcp_rpc_protocol_versions* r_versions) { + GPR_ASSERT(l_versions != nullptr && r_versions != nullptr); + if ((l_versions->max_rpc_version.major != + r_versions->max_rpc_version.major) || + (l_versions->max_rpc_version.minor != + r_versions->max_rpc_version.minor)) { + return false; + } + if ((l_versions->min_rpc_version.major != + r_versions->min_rpc_version.major) || + (l_versions->min_rpc_version.minor != + r_versions->min_rpc_version.minor)) { + return false; + } + return true; +} + +static void test_success() { + grpc_gcp_rpc_protocol_versions version; + grpc_gcp_rpc_protocol_versions decoded_version; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + &version, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + &version, kMinRpcVersionMajor, kMinRpcVersionMinor)); + /* Serializes to grpc slice. */ + grpc_slice encoded_slice; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&version, &encoded_slice)); + /* Deserializes and compares with the original version. */ + GPR_ASSERT( + grpc_gcp_rpc_protocol_versions_decode(encoded_slice, &decoded_version)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_equal(&version, &decoded_version)); + grpc_slice_unref(encoded_slice); +} + +static void test_failure() { + grpc_gcp_rpc_protocol_versions version, decoded_version; + grpc_slice encoded_slice; + /* Test for invalid arguments. */ + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_set_max( + nullptr, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_set_min( + nullptr, kMinRpcVersionMajor, kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + &version, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + &version, kMinRpcVersionMajor, kMinRpcVersionMinor)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode(nullptr, &encoded_slice)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode(&version, nullptr)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_decode(encoded_slice, nullptr)); + /* Test for upb decode. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&version, &encoded_slice)); + grpc_slice bad_slice = grpc_slice_split_head( + &encoded_slice, GRPC_SLICE_LENGTH(encoded_slice) - 1); + grpc_slice_unref(encoded_slice); + GPR_ASSERT( + !grpc_gcp_rpc_protocol_versions_decode(bad_slice, &decoded_version)); + grpc_slice_unref(bad_slice); +} + +static void test_copy() { + grpc_gcp_rpc_protocol_versions src; + grpc_gcp_rpc_protocol_versions des; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&src, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&src, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_copy(&src, &des)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_equal(&src, &des)); +} + +static void test_check_success() { + grpc_gcp_rpc_protocol_versions v1; + grpc_gcp_rpc_protocol_versions v2; + grpc_gcp_rpc_protocol_versions_version highest_common_version; + /* test equality. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 1); + GPR_ASSERT(grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &highest_common_version, &v1.max_rpc_version) == 0); + + /* test inequality. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMinRpcVersionMinor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMinRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 1); + GPR_ASSERT(grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &highest_common_version, &v2.max_rpc_version) == 0); +} + +static void test_check_failure() { + grpc_gcp_rpc_protocol_versions v1; + grpc_gcp_rpc_protocol_versions v2; + grpc_gcp_rpc_protocol_versions_version highest_common_version; + + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 0); +} + +int main(int /*argc*/, char** /*argv*/) { + /* Run tests. */ + test_success(); + test_failure(); + test_copy(); + test_check_success(); + test_check_failure(); + return 0; +} diff --git a/test/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_test.cc b/test/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_test.cc new file mode 100644 index 00000000..8b14ea0e --- /dev/null +++ b/test/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_test.cc @@ -0,0 +1,462 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol.h" + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_integrity_only_record_protocol.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_privacy_integrity_record_protocol.h" +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" +#include "test/core/tsi/alts/crypt/gsec_test_util.h" +#include "test/core/util/test_config.h" + +constexpr size_t kMaxSliceLength = 256; +constexpr size_t kMaxSlices = 10; +constexpr size_t kSealRepeatTimes = 5; +constexpr size_t kTagLength = 16; + +/* Test fixtures for each test cases. */ +struct alts_grpc_record_protocol_test_fixture { + alts_grpc_record_protocol* client_protect; + alts_grpc_record_protocol* client_unprotect; + alts_grpc_record_protocol* server_protect; + alts_grpc_record_protocol* server_unprotect; +}; + +/* Test input variables for protect/unprotect operations. */ +struct alts_grpc_record_protocol_test_var { + size_t header_length; + size_t tag_length; + grpc_slice_buffer original_sb; + grpc_slice_buffer duplicate_sb; + grpc_slice_buffer protected_sb; + grpc_slice_buffer unprotected_sb; +}; + +/* --- Test utility functions. --- */ + +static void create_random_slice_buffer(grpc_slice_buffer* sb) { + GPR_ASSERT(sb != nullptr); + size_t slice_count = gsec_test_bias_random_uint32(kMaxSlices) + 1; + for (size_t i = 0; i < slice_count; i++) { + size_t slice_length = gsec_test_bias_random_uint32(kMaxSliceLength) + 1; + grpc_slice slice = GRPC_SLICE_MALLOC(slice_length); + gsec_test_random_bytes(GRPC_SLICE_START_PTR(slice), slice_length); + grpc_slice_buffer_add(sb, slice); + } +} + +static uint8_t* pointer_to_nth_byte(grpc_slice_buffer* sb, size_t index) { + GPR_ASSERT(sb != nullptr); + GPR_ASSERT(index < sb->length); + for (size_t i = 0; i < sb->count; i++) { + if (index < GRPC_SLICE_LENGTH(sb->slices[i])) { + return GRPC_SLICE_START_PTR(sb->slices[i]) + index; + } else { + index -= GRPC_SLICE_LENGTH(sb->slices[i]); + } + } + return nullptr; +} + +/* Checks if two slice buffer contents are the same. It is not super efficient, + * but OK for testing. */ +static bool are_slice_buffers_equal(grpc_slice_buffer* first, + grpc_slice_buffer* second) { + GPR_ASSERT(first != nullptr); + GPR_ASSERT(second != nullptr); + if (first->length != second->length) { + return false; + } + for (size_t i = 0; i < first->length; i++) { + uint8_t* first_ptr = pointer_to_nth_byte(first, i); + uint8_t* second_ptr = pointer_to_nth_byte(second, i); + GPR_ASSERT(first_ptr != nullptr); + GPR_ASSERT(second_ptr != nullptr); + if ((*first_ptr) != (*second_ptr)) { + return false; + } + } + return true; +} + +static void alter_random_byte(grpc_slice_buffer* sb) { + GPR_ASSERT(sb != nullptr); + if (sb->length == 0) { + return; + } + uint32_t offset = + gsec_test_bias_random_uint32(static_cast(sb->length)); + uint8_t* ptr = pointer_to_nth_byte(sb, offset); + (*ptr)++; +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_integrity_only_create(bool rekey, bool extra_copy) { + alts_grpc_record_protocol_test_fixture* fixture = + static_cast( + gpr_zalloc(sizeof(alts_grpc_record_protocol_test_fixture))); + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + uint8_t* key; + gsec_test_random_array(&key, key_length); + gsec_aead_crypter* crypter = nullptr; + + /* Create client record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_integrity_only_record_protocol_create( + crypter, 8, /*is_client=*/true, /*is_protect=*/true, + extra_copy, &fixture->client_protect) == TSI_OK); + /* Create client record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_integrity_only_record_protocol_create( + crypter, 8, /*is_client=*/true, /*is_protect=*/false, + extra_copy, &fixture->client_unprotect) == TSI_OK); + /* Create server record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_integrity_only_record_protocol_create( + crypter, 8, /*is_client=*/false, /*is_protect=*/true, + extra_copy, &fixture->server_protect) == TSI_OK); + /* Create server record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_integrity_only_record_protocol_create( + crypter, 8, /*is_client=*/false, /*is_protect=*/false, + extra_copy, &fixture->server_unprotect) == TSI_OK); + + gpr_free(key); + return fixture; +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_integrity_only_no_rekey_no_extra_copy_create() { + return test_fixture_integrity_only_create(false, false); +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_integrity_only_rekey_create() { + return test_fixture_integrity_only_create(true, false); +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_integrity_only_extra_copy_create() { + return test_fixture_integrity_only_create(false, true); +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_privacy_integrity_create(bool rekey) { + alts_grpc_record_protocol_test_fixture* fixture = + static_cast( + gpr_zalloc(sizeof(alts_grpc_record_protocol_test_fixture))); + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + uint8_t* key; + gsec_test_random_array(&key, key_length); + gsec_aead_crypter* crypter = nullptr; + + /* Create client record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_privacy_integrity_record_protocol_create( + crypter, 8, /*is_client=*/true, /*is_protect=*/true, + &fixture->client_protect) == TSI_OK); + /* Create client record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_privacy_integrity_record_protocol_create( + crypter, 8, /*is_client=*/true, /*is_protect=*/false, + &fixture->client_unprotect) == TSI_OK); + /* Create server record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_privacy_integrity_record_protocol_create( + crypter, 8, /*is_client=*/false, /*is_protect=*/true, + &fixture->server_protect) == TSI_OK); + /* Create server record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_grpc_privacy_integrity_record_protocol_create( + crypter, 8, /*is_client=*/false, /*is_protect=*/false, + &fixture->server_unprotect) == TSI_OK); + + gpr_free(key); + return fixture; +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_privacy_integrity_no_rekey_create() { + return test_fixture_privacy_integrity_create(false); +} + +static alts_grpc_record_protocol_test_fixture* +test_fixture_privacy_integrity_rekey_create() { + return test_fixture_privacy_integrity_create(true); +} + +static void alts_grpc_record_protocol_test_fixture_destroy( + alts_grpc_record_protocol_test_fixture* fixture) { + if (fixture == nullptr) { + return; + } + grpc_core::ExecCtx exec_ctx; + alts_grpc_record_protocol_destroy(fixture->client_protect); + alts_grpc_record_protocol_destroy(fixture->client_unprotect); + alts_grpc_record_protocol_destroy(fixture->server_protect); + alts_grpc_record_protocol_destroy(fixture->server_unprotect); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(fixture); +} + +static alts_grpc_record_protocol_test_var* +alts_grpc_record_protocol_test_var_create() { + alts_grpc_record_protocol_test_var* var = + static_cast( + gpr_zalloc(sizeof(alts_grpc_record_protocol_test_var))); + var->header_length = alts_iovec_record_protocol_get_header_length(); + var->tag_length = kTagLength; + /* Initialized slice buffers. */ + grpc_slice_buffer_init(&var->original_sb); + grpc_slice_buffer_init(&var->duplicate_sb); + grpc_slice_buffer_init(&var->protected_sb); + grpc_slice_buffer_init(&var->unprotected_sb); + /* Randomly sets content of original_sb, and copies into duplicate_sb. */ + create_random_slice_buffer(&var->original_sb); + for (size_t i = 0; i < var->original_sb.count; i++) { + grpc_slice_buffer_add(&var->duplicate_sb, + grpc_slice_ref(var->original_sb.slices[i])); + } + return var; +} + +static void alts_grpc_record_protocol_test_var_destroy( + alts_grpc_record_protocol_test_var* var) { + if (var == nullptr) { + return; + } + grpc_slice_buffer_destroy_internal(&var->original_sb); + grpc_slice_buffer_destroy_internal(&var->duplicate_sb); + grpc_slice_buffer_destroy_internal(&var->protected_sb); + grpc_slice_buffer_destroy_internal(&var->unprotected_sb); + gpr_free(var); +} + +/* --- alts grpc record protocol tests. --- */ + +static void random_seal_unseal(alts_grpc_record_protocol* sender, + alts_grpc_record_protocol* receiver) { + grpc_core::ExecCtx exec_ctx; + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_grpc_record_protocol_test_var* var = + alts_grpc_record_protocol_test_var_create(); + /* Seals and then unseals. */ + size_t data_length = var->original_sb.length; + tsi_result status = alts_grpc_record_protocol_protect( + sender, &var->original_sb, &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(var->protected_sb.length == + data_length + var->header_length + var->tag_length); + status = alts_grpc_record_protocol_unprotect(receiver, &var->protected_sb, + &var->unprotected_sb); + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT( + are_slice_buffers_equal(&var->unprotected_sb, &var->duplicate_sb)); + alts_grpc_record_protocol_test_var_destroy(var); + } + grpc_core::ExecCtx::Get()->Flush(); +} + +static void empty_seal_unseal(alts_grpc_record_protocol* sender, + alts_grpc_record_protocol* receiver) { + grpc_core::ExecCtx exec_ctx; + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_grpc_record_protocol_test_var* var = + alts_grpc_record_protocol_test_var_create(); + /* Seals and then unseals empty payload. */ + grpc_slice_buffer_reset_and_unref_internal(&var->original_sb); + grpc_slice_buffer_reset_and_unref_internal(&var->duplicate_sb); + tsi_result status = alts_grpc_record_protocol_protect( + sender, &var->original_sb, &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(var->protected_sb.length == + var->header_length + var->tag_length); + status = alts_grpc_record_protocol_unprotect(receiver, &var->protected_sb, + &var->unprotected_sb); + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT( + are_slice_buffers_equal(&var->unprotected_sb, &var->duplicate_sb)); + alts_grpc_record_protocol_test_var_destroy(var); + } + grpc_core::ExecCtx::Get()->Flush(); +} + +static void unsync_seal_unseal(alts_grpc_record_protocol* sender, + alts_grpc_record_protocol* receiver) { + grpc_core::ExecCtx exec_ctx; + tsi_result status; + alts_grpc_record_protocol_test_var* var = + alts_grpc_record_protocol_test_var_create(); + /* Seals once. */ + status = alts_grpc_record_protocol_protect(sender, &var->original_sb, + &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + grpc_slice_buffer_reset_and_unref_internal(&var->protected_sb); + /* Seals again. */ + status = alts_grpc_record_protocol_protect(sender, &var->duplicate_sb, + &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + /* Unseals the second frame. */ + status = alts_grpc_record_protocol_unprotect(receiver, &var->protected_sb, + &var->unprotected_sb); + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + alts_grpc_record_protocol_test_var_destroy(var); + grpc_core::ExecCtx::Get()->Flush(); +} + +static void corrupted_data(alts_grpc_record_protocol* sender, + alts_grpc_record_protocol* receiver) { + grpc_core::ExecCtx exec_ctx; + tsi_result status; + alts_grpc_record_protocol_test_var* var = + alts_grpc_record_protocol_test_var_create(); + /* Seals once. */ + status = alts_grpc_record_protocol_protect(sender, &var->original_sb, + &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + /* Corrupts one byte in protected_sb and tries to unprotect. */ + alter_random_byte(&var->protected_sb); + status = alts_grpc_record_protocol_unprotect(receiver, &var->protected_sb, + &var->unprotected_sb); + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + alts_grpc_record_protocol_test_var_destroy(var); + grpc_core::ExecCtx::Get()->Flush(); +} + +static void input_check(alts_grpc_record_protocol* rp) { + grpc_core::ExecCtx exec_ctx; + tsi_result status; + alts_grpc_record_protocol_test_var* var = + alts_grpc_record_protocol_test_var_create(); + /* Protects with nullptr input. */ + status = alts_grpc_record_protocol_protect(rp, nullptr, &var->protected_sb); + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + status = alts_grpc_record_protocol_protect(rp, &var->original_sb, nullptr); + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + /* Unprotects with nullptr input. */ + status = alts_grpc_record_protocol_protect(rp, &var->original_sb, + &var->protected_sb); + GPR_ASSERT(status == TSI_OK); + status = + alts_grpc_record_protocol_unprotect(rp, nullptr, &var->unprotected_sb); + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + status = alts_grpc_record_protocol_unprotect(rp, &var->protected_sb, nullptr); + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + /* Unprotects on a temporary slice buffer which length is smaller than header + * length plus tag length. */ + grpc_slice_buffer temp_sb; + grpc_slice_buffer_init(&temp_sb); + grpc_slice_buffer_move_first( + &var->protected_sb, var->header_length + var->tag_length - 1, &temp_sb); + status = + alts_grpc_record_protocol_unprotect(rp, &temp_sb, &var->unprotected_sb); + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + grpc_slice_buffer_destroy_internal(&temp_sb); + alts_grpc_record_protocol_test_var_destroy(var); + grpc_core::ExecCtx::Get()->Flush(); +} + +/* --- Test cases. --- */ + +static void alts_grpc_record_protocol_random_seal_unseal_tests( + alts_grpc_record_protocol_test_fixture* fixture) { + random_seal_unseal(fixture->client_protect, fixture->server_unprotect); + random_seal_unseal(fixture->server_protect, fixture->client_unprotect); +} + +static void alts_grpc_record_protocol_empty_seal_unseal_tests( + alts_grpc_record_protocol_test_fixture* fixture) { + empty_seal_unseal(fixture->client_protect, fixture->server_unprotect); + empty_seal_unseal(fixture->server_protect, fixture->client_unprotect); +} + +static void alts_grpc_record_protocol_unsync_seal_unseal_tests( + alts_grpc_record_protocol_test_fixture* fixture) { + unsync_seal_unseal(fixture->client_protect, fixture->server_unprotect); + unsync_seal_unseal(fixture->server_protect, fixture->client_unprotect); +} + +static void alts_grpc_record_protocol_corrupted_data_tests( + alts_grpc_record_protocol_test_fixture* fixture) { + corrupted_data(fixture->client_protect, fixture->server_unprotect); + corrupted_data(fixture->server_protect, fixture->client_unprotect); +} + +static void alts_grpc_record_protocol_input_check_tests( + alts_grpc_record_protocol_test_fixture* fixture) { + input_check(fixture->client_protect); +} + +static void alts_grpc_record_protocol_tests( + alts_grpc_record_protocol_test_fixture* (*fixture_create)()) { + auto* fixture_1 = fixture_create(); + alts_grpc_record_protocol_random_seal_unseal_tests(fixture_1); + alts_grpc_record_protocol_test_fixture_destroy(fixture_1); + + auto* fixture_2 = fixture_create(); + alts_grpc_record_protocol_empty_seal_unseal_tests(fixture_2); + alts_grpc_record_protocol_test_fixture_destroy(fixture_2); + + auto* fixture_3 = fixture_create(); + alts_grpc_record_protocol_unsync_seal_unseal_tests(fixture_3); + alts_grpc_record_protocol_test_fixture_destroy(fixture_3); + + auto* fixture_4 = fixture_create(); + alts_grpc_record_protocol_corrupted_data_tests(fixture_4); + alts_grpc_record_protocol_test_fixture_destroy(fixture_4); + + auto* fixture_5 = fixture_create(); + alts_grpc_record_protocol_input_check_tests(fixture_5); + alts_grpc_record_protocol_test_fixture_destroy(fixture_5); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + alts_grpc_record_protocol_tests( + &test_fixture_integrity_only_no_rekey_no_extra_copy_create); + alts_grpc_record_protocol_tests(&test_fixture_integrity_only_rekey_create); + alts_grpc_record_protocol_tests( + &test_fixture_integrity_only_extra_copy_create); + alts_grpc_record_protocol_tests( + &test_fixture_privacy_integrity_no_rekey_create); + alts_grpc_record_protocol_tests(&test_fixture_privacy_integrity_rekey_create); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol_test.cc b/test/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol_test.cc new file mode 100644 index 00000000..9096b662 --- /dev/null +++ b/test/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol_test.cc @@ -0,0 +1,929 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_iovec_record_protocol.h" + +#include +#include + +#include "test/core/tsi/alts/crypt/gsec_test_util.h" + +constexpr size_t kMaxDataSize = 1024; +constexpr size_t kMaxSlices = 10; +constexpr size_t kSealRepeatTimes = 5; +constexpr size_t kTagLength = 16; + +/* Test fixtures for each test cases. */ +struct alts_iovec_record_protocol_test_fixture { + alts_iovec_record_protocol* client_protect; + alts_iovec_record_protocol* client_unprotect; + alts_iovec_record_protocol* server_protect; + alts_iovec_record_protocol* server_unprotect; +}; + +/* Test variables for protect/unprotect operations. */ +struct alts_iovec_record_protocol_test_var { + uint8_t* header_buf; + size_t header_length; + iovec_t header_iovec; + uint8_t* tag_buf; + size_t tag_length; + iovec_t tag_iovec; + uint8_t* data_buf; + uint8_t* dup_buf; + size_t data_length; + iovec_t* data_iovec; + size_t data_iovec_length; + uint8_t* protected_buf; + iovec_t protected_iovec; + iovec_t unprotected_iovec; +}; + +/* --- Test utility functions. --- */ + +static void randomly_slice(uint8_t* input, size_t input_length, + iovec_t** output, size_t* output_length) { + if (input_length == 0) { + *output = nullptr; + *output_length = 0; + return; + } + *output_length = gsec_test_bias_random_uint32(kMaxSlices) + 1; + *output = static_cast(gpr_malloc(*output_length * sizeof(iovec_t))); + for (size_t i = 0; i < *output_length - 1; i++) { + size_t slice_length = + gsec_test_bias_random_uint32(static_cast(input_length)); + iovec_t slice = {input, slice_length}; + (*output)[i] = slice; + input += slice_length; + input_length -= slice_length; + } + iovec_t slice = {input, input_length}; + (*output)[*output_length - 1] = slice; +} + +static size_t alter_random_byte(uint8_t* buf, size_t buf_length) { + GPR_ASSERT(buf != nullptr); + uint32_t offset = + gsec_test_bias_random_uint32(static_cast(buf_length)); + (*(buf + offset))++; + return offset; +} + +static void revert_back_alter(uint8_t* buf, size_t offset) { + GPR_ASSERT(buf != nullptr); + (*(buf + offset))--; +} + +static alts_iovec_record_protocol_test_fixture* +alts_iovec_record_protocol_test_fixture_create(bool rekey, + bool integrity_only) { + alts_iovec_record_protocol_test_fixture* fixture = + static_cast( + gpr_malloc(sizeof(alts_iovec_record_protocol_test_fixture))); + size_t overflow_size = 8; + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + uint8_t* key; + gsec_test_random_array(&key, key_length); + gsec_aead_crypter* crypter = nullptr; + /* Create client record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_iovec_record_protocol_create( + crypter, overflow_size, /*is_client=*/true, integrity_only, + /*is_protect=*/true, &fixture->client_protect, + nullptr) == GRPC_STATUS_OK); + /* Create client record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_iovec_record_protocol_create( + crypter, overflow_size, /*is_client=*/true, integrity_only, + /*is_protect=*/false, &fixture->client_unprotect, + nullptr) == GRPC_STATUS_OK); + /* Create server record protocol for protect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_iovec_record_protocol_create( + crypter, overflow_size, /*is_client=*/false, integrity_only, + /*is_protect=*/true, &fixture->server_protect, + nullptr) == GRPC_STATUS_OK); + /* Create server record protocol for unprotect. */ + GPR_ASSERT(gsec_aes_gcm_aead_crypter_create( + key, key_length, kAesGcmNonceLength, kAesGcmTagLength, rekey, + &crypter, nullptr) == GRPC_STATUS_OK); + GPR_ASSERT(alts_iovec_record_protocol_create( + crypter, overflow_size, /*is_client=*/false, integrity_only, + /*is_protect=*/false, &fixture->server_unprotect, + nullptr) == GRPC_STATUS_OK); + + gpr_free(key); + return fixture; +} + +static void alts_iovec_record_protocol_test_fixture_destroy( + alts_iovec_record_protocol_test_fixture* fixture) { + if (fixture == nullptr) { + return; + } + alts_iovec_record_protocol_destroy(fixture->client_protect); + alts_iovec_record_protocol_destroy(fixture->client_unprotect); + alts_iovec_record_protocol_destroy(fixture->server_protect); + alts_iovec_record_protocol_destroy(fixture->server_unprotect); + gpr_free(fixture); +} + +static alts_iovec_record_protocol_test_var* +alts_iovec_record_protocol_test_var_create() { + auto* var = static_cast( + gpr_zalloc(sizeof(alts_iovec_record_protocol_test_var))); + /* Sets header buffer. */ + var->header_length = alts_iovec_record_protocol_get_header_length(); + var->header_buf = static_cast(gpr_malloc(var->header_length)); + var->header_iovec.iov_base = var->header_buf; + var->header_iovec.iov_len = var->header_length; + /* Sets tag buffer. */ + var->tag_length = kTagLength; + var->tag_buf = static_cast(gpr_malloc(var->tag_length)); + var->tag_iovec.iov_base = var->tag_buf; + var->tag_iovec.iov_len = var->tag_length; + /* Randomly sets data buffer and duplicates to dup_buf. */ + var->data_length = gsec_test_bias_random_uint32(kMaxDataSize) + 1; + var->data_buf = static_cast(gpr_malloc(var->data_length)); + gsec_test_random_bytes(var->data_buf, var->data_length); + gsec_test_copy(var->data_buf, &var->dup_buf, var->data_length); + var->data_iovec = nullptr; + var->data_iovec_length = 0; + randomly_slice(var->data_buf, var->data_length, &var->data_iovec, + &var->data_iovec_length); + /* Sets protected iovec. */ + size_t protected_buf_length = + var->header_length + var->data_length + var->tag_length; + var->protected_buf = static_cast(gpr_malloc(protected_buf_length)); + var->protected_iovec.iov_base = var->protected_buf; + var->protected_iovec.iov_len = protected_buf_length; + /* Unprotected iovec points to data_buf. */ + var->unprotected_iovec.iov_base = var->data_buf; + var->unprotected_iovec.iov_len = var->data_length; + return var; +} + +static void alts_iovec_record_protocol_test_var_destroy( + alts_iovec_record_protocol_test_var* var) { + if (var == nullptr) { + return; + } + gpr_free(var->header_buf); + gpr_free(var->tag_buf); + gpr_free(var->data_buf); + gpr_free(var->dup_buf); + gpr_free(var->data_iovec); + gpr_free(var->protected_buf); + gpr_free(var); +} + +/* --- Integrity-only protect/unprotect tests. --- */ + +static void integrity_only_random_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + /* Seals and then unseals. */ + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + sender, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + gpr_free(var->data_iovec); + /* Randomly slices data buffer again. */ + randomly_slice(var->data_buf, var->data_length, &var->data_iovec, + &var->data_iovec_length); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Makes sure data buffer has not been modified during + * seal/unseal. */ + GPR_ASSERT(memcmp(var->data_buf, var->dup_buf, var->data_length) == 0); + alts_iovec_record_protocol_test_var_destroy(var); + } +} + +static void integrity_only_empty_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + /* Seals and then unseals empty payload. */ + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + sender, nullptr, 0, var->header_iovec, var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, nullptr, 0, var->header_iovec, var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + alts_iovec_record_protocol_test_var_destroy(var); + } +} + +static void integrity_only_unsync_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + /* Seals once. */ + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + sender, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + alts_iovec_record_protocol_test_var_destroy(var); + /* Seals again. */ + var = alts_iovec_record_protocol_test_var_create(); + status = alts_iovec_record_protocol_integrity_only_protect( + sender, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Unseals the second frame. */ + char* error_message = nullptr; + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Frame tag verification failed.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void integrity_only_corrupted_data( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + /* Seals the data first. */ + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + sender, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Alter frame length field. */ + char* error_message = nullptr; + size_t offset = + alter_random_byte(var->header_buf, kZeroCopyFrameLengthFieldSize); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, "Bad frame length.")); + gpr_free(error_message); + revert_back_alter(var->header_buf, offset); + /* Alter message type field. */ + offset = alter_random_byte(var->header_buf + kZeroCopyFrameLengthFieldSize, + kZeroCopyFrameMessageTypeFieldSize); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Unsupported message type.")); + gpr_free(error_message); + revert_back_alter(var->header_buf + kZeroCopyFrameLengthFieldSize, offset); + /* Alter data. */ + offset = alter_random_byte(var->data_buf, var->data_length); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Frame tag verification failed.")); + gpr_free(error_message); + revert_back_alter(var->data_buf, offset); + /* Alter tag. */ + offset = alter_random_byte(var->tag_buf, var->tag_length); + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Frame tag verification failed.")); + gpr_free(error_message); + revert_back_alter(var->tag_buf, offset); + /* Reverted protected data should be verified correctly. */ + status = alts_iovec_record_protocol_integrity_only_unprotect( + receiver, var->data_iovec, var->data_iovec_length, var->header_iovec, + var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(var->data_buf, var->dup_buf, var->data_length) == 0); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void integrity_only_protect_input_check(alts_iovec_record_protocol* rp) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + char* error_message = nullptr; + /* Header buffer is nullptr. */ + iovec_t header_iovec = {nullptr, var->header_length}; + grpc_status_code status = alts_iovec_record_protocol_integrity_only_protect( + rp, var->data_iovec, var->data_iovec_length, header_iovec, var->tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header is nullptr.")); + gpr_free(error_message); + /* Header buffer length is 0. */ + header_iovec.iov_base = var->header_buf; + header_iovec.iov_len = 0; + status = alts_iovec_record_protocol_integrity_only_protect( + rp, var->data_iovec, var->data_iovec_length, header_iovec, var->tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header length is incorrect.")); + gpr_free(error_message); + /* Tag buffer is nullptr. */ + iovec_t tag_iovec = {nullptr, var->tag_length}; + status = alts_iovec_record_protocol_integrity_only_protect( + rp, var->data_iovec, var->data_iovec_length, var->header_iovec, tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "Tag is nullptr.")); + gpr_free(error_message); + /* Tag buffer length is 0. */ + tag_iovec.iov_base = var->tag_buf; + tag_iovec.iov_len = 0; + status = alts_iovec_record_protocol_integrity_only_protect( + rp, var->data_iovec, var->data_iovec_length, var->header_iovec, tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Tag length is incorrect.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void integrity_only_unprotect_input_check( + alts_iovec_record_protocol* rp) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + char* error_message = nullptr; + /* Header buffer is nullptr. */ + iovec_t header_iovec = {nullptr, var->header_length}; + grpc_status_code status = alts_iovec_record_protocol_integrity_only_unprotect( + rp, var->data_iovec, var->data_iovec_length, header_iovec, var->tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header is nullptr.")); + gpr_free(error_message); + /* Header buffer length is 0. */ + header_iovec.iov_base = var->header_buf; + header_iovec.iov_len = 0; + status = alts_iovec_record_protocol_integrity_only_unprotect( + rp, var->data_iovec, var->data_iovec_length, header_iovec, var->tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header length is incorrect.")); + gpr_free(error_message); + /* Tag buffer is nullptr. */ + iovec_t tag_iovec = {nullptr, var->tag_length}; + status = alts_iovec_record_protocol_integrity_only_unprotect( + rp, var->data_iovec, var->data_iovec_length, var->header_iovec, tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, "Tag is nullptr.")); + gpr_free(error_message); + /* Tag buffer length is 0. */ + tag_iovec.iov_base = var->tag_buf; + tag_iovec.iov_len = 0; + status = alts_iovec_record_protocol_integrity_only_unprotect( + rp, var->data_iovec, var->data_iovec_length, var->header_iovec, tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Tag length is incorrect.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +/* --- Privacy-integrity protect/unprotect tests. --- */ + +static void privacy_integrity_random_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + /* Seals and then unseals. */ + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + sender, var->data_iovec, var->data_iovec_length, + var->protected_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + iovec_t header_iovec = {var->protected_buf, var->header_length}; + gpr_free(var->data_iovec); + /* Randomly slices protected buffer, excluding the header. */ + randomly_slice(var->protected_buf + var->header_length, + var->data_length + var->tag_length, &var->data_iovec, + &var->data_iovec_length); + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, var->data_iovec, var->data_iovec_length, + var->unprotected_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Makes sure unprotected data are the same as the original. */ + GPR_ASSERT(memcmp(var->data_buf, var->dup_buf, var->data_length) == 0); + alts_iovec_record_protocol_test_var_destroy(var); + } +} + +static void privacy_integrity_empty_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + size_t empty_payload_frame_size = var->header_length + var->tag_length; + auto* protected_buf = + static_cast(gpr_malloc(empty_payload_frame_size)); + for (size_t i = 0; i < kSealRepeatTimes; i++) { + iovec_t protected_iovec = {protected_buf, empty_payload_frame_size}; + iovec_t unprotected_iovec = {nullptr, 0}; + iovec_t data_iovec = {protected_buf + var->header_length, var->tag_length}; + /* Seals and then unseals empty payload. */ + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + sender, nullptr, 0, protected_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + iovec_t header_iovec = {protected_buf, var->header_length}; + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &data_iovec, 1, unprotected_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + } + gpr_free(protected_buf); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void privacy_integrity_unsync_seal_unseal( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + /* Seals once. */ + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + sender, var->data_iovec, var->data_iovec_length, var->protected_iovec, + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + alts_iovec_record_protocol_test_var_destroy(var); + /* Seals again. */ + var = alts_iovec_record_protocol_test_var_create(); + status = alts_iovec_record_protocol_privacy_integrity_protect( + sender, var->data_iovec, var->data_iovec_length, var->protected_iovec, + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Unseals the second frame. */ + char* error_message = nullptr; + iovec_t header_iovec = {var->protected_buf, var->header_length}; + iovec_t protected_iovec = {var->protected_buf + var->header_length, + var->data_length + var->tag_length}; + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, "Frame decryption failed.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void privacy_integrity_corrupted_data( + alts_iovec_record_protocol* sender, alts_iovec_record_protocol* receiver) { + /* Seals the data first. */ + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + sender, var->data_iovec, var->data_iovec_length, var->protected_iovec, + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + char* error_message = nullptr; + uint8_t* header_buf = var->protected_buf; + size_t header_length = var->header_length; + iovec_t header_iovec = {header_buf, header_length}; + /* The following protected_buf and protected_length excludes header. */ + uint8_t* protected_buf = var->protected_buf + var->header_length; + size_t protected_length = var->data_length + var->tag_length; + iovec_t protected_iovec = {protected_buf, protected_length}; + /* Alter frame length field. */ + size_t offset = alter_random_byte(header_buf, kZeroCopyFrameLengthFieldSize); + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, "Bad frame length.")); + gpr_free(error_message); + revert_back_alter(header_buf, offset); + /* Alter message type field. */ + offset = alter_random_byte(header_buf + kZeroCopyFrameLengthFieldSize, + kZeroCopyFrameMessageTypeFieldSize); + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Unsupported message type.")); + gpr_free(error_message); + revert_back_alter(header_buf + kZeroCopyFrameLengthFieldSize, offset); + /* Alter protected data. */ + offset = alter_random_byte(protected_buf, protected_length); + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, "Frame decryption failed.")); + gpr_free(error_message); + revert_back_alter(protected_buf, offset); + /* Reverted protected data should be verified correctly. */ + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + receiver, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + GPR_ASSERT(memcmp(var->data_buf, var->dup_buf, var->data_length) == 0); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void privacy_integrity_protect_input_check( + alts_iovec_record_protocol* rp) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + char* error_message = nullptr; + /* Protected output buffer is nullptr. */ + iovec_t protected_iovec = {nullptr, var->protected_iovec.iov_len}; + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_protect( + rp, var->data_iovec, var->data_iovec_length, protected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Protected frame is nullptr.")); + gpr_free(error_message); + /* Protected output buffer length incorrect. */ + protected_iovec.iov_base = var->protected_buf; + protected_iovec.iov_len = var->header_length + var->data_length; + status = alts_iovec_record_protocol_privacy_integrity_protect( + rp, var->data_iovec, var->data_iovec_length, protected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Protected frame size is incorrect.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void privacy_integrity_unprotect_input_check( + alts_iovec_record_protocol* rp) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + char* error_message = nullptr; + /* Header buffer is nullptr. */ + iovec_t header_iovec = {var->protected_buf, var->header_length}; + iovec_t protected_iovec = {var->protected_buf + var->header_length, + var->data_length + var->tag_length}; + header_iovec.iov_base = nullptr; + grpc_status_code status = + alts_iovec_record_protocol_privacy_integrity_unprotect( + rp, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header is nullptr.")); + gpr_free(error_message); + header_iovec.iov_base = var->protected_buf; + /* Header buffer length is 0. */ + header_iovec.iov_len = 0; + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + rp, header_iovec, &protected_iovec, 1, var->unprotected_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Header length is incorrect.")); + gpr_free(error_message); + header_iovec.iov_len = var->header_length; + /* Unprotected output buffer length is incorrect. */ + iovec_t unprotected_iovec = {var->data_buf, var->data_length - 1}; + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + rp, header_iovec, &protected_iovec, 1, unprotected_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INVALID_ARGUMENT, error_message, + "Unprotected data size is incorrect.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +/* --- Integrity-only and privacy-integrity mixed. --- */ + +static void record_protocol_wrong_mode( + alts_iovec_record_protocol* integrity_only_protect_rp, + alts_iovec_record_protocol* integrity_only_unprotect_rp, + alts_iovec_record_protocol* privacy_integrity_protect_rp, + alts_iovec_record_protocol* privacy_integrity_unprotect_rp) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status; + char* error_message = nullptr; + /* Call integrity-only protect on privacy-integrity record protocol. */ + status = alts_iovec_record_protocol_integrity_only_protect( + privacy_integrity_protect_rp, var->data_iovec, var->data_iovec_length, + var->header_iovec, var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Integrity-only operations are not allowed for this object.")); + gpr_free(error_message); + /* Call integrity-only unprotect on privacy-integrity record protocol. */ + status = alts_iovec_record_protocol_integrity_only_unprotect( + privacy_integrity_unprotect_rp, var->data_iovec, var->data_iovec_length, + var->header_iovec, var->tag_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Integrity-only operations are not allowed for this object.")); + gpr_free(error_message); + /* Call privacy-integrity protect on integrity-only record protocol. */ + status = alts_iovec_record_protocol_privacy_integrity_protect( + integrity_only_protect_rp, var->data_iovec, var->data_iovec_length, + var->protected_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Privacy-integrity operations are not allowed for this object.")); + gpr_free(error_message); + /* Call privacy-integrity unprotect on integrity-only record protocol. */ + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + integrity_only_unprotect_rp, var->header_iovec, var->data_iovec, + var->data_iovec_length, var->unprotected_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_FAILED_PRECONDITION, error_message, + "Privacy-integrity operations are not allowed for this object.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void integrity_seal_privacy_unseal( + alts_iovec_record_protocol* integrity_only_sender, + alts_iovec_record_protocol* privacy_integrity_receiver) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status; + char* error_message = nullptr; + /* Seals with integrity-only protect. */ + status = alts_iovec_record_protocol_integrity_only_protect( + integrity_only_sender, var->data_iovec, var->data_iovec_length, + var->header_iovec, var->tag_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Unseal with privacy-integrity unprotect. */ + memcpy(var->protected_buf, var->data_buf, var->data_length); + memcpy(var->protected_buf + var->data_length, var->tag_buf, var->tag_length); + iovec_t protected_iovec = {var->protected_buf, + var->data_length + var->tag_length}; + status = alts_iovec_record_protocol_privacy_integrity_unprotect( + privacy_integrity_receiver, var->header_iovec, &protected_iovec, 1, + var->unprotected_iovec, &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, "Frame decryption failed.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +static void privacy_seal_integrity_unseal( + alts_iovec_record_protocol* privacy_integrity_sender, + alts_iovec_record_protocol* integrity_only_receiver) { + alts_iovec_record_protocol_test_var* var = + alts_iovec_record_protocol_test_var_create(); + grpc_status_code status; + char* error_message = nullptr; + /* Seals with privacy-integrity protect. */ + status = alts_iovec_record_protocol_privacy_integrity_protect( + privacy_integrity_sender, var->data_iovec, var->data_iovec_length, + var->protected_iovec, nullptr); + GPR_ASSERT(status == GRPC_STATUS_OK); + /* Unseal with integrity-only unprotect. */ + iovec_t header_iovec = {var->protected_buf, var->header_length}; + iovec_t data_iovec = {var->protected_buf + var->header_length, + var->data_length}; + iovec_t tag_iovec = { + var->protected_buf + var->header_length + var->data_length, + var->tag_length}; + status = alts_iovec_record_protocol_integrity_only_unprotect( + integrity_only_receiver, &data_iovec, 1, header_iovec, tag_iovec, + &error_message); + GPR_ASSERT(gsec_test_expect_compare_code_and_substr( + status, GRPC_STATUS_INTERNAL, error_message, + "Frame tag verification failed.")); + gpr_free(error_message); + alts_iovec_record_protocol_test_var_destroy(var); +} + +/* --- Test cases. --- */ + +static void alts_iovec_record_protocol_random_seal_unseal_tests() { + alts_iovec_record_protocol_test_fixture* fixture = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + integrity_only_random_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_random_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true); + integrity_only_random_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_random_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + privacy_integrity_random_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_random_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false); + privacy_integrity_random_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_random_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); +} + +static void alts_iovec_record_protocol_empty_seal_unseal_tests() { + alts_iovec_record_protocol_test_fixture* fixture = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + integrity_only_empty_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_empty_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true); + integrity_only_empty_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_empty_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + privacy_integrity_empty_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_empty_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false); + privacy_integrity_empty_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_empty_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); +} + +static void alts_iovec_record_protocol_unsync_seal_unseal_tests() { + alts_iovec_record_protocol_test_fixture* fixture = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + integrity_only_unsync_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_unsync_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true); + integrity_only_unsync_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + integrity_only_unsync_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + privacy_integrity_unsync_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_unsync_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false); + privacy_integrity_unsync_seal_unseal(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_unsync_seal_unseal(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); +} + +static void alts_iovec_record_protocol_corrupted_data_tests() { + alts_iovec_record_protocol_test_fixture* fixture = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + integrity_only_corrupted_data(fixture->client_protect, + fixture->server_unprotect); + integrity_only_corrupted_data(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true); + integrity_only_corrupted_data(fixture->client_protect, + fixture->server_unprotect); + integrity_only_corrupted_data(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + privacy_integrity_corrupted_data(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_corrupted_data(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false); + privacy_integrity_corrupted_data(fixture->client_protect, + fixture->server_unprotect); + privacy_integrity_corrupted_data(fixture->server_protect, + fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); +} + +static void alts_iovec_record_protocol_input_check_tests() { + alts_iovec_record_protocol_test_fixture* fixture = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + integrity_only_protect_input_check(fixture->client_protect); + integrity_only_unprotect_input_check(fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true); + integrity_only_protect_input_check(fixture->client_protect); + integrity_only_unprotect_input_check(fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + privacy_integrity_protect_input_check(fixture->client_protect); + privacy_integrity_unprotect_input_check(fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); + + fixture = alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false); + privacy_integrity_protect_input_check(fixture->client_protect); + privacy_integrity_unprotect_input_check(fixture->client_unprotect); + alts_iovec_record_protocol_test_fixture_destroy(fixture); +} + +static void alts_iovec_record_protocol_mix_operations_tests() { + alts_iovec_record_protocol_test_fixture* fixture_1 = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true); + alts_iovec_record_protocol_test_fixture* fixture_2 = + alts_iovec_record_protocol_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false); + + record_protocol_wrong_mode( + fixture_1->client_protect, fixture_1->client_unprotect, + fixture_2->client_protect, fixture_2->client_unprotect); + integrity_seal_privacy_unseal(fixture_1->client_protect, + fixture_2->server_unprotect); + privacy_seal_integrity_unseal(fixture_2->client_protect, + fixture_1->server_unprotect); + + alts_iovec_record_protocol_test_fixture_destroy(fixture_1); + alts_iovec_record_protocol_test_fixture_destroy(fixture_2); +} + +int main(int /*argc*/, char** /*argv*/) { + alts_iovec_record_protocol_random_seal_unseal_tests(); + alts_iovec_record_protocol_empty_seal_unseal_tests(); + alts_iovec_record_protocol_unsync_seal_unseal_tests(); + alts_iovec_record_protocol_corrupted_data_tests(); + alts_iovec_record_protocol_input_check_tests(); + alts_iovec_record_protocol_mix_operations_tests(); + return 0; +} diff --git a/test/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector_test.cc b/test/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector_test.cc new file mode 100644 index 00000000..62b3a4b7 --- /dev/null +++ b/test/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector_test.cc @@ -0,0 +1,313 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" + +#include +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/tsi/alts/crypt/gsec.h" +#include "src/core/tsi/transport_security_grpc.h" +#include "test/core/tsi/alts/crypt/gsec_test_util.h" +#include "test/core/util/test_config.h" + +/* TODO: tests zero_copy_grpc_protector under TSI test library, which + * has more comprehensive tests. */ + +constexpr size_t kSealRepeatTimes = 50; +constexpr size_t kSmallBufferSize = 16; +constexpr size_t kLargeBufferSize = 16384; +constexpr size_t kChannelMaxSize = 2048; +constexpr size_t kChannelMinSize = 128; + +/* Test fixtures for each test cases. */ +struct alts_zero_copy_grpc_protector_test_fixture { + tsi_zero_copy_grpc_protector* client; + tsi_zero_copy_grpc_protector* server; +}; + +/* Test input variables for protect/unprotect operations. */ +struct alts_zero_copy_grpc_protector_test_var { + grpc_slice_buffer original_sb; + grpc_slice_buffer duplicate_sb; + grpc_slice_buffer staging_sb; + grpc_slice_buffer protected_sb; + grpc_slice_buffer unprotected_sb; +}; + +/* --- Test utility functions. --- */ + +static void create_random_slice_buffer(grpc_slice_buffer* sb, + grpc_slice_buffer* dup_sb, + size_t length) { + GPR_ASSERT(sb != nullptr); + GPR_ASSERT(dup_sb != nullptr); + GPR_ASSERT(length > 0); + grpc_slice slice = GRPC_SLICE_MALLOC(length); + gsec_test_random_bytes(GRPC_SLICE_START_PTR(slice), length); + grpc_slice_buffer_add(sb, grpc_slice_ref(slice)); + grpc_slice_buffer_add(dup_sb, slice); +} + +static uint8_t* pointer_to_nth_byte(grpc_slice_buffer* sb, size_t index) { + GPR_ASSERT(sb != nullptr); + GPR_ASSERT(index < sb->length); + for (size_t i = 0; i < sb->count; i++) { + if (index < GRPC_SLICE_LENGTH(sb->slices[i])) { + return GRPC_SLICE_START_PTR(sb->slices[i]) + index; + } else { + index -= GRPC_SLICE_LENGTH(sb->slices[i]); + } + } + return nullptr; +} + +/* Checks if two slice buffer contents are the same. It is not super efficient, + * but OK for testing. */ +static bool are_slice_buffers_equal(grpc_slice_buffer* first, + grpc_slice_buffer* second) { + GPR_ASSERT(first != nullptr); + GPR_ASSERT(second != nullptr); + if (first->length != second->length) { + return false; + } + for (size_t i = 0; i < first->length; i++) { + uint8_t* first_ptr = pointer_to_nth_byte(first, i); + uint8_t* second_ptr = pointer_to_nth_byte(second, i); + GPR_ASSERT(first_ptr != nullptr && second_ptr != nullptr); + if ((*first_ptr) != (*second_ptr)) { + return false; + } + } + return true; +} + +static alts_zero_copy_grpc_protector_test_fixture* +alts_zero_copy_grpc_protector_test_fixture_create(bool rekey, + bool integrity_only, + bool enable_extra_copy) { + alts_zero_copy_grpc_protector_test_fixture* fixture = + static_cast( + gpr_zalloc(sizeof(alts_zero_copy_grpc_protector_test_fixture))); + grpc_core::ExecCtx exec_ctx; + size_t key_length = rekey ? kAes128GcmRekeyKeyLength : kAes128GcmKeyLength; + uint8_t* key; + size_t max_protected_frame_size = 1024; + size_t actual_max_protected_frame_size; + gsec_test_random_array(&key, key_length); + GPR_ASSERT(alts_zero_copy_grpc_protector_create( + key, key_length, rekey, /*is_client=*/true, integrity_only, + enable_extra_copy, &max_protected_frame_size, + &fixture->client) == TSI_OK); + GPR_ASSERT(tsi_zero_copy_grpc_protector_max_frame_size( + fixture->client, &actual_max_protected_frame_size) == TSI_OK); + GPR_ASSERT(actual_max_protected_frame_size == max_protected_frame_size); + GPR_ASSERT(alts_zero_copy_grpc_protector_create( + key, key_length, rekey, /*is_client=*/false, integrity_only, + enable_extra_copy, &max_protected_frame_size, + &fixture->server) == TSI_OK); + GPR_ASSERT(tsi_zero_copy_grpc_protector_max_frame_size( + fixture->server, &actual_max_protected_frame_size) == TSI_OK); + GPR_ASSERT(actual_max_protected_frame_size == max_protected_frame_size); + gpr_free(key); + grpc_core::ExecCtx::Get()->Flush(); + return fixture; +} + +static void alts_zero_copy_grpc_protector_test_fixture_destroy( + alts_zero_copy_grpc_protector_test_fixture* fixture) { + if (fixture == nullptr) { + return; + } + grpc_core::ExecCtx exec_ctx; + tsi_zero_copy_grpc_protector_destroy(fixture->client); + tsi_zero_copy_grpc_protector_destroy(fixture->server); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(fixture); +} + +static alts_zero_copy_grpc_protector_test_var* +alts_zero_copy_grpc_protector_test_var_create() { + alts_zero_copy_grpc_protector_test_var* var = + static_cast( + gpr_zalloc(sizeof(alts_zero_copy_grpc_protector_test_var))); + grpc_slice_buffer_init(&var->original_sb); + grpc_slice_buffer_init(&var->duplicate_sb); + grpc_slice_buffer_init(&var->staging_sb); + grpc_slice_buffer_init(&var->protected_sb); + grpc_slice_buffer_init(&var->unprotected_sb); + return var; +} + +static void alts_zero_copy_grpc_protector_test_var_destroy( + alts_zero_copy_grpc_protector_test_var* var) { + if (var == nullptr) { + return; + } + grpc_slice_buffer_destroy_internal(&var->original_sb); + grpc_slice_buffer_destroy_internal(&var->duplicate_sb); + grpc_slice_buffer_destroy_internal(&var->staging_sb); + grpc_slice_buffer_destroy_internal(&var->protected_sb); + grpc_slice_buffer_destroy_internal(&var->unprotected_sb); + gpr_free(var); +} + +/* --- ALTS zero-copy protector tests. --- */ + +static void seal_unseal_small_buffer(tsi_zero_copy_grpc_protector* sender, + tsi_zero_copy_grpc_protector* receiver) { + grpc_core::ExecCtx exec_ctx; + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_zero_copy_grpc_protector_test_var* var = + alts_zero_copy_grpc_protector_test_var_create(); + /* Creates a random small slice buffer and calls protect(). */ + create_random_slice_buffer(&var->original_sb, &var->duplicate_sb, + kSmallBufferSize); + GPR_ASSERT(tsi_zero_copy_grpc_protector_protect( + sender, &var->original_sb, &var->protected_sb) == TSI_OK); + /* Splits protected slice buffer into two: first one is staging_sb, and + * second one is protected_sb. */ + uint32_t staging_sb_size = + gsec_test_bias_random_uint32( + static_cast(var->protected_sb.length - 1)) + + 1; + grpc_slice_buffer_move_first(&var->protected_sb, staging_sb_size, + &var->staging_sb); + /* Unprotects one by one. */ + GPR_ASSERT(tsi_zero_copy_grpc_protector_unprotect( + receiver, &var->staging_sb, &var->unprotected_sb) == TSI_OK); + GPR_ASSERT(var->unprotected_sb.length == 0); + GPR_ASSERT(tsi_zero_copy_grpc_protector_unprotect( + receiver, &var->protected_sb, &var->unprotected_sb) == + TSI_OK); + GPR_ASSERT( + are_slice_buffers_equal(&var->unprotected_sb, &var->duplicate_sb)); + alts_zero_copy_grpc_protector_test_var_destroy(var); + } + grpc_core::ExecCtx::Get()->Flush(); +} + +static void seal_unseal_large_buffer(tsi_zero_copy_grpc_protector* sender, + tsi_zero_copy_grpc_protector* receiver) { + grpc_core::ExecCtx exec_ctx; + for (size_t i = 0; i < kSealRepeatTimes; i++) { + alts_zero_copy_grpc_protector_test_var* var = + alts_zero_copy_grpc_protector_test_var_create(); + /* Creates a random large slice buffer and calls protect(). */ + create_random_slice_buffer(&var->original_sb, &var->duplicate_sb, + kLargeBufferSize); + GPR_ASSERT(tsi_zero_copy_grpc_protector_protect( + sender, &var->original_sb, &var->protected_sb) == TSI_OK); + /* Splits protected slice buffer into multiple pieces. Receiver unprotects + * each slice buffer one by one. */ + uint32_t channel_size = gsec_test_bias_random_uint32(static_cast( + kChannelMaxSize + 1 - kChannelMinSize)) + + static_cast(kChannelMinSize); + while (var->protected_sb.length > channel_size) { + grpc_slice_buffer_reset_and_unref_internal(&var->staging_sb); + grpc_slice_buffer_move_first(&var->protected_sb, channel_size, + &var->staging_sb); + GPR_ASSERT(tsi_zero_copy_grpc_protector_unprotect( + receiver, &var->staging_sb, &var->unprotected_sb) == + TSI_OK); + } + GPR_ASSERT(tsi_zero_copy_grpc_protector_unprotect( + receiver, &var->protected_sb, &var->unprotected_sb) == + TSI_OK); + GPR_ASSERT( + are_slice_buffers_equal(&var->unprotected_sb, &var->duplicate_sb)); + alts_zero_copy_grpc_protector_test_var_destroy(var); + } + grpc_core::ExecCtx::Get()->Flush(); +} + +/* --- Test cases. --- */ + +static void alts_zero_copy_protector_seal_unseal_small_buffer_tests( + bool enable_extra_copy) { + alts_zero_copy_grpc_protector_test_fixture* fixture = + alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true, enable_extra_copy); + seal_unseal_small_buffer(fixture->client, fixture->server); + seal_unseal_small_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false, enable_extra_copy); + seal_unseal_small_buffer(fixture->client, fixture->server); + seal_unseal_small_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true, enable_extra_copy); + seal_unseal_small_buffer(fixture->client, fixture->server); + seal_unseal_small_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false, enable_extra_copy); + seal_unseal_small_buffer(fixture->client, fixture->server); + seal_unseal_small_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); +} + +static void alts_zero_copy_protector_seal_unseal_large_buffer_tests( + bool enable_extra_copy) { + alts_zero_copy_grpc_protector_test_fixture* fixture = + alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/true, enable_extra_copy); + seal_unseal_large_buffer(fixture->client, fixture->server); + seal_unseal_large_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/false, /*integrity_only=*/false, enable_extra_copy); + seal_unseal_large_buffer(fixture->client, fixture->server); + seal_unseal_large_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/true, enable_extra_copy); + seal_unseal_large_buffer(fixture->client, fixture->server); + seal_unseal_large_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); + + fixture = alts_zero_copy_grpc_protector_test_fixture_create( + /*rekey=*/true, /*integrity_only=*/false, enable_extra_copy); + seal_unseal_large_buffer(fixture->client, fixture->server); + seal_unseal_large_buffer(fixture->server, fixture->client); + alts_zero_copy_grpc_protector_test_fixture_destroy(fixture); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + alts_zero_copy_protector_seal_unseal_small_buffer_tests( + /*enable_extra_copy=*/false); + alts_zero_copy_protector_seal_unseal_small_buffer_tests( + /*enable_extra_copy=*/true); + alts_zero_copy_protector_seal_unseal_large_buffer_tests( + /*enable_extra_copy=*/false); + alts_zero_copy_protector_seal_unseal_large_buffer_tests( + /*enable_extra_copy=*/true); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/fake_transport_security_test.cc b/test/core/tsi/fake_transport_security_test.cc new file mode 100644 index 00000000..9e1aaf20 --- /dev/null +++ b/test/core/tsi/fake_transport_security_test.cc @@ -0,0 +1,157 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/fake_transport_security.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/tsi/transport_security.h" +#include "test/core/tsi/transport_security_test_lib.h" +#include "test/core/util/test_config.h" + +typedef struct fake_tsi_test_fixture { + tsi_test_fixture base; +} fake_tsi_test_fixture; + +static void fake_test_setup_handshakers(tsi_test_fixture* fixture) { + fixture->client_handshaker = + tsi_create_fake_handshaker(true /* is_client. */); + fixture->server_handshaker = + tsi_create_fake_handshaker(false /* is_client. */); +} + +static void validate_handshaker_peers(tsi_handshaker_result* result) { + GPR_ASSERT(result != nullptr); + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + const tsi_peer_property* property = + tsi_peer_get_property_by_name(&peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); + GPR_ASSERT(property != nullptr); + GPR_ASSERT(memcmp(property->value.data, TSI_FAKE_CERTIFICATE_TYPE, + property->value.length) == 0); + property = + tsi_peer_get_property_by_name(&peer, TSI_SECURITY_LEVEL_PEER_PROPERTY); + GPR_ASSERT(property != nullptr); + GPR_ASSERT(memcmp(property->value.data, TSI_FAKE_SECURITY_LEVEL, + property->value.length) == 0); + tsi_peer_destruct(&peer); +} + +static void fake_test_check_handshaker_peers(tsi_test_fixture* fixture) { + validate_handshaker_peers(fixture->client_result); + validate_handshaker_peers(fixture->server_result); +} + +static void fake_test_destruct(tsi_test_fixture* /*fixture*/) {} + +static const struct tsi_test_fixture_vtable vtable = { + fake_test_setup_handshakers, fake_test_check_handshaker_peers, + fake_test_destruct}; + +static tsi_test_fixture* fake_tsi_test_fixture_create() { + fake_tsi_test_fixture* fake_fixture = + static_cast(gpr_zalloc(sizeof(*fake_fixture))); + tsi_test_fixture_init(&fake_fixture->base); + fake_fixture->base.vtable = &vtable; + return &fake_fixture->base; +} + +void fake_tsi_test_do_handshake_tiny_handshake_buffer() { + tsi_test_fixture* fixture = fake_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_TINY_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_handshake_small_handshake_buffer() { + tsi_test_fixture* fixture = fake_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_SMALL_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_handshake() { + tsi_test_fixture* fixture = fake_tsi_test_fixture_create(); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void fake_tsi_test_do_round_trip_for_all_configs() { + unsigned int* bit_array = static_cast( + gpr_zalloc(sizeof(unsigned int) * TSI_TEST_NUM_OF_ARGUMENTS)); + const unsigned int mask = 1U << (TSI_TEST_NUM_OF_ARGUMENTS - 1); + for (unsigned int val = 0; val < TSI_TEST_NUM_OF_COMBINATIONS; val++) { + unsigned int v = val; + for (unsigned int ind = 0; ind < TSI_TEST_NUM_OF_ARGUMENTS; ind++) { + bit_array[ind] = (v & mask) ? 1 : 0; + v <<= 1; + } + tsi_test_fixture* fixture = fake_tsi_test_fixture_create(); + fake_tsi_test_fixture* fake_fixture = + reinterpret_cast(fixture); + tsi_test_frame_protector_config_destroy(fake_fixture->base.config); + fake_fixture->base.config = tsi_test_frame_protector_config_create( + bit_array[0], bit_array[1], bit_array[2], bit_array[3], bit_array[4], + bit_array[5], bit_array[6]); + tsi_test_do_round_trip(&fake_fixture->base); + tsi_test_fixture_destroy(fixture); + } + gpr_free(bit_array); +} + +void fake_tsi_test_do_round_trip_odd_buffer_size() { + const size_t odd_sizes[] = {1025, 2051, 4103, 8207, 16409}; + const size_t size = sizeof(odd_sizes) / sizeof(size_t); + for (size_t ind1 = 0; ind1 < size; ind1++) { + for (size_t ind2 = 0; ind2 < size; ind2++) { + for (size_t ind3 = 0; ind3 < size; ind3++) { + for (size_t ind4 = 0; ind4 < size; ind4++) { + for (size_t ind5 = 0; ind5 < size; ind5++) { + tsi_test_fixture* fixture = fake_tsi_test_fixture_create(); + fake_tsi_test_fixture* fake_fixture = + reinterpret_cast(fixture); + tsi_test_frame_protector_config_set_buffer_size( + fake_fixture->base.config, odd_sizes[ind1], odd_sizes[ind2], + odd_sizes[ind3], odd_sizes[ind4], odd_sizes[ind5]); + tsi_test_do_round_trip(&fake_fixture->base); + tsi_test_fixture_destroy(fixture); + } + } + } + } + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + fake_tsi_test_do_handshake_tiny_handshake_buffer(); + fake_tsi_test_do_handshake_small_handshake_buffer(); + fake_tsi_test_do_handshake(); + fake_tsi_test_do_round_trip_for_all_configs(); + fake_tsi_test_do_round_trip_odd_buffer_size(); + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/ssl_session_cache_test.cc b/test/core/tsi/ssl_session_cache_test.cc new file mode 100644 index 00000000..353d75f8 --- /dev/null +++ b/test/core/tsi/ssl_session_cache_test.cc @@ -0,0 +1,155 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/ssl/session_cache/ssl_session_cache.h" + +#include +#include + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { + +namespace { + +class SessionTracker; + +struct SessionExDataId { + SessionTracker* tracker; + long id; +}; + +class SessionTracker { + public: + SessionTracker() { ssl_context_ = SSL_CTX_new(TLSv1_2_method()); } + + ~SessionTracker() { SSL_CTX_free(ssl_context_); } + + tsi::SslSessionPtr NewSession(long id) { + static int ex_data_id = SSL_SESSION_get_ex_new_index( + 0, nullptr, nullptr, nullptr, DestroyExData); + GPR_ASSERT(ex_data_id != -1); + // OpenSSL and different version of BoringSSL don't agree on API + // so try both. + tsi::SslSessionPtr session = NewSessionInternal(SSL_SESSION_new); + SessionExDataId* data = new SessionExDataId{this, id}; + int result = SSL_SESSION_set_ex_data(session.get(), ex_data_id, data); + EXPECT_EQ(result, 1); + alive_sessions_.insert(id); + return session; + } + + bool IsAlive(long id) const { + return alive_sessions_.find(id) != alive_sessions_.end(); + } + + size_t AliveCount() const { return alive_sessions_.size(); } + + private: + tsi::SslSessionPtr NewSessionInternal(SSL_SESSION* (*cb)()) { + return tsi::SslSessionPtr(cb()); + } + + tsi::SslSessionPtr NewSessionInternal(SSL_SESSION* (*cb)(const SSL_CTX*)) { + return tsi::SslSessionPtr(cb(ssl_context_)); + } + + static void DestroyExData(void* /*parent*/, void* ptr, CRYPTO_EX_DATA* /*ad*/, + int /*index*/, long /*argl*/, void* /*argp*/) { + SessionExDataId* data = static_cast(ptr); + data->tracker->alive_sessions_.erase(data->id); + delete data; + } + + SSL_CTX* ssl_context_; + std::unordered_set alive_sessions_; +}; + +TEST(SslSessionCacheTest, InitialState) { + SessionTracker tracker; + // Verify session initial state. + { + tsi::SslSessionPtr tmp_sess = tracker.NewSession(1); + EXPECT_TRUE(tracker.IsAlive(1)); + EXPECT_EQ(tracker.AliveCount(), 1); + } + EXPECT_FALSE(tracker.IsAlive(1)); + EXPECT_EQ(tracker.AliveCount(), 0); +} + +TEST(SslSessionCacheTest, LruCache) { + SessionTracker tracker; + { + RefCountedPtr cache = + tsi::SslSessionLRUCache::Create(3); + tsi::SslSessionPtr sess2 = tracker.NewSession(2); + SSL_SESSION* sess2_ptr = sess2.get(); + cache->Put("first.dropbox.com", std::move(sess2)); + EXPECT_EQ(cache->Get("first.dropbox.com").get(), sess2_ptr); + EXPECT_TRUE(tracker.IsAlive(2)); + EXPECT_EQ(tracker.AliveCount(), 1); + // Putting element with the same key destroys old session. + tsi::SslSessionPtr sess3 = tracker.NewSession(3); + SSL_SESSION* sess3_ptr = sess3.get(); + cache->Put("first.dropbox.com", std::move(sess3)); + EXPECT_FALSE(tracker.IsAlive(2)); + EXPECT_EQ(cache->Get("first.dropbox.com").get(), sess3_ptr); + EXPECT_TRUE(tracker.IsAlive(3)); + EXPECT_EQ(tracker.AliveCount(), 1); + // Putting three more elements discards current one. + for (long id = 4; id < 7; id++) { + EXPECT_TRUE(tracker.IsAlive(3)); + std::string domain = std::to_string(id) + ".random.domain"; + cache->Put(domain.c_str(), tracker.NewSession(id)); + } + EXPECT_EQ(cache->Size(), 3); + EXPECT_FALSE(tracker.IsAlive(3)); + EXPECT_EQ(tracker.AliveCount(), 3); + // Accessing element moves it into front of the queue. + EXPECT_TRUE(cache->Get("4.random.domain")); + EXPECT_TRUE(tracker.IsAlive(4)); + EXPECT_TRUE(tracker.IsAlive(5)); + EXPECT_TRUE(tracker.IsAlive(6)); + // One element has to be evicted from cache-> + cache->Put("7.random.domain", tracker.NewSession(7)); + EXPECT_TRUE(tracker.IsAlive(4)); + EXPECT_FALSE(tracker.IsAlive(5)); + EXPECT_TRUE(tracker.IsAlive(6)); + EXPECT_TRUE(tracker.IsAlive(7)); + EXPECT_EQ(tracker.AliveCount(), 3); + } + // Cache destructor destroys all sessions. + EXPECT_EQ(tracker.AliveCount(), 0); +} + +} // namespace +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/tsi/ssl_transport_security_test.cc b/test/core/tsi/ssl_transport_security_test.cc new file mode 100644 index 00000000..c560276b --- /dev/null +++ b/test/core/tsi/ssl_transport_security_test.cc @@ -0,0 +1,1058 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/ssl_transport_security.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/tsi/transport_security.h" +#include "src/core/tsi/transport_security_interface.h" +#include "test/core/tsi/transport_security_test_lib.h" +#include "test/core/util/test_config.h" + +extern "C" { +#include +#include +#include +} + +#define SSL_TSI_TEST_ALPN1 "foo" +#define SSL_TSI_TEST_ALPN2 "toto" +#define SSL_TSI_TEST_ALPN3 "baz" +#define SSL_TSI_TEST_ALPN_NUM 2 +#define SSL_TSI_TEST_SERVER_KEY_CERT_PAIRS_NUM 2 +#define SSL_TSI_TEST_BAD_SERVER_KEY_CERT_PAIRS_NUM 1 +#define SSL_TSI_TEST_CREDENTIALS_DIR "src/core/tsi/test_creds/" +#define SSL_TSI_TEST_WRONG_SNI "test.google.cn" + +// OpenSSL 1.1 uses AES256 for encryption session ticket by default so specify +// different STEK size. +#if OPENSSL_VERSION_NUMBER >= 0x10100000 && !defined(OPENSSL_IS_BORINGSSL) +const size_t kSessionTicketEncryptionKeySize = 80; +#else +const size_t kSessionTicketEncryptionKeySize = 48; +#endif + +// Indicates the TLS version used for the test. +static tsi_tls_version test_tls_version = tsi_tls_version::TSI_TLS1_3; + +typedef enum AlpnMode { + NO_ALPN, + ALPN_CLIENT_NO_SERVER, + ALPN_SERVER_NO_CLIENT, + ALPN_CLIENT_SERVER_OK, + ALPN_CLIENT_SERVER_MISMATCH +} AlpnMode; + +typedef struct ssl_alpn_lib { + AlpnMode alpn_mode; + const char** server_alpn_protocols; + const char** client_alpn_protocols; + uint16_t num_server_alpn_protocols; + uint16_t num_client_alpn_protocols; +} ssl_alpn_lib; + +typedef struct ssl_key_cert_lib { + bool use_bad_server_cert; + bool use_bad_client_cert; + bool use_root_store; + char* root_cert; + tsi_ssl_root_certs_store* root_store; + tsi_ssl_pem_key_cert_pair* server_pem_key_cert_pairs; + tsi_ssl_pem_key_cert_pair* bad_server_pem_key_cert_pairs; + tsi_ssl_pem_key_cert_pair client_pem_key_cert_pair; + tsi_ssl_pem_key_cert_pair bad_client_pem_key_cert_pair; + uint16_t server_num_key_cert_pairs; + uint16_t bad_server_num_key_cert_pairs; +} ssl_key_cert_lib; + +typedef struct ssl_tsi_test_fixture { + tsi_test_fixture base; + ssl_key_cert_lib* key_cert_lib; + ssl_alpn_lib* alpn_lib; + bool force_client_auth; + char* server_name_indication; + tsi_ssl_session_cache* session_cache; + bool session_reused; + const char* session_ticket_key; + size_t session_ticket_key_size; + tsi_ssl_server_handshaker_factory* server_handshaker_factory; + tsi_ssl_client_handshaker_factory* client_handshaker_factory; +} ssl_tsi_test_fixture; + +static void ssl_test_setup_handshakers(tsi_test_fixture* fixture) { + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + GPR_ASSERT(ssl_fixture != nullptr); + GPR_ASSERT(ssl_fixture->key_cert_lib != nullptr); + GPR_ASSERT(ssl_fixture->alpn_lib != nullptr); + ssl_key_cert_lib* key_cert_lib = ssl_fixture->key_cert_lib; + ssl_alpn_lib* alpn_lib = ssl_fixture->alpn_lib; + /* Create client handshaker factory. */ + tsi_ssl_client_handshaker_options client_options; + client_options.pem_root_certs = key_cert_lib->root_cert; + if (ssl_fixture->force_client_auth) { + client_options.pem_key_cert_pair = + key_cert_lib->use_bad_client_cert + ? &key_cert_lib->bad_client_pem_key_cert_pair + : &key_cert_lib->client_pem_key_cert_pair; + } + if (alpn_lib->alpn_mode == ALPN_CLIENT_NO_SERVER || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + client_options.alpn_protocols = alpn_lib->client_alpn_protocols; + client_options.num_alpn_protocols = alpn_lib->num_client_alpn_protocols; + } + client_options.root_store = + key_cert_lib->use_root_store ? key_cert_lib->root_store : nullptr; + if (ssl_fixture->session_cache != nullptr) { + client_options.session_cache = ssl_fixture->session_cache; + } + client_options.min_tls_version = test_tls_version; + client_options.max_tls_version = test_tls_version; + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory_with_options( + &client_options, &ssl_fixture->client_handshaker_factory) == + TSI_OK); + /* Create server handshaker factory. */ + tsi_ssl_server_handshaker_options server_options; + if (alpn_lib->alpn_mode == ALPN_SERVER_NO_CLIENT || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK || + alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + server_options.alpn_protocols = alpn_lib->server_alpn_protocols; + server_options.num_alpn_protocols = alpn_lib->num_server_alpn_protocols; + if (alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_MISMATCH) { + server_options.num_alpn_protocols--; + } + } + server_options.pem_key_cert_pairs = + key_cert_lib->use_bad_server_cert + ? key_cert_lib->bad_server_pem_key_cert_pairs + : key_cert_lib->server_pem_key_cert_pairs; + server_options.num_key_cert_pairs = + key_cert_lib->use_bad_server_cert + ? key_cert_lib->bad_server_num_key_cert_pairs + : key_cert_lib->server_num_key_cert_pairs; + server_options.pem_client_root_certs = key_cert_lib->root_cert; + if (ssl_fixture->force_client_auth) { + server_options.client_certificate_request = + TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + } else { + server_options.client_certificate_request = + TSI_DONT_REQUEST_CLIENT_CERTIFICATE; + } + server_options.session_ticket_key = ssl_fixture->session_ticket_key; + server_options.session_ticket_key_size = ssl_fixture->session_ticket_key_size; + server_options.min_tls_version = test_tls_version; + server_options.max_tls_version = test_tls_version; + GPR_ASSERT(tsi_create_ssl_server_handshaker_factory_with_options( + &server_options, &ssl_fixture->server_handshaker_factory) == + TSI_OK); + /* Create server and client handshakers. */ + GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker( + ssl_fixture->client_handshaker_factory, + ssl_fixture->server_name_indication, + &ssl_fixture->base.client_handshaker) == TSI_OK); + GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker( + ssl_fixture->server_handshaker_factory, + &ssl_fixture->base.server_handshaker) == TSI_OK); +} + +static void check_alpn(ssl_tsi_test_fixture* ssl_fixture, + const tsi_peer* peer) { + GPR_ASSERT(ssl_fixture != nullptr); + GPR_ASSERT(ssl_fixture->alpn_lib != nullptr); + ssl_alpn_lib* alpn_lib = ssl_fixture->alpn_lib; + const tsi_peer_property* alpn_property = + tsi_peer_get_property_by_name(peer, TSI_SSL_ALPN_SELECTED_PROTOCOL); + if (alpn_lib->alpn_mode != ALPN_CLIENT_SERVER_OK) { + GPR_ASSERT(alpn_property == nullptr); + } else { + GPR_ASSERT(alpn_property != nullptr); + const char* expected_match = "baz"; + GPR_ASSERT(memcmp(alpn_property->value.data, expected_match, + alpn_property->value.length) == 0); + } +} + +static void check_security_level(const tsi_peer* peer) { + const tsi_peer_property* security_level = + tsi_peer_get_property_by_name(peer, TSI_SECURITY_LEVEL_PEER_PROPERTY); + GPR_ASSERT(security_level != nullptr); + const char* expected_match = "TSI_PRIVACY_AND_INTEGRITY"; + GPR_ASSERT(memcmp(security_level->value.data, expected_match, + security_level->value.length) == 0); +} + +static const tsi_peer_property* +check_basic_authenticated_peer_and_get_common_name(const tsi_peer* peer) { + const tsi_peer_property* cert_type_property = + tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); + GPR_ASSERT(cert_type_property != nullptr); + GPR_ASSERT(memcmp(cert_type_property->value.data, TSI_X509_CERTIFICATE_TYPE, + cert_type_property->value.length) == 0); + const tsi_peer_property* property = tsi_peer_get_property_by_name( + peer, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY); + GPR_ASSERT(property != nullptr); + return property; +} + +static void check_session_reusage(ssl_tsi_test_fixture* ssl_fixture, + tsi_peer* peer) { + const tsi_peer_property* session_reused = + tsi_peer_get_property_by_name(peer, TSI_SSL_SESSION_REUSED_PEER_PROPERTY); + GPR_ASSERT(session_reused != nullptr); + if (ssl_fixture->session_reused) { + GPR_ASSERT(strncmp(session_reused->value.data, "true", + session_reused->value.length) == 0); + } else { + GPR_ASSERT(strncmp(session_reused->value.data, "false", + session_reused->value.length) == 0); + } +} + +void check_server0_peer(tsi_peer* peer) { + const tsi_peer_property* property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char* expected_match = "*.test.google.com.au"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + GPR_ASSERT(tsi_peer_get_property_by_name( + peer, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == + nullptr); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.google.com.au") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.com.au") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "BAR.TEST.GOOGLE.COM.AU") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "Bar.Test.Google.Com.Au") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bAr.TeST.gOOgle.cOm.AU") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.blah") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.bar.test.google.com.au") == + 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "test.google.com.au") == 0); + tsi_peer_destruct(peer); +} + +static bool check_property(tsi_peer* peer, const char* property_name, + const char* property_value) { + for (size_t i = 0; i < peer->property_count; i++) { + const tsi_peer_property* prop = &peer->properties[i]; + if (strcmp(prop->name, property_name) == 0) { + if (strlen(property_value) == prop->value.length && + memcmp(prop->value.data, property_value, prop->value.length) == 0) { + return true; + } + } + } + return false; +} + +static bool check_subject_alt_name(tsi_peer* peer, const char* name) { + return check_property(peer, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, + name); +} + +static bool check_dns(tsi_peer* peer, const char* name) { + return check_property(peer, TSI_X509_DNS_PEER_PROPERTY, name); +} + +static bool check_uri(tsi_peer* peer, const char* name) { + return check_property(peer, TSI_X509_URI_PEER_PROPERTY, name); +} + +static bool check_email(tsi_peer* peer, const char* name) { + return check_property(peer, TSI_X509_EMAIL_PEER_PROPERTY, name); +} + +static bool check_ip(tsi_peer* peer, const char* name) { + return check_property(peer, TSI_X509_IP_PEER_PROPERTY, name); +} + +void check_server1_peer(tsi_peer* peer) { + const tsi_peer_property* property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char* expected_match = "*.test.google.com"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + GPR_ASSERT(check_subject_alt_name(peer, "*.test.google.fr") == 1); + GPR_ASSERT(check_subject_alt_name(peer, "waterzooi.test.google.be") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.google.fr") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.test.google.fr") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "waterzooi.test.google.be") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "foo.test.youtube.com") == 1); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "bar.foo.test.google.com") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "test.google.fr") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "tartines.test.google.be") == 0); + GPR_ASSERT(tsi_ssl_peer_matches_name(peer, "tartines.youtube.com") == 0); + tsi_peer_destruct(peer); +} + +static void check_client_peer(ssl_tsi_test_fixture* ssl_fixture, + tsi_peer* peer) { + GPR_ASSERT(ssl_fixture != nullptr); + GPR_ASSERT(ssl_fixture->alpn_lib != nullptr); + ssl_alpn_lib* alpn_lib = ssl_fixture->alpn_lib; + if (!ssl_fixture->force_client_auth) { + GPR_ASSERT(peer->property_count == + (alpn_lib->alpn_mode == ALPN_CLIENT_SERVER_OK ? 3 : 2)); + } else { + const tsi_peer_property* property = + check_basic_authenticated_peer_and_get_common_name(peer); + const char* expected_match = "testclient"; + GPR_ASSERT(memcmp(property->value.data, expected_match, + property->value.length) == 0); + } + tsi_peer_destruct(peer); +} + +static void ssl_test_check_handshaker_peers(tsi_test_fixture* fixture) { + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + GPR_ASSERT(ssl_fixture != nullptr); + GPR_ASSERT(ssl_fixture->key_cert_lib != nullptr); + ssl_key_cert_lib* key_cert_lib = ssl_fixture->key_cert_lib; + tsi_peer peer; + // In TLS 1.3, the client-side handshake succeeds even if the client sends a + // bad certificate. In such a case, the server would fail the TLS handshake + // and send an alert to the client as the first application data message. In + // TLS 1.2, the client-side handshake will fail if the client sends a bad + // certificate. + // + // For OpenSSL versions < 1.1, TLS 1.3 is not supported, so the client-side + // handshake should succeed precisely when the server-side handshake + // succeeds. + bool expect_server_success = + !(key_cert_lib->use_bad_server_cert || + (key_cert_lib->use_bad_client_cert && ssl_fixture->force_client_auth)); +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + bool expect_client_success = test_tls_version == tsi_tls_version::TSI_TLS1_2 + ? expect_server_success + : !key_cert_lib->use_bad_server_cert; +#else + bool expect_client_success = expect_server_success; +#endif + if (expect_client_success) { + GPR_ASSERT(tsi_handshaker_result_extract_peer( + ssl_fixture->base.client_result, &peer) == TSI_OK); + check_session_reusage(ssl_fixture, &peer); + check_alpn(ssl_fixture, &peer); + check_security_level(&peer); + if (ssl_fixture->server_name_indication == nullptr || + strcmp(ssl_fixture->server_name_indication, SSL_TSI_TEST_WRONG_SNI) == + 0) { + // Expect server to use default server0.pem. + check_server0_peer(&peer); + } else { + // Expect server to use server1.pem. + check_server1_peer(&peer); + } + } else { + GPR_ASSERT(ssl_fixture->base.client_result == nullptr); + } + if (expect_server_success) { + GPR_ASSERT(tsi_handshaker_result_extract_peer( + ssl_fixture->base.server_result, &peer) == TSI_OK); + check_session_reusage(ssl_fixture, &peer); + check_alpn(ssl_fixture, &peer); + check_security_level(&peer); + check_client_peer(ssl_fixture, &peer); + } else { + GPR_ASSERT(ssl_fixture->base.server_result == nullptr); + } +} + +static void ssl_test_pem_key_cert_pair_destroy(tsi_ssl_pem_key_cert_pair kp) { + gpr_free(const_cast(kp.private_key)); + gpr_free(const_cast(kp.cert_chain)); +} + +static void ssl_test_destruct(tsi_test_fixture* fixture) { + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + if (ssl_fixture == nullptr) { + return; + } + /* Destroy ssl_alpn_lib. */ + ssl_alpn_lib* alpn_lib = ssl_fixture->alpn_lib; + for (size_t i = 0; i < alpn_lib->num_server_alpn_protocols; i++) { + gpr_free(const_cast(alpn_lib->server_alpn_protocols[i])); + } + gpr_free(alpn_lib->server_alpn_protocols); + for (size_t i = 0; i < alpn_lib->num_client_alpn_protocols; i++) { + gpr_free(const_cast(alpn_lib->client_alpn_protocols[i])); + } + gpr_free(alpn_lib->client_alpn_protocols); + gpr_free(alpn_lib); + /* Destroy ssl_key_cert_lib. */ + ssl_key_cert_lib* key_cert_lib = ssl_fixture->key_cert_lib; + for (size_t i = 0; i < key_cert_lib->server_num_key_cert_pairs; i++) { + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->server_pem_key_cert_pairs[i]); + } + gpr_free(key_cert_lib->server_pem_key_cert_pairs); + for (size_t i = 0; i < key_cert_lib->bad_server_num_key_cert_pairs; i++) { + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->bad_server_pem_key_cert_pairs[i]); + } + gpr_free(key_cert_lib->bad_server_pem_key_cert_pairs); + ssl_test_pem_key_cert_pair_destroy(key_cert_lib->client_pem_key_cert_pair); + ssl_test_pem_key_cert_pair_destroy( + key_cert_lib->bad_client_pem_key_cert_pair); + gpr_free(key_cert_lib->root_cert); + tsi_ssl_root_certs_store_destroy(key_cert_lib->root_store); + gpr_free(key_cert_lib); + if (ssl_fixture->session_cache != nullptr) { + tsi_ssl_session_cache_unref(ssl_fixture->session_cache); + } + /* Unreference others. */ + tsi_ssl_server_handshaker_factory_unref( + ssl_fixture->server_handshaker_factory); + tsi_ssl_client_handshaker_factory_unref( + ssl_fixture->client_handshaker_factory); +} + +static const struct tsi_test_fixture_vtable vtable = { + ssl_test_setup_handshakers, ssl_test_check_handshaker_peers, + ssl_test_destruct}; + +static char* load_file(const char* dir_path, const char* file_name) { + char* file_path = static_cast( + gpr_zalloc(sizeof(char) * (strlen(dir_path) + strlen(file_name) + 1))); + memcpy(file_path, dir_path, strlen(dir_path)); + memcpy(file_path + strlen(dir_path), file_name, strlen(file_name)); + grpc_slice slice; + GPR_ASSERT(grpc_load_file(file_path, 1, &slice) == GRPC_ERROR_NONE); + char* data = grpc_slice_to_c_string(slice); + grpc_slice_unref(slice); + gpr_free(file_path); + return data; +} + +static tsi_test_fixture* ssl_tsi_test_fixture_create() { + ssl_tsi_test_fixture* ssl_fixture = grpc_core::Zalloc(); + tsi_test_fixture_init(&ssl_fixture->base); + ssl_fixture->base.test_unused_bytes = true; + ssl_fixture->base.vtable = &vtable; + /* Create ssl_key_cert_lib. */ + ssl_key_cert_lib* key_cert_lib = grpc_core::Zalloc(); + key_cert_lib->use_bad_server_cert = false; + key_cert_lib->use_bad_client_cert = false; + key_cert_lib->use_root_store = false; + key_cert_lib->server_num_key_cert_pairs = + SSL_TSI_TEST_SERVER_KEY_CERT_PAIRS_NUM; + key_cert_lib->bad_server_num_key_cert_pairs = + SSL_TSI_TEST_BAD_SERVER_KEY_CERT_PAIRS_NUM; + key_cert_lib->server_pem_key_cert_pairs = + static_cast( + gpr_malloc(sizeof(tsi_ssl_pem_key_cert_pair) * + key_cert_lib->server_num_key_cert_pairs)); + key_cert_lib->bad_server_pem_key_cert_pairs = + static_cast( + gpr_malloc(sizeof(tsi_ssl_pem_key_cert_pair) * + key_cert_lib->bad_server_num_key_cert_pairs)); + key_cert_lib->server_pem_key_cert_pairs[0].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key"); + key_cert_lib->server_pem_key_cert_pairs[0].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem"); + key_cert_lib->server_pem_key_cert_pairs[1].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server1.key"); + key_cert_lib->server_pem_key_cert_pairs[1].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server1.pem"); + key_cert_lib->bad_server_pem_key_cert_pairs[0].private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badserver.key"); + key_cert_lib->bad_server_pem_key_cert_pairs[0].cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badserver.pem"); + key_cert_lib->client_pem_key_cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.key"); + key_cert_lib->client_pem_key_cert_pair.cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem"); + key_cert_lib->bad_client_pem_key_cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badclient.key"); + key_cert_lib->bad_client_pem_key_cert_pair.cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "badclient.pem"); + key_cert_lib->root_cert = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "ca.pem"); + key_cert_lib->root_store = + tsi_ssl_root_certs_store_create(key_cert_lib->root_cert); + GPR_ASSERT(key_cert_lib->root_store != nullptr); + ssl_fixture->key_cert_lib = key_cert_lib; + /* Create ssl_alpn_lib. */ + ssl_alpn_lib* alpn_lib = grpc_core::Zalloc(); + alpn_lib->server_alpn_protocols = static_cast( + gpr_zalloc(sizeof(char*) * SSL_TSI_TEST_ALPN_NUM)); + alpn_lib->client_alpn_protocols = static_cast( + gpr_zalloc(sizeof(char*) * SSL_TSI_TEST_ALPN_NUM)); + alpn_lib->server_alpn_protocols[0] = gpr_strdup(SSL_TSI_TEST_ALPN1); + alpn_lib->server_alpn_protocols[1] = gpr_strdup(SSL_TSI_TEST_ALPN3); + alpn_lib->client_alpn_protocols[0] = gpr_strdup(SSL_TSI_TEST_ALPN2); + alpn_lib->client_alpn_protocols[1] = gpr_strdup(SSL_TSI_TEST_ALPN3); + alpn_lib->num_server_alpn_protocols = SSL_TSI_TEST_ALPN_NUM; + alpn_lib->num_client_alpn_protocols = SSL_TSI_TEST_ALPN_NUM; + alpn_lib->alpn_mode = NO_ALPN; + ssl_fixture->alpn_lib = alpn_lib; + ssl_fixture->base.vtable = &vtable; + ssl_fixture->server_name_indication = nullptr; + ssl_fixture->session_reused = false; + ssl_fixture->session_ticket_key = nullptr; + ssl_fixture->session_ticket_key_size = 0; + ssl_fixture->force_client_auth = false; + return &ssl_fixture->base; +} + +void ssl_tsi_test_do_handshake_tiny_handshake_buffer() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_tiny_handshake_buffer"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_TINY_HANDSHAKE_BUFFER_SIZE; + // Handshake buffer is too small to hold both handshake messages and the + // unused bytes. + fixture->test_unused_bytes = false; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_small_handshake_buffer() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_small_handshake_buffer"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + fixture->handshake_buffer_size = TSI_TEST_SMALL_HANDSHAKE_BUFFER_SIZE; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_root_store() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_with_root_store"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->key_cert_lib->use_root_store = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_client_authentication() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_with_client_authentication"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->force_client_auth = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_client_authentication_and_root_store() { + gpr_log( + GPR_INFO, + "ssl_tsi_test_do_handshake_with_client_authentication_and_root_store"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->force_client_auth = true; + ssl_fixture->key_cert_lib->use_root_store = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_server_name_indication_exact_domain() { + gpr_log(GPR_INFO, + "ssl_tsi_test_do_handshake_with_server_name_indication_exact_domain"); + /* server1 cert contains "waterzooi.test.google.be" in SAN. */ + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->server_name_indication = + const_cast("waterzooi.test.google.be"); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_server_name_indication_wild_star_domain() { + gpr_log( + GPR_INFO, + "ssl_tsi_test_do_handshake_with_server_name_indication_wild_star_domain"); + /* server1 cert contains "*.test.google.fr" in SAN. */ + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->server_name_indication = + const_cast("juju.test.google.fr"); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_wrong_server_name_indication() { + gpr_log(GPR_INFO, + "ssl_tsi_test_do_handshake_with_wrong_server_name_indication"); + /* server certs do not contain "test.google.cn". */ + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->server_name_indication = + const_cast(SSL_TSI_TEST_WRONG_SNI); + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_bad_server_cert() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_with_bad_server_cert"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->key_cert_lib->use_bad_server_cert = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_with_bad_client_cert() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_with_bad_client_cert"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->key_cert_lib->use_bad_client_cert = true; + ssl_fixture->force_client_auth = true; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_no_server() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_alpn_client_no_server"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_NO_SERVER; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_server_no_client() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_alpn_server_no_client"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->alpn_lib->alpn_mode = ALPN_SERVER_NO_CLIENT; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_server_mismatch() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_alpn_server_no_client"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_SERVER_MISMATCH; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_handshake_alpn_client_server_ok() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_alpn_client_server_ok"); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->alpn_lib->alpn_mode = ALPN_CLIENT_SERVER_OK; + tsi_test_do_handshake(fixture); + tsi_test_fixture_destroy(fixture); +} + +void ssl_tsi_test_do_round_trip_for_all_configs() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_round_trip_for_all_configs"); + unsigned int* bit_array = static_cast( + gpr_zalloc(sizeof(unsigned int) * TSI_TEST_NUM_OF_ARGUMENTS)); + const unsigned int mask = 1U << (TSI_TEST_NUM_OF_ARGUMENTS - 1); + for (unsigned int val = 0; val < TSI_TEST_NUM_OF_COMBINATIONS; val++) { + unsigned int v = val; + for (unsigned int ind = 0; ind < TSI_TEST_NUM_OF_ARGUMENTS; ind++) { + bit_array[ind] = (v & mask) ? 1 : 0; + v <<= 1; + } + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + tsi_test_frame_protector_config_destroy(ssl_fixture->base.config); + ssl_fixture->base.config = tsi_test_frame_protector_config_create( + bit_array[0], bit_array[1], bit_array[2], bit_array[3], bit_array[4], + bit_array[5], bit_array[6]); + tsi_test_do_round_trip(&ssl_fixture->base); + tsi_test_fixture_destroy(fixture); + } + gpr_free(bit_array); +} + +void ssl_tsi_test_do_round_trip_with_error_on_stack() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_round_trip_with_error_on_stack"); + // Invoke an SSL function that causes an error, and ensure the error + // makes it to the stack. + GPR_ASSERT(!EC_KEY_new_by_curve_name(NID_rsa)); + GPR_ASSERT(ERR_peek_error() != 0); + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + tsi_test_do_round_trip(fixture); + tsi_test_fixture_destroy(fixture); +} + +static bool is_slow_build() { +#if defined(GPR_ARCH_32) || defined(__APPLE__) + return true; +#else + return BuiltUnderMsan() || BuiltUnderTsan(); +#endif +} + +void ssl_tsi_test_do_round_trip_odd_buffer_size() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_round_trip_odd_buffer_size"); + const size_t odd_sizes[] = {1025, 2051, 4103, 8207, 16409}; + size_t size = sizeof(odd_sizes) / sizeof(size_t); + // 1. This test is extremely slow under MSAN and TSAN. + // 2. On 32-bit, the test is much slower (probably due to lack of boringssl + // asm optimizations) so we only run a subset of tests to avoid timeout. + // 3. On Mac OS, we have slower testing machines so we only run a subset + // of tests to avoid timeout. + if (is_slow_build()) { + size = 1; + } + for (size_t ind1 = 0; ind1 < size; ind1++) { + for (size_t ind2 = 0; ind2 < size; ind2++) { + for (size_t ind3 = 0; ind3 < size; ind3++) { + for (size_t ind4 = 0; ind4 < size; ind4++) { + for (size_t ind5 = 0; ind5 < size; ind5++) { + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + tsi_test_frame_protector_config_set_buffer_size( + ssl_fixture->base.config, odd_sizes[ind1], odd_sizes[ind2], + odd_sizes[ind3], odd_sizes[ind4], odd_sizes[ind5]); + tsi_test_do_round_trip(&ssl_fixture->base); + tsi_test_fixture_destroy(fixture); + } + } + } + } + } +} + +void ssl_tsi_test_do_handshake_session_cache() { + gpr_log(GPR_INFO, "ssl_tsi_test_do_handshake_session_cache"); + tsi_ssl_session_cache* session_cache = tsi_ssl_session_cache_create_lru(16); + char session_ticket_key[kSessionTicketEncryptionKeySize]; + auto do_handshake = [&session_ticket_key, + &session_cache](bool session_reused) { + tsi_test_fixture* fixture = ssl_tsi_test_fixture_create(); + ssl_tsi_test_fixture* ssl_fixture = + reinterpret_cast(fixture); + ssl_fixture->server_name_indication = + const_cast("waterzooi.test.google.be"); + ssl_fixture->session_ticket_key = session_ticket_key; + ssl_fixture->session_ticket_key_size = sizeof(session_ticket_key); + tsi_ssl_session_cache_ref(session_cache); + ssl_fixture->session_cache = session_cache; + ssl_fixture->session_reused = session_reused; + tsi_test_do_round_trip(&ssl_fixture->base); + tsi_test_fixture_destroy(fixture); + }; + memset(session_ticket_key, 'a', sizeof(session_ticket_key)); + do_handshake(false); + do_handshake(true); + do_handshake(true); + // Changing session_ticket_key on server invalidates ticket. + memset(session_ticket_key, 'b', sizeof(session_ticket_key)); + do_handshake(false); + do_handshake(true); + memset(session_ticket_key, 'c', sizeof(session_ticket_key)); + do_handshake(false); + do_handshake(true); + tsi_ssl_session_cache_unref(session_cache); +} + +static const tsi_ssl_handshaker_factory_vtable* original_vtable; +static bool handshaker_factory_destructor_called; + +static void ssl_tsi_test_handshaker_factory_destructor( + tsi_ssl_handshaker_factory* factory) { + GPR_ASSERT(factory != nullptr); + handshaker_factory_destructor_called = true; + if (original_vtable != nullptr && original_vtable->destroy != nullptr) { + original_vtable->destroy(factory); + } +} + +static tsi_ssl_handshaker_factory_vtable test_handshaker_factory_vtable = { + ssl_tsi_test_handshaker_factory_destructor}; + +void test_tsi_ssl_client_handshaker_factory_refcounting() { + int i; + char* cert_chain = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "client.pem"); + + tsi_ssl_client_handshaker_options options; + options.pem_root_certs = cert_chain; + tsi_ssl_client_handshaker_factory* client_handshaker_factory; + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory) == TSI_OK); + + handshaker_factory_destructor_called = false; + original_vtable = tsi_ssl_handshaker_factory_swap_vtable( + reinterpret_cast(client_handshaker_factory), + &test_handshaker_factory_vtable); + + tsi_handshaker* handshaker[3]; + + for (i = 0; i < 3; ++i) { + GPR_ASSERT(tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory, "google.com", &handshaker[i]) == + TSI_OK); + } + + tsi_handshaker_destroy(handshaker[1]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[0]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[2]); + GPR_ASSERT(handshaker_factory_destructor_called); + + gpr_free(cert_chain); +} + +void test_tsi_ssl_server_handshaker_factory_refcounting() { + int i; + tsi_ssl_server_handshaker_factory* server_handshaker_factory; + tsi_handshaker* handshaker[3]; + const char* cert_chain = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.pem"); + tsi_ssl_pem_key_cert_pair cert_pair; + + cert_pair.cert_chain = cert_chain; + cert_pair.private_key = + load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server0.key"); + tsi_ssl_server_handshaker_options options; + options.pem_key_cert_pairs = &cert_pair; + options.num_key_cert_pairs = 1; + options.pem_client_root_certs = cert_chain; + + GPR_ASSERT(tsi_create_ssl_server_handshaker_factory_with_options( + &options, &server_handshaker_factory) == TSI_OK); + + handshaker_factory_destructor_called = false; + original_vtable = tsi_ssl_handshaker_factory_swap_vtable( + reinterpret_cast(server_handshaker_factory), + &test_handshaker_factory_vtable); + + for (i = 0; i < 3; ++i) { + GPR_ASSERT(tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory, &handshaker[i]) == TSI_OK); + } + + tsi_handshaker_destroy(handshaker[1]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[0]); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory); + GPR_ASSERT(!handshaker_factory_destructor_called); + + tsi_handshaker_destroy(handshaker[2]); + GPR_ASSERT(handshaker_factory_destructor_called); + + ssl_test_pem_key_cert_pair_destroy(cert_pair); +} + +/* Attempting to create a handshaker factory with invalid parameters should fail + * but not crash. */ +void test_tsi_ssl_client_handshaker_factory_bad_params() { + const char* cert_chain = "This is not a valid PEM file."; + + tsi_ssl_client_handshaker_factory* client_handshaker_factory; + tsi_ssl_client_handshaker_options options; + options.pem_root_certs = cert_chain; + GPR_ASSERT(tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory) == TSI_INVALID_ARGUMENT); + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory); +} + +void ssl_tsi_test_handshaker_factory_internals() { + gpr_log(GPR_INFO, "ssl_tsi_test_handshaker_factory_internals"); + test_tsi_ssl_client_handshaker_factory_refcounting(); + test_tsi_ssl_server_handshaker_factory_refcounting(); + test_tsi_ssl_client_handshaker_factory_bad_params(); +} + +void ssl_tsi_test_duplicate_root_certificates() { + gpr_log(GPR_INFO, "ssl_tsi_test_duplicate_root_certificates"); + char* root_cert = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "ca.pem"); + char* dup_root_cert = static_cast( + gpr_zalloc(sizeof(char) * (strlen(root_cert) * 2 + 1))); + memcpy(dup_root_cert, root_cert, strlen(root_cert)); + memcpy(dup_root_cert + strlen(root_cert), root_cert, strlen(root_cert)); + tsi_ssl_root_certs_store* root_store = + tsi_ssl_root_certs_store_create(dup_root_cert); + GPR_ASSERT(root_store != nullptr); + // Free memory. + tsi_ssl_root_certs_store_destroy(root_store); + gpr_free(root_cert); + gpr_free(dup_root_cert); +} + +void ssl_tsi_test_extract_x509_subject_names() { + gpr_log(GPR_INFO, "ssl_tsi_test_extract_x509_subject_names"); + char* cert = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "multi-domain.pem"); + tsi_peer peer; + GPR_ASSERT(tsi_ssl_extract_x509_subject_names_from_pem_cert(cert, &peer) == + TSI_OK); + // tsi_peer should include one common name, one certificate, one security + // level, ten SAN fields, two DNS SAN fields, three URI fields, two email + // addresses and two IP addresses. + size_t expected_property_count = 21; + GPR_ASSERT(peer.property_count == expected_property_count); + // Check common name + const char* expected_cn = "xpigors"; + const tsi_peer_property* property = tsi_peer_get_property_by_name( + &peer, TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY); + GPR_ASSERT(property != nullptr); + GPR_ASSERT( + memcmp(property->value.data, expected_cn, property->value.length) == 0); + // Check certificate data + property = tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + GPR_ASSERT(property != nullptr); + GPR_ASSERT(memcmp(property->value.data, cert, property->value.length) == 0); + // Check DNS + GPR_ASSERT(check_subject_alt_name(&peer, "foo.test.domain.com") == 1); + GPR_ASSERT(check_subject_alt_name(&peer, "bar.test.domain.com") == 1); + GPR_ASSERT(check_dns(&peer, "foo.test.domain.com") == 1); + GPR_ASSERT(check_dns(&peer, "bar.test.domain.com") == 1); + // Check URI + // Note that a valid SPIFFE certificate should only have one URI. + GPR_ASSERT(check_subject_alt_name(&peer, "spiffe://foo.com/bar/baz") == 1); + GPR_ASSERT( + check_subject_alt_name(&peer, "https://foo.test.domain.com/test") == 1); + GPR_ASSERT( + check_subject_alt_name(&peer, "https://bar.test.domain.com/test") == 1); + GPR_ASSERT(check_uri(&peer, "spiffe://foo.com/bar/baz") == 1); + GPR_ASSERT(check_uri(&peer, "https://foo.test.domain.com/test") == 1); + GPR_ASSERT(check_uri(&peer, "https://bar.test.domain.com/test") == 1); + // Check email address + GPR_ASSERT(check_subject_alt_name(&peer, "foo@test.domain.com") == 1); + GPR_ASSERT(check_subject_alt_name(&peer, "bar@test.domain.com") == 1); + GPR_ASSERT(check_email(&peer, "foo@test.domain.com") == 1); + GPR_ASSERT(check_email(&peer, "bar@test.domain.com") == 1); + // Check ip address + GPR_ASSERT(check_subject_alt_name(&peer, "192.168.7.1") == 1); + GPR_ASSERT(check_subject_alt_name(&peer, "13::17") == 1); + GPR_ASSERT(check_ip(&peer, "192.168.7.1") == 1); + GPR_ASSERT(check_ip(&peer, "13::17") == 1); + // Check other fields + GPR_ASSERT(check_subject_alt_name(&peer, "other types of SAN") == 1); + // Free memory + gpr_free(cert); + tsi_peer_destruct(&peer); +} + +void ssl_tsi_test_extract_cert_chain() { + gpr_log(GPR_INFO, "ssl_tsi_test_extract_cert_chain"); + char* cert = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "server1.pem"); + char* ca = load_file(SSL_TSI_TEST_CREDENTIALS_DIR, "ca.pem"); + char* chain = static_cast( + gpr_zalloc(sizeof(char) * (strlen(cert) + strlen(ca) + 1))); + memcpy(chain, cert, strlen(cert)); + memcpy(chain + strlen(cert), ca, strlen(ca)); + STACK_OF(X509)* cert_chain = sk_X509_new_null(); + GPR_ASSERT(cert_chain != nullptr); + BIO* bio = BIO_new_mem_buf(chain, strlen(chain)); + GPR_ASSERT(bio != nullptr); + STACK_OF(X509_INFO)* certInfos = + PEM_X509_INFO_read_bio(bio, nullptr, nullptr, nullptr); + GPR_ASSERT(certInfos != nullptr); + for (size_t i = 0; i < sk_X509_INFO_num(certInfos); i++) { + X509_INFO* certInfo = sk_X509_INFO_value(certInfos, i); + if (certInfo->x509 != nullptr) { + GPR_ASSERT(sk_X509_push(cert_chain, certInfo->x509) != 0); +#if OPENSSL_VERSION_NUMBER >= 0x10100000 + X509_up_ref(certInfo->x509); +#else + certInfo->x509->references += 1; +#endif + } + } + tsi_peer_property chain_property; + GPR_ASSERT(tsi_ssl_get_cert_chain_contents(cert_chain, &chain_property) == + TSI_OK); + GPR_ASSERT(memcmp(chain, chain_property.value.data, + chain_property.value.length) == 0); + BIO_free(bio); + gpr_free(chain); + gpr_free(cert); + gpr_free(ca); + tsi_peer_property_destruct(&chain_property); + sk_X509_INFO_pop_free(certInfos, X509_INFO_free); + sk_X509_pop_free(cert_chain, X509_free); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + const size_t number_tls_versions = 2; + const tsi_tls_version tls_versions[] = {tsi_tls_version::TSI_TLS1_2, + tsi_tls_version::TSI_TLS1_3}; + for (size_t i = 0; i < number_tls_versions; i++) { + // Set the TLS version to be used in the tests. + test_tls_version = tls_versions[i]; + // Run all the tests using that TLS version for both the client and server. + ssl_tsi_test_do_handshake_tiny_handshake_buffer(); + ssl_tsi_test_do_handshake_small_handshake_buffer(); + ssl_tsi_test_do_handshake(); + ssl_tsi_test_do_handshake_with_root_store(); + ssl_tsi_test_do_handshake_with_client_authentication(); + ssl_tsi_test_do_handshake_with_client_authentication_and_root_store(); + ssl_tsi_test_do_handshake_with_server_name_indication_exact_domain(); + ssl_tsi_test_do_handshake_with_server_name_indication_wild_star_domain(); + ssl_tsi_test_do_handshake_with_wrong_server_name_indication(); + ssl_tsi_test_do_handshake_with_bad_server_cert(); + ssl_tsi_test_do_handshake_with_bad_client_cert(); +#ifdef OPENSSL_IS_BORINGSSL + // BoringSSL and OpenSSL have different behaviors on mismatched ALPN. + ssl_tsi_test_do_handshake_alpn_client_no_server(); + ssl_tsi_test_do_handshake_alpn_client_server_mismatch(); +#endif + ssl_tsi_test_do_handshake_alpn_server_no_client(); + ssl_tsi_test_do_handshake_alpn_client_server_ok(); + ssl_tsi_test_do_handshake_session_cache(); + ssl_tsi_test_do_round_trip_for_all_configs(); + ssl_tsi_test_do_round_trip_with_error_on_stack(); + ssl_tsi_test_do_round_trip_odd_buffer_size(); + ssl_tsi_test_handshaker_factory_internals(); + ssl_tsi_test_duplicate_root_certificates(); + ssl_tsi_test_extract_x509_subject_names(); + ssl_tsi_test_extract_cert_chain(); + } + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/transport_security_test.cc b/test/core/tsi/transport_security_test.cc new file mode 100644 index 00000000..6000daed --- /dev/null +++ b/test/core/tsi/transport_security_test.cc @@ -0,0 +1,390 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/transport_security.h" + +#include + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/tsi/fake_transport_security.h" +#include "src/core/tsi/ssl_transport_security.h" +#include "test/core/util/test_config.h" + +typedef struct { + /* 1 if success, 0 if failure. */ + int expected; + + /* Host name to match. */ + const char* host_name; + + /* Common name (CN). */ + const char* common_name; + + /* Comma separated list of certificate names to match against. Any occurrence + of '#' will be replaced with a null character before processing. */ + const char* dns_names; + + /* Comma separated list of IP SANs to match aggainst */ + const char* ip_names; +} cert_name_test_entry; + +/* Largely inspired from: + chromium/src/net/cert/x509_certificate_unittest.cc. + TODO(jboeuf) uncomment test cases as we fix tsi_ssl_peer_matches_name. */ +const cert_name_test_entry cert_name_test_entries[] = { + {1, "foo.com", "foo.com", nullptr, nullptr}, + {1, "f", "f", nullptr, nullptr}, + {0, "h", "i", nullptr, nullptr}, + {1, "bar.foo.com", "*.foo.com", nullptr, nullptr}, + {1, "www.test.fr", "common.name", + "*.test.com,*.test.co.uk,*.test.de,*.test.fr", nullptr}, + /* + {1, "wwW.tESt.fr", "common.name", ",*.*,*.test.de,*.test.FR,www"}, + */ + {0, "f.uk", ".uk", nullptr, nullptr}, + {0, "w.bar.foo.com", "?.bar.foo.com", nullptr, nullptr}, + {0, "www.foo.com", "(www|ftp).foo.com", nullptr, nullptr}, + {0, "www.foo.com", "www.foo.com#", nullptr, nullptr}, /* # = null char. */ + {0, "www.foo.com", "", "www.foo.com#*.foo.com,#,#", nullptr}, + {0, "www.house.example", "ww.house.example", nullptr, nullptr}, + {0, "test.org", "", "www.test.org,*.test.org,*.org", nullptr}, + {0, "w.bar.foo.com", "w*.bar.foo.com", nullptr, nullptr}, + {0, "www.bar.foo.com", "ww*ww.bar.foo.com", nullptr, nullptr}, + {0, "wwww.bar.foo.com", "ww*ww.bar.foo.com", nullptr, nullptr}, + {0, "wwww.bar.foo.com", "w*w.bar.foo.com", nullptr, nullptr}, + {0, "wwww.bar.foo.com", "w*w.bar.foo.c0m", nullptr, nullptr}, + {0, "WALLY.bar.foo.com", "wa*.bar.foo.com", nullptr, nullptr}, + {0, "wally.bar.foo.com", "*Ly.bar.foo.com", nullptr, nullptr}, + /* + {1, "ww%57.foo.com", "", "www.foo.com"}, + {1, "www&.foo.com", "www%26.foo.com", NULL}, + */ + + /* Common name must not be used if subject alternative name was provided. */ + {0, "www.test.co.jp", "www.test.co.jp", + "*.test.de,*.jp,www.test.co.uk,www.*.co.jp", nullptr}, + {0, "www.bar.foo.com", "www.bar.foo.com", + "*.foo.com,*.*.foo.com,*.*.bar.foo.com,*..bar.foo.com,", nullptr}, + + /* IDN tests */ + {1, "xn--poema-9qae5a.com.br", "xn--poema-9qae5a.com.br", nullptr, nullptr}, + {1, "www.xn--poema-9qae5a.com.br", "*.xn--poema-9qae5a.com.br", nullptr, + nullptr}, + {0, "xn--poema-9qae5a.com.br", "", + "*.xn--poema-9qae5a.com.br," + "xn--poema-*.com.br," + "xn--*-9qae5a.com.br," + "*--poema-9qae5a.com.br", + nullptr}, + + /* The following are adapted from the examples quoted from + http://tools.ietf.org/html/rfc6125#section-6.4.3 + (e.g., *.example.com would match foo.example.com but + not bar.foo.example.com or example.com). */ + {1, "foo.example.com", "*.example.com", nullptr, nullptr}, + {0, "bar.foo.example.com", "*.example.com", nullptr, nullptr}, + {0, "example.com", "*.example.com", nullptr, nullptr}, + + /* Partial wildcards are disallowed, though RFC 2818 rules allow them. + That is, forms such as baz*.example.net, *baz.example.net, and + b*z.example.net should NOT match domains. Instead, the wildcard must + always be the left-most label, and only a single label. */ + {0, "baz1.example.net", "baz*.example.net", nullptr, nullptr}, + {0, "foobaz.example.net", "*baz.example.net", nullptr, nullptr}, + {0, "buzz.example.net", "b*z.example.net", nullptr, nullptr}, + {0, "www.test.example.net", "www.*.example.net", nullptr, nullptr}, + + /* Wildcards should not be valid for public registry controlled domains, + and unknown/unrecognized domains, at least three domain components must + be present. */ + {1, "www.test.example", "*.test.example", nullptr, nullptr}, + {1, "test.example.co.uk", "*.example.co.uk", nullptr, nullptr}, + {0, "test.example", "*.example", nullptr, nullptr}, + /* + {0, "example.co.uk", "*.co.uk", NULL}, + */ + {0, "foo.com", "*.com", nullptr, nullptr}, + {0, "foo.us", "*.us", nullptr, nullptr}, + {0, "foo", "*", nullptr, nullptr}, + + /* IDN variants of wildcards and registry controlled domains. */ + {1, "www.xn--poema-9qae5a.com.br", "*.xn--poema-9qae5a.com.br", nullptr, + nullptr}, + {1, "test.example.xn--mgbaam7a8h", "*.example.xn--mgbaam7a8h", nullptr, + nullptr}, + /* + {0, "xn--poema-9qae5a.com.br", "*.com.br", NULL}, + */ + {0, "example.xn--mgbaam7a8h", "*.xn--mgbaam7a8h", nullptr, nullptr}, + + /* Wildcards should be permissible for 'private' registry controlled + domains. */ + {1, "www.appspot.com", "*.appspot.com", nullptr, nullptr}, + {1, "foo.s3.amazonaws.com", "*.s3.amazonaws.com", nullptr, nullptr}, + + /* Multiple wildcards are not valid. */ + {0, "foo.example.com", "*.*.com", nullptr, nullptr}, + {0, "foo.bar.example.com", "*.bar.*.com", nullptr, nullptr}, + + /* Absolute vs relative DNS name tests. Although not explicitly specified + in RFC 6125, absolute reference names (those ending in a .) should + match either absolute or relative presented names. */ + {1, "foo.com", "foo.com.", nullptr, nullptr}, + {1, "foo.com.", "foo.com", nullptr, nullptr}, + {1, "foo.com.", "foo.com.", nullptr, nullptr}, + {1, "f", "f.", nullptr, nullptr}, + {1, "f.", "f", nullptr, nullptr}, + {1, "f.", "f.", nullptr, nullptr}, + {1, "www-3.bar.foo.com", "*.bar.foo.com.", nullptr, nullptr}, + {1, "www-3.bar.foo.com.", "*.bar.foo.com", nullptr, nullptr}, + {1, "www-3.bar.foo.com.", "*.bar.foo.com.", nullptr, nullptr}, + {0, ".", ".", nullptr, nullptr}, + {0, "example.com", "*.com.", nullptr, nullptr}, + {0, "example.com.", "*.com", nullptr, nullptr}, + {0, "example.com.", "*.com.", nullptr, nullptr}, + {0, "foo.", "*.", nullptr, nullptr}, + {0, "foo", "*.", nullptr, nullptr}, + /* + {0, "foo.co.uk", "*.co.uk.", NULL}, + {0, "foo.co.uk.", "*.co.uk.", NULL}, + */ + + /* An empty CN is OK. */ + {1, "test.foo.com", "", "test.foo.com", nullptr}, + + /* An IP should not be used for the CN. */ + {0, "173.194.195.139", "173.194.195.139", nullptr, nullptr}, + /* An IP can be used if the SAN IP is present */ + {1, "173.194.195.139", "foo.example.com", nullptr, "173.194.195.139"}, + {0, "173.194.195.139", "foo.example.com", nullptr, "8.8.8.8"}, + {0, "173.194.195.139", "foo.example.com", nullptr, "8.8.8.8,8.8.4.4"}, + {1, "173.194.195.139", "foo.example.com", nullptr, + "8.8.8.8,173.194.195.139"}, + {0, "173.194.195.139", "foo.example.com", nullptr, "173.194.195.13"}, + {0, "2001:db8:a0b:12f0::1", "foo.example.com", nullptr, "173.194.195.13"}, + {1, "2001:db8:a0b:12f0::1", "foo.example.com", nullptr, + "2001:db8:a0b:12f0::1"}, + {0, "2001:db8:a0b:12f0::1", "foo.example.com", nullptr, + "2001:db8:a0b:12f0::2"}, + {1, "2001:db8:a0b:12f0::1", "foo.example.com", nullptr, + "2001:db8:a0b:12f0::2,2001:db8:a0b:12f0::1,8.8.8.8"}, +}; + +typedef struct name_list { + const char* name; + struct name_list* next; +} name_list; + +typedef struct { + size_t name_count; + char* buffer; + name_list* names; +} parsed_names; + +name_list* name_list_add(const char* n) { + name_list* result = static_cast(gpr_malloc(sizeof(name_list))); + result->name = n; + result->next = nullptr; + return result; +} + +static parsed_names parse_names(const char* names_str) { + parsed_names result; + name_list* current_nl; + size_t i; + memset(&result, 0, sizeof(parsed_names)); + if (names_str == nullptr) return result; + result.name_count = 1; + result.buffer = gpr_strdup(names_str); + result.names = name_list_add(result.buffer); + current_nl = result.names; + for (i = 0; i < strlen(names_str); i++) { + if (names_str[i] == ',') { + result.buffer[i] = '\0'; + result.name_count++; + i++; + current_nl->next = name_list_add(result.buffer + i); + current_nl = current_nl->next; + } + } + return result; +} + +static void destruct_parsed_names(parsed_names* pdn) { + name_list* nl = pdn->names; + if (pdn->buffer != nullptr) gpr_free(pdn->buffer); + while (nl != nullptr) { + name_list* to_be_free = nl; + nl = nl->next; + gpr_free(to_be_free); + } +} + +static char* processed_name(const char* name) { + char* result = gpr_strdup(name); + size_t i; + for (i = 0; i < strlen(result); i++) { + if (result[i] == '#') { + result[i] = '\0'; + } + } + return result; +} + +static tsi_peer peer_from_cert_name_test_entry( + const cert_name_test_entry* entry) { + size_t i; + tsi_peer peer; + name_list* nl; + parsed_names dns_entries = parse_names(entry->dns_names); + parsed_names ip_entries = parse_names(entry->ip_names); + nl = dns_entries.names; + GPR_ASSERT(tsi_construct_peer( + 1 + dns_entries.name_count + ip_entries.name_count, &peer) == + TSI_OK); + GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( + TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY, entry->common_name, + &peer.properties[0]) == TSI_OK); + i = 1; + while (nl != nullptr) { + char* processed = processed_name(nl->name); + GPR_ASSERT(tsi_construct_string_peer_property( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, processed, + strlen(nl->name), &peer.properties[i++]) == TSI_OK); + nl = nl->next; + gpr_free(processed); + } + + nl = ip_entries.names; + while (nl != nullptr) { + char* processed = processed_name(nl->name); + GPR_ASSERT(tsi_construct_string_peer_property( + TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, processed, + strlen(nl->name), &peer.properties[i++]) == TSI_OK); + nl = nl->next; + gpr_free(processed); + } + destruct_parsed_names(&dns_entries); + destruct_parsed_names(&ip_entries); + return peer; +} + +std::string cert_name_test_entry_to_string(const cert_name_test_entry* entry) { + return absl::StrFormat( + "{ success = %s, host_name = %s, common_name = %s, dns_names = " + "%s, ip_names = %s}", + entry->expected ? "true" : "false", entry->host_name, entry->common_name, + entry->dns_names != nullptr ? entry->dns_names : "", + entry->ip_names != nullptr ? entry->ip_names : ""); +} + +static void test_peer_matches_name(void) { + size_t i = 0; + for (i = 0; i < GPR_ARRAY_SIZE(cert_name_test_entries); i++) { + const cert_name_test_entry* entry = &cert_name_test_entries[i]; + tsi_peer peer = peer_from_cert_name_test_entry(entry); + int result = tsi_ssl_peer_matches_name(&peer, entry->host_name); + if (result != entry->expected) { + gpr_log(GPR_ERROR, "%s", cert_name_test_entry_to_string(entry).c_str()); + GPR_ASSERT(0); /* Unexpected result. */ + } + tsi_peer_destruct(&peer); + } +} + +typedef struct { + tsi_result res; + const char* str; +} tsi_result_string_pair; + +static void test_result_strings(void) { + const tsi_result_string_pair results[] = { + {TSI_OK, "TSI_OK"}, + {TSI_UNKNOWN_ERROR, "TSI_UNKNOWN_ERROR"}, + {TSI_INVALID_ARGUMENT, "TSI_INVALID_ARGUMENT"}, + {TSI_PERMISSION_DENIED, "TSI_PERMISSION_DENIED"}, + {TSI_INCOMPLETE_DATA, "TSI_INCOMPLETE_DATA"}, + {TSI_FAILED_PRECONDITION, "TSI_FAILED_PRECONDITION"}, + {TSI_UNIMPLEMENTED, "TSI_UNIMPLEMENTED"}, + {TSI_INTERNAL_ERROR, "TSI_INTERNAL_ERROR"}, + {TSI_DATA_CORRUPTED, "TSI_DATA_CORRUPTED"}, + {TSI_NOT_FOUND, "TSI_NOT_FOUND"}, + {TSI_PROTOCOL_FAILURE, "TSI_PROTOCOL_FAILURE"}, + {TSI_HANDSHAKE_IN_PROGRESS, "TSI_HANDSHAKE_IN_PROGRESS"}, + {TSI_OUT_OF_RESOURCES, "TSI_OUT_OF_RESOURCES"}}; + size_t i; + for (i = 0; i < GPR_ARRAY_SIZE(results); i++) { + GPR_ASSERT(strcmp(results[i].str, tsi_result_to_string(results[i].res)) == + 0); + } + GPR_ASSERT(strcmp("UNKNOWN", tsi_result_to_string((tsi_result)42)) == 0); +} + +static void test_protector_invalid_args(void) { + GPR_ASSERT(tsi_frame_protector_protect(nullptr, nullptr, nullptr, nullptr, + nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_frame_protector_protect_flush( + nullptr, nullptr, nullptr, nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_frame_protector_unprotect(nullptr, nullptr, nullptr, nullptr, + nullptr) == TSI_INVALID_ARGUMENT); +} + +static void test_handshaker_invalid_args(void) { + GPR_ASSERT(tsi_handshaker_get_result(nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_extract_peer(nullptr, nullptr) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_create_frame_protector(nullptr, nullptr, nullptr) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_process_bytes_from_peer( + nullptr, nullptr, nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_get_bytes_to_send_to_peer( + nullptr, nullptr, nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_next(nullptr, nullptr, 0, nullptr, nullptr, nullptr, + nullptr, nullptr) == TSI_INVALID_ARGUMENT); +} + +static void test_handshaker_invalid_state(void) { + tsi_handshaker* h = tsi_create_fake_handshaker(0); + tsi_peer peer; + tsi_frame_protector* p; + GPR_ASSERT(tsi_handshaker_extract_peer(h, &peer) == TSI_FAILED_PRECONDITION); + GPR_ASSERT(tsi_handshaker_create_frame_protector(h, nullptr, &p) == + TSI_FAILED_PRECONDITION); + tsi_handshaker_destroy(h); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_peer_matches_name(); + test_result_strings(); + test_protector_invalid_args(); + test_handshaker_invalid_args(); + test_handshaker_invalid_state(); + return 0; +} diff --git a/test/core/tsi/transport_security_test_lib.cc b/test/core/tsi/transport_security_test_lib.cc new file mode 100644 index 00000000..6548af3d --- /dev/null +++ b/test/core/tsi/transport_security_test_lib.cc @@ -0,0 +1,666 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/tsi/transport_security_test_lib.h" + +#include +#include +#include + +#include +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/security/transport/tsi_error.h" + +static void notification_signal(tsi_test_fixture* fixture) { + gpr_mu_lock(&fixture->mu); + fixture->notified = true; + gpr_cv_signal(&fixture->cv); + gpr_mu_unlock(&fixture->mu); +} + +static void notification_wait(tsi_test_fixture* fixture) { + gpr_mu_lock(&fixture->mu); + while (!fixture->notified) { + gpr_cv_wait(&fixture->cv, &fixture->mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + fixture->notified = false; + gpr_mu_unlock(&fixture->mu); +} + +typedef struct handshaker_args { + tsi_test_fixture* fixture; + unsigned char* handshake_buffer; + size_t handshake_buffer_size; + bool is_client; + bool transferred_data; + bool appended_unused_bytes; + grpc_error_handle error; +} handshaker_args; + +static handshaker_args* handshaker_args_create(tsi_test_fixture* fixture, + bool is_client) { + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->config != nullptr); + handshaker_args* args = new handshaker_args(); + args->fixture = fixture; + args->handshake_buffer_size = fixture->handshake_buffer_size; + args->handshake_buffer = + static_cast(gpr_zalloc(args->handshake_buffer_size)); + args->is_client = is_client; + args->error = GRPC_ERROR_NONE; + return args; +} + +static void handshaker_args_destroy(handshaker_args* args) { + gpr_free(args->handshake_buffer); + GRPC_ERROR_UNREF(args->error); + delete args; +} + +static void do_handshaker_next(handshaker_args* args); + +static void setup_handshakers(tsi_test_fixture* fixture) { + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->vtable != nullptr); + GPR_ASSERT(fixture->vtable->setup_handshakers != nullptr); + fixture->vtable->setup_handshakers(fixture); +} + +static void check_unused_bytes(tsi_test_fixture* fixture) { + tsi_handshaker_result* result_with_unused_bytes = + fixture->has_client_finished_first ? fixture->server_result + : fixture->client_result; + tsi_handshaker_result* result_without_unused_bytes = + fixture->has_client_finished_first ? fixture->client_result + : fixture->server_result; + const unsigned char* bytes = nullptr; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes( + result_with_unused_bytes, &bytes, &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == strlen(TSI_TEST_UNUSED_BYTES)); + GPR_ASSERT(memcmp(bytes, TSI_TEST_UNUSED_BYTES, bytes_size) == 0); + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes( + result_without_unused_bytes, &bytes, &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == 0); + GPR_ASSERT(bytes == nullptr); +} + +static void check_handshake_results(tsi_test_fixture* fixture) { + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->vtable != nullptr); + GPR_ASSERT(fixture->vtable->check_handshaker_peers != nullptr); + /* Check handshaker peers. */ + fixture->vtable->check_handshaker_peers(fixture); + /* Check unused bytes. */ + if (fixture->test_unused_bytes) { + tsi_test_channel* channel = fixture->channel; + if (fixture->server_result != nullptr && + fixture->client_result != nullptr) { + check_unused_bytes(fixture); + } + channel->bytes_written_to_server_channel = 0; + channel->bytes_written_to_client_channel = 0; + channel->bytes_read_from_client_channel = 0; + channel->bytes_read_from_server_channel = 0; + } +} + +static void send_bytes_to_peer(tsi_test_channel* test_channel, + const unsigned char* buf, size_t buf_size, + bool is_client) { + GPR_ASSERT(test_channel != nullptr); + GPR_ASSERT(buf != nullptr); + uint8_t* channel = + is_client ? test_channel->server_channel : test_channel->client_channel; + GPR_ASSERT(channel != nullptr); + size_t* bytes_written = is_client + ? &test_channel->bytes_written_to_server_channel + : &test_channel->bytes_written_to_client_channel; + GPR_ASSERT(bytes_written != nullptr); + GPR_ASSERT(*bytes_written + buf_size <= TSI_TEST_DEFAULT_CHANNEL_SIZE); + /* Write data to channel. */ + memcpy(channel + *bytes_written, buf, buf_size); + *bytes_written += buf_size; +} + +static void maybe_append_unused_bytes(handshaker_args* args) { + GPR_ASSERT(args != nullptr); + GPR_ASSERT(args->fixture != nullptr); + tsi_test_fixture* fixture = args->fixture; + if (fixture->test_unused_bytes && !args->appended_unused_bytes) { + args->appended_unused_bytes = true; + send_bytes_to_peer( + fixture->channel, + reinterpret_cast(TSI_TEST_UNUSED_BYTES), + strlen(TSI_TEST_UNUSED_BYTES), args->is_client); + if (fixture->client_result != nullptr && + fixture->server_result == nullptr) { + fixture->has_client_finished_first = true; + } + } +} + +static void receive_bytes_from_peer(tsi_test_channel* test_channel, + unsigned char** buf, size_t* buf_size, + bool is_client) { + GPR_ASSERT(test_channel != nullptr); + GPR_ASSERT(*buf != nullptr); + GPR_ASSERT(buf_size != nullptr); + uint8_t* channel = + is_client ? test_channel->client_channel : test_channel->server_channel; + GPR_ASSERT(channel != nullptr); + size_t* bytes_read = is_client + ? &test_channel->bytes_read_from_client_channel + : &test_channel->bytes_read_from_server_channel; + size_t* bytes_written = is_client + ? &test_channel->bytes_written_to_client_channel + : &test_channel->bytes_written_to_server_channel; + GPR_ASSERT(bytes_read != nullptr); + GPR_ASSERT(bytes_written != nullptr); + size_t to_read = *buf_size < *bytes_written - *bytes_read + ? *buf_size + : *bytes_written - *bytes_read; + /* Read data from channel. */ + memcpy(*buf, channel + *bytes_read, to_read); + *buf_size = to_read; + *bytes_read += to_read; +} + +void tsi_test_frame_protector_send_message_to_peer( + tsi_test_frame_protector_config* config, tsi_test_channel* channel, + tsi_frame_protector* protector, bool is_client) { + /* Initialization. */ + GPR_ASSERT(config != nullptr); + GPR_ASSERT(channel != nullptr); + GPR_ASSERT(protector != nullptr); + unsigned char* protected_buffer = + static_cast(gpr_zalloc(config->protected_buffer_size)); + size_t message_size = + is_client ? config->client_message_size : config->server_message_size; + uint8_t* message = + is_client ? config->client_message : config->server_message; + GPR_ASSERT(message != nullptr); + const unsigned char* message_bytes = + reinterpret_cast(message); + tsi_result result = TSI_OK; + /* Do protect and send protected data to peer. */ + while (message_size > 0 && result == TSI_OK) { + size_t protected_buffer_size_to_send = config->protected_buffer_size; + size_t processed_message_size = message_size; + /* Do protect. */ + result = tsi_frame_protector_protect( + protector, message_bytes, &processed_message_size, protected_buffer, + &protected_buffer_size_to_send); + GPR_ASSERT(result == TSI_OK); + /* Send protected data to peer. */ + send_bytes_to_peer(channel, protected_buffer, protected_buffer_size_to_send, + is_client); + message_bytes += processed_message_size; + message_size -= processed_message_size; + /* Flush if we're done. */ + if (message_size == 0) { + size_t still_pending_size; + do { + protected_buffer_size_to_send = config->protected_buffer_size; + result = tsi_frame_protector_protect_flush( + protector, protected_buffer, &protected_buffer_size_to_send, + &still_pending_size); + GPR_ASSERT(result == TSI_OK); + send_bytes_to_peer(channel, protected_buffer, + protected_buffer_size_to_send, is_client); + } while (still_pending_size > 0 && result == TSI_OK); + GPR_ASSERT(result == TSI_OK); + } + } + GPR_ASSERT(result == TSI_OK); + gpr_free(protected_buffer); +} + +void tsi_test_frame_protector_receive_message_from_peer( + tsi_test_frame_protector_config* config, tsi_test_channel* channel, + tsi_frame_protector* protector, unsigned char* message, + size_t* bytes_received, bool is_client) { + /* Initialization. */ + GPR_ASSERT(config != nullptr); + GPR_ASSERT(channel != nullptr); + GPR_ASSERT(protector != nullptr); + GPR_ASSERT(message != nullptr); + GPR_ASSERT(bytes_received != nullptr); + size_t read_offset = 0; + size_t message_offset = 0; + size_t read_from_peer_size = 0; + tsi_result result = TSI_OK; + bool done = false; + unsigned char* read_buffer = static_cast( + gpr_zalloc(config->read_buffer_allocated_size)); + unsigned char* message_buffer = static_cast( + gpr_zalloc(config->message_buffer_allocated_size)); + /* Do unprotect on data received from peer. */ + while (!done && result == TSI_OK) { + /* Receive data from peer. */ + if (read_from_peer_size == 0) { + read_from_peer_size = config->read_buffer_allocated_size; + receive_bytes_from_peer(channel, &read_buffer, &read_from_peer_size, + is_client); + read_offset = 0; + } + if (read_from_peer_size == 0) { + done = true; + } + /* Do unprotect. */ + size_t message_buffer_size; + do { + message_buffer_size = config->message_buffer_allocated_size; + size_t processed_size = read_from_peer_size; + result = tsi_frame_protector_unprotect( + protector, read_buffer + read_offset, &processed_size, message_buffer, + &message_buffer_size); + GPR_ASSERT(result == TSI_OK); + if (message_buffer_size > 0) { + memcpy(message + message_offset, message_buffer, message_buffer_size); + message_offset += message_buffer_size; + } + read_offset += processed_size; + read_from_peer_size -= processed_size; + } while ((read_from_peer_size > 0 || message_buffer_size > 0) && + result == TSI_OK); + GPR_ASSERT(result == TSI_OK); + } + GPR_ASSERT(result == TSI_OK); + *bytes_received = message_offset; + gpr_free(read_buffer); + gpr_free(message_buffer); +} + +grpc_error_handle on_handshake_next_done( + tsi_result result, void* user_data, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { + handshaker_args* args = static_cast(user_data); + GPR_ASSERT(args != nullptr); + GPR_ASSERT(args->fixture != nullptr); + tsi_test_fixture* fixture = args->fixture; + grpc_error_handle error = GRPC_ERROR_NONE; + /* Read more data if we need to. */ + if (result == TSI_INCOMPLETE_DATA) { + GPR_ASSERT(bytes_to_send_size == 0); + notification_signal(fixture); + return error; + } + if (result != TSI_OK) { + notification_signal(fixture); + return grpc_set_tsi_error_result( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake failed"), result); + } + /* Update handshaker result. */ + if (handshaker_result != nullptr) { + tsi_handshaker_result** result_to_write = + args->is_client ? &fixture->client_result : &fixture->server_result; + GPR_ASSERT(*result_to_write == nullptr); + *result_to_write = handshaker_result; + } + /* Send data to peer, if needed. */ + if (bytes_to_send_size > 0) { + send_bytes_to_peer(fixture->channel, bytes_to_send, bytes_to_send_size, + args->is_client); + args->transferred_data = true; + } + if (handshaker_result != nullptr) { + maybe_append_unused_bytes(args); + } + notification_signal(fixture); + return error; +} + +static void on_handshake_next_done_wrapper( + tsi_result result, void* user_data, const unsigned char* bytes_to_send, + size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) { + handshaker_args* args = static_cast(user_data); + args->error = on_handshake_next_done(result, user_data, bytes_to_send, + bytes_to_send_size, handshaker_result); +} + +static bool is_handshake_finished_properly(handshaker_args* args) { + GPR_ASSERT(args != nullptr); + GPR_ASSERT(args->fixture != nullptr); + tsi_test_fixture* fixture = args->fixture; + return (args->is_client && fixture->client_result != nullptr) || + (!args->is_client && fixture->server_result != nullptr); +} + +static void do_handshaker_next(handshaker_args* args) { + /* Initialization. */ + GPR_ASSERT(args != nullptr); + GPR_ASSERT(args->fixture != nullptr); + tsi_test_fixture* fixture = args->fixture; + tsi_handshaker* handshaker = + args->is_client ? fixture->client_handshaker : fixture->server_handshaker; + if (is_handshake_finished_properly(args)) { + return; + } + tsi_handshaker_result* handshaker_result = nullptr; + unsigned char* bytes_to_send = nullptr; + size_t bytes_to_send_size = 0; + tsi_result result = TSI_OK; + /* Receive data from peer, if available. */ + do { + size_t buf_size = args->handshake_buffer_size; + receive_bytes_from_peer(fixture->channel, &args->handshake_buffer, + &buf_size, args->is_client); + if (buf_size > 0) { + args->transferred_data = true; + } + /* Peform handshaker next. */ + result = tsi_handshaker_next( + handshaker, args->handshake_buffer, buf_size, + const_cast(&bytes_to_send), &bytes_to_send_size, + &handshaker_result, &on_handshake_next_done_wrapper, args); + if (result != TSI_ASYNC) { + args->error = on_handshake_next_done( + result, args, bytes_to_send, bytes_to_send_size, handshaker_result); + if (args->error != GRPC_ERROR_NONE) { + return; + } + } + } while (result == TSI_INCOMPLETE_DATA); + notification_wait(fixture); +} + +void tsi_test_do_handshake(tsi_test_fixture* fixture) { + /* Initializaiton. */ + setup_handshakers(fixture); + handshaker_args* client_args = + handshaker_args_create(fixture, true /* is_client */); + handshaker_args* server_args = + handshaker_args_create(fixture, false /* is_client */); + /* Do handshake. */ + do { + client_args->transferred_data = false; + server_args->transferred_data = false; + do_handshaker_next(client_args); + if (client_args->error != GRPC_ERROR_NONE) { + break; + } + do_handshaker_next(server_args); + if (server_args->error != GRPC_ERROR_NONE) { + break; + } + GPR_ASSERT(client_args->transferred_data || server_args->transferred_data); + } while (fixture->client_result == nullptr || + fixture->server_result == nullptr); + /* Verify handshake results. */ + check_handshake_results(fixture); + /* Cleanup. */ + handshaker_args_destroy(client_args); + handshaker_args_destroy(server_args); +} + +static void tsi_test_do_ping_pong(tsi_test_frame_protector_config* config, + tsi_test_channel* channel, + tsi_frame_protector* client_frame_protector, + tsi_frame_protector* server_frame_protector) { + GPR_ASSERT(config != nullptr); + GPR_ASSERT(channel != nullptr); + GPR_ASSERT(client_frame_protector != nullptr); + GPR_ASSERT(server_frame_protector != nullptr); + /* Client sends a message to server. */ + tsi_test_frame_protector_send_message_to_peer( + config, channel, client_frame_protector, true /* is_client */); + unsigned char* server_received_message = + static_cast(gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE)); + size_t server_received_message_size = 0; + tsi_test_frame_protector_receive_message_from_peer( + config, channel, server_frame_protector, server_received_message, + &server_received_message_size, false /* is_client */); + GPR_ASSERT(config->client_message_size == server_received_message_size); + GPR_ASSERT(memcmp(config->client_message, server_received_message, + server_received_message_size) == 0); + /* Server sends a message to client. */ + tsi_test_frame_protector_send_message_to_peer( + config, channel, server_frame_protector, false /* is_client */); + unsigned char* client_received_message = + static_cast(gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE)); + size_t client_received_message_size = 0; + tsi_test_frame_protector_receive_message_from_peer( + config, channel, client_frame_protector, client_received_message, + &client_received_message_size, true /* is_client */); + GPR_ASSERT(config->server_message_size == client_received_message_size); + GPR_ASSERT(memcmp(config->server_message, client_received_message, + client_received_message_size) == 0); + gpr_free(server_received_message); + gpr_free(client_received_message); +} + +void tsi_test_frame_protector_do_round_trip_no_handshake( + tsi_test_frame_protector_fixture* fixture) { + GPR_ASSERT(fixture != nullptr); + tsi_test_do_ping_pong(fixture->config, fixture->channel, + fixture->client_frame_protector, + fixture->server_frame_protector); +} + +void tsi_test_do_round_trip(tsi_test_fixture* fixture) { + /* Initialization. */ + GPR_ASSERT(fixture != nullptr); + GPR_ASSERT(fixture->config != nullptr); + tsi_test_frame_protector_config* config = fixture->config; + tsi_frame_protector* client_frame_protector = nullptr; + tsi_frame_protector* server_frame_protector = nullptr; + /* Perform handshake. */ + tsi_test_do_handshake(fixture); + /* Create frame protectors.*/ + size_t client_max_output_protected_frame_size = + config->client_max_output_protected_frame_size; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + fixture->client_result, + client_max_output_protected_frame_size == 0 + ? nullptr + : &client_max_output_protected_frame_size, + &client_frame_protector) == TSI_OK); + size_t server_max_output_protected_frame_size = + config->server_max_output_protected_frame_size; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + fixture->server_result, + server_max_output_protected_frame_size == 0 + ? nullptr + : &server_max_output_protected_frame_size, + &server_frame_protector) == TSI_OK); + tsi_test_do_ping_pong(config, fixture->channel, client_frame_protector, + server_frame_protector); + /* Destroy server and client frame protectors. */ + tsi_frame_protector_destroy(client_frame_protector); + tsi_frame_protector_destroy(server_frame_protector); +} + +static unsigned char* generate_random_message(size_t size) { + size_t i; + unsigned char chars[] = "abcdefghijklmnopqrstuvwxyz1234567890"; + unsigned char* output = + static_cast(gpr_zalloc(sizeof(unsigned char) * size)); + for (i = 0; i < size - 1; ++i) { + output[i] = chars[rand() % static_cast(sizeof(chars) - 1)]; + } + return output; +} + +tsi_test_frame_protector_config* tsi_test_frame_protector_config_create( + bool use_default_read_buffer_allocated_size, + bool use_default_message_buffer_allocated_size, + bool use_default_protected_buffer_size, bool use_default_client_message, + bool use_default_server_message, + bool use_default_client_max_output_protected_frame_size, + bool use_default_server_max_output_protected_frame_size) { + tsi_test_frame_protector_config* config = + static_cast( + gpr_zalloc(sizeof(*config))); + /* Set the value for read_buffer_allocated_size. */ + config->read_buffer_allocated_size = + use_default_read_buffer_allocated_size + ? TSI_TEST_DEFAULT_BUFFER_SIZE + : TSI_TEST_SMALL_READ_BUFFER_ALLOCATED_SIZE; + /* Set the value for message_buffer_allocated_size. */ + config->message_buffer_allocated_size = + use_default_message_buffer_allocated_size + ? TSI_TEST_DEFAULT_BUFFER_SIZE + : TSI_TEST_SMALL_MESSAGE_BUFFER_ALLOCATED_SIZE; + /* Set the value for protected_buffer_size. */ + config->protected_buffer_size = use_default_protected_buffer_size + ? TSI_TEST_DEFAULT_PROTECTED_BUFFER_SIZE + : TSI_TEST_SMALL_PROTECTED_BUFFER_SIZE; + /* Set the value for client message. */ + config->client_message_size = use_default_client_message + ? TSI_TEST_BIG_MESSAGE_SIZE + : TSI_TEST_SMALL_MESSAGE_SIZE; + config->client_message = + use_default_client_message + ? generate_random_message(TSI_TEST_BIG_MESSAGE_SIZE) + : generate_random_message(TSI_TEST_SMALL_MESSAGE_SIZE); + /* Set the value for server message. */ + config->server_message_size = use_default_server_message + ? TSI_TEST_BIG_MESSAGE_SIZE + : TSI_TEST_SMALL_MESSAGE_SIZE; + config->server_message = + use_default_server_message + ? generate_random_message(TSI_TEST_BIG_MESSAGE_SIZE) + : generate_random_message(TSI_TEST_SMALL_MESSAGE_SIZE); + /* Set the value for client max_output_protected_frame_size. + If it is 0, we pass NULL to tsi_handshaker_result_create_frame_protector(), + which then uses default protected frame size for it. */ + config->client_max_output_protected_frame_size = + use_default_client_max_output_protected_frame_size + ? 0 + : TSI_TEST_SMALL_CLIENT_MAX_OUTPUT_PROTECTED_FRAME_SIZE; + /* Set the value for server max_output_protected_frame_size. + If it is 0, we pass NULL to tsi_handshaker_result_create_frame_protector(), + which then uses default protected frame size for it. */ + config->server_max_output_protected_frame_size = + use_default_server_max_output_protected_frame_size + ? 0 + : TSI_TEST_SMALL_SERVER_MAX_OUTPUT_PROTECTED_FRAME_SIZE; + return config; +} + +void tsi_test_frame_protector_config_set_buffer_size( + tsi_test_frame_protector_config* config, size_t read_buffer_allocated_size, + size_t message_buffer_allocated_size, size_t protected_buffer_size, + size_t client_max_output_protected_frame_size, + size_t server_max_output_protected_frame_size) { + GPR_ASSERT(config != nullptr); + config->read_buffer_allocated_size = read_buffer_allocated_size; + config->message_buffer_allocated_size = message_buffer_allocated_size; + config->protected_buffer_size = protected_buffer_size; + config->client_max_output_protected_frame_size = + client_max_output_protected_frame_size; + config->server_max_output_protected_frame_size = + server_max_output_protected_frame_size; +} + +void tsi_test_frame_protector_config_destroy( + tsi_test_frame_protector_config* config) { + if (config == nullptr) { + return; + } + gpr_free(config->client_message); + gpr_free(config->server_message); + gpr_free(config); +} + +static tsi_test_channel* tsi_test_channel_create() { + tsi_test_channel* channel = grpc_core::Zalloc(); + channel->client_channel = + static_cast(gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE)); + channel->server_channel = + static_cast(gpr_zalloc(TSI_TEST_DEFAULT_CHANNEL_SIZE)); + channel->bytes_written_to_client_channel = 0; + channel->bytes_written_to_server_channel = 0; + channel->bytes_read_from_client_channel = 0; + channel->bytes_read_from_server_channel = 0; + return channel; +} + +static void tsi_test_channel_destroy(tsi_test_channel* channel) { + if (channel == nullptr) { + return; + } + gpr_free(channel->client_channel); + gpr_free(channel->server_channel); + gpr_free(channel); +} + +void tsi_test_fixture_init(tsi_test_fixture* fixture) { + fixture->config = tsi_test_frame_protector_config_create( + true, true, true, true, true, true, true); + fixture->handshake_buffer_size = TSI_TEST_DEFAULT_BUFFER_SIZE; + fixture->channel = tsi_test_channel_create(); + fixture->test_unused_bytes = true; + fixture->has_client_finished_first = false; + gpr_mu_init(&fixture->mu); + gpr_cv_init(&fixture->cv); + fixture->notified = false; +} + +void tsi_test_fixture_destroy(tsi_test_fixture* fixture) { + if (fixture == nullptr) { + return; + } + tsi_test_frame_protector_config_destroy(fixture->config); + tsi_handshaker_destroy(fixture->client_handshaker); + tsi_handshaker_destroy(fixture->server_handshaker); + tsi_handshaker_result_destroy(fixture->client_result); + tsi_handshaker_result_destroy(fixture->server_result); + tsi_test_channel_destroy(fixture->channel); + GPR_ASSERT(fixture->vtable != nullptr); + GPR_ASSERT(fixture->vtable->destruct != nullptr); + fixture->vtable->destruct(fixture); + gpr_mu_destroy(&fixture->mu); + gpr_cv_destroy(&fixture->cv); + gpr_free(fixture); +} + +tsi_test_frame_protector_fixture* tsi_test_frame_protector_fixture_create() { + tsi_test_frame_protector_fixture* fixture = + static_cast( + gpr_zalloc(sizeof(*fixture))); + fixture->config = tsi_test_frame_protector_config_create( + true, true, true, true, true, true, true); + fixture->channel = tsi_test_channel_create(); + return fixture; +} + +void tsi_test_frame_protector_fixture_init( + tsi_test_frame_protector_fixture* fixture, + tsi_frame_protector* client_frame_protector, + tsi_frame_protector* server_frame_protector) { + GPR_ASSERT(fixture != nullptr); + fixture->client_frame_protector = client_frame_protector; + fixture->server_frame_protector = server_frame_protector; +} + +void tsi_test_frame_protector_fixture_destroy( + tsi_test_frame_protector_fixture* fixture) { + if (fixture == nullptr) { + return; + } + tsi_test_frame_protector_config_destroy(fixture->config); + tsi_test_channel_destroy(fixture->channel); + tsi_frame_protector_destroy(fixture->client_frame_protector); + tsi_frame_protector_destroy(fixture->server_frame_protector); + gpr_free(fixture); +} diff --git a/test/core/uri/uri_fuzzer_test.cc b/test/core/uri/uri_fuzzer_test.cc new file mode 100644 index 00000000..de6e8dfc --- /dev/null +++ b/test/core/uri/uri_fuzzer_test.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/uri/uri_parser.h" + +bool squelch = true; +bool leak_check = true; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + char* s = static_cast(gpr_malloc(size + 1)); + memcpy(s, data, size); + s[size] = 0; + + grpc_init(); + + { + grpc_core::ExecCtx exec_ctx; + (void)grpc_core::URI::Parse(s); + gpr_free(s); + } + + grpc_shutdown(); + return 0; +} diff --git a/test/core/uri/uri_parser_test.cc b/test/core/uri/uri_parser_test.cc new file mode 100644 index 00000000..71b32b2b --- /dev/null +++ b/test/core/uri/uri_parser_test.cc @@ -0,0 +1,225 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/uri/uri_parser.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" + +#include +#include + +#include "test/core/util/test_config.h" + +using ::testing::ContainerEq; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Pair; + +static void TestSucceeds( + absl::string_view uri_text, absl::string_view scheme, + absl::string_view authority, absl::string_view path, + const std::map& query_param_map, + const std::vector& query_param_pairs, + absl::string_view fragment) { + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + ASSERT_TRUE(uri.ok()); + EXPECT_EQ(scheme, uri->scheme()); + EXPECT_EQ(authority, uri->authority()); + EXPECT_EQ(path, uri->path()); + EXPECT_THAT(uri->query_parameter_map(), ContainerEq(query_param_map)); + EXPECT_THAT(uri->query_parameter_pairs(), ContainerEq(query_param_pairs)); + EXPECT_EQ(fragment, uri->fragment()); +} + +static void TestFails(absl::string_view uri_text) { + absl::StatusOr uri = grpc_core::URI::Parse(uri_text); + ASSERT_FALSE(uri.ok()); +} + +TEST(URIParserTest, BasicExamplesAreParsedCorrectly) { + TestSucceeds("http://www.google.com", "http", "www.google.com", "", {}, {}, + ""); + TestSucceeds("dns:///foo", "dns", "", "/foo", {}, {}, ""); + TestSucceeds("http://www.google.com:90", "http", "www.google.com:90", "", {}, + {}, ""); + TestSucceeds("a192.4-df:foo.coom", "a192.4-df", "", "foo.coom", {}, {}, ""); + TestSucceeds("a+b:foo.coom", "a+b", "", "foo.coom", {}, {}, ""); + TestSucceeds("zookeeper://127.0.0.1:2181/foo/bar", "zookeeper", + "127.0.0.1:2181", "/foo/bar", {}, {}, ""); + TestSucceeds("dns:foo.com#fragment-all-the-things", "dns", "", "foo.com", {}, + {}, "fragment-all-the-things"); + TestSucceeds("http://localhost:8080/whatzit?mi_casa=su_casa", "http", + "localhost:8080", "/whatzit", {{"mi_casa", "su_casa"}}, + {{"mi_casa", "su_casa"}}, ""); + TestSucceeds("http://localhost:8080/whatzit?1=2#buckle/my/shoe", "http", + "localhost:8080", "/whatzit", {{"1", "2"}}, {{"1", "2"}}, + "buckle/my/shoe"); +} + +TEST(URIParserTest, UncommonValidExamplesAreParsedCorrectly) { + TestSucceeds("scheme:path//is/ok", "scheme", "", "path//is/ok", {}, {}, ""); + TestSucceeds("http:?legit", "http", "", "", {{"legit", ""}}, {{"legit", ""}}, + ""); + TestSucceeds("unix:#this-is-ok-too", "unix", "", "", {}, {}, + "this-is-ok-too"); + TestSucceeds("http:?legit#twice", "http", "", "", {{"legit", ""}}, + {{"legit", ""}}, "twice"); + TestSucceeds("fake:///", "fake", "", "/", {}, {}, ""); +} + +TEST(URIParserTest, VariousKeyValueAndNonKVQueryParamsAreParsedCorrectly) { + TestSucceeds("http://foo/path?a&b=B&c=&#frag", "http", "foo", "/path", + {{"c", ""}, {"a", ""}, {"b", "B"}}, + {{"a", ""}, {"b", "B"}, {"c", ""}}, "frag"); +} + +TEST(URIParserTest, ParserTreatsFirstEqualSignAsKVDelimiterInQueryString) { + TestSucceeds( + "http://localhost:8080/?too=many=equals&are=present=here#fragged", "http", + "localhost:8080", "/", {{"are", "present=here"}, {"too", "many=equals"}}, + {{"too", "many=equals"}, {"are", "present=here"}}, "fragged"); + TestSucceeds("http://auth/path?foo=bar=baz&foobar===", "http", "auth", + "/path", {{"foo", "bar=baz"}, {"foobar", "=="}}, + {{"foo", "bar=baz"}, {"foobar", "=="}}, ""); +} + +TEST(URIParserTest, + RepeatedQueryParamsAreSupportedInOrderedPairsButDeduplicatedInTheMap) { + absl::StatusOr uri = + grpc_core::URI::Parse("http://foo/path?a=2&a=1&a=3"); + ASSERT_TRUE(uri.ok()); + // The map stores the last found value. + ASSERT_THAT(uri->query_parameter_map(), ElementsAre(Pair("a", "3"))); + // Order matters for query parameter pairs + ASSERT_THAT(uri->query_parameter_pairs(), + ElementsAre(grpc_core::URI::QueryParam{"a", "2"}, + grpc_core::URI::QueryParam{"a", "1"}, + grpc_core::URI::QueryParam{"a", "3"})); +} + +TEST(URIParserTest, QueryParamMapRemainsValiditAfterMovingTheURI) { + grpc_core::URI uri_copy; + { + absl::StatusOr uri = + grpc_core::URI::Parse("http://foo/path?a=2&b=1&c=3"); + ASSERT_TRUE(uri.ok()); + uri_copy = std::move(*uri); + } + // ASSERT_EQ(uri_copy.query_parameter_map().find("a")->second, "2"); + ASSERT_THAT(uri_copy.query_parameter_map(), Contains(Pair("a", "2"))); +} + +TEST(URIParserTest, QueryParamMapRemainsValidAfterCopyingTheURI) { + // Since the query parameter map points to objects stored in the param pair + // vector, this test checks that the param map pointers remain valid after + // a copy. Ideally {a,m}san will catch this if there's a problem. + // testing copy operator=: + grpc_core::URI uri_copy; + { + absl::StatusOr del_uri = + grpc_core::URI::Parse("http://foo/path?a=2&b=1&c=3"); + ASSERT_TRUE(del_uri.ok()); + uri_copy = *del_uri; + } + ASSERT_THAT(uri_copy.query_parameter_map(), Contains(Pair("a", "2"))); + grpc_core::URI* del_uri2 = new grpc_core::URI(uri_copy); + grpc_core::URI uri_copy2(*del_uri2); + delete del_uri2; + ASSERT_THAT(uri_copy2.query_parameter_map(), Contains(Pair("a", "2"))); +} + +TEST(URIParserTest, AWSExternalAccountRegressionTest) { + TestSucceeds( + "https://foo.com:5555/v1/" + "token-exchange?subject_token=eyJhbGciO&subject_token_type=urn:ietf:" + "params:oauth:token-type:id_token", + "https", "foo.com:5555", "/v1/token-exchange", + {{"subject_token", "eyJhbGciO"}, + {"subject_token_type", "urn:ietf:params:oauth:token-type:id_token"}}, + {{"subject_token", "eyJhbGciO"}, + {"subject_token_type", "urn:ietf:params:oauth:token-type:id_token"}}, + ""); +} + +TEST(URIParserTest, NonKeyValueQueryStringsWork) { + TestSucceeds("http://www.google.com?yay-i'm-using-queries", "http", + "www.google.com", "", {{"yay-i'm-using-queries", ""}}, + {{"yay-i'm-using-queries", ""}}, ""); +} + +TEST(URIParserTest, IPV6StringsAreParsedCorrectly) { + TestSucceeds("ipv6:[2001:db8::1%252]:12345", "ipv6", "", + "[2001:db8::1%2]:12345", {}, {}, ""); + TestSucceeds("ipv6:[fe80::90%eth1.sky1]:6010", "ipv6", "", + "[fe80::90%eth1.sky1]:6010", {}, {}, ""); +} + +TEST(URIParserTest, PreviouslyReservedCharactersInUnrelatedURIPartsAreIgnored) { + // The '?' and '/' characters are not reserved delimiter characters in the + // fragment. See http://go/rfc/3986#section-3.5 + TestSucceeds("http://foo?bar#lol?", "http", "foo", "", {{"bar", ""}}, + {{"bar", ""}}, "lol?"); + TestSucceeds("http://foo?bar#lol?/", "http", "foo", "", {{"bar", ""}}, + {{"bar", ""}}, "lol?/"); +} + +TEST(URIParserTest, EncodedCharactersInQueryStringAreParsedCorrectly) { + TestSucceeds("https://www.google.com/?a=1%26b%3D2&c=3", "https", + "www.google.com", "/", {{"c", "3"}, {"a", "1&b=2"}}, + {{"a", "1&b=2"}, {"c", "3"}}, ""); +} + +TEST(URIParserTest, InvalidPercentEncodingsArePassedThrough) { + TestSucceeds("x:y?%xx", "x", "", "y", {{"%xx", ""}}, {{"%xx", ""}}, ""); + TestSucceeds("http:?dangling-pct-%0", "http", "", "", + {{"dangling-pct-%0", ""}}, {{"dangling-pct-%0", ""}}, ""); +} + +TEST(URIParserTest, NullCharactersInURIStringAreSupported) { + // Artificial examples to show that embedded nulls are supported. + TestSucceeds(std::string("unix-abstract:\0should-be-ok", 27), "unix-abstract", + "", std::string("\0should-be-ok", 13), {}, {}, ""); +} + +TEST(URIParserTest, EncodedNullsInURIStringAreSupported) { + TestSucceeds("unix-abstract:%00x", "unix-abstract", "", std::string("\0x", 2), + {}, {}, ""); +} + +TEST(URIParserTest, InvalidURIsResultInFailureStatuses) { + TestFails("xyz"); + TestFails("http://foo?[bar]"); + TestFails("http://foo?x[bar]"); + TestFails("http://foo?bar#lol#"); + TestFails(""); + TestFails(":no_scheme"); + TestFails("0invalid_scheme:must_start/with?alpha"); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/util/cmdline.cc b/test/core/util/cmdline.cc new file mode 100644 index 00000000..182370a2 --- /dev/null +++ b/test/core/util/cmdline.cc @@ -0,0 +1,321 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/cmdline.h" + +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" + +typedef enum { ARGTYPE_INT, ARGTYPE_BOOL, ARGTYPE_STRING } argtype; + +typedef struct arg { + const char* name; + const char* help; + argtype type; + void* value; + struct arg* next; +} arg; + +struct gpr_cmdline { + const char* description; + arg* args; + const char* argv0; + + const char* extra_arg_name; + const char* extra_arg_help; + void (*extra_arg)(void* user_data, const char* arg); + void* extra_arg_user_data; + + int (*state)(gpr_cmdline* cl, char* arg); + arg* cur_arg; + + int survive_failure; +}; + +static int normal_state(gpr_cmdline* cl, char* str); + +gpr_cmdline* gpr_cmdline_create(const char* description) { + gpr_cmdline* cl = grpc_core::Zalloc(); + + cl->description = description; + cl->state = normal_state; + + return cl; +} + +void gpr_cmdline_set_survive_failure(gpr_cmdline* cl) { + cl->survive_failure = 1; +} + +void gpr_cmdline_destroy(gpr_cmdline* cl) { + while (cl->args) { + arg* a = cl->args; + cl->args = a->next; + gpr_free(a); + } + gpr_free(cl); +} + +static void add_arg(gpr_cmdline* cl, const char* name, const char* help, + argtype type, void* value) { + arg* a; + + for (a = cl->args; a; a = a->next) { + GPR_ASSERT(0 != strcmp(a->name, name)); + } + + a = static_cast(gpr_zalloc(sizeof(arg))); + a->name = name; + a->help = help; + a->type = type; + a->value = value; + a->next = cl->args; + cl->args = a; +} + +void gpr_cmdline_add_int(gpr_cmdline* cl, const char* name, const char* help, + int* value) { + add_arg(cl, name, help, ARGTYPE_INT, value); +} + +void gpr_cmdline_add_flag(gpr_cmdline* cl, const char* name, const char* help, + int* value) { + add_arg(cl, name, help, ARGTYPE_BOOL, value); +} + +void gpr_cmdline_add_string(gpr_cmdline* cl, const char* name, const char* help, + const char** value) { + add_arg(cl, name, help, ARGTYPE_STRING, value); +} + +void gpr_cmdline_on_extra_arg( + gpr_cmdline* cl, const char* name, const char* help, + void (*on_extra_arg)(void* user_data, const char* arg), void* user_data) { + GPR_ASSERT(!cl->extra_arg); + GPR_ASSERT(on_extra_arg); + + cl->extra_arg = on_extra_arg; + cl->extra_arg_user_data = user_data; + cl->extra_arg_name = name; + cl->extra_arg_help = help; +} + +/* recursively descend argument list, adding the last element + to s first - so that arguments are added in the order they were + added to the list by api calls */ +static void add_args_to_usage(arg* a, std::vector* s) { + if (a == nullptr) return; + add_args_to_usage(a->next, s); + switch (a->type) { + case ARGTYPE_BOOL: + s->push_back(absl::StrFormat(" [--%s|--no-%s]", a->name, a->name)); + break; + case ARGTYPE_STRING: + s->push_back(absl::StrFormat(" [--%s=string]", a->name)); + break; + case ARGTYPE_INT: + s->push_back(absl::StrFormat(" [--%s=int]", a->name)); + break; + } +} + +std::string gpr_cmdline_usage_string(gpr_cmdline* cl, const char* argv0) { + const char* name = strrchr(argv0, '/'); + if (name != nullptr) { + name++; + } else { + name = argv0; + } + + std::vector s; + s.push_back(absl::StrCat("Usage: ", name)); + add_args_to_usage(cl->args, &s); + if (cl->extra_arg) { + s.push_back(absl::StrFormat(" [%s...]", cl->extra_arg_name)); + } + s.push_back("\n"); + return absl::StrJoin(s, ""); +} + +static int print_usage_and_die(gpr_cmdline* cl) { + fprintf(stderr, "%s", gpr_cmdline_usage_string(cl, cl->argv0).c_str()); + if (!cl->survive_failure) { + exit(1); + } + return 0; +} + +static int extra_state(gpr_cmdline* cl, char* str) { + if (!cl->extra_arg) { + return print_usage_and_die(cl); + } + cl->extra_arg(cl->extra_arg_user_data, str); + return 1; +} + +static arg* find_arg(gpr_cmdline* cl, char* name) { + arg* a; + + for (a = cl->args; a; a = a->next) { + if (0 == strcmp(a->name, name)) { + break; + } + } + + if (!a) { + fprintf(stderr, "Unknown argument: %s\n", name); + return nullptr; + } + + return a; +} + +static int value_state(gpr_cmdline* cl, char* str) { + long intval; + char* end; + + GPR_ASSERT(cl->cur_arg); + + switch (cl->cur_arg->type) { + case ARGTYPE_INT: + intval = strtol(str, &end, 0); + if (*end || intval < INT_MIN || intval > INT_MAX) { + fprintf(stderr, "expected integer, got '%s' for %s\n", str, + cl->cur_arg->name); + return print_usage_and_die(cl); + } + *static_cast(cl->cur_arg->value) = static_cast(intval); + break; + case ARGTYPE_BOOL: + if (0 == strcmp(str, "1") || 0 == strcmp(str, "true")) { + *static_cast(cl->cur_arg->value) = 1; + } else if (0 == strcmp(str, "0") || 0 == strcmp(str, "false")) { + *static_cast(cl->cur_arg->value) = 0; + } else { + fprintf(stderr, "expected boolean, got '%s' for %s\n", str, + cl->cur_arg->name); + return print_usage_and_die(cl); + } + break; + case ARGTYPE_STRING: + *static_cast(cl->cur_arg->value) = str; + break; + } + + cl->state = normal_state; + return 1; +} + +static int normal_state(gpr_cmdline* cl, char* str) { + char* eq = nullptr; + char* tmp = nullptr; + char* arg_name = nullptr; + int r = 1; + + if (0 == strcmp(str, "-help") || 0 == strcmp(str, "--help") || + 0 == strcmp(str, "-h")) { + return print_usage_and_die(cl); + } + + cl->cur_arg = nullptr; + + if (str[0] == '-') { + if (str[1] == '-') { + if (str[2] == 0) { + /* handle '--' to move to just extra args */ + cl->state = extra_state; + return 1; + } + str += 2; + } else { + str += 1; + } + /* first byte of str is now past the leading '-' or '--' */ + if (str[0] == 'n' && str[1] == 'o' && str[2] == '-') { + /* str is of the form '--no-foo' - it's a flag disable */ + str += 3; + cl->cur_arg = find_arg(cl, str); + if (cl->cur_arg == nullptr) { + return print_usage_and_die(cl); + } + if (cl->cur_arg->type != ARGTYPE_BOOL) { + fprintf(stderr, "%s is not a flag argument\n", str); + return print_usage_and_die(cl); + } + *static_cast(cl->cur_arg->value) = 0; + return 1; /* early out */ + } + eq = strchr(str, '='); + if (eq != nullptr) { + /* copy the string into a temp buffer and extract the name */ + tmp = arg_name = + static_cast(gpr_malloc(static_cast(eq - str + 1))); + memcpy(arg_name, str, static_cast(eq - str)); + arg_name[eq - str] = 0; + } else { + arg_name = str; + } + cl->cur_arg = find_arg(cl, arg_name); + if (cl->cur_arg == nullptr) { + return print_usage_and_die(cl); + } + if (eq != nullptr) { + /* str was of the type --foo=value, parse the value */ + r = value_state(cl, eq + 1); + } else if (cl->cur_arg->type != ARGTYPE_BOOL) { + /* flag types don't have a '--foo value' variant, other types do */ + cl->state = value_state; + } else { + /* flag parameter: just set the value */ + *static_cast(cl->cur_arg->value) = 1; + } + } else { + r = extra_state(cl, str); + } + + gpr_free(tmp); + return r; +} + +int gpr_cmdline_parse(gpr_cmdline* cl, int argc, char** argv) { + int i; + + GPR_ASSERT(argc >= 1); + cl->argv0 = argv[0]; + + for (i = 1; i < argc; i++) { + if (!cl->state(cl, argv[i])) { + return 0; + } + } + return 1; +} diff --git a/test/core/util/cmdline_test.cc b/test/core/util/cmdline_test.cc new file mode 100644 index 00000000..3a1e6e71 --- /dev/null +++ b/test/core/util/cmdline_test.cc @@ -0,0 +1,494 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/cmdline.h" + +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "test/core/util/test_config.h" + +#define LOG_TEST() gpr_log(GPR_INFO, "test at %s:%d", __FILE__, __LINE__) + +static void test_simple_int(void) { + int x = 1; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("-foo"), + const_cast("3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_int(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 1); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 3); + gpr_cmdline_destroy(cl); +} + +static void test_eq_int(void) { + int x = 1; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("-foo=3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_int(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 1); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 3); + gpr_cmdline_destroy(cl); +} + +static void test_2dash_int(void) { + int x = 1; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo"), + const_cast("3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_int(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 1); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 3); + gpr_cmdline_destroy(cl); +} + +static void test_2dash_eq_int(void) { + int x = 1; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo=3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_int(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 1); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 3); + gpr_cmdline_destroy(cl); +} + +static void test_simple_string(void) { + const char* x = nullptr; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("-foo"), + const_cast("3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "foo", nullptr, &x); + GPR_ASSERT(x == nullptr); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(0 == strcmp(x, "3")); + gpr_cmdline_destroy(cl); +} + +static void test_eq_string(void) { + const char* x = nullptr; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("-foo=3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "foo", nullptr, &x); + GPR_ASSERT(x == nullptr); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(0 == strcmp(x, "3")); + gpr_cmdline_destroy(cl); +} + +static void test_2dash_string(void) { + const char* x = nullptr; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo"), + const_cast("3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "foo", nullptr, &x); + GPR_ASSERT(x == nullptr); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(0 == strcmp(x, "3")); + gpr_cmdline_destroy(cl); +} + +static void test_2dash_eq_string(void) { + const char* x = nullptr; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo=3")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "foo", nullptr, &x); + GPR_ASSERT(x == nullptr); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(0 == strcmp(x, "3")); + gpr_cmdline_destroy(cl); +} + +static void test_flag_on(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 1); + gpr_cmdline_destroy(cl); +} + +static void test_flag_no(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--no-foo")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 0); + gpr_cmdline_destroy(cl); +} + +static void test_flag_val_1(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo=1")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 1); + gpr_cmdline_destroy(cl); +} + +static void test_flag_val_0(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo=0")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 0); + gpr_cmdline_destroy(cl); +} + +static void test_flag_val_true(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), const_cast("--foo=true")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 1); + gpr_cmdline_destroy(cl); +} + +static void test_flag_val_false(void) { + int x = 2; + gpr_cmdline* cl; + char* args[] = {const_cast(__FILE__), + const_cast("--foo=false")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_flag(cl, "foo", nullptr, &x); + GPR_ASSERT(x == 2); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 0); + gpr_cmdline_destroy(cl); +} + +static void test_many(void) { + const char* str = nullptr; + int x = 0; + int flag = 2; + gpr_cmdline* cl; + + char* args[] = {const_cast(__FILE__), const_cast("--str"), + const_cast("hello"), const_cast("-x=4"), + const_cast("-no-flag")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(x == 4); + GPR_ASSERT(0 == strcmp(str, "hello")); + GPR_ASSERT(flag == 0); + gpr_cmdline_destroy(cl); +} + +static void extra_arg_cb(void* user_data, const char* arg) { + int* count = static_cast(user_data); + GPR_ASSERT(arg != nullptr); + GPR_ASSERT(strlen(arg) == 1); + GPR_ASSERT(arg[0] == 'a' + *count); + ++*count; +} + +static void test_extra(void) { + gpr_cmdline* cl; + int count = 0; + char* args[] = {const_cast(__FILE__), const_cast("a"), + const_cast("b"), const_cast("c")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + &count); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(count == 3); + gpr_cmdline_destroy(cl); +} + +static void test_extra_dashdash(void) { + gpr_cmdline* cl; + int count = 0; + char* args[] = {const_cast(__FILE__), const_cast("--"), + const_cast("a"), const_cast("b"), + const_cast("c")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + &count); + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(args), args); + GPR_ASSERT(count == 3); + gpr_cmdline_destroy(cl); +} + +static void test_usage(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + std::string usage = gpr_cmdline_usage_string(cl, "test"); + GPR_ASSERT(usage == + "Usage: test [--str=string] [--x=int] " + "[--flag|--no-flag] [file...]\n"); + + usage = gpr_cmdline_usage_string(cl, "/foo/test"); + GPR_ASSERT(usage == + "Usage: test [--str=string] [--x=int] " + "[--flag|--no-flag] [file...]\n"); + + gpr_cmdline_destroy(cl); +} + +static void test_help(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + char* help[] = {const_cast(__FILE__), const_cast("-h")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_set_survive_failure(cl); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + GPR_ASSERT(0 == gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(help), help)); + + gpr_cmdline_destroy(cl); +} + +static void test_badargs1(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + char* bad_arg_name[] = {const_cast(__FILE__), + const_cast("--y")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_set_survive_failure(cl); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + GPR_ASSERT(0 == + gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(bad_arg_name), bad_arg_name)); + + gpr_cmdline_destroy(cl); +} + +static void test_badargs2(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + char* bad_int_value[] = {const_cast(__FILE__), + const_cast("--x"), + const_cast("henry")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_set_survive_failure(cl); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + GPR_ASSERT( + 0 == gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(bad_int_value), bad_int_value)); + + gpr_cmdline_destroy(cl); +} + +static void test_badargs3(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + char* bad_bool_value[] = {const_cast(__FILE__), + const_cast("--flag=henry")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_set_survive_failure(cl); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + GPR_ASSERT(0 == gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(bad_bool_value), + bad_bool_value)); + + gpr_cmdline_destroy(cl); +} + +static void test_badargs4(void) { + gpr_cmdline* cl; + + const char* str = nullptr; + int x = 0; + int flag = 2; + + char* bad_bool_value[] = {const_cast(__FILE__), + const_cast("--no-str")}; + + LOG_TEST(); + + cl = gpr_cmdline_create(nullptr); + gpr_cmdline_set_survive_failure(cl); + gpr_cmdline_add_string(cl, "str", nullptr, &str); + gpr_cmdline_add_int(cl, "x", nullptr, &x); + gpr_cmdline_add_flag(cl, "flag", nullptr, &flag); + gpr_cmdline_on_extra_arg(cl, "file", "filenames to process", extra_arg_cb, + nullptr); + + GPR_ASSERT(0 == gpr_cmdline_parse(cl, GPR_ARRAY_SIZE(bad_bool_value), + bad_bool_value)); + + gpr_cmdline_destroy(cl); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + test_simple_int(); + test_eq_int(); + test_2dash_int(); + test_2dash_eq_int(); + test_simple_string(); + test_eq_string(); + test_2dash_string(); + test_2dash_eq_string(); + test_flag_on(); + test_flag_no(); + test_flag_val_1(); + test_flag_val_0(); + test_flag_val_true(); + test_flag_val_false(); + test_many(); + test_extra(); + test_extra_dashdash(); + test_usage(); + test_help(); + test_badargs1(); + test_badargs2(); + test_badargs3(); + test_badargs4(); + return 0; +} diff --git a/test/core/util/fuzzer_corpus_test.cc b/test/core/util/fuzzer_corpus_test.cc new file mode 100644 index 00000000..0638baf2 --- /dev/null +++ b/test/core/util/fuzzer_corpus_test.cc @@ -0,0 +1,166 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/load_file.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size); +extern bool squelch; +extern bool leak_check; + +ABSL_FLAG(std::string, file, "", "Use this file as test data"); +ABSL_FLAG(std::string, directory, "", "Use this directory as test data"); + +class FuzzerCorpusTest : public ::testing::TestWithParam {}; + +TEST_P(FuzzerCorpusTest, RunOneExample) { + // Need to call grpc_init() here to use a slice, but need to shut it + // down before calling LLVMFuzzerTestOneInput(), because most + // implementations of that function will initialize and shutdown gRPC + // internally. + grpc_init(); + gpr_log(GPR_INFO, "Example file: %s", GetParam().c_str()); + grpc_slice buffer; + squelch = false; + leak_check = false; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(GetParam().c_str(), 0, &buffer))); + size_t length = GRPC_SLICE_LENGTH(buffer); + void* data = gpr_malloc(length); + memcpy(data, GPR_SLICE_START_PTR(buffer), length); + grpc_slice_unref(buffer); + grpc_shutdown(); + LLVMFuzzerTestOneInput(static_cast(data), length); + gpr_free(data); +} + +class ExampleGenerator + : public ::testing::internal::ParamGeneratorInterface { + public: + ::testing::internal::ParamIteratorInterface* Begin() + const override; + ::testing::internal::ParamIteratorInterface* End() + const override; + + private: + void Materialize() const { + if (examples_.empty()) { + if (!absl::GetFlag(FLAGS_file).empty()) { + examples_.push_back(absl::GetFlag(FLAGS_file)); + } + if (!absl::GetFlag(FLAGS_directory).empty()) { + char* test_srcdir = gpr_getenv("TEST_SRCDIR"); + gpr_log(GPR_DEBUG, "test_srcdir=\"%s\"", test_srcdir); + std::string directory = absl::GetFlag(FLAGS_directory); + if (test_srcdir != nullptr) { + directory = + test_srcdir + std::string("/com_github_grpc_grpc/") + directory; + } + gpr_log(GPR_DEBUG, "Using corpus directory: %s", directory.c_str()); + DIR* dp; + struct dirent* ep; + dp = opendir(directory.c_str()); + + if (dp != nullptr) { + while ((ep = readdir(dp)) != nullptr) { + if (strcmp(ep->d_name, ".") != 0 && strcmp(ep->d_name, "..") != 0) { + examples_.push_back(directory + "/" + ep->d_name); + } + } + + (void)closedir(dp); + } else { + perror("Couldn't open the directory"); + abort(); + } + gpr_free(test_srcdir); + } + } + // Make sure we don't succeed without doing anything, which caused + // us to be blind to our fuzzers not running for 9 months. + GPR_ASSERT(!examples_.empty()); + } + + mutable std::vector examples_; +}; + +class ExampleIterator + : public ::testing::internal::ParamIteratorInterface { + public: + ExampleIterator(const ExampleGenerator& base_, + std::vector::const_iterator begin) + : base_(base_), begin_(begin), current_(begin) {} + + const ExampleGenerator* BaseGenerator() const override { return &base_; } + + void Advance() override { current_++; } + ExampleIterator* Clone() const override { return new ExampleIterator(*this); } + const std::string* Current() const override { return &*current_; } + + bool Equals(const ParamIteratorInterface& other) const override { + return &base_ == other.BaseGenerator() && + current_ == dynamic_cast(&other)->current_; + } + + private: + ExampleIterator(const ExampleIterator& other) + : base_(other.base_), begin_(other.begin_), current_(other.current_) {} + + const ExampleGenerator& base_; + const std::vector::const_iterator begin_; + std::vector::const_iterator current_; +}; + +::testing::internal::ParamIteratorInterface* +ExampleGenerator::Begin() const { + Materialize(); + return new ExampleIterator(*this, examples_.begin()); +} + +::testing::internal::ParamIteratorInterface* +ExampleGenerator::End() const { + Materialize(); + return new ExampleIterator(*this, examples_.end()); +} + +INSTANTIATE_TEST_SUITE_P( + CorpusExamples, FuzzerCorpusTest, + ::testing::internal::ParamGenerator(new ExampleGenerator)); + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/test/core/util/fuzzer_util.cc b/test/core/util/fuzzer_util.cc new file mode 100644 index 00000000..ffd8832a --- /dev/null +++ b/test/core/util/fuzzer_util.cc @@ -0,0 +1,84 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/fuzzer_util.h" + +#include + +#include + +#include "src/core/lib/gpr/useful.h" + +namespace grpc_core { +namespace testing { + +uint8_t grpc_fuzzer_get_next_byte(input_stream* inp) { + if (inp->cur == inp->end) { + return 0; + } + return *inp->cur++; +} + +char* grpc_fuzzer_get_next_string(input_stream* inp, bool* special) { + char* str = nullptr; + size_t cap = 0; + size_t sz = 0; + char c; + do { + if (cap == sz) { + cap = std::max(3 * cap / 2, cap + 8); + str = static_cast(gpr_realloc(str, cap)); + } + c = static_cast(grpc_fuzzer_get_next_byte(inp)); + str[sz++] = c; + } while (c != 0 && c != 1); + if (special != nullptr) { + *special = (c == 1); + } + if (c == 1) { + str[sz - 1] = 0; + } + return str; +} + +uint32_t grpc_fuzzer_get_next_uint32(input_stream* inp) { + uint8_t b = grpc_fuzzer_get_next_byte(inp); + uint32_t x = b & 0x7f; + if (b & 0x80) { + x <<= 7; + b = grpc_fuzzer_get_next_byte(inp); + x |= b & 0x7f; + if (b & 0x80) { + x <<= 7; + b = grpc_fuzzer_get_next_byte(inp); + x |= b & 0x7f; + if (b & 0x80) { + x <<= 7; + b = grpc_fuzzer_get_next_byte(inp); + x |= b & 0x7f; + if (b & 0x80) { + x = (x << 4) | (grpc_fuzzer_get_next_byte(inp) & 0x0f); + } + } + } + } + return x; +} + +} // namespace testing +} // namespace grpc_core diff --git a/test/core/util/grpc_profiler.cc b/test/core/util/grpc_profiler.cc new file mode 100644 index 00000000..88f23359 --- /dev/null +++ b/test/core/util/grpc_profiler.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/grpc_profiler.h" + +#if GRPC_HAVE_PERFTOOLS +#include + +void grpc_profiler_start(const char* filename) { ProfilerStart(filename); } + +void grpc_profiler_stop() { ProfilerStop(); } +#else +#include + +void grpc_profiler_start(const char* filename) { + static int printed_warning = 0; + if (!printed_warning) { + gpr_log(GPR_DEBUG, + "You do not have google-perftools installed, profiling is disabled " + "[for %s]", + filename); + gpr_log(GPR_DEBUG, + "To install on ubuntu: sudo apt-get install google-perftools " + "libgoogle-perftools-dev"); + printed_warning = 1; + } +} + +void grpc_profiler_stop(void) {} +#endif diff --git a/test/core/util/histogram.cc b/test/core/util/histogram.cc new file mode 100644 index 00000000..fc3e21c5 --- /dev/null +++ b/test/core/util/histogram.cc @@ -0,0 +1,233 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/histogram.h" + +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/useful.h" + +/* Histograms are stored with exponentially increasing bucket sizes. + The first bucket is [0, m) where m = 1 + resolution + Bucket n (n>=1) contains [m**n, m**(n+1)) + There are sufficient buckets to reach max_bucket_start */ + +struct grpc_histogram { + /* Sum of all values seen so far */ + double sum; + /* Sum of squares of all values seen so far */ + double sum_of_squares; + /* number of values seen so far */ + double count; + /* m in the description */ + double multiplier; + double one_on_log_multiplier; + /* minimum value seen */ + double min_seen; + /* maximum value seen */ + double max_seen; + /* maximum representable value */ + double max_possible; + /* number of buckets */ + size_t num_buckets; + /* the buckets themselves */ + uint32_t* buckets; +}; + +/* determine a bucket index given a value - does no bounds checking */ +static size_t bucket_for_unchecked(grpc_histogram* h, double x) { + return static_cast(log(x) * h->one_on_log_multiplier); +} + +/* bounds checked version of the above */ +static size_t bucket_for(grpc_histogram* h, double x) { + size_t bucket = + bucket_for_unchecked(h, grpc_core::Clamp(x, 1.0, h->max_possible)); + GPR_ASSERT(bucket < h->num_buckets); + return bucket; +} + +/* at what value does a bucket start? */ +static double bucket_start(grpc_histogram* h, double x) { + return pow(h->multiplier, x); +} + +grpc_histogram* grpc_histogram_create(double resolution, + double max_bucket_start) { + grpc_histogram* h = + static_cast(gpr_malloc(sizeof(grpc_histogram))); + GPR_ASSERT(resolution > 0.0); + GPR_ASSERT(max_bucket_start > resolution); + h->sum = 0.0; + h->sum_of_squares = 0.0; + h->multiplier = 1.0 + resolution; + h->one_on_log_multiplier = 1.0 / log(1.0 + resolution); + h->max_possible = max_bucket_start; + h->count = 0.0; + h->min_seen = max_bucket_start; + h->max_seen = 0.0; + h->num_buckets = bucket_for_unchecked(h, max_bucket_start) + 1; + GPR_ASSERT(h->num_buckets > 1); + GPR_ASSERT(h->num_buckets < 100000000); + h->buckets = + static_cast(gpr_zalloc(sizeof(uint32_t) * h->num_buckets)); + return h; +} + +void grpc_histogram_destroy(grpc_histogram* h) { + gpr_free(h->buckets); + gpr_free(h); +} + +void grpc_histogram_add(grpc_histogram* h, double x) { + h->sum += x; + h->sum_of_squares += x * x; + h->count++; + if (x < h->min_seen) { + h->min_seen = x; + } + if (x > h->max_seen) { + h->max_seen = x; + } + h->buckets[bucket_for(h, x)]++; +} + +int grpc_histogram_merge(grpc_histogram* dst, const grpc_histogram* src) { + if ((dst->num_buckets != src->num_buckets) || + (dst->multiplier != src->multiplier)) { + /* Fail because these histograms don't match */ + return 0; + } + grpc_histogram_merge_contents(dst, src->buckets, src->num_buckets, + src->min_seen, src->max_seen, src->sum, + src->sum_of_squares, src->count); + return 1; +} + +void grpc_histogram_merge_contents(grpc_histogram* histogram, + const uint32_t* data, size_t data_count, + double min_seen, double max_seen, double sum, + double sum_of_squares, double count) { + size_t i; + GPR_ASSERT(histogram->num_buckets == data_count); + histogram->sum += sum; + histogram->sum_of_squares += sum_of_squares; + histogram->count += count; + if (min_seen < histogram->min_seen) { + histogram->min_seen = min_seen; + } + if (max_seen > histogram->max_seen) { + histogram->max_seen = max_seen; + } + for (i = 0; i < histogram->num_buckets; i++) { + histogram->buckets[i] += data[i]; + } +} + +static double threshold_for_count_below(grpc_histogram* h, double count_below) { + double count_so_far; + double lower_bound; + double upper_bound; + size_t lower_idx; + size_t upper_idx; + + if (h->count == 0) { + return 0.0; + } + + if (count_below <= 0) { + return h->min_seen; + } + if (count_below >= h->count) { + return h->max_seen; + } + + /* find the lowest bucket that gets us above count_below */ + count_so_far = 0.0; + for (lower_idx = 0; lower_idx < h->num_buckets; lower_idx++) { + count_so_far += h->buckets[lower_idx]; + if (count_so_far >= count_below) { + break; + } + } + if (count_so_far == count_below) { + /* this bucket hits the threshold exactly... we should be midway through + any run of zero values following the bucket */ + for (upper_idx = lower_idx + 1; upper_idx < h->num_buckets; upper_idx++) { + if (h->buckets[upper_idx]) { + break; + } + } + return (bucket_start(h, static_cast(lower_idx)) + + bucket_start(h, static_cast(upper_idx))) / + 2.0; + } else { + /* treat values as uniform throughout the bucket, and find where this value + should lie */ + lower_bound = bucket_start(h, static_cast(lower_idx)); + upper_bound = bucket_start(h, static_cast(lower_idx + 1)); + return grpc_core::Clamp(upper_bound - (upper_bound - lower_bound) * + (count_so_far - count_below) / + h->buckets[lower_idx], + h->min_seen, h->max_seen); + } +} + +double grpc_histogram_percentile(grpc_histogram* h, double percentile) { + return threshold_for_count_below(h, h->count * percentile / 100.0); +} + +double grpc_histogram_mean(grpc_histogram* h) { + GPR_ASSERT(h->count != 0); + return h->sum / h->count; +} + +double grpc_histogram_stddev(grpc_histogram* h) { + return sqrt(grpc_histogram_variance(h)); +} + +double grpc_histogram_variance(grpc_histogram* h) { + if (h->count == 0) return 0.0; + return (h->sum_of_squares * h->count - h->sum * h->sum) / + (h->count * h->count); +} + +double grpc_histogram_maximum(grpc_histogram* h) { return h->max_seen; } + +double grpc_histogram_minimum(grpc_histogram* h) { return h->min_seen; } + +double grpc_histogram_count(grpc_histogram* h) { return h->count; } + +double grpc_histogram_sum(grpc_histogram* h) { return h->sum; } + +double grpc_histogram_sum_of_squares(grpc_histogram* h) { + return h->sum_of_squares; +} + +const uint32_t* grpc_histogram_get_contents(grpc_histogram* histogram, + size_t* count) { + *count = histogram->num_buckets; + return histogram->buckets; +} diff --git a/test/core/util/histogram_test.cc b/test/core/util/histogram_test.cc new file mode 100644 index 00000000..3ff92a31 --- /dev/null +++ b/test/core/util/histogram_test.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/histogram.h" + +#include + +#define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x); + +static void test_no_op(void) { + grpc_histogram_destroy(grpc_histogram_create(0.01, 60e9)); +} + +static void expect_percentile(grpc_histogram* h, double percentile, + double min_expect, double max_expect) { + double got = grpc_histogram_percentile(h, percentile); + gpr_log(GPR_INFO, "@%f%%, expect %f <= %f <= %f", percentile, min_expect, got, + max_expect); + GPR_ASSERT(min_expect <= got); + GPR_ASSERT(got <= max_expect); +} + +static void test_simple(void) { + grpc_histogram* h; + + LOG_TEST("test_simple"); + + h = grpc_histogram_create(0.01, 60e9); + grpc_histogram_add(h, 10000); + grpc_histogram_add(h, 10000); + grpc_histogram_add(h, 11000); + grpc_histogram_add(h, 11000); + + expect_percentile(h, 50, 10001, 10999); + GPR_ASSERT(grpc_histogram_mean(h) == 10500); + + grpc_histogram_destroy(h); +} + +static void test_percentile(void) { + grpc_histogram* h; + double last; + double i; + double cur; + + LOG_TEST("test_percentile"); + + h = grpc_histogram_create(0.05, 1e9); + grpc_histogram_add(h, 2.5); + grpc_histogram_add(h, 2.5); + grpc_histogram_add(h, 8); + grpc_histogram_add(h, 4); + + GPR_ASSERT(grpc_histogram_count(h) == 4); + GPR_ASSERT(grpc_histogram_minimum(h) == 2.5); + GPR_ASSERT(grpc_histogram_maximum(h) == 8); + GPR_ASSERT(grpc_histogram_sum(h) == 17); + GPR_ASSERT(grpc_histogram_sum_of_squares(h) == 92.5); + GPR_ASSERT(grpc_histogram_mean(h) == 4.25); + GPR_ASSERT(grpc_histogram_variance(h) == 5.0625); + GPR_ASSERT(grpc_histogram_stddev(h) == 2.25); + + expect_percentile(h, -10, 2.5, 2.5); + expect_percentile(h, 0, 2.5, 2.5); + expect_percentile(h, 12.5, 2.5, 2.5); + expect_percentile(h, 25, 2.5, 2.5); + expect_percentile(h, 37.5, 2.5, 2.8); + expect_percentile(h, 50, 3.0, 3.5); + expect_percentile(h, 62.5, 3.5, 4.5); + expect_percentile(h, 75, 5, 7.9); + expect_percentile(h, 100, 8, 8); + expect_percentile(h, 110, 8, 8); + + /* test monotonicity */ + last = 0.0; + for (i = 0; i < 100.0; i += 0.01) { + cur = grpc_histogram_percentile(h, i); + GPR_ASSERT(cur >= last); + last = cur; + } + + grpc_histogram_destroy(h); +} + +static void test_merge(void) { + grpc_histogram *h1, *h2; + double last; + double i; + double cur; + + LOG_TEST("test_merge"); + + h1 = grpc_histogram_create(0.05, 1e9); + grpc_histogram_add(h1, 2.5); + grpc_histogram_add(h1, 2.5); + grpc_histogram_add(h1, 8); + grpc_histogram_add(h1, 4); + + h2 = grpc_histogram_create(0.01, 1e9); + GPR_ASSERT(grpc_histogram_merge(h1, h2) == 0); + grpc_histogram_destroy(h2); + + h2 = grpc_histogram_create(0.05, 1e10); + GPR_ASSERT(grpc_histogram_merge(h1, h2) == 0); + grpc_histogram_destroy(h2); + + h2 = grpc_histogram_create(0.05, 1e9); + GPR_ASSERT(grpc_histogram_merge(h1, h2) == 1); + GPR_ASSERT(grpc_histogram_count(h1) == 4); + GPR_ASSERT(grpc_histogram_minimum(h1) == 2.5); + GPR_ASSERT(grpc_histogram_maximum(h1) == 8); + GPR_ASSERT(grpc_histogram_sum(h1) == 17); + GPR_ASSERT(grpc_histogram_sum_of_squares(h1) == 92.5); + GPR_ASSERT(grpc_histogram_mean(h1) == 4.25); + GPR_ASSERT(grpc_histogram_variance(h1) == 5.0625); + GPR_ASSERT(grpc_histogram_stddev(h1) == 2.25); + grpc_histogram_destroy(h2); + + h2 = grpc_histogram_create(0.05, 1e9); + grpc_histogram_add(h2, 7.0); + grpc_histogram_add(h2, 17.0); + grpc_histogram_add(h2, 1.0); + GPR_ASSERT(grpc_histogram_merge(h1, h2) == 1); + GPR_ASSERT(grpc_histogram_count(h1) == 7); + GPR_ASSERT(grpc_histogram_minimum(h1) == 1.0); + GPR_ASSERT(grpc_histogram_maximum(h1) == 17.0); + GPR_ASSERT(grpc_histogram_sum(h1) == 42.0); + GPR_ASSERT(grpc_histogram_sum_of_squares(h1) == 431.5); + GPR_ASSERT(grpc_histogram_mean(h1) == 6.0); + + /* test monotonicity */ + last = 0.0; + for (i = 0; i < 100.0; i += 0.01) { + cur = grpc_histogram_percentile(h1, i); + GPR_ASSERT(cur >= last); + last = cur; + } + + grpc_histogram_destroy(h1); + grpc_histogram_destroy(h2); +} + +int main(void) { + test_no_op(); + test_simple(); + test_percentile(); + test_merge(); + return 0; +} diff --git a/test/core/util/memory_counters.cc b/test/core/util/memory_counters.cc new file mode 100644 index 00000000..0deb1a4d --- /dev/null +++ b/test/core/util/memory_counters.cc @@ -0,0 +1,169 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/memory_counters.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/surface/init.h" + +static struct grpc_memory_counters g_memory_counters; +static bool g_memory_counter_enabled; + +#ifdef GPR_LOW_LEVEL_COUNTERS +/* hide these from the microbenchmark atomic stats */ +#define NO_BARRIER_FETCH_ADD(x, sz) \ + __atomic_fetch_add((x), (sz), __ATOMIC_RELAXED) +#define NO_BARRIER_LOAD(x) __atomic_load_n((x), __ATOMIC_RELAXED) +#else +#define NO_BARRIER_FETCH_ADD(x, sz) gpr_atm_no_barrier_fetch_add(x, sz) +#define NO_BARRIER_LOAD(x) gpr_atm_no_barrier_load(x) +#endif + +// Memory counter uses --wrap=symbol feature from ld. To use this, +// `GPR_WRAP_MEMORY_COUNTER` needs to be defined. following options should be +// passed to the compiler. +// -Wl,--wrap=malloc -Wl,--wrap=calloc -Wl,--wrap=realloc -Wl,--wrap=free +// * Reference: https://linux.die.net/man/1/ld) +#if GPR_WRAP_MEMORY_COUNTER + +extern "C" { +void* __real_malloc(size_t size); +void* __real_calloc(size_t size); +void* __real_realloc(void* ptr, size_t size); +void __real_free(void* ptr); + +void* __wrap_malloc(size_t size); +void* __wrap_calloc(size_t size); +void* __wrap_realloc(void* ptr, size_t size); +void __wrap_free(void* ptr); +} + +void* __wrap_malloc(size_t size) { + if (!size) return nullptr; + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_absolute, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_relative, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_absolute, (gpr_atm)1); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_relative, (gpr_atm)1); + void* ptr = + __real_malloc(GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)) + size); + *static_cast(ptr) = size; + return static_cast(ptr) + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)); +} + +void* __wrap_calloc(size_t size) { + if (!size) return nullptr; + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_absolute, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_relative, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_absolute, (gpr_atm)1); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_relative, (gpr_atm)1); + void* ptr = + __real_calloc(GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)) + size); + *static_cast(ptr) = size; + return static_cast(ptr) + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)); +} + +void* __wrap_realloc(void* ptr, size_t size) { + if (ptr == nullptr) { + return __wrap_malloc(size); + } + if (size == 0) { + __wrap_free(ptr); + return nullptr; + } + void* rptr = + static_cast(ptr) - GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_absolute, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_relative, + -*static_cast(rptr)); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_relative, (gpr_atm)size); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_absolute, (gpr_atm)1); + void* new_ptr = + __real_realloc(rptr, GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)) + size); + *static_cast(new_ptr) = size; + return static_cast(new_ptr) + + GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size)); +} + +void __wrap_free(void* ptr) { + if (ptr == nullptr) return; + void* rptr = + static_cast(ptr) - GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(size_t)); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_size_relative, + -*static_cast(rptr)); + NO_BARRIER_FETCH_ADD(&g_memory_counters.total_allocs_relative, -(gpr_atm)1); + __real_free(rptr); +} + +#endif // GPR_WRAP_MEMORY_COUNTER + +void grpc_memory_counters_init() { + memset(&g_memory_counters, 0, sizeof(g_memory_counters)); + g_memory_counter_enabled = true; +} + +void grpc_memory_counters_destroy() { g_memory_counter_enabled = false; } + +struct grpc_memory_counters grpc_memory_counters_snapshot() { + struct grpc_memory_counters counters; + counters.total_size_relative = + NO_BARRIER_LOAD(&g_memory_counters.total_size_relative); + counters.total_size_absolute = + NO_BARRIER_LOAD(&g_memory_counters.total_size_absolute); + counters.total_allocs_relative = + NO_BARRIER_LOAD(&g_memory_counters.total_allocs_relative); + counters.total_allocs_absolute = + NO_BARRIER_LOAD(&g_memory_counters.total_allocs_absolute); + return counters; +} + +namespace grpc_core { +namespace testing { + +LeakDetector::LeakDetector(bool enable) : enabled_(enable) { + if (enabled_) { + grpc_memory_counters_init(); + } +} + +LeakDetector::~LeakDetector() { + // Wait for grpc_shutdown() to finish its async work. + grpc_maybe_wait_for_async_shutdown(); + if (enabled_) { + struct grpc_memory_counters counters = grpc_memory_counters_snapshot(); + if (counters.total_size_relative != 0) { + gpr_log(GPR_ERROR, "Leaking %" PRIuPTR " bytes", + static_cast(counters.total_size_relative)); + GPR_ASSERT(0); + } + grpc_memory_counters_destroy(); + } +} + +} // namespace testing +} // namespace grpc_core diff --git a/test/core/util/mock_endpoint.cc b/test/core/util/mock_endpoint.cc new file mode 100644 index 00000000..f3fa7a3e --- /dev/null +++ b/test/core/util/mock_endpoint.cc @@ -0,0 +1,142 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/mock_endpoint.h" + +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/iomgr/sockaddr.h" + +typedef struct mock_endpoint { + grpc_endpoint base; + gpr_mu mu; + void (*on_write)(grpc_slice slice); + grpc_slice_buffer read_buffer; + grpc_slice_buffer* on_read_out; + grpc_closure* on_read; + grpc_slice_allocator* slice_allocator; +} mock_endpoint; + +static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool /*urgent*/) { + mock_endpoint* m = reinterpret_cast(ep); + gpr_mu_lock(&m->mu); + if (m->read_buffer.count > 0) { + grpc_slice_buffer_swap(&m->read_buffer, slices); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + } else { + m->on_read = cb; + m->on_read_out = slices; + } + gpr_mu_unlock(&m->mu); +} + +static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* /*arg*/) { + mock_endpoint* m = reinterpret_cast(ep); + for (size_t i = 0; i < slices->count; i++) { + m->on_write(slices->slices[i]); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); +} + +static void me_add_to_pollset(grpc_endpoint* /*ep*/, + grpc_pollset* /*pollset*/) {} + +static void me_add_to_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + mock_endpoint* m = reinterpret_cast(ep); + gpr_mu_lock(&m->mu); + if (m->on_read) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, m->on_read, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Endpoint Shutdown", &why, 1)); + m->on_read = nullptr; + } + gpr_mu_unlock(&m->mu); + GRPC_ERROR_UNREF(why); +} + +static void me_destroy(grpc_endpoint* ep) { + mock_endpoint* m = reinterpret_cast(ep); + grpc_slice_buffer_destroy(&m->read_buffer); + grpc_slice_allocator_destroy(m->slice_allocator); + gpr_mu_destroy(&m->mu); + gpr_free(m); +} + +static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) { + return "fake:mock_endpoint"; +} + +static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) { + return "fake:mock_endpoint"; +} + +static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; } + +static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; } + +static const grpc_endpoint_vtable vtable = {me_read, + me_write, + me_add_to_pollset, + me_add_to_pollset_set, + me_delete_from_pollset_set, + me_shutdown, + me_destroy, + me_get_peer, + me_get_local_address, + me_get_fd, + me_can_track_err}; + +grpc_endpoint* grpc_mock_endpoint_create( + void (*on_write)(grpc_slice slice), grpc_slice_allocator* slice_allocator) { + mock_endpoint* m = static_cast(gpr_malloc(sizeof(*m))); + m->base.vtable = &vtable; + m->slice_allocator = slice_allocator; + grpc_slice_buffer_init(&m->read_buffer); + gpr_mu_init(&m->mu); + m->on_write = on_write; + m->on_read = nullptr; + return &m->base; +} + +void grpc_mock_endpoint_put_read(grpc_endpoint* ep, grpc_slice slice) { + mock_endpoint* m = reinterpret_cast(ep); + gpr_mu_lock(&m->mu); + if (m->on_read != nullptr) { + grpc_slice_buffer_add(m->on_read_out, slice); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, m->on_read, GRPC_ERROR_NONE); + m->on_read = nullptr; + } else { + grpc_slice_buffer_add(&m->read_buffer, slice); + } + gpr_mu_unlock(&m->mu); +} diff --git a/test/core/util/one_corpus_entry_fuzzer.cc b/test/core/util/one_corpus_entry_fuzzer.cc new file mode 100644 index 00000000..fa0f04f0 --- /dev/null +++ b/test/core/util/one_corpus_entry_fuzzer.cc @@ -0,0 +1,48 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/load_file.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size); + +extern bool squelch; +extern bool leak_check; + +int main(int argc, char** argv) { + grpc_slice buffer; + squelch = false; + leak_check = false; + /* TODO(yashkt) Calling grpc_init breaks tests. Fix the tests and replace + * grpc_core::ExecCtx::GlobalInit with grpc_init and GlobalShutdown with + * grpc_shutdown */ + GPR_ASSERT(argc > 1); /* Make sure that we have a filename argument */ + GPR_ASSERT( + GRPC_LOG_IF_ERROR("load_file", grpc_load_file(argv[1], 0, &buffer))); + LLVMFuzzerTestOneInput(GRPC_SLICE_START_PTR(buffer), + GRPC_SLICE_LENGTH(buffer)); + grpc_core::ExecCtx::GlobalInit(); + grpc_slice_unref(buffer); + grpc_core::ExecCtx::GlobalShutdown(); + return 0; +} diff --git a/test/core/util/parse_hexstring.cc b/test/core/util/parse_hexstring.cc new file mode 100644 index 00000000..a65ef999 --- /dev/null +++ b/test/core/util/parse_hexstring.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/parse_hexstring.h" + +#include + +grpc_slice parse_hexstring(const char* hexstring) { + size_t nibbles = 0; + const char* p = nullptr; + uint8_t* out; + uint8_t temp; + grpc_slice slice; + + for (p = hexstring; *p; p++) { + nibbles += (*p >= '0' && *p <= '9') || (*p >= 'a' && *p <= 'f'); + } + + GPR_ASSERT((nibbles & 1) == 0); + + slice = grpc_slice_malloc(nibbles / 2); + out = GRPC_SLICE_START_PTR(slice); + + nibbles = 0; + temp = 0; + for (p = hexstring; *p; p++) { + if (*p >= '0' && *p <= '9') { + temp = static_cast(temp << 4) | static_cast(*p - '0'); + nibbles++; + } else if (*p >= 'a' && *p <= 'f') { + temp = + static_cast(temp << 4) | static_cast(*p - 'a' + 10); + nibbles++; + } + if (nibbles == 2) { + *out++ = temp; + nibbles = 0; + } + } + + return slice; +} diff --git a/test/core/util/passthru_endpoint.cc b/test/core/util/passthru_endpoint.cc new file mode 100644 index 00000000..50534f7a --- /dev/null +++ b/test/core/util/passthru_endpoint.cc @@ -0,0 +1,226 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/passthru_endpoint.h" + +#include +#include + +#include + +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/resource_user_util.h" + +typedef struct passthru_endpoint passthru_endpoint; + +typedef struct { + grpc_endpoint base; + passthru_endpoint* parent; + grpc_slice_buffer read_buffer; + grpc_slice_buffer* on_read_out; + grpc_closure* on_read; + grpc_slice_allocator* slice_allocator; +} half; + +struct passthru_endpoint { + gpr_mu mu; + int halves; + grpc_passthru_endpoint_stats* stats; + bool shutdown; + half client; + half server; +}; + +static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool /*urgent*/) { + half* m = reinterpret_cast(ep); + gpr_mu_lock(&m->parent->mu); + if (m->parent->shutdown) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, cb, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Already shutdown")); + } else if (m->read_buffer.count > 0) { + grpc_slice_buffer_swap(&m->read_buffer, slices); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + } else { + m->on_read = cb; + m->on_read_out = slices; + } + gpr_mu_unlock(&m->parent->mu); +} + +static half* other_half(half* h) { + if (h == &h->parent->client) return &h->parent->server; + return &h->parent->client; +} + +static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* /*arg*/) { + half* m = other_half(reinterpret_cast(ep)); + gpr_mu_lock(&m->parent->mu); + grpc_error_handle error = GRPC_ERROR_NONE; + gpr_atm_no_barrier_fetch_add(&m->parent->stats->num_writes, (gpr_atm)1); + if (m->parent->shutdown) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Endpoint already shutdown"); + } else if (m->on_read != nullptr) { + for (size_t i = 0; i < slices->count; i++) { + grpc_slice_buffer_add(m->on_read_out, grpc_slice_copy(slices->slices[i])); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, m->on_read, GRPC_ERROR_NONE); + m->on_read = nullptr; + } else { + for (size_t i = 0; i < slices->count; i++) { + grpc_slice_buffer_add(&m->read_buffer, + grpc_slice_copy(slices->slices[i])); + } + } + gpr_mu_unlock(&m->parent->mu); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); +} + +static void me_add_to_pollset(grpc_endpoint* /*ep*/, + grpc_pollset* /*pollset*/) {} + +static void me_add_to_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + +static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + half* m = reinterpret_cast(ep); + gpr_mu_lock(&m->parent->mu); + m->parent->shutdown = true; + if (m->on_read) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, m->on_read, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Shutdown", &why, 1)); + m->on_read = nullptr; + } + m = other_half(m); + if (m->on_read) { + grpc_core::ExecCtx::Run( + DEBUG_LOCATION, m->on_read, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Shutdown", &why, 1)); + m->on_read = nullptr; + } + gpr_mu_unlock(&m->parent->mu); + GRPC_ERROR_UNREF(why); +} + +static void me_destroy(grpc_endpoint* ep) { + passthru_endpoint* p = (reinterpret_cast(ep))->parent; + gpr_mu_lock(&p->mu); + if (0 == --p->halves) { + gpr_mu_unlock(&p->mu); + gpr_mu_destroy(&p->mu); + grpc_passthru_endpoint_stats_destroy(p->stats); + grpc_slice_buffer_destroy_internal(&p->client.read_buffer); + grpc_slice_buffer_destroy_internal(&p->server.read_buffer); + grpc_slice_allocator_destroy(p->client.slice_allocator); + grpc_slice_allocator_destroy(p->server.slice_allocator); + gpr_free(p); + } else { + gpr_mu_unlock(&p->mu); + } +} + +static absl::string_view me_get_peer(grpc_endpoint* ep) { + passthru_endpoint* p = (reinterpret_cast(ep))->parent; + return (reinterpret_cast(ep)) == &p->client + ? "fake:mock_client_endpoint" + : "fake:mock_server_endpoint"; +} + +static absl::string_view me_get_local_address(grpc_endpoint* ep) { + passthru_endpoint* p = (reinterpret_cast(ep))->parent; + return (reinterpret_cast(ep)) == &p->client + ? "fake:mock_client_endpoint" + : "fake:mock_server_endpoint"; +} + +static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; } + +static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; } + +static const grpc_endpoint_vtable vtable = { + me_read, + me_write, + me_add_to_pollset, + me_add_to_pollset_set, + me_delete_from_pollset_set, + me_shutdown, + me_destroy, + me_get_peer, + me_get_local_address, + me_get_fd, + me_can_track_err, +}; + +static void half_init(half* m, passthru_endpoint* parent, + grpc_slice_allocator* slice_allocator, + const char* half_name) { + m->base.vtable = &vtable; + m->parent = parent; + grpc_slice_buffer_init(&m->read_buffer); + m->on_read = nullptr; + std::string name = + absl::StrFormat("passthru_endpoint_%s_%p", half_name, parent); + m->slice_allocator = slice_allocator; +} + +void grpc_passthru_endpoint_create(grpc_endpoint** client, + grpc_endpoint** server, + grpc_passthru_endpoint_stats* stats) { + passthru_endpoint* m = + static_cast(gpr_malloc(sizeof(*m))); + m->halves = 2; + m->shutdown = false; + if (stats == nullptr) { + m->stats = grpc_passthru_endpoint_stats_create(); + } else { + gpr_ref(&stats->refs); + m->stats = stats; + } + half_init(&m->client, m, grpc_slice_allocator_create_unlimited(), "client"); + half_init(&m->server, m, grpc_slice_allocator_create_unlimited(), "server"); + gpr_mu_init(&m->mu); + *client = &m->client.base; + *server = &m->server.base; +} + +grpc_passthru_endpoint_stats* grpc_passthru_endpoint_stats_create() { + grpc_passthru_endpoint_stats* stats = + static_cast( + gpr_malloc(sizeof(grpc_passthru_endpoint_stats))); + memset(stats, 0, sizeof(*stats)); + gpr_ref_init(&stats->refs, 1); + return stats; +} + +void grpc_passthru_endpoint_stats_destroy(grpc_passthru_endpoint_stats* stats) { + if (gpr_unref(&stats->refs)) { + gpr_free(stats); + } +} diff --git a/test/core/util/port.cc b/test/core/util/port.cc new file mode 100644 index 00000000..5a2e0626 --- /dev/null +++ b/test/core/util/port.cc @@ -0,0 +1,143 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +#include "test/core/util/test_config.h" +#if defined(GRPC_TEST_PICK_PORT) + +#include +#include +#include + +#include +#include +#include +#include + +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/http/httpcli.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "test/core/util/port.h" +#include "test/core/util/port_server_client.h" + +static int* chosen_ports = nullptr; +static size_t num_chosen_ports = 0; +static grpc_core::Mutex* g_default_port_picker_mu; +static gpr_once g_default_port_picker_init = GPR_ONCE_INIT; + +static void init_default_port_picker() { + g_default_port_picker_mu = new grpc_core::Mutex(); +} + +static int free_chosen_port_locked(int port) { + size_t i; + int found = 0; + size_t found_at = 0; + /* Find the port and erase it from the list, then tell the server it can be + freed. */ + for (i = 0; i < num_chosen_ports; i++) { + if (chosen_ports[i] == port) { + GPR_ASSERT(found == 0); + found = 1; + found_at = i; + } + } + if (found) { + chosen_ports[found_at] = chosen_ports[num_chosen_ports - 1]; + num_chosen_ports--; + grpc_free_port_using_server(port); + } + return found; +} + +static void free_chosen_ports(void) { + grpc_core::MutexLock lock(g_default_port_picker_mu); + size_t i; + grpc_init(); + for (i = 0; i < num_chosen_ports; i++) { + grpc_free_port_using_server(chosen_ports[i]); + } + grpc_shutdown(); + gpr_free(chosen_ports); +} + +static void chose_port_locked(int port) { + if (chosen_ports == nullptr) { + atexit(free_chosen_ports); + } + num_chosen_ports++; + chosen_ports = static_cast( + gpr_realloc(chosen_ports, sizeof(int) * num_chosen_ports)); + chosen_ports[num_chosen_ports - 1] = port; +} + +static int grpc_pick_unused_port_impl(void) { + gpr_once_init(&g_default_port_picker_init, init_default_port_picker); + grpc_core::MutexLock lock(g_default_port_picker_mu); + int port = grpc_pick_port_using_server(); + if (port != 0) { + chose_port_locked(port); + } + + return port; +} + +static int grpc_pick_unused_port_or_die_impl(void) { + int port = grpc_pick_unused_port(); + if (port == 0) { + fprintf(stderr, + "gRPC tests require a helper port server to allocate ports used \n" + "during the test.\n\n" + "This server is not currently running.\n\n" + "To start it, run tools/run_tests/start_port_server.py\n\n"); + exit(1); + } + return port; +} + +static void grpc_recycle_unused_port_impl(int port) { + gpr_once_init(&g_default_port_picker_init, init_default_port_picker); + grpc_core::MutexLock lock(g_default_port_picker_mu); + GPR_ASSERT(free_chosen_port_locked(port)); +} + +static grpc_pick_port_functions g_pick_port_functions = { + grpc_pick_unused_port_impl, grpc_pick_unused_port_or_die_impl, + grpc_recycle_unused_port_impl}; + +int grpc_pick_unused_port(void) { + return g_pick_port_functions.pick_unused_port_fn(); +} + +int grpc_pick_unused_port_or_die(void) { + return g_pick_port_functions.pick_unused_port_or_die_fn(); +} + +void grpc_recycle_unused_port(int port) { + g_pick_port_functions.recycle_unused_port_fn(port); +} + +void grpc_set_pick_port_functions(grpc_pick_port_functions functions) { + GPR_ASSERT(functions.pick_unused_port_fn != nullptr); + GPR_ASSERT(functions.pick_unused_port_or_die_fn != nullptr); + GPR_ASSERT(functions.recycle_unused_port_fn != nullptr); + g_pick_port_functions = functions; +} + +#endif /* GRPC_TEST_PICK_PORT */ diff --git a/test/core/util/port_isolated_runtime_environment.cc b/test/core/util/port_isolated_runtime_environment.cc new file mode 100644 index 00000000..6d32aeb4 --- /dev/null +++ b/test/core/util/port_isolated_runtime_environment.cc @@ -0,0 +1,71 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* When individual tests run in an isolated runtime environment (e.g. each test + * runs in a separate container) the framework takes a round-robin pick of a + * port within certain range. There is no need to recycle ports. + */ +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/port.h" +#include "test/core/util/test_config.h" +#if defined(GRPC_PORT_ISOLATED_RUNTIME) + +#include "test/core/util/port.h" + +#define MIN_PORT 1025 +#define MAX_PORT 32766 + +static int get_random_port_offset() { + srand(gpr_now(GPR_CLOCK_REALTIME).tv_nsec); + double rnd = static_cast(rand()) / + (static_cast(RAND_MAX) + 1.0); // values from [0,1) + return static_cast(rnd * (MAX_PORT - MIN_PORT + 1)); +} + +static int s_initial_offset = get_random_port_offset(); +static gpr_atm s_pick_counter = 0; + +static int grpc_pick_unused_port_or_die_impl(void) { + int orig_counter_val = + static_cast(gpr_atm_full_fetch_add(&s_pick_counter, 1)); + GPR_ASSERT(orig_counter_val < (MAX_PORT - MIN_PORT + 1)); + return MIN_PORT + + (s_initial_offset + orig_counter_val) % (MAX_PORT - MIN_PORT + 1); +} + +int grpc_pick_unused_port_or_die(void) { + while (true) { + int port = grpc_pick_unused_port_or_die_impl(); + // 5985 cannot be bound on Windows RBE and results in + // WSA_ERROR 10013: "An attempt was made to access a socket in a way + // forbidden by its access permissions." + if (port == 5985) { + continue; + } + return port; + } +} + +void grpc_recycle_unused_port(int port) { (void)port; } + +#endif /* GRPC_PORT_ISOLATED_RUNTIME */ diff --git a/test/core/util/port_server_client.cc b/test/core/util/port_server_client.cc new file mode 100644 index 00000000..70a253f1 --- /dev/null +++ b/test/core/util/port_server_client.cc @@ -0,0 +1,252 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/test_config.h" + +#ifdef GRPC_TEST_PICK_PORT +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/http/httpcli.h" +#include "test/core/util/port_server_client.h" + +typedef struct freereq { + gpr_mu* mu = nullptr; + grpc_polling_entity pops = {}; + int done = 0; +} freereq; + +static void destroy_pops_and_shutdown(void* p, grpc_error_handle /*error*/) { + grpc_pollset* pollset = + grpc_polling_entity_pollset(static_cast(p)); + grpc_pollset_destroy(pollset); + gpr_free(pollset); +} + +static void freed_port_from_server(void* arg, grpc_error_handle /*error*/) { + freereq* pr = static_cast(arg); + gpr_mu_lock(pr->mu); + pr->done = 1; + GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&pr->pops), nullptr)); + gpr_mu_unlock(pr->mu); +} + +void grpc_free_port_using_server(int port) { + grpc_httpcli_context context; + grpc_httpcli_request req; + grpc_httpcli_response rsp; + freereq pr; + char* path; + grpc_closure* shutdown_closure; + + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + + pr = {}; + memset(&req, 0, sizeof(req)); + rsp = {}; + + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &pr.mu); + pr.pops = grpc_polling_entity_create_from_pollset(pollset); + shutdown_closure = GRPC_CLOSURE_CREATE(destroy_pops_and_shutdown, &pr.pops, + grpc_schedule_on_exec_ctx); + + req.host = const_cast(GRPC_PORT_SERVER_ADDRESS); + gpr_asprintf(&path, "/drop/%d", port); + req.http.path = path; + + grpc_httpcli_context_init(&context); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("port_server_client/free"); + grpc_httpcli_get(&context, &pr.pops, resource_quota, &req, + grpc_core::ExecCtx::Get()->Now() + 30 * GPR_MS_PER_SEC, + GRPC_CLOSURE_CREATE(freed_port_from_server, &pr, + grpc_schedule_on_exec_ctx), + &rsp); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(pr.mu); + while (!pr.done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work( + grpc_polling_entity_pollset(&pr.pops), &worker, + grpc_core::ExecCtx::Get()->Now() + GPR_MS_PER_SEC))) { + pr.done = 1; + } + } + gpr_mu_unlock(pr.mu); + + grpc_httpcli_context_destroy(&context); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&pr.pops), + shutdown_closure); + + gpr_free(path); + grpc_http_response_destroy(&rsp); + } + grpc_shutdown(); +} + +typedef struct portreq { + gpr_mu* mu = nullptr; + grpc_polling_entity pops = {}; + int port = 0; + int retries = 0; + char* server = nullptr; + grpc_httpcli_context* ctx = nullptr; + grpc_httpcli_response response = {}; +} portreq; + +static void got_port_from_server(void* arg, grpc_error_handle error) { + size_t i; + int port = 0; + portreq* pr = static_cast(arg); + int failed = 0; + grpc_httpcli_response* response = &pr->response; + + if (error != GRPC_ERROR_NONE) { + failed = 1; + gpr_log(GPR_DEBUG, "failed port pick from server: retrying [%s]", + grpc_error_std_string(error).c_str()); + } else if (response->status != 200) { + failed = 1; + gpr_log(GPR_DEBUG, "failed port pick from server: status=%d", + response->status); + } + + if (failed) { + grpc_httpcli_request req; + memset(&req, 0, sizeof(req)); + if (pr->retries >= 5) { + gpr_mu_lock(pr->mu); + pr->port = 0; + GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&pr->pops), nullptr)); + gpr_mu_unlock(pr->mu); + return; + } + GPR_ASSERT(pr->retries < 10); + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis( + static_cast( + 1000.0 * (1 + pow(1.3, pr->retries) * rand() / RAND_MAX)), + GPR_TIMESPAN))); + pr->retries++; + req.host = pr->server; + req.http.path = const_cast("/get"); + grpc_http_response_destroy(&pr->response); + pr->response = {}; + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("port_server_client/pick_retry"); + grpc_httpcli_get(pr->ctx, &pr->pops, resource_quota, &req, + grpc_core::ExecCtx::Get()->Now() + 30 * GPR_MS_PER_SEC, + GRPC_CLOSURE_CREATE(got_port_from_server, pr, + grpc_schedule_on_exec_ctx), + &pr->response); + return; + } + GPR_ASSERT(response); + GPR_ASSERT(response->status == 200); + for (i = 0; i < response->body_length; i++) { + GPR_ASSERT(response->body[i] >= '0' && response->body[i] <= '9'); + port = port * 10 + response->body[i] - '0'; + } + GPR_ASSERT(port > 1024); + gpr_mu_lock(pr->mu); + pr->port = port; + GRPC_LOG_IF_ERROR( + "pollset_kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&pr->pops), nullptr)); + gpr_mu_unlock(pr->mu); +} + +int grpc_pick_port_using_server(void) { + grpc_httpcli_context context; + grpc_httpcli_request req; + portreq pr; + grpc_closure* shutdown_closure; + + grpc_init(); + { + grpc_core::ExecCtx exec_ctx; + pr = {}; + memset(&req, 0, sizeof(req)); + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &pr.mu); + pr.pops = grpc_polling_entity_create_from_pollset(pollset); + shutdown_closure = GRPC_CLOSURE_CREATE(destroy_pops_and_shutdown, &pr.pops, + grpc_schedule_on_exec_ctx); + pr.port = -1; + pr.server = const_cast(GRPC_PORT_SERVER_ADDRESS); + pr.ctx = &context; + + req.host = const_cast(GRPC_PORT_SERVER_ADDRESS); + req.http.path = const_cast("/get"); + + grpc_httpcli_context_init(&context); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("port_server_client/pick"); + grpc_httpcli_get(&context, &pr.pops, resource_quota, &req, + grpc_core::ExecCtx::Get()->Now() + 30 * GPR_MS_PER_SEC, + GRPC_CLOSURE_CREATE(got_port_from_server, &pr, + grpc_schedule_on_exec_ctx), + &pr.response); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(pr.mu); + while (pr.port == -1) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work( + grpc_polling_entity_pollset(&pr.pops), &worker, + grpc_core::ExecCtx::Get()->Now() + GPR_MS_PER_SEC))) { + pr.port = 0; + } + } + gpr_mu_unlock(pr.mu); + + grpc_http_response_destroy(&pr.response); + grpc_httpcli_context_destroy(&context); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&pr.pops), + shutdown_closure); + + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_shutdown(); + + return pr.port; +} + +#endif // GRPC_TEST_PICK_PORT diff --git a/test/core/util/reconnect_server.cc b/test/core/util/reconnect_server.cc new file mode 100644 index 00000000..93a149bf --- /dev/null +++ b/test/core/util/reconnect_server.cc @@ -0,0 +1,131 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/reconnect_server.h" + +#include + +#include "absl/strings/string_view.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "test/core/util/port.h" +#include "test/core/util/test_tcp_server.h" + +static void pretty_print_backoffs(reconnect_server* server) { + gpr_timespec diff; + int i = 1; + double expected_backoff = 1000.0, backoff; + timestamp_list* head = server->head; + gpr_log(GPR_INFO, "reconnect server: new connection"); + for (head = server->head; head && head->next; head = head->next, i++) { + diff = gpr_time_sub(head->next->timestamp, head->timestamp); + backoff = gpr_time_to_millis(diff); + gpr_log(GPR_INFO, + "retry %2d:backoff %6.2fs,expected backoff %6.2fs, jitter %4.2f%%", + i, backoff / 1000.0, expected_backoff / 1000.0, + (backoff - expected_backoff) * 100.0 / expected_backoff); + expected_backoff *= 1.6; + int max_reconnect_backoff_ms = 120 * 1000; + if (server->max_reconnect_backoff_ms > 0) { + max_reconnect_backoff_ms = server->max_reconnect_backoff_ms; + } + if (expected_backoff > max_reconnect_backoff_ms) { + expected_backoff = max_reconnect_backoff_ms; + } + } +} + +static void on_connect(void* arg, grpc_endpoint* tcp, + grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + gpr_free(acceptor); + absl::string_view peer; + absl::string_view::size_type last_colon; + reconnect_server* server = static_cast(arg); + gpr_timespec now = gpr_now(GPR_CLOCK_REALTIME); + timestamp_list* new_tail; + peer = grpc_endpoint_get_peer(tcp); + grpc_endpoint_shutdown(tcp, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Connected")); + grpc_endpoint_destroy(tcp); + last_colon = peer.rfind(':'); + if (server->peer == nullptr) { + server->peer = new std::string(peer); + } else { + if (last_colon == std::string::npos) { + gpr_log(GPR_ERROR, "peer does not contain a ':'"); + } else if (peer.compare(0, static_cast(last_colon), + *server->peer) != 0) { + gpr_log(GPR_ERROR, "mismatched peer! %s vs %s", server->peer->c_str(), + std::string(peer).c_str()); + } + } + new_tail = static_cast(gpr_malloc(sizeof(timestamp_list))); + new_tail->timestamp = now; + new_tail->next = nullptr; + if (server->tail == nullptr) { + server->head = new_tail; + server->tail = new_tail; + } else { + server->tail->next = new_tail; + server->tail = new_tail; + } + pretty_print_backoffs(server); +} + +void reconnect_server_init(reconnect_server* server) { + test_tcp_server_init(&server->tcp_server, on_connect, server); + server->head = nullptr; + server->tail = nullptr; + server->peer = nullptr; + server->max_reconnect_backoff_ms = 0; +} + +void reconnect_server_start(reconnect_server* server, int port) { + test_tcp_server_start(&server->tcp_server, port); +} + +void reconnect_server_poll(reconnect_server* server, int seconds) { + test_tcp_server_poll(&server->tcp_server, 1000 * seconds); +} + +void reconnect_server_clear_timestamps(reconnect_server* server) { + timestamp_list* new_head = server->head; + while (server->head) { + new_head = server->head->next; + gpr_free(server->head); + server->head = new_head; + } + server->tail = nullptr; + delete server->peer; + server->peer = nullptr; +} + +void reconnect_server_destroy(reconnect_server* server) { + reconnect_server_clear_timestamps(server); + test_tcp_server_destroy(&server->tcp_server); +} diff --git a/test/core/util/resolve_localhost_ip46.cc b/test/core/util/resolve_localhost_ip46.cc new file mode 100644 index 00000000..bf3f621c --- /dev/null +++ b/test/core/util/resolve_localhost_ip46.cc @@ -0,0 +1,58 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "test/core/util/resolve_localhost_ip46.h" + +#include + +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" + +namespace grpc_core { +namespace { + +bool localhost_to_ipv4 = false; +bool localhost_to_ipv6 = false; +gpr_once g_resolve_localhost_ipv46 = GPR_ONCE_INIT; + +void InitResolveLocalhost() { + grpc_resolved_addresses* addresses; + grpc_error_handle err = + grpc_blocking_resolve_address("localhost", "https", &addresses); + GPR_ASSERT(err == GRPC_ERROR_NONE); + for (size_t i = 0; i < addresses->naddrs; i++) { + grpc_sockaddr* addr = + reinterpret_cast(addresses->addrs[i].addr); + if (addr->sa_family == GRPC_AF_INET) { + localhost_to_ipv4 = true; + } else if (addr->sa_family == GRPC_AF_INET6) { + localhost_to_ipv6 = true; + } + } + grpc_resolved_addresses_destroy(addresses); +} +} // namespace + +void LocalhostResolves(bool* ipv4, bool* ipv6) { + gpr_once_init(&g_resolve_localhost_ipv46, InitResolveLocalhost); + *ipv4 = localhost_to_ipv4; + *ipv6 = localhost_to_ipv6; +} + +} // namespace grpc_core diff --git a/test/core/util/resource_user_util.cc b/test/core/util/resource_user_util.cc new file mode 100644 index 00000000..2fa609c9 --- /dev/null +++ b/test/core/util/resource_user_util.cc @@ -0,0 +1,43 @@ +// Copyright 2021 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/util/resource_user_util.h" + +#include "absl/strings/str_format.h" + +grpc_resource_user* grpc_resource_user_create_unlimited( + grpc_resource_quota* resource_quota) { + if (resource_quota == nullptr) { + resource_quota = grpc_resource_quota_create("anonymous mock quota"); + } else { + grpc_resource_quota_ref_internal(resource_quota); + } + grpc_resource_user* ru = nullptr; + ru = grpc_resource_user_create( + resource_quota, absl::StrFormat("mock_resource_user_%" PRIxPTR, + reinterpret_cast(&ru)) + .c_str()); + grpc_resource_quota_unref_internal(resource_quota); + return ru; +} + +grpc_slice_allocator* grpc_slice_allocator_create_unlimited() { + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("anonymous mock quota"); + grpc_slice_allocator* slice_allocator = grpc_slice_allocator_create( + resource_quota, + absl::StrFormat("mock_resource_user_from_quota:%p", resource_quota)); + grpc_resource_quota_unref(resource_quota); + return slice_allocator; +} diff --git a/test/core/util/slice_splitter.cc b/test/core/util/slice_splitter.cc new file mode 100644 index 00000000..82864d6a --- /dev/null +++ b/test/core/util/slice_splitter.cc @@ -0,0 +1,128 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/slice_splitter.h" + +#include + +#include + +#include + +#include "src/core/lib/gpr/useful.h" + +const char* grpc_slice_split_mode_name(grpc_slice_split_mode mode) { + switch (mode) { + case GRPC_SLICE_SPLIT_IDENTITY: + return "identity"; + case GRPC_SLICE_SPLIT_MERGE_ALL: + return "merge_all"; + case GRPC_SLICE_SPLIT_ONE_BYTE: + return "one_byte"; + } + return "error"; +} + +void grpc_split_slices(grpc_slice_split_mode mode, grpc_slice* src_slices, + size_t src_slice_count, grpc_slice** dst_slices, + size_t* dst_slice_count) { + size_t i, j; + size_t length; + + switch (mode) { + case GRPC_SLICE_SPLIT_IDENTITY: + *dst_slice_count = src_slice_count; + *dst_slices = static_cast( + gpr_malloc(sizeof(grpc_slice) * src_slice_count)); + for (i = 0; i < src_slice_count; i++) { + (*dst_slices)[i] = src_slices[i]; + grpc_slice_ref((*dst_slices)[i]); + } + break; + case GRPC_SLICE_SPLIT_MERGE_ALL: + *dst_slice_count = 1; + length = 0; + for (i = 0; i < src_slice_count; i++) { + length += GRPC_SLICE_LENGTH(src_slices[i]); + } + *dst_slices = static_cast(gpr_malloc(sizeof(grpc_slice))); + **dst_slices = grpc_slice_malloc(length); + length = 0; + for (i = 0; i < src_slice_count; i++) { + memcpy(GRPC_SLICE_START_PTR(**dst_slices) + length, + GRPC_SLICE_START_PTR(src_slices[i]), + GRPC_SLICE_LENGTH(src_slices[i])); + length += GRPC_SLICE_LENGTH(src_slices[i]); + } + break; + case GRPC_SLICE_SPLIT_ONE_BYTE: + length = 0; + for (i = 0; i < src_slice_count; i++) { + length += GRPC_SLICE_LENGTH(src_slices[i]); + } + *dst_slice_count = length; + *dst_slices = + static_cast(gpr_malloc(sizeof(grpc_slice) * length)); + length = 0; + for (i = 0; i < src_slice_count; i++) { + for (j = 0; j < GRPC_SLICE_LENGTH(src_slices[i]); j++) { + (*dst_slices)[length] = grpc_slice_sub(src_slices[i], j, j + 1); + length++; + } + } + break; + } +} + +void grpc_split_slices_to_buffer(grpc_slice_split_mode mode, + grpc_slice* src_slices, size_t src_slice_count, + grpc_slice_buffer* dst) { + grpc_slice* slices; + size_t nslices; + size_t i; + grpc_split_slices(mode, src_slices, src_slice_count, &slices, &nslices); + for (i = 0; i < nslices; i++) { + /* add indexed to avoid re-merging split slices */ + grpc_slice_buffer_add_indexed(dst, slices[i]); + } + gpr_free(slices); +} + +void grpc_split_slice_buffer(grpc_slice_split_mode mode, grpc_slice_buffer* src, + grpc_slice_buffer* dst) { + grpc_split_slices_to_buffer(mode, src->slices, src->count, dst); +} + +grpc_slice grpc_slice_merge(grpc_slice* slices, size_t nslices) { + uint8_t* out = nullptr; + size_t length = 0; + size_t capacity = 0; + size_t i; + + for (i = 0; i < nslices; i++) { + if (GRPC_SLICE_LENGTH(slices[i]) + length > capacity) { + capacity = std::max(capacity * 2, GRPC_SLICE_LENGTH(slices[i]) + length); + out = static_cast(gpr_realloc(out, capacity)); + } + memcpy(out + length, GRPC_SLICE_START_PTR(slices[i]), + GRPC_SLICE_LENGTH(slices[i])); + length += GRPC_SLICE_LENGTH(slices[i]); + } + + return grpc_slice_new(out, length, gpr_free); +} diff --git a/test/core/util/stack_tracer.cc b/test/core/util/stack_tracer.cc new file mode 100644 index 00000000..f611ff1e --- /dev/null +++ b/test/core/util/stack_tracer.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2020 the gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/stack_tracer.h" + +#include +#include + +#include "absl/debugging/stacktrace.h" +#include "absl/debugging/symbolize.h" + +#include "src/core/lib/gprpp/examine_stack.h" + +namespace { + +static constexpr int kPrintfPointerFieldWidth = 2 + 2 * sizeof(void*); + +static void DumpPCAndFrameSizeAndSymbol(void (*writerfn)(const char*, void*), + void* writerfn_arg, void* pc, + void* symbolize_pc, int framesize, + const char* const prefix) { + char tmp[1024]; + const char* symbol = "(unknown)"; + if (absl::Symbolize(symbolize_pc, tmp, sizeof(tmp))) { + symbol = tmp; + } + char buf[1024]; + if (framesize <= 0) { + snprintf(buf, sizeof(buf), "%s@ %*p (unknown) %s\n", prefix, + kPrintfPointerFieldWidth, pc, symbol); + } else { + snprintf(buf, sizeof(buf), "%s@ %*p %9d %s\n", prefix, + kPrintfPointerFieldWidth, pc, framesize, symbol); + } + writerfn(buf, writerfn_arg); +} + +static void DumpPCAndFrameSize(void (*writerfn)(const char*, void*), + void* writerfn_arg, void* pc, int framesize, + const char* const prefix) { + char buf[100]; + if (framesize <= 0) { + snprintf(buf, sizeof(buf), "%s@ %*p (unknown)\n", prefix, + kPrintfPointerFieldWidth, pc); + } else { + snprintf(buf, sizeof(buf), "%s@ %*p %9d\n", prefix, + kPrintfPointerFieldWidth, pc, framesize); + } + writerfn(buf, writerfn_arg); +} + +static void DumpStackTrace(void* const stack[], int frame_sizes[], int depth, + bool symbolize_stacktrace, + void (*writerfn)(const char*, void*), + void* writerfn_arg) { + for (int i = 0; i < depth; i++) { + if (symbolize_stacktrace) { + DumpPCAndFrameSizeAndSymbol(writerfn, writerfn_arg, stack[i], + reinterpret_cast(stack[i]) - 1, + frame_sizes[i], " "); + } else { + DumpPCAndFrameSize(writerfn, writerfn_arg, stack[i], frame_sizes[i], + " "); + } + } +} + +static void DebugWriteToString(const char* data, void* str) { + reinterpret_cast(str)->append(data); +} + +} // namespace + +namespace grpc_core { +namespace testing { + +std::string GetCurrentStackTrace() { + std::string result = "Stack trace:\n"; + constexpr int kNumStackFrames = 32; + void* stack[kNumStackFrames]; + int frame_sizes[kNumStackFrames]; + int depth = absl::GetStackFrames(stack, frame_sizes, kNumStackFrames, 1); + DumpStackTrace(stack, frame_sizes, depth, true, DebugWriteToString, &result); + return result; +} + +void InitializeStackTracer(const char* argv0) { + absl::InitializeSymbolizer(argv0); + grpc_core::SetCurrentStackTraceProvider(&GetCurrentStackTrace); +} + +} // namespace testing +} // namespace grpc_core diff --git a/test/core/util/stack_tracer_test.cc b/test/core/util/stack_tracer_test.cc new file mode 100644 index 00000000..3687fdde --- /dev/null +++ b/test/core/util/stack_tracer_test.cc @@ -0,0 +1,45 @@ +/* + * + * Copyright 2020 the gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/stack_tracer.h" + +#include + +#include + +#include "absl/debugging/symbolize.h" +#include "absl/strings/match.h" + +#include + +#include "test/core/util/test_config.h" + +TEST(StackTracerTest, Basic) { + std::string stack_trace = grpc_core::testing::GetCurrentStackTrace(); + gpr_log(GPR_INFO, "stack_trace=%s", stack_trace.c_str()); +#if !defined(NDEBUG) && !defined(GPR_MUSL_LIBC_COMPAT) + EXPECT_TRUE(absl::StrContains(stack_trace, "Basic")); +#endif +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/core/util/subprocess_posix.cc b/test/core/util/subprocess_posix.cc new file mode 100644 index 00000000..61363142 --- /dev/null +++ b/test/core/util/subprocess_posix.cc @@ -0,0 +1,100 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_POSIX_SUBPROCESS + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "src/core/lib/gprpp/memory.h" +#include "test/core/util/subprocess.h" + +struct gpr_subprocess { + int pid; + bool joined; +}; + +const char* gpr_subprocess_binary_extension() { return ""; } + +gpr_subprocess* gpr_subprocess_create(int argc, const char** argv) { + gpr_subprocess* r; + int pid; + char** exec_args; + + pid = fork(); + if (pid == -1) { + return nullptr; + } else if (pid == 0) { + exec_args = static_cast( + gpr_malloc((static_cast(argc) + 1) * sizeof(char*))); + memcpy(exec_args, argv, static_cast(argc) * sizeof(char*)); + exec_args[argc] = nullptr; + execv(exec_args[0], exec_args); + /* if we reach here, an error has occurred */ + gpr_log(GPR_ERROR, "execv '%s' failed: %s", exec_args[0], strerror(errno)); + _exit(1); + } else { + r = grpc_core::Zalloc(); + r->pid = pid; + return r; + } +} + +void gpr_subprocess_destroy(gpr_subprocess* p) { + if (!p->joined) { + kill(p->pid, SIGKILL); + gpr_subprocess_join(p); + } + gpr_free(p); +} + +int gpr_subprocess_join(gpr_subprocess* p) { + int status; +retry: + if (waitpid(p->pid, &status, 0) == -1) { + if (errno == EINTR) { + goto retry; + } + gpr_log(GPR_ERROR, "waitpid failed for pid %d: %s", p->pid, + strerror(errno)); + return -1; + } + p->joined = true; + return status; +} + +void gpr_subprocess_interrupt(gpr_subprocess* p) { + if (!p->joined) { + kill(p->pid, SIGINT); + } +} + +#endif /* GPR_POSIX_SUBPROCESS */ diff --git a/test/core/util/subprocess_windows.cc b/test/core/util/subprocess_windows.cc new file mode 100644 index 00000000..7d69af13 --- /dev/null +++ b/test/core/util/subprocess_windows.cc @@ -0,0 +1,127 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#ifdef GPR_WINDOWS_SUBPROCESS + +#include +#include +#include + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/string_windows.h" +#include "test/core/util/subprocess.h" + +struct gpr_subprocess { + PROCESS_INFORMATION pi; + int joined; + int interrupted; +}; + +const char* gpr_subprocess_binary_extension() { return ".exe"; } + +gpr_subprocess* gpr_subprocess_create(int argc, const char** argv) { + gpr_subprocess* r; + + STARTUPINFO si; + PROCESS_INFORMATION pi; + + char* args = gpr_strjoin_sep(argv, (size_t)argc, " ", NULL); + TCHAR* args_tchar; + + args_tchar = gpr_char_to_tchar(args); + gpr_free(args); + + memset(&si, 0, sizeof(si)); + si.cb = sizeof(si); + memset(&pi, 0, sizeof(pi)); + + if (!CreateProcess(NULL, args_tchar, NULL, NULL, FALSE, + CREATE_NEW_PROCESS_GROUP, NULL, NULL, &si, &pi)) { + gpr_free(args_tchar); + return NULL; + } + gpr_free(args_tchar); + + r = (gpr_subprocess*)gpr_malloc(sizeof(gpr_subprocess)); + memset(r, 0, sizeof(*r)); + r->pi = pi; + return r; +} + +void gpr_subprocess_destroy(gpr_subprocess* p) { + if (p) { + if (!p->joined) { + gpr_subprocess_interrupt(p); + gpr_subprocess_join(p); + } + if (p->pi.hProcess) { + CloseHandle(p->pi.hProcess); + } + if (p->pi.hThread) { + CloseHandle(p->pi.hThread); + } + gpr_free(p); + } +} + +int gpr_subprocess_join(gpr_subprocess* p) { + DWORD dwExitCode; + if (GetExitCodeProcess(p->pi.hProcess, &dwExitCode)) { + if (dwExitCode == STILL_ACTIVE) { + if (WaitForSingleObject(p->pi.hProcess, INFINITE) == WAIT_OBJECT_0) { + p->joined = 1; + goto getExitCode; + } + return -1; // failed to join + } else { + goto getExitCode; + } + } else { + return -1; // failed to get exit code + } + +getExitCode: + if (p->interrupted) { + return 0; + } + if (GetExitCodeProcess(p->pi.hProcess, &dwExitCode)) { + return (int)dwExitCode; + } else { + return -1; // failed to get exit code + } +} + +void gpr_subprocess_interrupt(gpr_subprocess* p) { + DWORD dwExitCode; + if (GetExitCodeProcess(p->pi.hProcess, &dwExitCode)) { + if (dwExitCode == STILL_ACTIVE) { + gpr_log(GPR_INFO, "sending ctrl-break"); + GenerateConsoleCtrlEvent(CTRL_BREAK_EVENT, p->pi.dwProcessId); + p->joined = 1; + p->interrupted = 1; + } + } + return; +} + +#endif /* GPR_WINDOWS_SUBPROCESS */ diff --git a/test/core/util/test_config.cc b/test/core/util/test_config.cc new file mode 100644 index 00000000..09952c0b --- /dev/null +++ b/test/core/util/test_config.cc @@ -0,0 +1,219 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/test_config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/debugging/failure_signal_handler.h" +#include "absl/debugging/symbolize.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/gprpp/examine_stack.h" +#include "src/core/lib/surface/init.h" +#include "test/core/util/stack_tracer.h" + +int64_t g_fixture_slowdown_factor = 1; +int64_t g_poller_slowdown_factor = 1; + +#if GPR_GETPID_IN_UNISTD_H +#include +static unsigned seed(void) { return static_cast(getpid()); } +#endif + +#if GPR_GETPID_IN_PROCESS_H +#include +static unsigned seed(void) { return (unsigned)_getpid(); } +#endif + +bool BuiltUnderValgrind() { +#ifdef RUNNING_ON_VALGRIND + return true; +#else + return false; +#endif +} + +bool BuiltUnderTsan() { +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) + return true; +#else + return false; +#endif +#else +#ifdef THREAD_SANITIZER + return true; +#else + return false; +#endif +#endif +} + +bool BuiltUnderAsan() { +#if defined(__has_feature) +#if __has_feature(address_sanitizer) + return true; +#else + return false; +#endif +#else +#ifdef ADDRESS_SANITIZER + return true; +#else + return false; +#endif +#endif +} + +bool BuiltUnderMsan() { +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) + return true; +#else + return false; +#endif +#else +#ifdef MEMORY_SANITIZER + return true; +#else + return false; +#endif +#endif +} + +bool BuiltUnderUbsan() { +#ifdef GRPC_UBSAN + return true; +#else + return false; +#endif +} + +int64_t grpc_test_sanitizer_slowdown_factor() { + int64_t sanitizer_multiplier = 1; + if (BuiltUnderValgrind()) { + sanitizer_multiplier = 20; + } else if (BuiltUnderTsan()) { + sanitizer_multiplier = 5; + } else if (BuiltUnderAsan()) { + sanitizer_multiplier = 3; + } else if (BuiltUnderMsan()) { + sanitizer_multiplier = 4; + } else if (BuiltUnderUbsan()) { + sanitizer_multiplier = 5; + } + return sanitizer_multiplier; +} + +int64_t grpc_test_slowdown_factor() { + return grpc_test_sanitizer_slowdown_factor() * g_fixture_slowdown_factor * + g_poller_slowdown_factor; +} + +gpr_timespec grpc_timeout_seconds_to_deadline(int64_t time_s) { + return gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis( + grpc_test_slowdown_factor() * static_cast(1e3) * time_s, + GPR_TIMESPAN)); +} + +gpr_timespec grpc_timeout_milliseconds_to_deadline(int64_t time_ms) { + return gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros( + grpc_test_slowdown_factor() * static_cast(1e3) * time_ms, + GPR_TIMESPAN)); +} + +void grpc_test_init(int /*argc*/, char** argv) { + grpc_core::testing::InitializeStackTracer(argv[0]); + absl::FailureSignalHandlerOptions options; + absl::InstallFailureSignalHandler(options); + gpr_log_verbosity_init(); + gpr_log(GPR_DEBUG, + "test slowdown factor: sanitizer=%" PRId64 ", fixture=%" PRId64 + ", poller=%" PRId64 ", total=%" PRId64, + grpc_test_sanitizer_slowdown_factor(), g_fixture_slowdown_factor, + g_poller_slowdown_factor, grpc_test_slowdown_factor()); + /* seed rng with pid, so we don't end up with the same random numbers as a + concurrently running test binary */ + srand(seed()); +} + +bool grpc_wait_until_shutdown(int64_t time_s) { + gpr_timespec deadline = grpc_timeout_seconds_to_deadline(time_s); + while (grpc_is_initialized()) { + grpc_maybe_wait_for_async_shutdown(); + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(1, GPR_TIMESPAN))); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), deadline) > 0) { + return false; + } + } + return true; +} + +namespace grpc { +namespace testing { + +TestEnvironment::TestEnvironment(int argc, char** argv) { + grpc_test_init(argc, argv); +} + +TestEnvironment::~TestEnvironment() { + // This will wait until gRPC shutdown has actually happened to make sure + // no gRPC resources (such as thread) are active. (timeout = 10s) + if (!grpc_wait_until_shutdown(10)) { + gpr_log(GPR_ERROR, "Timeout in waiting for gRPC shutdown"); + } + if (BuiltUnderMsan()) { + // This is a workaround for MSAN. MSAN doesn't like having shutdown thread + // running. Although the code above waits until shutdown is done, chances + // are that thread itself is still alive. To workaround this problem, this + // is going to wait for 0.5 sec to give a chance to the shutdown thread to + // exit. https://github.com/grpc/grpc/issues/23695 + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(500, GPR_TIMESPAN))); + } + gpr_log(GPR_INFO, "TestEnvironment ends"); +} + +TestGrpcScope::TestGrpcScope() { grpc_init(); } + +TestGrpcScope::~TestGrpcScope() { + grpc_shutdown(); + if (!grpc_wait_until_shutdown(10)) { + gpr_log(GPR_ERROR, "Timeout in waiting for gRPC shutdown"); + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/core/util/test_lb_policies.cc b/test/core/util/test_lb_policies.cc new file mode 100644 index 00000000..03b6ef09 --- /dev/null +++ b/test/core/util/test_lb_policies.cc @@ -0,0 +1,556 @@ +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "test/core/util/test_lb_policies.h" + +#include + +#include + +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channelz.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/json/json.h" +#include "src/core/lib/json/json_util.h" +#include "src/core/lib/transport/connectivity_state.h" + +namespace grpc_core { + +namespace { + +// +// ForwardingLoadBalancingPolicy +// + +// A minimal forwarding class to avoid implementing a standalone test LB. +class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy { + public: + ForwardingLoadBalancingPolicy( + std::unique_ptr delegating_helper, Args args, + const char* delegate_policy_name, intptr_t initial_refcount = 1) + : LoadBalancingPolicy(std::move(args), initial_refcount) { + Args delegate_args; + delegate_args.work_serializer = work_serializer(); + delegate_args.channel_control_helper = std::move(delegating_helper); + delegate_args.args = args.args; + delegate_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( + delegate_policy_name, std::move(delegate_args)); + grpc_pollset_set_add_pollset_set(delegate_->interested_parties(), + interested_parties()); + } + + ~ForwardingLoadBalancingPolicy() override = default; + + void UpdateLocked(UpdateArgs args) override { + delegate_->UpdateLocked(std::move(args)); + } + + void ExitIdleLocked() override { delegate_->ExitIdleLocked(); } + + void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); } + + private: + void ShutdownLocked() override { delegate_.reset(); } + + OrphanablePtr delegate_; +}; + +// +// TestPickArgsLb +// + +constexpr char kTestPickArgsLbPolicyName[] = "test_pick_args_lb"; + +class TestPickArgsLb : public ForwardingLoadBalancingPolicy { + public: + TestPickArgsLb(Args args, TestPickArgsCallback cb, + const char* delegate_policy_name) + : ForwardingLoadBalancingPolicy( + absl::make_unique(RefCountedPtr(this), cb), + std::move(args), delegate_policy_name, + /*initial_refcount=*/2) {} + + ~TestPickArgsLb() override = default; + + const char* name() const override { return kTestPickArgsLbPolicyName; } + + private: + class Picker : public SubchannelPicker { + public: + Picker(std::unique_ptr delegate_picker, + TestPickArgsCallback cb) + : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {} + + PickResult Pick(PickArgs args) override { + // Report args seen. + PickArgsSeen args_seen; + args_seen.path = std::string(args.path); + args_seen.metadata = args.initial_metadata->TestOnlyCopyToVector(); + cb_(args_seen); + // Do pick. + return delegate_picker_->Pick(args); + } + + private: + std::unique_ptr delegate_picker_; + TestPickArgsCallback cb_; + }; + + class Helper : public ChannelControlHelper { + public: + Helper(RefCountedPtr parent, TestPickArgsCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState( + state, status, absl::make_unique(std::move(picker), cb_)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + absl::string_view GetAuthority() override { + return parent_->channel_control_helper()->GetAuthority(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + TestPickArgsCallback cb_; + }; +}; + +class TestPickArgsLbConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kTestPickArgsLbPolicyName; } +}; + +class TestPickArgsLbFactory : public LoadBalancingPolicyFactory { + public: + explicit TestPickArgsLbFactory(TestPickArgsCallback cb, + const char* delegate_policy_name) + : cb_(std::move(cb)), delegate_policy_name_(delegate_policy_name) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args), cb_, + delegate_policy_name_); + } + + const char* name() const override { return kTestPickArgsLbPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } + + private: + TestPickArgsCallback cb_; + const char* delegate_policy_name_; +}; + +// +// InterceptRecvTrailingMetadataLoadBalancingPolicy +// + +constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] = + "intercept_trailing_metadata_lb"; + +class InterceptRecvTrailingMetadataLoadBalancingPolicy + : public ForwardingLoadBalancingPolicy { + public: + InterceptRecvTrailingMetadataLoadBalancingPolicy( + Args args, InterceptRecvTrailingMetadataCallback cb) + : ForwardingLoadBalancingPolicy( + absl::make_unique( + RefCountedPtr( + this), + std::move(cb)), + std::move(args), + /*delegate_policy_name=*/"pick_first", + /*initial_refcount=*/2) {} + + ~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default; + + const char* name() const override { + return kInterceptRecvTrailingMetadataLbPolicyName; + } + + private: + class Picker : public SubchannelPicker { + public: + Picker(std::unique_ptr delegate_picker, + InterceptRecvTrailingMetadataCallback cb) + : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {} + + PickResult Pick(PickArgs args) override { + // Do pick. + PickResult result = delegate_picker_->Pick(args); + // Intercept trailing metadata. + auto* complete_pick = absl::get_if(&result.result); + if (complete_pick != nullptr) { + new (args.call_state->Alloc(sizeof(TrailingMetadataHandler))) + TrailingMetadataHandler(complete_pick, cb_); + } + return result; + } + + private: + std::unique_ptr delegate_picker_; + InterceptRecvTrailingMetadataCallback cb_; + }; + + class Helper : public ChannelControlHelper { + public: + Helper( + RefCountedPtr parent, + InterceptRecvTrailingMetadataCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState( + state, status, absl::make_unique(std::move(picker), cb_)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + absl::string_view GetAuthority() override { + return parent_->channel_control_helper()->GetAuthority(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + InterceptRecvTrailingMetadataCallback cb_; + }; + + class TrailingMetadataHandler { + public: + TrailingMetadataHandler(PickResult::Complete* result, + InterceptRecvTrailingMetadataCallback cb) + : cb_(std::move(cb)) { + result->recv_trailing_metadata_ready = [this](absl::Status /*status*/, + MetadataInterface* metadata, + CallState* call_state) { + RecordRecvTrailingMetadata(metadata, call_state); + }; + } + + private: + void RecordRecvTrailingMetadata(MetadataInterface* recv_trailing_metadata, + CallState* call_state) { + TrailingMetadataArgsSeen args_seen; + args_seen.backend_metric_data = call_state->GetBackendMetricData(); + GPR_ASSERT(recv_trailing_metadata != nullptr); + args_seen.metadata = recv_trailing_metadata->TestOnlyCopyToVector(); + cb_(args_seen); + this->~TrailingMetadataHandler(); + } + + InterceptRecvTrailingMetadataCallback cb_; + }; +}; + +class InterceptTrailingConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { + return kInterceptRecvTrailingMetadataLbPolicyName; + } +}; + +class InterceptTrailingFactory : public LoadBalancingPolicyFactory { + public: + explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb) + : cb_(std::move(cb)) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable( + std::move(args), cb_); + } + + const char* name() const override { + return kInterceptRecvTrailingMetadataLbPolicyName; + } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } + + private: + InterceptRecvTrailingMetadataCallback cb_; +}; + +// +// AddressTestLoadBalancingPolicy +// + +constexpr char kAddressTestLbPolicyName[] = "address_test_lb"; + +class AddressTestLoadBalancingPolicy : public ForwardingLoadBalancingPolicy { + public: + AddressTestLoadBalancingPolicy(Args args, AddressTestCallback cb) + : ForwardingLoadBalancingPolicy( + absl::make_unique( + RefCountedPtr(this), + std::move(cb)), + std::move(args), + /*delegate_policy_name=*/"pick_first", + /*initial_refcount=*/2) {} + + ~AddressTestLoadBalancingPolicy() override = default; + + const char* name() const override { return kAddressTestLbPolicyName; } + + private: + class Helper : public ChannelControlHelper { + public: + Helper(RefCountedPtr parent, + AddressTestCallback cb) + : parent_(std::move(parent)), cb_(std::move(cb)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + cb_(address); + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState(state, status, + std::move(picker)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + absl::string_view GetAuthority() override { + return parent_->channel_control_helper()->GetAuthority(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + AddressTestCallback cb_; + }; +}; + +class AddressTestConfig : public LoadBalancingPolicy::Config { + public: + const char* name() const override { return kAddressTestLbPolicyName; } +}; + +class AddressTestFactory : public LoadBalancingPolicyFactory { + public: + explicit AddressTestFactory(AddressTestCallback cb) : cb_(std::move(cb)) {} + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args), cb_); + } + + const char* name() const override { return kAddressTestLbPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& /*json*/, grpc_error_handle* /*error*/) const override { + return MakeRefCounted(); + } + + private: + AddressTestCallback cb_; +}; + +// +// FixedAddressLoadBalancingPolicy +// + +constexpr char kFixedAddressLbPolicyName[] = "fixed_address_lb"; + +class FixedAddressConfig : public LoadBalancingPolicy::Config { + public: + explicit FixedAddressConfig(std::string address) + : address_(std::move(address)) {} + + const char* name() const override { return kFixedAddressLbPolicyName; } + + const std::string& address() const { return address_; } + + private: + std::string address_; +}; + +class FixedAddressLoadBalancingPolicy : public ForwardingLoadBalancingPolicy { + public: + explicit FixedAddressLoadBalancingPolicy(Args args) + : ForwardingLoadBalancingPolicy( + absl::make_unique( + RefCountedPtr(this)), + std::move(args), + /*delegate_policy_name=*/"pick_first", + /*initial_refcount=*/2) {} + + ~FixedAddressLoadBalancingPolicy() override = default; + + const char* name() const override { return kFixedAddressLbPolicyName; } + + void UpdateLocked(UpdateArgs args) override { + auto* config = static_cast(args.config.get()); + gpr_log(GPR_INFO, "%s: update URI: %s", kFixedAddressLbPolicyName, + config->address().c_str()); + auto uri = URI::Parse(config->address()); + args.config.reset(); + args.addresses.clear(); + if (uri.ok()) { + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*uri, &address)); + args.addresses.emplace_back(address, /*args=*/nullptr); + } else { + gpr_log(GPR_ERROR, + "%s: could not parse URI (%s), using empty address list", + kFixedAddressLbPolicyName, uri.status().ToString().c_str()); + } + ForwardingLoadBalancingPolicy::UpdateLocked(std::move(args)); + } + + private: + class Helper : public ChannelControlHelper { + public: + explicit Helper(RefCountedPtr parent) + : parent_(std::move(parent)) {} + + RefCountedPtr CreateSubchannel( + ServerAddress address, const grpc_channel_args& args) override { + return parent_->channel_control_helper()->CreateSubchannel( + std::move(address), args); + } + + void UpdateState(grpc_connectivity_state state, const absl::Status& status, + std::unique_ptr picker) override { + parent_->channel_control_helper()->UpdateState(state, status, + std::move(picker)); + } + + void RequestReresolution() override { + parent_->channel_control_helper()->RequestReresolution(); + } + + absl::string_view GetAuthority() override { + return parent_->channel_control_helper()->GetAuthority(); + } + + void AddTraceEvent(TraceSeverity severity, + absl::string_view message) override { + parent_->channel_control_helper()->AddTraceEvent(severity, message); + } + + private: + RefCountedPtr parent_; + }; +}; + +class FixedAddressFactory : public LoadBalancingPolicyFactory { + public: + FixedAddressFactory() = default; + + OrphanablePtr CreateLoadBalancingPolicy( + LoadBalancingPolicy::Args args) const override { + return MakeOrphanable(std::move(args)); + } + + const char* name() const override { return kFixedAddressLbPolicyName; } + + RefCountedPtr ParseLoadBalancingConfig( + const Json& json, grpc_error_handle* error) const override { + std::vector error_list; + std::string address; + ParseJsonObjectField(json.object_value(), "address", &address, &error_list); + if (!error_list.empty()) { + *error = GRPC_ERROR_CREATE_FROM_VECTOR( + "errors parsing fixed_address_lb config", &error_list); + return nullptr; + } + return MakeRefCounted(std::move(address)); + } +}; + +} // namespace + +void RegisterTestPickArgsLoadBalancingPolicy(TestPickArgsCallback cb, + const char* delegate_policy_name) { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(cb), + delegate_policy_name)); +} + +void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( + InterceptRecvTrailingMetadataCallback cb) { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(cb))); +} + +void RegisterAddressTestLoadBalancingPolicy(AddressTestCallback cb) { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique(std::move(cb))); +} + +void RegisterFixedAddressLoadBalancingPolicy() { + LoadBalancingPolicyRegistry::Builder::RegisterLoadBalancingPolicyFactory( + absl::make_unique()); +} + +} // namespace grpc_core diff --git a/test/core/util/test_tcp_server.cc b/test/core/util/test_tcp_server.cc new file mode 100644 index 00000000..048d2f9d --- /dev/null +++ b/test/core/util/test_tcp_server.cc @@ -0,0 +1,119 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/test_tcp_server.h" + +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +static void on_server_destroyed(void* data, grpc_error_handle /*error*/) { + test_tcp_server* server = static_cast(data); + server->shutdown = true; +} + +void test_tcp_server_init(test_tcp_server* server, + grpc_tcp_server_cb on_connect, void* user_data) { + grpc_init(); + GRPC_CLOSURE_INIT(&server->shutdown_complete, on_server_destroyed, server, + grpc_schedule_on_exec_ctx); + + grpc_pollset* pollset = + static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &server->mu); + server->pollset.push_back(pollset); + server->on_connect = on_connect; + server->cb_data = user_data; +} + +void test_tcp_server_start(test_tcp_server* server, int port) { + grpc_resolved_address resolved_addr; + grpc_sockaddr_in* addr = + reinterpret_cast(resolved_addr.addr); + int port_added; + grpc_core::ExecCtx exec_ctx; + + addr->sin_family = GRPC_AF_INET; + addr->sin_port = grpc_htons(static_cast(port)); + memset(&addr->sin_addr, 0, sizeof(addr->sin_addr)); + resolved_addr.len = static_cast(sizeof(grpc_sockaddr_in)); + + grpc_error_handle error = grpc_tcp_server_create( + &server->shutdown_complete, nullptr, + grpc_slice_allocator_factory_create(grpc_resource_quota_create(nullptr)), + &server->tcp_server); + GPR_ASSERT(error == GRPC_ERROR_NONE); + error = + grpc_tcp_server_add_port(server->tcp_server, &resolved_addr, &port_added); + GPR_ASSERT(error == GRPC_ERROR_NONE); + GPR_ASSERT(port_added == port); + + grpc_tcp_server_start(server->tcp_server, &server->pollset, + server->on_connect, server->cb_data); + gpr_log(GPR_INFO, "test tcp server listening on 0.0.0.0:%d", port); +} + +void test_tcp_server_poll(test_tcp_server* server, int milliseconds) { + grpc_pollset_worker* worker = nullptr; + grpc_core::ExecCtx exec_ctx; + grpc_millis deadline = grpc_timespec_to_millis_round_up( + grpc_timeout_milliseconds_to_deadline(milliseconds)); + gpr_mu_lock(server->mu); + GRPC_LOG_IF_ERROR("pollset_work", + grpc_pollset_work(server->pollset[0], &worker, deadline)); + gpr_mu_unlock(server->mu); +} + +static void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} +static void finish_pollset(void* arg, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(arg)); +} + +void test_tcp_server_destroy(test_tcp_server* server) { + grpc_core::ExecCtx exec_ctx; + gpr_timespec shutdown_deadline; + grpc_closure do_nothing_cb; + grpc_tcp_server_unref(server->tcp_server); + GRPC_CLOSURE_INIT(&do_nothing_cb, do_nothing, nullptr, + grpc_schedule_on_exec_ctx); + shutdown_deadline = gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(5, GPR_TIMESPAN)); + grpc_core::ExecCtx::Get()->Flush(); + while (!server->shutdown && + gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), shutdown_deadline) < 0) { + test_tcp_server_poll(server, 1000); + } + grpc_pollset_shutdown(server->pollset[0], + GRPC_CLOSURE_CREATE(finish_pollset, server->pollset[0], + grpc_schedule_on_exec_ctx)); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(server->pollset[0]); + grpc_shutdown(); +} diff --git a/test/core/util/tls_utils.cc b/test/core/util/tls_utils.cc new file mode 100644 index 00000000..e70cdcf1 --- /dev/null +++ b/test/core/util/tls_utils.cc @@ -0,0 +1,76 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "test/core/util/tls_utils.h" + +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/slice/slice_internal.h" + +namespace grpc_core { + +namespace testing { + +TmpFile::TmpFile(absl::string_view credential_data) { + name_ = CreateTmpFileAndWriteData(credential_data); + GPR_ASSERT(!name_.empty()); +} + +TmpFile::~TmpFile() { GPR_ASSERT(remove(name_.c_str()) == 0); } + +void TmpFile::RewriteFile(absl::string_view credential_data) { + // Create a new file containing new data. + std::string new_name = CreateTmpFileAndWriteData(credential_data); + GPR_ASSERT(!new_name.empty()); + // Remove the old file. + GPR_ASSERT(remove(name_.c_str()) == 0); + // Rename the new file to the original name. + GPR_ASSERT(rename(new_name.c_str(), name_.c_str()) == 0); +} + +std::string TmpFile::CreateTmpFileAndWriteData( + absl::string_view credential_data) { + char* name = nullptr; + FILE* file_descriptor = gpr_tmpfile("GrpcTlsCertificateProviderTest", &name); + GPR_ASSERT(fwrite(credential_data.data(), 1, credential_data.size(), + file_descriptor) == credential_data.size()); + GPR_ASSERT(fclose(file_descriptor) == 0); + GPR_ASSERT(file_descriptor != nullptr); + GPR_ASSERT(name != nullptr); + std::string name_to_return = name; + gpr_free(name); + return name_to_return; +} + +PemKeyCertPairList MakeCertKeyPairs(absl::string_view private_key, + absl::string_view certs) { + if (private_key.empty() && certs.empty()) { + return {}; + } + return PemKeyCertPairList{PemKeyCertPair(private_key, certs)}; +} + +std::string GetFileContents(const char* path) { + grpc_slice slice = grpc_empty_slice(); + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", grpc_load_file(path, 0, &slice))); + std::string credential = std::string(StringViewFromSlice(slice)); + grpc_slice_unref(slice); + return credential; +} + +} // namespace testing + +} // namespace grpc_core diff --git a/test/core/util/tracer_util.cc b/test/core/util/tracer_util.cc new file mode 100644 index 00000000..e48ae0cc --- /dev/null +++ b/test/core/util/tracer_util.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/debug/trace.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { + +void grpc_tracer_enable_flag(grpc_core::TraceFlag* flag) { + flag->set_enabled(true); +} + +} // namespace testing +} // namespace grpc_core diff --git a/test/core/util/trickle_endpoint.cc b/test/core/util/trickle_endpoint.cc new file mode 100644 index 00000000..49c70093 --- /dev/null +++ b/test/core/util/trickle_endpoint.cc @@ -0,0 +1,207 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/slice/slice_internal.h" +#include "test/core/util/passthru_endpoint.h" + +#define WRITE_BUFFER_SIZE (2 * 1024 * 1024) + +typedef struct { + grpc_endpoint base; + double bytes_per_second; + grpc_endpoint* wrapped; + gpr_timespec last_write; + + gpr_mu mu; + grpc_slice_buffer write_buffer; + grpc_slice_buffer writing_buffer; + grpc_error_handle error; + bool writing; + grpc_closure* write_cb; +} trickle_endpoint; + +static void te_read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool urgent) { + trickle_endpoint* te = reinterpret_cast(ep); + grpc_endpoint_read(te->wrapped, slices, cb, urgent); +} + +static void maybe_call_write_cb_locked(trickle_endpoint* te) { + if (te->write_cb != nullptr && + (te->error != GRPC_ERROR_NONE || + te->write_buffer.length <= WRITE_BUFFER_SIZE)) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, te->write_cb, + GRPC_ERROR_REF(te->error)); + te->write_cb = nullptr; + } +} + +static void te_write(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, void* /*arg*/) { + trickle_endpoint* te = reinterpret_cast(ep); + gpr_mu_lock(&te->mu); + GPR_ASSERT(te->write_cb == nullptr); + if (te->write_buffer.length == 0) { + te->last_write = gpr_now(GPR_CLOCK_MONOTONIC); + } + for (size_t i = 0; i < slices->count; i++) { + grpc_slice_buffer_add(&te->write_buffer, + grpc_slice_copy(slices->slices[i])); + } + te->write_cb = cb; + maybe_call_write_cb_locked(te); + gpr_mu_unlock(&te->mu); +} + +static void te_add_to_pollset(grpc_endpoint* ep, grpc_pollset* pollset) { + trickle_endpoint* te = reinterpret_cast(ep); + grpc_endpoint_add_to_pollset(te->wrapped, pollset); +} + +static void te_add_to_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + trickle_endpoint* te = reinterpret_cast(ep); + grpc_endpoint_add_to_pollset_set(te->wrapped, pollset_set); +} + +static void te_delete_from_pollset_set(grpc_endpoint* ep, + grpc_pollset_set* pollset_set) { + trickle_endpoint* te = reinterpret_cast(ep); + grpc_endpoint_delete_from_pollset_set(te->wrapped, pollset_set); +} + +static void te_shutdown(grpc_endpoint* ep, grpc_error_handle why) { + trickle_endpoint* te = reinterpret_cast(ep); + gpr_mu_lock(&te->mu); + if (te->error == GRPC_ERROR_NONE) { + te->error = GRPC_ERROR_REF(why); + } + maybe_call_write_cb_locked(te); + gpr_mu_unlock(&te->mu); + grpc_endpoint_shutdown(te->wrapped, why); +} + +static void te_destroy(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + grpc_endpoint_destroy(te->wrapped); + gpr_mu_destroy(&te->mu); + grpc_slice_buffer_destroy_internal(&te->write_buffer); + grpc_slice_buffer_destroy_internal(&te->writing_buffer); + GRPC_ERROR_UNREF(te->error); + gpr_free(te); +} + +static absl::string_view te_get_peer(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + return grpc_endpoint_get_peer(te->wrapped); +} + +static absl::string_view te_get_local_address(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + return grpc_endpoint_get_local_address(te->wrapped); +} + +static int te_get_fd(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + return grpc_endpoint_get_fd(te->wrapped); +} + +static bool te_can_track_err(grpc_endpoint* /*ep*/) { return false; } + +static void te_finish_write(void* arg, grpc_error_handle /*error*/) { + trickle_endpoint* te = static_cast(arg); + gpr_mu_lock(&te->mu); + te->writing = false; + grpc_slice_buffer_reset_and_unref(&te->writing_buffer); + gpr_mu_unlock(&te->mu); +} + +static const grpc_endpoint_vtable vtable = {te_read, + te_write, + te_add_to_pollset, + te_add_to_pollset_set, + te_delete_from_pollset_set, + te_shutdown, + te_destroy, + te_get_peer, + te_get_local_address, + te_get_fd, + te_can_track_err}; + +grpc_endpoint* grpc_trickle_endpoint_create(grpc_endpoint* wrap, + double bytes_per_second) { + trickle_endpoint* te = + static_cast(gpr_malloc(sizeof(*te))); + te->base.vtable = &vtable; + te->wrapped = wrap; + te->bytes_per_second = bytes_per_second; + te->write_cb = nullptr; + gpr_mu_init(&te->mu); + grpc_slice_buffer_init(&te->write_buffer); + grpc_slice_buffer_init(&te->writing_buffer); + te->error = GRPC_ERROR_NONE; + te->writing = false; + return &te->base; +} + +static double ts2dbl(gpr_timespec s) { + return static_cast(s.tv_sec) + 1e-9 * static_cast(s.tv_nsec); +} + +size_t grpc_trickle_endpoint_trickle(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + gpr_mu_lock(&te->mu); + if (!te->writing && te->write_buffer.length > 0) { + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + double elapsed = ts2dbl(gpr_time_sub(now, te->last_write)); + size_t bytes = static_cast(te->bytes_per_second * elapsed); + // gpr_log(GPR_DEBUG, "%lf elapsed --> %" PRIdPTR " bytes", elapsed, bytes); + if (bytes > 0) { + grpc_slice_buffer_move_first(&te->write_buffer, + std::min(bytes, te->write_buffer.length), + &te->writing_buffer); + te->writing = true; + te->last_write = now; + grpc_endpoint_write( + te->wrapped, &te->writing_buffer, + GRPC_CLOSURE_CREATE(te_finish_write, te, grpc_schedule_on_exec_ctx), + nullptr); + maybe_call_write_cb_locked(te); + } + } + size_t backlog = te->write_buffer.length; + gpr_mu_unlock(&te->mu); + return backlog; +} + +size_t grpc_trickle_get_backlog(grpc_endpoint* ep) { + trickle_endpoint* te = reinterpret_cast(ep); + gpr_mu_lock(&te->mu); + size_t backlog = te->write_buffer.length; + gpr_mu_unlock(&te->mu); + return backlog; +} diff --git a/test/core/xds/certificate_provider_store_test.cc b/test/core/xds/certificate_provider_store_test.cc new file mode 100644 index 00000000..b82bf0a8 --- /dev/null +++ b/test/core/xds/certificate_provider_store_test.cc @@ -0,0 +1,168 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/xds/certificate_provider_store.h" + +#include + +#include + +#include "src/core/ext/xds/certificate_provider_registry.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +class CertificateProviderStoreTest : public ::testing::Test { + public: + CertificateProviderStoreTest() { grpc_init(); } + + ~CertificateProviderStoreTest() override { grpc_shutdown_blocking(); } +}; + +class FakeCertificateProvider : public grpc_tls_certificate_provider { + public: + RefCountedPtr distributor() const override { + // never called + GPR_ASSERT(0); + return nullptr; + } +}; + +class FakeCertificateProviderFactory1 : public CertificateProviderFactory { + public: + class Config : public CertificateProviderFactory::Config { + public: + const char* name() const override { return "fake1"; } + + std::string ToString() const override { return "{}"; } + }; + + const char* name() const override { return "fake1"; } + + RefCountedPtr + CreateCertificateProviderConfig(const Json& /*config_json*/, + grpc_error_handle* /*error*/) override { + return MakeRefCounted(); + } + + RefCountedPtr CreateCertificateProvider( + RefCountedPtr /*config*/) override { + return MakeRefCounted(); + } +}; + +class FakeCertificateProviderFactory2 : public CertificateProviderFactory { + public: + class Config : public CertificateProviderFactory::Config { + public: + const char* name() const override { return "fake2"; } + + std::string ToString() const override { return "{}"; } + }; + + const char* name() const override { return "fake2"; } + + RefCountedPtr + CreateCertificateProviderConfig(const Json& /*config_json*/, + grpc_error_handle* /*error*/) override { + return MakeRefCounted(); + } + + RefCountedPtr CreateCertificateProvider( + RefCountedPtr /*config*/) override { + return MakeRefCounted(); + } +}; + +TEST_F(CertificateProviderStoreTest, Basic) { + // Set up factories. (Register only one of the factories.) + auto* fake_factory_1 = new FakeCertificateProviderFactory1; + CertificateProviderRegistry::RegisterCertificateProviderFactory( + std::unique_ptr(fake_factory_1)); + auto fake_factory_2 = absl::make_unique(); + // Set up store + CertificateProviderStore::PluginDefinitionMap map = { + {"fake_plugin_1", + {"fake1", fake_factory_1->CreateCertificateProviderConfig(Json::Object(), + nullptr)}}, + {"fake_plugin_2", + {"fake2", fake_factory_2->CreateCertificateProviderConfig(Json::Object(), + nullptr)}}, + {"fake_plugin_3", + {"fake1", fake_factory_1->CreateCertificateProviderConfig(Json::Object(), + nullptr)}}, + }; + auto store = MakeOrphanable(std::move(map)); + // Test for creating certificate providers with known plugin configuration. + auto cert_provider_1 = store->CreateOrGetCertificateProvider("fake_plugin_1"); + ASSERT_NE(cert_provider_1, nullptr); + auto cert_provider_3 = store->CreateOrGetCertificateProvider("fake_plugin_3"); + ASSERT_NE(cert_provider_3, nullptr); + // Test for creating certificate provider with known plugin configuration but + // unregistered factory. + ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_2"), nullptr); + // Test for creating certificate provider with unknown plugin configuration. + ASSERT_EQ(store->CreateOrGetCertificateProvider("unknown"), nullptr); + // Test for getting previously created certificate providers. + ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_1"), + cert_provider_1); + ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_3"), + cert_provider_3); + // Release previously created certificate providers so that the store outlasts + // the certificate providers. + cert_provider_1.reset(); + cert_provider_3.reset(); +} + +TEST_F(CertificateProviderStoreTest, Multithreaded) { + auto* fake_factory_1 = new FakeCertificateProviderFactory1; + CertificateProviderRegistry::RegisterCertificateProviderFactory( + std::unique_ptr(fake_factory_1)); + CertificateProviderStore::PluginDefinitionMap map = { + {"fake_plugin_1", + {"fake1", fake_factory_1->CreateCertificateProviderConfig(Json::Object(), + nullptr)}}}; + auto store = MakeOrphanable(std::move(map)); + // Test concurrent `CreateOrGetCertificateProvider()` with the same key. + std::vector threads; + threads.reserve(1000); + for (auto i = 0; i < 1000; i++) { + threads.emplace_back([&store]() { + for (auto i = 0; i < 10; ++i) { + ASSERT_NE(store->CreateOrGetCertificateProvider("fake_plugin_1"), + nullptr); + } + }); + } + for (auto& thread : threads) { + thread.join(); + } +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/core/xds/file_watcher_certificate_provider_factory_test.cc b/test/core/xds/file_watcher_certificate_provider_factory_test.cc new file mode 100644 index 00000000..bba2f374 --- /dev/null +++ b/test/core/xds/file_watcher_certificate_provider_factory_test.cc @@ -0,0 +1,202 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/xds/file_watcher_certificate_provider_factory.h" + +#include +#include + +#include "absl/strings/str_format.h" + +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +const char* kIdentityCertFile = "/path/to/identity_cert_file"; +const char* kPrivateKeyFile = "/path/to/private_key_file"; +const char* kRootCertFile = "/path/to/root_cert_file"; +const int kRefreshInterval = 400; + +TEST(FileWatcherConfigTest, Basic) { + std::string json_str = absl::StrFormat( + "{" + " \"certificate_file\": \"%s\"," + " \"private_key_file\": \"%s\"," + " \"ca_certificate_file\": \"%s\"," + " \"refresh_interval\": \"%ds\"" + "}", + kIdentityCertFile, kPrivateKeyFile, kRootCertFile, kRefreshInterval); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(config->identity_cert_file(), kIdentityCertFile); + EXPECT_EQ(config->private_key_file(), kPrivateKeyFile); + EXPECT_EQ(config->root_cert_file(), kRootCertFile); + EXPECT_EQ(config->refresh_interval_ms(), kRefreshInterval * 1000); +} + +TEST(FileWatcherConfigTest, DefaultRefreshInterval) { + std::string json_str = absl::StrFormat( + "{" + " \"certificate_file\": \"%s\"," + " \"private_key_file\": \"%s\"," + " \"ca_certificate_file\": \"%s\"" + "}", + kIdentityCertFile, kPrivateKeyFile, kRootCertFile); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(config->identity_cert_file(), kIdentityCertFile); + EXPECT_EQ(config->private_key_file(), kPrivateKeyFile); + EXPECT_EQ(config->root_cert_file(), kRootCertFile); + EXPECT_EQ(config->refresh_interval_ms(), 600 * 1000); +} + +TEST(FileWatcherConfigTest, OnlyRootCertificatesFileProvided) { + std::string json_str = absl::StrFormat( + "{" + " \"ca_certificate_file\": \"%s\"" + "}", + kRootCertFile); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_TRUE(config->identity_cert_file().empty()); + EXPECT_TRUE(config->private_key_file().empty()); + EXPECT_EQ(config->root_cert_file(), kRootCertFile); + EXPECT_EQ(config->refresh_interval_ms(), 600 * 1000); +} + +TEST(FileWatcherConfigTest, OnlyIdenityCertificatesAndPrivateKeyProvided) { + std::string json_str = absl::StrFormat( + "{" + " \"certificate_file\": \"%s\"," + " \"private_key_file\": \"%s\"" + "}", + kIdentityCertFile, kPrivateKeyFile); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(config->identity_cert_file(), kIdentityCertFile); + EXPECT_EQ(config->private_key_file(), kPrivateKeyFile); + EXPECT_TRUE(config->root_cert_file().empty()); + EXPECT_EQ(config->refresh_interval_ms(), 600 * 1000); +} + +TEST(FileWatcherConfigTest, WrongTypes) { + const char* json_str = + "{" + " \"certificate_file\": 123," + " \"private_key_file\": 123," + " \"ca_certificate_file\": 123," + " \"refresh_interval\": 123" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "field:certificate_file error:type should be STRING.*" + "field:private_key_file error:type should be STRING.*" + "field:ca_certificate_file error:type should be STRING.*" + "field:refresh_interval error:type should be STRING of the " + "form given by " + "google.proto.Duration.*")); + GRPC_ERROR_UNREF(error); +} + +TEST(FileWatcherConfigTest, IdentityCertProvidedButPrivateKeyMissing) { + std::string json_str = absl::StrFormat( + "{" + " \"certificate_file\": \"%s\"" + "}", + kIdentityCertFile); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "fields \"certificate_file\" and \"private_key_file\" must " + "be both set or both unset.")); + GRPC_ERROR_UNREF(error); +} + +TEST(FileWatcherConfigTest, PrivateKeyProvidedButIdentityCertMissing) { + std::string json_str = absl::StrFormat( + "{" + " \"private_key_file\": \"%s\"" + "}", + kPrivateKeyFile); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "fields \"certificate_file\" and \"private_key_file\" must " + "be both set or both unset.")); + GRPC_ERROR_UNREF(error); +} + +TEST(FileWatcherConfigTest, EmptyJsonObject) { + std::string json_str = absl::StrFormat("{}"); + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + FileWatcherCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("At least one of \"certificate_file\" and " + "\"ca_certificate_file\" must be specified.")); + GRPC_ERROR_UNREF(error); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/xds/google_mesh_ca_certificate_provider_factory_test.cc b/test/core/xds/google_mesh_ca_certificate_provider_factory_test.cc new file mode 100644 index 00000000..8c6af3e1 --- /dev/null +++ b/test/core/xds/google_mesh_ca_certificate_provider_factory_test.cc @@ -0,0 +1,368 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/xds/google_mesh_ca_certificate_provider_factory.h" + +#include +#include + +#include + +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +TEST(GoogleMeshCaConfigTest, Basic) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"target_uri\": \"newmeshca.googleapis.com\"," + " \"channel_credentials\": { \"google_default\": {}}," + " \"call_credentials\": [{" + " \"sts_service\": {" + " \"token_exchange_service_uri\": " + "\"newsecuretoken.googleapis.com\"," + " \"resource\": \"newmeshca.googleapis.com\"," + " \"audience\": \"newmeshca.googleapis.com\"," + " \"scope\": " + "\"https://www.newgoogleapis.com/auth/cloud-platform\"," + " \"requested_token_type\": " + "\"urn:ietf:params:oauth:token-type:jwt\"," + " \"subject_token_path\": \"/etc/secret/sajwt.token\"," + " \"subject_token_type\": " + "\"urn:ietf:params:oauth:token-type:jwt\"," + " \"actor_token_path\": \"/etc/secret/sajwt.token\"," + " \"actor_token_type\": " + "\"urn:ietf:params:oauth:token-type:jwt\"" + " }" + " }]" + " }," + " \"timeout\": \"20s\"" + " }]" + " }," + " \"certificate_lifetime\": \"400s\"," + " \"renewal_grace_period\": \"100s\"," + " \"key_type\": \"RSA\"," + " \"key_size\": 1024," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(config->endpoint(), "newmeshca.googleapis.com"); + EXPECT_EQ(config->sts_config().token_exchange_service_uri, + "newsecuretoken.googleapis.com"); + EXPECT_EQ(config->sts_config().resource, "newmeshca.googleapis.com"); + EXPECT_EQ(config->sts_config().audience, "newmeshca.googleapis.com"); + EXPECT_EQ(config->sts_config().scope, + "https://www.newgoogleapis.com/auth/cloud-platform"); + EXPECT_EQ(config->sts_config().requested_token_type, + "urn:ietf:params:oauth:token-type:jwt"); + EXPECT_EQ(config->sts_config().subject_token_path, "/etc/secret/sajwt.token"); + EXPECT_EQ(config->sts_config().subject_token_type, + "urn:ietf:params:oauth:token-type:jwt"); + EXPECT_EQ(config->sts_config().actor_token_path, "/etc/secret/sajwt.token"); + EXPECT_EQ(config->sts_config().actor_token_type, + "urn:ietf:params:oauth:token-type:jwt"); + EXPECT_EQ(config->timeout(), 20 * 1000); + EXPECT_EQ(config->certificate_lifetime(), 400 * 1000); + EXPECT_EQ(config->renewal_grace_period(), 100 * 1000); + EXPECT_EQ(config->key_size(), 1024); + EXPECT_EQ(config->location(), + "https://container.googleapis.com/v1/project/test-project1/" + "locations/test-zone2/clusters/test-cluster3"); +} + +TEST(GoogleMeshCaConfigTest, Defaults) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"call_credentials\": [{" + " \"sts_service\": {" + " \"scope\": " + "\"https://www.googleapis.com/auth/cloud-platform\"," + " \"subject_token_path\": \"/etc/secret/sajwt.token\"," + " \"subject_token_type\": " + "\"urn:ietf:params:oauth:token-type:jwt\"" + " }" + " }]" + " }" + " }]" + " }," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(config->endpoint(), "meshca.googleapis.com"); + EXPECT_EQ(config->sts_config().token_exchange_service_uri, + "securetoken.googleapis.com"); + EXPECT_EQ(config->sts_config().resource, ""); + EXPECT_EQ(config->sts_config().audience, ""); + EXPECT_EQ(config->sts_config().scope, + "https://www.googleapis.com/auth/cloud-platform"); + EXPECT_EQ(config->sts_config().requested_token_type, ""); + EXPECT_EQ(config->sts_config().subject_token_path, "/etc/secret/sajwt.token"); + EXPECT_EQ(config->sts_config().subject_token_type, + "urn:ietf:params:oauth:token-type:jwt"); + EXPECT_EQ(config->sts_config().actor_token_path, ""); + EXPECT_EQ(config->sts_config().actor_token_type, ""); + EXPECT_EQ(config->timeout(), 10 * 1000); + EXPECT_EQ(config->certificate_lifetime(), 24 * 60 * 60 * 1000); + EXPECT_EQ(config->renewal_grace_period(), 12 * 60 * 60 * 1000); + EXPECT_EQ(config->key_size(), 2048); + EXPECT_EQ(config->location(), + "https://container.googleapis.com/v1/project/test-project1/" + "locations/test-zone2/clusters/test-cluster3"); +} + +TEST(GoogleMeshCaConfigTest, WrongExpectedValues) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"REST\"," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"call_credentials\": [{" + " \"sts_service\": {" + " \"scope\": " + "\"https://www.googleapis.com/auth/cloud-platform\"," + " \"subject_token_path\": \"/etc/secret/sajwt.token\"," + " \"subject_token_type\": " + "\"urn:ietf:params:oauth:token-type:jwt\"" + " }" + " }]" + " }" + " }]" + " }," + " \"key_type\": \"DSA\"," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("field:api_type error:Only GRPC is supported.*" + "field:key_type error:Only RSA is supported")); + GRPC_ERROR_UNREF(error); +} + +TEST(GoogleMeshCaConfigTest, WrongTypes) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": 123," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"target_uri\": 123," + " \"call_credentials\": [{" + " \"sts_service\": {" + " \"token_exchange_service_uri\": 123," + " \"resource\": 123," + " \"audience\": 123," + " \"scope\": 123," + " \"requested_token_type\": 123," + " \"subject_token_path\": 123," + " \"subject_token_type\": 123," + " \"actor_token_path\": 123," + " \"actor_token_type\": 123" + " }" + " }]" + " }," + " \"timeout\": 20" + " }]" + " }," + " \"certificate_lifetime\": 400," + " \"renewal_grace_period\": 100," + " \"key_type\": 123," + " \"key_size\": \"1024\"," + " \"location\": 123" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "field:server.*field:api_type error:type should be STRING.*" + "field:grpc_services.*field:google_grpc.*field:target_uri " + "error:type should be STRING.*" + "field:call_credentials.*field:sts_service.*field:token_exchange_" + "service_uri error:type should be STRING.*" + "field:resource error:type should be STRING.*" + "field:audience error:type should be STRING.*" + "field:scope error:type should be STRING.*" + "field:requested_token_type error:type should be STRING.*" + "field:subject_token_path error:type should be STRING.*" + "field:subject_token_type error:type should be STRING.*" + "field:actor_token_path error:type should be STRING.*" + "field:actor_token_type error:type should be STRING.*" + "field:timeout error:type should be STRING of the form given by " + "google.proto.Duration.*" + "field:certificate_lifetime error:type should be STRING of the form " + "given by google.proto.Duration.*" + "field:renewal_grace_period error:type should be STRING of the form " + "given by google.proto.Duration..*" + "field:key_type error:type should be STRING.*" + "field:key_size error:type should be NUMBER.*" + "field:location error:type should be STRING")); + GRPC_ERROR_UNREF(error); +} + +TEST(GoogleMeshCaConfigTest, GrpcServicesNotAnArray) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": 123" + " }," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "field:server.*field:grpc_services error:type should be ARRAY")); + GRPC_ERROR_UNREF(error); +} + +TEST(GoogleMeshCaConfigTest, GoogleGrpcNotAnObject) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": [{" + " \"google_grpc\": 123" + " }]" + " }," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("field:server.*field:grpc_services.*field:" + "google_grpc error:type should be OBJECT")); + GRPC_ERROR_UNREF(error); +} + +TEST(GoogleMeshCaConfigTest, CallCredentialsNotAnArray) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"call_credentials\": 123" + " }" + " }]" + " }," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "field:server.*field:grpc_services.*field:google_grpc.*" + "field:call_credentials error:type should be ARRAY")); + GRPC_ERROR_UNREF(error); +} + +TEST(GoogleMeshCaConfigTest, StsServiceNotAnObject) { + const char* json_str = + "{" + " \"server\": {" + " \"api_type\": \"GRPC\"," + " \"grpc_services\": [{" + " \"google_grpc\": {" + " \"call_credentials\": [{" + " \"sts_service\": 123" + " }]" + " }" + " }]" + " }," + " \"location\": " + "\"https://container.googleapis.com/v1/project/test-project1/locations/" + "test-zone2/clusters/test-cluster3\"" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + auto config = + GoogleMeshCaCertificateProviderFactory::Config::Parse(json, &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex( + "field:server.*field:grpc_services.*field:google_grpc.*field:" + "call_credentials.*field:sts_service error:type should be OBJECT")); + GRPC_ERROR_UNREF(error); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/xds/xds_bootstrap_test.cc b/test/core/xds/xds_bootstrap_test.cc new file mode 100644 index 00000000..d007780b --- /dev/null +++ b/test/core/xds/xds_bootstrap_test.cc @@ -0,0 +1,613 @@ +// +// Copyright 2019 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/core/ext/xds/xds_bootstrap.h" + +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" + +#include +#include + +#include "src/core/ext/xds/certificate_provider_registry.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +TEST(XdsBootstrapTest, Basic) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [" + " {" + " \"type\": \"fake\"," + " \"ignore\": 0" + " }" + " ]," + " \"ignore\": 0" + " }," + " {" + " \"server_uri\": \"ignored\"," + " \"channel_creds\": [" + " {" + " \"type\": \"ignored\"," + " \"ignore\": 0" + " }," + " {" + " \"type\": \"fake\"" + " }" + " ]," + " \"ignore\": 0" + " }" + " ]," + " \"node\": {" + " \"id\": \"foo\"," + " \"cluster\": \"bar\"," + " \"locality\": {" + " \"region\": \"milky_way\"," + " \"zone\": \"sol_system\"," + " \"sub_zone\": \"earth\"," + " \"ignore\": {}" + " }," + " \"metadata\": {" + " \"foo\": 1," + " \"bar\": 2" + " }," + " \"ignore\": \"whee\"" + " }," + " \"server_listener_resource_name_template\": \"example/resource\"," + " \"ignore\": {}" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(bootstrap.server().server_uri, "fake:///lb"); + EXPECT_EQ(bootstrap.server().channel_creds_type, "fake"); + EXPECT_EQ(bootstrap.server().channel_creds_config.type(), + Json::Type::JSON_NULL); + ASSERT_NE(bootstrap.node(), nullptr); + EXPECT_EQ(bootstrap.node()->id, "foo"); + EXPECT_EQ(bootstrap.node()->cluster, "bar"); + EXPECT_EQ(bootstrap.node()->locality_region, "milky_way"); + EXPECT_EQ(bootstrap.node()->locality_zone, "sol_system"); + EXPECT_EQ(bootstrap.node()->locality_sub_zone, "earth"); + ASSERT_EQ(bootstrap.node()->metadata.type(), Json::Type::OBJECT); + EXPECT_THAT(bootstrap.node()->metadata.object_value(), + ::testing::ElementsAre( + ::testing::Pair( + ::testing::Eq("bar"), + ::testing::AllOf( + ::testing::Property(&Json::type, Json::Type::NUMBER), + ::testing::Property(&Json::string_value, "2"))), + ::testing::Pair( + ::testing::Eq("foo"), + ::testing::AllOf( + ::testing::Property(&Json::type, Json::Type::NUMBER), + ::testing::Property(&Json::string_value, "1"))))); + EXPECT_EQ(bootstrap.server_listener_resource_name_template(), + "example/resource"); +} + +TEST(XdsBootstrapTest, ValidWithoutNode) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(bootstrap.server().server_uri, "fake:///lb"); + EXPECT_EQ(bootstrap.server().channel_creds_type, "fake"); + EXPECT_EQ(bootstrap.node(), nullptr); +} + +TEST(XdsBootstrapTest, InsecureCreds) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"insecure\"}]" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(bootstrap.server().server_uri, "fake:///lb"); + EXPECT_EQ(bootstrap.server().channel_creds_type, "insecure"); + EXPECT_EQ(bootstrap.node(), nullptr); +} + +TEST(XdsBootstrapTest, GoogleDefaultCreds) { + // Generate call creds file needed by GoogleDefaultCreds. + const char token_str[] = + "{ \"client_id\": \"32555999999.apps.googleusercontent.com\"," + " \"client_secret\": \"EmssLNjJy1332hD4KFsecret\"," + " \"refresh_token\": \"1/Blahblasj424jladJDSGNf-u4Sua3HDA2ngjd42\"," + " \"type\": \"authorized_user\"}"; + char* creds_file_name; + FILE* creds_file = gpr_tmpfile("xds_bootstrap_test", &creds_file_name); + ASSERT_NE(creds_file_name, nullptr); + ASSERT_NE(creds_file, nullptr); + ASSERT_EQ(fwrite(token_str, 1, sizeof(token_str), creds_file), + sizeof(token_str)); + fclose(creds_file); + gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, creds_file_name); + gpr_free(creds_file_name); + // Now run test. + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"google_default\"}]" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + EXPECT_EQ(bootstrap.server().server_uri, "fake:///lb"); + EXPECT_EQ(bootstrap.server().channel_creds_type, "google_default"); + EXPECT_EQ(bootstrap.node(), nullptr); +} + +TEST(XdsBootstrapTest, MissingChannelCreds) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("\"channel_creds\" field not present")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, NoKnownChannelCreds) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"unknown\"}]" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "no known creds type found in \"channel_creds\"")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, MissingXdsServers) { + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse("{}", &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("\"xds_servers\" field not present")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, TopFieldsWrongTypes) { + const char* json_str = + "{" + " \"xds_servers\":1," + " \"node\":1," + " \"server_listener_resource_name_template\":1," + " \"certificate_providers\":1" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("\"xds_servers\" field is not an array.*" + "\"node\" field is not an object.*" + "\"server_listener_resource_name_" + "template\" field is not a string.*")); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "\"certificate_providers\" field is not an object")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, XdsServerMissingServerUri) { + const char* json_str = + "{" + " \"xds_servers\":[{}]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("errors parsing \"xds_servers\" array.*" + "errors parsing index 0.*" + "\"server_uri\" field not present")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, XdsServerUriAndCredsWrongTypes) { + const char* json_str = + "{" + " \"xds_servers\":[" + " {" + " \"server_uri\":1," + " \"channel_creds\":1" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("errors parsing \"xds_servers\" array.*" + "errors parsing index 0.*" + "\"server_uri\" field is not a string.*" + "\"channel_creds\" field is not an array")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, ChannelCredsFieldsWrongTypes) { + const char* json_str = + "{" + " \"xds_servers\":[" + " {" + " \"server_uri\":\"foo\"," + " \"channel_creds\":[" + " {" + " \"type\":0," + " \"config\":1" + " }" + " ]" + " }" + " ]" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT( + grpc_error_std_string(error), + ::testing::ContainsRegex("errors parsing \"xds_servers\" array.*" + "errors parsing index 0.*" + "errors parsing \"channel_creds\" array.*" + "errors parsing index 0.*" + "\"type\" field is not a string.*" + "\"config\" field is not an object")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, NodeFieldsWrongTypes) { + const char* json_str = + "{" + " \"node\":{" + " \"id\":0," + " \"cluster\":0," + " \"locality\":0," + " \"metadata\":0" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("errors parsing \"node\" object.*" + "\"id\" field is not a string.*" + "\"cluster\" field is not a string.*" + "\"locality\" field is not an object.*" + "\"metadata\" field is not an object")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, LocalityFieldsWrongType) { + const char* json_str = + "{" + " \"node\":{" + " \"locality\":{" + " \"region\":0," + " \"zone\":0," + " \"sub_zone\":0" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex("errors parsing \"node\" object.*" + "errors parsing \"locality\" object.*" + "\"region\" field is not a string.*" + "\"zone\" field is not a string.*" + "\"sub_zone\" field is not a string")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, CertificateProvidersElementWrongType) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"plugin\":1" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing \"certificate_providers\" object.*" + "element \"plugin\" is not an object")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, CertificateProvidersPluginNameWrongType) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"plugin\": {" + " \"plugin_name\":1" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing \"certificate_providers\" object.*" + "errors parsing element \"plugin\".*" + "\"plugin_name\" field is not a string")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, CertificateProvidersUnrecognizedPluginName) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"plugin\": {" + " \"plugin_name\":\"unknown\"" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing \"certificate_providers\" object.*" + "errors parsing element \"plugin\".*" + "Unrecognized plugin name: unknown")); + GRPC_ERROR_UNREF(error); +} + +class FakeCertificateProviderFactory : public CertificateProviderFactory { + public: + class Config : public CertificateProviderFactory::Config { + public: + explicit Config(int value) : value_(value) {} + + int value() const { return value_; } + + const char* name() const override { return "fake"; } + + std::string ToString() const override { + return absl::StrFormat( + "{\n" + " value=%d" + "}", + value_); + } + + private: + int value_; + }; + + const char* name() const override { return "fake"; } + + RefCountedPtr + CreateCertificateProviderConfig(const Json& config_json, + grpc_error_handle* error) override { + std::vector error_list; + EXPECT_EQ(config_json.type(), Json::Type::OBJECT); + auto it = config_json.object_value().find("value"); + if (it == config_json.object_value().end()) { + return MakeRefCounted(0); + } else if (it->second.type() != Json::Type::NUMBER) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "field:config field:value not of type number"); + } else { + int value = 0; + EXPECT_TRUE(absl::SimpleAtoi(it->second.string_value(), &value)); + return MakeRefCounted(value); + } + return nullptr; + } + + RefCountedPtr CreateCertificateProvider( + RefCountedPtr /*config*/) override { + return nullptr; + } +}; + +TEST(XdsBootstrapTest, CertificateProvidersFakePluginParsingError) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"fake_plugin\": {" + " \"plugin_name\": \"fake\"," + " \"config\": {" + " \"value\": \"10\"" + " }" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + EXPECT_THAT(grpc_error_std_string(error), + ::testing::ContainsRegex( + "errors parsing \"certificate_providers\" object.*" + "errors parsing element \"fake_plugin\".*" + "field:config field:value not of type number")); + GRPC_ERROR_UNREF(error); +} + +TEST(XdsBootstrapTest, CertificateProvidersFakePluginParsingSuccess) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"fake_plugin\": {" + " \"plugin_name\": \"fake\"," + " \"config\": {" + " \"value\": 10" + " }" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + ASSERT_TRUE(error == GRPC_ERROR_NONE) << grpc_error_std_string(error); + const CertificateProviderStore::PluginDefinition& fake_plugin = + bootstrap.certificate_providers().at("fake_plugin"); + ASSERT_EQ(fake_plugin.plugin_name, "fake"); + ASSERT_STREQ(fake_plugin.config->name(), "fake"); + ASSERT_EQ(static_cast>( + fake_plugin.config) + ->value(), + 10); +} + +TEST(XdsBootstrapTest, CertificateProvidersFakePluginEmptyConfig) { + const char* json_str = + "{" + " \"xds_servers\": [" + " {" + " \"server_uri\": \"fake:///lb\"," + " \"channel_creds\": [{\"type\": \"fake\"}]" + " }" + " ]," + " \"certificate_providers\": {" + " \"fake_plugin\": {" + " \"plugin_name\": \"fake\"" + " }" + " }" + "}"; + grpc_error_handle error = GRPC_ERROR_NONE; + Json json = Json::Parse(json_str, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + XdsBootstrap bootstrap(std::move(json), &error); + ASSERT_TRUE(error == GRPC_ERROR_NONE) << grpc_error_std_string(error); + const CertificateProviderStore::PluginDefinition& fake_plugin = + bootstrap.certificate_providers().at("fake_plugin"); + ASSERT_EQ(fake_plugin.plugin_name, "fake"); + ASSERT_STREQ(fake_plugin.config->name(), "fake"); + ASSERT_EQ(static_cast>( + fake_plugin.config) + ->value(), + 0); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + grpc_core::CertificateProviderRegistry::RegisterCertificateProviderFactory( + absl::make_unique()); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/xds/xds_certificate_provider_test.cc b/test/core/xds/xds_certificate_provider_test.cc new file mode 100644 index 00000000..147eff3d --- /dev/null +++ b/test/core/xds/xds_certificate_provider_test.cc @@ -0,0 +1,581 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/xds/xds_certificate_provider.h" + +#include +#include + +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +namespace grpc_core { +namespace testing { +namespace { + +constexpr const char* kRootCert1 = "root_cert_1_contents"; +constexpr const char* kRootCert2 = "root_cert_2_contents"; +constexpr const char* kIdentityCert1PrivateKey = "identity_private_key_1"; +constexpr const char* kIdentityCert1 = "identity_cert_1_contents"; +constexpr const char* kIdentityCert2PrivateKey = "identity_private_key_2"; +constexpr const char* kIdentityCert2 = "identity_cert_2_contents"; +constexpr const char* kRootErrorMessage = "root_error_message"; +constexpr const char* kIdentityErrorMessage = "identity_error_message"; + +PemKeyCertPairList MakeKeyCertPairsType1() { + return MakeCertKeyPairs(kIdentityCert1PrivateKey, kIdentityCert1); +} + +PemKeyCertPairList MakeKeyCertPairsType2() { + return MakeCertKeyPairs(kIdentityCert2PrivateKey, kIdentityCert2); +} + +class TestCertificatesWatcher + : public grpc_tls_certificate_distributor::TlsCertificatesWatcherInterface { + public: + ~TestCertificatesWatcher() override { + GRPC_ERROR_UNREF(root_cert_error_); + GRPC_ERROR_UNREF(identity_cert_error_); + } + + void OnCertificatesChanged( + absl::optional root_certs, + absl::optional key_cert_pairs) override { + if (root_certs.has_value()) { + if (!root_certs_.has_value() || + (root_certs_.has_value() && + std::string(root_certs.value()) != root_certs_.value())) { + GRPC_ERROR_UNREF(root_cert_error_); + root_cert_error_ = GRPC_ERROR_NONE; + } + root_certs_.emplace(std::string(root_certs.value())); + } + if (key_cert_pairs.has_value()) { + if (key_cert_pairs != key_cert_pairs_) { + GRPC_ERROR_UNREF(identity_cert_error_); + identity_cert_error_ = GRPC_ERROR_NONE; + key_cert_pairs_ = key_cert_pairs; + } + } + } + + void OnError(grpc_error_handle root_cert_error, + grpc_error_handle identity_cert_error) override { + GRPC_ERROR_UNREF(root_cert_error_); + root_cert_error_ = root_cert_error; + GRPC_ERROR_UNREF(identity_cert_error_); + identity_cert_error_ = identity_cert_error; + } + + const absl::optional& root_certs() const { return root_certs_; } + + const absl::optional& key_cert_pairs() const { + return key_cert_pairs_; + } + + grpc_error_handle root_cert_error() const { return root_cert_error_; } + + grpc_error_handle identity_cert_error() const { return identity_cert_error_; } + + private: + absl::optional root_certs_; + absl::optional key_cert_pairs_; + grpc_error_handle root_cert_error_ = GRPC_ERROR_NONE; + grpc_error_handle identity_cert_error_ = GRPC_ERROR_NONE; +}; + +TEST( + XdsCertificateProviderTest, + RootCertDistributorDifferentFromIdentityCertDistributorDifferentCertNames) { + auto root_cert_distributor = + MakeRefCounted(); + auto identity_cert_distributor = + MakeRefCounted(); + XdsCertificateProvider provider; + provider.UpdateRootCertNameAndDistributor("", "root", root_cert_distributor); + provider.UpdateIdentityCertNameAndDistributor("", "identity", + identity_cert_distributor); + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "", ""); + EXPECT_EQ(watcher->root_certs(), absl::nullopt); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Update both root certs and identity certs + root_cert_distributor->SetKeyMaterials("root", kRootCert1, absl::nullopt); + identity_cert_distributor->SetKeyMaterials("identity", absl::nullopt, + MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for just root certs + root_cert_distributor->SetKeyMaterials( + "root", kRootCert2, + MakeKeyCertPairsType2() /* does not have an effect */); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for identity certs + identity_cert_distributor->SetKeyMaterials( + "identity", kRootCert1 /* does not have an effect */, + MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for both root and identity + root_cert_distributor->SetErrorForCert( + "root", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + identity_cert_distributor->SetErrorForCert( + "identity", absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for root certs. Test that the root cert error is reset. + root_cert_distributor->SetKeyMaterials("root", kRootCert1, absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for identity certs. Test that the identity cert error is + // reset. + identity_cert_distributor->SetKeyMaterials("identity", absl::nullopt, + MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); +} + +TEST(XdsCertificateProviderTest, + RootCertDistributorDifferentFromIdentityCertDistributorSameCertNames) { + auto root_cert_distributor = + MakeRefCounted(); + auto identity_cert_distributor = + MakeRefCounted(); + XdsCertificateProvider provider; + provider.UpdateRootCertNameAndDistributor("", "test", root_cert_distributor); + provider.UpdateIdentityCertNameAndDistributor("", "test", + identity_cert_distributor); + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "", ""); + EXPECT_EQ(watcher->root_certs(), absl::nullopt); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Update both root certs and identity certs + root_cert_distributor->SetKeyMaterials("test", kRootCert1, absl::nullopt); + identity_cert_distributor->SetKeyMaterials("test", absl::nullopt, + MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for just root certs + root_cert_distributor->SetKeyMaterials("test", kRootCert2, absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for identity certs + identity_cert_distributor->SetKeyMaterials("test", absl::nullopt, + MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for both root and identity + root_cert_distributor->SetErrorForCert( + "test", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + identity_cert_distributor->SetErrorForCert( + "test", absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for root certs. Test that the root cert error is reset. + root_cert_distributor->SetKeyMaterials("test", kRootCert1, absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for identity certs. Test that the identity cert error is + // reset. + identity_cert_distributor->SetKeyMaterials("test", absl::nullopt, + MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Test update on unwatched cert name + identity_cert_distributor->SetKeyMaterials("identity", kRootCert2, + MakeKeyCertPairsType2()); + root_cert_distributor->SetKeyMaterials("root", kRootCert1, + MakeKeyCertPairsType1()); +} + +TEST(XdsCertificateProviderTest, + RootCertDistributorSameAsIdentityCertDistributorDifferentCertNames) { + auto distributor = MakeRefCounted(); + XdsCertificateProvider provider; + provider.UpdateRootCertNameAndDistributor("", "root", distributor); + provider.UpdateIdentityCertNameAndDistributor("", "identity", distributor); + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "", ""); + EXPECT_EQ(watcher->root_certs(), absl::nullopt); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Update both root certs and identity certs + distributor->SetKeyMaterials("root", kRootCert1, MakeKeyCertPairsType2()); + distributor->SetKeyMaterials("identity", kRootCert2, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for just root certs + distributor->SetKeyMaterials("root", kRootCert2, MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for identity certs + distributor->SetKeyMaterials("identity", kRootCert1, MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for root + distributor->SetErrorForCert( + "root", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + distributor->SetErrorForCert( + "identity", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for root + distributor->SetKeyMaterials("root", kRootCert1, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for identity + distributor->SetKeyMaterials("identity", kRootCert2, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); +} + +TEST(XdsCertificateProviderTest, + RootCertDistributorSameAsIdentityCertDistributorSameCertNames) { + auto distributor = MakeRefCounted(); + XdsCertificateProvider provider; + provider.UpdateRootCertNameAndDistributor("", "", distributor); + provider.UpdateIdentityCertNameAndDistributor("", "", distributor); + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "", ""); + EXPECT_EQ(watcher->root_certs(), absl::nullopt); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Update both root certs and identity certs + distributor->SetKeyMaterials("", kRootCert1, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for just root certs + distributor->SetKeyMaterials("", kRootCert2, absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Second update for identity certs + distributor->SetKeyMaterials("", absl::nullopt, MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for root + distributor->SetErrorForCert( + "", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for identity + distributor->SetErrorForCert( + "", absl::nullopt, + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for root + distributor->SetKeyMaterials("", kRootCert1, absl::nullopt); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update for identity + distributor->SetKeyMaterials("", absl::nullopt, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); +} + +TEST(XdsCertificateProviderTest, SwapOutDistributorsMultipleTimes) { + auto distributor = MakeRefCounted(); + distributor->SetKeyMaterials("", kRootCert1, MakeKeyCertPairsType1()); + XdsCertificateProvider provider; + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "", ""); + // Initially there are no certificate providers. + EXPECT_EQ(watcher->root_certs(), absl::nullopt); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_THAT( + grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); + // Update root cert distributor. + provider.UpdateRootCertNameAndDistributor("", "", distributor); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), absl::nullopt); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_THAT( + grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); + // Update identity cert distributor + provider.UpdateIdentityCertNameAndDistributor("", "", distributor); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Update both root and identity certs + distributor->SetKeyMaterials("", kRootCert2, MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Set error for both root and identity + distributor->SetErrorForCert( + "", GRPC_ERROR_CREATE_FROM_STATIC_STRING(kRootErrorMessage), + GRPC_ERROR_CREATE_FROM_STATIC_STRING(kIdentityErrorMessage)); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr(kRootErrorMessage)); + EXPECT_THAT(grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr(kIdentityErrorMessage)); + // Send an update again + distributor->SetKeyMaterials("", kRootCert1, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Remove root cert provider + provider.UpdateRootCertNameAndDistributor("", "", nullptr); + distributor->SetKeyMaterials("", kRootCert2, MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); // not updated + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Remove identity cert provider too + provider.UpdateIdentityCertNameAndDistributor("", "", nullptr); + distributor->SetKeyMaterials("", kRootCert1, MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); // not updated + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_THAT( + grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); + // Change certificate names being watched, without any certificate updates. + provider.UpdateRootCertNameAndDistributor("", "root", distributor); + provider.UpdateIdentityCertNameAndDistributor("", "identity", distributor); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_THAT( + grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); + // Send out certificate updates. + distributor->SetKeyMaterials("root", kRootCert2, absl::nullopt); + distributor->SetKeyMaterials("identity", absl::nullopt, + MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Swap in new certificate distributors with different certificate names and + // existing updates. + auto root_cert_distributor = + MakeRefCounted(); + auto identity_cert_distributor = + MakeRefCounted(); + provider.UpdateRootCertNameAndDistributor("", "root", root_cert_distributor); + provider.UpdateIdentityCertNameAndDistributor("", "identity", + identity_cert_distributor); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Change certificate names without any certificate updates. + provider.UpdateRootCertNameAndDistributor("", "test", root_cert_distributor); + provider.UpdateIdentityCertNameAndDistributor("", "test", + identity_cert_distributor); + EXPECT_EQ(watcher->root_certs(), kRootCert2); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); + // Send out certificate updates. + root_cert_distributor->SetKeyMaterials("test", kRootCert1, + MakeKeyCertPairsType1()); + identity_cert_distributor->SetKeyMaterials("test", kRootCert2, + MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_certs(), kRootCert1); + EXPECT_EQ(watcher->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher->identity_cert_error(), GRPC_ERROR_NONE); +} + +TEST(XdsCertificateProviderTest, MultipleCertNames) { + XdsCertificateProvider provider; + // Start watch for "test1". There are no underlying distributors for + // that cert name, so it will return an error. + auto* watcher1 = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher1), "test1", "test1"); + EXPECT_EQ(watcher1->root_certs(), absl::nullopt); + EXPECT_EQ(watcher1->key_cert_pairs(), absl::nullopt); + EXPECT_THAT(grpc_error_std_string(watcher1->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_THAT( + grpc_error_std_string(watcher1->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); + // Add distributor for "test1". This will return data to the watcher. + auto cert_distributor1 = MakeRefCounted(); + cert_distributor1->SetKeyMaterials("root", kRootCert1, absl::nullopt); + cert_distributor1->SetKeyMaterials("identity", absl::nullopt, + MakeKeyCertPairsType1()); + provider.UpdateRootCertNameAndDistributor("test1", "root", cert_distributor1); + provider.UpdateIdentityCertNameAndDistributor("test1", "identity", + cert_distributor1); + EXPECT_EQ(watcher1->root_certs(), kRootCert1); + EXPECT_EQ(watcher1->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher1->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher1->identity_cert_error(), GRPC_ERROR_NONE); + // Add distributor for "test2". + auto cert_distributor2 = MakeRefCounted(); + cert_distributor2->SetKeyMaterials("root2", kRootCert2, absl::nullopt); + cert_distributor2->SetKeyMaterials("identity2", absl::nullopt, + MakeKeyCertPairsType2()); + provider.UpdateRootCertNameAndDistributor("test2", "root2", + cert_distributor2); + provider.UpdateIdentityCertNameAndDistributor("test2", "identity2", + cert_distributor2); + // Add watcher for "test2". This one should return data immediately. + auto* watcher2 = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher2), "test2", "test2"); + EXPECT_EQ(watcher2->root_certs(), kRootCert2); + EXPECT_EQ(watcher2->key_cert_pairs(), MakeKeyCertPairsType2()); + EXPECT_EQ(watcher2->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher2->identity_cert_error(), GRPC_ERROR_NONE); + // The presence of "test2" should not affect "test1". + EXPECT_EQ(watcher1->root_certs(), kRootCert1); + EXPECT_EQ(watcher1->key_cert_pairs(), MakeKeyCertPairsType1()); + EXPECT_EQ(watcher1->root_cert_error(), GRPC_ERROR_NONE); + EXPECT_EQ(watcher1->identity_cert_error(), GRPC_ERROR_NONE); +} + +TEST(XdsCertificateProviderTest, UnknownCertName) { + XdsCertificateProvider provider; + auto* watcher = new TestCertificatesWatcher; + provider.distributor()->WatchTlsCertificates( + std::unique_ptr(watcher), "test", "test"); + EXPECT_THAT(grpc_error_std_string(watcher->root_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for root certificates")); + EXPECT_THAT( + grpc_error_std_string(watcher->identity_cert_error()), + ::testing::HasSubstr( + "No certificate provider available for identity certificates")); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/core/xds/xds_channel_stack_modifier_test.cc b/test/core/xds/xds_channel_stack_modifier_test.cc new file mode 100644 index 00000000..3518e610 --- /dev/null +++ b/test/core/xds/xds_channel_stack_modifier_test.cc @@ -0,0 +1,171 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include "src/core/ext/xds/xds_channel_stack_modifier.h" + +#include + +#include + +#include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/surface/channel_init.h" +#include "src/core/lib/transport/transport_impl.h" +#include "test/core/util/test_config.h" + +namespace grpc_core { +namespace testing { +namespace { + +// Test that XdsChannelStackModifier can be safely copied to channel args +// and destroyed +TEST(XdsChannelStackModifierTest, CopyChannelArgs) { + grpc_init(); + auto server_config_selector_provider = + MakeRefCounted( + std::vector{}); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + EXPECT_EQ(server_config_selector_provider, + XdsChannelStackModifier::GetFromChannelArgs(*args)); + grpc_channel_args_destroy(args); + grpc_shutdown(); +} + +// Test compare on channel args with the same XdsChannelStackModifier +TEST(XdsChannelStackModifierTest, ChannelArgsCompare) { + grpc_init(); + auto server_config_selector_provider = + MakeRefCounted( + std::vector{}); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + grpc_channel_args* new_args = grpc_channel_args_copy(args); + EXPECT_EQ(XdsChannelStackModifier::GetFromChannelArgs(*new_args), + XdsChannelStackModifier::GetFromChannelArgs(*args)); + grpc_channel_args_destroy(args); + grpc_channel_args_destroy(new_args); + grpc_shutdown(); +} + +constexpr char kTestFilter1[] = "test_filter_1"; +constexpr char kTestFilter2[] = "test_filter_2"; + +// Test filters insertion +TEST(XdsChannelStackModifierTest, XdsHttpFiltersInsertion) { + CoreConfiguration::Reset(); + grpc_init(); + // Add 2 test filters to XdsChannelStackModifier + const grpc_channel_filter test_filter_1 = { + nullptr, nullptr, 0, nullptr, nullptr, nullptr, + 0, nullptr, nullptr, nullptr, kTestFilter1}; + const grpc_channel_filter test_filter_2 = { + nullptr, nullptr, 0, nullptr, nullptr, nullptr, + 0, nullptr, nullptr, nullptr, kTestFilter2}; + auto server_config_selector_provider = + MakeRefCounted( + std::vector{&test_filter_1, + &test_filter_2}); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + // Create a phony grpc_channel_stack_builder object + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + grpc_channel_stack_builder_set_channel_arguments(builder, args); + grpc_channel_args_destroy(args); + grpc_transport_vtable fake_transport_vtable; + memset(&fake_transport_vtable, 0, sizeof(grpc_transport_vtable)); + fake_transport_vtable.name = "fake"; + grpc_transport fake_transport = {&fake_transport_vtable}; + grpc_channel_stack_builder_set_transport(builder, &fake_transport); + // Construct channel stack and verify that the test filters were successfully + // added + ASSERT_TRUE(CoreConfiguration::Get().channel_init().CreateStack( + builder, GRPC_SERVER_CHANNEL)); + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), "server"); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), + kTestFilter1); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), + kTestFilter2); + grpc_channel_stack_builder_iterator_destroy(it); + grpc_channel_stack_builder_destroy(builder); + grpc_shutdown(); +} + +// Test filters insertion with OpenCensus plugin registered +TEST(XdsChannelStackModifierTest, XdsHttpFiltersInsertionAfterCensus) { + CoreConfiguration::Reset(); + grpc::RegisterOpenCensusPlugin(); + grpc_init(); + // Add 2 test filters to XdsChannelStackModifier + const grpc_channel_filter test_filter_1 = { + nullptr, nullptr, 0, nullptr, nullptr, nullptr, + 0, nullptr, nullptr, nullptr, kTestFilter1}; + const grpc_channel_filter test_filter_2 = { + nullptr, nullptr, 0, nullptr, nullptr, nullptr, + 0, nullptr, nullptr, nullptr, kTestFilter2}; + auto server_config_selector_provider = + MakeRefCounted( + std::vector{&test_filter_1, + &test_filter_2}); + grpc_arg arg = server_config_selector_provider->MakeChannelArg(); + // Create a phony grpc_channel_stack_builder object + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + grpc_channel_stack_builder_set_channel_arguments(builder, args); + grpc_channel_args_destroy(args); + grpc_transport_vtable fake_transport_vtable; + memset(&fake_transport_vtable, 0, sizeof(grpc_transport_vtable)); + fake_transport_vtable.name = "fake"; + grpc_transport fake_transport = {&fake_transport_vtable}; + grpc_channel_stack_builder_set_transport(builder, &fake_transport); + // Construct channel stack and verify that the test filters were successfully + // added after the census filter + ASSERT_TRUE(CoreConfiguration::Get().channel_init().CreateStack( + builder, GRPC_SERVER_CHANNEL)); + grpc_channel_stack_builder_iterator* it = + grpc_channel_stack_builder_create_iterator_at_first(builder); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), "server"); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), + "opencensus_server"); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), + kTestFilter1); + ASSERT_TRUE(grpc_channel_stack_builder_move_next(it)); + ASSERT_STREQ(grpc_channel_stack_builder_iterator_filter_name(it), + kTestFilter2); + grpc_channel_stack_builder_iterator_destroy(it); + grpc_channel_stack_builder_destroy(builder); + grpc_shutdown(); +} + +} // namespace +} // namespace testing +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/client/client_channel_stress_test.cc b/test/cpp/client/client_channel_stress_test.cc new file mode 100644 index 00000000..24c91908 --- /dev/null +++ b/test/cpp/client/client_channel_stress_test.cc @@ -0,0 +1,351 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/proto/grpc/lb/v1/load_balancer.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +using grpc::lb::v1::LoadBalancer; +using grpc::lb::v1::LoadBalanceRequest; +using grpc::lb::v1::LoadBalanceResponse; + +namespace grpc { +namespace testing { +namespace { + +const size_t kNumBackends = 10; +const size_t kNumBalancers = 5; +const size_t kNumClientThreads = 100; +const int kResolutionUpdateIntervalMs = 50; +const int kServerlistUpdateIntervalMs = 10; +const int kTestDurationSec = 30; + +using BackendServiceImpl = TestServiceImpl; + +class BalancerServiceImpl : public LoadBalancer::Service { + public: + using Stream = ServerReaderWriter; + + explicit BalancerServiceImpl(const std::vector& all_backend_ports) + : all_backend_ports_(all_backend_ports) {} + + Status BalanceLoad(ServerContext* /*context*/, Stream* stream) override { + gpr_log(GPR_INFO, "LB[%p]: Start BalanceLoad.", this); + LoadBalanceRequest request; + stream->Read(&request); + while (!shutdown_) { + stream->Write(BuildRandomResponseForBackends()); + std::this_thread::sleep_for( + std::chrono::milliseconds(kServerlistUpdateIntervalMs)); + } + gpr_log(GPR_INFO, "LB[%p]: Finish BalanceLoad.", this); + return Status::OK; + } + + void Shutdown() { shutdown_ = true; } + + private: + std::string Ip4ToPackedString(const char* ip_str) { + struct in_addr ip4; + GPR_ASSERT(inet_pton(AF_INET, ip_str, &ip4) == 1); + return std::string(reinterpret_cast(&ip4), sizeof(ip4)); + } + + LoadBalanceResponse BuildRandomResponseForBackends() { + // Generate a random serverlist with varying size (if N = + // all_backend_ports_.size(), num_non_drop_entry is in [0, 2N], + // num_drop_entry is in [0, N]), order, duplicate, and drop rate. + size_t num_non_drop_entry = + std::rand() % (all_backend_ports_.size() * 2 + 1); + size_t num_drop_entry = std::rand() % (all_backend_ports_.size() + 1); + std::vector random_backend_indices; + for (size_t i = 0; i < num_non_drop_entry; ++i) { + random_backend_indices.push_back(std::rand() % all_backend_ports_.size()); + } + for (size_t i = 0; i < num_drop_entry; ++i) { + random_backend_indices.push_back(-1); + } + std::shuffle(random_backend_indices.begin(), random_backend_indices.end(), + std::mt19937(std::random_device()())); + // Build the response according to the random list generated above. + LoadBalanceResponse response; + for (int index : random_backend_indices) { + auto* server = response.mutable_server_list()->add_servers(); + if (index < 0) { + server->set_drop(true); + server->set_load_balance_token("load_balancing"); + } else { + server->set_ip_address(Ip4ToPackedString("127.0.0.1")); + server->set_port(all_backend_ports_[index]); + } + } + return response; + } + + std::atomic_bool shutdown_{false}; + const std::vector all_backend_ports_; +}; + +class ClientChannelStressTest { + public: + void Run() { + Start(); + // Keep updating resolution for the test duration. + gpr_log(GPR_INFO, "Start updating resolution."); + const auto wait_duration = + std::chrono::milliseconds(kResolutionUpdateIntervalMs); + std::vector addresses; + auto start_time = std::chrono::steady_clock::now(); + while (true) { + if (std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count() > kTestDurationSec) { + break; + } + // Generate a random subset of balancers. + addresses.clear(); + for (const auto& balancer_server : balancer_servers_) { + // Select each address with probability of 0.8. + if (std::rand() % 10 < 8) { + addresses.emplace_back(AddressData{balancer_server.port_, ""}); + } + } + std::shuffle(addresses.begin(), addresses.end(), + std::mt19937(std::random_device()())); + SetNextResolution(addresses); + std::this_thread::sleep_for(wait_duration); + } + gpr_log(GPR_INFO, "Finish updating resolution."); + Shutdown(); + } + + private: + template + struct ServerThread { + explicit ServerThread(const std::string& type, + const std::string& server_host, T* service) + : type_(type), service_(service) { + grpc::internal::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Start from firing before the wait below is hit. + grpc::internal::MutexLock lock(&mu); + port_ = grpc_pick_unused_port_or_die(); + gpr_log(GPR_INFO, "starting %s server on port %d", type_.c_str(), port_); + grpc::internal::CondVar cond; + thread_ = absl::make_unique( + std::bind(&ServerThread::Start, this, server_host, &mu, &cond)); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", type_.c_str()); + } + + void Start(const std::string& server_host, grpc::internal::Mutex* mu, + grpc::internal::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(mu); + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + builder.AddListeningPort(server_address.str(), + InsecureServerCredentials()); + builder.RegisterService(service_); + server_ = builder.BuildAndStart(); + cond->Signal(); + } + + void Shutdown() { + gpr_log(GPR_INFO, "%s about to shutdown", type_.c_str()); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", type_.c_str()); + } + + int port_; + std::string type_; + std::unique_ptr server_; + T* service_; + std::unique_ptr thread_; + }; + + struct AddressData { + int port; + std::string balancer_name; + }; + + static grpc_core::ServerAddressList CreateAddressListFromAddressDataList( + const std::vector& address_data) { + grpc_core::ServerAddressList addresses; + for (const auto& addr : address_data) { + std::string lb_uri_str = absl::StrCat("ipv4:127.0.0.1:", addr.port); + absl::StatusOr lb_uri = grpc_core::URI::Parse(lb_uri_str); + GPR_ASSERT(lb_uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*lb_uri, &address)); + grpc_arg arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast(addr.balancer_name.c_str())); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(nullptr, &arg, 1); + addresses.emplace_back(address.addr, address.len, args); + } + return addresses; + } + + static grpc_core::Resolver::Result MakeResolverResult( + const std::vector& balancer_address_data) { + grpc_core::Resolver::Result result; + grpc_error_handle error = GRPC_ERROR_NONE; + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{\"loadBalancingConfig\":[{\"grpclb\":{}}]}", &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ServerAddressList balancer_addresses = + CreateAddressListFromAddressDataList(balancer_address_data); + grpc_arg arg = CreateGrpclbBalancerAddressesArg(&balancer_addresses); + result.args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + return result; + } + + void SetNextResolution(const std::vector& address_data) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = MakeResolverResult(address_data); + response_generator_->SetResponse(std::move(result)); + } + + void KeepSendingRequests() { + gpr_log(GPR_INFO, "Start sending requests."); + while (!shutdown_) { + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(1000)); + EchoRequest request; + request.set_message("test"); + EchoResponse response; + { + std::lock_guard lock(stub_mutex_); + Status status = stub_->Echo(&context, request, &response); + } + } + gpr_log(GPR_INFO, "Finish sending requests."); + } + + void CreateStub() { + ChannelArguments args; + response_generator_ = + grpc_core::MakeRefCounted(); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + std::ostringstream uri; + uri << "fake:///servername_not_used"; + channel_ = ::grpc::CreateCustomChannel(uri.str(), + InsecureChannelCredentials(), args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void Start() { + // Start the backends. + std::vector backend_ports; + for (size_t i = 0; i < kNumBackends; ++i) { + backends_.emplace_back(new BackendServiceImpl()); + backend_servers_.emplace_back(ServerThread( + "backend", server_host_, backends_.back().get())); + backend_ports.push_back(backend_servers_.back().port_); + } + // Start the load balancers. + for (size_t i = 0; i < kNumBalancers; ++i) { + balancers_.emplace_back(new BalancerServiceImpl(backend_ports)); + balancer_servers_.emplace_back(ServerThread( + "balancer", server_host_, balancers_.back().get())); + } + // Start sending RPCs in multiple threads. + CreateStub(); + for (size_t i = 0; i < kNumClientThreads; ++i) { + client_threads_.emplace_back( + std::thread(&ClientChannelStressTest::KeepSendingRequests, this)); + } + } + + void Shutdown() { + shutdown_ = true; + for (size_t i = 0; i < client_threads_.size(); ++i) { + client_threads_[i].join(); + } + for (size_t i = 0; i < balancers_.size(); ++i) { + balancers_[i]->Shutdown(); + balancer_servers_[i].Shutdown(); + } + for (size_t i = 0; i < backends_.size(); ++i) { + backend_servers_[i].Shutdown(); + } + } + + std::atomic_bool shutdown_{false}; + const std::string server_host_ = "localhost"; + std::shared_ptr channel_; + std::unique_ptr stub_; + std::mutex stub_mutex_; + std::vector> backends_; + std::vector> balancers_; + std::vector> backend_servers_; + std::vector> balancer_servers_; + grpc_core::RefCountedPtr + response_generator_; + std::vector client_threads_; +}; + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::ClientChannelStressTest test; + grpc_init(); + test.Run(); + grpc_shutdown(); + return 0; +} diff --git a/test/cpp/client/credentials_test.cc b/test/cpp/client/credentials_test.cc new file mode 100644 index 00000000..cc140da2 --- /dev/null +++ b/test/cpp/client/credentials_test.cc @@ -0,0 +1,583 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/common/tls_credentials_options_util.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +namespace { + +constexpr const char* kRootCertName = "root_cert_name"; +constexpr const char* kRootCertContents = "root_cert_contents"; +constexpr const char* kIdentityCertName = "identity_cert_name"; +constexpr const char* kIdentityCertPrivateKey = "identity_private_key"; +constexpr const char* kIdentityCertContents = "identity_cert_contents"; + +using ::grpc::experimental::FileWatcherCertificateProvider; +using ::grpc::experimental::StaticDataCertificateProvider; +using ::grpc::experimental::TlsServerAuthorizationCheckArg; +using ::grpc::experimental::TlsServerAuthorizationCheckConfig; +using ::grpc::experimental::TlsServerAuthorizationCheckInterface; + +static void tls_server_authorization_check_callback( + grpc_tls_server_authorization_check_arg* arg) { + GPR_ASSERT(arg != nullptr); + std::string cb_user_data = "cb_user_data"; + arg->cb_user_data = static_cast(gpr_strdup(cb_user_data.c_str())); + arg->success = 1; + arg->target_name = gpr_strdup("callback_target_name"); + arg->peer_cert = gpr_strdup("callback_peer_cert"); + arg->status = GRPC_STATUS_OK; + arg->error_details->set_error_details("callback_error_details"); +} + +class TestTlsServerAuthorizationCheck + : public TlsServerAuthorizationCheckInterface { + int Schedule(TlsServerAuthorizationCheckArg* arg) override { + GPR_ASSERT(arg != nullptr); + std::string cb_user_data = "cb_user_data"; + arg->set_cb_user_data(static_cast(gpr_strdup(cb_user_data.c_str()))); + arg->set_success(1); + arg->set_target_name("sync_target_name"); + arg->set_peer_cert("sync_peer_cert"); + arg->set_status(GRPC_STATUS_OK); + arg->set_error_details("sync_error_details"); + return 1; + } + + void Cancel(TlsServerAuthorizationCheckArg* arg) override { + GPR_ASSERT(arg != nullptr); + arg->set_status(GRPC_STATUS_PERMISSION_DENIED); + arg->set_error_details("cancelled"); + } +}; +} // namespace + +namespace grpc { +namespace testing { +namespace { + +TEST(CredentialsTest, InvalidGoogleRefreshToken) { + std::shared_ptr bad1 = GoogleRefreshTokenCredentials(""); + EXPECT_EQ(static_cast(nullptr), bad1.get()); +} + +TEST(CredentialsTest, DefaultCredentials) { + auto creds = GoogleDefaultCredentials(); +} + +TEST(CredentialsTest, ExternalAccountCredentials) { + // url credentials + std::string url_options_string( + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"url\":\"https://foo.com:5555/" + "generate_subject_token_format_json\",\"headers\":{\"Metadata-Flavor\":" + "\"Google\"},\"format\":{\"type\":\"json\",\"subject_token_field_name\":" + "\"access_token\"}},\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"); + auto url_creds = grpc::ExternalAccountCredentials(url_options_string, + {"scope1", "scope2"}); + EXPECT_TRUE(url_creds != nullptr); + // file credentials + std::string file_options_string( + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"file\":\"credentials_file_path\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"); + auto file_creds = grpc::ExternalAccountCredentials(file_options_string, + {"scope1", "scope2"}); + EXPECT_TRUE(file_creds != nullptr); + // aws credentials + std::string aws_options_string( + "{\"type\":\"external_account\",\"audience\":\"audience\",\"subject_" + "token_type\":\"subject_token_type\",\"service_account_impersonation_" + "url\":\"service_account_impersonation_url\",\"token_url\":\"https://" + "foo.com:5555/token\",\"token_info_url\":\"https://foo.com:5555/" + "token_info\",\"credential_source\":{\"environment_id\":\"aws1\"," + "\"region_url\":\"https://foo.com:5555/region_url\",\"url\":\"https://" + "foo.com:5555/url\",\"regional_cred_verification_url\":\"https://" + "foo.com:5555/regional_cred_verification_url_{region}\"}," + "\"quota_project_id\":\"quota_" + "project_id\",\"client_id\":\"client_id\",\"client_secret\":\"client_" + "secret\"}"); + auto aws_creds = grpc::ExternalAccountCredentials(aws_options_string, + {"scope1", "scope2"}); + EXPECT_TRUE(aws_creds != nullptr); +} + +TEST(CredentialsTest, StsCredentialsOptionsCppToCore) { + grpc::experimental::StsCredentialsOptions options; + options.token_exchange_service_uri = "https://foo.com/exchange"; + options.resource = "resource"; + options.audience = "audience"; + options.scope = "scope"; + // options.requested_token_type explicitly not set. + options.subject_token_path = "/foo/bar"; + options.subject_token_type = "nice_token_type"; + options.actor_token_path = "/foo/baz"; + options.actor_token_type = "even_nicer_token_type"; + grpc_sts_credentials_options core_opts = + grpc::experimental::StsCredentialsCppToCoreOptions(options); + EXPECT_EQ(options.token_exchange_service_uri, + core_opts.token_exchange_service_uri); + EXPECT_EQ(options.resource, core_opts.resource); + EXPECT_EQ(options.audience, core_opts.audience); + EXPECT_EQ(options.scope, core_opts.scope); + EXPECT_EQ(options.requested_token_type, core_opts.requested_token_type); + EXPECT_EQ(options.subject_token_path, core_opts.subject_token_path); + EXPECT_EQ(options.subject_token_type, core_opts.subject_token_type); + EXPECT_EQ(options.actor_token_path, core_opts.actor_token_path); + EXPECT_EQ(options.actor_token_type, core_opts.actor_token_type); +} + +TEST(CredentialsTest, StsCredentialsOptionsJson) { + const char valid_json[] = R"( + { + "token_exchange_service_uri": "https://foo/exchange", + "resource": "resource", + "audience": "audience", + "scope": "scope", + "requested_token_type": "requested_token_type", + "subject_token_path": "subject_token_path", + "subject_token_type": "subject_token_type", + "actor_token_path": "actor_token_path", + "actor_token_type": "actor_token_type" + })"; + grpc::experimental::StsCredentialsOptions options; + EXPECT_TRUE( + grpc::experimental::StsCredentialsOptionsFromJson(valid_json, &options) + .ok()); + EXPECT_EQ(options.token_exchange_service_uri, "https://foo/exchange"); + EXPECT_EQ(options.resource, "resource"); + EXPECT_EQ(options.audience, "audience"); + EXPECT_EQ(options.scope, "scope"); + EXPECT_EQ(options.requested_token_type, "requested_token_type"); + EXPECT_EQ(options.subject_token_path, "subject_token_path"); + EXPECT_EQ(options.subject_token_type, "subject_token_type"); + EXPECT_EQ(options.actor_token_path, "actor_token_path"); + EXPECT_EQ(options.actor_token_type, "actor_token_type"); + + const char minimum_valid_json[] = R"( + { + "token_exchange_service_uri": "https://foo/exchange", + "subject_token_path": "subject_token_path", + "subject_token_type": "subject_token_type" + })"; + EXPECT_TRUE(grpc::experimental::StsCredentialsOptionsFromJson( + minimum_valid_json, &options) + .ok()); + EXPECT_EQ(options.token_exchange_service_uri, "https://foo/exchange"); + EXPECT_EQ(options.resource, ""); + EXPECT_EQ(options.audience, ""); + EXPECT_EQ(options.scope, ""); + EXPECT_EQ(options.requested_token_type, ""); + EXPECT_EQ(options.subject_token_path, "subject_token_path"); + EXPECT_EQ(options.subject_token_type, "subject_token_type"); + EXPECT_EQ(options.actor_token_path, ""); + EXPECT_EQ(options.actor_token_type, ""); + + const char invalid_json[] = R"( + I'm not a valid JSON. + )"; + EXPECT_EQ( + grpc::StatusCode::INVALID_ARGUMENT, + grpc::experimental::StsCredentialsOptionsFromJson(invalid_json, &options) + .error_code()); + + const char invalid_json_missing_subject_token_type[] = R"( + { + "token_exchange_service_uri": "https://foo/exchange", + "subject_token_path": "subject_token_path" + })"; + auto status = grpc::experimental::StsCredentialsOptionsFromJson( + invalid_json_missing_subject_token_type, &options); + EXPECT_EQ(grpc::StatusCode::INVALID_ARGUMENT, status.error_code()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("subject_token_type")); + + const char invalid_json_missing_subject_token_path[] = R"( + { + "token_exchange_service_uri": "https://foo/exchange", + "subject_token_type": "subject_token_type" + })"; + status = grpc::experimental::StsCredentialsOptionsFromJson( + invalid_json_missing_subject_token_path, &options); + EXPECT_EQ(grpc::StatusCode::INVALID_ARGUMENT, status.error_code()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("subject_token_path")); + + const char invalid_json_missing_token_exchange_uri[] = R"( + { + "subject_token_path": "subject_token_path", + "subject_token_type": "subject_token_type" + })"; + status = grpc::experimental::StsCredentialsOptionsFromJson( + invalid_json_missing_token_exchange_uri, &options); + EXPECT_EQ(grpc::StatusCode::INVALID_ARGUMENT, status.error_code()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("token_exchange_service_uri")); +} + +TEST(CredentialsTest, StsCredentialsOptionsFromEnv) { + // Unset env and check expected failure. + gpr_unsetenv("STS_CREDENTIALS"); + grpc::experimental::StsCredentialsOptions options; + auto status = grpc::experimental::StsCredentialsOptionsFromEnv(&options); + EXPECT_EQ(grpc::StatusCode::NOT_FOUND, status.error_code()); + + // Set env and check for success. + const char valid_json[] = R"( + { + "token_exchange_service_uri": "https://foo/exchange", + "subject_token_path": "subject_token_path", + "subject_token_type": "subject_token_type" + })"; + char* creds_file_name; + FILE* creds_file = gpr_tmpfile("sts_creds_options", &creds_file_name); + ASSERT_NE(creds_file_name, nullptr); + ASSERT_NE(creds_file, nullptr); + ASSERT_EQ(sizeof(valid_json), + fwrite(valid_json, 1, sizeof(valid_json), creds_file)); + fclose(creds_file); + gpr_setenv("STS_CREDENTIALS", creds_file_name); + gpr_free(creds_file_name); + status = grpc::experimental::StsCredentialsOptionsFromEnv(&options); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(options.token_exchange_service_uri, "https://foo/exchange"); + EXPECT_EQ(options.resource, ""); + EXPECT_EQ(options.audience, ""); + EXPECT_EQ(options.scope, ""); + EXPECT_EQ(options.requested_token_type, ""); + EXPECT_EQ(options.subject_token_path, "subject_token_path"); + EXPECT_EQ(options.subject_token_type, "subject_token_type"); + EXPECT_EQ(options.actor_token_path, ""); + EXPECT_EQ(options.actor_token_type, ""); + + // Cleanup. + gpr_unsetenv("STS_CREDENTIALS"); +} + +TEST(CredentialsTest, TlsServerAuthorizationCheckArgCallback) { + grpc_tls_server_authorization_check_arg* c_arg = + new grpc_tls_server_authorization_check_arg; + c_arg->cb = tls_server_authorization_check_callback; + c_arg->context = nullptr; + c_arg->error_details = new grpc_tls_error_details(); + TlsServerAuthorizationCheckArg* arg = + new TlsServerAuthorizationCheckArg(c_arg); + arg->set_cb_user_data(nullptr); + arg->set_success(0); + arg->set_target_name("target_name"); + arg->set_peer_cert("peer_cert"); + arg->set_status(GRPC_STATUS_UNAUTHENTICATED); + arg->set_error_details("error_details"); + const char* target_name_before_callback = c_arg->target_name; + const char* peer_cert_before_callback = c_arg->peer_cert; + + arg->OnServerAuthorizationCheckDoneCallback(); + EXPECT_STREQ(static_cast(arg->cb_user_data()), "cb_user_data"); + gpr_free(arg->cb_user_data()); + EXPECT_EQ(arg->success(), 1); + EXPECT_STREQ(arg->target_name().c_str(), "callback_target_name"); + EXPECT_STREQ(arg->peer_cert().c_str(), "callback_peer_cert"); + EXPECT_EQ(arg->status(), GRPC_STATUS_OK); + EXPECT_STREQ(arg->error_details().c_str(), "callback_error_details"); + + // Cleanup. + gpr_free(const_cast(target_name_before_callback)); + gpr_free(const_cast(peer_cert_before_callback)); + gpr_free(const_cast(c_arg->target_name)); + gpr_free(const_cast(c_arg->peer_cert)); + delete c_arg->error_details; + delete arg; + delete c_arg; +} + +TEST(CredentialsTest, TlsServerAuthorizationCheckConfigSchedule) { + std::shared_ptr + test_server_authorization_check(new TestTlsServerAuthorizationCheck()); + TlsServerAuthorizationCheckConfig config(test_server_authorization_check); + grpc_tls_server_authorization_check_arg* c_arg = + new grpc_tls_server_authorization_check_arg(); + c_arg->error_details = new grpc_tls_error_details(); + c_arg->context = nullptr; + TlsServerAuthorizationCheckArg* arg = + new TlsServerAuthorizationCheckArg(c_arg); + arg->set_cb_user_data(nullptr); + arg->set_success(0); + arg->set_target_name("target_name"); + arg->set_peer_cert("peer_cert"); + arg->set_status(GRPC_STATUS_PERMISSION_DENIED); + arg->set_error_details("error_details"); + const char* target_name_before_schedule = c_arg->target_name; + const char* peer_cert_before_schedule = c_arg->peer_cert; + + int schedule_output = config.Schedule(arg); + EXPECT_EQ(schedule_output, 1); + EXPECT_STREQ(static_cast(arg->cb_user_data()), "cb_user_data"); + EXPECT_EQ(arg->success(), 1); + EXPECT_STREQ(arg->target_name().c_str(), "sync_target_name"); + EXPECT_STREQ(arg->peer_cert().c_str(), "sync_peer_cert"); + EXPECT_EQ(arg->status(), GRPC_STATUS_OK); + EXPECT_STREQ(arg->error_details().c_str(), "sync_error_details"); + + // Cleanup. + gpr_free(arg->cb_user_data()); + gpr_free(const_cast(target_name_before_schedule)); + gpr_free(const_cast(peer_cert_before_schedule)); + gpr_free(const_cast(c_arg->target_name)); + gpr_free(const_cast(c_arg->peer_cert)); + delete c_arg->error_details; + if (c_arg->destroy_context != nullptr) { + c_arg->destroy_context(c_arg->context); + } + delete c_arg; +} + +TEST(CredentialsTest, TlsServerAuthorizationCheckConfigCppToC) { + std::shared_ptr + test_server_authorization_check(new TestTlsServerAuthorizationCheck()); + TlsServerAuthorizationCheckConfig config(test_server_authorization_check); + grpc_tls_server_authorization_check_arg c_arg; + c_arg.cb = tls_server_authorization_check_callback; + c_arg.cb_user_data = nullptr; + c_arg.success = 0; + c_arg.target_name = "target_name"; + c_arg.peer_cert = "peer_cert"; + c_arg.status = GRPC_STATUS_UNAUTHENTICATED; + c_arg.error_details = new grpc_tls_error_details(); + c_arg.error_details->set_error_details("error_details"); + c_arg.config = config.c_config(); + c_arg.context = nullptr; + int c_schedule_output = (c_arg.config)->Schedule(&c_arg); + EXPECT_EQ(c_schedule_output, 1); + EXPECT_STREQ(static_cast(c_arg.cb_user_data), "cb_user_data"); + EXPECT_EQ(c_arg.success, 1); + EXPECT_STREQ(c_arg.target_name, "sync_target_name"); + EXPECT_STREQ(c_arg.peer_cert, "sync_peer_cert"); + EXPECT_EQ(c_arg.status, GRPC_STATUS_OK); + EXPECT_STREQ(c_arg.error_details->error_details().c_str(), + "sync_error_details"); + + // Cleanup. + gpr_free(c_arg.cb_user_data); + c_arg.destroy_context(c_arg.context); + delete c_arg.error_details; + gpr_free(const_cast(c_arg.target_name)); + gpr_free(const_cast(c_arg.peer_cert)); +} + +TEST(CredentialsTest, TlsChannelCredentialsWithDefaultRoots) { + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST( + CredentialsTest, + TlsChannelCredentialsWithStaticDataCertificateProviderLoadingRootAndIdentity) { + experimental::IdentityKeyCertPair key_cert_pair; + key_cert_pair.private_key = kIdentityCertPrivateKey; + key_cert_pair.certificate_chain = kIdentityCertContents; + std::vector identity_key_cert_pairs; + identity_key_cert_pairs.emplace_back(key_cert_pair); + auto certificate_provider = std::make_shared( + kRootCertContents, identity_key_cert_pairs); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST(CredentialsTest, + TlsChannelCredentialsWithStaticDataCertificateProviderLoadingRootOnly) { + auto certificate_provider = + std::make_shared(kRootCertContents); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + GPR_ASSERT(certificate_provider != nullptr); + GPR_ASSERT(certificate_provider->c_provider() != nullptr); + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST( + CredentialsTest, + TlsChannelCredentialsWithDefaultRootsAndStaticDataCertificateProviderLoadingIdentityOnly) { + experimental::IdentityKeyCertPair key_cert_pair; + key_cert_pair.private_key = kIdentityCertPrivateKey; + key_cert_pair.certificate_chain = kIdentityCertContents; + std::vector identity_key_cert_pairs; + identity_key_cert_pairs.emplace_back(key_cert_pair); + auto certificate_provider = + std::make_shared(identity_key_cert_pairs); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider(certificate_provider); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST( + CredentialsTest, + TlsChannelCredentialsWithFileWatcherCertificateProviderLoadingRootAndIdentity) { + auto certificate_provider = std::make_shared( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST(CredentialsTest, + TlsChannelCredentialsWithFileWatcherCertificateProviderLoadingRootOnly) { + auto certificate_provider = + std::make_shared(CA_CERT_PATH, 1); + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.set_server_verification_option(GRPC_TLS_SERVER_VERIFICATION); + auto test_server_authorization_check = + std::make_shared(); + auto server_authorization_check_config = + std::make_shared( + test_server_authorization_check); + options.set_server_authorization_check_config( + server_authorization_check_config); + auto channel_credentials = grpc::experimental::TlsCredentials(options); + GPR_ASSERT(channel_credentials.get() != nullptr); +} + +TEST(CredentialsTest, TlsServerAuthorizationCheckConfigErrorMessages) { + std::shared_ptr config( + new TlsServerAuthorizationCheckConfig(nullptr)); + grpc_tls_server_authorization_check_arg* c_arg = + new grpc_tls_server_authorization_check_arg; + c_arg->error_details = new grpc_tls_error_details(); + c_arg->context = nullptr; + TlsServerAuthorizationCheckArg* arg = + new TlsServerAuthorizationCheckArg(c_arg); + int schedule_output = config->Schedule(arg); + + EXPECT_EQ(schedule_output, 1); + EXPECT_EQ(arg->status(), GRPC_STATUS_NOT_FOUND); + EXPECT_STREQ( + arg->error_details().c_str(), + "the interface of the server authorization check config is nullptr"); + + arg->set_status(GRPC_STATUS_OK); + config->Cancel(arg); + EXPECT_EQ(arg->status(), GRPC_STATUS_NOT_FOUND); + EXPECT_STREQ( + arg->error_details().c_str(), + "the interface of the server authorization check config is nullptr"); + + // Cleanup. + delete c_arg->error_details; + if (c_arg->destroy_context != nullptr) { + c_arg->destroy_context(c_arg->context); + } + delete c_arg; +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/client/destroy_grpclb_channel_with_active_connect_stress_test.cc b/test/cpp/client/destroy_grpclb_channel_with_active_connect_stress_test.cc new file mode 100644 index 00000000..7b5fa992 --- /dev/null +++ b/test/cpp/client/destroy_grpclb_channel_with_active_connect_stress_test.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace { + +void TryConnectAndDestroy() { + auto response_generator = + grpc_core::MakeRefCounted(); + // Return a grpclb address with an IP address on the IPv6 discard prefix + // (https://tools.ietf.org/html/rfc6666). This is important because + // the behavior we want in this test is for a TCP connect attempt to "freeze", + // i.e. we want to send SYN, and then *not* receive SYN-ACK or RST. + // The precise behavior is dependant on the test runtime environment though, + // since connect() attempts on this address may unfortunately result in + // "network unreachable" errors in some test runtime environments. + absl::StatusOr lb_uri = + grpc_core::URI::Parse("ipv6:[0100::1234]:443"); + ASSERT_TRUE(lb_uri.ok()); + grpc_resolved_address address; + ASSERT_TRUE(grpc_parse_uri(*lb_uri, &address)); + grpc_core::ServerAddressList addresses; + addresses.emplace_back(address.addr, address.len, nullptr); + grpc_core::Resolver::Result lb_address_result; + grpc_error_handle error = GRPC_ERROR_NONE; + lb_address_result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{\"loadBalancingConfig\":[{\"grpclb\":{}}]}", &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + grpc_arg arg = grpc_core::CreateGrpclbBalancerAddressesArg(&addresses); + lb_address_result.args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + response_generator->SetResponse(lb_address_result); + grpc::ChannelArguments args; + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator.get()); + // Explicitly set the connect deadline to the same amount of + // time as the WaitForConnected time. The goal is to get the + // connect timeout code to run at about the same time as when + // the channel gets destroyed, to try to reproduce a race. + args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", + grpc_test_slowdown_factor() * 100); + std::ostringstream uri; + uri << "fake:///servername_not_used"; + auto channel = ::grpc::CreateCustomChannel( + uri.str(), grpc::InsecureChannelCredentials(), args); + // Start connecting, and give some time for the TCP connection attempt to the + // unreachable balancer to begin. The connection should never become ready + // because the LB we're trying to connect to is unreachable. + channel->GetState(true /* try_to_connect */); + ASSERT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100))); + ASSERT_EQ("grpclb", channel->GetLoadBalancingPolicyName()); + channel.reset(); +}; + +TEST(DestroyGrpclbChannelWithActiveConnectStressTest, + LoopTryConnectAndDestroy) { + grpc_init(); + std::vector> threads; + // 100 is picked for number of threads just + // because it's enough to reproduce a certain crash almost 100% + // at this time of writing. + const int kNumThreads = 100; + threads.reserve(kNumThreads); + for (int i = 0; i < kNumThreads; i++) { + threads.emplace_back(new std::thread(TryConnectAndDestroy)); + } + for (size_t i = 0; i < threads.size(); i++) { + threads[i]->join(); + } + grpc_shutdown(); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/codegen/codegen_test_full.cc b/test/cpp/codegen/codegen_test_full.cc new file mode 100644 index 00000000..b0b46b63 --- /dev/null +++ b/test/cpp/codegen/codegen_test_full.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +class CodegenTestFull : public ::testing::Test {}; + +TEST_F(CodegenTestFull, Init) { + grpc::CompletionQueue cq; + void* tag = nullptr; + bool ok = false; + cq.AsyncNext(&tag, &ok, gpr_time_0(GPR_CLOCK_REALTIME)); + ASSERT_FALSE(ok); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/codegen/codegen_test_minimal.cc b/test/cpp/codegen/codegen_test_minimal.cc new file mode 100644 index 00000000..60457308 --- /dev/null +++ b/test/cpp/codegen/codegen_test_minimal.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +class CodegenTestMinimal : public ::testing::Test {}; + +TEST_F(CodegenTestMinimal, Build) {} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/codegen/golden_file_test.cc b/test/cpp/codegen/golden_file_test.cc new file mode 100644 index 00000000..9af017e2 --- /dev/null +++ b/test/cpp/codegen/golden_file_test.cc @@ -0,0 +1,78 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/flags/flag.h" + +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG( + std::string, generated_file_path, "", + "path to the directory containing generated files compiler_test.grpc.pb.h " + "and compiler_test_mock.grpc.pb.h"); + +const char kGoldenFilePath[] = "test/cpp/codegen/compiler_test_golden"; +const char kMockGoldenFilePath[] = "test/cpp/codegen/compiler_test_mock_golden"; + +void run_test(const std::basic_string& generated_file, + const std::basic_string& golden_file) { + std::ifstream generated(generated_file); + std::ifstream golden(golden_file); + + ASSERT_TRUE(generated.good()); + ASSERT_TRUE(golden.good()); + + std::ostringstream gen_oss; + std::ostringstream gold_oss; + gen_oss << generated.rdbuf(); + gold_oss << golden.rdbuf(); + EXPECT_EQ(gold_oss.str(), gen_oss.str()); + + generated.close(); + golden.close(); +} + +TEST(GoldenFileTest, TestGeneratedFile) { + run_test(absl::GetFlag(FLAGS_generated_file_path) + "compiler_test.grpc.pb.h", + kGoldenFilePath); +} + +TEST(GoldenMockFileTest, TestGeneratedMockFile) { + run_test( + absl::GetFlag(FLAGS_generated_file_path) + "compiler_test_mock.grpc.pb.h", + kMockGoldenFilePath); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + if (absl::GetFlag(FLAGS_generated_file_path).empty()) { + absl::SetFlag(&FLAGS_generated_file_path, "gens/src/proto/grpc/testing/"); + } + if (absl::GetFlag(FLAGS_generated_file_path).back() != '/') { + absl::SetFlag(&FLAGS_generated_file_path, + absl::GetFlag(FLAGS_generated_file_path).append("/")); + } + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/codegen/proto_utils_test.cc b/test/cpp/codegen/proto_utils_test.cc new file mode 100644 index 00000000..da2ec637 --- /dev/null +++ b/test/cpp/codegen/proto_utils_test.cc @@ -0,0 +1,196 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { + +namespace internal { + +// Provide access to ProtoBufferWriter internals. +class ProtoBufferWriterPeer { + public: + explicit ProtoBufferWriterPeer(ProtoBufferWriter* writer) : writer_(writer) {} + bool have_backup() const { return writer_->have_backup_; } + const grpc_slice& backup_slice() const { return writer_->backup_slice_; } + const grpc_slice& slice() const { return writer_->slice_; } + + private: + ProtoBufferWriter* writer_; +}; + +// Provide access to ByteBuffer internals. +class GrpcByteBufferPeer { + public: + explicit GrpcByteBufferPeer(ByteBuffer* bb) : bb_(bb) {} + grpc_byte_buffer* c_buffer() { return bb_->c_buffer(); } + + private: + ByteBuffer* bb_; +}; + +class ProtoUtilsTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + // Ensure the ProtoBufferWriter internals are initialized. + grpc::internal::GrpcLibraryInitializer init; + init.summon(); + grpc::GrpcLibraryCodegen lib; + grpc_init(); + } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +// Regression test for a memory corruption bug where a series of +// ProtoBufferWriter Next()/Backup() invocations could result in a dangling +// pointer returned by Next() due to the interaction between grpc_slice inlining +// and GRPC_SLICE_START_PTR. +TEST_F(ProtoUtilsTest, TinyBackupThenNext) { + ByteBuffer bp; + const int block_size = 1024; + ProtoBufferWriter writer(&bp, block_size, 8192); + ProtoBufferWriterPeer peer(&writer); + + void* data; + int size; + // Allocate a slice. + ASSERT_TRUE(writer.Next(&data, &size)); + EXPECT_EQ(block_size, size); + // Return a single byte. + writer.BackUp(1); + EXPECT_FALSE(peer.have_backup()); + // On the next allocation, the returned slice is non-inlined. + ASSERT_TRUE(writer.Next(&data, &size)); + EXPECT_TRUE(peer.slice().refcount != nullptr); + EXPECT_EQ(block_size, size); +} + +namespace { + +// Set backup_size to 0 to indicate no backup is needed. +void BufferWriterTest(int block_size, int total_size, int backup_size) { + ByteBuffer bb; + ProtoBufferWriter writer(&bb, block_size, total_size); + + int written_size = 0; + void* data; + int size = 0; + bool backed_up_entire_slice = false; + + while (written_size < total_size) { + EXPECT_TRUE(writer.Next(&data, &size)); + EXPECT_GT(size, 0); + EXPECT_TRUE(data); + int write_size = size; + bool should_backup = false; + if (backup_size > 0 && size > backup_size) { + write_size = size - backup_size; + should_backup = true; + } else if (size == backup_size && !backed_up_entire_slice) { + // only backup entire slice once. + backed_up_entire_slice = true; + should_backup = true; + write_size = 0; + } + // May need a last backup. + if (write_size + written_size > total_size) { + write_size = total_size - written_size; + should_backup = true; + backup_size = size - write_size; + ASSERT_GT(backup_size, 0); + } + for (int i = 0; i < write_size; i++) { + (static_cast(data))[i] = written_size % 128; + written_size++; + } + if (should_backup) { + writer.BackUp(backup_size); + } + } + EXPECT_EQ(bb.Length(), (size_t)total_size); + + grpc_byte_buffer_reader reader; + GrpcByteBufferPeer peer(&bb); + grpc_byte_buffer_reader_init(&reader, peer.c_buffer()); + int read_bytes = 0; + while (read_bytes < total_size) { + grpc_slice s; + EXPECT_TRUE(grpc_byte_buffer_reader_next(&reader, &s)); + for (size_t i = 0; i < GRPC_SLICE_LENGTH(s); i++) { + EXPECT_EQ(GRPC_SLICE_START_PTR(s)[i], read_bytes % 128); + read_bytes++; + } + grpc_slice_unref(s); + } + EXPECT_EQ(read_bytes, total_size); + grpc_byte_buffer_reader_destroy(&reader); +} + +class WriterTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + grpc::internal::GrpcLibraryInitializer init; + init.summon(); + grpc::GrpcLibraryCodegen lib; + // Ensure the ProtoBufferWriter internals are initialized. + grpc_init(); + } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +TEST_F(WriterTest, TinyBlockTinyBackup) { + for (int i = 2; i < static_cast GRPC_SLICE_INLINED_SIZE; i++) { + BufferWriterTest(i, 256, 1); + } +} + +TEST_F(WriterTest, SmallBlockTinyBackup) { BufferWriterTest(64, 256, 1); } + +TEST_F(WriterTest, SmallBlockNoBackup) { BufferWriterTest(64, 256, 0); } + +TEST_F(WriterTest, SmallBlockFullBackup) { BufferWriterTest(64, 256, 64); } + +TEST_F(WriterTest, LargeBlockTinyBackup) { BufferWriterTest(4096, 8192, 1); } + +TEST_F(WriterTest, LargeBlockNoBackup) { BufferWriterTest(4096, 8192, 0); } + +TEST_F(WriterTest, LargeBlockFullBackup) { BufferWriterTest(4096, 8192, 4096); } + +TEST_F(WriterTest, LargeBlockLargeBackup) { + BufferWriterTest(4096, 8192, 4095); +} + +} // namespace +} // namespace internal +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/alarm_test.cc b/test/cpp/common/alarm_test.cc new file mode 100644 index 00000000..4a347f13 --- /dev/null +++ b/test/cpp/common/alarm_test.cc @@ -0,0 +1,373 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +TEST(AlarmTest, RegularExpiry) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, RegularExpiryMultiSet) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm; + + for (int i = 0; i < 3; i++) { + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); + } +} + +TEST(AlarmTest, RegularExpiryMultiSetMultiCQ) { + void* junk = reinterpret_cast(1618033); + Alarm alarm; + + for (int i = 0; i < 3; i++) { + CompletionQueue cq; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); + } +} + +struct Completion { + bool completed = false; + std::mutex mu; + std::condition_variable cv; +}; + +TEST(AlarmTest, CallbackRegularExpiry) { + Alarm alarm; + + auto c = std::make_shared(); + alarm.Set(std::chrono::system_clock::now() + std::chrono::seconds(1), + [c](bool ok) { + EXPECT_TRUE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + + std::unique_lock l(c->mu); + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(10), + [c] { return c->completed; })); +} + +TEST(AlarmTest, CallbackZeroExpiry) { + Alarm alarm; + + auto c = std::make_shared(); + alarm.Set(grpc_timeout_seconds_to_deadline(0), [c](bool ok) { + EXPECT_TRUE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + + std::unique_lock l(c->mu); + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(10), + [c] { return c->completed; })); +} + +TEST(AlarmTest, CallbackNegativeExpiry) { + Alarm alarm; + + auto c = std::make_shared(); + alarm.Set(std::chrono::system_clock::now() + std::chrono::seconds(-1), + [c](bool ok) { + EXPECT_TRUE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + + std::unique_lock l(c->mu); + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(10), + [c] { return c->completed; })); +} + +TEST(AlarmTest, MultithreadedRegularExpiry) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + void* output_tag; + bool ok; + CompletionQueue::NextStatus status; + Alarm alarm; + + std::thread t1([&alarm, &cq, &junk] { + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + }); + + std::thread t2([&cq, &ok, &output_tag, &status] { + status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + }); + + t1.join(); + t2.join(); + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, DeprecatedRegularExpiry) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm(&cq, grpc_timeout_seconds_to_deadline(1), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, MoveConstructor) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm first; + first.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + Alarm second(std::move(first)); + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, MoveAssignment) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm first; + first.Set(&cq, grpc_timeout_seconds_to_deadline(1), junk); + Alarm second(std::move(first)); + first = std::move(second); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, RegularExpiryChrono) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + std::chrono::system_clock::time_point one_sec_deadline = + std::chrono::system_clock::now() + std::chrono::seconds(1); + Alarm alarm; + alarm.Set(&cq, one_sec_deadline, junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(10)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, ZeroExpiry) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(0), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(1)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, NegativeExpiry) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(-1), junk); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(1)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_TRUE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, Cancellation) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + Alarm alarm; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(10), junk); + alarm.Cancel(); + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(1)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_FALSE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, CallbackCancellation) { + Alarm alarm; + + auto c = std::make_shared(); + alarm.Set(std::chrono::system_clock::now() + std::chrono::seconds(10), + [c](bool ok) { + EXPECT_FALSE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + alarm.Cancel(); + + std::unique_lock l(c->mu); + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(1), + [c] { return c->completed; })); +} + +TEST(AlarmTest, CallbackCancellationLocked) { + Alarm alarm; + + auto c = std::make_shared(); + alarm.Set(std::chrono::system_clock::now() + std::chrono::seconds(10), + [c](bool ok) { + EXPECT_FALSE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + std::unique_lock l(c->mu); + alarm.Cancel(); + + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(1), + [c] { return c->completed; })); +} + +TEST(AlarmTest, SetDestruction) { + CompletionQueue cq; + void* junk = reinterpret_cast(1618033); + { + Alarm alarm; + alarm.Set(&cq, grpc_timeout_seconds_to_deadline(10), junk); + } + + void* output_tag; + bool ok; + const CompletionQueue::NextStatus status = + cq.AsyncNext(&output_tag, &ok, grpc_timeout_seconds_to_deadline(1)); + + EXPECT_EQ(status, CompletionQueue::GOT_EVENT); + EXPECT_FALSE(ok); + EXPECT_EQ(junk, output_tag); +} + +TEST(AlarmTest, CallbackSetDestruction) { + auto c = std::make_shared(); + { + Alarm alarm; + alarm.Set(std::chrono::system_clock::now() + std::chrono::seconds(10), + [c](bool ok) { + EXPECT_FALSE(ok); + std::lock_guard l(c->mu); + c->completed = true; + c->cv.notify_one(); + }); + } + + std::unique_lock l(c->mu); + EXPECT_TRUE(c->cv.wait_until( + l, std::chrono::system_clock::now() + std::chrono::seconds(1), + [c] { return c->completed; })); +} + +TEST(AlarmTest, UnsetDestruction) { + CompletionQueue cq; + Alarm alarm; +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/alts_util_test.cc b/test/cpp/common/alts_util_test.cc new file mode 100644 index 00000000..f22385aa --- /dev/null +++ b/test/cpp/common/alts_util_test.cc @@ -0,0 +1,219 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "upb/upb.hpp" + +#include +#include +#include + +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/cpp/common/secure_auth_context.h" +#include "src/proto/grpc/gcp/altscontext.upb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" + +namespace grpc { +namespace { + +TEST(AltsUtilTest, NullAuthContext) { + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(nullptr); + EXPECT_EQ(alts_context, nullptr); +} + +TEST(AltsUtilTest, EmptyAuthContext) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(auth_context); + EXPECT_EQ(alts_context, nullptr); +} + +TEST(AltsUtilTest, AuthContextWithMoreThanOneAltsContext) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + ctx.reset(); + auth_context->AddProperty(TSI_ALTS_CONTEXT, "context1"); + auth_context->AddProperty(TSI_ALTS_CONTEXT, "context2"); + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(auth_context); + EXPECT_EQ(alts_context, nullptr); +} + +TEST(AltsUtilTest, AuthContextWithBadAltsContext) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + ctx.reset(); + auth_context->AddProperty(TSI_ALTS_CONTEXT, + "bad context string serialization"); + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(auth_context); + EXPECT_EQ(alts_context, nullptr); +} + +TEST(AltsUtilTest, AuthContextWithGoodAltsContextWithoutRpcVersions) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + ctx.reset(); + std::string expected_ap("application protocol"); + std::string expected_rp("record protocol"); + std::string expected_peer("peer"); + std::string expected_local("local"); + std::string expected_peer_atrributes_key("peer"); + std::string expected_peer_atrributes_value("attributes"); + grpc_security_level expected_sl = GRPC_INTEGRITY_ONLY; + upb::Arena context_arena; + grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); + grpc_gcp_AltsContext_set_application_protocol( + context, upb_strview_make(expected_ap.data(), expected_ap.length())); + grpc_gcp_AltsContext_set_record_protocol( + context, upb_strview_make(expected_rp.data(), expected_rp.length())); + grpc_gcp_AltsContext_set_security_level(context, expected_sl); + grpc_gcp_AltsContext_set_peer_service_account( + context, upb_strview_make(expected_peer.data(), expected_peer.length())); + grpc_gcp_AltsContext_set_local_service_account( + context, + upb_strview_make(expected_local.data(), expected_local.length())); + grpc_gcp_AltsContext_peer_attributes_set( + context, + upb_strview_make(expected_peer_atrributes_key.data(), + expected_peer_atrributes_key.length()), + upb_strview_make(expected_peer_atrributes_value.data(), + expected_peer_atrributes_value.length()), + context_arena.ptr()); + size_t serialized_ctx_length; + char* serialized_ctx = grpc_gcp_AltsContext_serialize( + context, context_arena.ptr(), &serialized_ctx_length); + EXPECT_NE(serialized_ctx, nullptr); + auth_context->AddProperty(TSI_ALTS_CONTEXT, + string(serialized_ctx, serialized_ctx_length)); + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(auth_context); + EXPECT_NE(alts_context, nullptr); + EXPECT_EQ(expected_ap, alts_context->application_protocol()); + EXPECT_EQ(expected_rp, alts_context->record_protocol()); + EXPECT_EQ(expected_peer, alts_context->peer_service_account()); + EXPECT_EQ(expected_local, alts_context->local_service_account()); + EXPECT_EQ(expected_sl, alts_context->security_level()); + // all rpc versions should be 0 if not set + experimental::AltsContext::RpcProtocolVersions rpc_protocol_versions = + alts_context->peer_rpc_versions(); + EXPECT_EQ(0, rpc_protocol_versions.max_rpc_version.major_version); + EXPECT_EQ(0, rpc_protocol_versions.max_rpc_version.minor_version); + EXPECT_EQ(0, rpc_protocol_versions.min_rpc_version.major_version); + EXPECT_EQ(0, rpc_protocol_versions.min_rpc_version.minor_version); + EXPECT_EQ(expected_peer_atrributes_value, + alts_context->peer_attributes().at(expected_peer_atrributes_key)); +} + +TEST(AltsUtilTest, AuthContextWithGoodAltsContext) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + ctx.reset(); + upb::Arena context_arena; + grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); + upb::Arena versions_arena; + grpc_gcp_RpcProtocolVersions* versions = + grpc_gcp_RpcProtocolVersions_new(versions_arena.ptr()); + upb::Arena max_major_version_arena; + grpc_gcp_RpcProtocolVersions_Version* version = + grpc_gcp_RpcProtocolVersions_Version_new(max_major_version_arena.ptr()); + grpc_gcp_RpcProtocolVersions_Version_set_major(version, 10); + grpc_gcp_RpcProtocolVersions_set_max_rpc_version(versions, version); + grpc_gcp_AltsContext_set_peer_rpc_versions(context, versions); + size_t serialized_ctx_length; + char* serialized_ctx = grpc_gcp_AltsContext_serialize( + context, context_arena.ptr(), &serialized_ctx_length); + EXPECT_NE(serialized_ctx, nullptr); + auth_context->AddProperty(TSI_ALTS_CONTEXT, + string(serialized_ctx, serialized_ctx_length)); + std::unique_ptr alts_context = + experimental::GetAltsContextFromAuthContext(auth_context); + EXPECT_NE(alts_context, nullptr); + EXPECT_EQ("", alts_context->application_protocol()); + EXPECT_EQ("", alts_context->record_protocol()); + EXPECT_EQ("", alts_context->peer_service_account()); + EXPECT_EQ("", alts_context->local_service_account()); + EXPECT_EQ(GRPC_SECURITY_NONE, alts_context->security_level()); + experimental::AltsContext::RpcProtocolVersions rpc_protocol_versions = + alts_context->peer_rpc_versions(); + EXPECT_EQ(10, rpc_protocol_versions.max_rpc_version.major_version); + EXPECT_EQ(0, rpc_protocol_versions.max_rpc_version.minor_version); + EXPECT_EQ(0, rpc_protocol_versions.min_rpc_version.major_version); + EXPECT_EQ(0, rpc_protocol_versions.min_rpc_version.minor_version); +} + +TEST(AltsUtilTest, AltsClientAuthzCheck) { + // AltsClientAuthzCheck function should return a permission denied error on + // the bad_auth_context, whose internal ALTS context does not exist + const std::shared_ptr bad_auth_context( + new SecureAuthContext(nullptr)); + std::vector service_accounts{"client"}; + grpc::Status status = + experimental::AltsClientAuthzCheck(bad_auth_context, service_accounts); + EXPECT_EQ(grpc::StatusCode::PERMISSION_DENIED, status.error_code()); + // AltsClientAuthzCheck function should function normally when the peer name + // in ALTS context is listed in service_accounts + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + const std::shared_ptr auth_context( + new SecureAuthContext(ctx.get())); + ctx.reset(); + std::string peer("good_client"); + std::vector good_service_accounts{"good_client", + "good_client_1"}; + std::vector bad_service_accounts{"bad_client", "bad_client_1"}; + upb::Arena context_arena; + grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); + grpc_gcp_AltsContext_set_peer_service_account( + context, upb_strview_make(peer.data(), peer.length())); + size_t serialized_ctx_length; + char* serialized_ctx = grpc_gcp_AltsContext_serialize( + context, context_arena.ptr(), &serialized_ctx_length); + EXPECT_NE(serialized_ctx, nullptr); + auth_context->AddProperty(TSI_ALTS_CONTEXT, + string(serialized_ctx, serialized_ctx_length)); + grpc::Status good_status = + experimental::AltsClientAuthzCheck(auth_context, good_service_accounts); + EXPECT_TRUE(good_status.ok()); + grpc::Status bad_status = + experimental::AltsClientAuthzCheck(auth_context, bad_service_accounts); + EXPECT_EQ(grpc::StatusCode::PERMISSION_DENIED, bad_status.error_code()); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/auth_property_iterator_test.cc b/test/cpp/common/auth_property_iterator_test.cc new file mode 100644 index 00000000..5a6084be --- /dev/null +++ b/test/cpp/common/auth_property_iterator_test.cc @@ -0,0 +1,89 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "src/core/lib/security/context/security_context.h" +#include "src/cpp/common/secure_auth_context.h" +#include "test/cpp/util/string_ref_helper.h" + +using ::grpc::testing::ToString; + +namespace grpc { +namespace { + +class TestAuthPropertyIterator : public AuthPropertyIterator { + public: + TestAuthPropertyIterator() {} + TestAuthPropertyIterator(const grpc_auth_property* property, + const grpc_auth_property_iterator* iter) + : AuthPropertyIterator(property, iter) {} +}; + +class AuthPropertyIteratorTest : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapo"); + grpc_auth_context_add_cstring_property(ctx_.get(), "foo", "bar"); + EXPECT_EQ(1, grpc_auth_context_set_peer_identity_property_name(ctx_.get(), + "name")); + } + grpc_core::RefCountedPtr ctx_; +}; + +TEST_F(AuthPropertyIteratorTest, DefaultCtor) { + TestAuthPropertyIterator iter1; + TestAuthPropertyIterator iter2; + EXPECT_EQ(iter1, iter2); +} + +TEST_F(AuthPropertyIteratorTest, GeneralTest) { + grpc_auth_property_iterator c_iter = + grpc_auth_context_property_iterator(ctx_.get()); + const grpc_auth_property* property = + grpc_auth_property_iterator_next(&c_iter); + TestAuthPropertyIterator iter(property, &c_iter); + TestAuthPropertyIterator empty_iter; + EXPECT_FALSE(iter == empty_iter); + AuthProperty p0 = *iter; + ++iter; + AuthProperty p1 = *iter; + iter++; + AuthProperty p2 = *iter; + EXPECT_EQ("name", ToString(p0.first)); + EXPECT_EQ("chapi", ToString(p0.second)); + EXPECT_EQ("name", ToString(p1.first)); + EXPECT_EQ("chapo", ToString(p1.second)); + EXPECT_EQ("foo", ToString(p2.first)); + EXPECT_EQ("bar", ToString(p2.second)); + ++iter; + EXPECT_EQ(empty_iter, iter); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/channel_arguments_test.cc b/test/cpp/common/channel_arguments_test.cc new file mode 100644 index 00000000..f3c67d59 --- /dev/null +++ b/test/cpp/common/channel_arguments_test.cc @@ -0,0 +1,264 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/socket_mutator.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { + +namespace { + +// A simple grpc_socket_mutator to be used to test SetSocketMutator +class TestSocketMutator : public grpc_socket_mutator { + public: + TestSocketMutator(); + + bool MutateFd(int /*fd*/) { + // Do nothing on the fd + return true; + } +}; + +// +// C API for TestSocketMutator +// + +bool test_mutator_mutate_fd(int fd, grpc_socket_mutator* mutator) { + TestSocketMutator* tsm = reinterpret_cast(mutator); + return tsm->MutateFd(fd); +} + +int test_mutator_compare(grpc_socket_mutator* a, grpc_socket_mutator* b) { + return grpc_core::QsortCompare(a, b); +} + +void test_mutator_destroy(grpc_socket_mutator* mutator) { + TestSocketMutator* tsm = reinterpret_cast(mutator); + delete tsm; +} + +grpc_socket_mutator_vtable test_mutator_vtable = { + test_mutator_mutate_fd, test_mutator_compare, test_mutator_destroy, + nullptr}; + +// +// TestSocketMutator implementation +// + +TestSocketMutator::TestSocketMutator() { + grpc_socket_mutator_init(this, &test_mutator_vtable); +} +} // namespace + +class ChannelArgumentsTest : public ::testing::Test { + protected: + ChannelArgumentsTest() + : pointer_vtable_({&ChannelArguments::PointerVtableMembers::Copy, + &ChannelArguments::PointerVtableMembers::Destroy, + &ChannelArguments::PointerVtableMembers::Compare}) {} + + void SetChannelArgs(const ChannelArguments& channel_args, + grpc_channel_args* args) { + channel_args.SetChannelArgs(args); + } + + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } + + std::string GetDefaultUserAgentPrefix() { + std::ostringstream user_agent_prefix; + user_agent_prefix << "grpc-c++/" << Version(); + return user_agent_prefix.str(); + } + + void VerifyDefaultChannelArgs() { + grpc_channel_args args; + SetChannelArgs(channel_args_, &args); + EXPECT_EQ(static_cast(1), args.num_args); + EXPECT_STREQ(GRPC_ARG_PRIMARY_USER_AGENT_STRING, args.args[0].key); + EXPECT_EQ(GetDefaultUserAgentPrefix(), + std::string(args.args[0].value.string)); + } + + bool HasArg(grpc_arg expected_arg) { + grpc_channel_args args; + SetChannelArgs(channel_args_, &args); + for (size_t i = 0; i < args.num_args; i++) { + const grpc_arg& arg = args.args[i]; + if (arg.type == expected_arg.type && + std::string(arg.key) == expected_arg.key) { + if (arg.type == GRPC_ARG_INTEGER) { + return arg.value.integer == expected_arg.value.integer; + } else if (arg.type == GRPC_ARG_STRING) { + return std::string(arg.value.string) == expected_arg.value.string; + } else if (arg.type == GRPC_ARG_POINTER) { + return arg.value.pointer.p == expected_arg.value.pointer.p && + arg.value.pointer.vtable->copy == + expected_arg.value.pointer.vtable->copy && + arg.value.pointer.vtable->destroy == + expected_arg.value.pointer.vtable->destroy; + } + } + } + return false; + } + grpc_arg_pointer_vtable pointer_vtable_; + ChannelArguments channel_args_; +}; + +TEST_F(ChannelArgumentsTest, SetInt) { + VerifyDefaultChannelArgs(); + std::string key0("key0"); + grpc_arg arg0; + arg0.type = GRPC_ARG_INTEGER; + arg0.key = const_cast(key0.c_str()); + arg0.value.integer = 0; + std::string key1("key1"); + grpc_arg arg1; + arg1.type = GRPC_ARG_INTEGER; + arg1.key = const_cast(key1.c_str()); + arg1.value.integer = 1; + + std::string arg_key0(key0); + channel_args_.SetInt(arg_key0, arg0.value.integer); + // Clear key early to make sure channel_args takes a copy + arg_key0.clear(); + EXPECT_TRUE(HasArg(arg0)); + + std::string arg_key1(key1); + channel_args_.SetInt(arg_key1, arg1.value.integer); + arg_key1.clear(); + EXPECT_TRUE(HasArg(arg0)); + EXPECT_TRUE(HasArg(arg1)); +} + +TEST_F(ChannelArgumentsTest, SetString) { + VerifyDefaultChannelArgs(); + std::string key0("key0"); + std::string val0("val0"); + grpc_arg arg0; + arg0.type = GRPC_ARG_STRING; + arg0.key = const_cast(key0.c_str()); + arg0.value.string = const_cast(val0.c_str()); + std::string key1("key1"); + std::string val1("val1"); + grpc_arg arg1; + arg1.type = GRPC_ARG_STRING; + arg1.key = const_cast(key1.c_str()); + arg1.value.string = const_cast(val1.c_str()); + + std::string key(key0); + std::string val(val0); + channel_args_.SetString(key, val); + // Clear key/val early to make sure channel_args takes a copy + key = ""; + val = ""; + EXPECT_TRUE(HasArg(arg0)); + + key = key1; + val = val1; + channel_args_.SetString(key, val); + // Clear key/val early to make sure channel_args takes a copy + key = ""; + val = ""; + EXPECT_TRUE(HasArg(arg0)); + EXPECT_TRUE(HasArg(arg1)); +} + +TEST_F(ChannelArgumentsTest, SetPointer) { + VerifyDefaultChannelArgs(); + std::string key0("key0"); + grpc_arg arg0; + arg0.type = GRPC_ARG_POINTER; + arg0.key = const_cast(key0.c_str()); + arg0.value.pointer.p = &key0; + arg0.value.pointer.vtable = &pointer_vtable_; + + std::string key(key0); + channel_args_.SetPointer(key, arg0.value.pointer.p); + EXPECT_TRUE(HasArg(arg0)); +} + +TEST_F(ChannelArgumentsTest, SetSocketMutator) { + VerifyDefaultChannelArgs(); + grpc_arg arg0, arg1; + TestSocketMutator* mutator0 = new TestSocketMutator(); + TestSocketMutator* mutator1 = new TestSocketMutator(); + arg0 = grpc_socket_mutator_to_arg(mutator0); + arg1 = grpc_socket_mutator_to_arg(mutator1); + + channel_args_.SetSocketMutator(mutator0); + EXPECT_TRUE(HasArg(arg0)); + + // Exercise the copy constructor because we ran some sanity checks in it. + grpc::ChannelArguments new_args{channel_args_}; + + channel_args_.SetSocketMutator(mutator1); + EXPECT_TRUE(HasArg(arg1)); + // arg0 is replaced by arg1 + EXPECT_FALSE(HasArg(arg0)); +} + +TEST_F(ChannelArgumentsTest, SetUserAgentPrefix) { + VerifyDefaultChannelArgs(); + std::string prefix("prefix"); + std::string whole_prefix = prefix + " " + GetDefaultUserAgentPrefix(); + grpc_arg arg0; + arg0.type = GRPC_ARG_STRING; + arg0.key = const_cast(GRPC_ARG_PRIMARY_USER_AGENT_STRING); + arg0.value.string = const_cast(whole_prefix.c_str()); + + channel_args_.SetUserAgentPrefix(prefix); + EXPECT_TRUE(HasArg(arg0)); + + // Test if the user agent string is copied correctly + ChannelArguments new_channel_args(channel_args_); + grpc_channel_args args; + SetChannelArgs(new_channel_args, &args); + bool found = false; + for (size_t i = 0; i < args.num_args; i++) { + const grpc_arg& arg = args.args[i]; + if (arg.type == GRPC_ARG_STRING && + std::string(arg.key) == GRPC_ARG_PRIMARY_USER_AGENT_STRING) { + EXPECT_FALSE(found); + EXPECT_EQ(0, strcmp(arg.value.string, arg0.value.string)); + found = true; + } + } + EXPECT_TRUE(found); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/common/channel_filter_test.cc b/test/cpp/common/channel_filter_test.cc new file mode 100644 index 00000000..139c0313 --- /dev/null +++ b/test/cpp/common/channel_filter_test.cc @@ -0,0 +1,66 @@ +// +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "src/cpp/common/channel_filter.h" + +#include + +#include + +#include + +namespace grpc { +namespace testing { + +class MyChannelData : public ChannelData { + public: + MyChannelData() {} + + grpc_error_handle Init(grpc_channel_element* /*elem*/, + grpc_channel_element_args* args) override { + (void)args->channel_args; // Make sure field is available. + return GRPC_ERROR_NONE; + } +}; + +class MyCallData : public CallData { + public: + MyCallData() {} + + grpc_error_handle Init(grpc_call_element* /*elem*/, + const grpc_call_element_args* args) override { + (void)args->path; // Make sure field is available. + return GRPC_ERROR_NONE; + } +}; + +// This test ensures that when we make changes to the filter API in +// C-core, we don't accidentally break the C++ filter API. +TEST(ChannelFilterTest, RegisterChannelFilter) { + grpc::RegisterChannelFilter( + "myfilter", GRPC_CLIENT_CHANNEL, INT_MAX, nullptr); +} + +// TODO(roth): When we have time, add tests for all methods of the +// filter API. + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/secure_auth_context_test.cc b/test/cpp/common/secure_auth_context_test.cc new file mode 100644 index 00000000..fcae9ba1 --- /dev/null +++ b/test/cpp/common/secure_auth_context_test.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/cpp/common/secure_auth_context.h" + +#include + +#include +#include + +#include "src/core/lib/security/context/security_context.h" +#include "test/cpp/util/string_ref_helper.h" + +using grpc::testing::ToString; + +namespace grpc { +namespace { + +class SecureAuthContextTest : public ::testing::Test {}; + +// Created with nullptr +TEST_F(SecureAuthContextTest, EmptyContext) { + SecureAuthContext context(nullptr); + EXPECT_TRUE(context.GetPeerIdentity().empty()); + EXPECT_TRUE(context.GetPeerIdentityPropertyName().empty()); + EXPECT_TRUE(context.FindPropertyValues("").empty()); + EXPECT_TRUE(context.FindPropertyValues("whatever").empty()); + EXPECT_TRUE(context.begin() == context.end()); +} + +TEST_F(SecureAuthContextTest, Properties) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + SecureAuthContext context(ctx.get()); + ctx.reset(); + context.AddProperty("name", "chapi"); + context.AddProperty("name", "chapo"); + context.AddProperty("foo", "bar"); + EXPECT_TRUE(context.SetPeerIdentityPropertyName("name")); + + std::vector peer_identity = context.GetPeerIdentity(); + EXPECT_EQ(2u, peer_identity.size()); + EXPECT_EQ("chapi", ToString(peer_identity[0])); + EXPECT_EQ("chapo", ToString(peer_identity[1])); + EXPECT_EQ("name", context.GetPeerIdentityPropertyName()); + std::vector bar = context.FindPropertyValues("foo"); + EXPECT_EQ(1u, bar.size()); + EXPECT_EQ("bar", ToString(bar[0])); +} + +TEST_F(SecureAuthContextTest, Iterators) { + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + SecureAuthContext context(ctx.get()); + ctx.reset(); + context.AddProperty("name", "chapi"); + context.AddProperty("name", "chapo"); + context.AddProperty("foo", "bar"); + EXPECT_TRUE(context.SetPeerIdentityPropertyName("name")); + + AuthPropertyIterator iter = context.begin(); + EXPECT_TRUE(context.end() != iter); + AuthProperty p0 = *iter; + ++iter; + AuthProperty p1 = *iter; + iter++; + AuthProperty p2 = *iter; + EXPECT_EQ("name", ToString(p0.first)); + EXPECT_EQ("chapi", ToString(p0.second)); + EXPECT_EQ("name", ToString(p1.first)); + EXPECT_EQ("chapo", ToString(p1.second)); + EXPECT_EQ("foo", ToString(p2.first)); + EXPECT_EQ("bar", ToString(p2.second)); + ++iter; + EXPECT_EQ(context.end(), iter); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/time_jump_test.cc b/test/cpp/common/time_jump_test.cc new file mode 100644 index 00000000..91643a75 --- /dev/null +++ b/test/cpp/common/time_jump_test.cc @@ -0,0 +1,146 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include + +#include "absl/time/time.h" + +#include +#include + +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/timer.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "test/core/util/test_config.h" + +extern char** environ; + +#ifdef GPR_ANDROID +// Android doesn't have posix_spawn. Use std::system instead +void run_cmd(const char* cmd) { std::system(cmd); } +#else +void run_cmd(const char* cmd) { + pid_t pid; + const char* argv[] = {const_cast("sh"), + const_cast("-c"), cmd, nullptr}; + int status; + + status = posix_spawn(&pid, const_cast("/bin/sh"), nullptr, + nullptr, const_cast(argv), environ); + if (status == 0) { + if (waitpid(pid, &status, 0) == -1) { + perror("waitpid"); + } + } +} +#endif + +class TimeJumpTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // Skip test if slowdown factor > 1 + if (grpc_test_slowdown_factor() != 1) { + GTEST_SKIP(); + } else { + grpc_init(); + } + } + void TearDown() override { + // Skip test if slowdown factor > 1 + if (grpc_test_slowdown_factor() == 1) { + run_cmd("sudo sntp -sS pool.ntp.org"); + grpc_shutdown(); + } + } + + const int kWaitTimeMs = 1500; +}; + +std::vector CreateTestScenarios() { + return {"-1M", "+1M", "-1H", "+1H", "-1d", "+1d", "-1y", "+1y"}; +} +INSTANTIATE_TEST_SUITE_P(TimeJump, TimeJumpTest, + ::testing::ValuesIn(CreateTestScenarios())); + +TEST_P(TimeJumpTest, TimerRunning) { + grpc_core::ExecCtx exec_ctx; + grpc_timer timer; + grpc_timer_init(&timer, grpc_core::ExecCtx::Get()->Now() + 3000, + GRPC_CLOSURE_CREATE( + [](void*, grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_CANCELLED); + }, + nullptr, grpc_schedule_on_exec_ctx)); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + std::ostringstream cmd; + cmd << "sudo date `date -v" << GetParam() << " \"+%m%d%H%M%y\"`"; + run_cmd(cmd.str().c_str()); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(kWaitTimeMs)); + // We expect 1 wakeup/sec when there are not timer expiries + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + gpr_log(GPR_DEBUG, "wakeups: %" PRId64 "", wakeups); + GPR_ASSERT(wakeups <= 3); + grpc_timer_cancel(&timer); +} + +TEST_P(TimeJumpTest, TimedWait) { + grpc_core::CondVar cond; + grpc_core::Mutex mu; + { + grpc_core::MutexLock lock(&mu); + std::thread thd = std::thread([]() { + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + std::ostringstream cmd; + cmd << "sudo date `date -v" << GetParam() << " \"+%m%d%H%M%y\"`"; + run_cmd(cmd.str().c_str()); + }); + gpr_timespec before = gpr_now(GPR_CLOCK_MONOTONIC); + bool timedout = cond.WaitWithTimeout(&mu, absl::Milliseconds(kWaitTimeMs)); + gpr_timespec after = gpr_now(GPR_CLOCK_MONOTONIC); + int32_t elapsed_ms = gpr_time_to_millis(gpr_time_sub(after, before)); + gpr_log(GPR_DEBUG, "After wait, timedout = %d elapsed_ms = %d", timedout, + elapsed_ms); + GPR_ASSERT(1 == timedout); + GPR_ASSERT(1 == + gpr_time_similar(gpr_time_sub(after, before), + gpr_time_from_millis(kWaitTimeMs, GPR_TIMESPAN), + gpr_time_from_millis(50, GPR_TIMESPAN))); + + thd.join(); + } + // We expect 1 wakeup/sec when there are not timer expiries + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + gpr_log(GPR_DEBUG, "wakeups: %" PRId64 "", wakeups); + GPR_ASSERT(wakeups <= 3); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/common/timer_test.cc b/test/cpp/common/timer_test.cc new file mode 100644 index 00000000..2a802461 --- /dev/null +++ b/test/cpp/common/timer_test.cc @@ -0,0 +1,232 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/timer.h" + +#include + +#include +#include + +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_POSIX_SOCKET_EV +#include "src/core/lib/iomgr/ev_posix.h" +#endif + +// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration +// should be skipped based on a decision made at SetUp time. +#define MAYBE_SKIP_TEST \ + do { \ + if (do_not_test_) { \ + return; \ + } \ + } while (0) + +class TimerTest : public ::testing::Test { + protected: + void SetUp() override { + grpc_init(); + // Skip test if slowdown factor > 1, or we are + // using event manager. +#ifdef GRPC_POSIX_SOCKET_EV + if (grpc_test_slowdown_factor() != 1 || + grpc_event_engine_run_in_background()) { +#else + if (grpc_test_slowdown_factor() != 1) { +#endif + do_not_test_ = true; + } + } + + void TearDown() override { grpc_shutdown(); } + + bool do_not_test_{false}; +}; + +#ifndef GPR_WINDOWS +// the test fails with too many wakeups on windows opt build +// the mechanism by which that happens is described in +// https://github.com/grpc/grpc/issues/20436 +TEST_F(TimerTest, NoTimers) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1500)); + + // We expect to get 1 wakeup per second. Sometimes we also get a wakeup + // during initialization, so in 1.5 seconds we expect to get 1 or 2 wakeups. + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + GPR_ASSERT(wakeups == 1 || wakeups == 2); +} +#endif + +TEST_F(TimerTest, OneTimerExpires) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + grpc_timer timer; + int timer_fired = 0; + grpc_timer_init(&timer, grpc_core::ExecCtx::Get()->Now() + 500, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle) { + int* timer_fired = static_cast(arg); + ++*timer_fired; + }, + &timer_fired, grpc_schedule_on_exec_ctx)); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1500)); + GPR_ASSERT(1 == timer_fired); + + // We expect to get 1 wakeup/second + 1 wakeup for the expired timer + maybe 1 + // wakeup during initialization. i.e. in 1.5 seconds we expect 2 or 3 wakeups. + // Actual number of wakeups is more due to bug + // https://github.com/grpc/grpc/issues/19947 + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + gpr_log(GPR_DEBUG, "wakeups: %" PRId64 "", wakeups); +} + +TEST_F(TimerTest, MultipleTimersExpire) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + const int kNumTimers = 10; + grpc_timer timers[kNumTimers]; + int timer_fired = 0; + for (int i = 0; i < kNumTimers; ++i) { + grpc_timer_init(&timers[i], grpc_core::ExecCtx::Get()->Now() + 500 + i, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle) { + int* timer_fired = static_cast(arg); + ++*timer_fired; + }, + &timer_fired, grpc_schedule_on_exec_ctx)); + } + + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1500)); + GPR_ASSERT(kNumTimers == timer_fired); + + // We expect to get 1 wakeup/second + 1 wakeup for per timer fired + maybe 1 + // wakeup during initialization. i.e. in 1.5 seconds we expect 11 or 12 + // wakeups. Actual number of wakeups is more due to bug + // https://github.com/grpc/grpc/issues/19947 + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + gpr_log(GPR_DEBUG, "wakeups: %" PRId64 "", wakeups); +} + +TEST_F(TimerTest, CancelSomeTimers) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + const int kNumTimers = 10; + grpc_timer timers[kNumTimers]; + int timer_fired = 0; + for (int i = 0; i < kNumTimers; ++i) { + grpc_timer_init(&timers[i], grpc_core::ExecCtx::Get()->Now() + 500 + i, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle error) { + if (error == GRPC_ERROR_CANCELLED) { + return; + } + int* timer_fired = static_cast(arg); + ++*timer_fired; + }, + &timer_fired, grpc_schedule_on_exec_ctx)); + } + for (int i = 0; i < kNumTimers / 2; ++i) { + grpc_timer_cancel(&timers[i]); + } + + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1500)); + GPR_ASSERT(kNumTimers / 2 == timer_fired); + + // We expect to get 1 wakeup/second + 1 wakeup per timer fired + maybe 1 + // wakeup during initialization. i.e. in 1.5 seconds we expect 6 or 7 wakeups. + // Actual number of wakeups is more due to bug + // https://github.com/grpc/grpc/issues/19947 + int64_t wakeups = grpc_timer_manager_get_wakeups_testonly(); + gpr_log(GPR_DEBUG, "wakeups: %" PRId64 "", wakeups); +} + +// Enable the following test after +// https://github.com/grpc/grpc/issues/20049 has been fixed. +TEST_F(TimerTest, DISABLED_TimerNotCanceled) { + grpc_core::ExecCtx exec_ctx; + grpc_timer timer; + grpc_timer_init(&timer, grpc_core::ExecCtx::Get()->Now() + 10000, + GRPC_CLOSURE_CREATE([](void*, grpc_error_handle) {}, nullptr, + grpc_schedule_on_exec_ctx)); +} + +// Enable the following test after +// https://github.com/grpc/grpc/issues/20064 has been fixed. +TEST_F(TimerTest, DISABLED_CancelRace) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + const int kNumTimers = 10; + grpc_timer timers[kNumTimers]; + for (int i = 0; i < kNumTimers; ++i) { + grpc_timer* arg = (i != 0) ? &timers[i - 1] : nullptr; + grpc_timer_init(&timers[i], grpc_core::ExecCtx::Get()->Now() + 100, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle /*error*/) { + grpc_timer* timer = static_cast(arg); + if (timer) { + grpc_timer_cancel(timer); + } + }, + arg, grpc_schedule_on_exec_ctx)); + } + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); +} + +// Enable the following test after +// https://github.com/grpc/grpc/issues/20066 has been fixed. +TEST_F(TimerTest, DISABLED_CancelNextTimer) { + MAYBE_SKIP_TEST; + grpc_core::ExecCtx exec_ctx; + const int kNumTimers = 10; + grpc_timer timers[kNumTimers]; + + for (int i = 0; i < kNumTimers; ++i) { + grpc_timer_init_unset(&timers[i]); + } + + for (int i = 0; i < kNumTimers; ++i) { + grpc_timer* arg = nullptr; + if (i < kNumTimers - 1) { + arg = &timers[i + 1]; + } + grpc_timer_init(&timers[i], grpc_core::ExecCtx::Get()->Now() + 100, + GRPC_CLOSURE_CREATE( + [](void* arg, grpc_error_handle /*error*/) { + grpc_timer* timer = static_cast(arg); + if (timer) { + grpc_timer_cancel(timer); + } + }, + arg, grpc_schedule_on_exec_ctx)); + } + grpc_timer_cancel(&timers[0]); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/admin_services_end2end_test.cc b/test/cpp/end2end/admin_services_end2end_test.cc new file mode 100644 index 00000000..b1baccc8 --- /dev/null +++ b/test/cpp/end2end/admin_services_end2end_test.cc @@ -0,0 +1,101 @@ +// +// +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include +#include + +#include "absl/strings/str_cat.h" + +#include +#include +#include + +#include "src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { + +class AdminServicesTest : public ::testing::Test { + public: + void SetUp() override { + std::string address = + absl::StrCat("localhost:", grpc_pick_unused_port_or_die()); + // Create admin server + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + builder.AddListeningPort(address, InsecureServerCredentials()); + ::grpc::AddAdminServices(&builder); + server_ = builder.BuildAndStart(); + // Create channel + auto reflection_stub = reflection::v1alpha::ServerReflection::NewStub( + CreateChannel(address, InsecureChannelCredentials())); + stream_ = reflection_stub->ServerReflectionInfo(&reflection_ctx_); + } + + std::vector GetServiceList() { + std::vector services; + reflection::v1alpha::ServerReflectionRequest request; + reflection::v1alpha::ServerReflectionResponse response; + request.set_list_services(""); + stream_->Write(request); + stream_->Read(&response); + for (auto& service : response.list_services_response().service()) { + services.push_back(service.name()); + } + return services; + } + + private: + std::unique_ptr server_; + ClientContext reflection_ctx_; + std::shared_ptr< + ClientReaderWriter> + stream_; +}; + +TEST_F(AdminServicesTest, ValidateRegisteredServices) { + // Using Contains here, because the server builder might register other + // services in certain environments. + EXPECT_THAT( + GetServiceList(), + ::testing::AllOf( + ::testing::Contains("grpc.channelz.v1.Channelz"), + ::testing::Contains("grpc.reflection.v1alpha.ServerReflection"))); +#if defined(GRPC_NO_XDS) || defined(DISABLED_XDS_PROTO_IN_CC) + EXPECT_THAT(GetServiceList(), + ::testing::Not(::testing::Contains( + "envoy.service.status.v3.ClientStatusDiscoveryService"))); +#else + EXPECT_THAT(GetServiceList(), + ::testing::Contains( + "envoy.service.status.v3.ClientStatusDiscoveryService")); +#endif // GRPC_NO_XDS or DISABLED_XDS_PROTO_IN_CC +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc new file mode 100644 index 00000000..71e8d2e5 --- /dev/null +++ b/test/cpp/end2end/async_end2end_test.cc @@ -0,0 +1,1954 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/health/v1/health.grpc.pb.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_EV +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET_EV + +#include + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { + +namespace { + +void* tag(int t) { return reinterpret_cast(t); } +int detag(void* p) { return static_cast(reinterpret_cast(p)); } + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + maybe_expectations_.clear(); + } + + // This version of Verify stops after a certain deadline + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + maybe_expectations_.clear(); + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + maybe_expectations_.clear(); + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + maybe_expectations_.erase(it2); + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map expectations_; + std::map maybe_expectations_; + bool lambda_run_; +}; + +bool plugin_has_sync_methods(std::unique_ptr& plugin) { + return plugin->has_sync_methods(); +} + +// This class disables the server builder plugins that may add sync services to +// the server. If there are sync services, UnimplementedRpc test will triger +// the sync unknown rpc routine on the server side, rather than the async one +// that needs to be tested here. +class ServerBuilderSyncPluginDisabler : public ::grpc::ServerBuilderOption { + public: + void UpdateArguments(ChannelArguments* /*arg*/) override {} + + void UpdatePlugins( + std::vector>* plugins) override { + plugins->erase(std::remove_if(plugins->begin(), plugins->end(), + plugin_has_sync_methods), + plugins->end()); + } +}; + +class TestScenario { + public: + TestScenario(bool inproc_stub, const std::string& creds_type, bool hcs, + const std::string& content) + : inproc(inproc_stub), + health_check_service(hcs), + credentials_type(creds_type), + message_content(content) {} + void Log() const; + bool inproc; + bool health_check_service; + const std::string credentials_type; + const std::string message_content; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{inproc=" << (scenario.inproc ? "true" : "false") + << ", credentials='" << scenario.credentials_type + << ", health_check_service=" + << (scenario.health_check_service ? "true" : "false") + << "', message_size=" << scenario.message_content.size() << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class HealthCheck : public health::v1::Health::Service {}; + +class AsyncEnd2endTest : public ::testing::TestWithParam { + protected: + AsyncEnd2endTest() { GetParam().Log(); } + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + + // Setup server + BuildAndStartServer(); + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) { + } + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + void BuildAndStartServer() { + ServerBuilder builder; + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + builder.AddListeningPort(server_address_.str(), server_creds); + service_ = + absl::make_unique(); + builder.RegisterService(service_.get()); + if (GetParam().health_check_service) { + builder.RegisterService(&health_check_); + } + cq_ = builder.AddCompletionQueue(); + + // TODO(zyc): make a test option to choose wheather sync plugins should be + // deleted + std::unique_ptr sync_plugin_disabler( + new ServerBuilderSyncPluginDisabler()); + builder.SetOption(move(sync_plugin_disabler)); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr channel = + !(GetParam().inproc) ? ::grpc::CreateCustomChannel( + server_address_.str(), channel_creds, args) + : server_->InProcessChannel(args); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + void SendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, + cq_.get(), cq_.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + std::unique_ptr cq_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::unique_ptr service_; + HealthCheck health_check_; + std::ostringstream server_address_; + int port_; +}; + +TEST_P(AsyncEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpc(1); +} + +TEST_P(AsyncEnd2endTest, SimpleRpcWithExpectedError) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + ErrorStatus error_status; + + send_request.set_message(GetParam().message_content); + error_status.set_code(1); // CANCELLED + error_status.set_error_message("cancel error message"); + *send_request.mutable_param()->mutable_expected_error() = error_status; + + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish( + send_response, + Status( + static_cast(recv_request.param().expected_error().code()), + recv_request.param().expected_error().error_message()), + tag(3)); + Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get()); + + EXPECT_EQ(recv_response.message(), ""); + EXPECT_EQ(recv_status.error_code(), error_status.code()); + EXPECT_EQ(recv_status.error_message(), error_status.error_message()); + EXPECT_FALSE(srv_ctx.IsCancelled()); +} + +TEST_P(AsyncEnd2endTest, SequentialRpcs) { + ResetStub(); + SendRpc(10); +} + +TEST_P(AsyncEnd2endTest, ReconnectChannel) { + // GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS is set to 100ms in main() + if (GetParam().inproc) { + return; + } + int poller_slowdown_factor = 1; +#ifdef GRPC_POSIX_SOCKET_EV + // It needs 2 pollset_works to reconnect the channel with polling engine + // "poll" + grpc_core::UniquePtr poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + if (0 == strcmp(poller.get(), "poll")) { + poller_slowdown_factor = 2; + } +#endif // GRPC_POSIX_SOCKET_EV + ResetStub(); + SendRpc(1); + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) { + } + BuildAndStartServer(); + // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to + // reconnect the channel. + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis( + 300 * poller_slowdown_factor * grpc_test_slowdown_factor(), + GPR_TIMESPAN))); + SendRpc(1); +} + +// We do not need to protect notify because the use is synchronized. +void ServerWait(Server* server, int* notify) { + server->Wait(); + *notify = 1; +} +TEST_P(AsyncEnd2endTest, WaitAndShutdownTest) { + int notify = 0; + std::thread wait_thread(&ServerWait, server_.get(), ¬ify); + ResetStub(); + SendRpc(1); + EXPECT_EQ(0, notify); + server_->Shutdown(); + wait_thread.join(); + EXPECT_EQ(1, notify); +} + +TEST_P(AsyncEnd2endTest, ShutdownThenWait) { + ResetStub(); + SendRpc(1); + std::thread t([this]() { server_->Shutdown(); }); + server_->Wait(); + t.join(); +} + +// Test a simple RPC using the async version of Next +TEST_P(AsyncEnd2endTest, AsyncNextRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + std::chrono::system_clock::time_point time_now( + std::chrono::system_clock::now()); + std::chrono::system_clock::time_point time_limit( + std::chrono::system_clock::now() + std::chrono::seconds(10)); + Verifier().Verify(cq_.get(), time_now); + Verifier().Verify(cq_.get(), time_now); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get(), time_limit); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify( + cq_.get(), std::chrono::system_clock::time_point::max()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Test a simple RPC using the async version of Next +TEST_P(AsyncEnd2endTest, DoThenAsyncNextRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + std::chrono::system_clock::time_point time_now( + std::chrono::system_clock::now()); + std::chrono::system_clock::time_point time_limit( + std::chrono::system_clock::now() + std::chrono::seconds(10)); + Verifier().Verify(cq_.get(), time_now); + Verifier().Verify(cq_.get(), time_now); + + auto resp_writer_ptr = &response_writer; + auto lambda_2 = [&, this, resp_writer_ptr]() { + service_->RequestEcho(&srv_ctx, &recv_request, resp_writer_ptr, cq_.get(), + cq_.get(), tag(2)); + }; + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get(), time_limit, lambda_2); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + auto lambda_3 = [resp_writer_ptr, send_response]() { + resp_writer_ptr->Finish(send_response, Status::OK, tag(3)); + }; + Verifier().Expect(3, true).Expect(4, true).Verify( + cq_.get(), std::chrono::system_clock::time_point::max(), lambda_3); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Two pings and a final pong. +TEST_P(AsyncEnd2endTest, SimpleClientStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); + + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get()); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->Write(send_request, tag(5)); + srv_stream.Read(&recv_request, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_request.message(), recv_request.message()); + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Two pings and a final pong. +TEST_P(AsyncEnd2endTest, SimpleClientStreamingWithCoalescingApi) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + // tag:1 never comes up since no op is performed + std::unique_ptr> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); + + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->Write(send_request, tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(5)); + srv_stream.Read(&recv_request, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(7)); + Verifier().Expect(7, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(8)); + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(8, true).Expect(9, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. +TEST_P(AsyncEnd2endTest, SimpleServerStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. Using WriteAndFinish API +TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWAF) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Read(&recv_response, tag(7)); + Verifier().Expect(7, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(8)); + Verifier().Expect(8, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. Using WriteLast API +TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWL) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.WriteLast(send_response, WriteOptions(), tag(5)); + cli_stream->Read(&recv_response, tag(6)); + srv_stream.Finish(Status::OK, tag(7)); + Verifier().Expect(5, true).Expect(6, true).Expect(7, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. +TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. Using server:WriteAndFinish api +TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWAF) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + std::unique_ptr> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(5)); + Verifier().Expect(5, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(6)); + cli_stream->Read(&recv_response, tag(7)); + Verifier().Expect(6, true).Expect(7, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Finish(&recv_status, tag(8)); + Verifier().Expect(8, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. Using server:WriteLast api +TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWL) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + std::unique_ptr> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(5)); + Verifier().Expect(5, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.WriteLast(send_response, WriteOptions(), tag(6)); + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(6, true).Expect(7, true).Expect(8, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// Metadata tests +TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair meta1("key1", "val1"); + std::pair meta2("key2", "val2"); + std::pair meta3("g.r.d-bin", "xyz"); + cli_ctx.AddMetadata(meta1.first, meta1.second); + cli_ctx.AddMetadata(meta2.first, meta2.second); + cli_ctx.AddMetadata(meta3.first, meta3.second); + + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + const auto& client_initial_metadata = srv_ctx.client_metadata(); + EXPECT_EQ(meta1.second, + ToString(client_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(client_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(meta3.second, + ToString(client_initial_metadata.find(meta3.first)->second)); + EXPECT_GE(client_initial_metadata.size(), static_cast(2)); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair meta1("key1", "val1"); + std::pair meta2("key2", "val2"); + + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->ReadInitialMetadata(tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast(2), server_initial_metadata.size()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(5)); + response_reader->Finish(&recv_response, &recv_status, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// 1 ping, 2 pongs. +TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreaming) { + ResetStub(); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + std::pair<::std::string, ::std::string> meta1("key1", "val1"); + std::pair<::std::string, ::std::string> meta2("key2", "val2"); + + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + cli_stream->ReadInitialMetadata(tag(11)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + srv_stream.SendInitialMetadata(tag(10)); + Verifier().Expect(10, true).Expect(11, true).Verify(cq_.get()); + auto server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast(2), server_initial_metadata.size()); + + srv_stream.Write(send_response, tag(3)); + + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// 1 ping, 2 pongs. +// Test for server initial metadata being sent implicitly +TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreamingImplicit) { + ResetStub(); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<::std::string, ::std::string> meta1("key1", "val1"); + std::pair<::std::string, ::std::string> meta2("key2", "val2"); + + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + auto server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast(2), server_initial_metadata.size()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair meta1("key1", "val1"); + std::pair meta2("key2", "val2"); + + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(5)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_ctx.AddTrailingMetadata(meta1.first, meta1.second); + srv_ctx.AddTrailingMetadata(meta2.first, meta2.second); + response_writer.Finish(send_response, Status::OK, tag(4)); + + Verifier().Expect(4, true).Expect(5, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_trailing_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_trailing_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast(2), server_trailing_metadata.size()); +} + +TEST_P(AsyncEnd2endTest, MetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair meta1("key1", "val1"); + std::pair meta2( + "key2-bin", + std::string("\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc", 13)); + std::pair meta3("key3", "val3"); + std::pair meta6( + "key4-bin", + std::string("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d", + 14)); + std::pair meta5("key5", "val5"); + std::pair meta4( + "key6-bin", + std::string( + "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee", 15)); + + cli_ctx.AddMetadata(meta1.first, meta1.second); + cli_ctx.AddMetadata(meta2.first, meta2.second); + + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->ReadInitialMetadata(tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + const auto& client_initial_metadata = srv_ctx.client_metadata(); + EXPECT_EQ(meta1.second, + ToString(client_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(client_initial_metadata.find(meta2.first)->second)); + EXPECT_GE(client_initial_metadata.size(), static_cast(2)); + + srv_ctx.AddInitialMetadata(meta3.first, meta3.second); + srv_ctx.AddInitialMetadata(meta4.first, meta4.second); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta3.second, + ToString(server_initial_metadata.find(meta3.first)->second)); + EXPECT_EQ(meta4.second, + ToString(server_initial_metadata.find(meta4.first)->second)); + EXPECT_GE(server_initial_metadata.size(), static_cast(2)); + + send_response.set_message(recv_request.message()); + srv_ctx.AddTrailingMetadata(meta5.first, meta5.second); + srv_ctx.AddTrailingMetadata(meta6.first, meta6.second); + response_writer.Finish(send_response, Status::OK, tag(5)); + response_reader->Finish(&recv_response, &recv_status, tag(6)); + + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata(); + EXPECT_EQ(meta5.second, + ToString(server_trailing_metadata.find(meta5.first)->second)); + EXPECT_EQ(meta6.second, + ToString(server_trailing_metadata.find(meta6.first)->second)); + EXPECT_GE(server_trailing_metadata.size(), static_cast(2)); +} + +// Server uses AsyncNotifyWhenDone API to check for cancellation +TEST_P(AsyncEnd2endTest, ServerCheckCancellation) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_ctx.TryCancel(); + Verifier().Expect(5, true).Expect(4, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + EXPECT_EQ(StatusCode::CANCELLED, recv_status.error_code()); +} + +// Server uses AsyncNotifyWhenDone API to check for normal finish +TEST_P(AsyncEnd2endTest, ServerCheckDone) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get()); + EXPECT_FALSE(srv_ctx.IsCancelled()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, UnimplementedRpc) { + ChannelArguments args; + const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr channel = + !(GetParam().inproc) ? ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args) + : server_->InProcessChannel(args); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message(GetParam().message_content); + std::unique_ptr> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); +} + +// This class is for testing scenarios where RPCs are cancelled on the server +// by calling ServerContext::TryCancel(). Server uses AsyncNotifyWhenDone +// API to check for cancellation +class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { + protected: + typedef enum { + DO_NOT_CANCEL = 0, + CANCEL_BEFORE_PROCESSING, + CANCEL_DURING_PROCESSING, + CANCEL_AFTER_PROCESSING + } ServerTryCancelRequestPhase; + + // Helper for testing client-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading + // any messages from the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // messages from the client (but before sending any status back to the + // client) + void TestClientStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader srv_stream(&srv_ctx); + + // Initiate the 'RequestStream' call on client + CompletionQueue cli_cq; + + std::unique_ptr> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, &cli_cq, tag(1))); + + // On the server, request to be notified of 'RequestStream' calls + // and receive the 'RequestStream' call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); }); + Verifier().Expect(2, true).Verify(cq_.get()); + t1.join(); + + bool expected_server_cq_result = true; + bool expected_client_cq_result = true; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + Verifier().Expect(11, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + // Since cancellation is done before server reads any results, we know + // for sure that all server cq results will return false from this + // point forward + expected_server_cq_result = false; + expected_client_cq_result = false; + } + + bool ignore_client_cq_result = + (server_try_cancel == CANCEL_DURING_PROCESSING) || + (server_try_cancel == CANCEL_BEFORE_PROCESSING); + + std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result, + &ignore_client_cq_result] { + EchoRequest send_request; + // Client sends 3 messages (tags 3, 4 and 5) + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_request.set_message("Ping " + std::to_string(tag_idx)); + cli_stream->Write(send_request, tag(tag_idx)); + Verifier() + .Expect(tag_idx, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + } + cli_stream->WritesDone(tag(6)); + // Ignore ok on WritesDone since cancel can affect it + Verifier() + .Expect(6, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + }); + + bool ignore_cq_result = false; + bool want_done_tag = false; + std::thread* server_try_cancel_thd = nullptr; + + auto verif = Verifier(); + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + // Server will cancel the RPC in a parallel thread while reading the + // requests from the client. Since the cancellation can happen at anytime, + // some of the cq results (i.e those until cancellation) might be true but + // its non deterministic. So better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + // Server reads 3 messages (tags 6, 7 and 8) + // But if want_done_tag is true, we might also see tag 11 + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + srv_stream.Read(&recv_request, tag(tag_idx)); + // Note that we'll add something to the verifier and verify that + // something was seen, but it might be tag 11 and not what we + // just added + int got_tag = verif.Expect(tag_idx, expected_server_cq_result) + .Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx); + } + } + + cli_thread.join(); + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + // Server sends the final message and cancelled status (but the RPC is + // already cancelled at this point. So we expect the operation to fail) + srv_stream.Finish(send_response, Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + // Client will see the cancellation + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + + cli_cq.Shutdown(); + void* phony_tag; + bool phony_ok; + while (cli_cq.Next(&phony_tag, &phony_ok)) { + } + } + + // Helper for testing server-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before sending + // any messages to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while sending + // messages to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after sending all + // messages to the client (but before sending any status back to the + // client) + void TestServerStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter srv_stream(&srv_ctx); + + send_request.set_message("Ping"); + // Initiate the 'ResponseStream' call on the client + CompletionQueue cli_cq; + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, &cli_cq, tag(1))); + // On the server, request to be notified of 'ResponseStream' calls and + // receive the call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); }); + Verifier().Expect(2, true).Verify(cq_.get()); + t1.join(); + + EXPECT_EQ(send_request.message(), recv_request.message()); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + bool want_done_tag = false; + bool expected_client_cq_result = true; + bool ignore_client_cq_result = + (server_try_cancel != CANCEL_BEFORE_PROCESSING); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + Verifier().Expect(11, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + // We know for sure that all cq results will be false from this point + // since the server cancelled the RPC + expected_cq_result = false; + expected_client_cq_result = false; + } + + std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result, + &ignore_client_cq_result] { + // Client attempts to read the three messages from the server + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + EchoResponse recv_response; + cli_stream->Read(&recv_response, tag(tag_idx)); + Verifier() + .Expect(tag_idx, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + } + }); + + std::thread* server_try_cancel_thd = nullptr; + + auto verif = Verifier(); + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + + // Server will cancel the RPC in a parallel thread while writing responses + // to the client. Since the cancellation can happen at anytime, some of + // the cq results (i.e those until cancellation) might be true but it is + // non deterministic. So better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + // Server sends three messages (tags 3, 4 and 5) + // But if want_done tag is true, we might also see tag 11 + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_response.set_message("Pong " + std::to_string(tag_idx)); + srv_stream.Write(send_response, tag(tag_idx)); + // Note that we'll add something to the verifier and verify that + // something was seen, but it might be tag 11 and not what we + // just added + int got_tag = verif.Expect(tag_idx, expected_cq_result) + .Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx); + } + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + cli_thread.join(); + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + // Server finishes the stream (but the RPC is already cancelled) + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + // Client will see the cancellation + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + + cli_cq.Shutdown(); + void* phony_tag; + bool phony_ok; + while (cli_cq.Next(&phony_tag, &phony_ok)) { + } + } + + // Helper for testing bidirectinal-streaming RPCs which are cancelled on the + // server. + // + // Depending on the value of server_try_cancel parameter, this will + // test one of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/ + // writing any messages from/to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // messages from the client (but before sending any status back to the + // client) + void TestBidiStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter srv_stream(&srv_ctx); + + // Initiate the call from the client side + std::unique_ptr> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + // On the server, request to be notified of the 'BidiStream' call and + // receive the call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + auto verif = Verifier(); + + // Client sends the first and the only message + send_request.set_message("Ping"); + cli_stream->Write(send_request, tag(3)); + verif.Expect(3, true); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + bool want_done_tag = false; + + int got_tag, got_tag2; + bool tag_3_done = false; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + verif.Expect(11, true); + // We know for sure that all server cq results will be false from + // this point since the server cancelled the RPC. However, we can't + // say for sure about the client + expected_cq_result = false; + ignore_cq_result = true; + + do { + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT(((got_tag == 3) && !tag_3_done) || (got_tag == 11)); + if (got_tag == 3) { + tag_3_done = true; + } + } while (got_tag != 11); + EXPECT_TRUE(srv_ctx.IsCancelled()); + } + + std::thread* server_try_cancel_thd = nullptr; + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + + // Since server is going to cancel the RPC in a parallel thread, some of + // the cq results (i.e those until the cancellation) might be true. Since + // that number is non-deterministic, it is better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + srv_stream.Read(&recv_request, tag(4)); + verif.Expect(4, expected_cq_result); + got_tag = tag_3_done ? 3 : verif.Next(cq_.get(), ignore_cq_result); + got_tag2 = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 3) || (got_tag == 4) || + (got_tag == 11 && want_done_tag)); + GPR_ASSERT((got_tag2 == 3) || (got_tag2 == 4) || + (got_tag2 == 11 && want_done_tag)); + // If we get 3 and 4, we don't need to wait for 11, but if + // we get 11, we should also clear 3 and 4 + if (got_tag + got_tag2 != 7) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 3) || (got_tag == 4)); + } + + send_response.set_message("Pong"); + srv_stream.Write(send_response, tag(5)); + verif.Expect(5, expected_cq_result); + + cli_stream->Read(&recv_response, tag(6)); + verif.Expect(6, expected_cq_result); + got_tag = verif.Next(cq_.get(), ignore_cq_result); + got_tag2 = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 5) || (got_tag == 6) || + (got_tag == 11 && want_done_tag)); + GPR_ASSERT((got_tag2 == 5) || (got_tag2 == 6) || + (got_tag2 == 11 && want_done_tag)); + // If we get 5 and 6, we don't need to wait for 11, but if + // we get 11, we should also clear 5 and 6 + if (got_tag + got_tag2 != 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 5) || (got_tag == 6)); + } + + // This is expected to succeed in all cases + cli_stream->WritesDone(tag(7)); + verif.Expect(7, true); + // TODO(vjpai): Consider whether the following is too flexible + // or whether it should just be reset to ignore_cq_result + bool ignore_cq_wd_result = + ignore_cq_result || (server_try_cancel == CANCEL_BEFORE_PROCESSING); + got_tag = verif.Next(cq_.get(), ignore_cq_wd_result); + GPR_ASSERT((got_tag == 7) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_wd_result), 7); + } + + // This is expected to fail in all cases i.e for all values of + // server_try_cancel. This is because at this point, either there are no + // more msgs from the client (because client called WritesDone) or the RPC + // is cancelled on the server + srv_stream.Read(&recv_request, tag(8)); + verif.Expect(8, false); + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 8) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), 8); + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(cq_.get()); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, recv_status.error_code()); + } +}; + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelBefore) { + TestClientStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelDuring) { + TestClientStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelAfter) { + TestClientStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelBefore) { + TestServerStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelDuring) { + TestServerStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelAfter) { + TestServerStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelBefore) { + TestBidiStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelDuring) { + TestBidiStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelAfter) { + TestBidiStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +std::vector CreateTestScenarios(bool /*test_secure*/, + bool test_message_size_limit) { + std::vector scenarios; + std::vector credentials_types; + std::vector messages; + + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + + if (insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + GPR_ASSERT(!credentials_types.empty()); + + messages.push_back("Hello"); + if (test_message_size_limit) { + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; + k *= 32) { + std::string big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + if (!BuiltUnderMsan()) { + // 4MB message processing with SSL is very slow under msan + // (causes timeouts) and doesn't really increase the signal from tests. + // Reserve 100 bytes for other fields of the message proto. + messages.push_back( + std::string(GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH - 100, 'a')); + } + } + + // TODO (sreek) Renable tests with health check service after the issue + // https://github.com/grpc/grpc/issues/11223 is resolved + for (auto health_check_service : {false}) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + for (auto cred = credentials_types.begin(); + cred != credentials_types.end(); ++cred) { + scenarios.emplace_back(false, *cred, health_check_service, *msg); + } + if (insec_ok()) { + scenarios.emplace_back(true, kInsecureCredentialsType, + health_check_service, *msg); + } + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(AsyncEnd2end, AsyncEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true, true))); +INSTANTIATE_TEST_SUITE_P(AsyncEnd2endServerTryCancel, + AsyncEnd2endServerTryCancelTest, + ::testing::ValuesIn(CreateTestScenarios(false, + false))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + // Change the backup poll interval from 5s to 100ms to speed up the + // ReconnectChannel test + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 100); + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/cfstream_test.cc b/test/cpp/end2end/cfstream_test.cc new file mode 100644 index 00000000..a5e65cb2 --- /dev/null +++ b/test/cpp/end2end/cfstream_test.cc @@ -0,0 +1,478 @@ +/* + * + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_CFSTREAM +using grpc::ClientAsyncResponseReader; +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::RequestParams; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +struct TestScenario { + TestScenario(const std::string& creds_type, const std::string& content) + : credentials_type(creds_type), message_content(content) {} + const std::string credentials_type; + const std::string message_content; +}; + +class CFStreamTest : public ::testing::TestWithParam { + protected: + CFStreamTest() + : server_host_("grpctest"), + interface_("lo0"), + ipv4_address_("10.0.0.1") {} + + void DNSUp() { + std::ostringstream cmd; + // Add DNS entry for server_host_ in /etc/hosts + cmd << "echo '" << ipv4_address_ << " " << server_host_ + << " ' | sudo tee -a /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DNSDown() { + std::ostringstream cmd; + // Remove DNS entry for server_host_ in /etc/hosts + cmd << "sudo sed -i '.bak' '/" << server_host_ << "/d' /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void InterfaceUp() { + std::ostringstream cmd; + cmd << "sudo /sbin/ifconfig " << interface_ << " alias " << ipv4_address_; + std::system(cmd.str().c_str()); + } + + void InterfaceDown() { + std::ostringstream cmd; + cmd << "sudo /sbin/ifconfig " << interface_ << " -alias " << ipv4_address_; + std::system(cmd.str().c_str()); + } + + void NetworkUp() { + gpr_log(GPR_DEBUG, "Bringing network up"); + InterfaceUp(); + DNSUp(); + } + + void NetworkDown() { + gpr_log(GPR_DEBUG, "Bringing network down"); + InterfaceDown(); + DNSDown(); + } + + void SetUp() override { + NetworkUp(); + grpc_init(); + StartServer(); + } + + void TearDown() override { + NetworkDown(); + StopServer(); + grpc_shutdown(); + } + + void StartServer() { + port_ = grpc_pick_unused_port_or_die(); + server_.reset(new ServerData(port_, GetParam().credentials_type)); + server_->Start(server_host_); + } + void StopServer() { server_->Shutdown(); } + + std::unique_ptr BuildStub( + const std::shared_ptr& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr BuildChannel() { + std::ostringstream server_address; + server_address << server_host_ << ":" << port_; + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + return CreateCustomChannel(server_address.str(), channel_creds, args); + } + + void SendRpc( + const std::unique_ptr& stub, + bool expect_success = false) { + auto response = std::unique_ptr(new EchoResponse()); + EchoRequest request; + auto& msg = GetParam().message_content; + request.set_message(msg); + ClientContext context; + Status status = stub->Echo(&context, request, response.get()); + if (status.ok()) { + gpr_log(GPR_DEBUG, "RPC with succeeded"); + EXPECT_EQ(msg, response->message()); + } else { + gpr_log(GPR_DEBUG, "RPC failed: %s", status.error_message().c_str()); + } + if (expect_success) { + EXPECT_TRUE(status.ok()); + } + } + void SendAsyncRpc( + const std::unique_ptr& stub, + RequestParams param = RequestParams()) { + EchoRequest request; + request.set_message(GetParam().message_content); + *request.mutable_param() = std::move(param); + AsyncClientCall* call = new AsyncClientCall; + + call->response_reader = + stub->PrepareAsyncEcho(&call->context, request, &cq_); + + call->response_reader->StartCall(); + call->response_reader->Finish(&call->reply, &call->status, (void*)call); + } + + void ShutdownCQ() { cq_.Shutdown(); } + + bool CQNext(void** tag, bool* ok) { + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); + auto ret = cq_.AsyncNext(tag, ok, deadline); + if (ret == grpc::CompletionQueue::GOT_EVENT) { + return true; + } else if (ret == grpc::CompletionQueue::SHUTDOWN) { + return false; + } else { + GPR_ASSERT(ret == grpc::CompletionQueue::TIMEOUT); + // This can happen if we hit the Apple CFStream bug which results in the + // read stream freezing. We are ignoring hangs and timeouts, but these + // tests are still useful as they can catch memory memory corruptions, + // crashes and other bugs that don't result in test freeze/timeout. + return false; + } + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 10) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + struct AsyncClientCall { + EchoResponse reply; + ClientContext context; + Status status; + std::unique_ptr> response_reader; + }; + + private: + struct ServerData { + int port_; + const std::string creds_; + std::unique_ptr server_; + TestServiceImpl service_; + std::unique_ptr thread_; + bool server_ready_ = false; + + ServerData(int port, const std::string& creds) + : port_(port), creds_(creds) {} + + void Start(const std::string& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + std::mutex mu; + std::unique_lock lock(mu); + std::condition_variable cond; + thread_.reset(new std::thread( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond))); + cond.wait(lock, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const std::string& server_host, std::mutex* mu, + std::condition_variable* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(creds_); + builder.AddListeningPort(server_address.str(), server_creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + std::lock_guard lock(*mu); + server_ready_ = true; + cond->notify_one(); + } + + void Shutdown(bool join = true) { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + if (join) thread_->join(); + } + }; + + CompletionQueue cq_; + const std::string server_host_; + const std::string interface_; + const std::string ipv4_address_; + std::unique_ptr server_; + int port_; +}; + +std::vector CreateTestScenarios() { + std::vector scenarios; + std::vector credentials_types; + std::vector messages; + + credentials_types.push_back(kInsecureCredentialsType); + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + + messages.push_back("🖖"); + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) { + std::string big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + for (auto cred = credentials_types.begin(); cred != credentials_types.end(); + ++cred) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + scenarios.emplace_back(*cred, *msg); + } + } + + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(CFStreamTest, CFStreamTest, + ::testing::ValuesIn(CreateTestScenarios())); + +// gRPC should automatically detech network flaps (without enabling keepalives) +// when CFStream is enabled +TEST_P(CFStreamTest, NetworkTransition) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + SendRpc(stub, /*expect_success=*/true); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // bring down network + NetworkDown(); + + // network going down should be detected by cfstream + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + + // bring network interface back up + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + NetworkUp(); + + // channel should reconnect + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); +} + +// Network flaps while RPCs are in flight +TEST_P(CFStreamTest, NetworkFlapRpcsInFlight) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + std::atomic_int rpcs_sent{0}; + + // Channel should be in READY state after we send some RPCs + for (int i = 0; i < 10; ++i) { + RequestParams param; + param.set_skip_cancelled_check(true); + SendAsyncRpc(stub, param); + ++rpcs_sent; + } + EXPECT_TRUE(WaitForChannelReady(channel.get())); + + // Bring down the network + NetworkDown(); + + std::thread thd = std::thread([this, &rpcs_sent]() { + void* got_tag; + bool ok = false; + bool network_down = true; + int total_completions = 0; + + while (CQNext(&got_tag, &ok)) { + ++total_completions; + GPR_ASSERT(ok); + AsyncClientCall* call = static_cast(got_tag); + if (!call->status.ok()) { + gpr_log(GPR_DEBUG, "RPC failed with error: %s", + call->status.error_message().c_str()); + // Bring network up when RPCs start failing + if (network_down) { + NetworkUp(); + network_down = false; + } + } else { + gpr_log(GPR_DEBUG, "RPC succeeded"); + } + delete call; + } + // Remove line below and uncomment the following line after Apple CFStream + // bug has been fixed. + (void)rpcs_sent; + // EXPECT_EQ(total_completions, rpcs_sent); + }); + + for (int i = 0; i < 100; ++i) { + RequestParams param; + param.set_skip_cancelled_check(true); + SendAsyncRpc(stub, param); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + ++rpcs_sent; + } + + ShutdownCQ(); + + thd.join(); +} + +// Send a bunch of RPCs, some of which are expected to fail. +// We should get back a response for all RPCs +TEST_P(CFStreamTest, ConcurrentRpc) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + std::atomic_int rpcs_sent{0}; + std::thread thd = std::thread([this, &rpcs_sent]() { + void* got_tag; + bool ok = false; + int total_completions = 0; + + while (CQNext(&got_tag, &ok)) { + ++total_completions; + GPR_ASSERT(ok); + AsyncClientCall* call = static_cast(got_tag); + if (!call->status.ok()) { + gpr_log(GPR_DEBUG, "RPC failed with error: %s", + call->status.error_message().c_str()); + // Bring network up when RPCs start failing + } else { + gpr_log(GPR_DEBUG, "RPC succeeded"); + } + delete call; + } + // Remove line below and uncomment the following line after Apple CFStream + // bug has been fixed. + (void)rpcs_sent; + // EXPECT_EQ(total_completions, rpcs_sent); + }); + + for (int i = 0; i < 10; ++i) { + if (i % 3 == 0) { + RequestParams param; + ErrorStatus* error = param.mutable_expected_error(); + error->set_code(StatusCode::INTERNAL); + error->set_error_message("internal error"); + SendAsyncRpc(stub, param); + } else if (i % 5 == 0) { + RequestParams param; + param.set_echo_metadata(true); + DebugInfo* info = param.mutable_debug_info(); + info->add_stack_entries("stack_entry1"); + info->add_stack_entries("stack_entry2"); + info->set_detail("detailed debug info"); + SendAsyncRpc(stub, param); + } else { + SendAsyncRpc(stub); + } + ++rpcs_sent; + } + + ShutdownCQ(); + + thd.join(); +} + +} // namespace +} // namespace testing +} // namespace grpc +#endif // GRPC_CFSTREAM + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + gpr_setenv("grpc_cfstream", "1"); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/channelz_service_test.cc b/test/cpp/end2end/channelz_service_test.cc new file mode 100644 index 00000000..6395f7a4 --- /dev/null +++ b/test/cpp/end2end/channelz_service_test.cc @@ -0,0 +1,937 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/security/credentials/tls/grpc_tls_certificate_provider.h" +#include "src/core/lib/security/security_connector/ssl_utils.h" +#include "src/core/lib/slice/slice_utils.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/proto/grpc/channelz/channelz.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +using grpc::channelz::v1::Address; +using grpc::channelz::v1::GetChannelRequest; +using grpc::channelz::v1::GetChannelResponse; +using grpc::channelz::v1::GetServerRequest; +using grpc::channelz::v1::GetServerResponse; +using grpc::channelz::v1::GetServerSocketsRequest; +using grpc::channelz::v1::GetServerSocketsResponse; +using grpc::channelz::v1::GetServersRequest; +using grpc::channelz::v1::GetServersResponse; +using grpc::channelz::v1::GetSocketRequest; +using grpc::channelz::v1::GetSocketResponse; +using grpc::channelz::v1::GetSubchannelRequest; +using grpc::channelz::v1::GetSubchannelResponse; +using grpc::channelz::v1::GetTopChannelsRequest; +using grpc::channelz::v1::GetTopChannelsResponse; + +namespace grpc { +namespace testing { +namespace { + +static bool ValidateAddress(const Address& address) { + if (address.address_case() != Address::kTcpipAddress) { + return true; + } + return address.tcpip_address().ip_address().size() == 4 || + address.tcpip_address().ip_address().size() == 16; +} + +// Proxy service supports N backends. Sends RPC to backend dictated by +// request->backend_channel_idx(). +class Proxy : public ::grpc::testing::EchoTestService::Service { + public: + Proxy() {} + + void AddChannelToBackend(const std::shared_ptr& channel) { + stubs_.push_back(grpc::testing::EchoTestService::NewStub(channel)); + } + + Status Echo(ServerContext* server_context, const EchoRequest* request, + EchoResponse* response) override { + std::unique_ptr client_context = + ClientContext::FromServerContext(*server_context); + size_t idx = request->param().backend_channel_idx(); + GPR_ASSERT(idx < stubs_.size()); + return stubs_[idx]->Echo(client_context.get(), *request, response); + } + + Status BidiStream(ServerContext* server_context, + ServerReaderWriter* + stream_from_client) override { + EchoRequest request; + EchoResponse response; + std::unique_ptr client_context = + ClientContext::FromServerContext(*server_context); + + // always use the first proxy for streaming + auto stream_to_backend = stubs_[0]->BidiStream(client_context.get()); + while (stream_from_client->Read(&request)) { + stream_to_backend->Write(request); + stream_to_backend->Read(&response); + stream_from_client->Write(response); + } + + stream_to_backend->WritesDone(); + return stream_to_backend->Finish(); + } + + private: + std::vector> stubs_; +}; + +enum class CredentialsType { + kInsecure = 0, + kTls = 1, + kMtls = 2, +}; + +constexpr char kCaCertPath[] = "src/core/tsi/test_creds/ca.pem"; +constexpr char kServerCertPath[] = "src/core/tsi/test_creds/server1.pem"; +constexpr char kServerKeyPath[] = "src/core/tsi/test_creds/server1.key"; +constexpr char kClientCertPath[] = "src/core/tsi/test_creds/client.pem"; +constexpr char kClientKeyPath[] = "src/core/tsi/test_creds/client.key"; + +std::string ReadFile(const char* file_path) { + grpc_slice slice; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("load_file", grpc_load_file(file_path, 0, &slice))); + std::string file_contents(grpc_core::StringViewFromSlice(slice)); + grpc_slice_unref(slice); + return file_contents; +} + +grpc_core::PemKeyCertPairList ReadTlsIdentityPair(const char* key_path, + const char* cert_path) { + return grpc_core::PemKeyCertPairList{ + grpc_core::PemKeyCertPair(ReadFile(key_path), ReadFile(cert_path))}; +} + +std::shared_ptr GetChannelCredentials( + CredentialsType type, ChannelArguments* args) { + if (type == CredentialsType::kInsecure) { + return InsecureChannelCredentials(); + } + args->SetSslTargetNameOverride("foo.test.google.fr"); + std::vector identity_key_cert_pairs = { + {ReadFile(kClientKeyPath), ReadFile(kClientCertPath)}}; + grpc::experimental::TlsChannelCredentialsOptions options; + options.set_certificate_provider( + std::make_shared( + ReadFile(kCaCertPath), identity_key_cert_pairs)); + if (type == CredentialsType::kMtls) { + options.watch_identity_key_cert_pairs(); + } + options.watch_root_certs(); + return grpc::experimental::TlsCredentials(options); +} + +std::shared_ptr GetServerCredentials( + CredentialsType type) { + if (type == CredentialsType::kInsecure) { + return InsecureServerCredentials(); + } + std::vector identity_key_cert_pairs = { + {ReadFile(kServerKeyPath), ReadFile(kServerCertPath)}}; + auto certificate_provider = + std::make_shared( + ReadFile(kCaCertPath), identity_key_cert_pairs); + grpc::experimental::TlsServerCredentialsOptions options(certificate_provider); + options.watch_root_certs(); + options.watch_identity_key_cert_pairs(); + options.set_cert_request_type(GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY); + return grpc::experimental::TlsServerCredentials(options); +} + +std::string RemoveWhitespaces(std::string input) { + input.erase(remove_if(input.begin(), input.end(), isspace), input.end()); + return input; +} + +class ChannelzServerTest : public ::testing::TestWithParam { + public: + ChannelzServerTest() {} + static void SetUpTestCase() { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + void SetUp() override { + // ensure channel server is brought up on all severs we build. + ::grpc::channelz::experimental::InitChannelzService(); + + // We set up a proxy server with channelz enabled. + proxy_port_ = grpc_pick_unused_port_or_die(); + ServerBuilder proxy_builder; + std::string proxy_server_address = "localhost:" + to_string(proxy_port_); + proxy_builder.AddListeningPort(proxy_server_address, + GetServerCredentials(GetParam())); + // forces channelz and channel tracing to be enabled. + proxy_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 1); + proxy_builder.AddChannelArgument( + GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024); + proxy_builder.RegisterService(&proxy_service_); + proxy_server_ = proxy_builder.BuildAndStart(); + } + + // Sets the proxy up to have an arbitrary number of backends. + void ConfigureProxy(size_t num_backends) { + backends_.resize(num_backends); + for (size_t i = 0; i < num_backends; ++i) { + // create a new backend. + backends_[i].port = grpc_pick_unused_port_or_die(); + ServerBuilder backend_builder; + std::string backend_server_address = + "localhost:" + to_string(backends_[i].port); + backend_builder.AddListeningPort(backend_server_address, + GetServerCredentials(GetParam())); + backends_[i].service = absl::make_unique(); + // ensure that the backend itself has channelz disabled. + backend_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 0); + backend_builder.RegisterService(backends_[i].service.get()); + backends_[i].server = backend_builder.BuildAndStart(); + // set up a channel to the backend. We ensure that this channel has + // channelz enabled since these channels (proxy outbound to backends) + // are the ones that our test will actually be validating. + ChannelArguments args; + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 1); + args.SetInt(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024); + std::shared_ptr channel_to_backend = ::grpc::CreateCustomChannel( + backend_server_address, GetChannelCredentials(GetParam(), &args), + args); + proxy_service_.AddChannelToBackend(channel_to_backend); + } + } + + void ResetStubs() { + string target = "dns:localhost:" + to_string(proxy_port_); + ChannelArguments args; + // disable channelz. We only want to focus on proxy to backend outbound. + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0); + std::shared_ptr channel = ::grpc::CreateCustomChannel( + target, GetChannelCredentials(GetParam(), &args), args); + channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel); + echo_stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr NewEchoStub() { + string target = "dns:localhost:" + to_string(proxy_port_); + ChannelArguments args; + // disable channelz. We only want to focus on proxy to backend outbound. + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0); + // This ensures that gRPC will not do connection sharing. + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true); + std::shared_ptr channel = ::grpc::CreateCustomChannel( + target, GetChannelCredentials(GetParam(), &args), args); + return grpc::testing::EchoTestService::NewStub(channel); + } + + void SendSuccessfulEcho(int channel_idx) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(channel_idx); + ClientContext context; + Status s = echo_stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + + void SendSuccessfulStream(int num_messages) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + ClientContext context; + auto stream_to_proxy = echo_stub_->BidiStream(&context); + for (int i = 0; i < num_messages; ++i) { + EXPECT_TRUE(stream_to_proxy->Write(request)); + EXPECT_TRUE(stream_to_proxy->Read(&response)); + } + stream_to_proxy->WritesDone(); + Status s = stream_to_proxy->Finish(); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + + void SendFailedEcho(int channel_idx) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(channel_idx); + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(13); // INTERNAL + error->set_error_message("error"); + ClientContext context; + Status s = echo_stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + } + + // Uses GetTopChannels to return the channel_id of a particular channel, + // so that the unit tests may test GetChannel call. + intptr_t GetChannelId(int channel_idx) { + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_GT(response.channel_size(), channel_idx); + return response.channel(channel_idx).ref().channel_id(); + } + + static string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + protected: + // package of data needed for each backend server. + struct BackendData { + std::unique_ptr server; + int port; + std::unique_ptr service; + }; + + std::unique_ptr channelz_stub_; + std::unique_ptr echo_stub_; + + // proxy server to ping with channelz requests. + std::unique_ptr proxy_server_; + int proxy_port_; + Proxy proxy_service_; + + // backends. All implement the echo service. + std::vector backends_; +}; + +TEST_P(ChannelzServerTest, BasicTest) { + ResetStubs(); + ConfigureProxy(1); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), 1); +} + +TEST_P(ChannelzServerTest, HighStartId) { + ResetStubs(); + ConfigureProxy(1); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(10000); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), 0); +} + +TEST_P(ChannelzServerTest, SuccessfulRequestTest) { + ResetStubs(); + ConfigureProxy(1); + SendSuccessfulEcho(0); + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 1); + EXPECT_EQ(response.channel().data().calls_succeeded(), 1); + EXPECT_EQ(response.channel().data().calls_failed(), 0); +} + +TEST_P(ChannelzServerTest, FailedRequestTest) { + ResetStubs(); + ConfigureProxy(1); + SendFailedEcho(0); + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 1); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), 1); +} + +TEST_P(ChannelzServerTest, ManyRequestsTest) { + ResetStubs(); + ConfigureProxy(1); + // send some RPCs + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(0); + } + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), + kNumSuccess + kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); +} + +TEST_P(ChannelzServerTest, ManyChannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), kNumChannels); +} + +TEST_P(ChannelzServerTest, ManyRequestsManyChannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + + // the first channel saw only successes + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), 0); + } + + // the second channel saw only failures + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(1)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); + } + + // the third channel saw both + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(2)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), + kNumSuccess + kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); + } + + // the fourth channel saw nothing + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(3)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 0); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), 0); + } +} + +TEST_P(ChannelzServerTest, ManySubchannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + GetTopChannelsRequest gtc_request; + GetTopChannelsResponse gtc_response; + gtc_request.set_start_channel_id(0); + ClientContext context; + Status s = + channelz_stub_->GetTopChannels(&context, gtc_request, >c_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel_size(), kNumChannels); + for (int i = 0; i < gtc_response.channel_size(); ++i) { + // if the channel sent no RPCs, then expect no subchannels to have been + // created. + if (gtc_response.channel(i).data().calls_started() == 0) { + EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0); + continue; + } + // The resolver must return at least one address. + ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0); + GetSubchannelRequest gsc_request; + GetSubchannelResponse gsc_response; + gsc_request.set_subchannel_id( + gtc_response.channel(i).subchannel_ref(0).subchannel_id()); + ClientContext context; + Status s = + channelz_stub_->GetSubchannel(&context, gsc_request, &gsc_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel(i).data().calls_started(), + gsc_response.subchannel().data().calls_started()); + EXPECT_EQ(gtc_response.channel(i).data().calls_succeeded(), + gsc_response.subchannel().data().calls_succeeded()); + EXPECT_EQ(gtc_response.channel(i).data().calls_failed(), + gsc_response.subchannel().data().calls_failed()); + } +} + +TEST_P(ChannelzServerTest, BasicServerTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest request; + GetServersResponse response; + request.set_start_server_id(0); + ClientContext context; + Status s = channelz_stub_->GetServers(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.server_size(), 1); +} + +TEST_P(ChannelzServerTest, BasicGetServerTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_servers_request; + GetServersResponse get_servers_response; + get_servers_request.set_start_server_id(0); + ClientContext get_servers_context; + Status s = channelz_stub_->GetServers( + &get_servers_context, get_servers_request, &get_servers_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_servers_response.server_size(), 1); + GetServerRequest get_server_request; + GetServerResponse get_server_response; + get_server_request.set_server_id( + get_servers_response.server(0).ref().server_id()); + ClientContext get_server_context; + s = channelz_stub_->GetServer(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_servers_response.server(0).ref().server_id(), + get_server_response.server().ref().server_id()); +} + +TEST_P(ChannelzServerTest, ServerCallTest) { + ResetStubs(); + ConfigureProxy(1); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(0); + } + GetServersRequest request; + GetServersResponse response; + request.set_start_server_id(0); + ClientContext context; + Status s = channelz_stub_->GetServers(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.server_size(), 1); + EXPECT_EQ(response.server(0).data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.server(0).data().calls_failed(), kNumFailed); + // This is success+failure+1 because the call that retrieved this information + // will be counted as started. It will not track success/failure until after + // it has returned, so that is not included in the response. + EXPECT_EQ(response.server(0).data().calls_started(), + kNumSuccess + kNumFailed + 1); +} + +TEST_P(ChannelzServerTest, ManySubchannelsAndSockets) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + GetTopChannelsRequest gtc_request; + GetTopChannelsResponse gtc_response; + gtc_request.set_start_channel_id(0); + ClientContext context; + Status s = + channelz_stub_->GetTopChannels(&context, gtc_request, >c_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel_size(), kNumChannels); + for (int i = 0; i < gtc_response.channel_size(); ++i) { + // if the channel sent no RPCs, then expect no subchannels to have been + // created. + if (gtc_response.channel(i).data().calls_started() == 0) { + EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0); + continue; + } + // The resolver must return at least one address. + ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0); + // First grab the subchannel + GetSubchannelRequest get_subchannel_req; + GetSubchannelResponse get_subchannel_resp; + get_subchannel_req.set_subchannel_id( + gtc_response.channel(i).subchannel_ref(0).subchannel_id()); + ClientContext get_subchannel_ctx; + Status s = channelz_stub_->GetSubchannel( + &get_subchannel_ctx, get_subchannel_req, &get_subchannel_resp); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(get_subchannel_resp.subchannel().socket_ref_size(), 1); + // Now grab the socket. + GetSocketRequest get_socket_req; + GetSocketResponse get_socket_resp; + ClientContext get_socket_ctx; + get_socket_req.set_socket_id( + get_subchannel_resp.subchannel().socket_ref(0).socket_id()); + s = channelz_stub_->GetSocket(&get_socket_ctx, get_socket_req, + &get_socket_resp); + EXPECT_TRUE( + get_subchannel_resp.subchannel().socket_ref(0).name().find("http")); + EXPECT_TRUE(s.ok()) << s.error_message(); + // calls started == streams started AND stream succeeded. Since none of + // these RPCs were canceled, all of the streams will succeeded even though + // the RPCs they represent might have failed. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().streams_started()); + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().streams_succeeded()); + // All of the calls were unary, so calls started == messages sent. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().messages_sent()); + // We only get responses when the RPC was successful, so + // calls succeeded == messages received. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_succeeded(), + get_socket_resp.socket().data().messages_received()); + switch (GetParam()) { + case CredentialsType::kInsecure: + EXPECT_FALSE(get_socket_resp.socket().has_security()); + break; + case CredentialsType::kTls: + case CredentialsType::kMtls: + EXPECT_TRUE(get_socket_resp.socket().has_security()); + EXPECT_TRUE(get_socket_resp.socket().security().has_tls()); + EXPECT_EQ( + RemoveWhitespaces( + get_socket_resp.socket().security().tls().remote_certificate()), + RemoveWhitespaces(ReadFile(kServerCertPath))); + break; + } + } +} + +TEST_P(ChannelzServerTest, StreamingRPC) { + ResetStubs(); + ConfigureProxy(1); + const int kNumMessages = 5; + SendSuccessfulStream(kNumMessages); + // Get the channel + GetChannelRequest get_channel_request; + GetChannelResponse get_channel_response; + get_channel_request.set_channel_id(GetChannelId(0)); + ClientContext get_channel_context; + Status s = channelz_stub_->GetChannel( + &get_channel_context, get_channel_request, &get_channel_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_channel_response.channel().data().calls_started(), 1); + EXPECT_EQ(get_channel_response.channel().data().calls_succeeded(), 1); + EXPECT_EQ(get_channel_response.channel().data().calls_failed(), 0); + // Get the subchannel + ASSERT_GT(get_channel_response.channel().subchannel_ref_size(), 0); + GetSubchannelRequest get_subchannel_request; + GetSubchannelResponse get_subchannel_response; + ClientContext get_subchannel_context; + get_subchannel_request.set_subchannel_id( + get_channel_response.channel().subchannel_ref(0).subchannel_id()); + s = channelz_stub_->GetSubchannel(&get_subchannel_context, + get_subchannel_request, + &get_subchannel_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_started(), 1); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_succeeded(), 1); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_failed(), 0); + // Get the socket + ASSERT_GT(get_subchannel_response.subchannel().socket_ref_size(), 0); + GetSocketRequest get_socket_request; + GetSocketResponse get_socket_response; + ClientContext get_socket_context; + get_socket_request.set_socket_id( + get_subchannel_response.subchannel().socket_ref(0).socket_id()); + EXPECT_TRUE( + get_subchannel_response.subchannel().socket_ref(0).name().find("http")); + s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_socket_response.socket().data().streams_started(), 1); + EXPECT_EQ(get_socket_response.socket().data().streams_succeeded(), 1); + EXPECT_EQ(get_socket_response.socket().data().streams_failed(), 0); + EXPECT_EQ(get_socket_response.socket().data().messages_sent(), kNumMessages); + EXPECT_EQ(get_socket_response.socket().data().messages_received(), + kNumMessages); + switch (GetParam()) { + case CredentialsType::kInsecure: + EXPECT_FALSE(get_socket_response.socket().has_security()); + break; + case CredentialsType::kTls: + case CredentialsType::kMtls: + EXPECT_TRUE(get_socket_response.socket().has_security()); + EXPECT_TRUE(get_socket_response.socket().security().has_tls()); + EXPECT_EQ(RemoveWhitespaces(get_socket_response.socket() + .security() + .tls() + .remote_certificate()), + RemoveWhitespaces(ReadFile(kServerCertPath))); + break; + } +} + +TEST_P(ChannelzServerTest, GetServerSocketsTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), 1); + EXPECT_TRUE(get_server_sockets_response.socket_ref(0).name().find("http")); + // Get the socket to verify security information. + GetSocketRequest get_socket_request; + GetSocketResponse get_socket_response; + ClientContext get_socket_context; + get_socket_request.set_socket_id( + get_server_sockets_response.socket_ref(0).socket_id()); + s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_TRUE(ValidateAddress(get_socket_response.socket().remote())); + EXPECT_TRUE(ValidateAddress(get_socket_response.socket().local())); + switch (GetParam()) { + case CredentialsType::kInsecure: + EXPECT_FALSE(get_socket_response.socket().has_security()); + break; + case CredentialsType::kTls: + case CredentialsType::kMtls: + EXPECT_TRUE(get_socket_response.socket().has_security()); + EXPECT_TRUE(get_socket_response.socket().security().has_tls()); + if (GetParam() == CredentialsType::kMtls) { + EXPECT_EQ(RemoveWhitespaces(get_socket_response.socket() + .security() + .tls() + .remote_certificate()), + RemoveWhitespaces(ReadFile(kClientCertPath))); + } else { + EXPECT_TRUE(get_socket_response.socket() + .security() + .tls() + .remote_certificate() + .empty()); + } + break; + } +} + +TEST_P(ChannelzServerTest, GetServerSocketsPaginationTest) { + ResetStubs(); + ConfigureProxy(1); + std::vector> stubs; + const int kNumServerSocketsCreated = 20; + for (int i = 0; i < kNumServerSocketsCreated; ++i) { + stubs.push_back(NewEchoStub()); + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(0); + ClientContext context; + Status s = stubs.back()->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + // Make a request that gets all of the serversockets + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + // We add one to account the channelz stub that will end up creating + // a serversocket. + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), + kNumServerSocketsCreated + 1); + EXPECT_TRUE(get_server_sockets_response.end()); + } + // Now we make a request that exercises pagination. + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + const int kMaxResults = 10; + get_server_sockets_request.set_max_results(kMaxResults); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), kMaxResults); + EXPECT_FALSE(get_server_sockets_response.end()); + } +} + +TEST_P(ChannelzServerTest, GetServerListenSocketsTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + // The resolver might return one or two addresses depending on the + // configuration, one for ipv4 and one for ipv6. + int listen_socket_size = get_server_response.server(0).listen_socket_size(); + EXPECT_TRUE(listen_socket_size == 1 || listen_socket_size == 2); + GetSocketRequest get_socket_request; + GetSocketResponse get_socket_response; + get_socket_request.set_socket_id( + get_server_response.server(0).listen_socket(0).socket_id()); + EXPECT_TRUE( + get_server_response.server(0).listen_socket(0).name().find("http")); + ClientContext get_socket_context_1; + s = channelz_stub_->GetSocket(&get_socket_context_1, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + + EXPECT_TRUE(ValidateAddress(get_socket_response.socket().remote())); + EXPECT_TRUE(ValidateAddress(get_socket_response.socket().local())); + if (listen_socket_size == 2) { + get_socket_request.set_socket_id( + get_server_response.server(0).listen_socket(1).socket_id()); + ClientContext get_socket_context_2; + EXPECT_TRUE( + get_server_response.server(0).listen_socket(1).name().find("http")); + s = channelz_stub_->GetSocket(&get_socket_context_2, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } +} + +INSTANTIATE_TEST_SUITE_P(ChannelzServer, ChannelzServerTest, + ::testing::ValuesIn(std::vector( + {CredentialsType::kInsecure, CredentialsType::kTls, + CredentialsType::kMtls}))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc new file mode 100644 index 00000000..9118d96a --- /dev/null +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -0,0 +1,1578 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { +namespace { + +enum class Protocol { INPROC, TCP }; + +class TestScenario { + public: + TestScenario(bool serve_callback, Protocol protocol, bool intercept, + const std::string& creds_type) + : callback_server(serve_callback), + protocol(protocol), + use_interceptors(intercept), + credentials_type(creds_type) {} + void Log() const; + bool callback_server; + Protocol protocol; + bool use_interceptors; + const std::string credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{callback_server=" + << (scenario.callback_server ? "true" : "false") << ",protocol=" + << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP") + << ",intercept=" << (scenario.use_interceptors ? "true" : "false") + << ",creds=" << scenario.credentials_type << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class ClientCallbackEnd2endTest + : public ::testing::TestWithParam { + protected: + ClientCallbackEnd2endTest() { GetParam().Log(); } + + void SetUp() override { + ServerBuilder builder; + + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + // TODO(vjpai): Support testing of AuthMetadataProcessor + + if (GetParam().protocol == Protocol::TCP) { + picked_port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << picked_port_; + builder.AddListeningPort(server_address_.str(), server_creds); + } + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } + + if (GetParam().use_interceptors) { + std::vector< + std::unique_ptr> + creators; + // Add 20 phony server interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + } + + server_ = builder.BuildAndStart(); + is_server_started_ = true; + } + + void ResetStub( + std::unique_ptr + interceptor = nullptr) { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + auto interceptors = CreatePhonyClientInterceptors(); + if (interceptor != nullptr) interceptors.push_back(std::move(interceptor)); + switch (GetParam().protocol) { + case Protocol::TCP: + if (!GetParam().use_interceptors) { + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + } else { + channel_ = CreateCustomChannelWithInterceptors( + server_address_.str(), channel_creds, args, + std::move(interceptors)); + } + break; + case Protocol::INPROC: + if (!GetParam().use_interceptors) { + channel_ = server_->InProcessChannel(args); + } else { + channel_ = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(interceptors)); + } + break; + default: + assert(false); + } + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + generic_stub_ = absl::make_unique(channel_); + PhonyInterceptor::Reset(); + } + + void TearDown() override { + if (is_server_started_) { + // Although we would normally do an explicit shutdown, the server + // should also work correctly with just a destructor call. The regular + // end2end test uses explicit shutdown, so let this one just do reset. + server_.reset(); + } + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + void SendRpcs(int num_rpcs, bool with_binary_metadata) { + std::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + test_string += "Hello world. "; + request.set_message(test_string); + std::string val; + if (with_binary_metadata) { + request.mutable_param()->set_echo_metadata(true); + char bytes[8] = {'\0', '\1', '\2', '\3', + '\4', '\5', '\6', static_cast(i)}; + val = std::string(bytes, 8); + cli_ctx.AddMetadata("custom-bin", val); + } + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->Echo( + &cli_ctx, &request, &response, + [&cli_ctx, &request, &response, &done, &mu, &cv, val, + with_binary_metadata](Status s) { + GPR_ASSERT(s.ok()); + + EXPECT_EQ(request.message(), response.message()); + if (with_binary_metadata) { + EXPECT_EQ( + 1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin")); + EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata() + .find("custom-bin") + ->second)); + } + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } + } + } + + void SendRpcsGeneric(int num_rpcs, bool maybe_except, + const char* suffix_for_stats) { + const std::string kMethodName("/grpc.testing.EchoTestService/Echo"); + std::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + std::unique_ptr send_buf; + ByteBuffer recv_buf; + ClientContext cli_ctx; + + test_string += "Hello world. "; + request.set_message(test_string); + send_buf = SerializeToByteBuffer(&request); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + StubOptions options(suffix_for_stats); + generic_stub_->UnaryCall( + &cli_ctx, kMethodName, options, send_buf.get(), &recv_buf, + [&request, &recv_buf, &done, &mu, &cv, maybe_except](Status s) { + GPR_ASSERT(s.ok()); + + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf, &response)); + EXPECT_EQ(request.message(), response.message()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); +#if GRPC_ALLOW_EXCEPTIONS + if (maybe_except) { + throw -1; + } +#else + GPR_ASSERT(!maybe_except); +#endif + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } + } + } + + void SendGenericEchoAsBidi(int num_rpcs, int reuses, bool do_writes_done, + const char* suffix_for_stats) { + const std::string kMethodName("/grpc.testing.EchoTestService/Echo"); + std::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + test_string += "Hello world. "; + class Client : public grpc::ClientBidiReactor { + public: + Client(ClientCallbackEnd2endTest* test, const std::string& method_name, + const char* suffix_for_stats, const std::string& test_str, + int reuses, bool do_writes_done) + : reuses_remaining_(reuses), do_writes_done_(do_writes_done) { + activate_ = [this, test, method_name, suffix_for_stats, test_str] { + if (reuses_remaining_ > 0) { + cli_ctx_ = absl::make_unique(); + reuses_remaining_--; + StubOptions options(suffix_for_stats); + test->generic_stub_->PrepareBidiStreamingCall( + cli_ctx_.get(), method_name, options, this); + request_.set_message(test_str); + send_buf_ = SerializeToByteBuffer(&request_); + StartWrite(send_buf_.get()); + StartRead(&recv_buf_); + StartCall(); + } else { + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + }; + activate_(); + } + void OnWriteDone(bool /*ok*/) override { + if (do_writes_done_) { + StartWritesDone(); + } + } + void OnReadDone(bool /*ok*/) override { + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); + EXPECT_EQ(request_.message(), response.message()); + }; + void OnDone(const Status& s) override { + EXPECT_TRUE(s.ok()); + activate_(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + EchoRequest request_; + std::unique_ptr send_buf_; + ByteBuffer recv_buf_; + std::unique_ptr cli_ctx_; + int reuses_remaining_; + std::function activate_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + const bool do_writes_done_; + }; + + Client rpc(this, kMethodName, suffix_for_stats, test_string, reuses, + do_writes_done); + + rpc.Await(); + } + } + bool is_server_started_{false}; + int picked_port_{0}; + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr generic_stub_; + TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; + std::unique_ptr server_; + std::ostringstream server_address_; +}; + +TEST_P(ClientCallbackEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpcs(1, false); +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcExpectedError) { + ResetStub(); + + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + ErrorStatus error_status; + + request.set_message("Hello failure"); + error_status.set_code(1); // CANCELLED + error_status.set_error_message("cancel error message"); + *request.mutable_param()->mutable_expected_error() = error_status; + + std::mutex mu; + std::condition_variable cv; + bool done = false; + + stub_->async()->Echo(&cli_ctx, &request, &response, + [&response, &done, &mu, &cv, &error_status](Status s) { + EXPECT_EQ("", response.message()); + EXPECT_EQ(error_status.code(), s.error_code()); + EXPECT_EQ(error_status.error_message(), + s.error_message()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLockNested) { + ResetStub(); + + // The request/response state associated with an RPC and the synchronization + // variables needed to notify its completion. + struct RpcState { + std::mutex mu; + std::condition_variable cv; + bool done = false; + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + RpcState() = default; + ~RpcState() { + // Grab the lock to prevent destruction while another is still holding + // lock + std::lock_guard lock(mu); + } + }; + std::vector rpc_state(3); + for (size_t i = 0; i < rpc_state.size(); i++) { + std::string message = "Hello locked world"; + message += std::to_string(i); + rpc_state[i].request.set_message(message); + } + + // Grab a lock and then start an RPC whose callback grabs the same lock and + // then calls this function to start the next RPC under lock (up to a limit of + // the size of the rpc_state vector). + std::function nested_call = [this, &nested_call, + &rpc_state](int index) { + std::lock_guard l(rpc_state[index].mu); + stub_->async()->Echo(&rpc_state[index].cli_ctx, &rpc_state[index].request, + &rpc_state[index].response, + [index, &nested_call, &rpc_state](Status s) { + std::lock_guard l1(rpc_state[index].mu); + EXPECT_TRUE(s.ok()); + rpc_state[index].done = true; + rpc_state[index].cv.notify_all(); + // Call the next level of nesting if possible + if (index + 1 < int(rpc_state.size())) { + nested_call(index + 1); + } + }); + }; + + nested_call(0); + + // Wait for completion notifications from all RPCs. Order doesn't matter. + for (RpcState& state : rpc_state) { + std::unique_lock l(state.mu); + while (!state.done) { + state.cv.wait(l); + } + EXPECT_EQ(state.request.message(), state.response.message()); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLock) { + ResetStub(); + std::mutex mu; + std::condition_variable cv; + bool done = false; + EchoRequest request; + request.set_message("Hello locked world."); + EchoResponse response; + ClientContext cli_ctx; + { + std::lock_guard l(mu); + stub_->async()->Echo(&cli_ctx, &request, &response, + [&mu, &cv, &done, &request, &response](Status s) { + std::lock_guard l(mu); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(request.message(), response.message()); + done = true; + cv.notify_one(); + }); + } + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) { + ResetStub(); + SendRpcs(10, false); +} + +TEST_P(ClientCallbackEnd2endTest, SendClientInitialMetadata) { + ResetStub(); + SimpleRequest request; + SimpleResponse response; + ClientContext cli_ctx; + + cli_ctx.AddMetadata(kCheckClientInitialMetadataKey, + kCheckClientInitialMetadataVal); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->CheckClientInitialMetadata( + &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) { + GPR_ASSERT(s.ok()); + + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcWithBinaryMetadata) { + ResetStub(); + SendRpcs(1, true); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) { + ResetStub(); + SendRpcs(10, true); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) { + ResetStub(absl::make_unique( + "/grpc.testing.EchoTestService/Echo", nullptr)); + SendRpcsGeneric(10, false, /*suffix_for_stats=*/nullptr); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsWithSuffix) { + ResetStub(absl::make_unique( + "/grpc.testing.EchoTestService/Echo", "TestSuffix")); + SendRpcsGeneric(10, false, "TestSuffix"); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) { + ResetStub(absl::make_unique( + "/grpc.testing.EchoTestService/Echo", nullptr)); + SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true, + /*suffix_for_stats=*/nullptr); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithSuffix) { + ResetStub(absl::make_unique( + "/grpc.testing.EchoTestService/Echo", "TestSuffix")); + SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true, "TestSuffix"); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) { + ResetStub(); + SendGenericEchoAsBidi(10, 10, /*do_writes_done=*/true, + /*suffix_for_stats=*/nullptr); +} + +TEST_P(ClientCallbackEnd2endTest, GenericRpcNoWritesDone) { + ResetStub(); + SendGenericEchoAsBidi(1, 1, /*do_writes_done=*/false, + /*suffix_for_stats=*/nullptr); +} + +#if GRPC_ALLOW_EXCEPTIONS +TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) { + ResetStub(); + SendRpcsGeneric(10, true, nullptr); +} +#endif + +TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { + ResetStub(); + std::vector threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back([this] { SendRpcs(10, true); }); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) { + ResetStub(); + std::vector threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back([this] { SendRpcs(10, false); }); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.TryCancel(); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->Echo(&context, &request, &response, + [&response, &done, &mu, &cv](Status s) { + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, RequestEchoServerCancel) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerTryCancelRequest, + std::to_string(CANCEL_BEFORE_PROCESSING)); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->Echo(&context, &request, &response, + [&done, &mu, &cv](Status s) { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +struct ClientCancelInfo { + bool cancel{false}; + int ops_before_cancel; + + ClientCancelInfo() : cancel{false} {} + explicit ClientCancelInfo(int ops) : cancel{true}, ops_before_cancel{ops} {} +}; + +class WriteClient : public grpc::ClientWriteReactor { + public: + WriteClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + int num_msgs_to_send, ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), + num_msgs_to_send_(num_msgs_to_send), + client_cancel_{client_cancel} { + std::string msg{"Hello server."}; + for (int i = 0; i < num_msgs_to_send; i++) { + desired_ += msg; + } + if (server_try_cancel != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + } + context_.set_initial_metadata_corked(true); + stub->async()->RequestStream(&context_, &response_, this); + StartCall(); + request_.set_message(msg); + MaybeWrite(); + } + void OnWriteDone(bool ok) override { + if (ok) { + num_msgs_sent_++; + MaybeWrite(); + } + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent_); + int num_to_send = + (client_cancel_.cancel) + ? std::min(num_msgs_to_send_, client_cancel_.ops_before_cancel) + : num_msgs_to_send_; + switch (server_try_cancel_) { + case CANCEL_BEFORE_PROCESSING: + case CANCEL_DURING_PROCESSING: + // If the RPC is canceled by server before / during messages from the + // client, it means that the client most likely did not get a chance to + // send all the messages it wanted to send. i.e num_msgs_sent <= + // num_msgs_to_send + EXPECT_LE(num_msgs_sent_, num_to_send); + break; + case DO_NOT_CANCEL: + case CANCEL_AFTER_PROCESSING: + // If the RPC was not canceled or canceled after all messages were read + // by the server, the client did get a chance to send all its messages + EXPECT_EQ(num_msgs_sent_, num_to_send); + break; + default: + assert(false); + break; + } + if ((server_try_cancel_ == DO_NOT_CANCEL) && !client_cancel_.cancel) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response_.message(), desired_); + } else { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + void MaybeWrite() { + if (client_cancel_.cancel && + num_msgs_sent_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } else if (num_msgs_to_send_ > num_msgs_sent_ + 1) { + StartWrite(&request_); + } else if (num_msgs_to_send_ == num_msgs_sent_ + 1) { + StartWriteLast(&request_, WriteOptions()); + } + } + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int num_msgs_sent_{0}; + const int num_msgs_to_send_; + std::string desired_; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; +}; + +TEST_P(ClientCallbackEnd2endTest, RequestStream) { + ResetStub(); + WriteClient test{stub_.get(), DO_NOT_CANCEL, 3}; + test.Await(); + // Make sure that the server interceptors were not notified to cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsRequestStream) { + ResetStub(); + WriteClient test{stub_.get(), DO_NOT_CANCEL, 3, ClientCancelInfo{2}}; + test.Await(); + // Make sure that the server interceptors got the cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel before doing reading the request +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelBeforeReads) { + ResetStub(); + WriteClient test{stub_.get(), CANCEL_BEFORE_PROCESSING, 1}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while reading a request from the stream in parallel +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelDuringRead) { + ResetStub(); + WriteClient test{stub_.get(), CANCEL_DURING_PROCESSING, 10}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after reading all the requests but before returning to the +// client +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelAfterReads) { + ResetStub(); + WriteClient test{stub_.get(), CANCEL_AFTER_PROCESSING, 4}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, UnaryReactor) { + ResetStub(); + class UnaryClient : public grpc::ClientUnaryReactor { + public: + explicit UnaryClient(grpc::testing::EchoTestService::Stub* stub) { + cli_ctx_.AddMetadata("key1", "val1"); + cli_ctx_.AddMetadata("key2", "val2"); + request_.mutable_param()->set_echo_metadata_initially(true); + request_.set_message("Hello metadata"); + stub->async()->Echo(&cli_ctx_, &request_, &response_, this); + StartCall(); + } + void OnReadInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1")); + EXPECT_EQ( + "val1", + ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second)); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2")); + EXPECT_EQ( + "val2", + ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second)); + initial_metadata_done_ = true; + } + void OnDone(const Status& s) override { + EXPECT_TRUE(initial_metadata_done_); + EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size()); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(request_.message(), response_.message()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext cli_ctx_; + std::mutex mu_; + std::condition_variable cv_; + bool done_{false}; + bool initial_metadata_done_{false}; + }; + + UnaryClient test{stub_.get()}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, GenericUnaryReactor) { + const std::string kMethodName("/grpc.testing.EchoTestService/Echo"); + constexpr char kSuffixForStats[] = "TestSuffixForStats"; + ResetStub( + absl::make_unique(kMethodName, kSuffixForStats)); + class UnaryClient : public grpc::ClientUnaryReactor { + public: + UnaryClient(grpc::GenericStub* stub, const std::string& method_name, + const char* suffix_for_stats) { + cli_ctx_.AddMetadata("key1", "val1"); + cli_ctx_.AddMetadata("key2", "val2"); + request_.mutable_param()->set_echo_metadata_initially(true); + request_.set_message("Hello metadata"); + send_buf_ = SerializeToByteBuffer(&request_); + + StubOptions options(suffix_for_stats); + stub->PrepareUnaryCall(&cli_ctx_, method_name, options, send_buf_.get(), + &recv_buf_, this); + StartCall(); + } + void OnReadInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1")); + EXPECT_EQ( + "val1", + ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second)); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2")); + EXPECT_EQ( + "val2", + ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second)); + initial_metadata_done_ = true; + } + void OnDone(const Status& s) override { + EXPECT_TRUE(initial_metadata_done_); + EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size()); + EXPECT_TRUE(s.ok()); + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); + EXPECT_EQ(request_.message(), response.message()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + std::unique_ptr send_buf_; + ByteBuffer recv_buf_; + ClientContext cli_ctx_; + std::mutex mu_; + std::condition_variable cv_; + bool done_{false}; + bool initial_metadata_done_{false}; + }; + + UnaryClient test{generic_stub_.get(), kMethodName, kSuffixForStats}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +class ReadClient : public grpc::ClientReadReactor { + public: + ReadClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), client_cancel_{client_cancel} { + if (server_try_cancel_ != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + } + request_.set_message("Hello client "); + stub->async()->ResponseStream(&context_, &request_, this); + if (client_cancel_.cancel && + reads_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } + // Even if we cancel, read until failure because there might be responses + // pending + StartRead(&response_); + StartCall(); + } + void OnReadDone(bool ok) override { + if (!ok) { + if (server_try_cancel_ == DO_NOT_CANCEL && !client_cancel_.cancel) { + EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend); + } + } else { + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + EXPECT_EQ(response_.message(), + request_.message() + std::to_string(reads_complete_)); + reads_complete_++; + if (client_cancel_.cancel && + reads_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } + // Even if we cancel, read until failure because there might be responses + // pending + StartRead(&response_); + } + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Read %d messages", reads_complete_); + switch (server_try_cancel_) { + case DO_NOT_CANCEL: + if (!client_cancel_.cancel || client_cancel_.ops_before_cancel > + kServerDefaultResponseStreamsToSend) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend); + } else { + EXPECT_GE(reads_complete_, client_cancel_.ops_before_cancel); + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + // Status might be ok or cancelled depending on whether server + // sent status before client cancel went through + if (!s.ok()) { + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + } + break; + case CANCEL_BEFORE_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(reads_complete_, 0); + break; + case CANCEL_DURING_PROCESSING: + case CANCEL_AFTER_PROCESSING: + // If server canceled while writing messages, client must have read + // less than or equal to the expected number of messages. Even if the + // server canceled after writing all messages, the RPC may be canceled + // before the Client got a chance to read all the messages. + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + break; + default: + assert(false); + } + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int reads_complete_{0}; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; +}; + +TEST_P(ClientCallbackEnd2endTest, ResponseStream) { + ResetStub(); + ReadClient test{stub_.get(), DO_NOT_CANCEL}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsResponseStream) { + ResetStub(); + ReadClient test{stub_.get(), DO_NOT_CANCEL, ClientCancelInfo{2}}; + test.Await(); + // Because cancel in this case races with server finish, we can't be sure that + // server interceptors even see cancellation +} + +// Server to cancel before sending any response messages +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelBefore) { + ResetStub(); + ReadClient test{stub_.get(), CANCEL_BEFORE_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while writing a response to the stream in parallel +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelDuring) { + ResetStub(); + ReadClient test{stub_.get(), CANCEL_DURING_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after writing all the respones to the stream but before +// returning to the client +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelAfter) { + ResetStub(); + ReadClient test{stub_.get(), CANCEL_AFTER_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +class BidiClient : public grpc::ClientBidiReactor { + public: + BidiClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + int num_msgs_to_send, bool cork_metadata, bool first_write_async, + ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), + msgs_to_send_{num_msgs_to_send}, + client_cancel_{client_cancel} { + if (server_try_cancel_ != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + } + request_.set_message("Hello fren "); + context_.set_initial_metadata_corked(cork_metadata); + stub->async()->BidiStream(&context_, this); + MaybeAsyncWrite(first_write_async); + StartRead(&response_); + StartCall(); + } + void OnReadDone(bool ok) override { + if (!ok) { + if (server_try_cancel_ == DO_NOT_CANCEL) { + if (!client_cancel_.cancel) { + EXPECT_EQ(reads_complete_, msgs_to_send_); + } else { + EXPECT_LE(reads_complete_, writes_complete_); + } + } + } else { + EXPECT_LE(reads_complete_, msgs_to_send_); + EXPECT_EQ(response_.message(), request_.message()); + reads_complete_++; + StartRead(&response_); + } + } + void OnWriteDone(bool ok) override { + if (async_write_thread_.joinable()) { + async_write_thread_.join(); + RemoveHold(); + } + if (server_try_cancel_ == DO_NOT_CANCEL) { + EXPECT_TRUE(ok); + } else if (!ok) { + return; + } + writes_complete_++; + MaybeWrite(); + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Sent %d messages", writes_complete_); + gpr_log(GPR_INFO, "Read %d messages", reads_complete_); + switch (server_try_cancel_) { + case DO_NOT_CANCEL: + if (!client_cancel_.cancel || + client_cancel_.ops_before_cancel > msgs_to_send_) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(writes_complete_, msgs_to_send_); + EXPECT_EQ(reads_complete_, writes_complete_); + } else { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(writes_complete_, client_cancel_.ops_before_cancel); + EXPECT_LE(reads_complete_, writes_complete_); + } + break; + case CANCEL_BEFORE_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // The RPC is canceled before the server did any work or returned any + // reads, but it's possible that some writes took place first from the + // client + EXPECT_LE(writes_complete_, msgs_to_send_); + EXPECT_EQ(reads_complete_, 0); + break; + case CANCEL_DURING_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_LE(writes_complete_, msgs_to_send_); + EXPECT_LE(reads_complete_, writes_complete_); + break; + case CANCEL_AFTER_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(writes_complete_, msgs_to_send_); + // The Server canceled after reading the last message and after writing + // the message to the client. However, the RPC cancellation might have + // taken effect before the client actually read the response. + EXPECT_LE(reads_complete_, writes_complete_); + break; + default: + assert(false); + } + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + void MaybeAsyncWrite(bool first_write_async) { + if (first_write_async) { + // Make sure that we have a write to issue. + // TODO(vjpai): Make this work with 0 writes case as well. + assert(msgs_to_send_ >= 1); + + AddHold(); + async_write_thread_ = std::thread([this] { + std::unique_lock lock(async_write_thread_mu_); + async_write_thread_cv_.wait( + lock, [this] { return async_write_thread_start_; }); + MaybeWrite(); + }); + std::lock_guard lock(async_write_thread_mu_); + async_write_thread_start_ = true; + async_write_thread_cv_.notify_one(); + return; + } + MaybeWrite(); + } + void MaybeWrite() { + if (client_cancel_.cancel && + writes_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } else if (writes_complete_ == msgs_to_send_) { + StartWritesDone(); + } else { + StartWrite(&request_); + } + } + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int reads_complete_{0}; + int writes_complete_{0}; + const int msgs_to_send_; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + std::thread async_write_thread_; + bool async_write_thread_start_ = false; + std::mutex async_write_thread_mu_; + std::condition_variable async_write_thread_cv_; +}; + +TEST_P(ClientCallbackEnd2endTest, BidiStream) { + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamFirstWriteAsync) { + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/true); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamCorked) { + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/true, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamCorkedFirstWriteAsync) { + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/true, /*first_write_async=*/true); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsBidiStream) { + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/false, + ClientCancelInfo(2)); + test.Await(); + // Make sure that the server interceptors were notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel before reading/writing any requests/responses on the stream +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelBefore) { + ResetStub(); + BidiClient test(stub_.get(), CANCEL_BEFORE_PROCESSING, /*num_msgs_to_send=*/2, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while reading/writing requests/responses on the stream in +// parallel +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelDuring) { + ResetStub(); + BidiClient test(stub_.get(), CANCEL_DURING_PROCESSING, + /*num_msgs_to_send=*/10, /*cork_metadata=*/false, + /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after reading/writing all requests/responses on the stream +// but before returning to the client +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelAfter) { + ResetStub(); + BidiClient test(stub_.get(), CANCEL_AFTER_PROCESSING, /*num_msgs_to_send=*/5, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimultaneousReadAndWritesDone) { + ResetStub(); + class Client : public grpc::ClientBidiReactor { + public: + explicit Client(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello bidi "); + stub->async()->BidiStream(&context_, this); + StartWrite(&request_); + StartCall(); + } + void OnReadDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(response_.message(), request_.message()); + } + void OnWriteDone(bool ok) override { + EXPECT_TRUE(ok); + // Now send out the simultaneous Read and WritesDone + StartWritesDone(); + StartRead(&response_); + } + void OnDone(const Status& s) override { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response_.message(), request_.message()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } test{stub_.get()}; + + test.Await(); +} + +TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) { + ChannelArguments args; + const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr channel = + (GetParam().protocol == Protocol::TCP) + ? ::grpc::CreateCustomChannel(server_address_.str(), channel_creds, + args) + : server_->InProcessChannel(args); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + request.set_message("Hello world."); + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub->async()->Unimplemented( + &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) { + EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code()); + EXPECT_EQ("", s.error_message()); + + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, TestTrailersOnlyOnError) { + // Note that trailers-only is an HTTP/2 concept so we shouldn't do this test + // for any other transport such as inproc. + if (GetParam().protocol != Protocol::TCP) { + return; + } + + ResetStub(); + class Reactor : public grpc::ClientBidiReactor { + public: + explicit Reactor(grpc::testing::EchoTestService::Stub* stub) { + stub->async()->UnimplementedBidi(&context_, this); + StartCall(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + done_cv_.wait(l); + } + } + + private: + void OnReadInitialMetadataDone(bool ok) override { EXPECT_FALSE(ok); } + void OnDone(const Status& s) override { + EXPECT_EQ(s.error_code(), grpc::StatusCode::UNIMPLEMENTED); + EXPECT_EQ(s.error_message(), ""); + std::unique_lock l(mu_); + done_ = true; + done_cv_.notify_one(); + } + + ClientContext context_; + std::mutex mu_; + std::condition_variable done_cv_; + bool done_ = false; + } client(stub_.get()); + + client.Await(); +} + +TEST_P(ClientCallbackEnd2endTest, + ResponseStreamExtraReactionFlowReadsUntilDone) { + ResetStub(); + class ReadAllIncomingDataClient + : public grpc::ClientReadReactor { + public: + explicit ReadAllIncomingDataClient( + grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello client "); + stub->async()->ResponseStream(&context_, &request_, this); + } + bool WaitForReadDone() { + std::unique_lock l(mu_); + while (!read_done_) { + read_cv_.wait(l); + } + read_done_ = false; + return read_ok_; + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + done_cv_.wait(l); + } + } + // RemoveHold under the same lock used for OnDone to make sure that we don't + // call OnDone directly or indirectly from the RemoveHold function. + void RemoveHoldUnderLock() { + std::unique_lock l(mu_); + RemoveHold(); + } + const Status& status() { + std::unique_lock l(mu_); + return status_; + } + + private: + void OnReadDone(bool ok) override { + std::unique_lock l(mu_); + read_ok_ = ok; + read_done_ = true; + read_cv_.notify_one(); + } + void OnDone(const Status& s) override { + std::unique_lock l(mu_); + done_ = true; + status_ = s; + done_cv_.notify_one(); + } + + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + bool read_ok_ = false; + bool read_done_ = false; + std::mutex mu_; + std::condition_variable read_cv_; + std::condition_variable done_cv_; + bool done_ = false; + Status status_; + } client{stub_.get()}; + + int reads_complete = 0; + client.AddHold(); + client.StartCall(); + + EchoResponse response; + bool read_ok = true; + while (read_ok) { + client.StartRead(&response); + read_ok = client.WaitForReadDone(); + if (read_ok) { + ++reads_complete; + } + } + client.RemoveHoldUnderLock(); + client.Await(); + + EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete); + EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK); +} + +std::vector CreateTestScenarios(bool test_insecure) { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + std::vector scenarios; + std::vector credentials_types{ + GetCredentialsProvider()->GetSecureCredentialsTypeList()}; + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + GPR_ASSERT(!credentials_types.empty()); + + bool barr[]{false, true}; + Protocol parr[]{Protocol::INPROC, Protocol::TCP}; + for (Protocol p : parr) { + for (const auto& cred : credentials_types) { + // TODO(vjpai): Test inproc with secure credentials when feasible + if (p == Protocol::INPROC && + (cred != kInsecureCredentialsType || !insec_ok())) { + continue; + } + for (bool callback_server : barr) { + for (bool use_interceptors : barr) { + scenarios.emplace_back(callback_server, p, use_interceptors, cred); + } + } + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/cpp/end2end/client_crash_test.cc b/test/cpp/end2end/client_crash_test.cc new file mode 100644 index 00000000..e34186be --- /dev/null +++ b/test/cpp/end2end/client_crash_test.cc @@ -0,0 +1,148 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +static std::string g_root; + +namespace grpc { +namespace testing { + +namespace { + +class CrashTest : public ::testing::Test { + protected: + CrashTest() {} + + std::unique_ptr CreateServerAndStub() { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + auto addr = addr_stream.str(); + server_ = absl::make_unique(std::vector({ + g_root + "/client_crash_test_server", + "--address=" + addr, + })); + GPR_ASSERT(server_); + return grpc::testing::EchoTestService::NewStub( + grpc::CreateChannel(addr, InsecureChannelCredentials())); + } + + void KillServer() { server_.reset(); } + + private: + std::unique_ptr server_; +}; + +TEST_F(CrashTest, KillBeforeWrite) { + auto stub = CreateServerAndStub(); + + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + KillServer(); + + request.set_message("You should be dead"); + // This may succeed or fail depending on the state of the TCP connection + stream->Write(request); + // But the read will definitely fail + EXPECT_FALSE(stream->Read(&response)); + + EXPECT_FALSE(stream->Finish().ok()); +} + +TEST_F(CrashTest, KillAfterWrite) { + auto stub = CreateServerAndStub(); + + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message("I'm going to kill you"); + EXPECT_TRUE(stream->Write(request)); + + KillServer(); + + // This may succeed or fail depending on how quick the server was + stream->Read(&response); + + EXPECT_FALSE(stream->Finish().ok()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + std::string me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != std::string::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + // Order seems to matter on these tests: run three times to eliminate that + for (int i = 0; i < 3; i++) { + if (RUN_ALL_TESTS() != 0) { + return 1; + } + } + return 0; +} diff --git a/test/cpp/end2end/client_crash_test_server.cc b/test/cpp/end2end/client_crash_test_server.cc new file mode 100644 index 00000000..5b8a533f --- /dev/null +++ b/test/cpp/end2end/client_crash_test_server.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(std::string, address, "", "Address to bind to"); + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +class ServiceImpl final : public ::grpc::testing::EchoTestService::Service { + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + } + return Status::OK; + } +}; + +void RunServer() { + ServiceImpl service; + + ServerBuilder builder; + builder.AddListeningPort(absl::GetFlag(FLAGS_address), + grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << absl::GetFlag(FLAGS_address) + << std::endl; + server->Wait(); +} +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + grpc::testing::RunServer(); + + return 0; +} diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc new file mode 100644 index 00000000..9a792ca7 --- /dev/null +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -0,0 +1,1244 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#ifdef GRPC_POSIX_SOCKET +#include + +#include "src/core/lib/iomgr/socket_utils_posix.h" +#endif /* GRPC_POSIX_SOCKET */ + +namespace grpc { +namespace testing { +namespace { + +enum class RPCType { + kSyncUnary, + kSyncClientStreaming, + kSyncServerStreaming, + kSyncBidiStreaming, + kAsyncCQUnary, + kAsyncCQClientStreaming, + kAsyncCQServerStreaming, + kAsyncCQBidiStreaming, +}; + +enum class ChannelType { + kHttpChannel, + kFdChannel, +}; + +/* Hijacks Echo RPC and fills in the expected values */ +class HijackingInterceptor : public experimental::Interceptor { + public: + explicit HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + EXPECT_EQ(info->suffix_for_stats(), nullptr); + EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + explicit HijackingInterceptorMakesAnotherCall( + experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + req_ = req; + stub_ = grpc::testing::EchoTestService::NewStub( + methods->GetInterceptedChannel()); + ctx_.AddMetadata(metadata_map_.begin()->first, + metadata_map_.begin()->second); + stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp_.message(), "Hello"); + methods->Hijack(); + }); + // This is a Unary RPC and we have got nothing interesting to do in the + // PRE_SEND_CLOSE interception hook point for this interceptor, so let's + // return here. (We do not want to call methods->Proceed(). When the new + // RPC returns, we will call methods->Hijack() instead.) + return; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message(resp_.message()); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap metadata_map_; + ClientContext ctx_; + EchoRequest req_; + EchoResponse resp_; + std::unique_ptr stub_; +}; + +class HijackingInterceptorMakesAnotherCallFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptorMakesAnotherCall(info); + } +}; + +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { + public: + explicit BidiStreamingRpcHijackingInterceptor( + experimental::ClientRpcInfo* info) { + info_ = info; + EXPECT_EQ(info->suffix_for_stats(), nullptr); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message().find("Hello"), 0u); + msg = req.message(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", + "testvalue"); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message(msg); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EXPECT_EQ(static_cast(methods->GetRecvMessage()) + ->message() + .find("Hello"), + 0u); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + std::string msg; +}; + +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + explicit ClientStreamingRpcHijackingInterceptor( + experimental::ClientRpcInfo* info) { + info_ = info; + EXPECT_EQ( + strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()), + 0); + EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0); + } + void Intercept(experimental::InterceptorBatchMethods* methods) override { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + +class ServerStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + explicit ServerStreamingRpcHijackingInterceptor( + experimental::ClientRpcInfo* info) { + info_ = info; + got_failed_message_ = false; + EXPECT_EQ(info->suffix_for_stats(), nullptr); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedRecvMessage(); + } + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = methods->GetRecvMessage() == nullptr; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedMessage() { return got_failed_message_; } + + private: + experimental::ClientRpcInfo* info_; + static bool got_failed_message_; + int count_ = 0; +}; + +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + +class ServerStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ServerStreamingRpcHijackingInterceptor(info); + } +}; + +class BidiStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new BidiStreamingRpcHijackingInterceptor(info); + } +}; + +// The logging interceptor is for testing purposes only. It is used to verify +// that all the appropriate hook points are invoked for an RPC. The counts are +// reset each time a new object of LoggingInterceptor is created, so only a +// single RPC should be made on the channel before calling the Verify methods. +class LoggingInterceptor : public experimental::Interceptor { + public: + explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) { + pre_send_initial_metadata_ = false; + pre_send_message_count_ = 0; + pre_send_close_ = false; + post_recv_initial_metadata_ = false; + post_recv_message_count_ = 0; + post_recv_status_ = false; + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + ASSERT_FALSE(pre_send_initial_metadata_); + pre_send_initial_metadata_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* send_msg = methods->GetSendMessage(); + if (send_msg == nullptr) { + // We did not get the non-serialized form of the message. Get the + // serialized form. + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EchoRequest req; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } else { + EXPECT_EQ( + static_cast(send_msg)->message().find("Hello"), + 0u); + } + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_TRUE(req.message().find("Hello") == 0u); + pre_send_message_count_++; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + pre_send_close_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + post_recv_initial_metadata_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + if (resp != nullptr) { + EXPECT_TRUE(resp->message().find("Hello") == 0u); + post_recv_message_count_++; + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + post_recv_status_ = true; + } + methods->Proceed(); + } + + static void VerifyCall(RPCType type) { + switch (type) { + case RPCType::kSyncUnary: + case RPCType::kAsyncCQUnary: + VerifyUnaryCall(); + break; + case RPCType::kSyncClientStreaming: + case RPCType::kAsyncCQClientStreaming: + VerifyClientStreamingCall(); + break; + case RPCType::kSyncServerStreaming: + case RPCType::kAsyncCQServerStreaming: + VerifyServerStreamingCall(); + break; + case RPCType::kSyncBidiStreaming: + case RPCType::kAsyncCQBidiStreaming: + VerifyBidiStreamingCall(); + break; + } + } + + static void VerifyCallCommon() { + EXPECT_TRUE(pre_send_initial_metadata_); + EXPECT_TRUE(pre_send_close_); + EXPECT_TRUE(post_recv_initial_metadata_); + EXPECT_TRUE(post_recv_status_); + } + + static void VerifyUnaryCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyClientStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyServerStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + + static void VerifyBidiStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + + private: + static bool pre_send_initial_metadata_; + static int pre_send_message_count_; + static bool pre_send_close_; + static bool post_recv_initial_metadata_; + static int post_recv_message_count_; + static bool post_recv_status_; +}; + +bool LoggingInterceptor::pre_send_initial_metadata_; +int LoggingInterceptor::pre_send_message_count_; +bool LoggingInterceptor::pre_send_close_; +bool LoggingInterceptor::post_recv_initial_metadata_; +int LoggingInterceptor::post_recv_message_count_; +bool LoggingInterceptor::post_recv_status_; + +class LoggingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +class TestScenario { + public: + explicit TestScenario(const ChannelType& channel_type, + const RPCType& rpc_type) + : channel_type_(channel_type), rpc_type_(rpc_type) {} + + ChannelType channel_type() const { return channel_type_; } + + RPCType rpc_type() const { return rpc_type_; } + + private: + const ChannelType channel_type_; + const RPCType rpc_type_; +}; + +std::vector CreateTestScenarios() { + std::vector scenarios; + std::vector rpc_types; + rpc_types.emplace_back(RPCType::kSyncUnary); + rpc_types.emplace_back(RPCType::kSyncClientStreaming); + rpc_types.emplace_back(RPCType::kSyncServerStreaming); + rpc_types.emplace_back(RPCType::kSyncBidiStreaming); + rpc_types.emplace_back(RPCType::kAsyncCQUnary); + rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming); + for (const auto& rpc_type : rpc_types) { + scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type); +// TODO(yashykt): Maybe add support for non-posix sockets too +#ifdef GRPC_POSIX_SOCKET + scenarios.emplace_back(ChannelType::kFdChannel, rpc_type); +#endif /* GRPC_POSIX_SOCKET */ + } + return scenarios; +} + +class ParameterizedClientInterceptorsEnd2endTest + : public ::testing::TestWithParam { + protected: + ParameterizedClientInterceptorsEnd2endTest() { + ServerBuilder builder; + builder.RegisterService(&service_); + if (GetParam().channel_type() == ChannelType::kHttpChannel) { + int port = grpc_pick_unused_port_or_die(); + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + server_ = builder.BuildAndStart(); + } +#ifdef GRPC_POSIX_SOCKET + else if (GetParam().channel_type() == ChannelType::kFdChannel) { + int flags; + GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_) == 0); + flags = fcntl(sv_[0], F_GETFL, 0); + GPR_ASSERT(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK) == 0); + flags = fcntl(sv_[1], F_GETFL, 0); + GPR_ASSERT(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK) == 0); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) == + GRPC_ERROR_NONE); + GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) == + GRPC_ERROR_NONE); + server_ = builder.BuildAndStart(); + AddInsecureChannelFromFd(server_.get(), sv_[1]); + } +#endif /* GRPC_POSIX_SOCKET */ + } + + ~ParameterizedClientInterceptorsEnd2endTest() override { + server_->Shutdown(); + } + + std::shared_ptr CreateClientChannel( + std::vector> + creators) { + if (GetParam().channel_type() == ChannelType::kHttpChannel) { + return experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), ChannelArguments(), + std::move(creators)); + } +#ifdef GRPC_POSIX_SOCKET + else if (GetParam().channel_type() == ChannelType::kFdChannel) { + return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd( + "", sv_[0], ChannelArguments(), std::move(creators)); + } +#endif /* GRPC_POSIX_SOCKET */ + return nullptr; + } + + void SendRPC(const std::shared_ptr& channel) { + switch (GetParam().rpc_type()) { + case RPCType::kSyncUnary: + MakeCall(channel); + break; + case RPCType::kSyncClientStreaming: + MakeClientStreamingCall(channel); + break; + case RPCType::kSyncServerStreaming: + MakeServerStreamingCall(channel); + break; + case RPCType::kSyncBidiStreaming: + MakeBidiStreamingCall(channel); + break; + case RPCType::kAsyncCQUnary: + MakeAsyncCQCall(channel); + break; + case RPCType::kAsyncCQClientStreaming: + // TODO(yashykt) : Fill this out + break; + case RPCType::kAsyncCQServerStreaming: + MakeAsyncCQServerStreamingCall(channel); + break; + case RPCType::kAsyncCQBidiStreaming: + // TODO(yashykt) : Fill this out + break; + } + } + + std::string server_address_; + int sv_[2]; + EchoTestServiceStreamingImpl service_; + std::unique_ptr server_; +}; + +TEST_P(ParameterizedClientInterceptorsEnd2endTest, + ClientInterceptorLoggingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = CreateClientChannel(std::move(creators)); + SendRPC(channel); + LoggingInterceptor::VerifyCall(GetParam().rpc_type()); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end, + ParameterizedClientInterceptorsEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios())); + +class ClientInterceptorsEnd2endTest + : public ::testing::TestWithParam { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ClientInterceptorsEnd2endTest, + LameChannelClientInterceptorHijackingTest) { + ChannelArguments args; + std::vector> + creators; + creators.push_back(absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, nullptr, args, std::move(creators)); + MakeCall(channel); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + // Add 20 phony interceptors before hijacking interceptor + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure only 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { + ChannelArguments args; + std::vector> + creators; + creators.push_back(absl::make_unique()); + creators.push_back(absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorHijackingMakesAnotherCallTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + // Add 5 phony interceptors before hijacking interceptor + creators.reserve(5); + for (auto i = 0; i < 5; i++) { + creators.push_back(absl::make_unique()); + } + creators.push_back( + std::unique_ptr( + new HijackingInterceptorMakesAnotherCallFactory())); + // Add 7 phony interceptors after hijacking interceptor + for (auto i = 0; i < 7; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + + MakeCall(channel, StubOptions("TestSuffixForStats")); + // Make sure all interceptors were run once, since the hijacking interceptor + // makes an RPC on the intercepted channel + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12); +} + +class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsCallbackEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ClientInterceptorsCallbackEnd2endTest, + ClientInterceptorLoggingTestWithCallback) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsCallbackEnd2endTest, + ClientInterceptorFactoryAllowsNullptrReturn) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors and 20 null interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + creators.push_back(absl::make_unique()); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsStreamingEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); } + + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeClientStreamingCall(channel); + LoggingInterceptor::VerifyClientStreamingCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + LoggingInterceptor::VerifyServerStreamingCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector> + creators; + creators.push_back( + absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub( + channel, StubOptions("TestSuffixForStats")); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, + AsyncCQServerStreamingHijackingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeAsyncCQServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + absl::make_unique()); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + creators.push_back(absl::make_unique()); + // Add 20 phony interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); + LoggingInterceptor::VerifyBidiStreamingCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +class ClientGlobalInterceptorEnd2endTest : public ::testing::Test { + protected: + ClientGlobalInterceptorEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + PhonyInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + // Add 20 phony interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 phony interceptors were run with the global interceptor + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + LoggingInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + // Add 20 phony interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + HijackingInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + PhonyInterceptor::Reset(); + std::vector> + creators; + // Add 20 phony interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/client_lb_end2end_test.cc b/test/cpp/end2end/client_lb_end2end_test.cc new file mode 100644 index 00000000..3892400f --- /dev/null +++ b/test/cpp/end2end/client_lb_end2end_test.cc @@ -0,0 +1,2021 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/service_config/service_config.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/orca_load_report.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/resolve_localhost_ip46.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_lb_policies.h" +#include "test/cpp/end2end/test_service_impl.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +// defined in tcp_client.cc +extern grpc_tcp_client_vtable* grpc_tcp_client_impl; + +static grpc_tcp_client_vtable* default_client_impl; + +namespace grpc { +namespace testing { +namespace { + +gpr_atm g_connection_delay_ms; + +void tcp_client_connect_with_delay(grpc_closure* closure, grpc_endpoint** ep, + grpc_slice_allocator* slice_allocator, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, + grpc_millis deadline) { + const int delay_ms = gpr_atm_acq_load(&g_connection_delay_ms); + if (delay_ms > 0) { + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + } + default_client_impl->connect(closure, ep, slice_allocator, interested_parties, + channel_args, addr, deadline + delay_ms); +} + +grpc_tcp_client_vtable delayed_connect = {tcp_client_connect_with_delay}; + +// Subclass of TestServiceImpl that increments a request counter for +// every call to the Echo RPC. +class MyTestServiceImpl : public TestServiceImpl { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + const xds::data::orca::v3::OrcaLoadReport* load_report = nullptr; + { + grpc::internal::MutexLock lock(&mu_); + ++request_count_; + load_report = load_report_; + } + AddClient(context->peer()); + if (load_report != nullptr) { + // TODO(roth): Once we provide a more standard server-side API for + // populating this data, use that API here. + context->AddTrailingMetadata("x-endpoint-load-metrics-bin", + load_report->SerializeAsString()); + } + return TestServiceImpl::Echo(context, request, response); + } + + int request_count() { + grpc::internal::MutexLock lock(&mu_); + return request_count_; + } + + void ResetCounters() { + grpc::internal::MutexLock lock(&mu_); + request_count_ = 0; + } + + std::set clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + void set_load_report(xds::data::orca::v3::OrcaLoadReport* load_report) { + grpc::internal::MutexLock lock(&mu_); + load_report_ = load_report; + } + + private: + void AddClient(const std::string& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex mu_; + int request_count_ = 0; + const xds::data::orca::v3::OrcaLoadReport* load_report_ = nullptr; + grpc::internal::Mutex clients_mu_; + std::set clients_; +}; + +class FakeResolverResponseGeneratorWrapper { + public: + explicit FakeResolverResponseGeneratorWrapper(bool ipv6_only) + : ipv6_only_(ipv6_only), + response_generator_(grpc_core::MakeRefCounted< + grpc_core::FakeResolverResponseGenerator>()) {} + + FakeResolverResponseGeneratorWrapper( + FakeResolverResponseGeneratorWrapper&& other) noexcept { + ipv6_only_ = other.ipv6_only_; + response_generator_ = std::move(other.response_generator_); + } + + void SetNextResolution( + const std::vector& ports, const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr attribute = + nullptr) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetResponse( + BuildFakeResults(ipv6_only_, ports, service_config_json, attribute_key, + std::move(attribute))); + } + + void SetNextResolutionUponError(const std::vector& ports) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetReresolutionResponse( + BuildFakeResults(ipv6_only_, ports)); + } + + void SetFailureOnReresolution() { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetFailureOnReresolution(); + } + + grpc_core::FakeResolverResponseGenerator* Get() const { + return response_generator_.get(); + } + + private: + static grpc_core::Resolver::Result BuildFakeResults( + bool ipv6_only, const std::vector& ports, + const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr attribute = + nullptr) { + grpc_core::Resolver::Result result; + for (const int& port : ports) { + absl::StatusOr lb_uri = grpc_core::URI::Parse( + absl::StrCat(ipv6_only ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", port)); + GPR_ASSERT(lb_uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*lb_uri, &address)); + std::map> + attributes; + if (attribute != nullptr) { + attributes[attribute_key] = attribute->Copy(); + } + result.addresses.emplace_back(address.addr, address.len, + nullptr /* args */, std::move(attributes)); + } + if (service_config_json != nullptr) { + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, service_config_json, &result.service_config_error); + GPR_ASSERT(result.service_config != nullptr); + } + return result; + } + + bool ipv6_only_ = false; + grpc_core::RefCountedPtr + response_generator_; +}; + +class ClientLbEnd2endTest : public ::testing::Test { + protected: + ClientLbEnd2endTest() + : server_host_("localhost"), + kRequestMessage_("Live long and prosper."), + creds_(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create())) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + + void SetUp() override { + grpc_init(); + bool localhost_resolves_to_ipv4 = false; + bool localhost_resolves_to_ipv6 = false; + grpc_core::LocalhostResolves(&localhost_resolves_to_ipv4, + &localhost_resolves_to_ipv6); + ipv6_only_ = !localhost_resolves_to_ipv4 && localhost_resolves_to_ipv6; + } + + void TearDown() override { + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + servers_.clear(); + creds_.reset(); + grpc_shutdown(); + } + + void CreateServers(size_t num_servers, + std::vector ports = std::vector()) { + servers_.clear(); + for (size_t i = 0; i < num_servers; ++i) { + int port = 0; + if (ports.size() == num_servers) port = ports[i]; + servers_.emplace_back(new ServerData(port)); + } + } + + void StartServer(size_t index) { servers_[index]->Start(server_host_); } + + void StartServers(size_t num_servers, + std::vector ports = std::vector()) { + CreateServers(num_servers, std::move(ports)); + for (size_t i = 0; i < num_servers; ++i) { + StartServer(i); + } + } + + std::vector GetServersPorts(size_t start_index = 0) { + std::vector ports; + for (size_t i = start_index; i < servers_.size(); ++i) { + ports.push_back(servers_[i]->port_); + } + return ports; + } + + FakeResolverResponseGeneratorWrapper BuildResolverResponseGenerator() { + return FakeResolverResponseGeneratorWrapper(ipv6_only_); + } + + std::unique_ptr BuildStub( + const std::shared_ptr& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr BuildChannel( + const std::string& lb_policy_name, + const FakeResolverResponseGeneratorWrapper& response_generator, + ChannelArguments args = ChannelArguments()) { + if (!lb_policy_name.empty()) { + args.SetLoadBalancingPolicyName(lb_policy_name); + } // else, default to pick first + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator.Get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + bool SendRpc( + const std::unique_ptr& stub, + EchoResponse* response = nullptr, int timeout_ms = 1000, + Status* result = nullptr, bool wait_for_ready = false) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + request.mutable_param()->set_echo_metadata(true); + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + context.AddMetadata("foo", "1"); + context.AddMetadata("bar", "2"); + context.AddMetadata("baz", "3"); + Status status = stub->Echo(&context, request, response); + if (result != nullptr) *result = status; + if (local_response) delete response; + return status.ok(); + } + + void CheckRpcSendOk( + const std::unique_ptr& stub, + const grpc_core::DebugLocation& location, bool wait_for_ready = false) { + EchoResponse response; + Status status; + const bool success = + SendRpc(stub, &response, 2000, &status, wait_for_ready); + ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line() + << "\n" + << "Error: " << status.error_message() << " " + << status.error_details(); + ASSERT_EQ(response.message(), kRequestMessage_) + << "From " << location.file() << ":" << location.line(); + if (!success) abort(); + } + + void CheckRpcSendFailure( + const std::unique_ptr& stub) { + const bool success = SendRpc(stub); + EXPECT_FALSE(success); + } + + struct ServerData { + const int port_; + std::unique_ptr server_; + MyTestServiceImpl service_; + std::unique_ptr thread_; + + grpc::internal::Mutex mu_; + grpc::internal::CondVar cond_; + bool server_ready_ ABSL_GUARDED_BY(mu_) = false; + bool started_ ABSL_GUARDED_BY(mu_) = false; + + explicit ServerData(int port = 0) + : port_(port > 0 ? port : grpc_pick_unused_port_or_die()) {} + + void Start(const std::string& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + grpc::internal::MutexLock lock(&mu_); + started_ = true; + thread_ = absl::make_unique( + std::bind(&ServerData::Serve, this, server_host)); + while (!server_ready_) { + cond_.Wait(&mu_); + } + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const std::string& server_host) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), std::move(creds)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc::internal::MutexLock lock(&mu_); + server_ready_ = true; + cond_.Signal(); + } + + void Shutdown() { + grpc::internal::MutexLock lock(&mu_); + if (!started_) return; + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + started_ = false; + } + + void SetServingStatus(const std::string& service, bool serving) { + server_->GetHealthCheckService()->SetServingStatus(service, serving); + } + }; + + void ResetCounters() { + for (const auto& server : servers_) server->service_.ResetCounters(); + } + + void WaitForServer( + const std::unique_ptr& stub, + size_t server_idx, const grpc_core::DebugLocation& location, + bool ignore_failure = false) { + do { + if (ignore_failure) { + SendRpc(stub); + } else { + CheckRpcSendOk(stub, location, true); + } + } while (servers_[server_idx]->service_.request_count() == 0); + ResetCounters(); + } + + bool WaitForChannelState( + Channel* channel, + const std::function& predicate, + bool try_to_connect = false, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + while (true) { + grpc_connectivity_state state = channel->GetState(try_to_connect); + if (predicate(state)) break; + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + auto predicate = [](grpc_connectivity_state state) { + return state != GRPC_CHANNEL_READY; + }; + return WaitForChannelState(channel, predicate, false, timeout_seconds); + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_READY; + }; + return WaitForChannelState(channel, predicate, true, timeout_seconds); + } + + bool SeenAllServers() { + for (const auto& server : servers_) { + if (server->service_.request_count() == 0) return false; + } + return true; + } + + // Updates \a connection_order by appending to it the index of the newly + // connected server. Must be called after every single RPC. + void UpdateConnectionOrder( + const std::vector>& servers, + std::vector* connection_order) { + for (size_t i = 0; i < servers.size(); ++i) { + if (servers[i]->service_.request_count() == 1) { + // Was the server index known? If not, update connection_order. + const auto it = + std::find(connection_order->begin(), connection_order->end(), i); + if (it == connection_order->end()) { + connection_order->push_back(i); + return; + } + } + } + } + + const std::string server_host_; + std::vector> servers_; + const std::string kRequestMessage_; + std::shared_ptr creds_; + bool ipv6_only_ = false; +}; + +TEST_F(ClientLbEnd2endTest, ChannelStateConnectingWhenResolving) { + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator); + auto stub = BuildStub(channel); + // Initial state should be IDLE. + EXPECT_EQ(channel->GetState(false /* try_to_connect */), GRPC_CHANNEL_IDLE); + // Tell the channel to try to connect. + // Note that this call also returns IDLE, since the state change has + // not yet occurred; it just gets triggered by this call. + EXPECT_EQ(channel->GetState(true /* try_to_connect */), GRPC_CHANNEL_IDLE); + // Now that the channel is trying to connect, we should be in state + // CONNECTING. + EXPECT_EQ(channel->GetState(false /* try_to_connect */), + GRPC_CHANNEL_CONNECTING); + // Return a resolver result, which allows the connection attempt to proceed. + response_generator.SetNextResolution(GetServersPorts()); + // We should eventually transition into state READY. + EXPECT_TRUE(WaitForChannelReady(channel.get())); +} + +TEST_F(ClientLbEnd2endTest, PickFirst) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel( + "", response_generator); // test that pick first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // All requests should have gone to a single server. + bool found = false; + for (size_t i = 0; i < servers_.size(); ++i) { + const int request_count = servers_[i]->service_.request_count(); + if (request_count == kNumServers) { + found = true; + } else { + EXPECT_EQ(0, request_count); + } + } + EXPECT_TRUE(found); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstProcessPending) { + StartServers(1); // Single server + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel( + "", response_generator); // test that pick first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Create a new channel and its corresponding PF LB policy, which will pick + // the subchannels in READY state from the previous RPC against the same + // target (even if it happened over a different channel, because subchannels + // are globally reused). Progress should happen without any transition from + // this READY state. + auto second_response_generator = BuildResolverResponseGenerator(); + auto second_channel = BuildChannel("", second_response_generator); + auto second_stub = BuildStub(second_channel); + second_response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(second_stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, PickFirstSelectsReadyAtStartup) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 5000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + // Create 2 servers, but start only the second one. + std::vector ports = {grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die()}; + CreateServers(2, ports); + StartServer(1); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1, args); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + // Wait for second server to be ready. + WaitForServer(stub1, 1, DEBUG_LOCATION); + // Create a second channel with the same addresses. Its PF instance + // should immediately pick the second subchannel, since it's already + // in READY state. + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2, args); + response_generator2.SetNextResolution(ports); + // Check that the channel reports READY without waiting for the + // initial backoff. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1 /* timeout_seconds */)); +} + +TEST_F(ClientLbEnd2endTest, PickFirstBackOffInitialReconnect) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 100; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector ports = {grpc_pick_unused_port_or_die()}; + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // The channel won't become connected (there's no server). + ASSERT_FALSE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2))); + // Bring up a server on the chosen port. + StartServers(1, ports); + // Now it will. + ASSERT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + // We should have waited at least kInitialBackOffMs. We substract one to + // account for test and precision accuracy drift. + EXPECT_GE(waited_ms, kInitialBackOffMs - 1); + // But not much more. + EXPECT_GT( + gpr_time_cmp( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 1.10), t1), + 0); +} + +TEST_F(ClientLbEnd2endTest, PickFirstBackOffMinReconnect) { + ChannelArguments args; + constexpr int kMinReconnectBackOffMs = 1000; + args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kMinReconnectBackOffMs); + const std::vector ports = {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // Make connection delay a 10% longer than it's willing to in order to make + // sure we are hitting the codepath that waits for the min reconnect backoff. + gpr_atm_rel_store(&g_connection_delay_ms, kMinReconnectBackOffMs * 1.10); + default_client_impl = grpc_tcp_client_impl; + grpc_set_tcp_client_impl(&delayed_connect); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kMinReconnectBackOffMs * 2)); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " ms", waited_ms); + // We should have waited at least kMinReconnectBackOffMs. We substract one to + // account for test and precision accuracy drift. + EXPECT_GE(waited_ms, kMinReconnectBackOffMs - 1); + gpr_atm_rel_store(&g_connection_delay_ms, 0); +} + +TEST_F(ClientLbEnd2endTest, PickFirstResetConnectionBackoff) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 1000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector ports = {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // The channel won't become connected (there's no server). + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Bring up a server on the chosen port. + StartServers(1, ports); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + // Wait for connect, but not long enough. This proves that we're + // being throttled by initial backoff. + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Reset connection backoff. + experimental::ChannelResetConnectionBackoff(channel.get()); + // Wait for connect. Should happen as soon as the client connects to + // the newly started server, which should be before the initial + // backoff timeout elapses. + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(20))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + // We should have waited less than kInitialBackOffMs. + EXPECT_LT(waited_ms, kInitialBackOffMs); +} + +TEST_F(ClientLbEnd2endTest, + PickFirstResetConnectionBackoffNextAttemptStartsImmediately) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 1000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector ports = {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // Wait for connect, which should fail ~immediately, because the server + // is not up. + gpr_log(GPR_INFO, "=== INITIAL CONNECTION ATTEMPT"); + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Reset connection backoff. + // Note that the time at which the third attempt will be started is + // actually computed at this point, so we record the start time here. + gpr_log(GPR_INFO, "=== RESETTING BACKOFF"); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + experimental::ChannelResetConnectionBackoff(channel.get()); + // Trigger a second connection attempt. This should also fail + // ~immediately, but the retry should be scheduled for + // kInitialBackOffMs instead of applying the multiplier. + gpr_log(GPR_INFO, "=== POLLING FOR SECOND CONNECTION ATTEMPT"); + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Bring up a server on the chosen port. + gpr_log(GPR_INFO, "=== STARTING BACKEND"); + StartServers(1, ports); + // Wait for connect. Should happen within kInitialBackOffMs. + // Give an extra 100ms to account for the time spent in the second and + // third connection attempts themselves (since what we really want to + // measure is the time between the two). As long as this is less than + // the 1.6x increase we would see if the backoff state was not reset + // properly, the test is still proving that the backoff was reset. + constexpr int kWaitMs = kInitialBackOffMs + 100; + gpr_log(GPR_INFO, "=== POLLING FOR THIRD CONNECTION ATTEMPT"); + EXPECT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kWaitMs))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + EXPECT_LT(waited_ms, kWaitMs); +} + +TEST_F(ClientLbEnd2endTest, PickFirstUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + + std::vector ports; + + // Perform one RPC against the first server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [0] *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + + // An empty update will result in the channel going into TRANSIENT_FAILURE. + ports.clear(); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET none *******"); + grpc_connectivity_state channel_state; + do { + channel_state = channel->GetState(true /* try to connect */); + } while (channel_state == GRPC_CHANNEL_READY); + ASSERT_NE(channel_state, GRPC_CHANNEL_READY); + servers_[0]->service_.ResetCounters(); + + // Next update introduces servers_[1], making the channel recover. + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [1] *******"); + WaitForServer(stub, 1, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 0); + + // And again for servers_[2] + ports.clear(); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [2] *******"); + WaitForServer(stub, 2, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 0); + EXPECT_EQ(servers_[1]->service_.request_count(), 0); + + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstUpdateSuperset) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + + std::vector ports; + + // Perform one RPC against the first server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [0] *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + servers_[0]->service_.ResetCounters(); + + // Send and superset update + ports.clear(); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET superset *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + // We stick to the previously connected server. + WaitForServer(stub, 0, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstGlobalSubchannelPool) { + // Start one server. + const int kNumServers = 1; + StartServers(kNumServers); + std::vector ports = GetServersPorts(); + // Create two channels that (by default) use the global subchannel pool. + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + WaitForServer(stub1, 0, DEBUG_LOCATION); + // Send one RPC on each channel. + CheckRpcSendOk(stub1, DEBUG_LOCATION); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // The server receives two requests. + EXPECT_EQ(2, servers_[0]->service_.request_count()); + // The two requests are from the same client port, because the two channels + // share subchannels via the global subchannel pool. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstLocalSubchannelPool) { + // Start one server. + const int kNumServers = 1; + StartServers(kNumServers); + std::vector ports = GetServersPorts(); + // Create two channels that use local subchannel pool. + ChannelArguments args; + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1, args); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2, args); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + WaitForServer(stub1, 0, DEBUG_LOCATION); + // Send one RPC on each channel. + CheckRpcSendOk(stub1, DEBUG_LOCATION); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // The server receives two requests. + EXPECT_EQ(2, servers_[0]->service_.request_count()); + // The two requests are from two client ports, because the two channels didn't + // share subchannels with each other. + EXPECT_EQ(2UL, servers_[0]->service_.clients().size()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstManyUpdates) { + const int kNumUpdates = 1000; + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + std::vector ports = GetServersPorts(); + for (size_t i = 0; i < kNumUpdates; ++i) { + std::shuffle(ports.begin(), ports.end(), + std::mt19937(std::random_device()())); + response_generator.SetNextResolution(ports); + // We should re-enter core at the end of the loop to give the resolution + // setting closure a chance to run. + if ((i + 1) % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstReresolutionNoSelected) { + // Prepare the ports for up servers and down servers. + const int kNumServers = 3; + const int kNumAliveServers = 1; + StartServers(kNumAliveServers); + std::vector alive_ports, dead_ports; + for (size_t i = 0; i < kNumServers; ++i) { + if (i < kNumAliveServers) { + alive_ports.emplace_back(servers_[i]->port_); + } else { + dead_ports.emplace_back(grpc_pick_unused_port_or_die()); + } + } + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + // The initial resolution only contains dead ports. There won't be any + // selected subchannel. Re-resolution will return the same result. + response_generator.SetNextResolution(dead_ports); + gpr_log(GPR_INFO, "****** INITIAL RESOLUTION SET *******"); + for (size_t i = 0; i < 10; ++i) CheckRpcSendFailure(stub); + // Set a re-resolution result that contains reachable ports, so that the + // pick_first LB policy can recover soon. + response_generator.SetNextResolutionUponError(alive_ports); + gpr_log(GPR_INFO, "****** RE-RESOLUTION SET *******"); + WaitForServer(stub, 0, DEBUG_LOCATION, true /* ignore_failure */); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstReconnectWithoutNewResolverResult) { + std::vector ports = {grpc_pick_unused_port_or_die()}; + StartServers(1, ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******"); + WaitForServer(stub, 0, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** STOPPING SERVER ******"); + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + gpr_log(GPR_INFO, "****** RESTARTING SERVER ******"); + StartServers(1, ports); + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, + PickFirstReconnectWithoutNewResolverResultStartsFromTopOfList) { + std::vector ports = {grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die()}; + CreateServers(2, ports); + StartServer(1); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******"); + WaitForServer(stub, 1, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** STOPPING SERVER ******"); + servers_[1]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + gpr_log(GPR_INFO, "****** STARTING BOTH SERVERS ******"); + StartServers(2, ports); + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, PickFirstCheckStateBeforeStartWatch) { + std::vector ports = {grpc_pick_unused_port_or_die()}; + StartServers(1, ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel_1 = BuildChannel("pick_first", response_generator); + auto stub_1 = BuildStub(channel_1); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 1 *******"); + WaitForServer(stub_1, 0, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** CHANNEL 1 CONNECTED *******"); + servers_[0]->Shutdown(); + // Channel 1 will receive a re-resolution containing the same server. It will + // create a new subchannel and hold a ref to it. + StartServers(1, ports); + gpr_log(GPR_INFO, "****** SERVER RESTARTED *******"); + auto response_generator_2 = BuildResolverResponseGenerator(); + auto channel_2 = BuildChannel("pick_first", response_generator_2); + auto stub_2 = BuildStub(channel_2); + response_generator_2.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 2 *******"); + WaitForServer(stub_2, 0, DEBUG_LOCATION, true); + gpr_log(GPR_INFO, "****** CHANNEL 2 CONNECTED *******"); + servers_[0]->Shutdown(); + // Wait until the disconnection has triggered the connectivity notification. + // Otherwise, the subchannel may be picked for next call but will fail soon. + EXPECT_TRUE(WaitForChannelNotReady(channel_2.get())); + // Channel 2 will also receive a re-resolution containing the same server. + // Both channels will ref the same subchannel that failed. + StartServers(1, ports); + gpr_log(GPR_INFO, "****** SERVER RESTARTED AGAIN *******"); + gpr_log(GPR_INFO, "****** CHANNEL 2 STARTING A CALL *******"); + // The first call after the server restart will succeed. + CheckRpcSendOk(stub_2, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** CHANNEL 2 FINISHED A CALL *******"); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel_1->GetLoadBalancingPolicyName()); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel_2->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstIdleOnDisconnect) { + // Start server, send RPC, and make sure channel is READY. + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Stop server. Channel should go into state IDLE. + response_generator.SetFailureOnReresolution(); + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + servers_.clear(); +} + +TEST_F(ClientLbEnd2endTest, PickFirstPendingUpdateAndSelectedSubchannelFails) { + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + // Create a number of servers, but only start 1 of them. + CreateServers(10); + StartServer(0); + // Initially resolve to first server and make sure it connects. + gpr_log(GPR_INFO, "Phase 1: Connect to first server."); + response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(stub, DEBUG_LOCATION, true /* wait_for_ready */); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Send a resolution update with the remaining servers, none of which are + // running yet, so the update will stay pending. Note that it's important + // to have multiple servers here, or else the test will be flaky; with only + // one server, the pending subchannel list has already gone into + // TRANSIENT_FAILURE due to hitting the end of the list by the time we + // check the state. + gpr_log(GPR_INFO, + "Phase 2: Resolver update pointing to remaining " + "(not started) servers."); + response_generator.SetNextResolution(GetServersPorts(1 /* start_index */)); + // RPCs will continue to be sent to the first server. + CheckRpcSendOk(stub, DEBUG_LOCATION); + // Now stop the first server, so that the current subchannel list + // fails. This should cause us to immediately swap over to the + // pending list, even though it's not yet connected. The state should + // be set to CONNECTING, since that's what the pending subchannel list + // was doing when we swapped over. + gpr_log(GPR_INFO, "Phase 3: Stopping first server."); + servers_[0]->Shutdown(); + WaitForChannelNotReady(channel.get()); + // TODO(roth): This should always return CONNECTING, but it's flaky + // between that and TRANSIENT_FAILURE. I suspect that this problem + // will go away once we move the backoff code out of the subchannel + // and into the LB policies. + EXPECT_THAT(channel->GetState(false), + ::testing::AnyOf(GRPC_CHANNEL_CONNECTING, + GRPC_CHANNEL_TRANSIENT_FAILURE)); + // Now start the second server. + gpr_log(GPR_INFO, "Phase 4: Starting second server."); + StartServer(1); + // The channel should go to READY state and RPCs should go to the + // second server. + WaitForChannelReady(channel.get()); + WaitForServer(stub, 1, DEBUG_LOCATION, true /* ignore_failure */); +} + +TEST_F(ClientLbEnd2endTest, PickFirstStaysIdleUponEmptyUpdate) { + // Start server, send RPC, and make sure channel is READY. + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Stop server. Channel should go into state IDLE. + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // Now send resolver update that includes no addresses. Channel + // should stay in state IDLE. + response_generator.SetNextResolution({}); + EXPECT_FALSE(channel->WaitForStateChange( + GRPC_CHANNEL_IDLE, grpc_timeout_seconds_to_deadline(3))); + // Now bring the backend back up and send a non-empty resolver update, + // and then try to send an RPC. Channel should go back into state READY. + StartServer(0); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +TEST_F(ClientLbEnd2endTest, RoundRobin) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Wait until all backends are ready. + do { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } while (!SeenAllServers()); + ResetCounters(); + // "Sync" to the end of the list. Next sequence of picks will start at the + // first server (index 0). + WaitForServer(stub, servers_.size() - 1, DEBUG_LOCATION); + std::vector connection_order; + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + UpdateConnectionOrder(servers_, &connection_order); + } + // Backends should be iterated over in the order in which the addresses were + // given. + const auto expected = std::vector{0, 1, 2}; + EXPECT_EQ(expected, connection_order); + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinProcessPending) { + StartServers(1); // Single server + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Create a new channel and its corresponding RR LB policy, which will pick + // the subchannels in READY state from the previous RPC against the same + // target (even if it happened over a different channel, because subchannels + // are globally reused). Progress should happen without any transition from + // this READY state. + auto second_response_generator = BuildResolverResponseGenerator(); + auto second_channel = BuildChannel("round_robin", second_response_generator); + auto second_stub = BuildStub(second_channel); + second_response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(second_stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector ports; + // Start with a single server. + gpr_log(GPR_INFO, "*** FIRST BACKEND ***"); + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Send RPCs. They should all go servers_[0] + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[0]->service_.ResetCounters(); + // And now for the second server. + gpr_log(GPR_INFO, "*** SECOND BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + // Wait until update has been processed, as signaled by the second backend + // receiving a request. + EXPECT_EQ(0, servers_[1]->service_.request_count()); + WaitForServer(stub, 1, DEBUG_LOCATION); + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[0]->service_.request_count()); + EXPECT_EQ(10, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[1]->service_.ResetCounters(); + // ... and for the last server. + gpr_log(GPR_INFO, "*** THIRD BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 2, DEBUG_LOCATION); + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(10, servers_[2]->service_.request_count()); + servers_[2]->service_.ResetCounters(); + // Back to all servers. + gpr_log(GPR_INFO, "*** ALL BACKENDS ***"); + ports.clear(); + ports.emplace_back(servers_[0]->port_); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + WaitForServer(stub, 1, DEBUG_LOCATION); + WaitForServer(stub, 2, DEBUG_LOCATION); + // Send three RPCs, one per server. + for (size_t i = 0; i < 3; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(1, servers_[0]->service_.request_count()); + EXPECT_EQ(1, servers_[1]->service_.request_count()); + EXPECT_EQ(1, servers_[2]->service_.request_count()); + // An empty update will result in the channel going into TRANSIENT_FAILURE. + gpr_log(GPR_INFO, "*** NO BACKENDS ***"); + ports.clear(); + response_generator.SetNextResolution(ports); + grpc_connectivity_state channel_state; + do { + channel_state = channel->GetState(true /* try to connect */); + } while (channel_state == GRPC_CHANNEL_READY); + ASSERT_NE(channel_state, GRPC_CHANNEL_READY); + servers_[0]->service_.ResetCounters(); + // Next update introduces servers_[1], making the channel recover. + gpr_log(GPR_INFO, "*** BACK TO SECOND BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 1, DEBUG_LOCATION); + channel_state = channel->GetState(false /* try to connect */); + ASSERT_EQ(channel_state, GRPC_CHANNEL_READY); + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinUpdateInError) { + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector ports; + // Start with a single server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Send RPCs. They should all go to servers_[0] + for (size_t i = 0; i < 10; ++i) SendRpc(stub); + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[0]->service_.ResetCounters(); + // Shutdown one of the servers to be sent in the update. + servers_[1]->Shutdown(); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + WaitForServer(stub, 2, DEBUG_LOCATION); + // Send three RPCs, one per server. + for (size_t i = 0; i < kNumServers; ++i) SendRpc(stub); + // The server in shutdown shouldn't receive any. + EXPECT_EQ(0, servers_[1]->service_.request_count()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinManyUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector ports = GetServersPorts(); + for (size_t i = 0; i < 1000; ++i) { + std::shuffle(ports.begin(), ports.end(), + std::mt19937(std::random_device()())); + response_generator.SetNextResolution(ports); + if (i % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinConcurrentUpdates) { + // TODO(dgq): replicate the way internal testing exercises the concurrent + // update provisions of RR. +} + +TEST_F(ClientLbEnd2endTest, RoundRobinReresolve) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + std::vector first_ports; + std::vector second_ports; + first_ports.reserve(kNumServers); + for (int i = 0; i < kNumServers; ++i) { + first_ports.push_back(grpc_pick_unused_port_or_die()); + } + second_ports.reserve(kNumServers); + for (int i = 0; i < kNumServers; ++i) { + second_ports.push_back(grpc_pick_unused_port_or_die()); + } + StartServers(kNumServers, first_ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(first_ports); + // Send a number of RPCs, which succeed. + for (size_t i = 0; i < 100; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Kill all servers + gpr_log(GPR_INFO, "****** ABOUT TO KILL SERVERS *******"); + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + gpr_log(GPR_INFO, "****** SERVERS KILLED *******"); + gpr_log(GPR_INFO, "****** SENDING DOOMED REQUESTS *******"); + // Client requests should fail. Send enough to tickle all subchannels. + for (size_t i = 0; i < servers_.size(); ++i) CheckRpcSendFailure(stub); + gpr_log(GPR_INFO, "****** DOOMED REQUESTS SENT *******"); + // Bring servers back up on a different set of ports. We need to do this to be + // sure that the eventual success is *not* due to subchannel reconnection + // attempts and that an actual re-resolution has happened as a result of the + // RR policy going into transient failure when all its subchannels become + // unavailable (in transient failure as well). + gpr_log(GPR_INFO, "****** RESTARTING SERVERS *******"); + StartServers(kNumServers, second_ports); + // Don't notify of the update. Wait for the LB policy's re-resolution to + // "pull" the new ports. + response_generator.SetNextResolutionUponError(second_ports); + gpr_log(GPR_INFO, "****** SERVERS RESTARTED *******"); + gpr_log(GPR_INFO, "****** SENDING REQUEST TO SUCCEED *******"); + // Client request should eventually (but still fairly soon) succeed. + const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + while (gpr_time_cmp(deadline, now) > 0) { + if (SendRpc(stub)) break; + now = gpr_now(GPR_CLOCK_MONOTONIC); + } + ASSERT_GT(gpr_time_cmp(deadline, now), 0); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailure) { + // Start servers and create channel. Channel should go to READY state. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + // Now kill the servers. The channel should transition to TRANSIENT_FAILURE. + // TODO(roth): This test should ideally check that even when the + // subchannels are in state CONNECTING for an extended period of time, + // we will still report TRANSIENT_FAILURE. Unfortunately, we don't + // currently have a good way to get a subchannel to report CONNECTING + // for a long period of time, since the servers in this test framework + // are on the loopback interface, which will immediately return a + // "Connection refused" error, so the subchannels will only be in + // CONNECTING state very briefly. When we have time, see if we can + // find a way to fix this. + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_TRANSIENT_FAILURE; + }; + EXPECT_TRUE(WaitForChannelState(channel.get(), predicate)); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailureAtStartup) { + // Create channel and return servers that don't exist. Channel should + // quickly transition into TRANSIENT_FAILURE. + // TODO(roth): This test should ideally check that even when the + // subchannels are in state CONNECTING for an extended period of time, + // we will still report TRANSIENT_FAILURE. Unfortunately, we don't + // currently have a good way to get a subchannel to report CONNECTING + // for a long period of time, since the servers in this test framework + // are on the loopback interface, which will immediately return a + // "Connection refused" error, so the subchannels will only be in + // CONNECTING state very briefly. When we have time, see if we can + // find a way to fix this. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({ + grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die(), + }); + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_TRANSIENT_FAILURE; + }; + EXPECT_TRUE(WaitForChannelState(channel.get(), predicate, true)); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinSingleReconnect) { + const int kNumServers = 3; + StartServers(kNumServers); + const auto ports = GetServersPorts(); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + for (size_t i = 0; i < kNumServers; ++i) { + WaitForServer(stub, i, DEBUG_LOCATION); + } + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(1, servers_[i]->service_.request_count()) << "for backend #" << i; + } + // One request should have gone to each server. + for (size_t i = 0; i < servers_.size(); ++i) { + EXPECT_EQ(1, servers_[i]->service_.request_count()); + } + const auto pre_death = servers_[0]->service_.request_count(); + // Kill the first server. + servers_[0]->Shutdown(); + // Client request still succeed. May need retrying if RR had returned a pick + // before noticing the change in the server's connectivity. + while (!SendRpc(stub)) { + } // Retry until success. + // Send a bunch of RPCs that should succeed. + for (int i = 0; i < 10 * kNumServers; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + const auto post_death = servers_[0]->service_.request_count(); + // No requests have gone to the deceased server. + EXPECT_EQ(pre_death, post_death); + // Bring the first server back up. + StartServer(0); + // Requests should start arriving at the first server either right away (if + // the server managed to start before the RR policy retried the subchannel) or + // after the subchannel retry delay otherwise (RR's subchannel retried before + // the server was fully back up). + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +// If health checking is required by client but health checking service +// is not running on the server, the channel should be treated as healthy. +TEST_F(ClientLbEnd2endTest, + RoundRobinServersHealthCheckingUnimplementedTreatedAsHealthy) { + StartServers(1); // Single server + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + CheckRpcSendOk(stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthChecking) { + EnableDefaultHealthCheckService(true); + // Start servers. + const int kNumServers = 3; + StartServers(kNumServers); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Channel should not become READY, because health checks should be failing. + gpr_log(GPR_INFO, + "*** initial state: unknown health check service name for " + "all servers"); + EXPECT_FALSE(WaitForChannelReady(channel.get(), 1)); + // Now set one of the servers to be healthy. + // The channel should become healthy and all requests should go to + // the healthy server. + gpr_log(GPR_INFO, "*** server 0 healthy"); + servers_[0]->SetServingStatus("health_check_service_name", true); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + for (int i = 0; i < 10; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + // Now set a second server to be healthy. + gpr_log(GPR_INFO, "*** server 2 healthy"); + servers_[2]->SetServingStatus("health_check_service_name", true); + WaitForServer(stub, 2, DEBUG_LOCATION); + for (int i = 0; i < 10; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(5, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(5, servers_[2]->service_.request_count()); + // Now set the remaining server to be healthy. + gpr_log(GPR_INFO, "*** server 1 healthy"); + servers_[1]->SetServingStatus("health_check_service_name", true); + WaitForServer(stub, 1, DEBUG_LOCATION); + for (int i = 0; i < 9; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(3, servers_[0]->service_.request_count()); + EXPECT_EQ(3, servers_[1]->service_.request_count()); + EXPECT_EQ(3, servers_[2]->service_.request_count()); + // Now set one server to be unhealthy again. Then wait until the + // unhealthiness has hit the client. We know that the client will see + // this when we send kNumServers requests and one of the remaining servers + // sees two of the requests. + gpr_log(GPR_INFO, "*** server 0 unhealthy"); + servers_[0]->SetServingStatus("health_check_service_name", false); + do { + ResetCounters(); + for (int i = 0; i < kNumServers; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + } while (servers_[1]->service_.request_count() != 2 && + servers_[2]->service_.request_count() != 2); + // Now set the remaining two servers to be unhealthy. Make sure the + // channel leaves READY state and that RPCs fail. + gpr_log(GPR_INFO, "*** all servers unhealthy"); + servers_[1]->SetServingStatus("health_check_service_name", false); + servers_[2]->SetServingStatus("health_check_service_name", false); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + CheckRpcSendFailure(stub); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, + RoundRobinWithHealthCheckingHandlesSubchannelFailure) { + EnableDefaultHealthCheckService(true); + // Start servers. + const int kNumServers = 3; + StartServers(kNumServers); + servers_[0]->SetServingStatus("health_check_service_name", true); + servers_[1]->SetServingStatus("health_check_service_name", true); + servers_[2]->SetServingStatus("health_check_service_name", true); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Stop server 0 and send a new resolver result to ensure that RR + // checks each subchannel's state. + servers_[0]->Shutdown(); + response_generator.SetNextResolution(GetServersPorts()); + // Send a bunch more RPCs. + for (size_t i = 0; i < 100; i++) { + SendRpc(stub); + } +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingInhibitPerChannel) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("round_robin", response_generator1, args); + auto stub1 = BuildStub(channel1); + std::vector ports = GetServersPorts(); + response_generator1.SetNextResolution(ports); + // Create a channel with health checking enabled but inhibited. + args.SetInt(GRPC_ARG_INHIBIT_HEALTH_CHECKING, 1); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("round_robin", response_generator2, args); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + // First channel should not become READY, because health checks should be + // failing. + EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1)); + CheckRpcSendFailure(stub1); + // Second channel should be READY. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1)); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // Enable health checks on the backend and wait for channel 1 to succeed. + servers_[0]->SetServingStatus("health_check_service_name", true); + CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */); + // Check that we created only one subchannel to the backend. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingServiceNamePerChannel) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("round_robin", response_generator1, args); + auto stub1 = BuildStub(channel1); + std::vector ports = GetServersPorts(); + response_generator1.SetNextResolution(ports); + // Create a channel with health-checking enabled with a different + // service name. + ChannelArguments args2; + args2.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name2\"}}"); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("round_robin", response_generator2, args2); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + // Allow health checks from channel 2 to succeed. + servers_[0]->SetServingStatus("health_check_service_name2", true); + // First channel should not become READY, because health checks should be + // failing. + EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1)); + CheckRpcSendFailure(stub1); + // Second channel should be READY. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1)); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // Enable health checks for channel 1 and wait for it to succeed. + servers_[0]->SetServingStatus("health_check_service_name", true); + CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */); + // Check that we created only one subchannel to the backend. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, + RoundRobinWithHealthCheckingServiceNameChangesAfterSubchannelsCreated) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + const char* kServiceConfigJson = + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector ports = GetServersPorts(); + response_generator.SetNextResolution(ports, kServiceConfigJson); + servers_[0]->SetServingStatus("health_check_service_name", true); + EXPECT_TRUE(WaitForChannelReady(channel.get(), 1 /* timeout_seconds */)); + // Send an update on the channel to change it to use a health checking + // service name that is not being reported as healthy. + const char* kServiceConfigJson2 = + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name2\"}}"; + response_generator.SetNextResolution(ports, kServiceConfigJson2); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, ChannelIdleness) { + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Set max idle time and build the channel. + ChannelArguments args; + args.SetInt(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS, 1000); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator, args); + auto stub = BuildStub(channel); + // The initial channel state should be IDLE. + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // After sending RPC, channel state should be READY. + gpr_log(GPR_INFO, "*** SENDING RPC, CHANNEL SHOULD CONNECT ***"); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // After a period time not using the channel, the channel state should switch + // to IDLE. + gpr_log(GPR_INFO, "*** WAITING FOR CHANNEL TO GO IDLE ***"); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1200)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // Sending a new RPC should awake the IDLE channel. + gpr_log(GPR_INFO, "*** SENDING ANOTHER RPC, CHANNEL SHOULD RECONNECT ***"); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +class ClientLbPickArgsTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterTestPickArgsLoadBalancingPolicy(SavePickArgs); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + std::vector args_seen_list() { + grpc::internal::MutexLock lock(&mu_); + return args_seen_list_; + } + + static std::string ArgsSeenListString( + const std::vector& args_seen_list) { + std::vector entries; + for (const auto& args_seen : args_seen_list) { + std::vector metadata; + for (const auto& p : args_seen.metadata) { + metadata.push_back(absl::StrCat(p.first, "=", p.second)); + } + entries.push_back(absl::StrFormat("{path=\"%s\", metadata=[%s]}", + args_seen.path, + absl::StrJoin(metadata, ", "))); + } + return absl::StrCat("[", absl::StrJoin(entries, ", "), "]"); + } + + private: + static void SavePickArgs(const grpc_core::PickArgsSeen& args_seen) { + ClientLbPickArgsTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->args_seen_list_.emplace_back(args_seen); + } + + static ClientLbPickArgsTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector args_seen_list_; +}; + +ClientLbPickArgsTest* ClientLbPickArgsTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbPickArgsTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("test_pick_args_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Proactively connect the channel, so that the LB policy will always + // be connected before it sees the pick. Otherwise, the test would be + // flaky because sometimes the pick would be seen twice (once in + // CONNECTING and again in READY) and other times only once (in READY). + ASSERT_TRUE(channel->WaitForConnected(gpr_inf_future(GPR_CLOCK_MONOTONIC))); + // Check LB policy name for the channel. + EXPECT_EQ("test_pick_args_lb", channel->GetLoadBalancingPolicyName()); + // Now send an RPC and check that the picker sees the expected data. + CheckRpcSendOk(stub, DEBUG_LOCATION, /*wait_for_ready=*/true); + auto pick_args_seen_list = args_seen_list(); + EXPECT_THAT(pick_args_seen_list, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Field(&grpc_core::PickArgsSeen::path, + "/grpc.testing.EchoTestService/Echo"), + ::testing::Field(&grpc_core::PickArgsSeen::metadata, + ::testing::UnorderedElementsAre( + ::testing::Pair("foo", "1"), + ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3")))))) + << ArgsSeenListString(pick_args_seen_list); +} + +class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( + ReportTrailerIntercepted); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + int trailers_intercepted() { + grpc::internal::MutexLock lock(&mu_); + return trailers_intercepted_; + } + + const grpc_core::MetadataVector& trailing_metadata() { + grpc::internal::MutexLock lock(&mu_); + return trailing_metadata_; + } + + const xds::data::orca::v3::OrcaLoadReport* backend_load_report() { + grpc::internal::MutexLock lock(&mu_); + return load_report_.get(); + } + + private: + static void ReportTrailerIntercepted( + const grpc_core::TrailingMetadataArgsSeen& args_seen) { + const auto* backend_metric_data = args_seen.backend_metric_data; + ClientLbInterceptTrailingMetadataTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->trailers_intercepted_++; + self->trailing_metadata_ = args_seen.metadata; + if (backend_metric_data != nullptr) { + self->load_report_ = + absl::make_unique(); + self->load_report_->set_cpu_utilization( + backend_metric_data->cpu_utilization); + self->load_report_->set_mem_utilization( + backend_metric_data->mem_utilization); + self->load_report_->set_rps(backend_metric_data->requests_per_second); + for (const auto& p : backend_metric_data->request_cost) { + std::string name = std::string(p.first); + (*self->load_report_->mutable_request_cost())[name] = p.second; + } + for (const auto& p : backend_metric_data->utilization) { + std::string name = std::string(p.first); + (*self->load_report_->mutable_utilization())[name] = p.second; + } + } + } + + static ClientLbInterceptTrailingMetadataTest* current_test_instance_; + grpc::internal::Mutex mu_; + int trailers_intercepted_ = 0; + grpc_core::MetadataVector trailing_metadata_; + std::unique_ptr load_report_; +}; + +ClientLbInterceptTrailingMetadataTest* + ClientLbInterceptTrailingMetadataTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); + EXPECT_EQ(nullptr, backend_load_report()); +} + +TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesEnabled) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"grpc.testing.EchoTestService\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); + EXPECT_EQ(nullptr, backend_load_report()); +} + +TEST_F(ClientLbInterceptTrailingMetadataTest, BackendMetricData) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + xds::data::orca::v3::OrcaLoadReport load_report; + load_report.set_cpu_utilization(0.5); + load_report.set_mem_utilization(0.75); + load_report.set_rps(25); + auto* request_cost = load_report.mutable_request_cost(); + (*request_cost)["foo"] = 0.8; + (*request_cost)["bar"] = 1.4; + auto* utilization = load_report.mutable_utilization(); + (*utilization)["baz"] = 1.1; + (*utilization)["quux"] = 0.9; + for (const auto& server : servers_) { + server->service_.set_load_report(&load_report); + } + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + auto* actual = backend_load_report(); + ASSERT_NE(actual, nullptr); + // TODO(roth): Change this to use EqualsProto() once that becomes + // available in OSS. + EXPECT_EQ(actual->cpu_utilization(), load_report.cpu_utilization()); + EXPECT_EQ(actual->mem_utilization(), load_report.mem_utilization()); + EXPECT_EQ(actual->rps(), load_report.rps()); + EXPECT_EQ(actual->request_cost().size(), load_report.request_cost().size()); + for (const auto& p : actual->request_cost()) { + auto it = load_report.request_cost().find(p.first); + ASSERT_NE(it, load_report.request_cost().end()); + EXPECT_EQ(it->second, p.second); + } + EXPECT_EQ(actual->utilization().size(), load_report.utilization().size()); + for (const auto& p : actual->utilization()) { + auto it = load_report.utilization().find(p.first); + ASSERT_NE(it, load_report.utilization().end()); + EXPECT_EQ(it->second, p.second); + } + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); +} + +class ClientLbAddressTest : public ClientLbEnd2endTest { + protected: + static const char* kAttributeKey; + + class Attribute : public grpc_core::ServerAddress::AttributeInterface { + public: + explicit Attribute(const std::string& str) : str_(str) {} + + std::unique_ptr Copy() const override { + return absl::make_unique(str_); + } + + int Cmp(const AttributeInterface* other) const override { + return str_.compare(static_cast(other)->str_); + } + + std::string ToString() const override { return str_; } + + private: + std::string str_; + }; + + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterAddressTestLoadBalancingPolicy(SaveAddress); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + const std::vector& addresses_seen() { + grpc::internal::MutexLock lock(&mu_); + return addresses_seen_; + } + + private: + static void SaveAddress(const grpc_core::ServerAddress& address) { + ClientLbAddressTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->addresses_seen_.emplace_back(address.ToString()); + } + + static ClientLbAddressTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector addresses_seen_; +}; + +const char* ClientLbAddressTest::kAttributeKey = "attribute_key"; + +ClientLbAddressTest* ClientLbAddressTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbAddressTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("address_test_lb", response_generator); + auto stub = BuildStub(channel); + // Addresses returned by the resolver will have attached attributes. + response_generator.SetNextResolution(GetServersPorts(), nullptr, + kAttributeKey, + absl::make_unique("foo")); + CheckRpcSendOk(stub, DEBUG_LOCATION); + // Check LB policy name for the channel. + EXPECT_EQ("address_test_lb", channel->GetLoadBalancingPolicyName()); + // Make sure that the attributes wind up on the subchannels. + std::vector expected; + for (const int port : GetServersPorts()) { + expected.emplace_back( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", port, + " args={} attributes={", kAttributeKey, "=foo}")); + } + EXPECT_EQ(addresses_seen(), expected); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/context_allocator_end2end_test.cc b/test/cpp/end2end/context_allocator_end2end_test.cc new file mode 100644 index 00000000..67e8cddb --- /dev/null +++ b/test/cpp/end2end/context_allocator_end2end_test.cc @@ -0,0 +1,331 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { +namespace { + +enum class Protocol { INPROC, TCP }; + +class TestScenario { + public: + TestScenario(Protocol protocol, const std::string& creds_type) + : protocol(protocol), credentials_type(creds_type) {} + void Log() const; + Protocol protocol; + const std::string credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{protocol=" + << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP") + << "," << scenario.credentials_type << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_INFO, "%s", out.str().c_str()); +} + +class ContextAllocatorEnd2endTestBase + : public ::testing::TestWithParam { + protected: + static void SetUpTestCase() { grpc_init(); } + static void TearDownTestCase() { grpc_shutdown(); } + ContextAllocatorEnd2endTestBase() {} + + ~ContextAllocatorEnd2endTestBase() override = default; + + void SetUp() override { GetParam().Log(); } + + void CreateServer(std::unique_ptr context_allocator) { + ServerBuilder builder; + + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + if (GetParam().protocol == Protocol::TCP) { + picked_port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << picked_port_; + builder.AddListeningPort(server_address_.str(), server_creds); + } + builder.SetContextAllocator(std::move(context_allocator)); + builder.RegisterService(&callback_service_); + + server_ = builder.BuildAndStart(); + } + + void DestroyServer() { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + switch (GetParam().protocol) { + case Protocol::TCP: + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + break; + case Protocol::INPROC: + channel_ = server_->InProcessChannel(args); + break; + default: + assert(false); + } + stub_ = EchoTestService::NewStub(channel_); + } + + void TearDown() override { + DestroyServer(); + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + void SendRpcs(int num_rpcs) { + std::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + test_string += std::string(1024, 'x'); + request.set_message(test_string); + std::string val; + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->Echo( + &cli_ctx, &request, &response, + [&request, &response, &done, &mu, &cv, val](Status s) { + GPR_ASSERT(s.ok()); + + EXPECT_EQ(request.message(), response.message()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } + } + } + + int picked_port_{0}; + std::shared_ptr channel_; + std::unique_ptr stub_; + CallbackTestServiceImpl callback_service_; + std::unique_ptr server_; + std::ostringstream server_address_; +}; + +class DefaultContextAllocatorTest : public ContextAllocatorEnd2endTestBase {}; + +TEST_P(DefaultContextAllocatorTest, SimpleRpc) { + const int kRpcCount = 10; + CreateServer(nullptr); + ResetStub(); + SendRpcs(kRpcCount); +} + +class NullContextAllocatorTest : public ContextAllocatorEnd2endTestBase { + public: + class NullAllocator : public grpc::ContextAllocator { + public: + NullAllocator(std::atomic* allocation_count, + std::atomic* deallocation_count) + : allocation_count_(allocation_count), + deallocation_count_(deallocation_count) {} + grpc::CallbackServerContext* NewCallbackServerContext() override { + allocation_count_->fetch_add(1, std::memory_order_relaxed); + return nullptr; + } + + GenericCallbackServerContext* NewGenericCallbackServerContext() override { + allocation_count_->fetch_add(1, std::memory_order_relaxed); + return nullptr; + } + + void Release( + grpc::CallbackServerContext* /*callback_server_context*/) override { + deallocation_count_->fetch_add(1, std::memory_order_relaxed); + } + + void Release( + GenericCallbackServerContext* /*generic_callback_server_context*/) + override { + deallocation_count_->fetch_add(1, std::memory_order_relaxed); + } + + std::atomic* allocation_count_; + std::atomic* deallocation_count_; + }; +}; + +TEST_P(NullContextAllocatorTest, UnaryRpc) { + const int kRpcCount = 10; + std::atomic allocation_count{0}; + std::atomic deallocation_count{0}; + std::unique_ptr allocator( + new NullAllocator(&allocation_count, &deallocation_count)); + CreateServer(std::move(allocator)); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side + // OnDone. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocation_count); + EXPECT_EQ(kRpcCount, deallocation_count); +} + +class SimpleContextAllocatorTest : public ContextAllocatorEnd2endTestBase { + public: + class SimpleAllocator : public grpc::ContextAllocator { + public: + SimpleAllocator(std::atomic* allocation_count, + std::atomic* deallocation_count) + : allocation_count_(allocation_count), + deallocation_count_(deallocation_count) {} + grpc::CallbackServerContext* NewCallbackServerContext() override { + allocation_count_->fetch_add(1, std::memory_order_relaxed); + return new grpc::CallbackServerContext(); + } + GenericCallbackServerContext* NewGenericCallbackServerContext() override { + allocation_count_->fetch_add(1, std::memory_order_relaxed); + return new GenericCallbackServerContext(); + } + + void Release( + grpc::CallbackServerContext* callback_server_context) override { + deallocation_count_->fetch_add(1, std::memory_order_relaxed); + delete callback_server_context; + } + + void Release(GenericCallbackServerContext* generic_callback_server_context) + override { + deallocation_count_->fetch_add(1, std::memory_order_relaxed); + delete generic_callback_server_context; + } + + std::atomic* allocation_count_; + std::atomic* deallocation_count_; + }; +}; + +TEST_P(SimpleContextAllocatorTest, UnaryRpc) { + const int kRpcCount = 10; + std::atomic allocation_count{0}; + std::atomic deallocation_count{0}; + std::unique_ptr allocator( + new SimpleAllocator(&allocation_count, &deallocation_count)); + CreateServer(std::move(allocator)); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side + // OnDone. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocation_count); + EXPECT_EQ(kRpcCount, deallocation_count); +} + +std::vector CreateTestScenarios(bool test_insecure) { + std::vector scenarios; + std::vector credentials_types{ + GetCredentialsProvider()->GetSecureCredentialsTypeList()}; + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + GPR_ASSERT(!credentials_types.empty()); + + Protocol parr[]{Protocol::INPROC, Protocol::TCP}; + for (Protocol p : parr) { + for (const auto& cred : credentials_types) { + if (p == Protocol::INPROC && + (cred != kInsecureCredentialsType || !insec_ok())) { + continue; + } + scenarios.emplace_back(p, cred); + } + } + return scenarios; +} + +// TODO(ddyihai): adding client streaming/server streaming/bidi streaming +// test. + +INSTANTIATE_TEST_SUITE_P(DefaultContextAllocatorTest, + DefaultContextAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(NullContextAllocatorTest, NullContextAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(SimpleContextAllocatorTest, SimpleContextAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/delegating_channel_test.cc b/test/cpp/end2end/delegating_channel_test.cc new file mode 100644 index 00000000..0c065e63 --- /dev/null +++ b/test/cpp/end2end/delegating_channel_test.cc @@ -0,0 +1,101 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +namespace grpc { +namespace testing { +namespace { + +class TestChannel : public experimental::DelegatingChannel { + public: + explicit TestChannel( + const std::shared_ptr& delegate_channel) + : experimental::DelegatingChannel(delegate_channel) {} + // Always returns GRPC_CHANNEL_READY + grpc_connectivity_state GetState(bool /*try_to_connect*/) override { + return GRPC_CHANNEL_READY; + } +}; + +class DelegatingChannelTest : public ::testing::Test { + protected: + DelegatingChannelTest() { + int port = grpc_pick_unused_port_or_die(); + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~DelegatingChannelTest() override { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(DelegatingChannelTest, SimpleTest) { + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + std::shared_ptr test_channel = + std::make_shared(channel); + // gRPC channel should be in idle state at this point but our test channel + // will return ready. + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + EXPECT_EQ(test_channel->GetState(false), GRPC_CHANNEL_READY); + auto stub = grpc::testing::EchoTestService::NewStub(test_channel); + ClientContext ctx; + EchoRequest req; + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc new file mode 100644 index 00000000..0d375167 --- /dev/null +++ b/test/cpp/end2end/end2end_test.cc @@ -0,0 +1,2299 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_EV +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET_EV + +#include + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::kTlsCredentialsType; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +bool CheckIsLocalhost(const std::string& addr) { + const std::string kIpv6("ipv6:[::1]:"); + const std::string kIpv4MappedIpv6("ipv6:[::ffff:127.0.0.1]:"); + const std::string kIpv4("ipv4:127.0.0.1:"); + return addr.substr(0, kIpv4.size()) == kIpv4 || + addr.substr(0, kIpv4MappedIpv6.size()) == kIpv4MappedIpv6 || + addr.substr(0, kIpv6.size()) == kIpv6; +} + +const int kClientChannelBackupPollIntervalMs = 200; + +const char kTestCredsPluginErrorMsg[] = "Could not find plugin metadata."; + +const char kFakeToken[] = "fake_token"; +const char kFakeSelector[] = "fake_selector"; +const char kExpectedFakeCredsDebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector}}"; + +const char kWrongToken[] = "wrong_token"; +const char kWrongSelector[] = "wrong_selector"; +const char kExpectedWrongCredsDebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:wrong_selector}}"; + +const char kFakeToken1[] = "fake_token1"; +const char kFakeSelector1[] = "fake_selector1"; +const char kExpectedFakeCreds1DebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector1}}"; + +const char kFakeToken2[] = "fake_token2"; +const char kFakeSelector2[] = "fake_selector2"; +const char kExpectedFakeCreds2DebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector2}}"; + +const char kExpectedAuthMetadataPluginKeyFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:TestPluginMetadata," + "value:Does not matter, will fail the key is invalid.}}"; +const char kExpectedAuthMetadataPluginValueFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:With illegal \n value.}}"; +const char kExpectedAuthMetadataPluginWithDeadlineCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:meta_key,value:Does not " + "matter}}"; +const char kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:Does not matter, will fail anyway (see 3rd param)}}"; +const char + kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString + [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-" + "metadata,value:Dr Jekyll}}"; +const char + kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString + [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-" + "metadata,value:Mr Hyde}}"; +const char kExpectedBlockingAuthMetadataPluginFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:Does not matter, will fail anyway (see 3rd param)}}"; +const char kExpectedCompositeCallCredsDebugString[] = + "SecureCallCredentials{CompositeCallCredentials{TestMetadataCredentials{" + "key:call-creds-key1,value:call-creds-val1},TestMetadataCredentials{key:" + "call-creds-key2,value:call-creds-val2}}}"; + +class TestMetadataCredentialsPlugin : public MetadataCredentialsPlugin { + public: + static const char kGoodMetadataKey[]; + static const char kBadMetadataKey[]; + + TestMetadataCredentialsPlugin(const grpc::string_ref& metadata_key, + const grpc::string_ref& metadata_value, + bool is_blocking, bool is_successful, + int delay_ms) + : metadata_key_(metadata_key.data(), metadata_key.length()), + metadata_value_(metadata_value.data(), metadata_value.length()), + is_blocking_(is_blocking), + is_successful_(is_successful), + delay_ms_(delay_ms) {} + + bool IsBlocking() const override { return is_blocking_; } + + Status GetMetadata( + grpc::string_ref service_url, grpc::string_ref method_name, + const grpc::AuthContext& channel_auth_context, + std::multimap* metadata) override { + if (delay_ms_ != 0) { + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(delay_ms_, GPR_TIMESPAN))); + } + EXPECT_GT(service_url.length(), 0UL); + EXPECT_GT(method_name.length(), 0UL); + EXPECT_TRUE(channel_auth_context.IsPeerAuthenticated()); + EXPECT_TRUE(metadata != nullptr); + if (is_successful_) { + metadata->insert(std::make_pair(metadata_key_, metadata_value_)); + return Status::OK; + } else { + return Status(StatusCode::NOT_FOUND, kTestCredsPluginErrorMsg); + } + } + + std::string DebugString() override { + return absl::StrFormat("TestMetadataCredentials{key:%s,value:%s}", + metadata_key_.c_str(), metadata_value_.c_str()); + } + + private: + std::string metadata_key_; + std::string metadata_value_; + bool is_blocking_; + bool is_successful_; + int delay_ms_; +}; + +const char TestMetadataCredentialsPlugin::kBadMetadataKey[] = + "TestPluginMetadata"; +const char TestMetadataCredentialsPlugin::kGoodMetadataKey[] = + "test-plugin-metadata"; + +class TestAuthMetadataProcessor : public AuthMetadataProcessor { + public: + static const char kGoodGuy[]; + + explicit TestAuthMetadataProcessor(bool is_blocking) + : is_blocking_(is_blocking) {} + + std::shared_ptr GetCompatibleClientCreds() { + return grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, kGoodGuy, + is_blocking_, true, 0))); + } + + std::shared_ptr GetIncompatibleClientCreds() { + return grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, "Mr Hyde", + is_blocking_, true, 0))); + } + + // Interface implementation + bool IsBlocking() const override { return is_blocking_; } + + Status Process(const InputMetadata& auth_metadata, AuthContext* context, + OutputMetadata* consumed_auth_metadata, + OutputMetadata* response_metadata) override { + EXPECT_TRUE(consumed_auth_metadata != nullptr); + EXPECT_TRUE(context != nullptr); + EXPECT_TRUE(response_metadata != nullptr); + auto auth_md = + auth_metadata.find(TestMetadataCredentialsPlugin::kGoodMetadataKey); + EXPECT_NE(auth_md, auth_metadata.end()); + string_ref auth_md_value = auth_md->second; + if (auth_md_value == kGoodGuy) { + context->AddProperty(kIdentityPropName, kGoodGuy); + context->SetPeerIdentityPropertyName(kIdentityPropName); + consumed_auth_metadata->insert(std::make_pair( + string(auth_md->first.data(), auth_md->first.length()), + string(auth_md->second.data(), auth_md->second.length()))); + return Status::OK; + } else { + return Status(StatusCode::UNAUTHENTICATED, + string("Invalid principal: ") + + string(auth_md_value.data(), auth_md_value.length())); + } + } + + private: + static const char kIdentityPropName[]; + bool is_blocking_; +}; + +const char TestAuthMetadataProcessor::kGoodGuy[] = "Dr Jekyll"; +const char TestAuthMetadataProcessor::kIdentityPropName[] = "novel identity"; + +class Proxy : public ::grpc::testing::EchoTestService::Service { + public: + explicit Proxy(const std::shared_ptr& channel) + : stub_(grpc::testing::EchoTestService::NewStub(channel)) {} + + Status Echo(ServerContext* server_context, const EchoRequest* request, + EchoResponse* response) override { + std::unique_ptr client_context = + ClientContext::FromServerContext(*server_context); + return stub_->Echo(client_context.get(), *request, response); + } + + private: + std::unique_ptr<::grpc::testing::EchoTestService::Stub> stub_; +}; + +class TestServiceImplDupPkg + : public ::grpc::testing::duplicate::EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* /*request*/, + EchoResponse* response) override { + response->set_message("no package"); + return Status::OK; + } +}; + +class TestScenario { + public: + TestScenario(bool interceptors, bool proxy, bool inproc_stub, + const std::string& creds_type, bool use_callback_server) + : use_interceptors(interceptors), + use_proxy(proxy), + inproc(inproc_stub), + credentials_type(creds_type), + callback_server(use_callback_server) {} + void Log() const; + bool use_interceptors; + bool use_proxy; + bool inproc; + const std::string credentials_type; + bool callback_server; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{use_interceptors=" + << (scenario.use_interceptors ? "true" : "false") + << ", use_proxy=" << (scenario.use_proxy ? "true" : "false") + << ", inproc=" << (scenario.inproc ? "true" : "false") + << ", server_type=" + << (scenario.callback_server ? "callback" : "sync") + << ", credentials='" << scenario.credentials_type << "'}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class End2endTest : public ::testing::TestWithParam { + protected: + static void SetUpTestCase() { grpc_init(); } + static void TearDownTestCase() { grpc_shutdown(); } + End2endTest() + : is_server_started_(false), + kMaxMessageSize_(8192), + special_service_("special"), + first_picked_port_(0) { + GetParam().Log(); + } + + void TearDown() override { + if (is_server_started_) { + server_->Shutdown(); + if (proxy_server_) proxy_server_->Shutdown(); + } + if (first_picked_port_ > 0) { + grpc_recycle_unused_port(first_picked_port_); + } + } + + void StartServer(const std::shared_ptr& processor) { + int port = grpc_pick_unused_port_or_die(); + first_picked_port_ = port; + server_address_ << "localhost:" << port; + // Setup server + BuildAndStartServer(processor); + } + + void RestartServer(const std::shared_ptr& processor) { + if (is_server_started_) { + server_->Shutdown(); + BuildAndStartServer(processor); + } + } + + void BuildAndStartServer( + const std::shared_ptr& processor) { + ServerBuilder builder; + ConfigureServerBuilder(&builder); + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + if (GetParam().credentials_type != kInsecureCredentialsType) { + server_creds->SetAuthMetadataProcessor(processor); + } + if (GetParam().use_interceptors) { + std::vector< + std::unique_ptr> + creators; + // Add 20 phony server interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + } + builder.AddListeningPort(server_address_.str(), server_creds); + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } + builder.RegisterService("foo.test.youtube.com", &special_service_); + builder.RegisterService(&dup_pkg_service_); + + builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4); + builder.SetSyncServerOption( + ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10); + + server_ = builder.BuildAndStart(); + is_server_started_ = true; + } + + virtual void ConfigureServerBuilder(ServerBuilder* builder) { + builder->SetMaxMessageSize( + kMaxMessageSize_); // For testing max message size. + } + + void ResetChannel( + std::vector< + std::unique_ptr> + interceptor_creators = {}) { + if (!is_server_started_) { + StartServer(std::shared_ptr()); + } + EXPECT_TRUE(is_server_started_); + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + if (!user_agent_prefix_.empty()) { + args.SetUserAgentPrefix(user_agent_prefix_); + } + args.SetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING, "end2end_test"); + + if (!GetParam().inproc) { + if (!GetParam().use_interceptors) { + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + } else { + channel_ = CreateCustomChannelWithInterceptors( + server_address_.str(), channel_creds, args, + interceptor_creators.empty() ? CreatePhonyClientInterceptors() + : std::move(interceptor_creators)); + } + } else { + if (!GetParam().use_interceptors) { + channel_ = server_->InProcessChannel(args); + } else { + channel_ = server_->experimental().InProcessChannelWithInterceptors( + args, interceptor_creators.empty() + ? CreatePhonyClientInterceptors() + : std::move(interceptor_creators)); + } + } + } + + void ResetStub( + std::vector< + std::unique_ptr> + interceptor_creators = {}) { + ResetChannel(std::move(interceptor_creators)); + if (GetParam().use_proxy) { + proxy_service_ = absl::make_unique(channel_); + int port = grpc_pick_unused_port_or_die(); + std::ostringstream proxyaddr; + proxyaddr << "localhost:" << port; + ServerBuilder builder; + builder.AddListeningPort(proxyaddr.str(), InsecureServerCredentials()); + builder.RegisterService(proxy_service_.get()); + + builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4); + builder.SetSyncServerOption( + ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10); + + proxy_server_ = builder.BuildAndStart(); + + channel_ = + grpc::CreateChannel(proxyaddr.str(), InsecureChannelCredentials()); + } + + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + PhonyInterceptor::Reset(); + } + + bool is_server_started_; + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::unique_ptr proxy_server_; + std::unique_ptr proxy_service_; + std::ostringstream server_address_; + const int kMaxMessageSize_; + TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; + TestServiceImpl special_service_; + TestServiceImplDupPkg dup_pkg_service_; + std::string user_agent_prefix_; + int first_picked_port_; +}; + +static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs, + bool with_binary_metadata) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + if (with_binary_metadata) { + char bytes[8] = {'\0', '\1', '\2', '\3', + '\4', '\5', '\6', static_cast(i)}; + context.AddMetadata("custom-bin", std::string(bytes, 8)); + } + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + Status s = stub->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + } +} + +// This class is for testing scenarios where RPCs are cancelled on the server +// by calling ServerContext::TryCancel() +class End2endServerTryCancelTest : public End2endTest { + protected: + // Helper for testing client-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading + // any messages from the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // the messages from the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestRequestStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel, int num_msgs_to_send) { + RestartServer(std::shared_ptr()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel value in the client metadata + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + + auto stream = stub_->RequestStream(&context, &response); + + int num_msgs_sent = 0; + while (num_msgs_sent < num_msgs_to_send) { + request.set_message("hello"); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + + stream->WritesDone(); + Status s = stream->Finish(); + + // At this point, we know for sure that RPC was cancelled by the server + // since we passed server_try_cancel value in the metadata. Depending on the + // value of server_try_cancel, the RPC might have been cancelled by the + // server at different stages. The following validates our expectations of + // number of messages sent in various cancellation scenarios: + + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + case CANCEL_DURING_PROCESSING: + // If the RPC is cancelled by server before / during messages from the + // client, it means that the client most likely did not get a chance to + // send all the messages it wanted to send. i.e num_msgs_sent <= + // num_msgs_to_send + EXPECT_LE(num_msgs_sent, num_msgs_to_send); + break; + + case CANCEL_AFTER_PROCESSING: + // If the RPC was cancelled after all messages were read by the server, + // the client did get a chance to send all its messages + EXPECT_EQ(num_msgs_sent, num_msgs_to_send); + break; + + default: + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } + } + + // Helper for testing server-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before writing + // any messages to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while writing + // messages to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after writing all + // the messages to the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestResponseStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + RestartServer(std::shared_ptr()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel in the client metadata + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + + request.set_message("hello"); + auto stream = stub_->ResponseStream(&context, request); + + int num_msgs_read = 0; + while (num_msgs_read < kServerDefaultResponseStreamsToSend) { + if (!stream->Read(&response)) { + break; + } + EXPECT_EQ(response.message(), + request.message() + std::to_string(num_msgs_read)); + num_msgs_read++; + } + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + Status s = stream->Finish(); + + // Depending on the value of server_try_cancel, the RPC might have been + // cancelled by the server at different stages. The following validates our + // expectations of number of messages read in various cancellation + // scenarios: + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + // Server cancelled before sending any messages. Which means the client + // wouldn't have read any + EXPECT_EQ(num_msgs_read, 0); + break; + + case CANCEL_DURING_PROCESSING: + // Server cancelled while writing messages. Client must have read less + // than or equal to the expected number of messages + EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend); + break; + + case CANCEL_AFTER_PROCESSING: + // Even though the Server cancelled after writing all messages, the RPC + // may be cancelled before the Client got a chance to read all the + // messages. + EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend); + break; + + default: { + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + } + + EXPECT_FALSE(s.ok()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } + } + + // Helper for testing bidirectional-streaming RPCs which are cancelled on the + // server. Depending on the value of server_try_cancel parameter, this will + // test one of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/ + // writing any messages from/to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading/ + // writing messages from/to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading/writing + // all the messages from/to the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestBidiStreamServerCancel(ServerTryCancelRequestPhase server_try_cancel, + int num_messages) { + RestartServer(std::shared_ptr()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel in the client metadata + context.AddMetadata(kServerTryCancelRequest, + std::to_string(server_try_cancel)); + + auto stream = stub_->BidiStream(&context); + + int num_msgs_read = 0; + int num_msgs_sent = 0; + while (num_msgs_sent < num_messages) { + request.set_message("hello " + std::to_string(num_msgs_sent)); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + + if (!stream->Read(&response)) { + break; + } + num_msgs_read++; + + EXPECT_EQ(response.message(), request.message()); + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + stream->WritesDone(); + Status s = stream->Finish(); + + // Depending on the value of server_try_cancel, the RPC might have been + // cancelled by the server at different stages. The following validates our + // expectations of number of messages read in various cancellation + // scenarios: + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + EXPECT_EQ(num_msgs_read, 0); + break; + + case CANCEL_DURING_PROCESSING: + EXPECT_LE(num_msgs_sent, num_messages); + EXPECT_LE(num_msgs_read, num_msgs_sent); + break; + + case CANCEL_AFTER_PROCESSING: + EXPECT_EQ(num_msgs_sent, num_messages); + + // The Server cancelled after reading the last message and after writing + // the message to the client. However, the RPC cancellation might have + // taken effect before the client actually read the response. + EXPECT_LE(num_msgs_read, num_msgs_sent); + break; + + default: + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } + } +}; + +TEST_P(End2endServerTryCancelTest, RequestEchoServerCancel) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + std::to_string(CANCEL_BEFORE_PROCESSING)); + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); +} + +// Server to cancel before doing reading the request +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelBeforeReads) { + TestRequestStreamServerCancel(CANCEL_BEFORE_PROCESSING, 1); +} + +// Server to cancel while reading a request from the stream in parallel +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelDuringRead) { + TestRequestStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading all the requests but before returning to the +// client +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelAfterReads) { + TestRequestStreamServerCancel(CANCEL_AFTER_PROCESSING, 4); +} + +// Server to cancel before sending any response messages +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelBefore) { + TestResponseStreamServerCancel(CANCEL_BEFORE_PROCESSING); +} + +// Server to cancel while writing a response to the stream in parallel +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelDuring) { + TestResponseStreamServerCancel(CANCEL_DURING_PROCESSING); +} + +// Server to cancel after writing all the respones to the stream but before +// returning to the client +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelAfter) { + TestResponseStreamServerCancel(CANCEL_AFTER_PROCESSING); +} + +// Server to cancel before reading/writing any requests/responses on the stream +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelBefore) { + TestBidiStreamServerCancel(CANCEL_BEFORE_PROCESSING, 2); +} + +// Server to cancel while reading/writing requests/responses on the stream in +// parallel +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelDuring) { + TestBidiStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading/writing all requests/responses on the stream +// but before returning to the client +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelAfter) { + TestBidiStreamServerCancel(CANCEL_AFTER_PROCESSING, 5); +} + +TEST_P(End2endTest, SimpleRpcWithCustomUserAgentPrefix) { + // User-Agent is an HTTP header for HTTP transports only + if (GetParam().inproc) { + return; + } + user_agent_prefix_ = "custom_prefix"; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + request.mutable_param()->set_echo_metadata(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + const auto& trailing_metadata = context.GetServerTrailingMetadata(); + auto iter = trailing_metadata.find("user-agent"); + EXPECT_TRUE(iter != trailing_metadata.end()); + std::string expected_prefix = user_agent_prefix_ + " grpc-c++/"; + EXPECT_TRUE(iter->second.starts_with(expected_prefix)) << iter->second; +} + +TEST_P(End2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { + ResetStub(); + std::vector threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, true); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(End2endTest, MultipleRpcs) { + ResetStub(); + std::vector threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, false); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(End2endTest, ManyStubs) { + ResetStub(); + ChannelTestPeer peer(channel_.get()); + int registered_calls_pre = peer.registered_calls(); + int registration_attempts_pre = peer.registration_attempts(); + for (int i = 0; i < 1000; ++i) { + grpc::testing::EchoTestService::NewStub(channel_); + } + EXPECT_EQ(peer.registered_calls(), registered_calls_pre); + EXPECT_GT(peer.registration_attempts(), registration_attempts_pre); +} + +TEST_P(End2endTest, EmptyBinaryMetadata) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + ClientContext context; + context.AddMetadata("custom-bin", ""); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ReconnectChannel) { + if (GetParam().inproc) { + return; + } + int poller_slowdown_factor = 1; + // It needs 2 pollset_works to reconnect the channel with polling engine + // "poll" +#ifdef GRPC_POSIX_SOCKET_EV + grpc_core::UniquePtr poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + if (0 == strcmp(poller.get(), "poll")) { + poller_slowdown_factor = 2; + } +#endif // GRPC_POSIX_SOCKET_EV + ResetStub(); + SendRpc(stub_.get(), 1, false); + RestartServer(std::shared_ptr()); + // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to + // reconnect the channel. Make it a factor of 5x + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(kClientChannelBackupPollIntervalMs * 5 * + poller_slowdown_factor * + grpc_test_slowdown_factor(), + GPR_TIMESPAN))); + SendRpc(stub_.get(), 1, false); +} + +TEST_P(End2endTest, RequestStreamOneRequest) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(context.debug_error_string().empty()); +} + +TEST_P(End2endTest, RequestStreamOneRequestWithCoalescingApi) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.set_initial_metadata_corked(true); + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + stream->WriteLast(request, WriteOptions()); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequests) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Write(request)); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequestsWithWriteThrough) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through())); + EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through())); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequestsWithCoalescingApi) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.set_initial_metadata_corked(true); + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + stream->WriteLast(request, WriteOptions()); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ResponseStream) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + std::to_string(i)); + } + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ResponseStreamWithCoalescingApi) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerUseCoalescingApi, "1"); + + auto stream = stub_->ResponseStream(&context, request); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + std::to_string(i)); + } + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// This was added to prevent regression from issue: +// https://github.com/grpc/grpc/issues/11546 +TEST_P(End2endTest, ResponseStreamWithEverythingCoalesced) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerUseCoalescingApi, "1"); + // We will only send one message, forcing everything (init metadata, message, + // trailing) to be coalesced together. + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto stream = stub_->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, BidiStream) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::string msg("hello"); + + auto stream = stub_->BidiStream(&context); + + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + request.set_message(msg + std::to_string(i)); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + } + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, BidiStreamWithCoalescingApi) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.AddMetadata(kServerFinishAfterNReads, "3"); + context.set_initial_metadata_corked(true); + std::string msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + stream->WriteLast(request, WriteOptions()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// This was added to prevent regression from issue: +// https://github.com/grpc/grpc/issues/11546 +TEST_P(End2endTest, BidiStreamWithEverythingCoalesced) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.AddMetadata(kServerFinishAfterNReads, "1"); + context.set_initial_metadata_corked(true); + std::string msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + stream->WriteLast(request, WriteOptions()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// Talk to the two services with the same name but different package names. +// The two stubs are created on the same channel. +TEST_P(End2endTest, DiffPackageServices) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + std::unique_ptr dup_pkg_stub( + grpc::testing::duplicate::EchoTestService::NewStub(channel_)); + ClientContext context2; + s = dup_pkg_stub->Echo(&context2, request, &response); + EXPECT_EQ("no package", response.message()); + EXPECT_TRUE(s.ok()); +} + +template +void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(delay_us, GPR_TIMESPAN))); + while (!service->signal_client()) { + } + context->TryCancel(); +} + +TEST_P(End2endTest, CancelRpcBeforeStart) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.TryCancel(); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(End2endTest, CancelRpcAfterStart) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + request.mutable_param()->set_server_notify_client_when_started(true); + request.mutable_param()->set_skip_cancelled_check(true); + Status s; + std::thread echo_thread([this, &s, &context, &request, &response] { + s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + }); + if (!GetParam().callback_server) { + service_.ClientWaitUntilRpcStarted(); + } else { + callback_service_.ClientWaitUntilRpcStarted(); + } + + context.TryCancel(); + + if (!GetParam().callback_server) { + service_.SignalServerToContinue(); + } else { + callback_service_.SignalServerToContinue(); + } + + echo_thread.join(); + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels request stream after sending two messages +TEST_P(End2endTest, ClientCancelsRequestStream) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->RequestStream(&context, &response); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Write(request)); + + context.TryCancel(); + + Status s = stream->Finish(); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + + EXPECT_EQ(response.message(), ""); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels server stream after sending some messages +TEST_P(End2endTest, ClientCancelsResponseStream) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1"); + + context.TryCancel(); + + // The cancellation races with responses, so there might be zero or + // one responses pending, read till failure + + if (stream->Read(&response)) { + EXPECT_EQ(response.message(), request.message() + "2"); + // Since we have cancelled, we expect the next attempt to read to fail + EXPECT_FALSE(stream->Read(&response)); + } + + Status s = stream->Finish(); + // The final status could be either of CANCELLED or OK depending on + // who won the race. + EXPECT_GE(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels bidi stream after sending some messages +TEST_P(End2endTest, ClientCancelsBidi) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::string msg("hello"); + + // Send server_try_cancel value in the client metadata + context.AddMetadata(kClientTryCancelRequest, std::to_string(1)); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + + context.TryCancel(); + + // The cancellation races with responses, so there might be zero or + // one responses pending, read till failure + + if (stream->Read(&response)) { + EXPECT_EQ(response.message(), request.message()); + // Since we have cancelled, we expect the next attempt to read to fail + EXPECT_FALSE(stream->Read(&response)); + } + + Status s = stream->Finish(); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(End2endTest, RpcMaxMessageSize) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message(string(kMaxMessageSize_ * 2, 'a')); + request.mutable_param()->set_server_die(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); +} + +void ReaderThreadFunc(ClientReaderWriter* stream, + gpr_event* ev) { + EchoResponse resp; + gpr_event_set(ev, reinterpret_cast(1)); + while (stream->Read(&resp)) { + gpr_log(GPR_INFO, "Read message"); + } +} + +// Run a Read and a WritesDone simultaneously. +TEST_P(End2endTest, SimultaneousReadWritesDone) { + ResetStub(); + ClientContext context; + gpr_event ev; + gpr_event_init(&ev); + auto stream = stub_->BidiStream(&context); + std::thread reader_thread(ReaderThreadFunc, stream.get(), &ev); + gpr_event_wait(&ev, gpr_inf_future(GPR_CLOCK_REALTIME)); + stream->WritesDone(); + reader_thread.join(); + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ChannelState) { + if (GetParam().inproc) { + return; + } + + ResetStub(); + // Start IDLE + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + + // Did not ask to connect, no state change. + CompletionQueue cq; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(10); + channel_->NotifyOnStateChange(GRPC_CHANNEL_IDLE, deadline, &cq, nullptr); + void* tag; + bool ok = true; + cq.Next(&tag, &ok); + EXPECT_FALSE(ok); + + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true)); + EXPECT_TRUE(channel_->WaitForStateChange(GRPC_CHANNEL_IDLE, + gpr_inf_future(GPR_CLOCK_REALTIME))); + auto state = channel_->GetState(false); + EXPECT_TRUE(state == GRPC_CHANNEL_CONNECTING || state == GRPC_CHANNEL_READY); +} + +// Takes 10s. +TEST_P(End2endTest, ChannelStateTimeout) { + if ((GetParam().credentials_type != kInsecureCredentialsType) || + GetParam().inproc) { + return; + } + int port = grpc_pick_unused_port_or_die(); + std::ostringstream server_address; + server_address << "localhost:" << port; + // Channel to non-existing server + auto channel = + grpc::CreateChannel(server_address.str(), InsecureChannelCredentials()); + // Start IDLE + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel->GetState(true)); + + auto state = GRPC_CHANNEL_IDLE; + for (int i = 0; i < 10; i++) { + channel->WaitForStateChange( + state, std::chrono::system_clock::now() + std::chrono::seconds(1)); + state = channel->GetState(false); + } +} + +TEST_P(End2endTest, ChannelStateOnLameChannel) { + if ((GetParam().credentials_type != kInsecureCredentialsType) || + GetParam().inproc) { + return; + } + // Channel using invalid target URI. This creates a lame channel. + auto channel = grpc::CreateChannel("dns:///", InsecureChannelCredentials()); + // Channel should immediately report TRANSIENT_FAILURE. + EXPECT_EQ(GRPC_CHANNEL_TRANSIENT_FAILURE, channel->GetState(true)); + // And state will never change. + auto state = GRPC_CHANNEL_TRANSIENT_FAILURE; + for (int i = 0; i < 10; ++i) { + channel->WaitForStateChange( + state, std::chrono::system_clock::now() + std::chrono::seconds(1)); + state = channel->GetState(false); + } +} + +// Talking to a non-existing service. +TEST_P(End2endTest, NonExistingService) { + ResetChannel(); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel_); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub->Unimplemented(&context, request, &response); + EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code()); + EXPECT_EQ("", s.error_message()); +} + +// Ask the server to send back a serialized proto in trailer. +// This is an example of setting error details. +TEST_P(End2endTest, BinaryTrailerTest) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + request.mutable_param()->set_echo_metadata(true); + DebugInfo* info = request.mutable_param()->mutable_debug_info(); + info->add_stack_entries("stack_entry_1"); + info->add_stack_entries("stack_entry_2"); + info->add_stack_entries("stack_entry_3"); + info->set_detail("detailed debug info"); + std::string expected_string = info->SerializeAsString(); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + auto trailers = context.GetServerTrailingMetadata(); + EXPECT_EQ(1u, trailers.count(kDebugInfoTrailerKey)); + auto iter = trailers.find(kDebugInfoTrailerKey); + EXPECT_EQ(expected_string, iter->second); + // Parse the returned trailer into a DebugInfo proto. + DebugInfo returned_info; + EXPECT_TRUE(returned_info.ParseFromString(ToString(iter->second))); +} + +TEST_P(End2endTest, ExpectErrorTest) { + ResetStub(); + + std::vector expected_status; + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + // No Error message or details + + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + expected_status.back().set_error_message("text error message"); + expected_status.back().set_binary_error_details("text error details"); + + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + expected_status.back().set_error_message("text error message"); + expected_status.back().set_binary_error_details( + "\x0\x1\x2\x3\x4\x5\x6\x8\x9\xA\xB"); + + for (auto iter = expected_status.begin(); iter != expected_status.end(); + ++iter) { + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("Hello"); + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(iter->code()); + error->set_error_message(iter->error_message()); + error->set_binary_error_details(iter->binary_error_details()); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(iter->code(), s.error_code()); + EXPECT_EQ(iter->error_message(), s.error_message()); + EXPECT_EQ(iter->binary_error_details(), s.error_details()); + EXPECT_TRUE(absl::StrContains(context.debug_error_string(), "created")); +#ifndef NDEBUG + // GRPC_ERROR_INT_FILE_LINE is for debug only + EXPECT_TRUE(absl::StrContains(context.debug_error_string(), "file")); + EXPECT_TRUE(absl::StrContains(context.debug_error_string(), "line")); +#endif + EXPECT_TRUE(absl::StrContains(context.debug_error_string(), "status")); + EXPECT_TRUE(absl::StrContains(context.debug_error_string(), "13")); + } +} + +////////////////////////////////////////////////////////////////////////// +// Test with and without a proxy. +class ProxyEnd2endTest : public End2endTest { + protected: +}; + +TEST_P(ProxyEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpc(stub_.get(), 1, false); +} + +TEST_P(ProxyEnd2endTest, SimpleRpcWithEmptyMessages) { + ResetStub(); + EchoRequest request; + EchoResponse response; + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(s.ok()); +} + +TEST_P(ProxyEnd2endTest, MultipleRpcs) { + ResetStub(); + std::vector threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, false); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +// Set a 10us deadline and make sure proper error is returned. +TEST_P(ProxyEnd2endTest, RpcDeadlineExpires) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_skip_cancelled_check(true); + // Let server sleep for 40 ms first to guarantee expiry. + // 40 ms might seem a bit extreme but the timer manager would have been just + // initialized (when ResetStub() was called) and there are some warmup costs + // i.e the timer thread many not have even started. There might also be other + // delays in the timer manager thread (in acquiring locks, timer data + // structure manipulations, starting backup timer threads) that add to the + // delays. 40ms is still not enough in some cases but this significantly + // reduces the test flakes + request.mutable_param()->set_server_sleep_us(40 * 1000); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(1); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, s.error_code()); +} + +// Set a long but finite deadline. +TEST_P(ProxyEnd2endTest, RpcLongDeadline) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::hours(1); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +// Ask server to echo back the deadline it sees. +TEST_P(ProxyEnd2endTest, EchoDeadline) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_echo_deadline(true); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(100); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + gpr_timespec sent_deadline; + Timepoint2Timespec(deadline, &sent_deadline); + // We want to allow some reasonable error given: + // - request_deadline() only has 1sec resolution so the best we can do is +-1 + // - if sent_deadline.tv_nsec is very close to the next second's boundary we + // can end up being off by 2 in one direction. + EXPECT_LE(response.param().request_deadline() - sent_deadline.tv_sec, 2); + EXPECT_GE(response.param().request_deadline() - sent_deadline.tv_sec, -1); +} + +// Ask server to echo back the deadline it sees. The rpc has no deadline. +TEST_P(ProxyEnd2endTest, EchoDeadlineForNoDeadlineRpc) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_echo_deadline(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response.param().request_deadline(), + gpr_inf_future(GPR_CLOCK_REALTIME).tv_sec); +} + +TEST_P(ProxyEnd2endTest, UnimplementedRpc) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Unimplemented(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), grpc::StatusCode::UNIMPLEMENTED); + EXPECT_EQ(s.error_message(), ""); + EXPECT_EQ(response.message(), ""); +} + +// Client cancels rpc after 10ms +TEST_P(ProxyEnd2endTest, ClientCancelsRpc) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + const int kCancelDelayUs = 10 * 1000; + request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs); + + ClientContext context; + std::thread cancel_thread; + if (!GetParam().callback_server) { + cancel_thread = std::thread( + [&context, this](int delay) { CancelRpc(&context, delay, &service_); }, + kCancelDelayUs); + // Note: the unusual pattern above (and below) is caused by a conflict + // between two sets of compiler expectations. clang allows const to be + // captured without mention, so there is no need to capture kCancelDelayUs + // (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler + // in our tests requires an explicit capture even for const. We square this + // circle by passing the const value in as an argument to the lambda. + } else { + cancel_thread = std::thread( + [&context, this](int delay) { + CancelRpc(&context, delay, &callback_service_); + }, + kCancelDelayUs); + } + Status s = stub_->Echo(&context, request, &response); + cancel_thread.join(); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(s.error_message(), "CANCELLED"); +} + +// Server cancels rpc after 1ms +TEST_P(ProxyEnd2endTest, ServerCancelsRpc) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_server_cancel_after_us(1000); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + EXPECT_TRUE(s.error_message().empty()); +} + +// Make the response larger than the flow control window. +TEST_P(ProxyEnd2endTest, HugeResponse) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("huge response"); + const size_t kResponseSize = 1024 * (1024 + 10); + request.mutable_param()->set_response_message_length(kResponseSize); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(20); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(kResponseSize, response.message().size()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(ProxyEnd2endTest, Peer) { + // Peer is not meaningful for inproc + if (GetParam().inproc) { + return; + } + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("hello"); + request.mutable_param()->set_echo_peer(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(CheckIsLocalhost(response.param().peer())); + EXPECT_TRUE(CheckIsLocalhost(context.peer())); +} + +////////////////////////////////////////////////////////////////////////// +class SecureEnd2endTest : public End2endTest { + protected: + SecureEnd2endTest() { + GPR_ASSERT(!GetParam().use_proxy); + GPR_ASSERT(GetParam().credentials_type != kInsecureCredentialsType); + } +}; + +TEST_P(SecureEnd2endTest, SimpleRpcWithHost) { + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + context.set_authority("foo.test.youtube.com"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(response.has_param()); + EXPECT_EQ("special", response.param().host()); + EXPECT_TRUE(s.ok()); +} + +bool MetadataContains( + const std::multimap& metadata, + const std::string& key, const std::string& value) { + int count = 0; + + for (std::multimap::const_iterator iter = + metadata.begin(); + iter != metadata.end(); ++iter) { + if (ToString(iter->first) == key && ToString(iter->second) == value) { + count++; + } + } + return count == 1; +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorSuccess) { + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + std::string("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorFailure) { + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); +} + +TEST_P(SecureEnd2endTest, SetPerCallCredentials) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds = + GoogleIAMCredentials(kFakeToken, kFakeSelector); + context.set_credentials(creds); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +class CredentialsInterceptor : public experimental::Interceptor { + public: + explicit CredentialsInterceptor(experimental::ClientRpcInfo* info) + : info_(info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + std::shared_ptr creds = + GoogleIAMCredentials(kFakeToken, kFakeSelector); + info_->client_context()->set_credentials(creds); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_ = nullptr; +}; + +class CredentialsInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + CredentialsInterceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new CredentialsInterceptor(info); + } +}; + +TEST_P(SecureEnd2endTest, CallCredentialsInterception) { + if (!GetParam().use_interceptors) { + return; + } + std::vector> + interceptor_creators; + interceptor_creators.push_back( + absl::make_unique()); + ResetStub(std::move(interceptor_creators)); + EchoRequest request; + EchoResponse response; + ClientContext context; + + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +TEST_P(SecureEnd2endTest, CallCredentialsInterceptionWithSetCredentials) { + if (!GetParam().use_interceptors) { + return; + } + std::vector> + interceptor_creators; + interceptor_creators.push_back( + absl::make_unique()); + ResetStub(std::move(interceptor_creators)); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds1 = + GoogleIAMCredentials(kWrongToken, kWrongSelector); + context.set_credentials(creds1); + EXPECT_EQ(context.credentials(), creds1); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedWrongCredsDebugString); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +TEST_P(SecureEnd2endTest, OverridePerCallCredentials) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr creds1 = + GoogleIAMCredentials(kFakeToken1, kFakeSelector1); + context.set_credentials(creds1); + EXPECT_EQ(context.credentials(), creds1); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCreds1DebugString); + std::shared_ptr creds2 = + GoogleIAMCredentials(kFakeToken2, kFakeSelector2); + context.set_credentials(creds2); + EXPECT_EQ(context.credentials(), creds2); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken2)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector2)); + EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken1)); + EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector1)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCreds2DebugString); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginKeyFailure) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kBadMetadataKey, + "Does not matter, will fail the key is invalid.", false, true, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginKeyFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginValueFailure) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "With illegal \n value.", false, true, 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginValueFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginWithDeadline) { + ResetStub(); + EchoRequest request; + request.mutable_param()->set_skip_cancelled_check(true); + EchoResponse response; + ClientContext context; + const int delay = 100; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(delay); + context.set_deadline(deadline); + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true, + true, delay)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + if (!s.ok()) { + EXPECT_TRUE(s.error_code() == StatusCode::DEADLINE_EXCEEDED || + s.error_code() == StatusCode::UNAVAILABLE); + } + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginWithDeadlineCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginWithCancel) { + ResetStub(); + EchoRequest request; + request.mutable_param()->set_skip_cancelled_check(true); + EchoResponse response; + ClientContext context; + const int delay = 100; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true, + true, delay)))); + request.set_message("Hello"); + + std::thread cancel_thread([&] { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(delay, GPR_TIMESPAN))); + context.TryCancel(); + }); + Status s = stub_->Echo(&context, request, &response); + if (!s.ok()) { + EXPECT_TRUE(s.error_code() == StatusCode::CANCELLED || + s.error_code() == StatusCode::UNAVAILABLE); + } + cancel_thread.join(); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginWithDeadlineCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginFailure) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "Does not matter, will fail anyway (see 3rd param)", false, false, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(s.error_message(), + std::string("Getting metadata from plugin failed with error: ") + + kTestCredsPluginErrorMsg); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorSuccess) { + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + std::string("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); + EXPECT_EQ( + context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorFailure) { + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); + EXPECT_EQ( + context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginFailure) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "Does not matter, will fail anyway (see 3rd param)", true, false, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(s.error_message(), + std::string("Getting metadata from plugin failed with error: ") + + kTestCredsPluginErrorMsg); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedBlockingAuthMetadataPluginFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, CompositeCallCreds) { + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + const char kMetadataKey1[] = "call-creds-key1"; + const char kMetadataKey2[] = "call-creds-key2"; + const char kMetadataVal1[] = "call-creds-val1"; + const char kMetadataVal2[] = "call-creds-val2"; + + context.set_credentials(grpc::CompositeCallCredentials( + grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin(kMetadataKey1, kMetadataVal1, + true, true, 0))), + grpc::MetadataCredentialsFromPlugin( + std::unique_ptr( + new TestMetadataCredentialsPlugin(kMetadataKey2, kMetadataVal2, + true, true, 0))))); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + kMetadataKey1, kMetadataVal1)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + kMetadataKey2, kMetadataVal2)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedCompositeCallCredsDebugString); +} + +TEST_P(SecureEnd2endTest, ClientAuthContext) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_check_auth_context(GetParam().credentials_type == + kTlsCredentialsType); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + std::shared_ptr auth_ctx = context.auth_context(); + std::vector tst = + auth_ctx->FindPropertyValues("transport_security_type"); + ASSERT_EQ(1u, tst.size()); + EXPECT_EQ(GetParam().credentials_type, ToString(tst[0])); + if (GetParam().credentials_type == kTlsCredentialsType) { + EXPECT_EQ("x509_subject_alternative_name", + auth_ctx->GetPeerIdentityPropertyName()); + EXPECT_EQ(4u, auth_ctx->GetPeerIdentity().size()); + EXPECT_EQ("*.test.google.fr", ToString(auth_ctx->GetPeerIdentity()[0])); + EXPECT_EQ("waterzooi.test.google.be", + ToString(auth_ctx->GetPeerIdentity()[1])); + EXPECT_EQ("*.test.youtube.com", ToString(auth_ctx->GetPeerIdentity()[2])); + EXPECT_EQ("192.168.1.3", ToString(auth_ctx->GetPeerIdentity()[3])); + } +} + +class ResourceQuotaEnd2endTest : public End2endTest { + public: + ResourceQuotaEnd2endTest() + : server_resource_quota_("server_resource_quota") {} + + void ConfigureServerBuilder(ServerBuilder* builder) override { + builder->SetResourceQuota(server_resource_quota_); + } + + private: + ResourceQuota server_resource_quota_; +}; + +TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) { + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +// TODO(vjpai): refactor arguments into a struct if it makes sense +std::vector CreateTestScenarios(bool use_proxy, + bool test_insecure, + bool test_secure, + bool test_inproc, + bool test_callback_server) { + std::vector scenarios; + std::vector credentials_types; + + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, + kClientChannelBackupPollIntervalMs); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + if (test_secure) { + credentials_types = + GetCredentialsProvider()->GetSecureCredentialsTypeList(); + } + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + + // Test callback with inproc or if the event-engine allows it + GPR_ASSERT(!credentials_types.empty()); + for (const auto& cred : credentials_types) { + scenarios.emplace_back(false, false, false, cred, false); + scenarios.emplace_back(true, false, false, cred, false); + if (test_callback_server) { + // Note that these scenarios will be dynamically disabled if the event + // engine doesn't run in the background + scenarios.emplace_back(false, false, false, cred, true); + scenarios.emplace_back(true, false, false, cred, true); + } + if (use_proxy) { + scenarios.emplace_back(false, true, false, cred, false); + scenarios.emplace_back(true, true, false, cred, false); + } + } + if (test_inproc && insec_ok()) { + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false); + if (test_callback_server) { + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, + true); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true); + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P( + End2end, End2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + End2endServerTryCancel, End2endServerTryCancelTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + ProxyEnd2end, ProxyEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + SecureEnd2end, SecureEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true))); + +INSTANTIATE_TEST_SUITE_P( + ResourceQuotaEnd2end, ResourceQuotaEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/exception_test.cc b/test/cpp/end2end/exception_test.cc new file mode 100644 index 00000000..1a4c418b --- /dev/null +++ b/test/cpp/end2end/exception_test.cc @@ -0,0 +1,124 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { + +const char* kErrorMessage = "This service caused an exception"; + +#if GRPC_ALLOW_EXCEPTIONS +class ExceptingServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* /*server_context*/, const EchoRequest* /*request*/, + EchoResponse* /*response*/) override { + throw -1; + } + Status RequestStream(ServerContext* /*context*/, + ServerReader* /*reader*/, + EchoResponse* /*response*/) override { + throw ServiceException(); + } + + private: + class ServiceException final : public std::exception { + public: + ServiceException() {} + + private: + const char* what() const noexcept override { return kErrorMessage; } + }; +}; + +class ExceptionTest : public ::testing::Test { + protected: + ExceptionTest() {} + + void SetUp() override { + ServerBuilder builder; + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + channel_ = server_->InProcessChannel(ChannelArguments()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr server_; + ExceptingServiceImpl service_; +}; + +TEST_F(ExceptionTest, Unary) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("test"); + + for (int i = 0; i < 10; i++) { + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN); + } +} + +TEST_F(ExceptionTest, RequestStream) { + ResetStub(); + EchoResponse response; + + for (int i = 0; i < 10; i++) { + ClientContext context; + auto stream = stub_->RequestStream(&context, &response); + stream->WritesDone(); + Status s = stream->Finish(); + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN); + } +} + +#endif // GRPC_ALLOW_EXCEPTIONS + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/filter_end2end_test.cc b/test/cpp/end2end/filter_end2end_test.cc new file mode 100644 index 00000000..93f07b0a --- /dev/null +++ b/test/cpp/end2end/filter_end2end_test.cc @@ -0,0 +1,347 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/cpp/common/channel_filter.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return reinterpret_cast(i); } + +void verify_ok(CompletionQueue* cq, int i, bool expect_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(expect_ok, ok); + EXPECT_EQ(tag(i), got_tag); +} + +namespace { + +int global_num_connections = 0; +int global_num_calls = 0; +std::mutex global_mu; + +void IncrementConnectionCounter() { + std::unique_lock lock(global_mu); + ++global_num_connections; +} + +void ResetConnectionCounter() { + std::unique_lock lock(global_mu); + global_num_connections = 0; +} + +int GetConnectionCounterValue() { + std::unique_lock lock(global_mu); + return global_num_connections; +} + +void IncrementCallCounter() { + std::unique_lock lock(global_mu); + ++global_num_calls; +} + +void ResetCallCounter() { + std::unique_lock lock(global_mu); + global_num_calls = 0; +} + +int GetCallCounterValue() { + std::unique_lock lock(global_mu); + return global_num_calls; +} + +} // namespace + +class ChannelDataImpl : public ChannelData { + public: + grpc_error_handle Init(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) override { + IncrementConnectionCounter(); + return GRPC_ERROR_NONE; + } +}; + +class CallDataImpl : public CallData { + public: + void StartTransportStreamOpBatch(grpc_call_element* elem, + TransportStreamOpBatch* op) override { + // Incrementing the counter could be done from Init(), but we want + // to test that the individual methods are actually called correctly. + if (op->recv_initial_metadata() != nullptr) IncrementCallCounter(); + grpc_call_next_op(elem, op->op()); + } +}; + +class FilterEnd2endTest : public ::testing::Test { + protected: + FilterEnd2endTest() : server_host_("localhost") {} + + static void SetUpTestCase() { + // Workaround for + // https://github.com/google/google-toolbox-for-mac/issues/242 + static bool setup_done = false; + if (!setup_done) { + setup_done = true; + grpc::RegisterChannelFilter( + "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr); + } + } + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << server_host_ << ":" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&generic_service_); + srv_cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cli_cq_.Shutdown(); + srv_cq_->Shutdown(); + while (cli_cq_.Next(&ignored_tag, &ignored_ok)) { + } + while (srv_cq_->Next(&ignored_tag, &ignored_ok)) { + } + } + + void ResetStub() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + generic_stub_ = absl::make_unique(channel); + ResetConnectionCounter(); + ResetCallCounter(); + } + + void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); } + void client_ok(int i) { verify_ok(&cli_cq_, i, true); } + void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); } + void client_fail(int i) { verify_ok(&cli_cq_, i, false); } + + void SendRpc(int num_rpcs) { + const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + std::thread request_call([this]() { server_ok(4); }); + std::unique_ptr call = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + call->StartCall(tag(1)); + client_ok(1); + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + client_ok(2); + call->WritesDone(tag(3)); + client_ok(3); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + client_ok(8); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + client_ok(9); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + CompletionQueue cli_cq_; + std::unique_ptr srv_cq_; + std::unique_ptr stub_; + std::unique_ptr generic_stub_; + std::unique_ptr server_; + AsyncGenericService generic_service_; + const std::string server_host_; + std::ostringstream server_address_; +}; + +TEST_F(FilterEnd2endTest, SimpleRpc) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + SendRpc(1); + EXPECT_EQ(1, GetConnectionCounterValue()); + EXPECT_EQ(1, GetCallCounterValue()); +} + +TEST_F(FilterEnd2endTest, SequentialRpcs) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + SendRpc(10); + EXPECT_EQ(1, GetConnectionCounterValue()); + EXPECT_EQ(10, GetCallCounterValue()); +} + +// One ping, one pong. +TEST_F(FilterEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + + const std::string kMethodName( + "/grpc.cpp.test.util.EchoTestService/BidiStream"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter srv_stream(&srv_ctx); + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + send_request.set_message("Hello"); + std::thread request_call([this]() { server_ok(2); }); + std::unique_ptr cli_stream = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + cli_stream->StartCall(tag(1)); + client_ok(1); + + generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(), + srv_cq_.get(), tag(2)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + cli_stream->Write(*send_buffer, tag(3)); + send_buffer.reset(); + client_ok(3); + + ByteBuffer recv_buffer; + srv_stream.Read(&recv_buffer, tag(4)); + server_ok(4); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Write(*send_buffer, tag(5)); + send_buffer.reset(); + server_ok(5); + + cli_stream->Read(&recv_buffer, tag(6)); + client_ok(6); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_buffer, tag(8)); + server_fail(8); + + srv_stream.Finish(Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + + EXPECT_EQ(1, GetCallCounterValue()); + EXPECT_EQ(1, GetConnectionCounterValue()); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/flaky_network_test.cc b/test/cpp/end2end/flaky_network_test.cc new file mode 100644 index 00000000..890d4a78 --- /dev/null +++ b/test/cpp/end2end/flaky_network_test.cc @@ -0,0 +1,552 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GPR_LINUX +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { +namespace { + +struct TestScenario { + TestScenario(const std::string& creds_type, const std::string& content) + : credentials_type(creds_type), message_content(content) {} + const std::string credentials_type; + const std::string message_content; +}; + +class FlakyNetworkTest : public ::testing::TestWithParam { + protected: + FlakyNetworkTest() + : server_host_("grpctest"), + interface_("lo:1"), + ipv4_address_("10.0.0.1"), + netmask_("/32") {} + + void InterfaceUp() { + std::ostringstream cmd; + // create interface_ with address ipv4_address_ + cmd << "ip addr add " << ipv4_address_ << netmask_ << " dev " << interface_; + std::system(cmd.str().c_str()); + } + + void InterfaceDown() { + std::ostringstream cmd; + // remove interface_ + cmd << "ip addr del " << ipv4_address_ << netmask_ << " dev " << interface_; + std::system(cmd.str().c_str()); + } + + void DNSUp() { + std::ostringstream cmd; + // Add DNS entry for server_host_ in /etc/hosts + cmd << "echo '" << ipv4_address_ << " " << server_host_ + << "' >> /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DNSDown() { + std::ostringstream cmd; + // Remove DNS entry for server_host_ from /etc/hosts + // NOTE: we can't do this in one step with sed -i because when we are + // running under docker, the file is mounted by docker so we can't change + // its inode from within the container (sed -i creates a new file and + // replaces the old file, which changes the inode) + cmd << "sed '/" << server_host_ << "/d' /etc/hosts > /etc/hosts.orig"; + std::system(cmd.str().c_str()); + + // clear the stream + cmd.str(""); + + cmd << "cat /etc/hosts.orig > /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DropPackets() { + std::ostringstream cmd; + // drop packets with src IP = ipv4_address_ + cmd << "iptables -A INPUT -s " << ipv4_address_ << " -j DROP"; + + std::system(cmd.str().c_str()); + // clear the stream + cmd.str(""); + + // drop packets with dst IP = ipv4_address_ + cmd << "iptables -A INPUT -d " << ipv4_address_ << " -j DROP"; + } + + void RestoreNetwork() { + std::ostringstream cmd; + // remove iptables rule to drop packets with src IP = ipv4_address_ + cmd << "iptables -D INPUT -s " << ipv4_address_ << " -j DROP"; + std::system(cmd.str().c_str()); + // clear the stream + cmd.str(""); + // remove iptables rule to drop packets with dest IP = ipv4_address_ + cmd << "iptables -D INPUT -d " << ipv4_address_ << " -j DROP"; + } + + void FlakeNetwork() { + std::ostringstream cmd; + // Emulate a flaky network connection over interface_. Add a delay of 100ms + // +/- 20ms, 0.1% packet loss, 1% duplicates and 0.01% corrupt packets. + cmd << "tc qdisc replace dev " << interface_ + << " root netem delay 100ms 20ms distribution normal loss 0.1% " + "duplicate " + "0.1% corrupt 0.01% "; + std::system(cmd.str().c_str()); + } + + void UnflakeNetwork() { + // Remove simulated network flake on interface_ + std::ostringstream cmd; + cmd << "tc qdisc del dev " << interface_ << " root netem"; + std::system(cmd.str().c_str()); + } + + void NetworkUp() { + InterfaceUp(); + DNSUp(); + } + + void NetworkDown() { + InterfaceDown(); + DNSDown(); + } + + void SetUp() override { + NetworkUp(); + grpc_init(); + StartServer(); + } + + void TearDown() override { + NetworkDown(); + StopServer(); + grpc_shutdown(); + } + + void StartServer() { + // TODO (pjaikumar): Ideally, we should allocate the port dynamically using + // grpc_pick_unused_port_or_die(). That doesn't work inside some docker + // containers because port_server listens on localhost which maps to + // ip6-looopback, but ipv6 support is not enabled by default in docker. + port_ = SERVER_PORT; + + server_ = absl::make_unique(port_, GetParam().credentials_type); + server_->Start(server_host_); + } + void StopServer() { server_->Shutdown(); } + + std::unique_ptr BuildStub( + const std::shared_ptr& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr BuildChannel( + const std::string& lb_policy_name, + ChannelArguments args = ChannelArguments()) { + if (!lb_policy_name.empty()) { + args.SetLoadBalancingPolicyName(lb_policy_name); + } // else, default to pick first + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::ostringstream server_address; + server_address << server_host_ << ":" << port_; + return CreateCustomChannel(server_address.str(), channel_creds, args); + } + + bool SendRpc( + const std::unique_ptr& stub, + int timeout_ms = 0, bool wait_for_ready = false) { + auto response = absl::make_unique(); + EchoRequest request; + auto& msg = GetParam().message_content; + request.set_message(msg); + ClientContext context; + if (timeout_ms > 0) { + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + // Allow an RPC to be canceled (for deadline exceeded) after it has + // reached the server. + request.mutable_param()->set_skip_cancelled_check(true); + } + // See https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md for + // details of wait-for-ready semantics + if (wait_for_ready) { + context.set_wait_for_ready(true); + } + Status status = stub->Echo(&context, request, response.get()); + auto ok = status.ok(); + if (ok) { + gpr_log(GPR_DEBUG, "RPC succeeded"); + } else { + gpr_log(GPR_DEBUG, "RPC failed: %s", status.error_message().c_str()); + } + return ok; + } + + struct ServerData { + int port_; + const std::string creds_; + std::unique_ptr server_; + TestServiceImpl service_; + std::unique_ptr thread_; + bool server_ready_ = false; + + ServerData(int port, const std::string& creds) + : port_(port), creds_(creds) {} + + void Start(const std::string& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + std::mutex mu; + std::unique_lock lock(mu); + std::condition_variable cond; + thread_ = absl::make_unique( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond)); + cond.wait(lock, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const std::string& server_host, std::mutex* mu, + std::condition_variable* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(creds_); + builder.AddListeningPort(server_address.str(), server_creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + std::lock_guard lock(*mu); + server_ready_ = true; + cond->notify_one(); + } + + void Shutdown() { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + } + }; + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + private: + const std::string server_host_; + const std::string interface_; + const std::string ipv4_address_; + const std::string netmask_; + std::unique_ptr stub_; + std::unique_ptr server_; + const int SERVER_PORT = 32750; + int port_; +}; + +std::vector CreateTestScenarios() { + std::vector scenarios; + std::vector credentials_types; + std::vector messages; + + credentials_types.push_back(kInsecureCredentialsType); + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + + messages.push_back("🖖"); + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) { + std::string big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + for (auto cred = credentials_types.begin(); cred != credentials_types.end(); + ++cred) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + scenarios.emplace_back(*cred, *msg); + } + } + + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(FlakyNetworkTest, FlakyNetworkTest, + ::testing::ValuesIn(CreateTestScenarios())); + +// Network interface connected to server flaps +TEST_P(FlakyNetworkTest, NetworkTransition) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // bring down network + NetworkDown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // bring network interface back up + InterfaceUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + // Restore DNS entry for server + DNSUp(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); +} + +// Traffic to server server is blackholed temporarily with keepalives enabled +TEST_P(FlakyNetworkTest, ServerUnreachableWithKeepalive) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + const int kReconnectBackoffMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + // max time for a connection attempt + args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kReconnectBackoffMs); + // max time between reconnect attempts + args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, kReconnectBackoffMs); + + gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive start"); + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // break network connectivity + gpr_log(GPR_DEBUG, "Adding iptables rule to drop packets"); + DropPackets(); + std::this_thread::sleep_for(std::chrono::milliseconds(10000)); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // bring network interface back up + RestoreNetwork(); + gpr_log(GPR_DEBUG, "Removed iptables rule to drop packets"); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); + gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive end"); +} + +// +// Traffic to server server is blackholed temporarily with keepalives disabled +TEST_P(FlakyNetworkTest, ServerUnreachableNoKeepalive) { + auto channel = BuildChannel("pick_first", ChannelArguments()); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // break network connectivity + DropPackets(); + + std::thread sender = std::thread([this, &stub]() { + // RPC with deadline should timeout + EXPECT_FALSE(SendRpc(stub, /*timeout_ms=*/500, /*wait_for_ready=*/true)); + // RPC without deadline forever until call finishes + EXPECT_TRUE(SendRpc(stub, /*timeout_ms=*/0, /*wait_for_ready=*/true)); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + // bring network interface back up + RestoreNetwork(); + + // wait for RPC to finish + sender.join(); +} + +// Send RPCs over a flaky network connection +TEST_P(FlakyNetworkTest, FlakyNetwork) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + const int kMessageCount = 100; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // simulate flaky network (packet loss, corruption and delays) + FlakeNetwork(); + for (int i = 0; i < kMessageCount; ++i) { + SendRpc(stub); + } + // remove network flakiness + UnflakeNetwork(); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +// Server is shutdown gracefully and restarted. Client keepalives are enabled +TEST_P(FlakyNetworkTest, ServerRestartKeepaliveEnabled) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // server goes down, client should detect server going down and calls should + // fail + StopServer(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_FALSE(SendRpc(stub)); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // server restarts, calls succeed + StartServer(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + // EXPECT_TRUE(SendRpc(stub)); +} + +// Server is shutdown gracefully and restarted. Client keepalives are enabled +TEST_P(FlakyNetworkTest, ServerRestartKeepaliveDisabled) { + auto channel = BuildChannel("pick_first", ChannelArguments()); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // server sends GOAWAY when it's shutdown, so client attempts to reconnect + StopServer(); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // server restarts, calls succeed + StartServer(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); +} + +} // namespace +} // namespace testing +} // namespace grpc +#endif // GPR_LINUX + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/generic_end2end_test.cc b/test/cpp/end2end/generic_end2end_test.cc new file mode 100644 index 00000000..0ebc863a --- /dev/null +++ b/test/cpp/end2end/generic_end2end_test.cc @@ -0,0 +1,431 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return reinterpret_cast(i); } + +void verify_ok(CompletionQueue* cq, int i, bool expect_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(expect_ok, ok); + EXPECT_EQ(tag(i), got_tag); +} + +class GenericEnd2endTest : public ::testing::Test { + protected: + GenericEnd2endTest() : server_host_("localhost") {} + + void SetUp() override { + shut_down_ = false; + int port = grpc_pick_unused_port_or_die(); + server_address_ << server_host_ << ":" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&generic_service_); + // Include a second call to RegisterAsyncGenericService to make sure that + // we get an error in the log, since it is not allowed to have 2 async + // generic services + builder.RegisterAsyncGenericService(&generic_service_); + srv_cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void ShutDownServerAndCQs() { + if (!shut_down_) { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cli_cq_.Shutdown(); + srv_cq_->Shutdown(); + while (cli_cq_.Next(&ignored_tag, &ignored_ok)) { + } + while (srv_cq_->Next(&ignored_tag, &ignored_ok)) { + } + shut_down_ = true; + } + } + void TearDown() override { ShutDownServerAndCQs(); } + + void ResetStub() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + generic_stub_ = absl::make_unique(channel); + } + + void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); } + void client_ok(int i) { verify_ok(&cli_cq_, i, true); } + void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); } + void client_fail(int i) { verify_ok(&cli_cq_, i, false); } + + void SendRpc(int num_rpcs) { + SendRpc(num_rpcs, false, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + + void SendRpc(int num_rpcs, bool check_deadline, gpr_timespec deadline) { + const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + + if (check_deadline) { + cli_ctx.set_deadline(deadline); + } + + // Rather than using the original kMethodName, make a short-lived + // copy to also confirm that we don't refer to this object beyond + // the initial call preparation + const std::string* method_name = new std::string(kMethodName); + + std::unique_ptr call = + generic_stub_->PrepareCall(&cli_ctx, *method_name, &cli_cq_); + + delete method_name; // Make sure that this is not needed after invocation + + std::thread request_call([this]() { server_ok(4); }); + call->StartCall(tag(1)); + client_ok(1); + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + client_ok(2); + call->WritesDone(tag(3)); + client_ok(3); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + if (check_deadline) { + EXPECT_TRUE(gpr_time_similar(deadline, srv_ctx.raw_deadline(), + gpr_time_from_millis(1000, GPR_TIMESPAN))); + } + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + client_ok(8); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + client_ok(9); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + // Return errors to up to one call that comes in on the supplied completion + // queue, until the CQ is being shut down (and therefore we can no longer + // enqueue further events). + void DriveCompletionQueue() { + enum class Event : uintptr_t { + kCallReceived, + kResponseSent, + }; + // Request the call, but only if the main thread hasn't beaten us to + // shutting down the CQ. + grpc::GenericServerContext server_context; + grpc::GenericServerAsyncReaderWriter reader_writer(&server_context); + + { + std::lock_guard lock(shutting_down_mu_); + if (!shutting_down_) { + generic_service_.RequestCall( + &server_context, &reader_writer, srv_cq_.get(), srv_cq_.get(), + reinterpret_cast(Event::kCallReceived)); + } + } + // Process events. + { + Event event; + bool ok; + while (srv_cq_->Next(reinterpret_cast(&event), &ok)) { + std::lock_guard lock(shutting_down_mu_); + if (shutting_down_) { + // The main thread has started shutting down. Simply continue to drain + // events. + continue; + } + + switch (event) { + case Event::kCallReceived: + reader_writer.Finish( + ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "go away"), + reinterpret_cast(Event::kResponseSent)); + break; + + case Event::kResponseSent: + // We are done. + break; + } + } + } + } + + CompletionQueue cli_cq_; + std::unique_ptr srv_cq_; + std::unique_ptr stub_; + std::unique_ptr generic_stub_; + std::unique_ptr server_; + AsyncGenericService generic_service_; + const std::string server_host_; + std::ostringstream server_address_; + bool shutting_down_; + bool shut_down_; + std::mutex shutting_down_mu_; +}; + +TEST_F(GenericEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpc(1); +} + +TEST_F(GenericEnd2endTest, SequentialRpcs) { + ResetStub(); + SendRpc(10); +} + +TEST_F(GenericEnd2endTest, SequentialUnaryRpcs) { + ResetStub(); + const int num_rpcs = 10; + const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + + std::unique_ptr cli_send_buffer = + SerializeToByteBuffer(&send_request); + std::thread request_call([this]() { server_ok(4); }); + std::unique_ptr call = + generic_stub_->PrepareUnaryCall(&cli_ctx, kMethodName, *cli_send_buffer, + &cli_cq_); + call->StartCall(); + ByteBuffer cli_recv_buffer; + call->Finish(&cli_recv_buffer, &recv_status, tag(1)); + std::thread client_check([this] { client_ok(1); }); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + ByteBuffer srv_recv_buffer; + stream.Read(&srv_recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&srv_recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + std::unique_ptr srv_send_buffer = + SerializeToByteBuffer(&send_response); + stream.Write(*srv_send_buffer, tag(6)); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + client_check.join(); + EXPECT_TRUE(ParseFromByteBuffer(&cli_recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } +} + +// One ping, one pong. +TEST_F(GenericEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + + const std::string kMethodName( + "/grpc.cpp.test.util.EchoTestService/BidiStream"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter srv_stream(&srv_ctx); + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + send_request.set_message("Hello"); + std::thread request_call([this]() { server_ok(2); }); + std::unique_ptr cli_stream = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + cli_stream->StartCall(tag(1)); + client_ok(1); + + generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(), + srv_cq_.get(), tag(2)); + request_call.join(); + + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + cli_stream->Write(*send_buffer, tag(3)); + send_buffer.reset(); + client_ok(3); + + ByteBuffer recv_buffer; + srv_stream.Read(&recv_buffer, tag(4)); + server_ok(4); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Write(*send_buffer, tag(5)); + send_buffer.reset(); + server_ok(5); + + cli_stream->Read(&recv_buffer, tag(6)); + client_ok(6); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_buffer, tag(8)); + server_fail(8); + + srv_stream.Finish(Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_F(GenericEnd2endTest, Deadline) { + ResetStub(); + SendRpc(1, true, + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(10, GPR_TIMESPAN))); +} + +TEST_F(GenericEnd2endTest, ShortDeadline) { + ResetStub(); + + ClientContext cli_ctx; + EchoRequest request; + EchoResponse response; + + shutting_down_ = false; + std::thread driver([this] { DriveCompletionQueue(); }); + + request.set_message(""); + cli_ctx.set_deadline(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(500, GPR_TIMESPAN))); + Status s = stub_->Echo(&cli_ctx, request, &response); + EXPECT_FALSE(s.ok()); + { + std::lock_guard lock(shutting_down_mu_); + shutting_down_ = true; + } + ShutDownServerAndCQs(); + driver.join(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/grpclb_end2end_test.cc b/test/cpp/end2end/grpclb_end2end_test.cc new file mode 100644 index 00000000..11b8115e --- /dev/null +++ b/test/cpp/end2end/grpclb_end2end_test.cc @@ -0,0 +1,2038 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/service_config/service_config.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/lb/v1/load_balancer.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/resolve_localhost_ip46.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/counted_service.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_config.h" + +// TODO(dgq): Other scenarios in need of testing: +// - Send a serverlist with faulty ip:port addresses (port > 2^16, etc). +// - Test reception of invalid serverlist +// - Test against a non-LB server. +// - Random LB server closing the stream unexpectedly. +// +// Findings from end to end testing to be covered here: +// - Handling of LB servers restart, including reconnection after backing-off +// retries. +// - Destruction of load balanced channel (and therefore of grpclb instance) +// while: +// 1) the internal LB call is still active. This should work by virtue +// of the weak reference the LB call holds. The call should be terminated as +// part of the grpclb shutdown process. +// 2) the retry timer is active. Again, the weak reference it holds should +// prevent a premature call to \a glb_destroy. + +using std::chrono::system_clock; + +using grpc::lb::v1::LoadBalancer; +using grpc::lb::v1::LoadBalanceRequest; +using grpc::lb::v1::LoadBalanceResponse; + +namespace grpc { +namespace testing { +namespace { + +constexpr char kDefaultServiceConfig[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"; + +using BackendService = CountedService; +using BalancerService = CountedService; + +const char g_kCallCredsMdKey[] = "Balancer should not ..."; +const char g_kCallCredsMdValue[] = "... receive me"; + +class BackendServiceImpl : public BackendService { + public: + BackendServiceImpl() {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + // Backend should receive the call credentials metadata. + auto call_credentials_entry = + context->client_metadata().find(g_kCallCredsMdKey); + EXPECT_NE(call_credentials_entry, context->client_metadata().end()); + if (call_credentials_entry != context->client_metadata().end()) { + EXPECT_EQ(call_credentials_entry->second, g_kCallCredsMdValue); + } + IncreaseRequestCount(); + const auto status = TestServiceImpl::Echo(context, request, response); + IncreaseResponseCount(); + AddClient(context->peer()); + return status; + } + + void Start() {} + + void Shutdown() {} + + std::set clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + void AddClient(const std::string& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex clients_mu_; + std::set clients_ ABSL_GUARDED_BY(&clients_mu_); +}; + +std::string Ip4ToPackedString(const char* ip_str) { + struct in_addr ip4; + GPR_ASSERT(inet_pton(AF_INET, ip_str, &ip4) == 1); + return std::string(reinterpret_cast(&ip4), sizeof(ip4)); +} + +std::string Ip6ToPackedString(const char* ip_str) { + struct in6_addr ip6; + GPR_ASSERT(inet_pton(AF_INET6, ip_str, &ip6) == 1); + return std::string(reinterpret_cast(&ip6), sizeof(ip6)); +} + +struct ClientStats { + size_t num_calls_started = 0; + size_t num_calls_finished = 0; + size_t num_calls_finished_with_client_failed_to_send = 0; + size_t num_calls_finished_known_received = 0; + std::map drop_token_counts; + + ClientStats& operator+=(const ClientStats& other) { + num_calls_started += other.num_calls_started; + num_calls_finished += other.num_calls_finished; + num_calls_finished_with_client_failed_to_send += + other.num_calls_finished_with_client_failed_to_send; + num_calls_finished_known_received += + other.num_calls_finished_known_received; + for (const auto& p : other.drop_token_counts) { + drop_token_counts[p.first] += p.second; + } + return *this; + } + + void Reset() { + num_calls_started = 0; + num_calls_finished = 0; + num_calls_finished_with_client_failed_to_send = 0; + num_calls_finished_known_received = 0; + drop_token_counts.clear(); + } +}; + +class BalancerServiceImpl : public BalancerService { + public: + using Stream = ServerReaderWriter; + using ResponseDelayPair = std::pair; + + explicit BalancerServiceImpl(int client_load_reporting_interval_seconds) + : client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds) {} + + Status BalanceLoad(ServerContext* context, Stream* stream) override { + gpr_log(GPR_INFO, "LB[%p]: BalanceLoad", this); + { + grpc::internal::MutexLock lock(&mu_); + if (serverlist_done_) goto done; + } + { + // Balancer shouldn't receive the call credentials metadata. + EXPECT_EQ(context->client_metadata().find(g_kCallCredsMdKey), + context->client_metadata().end()); + LoadBalanceRequest request; + std::vector responses_and_delays; + + if (!stream->Read(&request)) { + goto done; + } else { + if (request.has_initial_request()) { + grpc::internal::MutexLock lock(&mu_); + service_names_.push_back(request.initial_request().name()); + } + } + IncreaseRequestCount(); + gpr_log(GPR_INFO, "LB[%p]: received initial message '%s'", this, + request.DebugString().c_str()); + + // TODO(juanlishen): Initial response should always be the first response. + if (client_load_reporting_interval_seconds_ > 0) { + LoadBalanceResponse initial_response; + initial_response.mutable_initial_response() + ->mutable_client_stats_report_interval() + ->set_seconds(client_load_reporting_interval_seconds_); + stream->Write(initial_response); + } + + { + grpc::internal::MutexLock lock(&mu_); + responses_and_delays = responses_and_delays_; + } + for (const auto& response_and_delay : responses_and_delays) { + SendResponse(stream, response_and_delay.first, + response_and_delay.second); + } + { + grpc::internal::MutexLock lock(&mu_); + while (!serverlist_done_) { + serverlist_cond_.Wait(&mu_); + } + } + + if (client_load_reporting_interval_seconds_ > 0) { + request.Clear(); + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "LB[%p]: received client load report message '%s'", + this, request.DebugString().c_str()); + GPR_ASSERT(request.has_client_stats()); + ClientStats load_report; + load_report.num_calls_started = + request.client_stats().num_calls_started(); + load_report.num_calls_finished = + request.client_stats().num_calls_finished(); + load_report.num_calls_finished_with_client_failed_to_send = + request.client_stats() + .num_calls_finished_with_client_failed_to_send(); + load_report.num_calls_finished_known_received = + request.client_stats().num_calls_finished_known_received(); + for (const auto& drop_token_count : + request.client_stats().calls_finished_with_drop()) { + load_report + .drop_token_counts[drop_token_count.load_balance_token()] = + drop_token_count.num_calls(); + } + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(&mu_); + load_report_queue_.emplace_back(std::move(load_report)); + load_report_cond_.Signal(); + } + } + } + done: + gpr_log(GPR_INFO, "LB[%p]: done", this); + return Status::OK; + } + + void add_response(const LoadBalanceResponse& response, int send_after_ms) { + grpc::internal::MutexLock lock(&mu_); + responses_and_delays_.push_back(std::make_pair(response, send_after_ms)); + } + + void Start() { + grpc::internal::MutexLock lock(&mu_); + serverlist_done_ = false; + responses_and_delays_.clear(); + load_report_queue_.clear(); + } + + void Shutdown() { + NotifyDoneWithServerlists(); + gpr_log(GPR_INFO, "LB[%p]: shut down", this); + } + + ClientStats WaitForLoadReport() { + grpc::internal::MutexLock lock(&mu_); + if (load_report_queue_.empty()) { + while (load_report_queue_.empty()) { + load_report_cond_.Wait(&mu_); + } + } + ClientStats load_report = std::move(load_report_queue_.front()); + load_report_queue_.pop_front(); + return load_report; + } + + void NotifyDoneWithServerlists() { + grpc::internal::MutexLock lock(&mu_); + if (!serverlist_done_) { + serverlist_done_ = true; + serverlist_cond_.SignalAll(); + } + } + + std::vector service_names() { + grpc::internal::MutexLock lock(&mu_); + return service_names_; + } + + private: + void SendResponse(Stream* stream, const LoadBalanceResponse& response, + int delay_ms) { + gpr_log(GPR_INFO, "LB[%p]: sleeping for %d ms...", this, delay_ms); + if (delay_ms > 0) { + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + } + gpr_log(GPR_INFO, "LB[%p]: Woke up! Sending response '%s'", this, + response.DebugString().c_str()); + IncreaseResponseCount(); + stream->Write(response); + } + + const int client_load_reporting_interval_seconds_; + std::vector responses_and_delays_; + std::vector service_names_; + + grpc::internal::Mutex mu_; + grpc::internal::CondVar serverlist_cond_; + bool serverlist_done_ ABSL_GUARDED_BY(mu_) = false; + grpc::internal::CondVar load_report_cond_; + std::deque load_report_queue_ ABSL_GUARDED_BY(mu_); +}; + +class GrpclbEnd2endTest : public ::testing::Test { + protected: + GrpclbEnd2endTest(size_t num_backends, size_t num_balancers, + int client_load_reporting_interval_seconds) + : server_host_("localhost"), + num_backends_(num_backends), + num_balancers_(num_balancers), + client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + grpc_init(); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + void SetUp() override { + bool localhost_resolves_to_ipv4 = false; + bool localhost_resolves_to_ipv6 = false; + grpc_core::LocalhostResolves(&localhost_resolves_to_ipv4, + &localhost_resolves_to_ipv6); + ipv6_only_ = !localhost_resolves_to_ipv4 && localhost_resolves_to_ipv6; + response_generator_ = + grpc_core::MakeRefCounted(); + // Start the backends. + for (size_t i = 0; i < num_backends_; ++i) { + backends_.emplace_back(new ServerThread("backend")); + backends_.back()->Start(server_host_); + } + // Start the load balancers. + for (size_t i = 0; i < num_balancers_; ++i) { + balancers_.emplace_back(new ServerThread( + "balancer", client_load_reporting_interval_seconds_)); + balancers_.back()->Start(server_host_); + } + ResetStub(); + } + + void TearDown() override { + ShutdownAllBackends(); + for (auto& balancer : balancers_) balancer->Shutdown(); + } + + void StartAllBackends() { + for (auto& backend : backends_) backend->Start(server_host_); + } + + void StartBackend(size_t index) { backends_[index]->Start(server_host_); } + + void ShutdownAllBackends() { + for (auto& backend : backends_) backend->Shutdown(); + } + + void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); } + + void ResetStub(int fallback_timeout = 0, + const std::string& expected_targets = "", + int subchannel_cache_delay_ms = 0) { + ChannelArguments args; + if (fallback_timeout > 0) args.SetGrpclbFallbackTimeout(fallback_timeout); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + if (!expected_targets.empty()) { + args.SetString(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS, expected_targets); + } + if (subchannel_cache_delay_ms > 0) { + args.SetInt(GRPC_ARG_GRPCLB_SUBCHANNEL_CACHE_INTERVAL_MS, + subchannel_cache_delay_ms); + } + std::ostringstream uri; + uri << "fake:///" << kApplicationTargetName_; + // TODO(dgq): templatize tests to run everything using both secure and + // insecure channel credentials. + grpc_channel_credentials* channel_creds = + grpc_fake_transport_security_credentials_create(); + grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create( + g_kCallCredsMdKey, g_kCallCredsMdValue, false); + std::shared_ptr creds( + new SecureChannelCredentials(grpc_composite_channel_credentials_create( + channel_creds, call_creds, nullptr))); + call_creds->Unref(); + channel_creds->Unref(); + channel_ = ::grpc::CreateCustomChannel(uri.str(), creds, args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void ResetBackendCounters() { + for (auto& backend : backends_) backend->service_.ResetCounters(); + } + + ClientStats WaitForLoadReports() { + ClientStats client_stats; + for (auto& balancer : balancers_) { + client_stats += balancer->service_.WaitForLoadReport(); + } + return client_stats; + } + + bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + if (backends_[i]->service_.request_count() == 0) return false; + } + return true; + } + + void SendRpcAndCount(int* num_total, int* num_ok, int* num_failure, + int* num_drops) { + const Status status = SendRpc(); + if (status.ok()) { + ++*num_ok; + } else { + if (status.error_message() == "drop directed by grpclb balancer") { + ++*num_drops; + } else { + ++*num_failure; + } + } + ++*num_total; + } + + std::tuple WaitForAllBackends(int num_requests_multiple_of = 1, + size_t start_index = 0, + size_t stop_index = 0) { + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + int num_total = 0; + while (!SeenAllBackends(start_index, stop_index)) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); + } + while (num_total % num_requests_multiple_of != 0) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); + } + ResetBackendCounters(); + gpr_log(GPR_INFO, + "Performed %d warm up requests (a multiple of %d) against the " + "backends. %d succeeded, %d failed, %d dropped.", + num_total, num_requests_multiple_of, num_ok, num_failure, + num_drops); + return std::make_tuple(num_ok, num_failure, num_drops); + } + + void WaitForBackend(size_t backend_idx) { + do { + (void)SendRpc(); + } while (backends_[backend_idx]->service_.request_count() == 0); + ResetBackendCounters(); + } + + struct AddressData { + int port; + std::string balancer_name; + }; + + grpc_core::ServerAddressList CreateLbAddressesFromAddressDataList( + const std::vector& address_data) { + grpc_core::ServerAddressList addresses; + for (const auto& addr : address_data) { + absl::StatusOr lb_uri = + grpc_core::URI::Parse(absl::StrCat( + ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", addr.port)); + GPR_ASSERT(lb_uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*lb_uri, &address)); + grpc_arg arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_DEFAULT_AUTHORITY), + const_cast(addr.balancer_name.c_str())); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(nullptr, &arg, 1); + addresses.emplace_back(address.addr, address.len, args); + } + return addresses; + } + + grpc_core::Resolver::Result MakeResolverResult( + const std::vector& balancer_address_data, + const std::vector& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::Resolver::Result result; + result.addresses = + CreateLbAddressesFromAddressDataList(backend_address_data); + grpc_error_handle error = GRPC_ERROR_NONE; + result.service_config = + grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ServerAddressList balancer_addresses = + CreateLbAddressesFromAddressDataList(balancer_address_data); + grpc_arg arg = CreateGrpclbBalancerAddressesArg(&balancer_addresses); + result.args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + return result; + } + + void SetNextResolutionAllBalancers( + const char* service_config_json = kDefaultServiceConfig) { + std::vector addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(addresses, {}, service_config_json); + } + + void SetNextResolution( + const std::vector& balancer_address_data, + const std::vector& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = MakeResolverResult( + balancer_address_data, backend_address_data, service_config_json); + response_generator_->SetResponse(std::move(result)); + } + + void SetNextReresolutionResponse( + const std::vector& balancer_address_data, + const std::vector& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = MakeResolverResult( + balancer_address_data, backend_address_data, service_config_json); + response_generator_->SetReresolutionResponse(std::move(result)); + } + + std::vector GetBackendPorts(size_t start_index = 0, + size_t stop_index = 0) const { + if (stop_index == 0) stop_index = backends_.size(); + std::vector backend_ports; + for (size_t i = start_index; i < stop_index; ++i) { + backend_ports.push_back(backends_[i]->port_); + } + return backend_ports; + } + + void ScheduleResponseForBalancer(size_t i, + const LoadBalanceResponse& response, + int delay_ms) { + balancers_[i]->service_.add_response(response, delay_ms); + } + + LoadBalanceResponse BuildResponseForBackends( + const std::vector& backend_ports, + const std::map& drop_token_counts) { + LoadBalanceResponse response; + for (const auto& drop_token_count : drop_token_counts) { + for (size_t i = 0; i < drop_token_count.second; ++i) { + auto* server = response.mutable_server_list()->add_servers(); + server->set_drop(true); + server->set_load_balance_token(drop_token_count.first); + } + } + for (const int& backend_port : backend_ports) { + auto* server = response.mutable_server_list()->add_servers(); + server->set_ip_address(ipv6_only_ ? Ip6ToPackedString("::1") + : Ip4ToPackedString("127.0.0.1")); + server->set_port(backend_port); + static int token_count = 0; + server->set_load_balance_token( + absl::StrFormat("token%03d", ++token_count)); + } + return response; + } + + Status SendRpc(EchoResponse* response = nullptr, int timeout_ms = 1000, + bool wait_for_ready = false, + const Status& expected_status = Status::OK) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + if (!expected_status.ok()) { + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(expected_status.error_code()); + error->set_error_message(expected_status.error_message()); + } + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + Status status = stub_->Echo(&context, request, response); + if (local_response) delete response; + return status; + } + + void CheckRpcSendOk(const size_t times = 1, const int timeout_ms = 1000, + bool wait_for_ready = false) { + for (size_t i = 0; i < times; ++i) { + EchoResponse response; + const Status status = SendRpc(&response, timeout_ms, wait_for_ready); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + + void CheckRpcSendFailure() { + const Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + } + + template + struct ServerThread { + template + explicit ServerThread(const std::string& type, Args&&... args) + : port_(grpc_pick_unused_port_or_die()), + type_(type), + service_(std::forward(args)...) {} + + void Start(const std::string& server_host) { + gpr_log(GPR_INFO, "starting %s server on port %d", type_.c_str(), port_); + GPR_ASSERT(!running_); + running_ = true; + service_.Start(); + grpc::internal::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Serve from firing before the wait below is hit. + grpc::internal::MutexLock lock(&mu); + grpc::internal::CondVar cond; + thread_ = absl::make_unique( + std::bind(&ServerThread::Serve, this, server_host, &mu, &cond)); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", type_.c_str()); + } + + void Serve(const std::string& server_host, grpc::internal::Mutex* mu, + grpc::internal::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(mu); + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + cond->Signal(); + } + + void Shutdown() { + if (!running_) return; + gpr_log(GPR_INFO, "%s about to shutdown", type_.c_str()); + service_.Shutdown(); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", type_.c_str()); + running_ = false; + } + + const int port_; + std::string type_; + T service_; + std::unique_ptr server_; + std::unique_ptr thread_; + bool running_ = false; + }; + + const std::string server_host_; + const size_t num_backends_; + const size_t num_balancers_; + const int client_load_reporting_interval_seconds_; + bool ipv6_only_ = false; + std::shared_ptr channel_; + std::unique_ptr stub_; + std::vector>> backends_; + std::vector>> balancers_; + grpc_core::RefCountedPtr + response_generator_; + const std::string kRequestMessage_ = "Live long and prosper."; + const std::string kApplicationTargetName_ = "application_target_name"; +}; + +class SingleBalancerTest : public GrpclbEnd2endTest { + public: + SingleBalancerTest() : GrpclbEnd2endTest(4, 1, 0) {} +}; + +TEST_F(SingleBalancerTest, Vanilla) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SubchannelCaching) { + ResetStub(/*fallback_timeout=*/0, /*expected_targets=*/"", + /*subchannel_cache_delay_ms=*/1500); + SetNextResolutionAllBalancers(); + // Initially send all backends. + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Then remove backends 0 and 1. + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(2), {}), 1000); + // Now re-add backend 1. + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(1), {}), 1000); + // Wait for all backends to come online. + WaitForAllBackends(); + // Send RPCs for long enough to get all responses. + gpr_timespec deadline = grpc_timeout_milliseconds_to_deadline(3000); + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), deadline) < 0); + // Backend 0 should have received less traffic than the others. + // Backend 1 would have received less traffic than 2 and 3. + gpr_log(GPR_INFO, "BACKEND 0: %" PRIuPTR " requests", + backends_[0]->service_.request_count()); + EXPECT_GT(backends_[0]->service_.request_count(), 0); + for (size_t i = 1; i < backends_.size(); ++i) { + gpr_log(GPR_INFO, "BACKEND %" PRIuPTR ": %" PRIuPTR " requests", i, + backends_[i]->service_.request_count()); + EXPECT_GT(backends_[i]->service_.request_count(), + backends_[0]->service_.request_count()) + << "backend " << i; + if (i >= 2) { + EXPECT_GT(backends_[i]->service_.request_count(), + backends_[1]->service_.request_count()) + << "backend " << i; + } + } + // Backend 1 should never have lost its connection from the client. + EXPECT_EQ(1UL, backends_[1]->service_.clients().size()); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // And sent 3 responses. + EXPECT_EQ(3U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, ReturnServerStatus) { + SetNextResolutionAllBalancers(); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send a request that the backend will fail, and make sure we get + // back the right status. + Status expected(StatusCode::INVALID_ARGUMENT, "He's dead, Jim!"); + Status actual = SendRpc(/*response=*/nullptr, /*timeout_ms=*/1000, + /*wait_for_ready=*/false, expected); + EXPECT_EQ(actual.error_code(), expected.error_code()); + EXPECT_EQ(actual.error_message(), expected.error_message()); +} + +TEST_F(SingleBalancerTest, SelectGrpclbWithMigrationServiceConfig) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + CheckRpcSendOk(1, 1000 /* timeout_ms */, true /* wait_for_ready */); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, + SelectGrpclbWithMigrationServiceConfigAndNoAddresses) { + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + SetNextResolution({}, {}, + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"); + // Try to connect. + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true)); + // Should go into state TRANSIENT_FAILURE when we enter fallback mode. + const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1); + grpc_connectivity_state state; + while ((state = channel_->GetState(false)) != + GRPC_CHANNEL_TRANSIENT_FAILURE) { + ASSERT_TRUE(channel_->WaitForStateChange(state, deadline)); + } + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, UsePickFirstChildPolicy) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"childPolicy\":[\n" + " { \"pick_first\":{} }\n" + " ]\n" + " } }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + const size_t kNumRpcs = num_backends_ * 2; + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Check that all requests went to the first backend. This verifies + // that we used pick_first instead of round_robin as the child policy. + EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs); + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 0UL); + } + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SwapChildPolicy) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"childPolicy\":[\n" + " { \"pick_first\":{} }\n" + " ]\n" + " } }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + const size_t kNumRpcs = num_backends_ * 2; + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + // Check that all requests went to the first backend. This verifies + // that we used pick_first instead of round_robin as the child policy. + EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs); + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 0UL); + } + // Send new resolution that removes child policy from service config. + SetNextResolutionAllBalancers(); + WaitForAllBackends(); + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + // Check that every backend saw the same number of requests. This verifies + // that we used round_robin. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 2UL); + } + // Done. + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SameBackendListedMultipleTimes) { + SetNextResolutionAllBalancers(); + // Same backend listed twice. + std::vector ports; + ports.push_back(backends_[0]->port_); + ports.push_back(backends_[0]->port_); + const size_t kNumRpcsPerAddress = 10; + ScheduleResponseForBalancer(0, BuildResponseForBackends(ports, {}), 0); + // We need to wait for the backend to come online. + WaitForBackend(0); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * ports.size()); + // Backend should have gotten 20 requests. + EXPECT_EQ(kNumRpcsPerAddress * 2, backends_[0]->service_.request_count()); + // And they should have come from a single client port, because of + // subchannel sharing. + EXPECT_EQ(1UL, backends_[0]->service_.clients().size()); + balancers_[0]->service_.NotifyDoneWithServerlists(); +} + +TEST_F(SingleBalancerTest, SecureNaming) { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution({AddressData{balancers_[0]->port_, "lb"}}); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SecureNamingDeathTest) { + GRPC_GTEST_FLAG_SET_DEATH_TEST_STYLE("threadsafe"); + // Make sure that we blow up (via abort() from the security connector) when + // the name from the balancer doesn't match expectations. + ASSERT_DEATH_IF_SUPPORTED( + { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution({AddressData{balancers_[0]->port_, "woops"}}); + channel_->WaitForConnected(grpc_timeout_seconds_to_deadline(1)); + }, + ""); +} + +TEST_F(SingleBalancerTest, InitiallyEmptyServerlist) { + SetNextResolutionAllBalancers(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const int kCallDeadlineMs = kServerlistDelayMs * 2; + // First response is an empty serverlist, sent right away. + ScheduleResponseForBalancer(0, LoadBalanceResponse(), 0); + // Send non-empty serverlist only after kServerlistDelayMs + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), kServerlistDelayMs); + const auto t0 = system_clock::now(); + // Client will block: LB will initially send empty serverlist. + CheckRpcSendOk(1, kCallDeadlineMs, true /* wait_for_ready */); + const auto ellapsed_ms = + std::chrono::duration_cast( + system_clock::now() - t0); + // but eventually, the LB sends a serverlist update that allows the call to + // proceed. The call delay must be larger than the delay in sending the + // populated serverlist but under the call's deadline (which is enforced by + // the call's deadline). + EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent two responses. + EXPECT_EQ(2U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, AllServersUnreachableFailFast) { + SetNextResolutionAllBalancers(); + const size_t kNumUnreachableServers = 5; + std::vector ports; + for (size_t i = 0; i < kNumUnreachableServers; ++i) { + ports.push_back(grpc_pick_unused_port_or_die()); + } + ScheduleResponseForBalancer(0, BuildResponseForBackends(ports, {}), 0); + const Status status = SendRpc(); + // The error shouldn't be DEADLINE_EXCEEDED. + EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code()); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, Fallback) { + SetNextResolutionAllBalancers(); + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const size_t kNumBackendsInResolution = backends_.size() / 2; + + ResetStub(kFallbackTimeoutMs); + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Send non-empty serverlist only after kServerlistDelayMs. + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + GetBackendPorts(kNumBackendsInResolution /* start_index */), {}), + kServerlistDelayMs); + + // Wait until all the fallback backends are reachable. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + WaitForBackend(i); + } + + // The first request. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // Fallback is used: each backend returned by the resolver should have + // gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + // Wait until the serverlist reception has been processed and all backends + // in the serverlist are reachable. + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + WaitForBackend(i); + } + + // Send out the second request. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(backends_.size() - kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + + // Serverlist is used: each backend returned by the balancer should + // have gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, FallbackUpdate) { + SetNextResolutionAllBalancers(); + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const size_t kNumBackendsInResolution = backends_.size() / 3; + const size_t kNumBackendsInResolutionUpdate = backends_.size() / 3; + + ResetStub(kFallbackTimeoutMs); + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Send non-empty serverlist only after kServerlistDelayMs. + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + GetBackendPorts(kNumBackendsInResolution + + kNumBackendsInResolutionUpdate /* start_index */), + {}), + kServerlistDelayMs); + + // Wait until all the fallback backends are reachable. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + WaitForBackend(i); + } + + // The first request. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // Fallback is used: each backend returned by the resolver should have + // gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + balancer_addresses.clear(); + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + backend_addresses.clear(); + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Wait until the resolution update has been processed and all the new + // fallback backends are reachable. + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + WaitForBackend(i); + } + + // Send out the second request. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolutionUpdate); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + + // The resolution update is used: each backend in the resolution update should + // have gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + // Wait until the serverlist reception has been processed and all backends + // in the serverlist are reachable. + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + WaitForBackend(i); + } + + // Send out the third request. + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(backends_.size() - kNumBackendsInResolution - + kNumBackendsInResolutionUpdate); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + + // Serverlist is used: each backend returned by the balancer should + // have gotten one request. + for (size_t i = 0; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, + FallbackAfterStartup_LoseContactWithBalancerThenBackends) { + // First two backends are fallback, last two are pointed to by balancer. + const size_t kNumFallbackBackends = 2; + const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends; + std::vector backend_addresses; + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + std::vector balancer_addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(kNumFallbackBackends), {}), + 0); + // Try to connect. + channel_->GetState(true /* try_to_connect */); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); + // Stop balancer. RPCs should continue going to backends from balancer. + balancers_[0]->Shutdown(); + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Stop backends from balancer. This should put us in fallback mode. + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + ShutdownBackend(i); + } + WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */, + kNumFallbackBackends /* stop_index */); + // Restart the backends from the balancer. We should *not* start + // sending traffic back to them at this point (although the behavior + // in xds may be different). + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + StartBackend(i); + } + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Now start the balancer again. This should cause us to exit + // fallback mode. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(kNumFallbackBackends), {}), + 0); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); +} + +TEST_F(SingleBalancerTest, + FallbackAfterStartup_LoseContactWithBackendsThenBalancer) { + // First two backends are fallback, last two are pointed to by balancer. + const size_t kNumFallbackBackends = 2; + const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends; + std::vector backend_addresses; + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + std::vector balancer_addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(kNumFallbackBackends), {}), + 0); + // Try to connect. + channel_->GetState(true /* try_to_connect */); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); + // Stop backends from balancer. Since we are still in contact with + // the balancer at this point, RPCs should be failing. + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + ShutdownBackend(i); + } + CheckRpcSendFailure(); + // Stop balancer. This should put us in fallback mode. + balancers_[0]->Shutdown(); + WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */, + kNumFallbackBackends /* stop_index */); + // Restart the backends from the balancer. We should *not* start + // sending traffic back to them at this point (although the behavior + // in xds may be different). + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + StartBackend(i); + } + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Now start the balancer again. This should cause us to exit + // fallback mode. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(kNumFallbackBackends), {}), + 0); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); +} + +TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerChannelFails) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return an unreachable balancer and one fallback backend. + std::vector balancer_addresses; + balancer_addresses.emplace_back( + AddressData{grpc_pick_unused_port_or_die(), ""}); + std::vector backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerCallFails) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return one balancer and one fallback backend. + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer drops call without sending a serverlist. + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackControlledByBalancer_BeforeFirstServerlist) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return one balancer and one fallback backend. + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer explicitly tells client to fallback. + LoadBalanceResponse resp; + resp.mutable_fallback_response(); + ScheduleResponseForBalancer(0, resp, 0); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackControlledByBalancer_AfterFirstServerlist) { + // Return one balancer and one fallback backend (backend 0). + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer initially sends serverlist, then tells client to fall back, + // then sends the serverlist again. + // The serverlist points to backend 1. + LoadBalanceResponse serverlist_resp = + BuildResponseForBackends({backends_[1]->port_}, {}); + LoadBalanceResponse fallback_resp; + fallback_resp.mutable_fallback_response(); + ScheduleResponseForBalancer(0, serverlist_resp, 0); + ScheduleResponseForBalancer(0, fallback_resp, 100); + ScheduleResponseForBalancer(0, serverlist_resp, 100); + // Requests initially go to backend 1, then go to backend 0 in + // fallback mode, then go back to backend 1 when we exit fallback. + WaitForBackend(1); + WaitForBackend(0); + WaitForBackend(1); +} + +TEST_F(SingleBalancerTest, BackendsRestart) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Stop backends. RPCs should fail. + ShutdownAllBackends(); + CheckRpcSendFailure(); + // Restart backends. RPCs should start succeeding again. + StartAllBackends(); + CheckRpcSendOk(1 /* times */, 2000 /* timeout_ms */, + true /* wait_for_ready */); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, ServiceNameFromLbPolicyConfig) { + constexpr char kServiceConfigWithTarget[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"serviceName\":\"test_service\"\n" + " }}\n" + " ]\n" + "}"; + + SetNextResolutionAllBalancers(kServiceConfigWithTarget); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + EXPECT_EQ(balancers_[0]->service_.service_names().back(), "test_service"); +} + +class UpdatesTest : public GrpclbEnd2endTest { + public: + UpdatesTest() : GrpclbEnd2endTest(4, 3, 0) {} +}; + +TEST_F(UpdatesTest, UpdateBalancersButKeepUsingOriginalBalancer) { + SetNextResolutionAllBalancers(); + const std::vector first_backend{GetBackendPorts()[0]}; + const std::vector second_backend{GetBackendPorts()[1]}; + ScheduleResponseForBalancer(0, BuildResponseForBackends(first_backend, {}), + 0); + ScheduleResponseForBalancer(1, BuildResponseForBackends(second_backend, {}), + 0); + + // Wait until the first backend is ready. + WaitForBackend(0); + + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + std::vector addresses; + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // The current LB call is still working, so grpclb continued using it to the + // first balancer, which doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +// Send an update with the same set of LBs as the one in SetUp() in order to +// verify that the LB channel inside grpclb keeps the initial connection (which +// by definition is also present in the update). +TEST_F(UpdatesTest, UpdateBalancersRepeated) { + SetNextResolutionAllBalancers(); + const std::vector first_backend{GetBackendPorts()[0]}; + const std::vector second_backend{GetBackendPorts()[0]}; + + ScheduleResponseForBalancer(0, BuildResponseForBackends(first_backend, {}), + 0); + ScheduleResponseForBalancer(1, BuildResponseForBackends(second_backend, {}), + 0); + + // Wait until the first backend is ready. + WaitForBackend(0); + + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + std::vector addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[2]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // grpclb continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + balancers_[0]->service_.NotifyDoneWithServerlists(); + + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 2 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // grpclb continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + balancers_[0]->service_.NotifyDoneWithServerlists(); +} + +TEST_F(UpdatesTest, UpdateBalancersDeadUpdate) { + std::vector addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + SetNextResolution(addresses); + const std::vector first_backend{GetBackendPorts()[0]}; + const std::vector second_backend{GetBackendPorts()[1]}; + + ScheduleResponseForBalancer(0, BuildResponseForBackends(first_backend, {}), + 0); + ScheduleResponseForBalancer(1, BuildResponseForBackends(second_backend, {}), + 0); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill balancer 0 + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + + // This is serviced by the existing RR policy + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should again have gone to the first backend. + EXPECT_EQ(20U, backends_[0]->service_.request_count()); + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + // Wait until update has been processed, as signaled by the second backend + // receiving a request. In the meantime, the client continues to be serviced + // (by the first backend) without interruption. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + WaitForBackend(1); + + // This is serviced by the updated RR policy + backends_[1]->service_.ResetCounters(); + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // The second balancer, published as part of the first update, may end up + // getting two requests (that is, 1 <= #req <= 2) if the LB call retry timer + // firing races with the arrival of the update containing the second + // balancer. + EXPECT_GE(balancers_[1]->service_.request_count(), 1U); + EXPECT_GE(balancers_[1]->service_.response_count(), 1U); + EXPECT_LE(balancers_[1]->service_.request_count(), 2U); + EXPECT_LE(balancers_[1]->service_.response_count(), 2U); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +TEST_F(UpdatesTest, ReresolveDeadBackend) { + ResetStub(500); + // The first resolution contains the addresses of a balancer that never + // responds, and a fallback backend. + std::vector balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Ask channel to connect to trigger resolver creation. + channel_->GetState(true); + // The re-resolution result will contain the addresses of the same balancer + // and a new fallback backend. + balancer_addresses.clear(); + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + backend_addresses.clear(); + backend_addresses.emplace_back(AddressData{backends_[1]->port_, ""}); + SetNextReresolutionResponse(balancer_addresses, backend_addresses); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the fallback backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill backend 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************"); + backends_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************"); + + // Wait until re-resolution has finished, as signaled by the second backend + // receiving a request. + WaitForBackend(1); + + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + balancers_[0]->service_.NotifyDoneWithServerlists(); + balancers_[1]->service_.NotifyDoneWithServerlists(); + balancers_[2]->service_.NotifyDoneWithServerlists(); + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(0U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +// TODO(juanlishen): Should be removed when the first response is always the +// initial response. Currently, if client load reporting is not enabled, the +// balancer doesn't send initial response. When the backend shuts down, an +// unexpected re-resolution will happen. This test configuration is a workaround +// for test ReresolveDeadBalancer. +class UpdatesWithClientLoadReportingTest : public GrpclbEnd2endTest { + public: + UpdatesWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 3, 2) {} +}; + +TEST_F(UpdatesWithClientLoadReportingTest, ReresolveDeadBalancer) { + const std::vector first_backend{GetBackendPorts()[0]}; + const std::vector second_backend{GetBackendPorts()[1]}; + ScheduleResponseForBalancer(0, BuildResponseForBackends(first_backend, {}), + 0); + ScheduleResponseForBalancer(1, BuildResponseForBackends(second_backend, {}), + 0); + + // Ask channel to connect to trigger resolver creation. + channel_->GetState(true); + std::vector addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + SetNextResolution(addresses); + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + SetNextReresolutionResponse(addresses); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill backend 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************"); + backends_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************"); + + CheckRpcSendFailure(); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + // Kill balancer 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + + // Wait until re-resolution has finished, as signaled by the second backend + // receiving a request. + WaitForBackend(1); + + // This is serviced by the new serverlist. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // After balancer 0 is killed, we restart an LB call immediately (because we + // disconnect to a previously connected balancer). Although we will cancel + // this call when the re-resolution update is done and another LB call restart + // is needed, this old call may still succeed reaching the LB server if + // re-resolution is slow. So balancer 1 may have received 2 requests and sent + // 2 responses. + EXPECT_GE(balancers_[1]->service_.request_count(), 1U); + EXPECT_GE(balancers_[1]->service_.response_count(), 1U); + EXPECT_LE(balancers_[1]->service_.request_count(), 2U); + EXPECT_LE(balancers_[1]->service_.response_count(), 2U); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, Drop) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 2; + const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses + + num_of_drop_by_load_balancing_addresses; + const int num_total_addresses = num_backends_ + num_of_drop_addresses; + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + GetBackendPorts(), + {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + // Wait until all backends are ready. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs for each server and drop address. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) { + EchoResponse response; + const Status status = SendRpc(&response); + if (!status.ok() && + status.error_message() == "drop directed by grpclb balancer") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, DropAllFirst) { + SetNextResolutionAllBalancers(); + // All registered addresses are marked as "drop". + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 1; + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + const Status status = SendRpc(nullptr, 1000, true); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "drop directed by grpclb balancer"); +} + +TEST_F(SingleBalancerTest, DropAll) { + SetNextResolutionAllBalancers(); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 1; + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 1000); + + // First call succeeds. + CheckRpcSendOk(); + // But eventually, the update with only dropped servers is processed and calls + // fail. + Status status; + do { + status = SendRpc(nullptr, 1000, true); + } while (status.ok()); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "drop directed by grpclb balancer"); +} + +class SingleBalancerWithClientLoadReportingTest : public GrpclbEnd2endTest { + public: + SingleBalancerWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 1, 3) {} +}; + +TEST_F(SingleBalancerWithClientLoadReportingTest, Vanilla) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(), {}), 0); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + ClientStats client_stats; + do { + client_stats += WaitForLoadReports(); + } while (client_stats.num_calls_finished != + kNumRpcsPerAddress * num_backends_ + num_ok); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.num_calls_started); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + (num_ok + num_drops), + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); +} + +TEST_F(SingleBalancerWithClientLoadReportingTest, BalancerRestart) { + SetNextResolutionAllBalancers(); + const size_t kNumBackendsFirstPass = 2; + const size_t kNumBackendsSecondPass = + backends_.size() - kNumBackendsFirstPass; + // Balancer returns backends starting at index 1. + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends(GetBackendPorts(0, kNumBackendsFirstPass), {}), + 0); + // Wait until all backends returned by the balancer are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* num_requests_multiple_of */ 1, /* start_index */ 0, + /* stop_index */ kNumBackendsFirstPass); + balancers_[0]->service_.NotifyDoneWithServerlists(); + ClientStats client_stats = WaitForLoadReports(); + EXPECT_EQ(static_cast(num_ok), client_stats.num_calls_started); + EXPECT_EQ(static_cast(num_ok), client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(static_cast(num_ok), + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); + // Shut down the balancer. + balancers_[0]->Shutdown(); + // Send 10 more requests per backend. This will continue using the + // last serverlist we received from the balancer before it was shut down. + ResetBackendCounters(); + CheckRpcSendOk(kNumBackendsFirstPass); + // Each backend should have gotten 1 request. + for (size_t i = 0; i < kNumBackendsFirstPass; ++i) { + EXPECT_EQ(1UL, backends_[i]->service_.request_count()); + } + // Now restart the balancer, this time pointing to all backends. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer( + 0, BuildResponseForBackends(GetBackendPorts(kNumBackendsFirstPass), {}), + 0); + // Wait for queries to start going to one of the new backends. + // This tells us that we're now using the new serverlist. + do { + CheckRpcSendOk(); + } while (backends_[2]->service_.request_count() == 0 && + backends_[3]->service_.request_count() == 0); + // Send one RPC per backend. + CheckRpcSendOk(kNumBackendsSecondPass); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Check client stats. + client_stats = WaitForLoadReports(); + EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_started); + EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumBackendsSecondPass + 1, + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); +} + +TEST_F(SingleBalancerWithClientLoadReportingTest, Drop) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 3; + const int num_of_drop_by_rate_limiting_addresses = 2; + const int num_of_drop_by_load_balancing_addresses = 1; + const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses + + num_of_drop_by_load_balancing_addresses; + const int num_total_addresses = num_backends_ + num_of_drop_addresses; + ScheduleResponseForBalancer( + 0, + BuildResponseForBackends( + GetBackendPorts(), + {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + // Wait until all backends are ready. + int num_warmup_ok = 0; + int num_warmup_failure = 0; + int num_warmup_drops = 0; + std::tie(num_warmup_ok, num_warmup_failure, num_warmup_drops) = + WaitForAllBackends(num_total_addresses /* num_requests_multiple_of */); + const int num_total_warmup_requests = + num_warmup_ok + num_warmup_failure + num_warmup_drops; + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) { + EchoResponse response; + const Status status = SendRpc(&response); + if (!status.ok() && + status.error_message() == "drop directed by grpclb balancer") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + const ClientStats client_stats = WaitForLoadReports(); + EXPECT_EQ( + kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests, + client_stats.num_calls_started); + EXPECT_EQ( + kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests, + client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_warmup_ok, + client_stats.num_calls_finished_known_received); + // The number of warmup request is a multiple of the number of addresses. + // Therefore, all addresses in the scheduled balancer response are hit the + // same number of times. + const int num_times_drop_addresses_hit = + num_warmup_drops / num_of_drop_addresses; + EXPECT_THAT( + client_stats.drop_token_counts, + ::testing::ElementsAre( + ::testing::Pair("load_balancing", + (kNumRpcsPerAddress + num_times_drop_addresses_hit)), + ::testing::Pair( + "rate_limiting", + (kNumRpcsPerAddress + num_times_drop_addresses_hit) * 2))); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/health_service_end2end_test.cc b/test/cpp/end2end/health_service_end2end_test.cc new file mode 100644 index 00000000..efbdc36b --- /dev/null +++ b/test/cpp/end2end/health_service_end2end_test.cc @@ -0,0 +1,374 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/health/v1/health.grpc.pb.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_health_check_service_impl.h" +#include "test/cpp/end2end/test_service_impl.h" + +using grpc::health::v1::Health; +using grpc::health::v1::HealthCheckRequest; +using grpc::health::v1::HealthCheckResponse; + +namespace grpc { +namespace testing { +namespace { + +// A custom implementation of the health checking service interface. This is +// used to test that it prevents the server from creating a default service and +// also serves as an example of how to override the default service. +class CustomHealthCheckService : public HealthCheckServiceInterface { + public: + explicit CustomHealthCheckService(HealthCheckServiceImpl* impl) + : impl_(impl) { + impl_->SetStatus("", HealthCheckResponse::SERVING); + } + void SetServingStatus(const std::string& service_name, + bool serving) override { + impl_->SetStatus(service_name, serving ? HealthCheckResponse::SERVING + : HealthCheckResponse::NOT_SERVING); + } + + void SetServingStatus(bool serving) override { + impl_->SetAll(serving ? HealthCheckResponse::SERVING + : HealthCheckResponse::NOT_SERVING); + } + + void Shutdown() override { impl_->Shutdown(); } + + private: + HealthCheckServiceImpl* impl_; // not owned +}; + +class HealthServiceEnd2endTest : public ::testing::Test { + protected: + HealthServiceEnd2endTest() {} + + void SetUpServer(bool register_sync_test_service, bool add_async_cq, + bool explicit_health_service, + std::unique_ptr service) { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + + bool register_sync_health_service_impl = + explicit_health_service && service != nullptr; + + // Setup server + ServerBuilder builder; + if (explicit_health_service) { + std::unique_ptr option( + new HealthCheckServiceServerBuilderOption(std::move(service))); + builder.SetOption(std::move(option)); + } + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + if (register_sync_test_service) { + // Register a sync service. + builder.RegisterService(&echo_test_service_); + } + if (register_sync_health_service_impl) { + builder.RegisterService(&health_check_service_impl_); + } + if (add_async_cq) { + cq_ = builder.AddCompletionQueue(); + } + server_ = builder.BuildAndStart(); + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + if (cq_ != nullptr) { + cq_->Shutdown(); + } + if (cq_thread_.joinable()) { + cq_thread_.join(); + } + } + } + + void ResetStubs() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + hc_stub_ = grpc::health::v1::Health::NewStub(channel); + } + + // When the expected_status is NOT OK, we do not care about the response. + void SendHealthCheckRpc(const std::string& service_name, + const Status& expected_status) { + EXPECT_FALSE(expected_status.ok()); + SendHealthCheckRpc(service_name, expected_status, + HealthCheckResponse::UNKNOWN); + } + + void SendHealthCheckRpc( + const std::string& service_name, const Status& expected_status, + HealthCheckResponse::ServingStatus expected_serving_status) { + HealthCheckRequest request; + request.set_service(service_name); + HealthCheckResponse response; + ClientContext context; + Status s = hc_stub_->Check(&context, request, &response); + EXPECT_EQ(expected_status.error_code(), s.error_code()); + if (s.ok()) { + EXPECT_EQ(expected_serving_status, response.status()); + } + } + + void VerifyHealthCheckService() { + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service != nullptr); + const std::string kHealthyService("healthy_service"); + const std::string kUnhealthyService("unhealthy_service"); + const std::string kNotRegisteredService("not_registered"); + service->SetServingStatus(kHealthyService, true); + service->SetServingStatus(kUnhealthyService, false); + + ResetStubs(); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + service->SetServingStatus(false); + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + } + + void VerifyHealthCheckServiceStreaming() { + const std::string kServiceName("service_name"); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + // Start Watch for service. + ClientContext context; + HealthCheckRequest request; + request.set_service(kServiceName); + std::unique_ptr<::grpc::ClientReaderInterface> reader = + hc_stub_->Watch(&context, request); + // Initial response will be SERVICE_UNKNOWN. + HealthCheckResponse response; + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVICE_UNKNOWN, response.status()); + response.Clear(); + // Now set service to NOT_SERVING and make sure we get an update. + service->SetServingStatus(kServiceName, false); + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.NOT_SERVING, response.status()); + response.Clear(); + // Now set service to SERVING and make sure we get another update. + service->SetServingStatus(kServiceName, true); + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVING, response.status()); + // Finish call. + context.TryCancel(); + } + + // Verify that after HealthCheckServiceInterface::Shutdown is called + // 1. unary client will see NOT_SERVING. + // 2. unary client still sees NOT_SERVING after a SetServing(true) is called. + // 3. streaming (Watch) client will see an update. + // 4. setting a new service to serving after shutdown will add the service + // name but return NOT_SERVING to client. + // This has to be called last. + void VerifyHealthCheckServiceShutdown() { + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service != nullptr); + const std::string kHealthyService("healthy_service"); + const std::string kUnhealthyService("unhealthy_service"); + const std::string kNotRegisteredService("not_registered"); + const std::string kNewService("add_after_shutdown"); + service->SetServingStatus(kHealthyService, true); + service->SetServingStatus(kUnhealthyService, false); + + ResetStubs(); + + // Start Watch for service. + ClientContext context; + HealthCheckRequest request; + request.set_service(kHealthyService); + std::unique_ptr<::grpc::ClientReaderInterface> reader = + hc_stub_->Watch(&context, request); + + HealthCheckResponse response; + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVING, response.status()); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + SendHealthCheckRpc(kNewService, Status(StatusCode::NOT_FOUND, "")); + + // Shutdown health check service. + service->Shutdown(); + + // Watch client gets another update. + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.NOT_SERVING, response.status()); + // Finish Watch call. + context.TryCancel(); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + // Setting status after Shutdown has no effect. + service->SetServingStatus(kHealthyService, true); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + + // Adding serving status for a new service after shutdown will return + // NOT_SERVING. + service->SetServingStatus(kNewService, true); + SendHealthCheckRpc(kNewService, Status::OK, + HealthCheckResponse::NOT_SERVING); + } + + TestServiceImpl echo_test_service_; + HealthCheckServiceImpl health_check_service_impl_; + std::unique_ptr hc_stub_; + std::unique_ptr cq_; + std::unique_ptr server_; + std::ostringstream server_address_; + std::thread cq_thread_; +}; + +TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceDisabled) { + EnableDefaultHealthCheckService(false); + EXPECT_FALSE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + HealthCheckServiceInterface* default_service = + server_->GetHealthCheckService(); + EXPECT_TRUE(default_service == nullptr); + + ResetStubs(); + + SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, "")); +} + +TEST_F(HealthServiceEnd2endTest, DefaultHealthService) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + VerifyHealthCheckService(); + VerifyHealthCheckServiceStreaming(); + + // The default service has a size limit of the service name. + const std::string kTooLongServiceName(201, 'x'); + SendHealthCheckRpc(kTooLongServiceName, + Status(StatusCode::INVALID_ARGUMENT, "")); +} + +TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + VerifyHealthCheckServiceShutdown(); +} + +// Provide an empty service to disable the default service. +TEST_F(HealthServiceEnd2endTest, ExplicitlyDisableViaOverride) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr empty_service; + SetUpServer(true, false, true, std::move(empty_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == nullptr); + + ResetStubs(); + + SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, "")); +} + +// Provide an explicit override of health checking service interface. +TEST_F(HealthServiceEnd2endTest, ExplicitlyOverride) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr override_service( + new CustomHealthCheckService(&health_check_service_impl_)); + HealthCheckServiceInterface* underlying_service = override_service.get(); + SetUpServer(false, false, true, std::move(override_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == underlying_service); + + ResetStubs(); + + VerifyHealthCheckService(); + VerifyHealthCheckServiceStreaming(); +} + +TEST_F(HealthServiceEnd2endTest, ExplicitlyHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr override_service( + new CustomHealthCheckService(&health_check_service_impl_)); + HealthCheckServiceInterface* underlying_service = override_service.get(); + SetUpServer(false, false, true, std::move(override_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == underlying_service); + + ResetStubs(); + + VerifyHealthCheckServiceShutdown(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/hybrid_end2end_test.cc b/test/cpp/end2end/hybrid_end2end_test.cc new file mode 100644 index 00000000..70c82c8d --- /dev/null +++ b/test/cpp/end2end/hybrid_end2end_test.cc @@ -0,0 +1,976 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return reinterpret_cast(i); } + +bool VerifyReturnSuccess(CompletionQueue* cq, int i) { + void* got_tag; + bool ok; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(tag(i), got_tag); + return ok; +} + +void Verify(CompletionQueue* cq, int i, bool expect_ok) { + EXPECT_EQ(expect_ok, VerifyReturnSuccess(cq, i)); +} + +// Handlers to handle async request at a server. To be run in a separate thread. +template +void HandleEcho(Service* service, ServerCompletionQueue* cq, bool dup_service) { + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + EchoRequest recv_request; + EchoResponse send_response; + service->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq, cq, + tag(1)); + Verify(cq, 1, true); + send_response.set_message(recv_request.message()); + if (dup_service) { + send_response.mutable_message()->append("_dup"); + } + response_writer.Finish(send_response, Status::OK, tag(2)); + Verify(cq, 2, true); +} + +// Handlers to handle raw request at a server. To be run in a +// separate thread. Note that this is the same as the async version, except +// that the req/resp are ByteBuffers +template +void HandleRawEcho(Service* service, ServerCompletionQueue* cq, + bool /*dup_service*/) { + ServerContext srv_ctx; + GenericServerAsyncResponseWriter response_writer(&srv_ctx); + ByteBuffer recv_buffer; + service->RequestEcho(&srv_ctx, &recv_buffer, &response_writer, cq, cq, + tag(1)); + Verify(cq, 1, true); + EchoRequest recv_request; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EchoResponse send_response; + send_response.set_message(recv_request.message()); + auto send_buffer = SerializeToByteBuffer(&send_response); + response_writer.Finish(*send_buffer, Status::OK, tag(2)); + Verify(cq, 2, true); +} + +template +void HandleClientStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + EchoRequest recv_request; + EchoResponse send_response; + ServerAsyncReader srv_stream(&srv_ctx); + service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1)); + Verify(cq, 1, true); + int i = 1; + do { + i++; + send_response.mutable_message()->append(recv_request.message()); + srv_stream.Read(&recv_request, tag(i)); + } while (VerifyReturnSuccess(cq, i)); + srv_stream.Finish(send_response, Status::OK, tag(100)); + Verify(cq, 100, true); +} + +template +void HandleRawClientStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + ByteBuffer recv_buffer; + EchoRequest recv_request; + EchoResponse send_response; + GenericServerAsyncReader srv_stream(&srv_ctx); + service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1)); + Verify(cq, 1, true); + int i = 1; + while (true) { + i++; + srv_stream.Read(&recv_buffer, tag(i)); + if (!VerifyReturnSuccess(cq, i)) { + break; + } + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + send_response.mutable_message()->append(recv_request.message()); + } + auto send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Finish(*send_buffer, Status::OK, tag(100)); + Verify(cq, 100, true); +} + +template +void HandleServerStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + EchoRequest recv_request; + EchoResponse send_response; + ServerAsyncWriter srv_stream(&srv_ctx); + service->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, cq, cq, + tag(1)); + Verify(cq, 1, true); + send_response.set_message(recv_request.message() + "0"); + srv_stream.Write(send_response, tag(2)); + Verify(cq, 2, true); + send_response.set_message(recv_request.message() + "1"); + srv_stream.Write(send_response, tag(3)); + Verify(cq, 3, true); + send_response.set_message(recv_request.message() + "2"); + srv_stream.Write(send_response, tag(4)); + Verify(cq, 4, true); + srv_stream.Finish(Status::OK, tag(5)); + Verify(cq, 5, true); +} + +void HandleGenericEcho(GenericServerAsyncReaderWriter* stream, + CompletionQueue* cq) { + ByteBuffer recv_buffer; + stream->Read(&recv_buffer, tag(2)); + Verify(cq, 2, true); + EchoRequest recv_request; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EchoResponse send_response; + send_response.set_message(recv_request.message()); + auto send_buffer = SerializeToByteBuffer(&send_response); + stream->Write(*send_buffer, tag(3)); + Verify(cq, 3, true); + stream->Finish(Status::OK, tag(4)); + Verify(cq, 4, true); +} + +void HandleGenericRequestStream(GenericServerAsyncReaderWriter* stream, + CompletionQueue* cq) { + ByteBuffer recv_buffer; + EchoRequest recv_request; + EchoResponse send_response; + int i = 1; + while (true) { + i++; + stream->Read(&recv_buffer, tag(i)); + if (!VerifyReturnSuccess(cq, i)) { + break; + } + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + send_response.mutable_message()->append(recv_request.message()); + } + auto send_buffer = SerializeToByteBuffer(&send_response); + stream->Write(*send_buffer, tag(99)); + Verify(cq, 99, true); + stream->Finish(Status::OK, tag(100)); + Verify(cq, 100, true); +} + +// Request and handle one generic call. +void HandleGenericCall(AsyncGenericService* service, + ServerCompletionQueue* cq) { + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + service->RequestCall(&srv_ctx, &stream, cq, cq, tag(1)); + Verify(cq, 1, true); + if (srv_ctx.method() == "/grpc.testing.EchoTestService/Echo") { + HandleGenericEcho(&stream, cq); + } else if (srv_ctx.method() == + "/grpc.testing.EchoTestService/RequestStream") { + HandleGenericRequestStream(&stream, cq); + } else { // other methods not handled yet. + gpr_log(GPR_ERROR, "method: %s", srv_ctx.method().c_str()); + GPR_ASSERT(0); + } +} + +class TestServiceImplDupPkg + : public ::grpc::testing::duplicate::EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message() + "_dup"); + return Status::OK; + } +}; + +class HybridEnd2endTest : public ::testing::TestWithParam { + protected: + HybridEnd2endTest() {} + + static void SetUpTestCase() { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + + void SetUp() override { + inproc_ = (::testing::UnitTest::GetInstance() + ->current_test_info() + ->value_param() != nullptr) + ? GetParam() + : false; + } + + bool SetUpServer(::grpc::Service* service1, ::grpc::Service* service2, + AsyncGenericService* generic_service, + CallbackGenericService* callback_generic_service, + int max_message_size = 0) { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + // Always add a sync unimplemented service: we rely on having at least one + // synchronous method to get a listening cq + builder.RegisterService(&unimplemented_service_); + builder.RegisterService(service1); + if (service2) { + builder.RegisterService(service2); + } + if (generic_service) { + builder.RegisterAsyncGenericService(generic_service); + } + if (callback_generic_service) { + builder.RegisterCallbackGenericService(callback_generic_service); + } + + if (max_message_size != 0) { + builder.SetMaxMessageSize(max_message_size); + } + + // Create a separate cq for each potential handler. + for (int i = 0; i < 5; i++) { + cqs_.push_back(builder.AddCompletionQueue(false)); + } + server_ = builder.BuildAndStart(); + + // If there is a generic callback service, this setup is only successful if + // we have an iomgr that can run in the background or are inprocess + return !callback_generic_service || grpc_iomgr_run_in_background() || + inproc_; + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + } + void* ignored_tag; + bool ignored_ok; + for (auto it = cqs_.begin(); it != cqs_.end(); ++it) { + (*it)->Shutdown(); + while ((*it)->Next(&ignored_tag, &ignored_ok)) { + } + } + } + + void ResetStub() { + std::shared_ptr channel = + inproc_ ? server_->InProcessChannel(ChannelArguments()) + : grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + // Test all rpc methods. + void TestAllMethods() { + SendEcho(); + SendSimpleClientStreaming(); + SendSimpleServerStreaming(); + SendBidiStreaming(); + } + + void SendEcho() { + EchoRequest send_request; + EchoResponse recv_response; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + Status recv_status = stub_->Echo(&cli_ctx, send_request, &recv_response); + EXPECT_EQ(send_request.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendEchoToDupService() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + Status recv_status = stub->Echo(&cli_ctx, send_request, &recv_response); + EXPECT_EQ(send_request.message() + "_dup", recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendSimpleClientStreaming() { + EchoRequest send_request; + EchoResponse recv_response; + std::string expected_message; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + auto stream = stub_->RequestStream(&cli_ctx, &recv_response); + for (int i = 0; i < 5; i++) { + EXPECT_TRUE(stream->Write(send_request)); + expected_message.append(send_request.message()); + } + stream->WritesDone(); + Status recv_status = stream->Finish(); + EXPECT_EQ(expected_message, recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendSimpleServerStreaming() { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "2"); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void SendSimpleServerStreamingToDupService() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + request.set_message("hello"); + + auto stream = stub->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0_dup"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1_dup"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "2_dup"); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void SendBidiStreaming() { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + std::string msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + grpc::testing::UnimplementedEchoService::Service unimplemented_service_; + std::vector> cqs_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + bool inproc_; +}; + +TEST_F(HybridEnd2endTest, AsyncEcho) { + typedef EchoTestService::WithAsyncMethod_Echo SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho, &service, cqs_[0].get(), + false); + TestAllMethods(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, RawEcho) { + typedef EchoTestService::WithRawMethod_Echo SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleRawEcho, &service, cqs_[0].get(), + false); + TestAllMethods(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, RawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread request_stream_handler_thread(HandleRawClientStreaming, + &service, cqs_[0].get()); + TestAllMethods(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncEchoRawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream< + EchoTestService::WithAsyncMethod_Echo> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho, &service, cqs_[0].get(), + false); + std::thread request_stream_handler_thread(HandleRawClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoRawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleRawClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncEchoRequestStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_Echo> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho, &service, cqs_[0].get(), + false); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + echo_handler_thread.join(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync method. +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_SyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + TestServiceImplDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync streamed unary method. +class StreamedUnaryDupPkg + : public duplicate::EchoTestService::WithStreamedUnaryMethod_Echo< + TestServiceImplDupPkg> { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncStreamedUnaryDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + StreamedUnaryDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully Streamed Unary +class FullyStreamedUnaryDupPkg + : public duplicate::EchoTestService::StreamedUnaryService { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncFullyStreamedUnaryDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + FullyStreamedUnaryDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync split server streaming method. +class SplitResponseStreamDupPkg + : public duplicate::EchoTestService:: + WithSplitStreamingMethod_ResponseStream { + public: + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + std::to_string(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncSplitStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + SplitResponseStreamDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully split server streamed +class FullySplitStreamedDupPkg + : public duplicate::EchoTestService::SplitStreamedService { + public: + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + std::to_string(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_FullySplitStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + FullySplitStreamedDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully server streamed +class FullyStreamedDupPkg : public duplicate::EchoTestService::StreamedService { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + std::to_string(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_FullyStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + FullyStreamedDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one async method. +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_AsyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream> + SType; + SType service; + duplicate::EchoTestService::AsyncService dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + std::thread echo_handler_thread( + HandleEcho, &dup_service, + cqs_[2].get(), true); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEcho) { + EchoTestService::WithGenericMethod_Echo service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + TestAllMethods(); + generic_handler_thread.join(); +} + +TEST_P(HybridEnd2endTest, CallbackGenericEcho) { + EchoTestService::WithGenericMethod_Echo service; + class GenericEchoService : public CallbackGenericService { + private: + ServerGenericBidiReactor* CreateReactor( + GenericCallbackServerContext* context) override { + EXPECT_EQ(context->method(), "/grpc.testing.EchoTestService/Echo"); + gpr_log(GPR_DEBUG, "Constructor of generic service %d", + static_cast(context->deadline().time_since_epoch().count())); + + class Reactor : public ServerGenericBidiReactor { + public: + Reactor() { StartRead(&request_); } + + private: + void OnDone() override { delete this; } + void OnReadDone(bool ok) override { + if (!ok) { + EXPECT_EQ(reads_complete_, 1); + } else { + EXPECT_EQ(reads_complete_++, 0); + response_ = request_; + StartWrite(&response_); + StartRead(&request_); + } + } + void OnWriteDone(bool ok) override { + Finish(ok ? Status::OK + : Status(StatusCode::UNKNOWN, "Unexpected failure")); + } + ByteBuffer request_; + ByteBuffer response_; + std::atomic_int reads_complete_{0}; + }; + return new Reactor; + } + } generic_service; + + if (!SetUpServer(&service, nullptr, nullptr, &generic_service)) { + return; + } + ResetStub(); + TestAllMethods(); +} + +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync method. +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_SyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo> + SType; + SType service; + AsyncGenericService generic_service; + TestServiceImplDupPkg dup_service; + SetUpServer(&service, &dup_service, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one async method. +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_AsyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo> + SType; + SType service; + AsyncGenericService generic_service; + duplicate::EchoTestService::AsyncService dup_service; + SetUpServer(&service, &dup_service, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + std::thread echo_handler_thread( + HandleEcho, &dup_service, + cqs_[2].get(), true); + TestAllMethods(); + SendEchoToDupService(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStreamResponseStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming, + &service, cqs_[1].get()); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[2].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); + response_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoRequestStreamAsyncResponseStream) { + typedef EchoTestService::WithGenericMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread generic_handler_thread2(HandleGenericCall, &generic_service, + cqs_[1].get()); + std::thread response_stream_handler_thread(HandleServerStreaming, + &service, cqs_[2].get()); + TestAllMethods(); + generic_handler_thread.join(); + generic_handler_thread2.join(); + response_stream_handler_thread.join(); +} + +// If WithGenericMethod is called and no generic service is registered, the +// server will fail to build. +TEST_F(HybridEnd2endTest, GenericMethodWithoutGenericService) { + EchoTestService::WithGenericMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream>> + service; + SetUpServer(&service, nullptr, nullptr, nullptr); + EXPECT_EQ(nullptr, server_.get()); +} + +INSTANTIATE_TEST_SUITE_P(HybridEnd2endTest, HybridEnd2endTest, + ::testing::Bool()); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc new file mode 100644 index 00000000..55684a49 --- /dev/null +++ b/test/cpp/end2end/interceptors_util.cc @@ -0,0 +1,214 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/interceptors_util.h" + +#include "absl/memory/memory.h" + +namespace grpc { +namespace testing { + +std::atomic PhonyInterceptor::num_times_run_; +std::atomic PhonyInterceptor::num_times_run_reverse_; +std::atomic PhonyInterceptor::num_times_cancel_; + +void MakeCall(const std::shared_ptr& channel, + const StubOptions& options) { + auto stub = grpc::testing::EchoTestService::NewStub(channel, options); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +void MakeClientStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < kNumStreamingMessages; i++) { + writer->Write(req); + expected_resp += "Hello"; + } + writer->WritesDone(); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), expected_resp); +} + +void MakeServerStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + auto reader = stub->ResponseStream(&ctx, req); + int count = 0; + while (reader->Read(&resp)) { + EXPECT_EQ(resp.message(), "Hello"); + count++; + } + ASSERT_EQ(count, kNumStreamingMessages); + Status s = reader->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeBidiStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + req.mutable_param()->set_echo_metadata(true); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < kNumStreamingMessages; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeAsyncCQCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, &cq)); + response_reader->Finish(&recv_response, &recv_status, tag(1)); + Verifier().Expect(1, true).Verify(&cq); + EXPECT_EQ(send_request.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQClientStreamingCall( + const std::shared_ptr& /*channel*/) { + // TODO(yashykt) : Fill this out +} + +void MakeAsyncCQServerStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + cli_ctx.AddMetadata("testkey", "testvalue"); + send_request.set_message("Hello"); + std::unique_ptr> cli_stream( + stub->AsyncResponseStream(&cli_ctx, send_request, &cq, tag(1))); + Verifier().Expect(1, true).Verify(&cq); + // Read the expected number of messages + for (int i = 0; i < kNumStreamingMessages; i++) { + cli_stream->Read(&recv_response, tag(2)); + Verifier().Expect(2, true).Verify(&cq); + ASSERT_EQ(recv_response.message(), send_request.message()); + } + // The next read should fail + cli_stream->Read(&recv_response, tag(3)); + Verifier().Expect(3, false).Verify(&cq); + // Get the status + cli_stream->Finish(&recv_status, tag(4)); + Verifier().Expect(4, true).Verify(&cq); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQBidiStreamingCall(const std::shared_ptr& /*channel*/) { + // TODO(yashykt) : Fill this out +} + +void MakeCallbackCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + std::mutex mu; + std::condition_variable cv; + bool done = false; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + stub->async()->Echo(&ctx, &req, &resp, [&resp, &mu, &done, &cv](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first.starts_with(key) && pair.second.starts_with(value)) { + return true; + } + } + return false; +} + +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first == key && pair.second == value) { + return true; + } + } + return false; +} + +std::vector> +CreatePhonyClientInterceptors() { + std::vector> + creators; + // Add 20 phony interceptors before hijacking interceptor + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + return creators; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/end2end/message_allocator_end2end_test.cc b/test/cpp/end2end/message_allocator_end2end_test.cc new file mode 100644 index 00000000..f42ac3a1 --- /dev/null +++ b/test/cpp/end2end/message_allocator_end2end_test.cc @@ -0,0 +1,401 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { +namespace { + +class CallbackTestServiceImpl : public EchoTestService::CallbackService { + public: + explicit CallbackTestServiceImpl() {} + + void SetAllocatorMutator( + std::function + mutator) { + allocator_mutator_ = std::move(mutator); + } + + ServerUnaryReactor* Echo(CallbackServerContext* context, + const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + if (allocator_mutator_) { + allocator_mutator_(context->GetRpcAllocatorState(), request, response); + } + auto* reactor = context->DefaultReactor(); + reactor->Finish(Status::OK); + return reactor; + } + + private: + std::function + allocator_mutator_; +}; + +enum class Protocol { INPROC, TCP }; + +class TestScenario { + public: + TestScenario(Protocol protocol, const std::string& creds_type) + : protocol(protocol), credentials_type(creds_type) {} + void Log() const; + Protocol protocol; + const std::string credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{protocol=" + << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP") + << "," << scenario.credentials_type << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_INFO, "%s", out.str().c_str()); +} + +class MessageAllocatorEnd2endTestBase + : public ::testing::TestWithParam { + protected: + MessageAllocatorEnd2endTestBase() { GetParam().Log(); } + + ~MessageAllocatorEnd2endTestBase() override = default; + + void CreateServer(MessageAllocator* allocator) { + ServerBuilder builder; + + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + if (GetParam().protocol == Protocol::TCP) { + picked_port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << picked_port_; + builder.AddListeningPort(server_address_.str(), server_creds); + } + callback_service_.SetMessageAllocatorFor_Echo(allocator); + builder.RegisterService(&callback_service_); + + server_ = builder.BuildAndStart(); + } + + void DestroyServer() { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + switch (GetParam().protocol) { + case Protocol::TCP: + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + break; + case Protocol::INPROC: + channel_ = server_->InProcessChannel(args); + break; + default: + assert(false); + } + stub_ = EchoTestService::NewStub(channel_); + } + + void TearDown() override { + DestroyServer(); + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + void SendRpcs(int num_rpcs) { + std::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + test_string += std::string(1024, 'x'); + request.set_message(test_string); + std::string val; + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->async()->Echo( + &cli_ctx, &request, &response, + [&request, &response, &done, &mu, &cv, val](Status s) { + GPR_ASSERT(s.ok()); + + EXPECT_EQ(request.message(), response.message()); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } + } + } + + int picked_port_{0}; + std::shared_ptr channel_; + std::unique_ptr stub_; + CallbackTestServiceImpl callback_service_; + std::unique_ptr server_; + std::ostringstream server_address_; +}; + +class NullAllocatorTest : public MessageAllocatorEnd2endTestBase {}; + +TEST_P(NullAllocatorTest, SimpleRpc) { + CreateServer(nullptr); + ResetStub(); + SendRpcs(1); +} + +class SimpleAllocatorTest : public MessageAllocatorEnd2endTestBase { + public: + class SimpleAllocator : public MessageAllocator { + public: + class MessageHolderImpl : public MessageHolder { + public: + MessageHolderImpl(std::atomic_int* request_deallocation_count, + std::atomic_int* messages_deallocation_count) + : request_deallocation_count_(request_deallocation_count), + messages_deallocation_count_(messages_deallocation_count) { + set_request(new EchoRequest); + set_response(new EchoResponse); + } + void Release() override { + (*messages_deallocation_count_)++; + delete request(); + delete response(); + delete this; + } + void FreeRequest() override { + (*request_deallocation_count_)++; + delete request(); + set_request(nullptr); + } + + EchoRequest* ReleaseRequest() { + auto* ret = request(); + set_request(nullptr); + return ret; + } + + private: + std::atomic_int* const request_deallocation_count_; + std::atomic_int* const messages_deallocation_count_; + }; + MessageHolder* AllocateMessages() override { + allocation_count++; + return new MessageHolderImpl(&request_deallocation_count, + &messages_deallocation_count); + } + int allocation_count = 0; + std::atomic_int request_deallocation_count{0}; + std::atomic_int messages_deallocation_count{0}; + }; +}; + +TEST_P(SimpleAllocatorTest, SimpleRpc) { + const int kRpcCount = 10; + std::unique_ptr allocator(new SimpleAllocator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(0, allocator->request_deallocation_count); +} + +TEST_P(SimpleAllocatorTest, RpcWithEarlyFreeRequest) { + const int kRpcCount = 10; + std::unique_ptr allocator(new SimpleAllocator); + auto mutator = [](RpcAllocatorState* allocator_state, const EchoRequest* req, + EchoResponse* resp) { + auto* info = + static_cast(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + allocator_state->FreeRequest(); + EXPECT_EQ(nullptr, info->request()); + }; + callback_service_.SetAllocatorMutator(mutator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(kRpcCount, allocator->request_deallocation_count); +} + +TEST_P(SimpleAllocatorTest, RpcWithReleaseRequest) { + const int kRpcCount = 10; + std::unique_ptr allocator(new SimpleAllocator); + std::vector released_requests; + auto mutator = [&released_requests](RpcAllocatorState* allocator_state, + const EchoRequest* req, + EchoResponse* resp) { + auto* info = + static_cast(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + released_requests.push_back(info->ReleaseRequest()); + EXPECT_EQ(nullptr, info->request()); + }; + callback_service_.SetAllocatorMutator(mutator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(0, allocator->request_deallocation_count); + EXPECT_EQ(static_cast(kRpcCount), released_requests.size()); + for (auto* req : released_requests) { + delete req; + } +} + +class ArenaAllocatorTest : public MessageAllocatorEnd2endTestBase { + public: + class ArenaAllocator : public MessageAllocator { + public: + class MessageHolderImpl : public MessageHolder { + public: + MessageHolderImpl() { + set_request( + google::protobuf::Arena::CreateMessage(&arena_)); + set_response( + google::protobuf::Arena::CreateMessage(&arena_)); + } + void Release() override { delete this; } + void FreeRequest() override { GPR_ASSERT(0); } + + private: + google::protobuf::Arena arena_; + }; + MessageHolder* AllocateMessages() override { + allocation_count++; + return new MessageHolderImpl; + } + int allocation_count = 0; + }; +}; + +TEST_P(ArenaAllocatorTest, SimpleRpc) { + const int kRpcCount = 10; + std::unique_ptr allocator(new ArenaAllocator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + EXPECT_EQ(kRpcCount, allocator->allocation_count); +} + +std::vector CreateTestScenarios(bool test_insecure) { + std::vector scenarios; + std::vector credentials_types{ + GetCredentialsProvider()->GetSecureCredentialsTypeList()}; + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + GPR_ASSERT(!credentials_types.empty()); + + Protocol parr[]{Protocol::INPROC, Protocol::TCP}; + for (Protocol p : parr) { + for (const auto& cred : credentials_types) { + // TODO(vjpai): Test inproc with secure credentials when feasible + if (p == Protocol::INPROC && + (cred != kInsecureCredentialsType || !insec_ok())) { + continue; + } + scenarios.emplace_back(p, cred); + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(NullAllocatorTest, NullAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(SimpleAllocatorTest, SimpleAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(ArenaAllocatorTest, ArenaAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/mock_test.cc b/test/cpp/end2end/mock_test.cc new file mode 100644 index 00000000..14431540 --- /dev/null +++ b/test/cpp/end2end/mock_test.cc @@ -0,0 +1,433 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include + +#include "absl/types/optional.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/echo_mock.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +using grpc::testing::DefaultReactorTestPeer; +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::EchoTestService; +using grpc::testing::MockClientReaderWriter; +using std::vector; +using ::testing::_; +using ::testing::AtLeast; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::SetArgPointee; +using ::testing::WithArg; + +namespace grpc { +namespace testing { + +namespace { +class FakeClient { + public: + explicit FakeClient(EchoTestService::StubInterface* stub) : stub_(stub) {} + + void DoEcho() { + ClientContext context; + EchoRequest request; + EchoResponse response; + request.set_message("hello world"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + } + + void DoRequestStream() { + EchoRequest request; + EchoResponse response; + + ClientContext context; + std::string msg("hello"); + std::string exp(msg); + + std::unique_ptr> cstream = + stub_->RequestStream(&context, &response); + + request.set_message(msg); + EXPECT_TRUE(cstream->Write(request)); + + msg = ", world"; + request.set_message(msg); + exp.append(msg); + EXPECT_TRUE(cstream->Write(request)); + + cstream->WritesDone(); + Status s = cstream->Finish(); + + EXPECT_EQ(exp, response.message()); + EXPECT_TRUE(s.ok()); + } + + void DoResponseStream() { + EchoRequest request; + EchoResponse response; + request.set_message("hello world"); + + ClientContext context; + std::unique_ptr> cstream = + stub_->ResponseStream(&context, request); + + std::string exp = ""; + EXPECT_TRUE(cstream->Read(&response)); + exp.append(response.message() + " "); + + EXPECT_TRUE(cstream->Read(&response)); + exp.append(response.message()); + + EXPECT_FALSE(cstream->Read(&response)); + EXPECT_EQ(request.message(), exp); + + Status s = cstream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void DoBidiStream() { + EchoRequest request; + EchoResponse response; + ClientContext context; + std::string msg("hello"); + + std::unique_ptr> + stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void ResetStub(EchoTestService::StubInterface* stub) { stub_ = stub; } + + private: + EchoTestService::StubInterface* stub_; +}; + +class CallbackTestServiceImpl : public EchoTestService::CallbackService { + public: + ServerUnaryReactor* Echo(CallbackServerContext* context, + const EchoRequest* request, + EchoResponse* response) override { + // Make the mock service explicitly treat empty input messages as invalid + // arguments so that we can test various results of status. In general, a + // mocked service should just use the original service methods, but we are + // adding this variance in Status return value just to improve coverage in + // this test. + auto* reactor = context->DefaultReactor(); + if (request->message().length() > 0) { + response->set_message(request->message()); + reactor->Finish(Status::OK); + } else { + reactor->Finish(Status(StatusCode::INVALID_ARGUMENT, "Invalid request")); + } + return reactor; + } +}; + +class MockCallbackTest : public ::testing::Test { + protected: + CallbackTestServiceImpl service_; + ServerContext context_; +}; + +TEST_F(MockCallbackTest, MockedCallSucceedsWithWait) { + CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + struct { + grpc::internal::Mutex mu; + grpc::internal::CondVar cv; + absl::optional ABSL_GUARDED_BY(mu) status; + } status; + DefaultReactorTestPeer peer(&ctx, [&](::grpc::Status s) { + grpc::internal::MutexLock l(&status.mu); + status.status = std::move(s); + status.cv.Signal(); + }); + + req.set_message("mock 1"); + auto* reactor = service_.Echo(&ctx, &req, &resp); + + grpc::internal::MutexLock l(&status.mu); + while (!status.status.has_value()) { + status.cv.Wait(&status.mu); + } + + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_TRUE(peer.test_status().ok()); + EXPECT_TRUE(status.status.has_value()); + EXPECT_TRUE(status.status.value().ok()); + EXPECT_EQ(req.message(), resp.message()); +} + +TEST_F(MockCallbackTest, MockedCallSucceeds) { + CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + DefaultReactorTestPeer peer(&ctx); + + req.set_message("ha ha, consider yourself mocked."); + auto* reactor = service_.Echo(&ctx, &req, &resp); + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_TRUE(peer.test_status().ok()); +} + +TEST_F(MockCallbackTest, MockedCallFails) { + CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + DefaultReactorTestPeer peer(&ctx); + + auto* reactor = service_.Echo(&ctx, &req, &resp); + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_EQ(peer.test_status().error_code(), StatusCode::INVALID_ARGUMENT); +} + +class TestServiceImpl : public EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + return Status::OK; + } + + Status RequestStream(ServerContext* /*context*/, + ServerReader* reader, + EchoResponse* response) override { + EchoRequest request; + std::string resp(""); + while (reader->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + resp.append(request.message()); + } + response->set_message(resp); + return Status::OK; + } + + Status ResponseStream(ServerContext* /*context*/, const EchoRequest* request, + ServerWriter* writer) override { + EchoResponse response; + vector tokens = split(request->message()); + for (const std::string& token : tokens) { + response.set_message(token); + writer->Write(response); + } + return Status::OK; + } + + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + } + return Status::OK; + } + + private: + vector split(const std::string& input) { + std::string buff(""); + vector result; + + for (auto n : input) { + if (n != ' ') { + buff += n; + continue; + } + if (buff.empty()) continue; + result.push_back(buff); + buff = ""; + } + if (!buff.empty()) result.push_back(buff); + + return result; + } +}; + +class MockTest : public ::testing::Test { + protected: + MockTest() {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +// Do one real rpc and one mocked one +TEST_F(MockTest, SimpleRpc) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoEcho(); + MockEchoTestServiceStub stub; + EchoResponse resp; + resp.set_message("hello world"); + EXPECT_CALL(stub, Echo(_, _, _)) + .Times(AtLeast(1)) + .WillOnce(DoAll(SetArgPointee<2>(resp), Return(Status::OK))); + client.ResetStub(&stub); + client.DoEcho(); +} + +TEST_F(MockTest, ClientStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoRequestStream(); + + MockEchoTestServiceStub stub; + auto w = new MockClientWriter(); + EchoResponse resp; + resp.set_message("hello, world"); + + EXPECT_CALL(*w, Write(_, _)).Times(2).WillRepeatedly(Return(true)); + EXPECT_CALL(*w, WritesDone()); + EXPECT_CALL(*w, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, RequestStreamRaw(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(resp), Return(w))); + client.ResetStub(&stub); + client.DoRequestStream(); +} + +TEST_F(MockTest, ServerStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoResponseStream(); + + MockEchoTestServiceStub stub; + auto r = new MockClientReader(); + EchoResponse resp1; + resp1.set_message("hello"); + EchoResponse resp2; + resp2.set_message("world"); + + EXPECT_CALL(*r, Read(_)) + .WillOnce(DoAll(SetArgPointee<0>(resp1), Return(true))) + .WillOnce(DoAll(SetArgPointee<0>(resp2), Return(true))) + .WillOnce(Return(false)); + EXPECT_CALL(*r, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, ResponseStreamRaw(_, _)).WillOnce(Return(r)); + + client.ResetStub(&stub); + client.DoResponseStream(); +} + +ACTION_P(copy, msg) { arg0->set_message(msg->message()); } + +TEST_F(MockTest, BidiStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoBidiStream(); + MockEchoTestServiceStub stub; + auto rw = new MockClientReaderWriter(); + EchoRequest msg; + + EXPECT_CALL(*rw, Write(_, _)) + .Times(3) + .WillRepeatedly(DoAll(SaveArg<0>(&msg), Return(true))); + EXPECT_CALL(*rw, Read(_)) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(Return(false)); + EXPECT_CALL(*rw, WritesDone()); + EXPECT_CALL(*rw, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, BidiStreamRaw(_)).WillOnce(Return(rw)); + client.ResetStub(&stub); + client.DoBidiStream(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/nonblocking_test.cc b/test/cpp/end2end/nonblocking_test.cc new file mode 100644 index 00000000..cb00810d --- /dev/null +++ b/test/cpp/end2end/nonblocking_test.cc @@ -0,0 +1,218 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_POSIX_SOCKET +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET + +#include + +#ifdef GRPC_POSIX_SOCKET +// Thread-local variable to so that only polls from this test assert +// non-blocking (not polls from resolver, timer thread, etc), and only when the +// thread is waiting on polls caused by CompletionQueue::AsyncNext (not for +// picking a port or other reasons). +static GPR_THREAD_LOCAL(bool) g_is_nonblocking_poll; + +namespace { + +int maybe_assert_non_blocking_poll(struct pollfd* pfds, nfds_t nfds, + int timeout) { + // Only assert that this poll should have zero timeout if we're in the + // middle of a zero-timeout CQ Next. + if (g_is_nonblocking_poll) { + GPR_ASSERT(timeout == 0); + } + return poll(pfds, nfds, timeout); +} + +} // namespace + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return reinterpret_cast(static_cast(i)); } +int detag(void* p) { return static_cast(reinterpret_cast(p)); } + +class NonblockingTest : public ::testing::Test { + protected: + NonblockingTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + + // Setup server + BuildAndStartServer(); + } + + bool LoopForTag(void** tag, bool* ok) { + // Temporarily set the thread-local nonblocking poll flag so that the polls + // caused by this loop are indeed sent by the library with zero timeout. + bool orig_val = g_is_nonblocking_poll; + g_is_nonblocking_poll = true; + for (;;) { + auto r = cq_->AsyncNext(tag, ok, gpr_time_0(GPR_CLOCK_REALTIME)); + if (r == CompletionQueue::SHUTDOWN) { + g_is_nonblocking_poll = orig_val; + return false; + } else if (r == CompletionQueue::GOT_EVENT) { + g_is_nonblocking_poll = orig_val; + return true; + } + } + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (LoopForTag(&ignored_tag, &ignored_ok)) { + } + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + void BuildAndStartServer() { + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + service_ = + absl::make_unique(); + builder.RegisterService(service_.get()); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), grpc::InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + void SendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message("hello non-blocking world"); + std::unique_ptr> response_reader( + stub_->PrepareAsyncEcho(&cli_ctx, send_request, cq_.get())); + + response_reader->StartCall(); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, + cq_.get(), cq_.get(), tag(2)); + + void* got_tag; + bool ok; + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + EXPECT_EQ(detag(got_tag), 2); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + + int tagsum = 0; + int tagprod = 1; + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + tagsum += detag(got_tag); + tagprod *= detag(got_tag); + + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + tagsum += detag(got_tag); + tagprod *= detag(got_tag); + + EXPECT_EQ(tagsum, 7); + EXPECT_EQ(tagprod, 12); + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + std::unique_ptr cq_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::unique_ptr service_; + std::ostringstream server_address_; + int port_; +}; + +TEST_F(NonblockingTest, SimpleRpc) { + ResetStub(); + SendRpc(10); +} + +} // namespace +} // namespace testing +} // namespace grpc + +#endif // GRPC_POSIX_SOCKET + +int main(int argc, char** argv) { +#ifdef GRPC_POSIX_SOCKET + // Override the poll function before anything else can happen + grpc_poll_function = maybe_assert_non_blocking_poll; + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + // Start the nonblocking poll thread-local variable as false because the + // thread that issues RPCs starts by picking a port (which has non-zero + // timeout). + g_is_nonblocking_poll = false; + + int ret = RUN_ALL_TESTS(); + + return ret; +#else // GRPC_POSIX_SOCKET + (void)argc; + (void)argv; + return 0; +#endif // GRPC_POSIX_SOCKET +} diff --git a/test/cpp/end2end/port_sharing_end2end_test.cc b/test/cpp/end2end/port_sharing_end2end_test.cc new file mode 100644 index 00000000..1bc7a01f --- /dev/null +++ b/test/cpp/end2end/port_sharing_end2end_test.cc @@ -0,0 +1,375 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_tcp_server.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_TCP_SERVER + +#include "src/core/lib/iomgr/tcp_posix.h" + +namespace grpc { +namespace testing { +namespace { + +class TestScenario { + public: + TestScenario(bool server_port, bool pending_data, + const std::string& creds_type) + : server_has_port(server_port), + queue_pending_data(pending_data), + credentials_type(creds_type) {} + void Log() const; + // server has its own port or not + bool server_has_port; + // whether tcp server should read some data before handoff + bool queue_pending_data; + const std::string credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{server_has_port=" + << (scenario.server_has_port ? "true" : "false") + << ", queue_pending_data=" + << (scenario.queue_pending_data ? "true" : "false") + << ", credentials='" << scenario.credentials_type << "'}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_ERROR, "%s", out.str().c_str()); +} + +// Set up a test tcp server which is in charge of accepting connections and +// handing off the connections as fds. +class TestTcpServer { + public: + TestTcpServer() + : shutdown_(false), + queue_data_(false), + port_(grpc_pick_unused_port_or_die()) { + std::ostringstream server_address; + server_address << "localhost:" << port_; + address_ = server_address.str(); + test_tcp_server_init(&tcp_server_, &TestTcpServer::OnConnect, this); + GRPC_CLOSURE_INIT(&on_fd_released_, &TestTcpServer::OnFdReleased, this, + grpc_schedule_on_exec_ctx); + } + + ~TestTcpServer() { + running_thread_.join(); + test_tcp_server_destroy(&tcp_server_); + grpc_recycle_unused_port(port_); + } + + // Read some data before handing off the connection. + void SetQueueData() { queue_data_ = true; } + + void Start() { + test_tcp_server_start(&tcp_server_, port_); + gpr_log(GPR_INFO, "Test TCP server started at %s", address_.c_str()); + } + + const std::string& address() { return address_; } + + void SetAcceptor( + std::unique_ptr acceptor) { + connection_acceptor_ = std::move(acceptor); + } + + void Run() { + running_thread_ = std::thread([this]() { + while (true) { + { + std::lock_guard lock(mu_); + if (shutdown_) { + return; + } + } + test_tcp_server_poll(&tcp_server_, 1); + } + }); + } + + void Shutdown() { + std::lock_guard lock(mu_); + shutdown_ = true; + } + + static void OnConnect(void* arg, grpc_endpoint* tcp, + grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor) { + auto* self = static_cast(arg); + self->OnConnect(tcp, accepting_pollset, acceptor); + } + + static void OnFdReleased(void* arg, grpc_error_handle err) { + auto* self = static_cast(arg); + self->OnFdReleased(err); + } + + private: + void OnConnect(grpc_endpoint* tcp, grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + std::string peer(grpc_endpoint_get_peer(tcp)); + gpr_log(GPR_INFO, "Got incoming connection! from %s", peer.c_str()); + EXPECT_FALSE(acceptor->external_connection); + listener_fd_ = grpc_tcp_server_port_fd( + acceptor->from_server, acceptor->port_index, acceptor->fd_index); + gpr_free(acceptor); + grpc_tcp_destroy_and_release_fd(tcp, &fd_, &on_fd_released_); + } + + void OnFdReleased(grpc_error_handle err) { + EXPECT_EQ(GRPC_ERROR_NONE, err); + experimental::ExternalConnectionAcceptor::NewConnectionParameters p; + p.listener_fd = listener_fd_; + p.fd = fd_; + if (queue_data_) { + char buf[1024]; + ssize_t read_bytes = 0; + while (read_bytes <= 0) { + read_bytes = read(fd_, buf, 1024); + } + Slice data(buf, read_bytes); + p.read_buffer = ByteBuffer(&data, 1); + } + gpr_log(GPR_INFO, "Handing off fd %d with data size %d from listener fd %d", + fd_, static_cast(p.read_buffer.Length()), listener_fd_); + connection_acceptor_->HandleNewConnection(&p); + } + + std::mutex mu_; + bool shutdown_; + + int listener_fd_ = -1; + int fd_ = -1; + bool queue_data_ = false; + + grpc_closure on_fd_released_; + std::thread running_thread_; + int port_ = -1; + std::string address_; + std::unique_ptr + connection_acceptor_; + test_tcp_server tcp_server_; +}; + +class PortSharingEnd2endTest : public ::testing::TestWithParam { + protected: + PortSharingEnd2endTest() : is_server_started_(false), first_picked_port_(0) { + GetParam().Log(); + } + + void SetUp() override { + if (GetParam().queue_pending_data) { + tcp_server1_.SetQueueData(); + tcp_server2_.SetQueueData(); + } + tcp_server1_.Start(); + tcp_server2_.Start(); + ServerBuilder builder; + if (GetParam().server_has_port) { + int port = grpc_pick_unused_port_or_die(); + first_picked_port_ = port; + server_address_ << "localhost:" << port; + auto creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + builder.AddListeningPort(server_address_.str(), creds); + gpr_log(GPR_INFO, "gRPC server listening on %s", + server_address_.str().c_str()); + } + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + auto acceptor1 = builder.experimental().AddExternalConnectionAcceptor( + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD, + server_creds); + tcp_server1_.SetAcceptor(std::move(acceptor1)); + auto acceptor2 = builder.experimental().AddExternalConnectionAcceptor( + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD, + server_creds); + tcp_server2_.SetAcceptor(std::move(acceptor2)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + is_server_started_ = true; + + tcp_server1_.Run(); + tcp_server2_.Run(); + } + + void TearDown() override { + tcp_server1_.Shutdown(); + tcp_server2_.Shutdown(); + if (is_server_started_) { + server_->Shutdown(); + } + if (first_picked_port_ > 0) { + grpc_recycle_unused_port(first_picked_port_); + } + } + + void ResetStubs() { + EXPECT_TRUE(is_server_started_); + ChannelArguments args; + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + channel_handoff1_ = + CreateCustomChannel(tcp_server1_.address(), channel_creds, args); + stub_handoff1_ = EchoTestService::NewStub(channel_handoff1_); + channel_handoff2_ = + CreateCustomChannel(tcp_server2_.address(), channel_creds, args); + stub_handoff2_ = EchoTestService::NewStub(channel_handoff2_); + if (GetParam().server_has_port) { + ChannelArguments direct_args; + direct_args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto direct_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &direct_args); + channel_direct_ = + CreateCustomChannel(server_address_.str(), direct_creds, direct_args); + stub_direct_ = EchoTestService::NewStub(channel_direct_); + } + } + + bool is_server_started_; + // channel/stub to the test tcp server, the connection will be handed to the + // grpc server. + std::shared_ptr channel_handoff1_; + std::unique_ptr stub_handoff1_; + std::shared_ptr channel_handoff2_; + std::unique_ptr stub_handoff2_; + // channel/stub to talk to the grpc server directly, if applicable. + std::shared_ptr channel_direct_; + std::unique_ptr stub_direct_; + std::unique_ptr server_; + std::ostringstream server_address_; + TestServiceImpl service_; + TestTcpServer tcp_server1_; + TestTcpServer tcp_server2_; + int first_picked_port_; +}; + +static void SendRpc(EchoTestService::Stub* stub, int num_rpcs) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + Status s = stub->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + } +} + +std::vector CreateTestScenarios() { + std::vector scenarios; + std::vector credentials_types; + +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + credentials_types = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType, + nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } + + GPR_ASSERT(!credentials_types.empty()); + for (const auto& cred : credentials_types) { + for (auto server_has_port : {true, false}) { + for (auto queue_pending_data : {true, false}) { + scenarios.emplace_back(server_has_port, queue_pending_data, cred); + } + } + } + return scenarios; +} + +TEST_P(PortSharingEnd2endTest, HandoffAndDirectCalls) { + ResetStubs(); + SendRpc(stub_handoff1_.get(), 5); + if (GetParam().server_has_port) { + SendRpc(stub_direct_.get(), 5); + } +} + +TEST_P(PortSharingEnd2endTest, MultipleHandoff) { + for (int i = 0; i < 3; i++) { + ResetStubs(); + SendRpc(stub_handoff2_.get(), 1); + } +} + +TEST_P(PortSharingEnd2endTest, TwoHandoffPorts) { + for (int i = 0; i < 3; i++) { + ResetStubs(); + SendRpc(stub_handoff1_.get(), 5); + SendRpc(stub_handoff2_.get(), 5); + } +} + +INSTANTIATE_TEST_SUITE_P(PortSharingEnd2end, PortSharingEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios())); + +} // namespace +} // namespace testing +} // namespace grpc + +#endif // GRPC_POSIX_SOCKET_TCP_SERVER + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/proto_server_reflection_test.cc b/test/cpp/end2end/proto_server_reflection_test.cc new file mode 100644 index 00000000..b9d14e47 --- /dev/null +++ b/test/cpp/end2end/proto_server_reflection_test.cc @@ -0,0 +1,152 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +namespace grpc { +namespace testing { + +class ProtoServerReflectionTest : public ::testing::Test { + public: + ProtoServerReflectionTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + ref_desc_pool_ = protobuf::DescriptorPool::generated_pool(); + + ServerBuilder builder; + std::string server_address = "localhost:" + to_string(port_); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + std::shared_ptr channel = + grpc::CreateChannel(target, InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + desc_db_ = absl::make_unique(channel); + desc_pool_ = absl::make_unique(desc_db_.get()); + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + void CompareService(const std::string& service) { + const protobuf::ServiceDescriptor* service_desc = + desc_pool_->FindServiceByName(service); + const protobuf::ServiceDescriptor* ref_service_desc = + ref_desc_pool_->FindServiceByName(service); + EXPECT_TRUE(service_desc != nullptr); + EXPECT_TRUE(ref_service_desc != nullptr); + EXPECT_EQ(service_desc->DebugString(), ref_service_desc->DebugString()); + + const protobuf::FileDescriptor* file_desc = service_desc->file(); + if (known_files_.find(file_desc->package() + "/" + file_desc->name()) != + known_files_.end()) { + EXPECT_EQ(file_desc->DebugString(), + ref_service_desc->file()->DebugString()); + known_files_.insert(file_desc->package() + "/" + file_desc->name()); + } + + for (int i = 0; i < service_desc->method_count(); ++i) { + CompareMethod(service_desc->method(i)->full_name()); + } + } + + void CompareMethod(const std::string& method) { + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(method); + const protobuf::MethodDescriptor* ref_method_desc = + ref_desc_pool_->FindMethodByName(method); + EXPECT_TRUE(method_desc != nullptr); + EXPECT_TRUE(ref_method_desc != nullptr); + EXPECT_EQ(method_desc->DebugString(), ref_method_desc->DebugString()); + + CompareType(method_desc->input_type()->full_name()); + CompareType(method_desc->output_type()->full_name()); + } + + void CompareType(const std::string& type) { + if (known_types_.find(type) != known_types_.end()) { + return; + } + + const protobuf::Descriptor* desc = desc_pool_->FindMessageTypeByName(type); + const protobuf::Descriptor* ref_desc = + ref_desc_pool_->FindMessageTypeByName(type); + EXPECT_TRUE(desc != nullptr); + EXPECT_TRUE(ref_desc != nullptr); + EXPECT_EQ(desc->DebugString(), ref_desc->DebugString()); + } + + protected: + std::unique_ptr server_; + std::unique_ptr stub_; + std::unique_ptr desc_db_; + std::unique_ptr desc_pool_; + std::unordered_set known_files_; + std::unordered_set known_types_; + const protobuf::DescriptorPool* ref_desc_pool_; + int port_; + reflection::ProtoServerReflectionPlugin plugin_; +}; + +TEST_F(ProtoServerReflectionTest, CheckResponseWithLocalDescriptorPool) { + ResetStub(); + + std::vector services; + desc_db_->GetServices(&services); + // The service list has at least one service (reflection servcie). + EXPECT_TRUE(!services.empty()); + + for (auto it = services.begin(); it != services.end(); ++it) { + CompareService(*it); + } +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/raw_end2end_test.cc b/test/cpp/end2end/raw_end2end_test.cc new file mode 100644 index 00000000..e0c29cd6 --- /dev/null +++ b/test/cpp/end2end/raw_end2end_test.cc @@ -0,0 +1,370 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +namespace { + +void* tag(int i) { return reinterpret_cast(i); } +int detag(void* p) { return static_cast(reinterpret_cast(p)); } + +class Verifier { + public: + Verifier() {} + + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + expectations_[tag(i)] = expect_ok; + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { + GPR_ASSERT(!expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, false); + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } + } + + std::map expectations_; +}; + +class RawEnd2EndTest : public ::testing::Test { + protected: + RawEnd2EndTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) { + } + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + template + std::unique_ptr BuildAndStartServer() { + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + std::unique_ptr service(new ServerType()); + builder.RegisterService(service.get()); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + return service; + } + + void ResetStub() { + ChannelArguments args; + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), grpc::InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr cq_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + int port_; + + // For the client application to populate and send to server. + EchoRequest send_request_; + ::grpc::ByteBuffer send_request_buffer_; + + // For the server to give to gRPC to be populated by incoming request + // from client. + EchoRequest recv_request_; + ::grpc::ByteBuffer recv_request_buffer_; + + // For the server application to populate and send back to client. + EchoResponse send_response_; + ::grpc::ByteBuffer send_response_buffer_; + + // For the client to give to gRPC to be populated by incoming response + // from server. + EchoResponse recv_response_; + ::grpc::ByteBuffer recv_response_buffer_; + Status recv_status_; + + // Both sides need contexts + ClientContext cli_ctx_; + ServerContext srv_ctx_; +}; + +// Regular Async, both peers use proto +TEST_F(RawEnd2EndTest, PureAsyncService) { + typedef grpc::testing::EchoTestService::AsyncService SType; + ResetStub(); + auto service = BuildAndStartServer(); + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx_); + + send_request_.set_message("hello"); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get())); + service->RequestEcho(&srv_ctx_, &recv_request_, &response_writer, cq_.get(), + cq_.get(), tag(2)); + response_reader->Finish(&recv_response_, &recv_status_, tag(4)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + send_response_.set_message(recv_request_.message()); + response_writer.Finish(send_response_, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, unary +TEST_F(RawEnd2EndTest, RawServerUnary) { + typedef grpc::testing::EchoTestService::WithRawMethod_Echo< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer(); + grpc::GenericServerAsyncResponseWriter response_writer(&srv_ctx_); + + send_request_.set_message("hello unary"); + std::unique_ptr> response_reader( + stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get())); + service->RequestEcho(&srv_ctx_, &recv_request_buffer_, &response_writer, + cq_.get(), cq_.get(), tag(2)); + response_reader->Finish(&recv_response_, &recv_status_, tag(4)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_request_buffer_, &recv_request_)); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + send_response_.set_message(recv_request_.message()); + EXPECT_TRUE( + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_)); + response_writer.Finish(send_response_buffer_, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, client streaming +TEST_F(RawEnd2EndTest, RawServerClientStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_RequestStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer(); + + grpc::GenericServerAsyncReader srv_stream(&srv_ctx_); + + send_request_.set_message("hello client streaming"); + std::unique_ptr> cli_stream( + stub_->AsyncRequestStream(&cli_ctx_, &recv_response_, cq_.get(), tag(1))); + + service->RequestRequestStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get()); + + cli_stream->Write(send_request_, tag(3)); + srv_stream.Read(&recv_request_buffer_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + cli_stream->Write(send_request_, tag(5)); + srv_stream.Read(&recv_request_buffer_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request_buffer_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Finish(send_response_buffer_, Status::OK, tag(9)); + cli_stream->Finish(&recv_status_, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, server streaming +TEST_F(RawEnd2EndTest, RawServerServerStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_ResponseStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer(); + grpc::GenericServerAsyncWriter srv_stream(&srv_ctx_); + + send_request_.set_message("hello server streaming"); + std::unique_ptr> cli_stream( + stub_->AsyncResponseStream(&cli_ctx_, send_request_, cq_.get(), tag(1))); + + service->RequestResponseStream(&srv_ctx_, &recv_request_buffer_, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Write(send_response_buffer_, tag(3)); + cli_stream->Read(&recv_response_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + srv_stream.Write(send_response_buffer_, tag(5)); + cli_stream->Read(&recv_response_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status_, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, bidi streaming +TEST_F(RawEnd2EndTest, RawServerBidiStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_BidiStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer(); + + grpc::GenericServerAsyncReaderWriter srv_stream(&srv_ctx_); + + send_request_.set_message("hello bidi streaming"); + std::unique_ptr> + cli_stream(stub_->AsyncBidiStream(&cli_ctx_, cq_.get(), tag(1))); + + service->RequestBidiStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + cli_stream->Write(send_request_, tag(3)); + srv_stream.Read(&recv_request_buffer_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Write(send_response_buffer_, tag(5)); + cli_stream->Read(&recv_response_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request_buffer_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status_, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status_.ok()); +} + +// Testing that this pattern compiles +TEST_F(RawEnd2EndTest, CompileTest) { + typedef grpc::testing::EchoTestService::WithRawMethod_Echo< + grpc::testing::EchoTestService::AsyncService> + SType; + ResetStub(); + auto service = BuildAndStartServer(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + // Change the backup poll interval from 5s to 100ms to speed up the + // ReconnectChannel test + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/rls_end2end_test.cc b/test/cpp/end2end/rls_end2end_test.cc new file mode 100644 index 00000000..8b4e43e1 --- /dev/null +++ b/test/cpp/end2end/rls_end2end_test.cc @@ -0,0 +1,1458 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// FIXME: add tests: +// - cache eviction via cleanup timer (based on age) +// - RLS channel is down; wait_for_ready request is sent and RLS request fails +// and goes into backoff; RLS channel comes back up before backoff timer +// fires; request is processed at that point + +#include +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/uri/uri_parser.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/lookup/v1/rls.grpc.pb.h" +#include "src/proto/grpc/lookup/v1/rls.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/resolve_localhost_ip46.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_lb_policies.h" +#include "test/cpp/end2end/counted_service.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_config.h" + +using ::grpc::lookup::v1::RouteLookupRequest; +using ::grpc::lookup::v1::RouteLookupResponse; + +namespace grpc { +namespace testing { +namespace { + +const char* kServerName = "test.google.fr"; +const char* kRequestMessage = "Live long and prosper."; + +const char* kCallCredsMdKey = "call_cred_name"; +const char* kCallCredsMdValue = "call_cred_value"; + +const char* kTestKey = "test_key"; +const char* kTestValue = "test_value"; +const char* kHostKey = "host_key"; +const char* kServiceKey = "service_key"; +const char* kServiceValue = "grpc.testing.EchoTestService"; +const char* kMethodKey = "method_key"; +const char* kMethodValue = "Echo"; +const char* kConstantKey = "constant_key"; +const char* kConstantValue = "constant_value"; + +using BackendService = CountedService; +using RlsService = + CountedService; + +class RlsServiceImpl : public RlsService { + public: + ::grpc::Status RouteLookup(::grpc::ServerContext* context, + const RouteLookupRequest* request, + RouteLookupResponse* response) override { + gpr_log(GPR_INFO, "RLS: Received request: %s", + request->DebugString().c_str()); + // RLS server should see call creds. + EXPECT_THAT(context->client_metadata(), + ::testing::Contains( + ::testing::Pair(kCallCredsMdKey, kCallCredsMdValue))); + IncreaseRequestCount(); + EXPECT_EQ(request->target_type(), "grpc"); + // See if we have a configured response for this request. + ResponseData res; + { + grpc::internal::MutexLock lock(&mu_); + auto it = responses_.find(*request); + if (it == responses_.end()) { + gpr_log(GPR_INFO, "RLS: no matching request, returning INTERNAL"); + unmatched_requests_.push_back(*request); + return Status(StatusCode::INTERNAL, "no response entry"); + } + res = it->second; + } + // Configured response found, so use it. + if (res.response_delay > 0) { + gpr_sleep_until( + grpc_timeout_milliseconds_to_deadline(res.response_delay)); + } + IncreaseResponseCount(); + *response = res.response; + gpr_log(GPR_INFO, "RLS: returning configured response: %s", + response->DebugString().c_str()); + return Status::OK; + } + + void Start() {} + + void Shutdown() {} + + void SetResponse(RouteLookupRequest request, RouteLookupResponse response, + grpc_millis response_delay = 0) { + grpc::internal::MutexLock lock(&mu_); + responses_[std::move(request)] = {std::move(response), response_delay}; + } + + void RemoveResponse(const RouteLookupRequest& request) { + grpc::internal::MutexLock lock(&mu_); + responses_.erase(request); + } + + std::vector GetUnmatchedRequests() { + grpc::internal::MutexLock lock(&mu_); + return std::move(unmatched_requests_); + } + + private: + // Sorting thunk for RouteLookupRequest. + struct RlsRequestLessThan { + bool operator()(const RouteLookupRequest& req1, + const RouteLookupRequest& req2) const { + std::map key_map1( + req1.key_map().begin(), req1.key_map().end()); + std::map key_map2( + req2.key_map().begin(), req2.key_map().end()); + if (key_map1 < key_map2) return true; + if (req1.reason() < req2.reason()) return true; + if (req1.stale_header_data() < req2.stale_header_data()) return true; + return false; + } + }; + + struct ResponseData { + RouteLookupResponse response; + grpc_millis response_delay; + }; + + grpc::internal::Mutex mu_; + std::map responses_ + ABSL_GUARDED_BY(&mu_); + std::vector unmatched_requests_ ABSL_GUARDED_BY(&mu_); +}; + +// Subclass of TestServiceImpl that increments a request counter for +// every call to the Echo Rpc. +class MyTestServiceImpl : public BackendService { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + // Backend should see call creds. + EXPECT_THAT(context->client_metadata(), + ::testing::Contains( + ::testing::Pair(kCallCredsMdKey, kCallCredsMdValue))); + IncreaseRequestCount(); + auto client_metadata = context->client_metadata(); + auto range = client_metadata.equal_range("X-Google-RLS-Data"); + { + grpc::internal::MutexLock lock(&mu_); + for (auto it = range.first; it != range.second; ++it) { + rls_header_data_.insert( + std::string(it->second.begin(), it->second.length())); + } + } + IncreaseResponseCount(); + return TestServiceImpl::Echo(context, request, response); + } + + std::set rls_data() { + grpc::internal::MutexLock lock(&mu_); + return std::move(rls_header_data_); + } + + void Start() {} + + void Shutdown() {} + + private: + grpc::internal::Mutex mu_; + std::set rls_header_data_ ABSL_GUARDED_BY(&mu_); +}; + +class FakeResolverResponseGeneratorWrapper { + public: + FakeResolverResponseGeneratorWrapper() + : response_generator_(grpc_core::MakeRefCounted< + grpc_core::FakeResolverResponseGenerator>()) {} + + void SetNextResolution(absl::string_view service_config_json) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetResponse(BuildFakeResults(service_config_json)); + } + + grpc_core::FakeResolverResponseGenerator* Get() const { + return response_generator_.get(); + } + + private: + static grpc_core::Resolver::Result BuildFakeResults( + absl::string_view service_config_json) { + grpc_core::Resolver::Result result; + result.service_config_error = GRPC_ERROR_NONE; + result.service_config = grpc_core::ServiceConfig::Create( + result.args, service_config_json, &result.service_config_error); + EXPECT_EQ(result.service_config_error, GRPC_ERROR_NONE) + << "JSON: " << service_config_json + << "Error: " << grpc_error_std_string(result.service_config_error); + EXPECT_NE(result.service_config, nullptr); + return result; + } + + grpc_core::RefCountedPtr + response_generator_; +}; + +class RlsEnd2endTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + gpr_setenv("GRPC_EXPERIMENTAL_ENABLE_RLS_LB_POLICY", "true"); + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); + grpc_init(); + grpc_core::RegisterFixedAddressLoadBalancingPolicy(); + } + + static void TearDownTestSuite() { + grpc_shutdown_blocking(); + gpr_unsetenv("GRPC_EXPERIMENTAL_ENABLE_RLS_LB_POLICY"); + } + + void SetUp() override { + bool localhost_resolves_to_ipv4 = false; + bool localhost_resolves_to_ipv6 = false; + grpc_core::LocalhostResolves(&localhost_resolves_to_ipv4, + &localhost_resolves_to_ipv6); + ipv6_only_ = !localhost_resolves_to_ipv4 && localhost_resolves_to_ipv6; + rls_server_ = absl::make_unique>("rls"); + rls_server_->Start(); + resolver_response_generator_ = + absl::make_unique(); + ResetStub(); + } + + void TearDown() override { + ShutdownBackends(); + rls_server_->Shutdown(); + } + + void ResetStub(const char* expected_authority = kServerName) { + ChannelArguments args; + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + resolver_response_generator_->Get()); + args.SetString(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS, expected_authority); + grpc_channel_credentials* channel_creds = + grpc_fake_transport_security_credentials_create(); + grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create( + kCallCredsMdKey, kCallCredsMdValue, false); + auto creds = std::make_shared( + grpc_composite_channel_credentials_create(channel_creds, call_creds, + nullptr)); + call_creds->Unref(); + channel_creds->Unref(); + channel_ = ::grpc::CreateCustomChannel( + absl::StrCat("fake:///", kServerName).c_str(), std::move(creds), args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void ShutdownBackends() { + for (auto& server : backends_) { + server->Shutdown(); + } + } + + void StartBackends(size_t num_servers) { + backends_.clear(); + for (size_t i = 0; i < num_servers; ++i) { + backends_.push_back( + absl::make_unique>("backend")); + backends_.back()->Start(); + } + } + + std::string TargetStringForPort(int port) { + if (ipv6_only_) return absl::StrCat("ipv6:[::1]:", port); + return absl::StrCat("ipv4:127.0.0.1:", port); + } + + static RouteLookupRequest BuildRlsRequest( + std::map key, + RouteLookupRequest::Reason reason = RouteLookupRequest::REASON_MISS, + const char* stale_header_data = "") { + RouteLookupRequest request; + request.set_target_type("grpc"); + request.mutable_key_map()->insert(key.begin(), key.end()); + request.set_reason(reason); + request.set_stale_header_data(stale_header_data); + return request; + } + + static RouteLookupResponse BuildRlsResponse(std::vector targets, + const char* header_data = "") { + RouteLookupResponse response; + response.mutable_targets()->Add(targets.begin(), targets.end()); + response.set_header_data(header_data); + return response; + } + + struct RpcOptions { + int timeout_ms = 1000; + bool wait_for_ready = false; + std::vector> metadata; + + RpcOptions() {} + + RpcOptions& set_timeout_ms(int rpc_timeout_ms) { + timeout_ms = rpc_timeout_ms; + return *this; + } + + RpcOptions& set_wait_for_ready(bool rpc_wait_for_ready) { + wait_for_ready = rpc_wait_for_ready; + return *this; + } + + RpcOptions& set_metadata( + std::vector> rpc_metadata) { + metadata = std::move(rpc_metadata); + return *this; + } + + // Populates context. + void SetupRpc(ClientContext* context) const { + for (const auto& item : metadata) { + context->AddMetadata(item.first, item.second); + } + if (timeout_ms != 0) { + context->set_deadline( + grpc_timeout_milliseconds_to_deadline(timeout_ms)); + } + if (wait_for_ready) context->set_wait_for_ready(true); + } + }; + + Status SendRpc(const RpcOptions& rpc_options = RpcOptions(), + EchoResponse* response = nullptr) { + EchoResponse local_response; + if (response == nullptr) response = &local_response; + ClientContext context; + rpc_options.SetupRpc(&context); + EchoRequest request; + request.set_message(kRequestMessage); + return stub_->Echo(&context, request, response); + } + + void CheckRpcSendOk(const grpc_core::DebugLocation& location, + const RpcOptions& rpc_options = RpcOptions()) { + EchoResponse response; + Status status = SendRpc(rpc_options, &response); + ASSERT_TRUE(status.ok()) << location.file() << ":" << location.line() + << ": RPC failed: " << status.error_code() << ": " + << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage) + << location.file() << ":" << location.line(); + } + + void CheckRpcSendFailure(const grpc_core::DebugLocation& location, + const RpcOptions& rpc_options = RpcOptions()) { + Status status = SendRpc(rpc_options); + ASSERT_FALSE(status.ok()) << location.file() << ":" << location.line(); + } + + class ServiceConfigBuilder { + public: + explicit ServiceConfigBuilder(int rls_server_port) + : rls_server_port_(rls_server_port) {} + + ServiceConfigBuilder& set_lookup_service_timeout(grpc_millis timeout) { + lookup_service_timeout_ = timeout * grpc_test_slowdown_factor(); + return *this; + } + + ServiceConfigBuilder& set_default_target(std::string default_target) { + default_target_ = std::move(default_target); + return *this; + } + + ServiceConfigBuilder& set_max_age(grpc_millis max_age) { + max_age_ = max_age * grpc_test_slowdown_factor(); + return *this; + } + + ServiceConfigBuilder& set_stale_age(grpc_millis stale_age) { + stale_age_ = stale_age * grpc_test_slowdown_factor(); + return *this; + } + + ServiceConfigBuilder& set_cache_size_bytes(int64_t size) { + cache_size_bytes_ = size; + return *this; + } + + ServiceConfigBuilder& AddKeyBuilder(absl::string_view key_builder) { + key_builders_.push_back(absl::StrCat("{", key_builder, "}")); + return *this; + } + + std::string Build() { + // First build parts of routeLookupConfig. + std::vector route_lookup_config_parts; + route_lookup_config_parts.push_back(absl::StrFormat( + " \"lookupService\":\"localhost:%d\"", rls_server_port_)); + if (lookup_service_timeout_ > 0) { + route_lookup_config_parts.push_back(absl::StrFormat( + " \"lookupServiceTimeout\":\"%d.%09ds\"", + lookup_service_timeout_ / 1000, lookup_service_timeout_ % 1000)); + } + if (!default_target_.empty()) { + route_lookup_config_parts.push_back(absl::StrFormat( + " \"defaultTarget\":\"%s\"", default_target_)); + } + route_lookup_config_parts.push_back(absl::StrFormat( + " \"cacheSizeBytes\":%" PRId64, cache_size_bytes_)); + if (max_age_ > 0) { + route_lookup_config_parts.push_back( + absl::StrFormat(" \"maxAge\":\"%d.%09ds\"", max_age_ / 1000, + max_age_ % 1000)); + } + if (stale_age_ > 0) { + route_lookup_config_parts.push_back( + absl::StrFormat(" \"staleAge\":\"%d.%09ds\"", + stale_age_ / 1000, stale_age_ % 1000)); + } + if (!key_builders_.empty()) { + route_lookup_config_parts.push_back( + absl::StrFormat(" \"grpcKeybuilders\":[%s]", + absl::StrJoin(key_builders_, ","))); + } + // Now build parts of RLS LB policy config. + std::vector rls_config_parts; + if (!route_lookup_config_parts.empty()) { + rls_config_parts.push_back(absl::StrCat( + " \"routeLookupConfig\":{", + absl::StrJoin(route_lookup_config_parts, ","), " }")); + } + rls_config_parts.push_back( + " \"childPolicy\":[{" + " \"fixed_address_lb\":{}\n" + " }],\n" + " \"childPolicyConfigTargetFieldName\":\"address\"\n"); + // Put it all together. + return absl::StrCat( + "{" + " \"loadBalancingConfig\":[{" + " \"rls\":{", + absl::StrJoin(rls_config_parts, ","), + " }" + " }]" + "}"); + } + + private: + int rls_server_port_; + grpc_millis lookup_service_timeout_ = 0; + std::string default_target_; + grpc_millis max_age_ = 0; + grpc_millis stale_age_ = 0; + int64_t cache_size_bytes_ = 10485760; + std::vector key_builders_; + }; + + ServiceConfigBuilder MakeServiceConfigBuilder() { + return ServiceConfigBuilder(rls_server_->port_); + } + + void SetNextResolution(absl::string_view service_config_json) { + resolver_response_generator_->SetNextResolution(service_config_json); + } + + template + struct ServerThread { + template + explicit ServerThread(const grpc::string& type, Args&&... args) + : port_(grpc_pick_unused_port_or_die()), + type_(type), + service_(std::forward(args)...) {} + + void Start() { + gpr_log(GPR_INFO, "starting %s server on port %d", type_.c_str(), port_); + GPR_ASSERT(!running_); + running_ = true; + service_.Start(); + grpc::internal::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Serve from firing before the wait below is hit. + grpc::internal::MutexLock lock(&mu); + grpc::internal::CondVar cond; + thread_ = absl::make_unique( + std::bind(&ServerThread::Serve, this, &mu, &cond)); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", type_.c_str()); + } + + void Serve(grpc::internal::Mutex* mu, grpc::internal::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(mu); + ServerBuilder builder; + auto creds = std::make_shared( + grpc_fake_transport_security_server_credentials_create()); + builder.AddListeningPort(absl::StrCat("localhost:", port_), + std::move(creds)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + cond->Signal(); + } + + void Shutdown() { + if (!running_) return; + gpr_log(GPR_INFO, "%s about to shutdown", type_.c_str()); + service_.Shutdown(); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", type_.c_str()); + running_ = false; + } + + const int port_; + grpc::string type_; + T service_; + std::unique_ptr server_; + std::unique_ptr thread_; + bool running_ = false; + }; + + bool ipv6_only_; + std::vector>> backends_; + std::unique_ptr> rls_server_; + std::unique_ptr + resolver_response_generator_; + std::shared_ptr channel_; + std::unique_ptr stub_; +}; + +TEST_F(RlsEnd2endTest, Basic) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + // No RLS header seen by the backend, since the RLS response didn't set any. + EXPECT_THAT(backends_[0]->service_.rls_data(), ::testing::ElementsAre()); +} + +TEST_F(RlsEnd2endTest, DuplicateHeadersAreMerged) { + const char* kTestValue2 = "test_value_2"; + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, absl::StrCat(kTestValue, ",", kTestValue2)}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Same header present twice in the request. Values should be merged. + CheckRpcSendOk( + DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}, {"key1", kTestValue2}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, SecondHeaderUsed) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\", \"key2\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key2", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, MultipleHeaderKeys) { + const char* kTestKey2 = "test_key_2"; + const char* kTestValue2 = "test_value_2"; + StartBackends(1); + SetNextResolution(MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat( + "\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }," + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key2\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey, kTestKey2)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({ + {kTestKey, kTestValue}, + {kTestKey2, kTestValue2}, + }), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk( + DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}, {"key2", kTestValue2}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + // No RLS header seen by the backend, since the RLS response didn't set any. + EXPECT_THAT(backends_[0]->service_.rls_data(), ::testing::ElementsAre()); +} + +TEST_F(RlsEnd2endTest, NoHeaderMatch) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Request does not have header "key1", so kTestKey will not be added. + CheckRpcSendOk(DEBUG_LOCATION); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, WildcardMethod) { + StartBackends(1); + SetNextResolution(MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, NoKeyBuilderForMethod) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"some_other_method\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, HeaderData) { + const char* kHeaderData = "header_data"; + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)}, + kHeaderData)); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + EXPECT_THAT(backends_[0]->service_.rls_data(), + ::testing::ElementsAre(kHeaderData)); +} + +TEST_F(RlsEnd2endTest, ExtraKeysAndConstantKeys) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\",\"key2\",\"key3\"" + " ]" + " }" + "]," + "\"extraKeys\":{" + " \"host\":\"%s\"," + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}," + "\"constantKeys\":{" + " \"%s\":\"%s\"" + "}", + kServiceValue, kMethodValue, kTestKey, + kHostKey, kServiceKey, kMethodKey, + kConstantKey, kConstantValue)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({ + {kTestKey, kTestValue}, + {kHostKey, kServerName}, + {kServiceKey, kServiceValue}, + {kMethodKey, kMethodValue}, + {kConstantKey, kConstantValue}, + }), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, TwoCacheEntriesWithSameTarget) { + const char* kTestValue2 = "test_value2"; + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue2}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue2}})); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 2); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); +} + +TEST_F(RlsEnd2endTest, FailedRlsRequestWithoutDefaultTarget) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + // Send an RPC before we give the RLS server a response. + // The RLS request will fail, and thus so will the data plane RPC. + CheckRpcSendFailure(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_THAT( + rls_server_->service_.GetUnmatchedRequests(), + ::testing::ElementsAre( + // TODO(roth): Change this to use ::testing::ProtoEquals() + // once that becomes available in OSS. + ::testing::Property( + &RouteLookupRequest::DebugString, + BuildRlsRequest({{kTestKey, kTestValue}}).DebugString()))); + // Now give the RLS server the right response. + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Sleep long enough for backoff to elapse, then try another RPC. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, FailedRlsRequestWithDefaultTarget) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_default_target(TargetStringForPort(backends_[0]->port_)) + .Build()); + // Don't give the RLS server a response, so the RLS request will fail. + // The data plane RPC should be sent to the default target. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_THAT( + rls_server_->service_.GetUnmatchedRequests(), + ::testing::ElementsAre( + // TODO(roth): Change this to use ::testing::ProtoEquals() + // once that becomes available in OSS. + ::testing::Property( + &RouteLookupRequest::DebugString, + BuildRlsRequest({{kTestKey, kTestValue}}).DebugString()))); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 0); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, RlsRequestTimeout) { + StartBackends(2); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_default_target(TargetStringForPort(backends_[1]->port_)) + .set_lookup_service_timeout(2000) + .Build()); + // RLS server will send a response, but it's longer than the timeout. + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)}), + /*response_delay=*/3000); + // The data plane RPC should be sent to the default target. + CheckRpcSendOk(DEBUG_LOCATION, RpcOptions().set_timeout_ms(4000).set_metadata( + {{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 0); + EXPECT_EQ(backends_[1]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, UpdateConfig) { + StartBackends(2); + auto service_config_builder = + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_default_target(TargetStringForPort(backends_[0]->port_)); + SetNextResolution(service_config_builder.Build()); + // Don't give the RLS server a response, so the RLS request will fail. + // The data plane RPC should be sent to the default target. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_THAT( + rls_server_->service_.GetUnmatchedRequests(), + ::testing::ElementsAre( + // TODO(roth): Change this to use ::testing::ProtoEquals() + // once that becomes available in OSS. + ::testing::Property( + &RouteLookupRequest::DebugString, + BuildRlsRequest({{kTestKey, kTestValue}}).DebugString()))); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 0); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + EXPECT_EQ(backends_[1]->service_.request_count(), 0); + // Now update the config to point to a new default target. + service_config_builder.set_default_target( + TargetStringForPort(backends_[1]->port_)); + SetNextResolution(service_config_builder.Build()); + // Send another RPC, which should go to the new default target. + // The RLS server will *not* see another request, because the cache + // entry is still in backoff. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 0); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + EXPECT_EQ(backends_[1]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, CachedResponse) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Send two RPCs. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + // The RLS server should have seen only one request. + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); +} + +TEST_F(RlsEnd2endTest, StaleCacheEntry) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_max_age(5000) + .set_stale_age(1000) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Send one RPC. RLS server gets a request, and RPC goes to backend. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + // Update RLS server to expect stale request. + rls_server_->service_.RemoveResponse( + BuildRlsRequest({{kTestKey, kTestValue}})); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}, + RouteLookupRequest::REASON_STALE), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Wait longer than stale age. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + // Send another RPC. This should use the stale value but should + // dispatch a second RLS request. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); + // Wait for RLS server to receive the second request. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 2); +} + +TEST_F(RlsEnd2endTest, StaleCacheEntryWithHeaderData) { + const char* kHeaderData = "header_data"; + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_max_age(5000) + .set_stale_age(1000) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)}, + kHeaderData)); + // Send one RPC. RLS server gets a request, and RPC goes to backend. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + // Update RLS server to expect stale request. + rls_server_->service_.RemoveResponse( + BuildRlsRequest({{kTestKey, kTestValue}})); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}, + RouteLookupRequest::REASON_STALE, kHeaderData), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)}, + kHeaderData)); + // Wait longer than stale age. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + // Send another RPC. This should use the stale value but should + // dispatch a second RLS request. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); + // Wait for RLS server to receive the second request. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 2); +} + +TEST_F(RlsEnd2endTest, ExpiredCacheEntry) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .set_max_age(1000) + .set_lookup_service_timeout(1000) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + // Send one RPC. RLS server gets a request, and RPC goes to backend. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + // Remove response from RLS server so that the next RLS request fails. + rls_server_->service_.RemoveResponse( + BuildRlsRequest({{kTestKey, kTestValue}})); + // Wait for cache to be expired. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(1)); + // Send another RPC. This should trigger a second RLS request, but + // that fails, so the RPC fails. + CheckRpcSendFailure(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, CacheSizeLimit) { + const char* kTestValue2 = "test_value_2"; + StartBackends(2); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, + kTestKey)) + .set_cache_size_bytes(1) // Not even big enough for one entry. + .Build()); + // Set RLS responses for both kTestValue and kTestValue2. + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({TargetStringForPort(backends_[0]->port_)})); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue2}}), + BuildRlsResponse({TargetStringForPort(backends_[1]->port_)})); + // Send an RPC for kTestValue. + // RLS server gets a request, and RPC goes to backend. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + EXPECT_EQ(backends_[1]->service_.request_count(), 0); + // A second RPC for kTestValue should not generate another RLS + // request, because the cache entry is held by min_eviction_time. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); + EXPECT_EQ(backends_[1]->service_.request_count(), 0); + // Wait for min_eviction_time to elapse. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(5)); + // Send a request for kTestValue2. + // RLS server gets a request, and RPC goes to backend. + // This causes the entry for kTestValue to be evicted. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue2}})); + EXPECT_EQ(rls_server_->service_.request_count(), 2); + EXPECT_EQ(rls_server_->service_.response_count(), 2); + EXPECT_EQ(backends_[0]->service_.request_count(), 2); + EXPECT_EQ(backends_[1]->service_.request_count(), 1); + // Send another RPC for kTestValue. + // This should now trigger a new RLS request. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 3); + EXPECT_EQ(rls_server_->service_.response_count(), 3); + EXPECT_EQ(backends_[0]->service_.request_count(), 3); + EXPECT_EQ(backends_[1]->service_.request_count(), 1); + // Another RPC for kTestValue2 should still work due to min_eviction_time. + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue2}})); + EXPECT_EQ(rls_server_->service_.request_count(), 3); + EXPECT_EQ(rls_server_->service_.response_count(), 3); + EXPECT_EQ(backends_[0]->service_.request_count(), 3); + EXPECT_EQ(backends_[1]->service_.request_count(), 2); +} + +TEST_F(RlsEnd2endTest, MultipleTargets) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse( + // First target will report TRANSIENT_FAILURE.. + {"invalid_target", TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); +} + +TEST_F(RlsEnd2endTest, ConnectivityStateReady) { + StartBackends(1); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(/*try_to_connect=*/false)); + rls_server_->service_.SetResponse( + BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse( + // One target in TRANSIENT_FAILURE, the other in READY. + {"invalid_target", TargetStringForPort(backends_[0]->port_)})); + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(backends_[0]->service_.request_count(), 1); + EXPECT_EQ(GRPC_CHANNEL_READY, channel_->GetState(/*try_to_connect=*/false)); +} + +TEST_F(RlsEnd2endTest, ConnectivityStateIdle) { + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(/*try_to_connect=*/false)); + // RLS server not given any responses, so the request will fail. + CheckRpcSendFailure(DEBUG_LOCATION); + // No child policies, so should be IDLE. + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(/*try_to_connect=*/false)); +} + +TEST_F(RlsEnd2endTest, ConnectivityStateTransientFailure) { + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(/*try_to_connect=*/false)); + rls_server_->service_.SetResponse(BuildRlsRequest({{kTestKey, kTestValue}}), + BuildRlsResponse({"invalid_target"})); + CheckRpcSendFailure(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + EXPECT_EQ(rls_server_->service_.request_count(), 1); + EXPECT_EQ(rls_server_->service_.response_count(), 1); + EXPECT_EQ(GRPC_CHANNEL_TRANSIENT_FAILURE, + channel_->GetState(/*try_to_connect=*/false)); +} + +TEST_F(RlsEnd2endTest, RlsAuthorityDeathTest) { + GRPC_GTEST_FLAG_SET_DEATH_TEST_STYLE("threadsafe"); + ResetStub("incorrect_authority"); + SetNextResolution( + MakeServiceConfigBuilder() + .AddKeyBuilder(absl::StrFormat("\"names\":[{" + " \"service\":\"%s\"," + " \"method\":\"%s\"" + "}]," + "\"headers\":[" + " {" + " \"key\":\"%s\"," + " \"names\":[" + " \"key1\"" + " ]" + " }" + "]", + kServiceValue, kMethodValue, kTestKey)) + .Build()); + // Make sure that we blow up (via abort() from the security connector) when + // the authority for the RLS channel doesn't match expectations. + ASSERT_DEATH_IF_SUPPORTED( + { + CheckRpcSendOk(DEBUG_LOCATION, + RpcOptions().set_metadata({{"key1", kTestValue}})); + }, + ""); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/sdk_authz_end2end_test.cc b/test/cpp/end2end/sdk_authz_end2end_test.cc new file mode 100644 index 00000000..3b6fb4a4 --- /dev/null +++ b/test/cpp/end2end/sdk_authz_end2end_test.cc @@ -0,0 +1,763 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" +#include "test/cpp/end2end/test_service_impl.h" + +namespace grpc { +namespace testing { +namespace { + +constexpr char kMessage[] = "Hello"; + +class SdkAuthzEnd2EndTest : public ::testing::Test { + protected: + SdkAuthzEnd2EndTest() + : server_address_( + absl::StrCat("localhost:", grpc_pick_unused_port_or_die())), + server_creds_( + std::shared_ptr(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create()))), + channel_creds_( + std::shared_ptr(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create()))) {} + + ~SdkAuthzEnd2EndTest() override { server_->Shutdown(); } + + // Replaces existing credentials with insecure credentials. + void UseInsecureCredentials() { + server_creds_ = InsecureServerCredentials(); + channel_creds_ = InsecureChannelCredentials(); + } + + // Creates server with sdk authorization enabled when provider is not null. + void InitServer( + std::shared_ptr + provider) { + ServerBuilder builder; + builder.AddListeningPort(server_address_, std::move(server_creds_)); + builder.experimental().SetAuthorizationPolicyProvider(std::move(provider)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + std::shared_ptr + CreateStaticAuthzPolicyProvider(const std::string& policy) { + grpc::Status status; + auto provider = experimental::StaticDataAuthorizationPolicyProvider::Create( + policy, &status); + EXPECT_TRUE(status.ok()); + return provider; + } + + std::shared_ptr + CreateFileWatcherAuthzPolicyProvider(const std::string& policy_path, + unsigned int refresh_interval_sec) { + grpc::Status status; + auto provider = + experimental::FileWatcherAuthorizationPolicyProvider::Create( + policy_path, refresh_interval_sec, &status); + EXPECT_TRUE(status.ok()); + return provider; + } + + std::shared_ptr BuildChannel() { + ChannelArguments args; + return ::grpc::CreateCustomChannel(server_address_, channel_creds_, args); + } + + grpc::Status SendRpc(const std::shared_ptr& channel, + ClientContext* context, + grpc::testing::EchoResponse* response = nullptr) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + grpc::testing::EchoRequest request; + request.set_message(kMessage); + return stub->Echo(context, request, response); + } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; + std::shared_ptr server_creds_; + std::shared_ptr channel_creds_; +}; + +TEST_F(SdkAuthzEnd2EndTest, + StaticInitAllowsRpcRequestNoMatchInDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\", \"foo2\"]" + " }," + " {" + " \"key\": \"key-bar\"," + " \"values\": [\"bar1\"]" + " }" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_clientstreamingecho\"," + " \"request\": {" + " \"paths\": [" + " \"*/ClientStreamingEcho\"" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-foo", "foo2"); + context.AddMetadata("key-bar", "bar1"); + context.AddMetadata("key-baz", "baz1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp.message(), kMessage); +} + +TEST_F(SdkAuthzEnd2EndTest, StaticInitDeniesRpcRequestNoMatchInAllowAndDeny) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_bar\"," + " \"source\": {" + " \"principals\": [" + " \"bar\"" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, StaticInitDeniesRpcRequestMatchInDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_all\"" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, + StaticInitDeniesRpcRequestMatchInDenyNoMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_clientstreamingecho\"," + " \"request\": {" + " \"paths\": [" + " \"*/ClientStreamingEcho\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, StaticInitAllowsRpcRequestEmptyDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\", \"foo2\"]" + " }," + " {" + " \"key\": \"key-bar\"," + " \"values\": [\"bar1\"]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-foo", "foo2"); + context.AddMetadata("key-bar", "bar1"); + context.AddMetadata("key-baz", "baz1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp.message(), kMessage); +} + +TEST_F(SdkAuthzEnd2EndTest, StaticInitDeniesRpcRequestEmptyDenyNoMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\"]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-bar", "bar1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F( + SdkAuthzEnd2EndTest, + StaticInitDeniesRpcRequestWithPrincipalsFieldOnUnauthenticatedConnection) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"source\": {" + " \"principals\": [" + " \"foo\"" + " ]" + " }," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + UseInsecureCredentials(); + InitServer(CreateStaticAuthzPolicyProvider(policy)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitAllowsRpcRequestNoMatchInDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\", \"foo2\"]" + " }," + " {" + " \"key\": \"key-bar\"," + " \"values\": [\"bar1\"]" + " }" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_clientstreamingecho\"," + " \"request\": {" + " \"paths\": [" + " \"*/ClientStreamingEcho\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-foo", "foo2"); + context.AddMetadata("key-bar", "bar1"); + context.AddMetadata("key-baz", "baz1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp.message(), kMessage); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitDeniesRpcRequestNoMatchInAllowAndDeny) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_bar\"," + " \"source\": {" + " \"principals\": [" + " \"bar\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitDeniesRpcRequestMatchInDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_all\"" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitDeniesRpcRequestMatchInDenyNoMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_clientstreamingecho\"," + " \"request\": {" + " \"paths\": [" + " \"*/ClientStreamingEcho\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitAllowsRpcRequestEmptyDenyMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\", \"foo2\"]" + " }," + " {" + " \"key\": \"key-bar\"," + " \"values\": [\"bar1\"]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-foo", "foo2"); + context.AddMetadata("key-bar", "bar1"); + context.AddMetadata("key-baz", "baz1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp.message(), kMessage); +} + +TEST_F(SdkAuthzEnd2EndTest, + FileWatcherInitDeniesRpcRequestEmptyDenyNoMatchInAllow) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]," + " \"headers\": [" + " {" + " \"key\": \"key-foo\"," + " \"values\": [\"foo1\"]" + " }" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 5)); + auto channel = BuildChannel(); + ClientContext context; + context.AddMetadata("key-bar", "bar1"); + grpc::testing::EchoResponse resp; + grpc::Status status = SendRpc(channel, &context, &resp); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, FileWatcherValidPolicyRefresh) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 1)); + auto channel = BuildChannel(); + ClientContext context1; + grpc::testing::EchoResponse resp1; + grpc::Status status = SendRpc(channel, &context1, &resp1); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp1.message(), kMessage); + // Replace the existing policy with a new authorization policy. + policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + tmp_policy.RewriteFile(policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + ClientContext context2; + grpc::testing::EchoResponse resp2; + status = SendRpc(channel, &context2, &resp2); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp2.message().empty()); +} + +TEST_F(SdkAuthzEnd2EndTest, FileWatcherInvalidPolicyRefreshSkipsReload) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 1)); + auto channel = BuildChannel(); + ClientContext context1; + grpc::testing::EchoResponse resp1; + grpc::Status status = SendRpc(channel, &context1, &resp1); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp1.message(), kMessage); + // Replaces existing policy with an invalid authorization policy. + policy = "{}"; + tmp_policy.RewriteFile(policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + ClientContext context2; + grpc::testing::EchoResponse resp2; + status = SendRpc(channel, &context2, &resp2); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp2.message(), kMessage); +} + +TEST_F(SdkAuthzEnd2EndTest, FileWatcherRecoversFromFailure) { + std::string policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + grpc_core::testing::TmpFile tmp_policy(policy); + InitServer(CreateFileWatcherAuthzPolicyProvider(tmp_policy.name(), 1)); + auto channel = BuildChannel(); + ClientContext context1; + grpc::testing::EchoResponse resp1; + grpc::Status status = SendRpc(channel, &context1, &resp1); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp1.message(), kMessage); + // Replaces existing policy with an invalid authorization policy. + policy = "{}"; + tmp_policy.RewriteFile(policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + ClientContext context2; + grpc::testing::EchoResponse resp2; + status = SendRpc(channel, &context2, &resp2); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(resp2.message(), kMessage); + // Replace the existing invalid policy with a valid authorization policy. + policy = + "{" + " \"name\": \"authz\"," + " \"allow_rules\": [" + " {" + " \"name\": \"allow_foo\"," + " \"request\": {" + " \"paths\": [" + " \"*/foo\"" + " ]" + " }" + " }" + " ]," + " \"deny_rules\": [" + " {" + " \"name\": \"deny_echo\"," + " \"request\": {" + " \"paths\": [" + " \"*/Echo\"" + " ]" + " }" + " }" + " ]" + "}"; + tmp_policy.RewriteFile(policy); + // Wait 2 seconds for the provider's refresh thread to read the updated files. + gpr_sleep_until(grpc_timeout_seconds_to_deadline(2)); + ClientContext context3; + grpc::testing::EchoResponse resp3; + status = SendRpc(channel, &context3, &resp3); + EXPECT_EQ(status.error_code(), grpc::StatusCode::PERMISSION_DENIED); + EXPECT_EQ(status.error_message(), "Unauthorized RPC request rejected."); + EXPECT_TRUE(resp3.message().empty()); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/server_builder_plugin_test.cc b/test/cpp/end2end/server_builder_plugin_test.cc new file mode 100644 index 00000000..4aaf0183 --- /dev/null +++ b/test/cpp/end2end/server_builder_plugin_test.cc @@ -0,0 +1,267 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#define PLUGIN_NAME "TestServerBuilderPlugin" + +namespace grpc { +namespace testing { + +class TestServerBuilderPlugin : public ServerBuilderPlugin { + public: + TestServerBuilderPlugin() : service_(new TestServiceImpl()) { + init_server_is_called_ = false; + finish_is_called_ = false; + change_arguments_is_called_ = false; + register_service_ = false; + } + + std::string name() override { return PLUGIN_NAME; } + + void InitServer(ServerInitializer* si) override { + init_server_is_called_ = true; + if (register_service_) { + si->RegisterService(service_); + } + } + + void Finish(ServerInitializer* /*si*/) override { finish_is_called_ = true; } + + void ChangeArguments(const std::string& /*name*/, void* /*value*/) override { + change_arguments_is_called_ = true; + } + + bool has_async_methods() const override { + if (register_service_) { + return service_->has_async_methods(); + } + return false; + } + + bool has_sync_methods() const override { + if (register_service_) { + return service_->has_synchronous_methods(); + } + return false; + } + + void SetRegisterService() { register_service_ = true; } + + bool init_server_is_called() { return init_server_is_called_; } + bool finish_is_called() { return finish_is_called_; } + bool change_arguments_is_called() { return change_arguments_is_called_; } + + private: + bool init_server_is_called_; + bool finish_is_called_; + bool change_arguments_is_called_; + bool register_service_; + std::shared_ptr service_; +}; + +class InsertPluginServerBuilderOption : public ServerBuilderOption { + public: + InsertPluginServerBuilderOption() { register_service_ = false; } + + void UpdateArguments(ChannelArguments* /*arg*/) override {} + + void UpdatePlugins( + std::vector>* plugins) override { + plugins->clear(); + + std::unique_ptr plugin( + new TestServerBuilderPlugin()); + if (register_service_) plugin->SetRegisterService(); + plugins->emplace_back(std::move(plugin)); + } + + void SetRegisterService() { register_service_ = true; } + + private: + bool register_service_; +}; + +std::unique_ptr CreateTestServerBuilderPlugin() { + return std::unique_ptr(new TestServerBuilderPlugin()); +} + +// Force AddServerBuilderPlugin() to be called at static initialization time. +struct StaticTestPluginInitializer { + StaticTestPluginInitializer() { + ::grpc::ServerBuilder::InternalAddPluginFactory( + &CreateTestServerBuilderPlugin); + } +} static_plugin_initializer_test_; + +// When the param boolean is true, the ServerBuilder plugin will be added at the +// time of static initialization. When it's false, the ServerBuilder plugin will +// be added using ServerBuilder::SetOption(). +class ServerBuilderPluginTest : public ::testing::TestWithParam { + public: + ServerBuilderPluginTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + builder_ = absl::make_unique(); + } + + void InsertPlugin() { + if (GetParam()) { + // Add ServerBuilder plugin in static initialization + CheckPresent(); + } else { + // Add ServerBuilder plugin using ServerBuilder::SetOption() + builder_->SetOption(std::unique_ptr( + new InsertPluginServerBuilderOption())); + } + } + + void InsertPluginWithTestService() { + if (GetParam()) { + // Add ServerBuilder plugin in static initialization + auto plugin = CheckPresent(); + EXPECT_TRUE(plugin); + plugin->SetRegisterService(); + } else { + // Add ServerBuilder plugin using ServerBuilder::SetOption() + std::unique_ptr option( + new InsertPluginServerBuilderOption()); + option->SetRegisterService(); + builder_->SetOption(std::move(option)); + } + } + + void StartServer() { + std::string server_address = "localhost:" + to_string(port_); + builder_->AddListeningPort(server_address, InsecureServerCredentials()); + // we run some tests without a service, and for those we need to supply a + // frequently polled completion queue + cq_ = builder_->AddCompletionQueue(); + cq_thread_ = new std::thread(&ServerBuilderPluginTest::RunCQ, this); + server_ = builder_->BuildAndStart(); + EXPECT_TRUE(CheckPresent()); + } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + channel_ = grpc::CreateChannel(target, InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() override { + auto plugin = CheckPresent(); + EXPECT_TRUE(plugin); + EXPECT_TRUE(plugin->init_server_is_called()); + EXPECT_TRUE(plugin->finish_is_called()); + server_->Shutdown(); + cq_->Shutdown(); + cq_thread_->join(); + delete cq_thread_; + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + protected: + std::shared_ptr channel_; + std::unique_ptr builder_; + std::unique_ptr stub_; + std::unique_ptr cq_; + std::unique_ptr server_; + std::thread* cq_thread_; + TestServiceImpl service_; + int port_; + + private: + TestServerBuilderPlugin* CheckPresent() { + auto it = builder_->plugins_.begin(); + for (; it != builder_->plugins_.end(); it++) { + if ((*it)->name() == PLUGIN_NAME) break; + } + if (it != builder_->plugins_.end()) { + return static_cast(it->get()); + } else { + return nullptr; + } + } + + void RunCQ() { + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + } + } +}; + +TEST_P(ServerBuilderPluginTest, PluginWithoutServiceTest) { + InsertPlugin(); + StartServer(); +} + +TEST_P(ServerBuilderPluginTest, PluginWithServiceTest) { + InsertPluginWithTestService(); + StartServer(); + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + ClientContext context; + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +INSTANTIATE_TEST_SUITE_P(ServerBuilderPluginTest, ServerBuilderPluginTest, + ::testing::Values(false, true)); + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/server_crash_test.cc b/test/cpp/end2end/server_crash_test.cc new file mode 100644 index 00000000..5b5d90b6 --- /dev/null +++ b/test/cpp/end2end/server_crash_test.cc @@ -0,0 +1,162 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +static std::string g_root; + +namespace grpc { +namespace testing { + +namespace { + +class ServiceImpl final : public ::grpc::testing::EchoTestService::Service { + public: + ServiceImpl() : bidi_stream_count_(0), response_stream_count_(0) {} + + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + bidi_stream_count_++; + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + } + return Status::OK; + } + + Status ResponseStream(ServerContext* /*context*/, + const EchoRequest* /*request*/, + ServerWriter* writer) override { + EchoResponse response; + response_stream_count_++; + for (int i = 0;; i++) { + std::ostringstream msg; + msg << "Hello " << i; + response.set_message(msg.str()); + if (!writer->Write(response)) break; + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + } + return Status::OK; + } + + int bidi_stream_count() { return bidi_stream_count_; } + + int response_stream_count() { return response_stream_count_; } + + private: + int bidi_stream_count_; + int response_stream_count_; +}; + +class CrashTest : public ::testing::Test { + protected: + CrashTest() {} + + std::unique_ptr CreateServerAndClient(const std::string& mode) { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + auto addr = addr_stream.str(); + client_ = absl::make_unique( + std::vector({g_root + "/server_crash_test_client", + "--address=" + addr, "--mode=" + mode})); + GPR_ASSERT(client_); + + ServerBuilder builder; + builder.AddListeningPort(addr, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + return builder.BuildAndStart(); + } + + void KillClient() { client_.reset(); } + + bool HadOneBidiStream() { return service_.bidi_stream_count() == 1; } + + bool HadOneResponseStream() { return service_.response_stream_count() == 1; } + + private: + std::unique_ptr client_; + ServiceImpl service_; +}; + +TEST_F(CrashTest, ResponseStream) { + auto server = CreateServerAndClient("response"); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(60, GPR_TIMESPAN))); + KillClient(); + server->Shutdown(); + GPR_ASSERT(HadOneResponseStream()); +} + +TEST_F(CrashTest, BidiStream) { + auto server = CreateServerAndClient("bidi"); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(60, GPR_TIMESPAN))); + KillClient(); + server->Shutdown(); + GPR_ASSERT(HadOneBidiStream()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + std::string me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != std::string::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/server_crash_test_client.cc b/test/cpp/end2end/server_crash_test_client.cc new file mode 100644 index 00000000..102df3c5 --- /dev/null +++ b/test/cpp/end2end/server_crash_test_client.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(std::string, address, "", "Address to connect to"); +ABSL_FLAG(std::string, mode, "", "Test mode to use"); + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + auto stub = grpc::testing::EchoTestService::NewStub(grpc::CreateChannel( + absl::GetFlag(FLAGS_address), grpc::InsecureChannelCredentials())); + + EchoRequest request; + EchoResponse response; + grpc::ClientContext context; + context.set_wait_for_ready(true); + + if (absl::GetFlag(FLAGS_mode) == "bidi") { + auto stream = stub->BidiStream(&context); + for (int i = 0;; i++) { + std::ostringstream msg; + msg << "Hello " << i; + request.set_message(msg.str()); + GPR_ASSERT(stream->Write(request)); + GPR_ASSERT(stream->Read(&response)); + GPR_ASSERT(response.message() == request.message()); + } + } else if (absl::GetFlag(FLAGS_mode) == "response") { + EchoRequest request; + request.set_message("Hello"); + auto stream = stub->ResponseStream(&context, request); + for (;;) { + GPR_ASSERT(stream->Read(&response)); + } + } else { + gpr_log(GPR_ERROR, "invalid test mode '%s'", + absl::GetFlag(FLAGS_mode).c_str()); + return 1; + } +} diff --git a/test/cpp/end2end/server_early_return_test.cc b/test/cpp/end2end/server_early_return_test.cc new file mode 100644 index 00000000..b318a266 --- /dev/null +++ b/test/cpp/end2end/server_early_return_test.cc @@ -0,0 +1,232 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" + +namespace grpc { +namespace testing { +namespace { + +const char kServerReturnStatusCode[] = "server_return_status_code"; +const char kServerDelayBeforeReturnUs[] = "server_delay_before_return_us"; +const char kServerReturnAfterNReads[] = "server_return_after_n_reads"; + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + // Unused methods are not implemented. + + Status RequestStream(ServerContext* context, + ServerReader* reader, + EchoResponse* response) override { + int server_return_status_code = + GetIntValueFromMetadata(context, kServerReturnStatusCode, 0); + int server_delay_before_return_us = + GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0); + int server_return_after_n_reads = + GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0); + + EchoRequest request; + while (server_return_after_n_reads--) { + EXPECT_TRUE(reader->Read(&request)); + } + + response->set_message("response msg"); + + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN))); + + return Status(static_cast(server_return_status_code), ""); + } + + Status BidiStream( + ServerContext* context, + ServerReaderWriter* stream) override { + int server_return_status_code = + GetIntValueFromMetadata(context, kServerReturnStatusCode, 0); + int server_delay_before_return_us = + GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0); + int server_return_after_n_reads = + GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0); + + EchoRequest request; + EchoResponse response; + while (server_return_after_n_reads--) { + EXPECT_TRUE(stream->Read(&request)); + response.set_message(request.message()); + EXPECT_TRUE(stream->Write(response)); + } + + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN))); + + return Status(static_cast(server_return_status_code), ""); + } + + int GetIntValueFromMetadata(ServerContext* context, const char* key, + int default_value) { + auto metadata = context->client_metadata(); + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + } + return default_value; + } +}; + +class ServerEarlyReturnTest : public ::testing::Test { + protected: + ServerEarlyReturnTest() : picked_port_(0) {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + picked_port_ = port; + server_address_ << "localhost:" << port; + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + + channel_ = grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() override { + server_->Shutdown(); + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + // Client sends 20 requests and the server returns after reading 10 requests. + // If return_cancel is true, server returns CANCELLED status. Otherwise it + // returns OK. + void DoBidiStream(bool return_cancelled) { + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerReturnAfterNReads, "10"); + if (return_cancelled) { + // "1" means CANCELLED + context.AddMetadata(kServerReturnStatusCode, "1"); + } + context.AddMetadata(kServerDelayBeforeReturnUs, "10000"); + + auto stream = stub_->BidiStream(&context); + + for (int i = 0; i < 20; i++) { + request.set_message(std::string("hello") + std::to_string(i)); + bool write_ok = stream->Write(request); + bool read_ok = stream->Read(&response); + if (i < 10) { + EXPECT_TRUE(write_ok); + EXPECT_TRUE(read_ok); + EXPECT_EQ(response.message(), request.message()); + } else { + EXPECT_FALSE(read_ok); + } + } + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + if (return_cancelled) { + EXPECT_EQ(s.error_code(), StatusCode::CANCELLED); + } else { + EXPECT_TRUE(s.ok()); + } + } + + void DoRequestStream(bool return_cancelled) { + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerReturnAfterNReads, "10"); + if (return_cancelled) { + // "1" means CANCELLED + context.AddMetadata(kServerReturnStatusCode, "1"); + } + context.AddMetadata(kServerDelayBeforeReturnUs, "10000"); + + auto stream = stub_->RequestStream(&context, &response); + for (int i = 0; i < 20; i++) { + request.set_message(std::string("hello") + std::to_string(i)); + bool written = stream->Write(request); + if (i < 10) { + EXPECT_TRUE(written); + } + } + stream->WritesDone(); + Status s = stream->Finish(); + if (return_cancelled) { + EXPECT_EQ(s.error_code(), StatusCode::CANCELLED); + } else { + EXPECT_TRUE(s.ok()); + } + } + + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + TestServiceImpl service_; + int picked_port_; +}; + +TEST_F(ServerEarlyReturnTest, BidiStreamEarlyOk) { DoBidiStream(false); } + +TEST_F(ServerEarlyReturnTest, BidiStreamEarlyCancel) { DoBidiStream(true); } + +TEST_F(ServerEarlyReturnTest, RequestStreamEarlyOK) { DoRequestStream(false); } +TEST_F(ServerEarlyReturnTest, RequestStreamEarlyCancel) { + DoRequestStream(true); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc new file mode 100644 index 00000000..3f303cdc --- /dev/null +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -0,0 +1,703 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/match.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +namespace grpc { +namespace testing { +namespace { + +class LoggingInterceptor : public experimental::Interceptor { + public: + explicit LoggingInterceptor(experimental::ServerRpcInfo* info) { + info_ = info; + + // Check the method name and compare to the type + const char* method = info->method(); + experimental::ServerRpcInfo::Type type = info->type(); + + // Check that we use one of our standard methods with expected type. + // Also allow the health checking service. + // We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService + // being tested (the GenericRpc test). + // The empty method is for the Unimplemented requests that arise + // when draining the CQ. + EXPECT_TRUE( + strstr(method, "/grpc.health") == method || + (strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 && + (type == experimental::ServerRpcInfo::Type::UNARY || + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) || + (strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 && + type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 && + type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) || + strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 || + (strcmp(method, "") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS)) { + auto* map = methods->GetSendTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = absl::StartsWith(pair.first, "testkey") && + absl::StartsWith(pair.second, "testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto status = methods->GetSendStatus(); + EXPECT_EQ(status.ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + if (resp != nullptr) { + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE)) { + // Got nothing interesting to do here + } + methods->Proceed(); + } + + private: + experimental::ServerRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageTester : public experimental::Interceptor { + public: + explicit SyncSendMessageTester(experimental::ServerRpcInfo* /*info*/) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + string old_msg = + static_cast(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("Hello"), 0u); + new_msg_.set_message("World" + old_msg); + methods->ModifySendMessage(&new_msg_); + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageTesterFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageTester(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageVerifier : public experimental::Interceptor { + public: + explicit SyncSendMessageVerifier(experimental::ServerRpcInfo* /*info*/) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + // Make sure that the changes made in SyncSendMessageTester persisted + string old_msg = + static_cast(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("World"), 0u); + + // Remove the "World" part of the string that we added earlier + new_msg_.set_message(old_msg.erase(0, 5)); + methods->ModifySendMessage(&new_msg_); + + // LoggingInterceptor verifies that changes got reverted + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageVerifierFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageVerifier(info); + } +}; + +void MakeBidiStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncUnaryTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr> + creators; + creators.push_back( + std::unique_ptr( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 phony interceptor factories and null interceptor factories + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeCall(channel); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncStreamingTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr> + creators; + creators.push_back( + std::unique_ptr( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeClientStreamingCall(channel); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeServerStreamingCall(channel); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) { + ChannelArguments args; + PhonyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeBidiStreamingCall(channel); + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {}; + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) { + PhonyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, cq.get())); + + service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(), + cq.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) { + } + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) { + PhonyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncReaderWriter srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> + cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1))); + + service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq.get()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq.get()); + + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) { + } + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) { + PhonyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + AsyncGenericService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&service); + std::vector> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto srv_cq = builder.AddCompletionQueue(); + CompletionQueue cli_cq; + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + GenericStub generic_stub(channel); + + const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + + CompletionQueue* cq = srv_cq.get(); + std::thread request_call([cq]() { Verifier().Expect(4, true).Verify(cq); }); + std::unique_ptr call = + generic_stub.PrepareCall(&cli_ctx, kMethodName, &cli_cq); + call->StartCall(tag(1)); + Verifier().Expect(1, true).Verify(&cli_cq); + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + Verifier().Expect(2, true).Verify(&cli_cq); + call->WritesDone(tag(3)); + Verifier().Expect(3, true).Verify(&cli_cq); + + service.RequestCall(&srv_ctx, &stream, srv_cq.get(), srv_cq.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(kMethodName, srv_ctx.method()); + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + Verifier().Expect(5, true).Verify(srv_cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + Verifier().Expect(6, true).Verify(srv_cq.get()); + + stream.Finish(Status::OK, tag(7)); + // Shutdown srv_cq before we try to get the tag back, to verify that the + // interception API handles completion queue shutdowns that take place before + // all the tags are returned + srv_cq->Shutdown(); + Verifier().Expect(7, true).Verify(srv_cq.get()); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + Verifier().Expect(8, true).Verify(&cli_cq); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + cli_cq.Shutdown(); + Verifier().Expect(9, true).Verify(&cli_cq); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cli_cq.Next(&ignored_tag, &ignored_ok)) { + } + while (srv_cq->Next(&ignored_tag, &ignored_ok)) { + } + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) { + PhonyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + std::unique_ptr> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) { + } + grpc_recycle_unused_port(port); +} + +class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test { +}; + +TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) { + PhonyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + TestServiceImpl service; + builder.RegisterService(&service); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(absl::make_unique()); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + Status recv_status = + stub->Unimplemented(&cli_ctx, send_request, &recv_response); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 phony interceptors were run + EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + grpc_recycle_unused_port(port); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/server_load_reporting_end2end_test.cc b/test/cpp/end2end/server_load_reporting_end2end_test.cc new file mode 100644 index 00000000..8c83afad --- /dev/null +++ b/test/cpp/end2end/server_load_reporting_end2end_test.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/lb/v1/load_reporter.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +constexpr double kMetricValue = 3.1415; +constexpr char kMetricName[] = "METRIC_PI"; + +// Different messages result in different response statuses. For simplicity in +// computing request bytes, the message sizes should be the same. +const char kOkMessage[] = "hello"; +const char kServerErrorMessage[] = "sverr"; +const char kClientErrorMessage[] = "clerr"; + +class EchoTestServiceImpl : public EchoTestService::Service { + public: + ~EchoTestServiceImpl() override {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (request->message() == kServerErrorMessage) { + return Status(StatusCode::UNKNOWN, "Server error requested"); + } + if (request->message() == kClientErrorMessage) { + return Status(StatusCode::FAILED_PRECONDITION, "Client error requested"); + } + response->set_message(request->message()); + ::grpc::load_reporter::experimental::AddLoadReportingCost( + context, kMetricName, kMetricValue); + return Status::OK; + } +}; + +class ServerLoadReportingEnd2endTest : public ::testing::Test { + protected: + void SetUp() override { + server_address_ = + "localhost:" + std::to_string(grpc_pick_unused_port_or_die()); + server_ = + ServerBuilder() + .AddListeningPort(server_address_, InsecureServerCredentials()) + .RegisterService(&echo_service_) + .SetOption(std::unique_ptr<::grpc::ServerBuilderOption>( + new ::grpc::load_reporter::experimental:: + LoadReportingServiceServerBuilderOption())) + .BuildAndStart(); + server_thread_ = + std::thread(&ServerLoadReportingEnd2endTest::RunServerLoop, this); + } + + void RunServerLoop() { server_->Wait(); } + + void TearDown() override { + server_->Shutdown(); + server_thread_.join(); + } + + void ClientMakeEchoCalls(const std::string& lb_id, const std::string& lb_tag, + const std::string& message, size_t num_requests) { + auto stub = EchoTestService::NewStub( + grpc::CreateChannel(server_address_, InsecureChannelCredentials())); + std::string lb_token = lb_id + lb_tag; + for (size_t i = 0; i < num_requests; ++i) { + ClientContext ctx; + if (!lb_token.empty()) ctx.AddMetadata(GRPC_LB_TOKEN_MD_KEY, lb_token); + EchoRequest request; + EchoResponse response; + request.set_message(message); + Status status = stub->Echo(&ctx, request, &response); + if (message == kOkMessage) { + ASSERT_EQ(status.error_code(), StatusCode::OK); + ASSERT_EQ(request.message(), response.message()); + } else if (message == kServerErrorMessage) { + ASSERT_EQ(status.error_code(), StatusCode::UNKNOWN); + } else if (message == kClientErrorMessage) { + ASSERT_EQ(status.error_code(), StatusCode::FAILED_PRECONDITION); + } + } + } + + std::string server_address_; + std::unique_ptr server_; + std::thread server_thread_; + EchoTestServiceImpl echo_service_; +}; + +TEST_F(ServerLoadReportingEnd2endTest, NoCall) {} + +TEST_F(ServerLoadReportingEnd2endTest, BasicReport) { + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + auto stub = ::grpc::lb::v1::LoadReporter::NewStub(channel); + ClientContext ctx; + auto stream = stub->ReportLoad(&ctx); + ::grpc::lb::v1::LoadReportRequest request; + request.mutable_initial_request()->set_load_balanced_hostname( + server_address_); + request.mutable_initial_request()->set_load_key("LOAD_KEY"); + request.mutable_initial_request() + ->mutable_load_report_interval() + ->set_seconds(5); + stream->Write(request); + gpr_log(GPR_INFO, "Initial request sent."); + ::grpc::lb::v1::LoadReportResponse response; + stream->Read(&response); + const std::string& lb_id = response.initial_response().load_balancer_id(); + gpr_log(GPR_INFO, "Initial response received (lb_id: %s).", lb_id.c_str()); + ClientMakeEchoCalls(lb_id, "LB_TAG", kOkMessage, 1); + while (true) { + stream->Read(&response); + if (!response.load().empty()) { + ASSERT_EQ(response.load().size(), 3); + for (const auto& load : response.load()) { + if (load.in_progress_report_case()) { + // The special load record that reports the number of in-progress + // calls. + ASSERT_EQ(load.num_calls_in_progress(), 1); + } else if (load.orphaned_load_case()) { + // The call from the balancer doesn't have any valid LB token. + ASSERT_EQ(load.orphaned_load_case(), load.kLoadKeyUnknown); + ASSERT_EQ(load.num_calls_started(), 1); + ASSERT_EQ(load.num_calls_finished_without_error(), 0); + ASSERT_EQ(load.num_calls_finished_with_error(), 0); + } else { + // This corresponds to the calls from the client. + ASSERT_EQ(load.num_calls_started(), 1); + ASSERT_EQ(load.num_calls_finished_without_error(), 1); + ASSERT_EQ(load.num_calls_finished_with_error(), 0); + ASSERT_GE(load.total_bytes_received(), sizeof(kOkMessage)); + ASSERT_GE(load.total_bytes_sent(), sizeof(kOkMessage)); + ASSERT_EQ(load.metric_data().size(), 1); + ASSERT_EQ(load.metric_data().Get(0).metric_name(), kMetricName); + ASSERT_EQ(load.metric_data().Get(0).num_calls_finished_with_metric(), + 1); + ASSERT_EQ(load.metric_data().Get(0).total_metric_value(), + kMetricValue); + } + } + break; + } + } + stream->WritesDone(); + ASSERT_EQ(stream->Finish().error_code(), StatusCode::CANCELLED); +} + +// TODO(juanlishen): Add more tests. + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/service_config_end2end_test.cc b/test/cpp/end2end/service_config_end2end_test.cc new file mode 100644 index 00000000..bd64e172 --- /dev/null +++ b/test/cpp/end2end/service_config_end2end_test.cc @@ -0,0 +1,621 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/resolve_localhost_ip46.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { +namespace { + +// Subclass of TestServiceImpl that increments a request counter for +// every call to the Echo RPC. +class MyTestServiceImpl : public TestServiceImpl { + public: + MyTestServiceImpl() : request_count_(0) {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + { + grpc::internal::MutexLock lock(&mu_); + ++request_count_; + } + AddClient(context->peer()); + return TestServiceImpl::Echo(context, request, response); + } + + int request_count() { + grpc::internal::MutexLock lock(&mu_); + return request_count_; + } + + void ResetCounters() { + grpc::internal::MutexLock lock(&mu_); + request_count_ = 0; + } + + std::set clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + void AddClient(const std::string& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex mu_; + int request_count_; + grpc::internal::Mutex clients_mu_; + std::set clients_; +}; + +class ServiceConfigEnd2endTest : public ::testing::Test { + protected: + ServiceConfigEnd2endTest() + : server_host_("localhost"), + kRequestMessage_("Live long and prosper."), + creds_(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create())) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); + } + + void SetUp() override { + grpc_init(); + response_generator_ = + grpc_core::MakeRefCounted(); + bool localhost_resolves_to_ipv4 = false; + bool localhost_resolves_to_ipv6 = false; + grpc_core::LocalhostResolves(&localhost_resolves_to_ipv4, + &localhost_resolves_to_ipv6); + ipv6_only_ = !localhost_resolves_to_ipv4 && localhost_resolves_to_ipv6; + } + + void TearDown() override { + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + // Explicitly destroy all the members so that we can make sure grpc_shutdown + // has finished by the end of this function, and thus all the registered + // LB policy factories are removed. + stub_.reset(); + servers_.clear(); + creds_.reset(); + grpc_shutdown(); + } + + void CreateServers(size_t num_servers, + std::vector ports = std::vector()) { + servers_.clear(); + for (size_t i = 0; i < num_servers; ++i) { + int port = 0; + if (ports.size() == num_servers) port = ports[i]; + servers_.emplace_back(new ServerData(port)); + } + } + + void StartServer(size_t index) { servers_[index]->Start(server_host_); } + + void StartServers(size_t num_servers, + std::vector ports = std::vector()) { + CreateServers(num_servers, std::move(ports)); + for (size_t i = 0; i < num_servers; ++i) { + StartServer(i); + } + } + + grpc_core::Resolver::Result BuildFakeResults(const std::vector& ports) { + grpc_core::Resolver::Result result; + for (const int& port : ports) { + std::string lb_uri_str = + absl::StrCat(ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", port); + absl::StatusOr lb_uri = grpc_core::URI::Parse(lb_uri_str); + GPR_ASSERT(lb_uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*lb_uri, &address)); + result.addresses.emplace_back(address.addr, address.len, + nullptr /* args */); + } + return result; + } + + void SetNextResolutionNoServiceConfig(const std::vector& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + response_generator_->SetResponse(result); + } + + void SetNextResolutionValidServiceConfig(const std::vector& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{}", &result.service_config_error); + response_generator_->SetResponse(result); + } + + void SetNextResolutionInvalidServiceConfig(const std::vector& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{", &result.service_config_error); + response_generator_->SetResponse(result); + } + + void SetNextResolutionWithServiceConfig(const std::vector& ports, + const char* svc_cfg) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, svc_cfg, &result.service_config_error); + response_generator_->SetResponse(result); + } + + std::vector GetServersPorts(size_t start_index = 0) { + std::vector ports; + for (size_t i = start_index; i < servers_.size(); ++i) { + ports.push_back(servers_[i]->port_); + } + return ports; + } + + std::unique_ptr BuildStub( + const std::shared_ptr& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr BuildChannel() { + ChannelArguments args; + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + std::shared_ptr BuildChannelWithDefaultServiceConfig() { + ChannelArguments args; + EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON( + ValidDefaultServiceConfig()), + ::testing::StrEq("")); + args.SetServiceConfigJSON(ValidDefaultServiceConfig()); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + std::shared_ptr BuildChannelWithInvalidDefaultServiceConfig() { + ChannelArguments args; + EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON( + InvalidDefaultServiceConfig()), + ::testing::HasSubstr("JSON parse error")); + args.SetServiceConfigJSON(InvalidDefaultServiceConfig()); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + bool SendRpc( + const std::unique_ptr& stub, + EchoResponse* response = nullptr, int timeout_ms = 1000, + Status* result = nullptr, bool wait_for_ready = false) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + Status status = stub->Echo(&context, request, response); + if (result != nullptr) *result = status; + if (local_response) delete response; + return status.ok(); + } + + void CheckRpcSendOk( + const std::unique_ptr& stub, + const grpc_core::DebugLocation& location, bool wait_for_ready = false) { + EchoResponse response; + Status status; + const bool success = + SendRpc(stub, &response, 2000, &status, wait_for_ready); + ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line() + << "\n" + << "Error: " << status.error_message() << " " + << status.error_details(); + ASSERT_EQ(response.message(), kRequestMessage_) + << "From " << location.file() << ":" << location.line(); + if (!success) abort(); + } + + void CheckRpcSendFailure( + const std::unique_ptr& stub) { + const bool success = SendRpc(stub); + EXPECT_FALSE(success); + } + + struct ServerData { + const int port_; + std::unique_ptr server_; + MyTestServiceImpl service_; + std::unique_ptr thread_; + + grpc::internal::Mutex mu_; + grpc::internal::CondVar cond_; + bool server_ready_ ABSL_GUARDED_BY(mu_) = false; + bool started_ ABSL_GUARDED_BY(mu_) = false; + + explicit ServerData(int port = 0) + : port_(port > 0 ? port : grpc_pick_unused_port_or_die()) {} + + void Start(const std::string& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + grpc::internal::MutexLock lock(&mu_); + started_ = true; + thread_ = absl::make_unique( + std::bind(&ServerData::Serve, this, server_host)); + while (!server_ready_) { + cond_.Wait(&mu_); + } + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const std::string& server_host) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), std::move(creds)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc::internal::MutexLock lock(&mu_); + server_ready_ = true; + cond_.Signal(); + } + + void Shutdown() { + grpc::internal::MutexLock lock(&mu_); + if (!started_) return; + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + started_ = false; + } + + void SetServingStatus(const std::string& service, bool serving) { + server_->GetHealthCheckService()->SetServingStatus(service, serving); + } + }; + + void ResetCounters() { + for (const auto& server : servers_) server->service_.ResetCounters(); + } + + void WaitForServer( + const std::unique_ptr& stub, + size_t server_idx, const grpc_core::DebugLocation& location, + bool ignore_failure = false) { + do { + if (ignore_failure) { + SendRpc(stub); + } else { + CheckRpcSendOk(stub, location, true); + } + } while (servers_[server_idx]->service_.request_count() == 0); + ResetCounters(); + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool SeenAllServers() { + for (const auto& server : servers_) { + if (server->service_.request_count() == 0) return false; + } + return true; + } + + // Updates \a connection_order by appending to it the index of the newly + // connected server. Must be called after every single RPC. + void UpdateConnectionOrder( + const std::vector>& servers, + std::vector* connection_order) { + for (size_t i = 0; i < servers.size(); ++i) { + if (servers[i]->service_.request_count() == 1) { + // Was the server index known? If not, update connection_order. + const auto it = + std::find(connection_order->begin(), connection_order->end(), i); + if (it == connection_order->end()) { + connection_order->push_back(i); + return; + } + } + } + } + + const char* ValidServiceConfigV1() { return "{\"version\": \"1\"}"; } + + const char* ValidServiceConfigV2() { return "{\"version\": \"2\"}"; } + + const char* ValidDefaultServiceConfig() { + return "{\"version\": \"valid_default\"}"; + } + + const char* InvalidDefaultServiceConfig() { + return "{\"version\": \"invalid_default\""; + } + + bool ipv6_only_ = false; + const std::string server_host_; + std::unique_ptr stub_; + std::vector> servers_; + grpc_core::RefCountedPtr + response_generator_; + const std::string kRequestMessage_; + std::shared_ptr creds_; +}; + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidDefaultServiceConfig(), + channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, InvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, ValidServiceConfigUpdatesTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV2()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV2(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + NoServiceConfigUpdateAfterValidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + NoServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidDefaultServiceConfig(), + channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidServiceConfigUpdateAfterValidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + ValidServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionValidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); +} + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + AnotherInvalidServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, InvalidDefaultServiceConfigTest) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + // An invalid default service config results in a lame channel which fails all + // RPCs + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithValidServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionValidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithInvalidServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithNoServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/shutdown_test.cc b/test/cpp/end2end/shutdown_test.cc new file mode 100644 index 00000000..384b2e8b --- /dev/null +++ b/test/cpp/end2end/shutdown_test.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + explicit TestServiceImpl(gpr_event* ev) : ev_(ev) {} + + Status Echo(ServerContext* context, const EchoRequest* /*request*/, + EchoResponse* /*response*/) override { + gpr_event_set(ev_, reinterpret_cast(1)); + while (!context->IsCancelled()) { + } + return Status::OK; + } + + private: + gpr_event* ev_; +}; + +class ShutdownTest : public ::testing::TestWithParam { + public: + ShutdownTest() : shutdown_(false), service_(&ev_) { gpr_event_init(&ev_); } + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_ = SetUpServer(port_); + } + + std::unique_ptr SetUpServer(const int port) { + std::string server_address = "localhost:" + to_string(port); + + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(GetParam()); + builder.AddListeningPort(server_address, server_creds); + builder.RegisterService(&service_); + std::unique_ptr server = builder.BuildAndStart(); + return server; + } + + void TearDown() override { GPR_ASSERT(shutdown_); } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + ChannelArguments args; + auto channel_creds = + GetCredentialsProvider()->GetChannelCredentials(GetParam(), &args); + channel_ = ::grpc::CreateCustomChannel(target, channel_creds, args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + void SendRequest() { + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + ClientContext context; + GPR_ASSERT(!shutdown_); + Status s = stub_->Echo(&context, request, &response); + GPR_ASSERT(shutdown_); + } + + protected: + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr server_; + bool shutdown_; + int port_; + gpr_event ev_; + TestServiceImpl service_; +}; + +std::vector GetAllCredentialsTypeList() { + std::vector credentials_types; + if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType, + nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + GPR_ASSERT(!credentials_types.empty()); + + std::string credentials_type_list("credentials types:"); + for (const string& type : credentials_types) { + credentials_type_list.append(" " + type); + } + gpr_log(GPR_INFO, "%s", credentials_type_list.c_str()); + return credentials_types; +} + +INSTANTIATE_TEST_SUITE_P(End2EndShutdown, ShutdownTest, + ::testing::ValuesIn(GetAllCredentialsTypeList())); + +// TODO(ctiller): leaked objects in this test +TEST_P(ShutdownTest, ShutdownTest) { + ResetStub(); + + // send the request in a background thread + std::thread thr(std::bind(&ShutdownTest::SendRequest, this)); + + // wait for the server to get the event + gpr_event_wait(&ev_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + + shutdown_ = true; + + // shutdown should trigger cancellation causing everything to shutdown + auto deadline = + std::chrono::system_clock::now() + std::chrono::microseconds(100); + server_->Shutdown(deadline); + EXPECT_GE(std::chrono::system_clock::now(), deadline); + + thr.join(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/streaming_throughput_test.cc b/test/cpp/end2end/streaming_throughput_test.cc new file mode 100644 index 00000000..41721957 --- /dev/null +++ b/test/cpp/end2end/streaming_throughput_test.cc @@ -0,0 +1,193 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +const char* kLargeString = + "(" + "To be, or not to be- that is the question:" + "Whether 'tis nobler in the mind to suffer" + "The slings and arrows of outrageous fortune" + "Or to take arms against a sea of troubles," + "And by opposing end them. To die- to sleep-" + "No more; and by a sleep to say we end" + "The heartache, and the thousand natural shock" + "That flesh is heir to. 'Tis a consummation" + "Devoutly to be wish'd. To die- to sleep." + "To sleep- perchance to dream: ay, there's the rub!" + "For in that sleep of death what dreams may come" + "When we have shuffled off this mortal coil," + "Must give us pause. There's the respect" + "That makes calamity of so long life." + "For who would bear the whips and scorns of time," + "Th' oppressor's wrong, the proud man's contumely," + "The pangs of despis'd love, the law's delay," + "The insolence of office, and the spurns" + "That patient merit of th' unworthy takes," + "When he himself might his quietus make" + "With a bare bodkin? Who would these fardels bear," + "To grunt and sweat under a weary life," + "But that the dread of something after death-" + "The undiscover'd country, from whose bourn" + "No traveller returns- puzzles the will," + "And makes us rather bear those ills we have" + "Than fly to others that we know not of?" + "Thus conscience does make cowards of us all," + "And thus the native hue of resolution" + "Is sicklied o'er with the pale cast of thought," + "And enterprises of great pith and moment" + "With this regard their currents turn awry" + "And lose the name of action.- Soft you now!" + "The fair Ophelia!- Nymph, in thy orisons" + "Be all my sins rememb'red."; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + static void BidiStream_Sender( + ServerReaderWriter* stream, + gpr_atm* should_exit) { + EchoResponse response; + response.set_message(kLargeString); + while (gpr_atm_acq_load(should_exit) == static_cast(0)) { + struct timespec tv = {0, 1000000}; // 1 ms + struct timespec rem; + // TODO (vpai): Mark this blocking + while (nanosleep(&tv, &rem) != 0) { + tv = rem; + }; + + stream->Write(response); + } + } + + // Only implement the one method we will be calling for brevity. + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + EchoRequest request; + gpr_atm should_exit; + gpr_atm_rel_store(&should_exit, static_cast(0)); + + std::thread sender( + std::bind(&TestServiceImpl::BidiStream_Sender, stream, &should_exit)); + + while (stream->Read(&request)) { + struct timespec tv = {0, 3000000}; // 3 ms + struct timespec rem; + // TODO (vpai): Mark this blocking + while (nanosleep(&tv, &rem) != 0) { + tv = rem; + }; + } + gpr_atm_rel_store(&should_exit, static_cast(1)); + sender.join(); + return Status::OK; + } +}; + +class End2endTest : public ::testing::Test { + protected: + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +static void Drainer(ClientReaderWriter* reader) { + EchoResponse response; + while (reader->Read(&response)) { + // Just drain out the responses as fast as possible. + } +} + +TEST_F(End2endTest, StreamingThroughput) { + ResetStub(); + grpc::ClientContext context; + auto stream = stub_->BidiStream(&context); + + auto reader = stream.get(); + std::thread receiver(std::bind(Drainer, reader)); + + for (int i = 0; i < 10000; i++) { + EchoRequest request; + request.set_message(kLargeString); + ASSERT_TRUE(stream->Write(request)); + if (i % 1000 == 0) { + gpr_log(GPR_INFO, "Send count = %d", i); + } + } + stream->WritesDone(); + receiver.join(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/test_health_check_service_impl.cc b/test/cpp/end2end/test_health_check_service_impl.cc new file mode 100644 index 00000000..957538eb --- /dev/null +++ b/test/cpp/end2end/test_health_check_service_impl.cc @@ -0,0 +1,98 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/test_health_check_service_impl.h" + +#include + +using grpc::health::v1::HealthCheckRequest; +using grpc::health::v1::HealthCheckResponse; + +namespace grpc { +namespace testing { + +Status HealthCheckServiceImpl::Check(ServerContext* /*context*/, + const HealthCheckRequest* request, + HealthCheckResponse* response) { + std::lock_guard lock(mu_); + auto iter = status_map_.find(request->service()); + if (iter == status_map_.end()) { + return Status(StatusCode::NOT_FOUND, ""); + } + response->set_status(iter->second); + return Status::OK; +} + +Status HealthCheckServiceImpl::Watch( + ServerContext* context, const HealthCheckRequest* request, + ::grpc::ServerWriter* writer) { + auto last_state = HealthCheckResponse::UNKNOWN; + while (!context->IsCancelled()) { + { + std::lock_guard lock(mu_); + HealthCheckResponse response; + auto iter = status_map_.find(request->service()); + if (iter == status_map_.end()) { + response.set_status(response.SERVICE_UNKNOWN); + } else { + response.set_status(iter->second); + } + if (response.status() != last_state) { + writer->Write(response, ::grpc::WriteOptions()); + last_state = response.status(); + } + } + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(1000, GPR_TIMESPAN))); + } + return Status::OK; +} + +void HealthCheckServiceImpl::SetStatus( + const std::string& service_name, + HealthCheckResponse::ServingStatus status) { + std::lock_guard lock(mu_); + if (shutdown_) { + status = HealthCheckResponse::NOT_SERVING; + } + status_map_[service_name] = status; +} + +void HealthCheckServiceImpl::SetAll(HealthCheckResponse::ServingStatus status) { + std::lock_guard lock(mu_); + if (shutdown_) { + return; + } + for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { + iter->second = status; + } +} + +void HealthCheckServiceImpl::Shutdown() { + std::lock_guard lock(mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { + iter->second = HealthCheckResponse::NOT_SERVING; + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc new file mode 100644 index 00000000..6bc18e1a --- /dev/null +++ b/test/cpp/end2end/test_service_impl.cc @@ -0,0 +1,635 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/test_service_impl.h" + +#include +#include + +#include + +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace internal { + +// When echo_deadline is requested, deadline seen in the ServerContext is set in +// the response in seconds. +void MaybeEchoDeadline(ServerContextBase* context, const EchoRequest* request, + EchoResponse* response) { + if (request->has_param() && request->param().echo_deadline()) { + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_REALTIME); + if (context->deadline() != system_clock::time_point::max()) { + Timepoint2Timespec(context->deadline(), &deadline); + } + response->mutable_param()->set_request_deadline(deadline.tv_sec); + } +} + +void CheckServerAuthContext(const ServerContextBase* context, + const std::string& expected_transport_security_type, + const std::string& expected_client_identity) { + std::shared_ptr auth_ctx = context->auth_context(); + std::vector tst = + auth_ctx->FindPropertyValues("transport_security_type"); + EXPECT_EQ(1u, tst.size()); + EXPECT_EQ(expected_transport_security_type, ToString(tst[0])); + if (expected_client_identity.empty()) { + EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty()); + EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty()); + EXPECT_FALSE(auth_ctx->IsPeerAuthenticated()); + } else { + auto identity = auth_ctx->GetPeerIdentity(); + EXPECT_TRUE(auth_ctx->IsPeerAuthenticated()); + EXPECT_EQ(1u, identity.size()); + EXPECT_EQ(expected_client_identity, identity[0]); + } +} + +// Returns the number of pairs in metadata that exactly match the given +// key-value pair. Returns -1 if the pair wasn't found. +int MetadataMatchCount( + const std::multimap& metadata, + const std::string& key, const std::string& value) { + int count = 0; + for (const auto& metadatum : metadata) { + if (ToString(metadatum.first) == key && + ToString(metadatum.second) == value) { + count++; + } + } + return count; +} + +int GetIntValueFromMetadataHelper( + const char* key, + const std::multimap& metadata, + int default_value) { + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + gpr_log(GPR_INFO, "%s : %d", key, default_value); + } + + return default_value; +} + +int GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} + +void ServerTryCancel(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + // Now wait until it's really canceled + while (!context->IsCancelled()) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000, GPR_TIMESPAN))); + } +} + +void ServerTryCancelNonblocking(CallbackServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, + "Server called TryCancelNonblocking() to cancel the request"); +} + +} // namespace internal + +ServerUnaryReactor* CallbackTestServiceImpl::Echo( + CallbackServerContext* context, const EchoRequest* request, + EchoResponse* response) { + class Reactor : public ::grpc::ServerUnaryReactor { + public: + Reactor(CallbackTestServiceImpl* service, CallbackServerContext* ctx, + const EchoRequest* request, EchoResponse* response) + : service_(service), ctx_(ctx), req_(request), resp_(response) { + // It should be safe to call IsCancelled here, even though we don't know + // the result. Call it asynchronously to see if we trigger any data races. + // Join it in OnDone (technically that could be blocking but shouldn't be + // for very long). + async_cancel_check_ = std::thread([this] { (void)ctx_->IsCancelled(); }); + + started_ = true; + + if (request->has_param() && + request->param().server_notify_client_when_started()) { + service->signaller_.SignalClientThatRpcStarted(); + // Block on the "wait to continue" decision in a different thread since + // we can't tie up an EM thread with blocking events. We can join it in + // OnDone since it would definitely be done by then. + rpc_wait_thread_ = std::thread([this] { + service_->signaller_.ServerWaitToContinue(); + StartRpc(); + }); + } else { + StartRpc(); + } + } + + void StartRpc() { + if (req_->has_param() && req_->param().server_sleep_us() > 0) { + // Set an alarm for that much time + alarm_.Set( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(req_->param().server_sleep_us(), + GPR_TIMESPAN)), + [this](bool ok) { NonDelayed(ok); }); + return; + } + NonDelayed(true); + } + void OnSendInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + initial_metadata_sent_ = true; + } + void OnCancel() override { + EXPECT_TRUE(started_); + EXPECT_TRUE(ctx_->IsCancelled()); + on_cancel_invoked_ = true; + std::lock_guard l(cancel_mu_); + cancel_cv_.notify_one(); + } + void OnDone() override { + if (req_->has_param() && req_->param().echo_metadata_initially()) { + EXPECT_TRUE(initial_metadata_sent_); + } + EXPECT_EQ(ctx_->IsCancelled(), on_cancel_invoked_); + // Validate that finishing with a non-OK status doesn't cause cancellation + if (req_->has_param() && req_->param().has_expected_error()) { + EXPECT_FALSE(on_cancel_invoked_); + } + async_cancel_check_.join(); + if (rpc_wait_thread_.joinable()) { + rpc_wait_thread_.join(); + } + if (finish_when_cancelled_.joinable()) { + finish_when_cancelled_.join(); + } + delete this; + } + + private: + void NonDelayed(bool ok) { + if (!ok) { + EXPECT_TRUE(ctx_->IsCancelled()); + Finish(Status::CANCELLED); + return; + } + if (req_->has_param() && req_->param().server_die()) { + gpr_log(GPR_ERROR, "The request should not reach application handler."); + GPR_ASSERT(0); + } + if (req_->has_param() && req_->param().has_expected_error()) { + const auto& error = req_->param().expected_error(); + Finish(Status(static_cast(error.code()), + error.error_message(), error.binary_error_details())); + return; + } + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, ctx_->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel != DO_NOT_CANCEL) { + // Since this is a unary RPC, by the time this server handler is called, + // the 'request' message is already read from the client. So the + // scenarios in server_try_cancel don't make much sense. Just cancel the + // RPC as long as server_try_cancel is not DO_NOT_CANCEL + EXPECT_FALSE(ctx_->IsCancelled()); + ctx_->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + FinishWhenCancelledAsync(); + return; + } + resp_->set_message(req_->message()); + internal::MaybeEchoDeadline(ctx_, req_, resp_); + if (service_->host_) { + resp_->mutable_param()->set_host(*service_->host_); + } + if (req_->has_param() && req_->param().client_cancel_after_us()) { + { + std::unique_lock lock(service_->mu_); + service_->signal_client_ = true; + } + FinishWhenCancelledAsync(); + return; + } else if (req_->has_param() && req_->param().server_cancel_after_us()) { + alarm_.Set(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros( + req_->param().server_cancel_after_us(), + GPR_TIMESPAN)), + [this](bool) { Finish(Status::CANCELLED); }); + return; + } else if (!req_->has_param() || !req_->param().skip_cancelled_check()) { + EXPECT_FALSE(ctx_->IsCancelled()); + } + + if (req_->has_param() && req_->param().echo_metadata_initially()) { + const std::multimap& + client_metadata = ctx_->client_metadata(); + for (const auto& metadatum : client_metadata) { + ctx_->AddInitialMetadata(ToString(metadatum.first), + ToString(metadatum.second)); + } + StartSendInitialMetadata(); + } + + if (req_->has_param() && req_->param().echo_metadata()) { + const std::multimap& + client_metadata = ctx_->client_metadata(); + for (const auto& metadatum : client_metadata) { + ctx_->AddTrailingMetadata(ToString(metadatum.first), + ToString(metadatum.second)); + } + // Terminate rpc with error and debug info in trailer. + if (req_->param().debug_info().stack_entries_size() || + !req_->param().debug_info().detail().empty()) { + std::string serialized_debug_info = + req_->param().debug_info().SerializeAsString(); + ctx_->AddTrailingMetadata(kDebugInfoTrailerKey, + serialized_debug_info); + Finish(Status::CANCELLED); + return; + } + } + if (req_->has_param() && + (req_->param().expected_client_identity().length() > 0 || + req_->param().check_auth_context())) { + internal::CheckServerAuthContext( + ctx_, req_->param().expected_transport_security_type(), + req_->param().expected_client_identity()); + } + if (req_->has_param() && req_->param().response_message_length() > 0) { + resp_->set_message( + std::string(req_->param().response_message_length(), '\0')); + } + if (req_->has_param() && req_->param().echo_peer()) { + resp_->mutable_param()->set_peer(ctx_->peer()); + } + Finish(Status::OK); + } + void FinishWhenCancelledAsync() { + finish_when_cancelled_ = std::thread([this] { + std::unique_lock l(cancel_mu_); + cancel_cv_.wait(l, [this] { return ctx_->IsCancelled(); }); + Finish(Status::CANCELLED); + }); + } + + CallbackTestServiceImpl* const service_; + CallbackServerContext* const ctx_; + const EchoRequest* const req_; + EchoResponse* const resp_; + Alarm alarm_; + std::mutex cancel_mu_; + std::condition_variable cancel_cv_; + bool initial_metadata_sent_ = false; + bool started_ = false; + bool on_cancel_invoked_ = false; + std::thread async_cancel_check_; + std::thread rpc_wait_thread_; + std::thread finish_when_cancelled_; + }; + + return new Reactor(this, context, request, response); +} + +ServerUnaryReactor* CallbackTestServiceImpl::CheckClientInitialMetadata( + CallbackServerContext* context, const SimpleRequest*, SimpleResponse*) { + class Reactor : public ::grpc::ServerUnaryReactor { + public: + explicit Reactor(CallbackServerContext* ctx) { + EXPECT_EQ(internal::MetadataMatchCount(ctx->client_metadata(), + kCheckClientInitialMetadataKey, + kCheckClientInitialMetadataVal), + 1); + EXPECT_EQ(ctx->client_metadata().count(kCheckClientInitialMetadataKey), + 1u); + Finish(Status::OK); + } + void OnDone() override { delete this; } + }; + + return new Reactor(context); +} + +ServerReadReactor* CallbackTestServiceImpl::RequestStream( + CallbackServerContext* context, EchoResponse* response) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(context); + // Don't need to provide a reactor since the RPC is canceled + return nullptr; + } + + class Reactor : public ::grpc::ServerReadReactor { + public: + Reactor(CallbackServerContext* ctx, EchoResponse* response, + int server_try_cancel) + : ctx_(ctx), + response_(response), + server_try_cancel_(server_try_cancel) { + EXPECT_NE(server_try_cancel, CANCEL_BEFORE_PROCESSING); + response->set_message(""); + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + // Don't wait for it here + } + StartRead(&request_); + setup_done_ = true; + } + void OnDone() override { delete this; } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnReadDone(bool ok) override { + if (ok) { + response_->mutable_message()->append(request_.message()); + num_msgs_read_++; + StartRead(&request_); + } else { + gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_); + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + return; + } + if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + return; + } + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + CallbackServerContext* const ctx_; + EchoResponse* const response_; + EchoRequest request_; + int num_msgs_read_{0}; + int server_try_cancel_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + }; + + return new Reactor(context, response, server_try_cancel); +} + +// Return 'kNumResponseStreamMsgs' messages. +// TODO(yangg) make it generic by adding a parameter into EchoRequest +ServerWriteReactor* CallbackTestServiceImpl::ResponseStream( + CallbackServerContext* context, const EchoRequest* request) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(context); + } + + class Reactor : public ::grpc::ServerWriteReactor { + public: + Reactor(CallbackServerContext* ctx, const EchoRequest* request, + int server_try_cancel) + : ctx_(ctx), request_(request), server_try_cancel_(server_try_cancel) { + server_coalescing_api_ = internal::GetIntValueFromMetadata( + kServerUseCoalescingApi, ctx->client_metadata(), 0); + server_responses_to_send_ = internal::GetIntValueFromMetadata( + kServerResponseStreamsToSend, ctx->client_metadata(), + kServerDefaultResponseStreamsToSend); + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + } + if (server_try_cancel_ != CANCEL_BEFORE_PROCESSING) { + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } + } + setup_done_ = true; + } + void OnDone() override { delete this; } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnWriteDone(bool /*ok*/) override { + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } else if (server_coalescing_api_ != 0) { + // We would have already done Finish just after the WriteLast + } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + void NextWrite() { + response_.set_message(request_->message() + + std::to_string(num_msgs_sent_)); + if (num_msgs_sent_ == server_responses_to_send_ - 1 && + server_coalescing_api_ != 0) { + { + std::lock_guard l(finish_mu_); + if (!finished_) { + num_msgs_sent_++; + StartWriteLast(&response_, WriteOptions()); + } + } + // If we use WriteLast, we shouldn't wait before attempting Finish + FinishOnce(Status::OK); + } else { + std::lock_guard l(finish_mu_); + if (!finished_) { + num_msgs_sent_++; + StartWrite(&response_); + } + } + } + CallbackServerContext* const ctx_; + const EchoRequest* const request_; + EchoResponse response_; + int num_msgs_sent_{0}; + int server_try_cancel_; + int server_coalescing_api_; + int server_responses_to_send_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + }; + return new Reactor(context, request, server_try_cancel); +} + +ServerBidiReactor* +CallbackTestServiceImpl::BidiStream(CallbackServerContext* context) { + class Reactor : public ::grpc::ServerBidiReactor { + public: + explicit Reactor(CallbackServerContext* ctx) : ctx_(ctx) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + server_try_cancel_ = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, ctx->client_metadata(), DO_NOT_CANCEL); + server_write_last_ = internal::GetIntValueFromMetadata( + kServerFinishAfterNReads, ctx->client_metadata(), 0); + client_try_cancel_ = static_cast(internal::GetIntValueFromMetadata( + kClientTryCancelRequest, ctx->client_metadata(), 0)); + if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx); + } else { + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + } + StartRead(&request_); + } + setup_done_ = true; + } + void OnDone() override { + { + // Use the same lock as finish to make sure that OnDone isn't inlined. + std::lock_guard l(finish_mu_); + EXPECT_TRUE(finished_); + finish_thread_.join(); + } + delete this; + } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnReadDone(bool ok) override { + if (ok) { + num_msgs_read_++; + response_.set_message(request_.message()); + std::lock_guard l(finish_mu_); + if (!finished_) { + if (num_msgs_read_ == server_write_last_) { + StartWriteLast(&response_, WriteOptions()); + // If we use WriteLast, we shouldn't wait before attempting Finish + } else { + StartWrite(&response_); + return; + } + } + } else if (client_try_cancel_) { + EXPECT_TRUE(ctx_->IsCancelled()); + } + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel handle this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + void OnWriteDone(bool /*ok*/) override { + std::lock_guard l(finish_mu_); + if (!finished_) { + StartRead(&request_); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + finished_ = true; + // Finish asynchronously to make sure that there are no deadlocks. + finish_thread_ = std::thread([this, s] { + std::lock_guard l(finish_mu_); + Finish(s); + }); + } + } + + CallbackServerContext* const ctx_; + EchoRequest request_; + EchoResponse response_; + int num_msgs_read_{0}; + int server_try_cancel_; + int server_write_last_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + std::thread finish_thread_; + bool client_try_cancel_ = false; + }; + + return new Reactor(context); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/end2end/thread_stress_test.cc b/test/cpp/end2end/thread_stress_test.cc new file mode 100644 index 00000000..8b5170b9 --- /dev/null +++ b/test/cpp/end2end/thread_stress_test.cc @@ -0,0 +1,440 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +const int kNumThreads = 100; // Number of threads +const int kNumAsyncSendThreads = 2; +const int kNumAsyncReceiveThreads = 50; +const int kNumAsyncServerThreads = 50; +const int kNumRpcs = 1000; // Number of RPCs per thread + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + TestServiceImpl() {} + + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + return Status::OK; + } +}; + +template +class CommonStressTest { + public: + CommonStressTest() : kMaxMessageSize_(8192) { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + virtual ~CommonStressTest() {} + virtual void SetUp() = 0; + virtual void TearDown() = 0; + virtual void ResetStub() = 0; + virtual bool AllowExhaustion() = 0; + grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); } + + protected: + std::unique_ptr stub_; + std::unique_ptr server_; + + virtual void SetUpStart(ServerBuilder* builder, Service* service) = 0; + void SetUpStartCommon(ServerBuilder* builder, Service* service) { + builder->RegisterService(service); + builder->SetMaxMessageSize( + kMaxMessageSize_); // For testing max message size. + } + void SetUpEnd(ServerBuilder* builder) { server_ = builder->BuildAndStart(); } + void TearDownStart() { server_->Shutdown(); } + void TearDownEnd() {} + + private: + const int kMaxMessageSize_; +}; + +template +class CommonStressTestInsecure : public CommonStressTest { + public: + void ResetStub() override { + std::shared_ptr channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + this->stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + bool AllowExhaustion() override { return false; } + + protected: + void SetUpStart(ServerBuilder* builder, Service* service) override { + int port = grpc_pick_unused_port_or_die(); + this->server_address_ << "localhost:" << port; + // Setup server + builder->AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + this->SetUpStartCommon(builder, service); + } + + private: + std::ostringstream server_address_; +}; + +template +class CommonStressTestInproc : public CommonStressTest { + public: + void ResetStub() override { + ChannelArguments args; + std::shared_ptr channel = this->server_->InProcessChannel(args); + this->stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + bool AllowExhaustion() override { return allow_resource_exhaustion; } + + protected: + void SetUpStart(ServerBuilder* builder, Service* service) override { + this->SetUpStartCommon(builder, service); + } +}; + +template +class CommonStressTestSyncServer : public BaseClass { + public: + void SetUp() override { + ServerBuilder builder; + this->SetUpStart(&builder, &service_); + this->SetUpEnd(&builder); + } + void TearDown() override { + this->TearDownStart(); + this->TearDownEnd(); + } + + private: + TestServiceImpl service_; +}; + +template +class CommonStressTestSyncServerLowThreadCount : public BaseClass { + public: + void SetUp() override { + ServerBuilder builder; + ResourceQuota quota; + this->SetUpStart(&builder, &service_); + quota.SetMaxThreads(4); + builder.SetResourceQuota(quota); + this->SetUpEnd(&builder); + } + void TearDown() override { + this->TearDownStart(); + this->TearDownEnd(); + } + + private: + TestServiceImpl service_; +}; + +template +class CommonStressTestAsyncServer : public BaseClass { + public: + CommonStressTestAsyncServer() : contexts_(kNumAsyncServerThreads * 100) {} + void SetUp() override { + shutting_down_ = false; + ServerBuilder builder; + this->SetUpStart(&builder, &service_); + cq_ = builder.AddCompletionQueue(); + this->SetUpEnd(&builder); + for (int i = 0; i < kNumAsyncServerThreads * 100; i++) { + RefreshContext(i); + } + for (int i = 0; i < kNumAsyncServerThreads; i++) { + server_threads_.emplace_back(&CommonStressTestAsyncServer::ProcessRpcs, + this); + } + } + void TearDown() override { + { + grpc::internal::MutexLock l(&mu_); + this->TearDownStart(); + shutting_down_ = true; + cq_->Shutdown(); + } + + for (int i = 0; i < kNumAsyncServerThreads; i++) { + server_threads_[i].join(); + } + + void* ignored_tag; + bool ignored_ok; + while (cq_->Next(&ignored_tag, &ignored_ok)) { + } + this->TearDownEnd(); + } + + private: + void ProcessRpcs() { + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + if (ok) { + int i = static_cast(reinterpret_cast(tag)); + switch (contexts_[i].state) { + case Context::READY: { + contexts_[i].state = Context::DONE; + EchoResponse send_response; + send_response.set_message(contexts_[i].recv_request.message()); + contexts_[i].response_writer->Finish(send_response, Status::OK, + tag); + break; + } + case Context::DONE: + RefreshContext(i); + break; + } + } + } + } + void RefreshContext(int i) { + grpc::internal::MutexLock l(&mu_); + if (!shutting_down_) { + contexts_[i].state = Context::READY; + contexts_[i].srv_ctx.reset(new ServerContext); + contexts_[i].response_writer.reset( + new grpc::ServerAsyncResponseWriter( + contexts_[i].srv_ctx.get())); + service_.RequestEcho(contexts_[i].srv_ctx.get(), + &contexts_[i].recv_request, + contexts_[i].response_writer.get(), cq_.get(), + cq_.get(), reinterpret_cast(i)); + } + } + struct Context { + std::unique_ptr srv_ctx; + std::unique_ptr> + response_writer; + EchoRequest recv_request; + enum { READY, DONE } state; + }; + std::vector contexts_; + ::grpc::testing::EchoTestService::AsyncService service_; + std::unique_ptr cq_; + bool shutting_down_; + grpc::internal::Mutex mu_; + std::vector server_threads_; +}; + +template +class End2endTest : public ::testing::Test { + protected: + End2endTest() {} + void SetUp() override { common_.SetUp(); } + void TearDown() override { common_.TearDown(); } + void ResetStub() { common_.ResetStub(); } + + Common common_; +}; + +static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs, + bool allow_exhaustion, gpr_atm* errors) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + Status s = stub->Echo(&context, request, &response); + EXPECT_TRUE(s.ok() || (allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)); + if (!s.ok()) { + if (!(allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)) { + gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(), + s.error_message().c_str()); + } + gpr_atm_no_barrier_fetch_add(errors, static_cast(1)); + } else { + EXPECT_EQ(response.message(), request.message()); + } + } +} + +typedef ::testing::Types< + CommonStressTestSyncServer>, + CommonStressTestSyncServer>, + CommonStressTestSyncServerLowThreadCount< + CommonStressTestInproc>, + CommonStressTestAsyncServer< + CommonStressTestInsecure>, + CommonStressTestAsyncServer>> + CommonTypes; +TYPED_TEST_SUITE(End2endTest, CommonTypes); +TYPED_TEST(End2endTest, ThreadStress) { + this->common_.ResetStub(); + std::vector threads; + gpr_atm errors; + gpr_atm_rel_store(&errors, static_cast(0)); + threads.reserve(kNumThreads); + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs, + this->common_.AllowExhaustion(), &errors); + } + for (int i = 0; i < kNumThreads; ++i) { + threads[i].join(); + } + uint64_t error_cnt = static_cast(gpr_atm_no_barrier_load(&errors)); + if (error_cnt != 0) { + gpr_log(GPR_INFO, "RPC error count: %" PRIu64, error_cnt); + } + // If this test allows resource exhaustion, expect that it actually sees some + if (this->common_.AllowExhaustion()) { + EXPECT_GT(error_cnt, static_cast(0)); + } +} + +template +class AsyncClientEnd2endTest : public ::testing::Test { + protected: + AsyncClientEnd2endTest() : rpcs_outstanding_(0) {} + + void SetUp() override { common_.SetUp(); } + void TearDown() override { + void* ignored_tag; + bool ignored_ok; + while (cq_.Next(&ignored_tag, &ignored_ok)) { + } + common_.TearDown(); + } + + void Wait() { + grpc::internal::MutexLock l(&mu_); + while (rpcs_outstanding_ != 0) { + cv_.Wait(&mu_); + } + + cq_.Shutdown(); + } + + struct AsyncClientCall { + EchoResponse response; + ClientContext context; + Status status; + std::unique_ptr> response_reader; + }; + + void AsyncSendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; ++i) { + AsyncClientCall* call = new AsyncClientCall; + EchoRequest request; + request.set_message("Hello: " + std::to_string(i)); + call->response_reader = + common_.GetStub()->AsyncEcho(&call->context, request, &cq_); + call->response_reader->Finish(&call->response, &call->status, call); + + grpc::internal::MutexLock l(&mu_); + rpcs_outstanding_++; + } + } + + void AsyncCompleteRpc() { + while (true) { + void* got_tag; + bool ok = false; + if (!cq_.Next(&got_tag, &ok)) break; + AsyncClientCall* call = static_cast(got_tag); + if (!ok) { + gpr_log(GPR_DEBUG, "Error: %d", call->status.error_code()); + } + delete call; + + bool notify; + { + grpc::internal::MutexLock l(&mu_); + rpcs_outstanding_--; + notify = (rpcs_outstanding_ == 0); + } + if (notify) { + cv_.Signal(); + } + } + } + + Common common_; + CompletionQueue cq_; + grpc::internal::Mutex mu_; + grpc::internal::CondVar cv_; + int rpcs_outstanding_; +}; + +TYPED_TEST_SUITE(AsyncClientEnd2endTest, CommonTypes); +TYPED_TEST(AsyncClientEnd2endTest, ThreadStress) { + this->common_.ResetStub(); + std::vector send_threads, completion_threads; + for (int i = 0; i < kNumAsyncReceiveThreads; ++i) { + completion_threads.emplace_back( + &AsyncClientEnd2endTest_ThreadStress_Test::AsyncCompleteRpc, + this); + } + for (int i = 0; i < kNumAsyncSendThreads; ++i) { + send_threads.emplace_back( + &AsyncClientEnd2endTest_ThreadStress_Test::AsyncSendRpc, + this, kNumRpcs); + } + for (int i = 0; i < kNumAsyncSendThreads; ++i) { + send_threads[i].join(); + } + + this->Wait(); + for (int i = 0; i < kNumAsyncReceiveThreads; ++i) { + completion_threads[i].join(); + } +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/time_change_test.cc b/test/cpp/end2end/time_change_test.cc new file mode 100644 index 00000000..0ea080e6 --- /dev/null +++ b/test/cpp/end2end/time_change_test.cc @@ -0,0 +1,371 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/iomgr/timer.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/subprocess.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +static std::string g_root; + +static gpr_mu g_mu; +extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type); +gpr_timespec (*gpr_now_impl_orig)(gpr_clock_type clock_type) = gpr_now_impl; +static int g_time_shift_sec = 0; +static int g_time_shift_nsec = 0; +static gpr_timespec now_impl(gpr_clock_type clock) { + auto ts = gpr_now_impl_orig(clock); + // We only manipulate the realtime clock to simulate changes in wall-clock + // time + if (clock != GPR_CLOCK_REALTIME) { + return ts; + } + GPR_ASSERT(ts.tv_nsec >= 0); + GPR_ASSERT(ts.tv_nsec < GPR_NS_PER_SEC); + gpr_mu_lock(&g_mu); + ts.tv_sec += g_time_shift_sec; + ts.tv_nsec += g_time_shift_nsec; + gpr_mu_unlock(&g_mu); + if (ts.tv_nsec >= GPR_NS_PER_SEC) { + ts.tv_nsec -= GPR_NS_PER_SEC; + ++ts.tv_sec; + } else if (ts.tv_nsec < 0) { + --ts.tv_sec; + ts.tv_nsec = GPR_NS_PER_SEC + ts.tv_nsec; + } + return ts; +} + +// offset the value returned by gpr_now(GPR_CLOCK_REALTIME) by msecs +// milliseconds +static void set_now_offset(int msecs) { + gpr_mu_lock(&g_mu); + g_time_shift_sec = msecs / 1000; + g_time_shift_nsec = (msecs % 1000) * 1e6; + gpr_mu_unlock(&g_mu); +} + +// restore the original implementation of gpr_now() +static void reset_now_offset() { + gpr_mu_lock(&g_mu); + g_time_shift_sec = 0; + g_time_shift_nsec = 0; + gpr_mu_unlock(&g_mu); +} + +namespace grpc { +namespace testing { + +namespace { + +// gpr_now() is called with invalid clock_type +TEST(TimespecTest, GprNowInvalidClockType) { + // initialize to some junk value + gpr_clock_type invalid_clock_type = static_cast(32641); + EXPECT_DEATH(gpr_now(invalid_clock_type), ".*"); +} + +// Add timespan with negative nanoseconds +TEST(TimespecTest, GprTimeAddNegativeNs) { + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN}; + EXPECT_DEATH(gpr_time_add(now, bad_ts), ".*"); +} + +// Subtract timespan with negative nanoseconds +TEST(TimespecTest, GprTimeSubNegativeNs) { + // Nanoseconds must always be positive. Negative timestamps are represented by + // (negative seconds, positive nanoseconds) + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN}; + EXPECT_DEATH(gpr_time_sub(now, bad_ts), ".*"); +} + +// Add negative milliseconds to gpr_timespec +TEST(TimespecTest, GrpcNegativeMillisToTimespec) { + // -1500 milliseconds converts to timespec (-2 secs, 5 * 10^8 nsec) + gpr_timespec ts = grpc_millis_to_timespec(-1500, GPR_CLOCK_MONOTONIC); + GPR_ASSERT(ts.tv_sec = -2); + GPR_ASSERT(ts.tv_nsec = 5e8); + GPR_ASSERT(ts.clock_type == GPR_CLOCK_MONOTONIC); +} + +class TimeChangeTest : public ::testing::Test { + protected: + TimeChangeTest() {} + + static void SetUpTestCase() { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + server_address_ = addr_stream.str(); + server_ = absl::make_unique(std::vector({ + g_root + "/client_crash_test_server", + "--address=" + server_address_, + })); + GPR_ASSERT(server_); + // connect to server and make sure it's reachable. + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + GPR_ASSERT(channel); + EXPECT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(30000))); + } + + static void TearDownTestCase() { server_.reset(); } + + void SetUp() override { + channel_ = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + GPR_ASSERT(channel_); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() override { reset_now_offset(); } + + std::unique_ptr CreateStub() { + return grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr GetChannel() { return channel_; } + // time jump offsets in milliseconds + const int TIME_OFFSET1 = 20123; + const int TIME_OFFSET2 = 5678; + + private: + static std::string server_address_; + static std::unique_ptr server_; + std::shared_ptr channel_; + std::unique_ptr stub_; +}; +std::string TimeChangeTest::server_address_; +std::unique_ptr TimeChangeTest::server_; + +// Wall-clock time jumps forward on client before bidi stream is created +TEST_F(TimeChangeTest, TimeJumpForwardBeforeStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + // time jumps forward by TIME_OFFSET1 milliseconds + set_now_offset(TIME_OFFSET1); + auto stream = stub->BidiStream(&context); + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps back on client before bidi stream is created +TEST_F(TimeChangeTest, TimeJumpBackBeforeStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + // time jumps back by TIME_OFFSET1 milliseconds + set_now_offset(-TIME_OFFSET1); + auto stream = stub->BidiStream(&context); + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(request.message(), response.message()); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps forward on client while call is in progress +TEST_F(TimeChangeTest, TimeJumpForwardAfterStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + + // time jumps forward by TIME_OFFSET1 milliseconds. + set_now_offset(TIME_OFFSET1); + + request.set_message("World"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps back on client while call is in progress +TEST_F(TimeChangeTest, TimeJumpBackAfterStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + + // time jumps back TIME_OFFSET1 milliseconds. + set_now_offset(-TIME_OFFSET1); + + request.set_message("World"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps forward and backwards during call +TEST_F(TimeChangeTest, TimeJumpForwardAndBackDuringCall) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->Read(&response)); + request.set_message("World"); + + // time jumps forward by TIME_OFFSET milliseconds + set_now_offset(TIME_OFFSET1); + + EXPECT_TRUE(stream->Write(request)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->WritesDone()); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->Read(&response)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + std::string me = argv[0]; + // get index of last slash in path to test binary + auto lslash = me.rfind('/'); + // set g_root = path to directory containing test binary + if (lslash != std::string::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + gpr_mu_init(&g_mu); + gpr_now_impl = now_impl; + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/end2end/xds/xds_credentials_end2end_test.cc b/test/cpp/end2end/xds/xds_credentials_end2end_test.cc new file mode 100644 index 00000000..a6704716 --- /dev/null +++ b/test/cpp/end2end/xds/xds_credentials_end2end_test.cc @@ -0,0 +1,127 @@ +// +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +#include +#include + +#include +#include + +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { +namespace { + +class XdsCredentialsEnd2EndFallbackTest + : public ::testing::TestWithParam { + protected: + XdsCredentialsEnd2EndFallbackTest() { + int port = grpc_pick_unused_port_or_die(); + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort( + server_address_, + GetCredentialsProvider()->GetServerCredentials(GetParam())); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_P(XdsCredentialsEnd2EndFallbackTest, NoXdsSchemeInTarget) { + // Target does not use 'xds:///' scheme and should result in using fallback + // credentials. + ChannelArguments args; + auto channel = grpc::CreateCustomChannel( + server_address_, + grpc::XdsCredentials( + GetCredentialsProvider()->GetChannelCredentials(GetParam(), &args)), + args); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +class XdsServerCredentialsEnd2EndFallbackTest + : public ::testing::TestWithParam { + protected: + XdsServerCredentialsEnd2EndFallbackTest() { + int port = grpc_pick_unused_port_or_die(); + // Build a server that is not xDS enabled but uses XdsServerCredentials. + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort( + server_address_, + grpc::XdsServerCredentials( + GetCredentialsProvider()->GetServerCredentials(GetParam()))); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_P(XdsServerCredentialsEnd2EndFallbackTest, Basic) { + ChannelArguments args; + auto channel = grpc::CreateCustomChannel( + server_address_, + GetCredentialsProvider()->GetChannelCredentials(GetParam(), &args), args); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +INSTANTIATE_TEST_SUITE_P(XdsCredentialsEnd2EndFallback, + XdsCredentialsEnd2EndFallbackTest, + ::testing::ValuesIn(std::vector( + {kInsecureCredentialsType, kTlsCredentialsType}))); + +INSTANTIATE_TEST_SUITE_P(XdsServerCredentialsEnd2EndFallback, + XdsServerCredentialsEnd2EndFallbackTest, + ::testing::ValuesIn(std::vector( + {kInsecureCredentialsType, kTlsCredentialsType}))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/end2end/xds/xds_end2end_test.cc b/test/cpp/end2end/xds/xds_end2end_test.cc new file mode 100644 index 00000000..66553e71 --- /dev/null +++ b/test/cpp/end2end/xds/xds_end2end_test.cc @@ -0,0 +1,12672 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// TODO(roth): Split this file up into a common test framework and a set +// of test files that use that framework. Need to figure out the best +// way to split up the tests. One option would be to split it up by xDS +// resource type; another approach would be to have all of the "core" +// xDS functionality in one file and then move specific features to +// their own files (e.g., mTLS security, fault injection, circuit +// breaking, etc). + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/functional/bind_front.h" +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/xds/certificate_provider_registry.h" +#include "src/core/ext/xds/xds_api.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/time_precise.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/time_util.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/xds/ads_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/cds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/eds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/lds_rds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/lrs_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/ads.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/aggregate_cluster.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/cluster.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/discovery.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/endpoint.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/fault.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/http_connection_manager.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/listener.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/lrs.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/route.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/router.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/tls.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/resolve_localhost_ip46.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/counted_service.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/end2end/xds/xds_server.h" +#include "test/cpp/util/test_config.h" + +#ifndef DISABLED_XDS_PROTO_IN_CC +#include "src/cpp/server/csds/csds.h" +#include "src/proto/grpc/testing/xds/v3/csds.grpc.pb.h" +#endif // DISABLED_XDS_PROTO_IN_CC + +namespace grpc { +namespace testing { +namespace { + +using std::chrono::system_clock; + +#ifndef DISABLED_XDS_PROTO_IN_CC +using ::envoy::admin::v3::ClientResourceStatus; +#endif // DISABLED_XDS_PROTO_IN_CC +using ::envoy::config::cluster::v3::CircuitBreakers; +using ::envoy::config::cluster::v3::Cluster; +using ::envoy::config::cluster::v3::CustomClusterType; +using ::envoy::config::cluster::v3::RoutingPriority; +using ::envoy::config::endpoint::v3::ClusterLoadAssignment; +using ::envoy::config::endpoint::v3::HealthStatus; +using ::envoy::config::listener::v3::FilterChainMatch; +using ::envoy::config::listener::v3::Listener; +using ::envoy::config::route::v3::RouteConfiguration; +using ::envoy::extensions::clusters::aggregate::v3::ClusterConfig; +using ::envoy::extensions::filters::http::fault::v3::HTTPFault; +using ::envoy::extensions::filters::network::http_connection_manager::v3:: + HttpConnectionManager; +using ::envoy::extensions::filters::network::http_connection_manager::v3:: + HttpFilter; +using ::envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext; +using ::envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext; +using ::envoy::type::matcher::v3::StringMatcher; +using ::envoy::type::v3::FractionalPercent; + +using ClientStats = LrsServiceImpl::ClientStats; + +constexpr char kDefaultLocalityRegion[] = "xds_default_locality_region"; +constexpr char kDefaultLocalityZone[] = "xds_default_locality_zone"; +constexpr char kLbDropType[] = "lb"; +constexpr char kThrottleDropType[] = "throttle"; +constexpr char kServerName[] = "server.example.com"; +constexpr char kDefaultRouteConfigurationName[] = "route_config_name"; +constexpr char kDefaultClusterName[] = "cluster_name"; +constexpr char kDefaultEdsServiceName[] = "eds_service_name"; +constexpr int kDefaultLocalityWeight = 3; +constexpr int kDefaultLocalityPriority = 0; + +constexpr char kRequestMessage[] = "Live long and prosper."; +constexpr char kDefaultServiceConfig[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"xds_cluster_resolver_experimental\":{\n" + " \"discoveryMechanisms\": [\n" + " { \"clusterName\": \"server.example.com\",\n" + " \"type\": \"EDS\",\n" + " \"lrsLoadReportingServerName\": \"\"\n" + " } ]\n" + " } }\n" + " ]\n" + "}"; +constexpr char kDefaultServiceConfigWithoutLoadReporting[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"xds_cluster_resolver_experimental\":{\n" + " \"discoveryMechanisms\": [\n" + " { \"clusterName\": \"server.example.com\",\n" + " \"type\": \"EDS\"\n" + " } ]\n" + " } }\n" + " ]\n" + "}"; + +constexpr char kBootstrapFileV3[] = + "{\n" + " \"xds_servers\": [\n" + " {\n" + " \"server_uri\": \"fake:///xds_server\",\n" + " \"channel_creds\": [\n" + " {\n" + " \"type\": \"fake\"\n" + " }\n" + " ],\n" + " \"server_features\": [\"xds_v3\"]\n" + " }\n" + " ],\n" + " \"node\": {\n" + " \"id\": \"xds_end2end_test\",\n" + " \"cluster\": \"test\",\n" + " \"metadata\": {\n" + " \"foo\": \"bar\"\n" + " },\n" + " \"locality\": {\n" + " \"region\": \"corp\",\n" + " \"zone\": \"svl\",\n" + " \"sub_zone\": \"mp3\"\n" + " }\n" + " },\n" + " \"server_listener_resource_name_template\": " + "\"grpc/server?xds.resource.listening_address=%s\",\n" + " \"certificate_providers\": {\n" + " \"fake_plugin1\": {\n" + " \"plugin_name\": \"fake1\"\n" + " },\n" + " \"fake_plugin2\": {\n" + " \"plugin_name\": \"fake2\"\n" + " },\n" + " \"file_plugin\": {\n" + " \"plugin_name\": \"file_watcher\",\n" + " \"config\": {\n" + " \"certificate_file\": \"src/core/tsi/test_creds/client.pem\",\n" + " \"private_key_file\": \"src/core/tsi/test_creds/client.key\",\n" + " \"ca_certificate_file\": \"src/core/tsi/test_creds/ca.pem\"\n" + " }" + " }\n" + " }\n" + "}\n"; + +constexpr char kBootstrapFileV2[] = + "{\n" + " \"xds_servers\": [\n" + " {\n" + " \"server_uri\": \"fake:///xds_server\",\n" + " \"channel_creds\": [\n" + " {\n" + " \"type\": \"fake\"\n" + " }\n" + " ]\n" + " }\n" + " ],\n" + " \"node\": {\n" + " \"id\": \"xds_end2end_test\",\n" + " \"cluster\": \"test\",\n" + " \"metadata\": {\n" + " \"foo\": \"bar\"\n" + " },\n" + " \"locality\": {\n" + " \"region\": \"corp\",\n" + " \"zone\": \"svl\",\n" + " \"sub_zone\": \"mp3\"\n" + " }\n" + " }\n" + "}\n"; +constexpr char kCaCertPath[] = "src/core/tsi/test_creds/ca.pem"; +constexpr char kServerCertPath[] = "src/core/tsi/test_creds/server1.pem"; +constexpr char kServerKeyPath[] = "src/core/tsi/test_creds/server1.key"; +constexpr char kClientCertPath[] = "src/core/tsi/test_creds/client.pem"; +constexpr char kClientKeyPath[] = "src/core/tsi/test_creds/client.key"; +constexpr char kBadClientCertPath[] = "src/core/tsi/test_creds/badclient.pem"; +constexpr char kBadClientKeyPath[] = "src/core/tsi/test_creds/badclient.key"; + +char* g_bootstrap_file_v3; +char* g_bootstrap_file_v2; + +void WriteBootstrapFiles() { + char* bootstrap_file; + FILE* out = gpr_tmpfile("xds_bootstrap_v3", &bootstrap_file); + fputs(kBootstrapFileV3, out); + fclose(out); + g_bootstrap_file_v3 = bootstrap_file; + out = gpr_tmpfile("xds_bootstrap_v2", &bootstrap_file); + fputs(kBootstrapFileV2, out); + fclose(out); + g_bootstrap_file_v2 = bootstrap_file; +} + +template +class BackendServiceImpl + : public CountedService> { + public: + BackendServiceImpl() {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + auto peer_identity = context->auth_context()->GetPeerIdentity(); + CountedService>::IncreaseRequestCount(); + const auto status = + TestMultipleServiceImpl::Echo(context, request, response); + CountedService< + TestMultipleServiceImpl>::IncreaseResponseCount(); + { + grpc_core::MutexLock lock(&mu_); + clients_.insert(context->peer()); + last_peer_identity_.clear(); + for (const auto& entry : peer_identity) { + last_peer_identity_.emplace_back(entry.data(), entry.size()); + } + } + return status; + } + + Status Echo1(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + return Echo(context, request, response); + } + + Status Echo2(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + return Echo(context, request, response); + } + + void Start() {} + void Shutdown() {} + + std::set clients() { + grpc_core::MutexLock lock(&mu_); + return clients_; + } + + const std::vector& last_peer_identity() { + grpc_core::MutexLock lock(&mu_); + return last_peer_identity_; + } + + private: + grpc_core::Mutex mu_; + std::set clients_ ABSL_GUARDED_BY(mu_); + std::vector last_peer_identity_ ABSL_GUARDED_BY(mu_); +}; + +class TestType { + public: + enum FilterConfigSetup { + // Set the fault injection filter directly from LDS + kHTTPConnectionManagerOriginal, + // Enable the fault injection filter in LDS, but override the filter config + // in route. + kRouteOverride, + }; + + enum BootstrapSource { + kBootstrapFromChannelArg, + kBootstrapFromFile, + kBootstrapFromEnvVar, + }; + + TestType& set_use_fake_resolver() { + use_fake_resolver_ = true; + return *this; + } + + TestType& set_enable_load_reporting() { + enable_load_reporting_ = true; + return *this; + } + + TestType& set_enable_rds_testing() { + enable_rds_testing_ = true; + return *this; + } + + TestType& set_use_v2() { + use_v2_ = true; + return *this; + } + + TestType& set_use_xds_credentials() { + use_xds_credentials_ = true; + return *this; + } + + TestType& set_use_csds_streaming() { + use_csds_streaming_ = true; + return *this; + } + + TestType& set_filter_config_setup(FilterConfigSetup setup) { + filter_config_setup_ = setup; + return *this; + } + + TestType& set_bootstrap_source(BootstrapSource bootstrap_source) { + bootstrap_source_ = bootstrap_source; + return *this; + } + + bool use_fake_resolver() const { return use_fake_resolver_; } + bool enable_load_reporting() const { return enable_load_reporting_; } + bool enable_rds_testing() const { return enable_rds_testing_; } + bool use_v2() const { return use_v2_; } + bool use_xds_credentials() const { return use_xds_credentials_; } + bool use_csds_streaming() const { return use_csds_streaming_; } + FilterConfigSetup filter_config_setup() const { return filter_config_setup_; } + BootstrapSource bootstrap_source() const { return bootstrap_source_; } + + std::string AsString() const { + std::string retval = (use_fake_resolver_ ? "FakeResolver" : "XdsResolver"); + retval += (use_v2_ ? "V2" : "V3"); + if (enable_load_reporting_) retval += "WithLoadReporting"; + if (enable_rds_testing_) retval += "Rds"; + if (use_xds_credentials_) retval += "XdsCreds"; + if (use_csds_streaming_) retval += "CsdsStreaming"; + if (filter_config_setup_ == kRouteOverride) { + retval += "FilterPerRouteOverride"; + } + if (bootstrap_source_ == kBootstrapFromFile) { + retval += "BootstrapFromFile"; + } else if (bootstrap_source_ == kBootstrapFromEnvVar) { + retval += "BootstrapFromEnvVar"; + } + return retval; + } + + private: + bool use_fake_resolver_ = false; + bool enable_load_reporting_ = false; + bool enable_rds_testing_ = false; + bool use_v2_ = false; + bool use_xds_credentials_ = false; + bool use_csds_streaming_ = false; + FilterConfigSetup filter_config_setup_ = kHTTPConnectionManagerOriginal; + BootstrapSource bootstrap_source_ = kBootstrapFromChannelArg; +}; + +std::string ReadFile(const char* file_path) { + grpc_slice slice; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("load_file", grpc_load_file(file_path, 0, &slice))); + std::string file_contents(grpc_core::StringViewFromSlice(slice)); + grpc_slice_unref(slice); + return file_contents; +} + +grpc_core::PemKeyCertPairList ReadTlsIdentityPair(const char* key_path, + const char* cert_path) { + return grpc_core::PemKeyCertPairList{ + grpc_core::PemKeyCertPair(ReadFile(key_path), ReadFile(cert_path))}; +} + +// Based on StaticDataCertificateProvider, but provides alternate certificates +// if the certificate name is not empty. +class FakeCertificateProvider final : public grpc_tls_certificate_provider { + public: + struct CertData { + std::string root_certificate; + grpc_core::PemKeyCertPairList identity_key_cert_pairs; + }; + + using CertDataMap = std::map; + + explicit FakeCertificateProvider(CertDataMap cert_data_map) + : distributor_( + grpc_core::MakeRefCounted()), + cert_data_map_(std::move(cert_data_map)) { + distributor_->SetWatchStatusCallback([this](std::string cert_name, + bool root_being_watched, + bool identity_being_watched) { + if (!root_being_watched && !identity_being_watched) return; + auto it = cert_data_map_.find(cert_name); + if (it == cert_data_map_.end()) { + grpc_error_handle error = + GRPC_ERROR_CREATE_FROM_CPP_STRING(absl::StrCat( + "No certificates available for cert_name \"", cert_name, "\"")); + distributor_->SetErrorForCert(cert_name, GRPC_ERROR_REF(error), + GRPC_ERROR_REF(error)); + GRPC_ERROR_UNREF(error); + } else { + absl::optional root_certificate; + absl::optional pem_key_cert_pairs; + if (root_being_watched) { + root_certificate = it->second.root_certificate; + } + if (identity_being_watched) { + pem_key_cert_pairs = it->second.identity_key_cert_pairs; + } + distributor_->SetKeyMaterials(cert_name, std::move(root_certificate), + std::move(pem_key_cert_pairs)); + } + }); + } + + ~FakeCertificateProvider() override { + distributor_->SetWatchStatusCallback(nullptr); + } + + grpc_core::RefCountedPtr distributor() + const override { + return distributor_; + } + + private: + grpc_core::RefCountedPtr distributor_; + CertDataMap cert_data_map_; +}; + +class FakeCertificateProviderFactory + : public grpc_core::CertificateProviderFactory { + public: + class Config : public grpc_core::CertificateProviderFactory::Config { + public: + explicit Config(const char* name) : name_(name) {} + + const char* name() const override { return name_; } + + std::string ToString() const override { return "{}"; } + + private: + const char* name_; + }; + + FakeCertificateProviderFactory( + const char* name, FakeCertificateProvider::CertDataMap** cert_data_map) + : name_(name), cert_data_map_(cert_data_map) { + GPR_ASSERT(cert_data_map != nullptr); + } + + const char* name() const override { return name_; } + + grpc_core::RefCountedPtr + CreateCertificateProviderConfig(const grpc_core::Json& /*config_json*/, + grpc_error_handle* /*error*/) override { + return grpc_core::MakeRefCounted(name_); + } + + grpc_core::RefCountedPtr + CreateCertificateProvider( + grpc_core::RefCountedPtr + /*config*/) override { + if (*cert_data_map_ == nullptr) return nullptr; + return grpc_core::MakeRefCounted(**cert_data_map_); + } + + private: + const char* name_; + FakeCertificateProvider::CertDataMap** cert_data_map_; +}; + +// Global variables for each provider. +FakeCertificateProvider::CertDataMap* g_fake1_cert_data_map = nullptr; +FakeCertificateProvider::CertDataMap* g_fake2_cert_data_map = nullptr; + +int ServerAuthCheckSchedule(void* /* config_user_data */, + grpc_tls_server_authorization_check_arg* arg) { + arg->success = 1; + arg->status = GRPC_STATUS_OK; + return 0; /* synchronous check */ +} + +std::shared_ptr CreateTlsFallbackCredentials() { + // TODO(yashykt): Switch to using C++ API once b/173823806 is fixed. + grpc_tls_credentials_options* options = grpc_tls_credentials_options_create(); + grpc_tls_credentials_options_set_server_verification_option( + options, GRPC_TLS_SKIP_HOSTNAME_VERIFICATION); + grpc_tls_credentials_options_set_certificate_provider( + options, + grpc_core::MakeRefCounted( + ReadFile(kCaCertPath), + ReadTlsIdentityPair(kServerKeyPath, kServerCertPath)) + .get()); + grpc_tls_credentials_options_watch_root_certs(options); + grpc_tls_credentials_options_watch_identity_key_cert_pairs(options); + grpc_tls_server_authorization_check_config* check_config = + grpc_tls_server_authorization_check_config_create( + nullptr, ServerAuthCheckSchedule, nullptr, nullptr); + grpc_tls_credentials_options_set_server_authorization_check_config( + options, check_config); + auto channel_creds = std::make_shared( + grpc_tls_credentials_create(options)); + grpc_tls_server_authorization_check_config_release(check_config); + return channel_creds; +} + +// A No-op HTTP filter used for verifying parsing logic. +class NoOpHttpFilter : public grpc_core::XdsHttpFilterImpl { + public: + NoOpHttpFilter(std::string name, bool supported_on_clients, + bool supported_on_servers) + : name_(std::move(name)), + supported_on_clients_(supported_on_clients), + supported_on_servers_(supported_on_servers) {} + + void PopulateSymtab(upb_symtab* /* symtab */) const override {} + + absl::StatusOr + GenerateFilterConfig(upb_strview /* serialized_filter_config */, + upb_arena* /* arena */) const override { + return grpc_core::XdsHttpFilterImpl::FilterConfig{name_, grpc_core::Json()}; + } + + absl::StatusOr + GenerateFilterConfigOverride(upb_strview /*serialized_filter_config*/, + upb_arena* /*arena*/) const override { + return grpc_core::XdsHttpFilterImpl::FilterConfig{name_, grpc_core::Json()}; + } + + const grpc_channel_filter* channel_filter() const override { return nullptr; } + + absl::StatusOr + GenerateServiceConfig( + const FilterConfig& /*hcm_filter_config*/, + const FilterConfig* /*filter_config_override*/) const override { + return grpc_core::XdsHttpFilterImpl::ServiceConfigJsonEntry{name_, ""}; + } + + bool IsSupportedOnClients() const override { return supported_on_clients_; } + + bool IsSupportedOnServers() const override { return supported_on_servers_; } + + private: + const std::string name_; + const bool supported_on_clients_; + const bool supported_on_servers_; +}; + +// There is slight difference between time fetched by GPR and by C++ system +// clock API. It's unclear if they are using the same syscall, but we do know +// GPR round the number at millisecond-level. This creates a 1ms difference, +// which could cause flake. +grpc_millis NowFromCycleCounter() { + return grpc_timespec_to_millis_round_down(gpr_now(GPR_CLOCK_MONOTONIC)); +} + +// Returns the number of RPCs needed to pass error_tolerance at 99.99994% +// chance. Rolling dices in drop/fault-injection generates a binomial +// distribution (if our code is not horribly wrong). Let's make "n" the number +// of samples, "p" the probability. If we have np>5 & n(1-p)>5, we can +// approximately treat the binomial distribution as a normal distribution. +// +// For normal distribution, we can easily look up how many standard deviation we +// need to reach 99.995%. Based on Wiki's table +// https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule, we need 5.00 +// sigma (standard deviation) to cover the probability area of 99.99994%. In +// another word, for a sample with size "n" probability "p" error-tolerance "k", +// we want the error always land within 5.00 sigma. The sigma of binominal +// distribution and be computed as sqrt(np(1-p)). Hence, we have the equation: +// +// kn <= 5.00 * sqrt(np(1-p)) +size_t ComputeIdealNumRpcs(double p, double error_tolerance) { + GPR_ASSERT(p >= 0 && p <= 1); + size_t num_rpcs = + ceil(p * (1 - p) * 5.00 * 5.00 / error_tolerance / error_tolerance); + gpr_log(GPR_INFO, + "Sending %" PRIuPTR " RPCs for percentage=%.3f error_tolerance=%.3f", + num_rpcs, p, error_tolerance); + return num_rpcs; +} + +// Channel arg pointer vtable for storing xDS channel args in the parent +// channel's channel args. +void* ChannelArgsArgCopy(void* p) { + auto* args = static_cast(p); + return grpc_channel_args_copy(args); +} +void ChannelArgsArgDestroy(void* p) { + auto* args = static_cast(p); + grpc_channel_args_destroy(args); +} +int ChannelArgsArgCmp(void* a, void* b) { + auto* args_a = static_cast(a); + auto* args_b = static_cast(b); + return grpc_channel_args_compare(args_a, args_b); +} +const grpc_arg_pointer_vtable kChannelArgsArgVtable = { + ChannelArgsArgCopy, ChannelArgsArgDestroy, ChannelArgsArgCmp}; + +class XdsEnd2endTest : public ::testing::TestWithParam { + protected: + // TODO(roth): We currently set the number of backends and number of + // balancers on a per-test-suite basis, not a per-test-case basis. + // However, not every individual test case in a given test suite uses + // the same number of backends or balancers, so we wind up having to + // set the numbers for the test suite to the max number needed by any + // one test case in that test suite. This results in starting more + // servers (and using more ports) than we actually need. When we have + // time, change each test to directly start the number of backends and + // balancers that it needs, so that we aren't wasting resources. + XdsEnd2endTest(size_t num_backends, size_t num_balancers, + int client_load_reporting_interval_seconds = 100, + bool use_xds_enabled_server = false) + : num_backends_(num_backends), + num_balancers_(num_balancers), + client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds), + use_xds_enabled_server_(use_xds_enabled_server) {} + + void SetUp() override { + bool localhost_resolves_to_ipv4 = false; + bool localhost_resolves_to_ipv6 = false; + grpc_core::LocalhostResolves(&localhost_resolves_to_ipv4, + &localhost_resolves_to_ipv6); + ipv6_only_ = !localhost_resolves_to_ipv4 && localhost_resolves_to_ipv6; + // Initialize default xDS resources. + // Construct LDS resource. + default_listener_.set_name(kServerName); + HttpConnectionManager http_connection_manager; + if (!GetParam().use_v2()) { + auto* filter = http_connection_manager.add_http_filters(); + filter->set_name("router"); + filter->mutable_typed_config()->PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + } + default_listener_.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + // Construct RDS resource. + default_route_config_.set_name(kDefaultRouteConfigurationName); + auto* virtual_host = default_route_config_.add_virtual_hosts(); + virtual_host->add_domains("*"); + auto* route = virtual_host->add_routes(); + route->mutable_match()->set_prefix(""); + route->mutable_route()->set_cluster(kDefaultClusterName); + // Construct CDS resource. + default_cluster_.set_name(kDefaultClusterName); + default_cluster_.set_type(Cluster::EDS); + auto* eds_config = default_cluster_.mutable_eds_cluster_config(); + eds_config->mutable_eds_config()->mutable_ads(); + eds_config->set_service_name(kDefaultEdsServiceName); + default_cluster_.set_lb_policy(Cluster::ROUND_ROBIN); + if (GetParam().enable_load_reporting()) { + default_cluster_.mutable_lrs_server()->mutable_self(); + } + // Start the load balancers. + for (size_t i = 0; i < num_balancers_; ++i) { + balancers_.emplace_back(new BalancerServerThread( + this, GetParam().enable_load_reporting() + ? client_load_reporting_interval_seconds_ + : 0)); + balancers_.back()->Start(); + // Initialize resources. + SetListenerAndRouteConfiguration(i, default_listener_, + default_route_config_); + balancers_.back()->ads_service()->SetCdsResource(default_cluster_); + } + // Create fake resolver response generators used by client. + if (GetParam().use_fake_resolver()) { + response_generator_ = + grpc_core::MakeRefCounted(); + } + logical_dns_cluster_resolver_response_generator_ = + grpc_core::MakeRefCounted(); + lb_channel_response_generator_ = + grpc_core::MakeRefCounted(); + // Construct channel args for XdsClient. + xds_channel_args_to_add_.emplace_back( + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + lb_channel_response_generator_.get())); + if (xds_resource_does_not_exist_timeout_ms_ > 0) { + xds_channel_args_to_add_.emplace_back(grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_XDS_RESOURCE_DOES_NOT_EXIST_TIMEOUT_MS), + xds_resource_does_not_exist_timeout_ms_)); + } + xds_channel_args_.num_args = xds_channel_args_to_add_.size(); + xds_channel_args_.args = xds_channel_args_to_add_.data(); + // Initialize XdsClient state. + // TODO(roth): Consider changing this to dynamically generate the + // bootstrap config in each individual test instead of hard-coding + // the contents here. That would allow us to use an ipv4: or ipv6: + // URI for the xDS server instead of using the fake resolver. + if (GetParam().bootstrap_source() == TestType::kBootstrapFromEnvVar) { + gpr_setenv("GRPC_XDS_BOOTSTRAP_CONFIG", + GetParam().use_v2() ? kBootstrapFileV2 : kBootstrapFileV3); + } else if (GetParam().bootstrap_source() == TestType::kBootstrapFromFile) { + gpr_setenv("GRPC_XDS_BOOTSTRAP", GetParam().use_v2() + ? g_bootstrap_file_v2 + : g_bootstrap_file_v3); + } + if (GetParam().bootstrap_source() != TestType::kBootstrapFromChannelArg) { + // If getting bootstrap from channel arg, we'll pass these args in + // via the parent channel args in CreateChannel() instead. + grpc_core::internal::SetXdsChannelArgsForTest(&xds_channel_args_); + // Make sure each test creates a new XdsClient instance rather than + // reusing the one from the previous test. This avoids spurious failures + // caused when a load reporting test runs after a non-load reporting test + // and the XdsClient is still talking to the old LRS server, which fails + // because it's not expecting the client to connect. It also + // ensures that each test can independently set the global channel + // args for the xDS channel. + grpc_core::internal::UnsetGlobalXdsClientForTest(); + } + // Start the backends. + for (size_t i = 0; i < num_backends_; ++i) { + backends_.emplace_back( + new BackendServerThread(this, use_xds_enabled_server_)); + backends_.back()->Start(); + } + // Create channel and stub. + ResetStub(); + } + + const char* DefaultEdsServiceName() const { + return GetParam().use_fake_resolver() ? kServerName + : kDefaultEdsServiceName; + } + + void TearDown() override { + ShutdownAllBackends(); + for (auto& balancer : balancers_) balancer->Shutdown(); + // Clear global xDS channel args, since they will go out of scope + // when this test object is destroyed. + grpc_core::internal::SetXdsChannelArgsForTest(nullptr); + gpr_unsetenv("GRPC_XDS_BOOTSTRAP"); + gpr_unsetenv("GRPC_XDS_BOOTSTRAP_CONFIG"); + } + + void StartAllBackends() { + for (auto& backend : backends_) backend->Start(); + } + + void StartBackend(size_t index) { backends_[index]->Start(); } + + void ShutdownAllBackends() { + for (auto& backend : backends_) backend->Shutdown(); + } + + void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); } + + void ResetStub(int failover_timeout = 0) { + channel_ = CreateChannel(failover_timeout); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + stub1_ = grpc::testing::EchoTest1Service::NewStub(channel_); + stub2_ = grpc::testing::EchoTest2Service::NewStub(channel_); + } + + std::shared_ptr CreateChannel( + int failover_timeout = 0, const char* server_name = kServerName, + grpc_core::FakeResolverResponseGenerator* response_generator = nullptr, + grpc_channel_args* xds_channel_args = nullptr) { + ChannelArguments args; + if (failover_timeout > 0) { + args.SetInt(GRPC_ARG_PRIORITY_FAILOVER_TIMEOUT_MS, failover_timeout); + } + // If the parent channel is using the fake resolver, we inject the + // response generator here. + if (GetParam().use_fake_resolver()) { + if (response_generator == nullptr) { + response_generator = response_generator_.get(); + } + args.SetPointerWithVtable( + GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, response_generator, + &grpc_core::FakeResolverResponseGenerator::kChannelArgPointerVtable); + } + if (GetParam().bootstrap_source() == TestType::kBootstrapFromChannelArg) { + // We're getting the bootstrap from a channel arg, so we do the + // same thing for the response generator to use for the xDS + // channel and the xDS resource-does-not-exist timeout value. + args.SetString(GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_BOOTSTRAP_CONFIG, + GetParam().use_v2() ? kBootstrapFileV2 : kBootstrapFileV3); + if (xds_channel_args == nullptr) xds_channel_args = &xds_channel_args_; + args.SetPointerWithVtable( + GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_CLIENT_CHANNEL_ARGS, + xds_channel_args, &kChannelArgsArgVtable); + } + args.SetPointerWithVtable( + GRPC_ARG_XDS_LOGICAL_DNS_CLUSTER_FAKE_RESOLVER_RESPONSE_GENERATOR, + logical_dns_cluster_resolver_response_generator_.get(), + &grpc_core::FakeResolverResponseGenerator::kChannelArgPointerVtable); + std::string uri = absl::StrCat( + GetParam().use_fake_resolver() ? "fake" : "xds", ":///", server_name); + std::shared_ptr channel_creds = + GetParam().use_xds_credentials() + ? XdsCredentials(CreateTlsFallbackCredentials()) + : std::make_shared( + grpc_fake_transport_security_credentials_create()); + return ::grpc::CreateCustomChannel(uri, channel_creds, args); + } + + enum RpcService { + SERVICE_ECHO, + SERVICE_ECHO1, + SERVICE_ECHO2, + }; + + enum RpcMethod { + METHOD_ECHO, + METHOD_ECHO1, + METHOD_ECHO2, + }; + + struct RpcOptions { + RpcService service = SERVICE_ECHO; + RpcMethod method = METHOD_ECHO; + int timeout_ms = 1000; + bool wait_for_ready = false; + bool server_fail = false; + std::vector> metadata; + int server_sleep_us = 0; + int client_cancel_after_us = 0; + bool skip_cancelled_check = false; + StatusCode server_expected_error = StatusCode::OK; + + RpcOptions() {} + + RpcOptions& set_rpc_service(RpcService rpc_service) { + service = rpc_service; + return *this; + } + + RpcOptions& set_rpc_method(RpcMethod rpc_method) { + method = rpc_method; + return *this; + } + + RpcOptions& set_timeout_ms(int rpc_timeout_ms) { + timeout_ms = rpc_timeout_ms; + return *this; + } + + RpcOptions& set_wait_for_ready(bool rpc_wait_for_ready) { + wait_for_ready = rpc_wait_for_ready; + return *this; + } + + RpcOptions& set_server_fail(bool rpc_server_fail) { + server_fail = rpc_server_fail; + return *this; + } + + RpcOptions& set_skip_cancelled_check(bool rpc_skip_cancelled_check) { + skip_cancelled_check = rpc_skip_cancelled_check; + return *this; + } + + RpcOptions& set_metadata( + std::vector> rpc_metadata) { + metadata = std::move(rpc_metadata); + return *this; + } + + RpcOptions& set_server_sleep_us(int rpc_server_sleep_us) { + server_sleep_us = rpc_server_sleep_us; + return *this; + } + + RpcOptions& set_client_cancel_after_us(int rpc_client_cancel_after_us) { + client_cancel_after_us = rpc_client_cancel_after_us; + return *this; + } + + RpcOptions& set_server_expected_error(StatusCode code) { + server_expected_error = code; + return *this; + } + + // Populates context and request. + void SetupRpc(ClientContext* context, EchoRequest* request) const { + for (const auto& item : metadata) { + context->AddMetadata(item.first, item.second); + } + if (timeout_ms != 0) { + context->set_deadline( + grpc_timeout_milliseconds_to_deadline(timeout_ms)); + } + if (wait_for_ready) context->set_wait_for_ready(true); + request->set_message(kRequestMessage); + if (server_fail) { + request->mutable_param()->mutable_expected_error()->set_code( + GRPC_STATUS_FAILED_PRECONDITION); + } + if (server_sleep_us != 0) { + request->mutable_param()->set_server_sleep_us(server_sleep_us); + } + if (client_cancel_after_us != 0) { + request->mutable_param()->set_client_cancel_after_us( + client_cancel_after_us); + } + if (skip_cancelled_check) { + request->mutable_param()->set_skip_cancelled_check(true); + } + } + }; + + template + Status SendRpcMethod(Stub* stub, const RpcOptions& rpc_options, + ClientContext* context, EchoRequest& request, + EchoResponse* response) { + switch (rpc_options.method) { + case METHOD_ECHO: + return (*stub)->Echo(context, request, response); + case METHOD_ECHO1: + return (*stub)->Echo1(context, request, response); + case METHOD_ECHO2: + return (*stub)->Echo2(context, request, response); + } + GPR_UNREACHABLE_CODE(); + } + + void ResetBackendCounters(size_t start_index = 0, size_t stop_index = 0) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + backends_[i]->backend_service()->ResetCounters(); + backends_[i]->backend_service1()->ResetCounters(); + backends_[i]->backend_service2()->ResetCounters(); + } + } + + bool SeenBackend(size_t backend_idx, + const RpcService rpc_service = SERVICE_ECHO) { + switch (rpc_service) { + case SERVICE_ECHO: + if (backends_[backend_idx]->backend_service()->request_count() == 0) { + return false; + } + break; + case SERVICE_ECHO1: + if (backends_[backend_idx]->backend_service1()->request_count() == 0) { + return false; + } + break; + case SERVICE_ECHO2: + if (backends_[backend_idx]->backend_service2()->request_count() == 0) { + return false; + } + break; + } + return true; + } + + bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0, + const RpcService rpc_service = SERVICE_ECHO) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + if (!SeenBackend(i, rpc_service)) { + return false; + } + } + return true; + } + + void SendRpcAndCount( + int* num_total, int* num_ok, int* num_failure, int* num_drops, + const RpcOptions& rpc_options = RpcOptions(), + const char* drop_error_message_prefix = "EDS-configured drop: ") { + const Status status = SendRpc(rpc_options); + if (status.ok()) { + ++*num_ok; + } else { + if (absl::StartsWith(status.error_message(), drop_error_message_prefix)) { + ++*num_drops; + } else { + ++*num_failure; + } + } + ++*num_total; + } + + struct WaitForBackendOptions { + bool reset_counters = true; + bool allow_failures = false; + + WaitForBackendOptions() {} + + WaitForBackendOptions& set_reset_counters(bool enable) { + reset_counters = enable; + return *this; + } + + WaitForBackendOptions& set_allow_failures(bool enable) { + allow_failures = enable; + return *this; + } + }; + + std::tuple WaitForAllBackends( + size_t start_index = 0, size_t stop_index = 0, + const WaitForBackendOptions& wait_options = WaitForBackendOptions(), + const RpcOptions& rpc_options = RpcOptions()) { + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + int num_total = 0; + gpr_log(GPR_INFO, "========= WAITING FOR All BACKEND %lu TO %lu ==========", + static_cast(start_index), + static_cast(stop_index)); + while (!SeenAllBackends(start_index, stop_index, rpc_options.service)) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops, + rpc_options); + } + if (wait_options.reset_counters) ResetBackendCounters(); + gpr_log(GPR_INFO, + "Performed %d warm up requests against the backends. " + "%d succeeded, %d failed, %d dropped.", + num_total, num_ok, num_failure, num_drops); + if (!wait_options.allow_failures) EXPECT_EQ(num_failure, 0); + return std::make_tuple(num_ok, num_failure, num_drops); + } + + void WaitForBackend( + size_t backend_idx, + const WaitForBackendOptions& wait_options = WaitForBackendOptions(), + const RpcOptions& rpc_options = RpcOptions()) { + gpr_log(GPR_INFO, "========= WAITING FOR BACKEND %lu ==========", + static_cast(backend_idx)); + do { + Status status = SendRpc(rpc_options); + if (!wait_options.allow_failures) { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + } + } while (!SeenBackend(backend_idx, rpc_options.service)); + if (wait_options.reset_counters) ResetBackendCounters(); + gpr_log(GPR_INFO, "========= BACKEND %lu READY ==========", + static_cast(backend_idx)); + } + + grpc_core::ServerAddressList CreateAddressListFromPortList( + const std::vector& ports) { + grpc_core::ServerAddressList addresses; + for (int port : ports) { + absl::StatusOr lb_uri = grpc_core::URI::Parse( + absl::StrCat(ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", port)); + GPR_ASSERT(lb_uri.ok()); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(*lb_uri, &address)); + addresses.emplace_back(address.addr, address.len, nullptr); + } + return addresses; + } + + std::string CreateMetadataValueThatHashesToBackendPort(int port) { + return absl::StrCat(ipv6_only_ ? "[::1]" : "127.0.0.1", ":", port, "_0"); + } + + std::string CreateMetadataValueThatHashesToBackend(int index) { + return CreateMetadataValueThatHashesToBackendPort(backends_[index]->port()); + } + + void SetNextResolution( + const std::vector& ports, + grpc_core::FakeResolverResponseGenerator* response_generator = nullptr) { + if (!GetParam().use_fake_resolver()) return; // Not used with xds resolver. + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + grpc_error_handle error = GRPC_ERROR_NONE; + const char* service_config_json = + GetParam().enable_load_reporting() + ? kDefaultServiceConfig + : kDefaultServiceConfigWithoutLoadReporting; + result.service_config = + grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_NE(result.service_config.get(), nullptr); + if (response_generator == nullptr) { + response_generator = response_generator_.get(); + } + response_generator->SetResponse(std::move(result)); + } + + void SetNextResolutionForLbChannelAllBalancers( + const char* service_config_json = nullptr, + const char* expected_targets = nullptr, + grpc_core::FakeResolverResponseGenerator* response_generator = nullptr) { + std::vector ports; + for (size_t i = 0; i < balancers_.size(); ++i) { + ports.emplace_back(balancers_[i]->port()); + } + SetNextResolutionForLbChannel(ports, service_config_json, expected_targets, + response_generator); + } + + void SetNextResolutionForLbChannel( + const std::vector& ports, const char* service_config_json = nullptr, + const char* expected_targets = nullptr, + grpc_core::FakeResolverResponseGenerator* response_generator = nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + if (service_config_json != nullptr) { + grpc_error_handle error = GRPC_ERROR_NONE; + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, service_config_json, &error); + ASSERT_NE(result.service_config.get(), nullptr); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + } + if (expected_targets != nullptr) { + grpc_arg expected_targets_arg = grpc_channel_arg_string_create( + const_cast(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS), + const_cast(expected_targets)); + result.args = + grpc_channel_args_copy_and_add(nullptr, &expected_targets_arg, 1); + } + if (response_generator == nullptr) { + response_generator = lb_channel_response_generator_.get(); + } + response_generator->SetResponse(std::move(result)); + } + + void SetNextReresolutionResponse(const std::vector& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + response_generator_->SetReresolutionResponse(std::move(result)); + } + + std::vector GetBackendPorts(size_t start_index = 0, + size_t stop_index = 0) const { + if (stop_index == 0) stop_index = backends_.size(); + std::vector backend_ports; + for (size_t i = start_index; i < stop_index; ++i) { + backend_ports.push_back(backends_[i]->port()); + } + return backend_ports; + } + + Status SendRpc(const RpcOptions& rpc_options = RpcOptions(), + EchoResponse* response = nullptr) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + ClientContext context; + EchoRequest request; + if (rpc_options.server_expected_error != StatusCode::OK) { + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(rpc_options.server_expected_error); + } + rpc_options.SetupRpc(&context, &request); + Status status; + switch (rpc_options.service) { + case SERVICE_ECHO: + status = + SendRpcMethod(&stub_, rpc_options, &context, request, response); + break; + case SERVICE_ECHO1: + status = + SendRpcMethod(&stub1_, rpc_options, &context, request, response); + break; + case SERVICE_ECHO2: + status = + SendRpcMethod(&stub2_, rpc_options, &context, request, response); + break; + } + if (local_response) delete response; + return status; + } + + void CheckRpcSendOk(const size_t times = 1, + const RpcOptions& rpc_options = RpcOptions()) { + for (size_t i = 0; i < times; ++i) { + EchoResponse response; + const Status status = SendRpc(rpc_options, &response); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + + struct CheckRpcSendFailureOptions { + std::function continue_predicate = [](size_t i) { + return i < 1; + }; + RpcOptions rpc_options; + StatusCode expected_error_code = StatusCode::OK; + + CheckRpcSendFailureOptions() {} + + CheckRpcSendFailureOptions& set_times(size_t times) { + continue_predicate = [times](size_t i) { return i < times; }; + return *this; + } + + CheckRpcSendFailureOptions& set_continue_predicate( + std::function pred) { + continue_predicate = std::move(pred); + return *this; + } + + CheckRpcSendFailureOptions& set_rpc_options(const RpcOptions& options) { + rpc_options = options; + return *this; + } + + CheckRpcSendFailureOptions& set_expected_error_code(StatusCode code) { + expected_error_code = code; + return *this; + } + }; + + void CheckRpcSendFailure(const CheckRpcSendFailureOptions& options = + CheckRpcSendFailureOptions()) { + for (size_t i = 0; options.continue_predicate(i); ++i) { + const Status status = SendRpc(options.rpc_options); + EXPECT_FALSE(status.ok()); + if (options.expected_error_code != StatusCode::OK) { + EXPECT_EQ(options.expected_error_code, status.error_code()); + } + } + } + + bool WaitForNack( + std::function get_state, + StatusCode expected_status = StatusCode::UNAVAILABLE) { + auto deadline = absl::Now() + absl::Seconds(30); + bool success = true; + CheckRpcSendFailure(CheckRpcSendFailureOptions() + .set_continue_predicate([&](size_t) { + if (absl::Now() >= deadline) { + success = false; + return false; + } + return get_state() != + AdsServiceImpl::ResponseState::NACKED; + }) + .set_expected_error_code(expected_status)); + return success; + } + + bool WaitForLdsNack(StatusCode expected_status = StatusCode::UNAVAILABLE) { + return WaitForNack( + [&]() { + return balancers_[0]->ads_service()->lds_response_state().state; + }, + expected_status); + } + + bool WaitForRdsNack() { + return WaitForNack( + [&]() { return RouteConfigurationResponseState(0).state; }); + } + + bool WaitForCdsNack() { + return WaitForNack([&]() { + return balancers_[0]->ads_service()->cds_response_state().state; + }); + } + + bool WaitForEdsNack() { + return WaitForNack([&]() { + return balancers_[0]->ads_service()->eds_response_state().state; + }); + } + + AdsServiceImpl::ResponseState RouteConfigurationResponseState(int idx) const { + AdsServiceImpl* ads_service = balancers_[idx]->ads_service(); + if (GetParam().enable_rds_testing()) { + return ads_service->rds_response_state(); + } + return ads_service->lds_response_state(); + } + + void SetListenerAndRouteConfiguration( + int idx, Listener listener, const RouteConfiguration& route_config) { + auto* api_listener = + listener.mutable_api_listener()->mutable_api_listener(); + HttpConnectionManager http_connection_manager; + api_listener->UnpackTo(&http_connection_manager); + if (GetParam().enable_rds_testing()) { + auto* rds = http_connection_manager.mutable_rds(); + rds->set_route_config_name(kDefaultRouteConfigurationName); + rds->mutable_config_source()->mutable_ads(); + balancers_[idx]->ads_service()->SetRdsResource(route_config); + } else { + *http_connection_manager.mutable_route_config() = route_config; + } + api_listener->PackFrom(http_connection_manager); + balancers_[idx]->ads_service()->SetLdsResource(listener); + } + + void SetRouteConfiguration(int idx, const RouteConfiguration& route_config, + const Listener* listener_to_copy = nullptr) { + if (GetParam().enable_rds_testing()) { + balancers_[idx]->ads_service()->SetRdsResource(route_config); + } else { + Listener listener(listener_to_copy == nullptr ? default_listener_ + : *listener_to_copy); + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *(http_connection_manager.mutable_route_config()) = route_config; + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[idx]->ads_service()->SetLdsResource(listener); + } + } + + struct EdsResourceArgs { + struct Endpoint { + explicit Endpoint(int port, + HealthStatus health_status = HealthStatus::UNKNOWN, + int lb_weight = 1) + : port(port), health_status(health_status), lb_weight(lb_weight) {} + + int port; + HealthStatus health_status; + int lb_weight; + }; + + struct Locality { + Locality(std::string sub_zone, std::vector endpoints, + int lb_weight = kDefaultLocalityWeight, + int priority = kDefaultLocalityPriority) + : sub_zone(std::move(sub_zone)), + endpoints(std::move(endpoints)), + lb_weight(lb_weight), + priority(priority) {} + + const std::string sub_zone; + std::vector endpoints; + int lb_weight; + int priority; + }; + + EdsResourceArgs() = default; + explicit EdsResourceArgs(std::vector locality_list) + : locality_list(std::move(locality_list)) {} + + std::vector locality_list; + std::map drop_categories; + FractionalPercent::DenominatorType drop_denominator = + FractionalPercent::MILLION; + }; + + EdsResourceArgs::Endpoint CreateEndpoint( + size_t backend_idx, HealthStatus health_status = HealthStatus::UNKNOWN, + int lb_weight = 1) { + return EdsResourceArgs::Endpoint(backends_[backend_idx]->port(), + health_status, lb_weight); + } + + std::vector CreateEndpointsForBackends( + size_t start_index = 0, size_t stop_index = 0, + HealthStatus health_status = HealthStatus::UNKNOWN, int lb_weight = 1) { + if (stop_index == 0) stop_index = backends_.size(); + std::vector endpoints; + for (size_t i = start_index; i < stop_index; ++i) { + endpoints.emplace_back(CreateEndpoint(i, health_status, lb_weight)); + } + return endpoints; + } + + EdsResourceArgs::Endpoint MakeNonExistantEndpoint() { + return EdsResourceArgs::Endpoint(grpc_pick_unused_port_or_die()); + } + + ClusterLoadAssignment BuildEdsResource( + const EdsResourceArgs& args, + const char* eds_service_name = kDefaultEdsServiceName) { + ClusterLoadAssignment assignment; + assignment.set_cluster_name(eds_service_name); + for (const auto& locality : args.locality_list) { + auto* endpoints = assignment.add_endpoints(); + endpoints->mutable_load_balancing_weight()->set_value(locality.lb_weight); + endpoints->set_priority(locality.priority); + endpoints->mutable_locality()->set_region(kDefaultLocalityRegion); + endpoints->mutable_locality()->set_zone(kDefaultLocalityZone); + endpoints->mutable_locality()->set_sub_zone(locality.sub_zone); + for (size_t i = 0; i < locality.endpoints.size(); ++i) { + const int& port = locality.endpoints[i].port; + auto* lb_endpoints = endpoints->add_lb_endpoints(); + if (locality.endpoints.size() > i && + locality.endpoints[i].health_status != HealthStatus::UNKNOWN) { + lb_endpoints->set_health_status(locality.endpoints[i].health_status); + } + if (locality.endpoints.size() > i && + locality.endpoints[i].lb_weight >= 1) { + lb_endpoints->mutable_load_balancing_weight()->set_value( + locality.endpoints[i].lb_weight); + } + auto* endpoint = lb_endpoints->mutable_endpoint(); + auto* address = endpoint->mutable_address(); + auto* socket_address = address->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(port); + } + } + if (!args.drop_categories.empty()) { + auto* policy = assignment.mutable_policy(); + for (const auto& p : args.drop_categories) { + const std::string& name = p.first; + const uint32_t parts_per_million = p.second; + auto* drop_overload = policy->add_drop_overloads(); + drop_overload->set_category(name); + auto* drop_percentage = drop_overload->mutable_drop_percentage(); + drop_percentage->set_numerator(parts_per_million); + drop_percentage->set_denominator(args.drop_denominator); + } + } + return assignment; + } + + public: + // This method could benefit test subclasses; to make it accessible + // via bind with a qualified name, it needs to be public. + void SetEdsResourceWithDelay(size_t i, + const ClusterLoadAssignment& assignment, + int delay_ms) { + GPR_ASSERT(delay_ms > 0); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + balancers_[i]->ads_service()->SetEdsResource(assignment); + } + + protected: + class XdsServingStatusNotifier + : public grpc::experimental::XdsServerServingStatusNotifierInterface { + public: + void OnServingStatusUpdate(std::string uri, + ServingStatusUpdate update) override { + grpc_core::MutexLock lock(&mu_); + status_map[uri] = update.status; + cond_.Signal(); + } + + void WaitOnServingStatusChange(std::string uri, + grpc::StatusCode expected_status) { + grpc_core::MutexLock lock(&mu_); + std::map::iterator it; + while ((it = status_map.find(uri)) == status_map.end() || + it->second.error_code() != expected_status) { + cond_.Wait(&mu_); + } + } + + private: + grpc_core::Mutex mu_; + grpc_core::CondVar cond_; + std::map status_map ABSL_GUARDED_BY(mu_); + }; + + class ServerThread { + public: + explicit ServerThread(XdsEnd2endTest* test_obj, + bool use_xds_enabled_server = false) + : test_obj_(test_obj), + port_(grpc_pick_unused_port_or_die()), + use_xds_enabled_server_(use_xds_enabled_server) {} + virtual ~ServerThread(){}; + + void Start() { + gpr_log(GPR_INFO, "starting %s server on port %d", Type(), port_); + GPR_ASSERT(!running_); + running_ = true; + StartAllServices(); + grpc_core::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Serve from firing before the wait below is hit. + grpc_core::MutexLock lock(&mu); + grpc_core::CondVar cond; + thread_ = absl::make_unique( + std::bind(&ServerThread::Serve, this, &mu, &cond)); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", Type()); + } + + void Serve(grpc_core::Mutex* mu, grpc_core::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc_core::MutexLock lock(mu); + std::ostringstream server_address; + server_address << "localhost:" << port_; + if (use_xds_enabled_server_) { + XdsServerBuilder builder; + if (GetParam().bootstrap_source() == + TestType::kBootstrapFromChannelArg) { + builder.SetOption( + absl::make_unique(test_obj_)); + } + builder.set_status_notifier(¬ifier_); + builder.AddListeningPort(server_address.str(), Credentials()); + RegisterAllServices(&builder); + server_ = builder.BuildAndStart(); + } else { + ServerBuilder builder; + builder.AddListeningPort(server_address.str(), Credentials()); + RegisterAllServices(&builder); + server_ = builder.BuildAndStart(); + } + cond->Signal(); + } + + void Shutdown() { + if (!running_) return; + gpr_log(GPR_INFO, "%s about to shutdown", Type()); + ShutdownAllServices(); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", Type()); + running_ = false; + } + + virtual std::shared_ptr Credentials() { + return std::make_shared( + grpc_fake_transport_security_server_credentials_create()); + } + + int port() const { return port_; } + + bool use_xds_enabled_server() const { return use_xds_enabled_server_; } + + XdsServingStatusNotifier* notifier() { return ¬ifier_; } + + private: + class XdsChannelArgsServerBuilderOption + : public ::grpc::ServerBuilderOption { + public: + explicit XdsChannelArgsServerBuilderOption(XdsEnd2endTest* test_obj) + : test_obj_(test_obj) {} + + void UpdateArguments(grpc::ChannelArguments* args) override { + args->SetString( + GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_BOOTSTRAP_CONFIG, + GetParam().use_v2() ? kBootstrapFileV2 : kBootstrapFileV3); + args->SetPointerWithVtable( + GRPC_ARG_TEST_ONLY_DO_NOT_USE_IN_PROD_XDS_CLIENT_CHANNEL_ARGS, + &test_obj_->xds_channel_args_, &kChannelArgsArgVtable); + } + + void UpdatePlugins( + std::vector>* /*plugins*/) + override {} + + private: + XdsEnd2endTest* test_obj_; + }; + + virtual void RegisterAllServices(ServerBuilder* builder) = 0; + virtual void StartAllServices() = 0; + virtual void ShutdownAllServices() = 0; + + virtual const char* Type() = 0; + + XdsEnd2endTest* test_obj_; + const int port_; + std::unique_ptr server_; + XdsServingStatusNotifier notifier_; + std::unique_ptr thread_; + bool running_ = false; + const bool use_xds_enabled_server_; + }; + + class BackendServerThread : public ServerThread { + public: + explicit BackendServerThread(XdsEnd2endTest* test_obj, + bool use_xds_enabled_server) + : ServerThread(test_obj, use_xds_enabled_server) {} + + BackendServiceImpl<::grpc::testing::EchoTestService::Service>* + backend_service() { + return &backend_service_; + } + BackendServiceImpl<::grpc::testing::EchoTest1Service::Service>* + backend_service1() { + return &backend_service1_; + } + BackendServiceImpl<::grpc::testing::EchoTest2Service::Service>* + backend_service2() { + return &backend_service2_; + } + + std::shared_ptr Credentials() override { + if (GetParam().use_xds_credentials()) { + if (use_xds_enabled_server()) { + // We are testing server's use of XdsServerCredentials + return XdsServerCredentials(InsecureServerCredentials()); + } else { + // We are testing client's use of XdsCredentials + std::string root_cert = ReadFile(kCaCertPath); + std::string identity_cert = ReadFile(kServerCertPath); + std::string private_key = ReadFile(kServerKeyPath); + std::vector + identity_key_cert_pairs = {{private_key, identity_cert}}; + auto certificate_provider = std::make_shared< + grpc::experimental::StaticDataCertificateProvider>( + root_cert, identity_key_cert_pairs); + grpc::experimental::TlsServerCredentialsOptions options( + certificate_provider); + options.watch_root_certs(); + options.watch_identity_key_cert_pairs(); + options.set_cert_request_type( + GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY); + return grpc::experimental::TlsServerCredentials(options); + } + } + return ServerThread::Credentials(); + } + + private: + void RegisterAllServices(ServerBuilder* builder) override { + builder->RegisterService(&backend_service_); + builder->RegisterService(&backend_service1_); + builder->RegisterService(&backend_service2_); + } + + void StartAllServices() override { + backend_service_.Start(); + backend_service1_.Start(); + backend_service2_.Start(); + } + + void ShutdownAllServices() override { + backend_service_.Shutdown(); + backend_service1_.Shutdown(); + backend_service2_.Shutdown(); + } + + const char* Type() override { return "Backend"; } + + BackendServiceImpl<::grpc::testing::EchoTestService::Service> + backend_service_; + BackendServiceImpl<::grpc::testing::EchoTest1Service::Service> + backend_service1_; + BackendServiceImpl<::grpc::testing::EchoTest2Service::Service> + backend_service2_; + }; + + class BalancerServerThread : public ServerThread { + public: + explicit BalancerServerThread(XdsEnd2endTest* test_obj, + int client_load_reporting_interval = 0) + : ServerThread(test_obj), + ads_service_(new AdsServiceImpl()), + lrs_service_(new LrsServiceImpl(client_load_reporting_interval, + {kDefaultClusterName})) {} + + AdsServiceImpl* ads_service() { return ads_service_.get(); } + LrsServiceImpl* lrs_service() { return lrs_service_.get(); } + + private: + void RegisterAllServices(ServerBuilder* builder) override { + builder->RegisterService(ads_service_->v2_rpc_service()); + builder->RegisterService(ads_service_->v3_rpc_service()); + builder->RegisterService(lrs_service_->v2_rpc_service()); + builder->RegisterService(lrs_service_->v3_rpc_service()); + } + + void StartAllServices() override { + ads_service_->Start(); + lrs_service_->Start(); + } + + void ShutdownAllServices() override { + ads_service_->Shutdown(); + lrs_service_->Shutdown(); + } + + const char* Type() override { return "Balancer"; } + + std::shared_ptr ads_service_; + std::shared_ptr lrs_service_; + }; + +#ifndef DISABLED_XDS_PROTO_IN_CC + class AdminServerThread : public ServerThread { + public: + explicit AdminServerThread(XdsEnd2endTest* test_obj) + : ServerThread(test_obj) {} + + private: + void RegisterAllServices(ServerBuilder* builder) override { + builder->RegisterService(&csds_service_); + } + void StartAllServices() override {} + void ShutdownAllServices() override {} + + const char* Type() override { return "Admin"; } + + grpc::xds::experimental::ClientStatusDiscoveryService csds_service_; + }; +#endif // DISABLED_XDS_PROTO_IN_CC + + class LongRunningRpc { + public: + void StartRpc(grpc::testing::EchoTestService::Stub* stub, + const RpcOptions& rpc_options = + RpcOptions().set_timeout_ms(0).set_client_cancel_after_us( + 1 * 1000 * 1000)) { + sender_thread_ = std::thread([this, stub, rpc_options]() { + EchoRequest request; + EchoResponse response; + rpc_options.SetupRpc(&context_, &request); + status_ = stub->Echo(&context_, request, &response); + }); + } + + void CancelRpc() { + context_.TryCancel(); + if (sender_thread_.joinable()) sender_thread_.join(); + } + + Status GetStatus() { + if (sender_thread_.joinable()) sender_thread_.join(); + return status_; + } + + private: + std::thread sender_thread_; + ClientContext context_; + Status status_; + }; + + struct ConcurrentRpc { + ClientContext context; + Status status; + grpc_millis elapsed_time; + EchoResponse response; + }; + + std::vector SendConcurrentRpcs( + grpc::testing::EchoTestService::Stub* stub, size_t num_rpcs, + const RpcOptions& rpc_options) { + // Variables for RPCs. + std::vector rpcs(num_rpcs); + EchoRequest request; + // Variables for synchronization + absl::Mutex mu; + absl::CondVar cv; + size_t completed = 0; + // Set-off callback RPCs + for (size_t i = 0; i < num_rpcs; i++) { + ConcurrentRpc* rpc = &rpcs[i]; + rpc_options.SetupRpc(&rpc->context, &request); + grpc_millis t0 = NowFromCycleCounter(); + stub->async()->Echo(&rpc->context, &request, &rpc->response, + [rpc, &mu, &completed, &cv, num_rpcs, t0](Status s) { + rpc->status = s; + rpc->elapsed_time = NowFromCycleCounter() - t0; + bool done; + { + absl::MutexLock lock(&mu); + done = (++completed) == num_rpcs; + } + if (done) cv.Signal(); + }); + } + { + absl::MutexLock lock(&mu); + cv.Wait(&mu); + } + EXPECT_EQ(completed, num_rpcs); + return rpcs; + } + + const size_t num_backends_; + const size_t num_balancers_; + const int client_load_reporting_interval_seconds_; + bool ipv6_only_ = false; + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr stub1_; + std::unique_ptr stub2_; + std::vector> backends_; + std::vector> balancers_; + grpc_core::RefCountedPtr + response_generator_; + grpc_core::RefCountedPtr + lb_channel_response_generator_; + grpc_core::RefCountedPtr + logical_dns_cluster_resolver_response_generator_; + int xds_resource_does_not_exist_timeout_ms_ = 0; + absl::InlinedVector xds_channel_args_to_add_; + grpc_channel_args xds_channel_args_; + + Listener default_listener_; + RouteConfiguration default_route_config_; + Cluster default_cluster_; + bool use_xds_enabled_server_; + bool bootstrap_contents_from_env_var_; +}; + +class BasicTest : public XdsEnd2endTest { + public: + BasicTest() : XdsEnd2endTest(4, 1) {} +}; + +// Tests that the balancer sends the correct response to the client, and the +// client sends RPCs to the backends using the default child policy. +TEST_P(BasicTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } + // Check LB policy name for the channel. + EXPECT_EQ( + (GetParam().use_fake_resolver() ? "xds_cluster_resolver_experimental" + : "xds_cluster_manager_experimental"), + channel_->GetLoadBalancingPolicyName()); +} + +TEST_P(BasicTest, IgnoresUnhealthyEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + auto endpoints = CreateEndpointsForBackends(); + endpoints[0].health_status = HealthStatus::DRAINING; + EdsResourceArgs args({ + {"locality0", std::move(endpoints), kDefaultLocalityWeight, + kDefaultLocalityPriority}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(/*start_index=*/1); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * (num_backends_ - 1)); + // Each backend should have gotten 100 requests. + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } +} + +// Tests that subchannel sharing works when the same backend is listed +// multiple times. +TEST_P(BasicTest, SameBackendListedMultipleTimes) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Same backend listed twice. + auto endpoints = CreateEndpointsForBackends(0, 1); + endpoints.push_back(endpoints.front()); + EdsResourceArgs args({ + {"locality0", endpoints}, + }); + const size_t kNumRpcsPerAddress = 10; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // We need to wait for the backend to come online. + WaitForBackend(0); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * endpoints.size()); + // Backend should have gotten 20 requests. + EXPECT_EQ(kNumRpcsPerAddress * endpoints.size(), + backends_[0]->backend_service()->request_count()); + // And they should have come from a single client port, because of + // subchannel sharing. + EXPECT_EQ(1UL, backends_[0]->backend_service()->clients().size()); +} + +// Tests that RPCs will be blocked until a non-empty serverlist is received. +TEST_P(BasicTest, InitiallyEmptyServerlist) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const int kCallDeadlineMs = kServerlistDelayMs * 2; + // First response is an empty serverlist, sent right away. + EdsResourceArgs::Locality empty_locality("locality0", {}); + EdsResourceArgs args({ + empty_locality, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Send non-empty serverlist only after kServerlistDelayMs. + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends()}, + }); + std::thread delayed_resource_setter(std::bind( + &BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), kServerlistDelayMs)); + const auto t0 = system_clock::now(); + // Client will block: LB will initially send empty serverlist. + CheckRpcSendOk( + 1, RpcOptions().set_timeout_ms(kCallDeadlineMs).set_wait_for_ready(true)); + const auto ellapsed_ms = + std::chrono::duration_cast( + system_clock::now() - t0); + // but eventually, the LB sends a serverlist update that allows the call to + // proceed. The call delay must be larger than the delay in sending the + // populated serverlist but under the call's deadline (which is enforced by + // the call's deadline). + EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs); + delayed_resource_setter.join(); +} + +// Tests that RPCs will fail with UNAVAILABLE instead of DEADLINE_EXCEEDED if +// all the servers are unreachable. +TEST_P(BasicTest, AllServersUnreachableFailFast) { + // Set Rpc timeout to 5 seconds to ensure there is enough time + // for communication with the xDS server to take place upon test start up. + const uint32_t kRpcTimeoutMs = 5000; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumUnreachableServers = 5; + std::vector endpoints; + for (size_t i = 0; i < kNumUnreachableServers; ++i) { + endpoints.emplace_back(grpc_pick_unused_port_or_die()); + } + EdsResourceArgs args({ + {"locality0", endpoints}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + const Status status = SendRpc(RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + // The error shouldn't be DEADLINE_EXCEEDED because timeout is set to 5 + // seconds, and we should disocver in that time that the target backend is + // down. + EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code()); +} + +// Tests that RPCs fail when the backends are down, and will succeed again +// after the backends are restarted. +TEST_P(BasicTest, BackendsRestart) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Stop backends. RPCs should fail. + ShutdownAllBackends(); + // Sending multiple failed requests instead of just one to ensure that the + // client notices that all backends are down before we restart them. If we + // didn't do this, then a single RPC could fail here due to the race + // condition between the LB pick and the GOAWAY from the chosen backend + // being shut down, which would not actually prove that the client noticed + // that all of the backends are down. Then, when we send another request + // below (which we expect to succeed), if the callbacks happen in the wrong + // order, the same race condition could happen again due to the client not + // yet having noticed that the backends were all down. + CheckRpcSendFailure(CheckRpcSendFailureOptions().set_times(num_backends_)); + // Restart all backends. RPCs should start succeeding again. + StartAllBackends(); + CheckRpcSendOk(1, RpcOptions().set_timeout_ms(2000).set_wait_for_ready(true)); +} + +TEST_P(BasicTest, IgnoresDuplicateUpdates) { + const size_t kNumRpcsPerAddress = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server, but send an EDS update in + // between. If the update is not ignored, this will cause the + // round_robin policy to see an update, which will randomly reset its + // position in the address list. + for (size_t i = 0; i < kNumRpcsPerAddress; ++i) { + CheckRpcSendOk(2); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + CheckRpcSendOk(2); + } + // Each backend should have gotten the right number of requests. + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } +} + +using XdsResolverOnlyTest = BasicTest; + +TEST_P(XdsResolverOnlyTest, ResourceTypeVersionPersistsAcrossStreamRestarts) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Wait for backends to come online. + WaitForAllBackends(0, 1); + // Stop balancer. + balancers_[0]->Shutdown(); + // Tell balancer to require minimum version 1 for all resource types. + balancers_[0]->ads_service()->SetResourceMinVersion(kLdsTypeUrl, 1); + balancers_[0]->ads_service()->SetResourceMinVersion(kRdsTypeUrl, 1); + balancers_[0]->ads_service()->SetResourceMinVersion(kCdsTypeUrl, 1); + balancers_[0]->ads_service()->SetResourceMinVersion(kEdsTypeUrl, 1); + // Update backend, just so we can be sure that the client has + // reconnected to the balancer. + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args2)); + // Restart balancer. + balancers_[0]->Start(); + // Make sure client has reconnected. + WaitForAllBackends(1, 2); +} + +// Tests switching over from one cluster to another. +TEST_P(XdsResolverOnlyTest, ChangeClusters) { + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(0, 2); + // Populate new EDS resource. + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsServiceName)); + // Populate new CDS resource. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = default_route_config_; + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + // Wait for all new backends to be used. + std::tuple counts = WaitForAllBackends(2, 4); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +// Tests that we go into TRANSIENT_FAILURE if the Cluster disappears. +TEST_P(XdsResolverOnlyTest, ClusterRemoved) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Unset CDS resource. + balancers_[0]->ads_service()->UnsetResource(kCdsTypeUrl, kDefaultClusterName); + // Wait for RPCs to start failing. + do { + } while (SendRpc(RpcOptions(), nullptr).ok()); + // Make sure RPCs are still failing. + CheckRpcSendFailure(CheckRpcSendFailureOptions().set_times(1000)); + // Make sure we ACK'ed the update. + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that we restart all xDS requests when we reestablish the ADS call. +TEST_P(XdsResolverOnlyTest, RestartsRequestsUponReconnection) { + // Manually configure use of RDS. + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + auto* rds = http_connection_manager.mutable_rds(); + rds->set_route_config_name(kDefaultRouteConfigurationName); + rds->mutable_config_source()->mutable_ads(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + balancers_[0]->ads_service()->SetRdsResource(default_route_config_); + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(0, 2); + // Now shut down and restart the balancer. When the client + // reconnects, it should automatically restart the requests for all + // resource types. + balancers_[0]->Shutdown(); + balancers_[0]->Start(); + // Make sure things are still working. + CheckRpcSendOk(100); + // Populate new EDS resource. + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsServiceName)); + // Populate new CDS resource. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = default_route_config_; + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + balancers_[0]->ads_service()->SetRdsResource(new_route_config); + // Wait for all new backends to be used. + std::tuple counts = WaitForAllBackends(2, 4); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +TEST_P(XdsResolverOnlyTest, DefaultRouteSpecifiesSlashPrefix) { + RouteConfiguration route_config = default_route_config_; + route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_match() + ->set_prefix("/"); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); +} + +TEST_P(XdsResolverOnlyTest, CircuitBreaking) { + constexpr size_t kMaxConcurrentRequests = 10; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Update CDS resource to set max concurrent request. + CircuitBreakers circuit_breaks; + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Send exactly max_concurrent_requests long RPCs. + LongRunningRpc rpcs[kMaxConcurrentRequests]; + for (size_t i = 0; i < kMaxConcurrentRequests; ++i) { + rpcs[i].StartRpc(stub_.get()); + } + // Wait for all RPCs to be in flight. + while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() < + kMaxConcurrentRequests) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1 * 1000, GPR_TIMESPAN))); + } + // Sending a RPC now should fail, the error message should tell us + // we hit the max concurrent requests limit and got dropped. + Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "circuit breaker drop"); + // Cancel one RPC to allow another one through + rpcs[0].CancelRpc(); + status = SendRpc(); + EXPECT_TRUE(status.ok()); + for (size_t i = 1; i < kMaxConcurrentRequests; ++i) { + rpcs[i].CancelRpc(); + } + // Make sure RPCs go to the correct backend: + EXPECT_EQ(kMaxConcurrentRequests + 1, + backends_[0]->backend_service()->request_count()); +} + +TEST_P(XdsResolverOnlyTest, CircuitBreakingMultipleChannelsShareCallCounter) { + constexpr size_t kMaxConcurrentRequests = 10; + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Update CDS resource to set max concurrent request. + CircuitBreakers circuit_breaks; + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Create second channel. + auto response_generator2 = + grpc_core::MakeRefCounted(); + auto lb_response_generator2 = + grpc_core::MakeRefCounted(); + grpc_arg xds_arg = grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + lb_response_generator2.get()); + grpc_channel_args xds_channel_args2 = {1, &xds_arg}; + auto channel2 = CreateChannel( + /*failover_timeout=*/0, /*server_name=*/kServerName, + response_generator2.get(), &xds_channel_args2); + auto stub2 = grpc::testing::EchoTestService::NewStub(channel2); + // Set resolution results for both channels and for the xDS channel. + SetNextResolution({}); + SetNextResolution({}, response_generator2.get()); + SetNextResolutionForLbChannelAllBalancers(); + SetNextResolutionForLbChannelAllBalancers(nullptr, nullptr, + lb_response_generator2.get()); + // Send exactly max_concurrent_requests long RPCs, alternating between + // the two channels. + LongRunningRpc rpcs[kMaxConcurrentRequests]; + for (size_t i = 0; i < kMaxConcurrentRequests; ++i) { + rpcs[i].StartRpc(i % 2 == 0 ? stub_.get() : stub2.get()); + } + // Wait for all RPCs to be in flight. + while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() < + kMaxConcurrentRequests) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1 * 1000, GPR_TIMESPAN))); + } + // Sending a RPC now should fail, the error message should tell us + // we hit the max concurrent requests limit and got dropped. + Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "circuit breaker drop"); + // Cancel one RPC to allow another one through + rpcs[0].CancelRpc(); + status = SendRpc(); + EXPECT_TRUE(status.ok()); + for (size_t i = 1; i < kMaxConcurrentRequests; ++i) { + rpcs[i].CancelRpc(); + } + // Make sure RPCs go to the correct backend: + EXPECT_EQ(kMaxConcurrentRequests + 1, + backends_[0]->backend_service()->request_count()); +} + +TEST_P(XdsResolverOnlyTest, ClusterChangeAfterAdsCallFails) { + const char* kNewEdsResourceName = "new_eds_resource_name"; + // Populate EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + SetNextResolutionForLbChannelAllBalancers(); + // Check that the channel is working. + CheckRpcSendOk(); + // Stop and restart the balancer. + balancers_[0]->Shutdown(); + balancers_[0]->Start(); + // Create new EDS resource. + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsResourceName)); + // Change CDS resource to point to new EDS resource. + auto cluster = default_cluster_; + cluster.mutable_eds_cluster_config()->set_service_name(kNewEdsResourceName); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Make sure client sees the change. + // TODO(roth): This should not be allowing errors. The errors are + // being caused by a bug that triggers in the following situation: + // + // 1. xDS call fails. + // 2. When xDS call is restarted, the server sends the updated CDS + // resource that points to the new EDS resource name. + // 3. When the client receives the CDS update, it does two things: + // - Sends the update to the CDS LB policy, which creates a new + // xds_cluster_resolver policy using the new EDS service name. + // - Notices that the CDS update no longer refers to the old EDS + // service name, so removes that resource, notifying the old + // xds_cluster_resolver policy that the resource no longer exists. + // + // Need to figure out a way to fix this bug, and then change this to + // not allow failures. + WaitForBackend(1, WaitForBackendOptions().set_allow_failures(true)); +} + +using GlobalXdsClientTest = BasicTest; + +TEST_P(GlobalXdsClientTest, MultipleChannelsShareXdsClient) { + const char* kNewServerName = "new-server.example.com"; + Listener listener = default_listener_; + listener.set_name(kNewServerName); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + WaitForAllBackends(); + // Create second channel and tell it to connect to kNewServerName. + auto channel2 = CreateChannel(/*failover_timeout=*/0, kNewServerName); + channel2->GetState(/*try_to_connect=*/true); + ASSERT_TRUE( + channel2->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100))); + // Make sure there's only one client connected. + EXPECT_EQ(1UL, balancers_[0]->ads_service()->clients().size()); +} + +// Tests that the NACK for multiple bad LDS resources includes both errors. +TEST_P(GlobalXdsClientTest, MultipleBadResources) { + constexpr char kServerName2[] = "server.other.com"; + constexpr char kServerName3[] = "server.another.com"; + auto listener = default_listener_; + listener.clear_api_listener(); + balancers_[0]->ads_service()->SetLdsResource(listener); + listener.set_name(kServerName2); + balancers_[0]->ads_service()->SetLdsResource(listener); + listener = default_listener_; + listener.set_name(kServerName3); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + // Need to create a second channel to subscribe to a second LDS resource. + auto channel2 = CreateChannel(0, kServerName2); + auto stub2 = grpc::testing::EchoTestService::NewStub(channel2); + { + ClientContext context; + EchoRequest request; + request.set_message(kRequestMessage); + EchoResponse response; + grpc::Status status = stub2->Echo(&context, request, &response); + EXPECT_FALSE(status.ok()); + // Wait for second NACK to be reported to xDS server. + auto deadline = absl::Now() + absl::Seconds(30); + bool timed_out = false; + CheckRpcSendFailure( + CheckRpcSendFailureOptions().set_continue_predicate([&](size_t) { + if (absl::Now() >= deadline) { + timed_out = true; + return false; + } + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + return response_state.state != + AdsServiceImpl::ResponseState::NACKED || + ::testing::Matches(::testing::ContainsRegex(absl::StrCat( + kServerName, + ": validation error.*" + "Listener has neither address nor ApiListener.*", + kServerName2, + ": validation error.*" + "Listener has neither address nor ApiListener")))( + response_state.error_message); + })); + ASSERT_FALSE(timed_out); + } + // Now start a new channel with a third server name, this one with a + // valid resource. + auto channel3 = CreateChannel(0, kServerName3); + auto stub3 = grpc::testing::EchoTestService::NewStub(channel3); + { + ClientContext context; + EchoRequest request; + request.set_message(kRequestMessage); + EchoResponse response; + grpc::Status status = stub3->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + } +} + +// Tests that we don't trigger does-not-exist callbacks for a resource +// that was previously valid but is updated to be invalid. +TEST_P(GlobalXdsClientTest, InvalidListenerStillExistsIfPreviouslyCached) { + // Set up valid resources and check that the channel works. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); + // Now send an update changing the Listener to be invalid. + auto listener = default_listener_; + listener.clear_api_listener(); + balancers_[0]->ads_service()->SetLdsResource(listener); + // Wait for xDS server to see NACK. + auto deadline = absl::Now() + absl::Seconds(30); + do { + CheckRpcSendOk(); + ASSERT_LT(absl::Now(), deadline); + } while (balancers_[0]->ads_service()->lds_response_state().state != + AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::ContainsRegex(absl::StrCat( + kServerName, + ": validation error.*" + "Listener has neither address nor ApiListener"))); + // Check one more time, just to make sure it still works after NACK. + CheckRpcSendOk(); +} + +class XdsResolverLoadReportingOnlyTest : public XdsEnd2endTest { + public: + XdsResolverLoadReportingOnlyTest() : XdsEnd2endTest(4, 1, 3) {} +}; + +// Tests load reporting when switching over from one cluster to another. +TEST_P(XdsResolverLoadReportingOnlyTest, ChangeClusters) { + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + balancers_[0]->lrs_service()->set_cluster_names( + {kDefaultClusterName, kNewClusterName}); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // cluster kDefaultClusterName -> locality0 -> backends 0 and 1 + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // cluster kNewClusterName -> locality1 -> backends 2 and 3 + EdsResourceArgs args2({ + {"locality1", CreateEndpointsForBackends(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsServiceName)); + // CDS resource for kNewClusterName. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Wait for all backends to come online. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(0, 2); + // The load report received at the balancer should be correct. + std::vector load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + EXPECT_THAT( + load_report, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, kDefaultClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality0", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + num_ok), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + num_failure), + ::testing::Field( + &ClientStats::LocalityStats::total_issued_requests, + num_failure + num_ok))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)))); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = default_route_config_; + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + // Wait for all new backends to be used. + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(2, 4); + // The load report received at the balancer should be correct. + load_report = balancers_[0]->lrs_service()->WaitForLoadReport(); + EXPECT_THAT( + load_report, + ::testing::ElementsAre( + ::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, + kDefaultClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality0", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + ::testing::Lt(num_ok)), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + ::testing::Le(num_failure)), + ::testing::Field( + &ClientStats::LocalityStats:: + total_issued_requests, + ::testing::Le(num_failure + num_ok)))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)), + ::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, kNewClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality1", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + ::testing::Le(num_ok)), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + ::testing::Le(num_failure)), + ::testing::Field( + &ClientStats::LocalityStats:: + total_issued_requests, + ::testing::Le(num_failure + num_ok)))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)))); + int total_ok = 0; + int total_failure = 0; + for (const ClientStats& client_stats : load_report) { + total_ok += client_stats.total_successful_requests(); + total_failure += client_stats.total_error_requests(); + } + EXPECT_EQ(total_ok, num_ok); + EXPECT_EQ(total_failure, num_failure); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +using SecureNamingTest = BasicTest; + +// Tests that secure naming check passes if target name is expected. +TEST_P(SecureNamingTest, TargetNameIsExpected) { + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr, "xds_server"); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + CheckRpcSendOk(); +} + +// Tests that secure naming check fails if target name is unexpected. +TEST_P(SecureNamingTest, TargetNameIsUnexpected) { + GRPC_GTEST_FLAG_SET_DEATH_TEST_STYLE("threadsafe"); + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr, + "incorrect_server_name"); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that we blow up (via abort() from the security connector) when + // the name from the balancer doesn't match expectations. + ASSERT_DEATH_IF_SUPPORTED({ CheckRpcSendOk(); }, ""); +} + +using LdsTest = BasicTest; + +// Tests that LDS client should send a NACK if there is no API listener in the +// Listener in the LDS response. +TEST_P(LdsTest, NoApiListener) { + auto listener = default_listener_; + listener.clear_api_listener(); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Listener has neither address nor ApiListener")); +} + +// Tests that LDS client should send a NACK if the route_specifier in the +// http_connection_manager is neither inlined route_config nor RDS. +TEST_P(LdsTest, WrongRouteSpecifier) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + http_connection_manager.mutable_scoped_routes(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "HttpConnectionManager neither has inlined route_config nor RDS.")); +} + +// Tests that LDS client should send a NACK if the rds message in the +// http_connection_manager is missing the config_source field. +TEST_P(LdsTest, RdsMissingConfigSource) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + http_connection_manager.mutable_rds()->set_route_config_name( + kDefaultRouteConfigurationName); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "HttpConnectionManager missing config_source for RDS.")); +} + +// Tests that LDS client should send a NACK if the rds message in the +// http_connection_manager has a config_source field that does not specify +// ADS. +TEST_P(LdsTest, RdsConfigSourceDoesNotSpecifyAds) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + auto* rds = http_connection_manager.mutable_rds(); + rds->set_route_config_name(kDefaultRouteConfigurationName); + rds->mutable_config_source()->mutable_self(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("HttpConnectionManager ConfigSource for " + "RDS does not specify ADS.")); +} + +// Tests that we NACK non-terminal filters at the end of the list. +TEST_P(LdsTest, NacksNonTerminalHttpFilterAtEndOfList) { + SetNextResolutionForLbChannelAllBalancers(); + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("unknown"); + filter->mutable_typed_config()->set_type_url( + "grpc.testing.client_only_http_filter"); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "non-terminal filter for config type grpc.testing" + ".client_only_http_filter is the last filter in the chain")); +} + +// Test that we NACK terminal filters that are not at the end of the list. +TEST_P(LdsTest, NacksTerminalFilterBeforeEndOfList) { + SetNextResolutionForLbChannelAllBalancers(); + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "terminal filter for config type envoy.extensions.filters.http" + ".router.v3.Router must be the last filter in the chain")); +} + +// Test that we NACK empty filter names. +TEST_P(LdsTest, RejectsEmptyHttpFilterName) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->Clear(); + filter->mutable_typed_config()->PackFrom(Listener()); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("empty filter name at index 0")); +} + +// Test that we NACK duplicate HTTP filter names. +TEST_P(LdsTest, RejectsDuplicateHttpFilterName) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + http_connection_manager.mutable_http_filters(0) + ->mutable_typed_config() + ->PackFrom(HTTPFault()); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("duplicate HTTP filter name: router")); +} + +// Test that we NACK unknown filter types. +TEST_P(LdsTest, RejectsUnknownHttpFilterType) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("unknown"); + filter->mutable_typed_config()->PackFrom(Listener()); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("no filter registered for config type " + "envoy.config.listener.v3.Listener")); +} + +// Test that we ignore optional unknown filter types. +TEST_P(LdsTest, IgnoresOptionalUnknownHttpFilterType) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("unknown"); + filter->mutable_typed_config()->PackFrom(Listener()); + filter->set_is_optional(true); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK filters without configs. +TEST_P(LdsTest, RejectsHttpFilterWithoutConfig) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->Clear(); + filter->set_name("unknown"); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we ignore optional filters without configs. +TEST_P(LdsTest, IgnoresOptionalHttpFilterWithoutConfig) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->Clear(); + filter->set_name("unknown"); + filter->set_is_optional(true); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK unparseable filter configs. +TEST_P(LdsTest, RejectsUnparseableHttpFilterType) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("unknown"); + filter->mutable_typed_config()->PackFrom(listener); + filter->mutable_typed_config()->set_type_url( + "type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault"); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "filter config for type " + "envoy.extensions.filters.http.fault.v3.HTTPFault failed to parse")); +} + +// Test that we NACK HTTP filters unsupported on client-side. +TEST_P(LdsTest, RejectsHttpFiltersNotSupportedOnClients) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("grpc.testing.server_only_http_filter"); + filter->mutable_typed_config()->set_type_url( + "grpc.testing.server_only_http_filter"); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Filter grpc.testing.server_only_http_filter is not " + "supported on clients")); +} + +// Test that we ignore optional HTTP filters unsupported on client-side. +TEST_P(LdsTest, IgnoresOptionalHttpFiltersNotSupportedOnClients) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + *http_connection_manager.add_http_filters() = + http_connection_manager.http_filters(0); + auto* filter = http_connection_manager.mutable_http_filters(0); + filter->set_name("grpc.testing.server_only_http_filter"); + filter->mutable_typed_config()->set_type_url( + "grpc.testing.server_only_http_filter"); + filter->set_is_optional(true); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + WaitForBackend(0); + EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +using LdsV2Test = LdsTest; + +// Tests that we ignore the HTTP filter list in v2. +// TODO(roth): The test framework is not set up to allow us to test +// the server sending v2 resources when the client requests v3, so this +// just tests a pure v2 setup. When we have time, fix this. +TEST_P(LdsV2Test, IgnoresHttpFilters) { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + auto* filter = http_connection_manager.add_http_filters(); + filter->set_name("unknown"); + filter->mutable_typed_config()->PackFrom(Listener()); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); +} + +using LdsRdsTest = BasicTest; + +// Tests that LDS client should send an ACK upon correct LDS response (with +// inlined RDS result). +TEST_P(LdsRdsTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); + // Make sure we actually used the RPC service for the right version of xDS. + EXPECT_EQ(balancers_[0]->ads_service()->seen_v2_client(), + GetParam().use_v2()); + EXPECT_NE(balancers_[0]->ads_service()->seen_v3_client(), + GetParam().use_v2()); +} + +// Tests that we go into TRANSIENT_FAILURE if the Listener is removed. +TEST_P(LdsRdsTest, ListenerRemoved) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Unset LDS resource. + balancers_[0]->ads_service()->UnsetResource(kLdsTypeUrl, kServerName); + // Wait for RPCs to start failing. + do { + } while (SendRpc(RpcOptions(), nullptr).ok()); + // Make sure RPCs are still failing. + CheckRpcSendFailure(CheckRpcSendFailureOptions().set_times(1000)); + // Make sure we ACK'ed the update. + EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client ACKs but fails if matching domain can't be found in +// the LDS response. +TEST_P(LdsRdsTest, NoMatchedDomain) { + RouteConfiguration route_config = default_route_config_; + route_config.mutable_virtual_hosts(0)->clear_domains(); + route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + // Do a bit of polling, to allow the ACK to get to the ADS server. + channel_->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100)); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should choose the virtual host with matching domain +// if multiple virtual hosts exist in the LDS response. +TEST_P(LdsRdsTest, ChooseMatchedDomain) { + RouteConfiguration route_config = default_route_config_; + *(route_config.add_virtual_hosts()) = route_config.virtual_hosts(0); + route_config.mutable_virtual_hosts(0)->clear_domains(); + route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should choose the last route in the virtual host if +// multiple routes exist in the LDS response. +TEST_P(LdsRdsTest, ChooseLastRoute) { + RouteConfiguration route_config = default_route_config_; + *(route_config.mutable_virtual_hosts(0)->add_routes()) = + route_config.virtual_hosts(0).routes(0); + route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_cluster_header(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should ignore route which has query_parameters. +TEST_P(LdsRdsTest, RouteMatchHasQueryParameters) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_match()->add_query_parameters(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should send a ACK if route match has a prefix +// that is either empty or a single slash +TEST_P(LdsRdsTest, RouteMatchHasValidPrefixEmptyOrSingleSlash) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix("/"); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should ignore route which has a path +// prefix string does not start with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixNoLeadingSlash) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("grpc.testing.EchoTest1Service/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has a prefix +// string with more than 2 slashes. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixExtraContent) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/Echo1/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has a prefix +// string "//". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixDoubleSlash) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("//"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// but it's empty. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathEmptyPath) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path(""); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// string does not start with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathNoLeadingSlash) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("grpc.testing.EchoTest1Service/Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// string that has too many slashes; for example, ends with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathTooManySlashes) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// string that has only 1 slash: missing "/" between service and method. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathOnlyOneSlash) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service.Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// string that is missing service. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingService) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("//Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Tests that LDS client should ignore route which has path +// string that is missing method. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingMethod) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No valid routes specified.")); +} + +// Test that LDS client should reject route which has invalid path regex. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathRegex) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->mutable_safe_regex()->set_regex("a[z-a]"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "path matcher: Invalid regex string specified in matcher.")); +} + +// Tests that LDS client should send a NACK if route has an action other than +// RouteAction in the LDS response. +TEST_P(LdsRdsTest, RouteHasNoRouteAction) { + RouteConfiguration route_config = default_route_config_; + route_config.mutable_virtual_hosts(0)->mutable_routes(0)->mutable_redirect(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("No RouteAction found in route.")); +} + +TEST_P(LdsRdsTest, RouteActionClusterHasEmptyClusterName) { + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(""); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("RouteAction cluster contains empty cluster name.")); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetHasIncorrectTotalWeightSet) { + const size_t kWeight75 = 75; + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + 1); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "RouteAction weighted_cluster has incorrect total weight")); +} + +TEST_P(LdsRdsTest, RouteActionWeightedClusterHasZeroTotalWeight) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(0); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(0); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "RouteAction weighted_cluster has no valid clusters specified.")); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasEmptyClusterName) { + const size_t kWeight75 = 75; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(""); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("RouteAction weighted_cluster cluster " + "contains empty cluster name.")); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasNoWeight) { + const size_t kWeight75 = 75; + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "RouteAction weighted_cluster cluster missing weight")); +} + +TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRegex) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->mutable_safe_regex_match()->set_regex("a[z-a]"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "header matcher: Invalid regex string specified in matcher.")); +} + +TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRange) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->mutable_range_match()->set_start(1001); + header_matcher1->mutable_range_match()->set_end(1000); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "header matcher: Invalid range specifier specified: end cannot be " + "smaller than start.")); +} + +// Tests that LDS client should choose the default route (with no matching +// specified) after unable to find a match with previous routes. +TEST_P(LdsRdsTest, XdsRoutingPathMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_path("/grpc.testing.EchoTest2Service/Echo2"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* route3 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route3->mutable_match()->set_path("/grpc.testing.EchoTest3Service/Echo3"); + route3->mutable_route()->set_cluster(kDefaultClusterName); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho2Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO2) + .set_rpc_method(METHOD_ECHO2) + .set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPathMatchingCaseInsensitive) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + // First route will not match, since it's case-sensitive. + // Second route will match with same path. + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/GrPc.TeStInG.EcHoTeSt1SErViCe/EcHo1"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_path("/GrPc.TeStInG.EcHoTeSt1SErViCe/EcHo1"); + route2->mutable_match()->mutable_case_sensitive()->set_value(false); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPrefixMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho1Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho2Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPrefixMatchingCaseInsensitive) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + // First route will not match, since it's case-sensitive. + // Second route will match with same path. + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/GrPc.TeStInG.EcHoTeSt1SErViCe"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix("/GrPc.TeStInG.EcHoTeSt1SErViCe"); + route2->mutable_match()->mutable_case_sensitive()->set_value(false); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPathRegexMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 2)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + // Will match "/grpc.testing.EchoTest1Service/" + route1->mutable_match()->mutable_safe_regex()->set_regex(".*1.*"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + // Will match "/grpc.testing.EchoTest2Service/" + route2->mutable_match()->mutable_safe_regex()->set_regex(".*2.*"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho1Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho2Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedCluster) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNotUsedClusterName = "not_used_cluster"; + const size_t kNumEchoRpcs = 10; // RPCs that will go to a fixed backend. + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const double kErrorTolerance = 0.05; + const double kWeight75Percent = static_cast(kWeight75) / 100; + const double kWeight25Percent = static_cast(kWeight25) / 100; + const size_t kNumEcho1Rpcs = + ComputeIdealNumRpcs(kWeight75Percent, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + // Cluster with weight 0 will not be used. + auto* weighted_cluster3 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster3->set_name(kNotUsedClusterName); + weighted_cluster3->mutable_weight()->set_value(0); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 3, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_25_request_count = + backends_[2]->backend_service1()->request_count(); + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(static_cast(weight_75_request_count) / kNumEcho1Rpcs, + ::testing::DoubleNear(kWeight75Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_25_request_count) / kNumEcho1Rpcs, + ::testing::DoubleNear(kWeight25Percent, kErrorTolerance)); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetDefaultRoute) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const double kErrorTolerance = 0.05; + const double kWeight75Percent = static_cast(kWeight75) / 100; + const double kWeight25Percent = static_cast(kWeight25) / 100; + const size_t kNumEchoRpcs = + ComputeIdealNumRpcs(kWeight75Percent, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(1, 3); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(0, backends_[0]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service()->request_count(); + const int weight_25_request_count = + backends_[2]->backend_service()->request_count(); + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(static_cast(weight_75_request_count) / kNumEchoRpcs, + ::testing::DoubleNear(kWeight75Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_25_request_count) / kNumEchoRpcs, + ::testing::DoubleNear(kWeight25Percent, kErrorTolerance)); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateWeights) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEchoRpcs = 10; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const size_t kWeight50 = 50; + const double kErrorTolerance = 0.05; + const double kWeight75Percent = static_cast(kWeight75) / 100; + const double kWeight25Percent = static_cast(kWeight25) / 100; + const double kWeight50Percent = static_cast(kWeight50) / 100; + const size_t kNumEcho1Rpcs7525 = + ComputeIdealNumRpcs(kWeight75Percent, kErrorTolerance); + const size_t kNumEcho1Rpcs5050 = + ComputeIdealNumRpcs(kWeight50Percent, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args3({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = default_cluster_; + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 3, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs7525, + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_25_request_count = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(static_cast(weight_75_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight75Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_25_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight25Percent, kErrorTolerance)); + // Change Route Configurations: same clusters different weights. + weighted_cluster1->mutable_weight()->set_value(kWeight50); + weighted_cluster2->mutable_weight()->set_value(kWeight50); + // Change default route to a new cluster to help to identify when new + // polices are seen by the client. + default_route->mutable_route()->set_cluster(kNewCluster3Name); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForAllBackends(3, 4); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs5050, + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(0, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_50_request_count_1 = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_50_request_count_2 = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(kNumEchoRpcs, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_THAT( + static_cast(weight_50_request_count_1) / kNumEcho1Rpcs5050, + ::testing::DoubleNear(kWeight50Percent, kErrorTolerance)); + EXPECT_THAT( + static_cast(weight_50_request_count_2) / kNumEcho1Rpcs5050, + ::testing::DoubleNear(kWeight50Percent, kErrorTolerance)); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateClusters) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEchoRpcs = 10; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const size_t kWeight50 = 50; + const double kErrorTolerance = 0.05; + const double kWeight75Percent = static_cast(kWeight75) / 100; + const double kWeight25Percent = static_cast(kWeight25) / 100; + const double kWeight50Percent = static_cast(kWeight50) / 100; + const size_t kNumEcho1Rpcs7525 = + ComputeIdealNumRpcs(kWeight75Percent, kErrorTolerance); + const size_t kNumEcho1Rpcs5050 = + ComputeIdealNumRpcs(kWeight50Percent, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args3({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = default_cluster_; + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kDefaultClusterName); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForBackend(0); + WaitForBackend(1, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs7525, + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + int weight_25_request_count = + backends_[0]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(static_cast(weight_75_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight75Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_25_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight25Percent, kErrorTolerance)); + // Change Route Configurations: new set of clusters with different weights. + weighted_cluster1->mutable_weight()->set_value(kWeight50); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight50); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForBackend(2, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs5050, + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_50_request_count_1 = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_50_request_count_2 = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_THAT( + static_cast(weight_50_request_count_1) / kNumEcho1Rpcs5050, + ::testing::DoubleNear(kWeight50Percent, kErrorTolerance)); + EXPECT_THAT( + static_cast(weight_50_request_count_2) / kNumEcho1Rpcs5050, + ::testing::DoubleNear(kWeight50Percent, kErrorTolerance)); + // Change Route Configurations. + weighted_cluster1->mutable_weight()->set_value(kWeight75); + weighted_cluster2->set_name(kNewCluster3Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForBackend(3, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs7525, + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + weight_75_request_count = backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + weight_25_request_count = backends_[3]->backend_service1()->request_count(); + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(static_cast(weight_75_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight75Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_25_request_count) / kNumEcho1Rpcs7525, + ::testing::DoubleNear(kWeight25Percent, kErrorTolerance)); +} + +TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClusters) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Send Route Configuration. + RouteConfiguration new_route_config = default_route_config_; + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + // Change Route Configurations: new default cluster. + auto* default_route = + new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + default_route->mutable_route()->set_cluster(kNewClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(1, 2); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[1]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClustersWithPickingDelays) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Bring down the current backend: 0, this will delay route picking time, + // resulting in un-committed RPCs. + ShutdownBackend(0); + // Send a RouteConfiguration with a default route that points to + // backend 0. + RouteConfiguration new_route_config = default_route_config_; + SetRouteConfiguration(0, new_route_config); + // Send exactly one RPC with no deadline and with wait_for_ready=true. + // This RPC will not complete until after backend 0 is started. + std::thread sending_rpc([this]() { + CheckRpcSendOk(1, RpcOptions().set_wait_for_ready(true).set_timeout_ms(0)); + }); + // Send a non-wait_for_ready RPC which should fail, this will tell us + // that the client has received the update and attempted to connect. + const Status status = SendRpc(RpcOptions().set_timeout_ms(0)); + EXPECT_FALSE(status.ok()); + // Send a update RouteConfiguration to use backend 1. + auto* default_route = + new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + default_route->mutable_route()->set_cluster(kNewClusterName); + SetRouteConfiguration(0, new_route_config); + // Wait for RPCs to go to the new backend: 1, this ensures that the client + // has processed the update. + WaitForBackend( + 1, WaitForBackendOptions().set_reset_counters(false).set_allow_failures( + true)); + // Bring up the previous backend: 0, this will allow the delayed RPC to + // finally call on_call_committed upon completion. + StartBackend(0); + sending_rpc.join(); + // Make sure RPCs go to the correct backend: + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingApplyXdsTimeout) { + const int64_t kTimeoutMillis = 500; + const int64_t kTimeoutNano = kTimeoutMillis * 1000000; + const int64_t kTimeoutGrpcTimeoutHeaderMaxSecond = 1; + const int64_t kTimeoutMaxStreamDurationSecond = 2; + const int64_t kTimeoutHttpMaxStreamDurationSecond = 3; + const int64_t kTimeoutApplicationSecond = 4; + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({{"locality0", {MakeNonExistantEndpoint()}}}); + EdsResourceArgs args1({{"locality0", {MakeNonExistantEndpoint()}}}); + EdsResourceArgs args2({{"locality0", {MakeNonExistantEndpoint()}}}); + EdsResourceArgs args3({{"locality0", {MakeNonExistantEndpoint()}}}); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = default_cluster_; + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Construct listener. + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + // Set up HTTP max_stream_duration of 3.5 seconds + auto* duration = + http_connection_manager.mutable_common_http_protocol_options() + ->mutable_max_stream_duration(); + duration->set_seconds(kTimeoutHttpMaxStreamDurationSecond); + duration->set_nanos(kTimeoutNano); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + // Construct route config. + RouteConfiguration new_route_config = default_route_config_; + // route 1: Set max_stream_duration of 2.5 seconds, Set + // grpc_timeout_header_max of 1.5 + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* max_stream_duration = + route1->mutable_route()->mutable_max_stream_duration(); + duration = max_stream_duration->mutable_max_stream_duration(); + duration->set_seconds(kTimeoutMaxStreamDurationSecond); + duration->set_nanos(kTimeoutNano); + duration = max_stream_duration->mutable_grpc_timeout_header_max(); + duration->set_seconds(kTimeoutGrpcTimeoutHeaderMaxSecond); + duration->set_nanos(kTimeoutNano); + // route 2: Set max_stream_duration of 2.5 seconds + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_path("/grpc.testing.EchoTest2Service/Echo2"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + max_stream_duration = route2->mutable_route()->mutable_max_stream_duration(); + duration = max_stream_duration->mutable_max_stream_duration(); + duration->set_seconds(kTimeoutMaxStreamDurationSecond); + duration->set_nanos(kTimeoutNano); + // route 3: No timeout values in route configuration + auto* route3 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route3->mutable_match()->set_path("/grpc.testing.EchoTestService/Echo"); + route3->mutable_route()->set_cluster(kNewCluster3Name); + // Set listener and route config. + SetListenerAndRouteConfiguration(0, std::move(listener), new_route_config); + // Test grpc_timeout_header_max of 1.5 seconds applied + grpc_millis t0 = NowFromCycleCounter(); + grpc_millis t1 = + t0 + kTimeoutGrpcTimeoutHeaderMaxSecond * 1000 + kTimeoutMillis; + grpc_millis t2 = t0 + kTimeoutMaxStreamDurationSecond * 1000 + kTimeoutMillis; + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true) + .set_timeout_ms(kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + t0 = NowFromCycleCounter(); + EXPECT_GE(t0, t1); + EXPECT_LT(t0, t2); + // Test max_stream_duration of 2.5 seconds applied + t0 = NowFromCycleCounter(); + t1 = t0 + kTimeoutMaxStreamDurationSecond * 1000 + kTimeoutMillis; + t2 = t0 + kTimeoutHttpMaxStreamDurationSecond * 1000 + kTimeoutMillis; + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions() + .set_rpc_service(SERVICE_ECHO2) + .set_rpc_method(METHOD_ECHO2) + .set_wait_for_ready(true) + .set_timeout_ms(kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + t0 = NowFromCycleCounter(); + EXPECT_GE(t0, t1); + EXPECT_LT(t0, t2); + // Test http_stream_duration of 3.5 seconds applied + t0 = NowFromCycleCounter(); + t1 = t0 + kTimeoutHttpMaxStreamDurationSecond * 1000 + kTimeoutMillis; + t2 = t0 + kTimeoutApplicationSecond * 1000 + kTimeoutMillis; + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_wait_for_ready(true).set_timeout_ms( + kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + t0 = NowFromCycleCounter(); + EXPECT_GE(t0, t1); + EXPECT_LT(t0, t2); +} + +TEST_P(LdsRdsTest, XdsRoutingApplyApplicationTimeoutWhenXdsTimeoutExplicit0) { + const int64_t kTimeoutNano = 500000000; + const int64_t kTimeoutMaxStreamDurationSecond = 2; + const int64_t kTimeoutHttpMaxStreamDurationSecond = 3; + const int64_t kTimeoutApplicationSecond = 4; + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({{"locality0", {MakeNonExistantEndpoint()}}}); + EdsResourceArgs args1({{"locality0", {MakeNonExistantEndpoint()}}}); + EdsResourceArgs args2({{"locality0", {MakeNonExistantEndpoint()}}}); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Construct listener. + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + // Set up HTTP max_stream_duration of 3.5 seconds + auto* duration = + http_connection_manager.mutable_common_http_protocol_options() + ->mutable_max_stream_duration(); + duration->set_seconds(kTimeoutHttpMaxStreamDurationSecond); + duration->set_nanos(kTimeoutNano); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + // Construct route config. + RouteConfiguration new_route_config = default_route_config_; + // route 1: Set max_stream_duration of 2.5 seconds, Set + // grpc_timeout_header_max of 0 + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* max_stream_duration = + route1->mutable_route()->mutable_max_stream_duration(); + duration = max_stream_duration->mutable_max_stream_duration(); + duration->set_seconds(kTimeoutMaxStreamDurationSecond); + duration->set_nanos(kTimeoutNano); + duration = max_stream_duration->mutable_grpc_timeout_header_max(); + duration->set_seconds(0); + duration->set_nanos(0); + // route 2: Set max_stream_duration to 0 + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_path("/grpc.testing.EchoTest2Service/Echo2"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + max_stream_duration = route2->mutable_route()->mutable_max_stream_duration(); + duration = max_stream_duration->mutable_max_stream_duration(); + duration->set_seconds(0); + duration->set_nanos(0); + // Set listener and route config. + SetListenerAndRouteConfiguration(0, std::move(listener), new_route_config); + // Test application timeout is applied for route 1 + auto t0 = system_clock::now(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true) + .set_timeout_ms(kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + auto ellapsed_nano_seconds = + std::chrono::duration_cast(system_clock::now() - + t0); + EXPECT_GT(ellapsed_nano_seconds.count(), + kTimeoutApplicationSecond * 1000000000); + // Test application timeout is applied for route 2 + t0 = system_clock::now(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions() + .set_rpc_service(SERVICE_ECHO2) + .set_rpc_method(METHOD_ECHO2) + .set_wait_for_ready(true) + .set_timeout_ms(kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + ellapsed_nano_seconds = std::chrono::duration_cast( + system_clock::now() - t0); + EXPECT_GT(ellapsed_nano_seconds.count(), + kTimeoutApplicationSecond * 1000000000); +} + +TEST_P(LdsRdsTest, XdsRoutingApplyApplicationTimeoutWhenHttpTimeoutExplicit0) { + const int64_t kTimeoutApplicationSecond = 4; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({{"locality0", {MakeNonExistantEndpoint()}}}); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + listener.mutable_api_listener()->mutable_api_listener()->UnpackTo( + &http_connection_manager); + // Set up HTTP max_stream_duration to be explicit 0 + auto* duration = + http_connection_manager.mutable_common_http_protocol_options() + ->mutable_max_stream_duration(); + duration->set_seconds(0); + duration->set_nanos(0); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + // Set listener and route config. + SetListenerAndRouteConfiguration(0, std::move(listener), + default_route_config_); + // Test application timeout is applied for route 1 + auto t0 = system_clock::now(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_wait_for_ready(true).set_timeout_ms( + kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + auto ellapsed_nano_seconds = + std::chrono::duration_cast(system_clock::now() - + t0); + EXPECT_GT(ellapsed_nano_seconds.count(), + kTimeoutApplicationSecond * 1000000000); +} + +// Test to ensure application-specified deadline won't be affected when +// the xDS config does not specify a timeout. +TEST_P(LdsRdsTest, XdsRoutingWithOnlyApplicationTimeout) { + const int64_t kTimeoutApplicationSecond = 4; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({{"locality0", {MakeNonExistantEndpoint()}}}); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + auto t0 = system_clock::now(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_wait_for_ready(true).set_timeout_ms( + kTimeoutApplicationSecond * 1000)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + auto ellapsed_nano_seconds = + std::chrono::duration_cast(system_clock::now() - + t0); + EXPECT_GT(ellapsed_nano_seconds.count(), + kTimeoutApplicationSecond * 1000000000); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyNumRetries) { + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on( + "5xx,cancelled,deadline-exceeded,internal,resource-exhausted," + "unavailable"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + SetRouteConfiguration(0, new_route_config); + // Ensure we retried the correct number of times on all supported status. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions().set_server_expected_error(StatusCode::CANCELLED)) + .set_expected_error_code(StatusCode::CANCELLED)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); + ResetBackendCounters(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::DEADLINE_EXCEEDED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); + ResetBackendCounters(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions().set_server_expected_error(StatusCode::INTERNAL)) + .set_expected_error_code(StatusCode::INTERNAL)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); + ResetBackendCounters(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::RESOURCE_EXHAUSTED)) + .set_expected_error_code(StatusCode::RESOURCE_EXHAUSTED)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); + ResetBackendCounters(); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions().set_server_expected_error(StatusCode::UNAVAILABLE)) + .set_expected_error_code(StatusCode::UNAVAILABLE)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); + ResetBackendCounters(); + // Ensure we don't retry on an unsupported status. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::UNAUTHENTICATED)) + .set_expected_error_code(StatusCode::UNAUTHENTICATED)); + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyAtVirtualHostLevel) { + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* retry_policy = + new_route_config.mutable_virtual_hosts(0)->mutable_retry_policy(); + retry_policy->set_retry_on( + "cancelled,deadline-exceeded,internal,resource-exhausted,unavailable"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + SetRouteConfiguration(0, new_route_config); + // Ensure we retried the correct number of times on a supported status. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::DEADLINE_EXCEEDED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(kNumRetries + 1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyLongBackOff) { + // Set num retries to 3, but due to longer back off, we expect only 1 retry + // will take place. + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on( + "5xx,cancelled,deadline-exceeded,internal,resource-exhausted," + "unavailable"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + auto base_interval = + retry_policy->mutable_retry_back_off()->mutable_base_interval(); + // Set backoff to 1 second, 1/2 of rpc timeout of 2 second. + base_interval->set_seconds(1 * grpc_test_slowdown_factor()); + base_interval->set_nanos(0); + SetRouteConfiguration(0, new_route_config); + // No need to set max interval and just let it be the default of 10x of base. + // We expect 1 retry before the RPC times out with DEADLINE_EXCEEDED. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions().set_timeout_ms(2500).set_server_expected_error( + StatusCode::CANCELLED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(1 + 1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyMaxBackOff) { + // Set num retries to 3, but due to longer back off, we expect only 2 retry + // will take place, while the 2nd one will obey the max backoff. + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on( + "5xx,cancelled,deadline-exceeded,internal,resource-exhausted," + "unavailable"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + auto base_interval = + retry_policy->mutable_retry_back_off()->mutable_base_interval(); + // Set backoff to 1 second. + base_interval->set_seconds(1 * grpc_test_slowdown_factor()); + base_interval->set_nanos(0); + auto max_interval = + retry_policy->mutable_retry_back_off()->mutable_max_interval(); + // Set max interval to be the same as base, so 2 retries will take 2 seconds + // and both retries will take place before the 2.5 seconds rpc timeout. + // Tested to ensure if max is not set, this test will be the same as + // XdsRetryPolicyLongBackOff and we will only see 1 retry in that case. + max_interval->set_seconds(1 * grpc_test_slowdown_factor()); + max_interval->set_nanos(0); + SetRouteConfiguration(0, new_route_config); + // We expect 2 retry before the RPC times out with DEADLINE_EXCEEDED. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options( + RpcOptions().set_timeout_ms(2500).set_server_expected_error( + StatusCode::CANCELLED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(2 + 1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyUnsupportedStatusCode) { + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on("5xx"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + SetRouteConfiguration(0, new_route_config); + // We expect no retry. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::DEADLINE_EXCEEDED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, + XdsRetryPolicyUnsupportedStatusCodeWithVirtualHostLevelRetry) { + const size_t kNumRetries = 3; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy with no supported retry_on + // statuses. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on("5xx"); + retry_policy->mutable_num_retries()->set_value(kNumRetries); + // Construct a virtual host level retry policy with supported statuses. + auto* virtual_host_retry_policy = + new_route_config.mutable_virtual_hosts(0)->mutable_retry_policy(); + virtual_host_retry_policy->set_retry_on( + "cancelled,deadline-exceeded,internal,resource-exhausted,unavailable"); + virtual_host_retry_policy->mutable_num_retries()->set_value(kNumRetries); + SetRouteConfiguration(0, new_route_config); + // We expect no retry. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_server_expected_error( + StatusCode::DEADLINE_EXCEEDED)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyInvalidNumRetriesZero) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on("deadline-exceeded"); + // Setting num_retries to zero is not valid. + retry_policy->mutable_num_retries()->set_value(0); + SetRouteConfiguration(0, new_route_config); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "RouteAction RetryPolicy num_retries set to invalid value 0.")); +} + +TEST_P(LdsRdsTest, XdsRetryPolicyRetryBackOffMissingBaseInterval) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Construct route config to set retry policy. + RouteConfiguration new_route_config = default_route_config_; + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* retry_policy = route1->mutable_route()->mutable_retry_policy(); + retry_policy->set_retry_on("deadline-exceeded"); + retry_policy->mutable_num_retries()->set_value(1); + // RetryBackoff is there but base interval is missing. + auto max_interval = + retry_policy->mutable_retry_back_off()->mutable_max_interval(); + max_interval->set_seconds(0); + max_interval->set_nanos(250000000); + SetRouteConfiguration(0, new_route_config); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "RouteAction RetryPolicy RetryBackoff missing base interval.")); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatching) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEcho1Rpcs = 100; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->set_exact_match("POST,PUT,GET"); + auto* header_matcher2 = route1->mutable_match()->add_headers(); + header_matcher2->set_name("header2"); + header_matcher2->mutable_safe_regex_match()->set_regex("[a-z]*"); + auto* header_matcher3 = route1->mutable_match()->add_headers(); + header_matcher3->set_name("header3"); + header_matcher3->mutable_range_match()->set_start(1); + header_matcher3->mutable_range_match()->set_end(1000); + auto* header_matcher4 = route1->mutable_match()->add_headers(); + header_matcher4->set_name("header4"); + header_matcher4->set_present_match(false); + auto* header_matcher5 = route1->mutable_match()->add_headers(); + header_matcher5->set_name("header5"); + header_matcher5->set_present_match(true); + auto* header_matcher6 = route1->mutable_match()->add_headers(); + header_matcher6->set_name("header6"); + header_matcher6->set_prefix_match("/grpc"); + auto* header_matcher7 = route1->mutable_match()->add_headers(); + header_matcher7->set_name("header7"); + header_matcher7->set_suffix_match(".cc"); + header_matcher7->set_invert_match(true); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + std::vector> metadata = { + {"header1", "POST"}, + {"header2", "blah"}, + {"header3", "1"}, + {"header5", "anything"}, + {"header6", "/grpc.testing.EchoTest1Service/"}, + {"header1", "PUT"}, + {"header7", "grpc.java"}, + {"header1", "GET"}, + }; + const auto header_match_rpc_options = RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_metadata(std::move(metadata)); + // Make sure all backends are up. + WaitForBackend(0); + WaitForBackend(1, WaitForBackendOptions(), header_match_rpc_options); + // Send RPCs. + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, header_match_rpc_options); + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialHeaderContentType) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEchoRpcs = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("content-type"); + header_matcher1->set_exact_match("notapplication/grpc"); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + auto* header_matcher2 = default_route->mutable_match()->add_headers(); + header_matcher2->set_name("content-type"); + header_matcher2->set_exact_match("application/grpc"); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Make sure the backend is up. + WaitForAllBackends(0, 1); + // Send RPCs. + CheckRpcSendOk(kNumEchoRpcs); + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialCasesToIgnore) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const size_t kNumEchoRpcs = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("grpc-foo-bin"); + header_matcher1->set_present_match(true); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Send headers which will mismatch each route + std::vector> metadata = { + {"grpc-foo-bin", "grpc-foo-bin"}, + }; + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata)); + // Verify that only the default backend got RPCs since all previous routes + // were mismatched. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingRuntimeFractionMatching) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const double kErrorTolerance = 0.05; + const size_t kRouteMatchNumerator = 25; + const double kRouteMatchPercent = + static_cast(kRouteMatchNumerator) / 100; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kRouteMatchPercent, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match() + ->mutable_runtime_fraction() + ->mutable_default_value() + ->set_numerator(kRouteMatchNumerator); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumRpcs); + const int default_backend_count = + backends_[0]->backend_service()->request_count(); + const int matched_backend_count = + backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(default_backend_count) / kNumRpcs, + ::testing::DoubleNear(1 - kRouteMatchPercent, kErrorTolerance)); + EXPECT_THAT(static_cast(matched_backend_count) / kNumRpcs, + ::testing::DoubleNear(kRouteMatchPercent, kErrorTolerance)); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingUnmatchCases) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEcho1Rpcs = 100; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + EdsResourceArgs args3({ + {"locality0", CreateEndpointsForBackends(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = default_cluster_; + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->set_exact_match("POST"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto route2 = route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher2 = route2->mutable_match()->add_headers(); + header_matcher2->set_name("header2"); + header_matcher2->mutable_range_match()->set_start(1); + header_matcher2->mutable_range_match()->set_end(1000); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto route3 = route_config.mutable_virtual_hosts(0)->add_routes(); + route3->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher3 = route3->mutable_match()->add_headers(); + header_matcher3->set_name("header3"); + header_matcher3->mutable_safe_regex_match()->set_regex("[a-z]*"); + route3->mutable_route()->set_cluster(kNewCluster3Name); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Send headers which will mismatch each route + std::vector> metadata = { + {"header1", "POST"}, + {"header2", "1000"}, + {"header3", "123"}, + {"header1", "GET"}, + }; + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_metadata(metadata)); + // Verify that only the default backend got RPCs since all previous routes + // were mismatched. + for (size_t i = 1; i < 4; ++i) { + EXPECT_EQ(0, backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingChangeRoutesWithoutChangingClusters) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = default_cluster_; + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = default_route_config_; + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Make sure all backends are up and that requests for each RPC + // service go to the right backends. + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false)); + WaitForBackend(1, WaitForBackendOptions().set_reset_counters(false), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false), + RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Requests for services Echo and Echo2 should have gone to backend 0. + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(1, backends_[0]->backend_service2()->request_count()); + // Requests for service Echo1 should have gone to backend 1. + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + // Now send an update that changes the first route to match a + // different RPC service, and wait for the client to make the change. + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/"); + SetRouteConfiguration(0, route_config); + WaitForBackend(1, WaitForBackendOptions(), + RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Now repeat the earlier test, making sure all traffic goes to the + // right place. + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false)); + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false), + RpcOptions().set_rpc_service(SERVICE_ECHO1)); + WaitForBackend(1, WaitForBackendOptions().set_reset_counters(false), + RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Requests for services Echo and Echo1 should have gone to backend 0. + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + // Requests for service Echo2 should have gone to backend 1. + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service2()->request_count()); +} + +// Test that we NACK unknown filter types in VirtualHost. +TEST_P(LdsRdsTest, RejectsUnknownHttpFilterTypeInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom(Listener()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("no filter registered for config type " + "envoy.config.listener.v3.Listener")); +} + +// Test that we ignore optional unknown filter types in VirtualHost. +TEST_P(LdsRdsTest, IgnoresOptionalUnknownHttpFilterTypeInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.mutable_config()->PackFrom(Listener()); + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK filters without configs in VirtualHost. +TEST_P(LdsRdsTest, RejectsHttpFilterWithoutConfigInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"]; + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we NACK filters without configs in FilterConfig in VirtualHost. +TEST_P(LdsRdsTest, RejectsHttpFilterWithoutConfigInFilterConfigInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + ::envoy::config::route::v3::FilterConfig()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we ignore optional filters without configs in VirtualHost. +TEST_P(LdsRdsTest, IgnoresOptionalHttpFilterWithoutConfigInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK unparseable filter types in VirtualHost. +TEST_P(LdsRdsTest, RejectsUnparseableHttpFilterTypeInVirtualHost) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = + route_config.mutable_virtual_hosts(0)->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("router filter does not support config override")); +} + +// Test that we NACK unknown filter types in Route. +TEST_P(LdsRdsTest, RejectsUnknownHttpFilterTypeInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom(Listener()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("no filter registered for config type " + "envoy.config.listener.v3.Listener")); +} + +// Test that we ignore optional unknown filter types in Route. +TEST_P(LdsRdsTest, IgnoresOptionalUnknownHttpFilterTypeInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.mutable_config()->PackFrom(Listener()); + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK filters without configs in Route. +TEST_P(LdsRdsTest, RejectsHttpFilterWithoutConfigInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"]; + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we NACK filters without configs in FilterConfig in Route. +TEST_P(LdsRdsTest, RejectsHttpFilterWithoutConfigInFilterConfigInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + ::envoy::config::route::v3::FilterConfig()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we ignore optional filters without configs in Route. +TEST_P(LdsRdsTest, IgnoresOptionalHttpFilterWithoutConfigInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK unparseable filter types in Route. +TEST_P(LdsRdsTest, RejectsUnparseableHttpFilterTypeInRoute) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* per_filter_config = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("router filter does not support config override")); +} + +// Test that we NACK unknown filter types in ClusterWeight. +TEST_P(LdsRdsTest, RejectsUnknownHttpFilterTypeInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom(Listener()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("no filter registered for config type " + "envoy.config.listener.v3.Listener")); +} + +// Test that we ignore optional unknown filter types in ClusterWeight. +TEST_P(LdsRdsTest, IgnoresOptionalUnknownHttpFilterTypeInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.mutable_config()->PackFrom(Listener()); + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK filters without configs in ClusterWeight. +TEST_P(LdsRdsTest, RejectsHttpFilterWithoutConfigInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"]; + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we NACK filters without configs in FilterConfig in ClusterWeight. +TEST_P(LdsRdsTest, + RejectsHttpFilterWithoutConfigInFilterConfigInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + ::envoy::config::route::v3::FilterConfig()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "no filter config specified for filter name unknown")); +} + +// Test that we ignore optional filters without configs in ClusterWeight. +TEST_P(LdsRdsTest, IgnoresOptionalHttpFilterWithoutConfigInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + ::envoy::config::route::v3::FilterConfig filter_config; + filter_config.set_is_optional(true); + (*per_filter_config)["unknown"].PackFrom(filter_config); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + WaitForAllBackends(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Test that we NACK unparseable filter types in ClusterWeight. +TEST_P(LdsRdsTest, RejectsUnparseableHttpFilterTypeInClusterWeight) { + if (GetParam().use_v2()) return; // Filters supported in v3 only. + RouteConfiguration route_config = default_route_config_; + auto* cluster_weight = route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_weighted_clusters() + ->add_clusters(); + cluster_weight->set_name(kDefaultClusterName); + cluster_weight->mutable_weight()->set_value(100); + auto* per_filter_config = cluster_weight->mutable_typed_per_filter_config(); + (*per_filter_config)["unknown"].PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + SetListenerAndRouteConfiguration(0, default_listener_, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForRdsNack()) << "timed out waiting for NACK"; + const auto response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("router filter does not support config override")); +} + +using CdsTest = BasicTest; + +// Tests that CDS client should send an ACK upon correct CDS response. +TEST_P(CdsTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(CdsTest, LogicalDNSClusterType) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + auto* address = cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address(); + address->set_address(kServerName); + address->set_port_value(443); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Set Logical DNS result + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(GetBackendPorts(1, 2)); + logical_dns_cluster_resolver_response_generator_->SetResponse( + std::move(result)); + } + // Wait for traffic to go to backend 1. + WaitForBackend(1); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeMissingLoadAssignment) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "load_assignment not present for LOGICAL_DNS cluster")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeMissingLocalities) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("load_assignment for LOGICAL_DNS cluster must have " + "exactly one locality, found 0")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeMultipleLocalities) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + auto* load_assignment = cluster.mutable_load_assignment(); + load_assignment->add_endpoints(); + load_assignment->add_endpoints(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("load_assignment for LOGICAL_DNS cluster must have " + "exactly one locality, found 2")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeMissingEndpoints) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment()->add_endpoints(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "locality for LOGICAL_DNS cluster must have exactly one " + "endpoint, found 0")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeMultipleEndpoints) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + auto* locality = cluster.mutable_load_assignment()->add_endpoints(); + locality->add_lb_endpoints(); + locality->add_lb_endpoints(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "locality for LOGICAL_DNS cluster must have exactly one " + "endpoint, found 2")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeEmptyEndpoint) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment()->add_endpoints()->add_lb_endpoints(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("LbEndpoint endpoint field not set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeEndpointMissingAddress) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("Endpoint address field not set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeAddressMissingSocketAddress) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("Address socket_address field not set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeSocketAddressHasResolverName) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address() + ->set_resolver_name("foo"); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("LOGICAL_DNS clusters must NOT have a " + "custom resolver name set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeSocketAddressMissingAddress) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("SocketAddress address field not set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, LogicalDNSClusterTypeSocketAddressMissingPort) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create Logical DNS Cluster + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address() + ->set_address(kServerName); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("SocketAddress port_value field not set")); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, AggregateClusterType) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Create Aggregate Cluster + auto cluster = default_cluster_; + CustomClusterType* custom_cluster = cluster.mutable_cluster_type(); + custom_cluster->set_name("envoy.clusters.aggregate"); + ClusterConfig cluster_config; + cluster_config.add_clusters(kNewCluster1Name); + cluster_config.add_clusters(kNewCluster2Name); + custom_cluster->mutable_typed_config()->PackFrom(cluster_config); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Wait for traffic to go to backend 1. + WaitForBackend(1); + // Shutdown backend 1 and wait for all traffic to go to backend 2. + ShutdownBackend(1); + WaitForBackend(2, WaitForBackendOptions().set_allow_failures(true)); + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); + // Bring backend 1 back and ensure all traffic go back to it. + StartBackend(1); + WaitForBackend(1); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, AggregateClusterEdsToLogicalDns) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kLogicalDNSClusterName = "logical_dns_cluster"; + // Populate new EDS resources. + EdsResourceArgs args1({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args1, kNewEdsService1Name)); + // Populate new CDS resources. + Cluster new_cluster1 = default_cluster_; + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + // Create Logical DNS Cluster + auto logical_dns_cluster = default_cluster_; + logical_dns_cluster.set_name(kLogicalDNSClusterName); + logical_dns_cluster.set_type(Cluster::LOGICAL_DNS); + auto* address = logical_dns_cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address(); + address->set_address(kServerName); + address->set_port_value(443); + balancers_[0]->ads_service()->SetCdsResource(logical_dns_cluster); + // Create Aggregate Cluster + auto cluster = default_cluster_; + CustomClusterType* custom_cluster = cluster.mutable_cluster_type(); + custom_cluster->set_name("envoy.clusters.aggregate"); + ClusterConfig cluster_config; + cluster_config.add_clusters(kNewCluster1Name); + cluster_config.add_clusters(kLogicalDNSClusterName); + custom_cluster->mutable_typed_config()->PackFrom(cluster_config); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Set Logical DNS result + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(GetBackendPorts(2, 3)); + logical_dns_cluster_resolver_response_generator_->SetResponse( + std::move(result)); + } + // Wait for traffic to go to backend 1. + WaitForBackend(1); + // Shutdown backend 1 and wait for all traffic to go to backend 2. + ShutdownBackend(1); + WaitForBackend(2, WaitForBackendOptions().set_allow_failures(true)); + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); + // Bring backend 1 back and ensure all traffic go back to it. + StartBackend(1); + WaitForBackend(1); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +TEST_P(CdsTest, AggregateClusterLogicalDnsToEds) { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER", + "true"); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kLogicalDNSClusterName = "logical_dns_cluster"; + // Populate new EDS resources. + EdsResourceArgs args2({ + {"locality0", CreateEndpointsForBackends(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster2 = default_cluster_; + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Create Logical DNS Cluster + auto logical_dns_cluster = default_cluster_; + logical_dns_cluster.set_name(kLogicalDNSClusterName); + logical_dns_cluster.set_type(Cluster::LOGICAL_DNS); + auto* address = logical_dns_cluster.mutable_load_assignment() + ->add_endpoints() + ->add_lb_endpoints() + ->mutable_endpoint() + ->mutable_address() + ->mutable_socket_address(); + address->set_address(kServerName); + address->set_port_value(443); + balancers_[0]->ads_service()->SetCdsResource(logical_dns_cluster); + // Create Aggregate Cluster + auto cluster = default_cluster_; + CustomClusterType* custom_cluster = cluster.mutable_cluster_type(); + custom_cluster->set_name("envoy.clusters.aggregate"); + ClusterConfig cluster_config; + cluster_config.add_clusters(kLogicalDNSClusterName); + cluster_config.add_clusters(kNewCluster2Name); + custom_cluster->mutable_typed_config()->PackFrom(cluster_config); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Set Logical DNS result + { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(GetBackendPorts(1, 2)); + logical_dns_cluster_resolver_response_generator_->SetResponse( + std::move(result)); + } + // Wait for traffic to go to backend 1. + WaitForBackend(1); + // Shutdown backend 1 and wait for all traffic to go to backend 2. + ShutdownBackend(1); + WaitForBackend(2, WaitForBackendOptions().set_allow_failures(true)); + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); + // Bring backend 1 back and ensure all traffic go back to it. + StartBackend(1); + WaitForBackend(1); + gpr_unsetenv( + "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER"); +} + +// Test that CDS client should send a NACK if cluster type is Logical DNS but +// the feature is not yet supported. +TEST_P(CdsTest, LogicalDNSClusterTypeDisabled) { + auto cluster = default_cluster_; + cluster.set_type(Cluster::LOGICAL_DNS); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("DiscoveryType is not valid.")); +} + +// Test that CDS client should send a NACK if cluster type is AGGREGATE but +// the feature is not yet supported. +TEST_P(CdsTest, AggregateClusterTypeDisabled) { + auto cluster = default_cluster_; + CustomClusterType* custom_cluster = cluster.mutable_cluster_type(); + custom_cluster->set_name("envoy.clusters.aggregate"); + ClusterConfig cluster_config; + cluster_config.add_clusters("cluster1"); + cluster_config.add_clusters("cluster2"); + custom_cluster->mutable_typed_config()->PackFrom(cluster_config); + cluster.set_type(Cluster::LOGICAL_DNS); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("DiscoveryType is not valid.")); +} + +// Tests that CDS client should send a NACK if the cluster type in CDS +// response is unsupported. +TEST_P(CdsTest, UnsupportedClusterType) { + auto cluster = default_cluster_; + cluster.set_type(Cluster::STATIC); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("DiscoveryType is not valid.")); +} + +// Tests that the NACK for multiple bad resources includes both errors. +TEST_P(CdsTest, MultipleBadResources) { + constexpr char kClusterName2[] = "cluster_name_2"; + constexpr char kClusterName3[] = "cluster_name_3"; + // Add cluster with unsupported type. + auto cluster = default_cluster_; + cluster.set_name(kClusterName2); + cluster.set_type(Cluster::STATIC); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Add second cluster with the same error. + cluster.set_name(kClusterName3); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Change RouteConfig to point to all clusters. + RouteConfiguration route_config = default_route_config_; + route_config.mutable_virtual_hosts(0)->clear_routes(); + // First route: default cluster, selected based on header. + auto* route = route_config.mutable_virtual_hosts(0)->add_routes(); + route->mutable_match()->set_prefix(""); + auto* header_matcher = route->mutable_match()->add_headers(); + header_matcher->set_name("cluster"); + header_matcher->set_exact_match(kDefaultClusterName); + route->mutable_route()->set_cluster(kDefaultClusterName); + // Second route: cluster 2, selected based on header. + route = route_config.mutable_virtual_hosts(0)->add_routes(); + route->mutable_match()->set_prefix(""); + header_matcher = route->mutable_match()->add_headers(); + header_matcher->set_name("cluster"); + header_matcher->set_exact_match(kClusterName2); + route->mutable_route()->set_cluster(kClusterName2); + // Third route: cluster 3, used by default. + route = route_config.mutable_virtual_hosts(0)->add_routes(); + route->mutable_match()->set_prefix(""); + route->mutable_route()->set_cluster(kClusterName3); + SetRouteConfiguration(0, route_config); + // Add EDS resource. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Send RPC. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::ContainsRegex(absl::StrCat(kClusterName2, + ": validation error.*" + "DiscoveryType is not valid.*", + kClusterName3, + ": validation error.*" + "DiscoveryType is not valid"))); + // RPCs for default cluster should succeed. + std::vector> metadata_default_cluster = { + {"cluster", kDefaultClusterName}, + }; + CheckRpcSendOk( + 1, RpcOptions().set_metadata(std::move(metadata_default_cluster))); + // RPCs for cluster 2 should fail. + std::vector> metadata_cluster_2 = { + {"cluster", kClusterName2}, + }; + CheckRpcSendFailure(CheckRpcSendFailureOptions().set_rpc_options( + RpcOptions().set_metadata(std::move(metadata_cluster_2)))); +} + +// Tests that we don't trigger does-not-exist callbacks for a resource +// that was previously valid but is updated to be invalid. +TEST_P(CdsTest, InvalidClusterStillExistsIfPreviouslyCached) { + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Check that everything works. + CheckRpcSendOk(); + // Now send an update changing the Cluster to be invalid. + auto cluster = default_cluster_; + cluster.set_type(Cluster::STATIC); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Wait for xDS server to see NACK. + auto deadline = absl::Now() + absl::Seconds(30); + do { + CheckRpcSendOk(); + ASSERT_LT(absl::Now(), deadline); + } while (balancers_[0]->ads_service()->cds_response_state().state != + AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(balancers_[0]->ads_service()->cds_response_state().error_message, + ::testing::ContainsRegex(absl::StrCat( + kDefaultClusterName, + ": validation error.*DiscoveryType is not valid"))); + // Check one more time, just to make sure it still works after NACK. + CheckRpcSendOk(); +} + +// Tests that CDS client should send a NACK if the eds_config in CDS response +// is other than ADS. +TEST_P(CdsTest, WrongEdsConfig) { + auto cluster = default_cluster_; + cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("EDS ConfigSource is not ADS.")); +} + +// Tests that CDS client should send a NACK if the lb_policy in CDS response +// is other than ROUND_ROBIN. +TEST_P(CdsTest, WrongLbPolicy) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::LEAST_REQUEST); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("LB policy is not supported.")); +} + +// Tests that CDS client should send a NACK if the lrs_server in CDS response +// is other than SELF. +TEST_P(CdsTest, WrongLrsServer) { + auto cluster = default_cluster_; + cluster.mutable_lrs_server()->mutable_ads(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("LRS ConfigSource is not self.")); +} + +// Tests that ring hash policy that hashes using channel id ensures all RPCs +// to go 1 particular backend. +TEST_P(CdsTest, RingHashChannelIdHashing) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(100); + bool found = false; + for (size_t i = 0; i < backends_.size(); ++i) { + if (backends_[i]->backend_service()->request_count() > 0) { + EXPECT_EQ(backends_[i]->backend_service()->request_count(), 100) + << "backend " << i; + EXPECT_FALSE(found) << "backend " << i; + found = true; + } + } + EXPECT_TRUE(found); +} + +// Tests that ring hash policy that hashes using a header value can spread +// RPCs across all the backends. +TEST_P(CdsTest, RingHashHeaderHashing) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + // Note each type of RPC will contains a header value that will always be + // hashed to a specific backend as the header value matches the value used + // to create the entry in the ring. + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + std::vector> metadata1 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(1)}}; + std::vector> metadata2 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(2)}}; + std::vector> metadata3 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(3)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + const auto rpc_options1 = RpcOptions().set_metadata(std::move(metadata1)); + const auto rpc_options2 = RpcOptions().set_metadata(std::move(metadata2)); + const auto rpc_options3 = RpcOptions().set_metadata(std::move(metadata3)); + WaitForBackend(0, WaitForBackendOptions(), rpc_options); + WaitForBackend(1, WaitForBackendOptions(), rpc_options1); + WaitForBackend(2, WaitForBackendOptions(), rpc_options2); + WaitForBackend(3, WaitForBackendOptions(), rpc_options3); + CheckRpcSendOk(100, rpc_options); + CheckRpcSendOk(100, rpc_options1); + CheckRpcSendOk(100, rpc_options2); + CheckRpcSendOk(100, rpc_options3); + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(100, backends_[i]->backend_service()->request_count()); + } +} + +// Tests that ring hash policy that hashes using a header value and regex +// rewrite to aggregate RPCs to 1 backend. +TEST_P(CdsTest, RingHashHeaderHashingWithRegexRewrite) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + hash_policy->mutable_header() + ->mutable_regex_rewrite() + ->mutable_pattern() + ->set_regex("[0-9]+"); + hash_policy->mutable_header()->mutable_regex_rewrite()->set_substitution( + "foo"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + std::vector> metadata1 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(1)}}; + std::vector> metadata2 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(2)}}; + std::vector> metadata3 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(3)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + const auto rpc_options1 = RpcOptions().set_metadata(std::move(metadata1)); + const auto rpc_options2 = RpcOptions().set_metadata(std::move(metadata2)); + const auto rpc_options3 = RpcOptions().set_metadata(std::move(metadata3)); + CheckRpcSendOk(100, rpc_options); + CheckRpcSendOk(100, rpc_options1); + CheckRpcSendOk(100, rpc_options2); + CheckRpcSendOk(100, rpc_options3); + bool found = false; + for (size_t i = 0; i < backends_.size(); ++i) { + if (backends_[i]->backend_service()->request_count() > 0) { + EXPECT_EQ(backends_[i]->backend_service()->request_count(), 400) + << "backend " << i; + EXPECT_FALSE(found) << "backend " << i; + found = true; + } + } + EXPECT_TRUE(found); +} + +// Tests that ring hash policy that hashes using a random value. +TEST_P(CdsTest, RingHashNoHashPolicy) { + const double kDistribution50Percent = 0.5; + const double kErrorTolerance = 0.05; + const uint32_t kRpcTimeoutMs = 10000; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kDistribution50Percent, kErrorTolerance); + auto cluster = default_cluster_; + // Increasing min ring size for random distribution. + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 100000); + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 2)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + // TODO(donnadionne): remove extended timeout after ring creation + // optimization. + WaitForAllBackends(0, 2, WaitForBackendOptions(), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + CheckRpcSendOk(kNumRpcs); + const int request_count_1 = backends_[0]->backend_service()->request_count(); + const int request_count_2 = backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(request_count_1) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(request_count_2) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); +} + +// Test that ring hash policy evaluation will continue past the terminal +// policy if no results are produced yet. +TEST_P(CdsTest, RingHashContinuesPastTerminalPolicyThatDoesNotProduceResult) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("header_not_present"); + hash_policy->set_terminal(true); + auto* hash_policy2 = route->mutable_route()->add_hash_policy(); + hash_policy2->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 2)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + CheckRpcSendOk(100, rpc_options); + EXPECT_EQ(backends_[0]->backend_service()->request_count(), 100); + EXPECT_EQ(backends_[1]->backend_service()->request_count(), 0); +} + +// Test random hash is used when header hashing specified a header field that +// the RPC did not have. +TEST_P(CdsTest, RingHashOnHeaderThatIsNotPresent) { + const double kDistribution50Percent = 0.5; + const double kErrorTolerance = 0.05; + const uint32_t kRpcTimeoutMs = 10000; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kDistribution50Percent, kErrorTolerance); + auto cluster = default_cluster_; + // Increasing min ring size for random distribution. + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 100000); + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("header_not_present"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 2)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"unmatched_header", absl::StrFormat("%" PRIu32, rand())}, + }; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + // TODO(donnadionne): remove extended timeout after ring creation + // optimization. + WaitForAllBackends(0, 2, WaitForBackendOptions(), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + CheckRpcSendOk(kNumRpcs, rpc_options); + const int request_count_1 = backends_[0]->backend_service()->request_count(); + const int request_count_2 = backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(request_count_1) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(request_count_2) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); +} + +// Test random hash is used when only unsupported hash policies are +// configured. +TEST_P(CdsTest, RingHashUnsupportedHashPolicyDefaultToRandomHashing) { + const double kDistribution50Percent = 0.5; + const double kErrorTolerance = 0.05; + const uint32_t kRpcTimeoutMs = 10000; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kDistribution50Percent, kErrorTolerance); + auto cluster = default_cluster_; + // Increasing min ring size for random distribution. + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 100000); + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy_unsupported_1 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_1->mutable_cookie()->set_name("cookie"); + auto* hash_policy_unsupported_2 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_2->mutable_connection_properties()->set_source_ip( + true); + auto* hash_policy_unsupported_3 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_3->mutable_query_parameter()->set_name( + "query_parameter"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 2)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + // TODO(donnadionne): remove extended timeout after ring creation + // optimization. + WaitForAllBackends(0, 2, WaitForBackendOptions(), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + CheckRpcSendOk(kNumRpcs); + const int request_count_1 = backends_[0]->backend_service()->request_count(); + const int request_count_2 = backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(request_count_1) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(request_count_2) / kNumRpcs, + ::testing::DoubleNear(kDistribution50Percent, kErrorTolerance)); +} + +// Tests that ring hash policy that hashes using a random value can spread +// RPCs across all the backends according to locality weight. +TEST_P(CdsTest, RingHashRandomHashingDistributionAccordingToEndpointWeight) { + const size_t kWeight1 = 1; + const size_t kWeight2 = 2; + const size_t kWeightTotal = kWeight1 + kWeight2; + const double kWeight33Percent = static_cast(kWeight1) / kWeightTotal; + const double kWeight66Percent = static_cast(kWeight2) / kWeightTotal; + const double kErrorTolerance = 0.05; + const uint32_t kRpcTimeoutMs = 10000; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kWeight33Percent, kErrorTolerance); + auto cluster = default_cluster_; + // Increasing min ring size for random distribution. + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 100000); + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + EdsResourceArgs args({{"locality0", + {CreateEndpoint(0, HealthStatus::UNKNOWN, 1), + CreateEndpoint(1, HealthStatus::UNKNOWN, 2)}}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + // TODO(donnadionne): remove extended timeout after ring creation + // optimization. + WaitForAllBackends(0, 2, WaitForBackendOptions(), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + CheckRpcSendOk(kNumRpcs); + const int weight_33_request_count = + backends_[0]->backend_service()->request_count(); + const int weight_66_request_count = + backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(weight_33_request_count) / kNumRpcs, + ::testing::DoubleNear(kWeight33Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_66_request_count) / kNumRpcs, + ::testing::DoubleNear(kWeight66Percent, kErrorTolerance)); +} + +// Tests that ring hash policy that hashes using a random value can spread +// RPCs across all the backends according to locality weight. +TEST_P(CdsTest, + RingHashRandomHashingDistributionAccordingToLocalityAndEndpointWeight) { + const size_t kWeight1 = 1 * 1; + const size_t kWeight2 = 2 * 2; + const size_t kWeightTotal = kWeight1 + kWeight2; + const double kWeight20Percent = static_cast(kWeight1) / kWeightTotal; + const double kWeight80Percent = static_cast(kWeight2) / kWeightTotal; + const double kErrorTolerance = 0.05; + const uint32_t kRpcTimeoutMs = 10000; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kWeight20Percent, kErrorTolerance); + auto cluster = default_cluster_; + // Increasing min ring size for random distribution. + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 100000); + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + EdsResourceArgs args( + {{"locality0", {CreateEndpoint(0, HealthStatus::UNKNOWN, 1)}, 1}, + {"locality1", {CreateEndpoint(1, HealthStatus::UNKNOWN, 2)}, 2}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + // TODO(donnadionne): remove extended timeout after ring creation + // optimization. + WaitForAllBackends(0, 2, WaitForBackendOptions(), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + CheckRpcSendOk(kNumRpcs); + const int weight_20_request_count = + backends_[0]->backend_service()->request_count(); + const int weight_80_request_count = + backends_[1]->backend_service()->request_count(); + EXPECT_THAT(static_cast(weight_20_request_count) / kNumRpcs, + ::testing::DoubleNear(kWeight20Percent, kErrorTolerance)); + EXPECT_THAT(static_cast(weight_80_request_count) / kNumRpcs, + ::testing::DoubleNear(kWeight80Percent, kErrorTolerance)); +} + +// Tests round robin is not implacted by the endpoint weight, and that the +// localities in a locality map are picked according to their weights. +TEST_P(CdsTest, RingHashEndpointWeightDoesNotImpactWeightedRoundRobin) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const int kLocalityWeight0 = 2; + const int kLocalityWeight1 = 8; + const int kTotalLocalityWeight = kLocalityWeight0 + kLocalityWeight1; + const double kLocalityWeightRate0 = + static_cast(kLocalityWeight0) / kTotalLocalityWeight; + const double kLocalityWeightRate1 = + static_cast(kLocalityWeight1) / kTotalLocalityWeight; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kLocalityWeightRate0, kErrorTolerance); + // ADS response contains 2 localities, each of which contains 1 backend. + EdsResourceArgs args({ + {"locality0", + {CreateEndpoint(0, HealthStatus::UNKNOWN, 8)}, + kLocalityWeight0}, + {"locality1", + {CreateEndpoint(1, HealthStatus::UNKNOWN, 2)}, + kLocalityWeight1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for both backends to be ready. + WaitForAllBackends(0, 2); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + // The locality picking rates should be roughly equal to the expectation. + const double locality_picked_rate_0 = + static_cast(backends_[0]->backend_service()->request_count()) / + kNumRpcs; + const double locality_picked_rate_1 = + static_cast(backends_[1]->backend_service()->request_count()) / + kNumRpcs; + EXPECT_THAT(locality_picked_rate_0, + ::testing::DoubleNear(kLocalityWeightRate0, kErrorTolerance)); + EXPECT_THAT(locality_picked_rate_1, + ::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance)); +} + +// Tests that ring hash policy that hashes using a fixed string ensures all +// RPCs to go 1 particular backend; and that subsequent hashing policies are +// ignored due to the setting of terminal. +TEST_P(CdsTest, RingHashFixedHashingTerminalPolicy) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("fixed_string"); + hash_policy->set_terminal(true); + auto* hash_policy_to_be_ignored = route->mutable_route()->add_hash_policy(); + hash_policy_to_be_ignored->mutable_header()->set_header_name("random_string"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"fixed_string", "fixed_value"}, + {"random_string", absl::StrFormat("%" PRIu32, rand())}, + }; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + CheckRpcSendOk(100, rpc_options); + bool found = false; + for (size_t i = 0; i < backends_.size(); ++i) { + if (backends_[i]->backend_service()->request_count() > 0) { + EXPECT_EQ(backends_[i]->backend_service()->request_count(), 100) + << "backend " << i; + EXPECT_FALSE(found) << "backend " << i; + found = true; + } + } + EXPECT_TRUE(found); +} + +// Test that the channel will go from idle to ready via connecting; +// (tho it is not possible to catch the connecting state before moving to +// ready) +TEST_P(CdsTest, RingHashIdleToReady) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + CheckRpcSendOk(); + EXPECT_EQ(GRPC_CHANNEL_READY, channel_->GetState(false)); +} + +// Test that when the first pick is down leading to a transient failure, we +// will move on to the next ring hash entry. +TEST_P(CdsTest, RingHashTransientFailureCheckNextOne) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + std::vector endpoints; + const int unused_port = grpc_pick_unused_port_or_die(); + endpoints.emplace_back(unused_port); + endpoints.emplace_back(backends_[1]->port()); + EdsResourceArgs args({ + {"locality0", std::move(endpoints)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", + CreateMetadataValueThatHashesToBackendPort(unused_port)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + WaitForBackend(1, WaitForBackendOptions(), rpc_options); + CheckRpcSendOk(100, rpc_options); + EXPECT_EQ(0, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(100, backends_[1]->backend_service()->request_count()); +} + +// Test that when a backend goes down, we will move on to the next subchannel +// (with a lower priority). When the backend comes back up, traffic will move +// back. +TEST_P(CdsTest, RingHashSwitchToLowerPrioirtyAndThenBack) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 0}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + WaitForBackend(0, WaitForBackendOptions(), rpc_options); + ShutdownBackend(0); + WaitForBackend(1, WaitForBackendOptions().set_allow_failures(true), + rpc_options); + StartBackend(0); + WaitForBackend(0, WaitForBackendOptions(), rpc_options); + CheckRpcSendOk(100, rpc_options); + EXPECT_EQ(100, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); +} + +// Test that when all backends are down, we will keep reattempting. +TEST_P(CdsTest, RingHashAllFailReattempt) { + const uint32_t kConnectionTimeoutMilliseconds = 5000; + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + std::vector endpoints; + endpoints.emplace_back(grpc_pick_unused_port_or_die()); + endpoints.emplace_back(backends_[1]->port()); + EdsResourceArgs args({ + {"locality0", std::move(endpoints)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + ShutdownBackend(1); + CheckRpcSendFailure(CheckRpcSendFailureOptions().set_rpc_options( + RpcOptions().set_metadata(std::move(metadata)))); + StartBackend(1); + // Ensure we are actively connecting without any traffic. + EXPECT_TRUE(channel_->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kConnectionTimeoutMilliseconds))); +} + +// Test that when all backends are down and then up, we may pick a TF backend +// and we will then jump to ready backend. +TEST_P(CdsTest, RingHashTransientFailureSkipToAvailableReady) { + const uint32_t kConnectionTimeoutMilliseconds = 5000; + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + std::vector endpoints; + // Make sure we include some unused ports to fill the ring. + endpoints.emplace_back(backends_[0]->port()); + endpoints.emplace_back(backends_[1]->port()); + endpoints.emplace_back(grpc_pick_unused_port_or_die()); + endpoints.emplace_back(grpc_pick_unused_port_or_die()); + EdsResourceArgs args({ + {"locality0", std::move(endpoints)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + const auto rpc_options = RpcOptions().set_metadata(std::move(metadata)); + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + ShutdownBackend(0); + ShutdownBackend(1); + CheckRpcSendFailure( + CheckRpcSendFailureOptions().set_rpc_options(rpc_options)); + EXPECT_EQ(GRPC_CHANNEL_TRANSIENT_FAILURE, channel_->GetState(false)); + // Bring up 0, should be picked as the RPC is hashed to it. + StartBackend(0); + EXPECT_TRUE(channel_->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kConnectionTimeoutMilliseconds))); + WaitForBackend(0, WaitForBackendOptions(), rpc_options); + // Bring down 0 and bring up 1. + // Note the RPC contains a header value that will always be hashed to + // backend 0. So by purposely bring down backend 0 and bring up another + // backend, this will ensure Picker's first choice of backend 0 will fail + // and it will + // 1. reattempt backend 0 and + // 2. go through the remaining subchannels to find one in READY. + // Since the the entries in the ring is pretty distributed and we have + // unused ports to fill the ring, it is almost guaranteed that the Picker + // will go through some non-READY entries and skip them as per design. + ShutdownBackend(0); + CheckRpcSendFailure( + CheckRpcSendFailureOptions().set_rpc_options(rpc_options)); + StartBackend(1); + EXPECT_TRUE(channel_->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kConnectionTimeoutMilliseconds))); + WaitForBackend(1, WaitForBackendOptions(), rpc_options); +} + +// Test unspported hash policy types are all ignored before a supported +// policy. +TEST_P(CdsTest, RingHashUnsupportedHashPolicyUntilChannelIdHashing) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy_unsupported_1 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_1->mutable_cookie()->set_name("cookie"); + auto* hash_policy_unsupported_2 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_2->mutable_connection_properties()->set_source_ip( + true); + auto* hash_policy_unsupported_3 = route->mutable_route()->add_hash_policy(); + hash_policy_unsupported_3->mutable_query_parameter()->set_name( + "query_parameter"); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(100); + bool found = false; + for (size_t i = 0; i < backends_.size(); ++i) { + if (backends_[i]->backend_service()->request_count() > 0) { + EXPECT_EQ(backends_[i]->backend_service()->request_count(), 100) + << "backend " << i; + EXPECT_FALSE(found) << "backend " << i; + found = true; + } + } + EXPECT_TRUE(found); +} + +// Test we nack when ring hash policy has invalid hash function (something +// other than XX_HASH. +TEST_P(CdsTest, RingHashPolicyHasInvalidHashFunction) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + cluster.mutable_ring_hash_lb_config()->set_hash_function( + Cluster::RingHashLbConfig::MURMUR_HASH_2); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("ring hash lb config has invalid hash function.")); +} + +// Test we nack when ring hash policy has invalid ring size. +TEST_P(CdsTest, RingHashPolicyHasInvalidMinimumRingSize) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 0); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "min_ring_size is not in the range of 1 to 8388608.")); +} + +// Test we nack when ring hash policy has invalid ring size. +TEST_P(CdsTest, RingHashPolicyHasInvalidMaxmumRingSize) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + cluster.mutable_ring_hash_lb_config()->mutable_maximum_ring_size()->set_value( + 8388609); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "max_ring_size is not in the range of 1 to 8388608.")); +} + +// Test we nack when ring hash policy has invalid ring size. +TEST_P(CdsTest, RingHashPolicyHasInvalidRingSizeMinGreaterThanMax) { + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + cluster.mutable_ring_hash_lb_config()->mutable_maximum_ring_size()->set_value( + 5000); + cluster.mutable_ring_hash_lb_config()->mutable_minimum_ring_size()->set_value( + 5001); + balancers_[0]->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_filter_state()->set_key("io.grpc.channel_id"); + SetListenerAndRouteConfiguration(0, default_listener_, new_route_config); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "min_ring_size cannot be greater than max_ring_size.")); +} + +class XdsSecurityTest : public BasicTest { + protected: + void SetUp() override { + BasicTest::SetUp(); + root_cert_ = ReadFile(kCaCertPath); + bad_root_cert_ = ReadFile(kBadClientCertPath); + identity_pair_ = ReadTlsIdentityPair(kClientKeyPath, kClientCertPath); + // TODO(yashykt): Use different client certs here instead of reusing + // server certs after https://github.com/grpc/grpc/pull/24876 is merged + fallback_identity_pair_ = + ReadTlsIdentityPair(kServerKeyPath, kServerCertPath); + bad_identity_pair_ = + ReadTlsIdentityPair(kBadClientKeyPath, kBadClientCertPath); + server_san_exact_.set_exact("*.test.google.fr"); + server_san_prefix_.set_prefix("waterzooi.test.google"); + server_san_suffix_.set_suffix("google.fr"); + server_san_contains_.set_contains("google"); + server_san_regex_.mutable_safe_regex()->mutable_google_re2(); + server_san_regex_.mutable_safe_regex()->set_regex( + "(foo|waterzooi).test.google.(fr|be)"); + bad_san_1_.set_exact("192.168.1.4"); + bad_san_2_.set_exact("foo.test.google.in"); + authenticated_identity_ = {"testclient"}; + fallback_authenticated_identity_ = {"*.test.google.fr", + "waterzooi.test.google.be", + "*.test.youtube.com", "192.168.1.3"}; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolutionForLbChannelAllBalancers(); + } + + void TearDown() override { + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; + BasicTest::TearDown(); + } + + // Sends CDS updates with the new security configuration and verifies that + // after propagation, this new configuration is used for connections. If \a + // identity_instance_name and \a root_instance_name are both empty, + // connections are expected to use fallback credentials. + void UpdateAndVerifyXdsSecurityConfiguration( + absl::string_view root_instance_name, + absl::string_view root_certificate_name, + absl::string_view identity_instance_name, + absl::string_view identity_certificate_name, + const std::vector& san_matchers, + const std::vector& expected_authenticated_identity, + bool test_expects_failure = false) { + auto cluster = default_cluster_; + if (!identity_instance_name.empty() || !root_instance_name.empty()) { + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + if (!identity_instance_name.empty()) { + upstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name(std::string(identity_instance_name)); + upstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_certificate_name(std::string(identity_certificate_name)); + } + if (!root_instance_name.empty()) { + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name(std::string(root_instance_name)); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_certificate_name(std::string(root_certificate_name)); + } + if (!san_matchers.empty()) { + auto* validation_context = + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context(); + for (const auto& san_matcher : san_matchers) { + *validation_context->add_match_subject_alt_names() = san_matcher; + } + } + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + } + balancers_[0]->ads_service()->SetCdsResource(cluster); + // The updates might take time to have an effect, so use a retry loop. + constexpr int kRetryCount = 100; + int num_tries = 0; + for (; num_tries < kRetryCount; num_tries++) { + // Give some time for the updates to propagate. + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(100)); + if (test_expects_failure) { + // Restart the servers to force a reconnection so that previously + // connected subchannels are not used for the RPC. + ShutdownBackend(0); + StartBackend(0); + if (SendRpc().ok()) { + gpr_log(GPR_ERROR, "RPC succeeded. Failure expected. Trying again."); + continue; + } + } else { + WaitForBackend(0, WaitForBackendOptions().set_allow_failures(true)); + Status status = SendRpc(); + if (!status.ok()) { + gpr_log(GPR_ERROR, "RPC failed. code=%d message=%s Trying again.", + status.error_code(), status.error_message().c_str()); + continue; + } + if (backends_[0]->backend_service()->last_peer_identity() != + expected_authenticated_identity) { + gpr_log( + GPR_ERROR, + "Expected client identity does not match. (actual) %s vs " + "(expected) %s Trying again.", + absl::StrJoin( + backends_[0]->backend_service()->last_peer_identity(), ",") + .c_str(), + absl::StrJoin(expected_authenticated_identity, ",").c_str()); + continue; + } + } + break; + } + EXPECT_LT(num_tries, kRetryCount); + } + + std::string root_cert_; + std::string bad_root_cert_; + grpc_core::PemKeyCertPairList identity_pair_; + grpc_core::PemKeyCertPairList fallback_identity_pair_; + grpc_core::PemKeyCertPairList bad_identity_pair_; + StringMatcher server_san_exact_; + StringMatcher server_san_prefix_; + StringMatcher server_san_suffix_; + StringMatcher server_san_contains_; + StringMatcher server_san_regex_; + StringMatcher bad_san_1_; + StringMatcher bad_san_2_; + std::vector authenticated_identity_; + std::vector fallback_authenticated_identity_; +}; + +TEST_P(XdsSecurityTest, UnknownTransportSocket) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("unknown_transport_socket"); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized transport socket: unknown_transport_socket")); +} + +TEST_P(XdsSecurityTest, + TLSConfigurationWithoutValidationContextCertificateProviderInstance) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("TLS configuration provided but no " + "ca_certificate_provider_instance found.")); +} + +TEST_P( + XdsSecurityTest, + MatchSubjectAltNamesProvidedWithoutValidationContextCertificateProviderInstance) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + auto* validation_context = upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context(); + *validation_context->add_match_subject_alt_names() = server_san_exact_; + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("TLS configuration provided but no " + "ca_certificate_provider_instance found.")); +} + +TEST_P( + XdsSecurityTest, + TlsCertificateProviderInstanceWithoutValidationContextCertificateProviderInstance) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name(std::string("fake_plugin1")); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("TLS configuration provided but no " + "ca_certificate_provider_instance found.")); +} + +TEST_P(XdsSecurityTest, RegexSanMatcherDoesNotAllowIgnoreCase) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name(std::string("fake_plugin1")); + auto* validation_context = upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context(); + StringMatcher matcher; + matcher.mutable_safe_regex()->mutable_google_re2(); + matcher.mutable_safe_regex()->set_regex( + "(foo|waterzooi).test.google.(fr|be)"); + matcher.set_ignore_case(true); + *validation_context->add_match_subject_alt_names() = matcher; + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "StringMatcher: ignore_case has no effect for SAFE_REGEX.")); +} + +TEST_P(XdsSecurityTest, UnknownRootCertificateProvider) { + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("unknown"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized certificate provider instance name: unknown")); +} + +TEST_P(XdsSecurityTest, UnknownIdentityCertificateProvider) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("unknown"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized certificate provider instance name: unknown")); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, + NacksCertificateValidationContextWithVerifyCertificateSpki) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->add_verify_certificate_spki("spki"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "CertificateValidationContext: verify_certificate_spki unsupported")); +} + +TEST_P(XdsSecurityTest, + NacksCertificateValidationContextWithVerifyCertificateHash) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->add_verify_certificate_hash("hash"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "CertificateValidationContext: verify_certificate_hash unsupported")); +} + +TEST_P(XdsSecurityTest, + NacksCertificateValidationContextWithRequireSignedCertificateTimes) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_require_signed_certificate_timestamp() + ->set_value(true); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("CertificateValidationContext: " + "require_signed_certificate_timestamp unsupported")); +} + +TEST_P(XdsSecurityTest, NacksCertificateValidationContextWithCrl) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_crl(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("CertificateValidationContext: crl unsupported")); +} + +TEST_P(XdsSecurityTest, + NacksCertificateValidationContextWithCustomValidatorConfig) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_custom_validator_config(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr( + "CertificateValidationContext: custom_validator_config unsupported")); +} + +TEST_P(XdsSecurityTest, NacksValidationContextSdsSecretConfig) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context_sds_secret_config(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("validation_context_sds_secret_config unsupported")); +} + +TEST_P(XdsSecurityTest, NacksTlsParams) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context()->mutable_tls_params(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("tls_params unsupported")); +} + +TEST_P(XdsSecurityTest, NacksCustomHandshaker) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->mutable_custom_handshaker(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("custom_handshaker unsupported")); +} + +TEST_P(XdsSecurityTest, NacksTlsCertificates) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context()->add_tls_certificates(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("tls_certificates unsupported")); +} + +TEST_P(XdsSecurityTest, NacksTlsCertificateSdsSecretConfigs) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + upstream_tls_context.mutable_common_tls_context() + ->add_tls_certificate_sds_secret_configs(); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + ASSERT_TRUE(WaitForCdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("tls_certificate_sds_secret_configs unsupported")); +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationInCombinedValidationContext) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_combined_validation_context() + ->mutable_default_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + WaitForBackend(0, WaitForBackendOptions().set_allow_failures(true)); + Status status = SendRpc(); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); +} + +// TODO(yashykt): Remove this test once we stop supporting old fields +TEST_P(XdsSecurityTest, + TestTlsConfigurationInValidationContextCertificateProviderInstance) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + auto cluster = default_cluster_; + auto* transport_socket = cluster.mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + UpstreamTlsContext upstream_tls_context; + upstream_tls_context.mutable_common_tls_context() + ->mutable_combined_validation_context() + ->mutable_validation_context_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + transport_socket->mutable_typed_config()->PackFrom(upstream_tls_context); + balancers_[0]->ads_service()->SetCdsResource(cluster); + WaitForBackend(0, WaitForBackendOptions().set_allow_failures(true)); + Status status = SendRpc(); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithNoSanMatchers) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {}, authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithExactSanMatcher) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithPrefixSanMatcher) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_prefix_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithSuffixSanMatcher) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_suffix_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithContainsSanMatcher) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_contains_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithRegexSanMatcher) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_regex_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithSanMatchersUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "fake_plugin1", "", + {server_san_exact_, server_san_prefix_}, authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {bad_san_1_, bad_san_2_}, {}, + true /* failure */); + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "fake_plugin1", "", + {server_san_prefix_, server_san_regex_}, authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithRootPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {bad_root_cert_, bad_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin2" /* bad root */, "", + "fake_plugin1", "", {}, {}, + true /* failure */); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithIdentityPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {root_cert_, fallback_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin2", + "", {server_san_exact_}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithBothPluginsUpdated) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {bad_root_cert_, bad_identity_pair_}}, + {"good", {root_cert_, fallback_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin2", "", "fake_plugin2", + "", {}, {}, true /* failure */); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_prefix_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin2", "good", "fake_plugin2", "good", {server_san_prefix_}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithRootCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"bad", {bad_root_cert_, bad_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_regex_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "bad", "fake_plugin1", + "", {server_san_regex_}, {}, + true /* failure */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, + TestMtlsConfigurationWithIdentityCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"bad", {bad_root_cert_, bad_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "bad", {server_san_exact_}, {}, + true /* failure */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, + TestMtlsConfigurationWithIdentityCertificateNameUpdateGoodCerts) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"good", {root_cert_, fallback_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "good", {server_san_exact_}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsConfigurationWithBothCertificateNamesUpdated) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"bad", {bad_root_cert_, bad_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "bad", "fake_plugin1", + "bad", {server_san_prefix_}, {}, + true /* failure */); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_prefix_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationWithNoSanMatchers) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", {}, + {} /* unauthenticated */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationWithSanMatchers) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "", "", + {server_san_exact_, server_san_prefix_, server_san_regex_}, + {} /* unauthenticated */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationWithSanMatchersUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "", "", {server_san_exact_, server_san_prefix_}, + {} /* unauthenticated */); + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "", "", {bad_san_1_, bad_san_2_}, + {} /* unauthenticated */, true /* failure */); + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin1", "", "", "", {server_san_prefix_, server_san_regex_}, + {} /* unauthenticated */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationWithRootCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"bad", {bad_root_cert_, bad_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "bad", "", "", + {server_san_exact_}, {}, + true /* failure */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsConfigurationWithRootPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {bad_root_cert_, bad_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + UpdateAndVerifyXdsSecurityConfiguration( + "fake_plugin2", "", "", "", {server_san_exact_}, {}, true /* failure */); + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestFallbackConfiguration) { + UpdateAndVerifyXdsSecurityConfiguration("", "", "", "", {}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsToTls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestMtlsToFallback) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("", "", "", "", {}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsToMtls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestTlsToFallback) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + UpdateAndVerifyXdsSecurityConfiguration("", "", "", "", {}, + fallback_authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestFallbackToMtls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("", "", "", "", {}, + fallback_authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "fake_plugin1", + "", {server_san_exact_}, + authenticated_identity_); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestFallbackToTls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + UpdateAndVerifyXdsSecurityConfiguration("", "", "", "", {}, + fallback_authenticated_identity_); + UpdateAndVerifyXdsSecurityConfiguration("fake_plugin1", "", "", "", + {server_san_exact_}, + {} /* unauthenticated */); + g_fake1_cert_data_map = nullptr; +} + +TEST_P(XdsSecurityTest, TestFileWatcherCertificateProvider) { + UpdateAndVerifyXdsSecurityConfiguration("file_plugin", "", "file_plugin", "", + {server_san_exact_}, + authenticated_identity_); +} + +class XdsEnabledServerTest : public XdsEnd2endTest { + protected: + XdsEnabledServerTest() + : XdsEnd2endTest(1, 1, 100, true /* use_xds_enabled_server */) {} + + void SetUp() override { + XdsEnd2endTest::SetUp(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + } +}; + +TEST_P(XdsEnabledServerTest, Basic) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + WaitForBackend(0); +} + +TEST_P(XdsEnabledServerTest, BadLdsUpdateNoApiListenerNorAddress) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Listener has neither address nor ApiListener")); +} + +TEST_P(XdsEnabledServerTest, BadLdsUpdateBothApiListenerAndAddress) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + listener.mutable_api_listener(); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Listener has both address and ApiListener")); +} + +TEST_P(XdsEnabledServerTest, UnsupportedL4Filter) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom(default_listener_ /* any proto object other than HttpConnectionManager */); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("Unsupported filter type")); +} + +TEST_P(XdsEnabledServerTest, UnsupportedHttpFilter) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + HttpConnectionManager http_connection_manager; + auto* http_filter = http_connection_manager.add_http_filters(); + http_filter->set_name("grpc.testing.unsupported_http_filter"); + http_filter->mutable_typed_config()->set_type_url( + "grpc.testing.unsupported_http_filter"); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("no filter registered for config type " + "grpc.testing.unsupported_http_filter")); +} + +TEST_P(XdsEnabledServerTest, HttpFilterNotSupportedOnServer) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + HttpConnectionManager http_connection_manager; + auto* http_filter = http_connection_manager.add_http_filters(); + http_filter->set_name("grpc.testing.client_only_http_filter"); + http_filter->mutable_typed_config()->set_type_url( + "grpc.testing.client_only_http_filter"); + http_filter = http_connection_manager.add_http_filters(); + http_filter->set_name("router"); + http_filter->mutable_typed_config()->PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Filter grpc.testing.client_only_http_filter is not " + "supported on servers")); +} + +TEST_P(XdsEnabledServerTest, + HttpFilterNotSupportedOnServerIgnoredWhenOptional) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + HttpConnectionManager http_connection_manager; + auto* http_filter = http_connection_manager.add_http_filters(); + http_filter->set_name("grpc.testing.client_only_http_filter"); + http_filter->mutable_typed_config()->set_type_url( + "grpc.testing.client_only_http_filter"); + http_filter->set_is_optional(true); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + WaitForBackend(0); + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +// Verify that a mismatch of listening address results in "not serving" +// status. +TEST_P(XdsEnabledServerTest, ListenerAddressMismatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + // Set a different listening address in the LDS update + listener.mutable_address()->mutable_socket_address()->set_address( + "192.168.1.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::FAILED_PRECONDITION); +} + +TEST_P(XdsEnabledServerTest, UseOriginalDstNotSupported) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "::1" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + listener.add_filter_chains()->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + listener.mutable_use_original_dst()->set_value(true); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("Field \'use_original_dst\' is not supported.")); +} + +class XdsServerSecurityTest : public XdsEnd2endTest { + protected: + XdsServerSecurityTest() + : XdsEnd2endTest(1, 1, 100, true /* use_xds_enabled_server */) {} + + void SetUp() override { + XdsEnd2endTest::SetUp(); + root_cert_ = ReadFile(kCaCertPath); + bad_root_cert_ = ReadFile(kBadClientCertPath); + identity_pair_ = ReadTlsIdentityPair(kServerKeyPath, kServerCertPath); + bad_identity_pair_ = + ReadTlsIdentityPair(kBadClientKeyPath, kBadClientCertPath); + identity_pair_2_ = ReadTlsIdentityPair(kClientKeyPath, kClientCertPath); + server_authenticated_identity_ = {"*.test.google.fr", + "waterzooi.test.google.be", + "*.test.youtube.com", "192.168.1.3"}; + server_authenticated_identity_2_ = {"testclient"}; + client_authenticated_identity_ = {"*.test.google.fr", + "waterzooi.test.google.be", + "*.test.youtube.com", "192.168.1.3"}; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + } + + void TearDown() override { + g_fake1_cert_data_map = nullptr; + g_fake2_cert_data_map = nullptr; + XdsEnd2endTest::TearDown(); + } + + void SetLdsUpdate(absl::string_view root_instance_name, + absl::string_view root_certificate_name, + absl::string_view identity_instance_name, + absl::string_view identity_certificate_name, + bool require_client_certificates) { + Listener listener; + listener.set_name(absl::StrCat( + ipv6_only_ ? "grpc/server?xds.resource.listening_address=[::1]:" + : "grpc/server?xds.resource.listening_address=127.0.0.1:", + backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "[::1]" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + if (!identity_instance_name.empty()) { + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name(std::string(identity_instance_name)); + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_certificate_name(std::string(identity_certificate_name)); + if (!root_instance_name.empty()) { + downstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_instance_name(std::string(root_instance_name)); + downstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->mutable_ca_certificate_provider_instance() + ->set_certificate_name(std::string(root_certificate_name)); + downstream_tls_context.mutable_require_client_certificate()->set_value( + require_client_certificates); + } + transport_socket->mutable_typed_config()->PackFrom( + downstream_tls_context); + } + balancers_[0]->ads_service()->SetLdsResource(listener); + } + + std::shared_ptr CreateMtlsChannel() { + ChannelArguments args; + // Override target name for host name check + args.SetString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, + ipv6_only_ ? "::1" : "127.0.0.1"); + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + std::string uri = absl::StrCat( + ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", backends_[0]->port()); + // TODO(yashykt): Switch to using C++ API once b/173823806 is fixed. + grpc_tls_credentials_options* options = + grpc_tls_credentials_options_create(); + grpc_tls_credentials_options_set_server_verification_option( + options, GRPC_TLS_SKIP_HOSTNAME_VERIFICATION); + grpc_tls_credentials_options_set_certificate_provider( + options, + grpc_core::MakeRefCounted( + ReadFile(kCaCertPath), + ReadTlsIdentityPair(kServerKeyPath, kServerCertPath)) + .get()); + grpc_tls_credentials_options_watch_root_certs(options); + grpc_tls_credentials_options_watch_identity_key_cert_pairs(options); + grpc_tls_server_authorization_check_config* check_config = + grpc_tls_server_authorization_check_config_create( + nullptr, ServerAuthCheckSchedule, nullptr, nullptr); + grpc_tls_credentials_options_set_server_authorization_check_config( + options, check_config); + auto channel_creds = std::make_shared( + grpc_tls_credentials_create(options)); + grpc_tls_server_authorization_check_config_release(check_config); + return CreateCustomChannel(uri, channel_creds, args); + } + + std::shared_ptr CreateTlsChannel() { + ChannelArguments args; + // Override target name for host name check + args.SetString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, + ipv6_only_ ? "::1" : "127.0.0.1"); + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + std::string uri = absl::StrCat( + ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", backends_[0]->port()); + // TODO(yashykt): Switch to using C++ API once b/173823806 is fixed. + grpc_tls_credentials_options* options = + grpc_tls_credentials_options_create(); + grpc_tls_credentials_options_set_server_verification_option( + options, GRPC_TLS_SKIP_HOSTNAME_VERIFICATION); + grpc_tls_credentials_options_set_certificate_provider( + options, + grpc_core::MakeRefCounted( + ReadFile(kCaCertPath), + ReadTlsIdentityPair(kServerKeyPath, kServerCertPath)) + .get()); + grpc_tls_credentials_options_watch_root_certs(options); + grpc_tls_server_authorization_check_config* check_config = + grpc_tls_server_authorization_check_config_create( + nullptr, ServerAuthCheckSchedule, nullptr, nullptr); + grpc_tls_credentials_options_set_server_authorization_check_config( + options, check_config); + auto channel_creds = std::make_shared( + grpc_tls_credentials_create(options)); + grpc_tls_server_authorization_check_config_release(check_config); + return CreateCustomChannel(uri, channel_creds, args); + } + + std::shared_ptr CreateInsecureChannel() { + ChannelArguments args; + // Override target name for host name check + args.SetString(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, + ipv6_only_ ? "::1" : "127.0.0.1"); + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + std::string uri = absl::StrCat( + ipv6_only_ ? "ipv6:[::1]:" : "ipv4:127.0.0.1:", backends_[0]->port()); + return CreateCustomChannel(uri, InsecureChannelCredentials(), args); + } + + void SendRpc(std::function()> channel_creator, + std::vector expected_server_identity, + std::vector expected_client_identity, + bool test_expects_failure = false) { + gpr_log(GPR_INFO, "Sending RPC"); + int num_tries = 0; + constexpr int kRetryCount = 100; + for (; num_tries < kRetryCount; num_tries++) { + auto channel = channel_creator(); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext context; + context.set_wait_for_ready(true); + context.set_deadline(grpc_timeout_milliseconds_to_deadline(2000)); + EchoRequest request; + request.set_message(kRequestMessage); + EchoResponse response; + Status status = stub->Echo(&context, request, &response); + if (test_expects_failure) { + if (status.ok()) { + gpr_log(GPR_ERROR, "RPC succeeded. Failure expected. Trying again."); + continue; + } + } else { + if (!status.ok()) { + gpr_log(GPR_ERROR, "RPC failed. code=%d message=%s Trying again.", + status.error_code(), status.error_message().c_str()); + continue; + } + EXPECT_EQ(response.message(), kRequestMessage); + std::vector peer_identity; + for (const auto& entry : context.auth_context()->GetPeerIdentity()) { + peer_identity.emplace_back( + std::string(entry.data(), entry.size()).c_str()); + } + if (peer_identity != expected_server_identity) { + gpr_log(GPR_ERROR, + "Expected server identity does not match. (actual) %s vs " + "(expected) %s Trying again.", + absl::StrJoin(peer_identity, ",").c_str(), + absl::StrJoin(expected_server_identity, ",").c_str()); + continue; + } + if (backends_[0]->backend_service()->last_peer_identity() != + expected_client_identity) { + gpr_log( + GPR_ERROR, + "Expected client identity does not match. (actual) %s vs " + "(expected) %s Trying again.", + absl::StrJoin( + backends_[0]->backend_service()->last_peer_identity(), ",") + .c_str(), + absl::StrJoin(expected_client_identity, ",").c_str()); + continue; + } + } + break; + } + EXPECT_LT(num_tries, kRetryCount); + } + + std::string root_cert_; + std::string bad_root_cert_; + grpc_core::PemKeyCertPairList identity_pair_; + grpc_core::PemKeyCertPairList bad_identity_pair_; + grpc_core::PemKeyCertPairList identity_pair_2_; + std::vector server_authenticated_identity_; + std::vector server_authenticated_identity_2_; + std::vector client_authenticated_identity_; +}; + +TEST_P(XdsServerSecurityTest, UnknownTransportSocket) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("unknown_transport_socket"); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized transport socket: unknown_transport_socket")); +} + +TEST_P(XdsServerSecurityTest, NacksRequireSNI) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + downstream_tls_context.mutable_require_sni()->set_value(true); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("require_sni: unsupported")); +} + +TEST_P(XdsServerSecurityTest, NacksOcspStaplePolicyOtherThanLenientStapling) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + downstream_tls_context.set_ocsp_staple_policy( + envoy::extensions::transport_sockets::tls::v3:: + DownstreamTlsContext_OcspStaplePolicy_STRICT_STAPLING); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "ocsp_staple_policy: Only LENIENT_STAPLING supported")); +} + +TEST_P( + XdsServerSecurityTest, + NacksRequiringClientCertificateWithoutValidationCertificateProviderInstance) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + downstream_tls_context.mutable_require_client_certificate()->set_value(true); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "TLS configuration requires client certificates but no " + "certificate provider instance specified for validation.")); +} + +TEST_P(XdsServerSecurityTest, + NacksTlsConfigurationWithoutIdentityProviderInstance) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("TLS configuration provided but no " + "tls_certificate_provider_instance found.")); +} + +TEST_P(XdsServerSecurityTest, NacksMatchSubjectAltNames) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + downstream_tls_context.mutable_common_tls_context() + ->mutable_validation_context() + ->add_match_subject_alt_names() + ->set_exact("*.test.google.fr"); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT( + response_state.error_message, + ::testing::HasSubstr("match_subject_alt_names not supported on servers")); +} + +TEST_P(XdsServerSecurityTest, UnknownIdentityCertificateProvider) { + SetLdsUpdate("", "", "unknown", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, {}, {}, + true /* test_expects_failure */); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized certificate provider instance name: unknown")); +} + +TEST_P(XdsServerSecurityTest, UnknownRootCertificateProvider) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + SetLdsUpdate("unknown", "", "fake_plugin1", "", false); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr( + "Unrecognized certificate provider instance name: unknown")); +} + +TEST_P(XdsServerSecurityTest, + TestDeprecateTlsCertificateCertificateProviderInstanceField) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + Listener listener; + listener.set_name(absl::StrCat( + ipv6_only_ ? "grpc/server?xds.resource.listening_address=[::1]:" + : "grpc/server?xds.resource.listening_address=127.0.0.1:", + backends_[0]->port())); + listener.mutable_address()->mutable_socket_address()->set_address( + ipv6_only_ ? "[::1]" : "127.0.0.1"); + listener.mutable_address()->mutable_socket_address()->set_port_value( + backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); +} + +TEST_P(XdsServerSecurityTest, CertificatesNotAvailable) { + FakeCertificateProvider::CertDataMap fake1_cert_map; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerSecurityTest, TestMtls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithRootPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {bad_root_cert_, bad_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin2", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithIdentityPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {root_cert_, identity_pair_2_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin1", "", "fake_plugin2", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_2_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithBothPluginsUpdated) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"good", {root_cert_, identity_pair_2_}}, + {"", {bad_root_cert_, bad_identity_pair_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + SetLdsUpdate("fake_plugin2", "", "fake_plugin2", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, {}, {}, + true /* test_expects_failure */); + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin2", "good", "fake_plugin2", "good", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_2_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithRootCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"bad", {bad_root_cert_, bad_identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin1", "bad", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithIdentityCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"good", {root_cert_, identity_pair_2_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "good", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_2_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsWithBothCertificateNamesUpdated) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"good", {root_cert_, identity_pair_2_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("fake_plugin1", "good", "fake_plugin1", "good", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_2_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsNotRequiringButProvidingClientCerts) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestMtlsNotRequiringAndNotProvidingClientCerts) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); +} + +TEST_P(XdsServerSecurityTest, TestTls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); +} + +TEST_P(XdsServerSecurityTest, TestTlsWithIdentityPluginUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + FakeCertificateProvider::CertDataMap fake2_cert_map = { + {"", {root_cert_, identity_pair_2_}}}; + g_fake2_cert_data_map = &fake2_cert_map; + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); + SetLdsUpdate("", "", "fake_plugin2", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_2_, {}); +} + +TEST_P(XdsServerSecurityTest, TestTlsWithIdentityCertificateNameUpdate) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}, + {"good", {root_cert_, identity_pair_2_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); + SetLdsUpdate("", "", "fake_plugin1", "good", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_2_, {}); +} + +TEST_P(XdsServerSecurityTest, TestFallback) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "", "", false); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerSecurityTest, TestMtlsToTls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateTlsChannel(); }, {}, {}, + true /* test_expects_failure */); + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); +} + +TEST_P(XdsServerSecurityTest, TestTlsToMtls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateTlsChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerSecurityTest, TestMtlsToFallback) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); + SetLdsUpdate("", "", "", "", false); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerSecurityTest, TestFallbackToMtls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "", "", false); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + SetLdsUpdate("fake_plugin1", "", "fake_plugin1", "", true); + SendRpc([this]() { return CreateMtlsChannel(); }, + server_authenticated_identity_, client_authenticated_identity_); +} + +TEST_P(XdsServerSecurityTest, TestTlsToFallback) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); + SetLdsUpdate("", "", "", "", false); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerSecurityTest, TestFallbackToTls) { + FakeCertificateProvider::CertDataMap fake1_cert_map = { + {"", {root_cert_, identity_pair_}}}; + g_fake1_cert_data_map = &fake1_cert_map; + SetLdsUpdate("", "", "", "", false); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + SetLdsUpdate("", "", "fake_plugin1", "", false); + SendRpc([this]() { return CreateTlsChannel(); }, + server_authenticated_identity_, {}); +} + +class XdsEnabledServerStatusNotificationTest : public XdsServerSecurityTest { + protected: + void SetValidLdsUpdate() { SetLdsUpdate("", "", "", "", false); } + + void SetInvalidLdsUpdate() { + Listener listener; + listener.set_name(absl::StrCat( + "grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + balancers_[0]->ads_service()->SetLdsResource(listener); + } + + void UnsetLdsUpdate() { + balancers_[0]->ads_service()->UnsetResource( + kLdsTypeUrl, absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", + backends_[0]->port())); + } +}; + +TEST_P(XdsEnabledServerStatusNotificationTest, ServingStatus) { + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsEnabledServerStatusNotificationTest, NotServingStatus) { + SetInvalidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::UNAVAILABLE); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsEnabledServerStatusNotificationTest, ErrorUpdateWhenAlreadyServing) { + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + // Invalid update does not lead to a change in the serving status. + SetInvalidLdsUpdate(); + do { + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + } while (balancers_[0]->ads_service()->lds_response_state().state == + AdsServiceImpl::ResponseState::SENT); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsEnabledServerStatusNotificationTest, + NotServingStatusToServingStatusTransition) { + SetInvalidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::UNAVAILABLE); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); + // Send a valid LDS update to change to serving status + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +// This test verifies that the resource getting deleted when already serving +// results in future connections being dropped. +TEST_P(XdsEnabledServerStatusNotificationTest, + ServingStatusToNonServingStatusTransition) { + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + // Deleting the resource should result in a non-serving status. + UnsetLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::NOT_FOUND); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsEnabledServerStatusNotificationTest, RepeatedServingStatusChanges) { + for (int i = 0; i < 5; i++) { + // Send a valid LDS update to get the server to start listening + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", + backends_[0]->port()), + grpc::StatusCode::OK); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); + // Deleting the resource will make the server start rejecting connections + UnsetLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", + backends_[0]->port()), + grpc::StatusCode::NOT_FOUND); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); + } +} + +TEST_P(XdsEnabledServerStatusNotificationTest, ExistingRpcsOnResourceDeletion) { + // Send a valid LDS update to get the server to start listening + SetValidLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::OK); + constexpr int kNumChannels = 10; + struct StreamingRpc { + std::shared_ptr channel; + std::unique_ptr stub; + ClientContext context; + std::unique_ptr> stream; + } streaming_rpcs[kNumChannels]; + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + for (int i = 0; i < kNumChannels; i++) { + streaming_rpcs[i].channel = CreateInsecureChannel(); + streaming_rpcs[i].stub = + grpc::testing::EchoTestService::NewStub(streaming_rpcs[i].channel); + streaming_rpcs[i].context.set_wait_for_ready(true); + streaming_rpcs[i].stream = + streaming_rpcs[i].stub->BidiStream(&streaming_rpcs[i].context); + EXPECT_TRUE(streaming_rpcs[i].stream->Write(request)); + streaming_rpcs[i].stream->Read(&response); + EXPECT_EQ(request.message(), response.message()); + } + // Deleting the resource will make the server start rejecting connections + UnsetLdsUpdate(); + backends_[0]->notifier()->WaitOnServingStatusChange( + absl::StrCat(ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port()), + grpc::StatusCode::NOT_FOUND); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); + for (int i = 0; i < kNumChannels; i++) { + EXPECT_TRUE(streaming_rpcs[i].stream->Write(request)); + streaming_rpcs[i].stream->Read(&response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(streaming_rpcs[i].stream->WritesDone()); + auto status = streaming_rpcs[i].stream->Finish(); + EXPECT_TRUE(status.ok()) + << status.error_message() << ", " << status.error_details() << ", " + << streaming_rpcs[i].context.debug_error_string(); + // New RPCs on the existing channels should fail. + ClientContext new_context; + new_context.set_deadline(grpc_timeout_milliseconds_to_deadline(1000)); + EXPECT_FALSE( + streaming_rpcs[i].stub->Echo(&new_context, request, &response).ok()); + } +} + +using XdsServerFilterChainMatchTest = XdsServerSecurityTest; + +TEST_P(XdsServerFilterChainMatchTest, + DefaultFilterChainUsedWhenNoFilterChainMentioned) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + listener.mutable_default_filter_chain() + ->add_filters() + ->mutable_typed_config() + ->PackFrom(HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + DefaultFilterChainUsedWhenOtherFilterChainsDontMatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add a filter chain that will never get matched + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match() + ->mutable_destination_port() + ->set_value(8080); + // Add default filter chain that should get used + listener.mutable_default_filter_chain() + ->add_filters() + ->mutable_typed_config() + ->PackFrom(HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithDestinationPortDontMatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with destination port that should never get matched + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match() + ->mutable_destination_port() + ->set_value(8080); + balancers_[0]->ads_service()->SetLdsResource(listener); + // RPC should fail since no matching filter chain was found and no default + // filter chain is configured. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerFilterChainMatchTest, FilterChainsWithServerNamesDontMatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with server name that should never get matched + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_server_names("server_name"); + balancers_[0]->ads_service()->SetLdsResource(listener); + // RPC should fail since no matching filter chain was found and no default + // filter chain is configured. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithTransportProtocolsOtherThanRawBufferDontMatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with transport protocol "tls" that should never match + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_transport_protocol("tls"); + balancers_[0]->ads_service()->SetLdsResource(listener); + // RPC should fail since no matching filter chain was found and no default + // filter chain is configured. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithApplicationProtocolsDontMatch) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with application protocol that should never get matched + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_application_protocols("h2"); + balancers_[0]->ads_service()->SetLdsResource(listener); + // RPC should fail since no matching filter chain was found and no default + // filter chain is configured. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}, + true /* test_expects_failure */); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithTransportProtocolRawBufferIsPreferred) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with "raw_buffer" transport protocol + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_transport_protocol( + "raw_buffer"); + // Add another filter chain with no transport protocol set but application + // protocol set (fails match) + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_application_protocols("h2"); + balancers_[0]->ads_service()->SetLdsResource(listener); + // A successful RPC proves that filter chains that mention "raw_buffer" as + // the transport protocol are chosen as the best match in the round. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithMoreSpecificDestinationPrefixRangesArePreferred) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with prefix range (length 4 and 16) but with server name + // mentioned. (Prefix range is matched first.) + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(4); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(16); + filter_chain->mutable_filter_chain_match()->add_server_names("server_name"); + // Add filter chain with two prefix ranges (length 8 and 24). Since 24 is + // the highest match, it should be chosen. + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(8); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(24); + // Add another filter chain with a non-matching prefix range (with length + // 30) + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix("192.168.1.1"); + prefix_range->mutable_prefix_len()->set_value(30); + filter_chain->mutable_filter_chain_match()->add_server_names("server_name"); + // Add another filter chain with no prefix range mentioned + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_server_names("server_name"); + balancers_[0]->ads_service()->SetLdsResource(listener); + // A successful RPC proves that the filter chain with the longest matching + // prefix range was the best match. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsThatMentionSourceTypeArePreferred) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with the local source type (best match) + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::SAME_IP_OR_LOOPBACK); + // Add filter chain with the external source type but bad source port. + // Note that backends_[0]->port() will never be a match for the source port + // because it is already being used by a backend. + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::EXTERNAL); + filter_chain->mutable_filter_chain_match()->add_source_ports( + backends_[0]->port()); + // Add filter chain with the default source type (ANY) but bad source port. + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_source_ports( + backends_[0]->port()); + balancers_[0]->ads_service()->SetLdsResource(listener); + // A successful RPC proves that the filter chain with the longest matching + // prefix range was the best match. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithMoreSpecificSourcePrefixRangesArePreferred) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with source prefix range (length 16) but with a bad + // source port mentioned. (Prefix range is matched first.) Note that + // backends_[0]->port() will never be a match for the source port because it + // is already being used by a backend. + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* source_prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + source_prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + source_prefix_range->mutable_prefix_len()->set_value(4); + source_prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + source_prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + source_prefix_range->mutable_prefix_len()->set_value(16); + filter_chain->mutable_filter_chain_match()->add_source_ports( + backends_[0]->port()); + // Add filter chain with two source prefix ranges (length 8 and 24). Since + // 24 is the highest match, it should be chosen. + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + source_prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + source_prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + source_prefix_range->mutable_prefix_len()->set_value(8); + source_prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + source_prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + source_prefix_range->mutable_prefix_len()->set_value(24); + // Add another filter chain with a non-matching source prefix range (with + // length 30) and bad source port + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + source_prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + source_prefix_range->set_address_prefix("192.168.1.1"); + source_prefix_range->mutable_prefix_len()->set_value(30); + filter_chain->mutable_filter_chain_match()->add_source_ports( + backends_[0]->port()); + // Add another filter chain with no source prefix range mentioned and bad + // source port + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_source_ports( + backends_[0]->port()); + balancers_[0]->ads_service()->SetLdsResource(listener); + // A successful RPC proves that the filter chain with the longest matching + // source prefix range was the best match. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, + FilterChainsWithMoreSpecificSourcePortArePreferred) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + // Since we don't know which port will be used by the channel, just add all + // ports except for 0. + for (int i = 1; i < 65536; i++) { + filter_chain->mutable_filter_chain_match()->add_source_ports(i); + } + // Add another filter chain with no source port mentioned with a bad + // DownstreamTlsContext configuration. + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* transport_socket = filter_chain->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.tls"); + DownstreamTlsContext downstream_tls_context; + downstream_tls_context.mutable_common_tls_context() + ->mutable_tls_certificate_provider_instance() + ->set_instance_name("fake_plugin1"); + transport_socket->mutable_typed_config()->PackFrom(downstream_tls_context); + balancers_[0]->ads_service()->SetLdsResource(listener); + // A successful RPC proves that the filter chain with matching source port + // was chosen. + SendRpc([this]() { return CreateInsecureChannel(); }, {}, {}); +} + +TEST_P(XdsServerFilterChainMatchTest, DuplicateMatchNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + // Add a duplicate filter chain + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr( + "Duplicate matching rules detected when adding filter chain: {}")); +} + +TEST_P(XdsServerFilterChainMatchTest, DuplicateMatchOnPrefixRangesNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with prefix range + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(16); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(24); + // Add a filter chain with a duplicate prefix range entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(16); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(32); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + if (ipv6_only_) { + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr( + "Duplicate matching rules detected when adding filter chain: " + "{prefix_ranges={{address_prefix=[::]:0, prefix_len=16}, " + "{address_prefix=[::]:0, prefix_len=32}}}")); + } else { + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr( + "Duplicate matching rules detected when adding filter chain: " + "{prefix_ranges={{address_prefix=127.0.0.0:0, prefix_len=16}, " + "{address_prefix=127.0.0.1:0, prefix_len=32}}}")); + } +} + +TEST_P(XdsServerFilterChainMatchTest, DuplicateMatchOnTransportProtocolNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with "raw_buffer" transport protocol + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_transport_protocol( + "raw_buffer"); + // Add a duplicate filter chain with the same "raw_buffer" transport + // protocol entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_transport_protocol( + "raw_buffer"); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr("Duplicate matching rules detected when adding " + "filter chain: {transport_protocol=raw_buffer}")); +} + +TEST_P(XdsServerFilterChainMatchTest, DuplicateMatchOnLocalSourceTypeNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with the local source type + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::SAME_IP_OR_LOOPBACK); + // Add a duplicate filter chain with the same local source type entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::SAME_IP_OR_LOOPBACK); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr("Duplicate matching rules detected when adding " + "filter chain: {source_type=SAME_IP_OR_LOOPBACK}")); +} + +TEST_P(XdsServerFilterChainMatchTest, + DuplicateMatchOnExternalSourceTypeNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with the external source type + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::EXTERNAL); + // Add a duplicate filter chain with the same external source type entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->set_source_type( + FilterChainMatch::EXTERNAL); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr("Duplicate matching rules detected when adding " + "filter chain: {source_type=EXTERNAL}")); +} + +TEST_P(XdsServerFilterChainMatchTest, + DuplicateMatchOnSourcePrefixRangesNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with source prefix range + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + auto* prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(16); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(24); + // Add a filter chain with a duplicate source prefix range entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(16); + prefix_range = + filter_chain->mutable_filter_chain_match()->add_source_prefix_ranges(); + prefix_range->set_address_prefix(ipv6_only_ ? "::1" : "127.0.0.1"); + prefix_range->mutable_prefix_len()->set_value(32); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + if (ipv6_only_) { + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr( + "Duplicate matching rules detected when adding filter chain: " + "{source_prefix_ranges={{address_prefix=[::]:0, prefix_len=16}, " + "{address_prefix=[::]:0, prefix_len=32}}}")); + } else { + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr( + "Duplicate matching rules detected when adding filter chain: " + "{source_prefix_ranges={{address_prefix=127.0.0.0:0, " + "prefix_len=16}, " + "{address_prefix=127.0.0.1:0, prefix_len=32}}}")); + } +} + +TEST_P(XdsServerFilterChainMatchTest, DuplicateMatchOnSourcePortNacked) { + Listener listener; + listener.set_name( + absl::StrCat("grpc/server?xds.resource.listening_address=", + ipv6_only_ ? "[::1]:" : "127.0.0.1:", backends_[0]->port())); + auto* socket_address = listener.mutable_address()->mutable_socket_address(); + socket_address->set_address(ipv6_only_ ? "::1" : "127.0.0.1"); + socket_address->set_port_value(backends_[0]->port()); + // Add filter chain with the external source type + auto* filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_source_ports(8080); + // Add a duplicate filter chain with the same source port entry + filter_chain = listener.add_filter_chains(); + filter_chain->add_filters()->mutable_typed_config()->PackFrom( + HttpConnectionManager()); + filter_chain->mutable_filter_chain_match()->add_source_ports(8080); + balancers_[0]->ads_service()->SetLdsResource(listener); + ASSERT_TRUE(WaitForLdsNack(StatusCode::DEADLINE_EXCEEDED)) + << "timed out waiting for NACK"; + EXPECT_THAT( + balancers_[0]->ads_service()->lds_response_state().error_message, + ::testing::HasSubstr("Duplicate matching rules detected when adding " + "filter chain: {source_ports={8080}}")); +} + +using EdsTest = BasicTest; + +// Tests that EDS client should send a NACK if the EDS update contains +// sparse priorities. +TEST_P(EdsTest, NacksSparsePriorityList) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(), kDefaultLocalityWeight, 1}, + }); + balancers_[0]->ads_service()->SetEdsResource(BuildEdsResource(args)); + ASSERT_TRUE(WaitForEdsNack()) << "timed out waiting for NACK"; + const auto response_state = + balancers_[0]->ads_service()->eds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_THAT(response_state.error_message, + ::testing::HasSubstr("sparse priority list")); +} + +// In most of our tests, we use different names for different resource +// types, to make sure that there are no cut-and-paste errors in the code +// that cause us to look at data for the wrong resource type. So we add +// this test to make sure that the EDS resource name defaults to the +// cluster name if not specified in the CDS resource. +TEST_P(EdsTest, EdsServiceNameDefaultsToClusterName) { + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, kDefaultClusterName)); + Cluster cluster = default_cluster_; + cluster.mutable_eds_cluster_config()->clear_service_name(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); +} + +class TimeoutTest : public BasicTest { + protected: + void SetUp() override { + xds_resource_does_not_exist_timeout_ms_ = 500; + BasicTest::SetUp(); + } +}; + +// Tests that LDS client times out when no response received. +TEST_P(TimeoutTest, Lds) { + balancers_[0]->ads_service()->IgnoreResourceType(kLdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +TEST_P(TimeoutTest, Rds) { + balancers_[0]->ads_service()->IgnoreResourceType(kRdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +// Tests that CDS client times out when no response received. +TEST_P(TimeoutTest, Cds) { + balancers_[0]->ads_service()->IgnoreResourceType(kCdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +TEST_P(TimeoutTest, Eds) { + balancers_[0]->ads_service()->IgnoreResourceType(kEdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +using LocalityMapTest = BasicTest; + +// Tests that the localities in a locality map are picked according to their +// weights. +TEST_P(LocalityMapTest, WeightedRoundRobin) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const int kLocalityWeight0 = 2; + const int kLocalityWeight1 = 8; + const int kTotalLocalityWeight = kLocalityWeight0 + kLocalityWeight1; + const double kLocalityWeightRate0 = + static_cast(kLocalityWeight0) / kTotalLocalityWeight; + const double kLocalityWeightRate1 = + static_cast(kLocalityWeight1) / kTotalLocalityWeight; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kLocalityWeightRate0, kErrorTolerance); + // ADS response contains 2 localities, each of which contains 1 backend. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kLocalityWeight0}, + {"locality1", CreateEndpointsForBackends(1, 2), kLocalityWeight1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for both backends to be ready. + WaitForAllBackends(0, 2); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + // The locality picking rates should be roughly equal to the expectation. + const double locality_picked_rate_0 = + static_cast(backends_[0]->backend_service()->request_count()) / + kNumRpcs; + const double locality_picked_rate_1 = + static_cast(backends_[1]->backend_service()->request_count()) / + kNumRpcs; + EXPECT_THAT(locality_picked_rate_0, + ::testing::DoubleNear(kLocalityWeightRate0, kErrorTolerance)); + EXPECT_THAT(locality_picked_rate_1, + ::testing::DoubleNear(kLocalityWeightRate1, kErrorTolerance)); +} + +// Tests that we correctly handle a locality containing no endpoints. +TEST_P(LocalityMapTest, LocalityContainingNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + // EDS response contains 2 localities, one with no endpoints. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + {"locality1", {}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for both backends to be ready. + WaitForAllBackends(); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + // All traffic should go to the reachable locality. + EXPECT_EQ(backends_[0]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[1]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[2]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[3]->backend_service()->request_count(), + kNumRpcs / backends_.size()); +} + +// EDS update with no localities. +TEST_P(LocalityMapTest, NoLocalities) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource({}, DefaultEdsServiceName())); + Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE); +} + +// Tests that the locality map can work properly even when it contains a large +// number of localities. +TEST_P(LocalityMapTest, StressTest) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumLocalities = 100; + const uint32_t kRpcTimeoutMs = 5000; + // The first ADS response contains kNumLocalities localities, each of which + // contains backend 0. + EdsResourceArgs args; + for (size_t i = 0; i < kNumLocalities; ++i) { + std::string name = absl::StrCat("locality", i); + EdsResourceArgs::Locality locality(name, CreateEndpointsForBackends(0, 1)); + args.locality_list.emplace_back(std::move(locality)); + } + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // The second ADS response contains 1 locality, which contains backend 1. + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends(1, 2)}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), 60 * 1000)); + // Wait until backend 0 is ready, before which kNumLocalities localities are + // received and handled by the xds policy. + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false), + RpcOptions().set_timeout_ms(kRpcTimeoutMs)); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // Wait until backend 1 is ready, before which kNumLocalities localities are + // removed by the xds policy. + WaitForBackend(1); + delayed_resource_setter.join(); +} + +// Tests that the localities in a locality map are picked correctly after +// update (addition, modification, deletion). +TEST_P(LocalityMapTest, UpdateMap) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 3000; + // The locality weight for the first 3 localities. + const std::vector kLocalityWeights0 = {2, 3, 4}; + const double kTotalLocalityWeight0 = + std::accumulate(kLocalityWeights0.begin(), kLocalityWeights0.end(), 0); + std::vector locality_weight_rate_0; + locality_weight_rate_0.reserve(kLocalityWeights0.size()); + for (int weight : kLocalityWeights0) { + locality_weight_rate_0.push_back(weight / kTotalLocalityWeight0); + } + // Delete the first locality, keep the second locality, change the third + // locality's weight from 4 to 2, and add a new locality with weight 6. + const std::vector kLocalityWeights1 = {3, 2, 6}; + const double kTotalLocalityWeight1 = + std::accumulate(kLocalityWeights1.begin(), kLocalityWeights1.end(), 0); + std::vector locality_weight_rate_1 = { + 0 /* placeholder for locality 0 */}; + for (int weight : kLocalityWeights1) { + locality_weight_rate_1.push_back(weight / kTotalLocalityWeight1); + } + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), 2}, + {"locality1", CreateEndpointsForBackends(1, 2), 3}, + {"locality2", CreateEndpointsForBackends(2, 3), 4}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for the first 3 backends to be ready. + WaitForAllBackends(0, 3); + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // The picking rates of the first 3 backends should be roughly equal to the + // expectation. + std::vector locality_picked_rates; + for (size_t i = 0; i < 3; ++i) { + locality_picked_rates.push_back( + static_cast(backends_[i]->backend_service()->request_count()) / + kNumRpcs); + } + const double kErrorTolerance = 0.2; + for (size_t i = 0; i < 3; ++i) { + gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i, + locality_picked_rates[i]); + EXPECT_THAT( + locality_picked_rates[i], + ::testing::AllOf( + ::testing::Ge(locality_weight_rate_0[i] * (1 - kErrorTolerance)), + ::testing::Le(locality_weight_rate_0[i] * (1 + kErrorTolerance)))); + } + args = EdsResourceArgs({ + {"locality1", CreateEndpointsForBackends(1, 2), 3}, + {"locality2", CreateEndpointsForBackends(2, 3), 2}, + {"locality3", CreateEndpointsForBackends(3, 4), 6}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Backend 3 hasn't received any request. + EXPECT_EQ(0U, backends_[3]->backend_service()->request_count()); + // Wait until the locality update has been processed, as signaled by backend + // 3 receiving a request. + WaitForAllBackends(3, 4); + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // Backend 0 no longer receives any request. + EXPECT_EQ(0U, backends_[0]->backend_service()->request_count()); + // The picking rates of the last 3 backends should be roughly equal to the + // expectation. + locality_picked_rates = {0 /* placeholder for backend 0 */}; + for (size_t i = 1; i < 4; ++i) { + locality_picked_rates.push_back( + static_cast(backends_[i]->backend_service()->request_count()) / + kNumRpcs); + } + for (size_t i = 1; i < 4; ++i) { + gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i, + locality_picked_rates[i]); + EXPECT_THAT( + locality_picked_rates[i], + ::testing::AllOf( + ::testing::Ge(locality_weight_rate_1[i] * (1 - kErrorTolerance)), + ::testing::Le(locality_weight_rate_1[i] * (1 + kErrorTolerance)))); + } +} + +// Tests that we don't fail RPCs when replacing all of the localities in +// a given priority. +TEST_P(LocalityMapTest, ReplaceAllLocalitiesInPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({ + {"locality1", CreateEndpointsForBackends(1, 2)}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), 5000)); + // Wait for the first backend to be ready. + WaitForBackend(0); + // Keep sending RPCs until we switch over to backend 1, which tells us + // that we received the update. No RPCs should fail during this + // transition. + WaitForBackend(1); + delayed_resource_setter.join(); +} + +class FailoverTest : public BasicTest { + public: + void SetUp() override { + BasicTest::SetUp(); + ResetStub(500); + } +}; + +// Localities with the highest priority are used when multiple priority exist. +TEST_P(FailoverTest, ChooseHighestPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 1}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 2}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 3}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(3, WaitForBackendOptions().set_reset_counters(false)); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// Does not choose priority with no endpoints. +TEST_P(FailoverTest, DoesNotUsePriorityWithNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 1}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 2}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 3}, + {"locality3", {}, kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(0, WaitForBackendOptions().set_reset_counters(false)); + for (size_t i = 1; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// Does not choose locality with no endpoints. +TEST_P(FailoverTest, DoesNotUseLocalityWithNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", {}, kDefaultLocalityWeight, 0}, + {"locality1", CreateEndpointsForBackends(), kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for all backends to be used. + std::tuple counts = WaitForAllBackends(); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +// If the higher priority localities are not reachable, failover to the +// highest priority among the rest. +TEST_P(FailoverTest, Failover) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 1}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 2}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 3}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 0}, + }); + ShutdownBackend(3); + ShutdownBackend(0); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(1, WaitForBackendOptions().set_reset_counters(false)); + for (size_t i = 0; i < 4; ++i) { + if (i == 1) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// If a locality with higher priority than the current one becomes ready, +// switch to it. +TEST_P(FailoverTest, SwitchBackToHigherPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 100; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 1}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 2}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 3}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(3); + ShutdownBackend(3); + ShutdownBackend(0); + WaitForBackend( + 1, WaitForBackendOptions().set_reset_counters(false).set_allow_failures( + true)); + for (size_t i = 0; i < 4; ++i) { + if (i == 1) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + StartBackend(0); + WaitForBackend(0); + CheckRpcSendOk(kNumRpcs); + EXPECT_EQ(kNumRpcs, backends_[0]->backend_service()->request_count()); +} + +// The first update only contains unavailable priorities. The second update +// contains available priorities. +TEST_P(FailoverTest, UpdateInitialUnavailable) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 0}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 0}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 1}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 2}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 3}, + }); + ShutdownBackend(0); + ShutdownBackend(1); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + gpr_timespec deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(500, GPR_TIMESPAN)); + // Send 0.5 second worth of RPCs. + do { + CheckRpcSendFailure(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + WaitForBackend( + 2, WaitForBackendOptions().set_reset_counters(false).set_allow_failures( + true)); + for (size_t i = 0; i < 4; ++i) { + if (i == 2) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + delayed_resource_setter.join(); +} + +// Tests that after the localities' priorities are updated, we still choose +// the highest READY priority with the updated localities. +TEST_P(FailoverTest, UpdatePriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 100; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 1}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 2}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 3}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 2}, + {"locality1", CreateEndpointsForBackends(1, 2), kDefaultLocalityWeight, + 0}, + {"locality2", CreateEndpointsForBackends(2, 3), kDefaultLocalityWeight, + 1}, + {"locality3", CreateEndpointsForBackends(3, 4), kDefaultLocalityWeight, + 3}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + WaitForBackend(3, WaitForBackendOptions().set_reset_counters(false)); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + WaitForBackend(1); + CheckRpcSendOk(kNumRpcs); + EXPECT_EQ(kNumRpcs, backends_[1]->backend_service()->request_count()); + delayed_resource_setter.join(); +} + +// Moves all localities in the current priority to a higher priority. +TEST_P(FailoverTest, MoveAllLocalitiesInCurrentPriorityToHigherPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // First update: + // - Priority 0 is locality 0, containing backend 0, which is down. + // - Priority 1 is locality 1, containing backends 1 and 2, which are up. + ShutdownBackend(0); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 0}, + {"locality1", CreateEndpointsForBackends(1, 3), kDefaultLocalityWeight, + 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Second update: + // - Priority 0 contains both localities 0 and 1. + // - Priority 1 is not present. + // - We add backend 3 to locality 1, just so we have a way to know + // when the update has been seen by the client. + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends(0, 1), kDefaultLocalityWeight, + 0}, + {"locality1", CreateEndpointsForBackends(1, 4), kDefaultLocalityWeight, + 0}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + // When we get the first update, all backends in priority 0 are down, + // so we will create priority 1. Backends 1 and 2 should have traffic, + // but backend 3 should not. + WaitForAllBackends(1, 3, WaitForBackendOptions().set_reset_counters(false)); + EXPECT_EQ(0UL, backends_[3]->backend_service()->request_count()); + // When backend 3 gets traffic, we know the second update has been seen. + WaitForBackend(3); + // The ADS service of balancer 0 got at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + delayed_resource_setter.join(); +} + +using DropTest = BasicTest; + +// Tests that RPCs are dropped according to the drop config. +TEST_P(DropTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double kDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kDropRateForLbAndThrottle, kErrorTolerance); + // The ADS response contains two drop categories. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast(num_drops) / kNumRpcs; + EXPECT_THAT(seen_drop_rate, ::testing::DoubleNear(kDropRateForLbAndThrottle, + kErrorTolerance)); +} + +// Tests that drop config is converted correctly from per hundred. +TEST_P(DropTest, DropPerHundred) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const uint32_t kDropPerHundredForLb = 10; + const double kDropRateForLb = kDropPerHundredForLb / 100.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kDropRateForLb, kErrorTolerance); + // The ADS response contains one drop category. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + args.drop_categories = {{kLbDropType, kDropPerHundredForLb}}; + args.drop_denominator = FractionalPercent::HUNDRED; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast(num_drops) / kNumRpcs; + EXPECT_THAT(seen_drop_rate, + ::testing::DoubleNear(kDropRateForLb, kErrorTolerance)); +} + +// Tests that drop config is converted correctly from per ten thousand. +TEST_P(DropTest, DropPerTenThousand) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const uint32_t kDropPerTenThousandForLb = 1000; + const double kDropRateForLb = kDropPerTenThousandForLb / 10000.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kDropRateForLb, kErrorTolerance); + // The ADS response contains one drop category. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + args.drop_categories = {{kLbDropType, kDropPerTenThousandForLb}}; + args.drop_denominator = FractionalPercent::TEN_THOUSAND; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast(num_drops) / kNumRpcs; + EXPECT_THAT(seen_drop_rate, + ::testing::DoubleNear(kDropRateForLb, kErrorTolerance)); +} + +// Tests that drop is working correctly after update. +TEST_P(DropTest, Update) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kErrorTolerance = 0.05; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double kDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + const size_t kNumRpcsLbOnly = + ComputeIdealNumRpcs(kDropRateForLb, kErrorTolerance); + const size_t kNumRpcsBoth = + ComputeIdealNumRpcs(kDropRateForLbAndThrottle, kErrorTolerance); + // The first ADS response contains one drop category. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}}; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcsLbOnly RPCs and count the drops. + size_t num_drops = 0; + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + for (size_t i = 0; i < kNumRpcsLbOnly; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // The drop rate should be roughly equal to the expectation. + double seen_drop_rate = static_cast(num_drops) / kNumRpcsLbOnly; + gpr_log(GPR_INFO, "First batch drop rate %f", seen_drop_rate); + EXPECT_THAT(seen_drop_rate, + ::testing::DoubleNear(kDropRateForLb, kErrorTolerance)); + // The second ADS response contains two drop categories, send an update EDS + // response. + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the drop rate increases to the middle of the two configs, + // which implies that the update has been in effect. + const double kDropRateThreshold = + (kDropRateForLb + kDropRateForLbAndThrottle) / 2; + size_t num_rpcs = kNumRpcsBoth; + while (seen_drop_rate < kDropRateThreshold) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + ++num_rpcs; + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + seen_drop_rate = static_cast(num_drops) / num_rpcs; + } + // Send kNumRpcsBoth RPCs and count the drops. + num_drops = 0; + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + for (size_t i = 0; i < kNumRpcsBoth; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // The new drop rate should be roughly equal to the expectation. + seen_drop_rate = static_cast(num_drops) / kNumRpcsBoth; + gpr_log(GPR_INFO, "Second batch drop rate %f", seen_drop_rate); + EXPECT_THAT(seen_drop_rate, ::testing::DoubleNear(kDropRateForLbAndThrottle, + kErrorTolerance)); +} + +// Tests that all the RPCs are dropped if any drop category drops 100%. +TEST_P(DropTest, DropAll) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 1000; + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 1000000; + // The ADS response contains two drop categories. + EdsResourceArgs args; + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Send kNumRpcs RPCs and all of them are dropped. + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE); + EXPECT_THAT(status.error_message(), + ::testing::StartsWith("EDS-configured drop: ")); + } +} + +class BalancerUpdateTest : public XdsEnd2endTest { + public: + BalancerUpdateTest() : XdsEnd2endTest(4, 3) {} +}; + +// Tests that the old LB call is still used after the balancer address update +// as long as that call is still alive. +TEST_P(BalancerUpdateTest, UpdateBalancersButKeepUsingOriginalBalancer) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({{"locality0", CreateEndpointsForBackends(1, 2)}}); + balancers_[1]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the first backend is ready. + WaitForBackend(0); + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel({balancers_[1]->port()}); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // The current LB call is still working, so xds continued using it to the + // first balancer, which doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; +} + +// Tests that the old LB call is still used after multiple balancer address +// updates as long as that call is still alive. Send an update with the same +// set of LBs as the one in SetUp() in order to verify that the LB channel +// inside xds keeps the initial connection (which by definition is also +// present in the update). +TEST_P(BalancerUpdateTest, Repeated) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({{"locality0", CreateEndpointsForBackends(1, 2)}}); + balancers_[1]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the first backend is ready. + WaitForBackend(0); + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + std::vector ports; + ports.emplace_back(balancers_[0]->port()); + ports.emplace_back(balancers_[1]->port()); + ports.emplace_back(balancers_[2]->port()); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel(ports); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // xds continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + ports.clear(); + ports.emplace_back(balancers_[0]->port()); + ports.emplace_back(balancers_[1]->port()); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 =========="); + SetNextResolutionForLbChannel(ports); + gpr_log(GPR_INFO, "========= UPDATE 2 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // xds continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); +} + +// Tests that if the balancer is down, the RPCs will still be sent to the +// backends according to the last balancer response, until a new balancer is +// reachable. +TEST_P(BalancerUpdateTest, DeadUpdate) { + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + args = EdsResourceArgs({{"locality0", CreateEndpointsForBackends(1, 2)}}); + balancers_[1]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + // Kill balancer 0 + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + // This is serviced by the existing child policy. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should again have gone to the first backend. + EXPECT_EQ(20U, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // The ADS service of no balancers sent anything + EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[0]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel({balancers_[1]->port()}); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + // Wait until update has been processed, as signaled by the second backend + // receiving a request. In the meantime, the client continues to be serviced + // (by the first backend) without interruption. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + WaitForBackend(1); + // This is serviced by the updated RR policy + backends_[1]->backend_service()->ResetCounters(); + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->backend_service()->request_count()); + // The ADS service of balancer 1 sent at least 1 response. + EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[0]->ads_service()->eds_response_state().error_message; + EXPECT_GT(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; +} + +class ClientLoadReportingTest : public XdsEnd2endTest { + public: + ClientLoadReportingTest() : XdsEnd2endTest(4, 1, 3) {} +}; + +// Tests that the load report received at the balancer is correct. +TEST_P(ClientLoadReportingTest, Vanilla) { + if (GetParam().use_fake_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 10; + const size_t kNumFailuresPerAddress = 3; + // TODO(juanlishen): Partition the backends after multiple localities is + // tested. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + CheckRpcSendFailure(CheckRpcSendFailureOptions() + .set_times(kNumFailuresPerAddress * num_backends_) + .set_rpc_options(RpcOptions().set_server_fail(true))); + // Check that each backend got the right number of requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The load report received at the balancer should be correct. + std::vector load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats& client_stats = load_report.front(); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ + + num_ok + num_failure, + client_stats.total_issued_requests()); + EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure, + client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +// Tests send_all_clusters. +TEST_P(ClientLoadReportingTest, SendAllClusters) { + balancers_[0]->lrs_service()->set_send_all_clusters(true); + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 10; + const size_t kNumFailuresPerAddress = 3; + // TODO(juanlishen): Partition the backends after multiple localities is + // tested. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + CheckRpcSendFailure(CheckRpcSendFailureOptions() + .set_times(kNumFailuresPerAddress * num_backends_) + .set_rpc_options(RpcOptions().set_server_fail(true))); + // Check that each backend got the right number of requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The load report received at the balancer should be correct. + std::vector load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats& client_stats = load_report.front(); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ + + num_ok + num_failure, + client_stats.total_issued_requests()); + EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure, + client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +// Tests that we don't include stats for clusters that are not requested +// by the LRS server. +TEST_P(ClientLoadReportingTest, HonorsClustersRequestedByLrsServer) { + balancers_[0]->lrs_service()->set_cluster_names({"bogus"}); + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 100; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); + // The load report received at the balancer should be correct. + std::vector load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 0UL); +} + +// Tests that if the balancer restarts, the client load report contains the +// stats before and after the restart correctly. +TEST_P(ClientLoadReportingTest, BalancerRestart) { + if (GetParam().use_fake_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumBackendsFirstPass = backends_.size() / 2; + const size_t kNumBackendsSecondPass = + backends_.size() - kNumBackendsFirstPass; + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends(0, kNumBackendsFirstPass)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends returned by the balancer are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* start_index */ 0, + /* stop_index */ kNumBackendsFirstPass); + std::vector load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats client_stats = std::move(load_report.front()); + EXPECT_EQ(static_cast(num_ok), + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ(0U, client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // Shut down the balancer. + balancers_[0]->Shutdown(); + // We should continue using the last EDS response we received from the + // balancer before it was shut down. + // Note: We need to use WaitForAllBackends() here instead of just + // CheckRpcSendOk(kNumBackendsFirstPass), because when the balancer + // shuts down, the XdsClient will generate an error to the + // ServiceConfigWatcher, which will cause the xds resolver to send a + // no-op update to the LB policy. When this update gets down to the + // round_robin child policy for the locality, it will generate a new + // subchannel list, which resets the start index randomly. So we need + // to be a little more permissive here to avoid spurious failures. + ResetBackendCounters(); + int num_started = std::get<0>(WaitForAllBackends( + /* start_index */ 0, /* stop_index */ kNumBackendsFirstPass)); + // Now restart the balancer, this time pointing to the new backends. + balancers_[0]->Start(); + args = EdsResourceArgs({ + {"locality0", CreateEndpointsForBackends(kNumBackendsFirstPass)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for queries to start going to one of the new backends. + // This tells us that we're now using the new serverlist. + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* start_index */ kNumBackendsFirstPass); + num_started += num_ok + num_failure + num_drops; + // Send one RPC per backend. + CheckRpcSendOk(kNumBackendsSecondPass); + num_started += kNumBackendsSecondPass; + // Check client stats. + load_report = balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + client_stats = std::move(load_report.front()); + EXPECT_EQ(num_started, client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ(0U, client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); +} + +class ClientLoadReportingWithDropTest : public XdsEnd2endTest { + public: + ClientLoadReportingWithDropTest() : XdsEnd2endTest(4, 1, 20) {} +}; + +// Tests that the drop stats are correctly reported by client load reporting. +TEST_P(ClientLoadReportingWithDropTest, Vanilla) { + if (GetParam().use_fake_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kErrorTolerance = 0.05; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double kDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + const size_t kNumRpcs = + ComputeIdealNumRpcs(kDropRateForLbAndThrottle, kErrorTolerance); + // The ADS response contains two drop categories. + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + const size_t num_warmup = num_ok + num_failure + num_drops; + // Send kNumRpcs RPCs and count the drops. + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + absl::StartsWith(status.error_message(), "EDS-configured drop: ")) { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast(num_drops) / kNumRpcs; + EXPECT_THAT(seen_drop_rate, ::testing::DoubleNear(kDropRateForLbAndThrottle, + kErrorTolerance)); + // Check client stats. + const size_t total_rpc = num_warmup + kNumRpcs; + ClientStats client_stats; + do { + std::vector load_reports = + balancers_[0]->lrs_service()->WaitForLoadReport(); + for (const auto& load_report : load_reports) { + client_stats += load_report; + } + } while (client_stats.total_issued_requests() + + client_stats.total_dropped_requests() < + total_rpc); + EXPECT_EQ(num_drops, client_stats.total_dropped_requests()); + EXPECT_THAT(static_cast(client_stats.dropped_requests(kLbDropType)) / + total_rpc, + ::testing::DoubleNear(kDropRateForLb, kErrorTolerance)); + EXPECT_THAT( + static_cast(client_stats.dropped_requests(kThrottleDropType)) / + (total_rpc * (1 - kDropRateForLb)), + ::testing::DoubleNear(kDropRateForThrottle, kErrorTolerance)); +} + +class FaultInjectionTest : public XdsEnd2endTest { + public: + FaultInjectionTest() : XdsEnd2endTest(1, 1) {} + + // Builds a Listener with Fault Injection filter config. If the http_fault + // is nullptr, then assign an empty filter config. This filter config is + // required to enable the fault injection features. + static Listener BuildListenerWithFaultInjection( + const HTTPFault& http_fault = HTTPFault()) { + HttpConnectionManager http_connection_manager; + Listener listener; + listener.set_name(kServerName); + HttpFilter* fault_filter = http_connection_manager.add_http_filters(); + fault_filter->set_name("envoy.fault"); + fault_filter->mutable_typed_config()->PackFrom(http_fault); + HttpFilter* router_filter = http_connection_manager.add_http_filters(); + router_filter->set_name("router"); + router_filter->mutable_typed_config()->PackFrom( + envoy::extensions::filters::http::router::v3::Router()); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + return listener; + } + + RouteConfiguration BuildRouteConfigurationWithFaultInjection( + const HTTPFault& http_fault) { + // Package as Any + google::protobuf::Any filter_config; + filter_config.PackFrom(http_fault); + // Plug into the RouteConfiguration + RouteConfiguration new_route_config = default_route_config_; + auto* config_map = new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_typed_per_filter_config(); + (*config_map)["envoy.fault"] = std::move(filter_config); + return new_route_config; + } + + void SetFilterConfig(HTTPFault& http_fault) { + switch (GetParam().filter_config_setup()) { + case TestType::FilterConfigSetup::kRouteOverride: { + Listener listener = BuildListenerWithFaultInjection(); + RouteConfiguration route = + BuildRouteConfigurationWithFaultInjection(http_fault); + SetListenerAndRouteConfiguration(0, listener, route); + break; + } + case TestType::FilterConfigSetup::kHTTPConnectionManagerOriginal: { + Listener listener = BuildListenerWithFaultInjection(http_fault); + SetListenerAndRouteConfiguration(0, listener, default_route_config_); + } + }; + } +}; + +// Test to ensure the most basic fault injection config works. +TEST_P(FaultInjectionTest, XdsFaultInjectionAlwaysAbort) { + const uint32_t kAbortPercentagePerHundred = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(kAbortPercentagePerHundred); + abort_percentage->set_denominator(FractionalPercent::HUNDRED); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Fire several RPCs, and expect all of them to be aborted. + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_times(5) + .set_rpc_options(RpcOptions().set_wait_for_ready(true)) + .set_expected_error_code(StatusCode::ABORTED)); +} + +// Without the listener config, the fault injection won't be enabled. +TEST_P(FaultInjectionTest, XdsFaultInjectionWithoutListenerFilter) { + const uint32_t kAbortPercentagePerHundred = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(kAbortPercentagePerHundred); + abort_percentage->set_denominator(FractionalPercent::HUNDRED); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + // Turn on fault injection + RouteConfiguration route = + BuildRouteConfigurationWithFaultInjection(http_fault); + SetListenerAndRouteConfiguration(0, default_listener_, route); + // Fire several RPCs, and expect all of them to be pass. + CheckRpcSendOk(5, RpcOptions().set_wait_for_ready(true)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionPercentageAbort) { + const uint32_t kAbortPercentagePerHundred = 50; + const double kAbortRate = kAbortPercentagePerHundred / 100.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kAbortRate, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(kAbortPercentagePerHundred); + abort_percentage->set_denominator(FractionalPercent::HUNDRED); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Send kNumRpcs RPCs and count the aborts. + int num_total = 0, num_ok = 0, num_failure = 0, num_aborted = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_aborted, + RpcOptions(), "Fault injected"); + } + EXPECT_EQ(kNumRpcs, num_total); + EXPECT_EQ(0, num_failure); + // The abort rate should be roughly equal to the expectation. + const double seen_abort_rate = static_cast(num_aborted) / kNumRpcs; + EXPECT_THAT(seen_abort_rate, + ::testing::DoubleNear(kAbortRate, kErrorTolerance)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionPercentageAbortViaHeaders) { + const uint32_t kAbortPercentageCap = 100; + const uint32_t kAbortPercentage = 50; + const double kAbortRate = kAbortPercentage / 100.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kAbortRate, kErrorTolerance); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + http_fault.mutable_abort()->mutable_header_abort(); + http_fault.mutable_abort()->mutable_percentage()->set_numerator( + kAbortPercentageCap); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Send kNumRpcs RPCs and count the aborts. + std::vector> metadata = { + {"x-envoy-fault-abort-grpc-request", "10"}, + {"x-envoy-fault-abort-percentage", std::to_string(kAbortPercentage)}, + }; + int num_total = 0, num_ok = 0, num_failure = 0, num_aborted = 0; + RpcOptions options = RpcOptions().set_metadata(metadata); + for (size_t i = 0; i < kNumRpcs; ++i) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_aborted, options, + "Fault injected"); + } + EXPECT_EQ(kNumRpcs, num_total); + EXPECT_EQ(0, num_failure); + // The abort rate should be roughly equal to the expectation. + const double seen_abort_rate = static_cast(num_aborted) / kNumRpcs; + EXPECT_THAT(seen_abort_rate, + ::testing::DoubleNear(kAbortRate, kErrorTolerance)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionPercentageDelay) { + const uint32_t kRpcTimeoutMilliseconds = grpc_test_slowdown_factor() * 3000; + const uint32_t kFixedDelaySeconds = 100; + const uint32_t kDelayPercentagePerHundred = 50; + const double kDelayRate = kDelayPercentagePerHundred / 100.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kDelayRate, kErrorTolerance); + const size_t kMaxConcurrentRequests = kNumRpcs; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Loosen the max concurrent request limit + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(kDelayPercentagePerHundred); + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Send kNumRpcs RPCs and count the delays. + RpcOptions rpc_options = RpcOptions() + .set_timeout_ms(kRpcTimeoutMilliseconds) + .set_skip_cancelled_check(true); + std::vector rpcs = + SendConcurrentRpcs(stub_.get(), kNumRpcs, rpc_options); + size_t num_delayed = 0; + for (auto& rpc : rpcs) { + if (rpc.status.error_code() == StatusCode::OK) continue; + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, rpc.status.error_code()); + ++num_delayed; + } + // The delay rate should be roughly equal to the expectation. + const double seen_delay_rate = static_cast(num_delayed) / kNumRpcs; + EXPECT_THAT(seen_delay_rate, + ::testing::DoubleNear(kDelayRate, kErrorTolerance)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionPercentageDelayViaHeaders) { + const uint32_t kFixedDelayMilliseconds = 100000; + const uint32_t kRpcTimeoutMilliseconds = grpc_test_slowdown_factor() * 3000; + const uint32_t kDelayPercentageCap = 100; + const uint32_t kDelayPercentage = 50; + const double kDelayRate = kDelayPercentage / 100.0; + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kDelayRate, kErrorTolerance); + const size_t kMaxConcurrentRequests = kNumRpcs; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Loosen the max concurrent request limit + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Construct the fault injection filter config + HTTPFault http_fault; + http_fault.mutable_delay()->mutable_header_delay(); + http_fault.mutable_delay()->mutable_percentage()->set_numerator( + kDelayPercentageCap); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Send kNumRpcs RPCs and count the delays. + std::vector> metadata = { + {"x-envoy-fault-delay-request", std::to_string(kFixedDelayMilliseconds)}, + {"x-envoy-fault-delay-request-percentage", + std::to_string(kDelayPercentage)}, + }; + RpcOptions rpc_options = RpcOptions() + .set_metadata(metadata) + .set_timeout_ms(kRpcTimeoutMilliseconds) + .set_skip_cancelled_check(true); + std::vector rpcs = + SendConcurrentRpcs(stub_.get(), kNumRpcs, rpc_options); + size_t num_delayed = 0; + for (auto& rpc : rpcs) { + if (rpc.status.error_code() == StatusCode::OK) continue; + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, rpc.status.error_code()); + ++num_delayed; + } + // The delay rate should be roughly equal to the expectation. + const double seen_delay_rate = static_cast(num_delayed) / kNumRpcs; + EXPECT_THAT(seen_delay_rate, + ::testing::DoubleNear(kDelayRate, kErrorTolerance)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionAbortAfterDelayForStreamCall) { + const uint32_t kFixedDelaySeconds = 1; + const uint32_t kRpcTimeoutMilliseconds = 100 * 1000; // 100s should not reach + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(100); // Always inject ABORT! + abort_percentage->set_denominator(FractionalPercent::HUNDRED); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(100); // Always inject DELAY! + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Send a stream RPC and check its status code + ClientContext context; + context.set_deadline( + grpc_timeout_milliseconds_to_deadline(kRpcTimeoutMilliseconds)); + auto stream = stub_->BidiStream(&context); + stream->WritesDone(); + auto status = stream->Finish(); + EXPECT_EQ(StatusCode::ABORTED, status.error_code()) + << status.error_message() << ", " << status.error_details() << ", " + << context.debug_error_string(); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionAlwaysDelayPercentageAbort) { + const uint32_t kAbortPercentagePerHundred = 50; + const double kAbortRate = kAbortPercentagePerHundred / 100.0; + const uint32_t kFixedDelaySeconds = 1; + const uint32_t kRpcTimeoutMilliseconds = 100 * 1000; // 100s should not reach + const uint32_t kConnectionTimeoutMilliseconds = + 10 * 1000; // 10s should not reach + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kAbortRate, kErrorTolerance); + const size_t kMaxConcurrentRequests = kNumRpcs; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Loosen the max concurrent request limit + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(kAbortPercentagePerHundred); + abort_percentage->set_denominator(FractionalPercent::HUNDRED); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(1000000); // Always inject DELAY! + delay_percentage->set_denominator(FractionalPercent::MILLION); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Allow the channel to connect to one backends, so the herd of queued RPCs + // won't be executed on the same ExecCtx object and using the cached Now() + // value, which causes millisecond level delay error. + channel_->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kConnectionTimeoutMilliseconds)); + // Send kNumRpcs RPCs and count the aborts. + int num_aborted = 0; + RpcOptions rpc_options = RpcOptions().set_timeout_ms(kRpcTimeoutMilliseconds); + std::vector rpcs = + SendConcurrentRpcs(stub_.get(), kNumRpcs, rpc_options); + for (auto& rpc : rpcs) { + EXPECT_GE(rpc.elapsed_time, kFixedDelaySeconds * 1000); + if (rpc.status.error_code() == StatusCode::OK) continue; + EXPECT_EQ("Fault injected", rpc.status.error_message()); + ++num_aborted; + } + // The abort rate should be roughly equal to the expectation. + const double seen_abort_rate = static_cast(num_aborted) / kNumRpcs; + EXPECT_THAT(seen_abort_rate, + ::testing::DoubleNear(kAbortRate, kErrorTolerance)); +} + +// This test and the above test apply different denominators to delay and +// abort. This ensures that we are using the right denominator for each +// injected fault in our code. +TEST_P(FaultInjectionTest, + XdsFaultInjectionAlwaysDelayPercentageAbortSwitchDenominator) { + const uint32_t kAbortPercentagePerMillion = 500000; + const double kAbortRate = kAbortPercentagePerMillion / 1000000.0; + const uint32_t kFixedDelaySeconds = 1; // 1s + const uint32_t kRpcTimeoutMilliseconds = 100 * 1000; // 100s should not reach + const uint32_t kConnectionTimeoutMilliseconds = + 10 * 1000; // 10s should not reach + const double kErrorTolerance = 0.05; + const size_t kNumRpcs = ComputeIdealNumRpcs(kAbortRate, kErrorTolerance); + const size_t kMaxConcurrentRequests = kNumRpcs; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Loosen the max concurrent request limit + Cluster cluster = default_cluster_; + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* abort_percentage = http_fault.mutable_abort()->mutable_percentage(); + abort_percentage->set_numerator(kAbortPercentagePerMillion); + abort_percentage->set_denominator(FractionalPercent::MILLION); + http_fault.mutable_abort()->set_grpc_status( + static_cast(StatusCode::ABORTED)); + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(100); // Always inject DELAY! + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Allow the channel to connect to one backends, so the herd of queued RPCs + // won't be executed on the same ExecCtx object and using the cached Now() + // value, which causes millisecond level delay error. + channel_->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kConnectionTimeoutMilliseconds)); + // Send kNumRpcs RPCs and count the aborts. + int num_aborted = 0; + RpcOptions rpc_options = RpcOptions().set_timeout_ms(kRpcTimeoutMilliseconds); + std::vector rpcs = + SendConcurrentRpcs(stub_.get(), kNumRpcs, rpc_options); + for (auto& rpc : rpcs) { + EXPECT_GE(rpc.elapsed_time, kFixedDelaySeconds * 1000); + if (rpc.status.error_code() == StatusCode::OK) continue; + EXPECT_EQ("Fault injected", rpc.status.error_message()); + ++num_aborted; + } + // The abort rate should be roughly equal to the expectation. + const double seen_abort_rate = static_cast(num_aborted) / kNumRpcs; + EXPECT_THAT(seen_abort_rate, + ::testing::DoubleNear(kAbortRate, kErrorTolerance)); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionMaxFault) { + const uint32_t kMaxFault = 10; + const uint32_t kNumRpcs = 30; // kNumRpcs should be bigger than kMaxFault + const uint32_t kRpcTimeoutMs = 4000; // 4 seconds + const uint32_t kLongDelaySeconds = 100; // 100 seconds + const uint32_t kAlwaysDelayPercentage = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator( + kAlwaysDelayPercentage); // Always inject DELAY! + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kLongDelaySeconds); + http_fault.mutable_max_active_faults()->set_value(kMaxFault); + // Config fault injection via different setup + SetFilterConfig(http_fault); + // Sends a batch of long running RPCs with long timeout to consume all + // active faults quota. + int num_delayed = 0; + RpcOptions rpc_options = RpcOptions().set_timeout_ms(kRpcTimeoutMs); + std::vector rpcs = + SendConcurrentRpcs(stub_.get(), kNumRpcs, rpc_options); + for (auto& rpc : rpcs) { + if (rpc.status.error_code() == StatusCode::OK) continue; + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, rpc.status.error_code()); + ++num_delayed; + } + // Only kMaxFault number of RPC should be fault injected.. + EXPECT_EQ(kMaxFault, num_delayed); +} + +TEST_P(FaultInjectionTest, XdsFaultInjectionBidiStreamDelayOk) { + // kRpcTimeoutMilliseconds is 10s should never be reached. + const uint32_t kRpcTimeoutMilliseconds = grpc_test_slowdown_factor() * 10000; + const uint32_t kFixedDelaySeconds = 1; + const uint32_t kDelayPercentagePerHundred = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(kDelayPercentagePerHundred); + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + ClientContext context; + context.set_deadline( + grpc_timeout_milliseconds_to_deadline(kRpcTimeoutMilliseconds)); + auto stream = stub_->BidiStream(&context); + stream->WritesDone(); + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()) << status.error_message() << ", " + << status.error_details() << ", " + << context.debug_error_string(); +} + +// This case catches a bug in the retry code that was triggered by a bad +// interaction with the FI code. See https://github.com/grpc/grpc/pull/27217 +// for description. +TEST_P(FaultInjectionTest, XdsFaultInjectionBidiStreamDelayError) { + const uint32_t kRpcTimeoutMilliseconds = grpc_test_slowdown_factor() * 500; + const uint32_t kFixedDelaySeconds = 100; + const uint32_t kDelayPercentagePerHundred = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create an EDS resource + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Construct the fault injection filter config + HTTPFault http_fault; + auto* delay_percentage = http_fault.mutable_delay()->mutable_percentage(); + delay_percentage->set_numerator(kDelayPercentagePerHundred); + delay_percentage->set_denominator(FractionalPercent::HUNDRED); + auto* fixed_delay = http_fault.mutable_delay()->mutable_fixed_delay(); + fixed_delay->set_seconds(kFixedDelaySeconds); + // Config fault injection via different setup + SetFilterConfig(http_fault); + ClientContext context; + context.set_deadline( + grpc_timeout_milliseconds_to_deadline(kRpcTimeoutMilliseconds)); + auto stream = stub_->BidiStream(&context); + stream->WritesDone(); + auto status = stream->Finish(); + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, status.error_code()) + << status.error_message() << ", " << status.error_details() << ", " + << context.debug_error_string(); +} + +class BootstrapSourceTest : public XdsEnd2endTest { + public: + BootstrapSourceTest() : XdsEnd2endTest(4, 1) {} +}; + +TEST_P(BootstrapSourceTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({ + {"locality0", CreateEndpointsForBackends()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); +} + +#ifndef DISABLED_XDS_PROTO_IN_CC +class ClientStatusDiscoveryServiceTest : public XdsEnd2endTest { + public: + ClientStatusDiscoveryServiceTest() : XdsEnd2endTest(1, 1) {} + + void SetUp() override { + XdsEnd2endTest::SetUp(); + admin_server_thread_ = absl::make_unique(this); + admin_server_thread_->Start(); + std::string admin_server_address = absl::StrCat( + ipv6_only_ ? "[::1]:" : "127.0.0.1:", admin_server_thread_->port()); + admin_channel_ = grpc::CreateChannel( + admin_server_address, + std::make_shared( + grpc_fake_transport_security_credentials_create())); + csds_stub_ = + envoy::service::status::v3::ClientStatusDiscoveryService::NewStub( + admin_channel_); + if (GetParam().use_csds_streaming()) { + stream_ = csds_stub_->StreamClientStatus(&stream_context_); + } + } + + void TearDown() override { + if (stream_ != nullptr) { + EXPECT_TRUE(stream_->WritesDone()); + Status status = stream_->Finish(); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + } + admin_server_thread_->Shutdown(); + XdsEnd2endTest::TearDown(); + } + + envoy::service::status::v3::ClientStatusResponse FetchCsdsResponse() { + envoy::service::status::v3::ClientStatusResponse response; + if (!GetParam().use_csds_streaming()) { + // Fetch through unary pulls + ClientContext context; + Status status = csds_stub_->FetchClientStatus( + &context, envoy::service::status::v3::ClientStatusRequest(), + &response); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + } else { + // Fetch through streaming pulls + EXPECT_TRUE( + stream_->Write(envoy::service::status::v3::ClientStatusRequest())); + EXPECT_TRUE(stream_->Read(&response)); + } + return response; + } + + private: + std::unique_ptr admin_server_thread_; + std::shared_ptr admin_channel_; + std::unique_ptr< + envoy::service::status::v3::ClientStatusDiscoveryService::Stub> + csds_stub_; + ClientContext stream_context_; + std::unique_ptr< + ClientReaderWriter> + stream_; +}; + +MATCHER_P4(EqNode, id, user_agent_name, user_agent_version, client_features, + "equals Node") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(id, arg.id(), result_listener); + ok &= ::testing::ExplainMatchResult(user_agent_name, arg.user_agent_name(), + result_listener); + ok &= ::testing::ExplainMatchResult( + user_agent_version, arg.user_agent_version(), result_listener); + ok &= ::testing::ExplainMatchResult(client_features, arg.client_features(), + result_listener); + return ok; +} + +MATCHER_P2(EqListenersConfigDump, version_info, dynamic_listeners, + "equals ListenerConfigDump") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(::testing::ElementsAre(), + arg.static_listeners(), result_listener); + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + ok &= ::testing::ExplainMatchResult(dynamic_listeners, + arg.dynamic_listeners(), result_listener); + return ok; +} + +MATCHER_P2(EqDynamicListenerState, version_info, listener, + "equals DynamicListenerState") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + ok &= + ::testing::ExplainMatchResult(listener, arg.listener(), result_listener); + return ok; +} + +MATCHER_P2(EqListener, name, api_listener, "equals Listener") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(name, arg.name(), result_listener); + ok &= ::testing::ExplainMatchResult( + api_listener, arg.api_listener().api_listener(), result_listener); + return ok; +} + +MATCHER_P(EqHttpConnectionManagerNotRds, route_config, + "equals HttpConnectionManager") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(route_config, arg.route_config(), + result_listener); + return ok; +} + +MATCHER_P(EqRouteConfigurationName, name, "equals RouteConfiguration") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(name, arg.name(), result_listener); + return ok; +} + +MATCHER_P2(EqRouteConfiguration, name, cluster_name, + "equals RouteConfiguration") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(name, arg.name(), result_listener); + ok &= ::testing::ExplainMatchResult( + ::testing::ElementsAre(::testing::Property( + &envoy::config::route::v3::VirtualHost::routes, + ::testing::ElementsAre(::testing::Property( + &envoy::config::route::v3::Route::route, + ::testing::Property( + &envoy::config::route::v3::RouteAction::cluster, + cluster_name))))), + arg.virtual_hosts(), result_listener); + return ok; +} + +MATCHER_P(EqRoutesConfigDump, dynamic_route_configs, + "equals RoutesConfigDump") { + bool ok = true; + ok &= ::testing::ExplainMatchResult( + ::testing::ElementsAre(), arg.static_route_configs(), result_listener); + ok &= ::testing::ExplainMatchResult( + dynamic_route_configs, arg.dynamic_route_configs(), result_listener); + return ok; +} + +MATCHER_P2(EqClustersConfigDump, version_info, dynamic_active_clusters, + "equals ClustersConfigDump") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(::testing::ElementsAre(), + arg.static_clusters(), result_listener); + ok &= ::testing::ExplainMatchResult(::testing::ElementsAre(), + arg.dynamic_warming_clusters(), + result_listener); + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + ok &= ::testing::ExplainMatchResult( + dynamic_active_clusters, arg.dynamic_active_clusters(), result_listener); + return ok; +} + +MATCHER_P(EqCluster, name, "equals Cluster") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(name, arg.name(), result_listener); + return ok; +} + +MATCHER_P(EqEndpointsConfigDump, dynamic_endpoint_configs, + "equals EndpointsConfigDump") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(dynamic_endpoint_configs, + arg.dynamic_endpoint_configs(), + result_listener); + return ok; +} + +MATCHER_P(EqEndpoint, port, "equals Endpoint") { + bool ok = true; + ok &= ::testing::ExplainMatchResult( + port, arg.address().socket_address().port_value(), result_listener); + return ok; +} + +MATCHER_P2(EqLocalityLbEndpoints, port, weight, "equals LocalityLbEndpoints") { + bool ok = true; + ok &= ::testing::ExplainMatchResult( + ::testing::ElementsAre(::testing::Property( + &envoy::config::endpoint::v3::LbEndpoint::endpoint, + EqEndpoint(port))), + arg.lb_endpoints(), result_listener); + ok &= ::testing::ExplainMatchResult( + weight, arg.load_balancing_weight().value(), result_listener); + return ok; +} + +MATCHER_P(EqClusterLoadAssignmentName, cluster_name, + "equals ClusterLoadAssignment") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(cluster_name, arg.cluster_name(), + result_listener); + return ok; +} + +MATCHER_P3(EqClusterLoadAssignment, cluster_name, port, weight, + "equals ClusterLoadAssignment") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(cluster_name, arg.cluster_name(), + result_listener); + ok &= ::testing::ExplainMatchResult( + ::testing::ElementsAre(EqLocalityLbEndpoints(port, weight)), + arg.endpoints(), result_listener); + return ok; +} + +MATCHER_P2(EqUpdateFailureState, details, version_info, + "equals UpdateFailureState") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(details, arg.details(), result_listener); + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + return ok; +} + +MATCHER_P(UnpackListener, matcher, "is a Listener") { + Listener config; + if (!::testing::ExplainMatchResult(true, arg.UnpackTo(&config), + result_listener)) { + return false; + } + return ::testing::ExplainMatchResult(matcher, config, result_listener); +} + +MATCHER_P(UnpackRouteConfiguration, matcher, "is a RouteConfiguration") { + RouteConfiguration config; + if (!::testing::ExplainMatchResult(true, arg.UnpackTo(&config), + result_listener)) { + return false; + } + return ::testing::ExplainMatchResult(matcher, config, result_listener); +} + +MATCHER_P(UnpackHttpConnectionManager, matcher, "is a HttpConnectionManager") { + HttpConnectionManager config; + if (!::testing::ExplainMatchResult(true, arg.UnpackTo(&config), + result_listener)) { + return false; + } + return ::testing::ExplainMatchResult(matcher, config, result_listener); +} + +MATCHER_P(UnpackCluster, matcher, "is a Cluster") { + Cluster config; + if (!::testing::ExplainMatchResult(true, arg.UnpackTo(&config), + result_listener)) { + return false; + } + return ::testing::ExplainMatchResult(matcher, config, result_listener); +} + +MATCHER_P(UnpackClusterLoadAssignment, matcher, "is a ClusterLoadAssignment") { + ClusterLoadAssignment config; + if (!::testing::ExplainMatchResult(true, arg.UnpackTo(&config), + result_listener)) { + return false; + } + return ::testing::ExplainMatchResult(matcher, config, result_listener); +} + +MATCHER_P5(EqDynamicListener, name, version_info, client_status, + api_listener_matcher, error_state, "equals DynamicListener") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(false, arg.has_warming_state(), + result_listener); + ok &= ::testing::ExplainMatchResult(false, arg.has_draining_state(), + result_listener); + ok &= ::testing::ExplainMatchResult(name, arg.name(), result_listener); + ok &= ::testing::ExplainMatchResult(client_status, arg.client_status(), + result_listener); + if (client_status == ClientResourceStatus::ACKED || + client_status == ClientResourceStatus::NACKED) { + ok &= ::testing::ExplainMatchResult( + EqDynamicListenerState(version_info, UnpackListener(EqListener( + name, api_listener_matcher))), + arg.active_state(), result_listener); + } + ok &= ::testing::ExplainMatchResult(error_state, arg.error_state(), + result_listener); + return ok; +} + +MATCHER_P5(EqDynamicRouteConfig, name, version_info, client_status, + cluster_name, error_state, "equals DynamicRouteConfig") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + if (client_status == ClientResourceStatus::REQUESTED || + client_status == ClientResourceStatus::DOES_NOT_EXIST) { + ok &= ::testing::ExplainMatchResult( + UnpackRouteConfiguration(EqRouteConfigurationName(name)), + arg.route_config(), result_listener); + } else { + ok &= ::testing::ExplainMatchResult( + UnpackRouteConfiguration(EqRouteConfiguration(name, cluster_name)), + arg.route_config(), result_listener); + } + ok &= ::testing::ExplainMatchResult(error_state, arg.error_state(), + result_listener); + ok &= ::testing::ExplainMatchResult(client_status, arg.client_status(), + result_listener); + return ok; +} + +MATCHER_P4(EqDynamicCluster, name, version_info, client_status, error_state, + "equals DynamicCluster") { + bool ok = true; + ok &= ::testing::ExplainMatchResult(UnpackCluster(EqCluster(name)), + arg.cluster(), result_listener); + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + ok &= ::testing::ExplainMatchResult(client_status, arg.client_status(), + result_listener); + ok &= ::testing::ExplainMatchResult(error_state, arg.error_state(), + result_listener); + return ok; +} + +MATCHER_P6(EqDynamicEndpointConfig, name, version_info, client_status, port, + weight, error_state, "equals DynamicEndpointConfig") { + bool ok = true; + if (client_status == ClientResourceStatus::REQUESTED || + client_status == ClientResourceStatus::DOES_NOT_EXIST) { + ok &= ::testing::ExplainMatchResult( + UnpackClusterLoadAssignment(EqClusterLoadAssignmentName(name)), + arg.endpoint_config(), result_listener); + } else { + ok &= ::testing::ExplainMatchResult( + UnpackClusterLoadAssignment( + EqClusterLoadAssignment(name, port, weight)), + arg.endpoint_config(), result_listener); + } + ok &= ::testing::ExplainMatchResult(version_info, arg.version_info(), + result_listener); + ok &= ::testing::ExplainMatchResult(client_status, arg.client_status(), + result_listener); + ok &= ::testing::ExplainMatchResult(error_state, arg.error_state(), + result_listener); + return ok; +} + +MATCHER(IsRdsEnabledHCM, "is a RDS enabled HttpConnectionManager") { + return ::testing::ExplainMatchResult( + UnpackHttpConnectionManager( + ::testing::Property(&HttpConnectionManager::has_rds, true)), + arg, result_listener); +} + +MATCHER_P2(EqNoRdsHCM, route_configuration_name, cluster_name, + "equals RDS disabled HttpConnectionManager") { + return ::testing::ExplainMatchResult( + UnpackHttpConnectionManager(EqHttpConnectionManagerNotRds( + EqRouteConfiguration(route_configuration_name, cluster_name))), + arg, result_listener); +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpVanilla) { + const size_t kNumRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Send several RPCs to ensure the xDS setup works + CheckRpcSendOk(kNumRpcs); + // Fetches the client config + auto csds_response = FetchCsdsResponse(); + gpr_log(GPR_INFO, "xDS config dump: %s", csds_response.DebugString().c_str()); + EXPECT_EQ(1, csds_response.config_size()); + const auto& client_config = csds_response.config(0); + // Validate the Node information + EXPECT_THAT(client_config.node(), + EqNode("xds_end2end_test", ::testing::HasSubstr("C-core"), + ::testing::HasSubstr(grpc_version_string()), + ::testing::ElementsAre( + "envoy.lb.does_not_support_overprovisioning"))); + // Prepare matches for RDS on or off + ::testing::Matcher api_listener_matcher; + ::testing::Matcher + route_config_dump_matcher; + if (GetParam().enable_rds_testing()) { + api_listener_matcher = IsRdsEnabledHCM(); + route_config_dump_matcher = + EqRoutesConfigDump(::testing::ElementsAre(EqDynamicRouteConfig( + kDefaultRouteConfigurationName, "1", ClientResourceStatus::ACKED, + kDefaultClusterName, ::testing::_))); + } else { + api_listener_matcher = + EqNoRdsHCM(kDefaultRouteConfigurationName, kDefaultClusterName); + route_config_dump_matcher = EqRoutesConfigDump(::testing::ElementsAre()); + } + // Validate the dumped xDS configs + EXPECT_THAT( + client_config.xds_config(), + ::testing::UnorderedElementsAre( + ::testing::Property( + &envoy::service::status::v3::PerXdsConfig::listener_config, + EqListenersConfigDump( + "1", ::testing::ElementsAre(EqDynamicListener( + kServerName, "1", ClientResourceStatus::ACKED, + api_listener_matcher, ::testing::_)))), + ::testing::Property( + &envoy::service::status::v3::PerXdsConfig::route_config, + route_config_dump_matcher), + ::testing::Property( + &envoy::service::status::v3::PerXdsConfig::cluster_config, + EqClustersConfigDump( + "1", ::testing::ElementsAre(EqDynamicCluster( + kDefaultClusterName, "1", + ClientResourceStatus::ACKED, ::testing::_)))), + ::testing::Property( + &envoy::service::status::v3::PerXdsConfig::endpoint_config, + EqEndpointsConfigDump( + ::testing::ElementsAre(EqDynamicEndpointConfig( + kDefaultEdsServiceName, "1", ClientResourceStatus::ACKED, + backends_[0]->port(), kDefaultLocalityWeight, + ::testing::_)))))); +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpEmpty) { + // The CSDS service should not fail if XdsClient is not initialized or there + // is no working xDS configs. + FetchCsdsResponse(); +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpListenerError) { + int kFetchConfigRetries = 3; + int kFetchIntervalMilliseconds = 200; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Ensure the xDS resolver has working configs. + CheckRpcSendOk(); + // Bad Listener should be rejected. + Listener listener; + listener.set_name(kServerName); + balancers_[0]->ads_service()->SetLdsResource(listener); + // The old xDS configs should still be effective. + CheckRpcSendOk(); + ::testing::Matcher api_listener_matcher; + if (GetParam().enable_rds_testing()) { + api_listener_matcher = IsRdsEnabledHCM(); + } else { + api_listener_matcher = + EqNoRdsHCM(kDefaultRouteConfigurationName, kDefaultClusterName); + } + for (int o = 0; o < kFetchConfigRetries; o++) { + auto csds_response = FetchCsdsResponse(); + // Check if error state is propagated + bool ok = ::testing::Value( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::listener_config, + EqListenersConfigDump( + "1", + ::testing::ElementsAre(EqDynamicListener( + kServerName, "1", ClientResourceStatus::NACKED, + api_listener_matcher, + EqUpdateFailureState( + ::testing::HasSubstr( + "Listener has neither address nor ApiListener"), + "2"))))))); + if (ok) return; // TEST PASSED! + gpr_sleep_until( + grpc_timeout_milliseconds_to_deadline(kFetchIntervalMilliseconds)); + } + FAIL() << "error_state not seen in CSDS responses"; +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpRouteError) { + int kFetchConfigRetries = 3; + int kFetchIntervalMilliseconds = 200; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Ensure the xDS resolver has working configs. + CheckRpcSendOk(); + // Bad route config will be rejected. + RouteConfiguration route_config; + route_config.set_name(kDefaultRouteConfigurationName); + route_config.add_virtual_hosts(); + SetRouteConfiguration(0, route_config); + // The old xDS configs should still be effective. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); + for (int o = 0; o < kFetchConfigRetries; o++) { + auto csds_response = FetchCsdsResponse(); + bool ok = false; + if (GetParam().enable_rds_testing()) { + ok = ::testing::Value( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::route_config, + EqRoutesConfigDump(::testing::ElementsAre(EqDynamicRouteConfig( + kDefaultRouteConfigurationName, "1", + ClientResourceStatus::NACKED, kDefaultClusterName, + EqUpdateFailureState( + ::testing::HasSubstr("VirtualHost has no domains"), + "2"))))))); + } else { + ok = ::testing::Value( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::listener_config, + EqListenersConfigDump( + "1", + ::testing::ElementsAre(EqDynamicListener( + kServerName, "1", ClientResourceStatus::NACKED, + EqNoRdsHCM(kDefaultRouteConfigurationName, + kDefaultClusterName), + EqUpdateFailureState( + ::testing::HasSubstr("VirtualHost has no domains"), + "2"))))))); + } + if (ok) return; // TEST PASSED! + gpr_sleep_until( + grpc_timeout_milliseconds_to_deadline(kFetchIntervalMilliseconds)); + } + FAIL() << "error_state not seen in CSDS responses"; +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpClusterError) { + int kFetchConfigRetries = 3; + int kFetchIntervalMilliseconds = 200; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Ensure the xDS resolver has working configs. + CheckRpcSendOk(); + // Listener without any route, will be rejected. + Cluster cluster; + cluster.set_name(kDefaultClusterName); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // The old xDS configs should still be effective. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); + for (int o = 0; o < kFetchConfigRetries; o++) { + auto csds_response = FetchCsdsResponse(); + // Check if error state is propagated + bool ok = ::testing::Value( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::cluster_config, + EqClustersConfigDump( + "1", ::testing::ElementsAre(EqDynamicCluster( + kDefaultClusterName, "1", ClientResourceStatus::NACKED, + EqUpdateFailureState( + ::testing::HasSubstr("DiscoveryType not found"), + "2"))))))); + if (ok) return; // TEST PASSED! + gpr_sleep_until( + grpc_timeout_milliseconds_to_deadline(kFetchIntervalMilliseconds)); + } + FAIL() << "error_state not seen in CSDS responses"; +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpEndpointError) { + int kFetchConfigRetries = 3; + int kFetchIntervalMilliseconds = 200; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + EdsResourceArgs args({{"locality0", CreateEndpointsForBackends(0, 1)}}); + balancers_[0]->ads_service()->SetEdsResource( + BuildEdsResource(args, DefaultEdsServiceName())); + // Ensure the xDS resolver has working configs. + CheckRpcSendOk(); + // Bad endpoint config will be rejected. + ClusterLoadAssignment cluster_load_assignment; + cluster_load_assignment.set_cluster_name(kDefaultEdsServiceName); + auto* endpoints = cluster_load_assignment.add_endpoints(); + endpoints->mutable_load_balancing_weight()->set_value(1); + auto* endpoint = endpoints->add_lb_endpoints()->mutable_endpoint(); + endpoint->mutable_address()->mutable_socket_address()->set_port_value(1 << 1); + balancers_[0]->ads_service()->SetEdsResource(cluster_load_assignment); + // The old xDS configs should still be effective. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); + for (int o = 0; o < kFetchConfigRetries; o++) { + auto csds_response = FetchCsdsResponse(); + // Check if error state is propagated + bool ok = ::testing::Value( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::endpoint_config, + EqEndpointsConfigDump( + ::testing::ElementsAre(EqDynamicEndpointConfig( + kDefaultEdsServiceName, "1", ClientResourceStatus::NACKED, + backends_[0]->port(), kDefaultLocalityWeight, + EqUpdateFailureState(::testing::HasSubstr("Empty locality"), + "2"))))))); + if (ok) return; // TEST PASSED! + gpr_sleep_until( + grpc_timeout_milliseconds_to_deadline(kFetchIntervalMilliseconds)); + } + FAIL() << "error_state not seen in CSDS responses"; +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpListenerRequested) { + int kTimeoutMillisecond = 1000; + balancers_[0]->ads_service()->UnsetResource(kLdsTypeUrl, kServerName); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT(csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::listener_config, + EqListenersConfigDump( + ::testing::_, ::testing::ElementsAre(EqDynamicListener( + kServerName, ::testing::_, + ClientResourceStatus::REQUESTED, + ::testing::_, ::testing::_)))))); +} + +TEST_P(ClientStatusDiscoveryServiceTest, XdsConfigDumpClusterRequested) { + int kTimeoutMillisecond = 1000; + std::string kClusterName1 = "cluster-1"; + std::string kClusterName2 = "cluster-2"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Create a route config requesting two non-existing clusters + RouteConfiguration route_config; + route_config.set_name(kDefaultRouteConfigurationName); + auto* vh = route_config.add_virtual_hosts(); + // The VirtualHost must match the domain name, otherwise will cause resolver + // transient failure. + vh->add_domains("*"); + auto* routes1 = vh->add_routes(); + routes1->mutable_match()->set_prefix(""); + routes1->mutable_route()->set_cluster(kClusterName1); + auto* routes2 = vh->add_routes(); + routes2->mutable_match()->set_prefix(""); + routes2->mutable_route()->set_cluster(kClusterName2); + SetRouteConfiguration(0, route_config); + // Try to get the configs plumb through + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(StatusCode::DEADLINE_EXCEEDED)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT(csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::cluster_config, + EqClustersConfigDump( + ::testing::_, + ::testing::UnorderedElementsAre( + EqDynamicCluster(kClusterName1, ::testing::_, + ClientResourceStatus::REQUESTED, + ::testing::_), + EqDynamicCluster(kClusterName2, ::testing::_, + ClientResourceStatus::REQUESTED, + ::testing::_)))))); +} + +class CsdsShortAdsTimeoutTest : public ClientStatusDiscoveryServiceTest { + void SetUp() override { + // Shorten the ADS subscription timeout to speed up the test run. + xds_resource_does_not_exist_timeout_ms_ = 2000; + ClientStatusDiscoveryServiceTest::SetUp(); + } +}; + +TEST_P(CsdsShortAdsTimeoutTest, XdsConfigDumpListenerDoesNotExist) { + int kTimeoutMillisecond = 1000000; // 1000s wait for the transient failure. + balancers_[0]->ads_service()->UnsetResource(kLdsTypeUrl, kServerName); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(grpc::UNAVAILABLE)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT(csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::listener_config, + EqListenersConfigDump( + ::testing::_, ::testing::ElementsAre(EqDynamicListener( + kServerName, ::testing::_, + ClientResourceStatus::DOES_NOT_EXIST, + ::testing::_, ::testing::_)))))); +} + +TEST_P(CsdsShortAdsTimeoutTest, XdsConfigDumpRouteConfigDoesNotExist) { + if (!GetParam().enable_rds_testing()) return; + int kTimeoutMillisecond = 1000000; // 1000s wait for the transient failure. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + balancers_[0]->ads_service()->UnsetResource(kRdsTypeUrl, + kDefaultRouteConfigurationName); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(grpc::UNAVAILABLE)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::route_config, + EqRoutesConfigDump(::testing::ElementsAre( + EqDynamicRouteConfig(kDefaultRouteConfigurationName, ::testing::_, + ClientResourceStatus::DOES_NOT_EXIST, + ::testing::_, ::testing::_)))))); +} + +TEST_P(CsdsShortAdsTimeoutTest, XdsConfigDumpClusterDoesNotExist) { + int kTimeoutMillisecond = 1000000; // 1000s wait for the transient failure. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + balancers_[0]->ads_service()->UnsetResource(kCdsTypeUrl, kDefaultClusterName); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(grpc::UNAVAILABLE)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT(csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::cluster_config, + EqClustersConfigDump(::testing::_, + ::testing::ElementsAre(EqDynamicCluster( + kDefaultClusterName, ::testing::_, + ClientResourceStatus::DOES_NOT_EXIST, + ::testing::_)))))); +} + +TEST_P(CsdsShortAdsTimeoutTest, XdsConfigDumpEndpointDoesNotExist) { + int kTimeoutMillisecond = 1000000; // 1000s wait for the transient failure. + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + balancers_[0]->ads_service()->UnsetResource(kEdsTypeUrl, + kDefaultEdsServiceName); + CheckRpcSendFailure( + CheckRpcSendFailureOptions() + .set_rpc_options(RpcOptions().set_timeout_ms(kTimeoutMillisecond)) + .set_expected_error_code(grpc::UNAVAILABLE)); + auto csds_response = FetchCsdsResponse(); + EXPECT_THAT( + csds_response.config(0).xds_config(), + ::testing::Contains(::testing::Property( + &envoy::service::status::v3::PerXdsConfig::endpoint_config, + EqEndpointsConfigDump(::testing::ElementsAre(EqDynamicEndpointConfig( + kDefaultEdsServiceName, ::testing::_, + ClientResourceStatus::DOES_NOT_EXIST, ::testing::_, ::testing::_, + ::testing::_)))))); +} +#endif // DISABLED_XDS_PROTO_IN_CC + +std::string TestTypeName(const ::testing::TestParamInfo& info) { + return info.param.AsString(); +} + +// Run with all combinations of xds/fake resolver and enabling load reporting. +INSTANTIATE_TEST_SUITE_P( + XdsTest, BasicTest, + ::testing::Values( + TestType(), TestType().set_enable_load_reporting(), + TestType().set_use_fake_resolver(), + TestType().set_use_fake_resolver().set_enable_load_reporting()), + &TestTypeName); + +// Run with both fake resolver and xds resolver. +// Don't run with load reporting or v2 or RDS, since they are irrelevant to +// the tests. +INSTANTIATE_TEST_SUITE_P(XdsTest, SecureNamingTest, + ::testing::Values(TestType(), + TestType().set_use_fake_resolver()), + &TestTypeName); + +// LDS depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, LdsTest, ::testing::Values(TestType()), + &TestTypeName); +INSTANTIATE_TEST_SUITE_P(XdsTest, LdsV2Test, + ::testing::Values(TestType().set_use_v2()), + &TestTypeName); + +// LDS/RDS commmon tests depend on XdsResolver. +INSTANTIATE_TEST_SUITE_P( + XdsTest, LdsRdsTest, + ::testing::Values(TestType(), TestType().set_enable_rds_testing(), + // Also test with xDS v2. + TestType().set_enable_rds_testing().set_use_v2()), + &TestTypeName); + +// CDS depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P( + XdsTest, CdsTest, + ::testing::Values(TestType(), TestType().set_enable_load_reporting()), + &TestTypeName); + +// CDS depends on XdsResolver. +// Security depends on v3. +// Not enabling load reporting or RDS, since those are irrelevant to these +// tests. +INSTANTIATE_TEST_SUITE_P( + XdsTest, XdsSecurityTest, + ::testing::Values(TestType().set_use_xds_credentials()), &TestTypeName); + +// We are only testing the server here. +// Run with bootstrap from env var, so that we use a global XdsClient +// instance. Otherwise, we would need to use a separate fake resolver +// result generator on the client and server sides. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsEnabledServerTest, + ::testing::Values(TestType().set_bootstrap_source( + TestType::kBootstrapFromEnvVar)), + &TestTypeName); + +// We are only testing the server here. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsServerSecurityTest, + ::testing::Values(TestType() + .set_use_fake_resolver() + .set_use_xds_credentials()), + &TestTypeName); + +// We are only testing the server here. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsEnabledServerStatusNotificationTest, + ::testing::Values(TestType() + .set_use_fake_resolver() + .set_use_xds_credentials()), + &TestTypeName); + +// We are only testing the server here. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsServerFilterChainMatchTest, + ::testing::Values(TestType() + .set_use_fake_resolver() + .set_use_xds_credentials()), + &TestTypeName); + +// EDS could be tested with or without XdsResolver, but the tests would +// be the same either way, so we test it only with XdsResolver. +INSTANTIATE_TEST_SUITE_P( + XdsTest, EdsTest, + ::testing::Values(TestType(), TestType().set_enable_load_reporting()), + &TestTypeName); + +// Test initial resource timeouts for each resource type. +// Do this only for XdsResolver with RDS enabled, so that we can test +// all resource types. +// Run with V3 only, since the functionality is no different in V2. +INSTANTIATE_TEST_SUITE_P(XdsTest, TimeoutTest, + ::testing::Values(TestType().set_enable_rds_testing()), + &TestTypeName); + +// XdsResolverOnlyTest depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P( + XdsTest, XdsResolverOnlyTest, + ::testing::Values(TestType(), TestType().set_enable_load_reporting()), + &TestTypeName); + +// Runs with bootstrap from env var, so that there's a global XdsClient. +INSTANTIATE_TEST_SUITE_P( + XdsTest, GlobalXdsClientTest, + ::testing::Values( + TestType().set_bootstrap_source(TestType::kBootstrapFromEnvVar), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_enable_load_reporting()), + &TestTypeName); + +// XdsResolverLoadReprtingOnlyTest depends on XdsResolver and load reporting. +INSTANTIATE_TEST_SUITE_P( + XdsTest, XdsResolverLoadReportingOnlyTest, + ::testing::Values(TestType().set_enable_load_reporting()), &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, LocalityMapTest, + ::testing::Values( + TestType(), TestType().set_enable_load_reporting(), + TestType().set_use_fake_resolver(), + TestType().set_use_fake_resolver().set_enable_load_reporting()), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, FailoverTest, + ::testing::Values( + TestType(), TestType().set_enable_load_reporting(), + TestType().set_use_fake_resolver(), + TestType().set_use_fake_resolver().set_enable_load_reporting()), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, DropTest, + ::testing::Values( + TestType(), TestType().set_enable_load_reporting(), + TestType().set_use_fake_resolver(), + TestType().set_use_fake_resolver().set_enable_load_reporting()), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, BalancerUpdateTest, + ::testing::Values( + TestType().set_use_fake_resolver(), + TestType().set_use_fake_resolver().set_enable_load_reporting(), + TestType().set_enable_load_reporting()), + &TestTypeName); + +// Load reporting tests are not run with load reporting disabled. +INSTANTIATE_TEST_SUITE_P( + XdsTest, ClientLoadReportingTest, + ::testing::Values( + TestType().set_enable_load_reporting(), + TestType().set_enable_load_reporting().set_use_fake_resolver()), + &TestTypeName); + +// Load reporting tests are not run with load reporting disabled. +INSTANTIATE_TEST_SUITE_P( + XdsTest, ClientLoadReportingWithDropTest, + ::testing::Values( + TestType().set_enable_load_reporting(), + TestType().set_enable_load_reporting().set_use_fake_resolver()), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, FaultInjectionTest, + ::testing::Values( + TestType(), TestType().set_enable_rds_testing(), + TestType().set_filter_config_setup( + TestType::FilterConfigSetup::kRouteOverride), + TestType().set_enable_rds_testing().set_filter_config_setup( + TestType::FilterConfigSetup::kRouteOverride)), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P( + XdsTest, BootstrapSourceTest, + ::testing::Values( + TestType().set_bootstrap_source(TestType::kBootstrapFromEnvVar), + TestType().set_bootstrap_source(TestType::kBootstrapFromFile)), + &TestTypeName); + +#ifndef DISABLED_XDS_PROTO_IN_CC +// Run CSDS tests with RDS enabled and disabled. +// These need to run with the bootstrap from an env var instead of from +// a channel arg, since there needs to be a global XdsClient instance. +INSTANTIATE_TEST_SUITE_P( + XdsTest, ClientStatusDiscoveryServiceTest, + ::testing::Values( + TestType().set_bootstrap_source(TestType::kBootstrapFromEnvVar), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_enable_rds_testing(), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_use_csds_streaming(), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_enable_rds_testing() + .set_use_csds_streaming()), + &TestTypeName); +INSTANTIATE_TEST_SUITE_P( + XdsTest, CsdsShortAdsTimeoutTest, + ::testing::Values( + TestType().set_bootstrap_source(TestType::kBootstrapFromEnvVar), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_enable_rds_testing(), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_use_csds_streaming(), + TestType() + .set_bootstrap_source(TestType::kBootstrapFromEnvVar) + .set_enable_rds_testing() + .set_use_csds_streaming()), + &TestTypeName); +#endif // DISABLED_XDS_PROTO_IN_CC + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::WriteBootstrapFiles(); + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + grpc_core::CertificateProviderRegistry::RegisterCertificateProviderFactory( + absl::make_unique( + "fake1", &grpc::testing::g_fake1_cert_data_map)); + grpc_core::CertificateProviderRegistry::RegisterCertificateProviderFactory( + absl::make_unique( + "fake2", &grpc::testing::g_fake2_cert_data_map)); + grpc_init(); + grpc_core::XdsHttpFilterRegistry::RegisterFilter( + absl::make_unique( + "grpc.testing.client_only_http_filter", true, false), + {"grpc.testing.client_only_http_filter"}); + grpc_core::XdsHttpFilterRegistry::RegisterFilter( + absl::make_unique( + "grpc.testing.server_only_http_filter", false, true), + {"grpc.testing.server_only_http_filter"}); + const auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/cpp/end2end/xds/xds_server.cc b/test/cpp/end2end/xds/xds_server.cc new file mode 100644 index 00000000..f9ac5a18 --- /dev/null +++ b/test/cpp/end2end/xds/xds_server.cc @@ -0,0 +1,257 @@ +// +// Copyright 2017 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "test/cpp/end2end/xds/xds_server.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/types/optional.h" + +#include + +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/proto/grpc/testing/xds/ads_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/lrs_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/ads.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/discovery.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/lrs.grpc.pb.h" + +namespace grpc { +namespace testing { + +// +// AdsServiceImpl +// + +void AdsServiceImpl::SetResource(google::protobuf::Any resource, + const std::string& type_url, + const std::string& name) { + grpc_core::MutexLock lock(&ads_mu_); + ResourceTypeState& resource_type_state = resource_map_[type_url]; + ++resource_type_state.resource_type_version; + ResourceState& resource_state = resource_type_state.resource_name_map[name]; + resource_state.resource_type_version = + resource_type_state.resource_type_version; + resource_state.resource = std::move(resource); + gpr_log(GPR_INFO, + "ADS[%p]: Updating %s resource %s; resource_type_version now %u", + this, type_url.c_str(), name.c_str(), + resource_type_state.resource_type_version); + for (SubscriptionState* subscription : resource_state.subscriptions) { + subscription->update_queue->emplace_back(type_url, name); + } +} + +void AdsServiceImpl::UnsetResource(const std::string& type_url, + const std::string& name) { + grpc_core::MutexLock lock(&ads_mu_); + ResourceTypeState& resource_type_state = resource_map_[type_url]; + ++resource_type_state.resource_type_version; + ResourceState& resource_state = resource_type_state.resource_name_map[name]; + resource_state.resource_type_version = + resource_type_state.resource_type_version; + resource_state.resource.reset(); + gpr_log(GPR_INFO, + "ADS[%p]: Unsetting %s resource %s; resource_type_version now %u", + this, type_url.c_str(), name.c_str(), + resource_type_state.resource_type_version); + for (SubscriptionState* subscription : resource_state.subscriptions) { + subscription->update_queue->emplace_back(type_url, name); + } +} + +// Checks whether the client needs to receive a newer version of +// the resource. +bool AdsServiceImpl::ClientNeedsResourceUpdate( + const ResourceTypeState& resource_type_state, + const ResourceState& resource_state, int client_resource_type_version) { + return client_resource_type_version < + resource_type_state.resource_type_version && + resource_state.resource_type_version <= + resource_type_state.resource_type_version; +} + +// Subscribes to a resource if not already subscribed: +// 1. Sets the update_queue field in subscription_state. +// 2. Adds subscription_state to resource_state->subscriptions. +bool AdsServiceImpl::MaybeSubscribe(const std::string& resource_type, + const std::string& resource_name, + SubscriptionState* subscription_state, + ResourceState* resource_state, + UpdateQueue* update_queue) { + // The update_queue will be null if we were not previously subscribed. + if (subscription_state->update_queue != nullptr) return false; + subscription_state->update_queue = update_queue; + resource_state->subscriptions.emplace(subscription_state); + gpr_log(GPR_INFO, "ADS[%p]: subscribe to resource type %s name %s state %p", + this, resource_type.c_str(), resource_name.c_str(), + &subscription_state); + return true; +} + +// Removes subscriptions for resources no longer present in the +// current request. +void AdsServiceImpl::ProcessUnsubscriptions( + const std::string& resource_type, + const std::set& resources_in_current_request, + SubscriptionNameMap* subscription_name_map, + ResourceNameMap* resource_name_map) { + for (auto it = subscription_name_map->begin(); + it != subscription_name_map->end();) { + const std::string& resource_name = it->first; + SubscriptionState& subscription_state = it->second; + if (resources_in_current_request.find(resource_name) != + resources_in_current_request.end()) { + ++it; + continue; + } + gpr_log(GPR_INFO, "ADS[%p]: Unsubscribe to type=%s name=%s state=%p", this, + resource_type.c_str(), resource_name.c_str(), &subscription_state); + auto resource_it = resource_name_map->find(resource_name); + GPR_ASSERT(resource_it != resource_name_map->end()); + auto& resource_state = resource_it->second; + resource_state.subscriptions.erase(&subscription_state); + if (resource_state.subscriptions.empty() && + !resource_state.resource.has_value()) { + resource_name_map->erase(resource_it); + } + it = subscription_name_map->erase(it); + } +} + +void AdsServiceImpl::Start() { + grpc_core::MutexLock lock(&ads_mu_); + ads_done_ = false; +} + +void AdsServiceImpl::Shutdown() { + { + grpc_core::MutexLock lock(&ads_mu_); + if (!ads_done_) { + ads_done_ = true; + ads_cond_.SignalAll(); + } + resource_type_response_state_.clear(); + } + gpr_log(GPR_INFO, "ADS[%p]: shut down", this); +} + +// +// LrsServiceImpl::ClientStats +// + +uint64_t LrsServiceImpl::ClientStats::total_successful_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_successful_requests; + } + return sum; +} + +uint64_t LrsServiceImpl::ClientStats::total_requests_in_progress() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_requests_in_progress; + } + return sum; +} + +uint64_t LrsServiceImpl::ClientStats::total_error_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_error_requests; + } + return sum; +} + +uint64_t LrsServiceImpl::ClientStats::total_issued_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_issued_requests; + } + return sum; +} + +uint64_t LrsServiceImpl::ClientStats::dropped_requests( + const std::string& category) const { + auto iter = dropped_requests_.find(category); + GPR_ASSERT(iter != dropped_requests_.end()); + return iter->second; +} + +LrsServiceImpl::ClientStats& LrsServiceImpl::ClientStats::operator+=( + const ClientStats& other) { + for (const auto& p : other.locality_stats_) { + locality_stats_[p.first] += p.second; + } + total_dropped_requests_ += other.total_dropped_requests_; + for (const auto& p : other.dropped_requests_) { + dropped_requests_[p.first] += p.second; + } + return *this; +} + +// +// LrsServiceImpl +// + +void LrsServiceImpl::Start() { + { + grpc_core::MutexLock lock(&lrs_mu_); + lrs_done_ = false; + } + { + grpc_core::MutexLock lock(&load_report_mu_); + result_queue_.clear(); + } +} + +void LrsServiceImpl::Shutdown() { + { + grpc_core::MutexLock lock(&lrs_mu_); + if (!lrs_done_) { + lrs_done_ = true; + lrs_cv_.SignalAll(); + } + } + gpr_log(GPR_INFO, "LRS[%p]: shut down", this); +} + +std::vector LrsServiceImpl::WaitForLoadReport() { + grpc_core::MutexLock lock(&load_report_mu_); + grpc_core::CondVar cv; + if (result_queue_.empty()) { + load_report_cond_ = &cv; + while (result_queue_.empty()) { + cv.Wait(&load_report_mu_); + } + load_report_cond_ = nullptr; + } + std::vector result = std::move(result_queue_.front()); + result_queue_.pop_front(); + return result; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/ext/filters/census/stats_plugin_end2end_test.cc b/test/cpp/ext/filters/census/stats_plugin_end2end_test.cc new file mode 100644 index 00000000..a39ae72f --- /dev/null +++ b/test/cpp/ext/filters/census/stats_plugin_end2end_test.cc @@ -0,0 +1,559 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include // NOLINT +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "opencensus/stats/stats.h" +#include "opencensus/stats/tag_key.h" +#include "opencensus/stats/testing/test_utils.h" +#include "opencensus/tags/tag_map.h" +#include "opencensus/tags/with_tag_map.h" + +#include +#include + +#include "src/cpp/ext/filters/census/context.h" +#include "src/cpp/ext/filters/census/grpc_plugin.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +using ::opencensus::stats::Aggregation; +using ::opencensus::stats::Distribution; +using ::opencensus::stats::View; +using ::opencensus::stats::ViewDescriptor; +using ::opencensus::stats::testing::TestUtils; +using ::opencensus::tags::TagKey; +using ::opencensus::tags::WithTagMap; + +static const auto TEST_TAG_KEY = TagKey::Register("my_key"); +static const auto TEST_TAG_VALUE = "my_value"; +const char* kExpectedTraceIdKey = "expected_trace_id"; + +class EchoServer final : public EchoTestService::Service { + ::grpc::Status Echo(::grpc::ServerContext* context, + const EchoRequest* request, + EchoResponse* response) override { + for (const auto& metadata : context->client_metadata()) { + if (metadata.first == kExpectedTraceIdKey) { + EXPECT_EQ(metadata.second, reinterpret_cast( + context->census_context()) + ->Span() + .context() + .trace_id() + .ToHex()); + break; + } + } + if (request->param().expected_error().code() == 0) { + response->set_message(request->message()); + return ::grpc::Status::OK; + } else { + return ::grpc::Status(static_cast<::grpc::StatusCode>( + request->param().expected_error().code()), + ""); + } + } +}; + +class StatsPluginEnd2EndTest : public ::testing::Test { + protected: + static void SetUpTestCase() { RegisterOpenCensusPlugin(); } + + void SetUp() override { + // Set up a synchronous server on a different thread to avoid the asynch + // interface. + ::grpc::ServerBuilder builder; + int port; + // Use IPv4 here because it's less flaky than IPv6 ("[::]:0") on Travis. + builder.AddListeningPort("0.0.0.0:0", ::grpc::InsecureServerCredentials(), + &port); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + ASSERT_NE(nullptr, server_); + ASSERT_NE(0, port); + server_address_ = absl::StrCat("localhost:", port); + server_thread_ = std::thread(&StatsPluginEnd2EndTest::RunServerLoop, this); + + stub_ = EchoTestService::NewStub(::grpc::CreateChannel( + server_address_, ::grpc::InsecureChannelCredentials())); + } + + void ResetStub(std::shared_ptr channel) { + stub_ = EchoTestService::NewStub(channel); + } + + void TearDown() override { + server_->Shutdown(); + server_thread_.join(); + } + + void RunServerLoop() { server_->Wait(); } + + const std::string client_method_name_ = "grpc.testing.EchoTestService/Echo"; + const std::string server_method_name_ = "grpc.testing.EchoTestService/Echo"; + + std::string server_address_; + EchoServer service_; + std::unique_ptr server_; + std::thread server_thread_; + + std::unique_ptr stub_; +}; + +TEST_F(StatsPluginEnd2EndTest, ErrorCount) { + const auto client_method_descriptor = + ViewDescriptor() + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_name("client_method") + .set_aggregation(Aggregation::Count()) + .add_column(ClientMethodTagKey()) + .add_column(TEST_TAG_KEY); + View client_method_view(client_method_descriptor); + const auto server_method_descriptor = + ViewDescriptor() + .set_measure(kRpcServerServerLatencyMeasureName) + .set_name("server_method") + .set_aggregation(Aggregation::Count()) + .add_column(ServerMethodTagKey()); + //.add_column(TEST_TAG_KEY); + View server_method_view(server_method_descriptor); + + const auto client_status_descriptor = + ViewDescriptor() + .set_measure(kRpcClientRoundtripLatencyMeasureName) + .set_name("client_status") + .set_aggregation(Aggregation::Count()) + .add_column(ClientStatusTagKey()) + .add_column(TEST_TAG_KEY); + View client_status_view(client_status_descriptor); + const auto server_status_descriptor = + ViewDescriptor() + .set_measure(kRpcServerServerLatencyMeasureName) + .set_name("server_status") + .set_aggregation(Aggregation::Count()) + .add_column(ServerStatusTagKey()); + View server_status_view(server_status_descriptor); + + // Cover all valid statuses. + for (int i = 0; i <= 16; ++i) { + EchoRequest request; + request.set_message("foo"); + request.mutable_param()->mutable_expected_error()->set_code(i); + EchoResponse response; + ::grpc::ClientContext context; + { + WithTagMap tags({{TEST_TAG_KEY, TEST_TAG_VALUE}}); + ::grpc::Status status = stub_->Echo(&context, request, &response); + } + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + + // Client side views can be tagged with custom tags. + EXPECT_THAT( + client_method_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_, TEST_TAG_VALUE), 17))); + // TODO(unknown): Implement server view tagging with custom tags. + EXPECT_THAT(server_method_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), 17))); + + // Client side views can be tagged with custom tags. + auto client_tags = { + ::testing::Pair(::testing::ElementsAre("OK", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("CANCELLED", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("UNKNOWN", TEST_TAG_VALUE), 1), + ::testing::Pair( + ::testing::ElementsAre("INVALID_ARGUMENT", TEST_TAG_VALUE), 1), + ::testing::Pair( + ::testing::ElementsAre("DEADLINE_EXCEEDED", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("NOT_FOUND", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("ALREADY_EXISTS", TEST_TAG_VALUE), + 1), + ::testing::Pair( + ::testing::ElementsAre("PERMISSION_DENIED", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("UNAUTHENTICATED", TEST_TAG_VALUE), + 1), + ::testing::Pair( + ::testing::ElementsAre("RESOURCE_EXHAUSTED", TEST_TAG_VALUE), 1), + ::testing::Pair( + ::testing::ElementsAre("FAILED_PRECONDITION", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("ABORTED", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("OUT_OF_RANGE", TEST_TAG_VALUE), + 1), + ::testing::Pair(::testing::ElementsAre("UNIMPLEMENTED", TEST_TAG_VALUE), + 1), + ::testing::Pair(::testing::ElementsAre("INTERNAL", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("UNAVAILABLE", TEST_TAG_VALUE), 1), + ::testing::Pair(::testing::ElementsAre("DATA_LOSS", TEST_TAG_VALUE), 1), + }; + + // TODO(unknown): Implement server view tagging with custom tags. + auto server_tags = { + ::testing::Pair(::testing::ElementsAre("OK"), 1), + ::testing::Pair(::testing::ElementsAre("CANCELLED"), 1), + ::testing::Pair(::testing::ElementsAre("UNKNOWN"), 1), + ::testing::Pair(::testing::ElementsAre("INVALID_ARGUMENT"), 1), + ::testing::Pair(::testing::ElementsAre("DEADLINE_EXCEEDED"), 1), + ::testing::Pair(::testing::ElementsAre("NOT_FOUND"), 1), + ::testing::Pair(::testing::ElementsAre("ALREADY_EXISTS"), 1), + ::testing::Pair(::testing::ElementsAre("PERMISSION_DENIED"), 1), + ::testing::Pair(::testing::ElementsAre("UNAUTHENTICATED"), 1), + ::testing::Pair(::testing::ElementsAre("RESOURCE_EXHAUSTED"), 1), + ::testing::Pair(::testing::ElementsAre("FAILED_PRECONDITION"), 1), + ::testing::Pair(::testing::ElementsAre("ABORTED"), 1), + ::testing::Pair(::testing::ElementsAre("OUT_OF_RANGE"), 1), + ::testing::Pair(::testing::ElementsAre("UNIMPLEMENTED"), 1), + ::testing::Pair(::testing::ElementsAre("INTERNAL"), 1), + ::testing::Pair(::testing::ElementsAre("UNAVAILABLE"), 1), + ::testing::Pair(::testing::ElementsAre("DATA_LOSS"), 1), + }; + + EXPECT_THAT(client_status_view.GetData().int_data(), + ::testing::UnorderedElementsAreArray(client_tags)); + EXPECT_THAT(server_status_view.GetData().int_data(), + ::testing::UnorderedElementsAreArray(server_tags)); +} + +TEST_F(StatsPluginEnd2EndTest, RequestReceivedBytesPerRpc) { + View client_sent_bytes_per_rpc_view(ClientSentBytesPerRpcCumulative()); + View client_received_bytes_per_rpc_view( + ClientReceivedBytesPerRpcCumulative()); + View server_sent_bytes_per_rpc_view(ServerSentBytesPerRpcCumulative()); + View server_received_bytes_per_rpc_view( + ServerReceivedBytesPerRpcCumulative()); + + { + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + ASSERT_TRUE(status.ok()); + EXPECT_EQ("foo", response.message()); + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + + EXPECT_THAT(client_received_bytes_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, + ::testing::Gt(0.0)))))); + EXPECT_THAT(client_sent_bytes_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, + ::testing::Gt(0.0)))))); + EXPECT_THAT(server_received_bytes_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, + ::testing::Gt(0.0)))))); + EXPECT_THAT(server_sent_bytes_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, + ::testing::Gt(0.0)))))); +} + +TEST_F(StatsPluginEnd2EndTest, Latency) { + View client_latency_view(ClientRoundtripLatencyCumulative()); + View client_server_latency_view(ClientServerLatencyCumulative()); + View server_server_latency_view(ServerServerLatencyCumulative()); + + const absl::Time start_time = absl::Now(); + { + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + ASSERT_TRUE(status.ok()); + EXPECT_EQ("foo", response.message()); + } + // We do not know exact latency/elapsed time, but we know it is less than the + // entire time spent making the RPC. + const double max_time = absl::ToDoubleMilliseconds(absl::Now() - start_time); + + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + + EXPECT_THAT( + client_latency_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf( + ::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, ::testing::Gt(0.0)), + ::testing::Property(&Distribution::mean, + ::testing::Lt(max_time)))))); + + // Elapsed time is a subinterval of total latency. + const auto client_latency = client_latency_view.GetData() + .distribution_data() + .find({client_method_name_}) + ->second.mean(); + EXPECT_THAT( + client_server_latency_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf( + ::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, ::testing::Gt(0.0)), + ::testing::Property(&Distribution::mean, + ::testing::Lt(client_latency)))))); + + // client server elapsed time should be the same value propagated to the + // client. + const auto client_elapsed_time = client_server_latency_view.GetData() + .distribution_data() + .find({client_method_name_}) + ->second.mean(); + EXPECT_THAT( + server_server_latency_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), + ::testing::AllOf( + ::testing::Property(&Distribution::count, 1), + ::testing::Property(&Distribution::mean, + ::testing::DoubleEq(client_elapsed_time)))))); +} + +TEST_F(StatsPluginEnd2EndTest, CompletedRpcs) { + View client_completed_rpcs_view(ClientCompletedRpcsCumulative()); + View server_completed_rpcs_view(ServerCompletedRpcsCumulative()); + + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + const int count = 5; + for (int i = 0; i < count; ++i) { + { + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + ASSERT_TRUE(status.ok()); + EXPECT_EQ("foo", response.message()); + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + + EXPECT_THAT(client_completed_rpcs_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_, "OK"), i + 1))); + EXPECT_THAT(server_completed_rpcs_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_, "OK"), i + 1))); + } +} + +TEST_F(StatsPluginEnd2EndTest, RequestReceivedMessagesPerRpc) { + // TODO(unknown): Use streaming RPCs. + View client_received_messages_per_rpc_view( + ClientSentMessagesPerRpcCumulative()); + View client_sent_messages_per_rpc_view( + ClientReceivedMessagesPerRpcCumulative()); + View server_received_messages_per_rpc_view( + ServerSentMessagesPerRpcCumulative()); + View server_sent_messages_per_rpc_view( + ServerReceivedMessagesPerRpcCumulative()); + + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + const int count = 5; + for (int i = 0; i < count; ++i) { + { + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + ASSERT_TRUE(status.ok()); + EXPECT_EQ("foo", response.message()); + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + + EXPECT_THAT( + client_received_messages_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, i + 1), + ::testing::Property(&Distribution::mean, + ::testing::DoubleEq(1.0)))))); + EXPECT_THAT( + client_sent_messages_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, i + 1), + ::testing::Property(&Distribution::mean, + ::testing::DoubleEq(1.0)))))); + EXPECT_THAT( + server_received_messages_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, i + 1), + ::testing::Property(&Distribution::mean, + ::testing::DoubleEq(1.0)))))); + EXPECT_THAT( + server_sent_messages_per_rpc_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(server_method_name_), + ::testing::AllOf(::testing::Property(&Distribution::count, i + 1), + ::testing::Property(&Distribution::mean, + ::testing::DoubleEq(1.0)))))); + } +} + +TEST_F(StatsPluginEnd2EndTest, TestRetryStatsWithoutAdditionalRetries) { + View client_retries_cumulative_view(ClientRetriesCumulative()); + View client_transparent_retries_cumulative_view( + ClientTransparentRetriesCumulative()); + View client_retry_delay_per_call_view(ClientRetryDelayPerCallCumulative()); + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + const int count = 5; + for (int i = 0; i < count; ++i) { + { + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + ASSERT_TRUE(status.ok()); + EXPECT_EQ("foo", response.message()); + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + EXPECT_THAT( + client_retries_cumulative_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), ::testing::Eq(0)))); + EXPECT_THAT( + client_transparent_retries_cumulative_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), ::testing::Eq(0)))); + EXPECT_THAT( + client_retry_delay_per_call_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::Property(&Distribution::mean, ::testing::Eq(0))))); + } +} + +TEST_F(StatsPluginEnd2EndTest, TestRetryStatsWithAdditionalRetries) { + View client_retries_cumulative_view(ClientRetriesCumulative()); + View client_transparent_retries_cumulative_view( + ClientTransparentRetriesCumulative()); + View client_retry_delay_per_call_view(ClientRetryDelayPerCallCumulative()); + ChannelArguments args; + args.SetInt(GRPC_ARG_ENABLE_RETRIES, 1); + args.SetString(GRPC_ARG_SERVICE_CONFIG, + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"grpc.testing.EchoTestService\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"0.1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"); + auto channel = + CreateCustomChannel(server_address_, InsecureChannelCredentials(), args); + ResetStub(channel); + EchoRequest request; + request.mutable_param()->mutable_expected_error()->set_code( + StatusCode::ABORTED); + request.set_message("foo"); + EchoResponse response; + const int count = 5; + for (int i = 0; i < count; ++i) { + { + ::grpc::ClientContext context; + ::grpc::Status status = stub_->Echo(&context, request, &response); + EXPECT_EQ(status.error_code(), StatusCode::ABORTED); + } + absl::SleepFor(absl::Milliseconds(500)); + TestUtils::Flush(); + EXPECT_THAT(client_retries_cumulative_view.GetData().int_data(), + ::testing::UnorderedElementsAre( + ::testing::Pair(::testing::ElementsAre(client_method_name_), + ::testing::Eq((i + 1) * 2)))); + EXPECT_THAT( + client_transparent_retries_cumulative_view.GetData().int_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), ::testing::Eq(0)))); + auto data = client_retry_delay_per_call_view.GetData().distribution_data(); + for (const auto& entry : data) { + gpr_log(GPR_ERROR, "Mean Retry Delay %s: %lf ms", entry.first[0].c_str(), + entry.second.mean()); + } + // We expect the retry delay to be around 100ms. + EXPECT_THAT( + client_retry_delay_per_call_view.GetData().distribution_data(), + ::testing::UnorderedElementsAre(::testing::Pair( + ::testing::ElementsAre(client_method_name_), + ::testing::Property( + &Distribution::mean, + ::testing::AllOf(::testing::Ge(50), ::testing::Le(300)))))); + } +} + +// Test that CensusContext object set by application is used. +TEST_F(StatsPluginEnd2EndTest, TestApplicationCensusContextFlows) { + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + ResetStub(channel); + EchoRequest request; + request.set_message("foo"); + EchoResponse response; + ::grpc::ClientContext context; + ::grpc::CensusContext app_census_context("root", + ::opencensus::tags::TagMap{}); + context.set_census_context( + reinterpret_cast(&app_census_context)); + context.AddMetadata(kExpectedTraceIdKey, + app_census_context.Span().context().trace_id().ToHex()); + ::grpc::Status status = stub_->Echo(&context, request, &response); + EXPECT_TRUE(status.ok()); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/grpclb/grpclb_api_test.cc b/test/cpp/grpclb/grpclb_api_test.cc new file mode 100644 index 00000000..6e2d19f3 --- /dev/null +++ b/test/cpp/grpclb/grpclb_api_test.cc @@ -0,0 +1,145 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "google/protobuf/duration.upb.h" +#include "upb/upb.hpp" + +#include +#include + +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/event_engine/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/proto/grpc/lb/v1/load_balancer.pb.h" // C++ version +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +using grpc::lb::v1::LoadBalanceRequest; +using grpc::lb::v1::LoadBalanceResponse; + +class GrpclbTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +std::string Ip4ToPackedString(const char* ip_str) { + struct in_addr ip4; + GPR_ASSERT(inet_pton(AF_INET, ip_str, &ip4) == 1); + return std::string(reinterpret_cast(&ip4), sizeof(ip4)); +} + +std::string PackedStringToIp(const grpc_core::GrpcLbServer& server) { + char ip_str[46] = {0}; + int af = -1; + if (server.ip_size == 4) { + af = AF_INET; + } else if (server.ip_size == 16) { + af = AF_INET6; + } else { + abort(); + } + GPR_ASSERT(inet_ntop(af, (void*)server.ip_addr, ip_str, 46) != nullptr); + return ip_str; +} + +TEST_F(GrpclbTest, CreateRequest) { + const std::string service_name = "AServiceName"; + LoadBalanceRequest request; + upb::Arena arena; + grpc_slice slice = + grpc_core::GrpcLbRequestCreate(service_name.c_str(), arena.ptr()); + const int num_bytes_written = GRPC_SLICE_LENGTH(slice); + EXPECT_GT(num_bytes_written, 0); + request.ParseFromArray(GRPC_SLICE_START_PTR(slice), num_bytes_written); + EXPECT_EQ(request.initial_request().name(), service_name); + grpc_slice_unref(slice); +} + +TEST_F(GrpclbTest, ParseInitialResponse) { + // Construct response to parse. + LoadBalanceResponse response; + auto* initial_response = response.mutable_initial_response(); + auto* client_stats_report_interval = + initial_response->mutable_client_stats_report_interval(); + client_stats_report_interval->set_seconds(123); + client_stats_report_interval->set_nanos(456000000); + const std::string encoded_response = response.SerializeAsString(); + grpc_slice encoded_slice = + grpc_slice_from_copied_string(encoded_response.c_str()); + // Test parsing. + grpc_core::GrpcLbResponse resp; + upb::Arena arena; + ASSERT_TRUE( + grpc_core::GrpcLbResponseParse(encoded_slice, arena.ptr(), &resp)); + grpc_slice_unref(encoded_slice); + EXPECT_EQ(resp.type, resp.INITIAL); + EXPECT_EQ(resp.client_stats_report_interval, 123456); + EXPECT_EQ(resp.serverlist.size(), 0); +} + +TEST_F(GrpclbTest, ParseResponseServerList) { + // Construct response to parse. + LoadBalanceResponse response; + auto* serverlist = response.mutable_server_list(); + auto* server = serverlist->add_servers(); + server->set_ip_address(Ip4ToPackedString("127.0.0.1")); + server->set_port(12345); + server->set_load_balance_token("rate_limting"); + server->set_drop(true); + server = response.mutable_server_list()->add_servers(); + server->set_ip_address(Ip4ToPackedString("10.0.0.1")); + server->set_port(54321); + server->set_load_balance_token("load_balancing"); + server->set_drop(true); + const std::string encoded_response = response.SerializeAsString(); + const grpc_slice encoded_slice = grpc_slice_from_copied_buffer( + encoded_response.data(), encoded_response.size()); + // Test parsing. + grpc_core::GrpcLbResponse resp; + upb::Arena arena; + ASSERT_TRUE( + grpc_core::GrpcLbResponseParse(encoded_slice, arena.ptr(), &resp)); + grpc_slice_unref(encoded_slice); + EXPECT_EQ(resp.type, resp.SERVERLIST); + EXPECT_EQ(resp.serverlist.size(), 2); + EXPECT_EQ(PackedStringToIp(resp.serverlist[0]), "127.0.0.1"); + EXPECT_EQ(resp.serverlist[0].port, 12345); + EXPECT_STREQ(resp.serverlist[0].load_balance_token, "rate_limting"); + EXPECT_TRUE(resp.serverlist[0].drop); + EXPECT_EQ(PackedStringToIp(resp.serverlist[1]), "10.0.0.1"); + EXPECT_EQ(resp.serverlist[1].port, 54321); + EXPECT_STREQ(resp.serverlist[1].load_balance_token, "load_balancing"); + EXPECT_TRUE(resp.serverlist[1].drop); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/interop/client.cc b/test/cpp/interop/client.cc new file mode 100644 index 00000000..ce7b1499 --- /dev/null +++ b/test/cpp/interop/client.cc @@ -0,0 +1,322 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "test/core/util/test_config.h" +#include "test/cpp/interop/client_helper.h" +#include "test/cpp/interop/interop_client.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(bool, use_alts, false, + "Whether to use alts. Enable alts will disable tls."); +ABSL_FLAG(bool, use_tls, false, "Whether to use tls."); +ABSL_FLAG(std::string, custom_credentials_type, "", + "User provided credentials type."); +ABSL_FLAG(bool, use_test_ca, false, "False to use SSL roots for google"); +ABSL_FLAG(int32_t, server_port, 0, "Server port."); +ABSL_FLAG(std::string, server_host, "localhost", "Server host to connect to"); +ABSL_FLAG(std::string, server_host_override, "", + "Override the server host which is sent in HTTP header"); +ABSL_FLAG( + std::string, test_case, "large_unary", + "Configure different test cases. Valid options are:\n\n" + "all : all test cases;\n" + + // TODO(veblush): Replace the help message with the following full message + // once Abseil fixes the flag-help compiler error on Windows. (b/171659833) + /* + "cancel_after_begin : cancel stream after starting it;\n" + "cancel_after_first_response: cancel on first response;\n" + "channel_soak: sends 'soak_iterations' rpcs, rebuilds channel each time;\n" + "client_compressed_streaming : compressed request streaming with " + "client_compressed_unary : single compressed request;\n" + "client_streaming : request streaming with single response;\n" + "compute_engine_creds: large_unary with compute engine auth;\n" + "custom_metadata: server will echo custom metadata;\n" + "empty_stream : bi-di stream with no request/response;\n" + "empty_unary : empty (zero bytes) request and response;\n" + "google_default_credentials: large unary using GDC;\n" + "half_duplex : half-duplex streaming;\n" + "jwt_token_creds: large_unary with JWT token auth;\n" + "large_unary : single request and (large) response;\n" + "long_lived_channel: sends large_unary rpcs over a long-lived channel;\n" + "oauth2_auth_token: raw oauth2 access token auth;\n" + "per_rpc_creds: raw oauth2 access token on a single rpc;\n" + "ping_pong : full-duplex streaming;\n" + "response streaming;\n" + "rpc_soak: 'sends soak_iterations' large_unary rpcs;\n" + "server_compressed_streaming : single request with compressed " + "server_compressed_unary : single compressed response;\n" + "server_streaming : single request with response streaming;\n" + "slow_consumer : single request with response streaming with " + "slow client consumer;\n" + "special_status_message: verify Unicode and whitespace in status message;\n" + "status_code_and_message: verify status code & message;\n" + "timeout_on_sleeping_server: deadline exceeds on stream;\n" + "unimplemented_method: client calls an unimplemented method;\n" + "unimplemented_service: client calls an unimplemented service;\n" + */ +); +ABSL_FLAG(std::string, default_service_account, "", + "Email of GCE default service account"); +ABSL_FLAG(std::string, service_account_key_file, "", + "Path to service account json key file."); +ABSL_FLAG(std::string, oauth_scope, "", "Scope for OAuth tokens."); +ABSL_FLAG(bool, do_not_abort_on_transient_failures, false, + "If set to 'true', abort() is not called in case of transient " + "failures (i.e failures that are temporary and will likely go away " + "on retrying; like a temporary connection failure) and an error " + "message is printed instead. Note that this flag just controls " + "whether abort() is called or not. It does not control whether the " + "test is retried in case of transient failures (and currently the " + "interop tests are not retried even if this flag is set to true)"); +ABSL_FLAG(int32_t, soak_iterations, 1000, + "The number of iterations to use for the two soak tests; rpc_soak " + "and channel_soak."); +ABSL_FLAG(int32_t, soak_max_failures, 0, + "The number of iterations in soak tests that are allowed to fail " + "(either due to non-OK status code or exceeding the " + "per-iteration max acceptable latency)."); +ABSL_FLAG(int32_t, soak_per_iteration_max_acceptable_latency_ms, 0, + "The number of milliseconds a single iteration in the two soak " + "tests (rpc_soak and channel_soak) should take."); +ABSL_FLAG(int32_t, soak_overall_timeout_seconds, 0, + "The overall number of seconds after which a soak test should " + "stop and fail, if the desired number of iterations have not yet " + "completed."); +ABSL_FLAG(int32_t, iteration_interval, 10, + "The interval in seconds between rpcs. This is used by " + "long_connection test"); +ABSL_FLAG(std::string, additional_metadata, "", + "Additional metadata to send in each request, as a " + "semicolon-separated list of key:value pairs."); + +using grpc::testing::CreateChannelForTestCase; +using grpc::testing::GetServiceAccountJsonKey; +using grpc::testing::UpdateActions; + +namespace { + +// Parse the contents of FLAGS_additional_metadata into a map. Allow +// alphanumeric characters and dashes in keys, and any character but semicolons +// in values. Convert keys to lowercase. On failure, log an error and return +// false. +bool ParseAdditionalMetadataFlag( + const std::string& flag, + std::multimap* additional_metadata) { + size_t start_pos = 0; + while (start_pos < flag.length()) { + size_t colon_pos = flag.find(':', start_pos); + if (colon_pos == std::string::npos) { + gpr_log(GPR_ERROR, + "Couldn't parse metadata flag: extra characters at end of flag"); + return false; + } + size_t semicolon_pos = flag.find(';', colon_pos); + + std::string key = flag.substr(start_pos, colon_pos - start_pos); + std::string value = + flag.substr(colon_pos + 1, semicolon_pos - colon_pos - 1); + + constexpr char alphanum_and_hyphen[] = + "-0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + if (key.find_first_not_of(alphanum_and_hyphen) != std::string::npos) { + gpr_log(GPR_ERROR, + "Couldn't parse metadata flag: key contains characters other " + "than alphanumeric and hyphens: %s", + key.c_str()); + return false; + } + + // Convert to lowercase. + for (char& c : key) { + if (c >= 'A' && c <= 'Z') { + c += ('a' - 'A'); + } + } + + gpr_log(GPR_INFO, "Adding additional metadata with key %s and value %s", + key.c_str(), value.c_str()); + additional_metadata->insert({key, value}); + + if (semicolon_pos == std::string::npos) { + break; + } else { + start_pos = semicolon_pos + 1; + } + } + + return true; +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + gpr_log(GPR_INFO, "Testing these cases: %s", + absl::GetFlag(FLAGS_test_case).c_str()); + int ret = 0; + + grpc::testing::ChannelCreationFunc channel_creation_func; + std::string test_case = absl::GetFlag(FLAGS_test_case); + if (absl::GetFlag(FLAGS_additional_metadata).empty()) { + channel_creation_func = [test_case]() { + return CreateChannelForTestCase(test_case); + }; + } else { + std::multimap additional_metadata; + if (!ParseAdditionalMetadataFlag(absl::GetFlag(FLAGS_additional_metadata), + &additional_metadata)) { + return 1; + } + + channel_creation_func = [test_case, additional_metadata]() { + std::vector> + factories; + factories.emplace_back( + new grpc::testing::AdditionalMetadataInterceptorFactory( + additional_metadata)); + return CreateChannelForTestCase(test_case, std::move(factories)); + }; + } + + grpc::testing::InteropClient client( + channel_creation_func, true, + absl::GetFlag(FLAGS_do_not_abort_on_transient_failures)); + + std::unordered_map> actions; + actions["empty_unary"] = + std::bind(&grpc::testing::InteropClient::DoEmpty, &client); + actions["large_unary"] = + std::bind(&grpc::testing::InteropClient::DoLargeUnary, &client); + actions["server_compressed_unary"] = std::bind( + &grpc::testing::InteropClient::DoServerCompressedUnary, &client); + actions["client_compressed_unary"] = std::bind( + &grpc::testing::InteropClient::DoClientCompressedUnary, &client); + actions["client_streaming"] = + std::bind(&grpc::testing::InteropClient::DoRequestStreaming, &client); + actions["server_streaming"] = + std::bind(&grpc::testing::InteropClient::DoResponseStreaming, &client); + actions["server_compressed_streaming"] = std::bind( + &grpc::testing::InteropClient::DoServerCompressedStreaming, &client); + actions["client_compressed_streaming"] = std::bind( + &grpc::testing::InteropClient::DoClientCompressedStreaming, &client); + actions["slow_consumer"] = std::bind( + &grpc::testing::InteropClient::DoResponseStreamingWithSlowConsumer, + &client); + actions["half_duplex"] = + std::bind(&grpc::testing::InteropClient::DoHalfDuplex, &client); + actions["ping_pong"] = + std::bind(&grpc::testing::InteropClient::DoPingPong, &client); + actions["cancel_after_begin"] = + std::bind(&grpc::testing::InteropClient::DoCancelAfterBegin, &client); + actions["cancel_after_first_response"] = std::bind( + &grpc::testing::InteropClient::DoCancelAfterFirstResponse, &client); + actions["timeout_on_sleeping_server"] = std::bind( + &grpc::testing::InteropClient::DoTimeoutOnSleepingServer, &client); + actions["empty_stream"] = + std::bind(&grpc::testing::InteropClient::DoEmptyStream, &client); + actions["pick_first_unary"] = + std::bind(&grpc::testing::InteropClient::DoPickFirstUnary, &client); + if (absl::GetFlag(FLAGS_use_tls)) { + actions["compute_engine_creds"] = + std::bind(&grpc::testing::InteropClient::DoComputeEngineCreds, &client, + absl::GetFlag(FLAGS_default_service_account), + absl::GetFlag(FLAGS_oauth_scope)); + actions["jwt_token_creds"] = + std::bind(&grpc::testing::InteropClient::DoJwtTokenCreds, &client, + GetServiceAccountJsonKey()); + actions["oauth2_auth_token"] = + std::bind(&grpc::testing::InteropClient::DoOauth2AuthToken, &client, + absl::GetFlag(FLAGS_default_service_account), + absl::GetFlag(FLAGS_oauth_scope)); + actions["per_rpc_creds"] = + std::bind(&grpc::testing::InteropClient::DoPerRpcCreds, &client, + GetServiceAccountJsonKey()); + } + if (absl::GetFlag(FLAGS_custom_credentials_type) == + "google_default_credentials") { + actions["google_default_credentials"] = + std::bind(&grpc::testing::InteropClient::DoGoogleDefaultCredentials, + &client, absl::GetFlag(FLAGS_default_service_account)); + } + actions["status_code_and_message"] = + std::bind(&grpc::testing::InteropClient::DoStatusWithMessage, &client); + actions["special_status_message"] = + std::bind(&grpc::testing::InteropClient::DoSpecialStatusMessage, &client); + actions["custom_metadata"] = + std::bind(&grpc::testing::InteropClient::DoCustomMetadata, &client); + actions["unimplemented_method"] = + std::bind(&grpc::testing::InteropClient::DoUnimplementedMethod, &client); + actions["unimplemented_service"] = + std::bind(&grpc::testing::InteropClient::DoUnimplementedService, &client); + actions["cacheable_unary"] = + std::bind(&grpc::testing::InteropClient::DoCacheableUnary, &client); + actions["channel_soak"] = std::bind( + &grpc::testing::InteropClient::DoChannelSoakTest, &client, + absl::GetFlag(FLAGS_soak_iterations), + absl::GetFlag(FLAGS_soak_max_failures), + absl::GetFlag(FLAGS_soak_per_iteration_max_acceptable_latency_ms), + absl::GetFlag(FLAGS_soak_overall_timeout_seconds)); + actions["rpc_soak"] = std::bind( + &grpc::testing::InteropClient::DoRpcSoakTest, &client, + absl::GetFlag(FLAGS_soak_iterations), + absl::GetFlag(FLAGS_soak_max_failures), + absl::GetFlag(FLAGS_soak_per_iteration_max_acceptable_latency_ms), + absl::GetFlag(FLAGS_soak_overall_timeout_seconds)); + actions["long_lived_channel"] = + std::bind(&grpc::testing::InteropClient::DoLongLivedChannelTest, &client, + absl::GetFlag(FLAGS_soak_iterations), + absl::GetFlag(FLAGS_iteration_interval)); + + UpdateActions(&actions); + + if (absl::GetFlag(FLAGS_test_case) == "all") { + for (const auto& action : actions) { + action.second(); + } + } else if (actions.find(absl::GetFlag(FLAGS_test_case)) != actions.end()) { + actions.find(absl::GetFlag(FLAGS_test_case))->second(); + } else { + std::string test_cases; + for (const auto& action : actions) { + if (!test_cases.empty()) test_cases += "\n"; + test_cases += action.first; + } + gpr_log(GPR_ERROR, "Unsupported test case %s. Valid options are\n%s", + absl::GetFlag(FLAGS_test_case).c_str(), test_cases.c_str()); + ret = 1; + } + + return ret; +} diff --git a/test/cpp/interop/client_helper.cc b/test/cpp/interop/client_helper.cc new file mode 100644 index 00000000..17f2ad25 --- /dev/null +++ b/test/cpp/interop/client_helper.cc @@ -0,0 +1,143 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/interop/client_helper.h" + +#include +#include +#include + +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include +#include + +#include "src/cpp/client/secure_credentials.h" +#include "test/core/security/oauth2_utils.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_DECLARE_FLAG(bool, use_alts); +ABSL_DECLARE_FLAG(bool, use_tls); +ABSL_DECLARE_FLAG(std::string, custom_credentials_type); +ABSL_DECLARE_FLAG(bool, use_test_ca); +ABSL_DECLARE_FLAG(int32_t, server_port); +ABSL_DECLARE_FLAG(std::string, server_host); +ABSL_DECLARE_FLAG(std::string, server_host_override); +ABSL_DECLARE_FLAG(std::string, test_case); +ABSL_DECLARE_FLAG(std::string, default_service_account); +ABSL_DECLARE_FLAG(std::string, service_account_key_file); +ABSL_DECLARE_FLAG(std::string, oauth_scope); + +namespace grpc { +namespace testing { + +std::string GetServiceAccountJsonKey() { + static std::string json_key; + if (json_key.empty()) { + std::ifstream json_key_file(absl::GetFlag(FLAGS_service_account_key_file)); + std::stringstream key_stream; + key_stream << json_key_file.rdbuf(); + json_key = key_stream.str(); + } + return json_key; +} + +std::string GetOauth2AccessToken() { + std::shared_ptr creds = GoogleComputeEngineCredentials(); + SecureCallCredentials* secure_creds = + dynamic_cast(creds.get()); + GPR_ASSERT(secure_creds != nullptr); + grpc_call_credentials* c_creds = secure_creds->GetRawCreds(); + char* token = grpc_test_fetch_oauth2_token_with_credentials(c_creds); + GPR_ASSERT(token != nullptr); + gpr_log(GPR_INFO, "Get raw oauth2 access token: %s", token); + std::string access_token(token + sizeof("Bearer ") - 1); + gpr_free(token); + return access_token; +} + +void UpdateActions( + std::unordered_map>* /*actions*/) {} + +std::shared_ptr CreateChannelForTestCase( + const std::string& test_case, + std::vector< + std::unique_ptr> + interceptor_creators) { + std::string server_uri = absl::GetFlag(FLAGS_server_host); + int32_t port = absl::GetFlag(FLAGS_server_port); + if (port != 0) { + absl::StrAppend(&server_uri, ":", std::to_string(port)); + } + std::shared_ptr creds; + if (test_case == "compute_engine_creds") { + creds = absl::GetFlag(FLAGS_custom_credentials_type) == + "google_default_credentials" + ? nullptr + : GoogleComputeEngineCredentials(); + } else if (test_case == "jwt_token_creds") { + std::string json_key = GetServiceAccountJsonKey(); + std::chrono::seconds token_lifetime = std::chrono::hours(1); + creds = absl::GetFlag(FLAGS_custom_credentials_type) == + "google_default_credentials" + ? nullptr + : ServiceAccountJWTAccessCredentials(json_key, + token_lifetime.count()); + } else if (test_case == "oauth2_auth_token") { + creds = absl::GetFlag(FLAGS_custom_credentials_type) == + "google_default_credentials" + ? nullptr + : AccessTokenCredentials(GetOauth2AccessToken()); + } else if (test_case == "pick_first_unary") { + ChannelArguments channel_args; + // allow the LB policy to be configured with service config + channel_args.SetInt(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION, 0); + return CreateTestChannel( + server_uri, absl::GetFlag(FLAGS_custom_credentials_type), + absl::GetFlag(FLAGS_server_host_override), + !absl::GetFlag(FLAGS_use_test_ca), creds, channel_args); + } + if (absl::GetFlag(FLAGS_custom_credentials_type).empty()) { + transport_security security_type = + absl::GetFlag(FLAGS_use_alts) + ? ALTS + : (absl::GetFlag(FLAGS_use_tls) ? TLS : INSECURE); + return CreateTestChannel(server_uri, + absl::GetFlag(FLAGS_server_host_override), + security_type, !absl::GetFlag(FLAGS_use_test_ca), + creds, std::move(interceptor_creators)); + } else { + if (interceptor_creators.empty()) { + return CreateTestChannel( + server_uri, absl::GetFlag(FLAGS_custom_credentials_type), creds); + } else { + return CreateTestChannel(server_uri, + absl::GetFlag(FLAGS_custom_credentials_type), + creds, std::move(interceptor_creators)); + } + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/interop/grpclb_fallback_test.cc b/test/cpp/interop/grpclb_fallback_test.cc new file mode 100644 index 00000000..dc4c08a7 --- /dev/null +++ b/test/cpp/interop/grpclb_fallback_test.cc @@ -0,0 +1,296 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/socket_mutator.h" +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "src/proto/grpc/testing/test.pb.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(std::string, custom_credentials_type, "", + "User provided credentials type."); +ABSL_FLAG(std::string, server_uri, "localhost:1000", "Server URI target"); +ABSL_FLAG(std::string, unroute_lb_and_backend_addrs_cmd, "exit 1", + "Shell command used to make LB and backend addresses unroutable"); +ABSL_FLAG(std::string, blackhole_lb_and_backend_addrs_cmd, "exit 1", + "Shell command used to make LB and backend addresses blackholed"); +ABSL_FLAG( + std::string, test_case, "", + "Test case to run. Valid options are:\n\n" + "fast_fallback_before_startup : fallback before establishing connection to " + "LB;\n" + "fast_fallback_after_startup : fallback after startup due to LB/backend " + "addresses becoming unroutable;\n" + "slow_fallback_before_startup : fallback before startup due to LB address " + "being blackholed;\n" + "slow_fallback_after_startup : fallback after startup due to LB/backend " + "addresses becoming blackholed;\n"); + +#ifdef LINUX_VERSION_CODE +#if LINUX_VERSION_CODE >= KERNEL_VERSION(2, 6, 37) +#define SOCKET_SUPPORTS_TCP_USER_TIMEOUT +#endif +#endif + +#ifdef SOCKET_SUPPORTS_TCP_USER_TIMEOUT +using grpc::testing::GrpclbRouteType; +using grpc::testing::SimpleRequest; +using grpc::testing::SimpleResponse; +using grpc::testing::TestService; + +namespace { + +enum RpcMode { + FailFast, + WaitForReady, +}; + +GrpclbRouteType DoRPCAndGetPath(TestService::Stub* stub, int deadline_seconds, + RpcMode rpc_mode) { + gpr_log(GPR_INFO, "DoRPCAndGetPath deadline_seconds:%d rpc_mode:%d", + deadline_seconds, rpc_mode); + SimpleRequest request; + SimpleResponse response; + grpc::ClientContext context; + if (rpc_mode == WaitForReady) { + context.set_wait_for_ready(true); + } + request.set_fill_grpclb_route_type(true); + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(deadline_seconds); + context.set_deadline(deadline); + grpc::Status s = stub->UnaryCall(&context, request, &response); + if (!s.ok()) { + gpr_log(GPR_INFO, "DoRPCAndGetPath failed. status-message: %s", + s.error_message().c_str()); + return GrpclbRouteType::GRPCLB_ROUTE_TYPE_UNKNOWN; + } + GPR_ASSERT(response.grpclb_route_type() == + GrpclbRouteType::GRPCLB_ROUTE_TYPE_BACKEND || + response.grpclb_route_type() == + GrpclbRouteType::GRPCLB_ROUTE_TYPE_FALLBACK); + gpr_log(GPR_INFO, "DoRPCAndGetPath done. grpclb_route_type:%d", + response.grpclb_route_type()); + return response.grpclb_route_type(); +} + +GrpclbRouteType DoRPCAndGetPath(TestService::Stub* stub, int deadline_seconds) { + return DoRPCAndGetPath(stub, deadline_seconds, FailFast); +} + +GrpclbRouteType DoWaitForReadyRPCAndGetPath(TestService::Stub* stub, + int deadline_seconds) { + return DoRPCAndGetPath(stub, deadline_seconds, WaitForReady); +} + +bool TcpUserTimeoutMutateFd(int fd, grpc_socket_mutator* /*mutator*/) { + int timeout = 20000; // 20 seconds + gpr_log(GPR_INFO, "Setting socket option TCP_USER_TIMEOUT on fd: %d", fd); + if (0 != setsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &timeout, + sizeof(timeout))) { + gpr_log(GPR_ERROR, "Failed to set socket option TCP_USER_TIMEOUT"); + abort(); + } + int newval; + socklen_t len = sizeof(newval); + if (0 != getsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &newval, &len) || + newval != timeout) { + gpr_log(GPR_ERROR, "Failed to get expected socket option TCP_USER_TIMEOUT"); + abort(); + } + return true; +} + +int TcpUserTimeoutCompare(grpc_socket_mutator* /*a*/, + grpc_socket_mutator* /*b*/) { + return 0; +} + +void TcpUserTimeoutDestroy(grpc_socket_mutator* mutator) { gpr_free(mutator); } + +const grpc_socket_mutator_vtable kTcpUserTimeoutMutatorVtable = + grpc_socket_mutator_vtable{TcpUserTimeoutMutateFd, TcpUserTimeoutCompare, + TcpUserTimeoutDestroy, nullptr}; + +std::unique_ptr CreateFallbackTestStub() { + grpc::ChannelArguments channel_args; + grpc_socket_mutator* tcp_user_timeout_mutator = + static_cast( + gpr_malloc(sizeof(tcp_user_timeout_mutator))); + grpc_socket_mutator_init(tcp_user_timeout_mutator, + &kTcpUserTimeoutMutatorVtable); + channel_args.SetSocketMutator(tcp_user_timeout_mutator); + // Allow LB policy to be configured by service config + channel_args.SetInt(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION, 0); + std::shared_ptr channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + absl::GetFlag(FLAGS_custom_credentials_type), &channel_args); + return TestService::NewStub(grpc::CreateCustomChannel( + absl::GetFlag(FLAGS_server_uri), channel_creds, channel_args)); +} + +void RunCommand(const std::string& command) { + gpr_log(GPR_INFO, "RunCommand: |%s|", command.c_str()); + int out = std::system(command.c_str()); + if (WIFEXITED(out)) { + int code = WEXITSTATUS(out); + if (code != 0) { + gpr_log(GPR_ERROR, "RunCommand failed exit code:%d command:|%s|", code, + command.c_str()); + abort(); + } + } else { + gpr_log(GPR_ERROR, "RunCommand failed command:|%s|", command.c_str()); + abort(); + } +} + +void RunFallbackBeforeStartupTest( + const std::string& break_lb_and_backend_conns_cmd, + int per_rpc_deadline_seconds) { + std::unique_ptr stub = CreateFallbackTestStub(); + RunCommand(break_lb_and_backend_conns_cmd); + for (size_t i = 0; i < 30; i++) { + GrpclbRouteType grpclb_route_type = + DoRPCAndGetPath(stub.get(), per_rpc_deadline_seconds); + if (grpclb_route_type != GrpclbRouteType::GRPCLB_ROUTE_TYPE_FALLBACK) { + gpr_log(GPR_ERROR, "Expected grpclb route type: FALLBACK. Got: %d", + grpclb_route_type); + abort(); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } +} + +void DoFastFallbackBeforeStartup() { + RunFallbackBeforeStartupTest( + absl::GetFlag(FLAGS_unroute_lb_and_backend_addrs_cmd), 9); +} + +void DoSlowFallbackBeforeStartup() { + RunFallbackBeforeStartupTest( + absl::GetFlag(FLAGS_blackhole_lb_and_backend_addrs_cmd), 20); +} + +void RunFallbackAfterStartupTest( + const std::string& break_lb_and_backend_conns_cmd) { + std::unique_ptr stub = CreateFallbackTestStub(); + GrpclbRouteType grpclb_route_type = DoRPCAndGetPath(stub.get(), 20); + if (grpclb_route_type != GrpclbRouteType::GRPCLB_ROUTE_TYPE_BACKEND) { + gpr_log(GPR_ERROR, "Expected grpclb route type: BACKEND. Got: %d", + grpclb_route_type); + abort(); + } + RunCommand(break_lb_and_backend_conns_cmd); + for (size_t i = 0; i < 40; i++) { + GrpclbRouteType grpclb_route_type = + DoWaitForReadyRPCAndGetPath(stub.get(), 1); + // Backends should be unreachable by now, otherwise the test is broken. + GPR_ASSERT(grpclb_route_type != GrpclbRouteType::GRPCLB_ROUTE_TYPE_BACKEND); + if (grpclb_route_type == GrpclbRouteType::GRPCLB_ROUTE_TYPE_FALLBACK) { + gpr_log(GPR_INFO, + "Made one successul RPC to a fallback. Now expect the same for " + "the rest."); + break; + } else { + gpr_log(GPR_ERROR, "Retryable RPC failure on iteration: %" PRIdPTR, i); + } + } + for (size_t i = 0; i < 30; i++) { + GrpclbRouteType grpclb_route_type = DoRPCAndGetPath(stub.get(), 20); + if (grpclb_route_type != GrpclbRouteType::GRPCLB_ROUTE_TYPE_FALLBACK) { + gpr_log(GPR_ERROR, "Expected grpclb route type: FALLBACK. Got: %d", + grpclb_route_type); + abort(); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } +} + +void DoFastFallbackAfterStartup() { + RunFallbackAfterStartupTest( + absl::GetFlag(FLAGS_unroute_lb_and_backend_addrs_cmd)); +} + +void DoSlowFallbackAfterStartup() { + RunFallbackAfterStartupTest( + absl::GetFlag(FLAGS_blackhole_lb_and_backend_addrs_cmd)); +} +} // namespace + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + gpr_log(GPR_INFO, "Testing: %s", absl::GetFlag(FLAGS_test_case).c_str()); + if (absl::GetFlag(FLAGS_test_case) == "fast_fallback_before_startup") { + DoFastFallbackBeforeStartup(); + gpr_log(GPR_INFO, "DoFastFallbackBeforeStartup done!"); + } else if (absl::GetFlag(FLAGS_test_case) == "slow_fallback_before_startup") { + DoSlowFallbackBeforeStartup(); + gpr_log(GPR_INFO, "DoSlowFallbackBeforeStartup done!"); + } else if (absl::GetFlag(FLAGS_test_case) == "fast_fallback_after_startup") { + DoFastFallbackAfterStartup(); + gpr_log(GPR_INFO, "DoFastFallbackAfterStartup done!"); + } else if (absl::GetFlag(FLAGS_test_case) == "slow_fallback_after_startup") { + DoSlowFallbackAfterStartup(); + gpr_log(GPR_INFO, "DoSlowFallbackAfterStartup done!"); + } else { + gpr_log(GPR_ERROR, "Invalid test case: %s", + absl::GetFlag(FLAGS_test_case).c_str()); + abort(); + } +} + +#else + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + gpr_log(GPR_ERROR, + "This test requires TCP_USER_TIMEOUT, which isn't available"); + abort(); +} + +#endif // SOCKET_SUPPORTS_TCP_USER_TIMEOUT diff --git a/test/cpp/interop/http2_client.cc b/test/cpp/interop/http2_client.cc new file mode 100644 index 00000000..d890d6a3 --- /dev/null +++ b/test/cpp/interop/http2_client.cc @@ -0,0 +1,232 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/interop/http2_client.h" + +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/transport/byte_stream.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +namespace { +const int kLargeRequestSize = 271828; +const int kLargeResponseSize = 314159; +} // namespace + +Http2Client::ServiceStub::ServiceStub(const std::shared_ptr& channel) + : channel_(channel) { + stub_ = TestService::NewStub(channel); +} + +TestService::Stub* Http2Client::ServiceStub::Get() { return stub_.get(); } + +Http2Client::Http2Client(const std::shared_ptr& channel) + : serviceStub_(channel), + channel_(channel), + defaultRequest_(BuildDefaultRequest()) {} + +bool Http2Client::AssertStatusCode(const Status& s, StatusCode expected_code) { + if (s.error_code() == expected_code) { + return true; + } + + gpr_log(GPR_ERROR, "Error status code: %d (expected: %d), message: %s", + s.error_code(), expected_code, s.error_message().c_str()); + abort(); +} + +Status Http2Client::SendUnaryCall(SimpleResponse* response) { + ClientContext context; + return serviceStub_.Get()->UnaryCall(&context, defaultRequest_, response); +} + +SimpleRequest Http2Client::BuildDefaultRequest() { + SimpleRequest request; + request.set_response_size(kLargeResponseSize); + std::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + return request; +} + +bool Http2Client::DoRstAfterHeader() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream after header"); + + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::INTERNAL); + GPR_ASSERT(!response.has_payload()); // no data should be received + + gpr_log(GPR_DEBUG, "Done testing reset stream after header"); + return true; +} + +bool Http2Client::DoRstAfterData() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream after data"); + + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::INTERNAL); + // There is no guarantee that data would be received. + + gpr_log(GPR_DEBUG, "Done testing reset stream after data"); + return true; +} + +bool Http2Client::DoRstDuringData() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting reset stream during data"); + + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::INTERNAL); + GPR_ASSERT(!response.has_payload()); // no data should be received + + gpr_log(GPR_DEBUG, "Done testing reset stream during data"); + return true; +} + +bool Http2Client::DoGoaway() { + gpr_log(GPR_DEBUG, "Sending two RPCs and expecting goaway"); + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); + + // Sleep for one second to give time for client to receive goaway frame. + gpr_timespec sleep_time = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(1, GPR_TIMESPAN)); + gpr_sleep_until(sleep_time); + + response.Clear(); + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); + gpr_log(GPR_DEBUG, "Done testing goaway"); + return true; +} + +bool Http2Client::DoPing() { + gpr_log(GPR_DEBUG, "Sending RPC and expecting ping"); + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); + gpr_log(GPR_DEBUG, "Done testing ping"); + return true; +} + +void Http2Client::MaxStreamsWorker( + const std::shared_ptr& /*channel*/) { + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); +} + +bool Http2Client::DoMaxStreams() { + gpr_log(GPR_DEBUG, "Testing max streams"); + + // Make an initial call on the channel to ensure the server's max streams + // setting is received + SimpleResponse response; + AssertStatusCode(SendUnaryCall(&response), grpc::StatusCode::OK); + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); + + std::vector test_threads; + test_threads.reserve(10); + for (int i = 0; i < 10; i++) { + test_threads.emplace_back( + std::thread(&Http2Client::MaxStreamsWorker, this, channel_)); + } + + for (auto it = test_threads.begin(); it != test_threads.end(); it++) { + it->join(); + } + + gpr_log(GPR_DEBUG, "Done testing max streams"); + return true; +} + +} // namespace testing +} // namespace grpc + +ABSL_FLAG(int32_t, server_port, 0, "Server port."); +ABSL_FLAG(std::string, server_host, "localhost", "Server host to connect to"); +ABSL_FLAG(std::string, test_case, "rst_after_header", + "Configure different test cases. Valid options are:\n\n" + "goaway\n" + "max_streams\n" + "ping\n" + "rst_after_data\n" + "rst_after_header\n" + "rst_during_data\n"); + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + GPR_ASSERT(absl::GetFlag(FLAGS_server_port)); + const int host_port_buf_size = 1024; + char host_port[host_port_buf_size]; + snprintf(host_port, host_port_buf_size, "%s:%d", + absl::GetFlag(FLAGS_server_host).c_str(), + absl::GetFlag(FLAGS_server_port)); + std::shared_ptr channel = + grpc::CreateTestChannel(host_port, grpc::testing::INSECURE); + GPR_ASSERT(channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(300, GPR_TIMESPAN)))); + grpc::testing::Http2Client client(channel); + gpr_log(GPR_INFO, "Testing case: %s", absl::GetFlag(FLAGS_test_case).c_str()); + int ret = 0; + if (absl::GetFlag(FLAGS_test_case) == "rst_after_header") { + client.DoRstAfterHeader(); + } else if (absl::GetFlag(FLAGS_test_case) == "rst_after_data") { + client.DoRstAfterData(); + } else if (absl::GetFlag(FLAGS_test_case) == "rst_during_data") { + client.DoRstDuringData(); + } else if (absl::GetFlag(FLAGS_test_case) == "goaway") { + client.DoGoaway(); + } else if (absl::GetFlag(FLAGS_test_case) == "ping") { + client.DoPing(); + } else if (absl::GetFlag(FLAGS_test_case) == "max_streams") { + client.DoMaxStreams(); + } else { + const char* testcases[] = { + "goaway", "max_streams", "ping", + "rst_after_data", "rst_after_header", "rst_during_data"}; + char* joined_testcases = + gpr_strjoin_sep(testcases, GPR_ARRAY_SIZE(testcases), "\n", nullptr); + + gpr_log(GPR_ERROR, "Unsupported test case %s. Valid options are\n%s", + absl::GetFlag(FLAGS_test_case).c_str(), joined_testcases); + gpr_free(joined_testcases); + ret = 1; + } + + return ret; +} diff --git a/test/cpp/interop/interop_client.cc b/test/cpp/interop/interop_client.cc new file mode 100644 index 00000000..c29f6eb9 --- /dev/null +++ b/test/cpp/interop/interop_client.cc @@ -0,0 +1,1302 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/interop/interop_client.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/histogram.h" +#include "test/cpp/interop/client_helper.h" + +namespace grpc { +namespace testing { + +namespace { +// The same value is defined by the Java client. +const std::vector request_stream_sizes = {27182, 8, 1828, 45904}; +const std::vector response_stream_sizes = {31415, 9, 2653, 58979}; +const int kNumResponseMessages = 2000; +const int kResponseMessageSize = 1030; +const int kReceiveDelayMilliSeconds = 20; +const int kLargeRequestSize = 271828; +const int kLargeResponseSize = 314159; + +void NoopChecks(const InteropClientContextInspector& /*inspector*/, + const SimpleRequest* /*request*/, + const SimpleResponse* /*response*/) {} + +void UnaryCompressionChecks(const InteropClientContextInspector& inspector, + const SimpleRequest* request, + const SimpleResponse* /*response*/) { + const grpc_compression_algorithm received_compression = + inspector.GetCallCompressionAlgorithm(); + if (request->response_compressed().value()) { + if (received_compression == GRPC_COMPRESS_NONE) { + // Requested some compression, got NONE. This is an error. + gpr_log(GPR_ERROR, + "Failure: Requested compression but got uncompressed response " + "from server."); + abort(); + } + GPR_ASSERT(inspector.WasCompressed()); + } else { + // Didn't request compression -> make sure the response is uncompressed + GPR_ASSERT(!(inspector.WasCompressed())); + } +} +} // namespace + +InteropClient::ServiceStub::ServiceStub( + ChannelCreationFunc channel_creation_func, bool new_stub_every_call) + : channel_creation_func_(std::move(channel_creation_func)), + channel_(channel_creation_func_()), + new_stub_every_call_(new_stub_every_call) { + // If new_stub_every_call is false, then this is our chance to initialize + // stub_. (see Get()) + if (!new_stub_every_call) { + stub_ = TestService::NewStub(channel_); + } +} + +TestService::Stub* InteropClient::ServiceStub::Get() { + if (new_stub_every_call_) { + stub_ = TestService::NewStub(channel_); + } + + return stub_.get(); +} + +UnimplementedService::Stub* +InteropClient::ServiceStub::GetUnimplementedServiceStub() { + if (unimplemented_service_stub_ == nullptr) { + unimplemented_service_stub_ = UnimplementedService::NewStub(channel_); + } + return unimplemented_service_stub_.get(); +} + +void InteropClient::ServiceStub::ResetChannel() { + channel_ = channel_creation_func_(); + if (!new_stub_every_call_) { + stub_ = TestService::NewStub(channel_); + } +} + +InteropClient::InteropClient(ChannelCreationFunc channel_creation_func, + bool new_stub_every_test_case, + bool do_not_abort_on_transient_failures) + : serviceStub_(std::move(channel_creation_func), new_stub_every_test_case), + do_not_abort_on_transient_failures_(do_not_abort_on_transient_failures) {} + +bool InteropClient::AssertStatusOk(const Status& s, + const std::string& optional_debug_string) { + if (s.ok()) { + return true; + } + + // Note: At this point, s.error_code is definitely not StatusCode::OK (we + // already checked for s.ok() above). So, the following will call abort() + // (unless s.error_code() corresponds to a transient failure and + // 'do_not_abort_on_transient_failures' is true) + return AssertStatusCode(s, StatusCode::OK, optional_debug_string); +} + +bool InteropClient::AssertStatusCode(const Status& s, StatusCode expected_code, + const std::string& optional_debug_string) { + if (s.error_code() == expected_code) { + return true; + } + + gpr_log(GPR_ERROR, + "Error status code: %d (expected: %d), message: %s," + " debug string: %s", + s.error_code(), expected_code, s.error_message().c_str(), + optional_debug_string.c_str()); + + // In case of transient transient/retryable failures (like a broken + // connection) we may or may not abort (see TransientFailureOrAbort()) + if (s.error_code() == grpc::StatusCode::UNAVAILABLE) { + return TransientFailureOrAbort(); + } + + abort(); +} + +bool InteropClient::DoEmpty() { + gpr_log(GPR_DEBUG, "Sending an empty rpc..."); + + Empty request; + Empty response; + ClientContext context; + + Status s = serviceStub_.Get()->EmptyCall(&context, request, &response); + + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Empty rpc done."); + return true; +} + +bool InteropClient::PerformLargeUnary(SimpleRequest* request, + SimpleResponse* response) { + return PerformLargeUnary(request, response, NoopChecks); +} + +bool InteropClient::PerformLargeUnary(SimpleRequest* request, + SimpleResponse* response, + const CheckerFn& custom_checks_fn) { + ClientContext context; + InteropClientContextInspector inspector(context); + request->set_response_size(kLargeResponseSize); + std::string payload(kLargeRequestSize, '\0'); + request->mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + if (request->has_expect_compressed()) { + if (request->expect_compressed().value()) { + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + } else { + context.set_compression_algorithm(GRPC_COMPRESS_NONE); + } + } + + Status s = serviceStub_.Get()->UnaryCall(&context, *request, response); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + custom_checks_fn(inspector, request, response); + + // Payload related checks. + GPR_ASSERT(response->payload().body() == + std::string(kLargeResponseSize, '\0')); + return true; +} + +bool InteropClient::DoComputeEngineCreds( + const std::string& default_service_account, + const std::string& oauth_scope) { + gpr_log(GPR_DEBUG, + "Sending a large unary rpc with compute engine credentials ..."); + SimpleRequest request; + SimpleResponse response; + request.set_fill_username(true); + request.set_fill_oauth_scope(true); + + if (!PerformLargeUnary(&request, &response)) { + return false; + } + + gpr_log(GPR_DEBUG, "Got username %s", response.username().c_str()); + gpr_log(GPR_DEBUG, "Got oauth_scope %s", response.oauth_scope().c_str()); + GPR_ASSERT(!response.username().empty()); + GPR_ASSERT(response.username().c_str() == default_service_account); + GPR_ASSERT(!response.oauth_scope().empty()); + const char* oauth_scope_str = response.oauth_scope().c_str(); + GPR_ASSERT(absl::StrContains(oauth_scope, oauth_scope_str)); + gpr_log(GPR_DEBUG, "Large unary with compute engine creds done."); + return true; +} + +bool InteropClient::DoOauth2AuthToken(const std::string& username, + const std::string& oauth_scope) { + gpr_log(GPR_DEBUG, + "Sending a unary rpc with raw oauth2 access token credentials ..."); + SimpleRequest request; + SimpleResponse response; + request.set_fill_username(true); + request.set_fill_oauth_scope(true); + + ClientContext context; + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + GPR_ASSERT(!response.username().empty()); + GPR_ASSERT(!response.oauth_scope().empty()); + GPR_ASSERT(username == response.username()); + const char* oauth_scope_str = response.oauth_scope().c_str(); + GPR_ASSERT(absl::StrContains(oauth_scope, oauth_scope_str)); + gpr_log(GPR_DEBUG, "Unary with oauth2 access token credentials done."); + return true; +} + +bool InteropClient::DoPerRpcCreds(const std::string& json_key) { + gpr_log(GPR_DEBUG, "Sending a unary rpc with per-rpc JWT access token ..."); + SimpleRequest request; + SimpleResponse response; + request.set_fill_username(true); + + ClientContext context; + std::chrono::seconds token_lifetime = std::chrono::hours(1); + std::shared_ptr creds = + ServiceAccountJWTAccessCredentials(json_key, token_lifetime.count()); + + context.set_credentials(creds); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + GPR_ASSERT(!response.username().empty()); + GPR_ASSERT(json_key.find(response.username()) != std::string::npos); + gpr_log(GPR_DEBUG, "Unary with per-rpc JWT access token done."); + return true; +} + +bool InteropClient::DoJwtTokenCreds(const std::string& username) { + gpr_log(GPR_DEBUG, + "Sending a large unary rpc with JWT token credentials ..."); + SimpleRequest request; + SimpleResponse response; + request.set_fill_username(true); + + if (!PerformLargeUnary(&request, &response)) { + return false; + } + + GPR_ASSERT(!response.username().empty()); + GPR_ASSERT(username.find(response.username()) != std::string::npos); + gpr_log(GPR_DEBUG, "Large unary with JWT token creds done."); + return true; +} + +bool InteropClient::DoGoogleDefaultCredentials( + const std::string& default_service_account) { + gpr_log(GPR_DEBUG, + "Sending a large unary rpc with GoogleDefaultCredentials..."); + SimpleRequest request; + SimpleResponse response; + request.set_fill_username(true); + + if (!PerformLargeUnary(&request, &response)) { + return false; + } + + gpr_log(GPR_DEBUG, "Got username %s", response.username().c_str()); + GPR_ASSERT(!response.username().empty()); + GPR_ASSERT(response.username().c_str() == default_service_account); + gpr_log(GPR_DEBUG, "Large unary rpc with GoogleDefaultCredentials done."); + return true; +} + +bool InteropClient::DoLargeUnary() { + gpr_log(GPR_DEBUG, "Sending a large unary rpc..."); + SimpleRequest request; + SimpleResponse response; + if (!PerformLargeUnary(&request, &response)) { + return false; + } + gpr_log(GPR_DEBUG, "Large unary done."); + return true; +} + +bool InteropClient::DoClientCompressedUnary() { + // Probing for compression-checks support. + ClientContext probe_context; + SimpleRequest probe_req; + SimpleResponse probe_res; + + probe_context.set_compression_algorithm(GRPC_COMPRESS_NONE); + probe_req.mutable_expect_compressed()->set_value(true); // lies! + + probe_req.set_response_size(kLargeResponseSize); + probe_req.mutable_payload()->set_body(std::string(kLargeRequestSize, '\0')); + + gpr_log(GPR_DEBUG, "Sending probe for compressed unary request."); + const Status s = + serviceStub_.Get()->UnaryCall(&probe_context, probe_req, &probe_res); + if (s.error_code() != grpc::StatusCode::INVALID_ARGUMENT) { + // The server isn't able to evaluate incoming compression, making the rest + // of this test moot. + gpr_log(GPR_DEBUG, "Compressed unary request probe failed"); + return false; + } + gpr_log(GPR_DEBUG, "Compressed unary request probe succeeded. Proceeding."); + + const std::vector compressions = {true, false}; + for (size_t i = 0; i < compressions.size(); i++) { + std::string log_suffix = + absl::StrFormat("(compression=%s)", compressions[i] ? "true" : "false"); + + gpr_log(GPR_DEBUG, "Sending compressed unary request %s.", + log_suffix.c_str()); + SimpleRequest request; + SimpleResponse response; + request.mutable_expect_compressed()->set_value(compressions[i]); + if (!PerformLargeUnary(&request, &response, UnaryCompressionChecks)) { + gpr_log(GPR_ERROR, "Compressed unary request failed %s", + log_suffix.c_str()); + return false; + } + + gpr_log(GPR_DEBUG, "Compressed unary request failed %s", + log_suffix.c_str()); + } + + return true; +} + +bool InteropClient::DoServerCompressedUnary() { + const std::vector compressions = {true, false}; + for (size_t i = 0; i < compressions.size(); i++) { + std::string log_suffix = + absl::StrFormat("(compression=%s)", compressions[i] ? "true" : "false"); + + gpr_log(GPR_DEBUG, "Sending unary request for compressed response %s.", + log_suffix.c_str()); + SimpleRequest request; + SimpleResponse response; + request.mutable_response_compressed()->set_value(compressions[i]); + + if (!PerformLargeUnary(&request, &response, UnaryCompressionChecks)) { + gpr_log(GPR_ERROR, "Request for compressed unary failed %s", + log_suffix.c_str()); + return false; + } + + gpr_log(GPR_DEBUG, "Request for compressed unary failed %s", + log_suffix.c_str()); + } + + return true; +} + +// Either abort() (unless do_not_abort_on_transient_failures_ is true) or return +// false +bool InteropClient::TransientFailureOrAbort() { + if (do_not_abort_on_transient_failures_) { + return false; + } + + abort(); +} + +bool InteropClient::DoRequestStreaming() { + gpr_log(GPR_DEBUG, "Sending request steaming rpc ..."); + + ClientContext context; + StreamingInputCallRequest request; + StreamingInputCallResponse response; + + std::unique_ptr> stream( + serviceStub_.Get()->StreamingInputCall(&context, &response)); + + int aggregated_payload_size = 0; + for (size_t i = 0; i < request_stream_sizes.size(); ++i) { + Payload* payload = request.mutable_payload(); + payload->set_body(std::string(request_stream_sizes[i], '\0')); + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "DoRequestStreaming(): stream->Write() failed"); + return TransientFailureOrAbort(); + } + aggregated_payload_size += request_stream_sizes[i]; + } + GPR_ASSERT(stream->WritesDone()); + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + GPR_ASSERT(response.aggregated_payload_size() == aggregated_payload_size); + return true; +} + +bool InteropClient::DoResponseStreaming() { + gpr_log(GPR_DEBUG, "Receiving response streaming rpc ..."); + + ClientContext context; + StreamingOutputCallRequest request; + for (unsigned int i = 0; i < response_stream_sizes.size(); ++i) { + ResponseParameters* response_parameter = request.add_response_parameters(); + response_parameter->set_size(response_stream_sizes[i]); + } + StreamingOutputCallResponse response; + std::unique_ptr> stream( + serviceStub_.Get()->StreamingOutputCall(&context, request)); + + unsigned int i = 0; + while (stream->Read(&response)) { + GPR_ASSERT(response.payload().body() == + std::string(response_stream_sizes[i], '\0')); + ++i; + } + + if (i < response_stream_sizes.size()) { + // stream->Read() failed before reading all the expected messages. This is + // most likely due to connection failure. + gpr_log(GPR_ERROR, + "DoResponseStreaming(): Read fewer streams (%d) than " + "response_stream_sizes.size() (%" PRIuPTR ")", + i, response_stream_sizes.size()); + return TransientFailureOrAbort(); + } + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Response streaming done."); + return true; +} + +bool InteropClient::DoClientCompressedStreaming() { + // Probing for compression-checks support. + ClientContext probe_context; + StreamingInputCallRequest probe_req; + StreamingInputCallResponse probe_res; + + probe_context.set_compression_algorithm(GRPC_COMPRESS_NONE); + probe_req.mutable_expect_compressed()->set_value(true); // lies! + probe_req.mutable_payload()->set_body(std::string(27182, '\0')); + + gpr_log(GPR_DEBUG, "Sending probe for compressed streaming request."); + + std::unique_ptr> probe_stream( + serviceStub_.Get()->StreamingInputCall(&probe_context, &probe_res)); + + if (!probe_stream->Write(probe_req)) { + gpr_log(GPR_ERROR, "%s(): stream->Write() failed", __func__); + return TransientFailureOrAbort(); + } + Status s = probe_stream->Finish(); + if (s.error_code() != grpc::StatusCode::INVALID_ARGUMENT) { + // The server isn't able to evaluate incoming compression, making the rest + // of this test moot. + gpr_log(GPR_DEBUG, "Compressed streaming request probe failed"); + return false; + } + gpr_log(GPR_DEBUG, + "Compressed streaming request probe succeeded. Proceeding."); + + ClientContext context; + StreamingInputCallRequest request; + StreamingInputCallResponse response; + + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + std::unique_ptr> stream( + serviceStub_.Get()->StreamingInputCall(&context, &response)); + + request.mutable_payload()->set_body(std::string(27182, '\0')); + request.mutable_expect_compressed()->set_value(true); + gpr_log(GPR_DEBUG, "Sending streaming request with compression enabled"); + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "%s(): stream->Write() failed", __func__); + return TransientFailureOrAbort(); + } + + WriteOptions wopts; + wopts.set_no_compression(); + request.mutable_payload()->set_body(std::string(45904, '\0')); + request.mutable_expect_compressed()->set_value(false); + gpr_log(GPR_DEBUG, "Sending streaming request with compression disabled"); + if (!stream->Write(request, wopts)) { + gpr_log(GPR_ERROR, "%s(): stream->Write() failed", __func__); + return TransientFailureOrAbort(); + } + GPR_ASSERT(stream->WritesDone()); + + s = stream->Finish(); + return AssertStatusOk(s, context.debug_error_string()); +} + +bool InteropClient::DoServerCompressedStreaming() { + const std::vector compressions = {true, false}; + const std::vector sizes = {31415, 92653}; + + ClientContext context; + InteropClientContextInspector inspector(context); + StreamingOutputCallRequest request; + + GPR_ASSERT(compressions.size() == sizes.size()); + for (size_t i = 0; i < sizes.size(); i++) { + std::string log_suffix = + absl::StrFormat("(compression=%s; size=%d)", + compressions[i] ? "true" : "false", sizes[i]); + + gpr_log(GPR_DEBUG, "Sending request streaming rpc %s.", log_suffix.c_str()); + + ResponseParameters* const response_parameter = + request.add_response_parameters(); + response_parameter->mutable_compressed()->set_value(compressions[i]); + response_parameter->set_size(sizes[i]); + } + std::unique_ptr> stream( + serviceStub_.Get()->StreamingOutputCall(&context, request)); + + size_t k = 0; + StreamingOutputCallResponse response; + while (stream->Read(&response)) { + // Payload size checks. + GPR_ASSERT(response.payload().body() == + std::string(request.response_parameters(k).size(), '\0')); + + // Compression checks. + GPR_ASSERT(request.response_parameters(k).has_compressed()); + if (request.response_parameters(k).compressed().value()) { + GPR_ASSERT(inspector.GetCallCompressionAlgorithm() > GRPC_COMPRESS_NONE); + GPR_ASSERT(inspector.WasCompressed()); + } else { + // requested *no* compression. + GPR_ASSERT(!(inspector.WasCompressed())); + } + ++k; + } + + if (k < sizes.size()) { + // stream->Read() failed before reading all the expected messages. This + // is most likely due to a connection failure. + gpr_log(GPR_ERROR, + "%s(): Responses read (k=%" PRIuPTR + ") is less than the expected number of messages (%" PRIuPTR ").", + __func__, k, sizes.size()); + return TransientFailureOrAbort(); + } + + Status s = stream->Finish(); + return AssertStatusOk(s, context.debug_error_string()); +} + +bool InteropClient::DoResponseStreamingWithSlowConsumer() { + gpr_log(GPR_DEBUG, "Receiving response streaming rpc with slow consumer ..."); + + ClientContext context; + StreamingOutputCallRequest request; + + for (int i = 0; i < kNumResponseMessages; ++i) { + ResponseParameters* response_parameter = request.add_response_parameters(); + response_parameter->set_size(kResponseMessageSize); + } + StreamingOutputCallResponse response; + std::unique_ptr> stream( + serviceStub_.Get()->StreamingOutputCall(&context, request)); + + int i = 0; + while (stream->Read(&response)) { + GPR_ASSERT(response.payload().body() == + std::string(kResponseMessageSize, '\0')); + gpr_log(GPR_DEBUG, "received message %d", i); + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(kReceiveDelayMilliSeconds, GPR_TIMESPAN))); + ++i; + } + + if (i < kNumResponseMessages) { + gpr_log(GPR_ERROR, + "DoResponseStreamingWithSlowConsumer(): Responses read (i=%d) is " + "less than the expected messages (i.e kNumResponseMessages = %d)", + i, kNumResponseMessages); + + return TransientFailureOrAbort(); + } + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Response streaming done."); + return true; +} + +bool InteropClient::DoHalfDuplex() { + gpr_log(GPR_DEBUG, "Sending half-duplex streaming rpc ..."); + + ClientContext context; + std::unique_ptr> + stream(serviceStub_.Get()->HalfDuplexCall(&context)); + + StreamingOutputCallRequest request; + ResponseParameters* response_parameter = request.add_response_parameters(); + for (unsigned int i = 0; i < response_stream_sizes.size(); ++i) { + response_parameter->set_size(response_stream_sizes[i]); + + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "DoHalfDuplex(): stream->Write() failed. i=%d", i); + return TransientFailureOrAbort(); + } + } + stream->WritesDone(); + + unsigned int i = 0; + StreamingOutputCallResponse response; + while (stream->Read(&response)) { + GPR_ASSERT(response.payload().body() == + std::string(response_stream_sizes[i], '\0')); + ++i; + } + + if (i < response_stream_sizes.size()) { + // stream->Read() failed before reading all the expected messages. This is + // most likely due to a connection failure + gpr_log(GPR_ERROR, + "DoHalfDuplex(): Responses read (i=%d) are less than the expected " + "number of messages response_stream_sizes.size() (%" PRIuPTR ")", + i, response_stream_sizes.size()); + return TransientFailureOrAbort(); + } + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Half-duplex streaming rpc done."); + return true; +} + +bool InteropClient::DoPingPong() { + gpr_log(GPR_DEBUG, "Sending Ping Pong streaming rpc ..."); + + ClientContext context; + std::unique_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&context)); + + StreamingOutputCallRequest request; + ResponseParameters* response_parameter = request.add_response_parameters(); + Payload* payload = request.mutable_payload(); + StreamingOutputCallResponse response; + + for (unsigned int i = 0; i < request_stream_sizes.size(); ++i) { + response_parameter->set_size(response_stream_sizes[i]); + payload->set_body(std::string(request_stream_sizes[i], '\0')); + + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "DoPingPong(): stream->Write() failed. i: %d", i); + return TransientFailureOrAbort(); + } + + if (!stream->Read(&response)) { + gpr_log(GPR_ERROR, "DoPingPong(): stream->Read() failed. i:%d", i); + return TransientFailureOrAbort(); + } + + GPR_ASSERT(response.payload().body() == + std::string(response_stream_sizes[i], '\0')); + } + + stream->WritesDone(); + + GPR_ASSERT(!stream->Read(&response)); + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Ping pong streaming done."); + return true; +} + +bool InteropClient::DoCancelAfterBegin() { + gpr_log(GPR_DEBUG, "Sending request streaming rpc ..."); + + ClientContext context; + StreamingInputCallRequest request; + StreamingInputCallResponse response; + + std::unique_ptr> stream( + serviceStub_.Get()->StreamingInputCall(&context, &response)); + + gpr_log(GPR_DEBUG, "Trying to cancel..."); + context.TryCancel(); + Status s = stream->Finish(); + + if (!AssertStatusCode(s, StatusCode::CANCELLED, + context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Canceling streaming done."); + return true; +} + +bool InteropClient::DoCancelAfterFirstResponse() { + gpr_log(GPR_DEBUG, "Sending Ping Pong streaming rpc ..."); + + ClientContext context; + std::unique_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&context)); + + StreamingOutputCallRequest request; + ResponseParameters* response_parameter = request.add_response_parameters(); + response_parameter->set_size(31415); + request.mutable_payload()->set_body(std::string(27182, '\0')); + StreamingOutputCallResponse response; + + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "DoCancelAfterFirstResponse(): stream->Write() failed"); + return TransientFailureOrAbort(); + } + + if (!stream->Read(&response)) { + gpr_log(GPR_ERROR, "DoCancelAfterFirstResponse(): stream->Read failed"); + return TransientFailureOrAbort(); + } + GPR_ASSERT(response.payload().body() == std::string(31415, '\0')); + + gpr_log(GPR_DEBUG, "Trying to cancel..."); + context.TryCancel(); + + Status s = stream->Finish(); + gpr_log(GPR_DEBUG, "Canceling pingpong streaming done."); + return true; +} + +bool InteropClient::DoTimeoutOnSleepingServer() { + gpr_log(GPR_DEBUG, + "Sending Ping Pong streaming rpc with a short deadline..."); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(1); + context.set_deadline(deadline); + std::unique_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&context)); + + StreamingOutputCallRequest request; + request.mutable_payload()->set_body(std::string(27182, '\0')); + stream->Write(request); + + Status s = stream->Finish(); + if (!AssertStatusCode(s, StatusCode::DEADLINE_EXCEEDED, + context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "Pingpong streaming timeout done."); + return true; +} + +bool InteropClient::DoEmptyStream() { + gpr_log(GPR_DEBUG, "Starting empty_stream."); + + ClientContext context; + std::unique_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&context)); + stream->WritesDone(); + StreamingOutputCallResponse response; + GPR_ASSERT(stream->Read(&response) == false); + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "empty_stream done."); + return true; +} + +bool InteropClient::DoStatusWithMessage() { + gpr_log(GPR_DEBUG, + "Sending RPC with a request for status code 2 and message"); + + const grpc::StatusCode test_code = grpc::StatusCode::UNKNOWN; + const std::string test_msg = "This is a test message"; + + // Test UnaryCall. + ClientContext context; + SimpleRequest request; + SimpleResponse response; + EchoStatus* requested_status = request.mutable_response_status(); + requested_status->set_code(test_code); + requested_status->set_message(test_msg); + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + if (!AssertStatusCode(s, grpc::StatusCode::UNKNOWN, + context.debug_error_string())) { + return false; + } + GPR_ASSERT(s.error_message() == test_msg); + + // Test FullDuplexCall. + ClientContext stream_context; + std::shared_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&stream_context)); + StreamingOutputCallRequest streaming_request; + requested_status = streaming_request.mutable_response_status(); + requested_status->set_code(test_code); + requested_status->set_message(test_msg); + stream->Write(streaming_request); + stream->WritesDone(); + StreamingOutputCallResponse streaming_response; + while (stream->Read(&streaming_response)) { + } + s = stream->Finish(); + if (!AssertStatusCode(s, grpc::StatusCode::UNKNOWN, + context.debug_error_string())) { + return false; + } + GPR_ASSERT(s.error_message() == test_msg); + + gpr_log(GPR_DEBUG, "Done testing Status and Message"); + return true; +} + +bool InteropClient::DoSpecialStatusMessage() { + gpr_log( + GPR_DEBUG, + "Sending RPC with a request for status code 2 and message - \\t\\ntest " + "with whitespace\\r\\nand Unicode BMP ☺ and non-BMP 😈\\t\\n"); + const grpc::StatusCode test_code = grpc::StatusCode::UNKNOWN; + const std::string test_msg = + "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n"; + ClientContext context; + SimpleRequest request; + SimpleResponse response; + EchoStatus* requested_status = request.mutable_response_status(); + requested_status->set_code(test_code); + requested_status->set_message(test_msg); + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + if (!AssertStatusCode(s, grpc::StatusCode::UNKNOWN, + context.debug_error_string())) { + return false; + } + GPR_ASSERT(s.error_message() == test_msg); + gpr_log(GPR_DEBUG, "Done testing Special Status Message"); + return true; +} + +bool InteropClient::DoCacheableUnary() { + gpr_log(GPR_DEBUG, "Sending RPC with cacheable response"); + + // Create request with current timestamp + gpr_timespec ts = gpr_now(GPR_CLOCK_PRECISE); + std::string timestamp = + std::to_string(static_cast(ts.tv_nsec)); + SimpleRequest request; + request.mutable_payload()->set_body(timestamp.c_str(), timestamp.size()); + + // Request 1 + ClientContext context1; + SimpleResponse response1; + context1.set_cacheable(true); + // Add fake user IP since some proxy's (GFE) won't cache requests from + // localhost. + context1.AddMetadata("x-user-ip", "1.2.3.4"); + Status s1 = + serviceStub_.Get()->CacheableUnaryCall(&context1, request, &response1); + if (!AssertStatusOk(s1, context1.debug_error_string())) { + return false; + } + gpr_log(GPR_DEBUG, "response 1 payload: %s", + response1.payload().body().c_str()); + + // Request 2 + ClientContext context2; + SimpleResponse response2; + context2.set_cacheable(true); + context2.AddMetadata("x-user-ip", "1.2.3.4"); + Status s2 = + serviceStub_.Get()->CacheableUnaryCall(&context2, request, &response2); + if (!AssertStatusOk(s2, context2.debug_error_string())) { + return false; + } + gpr_log(GPR_DEBUG, "response 2 payload: %s", + response2.payload().body().c_str()); + + // Check that the body is same for both requests. It will be the same if the + // second response is a cached copy of the first response + GPR_ASSERT(response2.payload().body() == response1.payload().body()); + + // Request 3 + // Modify the request body so it will not get a cache hit + ts = gpr_now(GPR_CLOCK_PRECISE); + timestamp = std::to_string(static_cast(ts.tv_nsec)); + SimpleRequest request1; + request1.mutable_payload()->set_body(timestamp.c_str(), timestamp.size()); + ClientContext context3; + SimpleResponse response3; + context3.set_cacheable(true); + context3.AddMetadata("x-user-ip", "1.2.3.4"); + Status s3 = + serviceStub_.Get()->CacheableUnaryCall(&context3, request1, &response3); + if (!AssertStatusOk(s3, context3.debug_error_string())) { + return false; + } + gpr_log(GPR_DEBUG, "response 3 payload: %s", + response3.payload().body().c_str()); + + // Check that the response is different from the previous response. + GPR_ASSERT(response3.payload().body() != response1.payload().body()); + return true; +} + +bool InteropClient::DoPickFirstUnary() { + const int rpcCount = 100; + SimpleRequest request; + SimpleResponse response; + std::string server_id; + request.set_fill_server_id(true); + for (int i = 0; i < rpcCount; i++) { + ClientContext context; + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + if (i == 0) { + server_id = response.server_id(); + continue; + } + if (response.server_id() != server_id) { + gpr_log(GPR_ERROR, "#%d rpc hits server_id %s, expect server_id %s", i, + response.server_id().c_str(), server_id.c_str()); + return false; + } + } + gpr_log(GPR_DEBUG, "pick first unary successfully finished"); + return true; +} + +bool InteropClient::DoCustomMetadata() { + const std::string kEchoInitialMetadataKey("x-grpc-test-echo-initial"); + const std::string kInitialMetadataValue("test_initial_metadata_value"); + const std::string kEchoTrailingBinMetadataKey( + "x-grpc-test-echo-trailing-bin"); + const std::string kTrailingBinValue("\x0a\x0b\x0a\x0b\x0a\x0b"); + + { + gpr_log(GPR_DEBUG, "Sending RPC with custom metadata"); + ClientContext context; + context.AddMetadata(kEchoInitialMetadataKey, kInitialMetadataValue); + context.AddMetadata(kEchoTrailingBinMetadataKey, kTrailingBinValue); + SimpleRequest request; + SimpleResponse response; + request.set_response_size(kLargeResponseSize); + std::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + const auto& server_initial_metadata = context.GetServerInitialMetadata(); + auto iter = server_initial_metadata.find(kEchoInitialMetadataKey); + GPR_ASSERT(iter != server_initial_metadata.end()); + GPR_ASSERT(iter->second == kInitialMetadataValue); + const auto& server_trailing_metadata = context.GetServerTrailingMetadata(); + iter = server_trailing_metadata.find(kEchoTrailingBinMetadataKey); + GPR_ASSERT(iter != server_trailing_metadata.end()); + GPR_ASSERT(std::string(iter->second.begin(), iter->second.end()) == + kTrailingBinValue); + + gpr_log(GPR_DEBUG, "Done testing RPC with custom metadata"); + } + + { + gpr_log(GPR_DEBUG, "Sending stream with custom metadata"); + ClientContext context; + context.AddMetadata(kEchoInitialMetadataKey, kInitialMetadataValue); + context.AddMetadata(kEchoTrailingBinMetadataKey, kTrailingBinValue); + std::unique_ptr> + stream(serviceStub_.Get()->FullDuplexCall(&context)); + + StreamingOutputCallRequest request; + ResponseParameters* response_parameter = request.add_response_parameters(); + response_parameter->set_size(kLargeResponseSize); + std::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + StreamingOutputCallResponse response; + + if (!stream->Write(request)) { + gpr_log(GPR_ERROR, "DoCustomMetadata(): stream->Write() failed"); + return TransientFailureOrAbort(); + } + + stream->WritesDone(); + + if (!stream->Read(&response)) { + gpr_log(GPR_ERROR, "DoCustomMetadata(): stream->Read() failed"); + return TransientFailureOrAbort(); + } + + GPR_ASSERT(response.payload().body() == + std::string(kLargeResponseSize, '\0')); + + GPR_ASSERT(!stream->Read(&response)); + + Status s = stream->Finish(); + if (!AssertStatusOk(s, context.debug_error_string())) { + return false; + } + + const auto& server_initial_metadata = context.GetServerInitialMetadata(); + auto iter = server_initial_metadata.find(kEchoInitialMetadataKey); + GPR_ASSERT(iter != server_initial_metadata.end()); + GPR_ASSERT(iter->second == kInitialMetadataValue); + const auto& server_trailing_metadata = context.GetServerTrailingMetadata(); + iter = server_trailing_metadata.find(kEchoTrailingBinMetadataKey); + GPR_ASSERT(iter != server_trailing_metadata.end()); + GPR_ASSERT(std::string(iter->second.begin(), iter->second.end()) == + kTrailingBinValue); + + gpr_log(GPR_DEBUG, "Done testing stream with custom metadata"); + } + + return true; +} + +std::tuple +InteropClient::PerformOneSoakTestIteration( + const bool reset_channel, + const int32_t max_acceptable_per_iteration_latency_ms) { + gpr_timespec start = gpr_now(GPR_CLOCK_MONOTONIC); + SimpleRequest request; + SimpleResponse response; + // Don't set the deadline on the RPC, and instead just + // record how long the RPC took and compare. This makes + // debugging easier when looking at failure results. + ClientContext context; + InteropClientContextInspector inspector(context); + request.set_response_size(kLargeResponseSize); + std::string payload(kLargeRequestSize, '\0'); + request.mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); + if (reset_channel) { + serviceStub_.ResetChannel(); + } + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + int32_t elapsed_ms = gpr_time_to_millis(gpr_time_sub(now, start)); + if (!s.ok()) { + return std::make_tuple(false, elapsed_ms, context.debug_error_string()); + } else if (elapsed_ms > max_acceptable_per_iteration_latency_ms) { + std::string debug_string = absl::StrFormat( + "%d ms exceeds max acceptable latency: %d ms, peer: %s", elapsed_ms, + max_acceptable_per_iteration_latency_ms, context.peer()); + return std::make_tuple(false, elapsed_ms, std::move(debug_string)); + } else { + return std::make_tuple(true, elapsed_ms, ""); + } +} + +void InteropClient::PerformSoakTest( + const bool reset_channel_per_iteration, const int32_t soak_iterations, + const int32_t max_failures, + const int32_t max_acceptable_per_iteration_latency_ms, + const int32_t overall_timeout_seconds) { + std::vector> results; + grpc_histogram* latencies_ms_histogram = grpc_histogram_create( + 1 /* resolution */, + 500 * 1e3 /* largest bucket; 500 seconds is unlikely */); + gpr_timespec overall_deadline = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(overall_timeout_seconds, GPR_TIMESPAN)); + int32_t iterations_ran = 0; + for (int i = 0; + i < soak_iterations && + gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), overall_deadline) < 0; + ++i) { + auto result = PerformOneSoakTestIteration( + reset_channel_per_iteration, max_acceptable_per_iteration_latency_ms); + results.push_back(result); + grpc_histogram_add(latencies_ms_histogram, std::get<1>(result)); + iterations_ran++; + } + int total_failures = 0; + for (size_t i = 0; i < results.size(); i++) { + bool success = std::get<0>(results[i]); + int32_t elapsed_ms = std::get<1>(results[i]); + std::string debug_string = std::get<2>(results[i]); + if (!success) { + gpr_log(GPR_DEBUG, "soak iteration: %ld elapsed_ms: %d failed: %s", i, + elapsed_ms, debug_string.c_str()); + total_failures++; + } else { + gpr_log(GPR_DEBUG, "soak iteration: %ld elapsed_ms: %d succeeded", i, + elapsed_ms); + } + } + double latency_ms_median = + grpc_histogram_percentile(latencies_ms_histogram, 50); + double latency_ms_90th = + grpc_histogram_percentile(latencies_ms_histogram, 90); + double latency_ms_worst = grpc_histogram_maximum(latencies_ms_histogram); + grpc_histogram_destroy(latencies_ms_histogram); + if (iterations_ran < soak_iterations) { + gpr_log( + GPR_ERROR, + "soak test consumed all %d seconds of time and quit early, only " + "having ran %d out of desired %d iterations. " + "total_failures: %d. " + "max_failures_threshold: %d. " + "median_soak_iteration_latency: %lf ms. " + "90th_soak_iteration_latency: %lf ms. " + "worst_soak_iteration_latency: %lf ms. " + "Some or all of the iterations that did run were unexpectedly slow. " + "See breakdown above for which iterations succeeded, failed, and " + "why for more info.", + overall_timeout_seconds, iterations_ran, soak_iterations, + total_failures, max_failures, latency_ms_median, latency_ms_90th, + latency_ms_worst); + GPR_ASSERT(0); + } else if (total_failures > max_failures) { + gpr_log(GPR_ERROR, + "soak test ran: %d iterations. total_failures: %d exceeds " + "max_failures_threshold: %d. " + "median_soak_iteration_latency: %lf ms. " + "90th_soak_iteration_latency: %lf ms. " + "worst_soak_iteration_latency: %lf ms. " + "See breakdown above for which iterations succeeded, failed, and " + "why for more info.", + soak_iterations, total_failures, max_failures, latency_ms_median, + latency_ms_90th, latency_ms_worst); + GPR_ASSERT(0); + } else { + gpr_log(GPR_INFO, + "soak test ran: %d iterations. total_failures: %d is within " + "max_failures_threshold: %d. " + "median_soak_iteration_latency: %lf ms. " + "90th_soak_iteration_latency: %lf ms. " + "worst_soak_iteration_latency: %lf ms. " + "See breakdown above for which iterations succeeded, failed, and " + "why for more info.", + soak_iterations, total_failures, max_failures, latency_ms_median, + latency_ms_90th, latency_ms_worst); + } +} + +bool InteropClient::DoRpcSoakTest( + int32_t soak_iterations, int32_t max_failures, + int64_t max_acceptable_per_iteration_latency_ms, + int32_t overall_timeout_seconds) { + gpr_log(GPR_DEBUG, "Sending %d RPCs...", soak_iterations); + GPR_ASSERT(soak_iterations > 0); + PerformSoakTest(false /* reset channel per iteration */, soak_iterations, + max_failures, max_acceptable_per_iteration_latency_ms, + overall_timeout_seconds); + gpr_log(GPR_DEBUG, "rpc_soak test done."); + return true; +} + +bool InteropClient::DoChannelSoakTest( + int32_t soak_iterations, int32_t max_failures, + int64_t max_acceptable_per_iteration_latency_ms, + int32_t overall_timeout_seconds) { + gpr_log(GPR_DEBUG, "Sending %d RPCs, tearing down the channel each time...", + soak_iterations); + GPR_ASSERT(soak_iterations > 0); + PerformSoakTest(true /* reset channel per iteration */, soak_iterations, + max_failures, max_acceptable_per_iteration_latency_ms, + overall_timeout_seconds); + gpr_log(GPR_DEBUG, "channel_soak test done."); + return true; +} + +bool InteropClient::DoLongLivedChannelTest(int32_t soak_iterations, + int32_t iteration_interval) { + gpr_log(GPR_DEBUG, "Sending %d RPCs...", soak_iterations); + GPR_ASSERT(soak_iterations > 0); + GPR_ASSERT(iteration_interval > 0); + SimpleRequest request; + SimpleResponse response; + int num_failures = 0; + for (int i = 0; i < soak_iterations; ++i) { + gpr_log(GPR_DEBUG, "Sending RPC number %d...", i); + if (!PerformLargeUnary(&request, &response)) { + gpr_log(GPR_ERROR, "Iteration %d failed.", i); + num_failures++; + } + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(iteration_interval, GPR_TIMESPAN))); + } + if (num_failures == 0) { + gpr_log(GPR_DEBUG, "long_lived_channel test done."); + return true; + } else { + gpr_log(GPR_DEBUG, "long_lived_channel test failed with %d rpc failures.", + num_failures); + return false; + } +} + +bool InteropClient::DoUnimplementedService() { + gpr_log(GPR_DEBUG, "Sending a request for an unimplemented service..."); + + Empty request; + Empty response; + ClientContext context; + + UnimplementedService::Stub* stub = serviceStub_.GetUnimplementedServiceStub(); + + Status s = stub->UnimplementedCall(&context, request, &response); + + if (!AssertStatusCode(s, StatusCode::UNIMPLEMENTED, + context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "unimplemented service done."); + return true; +} + +bool InteropClient::DoUnimplementedMethod() { + gpr_log(GPR_DEBUG, "Sending a request for an unimplemented rpc..."); + + Empty request; + Empty response; + ClientContext context; + + Status s = + serviceStub_.Get()->UnimplementedCall(&context, request, &response); + + if (!AssertStatusCode(s, StatusCode::UNIMPLEMENTED, + context.debug_error_string())) { + return false; + } + + gpr_log(GPR_DEBUG, "unimplemented rpc done."); + return true; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/interop/interop_server.cc b/test/cpp/interop/interop_server.cc new file mode 100644 index 00000000..10661465 --- /dev/null +++ b/test/cpp/interop/interop_server.cc @@ -0,0 +1,377 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/cpp/interop/server_helper.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(bool, use_alts, false, + "Whether to use alts. Enable alts will disable tls."); +ABSL_FLAG(bool, use_tls, false, "Whether to use tls."); +ABSL_FLAG(std::string, custom_credentials_type, "", + "User provided credentials type."); +ABSL_FLAG(int32_t, port, 0, "Server port."); +ABSL_FLAG(int32_t, max_send_message_size, -1, "The maximum send message size."); + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::ServerCredentials; +using grpc::ServerReader; +using grpc::ServerReaderWriter; +using grpc::ServerWriter; +using grpc::Status; +using grpc::WriteOptions; +using grpc::testing::InteropServerContextInspector; +using grpc::testing::Payload; +using grpc::testing::SimpleRequest; +using grpc::testing::SimpleResponse; +using grpc::testing::StreamingInputCallRequest; +using grpc::testing::StreamingInputCallResponse; +using grpc::testing::StreamingOutputCallRequest; +using grpc::testing::StreamingOutputCallResponse; +using grpc::testing::TestService; + +const char kEchoInitialMetadataKey[] = "x-grpc-test-echo-initial"; +const char kEchoTrailingBinMetadataKey[] = "x-grpc-test-echo-trailing-bin"; +const char kEchoUserAgentKey[] = "x-grpc-test-echo-useragent"; + +void MaybeEchoMetadata(ServerContext* context) { + const auto& client_metadata = context->client_metadata(); + GPR_ASSERT(client_metadata.count(kEchoInitialMetadataKey) <= 1); + GPR_ASSERT(client_metadata.count(kEchoTrailingBinMetadataKey) <= 1); + + auto iter = client_metadata.find(kEchoInitialMetadataKey); + if (iter != client_metadata.end()) { + context->AddInitialMetadata( + kEchoInitialMetadataKey, + std::string(iter->second.begin(), iter->second.end())); + } + iter = client_metadata.find(kEchoTrailingBinMetadataKey); + if (iter != client_metadata.end()) { + context->AddTrailingMetadata( + kEchoTrailingBinMetadataKey, + std::string(iter->second.begin(), iter->second.end())); + } + // Check if client sent a magic key in the header that makes us echo + // back the user-agent (for testing purpose) + iter = client_metadata.find(kEchoUserAgentKey); + if (iter != client_metadata.end()) { + iter = client_metadata.find("user-agent"); + if (iter != client_metadata.end()) { + context->AddInitialMetadata( + kEchoUserAgentKey, + std::string(iter->second.begin(), iter->second.end())); + } + } +} + +bool SetPayload(int size, Payload* payload) { + std::unique_ptr body(new char[size]()); + payload->set_body(body.get(), size); + return true; +} + +bool CheckExpectedCompression(const ServerContext& context, + const bool compression_expected) { + const InteropServerContextInspector inspector(context); + const grpc_compression_algorithm received_compression = + inspector.GetCallCompressionAlgorithm(); + + if (compression_expected) { + if (received_compression == GRPC_COMPRESS_NONE) { + // Expected some compression, got NONE. This is an error. + gpr_log(GPR_ERROR, + "Expected compression but got uncompressed request from client."); + return false; + } + if (!(inspector.WasCompressed())) { + gpr_log(GPR_ERROR, + "Failure: Requested compression in a compressable request, but " + "compression bit in message flags not set."); + return false; + } + } else { + // Didn't expect compression -> make sure the request is uncompressed + if (inspector.WasCompressed()) { + gpr_log(GPR_ERROR, + "Failure: Didn't requested compression, but compression bit in " + "message flags set."); + return false; + } + } + return true; +} + +class TestServiceImpl : public TestService::Service { + public: + Status EmptyCall(ServerContext* context, + const grpc::testing::Empty* /*request*/, + grpc::testing::Empty* /*response*/) override { + MaybeEchoMetadata(context); + return Status::OK; + } + + // Response contains current timestamp. We ignore everything in the request. + Status CacheableUnaryCall(ServerContext* context, + const SimpleRequest* /*request*/, + SimpleResponse* response) override { + gpr_timespec ts = gpr_now(GPR_CLOCK_PRECISE); + std::string timestamp = std::to_string(ts.tv_nsec); + response->mutable_payload()->set_body(timestamp.c_str(), timestamp.size()); + context->AddInitialMetadata("cache-control", "max-age=60, public"); + return Status::OK; + } + + Status UnaryCall(ServerContext* context, const SimpleRequest* request, + SimpleResponse* response) override { + MaybeEchoMetadata(context); + if (request->has_response_compressed()) { + const bool compression_requested = request->response_compressed().value(); + gpr_log(GPR_DEBUG, "Request for compression (%s) present for %s", + compression_requested ? "enabled" : "disabled", __func__); + if (compression_requested) { + // Any level would do, let's go for HIGH because we are overachievers. + context->set_compression_level(GRPC_COMPRESS_LEVEL_HIGH); + } else { + context->set_compression_level(GRPC_COMPRESS_LEVEL_NONE); + } + } + if (!CheckExpectedCompression(*context, + request->expect_compressed().value())) { + return Status(grpc::StatusCode::INVALID_ARGUMENT, + "Compressed request expectation not met."); + } + if (request->response_size() > 0) { + if (!SetPayload(request->response_size(), response->mutable_payload())) { + return Status(grpc::StatusCode::INVALID_ARGUMENT, + "Error creating payload."); + } + } + + if (request->has_response_status()) { + return Status( + static_cast(request->response_status().code()), + request->response_status().message()); + } + + return Status::OK; + } + + Status StreamingOutputCall( + ServerContext* context, const StreamingOutputCallRequest* request, + ServerWriter* writer) override { + StreamingOutputCallResponse response; + bool write_success = true; + for (int i = 0; write_success && i < request->response_parameters_size(); + i++) { + if (!SetPayload(request->response_parameters(i).size(), + response.mutable_payload())) { + return Status(grpc::StatusCode::INVALID_ARGUMENT, + "Error creating payload."); + } + WriteOptions wopts; + if (request->response_parameters(i).has_compressed()) { + // Compress by default. Disabled on a per-message basis. + context->set_compression_level(GRPC_COMPRESS_LEVEL_HIGH); + const bool compression_requested = + request->response_parameters(i).compressed().value(); + gpr_log(GPR_DEBUG, "Request for compression (%s) present for %s", + compression_requested ? "enabled" : "disabled", __func__); + if (!compression_requested) { + wopts.set_no_compression(); + } // else, compression is already enabled via the context. + } + int time_us; + if ((time_us = request->response_parameters(i).interval_us()) > 0) { + // Sleep before response if needed + gpr_timespec sleep_time = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(time_us, GPR_TIMESPAN)); + gpr_sleep_until(sleep_time); + } + write_success = writer->Write(response, wopts); + } + if (write_success) { + return Status::OK; + } else { + return Status(grpc::StatusCode::INTERNAL, "Error writing response."); + } + } + + Status StreamingInputCall(ServerContext* context, + ServerReader* reader, + StreamingInputCallResponse* response) override { + StreamingInputCallRequest request; + int aggregated_payload_size = 0; + while (reader->Read(&request)) { + if (!CheckExpectedCompression(*context, + request.expect_compressed().value())) { + return Status(grpc::StatusCode::INVALID_ARGUMENT, + "Compressed request expectation not met."); + } + if (request.has_payload()) { + aggregated_payload_size += request.payload().body().size(); + } + } + response->set_aggregated_payload_size(aggregated_payload_size); + return Status::OK; + } + + Status FullDuplexCall( + ServerContext* context, + ServerReaderWriter* stream) override { + MaybeEchoMetadata(context); + StreamingOutputCallRequest request; + StreamingOutputCallResponse response; + bool write_success = true; + while (write_success && stream->Read(&request)) { + if (request.has_response_status()) { + return Status( + static_cast(request.response_status().code()), + request.response_status().message()); + } + if (request.response_parameters_size() != 0) { + response.mutable_payload()->set_type(request.payload().type()); + response.mutable_payload()->set_body( + std::string(request.response_parameters(0).size(), '\0')); + int time_us; + if ((time_us = request.response_parameters(0).interval_us()) > 0) { + // Sleep before response if needed + gpr_timespec sleep_time = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(time_us, GPR_TIMESPAN)); + gpr_sleep_until(sleep_time); + } + write_success = stream->Write(response); + } + } + if (write_success) { + return Status::OK; + } else { + return Status(grpc::StatusCode::INTERNAL, "Error writing response."); + } + } + + Status HalfDuplexCall( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + std::vector requests; + StreamingOutputCallRequest request; + while (stream->Read(&request)) { + requests.push_back(request); + } + + StreamingOutputCallResponse response; + bool write_success = true; + for (unsigned int i = 0; write_success && i < requests.size(); i++) { + response.mutable_payload()->set_type(requests[i].payload().type()); + if (requests[i].response_parameters_size() == 0) { + return Status(grpc::StatusCode::INTERNAL, + "Request does not have response parameters."); + } + response.mutable_payload()->set_body( + std::string(requests[i].response_parameters(0).size(), '\0')); + write_success = stream->Write(response); + } + if (write_success) { + return Status::OK; + } else { + return Status(grpc::StatusCode::INTERNAL, "Error writing response."); + } + } +}; + +void grpc::testing::interop::RunServer( + const std::shared_ptr& creds) { + RunServer(creds, absl::GetFlag(FLAGS_port), nullptr, nullptr); +} + +void grpc::testing::interop::RunServer( + const std::shared_ptr& creds, + std::unique_ptr>> + server_options) { + RunServer(creds, absl::GetFlag(FLAGS_port), nullptr, + std::move(server_options)); +} + +void grpc::testing::interop::RunServer( + const std::shared_ptr& creds, const int port, + ServerStartedCondition* server_started_condition) { + RunServer(creds, port, server_started_condition, nullptr); +} + +void grpc::testing::interop::RunServer( + const std::shared_ptr& creds, const int port, + ServerStartedCondition* server_started_condition, + std::unique_ptr>> + server_options) { + GPR_ASSERT(port != 0); + std::ostringstream server_address; + server_address << "0.0.0.0:" << port; + TestServiceImpl service; + + SimpleRequest request; + SimpleResponse response; + + ServerBuilder builder; + builder.RegisterService(&service); + builder.AddListeningPort(server_address.str(), creds); + if (server_options != nullptr) { + for (size_t i = 0; i < server_options->size(); i++) { + builder.SetOption(std::move((*server_options)[i])); + } + } + if (absl::GetFlag(FLAGS_max_send_message_size) >= 0) { + builder.SetMaxSendMessageSize(absl::GetFlag(FLAGS_max_send_message_size)); + } + std::unique_ptr server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Server listening on %s", server_address.str().c_str()); + + // Signal that the server has started. + if (server_started_condition) { + std::unique_lock lock(server_started_condition->mutex); + server_started_condition->server_started = true; + server_started_condition->condition.notify_all(); + } + + while (!gpr_atm_no_barrier_load(&g_got_sigint)) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(5, GPR_TIMESPAN))); + } +} diff --git a/test/cpp/interop/interop_server_bootstrap.cc b/test/cpp/interop/interop_server_bootstrap.cc new file mode 100644 index 00000000..a33f60fe --- /dev/null +++ b/test/cpp/interop/interop_server_bootstrap.cc @@ -0,0 +1,40 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/interop/server_helper.h" +#include "test/cpp/util/test_config.h" + +gpr_atm grpc::testing::interop::g_got_sigint; + +static void sigint_handler(int /*x*/) { + gpr_atm_no_barrier_store(&grpc::testing::interop::g_got_sigint, true); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + signal(SIGINT, sigint_handler); + + grpc::testing::interop::RunServer( + grpc::testing::CreateInteropServerCredentials()); + + return 0; +} diff --git a/test/cpp/interop/interop_test.cc b/test/cpp/interop/interop_test.cc new file mode 100644 index 00000000..f0985e7b --- /dev/null +++ b/test/cpp/interop/interop_test.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" + +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/socket_utils_posix.h" +#include "test/core/util/port.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(std::string, extra_server_flags, "", + "Extra flags to pass to server."); + +int test_client(const char* root, const char* host, int port) { + int status; + pid_t cli; + cli = fork(); + if (cli == 0) { + std::string binary_path = absl::StrCat(root, "/interop_client"); + std::string port_arg = absl::StrCat("--server_port=", port); + execl(binary_path.c_str(), binary_path.c_str(), port_arg.c_str(), NULL); + return 1; + } + /* wait for client */ + gpr_log(GPR_INFO, "Waiting for client: %s", host); + if (waitpid(cli, &status, 0) == -1) return 2; + if (!WIFEXITED(status)) return 4; + if (WEXITSTATUS(status)) return WEXITSTATUS(status); + return 0; +} + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + char* me = argv[0]; + char* lslash = strrchr(me, '/'); + char root[1024]; + int port = grpc_pick_unused_port_or_die(); + int status; + pid_t svr; + int ret; + int do_ipv6 = 1; + /* seed rng with pid, so we don't end up with the same random numbers as a + concurrently running test binary */ + srand(getpid()); + if (!grpc_ipv6_loopback_available()) { + gpr_log(GPR_INFO, "Can't bind to ::1. Skipping IPv6 tests."); + do_ipv6 = 0; + } + /* figure out where we are */ + if (lslash) { + memcpy(root, me, lslash - me); + root[lslash - me] = 0; + } else { + strcpy(root, "."); + } + /* start the server */ + svr = fork(); + if (svr == 0) { + std::vector args; + std::string command = absl::StrCat(root, "/interop_server"); + args.push_back(const_cast(command.c_str())); + std::string port_arg = absl::StrCat("--port=", port); + args.push_back(const_cast(port_arg.c_str())); + if (!absl::GetFlag(FLAGS_extra_server_flags).empty()) { + args.push_back( + const_cast(absl::GetFlag(FLAGS_extra_server_flags).c_str())); + } + args.push_back(nullptr); + execv(args[0], args.data()); + return 1; + } + /* wait a little */ + sleep(10); + /* start the clients */ + ret = test_client(root, "127.0.0.1", port); + if (ret != 0) return ret; + ret = test_client(root, "::ffff:127.0.0.1", port); + if (ret != 0) return ret; + ret = test_client(root, "localhost", port); + if (ret != 0) return ret; + if (do_ipv6) { + ret = test_client(root, "::1", port); + if (ret != 0) return ret; + } + /* wait for server */ + gpr_log(GPR_INFO, "Waiting for server"); + kill(svr, SIGINT); + if (waitpid(svr, &status, 0) == -1) return 2; + if (!WIFEXITED(status)) return 4; + if (WEXITSTATUS(status)) return WEXITSTATUS(status); + return 0; +} diff --git a/test/cpp/interop/metrics_client.cc b/test/cpp/interop/metrics_client.cc new file mode 100644 index 00000000..f733dfc2 --- /dev/null +++ b/test/cpp/interop/metrics_client.cc @@ -0,0 +1,108 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include +#include + +#include "absl/flags/flag.h" + +#include +#include + +#include "src/proto/grpc/testing/metrics.grpc.pb.h" +#include "src/proto/grpc/testing/metrics.pb.h" +#include "test/cpp/util/metrics_server.h" +#include "test/cpp/util/test_config.h" + +int kDeadlineSecs = 10; + +ABSL_FLAG(std::string, metrics_server_address, "localhost:8081", + "The metrics server addresses in the fomrat :"); +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int32_t, deadline_secs, kDeadlineSecs, + "The deadline (in seconds) for RCP call"); +ABSL_FLAG(bool, total_only, false, + "If true, this prints only the total value of all gauges"); + +using grpc::testing::EmptyMessage; +using grpc::testing::GaugeResponse; +using grpc::testing::MetricsService; + +// Do not log anything +void BlackholeLogger(gpr_log_func_args* /*args*/) {} + +// Prints the values of all Gauges (unless total_only is set to 'true' in which +// case this only prints the sum of all gauge values). +bool PrintMetrics(std::unique_ptr stub, bool total_only, + int deadline_secs) { + grpc::ClientContext context; + EmptyMessage message; + + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(deadline_secs); + + context.set_deadline(deadline); + + std::unique_ptr> reader( + stub->GetAllGauges(&context, message)); + + GaugeResponse gauge_response; + long overall_qps = 0; + while (reader->Read(&gauge_response)) { + if (gauge_response.value_case() == GaugeResponse::kLongValue) { + if (!total_only) { + std::cout << gauge_response.name() << ": " + << gauge_response.long_value() << std::endl; + } + overall_qps += gauge_response.long_value(); + } else { + std::cout << "Gauge '" << gauge_response.name() << "' is not long valued" + << std::endl; + } + } + + std::cout << overall_qps << std::endl; + + const grpc::Status status = reader->Finish(); + if (!status.ok()) { + std::cout << "Error in getting metrics from the client" << std::endl; + } + + return status.ok(); +} + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + + // The output of metrics client is in some cases programmatically parsed (for + // example by the stress test framework). So, we do not want any of the log + // from the grpc library appearing on stdout. + gpr_set_log_function(BlackholeLogger); + + std::shared_ptr channel( + grpc::CreateChannel(absl::GetFlag(FLAGS_metrics_server_address), + grpc::InsecureChannelCredentials())); + + if (!PrintMetrics(MetricsService::NewStub(channel), + absl::GetFlag(FLAGS_total_only), + absl::GetFlag(FLAGS_deadline_secs))) { + return 1; + } + + return 0; +} diff --git a/test/cpp/interop/reconnect_interop_client.cc b/test/cpp/interop/reconnect_interop_client.cc new file mode 100644 index 00000000..9554c17b --- /dev/null +++ b/test/cpp/interop/reconnect_interop_client.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(int32_t, server_control_port, 0, "Server port for control rpcs."); +ABSL_FLAG(int32_t, server_retry_port, 0, + "Server port for testing reconnection."); +ABSL_FLAG(std::string, server_host, "localhost", "Server host to connect to"); +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int32_t, max_reconnect_backoff_ms, 0, + "Maximum backoff time, or 0 for default."); + +using grpc::CallCredentials; +using grpc::Channel; +using grpc::ChannelArguments; +using grpc::ClientContext; +using grpc::CreateTestChannel; +using grpc::Status; +using grpc::testing::Empty; +using grpc::testing::INSECURE; +using grpc::testing::ReconnectInfo; +using grpc::testing::ReconnectParams; +using grpc::testing::ReconnectService; +using grpc::testing::TLS; + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + GPR_ASSERT(absl::GetFlag(FLAGS_server_control_port)); + GPR_ASSERT(absl::GetFlag(FLAGS_server_retry_port)); + + std::ostringstream server_address; + server_address << absl::GetFlag(FLAGS_server_host) << ':' + << absl::GetFlag(FLAGS_server_control_port); + std::unique_ptr control_stub( + ReconnectService::NewStub( + CreateTestChannel(server_address.str(), INSECURE))); + ClientContext start_context; + ReconnectParams reconnect_params; + reconnect_params.set_max_reconnect_backoff_ms( + absl::GetFlag(FLAGS_max_reconnect_backoff_ms)); + Empty empty_response; + Status start_status = + control_stub->Start(&start_context, reconnect_params, &empty_response); + GPR_ASSERT(start_status.ok()); + + gpr_log(GPR_INFO, "Starting connections with retries."); + server_address.str(""); + server_address << absl::GetFlag(FLAGS_server_host) << ':' + << absl::GetFlag(FLAGS_server_retry_port); + ChannelArguments channel_args; + if (absl::GetFlag(FLAGS_max_reconnect_backoff_ms) > 0) { + channel_args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, + absl::GetFlag(FLAGS_max_reconnect_backoff_ms)); + } + std::shared_ptr retry_channel = + CreateTestChannel(server_address.str(), "foo.test.google.fr", TLS, false, + std::shared_ptr(), channel_args); + + // About 13 retries. + const int kDeadlineSeconds = 540; + // Use any rpc to test retry. + std::unique_ptr retry_stub( + ReconnectService::NewStub(retry_channel)); + ClientContext retry_context; + retry_context.set_deadline(std::chrono::system_clock::now() + + std::chrono::seconds(kDeadlineSeconds)); + Status retry_status = + retry_stub->Start(&retry_context, reconnect_params, &empty_response); + GPR_ASSERT(retry_status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED); + gpr_log(GPR_INFO, "Done retrying, getting final data from server"); + + ClientContext stop_context; + ReconnectInfo response; + Status stop_status = control_stub->Stop(&stop_context, Empty(), &response); + GPR_ASSERT(stop_status.ok()); + GPR_ASSERT(response.passed() == true); + gpr_log(GPR_INFO, "Passed"); + return 0; +} diff --git a/test/cpp/interop/reconnect_interop_server.cc b/test/cpp/interop/reconnect_interop_server.cc new file mode 100644 index 00000000..6496f8f4 --- /dev/null +++ b/test/cpp/interop/reconnect_interop_server.cc @@ -0,0 +1,186 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Test description at doc/connection-backoff-interop-test-description.md + +#include + +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/reconnect_server.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(int32_t, control_port, 0, "Server port for controlling the server."); +ABSL_FLAG(int32_t, retry_port, 0, + "Server port for raw tcp connections. All incoming " + "connections will be closed immediately."); + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::testing::Empty; +using grpc::testing::ReconnectInfo; +using grpc::testing::ReconnectParams; +using grpc::testing::ReconnectService; + +static bool got_sigint = false; + +class ReconnectServiceImpl : public ReconnectService::Service { + public: + explicit ReconnectServiceImpl(int retry_port) + : retry_port_(retry_port), + serving_(false), + server_started_(false), + shutdown_(false) { + reconnect_server_init(&tcp_server_); + } + + ~ReconnectServiceImpl() override { + if (server_started_) { + reconnect_server_destroy(&tcp_server_); + } + } + + void Poll(int seconds) { reconnect_server_poll(&tcp_server_, seconds); } + + Status Start(ServerContext* /*context*/, const ReconnectParams* request, + Empty* /*response*/) override { + bool start_server = true; + std::unique_lock lock(mu_); + while (serving_ && !shutdown_) { + cv_.wait(lock); + } + if (shutdown_) { + return Status(grpc::StatusCode::UNAVAILABLE, "shutting down"); + } + serving_ = true; + if (server_started_) { + start_server = false; + } else { + tcp_server_.max_reconnect_backoff_ms = + request->max_reconnect_backoff_ms(); + server_started_ = true; + } + lock.unlock(); + + if (start_server) { + reconnect_server_start(&tcp_server_, retry_port_); + } else { + reconnect_server_clear_timestamps(&tcp_server_); + } + return Status::OK; + } + + Status Stop(ServerContext* /*context*/, const Empty* /*request*/, + ReconnectInfo* response) override { + // extract timestamps and set response + Verify(response); + reconnect_server_clear_timestamps(&tcp_server_); + std::lock_guard lock(mu_); + serving_ = false; + cv_.notify_one(); + return Status::OK; + } + + void Verify(ReconnectInfo* response) { + double expected_backoff = 1000.0; + const double kTransmissionDelay = 100.0; + const double kBackoffMultiplier = 1.6; + const double kJitterFactor = 0.2; + const int kMaxBackoffMs = tcp_server_.max_reconnect_backoff_ms + ? tcp_server_.max_reconnect_backoff_ms + : 120 * 1000; + bool passed = true; + for (timestamp_list* cur = tcp_server_.head; cur && cur->next; + cur = cur->next) { + double backoff = gpr_time_to_millis( + gpr_time_sub(cur->next->timestamp, cur->timestamp)); + double min_backoff = expected_backoff * (1 - kJitterFactor); + double max_backoff = expected_backoff * (1 + kJitterFactor); + if (backoff < min_backoff - kTransmissionDelay || + backoff > max_backoff + kTransmissionDelay) { + passed = false; + } + response->add_backoff_ms(static_cast(backoff)); + expected_backoff *= kBackoffMultiplier; + expected_backoff = + expected_backoff > kMaxBackoffMs ? kMaxBackoffMs : expected_backoff; + } + response->set_passed(passed); + } + + void Shutdown() { + std::lock_guard lock(mu_); + shutdown_ = true; + cv_.notify_all(); + } + + private: + int retry_port_; + reconnect_server tcp_server_; + bool serving_; + bool server_started_; + bool shutdown_; + std::mutex mu_; + std::condition_variable cv_; +}; + +void RunServer() { + std::ostringstream server_address; + server_address << "0.0.0.0:" << absl::GetFlag(FLAGS_control_port); + ReconnectServiceImpl service(absl::GetFlag(FLAGS_retry_port)); + + ServerBuilder builder; + builder.RegisterService(&service); + builder.AddListeningPort(server_address.str(), + grpc::InsecureServerCredentials()); + std::unique_ptr server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Server listening on %s", server_address.str().c_str()); + while (!got_sigint) { + service.Poll(5); + } + service.Shutdown(); +} + +static void sigint_handler(int /*x*/) { got_sigint = true; } + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + signal(SIGINT, sigint_handler); + + GPR_ASSERT(absl::GetFlag(FLAGS_control_port) != 0); + GPR_ASSERT(absl::GetFlag(FLAGS_retry_port) != 0); + RunServer(); + + return 0; +} diff --git a/test/cpp/interop/server_helper.cc b/test/cpp/interop/server_helper.cc new file mode 100644 index 00000000..4c7e7b55 --- /dev/null +++ b/test/cpp/interop/server_helper.cc @@ -0,0 +1,84 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/interop/server_helper.h" + +#include + +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" + +#include + +#include "src/core/lib/surface/call_test_only.h" +#include "src/core/lib/transport/byte_stream.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_DECLARE_FLAG(bool, use_alts); +ABSL_DECLARE_FLAG(bool, use_tls); +ABSL_DECLARE_FLAG(std::string, custom_credentials_type); + +namespace grpc { +namespace testing { + +std::shared_ptr CreateInteropServerCredentials() { + if (!absl::GetFlag(FLAGS_custom_credentials_type).empty()) { + return GetCredentialsProvider()->GetServerCredentials( + absl::GetFlag(FLAGS_custom_credentials_type)); + } else if (absl::GetFlag(FLAGS_use_alts)) { + return GetCredentialsProvider()->GetServerCredentials(kAltsCredentialsType); + } else if (absl::GetFlag(FLAGS_use_tls)) { + return GetCredentialsProvider()->GetServerCredentials(kTlsCredentialsType); + } else { + return GetCredentialsProvider()->GetServerCredentials( + kInsecureCredentialsType); + } +} + +InteropServerContextInspector::InteropServerContextInspector( + const ::grpc::ServerContext& context) + : context_(context) {} + +grpc_compression_algorithm +InteropServerContextInspector::GetCallCompressionAlgorithm() const { + return grpc_call_test_only_get_compression_algorithm(context_.call_.call); +} + +uint32_t InteropServerContextInspector::GetEncodingsAcceptedByClient() const { + return grpc_call_test_only_get_encodings_accepted_by_peer( + context_.call_.call); +} + +bool InteropServerContextInspector::WasCompressed() const { + return (grpc_call_test_only_get_message_flags(context_.call_.call) & + GRPC_WRITE_INTERNAL_COMPRESS) || + (grpc_call_test_only_get_message_flags(context_.call_.call) & + GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED); +} + +std::shared_ptr +InteropServerContextInspector::GetAuthContext() const { + return context_.auth_context(); +} + +bool InteropServerContextInspector::IsCancelled() const { + return context_.IsCancelled(); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/interop/stress_interop_client.cc b/test/cpp/interop/stress_interop_client.cc new file mode 100644 index 00000000..f001f815 --- /dev/null +++ b/test/cpp/interop/stress_interop_client.cc @@ -0,0 +1,199 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include "test/cpp/interop/stress_interop_client.h" + +#include +#include +#include + +#include +#include + +#include "test/cpp/interop/interop_client.h" +#include "test/cpp/util/metrics_server.h" + +namespace grpc { +namespace testing { + +using std::pair; +using std::vector; + +WeightedRandomTestSelector::WeightedRandomTestSelector( + const vector>& tests) + : tests_(tests) { + total_weight_ = 0; + for (auto it = tests.begin(); it != tests.end(); it++) { + total_weight_ += it->second; + } +} + +// Returns a weighted-randomly selected test case based on the test weights +// passed in the constructror +TestCaseType WeightedRandomTestSelector::GetNextTest() const { + int random = 0; + TestCaseType selected_test = UNKNOWN_TEST; + + // Get a random number from [0 to the total_weight - 1] + random = rand() % total_weight_; + + int weight_sofar = 0; + for (auto it = tests_.begin(); it != tests_.end(); it++) { + weight_sofar += it->second; + if (random < weight_sofar) { + selected_test = it->first; + break; + } + } + + // It is a bug in the logic if no test is selected at this point + GPR_ASSERT(selected_test != UNKNOWN_TEST); + return selected_test; +} + +StressTestInteropClient::StressTestInteropClient( + int test_id, const std::string& server_address, + ChannelCreationFunc channel_creation_func, + const WeightedRandomTestSelector& test_selector, long test_duration_secs, + long sleep_duration_ms, bool do_not_abort_on_transient_failures) + : test_id_(test_id), + server_address_(server_address), + channel_creation_func_(std::move(channel_creation_func)), + interop_client_(new InteropClient(channel_creation_func_, false, + do_not_abort_on_transient_failures)), + test_selector_(test_selector), + test_duration_secs_(test_duration_secs), + sleep_duration_ms_(sleep_duration_ms) {} + +void StressTestInteropClient::MainLoop( + const std::shared_ptr& qps_gauge) { + gpr_log(GPR_INFO, "Running test %d. ServerAddr: %s", test_id_, + server_address_.c_str()); + + gpr_timespec test_end_time; + if (test_duration_secs_ < 0) { + test_end_time = gpr_inf_future(GPR_CLOCK_REALTIME); + } else { + test_end_time = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(test_duration_secs_, GPR_TIMESPAN)); + } + + qps_gauge->Reset(); + + while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), test_end_time) < 0) { + // Select the test case to execute based on the weights and execute it + TestCaseType test_case = test_selector_.GetNextTest(); + gpr_log(GPR_DEBUG, "%d - Executing the test case %d", test_id_, test_case); + RunTest(test_case); + + qps_gauge->Incr(); + + // Sleep between successive calls if needed + if (sleep_duration_ms_ > 0) { + gpr_timespec sleep_time = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(sleep_duration_ms_, GPR_TIMESPAN)); + gpr_sleep_until(sleep_time); + } + } +} + +bool StressTestInteropClient::RunTest(TestCaseType test_case) { + bool is_success = false; + switch (test_case) { + case EMPTY_UNARY: { + is_success = interop_client_->DoEmpty(); + break; + } + case LARGE_UNARY: { + is_success = interop_client_->DoLargeUnary(); + break; + } + case CLIENT_COMPRESSED_UNARY: { + is_success = interop_client_->DoClientCompressedUnary(); + break; + } + case CLIENT_COMPRESSED_STREAMING: { + is_success = interop_client_->DoClientCompressedStreaming(); + break; + } + case CLIENT_STREAMING: { + is_success = interop_client_->DoRequestStreaming(); + break; + } + case SERVER_STREAMING: { + is_success = interop_client_->DoResponseStreaming(); + break; + } + case SERVER_COMPRESSED_UNARY: { + is_success = interop_client_->DoServerCompressedUnary(); + break; + } + case SERVER_COMPRESSED_STREAMING: { + is_success = interop_client_->DoServerCompressedStreaming(); + break; + } + case SLOW_CONSUMER: { + is_success = interop_client_->DoResponseStreamingWithSlowConsumer(); + break; + } + case HALF_DUPLEX: { + is_success = interop_client_->DoHalfDuplex(); + break; + } + case PING_PONG: { + is_success = interop_client_->DoPingPong(); + break; + } + case CANCEL_AFTER_BEGIN: { + is_success = interop_client_->DoCancelAfterBegin(); + break; + } + case CANCEL_AFTER_FIRST_RESPONSE: { + is_success = interop_client_->DoCancelAfterFirstResponse(); + break; + } + case TIMEOUT_ON_SLEEPING_SERVER: { + is_success = interop_client_->DoTimeoutOnSleepingServer(); + break; + } + case EMPTY_STREAM: { + is_success = interop_client_->DoEmptyStream(); + break; + } + case STATUS_CODE_AND_MESSAGE: { + is_success = interop_client_->DoStatusWithMessage(); + break; + } + case CUSTOM_METADATA: { + is_success = interop_client_->DoCustomMetadata(); + break; + } + default: { + gpr_log(GPR_ERROR, "Invalid test case (%d)", test_case); + GPR_ASSERT(false); + break; + } + } + + return is_success; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/interop/stress_test.cc b/test/cpp/interop/stress_test.cc new file mode 100644 index 00000000..1faed828 --- /dev/null +++ b/test/cpp/interop/stress_test.cc @@ -0,0 +1,344 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include + +#include "src/proto/grpc/testing/metrics.grpc.pb.h" +#include "src/proto/grpc/testing/metrics.pb.h" +#include "test/cpp/interop/interop_client.h" +#include "test/cpp/interop/stress_interop_client.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/metrics_server.h" +#include "test/cpp/util/test_config.h" + +extern void gpr_default_log(gpr_log_func_args* args); + +ABSL_FLAG(int32_t, metrics_port, 8081, "The metrics server port."); + +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int32_t, sleep_duration_ms, 0, + "The duration (in millisec) between two" + " consecutive test calls (per server) issued by the server."); + +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int32_t, test_duration_secs, -1, + "The length of time (in seconds) to run" + " the test. Enter -1 if the test should run continuously until" + " forcefully terminated."); + +ABSL_FLAG(std::string, server_addresses, "localhost:8080", + "The list of server addresses. The format is: \n" + " \":,:...:\"\n" + " Note: can be servername or IP address."); + +ABSL_FLAG(int32_t, num_channels_per_server, 1, + "Number of channels for each server"); + +ABSL_FLAG(int32_t, num_stubs_per_channel, 1, + "Number of stubs per each channels to server. This number also " + "indicates the max number of parallel RPC calls on each channel " + "at any given time."); + +// TODO(sreek): Add more test cases here in future +ABSL_FLAG(std::string, test_cases, "", + "List of test cases to call along with the" + " relative weights in the following format:\n" + " \",...\"\n" + " The following testcases are currently supported:\n" + " empty_unary\n" + " large_unary\n" + " large_compressed_unary\n" + " client_streaming\n" + " server_streaming\n" + " server_compressed_streaming\n" + " slow_consumer\n" + " half_duplex\n" + " ping_pong\n" + " cancel_after_begin\n" + " cancel_after_first_response\n" + " timeout_on_sleeping_server\n" + " empty_stream\n" + " status_code_and_message\n" + " custom_metadata\n" + " Example: \"empty_unary:20,large_unary:10,empty_stream:70\"\n" + " The above will execute 'empty_unary', 20% of the time," + " 'large_unary', 10% of the time and 'empty_stream' the remaining" + " 70% of the time"); + +ABSL_FLAG(int32_t, log_level, GPR_LOG_SEVERITY_INFO, + "Severity level of messages that should be logged. Any messages " + "greater than or equal to the level set here will be logged. " + "The choices are: 0 (GPR_LOG_SEVERITY_DEBUG), 1 " + "(GPR_LOG_SEVERITY_INFO) and 2 (GPR_LOG_SEVERITY_ERROR)"); + +ABSL_FLAG(bool, do_not_abort_on_transient_failures, true, + "If set to 'true', abort() is not called in case of transient " + "failures like temporary connection failures."); + +// Options from client.cc (for compatibility with interop test). +// TODO(sreek): Consolidate overlapping options +ABSL_FLAG(bool, use_alts, false, + "Whether to use alts. Enable alts will disable tls."); +ABSL_FLAG(bool, use_tls, false, "Whether to use tls."); +ABSL_FLAG(bool, use_test_ca, false, "False to use SSL roots for google"); +ABSL_FLAG(std::string, server_host_override, "", + "Override the server host which is sent in HTTP header"); + +using grpc::testing::ALTS; +using grpc::testing::INSECURE; +using grpc::testing::kTestCaseList; +using grpc::testing::MetricsServiceImpl; +using grpc::testing::StressTestInteropClient; +using grpc::testing::TestCaseType; +using grpc::testing::TLS; +using grpc::testing::transport_security; +using grpc::testing::UNKNOWN_TEST; +using grpc::testing::WeightedRandomTestSelector; + +static int log_level = GPR_LOG_SEVERITY_DEBUG; + +// A simple wrapper to grp_default_log() function. This only logs messages at or +// above the current log level (set in 'log_level' variable) +void TestLogFunction(gpr_log_func_args* args) { + if (args->severity >= log_level) { + gpr_default_log(args); + } +} + +TestCaseType GetTestTypeFromName(const std::string& test_name) { + TestCaseType test_case = UNKNOWN_TEST; + + for (auto it = kTestCaseList.begin(); it != kTestCaseList.end(); it++) { + if (test_name == it->second) { + test_case = it->first; + break; + } + } + + return test_case; +} + +// Converts a string of comma delimited tokens to a vector of tokens +bool ParseCommaDelimitedString(const std::string& comma_delimited_str, + std::vector& tokens) { + size_t bpos = 0; + size_t epos = std::string::npos; + + while ((epos = comma_delimited_str.find(',', bpos)) != std::string::npos) { + tokens.emplace_back(comma_delimited_str.substr(bpos, epos - bpos)); + bpos = epos + 1; + } + + tokens.emplace_back(comma_delimited_str.substr(bpos)); // Last token + return true; +} + +// Input: Test case string ",...." +// Output: +// - Whether parsing was successful (return value) +// - Vector of (test_type_enum, weight) pairs returned via 'tests' parameter +bool ParseTestCasesString(const std::string& test_cases, + std::vector>& tests) { + bool is_success = true; + + std::vector tokens; + ParseCommaDelimitedString(test_cases, tokens); + + for (auto it = tokens.begin(); it != tokens.end(); it++) { + // Token is in the form : + size_t colon_pos = it->find(':'); + if (colon_pos == std::string::npos) { + gpr_log(GPR_ERROR, "Error in parsing test case string: %s", it->c_str()); + is_success = false; + break; + } + + std::string test_name = it->substr(0, colon_pos); + int weight = std::stoi(it->substr(colon_pos + 1)); + TestCaseType test_case = GetTestTypeFromName(test_name); + if (test_case == UNKNOWN_TEST) { + gpr_log(GPR_ERROR, "Unknown test case: %s", test_name.c_str()); + is_success = false; + break; + } + + tests.emplace_back(std::make_pair(test_case, weight)); + } + + return is_success; +} + +// For debugging purposes +void LogParameterInfo(const std::vector& addresses, + const std::vector>& tests) { + gpr_log(GPR_INFO, "server_addresses: %s", + absl::GetFlag(FLAGS_server_addresses).c_str()); + gpr_log(GPR_INFO, "test_cases : %s", absl::GetFlag(FLAGS_test_cases).c_str()); + gpr_log(GPR_INFO, "sleep_duration_ms: %d", + absl::GetFlag(FLAGS_sleep_duration_ms)); + gpr_log(GPR_INFO, "test_duration_secs: %d", + absl::GetFlag(FLAGS_test_duration_secs)); + gpr_log(GPR_INFO, "num_channels_per_server: %d", + absl::GetFlag(FLAGS_num_channels_per_server)); + gpr_log(GPR_INFO, "num_stubs_per_channel: %d", + absl::GetFlag(FLAGS_num_stubs_per_channel)); + gpr_log(GPR_INFO, "log_level: %d", absl::GetFlag(FLAGS_log_level)); + gpr_log(GPR_INFO, "do_not_abort_on_transient_failures: %s", + absl::GetFlag(FLAGS_do_not_abort_on_transient_failures) ? "true" + : "false"); + + int num = 0; + for (auto it = addresses.begin(); it != addresses.end(); it++) { + gpr_log(GPR_INFO, "%d:%s", ++num, it->c_str()); + } + + num = 0; + for (auto it = tests.begin(); it != tests.end(); it++) { + TestCaseType test_case = it->first; + int weight = it->second; + gpr_log(GPR_INFO, "%d. TestCaseType: %d, Weight: %d", ++num, test_case, + weight); + } +} + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + + if (absl::GetFlag(FLAGS_log_level) > GPR_LOG_SEVERITY_ERROR || + absl::GetFlag(FLAGS_log_level) < GPR_LOG_SEVERITY_DEBUG) { + gpr_log(GPR_ERROR, "log_level should be an integer between %d and %d", + GPR_LOG_SEVERITY_DEBUG, GPR_LOG_SEVERITY_ERROR); + return 1; + } + + // Change the default log function to TestLogFunction which respects the + // log_level setting. + log_level = absl::GetFlag(FLAGS_log_level); + gpr_set_log_function(TestLogFunction); + + srand(time(nullptr)); + + // Parse the server addresses + std::vector server_addresses; + ParseCommaDelimitedString(absl::GetFlag(FLAGS_server_addresses), + server_addresses); + + // Parse test cases and weights + if (absl::GetFlag(FLAGS_test_cases).length() == 0) { + gpr_log(GPR_ERROR, "No test cases supplied"); + return 1; + } + + std::vector> tests; + if (!ParseTestCasesString(absl::GetFlag(FLAGS_test_cases), tests)) { + gpr_log(GPR_ERROR, "Error in parsing test cases string %s ", + absl::GetFlag(FLAGS_test_cases).c_str()); + return 1; + } + + LogParameterInfo(server_addresses, tests); + + WeightedRandomTestSelector test_selector(tests); + MetricsServiceImpl metrics_service; + + gpr_log(GPR_INFO, "Starting test(s).."); + + std::vector test_threads; + std::vector> clients; + + // Create and start the test threads. + // Note that: + // - Each server can have multiple channels (as configured by + // FLAGS_num_channels_per_server). + // + // - Each channel can have multiple stubs (as configured by + // FLAGS_num_stubs_per_channel). This is to test calling multiple RPCs in + // parallel on the same channel. + int thread_idx = 0; + int server_idx = -1; + char buffer[256]; + transport_security security_type = + absl::GetFlag(FLAGS_use_alts) + ? ALTS + : (absl::GetFlag(FLAGS_use_tls) ? TLS : INSECURE); + for (auto it = server_addresses.begin(); it != server_addresses.end(); it++) { + ++server_idx; + // Create channel(s) for each server + for (int channel_idx = 0; + channel_idx < absl::GetFlag(FLAGS_num_channels_per_server); + channel_idx++) { + gpr_log(GPR_INFO, "Starting test with %s channel_idx=%d..", it->c_str(), + channel_idx); + grpc::testing::ChannelCreationFunc channel_creation_func = + std::bind(static_cast (*)( + const std::string&, const std::string&, + grpc::testing::transport_security, bool)>( + grpc::CreateTestChannel), + *it, absl::GetFlag(FLAGS_server_host_override), + security_type, !absl::GetFlag(FLAGS_use_test_ca)); + + // Create stub(s) for each channel + for (int stub_idx = 0; + stub_idx < absl::GetFlag(FLAGS_num_stubs_per_channel); stub_idx++) { + clients.emplace_back(new StressTestInteropClient( + ++thread_idx, *it, channel_creation_func, test_selector, + absl::GetFlag(FLAGS_test_duration_secs), + absl::GetFlag(FLAGS_sleep_duration_ms), + absl::GetFlag(FLAGS_do_not_abort_on_transient_failures))); + + bool is_already_created = false; + // QpsGauge name + std::snprintf(buffer, sizeof(buffer), + "/stress_test/server_%d/channel_%d/stub_%d/qps", + server_idx, channel_idx, stub_idx); + + test_threads.emplace_back(std::thread( + &StressTestInteropClient::MainLoop, clients.back().get(), + metrics_service.CreateQpsGauge(buffer, &is_already_created))); + + // The QpsGauge should not have been already created + GPR_ASSERT(!is_already_created); + } + } + } + + // Start metrics server before waiting for the stress test threads + std::unique_ptr metrics_server; + if (absl::GetFlag(FLAGS_metrics_port) > 0) { + metrics_server = + metrics_service.StartServer(absl::GetFlag(FLAGS_metrics_port)); + } + + // Wait for the stress test threads to complete + for (auto it = test_threads.begin(); it != test_threads.end(); it++) { + it->join(); + } + + return 0; +} diff --git a/test/cpp/interop/xds_interop_client.cc b/test/cpp/interop/xds_interop_client.cc new file mode 100644 index 00000000..c688a2e6 --- /dev/null +++ b/test/cpp/interop/xds_interop_client.cc @@ -0,0 +1,619 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/flags/flag.h" +#include "absl/strings/str_split.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(bool, fail_on_failed_rpc, false, + "Fail client if any RPCs fail after first successful RPC."); +ABSL_FLAG(int32_t, num_channels, 1, "Number of channels."); +ABSL_FLAG(bool, print_response, false, "Write RPC response to stdout."); +ABSL_FLAG(int32_t, qps, 1, "Qps per channel."); +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int32_t, rpc_timeout_sec, 30, "Per RPC timeout seconds."); +ABSL_FLAG(std::string, server, "localhost:50051", "Address of server."); +ABSL_FLAG(int32_t, stats_port, 50052, + "Port to expose peer distribution stats service."); +ABSL_FLAG(std::string, rpc, "UnaryCall", + "a comma separated list of rpc methods."); +ABSL_FLAG(std::string, metadata, "", "metadata to send with the RPC."); +ABSL_FLAG(std::string, expect_status, "OK", + "RPC status for the test RPC to be considered successful"); +ABSL_FLAG( + bool, secure_mode, false, + "If true, XdsCredentials are used, InsecureChannelCredentials otherwise"); + +using grpc::Channel; +using grpc::ClientAsyncResponseReader; +using grpc::ClientContext; +using grpc::CompletionQueue; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::testing::ClientConfigureRequest; +using grpc::testing::ClientConfigureRequest_RpcType_Name; +using grpc::testing::ClientConfigureResponse; +using grpc::testing::Empty; +using grpc::testing::LoadBalancerAccumulatedStatsRequest; +using grpc::testing::LoadBalancerAccumulatedStatsResponse; +using grpc::testing::LoadBalancerStatsRequest; +using grpc::testing::LoadBalancerStatsResponse; +using grpc::testing::LoadBalancerStatsService; +using grpc::testing::SimpleRequest; +using grpc::testing::SimpleResponse; +using grpc::testing::TestService; +using grpc::testing::XdsUpdateClientConfigureService; + +class XdsStatsWatcher; + +struct StatsWatchers { + // Unique ID for each outgoing RPC + int global_request_id = 0; + // Unique ID for each outgoing RPC by RPC method type + std::map global_request_id_by_type; + // Stores a set of watchers that should be notified upon outgoing RPC + // completion + std::set watchers; + // Global watcher for accumululated stats. + XdsStatsWatcher* global_watcher; + // Mutex for global_request_id and watchers + std::mutex mu; +}; +// Whether at least one RPC has succeeded, indicating xDS resolution completed. +std::atomic one_rpc_succeeded(false); +// RPC configuration detailing how RPC should be sent. +struct RpcConfig { + ClientConfigureRequest::RpcType type; + std::vector> metadata; + int timeout_sec = 0; +}; +struct RpcConfigurationsQueue { + // A queue of RPC configurations detailing how RPCs should be sent. + std::deque> rpc_configs_queue; + // Mutex for rpc_configs_queue + std::mutex mu_rpc_configs_queue; +}; +struct AsyncClientCall { + Empty empty_response; + SimpleResponse simple_response; + ClientContext context; + Status status; + int saved_request_id; + ClientConfigureRequest::RpcType rpc_type; + std::unique_ptr> empty_response_reader; + std::unique_ptr> + simple_response_reader; +}; + +/** Records the remote peer distribution for a given range of RPCs. */ +class XdsStatsWatcher { + public: + XdsStatsWatcher(int start_id, int end_id) + : start_id_(start_id), end_id_(end_id), rpcs_needed_(end_id - start_id) {} + + // Upon the completion of an RPC, we will look at the request_id, the + // rpc_type, and the peer the RPC was sent to in order to count + // this RPC into the right stats bin. + void RpcCompleted(AsyncClientCall* call, const std::string& peer) { + // We count RPCs for global watcher or if the request_id falls into the + // watcher's interested range of request ids. + if ((start_id_ == 0 && end_id_ == 0) || + (start_id_ <= call->saved_request_id && + call->saved_request_id < end_id_)) { + { + std::lock_guard lock(m_); + if (peer.empty()) { + no_remote_peer_++; + ++no_remote_peer_by_type_[call->rpc_type]; + } else { + // RPC is counted into both per-peer bin and per-method-per-peer bin. + rpcs_by_peer_[peer]++; + rpcs_by_type_[call->rpc_type][peer]++; + } + rpcs_needed_--; + // Report accumulated stats. + auto& stats_per_method = *accumulated_stats_.mutable_stats_per_method(); + auto& method_stat = + stats_per_method[ClientConfigureRequest_RpcType_Name( + call->rpc_type)]; + auto& result = *method_stat.mutable_result(); + grpc_status_code code = + static_cast(call->status.error_code()); + auto& num_rpcs = result[code]; + ++num_rpcs; + auto rpcs_started = method_stat.rpcs_started(); + method_stat.set_rpcs_started(++rpcs_started); + } + cv_.notify_one(); + } + } + + void WaitForRpcStatsResponse(LoadBalancerStatsResponse* response, + int timeout_sec) { + std::unique_lock lock(m_); + cv_.wait_for(lock, std::chrono::seconds(timeout_sec), + [this] { return rpcs_needed_ == 0; }); + response->mutable_rpcs_by_peer()->insert(rpcs_by_peer_.begin(), + rpcs_by_peer_.end()); + auto& response_rpcs_by_method = *response->mutable_rpcs_by_method(); + for (const auto& rpc_by_type : rpcs_by_type_) { + std::string method_name; + if (rpc_by_type.first == ClientConfigureRequest::EMPTY_CALL) { + method_name = "EmptyCall"; + } else if (rpc_by_type.first == ClientConfigureRequest::UNARY_CALL) { + method_name = "UnaryCall"; + } else { + GPR_ASSERT(0); + } + // TODO(@donnadionne): When the test runner changes to accept EMPTY_CALL + // and UNARY_CALL we will just use the name of the enum instead of the + // method_name variable. + auto& response_rpc_by_method = response_rpcs_by_method[method_name]; + auto& response_rpcs_by_peer = + *response_rpc_by_method.mutable_rpcs_by_peer(); + for (const auto& rpc_by_peer : rpc_by_type.second) { + auto& response_rpc_by_peer = response_rpcs_by_peer[rpc_by_peer.first]; + response_rpc_by_peer = rpc_by_peer.second; + } + } + response->set_num_failures(no_remote_peer_ + rpcs_needed_); + } + + void GetCurrentRpcStats(LoadBalancerAccumulatedStatsResponse* response, + StatsWatchers* stats_watchers) { + std::unique_lock lock(m_); + response->CopyFrom(accumulated_stats_); + // TODO(@donnadionne): delete deprecated stats below when the test is no + // longer using them. + auto& response_rpcs_started_by_method = + *response->mutable_num_rpcs_started_by_method(); + auto& response_rpcs_succeeded_by_method = + *response->mutable_num_rpcs_succeeded_by_method(); + auto& response_rpcs_failed_by_method = + *response->mutable_num_rpcs_failed_by_method(); + for (const auto& rpc_by_type : rpcs_by_type_) { + auto total_succeeded = 0; + for (const auto& rpc_by_peer : rpc_by_type.second) { + total_succeeded += rpc_by_peer.second; + } + response_rpcs_succeeded_by_method[ClientConfigureRequest_RpcType_Name( + rpc_by_type.first)] = total_succeeded; + response_rpcs_started_by_method[ClientConfigureRequest_RpcType_Name( + rpc_by_type.first)] = + stats_watchers->global_request_id_by_type[rpc_by_type.first]; + response_rpcs_failed_by_method[ClientConfigureRequest_RpcType_Name( + rpc_by_type.first)] = no_remote_peer_by_type_[rpc_by_type.first]; + } + } + + private: + int start_id_; + int end_id_; + int rpcs_needed_; + int no_remote_peer_ = 0; + std::map no_remote_peer_by_type_; + // A map of stats keyed by peer name. + std::map rpcs_by_peer_; + // A two-level map of stats keyed at top level by RPC method and second level + // by peer name. + std::map> rpcs_by_type_; + // Storing accumulated stats in the response proto format. + LoadBalancerAccumulatedStatsResponse accumulated_stats_; + std::mutex m_; + std::condition_variable cv_; +}; + +class TestClient { + public: + TestClient(const std::shared_ptr& channel, + StatsWatchers* stats_watchers) + : stub_(TestService::NewStub(channel)), stats_watchers_(stats_watchers) {} + + void AsyncUnaryCall(const RpcConfig& config) { + SimpleResponse response; + int saved_request_id; + { + std::lock_guard lock(stats_watchers_->mu); + saved_request_id = ++stats_watchers_->global_request_id; + ++stats_watchers_ + ->global_request_id_by_type[ClientConfigureRequest::UNARY_CALL]; + } + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + + std::chrono::seconds(config.timeout_sec != 0 + ? config.timeout_sec + : absl::GetFlag(FLAGS_rpc_timeout_sec)); + AsyncClientCall* call = new AsyncClientCall; + for (const auto& data : config.metadata) { + call->context.AddMetadata(data.first, data.second); + // TODO(@donnadionne): move deadline to separate proto. + if (data.first == "rpc-behavior" && data.second == "keep-open") { + deadline = + std::chrono::system_clock::now() + std::chrono::seconds(INT_MAX); + } + } + call->context.set_deadline(deadline); + call->saved_request_id = saved_request_id; + call->rpc_type = ClientConfigureRequest::UNARY_CALL; + call->simple_response_reader = stub_->PrepareAsyncUnaryCall( + &call->context, SimpleRequest::default_instance(), &cq_); + call->simple_response_reader->StartCall(); + call->simple_response_reader->Finish(&call->simple_response, &call->status, + call); + } + + void AsyncEmptyCall(const RpcConfig& config) { + Empty response; + int saved_request_id; + { + std::lock_guard lock(stats_watchers_->mu); + saved_request_id = ++stats_watchers_->global_request_id; + ++stats_watchers_ + ->global_request_id_by_type[ClientConfigureRequest::EMPTY_CALL]; + } + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + + std::chrono::seconds(config.timeout_sec != 0 + ? config.timeout_sec + : absl::GetFlag(FLAGS_rpc_timeout_sec)); + AsyncClientCall* call = new AsyncClientCall; + for (const auto& data : config.metadata) { + call->context.AddMetadata(data.first, data.second); + // TODO(@donnadionne): move deadline to separate proto. + if (data.first == "rpc-behavior" && data.second == "keep-open") { + deadline = + std::chrono::system_clock::now() + std::chrono::seconds(INT_MAX); + } + } + call->context.set_deadline(deadline); + call->saved_request_id = saved_request_id; + call->rpc_type = ClientConfigureRequest::EMPTY_CALL; + call->empty_response_reader = stub_->PrepareAsyncEmptyCall( + &call->context, Empty::default_instance(), &cq_); + call->empty_response_reader->StartCall(); + call->empty_response_reader->Finish(&call->empty_response, &call->status, + call); + } + + void AsyncCompleteRpc() { + void* got_tag; + bool ok = false; + while (cq_.Next(&got_tag, &ok)) { + AsyncClientCall* call = static_cast(got_tag); + GPR_ASSERT(ok); + { + std::lock_guard lock(stats_watchers_->mu); + auto server_initial_metadata = call->context.GetServerInitialMetadata(); + auto metadata_hostname = + call->context.GetServerInitialMetadata().find("hostname"); + std::string hostname = + metadata_hostname != call->context.GetServerInitialMetadata().end() + ? std::string(metadata_hostname->second.data(), + metadata_hostname->second.length()) + : call->simple_response.hostname(); + for (auto watcher : stats_watchers_->watchers) { + watcher->RpcCompleted(call, hostname); + } + } + + if (!RpcStatusCheckSuccess(call)) { + if (absl::GetFlag(FLAGS_print_response) || + absl::GetFlag(FLAGS_fail_on_failed_rpc)) { + std::cout << "RPC failed: " << call->status.error_code() << ": " + << call->status.error_message() << std::endl; + } + if (absl::GetFlag(FLAGS_fail_on_failed_rpc) && + one_rpc_succeeded.load()) { + abort(); + } + } else { + if (absl::GetFlag(FLAGS_print_response)) { + auto metadata_hostname = + call->context.GetServerInitialMetadata().find("hostname"); + std::string hostname = + metadata_hostname != + call->context.GetServerInitialMetadata().end() + ? std::string(metadata_hostname->second.data(), + metadata_hostname->second.length()) + : call->simple_response.hostname(); + std::cout << "Greeting: Hello world, this is " << hostname + << ", from " << call->context.peer() << std::endl; + } + one_rpc_succeeded = true; + } + + delete call; + } + } + + private: + static bool RpcStatusCheckSuccess(AsyncClientCall* call) { + // Determine RPC success based on expected status. + grpc_status_code code; + GPR_ASSERT(grpc_status_code_from_string( + absl::GetFlag(FLAGS_expect_status).c_str(), &code)); + return code == static_cast(call->status.error_code()); + } + + std::unique_ptr stub_; + StatsWatchers* stats_watchers_; + CompletionQueue cq_; +}; + +class LoadBalancerStatsServiceImpl : public LoadBalancerStatsService::Service { + public: + explicit LoadBalancerStatsServiceImpl(StatsWatchers* stats_watchers) + : stats_watchers_(stats_watchers) {} + + Status GetClientStats(ServerContext* /*context*/, + const LoadBalancerStatsRequest* request, + LoadBalancerStatsResponse* response) override { + int start_id; + int end_id; + XdsStatsWatcher* watcher; + { + std::lock_guard lock(stats_watchers_->mu); + start_id = stats_watchers_->global_request_id + 1; + end_id = start_id + request->num_rpcs(); + watcher = new XdsStatsWatcher(start_id, end_id); + stats_watchers_->watchers.insert(watcher); + } + watcher->WaitForRpcStatsResponse(response, request->timeout_sec()); + { + std::lock_guard lock(stats_watchers_->mu); + stats_watchers_->watchers.erase(watcher); + } + delete watcher; + return Status::OK; + } + + Status GetClientAccumulatedStats( + ServerContext* /*context*/, + const LoadBalancerAccumulatedStatsRequest* /*request*/, + LoadBalancerAccumulatedStatsResponse* response) override { + std::lock_guard lock(stats_watchers_->mu); + stats_watchers_->global_watcher->GetCurrentRpcStats(response, + stats_watchers_); + return Status::OK; + } + + private: + StatsWatchers* stats_watchers_; +}; + +class XdsUpdateClientConfigureServiceImpl + : public XdsUpdateClientConfigureService::Service { + public: + explicit XdsUpdateClientConfigureServiceImpl( + RpcConfigurationsQueue* rpc_configs_queue) + : rpc_configs_queue_(rpc_configs_queue) {} + + Status Configure(ServerContext* /*context*/, + const ClientConfigureRequest* request, + ClientConfigureResponse* /*response*/) override { + std::map>> + metadata_map; + for (const auto& data : request->metadata()) { + metadata_map[data.type()].push_back({data.key(), data.value()}); + } + std::vector configs; + for (const auto& rpc : request->types()) { + RpcConfig config; + config.timeout_sec = request->timeout_sec(); + config.type = static_cast(rpc); + auto metadata_iter = metadata_map.find(rpc); + if (metadata_iter != metadata_map.end()) { + config.metadata = metadata_iter->second; + } + configs.push_back(std::move(config)); + } + { + std::lock_guard lock( + rpc_configs_queue_->mu_rpc_configs_queue); + rpc_configs_queue_->rpc_configs_queue.emplace_back(std::move(configs)); + } + return Status::OK; + } + + private: + RpcConfigurationsQueue* rpc_configs_queue_; +}; + +void RunTestLoop(std::chrono::duration duration_per_query, + StatsWatchers* stats_watchers, + RpcConfigurationsQueue* rpc_configs_queue) { + grpc::ChannelArguments channel_args; + channel_args.SetInt(GRPC_ARG_ENABLE_RETRIES, 1); + TestClient client( + grpc::CreateCustomChannel(absl::GetFlag(FLAGS_server), + absl::GetFlag(FLAGS_secure_mode) + ? grpc::experimental::XdsCredentials( + grpc::InsecureChannelCredentials()) + : grpc::InsecureChannelCredentials(), + channel_args), + stats_watchers); + std::chrono::time_point start = + std::chrono::system_clock::now(); + std::chrono::duration elapsed; + + std::thread thread = std::thread(&TestClient::AsyncCompleteRpc, &client); + + std::vector configs; + while (true) { + { + std::lock_guard lockk( + rpc_configs_queue->mu_rpc_configs_queue); + if (!rpc_configs_queue->rpc_configs_queue.empty()) { + configs = std::move(rpc_configs_queue->rpc_configs_queue.front()); + rpc_configs_queue->rpc_configs_queue.pop_front(); + } + } + + elapsed = std::chrono::system_clock::now() - start; + if (elapsed > duration_per_query) { + start = std::chrono::system_clock::now(); + for (const auto& config : configs) { + if (config.type == ClientConfigureRequest::EMPTY_CALL) { + client.AsyncEmptyCall(config); + } else if (config.type == ClientConfigureRequest::UNARY_CALL) { + client.AsyncUnaryCall(config); + } else { + GPR_ASSERT(0); + } + } + } + } + GPR_UNREACHABLE_CODE(thread.join()); +} + +void RunServer(const int port, StatsWatchers* stats_watchers, + RpcConfigurationsQueue* rpc_configs_queue) { + GPR_ASSERT(port != 0); + std::ostringstream server_address; + server_address << "0.0.0.0:" << port; + + LoadBalancerStatsServiceImpl stats_service(stats_watchers); + XdsUpdateClientConfigureServiceImpl client_config_service(rpc_configs_queue); + + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + builder.RegisterService(&stats_service); + builder.RegisterService(&client_config_service); + grpc::AddAdminServices(&builder); + builder.AddListeningPort(server_address.str(), + grpc::InsecureServerCredentials()); + std::unique_ptr server(builder.BuildAndStart()); + gpr_log(GPR_DEBUG, "Server listening on %s", server_address.str().c_str()); + + server->Wait(); +} + +void BuildRpcConfigsFromFlags(RpcConfigurationsQueue* rpc_configs_queue) { + // Store Metadata like + // "EmptyCall:key1:value1,UnaryCall:key1:value1,UnaryCall:key2:value2" into a + // map where the key is the RPC method and value is a vector of key:value + // pairs. {EmptyCall, [{key1,value1}], + // UnaryCall, [{key1,value1}, {key2,value2}]} + std::vector rpc_metadata = + absl::StrSplit(absl::GetFlag(FLAGS_metadata), ',', absl::SkipEmpty()); + std::map>> metadata_map; + for (auto& data : rpc_metadata) { + std::vector metadata = + absl::StrSplit(data, ':', absl::SkipEmpty()); + GPR_ASSERT(metadata.size() == 3); + if (metadata[0] == "EmptyCall") { + metadata_map[ClientConfigureRequest::EMPTY_CALL].push_back( + {metadata[1], metadata[2]}); + } else if (metadata[0] == "UnaryCall") { + metadata_map[ClientConfigureRequest::UNARY_CALL].push_back( + {metadata[1], metadata[2]}); + } else { + GPR_ASSERT(0); + } + } + std::vector configs; + std::vector rpc_methods = + absl::StrSplit(absl::GetFlag(FLAGS_rpc), ',', absl::SkipEmpty()); + for (const std::string& rpc_method : rpc_methods) { + RpcConfig config; + if (rpc_method == "EmptyCall") { + config.type = ClientConfigureRequest::EMPTY_CALL; + } else if (rpc_method == "UnaryCall") { + config.type = ClientConfigureRequest::UNARY_CALL; + } else { + GPR_ASSERT(0); + } + auto metadata_iter = metadata_map.find(config.type); + if (metadata_iter != metadata_map.end()) { + config.metadata = metadata_iter->second; + } + configs.push_back(std::move(config)); + } + { + std::lock_guard lock(rpc_configs_queue->mu_rpc_configs_queue); + rpc_configs_queue->rpc_configs_queue.emplace_back(std::move(configs)); + } +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + // Validate the expect_status flag. + grpc_status_code code; + GPR_ASSERT(grpc_status_code_from_string( + absl::GetFlag(FLAGS_expect_status).c_str(), &code)); + StatsWatchers stats_watchers; + RpcConfigurationsQueue rpc_config_queue; + + { + std::lock_guard lock(stats_watchers.mu); + stats_watchers.global_watcher = new XdsStatsWatcher(0, 0); + stats_watchers.watchers.insert(stats_watchers.global_watcher); + } + + BuildRpcConfigsFromFlags(&rpc_config_queue); + + std::chrono::duration duration_per_query = + std::chrono::nanoseconds(std::chrono::seconds(1)) / + absl::GetFlag(FLAGS_qps); + + std::vector test_threads; + test_threads.reserve(absl::GetFlag(FLAGS_num_channels)); + for (int i = 0; i < absl::GetFlag(FLAGS_num_channels); i++) { + test_threads.emplace_back(std::thread(&RunTestLoop, duration_per_query, + &stats_watchers, &rpc_config_queue)); + } + + RunServer(absl::GetFlag(FLAGS_stats_port), &stats_watchers, + &rpc_config_queue); + + for (auto it = test_threads.begin(); it != test_threads.end(); it++) { + it->join(); + } + + return 0; +} diff --git a/test/cpp/interop/xds_interop_server.cc b/test/cpp/interop/xds_interop_server.cc new file mode 100644 index 00000000..7b6f5a6e --- /dev/null +++ b/test/cpp/interop/xds_interop_server.cc @@ -0,0 +1,183 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/gethostname.h" +#include "src/core/lib/transport/byte_stream.h" +#include "src/proto/grpc/testing/empty.pb.h" +#include "src/proto/grpc/testing/messages.pb.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_health_check_service_impl.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(int32_t, port, 8080, "Server port for service."); +ABSL_FLAG(int32_t, maintenance_port, 8081, + "Server port for maintenance if --security is \"secure\"."); +ABSL_FLAG(std::string, server_id, "cpp_server", + "Server ID to include in responses."); +ABSL_FLAG(bool, secure_mode, false, + "If true, XdsServerCredentials are used, InsecureServerCredentials " + "otherwise"); + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using grpc::XdsServerBuilder; +using grpc::testing::Empty; +using grpc::testing::HealthCheckServiceImpl; +using grpc::testing::SimpleRequest; +using grpc::testing::SimpleResponse; +using grpc::testing::TestService; +using grpc::testing::XdsUpdateHealthService; + +class TestServiceImpl : public TestService::Service { + public: + explicit TestServiceImpl(const std::string& hostname) : hostname_(hostname) {} + + Status UnaryCall(ServerContext* context, const SimpleRequest* /*request*/, + SimpleResponse* response) override { + response->set_server_id(absl::GetFlag(FLAGS_server_id)); + response->set_hostname(hostname_); + context->AddInitialMetadata("hostname", hostname_); + return Status::OK; + } + + Status EmptyCall(ServerContext* context, const Empty* /*request*/, + Empty* /*response*/) override { + context->AddInitialMetadata("hostname", hostname_); + return Status::OK; + } + + private: + std::string hostname_; +}; + +class XdsUpdateHealthServiceImpl : public XdsUpdateHealthService::Service { + public: + explicit XdsUpdateHealthServiceImpl( + HealthCheckServiceImpl* health_check_service) + : health_check_service_(health_check_service) {} + + Status SetServing(ServerContext* /* context */, const Empty* /* request */, + Empty* /* response */) override { + health_check_service_->SetAll( + grpc::health::v1::HealthCheckResponse::SERVING); + return Status::OK; + } + + Status SetNotServing(ServerContext* /* context */, const Empty* /* request */, + Empty* /* response */) override { + health_check_service_->SetAll( + grpc::health::v1::HealthCheckResponse::NOT_SERVING); + return Status::OK; + } + + private: + HealthCheckServiceImpl* const health_check_service_; +}; + +void RunServer(bool secure_mode, const int port, const int maintenance_port, + const std::string& hostname) { + std::unique_ptr xds_enabled_server; + std::unique_ptr server; + TestServiceImpl service(hostname); + HealthCheckServiceImpl health_check_service; + health_check_service.SetStatus( + "", grpc::health::v1::HealthCheckResponse::SERVING); + health_check_service.SetStatus( + "grpc.testing.TestService", + grpc::health::v1::HealthCheckResponse::SERVING); + health_check_service.SetStatus( + "grpc.testing.XdsUpdateHealthService", + grpc::health::v1::HealthCheckResponse::SERVING); + XdsUpdateHealthServiceImpl update_health_service(&health_check_service); + + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + if (secure_mode) { + grpc::XdsServerBuilder xds_builder; + xds_builder.RegisterService(&service); + xds_builder.AddListeningPort( + absl::StrCat("0.0.0.0:", port), + grpc::XdsServerCredentials(grpc::InsecureServerCredentials())); + xds_enabled_server = xds_builder.BuildAndStart(); + gpr_log(GPR_INFO, "Server starting on 0.0.0.0:%d", port); + builder.RegisterService(&health_check_service); + builder.RegisterService(&update_health_service); + grpc::AddAdminServices(&builder); + builder.AddListeningPort(absl::StrCat("0.0.0.0:", maintenance_port), + grpc::InsecureServerCredentials()); + server = builder.BuildAndStart(); + gpr_log(GPR_INFO, "Maintenance server listening on 0.0.0.0:%d", + maintenance_port); + } else { + builder.RegisterService(&service); + builder.RegisterService(&health_check_service); + builder.RegisterService(&update_health_service); + grpc::AddAdminServices(&builder); + builder.AddListeningPort(absl::StrCat("0.0.0.0:", port), + grpc::InsecureServerCredentials()); + server = builder.BuildAndStart(); + gpr_log(GPR_INFO, "Server listening on 0.0.0.0:%d", port); + } + + server->Wait(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + char* hostname = grpc_gethostname(); + if (hostname == nullptr) { + std::cout << "Failed to get hostname, terminating" << std::endl; + return 1; + } + int port = absl::GetFlag(FLAGS_port); + if (port == 0) { + std::cout << "Invalid port, terminating" << std::endl; + return 1; + } + int maintenance_port = absl::GetFlag(FLAGS_maintenance_port); + if (maintenance_port == 0) { + std::cout << "Invalid maintenance port, terminating" << std::endl; + return 1; + } + grpc::EnableDefaultHealthCheckService(false); + RunServer(absl::GetFlag(FLAGS_secure_mode), port, maintenance_port, hostname); + + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_alarm.cc b/test/cpp/microbenchmarks/bm_alarm.cc new file mode 100644 index 00000000..83899d1a --- /dev/null +++ b/test/cpp/microbenchmarks/bm_alarm.cc @@ -0,0 +1,66 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This benchmark exists to ensure that immediately-firing alarms are fast */ + +#include + +#include +#include +#include +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +static void BM_Alarm_Tag_Immediate(benchmark::State& state) { + TrackCounters track_counters; + CompletionQueue cq; + Alarm alarm; + void* output_tag; + bool ok; + auto deadline = grpc_timeout_seconds_to_deadline(0); + for (auto _ : state) { + alarm.Set(&cq, deadline, nullptr); + cq.Next(&output_tag, &ok); + } + track_counters.Finish(state); +} +BENCHMARK(BM_Alarm_Tag_Immediate); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_arena.cc b/test/cpp/microbenchmarks/bm_arena.cc new file mode 100644 index 00000000..aba5ceb6 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_arena.cc @@ -0,0 +1,76 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark arenas */ + +#include + +#include "src/core/lib/gprpp/arena.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +using grpc_core::Arena; + +static void BM_Arena_NoOp(benchmark::State& state) { + for (auto _ : state) { + Arena::Create(state.range(0))->Destroy(); + } +} +BENCHMARK(BM_Arena_NoOp)->Range(1, 1024 * 1024); + +static void BM_Arena_ManyAlloc(benchmark::State& state) { + Arena* a = Arena::Create(state.range(0)); + const size_t realloc_after = + 1024 * 1024 * 1024 / ((state.range(1) + 15) & 0xffffff0u); + while (state.KeepRunning()) { + a->Alloc(state.range(1)); + // periodically recreate arena to avoid OOM + if (state.iterations() % realloc_after == 0) { + a->Destroy(); + a = Arena::Create(state.range(0)); + } + } + a->Destroy(); +} +BENCHMARK(BM_Arena_ManyAlloc)->Ranges({{1, 1024 * 1024}, {1, 32 * 1024}}); + +static void BM_Arena_Batch(benchmark::State& state) { + for (auto _ : state) { + Arena* a = Arena::Create(state.range(0)); + for (int i = 0; i < state.range(1); i++) { + a->Alloc(state.range(2)); + } + a->Destroy(); + } +} +BENCHMARK(BM_Arena_Batch)->Ranges({{1, 64 * 1024}, {1, 64}, {1, 1024}}); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_byte_buffer.cc b/test/cpp/microbenchmarks/bm_byte_buffer.cc new file mode 100644 index 00000000..78d87c68 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_byte_buffer.cc @@ -0,0 +1,135 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This benchmark exists to show that byte-buffer copy is size-independent */ + +#include + +#include + +#include +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +static void BM_ByteBuffer_Copy(benchmark::State& state) { + int num_slices = state.range(0); + size_t slice_size = state.range(1); + std::vector slices; + while (num_slices > 0) { + num_slices--; + std::unique_ptr buf(new char[slice_size]); + memset(buf.get(), 0, slice_size); + slices.emplace_back(buf.get(), slice_size); + } + grpc::ByteBuffer bb(slices.data(), num_slices); + for (auto _ : state) { + grpc::ByteBuffer cc(bb); + } +} +BENCHMARK(BM_ByteBuffer_Copy)->Ranges({{1, 64}, {1, 1024 * 1024}}); + +static void BM_ByteBufferReader_Next(benchmark::State& state) { + const int num_slices = state.range(0); + constexpr size_t kSliceSize = 16; + std::vector slices; + for (int i = 0; i < num_slices; ++i) { + std::unique_ptr buf(new char[kSliceSize]); + slices.emplace_back(g_core_codegen_interface->grpc_slice_from_copied_buffer( + buf.get(), kSliceSize)); + } + grpc_byte_buffer* bb = g_core_codegen_interface->grpc_raw_byte_buffer_create( + slices.data(), num_slices); + grpc_byte_buffer_reader reader; + GPR_ASSERT( + g_core_codegen_interface->grpc_byte_buffer_reader_init(&reader, bb)); + for (auto _ : state) { + grpc_slice* slice; + if (GPR_UNLIKELY(!g_core_codegen_interface->grpc_byte_buffer_reader_peek( + &reader, &slice))) { + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader); + GPR_ASSERT( + g_core_codegen_interface->grpc_byte_buffer_reader_init(&reader, bb)); + continue; + } + } + + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader); + g_core_codegen_interface->grpc_byte_buffer_destroy(bb); + for (auto& slice : slices) { + g_core_codegen_interface->grpc_slice_unref(slice); + } +} +BENCHMARK(BM_ByteBufferReader_Next)->Ranges({{64 * 1024, 1024 * 1024}}); + +static void BM_ByteBufferReader_Peek(benchmark::State& state) { + const int num_slices = state.range(0); + constexpr size_t kSliceSize = 16; + std::vector slices; + for (int i = 0; i < num_slices; ++i) { + std::unique_ptr buf(new char[kSliceSize]); + slices.emplace_back(g_core_codegen_interface->grpc_slice_from_copied_buffer( + buf.get(), kSliceSize)); + } + grpc_byte_buffer* bb = g_core_codegen_interface->grpc_raw_byte_buffer_create( + slices.data(), num_slices); + grpc_byte_buffer_reader reader; + GPR_ASSERT( + g_core_codegen_interface->grpc_byte_buffer_reader_init(&reader, bb)); + for (auto _ : state) { + grpc_slice* slice; + if (GPR_UNLIKELY(!g_core_codegen_interface->grpc_byte_buffer_reader_peek( + &reader, &slice))) { + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader); + GPR_ASSERT( + g_core_codegen_interface->grpc_byte_buffer_reader_init(&reader, bb)); + continue; + } + } + + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader); + g_core_codegen_interface->grpc_byte_buffer_destroy(bb); + for (auto& slice : slices) { + g_core_codegen_interface->grpc_slice_unref(slice); + } +} +BENCHMARK(BM_ByteBufferReader_Peek)->Ranges({{64 * 1024, 1024 * 1024}}); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc new file mode 100644 index 00000000..4c756c13 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_call_create.cc @@ -0,0 +1,843 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This benchmark exists to ensure that the benchmark integration is + * working */ + +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/deadline/deadline_filter.h" +#include "src/core/ext/filters/http/client/http_client_filter.h" +#include "src/core/ext/filters/http/message_compress/message_compress_filter.h" +#include "src/core/ext/filters/http/server/http_server_filter.h" +#include "src/core/ext/filters/message_size/message_size_filter.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/iomgr/call_combiner.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/transport_impl.h" +#include "src/cpp/client/create_channel_internal.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +void BM_Zalloc(benchmark::State& state) { + // speed of light for call creation is zalloc, so benchmark a few interesting + // sizes + TrackCounters track_counters; + size_t sz = state.range(0); + for (auto _ : state) { + gpr_free(gpr_zalloc(sz)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_Zalloc) + ->Arg(64) + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(1536) + ->Arg(2048) + ->Arg(3072) + ->Arg(4096) + ->Arg(5120) + ->Arg(6144) + ->Arg(7168); + +//////////////////////////////////////////////////////////////////////////////// +// Benchmarks creating full stacks + +class BaseChannelFixture { + public: + explicit BaseChannelFixture(grpc_channel* channel) : channel_(channel) {} + ~BaseChannelFixture() { grpc_channel_destroy(channel_); } + + grpc_channel* channel() const { return channel_; } + + private: + grpc_channel* const channel_; +}; + +class InsecureChannel : public BaseChannelFixture { + public: + InsecureChannel() + : BaseChannelFixture( + grpc_insecure_channel_create("localhost:1234", nullptr, nullptr)) {} +}; + +class LameChannel : public BaseChannelFixture { + public: + LameChannel() + : BaseChannelFixture(grpc_lame_client_channel_create( + "localhost:1234", GRPC_STATUS_UNAUTHENTICATED, "blah")) {} +}; + +template +static void BM_CallCreateDestroy(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + void* method_hdl = grpc_channel_register_call(fixture.channel(), "/foo/bar", + nullptr, nullptr); + for (auto _ : state) { + grpc_call_unref(grpc_channel_create_registered_call( + fixture.channel(), nullptr, GRPC_PROPAGATE_DEFAULTS, cq, method_hdl, + deadline, nullptr)); + } + grpc_completion_queue_destroy(cq); + track_counters.Finish(state); +} + +BENCHMARK_TEMPLATE(BM_CallCreateDestroy, InsecureChannel); +BENCHMARK_TEMPLATE(BM_CallCreateDestroy, LameChannel); + +//////////////////////////////////////////////////////////////////////////////// +// Benchmarks isolating individual filters + +static void* tag(int i) { + return reinterpret_cast(static_cast(i)); +} + +static void BM_LameChannelCallCreateCpp(benchmark::State& state) { + TrackCounters track_counters; + auto stub = + grpc::testing::EchoTestService::NewStub(grpc::CreateChannelInternal( + "", + grpc_lame_client_channel_create("localhost:1234", + GRPC_STATUS_UNAUTHENTICATED, "blah"), + std::vector>())); + grpc::CompletionQueue cq; + grpc::testing::EchoRequest send_request; + grpc::testing::EchoResponse recv_response; + grpc::Status recv_status; + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc::ClientContext cli_ctx; + auto reader = stub->AsyncEcho(&cli_ctx, send_request, &cq); + reader->Finish(&recv_response, &recv_status, tag(0)); + void* t; + bool ok; + GPR_ASSERT(cq.Next(&t, &ok)); + GPR_ASSERT(ok); + } + track_counters.Finish(state); +} +BENCHMARK(BM_LameChannelCallCreateCpp); + +static void do_nothing(void* /*ignored*/) {} + +static void BM_LameChannelCallCreateCore(benchmark::State& state) { + TrackCounters track_counters; + + grpc_channel* channel; + grpc_completion_queue* cq; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_slice details; + grpc::testing::EchoRequest send_request; + grpc_slice send_request_slice = + grpc_slice_new(&send_request, sizeof(send_request), do_nothing); + + channel = grpc_lame_client_channel_create( + "localhost:1234", GRPC_STATUS_UNAUTHENTICATED, "blah"); + cq = grpc_completion_queue_create_for_next(nullptr); + void* rc = grpc_channel_register_call( + channel, "/grpc.testing.EchoTestService/Echo", nullptr, nullptr); + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc_call* call = grpc_channel_create_registered_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, rc, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_byte_buffer* request_payload_send = + grpc_raw_byte_buffer_create(&send_request_slice, 1); + + // Fill in call ops + grpc_op ops[6]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload_send; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), + (void*)1, nullptr)); + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type != GRPC_QUEUE_SHUTDOWN); + GPR_ASSERT(ev.success != 0); + grpc_call_unref(call); + grpc_byte_buffer_destroy(request_payload_send); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + } + grpc_channel_destroy(channel); + grpc_completion_queue_destroy(cq); + grpc_slice_unref(send_request_slice); + track_counters.Finish(state); +} +BENCHMARK(BM_LameChannelCallCreateCore); + +static void BM_LameChannelCallCreateCoreSeparateBatch(benchmark::State& state) { + TrackCounters track_counters; + + grpc_channel* channel; + grpc_completion_queue* cq; + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_byte_buffer* response_payload_recv = nullptr; + grpc_status_code status; + grpc_slice details; + grpc::testing::EchoRequest send_request; + grpc_slice send_request_slice = + grpc_slice_new(&send_request, sizeof(send_request), do_nothing); + + channel = grpc_lame_client_channel_create( + "localhost:1234", GRPC_STATUS_UNAUTHENTICATED, "blah"); + cq = grpc_completion_queue_create_for_next(nullptr); + void* rc = grpc_channel_register_call( + channel, "/grpc.testing.EchoTestService/Echo", nullptr, nullptr); + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc_call* call = grpc_channel_create_registered_call( + channel, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, rc, + gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_byte_buffer* request_payload_send = + grpc_raw_byte_buffer_create(&send_request_slice, 1); + + // Fill in call ops + grpc_op ops[3]; + memset(ops, 0, sizeof(ops)); + grpc_op* op = ops; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op++; + op->op = GRPC_OP_SEND_MESSAGE; + op->data.send_message.send_message = request_payload_send; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op++; + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), + (void*)nullptr, nullptr)); + memset(ops, 0, sizeof(ops)); + op = ops; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = + &initial_metadata_recv; + op++; + op->op = GRPC_OP_RECV_MESSAGE; + op->data.recv_message.recv_message = &response_payload_recv; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op++; + + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call, ops, + (size_t)(op - ops), + (void*)1, nullptr)); + grpc_event ev = grpc_completion_queue_next( + cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); + GPR_ASSERT(ev.type != GRPC_QUEUE_SHUTDOWN); + GPR_ASSERT(ev.success == 0); + ev = grpc_completion_queue_next(cq, gpr_inf_future(GPR_CLOCK_REALTIME), + nullptr); + GPR_ASSERT(ev.type != GRPC_QUEUE_SHUTDOWN); + GPR_ASSERT(ev.success != 0); + grpc_call_unref(call); + grpc_byte_buffer_destroy(request_payload_send); + grpc_byte_buffer_destroy(response_payload_recv); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + } + grpc_channel_destroy(channel); + grpc_completion_queue_destroy(cq); + grpc_slice_unref(send_request_slice); + track_counters.Finish(state); +} +BENCHMARK(BM_LameChannelCallCreateCoreSeparateBatch); + +static void FilterDestroy(void* arg, grpc_error_handle /*error*/) { + gpr_free(arg); +} + +static void DoNothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +class FakeClientChannelFactory : public grpc_core::ClientChannelFactory { + public: + grpc_core::RefCountedPtr CreateSubchannel( + const grpc_resolved_address& /*address*/, + const grpc_channel_args* /*args*/) override { + return nullptr; + } +}; + +static grpc_arg StringArg(const char* key, const char* value) { + grpc_arg a; + a.type = GRPC_ARG_STRING; + a.key = const_cast(key); + a.value.string = const_cast(value); + return a; +} + +enum FixtureFlags : uint32_t { + CHECKS_NOT_LAST = 1, + REQUIRES_TRANSPORT = 2, +}; + +template +struct Fixture { + const grpc_channel_filter* filter = kFilter; + const uint32_t flags = kFlags; +}; + +namespace phony_filter { + +static void StartTransportStreamOp(grpc_call_element* /*elem*/, + grpc_transport_stream_op_batch* /*op*/) {} + +static void StartTransportOp(grpc_channel_element* /*elem*/, + grpc_transport_op* /*op*/) {} + +static grpc_error_handle InitCallElem(grpc_call_element* /*elem*/, + const grpc_call_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +static void SetPollsetOrPollsetSet(grpc_call_element* /*elem*/, + grpc_polling_entity* /*pollent*/) {} + +static void DestroyCallElem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* /*then_sched_closure*/) {} + +grpc_error_handle InitChannelElem(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +void DestroyChannelElem(grpc_channel_element* /*elem*/) {} + +void GetChannelInfo(grpc_channel_element* /*elem*/, + const grpc_channel_info* /*channel_info*/) {} + +static const grpc_channel_filter phony_filter = {StartTransportStreamOp, + StartTransportOp, + 0, + InitCallElem, + SetPollsetOrPollsetSet, + DestroyCallElem, + 0, + InitChannelElem, + DestroyChannelElem, + GetChannelInfo, + "phony_filter"}; + +} // namespace phony_filter + +namespace phony_transport { + +/* Memory required for a single stream element - this is allocated by upper + layers and initialized by the transport */ +size_t sizeof_stream; /* = sizeof(transport stream) */ + +/* name of this transport implementation */ +const char* name; + +/* implementation of grpc_transport_init_stream */ +int InitStream(grpc_transport* /*self*/, grpc_stream* /*stream*/, + grpc_stream_refcount* /*refcount*/, const void* /*server_data*/, + grpc_core::Arena* /*arena*/) { + return 0; +} + +/* implementation of grpc_transport_set_pollset */ +void SetPollset(grpc_transport* /*self*/, grpc_stream* /*stream*/, + grpc_pollset* /*pollset*/) {} + +/* implementation of grpc_transport_set_pollset */ +void SetPollsetSet(grpc_transport* /*self*/, grpc_stream* /*stream*/, + grpc_pollset_set* /*pollset_set*/) {} + +/* implementation of grpc_transport_perform_stream_op */ +void PerformStreamOp(grpc_transport* /*self*/, grpc_stream* /*stream*/, + grpc_transport_stream_op_batch* op) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, GRPC_ERROR_NONE); +} + +/* implementation of grpc_transport_perform_op */ +void PerformOp(grpc_transport* /*self*/, grpc_transport_op* /*op*/) {} + +/* implementation of grpc_transport_destroy_stream */ +void DestroyStream(grpc_transport* /*self*/, grpc_stream* /*stream*/, + grpc_closure* /*then_sched_closure*/) {} + +/* implementation of grpc_transport_destroy */ +void Destroy(grpc_transport* /*self*/) {} + +/* implementation of grpc_transport_get_endpoint */ +grpc_endpoint* GetEndpoint(grpc_transport* /*self*/) { return nullptr; } + +static const grpc_transport_vtable phony_transport_vtable = { + 0, "phony_http2", InitStream, + SetPollset, SetPollsetSet, PerformStreamOp, + PerformOp, DestroyStream, Destroy, + GetEndpoint}; + +static grpc_transport phony_transport = {&phony_transport_vtable}; + +} // namespace phony_transport + +class NoOp { + public: + class Op { + public: + Op(NoOp* /*p*/, grpc_call_stack* /*s*/, grpc_core::Arena*) {} + void Finish() {} + }; +}; + +class SendEmptyMetadata { + public: + SendEmptyMetadata() : op_payload_(nullptr) { + op_ = {}; + op_.on_complete = GRPC_CLOSURE_INIT(&closure_, DoNothing, nullptr, + grpc_schedule_on_exec_ctx); + op_.send_initial_metadata = true; + op_.payload = &op_payload_; + } + + class Op { + public: + Op(SendEmptyMetadata* p, grpc_call_stack* /*s*/, grpc_core::Arena* arena) + : batch_(arena) { + p->op_payload_.send_initial_metadata.send_initial_metadata = &batch_; + } + void Finish() {} + + private: + grpc_metadata_batch batch_; + }; + + private: + const gpr_timespec deadline_ = gpr_inf_future(GPR_CLOCK_MONOTONIC); + const gpr_timespec start_time_ = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_slice method_ = grpc_slice_from_static_string("/foo/bar"); + grpc_transport_stream_op_batch op_; + grpc_transport_stream_op_batch_payload op_payload_; + grpc_closure closure_; +}; + +// Test a filter in isolation. Fixture specifies the filter under test (use the +// Fixture<> template to specify this), and TestOp defines some unit of work to +// perform on said filter. +template +static void BM_IsolatedFilter(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + std::ostringstream label; + FakeClientChannelFactory fake_client_channel_factory; + + std::vector args = { + grpc_core::ClientChannelFactory::CreateChannelArg( + &fake_client_channel_factory), + StringArg(GRPC_ARG_SERVER_URI, "localhost"), + }; + grpc_channel_args channel_args = {args.size(), &args[0]}; + + std::vector filters; + if (fixture.filter != nullptr) { + filters.push_back(fixture.filter); + } + if (fixture.flags & CHECKS_NOT_LAST) { + filters.push_back(&phony_filter::phony_filter); + label << " #has_phony_filter"; + } + + grpc_core::ExecCtx exec_ctx; + size_t channel_size = grpc_channel_stack_size( + filters.empty() ? nullptr : &filters[0], filters.size()); + grpc_channel_stack* channel_stack = + static_cast(gpr_zalloc(channel_size)); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "channel_stack_init", + grpc_channel_stack_init(1, FilterDestroy, channel_stack, + filters.empty() ? nullptr : &filters[0], + filters.size(), &channel_args, + fixture.flags & REQUIRES_TRANSPORT + ? &phony_transport::phony_transport + : nullptr, + "CHANNEL", channel_stack))); + grpc_core::ExecCtx::Get()->Flush(); + grpc_call_stack* call_stack = + static_cast(gpr_zalloc(channel_stack->call_stack_size)); + grpc_millis deadline = GRPC_MILLIS_INF_FUTURE; + gpr_cycle_counter start_time = gpr_get_cycle_counter(); + grpc_slice method = grpc_slice_from_static_string("/foo/bar"); + grpc_call_final_info final_info; + TestOp test_op_data; + const int kArenaSize = 4096; + grpc_call_context_element context[GRPC_CONTEXT_COUNT] = {}; + grpc_call_element_args call_args{call_stack, + nullptr, + context, + method, + start_time, + deadline, + grpc_core::Arena::Create(kArenaSize), + nullptr}; + while (state.KeepRunning()) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + GRPC_ERROR_UNREF( + grpc_call_stack_init(channel_stack, 1, DoNothing, nullptr, &call_args)); + typename TestOp::Op op(&test_op_data, call_stack, call_args.arena); + grpc_call_stack_destroy(call_stack, &final_info, nullptr); + op.Finish(); + grpc_core::ExecCtx::Get()->Flush(); + // recreate arena every 64k iterations to avoid oom + if (0 == (state.iterations() & 0xffff)) { + call_args.arena->Destroy(); + call_args.arena = grpc_core::Arena::Create(kArenaSize); + } + } + call_args.arena->Destroy(); + grpc_channel_stack_destroy(channel_stack); + grpc_core::ExecCtx::Get()->Flush(); + + gpr_free(channel_stack); + gpr_free(call_stack); + + state.SetLabel(label.str()); + track_counters.Finish(state); +} + +typedef Fixture NoFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, NoFilter, NoOp); +typedef Fixture<&phony_filter::phony_filter, 0> PhonyFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, SendEmptyMetadata); +typedef Fixture<&grpc_core::ClientChannel::kFilterVtable, 0> + ClientChannelFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, ClientChannelFilter, NoOp); +typedef Fixture<&grpc_message_compress_filter, CHECKS_NOT_LAST> CompressFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, CompressFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, CompressFilter, SendEmptyMetadata); +typedef Fixture<&grpc_client_deadline_filter, CHECKS_NOT_LAST> + ClientDeadlineFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, ClientDeadlineFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, ClientDeadlineFilter, SendEmptyMetadata); +typedef Fixture<&grpc_server_deadline_filter, CHECKS_NOT_LAST> + ServerDeadlineFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, ServerDeadlineFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, ServerDeadlineFilter, SendEmptyMetadata); +typedef Fixture<&grpc_http_client_filter, CHECKS_NOT_LAST | REQUIRES_TRANSPORT> + HttpClientFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpClientFilter, SendEmptyMetadata); +typedef Fixture<&grpc_http_server_filter, CHECKS_NOT_LAST> HttpServerFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpServerFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, HttpServerFilter, SendEmptyMetadata); +typedef Fixture<&grpc_message_size_filter, CHECKS_NOT_LAST> MessageSizeFilter; +BENCHMARK_TEMPLATE(BM_IsolatedFilter, MessageSizeFilter, NoOp); +BENCHMARK_TEMPLATE(BM_IsolatedFilter, MessageSizeFilter, SendEmptyMetadata); +// This cmake target is disabled for now because it depends on OpenCensus, which +// is Bazel-only. +// typedef Fixture<&grpc_server_load_reporting_filter, CHECKS_NOT_LAST> +// LoadReportingFilter; +// BENCHMARK_TEMPLATE(BM_IsolatedFilter, LoadReportingFilter, NoOp); +// BENCHMARK_TEMPLATE(BM_IsolatedFilter, LoadReportingFilter, +// SendEmptyMetadata); + +//////////////////////////////////////////////////////////////////////////////// +// Benchmarks isolating grpc_call + +namespace isolated_call_filter { + +typedef struct { + grpc_core::CallCombiner* call_combiner; +} call_data; + +static void StartTransportStreamOp(grpc_call_element* elem, + grpc_transport_stream_op_batch* op) { + call_data* calld = static_cast(elem->call_data); + // Construct list of closures to return. + grpc_core::CallCombinerClosureList closures; + if (op->recv_initial_metadata) { + closures.Add(op->payload->recv_initial_metadata.recv_initial_metadata_ready, + GRPC_ERROR_NONE, "recv_initial_metadata"); + } + if (op->recv_message) { + closures.Add(op->payload->recv_message.recv_message_ready, GRPC_ERROR_NONE, + "recv_message"); + } + if (op->recv_trailing_metadata) { + closures.Add( + op->payload->recv_trailing_metadata.recv_trailing_metadata_ready, + GRPC_ERROR_NONE, "recv_trailing_metadata"); + } + if (op->on_complete != nullptr) { + closures.Add(op->on_complete, GRPC_ERROR_NONE, "on_complete"); + } + // Execute closures. + closures.RunClosures(calld->call_combiner); +} + +static void StartTransportOp(grpc_channel_element* /*elem*/, + grpc_transport_op* op) { + if (op->disconnect_with_error != GRPC_ERROR_NONE) { + GRPC_ERROR_UNREF(op->disconnect_with_error); + } + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_consumed, GRPC_ERROR_NONE); +} + +static grpc_error_handle InitCallElem(grpc_call_element* elem, + const grpc_call_element_args* args) { + call_data* calld = static_cast(elem->call_data); + calld->call_combiner = args->call_combiner; + return GRPC_ERROR_NONE; +} + +static void SetPollsetOrPollsetSet(grpc_call_element* /*elem*/, + grpc_polling_entity* /*pollent*/) {} + +static void DestroyCallElem(grpc_call_element* /*elem*/, + const grpc_call_final_info* /*final_info*/, + grpc_closure* then_sched_closure) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, then_sched_closure, GRPC_ERROR_NONE); +} + +grpc_error_handle InitChannelElem(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) { + return GRPC_ERROR_NONE; +} + +void DestroyChannelElem(grpc_channel_element* /*elem*/) {} + +void GetChannelInfo(grpc_channel_element* /*elem*/, + const grpc_channel_info* /*channel_info*/) {} + +static const grpc_channel_filter isolated_call_filter = { + StartTransportStreamOp, + StartTransportOp, + sizeof(call_data), + InitCallElem, + SetPollsetOrPollsetSet, + DestroyCallElem, + 0, + InitChannelElem, + DestroyChannelElem, + GetChannelInfo, + "isolated_call_filter"}; +} // namespace isolated_call_filter + +class IsolatedCallFixture : public TrackCounters { + public: + IsolatedCallFixture() { + // We are calling grpc_channel_stack_builder_create() instead of + // grpc_channel_create() here, which means we're not getting the + // grpc_init() called by grpc_channel_create(), but we are getting + // the grpc_shutdown() run by grpc_channel_destroy(). So we need to + // call grpc_init() manually here to balance things out. + grpc_init(); + grpc_channel_stack_builder* builder = grpc_channel_stack_builder_create(); + grpc_channel_stack_builder_set_name(builder, "phony"); + grpc_channel_stack_builder_set_target(builder, "phony_target"); + GPR_ASSERT(grpc_channel_stack_builder_append_filter( + builder, &isolated_call_filter::isolated_call_filter, nullptr, + nullptr)); + { + grpc_core::ExecCtx exec_ctx; + channel_ = grpc_channel_create_with_builder( + builder, GRPC_CLIENT_CHANNEL, grpc_resource_user_create_unlimited(), + 0); + } + cq_ = grpc_completion_queue_create_for_next(nullptr); + } + + void Finish(benchmark::State& state) override { + grpc_completion_queue_destroy(cq_); + grpc_channel_destroy(channel_); + TrackCounters::Finish(state); + } + + grpc_channel* channel() const { return channel_; } + grpc_completion_queue* cq() const { return cq_; } + + private: + grpc_completion_queue* cq_; + grpc_channel* channel_; +}; + +static void BM_IsolatedCall_NoOp(benchmark::State& state) { + IsolatedCallFixture fixture; + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + void* method_hdl = grpc_channel_register_call(fixture.channel(), "/foo/bar", + nullptr, nullptr); + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc_call_unref(grpc_channel_create_registered_call( + fixture.channel(), nullptr, GRPC_PROPAGATE_DEFAULTS, fixture.cq(), + method_hdl, deadline, nullptr)); + } + fixture.Finish(state); +} +BENCHMARK(BM_IsolatedCall_NoOp); + +static void BM_IsolatedCall_Unary(benchmark::State& state) { + IsolatedCallFixture fixture; + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + void* method_hdl = grpc_channel_register_call(fixture.channel(), "/foo/bar", + nullptr, nullptr); + grpc_slice slice = grpc_slice_from_static_string("hello world"); + grpc_byte_buffer* send_message = grpc_raw_byte_buffer_create(&slice, 1); + grpc_byte_buffer* recv_message = nullptr; + grpc_status_code status_code; + grpc_slice status_details = grpc_empty_slice(); + grpc_metadata_array recv_initial_metadata; + grpc_metadata_array_init(&recv_initial_metadata); + grpc_metadata_array recv_trailing_metadata; + grpc_metadata_array_init(&recv_trailing_metadata); + grpc_op ops[6]; + memset(ops, 0, sizeof(ops)); + ops[0].op = GRPC_OP_SEND_INITIAL_METADATA; + ops[1].op = GRPC_OP_SEND_MESSAGE; + ops[1].data.send_message.send_message = send_message; + ops[2].op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + ops[3].op = GRPC_OP_RECV_INITIAL_METADATA; + ops[3].data.recv_initial_metadata.recv_initial_metadata = + &recv_initial_metadata; + ops[4].op = GRPC_OP_RECV_MESSAGE; + ops[4].data.recv_message.recv_message = &recv_message; + ops[5].op = GRPC_OP_RECV_STATUS_ON_CLIENT; + ops[5].data.recv_status_on_client.status = &status_code; + ops[5].data.recv_status_on_client.status_details = &status_details; + ops[5].data.recv_status_on_client.trailing_metadata = &recv_trailing_metadata; + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc_call* call = grpc_channel_create_registered_call( + fixture.channel(), nullptr, GRPC_PROPAGATE_DEFAULTS, fixture.cq(), + method_hdl, deadline, nullptr); + grpc_call_start_batch(call, ops, 6, tag(1), nullptr); + grpc_completion_queue_next(fixture.cq(), + gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + grpc_call_unref(call); + } + fixture.Finish(state); + grpc_metadata_array_destroy(&recv_initial_metadata); + grpc_metadata_array_destroy(&recv_trailing_metadata); + grpc_byte_buffer_destroy(send_message); +} +BENCHMARK(BM_IsolatedCall_Unary); + +static void BM_IsolatedCall_StreamingSend(benchmark::State& state) { + IsolatedCallFixture fixture; + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + void* method_hdl = grpc_channel_register_call(fixture.channel(), "/foo/bar", + nullptr, nullptr); + grpc_slice slice = grpc_slice_from_static_string("hello world"); + grpc_byte_buffer* send_message = grpc_raw_byte_buffer_create(&slice, 1); + grpc_metadata_array recv_initial_metadata; + grpc_metadata_array_init(&recv_initial_metadata); + grpc_metadata_array recv_trailing_metadata; + grpc_metadata_array_init(&recv_trailing_metadata); + grpc_op ops[2]; + memset(ops, 0, sizeof(ops)); + ops[0].op = GRPC_OP_SEND_INITIAL_METADATA; + ops[1].op = GRPC_OP_RECV_INITIAL_METADATA; + ops[1].data.recv_initial_metadata.recv_initial_metadata = + &recv_initial_metadata; + grpc_call* call = grpc_channel_create_registered_call( + fixture.channel(), nullptr, GRPC_PROPAGATE_DEFAULTS, fixture.cq(), + method_hdl, deadline, nullptr); + grpc_call_start_batch(call, ops, 2, tag(1), nullptr); + grpc_completion_queue_next(fixture.cq(), gpr_inf_future(GPR_CLOCK_MONOTONIC), + nullptr); + memset(ops, 0, sizeof(ops)); + ops[0].op = GRPC_OP_SEND_MESSAGE; + ops[0].data.send_message.send_message = send_message; + for (auto _ : state) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + grpc_call_start_batch(call, ops, 1, tag(2), nullptr); + grpc_completion_queue_next(fixture.cq(), + gpr_inf_future(GPR_CLOCK_MONOTONIC), nullptr); + } + grpc_call_unref(call); + fixture.Finish(state); + grpc_metadata_array_destroy(&recv_initial_metadata); + grpc_metadata_array_destroy(&recv_trailing_metadata); + grpc_byte_buffer_destroy(send_message); +} +BENCHMARK(BM_IsolatedCall_StreamingSend); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_callback_streaming_ping_pong.cc b/test/cpp/microbenchmarks/bm_callback_streaming_ping_pong.cc new file mode 100644 index 00000000..bc02bb9c --- /dev/null +++ b/test/cpp/microbenchmarks/bm_callback_streaming_ping_pong.cc @@ -0,0 +1,138 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/callback_streaming_ping_pong.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +/******************************************************************************* + * CONFIGURATIONS + */ + +// Replace "benchmark::internal::Benchmark" with "::testing::Benchmark" to use +// internal microbenchmarking tooling +static void StreamingPingPongMsgSizeArgs(benchmark::internal::Benchmark* b) { + // base case: 0 byte ping-pong msgs + b->Args({0, 1}); + b->Args({0, 2}); + + for (int msg_size = 1; msg_size <= 128 * 1024 * 1024; msg_size *= 8) { + b->Args({msg_size, 1}); + b->Args({msg_size, 2}); + } +} + +// Replace "benchmark::internal::Benchmark" with "::testing::Benchmark" to use +// internal microbenchmarking tooling +static void StreamingPingPongMsgsNumberArgs(benchmark::internal::Benchmark* b) { + for (int msg_number = 1; msg_number <= 256 * 1024; msg_number *= 8) { + b->Args({0, msg_number}); + b->Args({1024, msg_number}); + } +} + +// Streaming with different message size +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongMsgSizeArgs); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, MinInProcess, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongMsgSizeArgs); + +// Streaming with different message number +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongMsgsNumberArgs); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, MinInProcess, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongMsgsNumberArgs); + +// Client context with different metadata +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, + NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 2>, + NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 1}); + +// Server context with different metadata +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 1}); +BENCHMARK_TEMPLATE(BM_CallbackBidiStreaming, InProcess, NoOpMutator, + Server_AddInitialMetadata, 100>) + ->Args({0, 1}); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_callback_unary_ping_pong.cc b/test/cpp/microbenchmarks/bm_callback_unary_ping_pong.cc new file mode 100644 index 00000000..93eef7c1 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_callback_unary_ping_pong.cc @@ -0,0 +1,120 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/callback_unary_ping_pong.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +/******************************************************************************* + * CONFIGURATIONS + */ + +// Replace "benchmark::internal::Benchmark" with "::testing::Benchmark" to use +// internal microbenchmarking tooling +static void SweepSizesArgs(benchmark::internal::Benchmark* b) { + b->Args({0, 0}); + for (int i = 1; i <= 128 * 1024 * 1024; i *= 8) { + // First argument is the message size of request + // Second argument is the message size of response + b->Args({i, 0}); + b->Args({0, i}); + b->Args({i, i}); + } +} + +// Unary ping pong with different message size of request and response +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, MinInProcess, NoOpMutator, + NoOpMutator) + ->Apply(SweepSizesArgs); + +// Client context with different metadata +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 2>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); + +// Server context with different metadata +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_CallbackUnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 100>) + ->Args({0, 0}); +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_channel.cc b/test/cpp/microbenchmarks/bm_channel.cc new file mode 100644 index 00000000..b3e9b77b --- /dev/null +++ b/test/cpp/microbenchmarks/bm_channel.cc @@ -0,0 +1,93 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark channel */ + +#include + +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +class ChannelDestroyerFixture { + public: + ChannelDestroyerFixture() {} + virtual ~ChannelDestroyerFixture() { + if (channel_) { + grpc_channel_destroy(channel_); + } + } + virtual void Init() = 0; + + protected: + grpc_channel* channel_ = nullptr; +}; + +class InsecureChannelFixture : public ChannelDestroyerFixture { + public: + InsecureChannelFixture() {} + void Init() override { + channel_ = grpc_insecure_channel_create("localhost:1234", nullptr, nullptr); + } +}; + +class LameChannelFixture : public ChannelDestroyerFixture { + public: + LameChannelFixture() {} + void Init() override { + channel_ = grpc_lame_client_channel_create( + "localhost:1234", GRPC_STATUS_UNAUTHENTICATED, "blah"); + } +}; + +template +static void BM_InsecureChannelCreateDestroy(benchmark::State& state) { + // In order to test if channel creation time is affected by the number of + // already existing channels, we create some initial channels here. + Fixture initial_channels[512]; + for (int i = 0; i < state.range(0); i++) { + initial_channels[i].Init(); + } + for (auto _ : state) { + Fixture channel; + channel.Init(); + } +} +BENCHMARK_TEMPLATE(BM_InsecureChannelCreateDestroy, InsecureChannelFixture) + ->Range(0, 512); +; +BENCHMARK_TEMPLATE(BM_InsecureChannelCreateDestroy, LameChannelFixture) + ->Range(0, 512); +; + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_chttp2_hpack.cc b/test/cpp/microbenchmarks/bm_chttp2_hpack.cc new file mode 100644 index 00000000..529f41ee --- /dev/null +++ b/test/cpp/microbenchmarks/bm_chttp2_hpack.cc @@ -0,0 +1,867 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Microbenchmarks around CHTTP2 HPACK operations */ + +#include + +#include +#include + +#include + +#include +#include + +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/timeout_encoding.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +static grpc_slice MakeSlice(std::vector bytes) { + grpc_slice s = grpc_slice_malloc(bytes.size()); + uint8_t* p = GRPC_SLICE_START_PTR(s); + for (auto b : bytes) { + *p++ = b; + } + return s; +} + +//////////////////////////////////////////////////////////////////////////////// +// HPACK encoder +// + +static void BM_HpackEncoderInitDestroy(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::HPackCompressor c; + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_HpackEncoderInitDestroy); + +static void BM_HpackEncoderEncodeDeadline(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + grpc_millis saved_now = grpc_core::ExecCtx::Get()->Now(); + + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + b.Set(grpc_core::GrpcTimeoutMetadata(), saved_now + 30 * 1000); + + grpc_core::HPackCompressor c; + grpc_transport_one_way_stats stats; + stats = {}; + grpc_slice_buffer outbuf; + grpc_slice_buffer_init(&outbuf); + while (state.KeepRunning()) { + c.EncodeHeaders( + grpc_core::HPackCompressor::EncodeHeaderOptions{ + static_cast(state.iterations()), + true, + false, + static_cast(1024), + &stats, + }, + b, &outbuf); + grpc_slice_buffer_reset_and_unref_internal(&outbuf); + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_slice_buffer_destroy_internal(&outbuf); + + std::ostringstream label; + label << "framing_bytes/iter:" + << (static_cast(stats.framing_bytes) / + static_cast(state.iterations())) + << " header_bytes/iter:" + << (static_cast(stats.header_bytes) / + static_cast(state.iterations())); + track_counters.AddLabel(label.str()); + track_counters.Finish(state); +} +BENCHMARK(BM_HpackEncoderEncodeDeadline); + +template +static void BM_HpackEncoderEncodeHeader(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + static bool logged_representative_output = false; + + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + Fixture::Prepare(&b); + + grpc_core::HPackCompressor c; + grpc_transport_one_way_stats stats; + stats = {}; + grpc_slice_buffer outbuf; + grpc_slice_buffer_init(&outbuf); + while (state.KeepRunning()) { + static constexpr int kEnsureMaxFrameAtLeast = 2; + c.EncodeHeaders( + grpc_core::HPackCompressor::EncodeHeaderOptions{ + static_cast(state.iterations()), + state.range(0) != 0, + Fixture::kEnableTrueBinary, + static_cast(state.range(1) + kEnsureMaxFrameAtLeast), + &stats, + }, + b, &outbuf); + if (!logged_representative_output && state.iterations() > 3) { + logged_representative_output = true; + for (size_t i = 0; i < outbuf.count; i++) { + char* s = grpc_dump_slice(outbuf.slices[i], GPR_DUMP_HEX); + gpr_log(GPR_DEBUG, "%" PRIdPTR ": %s", i, s); + gpr_free(s); + } + } + grpc_slice_buffer_reset_and_unref_internal(&outbuf); + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_slice_buffer_destroy_internal(&outbuf); + + std::ostringstream label; + label << "framing_bytes/iter:" + << (static_cast(stats.framing_bytes) / + static_cast(state.iterations())) + << " header_bytes/iter:" + << (static_cast(stats.header_bytes) / + static_cast(state.iterations())); + track_counters.AddLabel(label.str()); + track_counters.Finish(state); +} + +namespace hpack_encoder_fixtures { + +class EmptyBatch { + public: + static constexpr bool kEnableTrueBinary = false; + static void Prepare(grpc_metadata_batch*) {} +}; + +class SingleStaticElem { + public: + static constexpr bool kEnableTrueBinary = false; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE))); + } +}; + +class SingleInternedElem { + public: + static constexpr bool kEnableTrueBinary = false; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("abc")), + grpc_slice_intern(grpc_slice_from_static_string("def")))))); + } +}; + +template +class SingleInternedBinaryElem { + public: + static constexpr bool kEnableTrueBinary = kTrueBinary; + static void Prepare(grpc_metadata_batch* b) { + grpc_slice bytes = MakeBytes(); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("abc-bin")), + grpc_slice_intern(bytes))))); + grpc_slice_unref(bytes); + } + + private: + static grpc_slice MakeBytes() { + std::vector v; + v.reserve(kLength); + for (int i = 0; i < kLength; i++) { + v.push_back(static_cast(rand())); + } + return grpc_slice_from_copied_buffer(v.data(), v.size()); + } +}; + +class SingleInternedKeyElem { + public: + static constexpr bool kEnableTrueBinary = false; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + grpc_slice_intern(grpc_slice_from_static_string("abc")), + grpc_slice_from_static_string("def"))))); + } +}; + +class SingleNonInternedElem { + public: + static constexpr bool kEnableTrueBinary = false; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT( + GRPC_LOG_IF_ERROR("addmd", b->Append(grpc_mdelem_from_slices( + grpc_slice_from_static_string("abc"), + grpc_slice_from_static_string("def"))))); + } +}; + +template +class SingleNonInternedBinaryElem { + public: + static constexpr bool kEnableTrueBinary = kTrueBinary; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + grpc_slice_from_static_string("abc-bin"), MakeBytes())))); + } + + private: + static grpc_slice MakeBytes() { + std::vector v; + v.reserve(kLength); + for (int i = 0; i < kLength; i++) { + v.push_back(static_cast(rand())); + } + return grpc_slice_from_copied_buffer(v.data(), v.size()); + } +}; + +class RepresentativeClientInitialMetadata { + public: + static constexpr bool kEnableTrueBinary = true; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_SCHEME_HTTP))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_METHOD_POST))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, + grpc_slice_intern(grpc_slice_from_static_string("/foo/bar")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_AUTHORITY, + grpc_slice_intern(grpc_slice_from_static_string( + "foo.test.google.fr:1234")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append( + GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE_COMMA_GZIP))); + b->Set(grpc_core::TeMetadata(), grpc_core::TeMetadata::kTrailers); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_USER_AGENT, + grpc_slice_intern(grpc_slice_from_static_string( + "grpc-c/3.0.0-dev (linux; chttp2; green)")))))); + } +}; + +// This fixture reflects how initial metadata are sent by a production client, +// with non-indexed :path and binary headers. The metadata here are the same as +// the corresponding parser benchmark below. +class MoreRepresentativeClientInitialMetadata { + public: + static constexpr bool kEnableTrueBinary = true; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_SCHEME_HTTP))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_METHOD_POST))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, grpc_slice_intern(grpc_slice_from_static_string( + "/grpc.test.FooService/BarMethod")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_AUTHORITY, + grpc_slice_intern(grpc_slice_from_static_string( + "foo.test.google.fr:1234")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_TRACE_BIN, + grpc_slice_from_static_string("\x00\x01\x02\x03\x04\x05\x06\x07\x08" + "\x09\x0a\x0b\x0c\x0d\x0e\x0f" + "\x10\x11\x12\x13\x14\x15\x16\x17\x18" + "\x19\x1a\x1b\x1c\x1d\x1e\x1f" + "\x20\x21\x22\x23\x24\x25\x26\x27\x28" + "\x29\x2a\x2b\x2c\x2d\x2e\x2f" + "\x30"))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_GRPC_TAGS_BIN, + grpc_slice_from_static_string("\x00\x01\x02\x03\x04\x05\x06\x07\x08" + "\x09\x0a\x0b\x0c\x0d\x0e\x0f" + "\x10\x11\x12\x13"))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append( + GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE_COMMA_GZIP))); + b->Set(grpc_core::TeMetadata(), grpc_core::TeMetadata::kTrailers); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_USER_AGENT, + grpc_slice_intern(grpc_slice_from_static_string( + "grpc-c/3.0.0-dev (linux; chttp2; green)")))))); + } +}; + +class RepresentativeServerInitialMetadata { + public: + static constexpr bool kEnableTrueBinary = true; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_STATUS_200))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append( + GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE_COMMA_GZIP))); + } +}; + +class RepresentativeServerTrailingMetadata { + public: + static constexpr bool kEnableTrueBinary = true; + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT( + GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_GRPC_STATUS_0))); + } +}; + +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, EmptyBatch)->Args({0, 16384}); +// test with eof (shouldn't affect anything) +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, EmptyBatch)->Args({1, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, SingleStaticElem) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, SingleInternedKeyElem) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, SingleInternedElem) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<1, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<3, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<10, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<31, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<100, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<1, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<3, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<10, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<31, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleInternedBinaryElem<100, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, SingleNonInternedElem) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<1, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<3, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<10, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<31, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<100, false>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<1, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<3, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<10, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<31, true>) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + SingleNonInternedBinaryElem<100, true>) + ->Args({0, 16384}); +// test with a tiny frame size, to highlight continuation costs +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, SingleNonInternedElem) + ->Args({0, 1}); + +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + RepresentativeClientInitialMetadata) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + MoreRepresentativeClientInitialMetadata) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + RepresentativeServerInitialMetadata) + ->Args({0, 16384}); +BENCHMARK_TEMPLATE(BM_HpackEncoderEncodeHeader, + RepresentativeServerTrailingMetadata) + ->Args({1, 16384}); + +} // namespace hpack_encoder_fixtures + +//////////////////////////////////////////////////////////////////////////////// +// HPACK parser +// + +static void BM_HpackParserInitDestroy(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + { grpc_core::HPackParser(); } + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_HpackParserInitDestroy); + +template +static void BM_HpackParserParseHeader(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + std::vector init_slices = Fixture::GetInitSlices(); + std::vector benchmark_slices = Fixture::GetBenchmarkSlices(); + grpc_core::HPackParser p; + const int kArenaSize = 4096 * 4096; + auto* arena = grpc_core::Arena::Create(kArenaSize); + grpc_core::ManualConstructor b; + b.Init(arena); + p.BeginFrame(&*b, std::numeric_limits::max(), + grpc_core::HPackParser::Boundary::None, + grpc_core::HPackParser::Priority::None, + grpc_core::HPackParser::LogInfo{ + 1, grpc_core::HPackParser::LogInfo::kHeaders, false}); + for (auto slice : init_slices) { + GPR_ASSERT(GRPC_ERROR_NONE == p.Parse(slice, false)); + } + while (state.KeepRunning()) { + b->Clear(); + for (auto slice : benchmark_slices) { + GPR_ASSERT(GRPC_ERROR_NONE == p.Parse(slice, false)); + } + grpc_core::ExecCtx::Get()->Flush(); + // Recreate arena every 4k iterations to avoid oom + if (0 == (state.iterations() & 0xfff)) { + b.Destroy(); + arena->Destroy(); + arena = grpc_core::Arena::Create(kArenaSize); + b.Init(arena); + p.BeginFrame(&*b, std::numeric_limits::max(), + grpc_core::HPackParser::Boundary::None, + grpc_core::HPackParser::Priority::None, + grpc_core::HPackParser::LogInfo{ + 1, grpc_core::HPackParser::LogInfo::kHeaders, false}); + } + } + // Clean up + b.Destroy(); + for (auto slice : init_slices) grpc_slice_unref(slice); + for (auto slice : benchmark_slices) grpc_slice_unref(slice); + arena->Destroy(); + + track_counters.Finish(state); +} + +namespace hpack_parser_fixtures { + +class EmptyBatch { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({})}; + } +}; + +class IndexedSingleStaticElem { + public: + static std::vector GetInitSlices() { + return {MakeSlice( + {0x40, 0x07, ':', 's', 't', 'a', 't', 'u', 's', 0x03, '2', '0', '0'})}; + } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0xbe})}; + } +}; + +class AddIndexedSingleStaticElem { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice( + {0x40, 0x07, ':', 's', 't', 'a', 't', 'u', 's', 0x03, '2', '0', '0'})}; + } +}; + +class KeyIndexedSingleStaticElem { + public: + static std::vector GetInitSlices() { + return {MakeSlice( + {0x40, 0x07, ':', 's', 't', 'a', 't', 'u', 's', 0x03, '2', '0', '0'})}; + } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x7e, 0x03, 'd', 'e', 'f'})}; + } +}; + +class IndexedSingleInternedElem { + public: + static std::vector GetInitSlices() { + return {MakeSlice({0x40, 0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'})}; + } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0xbe})}; + } +}; + +class AddIndexedSingleInternedElem { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x40, 0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'})}; + } +}; + +class KeyIndexedSingleInternedElem { + public: + static std::vector GetInitSlices() { + return {MakeSlice({0x40, 0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'})}; + } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x7e, 0x03, 'g', 'h', 'i'})}; + } +}; + +class NonIndexedElem { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x00, 0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'})}; + } +}; + +template +class NonIndexedBinaryElem; + +template +class NonIndexedBinaryElem { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + std::vector v = { + 0x00, 0x07, 'a', 'b', 'c', + '-', 'b', 'i', 'n', static_cast(kLength + 1), + 0}; + for (int i = 0; i < kLength; i++) { + v.push_back(static_cast(i)); + } + return {MakeSlice(v)}; + } +}; + +template <> +class NonIndexedBinaryElem<1, false> { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice( + {0x00, 0x07, 'a', 'b', 'c', '-', 'b', 'i', 'n', 0x82, 0xf7, 0xb3})}; + } +}; + +template <> +class NonIndexedBinaryElem<3, false> { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x00, 0x07, 'a', 'b', 'c', '-', 'b', 'i', 'n', 0x84, + 0x7f, 0x4e, 0x29, 0x3f})}; + } +}; + +template <> +class NonIndexedBinaryElem<10, false> { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x00, 0x07, 'a', 'b', 'c', '-', 'b', + 'i', 'n', 0x8b, 0x71, 0x0c, 0xa5, 0x81, + 0x73, 0x7b, 0x47, 0x13, 0xe9, 0xf7, 0xe3})}; + } +}; + +template <> +class NonIndexedBinaryElem<31, false> { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({0x00, 0x07, 'a', 'b', 'c', '-', 'b', 'i', 'n', + 0xa3, 0x92, 0x43, 0x7f, 0xbe, 0x7c, 0xea, 0x6f, 0xf3, + 0x3d, 0xa7, 0xa7, 0x67, 0xfb, 0xe2, 0x82, 0xf7, 0xf2, + 0x8f, 0x1f, 0x9d, 0xdf, 0xf1, 0x7e, 0xb3, 0xef, 0xb2, + 0x8f, 0x53, 0x77, 0xce, 0x0c, 0x13, 0xe3, 0xfd, 0x87})}; + } +}; + +template <> +class NonIndexedBinaryElem<100, false> { + public: + static std::vector GetInitSlices() { return {}; } + static std::vector GetBenchmarkSlices() { + return {MakeSlice( + {0x00, 0x07, 'a', 'b', 'c', '-', 'b', 'i', 'n', 0xeb, 0x1d, 0x4d, + 0xe8, 0x96, 0x8c, 0x14, 0x20, 0x06, 0xc1, 0xc3, 0xdf, 0x6e, 0x1f, 0xef, + 0xde, 0x2f, 0xde, 0xb7, 0xf2, 0xfe, 0x6d, 0xd4, 0xe4, 0x7d, 0xf5, 0x55, + 0x46, 0x52, 0x3d, 0x91, 0xf2, 0xd4, 0x6f, 0xca, 0x34, 0xcd, 0xd9, 0x39, + 0xbd, 0x03, 0x27, 0xe3, 0x9c, 0x74, 0xcc, 0x17, 0x34, 0xed, 0xa6, 0x6a, + 0x77, 0x73, 0x10, 0xcd, 0x8e, 0x4e, 0x5c, 0x7c, 0x72, 0x39, 0xd8, 0xe6, + 0x78, 0x6b, 0xdb, 0xa5, 0xb7, 0xab, 0xe7, 0x46, 0xae, 0x21, 0xab, 0x7f, + 0x01, 0x89, 0x13, 0xd7, 0xca, 0x17, 0x6e, 0xcb, 0xd6, 0x79, 0x71, 0x68, + 0xbf, 0x8a, 0x3f, 0x32, 0xe8, 0xba, 0xf5, 0xbe, 0xb3, 0xbc, 0xde, 0x28, + 0xc7, 0xcf, 0x62, 0x7a, 0x58, 0x2c, 0xcf, 0x4d, 0xe3})}; + } +}; + +class RepresentativeClientInitialMetadata { + public: + static std::vector GetInitSlices() { + return {grpc_slice_from_static_string( + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression inc --no_framing + // < test/core/bad_client/tests/simple_request.headers + // ``` + "@\x05:path\x08/foo/bar" + "@\x07:scheme\x04http" + "@\x07:method\x04POST" + "@\x0a:authority\x09localhost" + "@\x0c" + "content-type\x10" + "application/grpc" + "@\x14grpc-accept-encoding\x15identity,deflate,gzip" + "@\x02te\x08trailers" + "@\x0auser-agent\"bad-client grpc-c/0.12.0.0 (linux)")}; + } + static std::vector GetBenchmarkSlices() { + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression pre --no_framing + // --hex < test/core/bad_client/tests/simple_request.headers + // ``` + return {MakeSlice({0xc5, 0xc4, 0xc3, 0xc2, 0xc1, 0xc0, 0xbf, 0xbe})}; + } +}; + +// This fixture reflects how initial metadata are sent by a production client, +// with non-indexed :path and binary headers. The metadata here are the same as +// the corresponding encoder benchmark above. +class MoreRepresentativeClientInitialMetadata { + public: + static std::vector GetInitSlices() { + return {MakeSlice( + {0x40, 0x07, ':', 's', 'c', 'h', 'e', 'm', 'e', 0x04, 'h', 't', + 't', 'p', 0x40, 0x07, ':', 'm', 'e', 't', 'h', 'o', 'd', 0x04, + 'P', 'O', 'S', 'T', 0x40, 0x05, ':', 'p', 'a', 't', 'h', 0x1f, + '/', 'g', 'r', 'p', 'c', '.', 't', 'e', 's', 't', '.', 'F', + 'o', 'o', 'S', 'e', 'r', 'v', 'i', 'c', 'e', '/', 'B', 'a', + 'r', 'M', 'e', 't', 'h', 'o', 'd', 0x40, 0x0a, ':', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 't', 'y', 0x09, 'l', 'o', 'c', 'a', + 'l', 'h', 'o', 's', 't', 0x40, 0x0e, 'g', 'r', 'p', 'c', '-', + 't', 'r', 'a', 'c', 'e', '-', 'b', 'i', 'n', 0x31, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, + 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x40, + 0x0d, 'g', 'r', 'p', 'c', '-', 't', 'a', 'g', 's', '-', 'b', + 'i', 'n', 0x14, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x40, + 0x0c, 'c', 'o', 'n', 't', 'e', 'n', 't', '-', 't', 'y', 'p', + 'e', 0x10, 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', + 'n', '/', 'g', 'r', 'p', 'c', 0x40, 0x14, 'g', 'r', 'p', 'c', + '-', 'a', 'c', 'c', 'e', 'p', 't', '-', 'e', 'n', 'c', 'o', + 'd', 'i', 'n', 'g', 0x15, 'i', 'd', 'e', 'n', 't', 'i', 't', + 'y', ',', 'd', 'e', 'f', 'l', 'a', 't', 'e', ',', 'g', 'z', + 'i', 'p', 0x40, 0x02, 't', 'e', 0x08, 't', 'r', 'a', 'i', 'l', + 'e', 'r', 's', 0x40, 0x0a, 'u', 's', 'e', 'r', '-', 'a', 'g', + 'e', 'n', 't', 0x22, 'b', 'a', 'd', '-', 'c', 'l', 'i', 'e', + 'n', 't', ' ', 'g', 'r', 'p', 'c', '-', 'c', '/', '0', '.', + '1', '2', '.', '0', '.', '0', ' ', '(', 'l', 'i', 'n', 'u', + 'x', ')'})}; + } + static std::vector GetBenchmarkSlices() { + return {MakeSlice({ + 0xc7, 0xc6, 0xc5, 0xc4, 0x7f, 0x04, 0x31, 0x00, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, + 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, + })}; + } +}; + +class RepresentativeServerInitialMetadata { + public: + static std::vector GetInitSlices() { + return {grpc_slice_from_static_string( + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression inc --no_framing + // < + // test/cpp/microbenchmarks/representative_server_initial_metadata.headers + // ``` + "@\x07:status\x03" + "200" + "@\x0c" + "content-type\x10" + "application/grpc" + "@\x14grpc-accept-encoding\x15identity,deflate,gzip")}; + } + static std::vector GetBenchmarkSlices() { + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression pre --no_framing + // --hex < + // test/cpp/microbenchmarks/representative_server_initial_metadata.headers + // ``` + return {MakeSlice({0xc0, 0xbf, 0xbe})}; + } +}; + +class RepresentativeServerTrailingMetadata { + public: + static std::vector GetInitSlices() { + return {grpc_slice_from_static_string( + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression inc --no_framing + // < + // test/cpp/microbenchmarks/representative_server_trailing_metadata.headers + // ``` + "@\x0bgrpc-status\x01" + "0" + "@\x0cgrpc-message\x00")}; + } + static std::vector GetBenchmarkSlices() { + // generated with: + // ``` + // tools/codegen/core/gen_header_frame.py --compression pre --no_framing + // --hex < + // test/cpp/microbenchmarks/representative_server_trailing_metadata.headers + // ``` + return {MakeSlice({0xbf, 0xbe})}; + } +}; + +// Send the same deadline repeatedly +class SameDeadline { + public: + static std::vector GetInitSlices() { + return { + grpc_slice_from_static_string("@\x0cgrpc-timeout\x03" + "30S")}; + } + static std::vector GetBenchmarkSlices() { + // Use saved key and literal value. + return {MakeSlice({0x0f, 0x2f, 0x03, '3', '0', 'S'})}; + } +}; + +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, EmptyBatch); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, IndexedSingleStaticElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, AddIndexedSingleStaticElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, KeyIndexedSingleStaticElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, IndexedSingleInternedElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, AddIndexedSingleInternedElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, KeyIndexedSingleInternedElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedElem); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<1, false>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<3, false>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<10, false>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<31, false>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<100, false>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<1, true>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<3, true>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<10, true>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<31, true>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, NonIndexedBinaryElem<100, true>); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + RepresentativeClientInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + MoreRepresentativeClientInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + RepresentativeServerInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + RepresentativeServerTrailingMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + RepresentativeClientInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + MoreRepresentativeClientInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, + RepresentativeServerInitialMetadata); +BENCHMARK_TEMPLATE(BM_HpackParserParseHeader, SameDeadline); + +} // namespace hpack_parser_fixtures + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_chttp2_transport.cc b/test/cpp/microbenchmarks/bm_chttp2_transport.cc new file mode 100644 index 00000000..7aaa338f --- /dev/null +++ b/test/cpp/microbenchmarks/bm_chttp2_transport.cc @@ -0,0 +1,687 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Microbenchmarks around CHTTP2 transport operations */ + +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/resource_quota.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +//////////////////////////////////////////////////////////////////////////////// +// Helper classes +// + +class PhonyEndpoint : public grpc_endpoint { + public: + PhonyEndpoint() { + static const grpc_endpoint_vtable my_vtable = {read, + write, + add_to_pollset, + add_to_pollset_set, + delete_from_pollset_set, + shutdown, + destroy, + get_peer, + get_local_address, + get_fd, + can_track_err}; + grpc_endpoint::vtable = &my_vtable; + } + + void PushInput(grpc_slice slice) { + if (read_cb_ == nullptr) { + GPR_ASSERT(!have_slice_); + buffered_slice_ = slice; + have_slice_ = true; + return; + } + grpc_slice_buffer_add(slices_, slice); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, read_cb_, GRPC_ERROR_NONE); + read_cb_ = nullptr; + } + + private: + grpc_closure* read_cb_ = nullptr; + grpc_slice_buffer* slices_ = nullptr; + bool have_slice_ = false; + grpc_slice buffered_slice_; + + void QueueRead(grpc_slice_buffer* slices, grpc_closure* cb) { + GPR_ASSERT(read_cb_ == nullptr); + if (have_slice_) { + have_slice_ = false; + grpc_slice_buffer_add(slices, buffered_slice_); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + return; + } + read_cb_ = cb; + slices_ = slices; + } + + static void read(grpc_endpoint* ep, grpc_slice_buffer* slices, + grpc_closure* cb, bool /*urgent*/) { + static_cast(ep)->QueueRead(slices, cb); + } + + static void write(grpc_endpoint* /*ep*/, grpc_slice_buffer* /*slices*/, + grpc_closure* cb, void* /*arg*/) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + } + + static void add_to_pollset(grpc_endpoint* /*ep*/, grpc_pollset* /*pollset*/) { + } + + static void add_to_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + + static void delete_from_pollset_set(grpc_endpoint* /*ep*/, + grpc_pollset_set* /*pollset*/) {} + + static void shutdown(grpc_endpoint* ep, grpc_error_handle why) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + static_cast(ep)->read_cb_, why); + } + + static void destroy(grpc_endpoint* ep) { + delete static_cast(ep); + } + + static absl::string_view get_peer(grpc_endpoint* /*ep*/) { return "test"; } + static absl::string_view get_local_address(grpc_endpoint* /*ep*/) { + return "test"; + } + static int get_fd(grpc_endpoint* /*ep*/) { return 0; } + static bool can_track_err(grpc_endpoint* /*ep*/) { return false; } +}; + +class Fixture { + public: + Fixture(const grpc::ChannelArguments& args, bool client) { + grpc_channel_args c_args = args.c_channel_args(); + ep_ = new PhonyEndpoint; + t_ = grpc_create_chttp2_transport(&c_args, ep_, client, + grpc_resource_user_create_unlimited()); + grpc_chttp2_transport_start_reading(t_, nullptr, nullptr, nullptr); + FlushExecCtx(); + } + + void FlushExecCtx() { grpc_core::ExecCtx::Get()->Flush(); } + + ~Fixture() { grpc_transport_destroy(t_); } + + grpc_chttp2_transport* chttp2_transport() { + return reinterpret_cast(t_); + } + grpc_transport* transport() { return t_; } + + void PushInput(grpc_slice slice) { ep_->PushInput(slice); } + + private: + PhonyEndpoint* ep_; + grpc_transport* t_; +}; + +class TestClosure : public grpc_closure { + public: + virtual ~TestClosure() {} +}; + +template +std::unique_ptr MakeTestClosure(F f) { + struct C : public TestClosure { + explicit C(const F& f) : f_(f) { + GRPC_CLOSURE_INIT(this, Execute, this, nullptr); + } + F f_; + static void Execute(void* arg, grpc_error_handle error) { + static_cast(arg)->f_(error); + } + }; + return std::unique_ptr(new C(f)); +} + +template +grpc_closure* MakeOnceClosure(F f) { + struct C : public grpc_closure { + explicit C(const F& f) : f_(f) {} + F f_; + static void Execute(void* arg, grpc_error_handle error) { + static_cast(arg)->f_(error); + delete static_cast(arg); + } + }; + auto* c = new C{f}; + return GRPC_CLOSURE_INIT(c, C::Execute, c, nullptr); +} + +class Stream { + public: + explicit Stream(Fixture* f) : f_(f) { + stream_size_ = grpc_transport_stream_size(f->transport()); + stream_ = gpr_malloc(stream_size_); + arena_ = grpc_core::Arena::Create(4096); + } + + ~Stream() { + gpr_event_wait(&done_, gpr_inf_future(GPR_CLOCK_REALTIME)); + gpr_free(stream_); + arena_->Destroy(); + } + + void Init(benchmark::State& state) { + GRPC_STREAM_REF_INIT(&refcount_, 1, &Stream::FinishDestroy, this, + "test_stream"); + gpr_event_init(&done_); + memset(stream_, 0, stream_size_); + if ((state.iterations() & 0xffff) == 0) { + arena_->Destroy(); + arena_ = grpc_core::Arena::Create(4096); + } + grpc_transport_init_stream(f_->transport(), + static_cast(stream_), &refcount_, + nullptr, arena_); + } + + void DestroyThen(grpc_closure* closure) { + destroy_closure_ = closure; +#ifndef NDEBUG + grpc_stream_unref(&refcount_, "DestroyThen"); +#else + grpc_stream_unref(&refcount_); +#endif + } + + void Op(grpc_transport_stream_op_batch* op) { + grpc_transport_perform_stream_op(f_->transport(), + static_cast(stream_), op); + } + + grpc_chttp2_stream* chttp2_stream() { + return static_cast(stream_); + } + + private: + static void FinishDestroy(void* arg, grpc_error_handle /*error*/) { + auto stream = static_cast(arg); + grpc_transport_destroy_stream(stream->f_->transport(), + static_cast(stream->stream_), + stream->destroy_closure_); + gpr_event_set(&stream->done_, reinterpret_cast(1)); + } + + Fixture* f_; + grpc_stream_refcount refcount_; + grpc_core::Arena* arena_; + size_t stream_size_; + void* stream_; + grpc_closure* destroy_closure_ = nullptr; + gpr_event done_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Benchmarks +// +std::vector> done_events; + +static void BM_StreamCreateDestroy(benchmark::State& state) { + grpc_core::ExecCtx exec_ctx; + TrackCounters track_counters; + Fixture f(grpc::ChannelArguments(), true); + auto* s = new Stream(&f); + grpc_transport_stream_op_batch op; + grpc_transport_stream_op_batch_payload op_payload(nullptr); + op = {}; + op.cancel_stream = true; + op.payload = &op_payload; + op_payload.cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + std::unique_ptr next = + MakeTestClosure([&, s](grpc_error_handle /*error*/) { + if (!state.KeepRunning()) { + delete s; + return; + } + s->Init(state); + s->Op(&op); + s->DestroyThen(next.get()); + }); + grpc_core::Closure::Run(DEBUG_LOCATION, next.get(), GRPC_ERROR_NONE); + f.FlushExecCtx(); + track_counters.Finish(state); +} +BENCHMARK(BM_StreamCreateDestroy); + +class RepresentativeClientInitialMetadata { + public: + static void Prepare(grpc_metadata_batch* b) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_SCHEME_HTTP))); + GPR_ASSERT(GRPC_LOG_IF_ERROR("addmd", b->Append(GRPC_MDELEM_METHOD_POST))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_PATH, grpc_slice_intern(grpc_slice_from_static_string( + "/foo/bar/bm_chttp2_transport")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_AUTHORITY, + grpc_slice_intern(grpc_slice_from_static_string( + "foo.test.google.fr:1234")))))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", + b->Append( + GRPC_MDELEM_GRPC_ACCEPT_ENCODING_IDENTITY_COMMA_DEFLATE_COMMA_GZIP))); + b->Set(grpc_core::TeMetadata(), grpc_core::TeMetadata::kTrailers); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(GRPC_MDELEM_CONTENT_TYPE_APPLICATION_SLASH_GRPC))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "addmd", b->Append(grpc_mdelem_from_slices( + GRPC_MDSTR_USER_AGENT, + grpc_slice_intern(grpc_slice_from_static_string( + "grpc-c/3.0.0-dev (linux; chttp2; green)")))))); + } +}; + +template +static void BM_StreamCreateSendInitialMetadataDestroy(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + Fixture f(grpc::ChannelArguments(), true); + auto* s = new Stream(&f); + grpc_transport_stream_op_batch op; + grpc_transport_stream_op_batch_payload op_payload(nullptr); + std::unique_ptr start; + std::unique_ptr done; + + auto reset_op = [&]() { + op = {}; + op.payload = &op_payload; + }; + + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + Metadata::Prepare(&b); + + f.FlushExecCtx(); + gpr_event bm_done; + gpr_event_init(&bm_done); + start = MakeTestClosure([&, s](grpc_error_handle /*error*/) { + if (!state.KeepRunning()) { + delete s; + gpr_event_set(&bm_done, (void*)1); + return; + } + s->Init(state); + reset_op(); + op.on_complete = done.get(); + op.send_initial_metadata = true; + op.payload->send_initial_metadata.send_initial_metadata = &b; + s->Op(&op); + }); + done = MakeTestClosure([&](grpc_error_handle /*error*/) { + reset_op(); + op.cancel_stream = true; + op.payload->cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + s->Op(&op); + s->DestroyThen(start.get()); + }); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, start.get(), GRPC_ERROR_NONE); + f.FlushExecCtx(); + gpr_event_wait(&bm_done, gpr_inf_future(GPR_CLOCK_REALTIME)); + track_counters.Finish(state); +} +BENCHMARK_TEMPLATE(BM_StreamCreateSendInitialMetadataDestroy, + RepresentativeClientInitialMetadata); + +static void BM_TransportEmptyOp(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + Fixture f(grpc::ChannelArguments(), true); + auto* s = new Stream(&f); + s->Init(state); + grpc_transport_stream_op_batch op; + grpc_transport_stream_op_batch_payload op_payload(nullptr); + auto reset_op = [&]() { + op = {}; + op.payload = &op_payload; + }; + std::unique_ptr c = + MakeTestClosure([&](grpc_error_handle /*error*/) { + if (!state.KeepRunning()) return; + reset_op(); + op.on_complete = c.get(); + s->Op(&op); + }); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, c.get(), GRPC_ERROR_NONE); + f.FlushExecCtx(); + reset_op(); + op.cancel_stream = true; + op_payload.cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + gpr_event* stream_cancel_done = new gpr_event; + gpr_event_init(stream_cancel_done); + std::unique_ptr stream_cancel_closure = + MakeTestClosure([&](grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + gpr_event_set(stream_cancel_done, reinterpret_cast(1)); + }); + op.on_complete = stream_cancel_closure.get(); + s->Op(&op); + f.FlushExecCtx(); + gpr_event_wait(stream_cancel_done, gpr_inf_future(GPR_CLOCK_REALTIME)); + done_events.emplace_back(stream_cancel_done); + s->DestroyThen( + MakeOnceClosure([s](grpc_error_handle /*error*/) { delete s; })); + f.FlushExecCtx(); + track_counters.Finish(state); +} +BENCHMARK(BM_TransportEmptyOp); + +static void BM_TransportStreamSend(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + Fixture f(grpc::ChannelArguments(), true); + auto* s = new Stream(&f); + s->Init(state); + grpc_transport_stream_op_batch op; + grpc_transport_stream_op_batch_payload op_payload(nullptr); + auto reset_op = [&]() { + op = {}; + op.payload = &op_payload; + }; + // Create the send_message payload slice. + // Note: We use grpc_slice_malloc_large() instead of grpc_slice_malloc() + // to force the slice to be refcounted, so that it remains alive when it + // is unreffed after each send_message op. + grpc_slice send_slice = grpc_slice_malloc_large(state.range(0)); + memset(GRPC_SLICE_START_PTR(send_slice), 0, GRPC_SLICE_LENGTH(send_slice)); + grpc_core::ManualConstructor send_stream; + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + RepresentativeClientInitialMetadata::Prepare(&b); + + gpr_event* bm_done = new gpr_event; + gpr_event_init(bm_done); + + std::unique_ptr c = + MakeTestClosure([&](grpc_error_handle /*error*/) { + if (!state.KeepRunning()) { + gpr_event_set(bm_done, reinterpret_cast(1)); + return; + } + grpc_slice_buffer send_buffer; + grpc_slice_buffer_init(&send_buffer); + grpc_slice_buffer_add(&send_buffer, grpc_slice_ref(send_slice)); + send_stream.Init(&send_buffer, 0); + grpc_slice_buffer_destroy(&send_buffer); + // force outgoing window to be yuge + s->chttp2_stream()->flow_control->TestOnlyForceHugeWindow(); + f.chttp2_transport()->flow_control->TestOnlyForceHugeWindow(); + reset_op(); + op.on_complete = c.get(); + op.send_message = true; + op.payload->send_message.send_message.reset(send_stream.get()); + s->Op(&op); + }); + + reset_op(); + op.send_initial_metadata = true; + op.payload->send_initial_metadata.send_initial_metadata = &b; + op.on_complete = c.get(); + s->Op(&op); + + f.FlushExecCtx(); + gpr_event_wait(bm_done, gpr_inf_future(GPR_CLOCK_REALTIME)); + done_events.emplace_back(bm_done); + + reset_op(); + op.cancel_stream = true; + op.payload->cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + gpr_event* stream_cancel_done = new gpr_event; + gpr_event_init(stream_cancel_done); + std::unique_ptr stream_cancel_closure = + MakeTestClosure([&](grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + gpr_event_set(stream_cancel_done, reinterpret_cast(1)); + }); + op.on_complete = stream_cancel_closure.get(); + s->Op(&op); + f.FlushExecCtx(); + gpr_event_wait(stream_cancel_done, gpr_inf_future(GPR_CLOCK_REALTIME)); + done_events.emplace_back(stream_cancel_done); + s->DestroyThen( + MakeOnceClosure([s](grpc_error_handle /*error*/) { delete s; })); + f.FlushExecCtx(); + track_counters.Finish(state); + grpc_slice_unref(send_slice); +} +BENCHMARK(BM_TransportStreamSend)->Range(0, 128 * 1024 * 1024); + +#define SLICE_FROM_BUFFER(s) grpc_slice_from_static_buffer(s, sizeof(s) - 1) + +static grpc_slice CreateIncomingDataSlice(size_t length, size_t frame_size) { + std::queue unframed; + + unframed.push(static_cast(0)); + unframed.push(static_cast(length >> 24)); + unframed.push(static_cast(length >> 16)); + unframed.push(static_cast(length >> 8)); + unframed.push(static_cast(length)); + for (size_t i = 0; i < length; i++) { + unframed.push('a'); + } + + std::vector framed; + while (unframed.size() > frame_size) { + // frame size + framed.push_back(static_cast(frame_size >> 16)); + framed.push_back(static_cast(frame_size >> 8)); + framed.push_back(static_cast(frame_size)); + // data frame + framed.push_back(0); + // no flags + framed.push_back(0); + // stream id + framed.push_back(0); + framed.push_back(0); + framed.push_back(0); + framed.push_back(1); + // frame data + for (size_t i = 0; i < frame_size; i++) { + framed.push_back(unframed.front()); + unframed.pop(); + } + } + + // frame size + framed.push_back(static_cast(unframed.size() >> 16)); + framed.push_back(static_cast(unframed.size() >> 8)); + framed.push_back(static_cast(unframed.size())); + // data frame + framed.push_back(0); + // no flags + framed.push_back(0); + // stream id + framed.push_back(0); + framed.push_back(0); + framed.push_back(0); + framed.push_back(1); + while (!unframed.empty()) { + framed.push_back(unframed.front()); + unframed.pop(); + } + + return grpc_slice_from_copied_buffer(framed.data(), framed.size()); +} + +static void BM_TransportStreamRecv(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + Fixture f(grpc::ChannelArguments(), true); + auto* s = new Stream(&f); + s->Init(state); + grpc_transport_stream_op_batch_payload op_payload(nullptr); + grpc_transport_stream_op_batch op; + grpc_core::OrphanablePtr recv_stream; + grpc_slice incoming_data = CreateIncomingDataSlice(state.range(0), 16384); + + auto reset_op = [&]() { + op = {}; + op.payload = &op_payload; + }; + + auto arena = grpc_core::MakeScopedArena(1024); + grpc_metadata_batch b(arena.get()); + RepresentativeClientInitialMetadata::Prepare(&b); + + std::unique_ptr do_nothing = + MakeTestClosure([](grpc_error_handle /*error*/) {}); + + uint32_t received; + + std::unique_ptr drain_start; + std::unique_ptr drain; + std::unique_ptr drain_continue; + grpc_slice recv_slice; + + std::unique_ptr c = + MakeTestClosure([&](grpc_error_handle /*error*/) { + if (!state.KeepRunning()) return; + // force outgoing window to be yuge + s->chttp2_stream()->flow_control->TestOnlyForceHugeWindow(); + f.chttp2_transport()->flow_control->TestOnlyForceHugeWindow(); + received = 0; + reset_op(); + op.on_complete = do_nothing.get(); + op.recv_message = true; + op.payload->recv_message.recv_message = &recv_stream; + op.payload->recv_message.call_failed_before_recv_message = nullptr; + op.payload->recv_message.recv_message_ready = drain_start.get(); + s->Op(&op); + f.PushInput(grpc_slice_ref(incoming_data)); + }); + + drain_start = MakeTestClosure([&](grpc_error_handle /*error*/) { + if (recv_stream == nullptr) { + GPR_ASSERT(!state.KeepRunning()); + return; + } + grpc_core::Closure::Run(DEBUG_LOCATION, drain.get(), GRPC_ERROR_NONE); + }); + + drain = MakeTestClosure([&](grpc_error_handle /*error*/) { + do { + if (received == recv_stream->length()) { + recv_stream.reset(); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, c.get(), GRPC_ERROR_NONE); + return; + } + } while (recv_stream->Next(recv_stream->length() - received, + drain_continue.get()) && + GRPC_ERROR_NONE == recv_stream->Pull(&recv_slice) && + (received += GRPC_SLICE_LENGTH(recv_slice), + grpc_slice_unref_internal(recv_slice), true)); + }); + + drain_continue = MakeTestClosure([&](grpc_error_handle /*error*/) { + GPR_ASSERT(GRPC_LOG_IF_ERROR("Pull", recv_stream->Pull(&recv_slice))); + received += GRPC_SLICE_LENGTH(recv_slice); + grpc_slice_unref_internal(recv_slice); + grpc_core::Closure::Run(DEBUG_LOCATION, drain.get(), GRPC_ERROR_NONE); + }); + + reset_op(); + auto b_recv = absl::make_unique(arena.get()); + op.send_initial_metadata = true; + op.payload->send_initial_metadata.send_initial_metadata = &b; + op.recv_initial_metadata = true; + op.payload->recv_initial_metadata.recv_initial_metadata = b_recv.get(); + op.payload->recv_initial_metadata.recv_initial_metadata_ready = + do_nothing.get(); + op.on_complete = c.get(); + s->Op(&op); + f.PushInput(SLICE_FROM_BUFFER( + "\x00\x00\x00\x04\x00\x00\x00\x00\x00" + // Generated using: + // tools/codegen/core/gen_header_frame.py < + // test/cpp/microbenchmarks/representative_server_initial_metadata.headers + "\x00\x00X\x01\x04\x00\x00\x00\x01" + "\x10\x07:status\x03" + "200" + "\x10\x0c" + "content-type\x10" + "application/grpc" + "\x10\x14grpc-accept-encoding\x15identity,deflate,gzip")); + + f.FlushExecCtx(); + reset_op(); + op.cancel_stream = true; + op.payload->cancel_stream.cancel_error = GRPC_ERROR_CANCELLED; + gpr_event* stream_cancel_done = new gpr_event; + gpr_event_init(stream_cancel_done); + std::unique_ptr stream_cancel_closure = + MakeTestClosure([&](grpc_error_handle error) { + GPR_ASSERT(error == GRPC_ERROR_NONE); + gpr_event_set(stream_cancel_done, reinterpret_cast(1)); + }); + op.on_complete = stream_cancel_closure.get(); + s->Op(&op); + f.FlushExecCtx(); + gpr_event_wait(stream_cancel_done, gpr_inf_future(GPR_CLOCK_REALTIME)); + done_events.emplace_back(stream_cancel_done); + s->DestroyThen(MakeOnceClosure([s, &b_recv](grpc_error_handle /*error*/) { + b_recv.reset(); + delete s; + })); + f.FlushExecCtx(); + track_counters.Finish(state); + grpc_slice_unref(incoming_data); +} +BENCHMARK(BM_TransportStreamRecv)->Range(0, 128 * 1024 * 1024); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_closure.cc b/test/cpp/microbenchmarks/bm_closure.cc new file mode 100644 index 00000000..4c643fc5 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_closure.cc @@ -0,0 +1,404 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test various closure related operations */ + +#include + +#include + +#include + +#include "src/core/lib/gpr/spinlock.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +static void BM_NoOpExecCtx(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + grpc_core::ExecCtx exec_ctx; + } + track_counters.Finish(state); +} +BENCHMARK(BM_NoOpExecCtx); + +static void BM_WellFlushed(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_WellFlushed); + +static void DoNothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +static void BM_ClosureInitAgainstExecCtx(benchmark::State& state) { + TrackCounters track_counters; + grpc_closure c; + for (auto _ : state) { + benchmark::DoNotOptimize( + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, grpc_schedule_on_exec_ctx)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureInitAgainstExecCtx); + +static void BM_ClosureInitAgainstCombiner(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner = grpc_combiner_create(); + grpc_closure c; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + benchmark::DoNotOptimize( + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, nullptr)); + } + GRPC_COMBINER_UNREF(combiner, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureInitAgainstCombiner); + +static void BM_ClosureRun(benchmark::State& state) { + TrackCounters track_counters; + grpc_closure c; + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::Closure::Run(DEBUG_LOCATION, &c, GRPC_ERROR_NONE); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureRun); + +static void BM_ClosureCreateAndRun(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::Closure::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(DoNothing, nullptr, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureCreateAndRun); + +static void BM_ClosureInitAndRun(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + grpc_closure c; + for (auto _ : state) { + grpc_core::Closure::Run( + DEBUG_LOCATION, + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, grpc_schedule_on_exec_ctx), + GRPC_ERROR_NONE); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureInitAndRun); + +static void BM_ClosureSchedOnExecCtx(benchmark::State& state) { + TrackCounters track_counters; + grpc_closure c; + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSchedOnExecCtx); + +static void BM_ClosureSched2OnExecCtx(benchmark::State& state) { + TrackCounters track_counters; + grpc_closure c1; + grpc_closure c2; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c1, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c2, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched2OnExecCtx); + +static void BM_ClosureSched3OnExecCtx(benchmark::State& state) { + TrackCounters track_counters; + grpc_closure c1; + grpc_closure c2; + grpc_closure c3; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&c3, DoNothing, nullptr, grpc_schedule_on_exec_ctx); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c1, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c2, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &c3, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched3OnExecCtx); + +static void BM_AcquireMutex(benchmark::State& state) { + TrackCounters track_counters; + // for comparison with the combiner stuff below + gpr_mu mu; + gpr_mu_init(&mu); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + gpr_mu_lock(&mu); + DoNothing(nullptr, GRPC_ERROR_NONE); + gpr_mu_unlock(&mu); + } + gpr_mu_destroy(&mu); + + track_counters.Finish(state); +} +BENCHMARK(BM_AcquireMutex); + +static void BM_TryAcquireMutex(benchmark::State& state) { + TrackCounters track_counters; + // for comparison with the combiner stuff below + gpr_mu mu; + gpr_mu_init(&mu); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + if (gpr_mu_trylock(&mu)) { + DoNothing(nullptr, GRPC_ERROR_NONE); + gpr_mu_unlock(&mu); + } else { + abort(); + } + } + gpr_mu_destroy(&mu); + + track_counters.Finish(state); +} +BENCHMARK(BM_TryAcquireMutex); + +static void BM_AcquireSpinlock(benchmark::State& state) { + TrackCounters track_counters; + // for comparison with the combiner stuff below + gpr_spinlock mu = GPR_SPINLOCK_INITIALIZER; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + gpr_spinlock_lock(&mu); + DoNothing(nullptr, GRPC_ERROR_NONE); + gpr_spinlock_unlock(&mu); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_AcquireSpinlock); + +static void BM_TryAcquireSpinlock(benchmark::State& state) { + TrackCounters track_counters; + // for comparison with the combiner stuff below + gpr_spinlock mu = GPR_SPINLOCK_INITIALIZER; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + if (gpr_spinlock_trylock(&mu)) { + DoNothing(nullptr, GRPC_ERROR_NONE); + gpr_spinlock_unlock(&mu); + } else { + abort(); + } + } + + track_counters.Finish(state); +} +BENCHMARK(BM_TryAcquireSpinlock); + +static void BM_ClosureSchedOnCombiner(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner = grpc_combiner_create(); + grpc_closure c; + GRPC_CLOSURE_INIT(&c, DoNothing, nullptr, nullptr); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + combiner->Run(&c, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + GRPC_COMBINER_UNREF(combiner, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSchedOnCombiner); + +static void BM_ClosureSched2OnCombiner(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner = grpc_combiner_create(); + grpc_closure c1; + grpc_closure c2; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, nullptr); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + combiner->Run(&c1, GRPC_ERROR_NONE); + combiner->Run(&c2, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + GRPC_COMBINER_UNREF(combiner, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched2OnCombiner); + +static void BM_ClosureSched3OnCombiner(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner = grpc_combiner_create(); + grpc_closure c1; + grpc_closure c2; + grpc_closure c3; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c3, DoNothing, nullptr, nullptr); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + combiner->Run(&c1, GRPC_ERROR_NONE); + combiner->Run(&c2, GRPC_ERROR_NONE); + combiner->Run(&c3, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + GRPC_COMBINER_UNREF(combiner, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched3OnCombiner); + +static void BM_ClosureSched2OnTwoCombiners(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner1 = grpc_combiner_create(); + grpc_core::Combiner* combiner2 = grpc_combiner_create(); + grpc_closure c1; + grpc_closure c2; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, nullptr); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + combiner1->Run(&c1, GRPC_ERROR_NONE); + combiner2->Run(&c2, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + GRPC_COMBINER_UNREF(combiner1, "finished"); + GRPC_COMBINER_UNREF(combiner2, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched2OnTwoCombiners); + +static void BM_ClosureSched4OnTwoCombiners(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::Combiner* combiner1 = grpc_combiner_create(); + grpc_core::Combiner* combiner2 = grpc_combiner_create(); + grpc_closure c1; + grpc_closure c2; + grpc_closure c3; + grpc_closure c4; + GRPC_CLOSURE_INIT(&c1, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c2, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c3, DoNothing, nullptr, nullptr); + GRPC_CLOSURE_INIT(&c4, DoNothing, nullptr, nullptr); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + combiner1->Run(&c1, GRPC_ERROR_NONE); + combiner2->Run(&c2, GRPC_ERROR_NONE); + combiner1->Run(&c3, GRPC_ERROR_NONE); + combiner2->Run(&c4, GRPC_ERROR_NONE); + grpc_core::ExecCtx::Get()->Flush(); + } + GRPC_COMBINER_UNREF(combiner1, "finished"); + GRPC_COMBINER_UNREF(combiner2, "finished"); + + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureSched4OnTwoCombiners); + +// Helper that continuously reschedules the same closure against something until +// the benchmark is complete +class Rescheduler { + public: + explicit Rescheduler(benchmark::State& state) : state_(state) { + GRPC_CLOSURE_INIT(&closure_, Step, this, nullptr); + } + + void ScheduleFirst() { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &closure_, GRPC_ERROR_NONE); + } + + void ScheduleFirstAgainstDifferentScheduler() { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, + GRPC_CLOSURE_CREATE(Step, this, nullptr), + GRPC_ERROR_NONE); + } + + private: + benchmark::State& state_; + grpc_closure closure_; + + static void Step(void* arg, grpc_error_handle /*error*/) { + Rescheduler* self = static_cast(arg); + if (self->state_.KeepRunning()) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, &self->closure_, GRPC_ERROR_NONE); + } + } +}; + +static void BM_ClosureReschedOnExecCtx(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + Rescheduler r(state); + r.ScheduleFirst(); + grpc_core::ExecCtx::Get()->Flush(); + track_counters.Finish(state); +} +BENCHMARK(BM_ClosureReschedOnExecCtx); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_cq.cc b/test/cpp/microbenchmarks/bm_cq.cc new file mode 100644 index 00000000..4b9aeefd --- /dev/null +++ b/test/cpp/microbenchmarks/bm_cq.cc @@ -0,0 +1,325 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This benchmark exists to ensure that the benchmark integration is + * working */ + +#include + +#include +#include +#include +#include + +#include "src/core/lib/surface/completion_queue.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +static void BM_CreateDestroyCpp(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + CompletionQueue cq; + } + track_counters.Finish(state); +} +BENCHMARK(BM_CreateDestroyCpp); + +/* Create cq using a different constructor */ +static void BM_CreateDestroyCpp2(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + grpc_completion_queue* core_cq = + grpc_completion_queue_create_for_next(nullptr); + CompletionQueue cq(core_cq); + } + track_counters.Finish(state); +} +BENCHMARK(BM_CreateDestroyCpp2); + +static void BM_CreateDestroyCore(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + // TODO(sreek): Templatize this benchmark and pass completion type and + // polling type as parameters + grpc_completion_queue_destroy( + grpc_completion_queue_create_for_next(nullptr)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_CreateDestroyCore); + +static void DoneWithCompletionOnStack(void* /*arg*/, + grpc_cq_completion* /*completion*/) {} + +static void DoneWithCompletionOnHeap(void* /*arg*/, + grpc_cq_completion* completion) { + delete completion; +} + +class PhonyTag final : public internal::CompletionQueueTag { + public: + bool FinalizeResult(void** /*tag*/, bool* /*status*/) override { + return true; + } +}; + +static void BM_Pass1Cpp(benchmark::State& state) { + TrackCounters track_counters; + CompletionQueue cq; + grpc_completion_queue* c_cq = cq.cq(); + for (auto _ : state) { + grpc_cq_completion completion; + PhonyTag phony_tag; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(grpc_cq_begin_op(c_cq, &phony_tag)); + grpc_cq_end_op(c_cq, &phony_tag, GRPC_ERROR_NONE, DoneWithCompletionOnStack, + nullptr, &completion); + + void* tag; + bool ok; + cq.Next(&tag, &ok); + } + track_counters.Finish(state); +} +BENCHMARK(BM_Pass1Cpp); + +static void BM_Pass1Core(benchmark::State& state) { + TrackCounters track_counters; + // TODO(sreek): Templatize this benchmark and pass polling_type as a param + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + for (auto _ : state) { + grpc_cq_completion completion; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(grpc_cq_begin_op(cq, nullptr)); + grpc_cq_end_op(cq, nullptr, GRPC_ERROR_NONE, DoneWithCompletionOnStack, + nullptr, &completion); + + grpc_completion_queue_next(cq, deadline, nullptr); + } + grpc_completion_queue_destroy(cq); + track_counters.Finish(state); +} +BENCHMARK(BM_Pass1Core); + +static void BM_Pluck1Core(benchmark::State& state) { + TrackCounters track_counters; + // TODO(sreek): Templatize this benchmark and pass polling_type as a param + grpc_completion_queue* cq = grpc_completion_queue_create_for_pluck(nullptr); + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + for (auto _ : state) { + grpc_cq_completion completion; + grpc_core::ExecCtx exec_ctx; + GPR_ASSERT(grpc_cq_begin_op(cq, nullptr)); + grpc_cq_end_op(cq, nullptr, GRPC_ERROR_NONE, DoneWithCompletionOnStack, + nullptr, &completion); + + grpc_completion_queue_pluck(cq, nullptr, deadline, nullptr); + } + grpc_completion_queue_destroy(cq); + track_counters.Finish(state); +} +BENCHMARK(BM_Pluck1Core); + +static void BM_EmptyCore(benchmark::State& state) { + TrackCounters track_counters; + // TODO(sreek): Templatize this benchmark and pass polling_type as a param + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + gpr_timespec deadline = gpr_inf_past(GPR_CLOCK_MONOTONIC); + for (auto _ : state) { + grpc_completion_queue_next(cq, deadline, nullptr); + } + grpc_completion_queue_destroy(cq); + track_counters.Finish(state); +} +BENCHMARK(BM_EmptyCore); + +// Helper for tests to shutdown correctly and tersely +static void shutdown_and_destroy(grpc_completion_queue* cc) { + grpc_completion_queue_shutdown(cc); + grpc_completion_queue_destroy(cc); +} + +static gpr_mu shutdown_mu, mu; +static gpr_cv shutdown_cv, cv; + +// Tag completion queue iterate times +class TagCallback : public grpc_completion_queue_functor { + public: + explicit TagCallback(int* iter) : iter_(iter) { + functor_run = &TagCallback::Run; + inlineable = false; + } + ~TagCallback() {} + static void Run(grpc_completion_queue_functor* cb, int ok) { + gpr_mu_lock(&mu); + GPR_ASSERT(static_cast(ok)); + *static_cast(cb)->iter_ += 1; + gpr_cv_signal(&cv); + gpr_mu_unlock(&mu); + }; + + private: + int* iter_; +}; + +// Check if completion queue is shut down +class ShutdownCallback : public grpc_completion_queue_functor { + public: + explicit ShutdownCallback(bool* done) : done_(done) { + functor_run = &ShutdownCallback::Run; + inlineable = false; + } + ~ShutdownCallback() {} + static void Run(grpc_completion_queue_functor* cb, int ok) { + gpr_mu_lock(&shutdown_mu); + *static_cast(cb)->done_ = static_cast(ok); + gpr_cv_signal(&shutdown_cv); + gpr_mu_unlock(&shutdown_mu); + } + + private: + bool* done_; +}; + +static void BM_Callback_CQ_Pass1Core(benchmark::State& state) { + TrackCounters track_counters; + int iteration = 0, current_iterations = 0; + TagCallback tag_cb(&iteration); + gpr_mu_init(&mu); + gpr_cv_init(&cv); + gpr_mu_init(&shutdown_mu); + gpr_cv_init(&shutdown_cv); + bool got_shutdown = false; + ShutdownCallback shutdown_cb(&got_shutdown); + // This test with stack-allocated completions only works for non-polling or + // EM-polling callback core CQs because otherwise the callback could execute + // on another thread after the stack objects here go out of scope. An + // alternative would be to synchronize between the benchmark loop and the + // callback, but then it would be measuring the overhead of synchronization + // rather than the overhead of the completion queue. + // For generality, test here with non-polling. + grpc_completion_queue_attributes attr; + attr.version = 2; + attr.cq_completion_type = GRPC_CQ_CALLBACK; + attr.cq_polling_type = GRPC_CQ_NON_POLLING; + attr.cq_shutdown_cb = &shutdown_cb; + grpc_completion_queue* cc = grpc_completion_queue_create( + grpc_completion_queue_factory_lookup(&attr), &attr, nullptr); + for (auto _ : state) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_cq_completion completion; + GPR_ASSERT(grpc_cq_begin_op(cc, &tag_cb)); + grpc_cq_end_op(cc, &tag_cb, GRPC_ERROR_NONE, DoneWithCompletionOnStack, + nullptr, &completion); + } + shutdown_and_destroy(cc); + + gpr_mu_lock(&mu); + current_iterations = static_cast(state.iterations()); + while (current_iterations != iteration) { + // Wait for all the callbacks to complete. + gpr_cv_wait(&cv, &mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&mu); + + gpr_mu_lock(&shutdown_mu); + while (!got_shutdown) { + // Wait for the shutdown callback to complete. + gpr_cv_wait(&shutdown_cv, &shutdown_mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&shutdown_mu); + + GPR_ASSERT(got_shutdown); + GPR_ASSERT(iteration == static_cast(state.iterations())); + track_counters.Finish(state); + gpr_cv_destroy(&cv); + gpr_mu_destroy(&mu); + gpr_cv_destroy(&shutdown_cv); + gpr_mu_destroy(&shutdown_mu); +} +static void BM_Callback_CQ_Pass1CoreHeapCompletion(benchmark::State& state) { + TrackCounters track_counters; + int iteration = 0, current_iterations = 0; + TagCallback tag_cb(&iteration); + gpr_mu_init(&mu); + gpr_cv_init(&cv); + gpr_mu_init(&shutdown_mu); + gpr_cv_init(&shutdown_cv); + bool got_shutdown = false; + ShutdownCallback shutdown_cb(&got_shutdown); + grpc_completion_queue* cc = + grpc_completion_queue_create_for_callback(&shutdown_cb, nullptr); + for (auto _ : state) { + grpc_core::ApplicationCallbackExecCtx callback_exec_ctx; + grpc_core::ExecCtx exec_ctx; + grpc_cq_completion* completion = new grpc_cq_completion; + GPR_ASSERT(grpc_cq_begin_op(cc, &tag_cb)); + grpc_cq_end_op(cc, &tag_cb, GRPC_ERROR_NONE, DoneWithCompletionOnHeap, + nullptr, completion); + } + shutdown_and_destroy(cc); + + gpr_mu_lock(&mu); + current_iterations = static_cast(state.iterations()); + while (current_iterations != iteration) { + // Wait for all the callbacks to complete. + gpr_cv_wait(&cv, &mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&mu); + + gpr_mu_lock(&shutdown_mu); + while (!got_shutdown) { + // Wait for the shutdown callback to complete. + gpr_cv_wait(&shutdown_cv, &shutdown_mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&shutdown_mu); + + GPR_ASSERT(got_shutdown); + GPR_ASSERT(iteration == static_cast(state.iterations())); + track_counters.Finish(state); + gpr_cv_destroy(&cv); + gpr_mu_destroy(&mu); + gpr_cv_destroy(&shutdown_cv); + gpr_mu_destroy(&shutdown_mu); +} +BENCHMARK(BM_Callback_CQ_Pass1Core); +BENCHMARK(BM_Callback_CQ_Pass1CoreHeapCompletion); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_cq_multiple_threads.cc b/test/cpp/microbenchmarks/bm_cq_multiple_threads.cc new file mode 100644 index 00000000..2c4241c0 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_cq_multiple_threads.cc @@ -0,0 +1,228 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/surface/completion_queue.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +struct grpc_pollset { + gpr_mu mu; +}; + +static gpr_mu g_mu; +static gpr_cv g_cv; +static int g_threads_active; +static bool g_active; + +namespace grpc { +namespace testing { +static grpc_completion_queue* g_cq; +static grpc_event_engine_vtable g_vtable; + +static void pollset_shutdown(grpc_pollset* /*ps*/, grpc_closure* closure) { + grpc_core::ExecCtx::Run(DEBUG_LOCATION, closure, GRPC_ERROR_NONE); +} + +static void pollset_init(grpc_pollset* ps, gpr_mu** mu) { + gpr_mu_init(&ps->mu); + *mu = &ps->mu; +} + +static void pollset_destroy(grpc_pollset* ps) { gpr_mu_destroy(&ps->mu); } + +static grpc_error_handle pollset_kick(grpc_pollset* /*p*/, + grpc_pollset_worker* /*worker*/) { + return GRPC_ERROR_NONE; +} + +/* Callback when the tag is dequeued from the completion queue. Does nothing */ +static void cq_done_cb(void* /*done_arg*/, grpc_cq_completion* cq_completion) { + gpr_free(cq_completion); +} + +/* Queues a completion tag if deadline is > 0. + * Does nothing if deadline is 0 (i.e gpr_time_0(GPR_CLOCK_MONOTONIC)) */ +static grpc_error_handle pollset_work(grpc_pollset* ps, + grpc_pollset_worker** /*worker*/, + grpc_millis deadline) { + if (deadline == 0) { + gpr_log(GPR_DEBUG, "no-op"); + return GRPC_ERROR_NONE; + } + + gpr_mu_unlock(&ps->mu); + + void* tag = reinterpret_cast(10); // Some random number + GPR_ASSERT(grpc_cq_begin_op(g_cq, tag)); + grpc_cq_end_op( + g_cq, tag, GRPC_ERROR_NONE, cq_done_cb, nullptr, + static_cast(gpr_malloc(sizeof(grpc_cq_completion)))); + grpc_core::ExecCtx::Get()->Flush(); + gpr_mu_lock(&ps->mu); + return GRPC_ERROR_NONE; +} + +static const grpc_event_engine_vtable* init_engine_vtable(bool) { + memset(&g_vtable, 0, sizeof(g_vtable)); + + g_vtable.pollset_size = sizeof(grpc_pollset); + g_vtable.pollset_init = pollset_init; + g_vtable.pollset_shutdown = pollset_shutdown; + g_vtable.pollset_destroy = pollset_destroy; + g_vtable.pollset_work = pollset_work; + g_vtable.pollset_kick = pollset_kick; + g_vtable.is_any_background_poller_thread = [] { return false; }; + g_vtable.add_closure_to_background_poller = [](grpc_closure* /*closure*/, + grpc_error_handle /*error*/) { + return false; + }; + g_vtable.shutdown_background_closure = [] {}; + g_vtable.shutdown_engine = [] {}; + + return &g_vtable; +} + +static void setup() { + // This test should only ever be run with a non or any polling engine + // Override the polling engine for the non-polling engine + // and add a custom polling engine + grpc_register_event_engine_factory("none", init_engine_vtable, false); + grpc_register_event_engine_factory("bm_cq_multiple_threads", + init_engine_vtable, true); + + grpc_init(); + GPR_ASSERT(strcmp(grpc_get_poll_strategy_name(), "none") == 0 || + strcmp(grpc_get_poll_strategy_name(), "bm_cq_multiple_threads") == + 0); + + g_cq = grpc_completion_queue_create_for_next(nullptr); +} + +static void teardown() { + grpc_completion_queue_shutdown(g_cq); + + /* Drain any events */ + gpr_timespec deadline = gpr_time_0(GPR_CLOCK_MONOTONIC); + while (grpc_completion_queue_next(g_cq, deadline, nullptr).type != + GRPC_QUEUE_SHUTDOWN) { + /* Do nothing */ + } + + grpc_completion_queue_destroy(g_cq); + grpc_shutdown(); +} + +/* A few notes about Multi-threaded benchmarks: + + Setup: + The benchmark framework ensures that none of the threads proceed beyond the + state.KeepRunning() call unless all the threads have called state.keepRunning + at least once. So it is safe to do the initialization in one of the threads + before state.KeepRunning() is called. + + Teardown: + The benchmark framework also ensures that no thread is running the benchmark + code (i.e the code between two successive calls of state.KeepRunning()) if + state.KeepRunning() returns false. So it is safe to do the teardown in one + of the threads after state.keepRunning() returns false. + + However, our use requires synchronization because we do additional work at + each thread that requires specific ordering (TrackCounters must be constructed + after grpc_init because it needs the number of cores, initialized by grpc, + and its Finish call must take place before grpc_shutdown so that it can use + grpc_stats). +*/ +static void BM_Cq_Throughput(benchmark::State& state) { + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_MONOTONIC); + auto thd_idx = hack::get_thread_idx(state); + + gpr_mu_lock(&g_mu); + g_threads_active++; + if (thd_idx == 0) { + setup(); + g_active = true; + gpr_cv_broadcast(&g_cv); + } else { + while (!g_active) { + gpr_cv_wait(&g_cv, &g_mu, deadline); + } + } + gpr_mu_unlock(&g_mu); + + // Use a TrackCounters object to monitor the gRPC performance statistics + // (optionally including low-level counters) before and after the test + TrackCounters track_counters; + + for (auto _ : state) { + GPR_ASSERT(grpc_completion_queue_next(g_cq, deadline, nullptr).type == + GRPC_OP_COMPLETE); + } + + state.SetItemsProcessed(state.iterations()); + track_counters.Finish(state); + + gpr_mu_lock(&g_mu); + g_threads_active--; + if (g_threads_active == 0) { + gpr_cv_broadcast(&g_cv); + } else { + while (g_threads_active > 0) { + gpr_cv_wait(&g_cv, &g_mu, deadline); + } + } + gpr_mu_unlock(&g_mu); + + if (thd_idx == 0) { + teardown(); + g_active = false; + } +} + +BENCHMARK(BM_Cq_Throughput)->ThreadRange(1, 16)->UseRealTime(); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + gpr_mu_init(&g_mu); + gpr_cv_init(&g_cv); + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_error.cc b/test/cpp/microbenchmarks/bm_error.cc new file mode 100644 index 00000000..544d8879 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_error.cc @@ -0,0 +1,329 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test various operations on grpc_error */ + +#include + +#include + +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/transport/error_utils.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +class ErrorHandleHolder { + public: + explicit ErrorHandleHolder(grpc_error_handle error) : error_(error) {} + ~ErrorHandleHolder() { GRPC_ERROR_UNREF(error_); } + const grpc_error_handle& get() const { return error_; } + + private: + grpc_error_handle error_; +}; + +static void BM_ErrorCreateFromStatic(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + GRPC_ERROR_UNREF(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error")); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateFromStatic); + +static void BM_ErrorCreateFromCopied(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + GRPC_ERROR_UNREF(GRPC_ERROR_CREATE_FROM_COPIED_STRING("Error not inline")); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateFromCopied); + +static void BM_ErrorCreateAndSetStatus(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + GRPC_ERROR_UNREF( + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_ABORTED)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateAndSetStatus); + +static void BM_ErrorCreateAndSetIntAndStr(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + GRPC_ERROR_UNREF(grpc_error_set_str( + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("GOAWAY received"), + GRPC_ERROR_INT_HTTP2_ERROR, (intptr_t)0), + GRPC_ERROR_STR_RAW_BYTES, "raw bytes")); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateAndSetIntAndStr); + +static void BM_ErrorCreateAndSetIntLoop(benchmark::State& state) { + TrackCounters track_counters; + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"); + int n = 0; + for (auto _ : state) { + error = grpc_error_set_int(error, GRPC_ERROR_INT_GRPC_STATUS, n++); + } + GRPC_ERROR_UNREF(error); + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateAndSetIntLoop); + +static void BM_ErrorCreateAndSetStrLoop(benchmark::State& state) { + TrackCounters track_counters; + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"); + const char* str = "hello"; + for (auto _ : state) { + error = grpc_error_set_str(error, GRPC_ERROR_STR_GRPC_MESSAGE, str); + } + GRPC_ERROR_UNREF(error); + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorCreateAndSetStrLoop); + +static void BM_ErrorRefUnref(benchmark::State& state) { + TrackCounters track_counters; + grpc_error_handle error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"); + for (auto _ : state) { + GRPC_ERROR_UNREF(GRPC_ERROR_REF(error)); + } + GRPC_ERROR_UNREF(error); + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorRefUnref); + +static void BM_ErrorUnrefNone(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + GRPC_ERROR_UNREF(GRPC_ERROR_NONE); + } +} +BENCHMARK(BM_ErrorUnrefNone); + +static void BM_ErrorGetIntFromNoError(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + intptr_t value; + grpc_error_get_int(GRPC_ERROR_NONE, GRPC_ERROR_INT_GRPC_STATUS, &value); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorGetIntFromNoError); + +static void BM_ErrorGetMissingInt(benchmark::State& state) { + TrackCounters track_counters; + ErrorHandleHolder error(grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), GRPC_ERROR_INT_INDEX, 1)); + for (auto _ : state) { + intptr_t value; + grpc_error_get_int(error.get(), GRPC_ERROR_INT_OFFSET, &value); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorGetMissingInt); + +static void BM_ErrorGetPresentInt(benchmark::State& state) { + TrackCounters track_counters; + ErrorHandleHolder error(grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), GRPC_ERROR_INT_OFFSET, 1)); + for (auto _ : state) { + intptr_t value; + grpc_error_get_int(error.get(), GRPC_ERROR_INT_OFFSET, &value); + } + track_counters.Finish(state); +} +BENCHMARK(BM_ErrorGetPresentInt); + +// Fixtures for tests: generate different kinds of errors +class ErrorNone { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return GRPC_ERROR_NONE; } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; +}; + +class ErrorCancelled { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return GRPC_ERROR_CANCELLED; } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; +}; + +class SimpleError { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return error_.get(); } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; + ErrorHandleHolder error_{GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error")}; +}; + +class ErrorWithGrpcStatus { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return error_.get(); } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; + ErrorHandleHolder error_{grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNIMPLEMENTED)}; +}; + +class ErrorWithHttpError { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return error_.get(); } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; + ErrorHandleHolder error_{grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), GRPC_ERROR_INT_HTTP2_ERROR, + GRPC_HTTP2_COMPRESSION_ERROR)}; +}; + +class ErrorWithNestedGrpcStatus { + public: + grpc_millis deadline() const { return deadline_; } + grpc_error_handle error() const { return error_.get(); } + + private: + const grpc_millis deadline_ = GRPC_MILLIS_INF_FUTURE; + ErrorHandleHolder nested_error_{grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Error"), GRPC_ERROR_INT_GRPC_STATUS, + GRPC_STATUS_UNIMPLEMENTED)}; + grpc_error_handle nested_errors_[1] = {nested_error_.get()}; + ErrorHandleHolder error_{GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Error", nested_errors_, 1)}; +}; + +template +static void BM_ErrorStringOnNewError(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + Fixture fixture; + grpc_error_std_string(fixture.error()); + } + track_counters.Finish(state); +} + +template +static void BM_ErrorStringRepeatedly(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + for (auto _ : state) { + grpc_error_std_string(fixture.error()); + } + track_counters.Finish(state); +} + +template +static void BM_ErrorGetStatus(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_status_code status; + std::string message; + grpc_error_get_status(fixture.error(), fixture.deadline(), &status, + &message, nullptr, nullptr); + } + + track_counters.Finish(state); +} + +template +static void BM_ErrorGetStatusCode(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_status_code status; + grpc_error_get_status(fixture.error(), fixture.deadline(), &status, nullptr, + nullptr, nullptr); + } + + track_counters.Finish(state); +} + +template +static void BM_ErrorHttpError(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + grpc_http2_error_code error; + grpc_error_get_status(fixture.error(), fixture.deadline(), nullptr, nullptr, + &error, nullptr); + } + + track_counters.Finish(state); +} + +template +static void BM_HasClearGrpcStatus(benchmark::State& state) { + TrackCounters track_counters; + Fixture fixture; + for (auto _ : state) { + grpc_error_has_clear_grpc_status(fixture.error()); + } + track_counters.Finish(state); +} + +#define BENCHMARK_SUITE(fixture) \ + BENCHMARK_TEMPLATE(BM_ErrorStringOnNewError, fixture); \ + BENCHMARK_TEMPLATE(BM_ErrorStringRepeatedly, fixture); \ + BENCHMARK_TEMPLATE(BM_ErrorGetStatus, fixture); \ + BENCHMARK_TEMPLATE(BM_ErrorGetStatusCode, fixture); \ + BENCHMARK_TEMPLATE(BM_ErrorHttpError, fixture); \ + BENCHMARK_TEMPLATE(BM_HasClearGrpcStatus, fixture) + +BENCHMARK_SUITE(ErrorNone); +BENCHMARK_SUITE(ErrorCancelled); +BENCHMARK_SUITE(SimpleError); +BENCHMARK_SUITE(ErrorWithGrpcStatus); +BENCHMARK_SUITE(ErrorWithHttpError); +BENCHMARK_SUITE(ErrorWithNestedGrpcStatus); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc b/test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc new file mode 100644 index 00000000..7096a84b --- /dev/null +++ b/test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc @@ -0,0 +1,129 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark gRPC end2end in various configurations */ + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/fullstack_streaming_ping_pong.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +/******************************************************************************* + * CONFIGURATIONS + */ + +// Generate Args for StreamingPingPong benchmarks. Currently generates args for +// only "small streams" (i.e streams with 0, 1 or 2 messages) +static void StreamingPingPongArgs(benchmark::internal::Benchmark* b) { + int msg_size = 0; + + b->Args({0, 0}); // spl case: 0 ping-pong msgs (msg_size doesn't matter here) + + for (msg_size = 0; msg_size <= 128 * 1024 * 1024; + msg_size == 0 ? msg_size++ : msg_size *= 8) { + b->Args({msg_size, 1}); + b->Args({msg_size, 2}); + } +} + +BENCHMARK_TEMPLATE(BM_StreamingPingPong, InProcessCHTTP2, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPong, TCP, NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPong, InProcess, NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongArgs); + +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, InProcessCHTTP2, NoOpMutator, + NoOpMutator) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, TCP, NoOpMutator, NoOpMutator) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, InProcess, NoOpMutator, + NoOpMutator) + ->Range(0, 128 * 1024 * 1024); + +BENCHMARK_TEMPLATE(BM_StreamingPingPong, MinInProcessCHTTP2, NoOpMutator, + NoOpMutator) + ->Apply(StreamingPingPongArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPong, MinTCP, NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPong, MinInProcess, NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongArgs); + +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, MinInProcessCHTTP2, NoOpMutator, + NoOpMutator) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, MinTCP, NoOpMutator, NoOpMutator) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_StreamingPingPongMsgs, MinInProcess, NoOpMutator, + NoOpMutator) + ->Range(0, 128 * 1024 * 1024); + +// Generate Args for StreamingPingPongWithCoalescingApi benchmarks. Currently +// generates args for only "small streams" (i.e streams with 0, 1 or 2 messages) +static void StreamingPingPongWithCoalescingApiArgs( + benchmark::internal::Benchmark* b) { + int msg_size = 0; + + b->Args( + {0, 0, 0}); // spl case: 0 ping-pong msgs (msg_size doesn't matter here) + b->Args( + {0, 0, 1}); // spl case: 0 ping-pong msgs (msg_size doesn't matter here) + + for (msg_size = 0; msg_size <= 128 * 1024 * 1024; + msg_size == 0 ? msg_size++ : msg_size *= 8) { + b->Args({msg_size, 1, 0}); + b->Args({msg_size, 2, 0}); + b->Args({msg_size, 1, 1}); + b->Args({msg_size, 2, 1}); + } +} + +BENCHMARK_TEMPLATE(BM_StreamingPingPongWithCoalescingApi, InProcessCHTTP2, + NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongWithCoalescingApiArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPongWithCoalescingApi, MinInProcessCHTTP2, + NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongWithCoalescingApiArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPongWithCoalescingApi, InProcess, + NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongWithCoalescingApiArgs); +BENCHMARK_TEMPLATE(BM_StreamingPingPongWithCoalescingApi, MinInProcess, + NoOpMutator, NoOpMutator) + ->Apply(StreamingPingPongWithCoalescingApiArgs); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc b/test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc new file mode 100644 index 00000000..f0f57f2f --- /dev/null +++ b/test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc @@ -0,0 +1,73 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark gRPC end2end in various configurations */ + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/fullstack_streaming_pump.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +/******************************************************************************* + * CONFIGURATIONS + */ + +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, TCP) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, UDS) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, InProcess) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, InProcessCHTTP2) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, TCP) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, UDS) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, InProcess) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, InProcessCHTTP2) + ->Range(0, 128 * 1024 * 1024); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, MinTCP)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, MinUDS)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, MinInProcess)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamClientToServer, MinInProcessCHTTP2)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, MinTCP)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, MinUDS)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, MinInProcess)->Arg(0); +BENCHMARK_TEMPLATE(BM_PumpStreamServerToClient, MinInProcessCHTTP2)->Arg(0); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_fullstack_trickle.cc b/test/cpp/microbenchmarks/bm_fullstack_trickle.cc new file mode 100644 index 00000000..6d7394aa --- /dev/null +++ b/test/cpp/microbenchmarks/bm_fullstack_trickle.cc @@ -0,0 +1,480 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark gRPC end2end in various configurations */ + +#include + +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/ext/transport/chttp2/transport/internal.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/profiling/timers.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" +#include "test/core/util/trickle_endpoint.h" +#include "test/cpp/microbenchmarks/fullstack_context_mutators.h" +#include "test/cpp/microbenchmarks/fullstack_fixtures.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(bool, log, false, "Log state to CSV files"); +ABSL_FLAG(int32_t, warmup_megabytes, 1, + "Number of megabytes to pump before collecting flow control stats"); +ABSL_FLAG(int32_t, warmup_iterations, 100, + "Number of iterations to run before collecting flow control stats"); +ABSL_FLAG(int32_t, warmup_max_time_seconds, 10, + "Maximum number of seconds to run warmup loop"); + +namespace grpc { +namespace testing { + +gpr_atm g_now_us = 0; + +static gpr_timespec fake_now(gpr_clock_type clock_type) { + gpr_timespec t; + gpr_atm now = gpr_atm_no_barrier_load(&g_now_us); + t.tv_sec = now / GPR_US_PER_SEC; + t.tv_nsec = (now % GPR_US_PER_SEC) * GPR_NS_PER_US; + t.clock_type = clock_type; + return t; +} + +static void inc_time() { + gpr_atm_no_barrier_fetch_add(&g_now_us, 100); + grpc_timer_manager_tick(); +} + +static void* tag(intptr_t x) { return reinterpret_cast(x); } + +template +static void write_csv(std::ostream* out, A0&& a0) { + if (!out) return; + (*out) << a0 << "\n"; +} + +template +static void write_csv(std::ostream* out, A0&& a0, Arg&&... arg) { + if (!out) return; + (*out) << a0 << ","; + write_csv(out, std::forward(arg)...); +} + +class TrickledCHTTP2 : public EndpointPairFixture { + public: + TrickledCHTTP2(Service* service, bool streaming, size_t req_size, + size_t resp_size, size_t kilobits_per_second, + grpc_passthru_endpoint_stats* stats) + : EndpointPairFixture(service, MakeEndpoints(kilobits_per_second, stats), + FixtureConfiguration()), + stats_(stats) { + if (absl::GetFlag(FLAGS_log)) { + std::ostringstream fn; + fn << "trickle." << (streaming ? "streaming" : "unary") << "." << req_size + << "." << resp_size << "." << kilobits_per_second << ".csv"; + log_ = absl::make_unique(fn.str().c_str()); + write_csv(log_.get(), "t", "iteration", "client_backlog", + "server_backlog", "client_t_stall", "client_s_stall", + "server_t_stall", "server_s_stall", "client_t_remote", + "server_t_remote", "client_t_announced", "server_t_announced", + "client_s_remote_delta", "server_s_remote_delta", + "client_s_local_delta", "server_s_local_delta", + "client_s_announced_delta", "server_s_announced_delta", + "client_peer_iws", "client_local_iws", "client_sent_iws", + "client_acked_iws", "server_peer_iws", "server_local_iws", + "server_sent_iws", "server_acked_iws", "client_queued_bytes", + "server_queued_bytes"); + } + } + + ~TrickledCHTTP2() override { + if (stats_ != nullptr) { + grpc_passthru_endpoint_stats_destroy(stats_); + } + } + + void AddToLabel(std::ostream& out, benchmark::State& state) override { + out << " writes/iter:" + << (static_cast(stats_->num_writes) / + static_cast(state.iterations())) + << " cli_transport_stalls/iter:" + << (static_cast( + client_stats_.streams_stalled_due_to_transport_flow_control) / + static_cast(state.iterations())) + << " cli_stream_stalls/iter:" + << (static_cast( + client_stats_.streams_stalled_due_to_stream_flow_control) / + static_cast(state.iterations())) + << " svr_transport_stalls/iter:" + << (static_cast( + server_stats_.streams_stalled_due_to_transport_flow_control) / + static_cast(state.iterations())) + << " svr_stream_stalls/iter:" + << (static_cast( + server_stats_.streams_stalled_due_to_stream_flow_control) / + static_cast(state.iterations())); + } + + void Log(int64_t iteration) GPR_ATTRIBUTE_NO_TSAN { + auto now = gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), start_); + grpc_chttp2_transport* client = + reinterpret_cast(client_transport_); + grpc_chttp2_transport* server = + reinterpret_cast(server_transport_); + grpc_chttp2_stream* client_stream = + client->stream_map.count == 1 + ? static_cast(client->stream_map.values[0]) + : nullptr; + grpc_chttp2_stream* server_stream = + server->stream_map.count == 1 + ? static_cast(server->stream_map.values[0]) + : nullptr; + write_csv( + log_.get(), + static_cast(now.tv_sec) + + 1e-9 * static_cast(now.tv_nsec), + iteration, grpc_trickle_get_backlog(endpoint_pair_.client), + grpc_trickle_get_backlog(endpoint_pair_.server), + client->lists[GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT].head != nullptr, + client->lists[GRPC_CHTTP2_LIST_STALLED_BY_STREAM].head != nullptr, + server->lists[GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT].head != nullptr, + server->lists[GRPC_CHTTP2_LIST_STALLED_BY_STREAM].head != nullptr, + client->flow_control->remote_window_, + server->flow_control->remote_window_, + client->flow_control->announced_window_, + server->flow_control->announced_window_, + client_stream ? client_stream->flow_control->remote_window_delta_ : -1, + server_stream ? server_stream->flow_control->remote_window_delta_ : -1, + client_stream ? client_stream->flow_control->local_window_delta_ : -1, + server_stream ? server_stream->flow_control->local_window_delta_ : -1, + client_stream ? client_stream->flow_control->announced_window_delta_ + : -1, + server_stream ? server_stream->flow_control->announced_window_delta_ + : -1, + client->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + client->settings[GRPC_LOCAL_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + client->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + client->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + server->settings[GRPC_PEER_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + server->settings[GRPC_LOCAL_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + server->settings[GRPC_SENT_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + server->settings[GRPC_ACKED_SETTINGS] + [GRPC_CHTTP2_SETTINGS_INITIAL_WINDOW_SIZE], + client_stream ? client_stream->flow_controlled_buffer.length : 0, + server_stream ? server_stream->flow_controlled_buffer.length : 0); + } + + void Step(bool update_stats) { + grpc_core::ExecCtx exec_ctx; + inc_time(); + size_t client_backlog = + grpc_trickle_endpoint_trickle(endpoint_pair_.client); + size_t server_backlog = + grpc_trickle_endpoint_trickle(endpoint_pair_.server); + + if (update_stats) { + UpdateStats(reinterpret_cast(client_transport_), + &client_stats_, client_backlog); + UpdateStats(reinterpret_cast(server_transport_), + &server_stats_, server_backlog); + } + } + + private: + grpc_passthru_endpoint_stats* stats_; + struct Stats { + int streams_stalled_due_to_stream_flow_control = 0; + int streams_stalled_due_to_transport_flow_control = 0; + }; + Stats client_stats_; + Stats server_stats_; + std::unique_ptr log_; + gpr_timespec start_ = gpr_now(GPR_CLOCK_MONOTONIC); + + static grpc_endpoint_pair MakeEndpoints(size_t kilobits, + grpc_passthru_endpoint_stats* stats) { + grpc_endpoint_pair p; + grpc_passthru_endpoint_create(&p.client, &p.server, stats); + double bytes_per_second = 125.0 * kilobits; + p.client = grpc_trickle_endpoint_create(p.client, bytes_per_second); + p.server = grpc_trickle_endpoint_create(p.server, bytes_per_second); + return p; + } + + void UpdateStats(grpc_chttp2_transport* t, Stats* s, + size_t backlog) GPR_ATTRIBUTE_NO_TSAN { + if (backlog == 0) { + if (t->lists[GRPC_CHTTP2_LIST_STALLED_BY_STREAM].head != nullptr) { + s->streams_stalled_due_to_stream_flow_control++; + } + if (t->lists[GRPC_CHTTP2_LIST_STALLED_BY_TRANSPORT].head != nullptr) { + s->streams_stalled_due_to_transport_flow_control++; + } + } + } +}; + +static void TrickleCQNext(TrickledCHTTP2* fixture, void** t, bool* ok, + int64_t iteration) { + while (true) { + fixture->Log(iteration); + switch ( + fixture->cq()->AsyncNext(t, ok, gpr_inf_past(GPR_CLOCK_MONOTONIC))) { + case CompletionQueue::TIMEOUT: + fixture->Step(iteration != -1); + break; + case CompletionQueue::SHUTDOWN: + GPR_ASSERT(false); + break; + case CompletionQueue::GOT_EVENT: + return; + } + } +} + +static void BM_PumpStreamServerToClient_Trickle(benchmark::State& state) { + EchoTestService::AsyncService service; + std::unique_ptr fixture(new TrickledCHTTP2( + &service, true, state.range(0) /* req_size */, + state.range(0) /* resp_size */, state.range(1) /* bw in kbit/s */, + grpc_passthru_endpoint_stats_create())); + { + EchoResponse send_response; + EchoResponse recv_response; + if (state.range(0) > 0) { + send_response.set_message(std::string(state.range(0), 'a')); + } + Status recv_status; + ServerContext svr_ctx; + ServerAsyncReaderWriter response_rw(&svr_ctx); + service.RequestBidiStream(&svr_ctx, &response_rw, fixture->cq(), + fixture->cq(), tag(0)); + std::unique_ptr stub( + EchoTestService::NewStub(fixture->channel())); + ClientContext cli_ctx; + auto request_rw = stub->AsyncBidiStream(&cli_ctx, fixture->cq(), tag(1)); + int need_tags = (1 << 0) | (1 << 1); + void* t; + bool ok; + while (need_tags) { + TrickleCQNext(fixture.get(), &t, &ok, -1); + GPR_ASSERT(ok); + int i = static_cast(reinterpret_cast(t)); + GPR_ASSERT(need_tags & (1 << i)); + need_tags &= ~(1 << i); + } + request_rw->Read(&recv_response, tag(0)); + auto inner_loop = [&](bool in_warmup) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + response_rw.Write(send_response, tag(1)); + while (true) { + TrickleCQNext(fixture.get(), &t, &ok, + in_warmup ? -1 : state.iterations()); + if (t == tag(0)) { + request_rw->Read(&recv_response, tag(0)); + } else if (t == tag(1)) { + break; + } else { + GPR_ASSERT(false); + } + } + }; + gpr_timespec warmup_start = gpr_now(GPR_CLOCK_MONOTONIC); + for (int i = 0; + i < std::max(int64_t(absl::GetFlag(FLAGS_warmup_iterations)), + absl::GetFlag(FLAGS_warmup_megabytes) * 1024 * 1024 / + (14 + state.range(0))); + i++) { + inner_loop(true); + if (gpr_time_cmp(gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), warmup_start), + gpr_time_from_seconds( + absl::GetFlag(FLAGS_warmup_max_time_seconds), + GPR_TIMESPAN)) > 0) { + break; + } + } + while (state.KeepRunning()) { + inner_loop(false); + } + response_rw.Finish(Status::OK, tag(1)); + grpc::Status status; + request_rw->Finish(&status, tag(2)); + need_tags = (1 << 0) | (1 << 1) | (1 << 2); + while (need_tags) { + TrickleCQNext(fixture.get(), &t, &ok, -1); + if (t == tag(0) && ok) { + request_rw->Read(&recv_response, tag(0)); + continue; + } + int i = static_cast(reinterpret_cast(t)); + GPR_ASSERT(need_tags & (1 << i)); + need_tags &= ~(1 << i); + } + } + fixture->Finish(state); + fixture.reset(); + state.SetBytesProcessed(state.range(0) * state.iterations()); +} + +static void StreamingTrickleArgs(benchmark::internal::Benchmark* b) { + for (int i = 1; i <= 128 * 1024 * 1024; i *= 8) { + for (int j = 64; j <= 128 * 1024 * 1024; j *= 8) { + double expected_time = + static_cast(14 + i) / (125.0 * static_cast(j)); + if (expected_time > 2.0) continue; + b->Args({i, j}); + } + } +} +BENCHMARK(BM_PumpStreamServerToClient_Trickle)->Apply(StreamingTrickleArgs); + +static void BM_PumpUnbalancedUnary_Trickle(benchmark::State& state) { + EchoTestService::AsyncService service; + std::unique_ptr fixture(new TrickledCHTTP2( + &service, false, state.range(0) /* req_size */, + state.range(1) /* resp_size */, state.range(2) /* bw in kbit/s */, + grpc_passthru_endpoint_stats_create())); + EchoRequest send_request; + EchoResponse send_response; + EchoResponse recv_response; + if (state.range(0) > 0) { + send_request.set_message(std::string(state.range(0), 'a')); + } + if (state.range(1) > 0) { + send_response.set_message(std::string(state.range(1), 'a')); + } + Status recv_status; + struct ServerEnv { + ServerContext ctx; + EchoRequest recv_request; + grpc::ServerAsyncResponseWriter response_writer; + ServerEnv() : response_writer(&ctx) {} + }; + uint8_t server_env_buffer[2 * sizeof(ServerEnv)]; + ServerEnv* server_env[2] = { + reinterpret_cast(server_env_buffer), + reinterpret_cast(server_env_buffer + sizeof(ServerEnv))}; + new (server_env[0]) ServerEnv; + new (server_env[1]) ServerEnv; + service.RequestEcho(&server_env[0]->ctx, &server_env[0]->recv_request, + &server_env[0]->response_writer, fixture->cq(), + fixture->cq(), tag(0)); + service.RequestEcho(&server_env[1]->ctx, &server_env[1]->recv_request, + &server_env[1]->response_writer, fixture->cq(), + fixture->cq(), tag(1)); + std::unique_ptr stub( + EchoTestService::NewStub(fixture->channel())); + auto inner_loop = [&](bool in_warmup) { + GPR_TIMER_SCOPE("BenchmarkCycle", 0); + recv_response.Clear(); + ClientContext cli_ctx; + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, fixture->cq())); + void* t; + bool ok; + response_reader->Finish(&recv_response, &recv_status, tag(4)); + TrickleCQNext(fixture.get(), &t, &ok, in_warmup ? -1 : state.iterations()); + GPR_ASSERT(ok); + GPR_ASSERT(t == tag(0) || t == tag(1)); + intptr_t slot = reinterpret_cast(t); + ServerEnv* senv = server_env[slot]; + senv->response_writer.Finish(send_response, Status::OK, tag(3)); + for (int i = (1 << 3) | (1 << 4); i != 0;) { + TrickleCQNext(fixture.get(), &t, &ok, + in_warmup ? -1 : state.iterations()); + GPR_ASSERT(ok); + int tagnum = static_cast(reinterpret_cast(t)); + GPR_ASSERT(i & (1 << tagnum)); + i -= 1 << tagnum; + } + GPR_ASSERT(recv_status.ok()); + + senv->~ServerEnv(); + senv = new (senv) ServerEnv(); + service.RequestEcho(&senv->ctx, &senv->recv_request, &senv->response_writer, + fixture->cq(), fixture->cq(), tag(slot)); + }; + gpr_timespec warmup_start = gpr_now(GPR_CLOCK_MONOTONIC); + for (int i = 0; i < std::max(int64_t(absl::GetFlag(FLAGS_warmup_iterations)), + absl::GetFlag(FLAGS_warmup_megabytes) * 1024 * + 1024 / (14 + state.range(0))); + i++) { + inner_loop(true); + if (gpr_time_cmp( + gpr_time_sub(gpr_now(GPR_CLOCK_MONOTONIC), warmup_start), + gpr_time_from_seconds(absl::GetFlag(FLAGS_warmup_max_time_seconds), + GPR_TIMESPAN)) > 0) { + break; + } + } + while (state.KeepRunning()) { + inner_loop(false); + } + fixture->Finish(state); + fixture.reset(); + server_env[0]->~ServerEnv(); + server_env[1]->~ServerEnv(); + state.SetBytesProcessed(state.range(0) * state.iterations() + + state.range(1) * state.iterations()); +} + +static void UnaryTrickleArgs(benchmark::internal::Benchmark* b) { + for (int bw = 64; bw <= 128 * 1024 * 1024; bw *= 16) { + b->Args({1, 1, bw}); + for (int i = 64; i <= 128 * 1024 * 1024; i *= 64) { + double expected_time = + static_cast(14 + i) / (125.0 * static_cast(bw)); + if (expected_time > 2.0) continue; + b->Args({i, 1, bw}); + b->Args({1, i, bw}); + b->Args({i, i, bw}); + } + } +} +BENCHMARK(BM_PumpUnbalancedUnary_Trickle)->Apply(UnaryTrickleArgs); +} // namespace testing +} // namespace grpc + +extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + grpc_timer_manager_set_threading(false); + gpr_now_impl = ::grpc::testing::fake_now; + benchmark::RunTheBenchmarksNamespaced(); +} diff --git a/test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc b/test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc new file mode 100644 index 00000000..c47476e7 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc @@ -0,0 +1,181 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Benchmark gRPC end2end in various configurations */ + +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/fullstack_unary_ping_pong.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +/******************************************************************************* + * CONFIGURATIONS + */ + +// Replace "benchmark::internal::Benchmark" with "::testing::Benchmark" to use +// internal microbenchmarking tooling +static void SweepSizesArgs(benchmark::internal::Benchmark* b) { + b->Args({0, 0}); + for (int i = 1; i <= 128 * 1024 * 1024; i *= 8) { + b->Args({i, 0}); + b->Args({0, i}); + b->Args({i, i}); + } +} + +BENCHMARK_TEMPLATE(BM_UnaryPingPong, TCP, NoOpMutator, NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, MinTCP, NoOpMutator, NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, UDS, NoOpMutator, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, MinUDS, NoOpMutator, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, MinInProcess, NoOpMutator, NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, SockPair, NoOpMutator, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, MinSockPair, NoOpMutator, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, MinInProcessCHTTP2, NoOpMutator, + NoOpMutator) + ->Apply(SweepSizesArgs); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 2>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcessCHTTP2, NoOpMutator, + Server_AddInitialMetadata, 100>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 2>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 2>, + NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, + Client_AddMetadata, 1>, NoOpMutator) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 1>) + ->Args({0, 0}); +BENCHMARK_TEMPLATE(BM_UnaryPingPong, InProcess, NoOpMutator, + Server_AddInitialMetadata, 100>) + ->Args({0, 0}); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_metadata.cc b/test/cpp/microbenchmarks/bm_metadata.cc new file mode 100644 index 00000000..ac48c096 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_metadata.cc @@ -0,0 +1,305 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test out various metadata handling primitives */ + +#include + +#include + +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/static_metadata.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +static void BM_SliceFromStatic(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + benchmark::DoNotOptimize(grpc_core::ExternallyManagedSlice("abc")); + } + track_counters.Finish(state); +} +BENCHMARK(BM_SliceFromStatic); + +static void BM_SliceFromCopied(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + grpc_slice_unref(grpc_core::UnmanagedMemorySlice("abc")); + } + track_counters.Finish(state); +} +BENCHMARK(BM_SliceFromCopied); + +static void BM_SliceIntern(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExternallyManagedSlice slice("abc"); + for (auto _ : state) { + grpc_slice_unref(grpc_core::ManagedMemorySlice(&slice)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_SliceIntern); + +static void BM_SliceReIntern(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExternallyManagedSlice static_slice("abc"); + grpc_core::ManagedMemorySlice slice(&static_slice); + for (auto _ : state) { + grpc_slice_unref(grpc_core::ManagedMemorySlice(&slice)); + } + grpc_slice_unref(slice); + track_counters.Finish(state); +} +BENCHMARK(BM_SliceReIntern); + +static void BM_SliceInternStaticMetadata(benchmark::State& state) { + TrackCounters track_counters; + for (auto _ : state) { + benchmark::DoNotOptimize(grpc_core::ManagedMemorySlice(&GRPC_MDSTR_GZIP)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_SliceInternStaticMetadata); + +static void BM_SliceInternEqualToStaticMetadata(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExternallyManagedSlice slice("gzip"); + for (auto _ : state) { + benchmark::DoNotOptimize(grpc_core::ManagedMemorySlice(&slice)); + } + track_counters.Finish(state); +} +BENCHMARK(BM_SliceInternEqualToStaticMetadata); + +static void BM_MetadataFromNonInternedSlices(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExternallyManagedSlice k("key"); + grpc_core::ExternallyManagedSlice v("value"); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create(k, v, nullptr)); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromNonInternedSlices); + +static void BM_MetadataFromInternedSlices(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ManagedMemorySlice v("value"); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create(k, v, nullptr)); + } + + grpc_slice_unref(k); + grpc_slice_unref(v); + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromInternedSlices); + +static void BM_MetadataFromInternedSlicesAlreadyInIndex( + benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ManagedMemorySlice v("value"); + grpc_core::ExecCtx exec_ctx; + grpc_mdelem seed = grpc_mdelem_create(k, v, nullptr); + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create(k, v, nullptr)); + } + GRPC_MDELEM_UNREF(seed); + + grpc_slice_unref(k); + grpc_slice_unref(v); + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromInternedSlicesAlreadyInIndex); + +static void BM_MetadataFromInternedKey(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ExternallyManagedSlice v("value"); + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create(k, v, nullptr)); + } + + grpc_slice_unref(k); + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromInternedKey); + +static void BM_MetadataFromNonInternedSlicesWithBackingStore( + benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExternallyManagedSlice k("key"); + grpc_core::ExternallyManagedSlice v("value"); + char backing_store[sizeof(grpc_mdelem_data)]; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create( + k, v, reinterpret_cast(backing_store))); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromNonInternedSlicesWithBackingStore); + +static void BM_MetadataFromInternedSlicesWithBackingStore( + benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ManagedMemorySlice v("value"); + char backing_store[sizeof(grpc_mdelem_data)]; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create( + k, v, reinterpret_cast(backing_store))); + } + + grpc_slice_unref(k); + grpc_slice_unref(v); + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromInternedSlicesWithBackingStore); + +static void BM_MetadataFromInternedKeyWithBackingStore( + benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ExternallyManagedSlice v("value"); + char backing_store[sizeof(grpc_mdelem_data)]; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF(grpc_mdelem_create( + k, v, reinterpret_cast(backing_store))); + } + + grpc_slice_unref(k); + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromInternedKeyWithBackingStore); + +static void BM_MetadataFromStaticMetadataStrings(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF( + grpc_mdelem_create(GRPC_MDSTR_STATUS, GRPC_MDSTR_200, nullptr)); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromStaticMetadataStrings); + +static void BM_MetadataFromStaticMetadataStringsNotIndexed( + benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + for (auto _ : state) { + GRPC_MDELEM_UNREF( + grpc_mdelem_create(GRPC_MDSTR_STATUS, GRPC_MDSTR_GZIP, nullptr)); + } + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataFromStaticMetadataStringsNotIndexed); + +static void BM_MetadataRefUnrefExternal(benchmark::State& state) { + TrackCounters track_counters; + char backing_store[sizeof(grpc_mdelem_data)]; + grpc_core::ExecCtx exec_ctx; + grpc_mdelem el = + grpc_mdelem_create(grpc_core::ExternallyManagedSlice("a"), + grpc_core::ExternallyManagedSlice("b"), + reinterpret_cast(backing_store)); + for (auto _ : state) { + GRPC_MDELEM_UNREF(GRPC_MDELEM_REF(el)); + } + GRPC_MDELEM_UNREF(el); + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataRefUnrefExternal); + +static void BM_MetadataRefUnrefInterned(benchmark::State& state) { + TrackCounters track_counters; + char backing_store[sizeof(grpc_mdelem_data)]; + grpc_core::ExecCtx exec_ctx; + grpc_core::ManagedMemorySlice k("key"); + grpc_core::ManagedMemorySlice v("value"); + grpc_mdelem el = grpc_mdelem_create( + k, v, reinterpret_cast(backing_store)); + grpc_slice_unref(k); + grpc_slice_unref(v); + for (auto _ : state) { + GRPC_MDELEM_UNREF(GRPC_MDELEM_REF(el)); + } + GRPC_MDELEM_UNREF(el); + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataRefUnrefInterned); + +static void BM_MetadataRefUnrefAllocated(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + grpc_mdelem el = + grpc_mdelem_create(grpc_core::ExternallyManagedSlice("a"), + grpc_core::ExternallyManagedSlice("b"), nullptr); + for (auto _ : state) { + GRPC_MDELEM_UNREF(GRPC_MDELEM_REF(el)); + } + GRPC_MDELEM_UNREF(el); + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataRefUnrefAllocated); + +static void BM_MetadataRefUnrefStatic(benchmark::State& state) { + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + grpc_mdelem el = + grpc_mdelem_create(GRPC_MDSTR_STATUS, GRPC_MDSTR_200, nullptr); + for (auto _ : state) { + GRPC_MDELEM_UNREF(GRPC_MDELEM_REF(el)); + } + GRPC_MDELEM_UNREF(el); + + track_counters.Finish(state); +} +BENCHMARK(BM_MetadataRefUnrefStatic); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_opencensus_plugin.cc b/test/cpp/microbenchmarks/bm_opencensus_plugin.cc new file mode 100644 index 00000000..800162f7 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_opencensus_plugin.cc @@ -0,0 +1,134 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include // NOLINT + +#include + +#include "absl/base/call_once.h" +#include "absl/strings/str_cat.h" +#include "opencensus/stats/stats.h" + +#include +#include + +#include "src/core/lib/config/core_configuration.h" +#include "src/cpp/ext/filters/census/grpc_plugin.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" + +absl::once_flag once; +void RegisterOnce() { absl::call_once(once, grpc::RegisterOpenCensusPlugin); } + +class EchoServer final : public grpc::testing::EchoTestService::Service { + grpc::Status Echo(grpc::ServerContext* /*context*/, + const grpc::testing::EchoRequest* request, + grpc::testing::EchoResponse* response) override { + if (request->param().expected_error().code() == 0) { + response->set_message(request->message()); + return grpc::Status::OK; + } else { + return grpc::Status(static_cast( + request->param().expected_error().code()), + ""); + } + } +}; + +// An EchoServerThread object creates an EchoServer on a separate thread and +// shuts down the server and thread when it goes out of scope. +class EchoServerThread final { + public: + EchoServerThread() { + grpc::ServerBuilder builder; + int port; + builder.AddListeningPort("[::]:0", grpc::InsecureServerCredentials(), + &port); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + if (server_ == nullptr || port == 0) { + std::abort(); + } + server_address_ = absl::StrCat("[::]:", port); + server_thread_ = std::thread(&EchoServerThread::RunServerLoop, this); + } + + ~EchoServerThread() { + server_->Shutdown(); + server_thread_.join(); + } + + const std::string& address() { return server_address_; } + + private: + void RunServerLoop() { server_->Wait(); } + + std::string server_address_; + EchoServer service_; + std::unique_ptr server_; + std::thread server_thread_; +}; + +static void BM_E2eLatencyCensusDisabled(benchmark::State& state) { + grpc_core::CoreConfiguration::Reset(); + grpc::testing::TestGrpcScope grpc_scope; + EchoServerThread server; + std::unique_ptr stub = + grpc::testing::EchoTestService::NewStub(grpc::CreateChannel( + server.address(), grpc::InsecureChannelCredentials())); + + grpc::testing::EchoResponse response; + for (auto _ : state) { + grpc::testing::EchoRequest request; + grpc::ClientContext context; + grpc::Status status = stub->Echo(&context, request, &response); + } +} +BENCHMARK(BM_E2eLatencyCensusDisabled); + +static void BM_E2eLatencyCensusEnabled(benchmark::State& state) { + grpc_core::CoreConfiguration::Reset(); + // Now start the test by registering the plugin (once in the execution) + RegisterOnce(); + // This we can safely repeat, and doing so clears accumulated data to avoid + // initialization costs varying between runs. + grpc::RegisterOpenCensusViewsForExport(); + + grpc::testing::TestGrpcScope grpc_scope; + EchoServerThread server; + std::unique_ptr stub = + grpc::testing::EchoTestService::NewStub(grpc::CreateChannel( + server.address(), grpc::InsecureChannelCredentials())); + + grpc::testing::EchoResponse response; + for (auto _ : state) { + grpc::testing::EchoRequest request; + grpc::ClientContext context; + grpc::Status status = stub->Echo(&context, request, &response); + } +} +BENCHMARK(BM_E2eLatencyCensusEnabled); + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; + ::benchmark::RunSpecifiedBenchmarks(); +} diff --git a/test/cpp/microbenchmarks/bm_pollset.cc b/test/cpp/microbenchmarks/bm_pollset.cc new file mode 100644 index 00000000..beb38d62 --- /dev/null +++ b/test/cpp/microbenchmarks/bm_pollset.cc @@ -0,0 +1,268 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* Test out pollset latencies */ + +#include + +#include + +#include +#include +#include + +#include "src/core/lib/gpr/useful.h" +#include "src/core/lib/iomgr/ev_posix.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/wakeup_fd_posix.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +#ifdef GRPC_LINUX_MULTIPOLL_WITH_EPOLL +#include +#include +#include +#endif + +static void shutdown_ps(void* ps, grpc_error_handle /*error*/) { + grpc_pollset_destroy(static_cast(ps)); +} + +static void BM_CreateDestroyPollset(benchmark::State& state) { + TrackCounters track_counters; + size_t ps_sz = grpc_pollset_size(); + grpc_pollset* ps = static_cast(gpr_malloc(ps_sz)); + gpr_mu* mu; + grpc_core::ExecCtx exec_ctx; + grpc_closure shutdown_ps_closure; + GRPC_CLOSURE_INIT(&shutdown_ps_closure, shutdown_ps, ps, + grpc_schedule_on_exec_ctx); + for (auto _ : state) { + memset(ps, 0, ps_sz); + grpc_pollset_init(ps, &mu); + gpr_mu_lock(mu); + grpc_pollset_shutdown(ps, &shutdown_ps_closure); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(ps); + track_counters.Finish(state); +} +BENCHMARK(BM_CreateDestroyPollset); + +#ifdef GRPC_LINUX_MULTIPOLL_WITH_EPOLL +static void BM_PollEmptyPollset_SpeedOfLight(benchmark::State& state) { + // equivalent to BM_PollEmptyPollset, but just use the OS primitives to guage + // what the speed of light would be if we abstracted perfectly + TrackCounters track_counters; + int epfd = epoll_create1(0); + GPR_ASSERT(epfd != -1); + size_t nev = state.range(0); + size_t nfd = state.range(1); + epoll_event* ev = new epoll_event[nev]; + std::vector fds; + for (size_t i = 0; i < nfd; i++) { + fds.push_back(eventfd(0, 0)); + epoll_event ev; + ev.events = EPOLLIN; + epoll_ctl(epfd, EPOLL_CTL_ADD, fds.back(), &ev); + } + for (auto _ : state) { + epoll_wait(epfd, ev, nev, 0); + } + for (auto fd : fds) { + close(fd); + } + close(epfd); + delete[] ev; + track_counters.Finish(state); +} +BENCHMARK(BM_PollEmptyPollset_SpeedOfLight) + ->Args({1, 0}) + ->Args({1, 1}) + ->Args({1, 10}) + ->Args({1, 100}) + ->Args({1, 1000}) + ->Args({1, 10000}) + ->Args({1, 100000}) + ->Args({10, 1}) + ->Args({100, 1}) + ->Args({1000, 1}); +#endif + +static void BM_PollEmptyPollset(benchmark::State& state) { + TrackCounters track_counters; + size_t ps_sz = grpc_pollset_size(); + grpc_pollset* ps = static_cast(gpr_zalloc(ps_sz)); + gpr_mu* mu; + grpc_pollset_init(ps, &mu); + grpc_core::ExecCtx exec_ctx; + gpr_mu_lock(mu); + for (auto _ : state) { + GRPC_ERROR_UNREF(grpc_pollset_work(ps, nullptr, 0)); + } + grpc_closure shutdown_ps_closure; + GRPC_CLOSURE_INIT(&shutdown_ps_closure, shutdown_ps, ps, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(ps, &shutdown_ps_closure); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(ps); + track_counters.Finish(state); +} +BENCHMARK(BM_PollEmptyPollset); + +static void BM_PollAddFd(benchmark::State& state) { + TrackCounters track_counters; + size_t ps_sz = grpc_pollset_size(); + grpc_pollset* ps = static_cast(gpr_zalloc(ps_sz)); + gpr_mu* mu; + grpc_pollset_init(ps, &mu); + grpc_core::ExecCtx exec_ctx; + grpc_wakeup_fd wakeup_fd; + GPR_ASSERT( + GRPC_LOG_IF_ERROR("wakeup_fd_init", grpc_wakeup_fd_init(&wakeup_fd))); + grpc_fd* fd = grpc_fd_create(wakeup_fd.read_fd, "xxx", false); + for (auto _ : state) { + grpc_pollset_add_fd(ps, fd); + grpc_core::ExecCtx::Get()->Flush(); + } + grpc_fd_orphan(fd, nullptr, nullptr, "xxx"); + grpc_closure shutdown_ps_closure; + GRPC_CLOSURE_INIT(&shutdown_ps_closure, shutdown_ps, ps, + grpc_schedule_on_exec_ctx); + gpr_mu_lock(mu); + grpc_pollset_shutdown(ps, &shutdown_ps_closure); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(ps); + track_counters.Finish(state); +} +BENCHMARK(BM_PollAddFd); + +class TestClosure : public grpc_closure { + public: + virtual ~TestClosure() {} +}; + +template +TestClosure* MakeTestClosure(F f) { + struct C : public TestClosure { + explicit C(F f) : f_(f) { GRPC_CLOSURE_INIT(this, C::cbfn, this, nullptr); } + static void cbfn(void* arg, grpc_error_handle /*error*/) { + C* p = static_cast(arg); + p->f_(); + } + F f_; + }; + return new C(f); +} + +#ifdef GRPC_LINUX_MULTIPOLL_WITH_EPOLL +static void BM_SingleThreadPollOneFd_SpeedOfLight(benchmark::State& state) { + // equivalent to BM_PollEmptyPollset, but just use the OS primitives to guage + // what the speed of light would be if we abstracted perfectly + TrackCounters track_counters; + int epfd = epoll_create1(0); + GPR_ASSERT(epfd != -1); + epoll_event ev[100]; + int fd = eventfd(0, EFD_NONBLOCK); + ev[0].events = EPOLLIN; + epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev[0]); + for (auto _ : state) { + int err; + do { + err = eventfd_write(fd, 1); + } while (err < 0 && errno == EINTR); + GPR_ASSERT(err == 0); + do { + err = epoll_wait(epfd, ev, GPR_ARRAY_SIZE(ev), 0); + } while (err < 0 && errno == EINTR); + GPR_ASSERT(err == 1); + eventfd_t value; + do { + err = eventfd_read(fd, &value); + } while (err < 0 && errno == EINTR); + GPR_ASSERT(err == 0); + } + close(fd); + close(epfd); + track_counters.Finish(state); +} +BENCHMARK(BM_SingleThreadPollOneFd_SpeedOfLight); +#endif + +static void BM_SingleThreadPollOneFd(benchmark::State& state) { + TrackCounters track_counters; + size_t ps_sz = grpc_pollset_size(); + grpc_pollset* ps = static_cast(gpr_zalloc(ps_sz)); + gpr_mu* mu; + grpc_pollset_init(ps, &mu); + grpc_core::ExecCtx exec_ctx; + grpc_wakeup_fd wakeup_fd; + GRPC_ERROR_UNREF(grpc_wakeup_fd_init(&wakeup_fd)); + grpc_fd* wakeup = grpc_fd_create(wakeup_fd.read_fd, "wakeup_read", false); + grpc_pollset_add_fd(ps, wakeup); + bool done = false; + TestClosure* continue_closure = MakeTestClosure([&]() { + GRPC_ERROR_UNREF(grpc_wakeup_fd_consume_wakeup(&wakeup_fd)); + if (!state.KeepRunning()) { + done = true; + return; + } + GRPC_ERROR_UNREF(grpc_wakeup_fd_wakeup(&wakeup_fd)); + grpc_fd_notify_on_read(wakeup, continue_closure); + }); + GRPC_ERROR_UNREF(grpc_wakeup_fd_wakeup(&wakeup_fd)); + grpc_fd_notify_on_read(wakeup, continue_closure); + gpr_mu_lock(mu); + while (!done) { + GRPC_ERROR_UNREF(grpc_pollset_work(ps, nullptr, GRPC_MILLIS_INF_FUTURE)); + } + grpc_fd_orphan(wakeup, nullptr, nullptr, "done"); + wakeup_fd.read_fd = 0; + grpc_closure shutdown_ps_closure; + GRPC_CLOSURE_INIT(&shutdown_ps_closure, shutdown_ps, ps, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(ps, &shutdown_ps_closure); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + grpc_wakeup_fd_destroy(&wakeup_fd); + gpr_free(ps); + track_counters.Finish(state); + delete continue_closure; +} +BENCHMARK(BM_SingleThreadPollOneFd); + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_threadpool.cc b/test/cpp/microbenchmarks/bm_threadpool.cc new file mode 100644 index 00000000..601392dd --- /dev/null +++ b/test/cpp/microbenchmarks/bm_threadpool.cc @@ -0,0 +1,331 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include + +#include "src/core/lib/iomgr/executor/threadpool.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +// This helper class allows a thread to block for a pre-specified number of +// actions. BlockingCounter has an initial non-negative count on initialization. +// Each call to DecrementCount will decrease the count by 1. When making a call +// to Wait, if the count is greater than 0, the thread will be blocked, until +// the count reaches 0. +class BlockingCounter { + public: + explicit BlockingCounter(int count) : count_(count) {} + void DecrementCount() { + std::lock_guard l(mu_); + count_--; + if (count_ == 0) cv_.notify_all(); + } + + void Wait() { + std::unique_lock l(mu_); + while (count_ > 0) { + cv_.wait(l); + } + } + + private: + int count_; + std::mutex mu_; + std::condition_variable cv_; +}; + +// This is a functor/closure class for threadpool microbenchmark. +// This functor (closure) class will add another functor into pool if the +// number passed in (num_add) is greater than 0. Otherwise, it will decrement +// the counter to indicate that task is finished. This functor will suicide at +// the end, therefore, no need for caller to do clean-ups. +class AddAnotherFunctor : public grpc_completion_queue_functor { + public: + AddAnotherFunctor(grpc_core::ThreadPool* pool, BlockingCounter* counter, + int num_add) + : pool_(pool), counter_(counter), num_add_(num_add) { + functor_run = &AddAnotherFunctor::Run; + inlineable = false; + internal_next = this; + internal_success = 0; + } + // When the functor gets to run in thread pool, it will take itself as first + // argument and internal_success as second one. + static void Run(grpc_completion_queue_functor* cb, int /*ok*/) { + auto* callback = static_cast(cb); + if (--callback->num_add_ > 0) { + callback->pool_->Add(new AddAnotherFunctor( + callback->pool_, callback->counter_, callback->num_add_)); + } else { + callback->counter_->DecrementCount(); + } + // Suicides. + delete callback; + } + + private: + grpc_core::ThreadPool* pool_; + BlockingCounter* counter_; + int num_add_; +}; + +template +static void ThreadPoolAddAnother(benchmark::State& state) { + const int num_iterations = state.range(0); + const int num_threads = state.range(1); + // Number of adds done by each closure. + const int num_add = num_iterations / kConcurrentFunctor; + grpc_core::ThreadPool pool(num_threads); + while (state.KeepRunningBatch(num_iterations)) { + BlockingCounter counter(kConcurrentFunctor); + for (int i = 0; i < kConcurrentFunctor; ++i) { + pool.Add(new AddAnotherFunctor(&pool, &counter, num_add)); + } + counter.Wait(); + } + state.SetItemsProcessed(state.iterations()); +} + +// First pair of arguments is range for number of iterations (num_iterations). +// Second pair of arguments is range for thread pool size (num_threads). +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 1)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 4)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 8)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 16) + ->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 32) + ->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 64) + ->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 128) + ->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 512) + ->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddAnother, 2048) + ->RangePair(524288, 524288, 1, 1024); + +// A functor class that will delete self on end of running. +class SuicideFunctorForAdd : public grpc_completion_queue_functor { + public: + explicit SuicideFunctorForAdd(BlockingCounter* counter) : counter_(counter) { + functor_run = &SuicideFunctorForAdd::Run; + inlineable = false; + internal_next = this; + internal_success = 0; + } + + static void Run(grpc_completion_queue_functor* cb, int /*ok*/) { + // On running, the first argument would be itself. + auto* callback = static_cast(cb); + callback->counter_->DecrementCount(); + delete callback; + } + + private: + BlockingCounter* counter_; +}; + +// Performs the scenario of external thread(s) adding closures into pool. +static void BM_ThreadPoolExternalAdd(benchmark::State& state) { + static grpc_core::ThreadPool* external_add_pool = nullptr; + int thread_idx = hack::get_thread_idx(state); + // Setup for each run of test. + if (thread_idx == 0) { + const int num_threads = state.range(1); + external_add_pool = new grpc_core::ThreadPool(num_threads); + } + const int num_iterations = state.range(0) / hack::get_threads(state); + while (state.KeepRunningBatch(num_iterations)) { + BlockingCounter counter(num_iterations); + for (int i = 0; i < num_iterations; ++i) { + external_add_pool->Add(new SuicideFunctorForAdd(&counter)); + } + counter.Wait(); + } + + // Teardown at the end of each test run. + if (thread_idx == 0) { + state.SetItemsProcessed(state.range(0)); + delete external_add_pool; + } +} +BENCHMARK(BM_ThreadPoolExternalAdd) + // First pair is range for number of iterations (num_iterations). + // Second pair is range for thread pool size (num_threads). + ->RangePair(524288, 524288, 1, 1024) + ->ThreadRange(1, 256); // Concurrent external thread(s) up to 256 + +// Functor (closure) that adds itself into pool repeatedly. By adding self, the +// overhead would be low and can measure the time of add more accurately. +class AddSelfFunctor : public grpc_completion_queue_functor { + public: + AddSelfFunctor(grpc_core::ThreadPool* pool, BlockingCounter* counter, + int num_add) + : pool_(pool), counter_(counter), num_add_(num_add) { + functor_run = &AddSelfFunctor::Run; + inlineable = false; + internal_next = this; + internal_success = 0; + } + // When the functor gets to run in thread pool, it will take itself as first + // argument and internal_success as second one. + static void Run(grpc_completion_queue_functor* cb, int /*ok*/) { + auto* callback = static_cast(cb); + if (--callback->num_add_ > 0) { + callback->pool_->Add(cb); + } else { + callback->counter_->DecrementCount(); + // Suicides. + delete callback; + } + } + + private: + grpc_core::ThreadPool* pool_; + BlockingCounter* counter_; + int num_add_; +}; + +template +static void ThreadPoolAddSelf(benchmark::State& state) { + const int num_iterations = state.range(0); + const int num_threads = state.range(1); + // Number of adds done by each closure. + const int num_add = num_iterations / kConcurrentFunctor; + grpc_core::ThreadPool pool(num_threads); + while (state.KeepRunningBatch(num_iterations)) { + BlockingCounter counter(kConcurrentFunctor); + for (int i = 0; i < kConcurrentFunctor; ++i) { + pool.Add(new AddSelfFunctor(&pool, &counter, num_add)); + } + counter.Wait(); + } + state.SetItemsProcessed(state.iterations()); +} + +// First pair of arguments is range for number of iterations (num_iterations). +// Second pair of arguments is range for thread pool size (num_threads). +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 1)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 4)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 8)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 16)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 32)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 64)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 128)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 512)->RangePair(524288, 524288, 1, 1024); +BENCHMARK_TEMPLATE(ThreadPoolAddSelf, 2048)->RangePair(524288, 524288, 1, 1024); + +#if defined(__GNUC__) && !defined(SWIG) +#if defined(__i386__) || defined(__x86_64__) +#define CACHELINE_SIZE 64 +#elif defined(__powerpc64__) +#define CACHELINE_SIZE 128 +#elif defined(__aarch64__) +#define CACHELINE_SIZE 64 +#elif defined(__arm__) +#if defined(__ARM_ARCH_5T__) +#define CACHELINE_SIZE 32 +#elif defined(__ARM_ARCH_7A__) +#define CACHELINE_SIZE 64 +#endif +#endif +#ifndef CACHELINE_SIZE +#define CACHELINE_SIZE 64 +#endif +#endif + +// A functor (closure) that simulates closures with small but non-trivial amount +// of work. +class ShortWorkFunctorForAdd : public grpc_completion_queue_functor { + public: + BlockingCounter* counter_; + + ShortWorkFunctorForAdd() { + functor_run = &ShortWorkFunctorForAdd::Run; + inlineable = false; + internal_next = this; + internal_success = 0; + val_ = 0; + } + static void Run(grpc_completion_queue_functor* cb, int /*ok*/) { + auto* callback = static_cast(cb); + // Uses pad to avoid compiler complaining unused variable error. + callback->pad[0] = 0; + for (int i = 0; i < 1000; ++i) { + callback->val_++; + } + callback->counter_->DecrementCount(); + } + + private: + char pad[CACHELINE_SIZE]; + volatile int val_; +}; + +// Simulates workloads where many short running callbacks are added to the +// threadpool. The callbacks are not enough to keep all the workers busy +// continuously so the number of workers running changes overtime. +// +// In effect this tests how well the threadpool avoids spurious wakeups. +static void BM_SpikyLoad(benchmark::State& state) { + const int num_threads = state.range(0); + + const int kNumSpikes = 1000; + const int batch_size = 3 * num_threads; + std::vector work_vector(batch_size); + grpc_core::ThreadPool pool(num_threads); + while (state.KeepRunningBatch(kNumSpikes * batch_size)) { + for (int i = 0; i != kNumSpikes; ++i) { + BlockingCounter counter(batch_size); + for (auto& w : work_vector) { + w.counter_ = &counter; + pool.Add(&w); + } + counter.Wait(); + } + } + state.SetItemsProcessed(state.iterations() * batch_size); +} +BENCHMARK(BM_SpikyLoad)->Arg(1)->Arg(2)->Arg(4)->Arg(8)->Arg(16); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char* argv[]) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/bm_timer.cc b/test/cpp/microbenchmarks/bm_timer.cc new file mode 100644 index 00000000..8e3a74ab --- /dev/null +++ b/test/cpp/microbenchmarks/bm_timer.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include + +#include +#include +#include + +#include "src/core/lib/iomgr/timer.h" +#include "test/core/util/test_config.h" +#include "test/cpp/microbenchmarks/helpers.h" +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +struct TimerClosure { + grpc_timer timer; + grpc_closure closure; +}; + +static void BM_InitCancelTimer(benchmark::State& state) { + constexpr int kTimerCount = 1024; + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + std::vector timer_closures(kTimerCount); + int i = 0; + for (auto _ : state) { + TimerClosure* timer_closure = &timer_closures[i++ % kTimerCount]; + GRPC_CLOSURE_INIT( + &timer_closure->closure, + [](void* /*args*/, grpc_error_handle /*err*/) {}, nullptr, + grpc_schedule_on_exec_ctx); + grpc_timer_init(&timer_closure->timer, GRPC_MILLIS_INF_FUTURE, + &timer_closure->closure); + grpc_timer_cancel(&timer_closure->timer); + exec_ctx.Flush(); + } + track_counters.Finish(state); +} +BENCHMARK(BM_InitCancelTimer); + +static void BM_TimerBatch(benchmark::State& state) { + constexpr int kTimerCount = 1024; + const bool check = state.range(0); + const bool reverse = state.range(1); + + const grpc_millis start = + reverse ? GRPC_MILLIS_INF_FUTURE : GRPC_MILLIS_INF_FUTURE - kTimerCount; + const grpc_millis end = + reverse ? GRPC_MILLIS_INF_FUTURE - kTimerCount : GRPC_MILLIS_INF_FUTURE; + const grpc_millis increment = reverse ? -1 : 1; + + TrackCounters track_counters; + grpc_core::ExecCtx exec_ctx; + std::vector timer_closures(kTimerCount); + for (auto _ : state) { + for (grpc_millis deadline = start; deadline != end; deadline += increment) { + TimerClosure* timer_closure = &timer_closures[deadline % kTimerCount]; + GRPC_CLOSURE_INIT( + &timer_closure->closure, + [](void* /*args*/, grpc_error_handle /*err*/) {}, nullptr, + grpc_schedule_on_exec_ctx); + + grpc_timer_init(&timer_closure->timer, deadline, &timer_closure->closure); + } + if (check) { + grpc_millis next = GRPC_MILLIS_INF_FUTURE; + grpc_timer_check(&next); + } + for (grpc_millis deadline = start; deadline != end; deadline += increment) { + TimerClosure* timer_closure = &timer_closures[deadline % kTimerCount]; + grpc_timer_cancel(&timer_closure->timer); + } + exec_ctx.Flush(); + } + track_counters.Finish(state); +} +BENCHMARK(BM_TimerBatch) + ->Args({/*check=*/false, /*reverse=*/false}) + ->Args({/*check=*/false, /*reverse=*/true}) + ->Args({/*check=*/true, /*reverse=*/false}) + ->Args({/*check=*/true, /*reverse=*/true}) + ->ThreadRange(1, 128); + +} // namespace testing +} // namespace grpc + +// Some distros have RunSpecifiedBenchmarks under the benchmark namespace, +// and others do not. This allows us to support both modes. +namespace benchmark { +void RunTheBenchmarksNamespaced() { RunSpecifiedBenchmarks(); } +} // namespace benchmark + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + LibraryInitializer libInit; + ::benchmark::Initialize(&argc, argv); + ::grpc::testing::InitTest(&argc, &argv, false); + benchmark::RunTheBenchmarksNamespaced(); + return 0; +} diff --git a/test/cpp/microbenchmarks/callback_test_service.cc b/test/cpp/microbenchmarks/callback_test_service.cc new file mode 100644 index 00000000..049e361c --- /dev/null +++ b/test/cpp/microbenchmarks/callback_test_service.cc @@ -0,0 +1,110 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/microbenchmarks/callback_test_service.h" + +namespace grpc { +namespace testing { +namespace { + +std::string ToString(const grpc::string_ref& r) { + return std::string(r.data(), r.size()); +} + +int GetIntValueFromMetadataHelper( + const char* key, + const std::multimap& metadata, + int default_value) { + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + } + + return default_value; +} + +int GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} +} // namespace + +ServerUnaryReactor* CallbackStreamingTestService::Echo( + CallbackServerContext* context, const EchoRequest* /*request*/, + EchoResponse* response) { + int response_msgs_size = GetIntValueFromMetadata( + kServerMessageSize, context->client_metadata(), 0); + if (response_msgs_size > 0) { + response->set_message(std::string(response_msgs_size, 'a')); + } else { + response->set_message(""); + } + auto* reactor = context->DefaultReactor(); + reactor->Finish(::grpc::Status::OK); + return reactor; +} + +ServerBidiReactor* +CallbackStreamingTestService::BidiStream(CallbackServerContext* context) { + class Reactor : public ServerBidiReactor { + public: + explicit Reactor(CallbackServerContext* context) { + message_size_ = GetIntValueFromMetadata(kServerMessageSize, + context->client_metadata(), 0); + StartRead(&request_); + } + void OnDone() override { + GPR_ASSERT(finished_); + delete this; + } + void OnCancel() override {} + void OnReadDone(bool ok) override { + if (!ok) { + // Stream is over + Finish(::grpc::Status::OK); + finished_ = true; + return; + } + if (message_size_ > 0) { + response_.set_message(std::string(message_size_, 'a')); + } else { + response_.set_message(""); + } + StartWrite(&response_); + } + void OnWriteDone(bool ok) override { + if (!ok) { + gpr_log(GPR_ERROR, "Server write failed"); + return; + } + StartRead(&request_); + } + + private: + EchoRequest request_; + EchoResponse response_; + int message_size_; + bool finished_{false}; + }; + + return new Reactor(context); +} +} // namespace testing +} // namespace grpc diff --git a/test/cpp/microbenchmarks/helpers.cc b/test/cpp/microbenchmarks/helpers.cc new file mode 100644 index 00000000..d46321c2 --- /dev/null +++ b/test/cpp/microbenchmarks/helpers.cc @@ -0,0 +1,109 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/microbenchmarks/helpers.h" + +#include + +static grpc::internal::GrpcLibraryInitializer g_gli_initializer; +static LibraryInitializer* g_libraryInitializer; + +LibraryInitializer::LibraryInitializer() { + GPR_ASSERT(g_libraryInitializer == nullptr); + g_libraryInitializer = this; + + g_gli_initializer.summon(); +#ifdef GPR_LOW_LEVEL_COUNTERS + grpc_memory_counters_init(); +#endif + init_lib_.init(); +} + +LibraryInitializer::~LibraryInitializer() { + g_libraryInitializer = nullptr; + init_lib_.shutdown(); +} + +LibraryInitializer& LibraryInitializer::get() { + GPR_ASSERT(g_libraryInitializer != nullptr); + return *g_libraryInitializer; +} + +void TrackCounters::Finish(benchmark::State& state) { + std::ostringstream out; + for (const auto& l : labels_) { + out << l << ' '; + } + AddToLabel(out, state); + std::string label = out.str(); + if (label.length() && label[0] == ' ') { + label = label.substr(1); + } + state.SetLabel(label.c_str()); +} + +void TrackCounters::AddLabel(const std::string& label) { + labels_.push_back(label); +} + +void TrackCounters::AddToLabel(std::ostream& out, benchmark::State& state) { + // Use the parameters to avoid unused-parameter warnings depending on the + // #define's present + (void)out; + (void)state; +#ifdef GRPC_COLLECT_STATS + grpc_stats_data stats_end; + grpc_stats_collect(&stats_end); + grpc_stats_data stats; + grpc_stats_diff(&stats_end, &stats_begin_, &stats); + for (int i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + out << " " << grpc_stats_counter_name[i] << "/iter:" + << (static_cast(stats.counters[i]) / + static_cast(state.iterations())); + } + for (int i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + out << " " << grpc_stats_histogram_name[i] << "-median:" + << grpc_stats_histo_percentile(&stats, (grpc_stats_histograms)i, 50.0) + << " " << grpc_stats_histogram_name[i] << "-99p:" + << grpc_stats_histo_percentile(&stats, (grpc_stats_histograms)i, 99.0); + } +#endif +#ifdef GPR_LOW_LEVEL_COUNTERS + grpc_memory_counters counters_at_end = grpc_memory_counters_snapshot(); + out << " locks/iter:" + << ((double)(gpr_atm_no_barrier_load(&gpr_mu_locks) - + mu_locks_at_start_) / + (double)state.iterations()) + << " atm_cas/iter:" + << ((double)(gpr_atm_no_barrier_load(&gpr_counter_atm_cas) - + atm_cas_at_start_) / + (double)state.iterations()) + << " atm_add/iter:" + << ((double)(gpr_atm_no_barrier_load(&gpr_counter_atm_add) - + atm_add_at_start_) / + (double)state.iterations()) + << " nows/iter:" + << ((double)(gpr_atm_no_barrier_load(&gpr_now_call_count) - + now_calls_at_start_) / + (double)state.iterations()) + << " allocs/iter:" + << ((double)(counters_at_end.total_allocs_absolute - + counters_at_start_.total_allocs_absolute) / + (double)state.iterations()); +#endif +} diff --git a/test/cpp/microbenchmarks/noop-benchmark.cc b/test/cpp/microbenchmarks/noop-benchmark.cc new file mode 100644 index 00000000..49ffbf84 --- /dev/null +++ b/test/cpp/microbenchmarks/noop-benchmark.cc @@ -0,0 +1,30 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* This benchmark exists to ensure that the benchmark integration is + * working */ + +#include + +static void BM_NoOp(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_NoOp); + +BENCHMARK_MAIN(); diff --git a/test/cpp/naming/address_sorting_test.cc b/test/cpp/naming/address_sorting_test.cc new file mode 100644 index 00000000..33603665 --- /dev/null +++ b/test/cpp/naming/address_sorting_test.cc @@ -0,0 +1,843 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" +#include "test/cpp/util/test_config.h" + +#ifndef GPR_WINDOWS +#include +#include +#include +#endif + +namespace { + +struct TestAddress { + std::string dest_addr; + int family; +}; + +grpc_resolved_address TestAddressToGrpcResolvedAddress(TestAddress test_addr) { + std::string host; + std::string port; + grpc_resolved_address resolved_addr; + grpc_core::SplitHostPort(test_addr.dest_addr.c_str(), &host, &port); + if (test_addr.family == AF_INET) { + sockaddr_in in_dest; + memset(&in_dest, 0, sizeof(sockaddr_in)); + in_dest.sin_port = htons(atoi(port.c_str())); + in_dest.sin_family = AF_INET; + GPR_ASSERT(inet_pton(AF_INET, host.c_str(), &in_dest.sin_addr) == 1); + memcpy(&resolved_addr.addr, &in_dest, sizeof(sockaddr_in)); + resolved_addr.len = sizeof(sockaddr_in); + } else { + GPR_ASSERT(test_addr.family == AF_INET6); + sockaddr_in6 in6_dest; + memset(&in6_dest, 0, sizeof(sockaddr_in6)); + in6_dest.sin6_port = htons(atoi(port.c_str())); + in6_dest.sin6_family = AF_INET6; + GPR_ASSERT(inet_pton(AF_INET6, host.c_str(), &in6_dest.sin6_addr) == 1); + memcpy(&resolved_addr.addr, &in6_dest, sizeof(sockaddr_in6)); + resolved_addr.len = sizeof(sockaddr_in6); + } + return resolved_addr; +} + +class MockSourceAddrFactory : public address_sorting_source_addr_factory { + public: + MockSourceAddrFactory( + bool ipv4_supported, bool ipv6_supported, + const std::map& dest_addr_to_src_addr) + : ipv4_supported_(ipv4_supported), + ipv6_supported_(ipv6_supported), + dest_addr_to_src_addr_(dest_addr_to_src_addr) {} + + bool GetSourceAddr(const address_sorting_address* dest_addr, + address_sorting_address* source_addr) { + if ((address_sorting_abstract_get_family(dest_addr) == + ADDRESS_SORTING_AF_INET && + !ipv4_supported_) || + (address_sorting_abstract_get_family(dest_addr) == + ADDRESS_SORTING_AF_INET6 && + !ipv6_supported_)) { + return false; + } + grpc_resolved_address dest_addr_as_resolved_addr; + memcpy(&dest_addr_as_resolved_addr.addr, dest_addr, dest_addr->len); + dest_addr_as_resolved_addr.len = dest_addr->len; + std::string ip_addr_str = grpc_sockaddr_to_string( + &dest_addr_as_resolved_addr, false /* normalize */); + auto it = dest_addr_to_src_addr_.find(ip_addr_str); + if (it == dest_addr_to_src_addr_.end()) { + gpr_log(GPR_DEBUG, "can't find |%s| in dest to src map", + ip_addr_str.c_str()); + return false; + } + grpc_resolved_address source_addr_as_resolved_addr = + TestAddressToGrpcResolvedAddress(it->second); + memcpy(source_addr->addr, &source_addr_as_resolved_addr.addr, + source_addr_as_resolved_addr.len); + source_addr->len = source_addr_as_resolved_addr.len; + return true; + } + + private: + // user provided test config + bool ipv4_supported_; + bool ipv6_supported_; + std::map dest_addr_to_src_addr_; +}; + +static bool mock_source_addr_factory_wrapper_get_source_addr( + address_sorting_source_addr_factory* factory, + const address_sorting_address* dest_addr, + address_sorting_address* source_addr) { + MockSourceAddrFactory* mock = + reinterpret_cast(factory); + return mock->GetSourceAddr(dest_addr, source_addr); +} + +void mock_source_addr_factory_wrapper_destroy( + address_sorting_source_addr_factory* factory) { + MockSourceAddrFactory* mock = + reinterpret_cast(factory); + delete mock; +} + +const address_sorting_source_addr_factory_vtable kMockSourceAddrFactoryVtable = + { + mock_source_addr_factory_wrapper_get_source_addr, + mock_source_addr_factory_wrapper_destroy, +}; + +void OverrideAddressSortingSourceAddrFactory( + bool ipv4_supported, bool ipv6_supported, + const std::map& dest_addr_to_src_addr) { + address_sorting_source_addr_factory* factory = new MockSourceAddrFactory( + ipv4_supported, ipv6_supported, dest_addr_to_src_addr); + factory->vtable = &kMockSourceAddrFactoryVtable; + address_sorting_override_source_addr_factory_for_testing(factory); +} + +grpc_core::ServerAddressList BuildLbAddrInputs( + const std::vector& test_addrs) { + grpc_core::ServerAddressList addresses; + for (const auto& addr : test_addrs) { + addresses.emplace_back(TestAddressToGrpcResolvedAddress(addr), nullptr); + } + return addresses; +} + +void VerifyLbAddrOutputs(const grpc_core::ServerAddressList& addresses, + std::vector expected_addrs) { + EXPECT_EQ(addresses.size(), expected_addrs.size()); + for (size_t i = 0; i < addresses.size(); ++i) { + std::string ip_addr_str = + grpc_sockaddr_to_string(&addresses[i].address(), false /* normalize */); + EXPECT_EQ(expected_addrs[i], ip_addr_str); + } +} + +/* We need to run each test case inside of its own + * isolated grpc_init/grpc_shutdown pair, so that + * the "address sorting source addr factory" can be + * restored to its default for each test case. */ +class AddressSortingTest : public ::testing::Test { + protected: + void SetUp() override { grpc_init(); } + void TearDown() override { grpc_shutdown(); } +}; + +/* Tests for rule 1 */ +TEST_F(AddressSortingTest, TestDepriotizesUnreachableAddresses) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"1.2.3.4:443", {"4.3.2.1:443", AF_INET}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"1.2.3.4:443", AF_INET}, + {"5.6.7.8:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "1.2.3.4:443", + "5.6.7.8:443", + }); +} + +TEST_F(AddressSortingTest, TestDepriotizesUnsupportedDomainIpv6) { + bool ipv4_supported = true; + bool ipv6_supported = false; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"1.2.3.4:443", {"4.3.2.1:0", AF_INET}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2607:f8b0:400a:801::1002]:443", AF_INET6}, + {"1.2.3.4:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "1.2.3.4:443", + "[2607:f8b0:400a:801::1002]:443", + }); +} + +TEST_F(AddressSortingTest, TestDepriotizesUnsupportedDomainIpv4) { + bool ipv4_supported = false; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"1.2.3.4:443", {"4.3.2.1:0", AF_INET}}, + {"[2607:f8b0:400a:801::1002]:443", {"[fec0::1234]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2607:f8b0:400a:801::1002]:443", AF_INET6}, + {"1.2.3.4:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[2607:f8b0:400a:801::1002]:443", + "1.2.3.4:443", + }); +} + +/* Tests for rule 2 */ + +TEST_F(AddressSortingTest, TestDepriotizesNonMatchingScope) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[2000:f8b0:400a:801::1002]:443", + {"[fec0::1000]:0", AF_INET6}}, // global and site-local scope + {"[fec0::5000]:443", + {"[fec0::5001]:0", AF_INET6}}, // site-local and site-local scope + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2000:f8b0:400a:801::1002]:443", AF_INET6}, + {"[fec0::5000]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[fec0::5000]:443", + "[2000:f8b0:400a:801::1002]:443", + }); +} + +/* Tests for rule 5 */ + +TEST_F(AddressSortingTest, TestUsesLabelFromDefaultTable) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[2002::5001]:443", {"[2001::5002]:0", AF_INET6}}, + {"[2001::5001]:443", + {"[2001::5002]:0", AF_INET6}}, // matching labels + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2002::5001]:443", AF_INET6}, + {"[2001::5001]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[2001::5001]:443", + "[2002::5001]:443", + }); +} + +/* Flip the input on the test above to reorder the sort function's + * comparator's inputs. */ +TEST_F(AddressSortingTest, TestUsesLabelFromDefaultTableInputFlipped) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[2002::5001]:443", {"[2001::5002]:0", AF_INET6}}, + {"[2001::5001]:443", + {"[2001::5002]:0", AF_INET6}}, // matching labels + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2001::5001]:443", AF_INET6}, + {"[2002::5001]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[2001::5001]:443", + "[2002::5001]:443", + }); +} + +/* Tests for rule 6 */ + +TEST_F(AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWithAnIpv4Address) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe::5001]:443", {"[3ffe::5002]:0", AF_INET6}}, + {"1.2.3.4:443", {"5.6.7.8:0", AF_INET}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::5001]:443", AF_INET6}, + {"1.2.3.4:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs( + lb_addrs, { + // The AF_INET address should be IPv4-mapped by the sort, + // and IPv4-mapped + // addresses have higher precedence than 3ffe::/16 by spec. + "1.2.3.4:443", + "[3ffe::5001]:443", + }); +} + +TEST_F(AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWithV4CompatAndLocalhostAddress) { + bool ipv4_supported = true; + bool ipv6_supported = true; +// Handle unique observed behavior of inet_ntop(v4-compatible-address) on OS X. +#if GPR_APPLE == 1 + const char* v4_compat_dest = "[::0.0.0.2]:443"; + const char* v4_compat_src = "[::0.0.0.2]:0"; +#else + const char* v4_compat_dest = "[::2]:443"; + const char* v4_compat_src = "[::2]:0"; +#endif + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[::1]:443", {"[::1]:0", AF_INET6}}, + {v4_compat_dest, {v4_compat_src, AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {v4_compat_dest, AF_INET6}, + {"[::1]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[::1]:443", + v4_compat_dest, + }); +} + +TEST_F(AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWithCatchAllAndLocalhostAddress) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + // 1234::2 for src and dest to make sure that prefix matching has no + // influence on this test. + {"[1234::2]:443", {"[1234::2]:0", AF_INET6}}, + {"[::1]:443", {"[::1]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[1234::2]:443", AF_INET6}, + {"[::1]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs( + lb_addrs, + { + // ::1 should match the localhost precedence entry and be prioritized + "[::1]:443", + "[1234::2]:443", + }); +} + +TEST_F(AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWith2000PrefixedAddress) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[2001::1234]:443", {"[2001::5678]:0", AF_INET6}}, + {"[2000::5001]:443", {"[2000::5002]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2001::1234]:443", AF_INET6}, + {"[2000::5001]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs( + lb_addrs, { + // The 2000::/16 address should match the ::/0 prefix rule + "[2000::5001]:443", + "[2001::1234]:443", + }); +} + +TEST_F( + AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWith2000PrefixedAddressEnsurePrefixMatchHasNoEffect) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[2001::1231]:443", {"[2001::1232]:0", AF_INET6}}, + {"[2000::5001]:443", {"[2000::5002]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[2001::1231]:443", AF_INET6}, + {"[2000::5001]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[2000::5001]:443", + "[2001::1231]:443", + }); +} + +TEST_F(AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWithLinkAndSiteLocalAddresses) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[fec0::1234]:443", {"[fec0::5678]:0", AF_INET6}}, + {"[fc00::5001]:443", {"[fc00::5002]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[fec0::1234]:443", AF_INET6}, + {"[fc00::5001]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[fc00::5001]:443", + "[fec0::1234]:443", + }); +} + +TEST_F( + AddressSortingTest, + TestUsesDestinationWithHigherPrecedenceWithCatchAllAndAndV4MappedAddresses) { + bool ipv4_supported = true; + bool ipv6_supported = true; + // Use embedded ipv4 addresses with leading 1's instead of zero's to be + // compatible with inet_ntop implementations that can display such + // addresses with leading zero's as e.g.: "::ffff:0:2", as on windows. + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[::ffff:1.1.1.2]:443", {"[::ffff:1.1.1.3]:0", AF_INET6}}, + {"[1234::2]:443", {"[1234::3]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[::ffff:1.1.1.2]:443", AF_INET6}, + {"[1234::2]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + // ::ffff:0:2 should match the v4-mapped + // precedence entry and be deprioritized. + "[1234::2]:443", + "[::ffff:1.1.1.2]:443", + }); +} + +/* Tests for rule 8 */ + +TEST_F(AddressSortingTest, TestPrefersSmallerScope) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + // Both of these destinations have the same precedence in default + // policy + // table. + {"[fec0::1234]:443", {"[fec0::5678]:0", AF_INET6}}, + {"[3ffe::5001]:443", {"[3ffe::5002]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::5001]:443", AF_INET6}, + {"[fec0::1234]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[fec0::1234]:443", + "[3ffe::5001]:443", + }); +} + +/* Tests for rule 9 */ + +TEST_F(AddressSortingTest, TestPrefersLongestMatchingSrcDstPrefix) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + // Both of these destinations have the same precedence in default + // policy + // table. + {"[3ffe:1234::]:443", {"[3ffe:1235::]:0", AF_INET6}}, + {"[3ffe:5001::]:443", {"[3ffe:4321::]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe:5001::]:443", AF_INET6}, + {"[3ffe:1234::]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe:1234::]:443", + "[3ffe:5001::]:443", + }); +} + +TEST_F(AddressSortingTest, + TestPrefersLongestMatchingSrcDstPrefixMatchesWholeAddress) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe::1234]:443", {"[3ffe::1235]:0", AF_INET6}}, + {"[3ffe::5001]:443", {"[3ffe::4321]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::5001]:443", AF_INET6}, + {"[3ffe::1234]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe::1234]:443", + "[3ffe::5001]:443", + }); +} + +TEST_F(AddressSortingTest, TestPrefersLongestPrefixStressInnerBytePrefix) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe:8000::]:443", {"[3ffe:C000::]:0", AF_INET6}}, + {"[3ffe:2000::]:443", {"[3ffe:3000::]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe:8000::]:443", AF_INET6}, + {"[3ffe:2000::]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe:2000::]:443", + "[3ffe:8000::]:443", + }); +} + +TEST_F(AddressSortingTest, TestPrefersLongestPrefixDiffersOnHighestBitOfByte) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe:6::]:443", {"[3ffe:8::]:0", AF_INET6}}, + {"[3ffe:c::]:443", {"[3ffe:8::]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe:6::]:443", AF_INET6}, + {"[3ffe:c::]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe:c::]:443", + "[3ffe:6::]:443", + }); +} + +TEST_F(AddressSortingTest, TestPrefersLongestPrefixDiffersByLastBit) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe:1111:1111:1111::]:443", + {"[3ffe:1111:1111:1111::]:0", AF_INET6}}, + {"[3ffe:1111:1111:1110::]:443", + {"[3ffe:1111:1111:1111::]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe:1111:1111:1110::]:443", AF_INET6}, + {"[3ffe:1111:1111:1111::]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe:1111:1111:1111::]:443", + "[3ffe:1111:1111:1110::]:443", + }); +} + +/* Tests for rule 10 */ + +TEST_F(AddressSortingTest, TestStableSort) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe::1234]:443", {"[3ffe::1236]:0", AF_INET6}}, + {"[3ffe::1235]:443", {"[3ffe::1237]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::1234]:443", AF_INET6}, + {"[3ffe::1235]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe::1234]:443", + "[3ffe::1235]:443", + }); +} + +TEST_F(AddressSortingTest, TestStableSortFiveElements) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[3ffe::1231]:443", {"[3ffe::1201]:0", AF_INET6}}, + {"[3ffe::1232]:443", {"[3ffe::1202]:0", AF_INET6}}, + {"[3ffe::1233]:443", {"[3ffe::1203]:0", AF_INET6}}, + {"[3ffe::1234]:443", {"[3ffe::1204]:0", AF_INET6}}, + {"[3ffe::1235]:443", {"[3ffe::1205]:0", AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::1231]:443", AF_INET6}, + {"[3ffe::1232]:443", AF_INET6}, + {"[3ffe::1233]:443", AF_INET6}, + {"[3ffe::1234]:443", AF_INET6}, + {"[3ffe::1235]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe::1231]:443", + "[3ffe::1232]:443", + "[3ffe::1233]:443", + "[3ffe::1234]:443", + "[3ffe::1235]:443", + }); +} + +TEST_F(AddressSortingTest, TestStableSortNoSrcAddrsExist) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory(ipv4_supported, ipv6_supported, {}); + auto lb_addrs = BuildLbAddrInputs({ + {"[3ffe::1231]:443", AF_INET6}, + {"[3ffe::1232]:443", AF_INET6}, + {"[3ffe::1233]:443", AF_INET6}, + {"[3ffe::1234]:443", AF_INET6}, + {"[3ffe::1235]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[3ffe::1231]:443", + "[3ffe::1232]:443", + "[3ffe::1233]:443", + "[3ffe::1234]:443", + "[3ffe::1235]:443", + }); +} + +TEST_F(AddressSortingTest, TestStableSortNoSrcAddrsExistWithIpv4) { + bool ipv4_supported = true; + bool ipv6_supported = true; + OverrideAddressSortingSourceAddrFactory(ipv4_supported, ipv6_supported, {}); + auto lb_addrs = BuildLbAddrInputs({ + {"[::ffff:5.6.7.8]:443", AF_INET6}, + {"1.2.3.4:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[::ffff:5.6.7.8]:443", + "1.2.3.4:443", + }); +} + +TEST_F(AddressSortingTest, TestStableSortV4CompatAndSiteLocalAddresses) { + bool ipv4_supported = true; + bool ipv6_supported = true; +// Handle unique observed behavior of inet_ntop(v4-compatible-address) on OS X. +#if GPR_APPLE == 1 + const char* v4_compat_dest = "[::0.0.0.2]:443"; + const char* v4_compat_src = "[::0.0.0.3]:0"; +#else + const char* v4_compat_dest = "[::2]:443"; + const char* v4_compat_src = "[::3]:0"; +#endif + OverrideAddressSortingSourceAddrFactory( + ipv4_supported, ipv6_supported, + { + {"[fec0::2000]:443", {"[fec0::2001]:0", AF_INET6}}, + {v4_compat_dest, {v4_compat_src, AF_INET6}}, + }); + auto lb_addrs = BuildLbAddrInputs({ + {"[fec0::2000]:443", AF_INET6}, + {v4_compat_dest, AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, + { + // The sort should be stable since + // v4-compatible has same precedence as site-local. + "[fec0::2000]:443", + v4_compat_dest, + }); +} + +/* TestPrefersIpv6Loopback tests the actual "address probing" code + * for the current platform, without any mocks. + * This test relies on the assumption that the ipv6 loopback address is + * available in the hosts/containers that grpc C/C++ tests run on + * (whether ipv4 loopback is available or not, an available ipv6 + * loopback should be preferred). */ +TEST_F(AddressSortingTest, TestPrefersIpv6Loopback) { + auto lb_addrs = BuildLbAddrInputs({ + {"[::1]:443", AF_INET6}, + {"127.0.0.1:443", AF_INET}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[::1]:443", + "127.0.0.1:443", + }); +} + +/* Flip the order of the inputs above and expect the same output order + * (try to rule out influence of arbitrary qsort ordering) */ +TEST_F(AddressSortingTest, TestPrefersIpv6LoopbackInputsFlipped) { + auto lb_addrs = BuildLbAddrInputs({ + {"127.0.0.1:443", AF_INET}, + {"[::1]:443", AF_INET6}, + }); + grpc_cares_wrapper_address_sorting_sort(nullptr, &lb_addrs); + VerifyLbAddrOutputs(lb_addrs, { + "[::1]:443", + "127.0.0.1:443", + }); +} + +/* Try to rule out false positives in the above two tests in which + * the sorter might think that neither ipv6 or ipv4 loopback is + * available, but ipv6 loopback is still preferred only due + * to precedence table lookups. */ +TEST_F(AddressSortingTest, TestSorterKnowsIpv6LoopbackIsAvailable) { + sockaddr_in6 ipv6_loopback; + memset(&ipv6_loopback, 0, sizeof(ipv6_loopback)); + ipv6_loopback.sin6_family = AF_INET6; + (reinterpret_cast(&ipv6_loopback.sin6_addr))[15] = 1; + ipv6_loopback.sin6_port = htons(443); + // Set up the source and destination parameters of + // address_sorting_get_source_addr + address_sorting_address sort_input_dest; + memcpy(&sort_input_dest.addr, &ipv6_loopback, sizeof(ipv6_loopback)); + sort_input_dest.len = sizeof(ipv6_loopback); + address_sorting_address source_for_sort_input_dest; + memset(&source_for_sort_input_dest, 0, sizeof(source_for_sort_input_dest)); + // address_sorting_get_source_addr returns true if a source address was found + // for the destination address, otherwise false. + EXPECT_TRUE(address_sorting_get_source_addr_for_testing( + &sort_input_dest, &source_for_sort_input_dest)); + // Now also check that the source address was filled in correctly. + EXPECT_GT(source_for_sort_input_dest.len, 0u); + sockaddr_in6* source_addr_output = + reinterpret_cast(source_for_sort_input_dest.addr); + EXPECT_EQ(source_addr_output->sin6_family, AF_INET6); + char* buf = static_cast(gpr_zalloc(100)); + EXPECT_NE(inet_ntop(AF_INET6, &source_addr_output->sin6_addr, buf, 100), + nullptr) + << "inet_ntop failed. Errno: " + std::to_string(errno); + std::string source_addr_str(buf); + gpr_free(buf); + // This test + // assumes that the source address for any loopback destination is also the + // loopback address. + EXPECT_EQ(source_addr_str, "::1"); +} + +} // namespace + +int main(int argc, char** argv) { + grpc_core::UniquePtr resolver = + GPR_GLOBAL_CONFIG_GET(grpc_dns_resolver); + if (strlen(resolver.get()) == 0) { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "ares"); + } else if (strcmp("ares", resolver.get()) != 0) { + gpr_log(GPR_INFO, "GRPC_DNS_RESOLVER != ares: %s.", resolver.get()); + } + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); + // Test sequential and nested inits and shutdowns. + grpc_init(); + grpc_init(); + grpc_shutdown(); + grpc_shutdown(); + grpc_init(); + grpc_shutdown(); + return result; +} diff --git a/test/cpp/naming/cancel_ares_query_test.cc b/test/cpp/naming/cancel_ares_query_test.cc new file mode 100644 index 00000000..ff9ce382 --- /dev/null +++ b/test/cpp/naming/cancel_ares_query_test.cc @@ -0,0 +1,414 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/dns_resolver_selection.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/debug/stats.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/end2end/cq_verifier.h" +#include "test/core/util/cmdline.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/naming/dns_test_util.h" + +#ifdef GPR_WINDOWS +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#define BAD_SOCKET_RETURN_VAL INVALID_SOCKET +#else +#include "src/core/lib/iomgr/sockaddr_posix.h" +#define BAD_SOCKET_RETURN_VAL (-1) +#endif + +namespace { + +void* Tag(intptr_t t) { return reinterpret_cast(t); } + +gpr_timespec FiveSecondsFromNow(void) { + return grpc_timeout_seconds_to_deadline(5); +} + +void DrainCq(grpc_completion_queue* cq) { + grpc_event ev; + do { + ev = grpc_completion_queue_next(cq, FiveSecondsFromNow(), nullptr); + } while (ev.type != GRPC_QUEUE_SHUTDOWN); +} + +void EndTest(grpc_channel* client, grpc_completion_queue* cq) { + grpc_channel_destroy(client); + grpc_completion_queue_shutdown(cq); + DrainCq(cq); + grpc_completion_queue_destroy(cq); +} + +struct ArgsStruct { + gpr_atm done_atm; + gpr_mu* mu; + grpc_pollset* pollset; + grpc_pollset_set* pollset_set; + std::shared_ptr lock; + grpc_channel_args* channel_args; +}; + +void ArgsInit(ArgsStruct* args) { + args->pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(args->pollset, &args->mu); + args->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(args->pollset_set, args->pollset); + args->lock = std::make_shared(); + gpr_atm_rel_store(&args->done_atm, 0); + args->channel_args = nullptr; +} + +void DoNothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void ArgsFinish(ArgsStruct* args) { + grpc_pollset_set_del_pollset(args->pollset_set, args->pollset); + grpc_pollset_set_destroy(args->pollset_set); + grpc_closure DoNothing_cb; + GRPC_CLOSURE_INIT(&DoNothing_cb, DoNothing, nullptr, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(args->pollset, &DoNothing_cb); + // exec_ctx needs to be flushed before calling grpc_pollset_destroy() + grpc_channel_args_destroy(args->channel_args); + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(args->pollset); + gpr_free(args->pollset); +} + +void PollPollsetUntilRequestDone(ArgsStruct* args) { + while (true) { + bool done = gpr_atm_acq_load(&args->done_atm) != 0; + if (done) { + break; + } + grpc_pollset_worker* worker = nullptr; + grpc_core::ExecCtx exec_ctx; + gpr_mu_lock(args->mu); + GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(args->pollset, &worker, + grpc_timespec_to_millis_round_up( + gpr_inf_future(GPR_CLOCK_REALTIME)))); + gpr_mu_unlock(args->mu); + } +} + +class AssertFailureResultHandler : public grpc_core::Resolver::ResultHandler { + public: + explicit AssertFailureResultHandler(ArgsStruct* args) : args_(args) {} + + ~AssertFailureResultHandler() override { + gpr_atm_rel_store(&args_->done_atm, 1); + gpr_mu_lock(args_->mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(args_->pollset, nullptr)); + gpr_mu_unlock(args_->mu); + } + + void ReturnResult(grpc_core::Resolver::Result /*result*/) override { + GPR_ASSERT(false); + } + + void ReturnError(grpc_error_handle /*error*/) override { GPR_ASSERT(false); } + + private: + ArgsStruct* args_; +}; + +void TestCancelActiveDNSQuery(ArgsStruct* args) { + int fake_dns_port = grpc_pick_unused_port_or_die(); + grpc::testing::FakeNonResponsiveDNSServer fake_dns_server(fake_dns_port); + std::string client_target = absl::StrFormat( + "dns://[::1]:%d/dont-care-since-wont-be-resolved.test.com:1234", + fake_dns_port); + // create resolver and resolve + grpc_core::OrphanablePtr resolver = + grpc_core::ResolverRegistry::CreateResolver( + client_target.c_str(), nullptr, args->pollset_set, args->lock, + std::unique_ptr( + new AssertFailureResultHandler(args))); + resolver->StartLocked(); + // Without resetting and causing resolver shutdown, the + // PollPollsetUntilRequestDone call should never finish. + resolver.reset(); + grpc_core::ExecCtx::Get()->Flush(); + PollPollsetUntilRequestDone(args); + ArgsFinish(args); +} + +class CancelDuringAresQuery : public ::testing::Test { + protected: + static void SetUpTestCase() { + GPR_GLOBAL_CONFIG_SET(grpc_dns_resolver, "ares"); + // Sanity check the time that it takes to run the test + // including the teardown time (the teardown + // part of the test involves cancelling the DNS query, + // which is the main point of interest for this test). + overall_deadline = grpc_timeout_seconds_to_deadline(4); + grpc_init(); + } + + static void TearDownTestCase() { + grpc_shutdown(); + if (gpr_time_cmp(gpr_now(GPR_CLOCK_MONOTONIC), overall_deadline) > 0) { + gpr_log(GPR_ERROR, "Test took too long"); + abort(); + } + } + + private: + static gpr_timespec overall_deadline; +}; +gpr_timespec CancelDuringAresQuery::overall_deadline; + +TEST_F(CancelDuringAresQuery, TestCancelActiveDNSQuery) { + grpc_core::ExecCtx exec_ctx; + ArgsStruct args; + ArgsInit(&args); + TestCancelActiveDNSQuery(&args); +} + +#ifdef GPR_WINDOWS + +void MaybePollArbitraryPollsetTwice() { + grpc_pollset* pollset = (grpc_pollset*)gpr_zalloc(grpc_pollset_size()); + gpr_mu* mu; + grpc_pollset_init(pollset, &mu); + grpc_pollset_worker* worker = nullptr; + // Make a zero timeout poll + gpr_mu_lock(mu); + GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(pollset, &worker, grpc_core::ExecCtx::Get()->Now())); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + // Make a second zero-timeout poll (in case the first one + // short-circuited by picking up a previous "kick") + gpr_mu_lock(mu); + GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(pollset, &worker, grpc_core::ExecCtx::Get()->Now())); + gpr_mu_unlock(mu); + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(pollset); + gpr_free(pollset); +} + +#else + +void MaybePollArbitraryPollsetTwice() {} + +#endif + +TEST_F(CancelDuringAresQuery, TestFdsAreDeletedFromPollsetSet) { + grpc_core::ExecCtx exec_ctx; + ArgsStruct args; + ArgsInit(&args); + // Add fake_other_pollset_set into the mix to test + // that we're explicitly deleting fd's from their pollset. + // If we aren't doing so, then the remaining presence of + // "fake_other_pollset_set" after the request is done and the resolver + // pollset set is destroyed should keep the resolver's fd alive and + // fail the test. + grpc_pollset_set* fake_other_pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset_set(fake_other_pollset_set, args.pollset_set); + // Note that running the cancellation c-ares test is somewhat irrelevant for + // this test. This test only cares about what happens to fd's that c-ares + // opens. + TestCancelActiveDNSQuery(&args); + // This test relies on the assumption that cancelling a c-ares query + // will flush out all callbacks on the current exec ctx, which is true + // on posix platforms but not on Windows, because fd shutdown on Windows + // requires a trip through the polling loop to schedule the callback. + // So we need to do extra polling work on Windows to free things up. + MaybePollArbitraryPollsetTwice(); + EXPECT_EQ(grpc_iomgr_count_objects_for_testing(), 0u); + grpc_pollset_set_destroy(fake_other_pollset_set); +} + +// Settings for TestCancelDuringActiveQuery test +typedef enum { + NONE, + SHORT, + ZERO, +} cancellation_test_query_timeout_setting; + +void TestCancelDuringActiveQuery( + cancellation_test_query_timeout_setting query_timeout_setting) { + // Start up fake non responsive DNS server + int fake_dns_port = grpc_pick_unused_port_or_die(); + grpc::testing::FakeNonResponsiveDNSServer fake_dns_server(fake_dns_port); + // Create a call that will try to use the fake DNS server + std::string client_target = absl::StrFormat( + "dns://[::1]:%d/dont-care-since-wont-be-resolved.test.com:1234", + fake_dns_port); + gpr_log(GPR_DEBUG, "TestCancelActiveDNSQuery. query timeout setting: %d", + query_timeout_setting); + grpc_channel_args* client_args = nullptr; + grpc_status_code expected_status_code = GRPC_STATUS_OK; + gpr_timespec rpc_deadline; + if (query_timeout_setting == NONE) { + // The RPC deadline should go off well before the DNS resolution + // timeout fires. + expected_status_code = GRPC_STATUS_DEADLINE_EXCEEDED; + // use default DNS resolution timeout (which is over one minute). + client_args = nullptr; + rpc_deadline = grpc_timeout_milliseconds_to_deadline(100); + } else if (query_timeout_setting == SHORT) { + // The DNS resolution timeout should fire well before the + // RPC's deadline expires. + expected_status_code = GRPC_STATUS_UNAVAILABLE; + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_DNS_ARES_QUERY_TIMEOUT_MS); + arg.value.integer = + 1; // Set this shorter than the call deadline so that it goes off. + client_args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + // Set the deadline high enough such that if we hit this and get + // a deadline exceeded status code, then we are confident that there's + // a bug causing cancellation of DNS resolutions to not happen in a timely + // manner. + rpc_deadline = grpc_timeout_seconds_to_deadline(10); + } else if (query_timeout_setting == ZERO) { + // The RPC deadline should go off well before the DNS resolution + // timeout fires. + expected_status_code = GRPC_STATUS_DEADLINE_EXCEEDED; + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_DNS_ARES_QUERY_TIMEOUT_MS); + arg.value.integer = 0; // Set this to zero to disable query timeouts. + client_args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + rpc_deadline = grpc_timeout_milliseconds_to_deadline(100); + } else { + abort(); + } + grpc_channel* client = + grpc_insecure_channel_create(client_target.c_str(), client_args, nullptr); + grpc_completion_queue* cq = grpc_completion_queue_create_for_next(nullptr); + cq_verifier* cqv = cq_verifier_create(cq); + grpc_call* call = grpc_channel_create_call( + client, nullptr, GRPC_PROPAGATE_DEFAULTS, cq, + grpc_slice_from_static_string("/foo"), nullptr, rpc_deadline, nullptr); + GPR_ASSERT(call); + grpc_metadata_array initial_metadata_recv; + grpc_metadata_array trailing_metadata_recv; + grpc_metadata_array request_metadata_recv; + grpc_metadata_array_init(&initial_metadata_recv); + grpc_metadata_array_init(&trailing_metadata_recv); + grpc_metadata_array_init(&request_metadata_recv); + grpc_call_details call_details; + grpc_call_details_init(&call_details); + grpc_status_code status; + const char* error_string; + grpc_slice details; + // Set ops for client the request + grpc_op ops_base[6]; + memset(ops_base, 0, sizeof(ops_base)); + grpc_op* op = ops_base; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->data.send_initial_metadata.count = 0; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv; + op->flags = 0; + op->reserved = nullptr; + op++; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = &trailing_metadata_recv; + op->data.recv_status_on_client.status = &status; + op->data.recv_status_on_client.status_details = &details; + op->data.recv_status_on_client.error_string = &error_string; + op->flags = 0; + op->reserved = nullptr; + op++; + // Run the call and sanity check it failed as expected + grpc_call_error error = grpc_call_start_batch( + call, ops_base, static_cast(op - ops_base), Tag(1), nullptr); + EXPECT_EQ(GRPC_CALL_OK, error); + CQ_EXPECT_COMPLETION(cqv, Tag(1), 1); + cq_verify(cqv); + EXPECT_EQ(status, expected_status_code); + // Teardown + grpc_channel_args_destroy(client_args); + grpc_slice_unref(details); + gpr_free(const_cast(error_string)); + grpc_metadata_array_destroy(&initial_metadata_recv); + grpc_metadata_array_destroy(&trailing_metadata_recv); + grpc_metadata_array_destroy(&request_metadata_recv); + grpc_call_details_destroy(&call_details); + grpc_call_unref(call); + cq_verifier_destroy(cqv); + EndTest(client, cq); +} + +TEST_F(CancelDuringAresQuery, + TestHitDeadlineAndDestroyChannelDuringAresResolutionIsGraceful) { + TestCancelDuringActiveQuery(NONE /* don't set query timeouts */); +} + +TEST_F( + CancelDuringAresQuery, + TestHitDeadlineAndDestroyChannelDuringAresResolutionWithQueryTimeoutIsGraceful) { + TestCancelDuringActiveQuery(SHORT /* set short query timeout */); +} + +TEST_F( + CancelDuringAresQuery, + TestHitDeadlineAndDestroyChannelDuringAresResolutionWithZeroQueryTimeoutIsGraceful) { + TestCancelDuringActiveQuery(ZERO /* disable query timeouts */); +} + +} // namespace + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/test/cpp/naming/dns_test_util.cc b/test/cpp/naming/dns_test_util.cc new file mode 100644 index 00000000..e0ea96fd --- /dev/null +++ b/test/cpp/naming/dns_test_util.cc @@ -0,0 +1,101 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/naming/dns_test_util.h" + +#include +#include + +#include +#include + +#include "src/core/lib/event_engine/sockaddr.h" + +#ifdef GPR_WINDOWS +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#define BAD_SOCKET_RETURN_VAL INVALID_SOCKET +#else +#include "src/core/lib/iomgr/sockaddr_posix.h" +#define BAD_SOCKET_RETURN_VAL (-1) +#endif + +namespace grpc { +namespace testing { + +FakeNonResponsiveDNSServer::FakeNonResponsiveDNSServer(int port) { + udp_socket_ = socket(AF_INET6, SOCK_DGRAM, 0); + tcp_socket_ = socket(AF_INET6, SOCK_STREAM, 0); + if (udp_socket_ == BAD_SOCKET_RETURN_VAL) { + gpr_log(GPR_DEBUG, "Failed to create UDP ipv6 socket"); + abort(); + } + if (tcp_socket_ == BAD_SOCKET_RETURN_VAL) { + gpr_log(GPR_DEBUG, "Failed to create TCP ipv6 socket"); + abort(); + } + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port); + (reinterpret_cast(&addr.sin6_addr))[15] = 1; + if (bind(udp_socket_, reinterpret_cast(&addr), + sizeof(addr)) != 0) { + gpr_log(GPR_DEBUG, "Failed to bind UDP ipv6 socket to [::1]:%d", port); + abort(); + } +#ifdef GPR_WINDOWS + char val = 1; + if (setsockopt(tcp_socket_, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == + SOCKET_ERROR) { + gpr_log(GPR_DEBUG, + "Failed to set SO_REUSEADDR on TCP ipv6 socket to [::1]:%d", port); + abort(); + } +#else + int val = 1; + if (setsockopt(tcp_socket_, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) != + 0) { + gpr_log(GPR_DEBUG, + "Failed to set SO_REUSEADDR on TCP ipv6 socket to [::1]:%d", port); + abort(); + } +#endif + if (bind(tcp_socket_, reinterpret_cast(&addr), + sizeof(addr)) != 0) { + gpr_log(GPR_DEBUG, "Failed to bind TCP ipv6 socket to [::1]:%d", port); + abort(); + } + if (listen(tcp_socket_, 100)) { + gpr_log(GPR_DEBUG, "Failed to listen on TCP ipv6 socket to [::1]:%d", port); + abort(); + } +} + +FakeNonResponsiveDNSServer::~FakeNonResponsiveDNSServer() { +#ifdef GPR_WINDOWS + closesocket(udp_socket_); + closesocket(tcp_socket_); +#else + close(udp_socket_); + close(tcp_socket_); +#endif +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/naming/resolver_component_test.cc b/test/cpp/naming/resolver_component_test.cc new file mode 100644 index 00000000..a230b5a4 --- /dev/null +++ b/test/cpp/naming/resolver_component_test.cc @@ -0,0 +1,685 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/address_utils/parse_address.h" +#include "src/core/lib/address_utils/sockaddr_utils.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/work_serializer.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/naming/dns_test_util.h" +#include "test/cpp/util/subprocess.h" +#include "test/cpp/util/test_config.h" + +// TODO(unknown): pull in different headers when enabling this +// test on windows. Also set BAD_SOCKET_RETURN_VAL +// to INVALID_SOCKET on windows. +#ifdef GPR_WINDOWS +#include "src/core/lib/iomgr/sockaddr_windows.h" +#include "src/core/lib/iomgr/socket_windows.h" +#include "src/core/lib/iomgr/tcp_windows.h" +#define BAD_SOCKET_RETURN_VAL INVALID_SOCKET +#else +#include "src/core/lib/iomgr/sockaddr_posix.h" +#define BAD_SOCKET_RETURN_VAL (-1) +#endif + +using std::vector; +using testing::UnorderedElementsAreArray; + +ABSL_FLAG(std::string, target_name, "", "Target name to resolve."); +ABSL_FLAG(std::string, do_ordered_address_comparison, "", + "Whether or not to compare resolved addresses to expected " + "addresses using an ordered comparison. This is useful for " + "testing certain behaviors that involve sorting of resolved " + "addresses. Note it would be better if this argument was a " + "bool flag, but it's a string for ease of invocation from " + "the generated python test runner."); +ABSL_FLAG(std::string, expected_addrs, "", + "List of expected backend or balancer addresses in the form " + "',;,;...'. " + "'is_balancer' should be bool, i.e. true or false."); +ABSL_FLAG(std::string, expected_chosen_service_config, "", + "Expected service config json string that gets chosen (no " + "whitespace). Empty for none."); +ABSL_FLAG(std::string, expected_service_config_error, "", + "Expected service config error. Empty for none."); +ABSL_FLAG(std::string, local_dns_server_address, "", + "Optional. This address is placed as the uri authority if present."); +// TODO(Capstan): Is this worth making `bool` now with Abseil flags? +ABSL_FLAG( + std::string, enable_srv_queries, "", + "Whether or not to enable SRV queries for the ares resolver instance." + "It would be better if this arg could be bool, but the way that we " + "generate " + "the python script runner doesn't allow us to pass a gflags bool to this " + "binary."); +// TODO(Capstan): Is this worth making `bool` now with Abseil flags? +ABSL_FLAG( + std::string, enable_txt_queries, "", + "Whether or not to enable TXT queries for the ares resolver instance." + "It would be better if this arg could be bool, but the way that we " + "generate " + "the python script runner doesn't allow us to pass a gflags bool to this " + "binary."); +// TODO(Capstan): Is this worth making `bool` now with Abseil flags? +ABSL_FLAG( + std::string, inject_broken_nameserver_list, "", + "Whether or not to configure c-ares to use a broken nameserver list, in " + "which " + "the first nameserver in the list is non-responsive, but the second one " + "works, i.e " + "serves the expected DNS records; using for testing such a real scenario." + "It would be better if this arg could be bool, but the way that we " + "generate " + "the python script runner doesn't allow us to pass a gflags bool to this " + "binary."); +ABSL_FLAG(std::string, expected_lb_policy, "", + "Expected lb policy name that appears in resolver result channel " + "arg. Empty for none."); + +namespace { + +class GrpcLBAddress final { + public: + GrpcLBAddress(std::string address, bool is_balancer) + : is_balancer(is_balancer), address(std::move(address)) {} + + bool operator==(const GrpcLBAddress& other) const { + return this->is_balancer == other.is_balancer && + this->address == other.address; + } + + bool operator!=(const GrpcLBAddress& other) const { + return !(*this == other); + } + + bool is_balancer; + std::string address; +}; + +vector ParseExpectedAddrs(std::string expected_addrs) { + std::vector out; + while (!expected_addrs.empty()) { + // get the next , (v4 or v6) + size_t next_comma = expected_addrs.find(','); + if (next_comma == std::string::npos) { + gpr_log(GPR_ERROR, + "Missing ','. Expected_addrs arg should be a semicolon-separated " + "list of , pairs. Left-to-be-parsed arg is |%s|", + expected_addrs.c_str()); + abort(); + } + std::string next_addr = expected_addrs.substr(0, next_comma); + expected_addrs = expected_addrs.substr(next_comma + 1, std::string::npos); + // get the next is_balancer 'bool' associated with this address + size_t next_semicolon = expected_addrs.find(';'); + bool is_balancer = false; + gpr_parse_bool_value(expected_addrs.substr(0, next_semicolon).c_str(), + &is_balancer); + out.emplace_back(GrpcLBAddress(next_addr, is_balancer)); + if (next_semicolon == std::string::npos) { + break; + } + expected_addrs = + expected_addrs.substr(next_semicolon + 1, std::string::npos); + } + if (out.empty()) { + gpr_log(GPR_ERROR, + "expected_addrs arg should be a semicolon-separated list of " + ", pairs"); + abort(); + } + return out; +} + +gpr_timespec TestDeadline(void) { + return grpc_timeout_seconds_to_deadline(100); +} + +struct ArgsStruct { + gpr_event ev; + gpr_atm done_atm; + gpr_mu* mu; + grpc_pollset* pollset; + grpc_pollset_set* pollset_set; + std::shared_ptr lock; + grpc_channel_args* channel_args; + vector expected_addrs; + std::string expected_service_config_string; + std::string expected_service_config_error; + std::string expected_lb_policy; +}; + +void ArgsInit(ArgsStruct* args) { + gpr_event_init(&args->ev); + args->pollset = static_cast(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(args->pollset, &args->mu); + args->pollset_set = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(args->pollset_set, args->pollset); + args->lock = std::make_shared(); + gpr_atm_rel_store(&args->done_atm, 0); + args->channel_args = nullptr; +} + +void DoNothing(void* /*arg*/, grpc_error_handle /*error*/) {} + +void ArgsFinish(ArgsStruct* args) { + GPR_ASSERT(gpr_event_wait(&args->ev, TestDeadline())); + grpc_pollset_set_del_pollset(args->pollset_set, args->pollset); + grpc_pollset_set_destroy(args->pollset_set); + grpc_closure DoNothing_cb; + GRPC_CLOSURE_INIT(&DoNothing_cb, DoNothing, nullptr, + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(args->pollset, &DoNothing_cb); + // exec_ctx needs to be flushed before calling grpc_pollset_destroy() + grpc_channel_args_destroy(args->channel_args); + grpc_core::ExecCtx::Get()->Flush(); + grpc_pollset_destroy(args->pollset); + gpr_free(args->pollset); +} + +gpr_timespec NSecondDeadline(int seconds) { + return gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(seconds, GPR_TIMESPAN)); +} + +void PollPollsetUntilRequestDone(ArgsStruct* args) { + // Use a 20-second timeout to give room for the tests that involve + // a non-responsive name server (c-ares uses a ~5 second query timeout + // for that server before succeeding with the healthy one). + gpr_timespec deadline = NSecondDeadline(20); + while (true) { + bool done = gpr_atm_acq_load(&args->done_atm) != 0; + if (done) { + break; + } + gpr_timespec time_left = + gpr_time_sub(deadline, gpr_now(GPR_CLOCK_REALTIME)); + gpr_log(GPR_DEBUG, "done=%d, time_left=%" PRId64 ".%09d", done, + time_left.tv_sec, time_left.tv_nsec); + GPR_ASSERT(gpr_time_cmp(time_left, gpr_time_0(GPR_TIMESPAN)) >= 0); + grpc_pollset_worker* worker = nullptr; + grpc_core::ExecCtx exec_ctx; + gpr_mu_lock(args->mu); + GRPC_LOG_IF_ERROR("pollset_work", + grpc_pollset_work(args->pollset, &worker, + grpc_timespec_to_millis_round_up( + NSecondDeadline(1)))); + gpr_mu_unlock(args->mu); + } + gpr_event_set(&args->ev, reinterpret_cast(1)); +} + +void CheckServiceConfigResultLocked(const char* service_config_json, + grpc_error_handle service_config_error, + ArgsStruct* args) { + if (!args->expected_service_config_string.empty()) { + GPR_ASSERT(service_config_json != nullptr); + EXPECT_EQ(service_config_json, args->expected_service_config_string); + } + if (args->expected_service_config_error.empty()) { + EXPECT_EQ(service_config_error, GRPC_ERROR_NONE); + } else { + EXPECT_THAT(grpc_error_std_string(service_config_error), + testing::HasSubstr(args->expected_service_config_error)); + } + GRPC_ERROR_UNREF(service_config_error); +} + +void CheckLBPolicyResultLocked(const grpc_channel_args* channel_args, + ArgsStruct* args) { + const grpc_arg* lb_policy_arg = + grpc_channel_args_find(channel_args, GRPC_ARG_LB_POLICY_NAME); + if (!args->expected_lb_policy.empty()) { + GPR_ASSERT(lb_policy_arg != nullptr); + GPR_ASSERT(lb_policy_arg->type == GRPC_ARG_STRING); + EXPECT_EQ(lb_policy_arg->value.string, args->expected_lb_policy); + } else { + GPR_ASSERT(lb_policy_arg == nullptr); + } +} + +#ifdef GPR_WINDOWS +void OpenAndCloseSocketsStressLoop(int phony_port, gpr_event* done_ev) { + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(phony_port); + ((char*)&addr.sin6_addr)[15] = 1; + for (;;) { + if (gpr_event_get(done_ev)) { + return; + } + std::vector sockets; + for (size_t i = 0; i < 50; i++) { + SOCKET s = WSASocket(AF_INET6, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, + WSA_FLAG_OVERLAPPED); + ASSERT_TRUE(s != BAD_SOCKET_RETURN_VAL) + << "Failed to create TCP ipv6 socket"; + gpr_log(GPR_DEBUG, "Opened socket: %d", s); + char val = 1; + ASSERT_TRUE(setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) != + SOCKET_ERROR) + << "Failed to set socketopt reuseaddr. WSA error: " + + std::to_string(WSAGetLastError()); + ASSERT_TRUE(grpc_tcp_set_non_block(s) == GRPC_ERROR_NONE) + << "Failed to set socket non-blocking"; + ASSERT_TRUE(bind(s, (const sockaddr*)&addr, sizeof(addr)) != SOCKET_ERROR) + << "Failed to bind socket " + std::to_string(s) + + " to [::1]:" + std::to_string(phony_port) + + ". WSA error: " + std::to_string(WSAGetLastError()); + ASSERT_TRUE(listen(s, 1) != SOCKET_ERROR) + << "Failed to listen on socket " + std::to_string(s) + + ". WSA error: " + std::to_string(WSAGetLastError()); + sockets.push_back(s); + } + // Do a non-blocking accept followed by a close on all of those sockets. + // Do this in a separate loop to try to induce a time window to hit races. + for (size_t i = 0; i < sockets.size(); i++) { + gpr_log(GPR_DEBUG, "non-blocking accept then close on %d", sockets[i]); + ASSERT_TRUE(accept(sockets[i], nullptr, nullptr) == INVALID_SOCKET) + << "Accept on phony socket unexpectedly accepted actual connection."; + ASSERT_TRUE(WSAGetLastError() == WSAEWOULDBLOCK) + << "OpenAndCloseSocketsStressLoop accept on socket " + + std::to_string(sockets[i]) + + " failed in " + "an unexpected way. " + "WSA error: " + + std::to_string(WSAGetLastError()) + + ". Socket use-after-close bugs are likely."; + ASSERT_TRUE(closesocket(sockets[i]) != SOCKET_ERROR) + << "Failed to close socket: " + std::to_string(sockets[i]) + + ". WSA error: " + std::to_string(WSAGetLastError()); + } + } + return; +} +#else +void OpenAndCloseSocketsStressLoop(int phony_port, gpr_event* done_ev) { + // The goal of this loop is to catch socket + // "use after close" bugs within the c-ares resolver by acting + // like some separate thread doing I/O. + // It's goal is to try to hit race conditions whereby: + // 1) The c-ares resolver closes a socket. + // 2) This loop opens a socket with (coincidentally) the same handle. + // 3) the c-ares resolver mistakenly uses that same socket without + // realizing that its closed. + // 4) This loop performs an operation on that socket that should + // succeed but instead fails because of what the c-ares + // resolver did in the meantime. + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(phony_port); + (reinterpret_cast(&addr.sin6_addr))[15] = 1; + for (;;) { + if (gpr_event_get(done_ev)) { + return; + } + std::vector sockets; + // First open a bunch of sockets, bind and listen + // '50' is an arbitrary number that, experimentally, + // has a good chance of catching bugs. + for (size_t i = 0; i < 50; i++) { + int s = socket(AF_INET6, SOCK_STREAM, 0); + int val = 1; + ASSERT_TRUE(setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val)) == + 0) + << "Failed to set socketopt reuseport"; + ASSERT_TRUE(setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == + 0) + << "Failed to set socket reuseaddr"; + ASSERT_TRUE(fcntl(s, F_SETFL, O_NONBLOCK) == 0) + << "Failed to set socket non-blocking"; + ASSERT_TRUE(s != BAD_SOCKET_RETURN_VAL) + << "Failed to create TCP ipv6 socket"; + gpr_log(GPR_DEBUG, "Opened fd: %d", s); + ASSERT_TRUE(bind(s, (const sockaddr*)&addr, sizeof(addr)) == 0) + << "Failed to bind socket " + std::to_string(s) + + " to [::1]:" + std::to_string(phony_port) + + ". errno: " + std::to_string(errno); + ASSERT_TRUE(listen(s, 1) == 0) << "Failed to listen on socket " + + std::to_string(s) + + ". errno: " + std::to_string(errno); + sockets.push_back(s); + } + // Do a non-blocking accept followed by a close on all of those sockets. + // Do this in a separate loop to try to induce a time window to hit races. + for (size_t i = 0; i < sockets.size(); i++) { + gpr_log(GPR_DEBUG, "non-blocking accept then close on %d", sockets[i]); + if (accept(sockets[i], nullptr, nullptr)) { + // If e.g. a "shutdown" was called on this fd from another thread, + // then this accept call should fail with an unexpected error. + ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK) + << "OpenAndCloseSocketsStressLoop accept on socket " + + std::to_string(sockets[i]) + + " failed in " + "an unexpected way. " + "errno: " + + std::to_string(errno) + + ". Socket use-after-close bugs are likely."; + } + ASSERT_TRUE(close(sockets[i]) == 0) + << "Failed to close socket: " + std::to_string(sockets[i]) + + ". errno: " + std::to_string(errno); + } + } +} +#endif + +class ResultHandler : public grpc_core::Resolver::ResultHandler { + public: + static std::unique_ptr Create( + ArgsStruct* args) { + return std::unique_ptr( + new ResultHandler(args)); + } + + explicit ResultHandler(ArgsStruct* args) : args_(args) {} + + void ReturnResult(grpc_core::Resolver::Result result) override { + CheckResult(result); + gpr_atm_rel_store(&args_->done_atm, 1); + gpr_mu_lock(args_->mu); + GRPC_LOG_IF_ERROR("pollset_kick", + grpc_pollset_kick(args_->pollset, nullptr)); + gpr_mu_unlock(args_->mu); + } + + void ReturnError(grpc_error_handle error) override { + gpr_log(GPR_ERROR, "resolver returned error: %s", + grpc_error_std_string(error).c_str()); + GPR_ASSERT(false); + } + + virtual void CheckResult(const grpc_core::Resolver::Result& /*result*/) {} + + protected: + ArgsStruct* args_struct() const { return args_; } + + private: + ArgsStruct* args_; +}; + +class CheckingResultHandler : public ResultHandler { + public: + static std::unique_ptr Create( + ArgsStruct* args) { + return std::unique_ptr( + new CheckingResultHandler(args)); + } + + explicit CheckingResultHandler(ArgsStruct* args) : ResultHandler(args) {} + + void CheckResult(const grpc_core::Resolver::Result& result) override { + ArgsStruct* args = args_struct(); + std::vector found_lb_addrs; + AddActualAddresses(result.addresses, /*is_balancer=*/false, + &found_lb_addrs); + const grpc_core::ServerAddressList* balancer_addresses = + grpc_core::FindGrpclbBalancerAddressesInChannelArgs(*result.args); + if (balancer_addresses != nullptr) { + AddActualAddresses(*balancer_addresses, /*is_balancer=*/true, + &found_lb_addrs); + } + gpr_log(GPR_INFO, + "found %" PRIdPTR " backend addresses and %" PRIdPTR + " balancer addresses", + result.addresses.size(), + balancer_addresses == nullptr ? 0L : balancer_addresses->size()); + if (args->expected_addrs.size() != found_lb_addrs.size()) { + gpr_log(GPR_DEBUG, + "found lb addrs size is: %" PRIdPTR + ". expected addrs size is %" PRIdPTR, + found_lb_addrs.size(), args->expected_addrs.size()); + abort(); + } + if (absl::GetFlag(FLAGS_do_ordered_address_comparison) == "True") { + EXPECT_EQ(args->expected_addrs, found_lb_addrs); + } else if (absl::GetFlag(FLAGS_do_ordered_address_comparison) == "False") { + EXPECT_THAT(args->expected_addrs, + UnorderedElementsAreArray(found_lb_addrs)); + } else { + gpr_log(GPR_ERROR, + "Invalid for setting for --do_ordered_address_comparison. " + "Have %s, want True or False", + absl::GetFlag(FLAGS_do_ordered_address_comparison).c_str()); + GPR_ASSERT(0); + } + const char* service_config_json = + result.service_config == nullptr + ? nullptr + : result.service_config->json_string().c_str(); + CheckServiceConfigResultLocked( + service_config_json, GRPC_ERROR_REF(result.service_config_error), args); + if (args->expected_service_config_string.empty()) { + CheckLBPolicyResultLocked(result.args, args); + } + } + + private: + static void AddActualAddresses(const grpc_core::ServerAddressList& addresses, + bool is_balancer, + std::vector* out) { + for (size_t i = 0; i < addresses.size(); i++) { + const grpc_core::ServerAddress& addr = addresses[i]; + std::string str = + grpc_sockaddr_to_string(&addr.address(), true /* normalize */); + gpr_log(GPR_INFO, "%s", str.c_str()); + out->emplace_back(GrpcLBAddress(std::move(str), is_balancer)); + } + } +}; + +int g_fake_non_responsive_dns_server_port = -1; + +/* This function will configure any ares_channel created by the c-ares based + * resolver. This is useful to effectively mock /etc/resolv.conf settings + * (and equivalent on Windows), which unit tests don't have write permissions. + */ +void InjectBrokenNameServerList(ares_channel channel) { + struct ares_addr_port_node dns_server_addrs[2]; + memset(dns_server_addrs, 0, sizeof(dns_server_addrs)); + std::string unused_host; + std::string local_dns_server_port; + GPR_ASSERT(grpc_core::SplitHostPort( + absl::GetFlag(FLAGS_local_dns_server_address).c_str(), &unused_host, + &local_dns_server_port)); + gpr_log(GPR_DEBUG, + "Injecting broken nameserver list. Bad server address:|[::1]:%d|. " + "Good server address:%s", + g_fake_non_responsive_dns_server_port, + absl::GetFlag(FLAGS_local_dns_server_address).c_str()); + // Put the non-responsive DNS server at the front of c-ares's nameserver list. + dns_server_addrs[0].family = AF_INET6; + (reinterpret_cast(&dns_server_addrs[0].addr.addr6))[15] = 0x1; + dns_server_addrs[0].tcp_port = g_fake_non_responsive_dns_server_port; + dns_server_addrs[0].udp_port = g_fake_non_responsive_dns_server_port; + dns_server_addrs[0].next = &dns_server_addrs[1]; + // Put the actual healthy DNS server after the first one. The expectation is + // that the resolver will timeout the query to the non-responsive DNS server + // and will skip over to this healthy DNS server, without causing any DNS + // resolution errors. + dns_server_addrs[1].family = AF_INET; + (reinterpret_cast(&dns_server_addrs[1].addr.addr4))[0] = 0x7f; + (reinterpret_cast(&dns_server_addrs[1].addr.addr4))[3] = 0x1; + dns_server_addrs[1].tcp_port = atoi(local_dns_server_port.c_str()); + dns_server_addrs[1].udp_port = atoi(local_dns_server_port.c_str()); + dns_server_addrs[1].next = nullptr; + GPR_ASSERT(ares_set_servers_ports(channel, dns_server_addrs) == ARES_SUCCESS); +} + +void StartResolvingLocked(grpc_core::Resolver* r) { r->StartLocked(); } + +void RunResolvesRelevantRecordsTest( + std::unique_ptr (*CreateResultHandler)( + ArgsStruct* args)) { + grpc_core::ExecCtx exec_ctx; + ArgsStruct args; + ArgsInit(&args); + args.expected_addrs = ParseExpectedAddrs(absl::GetFlag(FLAGS_expected_addrs)); + args.expected_service_config_string = + absl::GetFlag(FLAGS_expected_chosen_service_config); + args.expected_service_config_error = + absl::GetFlag(FLAGS_expected_service_config_error); + args.expected_lb_policy = absl::GetFlag(FLAGS_expected_lb_policy); + // maybe build the address with an authority + std::string whole_uri; + gpr_log(GPR_DEBUG, + "resolver_component_test: --inject_broken_nameserver_list: %s", + absl::GetFlag(FLAGS_inject_broken_nameserver_list).c_str()); + std::unique_ptr + fake_non_responsive_dns_server; + if (absl::GetFlag(FLAGS_inject_broken_nameserver_list) == "True") { + g_fake_non_responsive_dns_server_port = grpc_pick_unused_port_or_die(); + fake_non_responsive_dns_server = + absl::make_unique( + + g_fake_non_responsive_dns_server_port); + grpc_ares_test_only_inject_config = InjectBrokenNameServerList; + whole_uri = absl::StrCat("dns:///", absl::GetFlag(FLAGS_target_name)); + } else if (absl::GetFlag(FLAGS_inject_broken_nameserver_list) == "False") { + gpr_log(GPR_INFO, "Specifying authority in uris to: %s", + absl::GetFlag(FLAGS_local_dns_server_address).c_str()); + whole_uri = absl::StrFormat("dns://%s/%s", + absl::GetFlag(FLAGS_local_dns_server_address), + absl::GetFlag(FLAGS_target_name)); + } else { + gpr_log(GPR_DEBUG, "Invalid value for --inject_broken_nameserver_list."); + abort(); + } + gpr_log(GPR_DEBUG, "resolver_component_test: --enable_srv_queries: %s", + absl::GetFlag(FLAGS_enable_srv_queries).c_str()); + grpc_channel_args* resolver_args = nullptr; + // By default, SRV queries are disabled, so tests that expect no SRV query + // should avoid setting any channel arg. Test cases that do rely on the SRV + // query must explicitly enable SRV though. + if (absl::GetFlag(FLAGS_enable_srv_queries) == "True") { + grpc_arg srv_queries_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_DNS_ENABLE_SRV_QUERIES), true); + resolver_args = + grpc_channel_args_copy_and_add(nullptr, &srv_queries_arg, 1); + } else if (absl::GetFlag(FLAGS_enable_srv_queries) != "False") { + gpr_log(GPR_DEBUG, "Invalid value for --enable_srv_queries."); + abort(); + } + gpr_log(GPR_DEBUG, "resolver_component_test: --enable_txt_queries: %s", + absl::GetFlag(FLAGS_enable_txt_queries).c_str()); + // By default, TXT queries are disabled, so tests that expect no TXT query + // should avoid setting any channel arg. Test cases that do rely on the TXT + // query must explicitly enable TXT though. + if (absl::GetFlag(FLAGS_enable_txt_queries) == "True") { + // Unlike SRV queries, there isn't a channel arg specific to TXT records. + // Rather, we use the resolver-agnostic "service config" resolution option, + // for which c-ares has its own specific default value, which isn't + // necessarily shared by other resolvers. + grpc_arg txt_queries_arg = grpc_channel_arg_integer_create( + const_cast(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION), false); + grpc_channel_args* tmp_args = + grpc_channel_args_copy_and_add(resolver_args, &txt_queries_arg, 1); + grpc_channel_args_destroy(resolver_args); + resolver_args = tmp_args; + } else if (absl::GetFlag(FLAGS_enable_txt_queries) != "False") { + gpr_log(GPR_DEBUG, "Invalid value for --enable_txt_queries."); + abort(); + } + // create resolver and resolve + grpc_core::OrphanablePtr resolver = + grpc_core::ResolverRegistry::CreateResolver( + whole_uri.c_str(), resolver_args, args.pollset_set, args.lock, + CreateResultHandler(&args)); + grpc_channel_args_destroy(resolver_args); + auto* resolver_ptr = resolver.get(); + args.lock->Run([resolver_ptr]() { StartResolvingLocked(resolver_ptr); }, + DEBUG_LOCATION); + grpc_core::ExecCtx::Get()->Flush(); + PollPollsetUntilRequestDone(&args); + ArgsFinish(&args); +} + +TEST(ResolverComponentTest, TestResolvesRelevantRecords) { + RunResolvesRelevantRecordsTest(CheckingResultHandler::Create); +} + +TEST(ResolverComponentTest, TestResolvesRelevantRecordsWithConcurrentFdStress) { + // Start up background stress thread + int phony_port = grpc_pick_unused_port_or_die(); + gpr_event done_ev; + gpr_event_init(&done_ev); + std::thread socket_stress_thread(OpenAndCloseSocketsStressLoop, phony_port, + &done_ev); + // Run the resolver test + RunResolvesRelevantRecordsTest(ResultHandler::Create); + // Shutdown and join stress thread + gpr_event_set(&done_ev, reinterpret_cast(1)); + socket_stress_thread.join(); +} + +} // namespace + +int main(int argc, char** argv) { + grpc_init(); + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + if (absl::GetFlag(FLAGS_target_name).empty()) { + gpr_log(GPR_ERROR, "Missing target_name param."); + abort(); + } + auto result = RUN_ALL_TESTS(); + grpc_shutdown(); + return result; +} diff --git a/test/cpp/naming/resolver_component_tests_runner_invoker.cc b/test/cpp/naming/resolver_component_tests_runner_invoker.cc new file mode 100644 index 00000000..d6f12171 --- /dev/null +++ b/test/cpp/naming/resolver_component_tests_runner_invoker.cc @@ -0,0 +1,164 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include +#include + +#ifdef __FreeBSD__ +#include +#endif + +#include "src/core/lib/gpr/env.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG( + bool, running_under_bazel, false, + "True if this test is running under bazel. " + "False indicates that this test is running under run_tests.py. " + "Child process test binaries are located differently based on this flag. "); + +ABSL_FLAG(std::string, test_bin_name, "", + "Name, without the preceding path, of the test binary"); + +ABSL_FLAG(std::string, grpc_test_directory_relative_to_test_srcdir, + "/com_github_grpc_grpc", + "This flag only applies if runner_under_bazel is true. This " + "flag is ignored if runner_under_bazel is false. " + "Directory of the /test directory relative to bazel's " + "TEST_SRCDIR environment variable"); + +ABSL_FLAG(std::string, extra_args, "", + "Comma-separated list of opaque command args to plumb through to " + "the binary pointed at by --test_bin_name"); + +using grpc::SubProcess; + +namespace grpc { + +namespace testing { + +void InvokeResolverComponentTestsRunner( + std::string test_runner_bin_path, const std::string& test_bin_path, + const std::string& dns_server_bin_path, + const std::string& records_config_path, + const std::string& dns_resolver_bin_path, + const std::string& tcp_connect_bin_path) { + int dns_server_port = grpc_pick_unused_port_or_die(); + + SubProcess* test_driver = new SubProcess( + {std::move(test_runner_bin_path), "--test_bin_path=" + test_bin_path, + "--dns_server_bin_path=" + dns_server_bin_path, + "--records_config_path=" + records_config_path, + "--dns_server_port=" + std::to_string(dns_server_port), + "--dns_resolver_bin_path=" + dns_resolver_bin_path, + "--tcp_connect_bin_path=" + tcp_connect_bin_path, + "--extra_args=" + absl::GetFlag(FLAGS_extra_args)}); + gpr_mu test_driver_mu; + gpr_mu_init(&test_driver_mu); + gpr_cv test_driver_cv; + gpr_cv_init(&test_driver_cv); + int test_driver_done = 0; + int status = test_driver->Join(); + if (WIFEXITED(status)) { + if (WEXITSTATUS(status)) { + gpr_log(GPR_INFO, + "Resolver component test test-runner exited with code %d", + WEXITSTATUS(status)); + abort(); + } + } else if (WIFSIGNALED(status)) { + gpr_log(GPR_INFO, + "Resolver component test test-runner ended from signal %d", + WTERMSIG(status)); + abort(); + } else { + gpr_log(GPR_INFO, + "Resolver component test test-runner ended with unknown status %d", + status); + abort(); + } + gpr_mu_lock(&test_driver_mu); + test_driver_done = 1; + gpr_cv_signal(&test_driver_cv); + gpr_mu_unlock(&test_driver_mu); + delete test_driver; + gpr_mu_destroy(&test_driver_mu); + gpr_cv_destroy(&test_driver_cv); +} + +} // namespace testing + +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + grpc_init(); + GPR_ASSERT(!absl::GetFlag(FLAGS_test_bin_name).empty()); + std::string my_bin = argv[0]; + if (absl::GetFlag(FLAGS_running_under_bazel)) { + GPR_ASSERT(!absl::GetFlag(FLAGS_grpc_test_directory_relative_to_test_srcdir) + .empty()); + // Use bazel's TEST_SRCDIR environment variable to locate the "test data" + // binaries. + char* test_srcdir = gpr_getenv("TEST_SRCDIR"); + std::string const bin_dir = + test_srcdir + + absl::GetFlag(FLAGS_grpc_test_directory_relative_to_test_srcdir) + + std::string("/test/cpp/naming"); + // Invoke bazel's executeable links to the .sh and .py scripts (don't use + // the .sh and .py suffixes) to make + // sure that we're using bazel's test environment. + grpc::testing::InvokeResolverComponentTestsRunner( + bin_dir + "/resolver_component_tests_runner", + bin_dir + "/" + absl::GetFlag(FLAGS_test_bin_name), + bin_dir + "/utils/dns_server", + bin_dir + "/resolver_test_record_groups.yaml", + bin_dir + "/utils/dns_resolver", bin_dir + "/utils/tcp_connect"); + gpr_free(test_srcdir); + } else { + // Get the current binary's directory relative to repo root to invoke the + // correct build config (asan/tsan/dbg, etc.). + std::string const bin_dir = my_bin.substr(0, my_bin.rfind('/')); + // Invoke the .sh and .py scripts directly where they are in source code. + grpc::testing::InvokeResolverComponentTestsRunner( + "test/cpp/naming/resolver_component_tests_runner.py", + bin_dir + "/" + absl::GetFlag(FLAGS_test_bin_name), + "test/cpp/naming/utils/dns_server.py", + "test/cpp/naming/resolver_test_record_groups.yaml", + "test/cpp/naming/utils/dns_resolver.py", + "test/cpp/naming/utils/tcp_connect.py"); + } + grpc_shutdown(); + return 0; +} diff --git a/test/cpp/performance/writes_per_rpc_test.cc b/test/cpp/performance/writes_per_rpc_test.cc new file mode 100644 index 00000000..900d55df --- /dev/null +++ b/test/cpp/performance/writes_per_rpc_test.cc @@ -0,0 +1,248 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/endpoint_pair.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/tcp_posix.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/core/lib/surface/server.h" +#include "src/cpp/client/create_channel_internal.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/passthru_endpoint.h" +#include "test/core/util/port.h" +#include "test/core/util/resource_user_util.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { + +static void* tag(intptr_t x) { return reinterpret_cast(x); } + +static void ApplyCommonServerBuilderConfig(ServerBuilder* b) { + b->SetMaxReceiveMessageSize(INT_MAX); + b->SetMaxSendMessageSize(INT_MAX); +} + +static void ApplyCommonChannelArguments(ChannelArguments* c) { + c->SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, INT_MAX); + c->SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, INT_MAX); +} + +class EndpointPairFixture { + public: + EndpointPairFixture(Service* service, grpc_endpoint_pair endpoints) { + ServerBuilder b; + cq_ = b.AddCompletionQueue(true); + b.RegisterService(service); + ApplyCommonServerBuilderConfig(&b); + server_ = b.BuildAndStart(); + + grpc_core::ExecCtx exec_ctx; + + /* add server endpoint to server_ */ + { + const grpc_channel_args* server_args = + server_->c_server()->core_server->channel_args(); + grpc_transport* transport = grpc_create_chttp2_transport( + server_args, endpoints.server, false /* is_client */, + grpc_resource_user_create_unlimited()); + for (grpc_pollset* pollset : + server_->c_server()->core_server->pollsets()) { + grpc_endpoint_add_to_pollset(endpoints.server, pollset); + } + + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "SetupTransport", server_->c_server()->core_server->SetupTransport( + transport, nullptr, server_args, nullptr))); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + } + + /* create channel */ + { + ChannelArguments args; + args.SetString(GRPC_ARG_DEFAULT_AUTHORITY, "test.authority"); + ApplyCommonChannelArguments(&args); + + grpc_channel_args c_args = args.c_channel_args(); + grpc_transport* transport = + grpc_create_chttp2_transport(&c_args, endpoints.client, true, + grpc_resource_user_create_unlimited()); + GPR_ASSERT(transport); + grpc_channel* channel = + grpc_channel_create("target", &c_args, GRPC_CLIENT_DIRECT_CHANNEL, + transport, nullptr, 0, nullptr); + grpc_chttp2_transport_start_reading(transport, nullptr, nullptr, nullptr); + + channel_ = ::grpc::CreateChannelInternal( + "", channel, + std::vector>()); + } + } + + virtual ~EndpointPairFixture() { + server_->Shutdown(); + cq_->Shutdown(); + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + } + } + + ServerCompletionQueue* cq() { return cq_.get(); } + std::shared_ptr channel() { return channel_; } + + private: + std::unique_ptr server_; + std::unique_ptr cq_; + std::shared_ptr channel_; +}; + +class InProcessCHTTP2 : public EndpointPairFixture { + public: + InProcessCHTTP2(Service* service, grpc_passthru_endpoint_stats* stats) + : EndpointPairFixture(service, MakeEndpoints(stats)), stats_(stats) {} + + ~InProcessCHTTP2() override { + if (stats_ != nullptr) { + grpc_passthru_endpoint_stats_destroy(stats_); + } + } + + int writes_performed() const { return stats_->num_writes; } + + private: + grpc_passthru_endpoint_stats* stats_; + + static grpc_endpoint_pair MakeEndpoints(grpc_passthru_endpoint_stats* stats) { + grpc_endpoint_pair p; + grpc_passthru_endpoint_create(&p.client, &p.server, stats); + return p; + } +}; + +static double UnaryPingPong(int request_size, int response_size) { + const int kIterations = 10000; + + EchoTestService::AsyncService service; + std::unique_ptr fixture( + new InProcessCHTTP2(&service, grpc_passthru_endpoint_stats_create())); + EchoRequest send_request; + EchoResponse send_response; + EchoResponse recv_response; + if (request_size > 0) { + send_request.set_message(std::string(request_size, 'a')); + } + if (response_size > 0) { + send_response.set_message(std::string(response_size, 'a')); + } + Status recv_status; + struct ServerEnv { + ServerContext ctx; + EchoRequest recv_request; + grpc::ServerAsyncResponseWriter response_writer; + ServerEnv() : response_writer(&ctx) {} + }; + uint8_t server_env_buffer[2 * sizeof(ServerEnv)]; + ServerEnv* server_env[2] = { + reinterpret_cast(server_env_buffer), + reinterpret_cast(server_env_buffer + sizeof(ServerEnv))}; + new (server_env[0]) ServerEnv; + new (server_env[1]) ServerEnv; + service.RequestEcho(&server_env[0]->ctx, &server_env[0]->recv_request, + &server_env[0]->response_writer, fixture->cq(), + fixture->cq(), tag(0)); + service.RequestEcho(&server_env[1]->ctx, &server_env[1]->recv_request, + &server_env[1]->response_writer, fixture->cq(), + fixture->cq(), tag(1)); + std::unique_ptr stub( + EchoTestService::NewStub(fixture->channel())); + for (int iteration = 0; iteration < kIterations; iteration++) { + recv_response.Clear(); + ClientContext cli_ctx; + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, fixture->cq())); + void* t; + bool ok; + response_reader->Finish(&recv_response, &recv_status, tag(4)); + GPR_ASSERT(fixture->cq()->Next(&t, &ok)); + GPR_ASSERT(ok); + GPR_ASSERT(t == tag(0) || t == tag(1)); + intptr_t slot = reinterpret_cast(t); + ServerEnv* senv = server_env[slot]; + senv->response_writer.Finish(send_response, Status::OK, tag(3)); + for (int i = (1 << 3) | (1 << 4); i != 0;) { + GPR_ASSERT(fixture->cq()->Next(&t, &ok)); + GPR_ASSERT(ok); + int tagnum = static_cast(reinterpret_cast(t)); + GPR_ASSERT(i & (1 << tagnum)); + i -= 1 << tagnum; + } + GPR_ASSERT(recv_status.ok()); + + senv->~ServerEnv(); + senv = new (senv) ServerEnv(); + service.RequestEcho(&senv->ctx, &senv->recv_request, &senv->response_writer, + fixture->cq(), fixture->cq(), tag(slot)); + } + + double writes_per_iteration = + static_cast(fixture->writes_performed()) / + static_cast(kIterations); + + fixture.reset(); + server_env[0]->~ServerEnv(); + server_env[1]->~ServerEnv(); + + return writes_per_iteration; +} + +TEST(WritesPerRpcTest, UnaryPingPong) { + EXPECT_LT(UnaryPingPong(0, 0), 2.05); + EXPECT_LT(UnaryPingPong(1, 0), 2.05); + EXPECT_LT(UnaryPingPong(0, 1), 2.05); + EXPECT_LT(UnaryPingPong(4096, 0), 2.5); + EXPECT_LT(UnaryPingPong(0, 4096), 2.5); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/cpp/qps/benchmark_config.cc b/test/cpp/qps/benchmark_config.cc new file mode 100644 index 00000000..3a44eab8 --- /dev/null +++ b/test/cpp/qps/benchmark_config.cc @@ -0,0 +1,90 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/benchmark_config.h" + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(bool, enable_log_reporter, true, + "Enable reporting of benchmark results through GprLog"); + +ABSL_FLAG(std::string, scenario_result_file, "", + "Write JSON benchmark report to the file specified."); + +ABSL_FLAG(std::string, hashed_id, "", "Hash of the user id"); + +ABSL_FLAG(std::string, test_name, "", "Name of the test being executed"); + +ABSL_FLAG(std::string, sys_info, "", "System information"); + +ABSL_FLAG(std::string, server_address, "localhost:50052", + "Address of the performance database server"); + +ABSL_FLAG(std::string, tag, "", "Optional tag for the test"); + +ABSL_FLAG(std::string, rpc_reporter_server_address, "", + "Server address for rpc reporter to send results to"); + +ABSL_FLAG(bool, enable_rpc_reporter, false, "Enable use of RPC reporter"); + +ABSL_FLAG( + std::string, rpc_reporter_credential_type, + grpc::testing::kInsecureCredentialsType, + "Credential type for communication to the QPS benchmark report server"); + +namespace grpc { +namespace testing { + +static std::shared_ptr InitBenchmarkReporters() { + auto* composite_reporter = new CompositeReporter; + if (absl::GetFlag(FLAGS_enable_log_reporter)) { + composite_reporter->add( + std::unique_ptr(new GprLogReporter("LogReporter"))); + } + if (!absl::GetFlag(FLAGS_scenario_result_file).empty()) { + composite_reporter->add(std::unique_ptr(new JsonReporter( + "JsonReporter", absl::GetFlag(FLAGS_scenario_result_file)))); + } + if (absl::GetFlag(FLAGS_enable_rpc_reporter)) { + ChannelArguments channel_args; + std::shared_ptr channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials( + absl::GetFlag(FLAGS_rpc_reporter_credential_type), &channel_args); + GPR_ASSERT(!absl::GetFlag(FLAGS_rpc_reporter_server_address).empty()); + composite_reporter->add(std::unique_ptr(new RpcReporter( + "RpcReporter", + grpc::CreateChannel(absl::GetFlag(FLAGS_rpc_reporter_server_address), + channel_creds)))); + } + + return std::shared_ptr(composite_reporter); +} + +std::shared_ptr GetReporter() { + static std::shared_ptr reporter(InitBenchmarkReporters()); + return reporter; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc new file mode 100644 index 00000000..4e57de3a --- /dev/null +++ b/test/cpp/qps/client_async.cc @@ -0,0 +1,961 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/surface/completion_queue.h" +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/cpp/qps/client.h" +#include "test/cpp/qps/usage_timer.h" +#include "test/cpp/util/create_test_channel.h" + +namespace grpc { +namespace testing { + +class ClientRpcContext { + public: + ClientRpcContext() {} + virtual ~ClientRpcContext() {} + // next state, return false if done. Collect stats when appropriate + virtual bool RunNextState(bool, HistogramEntry* entry) = 0; + virtual void StartNewClone(CompletionQueue* cq) = 0; + static void* tag(ClientRpcContext* c) { return static_cast(c); } + static ClientRpcContext* detag(void* t) { + return static_cast(t); + } + + virtual void Start(CompletionQueue* cq, const ClientConfig& config) = 0; + virtual void TryCancel() = 0; +}; + +template +class ClientRpcContextUnaryImpl : public ClientRpcContext { + public: + ClientRpcContextUnaryImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function next_issue, + std::function< + std::unique_ptr>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*)> + prepare_req, + std::function on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::READY), + callback_(on_done), + next_issue_(std::move(next_issue)), + prepare_req_(prepare_req) {} + ~ClientRpcContextUnaryImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + GPR_ASSERT(!config.use_coalesce_api()); // not supported. + StartInternal(cq); + } + bool RunNextState(bool /*ok*/, HistogramEntry* entry) override { + switch (next_state_) { + case State::READY: + start_ = UsageTimer::Now(); + response_reader_ = prepare_req_(stub_, &context_, req_, cq_); + response_reader_->StartCall(); + next_state_ = State::RESP_DONE; + response_reader_->Finish(&response_, &status_, + ClientRpcContext::tag(this)); + return true; + case State::RESP_DONE: + if (status_.ok()) { + entry->set_value((UsageTimer::Now() - start_) * 1e9); + } + callback_(status_, &response_, entry); + next_state_ = State::INVALID; + return false; + default: + GPR_ASSERT(false); + return false; + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextUnaryImpl(stub_, req_, next_issue_, + prepare_req_, callback_); + clone->StartInternal(cq); + } + void TryCancel() override { context_.TryCancel(); } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr alarm_; + const RequestType& req_; + ResponseType response_; + enum State { INVALID, READY, RESP_DONE }; + State next_state_; + std::function callback_; + std::function next_issue_; + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*)> + prepare_req_; + grpc::Status status_; + double start_; + std::unique_ptr> + response_reader_; + + void StartInternal(CompletionQueue* cq) { + cq_ = cq; + if (!next_issue_) { // ready to issue + RunNextState(true, nullptr); + } else { // wait for the issue time + alarm_ = absl::make_unique(); + alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this)); + } + } +}; + +template +class AsyncClient : public ClientImpl { + // Specify which protected members we are using since there is no + // member name resolution until the template types are fully resolved + public: + using Client::closed_loop_; + using Client::NextIssuer; + using Client::SetupLoadTest; + using ClientImpl::cores_; + using ClientImpl::channels_; + using ClientImpl::request_; + AsyncClient(const ClientConfig& config, + std::function next_issue, + const RequestType&)> + setup_ctx, + std::function(std::shared_ptr)> + create_stub) + : ClientImpl(config, create_stub), + num_async_threads_(NumThreads(config)) { + SetupLoadTest(config, num_async_threads_); + + int tpc = std::max(1, config.threads_per_cq()); // 1 if unspecified + int num_cqs = (num_async_threads_ + tpc - 1) / tpc; // ceiling operator + for (int i = 0; i < num_cqs; i++) { + cli_cqs_.emplace_back(new CompletionQueue); + } + + for (int i = 0; i < num_async_threads_; i++) { + cq_.emplace_back(i % cli_cqs_.size()); + next_issuers_.emplace_back(NextIssuer(i)); + shutdown_state_.emplace_back(new PerThreadShutdownState()); + } + + int t = 0; + for (int ch = 0; ch < config.client_channels(); ch++) { + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + auto* cq = cli_cqs_[t].get(); + auto ctx = + setup_ctx(channels_[ch].get_stub(), next_issuers_[t], request_); + ctx->Start(cq, config); + } + t = (t + 1) % cli_cqs_.size(); + } + } + ~AsyncClient() override { + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + void* got_tag; + bool ok; + while ((*cq)->Next(&got_tag, &ok)) { + delete ClientRpcContext::detag(got_tag); + } + } + } + + int GetPollCount() override { + int count = 0; + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + count += grpc_get_cq_poll_num((*cq)->cq()); + } + return count; + } + + protected: + const int num_async_threads_; + + private: + struct PerThreadShutdownState { + mutable std::mutex mutex; + bool shutdown; + PerThreadShutdownState() : shutdown(false) {} + }; + + int NumThreads(const ClientConfig& config) { + int num_threads = config.async_client_threads(); + if (num_threads <= 0) { // Use dynamic sizing + num_threads = cores_; + gpr_log(GPR_INFO, "Sizing async client to %d threads", num_threads); + } + return num_threads; + } + void DestroyMultithreading() final { + for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { + std::lock_guard lock((*ss)->mutex); + (*ss)->shutdown = true; + } + for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) { + (*cq)->Shutdown(); + } + this->EndThreads(); // this needed for resolution + } + + ClientRpcContext* ProcessTag(size_t thread_idx, void* tag) { + ClientRpcContext* ctx = ClientRpcContext::detag(tag); + if (shutdown_state_[thread_idx]->shutdown) { + ctx->TryCancel(); + delete ctx; + bool ok; + while (cli_cqs_[cq_[thread_idx]]->Next(&tag, &ok)) { + ctx = ClientRpcContext::detag(tag); + ctx->TryCancel(); + delete ctx; + } + return nullptr; + } + return ctx; + } + + void ThreadFunc(size_t thread_idx, Client::Thread* t) final { + void* got_tag; + bool ok; + + HistogramEntry entry; + HistogramEntry* entry_ptr = &entry; + if (!cli_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) { + return; + } + std::mutex* shutdown_mu = &shutdown_state_[thread_idx]->mutex; + shutdown_mu->lock(); + ClientRpcContext* ctx = ProcessTag(thread_idx, got_tag); + if (ctx == nullptr) { + shutdown_mu->unlock(); + return; + } + while (cli_cqs_[cq_[thread_idx]]->DoThenAsyncNext( + [&, ctx, ok, entry_ptr, shutdown_mu]() { + if (!ctx->RunNextState(ok, entry_ptr)) { + // The RPC and callback are done, so clone the ctx + // and kickstart the new one + ctx->StartNewClone(cli_cqs_[cq_[thread_idx]].get()); + delete ctx; + } + shutdown_mu->unlock(); + }, + &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))) { + t->UpdateHistogram(entry_ptr); + entry = HistogramEntry(); + shutdown_mu->lock(); + ctx = ProcessTag(thread_idx, got_tag); + if (ctx == nullptr) { + shutdown_mu->unlock(); + return; + } + } + } + + std::vector> cli_cqs_; + std::vector cq_; + std::vector> next_issuers_; + std::vector> shutdown_state_; +}; + +static std::unique_ptr BenchmarkStubCreator( + const std::shared_ptr& ch) { + return BenchmarkService::NewStub(ch); +} + +class AsyncUnaryClient final + : public AsyncClient { + public: + explicit AsyncUnaryClient(const ClientConfig& config) + : AsyncClient( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + ~AsyncUnaryClient() override {} + + private: + static void CheckDone(const grpc::Status& s, SimpleResponse* /*response*/, + HistogramEntry* entry) { + entry->set_status(s.error_code()); + } + static std::unique_ptr> + PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + const SimpleRequest& request, CompletionQueue* cq) { + return stub->PrepareAsyncUnaryCall(ctx, request, cq); + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function next_issue, + const SimpleRequest& req) { + return new ClientRpcContextUnaryImpl( + stub, req, std::move(next_issue), AsyncUnaryClient::PrepareReq, + AsyncUnaryClient::CheckDone); + } +}; + +template +class ClientRpcContextStreamingPingPongImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingPingPongImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function next_issue, + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)> + prepare_req, + std::function on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(on_done), + next_issue_(std::move(next_issue)), + prepare_req_(prepare_req), + coalesce_(false) {} + ~ClientRpcContextStreamingPingPongImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + StartInternal(cq, config.messages_per_stream(), config.use_coalesce_api()); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!next_issue_) { // ready to issue + next_state_ = State::READY_TO_WRITE; + } else { + next_state_ = State::WAIT; + } + break; // loop around, don't return + case State::WAIT: + next_state_ = State::READY_TO_WRITE; + alarm_ = absl::make_unique(); + alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this)); + return true; + case State::READY_TO_WRITE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::WRITE_DONE; + if (coalesce_ && messages_issued_ == messages_per_stream_ - 1) { + stream_->WriteLast(req_, WriteOptions(), + ClientRpcContext::tag(this)); + } else { + stream_->Write(req_, ClientRpcContext::tag(this)); + } + return true; + case State::WRITE_DONE: + if (!ok) { + return false; + } + next_state_ = State::READ_DONE; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + break; + case State::READ_DONE: + entry->set_value((UsageTimer::Now() - start_) * 1e9); + callback_(status_, &response_); + if ((messages_per_stream_ != 0) && + (++messages_issued_ >= messages_per_stream_)) { + next_state_ = State::WRITES_DONE_DONE; + if (coalesce_) { + // WritesDone should have been called on the last Write. + // loop around to call Finish. + break; + } + stream_->WritesDone(ClientRpcContext::tag(this)); + return true; + } + next_state_ = State::STREAM_IDLE; + break; // loop around + case State::WRITES_DONE_DONE: + next_state_ = State::FINISH_DONE; + stream_->Finish(&status_, ClientRpcContext::tag(this)); + return true; + case State::FINISH_DONE: + next_state_ = State::INVALID; + return false; + break; + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextStreamingPingPongImpl( + stub_, req_, next_issue_, prepare_req_, callback_); + clone->StartInternal(cq, messages_per_stream_, coalesce_); + } + void TryCancel() override { context_.TryCancel(); } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr alarm_; + const RequestType& req_; + ResponseType response_; + enum State { + INVALID, + STREAM_IDLE, + WAIT, + READY_TO_WRITE, + WRITE_DONE, + READ_DONE, + WRITES_DONE_DONE, + FINISH_DONE + }; + State next_state_; + std::function callback_; + std::function next_issue_; + std::function< + std::unique_ptr>( + BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)> + prepare_req_; + grpc::Status status_; + double start_; + std::unique_ptr> + stream_; + + // Allow a limit on number of messages in a stream + int messages_per_stream_; + int messages_issued_; + // Whether to use coalescing API. + bool coalesce_; + + void StartInternal(CompletionQueue* cq, int messages_per_stream, + bool coalesce) { + cq_ = cq; + messages_per_stream_ = messages_per_stream; + messages_issued_ = 0; + coalesce_ = coalesce; + if (coalesce_) { + GPR_ASSERT(messages_per_stream_ != 0); + context_.set_initial_metadata_corked(true); + } + stream_ = prepare_req_(stub_, &context_, cq); + next_state_ = State::STREAM_IDLE; + stream_->StartCall(ClientRpcContext::tag(this)); + if (coalesce_) { + // When the initial metadata is corked, the tag will not come back and we + // need to manually drive the state machine. + RunNextState(true, nullptr); + } + } +}; + +class AsyncStreamingPingPongClient final + : public AsyncClient { + public: + explicit AsyncStreamingPingPongClient(const ClientConfig& config) + : AsyncClient( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + + ~AsyncStreamingPingPongClient() override {} + + private: + static void CheckDone(const grpc::Status& /*s*/, + SimpleResponse* /*response*/) {} + static std::unique_ptr< + grpc::ClientAsyncReaderWriter> + PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + CompletionQueue* cq) { + auto stream = stub->PrepareAsyncStreamingCall(ctx, cq); + return stream; + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function next_issue, + const SimpleRequest& req) { + return new ClientRpcContextStreamingPingPongImpl( + stub, req, std::move(next_issue), + AsyncStreamingPingPongClient::PrepareReq, + AsyncStreamingPingPongClient::CheckDone); + } +}; + +template +class ClientRpcContextStreamingFromClientImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingFromClientImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function next_issue, + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*, + CompletionQueue*)> + prepare_req, + std::function on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(on_done), + next_issue_(std::move(next_issue)), + prepare_req_(prepare_req) {} + ~ClientRpcContextStreamingFromClientImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + GPR_ASSERT(!config.use_coalesce_api()); // not supported yet. + StartInternal(cq); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!next_issue_) { // ready to issue + next_state_ = State::READY_TO_WRITE; + } else { + next_state_ = State::WAIT; + } + break; // loop around, don't return + case State::WAIT: + alarm_ = absl::make_unique(); + alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this)); + next_state_ = State::READY_TO_WRITE; + return true; + case State::READY_TO_WRITE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::WRITE_DONE; + stream_->Write(req_, ClientRpcContext::tag(this)); + return true; + case State::WRITE_DONE: + if (!ok) { + return false; + } + entry->set_value((UsageTimer::Now() - start_) * 1e9); + next_state_ = State::STREAM_IDLE; + break; // loop around + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextStreamingFromClientImpl( + stub_, req_, next_issue_, prepare_req_, callback_); + clone->StartInternal(cq); + } + void TryCancel() override { context_.TryCancel(); } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr alarm_; + const RequestType& req_; + ResponseType response_; + enum State { + INVALID, + STREAM_IDLE, + WAIT, + READY_TO_WRITE, + WRITE_DONE, + }; + State next_state_; + std::function callback_; + std::function next_issue_; + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*, + CompletionQueue*)> + prepare_req_; + grpc::Status status_; + double start_; + std::unique_ptr> stream_; + + void StartInternal(CompletionQueue* cq) { + cq_ = cq; + stream_ = prepare_req_(stub_, &context_, &response_, cq); + next_state_ = State::STREAM_IDLE; + stream_->StartCall(ClientRpcContext::tag(this)); + } +}; + +class AsyncStreamingFromClientClient final + : public AsyncClient { + public: + explicit AsyncStreamingFromClientClient(const ClientConfig& config) + : AsyncClient( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + + ~AsyncStreamingFromClientClient() override {} + + private: + static void CheckDone(const grpc::Status& /*s*/, + SimpleResponse* /*response*/) {} + static std::unique_ptr> PrepareReq( + BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + SimpleResponse* resp, CompletionQueue* cq) { + auto stream = stub->PrepareAsyncStreamingFromClient(ctx, resp, cq); + return stream; + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function next_issue, + const SimpleRequest& req) { + return new ClientRpcContextStreamingFromClientImpl( + stub, req, std::move(next_issue), + AsyncStreamingFromClientClient::PrepareReq, + AsyncStreamingFromClientClient::CheckDone); + } +}; + +template +class ClientRpcContextStreamingFromServerImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingFromServerImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function next_issue, + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*)> + prepare_req, + std::function on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(on_done), + next_issue_(std::move(next_issue)), + prepare_req_(prepare_req) {} + ~ClientRpcContextStreamingFromServerImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + GPR_ASSERT(!config.use_coalesce_api()); // not supported + StartInternal(cq); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::READ_DONE; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + case State::READ_DONE: + if (!ok) { + return false; + } + entry->set_value((UsageTimer::Now() - start_) * 1e9); + callback_(status_, &response_); + next_state_ = State::STREAM_IDLE; + break; // loop around + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextStreamingFromServerImpl( + stub_, req_, next_issue_, prepare_req_, callback_); + clone->StartInternal(cq); + } + void TryCancel() override { context_.TryCancel(); } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr alarm_; + const RequestType& req_; + ResponseType response_; + enum State { INVALID, STREAM_IDLE, READ_DONE }; + State next_state_; + std::function callback_; + std::function next_issue_; + std::function>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*)> + prepare_req_; + grpc::Status status_; + double start_; + std::unique_ptr> stream_; + + void StartInternal(CompletionQueue* cq) { + // TODO(vjpai): Add support to rate-pace this + cq_ = cq; + stream_ = prepare_req_(stub_, &context_, req_, cq); + next_state_ = State::STREAM_IDLE; + stream_->StartCall(ClientRpcContext::tag(this)); + } +}; + +class AsyncStreamingFromServerClient final + : public AsyncClient { + public: + explicit AsyncStreamingFromServerClient(const ClientConfig& config) + : AsyncClient( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + + ~AsyncStreamingFromServerClient() override {} + + private: + static void CheckDone(const grpc::Status& /*s*/, + SimpleResponse* /*response*/) {} + static std::unique_ptr> PrepareReq( + BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + const SimpleRequest& req, CompletionQueue* cq) { + auto stream = stub->PrepareAsyncStreamingFromServer(ctx, req, cq); + return stream; + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function next_issue, + const SimpleRequest& req) { + return new ClientRpcContextStreamingFromServerImpl( + stub, req, std::move(next_issue), + AsyncStreamingFromServerClient::PrepareReq, + AsyncStreamingFromServerClient::CheckDone); + } +}; + +class ClientRpcContextGenericStreamingImpl : public ClientRpcContext { + public: + ClientRpcContextGenericStreamingImpl( + grpc::GenericStub* stub, const ByteBuffer& req, + std::function next_issue, + std::function( + grpc::GenericStub*, grpc::ClientContext*, + const std::string& method_name, CompletionQueue*)> + prepare_req, + std::function on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(std::move(on_done)), + next_issue_(std::move(next_issue)), + prepare_req_(std::move(prepare_req)) {} + ~ClientRpcContextGenericStreamingImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + GPR_ASSERT(!config.use_coalesce_api()); // not supported yet. + StartInternal(cq, config.messages_per_stream()); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!next_issue_) { // ready to issue + next_state_ = State::READY_TO_WRITE; + } else { + next_state_ = State::WAIT; + } + break; // loop around, don't return + case State::WAIT: + next_state_ = State::READY_TO_WRITE; + alarm_ = absl::make_unique(); + alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this)); + return true; + case State::READY_TO_WRITE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::WRITE_DONE; + stream_->Write(req_, ClientRpcContext::tag(this)); + return true; + case State::WRITE_DONE: + if (!ok) { + return false; + } + next_state_ = State::READ_DONE; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + case State::READ_DONE: + entry->set_value((UsageTimer::Now() - start_) * 1e9); + callback_(status_, &response_); + if ((messages_per_stream_ != 0) && + (++messages_issued_ >= messages_per_stream_)) { + next_state_ = State::WRITES_DONE_DONE; + stream_->WritesDone(ClientRpcContext::tag(this)); + return true; + } + next_state_ = State::STREAM_IDLE; + break; // loop around + case State::WRITES_DONE_DONE: + next_state_ = State::FINISH_DONE; + stream_->Finish(&status_, ClientRpcContext::tag(this)); + return true; + case State::FINISH_DONE: + next_state_ = State::INVALID; + return false; + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextGenericStreamingImpl( + stub_, req_, next_issue_, prepare_req_, callback_); + clone->StartInternal(cq, messages_per_stream_); + } + void TryCancel() override { context_.TryCancel(); } + + private: + grpc::ClientContext context_; + grpc::GenericStub* stub_; + CompletionQueue* cq_; + std::unique_ptr alarm_; + ByteBuffer req_; + ByteBuffer response_; + enum State { + INVALID, + STREAM_IDLE, + WAIT, + READY_TO_WRITE, + WRITE_DONE, + READ_DONE, + WRITES_DONE_DONE, + FINISH_DONE + }; + State next_state_; + std::function callback_; + std::function next_issue_; + std::function( + grpc::GenericStub*, grpc::ClientContext*, const std::string&, + CompletionQueue*)> + prepare_req_; + grpc::Status status_; + double start_; + std::unique_ptr stream_; + + // Allow a limit on number of messages in a stream + int messages_per_stream_; + int messages_issued_; + + void StartInternal(CompletionQueue* cq, int messages_per_stream) { + cq_ = cq; + const std::string kMethodName( + "/grpc.testing.BenchmarkService/StreamingCall"); + messages_per_stream_ = messages_per_stream; + messages_issued_ = 0; + stream_ = prepare_req_(stub_, &context_, kMethodName, cq); + next_state_ = State::STREAM_IDLE; + stream_->StartCall(ClientRpcContext::tag(this)); + } +}; + +static std::unique_ptr GenericStubCreator( + const std::shared_ptr& ch) { + return absl::make_unique(ch); +} + +class GenericAsyncStreamingClient final + : public AsyncClient { + public: + explicit GenericAsyncStreamingClient(const ClientConfig& config) + : AsyncClient(config, SetupCtx, + GenericStubCreator) { + StartThreads(num_async_threads_); + } + + ~GenericAsyncStreamingClient() override {} + + private: + static void CheckDone(const grpc::Status& /*s*/, ByteBuffer* /*response*/) {} + static std::unique_ptr PrepareReq( + grpc::GenericStub* stub, grpc::ClientContext* ctx, + const std::string& method_name, CompletionQueue* cq) { + auto stream = stub->PrepareCall(ctx, method_name, cq); + return stream; + }; + static ClientRpcContext* SetupCtx(grpc::GenericStub* stub, + std::function next_issue, + const ByteBuffer& req) { + return new ClientRpcContextGenericStreamingImpl( + stub, req, std::move(next_issue), + GenericAsyncStreamingClient::PrepareReq, + GenericAsyncStreamingClient::CheckDone); + } +}; + +std::unique_ptr CreateAsyncClient(const ClientConfig& config) { + switch (config.rpc_type()) { + case UNARY: + return std::unique_ptr(new AsyncUnaryClient(config)); + case STREAMING: + return std::unique_ptr(new AsyncStreamingPingPongClient(config)); + case STREAMING_FROM_CLIENT: + return std::unique_ptr( + new AsyncStreamingFromClientClient(config)); + case STREAMING_FROM_SERVER: + return std::unique_ptr( + new AsyncStreamingFromServerClient(config)); + case STREAMING_BOTH_WAYS: + // TODO(vjpai): Implement this + assert(false); + return nullptr; + default: + assert(false); + return nullptr; + } +} +std::unique_ptr CreateGenericAsyncStreamingClient( + const ClientConfig& config) { + return std::unique_ptr(new GenericAsyncStreamingClient(config)); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc new file mode 100644 index 00000000..3119d2a1 --- /dev/null +++ b/test/cpp/qps/client_callback.cc @@ -0,0 +1,390 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/cpp/qps/client.h" +#include "test/cpp/qps/usage_timer.h" + +namespace grpc { +namespace testing { + +/** + * Maintains context info per RPC + */ +struct CallbackClientRpcContext { + explicit CallbackClientRpcContext(BenchmarkService::Stub* stub) + : alarm_(nullptr), stub_(stub) {} + + ~CallbackClientRpcContext() {} + + SimpleResponse response_; + ClientContext context_; + std::unique_ptr alarm_; + BenchmarkService::Stub* stub_; +}; + +static std::unique_ptr BenchmarkStubCreator( + const std::shared_ptr& ch) { + return BenchmarkService::NewStub(ch); +} + +class CallbackClient + : public ClientImpl { + public: + explicit CallbackClient(const ClientConfig& config) + : ClientImpl( + config, BenchmarkStubCreator) { + num_threads_ = NumThreads(config); + rpcs_done_ = 0; + + // Don't divide the fixed load among threads as the user threads + // only bootstrap the RPCs + SetupLoadTest(config, 1); + total_outstanding_rpcs_ = + config.client_channels() * config.outstanding_rpcs_per_channel(); + } + + ~CallbackClient() override {} + + /** + * The main thread of the benchmark will be waiting on DestroyMultithreading. + * Increment the rpcs_done_ variable to signify that the Callback RPC + * after thread completion is done. When the last outstanding rpc increments + * the counter it should also signal the main thread's conditional variable. + */ + void NotifyMainThreadOfThreadCompletion() { + std::lock_guard l(shutdown_mu_); + rpcs_done_++; + if (rpcs_done_ == total_outstanding_rpcs_) { + shutdown_cv_.notify_one(); + } + } + + gpr_timespec NextRPCIssueTime() { + std::lock_guard l(next_issue_time_mu_); + return Client::NextIssueTime(0); + } + + protected: + size_t num_threads_; + size_t total_outstanding_rpcs_; + // The below mutex and condition variable is used by main benchmark thread to + // wait on completion of all RPCs before shutdown + std::mutex shutdown_mu_; + std::condition_variable shutdown_cv_; + // Number of rpcs done after thread completion + size_t rpcs_done_; + // Vector of Context data pointers for running a RPC + std::vector> ctx_; + + virtual void InitThreadFuncImpl(size_t thread_idx) = 0; + virtual bool ThreadFuncImpl(Thread* t, size_t thread_idx) = 0; + + void ThreadFunc(size_t thread_idx, Thread* t) override { + InitThreadFuncImpl(thread_idx); + ThreadFuncImpl(t, thread_idx); + } + + private: + std::mutex next_issue_time_mu_; // Used by next issue time + + int NumThreads(const ClientConfig& config) { + int num_threads = config.async_client_threads(); + if (num_threads <= 0) { // Use dynamic sizing + num_threads = cores_; + gpr_log(GPR_INFO, "Sizing callback client to %d threads", num_threads); + } + return num_threads; + } + + /** + * Wait until all outstanding Callback RPCs are done + */ + void DestroyMultithreading() final { + std::unique_lock l(shutdown_mu_); + while (rpcs_done_ != total_outstanding_rpcs_) { + shutdown_cv_.wait(l); + } + EndThreads(); + } +}; + +class CallbackUnaryClient final : public CallbackClient { + public: + explicit CallbackUnaryClient(const ClientConfig& config) + : CallbackClient(config) { + for (int ch = 0; ch < config.client_channels(); ch++) { + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + ctx_.emplace_back( + new CallbackClientRpcContext(channels_[ch].get_stub())); + } + } + StartThreads(num_threads_); + } + ~CallbackUnaryClient() override {} + + protected: + bool ThreadFuncImpl(Thread* t, size_t thread_idx) override { + for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; + vector_idx += num_threads_) { + ScheduleRpc(t, vector_idx); + } + return true; + } + + void InitThreadFuncImpl(size_t /*thread_idx*/) override {} + + private: + void ScheduleRpc(Thread* t, size_t vector_idx) { + if (!closed_loop_) { + gpr_timespec next_issue_time = NextRPCIssueTime(); + // Start an alarm callback to run the internal callback after + // next_issue_time + if (ctx_[vector_idx]->alarm_ == nullptr) { + ctx_[vector_idx]->alarm_ = absl::make_unique(); + } + ctx_[vector_idx]->alarm_->Set(next_issue_time, + [this, t, vector_idx](bool /*ok*/) { + IssueUnaryCallbackRpc(t, vector_idx); + }); + } else { + IssueUnaryCallbackRpc(t, vector_idx); + } + } + + void IssueUnaryCallbackRpc(Thread* t, size_t vector_idx) { + GPR_TIMER_SCOPE("CallbackUnaryClient::ThreadFunc", 0); + double start = UsageTimer::Now(); + ctx_[vector_idx]->stub_->async()->UnaryCall( + (&ctx_[vector_idx]->context_), &request_, &ctx_[vector_idx]->response_, + [this, t, start, vector_idx](grpc::Status s) { + // Update Histogram with data from the callback run + HistogramEntry entry; + if (s.ok()) { + entry.set_value((UsageTimer::Now() - start) * 1e9); + } + entry.set_status(s.error_code()); + t->UpdateHistogram(&entry); + + if (ThreadCompleted() || !s.ok()) { + // Notify thread of completion + NotifyMainThreadOfThreadCompletion(); + } else { + // Reallocate ctx for next RPC + ctx_[vector_idx] = absl::make_unique( + ctx_[vector_idx]->stub_); + // Schedule a new RPC + ScheduleRpc(t, vector_idx); + } + }); + } +}; + +class CallbackStreamingClient : public CallbackClient { + public: + explicit CallbackStreamingClient(const ClientConfig& config) + : CallbackClient(config), + messages_per_stream_(config.messages_per_stream()) { + for (int ch = 0; ch < config.client_channels(); ch++) { + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + ctx_.emplace_back( + new CallbackClientRpcContext(channels_[ch].get_stub())); + } + } + StartThreads(num_threads_); + } + ~CallbackStreamingClient() override {} + + void AddHistogramEntry(double start, bool ok, Thread* thread_ptr) { + // Update Histogram with data from the callback run + HistogramEntry entry; + if (ok) { + entry.set_value((UsageTimer::Now() - start) * 1e9); + } + thread_ptr->UpdateHistogram(&entry); + } + + int messages_per_stream() { return messages_per_stream_; } + + protected: + const int messages_per_stream_; +}; + +class CallbackStreamingPingPongClient : public CallbackStreamingClient { + public: + explicit CallbackStreamingPingPongClient(const ClientConfig& config) + : CallbackStreamingClient(config) {} + ~CallbackStreamingPingPongClient() override {} +}; + +class CallbackStreamingPingPongReactor final + : public grpc::ClientBidiReactor { + public: + CallbackStreamingPingPongReactor( + CallbackStreamingPingPongClient* client, + std::unique_ptr ctx) + : client_(client), ctx_(std::move(ctx)), messages_issued_(0) {} + + void StartNewRpc() { + ctx_->stub_->async()->StreamingCall(&(ctx_->context_), this); + write_time_ = UsageTimer::Now(); + StartWrite(client_->request()); + writes_done_started_.clear(); + StartCall(); + } + + void OnWriteDone(bool ok) override { + if (!ok) { + gpr_log(GPR_ERROR, "Error writing RPC"); + } + if ((!ok || client_->ThreadCompleted()) && + !writes_done_started_.test_and_set()) { + StartWritesDone(); + } + StartRead(&ctx_->response_); + } + + void OnReadDone(bool ok) override { + client_->AddHistogramEntry(write_time_, ok, thread_ptr_); + + if (client_->ThreadCompleted() || !ok || + (client_->messages_per_stream() != 0 && + ++messages_issued_ >= client_->messages_per_stream())) { + if (!ok) { + gpr_log(GPR_ERROR, "Error reading RPC"); + } + if (!writes_done_started_.test_and_set()) { + StartWritesDone(); + } + return; + } + if (!client_->IsClosedLoop()) { + gpr_timespec next_issue_time = client_->NextRPCIssueTime(); + // Start an alarm callback to run the internal callback after + // next_issue_time + ctx_->alarm_->Set(next_issue_time, [this](bool /*ok*/) { + write_time_ = UsageTimer::Now(); + StartWrite(client_->request()); + }); + } else { + write_time_ = UsageTimer::Now(); + StartWrite(client_->request()); + } + } + + void OnDone(const Status& s) override { + if (client_->ThreadCompleted() || !s.ok()) { + client_->NotifyMainThreadOfThreadCompletion(); + return; + } + ctx_ = absl::make_unique(ctx_->stub_); + ScheduleRpc(); + } + + void ScheduleRpc() { + if (!client_->IsClosedLoop()) { + gpr_timespec next_issue_time = client_->NextRPCIssueTime(); + // Start an alarm callback to run the internal callback after + // next_issue_time + if (ctx_->alarm_ == nullptr) { + ctx_->alarm_ = absl::make_unique(); + } + ctx_->alarm_->Set(next_issue_time, + [this](bool /*ok*/) { StartNewRpc(); }); + } else { + StartNewRpc(); + } + } + + void set_thread_ptr(Client::Thread* ptr) { thread_ptr_ = ptr; } + + CallbackStreamingPingPongClient* client_; + std::unique_ptr ctx_; + std::atomic_flag writes_done_started_; + Client::Thread* thread_ptr_; // Needed to update histogram entries + double write_time_; // Track ping-pong round start time + int messages_issued_; // Messages issued by this stream +}; + +class CallbackStreamingPingPongClientImpl final + : public CallbackStreamingPingPongClient { + public: + explicit CallbackStreamingPingPongClientImpl(const ClientConfig& config) + : CallbackStreamingPingPongClient(config) { + for (size_t i = 0; i < total_outstanding_rpcs_; i++) { + reactor_.emplace_back( + new CallbackStreamingPingPongReactor(this, std::move(ctx_[i]))); + } + } + ~CallbackStreamingPingPongClientImpl() override {} + + bool ThreadFuncImpl(Client::Thread* t, size_t thread_idx) override { + for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; + vector_idx += num_threads_) { + reactor_[vector_idx]->set_thread_ptr(t); + reactor_[vector_idx]->ScheduleRpc(); + } + return true; + } + + void InitThreadFuncImpl(size_t /*thread_idx*/) override {} + + private: + std::vector> reactor_; +}; + +// TODO(mhaidry) : Implement Streaming from client, server and both ways + +std::unique_ptr CreateCallbackClient(const ClientConfig& config) { + switch (config.rpc_type()) { + case UNARY: + return std::unique_ptr(new CallbackUnaryClient(config)); + case STREAMING: + return std::unique_ptr( + new CallbackStreamingPingPongClientImpl(config)); + case STREAMING_FROM_CLIENT: + case STREAMING_FROM_SERVER: + case STREAMING_BOTH_WAYS: + assert(false); + return nullptr; + default: + assert(false); + return nullptr; + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/client_sync.cc b/test/cpp/qps/client_sync.cc new file mode 100644 index 00000000..72fff1a8 --- /dev/null +++ b/test/cpp/qps/client_sync.cc @@ -0,0 +1,428 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/profiling/timers.h" +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/cpp/qps/client.h" +#include "test/cpp/qps/interarrival.h" +#include "test/cpp/qps/usage_timer.h" + +namespace grpc { +namespace testing { + +static std::unique_ptr BenchmarkStubCreator( + const std::shared_ptr& ch) { + return BenchmarkService::NewStub(ch); +} + +class SynchronousClient + : public ClientImpl { + public: + explicit SynchronousClient(const ClientConfig& config) + : ClientImpl( + config, BenchmarkStubCreator) { + num_threads_ = + config.outstanding_rpcs_per_channel() * config.client_channels(); + responses_.resize(num_threads_); + SetupLoadTest(config, num_threads_); + } + + ~SynchronousClient() override {} + + virtual bool InitThreadFuncImpl(size_t thread_idx) = 0; + virtual bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) = 0; + + void ThreadFunc(size_t thread_idx, Thread* t) override { + if (!InitThreadFuncImpl(thread_idx)) { + return; + } + for (;;) { + // run the loop body + HistogramEntry entry; + const bool thread_still_ok = ThreadFuncImpl(&entry, thread_idx); + t->UpdateHistogram(&entry); + if (!thread_still_ok || ThreadCompleted()) { + return; + } + } + } + + protected: + // WaitToIssue returns false if we realize that we need to break out + bool WaitToIssue(int thread_idx) { + if (!closed_loop_) { + const gpr_timespec next_issue_time = NextIssueTime(thread_idx); + // Avoid sleeping for too long continuously because we might + // need to terminate before then. This is an issue since + // exponential distribution can occasionally produce bad outliers + while (true) { + const gpr_timespec one_sec_delay = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(1, GPR_TIMESPAN)); + if (gpr_time_cmp(next_issue_time, one_sec_delay) <= 0) { + gpr_sleep_until(next_issue_time); + return true; + } else { + gpr_sleep_until(one_sec_delay); + if (gpr_atm_acq_load(&thread_pool_done_) != static_cast(0)) { + return false; + } + } + } + } + return true; + } + + size_t num_threads_; + std::vector responses_; +}; + +class SynchronousUnaryClient final : public SynchronousClient { + public: + explicit SynchronousUnaryClient(const ClientConfig& config) + : SynchronousClient(config) { + StartThreads(num_threads_); + } + ~SynchronousUnaryClient() override {} + + bool InitThreadFuncImpl(size_t /*thread_idx*/) override { return true; } + + bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { + if (!WaitToIssue(thread_idx)) { + return true; + } + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + double start = UsageTimer::Now(); + GPR_TIMER_SCOPE("SynchronousUnaryClient::ThreadFunc", 0); + grpc::ClientContext context; + grpc::Status s = + stub->UnaryCall(&context, request_, &responses_[thread_idx]); + if (s.ok()) { + entry->set_value((UsageTimer::Now() - start) * 1e9); + } + entry->set_status(s.error_code()); + return true; + } + + private: + void DestroyMultithreading() final { EndThreads(); } +}; + +template +class SynchronousStreamingClient : public SynchronousClient { + public: + explicit SynchronousStreamingClient(const ClientConfig& config) + : SynchronousClient(config), + context_(num_threads_), + stream_(num_threads_), + stream_mu_(num_threads_), + shutdown_(num_threads_), + messages_per_stream_(config.messages_per_stream()), + messages_issued_(num_threads_) { + StartThreads(num_threads_); + } + ~SynchronousStreamingClient() override { + CleanupAllStreams([this](size_t thread_idx) { + // Don't log any kind of error since we may have canceled this + stream_[thread_idx]->Finish().IgnoreError(); + }); + } + + protected: + std::vector context_; + std::vector> stream_; + // stream_mu_ is only needed when changing an element of stream_ or context_ + std::vector stream_mu_; + // use struct Bool rather than bool because vector is not concurrent + struct Bool { + bool val; + Bool() : val(false) {} + }; + std::vector shutdown_; + const int messages_per_stream_; + std::vector messages_issued_; + + void FinishStream(HistogramEntry* entry, size_t thread_idx) { + Status s = stream_[thread_idx]->Finish(); + // don't set the value since the stream is failed and shouldn't be timed + entry->set_status(s.error_code()); + if (!s.ok()) { + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", + thread_idx, s.error_message().c_str()); + } + } + // Lock the stream_mu_ now because the client context could change + std::lock_guard l(stream_mu_[thread_idx]); + context_[thread_idx].~ClientContext(); + new (&context_[thread_idx]) ClientContext(); + } + + void CleanupAllStreams(const std::function& cleaner) { + std::vector cleanup_threads; + for (size_t i = 0; i < num_threads_; i++) { + cleanup_threads.emplace_back([this, i, cleaner] { + std::lock_guard l(stream_mu_[i]); + shutdown_[i].val = true; + if (stream_[i]) { + cleaner(i); + } + }); + } + for (auto& th : cleanup_threads) { + th.join(); + } + } + + private: + void DestroyMultithreading() final { + CleanupAllStreams( + [this](size_t thread_idx) { context_[thread_idx].TryCancel(); }); + EndThreads(); + } +}; + +class SynchronousStreamingPingPongClient final + : public SynchronousStreamingClient< + grpc::ClientReaderWriter> { + public: + explicit SynchronousStreamingPingPongClient(const ClientConfig& config) + : SynchronousStreamingClient(config) {} + ~SynchronousStreamingPingPongClient() override { + CleanupAllStreams( + [this](size_t thread_idx) { stream_[thread_idx]->WritesDone(); }); + } + + private: + bool InitThreadFuncImpl(size_t thread_idx) override { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + } else { + return false; + } + messages_issued_[thread_idx] = 0; + return true; + } + + bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { + if (!WaitToIssue(thread_idx)) { + return true; + } + GPR_TIMER_SCOPE("SynchronousStreamingPingPongClient::ThreadFunc", 0); + double start = UsageTimer::Now(); + if (stream_[thread_idx]->Write(request_) && + stream_[thread_idx]->Read(&responses_[thread_idx])) { + entry->set_value((UsageTimer::Now() - start) * 1e9); + // don't set the status since there isn't one yet + if ((messages_per_stream_ != 0) && + (++messages_issued_[thread_idx] < messages_per_stream_)) { + return true; + } else if (messages_per_stream_ == 0) { + return true; + } else { + // Fall through to the below resetting code after finish + } + } + stream_[thread_idx]->WritesDone(); + FinishStream(entry, thread_idx); + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + } else { + stream_[thread_idx].reset(); + return false; + } + messages_issued_[thread_idx] = 0; + return true; + } +}; + +class SynchronousStreamingFromClientClient final + : public SynchronousStreamingClient> { + public: + explicit SynchronousStreamingFromClientClient(const ClientConfig& config) + : SynchronousStreamingClient(config), last_issue_(num_threads_) {} + ~SynchronousStreamingFromClientClient() override { + CleanupAllStreams( + [this](size_t thread_idx) { stream_[thread_idx]->WritesDone(); }); + } + + private: + std::vector last_issue_; + + bool InitThreadFuncImpl(size_t thread_idx) override { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + } else { + return false; + } + last_issue_[thread_idx] = UsageTimer::Now(); + return true; + } + + bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { + // Figure out how to make histogram sensible if this is rate-paced + if (!WaitToIssue(thread_idx)) { + return true; + } + GPR_TIMER_SCOPE("SynchronousStreamingFromClientClient::ThreadFunc", 0); + if (stream_[thread_idx]->Write(request_)) { + double now = UsageTimer::Now(); + entry->set_value((now - last_issue_[thread_idx]) * 1e9); + last_issue_[thread_idx] = now; + return true; + } + stream_[thread_idx]->WritesDone(); + FinishStream(entry, thread_idx); + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + } else { + stream_[thread_idx].reset(); + return false; + } + return true; + } +}; + +class SynchronousStreamingFromServerClient final + : public SynchronousStreamingClient> { + public: + explicit SynchronousStreamingFromServerClient(const ClientConfig& config) + : SynchronousStreamingClient(config), last_recv_(num_threads_) {} + ~SynchronousStreamingFromServerClient() override {} + + private: + std::vector last_recv_; + + bool InitThreadFuncImpl(size_t thread_idx) override { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + } else { + return false; + } + last_recv_[thread_idx] = UsageTimer::Now(); + return true; + } + + bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { + GPR_TIMER_SCOPE("SynchronousStreamingFromServerClient::ThreadFunc", 0); + if (stream_[thread_idx]->Read(&responses_[thread_idx])) { + double now = UsageTimer::Now(); + entry->set_value((now - last_recv_[thread_idx]) * 1e9); + last_recv_[thread_idx] = now; + return true; + } + FinishStream(entry, thread_idx); + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + } else { + stream_[thread_idx].reset(); + return false; + } + return true; + } +}; + +class SynchronousStreamingBothWaysClient final + : public SynchronousStreamingClient< + grpc::ClientReaderWriter> { + public: + explicit SynchronousStreamingBothWaysClient(const ClientConfig& config) + : SynchronousStreamingClient(config) {} + ~SynchronousStreamingBothWaysClient() override { + CleanupAllStreams( + [this](size_t thread_idx) { stream_[thread_idx]->WritesDone(); }); + } + + private: + bool InitThreadFuncImpl(size_t thread_idx) override { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + std::lock_guard l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]); + } else { + return false; + } + return true; + } + + bool ThreadFuncImpl(HistogramEntry* /*entry*/, + size_t /*thread_idx*/) override { + // TODO (vjpai): Do this + return true; + } +}; + +std::unique_ptr CreateSynchronousClient(const ClientConfig& config) { + GPR_ASSERT(!config.use_coalesce_api()); // not supported yet. + switch (config.rpc_type()) { + case UNARY: + return std::unique_ptr(new SynchronousUnaryClient(config)); + case STREAMING: + return std::unique_ptr( + new SynchronousStreamingPingPongClient(config)); + case STREAMING_FROM_CLIENT: + return std::unique_ptr( + new SynchronousStreamingFromClientClient(config)); + case STREAMING_FROM_SERVER: + return std::unique_ptr( + new SynchronousStreamingFromServerClient(config)); + case STREAMING_BOTH_WAYS: + return std::unique_ptr( + new SynchronousStreamingBothWaysClient(config)); + default: + assert(false); + return nullptr; + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/driver.cc b/test/cpp/qps/driver.cc new file mode 100644 index 00000000..0a70084e --- /dev/null +++ b/test/cpp/qps/driver.cc @@ -0,0 +1,685 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/driver.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/profiling/timers.h" +#include "src/proto/grpc/testing/worker_service.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/qps/client.h" +#include "test/cpp/qps/histogram.h" +#include "test/cpp/qps/qps_worker.h" +#include "test/cpp/qps/stats.h" +#include "test/cpp/util/test_credentials_provider.h" + +using std::deque; +using std::list; +using std::unique_ptr; +using std::vector; + +namespace grpc { +namespace testing { +static std::string get_host(const std::string& worker) { + absl::string_view host; + absl::string_view port; + grpc_core::SplitHostPort(worker.c_str(), &host, &port); + return std::string(host.data(), host.size()); +} + +static deque get_workers(const string& env_name) { + deque out; + char* env = gpr_getenv(env_name.c_str()); + if (!env) { + env = gpr_strdup(""); + } + char* p = env; + if (strlen(env) != 0) { + for (;;) { + char* comma = strchr(p, ','); + if (comma) { + out.emplace_back(p, comma); + p = comma + 1; + } else { + out.emplace_back(p); + break; + } + } + } + if (out.empty()) { + gpr_log(GPR_ERROR, + "Environment variable \"%s\" does not contain a list of QPS " + "workers to use. Set it to a comma-separated list of " + "hostname:port pairs, starting with hosts that should act as " + "servers. E.g. export " + "%s=\"serverhost1:1234,clienthost1:1234,clienthost2:1234\"", + env_name.c_str(), env_name.c_str()); + } + gpr_free(env); + return out; +} + +std::string GetCredType( + const std::string& worker_addr, + const std::map& per_worker_credential_types, + const std::string& credential_type) { + auto it = per_worker_credential_types.find(worker_addr); + if (it != per_worker_credential_types.end()) { + return it->second; + } + return credential_type; +} + +// helpers for postprocess_scenario_result +static double WallTime(const ClientStats& s) { return s.time_elapsed(); } +static double SystemTime(const ClientStats& s) { return s.time_system(); } +static double UserTime(const ClientStats& s) { return s.time_user(); } +static double CliPollCount(const ClientStats& s) { return s.cq_poll_count(); } +static double SvrPollCount(const ServerStats& s) { return s.cq_poll_count(); } +static double ServerWallTime(const ServerStats& s) { return s.time_elapsed(); } +static double ServerSystemTime(const ServerStats& s) { return s.time_system(); } +static double ServerUserTime(const ServerStats& s) { return s.time_user(); } +static double ServerTotalCpuTime(const ServerStats& s) { + return s.total_cpu_time(); +} +static double ServerIdleCpuTime(const ServerStats& s) { + return s.idle_cpu_time(); +} +static int Cores(int n) { return n; } + +static bool IsSuccess(const Status& s) { + if (s.ok()) return true; + // Since we shutdown servers and clients at the same time, they both can + // observe cancellation. Thus, we consider CANCELLED as good status. + if (static_cast(s.error_code()) == StatusCode::CANCELLED) { + return true; + } + // Since we shutdown servers and clients at the same time, server can close + // the socket before the client attempts to do that, and vice versa. Thus + // receiving a "Socket closed" error is fine. + if (s.error_message() == "Socket closed") return true; + return false; +} + +// Postprocess ScenarioResult and populate result summary. +static void postprocess_scenario_result(ScenarioResult* result) { + // Get latencies from ScenarioResult latencies histogram and populate to + // result summary. + Histogram histogram; + histogram.MergeProto(result->latencies()); + result->mutable_summary()->set_latency_50(histogram.Percentile(50)); + result->mutable_summary()->set_latency_90(histogram.Percentile(90)); + result->mutable_summary()->set_latency_95(histogram.Percentile(95)); + result->mutable_summary()->set_latency_99(histogram.Percentile(99)); + result->mutable_summary()->set_latency_999(histogram.Percentile(99.9)); + + // Calculate qps and cpu load for each client and then aggregate results for + // all clients + double qps = 0; + double client_system_cpu_load = 0, client_user_cpu_load = 0; + for (int i = 0; i < result->client_stats_size(); i++) { + auto client_stat = result->client_stats(i); + qps += client_stat.latencies().count() / client_stat.time_elapsed(); + client_system_cpu_load += + client_stat.time_system() / client_stat.time_elapsed(); + client_user_cpu_load += + client_stat.time_user() / client_stat.time_elapsed(); + } + // Calculate cpu load for each server and then aggregate results for all + // servers + double server_system_cpu_load = 0, server_user_cpu_load = 0; + for (int i = 0; i < result->server_stats_size(); i++) { + auto server_stat = result->server_stats(i); + server_system_cpu_load += + server_stat.time_system() / server_stat.time_elapsed(); + server_user_cpu_load += + server_stat.time_user() / server_stat.time_elapsed(); + } + result->mutable_summary()->set_qps(qps); + // Populate the percentage of cpu load to result summary. + result->mutable_summary()->set_server_system_time(100 * + server_system_cpu_load); + result->mutable_summary()->set_server_user_time(100 * server_user_cpu_load); + result->mutable_summary()->set_client_system_time(100 * + client_system_cpu_load); + result->mutable_summary()->set_client_user_time(100 * client_user_cpu_load); + + // For Non-linux platform, get_cpu_usage() is not implemented. Thus, + // ServerTotalCpuTime and ServerIdleCpuTime are both 0. + if (average(result->server_stats(), ServerTotalCpuTime) == 0) { + result->mutable_summary()->set_server_cpu_usage(0); + } else { + auto server_cpu_usage = + 100 - 100 * average(result->server_stats(), ServerIdleCpuTime) / + average(result->server_stats(), ServerTotalCpuTime); + result->mutable_summary()->set_server_cpu_usage(server_cpu_usage); + } + + // Calculate and populate successful request per second and failed requests + // per seconds to result summary. + auto time_estimate = average(result->client_stats(), WallTime); + if (result->request_results_size() > 0) { + int64_t successes = 0; + int64_t failures = 0; + for (int i = 0; i < result->request_results_size(); i++) { + const RequestResultCount& rrc = result->request_results(i); + if (rrc.status_code() == 0) { + successes += rrc.count(); + } else { + failures += rrc.count(); + } + } + result->mutable_summary()->set_successful_requests_per_second( + successes / time_estimate); + result->mutable_summary()->set_failed_requests_per_second(failures / + time_estimate); + } + + // Fill in data for other metrics required in result summary + auto qps_per_server_core = qps / sum(result->server_cores(), Cores); + result->mutable_summary()->set_qps_per_server_core(qps_per_server_core); + result->mutable_summary()->set_client_polls_per_request( + sum(result->client_stats(), CliPollCount) / histogram.Count()); + result->mutable_summary()->set_server_polls_per_request( + sum(result->server_stats(), SvrPollCount) / histogram.Count()); + + auto server_queries_per_cpu_sec = + histogram.Count() / (sum(result->server_stats(), ServerSystemTime) + + sum(result->server_stats(), ServerUserTime)); + auto client_queries_per_cpu_sec = + histogram.Count() / (sum(result->client_stats(), SystemTime) + + sum(result->client_stats(), UserTime)); + + result->mutable_summary()->set_server_queries_per_cpu_sec( + server_queries_per_cpu_sec); + result->mutable_summary()->set_client_queries_per_cpu_sec( + client_queries_per_cpu_sec); +} + +struct ClientData { + unique_ptr stub; + unique_ptr> stream; +}; + +struct ServerData { + unique_ptr stub; + unique_ptr> stream; +}; + +static void FinishClients(const std::vector& clients, + const ClientArgs& client_mark) { + gpr_log(GPR_INFO, "Finishing clients"); + for (size_t i = 0, i_end = clients.size(); i < i_end; i++) { + auto client = &clients[i]; + if (!client->stream->Write(client_mark)) { + gpr_log(GPR_ERROR, "Couldn't write mark to client %zu", i); + GPR_ASSERT(false); + } + if (!client->stream->WritesDone()) { + gpr_log(GPR_ERROR, "Failed WritesDone for client %zu", i); + GPR_ASSERT(false); + } + } +} + +static void ReceiveFinalStatusFromClients( + const std::vector& clients, Histogram& merged_latencies, + std::unordered_map& merged_statuses, ScenarioResult& result) { + gpr_log(GPR_INFO, "Receiving final status from clients"); + ClientStatus client_status; + for (size_t i = 0, i_end = clients.size(); i < i_end; i++) { + auto client = &clients[i]; + // Read the client final status + if (client->stream->Read(&client_status)) { + gpr_log(GPR_INFO, "Received final status from client %zu", i); + const auto& stats = client_status.stats(); + merged_latencies.MergeProto(stats.latencies()); + for (int i = 0; i < stats.request_results_size(); i++) { + merged_statuses[stats.request_results(i).status_code()] += + stats.request_results(i).count(); + } + result.add_client_stats()->CopyFrom(stats); + // Check that final status was should be the last message on the client + // stream. + // TODO(jtattermusch): note that that waiting for Read to return can take + // long on some scenarios (e.g. unconstrained streaming_from_server). See + // https://github.com/grpc/grpc/blob/3bd0cd208ea549760a2daf595f79b91b247fe240/test/cpp/qps/server_async.cc#L176 + // where the shutdown delay pretty much determines the wait here. + GPR_ASSERT(!client->stream->Read(&client_status)); + } else { + gpr_log(GPR_ERROR, "Couldn't get final status from client %zu", i); + GPR_ASSERT(false); + } + } +} + +static void ShutdownClients(const std::vector& clients, + ScenarioResult& result) { + gpr_log(GPR_INFO, "Shutdown clients"); + for (size_t i = 0, i_end = clients.size(); i < i_end; i++) { + auto client = &clients[i]; + Status s = client->stream->Finish(); + // Since we shutdown servers and clients at the same time, clients can + // observe cancellation. Thus, we consider both OK and CANCELLED as good + // status. + const bool success = IsSuccess(s); + result.add_client_success(success); + if (!success) { + gpr_log(GPR_ERROR, "Client %zu had an error %s", i, + s.error_message().c_str()); + GPR_ASSERT(false); + } + } +} + +static void FinishServers(const std::vector& servers, + const ServerArgs& server_mark) { + gpr_log(GPR_INFO, "Finishing servers"); + for (size_t i = 0, i_end = servers.size(); i < i_end; i++) { + auto server = &servers[i]; + if (!server->stream->Write(server_mark)) { + gpr_log(GPR_ERROR, "Couldn't write mark to server %zu", i); + GPR_ASSERT(false); + } + if (!server->stream->WritesDone()) { + gpr_log(GPR_ERROR, "Failed WritesDone for server %zu", i); + GPR_ASSERT(false); + } + } +} + +static void ReceiveFinalStatusFromServer(const std::vector& servers, + ScenarioResult& result) { + gpr_log(GPR_INFO, "Receiving final status from servers"); + ServerStatus server_status; + for (size_t i = 0, i_end = servers.size(); i < i_end; i++) { + auto server = &servers[i]; + // Read the server final status + if (server->stream->Read(&server_status)) { + gpr_log(GPR_INFO, "Received final status from server %zu", i); + result.add_server_stats()->CopyFrom(server_status.stats()); + result.add_server_cores(server_status.cores()); + // That final status should be the last message on the server stream + GPR_ASSERT(!server->stream->Read(&server_status)); + } else { + gpr_log(GPR_ERROR, "Couldn't get final status from server %zu", i); + GPR_ASSERT(false); + } + } +} + +static void ShutdownServers(const std::vector& servers, + ScenarioResult& result) { + gpr_log(GPR_INFO, "Shutdown servers"); + for (size_t i = 0, i_end = servers.size(); i < i_end; i++) { + auto server = &servers[i]; + Status s = server->stream->Finish(); + // Since we shutdown servers and clients at the same time, servers can + // observe cancellation. Thus, we consider both OK and CANCELLED as good + // status. + const bool success = IsSuccess(s); + result.add_server_success(success); + if (!success) { + gpr_log(GPR_ERROR, "Server %zu had an error %s", i, + s.error_message().c_str()); + GPR_ASSERT(false); + } + } +} + +std::vector* g_inproc_servers = nullptr; + +std::unique_ptr RunScenario( + const ClientConfig& initial_client_config, size_t num_clients, + const ServerConfig& initial_server_config, size_t num_servers, + int warmup_seconds, int benchmark_seconds, int spawn_local_worker_count, + const std::string& qps_server_target_override, + const std::string& credential_type, + const std::map& per_worker_credential_types, + bool run_inproc, int32_t median_latency_collection_interval_millis) { + if (run_inproc) { + g_inproc_servers = new std::vector; + } + // Log everything from the driver + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + + // ClientContext allocations (all are destroyed at scope exit) + list contexts; + auto alloc_context = [](list* contexts) { + contexts->emplace_back(); + auto context = &contexts->back(); + context->set_wait_for_ready(true); + return context; + }; + + // To be added to the result, containing the final configuration used for + // client and config (including host, etc.) + ClientConfig result_client_config; + + // Get client, server lists; ignore if inproc test + auto workers = (!run_inproc) ? get_workers("QPS_WORKERS") : deque(); + ClientConfig client_config = initial_client_config; + + // Spawn some local workers if desired + vector> local_workers; + for (int i = 0; i < abs(spawn_local_worker_count); i++) { + // act as if we're a new test -- gets a good rng seed + static bool called_init = false; + if (!called_init) { + char args_buf[100]; + strcpy(args_buf, "some-benchmark"); + char* args[] = {args_buf}; + grpc_test_init(1, args); + called_init = true; + } + + char addr[256]; + // we use port # of -1 to indicate inproc + int driver_port = (!run_inproc) ? grpc_pick_unused_port_or_die() : -1; + local_workers.emplace_back(new QpsWorker(driver_port, 0, credential_type)); + sprintf(addr, "localhost:%d", driver_port); + if (spawn_local_worker_count < 0) { + workers.push_front(addr); + } else { + workers.push_back(addr); + } + } + GPR_ASSERT(!workers.empty()); + + // if num_clients is set to <=0, do dynamic sizing: all workers + // except for servers are clients + if (num_clients <= 0) { + num_clients = workers.size() - num_servers; + } + + // TODO(ctiller): support running multiple configurations, and binpack + // client/server pairs + // to available workers + GPR_ASSERT(workers.size() >= num_clients + num_servers); + + // Trim to just what we need + workers.resize(num_clients + num_servers); + + // Start servers + std::vector servers(num_servers); + std::unordered_map> hosts_cores; + ChannelArguments channel_args; + + for (size_t i = 0; i < num_servers; i++) { + gpr_log(GPR_INFO, "Starting server on %s (worker #%" PRIuPTR ")", + workers[i].c_str(), i); + if (!run_inproc) { + servers[i].stub = WorkerService::NewStub(grpc::CreateTestChannel( + workers[i], + GetCredType(workers[i], per_worker_credential_types, credential_type), + nullptr /* call creds */, {} /* interceptor creators */)); + } else { + servers[i].stub = WorkerService::NewStub( + local_workers[i]->InProcessChannel(channel_args)); + } + + const ServerConfig& server_config = initial_server_config; + if (server_config.core_limit() != 0) { + gpr_log(GPR_ERROR, + "server config core limit is set but ignored by driver"); + GPR_ASSERT(false); + } + + ServerArgs args; + *args.mutable_setup() = server_config; + servers[i].stream = servers[i].stub->RunServer(alloc_context(&contexts)); + if (!servers[i].stream->Write(args)) { + gpr_log(GPR_ERROR, "Could not write args to server %zu", i); + GPR_ASSERT(false); + } + ServerStatus init_status; + if (!servers[i].stream->Read(&init_status)) { + gpr_log(GPR_ERROR, "Server %zu did not yield initial status", i); + GPR_ASSERT(false); + } + if (run_inproc) { + std::string cli_target(INPROC_NAME_PREFIX); + cli_target += std::to_string(i); + client_config.add_server_targets(cli_target); + } else { + std::string host = get_host(workers[i]); + std::string cli_target = + grpc_core::JoinHostPort(host.c_str(), init_status.port()); + client_config.add_server_targets(cli_target.c_str()); + } + } + if (qps_server_target_override.length() > 0) { + // overriding the qps server target only makes since if there is <= 1 + // servers + GPR_ASSERT(num_servers <= 1); + client_config.add_server_targets(qps_server_target_override); + } + client_config.set_median_latency_collection_interval_millis( + median_latency_collection_interval_millis); + + // Targets are all set by now + result_client_config = client_config; + // Start clients + std::vector clients(num_clients); + size_t channels_allocated = 0; + for (size_t i = 0; i < num_clients; i++) { + const auto& worker = workers[i + num_servers]; + gpr_log(GPR_INFO, "Starting client on %s (worker #%" PRIuPTR ")", + worker.c_str(), i + num_servers); + if (!run_inproc) { + clients[i].stub = WorkerService::NewStub(grpc::CreateTestChannel( + worker, + GetCredType(worker, per_worker_credential_types, credential_type), + nullptr /* call creds */, {} /* interceptor creators */)); + } else { + clients[i].stub = WorkerService::NewStub( + local_workers[i + num_servers]->InProcessChannel(channel_args)); + } + ClientConfig per_client_config = client_config; + + if (initial_client_config.core_limit() != 0) { + gpr_log(GPR_ERROR, "client config core limit set but ignored"); + GPR_ASSERT(false); + } + + // Reduce channel count so that total channels specified is held regardless + // of the number of clients available + size_t num_channels = + (client_config.client_channels() - channels_allocated) / + (num_clients - i); + channels_allocated += num_channels; + gpr_log(GPR_DEBUG, "Client %" PRIdPTR " gets %" PRIdPTR " channels", i, + num_channels); + per_client_config.set_client_channels(num_channels); + + ClientArgs args; + *args.mutable_setup() = per_client_config; + clients[i].stream = clients[i].stub->RunClient(alloc_context(&contexts)); + if (!clients[i].stream->Write(args)) { + gpr_log(GPR_ERROR, "Could not write args to client %zu", i); + GPR_ASSERT(false); + } + } + + for (size_t i = 0; i < num_clients; i++) { + ClientStatus init_status; + if (!clients[i].stream->Read(&init_status)) { + gpr_log(GPR_ERROR, "Client %zu did not yield initial status", i); + GPR_ASSERT(false); + } + } + + // Send an initial mark: clients can use this to know that everything is ready + // to start + gpr_log(GPR_INFO, "Initiating"); + ServerArgs server_mark; + server_mark.mutable_mark()->set_reset(true); + ClientArgs client_mark; + client_mark.mutable_mark()->set_reset(true); + ServerStatus server_status; + ClientStatus client_status; + for (size_t i = 0; i < num_clients; i++) { + auto client = &clients[i]; + if (!client->stream->Write(client_mark)) { + gpr_log(GPR_ERROR, "Couldn't write mark to client %zu", i); + GPR_ASSERT(false); + } + } + for (size_t i = 0; i < num_clients; i++) { + auto client = &clients[i]; + if (!client->stream->Read(&client_status)) { + gpr_log(GPR_ERROR, "Couldn't get status from client %zu", i); + GPR_ASSERT(false); + } + } + + // Let everything warmup + gpr_log(GPR_INFO, "Warming up"); + gpr_timespec start = gpr_now(GPR_CLOCK_REALTIME); + gpr_sleep_until( + gpr_time_add(start, gpr_time_from_seconds(warmup_seconds, GPR_TIMESPAN))); + + // Start a run + gpr_log(GPR_INFO, "Starting"); + for (size_t i = 0; i < num_servers; i++) { + auto server = &servers[i]; + if (!server->stream->Write(server_mark)) { + gpr_log(GPR_ERROR, "Couldn't write mark to server %zu", i); + GPR_ASSERT(false); + } + } + for (size_t i = 0; i < num_clients; i++) { + auto client = &clients[i]; + if (!client->stream->Write(client_mark)) { + gpr_log(GPR_ERROR, "Couldn't write mark to client %zu", i); + GPR_ASSERT(false); + } + } + for (size_t i = 0; i < num_servers; i++) { + auto server = &servers[i]; + if (!server->stream->Read(&server_status)) { + gpr_log(GPR_ERROR, "Couldn't get status from server %zu", i); + GPR_ASSERT(false); + } + } + for (size_t i = 0; i < num_clients; i++) { + auto client = &clients[i]; + if (!client->stream->Read(&client_status)) { + gpr_log(GPR_ERROR, "Couldn't get status from client %zu", i); + GPR_ASSERT(false); + } + } + + // Wait some time + gpr_log(GPR_INFO, "Running"); + // Use gpr_sleep_until rather than this_thread::sleep_until to support + // compilers that don't work with this_thread + gpr_sleep_until(gpr_time_add( + start, + gpr_time_from_seconds(warmup_seconds + benchmark_seconds, GPR_TIMESPAN))); + + gpr_timer_set_enabled(0); + + // Finish a run + std::unique_ptr result(new ScenarioResult); + Histogram merged_latencies; + std::unordered_map merged_statuses; + + // For the case where clients lead the test such as UNARY and + // STREAMING_FROM_CLIENT, clients need to finish completely while a server + // is running to prevent the clients from being stuck while waiting for + // the result. + bool client_finish_first = + (client_config.rpc_type() != STREAMING_FROM_SERVER); + + FinishClients(clients, client_mark); + + if (!client_finish_first) { + FinishServers(servers, server_mark); + } + + ReceiveFinalStatusFromClients(clients, merged_latencies, merged_statuses, + *result); + ShutdownClients(clients, *result); + + if (client_finish_first) { + FinishServers(servers, server_mark); + } + + ReceiveFinalStatusFromServer(servers, *result); + ShutdownServers(servers, *result); + + delete g_inproc_servers; + + merged_latencies.FillProto(result->mutable_latencies()); + for (std::unordered_map::iterator it = merged_statuses.begin(); + it != merged_statuses.end(); ++it) { + RequestResultCount* rrc = result->add_request_results(); + rrc->set_status_code(it->first); + rrc->set_count(it->second); + } + postprocess_scenario_result(result.get()); + return result; +} + +bool RunQuit( + const std::string& credential_type, + const std::map& per_worker_credential_types) { + // Get client, server lists + bool result = true; + auto workers = get_workers("QPS_WORKERS"); + if (workers.empty()) { + return false; + } + + for (size_t i = 0; i < workers.size(); i++) { + auto stub = WorkerService::NewStub(grpc::CreateTestChannel( + workers[i], + GetCredType(workers[i], per_worker_credential_types, credential_type), + nullptr /* call creds */, {} /* interceptor creators */)); + Void phony; + grpc::ClientContext ctx; + ctx.set_wait_for_ready(true); + Status s = stub->QuitWorker(&ctx, phony, &phony); + if (!s.ok()) { + gpr_log(GPR_ERROR, "Worker %zu could not be properly quit because %s", i, + s.error_message().c_str()); + result = false; + } + } + return result; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/inproc_sync_unary_ping_pong_test.cc b/test/cpp/qps/inproc_sync_unary_ping_pong_test.cc new file mode 100644 index 00000000..1a3682e3 --- /dev/null +++ b/test/cpp/qps/inproc_sync_unary_ping_pong_test.cc @@ -0,0 +1,68 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/qps/benchmark_config.h" +#include "test/cpp/qps/driver.h" +#include "test/cpp/qps/report.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { + +static const int WARMUP = 1; +static const int BENCHMARK = 3; + +static void RunSynchronousUnaryPingPong() { + gpr_log(GPR_INFO, "Running Synchronous Unary Ping Pong"); + + ClientConfig client_config; + client_config.set_client_type(SYNC_CLIENT); + client_config.set_outstanding_rpcs_per_channel(1); + client_config.set_client_channels(1); + client_config.set_rpc_type(UNARY); + client_config.mutable_load_params()->mutable_closed_loop(); + + ServerConfig server_config; + server_config.set_server_type(SYNC_SERVER); + + const auto result = + RunScenario(client_config, 1, server_config, 1, WARMUP, BENCHMARK, -2, "", + kInsecureCredentialsType, {}, true, 0); + + GetReporter()->ReportQPS(*result); + GetReporter()->ReportLatency(*result); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + grpc::testing::RunSynchronousUnaryPingPong(); + + return 0; +} diff --git a/test/cpp/qps/json_run_localhost.cc b/test/cpp/qps/json_run_localhost.cc new file mode 100644 index 00000000..b515efeb --- /dev/null +++ b/test/cpp/qps/json_run_localhost.cc @@ -0,0 +1,137 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include + +#ifdef __FreeBSD__ +#include +#endif + +#include + +#include "src/core/lib/gpr/env.h" +#include "test/core/util/port.h" +#include "test/cpp/util/subprocess.h" + +using grpc::SubProcess; + +constexpr auto kNumWorkers = 2; + +static SubProcess* g_driver; +static SubProcess* g_workers[kNumWorkers]; + +template +std::string as_string(const T& val) { + std::ostringstream out; + out << val; + return out.str(); +} + +static void sighandler(int /*sig*/) { + const int errno_saved = errno; + if (g_driver != nullptr) g_driver->Interrupt(); + for (int i = 0; i < kNumWorkers; ++i) { + if (g_workers[i]) g_workers[i]->Interrupt(); + } + errno = errno_saved; +} + +static void register_sighandler() { + struct sigaction act; + memset(&act, 0, sizeof(act)); + act.sa_handler = sighandler; + + sigaction(SIGINT, &act, nullptr); + sigaction(SIGTERM, &act, nullptr); +} + +static void LogStatus(int status, const char* label) { + if (WIFEXITED(status)) { + gpr_log(GPR_INFO, "%s: subprocess exited with status %d", label, + WEXITSTATUS(status)); + } else if (WIFSIGNALED(status)) { + gpr_log(GPR_INFO, "%s: subprocess terminated with signal %d", label, + WTERMSIG(status)); + } else { + gpr_log(GPR_INFO, "%s: unknown subprocess status: %d", label, status); + } +} + +int main(int argc, char** argv) { + register_sighandler(); + + std::string my_bin = argv[0]; + std::string bin_dir = my_bin.substr(0, my_bin.rfind('/')); + + std::ostringstream env; + bool first = true; + + for (int i = 0; i < kNumWorkers; i++) { + const auto driver_port = grpc_pick_unused_port_or_die(); + // ServerPort can be used or not later depending on the type of worker + // but we like to issue all ports required here to avoid port conflict. + const auto server_port = grpc_pick_unused_port_or_die(); + std::vector args = {bin_dir + "/qps_worker", "-driver_port", + as_string(driver_port), "-server_port", + as_string(server_port)}; + g_workers[i] = new SubProcess(args); + if (!first) env << ","; + env << "localhost:" << driver_port; + first = false; + } + + gpr_setenv("QPS_WORKERS", env.str().c_str()); + std::vector args = {bin_dir + "/qps_json_driver"}; + for (int i = 1; i < argc; i++) { + args.push_back(argv[i]); + } + + g_driver = new SubProcess(args); + const int driver_join_status = g_driver->Join(); + if (driver_join_status != 0) { + LogStatus(driver_join_status, "driver"); + } + for (int i = 0; i < kNumWorkers; ++i) { + if (g_workers[i]) g_workers[i]->Interrupt(); + } + + for (int i = 0; i < kNumWorkers; ++i) { + if (g_workers[i]) { + const int worker_status = g_workers[i]->Join(); + if (worker_status != 0) { + LogStatus(worker_status, "worker"); + } + } + } + + delete g_driver; + + g_driver = nullptr; + for (int i = 0; i < kNumWorkers; ++i) { + if (g_workers[i] != nullptr) { + delete g_workers[i]; + } + } + GPR_ASSERT(driver_join_status == 0); +} diff --git a/test/cpp/qps/parse_json.cc b/test/cpp/qps/parse_json.cc new file mode 100644 index 00000000..f3ac1052 --- /dev/null +++ b/test/cpp/qps/parse_json.cc @@ -0,0 +1,61 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/parse_json.h" + +#include + +#include + +namespace grpc { +namespace testing { + +void ParseJson(const std::string& json, const std::string& type, + GRPC_CUSTOM_MESSAGE* msg) { + std::unique_ptr type_resolver( + protobuf::json::NewTypeResolverForDescriptorPool( + "type.googleapis.com", protobuf::DescriptorPool::generated_pool())); + std::string binary; + auto status = JsonToBinaryString( + type_resolver.get(), "type.googleapis.com/" + type, json, &binary); + if (!status.ok()) { + std::string errmsg(status.message()); + gpr_log(GPR_ERROR, "Failed to convert json to binary: errcode=%d msg=%s", + status.code(), errmsg.c_str()); + gpr_log(GPR_ERROR, "JSON: %s", json.c_str()); + abort(); + } + GPR_ASSERT(msg->ParseFromString(binary)); +} + +std::string SerializeJson(const GRPC_CUSTOM_MESSAGE& msg, + const std::string& type) { + std::unique_ptr type_resolver( + protobuf::json::NewTypeResolverForDescriptorPool( + "type.googleapis.com", protobuf::DescriptorPool::generated_pool())); + std::string binary; + std::string json_string; + msg.SerializeToString(&binary); + auto status = + BinaryToJsonString(type_resolver.get(), type, binary, &json_string); + GPR_ASSERT(status.ok()); + return json_string; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/qps_interarrival_test.cc b/test/cpp/qps/qps_interarrival_test.cc new file mode 100644 index 00000000..768c897d --- /dev/null +++ b/test/cpp/qps/qps_interarrival_test.cc @@ -0,0 +1,60 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +// Use the C histogram rather than C++ to avoid depending on proto +#include "test/core/util/histogram.h" +#include "test/core/util/test_config.h" +#include "test/cpp/qps/interarrival.h" +#include "test/cpp/util/test_config.h" + +using grpc::testing::InterarrivalTimer; +using grpc::testing::RandomDistInterface; + +static void RunTest(RandomDistInterface&& r, int threads, + const std::string& title) { + InterarrivalTimer timer; + timer.init(r, threads); + grpc_histogram* h(grpc_histogram_create(0.01, 60e9)); + + for (int i = 0; i < 10000000; i++) { + for (int j = 0; j < threads; j++) { + grpc_histogram_add(h, timer.next(j)); + } + } + + std::cout << title << " Distribution" << std::endl; + std::cout << "Value, Percentile" << std::endl; + for (double pct = 0.0; pct < 100.0; pct += 1.0) { + std::cout << grpc_histogram_percentile(h, pct) << "," << pct << std::endl; + } + + grpc_histogram_destroy(h); +} + +using grpc::testing::ExpDist; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + RunTest(ExpDist(10.0), 5, std::string("Exponential(10)")); + return 0; +} diff --git a/test/cpp/qps/qps_json_driver.cc b/test/cpp/qps/qps_json_driver.cc new file mode 100644 index 00000000..2e8cb42c --- /dev/null +++ b/test/cpp/qps/qps_json_driver.cc @@ -0,0 +1,305 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/qps/benchmark_config.h" +#include "test/cpp/qps/driver.h" +#include "test/cpp/qps/parse_json.h" +#include "test/cpp/qps/report.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(std::string, scenarios_file, "", + "JSON file containing an array of Scenario objects"); +ABSL_FLAG(std::string, scenarios_json, "", + "JSON string containing an array of Scenario objects"); +ABSL_FLAG(bool, quit, false, "Quit the workers"); +ABSL_FLAG(std::string, search_param, "", + "The parameter, whose value is to be searched for to achieve " + "targeted cpu load. For now, we have 'offered_load'. Later, " + "'num_channels', 'num_outstanding_requests', etc. shall be " + "added."); +ABSL_FLAG( + double, initial_search_value, 0.0, + "initial parameter value to start the search with (i.e. lower bound)"); +ABSL_FLAG(double, targeted_cpu_load, 70.0, + "Targeted cpu load (unit: %, range [0,100])"); +ABSL_FLAG(double, stride, 1, + "Defines each stride of the search. The larger the stride is, " + "the coarser the result will be, but will also be faster."); +ABSL_FLAG(double, error_tolerance, 0.01, + "Defines threshold for stopping the search. When current search " + "range is narrower than the error_tolerance computed range, we " + "stop the search."); + +ABSL_FLAG(std::string, qps_server_target_override, "", + "Override QPS server target to configure in client configs." + "Only applicable if there is a single benchmark server."); + +ABSL_FLAG(std::string, json_file_out, "", "File to write the JSON output to."); + +ABSL_FLAG(std::string, credential_type, grpc::testing::kInsecureCredentialsType, + "Credential type for communication with workers"); +ABSL_FLAG( + std::string, per_worker_credential_types, "", + "A map of QPS worker addresses to credential types. When creating a " + "channel to a QPS worker's driver port, the qps_json_driver first checks " + "if the 'name:port' string is in the map, and it uses the corresponding " + "credential type if so. If the QPS worker's 'name:port' string is not " + "in the map, then the driver -> worker channel will be created with " + "the credentials specified in --credential_type. The value of this flag " + "is a semicolon-separated list of map entries, where each map entry is " + "a comma-separated pair."); +ABSL_FLAG(bool, run_inproc, false, "Perform an in-process transport test"); +ABSL_FLAG( + int32_t, median_latency_collection_interval_millis, 0, + "Specifies the period between gathering latency medians in " + "milliseconds. The medians will be logged out on the client at the " + "end of the benchmark run. If 0, this periodic collection is disabled."); + +namespace grpc { +namespace testing { + +static std::map +ConstructPerWorkerCredentialTypesMap() { + // Parse a list of the form: "addr1,cred_type1;addr2,cred_type2;..." into + // a map. + std::string remaining = absl::GetFlag(FLAGS_per_worker_credential_types); + std::map out; + while (!remaining.empty()) { + size_t next_semicolon = remaining.find(';'); + std::string next_entry = remaining.substr(0, next_semicolon); + if (next_semicolon == std::string::npos) { + remaining = ""; + } else { + remaining = remaining.substr(next_semicolon + 1, std::string::npos); + } + size_t comma = next_entry.find(','); + if (comma == std::string::npos) { + gpr_log(GPR_ERROR, + "Expectd --per_worker_credential_types to be a list " + "of the form: 'addr1,cred_type1;addr2,cred_type2;...' " + "into."); + abort(); + } + std::string addr = next_entry.substr(0, comma); + std::string cred_type = next_entry.substr(comma + 1, std::string::npos); + if (out.find(addr) != out.end()) { + gpr_log(GPR_ERROR, + "Found duplicate addr in per_worker_credential_types."); + abort(); + } + out[addr] = cred_type; + } + return out; +} + +static std::unique_ptr RunAndReport( + const Scenario& scenario, + const std::map& per_worker_credential_types, + bool* success) { + std::cerr << "RUNNING SCENARIO: " << scenario.name() << "\n"; + auto result = RunScenario( + scenario.client_config(), scenario.num_clients(), + scenario.server_config(), scenario.num_servers(), + scenario.warmup_seconds(), scenario.benchmark_seconds(), + !absl::GetFlag(FLAGS_run_inproc) ? scenario.spawn_local_worker_count() + : -2, + absl::GetFlag(FLAGS_qps_server_target_override), + absl::GetFlag(FLAGS_credential_type), per_worker_credential_types, + absl::GetFlag(FLAGS_run_inproc), + absl::GetFlag(FLAGS_median_latency_collection_interval_millis)); + + // Amend the result with scenario config. Eventually we should adjust + // RunScenario contract so we don't need to touch the result here. + result->mutable_scenario()->CopyFrom(scenario); + + GetReporter()->ReportQPS(*result); + GetReporter()->ReportQPSPerCore(*result); + GetReporter()->ReportLatency(*result); + GetReporter()->ReportTimes(*result); + GetReporter()->ReportCpuUsage(*result); + GetReporter()->ReportPollCount(*result); + GetReporter()->ReportQueriesPerCpuSec(*result); + + for (int i = 0; *success && i < result->client_success_size(); i++) { + *success = result->client_success(i); + } + for (int i = 0; *success && i < result->server_success_size(); i++) { + *success = result->server_success(i); + } + + if (!absl::GetFlag(FLAGS_json_file_out).empty()) { + std::ofstream json_outfile; + json_outfile.open(absl::GetFlag(FLAGS_json_file_out)); + json_outfile << "{\"qps\": " << result->summary().qps() << "}\n"; + json_outfile.close(); + } + + return result; +} + +static double GetCpuLoad( + Scenario* scenario, double offered_load, + const std::map& per_worker_credential_types, + bool* success) { + scenario->mutable_client_config() + ->mutable_load_params() + ->mutable_poisson() + ->set_offered_load(offered_load); + auto result = RunAndReport(*scenario, per_worker_credential_types, success); + return result->summary().server_cpu_usage(); +} + +static double BinarySearch( + Scenario* scenario, double targeted_cpu_load, double low, double high, + const std::map& per_worker_credential_types, + bool* success) { + while (low <= high * (1 - absl::GetFlag(FLAGS_error_tolerance))) { + double mid = low + (high - low) / 2; + double current_cpu_load = + GetCpuLoad(scenario, mid, per_worker_credential_types, success); + gpr_log(GPR_DEBUG, "Binary Search: current_offered_load %.0f", mid); + if (!*success) { + gpr_log(GPR_ERROR, "Client/Server Failure"); + break; + } + if (targeted_cpu_load <= current_cpu_load) { + high = mid - absl::GetFlag(FLAGS_stride); + } else { + low = mid + absl::GetFlag(FLAGS_stride); + } + } + + return low; +} + +static double SearchOfferedLoad( + double initial_offered_load, double targeted_cpu_load, Scenario* scenario, + const std::map& per_worker_credential_types, + bool* success) { + std::cerr << "RUNNING SCENARIO: " << scenario->name() << "\n"; + double current_offered_load = initial_offered_load; + double current_cpu_load = GetCpuLoad(scenario, current_offered_load, + per_worker_credential_types, success); + if (current_cpu_load > targeted_cpu_load) { + gpr_log(GPR_ERROR, "Initial offered load too high"); + return -1; + } + + while (*success && (current_cpu_load < targeted_cpu_load)) { + current_offered_load *= 2; + current_cpu_load = GetCpuLoad(scenario, current_offered_load, + per_worker_credential_types, success); + gpr_log(GPR_DEBUG, "Binary Search: current_offered_load %.0f", + current_offered_load); + } + + double targeted_offered_load = + BinarySearch(scenario, targeted_cpu_load, current_offered_load / 2, + current_offered_load, per_worker_credential_types, success); + + return targeted_offered_load; +} + +static bool QpsDriver() { + std::string json; + + bool scfile = (!absl::GetFlag(FLAGS_scenarios_file).empty()); + bool scjson = (!absl::GetFlag(FLAGS_scenarios_json).empty()); + if ((!scfile && !scjson && !absl::GetFlag(FLAGS_quit)) || + (scfile && (scjson || absl::GetFlag(FLAGS_quit))) || + (scjson && absl::GetFlag(FLAGS_quit))) { + gpr_log(GPR_ERROR, + "Exactly one of --scenarios_file, --scenarios_json, " + "or --quit must be set"); + abort(); + } + + auto per_worker_credential_types = ConstructPerWorkerCredentialTypesMap(); + if (scfile) { + // Read the json data from disk + FILE* json_file = fopen(absl::GetFlag(FLAGS_scenarios_file).c_str(), "r"); + GPR_ASSERT(json_file != nullptr); + fseek(json_file, 0, SEEK_END); + long len = ftell(json_file); + char* data = new char[len]; + fseek(json_file, 0, SEEK_SET); + GPR_ASSERT(len == (long)fread(data, 1, len, json_file)); + fclose(json_file); + json = std::string(data, data + len); + delete[] data; + } else if (scjson) { + json = absl::GetFlag(FLAGS_scenarios_json).c_str(); + } else if (absl::GetFlag(FLAGS_quit)) { + return RunQuit(absl::GetFlag(FLAGS_credential_type), + per_worker_credential_types); + } + + // Parse into an array of scenarios + Scenarios scenarios; + ParseJson(json.c_str(), "grpc.testing.Scenarios", &scenarios); + bool success = true; + + // Make sure that there is at least some valid scenario here + GPR_ASSERT(scenarios.scenarios_size() > 0); + + for (int i = 0; i < scenarios.scenarios_size(); i++) { + if (absl::GetFlag(FLAGS_search_param).empty()) { + const Scenario& scenario = scenarios.scenarios(i); + RunAndReport(scenario, per_worker_credential_types, &success); + } else { + if (absl::GetFlag(FLAGS_search_param) == "offered_load") { + Scenario* scenario = scenarios.mutable_scenarios(i); + double targeted_offered_load = + SearchOfferedLoad(absl::GetFlag(FLAGS_initial_search_value), + absl::GetFlag(FLAGS_targeted_cpu_load), scenario, + per_worker_credential_types, &success); + gpr_log(GPR_INFO, "targeted_offered_load %f", targeted_offered_load); + GetCpuLoad(scenario, targeted_offered_load, per_worker_credential_types, + &success); + } else { + gpr_log(GPR_ERROR, "Unimplemented search param"); + } + } + } + return success; +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + bool ok = grpc::testing::QpsDriver(); + + return ok ? 0 : 1; +} diff --git a/test/cpp/qps/qps_openloop_test.cc b/test/cpp/qps/qps_openloop_test.cc new file mode 100644 index 00000000..38e70a84 --- /dev/null +++ b/test/cpp/qps/qps_openloop_test.cc @@ -0,0 +1,71 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/qps/benchmark_config.h" +#include "test/cpp/qps/driver.h" +#include "test/cpp/qps/report.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { + +static const int WARMUP = 1; +static const int BENCHMARK = 3; + +static void RunQPS() { + gpr_log(GPR_INFO, "Running QPS test, open-loop"); + + ClientConfig client_config; + client_config.set_client_type(ASYNC_CLIENT); + client_config.set_outstanding_rpcs_per_channel(100); + client_config.set_client_channels(8); + client_config.set_async_client_threads(8); + client_config.set_rpc_type(STREAMING); + client_config.mutable_load_params()->mutable_poisson()->set_offered_load( + 1000.0 / grpc_test_slowdown_factor()); + + ServerConfig server_config; + server_config.set_server_type(ASYNC_SERVER); + server_config.set_async_server_threads(8); + + const auto result = + RunScenario(client_config, 1, server_config, 1, WARMUP, BENCHMARK, -2, "", + kInsecureCredentialsType, {}, false, 0); + + GetReporter()->ReportQPSPerCore(*result); + GetReporter()->ReportLatency(*result); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + grpc::testing::RunQPS(); + + return 0; +} diff --git a/test/cpp/qps/qps_server_builder.cc b/test/cpp/qps/qps_server_builder.cc new file mode 100644 index 00000000..87cb452a --- /dev/null +++ b/test/cpp/qps/qps_server_builder.cc @@ -0,0 +1,47 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "qps_server_builder.h" + +#include "absl/memory/memory.h" + +using grpc::ServerBuilder; + +namespace grpc { +namespace testing { + +namespace { +std::unique_ptr DefaultCreateQpsServerBuilder() { + return absl::make_unique(); +} + +std::function()> g_create_qps_server_builder = + DefaultCreateQpsServerBuilder; +} // namespace + +std::unique_ptr CreateQpsServerBuilder() { + return g_create_qps_server_builder(); +} + +void SetCreateQpsServerBuilderFunc( + std::function()> create_qps_server_builder) { + g_create_qps_server_builder = std::move(create_qps_server_builder); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/qps_worker.cc b/test/cpp/qps/qps_worker.cc new file mode 100644 index 00000000..cd916bfa --- /dev/null +++ b/test/cpp/qps/qps_worker.cc @@ -0,0 +1,313 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/qps_worker.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/proto/grpc/testing/worker_service.grpc.pb.h" +#include "test/core/util/grpc_profiler.h" +#include "test/core/util/histogram.h" +#include "test/cpp/qps/client.h" +#include "test/cpp/qps/qps_server_builder.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/util/create_test_channel.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { + +static std::unique_ptr CreateClient(const ClientConfig& config) { + gpr_log(GPR_INFO, "Starting client of type %s %s %d", + ClientType_Name(config.client_type()).c_str(), + RpcType_Name(config.rpc_type()).c_str(), + config.payload_config().has_bytebuf_params()); + + switch (config.client_type()) { + case ClientType::SYNC_CLIENT: + return CreateSynchronousClient(config); + case ClientType::ASYNC_CLIENT: + return config.payload_config().has_bytebuf_params() + ? CreateGenericAsyncStreamingClient(config) + : CreateAsyncClient(config); + case ClientType::CALLBACK_CLIENT: + return CreateCallbackClient(config); + default: + abort(); + } +} + +static std::unique_ptr CreateServer(const ServerConfig& config) { + gpr_log(GPR_INFO, "Starting server of type %s", + ServerType_Name(config.server_type()).c_str()); + + switch (config.server_type()) { + case ServerType::SYNC_SERVER: + return CreateSynchronousServer(config); + case ServerType::ASYNC_SERVER: + return CreateAsyncServer(config); + case ServerType::ASYNC_GENERIC_SERVER: + return CreateAsyncGenericServer(config); + case ServerType::CALLBACK_SERVER: + return CreateCallbackServer(config); + default: + abort(); + } +} + +class ScopedProfile final { + public: + ScopedProfile(const char* filename, bool enable) : enable_(enable) { + if (enable_) grpc_profiler_start(filename); + } + ~ScopedProfile() { + if (enable_) grpc_profiler_stop(); + } + + private: + const bool enable_; +}; + +class WorkerServiceImpl final : public WorkerService::Service { + public: + WorkerServiceImpl(int server_port, QpsWorker* worker) + : acquired_(false), server_port_(server_port), worker_(worker) {} + + Status RunClient( + ServerContext* ctx, + ServerReaderWriter* stream) override { + gpr_log(GPR_INFO, "RunClient: Entering"); + InstanceGuard g(this); + if (!g.Acquired()) { + return Status(StatusCode::RESOURCE_EXHAUSTED, "Client worker busy"); + } + + ScopedProfile profile("qps_client.prof", false); + Status ret = RunClientBody(ctx, stream); + gpr_log(GPR_INFO, "RunClient: Returning"); + return ret; + } + + Status RunServer( + ServerContext* ctx, + ServerReaderWriter* stream) override { + gpr_log(GPR_INFO, "RunServer: Entering"); + InstanceGuard g(this); + if (!g.Acquired()) { + return Status(StatusCode::RESOURCE_EXHAUSTED, "Server worker busy"); + } + + ScopedProfile profile("qps_server.prof", false); + Status ret = RunServerBody(ctx, stream); + gpr_log(GPR_INFO, "RunServer: Returning"); + return ret; + } + + Status CoreCount(ServerContext* /*ctx*/, const CoreRequest*, + CoreResponse* resp) override { + resp->set_cores(gpr_cpu_num_cores()); + return Status::OK; + } + + Status QuitWorker(ServerContext* /*ctx*/, const Void*, Void*) override { + InstanceGuard g(this); + if (!g.Acquired()) { + return Status(StatusCode::RESOURCE_EXHAUSTED, "Quitting worker busy"); + } + + worker_->MarkDone(); + return Status::OK; + } + + private: + // Protect against multiple clients using this worker at once. + class InstanceGuard { + public: + explicit InstanceGuard(WorkerServiceImpl* impl) + : impl_(impl), acquired_(impl->TryAcquireInstance()) {} + ~InstanceGuard() { + if (acquired_) { + impl_->ReleaseInstance(); + } + } + + bool Acquired() const { return acquired_; } + + private: + WorkerServiceImpl* const impl_; + const bool acquired_; + }; + + bool TryAcquireInstance() { + std::lock_guard g(mu_); + if (acquired_) return false; + acquired_ = true; + return true; + } + + void ReleaseInstance() { + std::lock_guard g(mu_); + GPR_ASSERT(acquired_); + acquired_ = false; + } + + Status RunClientBody(ServerContext* /*ctx*/, + ServerReaderWriter* stream) { + ClientArgs args; + if (!stream->Read(&args)) { + return Status(StatusCode::INVALID_ARGUMENT, "Couldn't read args"); + } + if (!args.has_setup()) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid setup arg"); + } + gpr_log(GPR_INFO, "RunClientBody: about to create client"); + std::unique_ptr client = CreateClient(args.setup()); + if (!client) { + return Status(StatusCode::INVALID_ARGUMENT, "Couldn't create client"); + } + gpr_log(GPR_INFO, "RunClientBody: client created"); + ClientStatus status; + if (!stream->Write(status)) { + return Status(StatusCode::UNKNOWN, "Client couldn't report init status"); + } + gpr_log(GPR_INFO, "RunClientBody: creation status reported"); + while (stream->Read(&args)) { + gpr_log(GPR_INFO, "RunClientBody: Message read"); + if (!args.has_mark()) { + gpr_log(GPR_INFO, "RunClientBody: Message is not a mark!"); + return Status(StatusCode::INVALID_ARGUMENT, "Invalid mark"); + } + *status.mutable_stats() = client->Mark(args.mark().reset()); + if (!stream->Write(status)) { + return Status(StatusCode::UNKNOWN, "Client couldn't respond to mark"); + } + gpr_log(GPR_INFO, "RunClientBody: Mark response given"); + } + + gpr_log(GPR_INFO, "RunClientBody: Awaiting Threads Completion"); + client->AwaitThreadsCompletion(); + + gpr_log(GPR_INFO, "RunClientBody: Returning"); + return Status::OK; + } + + Status RunServerBody(ServerContext* /*ctx*/, + ServerReaderWriter* stream) { + ServerArgs args; + if (!stream->Read(&args)) { + return Status(StatusCode::INVALID_ARGUMENT, "Couldn't read server args"); + } + if (!args.has_setup()) { + return Status(StatusCode::INVALID_ARGUMENT, "Bad server creation args"); + } + if (server_port_ > 0 && args.setup().port() == 0) { + args.mutable_setup()->set_port(server_port_); + } + gpr_log(GPR_INFO, "RunServerBody: about to create server"); + std::unique_ptr server = CreateServer(args.setup()); + if (g_inproc_servers != nullptr) { + g_inproc_servers->push_back(server.get()); + } + if (!server) { + return Status(StatusCode::INVALID_ARGUMENT, "Couldn't create server"); + } + gpr_log(GPR_INFO, "RunServerBody: server created"); + ServerStatus status; + status.set_port(server->port()); + status.set_cores(server->cores()); + if (!stream->Write(status)) { + return Status(StatusCode::UNKNOWN, "Server couldn't report init status"); + } + gpr_log(GPR_INFO, "RunServerBody: creation status reported"); + while (stream->Read(&args)) { + gpr_log(GPR_INFO, "RunServerBody: Message read"); + if (!args.has_mark()) { + gpr_log(GPR_INFO, "RunServerBody: Message not a mark!"); + return Status(StatusCode::INVALID_ARGUMENT, "Invalid mark"); + } + *status.mutable_stats() = server->Mark(args.mark().reset()); + if (!stream->Write(status)) { + return Status(StatusCode::UNKNOWN, "Server couldn't respond to mark"); + } + gpr_log(GPR_INFO, "RunServerBody: Mark response given"); + } + + gpr_log(GPR_INFO, "RunServerBody: Returning"); + return Status::OK; + } + + std::mutex mu_; + bool acquired_; + int server_port_; + QpsWorker* worker_; +}; + +QpsWorker::QpsWorker(int driver_port, int server_port, + const std::string& credential_type) { + impl_ = absl::make_unique(server_port, this); + gpr_atm_rel_store(&done_, static_cast(0)); + + std::unique_ptr builder = CreateQpsServerBuilder(); + builder->AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); + if (driver_port >= 0) { + std::string server_address = grpc_core::JoinHostPort("::", driver_port); + builder->AddListeningPort( + server_address.c_str(), + GetCredentialsProvider()->GetServerCredentials(credential_type)); + } + builder->RegisterService(impl_.get()); + + server_ = builder->BuildAndStart(); + if (server_ == nullptr) { + gpr_log(GPR_ERROR, + "QpsWorker: Fail to BuildAndStart(driver_port=%d, server_port=%d)", + driver_port, server_port); + } else { + gpr_log(GPR_INFO, + "QpsWorker: BuildAndStart(driver_port=%d, server_port=%d) done", + driver_port, server_port); + } +} + +QpsWorker::~QpsWorker() {} + +bool QpsWorker::Done() const { + return (gpr_atm_acq_load(&done_) != static_cast(0)); +} +void QpsWorker::MarkDone() { + gpr_atm_rel_store(&done_, static_cast(1)); +} +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/report.cc b/test/cpp/qps/report.cc new file mode 100644 index 00000000..b4299d2f --- /dev/null +++ b/test/cpp/qps/report.cc @@ -0,0 +1,239 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/report.h" + +#include + +#include +#include + +#include "src/cpp/util/core_stats.h" +#include "src/proto/grpc/testing/report_qps_scenario_service.grpc.pb.h" +#include "test/cpp/qps/driver.h" +#include "test/cpp/qps/parse_json.h" +#include "test/cpp/qps/stats.h" + +namespace grpc { +namespace testing { + +void CompositeReporter::add(std::unique_ptr reporter) { + reporters_.emplace_back(std::move(reporter)); +} + +void CompositeReporter::ReportQPS(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportQPS(result); + } +} + +void CompositeReporter::ReportQPSPerCore(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportQPSPerCore(result); + } +} + +void CompositeReporter::ReportLatency(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportLatency(result); + } +} + +void CompositeReporter::ReportTimes(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportTimes(result); + } +} + +void CompositeReporter::ReportCpuUsage(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportCpuUsage(result); + } +} + +void CompositeReporter::ReportPollCount(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportPollCount(result); + } +} + +void CompositeReporter::ReportQueriesPerCpuSec(const ScenarioResult& result) { + for (size_t i = 0; i < reporters_.size(); ++i) { + reporters_[i]->ReportQueriesPerCpuSec(result); + } +} + +void GprLogReporter::ReportQPS(const ScenarioResult& result) { + gpr_log(GPR_INFO, "QPS: %.1f", result.summary().qps()); + if (result.summary().failed_requests_per_second() > 0) { + gpr_log(GPR_INFO, "failed requests/second: %.1f", + result.summary().failed_requests_per_second()); + gpr_log(GPR_INFO, "successful requests/second: %.1f", + result.summary().successful_requests_per_second()); + } + for (int i = 0; i < result.client_stats_size(); i++) { + if (result.client_stats(i).has_core_stats()) { + ReportCoreStats("CLIENT", i, result.client_stats(i).core_stats()); + } + } + for (int i = 0; i < result.server_stats_size(); i++) { + if (result.server_stats(i).has_core_stats()) { + ReportCoreStats("SERVER", i, result.server_stats(i).core_stats()); + } + } +} + +void GprLogReporter::ReportCoreStats(const char* name, int idx, + const grpc::core::Stats& stats) { + grpc_stats_data data; + ProtoToCoreStats(stats, &data); + for (int i = 0; i < GRPC_STATS_COUNTER_COUNT; i++) { + gpr_log(GPR_DEBUG, "%s[%d].%s = %" PRIdPTR, name, idx, + grpc_stats_counter_name[i], data.counters[i]); + } + for (int i = 0; i < GRPC_STATS_HISTOGRAM_COUNT; i++) { + gpr_log(GPR_DEBUG, "%s[%d].%s = %.1lf/%.1lf/%.1lf (50/95/99%%-ile)", name, + idx, grpc_stats_histogram_name[i], + grpc_stats_histo_percentile( + &data, static_cast(i), 50), + grpc_stats_histo_percentile( + &data, static_cast(i), 95), + grpc_stats_histo_percentile( + &data, static_cast(i), 99)); + } +} + +void GprLogReporter::ReportQPSPerCore(const ScenarioResult& result) { + gpr_log(GPR_INFO, "QPS: %.1f (%.1f/server core)", result.summary().qps(), + result.summary().qps_per_server_core()); +} + +void GprLogReporter::ReportLatency(const ScenarioResult& result) { + gpr_log(GPR_INFO, + "Latencies (50/90/95/99/99.9%%-ile): %.1f/%.1f/%.1f/%.1f/%.1f us", + result.summary().latency_50() / 1000, + result.summary().latency_90() / 1000, + result.summary().latency_95() / 1000, + result.summary().latency_99() / 1000, + result.summary().latency_999() / 1000); +} + +void GprLogReporter::ReportTimes(const ScenarioResult& result) { + gpr_log(GPR_INFO, "Server system time: %.2f%%", + result.summary().server_system_time()); + gpr_log(GPR_INFO, "Server user time: %.2f%%", + result.summary().server_user_time()); + gpr_log(GPR_INFO, "Client system time: %.2f%%", + result.summary().client_system_time()); + gpr_log(GPR_INFO, "Client user time: %.2f%%", + result.summary().client_user_time()); +} + +void GprLogReporter::ReportCpuUsage(const ScenarioResult& result) { + gpr_log(GPR_INFO, "Server CPU usage: %.2f%%", + result.summary().server_cpu_usage()); +} + +void GprLogReporter::ReportPollCount(const ScenarioResult& result) { + gpr_log(GPR_INFO, "Client Polls per Request: %.2f", + result.summary().client_polls_per_request()); + gpr_log(GPR_INFO, "Server Polls per Request: %.2f", + result.summary().server_polls_per_request()); +} + +void GprLogReporter::ReportQueriesPerCpuSec(const ScenarioResult& result) { + gpr_log(GPR_INFO, "Server Queries/CPU-sec: %.2f", + result.summary().server_queries_per_cpu_sec()); + gpr_log(GPR_INFO, "Client Queries/CPU-sec: %.2f", + result.summary().client_queries_per_cpu_sec()); +} + +void JsonReporter::ReportQPS(const ScenarioResult& result) { + std::string json_string = + SerializeJson(result, "type.googleapis.com/grpc.testing.ScenarioResult"); + std::ofstream output_file(report_file_); + output_file << json_string; + output_file.close(); +} + +void JsonReporter::ReportQPSPerCore(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void JsonReporter::ReportLatency(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void JsonReporter::ReportTimes(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void JsonReporter::ReportCpuUsage(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void JsonReporter::ReportPollCount(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void JsonReporter::ReportQueriesPerCpuSec(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportQPS(const ScenarioResult& result) { + grpc::ClientContext context; + grpc::Status status; + Void phony; + + gpr_log(GPR_INFO, "RPC reporter sending scenario result to server"); + status = stub_->ReportScenario(&context, result, &phony); + + if (status.ok()) { + gpr_log(GPR_INFO, "RpcReporter report RPC success!"); + } else { + gpr_log(GPR_ERROR, "RpcReporter report RPC: code: %d. message: %s", + status.error_code(), status.error_message().c_str()); + } +} + +void RpcReporter::ReportQPSPerCore(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportLatency(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportTimes(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportCpuUsage(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportPollCount(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +void RpcReporter::ReportQueriesPerCpuSec(const ScenarioResult& /*result*/) { + // NOP - all reporting is handled by ReportQPS. +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/secure_sync_unary_ping_pong_test.cc b/test/cpp/qps/secure_sync_unary_ping_pong_test.cc new file mode 100644 index 00000000..16cffc92 --- /dev/null +++ b/test/cpp/qps/secure_sync_unary_ping_pong_test.cc @@ -0,0 +1,74 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/qps/benchmark_config.h" +#include "test/cpp/qps/driver.h" +#include "test/cpp/qps/report.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +namespace grpc { +namespace testing { + +static const int WARMUP = 1; +static const int BENCHMARK = 3; + +static void RunSynchronousUnaryPingPong() { + gpr_log(GPR_INFO, "Running Synchronous Unary Ping Pong"); + + ClientConfig client_config; + client_config.set_client_type(SYNC_CLIENT); + client_config.set_outstanding_rpcs_per_channel(1); + client_config.set_client_channels(1); + client_config.set_rpc_type(UNARY); + client_config.mutable_load_params()->mutable_closed_loop(); + + ServerConfig server_config; + server_config.set_server_type(SYNC_SERVER); + + // Set up security params + SecurityParams security; + security.set_use_test_ca(true); + security.set_server_host_override("foo.test.google.fr"); + client_config.mutable_security_params()->CopyFrom(security); + server_config.mutable_security_params()->CopyFrom(security); + + const auto result = + RunScenario(client_config, 1, server_config, 1, WARMUP, BENCHMARK, -2, "", + kInsecureCredentialsType, {}, false, 0); + + GetReporter()->ReportQPS(*result); + GetReporter()->ReportLatency(*result); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + grpc::testing::RunSynchronousUnaryPingPong(); + return 0; +} diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc new file mode 100644 index 00000000..d57e5a8c --- /dev/null +++ b/test/cpp/qps/server_async.cc @@ -0,0 +1,598 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/core/lib/surface/completion_queue.h" +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/qps/qps_server_builder.h" +#include "test/cpp/qps/server.h" + +namespace grpc { +namespace testing { + +template +class AsyncQpsServerTest final : public grpc::testing::Server { + public: + AsyncQpsServerTest( + const ServerConfig& config, + std::function register_service, + std::function*, + CompletionQueue*, ServerCompletionQueue*, void*)> + request_unary_function, + std::function*, + CompletionQueue*, ServerCompletionQueue*, void*)> + request_streaming_function, + std::function*, + CompletionQueue*, ServerCompletionQueue*, void*)> + request_streaming_from_client_function, + std::function*, CompletionQueue*, + ServerCompletionQueue*, void*)> + request_streaming_from_server_function, + std::function*, + CompletionQueue*, ServerCompletionQueue*, void*)> + request_streaming_both_ways_function, + std::function + process_rpc) + : Server(config) { + std::unique_ptr builder = CreateQpsServerBuilder(); + + auto port_num = port(); + // Negative port number means inproc server, so no listen port needed + if (port_num >= 0) { + std::string server_address = grpc_core::JoinHostPort("::", port_num); + builder->AddListeningPort(server_address.c_str(), + Server::CreateServerCredentials(config), + &port_num); + } + + register_service(builder.get(), &async_service_); + + int num_threads = config.async_server_threads(); + if (num_threads <= 0) { // dynamic sizing + num_threads = std::min(64, cores()); + gpr_log(GPR_INFO, + "Sizing async server to %d threads. Defaults to number of cores " + "in machine or 64 threads if machine has more than 64 cores to " + "avoid OOMs.", + num_threads); + } + + int tpc = std::max(1, config.threads_per_cq()); // 1 if unspecified + int num_cqs = (num_threads + tpc - 1) / tpc; // ceiling operator + for (int i = 0; i < num_cqs; i++) { + srv_cqs_.emplace_back(builder->AddCompletionQueue()); + } + for (int i = 0; i < num_threads; i++) { + cq_.emplace_back(i % srv_cqs_.size()); + } + + ApplyConfigToBuilder(config, builder.get()); + + server_ = builder->BuildAndStart(); + if (server_ == nullptr) { + gpr_log(GPR_ERROR, "Server: Fail to BuildAndStart(port=%d)", port_num); + } else { + gpr_log(GPR_INFO, "Server: BuildAndStart(port=%d)", port_num); + } + + auto process_rpc_bound = + std::bind(process_rpc, config.payload_config(), std::placeholders::_1, + std::placeholders::_2); + + for (int i = 0; i < 5000; i++) { + for (int j = 0; j < num_cqs; j++) { + if (request_unary_function) { + auto request_unary = std::bind( + request_unary_function, &async_service_, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_4); + contexts_.emplace_back( + new ServerRpcContextUnaryImpl(request_unary, process_rpc_bound)); + } + if (request_streaming_function) { + auto request_streaming = std::bind( + request_streaming_function, &async_service_, + std::placeholders::_1, std::placeholders::_2, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_3); + contexts_.emplace_back(new ServerRpcContextStreamingImpl( + request_streaming, process_rpc_bound)); + } + if (request_streaming_from_client_function) { + auto request_streaming_from_client = std::bind( + request_streaming_from_client_function, &async_service_, + std::placeholders::_1, std::placeholders::_2, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_3); + contexts_.emplace_back(new ServerRpcContextStreamingFromClientImpl( + request_streaming_from_client, process_rpc_bound)); + } + if (request_streaming_from_server_function) { + auto request_streaming_from_server = + std::bind(request_streaming_from_server_function, &async_service_, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_4); + contexts_.emplace_back(new ServerRpcContextStreamingFromServerImpl( + request_streaming_from_server, process_rpc_bound)); + } + if (request_streaming_both_ways_function) { + // TODO(vjpai): Add this code + } + } + } + + for (int i = 0; i < num_threads; i++) { + shutdown_state_.emplace_back(new PerThreadShutdownState()); + threads_.emplace_back(&AsyncQpsServerTest::ThreadFunc, this, i); + } + } + ~AsyncQpsServerTest() override { + for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) { + std::lock_guard lock((*ss)->mutex); + (*ss)->shutdown = true; + } + // TODO(vjpai): Remove the following deadline and allow full proper + // shutdown. + server_->Shutdown(std::chrono::system_clock::now() + + std::chrono::seconds(3)); + for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { + (*cq)->Shutdown(); + } + for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { + thr->join(); + } + for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); ++cq) { + bool ok; + void* got_tag; + while ((*cq)->Next(&got_tag, &ok)) { + } + } + } + + int GetPollCount() override { + int count = 0; + for (auto cq = srv_cqs_.begin(); cq != srv_cqs_.end(); cq++) { + count += grpc_get_cq_poll_num((*cq)->cq()); + } + return count; + } + + std::shared_ptr InProcessChannel( + const ChannelArguments& args) override { + return server_->InProcessChannel(args); + } + + private: + void ThreadFunc(int thread_idx) { + // Wait until work is available or we are shutting down + bool ok; + void* got_tag; + if (!srv_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) { + return; + } + ServerRpcContext* ctx; + std::mutex* mu_ptr = &shutdown_state_[thread_idx]->mutex; + do { + ctx = detag(got_tag); + // The tag is a pointer to an RPC context to invoke + // Proceed while holding a lock to make sure that + // this thread isn't supposed to shut down + mu_ptr->lock(); + if (shutdown_state_[thread_idx]->shutdown) { + mu_ptr->unlock(); + return; + } + } while (srv_cqs_[cq_[thread_idx]]->DoThenAsyncNext( + [&, ctx, ok, mu_ptr]() { + ctx->lock(); + if (!ctx->RunNextState(ok)) { + ctx->Reset(); + } + ctx->unlock(); + mu_ptr->unlock(); + }, + &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))); + } + + class ServerRpcContext { + public: + ServerRpcContext() {} + void lock() { mu_.lock(); } + void unlock() { mu_.unlock(); } + virtual ~ServerRpcContext(){}; + virtual bool RunNextState(bool) = 0; // next state, return false if done + virtual void Reset() = 0; // start this back at a clean state + private: + std::mutex mu_; + }; + static void* tag(ServerRpcContext* func) { return static_cast(func); } + static ServerRpcContext* detag(void* tag) { + return static_cast(tag); + } + + class ServerRpcContextUnaryImpl final : public ServerRpcContext { + public: + ServerRpcContextUnaryImpl( + std::function*, + void*)> + request_method, + std::function invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextUnaryImpl::invoker), + request_method_(request_method), + invoke_method_(invoke_method), + response_writer_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &req_, &response_writer_, + AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextUnaryImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + response_writer_ = + grpc::ServerAsyncResponseWriter(srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextUnaryImpl::invoker; + request_method_(srv_ctx_.get(), &req_, &response_writer_, + AsyncQpsServerTest::tag(this)); + } + + private: + bool finisher(bool) { return false; } + bool invoker(bool ok) { + if (!ok) { + return false; + } + + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + + // Have the response writer work and invoke on_finish when done + next_state_ = &ServerRpcContextUnaryImpl::finisher; + response_writer_.Finish(response_, status, AsyncQpsServerTest::tag(this)); + return true; + } + std::unique_ptr srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextUnaryImpl::*next_state_)(bool); + std::function*, void*)> + request_method_; + std::function invoke_method_; + grpc::ServerAsyncResponseWriter response_writer_; + }; + + class ServerRpcContextStreamingImpl final : public ServerRpcContext { + public: + ServerRpcContextStreamingImpl( + std::function*, void*)> + request_method, + std::function invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextStreamingImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + stream_ = grpc::ServerAsyncReaderWriter( + srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextStreamingImpl::request_done; + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) { + return false; + } + next_state_ = &ServerRpcContextStreamingImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + return true; + } + + bool read_done(bool ok) { + if (ok) { + // invoke the method + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + // initiate the write + next_state_ = &ServerRpcContextStreamingImpl::write_done; + stream_.Write(response_, AsyncQpsServerTest::tag(this)); + } else { // client has sent writes done + // finish the stream + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool write_done(bool ok) { + // now go back and get another streaming read! + if (ok) { + next_state_ = &ServerRpcContextStreamingImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + } else { + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } + + std::unique_ptr srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextStreamingImpl::*next_state_)(bool); + std::function*, void*)> + request_method_; + std::function invoke_method_; + grpc::ServerAsyncReaderWriter stream_; + }; + + class ServerRpcContextStreamingFromClientImpl final + : public ServerRpcContext { + public: + ServerRpcContextStreamingFromClientImpl( + std::function*, + void*)> + request_method, + std::function invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextStreamingFromClientImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingFromClientImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + stream_ = + grpc::ServerAsyncReader(srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextStreamingFromClientImpl::request_done; + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) { + return false; + } + next_state_ = &ServerRpcContextStreamingFromClientImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + return true; + } + + bool read_done(bool ok) { + if (ok) { + // In this case, just do another read + // next_state_ is unchanged + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + return true; + } else { // client has sent writes done + // invoke the method + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + // finish the stream + next_state_ = &ServerRpcContextStreamingFromClientImpl::finish_done; + stream_.Finish(response_, Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } + + std::unique_ptr srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextStreamingFromClientImpl::*next_state_)(bool); + std::function*, + void*)> + request_method_; + std::function invoke_method_; + grpc::ServerAsyncReader stream_; + }; + + class ServerRpcContextStreamingFromServerImpl final + : public ServerRpcContext { + public: + ServerRpcContextStreamingFromServerImpl( + std::function*, void*)> + request_method, + std::function invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextStreamingFromServerImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &req_, &stream_, + AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingFromServerImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + stream_ = grpc::ServerAsyncWriter(srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextStreamingFromServerImpl::request_done; + request_method_(srv_ctx_.get(), &req_, &stream_, + AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) { + return false; + } + // invoke the method + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + + next_state_ = &ServerRpcContextStreamingFromServerImpl::write_done; + stream_.Write(response_, AsyncQpsServerTest::tag(this)); + return true; + } + + bool write_done(bool ok) { + if (ok) { + // Do another write! + // next_state_ is unchanged + stream_.Write(response_, AsyncQpsServerTest::tag(this)); + } else { // must be done so let's finish + next_state_ = &ServerRpcContextStreamingFromServerImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool finish_done(bool /*ok*/) { return false; /*reset the context*/ } + + std::unique_ptr srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextStreamingFromServerImpl::*next_state_)(bool); + std::function*, void*)> + request_method_; + std::function invoke_method_; + grpc::ServerAsyncWriter stream_; + }; + + std::vector threads_; + std::unique_ptr server_; + std::vector> srv_cqs_; + std::vector cq_; + ServiceType async_service_; + std::vector> contexts_; + + struct PerThreadShutdownState { + mutable std::mutex mutex; + bool shutdown; + PerThreadShutdownState() : shutdown(false) {} + }; + + std::vector> shutdown_state_; +}; + +static void RegisterBenchmarkService(ServerBuilder* builder, + BenchmarkService::AsyncService* service) { + builder->RegisterService(service); +} +static void RegisterGenericService(ServerBuilder* builder, + grpc::AsyncGenericService* service) { + builder->RegisterAsyncGenericService(service); +} + +static Status ProcessSimpleRPC(const PayloadConfig&, SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { + if (!Server::SetPayload(request->response_type(), request->response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + // We are done using the request. Clear it to reduce working memory. + // This proves to reduce cache misses in large message size cases. + request->Clear(); + return Status::OK; +} + +static Status ProcessGenericRPC(const PayloadConfig& payload_config, + ByteBuffer* request, ByteBuffer* response) { + // We are done using the request. Clear it to reduce working memory. + // This proves to reduce cache misses in large message size cases. + request->Clear(); + int resp_size = payload_config.bytebuf_params().resp_size(); + std::unique_ptr buf(new char[resp_size]); + memset(buf.get(), 0, static_cast(resp_size)); + Slice slice(buf.get(), resp_size); + *response = ByteBuffer(&slice, 1); + return Status::OK; +} + +std::unique_ptr CreateAsyncServer(const ServerConfig& config) { + return std::unique_ptr( + new AsyncQpsServerTest( + config, RegisterBenchmarkService, + &BenchmarkService::AsyncService::RequestUnaryCall, + &BenchmarkService::AsyncService::RequestStreamingCall, + &BenchmarkService::AsyncService::RequestStreamingFromClient, + &BenchmarkService::AsyncService::RequestStreamingFromServer, + &BenchmarkService::AsyncService::RequestStreamingBothWays, + ProcessSimpleRPC)); +} +std::unique_ptr CreateAsyncGenericServer(const ServerConfig& config) { + return std::unique_ptr( + new AsyncQpsServerTest( + config, RegisterGenericService, nullptr, + &grpc::AsyncGenericService::RequestCall, nullptr, nullptr, nullptr, + ProcessGenericRPC)); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/server_callback.cc b/test/cpp/qps/server_callback.cc new file mode 100644 index 00000000..e05f7364 --- /dev/null +++ b/test/cpp/qps/server_callback.cc @@ -0,0 +1,139 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/cpp/qps/qps_server_builder.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/qps/usage_timer.h" + +namespace grpc { +namespace testing { + +class BenchmarkCallbackServiceImpl final + : public BenchmarkService::CallbackService { + public: + ::grpc::ServerUnaryReactor* UnaryCall(::grpc::CallbackServerContext* context, + const SimpleRequest* request, + SimpleResponse* response) override { + auto* reactor = context->DefaultReactor(); + reactor->Finish(SetResponse(request, response)); + return reactor; + } + + ::grpc::ServerBidiReactor<::grpc::testing::SimpleRequest, + ::grpc::testing::SimpleResponse>* + StreamingCall(::grpc::CallbackServerContext*) override { + class Reactor + : public ::grpc::ServerBidiReactor<::grpc::testing::SimpleRequest, + ::grpc::testing::SimpleResponse> { + public: + Reactor() { StartRead(&request_); } + + void OnReadDone(bool ok) override { + if (!ok) { + Finish(::grpc::Status::OK); + return; + } + auto s = SetResponse(&request_, &response_); + if (!s.ok()) { + Finish(s); + return; + } + StartWrite(&response_); + } + + void OnWriteDone(bool ok) override { + if (!ok) { + Finish(::grpc::Status::OK); + return; + } + StartRead(&request_); + } + + void OnDone() override { delete (this); } + + private: + SimpleRequest request_; + SimpleResponse response_; + }; + return new Reactor; + } + + private: + static Status SetResponse(const SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { + if (!Server::SetPayload(request->response_type(), + request->response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; + } +}; + +class CallbackServer final : public grpc::testing::Server { + public: + explicit CallbackServer(const ServerConfig& config) : Server(config) { + std::unique_ptr builder = CreateQpsServerBuilder(); + + auto port_num = port(); + // Negative port number means inproc server, so no listen port needed + if (port_num >= 0) { + std::string server_address = grpc_core::JoinHostPort("::", port_num); + builder->AddListeningPort(server_address.c_str(), + Server::CreateServerCredentials(config), + &port_num); + } + + ApplyConfigToBuilder(config, builder.get()); + + builder->RegisterService(&service_); + + impl_ = builder->BuildAndStart(); + if (impl_ == nullptr) { + gpr_log(GPR_ERROR, "Server: Fail to BuildAndStart(port=%d)", port_num); + } else { + gpr_log(GPR_INFO, "Server: BuildAndStart(port=%d)", port_num); + } + } + + std::shared_ptr InProcessChannel( + const ChannelArguments& args) override { + return impl_->InProcessChannel(args); + } + + private: + BenchmarkCallbackServiceImpl service_; + std::unique_ptr impl_; +}; + +std::unique_ptr CreateCallbackServer( + const ServerConfig& config) { + return std::unique_ptr(new CallbackServer(config)); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/server_sync.cc b/test/cpp/qps/server_sync.cc new file mode 100644 index 00000000..094e1c1c --- /dev/null +++ b/test/cpp/qps/server_sync.cc @@ -0,0 +1,197 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/host_port.h" +#include "src/proto/grpc/testing/benchmark_service.grpc.pb.h" +#include "test/cpp/qps/qps_server_builder.h" +#include "test/cpp/qps/server.h" +#include "test/cpp/qps/usage_timer.h" + +namespace grpc { +namespace testing { + +class BenchmarkServiceImpl final : public BenchmarkService::Service { + public: + Status UnaryCall(ServerContext* /*context*/, const SimpleRequest* request, + SimpleResponse* response) override { + auto s = SetResponse(request, response); + if (!s.ok()) { + return s; + } + return Status::OK; + } + Status StreamingCall( + ServerContext* /*context*/, + ServerReaderWriter* stream) override { + SimpleRequest request; + while (stream->Read(&request)) { + SimpleResponse response; + auto s = SetResponse(&request, &response); + if (!s.ok()) { + return s; + } + if (!stream->Write(response)) { + return Status(StatusCode::INTERNAL, "Server couldn't respond"); + } + } + return Status::OK; + } + Status StreamingFromClient(ServerContext* context, + ServerReader* stream, + SimpleResponse* response) override { + auto s = ClientPull(context, stream, response); + if (!s.ok()) { + return s; + } + return Status::OK; + } + Status StreamingFromServer(ServerContext* context, + const SimpleRequest* request, + ServerWriter* stream) override { + SimpleResponse response; + auto s = SetResponse(request, &response); + if (!s.ok()) { + return s; + } + return ServerPush(context, stream, response, nullptr); + } + Status StreamingBothWays( + ServerContext* context, + ServerReaderWriter* stream) override { + // Read the first client message to setup server response + SimpleRequest request; + if (!stream->Read(&request)) { + return Status::OK; + } + SimpleResponse response; + auto s = SetResponse(&request, &response); + if (!s.ok()) { + return s; + } + std::atomic_bool done; + Status sp; + std::thread t([context, stream, &response, &done, &sp]() { + sp = ServerPush(context, stream, response, [&done]() { + return done.load(std::memory_order_relaxed); + }); + }); + SimpleResponse phony; + auto cp = ClientPull(context, stream, &phony); + done.store(true, std::memory_order_relaxed); // can be lazy + t.join(); + if (!cp.ok()) { + return cp; + } + if (!sp.ok()) { + return sp; + } + return Status::OK; + } + + private: + template + static Status ClientPull(ServerContext* /*context*/, R* stream, + SimpleResponse* response) { + SimpleRequest request; + while (stream->Read(&request)) { + } + if (request.response_size() > 0) { + if (!Server::SetPayload(request.response_type(), request.response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; + } + template + static Status ServerPush(ServerContext* /*context*/, W* stream, + const SimpleResponse& response, + const std::function& done) { + while ((done == nullptr) || !done()) { + // TODO(vjpai): Add potential for rate-pacing on this + if (!stream->Write(response)) { + return Status(StatusCode::INTERNAL, "Server couldn't push"); + } + } + return Status::OK; + } + static Status SetResponse(const SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { + if (!Server::SetPayload(request->response_type(), + request->response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; + } +}; + +class SynchronousServer final : public grpc::testing::Server { + public: + explicit SynchronousServer(const ServerConfig& config) : Server(config) { + std::unique_ptr builder = CreateQpsServerBuilder(); + + auto port_num = port(); + // Negative port number means inproc server, so no listen port needed + if (port_num >= 0) { + std::string server_address = grpc_core::JoinHostPort("::", port_num); + builder->AddListeningPort(server_address.c_str(), + Server::CreateServerCredentials(config), + &port_num); + } + + ApplyConfigToBuilder(config, builder.get()); + + builder->RegisterService(&service_); + + impl_ = builder->BuildAndStart(); + if (impl_ == nullptr) { + gpr_log(GPR_ERROR, "Server: Fail to BuildAndStart(port=%d)", port_num); + } else { + gpr_log(GPR_INFO, "Server: BuildAndStart(port=%d)", port_num); + } + } + + std::shared_ptr InProcessChannel( + const ChannelArguments& args) override { + return impl_->InProcessChannel(args); + } + + private: + BenchmarkServiceImpl service_; + std::unique_ptr impl_; +}; + +std::unique_ptr CreateSynchronousServer( + const ServerConfig& config) { + return std::unique_ptr(new SynchronousServer(config)); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/qps/usage_timer.cc b/test/cpp/qps/usage_timer.cc new file mode 100644 index 00000000..29a6aaab --- /dev/null +++ b/test/cpp/qps/usage_timer.cc @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/qps/usage_timer.h" + +#include +#include +#include + +#include +#include +#ifdef __linux__ +#include +#include + +static double time_double(struct timeval* tv) { + return tv->tv_sec + 1e-6 * tv->tv_usec; +} +#endif + +UsageTimer::UsageTimer() : start_(Sample()) {} + +double UsageTimer::Now() { + auto ts = gpr_now(GPR_CLOCK_REALTIME); + return ts.tv_sec + 1e-9 * ts.tv_nsec; +} + +static void get_resource_usage(double* utime, double* stime) { +#ifdef __linux__ + struct rusage usage; + getrusage(RUSAGE_SELF, &usage); + *utime = time_double(&usage.ru_utime); + *stime = time_double(&usage.ru_stime); +#else + *utime = 0; + *stime = 0; +#endif +} + +static void get_cpu_usage(unsigned long long* total_cpu_time, + unsigned long long* idle_cpu_time) { +#ifdef __linux__ + std::ifstream proc_stat("/proc/stat"); + proc_stat.ignore(5); + std::string cpu_time_str; + std::string first_line; + std::getline(proc_stat, first_line); + std::stringstream first_line_s(first_line); + for (int i = 0; i < 10; ++i) { + std::getline(first_line_s, cpu_time_str, ' '); + *total_cpu_time += std::stol(cpu_time_str); + if (i == 3) { + *idle_cpu_time = std::stol(cpu_time_str); + } + } +#else + // Use the parameters to avoid unused-parameter warning + (void)total_cpu_time; + (void)idle_cpu_time; + gpr_log(GPR_INFO, "get_cpu_usage(): Non-linux platform is not supported."); +#endif +} + +UsageTimer::Result UsageTimer::Sample() { + Result r; + r.wall = Now(); + get_resource_usage(&r.user, &r.system); + r.total_cpu_time = 0; + r.idle_cpu_time = 0; + get_cpu_usage(&r.total_cpu_time, &r.idle_cpu_time); + return r; +} + +UsageTimer::Result UsageTimer::Mark() const { + Result s = Sample(); + Result r; + r.wall = s.wall - start_.wall; + r.user = s.user - start_.user; + r.system = s.system - start_.system; + r.total_cpu_time = s.total_cpu_time - start_.total_cpu_time; + r.idle_cpu_time = s.idle_cpu_time - start_.idle_cpu_time; + + return r; +} diff --git a/test/cpp/qps/worker.cc b/test/cpp/qps/worker.cc new file mode 100644 index 00000000..28c7cc7e --- /dev/null +++ b/test/cpp/qps/worker.cc @@ -0,0 +1,74 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include + +#include "test/core/util/test_config.h" +#include "test/cpp/qps/qps_worker.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(int32_t, driver_port, 0, "Port for communication with driver"); +ABSL_FLAG(int32_t, server_port, 0, + "Port for operation as a server, if not specified by the server " + "config message"); +ABSL_FLAG(std::string, credential_type, grpc::testing::kInsecureCredentialsType, + "Credential type for communication with driver"); + +static bool got_sigint = false; + +static void sigint_handler(int /*x*/) { got_sigint = true; } + +namespace grpc { +namespace testing { + +std::vector* g_inproc_servers = nullptr; + +static void RunServer() { + QpsWorker worker(absl::GetFlag(FLAGS_driver_port), + absl::GetFlag(FLAGS_server_port), + absl::GetFlag(FLAGS_credential_type)); + + while (!got_sigint && !worker.Done()) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(500, GPR_TIMESPAN))); + } +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + + signal(SIGINT, sigint_handler); + + grpc::testing::RunServer(); + + return 0; +} diff --git a/test/cpp/server/authorization_policy_provider_test.cc b/test/cpp/server/authorization_policy_provider_test.cc new file mode 100644 index 00000000..b6426461 --- /dev/null +++ b/test/cpp/server/authorization_policy_provider_test.cc @@ -0,0 +1,79 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "test/core/util/test_config.h" +#include "test/core/util/tls_utils.h" + +#define VALID_POLICY_PATH_1 \ + "test/core/security/authorization/test_policies/valid_policy_1.json" +#define VALID_POLICY_PATH_2 \ + "test/core/security/authorization/test_policies/valid_policy_2.json" +#define INVALID_POLICY_PATH \ + "test/core/security/authorization/test_policies/invalid_policy.json" + +namespace grpc { + +TEST(AuthorizationPolicyProviderTest, StaticDataCreateReturnsProvider) { + grpc::Status status; + auto provider = experimental::StaticDataAuthorizationPolicyProvider::Create( + grpc_core::testing::GetFileContents(VALID_POLICY_PATH_1), &status); + ASSERT_NE(provider, nullptr); + EXPECT_NE(provider->c_provider(), nullptr); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(status.error_message().empty()); +} + +TEST(AuthorizationPolicyProviderTest, StaticDataCreateReturnsErrorStatus) { + grpc::Status status; + auto provider = experimental::StaticDataAuthorizationPolicyProvider::Create( + grpc_core::testing::GetFileContents(INVALID_POLICY_PATH), &status); + ASSERT_EQ(provider, nullptr); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_EQ(status.error_message(), "\"name\" field is not present."); +} + +TEST(AuthorizationPolicyProviderTest, FileWatcherCreateReturnsProvider) { + auto tmp_authz_policy = absl::make_unique( + grpc_core::testing::GetFileContents(VALID_POLICY_PATH_1)); + grpc::Status status; + auto provider = experimental::FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1, &status); + ASSERT_NE(provider, nullptr); + EXPECT_NE(provider->c_provider(), nullptr); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(status.error_message().empty()); +} + +TEST(AuthorizationPolicyProviderTest, FileWatcherCreateReturnsErrorStatus) { + auto tmp_authz_policy = absl::make_unique( + grpc_core::testing::GetFileContents(INVALID_POLICY_PATH)); + grpc::Status status; + auto provider = experimental::FileWatcherAuthorizationPolicyProvider::Create( + tmp_authz_policy->name(), /*refresh_interval_sec=*/1, &status); + ASSERT_EQ(provider, nullptr); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + EXPECT_EQ(status.error_message(), "\"name\" field is not present."); +} + +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/server/credentials_test.cc b/test/cpp/server/credentials_test.cc new file mode 100644 index 00000000..f0a6e099 --- /dev/null +++ b/test/cpp/server/credentials_test.cc @@ -0,0 +1,136 @@ +// +// Copyright 2020 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include +#include + +#include +#include +#include +#include + +#include "src/cpp/client/secure_credentials.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +namespace { + +constexpr const char* kRootCertName = "root_cert_name"; +constexpr const char* kRootCertContents = "root_cert_contents"; +constexpr const char* kIdentityCertName = "identity_cert_name"; +constexpr const char* kIdentityCertPrivateKey = "identity_private_key"; +constexpr const char* kIdentityCertContents = "identity_cert_contents"; + +using ::grpc::experimental::FileWatcherCertificateProvider; +using ::grpc::experimental::StaticDataCertificateProvider; + +} // namespace + +namespace grpc { +namespace testing { +namespace { + +TEST( + CredentialsTest, + TlsServerCredentialsWithStaticDataCertificateProviderLoadingRootAndIdentity) { + experimental::IdentityKeyCertPair key_cert_pair; + key_cert_pair.private_key = kIdentityCertPrivateKey; + key_cert_pair.certificate_chain = kIdentityCertContents; + std::vector identity_key_cert_pairs; + identity_key_cert_pairs.emplace_back(key_cert_pair); + auto certificate_provider = std::make_shared( + kRootCertContents, identity_key_cert_pairs); + grpc::experimental::TlsServerCredentialsOptions options(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto server_credentials = grpc::experimental::TlsServerCredentials(options); + GPR_ASSERT(server_credentials.get() != nullptr); +} + +// ServerCredentials should always have identity credential presented. +// Otherwise gRPC stack will fail. +TEST(CredentialsTest, + TlsServerCredentialsWithStaticDataCertificateProviderLoadingIdentityOnly) { + experimental::IdentityKeyCertPair key_cert_pair; + key_cert_pair.private_key = kIdentityCertPrivateKey; + key_cert_pair.certificate_chain = kIdentityCertContents; + std::vector identity_key_cert_pairs; + // Adding two key_cert_pair(s) should still work. + identity_key_cert_pairs.emplace_back(key_cert_pair); + identity_key_cert_pairs.emplace_back(key_cert_pair); + auto certificate_provider = + std::make_shared(identity_key_cert_pairs); + grpc::experimental::TlsServerCredentialsOptions options(certificate_provider); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto server_credentials = grpc::experimental::TlsServerCredentials(options); + GPR_ASSERT(server_credentials.get() != nullptr); +} + +TEST( + CredentialsTest, + TlsServerCredentialsWithFileWatcherCertificateProviderLoadingRootAndIdentity) { + auto certificate_provider = std::make_shared( + SERVER_KEY_PATH, SERVER_CERT_PATH, CA_CERT_PATH, 1); + grpc::experimental::TlsServerCredentialsOptions options(certificate_provider); + options.watch_root_certs(); + options.set_root_cert_name(kRootCertName); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto server_credentials = grpc::experimental::TlsServerCredentials(options); + GPR_ASSERT(server_credentials.get() != nullptr); +} + +// ServerCredentials should always have identity credential presented. +// Otherwise gRPC stack will fail. +TEST( + CredentialsTest, + TlsServerCredentialsWithFileWatcherCertificateProviderLoadingIdentityOnly) { + auto certificate_provider = std::make_shared( + SERVER_KEY_PATH, SERVER_CERT_PATH, 1); + grpc::experimental::TlsServerCredentialsOptions options(certificate_provider); + options.watch_identity_key_cert_pairs(); + options.set_identity_cert_name(kIdentityCertName); + options.set_cert_request_type( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + auto server_credentials = grpc::experimental::TlsServerCredentials(options); + GPR_ASSERT(server_credentials.get() != nullptr); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/server/load_reporter/get_cpu_stats_test.cc b/test/cpp/server/load_reporter/get_cpu_stats_test.cc new file mode 100644 index 00000000..9b821d2c --- /dev/null +++ b/test/cpp/server/load_reporter/get_cpu_stats_test.cc @@ -0,0 +1,62 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/get_cpu_stats.h" + +#include + +#include + +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +TEST(GetCpuStatsTest, ReadOnce) { ::grpc::load_reporter::GetCpuStatsImpl(); } + +TEST(GetCpuStatsTest, BusyNoLargerThanTotal) { + auto p = ::grpc::load_reporter::GetCpuStatsImpl(); + uint64_t busy = p.first; + uint64_t total = p.second; + ASSERT_LE(busy, total); +} + +TEST(GetCpuStatsTest, Ascending) { + const size_t kRuns = 100; + auto prev = ::grpc::load_reporter::GetCpuStatsImpl(); + for (size_t i = 0; i < kRuns; ++i) { + auto cur = ::grpc::load_reporter::GetCpuStatsImpl(); + ASSERT_LE(prev.first, cur.first); + ASSERT_LE(prev.second, cur.second); + prev = cur; + } +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/server/load_reporter/load_data_store_test.cc b/test/cpp/server/load_reporter/load_data_store_test.cc new file mode 100644 index 00000000..4d5a8775 --- /dev/null +++ b/test/cpp/server/load_reporter/load_data_store_test.cc @@ -0,0 +1,483 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_data_store.h" + +#include +#include + +#include + +#include + +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +using ::grpc::load_reporter::CallMetricValue; +using ::grpc::load_reporter::kInvalidLbId; +using ::grpc::load_reporter::LoadDataStore; +using ::grpc::load_reporter::LoadRecordKey; +using ::grpc::load_reporter::LoadRecordValue; +using ::grpc::load_reporter::PerBalancerStore; + +class LoadDataStoreTest : public ::testing::Test { + public: + LoadDataStoreTest() + : kKey1(kLbId1, kLbTag1, kUser1, kClientIp1), + kKey2(kLbId2, kLbTag2, kUser2, kClientIp2) {} + + // Check whether per_balancer_stores contains a store which was originally + // created for . + bool PerBalancerStoresContains( + const LoadDataStore& load_data_store, + const std::set* per_balancer_stores, + const std::string& hostname, const std::string& lb_id, + const std::string& load_key) { + auto original_per_balancer_store = + load_data_store.FindPerBalancerStore(hostname, lb_id); + EXPECT_NE(original_per_balancer_store, nullptr); + EXPECT_EQ(original_per_balancer_store->lb_id(), lb_id); + EXPECT_EQ(original_per_balancer_store->load_key(), load_key); + for (auto per_balancer_store : *per_balancer_stores) { + if (per_balancer_store == original_per_balancer_store) { + return true; + } + } + return false; + } + + std::string FormatLbId(size_t index) { + return "kLbId" + std::to_string(index); + } + + const std::string kHostname1 = "kHostname1"; + const std::string kHostname2 = "kHostname2"; + const std::string kLbId1 = "kLbId1"; + const std::string kLbId2 = "kLbId2"; + const std::string kLbId3 = "kLbId3"; + const std::string kLbId4 = "kLbId4"; + const std::string kLoadKey1 = "kLoadKey1"; + const std::string kLoadKey2 = "kLoadKey2"; + const std::string kLbTag1 = "kLbTag1"; + const std::string kLbTag2 = "kLbTag2"; + const std::string kUser1 = "kUser1"; + const std::string kUser2 = "kUser2"; + const std::string kClientIp1 = "00"; + const std::string kClientIp2 = "02"; + const std::string kMetric1 = "kMetric1"; + const std::string kMetric2 = "kMetric2"; + const LoadRecordKey kKey1; + const LoadRecordKey kKey2; +}; + +using PerBalancerStoreTest = LoadDataStoreTest; + +TEST_F(LoadDataStoreTest, AssignToSelf) { + LoadDataStore load_data_store; + load_data_store.ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + auto assigned_stores = load_data_store.GetAssignedStores(kHostname1, kLbId1); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_stores, + kHostname1, kLbId1, kLoadKey1)); +} + +TEST_F(LoadDataStoreTest, ReassignOrphanStores) { + LoadDataStore load_data_store; + load_data_store.ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + load_data_store.ReportStreamCreated(kHostname1, kLbId2, kLoadKey1); + load_data_store.ReportStreamCreated(kHostname1, kLbId3, kLoadKey2); + load_data_store.ReportStreamCreated(kHostname2, kLbId4, kLoadKey1); + // 1. Close the second stream. + load_data_store.ReportStreamClosed(kHostname1, kLbId2); + auto assigned_to_lb_id_1 = + load_data_store.GetAssignedStores(kHostname1, kLbId1); + // The orphaned store is re-assigned to kLbId1 with the same load key. + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_1, + kHostname1, kLbId1, kLoadKey1)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_1, + kHostname1, kLbId2, kLoadKey1)); + // 2. Close the first stream. + load_data_store.ReportStreamClosed(kHostname1, kLbId1); + auto assigned_to_lb_id_3 = + load_data_store.GetAssignedStores(kHostname1, kLbId3); + // The orphaned stores are re-assigned to kLbId3 with the same host, + // because there isn't any LB with the same load key. + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kLbId1, kLoadKey1)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kLbId2, kLoadKey1)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kLbId3, kLoadKey2)); + // 3. Close the third stream. + load_data_store.ReportStreamClosed(kHostname1, kLbId3); + auto assigned_to_lb_id_4 = + load_data_store.GetAssignedStores(kHostname2, kLbId4); + // There is no active LB for the first host now. kLbId4 is active but + // it's for the second host, so it wll NOT adopt the orphaned stores. + EXPECT_FALSE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_4, + kHostname1, kLbId1, kLoadKey1)); + EXPECT_FALSE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_4, + kHostname1, kLbId2, kLoadKey1)); + EXPECT_FALSE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_4, + kHostname1, kLbId3, kLoadKey2)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_4, + kHostname2, kLbId4, kLoadKey1)); +} + +TEST_F(LoadDataStoreTest, OrphanAssignmentIsSticky) { + LoadDataStore load_data_store; + std::set active_lb_ids; + size_t num_lb_ids = 1000; + for (size_t i = 0; i < num_lb_ids; ++i) { + load_data_store.ReportStreamCreated(kHostname1, FormatLbId(i), kLoadKey1); + active_lb_ids.insert(FormatLbId(i)); + } + std::string orphaned_lb_id = FormatLbId(std::rand() % num_lb_ids); + load_data_store.ReportStreamClosed(kHostname1, orphaned_lb_id); + active_lb_ids.erase(orphaned_lb_id); + // Find which LB is assigned the orphaned store. + std::string assigned_lb_id = ""; + for (const auto& lb_id : active_lb_ids) { + if (PerBalancerStoresContains( + load_data_store, + load_data_store.GetAssignedStores(kHostname1, lb_id), kHostname1, + orphaned_lb_id, kLoadKey1)) { + assigned_lb_id = lb_id; + break; + } + } + EXPECT_STRNE(assigned_lb_id.c_str(), ""); + // Close 10 more stream, skipping the assigned_lb_id. The assignment of + // orphaned_lb_id shouldn't change. + for (size_t _ = 0; _ < 10; ++_) { + std::string lb_id_to_close = ""; + for (const auto& lb_id : active_lb_ids) { + if (lb_id != assigned_lb_id) { + lb_id_to_close = lb_id; + break; + } + } + EXPECT_STRNE(lb_id_to_close.c_str(), ""); + load_data_store.ReportStreamClosed(kHostname1, lb_id_to_close); + active_lb_ids.erase(lb_id_to_close); + EXPECT_TRUE(PerBalancerStoresContains( + load_data_store, + load_data_store.GetAssignedStores(kHostname1, assigned_lb_id), + kHostname1, orphaned_lb_id, kLoadKey1)); + } + // Close the assigned_lb_id, orphaned_lb_id will be re-assigned again. + load_data_store.ReportStreamClosed(kHostname1, assigned_lb_id); + active_lb_ids.erase(assigned_lb_id); + size_t orphaned_lb_id_occurences = 0; + for (const auto& lb_id : active_lb_ids) { + if (PerBalancerStoresContains( + load_data_store, + load_data_store.GetAssignedStores(kHostname1, lb_id), kHostname1, + orphaned_lb_id, kLoadKey1)) { + orphaned_lb_id_occurences++; + } + } + EXPECT_EQ(orphaned_lb_id_occurences, 1U); +} + +TEST_F(LoadDataStoreTest, HostTemporarilyLoseAllStreams) { + LoadDataStore load_data_store; + load_data_store.ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + load_data_store.ReportStreamCreated(kHostname2, kLbId2, kLoadKey1); + auto store_lb_id_1 = load_data_store.FindPerBalancerStore(kHostname1, kLbId1); + auto store_invalid_lb_id_1 = + load_data_store.FindPerBalancerStore(kHostname1, kInvalidLbId); + EXPECT_FALSE(store_lb_id_1->IsSuspended()); + EXPECT_FALSE(store_invalid_lb_id_1->IsSuspended()); + // Disconnect all the streams of the first host. + load_data_store.ReportStreamClosed(kHostname1, kLbId1); + // All the streams of that host are suspended. + EXPECT_TRUE(store_lb_id_1->IsSuspended()); + EXPECT_TRUE(store_invalid_lb_id_1->IsSuspended()); + // Detailed load data won't be kept when the PerBalancerStore is suspended. + store_lb_id_1->MergeRow(kKey1, LoadRecordValue()); + store_invalid_lb_id_1->MergeRow(kKey1, LoadRecordValue()); + EXPECT_EQ(store_lb_id_1->load_record_map().size(), 0U); + EXPECT_EQ(store_invalid_lb_id_1->load_record_map().size(), 0U); + // The stores for different hosts won't mix, even if the load key is the same. + auto assigned_to_lb_id_2 = + load_data_store.GetAssignedStores(kHostname2, kLbId2); + EXPECT_EQ(assigned_to_lb_id_2->size(), 2U); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_2, + kHostname2, kLbId2, kLoadKey1)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_2, + kHostname2, kInvalidLbId, "")); + // A new stream is created for the first host. + load_data_store.ReportStreamCreated(kHostname1, kLbId3, kLoadKey2); + // The stores for the first host are resumed. + EXPECT_FALSE(store_lb_id_1->IsSuspended()); + EXPECT_FALSE(store_invalid_lb_id_1->IsSuspended()); + store_lb_id_1->MergeRow(kKey1, LoadRecordValue()); + store_invalid_lb_id_1->MergeRow(kKey1, LoadRecordValue()); + EXPECT_EQ(store_lb_id_1->load_record_map().size(), 1U); + EXPECT_EQ(store_invalid_lb_id_1->load_record_map().size(), 1U); + // The resumed stores are assigned to the new LB. + auto assigned_to_lb_id_3 = + load_data_store.GetAssignedStores(kHostname1, kLbId3); + EXPECT_EQ(assigned_to_lb_id_3->size(), 3U); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kLbId1, kLoadKey1)); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kInvalidLbId, "")); + EXPECT_TRUE(PerBalancerStoresContains(load_data_store, assigned_to_lb_id_3, + kHostname1, kLbId3, kLoadKey2)); +} + +TEST_F(LoadDataStoreTest, OneStorePerLbId) { + LoadDataStore load_data_store; + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname1, kLbId1), nullptr); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname1, kInvalidLbId), + nullptr); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId2), nullptr); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId3), nullptr); + // Create The first stream. + load_data_store.ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + auto store_lb_id_1 = load_data_store.FindPerBalancerStore(kHostname1, kLbId1); + auto store_invalid_lb_id_1 = + load_data_store.FindPerBalancerStore(kHostname1, kInvalidLbId); + // Two stores will be created: one is for the stream; the other one is for + // kInvalidLbId. + EXPECT_NE(store_lb_id_1, nullptr); + EXPECT_NE(store_invalid_lb_id_1, nullptr); + EXPECT_NE(store_lb_id_1, store_invalid_lb_id_1); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId2), nullptr); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId3), nullptr); + // Create the second stream. + load_data_store.ReportStreamCreated(kHostname2, kLbId3, kLoadKey1); + auto store_lb_id_3 = load_data_store.FindPerBalancerStore(kHostname2, kLbId3); + auto store_invalid_lb_id_2 = + load_data_store.FindPerBalancerStore(kHostname2, kInvalidLbId); + EXPECT_NE(store_lb_id_3, nullptr); + EXPECT_NE(store_invalid_lb_id_2, nullptr); + EXPECT_NE(store_lb_id_3, store_invalid_lb_id_2); + // The PerBalancerStores created for different hosts are independent. + EXPECT_NE(store_lb_id_3, store_invalid_lb_id_1); + EXPECT_NE(store_invalid_lb_id_2, store_invalid_lb_id_1); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId2), nullptr); +} + +TEST_F(LoadDataStoreTest, ExactlyOnceAssignment) { + LoadDataStore load_data_store; + size_t num_create = 100; + size_t num_close = 50; + for (size_t i = 0; i < num_create; ++i) { + load_data_store.ReportStreamCreated(kHostname1, FormatLbId(i), kLoadKey1); + } + for (size_t i = 0; i < num_close; ++i) { + load_data_store.ReportStreamClosed(kHostname1, FormatLbId(i)); + } + std::set reported_lb_ids; + for (size_t i = num_close; i < num_create; ++i) { + for (auto assigned_store : + *load_data_store.GetAssignedStores(kHostname1, FormatLbId(i))) { + EXPECT_TRUE(reported_lb_ids.insert(assigned_store->lb_id()).second); + } + } + // Add one for kInvalidLbId. + EXPECT_EQ(reported_lb_ids.size(), (num_create + 1)); + EXPECT_NE(reported_lb_ids.find(kInvalidLbId), reported_lb_ids.end()); +} + +TEST_F(LoadDataStoreTest, UnknownBalancerIdTracking) { + LoadDataStore load_data_store; + load_data_store.ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + // Merge data for a known LB ID. + LoadRecordValue v1(192); + load_data_store.MergeRow(kHostname1, kKey1, v1); + // Merge data for unknown LB ID. + LoadRecordValue v2(23); + EXPECT_FALSE(load_data_store.IsTrackedUnknownBalancerId(kLbId2)); + load_data_store.MergeRow( + kHostname1, LoadRecordKey(kLbId2, kLbTag1, kUser1, kClientIp1), v2); + EXPECT_TRUE(load_data_store.IsTrackedUnknownBalancerId(kLbId2)); + LoadRecordValue v3(952); + load_data_store.MergeRow( + kHostname2, LoadRecordKey(kLbId3, kLbTag1, kUser1, kClientIp1), v3); + EXPECT_TRUE(load_data_store.IsTrackedUnknownBalancerId(kLbId3)); + // The data kept for a known LB ID is correct. + auto store_lb_id_1 = load_data_store.FindPerBalancerStore(kHostname1, kLbId1); + EXPECT_EQ(store_lb_id_1->load_record_map().size(), 1U); + EXPECT_EQ(store_lb_id_1->load_record_map().find(kKey1)->second.start_count(), + v1.start_count()); + EXPECT_EQ(store_lb_id_1->GetNumCallsInProgressForReport(), v1.start_count()); + // No PerBalancerStore created for Unknown LB ID. + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname1, kLbId2), nullptr); + EXPECT_EQ(load_data_store.FindPerBalancerStore(kHostname2, kLbId3), nullptr); + // End all the started RPCs for kLbId1. + LoadRecordValue v4(0, v1.start_count()); + load_data_store.MergeRow(kHostname1, kKey1, v4); + EXPECT_EQ(store_lb_id_1->load_record_map().size(), 1U); + EXPECT_EQ(store_lb_id_1->load_record_map().find(kKey1)->second.start_count(), + v1.start_count()); + EXPECT_EQ(store_lb_id_1->load_record_map().find(kKey1)->second.ok_count(), + v4.ok_count()); + EXPECT_EQ(store_lb_id_1->GetNumCallsInProgressForReport(), 0U); + EXPECT_FALSE(load_data_store.IsTrackedUnknownBalancerId(kLbId1)); + // End all the started RPCs for kLbId2. + LoadRecordValue v5(0, v2.start_count()); + load_data_store.MergeRow( + kHostname1, LoadRecordKey(kLbId2, kLbTag1, kUser1, kClientIp1), v5); + EXPECT_FALSE(load_data_store.IsTrackedUnknownBalancerId(kLbId2)); + // End some of the started RPCs for kLbId3. + LoadRecordValue v6(0, v3.start_count() / 2); + load_data_store.MergeRow( + kHostname2, LoadRecordKey(kLbId3, kLbTag1, kUser1, kClientIp1), v6); + EXPECT_TRUE(load_data_store.IsTrackedUnknownBalancerId(kLbId3)); +} + +TEST_F(PerBalancerStoreTest, Suspend) { + PerBalancerStore per_balancer_store(kLbId1, kLoadKey1); + EXPECT_FALSE(per_balancer_store.IsSuspended()); + // Suspend the store. + per_balancer_store.Suspend(); + EXPECT_TRUE(per_balancer_store.IsSuspended()); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Data merged when the store is suspended won't be kept. + LoadRecordValue v1(139, 19); + per_balancer_store.MergeRow(kKey1, v1); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Resume the store. + per_balancer_store.Resume(); + EXPECT_FALSE(per_balancer_store.IsSuspended()); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Data merged after the store is resumed will be kept. + LoadRecordValue v2(23, 0, 51); + per_balancer_store.MergeRow(kKey1, v2); + EXPECT_EQ(1U, per_balancer_store.load_record_map().size()); + // Suspend the store. + per_balancer_store.Suspend(); + EXPECT_TRUE(per_balancer_store.IsSuspended()); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Data merged when the store is suspended won't be kept. + LoadRecordValue v3(62, 11); + per_balancer_store.MergeRow(kKey1, v3); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Resume the store. + per_balancer_store.Resume(); + EXPECT_FALSE(per_balancer_store.IsSuspended()); + EXPECT_EQ(0U, per_balancer_store.load_record_map().size()); + // Data merged after the store is resumed will be kept. + LoadRecordValue v4(225, 98); + per_balancer_store.MergeRow(kKey1, v4); + EXPECT_EQ(1U, per_balancer_store.load_record_map().size()); + // In-progress count is always kept. + EXPECT_EQ(per_balancer_store.GetNumCallsInProgressForReport(), + v1.start_count() - v1.ok_count() + v2.start_count() - + v2.error_count() + v3.start_count() - v3.ok_count() + + v4.start_count() - v4.ok_count()); +} + +TEST_F(PerBalancerStoreTest, DataAggregation) { + PerBalancerStore per_balancer_store(kLbId1, kLoadKey1); + // Construct some Values. + LoadRecordValue v1(992, 34, 13, 234, 164, 173467); + v1.InsertCallMetric(kMetric1, CallMetricValue(3, 2773.2)); + LoadRecordValue v2(4842, 213, 9, 393, 974, 1345); + v2.InsertCallMetric(kMetric1, CallMetricValue(7, 25.234)); + v2.InsertCallMetric(kMetric2, CallMetricValue(2, 387.08)); + // v3 doesn't change the number of in-progress RPCs. + LoadRecordValue v3(293, 55, 293 - 55, 28764, 5284, 5772); + v3.InsertCallMetric(kMetric1, CallMetricValue(61, 3465.0)); + v3.InsertCallMetric(kMetric2, CallMetricValue(13, 672.0)); + // The initial state of the store. + uint64_t num_calls_in_progress = 0; + EXPECT_FALSE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + EXPECT_EQ(per_balancer_store.GetNumCallsInProgressForReport(), + num_calls_in_progress); + // Merge v1 and get report of the number of in-progress calls. + per_balancer_store.MergeRow(kKey1, v1); + EXPECT_TRUE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + EXPECT_EQ(per_balancer_store.GetNumCallsInProgressForReport(), + num_calls_in_progress += + (v1.start_count() - v1.ok_count() - v1.error_count())); + EXPECT_FALSE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + // Merge v2 and get report of the number of in-progress calls. + per_balancer_store.MergeRow(kKey2, v2); + EXPECT_TRUE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + EXPECT_EQ(per_balancer_store.GetNumCallsInProgressForReport(), + num_calls_in_progress += + (v2.start_count() - v2.ok_count() - v2.error_count())); + EXPECT_FALSE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + // Merge v3 and get report of the number of in-progress calls. + per_balancer_store.MergeRow(kKey1, v3); + EXPECT_FALSE(per_balancer_store.IsNumCallsInProgressChangedSinceLastReport()); + EXPECT_EQ(per_balancer_store.GetNumCallsInProgressForReport(), + num_calls_in_progress); + // LoadRecordValue for kKey1 is aggregated correctly. + LoadRecordValue value_for_key1 = + per_balancer_store.load_record_map().find(kKey1)->second; + EXPECT_EQ(value_for_key1.start_count(), v1.start_count() + v3.start_count()); + EXPECT_EQ(value_for_key1.ok_count(), v1.ok_count() + v3.ok_count()); + EXPECT_EQ(value_for_key1.error_count(), v1.error_count() + v3.error_count()); + EXPECT_EQ(value_for_key1.bytes_sent(), v1.bytes_sent() + v3.bytes_sent()); + EXPECT_EQ(value_for_key1.bytes_recv(), v1.bytes_recv() + v3.bytes_recv()); + EXPECT_EQ(value_for_key1.latency_ms(), v1.latency_ms() + v3.latency_ms()); + EXPECT_EQ(value_for_key1.call_metrics().size(), 2U); + EXPECT_EQ(value_for_key1.call_metrics().find(kMetric1)->second.num_calls(), + v1.call_metrics().find(kMetric1)->second.num_calls() + + v3.call_metrics().find(kMetric1)->second.num_calls()); + EXPECT_EQ( + value_for_key1.call_metrics().find(kMetric1)->second.total_metric_value(), + v1.call_metrics().find(kMetric1)->second.total_metric_value() + + v3.call_metrics().find(kMetric1)->second.total_metric_value()); + EXPECT_EQ(value_for_key1.call_metrics().find(kMetric2)->second.num_calls(), + v3.call_metrics().find(kMetric2)->second.num_calls()); + EXPECT_EQ( + value_for_key1.call_metrics().find(kMetric2)->second.total_metric_value(), + v3.call_metrics().find(kMetric2)->second.total_metric_value()); + // LoadRecordValue for kKey2 is aggregated (trivially) correctly. + LoadRecordValue value_for_key2 = + per_balancer_store.load_record_map().find(kKey2)->second; + EXPECT_EQ(value_for_key2.start_count(), v2.start_count()); + EXPECT_EQ(value_for_key2.ok_count(), v2.ok_count()); + EXPECT_EQ(value_for_key2.error_count(), v2.error_count()); + EXPECT_EQ(value_for_key2.bytes_sent(), v2.bytes_sent()); + EXPECT_EQ(value_for_key2.bytes_recv(), v2.bytes_recv()); + EXPECT_EQ(value_for_key2.latency_ms(), v2.latency_ms()); + EXPECT_EQ(value_for_key2.call_metrics().size(), 2U); + EXPECT_EQ(value_for_key2.call_metrics().find(kMetric1)->second.num_calls(), + v2.call_metrics().find(kMetric1)->second.num_calls()); + EXPECT_EQ( + value_for_key2.call_metrics().find(kMetric1)->second.total_metric_value(), + v2.call_metrics().find(kMetric1)->second.total_metric_value()); + EXPECT_EQ(value_for_key2.call_metrics().find(kMetric2)->second.num_calls(), + v2.call_metrics().find(kMetric2)->second.num_calls()); + EXPECT_EQ( + value_for_key2.call_metrics().find(kMetric2)->second.total_metric_value(), + v2.call_metrics().find(kMetric2)->second.total_metric_value()); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/server/load_reporter/load_reporter_test.cc b/test/cpp/server/load_reporter/load_reporter_test.cc new file mode 100644 index 00000000..480935fc --- /dev/null +++ b/test/cpp/server/load_reporter/load_reporter_test.cc @@ -0,0 +1,507 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "src/cpp/server/load_reporter/load_reporter.h" + +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "opencensus/stats/testing/test_utils.h" + +#include + +#include "src/core/ext/filters/load_reporting/registered_opencensus_objects.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/cpp/server/load_reporter/constants.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +using ::grpc::lb::v1::LoadBalancingFeedback; +using ::grpc::load_reporter::CensusViewProvider; +using ::grpc::load_reporter::CpuStatsProvider; +using ::grpc::load_reporter::LoadReporter; +using ::opencensus::stats::ViewDescriptor; +using ::testing::DoubleNear; +using ::testing::Return; + +constexpr uint64_t kFeedbackSampleWindowSeconds = 5; +constexpr uint64_t kFetchAndSampleIntervalSeconds = 1; +constexpr uint64_t kNumFeedbackSamplesInWindow = + kFeedbackSampleWindowSeconds / kFetchAndSampleIntervalSeconds; + +class MockCensusViewProvider : public CensusViewProvider { + public: + MOCK_METHOD0(FetchViewData, CensusViewProvider::ViewDataMap()); + + const ::opencensus::stats::ViewDescriptor& FindViewDescriptor( + const std::string& view_name) { + auto it = view_descriptor_map().find(view_name); + GPR_ASSERT(it != view_descriptor_map().end()); + return it->second; + } +}; + +class MockCpuStatsProvider : public CpuStatsProvider { + public: + MOCK_METHOD0(GetCpuStats, CpuStatsProvider::CpuStatsSample()); +}; + +class LoadReporterTest : public ::testing::Test { + public: + LoadReporterTest() {} + + MockCensusViewProvider* mock_census_view_provider() { + return static_cast( + load_reporter_->census_view_provider()); + } + + void PrepareCpuExpectation(size_t call_num) { + auto mock_cpu_stats_provider = static_cast( + load_reporter_->cpu_stats_provider()); + ::testing::InSequence s; + for (size_t i = 0; i < call_num; ++i) { + EXPECT_CALL(*mock_cpu_stats_provider, GetCpuStats()) + .WillOnce(Return(kCpuStatsSamples[i])) + .RetiresOnSaturation(); + } + } + + CpuStatsProvider::CpuStatsSample initial_cpu_stats_{2, 20}; + const std::vector kCpuStatsSamples = { + {13, 53}, {64, 96}, {245, 345}, {314, 785}, + {874, 1230}, {1236, 2145}, {1864, 2974}}; + + std::unique_ptr load_reporter_; + + const std::string kHostname1 = "kHostname1"; + const std::string kHostname2 = "kHostname2"; + const std::string kHostname3 = "kHostname3"; + // Pad to the length of a valid LB ID. + const std::string kLbId1 = "kLbId111"; + const std::string kLbId2 = "kLbId222"; + const std::string kLbId3 = "kLbId333"; + const std::string kLbId4 = "kLbId444"; + const std::string kLoadKey1 = "kLoadKey1"; + const std::string kLoadKey2 = "kLoadKey2"; + const std::string kLoadKey3 = "kLoadKey3"; + const std::string kLbTag1 = "kLbTag1"; + const std::string kLbTag2 = "kLbTag2"; + const std::string kLbToken1 = "kLbId111kLbTag1"; + const std::string kLbToken2 = "kLbId222kLbTag2"; + const std::string kUser1 = "kUser1"; + const std::string kUser2 = "kUser2"; + const std::string kUser3 = "kUser3"; + const std::string kClientIp0 = "00"; + const std::string kClientIp1 = "0800000001"; + const std::string kClientIp2 = "3200000000000000000000000000000002"; + const std::string kMetric1 = "kMetric1"; + const std::string kMetric2 = "kMetric2"; + + private: + void SetUp() override { + // Access the measures to make them valid. + ::grpc::load_reporter::MeasureStartCount(); + ::grpc::load_reporter::MeasureEndCount(); + ::grpc::load_reporter::MeasureEndBytesSent(); + ::grpc::load_reporter::MeasureEndBytesReceived(); + ::grpc::load_reporter::MeasureEndLatencyMs(); + ::grpc::load_reporter::MeasureOtherCallMetric(); + // Set up the load reporter. + auto mock_cpu = new MockCpuStatsProvider(); + auto mock_census = new MockCensusViewProvider(); + // Prepare the initial CPU stats data. Note that the expectation should be + // set up before the load reporter is initialized, because CPU stats is + // sampled at that point. + EXPECT_CALL(*mock_cpu, GetCpuStats()) + .WillOnce(Return(initial_cpu_stats_)) + .RetiresOnSaturation(); + load_reporter_ = absl::make_unique( + kFeedbackSampleWindowSeconds, + std::unique_ptr(mock_census), + std::unique_ptr(mock_cpu)); + } +}; + +class LbFeedbackTest : public LoadReporterTest { + public: + // Note that [start, start + count) of the fake samples (maybe plus the + // initial record) are in the window now. + void VerifyLbFeedback(const LoadBalancingFeedback& lb_feedback, size_t start, + size_t count) { + const CpuStatsProvider::CpuStatsSample* base = + start == 0 ? &initial_cpu_stats_ : &kCpuStatsSamples[start - 1]; + double expected_cpu_util = + static_cast(kCpuStatsSamples[start + count - 1].first - + base->first) / + static_cast(kCpuStatsSamples[start + count - 1].second - + base->second); + ASSERT_THAT(static_cast(lb_feedback.server_utilization()), + DoubleNear(expected_cpu_util, 0.00001)); + double qps_sum = 0, eps_sum = 0; + for (size_t i = 0; i < count; ++i) { + qps_sum += kQpsEpsSamples[start + i].first; + eps_sum += kQpsEpsSamples[start + i].second; + } + double expected_qps = qps_sum / count; + double expected_eps = eps_sum / count; + // TODO(juanlishen): The error is big because we use sleep(). It should be + // much smaller when we use fake clock. + ASSERT_THAT(static_cast(lb_feedback.calls_per_second()), + DoubleNear(expected_qps, expected_qps * 0.3)); + ASSERT_THAT(static_cast(lb_feedback.errors_per_second()), + DoubleNear(expected_eps, expected_eps * 0.3)); + gpr_log(GPR_INFO, + "Verified LB feedback matches the samples of index [%zu, %zu).", + start, start + count); + } + + const std::vector> kQpsEpsSamples = { + {546.1, 153.1}, {62.1, 54.1}, {578.1, 154.2}, {978.1, 645.1}, + {1132.1, 846.4}, {531.5, 315.4}, {874.1, 324.9}}; +}; + +TEST_F(LbFeedbackTest, ZeroDuration) { + PrepareCpuExpectation(kCpuStatsSamples.size()); + EXPECT_CALL(*mock_census_view_provider(), FetchViewData()) + .WillRepeatedly( + Return(::grpc::load_reporter::CensusViewProvider::ViewDataMap())); + // Verify that divide-by-zero exception doesn't happen. + for (size_t i = 0; i < kCpuStatsSamples.size(); ++i) { + load_reporter_->FetchAndSample(); + } + load_reporter_->GenerateLoadBalancingFeedback(); +} + +TEST_F(LbFeedbackTest, Normal) { + // Prepare view data list using the samples. + std::vector view_data_map_list; + for (const auto& p : LbFeedbackTest::kQpsEpsSamples) { + double qps = p.first; + double eps = p.second; + double ok_count = (qps - eps) * kFetchAndSampleIntervalSeconds; + double error_count = eps * kFetchAndSampleIntervalSeconds; + double ok_count_1 = ok_count / 3.0; + double ok_count_2 = ok_count - ok_count_1; + auto end_count_vd = ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndCount), + {{{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + ok_count_1}, + {{kClientIp0 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + ok_count_2}, + {{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + error_count}}); + // Values for other view data don't matter. + auto end_bytes_sent_vd = + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesSent), + {{{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 0}}); + auto end_bytes_received_vd = + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesReceived), + {{{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 0}}); + auto end_latency_vd = ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndLatencyMs), + {{{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 0}, + {{kClientIp0 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 0}}); + view_data_map_list.push_back( + {{::grpc::load_reporter::kViewEndCount, end_count_vd}, + {::grpc::load_reporter::kViewEndBytesSent, end_bytes_sent_vd}, + {::grpc::load_reporter::kViewEndBytesReceived, end_bytes_received_vd}, + {::grpc::load_reporter::kViewEndLatencyMs, end_latency_vd}}); + } + { + ::testing::InSequence s; + for (size_t i = 0; i < view_data_map_list.size(); ++i) { + EXPECT_CALL(*mock_census_view_provider(), FetchViewData()) + .WillOnce(Return(view_data_map_list[i])) + .RetiresOnSaturation(); + } + } + PrepareCpuExpectation(kNumFeedbackSamplesInWindow + 2); + // When the load reporter is created, a trivial LB feedback record is added. + // But that's not enough for generating an LB feedback. + // Fetch some view data so that non-trivial LB feedback can be generated. + for (size_t i = 0; i < kNumFeedbackSamplesInWindow / 2; ++i) { + // TODO(juanlishen): Find some fake clock to speed up testing. + sleep(1); + load_reporter_->FetchAndSample(); + } + VerifyLbFeedback(load_reporter_->GenerateLoadBalancingFeedback(), 0, + kNumFeedbackSamplesInWindow / 2); + // Fetch more view data so that the feedback record window is just full (the + // initial record just falls out of the window). + for (size_t i = 0; i < (kNumFeedbackSamplesInWindow + 1) / 2; ++i) { + sleep(1); + load_reporter_->FetchAndSample(); + } + VerifyLbFeedback(load_reporter_->GenerateLoadBalancingFeedback(), 0, + kNumFeedbackSamplesInWindow); + // Further fetching will cause the old records to fall out of the window. + for (size_t i = 0; i < 2; ++i) { + sleep(1); + load_reporter_->FetchAndSample(); + } + VerifyLbFeedback(load_reporter_->GenerateLoadBalancingFeedback(), 2, + kNumFeedbackSamplesInWindow); +} + +using LoadReportTest = LoadReporterTest; + +TEST_F(LoadReportTest, BasicReport) { + // Make up the first view data map. + CensusViewProvider::ViewDataMap vdm1; + vdm1.emplace( + ::grpc::load_reporter::kViewStartCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewStartCount), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1}, 1234}, + {{kClientIp2 + kLbToken1, kHostname1, kUser1}, 1225}, + {{kClientIp0 + kLbToken1, kHostname1, kUser1}, 10}, + {{kClientIp2 + kLbToken1, kHostname1, kUser2}, 464}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3}, 101}, + {{kClientIp1 + kLbToken2, kHostname2, kUser3}, 17}, + {{kClientIp2 + kLbId3 + kLbTag2, kHostname2, kUser3}, 23}})); + vdm1.emplace(::grpc::load_reporter::kViewEndCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndCount), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 641}, + {{kClientIp2 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 272}, + {{kClientIp2 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 996}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 34}, + {{kClientIp1 + kLbToken2, kHostname2, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 18}})); + vdm1.emplace(::grpc::load_reporter::kViewEndBytesSent, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesSent), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 8977}, + {{kClientIp2 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 266}, + {{kClientIp2 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 1276}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 77823}, + {{kClientIp1 + kLbToken2, kHostname2, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 48}})); + vdm1.emplace(::grpc::load_reporter::kViewEndBytesReceived, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesReceived), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 2341}, + {{kClientIp2 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 466}, + {{kClientIp2 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 518}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 81}, + {{kClientIp1 + kLbToken2, kHostname2, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 27}})); + vdm1.emplace(::grpc::load_reporter::kViewEndLatencyMs, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndLatencyMs), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 3.14}, + {{kClientIp2 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusClientError}, + 5.26}, + {{kClientIp2 + kLbToken1, kHostname1, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 45.4}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 4.4}, + {{kClientIp1 + kLbToken2, kHostname2, kUser2, + ::grpc::load_reporter::kCallStatusOk}, + 2348.0}})); + vdm1.emplace( + ::grpc::load_reporter::kViewOtherCallMetricCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewOtherCallMetricCount), + {{{kClientIp1 + kLbToken1, kHostname1, kUser2, kMetric1}, 1}, + {{kClientIp1 + kLbToken1, kHostname1, kUser2, kMetric1}, 1}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric2}, + 1}})); + vdm1.emplace( + ::grpc::load_reporter::kViewOtherCallMetricValue, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewOtherCallMetricValue), + {{{kClientIp1 + kLbToken1, kHostname1, kUser2, kMetric1}, 1.2}, + {{kClientIp1 + kLbToken1, kHostname1, kUser2, kMetric1}, 1.2}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric2}, + 3.2}})); + // Make up the second view data map. + CensusViewProvider::ViewDataMap vdm2; + vdm2.emplace( + ::grpc::load_reporter::kViewStartCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewStartCount), + {{{kClientIp2 + kLbToken1, kHostname1, kUser1}, 3}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3}, 778}})); + vdm2.emplace(::grpc::load_reporter::kViewEndCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndCount), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 24}, + {{kClientIp1 + kLbToken2, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 546}})); + vdm2.emplace(::grpc::load_reporter::kViewEndBytesSent, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesSent), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 747}, + {{kClientIp1 + kLbToken2, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 229}})); + vdm2.emplace(::grpc::load_reporter::kViewEndBytesReceived, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndBytesReceived), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 173}, + {{kClientIp1 + kLbToken2, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 438}})); + vdm2.emplace(::grpc::load_reporter::kViewEndLatencyMs, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewEndLatencyMs), + {{{kClientIp1 + kLbToken1, kHostname1, kUser1, + ::grpc::load_reporter::kCallStatusOk}, + 187}, + {{kClientIp1 + kLbToken2, kHostname2, kUser3, + ::grpc::load_reporter::kCallStatusClientError}, + 34}})); + vdm2.emplace( + ::grpc::load_reporter::kViewOtherCallMetricCount, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewOtherCallMetricCount), + {{{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric1}, 1}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric2}, + 1}})); + vdm2.emplace( + ::grpc::load_reporter::kViewOtherCallMetricValue, + ::opencensus::stats::testing::TestUtils::MakeViewData( + mock_census_view_provider()->FindViewDescriptor( + ::grpc::load_reporter::kViewOtherCallMetricValue), + {{{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric1}, 9.6}, + {{kClientIp1 + kLbId2 + kLbTag1, kHostname2, kUser3, kMetric2}, + 5.7}})); + // Set up mock expectation. + EXPECT_CALL(*mock_census_view_provider(), FetchViewData()) + .WillOnce(Return(vdm1)) + .WillOnce(Return(vdm2)); + PrepareCpuExpectation(2); + // Start testing. + load_reporter_->ReportStreamCreated(kHostname1, kLbId1, kLoadKey1); + load_reporter_->ReportStreamCreated(kHostname2, kLbId2, kLoadKey2); + load_reporter_->ReportStreamCreated(kHostname2, kLbId3, kLoadKey3); + // First fetch. + load_reporter_->FetchAndSample(); + load_reporter_->GenerateLoads(kHostname1, kLbId1); + gpr_log(GPR_INFO, "First load generated."); + // Second fetch. + load_reporter_->FetchAndSample(); + load_reporter_->GenerateLoads(kHostname2, kLbId2); + gpr_log(GPR_INFO, "Second load generated."); + // TODO(juanlishen): Verify the data. +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/server/server_builder_test.cc b/test/cpp/server/server_builder_test.cc new file mode 100644 index 00000000..eb11e2d8 --- /dev/null +++ b/test/cpp/server/server_builder_test.cc @@ -0,0 +1,94 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +testing::EchoTestService::Service g_service; + +std::string MakePort() { + std::ostringstream s; + int p = grpc_pick_unused_port_or_die(); + s << "localhost:" << p; + return s.str(); +} + +const std::string& GetPort() { + static std::string g_port = MakePort(); + return g_port; +} + +class ServerBuilderTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } +}; +TEST_F(ServerBuilderTest, NoOp) { ServerBuilder b; } + +TEST_F(ServerBuilderTest, CreateServerNoPorts) { + ServerBuilder().RegisterService(&g_service).BuildAndStart()->Shutdown(); +} + +TEST_F(ServerBuilderTest, CreateServerOnePort) { + ServerBuilder() + .RegisterService(&g_service) + .AddListeningPort(GetPort(), InsecureServerCredentials()) + .BuildAndStart() + ->Shutdown(); +} + +TEST_F(ServerBuilderTest, CreateServerRepeatedPort) { + ServerBuilder() + .RegisterService(&g_service) + .AddListeningPort(GetPort(), InsecureServerCredentials()) + .AddListeningPort(GetPort(), InsecureServerCredentials()) + .BuildAndStart() + ->Shutdown(); +} + +TEST_F(ServerBuilderTest, CreateServerRepeatedPortWithDisallowedReusePort) { + EXPECT_EQ(ServerBuilder() + .RegisterService(&g_service) + .AddListeningPort(GetPort(), InsecureServerCredentials()) + .AddListeningPort(GetPort(), InsecureServerCredentials()) + .AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0) + .BuildAndStart(), + nullptr); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/server/server_builder_with_socket_mutator_test.cc b/test/cpp/server/server_builder_with_socket_mutator_test.cc new file mode 100644 index 00000000..8ed5372d --- /dev/null +++ b/test/cpp/server/server_builder_with_socket_mutator_test.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/socket_mutator.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +/* This test does a sanity check that grpc_socket_mutator's + * are used by servers. It's meant to protect code and end-to-end + * tests that rely on this functionality but which live outside + * of the grpc github repo. */ + +namespace grpc { +namespace { + +bool mock_socket_mutator_mutate_fd(int, grpc_socket_mutator*); +int mock_socket_mutator_compare(grpc_socket_mutator*, grpc_socket_mutator*); +void mock_socket_mutator_destroy(grpc_socket_mutator*); + +const grpc_socket_mutator_vtable mock_socket_mutator_vtable = { + mock_socket_mutator_mutate_fd, + mock_socket_mutator_compare, + mock_socket_mutator_destroy, + nullptr, +}; + +class MockSocketMutator : public grpc_socket_mutator { + public: + MockSocketMutator() : mutate_fd_call_count_(0) { + grpc_socket_mutator_init(this, &mock_socket_mutator_vtable); + } + int mutate_fd_call_count_; +}; + +bool mock_socket_mutator_mutate_fd(int /*fd*/, grpc_socket_mutator* m) { + MockSocketMutator* s = reinterpret_cast(m); + s->mutate_fd_call_count_++; + return true; +} + +int mock_socket_mutator_compare(grpc_socket_mutator* a, + grpc_socket_mutator* b) { + return reinterpret_cast(a) - reinterpret_cast(b); +} + +void mock_socket_mutator_destroy(grpc_socket_mutator* m) { + MockSocketMutator* s = reinterpret_cast(m); + delete s; +} + +class MockSocketMutatorServerBuilderOption : public grpc::ServerBuilderOption { + public: + explicit MockSocketMutatorServerBuilderOption( + MockSocketMutator* mock_socket_mutator) + : mock_socket_mutator_(mock_socket_mutator) {} + + void UpdateArguments(ChannelArguments* args) override { + args->SetSocketMutator(mock_socket_mutator_); + } + + void UpdatePlugins( + std::vector>*) override{}; + + MockSocketMutator* mock_socket_mutator_; +}; + +class ServerBuilderWithSocketMutatorTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +TEST_F(ServerBuilderWithSocketMutatorTest, CreateServerWithSocketMutator) { + auto address = "localhost:" + std::to_string(grpc_pick_unused_port_or_die()); + auto mock_socket_mutator = new MockSocketMutator(); + std::unique_ptr mock_socket_mutator_builder_option( + new MockSocketMutatorServerBuilderOption(mock_socket_mutator)); + testing::EchoTestService::Service echo_service; + EXPECT_EQ(mock_socket_mutator->mutate_fd_call_count_, 0); + ServerBuilder builder; + builder.RegisterService(&echo_service); + builder.AddListeningPort(address, InsecureServerCredentials()); + builder.SetOption(std::move(mock_socket_mutator_builder_option)); + std::unique_ptr server(builder.BuildAndStart()); + EXPECT_NE(server, nullptr); + // Only assert that the socket mutator was used. + EXPECT_GE(mock_socket_mutator->mutate_fd_call_count_, 1); + server->Shutdown(); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/server/server_request_call_test.cc b/test/cpp/server/server_request_call_test.cc new file mode 100644 index 00000000..a41d86c3 --- /dev/null +++ b/test/cpp/server/server_request_call_test.cc @@ -0,0 +1,163 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +TEST(ServerRequestCallTest, ShortDeadlineDoesNotCauseOkayFalse) { + std::mutex mu; + bool shutting_down = false; + + // grpc server config. + std::ostringstream s; + int p = grpc_pick_unused_port_or_die(); + s << "[::1]:" << p; + const string address = s.str(); + testing::EchoTestService::AsyncService service; + ServerBuilder builder; + builder.AddListeningPort(address, InsecureServerCredentials()); + auto cq = builder.AddCompletionQueue(); + builder.RegisterService(&service); + auto server = builder.BuildAndStart(); + + // server thread. + std::thread t([address, &service, &cq, &mu, &shutting_down] { + for (int n = 0; true; n++) { + ServerContext ctx; + testing::EchoRequest req; + ServerAsyncResponseWriter responder(&ctx); + + // if shutting down, don't enqueue a new request. + { + std::lock_guard lock(mu); + if (!shutting_down) { + service.RequestEcho(&ctx, &req, &responder, cq.get(), cq.get(), + reinterpret_cast(1)); + } + } + + bool ok; + void* tag; + if (!cq->Next(&tag, &ok)) { + break; + } + + EXPECT_EQ((void*)1, tag); + // If not shutting down, ok must be true for new requests. + { + std::lock_guard lock(mu); + if (!shutting_down && !ok) { + gpr_log(GPR_INFO, "!ok on request %d", n); + abort(); + } + if (shutting_down && !ok) { + // Failed connection due to shutdown, continue flushing the CQ. + continue; + } + } + + // Send a simple response after a small delay that would ensure the client + // deadline is exceeded. + gpr_log(GPR_INFO, "Got request %d", n); + testing::EchoResponse response; + response.set_message("foobar"); + // A bit of sleep to make sure the deadline elapses. + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(50, GPR_TIMESPAN))); + { + std::lock_guard lock(mu); + if (shutting_down) { + gpr_log(GPR_INFO, + "shut down while processing call, not calling Finish()"); + // Continue flushing the CQ. + continue; + } + gpr_log(GPR_INFO, "Finishing request %d", n); + responder.Finish(response, grpc::Status::OK, + reinterpret_cast(2)); + if (!cq->Next(&tag, &ok)) { + break; + } + EXPECT_EQ((void*)2, tag); + } + } + }); + + auto stub = testing::EchoTestService::NewStub( + grpc::CreateChannel(address, InsecureChannelCredentials())); + + for (int i = 0; i < 100; i++) { + gpr_log(GPR_INFO, "Sending %d.", i); + testing::EchoRequest request; + + ///////// + // Comment out the following line to get ok=false due to invalid request. + // Otherwise, ok=false due to deadline being exceeded. + ///////// + request.set_message("foobar"); + + // A simple request with a short deadline. The server will always exceed the + // deadline, whether due to the sleep or because the server was unable to + // even fetch the request from the CQ before the deadline elapsed. + testing::EchoResponse response; + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + ctx.set_deadline(std::chrono::system_clock::now() + + std::chrono::milliseconds(1)); + grpc::Status status = stub->Echo(&ctx, request, &response); + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, status.error_code()); + gpr_log(GPR_INFO, "Success."); + } + gpr_log(GPR_INFO, "Done sending RPCs."); + + // Shut down everything properly. + gpr_log(GPR_INFO, "Shutting down."); + { + std::lock_guard lock(mu); + shutting_down = true; + } + server->Shutdown(); + cq->Shutdown(); + server->Wait(); + + t.join(); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/test/client_context_test_peer_test.cc b/test/cpp/test/client_context_test_peer_test.cc new file mode 100644 index 00000000..67ab62d6 --- /dev/null +++ b/test/cpp/test/client_context_test_peer_test.cc @@ -0,0 +1,81 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include + +namespace grpc { +namespace testing { + +static internal::GrpcLibraryInitializer g_initializer; + +const char key1[] = "metadata-key1"; +const char key2[] = "metadata-key2"; +const char val1[] = "metadata-val1"; +const char val2[] = "metadata-val2"; + +bool ServerInitialMetadataContains(const ClientContext& context, + const grpc::string_ref& key, + const grpc::string_ref& value) { + const auto& server_metadata = context.GetServerInitialMetadata(); + for (auto iter = server_metadata.begin(); iter != server_metadata.end(); + ++iter) { + if (iter->first == key && iter->second == value) { + return true; + } + } + return true; +} + +TEST(ClientContextTestPeerTest, AddServerInitialMetadata) { + ClientContext context; + ClientContextTestPeer peer(&context); + + peer.AddServerInitialMetadata(key1, val1); + ASSERT_TRUE(ServerInitialMetadataContains(context, key1, val1)); + peer.AddServerInitialMetadata(key2, val2); + ASSERT_TRUE(ServerInitialMetadataContains(context, key1, val1)); + ASSERT_TRUE(ServerInitialMetadataContains(context, key2, val2)); +} + +TEST(ClientContextTestPeerTest, GetSendInitialMetadata) { + ClientContext context; + ClientContextTestPeer peer(&context); + std::multimap metadata; + + context.AddMetadata(key1, val1); + metadata.insert(std::pair(key1, val1)); + ASSERT_EQ(metadata, peer.GetSendInitialMetadata()); + + context.AddMetadata(key2, val2); + metadata.insert(std::pair(key2, val2)); + ASSERT_EQ(metadata, peer.GetSendInitialMetadata()); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/test/mock_stream_test.cc b/test/cpp/test/mock_stream_test.cc new file mode 100644 index 00000000..646596b1 --- /dev/null +++ b/test/cpp/test/mock_stream_test.cc @@ -0,0 +1,71 @@ +/* + * + * Copyright 2020 the gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/memory/memory.h" + +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +TEST(MockStreamTest, Basic) { + auto cr = absl::make_unique>(); + ASSERT_NE(cr, nullptr); + + auto cw = absl::make_unique>(); + ASSERT_NE(cw, nullptr); + + auto crw = absl::make_unique< + grpc::testing::MockClientReaderWriter>(); + ASSERT_NE(crw, nullptr); + + auto carr = absl::make_unique< + grpc::testing::MockClientAsyncResponseReader>(); + ASSERT_NE(carr, nullptr); + + auto car = + absl::make_unique>(); + ASSERT_NE(car, nullptr); + + auto caw = + absl::make_unique>(); + ASSERT_NE(caw, nullptr); + + auto carw = absl::make_unique< + grpc::testing::MockClientAsyncReaderWriter>(); + ASSERT_NE(carw, nullptr); + + auto sr = absl::make_unique>(); + ASSERT_NE(sr, nullptr); + + auto sw = absl::make_unique>(); + ASSERT_NE(sw, nullptr); + + auto srw = absl::make_unique< + grpc::testing::MockServerReaderWriter>(); + ASSERT_NE(srw, nullptr); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/test/server_context_test_spouse_test.cc b/test/cpp/test/server_context_test_spouse_test.cc new file mode 100644 index 00000000..42e9e89d --- /dev/null +++ b/test/cpp/test/server_context_test_spouse_test.cc @@ -0,0 +1,96 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include + +namespace grpc { +namespace testing { + +static internal::GrpcLibraryInitializer g_initializer; + +const char key1[] = "metadata-key1"; +const char key2[] = "metadata-key2"; +const char val1[] = "metadata-val1"; +const char val2[] = "metadata-val2"; + +bool ClientMetadataContains(const ServerContext& context, + const grpc::string_ref& key, + const grpc::string_ref& value) { + const auto& client_metadata = context.client_metadata(); + for (auto iter = client_metadata.begin(); iter != client_metadata.end(); + ++iter) { + if (iter->first == key && iter->second == value) { + return true; + } + } + return false; +} + +TEST(ServerContextTestSpouseTest, ClientMetadata) { + ServerContext context; + ServerContextTestSpouse spouse(&context); + + spouse.AddClientMetadata(key1, val1); + ASSERT_TRUE(ClientMetadataContains(context, key1, val1)); + + spouse.AddClientMetadata(key2, val2); + ASSERT_TRUE(ClientMetadataContains(context, key1, val1)); + ASSERT_TRUE(ClientMetadataContains(context, key2, val2)); +} + +TEST(ServerContextTestSpouseTest, InitialMetadata) { + ServerContext context; + ServerContextTestSpouse spouse(&context); + std::multimap metadata; + + context.AddInitialMetadata(key1, val1); + metadata.insert(std::pair(key1, val1)); + ASSERT_EQ(metadata, spouse.GetInitialMetadata()); + + context.AddInitialMetadata(key2, val2); + metadata.insert(std::pair(key2, val2)); + ASSERT_EQ(metadata, spouse.GetInitialMetadata()); +} + +TEST(ServerContextTestSpouseTest, TrailingMetadata) { + ServerContext context; + ServerContextTestSpouse spouse(&context); + std::multimap metadata; + + context.AddTrailingMetadata(key1, val1); + metadata.insert(std::pair(key1, val1)); + ASSERT_EQ(metadata, spouse.GetTrailingMetadata()); + + context.AddTrailingMetadata(key2, val2); + metadata.insert(std::pair(key2, val2)); + ASSERT_EQ(metadata, spouse.GetTrailingMetadata()); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/thread_manager/thread_manager_test.cc b/test/cpp/thread_manager/thread_manager_test.cc new file mode 100644 index 00000000..9d8db74d --- /dev/null +++ b/test/cpp/thread_manager/thread_manager_test.cc @@ -0,0 +1,194 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include + +#include "src/cpp/thread_manager/thread_manager.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +struct TestThreadManagerSettings { + // The min number of pollers that SHOULD be active in ThreadManager + int min_pollers; + + // The max number of pollers that could be active in ThreadManager + int max_pollers; + + // The sleep duration in PollForWork() function to simulate "polling" + int poll_duration_ms; + + // The sleep duration in DoWork() function to simulate "work" + int work_duration_ms; + + // Max number of times PollForWork() is called before shutting down + int max_poll_calls; + + // The thread limit (for use in resource quote) + int thread_limit; + + // How many should be instantiated + int thread_manager_count; +}; + +class TestThreadManager final : public grpc::ThreadManager { + public: + TestThreadManager(const char* name, grpc_resource_quota* rq, + const TestThreadManagerSettings& settings) + : ThreadManager(name, rq, settings.min_pollers, settings.max_pollers), + settings_(settings), + num_do_work_(0), + num_poll_for_work_(0), + num_work_found_(0) {} + + grpc::ThreadManager::WorkStatus PollForWork(void** tag, bool* ok) override; + void DoWork(void* /* tag */, bool /*ok*/, bool /*resources*/) override { + num_do_work_.fetch_add(1, std::memory_order_relaxed); + + // Simulate work by sleeping + std::this_thread::sleep_for( + std::chrono::milliseconds(settings_.work_duration_ms)); + } + + // Get number of times PollForWork() was called + int num_poll_for_work() const { + return num_poll_for_work_.load(std::memory_order_relaxed); + } + // Get number of times PollForWork() returned WORK_FOUND + int num_work_found() const { + return num_work_found_.load(std::memory_order_relaxed); + } + // Get number of times DoWork() was called + int num_do_work() const { + return num_do_work_.load(std::memory_order_relaxed); + } + + private: + TestThreadManagerSettings settings_; + + // Counters + std::atomic_int num_do_work_; // Number of calls to DoWork + std::atomic_int num_poll_for_work_; // Number of calls to PollForWork + std::atomic_int num_work_found_; // Number of times WORK_FOUND was returned +}; + +grpc::ThreadManager::WorkStatus TestThreadManager::PollForWork(void** tag, + bool* ok) { + int call_num = num_poll_for_work_.fetch_add(1, std::memory_order_relaxed); + if (call_num >= settings_.max_poll_calls) { + Shutdown(); + return SHUTDOWN; + } + + // Simulate "polling" duration + std::this_thread::sleep_for( + std::chrono::milliseconds(settings_.poll_duration_ms)); + *tag = nullptr; + *ok = true; + + // Return timeout roughly 1 out of every 3 calls just to make the test a bit + // more interesting + if (call_num % 3 == 0) { + return TIMEOUT; + } + + num_work_found_.fetch_add(1, std::memory_order_relaxed); + return WORK_FOUND; +} + +class ThreadManagerTest + : public ::testing::TestWithParam { + protected: + void SetUp() override { + grpc_resource_quota* rq = grpc_resource_quota_create("Thread manager test"); + if (GetParam().thread_limit > 0) { + grpc_resource_quota_set_max_threads(rq, GetParam().thread_limit); + } + for (int i = 0; i < GetParam().thread_manager_count; i++) { + thread_manager_.emplace_back( + new TestThreadManager("TestThreadManager", rq, GetParam())); + } + grpc_resource_quota_unref(rq); + for (auto& tm : thread_manager_) { + tm->Initialize(); + } + for (auto& tm : thread_manager_) { + tm->Wait(); + } + } + + std::vector> thread_manager_; +}; + +TestThreadManagerSettings scenarios[] = { + {2 /* min_pollers */, 10 /* max_pollers */, 10 /* poll_duration_ms */, + 1 /* work_duration_ms */, 50 /* max_poll_calls */, + INT_MAX /* thread_limit */, 1 /* thread_manager_count */}, + {1 /* min_pollers */, 1 /* max_pollers */, 1 /* poll_duration_ms */, + 10 /* work_duration_ms */, 50 /* max_poll_calls */, 3 /* thread_limit */, + 2 /* thread_manager_count */}}; + +INSTANTIATE_TEST_SUITE_P(ThreadManagerTest, ThreadManagerTest, + ::testing::ValuesIn(scenarios)); + +TEST_P(ThreadManagerTest, TestPollAndWork) { + for (auto& tm : thread_manager_) { + // Verify that The number of times DoWork() was called is equal to the + // number of times WORK_FOUND was returned + gpr_log(GPR_DEBUG, "DoWork() called %d times", tm->num_do_work()); + EXPECT_GE(tm->num_poll_for_work(), GetParam().max_poll_calls); + EXPECT_EQ(tm->num_do_work(), tm->num_work_found()); + } +} + +TEST_P(ThreadManagerTest, TestThreadQuota) { + if (GetParam().thread_limit > 0) { + for (auto& tm : thread_manager_) { + EXPECT_GE(tm->num_poll_for_work(), GetParam().max_poll_calls); + EXPECT_LE(tm->GetMaxActiveThreadsSoFar(), GetParam().thread_limit); + } + } +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + std::srand(std::time(nullptr)); + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + grpc_init(); + auto ret = RUN_ALL_TESTS(); + grpc_shutdown(); + + return ret; +} diff --git a/test/cpp/util/byte_buffer_proto_helper.cc b/test/cpp/util/byte_buffer_proto_helper.cc new file mode 100644 index 00000000..efccb863 --- /dev/null +++ b/test/cpp/util/byte_buffer_proto_helper.cc @@ -0,0 +1,59 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include "absl/memory/memory.h" + +namespace grpc { +namespace testing { + +bool ParseFromByteBuffer(ByteBuffer* buffer, grpc::protobuf::Message* message) { + std::vector slices; + (void)buffer->Dump(&slices); + std::string buf; + buf.reserve(buffer->Length()); + for (auto s = slices.begin(); s != slices.end(); s++) { + buf.append(reinterpret_cast(s->begin()), s->size()); + } + return message->ParseFromString(buf); +} + +std::unique_ptr SerializeToByteBuffer( + grpc::protobuf::Message* message) { + std::string buf; + message->SerializeToString(&buf); + Slice slice(buf); + return absl::make_unique(&slice, 1); +} + +bool SerializeToByteBufferInPlace(grpc::protobuf::Message* message, + ByteBuffer* buffer) { + std::string buf; + if (!message->SerializeToString(&buf)) { + return false; + } + buffer->Clear(); + Slice slice(buf); + ByteBuffer tmp(&slice, 1); + buffer->Swap(&tmp); + return true; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/byte_buffer_test.cc b/test/cpp/util/byte_buffer_test.cc new file mode 100644 index 00000000..ab18b5ec --- /dev/null +++ b/test/cpp/util/byte_buffer_test.cc @@ -0,0 +1,163 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +namespace { + +const char* kContent1 = "hello xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; +const char* kContent2 = "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy world"; + +class ByteBufferTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +TEST_F(ByteBufferTest, CopyCtor) { + ByteBuffer buffer1; + EXPECT_FALSE(buffer1.Valid()); + const ByteBuffer& buffer2 = buffer1; + EXPECT_FALSE(buffer2.Valid()); +} + +TEST_F(ByteBufferTest, CreateFromSingleSlice) { + Slice s(kContent1); + ByteBuffer buffer(&s, 1); + EXPECT_EQ(strlen(kContent1), buffer.Length()); +} + +TEST_F(ByteBufferTest, CreateFromVector) { + std::vector slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length()); +} + +TEST_F(ByteBufferTest, Clear) { + Slice s(kContent1); + ByteBuffer buffer(&s, 1); + buffer.Clear(); + EXPECT_EQ(static_cast(0), buffer.Length()); +} + +TEST_F(ByteBufferTest, Length) { + std::vector slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length()); +} + +bool SliceEqual(const Slice& a, grpc_slice b) { + if (a.size() != GRPC_SLICE_LENGTH(b)) { + return false; + } + for (size_t i = 0; i < a.size(); i++) { + if (a.begin()[i] != GRPC_SLICE_START_PTR(b)[i]) { + return false; + } + } + return true; +} + +TEST_F(ByteBufferTest, Dump) { + grpc_slice hello = grpc_slice_from_copied_string(kContent1); + grpc_slice world = grpc_slice_from_copied_string(kContent2); + std::vector slices; + slices.push_back(Slice(hello, Slice::STEAL_REF)); + slices.push_back(Slice(world, Slice::STEAL_REF)); + ByteBuffer buffer(&slices[0], 2); + slices.clear(); + (void)buffer.Dump(&slices); + EXPECT_TRUE(SliceEqual(slices[0], hello)); + EXPECT_TRUE(SliceEqual(slices[1], world)); +} + +TEST_F(ByteBufferTest, SerializationMakesCopy) { + grpc_slice hello = grpc_slice_from_copied_string(kContent1); + grpc_slice world = grpc_slice_from_copied_string(kContent2); + std::vector slices; + slices.push_back(Slice(hello, Slice::STEAL_REF)); + slices.push_back(Slice(world, Slice::STEAL_REF)); + ByteBuffer send_buffer; + bool owned = false; + ByteBuffer buffer(&slices[0], 2); + slices.clear(); + auto status = SerializationTraits::Serialize( + buffer, &send_buffer, &owned); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(owned); + EXPECT_TRUE(send_buffer.Valid()); +} + +TEST_F(ByteBufferTest, TrySingleSliceWithSingleSlice) { + std::vector slices; + slices.emplace_back(kContent1); + ByteBuffer buffer(&slices[0], 1); + Slice slice; + EXPECT_TRUE(buffer.TrySingleSlice(&slice).ok()); + EXPECT_EQ(slice.size(), slices[0].size()); + EXPECT_EQ(memcmp(slice.begin(), slices[0].begin(), slice.size()), 0); +} + +TEST_F(ByteBufferTest, TrySingleSliceWithMultipleSlices) { + std::vector slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + Slice slice; + EXPECT_FALSE(buffer.TrySingleSlice(&slice).ok()); +} + +TEST_F(ByteBufferTest, DumpToSingleSlice) { + std::vector slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + Slice slice; + EXPECT_TRUE(buffer.DumpToSingleSlice(&slice).ok()); + EXPECT_EQ(strlen(kContent1) + strlen(kContent2), slice.size()); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/util/channel_trace_proto_helper.cc b/test/cpp/util/channel_trace_proto_helper.cc new file mode 100644 index 00000000..770c9b00 --- /dev/null +++ b/test/cpp/util/channel_trace_proto_helper.cc @@ -0,0 +1,116 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/cpp/util/channel_trace_proto_helper.h" + +#include + +#include +#include +#include +#include + +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/json/json.h" +#include "src/proto/grpc/channelz/channelz.pb.h" + +namespace grpc { + +namespace { + +// Generic helper that takes in a json string, converts it to a proto, and +// then back to json. This ensures that the json string was correctly formatted +// according to https://developers.google.com/protocol-buffers/docs/proto3#json +template +void VaidateProtoJsonTranslation(const std::string& json_str) { + Message msg; + grpc::protobuf::json::JsonParseOptions parse_options; + // If the following line is failing, then uncomment the last line of the + // comment, and uncomment the lines that print the two strings. You can + // then compare the output, and determine what fields are missing. + // + // parse_options.ignore_unknown_fields = true; + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, &msg, parse_options); + EXPECT_TRUE(s.ok()); + std::string proto_json_str; + grpc::protobuf::json::JsonPrintOptions print_options; + // We usually do not want this to be true, however it can be helpful to + // uncomment and see the output produced then all fields are printed. + // print_options.always_print_primitive_fields = true; + s = grpc::protobuf::json::MessageToJsonString(msg, &proto_json_str); + EXPECT_TRUE(s.ok()); + // Parse JSON and re-dump to string, to make sure formatting is the + // same as what would be generated by our JSON library. + grpc_error_handle error = GRPC_ERROR_NONE; + grpc_core::Json parsed_json = + grpc_core::Json::Parse(proto_json_str.c_str(), &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_std_string(error); + ASSERT_EQ(parsed_json.type(), grpc_core::Json::Type::OBJECT); + proto_json_str = parsed_json.Dump(); + // uncomment these to compare the json strings. + // gpr_log(GPR_ERROR, "tracer json: %s", json_str.c_str()); + // gpr_log(GPR_ERROR, "proto json: %s", proto_json_str.c_str()); + EXPECT_EQ(json_str, proto_json_str); +} + +} // namespace + +namespace testing { + +void ValidateChannelTraceProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation(json_c_str); +} + +void ValidateChannelProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation(json_c_str); +} + +void ValidateGetTopChannelsResponseProtoJsonTranslation( + const char* json_c_str) { + VaidateProtoJsonTranslation( + json_c_str); +} + +void ValidateGetChannelResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation( + json_c_str); +} + +void ValidateGetServerResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation( + json_c_str); +} + +void ValidateSubchannelProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation(json_c_str); +} + +void ValidateServerProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation(json_c_str); +} + +void ValidateGetServersResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation( + json_c_str); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/channelz_sampler.cc b/test/cpp/util/channelz_sampler.cc new file mode 100644 index 00000000..a63e6b71 --- /dev/null +++ b/test/cpp/util/channelz_sampler.cc @@ -0,0 +1,593 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "google/protobuf/text_format.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/json/json.h" +#include "src/cpp/server/channelz/channelz_service.h" +#include "src/proto/grpc/channelz/channelz.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(std::string, server_address, "", "channelz server address"); +ABSL_FLAG(std::string, custom_credentials_type, "", "custom credentials type"); +ABSL_FLAG(int64_t, sampling_times, 1, "number of sampling"); +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(int64_t, sampling_interval_seconds, 0, + "sampling interval in seconds"); +ABSL_FLAG(std::string, output_json, "", "output filename in json format"); + +namespace { +using grpc::ClientContext; +using grpc::Status; +using grpc::StatusCode; +using grpc::channelz::v1::GetChannelRequest; +using grpc::channelz::v1::GetChannelResponse; +using grpc::channelz::v1::GetServersRequest; +using grpc::channelz::v1::GetServersResponse; +using grpc::channelz::v1::GetSocketRequest; +using grpc::channelz::v1::GetSocketResponse; +using grpc::channelz::v1::GetSubchannelRequest; +using grpc::channelz::v1::GetSubchannelResponse; +using grpc::channelz::v1::GetTopChannelsRequest; +using grpc::channelz::v1::GetTopChannelsResponse; +} // namespace + +class ChannelzSampler final { + public: + // Get server_id of a server + int64_t GetServerID(const grpc::channelz::v1::Server& server) { + return server.ref().server_id(); + } + + // Get channel_id of a channel + inline int64_t GetChannelID(const grpc::channelz::v1::Channel& channel) { + return channel.ref().channel_id(); + } + + // Get subchannel_id of a subchannel + inline int64_t GetSubchannelID( + const grpc::channelz::v1::Subchannel& subchannel) { + return subchannel.ref().subchannel_id(); + } + + // Get socket_id of a socket + inline int64_t GetSocketID(const grpc::channelz::v1::Socket& socket) { + return socket.ref().socket_id(); + } + + // Get name of a server + inline std::string GetServerName(const grpc::channelz::v1::Server& server) { + return server.ref().name(); + } + + // Get name of a channel + inline std::string GetChannelName( + const grpc::channelz::v1::Channel& channel) { + return channel.ref().name(); + } + + // Get name of a subchannel + inline std::string GetSubchannelName( + const grpc::channelz::v1::Subchannel& subchannel) { + return subchannel.ref().name(); + } + + // Get name of a socket + inline std::string GetSocketName(const grpc::channelz::v1::Socket& socket) { + return socket.ref().name(); + } + + // Get a channel based on channel_id + grpc::channelz::v1::Channel GetChannelRPC(int64_t channel_id) { + GetChannelRequest get_channel_request; + get_channel_request.set_channel_id(channel_id); + GetChannelResponse get_channel_response; + ClientContext get_channel_context; + get_channel_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetChannel( + &get_channel_context, get_channel_request, &get_channel_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetChannelRPC failed: %s", + get_channel_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_channel_response.channel(); + } + + // Get a subchannel based on subchannel_id + grpc::channelz::v1::Subchannel GetSubchannelRPC(int64_t subchannel_id) { + GetSubchannelRequest get_subchannel_request; + get_subchannel_request.set_subchannel_id(subchannel_id); + GetSubchannelResponse get_subchannel_response; + ClientContext get_subchannel_context; + get_subchannel_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetSubchannel(&get_subchannel_context, + get_subchannel_request, + &get_subchannel_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetSubchannelRPC failed: %s", + get_subchannel_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_subchannel_response.subchannel(); + } + + // get a socket based on socket_id + grpc::channelz::v1::Socket GetSocketRPC(int64_t socket_id) { + GetSocketRequest get_socket_request; + get_socket_request.set_socket_id(socket_id); + GetSocketResponse get_socket_response; + ClientContext get_socket_context; + get_socket_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetSocket( + &get_socket_context, get_socket_request, &get_socket_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetSocketRPC failed: %s", + get_socket_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_socket_response.socket(); + } + + // get the descedent channels/subchannels/sockets of a channel + // push descedent channels/subchannels to queue for layer traverse + // store descedent channels/subchannels/sockets for dumping data + void GetChannelDescedence( + const grpc::channelz::v1::Channel& channel, + std::queue& channel_queue, + std::queue& subchannel_queue) { + std::cout << " Channel ID" << GetChannelID(channel) << "_" + << GetChannelName(channel) << " descendence - "; + if (channel.channel_ref_size() > 0 || channel.subchannel_ref_size() > 0) { + if (channel.channel_ref_size() > 0) { + std::cout << "channel: "; + for (const auto& _channelref : channel.channel_ref()) { + int64_t ch_id = _channelref.channel_id(); + std::cout << "ID" << ch_id << "_" << _channelref.name() << " "; + grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id); + channel_queue.push(ch); + if (CheckID(ch_id)) { + all_channels_.push_back(ch); + StoreChannelInJson(ch); + } + } + if (channel.subchannel_ref_size() > 0) { + std::cout << ", "; + } + } + if (channel.subchannel_ref_size() > 0) { + std::cout << "subchannel: "; + for (const auto& _subchannelref : channel.subchannel_ref()) { + int64_t subch_id = _subchannelref.subchannel_id(); + std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " "; + grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id); + subchannel_queue.push(subch); + if (CheckID(subch_id)) { + all_subchannels_.push_back(subch); + StoreSubchannelInJson(subch); + } + } + } + } else if (channel.socket_ref_size() > 0) { + std::cout << "socket: "; + for (const auto& _socketref : channel.socket_ref()) { + int64_t so_id = _socketref.socket_id(); + std::cout << "ID" << so_id << "_" << _socketref.name() << " "; + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + if (CheckID(so_id)) { + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + } + std::cout << std::endl; + } + + // get the descedent channels/subchannels/sockets of a subchannel + // push descedent channels/subchannels to queue for layer traverse + // store descedent channels/subchannels/sockets for dumping data + void GetSubchannelDescedence( + grpc::channelz::v1::Subchannel& subchannel, + std::queue& channel_queue, + std::queue& subchannel_queue) { + std::cout << " Subchannel ID" << GetSubchannelID(subchannel) << "_" + << GetSubchannelName(subchannel) << " descendence - "; + if (subchannel.channel_ref_size() > 0 || + subchannel.subchannel_ref_size() > 0) { + if (subchannel.channel_ref_size() > 0) { + std::cout << "channel: "; + for (const auto& _channelref : subchannel.channel_ref()) { + int64_t ch_id = _channelref.channel_id(); + std::cout << "ID" << ch_id << "_" << _channelref.name() << " "; + grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id); + channel_queue.push(ch); + if (CheckID(ch_id)) { + all_channels_.push_back(ch); + StoreChannelInJson(ch); + } + } + if (subchannel.subchannel_ref_size() > 0) { + std::cout << ", "; + } + } + if (subchannel.subchannel_ref_size() > 0) { + std::cout << "subchannel: "; + for (const auto& _subchannelref : subchannel.subchannel_ref()) { + int64_t subch_id = _subchannelref.subchannel_id(); + std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " "; + grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id); + subchannel_queue.push(subch); + if (CheckID(subch_id)) { + all_subchannels_.push_back(subch); + StoreSubchannelInJson(subch); + } + } + } + } else if (subchannel.socket_ref_size() > 0) { + std::cout << "socket: "; + for (const auto& _socketref : subchannel.socket_ref()) { + int64_t so_id = _socketref.socket_id(); + std::cout << "ID" << so_id << "_" << _socketref.name() << " "; + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + if (CheckID(so_id)) { + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + } + std::cout << std::endl; + } + + // Set up the channelz sampler client + // Initialize json as an array + void Setup(const std::string& custom_credentials_type, + const std::string& server_address) { + json_ = grpc_core::Json::Array(); + rpc_timeout_seconds_ = 20; + grpc::ChannelArguments channel_args; + std::shared_ptr channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + if (!channel_creds) { + gpr_log(GPR_ERROR, + "Wrong user credential type: %s. Allowed credential types: " + "INSECURE_CREDENTIALS, ssl, alts, google_default_credentials.", + custom_credentials_type.c_str()); + GPR_ASSERT(0); + } + std::shared_ptr channel = + CreateChannel(server_address, channel_creds); + channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel); + } + + // Get all servers, keep querying until getting all + // Store servers for dumping data + // Need to check id repeating for servers + void GetServersRPC() { + int64_t server_start_id = 0; + while (true) { + GetServersRequest get_servers_request; + GetServersResponse get_servers_response; + ClientContext get_servers_context; + get_servers_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + get_servers_request.set_start_server_id(server_start_id); + Status status = channelz_stub_->GetServers( + &get_servers_context, get_servers_request, &get_servers_response); + if (!status.ok()) { + if (status.error_code() == StatusCode::UNIMPLEMENTED) { + gpr_log(GPR_ERROR, + "Error status UNIMPLEMENTED. Please check and make sure " + "channelz has been registered on the server being queried."); + } else { + gpr_log(GPR_ERROR, + "GetServers RPC with GetServersRequest.server_start_id=%d, " + "failed: %s", + int(server_start_id), + get_servers_context.debug_error_string().c_str()); + } + GPR_ASSERT(0); + } + for (const auto& _server : get_servers_response.server()) { + all_servers_.push_back(_server); + StoreServerInJson(_server); + } + if (!get_servers_response.end()) { + server_start_id = GetServerID(all_servers_.back()) + 1; + } else { + break; + } + } + std::cout << "Number of servers = " << all_servers_.size() << std::endl; + } + + // Get sockets that belongs to servers + // Store sockets for dumping data + void GetSocketsOfServers() { + for (const auto& _server : all_servers_) { + std::cout << "Server ID" << GetServerID(_server) << "_" + << GetServerName(_server) << " listen_socket - "; + for (const auto& _socket : _server.listen_socket()) { + int64_t so_id = _socket.socket_id(); + std::cout << "ID" << so_id << "_" << _socket.name() << " "; + if (CheckID(so_id)) { + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + std::cout << std::endl; + } + } + + // Get all top channels, keep querying until getting all + // Store channels for dumping data + // No need to check id repeating for top channels + void GetTopChannelsRPC() { + int64_t channel_start_id = 0; + while (true) { + GetTopChannelsRequest get_top_channels_request; + GetTopChannelsResponse get_top_channels_response; + ClientContext get_top_channels_context; + get_top_channels_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + get_top_channels_request.set_start_channel_id(channel_start_id); + Status status = channelz_stub_->GetTopChannels( + &get_top_channels_context, get_top_channels_request, + &get_top_channels_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, + "GetTopChannels RPC with " + "GetTopChannelsRequest.channel_start_id=%d failed: %s", + int(channel_start_id), + get_top_channels_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + for (const auto& _topchannel : get_top_channels_response.channel()) { + top_channels_.push_back(_topchannel); + all_channels_.push_back(_topchannel); + StoreChannelInJson(_topchannel); + } + if (!get_top_channels_response.end()) { + channel_start_id = GetChannelID(top_channels_.back()) + 1; + } else { + break; + } + } + std::cout << std::endl + << "Number of top channels = " << top_channels_.size() + << std::endl; + } + + // layer traverse for each top channel + void TraverseTopChannels() { + for (const auto& _topchannel : top_channels_) { + int tree_depth = 0; + std::queue channel_queue; + std::queue subchannel_queue; + std::cout << "Tree depth = " << tree_depth << std::endl; + GetChannelDescedence(_topchannel, channel_queue, subchannel_queue); + while (!channel_queue.empty() || !subchannel_queue.empty()) { + ++tree_depth; + std::cout << "Tree depth = " << tree_depth << std::endl; + int ch_q_size = channel_queue.size(); + int subch_q_size = subchannel_queue.size(); + for (int i = 0; i < ch_q_size; ++i) { + grpc::channelz::v1::Channel ch = channel_queue.front(); + channel_queue.pop(); + GetChannelDescedence(ch, channel_queue, subchannel_queue); + } + for (int i = 0; i < subch_q_size; ++i) { + grpc::channelz::v1::Subchannel subch = subchannel_queue.front(); + subchannel_queue.pop(); + GetSubchannelDescedence(subch, channel_queue, subchannel_queue); + } + } + std::cout << std::endl; + } + } + + // dump data of all entities to stdout + void DumpStdout() { + std::string data_str; + for (const auto& _channel : all_channels_) { + std::cout << "channel ID" << GetChannelID(_channel) << "_" + << GetChannelName(_channel) << " data:" << std::endl; + // TODO(mohanli): TextFormat::PrintToString records time as seconds and + // nanos. Need a more human readable way. + ::google::protobuf::TextFormat::PrintToString(_channel.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _subchannel : all_subchannels_) { + std::cout << "subchannel ID" << GetSubchannelID(_subchannel) << "_" + << GetSubchannelName(_subchannel) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_subchannel.data(), + &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _server : all_servers_) { + std::cout << "server ID" << GetServerID(_server) << "_" + << GetServerName(_server) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_server.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _socket : all_sockets_) { + std::cout << "socket ID" << GetSocketID(_socket) << "_" + << GetSocketName(_socket) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_socket.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + } + + // Store a channel in Json + void StoreChannelInJson(const grpc::channelz::v1::Channel& channel) { + std::string id = grpc::to_string(GetChannelID(channel)); + std::string type = "Channel"; + std::string description; + ::google::protobuf::TextFormat::PrintToString(channel.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a subchannel in Json + void StoreSubchannelInJson(const grpc::channelz::v1::Subchannel& subchannel) { + std::string id = grpc::to_string(GetSubchannelID(subchannel)); + std::string type = "Subchannel"; + std::string description; + ::google::protobuf::TextFormat::PrintToString(subchannel.data(), + &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a server in Json + void StoreServerInJson(const grpc::channelz::v1::Server& server) { + std::string id = grpc::to_string(GetServerID(server)); + std::string type = "Server"; + std::string description; + ::google::protobuf::TextFormat::PrintToString(server.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a socket in Json + void StoreSocketInJson(const grpc::channelz::v1::Socket& socket) { + std::string id = grpc::to_string(GetSocketID(socket)); + std::string type = "Socket"; + std::string description; + ::google::protobuf::TextFormat::PrintToString(socket.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store an entity in Json + void StoreEntityInJson(std::string& id, std::string& type, + const grpc_core::Json& description) { + std::string start, finish; + gpr_timespec ago = gpr_time_sub( + now_, + gpr_time_from_seconds(absl::GetFlag(FLAGS_sampling_interval_seconds), + GPR_TIMESPAN)); + std::stringstream ss; + const time_t time_now = now_.tv_sec; + ss << std::put_time(std::localtime(&time_now), "%F %T"); + finish = ss.str(); // example: "2019-02-01 12:12:18" + ss.str(""); + const time_t time_ago = ago.tv_sec; + ss << std::put_time(std::localtime(&time_ago), "%F %T"); + start = ss.str(); + grpc_core::Json obj = + grpc_core::Json::Object{{"Task", absl::StrFormat("%s_ID%s", type, id)}, + {"Start", start}, + {"Finish", finish}, + {"ID", id}, + {"Type", type}, + {"Description", description}}; + json_.mutable_array()->push_back(obj); + } + + // Dump data in json + std::string DumpJson() { return json_.Dump(); } + + // Check if one entity has been recorded + bool CheckID(int64_t id) { + if (id_set_.count(id) == 0) { + id_set_.insert(id); + return true; + } else { + return false; + } + } + + // Record current time + void RecordNow() { now_ = gpr_now(GPR_CLOCK_REALTIME); } + + private: + std::unique_ptr channelz_stub_; + std::vector top_channels_; + std::vector all_servers_; + std::vector all_channels_; + std::vector all_subchannels_; + std::vector all_sockets_; + std::unordered_set id_set_; + grpc_core::Json json_; + int64_t rpc_timeout_seconds_; + gpr_timespec now_; +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + std::ofstream output_file(absl::GetFlag(FLAGS_output_json)); + for (int i = 0; i < absl::GetFlag(FLAGS_sampling_times); ++i) { + ChannelzSampler channelz_sampler; + channelz_sampler.Setup(absl::GetFlag(FLAGS_custom_credentials_type), + absl::GetFlag(FLAGS_server_address)); + std::cout << "Wait for sampling interval " + << absl::GetFlag(FLAGS_sampling_interval_seconds) << "s..." + << std::endl; + const gpr_timespec kDelay = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(absl::GetFlag(FLAGS_sampling_interval_seconds), + GPR_TIMESPAN)); + gpr_sleep_until(kDelay); + std::cout << "##### " << i << "th sampling #####" << std::endl; + channelz_sampler.RecordNow(); + channelz_sampler.GetServersRPC(); + channelz_sampler.GetSocketsOfServers(); + channelz_sampler.GetTopChannelsRPC(); + channelz_sampler.TraverseTopChannels(); + channelz_sampler.DumpStdout(); + if (!absl::GetFlag(FLAGS_output_json).empty()) { + output_file << channelz_sampler.DumpJson() << "\n" << std::flush; + } + } + output_file.close(); + return 0; +} diff --git a/test/cpp/util/channelz_sampler_test.cc b/test/cpp/util/channelz_sampler_test.cc new file mode 100644 index 00000000..c8519f66 --- /dev/null +++ b/test/cpp/util/channelz_sampler_test.cc @@ -0,0 +1,179 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/cpp/server/channelz/channelz_service.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" +#include "test/cpp/util/test_credentials_provider.h" + +static std::string g_root; + +namespace { +using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +} // namespace + +// Test variables +std::string server_address("0.0.0.0:10000"); +std::string custom_credentials_type("INSECURE_CREDENTIALS"); +std::string sampling_times = "2"; +std::string sampling_interval_seconds = "3"; +std::string output_json("output.json"); + +// Creata an echo server +class EchoServerImpl final : public grpc::testing::TestService::Service { + Status EmptyCall(::grpc::ServerContext* /*context*/, + const grpc::testing::Empty* /*request*/, + grpc::testing::Empty* /*response*/) override { + return Status::OK; + } +}; + +// Run client in a thread +void RunClient(const std::string& client_id, gpr_event* done_ev) { + grpc::ChannelArguments channel_args; + std::shared_ptr channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + std::unique_ptr stub = + grpc::testing::TestService::NewStub( + grpc::CreateChannel(server_address, channel_creds)); + gpr_log(GPR_INFO, "Client %s is echoing!", client_id.c_str()); + while (true) { + if (gpr_event_wait(done_ev, grpc_timeout_seconds_to_deadline(1)) != + nullptr) { + return; + } + grpc::testing::Empty request; + grpc::testing::Empty response; + ClientContext context; + Status status = stub->EmptyCall(&context, request, &response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "Client echo failed."); + GPR_ASSERT(0); + } + } +} + +// Create the channelz to test the connection to the server +bool WaitForConnection(int wait_server_seconds) { + grpc::ChannelArguments channel_args; + std::shared_ptr channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + auto channel = grpc::CreateChannel(server_address, channel_creds); + return channel->WaitForConnected( + grpc_timeout_seconds_to_deadline(wait_server_seconds)); +} + +// Test the channelz sampler +TEST(ChannelzSamplerTest, SimpleTest) { + // start server + ::grpc::channelz::experimental::InitChannelzService(); + EchoServerImpl service; + grpc::ServerBuilder builder; + auto server_creds = + grpc::testing::GetCredentialsProvider()->GetServerCredentials( + custom_credentials_type); + builder.AddListeningPort(server_address, server_creds); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Server listening on %s", server_address.c_str()); + const int kWaitForServerSeconds = 10; + ASSERT_TRUE(WaitForConnection(kWaitForServerSeconds)); + // client threads + gpr_event done_ev1, done_ev2; + gpr_event_init(&done_ev1); + gpr_event_init(&done_ev2); + std::thread client_thread_1(RunClient, "1", &done_ev1); + std::thread client_thread_2(RunClient, "2", &done_ev2); + // Run the channelz sampler + grpc::SubProcess* test_driver = new grpc::SubProcess( + {g_root + "/channelz_sampler", "--server_address=" + server_address, + "--custom_credentials_type=" + custom_credentials_type, + "--sampling_times=" + sampling_times, + "--sampling_interval_seconds=" + sampling_interval_seconds, + "--output_json=" + output_json}); + int status = test_driver->Join(); + if (WIFEXITED(status)) { + if (WEXITSTATUS(status)) { + gpr_log(GPR_ERROR, + "Channelz sampler test test-runner exited with code %d", + WEXITSTATUS(status)); + GPR_ASSERT(0); // log the line number of the assertion failure + } + } else if (WIFSIGNALED(status)) { + gpr_log(GPR_ERROR, "Channelz sampler test test-runner ended from signal %d", + WTERMSIG(status)); + GPR_ASSERT(0); + } else { + gpr_log(GPR_ERROR, + "Channelz sampler test test-runner ended with unknown status %d", + status); + GPR_ASSERT(0); + } + delete test_driver; + gpr_event_set(&done_ev1, reinterpret_cast(1)); + gpr_event_set(&done_ev2, reinterpret_cast(1)); + client_thread_1.join(); + client_thread_2.join(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + std::string me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != std::string::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc new file mode 100644 index 00000000..9f8f1c99 --- /dev/null +++ b/test/cpp/util/cli_call.cc @@ -0,0 +1,225 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_call.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace grpc { +namespace testing { +namespace { +void* tag(intptr_t t) { return reinterpret_cast(t); } +} // namespace + +Status CliCall::Call(const std::string& request, std::string* response, + IncomingMetadataContainer* server_initial_metadata, + IncomingMetadataContainer* server_trailing_metadata) { + Write(request); + WritesDone(); + if (!Read(response, server_initial_metadata)) { + fprintf(stderr, "Failed to read response.\n"); + } + return Finish(server_trailing_metadata); +} + +CliCall::CliCall(const std::shared_ptr& channel, + const std::string& method, + const OutgoingMetadataContainer& metadata, CliArgs args) + : stub_(new grpc::GenericStub(channel)) { + gpr_mu_init(&write_mu_); + gpr_cv_init(&write_cv_); + if (!metadata.empty()) { + for (OutgoingMetadataContainer::const_iterator iter = metadata.begin(); + iter != metadata.end(); ++iter) { + ctx_.AddMetadata(iter->first, iter->second); + } + } + + // Set deadline if timeout > 0 (default value -1 if no timeout specified) + if (args.timeout > 0) { + int64_t timeout_in_ns = ceil(args.timeout * 1e9); + + // Convert timeout (in nanoseconds) to a deadline + auto deadline = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_nanos(timeout_in_ns, GPR_TIMESPAN)); + ctx_.set_deadline(deadline); + } else if (args.timeout != -1) { + fprintf( + stderr, + "WARNING: Non-positive timeout value, skipping setting deadline.\n"); + } + + call_ = stub_->PrepareCall(&ctx_, method, &cq_); + call_->StartCall(tag(1)); + void* got_tag; + bool ok; + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +CliCall::~CliCall() { + gpr_cv_destroy(&write_cv_); + gpr_mu_destroy(&write_mu_); +} + +void CliCall::Write(const std::string& request) { + void* got_tag; + bool ok; + + gpr_slice s = gpr_slice_from_copied_buffer(request.data(), request.size()); + grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); + grpc::ByteBuffer send_buffer(&req_slice, 1); + call_->Write(send_buffer, tag(2)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +bool CliCall::Read(std::string* response, + IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + + grpc::ByteBuffer recv_buffer; + call_->Read(&recv_buffer, tag(3)); + + if (!cq_.Next(&got_tag, &ok) || !ok) { + return false; + } + std::vector slices; + GPR_ASSERT(recv_buffer.Dump(&slices).ok()); + + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +void CliCall::WritesDone() { + void* got_tag; + bool ok; + + call_->WritesDone(tag(4)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +void CliCall::WriteAndWait(const std::string& request) { + grpc::Slice req_slice(request); + grpc::ByteBuffer send_buffer(&req_slice, 1); + + gpr_mu_lock(&write_mu_); + call_->Write(send_buffer, tag(2)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&write_mu_); +} + +void CliCall::WritesDoneAndWait() { + gpr_mu_lock(&write_mu_); + call_->WritesDone(tag(4)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&write_mu_); +} + +bool CliCall::ReadAndMaybeNotifyWrite( + std::string* response, IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + grpc::ByteBuffer recv_buffer; + + call_->Read(&recv_buffer, tag(3)); + bool cq_result = cq_.Next(&got_tag, &ok); + + while (got_tag != tag(3)) { + gpr_mu_lock(&write_mu_); + write_done_ = true; + gpr_cv_signal(&write_cv_); + gpr_mu_unlock(&write_mu_); + + cq_result = cq_.Next(&got_tag, &ok); + if (got_tag == tag(2)) { + GPR_ASSERT(ok); + } + } + + if (!cq_result || !ok) { + // If the RPC is ended on the server side, we should still wait for the + // pending write on the client side to be done. + if (!ok) { + gpr_mu_lock(&write_mu_); + if (!write_done_) { + cq_.Next(&got_tag, &ok); + GPR_ASSERT(got_tag != tag(2)); + write_done_ = true; + gpr_cv_signal(&write_cv_); + } + gpr_mu_unlock(&write_mu_); + } + return false; + } + + std::vector slices; + GPR_ASSERT(recv_buffer.Dump(&slices).ok()); + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) { + void* got_tag; + bool ok; + grpc::Status status; + + call_->Finish(&status, tag(5)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); + if (server_trailing_metadata) { + *server_trailing_metadata = ctx_.GetServerTrailingMetadata(); + } + + return status; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/cli_call_test.cc b/test/cpp/util/cli_call_test.cc new file mode 100644 index 00000000..c57defbc --- /dev/null +++ b/test/cpp/util/cli_call_test.cc @@ -0,0 +1,129 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_call.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (!context->client_metadata().empty()) { + for (std::multimap::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + response->set_message(request->message()); + return Status::OK; + } +}; + +class CliCallTest : public ::testing::Test { + protected: + CliCallTest() {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + channel_ = grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr channel_; + std::unique_ptr stub_; + std::unique_ptr server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +// Send a rpc with a normal stub and then a CliCall. Verify they match. +TEST_F(CliCallTest, SimpleRpc) { + ResetStub(); + // Normal stub. + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + context.AddMetadata("key1", "val1"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + const std::string kMethod("/grpc.testing.EchoTestService/Echo"); + std::string request_bin, response_bin, expected_response_bin; + EXPECT_TRUE(request.SerializeToString(&request_bin)); + EXPECT_TRUE(response.SerializeToString(&expected_response_bin)); + std::multimap client_metadata; + std::multimap server_initial_metadata, + server_trailing_metadata; + client_metadata.insert(std::pair("key1", "val1")); + CliCall call(channel_, kMethod, client_metadata); + Status s2 = call.Call(request_bin, &response_bin, &server_initial_metadata, + &server_trailing_metadata); + EXPECT_TRUE(s2.ok()); + + EXPECT_EQ(expected_response_bin, response_bin); + EXPECT_EQ(context.GetServerInitialMetadata(), server_initial_metadata); + EXPECT_EQ(context.GetServerTrailingMetadata(), server_trailing_metadata); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/util/cli_credentials.cc b/test/cpp/util/cli_credentials.cc new file mode 100644 index 00000000..579276ab --- /dev/null +++ b/test/cpp/util/cli_credentials.cc @@ -0,0 +1,191 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_credentials.h" + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "src/core/lib/iomgr/load_file.h" + +ABSL_RETIRED_FLAG(bool, enable_ssl, false, + "Replaced by --channel_creds_type=ssl."); +ABSL_RETIRED_FLAG(bool, use_auth, false, + "Replaced by --channel_creds_type=gdc."); +ABSL_RETIRED_FLAG(std::string, access_token, "", + "Replaced by --call_creds=access_token=."); +ABSL_FLAG( + std::string, ssl_target, "", + "If not empty, treat the server host name as this for ssl/tls certificate " + "validation."); +ABSL_FLAG( + std::string, ssl_client_cert, "", + "If not empty, load this PEM formatted client certificate file. Requires " + "use of --ssl_client_key."); +ABSL_FLAG(std::string, ssl_client_key, "", + "If not empty, load this PEM formatted private key. Requires use of " + "--ssl_client_cert"); +ABSL_FLAG( + std::string, local_connect_type, "local_tcp", + "The type of local connections for which local channel credentials will " + "be applied. Should be local_tcp or uds."); +ABSL_FLAG( + std::string, channel_creds_type, "", + "The channel creds type: insecure, ssl, gdc (Google Default Credentials), " + "alts, or local."); +ABSL_FLAG( + std::string, call_creds, "", + "Call credentials to use: none (default), or access_token=. If " + "provided, the call creds are composited on top of channel creds."); + +namespace grpc { +namespace testing { + +namespace { + +const char ACCESS_TOKEN_PREFIX[] = "access_token="; +constexpr int ACCESS_TOKEN_PREFIX_LEN = + sizeof(ACCESS_TOKEN_PREFIX) / sizeof(*ACCESS_TOKEN_PREFIX) - 1; + +bool IsAccessToken(const std::string& auth) { + return auth.length() > ACCESS_TOKEN_PREFIX_LEN && + auth.compare(0, ACCESS_TOKEN_PREFIX_LEN, ACCESS_TOKEN_PREFIX) == 0; +} + +std::string AccessToken(const std::string& auth) { + if (!IsAccessToken(auth)) { + return ""; + } + return std::string(auth, ACCESS_TOKEN_PREFIX_LEN); +} + +} // namespace + +std::string CliCredentials::GetDefaultChannelCredsType() const { + return "insecure"; +} + +std::string CliCredentials::GetDefaultCallCreds() const { return "none"; } + +std::shared_ptr +CliCredentials::GetChannelCredentials() const { + if (absl::GetFlag(FLAGS_channel_creds_type) == "insecure") { + return grpc::InsecureChannelCredentials(); + } else if (absl::GetFlag(FLAGS_channel_creds_type) == "ssl") { + grpc::SslCredentialsOptions ssl_creds_options; + // TODO(@Capstan): This won't affect Google Default Credentials using SSL. + if (!absl::GetFlag(FLAGS_ssl_client_cert).empty()) { + grpc_slice cert_slice = grpc_empty_slice(); + GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file(absl::GetFlag(FLAGS_ssl_client_cert).c_str(), 1, + &cert_slice)); + ssl_creds_options.pem_cert_chain = + grpc::StringFromCopiedSlice(cert_slice); + grpc_slice_unref(cert_slice); + } + if (!absl::GetFlag(FLAGS_ssl_client_key).empty()) { + grpc_slice key_slice = grpc_empty_slice(); + GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file(absl::GetFlag(FLAGS_ssl_client_key).c_str(), 1, + &key_slice)); + ssl_creds_options.pem_private_key = + grpc::StringFromCopiedSlice(key_slice); + grpc_slice_unref(key_slice); + } + return grpc::SslCredentials(ssl_creds_options); + } else if (absl::GetFlag(FLAGS_channel_creds_type) == "gdc") { + return grpc::GoogleDefaultCredentials(); + } else if (absl::GetFlag(FLAGS_channel_creds_type) == "alts") { + return grpc::experimental::AltsCredentials( + grpc::experimental::AltsCredentialsOptions()); + } else if (absl::GetFlag(FLAGS_channel_creds_type) == "local") { + if (absl::GetFlag(FLAGS_local_connect_type) == "local_tcp") { + return grpc::experimental::LocalCredentials(LOCAL_TCP); + } else if (absl::GetFlag(FLAGS_local_connect_type) == "uds") { + return grpc::experimental::LocalCredentials(UDS); + } else { + fprintf(stderr, + "--local_connect_type=%s invalid; must be local_tcp or uds.\n", + absl::GetFlag(FLAGS_local_connect_type).c_str()); + } + } + fprintf(stderr, + "--channel_creds_type=%s invalid; must be insecure, ssl, gdc, " + "alts, or local.\n", + absl::GetFlag(FLAGS_channel_creds_type).c_str()); + return std::shared_ptr(); +} + +std::shared_ptr CliCredentials::GetCallCredentials() + const { + if (IsAccessToken(absl::GetFlag(FLAGS_call_creds))) { + return grpc::AccessTokenCredentials( + AccessToken(absl::GetFlag(FLAGS_call_creds))); + } + if (absl::GetFlag(FLAGS_call_creds) == "none") { + // Nothing to do; creds, if any, are baked into the channel. + return std::shared_ptr(); + } + fprintf(stderr, + "--call_creds=%s invalid; must be none " + "or access_token=.\n", + absl::GetFlag(FLAGS_call_creds).c_str()); + return std::shared_ptr(); +} + +std::shared_ptr CliCredentials::GetCredentials() + const { + if (absl::GetFlag(FLAGS_call_creds).empty()) { + absl::SetFlag(&FLAGS_call_creds, GetDefaultCallCreds()); + } + if (absl::GetFlag(FLAGS_channel_creds_type).empty()) { + absl::SetFlag(&FLAGS_channel_creds_type, GetDefaultChannelCredsType()); + } + std::shared_ptr channel_creds = + GetChannelCredentials(); + // Composite any call-type credentials on top of the base channel. + std::shared_ptr call_creds = GetCallCredentials(); + return (channel_creds == nullptr || call_creds == nullptr) + ? channel_creds + : grpc::CompositeChannelCredentials(channel_creds, call_creds); +} + +std::string CliCredentials::GetCredentialUsage() const { + return " --ssl_target ; Set server host for ssl validation\n" + " --ssl_client_cert ; Client cert for ssl\n" + " --ssl_client_key ; Client private key for ssl\n" + " --local_connect_type ; Set to local_tcp or uds\n" + " --channel_creds_type ; Set to insecure, ssl, gdc, alts, or " + "local\n" + " --call_creds ; Set to none, or" + " access_token=\n"; +} + +std::string CliCredentials::GetSslTargetNameOverride() const { + bool use_ssl = absl::GetFlag(FLAGS_channel_creds_type) == "ssl" || + absl::GetFlag(FLAGS_channel_creds_type) == "gdc"; + return use_ssl ? absl::GetFlag(FLAGS_ssl_target) : ""; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/create_test_channel.cc b/test/cpp/util/create_test_channel.cc new file mode 100644 index 00000000..a3915320 --- /dev/null +++ b/test/cpp/util/create_test_channel.cc @@ -0,0 +1,252 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/create_test_channel.h" + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "test/cpp/util/test_credentials_provider.h" + +ABSL_FLAG(std::string, grpc_test_use_grpclb_with_child_policy, "", + "If non-empty, set a static service config on channels created by " + "grpc::CreateTestChannel, that configures the grpclb LB policy " + "with a child policy being the value of this flag (e.g. round_robin " + "or pick_first)."); + +namespace grpc { + +namespace { + +const char kProdTlsCredentialsType[] = "prod_ssl"; + +class SslCredentialProvider : public testing::CredentialTypeProvider { + public: + std::shared_ptr GetChannelCredentials( + grpc::ChannelArguments* /*args*/) override { + return grpc::SslCredentials(SslCredentialsOptions()); + } + std::shared_ptr GetServerCredentials() override { + return nullptr; + } +}; + +gpr_once g_once_init_add_prod_ssl_provider = GPR_ONCE_INIT; +// Register ssl with non-test roots type to the credentials provider. +void AddProdSslType() { + testing::GetCredentialsProvider()->AddSecureType( + kProdTlsCredentialsType, std::unique_ptr( + new SslCredentialProvider)); +} + +void MaybeSetCustomChannelArgs(grpc::ChannelArguments* args) { + if (!absl::GetFlag(FLAGS_grpc_test_use_grpclb_with_child_policy).empty()) { + args->SetString( + "grpc.service_config", + "{\"loadBalancingConfig\":[{\"grpclb\":{\"childPolicy\":[{" + "\"" + + absl::GetFlag(FLAGS_grpc_test_use_grpclb_with_child_policy) + + "\":{}}]}}]}"); + } +} + +} // namespace + +// When cred_type is 'ssl', if server is empty, override_hostname is used to +// create channel. Otherwise, connect to server and override hostname if +// override_hostname is provided. +// When cred_type is not 'ssl', override_hostname is ignored. +// Set use_prod_root to true to use the SSL root for connecting to google. +// In this case, path to the roots pem file must be set via environment variable +// GRPC_DEFAULT_SSL_ROOTS_FILE_PATH. +// Otherwise, root for test SSL cert will be used. +// creds will be used to create a channel when cred_type is 'ssl'. +// Use examples: +// CreateTestChannel( +// "1.1.1.1:12345", "ssl", "override.hostname.com", false, creds); +// CreateTestChannel("test.google.com:443", "ssl", "", true, creds); +// same as above +// CreateTestChannel("", "ssl", "test.google.com:443", true, creds); +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& cred_type, + const std::string& override_hostname, bool use_prod_roots, + const std::shared_ptr& creds, + const ChannelArguments& args) { + return CreateTestChannel(server, cred_type, override_hostname, use_prod_roots, + creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, + const ChannelArguments& args) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, ChannelArguments()); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, std::shared_ptr()); +} + +// Shortcut for end2end and interop tests. +std::shared_ptr CreateTestChannel( + const std::string& server, testing::transport_security security_type) { + return CreateTestChannel(server, "foo.test.google.fr", security_type, false); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& credential_type, + const std::shared_ptr& creds) { + ChannelArguments channel_args; + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + return ::grpc::CreateCustomChannel(server, channel_creds, channel_args); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& cred_type, + const std::string& override_hostname, bool use_prod_roots, + const std::shared_ptr& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + ChannelArguments channel_args(args); + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr channel_creds; + if (cred_type.empty()) { + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(server, InsecureChannelCredentials(), + channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, InsecureChannelCredentials(), channel_args, + std::move(interceptor_creators)); + } + } else if (cred_type == testing::kTlsCredentialsType) { // cred_type == "ssl" + if (use_prod_roots) { + gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType); + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + kProdTlsCredentialsType, &channel_args); + if (!server.empty() && !override_hostname.empty()) { + channel_args.SetSslTargetNameOverride(override_hostname); + } + } else { + // override_hostname is discarded as the provider handles it. + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + testing::kTlsCredentialsType, &channel_args); + } + GPR_ASSERT(channel_creds != nullptr); + + const std::string& connect_to = server.empty() ? override_hostname : server; + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(connect_to, channel_creds, + channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + connect_to, channel_creds, channel_args, + std::move(interceptor_creators)); + } + } else { + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + cred_type, &channel_args); + GPR_ASSERT(channel_creds != nullptr); + + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(server, channel_creds, channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, channel_args, std::move(interceptor_creators)); + } + } +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr> + interceptor_creators) { + std::string credential_type = + security_type == testing::ALTS + ? testing::kAltsCredentialsType + : (security_type == testing::TLS ? testing::kTlsCredentialsType + : testing::kInsecureCredentialsType); + return CreateTestChannel(server, credential_type, override_hostname, + use_prod_roots, creds, args, + std::move(interceptor_creators)); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, ChannelArguments(), + std::move(interceptor_creators)); +} + +std::shared_ptr CreateTestChannel( + const std::string& server, const std::string& credential_type, + const std::shared_ptr& creds, + std::vector< + std::unique_ptr> + interceptor_creators) { + ChannelArguments channel_args; + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, channel_args, std::move(interceptor_creators)); +} + +} // namespace grpc diff --git a/test/cpp/util/error_details_test.cc b/test/cpp/util/error_details_test.cc new file mode 100644 index 00000000..cfb5a2af --- /dev/null +++ b/test/cpp/util/error_details_test.cc @@ -0,0 +1,126 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include "src/proto/grpc/status/status.pb.h" +#include "src/proto/grpc/testing/echo_messages.pb.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +TEST(ExtractTest, Success) { + google::rpc::Status expected; + expected.set_code(13); // INTERNAL + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(std::string(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + google::rpc::Status to; + std::string error_details = expected.SerializeAsString(); + Status from(static_cast(expected.code()), expected.message(), + error_details); + EXPECT_TRUE(ExtractErrorDetails(from, &to).ok()); + EXPECT_EQ(expected.code(), to.code()); + EXPECT_EQ(expected.message(), to.message()); + EXPECT_EQ(1, to.details_size()); + testing::EchoRequest details; + to.details(0).UnpackTo(&details); + EXPECT_EQ(expected_details.message(), details.message()); +} + +TEST(ExtractTest, NullInput) { + EXPECT_EQ(StatusCode::FAILED_PRECONDITION, + ExtractErrorDetails(Status(), nullptr).error_code()); +} + +TEST(ExtractTest, Unparsable) { + std::string error_details("I am not a status object"); + Status from(StatusCode::INTERNAL, "", error_details); + google::rpc::Status to; + EXPECT_EQ(StatusCode::INVALID_ARGUMENT, + ExtractErrorDetails(from, &to).error_code()); +} + +TEST(SetTest, Success) { + google::rpc::Status expected; + expected.set_code(13); // INTERNAL + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(std::string(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(expected.code(), to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); +} + +TEST(SetTest, NullInput) { + EXPECT_EQ(StatusCode::FAILED_PRECONDITION, + SetErrorDetails(google::rpc::Status(), nullptr).error_code()); +} + +TEST(SetTest, OutOfScopeErrorCode) { + google::rpc::Status expected; + expected.set_code(17); // Out of scope (UNAUTHENTICATED is 16). + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(std::string(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(StatusCode::UNKNOWN, to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); +} + +TEST(SetTest, ValidScopeErrorCode) { + for (int c = StatusCode::OK; c <= StatusCode::UNAUTHENTICATED; c++) { + google::rpc::Status expected; + expected.set_code(c); + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(std::string(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(c, to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); + } +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/util/grpc_cli.cc b/test/cpp/util/grpc_cli.cc new file mode 100644 index 00000000..88ce1b76 --- /dev/null +++ b/test/cpp/util/grpc_cli.cc @@ -0,0 +1,95 @@ +/* + + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* + A command line tool to talk to a grpc server. + Run `grpc_cli help` command to see its usage information. + + Example of talking to grpc interop server: + grpc_cli call localhost:50051 UnaryCall "response_size:10" \ + --protofiles=src/proto/grpc/testing/test.proto \ + --channel_creds_type=insecure + + Options: + 1. --protofiles, use this flag to provide proto files if the server does + does not have the reflection service. + 2. --proto_path, if your proto file is not under current working directory, + use this flag to provide a search root. It should work similar to the + counterpart in protoc. This option is valid only when protofiles is + provided. + 3. --metadata specifies metadata to be sent to the server, such as: + --metadata="MyHeaderKey1:Value1:MyHeaderKey2:Value2" + 4. --channel_creds_type, whether to use tls, insecure or platform-specific + options. + 5. --use_auth, if set to true, attach a GoogleDefaultCredentials to the call + 6. --infile, input filename (defaults to stdin) + 7. --outfile, output filename (defaults to stdout) + 8. --binary_input, use the serialized request as input. The serialized + request can be generated by calling something like: + protoc --proto_path=src/proto/grpc/testing/ \ + --encode=grpc.testing.SimpleRequest \ + src/proto/grpc/testing/messages.proto \ + < input.txt > input.bin + If this is used and no proto file is provided in the argument list, the + method string has to be exact in the form of /package.service/method. + 9. --binary_output, use binary format response as output, it can + be later decoded using protoc: + protoc --proto_path=src/proto/grpc/testing/ \ + --decode=grpc.testing.SimpleResponse \ + src/proto/grpc/testing/messages.proto \ + < output.bin > output.txt + 10. --default_service_config, optional default service config to use + on the channel. Note that this may be ignored if the name resolver + returns a service config. + 11. --display_peer_address, on CallMethod commands, log the peer socket + address of the connection that each RPC is made on to stderr. +*/ + +#include +#include +#include + +#include "absl/flags/flag.h" + +#include + +#include "test/cpp/util/cli_credentials.h" +#include "test/cpp/util/grpc_tool.h" +#include "test/cpp/util/test_config.h" + +ABSL_FLAG(std::string, outfile, "", "Output file (default is stdout)"); + +static bool SimplePrint(const std::string& outfile, const std::string& output) { + if (outfile.empty()) { + std::cout << output << std::flush; + } else { + std::ofstream output_file(outfile, std::ios::app | std::ios::binary); + output_file << output << std::flush; + output_file.close(); + } + return true; +} + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + + return grpc::testing::GrpcToolMainLib( + argc, const_cast(argv), grpc::testing::CliCredentials(), + std::bind(SimplePrint, absl::GetFlag(FLAGS_outfile), + std::placeholders::_1)); +} diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc new file mode 100644 index 00000000..bcfbbbaf --- /dev/null +++ b/test/cpp/util/grpc_tool.cc @@ -0,0 +1,1010 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "test/cpp/util/grpc_tool.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" + +#include +#include +#include +#include +#include +#include + +#include "test/cpp/util/cli_call.h" +#include "test/cpp/util/proto_file_parser.h" +#include "test/cpp/util/proto_reflection_descriptor_database.h" +#include "test/cpp/util/service_describer.h" + +#if GPR_WINDOWS +#include +#else +#include +#endif + +ABSL_FLAG(bool, l, false, "Use a long listing format"); +ABSL_FLAG(bool, remotedb, true, + "Use server types to parse and format messages"); +ABSL_FLAG(std::string, metadata, "", + "Metadata to send to server, in the form of key1:val1:key2:val2"); +ABSL_FLAG(std::string, proto_path, ".", + "Path to look for the proto file. " + "Multiple paths can be separated by " GRPC_CLI_PATH_SEPARATOR); +ABSL_FLAG(std::string, protofiles, "", "Name of the proto file."); +ABSL_FLAG(bool, binary_input, false, "Input in binary format"); +ABSL_FLAG(bool, binary_output, false, "Output in binary format"); +ABSL_FLAG(std::string, default_service_config, "", + "Default service config to use on the channel, if non-empty. Note " + "that this will be ignored if the name resolver returns a service " + "config."); +ABSL_FLAG(bool, display_peer_address, false, + "Log the peer socket address of the connection that each RPC is made " + "on to stderr."); +ABSL_FLAG(bool, json_input, false, "Input in json format"); +ABSL_FLAG(bool, json_output, false, "Output in json format"); +ABSL_FLAG(std::string, infile, "", "Input file (default is stdin)"); +ABSL_FLAG(bool, batch, false, + "Input contains multiple requests. Please do not use this to send " + "more than a few RPCs. gRPC CLI has very different performance " + "characteristics compared with normal RPC calls which make it " + "unsuitable for loadtesting or significant production traffic."); +// TODO(Capstan): Consider using absl::Duration +ABSL_FLAG(double, timeout, -1, + "Specify timeout in seconds, used to set the deadline for all " + "RPCs. The default value of -1 means no deadline has been set."); + +namespace grpc { +namespace testing { +namespace { + +class GrpcTool { + public: + explicit GrpcTool(); + virtual ~GrpcTool() {} + + bool Help(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool CallMethod(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool ListServices(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool PrintType(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + // TODO(zyc): implement the following methods + // bool ListServices(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool PrintTypeId(int argc, const char** argv, GrpcToolOutputCallback + // callback); + bool ParseMessage(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool ToText(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool ToJson(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + bool ToBinary(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback); + + void SetPrintCommandMode(int exit_status) { + print_command_usage_ = true; + usage_exit_status_ = exit_status; + } + + private: + void CommandUsage(const std::string& usage) const; + bool print_command_usage_; + int usage_exit_status_; + const std::string cred_usage_; +}; + +template +std::function +BindWith5Args(T&& func) { + return std::bind(std::forward(func), std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5); +} + +template +size_t ArraySize(T& a) { + return ((sizeof(a) / sizeof(*(a))) / + static_cast(!(sizeof(a) % sizeof(*(a))))); +} + +void ParseMetadataFlag( + std::multimap* client_metadata) { + if (absl::GetFlag(FLAGS_metadata).empty()) { + return; + } + std::vector fields; + const char delim = ':'; + const char escape = '\\'; + size_t cur = -1; + std::stringstream ss; + while (++cur < absl::GetFlag(FLAGS_metadata).length()) { + switch (absl::GetFlag(FLAGS_metadata).at(cur)) { + case escape: + if (cur < absl::GetFlag(FLAGS_metadata).length() - 1) { + char c = absl::GetFlag(FLAGS_metadata).at(++cur); + if (c == delim || c == escape) { + ss << c; + continue; + } + } + fprintf(stderr, "Failed to parse metadata flag.\n"); + exit(1); + case delim: + fields.push_back(ss.str()); + ss.str(""); + ss.clear(); + break; + default: + ss << absl::GetFlag(FLAGS_metadata).at(cur); + } + } + fields.push_back(ss.str()); + if (fields.size() % 2) { + fprintf(stderr, "Failed to parse metadata flag.\n"); + exit(1); + } + for (size_t i = 0; i < fields.size(); i += 2) { + client_metadata->insert( + std::pair(fields[i], fields[i + 1])); + } +} + +template +void PrintMetadata(const T& m, const std::string& message) { + if (m.empty()) { + return; + } + fprintf(stderr, "%s\n", message.c_str()); + std::string pair; + for (typename T::const_iterator iter = m.begin(); iter != m.end(); ++iter) { + pair.clear(); + pair.append(iter->first.data(), iter->first.size()); + pair.append(" : "); + pair.append(iter->second.data(), iter->second.size()); + fprintf(stderr, "%s\n", pair.c_str()); + } +} + +void ReadResponse(CliCall* call, const std::string& method_name, + const GrpcToolOutputCallback& callback, + ProtoFileParser* parser, gpr_mu* parser_mu, bool print_mode) { + std::string serialized_response_proto; + std::multimap server_initial_metadata; + + for (bool receive_initial_metadata = true; call->ReadAndMaybeNotifyWrite( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + fprintf(stderr, "got response.\n"); + if (!absl::GetFlag(FLAGS_binary_output)) { + gpr_mu_lock(parser_mu); + serialized_response_proto = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, false /* is_request */, + absl::GetFlag(FLAGS_json_output)); + if (parser->HasError() && print_mode) { + fprintf(stderr, "Failed to parse response.\n"); + } + gpr_mu_unlock(parser_mu); + } + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto) && print_mode) { + fprintf(stderr, "Failed to output response.\n"); + } + } +} + +std::shared_ptr CreateCliChannel( + const std::string& server_address, const CliCredentials& cred) { + grpc::ChannelArguments args; + if (!cred.GetSslTargetNameOverride().empty()) { + args.SetSslTargetNameOverride(cred.GetSslTargetNameOverride()); + } + if (!absl::GetFlag(FLAGS_default_service_config).empty()) { + args.SetString(GRPC_ARG_SERVICE_CONFIG, + absl::GetFlag(FLAGS_default_service_config).c_str()); + } + // See |GRPC_ARG_MAX_METADATA_SIZE| in |grpc_types.h|. + // Set to large enough size (10M) that should work for most use cases. + args.SetInt(GRPC_ARG_MAX_METADATA_SIZE, 10 * 1024 * 1024); + return ::grpc::CreateCustomChannel(server_address, cred.GetCredentials(), + args); +} + +struct Command { + const char* command; + std::function + function; + int min_args; + int max_args; +}; + +const Command ops[] = { + {"help", BindWith5Args(&GrpcTool::Help), 0, INT_MAX}, + {"ls", BindWith5Args(&GrpcTool::ListServices), 1, 3}, + {"list", BindWith5Args(&GrpcTool::ListServices), 1, 3}, + {"call", BindWith5Args(&GrpcTool::CallMethod), 2, 3}, + {"type", BindWith5Args(&GrpcTool::PrintType), 2, 2}, + {"parse", BindWith5Args(&GrpcTool::ParseMessage), 2, 3}, + {"totext", BindWith5Args(&GrpcTool::ToText), 2, 3}, + {"tobinary", BindWith5Args(&GrpcTool::ToBinary), 2, 3}, + {"tojson", BindWith5Args(&GrpcTool::ToJson), 2, 3}, +}; + +void Usage(const std::string& msg) { + fprintf( + stderr, + "%s\n" + " grpc_cli ls ... ; List services\n" + " grpc_cli call ... ; Call method\n" + " grpc_cli type ... ; Print type\n" + " grpc_cli parse ... ; Parse message\n" + " grpc_cli totext ... ; Convert binary message to text\n" + " grpc_cli tojson ... ; Convert binary message to json\n" + " grpc_cli tobinary ... ; Convert text message to binary\n" + " grpc_cli help ... ; Print this message, or per-command usage\n" + "\n", + msg.c_str()); + + exit(1); +} + +const Command* FindCommand(const std::string& name) { + for (int i = 0; i < static_cast(ArraySize(ops)); i++) { + if (name == ops[i].command) { + return &ops[i]; + } + } + return nullptr; +} +} // namespace + +int GrpcToolMainLib(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + if (argc < 2) { + Usage("No command specified"); + } + + std::string command = argv[1]; + argc -= 2; + argv += 2; + + const Command* cmd = FindCommand(command); + if (cmd != nullptr) { + GrpcTool grpc_tool; + if (argc < cmd->min_args || argc > cmd->max_args) { + // Force the command to print its usage message + fprintf(stderr, "\nWrong number of arguments for %s\n", command.c_str()); + grpc_tool.SetPrintCommandMode(1); + return cmd->function(&grpc_tool, -1, nullptr, cred, callback); + } + const bool ok = cmd->function(&grpc_tool, argc, argv, cred, callback); + return ok ? 0 : 1; + } else { + Usage("Invalid command '" + std::string(command.c_str()) + "'"); + } + return 1; +} + +GrpcTool::GrpcTool() : print_command_usage_(false), usage_exit_status_(0) {} + +void GrpcTool::CommandUsage(const std::string& usage) const { + if (print_command_usage_) { + fprintf(stderr, "\n%s%s\n", usage.c_str(), + (usage.empty() || usage[usage.size() - 1] != '\n') ? "\n" : ""); + exit(usage_exit_status_); + } +} + +bool GrpcTool::Help(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Print help\n" + " grpc_cli help [subcommand]\n"); + + if (argc == 0) { + Usage(""); + } else { + const Command* cmd = FindCommand(argv[0]); + if (cmd == nullptr) { + Usage("Unknown command '" + std::string(argv[0]) + "'"); + } + SetPrintCommandMode(0); + cmd->function(this, -1, nullptr, cred, callback); + } + return true; +} + +bool GrpcTool::ListServices(int argc, const char** argv, + const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "List services\n" + " grpc_cli ls
[[/]]\n" + "
; host:port\n" + " ; Exported service name\n" + " ; Method name\n" + " --l ; Use a long listing format\n" + " --outfile ; Output filename (defaults to stdout)\n" + + cred.GetCredentialUsage()); + + std::string server_address(argv[0]); + std::shared_ptr channel = + CreateCliChannel(server_address, cred); + grpc::ProtoReflectionDescriptorDatabase desc_db(channel); + grpc::protobuf::DescriptorPool desc_pool(&desc_db); + + std::vector service_list; + if (!desc_db.GetServices(&service_list)) { + fprintf(stderr, "Received an error when querying services endpoint.\n"); + return false; + } + + // If no service is specified, dump the list of services. + std::string output; + if (argc < 2) { + // List all services, if --l is passed, then include full description, + // otherwise include a summarized list only. + if (absl::GetFlag(FLAGS_l)) { + output = DescribeServiceList(service_list, desc_pool); + } else { + for (auto it = service_list.begin(); it != service_list.end(); it++) { + auto const& service = *it; + output.append(service); + output.append("\n"); + } + } + } else { + std::string service_name; + std::string method_name; + std::stringstream ss(argv[1]); + + // Remove leading slashes. + while (ss.peek() == '/') { + ss.get(); + } + + // Parse service and method names. Support the following patterns: + // Service + // Service Method + // Service.Method + // Service/Method + if (argc == 3) { + std::getline(ss, service_name, '/'); + method_name = argv[2]; + } else { + if (std::getline(ss, service_name, '/')) { + std::getline(ss, method_name); + } + } + + const grpc::protobuf::ServiceDescriptor* service = + desc_pool.FindServiceByName(service_name); + if (service != nullptr) { + if (method_name.empty()) { + output = absl::GetFlag(FLAGS_l) ? DescribeService(service) + : SummarizeService(service); + } else { + method_name.insert(0, 1, '.'); + method_name.insert(0, service_name); + const grpc::protobuf::MethodDescriptor* method = + desc_pool.FindMethodByName(method_name); + if (method != nullptr) { + output = absl::GetFlag(FLAGS_l) ? DescribeMethod(method) + : SummarizeMethod(method); + } else { + fprintf(stderr, "Method %s not found in service %s.\n", + method_name.c_str(), service_name.c_str()); + return false; + } + } + } else { + if (!method_name.empty()) { + fprintf(stderr, "Service %s not found.\n", service_name.c_str()); + return false; + } else { + const grpc::protobuf::MethodDescriptor* method = + desc_pool.FindMethodByName(service_name); + if (method != nullptr) { + output = absl::GetFlag(FLAGS_l) ? DescribeMethod(method) + : SummarizeMethod(method); + } else { + fprintf(stderr, "Service or method %s not found.\n", + service_name.c_str()); + return false; + } + } + } + } + return callback(output); +} + +bool GrpcTool::PrintType(int /*argc*/, const char** argv, + const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Print type\n" + " grpc_cli type
\n" + "
; host:port\n" + " ; Protocol buffer type name\n" + + cred.GetCredentialUsage()); + + std::string server_address(argv[0]); + std::shared_ptr channel = + CreateCliChannel(server_address, cred); + grpc::ProtoReflectionDescriptorDatabase desc_db(channel); + grpc::protobuf::DescriptorPool desc_pool(&desc_db); + + std::string output; + const grpc::protobuf::Descriptor* descriptor = + desc_pool.FindMessageTypeByName(argv[1]); + if (descriptor != nullptr) { + output = descriptor->DebugString(); + } else { + fprintf(stderr, "Type %s not found.\n", argv[1]); + return false; + } + return callback(output); +} + +bool GrpcTool::CallMethod(int argc, const char** argv, + const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Call method\n" + " grpc_cli call
[.] \n" + "
; host:port\n" + " ; Exported service name\n" + " ; Method name\n" + " ; Text protobuffer (overrides infile)\n" + " --protofiles ; Comma separated proto files used as a" + " fallback when parsing request/response\n" + " --proto_path ; The search paths of proto files" + " (" GRPC_CLI_PATH_SEPARATOR + " separated), valid only when --protofiles is given\n" + " --noremotedb ; Don't attempt to use reflection service" + " at all\n" + " --metadata ; The metadata to be sent to the server\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n" + " --binary_input ; Input in binary format\n" + " --binary_output ; Output in binary format\n" + " --json_input ; Input in json format\n" + " --json_output ; Output in json format\n" + " --timeout ; Specify timeout (in seconds), used to " + "set the deadline for RPCs. The default value of -1 means no " + "deadline has been set.\n" + + cred.GetCredentialUsage()); + + std::stringstream output_ss; + std::string request_text; + std::string server_address(argv[0]); + std::string method_name(argv[1]); + std::string formatted_method_name; + std::unique_ptr parser; + std::string serialized_request_proto; + CliArgs cli_args; + cli_args.timeout = absl::GetFlag(FLAGS_timeout); + bool print_mode = false; + + std::shared_ptr channel = + CreateCliChannel(server_address, cred); + + if (!absl::GetFlag(FLAGS_binary_input) || + !absl::GetFlag(FLAGS_binary_output)) { + parser = absl::make_unique( + absl::GetFlag(FLAGS_remotedb) ? channel : nullptr, + absl::GetFlag(FLAGS_proto_path), absl::GetFlag(FLAGS_protofiles)); + if (parser->HasError()) { + fprintf( + stderr, + "Failed to find remote reflection service and local proto files.\n"); + return false; + } + } + + if (absl::GetFlag(FLAGS_binary_input)) { + formatted_method_name = method_name; + } else { + formatted_method_name = parser->GetFormattedMethodName(method_name); + if (parser->HasError()) { + fprintf(stderr, "Failed to find method %s in proto files.\n", + method_name.c_str()); + } + } + + if (argc == 3) { + request_text = argv[2]; + } + + if (parser->IsStreaming(method_name, true /* is_request */)) { + std::istream* input_stream; + std::ifstream input_file; + + if (absl::GetFlag(FLAGS_batch)) { + fprintf(stderr, "Batch mode for streaming RPC is not supported.\n"); + return false; + } + + std::multimap client_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata, cli_args); + if (absl::GetFlag(FLAGS_display_peer_address)) { + fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + + if (absl::GetFlag(FLAGS_infile).empty()) { + if (isatty(fileno(stdin))) { + print_mode = true; + fprintf(stderr, "reading streaming request message from stdin...\n"); + } + input_stream = &std::cin; + } else { + input_file.open(absl::GetFlag(FLAGS_infile), + std::ios::in | std::ios::binary); + input_stream = &input_file; + } + + gpr_mu parser_mu; + gpr_mu_init(&parser_mu); + std::thread read_thread(ReadResponse, &call, method_name, callback, + parser.get(), &parser_mu, print_mode); + + std::stringstream request_ss; + std::string line; + while (!request_text.empty() || + (!input_stream->eof() && getline(*input_stream, line))) { + if (!request_text.empty()) { + if (absl::GetFlag(FLAGS_binary_input)) { + serialized_request_proto = request_text; + request_text.clear(); + } else { + gpr_mu_lock(&parser_mu); + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, + absl::GetFlag(FLAGS_json_input)); + request_text.clear(); + if (parser->HasError()) { + if (print_mode) { + fprintf(stderr, "Failed to parse request.\n"); + } + gpr_mu_unlock(&parser_mu); + continue; + } + gpr_mu_unlock(&parser_mu); + } + + call.WriteAndWait(serialized_request_proto); + if (print_mode) { + fprintf(stderr, "Request sent.\n"); + } + } else { + if (line.length() == 0) { + request_text = request_ss.str(); + request_ss.str(std::string()); + request_ss.clear(); + } else { + request_ss << line << ' '; + } + } + } + if (input_file.is_open()) { + input_file.close(); + } + + call.WritesDoneAndWait(); + read_thread.join(); + gpr_mu_destroy(&parser_mu); + + std::multimap server_trailing_metadata; + Status status = call.Finish(&server_trailing_metadata); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + + if (status.ok()) { + fprintf(stderr, "Stream RPC succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + return false; + } + + } else { // parser->IsStreaming(method_name, true /* is_request */) + if (absl::GetFlag(FLAGS_batch)) { + if (parser->IsStreaming(method_name, false /* is_request */)) { + fprintf(stderr, "Batch mode for streaming RPC is not supported.\n"); + return false; + } + + std::istream* input_stream; + std::ifstream input_file; + + if (absl::GetFlag(FLAGS_infile).empty()) { + if (isatty(fileno(stdin))) { + print_mode = true; + fprintf(stderr, "reading request messages from stdin...\n"); + } + input_stream = &std::cin; + } else { + input_file.open(absl::GetFlag(FLAGS_infile), + std::ios::in | std::ios::binary); + input_stream = &input_file; + } + + std::multimap client_metadata; + ParseMetadataFlag(&client_metadata); + if (print_mode) { + PrintMetadata(client_metadata, "Sending client initial metadata:"); + } + + std::stringstream request_ss; + std::string line; + while (!request_text.empty() || + (!input_stream->eof() && getline(*input_stream, line))) { + if (!request_text.empty()) { + if (absl::GetFlag(FLAGS_binary_input)) { + serialized_request_proto = request_text; + request_text.clear(); + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, + absl::GetFlag(FLAGS_json_input)); + request_text.clear(); + if (parser->HasError()) { + if (print_mode) { + fprintf(stderr, "Failed to parse request.\n"); + } + continue; + } + } + + std::string serialized_response_proto; + std::multimap + server_initial_metadata, server_trailing_metadata; + CliCall call(channel, formatted_method_name, client_metadata, + cli_args); + if (absl::GetFlag(FLAGS_display_peer_address)) { + fprintf(stderr, + "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + call.Write(serialized_request_proto); + call.WritesDone(); + if (!call.Read(&serialized_response_proto, + &server_initial_metadata)) { + fprintf(stderr, "Failed to read response.\n"); + } + Status status = call.Finish(&server_trailing_metadata); + + if (status.ok()) { + if (print_mode) { + fprintf(stderr, "Rpc succeeded with OK status.\n"); + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + } + + if (absl::GetFlag(FLAGS_binary_output)) { + if (!callback(serialized_response_proto)) { + break; + } + } else { + std::string response_text = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, + false /* is_request */, absl::GetFlag(FLAGS_json_output)); + + if (parser->HasError() && print_mode) { + fprintf(stderr, "Failed to parse response.\n"); + } else { + if (!callback(response_text)) { + break; + } + } + } + } else { + if (print_mode) { + fprintf(stderr, + "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + } + } + } else { + if (line.length() == 0) { + request_text = request_ss.str(); + request_ss.str(std::string()); + request_ss.clear(); + } else { + request_ss << line << ' '; + } + } + } + + if (input_file.is_open()) { + input_file.close(); + } + + return true; + } + + if (argc == 3) { + if (!absl::GetFlag(FLAGS_infile).empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } + } else { + std::stringstream input_stream; + if (absl::GetFlag(FLAGS_infile).empty()) { + if (isatty(fileno(stdin))) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(absl::GetFlag(FLAGS_infile), + std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + request_text = input_stream.str(); + } + + if (absl::GetFlag(FLAGS_binary_input)) { + serialized_request_proto = request_text; + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, + absl::GetFlag(FLAGS_json_input)); + if (parser->HasError()) { + fprintf(stderr, "Failed to parse request.\n"); + return false; + } + } + fprintf(stderr, "connecting to %s\n", server_address.c_str()); + + std::string serialized_response_proto; + std::multimap client_metadata; + std::multimap server_initial_metadata, + server_trailing_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata, cli_args); + if (absl::GetFlag(FLAGS_display_peer_address)) { + fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + call.Write(serialized_request_proto); + call.WritesDone(); + + for (bool receive_initial_metadata = true; call.Read( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + if (!absl::GetFlag(FLAGS_binary_output)) { + serialized_response_proto = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, false /* is_request */, + absl::GetFlag(FLAGS_json_output)); + if (parser->HasError()) { + fprintf(stderr, "Failed to parse response.\n"); + return false; + } + } + + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto)) { + return false; + } + } + Status status = call.Finish(&server_trailing_metadata); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + if (status.ok()) { + fprintf(stderr, "Rpc succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + return false; + } + } + GPR_UNREACHABLE_CODE(return false); +} + +bool GrpcTool::ParseMessage(int argc, const char** argv, + const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Parse message\n" + " grpc_cli parse
[]\n" + "
; host:port\n" + " ; Protocol buffer type name\n" + " ; Text protobuffer (overrides --infile)\n" + " --protofiles ; Comma separated proto files used as a" + " fallback when parsing request/response\n" + " --proto_path ; The search paths of proto files" + " (" GRPC_CLI_PATH_SEPARATOR + " separated), valid only when --protofiles is given\n" + " --noremotedb ; Don't attempt to use reflection service" + " at all\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n" + " --binary_input ; Input in binary format\n" + " --binary_output ; Output in binary format\n" + " --json_input ; Input in json format\n" + " --json_output ; Output in json format\n" + + cred.GetCredentialUsage()); + + std::stringstream output_ss; + std::string message_text; + std::string server_address(argv[0]); + std::string type_name(argv[1]); + std::unique_ptr parser; + std::string serialized_request_proto; + + if (argc == 3) { + message_text = argv[2]; + if (!absl::GetFlag(FLAGS_infile).empty()) { + fprintf(stderr, "warning: message given in argv, ignoring --infile.\n"); + } + } else { + std::stringstream input_stream; + if (absl::GetFlag(FLAGS_infile).empty()) { + if (isatty(fileno(stdin))) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(absl::GetFlag(FLAGS_infile), + std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + message_text = input_stream.str(); + } + + if (!absl::GetFlag(FLAGS_binary_input) || + !absl::GetFlag(FLAGS_binary_output)) { + std::shared_ptr channel = + CreateCliChannel(server_address, cred); + parser = absl::make_unique( + absl::GetFlag(FLAGS_remotedb) ? channel : nullptr, + absl::GetFlag(FLAGS_proto_path), absl::GetFlag(FLAGS_protofiles)); + if (parser->HasError()) { + fprintf( + stderr, + "Failed to find remote reflection service and local proto files.\n"); + return false; + } + } + + if (absl::GetFlag(FLAGS_binary_input)) { + serialized_request_proto = message_text; + } else { + serialized_request_proto = parser->GetSerializedProtoFromMessageType( + type_name, message_text, absl::GetFlag(FLAGS_json_input)); + if (parser->HasError()) { + fprintf(stderr, "Failed to serialize the message.\n"); + return false; + } + } + + if (absl::GetFlag(FLAGS_binary_output)) { + output_ss << serialized_request_proto; + } else { + std::string output_text; + output_text = parser->GetFormattedStringFromMessageType( + type_name, serialized_request_proto, absl::GetFlag(FLAGS_json_output)); + if (parser->HasError()) { + fprintf(stderr, "Failed to deserialize the message.\n"); + return false; + } + + output_ss << output_text << std::endl; + } + + return callback(output_ss.str()); +} + +bool GrpcTool::ToText(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Convert binary message to text\n" + " grpc_cli totext \n" + " ; Comma separated list of proto files\n" + " ; Protocol buffer type name\n" + " --proto_path ; The search paths of proto files" + " (" GRPC_CLI_PATH_SEPARATOR + " separated)\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + absl::SetFlag(&FLAGS_protofiles, argv[0]); + absl::SetFlag(&FLAGS_remotedb, false); + absl::SetFlag(&FLAGS_binary_input, true); + absl::SetFlag(&FLAGS_binary_output, false); + return ParseMessage(argc, argv, cred, callback); +} + +bool GrpcTool::ToJson(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Convert binary message to json\n" + " grpc_cli tojson \n" + " ; Comma separated list of proto files\n" + " ; Protocol buffer type name\n" + " --proto_path ; The search paths of proto files" + " (" GRPC_CLI_PATH_SEPARATOR + " separated)\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + absl::SetFlag(&FLAGS_protofiles, argv[0]); + absl::SetFlag(&FLAGS_remotedb, false); + absl::SetFlag(&FLAGS_binary_input, true); + absl::SetFlag(&FLAGS_binary_output, false); + absl::SetFlag(&FLAGS_json_output, true); + return ParseMessage(argc, argv, cred, callback); +} + +bool GrpcTool::ToBinary(int argc, const char** argv, const CliCredentials& cred, + const GrpcToolOutputCallback& callback) { + CommandUsage( + "Convert text message to binary\n" + " grpc_cli tobinary []\n" + " ; Comma separated list of proto files\n" + " ; Protocol buffer type name\n" + " --proto_path ; The search paths of proto files" + " (" GRPC_CLI_PATH_SEPARATOR + " separated)\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + absl::SetFlag(&FLAGS_protofiles, argv[0]); + absl::SetFlag(&FLAGS_remotedb, false); + absl::SetFlag(&FLAGS_binary_input, false); + absl::SetFlag(&FLAGS_binary_output, true); + return ParseMessage(argc, argv, cred, callback); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/grpc_tool_test.cc b/test/cpp/util/grpc_tool_test.cc new file mode 100644 index 00000000..9dec8a22 --- /dev/null +++ b/test/cpp/util/grpc_tool_test.cc @@ -0,0 +1,1353 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/grpc_tool.h" + +#include +#include + +#include + +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/echo.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/cli_credentials.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_config.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +#define USAGE_REGEX "( grpc_cli .+\n){2,10}" + +#define ECHO_TEST_SERVICE_SUMMARY \ + "Echo\n" \ + "Echo1\n" \ + "Echo2\n" \ + "CheckDeadlineUpperBound\n" \ + "CheckDeadlineSet\n" \ + "CheckClientInitialMetadata\n" \ + "RequestStream\n" \ + "ResponseStream\n" \ + "BidiStream\n" \ + "Unimplemented\n" \ + "UnimplementedBidi\n" + +#define ECHO_TEST_SERVICE_DESCRIPTION \ + "filename: src/proto/grpc/testing/echo.proto\n" \ + "package: grpc.testing;\n" \ + "service EchoTestService {\n" \ + " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc Echo1(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc Echo2(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc CheckDeadlineUpperBound(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.StringValue) {}\n" \ + " rpc CheckDeadlineSet(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.StringValue) {}\n" \ + " rpc CheckClientInitialMetadata(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.SimpleResponse) {}\n" \ + " rpc RequestStream(stream grpc.testing.EchoRequest) returns " \ + "(grpc.testing.EchoResponse) {}\n" \ + " rpc ResponseStream(grpc.testing.EchoRequest) returns (stream " \ + "grpc.testing.EchoResponse) {}\n" \ + " rpc BidiStream(stream grpc.testing.EchoRequest) returns (stream " \ + "grpc.testing.EchoResponse) {}\n" \ + " rpc Unimplemented(grpc.testing.EchoRequest) returns " \ + "(grpc.testing.EchoResponse) {}\n" \ + " rpc UnimplementedBidi(stream grpc.testing.EchoRequest) returns (stream " \ + "grpc.testing.EchoResponse) {}\n" \ + "}\n" \ + "\n" + +#define ECHO_METHOD_DESCRIPTION \ + " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" + +#define ECHO_RESPONSE_MESSAGE_TEXT_FORMAT \ + "message: \"echo\"\n" \ + "param {\n" \ + " host: \"localhost\"\n" \ + " peer: \"peer\"\n" \ + "}\n\n" + +#define ECHO_RESPONSE_MESSAGE_JSON_FORMAT \ + "{\n" \ + " \"message\": \"echo\",\n" \ + " \"param\": {\n" \ + " \"host\": \"localhost\",\n" \ + " \"peer\": \"peer\"\n" \ + " }\n" \ + "}\n\n" + +ABSL_DECLARE_FLAG(std::string, channel_creds_type); +ABSL_DECLARE_FLAG(std::string, ssl_target); +ABSL_DECLARE_FLAG(bool, binary_input); +ABSL_DECLARE_FLAG(bool, binary_output); +ABSL_DECLARE_FLAG(bool, json_input); +ABSL_DECLARE_FLAG(bool, json_output); +ABSL_DECLARE_FLAG(bool, l); +ABSL_DECLARE_FLAG(bool, batch); +ABSL_DECLARE_FLAG(std::string, metadata); +ABSL_DECLARE_FLAG(std::string, protofiles); +ABSL_DECLARE_FLAG(std::string, proto_path); +ABSL_DECLARE_FLAG(std::string, default_service_config); +ABSL_DECLARE_FLAG(double, timeout); + +namespace grpc { +namespace testing { +namespace { + +const int kServerDefaultResponseStreamsToSend = 3; + +class TestCliCredentials final : public grpc::testing::CliCredentials { + public: + explicit TestCliCredentials(bool secure = false) : secure_(secure) {} + std::shared_ptr GetChannelCredentials() + const override { + if (!secure_) { + return InsecureChannelCredentials(); + } + grpc_slice ca_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + const char* test_root_cert = + reinterpret_cast GRPC_SLICE_START_PTR(ca_slice); + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + std::shared_ptr credential_ptr = + grpc::SslCredentials(grpc::SslCredentialsOptions(ssl_opts)); + grpc_slice_unref(ca_slice); + return credential_ptr; + } + std::string GetCredentialUsage() const override { return ""; } + + private: + const bool secure_; +}; + +bool PrintStream(std::stringstream* ss, const std::string& output) { + (*ss) << output; + return true; +} + +template +size_t ArraySize(T& a) { + return ((sizeof(a) / sizeof(*(a))) / + static_cast(!(sizeof(a) % sizeof(*(a))))); +} + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (!context->client_metadata().empty()) { + for (std::multimap::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + response->set_message(request->message()); + return Status::OK; + } + + Status CheckDeadlineSet(ServerContext* context, + const SimpleRequest* /*request*/, + StringValue* response) override { + response->set_message(context->deadline() != + std::chrono::system_clock::time_point::max() + ? "true" + : "false"); + return Status::OK; + } + + // Check if deadline - current time <= timeout + // If deadline set, timeout + current time should be an upper bound for it + Status CheckDeadlineUpperBound(ServerContext* context, + const SimpleRequest* /*request*/, + StringValue* response) override { + auto seconds = std::chrono::duration_cast( + context->deadline() - std::chrono::system_clock::now()); + + // Returning string instead of bool to avoid using embedded messages in + // proto3 + response->set_message( + seconds.count() <= absl::GetFlag(FLAGS_timeout) ? "true" : "false"); + return Status::OK; + } + + Status RequestStream(ServerContext* context, + ServerReader* reader, + EchoResponse* response) override { + EchoRequest request; + response->set_message(""); + if (!context->client_metadata().empty()) { + for (std::multimap::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + while (reader->Read(&request)) { + response->mutable_message()->append(request.message()); + } + + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* request, + ServerWriter* writer) override { + if (!context->client_metadata().empty()) { + for (std::multimap::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + EchoResponse response; + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + response.set_message(request->message() + std::to_string(i)); + writer->Write(response); + } + + return Status::OK; + } + + Status BidiStream( + ServerContext* context, + ServerReaderWriter* stream) override { + EchoRequest request; + EchoResponse response; + if (!context->client_metadata().empty()) { + for (std::multimap::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + while (stream->Read(&request)) { + response.set_message(request.message()); + stream->Write(response); + } + + return Status::OK; + } +}; + +} // namespace + +class GrpcToolTest : public ::testing::Test { + protected: + GrpcToolTest() {} + + // SetUpServer cannot be used with EXPECT_EXIT. grpc_pick_unused_port_or_die() + // uses atexit() to free chosen ports, and it will spawn a new thread in + // resolve_address_posix.c:192 at exit time. + std::string SetUpServer(bool secure = false) { + std::ostringstream server_address; + int port = grpc_pick_unused_port_or_die(); + server_address << "localhost:" << port; + // Setup server + ServerBuilder builder; + std::shared_ptr creds; + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast GRPC_SLICE_START_PTR(key_slice); + SslServerCredentialsOptions::PemKeyCertPair pkcp = {server_key, + server_cert}; + if (secure) { + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + creds = SslServerCredentials(ssl_opts); + } else { + creds = InsecureServerCredentials(); + } + builder.AddListeningPort(server_address.str(), creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + return server_address.str(); + } + + void ShutdownServer() { server_->Shutdown(); } + + std::unique_ptr server_; + TestServiceImpl service_; + reflection::ProtoServerReflectionPlugin plugin_; +}; + +TEST_F(GrpcToolTest, NoCommand) { + // Test input "grpc_cli" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), "No command specified\n" USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, InvalidCommand) { + // Test input "grpc_cli" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "abc"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), "Invalid command 'abc'\n" USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, HelpCommand) { + // Test input "grpc_cli help" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "help"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT(GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1)), + ::testing::ExitedWithCode(1), USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, ListCommand) { + // Test input "grpc_cli list localhost:" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + + absl::SetFlag(&FLAGS_l, false); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ListOneService) { + // Test input "grpc_cli list localhost: grpc.testing.EchoTestService" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str(), + "grpc.testing.EchoTestService"}; + // without -l flag + absl::SetFlag(&FLAGS_l, false); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_TEST_SERVICE_SUMMARY + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_SUMMARY)); + + // with -l flag + output_stream.str(std::string()); + output_stream.clear(); + absl::SetFlag(&FLAGS_l, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_TEST_SERVICE_DESCRIPTION + EXPECT_TRUE( + 0 == strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_DESCRIPTION)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TypeCommand) { + // Test input "grpc_cli type localhost: grpc.testing.EchoRequest" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "type", server_address.c_str(), + "grpc.testing.EchoRequest"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + const grpc::protobuf::Descriptor* desc = + grpc::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "grpc.testing.EchoRequest"); + // Expected output: the DebugString of grpc.testing.EchoRequest + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), desc->DebugString().c_str())); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ListOneMethod) { + // Test input "grpc_cli list localhost: grpc.testing.EchoTestService" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str(), + "grpc.testing.EchoTestService.Echo"}; + // without -l flag + absl::SetFlag(&FLAGS_l, false); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "Echo" + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), "Echo\n")); + + // with -l flag + output_stream.str(std::string()); + output_stream.clear(); + absl::SetFlag(&FLAGS_l, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_METHOD_DESCRIPTION + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), ECHO_METHOD_DESCRIPTION)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TypeNotFound) { + // Test input "grpc_cli type localhost: grpc.testing.PhonyRequest" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "type", server_address.c_str(), + "grpc.testing.PhonyRequest"}; + + EXPECT_TRUE(1 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommand) { + // Test input "grpc_cli call localhost: Echo "message: 'Hello'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello'"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + + // TODO(Capstan): Consider using absl::FlagSaver + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + + // Expected output: + // { + // "message": "Hello" + // } + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello\"\n}")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandJsonInput) { + // Test input "grpc_cli call localhost: Echo "{ \"message\": \"Hello\"}" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{ \"message\": \"Hello\"}"}; + + absl::SetFlag(&FLAGS_json_input, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_json_input, false); + + // Expected output: + // { + // "message": "Hello" + // } + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello\"\n}")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatch) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello1" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello1\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchJsonInput) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{\"message\": \"Hello0\"}"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{\"message\": \"Hello1\"}\n\n{\"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_json_input, true); + absl::SetFlag(&FLAGS_batch, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_batch, false); + absl::SetFlag(&FLAGS_json_input, false); + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello1" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello1\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchWithBadRequest) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello0'"}; + + // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 1\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: "message: "Hello0"\nmessage: "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchJsonInputWithBadRequest) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{ \"message\": \"Hello0\"}"}; + + // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"message\": 1 }\n\n { \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + absl::SetFlag(&FLAGS_json_input, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_input, false); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: "message: "Hello0"\nmessage: "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_batch, true); + absl::SetFlag(&FLAGS_json_input, true); + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_json_input, false); + absl::SetFlag(&FLAGS_batch, false); + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStream) { + // Test input: grpc_cli call localhost: RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello1Hello2\"" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0Hello1Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamJsonInput) { + // Test input: grpc_cli call localhost: RequestStream "{ \"message\": + // \"Hello0\"}" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "{ \"message\": \"Hello0\" }"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"message\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_json_input, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_input, false); + + // Expected output: "message: \"Hello0Hello1Hello2\"" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0Hello1Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequest) { + // Test input: grpc_cli call localhost: RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("bad_field: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello2\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequestJsonInput) { + // Test input: grpc_cli call localhost: RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "{ \"message\": \"Hello0\" }"}; + + // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"bad_field\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + absl::SetFlag(&FLAGS_json_input, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_input, false); + + // Expected output: "message: \"Hello0Hello2\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineSet) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=5000.25" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to 5000.25 seconds + absl::SetFlag(&FLAGS_timeout, 5000.25); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "true"", deadline set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"true\"")); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineUpperBound) { + // Test input "grpc_cli call CheckDeadlineUpperBound --timeout=900" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineUpperBound"}; + + // Set timeout to 900 seconds + absl::SetFlag(&FLAGS_timeout, 900); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "true"" + // deadline not greater than timeout + current time + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"true\"")); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithNegativeTimeoutValue) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=-5" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to -5 (deadline not set) + absl::SetFlag(&FLAGS_timeout, -5); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "false"", deadline not set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"false\"")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithDefaultTimeoutValue) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=-1" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to -1 (default value, deadline not set) + absl::SetFlag(&FLAGS_timeout, -1); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "false"", deadline not set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"false\"")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandResponseStream) { + // Test input: grpc_cli call localhost: ResponseStream "message: + // 'Hello'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "ResponseStream", "message: 'Hello'"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello{n}\"" + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + std::string expected_response_text = + "message: \"Hello" + std::to_string(i) + "\"\n"; + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + expected_response_text.c_str())); + } + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + + // Expected output: "{\n \"message\": \"Hello{n}\"\n}\n" + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + std::string expected_response_text = + "{\n \"message\": \"Hello" + std::to_string(i) + "\"\n}\n"; + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + expected_response_text.c_str())); + } + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStream) { + // Test input: grpc_cli call localhost: BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStreamWithBadRequest) { + // Test input: grpc_cli call localhost: BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 1.0\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ParseCommand) { + // Test input "grpc_cli parse localhost: grpc.testing.EchoResponse + // ECHO_RESPONSE_MESSAGE" + std::stringstream output_stream; + std::stringstream binary_output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "parse", server_address.c_str(), + "grpc.testing.EchoResponse", + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT}; + + absl::SetFlag(&FLAGS_binary_input, false); + absl::SetFlag(&FLAGS_binary_output, false); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + + // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_JSON_FORMAT)); + + // Parse text message to binary message and then parse it back to text message + output_stream.str(std::string()); + output_stream.clear(); + absl::SetFlag(&FLAGS_binary_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + std::string binary_data = output_stream.str(); + output_stream.str(std::string()); + output_stream.clear(); + argv[4] = binary_data.c_str(); + absl::SetFlag(&FLAGS_binary_input, true); + absl::SetFlag(&FLAGS_binary_output, false); + EXPECT_TRUE(0 == GrpcToolMainLib(5, argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: ECHO_RESPONSE_MESSAGE + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + absl::SetFlag(&FLAGS_binary_input, false); + absl::SetFlag(&FLAGS_binary_output, false); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ParseCommandJsonFormat) { + // Test input "grpc_cli parse localhost: grpc.testing.EchoResponse + // ECHO_RESPONSE_MESSAGE_JSON_FORMAT" + std::stringstream output_stream; + std::stringstream binary_output_stream; + + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "parse", server_address.c_str(), + "grpc.testing.EchoResponse", + ECHO_RESPONSE_MESSAGE_JSON_FORMAT}; + + absl::SetFlag(&FLAGS_json_input, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + // with json_output + output_stream.str(std::string()); + output_stream.clear(); + + absl::SetFlag(&FLAGS_json_output, true); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_json_output, false); + absl::SetFlag(&FLAGS_json_input, false); + + // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_JSON_FORMAT)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TooFewArguments) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "call", "Echo"}; + + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*"); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, TooManyArguments) { + // Test input "grpc_cli call localhost: Echo Echo "message: 'Hello'" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "call", "localhost:10000", + "Echo", "Echo", "message: 'Hello'"}; + + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*"); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, CallCommandWithMetadata) { + // Test input "grpc_cli call localhost: Echo "message: 'Hello'" + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello'"}; + + { + std::stringstream output_stream; + absl::SetFlag(&FLAGS_metadata, "key0:val0:key1:valq:key2:val2"); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + { + std::stringstream output_stream; + absl::SetFlag(&FLAGS_metadata, "key:val\\:val"); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + { + std::stringstream output_stream; + absl::SetFlag(&FLAGS_metadata, "key:val\\\\val"); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + absl::SetFlag(&FLAGS_metadata, ""); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithBadMetadata) { + // Test input "grpc_cli call localhost:10000 Echo "message: 'Hello'" + const char* argv[] = {"grpc_cli", "call", "localhost:10000", + "grpc.testing.EchoTestService.Echo", + "message: 'Hello'"}; + absl::SetFlag(&FLAGS_protofiles, "src/proto/grpc/testing/echo.proto"); + char* test_srcdir = gpr_getenv("TEST_SRCDIR"); + if (test_srcdir != nullptr) { + absl::SetFlag(&FLAGS_proto_path, + test_srcdir + std::string("/com_github_grpc_grpc")); + } + + { + std::stringstream output_stream; + absl::SetFlag(&FLAGS_metadata, "key0:val0:key1"); + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*"); + } + + { + std::stringstream output_stream; + absl::SetFlag(&FLAGS_metadata, "key:val\\val"); + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*"); + } + + absl::SetFlag(&FLAGS_metadata, ""); + absl::SetFlag(&FLAGS_protofiles, ""); + + gpr_free(test_srcdir); +} + +TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) { + const std::string server_address = SetUpServer(true); + + // Test input "grpc_cli ls localhost: --channel_creds_type=ssl + // --ssl_target=z.test.google.fr" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + absl::SetFlag(&FLAGS_l, false); + absl::SetFlag(&FLAGS_channel_creds_type, "ssl"); + absl::SetFlag(&FLAGS_ssl_target, "z.test.google.fr"); + EXPECT_TRUE( + 0 == GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(true), + std::bind(PrintStream, &output_stream, std::placeholders::_1))); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + + absl::SetFlag(&FLAGS_channel_creds_type, ""); + absl::SetFlag(&FLAGS_ssl_target, ""); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ConfiguringDefaultServiceConfig) { + // Test input "grpc_cli list localhost: + // --default_service_config={\"loadBalancingConfig\":[{\"pick_first\":{}}]}" + std::stringstream output_stream; + const std::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + // Just check that the tool is still operational when --default_service_config + // is configured. This particular service config is in reality redundant with + // the channel's default configuration. + absl::SetFlag(&FLAGS_l, false); + absl::SetFlag(&FLAGS_default_service_config, + "{\"loadBalancingConfig\":[{\"pick_first\":{}}]}"); + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + absl::SetFlag(&FLAGS_default_service_config, ""); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + ShutdownServer(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + GRPC_GTEST_FLAG_SET_DEATH_TEST_STYLE("threadsafe"); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/util/metrics_server.cc b/test/cpp/util/metrics_server.cc new file mode 100644 index 00000000..9504c700 --- /dev/null +++ b/test/cpp/util/metrics_server.cc @@ -0,0 +1,117 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include "test/cpp/util/metrics_server.h" + +#include +#include +#include + +#include "src/proto/grpc/testing/metrics.grpc.pb.h" +#include "src/proto/grpc/testing/metrics.pb.h" + +namespace grpc { +namespace testing { + +QpsGauge::QpsGauge() + : start_time_(gpr_now(GPR_CLOCK_REALTIME)), num_queries_(0) {} + +void QpsGauge::Reset() { + std::lock_guard lock(num_queries_mu_); + num_queries_ = 0; + start_time_ = gpr_now(GPR_CLOCK_REALTIME); +} + +void QpsGauge::Incr() { + std::lock_guard lock(num_queries_mu_); + num_queries_++; +} + +long QpsGauge::Get() { + std::lock_guard lock(num_queries_mu_); + gpr_timespec time_diff = + gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), start_time_); + long duration_secs = time_diff.tv_sec > 0 ? time_diff.tv_sec : 1; + return num_queries_ / duration_secs; +} + +grpc::Status MetricsServiceImpl::GetAllGauges( + ServerContext* /*context*/, const EmptyMessage* /*request*/, + ServerWriter* writer) { + gpr_log(GPR_DEBUG, "GetAllGauges called"); + + std::lock_guard lock(mu_); + for (auto it = qps_gauges_.begin(); it != qps_gauges_.end(); it++) { + GaugeResponse resp; + resp.set_name(it->first); // Gauge name + resp.set_long_value(it->second->Get()); // Gauge value + writer->Write(resp); + } + + return Status::OK; +} + +grpc::Status MetricsServiceImpl::GetGauge(ServerContext* /*context*/, + const GaugeRequest* request, + GaugeResponse* response) { + std::lock_guard lock(mu_); + + const auto it = qps_gauges_.find(request->name()); + if (it != qps_gauges_.end()) { + response->set_name(it->first); + response->set_long_value(it->second->Get()); + } + + return Status::OK; +} + +std::shared_ptr MetricsServiceImpl::CreateQpsGauge( + const std::string& name, bool* already_present) { + std::lock_guard lock(mu_); + + std::shared_ptr qps_gauge(new QpsGauge()); + const auto p = qps_gauges_.insert(std::make_pair(name, qps_gauge)); + + // p.first is an iterator pointing to > pair. + // p.second is a boolean which is set to 'true' if the QpsGauge is + // successfully inserted in the guages_ map and 'false' if it is already + // present in the map + *already_present = !p.second; + return p.first->second; +} + +// Starts the metrics server and returns the grpc::Server instance. Call Wait() +// on the returned server instance. +std::unique_ptr MetricsServiceImpl::StartServer(int port) { + gpr_log(GPR_INFO, "Building metrics server.."); + + const std::string address = "0.0.0.0:" + std::to_string(port); + + ServerBuilder builder; + builder.AddListeningPort(address, grpc::InsecureServerCredentials()); + builder.RegisterService(this); + + std::unique_ptr server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Metrics server %s started. Ready to receive requests..", + address.c_str()); + + return server; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc new file mode 100644 index 00000000..7a977f31 --- /dev/null +++ b/test/cpp/util/proto_file_parser.cc @@ -0,0 +1,332 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/proto_file_parser.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_split.h" + +#include + +namespace grpc { +namespace testing { +namespace { + +// Match the user input method string to the full_name from method descriptor. +bool MethodNameMatch(const std::string& full_name, const std::string& input) { + std::string clean_input = input; + std::replace(clean_input.begin(), clean_input.end(), '/', '.'); + if (clean_input.size() > full_name.size()) { + return false; + } + return full_name.compare(full_name.size() - clean_input.size(), + clean_input.size(), clean_input) == 0; +} +} // namespace + +class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector { + public: + explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {} + + void AddError(const std::string& filename, int line, int column, + const std::string& message) override { + std::ostringstream oss; + oss << "error " << filename << " " << line << " " << column << " " + << message << "\n"; + parser_->LogError(oss.str()); + } + + void AddWarning(const std::string& filename, int line, int column, + const std::string& message) override { + std::cerr << "warning " << filename << " " << line << " " << column << " " + << message << std::endl; + } + + private: + ProtoFileParser* parser_; // not owned +}; + +ProtoFileParser::ProtoFileParser(const std::shared_ptr& channel, + const std::string& proto_path, + const std::string& protofiles) + : has_error_(false), + dynamic_factory_(new protobuf::DynamicMessageFactory()) { + std::vector service_list; + if (channel) { + reflection_db_ = + absl::make_unique(channel); + reflection_db_->GetServices(&service_list); + } + + std::unordered_set known_services; + if (!protofiles.empty()) { + for (const absl::string_view single_path : absl::StrSplit( + proto_path, GRPC_CLI_PATH_SEPARATOR, absl::AllowEmpty())) { + source_tree_.MapPath("", std::string(single_path)); + } + error_printer_ = absl::make_unique(this); + importer_ = absl::make_unique( + &source_tree_, error_printer_.get()); + + std::string file_name; + std::stringstream ss(protofiles); + while (std::getline(ss, file_name, ',')) { + const auto* file_desc = importer_->Import(file_name); + if (file_desc) { + for (int i = 0; i < file_desc->service_count(); i++) { + service_desc_list_.push_back(file_desc->service(i)); + known_services.insert(file_desc->service(i)->full_name()); + } + } else { + std::cerr << file_name << " not found" << std::endl; + } + } + + file_db_ = + absl::make_unique(*importer_->pool()); + } + + if (!reflection_db_ && !file_db_) { + LogError("No available proto database"); + return; + } + + if (!reflection_db_) { + desc_db_ = std::move(file_db_); + } else if (!file_db_) { + desc_db_ = std::move(reflection_db_); + } else { + desc_db_ = absl::make_unique( + reflection_db_.get(), file_db_.get()); + } + + desc_pool_ = absl::make_unique(desc_db_.get()); + + for (auto it = service_list.begin(); it != service_list.end(); it++) { + if (known_services.find(*it) == known_services.end()) { + if (const protobuf::ServiceDescriptor* service_desc = + desc_pool_->FindServiceByName(*it)) { + service_desc_list_.push_back(service_desc); + known_services.insert(*it); + } + } + } +} + +ProtoFileParser::~ProtoFileParser() {} + +std::string ProtoFileParser::GetFullMethodName(const std::string& method) { + has_error_ = false; + + if (known_methods_.find(method) != known_methods_.end()) { + return known_methods_[method]; + } + + const protobuf::MethodDescriptor* method_descriptor = nullptr; + for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); + it++) { + const auto* service_desc = *it; + for (int j = 0; j < service_desc->method_count(); j++) { + const auto* method_desc = service_desc->method(j); + if (MethodNameMatch(method_desc->full_name(), method)) { + if (method_descriptor) { + std::ostringstream error_stream; + error_stream << "Ambiguous method names: "; + error_stream << method_descriptor->full_name() << " "; + error_stream << method_desc->full_name(); + LogError(error_stream.str()); + } + method_descriptor = method_desc; + } + } + } + if (!method_descriptor) { + LogError("Method name not found"); + } + if (has_error_) { + return ""; + } + + known_methods_[method] = method_descriptor->full_name(); + + return method_descriptor->full_name(); +} + +std::string ProtoFileParser::GetFormattedMethodName(const std::string& method) { + has_error_ = false; + std::string formatted_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + size_t last_dot = formatted_method_name.find_last_of('.'); + if (last_dot != std::string::npos) { + formatted_method_name[last_dot] = '/'; + } + formatted_method_name.insert(formatted_method_name.begin(), '/'); + return formatted_method_name; +} + +std::string ProtoFileParser::GetMessageTypeFromMethod(const std::string& method, + bool is_request) { + has_error_ = false; + std::string full_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(full_method_name); + if (!method_desc) { + LogError("Method not found"); + return ""; + } + + return is_request ? method_desc->input_type()->full_name() + : method_desc->output_type()->full_name(); +} + +bool ProtoFileParser::IsStreaming(const std::string& method, bool is_request) { + has_error_ = false; + + std::string full_method_name = GetFullMethodName(method); + if (has_error_) { + return false; + } + + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(full_method_name); + if (!method_desc) { + LogError("Method not found"); + return false; + } + + return is_request ? method_desc->client_streaming() + : method_desc->server_streaming(); +} + +std::string ProtoFileParser::GetSerializedProtoFromMethod( + const std::string& method, const std::string& formatted_proto, + bool is_request, bool is_json_format) { + has_error_ = false; + std::string message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetSerializedProtoFromMessageType(message_type_name, formatted_proto, + is_json_format); +} + +std::string ProtoFileParser::GetFormattedStringFromMethod( + const std::string& method, const std::string& serialized_proto, + bool is_request, bool is_json_format) { + has_error_ = false; + std::string message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetFormattedStringFromMessageType(message_type_name, serialized_proto, + is_json_format); +} + +std::string ProtoFileParser::GetSerializedProtoFromMessageType( + const std::string& message_type_name, const std::string& formatted_proto, + bool is_json_format) { + has_error_ = false; + std::string serialized; + const protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(message_type_name); + if (!desc) { + LogError("Message type not found"); + return ""; + } + std::unique_ptr msg( + dynamic_factory_->GetPrototype(desc)->New()); + + bool ok; + if (is_json_format) { + ok = grpc::protobuf::json::JsonStringToMessage(formatted_proto, msg.get()) + .ok(); + if (!ok) { + LogError("Failed to convert json format to proto."); + return ""; + } + } else { + ok = protobuf::TextFormat::ParseFromString(formatted_proto, msg.get()); + if (!ok) { + LogError("Failed to convert text format to proto."); + return ""; + } + } + + ok = msg->SerializeToString(&serialized); + if (!ok) { + LogError("Failed to serialize proto."); + return ""; + } + return serialized; +} + +std::string ProtoFileParser::GetFormattedStringFromMessageType( + const std::string& message_type_name, const std::string& serialized_proto, + bool is_json_format) { + has_error_ = false; + const protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(message_type_name); + if (!desc) { + LogError("Message type not found"); + return ""; + } + std::unique_ptr msg( + dynamic_factory_->GetPrototype(desc)->New()); + if (!msg->ParseFromString(serialized_proto)) { + LogError("Failed to deserialize proto."); + return ""; + } + std::string formatted_string; + + if (is_json_format) { + grpc::protobuf::json::JsonPrintOptions jsonPrintOptions; + jsonPrintOptions.add_whitespace = true; + if (!grpc::protobuf::json::MessageToJsonString(*msg, &formatted_string, + jsonPrintOptions) + .ok()) { + LogError("Failed to print proto message to json format"); + return ""; + } + } else { + if (!protobuf::TextFormat::PrintToString(*msg, &formatted_string)) { + LogError("Failed to print proto message to text format"); + return ""; + } + } + return formatted_string; +} + +void ProtoFileParser::LogError(const std::string& error_msg) { + if (!error_msg.empty()) { + std::cerr << error_msg << std::endl; + } + has_error_ = true; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/proto_reflection_descriptor_database.cc b/test/cpp/util/proto_reflection_descriptor_database.cc new file mode 100644 index 00000000..f2356021 --- /dev/null +++ b/test/cpp/util/proto_reflection_descriptor_database.cc @@ -0,0 +1,333 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +#include + +#include + +using grpc::reflection::v1alpha::ErrorResponse; +using grpc::reflection::v1alpha::ListServiceResponse; +using grpc::reflection::v1alpha::ServerReflection; +using grpc::reflection::v1alpha::ServerReflectionRequest; +using grpc::reflection::v1alpha::ServerReflectionResponse; + +namespace grpc { + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + std::unique_ptr stub) + : stub_(std::move(stub)) {} + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + const std::shared_ptr& channel) + : stub_(ServerReflection::NewStub(channel)) {} + +ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() { + if (stream_) { + stream_->WritesDone(); + Status status = stream_->Finish(); + if (!status.ok()) { + if (status.error_code() == StatusCode::UNIMPLEMENTED) { + fprintf(stderr, + "Reflection request not implemented; " + "is the ServerReflection service enabled?\n"); + } else { + fprintf(stderr, + "ServerReflectionInfo rpc failed. Error code: %d, message: %s, " + "debug info: %s\n", + static_cast(status.error_code()), + status.error_message().c_str(), + ctx_.debug_error_string().c_str()); + } + } + } +} + +bool ProtoReflectionDescriptorDatabase::FindFileByName( + const string& filename, protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileByName(filename, output)) { + return true; + } + + if (known_files_.find(filename) != known_files_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_by_filename(filename); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)", + filename.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileByName(%s)\n\tError code: %d\n" + "\tError Message: %s", + filename.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileByName(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + filename.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileByName(filename, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol( + const string& symbol_name, protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingSymbol(symbol_name, output)) { + return true; + } + + if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_containing_symbol(symbol_name); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + missing_symbols_.insert(symbol_name); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingSymbol(%s)", + symbol_name.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingSymbol(%s)\n" + "\tError code: %d\n\tError Message: %s", + symbol_name.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingSymbol(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + symbol_name.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + return cached_db_.FindFileContainingSymbol(symbol_name, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension( + const string& containing_type, int field_number, + protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingExtension(containing_type, field_number, + output)) { + return true; + } + + if (missing_extensions_.find(containing_type) != missing_extensions_.end() && + missing_extensions_[containing_type].find(field_number) != + missing_extensions_[containing_type].end()) { + gpr_log(GPR_INFO, "nested map."); + return false; + } + + ServerReflectionRequest request; + request.mutable_file_containing_extension()->set_containing_type( + containing_type); + request.mutable_file_containing_extension()->set_extension_number( + field_number); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + if (missing_extensions_.find(containing_type) == + missing_extensions_.end()) { + missing_extensions_[containing_type] = {}; + } + missing_extensions_[containing_type].insert(field_number); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingExtension(%s, %d)", + containing_type.c_str(), field_number); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingExtension(%s, %d)\n" + "\tError code: %d\n\tError Message: %s", + containing_type.c_str(), field_number, error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingExtension(%s, %d) response type\n" + "\tExpecting: %d\n\tReceived: %d", + containing_type.c_str(), field_number, + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileContainingExtension(containing_type, field_number, + output); +} + +bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers( + const string& extendee_type, std::vector* output) { + if (cached_extension_numbers_.find(extendee_type) != + cached_extension_numbers_.end()) { + *output = cached_extension_numbers_[extendee_type]; + return true; + } + + ServerReflectionRequest request; + request.set_all_extension_numbers_of_type(extendee_type); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase:: + kAllExtensionNumbersResponse) { + auto number = response.all_extension_numbers_response().extension_number(); + *output = std::vector(number.begin(), number.end()); + cached_extension_numbers_[extendee_type] = *output; + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)", + extendee_type.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindAllExtensionNumbersExtension(%s)\n" + "\tError code: %d\n\tError Message: %s", + extendee_type.c_str(), error.error_code(), + error.error_message().c_str()); + } + } + return false; +} + +bool ProtoReflectionDescriptorDatabase::GetServices( + std::vector* output) { + ServerReflectionRequest request; + request.set_list_services(""); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kListServicesResponse) { + const ListServiceResponse& ls_response = response.list_services_response(); + for (int i = 0; i < ls_response.service_size(); ++i) { + (*output).push_back(ls_response.service(i).name()); + } + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + gpr_log(GPR_INFO, + "Error on GetServices()\n\tError code: %d\n" + "\tError Message: %s", + error.error_code(), error.error_message().c_str()); + } else { + gpr_log( + GPR_INFO, + "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d", + ServerReflectionResponse::MessageResponseCase::kListServicesResponse, + response.message_response_case()); + } + return false; +} + +protobuf::FileDescriptorProto +ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse( + const std::string& byte_fd_proto) { + protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.ParseFromString(byte_fd_proto); + return file_desc_proto; +} + +void ProtoReflectionDescriptorDatabase::AddFileFromResponse( + const grpc::reflection::v1alpha::FileDescriptorResponse& response) { + for (int i = 0; i < response.file_descriptor_proto_size(); ++i) { + const protobuf::FileDescriptorProto file_proto = + ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i)); + if (known_files_.find(file_proto.name()) == known_files_.end()) { + known_files_.insert(file_proto.name()); + cached_db_.Add(file_proto); + } + } +} + +std::shared_ptr +ProtoReflectionDescriptorDatabase::GetStream() { + if (!stream_) { + stream_ = stub_->ServerReflectionInfo(&ctx_); + } + return stream_; +} + +bool ProtoReflectionDescriptorDatabase::DoOneRequest( + const ServerReflectionRequest& request, + ServerReflectionResponse& response) { + bool success = false; + stream_mutex_.lock(); + if (GetStream()->Write(request) && GetStream()->Read(&response)) { + success = true; + } + stream_mutex_.unlock(); + return success; +} + +} // namespace grpc diff --git a/test/cpp/util/service_describer.cc b/test/cpp/util/service_describer.cc new file mode 100644 index 00000000..f5b6208d --- /dev/null +++ b/test/cpp/util/service_describer.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/service_describer.h" + +#include +#include +#include +#include + +namespace grpc { +namespace testing { + +std::string DescribeServiceList(std::vector service_list, + grpc::protobuf::DescriptorPool& desc_pool) { + std::stringstream result; + for (auto it = service_list.begin(); it != service_list.end(); it++) { + auto const& service = *it; + const grpc::protobuf::ServiceDescriptor* service_desc = + desc_pool.FindServiceByName(service); + if (service_desc != nullptr) { + result << DescribeService(service_desc); + } + } + return result.str(); +} + +std::string DescribeService(const grpc::protobuf::ServiceDescriptor* service) { + std::string result; + if (service->options().deprecated()) { + result.append("DEPRECATED\n"); + } + result.append("filename: " + service->file()->name() + "\n"); + + std::string package = service->full_name(); + size_t pos = package.rfind("." + service->name()); + if (pos != std::string::npos) { + package.erase(pos); + result.append("package: " + package + ";\n"); + } + result.append("service " + service->name() + " {\n"); + for (int i = 0; i < service->method_count(); ++i) { + result.append(DescribeMethod(service->method(i))); + } + result.append("}\n\n"); + return result; +} + +std::string DescribeMethod(const grpc::protobuf::MethodDescriptor* method) { + std::stringstream result; + result << " rpc " << method->name() + << (method->client_streaming() ? "(stream " : "(") + << method->input_type()->full_name() << ") returns " + << (method->server_streaming() ? "(stream " : "(") + << method->output_type()->full_name() << ") {}\n"; + if (method->options().deprecated()) { + result << " DEPRECATED"; + } + return result.str(); +} + +std::string SummarizeService(const grpc::protobuf::ServiceDescriptor* service) { + std::string result; + for (int i = 0; i < service->method_count(); ++i) { + result.append(SummarizeMethod(service->method(i))); + } + return result; +} + +std::string SummarizeMethod(const grpc::protobuf::MethodDescriptor* method) { + std::string result = method->name(); + result.append("\n"); + return result; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/slice_test.cc b/test/cpp/util/slice_test.cc new file mode 100644 index 00000000..d74d3f8a --- /dev/null +++ b/test/cpp/util/slice_test.cc @@ -0,0 +1,151 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include +#include +#include + +#include "test/core/util/test_config.h" + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +namespace { + +const char* kContent = "hello xxxxxxxxxxxxxxxxxxxx world"; + +class SliceTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } + + void CheckSliceSize(const Slice& s, const std::string& content) { + EXPECT_EQ(content.size(), s.size()); + } + void CheckSlice(const Slice& s, const std::string& content) { + EXPECT_EQ(content.size(), s.size()); + EXPECT_EQ(content, + std::string(reinterpret_cast(s.begin()), s.size())); + } +}; + +TEST_F(SliceTest, Empty) { + Slice empty_slice; + CheckSlice(empty_slice, ""); +} + +TEST_F(SliceTest, Sized) { + Slice sized_slice(strlen(kContent)); + CheckSliceSize(sized_slice, kContent); +} + +TEST_F(SliceTest, String) { + Slice spp(kContent); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Buf) { + Slice spp(kContent, strlen(kContent)); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, StaticBuf) { + Slice spp(kContent, strlen(kContent), Slice::STATIC_SLICE); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNew) { + char* x = new char[strlen(kContent) + 1]; + strcpy(x, kContent); + Slice spp(x, strlen(x), [](void* p) { delete[] static_cast(p); }); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewDoNothing) { + Slice spp(const_cast(kContent), strlen(kContent), [](void* /*p*/) {}); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewWithUserData) { + struct stest { + char* x; + int y; + }; + auto* t = new stest; + t->x = new char[strlen(kContent) + 1]; + strcpy(t->x, kContent); + Slice spp( + t->x, strlen(t->x), + [](void* p) { + auto* t = static_cast(p); + delete[] t->x; + delete t; + }, + t); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewLen) { + Slice spp(const_cast(kContent), strlen(kContent), + [](void* /*p*/, size_t l) { EXPECT_EQ(l, strlen(kContent)); }); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Steal) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::STEAL_REF); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Add) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::ADD_REF); + grpc_slice_unref(s); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Sub) { + Slice spp("0123456789"); + Slice sub = spp.sub(1, 9); + CheckSlice(sub, "12345678"); +} + +TEST_F(SliceTest, Cslice) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::STEAL_REF); + CheckSlice(spp, kContent); + grpc_slice c_slice = spp.c_slice(); + EXPECT_EQ(GRPC_SLICE_START_PTR(s), GRPC_SLICE_START_PTR(c_slice)); + EXPECT_EQ(GRPC_SLICE_END_PTR(s), GRPC_SLICE_END_PTR(c_slice)); + grpc_slice_unref(c_slice); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/test/cpp/util/string_ref_helper.cc b/test/cpp/util/string_ref_helper.cc new file mode 100644 index 00000000..247e7a71 --- /dev/null +++ b/test/cpp/util/string_ref_helper.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/string_ref_helper.h" + +namespace grpc { +namespace testing { + +std::string ToString(const grpc::string_ref& r) { + return std::string(r.data(), r.size()); +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/string_ref_test.cc b/test/cpp/util/string_ref_test.cc new file mode 100644 index 00000000..1537e187 --- /dev/null +++ b/test/cpp/util/string_ref_test.cc @@ -0,0 +1,205 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include + +#include + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +const char kTestString[] = "blah"; +const char kTestStringWithEmbeddedNull[] = "blah\0foo"; +const size_t kTestStringWithEmbeddedNullLength = 8; +const char kTestUnrelatedString[] = "foo"; + +class StringRefTest : public ::testing::Test {}; + +TEST_F(StringRefTest, Empty) { + string_ref s; + EXPECT_EQ(0U, s.length()); + EXPECT_EQ(nullptr, s.data()); +} + +TEST_F(StringRefTest, FromCString) { + string_ref s(kTestString); + EXPECT_EQ(strlen(kTestString), s.length()); + EXPECT_EQ(kTestString, s.data()); +} + +TEST_F(StringRefTest, FromCStringWithLength) { + string_ref s(kTestString, 2); + EXPECT_EQ(2U, s.length()); + EXPECT_EQ(kTestString, s.data()); +} + +TEST_F(StringRefTest, FromString) { + string copy(kTestString); + string_ref s(copy); + EXPECT_EQ(copy.data(), s.data()); + EXPECT_EQ(copy.length(), s.length()); +} + +TEST_F(StringRefTest, CopyConstructor) { + string_ref s1(kTestString); + ; + const string_ref& s2(s1); + EXPECT_EQ(s1.length(), s2.length()); + EXPECT_EQ(s1.data(), s2.data()); +} + +TEST_F(StringRefTest, FromStringWithEmbeddedNull) { + string copy(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + string_ref s(copy); + EXPECT_EQ(copy.data(), s.data()); + EXPECT_EQ(copy.length(), s.length()); + EXPECT_EQ(kTestStringWithEmbeddedNullLength, s.length()); +} + +TEST_F(StringRefTest, Assignment) { + string_ref s1(kTestString); + ; + string_ref s2; + EXPECT_EQ(nullptr, s2.data()); + s2 = s1; + EXPECT_EQ(s1.length(), s2.length()); + EXPECT_EQ(s1.data(), s2.data()); +} + +TEST_F(StringRefTest, Iterator) { + string_ref s(kTestString); + size_t i = 0; + for (auto it = s.cbegin(); it != s.cend(); ++it) { + auto val = kTestString[i++]; + EXPECT_EQ(val, *it); + } + EXPECT_EQ(strlen(kTestString), i); +} + +TEST_F(StringRefTest, ReverseIterator) { + string_ref s(kTestString); + size_t i = strlen(kTestString); + for (auto rit = s.crbegin(); rit != s.crend(); ++rit) { + auto val = kTestString[--i]; + EXPECT_EQ(val, *rit); + } + EXPECT_EQ(0U, i); +} + +TEST_F(StringRefTest, Capacity) { + string_ref empty; + EXPECT_EQ(0U, empty.length()); + EXPECT_EQ(0U, empty.size()); + EXPECT_EQ(0U, empty.max_size()); + EXPECT_TRUE(empty.empty()); + + string_ref s(kTestString); + EXPECT_EQ(strlen(kTestString), s.length()); + EXPECT_EQ(s.length(), s.size()); + EXPECT_EQ(s.max_size(), s.length()); + EXPECT_FALSE(s.empty()); +} + +TEST_F(StringRefTest, Compare) { + string_ref s1(kTestString); + string s1_copy(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(0, s1.compare(s1_copy)); + EXPECT_NE(0, s1.compare(s2)); + EXPECT_NE(0, s1.compare(s3)); +} + +TEST_F(StringRefTest, StartsWith) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_TRUE(s1.starts_with(s1)); + EXPECT_FALSE(s1.starts_with(s2)); + EXPECT_FALSE(s2.starts_with(s1)); + EXPECT_FALSE(s1.starts_with(s3)); + EXPECT_TRUE(s3.starts_with(s1)); +} + +TEST_F(StringRefTest, Endswith) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_TRUE(s1.ends_with(s1)); + EXPECT_FALSE(s1.ends_with(s2)); + EXPECT_FALSE(s2.ends_with(s1)); + EXPECT_FALSE(s2.ends_with(s3)); + EXPECT_TRUE(s3.ends_with(s2)); +} + +TEST_F(StringRefTest, Find) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(0U, s1.find(s1)); + EXPECT_EQ(0U, s2.find(s2)); + EXPECT_EQ(0U, s3.find(s3)); + EXPECT_EQ(string_ref::npos, s1.find(s2)); + EXPECT_EQ(string_ref::npos, s2.find(s1)); + EXPECT_EQ(string_ref::npos, s1.find(s3)); + EXPECT_EQ(0U, s3.find(s1)); + EXPECT_EQ(5U, s3.find(s2)); + EXPECT_EQ(string_ref::npos, s1.find('z')); + EXPECT_EQ(1U, s2.find('o')); +} + +TEST_F(StringRefTest, SubString) { + string_ref s(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + string_ref sub1 = s.substr(0, 4); + EXPECT_EQ(string_ref(kTestString), sub1); + string_ref sub2 = s.substr(5); + EXPECT_EQ(string_ref(kTestUnrelatedString), sub2); +} + +TEST_F(StringRefTest, ComparisonOperators) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(s1, s1); + EXPECT_EQ(s2, s2); + EXPECT_EQ(s3, s3); + EXPECT_GE(s1, s1); + EXPECT_GE(s2, s2); + EXPECT_GE(s3, s3); + EXPECT_LE(s1, s1); + EXPECT_LE(s2, s2); + EXPECT_LE(s3, s3); + EXPECT_NE(s1, s2); + EXPECT_NE(s1, s3); + EXPECT_NE(s2, s3); + EXPECT_GT(s3, s1); + EXPECT_LT(s1, s3); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/util/subprocess.cc b/test/cpp/util/subprocess.cc new file mode 100644 index 00000000..ddaad898 --- /dev/null +++ b/test/cpp/util/subprocess.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/subprocess.h" + +#include + +#include "test/core/util/subprocess.h" + +namespace grpc { + +static gpr_subprocess* MakeProcess(const std::vector& args) { + std::vector vargs; + for (auto it = args.begin(); it != args.end(); ++it) { + vargs.push_back(it->c_str()); + } + return gpr_subprocess_create(vargs.size(), &vargs[0]); +} + +SubProcess::SubProcess(const std::vector& args) + : subprocess_(MakeProcess(args)) {} + +SubProcess::~SubProcess() { gpr_subprocess_destroy(subprocess_); } + +int SubProcess::Join() { return gpr_subprocess_join(subprocess_); } + +void SubProcess::Interrupt() { gpr_subprocess_interrupt(subprocess_); } + +} // namespace grpc diff --git a/test/cpp/util/test_config_cc.cc b/test/cpp/util/test_config_cc.cc new file mode 100644 index 00000000..866540d5 --- /dev/null +++ b/test/cpp/util/test_config_cc.cc @@ -0,0 +1,39 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include "absl/flags/parse.h" + +#include "test/cpp/util/test_config.h" + +namespace grpc { +namespace testing { + +void InitTest(int* argc, char*** argv, bool remove_flags) { + std::vector reduced_argv = absl::ParseCommandLine(*argc, *argv); + if (remove_flags) { + *argc = reduced_argv.size(); + for (int i = 0; i < *argc; i++) { + (*argv)[i] = reduced_argv.at(i); + } + } +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/test_credentials_provider.cc b/test/cpp/util/test_credentials_provider.cc new file mode 100644 index 00000000..b635e27c --- /dev/null +++ b/test/cpp/util/test_credentials_provider.cc @@ -0,0 +1,183 @@ + +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/test_credentials_provider.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" + +#include +#include +#include + +#include "test/core/end2end/data/ssl_test_data.h" + +ABSL_FLAG(std::string, tls_cert_file, "", + "The TLS cert file used when --use_tls=true"); +ABSL_FLAG(std::string, tls_key_file, "", + "The TLS key file used when --use_tls=true"); + +namespace grpc { +namespace testing { +namespace { + +std::string ReadFile(const std::string& src_path) { + std::ifstream src; + src.open(src_path, std::ifstream::in | std::ifstream::binary); + + std::string contents; + src.seekg(0, std::ios::end); + contents.reserve(src.tellg()); + src.seekg(0, std::ios::beg); + contents.assign((std::istreambuf_iterator(src)), + (std::istreambuf_iterator())); + return contents; +} + +class DefaultCredentialsProvider : public CredentialsProvider { + public: + DefaultCredentialsProvider() { + if (!absl::GetFlag(FLAGS_tls_key_file).empty()) { + custom_server_key_ = ReadFile(absl::GetFlag(FLAGS_tls_key_file)); + } + if (!absl::GetFlag(FLAGS_tls_cert_file).empty()) { + custom_server_cert_ = ReadFile(absl::GetFlag(FLAGS_tls_cert_file)); + } + } + ~DefaultCredentialsProvider() override {} + + void AddSecureType( + const std::string& type, + std::unique_ptr type_provider) override { + // This clobbers any existing entry for type, except the defaults, which + // can't be clobbered. + std::unique_lock lock(mu_); + auto it = std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type); + if (it == added_secure_type_names_.end()) { + added_secure_type_names_.push_back(type); + added_secure_type_providers_.push_back(std::move(type_provider)); + } else { + added_secure_type_providers_[it - added_secure_type_names_.begin()] = + std::move(type_provider); + } + } + + std::shared_ptr GetChannelCredentials( + const std::string& type, ChannelArguments* args) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureChannelCredentials(); + } else if (type == grpc::testing::kAltsCredentialsType) { + grpc::experimental::AltsCredentialsOptions alts_opts; + return grpc::experimental::AltsCredentials(alts_opts); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + args->SetSslTargetNameOverride("foo.test.google.fr"); + return grpc::SslCredentials(ssl_opts); + } else if (type == grpc::testing::kGoogleDefaultCredentialsType) { + return grpc::GoogleDefaultCredentials(); + } else { + std::unique_lock lock(mu_); + auto it(std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type)); + if (it == added_secure_type_names_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return added_secure_type_providers_[it - added_secure_type_names_.begin()] + ->GetChannelCredentials(args); + } + } + + std::shared_ptr GetServerCredentials( + const std::string& type) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureServerCredentials(); + } else if (type == grpc::testing::kAltsCredentialsType) { + grpc::experimental::AltsServerCredentialsOptions alts_opts; + return grpc::experimental::AltsServerCredentials(alts_opts); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + if (!custom_server_key_.empty() && !custom_server_cert_.empty()) { + SslServerCredentialsOptions::PemKeyCertPair pkcp = { + custom_server_key_, custom_server_cert_}; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + } else { + SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, + test_server1_cert}; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + } + return SslServerCredentials(ssl_opts); + } else { + std::unique_lock lock(mu_); + auto it(std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type)); + if (it == added_secure_type_names_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return added_secure_type_providers_[it - added_secure_type_names_.begin()] + ->GetServerCredentials(); + } + } + std::vector GetSecureCredentialsTypeList() override { + std::vector types; + types.push_back(grpc::testing::kTlsCredentialsType); + std::unique_lock lock(mu_); + for (auto it = added_secure_type_names_.begin(); + it != added_secure_type_names_.end(); it++) { + types.push_back(*it); + } + return types; + } + + private: + std::mutex mu_; + std::vector added_secure_type_names_; + std::vector> + added_secure_type_providers_; + std::string custom_server_key_; + std::string custom_server_cert_; +}; + +CredentialsProvider* g_provider = nullptr; + +} // namespace + +CredentialsProvider* GetCredentialsProvider() { + if (g_provider == nullptr) { + g_provider = new DefaultCredentialsProvider; + } + return g_provider; +} + +void SetCredentialsProvider(CredentialsProvider* provider) { + // For now, forbids overriding provider. + GPR_ASSERT(g_provider == nullptr); + g_provider = provider; +} + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/util/time_test.cc b/test/cpp/util/time_test.cc new file mode 100644 index 00000000..4970f4b5 --- /dev/null +++ b/test/cpp/util/time_test.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +#include +#include + +#include "test/core/util/test_config.h" + +using std::chrono::microseconds; +using std::chrono::system_clock; + +namespace grpc { +namespace { + +class TimeTest : public ::testing::Test {}; + +TEST_F(TimeTest, AbsolutePointTest) { + int64_t us = 10000000L; + gpr_timespec ts = gpr_time_from_micros(us, GPR_TIMESPAN); + ts.clock_type = GPR_CLOCK_REALTIME; + system_clock::time_point tp{microseconds(us)}; + system_clock::time_point tp_converted = Timespec2Timepoint(ts); + gpr_timespec ts_converted; + Timepoint2Timespec(tp_converted, &ts_converted); + EXPECT_TRUE(ts.tv_sec == ts_converted.tv_sec); + EXPECT_TRUE(ts.tv_nsec == ts_converted.tv_nsec); + system_clock::time_point tp_converted_2 = Timespec2Timepoint(ts_converted); + EXPECT_TRUE(tp == tp_converted); + EXPECT_TRUE(tp == tp_converted_2); +} + +// gpr_inf_future is treated specially and mapped to/from time_point::max() +TEST_F(TimeTest, InfFuture) { + EXPECT_EQ(system_clock::time_point::max(), + Timespec2Timepoint(gpr_inf_future(GPR_CLOCK_REALTIME))); + gpr_timespec from_time_point_max; + Timepoint2Timespec(system_clock::time_point::max(), &from_time_point_max); + EXPECT_EQ( + 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max)); + // This will cause an overflow + Timepoint2Timespec( + std::chrono::time_point::max(), + &from_time_point_max); + EXPECT_EQ( + 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max)); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/spm_build/test.cc b/test/spm_build/test.cc new file mode 100644 index 00000000..2eab4acf --- /dev/null +++ b/test/spm_build/test.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/completion_queue.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/generic/generic_stub.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/impl/codegen/grpc_library.h" +#include "grpcpp/support/byte_buffer.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/status_code_enum.h" +#include "grpcpp/support/string_ref.h" diff --git a/third_party/upb/benchmarks/benchmark.cc b/third_party/upb/benchmarks/benchmark.cc new file mode 100644 index 00000000..60e3b530 --- /dev/null +++ b/third_party/upb/benchmarks/benchmark.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "benchmarks/descriptor.pb.h" +#include "benchmarks/descriptor.upb.h" +#include "benchmarks/descriptor.upbdefs.h" +#include "benchmarks/descriptor_sv.pb.h" +#include "google/ads/googleads/v7/services/google_ads_service.upbdefs.h" +#include "google/protobuf/descriptor.pb.h" +#include "upb/def.hpp" + +upb_strview descriptor = benchmarks_descriptor_proto_upbdefinit.descriptor; +namespace protobuf = ::google::protobuf; + +/* A buffer big enough to parse descriptor.proto without going to heap. */ +char buf[65535]; + +void CollectFileDescriptors(const upb_def_init* file, + std::vector& serialized_files, + absl::flat_hash_set& seen) { + if (!seen.insert(file).second) return; + for (upb_def_init **deps = file->deps; *deps; deps++) { + CollectFileDescriptors(*deps, serialized_files, seen); + } + serialized_files.push_back(file->descriptor); +} + +static void BM_ArenaOneAlloc(benchmark::State& state) { + for (auto _ : state) { + upb_arena* arena = upb_arena_new(); + upb_arena_malloc(arena, 1); + upb_arena_free(arena); + } +} +BENCHMARK(BM_ArenaOneAlloc); + +static void BM_ArenaInitialBlockOneAlloc(benchmark::State& state) { + for (auto _ : state) { + upb_arena* arena = upb_arena_init(buf, sizeof(buf), NULL); + upb_arena_malloc(arena, 1); + upb_arena_free(arena); + } +} +BENCHMARK(BM_ArenaInitialBlockOneAlloc); + +static void BM_LoadDescriptor_Upb(benchmark::State& state) { + size_t bytes_per_iter = 0; + for (auto _ : state) { + upb::SymbolTable symtab; + upb_benchmark_DescriptorProto_getmsgdef(symtab.ptr()); + bytes_per_iter = _upb_symtab_bytesloaded(symtab.ptr()); + } + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} +BENCHMARK(BM_LoadDescriptor_Upb); + +static void BM_LoadAdsDescriptor_Upb(benchmark::State& state) { + size_t bytes_per_iter = 0; + for (auto _ : state) { + upb::SymbolTable symtab; + google_ads_googleads_v7_services_SearchGoogleAdsRequest_getmsgdef( + symtab.ptr()); + bytes_per_iter = _upb_symtab_bytesloaded(symtab.ptr()); + } + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} +BENCHMARK(BM_LoadAdsDescriptor_Upb); + +static void BM_LoadDescriptor_Proto2(benchmark::State& state) { + for (auto _ : state) { + protobuf::Arena arena; + protobuf::StringPiece input(descriptor.data,descriptor.size); + auto proto = protobuf::Arena::CreateMessage( + &arena); + protobuf::DescriptorPool pool; + bool ok = proto->ParseFrom(input) && + pool.BuildFile(*proto) != nullptr; + if (!ok) { + printf("Failed to add file.\n"); + exit(1); + } + } + state.SetBytesProcessed(state.iterations() * descriptor.size); +} +BENCHMARK(BM_LoadDescriptor_Proto2); + +static void BM_LoadAdsDescriptor_Proto2(benchmark::State& state) { + extern upb_def_init google_ads_googleads_v7_services_google_ads_service_proto_upbdefinit; + std::vector serialized_files; + absl::flat_hash_set seen_files; + CollectFileDescriptors( + &google_ads_googleads_v7_services_google_ads_service_proto_upbdefinit, + serialized_files, seen_files); + size_t bytes_per_iter = 0; + for (auto _ : state) { + bytes_per_iter = 0; + protobuf::Arena arena; + protobuf::DescriptorPool pool; + for (auto file : serialized_files) { + protobuf::StringPiece input(file.data, file.size); + auto proto = protobuf::Arena::CreateMessage( + &arena); + bool ok = proto->ParseFrom(input) && + pool.BuildFile(*proto) != nullptr; + if (!ok) { + printf("Failed to add file.\n"); + exit(1); + } + bytes_per_iter += input.size(); + } + } + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} +BENCHMARK(BM_LoadAdsDescriptor_Proto2); + +enum CopyStrings { + Copy, + Alias, +}; + +enum ArenaMode { + NoArena, + UseArena, + InitBlock, +}; + +template +static void BM_Parse_Upb_FileDesc(benchmark::State& state) { + size_t bytes = 0; + for (auto _ : state) { + upb_arena *arena; + if (AMode == InitBlock) { + arena = upb_arena_init(buf, sizeof(buf), NULL); + } else { + arena = upb_arena_new(); + } + upb_benchmark_FileDescriptorProto* set = + upb_benchmark_FileDescriptorProto_parse_ex( + descriptor.data, descriptor.size, NULL, + Copy == Alias ? UPB_DECODE_ALIAS : 0, arena); + if (!set) { + printf("Failed to parse.\n"); + exit(1); + } + bytes += descriptor.size; + upb_arena_free(arena); + } + state.SetBytesProcessed(state.iterations() * descriptor.size); +} +BENCHMARK_TEMPLATE(BM_Parse_Upb_FileDesc, UseArena, Copy); +BENCHMARK_TEMPLATE(BM_Parse_Upb_FileDesc, UseArena, Alias); +BENCHMARK_TEMPLATE(BM_Parse_Upb_FileDesc, InitBlock, Copy); +BENCHMARK_TEMPLATE(BM_Parse_Upb_FileDesc, InitBlock, Alias); + +template +struct Proto2Factory; + +template +struct Proto2Factory { + public: + P* GetProto() { return &proto_; } + + private: + P proto_; +}; + +template +struct Proto2Factory { + public: + P* GetProto() { return protobuf::Arena::CreateMessage

(&arena_); } + + private: + protobuf::Arena arena_; +}; + +template +struct Proto2Factory { + public: + Proto2Factory() : arena_(GetOptions()) {} + P* GetProto() { return protobuf::Arena::CreateMessage

(&arena_); } + + private: + protobuf::ArenaOptions GetOptions() { + protobuf::ArenaOptions opts; + opts.initial_block = buf; + opts.initial_block_size = sizeof(buf); + return opts; + } + + protobuf::Arena arena_; +}; + +using FileDesc = ::upb_benchmark::FileDescriptorProto; +using FileDescSV = ::upb_benchmark::sv::FileDescriptorProto; + +template +void BM_Parse_Proto2(benchmark::State& state) { + size_t bytes = 0; + constexpr protobuf::MessageLite::ParseFlags kParseFlags = + kCopy == Copy + ? protobuf::MessageLite::ParseFlags::kMergePartial + : protobuf::MessageLite::ParseFlags::kMergePartialWithAliasing; + for (auto _ : state) { + Proto2Factory proto_factory; + auto proto = proto_factory.GetProto(); + protobuf::StringPiece input(descriptor.data,descriptor.size); + bool ok = proto->template ParseFrom(input); + if (!ok) { + printf("Failed to parse.\n"); + exit(1); + } + bytes += descriptor.size; + } + state.SetBytesProcessed(state.iterations() * descriptor.size); +} +BENCHMARK_TEMPLATE(BM_Parse_Proto2, FileDesc, NoArena, Copy); +BENCHMARK_TEMPLATE(BM_Parse_Proto2, FileDesc, UseArena, Copy); +BENCHMARK_TEMPLATE(BM_Parse_Proto2, FileDesc, InitBlock, Copy); +BENCHMARK_TEMPLATE(BM_Parse_Proto2, FileDescSV, InitBlock, Alias); + +static void BM_SerializeDescriptor_Proto2(benchmark::State& state) { + size_t bytes = 0; + upb_benchmark::FileDescriptorProto proto; + proto.ParseFromArray(descriptor.data, descriptor.size); + for (auto _ : state) { + proto.SerializePartialToArray(buf, sizeof(buf)); + bytes += descriptor.size; + } + state.SetBytesProcessed(state.iterations() * descriptor.size); +} +BENCHMARK(BM_SerializeDescriptor_Proto2); + +static void BM_SerializeDescriptor_Upb(benchmark::State& state) { + int64_t total = 0; + upb_arena* arena = upb_arena_new(); + upb_benchmark_FileDescriptorProto* set = + upb_benchmark_FileDescriptorProto_parse(descriptor.data, descriptor.size, + arena); + if (!set) { + printf("Failed to parse.\n"); + exit(1); + } + for (auto _ : state) { + upb_arena* enc_arena = upb_arena_init(buf, sizeof(buf), NULL); + size_t size; + char* data = + upb_benchmark_FileDescriptorProto_serialize(set, enc_arena, &size); + if (!data) { + printf("Failed to serialize.\n"); + exit(1); + } + total += size; + } + state.SetBytesProcessed(total); +} +BENCHMARK(BM_SerializeDescriptor_Upb); diff --git a/third_party/upb/tests/corpus/temp.cc b/third_party/upb/tests/corpus/temp.cc new file mode 100644 index 00000000..c8f2691a --- /dev/null +++ b/third_party/upb/tests/corpus/temp.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Hello World diff --git a/third_party/upb/tests/file_descriptor_parsenew_fuzzer.cc b/third_party/upb/tests/file_descriptor_parsenew_fuzzer.cc new file mode 100644 index 00000000..55fff040 --- /dev/null +++ b/third_party/upb/tests/file_descriptor_parsenew_fuzzer.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "google/protobuf/descriptor.upb.h" +#include "upb/upb.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + upb::Arena arena; + google_protobuf_FileDescriptorProto_parse(reinterpret_cast(data), + size, arena.ptr()); + return 0; +} + +#ifndef HAVE_FUZZER +int main() {} +#endif diff --git a/third_party/upb/tests/test_cpp.cc b/third_party/upb/tests/test_cpp.cc new file mode 100644 index 00000000..ece707b1 --- /dev/null +++ b/third_party/upb/tests/test_cpp.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* + * Tests for C++ wrappers. + */ + +#include +#include + +#include +#include +#include +#include + +#include "tests/test_cpp.upbdefs.h" +#include "tests/test_cpp.upb.h" +#include "tests/upb_test.h" +#include "upb/def.h" +#include "upb/def.hpp" +#include "upb/json_decode.h" +#include "upb/json_encode.h" +#include "upb/upb.h" + +// Must be last. +#include "upb/port_def.inc" + +void TestIteration() { + upb::SymbolTable symtab; + upb::MessageDefPtr md(upb_test_TestMessage_getmsgdef(symtab.ptr())); + + // Test range-based for on both fields and oneofs (with the iterator adaptor). + int field_count = 0; + for (auto field : md.fields()) { + UPB_UNUSED(field); + field_count++; + } + ASSERT(field_count == md.field_count()); + + int oneof_count = 0; + for (auto oneof : md.oneofs()) { + UPB_UNUSED(oneof); + oneof_count++; + } + ASSERT(oneof_count == md.oneof_count()); +} + +void TestArena() { + int n = 100000; + + struct Decrementer { + Decrementer(int* _p) : p(_p) {} + ~Decrementer() { (*p)--; } + int* p; + }; + + { + upb::Arena arena; + for (int i = 0; i < n; i++) { + arena.Own(new Decrementer(&n)); + + // Intersperse allocation and ensure we can write to it. + int* val = static_cast(upb_arena_malloc(arena.ptr(), sizeof(int))); + *val = i; + } + + // Test a large allocation. + upb_arena_malloc(arena.ptr(), 1000000); + } + ASSERT(n == 0); + + { + // Test fuse. + upb::Arena arena1; + upb::Arena arena2; + + arena1.Fuse(arena2); + + upb_arena_malloc(arena1.ptr(), 10000); + upb_arena_malloc(arena2.ptr(), 10000); + } +} + +void TestInlinedArena() { + int n = 100000; + + struct Decrementer { + Decrementer(int* _p) : p(_p) {} + ~Decrementer() { (*p)--; } + int* p; + }; + + { + upb::InlinedArena<1024> arena; + for (int i = 0; i < n; i++) { + arena.Own(new Decrementer(&n)); + + // Intersperse allocation and ensure we can write to it. + int* val = static_cast(upb_arena_malloc(arena.ptr(), sizeof(int))); + *val = i; + } + + // Test a large allocation. + upb_arena_malloc(arena.ptr(), 1000000); + } + ASSERT(n == 0); +} + +void TestDefault() { + upb::SymbolTable symtab; + upb::Arena arena; + upb::MessageDefPtr md(upb_test_TestMessage_getmsgdef(symtab.ptr())); + upb_test_TestMessage *msg = upb_test_TestMessage_new(arena.ptr()); + size_t size = upb_json_encode(msg, md.ptr(), NULL, 0, NULL, 0, NULL); + ASSERT(size == 2); // "{}" +} + +void TestJsonNull() { + upb::SymbolTable symtab; + upb::MessageDefPtr md(upb_test_TestMessage_getmsgdef(symtab.ptr())); + upb::FieldDefPtr i32_f = md.FindFieldByName("i32"); + upb::FieldDefPtr str_f = md.FindFieldByName("str"); + ASSERT(i32_f && str_f); + ASSERT(i32_f.default_value().int32_val == 5); + ASSERT(strcmp(str_f.default_value().str_val.data, "abc") == 0); + ASSERT(str_f.default_value().str_val.size == 3); +} + +extern "C" { + +int run_tests() { + TestIteration(); + TestArena(); + TestDefault(); + + return 0; +} + +} diff --git a/third_party/upb/tests/test_table.cc b/third_party/upb/tests/test_table.cc new file mode 100644 index 00000000..84ede2be --- /dev/null +++ b/third_party/upb/tests/test_table.cc @@ -0,0 +1,709 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* + * Tests for upb_table. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tests/upb_test.h" +#include "upb/upb.hpp" +#include "upb/table_internal.h" + +#include "upb/port_def.inc" + +// Convenience interface for C++. We don't put this in upb itself because +// the table is not exposed to users. + +namespace upb { + +template upb_value MakeUpbValue(T val); +template T GetUpbValue(upb_value val); + +#define FUNCS(name, type_t, enumval) \ + template<> upb_value MakeUpbValue(type_t val) { return upb_value_ ## name(val); } \ + template<> type_t GetUpbValue(upb_value val) { return upb_value_get ## name(val); } \ + +FUNCS(int32, int32_t, UPB_CTYPE_INT32) +FUNCS(int64, int64_t, UPB_CTYPE_INT64) +FUNCS(uint32, uint32_t, UPB_CTYPE_UINT32) +FUNCS(uint64, uint64_t, UPB_CTYPE_UINT64) +FUNCS(bool, bool, UPB_CTYPE_BOOL) +FUNCS(cstr, char*, UPB_CTYPE_CSTR) +FUNCS(ptr, void*, UPB_CTYPE_PTR) +FUNCS(constptr, const void*, UPB_CTYPE_CONSTPTR) +FUNCS(fptr, upb_func*, UPB_CTYPE_FPTR) + +#undef FUNCS + +class IntTable { + public: + IntTable() { upb_inttable_init(&table_, arena_.ptr()); } + + size_t count() { return upb_inttable_count(&table_); } + + bool Insert(uintptr_t key, upb_value val) { + return upb_inttable_insert(&table_, key, val, arena_.ptr()); + } + + bool Replace(uintptr_t key, upb_value val) { + return upb_inttable_replace(&table_, key, val); + } + + std::pair Remove(uintptr_t key) { + std::pair ret; + ret.first = upb_inttable_remove(&table_, key, &ret.second); + return ret; + } + + std::pair Lookup(uintptr_t key) const { + std::pair ret; + ret.first = upb_inttable_lookup(&table_, key, &ret.second); + return ret; + } + + std::pair Lookup32(uint32_t key) const { + std::pair ret; + ret.first = upb_inttable_lookup(&table_, key, &ret.second); + return ret; + } + + void Compact() { upb_inttable_compact(&table_, arena_.ptr()); } + + class iterator : public std::iterator > { + public: + explicit iterator(IntTable* table) { + upb_inttable_begin(&iter_, &table->table_); + } + + static iterator end(IntTable* table) { + iterator iter(table); + upb_inttable_iter_setdone(&iter.iter_); + return iter; + } + + void operator++() { + return upb_inttable_next(&iter_); + } + + std::pair operator*() const { + std::pair ret; + ret.first = upb_inttable_iter_key(&iter_); + ret.second = upb_inttable_iter_value(&iter_); + return ret; + } + + bool operator==(const iterator& other) const { + return upb_inttable_iter_isequal(&iter_, &other.iter_); + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + private: + upb_inttable_iter iter_; + }; + + upb::Arena arena_; + upb_inttable table_; +}; + +class StrTable { + public: + StrTable() { upb_strtable_init(&table_, 4, arena_.ptr()); } + + size_t count() { return upb_strtable_count(&table_); } + + bool Insert(const std::string& key, upb_value val) { + return upb_strtable_insert(&table_, key.c_str(), key.size(), val, + arena_.ptr()); + } + + std::pair Remove(const std::string& key) { + std::pair ret; + ret.first = + upb_strtable_remove(&table_, key.c_str(), key.size(), &ret.second); + return ret; + } + + std::pair Lookup(const std::string& key) const { + std::pair ret; + ret.first = + upb_strtable_lookup2(&table_, key.c_str(), key.size(), &ret.second); + return ret; + } + + void Resize(size_t size_lg2) { + upb_strtable_resize(&table_, size_lg2, arena_.ptr()); + } + + class iterator : public std::iterator > { + public: + explicit iterator(StrTable* table) { + upb_strtable_begin(&iter_, &table->table_); + } + + static iterator end(StrTable* table) { + iterator iter(table); + upb_strtable_iter_setdone(&iter.iter_); + return iter; + } + + void operator++() { + return upb_strtable_next(&iter_); + } + + std::pair operator*() const { + std::pair ret; + upb_strview view = upb_strtable_iter_key(&iter_); + ret.first.assign(view.data, view.size); + ret.second = upb_strtable_iter_value(&iter_); + return ret; + } + + bool operator==(const iterator& other) const { + return upb_strtable_iter_isequal(&iter_, &other.iter_); + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + private: + upb_strtable_iter iter_; + }; + + upb::Arena arena_; + upb_strtable table_; +}; + +template class TypedStrTable { + public: + size_t count() { return table_.count(); } + + bool Insert(const std::string &key, T val) { + return table_.Insert(key, MakeUpbValue(val)); + } + + std::pair Remove(const std::string& key) { + std::pair found = table_.Remove(key); + std::pair ret; + ret.first = found.first; + if (ret.first) { + ret.second = GetUpbValue(found.second); + } + return ret; + } + + std::pair Lookup(const std::string& key) const { + std::pair found = table_.Lookup(key); + std::pair ret; + ret.first = found.first; + if (ret.first) { + ret.second = GetUpbValue(found.second); + } + return ret; + } + + void Resize(size_t size_lg2) { + table_.Resize(size_lg2); + } + + class iterator : public std::iterator > { + public: + explicit iterator(TypedStrTable* table) : iter_(&table->table_) {} + static iterator end(TypedStrTable* table) { + iterator iter(table); + iter.iter_ = StrTable::iterator::end(&table->table_); + return iter; + } + + void operator++() { ++iter_; } + + std::pair operator*() const { + std::pair val = *iter_; + std::pair ret; + ret.first = val.first; + ret.second = GetUpbValue(val.second); + return ret; + } + + bool operator==(const iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const iterator& other) const { + return iter_ != other.iter_; + } + + private: + StrTable::iterator iter_; + }; + + iterator begin() { return iterator(this); } + iterator end() { return iterator::end(this); } + + StrTable table_; +}; + +template class TypedIntTable { + public: + size_t count() { return table_.count(); } + + bool Insert(uintptr_t key, T val) { + return table_.Insert(key, MakeUpbValue(val)); + } + + bool Replace(uintptr_t key, T val) { + return table_.Replace(key, MakeUpbValue(val)); + } + + std::pair Remove(uintptr_t key) { + std::pair found = table_.Remove(key); + std::pair ret; + ret.first = found.first; + if (ret.first) { + ret.second = GetUpbValue(found.second); + } + return ret; + } + + std::pair Lookup(uintptr_t key) const { + std::pair found = table_.Lookup(key); + std::pair ret; + ret.first = found.first; + if (ret.first) { + ret.second = GetUpbValue(found.second); + } + return ret; + } + + void Compact() { table_.Compact(); } + + class iterator : public std::iterator > { + public: + explicit iterator(TypedIntTable* table) : iter_(&table->table_) {} + static iterator end(TypedIntTable* table) { + return IntTable::iterator::end(&table->table_); + } + + void operator++() { ++iter_; } + + std::pair operator*() const { + std::pair val = *iter_; + std::pair ret; + ret.first = val.first; + ret.second = GetUpbValue(val.second); + return ret; + } + + bool operator==(const iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const iterator& other) const { + return iter_ != other.iter_; + } + + private: + IntTable::iterator iter_; + }; + + iterator begin() { return iterator(this); } + iterator end() { return iterator::end(this); } + + IntTable table_; +}; + +} + +bool benchmark = false; +#define CPU_TIME_PER_TEST 0.5 + +using std::vector; + +double get_usertime() { + struct rusage usage; + getrusage(RUSAGE_SELF, &usage); + return usage.ru_utime.tv_sec + (usage.ru_utime.tv_usec/1000000.0); +} + +/* num_entries must be a power of 2. */ +void test_strtable(const vector& keys, uint32_t num_to_insert) { + /* Initialize structures. */ + std::map m; + typedef upb::TypedStrTable Table; + Table table; + std::set all; + for(size_t i = 0; i < num_to_insert; i++) { + const std::string& key = keys[i]; + all.insert(key); + table.Insert(key, key[0]); + m[key] = key[0]; + } + + /* Test correctness. */ + for(uint32_t i = 0; i < keys.size(); i++) { + const std::string& key = keys[i]; + std::pair found = table.Lookup(key); + if(m.find(key) != m.end()) { /* Assume map implementation is correct. */ + ASSERT(found.first); + ASSERT(found.second == key[0]); + ASSERT(m[key] == key[0]); + } else { + ASSERT(!found.first); + } + } + + for (Table::iterator it = table.begin(); it != table.end(); ++it) { + std::set::iterator i = all.find((*it).first); + ASSERT(i != all.end()); + all.erase(i); + } + ASSERT(all.empty()); + + // Test iteration with resizes. + + for (int i = 0; i < 10; i++) { + for (Table::iterator it = table.begin(); it != table.end(); ++it) { + // Even if we invalidate the iterator it should only return real elements. + ASSERT((*it).second == m[(*it).first]); + + // Force a resize even though the size isn't changing. + // Also forces the table size to grow so some new buckets end up empty. + int new_lg2 = table.table_.table_.t.size_lg2 + 1; + // Don't use more than 64k tables, to avoid exhausting memory. + new_lg2 = UPB_MIN(new_lg2, 16); + table.Resize(new_lg2); + } + } + +} + +/* num_entries must be a power of 2. */ +void test_inttable(int32_t *keys, uint16_t num_entries, const char *desc) { + /* Initialize structures. */ + typedef upb::TypedIntTable Table; + Table table; + uint32_t largest_key = 0; + std::map m; + std::unordered_map hm; + for(size_t i = 0; i < num_entries; i++) { + int32_t key = keys[i]; + largest_key = UPB_MAX((int32_t)largest_key, key); + table.Insert(key, key * 2); + m[key] = key*2; + hm[key] = key*2; + } + + /* Test correctness. */ + for(uint32_t i = 0; i <= largest_key; i++) { + std::pair found = table.Lookup(i); + if(m.find(i) != m.end()) { /* Assume map implementation is correct. */ + ASSERT(found.first); + ASSERT(found.second == i*2); + ASSERT(m[i] == i*2); + ASSERT(hm[i] == i*2); + } else { + ASSERT(!found.first); + } + } + + for(uint16_t i = 0; i < num_entries; i += 2) { + std::pair found = table.Remove(keys[i]); + ASSERT(found.first == (m.erase(keys[i]) == 1)); + if (found.first) ASSERT(found.second == (uint32_t)keys[i] * 2); + hm.erase(keys[i]); + m.erase(keys[i]); + } + + ASSERT(table.count() == hm.size()); + + /* Test correctness. */ + for(uint32_t i = 0; i <= largest_key; i++) { + std::pair found = table.Lookup(i); + if(m.find(i) != m.end()) { /* Assume map implementation is correct. */ + ASSERT(found.first); + ASSERT(found.second == i*2); + ASSERT(m[i] == i*2); + ASSERT(hm[i] == i*2); + } else { + ASSERT(!found.first); + } + } + + // Test replace. + for(uint32_t i = 0; i <= largest_key; i++) { + bool replaced = table.Replace(i, i*3); + if(m.find(i) != m.end()) { /* Assume map implementation is correct. */ + ASSERT(replaced); + m[i] = i * 3; + hm[i] = i * 3; + } else { + ASSERT(!replaced); + } + } + + // Compact and test correctness again. + table.Compact(); + for(uint32_t i = 0; i <= largest_key; i++) { + std::pair found = table.Lookup(i); + if(m.find(i) != m.end()) { /* Assume map implementation is correct. */ + ASSERT(found.first); + ASSERT(found.second == i*3); + ASSERT(m[i] == i*3); + ASSERT(hm[i] == i*3); + } else { + ASSERT(!found.first); + } + } + + if(!benchmark) { + return; + } + + printf("%s\n", desc); + + /* Test performance. We only test lookups for keys that are known to exist. */ + uint16_t *rand_order = new uint16_t[num_entries]; + for(uint16_t i = 0; i < num_entries; i++) { + rand_order[i] = i; + } + for(uint16_t i = num_entries - 1; i >= 1; i--) { + uint16_t rand_i = (random() / (double)RAND_MAX) * i; + ASSERT(rand_i <= i); + uint16_t tmp = rand_order[rand_i]; + rand_order[rand_i] = rand_order[i]; + rand_order[i] = tmp; + } + + uintptr_t x = 0; + const int mask = num_entries - 1; + int time_mask = 0xffff; + + printf("upb_inttable(seq): "); + fflush(stdout); + double before = get_usertime(); + unsigned int i; + +#define MAYBE_BREAK \ + if ((i & time_mask) == 0 && (get_usertime() - before) > CPU_TIME_PER_TEST) \ + break; + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[i & mask]; + upb_value v; + bool ok = upb_inttable_lookup(&table.table_.table_, key, &v); + x += (uintptr_t)ok; + } + double total = get_usertime() - before; + printf("%ld/s\n", (long)(i/total)); + double upb_seq_i = i / 100; // For later percentage calcuation. + + printf("upb_inttable(rand): "); + fflush(stdout); + before = get_usertime(); + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[rand_order[i & mask]]; + upb_value v; + bool ok = upb_inttable_lookup(&table.table_.table_, key, &v); + x += (uintptr_t)ok; + } + total = get_usertime() - before; + printf("%ld/s\n", (long)(i/total)); + double upb_rand_i = i / 100; // For later percentage calculation. + + printf("std::map(seq): "); + fflush(stdout); + before = get_usertime(); + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[i & mask]; + x += m[key]; + } + total = get_usertime() - before; + printf("%ld/s (%0.1f%% of upb)\n", (long)(i/total), i / upb_seq_i); + + printf("std::map(rand): "); + fflush(stdout); + before = get_usertime(); + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[rand_order[i & mask]]; + x += m[key]; + } + total = get_usertime() - before; + printf("%ld/s (%0.1f%% of upb)\n", (long)(i/total), i / upb_rand_i); + + printf("std::unordered_map(seq): "); + fflush(stdout); + before = get_usertime(); + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[rand_order[i & mask]]; + x += hm[key]; + } + total = get_usertime() - before; + printf("%ld/s (%0.1f%% of upb)\n", (long)(i/total), i / upb_seq_i); + + printf("std::unordered_map(rand): "); + fflush(stdout); + before = get_usertime(); + for(i = 0; true; i++) { + MAYBE_BREAK; + int32_t key = keys[rand_order[i & mask]]; + x += hm[key]; + } + total = get_usertime() - before; + if (x == INT_MAX) abort(); + printf("%ld/s (%0.1f%% of upb)\n\n", (long)(i/total), i / upb_rand_i); + delete[] rand_order; +} + +/* + * This test can't pass right now because the table can't store a value of + * (uint64_t)-1. + */ +void test_int64_max_value() { +/* + typedef upb::TypedIntTable Table; + Table table; + uintptr_t uint64_max = (uint64_t)-1; + table.Insert(1, uint64_max); + std::pair found = table.Lookup(1); + ASSERT(found.first); + ASSERT(found.second == uint64_max); +*/ +} + +int32_t *get_contiguous_keys(int32_t num) { + int32_t *buf = new int32_t[num]; + for(int32_t i = 0; i < num; i++) + buf[i] = i; + return buf; +} + +void test_delete() { + upb::Arena arena; + upb_inttable t; + upb_inttable_init(&t, arena.ptr()); + upb_inttable_insert(&t, 0, upb_value_bool(true), arena.ptr()); + upb_inttable_insert(&t, 2, upb_value_bool(true), arena.ptr()); + upb_inttable_insert(&t, 4, upb_value_bool(true), arena.ptr()); + upb_inttable_compact(&t, arena.ptr()); + upb_inttable_remove(&t, 0, NULL); + upb_inttable_remove(&t, 2, NULL); + upb_inttable_remove(&t, 4, NULL); + + upb_inttable_iter iter; + for (upb_inttable_begin(&iter, &t); !upb_inttable_done(&iter); + upb_inttable_next(&iter)) { + ASSERT(false); + } +} + +void test_init() { + for (int i = 0; i < 2048; i++) { + /* Tests that the size calculations in init() (lg2 size for target load) + * work for all expected sizes. */ + upb::Arena arena; + upb_strtable t; + upb_strtable_init(&t, i, arena.ptr()); + } +} + +extern "C" { + +int run_tests(int argc, char *argv[]) { + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "benchmark") == 0) benchmark = true; + } + + vector keys; + keys.push_back("google.protobuf.FileDescriptorSet"); + keys.push_back("google.protobuf.FileDescriptorProto"); + keys.push_back("google.protobuf.DescriptorProto"); + keys.push_back("google.protobuf.DescriptorProto.ExtensionRange"); + keys.push_back("google.protobuf.FieldDescriptorProto"); + keys.push_back("google.protobuf.EnumDescriptorProto"); + keys.push_back("google.protobuf.EnumValueDescriptorProto"); + keys.push_back("google.protobuf.ServiceDescriptorProto"); + keys.push_back("google.protobuf.MethodDescriptorProto"); + keys.push_back("google.protobuf.FileOptions"); + keys.push_back("google.protobuf.MessageOptions"); + keys.push_back("google.protobuf.FieldOptions"); + keys.push_back("google.protobuf.EnumOptions"); + keys.push_back("google.protobuf.EnumValueOptions"); + keys.push_back("google.protobuf.ServiceOptions"); + keys.push_back("google.protobuf.MethodOptions"); + keys.push_back("google.protobuf.UninterpretedOption"); + keys.push_back("google.protobuf.UninterpretedOption.NamePart"); + + for (int i = 0; i < 10; i++) { + test_strtable(keys, 18); + } + + int32_t *keys1 = get_contiguous_keys(8); + test_inttable(keys1, 8, "Table size: 8, keys: 1-8 ===="); + delete[] keys1; + + int32_t *keys2 = get_contiguous_keys(64); + test_inttable(keys2, 64, "Table size: 64, keys: 1-64 ====\n"); + delete[] keys2; + + int32_t *keys3 = get_contiguous_keys(512); + test_inttable(keys3, 512, "Table size: 512, keys: 1-512 ====\n"); + delete[] keys3; + + int32_t *keys4 = new int32_t[64]; + for(int32_t i = 0; i < 64; i++) { + if(i < 32) + keys4[i] = i+1; + else + keys4[i] = 10101+i; + } + test_inttable(keys4, 64, "Table size: 64, keys: 1-32 and 10133-10164 ====\n"); + delete[] keys4; + + test_delete(); + test_int64_max_value(); + + return 0; +} + +} diff --git a/third_party/upb/tests/testmain.cc b/third_party/upb/tests/testmain.cc new file mode 100644 index 00000000..3bd345ea --- /dev/null +++ b/third_party/upb/tests/testmain.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#ifdef USE_GOOGLE +#include "base/init_google.h" +#endif + +extern "C" { +int run_tests(int argc, char *argv[]); +} + +int main(int argc, char *argv[]) { +#ifdef USE_GOOGLE + InitGoogle(NULL, &argc, &argv, true); +#endif + run_tests(argc, argv); +} diff --git a/third_party/upb/upb/bindings/lua/upbc.cc b/third_party/upb/upb/bindings/lua/upbc.cc new file mode 100644 index 00000000..e2bb0dd6 --- /dev/null +++ b/third_party/upb/upb/bindings/lua/upbc.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/strings/str_replace.h" +#include "google/protobuf/compiler/code_generator.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" +#include +#include + +namespace protoc = ::google::protobuf::compiler; +namespace protobuf = ::google::protobuf; + +class LuaGenerator : public protoc::CodeGenerator { + bool Generate(const protobuf::FileDescriptor* file, + const std::string& parameter, protoc::GeneratorContext* context, + std::string* error) const override; + +}; + +static std::string StripExtension(absl::string_view fname) { + size_t lastdot = fname.find_last_of('.'); + if (lastdot == std::string::npos) { + return std::string(fname); + } + return std::string(fname.substr(0, lastdot)); +} + +static std::string Filename(const protobuf::FileDescriptor* file) { + return StripExtension(file->name()) + "_pb.lua"; +} + +static std::string ModuleName(const protobuf::FileDescriptor* file) { + std::string ret = StripExtension(file->name()) + "_pb"; + return absl::StrReplaceAll(ret, {{"/", "."}}); +} + +static void PrintHexDigit(char digit, protobuf::io::Printer* printer) { + char text; + if (digit < 10) { + text = '0' + digit; + } else { + text = 'A' + (digit - 10); + } + printer->WriteRaw(&text, 1); +} + +static void PrintString(int max_cols, absl::string_view* str, + protobuf::io::Printer* printer) { + printer->Print("\'"); + while (max_cols > 0 && !str->empty()) { + char ch = (*str)[0]; + if (ch == '\\') { + printer->PrintRaw("\\\\"); + max_cols--; + } else if (ch == '\'') { + printer->PrintRaw("\\'"); + max_cols--; + } else if (isprint(ch)) { + printer->WriteRaw(&ch, 1); + max_cols--; + } else { + unsigned char byte = ch; + printer->PrintRaw("\\x"); + PrintHexDigit(byte >> 4, printer); + PrintHexDigit(byte & 15, printer); + max_cols -= 4; + } + str->remove_prefix(1); + } + printer->Print("\'"); +} + +bool LuaGenerator::Generate( + const protobuf::FileDescriptor* file, + const std::string& /* parameter */, + protoc::GeneratorContext* context, + std::string* /* error */) const { + std::string filename = Filename(file); + protobuf::io::ZeroCopyOutputStream* out = context->Open(filename); + protobuf::io::Printer printer(out, '$'); + + for (int i = 0; i < file->dependency_count(); i++) { + const protobuf::FileDescriptor* dep = file->dependency(i); + printer.Print("require('$name$')\n", "name", ModuleName(dep)); + } + + printer.Print("local upb = require('upb')\n"); + + protobuf::FileDescriptorProto file_proto; + file->CopyTo(&file_proto); + std::string file_data; + file_proto.SerializeToString(&file_data); + + printer.Print("local descriptor = table.concat({\n"); + absl::string_view data(file_data); + while (!data.empty()) { + printer.Print(" "); + PrintString(72, &data, &printer); + printer.Print(",\n"); + } + printer.Print("})\n"); + + printer.Print("return upb._generated_module(descriptor)\n"); + + return true; +} + +int main(int argc, char** argv) { + LuaGenerator generator; + return google::protobuf::compiler::PluginMain(argc, argv, &generator); +} diff --git a/third_party/upb/upbc/common.cc b/third_party/upb/upbc/common.cc new file mode 100644 index 00000000..d2df929c --- /dev/null +++ b/third_party/upb/upbc/common.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "absl/strings/str_replace.h" +#include "upbc/common.h" + +namespace upbc { +namespace { + +namespace protobuf = ::google::protobuf; + +void AddMessages(const protobuf::Descriptor* message, + std::vector* messages) { + messages->push_back(message); + for (int i = 0; i < message->nested_type_count(); i++) { + AddMessages(message->nested_type(i), messages); + } +} + +} // namespace + +std::string StripExtension(absl::string_view fname) { + size_t lastdot = fname.find_last_of('.'); + if (lastdot == std::string::npos) { + return std::string(fname); + } + return std::string(fname.substr(0, lastdot)); +} + +std::string ToCIdent(absl::string_view str) { + return absl::StrReplaceAll(str, {{".", "_"}, {"/", "_"}}); +} + +std::string ToPreproc(absl::string_view str) { + return absl::AsciiStrToUpper(ToCIdent(str)); +} + +void EmitFileWarning(const protobuf::FileDescriptor* file, Output& output) { + output( + "/* This file was generated by upbc (the upb compiler) from the input\n" + " * file:\n" + " *\n" + " * $0\n" + " *\n" + " * Do not edit -- your changes will be discarded when the file is\n" + " * regenerated. */\n\n", + file->name()); +} + +std::vector SortedMessages( + const protobuf::FileDescriptor* file) { + std::vector messages; + for (int i = 0; i < file->message_type_count(); i++) { + AddMessages(file->message_type(i), &messages); + } + return messages; +} + +std::string MessageName(const protobuf::Descriptor* descriptor) { + return ToCIdent(descriptor->full_name()); +} + +std::string MessageInit(const protobuf::Descriptor* descriptor) { + return MessageName(descriptor) + "_msginit"; +} + +} // namespace upbc diff --git a/third_party/upb/upbc/message_layout.cc b/third_party/upb/upbc/message_layout.cc new file mode 100644 index 00000000..1c5cd0b1 --- /dev/null +++ b/third_party/upb/upbc/message_layout.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "upbc/message_layout.h" +#include "google/protobuf/descriptor.pb.h" + +namespace upbc { + +namespace protobuf = ::google::protobuf; + +static int64_t DivRoundUp(int64_t a, int64_t b) { + ABSL_ASSERT(a >= 0); + ABSL_ASSERT(b > 0); + return (a + b - 1) / b; +} + +MessageLayout::Size MessageLayout::Place( + MessageLayout::SizeAndAlign size_and_align) { + Size offset = size_; + offset.AlignUp(size_and_align.align); + size_ = offset; + size_.Add(size_and_align.size); + //maxalign_.MaxFrom(size_and_align.align); + maxalign_.MaxFrom(size_and_align.size); + return offset; +} + +bool MessageLayout::HasHasbit(const protobuf::FieldDescriptor* field) { + return field->has_presence() && !field->real_containing_oneof() && + !field->containing_type()->options().map_entry(); +} + +MessageLayout::SizeAndAlign MessageLayout::SizeOf( + const protobuf::FieldDescriptor* field) { + if (field->is_repeated()) { + return {{4, 8}, {4, 8}}; // Pointer to array object. + } else { + return SizeOfUnwrapped(field); + } +} + +MessageLayout::SizeAndAlign MessageLayout::SizeOfUnwrapped( + const protobuf::FieldDescriptor* field) { + switch (field->cpp_type()) { + case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return {{4, 8}, {4, 8}}; // Pointer to message. + case protobuf::FieldDescriptor::CPPTYPE_STRING: + return {{8, 16}, {4, 8}}; // upb_strview + case protobuf::FieldDescriptor::CPPTYPE_BOOL: + return {{1, 1}, {1, 1}}; + case protobuf::FieldDescriptor::CPPTYPE_FLOAT: + case protobuf::FieldDescriptor::CPPTYPE_INT32: + case protobuf::FieldDescriptor::CPPTYPE_UINT32: + case protobuf::FieldDescriptor::CPPTYPE_ENUM: + return {{4, 4}, {4, 4}}; + case protobuf::FieldDescriptor::CPPTYPE_INT64: + case protobuf::FieldDescriptor::CPPTYPE_UINT64: + case protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return {{8, 8}, {8, 8}}; + } + assert(false); + return {{-1, -1}, {-1, -1}}; +} + +int64_t MessageLayout::FieldLayoutRank(const protobuf::FieldDescriptor* field) { + // Order: + // 1, 2, 3. primitive fields (8, 4, 1 byte) + // 4. string fields + // 5. submessage fields + // 6. repeated fields + // + // This has the following nice properties: + // + // 1. padding alignment is (nearly) minimized. + // 2. fields that might have defaults (1-4) are segregated + // from fields that are always zero-initialized (5-7). + // + // We skip oneof fields, because they are emitted in a separate pass. + int64_t rank; + if (field->containing_oneof()) { + fprintf(stderr, "shouldn't have oneofs here.\n"); + abort(); + } else if (field->label() == protobuf::FieldDescriptor::LABEL_REPEATED) { + rank = 6; + } else { + switch (field->cpp_type()) { + case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + rank = 5; + break; + case protobuf::FieldDescriptor::CPPTYPE_STRING: + rank = 4; + break; + case protobuf::FieldDescriptor::CPPTYPE_BOOL: + rank = 3; + break; + case protobuf::FieldDescriptor::CPPTYPE_FLOAT: + case protobuf::FieldDescriptor::CPPTYPE_INT32: + case protobuf::FieldDescriptor::CPPTYPE_UINT32: + rank = 2; + break; + default: + rank = 1; + break; + } + } + + // Break ties with field number. + return (rank << 29) | field->number(); +} + +void MessageLayout::ComputeLayout(const protobuf::Descriptor* descriptor) { + size_ = Size{0, 0}; + maxalign_ = Size{8, 8}; + + if (descriptor->options().map_entry()) { + // Map entries aren't actually stored, they are only used during parsing. + // For parsing, it helps a lot if all map entry messages have the same + // layout. + SizeAndAlign size{{8, 16}, {4, 8}}; // upb_strview + field_offsets_[descriptor->FindFieldByNumber(1)] = Place(size); + field_offsets_[descriptor->FindFieldByNumber(2)] = Place(size); + } else { + PlaceNonOneofFields(descriptor); + PlaceOneofFields(descriptor); + } + + // Align overall size up to max size. + size_.AlignUp(maxalign_); +} + +void MessageLayout::PlaceNonOneofFields( + const protobuf::Descriptor* descriptor) { + std::vector field_order; + for (int i = 0; i < descriptor->field_count(); i++) { + const protobuf::FieldDescriptor* field = descriptor->field(i); + if (!field->containing_oneof()) { + field_order.push_back(descriptor->field(i)); + } + } + std::sort(field_order.begin(), field_order.end(), + [](const protobuf::FieldDescriptor* a, + const protobuf::FieldDescriptor* b) { + return FieldLayoutRank(a) < FieldLayoutRank(b); + }); + + // Place/count hasbits. + int hasbit_count = 0; + for (auto field : FieldHotnessOrder(descriptor)) { + if (HasHasbit(field)) { + // We don't use hasbit 0, so that 0 can indicate "no presence" in the + // table. This wastes one hasbit, but we don't worry about it for now. + hasbit_indexes_[field] = ++hasbit_count; + } + } + + // Place hasbits at the beginning. + int64_t hasbit_bytes = DivRoundUp(hasbit_count, 8); + Place(SizeAndAlign{{hasbit_bytes, hasbit_bytes}, {1, 1}}); + + // Place non-oneof fields. + for (auto field : field_order) { + field_offsets_[field] = Place(SizeOf(field)); + } +} + +void MessageLayout::PlaceOneofFields(const protobuf::Descriptor* descriptor) { + std::vector oneof_order; + for (int i = 0; i < descriptor->oneof_decl_count(); i++) { + oneof_order.push_back(descriptor->oneof_decl(i)); + } + std::sort(oneof_order.begin(), oneof_order.end(), + [](const protobuf::OneofDescriptor* a, + const protobuf::OneofDescriptor* b) { + return a->full_name() < b->full_name(); + }); + + for (auto oneof : oneof_order) { + SizeAndAlign oneof_maxsize{{0, 0}, {0, 0}}; + // Calculate max size. + for (int i = 0; i < oneof->field_count(); i++) { + oneof_maxsize.MaxFrom(SizeOf(oneof->field(i))); + } + + // Place discriminator enum and data. + Size data = Place(oneof_maxsize); + Size discriminator = Place(SizeAndAlign{{4, 4}, {4, 4}}); + + oneof_case_offsets_[oneof] = discriminator; + + for (int i = 0; i < oneof->field_count(); i++) { + field_offsets_[oneof->field(i)] = data; + } + } +} + +} // namespace upbc diff --git a/third_party/upb/upbc/protoc-gen-upb.cc b/third_party/upb/upbc/protoc-gen-upb.cc new file mode 100644 index 00000000..a68a691b --- /dev/null +++ b/third_party/upb/upbc/protoc-gen-upb.cc @@ -0,0 +1,1059 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/substitute.h" +#include "google/protobuf/compiler/code_generator.h" +#include "google/protobuf/compiler/plugin.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/wire_format.h" +#include "upbc/common.h" +#include "upbc/message_layout.h" + +namespace upbc { +namespace { + +namespace protoc = ::google::protobuf::compiler; +namespace protobuf = ::google::protobuf; + +std::string HeaderFilename(std::string proto_filename) { + return StripExtension(proto_filename) + ".upb.h"; +} + +std::string SourceFilename(std::string proto_filename) { + return StripExtension(proto_filename) + ".upb.c"; +} + +void AddEnums(const protobuf::Descriptor* message, + std::vector* enums) { + for (int i = 0; i < message->enum_type_count(); i++) { + enums->push_back(message->enum_type(i)); + } + for (int i = 0; i < message->nested_type_count(); i++) { + AddEnums(message->nested_type(i), enums); + } +} + +template +void SortDefs(std::vector* defs) { + std::sort(defs->begin(), defs->end(), + [](T a, T b) { return a->full_name() < b->full_name(); }); +} + +std::vector SortedEnums( + const protobuf::FileDescriptor* file) { + std::vector enums; + for (int i = 0; i < file->enum_type_count(); i++) { + enums.push_back(file->enum_type(i)); + } + for (int i = 0; i < file->message_type_count(); i++) { + AddEnums(file->message_type(i), &enums); + } + SortDefs(&enums); + return enums; +} + +std::vector FieldNumberOrder( + const protobuf::Descriptor* message) { + std::vector fields; + for (int i = 0; i < message->field_count(); i++) { + fields.push_back(message->field(i)); + } + std::sort(fields.begin(), fields.end(), + [](const protobuf::FieldDescriptor* a, + const protobuf::FieldDescriptor* b) { + return a->number() < b->number(); + }); + return fields; +} + +std::vector SortedSubmessages( + const protobuf::Descriptor* message) { + std::vector ret; + for (int i = 0; i < message->field_count(); i++) { + if (message->field(i)->cpp_type() == + protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + ret.push_back(message->field(i)); + } + } + std::sort(ret.begin(), ret.end(), + [](const protobuf::FieldDescriptor* a, + const protobuf::FieldDescriptor* b) { + return a->message_type()->full_name() < + b->message_type()->full_name(); + }); + return ret; +} + +std::string EnumValueSymbol(const protobuf::EnumValueDescriptor* value) { + return ToCIdent(value->full_name()); +} + +std::string GetSizeInit(const MessageLayout::Size& size) { + return absl::Substitute("UPB_SIZE($0, $1)", size.size32, size.size64); +} + +std::string CTypeInternal(const protobuf::FieldDescriptor* field, + bool is_const) { + std::string maybe_const = is_const ? "const " : ""; + switch (field->cpp_type()) { + case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + std::string maybe_struct = + field->file() != field->message_type()->file() ? "struct " : ""; + return maybe_const + maybe_struct + MessageName(field->message_type()) + + "*"; + } + case protobuf::FieldDescriptor::CPPTYPE_BOOL: + return "bool"; + case protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return "float"; + case protobuf::FieldDescriptor::CPPTYPE_INT32: + case protobuf::FieldDescriptor::CPPTYPE_ENUM: + return "int32_t"; + case protobuf::FieldDescriptor::CPPTYPE_UINT32: + return "uint32_t"; + case protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return "double"; + case protobuf::FieldDescriptor::CPPTYPE_INT64: + return "int64_t"; + case protobuf::FieldDescriptor::CPPTYPE_UINT64: + return "uint64_t"; + case protobuf::FieldDescriptor::CPPTYPE_STRING: + return "upb_strview"; + default: + fprintf(stderr, "Unexpected type"); + abort(); + } +} + +std::string SizeLg2(const protobuf::FieldDescriptor* field) { + switch (field->cpp_type()) { + case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return "UPB_SIZE(2, 3)"; + case protobuf::FieldDescriptor::CPPTYPE_ENUM: + return std::to_string(2); + case protobuf::FieldDescriptor::CPPTYPE_BOOL: + return std::to_string(1); + case protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return std::to_string(2); + case protobuf::FieldDescriptor::CPPTYPE_INT32: + return std::to_string(2); + case protobuf::FieldDescriptor::CPPTYPE_UINT32: + return std::to_string(2); + case protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return std::to_string(3); + case protobuf::FieldDescriptor::CPPTYPE_INT64: + return std::to_string(3); + case protobuf::FieldDescriptor::CPPTYPE_UINT64: + return std::to_string(3); + case protobuf::FieldDescriptor::CPPTYPE_STRING: + return "UPB_SIZE(3, 4)"; + default: + fprintf(stderr, "Unexpected type"); + abort(); + } +} + +std::string FieldDefault(const protobuf::FieldDescriptor* field) { + switch (field->cpp_type()) { + case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: + return "NULL"; + case protobuf::FieldDescriptor::CPPTYPE_STRING: + return absl::Substitute("upb_strview_make(\"$0\", strlen(\"$0\"))", + absl::CEscape(field->default_value_string())); + case protobuf::FieldDescriptor::CPPTYPE_INT32: + return absl::StrCat(field->default_value_int32()); + case protobuf::FieldDescriptor::CPPTYPE_INT64: + return absl::StrCat(field->default_value_int64()); + case protobuf::FieldDescriptor::CPPTYPE_UINT32: + return absl::StrCat(field->default_value_uint32()); + case protobuf::FieldDescriptor::CPPTYPE_UINT64: + return absl::StrCat(field->default_value_uint64()); + case protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return absl::StrCat(field->default_value_float()); + case protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return absl::StrCat(field->default_value_double()); + case protobuf::FieldDescriptor::CPPTYPE_BOOL: + return field->default_value_bool() ? "true" : "false"; + case protobuf::FieldDescriptor::CPPTYPE_ENUM: + // Use a number instead of a symbolic name so that we don't require + // this enum's header to be included. + return absl::StrCat(field->default_value_enum()->number()); + } + ABSL_ASSERT(false); + return "XXX"; +} + +std::string CType(const protobuf::FieldDescriptor* field) { + return CTypeInternal(field, false); +} + +std::string CTypeConst(const protobuf::FieldDescriptor* field) { + return CTypeInternal(field, true); +} + +void DumpEnumValues(const protobuf::EnumDescriptor* desc, Output& output) { + std::vector values; + for (int i = 0; i < desc->value_count(); i++) { + values.push_back(desc->value(i)); + } + std::sort(values.begin(), values.end(), + [](const protobuf::EnumValueDescriptor* a, + const protobuf::EnumValueDescriptor* b) { + return a->number() < b->number(); + }); + + for (size_t i = 0; i < values.size(); i++) { + auto value = values[i]; + output(" $0 = $1", EnumValueSymbol(value), value->number()); + if (i != values.size() - 1) { + output(","); + } + output("\n"); + } +} + +void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output) { + MessageLayout layout(message); + + output("/* $0 */\n\n", message->full_name()); + std::string msg_name = ToCIdent(message->full_name()); + + if (!message->options().map_entry()) { + output( + "UPB_INLINE $0 *$0_new(upb_arena *arena) {\n" + " return ($0 *)_upb_msg_new(&$1, arena);\n" + "}\n" + "UPB_INLINE $0 *$0_parse(const char *buf, size_t size,\n" + " upb_arena *arena) {\n" + " $0 *ret = $0_new(arena);\n" + " if (!ret) return NULL;\n" + " if (!upb_decode(buf, size, ret, &$1, arena)) return NULL;\n" + " return ret;\n" + "}\n" + "UPB_INLINE $0 *$0_parse_ex(const char *buf, size_t size,\n" + " const upb_extreg *extreg, int options,\n" + " upb_arena *arena) {\n" + " $0 *ret = $0_new(arena);\n" + " if (!ret) return NULL;\n" + " if (!_upb_decode(buf, size, ret, &$1, extreg, options, arena)) {\n" + " return NULL;\n" + " }\n" + " return ret;\n" + "}\n" + "UPB_INLINE char *$0_serialize(const $0 *msg, upb_arena *arena, size_t " + "*len) {\n" + " return upb_encode(msg, &$1, arena, len);\n" + "}\n" + "\n", + MessageName(message), MessageInit(message)); + } + + for (int i = 0; i < message->real_oneof_decl_count(); i++) { + const protobuf::OneofDescriptor* oneof = message->oneof_decl(i); + std::string fullname = ToCIdent(oneof->full_name()); + output("typedef enum {\n"); + for (int j = 0; j < oneof->field_count(); j++) { + const protobuf::FieldDescriptor* field = oneof->field(j); + output(" $0_$1 = $2,\n", fullname, field->name(), field->number()); + } + output( + " $0_NOT_SET = 0\n" + "} $0_oneofcases;\n", + fullname); + output( + "UPB_INLINE $0_oneofcases $1_$2_case(const $1* msg) { " + "return ($0_oneofcases)*UPB_PTR_AT(msg, $3, int32_t); }\n" + "\n", + fullname, msg_name, oneof->name(), + GetSizeInit(layout.GetOneofCaseOffset(oneof))); + } + + // Generate const methods. + + for (auto field : FieldNumberOrder(message)) { + // Generate hazzer (if any). + if (layout.HasHasbit(field)) { + output( + "UPB_INLINE bool $0_has_$1(const $0 *msg) { " + "return _upb_hasbit(msg, $2); }\n", + msg_name, field->name(), layout.GetHasbitIndex(field)); + } else if (field->real_containing_oneof()) { + output( + "UPB_INLINE bool $0_has_$1(const $0 *msg) { " + "return _upb_getoneofcase(msg, $2) == $3; }\n", + msg_name, field->name(), + GetSizeInit( + layout.GetOneofCaseOffset(field->real_containing_oneof())), + field->number()); + } else if (field->message_type()) { + output( + "UPB_INLINE bool $0_has_$1(const $0 *msg) { " + "return _upb_has_submsg_nohasbit(msg, $2); }\n", + msg_name, field->name(), GetSizeInit(layout.GetFieldOffset(field))); + } + + // Generate getter. + if (field->is_map()) { + const protobuf::Descriptor* entry = field->message_type(); + const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1); + const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2); + output( + "UPB_INLINE size_t $0_$1_size(const $0 *msg) {" + "return _upb_msg_map_size(msg, $2); }\n", + msg_name, field->name(), GetSizeInit(layout.GetFieldOffset(field))); + output( + "UPB_INLINE bool $0_$1_get(const $0 *msg, $2 key, $3 *val) { " + "return _upb_msg_map_get(msg, $4, &key, $5, val, $6); }\n", + msg_name, field->name(), CType(key), CType(val), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)", + val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(*val)"); + output( + "UPB_INLINE $0 $1_$2_next(const $1 *msg, size_t* iter) { " + "return ($0)_upb_msg_map_next(msg, $3, iter); }\n", + CTypeConst(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + } else if (message->options().map_entry()) { + output( + "UPB_INLINE $0 $1_$2(const $1 *msg) {\n" + " $3 ret;\n" + " _upb_msg_map_$2(msg, &ret, $4);\n" + " return ret;\n" + "}\n", + CTypeConst(field), msg_name, field->name(), CType(field), + field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(ret)"); + } else if (field->is_repeated()) { + output( + "UPB_INLINE $0 const* $1_$2(const $1 *msg, size_t *len) { " + "return ($0 const*)_upb_array_accessor(msg, $3, len); }\n", + CTypeConst(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + } else if (field->real_containing_oneof()) { + output( + "UPB_INLINE $0 $1_$2(const $1 *msg) { " + "return UPB_READ_ONEOF(msg, $0, $3, $4, $5, $6); }\n", + CTypeConst(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field)), + GetSizeInit(layout.GetOneofCaseOffset(field->real_containing_oneof())), + field->number(), FieldDefault(field)); + } else { + output( + "UPB_INLINE $0 $1_$2(const $1 *msg) { " + "return *UPB_PTR_AT(msg, $3, $0); }\n", + CTypeConst(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + } + } + + output("\n"); + + // Generate mutable methods. + + for (auto field : FieldNumberOrder(message)) { + if (field->is_map()) { + // TODO(haberman): add map-based mutators. + const protobuf::Descriptor* entry = field->message_type(); + const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1); + const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2); + output( + "UPB_INLINE void $0_$1_clear($0 *msg) { _upb_msg_map_clear(msg, $2); }\n", + msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + output( + "UPB_INLINE bool $0_$1_set($0 *msg, $2 key, $3 val, upb_arena *a) { " + "return _upb_msg_map_set(msg, $4, &key, $5, &val, $6, a); }\n", + msg_name, field->name(), CType(key), CType(val), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)", + val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(val)"); + output( + "UPB_INLINE bool $0_$1_delete($0 *msg, $2 key) { " + "return _upb_msg_map_delete(msg, $3, &key, $4); }\n", + msg_name, field->name(), CType(key), + GetSizeInit(layout.GetFieldOffset(field)), + key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(key)"); + output( + "UPB_INLINE $0 $1_$2_nextmutable($1 *msg, size_t* iter) { " + "return ($0)_upb_msg_map_next(msg, $3, iter); }\n", + CType(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + } else if (field->is_repeated()) { + output( + "UPB_INLINE $0* $1_mutable_$2($1 *msg, size_t *len) {\n" + " return ($0*)_upb_array_mutable_accessor(msg, $3, len);\n" + "}\n", + CType(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field))); + output( + "UPB_INLINE $0* $1_resize_$2($1 *msg, size_t len, " + "upb_arena *arena) {\n" + " return ($0*)_upb_array_resize_accessor2(msg, $3, len, $4, arena);\n" + "}\n", + CType(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field)), + SizeLg2(field)); + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + output( + "UPB_INLINE struct $0* $1_add_$2($1 *msg, upb_arena *arena) {\n" + " struct $0* sub = (struct $0*)_upb_msg_new(&$3, arena);\n" + " bool ok = _upb_array_append_accessor2(\n" + " msg, $4, $5, &sub, arena);\n" + " if (!ok) return NULL;\n" + " return sub;\n" + "}\n", + MessageName(field->message_type()), msg_name, field->name(), + MessageInit(field->message_type()), + GetSizeInit(layout.GetFieldOffset(field)), + SizeLg2(field)); + } else { + output( + "UPB_INLINE bool $1_add_$2($1 *msg, $0 val, upb_arena *arena) {\n" + " return _upb_array_append_accessor2(msg, $3, $4, &val,\n" + " arena);\n" + "}\n", + CType(field), msg_name, field->name(), + GetSizeInit(layout.GetFieldOffset(field)), + SizeLg2(field)); + } + } else { + // Non-repeated field. + if (message->options().map_entry() && field->name() == "key") { + // Key cannot be mutated. + continue; + } + + // The common function signature for all setters. Varying implementations + // follow. + output("UPB_INLINE void $0_set_$1($0 *msg, $2 value) {\n", msg_name, + field->name(), CType(field)); + + if (message->options().map_entry()) { + output( + " _upb_msg_map_set_value(msg, &value, $0);\n" + "}\n", + field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING + ? "0" + : "sizeof(" + CType(field) + ")"); + } else if (field->real_containing_oneof()) { + output( + " UPB_WRITE_ONEOF(msg, $0, $1, value, $2, $3);\n" + "}\n", + CType(field), GetSizeInit(layout.GetFieldOffset(field)), + GetSizeInit( + layout.GetOneofCaseOffset(field->real_containing_oneof())), + field->number()); + } else { + if (MessageLayout::HasHasbit(field)) { + output(" _upb_sethas(msg, $0);\n", layout.GetHasbitIndex(field)); + } + output( + " *UPB_PTR_AT(msg, $1, $0) = value;\n" + "}\n", + CType(field), GetSizeInit(layout.GetFieldOffset(field))); + } + + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + !message->options().map_entry()) { + output( + "UPB_INLINE struct $0* $1_mutable_$2($1 *msg, upb_arena *arena) {\n" + " struct $0* sub = (struct $0*)$1_$2(msg);\n" + " if (sub == NULL) {\n" + " sub = (struct $0*)_upb_msg_new(&$3, arena);\n" + " if (!sub) return NULL;\n" + " $1_set_$2(msg, sub);\n" + " }\n" + " return sub;\n" + "}\n", + MessageName(field->message_type()), msg_name, field->name(), + MessageInit(field->message_type())); + } + } + } + + output("\n"); +} + +void WriteHeader(const protobuf::FileDescriptor* file, Output& output) { + EmitFileWarning(file, output); + output( + "#ifndef $0_UPB_H_\n" + "#define $0_UPB_H_\n\n" + "#include \"upb/msg_internal.h\"\n" + "#include \"upb/decode.h\"\n" + "#include \"upb/decode_fast.h\"\n" + "#include \"upb/encode.h\"\n\n", + ToPreproc(file->name())); + + for (int i = 0; i < file->public_dependency_count(); i++) { + const auto& name = file->public_dependency(i)->name(); + if (i == 0) { + output("/* Public Imports. */\n"); + } + output("#include \"$0\"\n", HeaderFilename(name)); + if (i == file->public_dependency_count() - 1) { + output("\n"); + } + } + + output( + "#include \"upb/port_def.inc\"\n" + "\n" + "#ifdef __cplusplus\n" + "extern \"C\" {\n" + "#endif\n" + "\n"); + + const std::vector this_file_messages = + SortedMessages(file); + + // Forward-declare types defined in this file. + for (auto message : this_file_messages) { + output("struct $0;\n", ToCIdent(message->full_name())); + } + for (auto message : this_file_messages) { + output("typedef struct $0 $0;\n", ToCIdent(message->full_name())); + } + for (auto message : this_file_messages) { + output("extern const upb_msglayout $0;\n", MessageInit(message)); + } + + // Forward-declare types not in this file, but used as submessages. + // Order by full name for consistent ordering. + std::map forward_messages; + + for (auto* message : this_file_messages) { + for (int i = 0; i < message->field_count(); i++) { + const protobuf::FieldDescriptor* field = message->field(i); + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + field->file() != field->message_type()->file()) { + forward_messages[field->message_type()->full_name()] = + field->message_type(); + } + } + } + for (const auto& pair : forward_messages) { + output("struct $0;\n", MessageName(pair.second)); + } + for (const auto& pair : forward_messages) { + output("extern const upb_msglayout $0;\n", MessageInit(pair.second)); + } + + if (!this_file_messages.empty()) { + output("\n"); + } + + std::vector this_file_enums = + SortedEnums(file); + + for (auto enumdesc : this_file_enums) { + output("typedef enum {\n"); + DumpEnumValues(enumdesc, output); + output("} $0;\n\n", ToCIdent(enumdesc->full_name())); + } + + output("\n"); + + for (auto message : this_file_messages) { + GenerateMessageInHeader(message, output); + } + + output( + "#ifdef __cplusplus\n" + "} /* extern \"C\" */\n" + "#endif\n" + "\n" + "#include \"upb/port_undef.inc\"\n" + "\n" + "#endif /* $0_UPB_H_ */\n", + ToPreproc(file->name())); +} + +int TableDescriptorType(const protobuf::FieldDescriptor* field) { + if (field->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO2 && + field->type() == protobuf::FieldDescriptor::TYPE_STRING) { + // From the perspective of the binary encoder/decoder, proto2 string fields + // are identical to bytes fields. Only in proto3 do we check UTF-8 for + // string fields at parse time. + // + // If we ever use these tables for JSON encoding/decoding (for example by + // embedding field names on the side) we will have to revisit this, because + // string vs. bytes behavior is not affected by proto2 vs proto3. + return protobuf::FieldDescriptor::TYPE_BYTES; + } else { + return field->type(); + } +} + +struct SubmsgArray { + public: + SubmsgArray(const protobuf::Descriptor* message) : message_(message) { + MessageLayout layout(message); + std::vector sorted_submsgs = + SortedSubmessages(message); + int i = 0; + for (auto submsg : sorted_submsgs) { + if (indexes_.find(submsg->message_type()) != indexes_.end()) { + continue; + } + submsgs_.push_back(submsg->message_type()); + indexes_[submsg->message_type()] = i++; + } + } + + const std::vector& submsgs() const { + return submsgs_; + } + + int GetIndex(const protobuf::FieldDescriptor* field) { + (void)message_; + assert(field->containing_type() == message_); + auto it = indexes_.find(field->message_type()); + assert(it != indexes_.end()); + return it->second; + } + + private: + const protobuf::Descriptor* message_; + std::vector submsgs_; + absl::flat_hash_map indexes_; +}; + +typedef std::pair TableEntry; + +uint64_t GetEncodedTag(const protobuf::FieldDescriptor* field) { + protobuf::internal::WireFormatLite::WireType wire_type = + protobuf::internal::WireFormat::WireTypeForField(field); + uint32_t unencoded_tag = + protobuf::internal::WireFormatLite::MakeTag(field->number(), wire_type); + uint8_t tag_bytes[10] = {0}; + protobuf::io::CodedOutputStream::WriteVarint32ToArray(unencoded_tag, + tag_bytes); + uint64_t encoded_tag = 0; + memcpy(&encoded_tag, tag_bytes, sizeof(encoded_tag)); + // TODO: byte-swap for big endian. + return encoded_tag; +} + +int GetTableSlot(const protobuf::FieldDescriptor* field) { + uint64_t tag = GetEncodedTag(field); + if (tag > 0x7fff) { + // Tag must fit within a two-byte varint. + return -1; + } + return (tag & 0xf8) >> 3; +} + +bool TryFillTableEntry(const protobuf::Descriptor* message, + const MessageLayout& layout, + const protobuf::FieldDescriptor* field, + TableEntry& ent) { + std::string type = ""; + std::string cardinality = ""; + switch (field->type()) { + case protobuf::FieldDescriptor::TYPE_BOOL: + type = "b1"; + break; + case protobuf::FieldDescriptor::TYPE_INT32: + case protobuf::FieldDescriptor::TYPE_ENUM: + case protobuf::FieldDescriptor::TYPE_UINT32: + type = "v4"; + break; + case protobuf::FieldDescriptor::TYPE_INT64: + case protobuf::FieldDescriptor::TYPE_UINT64: + type = "v8"; + break; + case protobuf::FieldDescriptor::TYPE_FIXED32: + case protobuf::FieldDescriptor::TYPE_SFIXED32: + case protobuf::FieldDescriptor::TYPE_FLOAT: + type = "f4"; + break; + case protobuf::FieldDescriptor::TYPE_FIXED64: + case protobuf::FieldDescriptor::TYPE_SFIXED64: + case protobuf::FieldDescriptor::TYPE_DOUBLE: + type = "f8"; + break; + case protobuf::FieldDescriptor::TYPE_SINT32: + type = "z4"; + break; + case protobuf::FieldDescriptor::TYPE_SINT64: + type = "z8"; + break; + case protobuf::FieldDescriptor::TYPE_STRING: + if (field->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO3) { + // Only proto3 validates UTF-8. + type = "s"; + break; + } + ABSL_FALLTHROUGH_INTENDED; + case protobuf::FieldDescriptor::TYPE_BYTES: + type = "b"; + break; + case protobuf::FieldDescriptor::TYPE_MESSAGE: + if (field->is_map()) { + return false; // Not supported yet (ever?). + } + type = "m"; + break; + default: + return false; // Not supported yet. + } + + switch (field->label()) { + case protobuf::FieldDescriptor::LABEL_REPEATED: + if (field->is_packed()) { + cardinality = "p"; + } else { + cardinality = "r"; + } + break; + case protobuf::FieldDescriptor::LABEL_OPTIONAL: + case protobuf::FieldDescriptor::LABEL_REQUIRED: + if (field->real_containing_oneof()) { + cardinality = "o"; + } else { + cardinality = "s"; + } + break; + } + + uint64_t expected_tag = GetEncodedTag(field); + MessageLayout::Size offset = layout.GetFieldOffset(field); + + // Data is: + // + // 48 32 16 0 + // |--------|--------|--------|--------|--------|--------|--------|--------| + // | offset (16) |case offset (16) |presence| submsg | exp. tag (16) | + // |--------|--------|--------|--------|--------|--------|--------|--------| + // + // - |presence| is either hasbit index or field number for oneofs. + + uint64_t data = offset.size64 << 48 | expected_tag; + + if (field->is_repeated()) { + // No hasbit/oneof-related fields. + } if (field->real_containing_oneof()) { + MessageLayout::Size case_offset = + layout.GetOneofCaseOffset(field->real_containing_oneof()); + if (case_offset.size64 > 0xffff) return false; + assert(field->number() < 256); + data |= field->number() << 24; + data |= case_offset.size64 << 32; + } else { + uint64_t hasbit_index = 63; // No hasbit (set a high, unused bit). + if (layout.HasHasbit(field)) { + hasbit_index = layout.GetHasbitIndex(field); + if (hasbit_index > 31) return false; + } + data |= hasbit_index << 24; + } + + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + SubmsgArray submsg_array(message); + uint64_t idx = submsg_array.GetIndex(field); + if (idx > 255) return false; + data |= idx << 16; + + std::string size_ceil = "max"; + size_t size = SIZE_MAX; + if (field->message_type()->file() == field->file()) { + // We can only be guaranteed the size of the sub-message if it is in the + // same file as us. We could relax this to increase the speed of + // cross-file sub-message parsing if we are comfortable requiring that + // users compile all messages at the same time. + MessageLayout sub_layout(field->message_type()); + size = sub_layout.message_size().size64 + 8; + } + std::vector breaks = {64, 128, 192, 256}; + for (auto brk : breaks) { + if (size <= brk) { + size_ceil = std::to_string(brk); + break; + } + } + ent.first = absl::Substitute("upb_p$0$1_$2bt_max$3b", cardinality, type, + expected_tag > 0xff ? "2" : "1", size_ceil); + + } else { + ent.first = absl::Substitute("upb_p$0$1_$2bt", cardinality, type, + expected_tag > 0xff ? "2" : "1"); + } + ent.second = data; + return true; +} + +std::vector FastDecodeTable(const protobuf::Descriptor* message, + const MessageLayout& layout) { + std::vector table; + for (const auto field : FieldHotnessOrder(message)) { + TableEntry ent; + int slot = GetTableSlot(field); + // std::cerr << "table slot: " << field->number() << ": " << slot << "\n"; + if (slot < 0) { + // Tag can't fit in the table. + continue; + } + if (!TryFillTableEntry(message, layout, field, ent)) { + // Unsupported field type or offset, hasbit index, etc. doesn't fit. + continue; + } + while ((size_t)slot >= table.size()) { + size_t size = std::max(static_cast(1), table.size() * 2); + table.resize(size, TableEntry{"fastdecode_generic", 0}); + } + if (table[slot].first != "fastdecode_generic") { + // A hotter field already filled this slot. + continue; + } + table[slot] = ent; + } + return table; +} + +void WriteField(const protobuf::FieldDescriptor* field, + absl::string_view offset, absl::string_view presence, + int submsg_index, Output& output) { + std::string mode; + if (field->is_map()) { + mode = "_UPB_MODE_MAP"; + } else if (field->is_repeated()) { + mode = "_UPB_MODE_ARRAY"; + } else { + mode = "_UPB_MODE_SCALAR"; + } + + if (field->is_packed()) { + absl::StrAppend(&mode, " | _UPB_MODE_IS_PACKED"); + } + + output("{$0, $1, $2, $3, $4, $5}", field->number(), offset, presence, + submsg_index, TableDescriptorType(field), mode); +} + +// Writes a single field into a .upb.c source file. +void WriteMessageField(const protobuf::FieldDescriptor* field, + const MessageLayout& layout, int submsg_index, + Output& output) { + std::string presence = "0"; + + if (MessageLayout::HasHasbit(field)) { + int index = layout.GetHasbitIndex(field); + assert(index != 0); + presence = absl::StrCat(index); + } else if (field->real_containing_oneof()) { + MessageLayout::Size case_offset = + layout.GetOneofCaseOffset(field->real_containing_oneof()); + + // We encode as negative to distinguish from hasbits. + case_offset.size32 = ~case_offset.size32; + case_offset.size64 = ~case_offset.size64; + assert(case_offset.size32 < 0); + assert(case_offset.size64 < 0); + presence = GetSizeInit(case_offset); + } + + output(" "); + WriteField(field, GetSizeInit(layout.GetFieldOffset(field)), presence, + submsg_index, output); + output(",\n"); +} + +// Writes a single message into a .upb.c source file. +void WriteMessage(const protobuf::Descriptor* message, Output& output, + bool fasttable_enabled) { + std::string msg_name = ToCIdent(message->full_name()); + std::string fields_array_ref = "NULL"; + std::string submsgs_array_ref = "NULL"; + uint8_t dense_below = 0; + const int dense_below_max = std::numeric_limits::max(); + MessageLayout layout(message); + SubmsgArray submsg_array(message); + + if (!submsg_array.submsgs().empty()) { + // TODO(haberman): could save a little bit of space by only generating a + // "submsgs" array for every strongly-connected component. + std::string submsgs_array_name = msg_name + "_submsgs"; + submsgs_array_ref = "&" + submsgs_array_name + "[0]"; + output("static const upb_msglayout *const $0[$1] = {\n", + submsgs_array_name, submsg_array.submsgs().size()); + + for (auto submsg : submsg_array.submsgs()) { + output(" &$0,\n", MessageInit(submsg)); + } + + output("};\n\n"); + } + + std::vector field_number_order = + FieldNumberOrder(message); + if (!field_number_order.empty()) { + std::string fields_array_name = msg_name + "__fields"; + fields_array_ref = "&" + fields_array_name + "[0]"; + output("static const upb_msglayout_field $0[$1] = {\n", + fields_array_name, field_number_order.size()); + for (int i = 0; i < static_cast(field_number_order.size()); i++) { + auto field = field_number_order[i]; + int submsg_index = 0; + + if (i < dense_below_max && field->number() == i + 1 && + (i == 0 || field_number_order[i - 1]->number() == i)) { + dense_below = i + 1; + } + + if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + submsg_index = submsg_array.GetIndex(field); + } + + WriteMessageField(field, layout, submsg_index, output); + } + output("};\n\n"); + } + + std::vector table; + uint8_t table_mask = -1; + + if (fasttable_enabled) { + table = FastDecodeTable(message, layout); + } + + if (table.size() > 1) { + assert((table.size() & (table.size() - 1)) == 0); + table_mask = (table.size() - 1) << 3; + } + + output("const upb_msglayout $0 = {\n", MessageInit(message)); + output(" $0,\n", submsgs_array_ref); + output(" $0,\n", fields_array_ref); + output(" $0, $1, $2, $3, $4,\n", GetSizeInit(layout.message_size()), + field_number_order.size(), + "false", // TODO: extendable + dense_below, + table_mask + ); + if (!table.empty()) { + output(" UPB_FASTTABLE_INIT({\n"); + for (const auto& ent : table) { + output(" {0x$1, &$0},\n", ent.first, + absl::StrCat(absl::Hex(ent.second, absl::kZeroPad16))); + } + output(" }),\n"); + } + output("};\n\n"); +} + +void WriteMessages(const protobuf::FileDescriptor* file, Output& output, + bool fasttable_enabled) { + for (auto* message : SortedMessages(file)) { + WriteMessage(message, output, fasttable_enabled); + } +} + +// Writes a .upb.c source file. +void WriteSource(const protobuf::FileDescriptor* file, Output& output, + bool fasttable_enabled) { + EmitFileWarning(file, output); + + output( + "#include \n" + "#include \"upb/msg_internal.h\"\n" + "#include \"$0\"\n", + HeaderFilename(file->name())); + + for (int i = 0; i < file->dependency_count(); i++) { + output("#include \"$0\"\n", HeaderFilename(file->dependency(i)->name())); + } + + output( + "\n" + "#include \"upb/port_def.inc\"\n" + "\n"); + + WriteMessages(file, output, fasttable_enabled); + + output("#include \"upb/port_undef.inc\"\n"); + output("\n"); +} + +class Generator : public protoc::CodeGenerator { + ~Generator() override {} + bool Generate(const protobuf::FileDescriptor* file, + const std::string& parameter, protoc::GeneratorContext* context, + std::string* error) const override; + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +bool Generator::Generate(const protobuf::FileDescriptor* file, + const std::string& parameter, + protoc::GeneratorContext* context, + std::string* error) const { + bool fasttable_enabled = false; + std::vector> params; + google::protobuf::compiler::ParseGeneratorParameter(parameter, ¶ms); + + for (const auto& pair : params) { + if (pair.first == "fasttable") { + fasttable_enabled = true; + } else { + *error = "Unknown parameter: " + pair.first; + return false; + } + } + + Output h_output(context->Open(HeaderFilename(file->name()))); + WriteHeader(file, h_output); + + Output c_output(context->Open(SourceFilename(file->name()))); + WriteSource(file, c_output, fasttable_enabled); + + return true; +} + +} // namespace +} // namespace upbc + +int main(int argc, char** argv) { + std::unique_ptr generator( + new upbc::Generator()); + return google::protobuf::compiler::PluginMain(argc, argv, generator.get()); +} diff --git a/third_party/upb/upbc/protoc-gen-upbdefs.cc b/third_party/upb/upbc/protoc-gen-upbdefs.cc new file mode 100644 index 00000000..4ec4163f --- /dev/null +++ b/third_party/upb/upbc/protoc-gen-upbdefs.cc @@ -0,0 +1,207 @@ +// Copyright (c) 2009-2021, Google LLC +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Google LLC nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "google/protobuf/compiler/code_generator.h" +#include "google/protobuf/compiler/plugin.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" +#include "upbc/common.h" + +namespace upbc { +namespace { + +namespace protoc = ::google::protobuf::compiler; +namespace protobuf = ::google::protobuf; + +std::string DefInitSymbol(const protobuf::FileDescriptor *file) { + return ToCIdent(file->name()) + "_upbdefinit"; +} + +static std::string DefHeaderFilename(std::string proto_filename) { + return StripExtension(proto_filename) + ".upbdefs.h"; +} + +static std::string DefSourceFilename(std::string proto_filename) { + return StripExtension(proto_filename) + ".upbdefs.c"; +} + +void GenerateMessageDefAccessor(const protobuf::Descriptor* d, Output& output) { + output("UPB_INLINE const upb_msgdef *$0_getmsgdef(upb_symtab *s) {\n", + ToCIdent(d->full_name())); + output(" _upb_symtab_loaddefinit(s, &$0);\n", DefInitSymbol(d->file())); + output(" return upb_symtab_lookupmsg(s, \"$0\");\n", d->full_name()); + output("}\n"); + output("\n"); + + for (int i = 0; i < d->nested_type_count(); i++) { + GenerateMessageDefAccessor(d->nested_type(i), output); + } +} + +void WriteDefHeader(const protobuf::FileDescriptor* file, Output& output) { + EmitFileWarning(file, output); + + output( + "#ifndef $0_UPBDEFS_H_\n" + "#define $0_UPBDEFS_H_\n\n" + "#include \"upb/def.h\"\n" + "#include \"upb/port_def.inc\"\n" + "#ifdef __cplusplus\n" + "extern \"C\" {\n" + "#endif\n\n", + ToPreproc(file->name())); + + output("#include \"upb/def.h\"\n"); + output("\n"); + output("#include \"upb/port_def.inc\"\n"); + output("\n"); + + output("extern upb_def_init $0;\n", DefInitSymbol(file)); + output("\n"); + + for (int i = 0; i < file->message_type_count(); i++) { + GenerateMessageDefAccessor(file->message_type(i), output); + } + + output( + "#ifdef __cplusplus\n" + "} /* extern \"C\" */\n" + "#endif\n" + "\n" + "#include \"upb/port_undef.inc\"\n" + "\n" + "#endif /* $0_UPBDEFS_H_ */\n", + ToPreproc(file->name())); +} + + +void WriteDefSource(const protobuf::FileDescriptor* file, Output& output) { + EmitFileWarning(file, output); + + output("#include \"upb/def.h\"\n"); + output("#include \"$0\"\n", DefHeaderFilename(file->name())); + output("\n"); + + for (int i = 0; i < file->dependency_count(); i++) { + output("extern upb_def_init $0;\n", DefInitSymbol(file->dependency(i))); + } + + std::vector file_messages = + SortedMessages(file); + + for (auto message : file_messages) { + output("extern const upb_msglayout $0;\n", MessageInit(message)); + } + output("\n"); + + if (!file_messages.empty()) { + output("static const upb_msglayout *layouts[$0] = {\n", file_messages.size()); + for (auto message : file_messages) { + output(" &$0,\n", MessageInit(message)); + } + output("};\n"); + output("\n"); + } + + protobuf::FileDescriptorProto file_proto; + file->CopyTo(&file_proto); + std::string file_data; + file_proto.SerializeToString(&file_data); + + output("static const char descriptor[$0] = {", file_data.size()); + + // C90 only guarantees that strings can be up to 509 characters, and some + // implementations have limits here (for example, MSVC only allows 64k: + // https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/fatal-error-c1091. + // So we always emit an array instead of a string. + for (size_t i = 0; i < file_data.size();) { + for (size_t j = 0; j < 25 && i < file_data.size(); ++i, ++j) { + output("'$0', ", absl::CEscape(file_data.substr(i, 1))); + } + output("\n"); + } + output("};\n\n"); + + output("static upb_def_init *deps[$0] = {\n", file->dependency_count() + 1); + for (int i = 0; i < file->dependency_count(); i++) { + output(" &$0,\n", DefInitSymbol(file->dependency(i))); + } + output(" NULL\n"); + output("};\n"); + output("\n"); + + output("upb_def_init $0 = {\n", DefInitSymbol(file)); + output(" deps,\n"); + if (file_messages.empty()) { + output(" NULL,\n"); + } else { + output(" layouts,\n"); + } + output(" \"$0\",\n", file->name()); + output(" UPB_STRVIEW_INIT(descriptor, $0)\n", file_data.size()); + output("};\n"); +} + +class Generator : public protoc::CodeGenerator { + ~Generator() override {} + bool Generate(const protobuf::FileDescriptor* file, + const std::string& parameter, protoc::GeneratorContext* context, + std::string* error) const override; + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +bool Generator::Generate(const protobuf::FileDescriptor* file, + const std::string& parameter, + protoc::GeneratorContext* context, + std::string* error) const { + std::vector> params; + google::protobuf::compiler::ParseGeneratorParameter(parameter, ¶ms); + + for (const auto& pair : params) { + *error = "Unknown parameter: " + pair.first; + return false; + } + + Output h_def_output(context->Open(DefHeaderFilename(file->name()))); + WriteDefHeader(file, h_def_output); + + Output c_def_output(context->Open(DefSourceFilename(file->name()))); + WriteDefSource(file, c_def_output); + + return true; +} + +} // namespace +} // namespace upbc + +int main(int argc, char** argv) { + std::unique_ptr generator( + new upbc::Generator()); + return google::protobuf::compiler::PluginMain(argc, argv, generator.get()); +} diff --git a/tools/codegen/core/gen_hpack_tables.cc b/tools/codegen/core/gen_hpack_tables.cc new file mode 100644 index 00000000..d01a90eb --- /dev/null +++ b/tools/codegen/core/gen_hpack_tables.cc @@ -0,0 +1,247 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* generates constant tables for hpack.cc */ + +#include +#include +#include +#include + +#include +#include "src/core/ext/transport/chttp2/transport/huffsyms.h" + +/* + * Huffman decoder table generation + */ + +#define MAXHUFFSTATES 1024 + +/* represents a set of symbols as an array of booleans indicating inclusion */ +typedef struct { char included[GRPC_CHTTP2_NUM_HUFFSYMS]; } symset; +/* represents a lookup table indexed by a nibble */ +typedef struct { unsigned values[16]; } nibblelut; + +#define NOT_SET (~(unsigned)0) + +/* returns a symset that includes all possible symbols */ +static symset symset_all(void) { + symset x; + memset(x.included, 1, sizeof(x.included)); + return x; +} + +/* returns a symset that includes no symbols */ +static symset symset_none(void) { + symset x; + memset(x.included, 0, sizeof(x.included)); + return x; +} + +/* returns an empty nibblelut */ +static nibblelut nibblelut_empty(void) { + nibblelut x; + int i; + for (i = 0; i < 16; i++) { + x.values[i] = NOT_SET; + } + return x; +} + +/* counts symbols in a symset - only used for debug builds */ +#ifndef NDEBUG +static int nsyms(symset s) { + int i; + int c = 0; + for (i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) { + c += s.included[i] != 0; + } + return c; +} +#endif + +/* global table of discovered huffman decoding states */ +static struct { + /* the bit offset that this state starts at */ + unsigned bitofs; + /* the set of symbols that this state started with */ + symset syms; + + /* lookup table for the next state */ + nibblelut next; + /* lookup table for what to emit */ + nibblelut emit; +} huffstates[MAXHUFFSTATES]; +static unsigned nhuffstates = 0; + +/* given a number of decoded bits and a set of symbols that are live, + return the index into the decoder table for this state. + set isnew to 1 if this state was previously undiscovered */ +static unsigned state_index(unsigned bitofs, symset syms, unsigned *isnew) { + unsigned i; + for (i = 0; i < nhuffstates; i++) { + if (huffstates[i].bitofs != bitofs) continue; + if (0 != memcmp(huffstates[i].syms.included, syms.included, + GRPC_CHTTP2_NUM_HUFFSYMS)) + continue; + *isnew = 0; + return i; + } + GPR_ASSERT(nhuffstates != MAXHUFFSTATES); + + i = nhuffstates; + nhuffstates++; + + huffstates[i].bitofs = bitofs; + huffstates[i].syms = syms; + huffstates[i].next = nibblelut_empty(); + huffstates[i].emit = nibblelut_empty(); + *isnew = 1; + return i; +} + +/* recursively build a decoding table + + state - the huffman state that we are trying to fill in + nibble - the current nibble + nibbits - the number of bits in the nibble that have been filled in + bitofs - the number of bits of symbol that have been decoded + emit - the symbol to emit on this nibble (or -1 if no symbol has been + found) + syms - the set of symbols that could be matched */ +static void build_dec_tbl(unsigned state, unsigned nibble, int nibbits, + unsigned bitofs, unsigned emit, symset syms) { + unsigned i; + unsigned bit; + + /* If we have four bits in the nibble we're looking at, then we can fill in + a slot in the lookup tables. */ + if (nibbits == 4) { + unsigned isnew; + /* Find the state that we are in: this may be a new state, in which case + we recurse to fill it in, or we may have already seen this state, in + which case the recursion terminates */ + unsigned st = state_index(bitofs, syms, &isnew); + GPR_ASSERT(huffstates[state].next.values[nibble] == NOT_SET); + huffstates[state].next.values[nibble] = st; + huffstates[state].emit.values[nibble] = emit; + if (isnew) { + build_dec_tbl(st, 0, 0, bitofs, NOT_SET, syms); + } + return; + } + + assert(nsyms(syms)); + + /* A bit can be 0 or 1 */ + for (bit = 0; bit < 2; bit++) { + /* walk over active symbols and see if they have this bit set */ + symset nextsyms = symset_none(); + for (i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) { + if (!syms.included[i]) continue; /* disregard inactive symbols */ + if (((grpc_chttp2_huffsyms[i].bits >> + (grpc_chttp2_huffsyms[i].length - bitofs - 1)) & + 1) == bit) { + /* the bit is set, include it in the next recursive set */ + if (grpc_chttp2_huffsyms[i].length == bitofs + 1) { + /* additionally, we've gotten to the end of a symbol - this is a + special recursion step: re-activate all the symbols, reset + bitofs to zero, and recurse */ + build_dec_tbl(state, (nibble << 1) | bit, nibbits + 1, 0, i, + symset_all()); + /* skip the remainder of this loop */ + goto next; + } + nextsyms.included[i] = 1; + } + } + /* recurse down for this bit */ + build_dec_tbl(state, (nibble << 1) | bit, nibbits + 1, bitofs + 1, emit, + nextsyms); + next:; + } +} + +static nibblelut ctbl[MAXHUFFSTATES]; +static int nctbl; + +static int ctbl_idx(nibblelut x) { + int i; + for (i = 0; i < nctbl; i++) { + if (0 == memcmp(&x, ctbl + i, sizeof(nibblelut))) return i; + } + ctbl[i] = x; + nctbl++; + return i; +} + +static void dump_ctbl(const char *name) { + int i, j; + printf("static const gpr_int16 %s[%d*16] = {\n", name, nctbl); + for (i = 0; i < nctbl; i++) { + for (j = 0; j < 16; j++) { + printf("%d,", ctbl[i].values[j]); + } + printf("\n"); + } + printf("};\n"); +} + +static void generate_huff_tables(void) { + unsigned i; + build_dec_tbl(state_index(0, symset_all(), &i), 0, 0, 0, NOT_SET, + symset_all()); + + nctbl = 0; + printf("static const gpr_uint8 next_tbl[%d] = {", nhuffstates); + for (i = 0; i < nhuffstates; i++) { + printf("%d,", ctbl_idx(huffstates[i].next)); + } + printf("};\n"); + dump_ctbl("next_sub_tbl"); + + nctbl = 0; + printf("static const gpr_uint16 emit_tbl[%d] = {", nhuffstates); + for (i = 0; i < nhuffstates; i++) { + printf("%d,", ctbl_idx(huffstates[i].emit)); + } + printf("};\n"); + dump_ctbl("emit_sub_tbl"); +} + +static void generate_base64_huff_encoder_table(void) { + static const char alphabet[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + int i; + + printf( + "static const struct { gpr_uint16 bits, gpr_uint8 length } " + "base64_syms[64] = {\n"); + for (i = 0; i < 64; i++) { + printf("{0x%x, %d},", grpc_chttp2_huffsyms[(unsigned char)alphabet[i]].bits, + grpc_chttp2_huffsyms[(unsigned char)alphabet[i]].length); + } + printf("};\n"); +} + +int main(void) { + generate_huff_tables(); + generate_base64_huff_encoder_table(); + + return 0; +} diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/main.cc b/tools/distrib/python/grpcio_tools/grpc_tools/main.cc new file mode 100644 index 00000000..2f1c70ff --- /dev/null +++ b/tools/distrib/python/grpcio_tools/grpc_tools/main.cc @@ -0,0 +1,183 @@ +// Copyright 2016 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "grpc_tools/main.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "src/compiler/python_generator.h" + +using ::google::protobuf::FileDescriptor; +using ::google::protobuf::compiler::CodeGenerator; +using ::google::protobuf::compiler::DiskSourceTree; +using ::google::protobuf::compiler::GeneratorContext; +using ::google::protobuf::compiler::Importer; +using ::google::protobuf::compiler::MultiFileErrorCollector; +using ::google::protobuf::io::StringOutputStream; +using ::google::protobuf::io::ZeroCopyOutputStream; + +namespace grpc_tools { +int protoc_main(int argc, char* argv[]) { + google::protobuf::compiler::CommandLineInterface cli; + cli.AllowPlugins("protoc-"); + + // Proto2 Python + google::protobuf::compiler::python::Generator py_generator; + cli.RegisterGenerator("--python_out", &py_generator, + "Generate Python source file."); + + // gRPC Python + grpc_python_generator::GeneratorConfiguration grpc_py_config; + grpc_python_generator::PythonGrpcGenerator grpc_py_generator(grpc_py_config); + cli.RegisterGenerator("--grpc_python_out", &grpc_py_generator, + "Generate Python source file."); + + return cli.Run(argc, argv); +} + +namespace internal { + +class GeneratorContextImpl : public GeneratorContext { + public: + GeneratorContextImpl( + const std::vector& parsed_files, + std::vector>* files_out) + : files_(files_out), parsed_files_(parsed_files) {} + + ZeroCopyOutputStream* Open(const std::string& filename) { + files_->emplace_back(filename, ""); + return new StringOutputStream(&(files_->back().second)); + } + + // NOTE(rbellevi): Equivalent to Open, since all files start out empty. + ZeroCopyOutputStream* OpenForAppend(const std::string& filename) { + return Open(filename); + } + + // NOTE(rbellevi): Equivalent to Open, since all files start out empty. + ZeroCopyOutputStream* OpenForInsert(const std::string& filename, + const std::string& insertion_point) { + return Open(filename); + } + + void ListParsedFiles( + std::vector* output) { + *output = parsed_files_; + } + + private: + std::vector>* files_; + const std::vector& parsed_files_; +}; + +class ErrorCollectorImpl : public MultiFileErrorCollector { + public: + ErrorCollectorImpl(std::vector<::grpc_tools::ProtocError>* errors, + std::vector<::grpc_tools::ProtocWarning>* warnings) + : errors_(errors), warnings_(warnings) {} + + void AddError(const std::string& filename, int line, int column, + const std::string& message) { + errors_->emplace_back(filename, line, column, message); + } + + void AddWarning(const std::string& filename, int line, int column, + const std::string& message) { + warnings_->emplace_back(filename, line, column, message); + } + + private: + std::vector<::grpc_tools::ProtocError>* errors_; + std::vector<::grpc_tools::ProtocWarning>* warnings_; +}; + +static void calculate_transitive_closure( + const FileDescriptor* descriptor, + std::vector* transitive_closure, + std::unordered_set* visited) { + for (int i = 0; i < descriptor->dependency_count(); ++i) { + const FileDescriptor* dependency = descriptor->dependency(i); + if (visited->find(dependency) == visited->end()) { + calculate_transitive_closure(dependency, transitive_closure, visited); + } + } + transitive_closure->push_back(descriptor); + visited->insert(descriptor); +} + +} // end namespace internal + +static int generate_code( + CodeGenerator* code_generator, char* protobuf_path, + const std::vector* include_paths, + std::vector>* files_out, + std::vector<::grpc_tools::ProtocError>* errors, + std::vector<::grpc_tools::ProtocWarning>* warnings) { + std::unique_ptr error_collector( + new internal::ErrorCollectorImpl(errors, warnings)); + std::unique_ptr source_tree(new DiskSourceTree()); + for (const auto& include_path : *include_paths) { + source_tree->MapPath("", include_path); + } + Importer importer(source_tree.get(), error_collector.get()); + const FileDescriptor* parsed_file = importer.Import(protobuf_path); + if (parsed_file == nullptr) { + return 1; + } + std::vector transitive_closure; + std::unordered_set visited; + internal::calculate_transitive_closure(parsed_file, &transitive_closure, + &visited); + internal::GeneratorContextImpl generator_context(transitive_closure, + files_out); + std::string error; + for (const auto descriptor : transitive_closure) { + code_generator->Generate(descriptor, "", &generator_context, &error); + } + return 0; +} + +int protoc_get_protos( + char* protobuf_path, const std::vector* include_paths, + std::vector>* files_out, + std::vector<::grpc_tools::ProtocError>* errors, + std::vector<::grpc_tools::ProtocWarning>* warnings) { + ::google::protobuf::compiler::python::Generator python_generator; + return generate_code(&python_generator, protobuf_path, include_paths, + files_out, errors, warnings); +} + +int protoc_get_services( + char* protobuf_path, const std::vector* include_paths, + std::vector>* files_out, + std::vector<::grpc_tools::ProtocError>* errors, + std::vector<::grpc_tools::ProtocWarning>* warnings) { + grpc_python_generator::GeneratorConfiguration grpc_py_config; + grpc_python_generator::PythonGrpcGenerator grpc_py_generator(grpc_py_config); + return generate_code(&grpc_py_generator, protobuf_path, include_paths, + files_out, errors, warnings); +} +} // end namespace grpc_tools